In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# !pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
# !pip install matplotlib numpy pandas tqdm nltk

# for separating ingredients vs non-ingredients
# NOTE: if using Windows to run this, need to download GNU Wget
# !wget -c https://raw.githubusercontent.com/williamLyh/RecipeWithPlans/main/ingredient_set.json -O ingredient_set.json

In [3]:
import os
import math
import re
import string
import numpy as np
import pandas as pd
import random
import json
from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR, MultiStepLR, CosineAnnealingWarmRestarts
import nltk
from nltk.translate.bleu_score import corpus_bleu, sentence_bleu
from nltk.translate import meteor

from data import *
from encoder_decoder import *
from train import *
from eval import *
from utils import *
from neuro_dec import *

# required for bleu
nltk.download("wordnet")
nltk.download("punkt")

  from .autonotebook import tqdm as notebook_tqdm
[nltk_data] Downloading package wordnet to /home/student/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt to /home/student/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

---

In [4]:
SEED = 31989101
HIDDEN_SIZE = 256
MAX_INGR_LEN = 150 # fixed from assignment
MAX_RECIPE_LEN = 600
DROPOUT = 0.1
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## ensuring reproducibility
def reset_rng():
    torch.manual_seed(SEED)
    np.random.seed(SEED)
    random.seed(SEED)

reset_rng()

# to easily read ingredients and instructions
pd.set_option('display.max_colwidth', 2000)

print(f"Using device: {DEVICE}")

Using device: cuda


In [5]:
data_root = "./Cooking_Dataset"
add_intermediate_tag=False

train_df_orig = pd.read_csv(os.path.join(data_root, "train.csv"), usecols=['Ingredients', 'Recipe'])
dev_df_orig = pd.read_csv(os.path.join(data_root, "dev.csv"), usecols=['Ingredients', 'Recipe'])
test_df_orig = pd.read_csv(os.path.join(data_root, "test.csv"), usecols=['Ingredients', 'Recipe'])

In [6]:
train_df = preprocess_data(train_df_orig, max_ingr_len=MAX_INGR_LEN, max_recipe_len=MAX_RECIPE_LEN, add_intermediate_tag=add_intermediate_tag)

Number of data samples before preprocessing: 101340
Number of data samples after preprocessing: 100637 (99.306%)


In [7]:
dev_df = preprocess_data(dev_df_orig, max_ingr_len=MAX_INGR_LEN, max_recipe_len=MAX_RECIPE_LEN, add_intermediate_tag=add_intermediate_tag)

Number of data samples before preprocessing: 797
Number of data samples after preprocessing: 793 (99.498%)


In [8]:
test_df = preprocess_data(test_df_orig, max_ingr_len=MAX_INGR_LEN, max_recipe_len=MAX_RECIPE_LEN, add_intermediate_tag=add_intermediate_tag)

Number of data samples before preprocessing: 778
Number of data samples after preprocessing: 774 (99.486%)


In [9]:
vocab = Vocabulary(add_intermediate_tag=add_intermediate_tag)
vocab.populate(train_df)
vocab.n_unique_words

  0%|          | 0/100637 [00:00<?, ?it/s]

100%|██████████| 100637/100637 [00:06<00:00, 15730.54it/s]


44315

In [10]:
train_ds = RecipeDataset(train_df, vocab)
# subset_train_ds = RecipeDataset(train_df[:250], vocab) # ! REMOVE LATER
dev_ds_val_loss = RecipeDataset(dev_df, vocab, train=True) # used for getting validation loss
dev_ds_val_met = RecipeDataset(dev_df, vocab, train=False) # used for getting validation BLEU, and other metrics
test_ds = RecipeDataset(test_df, vocab, train=False)

## Encoder-Decoder (Base)

In [11]:
embedding_size=300
encoder = EncoderRNN(vocab.n_unique_words, embedding_size=embedding_size, hidden_size=HIDDEN_SIZE, padding_value=vocab.word2index(PAD_WORD)).to(DEVICE)
# in the training script, decoder is always fed a non-end token and thus never needs to generate padding
# also it should never generate "<UNKNOWN>"
decoder = DecoderRNN(embedding_size=embedding_size,hidden_size=HIDDEN_SIZE, output_size=vocab.n_unique_words-1).to(DEVICE)

In [12]:
initial_lr=1e-3
min_lr = 1e-5
n_epochs = 30
batch_size=128
encoder_optimizer = optim.Adam(encoder.parameters(), lr=initial_lr)
decoder_optimizer = optim.Adam(decoder.parameters(), lr=initial_lr)
enc_scheduler = CosineAnnealingLR(encoder_optimizer, T_max=n_epochs, eta_min=min_lr)
dec_scheduler = CosineAnnealingLR(decoder_optimizer, T_max=n_epochs, eta_min=min_lr)
# enc_scheduler = MultiStepLR(encoder_optimizer, milestones=[15], gamma=0.1)
# dec_scheduler = MultiStepLR(decoder_optimizer, milestones=[15], gamma=0.1)
identifier="adam_without_intermediate_tags_with_val_wd0_lr1e-3"
epoch_losses, val_epoch_losses, log = train(encoder, decoder, encoder_optimizer, decoder_optimizer, train_ds, 
                     n_epochs=n_epochs, vocab=vocab, decoder_mode="basic", batch_size=batch_size, 
                     enc_lr_scheduler=enc_scheduler, dec_lr_scheduler=dec_scheduler, 
                     dev_ds_val_loss = dev_ds_val_loss, dev_ds_val_met=dev_ds_val_met, identifier=identifier,
                     verbose_iter_interval=50)

save_log(identifier, log, encoder_optimizer, decoder_optimizer, enc_scheduler, dec_scheduler)

Starting epoch 1/30, enc lr scheduler: [0.001], dec lr scheduler: [0.001]
(Epoch 0, iter 50/787) Average loss so far: 7.144
(Epoch 0, iter 100/787) Average loss so far: 5.967
(Epoch 0, iter 150/787) Average loss so far: 5.730
(Epoch 0, iter 200/787) Average loss so far: 5.470
(Epoch 0, iter 250/787) Average loss so far: 5.212
(Epoch 0, iter 300/787) Average loss so far: 5.019
(Epoch 0, iter 350/787) Average loss so far: 4.865
(Epoch 0, iter 400/787) Average loss so far: 4.737
(Epoch 0, iter 450/787) Average loss so far: 4.632
(Epoch 0, iter 500/787) Average loss so far: 4.540
(Epoch 0, iter 550/787) Average loss so far: 4.449
(Epoch 0, iter 600/787) Average loss so far: 4.382
(Epoch 0, iter 650/787) Average loss so far: 4.317
(Epoch 0, iter 700/787) Average loss so far: 4.246
(Epoch 0, iter 750/787) Average loss so far: 4.207
Average epoch loss: 4.956
This epoch took 5.9502798040707905 mins. Time remaining: 2.0 hrs 52.0 mins.


7it [00:00,  8.22it/s]


validation loss: 4.188371658325195


100%|██████████| 7/7 [00:07<00:00,  1.02s/it]
100%|██████████| 793/793 [00:09<00:00, 87.30it/s] 


BLEU score: 0.0040273883646875595, METEOR score: 0.09091902403865486
Starting epoch 2/30, enc lr scheduler: [0.0009972883382072953], dec lr scheduler: [0.0009972883382072953]
(Epoch 1, iter 50/787) Average loss so far: 4.108
(Epoch 1, iter 100/787) Average loss so far: 4.071
(Epoch 1, iter 150/787) Average loss so far: 4.044
(Epoch 1, iter 200/787) Average loss so far: 4.004
(Epoch 1, iter 250/787) Average loss so far: 3.953
(Epoch 1, iter 300/787) Average loss so far: 3.925
(Epoch 1, iter 350/787) Average loss so far: 3.904
(Epoch 1, iter 400/787) Average loss so far: 3.880
(Epoch 1, iter 450/787) Average loss so far: 3.846
(Epoch 1, iter 500/787) Average loss so far: 3.816
(Epoch 1, iter 550/787) Average loss so far: 3.792
(Epoch 1, iter 600/787) Average loss so far: 3.775
(Epoch 1, iter 650/787) Average loss so far: 3.753
(Epoch 1, iter 700/787) Average loss so far: 3.747
(Epoch 1, iter 750/787) Average loss so far: 3.701
Average epoch loss: 3.878
This epoch took 5.9652024269104 min

7it [00:00,  8.33it/s]


validation loss: 3.762834276471819


100%|██████████| 7/7 [00:03<00:00,  1.76it/s]
100%|██████████| 793/793 [00:03<00:00, 253.49it/s]


BLEU score: 0.014792012153973902, METEOR score: 0.17129374579676301
Starting epoch 3/30, enc lr scheduler: [0.0009891830623632338], dec lr scheduler: [0.0009891830623632338]
(Epoch 2, iter 50/787) Average loss so far: 3.655
(Epoch 2, iter 100/787) Average loss so far: 3.650
(Epoch 2, iter 150/787) Average loss so far: 3.604
(Epoch 2, iter 200/787) Average loss so far: 3.606
(Epoch 2, iter 250/787) Average loss so far: 3.587
(Epoch 2, iter 300/787) Average loss so far: 3.586
(Epoch 2, iter 350/787) Average loss so far: 3.569
(Epoch 2, iter 400/787) Average loss so far: 3.535
(Epoch 2, iter 450/787) Average loss so far: 3.553
(Epoch 2, iter 500/787) Average loss so far: 3.543
(Epoch 2, iter 550/787) Average loss so far: 3.522
(Epoch 2, iter 600/787) Average loss so far: 3.509
(Epoch 2, iter 650/787) Average loss so far: 3.499
(Epoch 2, iter 700/787) Average loss so far: 3.479
(Epoch 2, iter 750/787) Average loss so far: 3.466
Average epoch loss: 3.553
This epoch took 5.964460162321727 mi

7it [00:00,  8.39it/s]


validation loss: 3.558991943086897


100%|██████████| 7/7 [00:04<00:00,  1.70it/s]
100%|██████████| 793/793 [00:03<00:00, 208.83it/s]


BLEU score: 0.013738925056068545, METEOR score: 0.1597429689849473
Starting epoch 4/30, enc lr scheduler: [0.0009757729755661011], dec lr scheduler: [0.0009757729755661011]
(Epoch 3, iter 50/787) Average loss so far: 3.436
(Epoch 3, iter 100/787) Average loss so far: 3.407
(Epoch 3, iter 150/787) Average loss so far: 3.395
(Epoch 3, iter 200/787) Average loss so far: 3.389
(Epoch 3, iter 250/787) Average loss so far: 3.381
(Epoch 3, iter 300/787) Average loss so far: 3.382
(Epoch 3, iter 350/787) Average loss so far: 3.371
(Epoch 3, iter 400/787) Average loss so far: 3.364
(Epoch 3, iter 450/787) Average loss so far: 3.346
(Epoch 3, iter 500/787) Average loss so far: 3.333
(Epoch 3, iter 550/787) Average loss so far: 3.330
(Epoch 3, iter 600/787) Average loss so far: 3.347
(Epoch 3, iter 650/787) Average loss so far: 3.328
(Epoch 3, iter 700/787) Average loss so far: 3.305
(Epoch 3, iter 750/787) Average loss so far: 3.300
Average epoch loss: 3.359
This epoch took 5.93084941705068 mins

7it [00:00,  8.37it/s]


validation loss: 3.4334449768066406


100%|██████████| 7/7 [00:03<00:00,  2.01it/s]
100%|██████████| 793/793 [00:02<00:00, 284.67it/s]


BLEU score: 0.01996062579533466, METEOR score: 0.17012901848915718
Starting epoch 5/30, enc lr scheduler: [0.0009572050015330874], dec lr scheduler: [0.0009572050015330874]
(Epoch 4, iter 50/787) Average loss so far: 3.272
(Epoch 4, iter 100/787) Average loss so far: 3.258
(Epoch 4, iter 150/787) Average loss so far: 3.250
(Epoch 4, iter 200/787) Average loss so far: 3.243
(Epoch 4, iter 250/787) Average loss so far: 3.239
(Epoch 4, iter 300/787) Average loss so far: 3.247
(Epoch 4, iter 350/787) Average loss so far: 3.219
(Epoch 4, iter 400/787) Average loss so far: 3.228
(Epoch 4, iter 450/787) Average loss so far: 3.223
(Epoch 4, iter 500/787) Average loss so far: 3.212
(Epoch 4, iter 550/787) Average loss so far: 3.211
(Epoch 4, iter 600/787) Average loss so far: 3.203
(Epoch 4, iter 650/787) Average loss so far: 3.205
(Epoch 4, iter 700/787) Average loss so far: 3.205
(Epoch 4, iter 750/787) Average loss so far: 3.193
Average epoch loss: 3.227
This epoch took 5.959438494841257 min

7it [00:00,  8.29it/s]


validation loss: 3.358621767589024


100%|██████████| 7/7 [00:02<00:00,  2.36it/s]
100%|██████████| 793/793 [00:02<00:00, 380.90it/s]


BLEU score: 0.027415656741189266, METEOR score: 0.1823641169217577
Starting epoch 6/30, enc lr scheduler: [0.0009336825748732973], dec lr scheduler: [0.0009336825748732973]
(Epoch 5, iter 50/787) Average loss so far: 3.157
(Epoch 5, iter 100/787) Average loss so far: 3.144
(Epoch 5, iter 150/787) Average loss so far: 3.140
(Epoch 5, iter 200/787) Average loss so far: 3.148
(Epoch 5, iter 250/787) Average loss so far: 3.115
(Epoch 5, iter 300/787) Average loss so far: 3.129
(Epoch 5, iter 350/787) Average loss so far: 3.146
(Epoch 5, iter 400/787) Average loss so far: 3.139
(Epoch 5, iter 450/787) Average loss so far: 3.141
(Epoch 5, iter 500/787) Average loss so far: 3.121
(Epoch 5, iter 550/787) Average loss so far: 3.120
(Epoch 5, iter 600/787) Average loss so far: 3.127
(Epoch 5, iter 650/787) Average loss so far: 3.111
(Epoch 5, iter 700/787) Average loss so far: 3.113
(Epoch 5, iter 750/787) Average loss so far: 3.111
Average epoch loss: 3.129
This epoch took 5.956571054458618 min

7it [00:00,  8.34it/s]


validation loss: 3.3043336187090193


100%|██████████| 7/7 [00:02<00:00,  2.46it/s]
100%|██████████| 793/793 [00:01<00:00, 401.47it/s]


BLEU score: 0.029595775761603153, METEOR score: 0.1858161071099046
Starting epoch 7/30, enc lr scheduler: [0.0009054634122155991], dec lr scheduler: [0.0009054634122155991]
(Epoch 6, iter 50/787) Average loss so far: 3.067
(Epoch 6, iter 100/787) Average loss so far: 3.064
(Epoch 6, iter 150/787) Average loss so far: 3.049
(Epoch 6, iter 200/787) Average loss so far: 3.051
(Epoch 6, iter 250/787) Average loss so far: 3.078
(Epoch 6, iter 300/787) Average loss so far: 3.066
(Epoch 6, iter 350/787) Average loss so far: 3.067
(Epoch 6, iter 400/787) Average loss so far: 3.046
(Epoch 6, iter 450/787) Average loss so far: 3.054
(Epoch 6, iter 500/787) Average loss so far: 3.038
(Epoch 6, iter 550/787) Average loss so far: 3.039
(Epoch 6, iter 600/787) Average loss so far: 3.050
(Epoch 6, iter 650/787) Average loss so far: 3.060
(Epoch 6, iter 700/787) Average loss so far: 3.054
(Epoch 6, iter 750/787) Average loss so far: 3.032
Average epoch loss: 3.054
This epoch took 5.976559368769328 min

7it [00:00,  8.35it/s]


validation loss: 3.268099410193307


100%|██████████| 7/7 [00:03<00:00,  2.20it/s]
100%|██████████| 793/793 [00:02<00:00, 323.56it/s]


BLEU score: 0.024870246390936355, METEOR score: 0.1775517379842907
Starting epoch 8/30, enc lr scheduler: [0.0008728566886113102], dec lr scheduler: [0.0008728566886113102]
(Epoch 7, iter 50/787) Average loss so far: 3.002
(Epoch 7, iter 100/787) Average loss so far: 2.993
(Epoch 7, iter 150/787) Average loss so far: 2.997
(Epoch 7, iter 200/787) Average loss so far: 2.996
(Epoch 7, iter 250/787) Average loss so far: 2.998
(Epoch 7, iter 300/787) Average loss so far: 2.999
(Epoch 7, iter 350/787) Average loss so far: 2.998
(Epoch 7, iter 400/787) Average loss so far: 2.988
(Epoch 7, iter 450/787) Average loss so far: 2.995
(Epoch 7, iter 500/787) Average loss so far: 2.992
(Epoch 7, iter 550/787) Average loss so far: 2.994
(Epoch 7, iter 600/787) Average loss so far: 3.009
(Epoch 7, iter 650/787) Average loss so far: 2.974
(Epoch 7, iter 700/787) Average loss so far: 2.990
(Epoch 7, iter 750/787) Average loss so far: 2.992
Average epoch loss: 2.994
This epoch took 5.975848956902822 min

7it [00:00,  8.45it/s]


validation loss: 3.2407847813197543


100%|██████████| 7/7 [00:03<00:00,  2.18it/s]
100%|██████████| 793/793 [00:02<00:00, 335.07it/s]


BLEU score: 0.024665420986828435, METEOR score: 0.18374070954399746
Starting epoch 9/30, enc lr scheduler: [0.0008362196501476349], dec lr scheduler: [0.0008362196501476349]
(Epoch 8, iter 50/787) Average loss so far: 2.956
(Epoch 8, iter 100/787) Average loss so far: 2.948
(Epoch 8, iter 150/787) Average loss so far: 2.961
(Epoch 8, iter 200/787) Average loss so far: 2.941
(Epoch 8, iter 250/787) Average loss so far: 2.955
(Epoch 8, iter 300/787) Average loss so far: 2.945
(Epoch 8, iter 350/787) Average loss so far: 2.959
(Epoch 8, iter 400/787) Average loss so far: 2.946
(Epoch 8, iter 450/787) Average loss so far: 2.938
(Epoch 8, iter 500/787) Average loss so far: 2.930
(Epoch 8, iter 550/787) Average loss so far: 2.946
(Epoch 8, iter 600/787) Average loss so far: 2.937
(Epoch 8, iter 650/787) Average loss so far: 2.956
(Epoch 8, iter 700/787) Average loss so far: 2.938
(Epoch 8, iter 750/787) Average loss so far: 2.926
Average epoch loss: 2.945
This epoch took 5.916466542085012 mi

7it [00:00,  8.38it/s]


validation loss: 3.2253596442086354


100%|██████████| 7/7 [00:02<00:00,  2.75it/s]
100%|██████████| 793/793 [00:01<00:00, 526.72it/s]


BLEU score: 0.0382104543038602, METEOR score: 0.19038100657935783
Starting epoch 10/30, enc lr scheduler: [0.0007959536998847742], dec lr scheduler: [0.0007959536998847742]
(Epoch 9, iter 50/787) Average loss so far: 2.894
(Epoch 9, iter 100/787) Average loss so far: 2.909
(Epoch 9, iter 150/787) Average loss so far: 2.912
(Epoch 9, iter 200/787) Average loss so far: 2.911
(Epoch 9, iter 250/787) Average loss so far: 2.901
(Epoch 9, iter 300/787) Average loss so far: 2.888
(Epoch 9, iter 350/787) Average loss so far: 2.905
(Epoch 9, iter 400/787) Average loss so far: 2.903
(Epoch 9, iter 450/787) Average loss so far: 2.915
(Epoch 9, iter 500/787) Average loss so far: 2.916
(Epoch 9, iter 550/787) Average loss so far: 2.901
(Epoch 9, iter 600/787) Average loss so far: 2.902
(Epoch 9, iter 650/787) Average loss so far: 2.888
(Epoch 9, iter 700/787) Average loss so far: 2.900
(Epoch 9, iter 750/787) Average loss so far: 2.907
Average epoch loss: 2.904
This epoch took 5.962946979204814 min

7it [00:00,  8.39it/s]


validation loss: 3.2130907603672574


100%|██████████| 7/7 [00:02<00:00,  2.67it/s]
100%|██████████| 793/793 [00:01<00:00, 569.24it/s]


BLEU score: 0.04252454242984992, METEOR score: 0.19856427989507028
Starting epoch 11/30, enc lr scheduler: [0.0007525], dec lr scheduler: [0.0007525]
(Epoch 10, iter 50/787) Average loss so far: 2.861
(Epoch 10, iter 100/787) Average loss so far: 2.868
(Epoch 10, iter 150/787) Average loss so far: 2.872
(Epoch 10, iter 200/787) Average loss so far: 2.868
(Epoch 10, iter 250/787) Average loss so far: 2.862
(Epoch 10, iter 300/787) Average loss so far: 2.882
(Epoch 10, iter 350/787) Average loss so far: 2.873
(Epoch 10, iter 400/787) Average loss so far: 2.874
(Epoch 10, iter 450/787) Average loss so far: 2.877
(Epoch 10, iter 500/787) Average loss so far: 2.884
(Epoch 10, iter 550/787) Average loss so far: 2.858
(Epoch 10, iter 600/787) Average loss so far: 2.864
(Epoch 10, iter 650/787) Average loss so far: 2.873
(Epoch 10, iter 700/787) Average loss so far: 2.876
(Epoch 10, iter 750/787) Average loss so far: 2.870
Average epoch loss: 2.870
This epoch took 5.9633493542671205 mins. Time

7it [00:00,  8.27it/s]


validation loss: 3.20478309903826


100%|██████████| 7/7 [00:02<00:00,  3.00it/s]
100%|██████████| 793/793 [00:01<00:00, 693.53it/s]


BLEU score: 0.04374173844572283, METEOR score: 0.19134044687967858
Starting epoch 12/30, enc lr scheduler: [0.0007063346383225212], dec lr scheduler: [0.0007063346383225212]
(Epoch 11, iter 50/787) Average loss so far: 2.849
(Epoch 11, iter 100/787) Average loss so far: 2.842
(Epoch 11, iter 150/787) Average loss so far: 2.845
(Epoch 11, iter 200/787) Average loss so far: 2.840
(Epoch 11, iter 250/787) Average loss so far: 2.825
(Epoch 11, iter 300/787) Average loss so far: 2.858
(Epoch 11, iter 350/787) Average loss so far: 2.837
(Epoch 11, iter 400/787) Average loss so far: 2.829
(Epoch 11, iter 450/787) Average loss so far: 2.836
(Epoch 11, iter 500/787) Average loss so far: 2.840
(Epoch 11, iter 550/787) Average loss so far: 2.844
(Epoch 11, iter 600/787) Average loss so far: 2.850
(Epoch 11, iter 650/787) Average loss so far: 2.833
(Epoch 11, iter 700/787) Average loss so far: 2.848
(Epoch 11, iter 750/787) Average loss so far: 2.844
Average epoch loss: 2.842
This epoch took 5.965

7it [00:00,  8.35it/s]


validation loss: 3.197985989706857


100%|██████████| 7/7 [00:02<00:00,  2.95it/s]
100%|██████████| 793/793 [00:01<00:00, 654.03it/s]


BLEU score: 0.04589529895972353, METEOR score: 0.1966180841918361
Starting epoch 13/30, enc lr scheduler: [0.000657963412215599], dec lr scheduler: [0.000657963412215599]
(Epoch 12, iter 50/787) Average loss so far: 2.813
(Epoch 12, iter 100/787) Average loss so far: 2.823
(Epoch 12, iter 150/787) Average loss so far: 2.817
(Epoch 12, iter 200/787) Average loss so far: 2.818
(Epoch 12, iter 250/787) Average loss so far: 2.804
(Epoch 12, iter 300/787) Average loss so far: 2.821
(Epoch 12, iter 350/787) Average loss so far: 2.820
(Epoch 12, iter 400/787) Average loss so far: 2.824
(Epoch 12, iter 450/787) Average loss so far: 2.837
(Epoch 12, iter 500/787) Average loss so far: 2.799
(Epoch 12, iter 550/787) Average loss so far: 2.822
(Epoch 12, iter 600/787) Average loss so far: 2.821
(Epoch 12, iter 650/787) Average loss so far: 2.810
(Epoch 12, iter 700/787) Average loss so far: 2.827
(Epoch 12, iter 750/787) Average loss so far: 2.812
Average epoch loss: 2.817
This epoch took 5.939353

7it [00:00,  8.27it/s]


validation loss: 3.1950793266296387


100%|██████████| 7/7 [00:02<00:00,  2.84it/s]
100%|██████████| 793/793 [00:01<00:00, 566.05it/s]


BLEU score: 0.042610497365320306, METEOR score: 0.19640161197240455
Starting epoch 14/30, enc lr scheduler: [0.0006079162869547909], dec lr scheduler: [0.0006079162869547909]
(Epoch 13, iter 50/787) Average loss so far: 2.796
(Epoch 13, iter 100/787) Average loss so far: 2.800
(Epoch 13, iter 150/787) Average loss so far: 2.797
(Epoch 13, iter 200/787) Average loss so far: 2.792
(Epoch 13, iter 250/787) Average loss so far: 2.801
(Epoch 13, iter 300/787) Average loss so far: 2.789
(Epoch 13, iter 350/787) Average loss so far: 2.798
(Epoch 13, iter 400/787) Average loss so far: 2.793
(Epoch 13, iter 450/787) Average loss so far: 2.792
(Epoch 13, iter 500/787) Average loss so far: 2.802
(Epoch 13, iter 550/787) Average loss so far: 2.808
(Epoch 13, iter 600/787) Average loss so far: 2.802
(Epoch 13, iter 650/787) Average loss so far: 2.785
(Epoch 13, iter 700/787) Average loss so far: 2.791
(Epoch 13, iter 750/787) Average loss so far: 2.772
Average epoch loss: 2.795
This epoch took 5.94

7it [00:00,  8.31it/s]


validation loss: 3.190805230821882


100%|██████████| 7/7 [00:01<00:00,  3.54it/s]
100%|██████████| 793/793 [00:01<00:00, 774.35it/s]


BLEU score: 0.04117386932250176, METEOR score: 0.1918381125873577
Starting epoch 15/30, enc lr scheduler: [0.0005567415893174886], dec lr scheduler: [0.0005567415893174886]
(Epoch 14, iter 50/787) Average loss so far: 2.771
(Epoch 14, iter 100/787) Average loss so far: 2.772
(Epoch 14, iter 150/787) Average loss so far: 2.768
(Epoch 14, iter 200/787) Average loss so far: 2.766
(Epoch 14, iter 250/787) Average loss so far: 2.776
(Epoch 14, iter 300/787) Average loss so far: 2.773
(Epoch 14, iter 350/787) Average loss so far: 2.768
(Epoch 14, iter 400/787) Average loss so far: 2.789
(Epoch 14, iter 450/787) Average loss so far: 2.763
(Epoch 14, iter 500/787) Average loss so far: 2.781
(Epoch 14, iter 550/787) Average loss so far: 2.784
(Epoch 14, iter 600/787) Average loss so far: 2.787
(Epoch 14, iter 650/787) Average loss so far: 2.773
(Epoch 14, iter 700/787) Average loss so far: 2.782
(Epoch 14, iter 750/787) Average loss so far: 2.792
Average epoch loss: 2.776
This epoch took 5.9348

7it [00:00,  8.31it/s]


validation loss: 3.191071476255144


100%|██████████| 7/7 [00:02<00:00,  2.97it/s]
100%|██████████| 793/793 [00:01<00:00, 619.10it/s]


BLEU score: 0.045987255785282295, METEOR score: 0.19803943527851992
Starting epoch 16/30, enc lr scheduler: [0.0005050000000000002], dec lr scheduler: [0.0005050000000000002]
(Epoch 15, iter 50/787) Average loss so far: 2.756
(Epoch 15, iter 100/787) Average loss so far: 2.755
(Epoch 15, iter 150/787) Average loss so far: 2.756
(Epoch 15, iter 200/787) Average loss so far: 2.764
(Epoch 15, iter 250/787) Average loss so far: 2.748
(Epoch 15, iter 300/787) Average loss so far: 2.737
(Epoch 15, iter 350/787) Average loss so far: 2.749
(Epoch 15, iter 400/787) Average loss so far: 2.763
(Epoch 15, iter 450/787) Average loss so far: 2.753
(Epoch 15, iter 500/787) Average loss so far: 2.776
(Epoch 15, iter 550/787) Average loss so far: 2.766
(Epoch 15, iter 600/787) Average loss so far: 2.768
(Epoch 15, iter 650/787) Average loss so far: 2.773
(Epoch 15, iter 700/787) Average loss so far: 2.767
(Epoch 15, iter 750/787) Average loss so far: 2.761
Average epoch loss: 2.760
This epoch took 5.95

7it [00:00,  8.36it/s]


validation loss: 3.1885389941079274


100%|██████████| 7/7 [00:01<00:00,  3.75it/s]
100%|██████████| 793/793 [00:01<00:00, 763.55it/s]


BLEU score: 0.039472464139260074, METEOR score: 0.1900444808031801
Starting epoch 17/30, enc lr scheduler: [0.0004532584106825117], dec lr scheduler: [0.0004532584106825117]
(Epoch 16, iter 50/787) Average loss so far: 2.728
(Epoch 16, iter 100/787) Average loss so far: 2.737
(Epoch 16, iter 150/787) Average loss so far: 2.737
(Epoch 16, iter 200/787) Average loss so far: 2.738
(Epoch 16, iter 250/787) Average loss so far: 2.746
(Epoch 16, iter 300/787) Average loss so far: 2.736
(Epoch 16, iter 350/787) Average loss so far: 2.753
(Epoch 16, iter 400/787) Average loss so far: 2.752
(Epoch 16, iter 450/787) Average loss so far: 2.745
(Epoch 16, iter 500/787) Average loss so far: 2.738
(Epoch 16, iter 550/787) Average loss so far: 2.760
(Epoch 16, iter 600/787) Average loss so far: 2.751
(Epoch 16, iter 650/787) Average loss so far: 2.750
(Epoch 16, iter 700/787) Average loss so far: 2.748
(Epoch 16, iter 750/787) Average loss so far: 2.753
Average epoch loss: 2.745
This epoch took 6.020

7it [00:00,  8.42it/s]


validation loss: 3.190322296960013


100%|██████████| 7/7 [00:02<00:00,  3.03it/s]
100%|██████████| 793/793 [00:01<00:00, 657.50it/s]


BLEU score: 0.04504238384476892, METEOR score: 0.19752496817224274
Starting epoch 18/30, enc lr scheduler: [0.00040208371304520916], dec lr scheduler: [0.00040208371304520916]
(Epoch 17, iter 50/787) Average loss so far: 2.726
(Epoch 17, iter 100/787) Average loss so far: 2.726
(Epoch 17, iter 150/787) Average loss so far: 2.718
(Epoch 17, iter 200/787) Average loss so far: 2.724
(Epoch 17, iter 250/787) Average loss so far: 2.715
(Epoch 17, iter 300/787) Average loss so far: 2.739
(Epoch 17, iter 350/787) Average loss so far: 2.729
(Epoch 17, iter 400/787) Average loss so far: 2.737
(Epoch 17, iter 450/787) Average loss so far: 2.733
(Epoch 17, iter 500/787) Average loss so far: 2.740
(Epoch 17, iter 550/787) Average loss so far: 2.755
(Epoch 17, iter 600/787) Average loss so far: 2.744
(Epoch 17, iter 650/787) Average loss so far: 2.728
(Epoch 17, iter 700/787) Average loss so far: 2.723
(Epoch 17, iter 750/787) Average loss so far: 2.736
Average epoch loss: 2.732
This epoch took 6.0

7it [00:00,  8.38it/s]


validation loss: 3.1880320140293668


100%|██████████| 7/7 [00:02<00:00,  3.22it/s]
100%|██████████| 793/793 [00:01<00:00, 660.58it/s]


BLEU score: 0.045373700729723125, METEOR score: 0.19635808499055457
Starting epoch 19/30, enc lr scheduler: [0.00035203658778440114], dec lr scheduler: [0.00035203658778440114]
(Epoch 18, iter 50/787) Average loss so far: 2.717
(Epoch 18, iter 100/787) Average loss so far: 2.697
(Epoch 18, iter 150/787) Average loss so far: 2.725
(Epoch 18, iter 200/787) Average loss so far: 2.716
(Epoch 18, iter 250/787) Average loss so far: 2.707
(Epoch 18, iter 300/787) Average loss so far: 2.734
(Epoch 18, iter 350/787) Average loss so far: 2.719
(Epoch 18, iter 400/787) Average loss so far: 2.715
(Epoch 18, iter 450/787) Average loss so far: 2.727
(Epoch 18, iter 500/787) Average loss so far: 2.713
(Epoch 18, iter 550/787) Average loss so far: 2.742
(Epoch 18, iter 600/787) Average loss so far: 2.717
(Epoch 18, iter 650/787) Average loss so far: 2.716
(Epoch 18, iter 700/787) Average loss so far: 2.738
(Epoch 18, iter 750/787) Average loss so far: 2.728
Average epoch loss: 2.721
This epoch took 6.

7it [00:00,  8.31it/s]


validation loss: 3.1870854582105363


100%|██████████| 7/7 [00:02<00:00,  3.22it/s]
100%|██████████| 793/793 [00:01<00:00, 733.39it/s]


BLEU score: 0.04436411201089971, METEOR score: 0.19867751657974117
Starting epoch 20/30, enc lr scheduler: [0.00030366536167747904], dec lr scheduler: [0.00030366536167747904]
(Epoch 19, iter 50/787) Average loss so far: 2.706
(Epoch 19, iter 100/787) Average loss so far: 2.711
(Epoch 19, iter 150/787) Average loss so far: 2.708
(Epoch 19, iter 200/787) Average loss so far: 2.707
(Epoch 19, iter 250/787) Average loss so far: 2.707
(Epoch 19, iter 300/787) Average loss so far: 2.711
(Epoch 19, iter 350/787) Average loss so far: 2.696
(Epoch 19, iter 400/787) Average loss so far: 2.702
(Epoch 19, iter 450/787) Average loss so far: 2.721
(Epoch 19, iter 500/787) Average loss so far: 2.712
(Epoch 19, iter 550/787) Average loss so far: 2.713
(Epoch 19, iter 600/787) Average loss so far: 2.713
(Epoch 19, iter 650/787) Average loss so far: 2.717
(Epoch 19, iter 700/787) Average loss so far: 2.725
(Epoch 19, iter 750/787) Average loss so far: 2.714
Average epoch loss: 2.711
This epoch took 6.0

7it [00:00,  7.83it/s]


validation loss: 3.1880082402910506


100%|██████████| 7/7 [00:02<00:00,  2.81it/s]
100%|██████████| 793/793 [00:01<00:00, 632.48it/s]


BLEU score: 0.0472837279646999, METEOR score: 0.2002107631916008
Starting epoch 21/30, enc lr scheduler: [0.00025750000000000013], dec lr scheduler: [0.00025750000000000013]
(Epoch 20, iter 50/787) Average loss so far: 2.690
(Epoch 20, iter 100/787) Average loss so far: 2.689
(Epoch 20, iter 150/787) Average loss so far: 2.699
(Epoch 20, iter 200/787) Average loss so far: 2.695
(Epoch 20, iter 250/787) Average loss so far: 2.702
(Epoch 20, iter 300/787) Average loss so far: 2.697
(Epoch 20, iter 350/787) Average loss so far: 2.718
(Epoch 20, iter 400/787) Average loss so far: 2.697
(Epoch 20, iter 450/787) Average loss so far: 2.714
(Epoch 20, iter 500/787) Average loss so far: 2.712
(Epoch 20, iter 550/787) Average loss so far: 2.702
(Epoch 20, iter 600/787) Average loss so far: 2.698
(Epoch 20, iter 650/787) Average loss so far: 2.717
(Epoch 20, iter 700/787) Average loss so far: 2.711
(Epoch 20, iter 750/787) Average loss so far: 2.709
Average epoch loss: 2.703
This epoch took 6.040

7it [00:00,  8.25it/s]


validation loss: 3.188002722603934


100%|██████████| 7/7 [00:02<00:00,  2.97it/s]
100%|██████████| 793/793 [00:01<00:00, 700.04it/s]


BLEU score: 0.046293812487382985, METEOR score: 0.19752825008718805
Starting epoch 22/30, enc lr scheduler: [0.00021404630011522585], dec lr scheduler: [0.00021404630011522585]
(Epoch 21, iter 50/787) Average loss so far: 2.694
(Epoch 21, iter 100/787) Average loss so far: 2.706
(Epoch 21, iter 150/787) Average loss so far: 2.679
(Epoch 21, iter 200/787) Average loss so far: 2.691
(Epoch 21, iter 250/787) Average loss so far: 2.688
(Epoch 21, iter 300/787) Average loss so far: 2.695
(Epoch 21, iter 350/787) Average loss so far: 2.691
(Epoch 21, iter 400/787) Average loss so far: 2.712
(Epoch 21, iter 450/787) Average loss so far: 2.682
(Epoch 21, iter 500/787) Average loss so far: 2.702
(Epoch 21, iter 550/787) Average loss so far: 2.687
(Epoch 21, iter 600/787) Average loss so far: 2.695
(Epoch 21, iter 650/787) Average loss so far: 2.695
(Epoch 21, iter 700/787) Average loss so far: 2.698
(Epoch 21, iter 750/787) Average loss so far: 2.708
Average epoch loss: 2.695
This epoch took 6.

7it [00:00,  8.35it/s]


validation loss: 3.190535000392369


100%|██████████| 7/7 [00:02<00:00,  3.29it/s]
100%|██████████| 793/793 [00:01<00:00, 751.69it/s]


BLEU score: 0.04469983055676468, METEOR score: 0.19735370122639811
Starting epoch 23/30, enc lr scheduler: [0.00017378034985236535], dec lr scheduler: [0.00017378034985236535]
(Epoch 22, iter 50/787) Average loss so far: 2.677
(Epoch 22, iter 100/787) Average loss so far: 2.690
(Epoch 22, iter 150/787) Average loss so far: 2.684
(Epoch 22, iter 200/787) Average loss so far: 2.692
(Epoch 22, iter 250/787) Average loss so far: 2.683
(Epoch 22, iter 300/787) Average loss so far: 2.682
(Epoch 22, iter 350/787) Average loss so far: 2.693
(Epoch 22, iter 400/787) Average loss so far: 2.699
(Epoch 22, iter 450/787) Average loss so far: 2.694
(Epoch 22, iter 500/787) Average loss so far: 2.678
(Epoch 22, iter 550/787) Average loss so far: 2.692
(Epoch 22, iter 600/787) Average loss so far: 2.682
(Epoch 22, iter 650/787) Average loss so far: 2.681
(Epoch 22, iter 700/787) Average loss so far: 2.692
(Epoch 22, iter 750/787) Average loss so far: 2.701
Average epoch loss: 2.689
This epoch took 5.9

7it [00:00,  8.23it/s]


validation loss: 3.1892076901027133


100%|██████████| 7/7 [00:02<00:00,  3.18it/s]
100%|██████████| 793/793 [00:01<00:00, 695.99it/s]


BLEU score: 0.044842831330928805, METEOR score: 0.19575345815276526
Starting epoch 24/30, enc lr scheduler: [0.00013714331138868998], dec lr scheduler: [0.00013714331138868998]
(Epoch 23, iter 50/787) Average loss so far: 2.671
(Epoch 23, iter 100/787) Average loss so far: 2.682
(Epoch 23, iter 150/787) Average loss so far: 2.683
(Epoch 23, iter 200/787) Average loss so far: 2.668
(Epoch 23, iter 250/787) Average loss so far: 2.684
(Epoch 23, iter 300/787) Average loss so far: 2.681
(Epoch 23, iter 350/787) Average loss so far: 2.674
(Epoch 23, iter 400/787) Average loss so far: 2.688
(Epoch 23, iter 450/787) Average loss so far: 2.690
(Epoch 23, iter 500/787) Average loss so far: 2.687
(Epoch 23, iter 550/787) Average loss so far: 2.682
(Epoch 23, iter 600/787) Average loss so far: 2.701
(Epoch 23, iter 650/787) Average loss so far: 2.692
(Epoch 23, iter 700/787) Average loss so far: 2.691
(Epoch 23, iter 750/787) Average loss so far: 2.677
Average epoch loss: 2.683
This epoch took 5.

7it [00:00,  8.54it/s]


validation loss: 3.1896793161119734


100%|██████████| 7/7 [00:02<00:00,  3.23it/s]
100%|██████████| 793/793 [00:01<00:00, 697.41it/s]


BLEU score: 0.04446433657061311, METEOR score: 0.19553098270527547
Starting epoch 25/30, enc lr scheduler: [0.00010453658778440108], dec lr scheduler: [0.00010453658778440108]
(Epoch 24, iter 50/787) Average loss so far: 2.673
(Epoch 24, iter 100/787) Average loss so far: 2.677
(Epoch 24, iter 150/787) Average loss so far: 2.681
(Epoch 24, iter 200/787) Average loss so far: 2.669
(Epoch 24, iter 250/787) Average loss so far: 2.675
(Epoch 24, iter 300/787) Average loss so far: 2.675
(Epoch 24, iter 350/787) Average loss so far: 2.677
(Epoch 24, iter 400/787) Average loss so far: 2.678
(Epoch 24, iter 450/787) Average loss so far: 2.683
(Epoch 24, iter 500/787) Average loss so far: 2.677
(Epoch 24, iter 550/787) Average loss so far: 2.681
(Epoch 24, iter 600/787) Average loss so far: 2.674
(Epoch 24, iter 650/787) Average loss so far: 2.688
(Epoch 24, iter 700/787) Average loss so far: 2.680
(Epoch 24, iter 750/787) Average loss so far: 2.682
Average epoch loss: 2.679
This epoch took 5.8

7it [00:00,  8.59it/s]


validation loss: 3.1895811557769775


100%|██████████| 7/7 [00:02<00:00,  3.10it/s]
100%|██████████| 793/793 [00:01<00:00, 672.54it/s]


BLEU score: 0.047831302586897266, METEOR score: 0.1988243943422679
Starting epoch 26/30, enc lr scheduler: [7.631742512670285e-05], dec lr scheduler: [7.631742512670285e-05]
(Epoch 25, iter 50/787) Average loss so far: 2.664
(Epoch 25, iter 100/787) Average loss so far: 2.655
(Epoch 25, iter 150/787) Average loss so far: 2.672
(Epoch 25, iter 200/787) Average loss so far: 2.664
(Epoch 25, iter 250/787) Average loss so far: 2.686
(Epoch 25, iter 300/787) Average loss so far: 2.691
(Epoch 25, iter 350/787) Average loss so far: 2.671
(Epoch 25, iter 400/787) Average loss so far: 2.693
(Epoch 25, iter 450/787) Average loss so far: 2.666
(Epoch 25, iter 500/787) Average loss so far: 2.676
(Epoch 25, iter 550/787) Average loss so far: 2.679
(Epoch 25, iter 600/787) Average loss so far: 2.676
(Epoch 25, iter 650/787) Average loss so far: 2.688
(Epoch 25, iter 700/787) Average loss so far: 2.677
(Epoch 25, iter 750/787) Average loss so far: 2.676
Average epoch loss: 2.676
This epoch took 5.903

7it [00:00,  8.29it/s]


validation loss: 3.191049098968506


100%|██████████| 7/7 [00:02<00:00,  3.08it/s]
100%|██████████| 793/793 [00:01<00:00, 717.71it/s]


BLEU score: 0.045387576782303124, METEOR score: 0.19554948188506097
Starting epoch 27/30, enc lr scheduler: [5.279499846691252e-05], dec lr scheduler: [5.279499846691252e-05]
(Epoch 26, iter 50/787) Average loss so far: 2.678
(Epoch 26, iter 100/787) Average loss so far: 2.663
(Epoch 26, iter 150/787) Average loss so far: 2.672
(Epoch 26, iter 200/787) Average loss so far: 2.686
(Epoch 26, iter 250/787) Average loss so far: 2.678
(Epoch 26, iter 300/787) Average loss so far: 2.668
(Epoch 26, iter 350/787) Average loss so far: 2.687
(Epoch 26, iter 400/787) Average loss so far: 2.667
(Epoch 26, iter 450/787) Average loss so far: 2.661
(Epoch 26, iter 500/787) Average loss so far: 2.681
(Epoch 26, iter 550/787) Average loss so far: 2.689
(Epoch 26, iter 600/787) Average loss so far: 2.655
(Epoch 26, iter 650/787) Average loss so far: 2.680
(Epoch 26, iter 700/787) Average loss so far: 2.657
(Epoch 26, iter 750/787) Average loss so far: 2.667
Average epoch loss: 2.672
This epoch took 5.97

7it [00:00,  8.29it/s]


validation loss: 3.190984317234584


100%|██████████| 7/7 [00:02<00:00,  3.10it/s]
100%|██████████| 793/793 [00:01<00:00, 678.46it/s]


BLEU score: 0.04477507155847262, METEOR score: 0.19652649610987258
Starting epoch 28/30, enc lr scheduler: [3.4227024433899005e-05], dec lr scheduler: [3.4227024433899005e-05]
(Epoch 27, iter 50/787) Average loss so far: 2.688
(Epoch 27, iter 100/787) Average loss so far: 2.668
(Epoch 27, iter 150/787) Average loss so far: 2.669
(Epoch 27, iter 200/787) Average loss so far: 2.668
(Epoch 27, iter 250/787) Average loss so far: 2.677
(Epoch 27, iter 300/787) Average loss so far: 2.678
(Epoch 27, iter 350/787) Average loss so far: 2.672
(Epoch 27, iter 400/787) Average loss so far: 2.675
(Epoch 27, iter 450/787) Average loss so far: 2.645
(Epoch 27, iter 500/787) Average loss so far: 2.663
(Epoch 27, iter 550/787) Average loss so far: 2.667
(Epoch 27, iter 600/787) Average loss so far: 2.666
(Epoch 27, iter 650/787) Average loss so far: 2.684
(Epoch 27, iter 700/787) Average loss so far: 2.673
(Epoch 27, iter 750/787) Average loss so far: 2.664
Average epoch loss: 2.670
This epoch took 5.9

7it [00:00,  8.23it/s]


validation loss: 3.191187177385603


100%|██████████| 7/7 [00:02<00:00,  3.29it/s]
100%|██████████| 793/793 [00:01<00:00, 711.73it/s]


BLEU score: 0.04755797880330766, METEOR score: 0.20054605031876846
Starting epoch 29/30, enc lr scheduler: [2.0816937636766188e-05], dec lr scheduler: [2.0816937636766188e-05]
(Epoch 28, iter 50/787) Average loss so far: 2.658
(Epoch 28, iter 100/787) Average loss so far: 2.680
(Epoch 28, iter 150/787) Average loss so far: 2.687
(Epoch 28, iter 200/787) Average loss so far: 2.667
(Epoch 28, iter 250/787) Average loss so far: 2.670
(Epoch 28, iter 300/787) Average loss so far: 2.676
(Epoch 28, iter 350/787) Average loss so far: 2.678
(Epoch 28, iter 400/787) Average loss so far: 2.680
(Epoch 28, iter 450/787) Average loss so far: 2.667
(Epoch 28, iter 500/787) Average loss so far: 2.670
(Epoch 28, iter 550/787) Average loss so far: 2.669
(Epoch 28, iter 600/787) Average loss so far: 2.635
(Epoch 28, iter 650/787) Average loss so far: 2.672
(Epoch 28, iter 700/787) Average loss so far: 2.669
(Epoch 28, iter 750/787) Average loss so far: 2.660
Average epoch loss: 2.669
This epoch took 5.9

7it [00:00,  8.34it/s]


validation loss: 3.1912351676395962


100%|██████████| 7/7 [00:02<00:00,  3.14it/s]
100%|██████████| 793/793 [00:01<00:00, 689.02it/s]


BLEU score: 0.04576639965453813, METEOR score: 0.19814575935905007
Starting epoch 30/30, enc lr scheduler: [1.2711661792704668e-05], dec lr scheduler: [1.2711661792704668e-05]
(Epoch 29, iter 50/787) Average loss so far: 2.674
(Epoch 29, iter 100/787) Average loss so far: 2.674
(Epoch 29, iter 150/787) Average loss so far: 2.675
(Epoch 29, iter 200/787) Average loss so far: 2.647
(Epoch 29, iter 250/787) Average loss so far: 2.673
(Epoch 29, iter 300/787) Average loss so far: 2.653
(Epoch 29, iter 350/787) Average loss so far: 2.662
(Epoch 29, iter 400/787) Average loss so far: 2.661
(Epoch 29, iter 450/787) Average loss so far: 2.664
(Epoch 29, iter 500/787) Average loss so far: 2.695
(Epoch 29, iter 550/787) Average loss so far: 2.670
(Epoch 29, iter 600/787) Average loss so far: 2.660
(Epoch 29, iter 650/787) Average loss so far: 2.679
(Epoch 29, iter 700/787) Average loss so far: 2.666
(Epoch 29, iter 750/787) Average loss so far: 2.656
Average epoch loss: 2.668
This epoch took 5.9

7it [00:00,  8.36it/s]


validation loss: 3.191310610089983


100%|██████████| 7/7 [00:02<00:00,  3.16it/s]
100%|██████████| 793/793 [00:01<00:00, 697.13it/s]

BLEU score: 0.045667970073548995, METEOR score: 0.19664536921554326





## Encoder-Decoder (Attention)

In [11]:
reset_rng()

In [12]:
embedding_size=300
encoder_attn = EncoderRNN(vocab.n_unique_words, embedding_size=embedding_size, hidden_size=HIDDEN_SIZE, padding_value=vocab.word2index(PAD_WORD)).to(DEVICE)
# in the training script, decoder is always fed a non-end token and thus never needs to generate padding
# also it should never generate "<UNKNOWN>"
# decoder = DecoderRNN(embedding_size=embedding_size,hidden_size=HIDDEN_SIZE, output_size=vocab.n_unique_words-2).to(DEVICE)
decoder_attn = AttnDecoderRNN(embedding_size, hidden_size=HIDDEN_SIZE, output_size=vocab.n_unique_words-1, padding_val=vocab.word2index(PAD_WORD), 
                              dropout=DROPOUT).to(DEVICE)

In [13]:
initial_lr=1e-3
min_lr = 1e-5
n_epochs = 30
batch_size=128
encoder_attn_optimizer = optim.Adam(encoder_attn.parameters(), lr=initial_lr)
decoder_attn_optimizer = optim.Adam(decoder_attn.parameters(), lr=initial_lr)
# enc_attn_scheduler = CosineAnnealingLR(encoder_attn_optimizer, T_max=n_epochs, eta_min=min_lr)
# dec_attn_scheduler = CosineAnnealingLR(decoder_attn_optimizer, T_max=n_epochs, eta_min=min_lr)
enc_attn_scheduler = CosineAnnealingLR(encoder_attn_optimizer, T_max=n_epochs, eta_min=min_lr)
dec_attn_scheduler = CosineAnnealingLR(decoder_attn_optimizer, T_max=n_epochs, eta_min=min_lr)
identifier="attn_adam_without_intermediate_tags_wd0_lr1e-3"

attn_epoch_losses, attn_val_epoch_losses, attn_log = train(
    encoder_attn, decoder_attn, encoder_attn_optimizer, decoder_attn_optimizer, train_ds, 
    n_epochs=n_epochs, vocab=vocab, decoder_mode="attention", batch_size=batch_size, 
    enc_lr_scheduler=enc_attn_scheduler, dec_lr_scheduler=dec_attn_scheduler, 
    dev_ds_val_loss = dev_ds_val_loss, dev_ds_val_met=dev_ds_val_met, identifier=identifier,
    verbose_iter_interval=50)

save_log(identifier, attn_log, encoder_attn_optimizer, decoder_attn_optimizer, 
         enc_attn_scheduler, dec_attn_scheduler)

Starting epoch 1/30, enc lr scheduler: [0.001], dec lr scheduler: [0.001]
(Epoch 0, iter 50/787) Average loss so far: 6.972
(Epoch 0, iter 100/787) Average loss so far: 6.023
(Epoch 0, iter 150/787) Average loss so far: 5.968
(Epoch 0, iter 200/787) Average loss so far: 5.927
(Epoch 0, iter 250/787) Average loss so far: 5.877
(Epoch 0, iter 300/787) Average loss so far: 5.818
(Epoch 0, iter 350/787) Average loss so far: 5.785
(Epoch 0, iter 400/787) Average loss so far: 5.725
(Epoch 0, iter 450/787) Average loss so far: 5.683
(Epoch 0, iter 500/787) Average loss so far: 5.621
(Epoch 0, iter 550/787) Average loss so far: 5.551
(Epoch 0, iter 600/787) Average loss so far: 5.440
(Epoch 0, iter 650/787) Average loss so far: 5.360
(Epoch 0, iter 700/787) Average loss so far: 5.215
(Epoch 0, iter 750/787) Average loss so far: 5.093
Average epoch loss: 5.701
This epoch took 8.194433637460072 mins. Time remaining: 3.0 hrs 57.0 mins.


7it [00:01,  5.45it/s]


validation loss: 4.93167952128819


100%|██████████| 7/7 [00:07<00:00,  1.13s/it]
100%|██████████| 793/793 [00:08<00:00, 95.55it/s] 


BLEU score: 0.002800126985071343, METEOR score: 0.08576382563509898
Starting epoch 2/30, enc lr scheduler: [0.0009972883382072953], dec lr scheduler: [0.0009972883382072953]
(Epoch 1, iter 50/787) Average loss so far: 4.863
(Epoch 1, iter 100/787) Average loss so far: 4.706
(Epoch 1, iter 150/787) Average loss so far: 4.641
(Epoch 1, iter 200/787) Average loss so far: 4.523
(Epoch 1, iter 250/787) Average loss so far: 4.437
(Epoch 1, iter 300/787) Average loss so far: 4.366
(Epoch 1, iter 350/787) Average loss so far: 4.295
(Epoch 1, iter 400/787) Average loss so far: 4.248
(Epoch 1, iter 450/787) Average loss so far: 4.186
(Epoch 1, iter 500/787) Average loss so far: 4.133
(Epoch 1, iter 550/787) Average loss so far: 4.082
(Epoch 1, iter 600/787) Average loss so far: 4.041
(Epoch 1, iter 650/787) Average loss so far: 4.004
(Epoch 1, iter 700/787) Average loss so far: 3.956
(Epoch 1, iter 750/787) Average loss so far: 3.932
Average epoch loss: 4.277
This epoch took 8.073306282361349 mi

7it [00:01,  5.24it/s]


validation loss: 3.926896367754255


100%|██████████| 7/7 [00:07<00:00,  1.05s/it]
100%|██████████| 793/793 [00:07<00:00, 112.04it/s]


BLEU score: 0.005528415216107959, METEOR score: 0.1087509998355728
Starting epoch 3/30, enc lr scheduler: [0.0009891830623632338], dec lr scheduler: [0.0009891830623632338]
(Epoch 2, iter 50/787) Average loss so far: 3.866
(Epoch 2, iter 100/787) Average loss so far: 3.844
(Epoch 2, iter 150/787) Average loss so far: 3.818
(Epoch 2, iter 200/787) Average loss so far: 3.778
(Epoch 2, iter 250/787) Average loss so far: 3.752
(Epoch 2, iter 300/787) Average loss so far: 3.761
(Epoch 2, iter 350/787) Average loss so far: 3.706
(Epoch 2, iter 400/787) Average loss so far: 3.690
(Epoch 2, iter 450/787) Average loss so far: 3.667
(Epoch 2, iter 500/787) Average loss so far: 3.662
(Epoch 2, iter 550/787) Average loss so far: 3.666
(Epoch 2, iter 600/787) Average loss so far: 3.646
(Epoch 2, iter 650/787) Average loss so far: 3.610
(Epoch 2, iter 700/787) Average loss so far: 3.615
(Epoch 2, iter 750/787) Average loss so far: 3.572
Average epoch loss: 3.704
This epoch took 8.144790331522623 min

7it [00:01,  5.53it/s]


validation loss: 3.6217985493796214


100%|██████████| 7/7 [00:04<00:00,  1.51it/s]
100%|██████████| 793/793 [00:03<00:00, 240.20it/s]


BLEU score: 0.015059414822228353, METEOR score: 0.1608885573065793
Starting epoch 4/30, enc lr scheduler: [0.0009757729755661011], dec lr scheduler: [0.0009757729755661011]
(Epoch 3, iter 50/787) Average loss so far: 3.552
(Epoch 3, iter 100/787) Average loss so far: 3.533
(Epoch 3, iter 150/787) Average loss so far: 3.512
(Epoch 3, iter 200/787) Average loss so far: 3.499
(Epoch 3, iter 250/787) Average loss so far: 3.503
(Epoch 3, iter 300/787) Average loss so far: 3.494
(Epoch 3, iter 350/787) Average loss so far: 3.480
(Epoch 3, iter 400/787) Average loss so far: 3.455
(Epoch 3, iter 450/787) Average loss so far: 3.445
(Epoch 3, iter 500/787) Average loss so far: 3.436
(Epoch 3, iter 550/787) Average loss so far: 3.436
(Epoch 3, iter 600/787) Average loss so far: 3.424
(Epoch 3, iter 650/787) Average loss so far: 3.410
(Epoch 3, iter 700/787) Average loss so far: 3.401
(Epoch 3, iter 750/787) Average loss so far: 3.412
Average epoch loss: 3.463
This epoch took 8.090149772167205 min

7it [00:01,  5.57it/s]


validation loss: 3.4686967645372664


100%|██████████| 7/7 [00:03<00:00,  1.87it/s]
100%|██████████| 793/793 [00:02<00:00, 350.55it/s]


BLEU score: 0.02478974721483497, METEOR score: 0.18132902088830255
Starting epoch 5/30, enc lr scheduler: [0.0009572050015330874], dec lr scheduler: [0.0009572050015330874]
(Epoch 4, iter 50/787) Average loss so far: 3.370
(Epoch 4, iter 100/787) Average loss so far: 3.354
(Epoch 4, iter 150/787) Average loss so far: 3.346
(Epoch 4, iter 200/787) Average loss so far: 3.360
(Epoch 4, iter 250/787) Average loss so far: 3.335
(Epoch 4, iter 300/787) Average loss so far: 3.325
(Epoch 4, iter 350/787) Average loss so far: 3.323
(Epoch 4, iter 400/787) Average loss so far: 3.308
(Epoch 4, iter 450/787) Average loss so far: 3.320
(Epoch 4, iter 500/787) Average loss so far: 3.303
(Epoch 4, iter 550/787) Average loss so far: 3.310
(Epoch 4, iter 600/787) Average loss so far: 3.296
(Epoch 4, iter 650/787) Average loss so far: 3.290
(Epoch 4, iter 700/787) Average loss so far: 3.293
(Epoch 4, iter 750/787) Average loss so far: 3.283
Average epoch loss: 3.318
This epoch took 8.103163651625316 min

7it [00:01,  5.50it/s]


validation loss: 3.376786708831787


100%|██████████| 7/7 [00:04<00:00,  1.44it/s]
100%|██████████| 793/793 [00:03<00:00, 216.71it/s]


BLEU score: 0.016797911257019238, METEOR score: 0.17466288946573208
Starting epoch 6/30, enc lr scheduler: [0.0009336825748732973], dec lr scheduler: [0.0009336825748732973]
(Epoch 5, iter 50/787) Average loss so far: 3.233
(Epoch 5, iter 100/787) Average loss so far: 3.250
(Epoch 5, iter 150/787) Average loss so far: 3.227
(Epoch 5, iter 200/787) Average loss so far: 3.236
(Epoch 5, iter 250/787) Average loss so far: 3.215
(Epoch 5, iter 300/787) Average loss so far: 3.229
(Epoch 5, iter 350/787) Average loss so far: 3.227
(Epoch 5, iter 400/787) Average loss so far: 3.229
(Epoch 5, iter 450/787) Average loss so far: 3.212
(Epoch 5, iter 500/787) Average loss so far: 3.217
(Epoch 5, iter 550/787) Average loss so far: 3.221
(Epoch 5, iter 600/787) Average loss so far: 3.205
(Epoch 5, iter 650/787) Average loss so far: 3.199
(Epoch 5, iter 700/787) Average loss so far: 3.194
(Epoch 5, iter 750/787) Average loss so far: 3.184
Average epoch loss: 3.217
This epoch took 8.09989018837611 min

7it [00:01,  5.31it/s]


validation loss: 3.3135062967027937


100%|██████████| 7/7 [00:03<00:00,  1.87it/s]
100%|██████████| 793/793 [00:02<00:00, 322.45it/s]


BLEU score: 0.026052313701649024, METEOR score: 0.1865945117369642
Starting epoch 7/30, enc lr scheduler: [0.0009054634122155991], dec lr scheduler: [0.0009054634122155991]
(Epoch 6, iter 50/787) Average loss so far: 3.171
(Epoch 6, iter 100/787) Average loss so far: 3.163
(Epoch 6, iter 150/787) Average loss so far: 3.161
(Epoch 6, iter 200/787) Average loss so far: 3.149
(Epoch 6, iter 250/787) Average loss so far: 3.141
(Epoch 6, iter 300/787) Average loss so far: 3.143
(Epoch 6, iter 350/787) Average loss so far: 3.139
(Epoch 6, iter 400/787) Average loss so far: 3.145
(Epoch 6, iter 450/787) Average loss so far: 3.148
(Epoch 6, iter 500/787) Average loss so far: 3.133
(Epoch 6, iter 550/787) Average loss so far: 3.126
(Epoch 6, iter 600/787) Average loss so far: 3.146
(Epoch 6, iter 650/787) Average loss so far: 3.121
(Epoch 6, iter 700/787) Average loss so far: 3.131
(Epoch 6, iter 750/787) Average loss so far: 3.117
Average epoch loss: 3.141
This epoch took 8.1756121357282 mins.

7it [00:01,  5.49it/s]


validation loss: 3.277291910988944


100%|██████████| 7/7 [00:03<00:00,  1.86it/s]
100%|██████████| 793/793 [00:02<00:00, 339.68it/s]


BLEU score: 0.0272790639424669, METEOR score: 0.1936208372227744
Starting epoch 8/30, enc lr scheduler: [0.0008728566886113102], dec lr scheduler: [0.0008728566886113102]
(Epoch 7, iter 50/787) Average loss so far: 3.102
(Epoch 7, iter 100/787) Average loss so far: 3.099
(Epoch 7, iter 150/787) Average loss so far: 3.106
(Epoch 7, iter 200/787) Average loss so far: 3.087
(Epoch 7, iter 250/787) Average loss so far: 3.083
(Epoch 7, iter 300/787) Average loss so far: 3.089
(Epoch 7, iter 350/787) Average loss so far: 3.085
(Epoch 7, iter 400/787) Average loss so far: 3.076
(Epoch 7, iter 450/787) Average loss so far: 3.084
(Epoch 7, iter 500/787) Average loss so far: 3.067
(Epoch 7, iter 550/787) Average loss so far: 3.081
(Epoch 7, iter 600/787) Average loss so far: 3.054
(Epoch 7, iter 650/787) Average loss so far: 3.071
(Epoch 7, iter 700/787) Average loss so far: 3.076
(Epoch 7, iter 750/787) Average loss so far: 3.075
Average epoch loss: 3.081
This epoch took 8.089503169059753 mins.

7it [00:01,  5.51it/s]


validation loss: 3.2438402516501292


100%|██████████| 7/7 [00:03<00:00,  2.02it/s]
100%|██████████| 793/793 [00:01<00:00, 431.65it/s]


BLEU score: 0.038146458721083895, METEOR score: 0.1998428588672739
Starting epoch 9/30, enc lr scheduler: [0.0008362196501476349], dec lr scheduler: [0.0008362196501476349]
(Epoch 8, iter 50/787) Average loss so far: 3.045
(Epoch 8, iter 100/787) Average loss so far: 3.044
(Epoch 8, iter 150/787) Average loss so far: 3.046
(Epoch 8, iter 200/787) Average loss so far: 3.050
(Epoch 8, iter 250/787) Average loss so far: 3.025
(Epoch 8, iter 300/787) Average loss so far: 3.033
(Epoch 8, iter 350/787) Average loss so far: 3.038
(Epoch 8, iter 400/787) Average loss so far: 3.021
(Epoch 8, iter 450/787) Average loss so far: 3.028
(Epoch 8, iter 500/787) Average loss so far: 3.025
(Epoch 8, iter 550/787) Average loss so far: 3.041
(Epoch 8, iter 600/787) Average loss so far: 3.017
(Epoch 8, iter 650/787) Average loss so far: 3.017
(Epoch 8, iter 700/787) Average loss so far: 3.039
(Epoch 8, iter 750/787) Average loss so far: 3.022
Average epoch loss: 3.031
This epoch took 8.09464071591695 mins

7it [00:01,  5.53it/s]


validation loss: 3.2192379406520297


100%|██████████| 7/7 [00:03<00:00,  2.01it/s]
100%|██████████| 793/793 [00:01<00:00, 409.40it/s]


BLEU score: 0.036190345080440604, METEOR score: 0.19858925888566942
Starting epoch 10/30, enc lr scheduler: [0.0007959536998847742], dec lr scheduler: [0.0007959536998847742]
(Epoch 9, iter 50/787) Average loss so far: 3.012
(Epoch 9, iter 100/787) Average loss so far: 2.984
(Epoch 9, iter 150/787) Average loss so far: 2.981
(Epoch 9, iter 200/787) Average loss so far: 2.988
(Epoch 9, iter 250/787) Average loss so far: 2.998
(Epoch 9, iter 300/787) Average loss so far: 2.991
(Epoch 9, iter 350/787) Average loss so far: 3.002
(Epoch 9, iter 400/787) Average loss so far: 2.992
(Epoch 9, iter 450/787) Average loss so far: 3.007
(Epoch 9, iter 500/787) Average loss so far: 2.980
(Epoch 9, iter 550/787) Average loss so far: 2.987
(Epoch 9, iter 600/787) Average loss so far: 2.970
(Epoch 9, iter 650/787) Average loss so far: 2.989
(Epoch 9, iter 700/787) Average loss so far: 2.997
(Epoch 9, iter 750/787) Average loss so far: 2.971
Average epoch loss: 2.990
This epoch took 8.119804537296295 m

7it [00:01,  5.47it/s]


validation loss: 3.2011922086988176


100%|██████████| 7/7 [00:03<00:00,  2.08it/s]
100%|██████████| 793/793 [00:01<00:00, 448.70it/s]


BLEU score: 0.03947045191983198, METEOR score: 0.19917691253253444
Starting epoch 11/30, enc lr scheduler: [0.0007525], dec lr scheduler: [0.0007525]
(Epoch 10, iter 50/787) Average loss so far: 2.958
(Epoch 10, iter 100/787) Average loss so far: 2.957
(Epoch 10, iter 150/787) Average loss so far: 2.941
(Epoch 10, iter 200/787) Average loss so far: 2.959
(Epoch 10, iter 250/787) Average loss so far: 2.947
(Epoch 10, iter 300/787) Average loss so far: 2.957
(Epoch 10, iter 350/787) Average loss so far: 2.952
(Epoch 10, iter 400/787) Average loss so far: 2.966
(Epoch 10, iter 450/787) Average loss so far: 2.948
(Epoch 10, iter 500/787) Average loss so far: 2.961
(Epoch 10, iter 550/787) Average loss so far: 2.954
(Epoch 10, iter 600/787) Average loss so far: 2.947
(Epoch 10, iter 650/787) Average loss so far: 2.955
(Epoch 10, iter 700/787) Average loss so far: 2.952
(Epoch 10, iter 750/787) Average loss so far: 2.965
Average epoch loss: 2.955
This epoch took 8.114036385218302 mins. Time 

7it [00:01,  5.57it/s]


validation loss: 3.1880315371922086


100%|██████████| 7/7 [00:03<00:00,  2.01it/s]
100%|██████████| 793/793 [00:01<00:00, 432.08it/s]


BLEU score: 0.03707379676793246, METEOR score: 0.20066902120742106
Starting epoch 12/30, enc lr scheduler: [0.0007063346383225212], dec lr scheduler: [0.0007063346383225212]
(Epoch 11, iter 50/787) Average loss so far: 2.919
(Epoch 11, iter 100/787) Average loss so far: 2.936
(Epoch 11, iter 150/787) Average loss so far: 2.912
(Epoch 11, iter 200/787) Average loss so far: 2.926
(Epoch 11, iter 250/787) Average loss so far: 2.933
(Epoch 11, iter 300/787) Average loss so far: 2.918
(Epoch 11, iter 350/787) Average loss so far: 2.928
(Epoch 11, iter 400/787) Average loss so far: 2.931
(Epoch 11, iter 450/787) Average loss so far: 2.915
(Epoch 11, iter 500/787) Average loss so far: 2.925
(Epoch 11, iter 550/787) Average loss so far: 2.922
(Epoch 11, iter 600/787) Average loss so far: 2.912
(Epoch 11, iter 650/787) Average loss so far: 2.919
(Epoch 11, iter 700/787) Average loss so far: 2.929
(Epoch 11, iter 750/787) Average loss so far: 2.939
Average epoch loss: 2.924
This epoch took 8.101

7it [00:01,  5.54it/s]


validation loss: 3.1738901478903636


100%|██████████| 7/7 [00:03<00:00,  1.97it/s]
100%|██████████| 793/793 [00:01<00:00, 420.49it/s]


BLEU score: 0.03867613565428751, METEOR score: 0.2017267310038669
Starting epoch 13/30, enc lr scheduler: [0.000657963412215599], dec lr scheduler: [0.000657963412215599]
(Epoch 12, iter 50/787) Average loss so far: 2.889
(Epoch 12, iter 100/787) Average loss so far: 2.888
(Epoch 12, iter 150/787) Average loss so far: 2.880
(Epoch 12, iter 200/787) Average loss so far: 2.896
(Epoch 12, iter 250/787) Average loss so far: 2.896
(Epoch 12, iter 300/787) Average loss so far: 2.903
(Epoch 12, iter 350/787) Average loss so far: 2.885
(Epoch 12, iter 400/787) Average loss so far: 2.904
(Epoch 12, iter 450/787) Average loss so far: 2.898
(Epoch 12, iter 500/787) Average loss so far: 2.898
(Epoch 12, iter 550/787) Average loss so far: 2.892
(Epoch 12, iter 600/787) Average loss so far: 2.911
(Epoch 12, iter 650/787) Average loss so far: 2.916
(Epoch 12, iter 700/787) Average loss so far: 2.894
(Epoch 12, iter 750/787) Average loss so far: 2.902
Average epoch loss: 2.896
This epoch took 8.079897

7it [00:01,  5.26it/s]


validation loss: 3.1668037346431186


100%|██████████| 7/7 [00:03<00:00,  1.94it/s]
100%|██████████| 793/793 [00:02<00:00, 390.27it/s]


BLEU score: 0.035783646119029455, METEOR score: 0.19836152461714607
Starting epoch 14/30, enc lr scheduler: [0.0006079162869547909], dec lr scheduler: [0.0006079162869547909]
(Epoch 13, iter 50/787) Average loss so far: 2.883
(Epoch 13, iter 100/787) Average loss so far: 2.870
(Epoch 13, iter 150/787) Average loss so far: 2.875
(Epoch 13, iter 200/787) Average loss so far: 2.870
(Epoch 13, iter 250/787) Average loss so far: 2.899
(Epoch 13, iter 300/787) Average loss so far: 2.864
(Epoch 13, iter 350/787) Average loss so far: 2.870
(Epoch 13, iter 400/787) Average loss so far: 2.875
(Epoch 13, iter 450/787) Average loss so far: 2.873
(Epoch 13, iter 500/787) Average loss so far: 2.873
(Epoch 13, iter 550/787) Average loss so far: 2.860
(Epoch 13, iter 600/787) Average loss so far: 2.863
(Epoch 13, iter 650/787) Average loss so far: 2.887
(Epoch 13, iter 700/787) Average loss so far: 2.874
(Epoch 13, iter 750/787) Average loss so far: 2.867
Average epoch loss: 2.873
This epoch took 8.05

7it [00:01,  5.44it/s]


validation loss: 3.1585819721221924


100%|██████████| 7/7 [00:03<00:00,  1.98it/s]
100%|██████████| 793/793 [00:01<00:00, 447.37it/s]


BLEU score: 0.04004954169047118, METEOR score: 0.20343329001800056
Starting epoch 15/30, enc lr scheduler: [0.0005567415893174886], dec lr scheduler: [0.0005567415893174886]
(Epoch 14, iter 50/787) Average loss so far: 2.854
(Epoch 14, iter 100/787) Average loss so far: 2.848
(Epoch 14, iter 150/787) Average loss so far: 2.840
(Epoch 14, iter 200/787) Average loss so far: 2.846
(Epoch 14, iter 250/787) Average loss so far: 2.842
(Epoch 14, iter 300/787) Average loss so far: 2.861
(Epoch 14, iter 350/787) Average loss so far: 2.832
(Epoch 14, iter 400/787) Average loss so far: 2.864
(Epoch 14, iter 450/787) Average loss so far: 2.858
(Epoch 14, iter 500/787) Average loss so far: 2.865
(Epoch 14, iter 550/787) Average loss so far: 2.849
(Epoch 14, iter 600/787) Average loss so far: 2.854
(Epoch 14, iter 650/787) Average loss so far: 2.849
(Epoch 14, iter 700/787) Average loss so far: 2.852
(Epoch 14, iter 750/787) Average loss so far: 2.866
Average epoch loss: 2.852
This epoch took 8.125

7it [00:01,  5.75it/s]


validation loss: 3.157423734664917


100%|██████████| 7/7 [00:03<00:00,  2.14it/s]
100%|██████████| 793/793 [00:01<00:00, 505.40it/s]


BLEU score: 0.04606436720940942, METEOR score: 0.2026929298648693
Starting epoch 16/30, enc lr scheduler: [0.0005050000000000002], dec lr scheduler: [0.0005050000000000002]
(Epoch 15, iter 50/787) Average loss so far: 2.824
(Epoch 15, iter 100/787) Average loss so far: 2.823
(Epoch 15, iter 150/787) Average loss so far: 2.835
(Epoch 15, iter 200/787) Average loss so far: 2.824
(Epoch 15, iter 250/787) Average loss so far: 2.828
(Epoch 15, iter 300/787) Average loss so far: 2.832
(Epoch 15, iter 350/787) Average loss so far: 2.837
(Epoch 15, iter 400/787) Average loss so far: 2.820
(Epoch 15, iter 450/787) Average loss so far: 2.829
(Epoch 15, iter 500/787) Average loss so far: 2.846
(Epoch 15, iter 550/787) Average loss so far: 2.845
(Epoch 15, iter 600/787) Average loss so far: 2.843
(Epoch 15, iter 650/787) Average loss so far: 2.838
(Epoch 15, iter 700/787) Average loss so far: 2.838
(Epoch 15, iter 750/787) Average loss so far: 2.839
Average epoch loss: 2.833
This epoch took 8.0893

7it [00:01,  5.77it/s]


validation loss: 3.1507908276149204


100%|██████████| 7/7 [00:03<00:00,  2.12it/s]
100%|██████████| 793/793 [00:01<00:00, 486.39it/s]


BLEU score: 0.04347215230255102, METEOR score: 0.20329518542651354
Starting epoch 17/30, enc lr scheduler: [0.0004532584106825117], dec lr scheduler: [0.0004532584106825117]
(Epoch 16, iter 50/787) Average loss so far: 2.810
(Epoch 16, iter 100/787) Average loss so far: 2.807
(Epoch 16, iter 150/787) Average loss so far: 2.817
(Epoch 16, iter 200/787) Average loss so far: 2.800
(Epoch 16, iter 250/787) Average loss so far: 2.812
(Epoch 16, iter 300/787) Average loss so far: 2.820
(Epoch 16, iter 350/787) Average loss so far: 2.824
(Epoch 16, iter 400/787) Average loss so far: 2.830
(Epoch 16, iter 450/787) Average loss so far: 2.818
(Epoch 16, iter 500/787) Average loss so far: 2.829
(Epoch 16, iter 550/787) Average loss so far: 2.807
(Epoch 16, iter 600/787) Average loss so far: 2.812
(Epoch 16, iter 650/787) Average loss so far: 2.816
(Epoch 16, iter 700/787) Average loss so far: 2.799
(Epoch 16, iter 750/787) Average loss so far: 2.830
Average epoch loss: 2.816
This epoch took 8.113

7it [00:01,  5.44it/s]


validation loss: 3.146672351019723


100%|██████████| 7/7 [00:03<00:00,  2.15it/s]
100%|██████████| 793/793 [00:01<00:00, 494.13it/s]


BLEU score: 0.04394619733107649, METEOR score: 0.2021709564532238
Starting epoch 18/30, enc lr scheduler: [0.00040208371304520916], dec lr scheduler: [0.00040208371304520916]
(Epoch 17, iter 50/787) Average loss so far: 2.801
(Epoch 17, iter 100/787) Average loss so far: 2.796
(Epoch 17, iter 150/787) Average loss so far: 2.805
(Epoch 17, iter 200/787) Average loss so far: 2.794
(Epoch 17, iter 250/787) Average loss so far: 2.784
(Epoch 17, iter 300/787) Average loss so far: 2.814
(Epoch 17, iter 350/787) Average loss so far: 2.790
(Epoch 17, iter 400/787) Average loss so far: 2.786
(Epoch 17, iter 450/787) Average loss so far: 2.808
(Epoch 17, iter 500/787) Average loss so far: 2.801
(Epoch 17, iter 550/787) Average loss so far: 2.797
(Epoch 17, iter 600/787) Average loss so far: 2.811
(Epoch 17, iter 650/787) Average loss so far: 2.809
(Epoch 17, iter 700/787) Average loss so far: 2.813
(Epoch 17, iter 750/787) Average loss so far: 2.814
Average epoch loss: 2.802
This epoch took 8.12

7it [00:01,  5.54it/s]


validation loss: 3.144554853439331


100%|██████████| 7/7 [00:03<00:00,  2.18it/s]
100%|██████████| 793/793 [00:01<00:00, 539.82it/s]


BLEU score: 0.050306488687707214, METEOR score: 0.2085596593048818
Starting epoch 19/30, enc lr scheduler: [0.00035203658778440114], dec lr scheduler: [0.00035203658778440114]
(Epoch 18, iter 50/787) Average loss so far: 2.776
(Epoch 18, iter 100/787) Average loss so far: 2.779
(Epoch 18, iter 150/787) Average loss so far: 2.778
(Epoch 18, iter 200/787) Average loss so far: 2.786
(Epoch 18, iter 250/787) Average loss so far: 2.788
(Epoch 18, iter 300/787) Average loss so far: 2.771
(Epoch 18, iter 350/787) Average loss so far: 2.792
(Epoch 18, iter 400/787) Average loss so far: 2.808
(Epoch 18, iter 450/787) Average loss so far: 2.795
(Epoch 18, iter 500/787) Average loss so far: 2.784
(Epoch 18, iter 550/787) Average loss so far: 2.794
(Epoch 18, iter 600/787) Average loss so far: 2.780
(Epoch 18, iter 650/787) Average loss so far: 2.786
(Epoch 18, iter 700/787) Average loss so far: 2.799
(Epoch 18, iter 750/787) Average loss so far: 2.798
Average epoch loss: 2.788
This epoch took 8.1

7it [00:01,  5.48it/s]


validation loss: 3.139380386897496


100%|██████████| 7/7 [00:03<00:00,  2.20it/s]
100%|██████████| 793/793 [00:01<00:00, 524.64it/s]


BLEU score: 0.05073499908229344, METEOR score: 0.20909495784132945
Starting epoch 20/30, enc lr scheduler: [0.00030366536167747904], dec lr scheduler: [0.00030366536167747904]
(Epoch 19, iter 50/787) Average loss so far: 2.776
(Epoch 19, iter 100/787) Average loss so far: 2.776
(Epoch 19, iter 150/787) Average loss so far: 2.768
(Epoch 19, iter 200/787) Average loss so far: 2.764
(Epoch 19, iter 250/787) Average loss so far: 2.763
(Epoch 19, iter 300/787) Average loss so far: 2.781
(Epoch 19, iter 350/787) Average loss so far: 2.774
(Epoch 19, iter 400/787) Average loss so far: 2.779
(Epoch 19, iter 450/787) Average loss so far: 2.769
(Epoch 19, iter 500/787) Average loss so far: 2.786
(Epoch 19, iter 550/787) Average loss so far: 2.772
(Epoch 19, iter 600/787) Average loss so far: 2.777
(Epoch 19, iter 650/787) Average loss so far: 2.790
(Epoch 19, iter 700/787) Average loss so far: 2.769
(Epoch 19, iter 750/787) Average loss so far: 2.797
Average epoch loss: 2.777
This epoch took 8.1

7it [00:01,  5.18it/s]


validation loss: 3.1408518382481168


100%|██████████| 7/7 [00:03<00:00,  2.24it/s]
100%|██████████| 793/793 [00:01<00:00, 574.52it/s]


BLEU score: 0.05228166600737611, METEOR score: 0.20709094525810967
Starting epoch 21/30, enc lr scheduler: [0.00025750000000000013], dec lr scheduler: [0.00025750000000000013]
(Epoch 20, iter 50/787) Average loss so far: 2.769
(Epoch 20, iter 100/787) Average loss so far: 2.758
(Epoch 20, iter 150/787) Average loss so far: 2.767
(Epoch 20, iter 200/787) Average loss so far: 2.772
(Epoch 20, iter 250/787) Average loss so far: 2.776
(Epoch 20, iter 300/787) Average loss so far: 2.772
(Epoch 20, iter 350/787) Average loss so far: 2.749
(Epoch 20, iter 400/787) Average loss so far: 2.774
(Epoch 20, iter 450/787) Average loss so far: 2.760
(Epoch 20, iter 500/787) Average loss so far: 2.759
(Epoch 20, iter 550/787) Average loss so far: 2.761
(Epoch 20, iter 600/787) Average loss so far: 2.778
(Epoch 20, iter 650/787) Average loss so far: 2.762
(Epoch 20, iter 700/787) Average loss so far: 2.777
(Epoch 20, iter 750/787) Average loss so far: 2.765
Average epoch loss: 2.767
This epoch took 8.1

7it [00:01,  5.42it/s]


validation loss: 3.1391042300633023


100%|██████████| 7/7 [00:02<00:00,  2.38it/s]
100%|██████████| 793/793 [00:01<00:00, 609.62it/s]


BLEU score: 0.05379213235939195, METEOR score: 0.2088759308593728
Starting epoch 22/30, enc lr scheduler: [0.00021404630011522585], dec lr scheduler: [0.00021404630011522585]
(Epoch 21, iter 50/787) Average loss so far: 2.759
(Epoch 21, iter 100/787) Average loss so far: 2.741
(Epoch 21, iter 150/787) Average loss so far: 2.746
(Epoch 21, iter 200/787) Average loss so far: 2.776
(Epoch 21, iter 250/787) Average loss so far: 2.746
(Epoch 21, iter 300/787) Average loss so far: 2.761
(Epoch 21, iter 350/787) Average loss so far: 2.762
(Epoch 21, iter 400/787) Average loss so far: 2.763
(Epoch 21, iter 450/787) Average loss so far: 2.759
(Epoch 21, iter 500/787) Average loss so far: 2.751
(Epoch 21, iter 550/787) Average loss so far: 2.757
(Epoch 21, iter 600/787) Average loss so far: 2.773
(Epoch 21, iter 650/787) Average loss so far: 2.748
(Epoch 21, iter 700/787) Average loss so far: 2.769
(Epoch 21, iter 750/787) Average loss so far: 2.743
Average epoch loss: 2.758
This epoch took 8.17

7it [00:01,  5.51it/s]


validation loss: 3.1356898035321916


100%|██████████| 7/7 [00:03<00:00,  2.20it/s]
100%|██████████| 793/793 [00:01<00:00, 558.30it/s]


BLEU score: 0.0521244945980193, METEOR score: 0.20665171059130683
Starting epoch 23/30, enc lr scheduler: [0.00017378034985236535], dec lr scheduler: [0.00017378034985236535]
(Epoch 22, iter 50/787) Average loss so far: 2.743
(Epoch 22, iter 100/787) Average loss so far: 2.736
(Epoch 22, iter 150/787) Average loss so far: 2.749
(Epoch 22, iter 200/787) Average loss so far: 2.755
(Epoch 22, iter 250/787) Average loss so far: 2.747
(Epoch 22, iter 300/787) Average loss so far: 2.757
(Epoch 22, iter 350/787) Average loss so far: 2.755
(Epoch 22, iter 400/787) Average loss so far: 2.744
(Epoch 22, iter 450/787) Average loss so far: 2.758
(Epoch 22, iter 500/787) Average loss so far: 2.746
(Epoch 22, iter 550/787) Average loss so far: 2.760
(Epoch 22, iter 600/787) Average loss so far: 2.766
(Epoch 22, iter 650/787) Average loss so far: 2.742
(Epoch 22, iter 700/787) Average loss so far: 2.747
(Epoch 22, iter 750/787) Average loss so far: 2.749
Average epoch loss: 2.750
This epoch took 7.98

7it [00:01,  5.36it/s]


validation loss: 3.1357997144971574


100%|██████████| 7/7 [00:02<00:00,  2.43it/s]
100%|██████████| 793/793 [00:01<00:00, 591.51it/s]


BLEU score: 0.05242640382134315, METEOR score: 0.20686833056692383
Starting epoch 24/30, enc lr scheduler: [0.00013714331138868998], dec lr scheduler: [0.00013714331138868998]
(Epoch 23, iter 50/787) Average loss so far: 2.736
(Epoch 23, iter 100/787) Average loss so far: 2.744
(Epoch 23, iter 150/787) Average loss so far: 2.725
(Epoch 23, iter 200/787) Average loss so far: 2.741
(Epoch 23, iter 250/787) Average loss so far: 2.745
(Epoch 23, iter 300/787) Average loss so far: 2.746
(Epoch 23, iter 350/787) Average loss so far: 2.734
(Epoch 23, iter 400/787) Average loss so far: 2.747
(Epoch 23, iter 450/787) Average loss so far: 2.741
(Epoch 23, iter 500/787) Average loss so far: 2.734
(Epoch 23, iter 550/787) Average loss so far: 2.739
(Epoch 23, iter 600/787) Average loss so far: 2.744
(Epoch 23, iter 650/787) Average loss so far: 2.765
(Epoch 23, iter 700/787) Average loss so far: 2.758
(Epoch 23, iter 750/787) Average loss so far: 2.746
Average epoch loss: 2.744
This epoch took 8.0

7it [00:01,  5.40it/s]


validation loss: 3.135296242577689


100%|██████████| 7/7 [00:03<00:00,  2.24it/s]
100%|██████████| 793/793 [00:01<00:00, 603.28it/s]


BLEU score: 0.051097938950738914, METEOR score: 0.20545257704474157
Starting epoch 25/30, enc lr scheduler: [0.00010453658778440108], dec lr scheduler: [0.00010453658778440108]
(Epoch 24, iter 50/787) Average loss so far: 2.748
(Epoch 24, iter 100/787) Average loss so far: 2.726
(Epoch 24, iter 150/787) Average loss so far: 2.731
(Epoch 24, iter 200/787) Average loss so far: 2.757
(Epoch 24, iter 250/787) Average loss so far: 2.718
(Epoch 24, iter 300/787) Average loss so far: 2.734
(Epoch 24, iter 350/787) Average loss so far: 2.749
(Epoch 24, iter 400/787) Average loss so far: 2.745
(Epoch 24, iter 450/787) Average loss so far: 2.738
(Epoch 24, iter 500/787) Average loss so far: 2.742
(Epoch 24, iter 550/787) Average loss so far: 2.754
(Epoch 24, iter 600/787) Average loss so far: 2.740
(Epoch 24, iter 650/787) Average loss so far: 2.746
(Epoch 24, iter 700/787) Average loss so far: 2.731
(Epoch 24, iter 750/787) Average loss so far: 2.722
Average epoch loss: 2.738
This epoch took 8.

7it [00:01,  5.42it/s]


validation loss: 3.1331892013549805


100%|██████████| 7/7 [00:02<00:00,  2.39it/s]
100%|██████████| 793/793 [00:01<00:00, 597.43it/s]


BLEU score: 0.05338881169928626, METEOR score: 0.20659685260841154
Starting epoch 26/30, enc lr scheduler: [7.631742512670285e-05], dec lr scheduler: [7.631742512670285e-05]
(Epoch 25, iter 50/787) Average loss so far: 2.725
(Epoch 25, iter 100/787) Average loss so far: 2.720
(Epoch 25, iter 150/787) Average loss so far: 2.750
(Epoch 25, iter 200/787) Average loss so far: 2.731
(Epoch 25, iter 250/787) Average loss so far: 2.725
(Epoch 25, iter 300/787) Average loss so far: 2.717
(Epoch 25, iter 350/787) Average loss so far: 2.742
(Epoch 25, iter 400/787) Average loss so far: 2.735
(Epoch 25, iter 450/787) Average loss so far: 2.736
(Epoch 25, iter 500/787) Average loss so far: 2.735
(Epoch 25, iter 550/787) Average loss so far: 2.738
(Epoch 25, iter 600/787) Average loss so far: 2.737
(Epoch 25, iter 650/787) Average loss so far: 2.732
(Epoch 25, iter 700/787) Average loss so far: 2.749
(Epoch 25, iter 750/787) Average loss so far: 2.738
Average epoch loss: 2.734
This epoch took 8.112

7it [00:01,  5.46it/s]


validation loss: 3.1343512194497243


100%|██████████| 7/7 [00:03<00:00,  2.24it/s]
100%|██████████| 793/793 [00:01<00:00, 591.37it/s]


BLEU score: 0.05337146516245549, METEOR score: 0.20767647224625604
Starting epoch 27/30, enc lr scheduler: [5.279499846691252e-05], dec lr scheduler: [5.279499846691252e-05]
(Epoch 26, iter 50/787) Average loss so far: 2.724
(Epoch 26, iter 100/787) Average loss so far: 2.720
(Epoch 26, iter 150/787) Average loss so far: 2.706
(Epoch 26, iter 200/787) Average loss so far: 2.737
(Epoch 26, iter 250/787) Average loss so far: 2.737
(Epoch 26, iter 300/787) Average loss so far: 2.726
(Epoch 26, iter 350/787) Average loss so far: 2.739
(Epoch 26, iter 400/787) Average loss so far: 2.741
(Epoch 26, iter 450/787) Average loss so far: 2.735
(Epoch 26, iter 500/787) Average loss so far: 2.735
(Epoch 26, iter 550/787) Average loss so far: 2.734
(Epoch 26, iter 600/787) Average loss so far: 2.735
(Epoch 26, iter 650/787) Average loss so far: 2.728
(Epoch 26, iter 700/787) Average loss so far: 2.725
(Epoch 26, iter 750/787) Average loss so far: 2.726
Average epoch loss: 2.730
This epoch took 8.101

7it [00:01,  5.39it/s]


validation loss: 3.1328334467751637


100%|██████████| 7/7 [00:03<00:00,  2.24it/s]
100%|██████████| 793/793 [00:01<00:00, 581.81it/s]


BLEU score: 0.05541456263767681, METEOR score: 0.21179274372666657
Starting epoch 28/30, enc lr scheduler: [3.4227024433899005e-05], dec lr scheduler: [3.4227024433899005e-05]
(Epoch 27, iter 50/787) Average loss so far: 2.747
(Epoch 27, iter 100/787) Average loss so far: 2.725
(Epoch 27, iter 150/787) Average loss so far: 2.725
(Epoch 27, iter 200/787) Average loss so far: 2.721
(Epoch 27, iter 250/787) Average loss so far: 2.731
(Epoch 27, iter 300/787) Average loss so far: 2.723
(Epoch 27, iter 350/787) Average loss so far: 2.726
(Epoch 27, iter 400/787) Average loss so far: 2.718
(Epoch 27, iter 450/787) Average loss so far: 2.734
(Epoch 27, iter 500/787) Average loss so far: 2.741
(Epoch 27, iter 550/787) Average loss so far: 2.731
(Epoch 27, iter 600/787) Average loss so far: 2.729
(Epoch 27, iter 650/787) Average loss so far: 2.714
(Epoch 27, iter 700/787) Average loss so far: 2.731
(Epoch 27, iter 750/787) Average loss so far: 2.726
Average epoch loss: 2.728
This epoch took 8.0

7it [00:01,  5.59it/s]


validation loss: 3.1330018043518066


100%|██████████| 7/7 [00:02<00:00,  2.40it/s]
100%|██████████| 793/793 [00:01<00:00, 605.09it/s]


BLEU score: 0.05391199400984629, METEOR score: 0.20699247679832458
Starting epoch 29/30, enc lr scheduler: [2.0816937636766188e-05], dec lr scheduler: [2.0816937636766188e-05]
(Epoch 28, iter 50/787) Average loss so far: 2.733
(Epoch 28, iter 100/787) Average loss so far: 2.710
(Epoch 28, iter 150/787) Average loss so far: 2.729
(Epoch 28, iter 200/787) Average loss so far: 2.698
(Epoch 28, iter 250/787) Average loss so far: 2.733
(Epoch 28, iter 300/787) Average loss so far: 2.728
(Epoch 28, iter 350/787) Average loss so far: 2.734
(Epoch 28, iter 400/787) Average loss so far: 2.728
(Epoch 28, iter 450/787) Average loss so far: 2.732
(Epoch 28, iter 500/787) Average loss so far: 2.728
(Epoch 28, iter 550/787) Average loss so far: 2.732
(Epoch 28, iter 600/787) Average loss so far: 2.721
(Epoch 28, iter 650/787) Average loss so far: 2.731
(Epoch 28, iter 700/787) Average loss so far: 2.728
(Epoch 28, iter 750/787) Average loss so far: 2.726
Average epoch loss: 2.726
This epoch took 7.9

7it [00:01,  5.47it/s]


validation loss: 3.132430451256888


100%|██████████| 7/7 [00:03<00:00,  2.30it/s]
100%|██████████| 793/793 [00:01<00:00, 598.63it/s]


BLEU score: 0.05513984026373013, METEOR score: 0.2094436932678814
Starting epoch 30/30, enc lr scheduler: [1.2711661792704668e-05], dec lr scheduler: [1.2711661792704668e-05]
(Epoch 29, iter 50/787) Average loss so far: 2.732
(Epoch 29, iter 100/787) Average loss so far: 2.722
(Epoch 29, iter 150/787) Average loss so far: 2.715
(Epoch 29, iter 200/787) Average loss so far: 2.739
(Epoch 29, iter 250/787) Average loss so far: 2.727
(Epoch 29, iter 300/787) Average loss so far: 2.729
(Epoch 29, iter 350/787) Average loss so far: 2.730
(Epoch 29, iter 400/787) Average loss so far: 2.714
(Epoch 29, iter 450/787) Average loss so far: 2.725
(Epoch 29, iter 500/787) Average loss so far: 2.711
(Epoch 29, iter 550/787) Average loss so far: 2.733
(Epoch 29, iter 600/787) Average loss so far: 2.737
(Epoch 29, iter 650/787) Average loss so far: 2.723
(Epoch 29, iter 700/787) Average loss so far: 2.724
(Epoch 29, iter 750/787) Average loss so far: 2.724
Average epoch loss: 2.725
This epoch took 7.99

7it [00:01,  5.51it/s]


validation loss: 3.1326404299054826


100%|██████████| 7/7 [00:02<00:00,  2.39it/s]
100%|██████████| 793/793 [00:01<00:00, 595.70it/s]

BLEU score: 0.055078340546709854, METEOR score: 0.20826842611097787





In [14]:
save_model(encoder_attn, decoder_attn, f"{identifier}_last")

## Encoder-Decoder (Extension: pretrained embeddings)

