In [1]:
import torch
import pandas as pd
import numpy as np
import random
import torch.nn as nn
from torch.autograd import Variable
from torch.nn import Parameter
from torch import Tensor
import torch.nn.functional as F
import matplotlib.image as img
import matplotlib.pyplot as plt
import torchvision
import PIL
import math
import datetime
import os
from bpemb import BPEmb
from torch import nn
from image_caption_dataset import ImageCaptionDataset
from multi_bpe import MultiBPE
from multimodal_model import MMLSTM, BenchmarkLSTM
from torch.utils.data import DataLoader
from tqdm import tqdm
import transformers
from keras_nlp.metrics import Perplexity

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
LOAD_PATH = './data/'
seed = 420
num_data = 200000
maxlen = 32
epochs = 15
train_pct = 0.8
bs=32
lr = 1.0/math.sqrt(32/bs)
is_multimodal=True
train_visual_module=False

In [3]:
torch.manual_seed(seed)
# torch.cuda.manual_seed(seed)
# torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
np.random.seed(seed)  # Numpy module.
random.seed(seed)  # Python random module.
# torch.backends.cudnn.benchmark = False
# torch.backends.cudnn.deterministic = True

In [4]:
def save_checkpoint(epoch, 
                    model, 
                    optimizer, 
                    scheduler,
                    loss,
                    FNAME):
    
    today = datetime.date.today()
    PATH = f'./checkpoints/{today.strftime("checkpoint-%m-%d-%Y")}/'
    if not os.path.exists(PATH):
        os.makedirs(PATH)
    torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict() if scheduler !=None else None,
            'loss': loss,
            }, f'{PATH}{FNAME}')

In [5]:
df = pd.read_csv(LOAD_PATH + 'ml_stacked_data.csv')

In [6]:
l = {}
for caption in df.caption:
    length = len(caption.split())
    if length in l:
        l[length] += 1  
    else:
        l[length] = 0

In [7]:
l

{11: 31894,
 14: 9235,
 18: 1305,
 8: 25591,
 12: 22805,
 13: 14901,
 10: 39529,
 9: 37084,
 16: 3337,
 7: 3297,
 15: 5511,
 19: 831,
 6: 447,
 21: 407,
 17: 2048,
 26: 89,
 20: 597,
 23: 192,
 5: 38,
 24: 140,
 25: 100,
 36: 7,
 22: 273,
 30: 32,
 28: 44,
 27: 59,
 33: 12,
 32: 24,
 34: 13,
 38: 5,
 37: 7,
 48: 5,
 29: 25,
 39: 11,
 31: 21,
 42: 4,
 41: 5,
 47: 1,
 4: 2,
 40: 5,
 52: 2,
 35: 6,
 45: 4,
 43: 2,
 44: 2,
 49: 0,
 2: 0,
 46: 0,
 59: 0,
 3: 0,
 53: 0}

In [8]:
def load_latest_checkpoint(PATH, model, optimizer, scheduler):
    checkpoint = torch.load(PATH)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    return epoch, loss

In [9]:
def _init_fn(worker_id):
    np.random.seed(int(seed))
    
def process_img(image_path): 
    transform = torchvision.transforms.Compose([
        # Resize image to 224 x 224 as required by most vision models
        torchvision.transforms.Resize(
            size=(224, 224)
        ),
        # Convert PIL image to tensor with image values in [0, 1]
        torchvision.transforms.ToTensor(),

        torchvision.transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ])
    
    im = PIL.Image.open(image_path)
    image = im.convert('RGB')
    image = transform(image)
    
    return image.view(1, image.size(0), image.size(1), image.size(2))

def get_dataset(tokenizer, path=LOAD_PATH, num_data = None, maxlen=64):
    df = pd.read_csv(path + 'ml_stacked_data.csv')    
    if num_data != None:
        text_df = df[:num_data]
    else:
        text_df = df
    dataset = ImageCaptionDataset(text_df, tokenizer, maxlen=maxlen)

    return dataset

