# Train the Model

## Import Libraries

In [1]:
import torch
import torch.nn as nn
import torch.utils.data as data
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from models import Encoder, Decoder

from pathlib import Path
from DatasetInterface import MSCOCOInterface
import json
import numpy as np
import time

## Load Dataset Interface and DataLoader

In [2]:
root = Path('Data')
#imgs_path = root/'images'/'train2017'
train_imgs_path = root/'train2017'
test_imgs_path = root/'val2017'
#captions_path = root/'annotations'/'captions_train2017.json'
train_captions_path = root/'annotations_trainval2017'/'annotations'/'sports_captions_train.json'
val_captions_path = root/'annotations_trainval2017'/'annotations'/'sports_captions_val.json'
test_captions_path = root/'annotations_trainval2017'/'annotations'/'sports_captions_test.json'

"""
# load vocab
with open('vocabulary/idx_to_string.json') as json_file:
    idx_to_string_json = json.load(json_file)
        
idx_to_string = dict()
for key in idx_to_string_json:
    idx_to_string[int(key)] = idx_to_string_json[key]
    
with open('vocabulary/string_to_index.json') as json_file:
    string_to_index = json.load(json_file)
"""

"\n# load vocab\nwith open('vocabulary/idx_to_string.json') as json_file:\n    idx_to_string_json = json.load(json_file)\n        \nidx_to_string = dict()\nfor key in idx_to_string_json:\n    idx_to_string[int(key)] = idx_to_string_json[key]\n    \nwith open('vocabulary/string_to_index.json') as json_file:\n    string_to_index = json.load(json_file)\n"

In [3]:
# to boost the performence of CUDA use:
# torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
train_interface_params = {
    'imgs_path': train_imgs_path,
    'captions_path': train_captions_path,
    'freq_threshold': 5,
    'sequence_length': 20,
    'caps_per_img': 1,
    'stage': "train",
    'idx_to_string': None,
    'string_to_index': None,
}

val_interface_params = {
    'imgs_path': train_imgs_path,
    'captions_path': val_captions_path,
    'freq_threshold': 5,
    'sequence_length': 20,
    'caps_per_img': 1,
    'stage': "validation",
    'idx_to_string': None,
    'string_to_index': None,
}

test_interface_params = {
    'imgs_path': test_imgs_path,
    'captions_path': test_captions_path,
    'freq_threshold': 5,
    'sequence_length': 20,
    'caps_per_img': 1,
    'stage': "test",
    'idx_to_string': None,
    'string_to_index': None,
}


# Training Interface
coco_interface_train = MSCOCOInterface(**train_interface_params)

# Validation Interface
coco_interface_val = MSCOCOInterface(**val_interface_params)

# Testing Interface
coco_interface_test = MSCOCOInterface(**test_interface_params)


In [6]:
print("Lenght of training image: {}, Lenght of Validation image: {} Lenght of Testing image: {}"\
      .format(len(coco_interface_train), len(coco_interface_val), len(coco_interface_test)))

print(f"Lenght of vocabulary: {len(coco_interface_train.idx_to_string)}")

Lenght of training image: 16252, Lenght of Validation image: 6966 Lenght of Testing image: 938
Lenght of vocabulary: 2688


In [6]:
batch_size = 1
train_loader = data.DataLoader(coco_interface_train, batch_size=batch_size, shuffle=True)
val_loader = data.DataLoader(coco_interface_val, batch_size=batch_size, shuffle=False)
test_loader = data.DataLoader(coco_interface_test, batch_size=batch_size, shuffle=False)

## Parameters

In [7]:
embed_size = 512
hidden_size = 512
vocab_size = len(coco_interface_train.idx_to_string)
num_layers = 1
total_epochs = 1

## Encoder and Decoder

In [8]:
encoder = Encoder(embed_size=embed_size, pretrained=False, model_weight_path="./model/resnet152_model.pth")
decoder = Decoder(embed_size=embed_size, hidden_size=hidden_size, vocab_size=vocab_size, num_layers=num_layers)
print("########################################READY########################################")

########################################READY########################################


In [9]:
# the loss is a cross entropy loss and ignore the index of <PAD> since it doesn't make any difference
criterion = nn.CrossEntropyLoss(ignore_index=coco_interface_train.string_to_index["<PAD>"])

# combine the paramters of decoder and ecnoder
params = list(decoder.parameters()) + list(encoder.embed.parameters())

# Adam optimizer
opt_pars = {'lr':1e-5, 'weight_decay':1e-3, 'betas':(0.9, 0.999), 'eps':1e-08}
optimizer = optim.Adam(params, **opt_pars)

