In [1]:
import torch
import pandas as pd
import numpy as np
import random
import torch.nn as nn
from torch.autograd import Variable
from torch.nn import Parameter
from torch import Tensor
import torch.nn.functional as F
import matplotlib.image as img
import matplotlib.pyplot as plt
import torchvision
import PIL
import math
import datetime
import os
from bpemb import BPEmb
from torch import nn
from image_caption_dataset import ImageCaptionDataset
from multi_bpe import MultiBPE
from multimodal_model import MMLSTM, BenchmarkLSTM
from torch.utils.data import DataLoader
from tqdm import tqdm
import transformers

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
LOAD_PATH = './data/'
seed = 420
num_data = 200000
maxlen = 32
epochs = 15
train_pct = 0.8
bs=32
lr = 1.0/math.sqrt(32/bs)
is_multimodal=False
train_visual_module=False

In [3]:
device

'cuda'

In [4]:
torch.manual_seed(seed)
# torch.cuda.manual_seed(seed)
# torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
np.random.seed(seed)  # Numpy module.
random.seed(seed)  # Python random module.
# torch.backends.cudnn.benchmark = False
# torch.backends.cudnn.deterministic = True

In [5]:
def save_checkpoint(epoch, 
                    model, 
                    optimizer, 
                    scheduler,
                    loss,
                    FNAME):
    
    today = datetime.date.today()
    PATH = f'./checkpoints/{today.strftime("checkpoint-%m-%d-%Y")}/'
    if not os.path.exists(PATH):
        os.makedirs(PATH)
    torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict() if scheduler !=None else None,
            'loss': loss,
            }, f'{PATH}{FNAME}')

In [6]:
def load_latest_checkpoint(PATH, model, optimizer, scheduler):
    checkpoint = torch.load(PATH)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    return epoch, loss

In [7]:
def _init_fn(worker_id):
    np.random.seed(int(seed))
    

def get_dataset(tokenizer, path=LOAD_PATH, num_data = None, maxlen=64):
    df = pd.read_csv(path + 'ml_stacked_data.csv')
    if num_data != None:
        text_df = df[:num_data]
    else:
        text_df = df
    dataset = ImageCaptionDataset(text_df, tokenizer, maxlen=maxlen)

    return dataset

def train(ml_model, 
          epochs, 
          train_data, 
          val_data, 
          loss_fct, 
          optimizer,
          scheduler = None,
          clip=2.0):
    epoch_train_losses = []
    epoch_val_losses = []
        
    for epoch in range(epochs):
        ml_model.train()
        epoch_train_loss, num_train_steps = 0, 0
        for i, batch in enumerate(tqdm(train_data)):
            text, img, target = batch['input_ids'].to(device), batch['image'].to(device), batch['label_ids'].to(device)
            #print(img.size())
            output = ml_model(text, img)
            ml_model.zero_grad()
            loss = torch.nn.functional.cross_entropy(output.view(-1, 320001), target.view(-1))
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(ml_model.parameters(), clip)
            optimizer.step()
            if scheduler is not None:
                scheduler.step()
            epoch_train_loss += loss.item()
            num_train_steps += 1
            
            if i % 2000 == 1:
                print("Current train loss:", epoch_train_loss/num_train_steps)

        current_train_loss = epoch_train_loss/len(train_data)
        epoch_train_losses.append(current_train_loss)
        
        ml_model.eval()
        epoch_val_loss = 0
        for i, batch in enumerate(tqdm(val_data)):
            text, img, target = batch['input_ids'].to(device), batch['image'].to(device), batch['label_ids'].to(device)
            
            with torch.no_grad():
                output = ml_model(text, img)
                loss = torch.nn.functional.cross_entropy(output.view(-1, output.size(-1)), target.view(-1))
            epoch_val_loss += loss.item()
            
        current_val_loss = epoch_val_loss/len(val_data)
        epoch_val_losses.append(current_val_loss)
        perplexity = math.exp
        print("=> Saving checkpoint...")
        if is_multimodal:
            if train_visual_module:
                FNAME = f'finetuned_visual_multimodal_lstm_{num_data}_{lr}_{epoch}-{current_val_loss:.2f}'
                save_checkpoint(epoch, 
                                ml_model, 
                                optimizer, 
                                scheduler,
                                current_val_loss,
                                FNAME)
            else:
                FNAME = f'multimodal_lstm_{num_data}_{lr}_{epoch}-{current_val_loss:.2f}'
                save_checkpoint(epoch, 
                                ml_model, 
                                optimizer, 
                                scheduler,
                                current_val_loss,
                                FNAME)

        else:
            FNAME = f'benchmark_model_{num_data}_{lr}_{epoch}-{current_val_loss:.2f}'
            save_checkpoint(epoch, 
                            ml_model, 
                            optimizer, 
                            scheduler,
                            current_val_loss,
                            FNAME)
        
        print(f"Epoch {epoch}:\nTrain Loss: {current_train_loss}\nVal loss: {current_val_loss}")
        
    return epoch_train_losses, epoch_val_losses

In [8]:
def train_benchmark(ml_model, 
                    epochs, 
                    train_data, 
                    val_data, 
                    loss_fct, 
                    optimizer,
                    scheduler=None,
                    clip=1.0):
    epoch_train_losses = []
    epoch_val_losses = []
    perplexities = []
    for epoch in range(epochs):
        print(optimizer.param_groups[0]["lr"])
        ml_model.train()
        epoch_train_loss, num_train_steps = 0, 0
        for i, batch in enumerate(tqdm(train_data)):
            text, img, target = batch['input_ids'].to(device), batch['image'].to(device), batch['label_ids'].to(device)
            output = ml_model(text)
            ml_model.zero_grad()
            loss = loss_fct(output.view(-1, 320001), target.view(-1))
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(ml_model.parameters(), clip)
            optimizer.step()
            if scheduler is not None:
                scheduler.step()
            epoch_train_loss += loss.item()
            num_train_steps += 1
            
            if i % 2000 == 1:
                print("Current train loss:", epoch_train_loss/num_train_steps)
        current_train_loss = epoch_train_loss/len(train_data)
        epoch_train_losses.append(current_train_loss)
        
        ml_model.eval()
        epoch_val_loss = 0
        for i, batch in enumerate(tqdm(val_data)):
            text, img, target = batch['input_ids'].to(device), batch['image'].to(device), batch['label_ids'].to(device)
            
            with torch.no_grad():
                output = ml_model(text)
                loss = loss_fct(output.view(-1, output.size(-1)), target.view(-1))
               
            epoch_val_loss += loss.item()
            
        current_val_loss = epoch_val_loss/len(val_data)
        epoch_val_losses.append(current_val_loss)
        perplexity = math.exp(current_val_loss)
        
        if len(perplexities) == 0:
            perplexities.append(perplexity)
        else:
            if perplexity >= (3 * perplexities[-1]):
                for g in torch.optim.param_groups:
                    g['lr'] = g['lr']/2
            perplexities.append(perplexity)
            
        print(f"Epoch {epoch}:\nTrain Loss: {current_train_loss}\nVal loss: {current_val_loss}")

#         print("=> Saving checkpoint...")
#         if is_multimodal:
#             if train_visual_module:
#                 FNAME = f'finetuned_visual_multimodal_lstm_{num_data}_{lr}_{epoch}-{current_val_loss:.2f}'
#                 save_checkpoint(epoch, 
#                                 ml_model, 
#                                 optimizer, 
#                                 scheduler,
#                                 current_val_loss,
#                                 FNAME)
#             else:
#                 FNAME = f'multimodal_lstm_{num_data}_{lr}_{epoch}-{current_val_loss:.2f}'
#                 save_checkpoint(epoch, 
#                                 ml_model, 
#                                 optimizer, 
#                                 scheduler,
#                                 current_val_loss,
#                                 FNAME)

#         else:
#             FNAME = f'benchmark_model_{num_data}_{lr}_{epoch}-{current_val_loss:.2f}'
#             save_checkpoint(epoch, 
#                             ml_model, 
#                             optimizer, 
#                             scheduler,
#                             current_val_loss,
#                             FNAME)
    return epoch_train_losses, epoch_val_losses

In [9]:
multi_bpe = MultiBPE()
all_data = get_dataset(tokenizer=multi_bpe, num_data=num_data,maxlen=maxlen)

train_size = int(train_pct * len(all_data))
test_size = len(all_data) - train_size
train_dataset, val_test_dataset = torch.utils.data.random_split(all_data, [train_size, test_size])

train_size = int(0.5 * len(val_test_dataset))
test_size = len(val_test_dataset) - train_size
val_dataset, test_dataset = torch.utils.data.random_split(val_test_dataset, [train_size, test_size])


train_data = DataLoader(train_dataset, batch_size=bs, shuffle=True)
val_data = DataLoader(val_dataset, batch_size=bs, shuffle=False)
test_data = DataLoader(test_dataset, batch_size=bs, shuffle=False)

benchmark_model = BenchmarkLSTM().to(device)
# visual_multimodal_lstm = MMLSTM(is_multimodal=is_multimodal, 
#                                       train_visual_module=train_visual_module).to(device)

optimizer = torch.optim.SGD(benchmark_model.parameters(), lr=lr)
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=320000, reduction='mean')

# num_warmup_steps = (len(train_dataset) // bs)
# num_training_steps = (len(train_dataset) // bs) * epochs
# scheduler = transformers.get_linear_schedule_with_warmup(optimizer, 
#                                                          num_warmup_steps=num_warmup_steps, 
#                                                          num_training_steps=num_training_steps)
# scheduler = transformers.get_constant_schedule_with_warmup(optimizer, 
#                                                          num_warmup_steps=num_warmup_steps) 
train_benchmark(benchmark_model, 
                epochs, 
                train_data, 
                val_data, 
                loss_fct, 
                optimizer,
                scheduler=None)

1.0


  0%|          | 2/5000 [00:00<21:59,  3.79it/s]

Current train loss: 12.646284580230713


 40%|████      | 2002/5000 [08:28<12:25,  4.02it/s]

Current train loss: 5.077624455198541


 80%|████████  | 4002/5000 [16:46<04:08,  4.02it/s]

Current train loss: 4.480915155367873


100%|██████████| 5000/5000 [20:54<00:00,  3.99it/s]
100%|██████████| 625/625 [02:04<00:00,  5.03it/s]


Epoch 0:
Train Loss: 4.3108990769863125
Val loss: 3.4786369148254392
1.0


  0%|          | 2/5000 [00:00<21:37,  3.85it/s]

Current train loss: 3.4299252033233643


 40%|████      | 2002/5000 [08:18<12:26,  4.02it/s]

Current train loss: 3.451037641171809


 80%|████████  | 4002/5000 [16:35<04:07,  4.03it/s]

Current train loss: 3.3848061294093363


100%|██████████| 5000/5000 [20:43<00:00,  4.02it/s]
100%|██████████| 625/625 [02:03<00:00,  5.04it/s]


Epoch 1:
Train Loss: 3.3553966972351073
Val loss: 3.1218619747161864
1.0


  0%|          | 2/5000 [00:00<21:20,  3.90it/s]

Current train loss: 3.474177360534668


 40%|████      | 2002/5000 [08:16<12:33,  3.98it/s]

Current train loss: 3.157481400521247


 80%|████████  | 4002/5000 [16:33<04:08,  4.02it/s]

Current train loss: 3.1263588131933675


100%|██████████| 5000/5000 [20:41<00:00,  4.03it/s]
100%|██████████| 625/625 [02:04<00:00,  5.02it/s]


Epoch 2:
Train Loss: 3.113920765209198
Val loss: 2.979832181930542
1.0


  0%|          | 2/5000 [00:00<21:48,  3.82it/s]

Current train loss: 3.0782840251922607


 40%|████      | 2002/5000 [08:16<12:30,  4.00it/s]

Current train loss: 3.0080290212259664


 80%|████████  | 4002/5000 [16:33<04:07,  4.04it/s]

Current train loss: 2.9872314812599687


100%|██████████| 5000/5000 [20:41<00:00,  4.03it/s]
100%|██████████| 625/625 [02:04<00:00,  5.03it/s]


Epoch 3:
Train Loss: 2.9792554649829865
Val loss: 2.889918435668945
1.0


  0%|          | 2/5000 [00:00<21:31,  3.87it/s]

Current train loss: 3.002747416496277


 40%|████      | 2002/5000 [08:17<12:24,  4.03it/s]

Current train loss: 2.9031917698733456


 80%|████████  | 4002/5000 [16:35<04:09,  4.00it/s]

Current train loss: 2.894876620997077


100%|██████████| 5000/5000 [20:44<00:00,  4.02it/s]
100%|██████████| 625/625 [02:04<00:00,  5.01it/s]


Epoch 4:
Train Loss: 2.8890707806110383
Val loss: 2.805567984390259
1.0


  0%|          | 2/5000 [00:00<21:59,  3.79it/s]

Current train loss: 2.9276753664016724


 40%|████      | 2002/5000 [08:18<12:31,  3.99it/s]

Current train loss: 2.828395057153273


 80%|████████  | 4002/5000 [16:35<04:07,  4.03it/s]

Current train loss: 2.828215938875045


100%|██████████| 5000/5000 [20:44<00:00,  4.02it/s]
100%|██████████| 625/625 [02:04<00:00,  5.02it/s]


Epoch 5:
Train Loss: 2.821729228401184
Val loss: 2.766284902191162
1.0


  0%|          | 2/5000 [00:00<21:44,  3.83it/s]

Current train loss: 2.84767484664917


 40%|████      | 2002/5000 [08:18<12:23,  4.03it/s]

Current train loss: 2.775845371402584


 80%|████████  | 4002/5000 [16:36<04:08,  4.01it/s]

Current train loss: 2.7721536929222537


100%|██████████| 5000/5000 [20:46<00:00,  4.01it/s]
100%|██████████| 625/625 [02:06<00:00,  4.95it/s]


Epoch 6:
Train Loss: 2.7696691148281096
Val loss: 2.72775007019043
1.0


  0%|          | 2/5000 [00:00<21:36,  3.85it/s]

Current train loss: 2.4008957147598267


 40%|████      | 2002/5000 [08:23<12:31,  3.99it/s]

Current train loss: 2.7305252931930206


 80%|████████  | 4002/5000 [16:46<04:07,  4.03it/s]

Current train loss: 2.7282316796723634


100%|██████████| 5000/5000 [20:56<00:00,  3.98it/s]
100%|██████████| 625/625 [02:06<00:00,  4.95it/s]


Epoch 7:
Train Loss: 2.7262459456920625
Val loss: 2.6945456871032714
1.0


  0%|          | 2/5000 [00:00<22:40,  3.67it/s]

Current train loss: 2.8408461809158325


 40%|████      | 2002/5000 [08:25<12:34,  3.97it/s]

Current train loss: 2.6886003649794494


 80%|████████  | 4002/5000 [16:49<04:10,  3.99it/s]

Current train loss: 2.6904490878735703


100%|██████████| 5000/5000 [21:02<00:00,  3.96it/s]
100%|██████████| 625/625 [02:06<00:00,  4.93it/s]


Epoch 8:
Train Loss: 2.6899516554832457
Val loss: 2.669086597442627
1.0


  0%|          | 2/5000 [00:00<22:23,  3.72it/s]

Current train loss: 2.571129322052002


 40%|████      | 2002/5000 [08:24<12:36,  3.96it/s]

Current train loss: 2.655496286821889


 80%|████████  | 4002/5000 [16:47<04:10,  3.99it/s]

Current train loss: 2.6553199086887487


100%|██████████| 5000/5000 [21:03<00:00,  3.96it/s]
100%|██████████| 625/625 [02:05<00:00,  4.98it/s]


Epoch 9:
Train Loss: 2.6576039593219756
Val loss: 2.655347897338867
1.0


  0%|          | 2/5000 [00:00<21:50,  3.81it/s]

Current train loss: 2.5654048919677734


 40%|████      | 2002/5000 [08:18<12:14,  4.08it/s]

Current train loss: 2.627405057062993


 80%|████████  | 4002/5000 [16:36<04:05,  4.06it/s]

Current train loss: 2.628507357665981


100%|██████████| 5000/5000 [20:45<00:00,  4.02it/s]
100%|██████████| 625/625 [02:04<00:00,  5.01it/s]


Epoch 10:
Train Loss: 2.6295034252643585
Val loss: 2.634938250732422
1.0


  0%|          | 2/5000 [00:00<21:37,  3.85it/s]

Current train loss: 2.6325340270996094


 40%|████      | 2002/5000 [08:18<12:32,  3.98it/s]

Current train loss: 2.603545538671724


 80%|████████  | 4002/5000 [16:37<04:08,  4.02it/s]

Current train loss: 2.605269845934405


100%|██████████| 5000/5000 [20:45<00:00,  4.01it/s]
100%|██████████| 625/625 [02:04<00:00,  5.02it/s]


Epoch 11:
Train Loss: 2.6051934757709505
Val loss: 2.61940299949646
1.0


  0%|          | 2/5000 [00:00<21:22,  3.90it/s]

Current train loss: 2.7626837491989136


 40%|████      | 2002/5000 [08:22<12:21,  4.04it/s]

Current train loss: 2.579836149792095


 80%|████████  | 4002/5000 [16:43<04:09,  4.01it/s]

Current train loss: 2.5822726332027277


100%|██████████| 5000/5000 [20:53<00:00,  3.99it/s]
100%|██████████| 625/625 [02:05<00:00,  4.99it/s]


Epoch 12:
Train Loss: 2.5822831481933592
Val loss: 2.608886410522461
1.0


  0%|          | 2/5000 [00:00<21:26,  3.88it/s]

Current train loss: 2.498419165611267


 40%|████      | 2002/5000 [08:19<12:32,  3.98it/s]

Current train loss: 2.562600621691236


 80%|████████  | 4002/5000 [16:36<04:10,  3.99it/s]

Current train loss: 2.561991814253987


100%|██████████| 5000/5000 [20:45<00:00,  4.01it/s]
100%|██████████| 625/625 [02:04<00:00,  5.02it/s]


Epoch 13:
Train Loss: 2.561886626958847
Val loss: 2.5930460548400878
1.0


  0%|          | 2/5000 [00:00<21:45,  3.83it/s]

Current train loss: 2.4724626541137695


 40%|████      | 2002/5000 [08:20<12:15,  4.08it/s]

Current train loss: 2.538843631148934


 80%|████████  | 4002/5000 [16:37<04:07,  4.04it/s]

Current train loss: 2.5422238860947677


100%|██████████| 5000/5000 [20:46<00:00,  4.01it/s]
100%|██████████| 625/625 [02:04<00:00,  5.02it/s]

Epoch 14:
Train Loss: 2.5432419083595277
Val loss: 2.583121911621094