def train(ml_model, 
          epochs, 
          train_data, 
          val_data, 
          loss_fct, 
          optimizer,
          scheduler = None,
          clip=2.0):
    epoch_train_losses = []
    epoch_val_losses = []
    perplexities = []
    for epoch in range(epochs):
        print(optimizer.param_groups[0]["lr"])
        ml_model.train()
        epoch_train_loss, num_train_steps = 0, 0
        for i, batch in enumerate(tqdm(train_data)):
            text, img, target = batch['input_ids'].to(device), batch['image'].to(device), batch['label_ids'].to(device)
            output = ml_model(text, img)
            ml_model.zero_grad()
            loss = loss_fct(output.view(-1, 320001), target.view(-1))
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(ml_model.parameters(), clip)
            optimizer.step()
            if scheduler is not None:
                scheduler.step()
            epoch_train_loss += loss.item()
            num_train_steps += 1
            
            if i % 2000 == 1:
                print("Current train loss:", epoch_train_loss/num_train_steps)
        current_train_loss = epoch_train_loss/len(train_data)
        epoch_train_losses.append(current_train_loss)
        
        ml_model.eval()
        epoch_val_loss = 0
        for i, batch in enumerate(tqdm(val_data)):
            text, img, target = batch['input_ids'].to(device), batch['image'].to(device), batch['label_ids'].to(device)
            
            with torch.no_grad():
                output = ml_model(text, img)
                loss = loss_fct(output.view(-1, output.size(-1)), target.view(-1))
               
            epoch_val_loss += loss.item()
            
        current_val_loss = epoch_val_loss/len(val_data)
        epoch_val_losses.append(current_val_loss)
        perplexity = math.exp(current_val_loss)
        
        if len(perplexities) == 0:
            perplexities.append(perplexity)
        else:
            if perplexity >= (3 * perplexities[-1]):
                for g in torch.optim.param_groups:
                    g['lr'] = g['lr']/2
            perplexities.append(perplexity)
            
        print(f"Epoch {epoch}:\nTrain Loss: {current_train_loss}\nVal loss: {current_val_loss}")

#         print("=> Saving checkpoint...")
#         if is_multimodal:
#             if train_visual_module:
#                 FNAME = f'finetuned_visual_multimodal_lstm_{num_data}_{lr}_{epoch}-{current_val_loss:.2f}'
#                 save_checkpoint(epoch, 
#                                 ml_model, 
#                                 optimizer, 
#                                 scheduler,
#                                 current_val_loss,
#                                 FNAME)
#             else:
#                 FNAME = f'multimodal_lstm_{num_data}_{lr}_{epoch}-{current_val_loss:.2f}'
#                 save_checkpoint(epoch, 
#                                 ml_model, 
#                                 optimizer, 
#                                 scheduler,
#                                 current_val_loss,
#                                 FNAME)

#         else:
#             FNAME = f'benchmark_model_{num_data}_{lr}_{epoch}-{current_val_loss:.2f}'
#             save_checkpoint(epoch, 
#                             ml_model, 
#                             optimizer, 
#                             scheduler,
#                             current_val_loss,
#                             FNAME)
    return epoch_train_losses, epoch_val_losses

In [10]:
multi_bpe = MultiBPE()
all_data = get_dataset(tokenizer=multi_bpe, num_data=num_data,maxlen=maxlen)

train_size = int(train_pct * len(all_data))
test_size = len(all_data) - train_size
train_dataset, val_test_dataset = torch.utils.data.random_split(all_data, [train_size, test_size])

train_size = int(0.5 * len(val_test_dataset))
test_size = len(val_test_dataset) - train_size
val_dataset, test_dataset = torch.utils.data.random_split(val_test_dataset, [train_size, test_size])


train_data = DataLoader(train_dataset, batch_size=bs, shuffle=True)
val_data = DataLoader(val_dataset, batch_size=bs, shuffle=False)
test_data = DataLoader(test_dataset, batch_size=bs, shuffle=False)

