In [1]:
import torch
import pandas as pd
import numpy as np
import random
import torch.nn as nn
from torch.autograd import Variable
from torch.nn import Parameter
from torch import Tensor
import torch.nn.functional as F
import matplotlib.image as img
import matplotlib.pyplot as plt
import torchvision
import PIL
import math
import datetime
import os
from bpemb import BPEmb
from torch import nn
from image_caption_dataset import ImageCaptionDataset
from multi_bpe import MultiBPE
from multimodal_model import MMLSTM, BenchmarkLSTM
from torch.utils.data import DataLoader
from tqdm import tqdm
import transformers

  return torch._C._cuda_getDeviceCount() > 0


In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
LOAD_PATH = './data/'
seed = 420
num_data = 200000
maxlen = 32
epochs = 15
train_pct = 0.8
bs=32
lr = 1.0/math.sqrt(32/bs)
is_multimodal=True
train_visual_module=False

In [3]:
torch.manual_seed(seed)
# torch.cuda.manual_seed(seed)
# torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
np.random.seed(seed)  # Numpy module.
random.seed(seed)  # Python random module.
# torch.backends.cudnn.benchmark = False
# torch.backends.cudnn.deterministic = True

In [4]:
def save_checkpoint(epoch, 
                    model, 
                    optimizer, 
                    scheduler,
                    loss,
                    FNAME):
    
    today = datetime.date.today()
    PATH = f'./checkpoints/{today.strftime("checkpoint-%m-%d-%Y")}/'
    if not os.path.exists(PATH):
        os.makedirs(PATH)
    torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict() if scheduler !=None else None,
            'loss': loss,
            }, f'{PATH}{FNAME}')

In [5]:
df = pd.read_csv(LOAD_PATH + 'ml_stacked_data.csv')

In [6]:
l = {}
for caption in df.caption:
    length = len(caption.split())
    if length in l:
        l[length] += 1  
    else:
        l[length] = 0

In [9]:
def _init_fn(worker_id):
    np.random.seed(int(seed))
    
def process_img(image_path): 
    transform = torchvision.transforms.Compose([
        # Resize image to 224 x 224 as required by most vision models
        torchvision.transforms.Resize(
            size=(224, 224)
        ),
        # Convert PIL image to tensor with image values in [0, 1]
        torchvision.transforms.ToTensor(),

        torchvision.transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ])
    
    im = PIL.Image.open(image_path)
    image = im.convert('RGB')
    image = transform(image)
    
    return image.view(1, image.size(0), image.size(1), image.size(2))

def get_dataset(tokenizer, path=LOAD_PATH, num_data = None, maxlen=64):
    df = pd.read_csv(path + 'ml_stacked_data.csv')    
    if num_data != None:
        text_df = df[:num_data]
    else:
        text_df = df
    dataset = ImageCaptionDataset(text_df, tokenizer, maxlen=maxlen)

    return dataset

def train(ml_model, 
          epochs, 
          train_data, 
          val_data, 
          loss_fct, 
          optimizer,
          scheduler = None,
          clip=2.0):
    epoch_train_losses = []
    epoch_val_losses = []
    perplexities = []
    for epoch in range(epochs):
        print(optimizer.param_groups[0]["lr"])
        ml_model.train()
        epoch_train_loss, num_train_steps = 0, 0
        for i, batch in enumerate(tqdm(train_data)):
            text, img, target = batch['input_ids'].to(device), batch['image'].to(device), batch['label_ids'].to(device)
            output = ml_model(text, img)
            ml_model.zero_grad()
            loss = loss_fct(output.view(-1, 320001), target.view(-1))
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(ml_model.parameters(), clip)
            optimizer.step()
            if scheduler is not None:
                scheduler.step()
            epoch_train_loss += loss.item()
            num_train_steps += 1
            
            if i % 2000 == 1:
                print("Current train loss:", epoch_train_loss/num_train_steps)
        current_train_loss = epoch_train_loss/len(train_data)
        epoch_train_losses.append(current_train_loss)
        
        ml_model.eval()
        epoch_val_loss = 0
        for i, batch in enumerate(tqdm(val_data)):
            text, img, target = batch['input_ids'].to(device), batch['image'].to(device), batch['label_ids'].to(device)
            
            with torch.no_grad():
                output = ml_model(text, img)
                loss = loss_fct(output.view(-1, output.size(-1)), target.view(-1))
               
            epoch_val_loss += loss.item()
            
        current_val_loss = epoch_val_loss/len(val_data)
        epoch_val_losses.append(current_val_loss)
        perplexity = math.exp(current_val_loss)
        
        if len(perplexities) == 0:
            perplexities.append(perplexity)
        else:
            if perplexity >= (3 * perplexities[-1]):
                for g in torch.optim.param_groups:
                    g['lr'] = g['lr']/2
            perplexities.append(perplexity)
            
        print(f"Epoch {epoch}:\nTrain Loss: {current_train_loss}\nVal loss: {current_val_loss}")

#         print("=> Saving checkpoint...")
#         if is_multimodal:
#             if train_visual_module:
#                 FNAME = f'finetuned_visual_multimodal_lstm_{num_data}_{lr}_{epoch}-{current_val_loss:.2f}'
#                 save_checkpoint(epoch, 
#                                 ml_model, 
#                                 optimizer, 
#                                 scheduler,
#                                 current_val_loss,
#                                 FNAME)
#             else:
#                 FNAME = f'multimodal_lstm_{num_data}_{lr}_{epoch}-{current_val_loss:.2f}'
#                 save_checkpoint(epoch, 
#                                 ml_model, 
#                                 optimizer, 
#                                 scheduler,
#                                 current_val_loss,
#                                 FNAME)

#         else:
#             FNAME = f'benchmark_model_{num_data}_{lr}_{epoch}-{current_val_loss:.2f}'
#             save_checkpoint(epoch, 
#                             ml_model, 
#                             optimizer, 
#                             scheduler,
#                             current_val_loss,
#                             FNAME)
    return epoch_train_losses, epoch_val_losses

In [10]:
multi_bpe = MultiBPE()
all_data = get_dataset(tokenizer=multi_bpe, num_data=num_data,maxlen=maxlen)

train_size = int(train_pct * len(all_data))
test_size = len(all_data) - train_size
train_dataset, val_test_dataset = torch.utils.data.random_split(all_data, [train_size, test_size])

train_size = int(0.5 * len(val_test_dataset))
test_size = len(val_test_dataset) - train_size
val_dataset, test_dataset = torch.utils.data.random_split(val_test_dataset, [train_size, test_size])


train_data = DataLoader(train_dataset, batch_size=bs, shuffle=True)
val_data = DataLoader(val_dataset, batch_size=bs, shuffle=False)
test_data = DataLoader(test_dataset, batch_size=bs, shuffle=False)

# benchmark_model = BenchmarkLSTM().to(device)
visual_multimodal_lstm = MMLSTM(is_multimodal=is_multimodal, 
                                      train_visual_module=train_visual_module).to(device)

optimizer = torch.optim.SGD(visual_multimodal_lstm.parameters(), lr=lr)
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=320000, reduction='mean')

# num_warmup_steps = (len(train_dataset) // bs)
# num_training_steps = (len(train_dataset) // bs) * epochs
# scheduler = transformers.get_linear_schedule_with_warmup(optimizer, 
#                                                          num_warmup_steps=num_warmup_steps, 
#                                                          num_training_steps=num_training_steps)
# scheduler = transformers.get_constant_schedule_with_warmup(optimizer, 
#                                                          num_warmup_steps=num_warmup_steps) 
train(visual_multimodal_lstm, 
    epochs, 
    train_data, 
    val_data, 
    loss_fct, 
    optimizer,
    scheduler=None)

1.0


  0%|          | 0/5000 [00:00<?, ?it/s]


RuntimeError: [Errno 2] No such file or directory: './images/000000046310.jpg'

In [None]:
if is_multimodal:
    if train_visual_module:
        torch.save(visual_multimodal_lstm.state_dict(), 
                   f'./saved_models/finetuned_visual_multimodal_lstm_{num_data}_{lr}_{epochs}v5')
    else:
        torch.save(visual_multimodal_lstm.state_dict(), 
                   f'./saved_models/multimodal_lstm_{num_data}_{lr}_{epochs}v6')
else:
    torch.save(benchmark_model.state_dict(), 
               f'./saved_models/benchmark_model_{num_data}_{lr}_{epochs}v2')

In [None]:
def test(ml_model, 
          epochs, 
          val_data, 
          loss_fct, 
          optimizer,
          scheduler = None,
          clip=2.0):
    
    ml_model.eval()
    epoch_val_loss = 0
    for i, batch in enumerate(tqdm(val_data)):
        text, img, target = batch['input_ids'].to(device), batch['image'].to(device), batch['label_ids'].to(device)

        with torch.no_grad():
            output = ml_model(text, img)
            loss = loss_fct(output.view(-1, output.size(-1)), target.view(-1))

        epoch_val_loss += loss.item()

    current_val_loss = epoch_val_loss/len(val_data)
    perplexity = math.exp(current_val_loss)


    print(f"Val loss: {current_val_loss}")
    print(f"Val perp: {perplexity}")

    return current_val_loss, perplexity

test(visual_multimodal_lstm, 
    epochs, 
    test_data, 
    loss_fct, 
    optimizer,
    scheduler=None)