Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NaN values after a single gradient step #40

Closed
MartinPernus opened this issue Feb 10, 2021 · 10 comments
Closed

NaN values after a single gradient step #40

MartinPernus opened this issue Feb 10, 2021 · 10 comments

Comments

@MartinPernus
Copy link

MartinPernus commented Feb 10, 2021

Hi!

Using PyTorch 1.7.1, I get NaN values after a single parameter update:

import torch
import torch.nn as nn
import torch.nn.functional as F
import clip

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.model, _ = clip.load('RN50')

    def forward(self, imgs, tokens):
        image_features = self.model.encode_image(imgs)
        match_text_features = self.model.encode_text(tokens)

        image_features = image_features /  image_features.norm(dim=-1, keepdim=True)
        match_text_features = match_text_features / match_text_features.norm(dim=-1, keepdim=True)

        similarity_match = image_features @ match_text_features.T
        return similarity_match

def compute_loss(similarity_match, labels):
    loss1 = F.cross_entropy(similarity_match, labels)
    loss2 = F.cross_entropy(similarity_match.T, labels)
    loss = (loss1 + loss2) / 2
    return loss

model = Model().cuda()
optimizer = torch.optim.Adam(model.parameters())

imgs = torch.randn(8, 3, 224, 224).cuda()
tokens = torch.randint(high=1000, size=(8, 77)).cuda()
labels = torch.arange(8).cuda()

similarity_match = model(imgs, tokens)
loss = compute_loss(similarity_match, labels)
loss.backward()
optimizer.step()

print(model(imgs, tokens))

Output:

       [nan, nan, nan, nan, nan, nan, nan, nan],
       [nan, nan, nan, nan, nan, nan, nan, nan],
       [nan, nan, nan, nan, nan, nan, nan, nan],
       [nan, nan, nan, nan, nan, nan, nan, nan],
       [nan, nan, nan, nan, nan, nan, nan, nan],
       [nan, nan, nan, nan, nan, nan, nan, nan],
       [nan, nan, nan, nan, nan, nan, nan, nan]], device='cuda:0',
      dtype=torch.float16, grad_fn=<MmBackward>)
@jongwook
Copy link
Collaborator

jongwook commented Feb 12, 2021

Hi, the model checkpoint contains fp16 parameters for speed, but gradients for these weights are very prone to overflow/underflow without careful loss scaling, causing nan outputs after a gradient step.

You can avoid this by casting all weights to fp32 with model.float(). Also, pass jit=False to clip.load when doing this, because the JIT-compiled code contains hard-coded dtype values and won't work.

@MartinPernus
Copy link
Author

Thank you! Did not realize a single step already breaks the fp16 parameters. The code works now.

@NotNANtoN
Copy link

@MartinPernus Hey, I'm curious about fine-tuning CLIP.

Could you share some details on how many image-text pairs you used and what batch size? Did it work to fine-tune it? I'd appreciate any input!

@MartinPernus
Copy link
Author

@NotNANtoN
I tried to fine-tune it on face-text pairs of images. I tried different batch sizes. Unfortunately it did not work - while the training and validation loss of my dataset decreased, it did not generalize on the test set. The problem might be the text part - I generated it from manually from a set of binary attributes. I have also tried to include some domain adaptation techniques to no avail.

@NotNANtoN
Copy link

Interesting use-case, thanks!

Weird, that the validation loss decreases while the test loss doesn't, I think I've never seen that during my trainings.
But it seems like using text might be a bit overkill when you can just fine-tune the vision-encoder to predict your binary attributes.

@MartinPernus
Copy link
Author

I should have elaborated further. The training and validation splits were constructed manually from a set of binary attributes, while the test set belonged to an actual dataset of face descriptions + face images. The motivation for that was that we could in theory leverage datasets that contain binary attributes to improve CLIP performance for text-based face search. This could be useful in criminal investigations, where the victim describes the suspect and we would not have to constrain ourself to prefixed set of binary attributes.

Hope I have made myself clearer, I am happy to discuss this further.

@NotNANtoN
Copy link

Hi Martin, thanks that clears it up! Interesting use-case, but it makes sense it this is very hard for the model to extrapolate to fuller sentences.

Unfortunately, I can't think of a way to leverage datasets with binary attributes in a better way than you tried there.

@cheng321284
Copy link

Hi, the model checkpoint contains fp16 parameters for speed, but gradients for these weights are very prone to overflow/underflow without careful loss scaling, causing nan outputs after a gradient step.

You can avoid this by casting all weights to fp32 with model.float(). Also, pass jit=False to clip.load when doing this, because the JIT-compiled code contains hard-coded dtype values and won't work.

Thanks, it works for me~

@Alan-lab
Copy link

Alan-lab commented Sep 5, 2023

Freezing the loaded model may also solve the problem. For example:
model, _ = clip.load('VIT_32')
for name, param in model.named_parameters():
param.requires_grad = False

@Dinosaurcubs
Copy link

Hi, the model checkpoint contains fp16 parameters for speed, but gradients for these weights are very prone to overflow/underflow without careful loss scaling, causing nan outputs after a gradient step.

You can avoid this by casting all weights to fp32 with model.float(). Also, pass jit=False to clip.load when doing this, because the JIT-compiled code contains hard-coded dtype values and won't work.

I met the same promblom recently, nan after random epochs. Now I changed the model to fp32 and train again, but the training time is nearly double....Is there any way to use fp16 and avoid nan? Why the nan happen during training with fp16 model? Thanks for ur reply.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants