Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Runtime Error: Expected object of scalar type Long but got scalar type Float for argument #2 'target' in call to _thnn_nll_loss_forward #6

Open
YCAyca opened this issue Apr 5, 2024 · 1 comment

Comments

@YCAyca
Copy link

YCAyca commented Apr 5, 2024

Hello, Im getting the error mentioned in the title. The explanation in torch.nn.functional.cross_entropy function says that we need to give either the class indexes of ground truth classes or class probabilities (which suits to this case since the ground truths are not strict 0 1 labels but predictions coming from the pretrained model). And the code implementation seems to be correct for the second case, but it gives me runtime error so I had to change it with ground truth class indexes which seem to be work well. I dont know if it would have a significant role to decrease the accuracy though. Any idea??

Screenshot from 2024-04-05 18-12-59

@Thunderbeee
Copy link
Owner

To resolve this error, you need to ensure that the target tensor contains long integers representing the class indices. Here are a few possible solutions:
Convert the target tensor to long integers:

import torch

# Assuming your target tensor is 'target_tensor'
target_tensor = target_tensor.long()

Create the target tensor with the correct data type:

import torch

# Assuming your target data is in 'target_data'
target_tensor = torch.tensor(target_data, dtype=torch.long)

Check the data type of your target tensor:

import torch

# Assuming your target tensor is 'target_tensor'
print(target_tensor.dtype)

If the data type is not torch.long, convert it accordingly.
By ensuring that the target tensor contains long integers representing the class indices, you should be able to resolve the "Expected object of scalar type Long but got scalar type Float for argument #2 'target' in call to _thnn_nll_loss_forward" error.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants