Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: expected scalar type Long but found Int in [cd_spot_the_diff_mnist_wine.ipynb] #411

Closed
tomaszek0 opened this issue Dec 16, 2021 · 17 comments · Fixed by #423
Closed
Assignees
Labels
Priority: High Type: Bug Something isn't working

Comments

@tomaszek0
Copy link

I am getting the following error when trying to execute code (in [10] section "Interpretable Drift Detection on the Wine Quality Dataset"):

RuntimeError Traceback (most recent call last)
~\AppData\Local\Temp/ipykernel_3564/2414863533.py in
10 )
11
---> 12 preds_h0 = cd.predict(x_h0)
13 preds_corr = cd.predict(x_corr)

~\AppData\Roaming\Python\Python39\site-packages\alibi_detect\cd\spot_the_diff.py in predict(self, x, return_p_val, return_distance, return_probs, return_model)
173 data, and the trained model.
174 """
--> 175 return self._detector.predict(x, return_p_val, return_distance, return_probs, return_model)

~\AppData\Roaming\Python\Python39\site-packages\alibi_detect\cd\pytorch\spot_the_diff.py in predict(self, x, return_p_val, return_distance, return_probs, return_model)
212 data, and the trained model.
213 """
--> 214 preds = self._detector.predict(x, return_p_val, return_distance, return_probs, return_model=True)
215 preds['data']['diffs'] = preds['data']['model'].diffs.detach().cpu().numpy() # type: ignore
216 preds['data']['diff_coeffs'] = preds['data']['model'].coeffs.detach().cpu().numpy() # type: ignore

~\AppData\Roaming\Python\Python39\site-packages\alibi_detect\cd\base.py in predict(self, x, return_p_val, return_distance, return_probs, return_model)
241 """
242 # compute drift scores
--> 243 p_val, dist, probs_ref, probs_test = self.score(x)
244 drift_pred = int(p_val < self.p_val)
245

~\AppData\Roaming\Python\Python39\site-packages\alibi_detect\cd\pytorch\classifier.py in score(self, x)
182 self.model = self.model.to(self.device)
183 train_args = [self.model, self.loss_fn, dl_tr, self.device]
--> 184 trainer(*train_args, **self.train_kwargs) # type: ignore
185 preds = self.predict_fn(x_te, self.model.eval())
186 preds_oof_list.append(preds)

~\AppData\Roaming\Python\Python39\site-packages\alibi_detect\models\pytorch\trainer.py in trainer(model, loss_fn, dataloader, device, optimizer, learning_rate, preprocess_fn, epochs, reg_loss_fn, verbose)
53 y_hat = model(x)
54 optimizer.zero_grad() # type: ignore
---> 55 loss = loss_fn(y_hat, y) + reg_loss_fn(model)
56 loss.backward()
57 optimizer.step() # type: ignore

~\anaconda3\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1101 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used
1104 full_backward_hooks, non_full_backward_hooks = [], []

~\anaconda3\lib\site-packages\torch\nn\modules\loss.py in forward(self, input, target)
1148
1149 def forward(self, input: Tensor, target: Tensor) -> Tensor:
-> 1150 return F.cross_entropy(input, target, weight=self.weight,
1151 ignore_index=self.ignore_index, reduction=self.reduction,
1152 label_smoothing=self.label_smoothing)

~\anaconda3\lib\site-packages\torch\nn\functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing)
2844 if size_average is not None or reduce is not None:
2845 reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 2846 return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
2847
2848

RuntimeError: expected scalar type Long but found Int

I use Python 3.8.8/ Win10 installed on the AMD Ryzen with integrated graphics (AMD).

@jklaise
Copy link
Member

jklaise commented Jan 5, 2022

@tomaszek0 what's your torch version?

@tomaszek0
Copy link
Author

@jklaise
print(torch.version) gave this output: 1.10.0

!pip install torch gave this output:
Requirement already satisfied: torch in c:\users\tsobi\anaconda3\lib\site-packages (1.10.0)
Requirement already satisfied: typing_extensions in c:\users\tsobi\appdata\roaming\python\python39\site-packages (from torch) (3.7.4.3)

@jklaise
Copy link
Member

jklaise commented Jan 6, 2022

I can't reproduce on a Linux machine.

Have you had any problems training regular Pytorch models? Specifically evaluating loss functions (which is where the error comes from)? This is most likely a Pytorch issue, possible related to your setup (Windows + AMD cpu), there are a few issues that come up about this: https://github.com/pytorch/pytorch/search?q=expected+scalar+type+long+but+found+int&type=issues

I would suggest trying to train some vanilla Pytorch models first and see if the same or similar issues happen.

@tomaszek0
Copy link
Author

tomaszek0 commented Jan 7, 2022

@jklaise, Thank you for your suggestions.
Testing on CUDA free system (CPU AMD Ryzen 5-4500U with Radeon Graphics – driver no. 27.20.21020.4003 _2021):

The output of the code (import torch // torch.cuda.is_available()): False

  1. I checked my PyTorch installation according to https://docs.microsoft.com/en-us/windows/ai/windows-ml/tutorials/pytorch-installation (that followed PyTorch official instruction). At this level of inquiry, it is ok.

  2. I ran code from this link (https://docs.microsoft.com/en-us/windows/ai/windows-ml/tutorials/pytorch-data) and next tutorial pages to test PyTorch execution flow on my computer. After removing the error - a known issue 37 - it produced an expected result using Jupyter Notebook tool and Visual Studio 2019. The model gave info that “The model will be running on cpu device”.

  3. I again executed the "Interpretable Drift Detection on the Wine Quality Dataset" - I obtained the same runtime error in the Jupyter notebook and info that "No GPU detected, fall back on CPU". Btw, the execution of the first code lines on Visual Studio failed and gave this output:

W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'cudart64_110.dll'; dlerror: cudart64_110.dll not found
I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.
W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'nvcuda.dll'; dlerror: nvcuda.dll not found
W tensorflow/stream_executor/cuda/cuda_driver.cc:269] failed call to cuInit: UNKNOWN ERROR (303)
I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:169] retrieving CUDA diagnostic information for host: DESKTOP-8J5K8HK
I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:176] hostname: DESKTOP-8J5K8HK
I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX AVX2
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.

Testing on CPU Intel Core i5 and GPU NVIDIA GTX460 (336 CUDA cores, 32 ROPs, and 56 texture units; driver no. 23.21.13.8813 _2017):

The output of the code (import torch // torch.cuda.is_available()): False

  1. PyTorch installation checking as above; result: ok

  2. a tutorial script executed as above (the CIFAR10 dataset); result: 2 x ok, but computation was slower. The model gave info that “The model will be running on cpu device”.

  3. a tutorial script executed as above (the Wine dataset); a result on the Jupyter Notebook: RuntimeError: expected scalar type Long but found Int; a result on the Visual Studio: failed and gave this output:

Traceback (most recent call last):
File "…\anaconda3\lib\runpy.py", line 194, in _run_module_as_main
return _run_code(code, main_globals, None,
File "…\anaconda3\lib\runpy.py", line 87, in run_code
exec(code, run_globals)
File "…\microsoft visual studio\2019\community\common7\ide\extensions\microsoft\python\core\debugpy_main
.py", line 45, in
cli.main()
File "…\microsoft visual studio\2019\community\common7\ide\extensions\microsoft\python\core\debugpy/..\debugpy\server\cli.py", line 430, in main
run()
File "…\microsoft visual studio\2019\community\common7\ide\extensions\microsoft\python\core\debugpy/..\debugpy\server\cli.py", line 267, in run_file
runpy.run_path(options.target, run_name=compat.force_str("main"))
File "…\anaconda3\lib\runpy.py", line 265, in run_path
return _run_module_code(code, init_globals, run_name,
File "…\anaconda3\lib\runpy.py", line 97, in _run_module_code
_run_code(code, mod_globals, init_globals,
File "…\anaconda3\lib\runpy.py", line 87, in _run_code
exec(code, run_globals)
File "…\source\repos\Python_Interpretable_drift_Wine\Python_Interpretable_drift_Wine.py", line 41, in
preds_h0 = cd.predict(x_h0)
File "…\anaconda3\lib\site-packages\alibi_detect\cd\spot_the_diff.py", line 175, in predict
return self._detector.predict(x, return_p_val, return_distance, return_probs, return_model)
File "…\anaconda3\lib\site-packages\alibi_detect\cd\pytorch\spot_the_diff.py", line 214, in predict
preds = self._detector.predict(x, return_p_val, return_distance, return_probs, return_model=True)
File "…\anaconda3\lib\site-packages\alibi_detect\cd\base.py", line 243, in predict
p_val, dist, probs_ref, probs_test = self.score(x)
File "…\anaconda3\lib\site-packages\alibi_detect\cd\pytorch\classifier.py", line 184, in score
trainer(*train_args, **self.train_kwargs) # type: ignore
File "…\anaconda3\lib\site-packages\alibi_detect\models\pytorch\trainer.py", line 55, in trainer
loss = loss_fn(y_hat, y) + reg_loss_fn(model)
File "…\anaconda3\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "…\anaconda3\lib\site-packages\torch\nn\modules\loss.py", line 1150, in forward
return F.cross_entropy(input, target, weight=self.weight,
File "…\anaconda3\lib\site-packages\torch\nn\functional.py", line 2846, in cross_entropy
return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
RuntimeError: expected scalar type Long but found Int

preds_h0 = cd.predict(x_h0) RuntimeError: expected scalar type Long but found Int

Testing on CPU Intel Core i5 and GPU NVIDIA GTX460 (336 CUDA cores, 32 ROPs, and 56 texture units; updated driver no. 23.21.13.9135 _2018):

The output of the code (import torch // torch.cuda.is_available()): False

As mentioned on https://discuss.pytorch.org/t/pytorch-nvidia-gtx460-version/62461 GTX 460 is too weak for PyTorch.

@jklaise
Copy link
Member

jklaise commented Jan 7, 2022

@tomaszek0 thanks for that, this gives something to work with. It's encouraging that the basic Pytorch training tutorial works.

On the 1st CUDA free system you say that the first few lines failed, but it looks like it's just the usual Tensorflow trying to look for CUDA and not finding it, hence reverting back to CPU. Can you confirm that there is no actual runtime error raised?

It looks like for all 3 scenarios the execution was attempted on CPU (due to the older unsupported GPU as you mention). The code should be executable in the same way in either CPU or GPU, so I think we can disregard any GPU presence for further debugging.

I think the best thing to do is check what the types are before hitting the error with a debugger. The steps would be to make a script with the code that causes the error, set a breakpoint at the line where the error happens preds_h0=cd.predict(x_h0) and step through the debugger until you hit L55 in alibi_detect.models.pytorch.trainer and look at the dtype of y_hat, y, loss_fn(y_hat, y), reg_loss_fn(model). From the logs it looks like the issue is with the loss_fn part as that is cross_entropy.

For reference, running this on my environment I have the following dtypes:
y_hat: torch.float32
y: torch.int64 <-- this might be the culprit here as int64 is usually also a Long type.
loss_fn(y_hat, y): torch.float32 <-- I suspect this will fail
reg_loss_fn(model): torch.float32

Would you be able to run through these steps?

EDIT: Even more simply, can you evaluate torch.nn.CrossEntropyLoss(y_hat, y) where y_hat is a torch vector of torch.float32 type and y is a torch vector of int64 type?

@tomaszek0
Copy link
Author

@jklaise, So far here is what I have found:
As to the first point. In the case of the "Interpretable Drift Detection on the Wine Quality Dataset" - unfortunately, I got the same logs as listed above (1.3 section) after running the import libraries section (first code segment from the tutorial) on Visual Studio (the AMD-based system).

It is a bit strange. So for comparison, I executed the tutorial entitled “Online detection with MMD (Maximum Mean Discrepancy) and Pytorch” (https://docs.seldon.io/projects/alibi-detect/en/latest/examples/cd_online_wine.html). The whole tutorial code gave the expected output in both AMD and Intel-based systems (Jupyter Notebook). The running code on the AMD based system in the Visual Studio gave obvious logs:

“W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'cudart64_110.dll'; dlerror: cudart64_110.dll not found
I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.”

and yield the expected result (chart).

After the execution of the next code segment that defines online drift detector "from alibi_detect.cd import MMDDriftOnline" I gave these logs and the expected output:

“W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'cudart64_110.dll'; dlerror: cudart64_110.dll not found
I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.
W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'nvcuda.dll'; dlerror: nvcuda.dll not found
W tensorflow/stream_executor/cuda/cuda_driver.cc:269] failed call to cuInit: UNKNOWN ERROR (303)
I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:169] retrieving CUDA diagnostic information for host: DESKTOP-8J5K8HK
I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:176] hostname: DESKTOP-8J5K8HK
I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX AVX2
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
No GPU detected, fall back on CPU.
Generating permutations of kernel matrix...

Using the next code segments also yields the expected output (but some computations without displaying the charts).
The running code on the Intel-NVIDIA-based system in the Visual Studio gave also the expected result.

I also checked the outputs of the code:
import torch
torch.cuda.is_available()
False

print(torch.rand(3,3).cuda())
AssertionError: Torch not compiled with CUDA enabled

conda install cuda -c nvidia
All requested packages already installed.

It seems that installation TensorFlow and Pytorch in the same environment results in some issues.

@jklaise
Copy link
Member

jklaise commented Jan 10, 2022

@tomaszek0 thanks for that. I'm not sure I fully follow, it seems you have both tensorflow (required by alibi-detect) and torch (optional by alibi-detect) installed and it's only certain notebooks that give you issues? Did you manage to check if you can evaluate the torch loss with the different types?

@tomaszek0
Copy link
Author

@jklaise,
Yes, the tutorial on "On-line drift..." worked fine, while the tutorial "Interpretable drift..." gave an error on only the "Wine" dataset.
Unfortunately, I am a beginner in Visual Studio. I got stuck on the breakpoint in the Visual Studio - after debugging, I have the cd.predict and <bound method SpotTheDiffDrift.predict of <alibi_detect.cd.spot_the_diff.SpotTheDiffDrift object at 0x000001CC7F210FA0> on the right, and then, I can expand this and have a list of special variables.

@jklaise
Copy link
Member

jklaise commented Jan 11, 2022

@tomaszek0 another way of doing it is writing breakpoint() in the Python script and when executed in the terminal it will drop you into the Python debugger from where you can navigate with shortcuts: https://docs.python.org/3/library/pdb.html#debugger-commands. It's more primitive but can often get the job done.

@tomaszek0
Copy link
Author

tomaszek0 commented Jan 11, 2022

@jklaise, thanks for the lead. I found something like these (in pdb).
After debugging the line that causes the error "RunTime Error : expected scalar type Long but found Int":

...\anaconda]...\alibi_detect]models]pytorch]trainer.py(55)trainer()
-> loss = loss_fn(y_hat, y) + reg_loss_fn(model)
(Pdb) args loss
model = InterpretableClf( ..
(Pdb) ll
55 -> loss = loss_fn ...
(Pdb) p loss_fn
CrossEntropyLoss()
(Pdb) p y
tensor ... , dtype=torch.int32
(Pdb) p y_hat
tensor ... , grad_fn=Catbackward0
(Pdb) p loss_fn(y_hat, y)
*** Runtime Error: expected scalar type Long but found Int
(Pdb) p reg_loss_fn(model)
tensor(5.0263e-06, grad_fn=MulBackward0

@jklaise
Copy link
Member

jklaise commented Jan 11, 2022

@tomaszek0 can you try evaluating loss_fn(y_hat.detach(), y)? Basically the .detach() gets rid of gradient information so you're left with pure float32 and int32 tensors.

Curiously, on my machine y is of type torch.int64 which is strange and could be the source of the problem. It would be worth tracing the type of y backwards from the dataloader where it comes from to see what type it is initially.

@jklaise
Copy link
Member

jklaise commented Jan 11, 2022

Oh dear, I think I've found the issue and it's a numpy bug in Windows... numpy/numpy#17640

Here we create labels to distinguish the reference and test dataset which should just be integers: numpy/numpy#17640

According to the above, in Windows astype results in int32 whilst elsewhere it is the expected int64. And unfortunately it appears evaluating loss functions with int32 types doesn't work with torch.

I think a fix from our end would be to change from astype(int) to astype(np.int64) as explained in the Github issue. You could confirm whether this works by manually editing the line in alibi_detect.cd.base.

@tomaszek0
Copy link
Author

@jklaise, I obtained the same debugging output on Intel and AMD-based machines.

(pdb) p loss_fn(y_hat.detach(), y)
*** Runtime Error: expected scalar type Long but found Int

@tomaszek0
Copy link
Author

tomaszek0 commented Jan 11, 2022

@jklaise,
This solution was successful. I obtained the expected output by running code. It is a simple but efficient solution. Thanks for guiding me through this task.
However, I noticed that the final chart does not contain all bars as in the tutorial (there are no bars for pH, density, and citric acid).

@jklaise jklaise self-assigned this Jan 12, 2022
@jklaise jklaise added Priority: High Type: Bug Something isn't working labels Jan 12, 2022
@jklaise
Copy link
Member

jklaise commented Jan 12, 2022

@tomaszek0 great to hear that worked. We will be submitting a fix for this soon as well as executing our test suite on a Windows platform to catch these types of bugs early.

For the "no bars" issue, are they completely absent from the figure or just close to zero? Perhaps @ojcobb can comment.

@tomaszek0
Copy link
Author

@jklaise,
I checked again several times. the result is different each time. Perhaps this is a typical classifier behavior, but it is worrying that the classification result may appear on the right side, and in another run - on the left side of the final chart. This completely changes the interpretation of the data.

@jklaise
Copy link
Member

jklaise commented Jan 14, 2022

@tomaszek0 we're aware of this (see #390), there is some uncontrolled randomness still which requires further investigation. That being said, for these use cases there is some merit in not relying on reproducibility as one may fall into the trap of being lucky with the randomness and getting results that are not fully representative (although I'm aware it raises some questions about the robustness of the method @ojcobb @ascillitoe).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Priority: High Type: Bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants