From 2678b36c3508ace5fbe91c44312798ef082fefad Mon Sep 17 00:00:00 2001 From: mrdrozdov Date: Thu, 17 Feb 2022 12:36:12 -0500 Subject: [PATCH] Resume mode for aligner. --- align_cfg/main.py | 30 +++++++++++++++++++----------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/align_cfg/main.py b/align_cfg/main.py index 80eb7eaa..9f8837ee 100644 --- a/align_cfg/main.py +++ b/align_cfg/main.py @@ -1181,7 +1181,7 @@ def forward(self, batch_map, input_vector=None): return output, labels, labels_mask, label_node_ids -def load_checkpoint(path, net, cuda=False): +def load_checkpoint(path, net, opt, cuda=False): try: toload = torch.load(path) except: @@ -1195,7 +1195,7 @@ def load_checkpoint(path, net, cuda=False): print('[load] copying {}'.format(k)) toload['state_dict'][k] = state_dict[k] - assert v.shape == toload['state_dict'][k].shape, k + assert v.shape == toload['state_dict'][k].shape, (k, v.shape, toload['state_dict'][k].shape) seen.add(k) @@ -1208,8 +1208,14 @@ def load_checkpoint(path, net, cuda=False): # TODO: Verify that vocab lines up. net.load_state_dict(toload['state_dict']) + try: + opt.load_state_dict(toload['opt_state_dict']) + except Exception as e: + print(e) + print('WARNING: Failed to load opt state dict. Be careful if resuming training.') + -def save_checkpoint(path, dataset, net, metrics=None): +def save_checkpoint(path, dataset, net, opt, metrics=None): state_dict = net.state_dict() for k, v in net.named_parameters(): @@ -1222,6 +1228,7 @@ def save_checkpoint(path, dataset, net, metrics=None): tosave['text_vocab'] = dataset.text_tokenizer.vocab tosave['amr_vocab'] = dataset.amr_tokenizer.vocab tosave['metrics'] = metrics + tosave['opt_state_dict'] = opt.state_dict() try: torch.save(tosave, path, _use_new_zipfile_serialization=False) @@ -1570,8 +1577,13 @@ def main(args): # Init model. net = Net.from_dataset_and_config(trn_dataset, model_config, args.cache_dir) + + # OPTIMIZER + opt = optim.Adam(net.parameters(), lr=lr) + + # LOAD if args.load is not None: - load_checkpoint(args.load, net) + load_checkpoint(args.load, net, opt) if args.cuda: net.cuda() @@ -1585,10 +1597,6 @@ def main(args): maybe_write(context) - # OPTIMIZER - - opt = optim.Adam(net.parameters(), lr=lr) - # CACHE dataset items. for dset in [trn_dataset] + val_dataset_list: @@ -1691,10 +1699,10 @@ def func(): print('trn epoch = {}, loss = {:.3f}, loss-nr = {:.3f}, ppl = {:.3f}, pr = {:.3f}'.format( epoch, trn_loss, trn_loss_notreduced, trn_ppl, trn_pr)) - save_checkpoint(os.path.join(args.log_dir, 'model.latest.pt'), trn_dataset, net, metrics=dict(epoch=epoch, trn_loss=trn_loss, trn_loss_notreduced=trn_loss_notreduced)) + save_checkpoint(os.path.join(args.log_dir, 'model.latest.pt'), trn_dataset, net, opt, metrics=dict(epoch=epoch, trn_loss=trn_loss, trn_loss_notreduced=trn_loss_notreduced)) - if (epoch + 1) % args.save_every_epoch == 0: - save_checkpoint(os.path.join(args.log_dir, 'model.epoch_{}.pt'.format(epoch)), trn_dataset, net, metrics=dict(epoch=epoch, trn_loss=trn_loss, trn_loss_notreduced=trn_loss_notreduced)) + if epoch % args.save_every_epoch == 0: + save_checkpoint(os.path.join(args.log_dir, 'model.epoch_{}.pt'.format(epoch)), trn_dataset, net, opt, metrics=dict(epoch=epoch, trn_loss=trn_loss, trn_loss_notreduced=trn_loss_notreduced)) # VALIDATION