# benchmark_model = BenchmarkLSTM().to(device)
visual_multimodal_lstm = MMLSTM(is_multimodal=is_multimodal, 
                                      train_visual_module=train_visual_module).to(device)

optimizer = torch.optim.SGD(visual_multimodal_lstm.parameters(), lr=lr)
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=320000, reduction='mean')

# num_warmup_steps = (len(train_dataset) // bs)
# num_training_steps = (len(train_dataset) // bs) * epochs
# scheduler = transformers.get_linear_schedule_with_warmup(optimizer, 
#                                                          num_warmup_steps=num_warmup_steps, 
#                                                          num_training_steps=num_training_steps)
# scheduler = transformers.get_constant_schedule_with_warmup(optimizer, 
#                                                          num_warmup_steps=num_warmup_steps) 
train(visual_multimodal_lstm, 
    epochs, 
    train_data, 
    val_data, 
    loss_fct, 
    optimizer,
    scheduler=None)

1.0


  0%|          | 2/5000 [00:00<29:43,  2.80it/s]

Current train loss: 12.661616802215576


 40%|████      | 2002/5000 [09:54<14:55,  3.35it/s]

Current train loss: 4.896476965565067


 80%|████████  | 4002/5000 [19:48<04:53,  3.40it/s]

Current train loss: 4.24222665408562


100%|██████████| 5000/5000 [24:44<00:00,  3.37it/s]
100%|██████████| 625/625 [02:31<00:00,  4.14it/s]


Epoch 0:
Train Loss: 4.05910000834465
Val loss: 3.161041009902954
1.0


  0%|          | 2/5000 [00:00<25:32,  3.26it/s]

Current train loss: 3.2082626819610596


 40%|████      | 2002/5000 [09:54<14:57,  3.34it/s]

Current train loss: 3.1361558980398723


 80%|████████  | 4002/5000 [19:48<04:54,  3.39it/s]

Current train loss: 3.0756328315034263


100%|██████████| 5000/5000 [24:45<00:00,  3.37it/s]
100%|██████████| 625/625 [02:31<00:00,  4.12it/s]


Epoch 1:
Train Loss: 3.047209130334854
Val loss: 2.813968430328369
1.0


  0%|          | 2/5000 [00:00<25:16,  3.30it/s]

Current train loss: 2.783403992652893


 40%|████      | 2002/5000 [09:54<14:51,  3.36it/s]

Current train loss: 2.8508966820580617


 80%|████████  | 4002/5000 [19:49<04:57,  3.36it/s]

Current train loss: 2.820929088335166


100%|██████████| 5000/5000 [24:46<00:00,  3.36it/s]
100%|██████████| 625/625 [02:31<00:00,  4.12it/s]


Epoch 2:
Train Loss: 2.8076358087539672
Val loss: 2.674389543533325
1.0


  0%|          | 2/5000 [00:00<25:22,  3.28it/s]

Current train loss: 2.685961127281189


 40%|████      | 2002/5000 [09:55<14:54,  3.35it/s]

Current train loss: 2.6943757485200117


 80%|████████  | 4002/5000 [19:52<04:54,  3.39it/s]

Current train loss: 2.6766548437931608


100%|██████████| 5000/5000 [24:49<00:00,  3.36it/s]
100%|██████████| 625/625 [02:33<00:00,  4.07it/s]


Epoch 3:
Train Loss: 2.671585415840149
Val loss: 2.5785831508636474
1.0


  0%|          | 2/5000 [00:00<25:34,  3.26it/s]

Current train loss: 2.2325196266174316


 40%|████      | 2002/5000 [10:00<15:03,  3.32it/s]

Current train loss: 2.589764584789981


 80%|████████  | 4002/5000 [20:00<04:58,  3.34it/s]

Current train loss: 2.58264017957023


100%|██████████| 5000/5000 [24:59<00:00,  3.33it/s]
100%|██████████| 625/625 [02:33<00:00,  4.06it/s]


