Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loss function that takes a data dependent 3rd input #21

Closed
prabhuteja12 opened this issue Feb 18, 2019 · 12 comments
Closed

Loss function that takes a data dependent 3rd input #21

prabhuteja12 opened this issue Feb 18, 2019 · 12 comments

Comments

@prabhuteja12
Copy link

Is there anyway to implement the functionality as described in this post. https://stackoverflow.com/questions/46464549/keras-custom-loss-function-accessing-current-input-pattern

Gist:
implementing something like code below

    def custom_loss_wrapper(input_tensor):
        def custom_loss(y_true, y_pred):
            return K.binary_crossentropy(y_true, y_pred) + K.mean(input_tensor)
   return custom_loss

This can be called with loss=custom_loss_wrapper(model.input)

@prabhuteja12
Copy link
Author

@freud14 Just tagging you because you made the last commit on the rep. Please loop anyone else who might be a better fit to answer this question.

@freud14
Copy link
Collaborator

freud14 commented Mar 4, 2019

Hi,
Sorry for the delay, I didn't get a notification from your initial post. The way to do that in PyToune is to simply return the input or the input dependent quantity as the output of your module. Then, in your loss function, you compute your loss and add it to your input dependent quantity.

def custom_loss(output, y_true):
    y_pred, input_tensor = output
    return F.binary_crossentropy(y_pred, y_true) + input_tensor.mean()

class MyPyTorchModule(nn.Module):
    def forward(self, input):
        ....
        ....
        return y_pred, input_tensor # or y_pred, input_tensor.mean() in your example

Thank you.

Frédérik

@prabhuteja12
Copy link
Author

Awesome! Definitely felt there would be a simple solution to this. Thank you :)

@prabhuteja12
Copy link
Author

Ooops . Looks like I closed it by accident.

I'm looking at a case that is slightly different from the examples above. I have a Dataset that outputs 3 tensors (say x, y, mask). I'm trying to compute the loss with something like this.
F.cross_entropy(input=x, target=y, reduction='none')[mask]

Now, I'm implement this

local_model.loss_function = partial(loss, ignore_mask=mask)
local_model.train_on_batch(x.to(device), y.long().to(device))`

Is there a better way to do this?

@freud14
Copy link
Collaborator

freud14 commented Mar 4, 2019

Not sure to understand your case but something like this should do it:

def custom_loss(y_pred, input):
    y_true, mask = input
    return F.cross_entropy(input=y_pred, target=y_true, reduction='none')[mask]

Where your dataset has to output a tuple like (x, (y, mask)). You also will need to implement the collate_fn of the dataloader.

Btw, if you did local_model.to(device), you don't need to call .to(device) on your input tensor. PyToune will do it for you.

@prabhuteja12
Copy link
Author

The output of the dataset is a tuple like (x, y, mask) and not in the structure (x, (y, mask)). I would have to a Lambda dataset for it?

Also, would pytoune be able to infer the device from the original pytorch model?

@freud14
Copy link
Collaborator

freud14 commented Mar 4, 2019

The output of the dataset is a tuple like (x, y, mask) and not in the structure (x, (y, mask)). I would have to a Lambda dataset for it?

Yes. More specifically, you have to pass a function for the parameter collate_fn to the Dataloader that transforms your (x, y, mask) tuple into a (x, (y, mask)) tuple.

Also, would pytoune be able to infer the device from the original pytorch model?

Yes if you've called local_model.to(device) beforehand.

@freud14
Copy link
Collaborator

freud14 commented Mar 5, 2019

class ClientDataset(torch.utils.data.Dataset):
    def __getitem__(self, index):
        ...
        ...
        return x, y, mask

def my_collate_function(samples):
    x, y, mask = list(zip(*samples)) # transform sample list into a list for each input.
    return x, (y, mask)


loader = torch.utils.data.DataLoader(my_dataset_instance, collate_fn=my_collate_function)

Here is a code skeleton for what I mean.

@prabhuteja12
Copy link
Author

Hi,

Thank you for the suggestion :) I had something quite similar too..
Btw, do you have any profiling information about how much overhead PyToune adds to pytorch.

@freud14
Copy link
Collaborator

freud14 commented Mar 6, 2019

Hi,
I did not do any profiling but it shouldn't add any observable overhead since you'd do the things it does anyway if you had your own training loop.

@freud14
Copy link
Collaborator

freud14 commented Mar 12, 2019

Hi, have I answered to all your questions?

@prabhuteja12
Copy link
Author

Hi @freud14 ,
I was just about to reply to this when I saw your message. Yes! Thanks for clearing all of my questions.
I'll close this now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants