# Training cycle debugging

Basically run code copied from train_seq2seq.py in this notebook to catch bugs

In [1]:
import os
os.environ['ALFRED_ROOT'] = '/home/hoyeung/alfred/'

import sys
sys.path.append(os.path.join(os.environ['ALFRED_ROOT']))
sys.path.append(os.path.join(os.environ['ALFRED_ROOT'], 'models'))

# from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
import torch
import pprint
import json
from data.preprocess import Dataset
from importlib import import_module, reload
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
from models.utils.helper_utils import optimizer_to

In [2]:
# import these if we want to debug a model such as seq2seq_nl_baseline.py

import torch
import pprint
import json
from data.preprocess import Dataset
from importlib import import_module, reload
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
from models.utils.helper_utils import optimizer_to

# Parser

In [3]:
# These are the default flags present in train_seq2seq.py

parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
args = parser.parse_args('')

# settings
args.seed = 123
args.data = 'data/json_feat_2.1.0'
args.splits = 'data/splits/oct21.json'
args.preprocess = False #!
args.pp_folder = 'pp'
args.monitor_train_every = 10
args.save_every_epoch = False #!
args.model = 'seq2seq_nl_baseline'
args.gpu = True
args.dout = 'exp/model:seq2seq_nl_baseline'
args.resume = False #!

# hyper parameters
args.batch = 8
args.epoch = 20
args.lr = 1e-4
args.decay_epoch = 10
args.dhid = 512
args.dframe = 2500
args.demb = 100
args.pframe = 300
args.mask_loss_wt = 1.
args.action_loss_wt = 1.
args.subgoal_aux_loss_wt = 0.
args.pm_aux_loss_wt = 0.

# dropouts
args.zero_goal = False #!
args.zero_instr = False #!
args.act_dropout = 0.
args.lang_dropout = 0.
args.input_dropout = 0.
args.vis_dropout = 0.3
args.hstate_dropout = 0.3
args.attn_dropout = 0.
args.actor_dropout = 0.
args.word_dropout = 0.

# other settings
args.dec_teacher_forcing = False #!
args.temp_no_history = False #!

# debugging
args.fast_epoch = False #!
args.dataset_fraction = 0

In [4]:
# overwrite the default flags

args.preprocess = False # Turn this to True if running for the first time

args.model = 'seq2seq_nl_baseline'  # found under models/model/ directory
args.dout = 'exp/model:seq2seq_nl_baseline'

args.dec_teacher_forcing = True
args.gpu = False

# light setup for debugging
args.fast_epoch = True # Turn this to False if running for the first time to preprocess data properly
args.epoch = 20

In [5]:
torch.manual_seed(args.seed)

<torch._C.Generator at 0x7f92fc05f670>

# Setup and load data splits

In [6]:
# make output dir
if not os.path.isdir(args.dout):
    os.makedirs(args.dout)

print('args.out ', args.dout)

args.out  exp/model:seq2seq_nl_baseline


In [7]:
# load train/valid/tests splits
with open(args.splits) as f:
    splits = json.load(f)
    pprint.pprint({k: len(v) for k, v in splits.items()})

{'tests_seen': 1533,
 'tests_unseen': 1529,
 'train': 21023,
 'valid_seen': 820,
 'valid_unseen': 821}


In [8]:
# preprocess and save -- only need to preprocess once
if args.preprocess:
    print("\nPreprocessing dataset and saving to %s folders ... This will take a while. Do this once as required." % args.pp_folder)
    dataset = Dataset(args, None)
    dataset.preprocess_splits(splits)
    vocab = torch.load(os.path.join(args.dout, "%s.vocab" % args.pp_folder))
else:
    vocab = torch.load(os.path.join(args.data, "%s.vocab" % args.pp_folder))
    
print(vocab)

{'word': Vocab(2360), 'action_low': Vocab(15), 'action_high': Vocab(93)}


# Model and Training

In [9]:
# load the model architecture
args.gpu = False
args.resume = None # os.path.join(args.dout, 'best_seen.pth')

M = import_module('model.{}'.format(args.model))
if args.resume:
    print("Loading: " + args.resume)
    model, optimizer, start_epoch = M.Module.load(args.resume)
    print("Restarting at epoch {}/{}".format(start_epoch, args.epoch-1))
    if start_epoch >= args.epoch:
        print('Checkpoint already finished {}/{} epochs.'.format(start_epoch, args.epoch))
        sys.exit(0)
else:
    model = M.Module(args, vocab)
    optimizer = None
    start_epoch = 0

if args.gpu:
    model = model.to(torch.device('cuda'))
    if not optimizer is None:
        optimizer_to(optimizer, torch.device('cuda'))

In [10]:
# examine model layers
model

