Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug?] Interesting load behavior with amp checkpointing #706

Open
voldemortX opened this issue Feb 6, 2020 · 3 comments
Open

[Bug?] Interesting load behavior with amp checkpointing #706

voldemortX opened this issue Feb 6, 2020 · 3 comments

Comments

@voldemortX
Copy link

I was just training my models in mixed precision today, and I found this very interesting thing about checkpointing with amp.
The thing is if I load some weights into my amp initialized model, then test on it in eval mode, and get a result(say result No.1), then I load another set of weights, test it in eval mode, and I still get the same result as result No.1, as opposed to what I should be getting(that's result No.2), sometimes the result is neither result No.1 nor result No.2, but I can't reproduce that.
So I reproduced the all result No.1 case(with PyTorch 1.2.0 & CUDA 10.0):

  1. Generate the different weights using seed 4396 with code Part 1
  2. Keep Part 3 commented, and run it, the results are the same, but the loaded weights and amp.state_dict() are not the same, now this is interesting.
  3. If I try to run this without amp(normal float32 training, without amp.initialize() , and load with is_mixed_precision=False ), then the results are different as expected, so this is probably not a problem in PyTorch.
  4. Now comes the most bizarre part , if I uncomment Part 3 (and uncomment the seed as well), the results are also different as expected. So the thing is if 2 consecutive load has no training in between, the second load is none and void .
import torch
from apex import amp

# All hail Clearlove, 7th of his name!
# torch.manual_seed(4396)


# Save model checkpoints(supports amp)
def save_checkpoint(net, optimizer, lr_scheduler, is_mixed_precision, filename='temp.pt'):
    checkpoint = {
        'model': net.state_dict(),
        'optimizer': optimizer.state_dict() if optimizer is not None else None,
        'lr_scheduler': lr_scheduler.state_dict() if lr_scheduler is not None else None,
        'amp': amp.state_dict() if is_mixed_precision else None
    }
    torch.save(checkpoint, filename)


# Load model checkpoints(supports amp)
def load_checkpoint(net, optimizer, lr_scheduler, is_mixed_precision, filename):
    checkpoint = torch.load(filename)
    net.load_state_dict(checkpoint['model'])
    if optimizer is not None:
        optimizer.load_state_dict(checkpoint['optimizer'])
    if lr_scheduler is not None:
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
    if is_mixed_precision and checkpoint['amp'] is not None:
        amp.load_state_dict(checkpoint['amp'])


# Test model
class TinyNet(torch.nn.Module):
    def __init__(self):
        super(TinyNet, self).__init__()
        self.fc = torch.nn.Linear(3, 3)

    def forward(self, x):
        x = self.fc(x)

        return x


if __name__ == '__main__':
    # Part 1, seed 4396
    ################
    # net1 = TinyNet()
    # net2 = TinyNet()
    # device = torch.device('cuda:0')
    # net1.to(device)
    # net2.to(device)
    # print(net1.state_dict())
    # print(net2.state_dict())
    # [net1, net2] = amp.initialize(models=[net1, net2], opt_level='O1')
    # save_checkpoint(net1, None, None, True, 'weights1.pt')
    # save_checkpoint(net2, None, None, True, 'weights2.pt')
    ################

    net = TinyNet()
    device = torch.device('cuda:0')
    net.to(device)
    op = torch.optim.SGD(net.parameters(), lr=0.1)
    net, op = amp.initialize(net, op, opt_level='O1')
    x = torch.ones((1, 3)).to(device)

    load_checkpoint(net, None, None, True, 'weights1.pt')
    print(net.state_dict())
    print(amp.state_dict())
    net.eval()
    with torch.no_grad():
        print(net(x))

    # Part 3
    ################
    # net.train()
    # z = torch.randn((1, 3)).to(device)
    # c = torch.nn.MSELoss()
    # y = torch.tensor([[1., 1., 1.]]).to(device)
    # t = net(z)  # Anything
    # print(t)
    # l = c(t, y)
    # with amp.scale_loss(l, op) as sl:
    #     sl.backward()
    ################

    load_checkpoint(net, None, None, True, 'weights2.pt')
    print(net.state_dict())
    print(amp.state_dict())
    net.eval()
    with torch.no_grad():
        print(net(x))

@voldemortX
Copy link
Author

I literally have no idea why this happened.

@donglixp
Copy link
Contributor

donglixp commented Feb 11, 2020

Is this the same issue as in #480 ?

@voldemortX
Copy link
Author

I think they are different(as this problem seems to have nothing to do with DDP or optim), but they might just be the same problem "fundamentally".

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants