New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Loss function that takes a data dependent 3rd input #21
Comments
@freud14 Just tagging you because you made the last commit on the rep. Please loop anyone else who might be a better fit to answer this question. |
Hi, def custom_loss(output, y_true):
y_pred, input_tensor = output
return F.binary_crossentropy(y_pred, y_true) + input_tensor.mean()
class MyPyTorchModule(nn.Module):
def forward(self, input):
....
....
return y_pred, input_tensor # or y_pred, input_tensor.mean() in your example Thank you. Frédérik |
Awesome! Definitely felt there would be a simple solution to this. Thank you :) |
Ooops . Looks like I closed it by accident. I'm looking at a case that is slightly different from the examples above. I have a Dataset that outputs 3 tensors (say Now, I'm implement this
Is there a better way to do this? |
Not sure to understand your case but something like this should do it: def custom_loss(y_pred, input):
y_true, mask = input
return F.cross_entropy(input=y_pred, target=y_true, reduction='none')[mask] Where your dataset has to output a tuple like Btw, if you did |
The output of the dataset is a tuple like Also, would pytoune be able to infer the device from the original pytorch model? |
Yes. More specifically, you have to pass a function for the parameter collate_fn to the Dataloader that transforms your
Yes if you've called |
class ClientDataset(torch.utils.data.Dataset):
def __getitem__(self, index):
...
...
return x, y, mask
def my_collate_function(samples):
x, y, mask = list(zip(*samples)) # transform sample list into a list for each input.
return x, (y, mask)
loader = torch.utils.data.DataLoader(my_dataset_instance, collate_fn=my_collate_function) Here is a code skeleton for what I mean. |
Hi, Thank you for the suggestion :) I had something quite similar too.. |
Hi, |
Hi, have I answered to all your questions? |
Hi @freud14 , |
Is there anyway to implement the functionality as described in this post. https://stackoverflow.com/questions/46464549/keras-custom-loss-function-accessing-current-input-pattern
Gist:
implementing something like code below
This can be called with
loss=custom_loss_wrapper(model.input)
The text was updated successfully, but these errors were encountered: