In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# !pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
# !pip install matplotlib numpy pandas tqdm nltk

# for separating ingredients vs non-ingredients
# NOTE: if using Windows to run this, need to download GNU Wget
# !wget -c https://raw.githubusercontent.com/williamLyh/RecipeWithPlans/main/ingredient_set.json -O ingredient_set.json

In [3]:
import os
import math
import re
import string
import numpy as np
import pandas as pd
import random
import json
from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR, MultiStepLR, CosineAnnealingWarmRestarts
import nltk
from nltk.translate.bleu_score import corpus_bleu, sentence_bleu
from nltk.translate import meteor

from data import *
from encoder_decoder import *
from train import *
from eval import *
from utils import *

# required for bleu
# nltk.download("wordnet")

  from .autonotebook import tqdm as notebook_tqdm


---

In [4]:
SEED = 31989101
HIDDEN_SIZE = 256
MAX_INGR_LEN = 150 # fixed from assignment
MAX_RECIPE_LEN = 600
DROPOUT = 0.1
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## ensuring reproducibility
def reset_rng():
    torch.manual_seed(SEED)
    np.random.seed(SEED)
    random.seed(SEED)

reset_rng()

# to easily read ingredients and instructions
pd.set_option('display.max_colwidth', 2000)

print(f"Using device: {DEVICE}")

Using device: cuda


In [5]:
data_root = "./Cooking_Dataset"
add_intermediate_tag=False

train_df_orig = pd.read_csv(os.path.join(data_root, "train.csv"), usecols=['Ingredients', 'Recipe'])
dev_df_orig = pd.read_csv(os.path.join(data_root, "dev.csv"), usecols=['Ingredients', 'Recipe'])
test_df_orig = pd.read_csv(os.path.join(data_root, "test.csv"), usecols=['Ingredients', 'Recipe'])

In [6]:
train_df = preprocess_data(train_df_orig, max_ingr_len=MAX_INGR_LEN, max_recipe_len=MAX_RECIPE_LEN, add_intermediate_tag=add_intermediate_tag)

Number of data samples before preprocessing: 101340
Number of data samples after preprocessing: 100637 (99.306%)


In [7]:
dev_df = preprocess_data(dev_df_orig, max_ingr_len=MAX_INGR_LEN, max_recipe_len=MAX_RECIPE_LEN, add_intermediate_tag=add_intermediate_tag)

Number of data samples before preprocessing: 797
Number of data samples after preprocessing: 793 (99.498%)


In [8]:
test_df = preprocess_data(test_df_orig, max_ingr_len=MAX_INGR_LEN, max_recipe_len=MAX_RECIPE_LEN, add_intermediate_tag=add_intermediate_tag)

Number of data samples before preprocessing: 778
Number of data samples after preprocessing: 774 (99.486%)


In [9]:
vocab = Vocabulary(add_intermediate_tag=add_intermediate_tag)
vocab.populate(train_df)
vocab.n_unique_words

  2%|▏         | 1555/100637 [00:00<00:06, 15547.09it/s]

100%|██████████| 100637/100637 [00:06<00:00, 15152.13it/s]


44315

In [10]:
train_ds = RecipeDataset(train_df, vocab)
# subset_train_ds = RecipeDataset(train_df[:250], vocab) # ! REMOVE LATER
dev_ds_val_loss = RecipeDataset(dev_df, vocab, train=True) # used for getting validation loss
dev_ds_val_met = RecipeDataset(dev_df, vocab, train=False) # used for getting validation BLEU, and other metrics
test_ds = RecipeDataset(test_df, vocab, train=False)

## Encoder-Decoder (Base)

In [11]:
embedding_size=300
encoder = EncoderRNN(vocab.n_unique_words, embedding_size=embedding_size, hidden_size=HIDDEN_SIZE, padding_value=vocab.word2index(PAD_WORD)).to(DEVICE)
# in the training script, decoder is always fed a non-end token and thus never needs to generate padding
# also it should never generate "<UNKNOWN>"
decoder = DecoderRNN(embedding_size=embedding_size,hidden_size=HIDDEN_SIZE, output_size=vocab.n_unique_words-1).to(DEVICE)

In [14]:
initial_lr=1e-3
min_lr = 1e-5
n_epochs = 30
batch_size=128
encoder_optimizer = optim.Adam(encoder.parameters(), lr=initial_lr)
decoder_optimizer = optim.Adam(decoder.parameters(), lr=initial_lr)
enc_scheduler = CosineAnnealingLR(encoder_optimizer, T_max=n_epochs, eta_min=min_lr)
dec_scheduler = CosineAnnealingLR(decoder_optimizer, T_max=n_epochs, eta_min=min_lr)
# enc_scheduler = MultiStepLR(encoder_optimizer, milestones=[15], gamma=0.1)
# dec_scheduler = MultiStepLR(decoder_optimizer, milestones=[15], gamma=0.1)
identifier="adam_without_intermediate_tags_wd0_lr1e-3"
epoch_losses, val_epoch_losses, log = train(encoder, decoder, encoder_optimizer, decoder_optimizer, train_ds, 
                     n_epochs=n_epochs, vocab=vocab, decoder_mode="basic", batch_size=batch_size, 
                     enc_lr_scheduler=enc_scheduler, dec_lr_scheduler=dec_scheduler, 
                     dev_ds_val_loss = dev_ds_val_loss, dev_ds_val_met=dev_ds_val_met, identifier=identifier,
                     verbose_iter_interval=10)

save_log(identifier, log, encoder_optimizer, decoder_optimizer, enc_scheduler, dec_scheduler)

Starting epoch 1/30, enc lr scheduler: [0.001], dec lr scheduler: [0.001]
(Epoch 0, iter 10/787) Average loss so far: 10.210
(Epoch 0, iter 20/787) Average loss so far: 6.945
(Epoch 0, iter 30/787) Average loss so far: 6.156
(Epoch 0, iter 40/787) Average loss so far: 6.180
(Epoch 0, iter 50/787) Average loss so far: 6.071
(Epoch 0, iter 60/787) Average loss so far: 6.040
(Epoch 0, iter 70/787) Average loss so far: 6.024
(Epoch 0, iter 80/787) Average loss so far: 5.991
(Epoch 0, iter 90/787) Average loss so far: 5.968
(Epoch 0, iter 100/787) Average loss so far: 5.922
(Epoch 0, iter 110/787) Average loss so far: 5.875
(Epoch 0, iter 120/787) Average loss so far: 5.845
(Epoch 0, iter 130/787) Average loss so far: 5.806
(Epoch 0, iter 140/787) Average loss so far: 5.735
(Epoch 0, iter 150/787) Average loss so far: 5.680
(Epoch 0, iter 160/787) Average loss so far: 5.623
(Epoch 0, iter 170/787) Average loss so far: 5.582
(Epoch 0, iter 180/787) Average loss so far: 5.544
(Epoch 0, iter 1

7it [00:03,  2.05it/s]


validation loss: 4.239502702440534


 14%|█▍        | 28/199 [00:15<01:36,  1.77it/s]


KeyboardInterrupt: 

In [None]:
save_model(encoder, decoder, "adam_without_intermediate_tags_wd0_lr1e-3_last")

## Encoder-Decoder (Attention)

In [11]:
reset_rng()

In [12]:
embedding_size=300
encoder_attn = EncoderRNN(vocab.n_unique_words, embedding_size=embedding_size, hidden_size=HIDDEN_SIZE, padding_value=vocab.word2index(PAD_WORD)).to(DEVICE)
# in the training script, decoder is always fed a non-end token and thus never needs to generate padding
# also it should never generate "<UNKNOWN>"
# decoder = DecoderRNN(embedding_size=embedding_size,hidden_size=HIDDEN_SIZE, output_size=vocab.n_unique_words-2).to(DEVICE)
decoder_attn = AttnDecoderRNN(embedding_size, hidden_size=HIDDEN_SIZE, output_size=vocab.n_unique_words-1, padding_val=vocab.word2index(PAD_WORD), 
                              dropout=DROPOUT).to(DEVICE)

In [13]:
initial_lr=1e-3
min_lr = 1e-5
n_epochs = 30
batch_size=128
encoder_attn_optimizer = optim.Adam(encoder_attn.parameters(), lr=initial_lr)
decoder_attn_optimizer = optim.Adam(decoder_attn.parameters(), lr=initial_lr)
# enc_attn_scheduler = CosineAnnealingLR(encoder_attn_optimizer, T_max=n_epochs, eta_min=min_lr)
# dec_attn_scheduler = CosineAnnealingLR(decoder_attn_optimizer, T_max=n_epochs, eta_min=min_lr)
enc_attn_scheduler = CosineAnnealingLR(encoder_attn_optimizer, T_max=n_epochs, eta_min=min_lr)
dec_attn_scheduler = CosineAnnealingLR(decoder_attn_optimizer, T_max=n_epochs, eta_min=min_lr)
identifier="attn_adam_without_intermediate_tags_wd0_lr1e-3"

attn_epoch_losses, attn_val_epoch_losses, attn_log = train(
    encoder_attn, decoder_attn, encoder_attn_optimizer, decoder_attn_optimizer, train_ds, 
    n_epochs=n_epochs, vocab=vocab, decoder_mode="attention", batch_size=batch_size, 
    enc_lr_scheduler=enc_attn_scheduler, dec_lr_scheduler=dec_attn_scheduler, 
    dev_ds_val_loss = dev_ds_val_loss, dev_ds_val_met=dev_ds_val_met, identifier=identifier,
    verbose_iter_interval=50)

save_log(identifier, attn_log, encoder_attn_optimizer, decoder_attn_optimizer, 
         enc_attn_scheduler, dec_attn_scheduler)

Starting epoch 1/30, enc lr scheduler: [0.001], dec lr scheduler: [0.001]
(Epoch 0, iter 50/787) Average loss so far: 6.972
(Epoch 0, iter 100/787) Average loss so far: 6.023
(Epoch 0, iter 150/787) Average loss so far: 5.968
(Epoch 0, iter 200/787) Average loss so far: 5.927
(Epoch 0, iter 250/787) Average loss so far: 5.877
(Epoch 0, iter 300/787) Average loss so far: 5.818
(Epoch 0, iter 350/787) Average loss so far: 5.785
(Epoch 0, iter 400/787) Average loss so far: 5.725
(Epoch 0, iter 450/787) Average loss so far: 5.683
(Epoch 0, iter 500/787) Average loss so far: 5.621
(Epoch 0, iter 550/787) Average loss so far: 5.551
(Epoch 0, iter 600/787) Average loss so far: 5.440
(Epoch 0, iter 650/787) Average loss so far: 5.360
(Epoch 0, iter 700/787) Average loss so far: 5.215
(Epoch 0, iter 750/787) Average loss so far: 5.093
Average epoch loss: 5.701
This epoch took 8.194433637460072 mins. Time remaining: 3.0 hrs 57.0 mins.


7it [00:01,  5.45it/s]


validation loss: 4.93167952128819


100%|██████████| 7/7 [00:07<00:00,  1.13s/it]
100%|██████████| 793/793 [00:08<00:00, 95.55it/s] 


BLEU score: 0.002800126985071343, METEOR score: 0.08576382563509898
Starting epoch 2/30, enc lr scheduler: [0.0009972883382072953], dec lr scheduler: [0.0009972883382072953]
(Epoch 1, iter 50/787) Average loss so far: 4.863
(Epoch 1, iter 100/787) Average loss so far: 4.706
(Epoch 1, iter 150/787) Average loss so far: 4.641
(Epoch 1, iter 200/787) Average loss so far: 4.523
(Epoch 1, iter 250/787) Average loss so far: 4.437
(Epoch 1, iter 300/787) Average loss so far: 4.366
(Epoch 1, iter 350/787) Average loss so far: 4.295
(Epoch 1, iter 400/787) Average loss so far: 4.248
(Epoch 1, iter 450/787) Average loss so far: 4.186
(Epoch 1, iter 500/787) Average loss so far: 4.133
(Epoch 1, iter 550/787) Average loss so far: 4.082
(Epoch 1, iter 600/787) Average loss so far: 4.041
(Epoch 1, iter 650/787) Average loss so far: 4.004
(Epoch 1, iter 700/787) Average loss so far: 3.956
(Epoch 1, iter 750/787) Average loss so far: 3.932
Average epoch loss: 4.277
This epoch took 8.073306282361349 mi

7it [00:01,  5.24it/s]


validation loss: 3.926896367754255


100%|██████████| 7/7 [00:07<00:00,  1.05s/it]
100%|██████████| 793/793 [00:07<00:00, 112.04it/s]


BLEU score: 0.005528415216107959, METEOR score: 0.1087509998355728
Starting epoch 3/30, enc lr scheduler: [0.0009891830623632338], dec lr scheduler: [0.0009891830623632338]
(Epoch 2, iter 50/787) Average loss so far: 3.866
(Epoch 2, iter 100/787) Average loss so far: 3.844
(Epoch 2, iter 150/787) Average loss so far: 3.818
(Epoch 2, iter 200/787) Average loss so far: 3.778
(Epoch 2, iter 250/787) Average loss so far: 3.752
(Epoch 2, iter 300/787) Average loss so far: 3.761
(Epoch 2, iter 350/787) Average loss so far: 3.706
(Epoch 2, iter 400/787) Average loss so far: 3.690
(Epoch 2, iter 450/787) Average loss so far: 3.667
(Epoch 2, iter 500/787) Average loss so far: 3.662
(Epoch 2, iter 550/787) Average loss so far: 3.666
(Epoch 2, iter 600/787) Average loss so far: 3.646
(Epoch 2, iter 650/787) Average loss so far: 3.610
(Epoch 2, iter 700/787) Average loss so far: 3.615
(Epoch 2, iter 750/787) Average loss so far: 3.572
Average epoch loss: 3.704
This epoch took 8.144790331522623 min

7it [00:01,  5.53it/s]


validation loss: 3.6217985493796214


100%|██████████| 7/7 [00:04<00:00,  1.51it/s]
100%|██████████| 793/793 [00:03<00:00, 240.20it/s]


BLEU score: 0.015059414822228353, METEOR score: 0.1608885573065793
Starting epoch 4/30, enc lr scheduler: [0.0009757729755661011], dec lr scheduler: [0.0009757729755661011]
(Epoch 3, iter 50/787) Average loss so far: 3.552
(Epoch 3, iter 100/787) Average loss so far: 3.533
(Epoch 3, iter 150/787) Average loss so far: 3.512
(Epoch 3, iter 200/787) Average loss so far: 3.499
(Epoch 3, iter 250/787) Average loss so far: 3.503
(Epoch 3, iter 300/787) Average loss so far: 3.494
(Epoch 3, iter 350/787) Average loss so far: 3.480
(Epoch 3, iter 400/787) Average loss so far: 3.455
(Epoch 3, iter 450/787) Average loss so far: 3.445
(Epoch 3, iter 500/787) Average loss so far: 3.436
(Epoch 3, iter 550/787) Average loss so far: 3.436
(Epoch 3, iter 600/787) Average loss so far: 3.424
(Epoch 3, iter 650/787) Average loss so far: 3.410
(Epoch 3, iter 700/787) Average loss so far: 3.401
(Epoch 3, iter 750/787) Average loss so far: 3.412
Average epoch loss: 3.463
This epoch took 8.090149772167205 min

7it [00:01,  5.57it/s]


validation loss: 3.4686967645372664


100%|██████████| 7/7 [00:03<00:00,  1.87it/s]
100%|██████████| 793/793 [00:02<00:00, 350.55it/s]


BLEU score: 0.02478974721483497, METEOR score: 0.18132902088830255
Starting epoch 5/30, enc lr scheduler: [0.0009572050015330874], dec lr scheduler: [0.0009572050015330874]
(Epoch 4, iter 50/787) Average loss so far: 3.370
(Epoch 4, iter 100/787) Average loss so far: 3.354
(Epoch 4, iter 150/787) Average loss so far: 3.346
(Epoch 4, iter 200/787) Average loss so far: 3.360
(Epoch 4, iter 250/787) Average loss so far: 3.335
(Epoch 4, iter 300/787) Average loss so far: 3.325
(Epoch 4, iter 350/787) Average loss so far: 3.323
(Epoch 4, iter 400/787) Average loss so far: 3.308
(Epoch 4, iter 450/787) Average loss so far: 3.320
(Epoch 4, iter 500/787) Average loss so far: 3.303
(Epoch 4, iter 550/787) Average loss so far: 3.310
(Epoch 4, iter 600/787) Average loss so far: 3.296
(Epoch 4, iter 650/787) Average loss so far: 3.290
(Epoch 4, iter 700/787) Average loss so far: 3.293
(Epoch 4, iter 750/787) Average loss so far: 3.283
Average epoch loss: 3.318
This epoch took 8.103163651625316 min

7it [00:01,  5.50it/s]


validation loss: 3.376786708831787


100%|██████████| 7/7 [00:04<00:00,  1.44it/s]
100%|██████████| 793/793 [00:03<00:00, 216.71it/s]


BLEU score: 0.016797911257019238, METEOR score: 0.17466288946573208
Starting epoch 6/30, enc lr scheduler: [0.0009336825748732973], dec lr scheduler: [0.0009336825748732973]
(Epoch 5, iter 50/787) Average loss so far: 3.233
(Epoch 5, iter 100/787) Average loss so far: 3.250
(Epoch 5, iter 150/787) Average loss so far: 3.227
(Epoch 5, iter 200/787) Average loss so far: 3.236
(Epoch 5, iter 250/787) Average loss so far: 3.215
(Epoch 5, iter 300/787) Average loss so far: 3.229
(Epoch 5, iter 350/787) Average loss so far: 3.227
(Epoch 5, iter 400/787) Average loss so far: 3.229
(Epoch 5, iter 450/787) Average loss so far: 3.212
(Epoch 5, iter 500/787) Average loss so far: 3.217
(Epoch 5, iter 550/787) Average loss so far: 3.221
(Epoch 5, iter 600/787) Average loss so far: 3.205
(Epoch 5, iter 650/787) Average loss so far: 3.199
(Epoch 5, iter 700/787) Average loss so far: 3.194
(Epoch 5, iter 750/787) Average loss so far: 3.184
Average epoch loss: 3.217
This epoch took 8.09989018837611 min

7it [00:01,  5.31it/s]


validation loss: 3.3135062967027937


100%|██████████| 7/7 [00:03<00:00,  1.87it/s]
100%|██████████| 793/793 [00:02<00:00, 322.45it/s]


BLEU score: 0.026052313701649024, METEOR score: 0.1865945117369642
Starting epoch 7/30, enc lr scheduler: [0.0009054634122155991], dec lr scheduler: [0.0009054634122155991]
(Epoch 6, iter 50/787) Average loss so far: 3.171
(Epoch 6, iter 100/787) Average loss so far: 3.163
(Epoch 6, iter 150/787) Average loss so far: 3.161
(Epoch 6, iter 200/787) Average loss so far: 3.149
(Epoch 6, iter 250/787) Average loss so far: 3.141
(Epoch 6, iter 300/787) Average loss so far: 3.143
(Epoch 6, iter 350/787) Average loss so far: 3.139
(Epoch 6, iter 400/787) Average loss so far: 3.145
(Epoch 6, iter 450/787) Average loss so far: 3.148
(Epoch 6, iter 500/787) Average loss so far: 3.133
(Epoch 6, iter 550/787) Average loss so far: 3.126
(Epoch 6, iter 600/787) Average loss so far: 3.146
(Epoch 6, iter 650/787) Average loss so far: 3.121
(Epoch 6, iter 700/787) Average loss so far: 3.131
(Epoch 6, iter 750/787) Average loss so far: 3.117
Average epoch loss: 3.141
This epoch took 8.1756121357282 mins.

7it [00:01,  5.49it/s]


validation loss: 3.277291910988944


100%|██████████| 7/7 [00:03<00:00,  1.86it/s]
100%|██████████| 793/793 [00:02<00:00, 339.68it/s]


BLEU score: 0.0272790639424669, METEOR score: 0.1936208372227744
Starting epoch 8/30, enc lr scheduler: [0.0008728566886113102], dec lr scheduler: [0.0008728566886113102]
(Epoch 7, iter 50/787) Average loss so far: 3.102
(Epoch 7, iter 100/787) Average loss so far: 3.099
(Epoch 7, iter 150/787) Average loss so far: 3.106
(Epoch 7, iter 200/787) Average loss so far: 3.087
(Epoch 7, iter 250/787) Average loss so far: 3.083
(Epoch 7, iter 300/787) Average loss so far: 3.089
(Epoch 7, iter 350/787) Average loss so far: 3.085
(Epoch 7, iter 400/787) Average loss so far: 3.076
(Epoch 7, iter 450/787) Average loss so far: 3.084
(Epoch 7, iter 500/787) Average loss so far: 3.067
(Epoch 7, iter 550/787) Average loss so far: 3.081
(Epoch 7, iter 600/787) Average loss so far: 3.054
(Epoch 7, iter 650/787) Average loss so far: 3.071
(Epoch 7, iter 700/787) Average loss so far: 3.076
(Epoch 7, iter 750/787) Average loss so far: 3.075
Average epoch loss: 3.081
This epoch took 8.089503169059753 mins.

7it [00:01,  5.51it/s]


validation loss: 3.2438402516501292


100%|██████████| 7/7 [00:03<00:00,  2.02it/s]
100%|██████████| 793/793 [00:01<00:00, 431.65it/s]


BLEU score: 0.038146458721083895, METEOR score: 0.1998428588672739
Starting epoch 9/30, enc lr scheduler: [0.0008362196501476349], dec lr scheduler: [0.0008362196501476349]
(Epoch 8, iter 50/787) Average loss so far: 3.045
(Epoch 8, iter 100/787) Average loss so far: 3.044
(Epoch 8, iter 150/787) Average loss so far: 3.046
(Epoch 8, iter 200/787) Average loss so far: 3.050
(Epoch 8, iter 250/787) Average loss so far: 3.025
(Epoch 8, iter 300/787) Average loss so far: 3.033
(Epoch 8, iter 350/787) Average loss so far: 3.038
(Epoch 8, iter 400/787) Average loss so far: 3.021
(Epoch 8, iter 450/787) Average loss so far: 3.028
(Epoch 8, iter 500/787) Average loss so far: 3.025
(Epoch 8, iter 550/787) Average loss so far: 3.041
(Epoch 8, iter 600/787) Average loss so far: 3.017
(Epoch 8, iter 650/787) Average loss so far: 3.017
(Epoch 8, iter 700/787) Average loss so far: 3.039
(Epoch 8, iter 750/787) Average loss so far: 3.022
Average epoch loss: 3.031
This epoch took 8.09464071591695 mins

7it [00:01,  5.53it/s]


validation loss: 3.2192379406520297


100%|██████████| 7/7 [00:03<00:00,  2.01it/s]
100%|██████████| 793/793 [00:01<00:00, 409.40it/s]


BLEU score: 0.036190345080440604, METEOR score: 0.19858925888566942
Starting epoch 10/30, enc lr scheduler: [0.0007959536998847742], dec lr scheduler: [0.0007959536998847742]
(Epoch 9, iter 50/787) Average loss so far: 3.012
(Epoch 9, iter 100/787) Average loss so far: 2.984
(Epoch 9, iter 150/787) Average loss so far: 2.981
(Epoch 9, iter 200/787) Average loss so far: 2.988
(Epoch 9, iter 250/787) Average loss so far: 2.998
(Epoch 9, iter 300/787) Average loss so far: 2.991
(Epoch 9, iter 350/787) Average loss so far: 3.002
(Epoch 9, iter 400/787) Average loss so far: 2.992
(Epoch 9, iter 450/787) Average loss so far: 3.007
(Epoch 9, iter 500/787) Average loss so far: 2.980
(Epoch 9, iter 550/787) Average loss so far: 2.987
(Epoch 9, iter 600/787) Average loss so far: 2.970
(Epoch 9, iter 650/787) Average loss so far: 2.989
(Epoch 9, iter 700/787) Average loss so far: 2.997
(Epoch 9, iter 750/787) Average loss so far: 2.971
Average epoch loss: 2.990
This epoch took 8.119804537296295 m

7it [00:01,  5.47it/s]


validation loss: 3.2011922086988176


100%|██████████| 7/7 [00:03<00:00,  2.08it/s]
100%|██████████| 793/793 [00:01<00:00, 448.70it/s]


BLEU score: 0.03947045191983198, METEOR score: 0.19917691253253444
Starting epoch 11/30, enc lr scheduler: [0.0007525], dec lr scheduler: [0.0007525]
(Epoch 10, iter 50/787) Average loss so far: 2.958
(Epoch 10, iter 100/787) Average loss so far: 2.957
(Epoch 10, iter 150/787) Average loss so far: 2.941
(Epoch 10, iter 200/787) Average loss so far: 2.959
(Epoch 10, iter 250/787) Average loss so far: 2.947
(Epoch 10, iter 300/787) Average loss so far: 2.957
(Epoch 10, iter 350/787) Average loss so far: 2.952
(Epoch 10, iter 400/787) Average loss so far: 2.966
(Epoch 10, iter 450/787) Average loss so far: 2.948
(Epoch 10, iter 500/787) Average loss so far: 2.961
(Epoch 10, iter 550/787) Average loss so far: 2.954
(Epoch 10, iter 600/787) Average loss so far: 2.947
(Epoch 10, iter 650/787) Average loss so far: 2.955
(Epoch 10, iter 700/787) Average loss so far: 2.952
(Epoch 10, iter 750/787) Average loss so far: 2.965
Average epoch loss: 2.955
This epoch took 8.114036385218302 mins. Time 

7it [00:01,  5.57it/s]


validation loss: 3.1880315371922086


100%|██████████| 7/7 [00:03<00:00,  2.01it/s]
100%|██████████| 793/793 [00:01<00:00, 432.08it/s]


BLEU score: 0.03707379676793246, METEOR score: 0.20066902120742106
Starting epoch 12/30, enc lr scheduler: [0.0007063346383225212], dec lr scheduler: [0.0007063346383225212]
(Epoch 11, iter 50/787) Average loss so far: 2.919
(Epoch 11, iter 100/787) Average loss so far: 2.936
(Epoch 11, iter 150/787) Average loss so far: 2.912
(Epoch 11, iter 200/787) Average loss so far: 2.926
(Epoch 11, iter 250/787) Average loss so far: 2.933
(Epoch 11, iter 300/787) Average loss so far: 2.918
(Epoch 11, iter 350/787) Average loss so far: 2.928
(Epoch 11, iter 400/787) Average loss so far: 2.931
(Epoch 11, iter 450/787) Average loss so far: 2.915
(Epoch 11, iter 500/787) Average loss so far: 2.925
(Epoch 11, iter 550/787) Average loss so far: 2.922
(Epoch 11, iter 600/787) Average loss so far: 2.912
(Epoch 11, iter 650/787) Average loss so far: 2.919
(Epoch 11, iter 700/787) Average loss so far: 2.929
(Epoch 11, iter 750/787) Average loss so far: 2.939
Average epoch loss: 2.924
This epoch took 8.101

7it [00:01,  5.54it/s]


validation loss: 3.1738901478903636


100%|██████████| 7/7 [00:03<00:00,  1.97it/s]
100%|██████████| 793/793 [00:01<00:00, 420.49it/s]


BLEU score: 0.03867613565428751, METEOR score: 0.2017267310038669
Starting epoch 13/30, enc lr scheduler: [0.000657963412215599], dec lr scheduler: [0.000657963412215599]
(Epoch 12, iter 50/787) Average loss so far: 2.889
(Epoch 12, iter 100/787) Average loss so far: 2.888
(Epoch 12, iter 150/787) Average loss so far: 2.880
(Epoch 12, iter 200/787) Average loss so far: 2.896
(Epoch 12, iter 250/787) Average loss so far: 2.896
(Epoch 12, iter 300/787) Average loss so far: 2.903
(Epoch 12, iter 350/787) Average loss so far: 2.885
(Epoch 12, iter 400/787) Average loss so far: 2.904
(Epoch 12, iter 450/787) Average loss so far: 2.898
(Epoch 12, iter 500/787) Average loss so far: 2.898
(Epoch 12, iter 550/787) Average loss so far: 2.892
(Epoch 12, iter 600/787) Average loss so far: 2.911
(Epoch 12, iter 650/787) Average loss so far: 2.916
(Epoch 12, iter 700/787) Average loss so far: 2.894
(Epoch 12, iter 750/787) Average loss so far: 2.902
Average epoch loss: 2.896
This epoch took 8.079897

7it [00:01,  5.26it/s]


validation loss: 3.1668037346431186


100%|██████████| 7/7 [00:03<00:00,  1.94it/s]
100%|██████████| 793/793 [00:02<00:00, 390.27it/s]


BLEU score: 0.035783646119029455, METEOR score: 0.19836152461714607
Starting epoch 14/30, enc lr scheduler: [0.0006079162869547909], dec lr scheduler: [0.0006079162869547909]
(Epoch 13, iter 50/787) Average loss so far: 2.883
(Epoch 13, iter 100/787) Average loss so far: 2.870
(Epoch 13, iter 150/787) Average loss so far: 2.875
(Epoch 13, iter 200/787) Average loss so far: 2.870
(Epoch 13, iter 250/787) Average loss so far: 2.899
(Epoch 13, iter 300/787) Average loss so far: 2.864
(Epoch 13, iter 350/787) Average loss so far: 2.870
(Epoch 13, iter 400/787) Average loss so far: 2.875
(Epoch 13, iter 450/787) Average loss so far: 2.873
(Epoch 13, iter 500/787) Average loss so far: 2.873
(Epoch 13, iter 550/787) Average loss so far: 2.860
(Epoch 13, iter 600/787) Average loss so far: 2.863
(Epoch 13, iter 650/787) Average loss so far: 2.887
(Epoch 13, iter 700/787) Average loss so far: 2.874
(Epoch 13, iter 750/787) Average loss so far: 2.867
Average epoch loss: 2.873
This epoch took 8.05

7it [00:01,  5.44it/s]


validation loss: 3.1585819721221924


100%|██████████| 7/7 [00:03<00:00,  1.98it/s]
100%|██████████| 793/793 [00:01<00:00, 447.37it/s]


BLEU score: 0.04004954169047118, METEOR score: 0.20343329001800056
Starting epoch 15/30, enc lr scheduler: [0.0005567415893174886], dec lr scheduler: [0.0005567415893174886]
(Epoch 14, iter 50/787) Average loss so far: 2.854
(Epoch 14, iter 100/787) Average loss so far: 2.848
(Epoch 14, iter 150/787) Average loss so far: 2.840
(Epoch 14, iter 200/787) Average loss so far: 2.846
(Epoch 14, iter 250/787) Average loss so far: 2.842
(Epoch 14, iter 300/787) Average loss so far: 2.861
(Epoch 14, iter 350/787) Average loss so far: 2.832
(Epoch 14, iter 400/787) Average loss so far: 2.864
(Epoch 14, iter 450/787) Average loss so far: 2.858
(Epoch 14, iter 500/787) Average loss so far: 2.865
(Epoch 14, iter 550/787) Average loss so far: 2.849
(Epoch 14, iter 600/787) Average loss so far: 2.854
(Epoch 14, iter 650/787) Average loss so far: 2.849
(Epoch 14, iter 700/787) Average loss so far: 2.852
(Epoch 14, iter 750/787) Average loss so far: 2.866
Average epoch loss: 2.852
This epoch took 8.125

7it [00:01,  5.75it/s]


validation loss: 3.157423734664917


100%|██████████| 7/7 [00:03<00:00,  2.14it/s]
100%|██████████| 793/793 [00:01<00:00, 505.40it/s]


BLEU score: 0.04606436720940942, METEOR score: 0.2026929298648693
Starting epoch 16/30, enc lr scheduler: [0.0005050000000000002], dec lr scheduler: [0.0005050000000000002]
(Epoch 15, iter 50/787) Average loss so far: 2.824
(Epoch 15, iter 100/787) Average loss so far: 2.823
(Epoch 15, iter 150/787) Average loss so far: 2.835
(Epoch 15, iter 200/787) Average loss so far: 2.824
(Epoch 15, iter 250/787) Average loss so far: 2.828
(Epoch 15, iter 300/787) Average loss so far: 2.832
(Epoch 15, iter 350/787) Average loss so far: 2.837
(Epoch 15, iter 400/787) Average loss so far: 2.820
(Epoch 15, iter 450/787) Average loss so far: 2.829
(Epoch 15, iter 500/787) Average loss so far: 2.846
(Epoch 15, iter 550/787) Average loss so far: 2.845
(Epoch 15, iter 600/787) Average loss so far: 2.843
(Epoch 15, iter 650/787) Average loss so far: 2.838
(Epoch 15, iter 700/787) Average loss so far: 2.838
(Epoch 15, iter 750/787) Average loss so far: 2.839
Average epoch loss: 2.833
This epoch took 8.0893

7it [00:01,  5.77it/s]


validation loss: 3.1507908276149204


100%|██████████| 7/7 [00:03<00:00,  2.12it/s]
100%|██████████| 793/793 [00:01<00:00, 486.39it/s]


BLEU score: 0.04347215230255102, METEOR score: 0.20329518542651354
Starting epoch 17/30, enc lr scheduler: [0.0004532584106825117], dec lr scheduler: [0.0004532584106825117]
(Epoch 16, iter 50/787) Average loss so far: 2.810
(Epoch 16, iter 100/787) Average loss so far: 2.807
(Epoch 16, iter 150/787) Average loss so far: 2.817
(Epoch 16, iter 200/787) Average loss so far: 2.800
(Epoch 16, iter 250/787) Average loss so far: 2.812
(Epoch 16, iter 300/787) Average loss so far: 2.820
(Epoch 16, iter 350/787) Average loss so far: 2.824
(Epoch 16, iter 400/787) Average loss so far: 2.830
(Epoch 16, iter 450/787) Average loss so far: 2.818
(Epoch 16, iter 500/787) Average loss so far: 2.829
(Epoch 16, iter 550/787) Average loss so far: 2.807
(Epoch 16, iter 600/787) Average loss so far: 2.812
(Epoch 16, iter 650/787) Average loss so far: 2.816
(Epoch 16, iter 700/787) Average loss so far: 2.799
(Epoch 16, iter 750/787) Average loss so far: 2.830
Average epoch loss: 2.816
This epoch took 8.113

7it [00:01,  5.44it/s]


validation loss: 3.146672351019723


100%|██████████| 7/7 [00:03<00:00,  2.15it/s]
100%|██████████| 793/793 [00:01<00:00, 494.13it/s]


BLEU score: 0.04394619733107649, METEOR score: 0.2021709564532238
Starting epoch 18/30, enc lr scheduler: [0.00040208371304520916], dec lr scheduler: [0.00040208371304520916]
(Epoch 17, iter 50/787) Average loss so far: 2.801
(Epoch 17, iter 100/787) Average loss so far: 2.796
(Epoch 17, iter 150/787) Average loss so far: 2.805
(Epoch 17, iter 200/787) Average loss so far: 2.794
(Epoch 17, iter 250/787) Average loss so far: 2.784
(Epoch 17, iter 300/787) Average loss so far: 2.814
(Epoch 17, iter 350/787) Average loss so far: 2.790
(Epoch 17, iter 400/787) Average loss so far: 2.786
(Epoch 17, iter 450/787) Average loss so far: 2.808
(Epoch 17, iter 500/787) Average loss so far: 2.801
(Epoch 17, iter 550/787) Average loss so far: 2.797
(Epoch 17, iter 600/787) Average loss so far: 2.811
(Epoch 17, iter 650/787) Average loss so far: 2.809
(Epoch 17, iter 700/787) Average loss so far: 2.813
(Epoch 17, iter 750/787) Average loss so far: 2.814
Average epoch loss: 2.802
This epoch took 8.12

7it [00:01,  5.54it/s]


validation loss: 3.144554853439331


100%|██████████| 7/7 [00:03<00:00,  2.18it/s]
100%|██████████| 793/793 [00:01<00:00, 539.82it/s]


BLEU score: 0.050306488687707214, METEOR score: 0.2085596593048818
Starting epoch 19/30, enc lr scheduler: [0.00035203658778440114], dec lr scheduler: [0.00035203658778440114]
(Epoch 18, iter 50/787) Average loss so far: 2.776
(Epoch 18, iter 100/787) Average loss so far: 2.779
(Epoch 18, iter 150/787) Average loss so far: 2.778
(Epoch 18, iter 200/787) Average loss so far: 2.786
(Epoch 18, iter 250/787) Average loss so far: 2.788
(Epoch 18, iter 300/787) Average loss so far: 2.771
(Epoch 18, iter 350/787) Average loss so far: 2.792
(Epoch 18, iter 400/787) Average loss so far: 2.808
(Epoch 18, iter 450/787) Average loss so far: 2.795
(Epoch 18, iter 500/787) Average loss so far: 2.784
(Epoch 18, iter 550/787) Average loss so far: 2.794
(Epoch 18, iter 600/787) Average loss so far: 2.780
(Epoch 18, iter 650/787) Average loss so far: 2.786
(Epoch 18, iter 700/787) Average loss so far: 2.799
(Epoch 18, iter 750/787) Average loss so far: 2.798
Average epoch loss: 2.788
This epoch took 8.1

7it [00:01,  5.48it/s]


validation loss: 3.139380386897496


100%|██████████| 7/7 [00:03<00:00,  2.20it/s]
100%|██████████| 793/793 [00:01<00:00, 524.64it/s]


BLEU score: 0.05073499908229344, METEOR score: 0.20909495784132945
Starting epoch 20/30, enc lr scheduler: [0.00030366536167747904], dec lr scheduler: [0.00030366536167747904]
(Epoch 19, iter 50/787) Average loss so far: 2.776
(Epoch 19, iter 100/787) Average loss so far: 2.776
(Epoch 19, iter 150/787) Average loss so far: 2.768
(Epoch 19, iter 200/787) Average loss so far: 2.764
(Epoch 19, iter 250/787) Average loss so far: 2.763
(Epoch 19, iter 300/787) Average loss so far: 2.781
(Epoch 19, iter 350/787) Average loss so far: 2.774
(Epoch 19, iter 400/787) Average loss so far: 2.779
(Epoch 19, iter 450/787) Average loss so far: 2.769
(Epoch 19, iter 500/787) Average loss so far: 2.786
(Epoch 19, iter 550/787) Average loss so far: 2.772
(Epoch 19, iter 600/787) Average loss so far: 2.777
(Epoch 19, iter 650/787) Average loss so far: 2.790
(Epoch 19, iter 700/787) Average loss so far: 2.769
(Epoch 19, iter 750/787) Average loss so far: 2.797
Average epoch loss: 2.777
This epoch took 8.1

7it [00:01,  5.18it/s]


validation loss: 3.1408518382481168


100%|██████████| 7/7 [00:03<00:00,  2.24it/s]
100%|██████████| 793/793 [00:01<00:00, 574.52it/s]


BLEU score: 0.05228166600737611, METEOR score: 0.20709094525810967
Starting epoch 21/30, enc lr scheduler: [0.00025750000000000013], dec lr scheduler: [0.00025750000000000013]
(Epoch 20, iter 50/787) Average loss so far: 2.769
(Epoch 20, iter 100/787) Average loss so far: 2.758
(Epoch 20, iter 150/787) Average loss so far: 2.767
(Epoch 20, iter 200/787) Average loss so far: 2.772
(Epoch 20, iter 250/787) Average loss so far: 2.776
(Epoch 20, iter 300/787) Average loss so far: 2.772
(Epoch 20, iter 350/787) Average loss so far: 2.749
(Epoch 20, iter 400/787) Average loss so far: 2.774
(Epoch 20, iter 450/787) Average loss so far: 2.760
(Epoch 20, iter 500/787) Average loss so far: 2.759
(Epoch 20, iter 550/787) Average loss so far: 2.761
(Epoch 20, iter 600/787) Average loss so far: 2.778
(Epoch 20, iter 650/787) Average loss so far: 2.762
(Epoch 20, iter 700/787) Average loss so far: 2.777
(Epoch 20, iter 750/787) Average loss so far: 2.765
Average epoch loss: 2.767
This epoch took 8.1

7it [00:01,  5.42it/s]


validation loss: 3.1391042300633023


100%|██████████| 7/7 [00:02<00:00,  2.38it/s]
100%|██████████| 793/793 [00:01<00:00, 609.62it/s]


BLEU score: 0.05379213235939195, METEOR score: 0.2088759308593728
Starting epoch 22/30, enc lr scheduler: [0.00021404630011522585], dec lr scheduler: [0.00021404630011522585]
(Epoch 21, iter 50/787) Average loss so far: 2.759
(Epoch 21, iter 100/787) Average loss so far: 2.741
(Epoch 21, iter 150/787) Average loss so far: 2.746
(Epoch 21, iter 200/787) Average loss so far: 2.776
(Epoch 21, iter 250/787) Average loss so far: 2.746
(Epoch 21, iter 300/787) Average loss so far: 2.761
(Epoch 21, iter 350/787) Average loss so far: 2.762
(Epoch 21, iter 400/787) Average loss so far: 2.763
(Epoch 21, iter 450/787) Average loss so far: 2.759
(Epoch 21, iter 500/787) Average loss so far: 2.751
(Epoch 21, iter 550/787) Average loss so far: 2.757
(Epoch 21, iter 600/787) Average loss so far: 2.773
(Epoch 21, iter 650/787) Average loss so far: 2.748
(Epoch 21, iter 700/787) Average loss so far: 2.769
(Epoch 21, iter 750/787) Average loss so far: 2.743
Average epoch loss: 2.758
This epoch took 8.17

7it [00:01,  5.51it/s]


validation loss: 3.1356898035321916


100%|██████████| 7/7 [00:03<00:00,  2.20it/s]
100%|██████████| 793/793 [00:01<00:00, 558.30it/s]


BLEU score: 0.0521244945980193, METEOR score: 0.20665171059130683
Starting epoch 23/30, enc lr scheduler: [0.00017378034985236535], dec lr scheduler: [0.00017378034985236535]
(Epoch 22, iter 50/787) Average loss so far: 2.743
(Epoch 22, iter 100/787) Average loss so far: 2.736
(Epoch 22, iter 150/787) Average loss so far: 2.749
(Epoch 22, iter 200/787) Average loss so far: 2.755
(Epoch 22, iter 250/787) Average loss so far: 2.747
(Epoch 22, iter 300/787) Average loss so far: 2.757
(Epoch 22, iter 350/787) Average loss so far: 2.755
(Epoch 22, iter 400/787) Average loss so far: 2.744
(Epoch 22, iter 450/787) Average loss so far: 2.758
(Epoch 22, iter 500/787) Average loss so far: 2.746
(Epoch 22, iter 550/787) Average loss so far: 2.760
(Epoch 22, iter 600/787) Average loss so far: 2.766
(Epoch 22, iter 650/787) Average loss so far: 2.742
(Epoch 22, iter 700/787) Average loss so far: 2.747
(Epoch 22, iter 750/787) Average loss so far: 2.749
Average epoch loss: 2.750
This epoch took 7.98

7it [00:01,  5.36it/s]


validation loss: 3.1357997144971574


100%|██████████| 7/7 [00:02<00:00,  2.43it/s]
100%|██████████| 793/793 [00:01<00:00, 591.51it/s]


BLEU score: 0.05242640382134315, METEOR score: 0.20686833056692383
Starting epoch 24/30, enc lr scheduler: [0.00013714331138868998], dec lr scheduler: [0.00013714331138868998]
(Epoch 23, iter 50/787) Average loss so far: 2.736
(Epoch 23, iter 100/787) Average loss so far: 2.744
(Epoch 23, iter 150/787) Average loss so far: 2.725
(Epoch 23, iter 200/787) Average loss so far: 2.741
(Epoch 23, iter 250/787) Average loss so far: 2.745
(Epoch 23, iter 300/787) Average loss so far: 2.746
(Epoch 23, iter 350/787) Average loss so far: 2.734
(Epoch 23, iter 400/787) Average loss so far: 2.747
(Epoch 23, iter 450/787) Average loss so far: 2.741
(Epoch 23, iter 500/787) Average loss so far: 2.734
(Epoch 23, iter 550/787) Average loss so far: 2.739
(Epoch 23, iter 600/787) Average loss so far: 2.744
(Epoch 23, iter 650/787) Average loss so far: 2.765
(Epoch 23, iter 700/787) Average loss so far: 2.758
(Epoch 23, iter 750/787) Average loss so far: 2.746
Average epoch loss: 2.744
This epoch took 8.0

7it [00:01,  5.40it/s]


validation loss: 3.135296242577689


100%|██████████| 7/7 [00:03<00:00,  2.24it/s]
100%|██████████| 793/793 [00:01<00:00, 603.28it/s]


BLEU score: 0.051097938950738914, METEOR score: 0.20545257704474157
Starting epoch 25/30, enc lr scheduler: [0.00010453658778440108], dec lr scheduler: [0.00010453658778440108]
(Epoch 24, iter 50/787) Average loss so far: 2.748
(Epoch 24, iter 100/787) Average loss so far: 2.726
(Epoch 24, iter 150/787) Average loss so far: 2.731
(Epoch 24, iter 200/787) Average loss so far: 2.757
(Epoch 24, iter 250/787) Average loss so far: 2.718
(Epoch 24, iter 300/787) Average loss so far: 2.734
(Epoch 24, iter 350/787) Average loss so far: 2.749
(Epoch 24, iter 400/787) Average loss so far: 2.745
(Epoch 24, iter 450/787) Average loss so far: 2.738
(Epoch 24, iter 500/787) Average loss so far: 2.742
(Epoch 24, iter 550/787) Average loss so far: 2.754
(Epoch 24, iter 600/787) Average loss so far: 2.740
(Epoch 24, iter 650/787) Average loss so far: 2.746
(Epoch 24, iter 700/787) Average loss so far: 2.731
(Epoch 24, iter 750/787) Average loss so far: 2.722
Average epoch loss: 2.738
This epoch took 8.

7it [00:01,  5.42it/s]


validation loss: 3.1331892013549805


100%|██████████| 7/7 [00:02<00:00,  2.39it/s]
100%|██████████| 793/793 [00:01<00:00, 597.43it/s]


BLEU score: 0.05338881169928626, METEOR score: 0.20659685260841154
Starting epoch 26/30, enc lr scheduler: [7.631742512670285e-05], dec lr scheduler: [7.631742512670285e-05]
(Epoch 25, iter 50/787) Average loss so far: 2.725
(Epoch 25, iter 100/787) Average loss so far: 2.720
(Epoch 25, iter 150/787) Average loss so far: 2.750
(Epoch 25, iter 200/787) Average loss so far: 2.731
(Epoch 25, iter 250/787) Average loss so far: 2.725
(Epoch 25, iter 300/787) Average loss so far: 2.717
(Epoch 25, iter 350/787) Average loss so far: 2.742
(Epoch 25, iter 400/787) Average loss so far: 2.735
(Epoch 25, iter 450/787) Average loss so far: 2.736
(Epoch 25, iter 500/787) Average loss so far: 2.735
(Epoch 25, iter 550/787) Average loss so far: 2.738
(Epoch 25, iter 600/787) Average loss so far: 2.737
(Epoch 25, iter 650/787) Average loss so far: 2.732
(Epoch 25, iter 700/787) Average loss so far: 2.749
(Epoch 25, iter 750/787) Average loss so far: 2.738
Average epoch loss: 2.734
This epoch took 8.112

7it [00:01,  5.46it/s]


validation loss: 3.1343512194497243


100%|██████████| 7/7 [00:03<00:00,  2.24it/s]
100%|██████████| 793/793 [00:01<00:00, 591.37it/s]


BLEU score: 0.05337146516245549, METEOR score: 0.20767647224625604
Starting epoch 27/30, enc lr scheduler: [5.279499846691252e-05], dec lr scheduler: [5.279499846691252e-05]
(Epoch 26, iter 50/787) Average loss so far: 2.724
(Epoch 26, iter 100/787) Average loss so far: 2.720
(Epoch 26, iter 150/787) Average loss so far: 2.706
(Epoch 26, iter 200/787) Average loss so far: 2.737
(Epoch 26, iter 250/787) Average loss so far: 2.737
(Epoch 26, iter 300/787) Average loss so far: 2.726
(Epoch 26, iter 350/787) Average loss so far: 2.739
(Epoch 26, iter 400/787) Average loss so far: 2.741
(Epoch 26, iter 450/787) Average loss so far: 2.735
(Epoch 26, iter 500/787) Average loss so far: 2.735
(Epoch 26, iter 550/787) Average loss so far: 2.734
(Epoch 26, iter 600/787) Average loss so far: 2.735
(Epoch 26, iter 650/787) Average loss so far: 2.728
(Epoch 26, iter 700/787) Average loss so far: 2.725
(Epoch 26, iter 750/787) Average loss so far: 2.726
Average epoch loss: 2.730
This epoch took 8.101

7it [00:01,  5.39it/s]


validation loss: 3.1328334467751637


100%|██████████| 7/7 [00:03<00:00,  2.24it/s]
100%|██████████| 793/793 [00:01<00:00, 581.81it/s]


BLEU score: 0.05541456263767681, METEOR score: 0.21179274372666657
Starting epoch 28/30, enc lr scheduler: [3.4227024433899005e-05], dec lr scheduler: [3.4227024433899005e-05]
(Epoch 27, iter 50/787) Average loss so far: 2.747
(Epoch 27, iter 100/787) Average loss so far: 2.725
(Epoch 27, iter 150/787) Average loss so far: 2.725
(Epoch 27, iter 200/787) Average loss so far: 2.721
(Epoch 27, iter 250/787) Average loss so far: 2.731
(Epoch 27, iter 300/787) Average loss so far: 2.723
(Epoch 27, iter 350/787) Average loss so far: 2.726
(Epoch 27, iter 400/787) Average loss so far: 2.718
(Epoch 27, iter 450/787) Average loss so far: 2.734
(Epoch 27, iter 500/787) Average loss so far: 2.741
(Epoch 27, iter 550/787) Average loss so far: 2.731
(Epoch 27, iter 600/787) Average loss so far: 2.729
(Epoch 27, iter 650/787) Average loss so far: 2.714
(Epoch 27, iter 700/787) Average loss so far: 2.731
(Epoch 27, iter 750/787) Average loss so far: 2.726
Average epoch loss: 2.728
This epoch took 8.0

7it [00:01,  5.59it/s]


validation loss: 3.1330018043518066


100%|██████████| 7/7 [00:02<00:00,  2.40it/s]
100%|██████████| 793/793 [00:01<00:00, 605.09it/s]


BLEU score: 0.05391199400984629, METEOR score: 0.20699247679832458
Starting epoch 29/30, enc lr scheduler: [2.0816937636766188e-05], dec lr scheduler: [2.0816937636766188e-05]
(Epoch 28, iter 50/787) Average loss so far: 2.733
(Epoch 28, iter 100/787) Average loss so far: 2.710
(Epoch 28, iter 150/787) Average loss so far: 2.729
(Epoch 28, iter 200/787) Average loss so far: 2.698
(Epoch 28, iter 250/787) Average loss so far: 2.733
(Epoch 28, iter 300/787) Average loss so far: 2.728
(Epoch 28, iter 350/787) Average loss so far: 2.734
(Epoch 28, iter 400/787) Average loss so far: 2.728
(Epoch 28, iter 450/787) Average loss so far: 2.732
(Epoch 28, iter 500/787) Average loss so far: 2.728
(Epoch 28, iter 550/787) Average loss so far: 2.732
(Epoch 28, iter 600/787) Average loss so far: 2.721
(Epoch 28, iter 650/787) Average loss so far: 2.731
(Epoch 28, iter 700/787) Average loss so far: 2.728
(Epoch 28, iter 750/787) Average loss so far: 2.726
Average epoch loss: 2.726
This epoch took 7.9

7it [00:01,  5.47it/s]


validation loss: 3.132430451256888


100%|██████████| 7/7 [00:03<00:00,  2.30it/s]
100%|██████████| 793/793 [00:01<00:00, 598.63it/s]


BLEU score: 0.05513984026373013, METEOR score: 0.2094436932678814
Starting epoch 30/30, enc lr scheduler: [1.2711661792704668e-05], dec lr scheduler: [1.2711661792704668e-05]
(Epoch 29, iter 50/787) Average loss so far: 2.732
(Epoch 29, iter 100/787) Average loss so far: 2.722
(Epoch 29, iter 150/787) Average loss so far: 2.715
(Epoch 29, iter 200/787) Average loss so far: 2.739
(Epoch 29, iter 250/787) Average loss so far: 2.727
(Epoch 29, iter 300/787) Average loss so far: 2.729
(Epoch 29, iter 350/787) Average loss so far: 2.730
(Epoch 29, iter 400/787) Average loss so far: 2.714
(Epoch 29, iter 450/787) Average loss so far: 2.725
(Epoch 29, iter 500/787) Average loss so far: 2.711
(Epoch 29, iter 550/787) Average loss so far: 2.733
(Epoch 29, iter 600/787) Average loss so far: 2.737
(Epoch 29, iter 650/787) Average loss so far: 2.723
(Epoch 29, iter 700/787) Average loss so far: 2.724
(Epoch 29, iter 750/787) Average loss so far: 2.724
Average epoch loss: 2.725
This epoch took 7.99

7it [00:01,  5.51it/s]


validation loss: 3.1326404299054826


100%|██████████| 7/7 [00:02<00:00,  2.39it/s]
100%|██████████| 793/793 [00:01<00:00, 595.70it/s]

BLEU score: 0.055078340546709854, METEOR score: 0.20826842611097787





In [14]:
save_model(encoder_attn, decoder_attn, f"{identifier}_last")

## Encoder-Decoder (Extension: pretrained embeddings)

In [14]:
reset_rng()

In [17]:
#! TODO: PUT IN CODE TO WGET AND EXTRACT FROM ZIP

In [15]:
pretrained_embedding_dict = create_pretrained_embedding_dict("./glove.840B.300d.txt")

In [16]:
embedding_size=300
encoder_pretrained_embed = EncoderRNN(
    input_size=vocab.n_unique_words, embedding_size=embedding_size, hidden_size=HIDDEN_SIZE, 
    padding_value=vocab.word2index(PAD_WORD), pretrained_embedding_dict=pretrained_embedding_dict, 
    vocab=vocab).to(DEVICE)
# in the training script, decoder is always fed a non-end token and thus never needs to generate padding
# also it should never generate "<UNKNOWN>"
decoder_pretrained_embed = AttnDecoderRNN(
    embedding_size=embedding_size,hidden_size=HIDDEN_SIZE, output_size=vocab.n_unique_words-1,
    padding_val=vocab.word2index(PAD_WORD), dropout=DROPOUT, global_max_ing_len=MAX_INGR_LEN,
    pretrained_embedding_dict=pretrained_embedding_dict, vocab=vocab).to(DEVICE)

100%|██████████| 44315/44315 [00:00<00:00, 788459.19it/s]


29462/44315 (0.665%) words have pretrained embeddings


100%|██████████| 44314/44314 [00:00<00:00, 753367.82it/s]

29462/44314 (0.665%) words have pretrained embeddings





In [18]:
initial_lr=1e-3
min_lr = 1e-5
n_epochs = 30
batch_size=128
encoder_pretrained_embed_optimizer = optim.Adam(encoder_pretrained_embed.parameters(), lr=initial_lr)
decoder_pretrained_embed_optimizer = optim.Adam(decoder_pretrained_embed.parameters(), lr=initial_lr)
# enc_scheduler = CosineAnnealingLR(encoder_pretrained_embed_optimizer, T_max=n_epochs, eta_min=min_lr)
# dec_scheduler = CosineAnnealingLR(decoder_pretrained_embed_optimizer, T_max=n_epochs, eta_min=min_lr)
enc_pretrained_embed_scheduler = CosineAnnealingLR(encoder_pretrained_embed_optimizer, T_max=n_epochs, eta_min=min_lr)
dec_pretrained_embed_scheduler = CosineAnnealingLR(decoder_pretrained_embed_optimizer, T_max=n_epochs, eta_min=min_lr)
identifier="pretrained_emb_attn_adam_without_intermediate_tags_wd0_lr1e-3"

pre_epoch_losses, pre_val_epoch_losses, pre_log = train(
    encoder_pretrained_embed, decoder_pretrained_embed, encoder_pretrained_embed_optimizer, decoder_pretrained_embed_optimizer, 
    train_ds, n_epochs=n_epochs, vocab=vocab, decoder_mode="attention", batch_size=batch_size, 
    enc_lr_scheduler=enc_pretrained_embed_scheduler, dec_lr_scheduler=dec_pretrained_embed_scheduler, 
    dev_ds_val_loss = dev_ds_val_loss, dev_ds_val_met=dev_ds_val_met, identifier=identifier,
    verbose_iter_interval=50)

save_log(identifier, pre_log, encoder_pretrained_embed_optimizer, decoder_pretrained_embed_optimizer, 
         enc_pretrained_embed_scheduler, dec_pretrained_embed_scheduler)

Starting epoch 1/30, enc lr scheduler: [0.001], dec lr scheduler: [0.001]
(Epoch 0, iter 50/787) Average loss so far: 6.968
(Epoch 0, iter 100/787) Average loss so far: 6.040
(Epoch 0, iter 150/787) Average loss so far: 5.995
(Epoch 0, iter 200/787) Average loss so far: 5.955
(Epoch 0, iter 250/787) Average loss so far: 5.910
(Epoch 0, iter 300/787) Average loss so far: 5.859
(Epoch 0, iter 350/787) Average loss so far: 5.839
(Epoch 0, iter 400/787) Average loss so far: 5.759
(Epoch 0, iter 450/787) Average loss so far: 5.627
(Epoch 0, iter 500/787) Average loss so far: 5.404
(Epoch 0, iter 550/787) Average loss so far: 5.209
(Epoch 0, iter 600/787) Average loss so far: 5.001
(Epoch 0, iter 650/787) Average loss so far: 4.851
(Epoch 0, iter 700/787) Average loss so far: 4.692
(Epoch 0, iter 750/787) Average loss so far: 4.582
Average epoch loss: 5.528
This epoch took 8.08877116839091 mins. Time remaining: 3.0 hrs 54.0 mins.


7it [00:01,  5.60it/s]


validation loss: 4.4685172353472025


100%|██████████| 7/7 [00:07<00:00,  1.12s/it]
100%|██████████| 793/793 [00:09<00:00, 87.13it/s] 


BLEU score: 0.0031624657072353694, METEOR score: 0.0796926100408187
Starting epoch 2/30, enc lr scheduler: [0.0009972883382072953], dec lr scheduler: [0.0009972883382072953]
(Epoch 1, iter 50/787) Average loss so far: 4.410
(Epoch 1, iter 100/787) Average loss so far: 4.295
(Epoch 1, iter 150/787) Average loss so far: 4.269
(Epoch 1, iter 200/787) Average loss so far: 4.187
(Epoch 1, iter 250/787) Average loss so far: 4.128
(Epoch 1, iter 300/787) Average loss so far: 4.080
(Epoch 1, iter 350/787) Average loss so far: 4.027
(Epoch 1, iter 400/787) Average loss so far: 3.997
(Epoch 1, iter 450/787) Average loss so far: 3.949
(Epoch 1, iter 500/787) Average loss so far: 3.910
(Epoch 1, iter 550/787) Average loss so far: 3.871
(Epoch 1, iter 600/787) Average loss so far: 3.844
(Epoch 1, iter 650/787) Average loss so far: 3.814
(Epoch 1, iter 700/787) Average loss so far: 3.778
(Epoch 1, iter 750/787) Average loss so far: 3.759
Average epoch loss: 4.009
This epoch took 8.075345408916473 mi

7it [00:01,  5.45it/s]


validation loss: 3.7725216320582797


100%|██████████| 7/7 [00:06<00:00,  1.10it/s]
100%|██████████| 793/793 [00:05<00:00, 144.78it/s]


BLEU score: 0.008172939091770137, METEOR score: 0.1373603160673997
Starting epoch 3/30, enc lr scheduler: [0.0009891830623632338], dec lr scheduler: [0.0009891830623632338]
(Epoch 2, iter 50/787) Average loss so far: 3.698
(Epoch 2, iter 100/787) Average loss so far: 3.680
(Epoch 2, iter 150/787) Average loss so far: 3.658
(Epoch 2, iter 200/787) Average loss so far: 3.622
(Epoch 2, iter 250/787) Average loss so far: 3.602
(Epoch 2, iter 300/787) Average loss so far: 3.609
(Epoch 2, iter 350/787) Average loss so far: 3.561
(Epoch 2, iter 400/787) Average loss so far: 3.547
(Epoch 2, iter 450/787) Average loss so far: 3.528
(Epoch 2, iter 500/787) Average loss so far: 3.528
(Epoch 2, iter 550/787) Average loss so far: 3.529
(Epoch 2, iter 600/787) Average loss so far: 3.514
(Epoch 2, iter 650/787) Average loss so far: 3.481
(Epoch 2, iter 700/787) Average loss so far: 3.483
(Epoch 2, iter 750/787) Average loss so far: 3.445
Average epoch loss: 3.560
This epoch took 8.12299534479777 mins

7it [00:01,  5.61it/s]


validation loss: 3.5144453048706055


100%|██████████| 7/7 [00:04<00:00,  1.57it/s]
100%|██████████| 793/793 [00:03<00:00, 259.34it/s]


BLEU score: 0.015654428117596996, METEOR score: 0.15914722051182856
Starting epoch 4/30, enc lr scheduler: [0.0009757729755661011], dec lr scheduler: [0.0009757729755661011]
(Epoch 3, iter 50/787) Average loss so far: 3.431
(Epoch 3, iter 100/787) Average loss so far: 3.411
(Epoch 3, iter 150/787) Average loss so far: 3.391
(Epoch 3, iter 200/787) Average loss so far: 3.380
(Epoch 3, iter 250/787) Average loss so far: 3.386
(Epoch 3, iter 300/787) Average loss so far: 3.376
(Epoch 3, iter 350/787) Average loss so far: 3.366
(Epoch 3, iter 400/787) Average loss so far: 3.343
(Epoch 3, iter 450/787) Average loss so far: 3.333
(Epoch 3, iter 500/787) Average loss so far: 3.326
(Epoch 3, iter 550/787) Average loss so far: 3.330
(Epoch 3, iter 600/787) Average loss so far: 3.317
(Epoch 3, iter 650/787) Average loss so far: 3.302
(Epoch 3, iter 700/787) Average loss so far: 3.296
(Epoch 3, iter 750/787) Average loss so far: 3.307
Average epoch loss: 3.351
This epoch took 8.083766424655915 mi

7it [00:01,  5.69it/s]


validation loss: 3.3842404569898332


100%|██████████| 7/7 [00:03<00:00,  1.81it/s]
100%|██████████| 793/793 [00:02<00:00, 321.84it/s]


BLEU score: 0.021776681776078183, METEOR score: 0.1731301206546595
Starting epoch 5/30, enc lr scheduler: [0.0009572050015330874], dec lr scheduler: [0.0009572050015330874]
(Epoch 4, iter 50/787) Average loss so far: 3.269
(Epoch 4, iter 100/787) Average loss so far: 3.256
(Epoch 4, iter 150/787) Average loss so far: 3.247
(Epoch 4, iter 200/787) Average loss so far: 3.260
(Epoch 4, iter 250/787) Average loss so far: 3.232
(Epoch 4, iter 300/787) Average loss so far: 3.226
(Epoch 4, iter 350/787) Average loss so far: 3.227
(Epoch 4, iter 400/787) Average loss so far: 3.210
(Epoch 4, iter 450/787) Average loss so far: 3.223
(Epoch 4, iter 500/787) Average loss so far: 3.208
(Epoch 4, iter 550/787) Average loss so far: 3.213
(Epoch 4, iter 600/787) Average loss so far: 3.201
(Epoch 4, iter 650/787) Average loss so far: 3.194
(Epoch 4, iter 700/787) Average loss so far: 3.198
(Epoch 4, iter 750/787) Average loss so far: 3.187
Average epoch loss: 3.221
This epoch took 8.08472265402476 mins

7it [00:01,  5.71it/s]


validation loss: 3.299443074635097


100%|██████████| 7/7 [00:03<00:00,  1.91it/s]
100%|██████████| 793/793 [00:02<00:00, 371.62it/s]


BLEU score: 0.029468622995135987, METEOR score: 0.18938801124897647
Starting epoch 6/30, enc lr scheduler: [0.0009336825748732973], dec lr scheduler: [0.0009336825748732973]
(Epoch 5, iter 50/787) Average loss so far: 3.146
(Epoch 5, iter 100/787) Average loss so far: 3.159
(Epoch 5, iter 150/787) Average loss so far: 3.138
(Epoch 5, iter 200/787) Average loss so far: 3.145
(Epoch 5, iter 250/787) Average loss so far: 3.128
(Epoch 5, iter 300/787) Average loss so far: 3.140
(Epoch 5, iter 350/787) Average loss so far: 3.138
(Epoch 5, iter 400/787) Average loss so far: 3.138
(Epoch 5, iter 450/787) Average loss so far: 3.124
(Epoch 5, iter 500/787) Average loss so far: 3.125
(Epoch 5, iter 550/787) Average loss so far: 3.129
(Epoch 5, iter 600/787) Average loss so far: 3.115
(Epoch 5, iter 650/787) Average loss so far: 3.108
(Epoch 5, iter 700/787) Average loss so far: 3.107
(Epoch 5, iter 750/787) Average loss so far: 3.096
Average epoch loss: 3.128
This epoch took 8.069037703673045 mi

7it [00:01,  5.64it/s]


validation loss: 3.240325791495187


100%|██████████| 7/7 [00:03<00:00,  1.98it/s]
100%|██████████| 793/793 [00:01<00:00, 407.38it/s]


BLEU score: 0.03311247830068684, METEOR score: 0.1885298967343268
Starting epoch 7/30, enc lr scheduler: [0.0009054634122155991], dec lr scheduler: [0.0009054634122155991]
(Epoch 6, iter 50/787) Average loss so far: 3.083
(Epoch 6, iter 100/787) Average loss so far: 3.079
(Epoch 6, iter 150/787) Average loss so far: 3.076
(Epoch 6, iter 200/787) Average loss so far: 3.064
(Epoch 6, iter 250/787) Average loss so far: 3.057
(Epoch 6, iter 300/787) Average loss so far: 3.059
(Epoch 6, iter 350/787) Average loss so far: 3.053
(Epoch 6, iter 400/787) Average loss so far: 3.059
(Epoch 6, iter 450/787) Average loss so far: 3.061
(Epoch 6, iter 500/787) Average loss so far: 3.048
(Epoch 6, iter 550/787) Average loss so far: 3.040
(Epoch 6, iter 600/787) Average loss so far: 3.061
(Epoch 6, iter 650/787) Average loss so far: 3.037
(Epoch 6, iter 700/787) Average loss so far: 3.044
(Epoch 6, iter 750/787) Average loss so far: 3.034
Average epoch loss: 3.056
This epoch took 8.02225666443507 mins.

7it [00:01,  5.65it/s]


validation loss: 3.198171922138759


100%|██████████| 7/7 [00:03<00:00,  1.87it/s]
100%|██████████| 793/793 [00:02<00:00, 363.29it/s]


BLEU score: 0.029020290042606374, METEOR score: 0.18989639255673424
Starting epoch 8/30, enc lr scheduler: [0.0008728566886113102], dec lr scheduler: [0.0008728566886113102]
(Epoch 7, iter 50/787) Average loss so far: 3.021
(Epoch 7, iter 100/787) Average loss so far: 3.015
(Epoch 7, iter 150/787) Average loss so far: 3.023
(Epoch 7, iter 200/787) Average loss so far: 3.006
(Epoch 7, iter 250/787) Average loss so far: 3.001
(Epoch 7, iter 300/787) Average loss so far: 3.006
(Epoch 7, iter 350/787) Average loss so far: 3.002
(Epoch 7, iter 400/787) Average loss so far: 2.992
(Epoch 7, iter 450/787) Average loss so far: 3.000
(Epoch 7, iter 500/787) Average loss so far: 2.984
(Epoch 7, iter 550/787) Average loss so far: 2.998
(Epoch 7, iter 600/787) Average loss so far: 2.974
(Epoch 7, iter 650/787) Average loss so far: 2.990
(Epoch 7, iter 700/787) Average loss so far: 2.993
(Epoch 7, iter 750/787) Average loss so far: 2.989
Average epoch loss: 2.999
This epoch took 8.018091865380605 mi

7it [00:01,  5.54it/s]


validation loss: 3.1683931010110036


100%|██████████| 7/7 [00:03<00:00,  2.00it/s]
100%|██████████| 793/793 [00:01<00:00, 458.38it/s]


BLEU score: 0.040572019424333215, METEOR score: 0.20380746462583008
Starting epoch 9/30, enc lr scheduler: [0.0008362196501476349], dec lr scheduler: [0.0008362196501476349]
(Epoch 8, iter 50/787) Average loss so far: 2.965
(Epoch 8, iter 100/787) Average loss so far: 2.963
(Epoch 8, iter 150/787) Average loss so far: 2.964
(Epoch 8, iter 200/787) Average loss so far: 2.969
(Epoch 8, iter 250/787) Average loss so far: 2.949
(Epoch 8, iter 300/787) Average loss so far: 2.951
(Epoch 8, iter 350/787) Average loss so far: 2.958
(Epoch 8, iter 400/787) Average loss so far: 2.944
(Epoch 8, iter 450/787) Average loss so far: 2.949
(Epoch 8, iter 500/787) Average loss so far: 2.945
(Epoch 8, iter 550/787) Average loss so far: 2.958
(Epoch 8, iter 600/787) Average loss so far: 2.933
(Epoch 8, iter 650/787) Average loss so far: 2.937
(Epoch 8, iter 700/787) Average loss so far: 2.957
(Epoch 8, iter 750/787) Average loss so far: 2.939
Average epoch loss: 2.951
This epoch took 8.068367584546406 mi

7it [00:01,  5.66it/s]


validation loss: 3.145275660923549


100%|██████████| 7/7 [00:03<00:00,  2.00it/s]
100%|██████████| 793/793 [00:01<00:00, 420.62it/s]


BLEU score: 0.035724116920881344, METEOR score: 0.1954857121917429
Starting epoch 10/30, enc lr scheduler: [0.0007959536998847742], dec lr scheduler: [0.0007959536998847742]
(Epoch 9, iter 50/787) Average loss so far: 2.933
(Epoch 9, iter 100/787) Average loss so far: 2.911
(Epoch 9, iter 150/787) Average loss so far: 2.903
(Epoch 9, iter 200/787) Average loss so far: 2.910
(Epoch 9, iter 250/787) Average loss so far: 2.920
(Epoch 9, iter 300/787) Average loss so far: 2.914
(Epoch 9, iter 350/787) Average loss so far: 2.922
(Epoch 9, iter 400/787) Average loss so far: 2.912
(Epoch 9, iter 450/787) Average loss so far: 2.927
(Epoch 9, iter 500/787) Average loss so far: 2.903
(Epoch 9, iter 550/787) Average loss so far: 2.909
(Epoch 9, iter 600/787) Average loss so far: 2.892
(Epoch 9, iter 650/787) Average loss so far: 2.910
(Epoch 9, iter 700/787) Average loss so far: 2.914
(Epoch 9, iter 750/787) Average loss so far: 2.891
Average epoch loss: 2.911
This epoch took 8.082320443789165 mi

7it [00:01,  5.57it/s]


validation loss: 3.127824068069458


100%|██████████| 7/7 [00:03<00:00,  2.11it/s]
100%|██████████| 793/793 [00:01<00:00, 480.40it/s]


BLEU score: 0.039821027333439805, METEOR score: 0.1971742265366046
Starting epoch 11/30, enc lr scheduler: [0.0007525], dec lr scheduler: [0.0007525]
(Epoch 10, iter 50/787) Average loss so far: 2.885
(Epoch 10, iter 100/787) Average loss so far: 2.880
(Epoch 10, iter 150/787) Average loss so far: 2.865
(Epoch 10, iter 200/787) Average loss so far: 2.880
(Epoch 10, iter 250/787) Average loss so far: 2.870
(Epoch 10, iter 300/787) Average loss so far: 2.883
(Epoch 10, iter 350/787) Average loss so far: 2.875
(Epoch 10, iter 400/787) Average loss so far: 2.885
(Epoch 10, iter 450/787) Average loss so far: 2.871
(Epoch 10, iter 500/787) Average loss so far: 2.883
(Epoch 10, iter 550/787) Average loss so far: 2.876
(Epoch 10, iter 600/787) Average loss so far: 2.870
(Epoch 10, iter 650/787) Average loss so far: 2.877
(Epoch 10, iter 700/787) Average loss so far: 2.873
(Epoch 10, iter 750/787) Average loss so far: 2.885
Average epoch loss: 2.877
This epoch took 8.075558908780415 mins. Time 

7it [00:01,  5.60it/s]


validation loss: 3.111717598778861


100%|██████████| 7/7 [00:03<00:00,  2.07it/s]
100%|██████████| 793/793 [00:01<00:00, 485.53it/s]


BLEU score: 0.04188177528685199, METEOR score: 0.20008110559374
Starting epoch 12/30, enc lr scheduler: [0.0007063346383225212], dec lr scheduler: [0.0007063346383225212]
(Epoch 11, iter 50/787) Average loss so far: 2.845
(Epoch 11, iter 100/787) Average loss so far: 2.858
(Epoch 11, iter 150/787) Average loss so far: 2.841
(Epoch 11, iter 200/787) Average loss so far: 2.848
(Epoch 11, iter 250/787) Average loss so far: 2.858
(Epoch 11, iter 300/787) Average loss so far: 2.844
(Epoch 11, iter 350/787) Average loss so far: 2.851
(Epoch 11, iter 400/787) Average loss so far: 2.857
(Epoch 11, iter 450/787) Average loss so far: 2.842
(Epoch 11, iter 500/787) Average loss so far: 2.851
(Epoch 11, iter 550/787) Average loss so far: 2.847
(Epoch 11, iter 600/787) Average loss so far: 2.832
(Epoch 11, iter 650/787) Average loss so far: 2.844
(Epoch 11, iter 700/787) Average loss so far: 2.852
(Epoch 11, iter 750/787) Average loss so far: 2.860
Average epoch loss: 2.848
This epoch took 8.083603

7it [00:01,  5.68it/s]


validation loss: 3.0999631881713867


100%|██████████| 7/7 [00:03<00:00,  2.11it/s]
100%|██████████| 793/793 [00:01<00:00, 503.03it/s]


BLEU score: 0.044384897292243625, METEOR score: 0.20706484658028232
Starting epoch 13/30, enc lr scheduler: [0.000657963412215599], dec lr scheduler: [0.000657963412215599]
(Epoch 12, iter 50/787) Average loss so far: 2.819
(Epoch 12, iter 100/787) Average loss so far: 2.815
(Epoch 12, iter 150/787) Average loss so far: 2.809
(Epoch 12, iter 200/787) Average loss so far: 2.823
(Epoch 12, iter 250/787) Average loss so far: 2.823
(Epoch 12, iter 300/787) Average loss so far: 2.830
(Epoch 12, iter 350/787) Average loss so far: 2.808
(Epoch 12, iter 400/787) Average loss so far: 2.828
(Epoch 12, iter 450/787) Average loss so far: 2.822
(Epoch 12, iter 500/787) Average loss so far: 2.823
(Epoch 12, iter 550/787) Average loss so far: 2.818
(Epoch 12, iter 600/787) Average loss so far: 2.833
(Epoch 12, iter 650/787) Average loss so far: 2.837
(Epoch 12, iter 700/787) Average loss so far: 2.817
(Epoch 12, iter 750/787) Average loss so far: 2.824
Average epoch loss: 2.821
This epoch took 8.0355

7it [00:01,  5.47it/s]


validation loss: 3.087609222957066


100%|██████████| 7/7 [00:03<00:00,  2.05it/s]
100%|██████████| 793/793 [00:01<00:00, 459.58it/s]


BLEU score: 0.041033156261249384, METEOR score: 0.20340139642244925
Starting epoch 14/30, enc lr scheduler: [0.0006079162869547909], dec lr scheduler: [0.0006079162869547909]
(Epoch 13, iter 50/787) Average loss so far: 2.810
(Epoch 13, iter 100/787) Average loss so far: 2.799
(Epoch 13, iter 150/787) Average loss so far: 2.801
(Epoch 13, iter 200/787) Average loss so far: 2.796
(Epoch 13, iter 250/787) Average loss so far: 2.825
(Epoch 13, iter 300/787) Average loss so far: 2.793
(Epoch 13, iter 350/787) Average loss so far: 2.795
(Epoch 13, iter 400/787) Average loss so far: 2.803
(Epoch 13, iter 450/787) Average loss so far: 2.799
(Epoch 13, iter 500/787) Average loss so far: 2.798
(Epoch 13, iter 550/787) Average loss so far: 2.785
(Epoch 13, iter 600/787) Average loss so far: 2.788
(Epoch 13, iter 650/787) Average loss so far: 2.808
(Epoch 13, iter 700/787) Average loss so far: 2.801
(Epoch 13, iter 750/787) Average loss so far: 2.792
Average epoch loss: 2.799
This epoch took 8.04

7it [00:01,  5.56it/s]


validation loss: 3.080646344593593


100%|██████████| 7/7 [00:03<00:00,  2.14it/s]
100%|██████████| 793/793 [00:01<00:00, 515.81it/s]


BLEU score: 0.046692676994029854, METEOR score: 0.20559286007764238
Starting epoch 15/30, enc lr scheduler: [0.0005567415893174886], dec lr scheduler: [0.0005567415893174886]
(Epoch 14, iter 50/787) Average loss so far: 2.781
(Epoch 14, iter 100/787) Average loss so far: 2.776
(Epoch 14, iter 150/787) Average loss so far: 2.767
(Epoch 14, iter 200/787) Average loss so far: 2.774
(Epoch 14, iter 250/787) Average loss so far: 2.769
(Epoch 14, iter 300/787) Average loss so far: 2.789
(Epoch 14, iter 350/787) Average loss so far: 2.760
(Epoch 14, iter 400/787) Average loss so far: 2.791
(Epoch 14, iter 450/787) Average loss so far: 2.784
(Epoch 14, iter 500/787) Average loss so far: 2.793
(Epoch 14, iter 550/787) Average loss so far: 2.776
(Epoch 14, iter 600/787) Average loss so far: 2.781
(Epoch 14, iter 650/787) Average loss so far: 2.773
(Epoch 14, iter 700/787) Average loss so far: 2.780
(Epoch 14, iter 750/787) Average loss so far: 2.788
Average epoch loss: 2.779
This epoch took 8.12

7it [00:01,  5.66it/s]


validation loss: 3.0746142183031355


100%|██████████| 7/7 [00:03<00:00,  2.15it/s]
100%|██████████| 793/793 [00:01<00:00, 521.60it/s]


BLEU score: 0.046943470589521766, METEOR score: 0.20379152039697576
Starting epoch 16/30, enc lr scheduler: [0.0005050000000000002], dec lr scheduler: [0.0005050000000000002]
(Epoch 15, iter 50/787) Average loss so far: 2.756
(Epoch 15, iter 100/787) Average loss so far: 2.753
(Epoch 15, iter 150/787) Average loss so far: 2.763
(Epoch 15, iter 200/787) Average loss so far: 2.755
(Epoch 15, iter 250/787) Average loss so far: 2.757
(Epoch 15, iter 300/787) Average loss so far: 2.760
(Epoch 15, iter 350/787) Average loss so far: 2.764
(Epoch 15, iter 400/787) Average loss so far: 2.747
(Epoch 15, iter 450/787) Average loss so far: 2.753
(Epoch 15, iter 500/787) Average loss so far: 2.772
(Epoch 15, iter 550/787) Average loss so far: 2.772
(Epoch 15, iter 600/787) Average loss so far: 2.770
(Epoch 15, iter 650/787) Average loss so far: 2.764
(Epoch 15, iter 700/787) Average loss so far: 2.764
(Epoch 15, iter 750/787) Average loss so far: 2.762
Average epoch loss: 2.760
This epoch took 8.08

7it [00:01,  5.72it/s]


validation loss: 3.0693612098693848


100%|██████████| 7/7 [00:03<00:00,  2.08it/s]
100%|██████████| 793/793 [00:01<00:00, 496.10it/s]


BLEU score: 0.04693823811822118, METEOR score: 0.21012783449843356
Starting epoch 17/30, enc lr scheduler: [0.0004532584106825117], dec lr scheduler: [0.0004532584106825117]
(Epoch 16, iter 50/787) Average loss so far: 2.741
(Epoch 16, iter 100/787) Average loss so far: 2.739
(Epoch 16, iter 150/787) Average loss so far: 2.747
(Epoch 16, iter 200/787) Average loss so far: 2.731
(Epoch 16, iter 250/787) Average loss so far: 2.741
(Epoch 16, iter 300/787) Average loss so far: 2.748
(Epoch 16, iter 350/787) Average loss so far: 2.753
(Epoch 16, iter 400/787) Average loss so far: 2.756
(Epoch 16, iter 450/787) Average loss so far: 2.745
(Epoch 16, iter 500/787) Average loss so far: 2.757
(Epoch 16, iter 550/787) Average loss so far: 2.734
(Epoch 16, iter 600/787) Average loss so far: 2.740
(Epoch 16, iter 650/787) Average loss so far: 2.744
(Epoch 16, iter 700/787) Average loss so far: 2.728
(Epoch 16, iter 750/787) Average loss so far: 2.755
Average epoch loss: 2.745
This epoch took 8.109

7it [00:01,  5.33it/s]


validation loss: 3.065413134438651


100%|██████████| 7/7 [00:03<00:00,  2.03it/s]
100%|██████████| 793/793 [00:01<00:00, 503.49it/s]


BLEU score: 0.048756397533905496, METEOR score: 0.21361782160791687
Starting epoch 18/30, enc lr scheduler: [0.00040208371304520916], dec lr scheduler: [0.00040208371304520916]
(Epoch 17, iter 50/787) Average loss so far: 2.730
(Epoch 17, iter 100/787) Average loss so far: 2.725
(Epoch 17, iter 150/787) Average loss so far: 2.735
(Epoch 17, iter 200/787) Average loss so far: 2.726
(Epoch 17, iter 250/787) Average loss so far: 2.714
(Epoch 17, iter 300/787) Average loss so far: 2.742
(Epoch 17, iter 350/787) Average loss so far: 2.721
(Epoch 17, iter 400/787) Average loss so far: 2.716
(Epoch 17, iter 450/787) Average loss so far: 2.738
(Epoch 17, iter 500/787) Average loss so far: 2.729
(Epoch 17, iter 550/787) Average loss so far: 2.727
(Epoch 17, iter 600/787) Average loss so far: 2.739
(Epoch 17, iter 650/787) Average loss so far: 2.738
(Epoch 17, iter 700/787) Average loss so far: 2.740
(Epoch 17, iter 750/787) Average loss so far: 2.741
Average epoch loss: 2.731
This epoch took 8.

7it [00:01,  5.55it/s]


validation loss: 3.061034543173654


100%|██████████| 7/7 [00:03<00:00,  2.17it/s]
100%|██████████| 793/793 [00:01<00:00, 541.85it/s]


BLEU score: 0.05130481997908665, METEOR score: 0.21268616360288597
Starting epoch 19/30, enc lr scheduler: [0.00035203658778440114], dec lr scheduler: [0.00035203658778440114]
(Epoch 18, iter 50/787) Average loss so far: 2.706
(Epoch 18, iter 100/787) Average loss so far: 2.712
(Epoch 18, iter 150/787) Average loss so far: 2.709
(Epoch 18, iter 200/787) Average loss so far: 2.718
(Epoch 18, iter 250/787) Average loss so far: 2.717
(Epoch 18, iter 300/787) Average loss so far: 2.702
(Epoch 18, iter 350/787) Average loss so far: 2.721
(Epoch 18, iter 400/787) Average loss so far: 2.737
(Epoch 18, iter 450/787) Average loss so far: 2.725
(Epoch 18, iter 500/787) Average loss so far: 2.715
(Epoch 18, iter 550/787) Average loss so far: 2.726
(Epoch 18, iter 600/787) Average loss so far: 2.708
(Epoch 18, iter 650/787) Average loss so far: 2.715
(Epoch 18, iter 700/787) Average loss so far: 2.727
(Epoch 18, iter 750/787) Average loss so far: 2.723
Average epoch loss: 2.718
This epoch took 8.1

7it [00:01,  5.26it/s]


validation loss: 3.0552920273372104


100%|██████████| 7/7 [00:03<00:00,  2.28it/s]
100%|██████████| 793/793 [00:01<00:00, 586.20it/s]


BLEU score: 0.05395021872522183, METEOR score: 0.21410343910120513
Starting epoch 20/30, enc lr scheduler: [0.00030366536167747904], dec lr scheduler: [0.00030366536167747904]
(Epoch 19, iter 50/787) Average loss so far: 2.709
(Epoch 19, iter 100/787) Average loss so far: 2.710
(Epoch 19, iter 150/787) Average loss so far: 2.699
(Epoch 19, iter 200/787) Average loss so far: 2.694
(Epoch 19, iter 250/787) Average loss so far: 2.694
(Epoch 19, iter 300/787) Average loss so far: 2.711
(Epoch 19, iter 350/787) Average loss so far: 2.704
(Epoch 19, iter 400/787) Average loss so far: 2.708
(Epoch 19, iter 450/787) Average loss so far: 2.699
(Epoch 19, iter 500/787) Average loss so far: 2.716
(Epoch 19, iter 550/787) Average loss so far: 2.701
(Epoch 19, iter 600/787) Average loss so far: 2.707
(Epoch 19, iter 650/787) Average loss so far: 2.717
(Epoch 19, iter 700/787) Average loss so far: 2.700
(Epoch 19, iter 750/787) Average loss so far: 2.728
Average epoch loss: 2.707
This epoch took 8.0

7it [00:01,  5.48it/s]


validation loss: 3.0541302817208424


100%|██████████| 7/7 [00:02<00:00,  2.45it/s]
100%|██████████| 793/793 [00:01<00:00, 657.10it/s]


BLEU score: 0.05167331114826218, METEOR score: 0.2028060241888733
Starting epoch 21/30, enc lr scheduler: [0.00025750000000000013], dec lr scheduler: [0.00025750000000000013]
(Epoch 20, iter 50/787) Average loss so far: 2.700
(Epoch 20, iter 100/787) Average loss so far: 2.687
(Epoch 20, iter 150/787) Average loss so far: 2.699
(Epoch 20, iter 200/787) Average loss so far: 2.703
(Epoch 20, iter 250/787) Average loss so far: 2.706
(Epoch 20, iter 300/787) Average loss so far: 2.704
(Epoch 20, iter 350/787) Average loss so far: 2.679
(Epoch 20, iter 400/787) Average loss so far: 2.706
(Epoch 20, iter 450/787) Average loss so far: 2.690
(Epoch 20, iter 500/787) Average loss so far: 2.689
(Epoch 20, iter 550/787) Average loss so far: 2.693
(Epoch 20, iter 600/787) Average loss so far: 2.709
(Epoch 20, iter 650/787) Average loss so far: 2.693
(Epoch 20, iter 700/787) Average loss so far: 2.706
(Epoch 20, iter 750/787) Average loss so far: 2.695
Average epoch loss: 2.698
This epoch took 7.94

7it [00:01,  5.96it/s]


validation loss: 3.0529042993273054


100%|██████████| 7/7 [00:02<00:00,  2.36it/s]
100%|██████████| 793/793 [00:01<00:00, 604.51it/s]


BLEU score: 0.05473522071477695, METEOR score: 0.2107675382802335
Starting epoch 22/30, enc lr scheduler: [0.00021404630011522585], dec lr scheduler: [0.00021404630011522585]
(Epoch 21, iter 50/787) Average loss so far: 2.690
(Epoch 21, iter 100/787) Average loss so far: 2.674
(Epoch 21, iter 150/787) Average loss so far: 2.676
(Epoch 21, iter 200/787) Average loss so far: 2.707
(Epoch 21, iter 250/787) Average loss so far: 2.681
(Epoch 21, iter 300/787) Average loss so far: 2.691
(Epoch 21, iter 350/787) Average loss so far: 2.691
(Epoch 21, iter 400/787) Average loss so far: 2.695
(Epoch 21, iter 450/787) Average loss so far: 2.690
(Epoch 21, iter 500/787) Average loss so far: 2.681
(Epoch 21, iter 550/787) Average loss so far: 2.690
(Epoch 21, iter 600/787) Average loss so far: 2.704
(Epoch 21, iter 650/787) Average loss so far: 2.679
(Epoch 21, iter 700/787) Average loss so far: 2.699
(Epoch 21, iter 750/787) Average loss so far: 2.675
Average epoch loss: 2.689
This epoch took 7.52

7it [00:01,  6.14it/s]


validation loss: 3.0493842533656528


100%|██████████| 7/7 [00:03<00:00,  2.28it/s]
100%|██████████| 793/793 [00:01<00:00, 541.32it/s]


BLEU score: 0.05105392764669949, METEOR score: 0.2156661550691742
Starting epoch 23/30, enc lr scheduler: [0.00017378034985236535], dec lr scheduler: [0.00017378034985236535]
(Epoch 22, iter 50/787) Average loss so far: 2.676
(Epoch 22, iter 100/787) Average loss so far: 2.671
(Epoch 22, iter 150/787) Average loss so far: 2.682
(Epoch 22, iter 200/787) Average loss so far: 2.685
(Epoch 22, iter 250/787) Average loss so far: 2.676
(Epoch 22, iter 300/787) Average loss so far: 2.687
(Epoch 22, iter 350/787) Average loss so far: 2.687
(Epoch 22, iter 400/787) Average loss so far: 2.676
(Epoch 22, iter 450/787) Average loss so far: 2.688
(Epoch 22, iter 500/787) Average loss so far: 2.678
(Epoch 22, iter 550/787) Average loss so far: 2.694
(Epoch 22, iter 600/787) Average loss so far: 2.698
(Epoch 22, iter 650/787) Average loss so far: 2.673
(Epoch 22, iter 700/787) Average loss so far: 2.676
(Epoch 22, iter 750/787) Average loss so far: 2.681
Average epoch loss: 2.682
This epoch took 7.50

7it [00:01,  5.96it/s]


validation loss: 3.0472822529929027


100%|██████████| 7/7 [00:02<00:00,  2.37it/s]
100%|██████████| 793/793 [00:01<00:00, 605.43it/s]


BLEU score: 0.05496028256787088, METEOR score: 0.21032837797650386
Starting epoch 24/30, enc lr scheduler: [0.00013714331138868998], dec lr scheduler: [0.00013714331138868998]
(Epoch 23, iter 50/787) Average loss so far: 2.669
(Epoch 23, iter 100/787) Average loss so far: 2.674
(Epoch 23, iter 150/787) Average loss so far: 2.658
(Epoch 23, iter 200/787) Average loss so far: 2.672
(Epoch 23, iter 250/787) Average loss so far: 2.679
(Epoch 23, iter 300/787) Average loss so far: 2.678
(Epoch 23, iter 350/787) Average loss so far: 2.670
(Epoch 23, iter 400/787) Average loss so far: 2.677
(Epoch 23, iter 450/787) Average loss so far: 2.673
(Epoch 23, iter 500/787) Average loss so far: 2.668
(Epoch 23, iter 550/787) Average loss so far: 2.670
(Epoch 23, iter 600/787) Average loss so far: 2.674
(Epoch 23, iter 650/787) Average loss so far: 2.698
(Epoch 23, iter 700/787) Average loss so far: 2.690
(Epoch 23, iter 750/787) Average loss so far: 2.678
Average epoch loss: 2.676
This epoch took 7.4

7it [00:01,  6.00it/s]


validation loss: 3.0462141377585277


100%|██████████| 7/7 [00:02<00:00,  2.35it/s]
100%|██████████| 793/793 [00:01<00:00, 610.16it/s]


BLEU score: 0.05558312256170381, METEOR score: 0.21142443556612248
Starting epoch 25/30, enc lr scheduler: [0.00010453658778440108], dec lr scheduler: [0.00010453658778440108]
(Epoch 24, iter 50/787) Average loss so far: 2.682
(Epoch 24, iter 100/787) Average loss so far: 2.661
(Epoch 24, iter 150/787) Average loss so far: 2.664
(Epoch 24, iter 200/787) Average loss so far: 2.688
(Epoch 24, iter 250/787) Average loss so far: 2.652
(Epoch 24, iter 300/787) Average loss so far: 2.666
(Epoch 24, iter 350/787) Average loss so far: 2.681
(Epoch 24, iter 400/787) Average loss so far: 2.678
(Epoch 24, iter 450/787) Average loss so far: 2.669
(Epoch 24, iter 500/787) Average loss so far: 2.676
(Epoch 24, iter 550/787) Average loss so far: 2.685
(Epoch 24, iter 600/787) Average loss so far: 2.672
(Epoch 24, iter 650/787) Average loss so far: 2.678
(Epoch 24, iter 700/787) Average loss so far: 2.666
(Epoch 24, iter 750/787) Average loss so far: 2.655
Average epoch loss: 2.671
This epoch took 7.5

7it [00:01,  5.94it/s]


validation loss: 3.0450666631971086


100%|██████████| 7/7 [00:03<00:00,  2.32it/s]
100%|██████████| 793/793 [00:01<00:00, 580.66it/s]


BLEU score: 0.054168369543141096, METEOR score: 0.2190973874735378
Starting epoch 26/30, enc lr scheduler: [7.631742512670285e-05], dec lr scheduler: [7.631742512670285e-05]
(Epoch 25, iter 50/787) Average loss so far: 2.660
(Epoch 25, iter 100/787) Average loss so far: 2.655
(Epoch 25, iter 150/787) Average loss so far: 2.682
(Epoch 25, iter 200/787) Average loss so far: 2.663
(Epoch 25, iter 250/787) Average loss so far: 2.658
(Epoch 25, iter 300/787) Average loss so far: 2.653
(Epoch 25, iter 350/787) Average loss so far: 2.674
(Epoch 25, iter 400/787) Average loss so far: 2.667
(Epoch 25, iter 450/787) Average loss so far: 2.669
(Epoch 25, iter 500/787) Average loss so far: 2.668
(Epoch 25, iter 550/787) Average loss so far: 2.671
(Epoch 25, iter 600/787) Average loss so far: 2.666
(Epoch 25, iter 650/787) Average loss so far: 2.666
(Epoch 25, iter 700/787) Average loss so far: 2.679
(Epoch 25, iter 750/787) Average loss so far: 2.669
Average epoch loss: 2.667
This epoch took 7.473

7it [00:01,  6.05it/s]


validation loss: 3.0444539955684116


100%|██████████| 7/7 [00:02<00:00,  2.53it/s]
100%|██████████| 793/793 [00:01<00:00, 609.17it/s]


BLEU score: 0.05549097705999941, METEOR score: 0.2134507586141872
Starting epoch 27/30, enc lr scheduler: [5.279499846691252e-05], dec lr scheduler: [5.279499846691252e-05]
(Epoch 26, iter 50/787) Average loss so far: 2.655
(Epoch 26, iter 100/787) Average loss so far: 2.655
(Epoch 26, iter 150/787) Average loss so far: 2.641
(Epoch 26, iter 200/787) Average loss so far: 2.668
(Epoch 26, iter 250/787) Average loss so far: 2.671
(Epoch 26, iter 300/787) Average loss so far: 2.661
(Epoch 26, iter 350/787) Average loss so far: 2.672
(Epoch 26, iter 400/787) Average loss so far: 2.673
(Epoch 26, iter 450/787) Average loss so far: 2.667
(Epoch 26, iter 500/787) Average loss so far: 2.671
(Epoch 26, iter 550/787) Average loss so far: 2.665
(Epoch 26, iter 600/787) Average loss so far: 2.668
(Epoch 26, iter 650/787) Average loss so far: 2.661
(Epoch 26, iter 700/787) Average loss so far: 2.657
(Epoch 26, iter 750/787) Average loss so far: 2.660
Average epoch loss: 2.663
This epoch took 7.4265

7it [00:01,  6.26it/s]


validation loss: 3.0439279760633196


100%|██████████| 7/7 [00:02<00:00,  2.39it/s]
100%|██████████| 793/793 [00:01<00:00, 605.29it/s]


BLEU score: 0.05676168975713699, METEOR score: 0.2139917855970703
Starting epoch 28/30, enc lr scheduler: [3.4227024433899005e-05], dec lr scheduler: [3.4227024433899005e-05]
(Epoch 27, iter 50/787) Average loss so far: 2.679
(Epoch 27, iter 100/787) Average loss so far: 2.660
(Epoch 27, iter 150/787) Average loss so far: 2.659
(Epoch 27, iter 200/787) Average loss so far: 2.654
(Epoch 27, iter 250/787) Average loss so far: 2.664
(Epoch 27, iter 300/787) Average loss so far: 2.655
(Epoch 27, iter 350/787) Average loss so far: 2.660
(Epoch 27, iter 400/787) Average loss so far: 2.650
(Epoch 27, iter 450/787) Average loss so far: 2.669
(Epoch 27, iter 500/787) Average loss so far: 2.670
(Epoch 27, iter 550/787) Average loss so far: 2.664
(Epoch 27, iter 600/787) Average loss so far: 2.660
(Epoch 27, iter 650/787) Average loss so far: 2.650
(Epoch 27, iter 700/787) Average loss so far: 2.664
(Epoch 27, iter 750/787) Average loss so far: 2.661
Average epoch loss: 2.661
This epoch took 7.45

7it [00:01,  6.13it/s]


validation loss: 3.0432401725224087


100%|██████████| 7/7 [00:02<00:00,  2.35it/s]
100%|██████████| 793/793 [00:01<00:00, 595.50it/s]


BLEU score: 0.05472410622862404, METEOR score: 0.2129614425943201
Starting epoch 29/30, enc lr scheduler: [2.0816937636766188e-05], dec lr scheduler: [2.0816937636766188e-05]
(Epoch 28, iter 50/787) Average loss so far: 2.666
(Epoch 28, iter 100/787) Average loss so far: 2.647
(Epoch 28, iter 150/787) Average loss so far: 2.662
(Epoch 28, iter 200/787) Average loss so far: 2.633
(Epoch 28, iter 250/787) Average loss so far: 2.664
(Epoch 28, iter 300/787) Average loss so far: 2.664
(Epoch 28, iter 350/787) Average loss so far: 2.668
(Epoch 28, iter 400/787) Average loss so far: 2.659
(Epoch 28, iter 450/787) Average loss so far: 2.665
(Epoch 28, iter 500/787) Average loss so far: 2.660
(Epoch 28, iter 550/787) Average loss so far: 2.663
(Epoch 28, iter 600/787) Average loss so far: 2.655
(Epoch 28, iter 650/787) Average loss so far: 2.662
(Epoch 28, iter 700/787) Average loss so far: 2.662
(Epoch 28, iter 750/787) Average loss so far: 2.661
Average epoch loss: 2.660
This epoch took 7.44

7it [00:01,  5.94it/s]


validation loss: 3.0424865995134627


100%|██████████| 7/7 [00:02<00:00,  2.34it/s]
100%|██████████| 793/793 [00:01<00:00, 588.28it/s]


BLEU score: 0.054989968499987295, METEOR score: 0.21584912823420402
Starting epoch 30/30, enc lr scheduler: [1.2711661792704668e-05], dec lr scheduler: [1.2711661792704668e-05]
(Epoch 29, iter 50/787) Average loss so far: 2.667
(Epoch 29, iter 100/787) Average loss so far: 2.655
(Epoch 29, iter 150/787) Average loss so far: 2.648
(Epoch 29, iter 200/787) Average loss so far: 2.672
(Epoch 29, iter 250/787) Average loss so far: 2.658
(Epoch 29, iter 300/787) Average loss so far: 2.662
(Epoch 29, iter 350/787) Average loss so far: 2.664
(Epoch 29, iter 400/787) Average loss so far: 2.649
(Epoch 29, iter 450/787) Average loss so far: 2.657
(Epoch 29, iter 500/787) Average loss so far: 2.646
(Epoch 29, iter 550/787) Average loss so far: 2.667
(Epoch 29, iter 600/787) Average loss so far: 2.670
(Epoch 29, iter 650/787) Average loss so far: 2.655
(Epoch 29, iter 700/787) Average loss so far: 2.656
(Epoch 29, iter 750/787) Average loss so far: 2.661
Average epoch loss: 2.658
This epoch took 7.

7it [00:01,  6.06it/s]


validation loss: 3.0425445352281844


100%|██████████| 7/7 [00:02<00:00,  2.39it/s]
100%|██████████| 793/793 [00:01<00:00, 619.58it/s]


BLEU score: 0.057802547829115924, METEOR score: 0.21431412239195574


## Extension 4: Multi-Layer Encoder-Decoder + Neurologic Decoding

In [11]:
reset_rng()

In [12]:
embedding_size=300
num_layers=3
encoder_multilayer_attn = EncoderRNN(vocab.n_unique_words, embedding_size=embedding_size, hidden_size=HIDDEN_SIZE, 
                          padding_value=vocab.word2index(PAD_WORD), num_lstm_layers=num_layers).to(DEVICE)
# in the training script, decoder is always fed a non-end token and thus never needs to generate padding
# also it should never generate "<UNKNOWN>"
# decoder = DecoderRNN(embedding_size=embedding_size,hidden_size=HIDDEN_SIZE, output_size=vocab.n_unique_words-2).to(DEVICE)
decoder_multilayer_attn = AttnDecoderRNN(embedding_size, hidden_size=HIDDEN_SIZE, output_size=vocab.n_unique_words-1, 
                              padding_val=vocab.word2index(PAD_WORD), dropout=DROPOUT, num_lstm_layers=num_layers).to(DEVICE)

In [13]:
initial_lr=1e-3
min_lr = 1e-5
n_epochs = 30
batch_size=128
encoder_multilayer_attn_optimizer = optim.Adam(encoder_multilayer_attn.parameters(), lr=initial_lr)
decoder_multilayer_attn_optimizer = optim.Adam(decoder_multilayer_attn.parameters(), lr=initial_lr)
enc_multilayer_attn_scheduler = CosineAnnealingLR(encoder_multilayer_attn_optimizer, T_max=n_epochs, eta_min=min_lr)
dec_multilayer_attn_scheduler = CosineAnnealingLR(decoder_multilayer_attn_optimizer, T_max=n_epochs, eta_min=min_lr)
identifier="multilayer_attn_adam_without_intermediate_tags_wd0_lr1e-3"

attn_epoch_losses, attn_val_epoch_losses, attn_log = train(
    encoder_multilayer_attn, decoder_multilayer_attn, encoder_multilayer_attn_optimizer, decoder_multilayer_attn_optimizer, train_ds, 
    n_epochs=n_epochs, vocab=vocab, decoder_mode="attention", batch_size=batch_size, 
    enc_lr_scheduler=enc_multilayer_attn_scheduler, dec_lr_scheduler=dec_multilayer_attn_scheduler, 
    dev_ds_val_loss = dev_ds_val_loss, dev_ds_val_met=dev_ds_val_met, identifier=identifier,
    verbose_iter_interval=10)

save_log(identifier, attn_log, encoder_multilayer_attn_optimizer, decoder_multilayer_attn_optimizer, 
         enc_multilayer_attn_scheduler, dec_multilayer_attn_scheduler)

Starting epoch 1/30, enc lr scheduler: [0.001], dec lr scheduler: [0.001]
(Epoch 0, iter 10/787) Average loss so far: 9.766
(Epoch 0, iter 20/787) Average loss so far: 6.787
(Epoch 0, iter 30/787) Average loss so far: 6.162
(Epoch 0, iter 40/787) Average loss so far: 6.129
(Epoch 0, iter 50/787) Average loss so far: 6.097
(Epoch 0, iter 60/787) Average loss so far: 6.059
(Epoch 0, iter 70/787) Average loss so far: 6.053
(Epoch 0, iter 80/787) Average loss so far: 6.079
(Epoch 0, iter 90/787) Average loss so far: 6.049
(Epoch 0, iter 100/787) Average loss so far: 6.047
(Epoch 0, iter 110/787) Average loss so far: 6.048
(Epoch 0, iter 120/787) Average loss so far: 6.034
(Epoch 0, iter 130/787) Average loss so far: 6.023
(Epoch 0, iter 140/787) Average loss so far: 6.032
(Epoch 0, iter 150/787) Average loss so far: 6.036
(Epoch 0, iter 160/787) Average loss so far: 6.031
(Epoch 0, iter 170/787) Average loss so far: 6.026
(Epoch 0, iter 180/787) Average loss so far: 6.046
(Epoch 0, iter 19

KeyboardInterrupt: 

---

## Evaluation

### Without attention

In [12]:
load_model(encoder, decoder, "adam_without_intermediate_tags_with_val_wd0_lr1e-3_ep_24")

In [13]:
all_decoder_outs, all_gt_recipes = eval(encoder, decoder, test_ds, vocab, batch_size=128, decoder_mode="basic",
                                        max_recipe_len=MAX_RECIPE_LEN)

100%|██████████| 7/7 [00:02<00:00,  3.07it/s]


In [14]:
calc_bleu(all_gt_recipes, all_decoder_outs)

0.044537752361233085

In [15]:
calc_meteor(all_gt_recipes, all_decoder_outs, split_gt=False)

100%|██████████| 774/774 [00:01<00:00, 435.87it/s]


0.20157579551309057

### With attention

In [15]:
load_model(encoder_attn, decoder_attn, "attn_adam_without_intermediate_tags_wd0_lr1e-3_ep_26")

In [16]:
all_decoder_outs, all_gt_recipes = eval(encoder_attn, decoder_attn, test_ds, vocab, batch_size=128, decoder_mode="attention",
                                        max_recipe_len=MAX_RECIPE_LEN)

100%|██████████| 7/7 [00:02<00:00,  2.40it/s]


In [18]:
all_gt_recipes[0]

['<RECIPE_START>',
 'cream',
 'butter',
 'and',
 'sugar',
 'add',
 'egg',
 'yolks',
 ',',
 'one',
 'at',
 'a',
 'time',
 'and',
 'beat',
 'until',
 'mixture',
 'is',
 'very',
 'light',
 'and',
 'lemon',
 'colored',
 'add',
 'vanilla',
 'and',
 'milk',
 'in',
 'another',
 'bowl',
 'beat',
 'egg',
 'whites',
 'until',
 'foamy',
 'fold',
 'into',
 'creamed',
 'mixture',
 'fold',
 'in',
 '11/2',
 'cups',
 'of',
 'the',
 'coconut',
 'pour',
 'into',
 'baked',
 'pie',
 'shell',
 'and',
 'sprinkle',
 'with',
 'nutmeg',
 'top',
 'with',
 'remaining',
 'grated',
 'coconut',
 'bake',
 'in',
 'preheated',
 '350',
 'oven',
 'for',
 'about',
 '35',
 'minutes',
 'or',
 'until',
 'filling',
 'is',
 'just',
 'set',
 'serve',
 'warm',
 'or',
 'cold',
 'courtesy',
 'of',
 'dale',
 '&',
 'gail',
 'shipp',
 ',',
 'columbia',
 'md',
 '<RECIPE_END>']

In [19]:
all_decoder_outs[0]

['<RECIPE_START>',
 'in',
 'a',
 'medium',
 'saucepan',
 ',',
 'combine',
 'the',
 'sugar',
 'and',
 'cocoa',
 ',',
 'and',
 'cook',
 'over',
 'medium',
 'heat',
 ',',
 'stirring',
 'constantly',
 ',',
 'until',
 'the',
 'mixture',
 'is',
 'smooth',
 'and',
 'the',
 'sugar',
 'is',
 'dissolved',
 ',',
 'about',
 '5',
 'minutes',
 'remove',
 'from',
 'the',
 'heat',
 'and',
 'stir',
 'in',
 'the',
 'milk',
 ',',
 'vanilla',
 'and',
 'salt',
 'and',
 'mix',
 'well',
 'pour',
 'into',
 'a',
 '9',
 '-',
 'inch',
 'springform',
 'pan',
 'bake',
 'in',
 'a',
 'preheated',
 '350',
 'degree',
 'f',
 'oven',
 'for',
 '45',
 'minutes',
 'or',
 'until',
 'the',
 'custard',
 'is',
 'set',
 'cool',
 'completely',
 'on',
 'a',
 'wire',
 'rack',
 'and',
 'refrigerate',
 'for',
 'at',
 'least',
 '2',
 'hours',
 'before',
 'serving',
 '<RECIPE_END>']

In [17]:
calc_bleu(all_gt_recipes, all_decoder_outs)

0.04949963728618708

In [20]:
calc_meteor(all_gt_recipes, all_decoder_outs, split_gt=False)

100%|██████████| 774/774 [00:01<00:00, 636.00it/s]


0.20855536635420194

---

In [19]:
all_decoder_outs, all_gt_recipes = eval(encoder_pretrained_embed, decoder_pretrained_embed, test_ds, vocab, batch_size=128, decoder_mode="attention",
                                        max_recipe_len=MAX_RECIPE_LEN)

100%|██████████| 7/7 [00:03<00:00,  2.11it/s]


In [20]:
calc_bleu(all_gt_recipes, all_decoder_outs)

0.05519142942964554

In [21]:
calc_meteor(all_gt_recipes, all_decoder_outs, split_gt=False)

100%|██████████| 774/774 [00:01<00:00, 586.11it/s]


0.2100281888663272

---

## Metric Sample

In [None]:
all_ings = get_all_ingredients("./ingredient_set.json")
all_ings_regex = get_ingredients_regex(all_ings)
metric_sample_ings, metric_sample_gold_recipe, metric_sample_generated_recipe = \
    load_metric_sample("./metric_sample.txt")

In [None]:
prop_inp_ings, n_extra_ings = get_prop_input_num_extra_ingredients(
    metric_sample_ings, metric_sample_generated_recipe, all_ings_regex, verbose=True,
    metric_sample=True)
print(f"\nproportion of input ingredients: {prop_inp_ings}\nnumber of extra ingredients: {n_extra_ings}")

In [None]:
bleu_score = calc_bleu([metric_sample_gold_recipe], [metric_sample_generated_recipe], split_gen=True)
meteor_score = calc_meteor([metric_sample_gold_recipe], [metric_sample_generated_recipe], split_gen=True)
print(f"BLEU score: {bleu_score}, METEOR score: {meteor_score}")