Module(
  (emb_word): Embedding(2360, 100)
  (emb_action_low): Embedding(15, 100)
  (enc): LSTM(100, 512, batch_first=True, bidirectional=True)
  (enc_att): SelfAttn(
    (scorer): Linear(in_features=1024, out_features=1, bias=True)
  )
  (dec): LanguageDecoder(
    (emb): Embedding(2360, 100)
    (cell): LSTMCell(1124, 1024)
    (attn): DotAttn()
    (input_dropout): Dropout(p=0.0, inplace=False)
    (attn_dropout): Dropout(p=0.0, inplace=False)
    (hstate_dropout): Dropout(p=0.3, inplace=False)
    (word_dropout): Dropout(p=0.0, inplace=False)
    (word): Linear(in_features=2148, out_features=100, bias=True)
    (h_tm1_fc): Linear(in_features=1024, out_features=1024, bias=True)
  )
  (act_dropout): Dropout(p=0.0, inplace=True)
)

In [11]:
# main training loop -- debug here if breakpoints were inserted
model.run_train(splits, optimizer=optimizer, start_epoch=start_epoch)

epoch:   0%|          | 0/20 [00:00<?, ?it/s]
batch:   0%|          | 0/2 [00:00<?, ?it/s][A

Saving to: exp/model:seq2seq_nl_baseline



batch:  50%|█████     | 1/2 [00:04<00:04,  4.96s/it][A
batch: 100%|██████████| 2/2 [00:10<00:00,  5.35s/it][A

batch:   0%|          | 0/2 [00:00<?, ?it/s][A
batch:  50%|█████     | 1/2 [00:03<00:03,  3.95s/it][A
batch: 100%|██████████| 2/2 [00:06<00:00,  3.21s/it][A
The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()

batch:   0%|          | 0/2 [00:00<?, ?it/s][A
ba


Found new best valid_seen!! Saving...
Found new best valid_unseen!! Saving...


epoch:   5%|▌         | 1/20 [00:29<09:26, 29.80s/it]
batch:   0%|          | 0/2 [00:00<?, ?it/s][A

{'epoch': 0,
 'train': {'loss_lang_instr': 5.071922063827515,
           'total_loss': 5.071922063827515},
 'valid_seen': {'BLEU': 3.471718950714187e-232,
                'loss_lang_instr': 4.792473793029785,
                'total_loss': 4.792473793029785},
 'valid_unseen': {'BLEU': 2.8859132646597725e-232,
                  'loss_lang_instr': 4.029761552810669,
                  'total_loss': 4.029761552810669}}
epoch_time                    29.8                                    
forward_batch_train           10.712                                  
forward_batch_valid_seen      6.431                                   
forward_batch_valid_unseen    4.808                                   
compute_metrics_valid_seen    3.995                                   
compute_metrics_valid_unseen  3.379                                   
torch_save_valid_unseen       0.114                                   
torch_save_valid_seen         0.113                                   
torch_save_las


batch:  50%|█████     | 1/2 [00:02<00:02,  2.88s/it][A
batch: 100%|██████████| 2/2 [00:05<00:00,  2.76s/it][A

batch:   0%|          | 0/2 [00:00<?, ?it/s][A
batch:  50%|█████     | 1/2 [00:01<00:01,  1.36s/it][A
batch: 100%|██████████| 2/2 [00:02<00:00,  1.24s/it][A

batch:   0%|          | 0/2 [00:00<?, ?it/s][A
batch:  50%|█████     | 1/2 [00:01<00:01,  1.04s/it][A
batch: 100%|██████████| 2/2 [00:02<00:00,  1.07s/it][A



Found new best valid_seen!! Saving...
Found new best valid_unseen!! Saving...


epoch:  10%|█         | 2/20 [00:45<07:38, 25.49s/it]
batch:   0%|          | 0/2 [00:00<?, ?it/s][A

{'epoch': 1,
 'train': {'loss_lang_instr': 4.661640405654907,
           'total_loss': 4.661640405654907},
 'valid_seen': {'BLEU': 1.0420441773186205e-156,
                'loss_lang_instr': 4.327688097953796,
                'total_loss': 4.327688097953796},
 'valid_unseen': {'BLEU': 6.4378053614892896e-232,
                  'loss_lang_instr': 3.788380742073059,
                  'total_loss': 3.788380742073059}}
epoch_time                    45.241                                  
forward_batch_train           16.23                                   
forward_batch_valid_seen      8.919                                   
forward_batch_valid_unseen    6.972                                   
compute_metrics_valid_seen    4.249                                   
compute_metrics_valid_unseen  3.649                                   
torch_save_valid_seen         3.436                                   
torch_save_last               1.241                                   
torch_save_va


batch:  50%|█████     | 1/2 [00:03<00:03,  3.98s/it][A
batch: 100%|██████████| 2/2 [00:07<00:00,  3.87s/it][A

batch:   0%|          | 0/2 [00:00<?, ?it/s][A
batch:  50%|█████     | 1/2 [00:01<00:01,  1.43s/it][A
batch: 100%|██████████| 2/2 [00:02<00:00,  1.43s/it][A

batch:   0%|          | 0/2 [00:00<?, ?it/s][A
batch:  50%|█████     | 1/2 [00:01<00:01,  1.59s/it][A
batch: 100%|██████████| 2/2 [00:03<00:00,  1.66s/it][A



Found new best valid_seen!! Saving...
Found new best valid_unseen!! Saving...


epoch:  15%|█▌        | 3/20 [01:00<06:23, 22.53s/it]
batch:   0%|          | 0/2 [00:00<?, ?it/s][A

{'epoch': 2,
 'train': {'loss_lang_instr': 4.332609176635742,
           'total_loss': 4.332609176635742},
 'valid_seen': {'BLEU': 1.1532114505674793e-156,
                'loss_lang_instr': 4.21867561340332,
                'total_loss': 4.21867561340332},
 'valid_unseen': {'BLEU': 1.1667015974228088e-156,
                  'loss_lang_instr': 3.5581694841384888,
                  'total_loss': 3.5581694841384888}}
epoch_time                    60.856                                  
forward_batch_train           23.966                                  
forward_batch_valid_seen      11.81                                   
forward_batch_valid_unseen    10.322                                  
compute_metrics_valid_seen    4.519                                   
torch_save_valid_seen         4.102                                   
compute_metrics_valid_unseen  3.907                                   
torch_save_last               1.393                                   
torch_save_va


batch:  50%|█████     | 1/2 [00:03<00:03,  3.38s/it][A
batch: 100%|██████████| 2/2 [00:06<00:00,  3.38s/it][A

batch:   0%|          | 0/2 [00:00<?, ?it/s][A
batch:  50%|█████     | 1/2 [00:01<00:01,  1.35s/it][A
batch: 100%|██████████| 2/2 [00:02<00:00,  1.27s/it][A

batch:   0%|          | 0/2 [00:00<?, ?it/s][A
batch:  50%|█████     | 1/2 [00:01<00:01,  1.40s/it][A
batch: 100%|██████████| 2/2 [00:02<00:00,  1.36s/it][A


Found new best valid_unseen!! Saving...


epoch:  20%|██        | 4/20 [01:14<05:17, 19.84s/it]
batch:   0%|          | 0/2 [00:00<?, ?it/s][A

{'epoch': 3,
 'train': {'loss_lang_instr': 4.126596212387085,
           'total_loss': 4.126596212387085},
 'valid_seen': {'BLEU': 1.1532114505674793e-156,
                'loss_lang_instr': 4.300110816955566,
                'total_loss': 4.300110816955566},
 'valid_unseen': {'BLEU': 2.1693517720781004e-156,
                  'loss_lang_instr': 3.5424221754074097,
                  'total_loss': 3.5424221754074097}}
epoch_time                    74.406                                  
forward_batch_train           30.726                                  
forward_batch_valid_seen      14.371                                  
forward_batch_valid_unseen    13.071                                  
compute_metrics_valid_seen    4.793                                   
compute_metrics_valid_unseen  4.167                                   
torch_save_valid_seen         4.102                                   
torch_save_last               1.541                                   
torch_save_


batch:  50%|█████     | 1/2 [00:03<00:03,  3.25s/it][A
batch: 100%|██████████| 2/2 [00:07<00:00,  3.53s/it][A

batch:   0%|          | 0/2 [00:00<?, ?it/s][A
batch:  50%|█████     | 1/2 [00:01<00:01,  1.32s/it][A
batch: 100%|██████████| 2/2 [00:02<00:00,  1.26s/it][A

batch:   0%|          | 0/2 [00:00<?, ?it/s][A
batch:  50%|█████     | 1/2 [00:01<00:01,  1.26s/it][A
batch: 100%|██████████| 2/2 [00:02<00:00,  1.26s/it][A



Found new best valid_seen!! Saving...
Found new best valid_unseen!! Saving...


epoch:  25%|██▌       | 5/20 [01:28<04:30, 18.03s/it]
batch:   0%|          | 0/2 [00:00<?, ?it/s][A

{'epoch': 4,
 'train': {'loss_lang_instr': 4.101462244987488,
           'total_loss': 4.101462244987488},
 'valid_seen': {'BLEU': 3.05968150351665e-80,
                'loss_lang_instr': 4.090612888336182,
                'total_loss': 4.090612888336182},
 'valid_unseen': {'BLEU': 6.397363076815268e-80,
                  'loss_lang_instr': 3.4359012842178345,
                  'total_loss': 3.4359012842178345}}
epoch_time                    88.218                                  
forward_batch_train           37.782                                  
forward_batch_valid_seen      16.907                                  
forward_batch_valid_unseen    15.601                                  
compute_metrics_valid_seen    5.069                                   
torch_save_valid_seen         4.758                                   
compute_metrics_valid_unseen  4.477                                   
torch_save_last               1.692                                   
torch_save_valid


batch:  50%|█████     | 1/2 [00:02<00:02,  2.85s/it][A
batch: 100%|██████████| 2/2 [00:07<00:00,  3.73s/it][A

batch:   0%|          | 0/2 [00:00<?, ?it/s][A
batch:  50%|█████     | 1/2 [00:01<00:01,  1.39s/it][A
batch: 100%|██████████| 2/2 [00:02<00:00,  1.42s/it][A

batch:   0%|          | 0/2 [00:00<?, ?it/s][A
batch:  50%|█████     | 1/2 [00:01<00:01,  1.20s/it][A
batch: 100%|██████████| 2/2 [00:02<00:00,  1.31s/it][A


Found new best valid_unseen!! Saving...


epoch:  30%|███       | 6/20 [01:42<03:57, 16.98s/it]
batch:   0%|          | 0/2 [00:00<?, ?it/s][A

{'epoch': 5,
 'train': {'loss_lang_instr': 3.7053922414779663,
           'total_loss': 3.7053922414779663},
 'valid_seen': {'BLEU': 3.05968150351665e-80,
                'loss_lang_instr': 4.1355592012405396,
                'total_loss': 4.1355592012405396},
 'valid_unseen': {'BLEU': 6.461741836178205e-80,
                  'loss_lang_instr': 3.4506880044937134,
                  'total_loss': 3.4506880044937134}}
epoch_time                    102.734                                 
forward_batch_train           45.247                                  
forward_batch_valid_seen      19.764                                  
forward_batch_valid_unseen    18.245                                  
compute_metrics_valid_seen    5.333                                   
compute_metrics_valid_unseen  4.76                                    
torch_save_valid_seen         4.758                                   
torch_save_valid_unseen       1.887                                   
torch_save_l


batch:  50%|█████     | 1/2 [00:05<00:05,  5.35s/it][A
batch: 100%|██████████| 2/2 [00:08<00:00,  4.22s/it][A

batch:   0%|          | 0/2 [00:00<?, ?it/s][A
batch:  50%|█████     | 1/2 [00:01<00:01,  1.41s/it][A
batch: 100%|██████████| 2/2 [00:02<00:00,  1.33s/it][A

batch:   0%|          | 0/2 [00:00<?, ?it/s][A
batch:  50%|█████     | 1/2 [00:01<00:01,  1.50s/it][A
batch: 100%|██████████| 2/2 [00:04<00:00,  2.05s/it][A



Found new best valid_seen!! Saving...
Found new best valid_unseen!! Saving...


epoch:  35%|███▌      | 7/20 [01:59<03:40, 16.96s/it]
batch:   0%|          | 0/2 [00:00<?, ?it/s][A

{'epoch': 6,
 'train': {'loss_lang_instr': 3.6735459566116333,
           'total_loss': 3.6735459566116333},
 'valid_seen': {'BLEU': 3.3230108299214737e-80,
                'loss_lang_instr': 4.290937900543213,
                'total_loss': 4.290937900543213},
 'valid_unseen': {'BLEU': 6.95797567463231e-80,
                  'loss_lang_instr': 3.5591847896575928,
                  'total_loss': 3.5591847896575928}}
epoch_time                    119.668                                 
forward_batch_train           53.698                                  
forward_batch_valid_seen      22.45                                   
forward_batch_valid_unseen    22.375                                  
compute_metrics_valid_seen    5.613                                   
torch_save_valid_seen         5.415                                   
compute_metrics_valid_unseen  5.034                                   
torch_save_valid_unseen       2.036                                   
torch_save_la


batch:  50%|█████     | 1/2 [00:03<00:03,  3.70s/it][A
batch: 100%|██████████| 2/2 [00:08<00:00,  4.06s/it][A

batch:   0%|          | 0/2 [00:00<?, ?it/s][A
batch:  50%|█████     | 1/2 [00:01<00:01,  1.45s/it][A
batch: 100%|██████████| 2/2 [00:02<00:00,  1.38s/it][A

batch:   0%|          | 0/2 [00:00<?, ?it/s][A
batch:  50%|█████     | 1/2 [00:01<00:01,  1.23s/it][A
batch: 100%|██████████| 2/2 [00:02<00:00,  1.22s/it][A



Found new best valid_seen!! Saving...
Found new best valid_unseen!! Saving...


epoch:  40%|████      | 8/20 [02:14<03:16, 16.40s/it]
batch:   0%|          | 0/2 [00:00<?, ?it/s][A

{'epoch': 7,
 'train': {'loss_lang_instr': 3.5684820413589478,
           'total_loss': 3.5684820413589478},
 'valid_seen': {'BLEU': 0.0010979161592196537,
                'loss_lang_instr': 4.2684584856033325,
                'total_loss': 4.2684584856033325},
 'valid_unseen': {'BLEU': 8.332688496156772e-80,
                  'loss_lang_instr': 3.5604666471481323,
                  'total_loss': 3.5604666471481323}}
epoch_time                    134.75                                  
forward_batch_train           61.819                                  
forward_batch_valid_seen      25.224                                  
forward_batch_valid_unseen    24.836                                  
torch_save_valid_seen         6.073                                   
compute_metrics_valid_seen    5.913                                   
compute_metrics_valid_unseen  5.314                                   
torch_save_valid_unseen       2.186                                   
torch_save_


batch:  50%|█████     | 1/2 [00:03<00:03,  3.21s/it][A
batch: 100%|██████████| 2/2 [00:06<00:00,  3.44s/it][A

batch:   0%|          | 0/2 [00:00<?, ?it/s][A
batch:  50%|█████     | 1/2 [00:01<00:01,  1.39s/it][A
batch: 100%|██████████| 2/2 [00:02<00:00,  1.26s/it][A

batch:   0%|          | 0/2 [00:00<?, ?it/s][A
batch:  50%|█████     | 1/2 [00:01<00:01,  1.35s/it][A
batch: 100%|██████████| 2/2 [00:02<00:00,  1.25s/it][A


Found new best valid_unseen!! Saving...


epoch:  45%|████▌     | 9/20 [02:28<02:50, 15.50s/it]
batch:   0%|          | 0/2 [00:00<?, ?it/s][A

{'epoch': 8,
 'train': {'loss_lang_instr': 3.420241594314575,
           'total_loss': 3.420241594314575},
 'valid_seen': {'BLEU': 1.2449261541901141e-79,
                'loss_lang_instr': 4.183590531349182,
                'total_loss': 4.183590531349182},
 'valid_unseen': {'BLEU': 1.1007265126069781e-79,
                  'loss_lang_instr': 3.5060874223709106,
                  'total_loss': 3.5060874223709106}}
epoch_time                    148.159                                 
forward_batch_train           68.699                                  
forward_batch_valid_seen      27.762                                  
forward_batch_valid_unseen    27.346                                  
compute_metrics_valid_seen    6.178                                   
torch_save_valid_seen         6.073                                   
compute_metrics_valid_unseen  5.578                                   
torch_save_valid_unseen       2.841                                   
torch_save_la


batch:  50%|█████     | 1/2 [00:03<00:03,  3.51s/it][A
batch: 100%|██████████| 2/2 [00:06<00:00,  3.06s/it][A

batch:   0%|          | 0/2 [00:00<?, ?it/s][A
batch:  50%|█████     | 1/2 [00:01<00:01,  1.27s/it][A
batch: 100%|██████████| 2/2 [00:02<00:00,  1.26s/it][A

batch:   0%|          | 0/2 [00:00<?, ?it/s][A
batch:  50%|█████     | 1/2 [00:01<00:01,  1.31s/it][A
batch: 100%|██████████| 2/2 [00:02<00:00,  1.31s/it][A


Found new best valid_unseen!! Saving...


epoch:  50%|█████     | 10/20 [02:41<02:27, 14.71s/it]
batch:   0%|          | 0/2 [00:00<?, ?it/s][A

{'epoch': 9,
 'train': {'loss_lang_instr': 3.5509294271469116,
           'total_loss': 3.5509294271469116},
 'valid_seen': {'BLEU': 1.7796410444059788e-79,
                'loss_lang_instr': 4.20244288444519,
                'total_loss': 4.20244288444519},
 'valid_unseen': {'BLEU': 0.0012093792678755603,
                  'loss_lang_instr': 3.4363412857055664,
                  'total_loss': 3.4363412857055664}}
epoch_time                    161.009                                 
forward_batch_train           74.813                                  
forward_batch_valid_seen      30.294                                  
forward_batch_valid_unseen    29.987                                  
compute_metrics_valid_seen    6.487                                   
torch_save_valid_seen         6.073                                   
compute_metrics_valid_unseen  5.878                                   
torch_save_valid_unseen       3.504                                   
torch_save_las


batch:  50%|█████     | 1/2 [00:04<00:04,  4.46s/it][A
batch: 100%|██████████| 2/2 [00:08<00:00,  4.42s/it][A

batch:   0%|          | 0/2 [00:00<?, ?it/s][A
batch:  50%|█████     | 1/2 [00:01<00:01,  1.24s/it][A
batch: 100%|██████████| 2/2 [00:02<00:00,  1.19s/it][A

batch:   0%|          | 0/2 [00:00<?, ?it/s][A
batch:  50%|█████     | 1/2 [00:01<00:01,  1.33s/it][A
batch: 100%|██████████| 2/2 [00:02<00:00,  1.46s/it][A



Found new best valid_seen!! Saving...


epoch:  55%|█████▌    | 11/20 [02:57<02:16, 15.15s/it]
batch:   0%|          | 0/2 [00:00<?, ?it/s][A

{'epoch': 10,
 'train': {'BLEU': 0.007043339815624692,
           'loss_lang_instr': 3.239856481552124,
           'total_loss': 3.239856481552124},
 'valid_seen': {'BLEU': 0.001420766451137129,
                'loss_lang_instr': 4.201016783714294,
                'total_loss': 4.201016783714294},
 'valid_unseen': {'BLEU': 0.0012093792678755603,
                  'loss_lang_instr': 3.4575302600860596,
                  'total_loss': 3.4575302600860596}}
epoch_time                    177.183                                 
forward_batch_train           83.648                                  
forward_batch_valid_unseen    32.938                                  
forward_batch_valid_seen      32.699                                  
compute_metrics_valid_seen    6.775                                   
torch_save_valid_seen         6.732                                   
compute_metrics_valid_unseen  6.156                                   
torch_save_valid_unseen       3.504          


batch:  50%|█████     | 1/2 [00:03<00:03,  3.68s/it][A
batch: 100%|██████████| 2/2 [00:07<00:00,  3.60s/it][A

batch:   0%|          | 0/2 [00:00<?, ?it/s][A
batch:  50%|█████     | 1/2 [00:01<00:01,  1.48s/it][A
batch: 100%|██████████| 2/2 [00:02<00:00,  1.38s/it][A

batch:   0%|          | 0/2 [00:00<?, ?it/s][A
batch:  50%|█████     | 1/2 [00:01<00:01,  1.18s/it][A
batch: 100%|██████████| 2/2 [00:02<00:00,  1.18s/it][A



Found new best valid_seen!! Saving...


epoch:  60%|██████    | 12/20 [03:11<01:58, 14.77s/it]
batch:   0%|          | 0/2 [00:00<?, ?it/s][A

{'epoch': 11,
 'train': {'loss_lang_instr': 3.269356608390808,
           'total_loss': 3.269356608390808},
 'valid_seen': {'BLEU': 0.0017290277616482297,
                'loss_lang_instr': 4.214152097702026,
                'total_loss': 4.214152097702026},
 'valid_unseen': {'BLEU': 0.0012093792678755603,
                  'loss_lang_instr': 3.4412035942077637,
                  'total_loss': 3.4412035942077637}}
epoch_time                    191.073                                 
forward_batch_train           90.852                                  
forward_batch_valid_seen      35.474                                  
forward_batch_valid_unseen    35.324                                  
torch_save_valid_seen         7.386                                   
compute_metrics_valid_seen    7.072                                   
compute_metrics_valid_unseen  6.434                                   
torch_save_valid_unseen       3.504                                   
torch_save_las


batch:  50%|█████     | 1/2 [00:02<00:02,  2.94s/it][A
batch: 100%|██████████| 2/2 [00:05<00:00,  2.81s/it][A

batch:   0%|          | 0/2 [00:00<?, ?it/s][A
batch:  50%|█████     | 1/2 [00:01<00:01,  1.25s/it][A
batch: 100%|██████████| 2/2 [00:02<00:00,  1.22s/it][A

batch:   0%|          | 0/2 [00:00<?, ?it/s][A
batch:  50%|█████     | 1/2 [00:01<00:01,  1.29s/it][A
batch: 100%|██████████| 2/2 [00:03<00:00,  1.52s/it][A
epoch:  65%|██████▌   | 13/20 [03:24<01:39, 14.21s/it]
batch:   0%|          | 0/2 [00:00<?, ?it/s][A

{'epoch': 12,
 'train': {'loss_lang_instr': 3.225303053855896,
           'total_loss': 3.225303053855896},
 'valid_seen': {'BLEU': 0.0017290277616482297,
                'loss_lang_instr': 4.218314051628113,
                'total_loss': 4.218314051628113},
 'valid_unseen': {'BLEU': 0.0012093792678755603,
                  'loss_lang_instr': 3.464127779006958,
                  'total_loss': 3.464127779006958}}
epoch_time                    203.978                                 
forward_batch_train           96.466                                  
forward_batch_valid_unseen    38.373                                  
forward_batch_valid_seen      37.922                                  
torch_save_valid_seen         7.386                                   
compute_metrics_valid_seen    7.367                                   
compute_metrics_valid_unseen  6.706                                   
torch_save_last               3.782                                   
torch_save_valid


batch:  50%|█████     | 1/2 [00:02<00:02,  2.57s/it][A
batch: 100%|██████████| 2/2 [00:05<00:00,  2.69s/it][A

batch:   0%|          | 0/2 [00:00<?, ?it/s][A
batch:  50%|█████     | 1/2 [00:01<00:01,  1.37s/it][A
batch: 100%|██████████| 2/2 [00:02<00:00,  1.24s/it][A

batch:   0%|          | 0/2 [00:00<?, ?it/s][A
batch:  50%|█████     | 1/2 [00:01<00:01,  1.25s/it][A
batch: 100%|██████████| 2/2 [00:02<00:00,  1.16s/it][A
epoch:  70%|███████   | 14/20 [03:35<01:20, 13.40s/it]
batch:   0%|          | 0/2 [00:00<?, ?it/s][A

{'epoch': 13,
 'train': {'loss_lang_instr': 3.228430986404419,
           'total_loss': 3.228430986404419},
 'valid_seen': {'BLEU': 0.0017290277616482297,
                'loss_lang_instr': 4.224886894226074,
                'total_loss': 4.224886894226074},
 'valid_unseen': {'BLEU': 0.0012093792678755603,
                  'loss_lang_instr': 3.4623749256134033,
                  'total_loss': 3.4623749256134033}}
epoch_time                    215.495                                 
forward_batch_train           101.84                                  
forward_batch_valid_unseen    40.706                                  
forward_batch_valid_seen      40.421                                  
compute_metrics_valid_seen    7.628                                   
torch_save_valid_seen         7.386                                   
compute_metrics_valid_unseen  6.966                                   
torch_save_last               4.435                                   
torch_save_val


batch:  50%|█████     | 1/2 [00:02<00:02,  2.60s/it][A
batch: 100%|██████████| 2/2 [00:06<00:00,  3.19s/it][A

batch:   0%|          | 0/2 [00:00<?, ?it/s][A
batch:  50%|█████     | 1/2 [00:02<00:02,  2.15s/it][A
batch: 100%|██████████| 2/2 [00:04<00:00,  2.25s/it][A

batch:   0%|          | 0/2 [00:00<?, ?it/s][A
batch:  50%|█████     | 1/2 [00:01<00:01,  1.26s/it][A
batch: 100%|██████████| 2/2 [00:02<00:00,  1.19s/it][A



Found new best valid_seen!! Saving...


epoch:  75%|███████▌  | 15/20 [03:50<01:09, 13.82s/it]
batch:   0%|          | 0/2 [00:00<?, ?it/s][A

{'epoch': 14,
 'train': {'loss_lang_instr': 3.2111645936965942,
           'total_loss': 3.2111645936965942},
 'valid_seen': {'BLEU': 0.004694297237361548,
                'loss_lang_instr': 4.19797158241272,
                'total_loss': 4.19797158241272},
 'valid_unseen': {'BLEU': 0.0012093792678755603,
                  'loss_lang_instr': 3.460483193397522,
                  'total_loss': 3.460483193397522}}
epoch_time                    230.298                                 
forward_batch_train           108.23                                  
forward_batch_valid_seen      44.942                                  
forward_batch_valid_unseen    43.106                                  
torch_save_valid_seen         8.048                                   
compute_metrics_valid_seen    7.891                                   
compute_metrics_valid_unseen  7.24                                    
torch_save_last               4.584                                   
torch_save_valid_


batch:  50%|█████     | 1/2 [00:02<00:02,  2.66s/it][A
batch: 100%|██████████| 2/2 [00:05<00:00,  2.79s/it][A

batch:   0%|          | 0/2 [00:00<?, ?it/s][A
batch:  50%|█████     | 1/2 [00:01<00:01,  1.28s/it][A
batch: 100%|██████████| 2/2 [00:02<00:00,  1.19s/it][A

batch:   0%|          | 0/2 [00:00<?, ?it/s][A
batch:  50%|█████     | 1/2 [00:01<00:01,  1.16s/it][A
batch: 100%|██████████| 2/2 [00:02<00:00,  1.25s/it][A



Found new best valid_seen!! Saving...
Found new best valid_unseen!! Saving...


epoch:  80%|████████  | 16/20 [04:02<00:53, 13.33s/it]
batch:   0%|          | 0/2 [00:00<?, ?it/s][A

{'epoch': 15,
 'train': {'loss_lang_instr': 3.230161666870117,
           'total_loss': 3.230161666870117},
 'valid_seen': {'BLEU': 0.005666077834081894,
                'loss_lang_instr': 4.2196513414382935,
                'total_loss': 4.2196513414382935},
 'valid_unseen': {'BLEU': 0.003183516167733599,
                  'loss_lang_instr': 3.4635472297668457,
                  'total_loss': 3.4635472297668457}}
epoch_time                    242.464                                 
forward_batch_train           113.817                                 
forward_batch_valid_seen      47.343                                  
forward_batch_valid_unseen    45.628                                  
torch_save_valid_seen         8.705                                   
compute_metrics_valid_seen    8.165                                   
compute_metrics_valid_unseen  7.513                                   
torch_save_last               4.735                                   
torch_save_val


batch:  50%|█████     | 1/2 [00:03<00:03,  3.11s/it][A
batch: 100%|██████████| 2/2 [00:05<00:00,  3.00s/it][A

batch:   0%|          | 0/2 [00:00<?, ?it/s][A
batch:  50%|█████     | 1/2 [00:01<00:01,  1.69s/it][A
batch: 100%|██████████| 2/2 [00:02<00:00,  1.49s/it][A

batch:   0%|          | 0/2 [00:00<?, ?it/s][A
batch:  50%|█████     | 1/2 [00:01<00:01,  1.53s/it][A
batch: 100%|██████████| 2/2 [00:02<00:00,  1.35s/it][A
epoch:  85%|████████▌ | 17/20 [04:15<00:39, 13.26s/it]
batch:   0%|          | 0/2 [00:00<?, ?it/s][A

{'epoch': 16,
 'train': {'loss_lang_instr': 3.169914484024048,
           'total_loss': 3.169914484024048},
 'valid_seen': {'BLEU': 0.005654025358317886,
                'loss_lang_instr': 4.202825427055359,
                'total_loss': 4.202825427055359},
 'valid_unseen': {'BLEU': 0.003183516167733599,
                  'loss_lang_instr': 3.486653447151184,
                  'total_loss': 3.486653447151184}}
epoch_time                    255.555                                 
forward_batch_train           119.82                                  
forward_batch_valid_seen      50.341                                  
forward_batch_valid_unseen    48.343                                  
torch_save_valid_seen         8.705                                   
compute_metrics_valid_seen    8.448                                   
compute_metrics_valid_unseen  7.809                                   
torch_save_last               5.39                                    
torch_save_valid_u


batch:  50%|█████     | 1/2 [00:02<00:02,  2.66s/it][A
batch: 100%|██████████| 2/2 [00:05<00:00,  2.92s/it][A

batch:   0%|          | 0/2 [00:00<?, ?it/s][A
batch:  50%|█████     | 1/2 [00:01<00:01,  1.38s/it][A
batch: 100%|██████████| 2/2 [00:02<00:00,  1.22s/it][A

batch:   0%|          | 0/2 [00:00<?, ?it/s][A
batch:  50%|█████     | 1/2 [00:01<00:01,  1.10s/it][A
batch: 100%|██████████| 2/2 [00:02<00:00,  1.11s/it][A



Found new best valid_seen!! Saving...
Found new best valid_unseen!! Saving...


epoch:  90%|█████████ | 18/20 [04:27<00:25, 12.94s/it]
batch:   0%|          | 0/2 [00:00<?, ?it/s][A

{'epoch': 17,
 'train': {'loss_lang_instr': 3.2095788717269897,
           'total_loss': 3.2095788717269897},
 'valid_seen': {'BLEU': 0.006230817273158473,
                'loss_lang_instr': 4.206936478614807,
                'total_loss': 4.206936478614807},
 'valid_unseen': {'BLEU': 0.005084008324076588,
                  'loss_lang_instr': 3.5179401636123657,
                  'total_loss': 3.5179401636123657}}
epoch_time                    267.76                                  
forward_batch_train           125.669                                 
forward_batch_valid_seen      52.802                                  
forward_batch_valid_unseen    50.581                                  
torch_save_valid_seen         9.36                                    
compute_metrics_valid_seen    8.72                                    
compute_metrics_valid_unseen  8.087                                   
torch_save_last               5.537                                   
torch_save_val


batch:  50%|█████     | 1/2 [00:02<00:02,  2.71s/it][A
batch: 100%|██████████| 2/2 [00:05<00:00,  2.93s/it][A

batch:   0%|          | 0/2 [00:00<?, ?it/s][A
batch:  50%|█████     | 1/2 [00:01<00:01,  1.17s/it][A
batch: 100%|██████████| 2/2 [00:02<00:00,  1.07s/it][A

batch:   0%|          | 0/2 [00:00<?, ?it/s][A
batch:  50%|█████     | 1/2 [00:01<00:01,  1.20s/it][A
batch: 100%|██████████| 2/2 [00:02<00:00,  1.16s/it][A
epoch:  95%|█████████▌| 19/20 [04:39<00:12, 12.57s/it]
batch:   0%|          | 0/2 [00:00<?, ?it/s][A

{'epoch': 18,
 'train': {'loss_lang_instr': 3.1863350868225098,
           'total_loss': 3.1863350868225098},
 'valid_seen': {'BLEU': 0.006230817273158473,
                'loss_lang_instr': 4.167171001434326,
                'total_loss': 4.167171001434326},
 'valid_unseen': {'BLEU': 0.002381648093789713,
                  'loss_lang_instr': 3.522326707839966,
                  'total_loss': 3.522326707839966}}
epoch_time                    279.445                                 
forward_batch_train           131.527                                 
forward_batch_valid_seen      54.948                                  
forward_batch_valid_unseen    52.916                                  
torch_save_valid_seen         9.36                                    
compute_metrics_valid_seen    8.983                                   
compute_metrics_valid_unseen  8.373                                   
torch_save_last               6.194                                   
torch_save_valid


batch:  50%|█████     | 1/2 [00:04<00:04,  4.80s/it][A
batch: 100%|██████████| 2/2 [00:09<00:00,  4.64s/it][A

batch:   0%|          | 0/2 [00:00<?, ?it/s][A
batch:  50%|█████     | 1/2 [00:01<00:01,  1.35s/it][A
batch: 100%|██████████| 2/2 [00:02<00:00,  1.26s/it][A

batch:   0%|          | 0/2 [00:00<?, ?it/s][A
batch:  50%|█████     | 1/2 [00:01<00:01,  1.15s/it][A
batch: 100%|██████████| 2/2 [00:02<00:00,  1.15s/it][A



Found new best valid_seen!! Saving...


epoch: 100%|██████████| 20/20 [04:57<00:00, 14.89s/it]

{'epoch': 19,
 'train': {'loss_lang_instr': 3.1608837842941284,
           'total_loss': 3.1608837842941284},
 'valid_seen': {'BLEU': 0.006351188978705782,
                'loss_lang_instr': 4.219351053237915,
                'total_loss': 4.219351053237915},
 'valid_unseen': {'BLEU': 0.001968472204978389,
                  'loss_lang_instr': 3.4981753826141357,
                  'total_loss': 3.4981753826141357}}
epoch_time                    297.739                                 
forward_batch_train           140.8                                   
forward_batch_valid_seen      57.485                                  
forward_batch_valid_unseen    55.246                                  
torch_save_valid_seen         11.96                                   
compute_metrics_valid_seen    9.276                                   
compute_metrics_valid_unseen  8.696                                   
torch_save_last               6.95                                    
torch_save_val


