Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

loss spike after checkpoint reload #480

Closed
williamFalcon opened this issue Sep 6, 2019 · 26 comments
Closed

loss spike after checkpoint reload #480

williamFalcon opened this issue Sep 6, 2019 · 26 comments

Comments

@williamFalcon
Copy link

williamFalcon commented Sep 6, 2019

When running a model using apex+ddt my loss spikes dramatically after the model restarts.
If i disable apex, it works fine.

Currently, I've set up apex this way:

optimizer = Adam()
schedulers= LR

torch.cuda.set_device(gpu_nb)
model.cuda(gpu_nb)

# apex
model, optimizer = amp.initialize(model, optimizer, opt_level='O2')    

# ddp
model = DDP(model)   

# restore state
ckpt = torch.load(path)
model.load_state_dict(ckpt['state_dict'])
optimizer.load_state_dict(ckpt['opt_dict'])
LR.load_state_dict(ckpt['lr_dict'])
amp.load_state_dict(ckpt['amp'])   

# continue ....

Actual code is here:

@williamFalcon williamFalcon changed the title Restore checkpoint in CPU first? Restore checkpoint on CPU first? Sep 6, 2019
@williamFalcon
Copy link
Author

image

@williamFalcon williamFalcon changed the title Restore checkpoint on CPU first? loss spike after checkpoint reload Sep 6, 2019
@williamFalcon
Copy link
Author

@ptrblck

@ptrblck
Copy link
Contributor

ptrblck commented Sep 6, 2019

Thanks for the code @williamFalcon!
We'll try to reproduce and debug it.

Do you see this loss spike only using LightningDistributedDataParallel or also PyTorch's DDP?

@williamFalcon
Copy link
Author

lightning’s ddp is pytorch ddp. except it routed the forward call to train_step or val_step. but otherwise the same

@williamFalcon
Copy link
Author

williamFalcon commented Sep 6, 2019

You can replicate this by doing the following:

model = MNISTModel()   

trainer = Trainer(gpus=[0,1], use_amp=True)  
trainer.fit(model)

Run the above script using a slurm script on a node with 2 gpus (i used 2 v100s with 32gb each). set the walltime to 10 mins (so the loss can go down). At 7 mins it'll resubmit itself and you'll see the problem.

Although MNIST might be too trivial. I'd try cifar-10 perhaps?

@williamFalcon
Copy link
Author

williamFalcon commented Sep 6, 2019

Here's a full working example to replicate. Make sure to install lightning from master:

import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR100
import torchvision.transforms as transforms
from test_tube import Experiment, SlurmCluster, HyperOptArgumentParser
import numpy as np

import pytorch_lightning as pl


PORT = np.random.randint(12000, 20000, 1)[0]
SEED = 2334
torch.manual_seed(SEED)
np.random.seed(SEED)

"""
To run in interactive node:
python issue.py

To run as a cluster job (submits job to cluster):
python issue.py --cluster 

"""


class CIFAR100LM(pl.LightningModule):

    def __init__(self, save_path):
        super(CIFAR100LM, self).__init__()

        self.save_path = save_path
        self.l1 = torch.nn.Linear(32 * 32*3, 1028)
        self.l2 = torch.nn.Linear(1028, 2048)
        self.l3 = torch.nn.Linear(2048, 100)

    def forward(self, x):
        x = x.view(x.size(0), -1)

        x = F.relu(self.l1(x))
        x = F.relu(self.l2(x))
        x = self.l3(x)

        return x

    def training_step(self, batch, batch_nb):
        x, y = batch
        y_hat = self.forward(x)
        return {'loss': F.cross_entropy(y_hat, y)}

    def validation_step(self, batch, batch_nb):
        x, y = batch
        y_hat = self.forward(x)
        return {'val_loss': F.cross_entropy(y_hat, y)}

    def validation_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        return {'avg_val_loss': avg_loss}

    def configure_optimizers(self):
        # can return multiple optimizers and learning_rate schedulers
        return torch.optim.Adam(self.parameters(), lr=0.0002)

    @pl.data_loader
    def tng_dataloader(self):
        return DataLoader(CIFAR100(self.save_path, train=True, download=True,
                                   transform=transforms.ToTensor()), batch_size=32)

    @pl.data_loader
    def val_dataloader(self):
        return DataLoader(CIFAR100(self.save_path, train=True, download=True,
                                   transform=transforms.ToTensor()), batch_size=32)


