# Train the Model

## Import Libraries

In [1]:
import torch
import torch.nn as nn
import torch.utils.data as data
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from models import Encoder, Decoder

from pathlib import Path
from DatasetInterface import MSCOCOInterface, get_loader
from data_prep_utils import *
from utils import train, save_model, load_model, plot_loss
import json
import numpy as np
import math
import time

###### download the data we need
# !cd ~/INM706-image-captioning/Datasets/coco/images/
# !wget http://images.cocodataset.org/zips/train2017.zip
# !wget http://images.cocodataset.org/zips/val2017.zip
# !unzip train2017.zip
# !unzip val2017.zip
# !rm train2017.zip
# !rm val2017.zip

##### run code below if nltk hasn't been set up in clound instance yet
# !python -m nltk.downloader -d /usr/local/share/nltk_data all

###### run code below to save pre-trained weights if needed
# cd ~/INM706-image-captioning/model
# !wget https://download.pytorch.org/models/resnet152-394f9c45.pth
# !mv resnet152-394f9c45.pth resnet152_model.pth

## Load Dataset Interface and DataLoader

In [2]:
root = Path('Datasets/coco')
imgs_path = root/'images'/'train2017'
imgs_path_test = root/'images'/'val2017'


prepare_datasets(train_percent = 0.87, super_categories=['sports'],
                    max_train=15000, max_val=3000, max_test=3000)

#### build vocab using full original coco train. Uncomment to run
# build_vocab(freq_threshold=2, sequence_length=40,
#             captions_file='captions_train2017.json')

train_captions_path = root/'annotations'/'custom_captions_train.json'
val_captions_path = root/'annotations'/'custom_captions_val.json'
test_captions_path = root/'annotations'/'custom_captions_test.json'



train dataset has 15000 images
 val dataset has 3000 images
 test dataset has 938 images


In [3]:
# to boost the performence of CUDA use:
# torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
train_interface_params = {
    'imgs_path': imgs_path,
    'captions_path': train_captions_path,
    'freq_threshold': 5,
    # 'sequence_length': 20,
    'caps_per_img': 5,
    # 'stage': "train",
    'vocab_from_file': True
}

val_interface_params = {
    'imgs_path': imgs_path,
    'captions_path': val_captions_path,
    'freq_threshold': 5,
    # 'sequence_length': 20,
    'caps_per_img': 1,
    # 'stage': "validation",
    'vocab_from_file': True
}

test_interface_params = {
    'imgs_path': imgs_path_test,
    'captions_path': test_captions_path,
    'freq_threshold': 5,
    # 'sequence_length': 20,
    'caps_per_img': 1,
    # 'stage': "test",
    'vocab_from_file': True
}


####################
Vocab size is 16232
####################

Obtaining caption lengths...


100%|██████████| 75000/75000 [00:07<00:00, 10116.06it/s]


####################
Vocab size is 16232
####################

Obtaining caption lengths...


100%|██████████| 3000/3000 [00:00<00:00, 9982.21it/s] 


In [7]:
batch_size = 32

# # Training Interface
# coco_interface_train = MSCOCOInterface(**train_interface_params)

# # Validation Interface
# coco_interface_val = MSCOCOInterface(**val_interface_params)

train_loader = get_loader(**train_interface_params, batch_size=batch_size)
val_loader = get_loader(**val_interface_params, batch_size=batch_size)
# train_loader = data.DataLoader(coco_interface_train, batch_size=batch_size, shuffle=True)
# val_loader = data.DataLoader(coco_interface_val, batch_size=batch_size, shuffle=False)

####################
Vocab size is 16232
####################

Obtaining caption lengths...


100%|██████████| 75000/75000 [00:07<00:00, 9874.91it/s] 


####################
Vocab size is 16232
####################

Obtaining caption lengths...


100%|██████████| 3000/3000 [00:00<00:00, 10215.84it/s]


In [9]:
print("training captions: {}\nValidation captions: {}"
      .format(len(train_loader.dataset), len(val_loader.dataset)))

# print(f"Length of vocabulary: {len(coco_interface_train.idx_to_string)}")

training captions: 75000
Validation captions: 3000


## Parameters

In [10]:
embed_size = 512
hidden_size = 512
vocab_size = len(train_loader.dataset.idx_to_string)
num_layers = 1

## Encoder and Decoder

In [16]:
# pretrained = False does use a pretrained resnet but loads 
# from local .pth file 
encoder = Encoder(embed_size=embed_size, pretrained=False, model_weight_path="./model/resnet152_model.pth")
decoder = Decoder(embed_size=embed_size, hidden_size=hidden_size, vocab_size=vocab_size, num_layers=num_layers)
print("########################################READY########################################")

########################################READY########################################


In [22]:
# the loss is a cross entropy loss and ignore the index of <PAD> since it doesn't make any difference
criterion = nn.CrossEntropyLoss(ignore_index=coco_interface_train.string_to_index["<PAD>"])

# combine the parameters of decoder and encoder
params = list(decoder.parameters()) + list(encoder.embed.parameters())

# Adam optimizer
opt_pars = {'lr':5e-4, 'weight_decay':1e-3, 'betas':(0.9, 0.999), 'eps':1e-08}
optimizer = optim.Adam(params, **opt_pars)

## Train

In [23]:
def save_model(epoch, encoder, decoder, training_loss, validation_loss, checkpoint_path):
    torch.save({
        'epoch': epoch,
        'encoder_state_dict': encoder.state_dict(),
        'decoder_state_dict': decoder.state_dict(),
        'training_loss': training_loss,
        'validation_loss': validation_loss
    }, checkpoint_path)