## Train

In [10]:
def save_model(epoch, encoder, decoder, training_loss, validation_loss, checkpoint_path):
    torch.save({
        'epoch': epoch,
        'encoder_state_dict': encoder.state_dict(),
        'decoder_state_dict': decoder.state_dict(),
        'training_loss': training_loss,
        'validation_loss': validation_loss
    }, checkpoint_path)

def load_model(encoder, decoder, checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    encoder.load_state_dict(checkpoint['encoder_state_dict'])
    decoder.load_state_dict(checkpoint['decoder_state_dict'])
    training_loss = checkpoint['training_loss']
    validation_loss = checkpoint['validation_loss']
    
    return encoder, decoder, training_loss, validation_loss

In [11]:
encoder, decoder, training_loss, validation_loss = load_model(encoder, decoder, './model/image_captioning_model_v0.pth')

In [12]:
def train(encoder, decoder, criterion, optimizer, train_loader, val_loader, total_epoch, checkpoint_path):
    
    encoder.to(device)
    decoder.to(device)
    
    training_loss = []
    validation_loss = []
    
    start_time = time.time()
    for epoch in range(total_epoch):
        train_epoch_loss = 0
        val_epoch_loss = 0
        
        
        # Training phase
        encoder.train()
        decoder.train()
        
        for id, batch in enumerate(train_loader):
            idx, images, captions = batch
            images, captions = images.to(device), captions.to(device)
            
            # Zero the gradients.
            encoder.zero_grad()
            decoder.zero_grad()
            
            features = encoder(images)
            outputs = decoder(features, captions)
            
            loss = criterion(outputs.view(-1, vocab_size), captions.contiguous().view(-1))
            
            loss.backward()
            optimizer.step()
        
            train_epoch_loss += loss.item()
            if id % 100 == 0:
                print('Training: ', id, ' ', loss.item())
                
        train_epoch_loss /= len(train_loader)
        training_loss.append(train_epoch_loss)
        
        # validation phase
        encoder.eval()
        decoder.eval()
        
        for id, batch in enumerate(val_loader):
            idx, images, captions = batch
            images, captions = images.to(device), captions.to(device)
            features = encoder(images)
            outputs = decoder(features, captions)
            loss = criterion(outputs.view(-1, vocab_size), captions.contiguous().view(-1))
            val_epoch_loss += loss.item()
            if id % 100 == 0:
                print('Validation: ', id, ' ', loss.item())
            
        val_epoch_loss /= len(val_loader)
        validation_loss.append(val_epoch_loss)
    
    epoch_time = (time.time() - start_time) /60**1
    
    save_model(epoch, encoder, decoder, optimizer, training_loss, validation_loss, checkpoint_path)
    
    print("Epoch: {1:d}. Training Loss = {1:.4f}, Training Perplexity: {2:.4f}. Validation Loss: {3:.4f}, Validation Perplexity: {4:.4f}. Time: {5:f}" \
          .format(epoch, train_epoch_loss, np.exp(train_epoch_loss), val_epoch_loss, np.exp(val_epoch_loss), epoch_time))
    
    return training_loss, validation_loss
    

In [13]:
train_params = {
    'encoder': encoder,
    'decoder': decoder,
    'criterion': criterion,
    'optimizer': optimizer,
    'train_loader': train_loader,
    'val_loader': val_loader,
    'total_epoch': 1,
    'checkpoint_path': './model/image_captioning_model_v0.pth'
}

training_loss, validation_loss = train(**train_params) 

Training:  0   5.7463603019714355
Training:  100   4.77365255355835
Training:  200   2.5623116493225098
Training:  300   3.593135118484497
Training:  400   3.482023000717163
Training:  500   3.6291959285736084
Training:  600   3.4522058963775635
Training:  700   4.048513889312744
Training:  800   4.016190052032471
Training:  900   3.854299783706665
Training:  1000   4.702203750610352
Training:  1100   4.716959476470947
Training:  1200   5.448296546936035
Training:  1300   3.0563738346099854
Training:  1400   4.283986568450928
Training:  1500   5.271600723266602
Training:  1600   4.686922073364258
Training:  1700   4.246876239776611
Training:  1800   5.75971794128418
Training:  1900   3.0879921913146973
Training:  2000   2.7001264095306396
Training:  2100   3.2407281398773193
Training:  2200   6.608675003051758
Training:  2300   3.7038471698760986
Training:  2400   3.2610924243927
Training:  2500   3.7188451290130615
Training:  2600   4.580766677856445
Training:  2700   4.31274652481079

TypeError: save_model() takes 6 positional arguments but 7 were given