def main(*args, **kwargs):

    log_path = '/log/path'
    exp = Experiment(
        name='apex_bug_test',
        save_dir=log_path,
        autosave=False,
        description='amp test'
    )
    exp.save()

    data_path =  '/data/CIFAR100'
    model = CIFAR100LM(data_path)
    trainer = pl.Trainer(
        gpus=[0, 1],
        use_amp=True,
        experiment=exp,
        min_nb_epochs=200,
        distributed_backend='ddp'
    )
    trainer.fit(model)


def launch_cluster_job(args):
    # enable cluster training
    slurm_log_path = '/log/path'
    cluster = SlurmCluster(
        log_path=slurm_log_path,
        hyperparam_optimizer=args
    )

    # email for cluster coms
    cluster.notify_job_status(email='your@email.com', on_done=True, on_fail=True)

    # configure cluster
    cluster.per_experiment_nb_nodes = 1
    cluster.per_experiment_nb_gpus = 2
    cluster.job_time = '00:10:00'
    cluster.memory_mb_per_node = 0
    cluster.per_experiment_nb_cpus = 2

    # any modules for code to run in env
    cluster.add_command('source activate your_conda_env')
    cluster.add_command('export NCCL_SOCKET_IFNAME=^docker0,lo')
    cluster.add_command('export NCCL_DEBUG=INFO')
    cluster.add_command('export PYTHONFAULTHANDLER=1')
    cluster.add_command(f'export MASTER_PORT={PORT}')
    cluster.load_modules(['NCCL/2.4.7-1-cuda.10.0'])

    cluster.python_cmd = 'python'
    cluster.add_slurm_cmd(cmd='constraint', value='volta32gb', comment='use 32gb gpus')
    cluster.add_slurm_cmd(cmd='ntasks-per-node', value=2, comment='1 task per gpu')

    # name of exp
    job_display_name = 'apex_bug_test'

    # run hopt
    print('submitting jobs...')
    cluster.optimize_parallel_cluster_gpu(
        main,
        nb_trials=1,
        job_name=job_display_name
    )


if __name__ == '__main__':
    parser = HyperOptArgumentParser()
    parser.add_argument('--cluster', dest='cluster', action='store_true')
    args = parser.parse_args()

    if args.cluster:
        launch_cluster_job(args)
    else:
        main()

@williamFalcon
Copy link
Author

williamFalcon commented Sep 6, 2019

image

This is from the above example.

orange before checkpointing blue after resuming training

@williamFalcon
Copy link
Author

ok, digging into this more...

Here is the loss after 3 reloads (each color a different reload).
image

However, when tracking the training accuracy i see that it remains high even after reload even when the loss spikes.

image

This suggests that the model loads correctly, but the scaling is off (which has to do with amp.load_state_dict()).

In this simple model it's not a problem, but on more complex ones with losses sensitive to scaling, it nans out the losses after the model restarts.

@ibeltagy
Copy link

ibeltagy commented Oct 22, 2019

@williamFalcon, @ptrblck did you guys figure this out? we are having the same problem.

@ibeltagy
Copy link

Turns out the problem is related to reloading the optimizer. When you amp.initialize, the optimizer.state needs to be an empty dictionary, or this problem occurs. As a workaround, we empty optimizer.state before amp.initialize and use a short warmup for the Adam optimizer to recover its moving averages (thanks to @yaroslavvb for the warmup suggestion)

@mranzinger
Copy link

@ibeltagy Could you please explain how you're warming up the optimizer? All that is coming to mind for me is calling optim.step(), but that would actually apply the gradients to the model, which seems like exactly what I don't want to happen. Are you saving the original state of your model, stepping the optimizer a few times, and then restoring the model back to the original?

I'm dealing with a model that explodes after the first step upon reload, so aside from reloading and training with optim level O0, I'm not sure what to do.

@ibeltagy
Copy link

I mean slowly increasing learning rate from zero to the value you want.

@donglixp
Copy link
Contributor

We also encountered the same issue for PyTorch's DDP. @ptrblck The loss becomes quite large when we reload the checkpoint.