Please download [pretrained embeddings](https://huggingface.co/stanfordnlp/glove/resolve/main/glove.840B.300d.zip) and extract the contents to the project folder.

In [14]:
reset_rng()

In [13]:
pretrained_embedding_dict = create_pretrained_embedding_dict("./glove.840B.300d.txt")

In [14]:
embedding_size=300
encoder_pretrained_embed = EncoderRNN(
    input_size=vocab.n_unique_words, embedding_size=embedding_size, hidden_size=HIDDEN_SIZE, 
    padding_value=vocab.word2index(PAD_WORD), pretrained_embedding_dict=pretrained_embedding_dict, 
    vocab=vocab).to(DEVICE)
# in the training script, decoder is always fed a non-end token and thus never needs to generate padding
# also it should never generate "<UNKNOWN>"
decoder_pretrained_embed = AttnDecoderRNN(
    embedding_size=embedding_size,hidden_size=HIDDEN_SIZE, output_size=vocab.n_unique_words-1,
    padding_val=vocab.word2index(PAD_WORD), dropout=DROPOUT, global_max_ing_len=MAX_INGR_LEN,
    pretrained_embedding_dict=pretrained_embedding_dict, vocab=vocab).to(DEVICE)

100%|██████████| 44315/44315 [00:00<00:00, 371565.78it/s]


29462/44315 (0.665%) words have pretrained embeddings


100%|██████████| 44314/44314 [00:00<00:00, 357392.06it/s]


29462/44314 (0.665%) words have pretrained embeddings


In [18]:
initial_lr=1e-3
min_lr = 1e-5
n_epochs = 30
batch_size=128
encoder_pretrained_embed_optimizer = optim.Adam(encoder_pretrained_embed.parameters(), lr=initial_lr)
decoder_pretrained_embed_optimizer = optim.Adam(decoder_pretrained_embed.parameters(), lr=initial_lr)
# enc_scheduler = CosineAnnealingLR(encoder_pretrained_embed_optimizer, T_max=n_epochs, eta_min=min_lr)
# dec_scheduler = CosineAnnealingLR(decoder_pretrained_embed_optimizer, T_max=n_epochs, eta_min=min_lr)
enc_pretrained_embed_scheduler = CosineAnnealingLR(encoder_pretrained_embed_optimizer, T_max=n_epochs, eta_min=min_lr)
dec_pretrained_embed_scheduler = CosineAnnealingLR(decoder_pretrained_embed_optimizer, T_max=n_epochs, eta_min=min_lr)
identifier="pretrained_emb_attn_adam_without_intermediate_tags_wd0_lr1e-3"

pre_epoch_losses, pre_val_epoch_losses, pre_log = train(
    encoder_pretrained_embed, decoder_pretrained_embed, encoder_pretrained_embed_optimizer, decoder_pretrained_embed_optimizer, 
    train_ds, n_epochs=n_epochs, vocab=vocab, decoder_mode="attention", batch_size=batch_size, 
    enc_lr_scheduler=enc_pretrained_embed_scheduler, dec_lr_scheduler=dec_pretrained_embed_scheduler, 
    dev_ds_val_loss = dev_ds_val_loss, dev_ds_val_met=dev_ds_val_met, identifier=identifier,
    verbose_iter_interval=50)

save_log(identifier, pre_log, encoder_pretrained_embed_optimizer, decoder_pretrained_embed_optimizer, 
         enc_pretrained_embed_scheduler, dec_pretrained_embed_scheduler)

Starting epoch 1/30, enc lr scheduler: [0.001], dec lr scheduler: [0.001]
(Epoch 0, iter 50/787) Average loss so far: 6.968
(Epoch 0, iter 100/787) Average loss so far: 6.040
(Epoch 0, iter 150/787) Average loss so far: 5.995
(Epoch 0, iter 200/787) Average loss so far: 5.955
(Epoch 0, iter 250/787) Average loss so far: 5.910
(Epoch 0, iter 300/787) Average loss so far: 5.859
(Epoch 0, iter 350/787) Average loss so far: 5.839
(Epoch 0, iter 400/787) Average loss so far: 5.759
(Epoch 0, iter 450/787) Average loss so far: 5.627
(Epoch 0, iter 500/787) Average loss so far: 5.404
(Epoch 0, iter 550/787) Average loss so far: 5.209
(Epoch 0, iter 600/787) Average loss so far: 5.001
(Epoch 0, iter 650/787) Average loss so far: 4.851
(Epoch 0, iter 700/787) Average loss so far: 4.692
(Epoch 0, iter 750/787) Average loss so far: 4.582
Average epoch loss: 5.528
This epoch took 8.08877116839091 mins. Time remaining: 3.0 hrs 54.0 mins.


7it [00:01,  5.60it/s]


validation loss: 4.4685172353472025


100%|██████████| 7/7 [00:07<00:00,  1.12s/it]
100%|██████████| 793/793 [00:09<00:00, 87.13it/s] 


BLEU score: 0.0031624657072353694, METEOR score: 0.0796926100408187
Starting epoch 2/30, enc lr scheduler: [0.0009972883382072953], dec lr scheduler: [0.0009972883382072953]
(Epoch 1, iter 50/787) Average loss so far: 4.410
(Epoch 1, iter 100/787) Average loss so far: 4.295
(Epoch 1, iter 150/787) Average loss so far: 4.269
(Epoch 1, iter 200/787) Average loss so far: 4.187
(Epoch 1, iter 250/787) Average loss so far: 4.128
(Epoch 1, iter 300/787) Average loss so far: 4.080
(Epoch 1, iter 350/787) Average loss so far: 4.027
(Epoch 1, iter 400/787) Average loss so far: 3.997
(Epoch 1, iter 450/787) Average loss so far: 3.949
(Epoch 1, iter 500/787) Average loss so far: 3.910
(Epoch 1, iter 550/787) Average loss so far: 3.871
(Epoch 1, iter 600/787) Average loss so far: 3.844
(Epoch 1, iter 650/787) Average loss so far: 3.814
(Epoch 1, iter 700/787) Average loss so far: 3.778
(Epoch 1, iter 750/787) Average loss so far: 3.759
Average epoch loss: 4.009
This epoch took 8.075345408916473 mi

7it [00:01,  5.45it/s]


validation loss: 3.7725216320582797


100%|██████████| 7/7 [00:06<00:00,  1.10it/s]
100%|██████████| 793/793 [00:05<00:00, 144.78it/s]


BLEU score: 0.008172939091770137, METEOR score: 0.1373603160673997
Starting epoch 3/30, enc lr scheduler: [0.0009891830623632338], dec lr scheduler: [0.0009891830623632338]
(Epoch 2, iter 50/787) Average loss so far: 3.698
(Epoch 2, iter 100/787) Average loss so far: 3.680
(Epoch 2, iter 150/787) Average loss so far: 3.658
(Epoch 2, iter 200/787) Average loss so far: 3.622
(Epoch 2, iter 250/787) Average loss so far: 3.602
(Epoch 2, iter 300/787) Average loss so far: 3.609
(Epoch 2, iter 350/787) Average loss so far: 3.561
(Epoch 2, iter 400/787) Average loss so far: 3.547
(Epoch 2, iter 450/787) Average loss so far: 3.528
(Epoch 2, iter 500/787) Average loss so far: 3.528
(Epoch 2, iter 550/787) Average loss so far: 3.529
(Epoch 2, iter 600/787) Average loss so far: 3.514
(Epoch 2, iter 650/787) Average loss so far: 3.481
(Epoch 2, iter 700/787) Average loss so far: 3.483
(Epoch 2, iter 750/787) Average loss so far: 3.445
Average epoch loss: 3.560
This epoch took 8.12299534479777 mins

7it [00:01,  5.61it/s]


validation loss: 3.5144453048706055


100%|██████████| 7/7 [00:04<00:00,  1.57it/s]
100%|██████████| 793/793 [00:03<00:00, 259.34it/s]


BLEU score: 0.015654428117596996, METEOR score: 0.15914722051182856
Starting epoch 4/30, enc lr scheduler: [0.0009757729755661011], dec lr scheduler: [0.0009757729755661011]
(Epoch 3, iter 50/787) Average loss so far: 3.431
(Epoch 3, iter 100/787) Average loss so far: 3.411
(Epoch 3, iter 150/787) Average loss so far: 3.391
(Epoch 3, iter 200/787) Average loss so far: 3.380
(Epoch 3, iter 250/787) Average loss so far: 3.386
(Epoch 3, iter 300/787) Average loss so far: 3.376
(Epoch 3, iter 350/787) Average loss so far: 3.366
(Epoch 3, iter 400/787) Average loss so far: 3.343
(Epoch 3, iter 450/787) Average loss so far: 3.333
(Epoch 3, iter 500/787) Average loss so far: 3.326
(Epoch 3, iter 550/787) Average loss so far: 3.330
(Epoch 3, iter 600/787) Average loss so far: 3.317
(Epoch 3, iter 650/787) Average loss so far: 3.302
(Epoch 3, iter 700/787) Average loss so far: 3.296
(Epoch 3, iter 750/787) Average loss so far: 3.307
Average epoch loss: 3.351
This epoch took 8.083766424655915 mi

7it [00:01,  5.69it/s]


validation loss: 3.3842404569898332


100%|██████████| 7/7 [00:03<00:00,  1.81it/s]
100%|██████████| 793/793 [00:02<00:00, 321.84it/s]


BLEU score: 0.021776681776078183, METEOR score: 0.1731301206546595
Starting epoch 5/30, enc lr scheduler: [0.0009572050015330874], dec lr scheduler: [0.0009572050015330874]
(Epoch 4, iter 50/787) Average loss so far: 3.269
(Epoch 4, iter 100/787) Average loss so far: 3.256
(Epoch 4, iter 150/787) Average loss so far: 3.247
(Epoch 4, iter 200/787) Average loss so far: 3.260
(Epoch 4, iter 250/787) Average loss so far: 3.232
(Epoch 4, iter 300/787) Average loss so far: 3.226
(Epoch 4, iter 350/787) Average loss so far: 3.227
(Epoch 4, iter 400/787) Average loss so far: 3.210
(Epoch 4, iter 450/787) Average loss so far: 3.223
(Epoch 4, iter 500/787) Average loss so far: 3.208
(Epoch 4, iter 550/787) Average loss so far: 3.213
(Epoch 4, iter 600/787) Average loss so far: 3.201
(Epoch 4, iter 650/787) Average loss so far: 3.194
(Epoch 4, iter 700/787) Average loss so far: 3.198
(Epoch 4, iter 750/787) Average loss so far: 3.187
Average epoch loss: 3.221
This epoch took 8.08472265402476 mins

7it [00:01,  5.71it/s]


validation loss: 3.299443074635097


100%|██████████| 7/7 [00:03<00:00,  1.91it/s]
100%|██████████| 793/793 [00:02<00:00, 371.62it/s]


BLEU score: 0.029468622995135987, METEOR score: 0.18938801124897647
Starting epoch 6/30, enc lr scheduler: [0.0009336825748732973], dec lr scheduler: [0.0009336825748732973]
(Epoch 5, iter 50/787) Average loss so far: 3.146
(Epoch 5, iter 100/787) Average loss so far: 3.159
(Epoch 5, iter 150/787) Average loss so far: 3.138
(Epoch 5, iter 200/787) Average loss so far: 3.145
(Epoch 5, iter 250/787) Average loss so far: 3.128
(Epoch 5, iter 300/787) Average loss so far: 3.140
(Epoch 5, iter 350/787) Average loss so far: 3.138
(Epoch 5, iter 400/787) Average loss so far: 3.138
(Epoch 5, iter 450/787) Average loss so far: 3.124
(Epoch 5, iter 500/787) Average loss so far: 3.125
(Epoch 5, iter 550/787) Average loss so far: 3.129
(Epoch 5, iter 600/787) Average loss so far: 3.115
(Epoch 5, iter 650/787) Average loss so far: 3.108
(Epoch 5, iter 700/787) Average loss so far: 3.107
(Epoch 5, iter 750/787) Average loss so far: 3.096
Average epoch loss: 3.128
This epoch took 8.069037703673045 mi

7it [00:01,  5.64it/s]


validation loss: 3.240325791495187


100%|██████████| 7/7 [00:03<00:00,  1.98it/s]
100%|██████████| 793/793 [00:01<00:00, 407.38it/s]


BLEU score: 0.03311247830068684, METEOR score: 0.1885298967343268
Starting epoch 7/30, enc lr scheduler: [0.0009054634122155991], dec lr scheduler: [0.0009054634122155991]
(Epoch 6, iter 50/787) Average loss so far: 3.083
(Epoch 6, iter 100/787) Average loss so far: 3.079
(Epoch 6, iter 150/787) Average loss so far: 3.076
(Epoch 6, iter 200/787) Average loss so far: 3.064
(Epoch 6, iter 250/787) Average loss so far: 3.057
(Epoch 6, iter 300/787) Average loss so far: 3.059
(Epoch 6, iter 350/787) Average loss so far: 3.053
(Epoch 6, iter 400/787) Average loss so far: 3.059
(Epoch 6, iter 450/787) Average loss so far: 3.061
(Epoch 6, iter 500/787) Average loss so far: 3.048
(Epoch 6, iter 550/787) Average loss so far: 3.040
(Epoch 6, iter 600/787) Average loss so far: 3.061
(Epoch 6, iter 650/787) Average loss so far: 3.037
(Epoch 6, iter 700/787) Average loss so far: 3.044
(Epoch 6, iter 750/787) Average loss so far: 3.034
Average epoch loss: 3.056
This epoch took 8.02225666443507 mins.

7it [00:01,  5.65it/s]


validation loss: 3.198171922138759


100%|██████████| 7/7 [00:03<00:00,  1.87it/s]
100%|██████████| 793/793 [00:02<00:00, 363.29it/s]


BLEU score: 0.029020290042606374, METEOR score: 0.18989639255673424
Starting epoch 8/30, enc lr scheduler: [0.0008728566886113102], dec lr scheduler: [0.0008728566886113102]
(Epoch 7, iter 50/787) Average loss so far: 3.021
(Epoch 7, iter 100/787) Average loss so far: 3.015
(Epoch 7, iter 150/787) Average loss so far: 3.023
(Epoch 7, iter 200/787) Average loss so far: 3.006
(Epoch 7, iter 250/787) Average loss so far: 3.001
(Epoch 7, iter 300/787) Average loss so far: 3.006
(Epoch 7, iter 350/787) Average loss so far: 3.002
(Epoch 7, iter 400/787) Average loss so far: 2.992
(Epoch 7, iter 450/787) Average loss so far: 3.000
(Epoch 7, iter 500/787) Average loss so far: 2.984
(Epoch 7, iter 550/787) Average loss so far: 2.998
(Epoch 7, iter 600/787) Average loss so far: 2.974
(Epoch 7, iter 650/787) Average loss so far: 2.990
(Epoch 7, iter 700/787) Average loss so far: 2.993
(Epoch 7, iter 750/787) Average loss so far: 2.989
Average epoch loss: 2.999
This epoch took 8.018091865380605 mi

7it [00:01,  5.54it/s]


validation loss: 3.1683931010110036


100%|██████████| 7/7 [00:03<00:00,  2.00it/s]
100%|██████████| 793/793 [00:01<00:00, 458.38it/s]


BLEU score: 0.040572019424333215, METEOR score: 0.20380746462583008
Starting epoch 9/30, enc lr scheduler: [0.0008362196501476349], dec lr scheduler: [0.0008362196501476349]
(Epoch 8, iter 50/787) Average loss so far: 2.965
(Epoch 8, iter 100/787) Average loss so far: 2.963
(Epoch 8, iter 150/787) Average loss so far: 2.964
(Epoch 8, iter 200/787) Average loss so far: 2.969
(Epoch 8, iter 250/787) Average loss so far: 2.949
(Epoch 8, iter 300/787) Average loss so far: 2.951
(Epoch 8, iter 350/787) Average loss so far: 2.958
(Epoch 8, iter 400/787) Average loss so far: 2.944
(Epoch 8, iter 450/787) Average loss so far: 2.949
(Epoch 8, iter 500/787) Average loss so far: 2.945
(Epoch 8, iter 550/787) Average loss so far: 2.958
(Epoch 8, iter 600/787) Average loss so far: 2.933
(Epoch 8, iter 650/787) Average loss so far: 2.937
(Epoch 8, iter 700/787) Average loss so far: 2.957
(Epoch 8, iter 750/787) Average loss so far: 2.939
Average epoch loss: 2.951
This epoch took 8.068367584546406 mi

7it [00:01,  5.66it/s]


validation loss: 3.145275660923549


100%|██████████| 7/7 [00:03<00:00,  2.00it/s]
100%|██████████| 793/793 [00:01<00:00, 420.62it/s]


BLEU score: 0.035724116920881344, METEOR score: 0.1954857121917429
Starting epoch 10/30, enc lr scheduler: [0.0007959536998847742], dec lr scheduler: [0.0007959536998847742]
(Epoch 9, iter 50/787) Average loss so far: 2.933
(Epoch 9, iter 100/787) Average loss so far: 2.911
(Epoch 9, iter 150/787) Average loss so far: 2.903
(Epoch 9, iter 200/787) Average loss so far: 2.910
(Epoch 9, iter 250/787) Average loss so far: 2.920
(Epoch 9, iter 300/787) Average loss so far: 2.914
(Epoch 9, iter 350/787) Average loss so far: 2.922
(Epoch 9, iter 400/787) Average loss so far: 2.912
(Epoch 9, iter 450/787) Average loss so far: 2.927
(Epoch 9, iter 500/787) Average loss so far: 2.903
(Epoch 9, iter 550/787) Average loss so far: 2.909
(Epoch 9, iter 600/787) Average loss so far: 2.892
(Epoch 9, iter 650/787) Average loss so far: 2.910
(Epoch 9, iter 700/787) Average loss so far: 2.914
(Epoch 9, iter 750/787) Average loss so far: 2.891
Average epoch loss: 2.911
This epoch took 8.082320443789165 mi

7it [00:01,  5.57it/s]


validation loss: 3.127824068069458


100%|██████████| 7/7 [00:03<00:00,  2.11it/s]
100%|██████████| 793/793 [00:01<00:00, 480.40it/s]


BLEU score: 0.039821027333439805, METEOR score: 0.1971742265366046
Starting epoch 11/30, enc lr scheduler: [0.0007525], dec lr scheduler: [0.0007525]
(Epoch 10, iter 50/787) Average loss so far: 2.885
(Epoch 10, iter 100/787) Average loss so far: 2.880
(Epoch 10, iter 150/787) Average loss so far: 2.865
(Epoch 10, iter 200/787) Average loss so far: 2.880
(Epoch 10, iter 250/787) Average loss so far: 2.870
(Epoch 10, iter 300/787) Average loss so far: 2.883
(Epoch 10, iter 350/787) Average loss so far: 2.875
(Epoch 10, iter 400/787) Average loss so far: 2.885
(Epoch 10, iter 450/787) Average loss so far: 2.871
(Epoch 10, iter 500/787) Average loss so far: 2.883
(Epoch 10, iter 550/787) Average loss so far: 2.876
(Epoch 10, iter 600/787) Average loss so far: 2.870
(Epoch 10, iter 650/787) Average loss so far: 2.877
(Epoch 10, iter 700/787) Average loss so far: 2.873
(Epoch 10, iter 750/787) Average loss so far: 2.885
Average epoch loss: 2.877
This epoch took 8.075558908780415 mins. Time 

7it [00:01,  5.60it/s]


validation loss: 3.111717598778861


100%|██████████| 7/7 [00:03<00:00,  2.07it/s]
100%|██████████| 793/793 [00:01<00:00, 485.53it/s]


BLEU score: 0.04188177528685199, METEOR score: 0.20008110559374
Starting epoch 12/30, enc lr scheduler: [0.0007063346383225212], dec lr scheduler: [0.0007063346383225212]
(Epoch 11, iter 50/787) Average loss so far: 2.845
(Epoch 11, iter 100/787) Average loss so far: 2.858
(Epoch 11, iter 150/787) Average loss so far: 2.841
(Epoch 11, iter 200/787) Average loss so far: 2.848
(Epoch 11, iter 250/787) Average loss so far: 2.858
(Epoch 11, iter 300/787) Average loss so far: 2.844
(Epoch 11, iter 350/787) Average loss so far: 2.851
(Epoch 11, iter 400/787) Average loss so far: 2.857
(Epoch 11, iter 450/787) Average loss so far: 2.842
(Epoch 11, iter 500/787) Average loss so far: 2.851
(Epoch 11, iter 550/787) Average loss so far: 2.847
(Epoch 11, iter 600/787) Average loss so far: 2.832
(Epoch 11, iter 650/787) Average loss so far: 2.844
(Epoch 11, iter 700/787) Average loss so far: 2.852
(Epoch 11, iter 750/787) Average loss so far: 2.860
Average epoch loss: 2.848
This epoch took 8.083603

7it [00:01,  5.68it/s]


validation loss: 3.0999631881713867


100%|██████████| 7/7 [00:03<00:00,  2.11it/s]
100%|██████████| 793/793 [00:01<00:00, 503.03it/s]


BLEU score: 0.044384897292243625, METEOR score: 0.20706484658028232
Starting epoch 13/30, enc lr scheduler: [0.000657963412215599], dec lr scheduler: [0.000657963412215599]
(Epoch 12, iter 50/787) Average loss so far: 2.819
(Epoch 12, iter 100/787) Average loss so far: 2.815
(Epoch 12, iter 150/787) Average loss so far: 2.809
(Epoch 12, iter 200/787) Average loss so far: 2.823
(Epoch 12, iter 250/787) Average loss so far: 2.823
(Epoch 12, iter 300/787) Average loss so far: 2.830
(Epoch 12, iter 350/787) Average loss so far: 2.808
(Epoch 12, iter 400/787) Average loss so far: 2.828
(Epoch 12, iter 450/787) Average loss so far: 2.822
(Epoch 12, iter 500/787) Average loss so far: 2.823
(Epoch 12, iter 550/787) Average loss so far: 2.818
(Epoch 12, iter 600/787) Average loss so far: 2.833
(Epoch 12, iter 650/787) Average loss so far: 2.837
(Epoch 12, iter 700/787) Average loss so far: 2.817
(Epoch 12, iter 750/787) Average loss so far: 2.824
Average epoch loss: 2.821
This epoch took 8.0355

7it [00:01,  5.47it/s]


validation loss: 3.087609222957066


100%|██████████| 7/7 [00:03<00:00,  2.05it/s]
100%|██████████| 793/793 [00:01<00:00, 459.58it/s]


BLEU score: 0.041033156261249384, METEOR score: 0.20340139642244925
Starting epoch 14/30, enc lr scheduler: [0.0006079162869547909], dec lr scheduler: [0.0006079162869547909]
(Epoch 13, iter 50/787) Average loss so far: 2.810
(Epoch 13, iter 100/787) Average loss so far: 2.799
(Epoch 13, iter 150/787) Average loss so far: 2.801
(Epoch 13, iter 200/787) Average loss so far: 2.796
(Epoch 13, iter 250/787) Average loss so far: 2.825
(Epoch 13, iter 300/787) Average loss so far: 2.793
(Epoch 13, iter 350/787) Average loss so far: 2.795
(Epoch 13, iter 400/787) Average loss so far: 2.803
(Epoch 13, iter 450/787) Average loss so far: 2.799
(Epoch 13, iter 500/787) Average loss so far: 2.798
(Epoch 13, iter 550/787) Average loss so far: 2.785
(Epoch 13, iter 600/787) Average loss so far: 2.788
(Epoch 13, iter 650/787) Average loss so far: 2.808
(Epoch 13, iter 700/787) Average loss so far: 2.801
(Epoch 13, iter 750/787) Average loss so far: 2.792
Average epoch loss: 2.799
This epoch took 8.04

7it [00:01,  5.56it/s]


validation loss: 3.080646344593593


100%|██████████| 7/7 [00:03<00:00,  2.14it/s]
100%|██████████| 793/793 [00:01<00:00, 515.81it/s]


BLEU score: 0.046692676994029854, METEOR score: 0.20559286007764238
Starting epoch 15/30, enc lr scheduler: [0.0005567415893174886], dec lr scheduler: [0.0005567415893174886]
(Epoch 14, iter 50/787) Average loss so far: 2.781
(Epoch 14, iter 100/787) Average loss so far: 2.776
(Epoch 14, iter 150/787) Average loss so far: 2.767
(Epoch 14, iter 200/787) Average loss so far: 2.774
(Epoch 14, iter 250/787) Average loss so far: 2.769
(Epoch 14, iter 300/787) Average loss so far: 2.789
(Epoch 14, iter 350/787) Average loss so far: 2.760
(Epoch 14, iter 400/787) Average loss so far: 2.791
(Epoch 14, iter 450/787) Average loss so far: 2.784
(Epoch 14, iter 500/787) Average loss so far: 2.793
(Epoch 14, iter 550/787) Average loss so far: 2.776
(Epoch 14, iter 600/787) Average loss so far: 2.781
(Epoch 14, iter 650/787) Average loss so far: 2.773
(Epoch 14, iter 700/787) Average loss so far: 2.780
(Epoch 14, iter 750/787) Average loss so far: 2.788
Average epoch loss: 2.779
This epoch took 8.12

7it [00:01,  5.66it/s]


validation loss: 3.0746142183031355


100%|██████████| 7/7 [00:03<00:00,  2.15it/s]
100%|██████████| 793/793 [00:01<00:00, 521.60it/s]


BLEU score: 0.046943470589521766, METEOR score: 0.20379152039697576
Starting epoch 16/30, enc lr scheduler: [0.0005050000000000002], dec lr scheduler: [0.0005050000000000002]
(Epoch 15, iter 50/787) Average loss so far: 2.756
(Epoch 15, iter 100/787) Average loss so far: 2.753
(Epoch 15, iter 150/787) Average loss so far: 2.763
(Epoch 15, iter 200/787) Average loss so far: 2.755
(Epoch 15, iter 250/787) Average loss so far: 2.757
(Epoch 15, iter 300/787) Average loss so far: 2.760
(Epoch 15, iter 350/787) Average loss so far: 2.764
(Epoch 15, iter 400/787) Average loss so far: 2.747
(Epoch 15, iter 450/787) Average loss so far: 2.753
(Epoch 15, iter 500/787) Average loss so far: 2.772
(Epoch 15, iter 550/787) Average loss so far: 2.772
(Epoch 15, iter 600/787) Average loss so far: 2.770
(Epoch 15, iter 650/787) Average loss so far: 2.764
(Epoch 15, iter 700/787) Average loss so far: 2.764
(Epoch 15, iter 750/787) Average loss so far: 2.762
Average epoch loss: 2.760
This epoch took 8.08

7it [00:01,  5.72it/s]


validation loss: 3.0693612098693848


100%|██████████| 7/7 [00:03<00:00,  2.08it/s]
100%|██████████| 793/793 [00:01<00:00, 496.10it/s]


BLEU score: 0.04693823811822118, METEOR score: 0.21012783449843356
Starting epoch 17/30, enc lr scheduler: [0.0004532584106825117], dec lr scheduler: [0.0004532584106825117]
(Epoch 16, iter 50/787) Average loss so far: 2.741
(Epoch 16, iter 100/787) Average loss so far: 2.739
(Epoch 16, iter 150/787) Average loss so far: 2.747
(Epoch 16, iter 200/787) Average loss so far: 2.731
(Epoch 16, iter 250/787) Average loss so far: 2.741
(Epoch 16, iter 300/787) Average loss so far: 2.748
(Epoch 16, iter 350/787) Average loss so far: 2.753
(Epoch 16, iter 400/787) Average loss so far: 2.756
(Epoch 16, iter 450/787) Average loss so far: 2.745
(Epoch 16, iter 500/787) Average loss so far: 2.757
(Epoch 16, iter 550/787) Average loss so far: 2.734
(Epoch 16, iter 600/787) Average loss so far: 2.740
(Epoch 16, iter 650/787) Average loss so far: 2.744
(Epoch 16, iter 700/787) Average loss so far: 2.728
(Epoch 16, iter 750/787) Average loss so far: 2.755
Average epoch loss: 2.745
This epoch took 8.109

7it [00:01,  5.33it/s]


validation loss: 3.065413134438651


100%|██████████| 7/7 [00:03<00:00,  2.03it/s]
100%|██████████| 793/793 [00:01<00:00, 503.49it/s]


BLEU score: 0.048756397533905496, METEOR score: 0.21361782160791687
Starting epoch 18/30, enc lr scheduler: [0.00040208371304520916], dec lr scheduler: [0.00040208371304520916]
(Epoch 17, iter 50/787) Average loss so far: 2.730
(Epoch 17, iter 100/787) Average loss so far: 2.725
(Epoch 17, iter 150/787) Average loss so far: 2.735
(Epoch 17, iter 200/787) Average loss so far: 2.726
(Epoch 17, iter 250/787) Average loss so far: 2.714
(Epoch 17, iter 300/787) Average loss so far: 2.742
(Epoch 17, iter 350/787) Average loss so far: 2.721
(Epoch 17, iter 400/787) Average loss so far: 2.716
(Epoch 17, iter 450/787) Average loss so far: 2.738
(Epoch 17, iter 500/787) Average loss so far: 2.729
(Epoch 17, iter 550/787) Average loss so far: 2.727
(Epoch 17, iter 600/787) Average loss so far: 2.739
(Epoch 17, iter 650/787) Average loss so far: 2.738
(Epoch 17, iter 700/787) Average loss so far: 2.740
(Epoch 17, iter 750/787) Average loss so far: 2.741
Average epoch loss: 2.731
This epoch took 8.

7it [00:01,  5.55it/s]


validation loss: 3.061034543173654


100%|██████████| 7/7 [00:03<00:00,  2.17it/s]
100%|██████████| 793/793 [00:01<00:00, 541.85it/s]


BLEU score: 0.05130481997908665, METEOR score: 0.21268616360288597
Starting epoch 19/30, enc lr scheduler: [0.00035203658778440114], dec lr scheduler: [0.00035203658778440114]
(Epoch 18, iter 50/787) Average loss so far: 2.706
(Epoch 18, iter 100/787) Average loss so far: 2.712
(Epoch 18, iter 150/787) Average loss so far: 2.709
(Epoch 18, iter 200/787) Average loss so far: 2.718
(Epoch 18, iter 250/787) Average loss so far: 2.717
(Epoch 18, iter 300/787) Average loss so far: 2.702
(Epoch 18, iter 350/787) Average loss so far: 2.721
(Epoch 18, iter 400/787) Average loss so far: 2.737
(Epoch 18, iter 450/787) Average loss so far: 2.725
(Epoch 18, iter 500/787) Average loss so far: 2.715
(Epoch 18, iter 550/787) Average loss so far: 2.726
(Epoch 18, iter 600/787) Average loss so far: 2.708
(Epoch 18, iter 650/787) Average loss so far: 2.715
(Epoch 18, iter 700/787) Average loss so far: 2.727
(Epoch 18, iter 750/787) Average loss so far: 2.723
Average epoch loss: 2.718
This epoch took 8.1

7it [00:01,  5.26it/s]


validation loss: 3.0552920273372104


100%|██████████| 7/7 [00:03<00:00,  2.28it/s]
100%|██████████| 793/793 [00:01<00:00, 586.20it/s]


BLEU score: 0.05395021872522183, METEOR score: 0.21410343910120513
Starting epoch 20/30, enc lr scheduler: [0.00030366536167747904], dec lr scheduler: [0.00030366536167747904]
(Epoch 19, iter 50/787) Average loss so far: 2.709
(Epoch 19, iter 100/787) Average loss so far: 2.710
(Epoch 19, iter 150/787) Average loss so far: 2.699
(Epoch 19, iter 200/787) Average loss so far: 2.694
(Epoch 19, iter 250/787) Average loss so far: 2.694
(Epoch 19, iter 300/787) Average loss so far: 2.711
(Epoch 19, iter 350/787) Average loss so far: 2.704
(Epoch 19, iter 400/787) Average loss so far: 2.708
(Epoch 19, iter 450/787) Average loss so far: 2.699
(Epoch 19, iter 500/787) Average loss so far: 2.716
(Epoch 19, iter 550/787) Average loss so far: 2.701
(Epoch 19, iter 600/787) Average loss so far: 2.707
(Epoch 19, iter 650/787) Average loss so far: 2.717
(Epoch 19, iter 700/787) Average loss so far: 2.700
(Epoch 19, iter 750/787) Average loss so far: 2.728
Average epoch loss: 2.707
This epoch took 8.0

7it [00:01,  5.48it/s]


validation loss: 3.0541302817208424


100%|██████████| 7/7 [00:02<00:00,  2.45it/s]
100%|██████████| 793/793 [00:01<00:00, 657.10it/s]


BLEU score: 0.05167331114826218, METEOR score: 0.2028060241888733
Starting epoch 21/30, enc lr scheduler: [0.00025750000000000013], dec lr scheduler: [0.00025750000000000013]
(Epoch 20, iter 50/787) Average loss so far: 2.700
(Epoch 20, iter 100/787) Average loss so far: 2.687
(Epoch 20, iter 150/787) Average loss so far: 2.699
(Epoch 20, iter 200/787) Average loss so far: 2.703
(Epoch 20, iter 250/787) Average loss so far: 2.706
(Epoch 20, iter 300/787) Average loss so far: 2.704
(Epoch 20, iter 350/787) Average loss so far: 2.679
(Epoch 20, iter 400/787) Average loss so far: 2.706
(Epoch 20, iter 450/787) Average loss so far: 2.690
(Epoch 20, iter 500/787) Average loss so far: 2.689
(Epoch 20, iter 550/787) Average loss so far: 2.693
(Epoch 20, iter 600/787) Average loss so far: 2.709
(Epoch 20, iter 650/787) Average loss so far: 2.693
(Epoch 20, iter 700/787) Average loss so far: 2.706
(Epoch 20, iter 750/787) Average loss so far: 2.695
Average epoch loss: 2.698
This epoch took 7.94

7it [00:01,  5.96it/s]


validation loss: 3.0529042993273054


100%|██████████| 7/7 [00:02<00:00,  2.36it/s]
100%|██████████| 793/793 [00:01<00:00, 604.51it/s]


BLEU score: 0.05473522071477695, METEOR score: 0.2107675382802335
Starting epoch 22/30, enc lr scheduler: [0.00021404630011522585], dec lr scheduler: [0.00021404630011522585]
(Epoch 21, iter 50/787) Average loss so far: 2.690
(Epoch 21, iter 100/787) Average loss so far: 2.674
(Epoch 21, iter 150/787) Average loss so far: 2.676
(Epoch 21, iter 200/787) Average loss so far: 2.707
(Epoch 21, iter 250/787) Average loss so far: 2.681
(Epoch 21, iter 300/787) Average loss so far: 2.691
(Epoch 21, iter 350/787) Average loss so far: 2.691
(Epoch 21, iter 400/787) Average loss so far: 2.695
(Epoch 21, iter 450/787) Average loss so far: 2.690
(Epoch 21, iter 500/787) Average loss so far: 2.681
(Epoch 21, iter 550/787) Average loss so far: 2.690
(Epoch 21, iter 600/787) Average loss so far: 2.704
(Epoch 21, iter 650/787) Average loss so far: 2.679
(Epoch 21, iter 700/787) Average loss so far: 2.699
(Epoch 21, iter 750/787) Average loss so far: 2.675
Average epoch loss: 2.689
This epoch took 7.52

7it [00:01,  6.14it/s]


validation loss: 3.0493842533656528


100%|██████████| 7/7 [00:03<00:00,  2.28it/s]
100%|██████████| 793/793 [00:01<00:00, 541.32it/s]


BLEU score: 0.05105392764669949, METEOR score: 0.2156661550691742
Starting epoch 23/30, enc lr scheduler: [0.00017378034985236535], dec lr scheduler: [0.00017378034985236535]
(Epoch 22, iter 50/787) Average loss so far: 2.676
(Epoch 22, iter 100/787) Average loss so far: 2.671
(Epoch 22, iter 150/787) Average loss so far: 2.682
(Epoch 22, iter 200/787) Average loss so far: 2.685
(Epoch 22, iter 250/787) Average loss so far: 2.676
(Epoch 22, iter 300/787) Average loss so far: 2.687
(Epoch 22, iter 350/787) Average loss so far: 2.687
(Epoch 22, iter 400/787) Average loss so far: 2.676
(Epoch 22, iter 450/787) Average loss so far: 2.688
(Epoch 22, iter 500/787) Average loss so far: 2.678
(Epoch 22, iter 550/787) Average loss so far: 2.694
(Epoch 22, iter 600/787) Average loss so far: 2.698
(Epoch 22, iter 650/787) Average loss so far: 2.673
(Epoch 22, iter 700/787) Average loss so far: 2.676
(Epoch 22, iter 750/787) Average loss so far: 2.681
Average epoch loss: 2.682
This epoch took 7.50

7it [00:01,  5.96it/s]


validation loss: 3.0472822529929027


100%|██████████| 7/7 [00:02<00:00,  2.37it/s]
100%|██████████| 793/793 [00:01<00:00, 605.43it/s]


BLEU score: 0.05496028256787088, METEOR score: 0.21032837797650386
Starting epoch 24/30, enc lr scheduler: [0.00013714331138868998], dec lr scheduler: [0.00013714331138868998]
(Epoch 23, iter 50/787) Average loss so far: 2.669
(Epoch 23, iter 100/787) Average loss so far: 2.674
(Epoch 23, iter 150/787) Average loss so far: 2.658
(Epoch 23, iter 200/787) Average loss so far: 2.672
(Epoch 23, iter 250/787) Average loss so far: 2.679
(Epoch 23, iter 300/787) Average loss so far: 2.678
(Epoch 23, iter 350/787) Average loss so far: 2.670
(Epoch 23, iter 400/787) Average loss so far: 2.677
(Epoch 23, iter 450/787) Average loss so far: 2.673
(Epoch 23, iter 500/787) Average loss so far: 2.668
(Epoch 23, iter 550/787) Average loss so far: 2.670
(Epoch 23, iter 600/787) Average loss so far: 2.674
(Epoch 23, iter 650/787) Average loss so far: 2.698
(Epoch 23, iter 700/787) Average loss so far: 2.690
(Epoch 23, iter 750/787) Average loss so far: 2.678
Average epoch loss: 2.676
This epoch took 7.4

7it [00:01,  6.00it/s]


validation loss: 3.0462141377585277


100%|██████████| 7/7 [00:02<00:00,  2.35it/s]
100%|██████████| 793/793 [00:01<00:00, 610.16it/s]


BLEU score: 0.05558312256170381, METEOR score: 0.21142443556612248
Starting epoch 25/30, enc lr scheduler: [0.00010453658778440108], dec lr scheduler: [0.00010453658778440108]
(Epoch 24, iter 50/787) Average loss so far: 2.682
(Epoch 24, iter 100/787) Average loss so far: 2.661
(Epoch 24, iter 150/787) Average loss so far: 2.664
(Epoch 24, iter 200/787) Average loss so far: 2.688
(Epoch 24, iter 250/787) Average loss so far: 2.652
(Epoch 24, iter 300/787) Average loss so far: 2.666
(Epoch 24, iter 350/787) Average loss so far: 2.681
(Epoch 24, iter 400/787) Average loss so far: 2.678
(Epoch 24, iter 450/787) Average loss so far: 2.669
(Epoch 24, iter 500/787) Average loss so far: 2.676
(Epoch 24, iter 550/787) Average loss so far: 2.685
(Epoch 24, iter 600/787) Average loss so far: 2.672
(Epoch 24, iter 650/787) Average loss so far: 2.678
(Epoch 24, iter 700/787) Average loss so far: 2.666
(Epoch 24, iter 750/787) Average loss so far: 2.655
Average epoch loss: 2.671
This epoch took 7.5

7it [00:01,  5.94it/s]


validation loss: 3.0450666631971086


100%|██████████| 7/7 [00:03<00:00,  2.32it/s]
100%|██████████| 793/793 [00:01<00:00, 580.66it/s]


BLEU score: 0.054168369543141096, METEOR score: 0.2190973874735378
Starting epoch 26/30, enc lr scheduler: [7.631742512670285e-05], dec lr scheduler: [7.631742512670285e-05]
(Epoch 25, iter 50/787) Average loss so far: 2.660
(Epoch 25, iter 100/787) Average loss so far: 2.655
(Epoch 25, iter 150/787) Average loss so far: 2.682
(Epoch 25, iter 200/787) Average loss so far: 2.663
(Epoch 25, iter 250/787) Average loss so far: 2.658
(Epoch 25, iter 300/787) Average loss so far: 2.653
(Epoch 25, iter 350/787) Average loss so far: 2.674
(Epoch 25, iter 400/787) Average loss so far: 2.667
(Epoch 25, iter 450/787) Average loss so far: 2.669
(Epoch 25, iter 500/787) Average loss so far: 2.668
(Epoch 25, iter 550/787) Average loss so far: 2.671
(Epoch 25, iter 600/787) Average loss so far: 2.666
(Epoch 25, iter 650/787) Average loss so far: 2.666
(Epoch 25, iter 700/787) Average loss so far: 2.679
(Epoch 25, iter 750/787) Average loss so far: 2.669
Average epoch loss: 2.667
This epoch took 7.473

7it [00:01,  6.05it/s]


validation loss: 3.0444539955684116


100%|██████████| 7/7 [00:02<00:00,  2.53it/s]
100%|██████████| 793/793 [00:01<00:00, 609.17it/s]


BLEU score: 0.05549097705999941, METEOR score: 0.2134507586141872
Starting epoch 27/30, enc lr scheduler: [5.279499846691252e-05], dec lr scheduler: [5.279499846691252e-05]
(Epoch 26, iter 50/787) Average loss so far: 2.655
(Epoch 26, iter 100/787) Average loss so far: 2.655
(Epoch 26, iter 150/787) Average loss so far: 2.641
(Epoch 26, iter 200/787) Average loss so far: 2.668
(Epoch 26, iter 250/787) Average loss so far: 2.671
(Epoch 26, iter 300/787) Average loss so far: 2.661
(Epoch 26, iter 350/787) Average loss so far: 2.672
(Epoch 26, iter 400/787) Average loss so far: 2.673
(Epoch 26, iter 450/787) Average loss so far: 2.667
(Epoch 26, iter 500/787) Average loss so far: 2.671
(Epoch 26, iter 550/787) Average loss so far: 2.665
(Epoch 26, iter 600/787) Average loss so far: 2.668
(Epoch 26, iter 650/787) Average loss so far: 2.661
(Epoch 26, iter 700/787) Average loss so far: 2.657
(Epoch 26, iter 750/787) Average loss so far: 2.660
Average epoch loss: 2.663
This epoch took 7.4265

7it [00:01,  6.26it/s]


validation loss: 3.0439279760633196


100%|██████████| 7/7 [00:02<00:00,  2.39it/s]
100%|██████████| 793/793 [00:01<00:00, 605.29it/s]


BLEU score: 0.05676168975713699, METEOR score: 0.2139917855970703
Starting epoch 28/30, enc lr scheduler: [3.4227024433899005e-05], dec lr scheduler: [3.4227024433899005e-05]
(Epoch 27, iter 50/787) Average loss so far: 2.679
(Epoch 27, iter 100/787) Average loss so far: 2.660
(Epoch 27, iter 150/787) Average loss so far: 2.659
(Epoch 27, iter 200/787) Average loss so far: 2.654
(Epoch 27, iter 250/787) Average loss so far: 2.664
(Epoch 27, iter 300/787) Average loss so far: 2.655
(Epoch 27, iter 350/787) Average loss so far: 2.660
(Epoch 27, iter 400/787) Average loss so far: 2.650
(Epoch 27, iter 450/787) Average loss so far: 2.669
(Epoch 27, iter 500/787) Average loss so far: 2.670
(Epoch 27, iter 550/787) Average loss so far: 2.664
(Epoch 27, iter 600/787) Average loss so far: 2.660
(Epoch 27, iter 650/787) Average loss so far: 2.650
(Epoch 27, iter 700/787) Average loss so far: 2.664
(Epoch 27, iter 750/787) Average loss so far: 2.661
Average epoch loss: 2.661
This epoch took 7.45

7it [00:01,  6.13it/s]


validation loss: 3.0432401725224087


100%|██████████| 7/7 [00:02<00:00,  2.35it/s]
100%|██████████| 793/793 [00:01<00:00, 595.50it/s]


BLEU score: 0.05472410622862404, METEOR score: 0.2129614425943201
Starting epoch 29/30, enc lr scheduler: [2.0816937636766188e-05], dec lr scheduler: [2.0816937636766188e-05]
(Epoch 28, iter 50/787) Average loss so far: 2.666
(Epoch 28, iter 100/787) Average loss so far: 2.647
(Epoch 28, iter 150/787) Average loss so far: 2.662
(Epoch 28, iter 200/787) Average loss so far: 2.633
(Epoch 28, iter 250/787) Average loss so far: 2.664
(Epoch 28, iter 300/787) Average loss so far: 2.664
(Epoch 28, iter 350/787) Average loss so far: 2.668
(Epoch 28, iter 400/787) Average loss so far: 2.659
(Epoch 28, iter 450/787) Average loss so far: 2.665
(Epoch 28, iter 500/787) Average loss so far: 2.660
(Epoch 28, iter 550/787) Average loss so far: 2.663
(Epoch 28, iter 600/787) Average loss so far: 2.655
(Epoch 28, iter 650/787) Average loss so far: 2.662
(Epoch 28, iter 700/787) Average loss so far: 2.662
(Epoch 28, iter 750/787) Average loss so far: 2.661
Average epoch loss: 2.660
This epoch took 7.44

7it [00:01,  5.94it/s]


validation loss: 3.0424865995134627


100%|██████████| 7/7 [00:02<00:00,  2.34it/s]
100%|██████████| 793/793 [00:01<00:00, 588.28it/s]


BLEU score: 0.054989968499987295, METEOR score: 0.21584912823420402
Starting epoch 30/30, enc lr scheduler: [1.2711661792704668e-05], dec lr scheduler: [1.2711661792704668e-05]
(Epoch 29, iter 50/787) Average loss so far: 2.667
(Epoch 29, iter 100/787) Average loss so far: 2.655
(Epoch 29, iter 150/787) Average loss so far: 2.648
(Epoch 29, iter 200/787) Average loss so far: 2.672
(Epoch 29, iter 250/787) Average loss so far: 2.658
(Epoch 29, iter 300/787) Average loss so far: 2.662
(Epoch 29, iter 350/787) Average loss so far: 2.664
(Epoch 29, iter 400/787) Average loss so far: 2.649
(Epoch 29, iter 450/787) Average loss so far: 2.657
(Epoch 29, iter 500/787) Average loss so far: 2.646
(Epoch 29, iter 550/787) Average loss so far: 2.667
(Epoch 29, iter 600/787) Average loss so far: 2.670
(Epoch 29, iter 650/787) Average loss so far: 2.655
(Epoch 29, iter 700/787) Average loss so far: 2.656
(Epoch 29, iter 750/787) Average loss so far: 2.661
Average epoch loss: 2.658
This epoch took 7.

7it [00:01,  6.06it/s]


validation loss: 3.0425445352281844


100%|██████████| 7/7 [00:02<00:00,  2.39it/s]
100%|██████████| 793/793 [00:01<00:00, 619.58it/s]


BLEU score: 0.057802547829115924, METEOR score: 0.21431412239195574


## Extension 4: Multi-Layer Encoder-Decoder + Neurologic Decoding

In [24]:
reset_rng()

In [15]:
embedding_size=300
num_layers=3
encoder_multilayer_attn = EncoderRNN(vocab.n_unique_words, embedding_size=embedding_size, hidden_size=HIDDEN_SIZE, 
                          padding_value=vocab.word2index(PAD_WORD), num_lstm_layers=num_layers).to(DEVICE)
# in the training script, decoder is always fed a non-end token and thus never needs to generate padding
# also it should never generate "<UNKNOWN>"
# decoder = DecoderRNN(embedding_size=embedding_size,hidden_size=HIDDEN_SIZE, output_size=vocab.n_unique_words-2).to(DEVICE)
decoder_multilayer_attn = AttnDecoderRNN(embedding_size, hidden_size=HIDDEN_SIZE, output_size=vocab.n_unique_words-1, 
                              padding_val=vocab.word2index(PAD_WORD), dropout=DROPOUT, num_lstm_layers=num_layers).to(DEVICE)

In [26]:
initial_lr=1e-3
min_lr = 1e-5
n_epochs = 30
batch_size=128
encoder_multilayer_attn_optimizer = optim.Adam(encoder_multilayer_attn.parameters(), lr=initial_lr)
decoder_multilayer_attn_optimizer = optim.Adam(decoder_multilayer_attn.parameters(), lr=initial_lr)
enc_multilayer_attn_scheduler = CosineAnnealingLR(encoder_multilayer_attn_optimizer, T_max=n_epochs, eta_min=min_lr)
dec_multilayer_attn_scheduler = CosineAnnealingLR(decoder_multilayer_attn_optimizer, T_max=n_epochs, eta_min=min_lr)
identifier="multilayer_attn_adam_without_intermediate_tags_wd0_lr1e-3"

multilayer_attn_epoch_losses, multilayer_attn_val_epoch_losses, multilayer_attn_log = train(
    encoder_multilayer_attn, decoder_multilayer_attn, encoder_multilayer_attn_optimizer, decoder_multilayer_attn_optimizer, train_ds, 
    n_epochs=n_epochs, vocab=vocab, decoder_mode="attention", batch_size=batch_size, 
    enc_lr_scheduler=enc_multilayer_attn_scheduler, dec_lr_scheduler=dec_multilayer_attn_scheduler, 
    dev_ds_val_loss = dev_ds_val_loss, dev_ds_val_met=dev_ds_val_met, identifier=identifier,
    verbose_iter_interval=50)

save_log(identifier, multilayer_attn_log, encoder_multilayer_attn_optimizer, decoder_multilayer_attn_optimizer, 
         enc_multilayer_attn_scheduler, dec_multilayer_attn_scheduler)

Starting epoch 1/30, enc lr scheduler: [0.001], dec lr scheduler: [0.001]
(Epoch 0, iter 50/787) Average loss so far: 6.988
(Epoch 0, iter 100/787) Average loss so far: 6.058
(Epoch 0, iter 150/787) Average loss so far: 6.035
(Epoch 0, iter 200/787) Average loss so far: 6.032
(Epoch 0, iter 250/787) Average loss so far: 6.019
(Epoch 0, iter 300/787) Average loss so far: 6.011
(Epoch 0, iter 350/787) Average loss so far: 6.015
(Epoch 0, iter 400/787) Average loss so far: 6.006
(Epoch 0, iter 450/787) Average loss so far: 6.002
(Epoch 0, iter 500/787) Average loss so far: 5.990
(Epoch 0, iter 550/787) Average loss so far: 5.972
(Epoch 0, iter 600/787) Average loss so far: 5.961
(Epoch 0, iter 650/787) Average loss so far: 5.934
(Epoch 0, iter 700/787) Average loss so far: 5.930
(Epoch 0, iter 750/787) Average loss so far: 5.931
Average epoch loss: 6.053
This epoch took 8.784394307931263 mins. Time remaining: 4.0 hrs 14.0 mins.


7it [00:01,  5.15it/s]


validation loss: 5.933405944279262


100%|██████████| 7/7 [00:08<00:00,  1.15s/it]
The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
100%|██████████| 793/793 [00:06<00:00, 114.02it/s]


BLEU score: 5.8211892454135815e-80, METEOR score: 0.034986121334209475
Starting epoch 2/30, enc lr scheduler: [0.0009972883382072953], dec lr scheduler: [0.0009972883382072953]
(Epoch 1, iter 50/787) Average loss so far: 5.916
(Epoch 1, iter 100/787) Average loss so far: 5.902
(Epoch 1, iter 150/787) Average loss so far: 5.896
(Epoch 1, iter 200/787) Average loss so far: 5.880
(Epoch 1, iter 250/787) Average loss so far: 5.882
(Epoch 1, iter 300/787) Average loss so far: 5.891
(Epoch 1, iter 350/787) Average loss so far: 5.881
(Epoch 1, iter 400/787) Average loss so far: 5.874
(Epoch 1, iter 450/787) Average loss so far: 5.868
(Epoch 1, iter 500/787) Average loss so far: 5.830
(Epoch 1, iter 550/787) Average loss so far: 5.730
(Epoch 1, iter 600/787) Average loss so far: 5.635
(Epoch 1, iter 650/787) Average loss so far: 5.547
(Epoch 1, iter 700/787) Average loss so far: 5.400
(Epoch 1, iter 750/787) Average loss so far: 5.205
Average epoch loss: 5.724
This epoch took 8.680376549561819

7it [00:01,  5.27it/s]


validation loss: 5.041280337742397


100%|██████████| 7/7 [00:07<00:00,  1.08s/it]
100%|██████████| 793/793 [00:06<00:00, 114.85it/s]


BLEU score: 0.003278639860407049, METEOR score: 0.08561562097555378
Starting epoch 3/30, enc lr scheduler: [0.0009891830623632338], dec lr scheduler: [0.0009891830623632338]
(Epoch 2, iter 50/787) Average loss so far: 4.912
(Epoch 2, iter 100/787) Average loss so far: 4.751
(Epoch 2, iter 150/787) Average loss so far: 4.599
(Epoch 2, iter 200/787) Average loss so far: 4.489
(Epoch 2, iter 250/787) Average loss so far: 4.364
(Epoch 2, iter 300/787) Average loss so far: 4.268
(Epoch 2, iter 350/787) Average loss so far: 4.199
(Epoch 2, iter 400/787) Average loss so far: 4.150
(Epoch 2, iter 450/787) Average loss so far: 4.090
(Epoch 2, iter 500/787) Average loss so far: 4.023
(Epoch 2, iter 550/787) Average loss so far: 3.958
(Epoch 2, iter 600/787) Average loss so far: 3.944
(Epoch 2, iter 650/787) Average loss so far: 3.891
(Epoch 2, iter 700/787) Average loss so far: 3.854
(Epoch 2, iter 750/787) Average loss so far: 3.827
Average epoch loss: 4.201
This epoch took 8.660929580529531 mi

7it [00:01,  5.22it/s]


validation loss: 3.8312723296029225


100%|██████████| 7/7 [00:06<00:00,  1.14it/s]
100%|██████████| 793/793 [00:05<00:00, 150.06it/s]


BLEU score: 0.009649228203172842, METEOR score: 0.15070114988789704
Starting epoch 4/30, enc lr scheduler: [0.0009757729755661011], dec lr scheduler: [0.0009757729755661011]
(Epoch 3, iter 50/787) Average loss so far: 3.738
(Epoch 3, iter 100/787) Average loss so far: 3.721
(Epoch 3, iter 150/787) Average loss so far: 3.697
(Epoch 3, iter 200/787) Average loss so far: 3.681
(Epoch 3, iter 250/787) Average loss so far: 3.687
(Epoch 3, iter 300/787) Average loss so far: 3.624
(Epoch 3, iter 350/787) Average loss so far: 3.616
(Epoch 3, iter 400/787) Average loss so far: 3.585
(Epoch 3, iter 450/787) Average loss so far: 3.555
(Epoch 3, iter 500/787) Average loss so far: 3.548
(Epoch 3, iter 550/787) Average loss so far: 3.543
(Epoch 3, iter 600/787) Average loss so far: 3.546
(Epoch 3, iter 650/787) Average loss so far: 3.513
(Epoch 3, iter 700/787) Average loss so far: 3.485
(Epoch 3, iter 750/787) Average loss so far: 3.485
Average epoch loss: 3.596
This epoch took 8.700705035527546 mi

7it [00:01,  5.06it/s]


validation loss: 3.5433250835963657


100%|██████████| 7/7 [00:04<00:00,  1.49it/s]
100%|██████████| 793/793 [00:03<00:00, 237.64it/s]


BLEU score: 0.01680922041802637, METEOR score: 0.17252114463959134
Starting epoch 5/30, enc lr scheduler: [0.0009572050015330874], dec lr scheduler: [0.0009572050015330874]
(Epoch 4, iter 50/787) Average loss so far: 3.437
(Epoch 4, iter 100/787) Average loss so far: 3.421
(Epoch 4, iter 150/787) Average loss so far: 3.419
(Epoch 4, iter 200/787) Average loss so far: 3.411
(Epoch 4, iter 250/787) Average loss so far: 3.377
(Epoch 4, iter 300/787) Average loss so far: 3.361
(Epoch 4, iter 350/787) Average loss so far: 3.362
(Epoch 4, iter 400/787) Average loss so far: 3.337
(Epoch 4, iter 450/787) Average loss so far: 3.360
(Epoch 4, iter 500/787) Average loss so far: 3.334
(Epoch 4, iter 550/787) Average loss so far: 3.319
(Epoch 4, iter 600/787) Average loss so far: 3.328
(Epoch 4, iter 650/787) Average loss so far: 3.312
(Epoch 4, iter 700/787) Average loss so far: 3.283
(Epoch 4, iter 750/787) Average loss so far: 3.285
Average epoch loss: 3.354
This epoch took 8.66706737279892 mins

7it [00:01,  5.22it/s]


validation loss: 3.3929422923496793


100%|██████████| 7/7 [00:04<00:00,  1.69it/s]
100%|██████████| 793/793 [00:02<00:00, 300.09it/s]


BLEU score: 0.02410393818638161, METEOR score: 0.18903294738920962
Starting epoch 6/30, enc lr scheduler: [0.0009336825748732973], dec lr scheduler: [0.0009336825748732973]
(Epoch 5, iter 50/787) Average loss so far: 3.268
(Epoch 5, iter 100/787) Average loss so far: 3.238
(Epoch 5, iter 150/787) Average loss so far: 3.226
(Epoch 5, iter 200/787) Average loss so far: 3.217
(Epoch 5, iter 250/787) Average loss so far: 3.211
(Epoch 5, iter 300/787) Average loss so far: 3.213
(Epoch 5, iter 350/787) Average loss so far: 3.203
(Epoch 5, iter 400/787) Average loss so far: 3.193
(Epoch 5, iter 450/787) Average loss so far: 3.200
(Epoch 5, iter 500/787) Average loss so far: 3.213
(Epoch 5, iter 550/787) Average loss so far: 3.187
(Epoch 5, iter 600/787) Average loss so far: 3.182
(Epoch 5, iter 650/787) Average loss so far: 3.167
(Epoch 5, iter 700/787) Average loss so far: 3.176
(Epoch 5, iter 750/787) Average loss so far: 3.172
Average epoch loss: 3.202
This epoch took 8.660624651114146 min

7it [00:01,  5.19it/s]


validation loss: 3.2941173825945174


100%|██████████| 7/7 [00:04<00:00,  1.53it/s]
100%|██████████| 793/793 [00:03<00:00, 240.61it/s]


BLEU score: 0.01932799321202002, METEOR score: 0.18168938969894916
Starting epoch 7/30, enc lr scheduler: [0.0009054634122155991], dec lr scheduler: [0.0009054634122155991]
(Epoch 6, iter 50/787) Average loss so far: 3.107
(Epoch 6, iter 100/787) Average loss so far: 3.130
(Epoch 6, iter 150/787) Average loss so far: 3.095
(Epoch 6, iter 200/787) Average loss so far: 3.118
(Epoch 6, iter 250/787) Average loss so far: 3.103
(Epoch 6, iter 300/787) Average loss so far: 3.102
(Epoch 6, iter 350/787) Average loss so far: 3.094
(Epoch 6, iter 400/787) Average loss so far: 3.105
(Epoch 6, iter 450/787) Average loss so far: 3.091
(Epoch 6, iter 500/787) Average loss so far: 3.089
(Epoch 6, iter 550/787) Average loss so far: 3.082
(Epoch 6, iter 600/787) Average loss so far: 3.077
(Epoch 6, iter 650/787) Average loss so far: 3.076
(Epoch 6, iter 700/787) Average loss so far: 3.068
(Epoch 6, iter 750/787) Average loss so far: 3.063
Average epoch loss: 3.091
This epoch took 8.661432000001271 min

7it [00:01,  5.31it/s]


validation loss: 3.220607246671404


100%|██████████| 7/7 [00:03<00:00,  1.91it/s]
100%|██████████| 793/793 [00:01<00:00, 416.84it/s]


BLEU score: 0.034996658118198035, METEOR score: 0.19608702213958482
Starting epoch 8/30, enc lr scheduler: [0.0008728566886113102], dec lr scheduler: [0.0008728566886113102]
(Epoch 7, iter 50/787) Average loss so far: 3.020
(Epoch 7, iter 100/787) Average loss so far: 3.023
(Epoch 7, iter 150/787) Average loss so far: 3.018
(Epoch 7, iter 200/787) Average loss so far: 3.026
(Epoch 7, iter 250/787) Average loss so far: 3.011
(Epoch 7, iter 300/787) Average loss so far: 3.009
(Epoch 7, iter 350/787) Average loss so far: 3.006
(Epoch 7, iter 400/787) Average loss so far: 3.008
(Epoch 7, iter 450/787) Average loss so far: 3.012
(Epoch 7, iter 500/787) Average loss so far: 3.002
(Epoch 7, iter 550/787) Average loss so far: 3.000
(Epoch 7, iter 600/787) Average loss so far: 2.995
(Epoch 7, iter 650/787) Average loss so far: 2.978
(Epoch 7, iter 700/787) Average loss so far: 2.987
(Epoch 7, iter 750/787) Average loss so far: 2.994
Average epoch loss: 3.005
This epoch took 8.668269244829814 mi

7it [00:01,  5.30it/s]


validation loss: 3.1842188835144043


100%|██████████| 7/7 [00:03<00:00,  1.98it/s]
100%|██████████| 793/793 [00:01<00:00, 460.40it/s]


BLEU score: 0.04036306790727657, METEOR score: 0.20022729874967282
Starting epoch 9/30, enc lr scheduler: [0.0008362196501476349], dec lr scheduler: [0.0008362196501476349]
(Epoch 8, iter 50/787) Average loss so far: 2.934
(Epoch 8, iter 100/787) Average loss so far: 2.937
(Epoch 8, iter 150/787) Average loss so far: 2.937
(Epoch 8, iter 200/787) Average loss so far: 2.949
(Epoch 8, iter 250/787) Average loss so far: 2.949
(Epoch 8, iter 300/787) Average loss so far: 2.938
(Epoch 8, iter 350/787) Average loss so far: 2.934
(Epoch 8, iter 400/787) Average loss so far: 2.917
(Epoch 8, iter 450/787) Average loss so far: 2.927
(Epoch 8, iter 500/787) Average loss so far: 2.950
(Epoch 8, iter 550/787) Average loss so far: 2.952
(Epoch 8, iter 600/787) Average loss so far: 2.941
(Epoch 8, iter 650/787) Average loss so far: 2.934
(Epoch 8, iter 700/787) Average loss so far: 2.926
(Epoch 8, iter 750/787) Average loss so far: 2.933
Average epoch loss: 2.936
This epoch took 8.682707707087198 min

7it [00:01,  5.38it/s]


validation loss: 3.14445652280535


100%|██████████| 7/7 [00:03<00:00,  2.11it/s]
100%|██████████| 793/793 [00:01<00:00, 543.89it/s]


BLEU score: 0.04614282375504496, METEOR score: 0.19953643331690282
Starting epoch 10/30, enc lr scheduler: [0.0007959536998847742], dec lr scheduler: [0.0007959536998847742]
(Epoch 9, iter 50/787) Average loss so far: 2.881
(Epoch 9, iter 100/787) Average loss so far: 2.873
(Epoch 9, iter 150/787) Average loss so far: 2.894
(Epoch 9, iter 200/787) Average loss so far: 2.899
(Epoch 9, iter 250/787) Average loss so far: 2.899
(Epoch 9, iter 300/787) Average loss so far: 2.882
(Epoch 9, iter 350/787) Average loss so far: 2.858
(Epoch 9, iter 400/787) Average loss so far: 2.883
(Epoch 9, iter 450/787) Average loss so far: 2.864
(Epoch 9, iter 500/787) Average loss so far: 2.867
(Epoch 9, iter 550/787) Average loss so far: 2.876
(Epoch 9, iter 600/787) Average loss so far: 2.872
(Epoch 9, iter 650/787) Average loss so far: 2.900
(Epoch 9, iter 700/787) Average loss so far: 2.871
(Epoch 9, iter 750/787) Average loss so far: 2.870
Average epoch loss: 2.879
This epoch took 8.667925226688386 mi

7it [00:01,  5.29it/s]


validation loss: 3.120482785361154


100%|██████████| 7/7 [00:03<00:00,  2.12it/s]
100%|██████████| 793/793 [00:01<00:00, 558.82it/s]


BLEU score: 0.047825885570907486, METEOR score: 0.205678013742886
Starting epoch 11/30, enc lr scheduler: [0.0007525], dec lr scheduler: [0.0007525]
(Epoch 10, iter 50/787) Average loss so far: 2.819
(Epoch 10, iter 100/787) Average loss so far: 2.829
(Epoch 10, iter 150/787) Average loss so far: 2.827
(Epoch 10, iter 200/787) Average loss so far: 2.831
(Epoch 10, iter 250/787) Average loss so far: 2.829
(Epoch 10, iter 300/787) Average loss so far: 2.839
(Epoch 10, iter 350/787) Average loss so far: 2.841
(Epoch 10, iter 400/787) Average loss so far: 2.835
(Epoch 10, iter 450/787) Average loss so far: 2.823
(Epoch 10, iter 500/787) Average loss so far: 2.816
(Epoch 10, iter 550/787) Average loss so far: 2.818
(Epoch 10, iter 600/787) Average loss so far: 2.836
(Epoch 10, iter 650/787) Average loss so far: 2.844
(Epoch 10, iter 700/787) Average loss so far: 2.833
(Epoch 10, iter 750/787) Average loss so far: 2.826
Average epoch loss: 2.829
This epoch took 8.670869362354278 mins. Time r

7it [00:01,  5.27it/s]


validation loss: 3.101747683116368


100%|██████████| 7/7 [00:02<00:00,  2.34it/s]
100%|██████████| 793/793 [00:01<00:00, 617.97it/s]


BLEU score: 0.051791655004959615, METEOR score: 0.20470572480986682
Starting epoch 12/30, enc lr scheduler: [0.0007063346383225212], dec lr scheduler: [0.0007063346383225212]
(Epoch 11, iter 50/787) Average loss so far: 2.805
(Epoch 11, iter 100/787) Average loss so far: 2.796
(Epoch 11, iter 150/787) Average loss so far: 2.791
(Epoch 11, iter 200/787) Average loss so far: 2.791
(Epoch 11, iter 250/787) Average loss so far: 2.781
(Epoch 11, iter 300/787) Average loss so far: 2.795
(Epoch 11, iter 350/787) Average loss so far: 2.779
(Epoch 11, iter 400/787) Average loss so far: 2.771
(Epoch 11, iter 450/787) Average loss so far: 2.784
(Epoch 11, iter 500/787) Average loss so far: 2.782
(Epoch 11, iter 550/787) Average loss so far: 2.787
(Epoch 11, iter 600/787) Average loss so far: 2.787
(Epoch 11, iter 650/787) Average loss so far: 2.784
(Epoch 11, iter 700/787) Average loss so far: 2.785
(Epoch 11, iter 750/787) Average loss so far: 2.784
Average epoch loss: 2.786
This epoch took 8.65

7it [00:01,  5.30it/s]


validation loss: 3.090313128062657


100%|██████████| 7/7 [00:02<00:00,  2.36it/s]
100%|██████████| 793/793 [00:01<00:00, 648.16it/s]


BLEU score: 0.05567891393706685, METEOR score: 0.21062660135719127
Starting epoch 13/30, enc lr scheduler: [0.000657963412215599], dec lr scheduler: [0.000657963412215599]
(Epoch 12, iter 50/787) Average loss so far: 2.742
(Epoch 12, iter 100/787) Average loss so far: 2.729
(Epoch 12, iter 150/787) Average loss so far: 2.754
(Epoch 12, iter 200/787) Average loss so far: 2.751
(Epoch 12, iter 250/787) Average loss so far: 2.738
(Epoch 12, iter 300/787) Average loss so far: 2.758
(Epoch 12, iter 350/787) Average loss so far: 2.751
(Epoch 12, iter 400/787) Average loss so far: 2.749
(Epoch 12, iter 450/787) Average loss so far: 2.746
(Epoch 12, iter 500/787) Average loss so far: 2.754
(Epoch 12, iter 550/787) Average loss so far: 2.743
(Epoch 12, iter 600/787) Average loss so far: 2.757
(Epoch 12, iter 650/787) Average loss so far: 2.753
(Epoch 12, iter 700/787) Average loss so far: 2.768
(Epoch 12, iter 750/787) Average loss so far: 2.748
Average epoch loss: 2.749
This epoch took 8.67595

7it [00:01,  5.21it/s]


validation loss: 3.078939846583775


100%|██████████| 7/7 [00:03<00:00,  2.16it/s]
100%|██████████| 793/793 [00:01<00:00, 640.63it/s]


BLEU score: 0.05447397472432387, METEOR score: 0.20767725983235402
Starting epoch 14/30, enc lr scheduler: [0.0006079162869547909], dec lr scheduler: [0.0006079162869547909]
(Epoch 13, iter 50/787) Average loss so far: 2.712
(Epoch 13, iter 100/787) Average loss so far: 2.701
(Epoch 13, iter 150/787) Average loss so far: 2.727
(Epoch 13, iter 200/787) Average loss so far: 2.724
(Epoch 13, iter 250/787) Average loss so far: 2.723
(Epoch 13, iter 300/787) Average loss so far: 2.716
(Epoch 13, iter 350/787) Average loss so far: 2.729
(Epoch 13, iter 400/787) Average loss so far: 2.720
(Epoch 13, iter 450/787) Average loss so far: 2.703
(Epoch 13, iter 500/787) Average loss so far: 2.726
(Epoch 13, iter 550/787) Average loss so far: 2.717
(Epoch 13, iter 600/787) Average loss so far: 2.694
(Epoch 13, iter 650/787) Average loss so far: 2.711
(Epoch 13, iter 700/787) Average loss so far: 2.699
(Epoch 13, iter 750/787) Average loss so far: 2.720
Average epoch loss: 2.715
This epoch took 8.702

7it [00:01,  5.30it/s]


validation loss: 3.0758470467158725


100%|██████████| 7/7 [00:03<00:00,  2.31it/s]
100%|██████████| 793/793 [00:01<00:00, 655.92it/s]


BLEU score: 0.05610531776640169, METEOR score: 0.2134749913995397
Starting epoch 15/30, enc lr scheduler: [0.0005567415893174886], dec lr scheduler: [0.0005567415893174886]
(Epoch 14, iter 50/787) Average loss so far: 2.683
(Epoch 14, iter 100/787) Average loss so far: 2.684
(Epoch 14, iter 150/787) Average loss so far: 2.682
(Epoch 14, iter 200/787) Average loss so far: 2.683
(Epoch 14, iter 250/787) Average loss so far: 2.675
(Epoch 14, iter 300/787) Average loss so far: 2.713
(Epoch 14, iter 350/787) Average loss so far: 2.683
(Epoch 14, iter 400/787) Average loss so far: 2.698
(Epoch 14, iter 450/787) Average loss so far: 2.701
(Epoch 14, iter 500/787) Average loss so far: 2.686
(Epoch 14, iter 550/787) Average loss so far: 2.683
(Epoch 14, iter 600/787) Average loss so far: 2.664
(Epoch 14, iter 650/787) Average loss so far: 2.689
(Epoch 14, iter 700/787) Average loss so far: 2.686
(Epoch 14, iter 750/787) Average loss so far: 2.683
Average epoch loss: 2.686
This epoch took 8.6992

7it [00:01,  5.34it/s]


validation loss: 3.067518813269479


100%|██████████| 7/7 [00:02<00:00,  2.48it/s]
100%|██████████| 793/793 [00:01<00:00, 718.04it/s]


BLEU score: 0.05618418002695732, METEOR score: 0.2173668475571425
Starting epoch 16/30, enc lr scheduler: [0.0005050000000000002], dec lr scheduler: [0.0005050000000000002]
(Epoch 15, iter 50/787) Average loss so far: 2.676
(Epoch 15, iter 100/787) Average loss so far: 2.648
(Epoch 15, iter 150/787) Average loss so far: 2.659
(Epoch 15, iter 200/787) Average loss so far: 2.662
(Epoch 15, iter 250/787) Average loss so far: 2.662
(Epoch 15, iter 300/787) Average loss so far: 2.658
(Epoch 15, iter 350/787) Average loss so far: 2.671
(Epoch 15, iter 400/787) Average loss so far: 2.657
(Epoch 15, iter 450/787) Average loss so far: 2.644
(Epoch 15, iter 500/787) Average loss so far: 2.656
(Epoch 15, iter 550/787) Average loss so far: 2.664
(Epoch 15, iter 600/787) Average loss so far: 2.657
(Epoch 15, iter 650/787) Average loss so far: 2.666
(Epoch 15, iter 700/787) Average loss so far: 2.645
(Epoch 15, iter 750/787) Average loss so far: 2.655
Average epoch loss: 2.659
This epoch took 8.6669

7it [00:01,  5.05it/s]


validation loss: 3.0657551969800676


100%|██████████| 7/7 [00:02<00:00,  2.42it/s]
100%|██████████| 793/793 [00:01<00:00, 717.99it/s]


BLEU score: 0.05615294677876267, METEOR score: 0.2133591944674418
Starting epoch 17/30, enc lr scheduler: [0.0004532584106825117], dec lr scheduler: [0.0004532584106825117]
(Epoch 16, iter 50/787) Average loss so far: 2.639
(Epoch 16, iter 100/787) Average loss so far: 2.630
(Epoch 16, iter 150/787) Average loss so far: 2.627
(Epoch 16, iter 200/787) Average loss so far: 2.626
(Epoch 16, iter 250/787) Average loss so far: 2.629
(Epoch 16, iter 300/787) Average loss so far: 2.650
(Epoch 16, iter 350/787) Average loss so far: 2.658
(Epoch 16, iter 400/787) Average loss so far: 2.633
(Epoch 16, iter 450/787) Average loss so far: 2.648
(Epoch 16, iter 500/787) Average loss so far: 2.617
(Epoch 16, iter 550/787) Average loss so far: 2.637
(Epoch 16, iter 600/787) Average loss so far: 2.631
(Epoch 16, iter 650/787) Average loss so far: 2.639
(Epoch 16, iter 700/787) Average loss so far: 2.638
(Epoch 16, iter 750/787) Average loss so far: 2.647
Average epoch loss: 2.636
This epoch took 8.6888

7it [00:01,  5.22it/s]


validation loss: 3.0647445065634593


100%|██████████| 7/7 [00:02<00:00,  2.72it/s]
100%|██████████| 793/793 [00:00<00:00, 797.12it/s]


BLEU score: 0.05162445812511413, METEOR score: 0.2073657691543935
Starting epoch 18/30, enc lr scheduler: [0.00040208371304520916], dec lr scheduler: [0.00040208371304520916]
(Epoch 17, iter 50/787) Average loss so far: 2.605
(Epoch 17, iter 100/787) Average loss so far: 2.604
(Epoch 17, iter 150/787) Average loss so far: 2.611
(Epoch 17, iter 200/787) Average loss so far: 2.609
(Epoch 17, iter 250/787) Average loss so far: 2.625
(Epoch 17, iter 300/787) Average loss so far: 2.628
(Epoch 17, iter 350/787) Average loss so far: 2.601
(Epoch 17, iter 400/787) Average loss so far: 2.615
(Epoch 17, iter 450/787) Average loss so far: 2.599
(Epoch 17, iter 500/787) Average loss so far: 2.617
(Epoch 17, iter 550/787) Average loss so far: 2.613
(Epoch 17, iter 600/787) Average loss so far: 2.625
(Epoch 17, iter 650/787) Average loss so far: 2.611
(Epoch 17, iter 700/787) Average loss so far: 2.633
(Epoch 17, iter 750/787) Average loss so far: 2.617
Average epoch loss: 2.615
This epoch took 8.66

7it [00:01,  5.27it/s]


validation loss: 3.062435899462019


100%|██████████| 7/7 [00:02<00:00,  2.50it/s]
100%|██████████| 793/793 [00:01<00:00, 737.56it/s]


BLEU score: 0.05444598116058762, METEOR score: 0.21361483967861572
Starting epoch 19/30, enc lr scheduler: [0.00035203658778440114], dec lr scheduler: [0.00035203658778440114]
(Epoch 18, iter 50/787) Average loss so far: 2.583
(Epoch 18, iter 100/787) Average loss so far: 2.577
(Epoch 18, iter 150/787) Average loss so far: 2.597
(Epoch 18, iter 200/787) Average loss so far: 2.593
(Epoch 18, iter 250/787) Average loss so far: 2.596
(Epoch 18, iter 300/787) Average loss so far: 2.615
(Epoch 18, iter 350/787) Average loss so far: 2.590
(Epoch 18, iter 400/787) Average loss so far: 2.584
(Epoch 18, iter 450/787) Average loss so far: 2.598
(Epoch 18, iter 500/787) Average loss so far: 2.606
(Epoch 18, iter 550/787) Average loss so far: 2.591
(Epoch 18, iter 600/787) Average loss so far: 2.604
(Epoch 18, iter 650/787) Average loss so far: 2.588
(Epoch 18, iter 700/787) Average loss so far: 2.603
(Epoch 18, iter 750/787) Average loss so far: 2.613
Average epoch loss: 2.596
This epoch took 8.6

7it [00:01,  5.30it/s]


validation loss: 3.0591438157217845


100%|██████████| 7/7 [00:02<00:00,  2.38it/s]
100%|██████████| 793/793 [00:01<00:00, 758.33it/s]


BLEU score: 0.05424589140761788, METEOR score: 0.21430080632880433
Starting epoch 20/30, enc lr scheduler: [0.00030366536167747904], dec lr scheduler: [0.00030366536167747904]
(Epoch 19, iter 50/787) Average loss so far: 2.582
(Epoch 19, iter 100/787) Average loss so far: 2.569
(Epoch 19, iter 150/787) Average loss so far: 2.571
(Epoch 19, iter 200/787) Average loss so far: 2.587
(Epoch 19, iter 250/787) Average loss so far: 2.571
(Epoch 19, iter 300/787) Average loss so far: 2.579
(Epoch 19, iter 350/787) Average loss so far: 2.584
(Epoch 19, iter 400/787) Average loss so far: 2.580
(Epoch 19, iter 450/787) Average loss so far: 2.582
(Epoch 19, iter 500/787) Average loss so far: 2.583
(Epoch 19, iter 550/787) Average loss so far: 2.581
(Epoch 19, iter 600/787) Average loss so far: 2.582
(Epoch 19, iter 650/787) Average loss so far: 2.563
(Epoch 19, iter 700/787) Average loss so far: 2.587
(Epoch 19, iter 750/787) Average loss so far: 2.584
Average epoch loss: 2.580
This epoch took 8.6

7it [00:01,  5.19it/s]


validation loss: 3.0615645817347934


100%|██████████| 7/7 [00:02<00:00,  2.65it/s]
100%|██████████| 793/793 [00:00<00:00, 795.05it/s]


BLEU score: 0.05251383464691767, METEOR score: 0.2116644609179294
Starting epoch 21/30, enc lr scheduler: [0.00025750000000000013], dec lr scheduler: [0.00025750000000000013]
(Epoch 20, iter 50/787) Average loss so far: 2.563
(Epoch 20, iter 100/787) Average loss so far: 2.560
(Epoch 20, iter 150/787) Average loss so far: 2.569
(Epoch 20, iter 200/787) Average loss so far: 2.559
(Epoch 20, iter 250/787) Average loss so far: 2.561
(Epoch 20, iter 300/787) Average loss so far: 2.551
(Epoch 20, iter 350/787) Average loss so far: 2.561
(Epoch 20, iter 400/787) Average loss so far: 2.556
(Epoch 20, iter 450/787) Average loss so far: 2.571
(Epoch 20, iter 500/787) Average loss so far: 2.571
(Epoch 20, iter 550/787) Average loss so far: 2.564
(Epoch 20, iter 600/787) Average loss so far: 2.581
(Epoch 20, iter 650/787) Average loss so far: 2.567
(Epoch 20, iter 700/787) Average loss so far: 2.577
(Epoch 20, iter 750/787) Average loss so far: 2.566
Average epoch loss: 2.565
This epoch took 8.65

7it [00:01,  5.25it/s]


validation loss: 3.0620672702789307


100%|██████████| 7/7 [00:02<00:00,  2.73it/s]
100%|██████████| 793/793 [00:01<00:00, 774.12it/s]


BLEU score: 0.05325391809008416, METEOR score: 0.21406562669861987
Starting epoch 22/30, enc lr scheduler: [0.00021404630011522585], dec lr scheduler: [0.00021404630011522585]
(Epoch 21, iter 50/787) Average loss so far: 2.539
(Epoch 21, iter 100/787) Average loss so far: 2.549
(Epoch 21, iter 150/787) Average loss so far: 2.547
(Epoch 21, iter 200/787) Average loss so far: 2.550
(Epoch 21, iter 250/787) Average loss so far: 2.555
(Epoch 21, iter 300/787) Average loss so far: 2.554
(Epoch 21, iter 350/787) Average loss so far: 2.556
(Epoch 21, iter 400/787) Average loss so far: 2.543
(Epoch 21, iter 450/787) Average loss so far: 2.542
(Epoch 21, iter 500/787) Average loss so far: 2.563
(Epoch 21, iter 550/787) Average loss so far: 2.551
(Epoch 21, iter 600/787) Average loss so far: 2.551
(Epoch 21, iter 650/787) Average loss so far: 2.552
(Epoch 21, iter 700/787) Average loss so far: 2.554
(Epoch 21, iter 750/787) Average loss so far: 2.568
Average epoch loss: 2.552
This epoch took 8.6

7it [00:01,  5.22it/s]


validation loss: 3.0650003296988353


100%|██████████| 7/7 [00:02<00:00,  2.71it/s]
100%|██████████| 793/793 [00:01<00:00, 792.15it/s]


BLEU score: 0.05365965581686292, METEOR score: 0.21241248063528498
Starting epoch 23/30, enc lr scheduler: [0.00017378034985236535], dec lr scheduler: [0.00017378034985236535]
(Epoch 22, iter 50/787) Average loss so far: 2.535
(Epoch 22, iter 100/787) Average loss so far: 2.549
(Epoch 22, iter 150/787) Average loss so far: 2.538
(Epoch 22, iter 200/787) Average loss so far: 2.552
(Epoch 22, iter 250/787) Average loss so far: 2.555
(Epoch 22, iter 300/787) Average loss so far: 2.536
(Epoch 22, iter 350/787) Average loss so far: 2.558
(Epoch 22, iter 400/787) Average loss so far: 2.551
(Epoch 22, iter 450/787) Average loss so far: 2.554
(Epoch 22, iter 500/787) Average loss so far: 2.525
(Epoch 22, iter 550/787) Average loss so far: 2.539
(Epoch 22, iter 600/787) Average loss so far: 2.538
(Epoch 22, iter 650/787) Average loss so far: 2.521
(Epoch 22, iter 700/787) Average loss so far: 2.532
(Epoch 22, iter 750/787) Average loss so far: 2.527
Average epoch loss: 2.541
This epoch took 8.6

7it [00:01,  5.24it/s]


validation loss: 3.064901147569929


100%|██████████| 7/7 [00:02<00:00,  2.53it/s]
100%|██████████| 793/793 [00:01<00:00, 751.96it/s]


BLEU score: 0.05429019863071828, METEOR score: 0.2108751834143225
Starting epoch 24/30, enc lr scheduler: [0.00013714331138868998], dec lr scheduler: [0.00013714331138868998]
(Epoch 23, iter 50/787) Average loss so far: 2.534
(Epoch 23, iter 100/787) Average loss so far: 2.524
(Epoch 23, iter 150/787) Average loss so far: 2.534
(Epoch 23, iter 200/787) Average loss so far: 2.544
(Epoch 23, iter 250/787) Average loss so far: 2.528
(Epoch 23, iter 300/787) Average loss so far: 2.532
(Epoch 23, iter 350/787) Average loss so far: 2.524
(Epoch 23, iter 400/787) Average loss so far: 2.545
(Epoch 23, iter 450/787) Average loss so far: 2.529
(Epoch 23, iter 500/787) Average loss so far: 2.531
(Epoch 23, iter 550/787) Average loss so far: 2.525
(Epoch 23, iter 600/787) Average loss so far: 2.529
(Epoch 23, iter 650/787) Average loss so far: 2.528
(Epoch 23, iter 700/787) Average loss so far: 2.537
(Epoch 23, iter 750/787) Average loss so far: 2.539
Average epoch loss: 2.532
This epoch took 8.65

7it [00:01,  5.12it/s]


validation loss: 3.0661655834742954


100%|██████████| 7/7 [00:02<00:00,  2.60it/s]
100%|██████████| 793/793 [00:01<00:00, 733.89it/s]


BLEU score: 0.05508513737660636, METEOR score: 0.21475239961444473
Starting epoch 25/30, enc lr scheduler: [0.00010453658778440108], dec lr scheduler: [0.00010453658778440108]
(Epoch 24, iter 50/787) Average loss so far: 2.527
(Epoch 24, iter 100/787) Average loss so far: 2.526
(Epoch 24, iter 150/787) Average loss so far: 2.524
(Epoch 24, iter 200/787) Average loss so far: 2.541
(Epoch 24, iter 250/787) Average loss so far: 2.518
(Epoch 24, iter 300/787) Average loss so far: 2.529
(Epoch 24, iter 350/787) Average loss so far: 2.522
(Epoch 24, iter 400/787) Average loss so far: 2.517
(Epoch 24, iter 450/787) Average loss so far: 2.510
(Epoch 24, iter 500/787) Average loss so far: 2.528
(Epoch 24, iter 550/787) Average loss so far: 2.537
(Epoch 24, iter 600/787) Average loss so far: 2.520
(Epoch 24, iter 650/787) Average loss so far: 2.530
(Epoch 24, iter 700/787) Average loss so far: 2.523
(Epoch 24, iter 750/787) Average loss so far: 2.518
Average epoch loss: 2.524
This epoch took 8.6

7it [00:01,  5.27it/s]


validation loss: 3.06569732938494


100%|██████████| 7/7 [00:02<00:00,  2.73it/s]
100%|██████████| 793/793 [00:01<00:00, 781.48it/s]


BLEU score: 0.05240369931105146, METEOR score: 0.21010115919574035
Starting epoch 26/30, enc lr scheduler: [7.631742512670285e-05], dec lr scheduler: [7.631742512670285e-05]
(Epoch 25, iter 50/787) Average loss so far: 2.519
(Epoch 25, iter 100/787) Average loss so far: 2.513
(Epoch 25, iter 150/787) Average loss so far: 2.532
(Epoch 25, iter 200/787) Average loss so far: 2.528
(Epoch 25, iter 250/787) Average loss so far: 2.509
(Epoch 25, iter 300/787) Average loss so far: 2.518
(Epoch 25, iter 350/787) Average loss so far: 2.528
(Epoch 25, iter 400/787) Average loss so far: 2.508
(Epoch 25, iter 450/787) Average loss so far: 2.514
(Epoch 25, iter 500/787) Average loss so far: 2.526
(Epoch 25, iter 550/787) Average loss so far: 2.500
(Epoch 25, iter 600/787) Average loss so far: 2.517
(Epoch 25, iter 650/787) Average loss so far: 2.523
(Epoch 25, iter 700/787) Average loss so far: 2.510
(Epoch 25, iter 750/787) Average loss so far: 2.519
Average epoch loss: 2.518
This epoch took 8.698

7it [00:01,  5.42it/s]


validation loss: 3.066434485571725


100%|██████████| 7/7 [00:02<00:00,  2.52it/s]
100%|██████████| 793/793 [00:01<00:00, 770.50it/s]


BLEU score: 0.053463804062823625, METEOR score: 0.21029893221004417
Starting epoch 27/30, enc lr scheduler: [5.279499846691252e-05], dec lr scheduler: [5.279499846691252e-05]
(Epoch 26, iter 50/787) Average loss so far: 2.511
(Epoch 26, iter 100/787) Average loss so far: 2.507
(Epoch 26, iter 150/787) Average loss so far: 2.510
(Epoch 26, iter 200/787) Average loss so far: 2.510
(Epoch 26, iter 250/787) Average loss so far: 2.525
(Epoch 26, iter 300/787) Average loss so far: 2.518
(Epoch 26, iter 350/787) Average loss so far: 2.520
(Epoch 26, iter 400/787) Average loss so far: 2.511
(Epoch 26, iter 450/787) Average loss so far: 2.507
(Epoch 26, iter 500/787) Average loss so far: 2.520
(Epoch 26, iter 550/787) Average loss so far: 2.494
(Epoch 26, iter 600/787) Average loss so far: 2.519
(Epoch 26, iter 650/787) Average loss so far: 2.506
(Epoch 26, iter 700/787) Average loss so far: 2.524
(Epoch 26, iter 750/787) Average loss so far: 2.526
Average epoch loss: 2.513
This epoch took 8.69

7it [00:01,  4.92it/s]


validation loss: 3.0676874773842946


100%|██████████| 7/7 [00:02<00:00,  2.45it/s]
100%|██████████| 793/793 [00:01<00:00, 761.40it/s]


BLEU score: 0.05483833839420476, METEOR score: 0.21259949218767468
Starting epoch 28/30, enc lr scheduler: [3.4227024433899005e-05], dec lr scheduler: [3.4227024433899005e-05]
(Epoch 27, iter 50/787) Average loss so far: 2.514
(Epoch 27, iter 100/787) Average loss so far: 2.506
(Epoch 27, iter 150/787) Average loss so far: 2.513
(Epoch 27, iter 200/787) Average loss so far: 2.510
(Epoch 27, iter 250/787) Average loss so far: 2.520
(Epoch 27, iter 300/787) Average loss so far: 2.524
(Epoch 27, iter 350/787) Average loss so far: 2.502
(Epoch 27, iter 400/787) Average loss so far: 2.503
(Epoch 27, iter 450/787) Average loss so far: 2.507
(Epoch 27, iter 500/787) Average loss so far: 2.508
(Epoch 27, iter 550/787) Average loss so far: 2.498
(Epoch 27, iter 600/787) Average loss so far: 2.507
(Epoch 27, iter 650/787) Average loss so far: 2.519
(Epoch 27, iter 700/787) Average loss so far: 2.508
(Epoch 27, iter 750/787) Average loss so far: 2.507
Average epoch loss: 2.510
This epoch took 8.7

7it [00:01,  5.33it/s]


validation loss: 3.068269661494664


100%|██████████| 7/7 [00:02<00:00,  2.37it/s]
100%|██████████| 793/793 [00:01<00:00, 730.91it/s]


BLEU score: 0.05357286007825388, METEOR score: 0.21187169877005785
Starting epoch 29/30, enc lr scheduler: [2.0816937636766188e-05], dec lr scheduler: [2.0816937636766188e-05]
(Epoch 28, iter 50/787) Average loss so far: 2.503
(Epoch 28, iter 100/787) Average loss so far: 2.517
(Epoch 28, iter 150/787) Average loss so far: 2.502
(Epoch 28, iter 200/787) Average loss so far: 2.504
(Epoch 28, iter 250/787) Average loss so far: 2.514
(Epoch 28, iter 300/787) Average loss so far: 2.506
(Epoch 28, iter 350/787) Average loss so far: 2.518
(Epoch 28, iter 400/787) Average loss so far: 2.496
(Epoch 28, iter 450/787) Average loss so far: 2.508
(Epoch 28, iter 500/787) Average loss so far: 2.511
(Epoch 28, iter 550/787) Average loss so far: 2.514
(Epoch 28, iter 600/787) Average loss so far: 2.509
(Epoch 28, iter 650/787) Average loss so far: 2.483
(Epoch 28, iter 700/787) Average loss so far: 2.515
(Epoch 28, iter 750/787) Average loss so far: 2.511
Average epoch loss: 2.507
This epoch took 8.6

7it [00:01,  5.29it/s]


validation loss: 3.0683638708932057


100%|██████████| 7/7 [00:02<00:00,  2.39it/s]
100%|██████████| 793/793 [00:01<00:00, 760.32it/s]


BLEU score: 0.053924556203878635, METEOR score: 0.2136804902885653
Starting epoch 30/30, enc lr scheduler: [1.2711661792704668e-05], dec lr scheduler: [1.2711661792704668e-05]
(Epoch 29, iter 50/787) Average loss so far: 2.496
(Epoch 29, iter 100/787) Average loss so far: 2.494
(Epoch 29, iter 150/787) Average loss so far: 2.521
(Epoch 29, iter 200/787) Average loss so far: 2.511
(Epoch 29, iter 250/787) Average loss so far: 2.510
(Epoch 29, iter 300/787) Average loss so far: 2.511
(Epoch 29, iter 350/787) Average loss so far: 2.505
(Epoch 29, iter 400/787) Average loss so far: 2.493
(Epoch 29, iter 450/787) Average loss so far: 2.497
(Epoch 29, iter 500/787) Average loss so far: 2.521
(Epoch 29, iter 550/787) Average loss so far: 2.491
(Epoch 29, iter 600/787) Average loss so far: 2.503
(Epoch 29, iter 650/787) Average loss so far: 2.506
(Epoch 29, iter 700/787) Average loss so far: 2.509
(Epoch 29, iter 750/787) Average loss so far: 2.515
Average epoch loss: 2.506
This epoch took 8.6

7it [00:01,  5.22it/s]


validation loss: 3.0682195595332553


100%|██████████| 7/7 [00:02<00:00,  2.51it/s]
100%|██████████| 793/793 [00:01<00:00, 777.01it/s]

BLEU score: 0.05360497700665978, METEOR score: 0.21295206921320584





---

## Evaluation

In [16]:
all_ingredients_lst = get_all_ingredients("./ingredient_set.json")

### Without attention

In [16]:
load_model(encoder, decoder, "adam_without_intermediate_tags_with_val_wd0_lr1e-3_ep_24")

In [17]:
all_decoder_outs, all_gt_recipes, all_gt_ingredients = eval(
    encoder, decoder, test_ds, vocab, batch_size=128, decoder_mode="basic",
    max_recipe_len=MAX_RECIPE_LEN)

  0%|          | 0/7 [00:00<?, ?it/s]

100%|██████████| 7/7 [00:05<00:00,  1.30it/s]


In [18]:
calc_bleu(all_gt_recipes, all_decoder_outs)

0.04449619754164479

In [19]:
calc_meteor(all_gt_recipes, all_decoder_outs, split_gt=False)

  0%|          | 0/774 [00:00<?, ?it/s]

100%|██████████| 774/774 [00:03<00:00, 209.73it/s]


0.20156387133765763

In [21]:
convert_eval_out_to_get_ingredient_metrics(all_decoder_outs, all_gt_ingredients,
                                           vocab, all_ingredients_lst)

774it [00:00, 1141.23it/s]

Avg. % given ingredients: 25.563%
Avg. number of extra ingredients: 2.978





### With attention

In [45]:
load_model(encoder_attn, decoder_attn, "attn_adam_without_intermediate_tags_wd0_lr1e-3_ep_26")

In [46]:
all_decoder_outs, all_gt_recipes, all_gt_ingredients = eval(
    encoder_attn, decoder_attn, test_ds, vocab, batch_size=128, decoder_mode="attention",
    max_recipe_len=MAX_RECIPE_LEN)

  0%|          | 0/7 [00:00<?, ?it/s]

100%|██████████| 7/7 [00:06<00:00,  1.04it/s]


In [47]:
calc_bleu(all_gt_recipes, all_decoder_outs)

0.04975297680717382

In [48]:
calc_meteor(all_gt_recipes, all_decoder_outs, split_gt=False)

  0%|          | 0/774 [00:00<?, ?it/s]

100%|██████████| 774/774 [00:02<00:00, 292.46it/s]


0.20857380991725616

In [26]:
convert_eval_out_to_get_ingredient_metrics(all_decoder_outs, all_gt_ingredients,
                                           vocab, all_ingredients_lst)

774it [00:00, 1107.21it/s]

Avg. % given ingredients: 33.453%
Avg. number of extra ingredients: 2.295





## Attention + pretrained embeddings

In [49]:
load_model(encoder_pretrained_embed, decoder_pretrained_embed, 
           "pretrained_emb_attn_adam_without_intermediate_tags_wd0_lr1e-3_ep_29")

In [56]:
all_decoder_outs, all_gt_recipes, all_gt_ingredients = eval(
    encoder_pretrained_embed, decoder_pretrained_embed, test_ds, vocab, batch_size=128, 
    decoder_mode="attention", max_recipe_len=MAX_RECIPE_LEN)

  0%|          | 0/7 [00:00<?, ?it/s]

100%|██████████| 7/7 [00:07<00:00,  1.04s/it]


In [57]:
calc_bleu(all_gt_recipes, all_decoder_outs)

0.05573056259570414

In [58]:
calc_meteor(all_gt_recipes, all_decoder_outs, split_gt=False)

100%|██████████| 774/774 [00:02<00:00, 289.11it/s]


0.20985559115168426

In [36]:
convert_eval_out_to_get_ingredient_metrics(all_decoder_outs, all_gt_ingredients,
                                           vocab, all_ingredients_lst)

774it [00:00, 1054.15it/s]

Avg. % given ingredients: 31.074%
Avg. number of extra ingredients: 2.104





## Attention + Multilayer + Neurologic Decoding

### Without Neurologic Decoding

In [22]:
load_model(encoder_multilayer_attn, decoder_multilayer_attn, "multilayer_attn_adam_without_intermediate_tags_wd0_lr1e-3_ep_14")

FileNotFoundError: [Errno 2] No such file or directory: 'saved_models/encoder_multilayer_attn_adam_without_intermediate_tags_wd0_lr1e-3_ep_14.pth'

In [38]:
all_decoder_outs_multi, all_gt_recipes_multi, all_gt_ingredients_multi = eval(
    encoder_multilayer_attn, decoder_multilayer_attn, test_ds, vocab, batch_size=128, decoder_mode="attention",
    max_recipe_len=MAX_RECIPE_LEN)

100%|██████████| 7/7 [00:02<00:00,  2.64it/s]


In [39]:
calc_bleu(all_decoder_outs_multi, all_gt_recipes_multi)

0.053654397203939004

In [40]:
calc_meteor(all_gt_recipes_multi, all_decoder_outs_multi, split_gt=False)

100%|██████████| 774/774 [00:01<00:00, 724.46it/s]


0.21415571426435778

In [41]:
convert_eval_out_to_get_ingredient_metrics(all_decoder_outs_multi, all_gt_ingredients_multi,
                                           vocab, all_ingredients_lst)

774it [00:00, 1134.83it/s]

Avg. % given ingredients: 32.267%
Avg. number of extra ingredients: 2.223





### With Neurologic Decoding

! WARNING: this code takes around 20-30 mins to run

In [46]:
k=3
alpha=50 # likelihood
beta = 2 # num constraints satisfied
neg_constraint_penalty, likelihood_penalty, low_irr_satisfaction_penalty = 10, 0.5, 0.5
lam = 5.0

all_decoder_outs, all_gt_recipes, all_gt_ings = eval_neuro_decoding(
    encoder_multilayer_attn, decoder_multilayer_attn, test_ds, vocab, all_ingredients_lst, k=k, alpha=alpha, beta=beta, 
    neg_constraint_penalty=neg_constraint_penalty, likelihood_penalty=likelihood_penalty, 
    low_irr_satisfaction_penalty=low_irr_satisfaction_penalty, lam=lam)


100%|██████████| 774/774 [16:36<00:00,  1.29s/it]


In [47]:
calc_bleu(all_gt_recipes, all_decoder_outs)

0.05810241560787664

In [48]:
calc_meteor(all_gt_recipes, all_decoder_outs, split_gt=False)

100%|██████████| 774/774 [00:01<00:00, 712.33it/s]


0.23339527798724602

In [49]:
convert_eval_out_to_get_ingredient_metrics(all_decoder_outs, all_gt_ings,
                                           vocab, all_ingredients_lst, skip_ing_processing=True)

774it [00:00, 1072.67it/s]

Avg. % given ingredients: 79.462%
Avg. number of extra ingredients: 0.274





---

## Metric Sample

In [51]:
all_ings = get_all_ingredients("./ingredient_set.json")
all_ings_regex = get_ingredients_regex(all_ings)
metric_sample_ings, metric_sample_gold_recipe, metric_sample_generated_recipe = \
    load_metric_sample("./metric_sample.txt")

In [52]:
prop_inp_ings, n_extra_ings = get_prop_input_num_extra_ingredients_m(
    metric_sample_ings, metric_sample_generated_recipe, all_ings_regex, verbose=True)
print(f"\nproportion of input ingredients: {prop_inp_ings}\nnumber of extra ingredients: {n_extra_ings}")

=====Input ingredients in text=====
['lemon juice', 'orange juice', 'water', 'sugar', 'strawberries']

=====All ingredients in text===== 
['lemon juice', 'orange juice', 'water', 'cantaloupe', 'vanilla ice cream', 'sugar', 'strawberries']

proportion of input ingredients: 1.0
number of extra ingredients: 2


In [55]:
bleu_score = calc_bleu([metric_sample_gold_recipe], [metric_sample_generated_recipe], split_gt=True, split_gen=True)
meteor_score = calc_meteor([metric_sample_gold_recipe], [metric_sample_generated_recipe], split_gt=True, split_gen=True)
print(f"BLEU score: {bleu_score}, METEOR score: {meteor_score}")

  0%|          | 0/1 [00:00<?, ?it/s]

100%|██████████| 1/1 [00:00<00:00, 458.90it/s]

BLEU score: 0.2757364975156813, METEOR score: 0.5479209577754892





---

# Qualitative Evaluation

In [29]:
ingredients_sample_orig = "2 c sugar, 1/4 c lemon juice, 1 c water, 1/3 c orange juice, 8 c strawberries"

perform relevant preprocessing steps (note: these operations are taken from `preprocess_data` and is the same preprocessing done on the dataset)

In [None]:
def prepare_qual_ingredients()

In [30]:
ingredients_sample = re.sub("([^0-9a-zA-Z.'\"/ ])", r" \1 ", ingredients_sample_orig)
ingredients_sample = '<INGREDIENT_START> ' + ingredients_sample + ' <INGREDIENT_END>'
ingredients_sample = re.sub('[ ]{2,}', " ", ingredients_sample)

In [32]:
ingredients_sample

'<INGREDIENT_START> 2 c sugar , 1/4 c lemon juice , 1 c water , 1/3 c orange juice , 8 c strawberries <INGREDIENT_END>'

In [37]:
ingredients_sample_idxs = torch.tensor([vocab.word2index(w) for w in ingredients_sample.split(" ")],
                                       dtype=torch.long, device=DEVICE)
# convert to batch form with batch size 1 as model is expecting batched input
ingredients_sample_idxs = ingredients_sample_idxs[None] 

In [38]:
ingredients_sample_idxs

tensor([[   0,   22,   12,   16,   78,   32,   12,  641,  900,   78,   17,   12,
          311,   78,  235,   12,  791,  900,   78,  385,   12, 1218,    1]],
       device='cuda:0')

In [39]:
ingr_lens = torch.tensor([len(x) for x in ingredients_sample_idxs], dtype=torch.long, device=DEVICE)

In [40]:
ingr_lens

tensor([23], device='cuda:0')

In [41]:
qual_res_basic = get_predictions_iter(ingredients_sample_idxs, ingr_lens, encoder, decoder,
                                        vocab)

In [43]:
' '.join(qual_res_basic[0])

'<RECIPE_START> combine sugar , cornstarch and water in a saucepan bring to a boil , stirring constantly boil for 1 minute remove from heat and stir in lemon juice and lemon juice pour into a large bowl and stir in the orange juice and lemon juice pour into a large bowl and stir in the orange juice and lemon juice pour into a large bowl and chill <RECIPE_END>'

In [60]:
qual_res_attn = get_predictions_iter(ingredients_sample_idxs, ingr_lens, encoder_attn, decoder_attn,
                                    vocab, decoder_mode="attention")

In [62]:
' '.join(qual_res_attn[0])

'<RECIPE_START> combine sugar , water , lemon juice and salt in a saucepan bring to a boil , stirring constantly boil for 5 minutes , stirring constantly remove from heat and stir in lemon juice and pour into a sterilized jars seal and store in refrigerator <RECIPE_END>'

In [63]:
qual_res_attn_pre_embed = get_predictions_iter(
    ingredients_sample_idxs, ingr_lens, encoder_pretrained_embed, decoder_pretrained_embed,
    vocab, decoder_mode="attention")

In [65]:
' '.join(qual_res_attn_pre_embed[0])

'<RECIPE_START> combine sugar , water , and lemon juice in a saucepan bring to a boil over medium heat , stirring constantly , until sugar dissolves remove from heat and cool to room temperature add vanilla and mix well chill <RECIPE_END>'