# Training cycle debugging

Basically run code copied from train_seq2seq.py in this notebook to catch bugs

In [1]:
import os
os.environ['ALFRED_ROOT'] = '/root/data/home/hoyeung/alfred/'

import sys
sys.path.append(os.path.join(os.environ['ALFRED_ROOT']))
sys.path.append(os.path.join(os.environ['ALFRED_ROOT'], 'models'))

# from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
import torch
import pprint
import json
from data.preprocess import Dataset
from importlib import import_module, reload
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
from models.utils.helper_utils import optimizer_to

# Parser

In [2]:
# These are the default flags present in train_seq2seq.py

parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
args = parser.parse_args('')

# settings
args.seed = 123
args.data = '/root/data_alfred/json_feat_2.1.0'
args.splits = '/root/data_alfred/splits/may17.json'
args.object_vocab = 'objects_20200522'
args.preprocess = False #!
args.pp_folder = 'pp'
args.monitor_train_every = 10
args.save_every_epoch = False #!
args.model = 'seq2seq_per_subgoal'
args.gpu = True
args.dout = 'exp/model:seq2seq_per_subgoal'
args.resume = False #!

# hyper parameters
args.batch = 8
args.epoch = 20
args.lr = 1e-4
args.decay_epoch = 10
args.dhid = 512
args.dframe = 2500
args.demb = 100
args.pframe = 300
args.mask_loss_wt = 1.
args.action_loss_wt = 1.
args.subgoal_aux_loss_wt = 0.
args.pm_aux_loss_wt = 0.

# architecture ablations
# args.maxpool_over_object_states = False
# args.aux_loss_over_object_states = False
args.encoder_addons = 'none'
args.decoder_addons = 'none'
args.object_repr = 'type'
args.reweight_aux_bce = False

# dropouts
args.zero_goal = False #!
args.zero_instr = False #!
args.act_dropout = 0.
args.lang_dropout = 0.
args.input_dropout = 0.
args.vis_dropout = 0.3
args.hstate_dropout = 0.3
args.attn_dropout = 0.
args.actor_dropout = 0.
args.word_dropout = 0.

# other settings
args.train_teacher_forcing = False #!
args.train_student_forcing_prob = 0.1
args.temp_no_history = False #!

# debugging
args.fast_epoch = False #!
args.dataset_fraction = 0

In [3]:
# overwrite the default flags

args.preprocess = False # Turn this to True if running for the first time

args.model = 'seq2seq_per_subgoal'  # found under models/model/ directory
args.dout = '/root/data_alfred/exp/model:seq2seq_per_subgoal_biattn_instance_fast_epoch'

args.train_teacher_forcing = True
args.gpu = False

# light setup for debugging
args.fast_epoch = True # Turn this to False if running for the first time to preprocess data properly
args.epoch = 5

In [4]:
# args.maxpool_over_object_states = True
# args.aux_loss_over_object_states = True

args.encoder_addons = 'biattn_obj'
args.decoder_addons = 'aux_loss'
args.object_repr = 'instance'
args.reweight_aux_bce = True


In [5]:
torch.manual_seed(args.seed)

<torch._C.Generator at 0x7f724809d5d0>

# Setup and load data splits

In [6]:
# make output dir
if not os.path.isdir(args.dout):
    os.makedirs(args.dout)

print('args.out ', args.dout)

args.out  /root/data_alfred/exp/model:seq2seq_per_subgoal_biattn_instance_fast_epoch


In [7]:
# load train/valid/tests splits
with open(args.splits) as f:
    splits = json.load(f)
    pprint.pprint({k: len(v) for k, v in splits.items()})

{'tests_seen': 1533,
 'tests_unseen': 1529,
 'train': 20806,
 'train_sanity': 246,
 'train_sanity_v1': 246,
 'valid_seen': 814,
 'valid_seen_v1': 249,
 'valid_unseen': 818,
 'valid_unseen_v1': 254}


In [8]:
# preprocess and save -- only need to preprocess once
if args.preprocess:
    print("\nPreprocessing dataset and saving to %s folders ... This will take a while. Do this once as required." % args.pp_folder)
    dataset = Dataset(args, None)
    dataset.preprocess_splits(splits)
    vocab = torch.load(os.path.join(args.dout, "%s.vocab" % args.pp_folder))
else:
    vocab = torch.load(os.path.join(args.data, "%s.vocab" % args.pp_folder))
    
print(vocab)

{'action_high': Vocab(93), 'word': Vocab(2360), 'action_low': Vocab(15)}


In [9]:
object_vocab = torch.load(os.path.join(args.data, '%s.vocab' % args.object_vocab))

# Model and Training

In [10]:
# load the model architecture
args.gpu = False
args.resume = None # os.path.join(args.dout, 'best_seen.pth')

M = import_module('model.{}'.format(args.model))
if args.resume:
    print("Loading: " + args.resume)
    model, optimizer, start_epoch = M.Module.load(args.resume)
    print("Restarting at epoch {}/{}".format(start_epoch, args.epoch-1))
    if start_epoch >= args.epoch:
        print('Checkpoint already finished {}/{} epochs.'.format(start_epoch, args.epoch))
        sys.exit(0)
else:
    model = M.Module(args, vocab, object_vocab)
    optimizer = None
    start_epoch = 0

if args.gpu:
    model = model.to(torch.device('cuda'))
    if not optimizer is None:
        optimizer_to(optimizer, torch.device('cuda'))

In [11]:
# examine model layers
model

Module(
  (emb_word): Embedding(2360, 100)
  (emb_action_low): Embedding(15, 100)
  (emb_object): Embedding(111, 100, padding_idx=0)
  (instance_fc): Linear(in_features=201, out_features=512, bias=True)
  (enc): ActionFrameAttnEncoderPerSubgoalObjAttn(
    (emb): Embedding(15, 100)
    (vis_dropout): Dropout(p=0.3, inplace=False)
    (act_dropout): Dropout(p=0.0, inplace=True)
    (vis_encoder): ResnetVisualEncoder(
      (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
      (conv2): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
      (fc): Linear(in_features=3136, out_features=2500, bias=True)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (enc_att): SelfAttn(
      (scorer): Linear(in_features=1024, out_features=1, bias=True)
    )
    (encoder): LSTM(2600, 512, batch_first=True, bidirectional=True)
    (input_dropout)

In [12]:
# main training loop -- debug here if breakpoints were inserted
model.run_train(splits, optimizer=optimizer, start_epoch=start_epoch)

epoch:   0%|          | 0/50 [00:00<?, ?it/s]
batch:   0%|          | 0/2 [00:00<?, ?it/s][A

Saving to: /root/data_alfred/exp/model:seq2seq_per_subgoal_biattn_instance_fast_epoch


  empty_tensor = torch.ones(torch.tensor(v[0][0][0]).unsqueeze(0).shape, device=device, dtype=torch.float if ('frames' in k) else torch.long) * self.pad
  seqs.append(torch.tensor(v[subgoal_i][batch_i], device=device, dtype=torch.float if ('frames' in k) else torch.long))

batch:  50%|█████     | 1/2 [00:16<00:16, 16.84s/it][A
batch: 100%|██████████| 2/2 [00:35<00:00, 17.51s/it][A

batch:   0%|          | 0/2 [00:00<?, ?it/s][A
batch:  50%|█████     | 1/2 [00:03<00:03,  3.85s/it][A
batch: 100%|██████████| 2/2 [00:05<00:00,  2.85s/it][A

batch:   0%|          | 0/2 [00:00<?, ?it/s][A
batch:  50%|█████     | 1/2 [00:04<00:04,  4.25s/it][A
batch: 100%|██████████| 2/2 [00:15<00:00,  7.66s/it][A

batch:   0%|          | 0/2 [00:00<?, ?it/s][A
batch:  50%|█████     | 1/2 [00:04<00:04,  4.69s/it][A
batch: 100%|██████████| 2/2 [00:12<00:00,  6.10s/it][A

batch:   0%|          | 0/2 [00:00<?, ?it/s][A
batch:  50%|█████     | 1/2 [00:19<00:19, 19.59s/it][A
batch: 100%|██████████| 2/


Found new best valid_seen!! Saving...
Found new best valid_unseen!! Saving...
Found new best train_sanity!! Saving...


epoch:   2%|▏         | 1/50 [03:17<2:41:02, 197.20s/it]
batch:   0%|          | 0/2 [00:00<?, ?it/s][A

{'epoch': 0,
 'train_sanity': {'BLEU': 2.150927457282457e-232,
                  'perplexity': 262.3821563720703,
                  'total_loss': 1471.2233238220215},
 'valid_seen': {'BLEU': 2.575713873220704e-232,
                'perplexity': 279.86607360839844,
                'total_loss': 1694.620002746582},
 'valid_unseen': {'BLEU': 2.656039007870951e-232,
                  'perplexity': 283.98423767089844,
                  'total_loss': 115.86846160888672}}
epoch_time                    197.196                                 
compute_metrics_validation_sets158.658                                 
forward_batch_train_with_iterate35.02                                   
forward_batch_train           34.605                                  
torch_save_valid_seen         1.91                                    
torch_save_valid_unseen       0.461                                   
iterate_featurize             0.436                                   
torch_save_last               


batch:  50%|█████     | 1/2 [00:18<00:18, 18.58s/it][A
batch: 100%|██████████| 2/2 [00:34<00:00, 17.09s/it][A

batch:   0%|          | 0/2 [00:00<?, ?it/s][A
batch:  50%|█████     | 1/2 [00:04<00:04,  4.76s/it][A
batch: 100%|██████████| 2/2 [00:06<00:00,  3.36s/it][A

batch:   0%|          | 0/2 [00:00<?, ?it/s][A
batch:  50%|█████     | 1/2 [00:04<00:04,  4.58s/it][A
batch: 100%|██████████| 2/2 [00:09<00:00,  4.59s/it][A

batch:   0%|          | 0/2 [00:00<?, ?it/s][A
batch:  50%|█████     | 1/2 [00:04<00:04,  4.08s/it][A
batch: 100%|██████████| 2/2 [00:07<00:00,  3.89s/it][A

batch:   0%|          | 0/2 [00:00<?, ?it/s][A
batch:  50%|█████     | 1/2 [00:17<00:17, 17.89s/it][A
batch: 100%|██████████| 2/2 [00:27<00:00, 13.89s/it][A

batch:   0%|          | 0/2 [00:00<?, ?it/s][A
batch:  50%|█████     | 1/2 [00:23<00:23, 23.62s/it][A
batch: 100%|██████████| 2/2 [00:43<00:00, 21.82s/it][A

batch:   0%|          | 0/2 [00:00<?, ?it/s][A
batch:  50%|█████     | 1/2 [00:2


Found new best valid_seen!! Saving...
Found new best valid_unseen!! Saving...
Found new best train_sanity!! Saving...


epoch:   4%|▍         | 2/50 [06:26<2:35:51, 194.83s/it]
batch:   0%|          | 0/2 [00:00<?, ?it/s][A

{'epoch': 1,
 'train_sanity': {'BLEU': 3.0002294792922473e-232,
                  'perplexity': 147.49627685546875,
                  'total_loss': 3726.3920612335205},
 'valid_seen': {'BLEU': 3.552694461707597e-232,
                'perplexity': 160.00331115722656,
                'total_loss': 4326.737593173981},
 'valid_unseen': {'BLEU': 3.5704321450811896e-232,
                  'perplexity': 155.53447723388672,
                  'total_loss': 11.868228912353516}}
epoch_time                    386.499                                 
compute_metrics_validation_sets296.367                                 
forward_batch_train_with_iterate69.208                                  
forward_batch_train           68.343                                  
torch_save_last               7.305                                   
torch_save_valid_seen         6.82                                    
torch_save_valid_unseen       3.85                                    
torch_save_train_sanity    


batch:  50%|█████     | 1/2 [00:24<00:24, 24.24s/it][A
batch: 100%|██████████| 2/2 [00:41<00:00, 20.74s/it][A

batch:   0%|          | 0/2 [00:00<?, ?it/s][A
batch:  50%|█████     | 1/2 [00:05<00:05,  5.66s/it][A
batch: 100%|██████████| 2/2 [00:09<00:00,  4.80s/it][A

batch:   0%|          | 0/2 [00:00<?, ?it/s][A
batch:  50%|█████     | 1/2 [00:07<00:07,  7.65s/it][A
batch: 100%|██████████| 2/2 [00:12<00:00,  6.32s/it][A

batch:   0%|          | 0/2 [00:00<?, ?it/s][A
batch:  50%|█████     | 1/2 [00:07<00:07,  7.13s/it][A
batch: 100%|██████████| 2/2 [00:12<00:00,  6.35s/it][A

batch:   0%|          | 0/2 [00:00<?, ?it/s][A
batch:  50%|█████     | 1/2 [00:40<00:40, 40.34s/it][A
batch: 100%|██████████| 2/2 [00:50<00:00, 25.34s/it][A

batch:   0%|          | 0/2 [00:00<?, ?it/s][A
batch:  50%|█████     | 1/2 [00:39<00:39, 39.43s/it][A
batch: 100%|██████████| 2/2 [01:23<00:00, 41.55s/it][A

batch:   0%|          | 0/2 [00:00<?, ?it/s][A
batch:  50%|█████     | 1/2 [01:1

In [13]:
with open('/root/data_alfred/exp/model:seq2seq_per_subgoal_biattn_instance_fast_epoch/train_sanity.debug_epoch_0.preds.json', 'r') as f:
    debug_dat = json.load(f)

In [17]:
debug_dat['trial_T20190907_200104_945677_0'].keys()

dict_keys(['root', 'action_low', 'action_high', 'p_lang_instr', 'obj_token_id', 'p_obj_vis', 'p_state_change', 'l_obj_vis', 'l_state_change'])

In [19]:
debug_dat['trial_T20190907_200104_945677_0']['p_state_change']

{'0': [0.0691295862197876,
  0.31164777278900146,
  0.9970205426216125,
  0.015615975484251976,
  0.014363543130457401,
  0.17510473728179932,
  0.0005708530079573393,
  0.08195849508047104,
  0.9996793270111084,
  0.9995750784873962,
  0.9997149109840393,
  0.9995577931404114,
  0.9995660185813904,
  0.9997116923332214,
  0.9995922446250916,
  0.999503493309021,
  0.9993239641189575,
  0.9991711378097534,
  0.9988883137702942,
  0.9985225796699524,
  0.9984582662582397,
  0.49693214893341064,
  0.6733318567276001,
  0.6518407464027405,
  0.4758867025375366,
  0.3146660625934601,
  0.9996658563613892,
  2.830060475389473e-05,
  9.804693945625331e-06,
  0.06095347926020622,
  0.9994327425956726,
  9.529480303172022e-05,
  4.818804154638201e-05,
  0.999860405921936,
  0.9998681545257568,
  0.9996261596679688,
  0.9994298815727234,
  0.903156042098999,
  0.9140000343322754,
  0.0016576650086790323,
  0.5512198805809021,
  0.819766640663147,
  0.6092637181282043,
  0.9517333507537842,
  0.