@krishansubudhi
Copy link

krishansubudhi commented Feb 17, 2020

Same issues here. Loss spikes after loading checkpoint . Not loading the optimizer helps sometimes but not always. Nvidia needs to fix this asap.

I don't see quick responses from them in GitHub.
Is it time to search for an alternative to apex!

@ibeltagy
Copy link

ibeltagy commented Feb 17, 2020

I found an ugly hack but it seems to work. It goes like this,

# load optimizer from file
optimizer = torch.optim.AdamW(...)
optimizer_state_dict = torch.load(f_opt)
optimizer.load_state_dict(optimizer_state_dict)

# then remove optimizer state to make amp happy
optimizer.state = {} 

# init amp and load it from file
model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
amp.load_state_dict(torch.load(f_amp))

# forward, backward, optimizer step, zero_grad
loss = model(random_input)
with amp.scale_loss(loss, optimizer) as scaled_loss:
    scaled_loss.backward()
optimizer.step() 
model.zero_grad()

# then load optimizer state_dict again (this time without removing optimizer.state)
optimizer.load_state_dict(optimizer_state_dict)

@krishansubudhi
Copy link

So amp is expecting one optimizer step before loading the checkpoint smoothly?

Has anyone looked into the source code to find the root cause ?

@ibeltagy
Copy link

no, amp.initialize is expecting an optimizer with an empty optimizer.state. With the first optimizer.step, the optimizer initializes its optimizer.state and it gets registered with apex somehow. At this step, replacing optimizer.state with a state from the checkpoint seems to work.

@daniel347x
Copy link

daniel347x commented Feb 19, 2020

Unfortunately, I cannot get the workaround to work. I may have to disable fp16 entirely - unless I find that the loss spike is harmless in terms of actual model improvement.

I am grateful for fp16, but this does seem like a nearly show-stopping issue! It should be possible to restart from a checkpoint and continue where you left off - right? I'm a bit surprised this isn't a bigger issue - are folks not using fp16?

@williamFalcon
Copy link
Author

i don’t remember how we solved this but we did in pytorch lightning. you could try running it there with fp16 enabled

@ibeltagy
Copy link

@williamFalcon, couldn't find the relevant code (not here https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/trainer/training_io.py#L374), and also didn't find amp.load_state_dict (which is needed for loss scale)

@mranzinger
Copy link

@daniel347x FWIW, I've found the severity of the bug to be somewhat task dependent. For some tasks, the loss spike disappears quickly, and it's indiscernible whether or not it had a lasting effect on convergence. For others, it will legitimately ruin the model.

My current best rule of thumb is that classification tasks are most affected, and particularly, as the number of possible classes increases, the worse the spike affects the model.

@krishansubudhi
Copy link

Can this be because model params are actually FP16 when they are saved. After loading optimizer gets the FP16 params as FP32 which causes loss in precission?

@qywu
Copy link

qywu commented Mar 22, 2020

I have also encountered the same problem, and the loss diverged at O2 level. But I found @ibeltagy 's hack is useful.
I suspect the problem is with scale_loss and amp.initialize not correctly handling optimizer when it is already loaded. Maybe a simple fix can be done here.

@enijkamp
Copy link

@ptrblck Any updates on this?

@williamFalcon
Copy link
Author

we’re now using native amp with pt 1.6+ on pytorch lightning. I would just switch to that.

@ruotianluo
Copy link

ruotianluo commented Sep 25, 2020

Add optimizer.load_state_dict right before the first optimizer.step works for me too. (I basically add a manual checkpoint loading in optimizer_step of LightningModule)

For general pytorch user, this is what I have:

state_dict = torch.load(optimizer_ckpt)
optimizer.load_state_dict(state_dict)
....
with amp.scale_loss(loss, optimizer) as scaled_loss:
    scaled_loss.backward()
optimizer.load_state_dict(state_dict)
optimizer.step()
optimizer.zero_grad()

It seems the reason is when you load the first time, the saved states will be cast to fp16, while at this time the states are not properly initialized because of the lazy_init. After the first time to call amp.scale_loss, the states of optimizer are properly initialized and the states will be recast to fp32, and the precision difference here will cause the problem.

However if amp_lazy_init earlier, the loss still spikes (but not that bad.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

10 participants