Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Best practice for mix precision training #57

Closed
qingerVT opened this issue Mar 10, 2021 · 17 comments
Closed

Best practice for mix precision training #57

qingerVT opened this issue Mar 10, 2021 · 17 comments

Comments

@qingerVT
Copy link

qingerVT commented Mar 10, 2021

I am trying to fine-tune CLIP models on new datasets. What's the best practice for mix precision training?

Using Adam, I got errors either nan or inf since attribute eps is hard to specify for Half and float32. My walk around is to divide parameters into two groups and specify different eps. Any better solutions?

@KeremTurgutlu
Copy link

KeremTurgutlu commented Mar 10, 2021

You can use fastai, which already provides to_fp16() under the hood it uses torch autocast. So it's just as simple as adding that single line to your code. There might be other repos like pytorch lightning, apex, etc.. which allow to do mixed precision training with similar ease.

Also, I've already implemented training code in fastai, you can also take a look at it. You just need to add lines for loading the open source CLIP weights for finetuning, my code is currently for training from scratch but it should be just a couple of lines of change. Code is here, an example sample script.

@qingerVT
Copy link
Author

Thanks @KeremTurgutlu. The basic idea is to convert parameters into half and then use torch autocast?

@KeremTurgutlu
Copy link

Thanks @KeremTurgutlu. The basic idea is to convert parameters into half and then use torch autocast?

You can read materials from https://pytorch.org/docs/stable/amp.html.

@jongwook
Copy link
Collaborator

We did the memory-intensive tensor operations (matmul, convolutions) in fp16, while doing aggregation (batchnorms/layernorms) in fp32. We also kept the model weights and gradients in fp32 (which only takes constant memory w.r.t. the batch size) so that Adam only sees fp32 numbers.

For a general introduction to mixed-precision training, I found this video tutorial and document by NVIDIA and particularly useful.

@qingerVT
Copy link
Author

@jongwook Thanks. fp32 for optimization and fp16 for forward/backward. Switching back and forth works!

@WujiangXu
Copy link

@jongwook Thanks. fp32 for optimization and fp16 for forward/backward. Switching back and forth works!

hello, I met the same problem with you.
Can you show the specific solution about this problem. I try the code as below but it not work for me.
with amp.autocast(): out = net(photo1,token)

@qingerVT
Copy link
Author

qingerVT commented Mar 11, 2021

@cupcakefromchina Convert parameters and grads to fp 32 before applying Adam, then convert it back

def convert_models_to_fp32(model): 
    for p in model.parameters(): 
        p.data = p.data.float() 
        p.grad.data = p.grad.data.float() 
def convert_models_to_mix(model):
    clip.model.convert_weights(model)

---- your train function ------

convert_models_to_fp32(model) 
optimizer.step()
convert_models_to_mix(model)

@WujiangXu
Copy link

@cupcakefromchina Convert parameters and grads to fp 32 before applying Adam, then convert it back

def convert_models_to_fp32(model): 
    for p in model.parameters(): 
        p.data = p.data.float() 
        p.grad.data = p.grad.data.float() 
def convert_models_to_mix(model):
    clip.model.convert_weights(model)

---- your train function ------

convert_models_to_fp32(model) 
optimizer.step()
convert_models_to_mix(model)

Thank you for your solution. But there is still a question. Which object is the "clip" represent? I don't understand this function and find no api in the original code.
def convert_models_to_mix(model): clip.model.convert_weights(model)

@nishanthcgit
Copy link

@cupcakefromchina here

@nishanthcgit
Copy link

@qingerVT I was trying out training on my own dataset too and was facing the exact same issue. Thanks for pointing out the stuff about mixed precision training, will definitely try it out, thanks!
I'm curious though, are you able to see the loss decreasing in that case? For me it always ends up blowing up, maybe small batch size is a problem, I have only a single GPU machine.
@KeremTurgutlu Could you weigh in too if training is working out? Also, I've taken a look at your code and it would be really useful to me. Could you maybe point out how I can go about adding the open souce CLIP weights for finetuning? Still quite new to pytorch.

@KeremTurgutlu
Copy link

@KeremTurgutlu Could you weigh in too if training is working out? Also, I've taken a look at your code and it would be really useful to me. Could you maybe point out how I can go about adding the open souce CLIP weights for finetuning? Still quite new to pytorch.

Yes, I have successfully trained CLIP on my own data and performance is close to open source model with just 5.5 million image-text pairs. I am working on wrapping up the code for both training from scratch and fine-tuning. In the meantime you can watch this branch, should be complete in 1-2 days.

@qingerVT
Copy link
Author

@cupcakefromchina Convert parameters and grads to fp 32 before applying Adam, then convert it back

def convert_models_to_fp32(model): 
    for p in model.parameters(): 
        p.data = p.data.float() 
        p.grad.data = p.grad.data.float() 
def convert_models_to_mix(model):
    clip.model.convert_weights(model)

---- your train function ------

convert_models_to_fp32(model) 
optimizer.step()
convert_models_to_mix(model)

Thank you for your solution. But there is still a question. Which object is the "clip" represent? I don't understand this function and find no api in the original code.
def convert_models_to_mix(model): clip.model.convert_weights(model)

yes, loss is decreasing. Function convert_weights() is in clip/model.py. Try
export PYTHONPATH=/your/path/CLIP:$PYTHONPATH

@vinson2233
Copy link

can I know what's the differences between

def convert_models_to_fp32(model): 
    for p in model.parameters(): 
        p.data = p.data.float() 
        p.grad.data = p.grad.data.float() 

vs
model.float() ?

@qingerVT
Copy link
Author

can I know what's the differences between

def convert_models_to_fp32(model): 
    for p in model.parameters(): 
        p.data = p.data.float() 
        p.grad.data = p.grad.data.float() 

vs
model.float() ?

Ah, no difference :)

@nbl97
Copy link

nbl97 commented Oct 1, 2021

@jongwook. @qingerVT . Hi. I have a simple question. I understand the solution below about the mix precision training. I know the converting to fp16 is to save cuda memory, with negligible loss of Acc. However, what are the advantages of the below referenced method over the usual methods, such as Apex and torch.cuda.amp work well?

@cupcakefromchina Convert parameters and grads to fp 32 before applying Adam, then convert it back

def convert_models_to_fp32(model): 
    for p in model.parameters(): 
        p.data = p.data.float() 
        p.grad.data = p.grad.data.float() 
def convert_models_to_mix(model):
    clip.model.convert_weights(model)

---- your train function ------

convert_models_to_fp32(model) 
optimizer.step()
convert_models_to_mix(model)

@jongwook
Copy link
Collaborator

jongwook commented Oct 1, 2021

@nbl97 Not much, just that we wanted to have a granular control over which operation is in which dtype. Now that torch AMP is mature, I'd use it if I were to start a project from scratch.

@nbl97
Copy link

nbl97 commented Oct 3, 2021

@nbl97 Not much, just that we wanted to have a granular control over which operation is in which dtype. Not that torch AMP is mature, I'd use it if I were to start a project from scratch.

@jongwook As we all know, converting fp32 to fp16 may drop Acc, however, what is the performance if we convert fp16 to fp32 ? When useing CLIP pretrained model on my own task, what is the performance if I always use fp32 and never convert the model? Do you have experience about that? Thanks in advance.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants