In [1]:
import torch
import pandas as pd
import numpy as np
import random
import torch.nn as nn
from torch.autograd import Variable
from torch.nn import Parameter
from torch import Tensor
import torch.nn.functional as F
import matplotlib.image as img
import matplotlib.pyplot as plt
import torchvision
import PIL
import math
import datetime
import os
from bpemb import BPEmb
from torch import nn
from image_caption_dataset import ImageCaptionDataset
from multi_bpe import MultiBPE
from multimodal_model import MMLSTM, BenchmarkLSTM, BenchmarkCustomLSTM
from torch.utils.data import DataLoader
from tqdm import tqdm
import transformers

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
LOAD_PATH = './data/'
seed = 420
num_data = 200000
maxlen = 49
epochs = 15
train_pct = 0.8
bs=16
lr = 1.0/math.sqrt(32/bs)
is_multimodal=False
train_visual_module=False

In [3]:
torch.manual_seed(seed)
# torch.cuda.manual_seed(seed)
# torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
np.random.seed(seed)  # Numpy module.
random.seed(seed)  # Python random module.
# torch.backends.cudnn.benchmark = False
# torch.backends.cudnn.deterministic = True

In [4]:
def save_checkpoint(epoch, 
                    model, 
                    optimizer, 
                    scheduler,
                    loss,
                    FNAME):
    
    today = datetime.date.today()
    PATH = f'./checkpoints/{today.strftime("checkpoint-%m-%d-%Y")}/'
    if not os.path.exists(PATH):
        os.makedirs(PATH)
    torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict() if scheduler !=None else None,
            'loss': loss,
            }, f'{PATH}{FNAME}')

In [5]:
def load_latest_checkpoint(PATH, model, optimizer, scheduler):
    checkpoint = torch.load(PATH)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    return epoch, loss

In [6]:
def _init_fn(worker_id):
    np.random.seed(int(seed))
    
def process_img(image_path): 
    transform = torchvision.transforms.Compose([
        # Resize image to 224 x 224 as required by most vision models
        torchvision.transforms.Resize(
            size=(224, 224)
        ),
        # Convert PIL image to tensor with image values in [0, 1]
        torchvision.transforms.ToTensor(),

        torchvision.transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ])
    
    im = PIL.Image.open(image_path)
    image = im.convert('RGB')
    image = transform(image)
    
    return image.view(1, image.size(0), image.size(1), image.size(2))

def get_dataset(tokenizer, path=LOAD_PATH, num_data = None, maxlen=64):
    df = pd.read_csv(path + 'ml_stacked_data.csv')
    df = df[df.label == 'eng'].reset_index(drop=True)
    if num_data != None:
        text_df = df[:num_data]
    else:
        text_df = df
    dataset = ImageCaptionDataset(text_df, tokenizer, maxlen=maxlen)

    return dataset

def train(ml_model, 
          epochs, 
          train_data, 
          val_data, 
          loss_fct, 
          optimizer,
          scheduler = None,
          clip=2.0):
    epoch_train_losses = []
    epoch_val_losses = []
        
    for epoch in range(epochs):
        ml_model.train()
        epoch_train_loss, num_train_steps = 0, 0
        for i, batch in enumerate(tqdm(train_data)):
            text, img, target = batch['input_ids'].to(device), batch['image'].to(device), batch['label_ids'].to(device)
            #print(img.size())
            output = ml_model(text, img)
            ml_model.zero_grad()
            loss = torch.nn.functional.cross_entropy(output.view(-1, 320001), target.view(-1))
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(ml_model.parameters(), clip)
            optimizer.step()
            if scheduler is not None:
                scheduler.step()
            epoch_train_loss += loss.item()
            num_train_steps += 1
            
            if i % 2000 == 1:
                print("Current train loss:", epoch_train_loss/num_train_steps)

        current_train_loss = epoch_train_loss/len(train_data)
        epoch_train_losses.append(current_train_loss)
        
        ml_model.eval()
        epoch_val_loss = 0
        for i, batch in enumerate(tqdm(val_data)):
            text, img, target = batch['input_ids'].to(device), batch['image'].to(device), batch['label_ids'].to(device)
            
            with torch.no_grad():
                output = ml_model(text, img)
                loss = torch.nn.functional.cross_entropy(output.view(-1, output.size(-1)), target.view(-1))
            epoch_val_loss += loss.item()
            
        current_val_loss = epoch_val_loss/len(val_data)
        epoch_val_losses.append(current_val_loss)
        perplexity = math.exp
        print("=> Saving checkpoint...")
        if is_multimodal:
            if train_visual_module:
                FNAME = f'finetuned_visual_multimodal_lstm_{num_data}_{lr}_{epoch}-{current_val_loss:.2f}'
                save_checkpoint(epoch, 
                                ml_model, 
                                optimizer, 
                                scheduler,
                                current_val_loss,
                                FNAME)
            else:
                FNAME = f'multimodal_lstm_{num_data}_{lr}_{epoch}-{current_val_loss:.2f}'
                save_checkpoint(epoch, 
                                ml_model, 
                                optimizer, 
                                scheduler,
                                current_val_loss,
                                FNAME)

        else:
            FNAME = f'benchmark_model_{num_data}_{lr}_{epoch}-{current_val_loss:.2f}'
            save_checkpoint(epoch, 
                            ml_model, 
                            optimizer, 
                            scheduler,
                            current_val_loss,
                            FNAME)
        
        print(f"Epoch {epoch}:\nTrain Loss: {current_train_loss}\nVal loss: {current_val_loss}")
        
    return epoch_train_losses, epoch_val_losses

In [7]:
def train_benchmark(ml_model, 
                    epochs, 
                    train_data, 
                    val_data, 
                    loss_fct, 
                    optimizer,
                    scheduler=None,
                    clip=1.0):
    epoch_train_losses = []
    epoch_val_losses = []
    perplexities = []
    for epoch in range(epochs):
        print(optimizer.param_groups[0]["lr"])
        ml_model.train()
        epoch_train_loss, num_train_steps = 0, 0
        for i, batch in enumerate(tqdm(train_data)):
            text, img, target = batch['input_ids'].to(device), batch['image'].to(device), batch['label_ids'].to(device)
            output = ml_model(text, None)
            ml_model.zero_grad()
            loss = loss_fct(output.view(-1, 320001), target.view(-1))
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(ml_model.parameters(), clip)
            optimizer.step()
            if scheduler is not None:
                scheduler.step()
            epoch_train_loss += loss.item()
            num_train_steps += 1
            
            if i % 2000 == 1:
                print("Current train loss:", epoch_train_loss/num_train_steps)
        current_train_loss = epoch_train_loss/len(train_data)
        epoch_train_losses.append(current_train_loss)
        
        ml_model.eval()
        epoch_val_loss = 0
        for i, batch in enumerate(tqdm(val_data)):
            text, img, target = batch['input_ids'].to(device), batch['image'].to(device), batch['label_ids'].to(device)
            
            with torch.no_grad():
                output = ml_model(text, None)
                loss = loss_fct(output.view(-1, output.size(-1)), target.view(-1))
               
            epoch_val_loss += loss.item()
            
        current_val_loss = epoch_val_loss/len(val_data)
        epoch_val_losses.append(current_val_loss)
        perplexity = math.exp(current_val_loss)
        
        if len(perplexities) == 0:
            perplexities.append(perplexity)
        else:
            if perplexity >= (3 * perplexities[-1]):
                for g in torch.optim.param_groups:
                    g['lr'] = g['lr']/2
            perplexities.append(perplexity)
            
        print(f"Epoch {epoch}:\nTrain Loss: {current_train_loss}\nVal loss: {current_val_loss}")

#         print("=> Saving checkpoint...")
#         if is_multimodal:
#             if train_visual_module:
#                 FNAME = f'finetuned_visual_multimodal_lstm_{num_data}_{lr}_{epoch}-{current_val_loss:.2f}'
#                 save_checkpoint(epoch, 
#                                 ml_model, 
#                                 optimizer, 
#                                 scheduler,
#                                 current_val_loss,
#                                 FNAME)
#             else:
#                 FNAME = f'multimodal_lstm_{num_data}_{lr}_{epoch}-{current_val_loss:.2f}'
#                 save_checkpoint(epoch, 
#                                 ml_model, 
#                                 optimizer, 
#                                 scheduler,
#                                 current_val_loss,
#                                 FNAME)

#         else:
#             FNAME = f'benchmark_model_{num_data}_{lr}_{epoch}-{current_val_loss:.2f}'
#             save_checkpoint(epoch, 
#                             ml_model, 
#                             optimizer, 
#                             scheduler,
#                             current_val_loss,
#                             FNAME)
    return epoch_train_losses, epoch_val_losses

In [8]:
multi_bpe = MultiBPE()
all_data = get_dataset(tokenizer=multi_bpe, num_data=num_data,maxlen=maxlen)

train_size = int(train_pct * len(all_data))
test_size = len(all_data) - train_size
train_dataset, val_test_dataset = torch.utils.data.random_split(all_data, [train_size, test_size])

train_size = int(0.5 * len(val_test_dataset))
test_size = len(val_test_dataset) - train_size
val_dataset, test_dataset = torch.utils.data.random_split(val_test_dataset, [train_size, test_size])


train_data = DataLoader(train_dataset, batch_size=bs, shuffle=True)
val_data = DataLoader(val_dataset, batch_size=bs, shuffle=False)
test_data = DataLoader(test_dataset, batch_size=bs, shuffle=False)

benchmark_model = MMLSTM(is_multimodal=False,
                 train_visual_module=False).to(device)
# visual_multimodal_lstm = MMLSTM(is_multimodal=is_multimodal, 
#                                       train_visual_module=train_visual_module).to(device)

optimizer = torch.optim.SGD(benchmark_model.parameters(), lr=lr)
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=320000, reduction='mean')

# num_warmup_steps = (len(train_dataset) // bs)
# num_training_steps = (len(train_dataset) // bs) * epochs
# scheduler = transformers.get_linear_schedule_with_warmup(optimizer, 
#                                                          num_warmup_steps=num_warmup_steps, 
#                                                          num_training_steps=num_training_steps)
# scheduler = transformers.get_constant_schedule_with_warmup(optimizer, 
#                                                          num_warmup_steps=num_warmup_steps) 
train_benchmark(benchmark_model, 
                epochs, 
                train_data, 
                val_data, 
                loss_fct, 
                optimizer,
                scheduler=None)

0.7071067811865475


  0%|          | 2/10000 [00:00<59:54,  2.78it/s]  

Current train loss: 12.65517807006836


 20%|██        | 2002/10000 [06:48<27:09,  4.91it/s]

Current train loss: 5.486799412316733


 40%|████      | 4002/10000 [13:38<20:20,  4.91it/s]

Current train loss: 4.901512066046635


 60%|██████    | 6002/10000 [20:26<13:34,  4.91it/s]

Current train loss: 4.583584970770102


 80%|████████  | 8002/10000 [27:15<06:34,  5.07it/s]

Current train loss: 4.370415696857036


100%|██████████| 10000/10000 [34:03<00:00,  4.89it/s]
100%|██████████| 1250/1250 [03:10<00:00,  6.55it/s]


Epoch 0:
Train Loss: 4.215248198866844
Val loss: 3.452218783760071
0.7071067811865475


  0%|          | 2/10000 [00:00<35:20,  4.71it/s]

Current train loss: 3.1885465383529663


 20%|██        | 2002/10000 [06:47<27:41,  4.81it/s]

Current train loss: 3.4585332109497027


 40%|████      | 4002/10000 [13:34<20:49,  4.80it/s]

Current train loss: 3.4194504866773996


 60%|██████    | 6003/10000 [20:22<13:24,  4.97it/s]

Current train loss: 3.3897459021571157


 80%|████████  | 8002/10000 [27:09<07:03,  4.71it/s]

Current train loss: 3.3596343746545223


100%|██████████| 10000/10000 [33:58<00:00,  4.90it/s]
100%|██████████| 1250/1250 [03:09<00:00,  6.58it/s]


Epoch 1:
Train Loss: 3.3325112164258957
Val loss: 3.1107340593338013
0.7071067811865475


  0%|          | 2/10000 [00:00<33:32,  4.97it/s]

Current train loss: 3.3606032133102417


 20%|██        | 2002/10000 [06:47<27:05,  4.92it/s]

Current train loss: 3.1549121821438755


 40%|████      | 4002/10000 [13:35<21:02,  4.75it/s]

Current train loss: 3.144089566297021


 60%|██████    | 6003/10000 [20:22<13:11,  5.05it/s]

Current train loss: 3.1318892823263473


 80%|████████  | 8002/10000 [27:10<06:50,  4.87it/s]

Current train loss: 3.116915875510912


100%|██████████| 10000/10000 [33:57<00:00,  4.91it/s]
100%|██████████| 1250/1250 [03:09<00:00,  6.58it/s]


Epoch 2:
Train Loss: 3.105567686152458
Val loss: 2.9717999406814575
0.7071067811865475


  0%|          | 2/10000 [00:00<34:23,  4.85it/s]

Current train loss: 3.0287365913391113


 20%|██        | 2003/10000 [06:46<26:26,  5.04it/s]

Current train loss: 3.008124230386732


 40%|████      | 4002/10000 [13:33<21:08,  4.73it/s]

Current train loss: 3.005432755574174


 60%|██████    | 6002/10000 [20:22<13:48,  4.83it/s]

Current train loss: 2.995436712409289


 80%|████████  | 8003/10000 [27:10<06:41,  4.98it/s]

Current train loss: 2.987162598131061


100%|██████████| 10000/10000 [33:56<00:00,  4.91it/s]
100%|██████████| 1250/1250 [03:10<00:00,  6.57it/s]


Epoch 3:
Train Loss: 2.980543468928337
Val loss: 2.886949514579773
0.7071067811865475


  0%|          | 3/10000 [00:00<33:11,  5.02it/s]

Current train loss: 2.933606505393982


 20%|██        | 2003/10000 [06:49<26:45,  4.98it/s]

Current train loss: 2.906056476520611


 40%|████      | 4003/10000 [13:36<19:41,  5.07it/s]

Current train loss: 2.9060316326497855


 60%|██████    | 6002/10000 [20:24<13:27,  4.95it/s]

Current train loss: 2.903435939830448


 80%|████████  | 8002/10000 [27:11<06:53,  4.83it/s]

Current train loss: 2.9002465695269373


100%|██████████| 10000/10000 [34:00<00:00,  4.90it/s]
100%|██████████| 1250/1250 [03:10<00:00,  6.55it/s]


Epoch 4:
Train Loss: 2.8956520588874817
Val loss: 2.831635028266907
0.7071067811865475


  0%|          | 2/10000 [00:00<34:42,  4.80it/s]

Current train loss: 2.6853089332580566


 20%|██        | 2002/10000 [06:47<27:25,  4.86it/s]

Current train loss: 2.8350156831455515


 40%|████      | 4002/10000 [13:36<20:38,  4.84it/s]

Current train loss: 2.8361195475384333


 60%|██████    | 6003/10000 [20:26<13:23,  4.98it/s]

Current train loss: 2.8391417348357053


 80%|████████  | 8002/10000 [27:16<06:48,  4.89it/s]

Current train loss: 2.8388014084039646


100%|██████████| 10000/10000 [34:05<00:00,  4.89it/s]
100%|██████████| 1250/1250 [03:10<00:00,  6.55it/s]


Epoch 5:
Train Loss: 2.833346119689941
Val loss: 2.793127955055237
0.7071067811865475


  0%|          | 2/10000 [00:00<34:21,  4.85it/s]

Current train loss: 2.8515820503234863


 20%|██        | 2002/10000 [06:51<27:07,  4.91it/s]

Current train loss: 2.7832986507025157


 40%|████      | 4002/10000 [13:40<20:26,  4.89it/s]

Current train loss: 2.7858049416172688


 60%|██████    | 6002/10000 [20:28<13:56,  4.78it/s]

Current train loss: 2.786292972345425


 80%|████████  | 8003/10000 [27:16<06:36,  5.04it/s]

Current train loss: 2.7846410672267177


100%|██████████| 10000/10000 [34:09<00:00,  4.88it/s]
100%|██████████| 1250/1250 [03:11<00:00,  6.52it/s]


Epoch 6:
Train Loss: 2.782998184108734
Val loss: 2.757483325958252
0.7071067811865475


  0%|          | 3/10000 [00:00<33:53,  4.92it/s]

Current train loss: 2.4028910398483276


 20%|██        | 2002/10000 [06:50<28:04,  4.75it/s]

Current train loss: 2.744975947476291


 40%|████      | 4002/10000 [13:37<20:40,  4.84it/s]

Current train loss: 2.744411613689787


 60%|██████    | 6003/10000 [20:26<13:22,  4.98it/s]

Current train loss: 2.7454949406455733


 80%|████████  | 8003/10000 [27:16<06:33,  5.08it/s]

Current train loss: 2.7441470695268926


100%|██████████| 10000/10000 [34:03<00:00,  4.89it/s]
100%|██████████| 1250/1250 [03:10<00:00,  6.55it/s]


Epoch 7:
Train Loss: 2.7426236793637275
Val loss: 2.732710922527313
0.7071067811865475


  0%|          | 2/10000 [00:00<33:59,  4.90it/s]

Current train loss: 2.831613063812256


 20%|██        | 2002/10000 [06:50<27:23,  4.87it/s]

Current train loss: 2.701992702055406


 40%|████      | 4002/10000 [13:37<20:02,  4.99it/s]

Current train loss: 2.703667904245204


 60%|██████    | 6002/10000 [20:25<14:03,  4.74it/s]

Current train loss: 2.7065084498550687


 80%|████████  | 8002/10000 [27:12<06:51,  4.86it/s]

Current train loss: 2.707014627171826


100%|██████████| 10000/10000 [33:59<00:00,  4.90it/s]
100%|██████████| 1250/1250 [03:10<00:00,  6.56it/s]


Epoch 8:
Train Loss: 2.7070716718554495
Val loss: 2.7063455171585082
0.7071067811865475


  0%|          | 2/10000 [00:00<34:46,  4.79it/s]

Current train loss: 2.577356457710266


 20%|██        | 2003/10000 [06:47<26:14,  5.08it/s]

Current train loss: 2.677229053728826


 40%|████      | 4002/10000 [13:36<20:27,  4.89it/s]

Current train loss: 2.672297399083356


 60%|██████    | 6002/10000 [20:25<13:37,  4.89it/s]

Current train loss: 2.6705483693116827


 80%|████████  | 8002/10000 [27:12<06:51,  4.86it/s]

Current train loss: 2.6733637565197097


100%|██████████| 10000/10000 [34:00<00:00,  4.90it/s]
100%|██████████| 1250/1250 [03:10<00:00,  6.57it/s]


Epoch 9:
Train Loss: 2.6761163468003275
Val loss: 2.691673587036133
0.7071067811865475


  0%|          | 3/10000 [00:00<34:23,  4.84it/s]

Current train loss: 2.4280706644058228


 20%|██        | 2003/10000 [06:46<25:45,  5.17it/s]

Current train loss: 2.6442384535258823


 40%|████      | 4003/10000 [13:34<20:15,  4.93it/s]

Current train loss: 2.644126569074967


 60%|██████    | 6003/10000 [20:22<12:57,  5.14it/s]

Current train loss: 2.644143968234655


 80%|████████  | 8002/10000 [27:08<06:41,  4.98it/s]

Current train loss: 2.647243438631676


100%|██████████| 10000/10000 [33:55<00:00,  4.91it/s]
100%|██████████| 1250/1250 [03:10<00:00,  6.57it/s]


Epoch 10:
Train Loss: 2.6490342474341393
Val loss: 2.6720113839149473
0.7071067811865475


  0%|          | 2/10000 [00:00<34:01,  4.90it/s]

Current train loss: 2.6188780069351196


 20%|██        | 2003/10000 [06:48<27:00,  4.93it/s]

Current train loss: 2.6172939818817658


 40%|████      | 4002/10000 [13:36<20:50,  4.80it/s]

Current train loss: 2.620273868540774


 60%|██████    | 6003/10000 [20:26<13:22,  4.98it/s]

Current train loss: 2.621713917480235


 80%|████████  | 8002/10000 [27:15<06:48,  4.89it/s]

Current train loss: 2.6238167699472035


100%|██████████| 10000/10000 [34:02<00:00,  4.90it/s]
100%|██████████| 1250/1250 [03:09<00:00,  6.58it/s]


Epoch 11:
Train Loss: 2.624616628873348
Val loss: 2.662219844532013
0.7071067811865475


  0%|          | 2/10000 [00:00<34:06,  4.89it/s]

Current train loss: 2.875465512275696


 20%|██        | 2003/10000 [06:49<26:27,  5.04it/s]

Current train loss: 2.5950734136583327


 40%|████      | 4002/10000 [13:36<20:19,  4.92it/s]

Current train loss: 2.598049786673493


 60%|██████    | 6003/10000 [20:22<13:27,  4.95it/s]

Current train loss: 2.600277048752571


 80%|████████  | 8002/10000 [27:09<06:53,  4.83it/s]

Current train loss: 2.6021970332607394


100%|██████████| 10000/10000 [33:56<00:00,  4.91it/s]
100%|██████████| 1250/1250 [03:10<00:00,  6.57it/s]


Epoch 12:
Train Loss: 2.602926277446747
Val loss: 2.6498695282936096
0.7071067811865475


  0%|          | 2/10000 [00:00<34:49,  4.79it/s]

Current train loss: 2.5155704021453857


 20%|██        | 2003/10000 [06:48<26:16,  5.07it/s]

Current train loss: 2.5740064970858687


 40%|████      | 4002/10000 [13:33<20:34,  4.86it/s]

Current train loss: 2.5809164546180643


 60%|██████    | 6003/10000 [20:22<13:07,  5.08it/s]

Current train loss: 2.5809111460293583


 80%|████████  | 8002/10000 [27:10<07:00,  4.75it/s]

Current train loss: 2.58119693995118


100%|██████████| 10000/10000 [33:58<00:00,  4.91it/s]
100%|██████████| 1250/1250 [03:10<00:00,  6.57it/s]


Epoch 13:
Train Loss: 2.581852628874779
Val loss: 2.6409975255966187
0.7071067811865475


  0%|          | 2/10000 [00:00<35:24,  4.71it/s]

Current train loss: 2.5386228561401367


 20%|██        | 2002/10000 [06:48<26:54,  4.95it/s]

Current train loss: 2.548977495371164


 40%|████      | 4003/10000 [13:35<19:58,  5.01it/s]

Current train loss: 2.557039044428801


 60%|██████    | 6003/10000 [20:22<13:30,  4.93it/s]

Current train loss: 2.5597556584559054


 80%|████████  | 8003/10000 [27:16<06:45,  4.93it/s]

Current train loss: 2.562566021924256


100%|██████████| 10000/10000 [34:08<00:00,  4.88it/s]
100%|██████████| 1250/1250 [03:10<00:00,  6.58it/s]

Epoch 14:
Train Loss: 2.5635516109228136
Val loss: 2.632882562351227





([4.215248198866844,
  3.3325112164258957,
  3.105567686152458,
  2.980543468928337,
  2.8956520588874817,
  2.833346119689941,
  2.782998184108734,
  2.7426236793637275,
  2.7070716718554495,
  2.6761163468003275,
  2.6490342474341393,
  2.624616628873348,
  2.602926277446747,
  2.581852628874779,
  2.5635516109228136],
 [3.452218783760071,
  3.1107340593338013,
  2.9717999406814575,
  2.886949514579773,
  2.831635028266907,
  2.793127955055237,
  2.757483325958252,
  2.732710922527313,
  2.7063455171585082,
  2.691673587036133,
  2.6720113839149473,
  2.662219844532013,
  2.6498695282936096,
  2.6409975255966187,
  2.632882562351227])

In [9]:
([4.612036231708527, 3.849266048717499, 3.684289860057831, 3.5965754370212557],
 [3.955835499954224,
  3.7210617919921876,
  3.623933087158203,
  3.5507487411499024])

([4.612036231708527, 3.849266048717499, 3.684289860057831, 3.5965754370212557],
 [3.955835499954224,
  3.7210617919921876,
  3.623933087158203,
  3.5507487411499024])

In [10]:
def test(ml_model, 
          epochs, 
          val_data, 
          loss_fct, 
          optimizer,
          scheduler = None,
          clip=2.0):
    
    ml_model.eval()
    epoch_val_loss = 0
    for i, batch in enumerate(tqdm(val_data)):
        text, img, target = batch['input_ids'].to(device), batch['image'].to(device), batch['label_ids'].to(device)

        with torch.no_grad():
            output = ml_model(text, None)
            loss = loss_fct(output.view(-1, output.size(-1)), target.view(-1))

        epoch_val_loss += loss.item()

    current_val_loss = epoch_val_loss/len(val_data)
    perplexity = math.exp(current_val_loss)


    print(f"Val loss: {current_val_loss}")
    print(f"Val perp: {perplexity}")

    return current_val_loss, perplexity

In [11]:
test(benchmark_model, 
    epochs, 
    test_data, 
    loss_fct, 
    optimizer,
    scheduler=None)

100%|██████████| 1250/1250 [03:10<00:00,  6.55it/s]

Val loss: 2.628850291252136
Val perp: 13.857828273378601





(2.628850291252136, 13.857828273378601)

In [12]:
if is_multimodal:
    if train_visual_module:
        torch.save(visual_multimodal_lstm.state_dict(), 
                   f'./saved_models/finetuned_visual_multimodal_lstm_{num_data}_{lr}_{epochs}v3')
    else:
        torch.save(train_visual_multimodal_lstm.state_dict(), 
                   f'./saved_models/multimodal_lstm_{num_data}_{lr}_{epochs}v3')
else:
    torch.save(benchmark_model.state_dict(), 
               f'./saved_models/benchmark_model_{num_data}_{lr}_{epochs}v5')

In [13]:
([4.528013765186734,
  3.832729446686639,
  3.6829322583516437,
  3.6024026798884075,
  3.5487994864993624,
  3.5103443424860634,
  3.479441665818956,
  3.454497324244181,
  3.4344203785790337,
  3.416726737170749,
  3.4017402518802218,
  3.387020494418674,
  3.3755083185831705,
  3.364240313635932,
  3.354363308567471],
 [3.8837243381500244,
  3.669478569030762,
  3.588050753593445,
  3.522890328216553,
  3.492208413696289,
  3.4554949699401853,
  3.435982433128357,
  3.4159389377593996,
  3.4012541484832766,
  3.3922290019989014,
  3.381755703544617,
  3.3686582359313966,
  3.364826725578308,
  3.3517054290771484,
  3.344561079216003])

([4.528013765186734,
  3.832729446686639,
  3.6829322583516437,
  3.6024026798884075,
  3.5487994864993624,
  3.5103443424860634,
  3.479441665818956,
  3.454497324244181,
  3.4344203785790337,
  3.416726737170749,
  3.4017402518802218,
  3.387020494418674,
  3.3755083185831705,
  3.364240313635932,
  3.354363308567471],
 [3.8837243381500244,
  3.669478569030762,
  3.588050753593445,
  3.522890328216553,
  3.492208413696289,
  3.4554949699401853,
  3.435982433128357,
  3.4159389377593996,
  3.4012541484832766,
  3.3922290019989014,
  3.381755703544617,
  3.3686582359313966,
  3.364826725578308,
  3.3517054290771484,
  3.344561079216003])

In [14]:
a = [3.8373831869125365,
  3.5927175075531004,
  3.5031064222335817,
  3.443180693435669,
  3.4116617540359497,
  3.387208477973938,
  3.362659523963928,
  3.3494730602264404,
  3.3324638856887816,
  3.3229802640914916,
  3.3158617279052733,
  3.3066658948898318,
  3.2975340688705446,
  3.287569703865051,
  3.2815385608673098]

In [15]:
import math

math.exp(3)

20.085536923187668

In [16]:
for i, num in enumerate(a):
    print(f"({i+1}, {math.exp(num)})")

(1, 46.403885127011236)
(2, 36.332676214320756)
(3, 33.21848248023084)
(4, 31.286312245101627)
(5, 30.31557946197708)
(6, 29.583254591081086)
(7, 28.865858326555315)
(8, 28.48771837617349)
(9, 28.00726345632081)
(10, 27.742908669424494)
(11, 27.54612102114138)
(12, 27.293972626149017)
(13, 27.045863386113886)
(14, 26.777706755456222)
(15, 26.61669261458621)