Epoch 4:
Train Loss: 2.5797937920093537
Val loss: 2.514091403579712
1.0


  0%|          | 2/5000 [00:00<25:32,  3.26it/s]

Current train loss: 2.4148359298706055


 40%|████      | 2002/5000 [09:59<15:00,  3.33it/s]

Current train loss: 2.5091880094278585


 80%|████████  | 4002/5000 [20:01<04:58,  3.34it/s]

Current train loss: 2.5128862317474647


100%|██████████| 5000/5000 [24:59<00:00,  3.33it/s]
100%|██████████| 625/625 [02:33<00:00,  4.06it/s]


Epoch 5:
Train Loss: 2.509734126996994
Val loss: 2.465382401275635
1.0


  0%|          | 2/5000 [00:00<25:37,  3.25it/s]

Current train loss: 2.6227495670318604


 40%|████      | 2002/5000 [10:03<15:32,  3.22it/s]

Current train loss: 2.4609735205576015


 80%|████████  | 4002/5000 [20:04<04:55,  3.37it/s]

Current train loss: 2.4547252281792815


100%|██████████| 5000/5000 [25:02<00:00,  3.33it/s]
100%|██████████| 625/625 [02:32<00:00,  4.11it/s]


Epoch 6:
Train Loss: 2.4545627468824387
Val loss: 2.4319754432678224
1.0


  0%|          | 2/5000 [00:00<25:11,  3.31it/s]

Current train loss: 2.544795513153076


 40%|████      | 2002/5000 [09:56<14:47,  3.38it/s]

Current train loss: 2.41085917603124


 80%|████████  | 4002/5000 [19:53<05:00,  3.32it/s]

Current train loss: 2.411234011774001


100%|██████████| 5000/5000 [24:52<00:00,  3.35it/s]
100%|██████████| 625/625 [02:31<00:00,  4.11it/s]


Epoch 7:
Train Loss: 2.410029220175743
Val loss: 2.3976976949691773
1.0


  0%|          | 2/5000 [00:00<25:18,  3.29it/s]

Current train loss: 2.108996033668518


 40%|████      | 2002/5000 [09:57<14:52,  3.36it/s]

Current train loss: 2.3635367720158067


 80%|████████  | 4002/5000 [19:55<04:57,  3.35it/s]

Current train loss: 2.3707668482810482


100%|██████████| 5000/5000 [24:53<00:00,  3.35it/s]
100%|██████████| 625/625 [02:32<00:00,  4.10it/s]


Epoch 8:
Train Loss: 2.3715607931613922
Val loss: 2.38222297668457
1.0


  0%|          | 2/5000 [00:00<25:46,  3.23it/s]

Current train loss: 2.23163640499115


 40%|████      | 2002/5000 [09:56<14:50,  3.37it/s]

Current train loss: 2.3339185299692335


 80%|████████  | 4002/5000 [19:51<05:00,  3.32it/s]

Current train loss: 2.3397968494850416


100%|██████████| 5000/5000 [24:49<00:00,  3.36it/s]
100%|██████████| 625/625 [02:32<00:00,  4.09it/s]


Epoch 9:
Train Loss: 2.33947609603405
Val loss: 2.355729240036011
1.0


  0%|          | 2/5000 [00:00<25:16,  3.30it/s]

Current train loss: 2.169552505016327


 40%|████      | 2002/5000 [09:55<14:44,  3.39it/s]

Current train loss: 2.3037221312046525


 80%|████████  | 4002/5000 [19:49<04:55,  3.38it/s]

Current train loss: 2.3081037288960786


100%|██████████| 5000/5000 [24:46<00:00,  3.36it/s]
100%|██████████| 625/625 [02:31<00:00,  4.12it/s]


Epoch 10:
Train Loss: 2.3107679262161254
Val loss: 2.331479719352722
1.0


  0%|          | 2/5000 [00:00<25:04,  3.32it/s]

Current train loss: 2.2140849828720093


 40%|████      | 2002/5000 [09:52<14:50,  3.37it/s]

Current train loss: 2.276515412818897


 80%|████████  | 4002/5000 [19:46<04:55,  3.37it/s]