([4.3108990769863125,
  3.3553966972351073,
  3.113920765209198,
  2.9792554649829865,
  2.8890707806110383,
  2.821729228401184,
  2.7696691148281096,
  2.7262459456920625,
  2.6899516554832457,
  2.6576039593219756,
  2.6295034252643585,
  2.6051934757709505,
  2.5822831481933592,
  2.561886626958847,
  2.5432419083595277],
 [3.4786369148254392,
  3.1218619747161864,
  2.979832181930542,
  2.889918435668945,
  2.805567984390259,
  2.766284902191162,
  2.72775007019043,
  2.6945456871032714,
  2.669086597442627,
  2.655347897338867,
  2.634938250732422,
  2.61940299949646,
  2.608886410522461,
  2.5930460548400878,
  2.583121911621094])

In [10]:
([4.612036231708527, 3.849266048717499, 3.684289860057831, 3.5965754370212557],
 [3.955835499954224,
  3.7210617919921876,
  3.623933087158203,
  3.5507487411499024])

([4.612036231708527, 3.849266048717499, 3.684289860057831, 3.5965754370212557],
 [3.955835499954224,
  3.7210617919921876,
  3.623933087158203,
  3.5507487411499024])

In [11]:
def test(ml_model, 
          epochs, 
          val_data, 
          loss_fct, 
          optimizer,
          scheduler = None,
          clip=2.0):
    
    ml_model.eval()
    epoch_val_loss = 0
    for i, batch in enumerate(tqdm(val_data)):
        text, img, target = batch['input_ids'].to(device), batch['image'].to(device), batch['label_ids'].to(device)

        with torch.no_grad():
            output = ml_model(text)
            loss = loss_fct(output.view(-1, output.size(-1)), target.view(-1))

        epoch_val_loss += loss.item()

    current_val_loss = epoch_val_loss/len(val_data)
    perplexity = math.exp(current_val_loss)


    print(f"Val loss: {current_val_loss}")
    print(f"Val perp: {perplexity}")

    return current_val_loss, perplexity

In [12]:
test(benchmark_model, 
    epochs, 
    test_data, 
    loss_fct, 
    optimizer,
    scheduler=None)

100%|██████████| 625/625 [02:04<00:00,  5.03it/s]

Val loss: 2.5833786556243896
Val perp: 13.241802154368015





(2.5833786556243896, 13.241802154368015)

In [13]:
if is_multimodal:
    if train_visual_module:
        torch.save(visual_multimodal_lstm.state_dict(), 
                   f'./saved_models/finetuned_visual_multimodal_lstm_{num_data}_{lr}_{epochs}v4')
    else:
        torch.save(train_visual_multimodal_lstm.state_dict(), 
                   f'./saved_models/multimodal_lstm_{num_data}_{lr}_{epochs}v4')
else:
    torch.save(benchmark_model.state_dict(), 
               f'./saved_models/benchmark_model_{num_data}_{lr}_{epochs}v7')

In [4]:
a =  [3.161041009902954,
  2.813968430328369,
  2.674389543533325,
  2.5785831508636474,
  2.514091403579712,
  2.465382401275635,
  2.4319754432678224,
  2.3976976949691773,
  2.38222297668457,
  2.355729240036011,
  2.331479719352722,
  2.327339119338989,
  2.323257204246521,
  2.3015333766937256,
  2.292942211151123]

In [2]:
import math

math.exp(3)

20.085536923187668

In [5]:
for i, num in enumerate(a):
    print(f"({i+1}, {math.exp(num)})")

(1, 23.595145929028277)
(2, 16.67596448532435)
(3, 14.50349338529962)
(4, 13.178453045973917)
(5, 12.355377625709929)
(6, 11.76798138276216)
(7, 11.38134308505272)
(8, 10.997826858568226)
(9, 10.828948627760905)
(10, 10.54581648988851)
(11, 10.293161259721039)
(12, 10.250629510314461)
(13, 10.208872593083909)
(14, 9.989488365594339)
(15, 9.90403461660927)