def load_model(encoder, decoder, checkpoint_path):

    checkpoint = torch.load(checkpoint_path)
    encoder.load_state_dict(checkpoint['encoder_state_dict'])
    decoder.load_state_dict(checkpoint['decoder_state_dict'])
    training_loss = checkpoint['training_loss']
    validation_loss = checkpoint['validation_loss']

    return encoder, decoder, training_loss, validation_loss


In [33]:
CHECKPOINT = './model/image_captioning_model_v7.pth'
if Path(CHECKPOINT).exists():
    encoder, decoder, training_loss, validation_loss = load_model(encoder, decoder, CHECKPOINT)
else:
    print(f'{CHECKPOINT} file does not exist, training startging from scratch')

./model/image_captioning_model_v7.pth file does not exist, training startging from scratch


In [37]:
import math
train_step = math.ceil(len(train_loader.dataset.caption_lengths) / train_loader.batch_sampler.batch_size)
val_step = math.ceil(len(val_loader.dataset.caption_lengths) / val_loader.batch_sampler.batch_size)

def train(encoder, decoder, criterion, optimizer, train_loader, val_loader, total_epoch, checkpoint_path):
    
    encoder.to(device)
    decoder.to(device)
    
    training_loss = []
    validation_loss = []
    
    start_time = time.time()
    for epoch in range(total_epoch):
        train_epoch_loss = 0
        val_epoch_loss = 0
        
        # Training phase
        encoder.train()
        decoder.train()
        
        for i_step in range(train_step):
            # obtain a sample where all captions have same length
            indices = train_loader.dataset.get_train_indices()
            # Create and assign a batch sampler to retrieve a batch with the sampled indices.
            new_sampler = data.sampler.SubsetRandomSampler(indices=indices)
            train_loader.batch_sampler.sampler = new_sampler

            # Obtain the batch.
            idx, images, captions = next(iter(train_loader))

            images, captions = images.to(device), captions.to(device)
            
            # Zero the gradients.
            encoder.zero_grad()
            decoder.zero_grad()
            
            features = encoder(images)
            outputs = decoder(features, captions)
            
            loss = criterion(outputs.view(-1, vocab_size), captions.contiguous().view(-1))
            
            loss.backward()
            optimizer.step()
        
            train_epoch_loss += loss.item()
            if i_step % 100 == 0:
                print('Training: {}  :{:.4f}'.format(i_step,loss.item()))
                
        train_epoch_loss /= train_step
        training_loss.append(train_epoch_loss)
        
        # validation phase
        encoder.eval()
        decoder.eval()
        
        for i_step in range(val_step):
            # obtain a sample where all captions have same length
            indices = val_loader.dataset.get_train_indices()
            # Create and assign a batch sampler to retrieve a batch with the sampled indices.
            new_sampler = data.sampler.SubsetRandomSampler(indices=indices)
            val_loader.batch_sampler.sampler = new_sampler

            # Obtain the batch.
            idx, images, captions = next(iter(val_loader))

            images, captions = images.to(device), captions.to(device)
            
            features = encoder(images)
            outputs = decoder(features, captions)
            
            loss = criterion(outputs.view(-1, vocab_size), captions.contiguous().view(-1))
            
            val_epoch_loss += loss.item()
            if i_step % 100 == 0:
                print('Validation: {}  :{:.4f}'.format(i_step,loss.item()))
            
        val_epoch_loss /= val_step
        validation_loss.append(val_epoch_loss)
    
        epoch_time = (time.time() - start_time) /60**1

        save_model(epoch, encoder, decoder, training_loss, validation_loss, checkpoint_path)

        print("Epoch: {:d}. Training Loss = {:.4f}, Training Perplexity: {:.4f}. Validation Loss: {:.4f}, Validation Perplexity: {:.4f}. Time: {:f}" \
          .format(epoch, train_epoch_loss, np.exp(train_epoch_loss), val_epoch_loss, np.exp(val_epoch_loss), epoch_time))
    
    return training_loss, validation_loss
    

In [38]:
train_params = {
    'encoder': encoder,
    'decoder': decoder,
    'criterion': criterion,
    'optimizer': optimizer,
    'train_loader': train_loader,
    'val_loader': val_loader,
    'total_epoch': 10,
    'checkpoint_path': CHECKPOINT
}

training_loss, validation_loss = train(**train_params) 

Training: 0  :4.0140
Training: 100  :3.6668
Training: 200  :2.4205
Training: 300  :2.6587
Training: 400  :1.8600
Training: 500  :3.0281
Training: 600  :1.9365
Training: 700  :1.9621
Training: 800  :2.4196
Training: 900  :4.0941
Training: 1000  :1.6715
Training: 1100  :3.8746
Training: 1200  :3.4008
Training: 1300  :2.8754
Training: 1400  :2.3959
Training: 1500  :2.5596
Training: 1600  :3.2610
Training: 1700  :4.4235
Training: 1800  :2.8325
Training: 1900  :2.9630
Training: 2000  :2.3668
Training: 2100  :5.0385
Training: 2200  :2.5006
Training: 2300  :1.7134
Validation: 0  :3.0233
Epoch: 0. Training Loss = 3.1229, Training Perplexity: 22.7113. Validation Loss: 3.0649, Validation Perplexity: 21.4314. Time: 3.008752
Training: 0  :3.1670
Training: 100  :4.0962
Training: 200  :3.0191
Training: 300  :3.0917
Training: 400  :1.7003
Training: 500  :5.6103
Training: 600  :6.4488
Training: 700  :4.2989
Training: 800  :3.7549
Training: 900  :3.1569
Training: 1000  :1.5605
Training: 1100  :1.6264
T