Current train loss: 2.28374005588277


100%|██████████| 5000/5000 [24:42<00:00,  3.37it/s]
100%|██████████| 625/625 [02:31<00:00,  4.13it/s]


Epoch 11:
Train Loss: 2.2846662587165834
Val loss: 2.327339119338989
1.0


  0%|          | 2/5000 [00:00<24:57,  3.34it/s]

Current train loss: 2.3561549186706543


 40%|████      | 2002/5000 [09:53<14:50,  3.37it/s]

Current train loss: 2.249010990370999


 80%|████████  | 4002/5000 [19:46<04:59,  3.34it/s]

Current train loss: 2.2598615006051737


100%|██████████| 5000/5000 [24:42<00:00,  3.37it/s]
100%|██████████| 625/625 [02:30<00:00,  4.14it/s]


Epoch 12:
Train Loss: 2.261883927178383
Val loss: 2.323257204246521
1.0


  0%|          | 2/5000 [00:00<25:04,  3.32it/s]

Current train loss: 2.308452010154724


 40%|████      | 2002/5000 [09:53<14:50,  3.37it/s]

Current train loss: 2.227920084923774


 80%|████████  | 4002/5000 [19:50<04:53,  3.40it/s]

Current train loss: 2.237965793206893


100%|██████████| 5000/5000 [24:45<00:00,  3.36it/s]
100%|██████████| 625/625 [02:31<00:00,  4.13it/s]


Epoch 13:
Train Loss: 2.240810235309601
Val loss: 2.3015333766937256
1.0


  0%|          | 2/5000 [00:00<25:14,  3.30it/s]

Current train loss: 2.2389070987701416


 40%|████      | 2002/5000 [09:52<14:40,  3.41it/s]

Current train loss: 2.2099995046824246


 80%|████████  | 4002/5000 [19:44<04:57,  3.35it/s]

Current train loss: 2.2186962259763003


100%|██████████| 5000/5000 [24:41<00:00,  3.38it/s]
100%|██████████| 625/625 [02:31<00:00,  4.12it/s]

Epoch 14:
Train Loss: 2.2235176665067673
Val loss: 2.292942211151123





([4.05910000834465,
  3.047209130334854,
  2.8076358087539672,
  2.671585415840149,
  2.5797937920093537,
  2.509734126996994,
  2.4545627468824387,
  2.410029220175743,
  2.3715607931613922,
  2.33947609603405,
  2.3107679262161254,
  2.2846662587165834,
  2.261883927178383,
  2.240810235309601,
  2.2235176665067673],
 [3.161041009902954,
  2.813968430328369,
  2.674389543533325,
  2.5785831508636474,
  2.514091403579712,
  2.465382401275635,
  2.4319754432678224,
  2.3976976949691773,
  2.38222297668457,
  2.355729240036011,
  2.331479719352722,
  2.327339119338989,
  2.323257204246521,
  2.3015333766937256,
  2.292942211151123])

In [11]:
torchvision.__version__

'0.13.0+cu102'

Epoch 0:
Train Loss: 4.454108733892441
Val loss: 3.682773498916626
0.7071067811865475
  0%|          | 2/5000 [00:00<19:00,  4.38it/s]
Current train loss: 3.5417309999465942
 40%|████      | 2002/5000 [07:30<11:23,  4.39it/s]
Current train loss: 3.650782839758889
 80%|████████  | 4002/5000 [14:59<03:42,  4.48it/s]
Current train loss: 3.593524467105093
100%|██████████| 5000/5000 [18:42<00:00,  4.45it/s]
100%|██████████| 625/625 [01:52<00:00,  5.55it/s]
Epoch 1:
Train Loss: 3.5651441644191744
Val loss: 3.3736032577514647
0.7071067811865475
  0%|          | 2/5000 [00:00<19:49,  4.20it/s]
Current train loss: 3.327363133430481
 40%|████      | 2002/5000 [07:27<10:59,  4.54it/s]
