# Train the Model

## Import Libraries

In [12]:
import torch
import torch.nn as nn
import torch.utils.data as data
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from models import Encoder, Decoder

from pathlib import Path
from DatasetInterface import MSCOCOInterface
import json
import numpy as np
import time

## Load Dataset Interface and DataLoader

In [2]:
root = Path('Data')
#imgs_path = root/'images'/'train2017'
imgs_path = root/'train2017'
#captions_path = root/'annotations'/'captions_train2017.json'
captions_path = root/'annotations_trainval2017'/'annotations'/'captions_train2017.json'

# load vocab
with open('vocabulary/idx_to_string.json') as json_file:
    idx_to_string_json = json.load(json_file)
        
idx_to_string = dict()
for key in idx_to_string_json:
    idx_to_string[int(key)] = idx_to_string_json[key]
    
with open('vocabulary/string_to_index.json') as json_file:
    string_to_index = json.load(json_file)


In [3]:
# to boost the performence of CUDA use:
# torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
train_interface_params = {
    'imgs_path': imgs_path,
    'captions_path': captions_path,
    'freq_threshold': 1,
    'sequence_length': 20,
    'caps_per_img': 1,
    'stage': "train",
    'idx_to_string': idx_to_string,
    'string_to_index': string_to_index,
    'seed': 706
}

val_interface_params = {
    'imgs_path': imgs_path,
    'captions_path': captions_path,
    'freq_threshold': 1,
    'sequence_length': 20,
    'caps_per_img': 1,
    'stage': "validation",
    'idx_to_string': idx_to_string,
    'string_to_index': string_to_index,
    'seed': 706
}

test_interface_params = {
    'imgs_path': imgs_path,
    'captions_path': captions_path,
    'freq_threshold': 1,
    'sequence_length': 20,
    'caps_per_img': 1,
    'stage': "test",
    'idx_to_string': idx_to_string,
    'string_to_index': string_to_index,
    'seed': 706
}


# Training Interface
coco_interface_train = MSCOCOInterface(**train_interface_params)

# Validation Interface
coco_interface_val = MSCOCOInterface(**val_interface_params)

# Testing Interface
coco_interface_test = MSCOCOInterface(**test_interface_params)


print("Lenght of training image: {}, Lenght of Validation image: {} Lenght of Testing image: {}" \
      .format(len(coco_interface_train), len(coco_interface_val), len(coco_interface_test)))

Lenght of training image: 15000, Lenght of Validation image: 5000 Lenght of Testing image: 5000


In [5]:
batch_size = 1
train_loader = data.DataLoader(coco_interface_train, batch_size=batch_size, shuffle=True)
val_loader = data.DataLoader(coco_interface_val, batch_size=batch_size, shuffle=False)
test_loader = data.DataLoader(coco_interface_test, batch_size=batch_size, shuffle=False)

## Parameters

In [6]:
embed_size = 512
hidden_size = 512
vocab_size = len(coco_interface_train.idx_to_string)
num_layers = 1
total_epochs = 1

## Encoder and Decoder

In [7]:
encoder = Encoder(embed_size=embed_size, pretrained=False, model_weight_path="./model/resnet152_model.pth")
decoder = Decoder(embed_size=embed_size, hidden_size=hidden_size, vocab_size=vocab_size, num_layers=num_layers)
print("########################################READY########################################")

########################################READY########################################


In [8]:
# the loss is a cross entropy loss and ignore the index of <PAD> since it doesn't make any difference
criterion = nn.CrossEntropyLoss(ignore_index=coco_interface_train.string_to_index["<PAD>"])

# combine the paramters of decoder and ecnoder
params = list(decoder.parameters()) + list(encoder.embed.parameters())

# Adam optimizer
opt_pars = {'lr':1e-5, 'weight_decay':1e-3, 'betas':(0.9, 0.999), 'eps':1e-08}
optimizer = optim.Adam(params, **opt_pars)

## Train

In [9]:
def save_model(epoch, encoder, decoder, optimizer, training_loss, validation_loss, checkpoint_path):
    torch.save({
        'epoch': epoch,
        'encoder_state_dict': encoder.state_dict(),
        'decoder_state_dict': decoder.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'training_loss': training_loss,
        'validation_loss': validation_loss
    }, checkpoint_path)

def load_model(encoder, decoder, optimizer, training_loss, validation_loss, checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    enoder.load_state_dict(checkpoint['encoder_state_dict'])
    decoder.load_state_dict(checkpoint['decoder_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    training_loss = checkpoint['training_loss']
    validation_loss = checkpoint['validation_loss']
    
    return encoder, decoder, optimizer, training_loss, validation_loss

In [14]:
def train(encoder, decoder, criterion, optimizer, train_loader, val_loader, total_epoch, checkpoint_path):
    
    encoder.to(device)
    decoder.to(device)
    
    training_loss = []
    validation_loss = []
    
    start_time = time.time()
    for epoch in range(total_epoch):
        train_epoch_loss = 0
        val_epoch_loss = 0
        
        
        # Training phase
        encoder.train()
        decoder.train()
        
        for id, batch in enumerate(train_loader):
            
            idx, images, captions = batch
            images, captions = images.to(device), captions.to(device)
            
            # Zero the gradients.
            encoder.zero_grad()
            decoder.zero_grad()
            
            features = encoder(images)
            outputs = decoder(features, captions)
            
            loss = criterion(outputs.view(-1, vocab_size), captions.contiguous().view(-1))
            
            loss.backward()
            optimizer.step()
        
            train_epoch_loss += loss.item()
            print(loss.item())
        train_epoch_loss /= len(train_loader)
        training_loss.append(train_epoch_loss)
        
        # validation phase
        encoder.eval()
        decoder.eval()
        
        for id, batch in enumerate(val_loader):
            idx, images, captions = batch
            images, captions = images.to(device), captions.to(device)
            features = encoder(images)
            outputs = decoder(features, captions)
            loss = criterion(outputs.view(-1, vocab_size), captions.contiguous().view(-1))
            val_epoch_loss += loss.item()
            
        val_epoch_loss /= len(val_loader)
        validation_loss.append(val_epoch_loss)
    
    epoch_time = (time.time() - start_time) /60**1
    
    save_model(epoch, encoder, decoder, optimizer, training_loss, validation_loss, checkpoint_path)
    
    print("Epoch: {1:d}. Training Loss = {1:.4f}, Training Perplexity: {2:.4f}. Validation Loss: {3:.4f}, Validation Perplexity: {4:.4f}. Time: {5:d}" \
          .format(epoch, train_epoch_loss, np.exp(train_epoch_loss), val_epoch_loss, np.exp(val_epoch_loss), epoch_time))
    
    return training_loss, validation_loss
    

In [15]:
train_params = {
    'encoder': encoder,
    'decoder': decoder,
    'criterion': criterion,
    'optimizer': optimizer,
    'train_loader': train_loader,
    'val_loader': val_loader,
    'total_epoch': 1,
    'checkpoint_path': './model/image_captioning_model_v0.pth'
}

training_loss, validation_loss = train(**train_params) 

4.597378730773926
4.937851905822754
6.679941654205322
5.0036540031433105
4.918497085571289
6.342258930206299
4.928537368774414
6.742866039276123
6.400446891784668
5.012894153594971
5.112613201141357
5.347065448760986
5.154609680175781
5.558534622192383
5.89182710647583
5.395740985870361
6.5538482666015625
5.207512378692627


KeyboardInterrupt: 