Current train loss: 3.3840376391158355
 80%|████████  | 4002/5000 [14:55<03:46,  4.41it/s]
Current train loss: 3.352826803103499
100%|██████████| 5000/5000 [18:39<00:00,  4.46it/s]
100%|██████████| 625/625 [01:53<00:00,  5.52it/s]
Epoch 2:
Train Loss: 3.3399432120323183
Val loss: 3.2226122886657715
0.7071067811865475
  0%|          | 2/5000 [00:00<18:53,  4.41it/s]
Current train loss: 3.6570446491241455
 40%|████      | 2002/5000 [07:25<11:08,  4.49it/s]
Current train loss: 3.238890047792669


In [12]:
([4.255402976131439,
  3.378325371956825,
  3.166072261953354,
  3.0472596979379656,
  2.968185928463936,
  2.9084950731039045,
  2.8601422364234925,
  2.8210828330516815,
  2.787460916996002,
  2.7586146072626114,
  2.7335047327280044,
  2.7104552837371827,
  2.6902655891776086,
  2.6708284264326094,
  2.6541572811841965],
 [3.490547196960449,
  3.176377497291565,
  3.0337498971939088,
  2.952458595657349,
  2.892481524276733,
  2.8503320302963258,
  2.813573787879944,
  2.7869199739456176,
  2.762750726890564,
  2.7455632947921753,
  2.731031491279602,
  2.7182074075698854,
  2.6980771926879883,
  2.692284686756134,
  2.676084305477142])

([4.255402976131439,
  3.378325371956825,
  3.166072261953354,
  3.0472596979379656,
  2.968185928463936,
  2.9084950731039045,
  2.8601422364234925,
  2.8210828330516815,
  2.787460916996002,
  2.7586146072626114,
  2.7335047327280044,
  2.7104552837371827,
  2.6902655891776086,
  2.6708284264326094,
  2.6541572811841965],
 [3.490547196960449,
  3.176377497291565,
  3.0337498971939088,
  2.952458595657349,
  2.892481524276733,
  2.8503320302963258,
  2.813573787879944,
  2.7869199739456176,
  2.762750726890564,
  2.7455632947921753,
  2.731031491279602,
  2.7182074075698854,
  2.6980771926879883,
  2.692284686756134,
  2.676084305477142])

In [13]:
if is_multimodal:
    if train_visual_module:
        torch.save(visual_multimodal_lstm.state_dict(), 
                   f'./saved_models/finetuned_visual_multimodal_lstm_{num_data}_{lr}_{epochs}v5')
    else:
        torch.save(visual_multimodal_lstm.state_dict(), 
                   f'./saved_models/multimodal_lstm_{num_data}_{lr}_{epochs}v6')
else:
    torch.save(benchmark_model.state_dict(), 
               f'./saved_models/benchmark_model_{num_data}_{lr}_{epochs}v2')

In [14]:
def test(ml_model, 
          epochs, 
          val_data, 
          loss_fct, 
          optimizer,
          scheduler = None,
          clip=2.0):
    
    ml_model.eval()
    epoch_val_loss = 0
    for i, batch in enumerate(tqdm(val_data)):
        text, img, target = batch['input_ids'].to(device), batch['image'].to(device), batch['label_ids'].to(device)

        with torch.no_grad():
            output = ml_model(text, img)
            loss = loss_fct(output.view(-1, output.size(-1)), target.view(-1))

        epoch_val_loss += loss.item()

    current_val_loss = epoch_val_loss/len(val_data)
    perplexity = math.exp(current_val_loss)


    print(f"Val loss: {current_val_loss}")
    print(f"Val perp: {perplexity}")

    return current_val_loss, perplexity

test(visual_multimodal_lstm, 
    epochs, 
    test_data, 
    loss_fct, 
    optimizer,
    scheduler=None)

100%|██████████| 625/625 [02:31<00:00,  4.12it/s]

Val loss: 2.2953844356536863
Val perp: 9.928252052795216





(2.2953844356536863, 9.928252052795216)