# Train the Model

## Import Libraries

In [1]:
import torch
import torch.nn as nn
import torch.utils.data as data
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from models import Encoder, Decoder

from pathlib import Path
from DatasetInterface import MSCOCOInterface, get_loader
from data_prep_utils import *
from utils import train, save_model, load_model, plot_loss
import json
import numpy as np
import math
import time

###### download the data we need
# !cd ~/INM706-image-captioning/Datasets/coco/images/
# !wget http://images.cocodataset.org/zips/train2017.zip
# !wget http://images.cocodataset.org/zips/val2017.zip
# !unzip train2017.zip
# !unzip val2017.zip
# !rm train2017.zip
# !rm val2017.zip

##### run code below if nltk hasn't been set up in clound instance yet
# !python -m nltk.downloader -d /usr/local/share/nltk_data all

###### run code below to save pre-trained weights if needed
# cd ~/INM706-image-captioning/model
# !wget https://download.pytorch.org/models/resnet152-394f9c45.pth
# !mv resnet152-394f9c45.pth resnet152_model.pth

## Load Dataset Interface and DataLoader

In [2]:
root = Path('Datasets/coco')
imgs_path = root/'images'/'train2017'
imgs_path_test = root/'images'/'val2017'


# prepare_datasets(train_percent = 0.87, super_categories=None,
#                  max_train=45000, max_val=9000, max_test=3000)

#### build vocab using full original coco train. Uncomment to run
# build_vocab(freq_threshold=2, sequence_length=40,
#             captions_file='captions_train2017.json')

train_captions_path = root/'annotations'/'custom_captions_train.json'
val_captions_path = root/'annotations'/'custom_captions_val.json'
test_captions_path = root/'annotations'/'custom_captions_test.json'



In [3]:
# to boost the performence of CUDA use:
# torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
train_interface_params = {
    'imgs_path': imgs_path,
    'captions_path': train_captions_path,
    'freq_threshold': 5,
    # 'sequence_length': 20,
    'caps_per_img': 5,
    # 'stage': "train",
    'vocab_from_file': True
}

val_interface_params = {
    'imgs_path': imgs_path,
    'captions_path': val_captions_path,
    'freq_threshold': 5,
    # 'sequence_length': 20,
    'caps_per_img': 1,
    # 'stage': "validation",
    'vocab_from_file': True
}

test_interface_params = {
    'imgs_path': imgs_path_test,
    'captions_path': test_captions_path,
    'freq_threshold': 5,
    # 'sequence_length': 20,
    'caps_per_img': 1,
    # 'stage': "test",
    'vocab_from_file': True
}


In [5]:
batch_size = 32

# # Training Interface
# coco_interface_train = MSCOCOInterface(**train_interface_params)

# # Validation Interface
# coco_interface_val = MSCOCOInterface(**val_interface_params)

train_loader = get_loader(**train_interface_params, batch_size=batch_size)
val_loader = get_loader(**val_interface_params, batch_size=batch_size)
# train_loader = data.DataLoader(coco_interface_train, batch_size=batch_size, shuffle=True)
# val_loader = data.DataLoader(coco_interface_val, batch_size=batch_size, shuffle=False)

####################
Vocab size is 16232
####################

Obtaining caption lengths...


100%|██████████| 225000/225000 [00:22<00:00, 10006.65it/s]


####################
Vocab size is 16232
####################

Obtaining caption lengths...


100%|██████████| 9000/9000 [00:01<00:00, 8710.12it/s] 


In [6]:
print("training captions: {}\nValidation captions: {}"
      .format(len(train_loader.dataset), len(val_loader.dataset)))

# print(f"Length of vocabulary: {len(coco_interface_train.idx_to_string)}")

training captions: 225000
Validation captions: 9000


In [7]:
idx, im, cap = next(iter(train_loader))
print(idx.shape, im.shape, cap.shape)

torch.Size([32]) torch.Size([32, 3, 256, 256]) torch.Size([32, 10])


## Parameters

In [8]:
embed_size = 512
hidden_size = 512
vocab_size = len(train_loader.dataset.idx_to_string)
num_layers = 1

## Encoder and Decoder

In [9]:
# pretrained = False does use a pretrained resnet but loads 
# from local .pth file 
encoder = Encoder(embed_size=embed_size, pretrained=True, model_weight_path="./model/resnet152_model.pth")
decoder = Decoder(embed_size=embed_size, hidden_size=hidden_size, vocab_size=vocab_size, num_layers=num_layers)
print("########################################READY########################################")

########################################READY########################################


In [15]:
# the loss is a cross entropy loss and ignore the index of <PAD> since it doesn't make any difference
# commenting out ignore pad as we are not padding now
criterion = nn.CrossEntropyLoss(#ignore_index=train_loader.dataset.string_to_index["<PAD>"]
                                )

# combine the parameters of decoder and encoder
params = list(decoder.parameters()) + list(encoder.embed.parameters())

# Adam optimizer
opt_pars = {'lr':1e-3, 'weight_decay':1e-3, 'betas':(0.9, 0.999), 'eps':1e-08}
optimizer = optim.Adam(params, **opt_pars)

## Train

In [16]:
def save_model(epoch, encoder, decoder, training_loss, validation_loss, checkpoint_path):
    torch.save({
        'epoch': epoch,
        'encoder_state_dict': encoder.state_dict(),
        'decoder_state_dict': decoder.state_dict(),
        'training_loss': training_loss,
        'validation_loss': validation_loss
    }, checkpoint_path)

def load_model(encoder, decoder, checkpoint_path):

    checkpoint = torch.load(checkpoint_path)
    encoder.load_state_dict(checkpoint['encoder_state_dict'])
    decoder.load_state_dict(checkpoint['decoder_state_dict'])
    training_loss = checkpoint['training_loss']
    validation_loss = checkpoint['validation_loss']

    return encoder, decoder, training_loss, validation_loss


In [17]:
CHECKPOINT = './model/image_captioning_model_v8.pth'
if Path(CHECKPOINT).exists():
    encoder, decoder, training_loss, validation_loss = load_model(encoder, decoder, CHECKPOINT)
else:
    print(f'{CHECKPOINT} file does not exist, training startging from scratch')

./model/image_captioning_model_v8.pth file does not exist, training startging from scratch


In [18]:
import math
train_step = math.ceil(len(train_loader.dataset.caption_lengths) / train_loader.batch_sampler.batch_size)
val_step = math.ceil(len(val_loader.dataset.caption_lengths) / val_loader.batch_sampler.batch_size)

def train(encoder, decoder, criterion, optimizer, train_loader, val_loader, total_epoch, checkpoint_path):
    
    encoder.to(device)
    decoder.to(device)
    
    training_loss = []
    validation_loss = []
    
    start_time = time.time()
    for epoch in range(total_epoch):
        train_epoch_loss = 0
        val_epoch_loss = 0
        
        # Training phase
        encoder.train()
        decoder.train()
        
        for i_step in range(train_step):
            # obtain a sample where all captions have same length
            indices = train_loader.dataset.get_train_indices()
            # Create and assign a batch sampler to retrieve a batch with the sampled indices.
            new_sampler = data.sampler.SubsetRandomSampler(indices=indices)
            train_loader.batch_sampler.sampler = new_sampler

            # Obtain the batch.
            idx, images, captions = next(iter(train_loader))

            images, captions = images.to(device), captions.to(device)
            
            # Zero the gradients.
            encoder.zero_grad()
            decoder.zero_grad()
            
            features = encoder(images)
            outputs = decoder(features, captions)
            
            loss = criterion(outputs.view(-1, vocab_size), captions.contiguous().view(-1))
            
            loss.backward()
            optimizer.step()
        
            train_epoch_loss += loss.item()
            if i_step % 100 == 0:
                print('Training: {}  :{:.4f}'.format(i_step,loss.item()))
                
        train_epoch_loss /= train_step
        training_loss.append(train_epoch_loss)
        
        # validation phase
        encoder.eval()
        decoder.eval()
        
        for i_step in range(val_step):
            # obtain a sample where all captions have same length
            indices = val_loader.dataset.get_train_indices()
            # Create and assign a batch sampler to retrieve a batch with the sampled indices.
            new_sampler = data.sampler.SubsetRandomSampler(indices=indices)
            val_loader.batch_sampler.sampler = new_sampler

            # Obtain the batch.
            idx, images, captions = next(iter(val_loader))

            images, captions = images.to(device), captions.to(device)
            
            features = encoder(images)
            outputs = decoder(features, captions)
            
            loss = criterion(outputs.view(-1, vocab_size), captions.contiguous().view(-1))
            
            val_epoch_loss += loss.item()
            if i_step % 100 == 0:
                print('Validation: {}  :{:.4f}'.format(i_step,loss.item()))
            
        val_epoch_loss /= val_step
        validation_loss.append(val_epoch_loss)
    
        epoch_time = (time.time() - start_time) /60**1

        save_model(epoch, encoder, decoder, training_loss, validation_loss, checkpoint_path)

        print("Epoch: {:d}. Training Loss = {:.4f}, Training Perplexity: {:.4f}. Validation Loss: {:.4f}, Validation Perplexity: {:.4f}. Time: {:f}" \
          .format(epoch, train_epoch_loss, np.exp(train_epoch_loss), val_epoch_loss, np.exp(val_epoch_loss), epoch_time))
    
    return training_loss, validation_loss
    

In [None]:
train_params = {
    'encoder': encoder,
    'decoder': decoder,
    'criterion': criterion,
    'optimizer': optimizer,
    'train_loader': train_loader,
    'val_loader': val_loader,
    'total_epoch': 10,
    'checkpoint_path': CHECKPOINT
}

training_loss, validation_loss = train(**train_params) 

Training: 0  :5.1558
Training: 100  :6.8887
Training: 200  :4.7070
Training: 300  :5.3139
Training: 400  :4.2956
Training: 500  :5.6933
Training: 600  :5.2031
Training: 700  :5.7740
Training: 800  :3.8395
Training: 900  :4.6064
Training: 1000  :3.4492
Training: 1100  :5.4491
Training: 1200  :6.1125
Training: 1300  :5.4897
Training: 1400  :6.6075
Training: 1500  :4.3259
Training: 1600  :4.6551
Training: 1700  :3.8328
Training: 1800  :3.4387
Training: 1900  :4.1924
Training: 2000  :5.0335
Training: 2100  :5.2689
Training: 2200  :3.6711
Training: 2300  :5.5258
Training: 2400  :4.7051
Training: 2500  :3.8897
Training: 2600  :5.2689
Training: 2700  :6.3200
Training: 2800  :5.6591
Training: 2900  :3.6919
Training: 3000  :6.2831
Training: 3100  :5.1877
Training: 3200  :6.0560
Training: 3300  :3.8607
Training: 3400  :5.2743
Training: 3500  :6.6732
Training: 3600  :7.4037
Training: 3700  :4.6636
Training: 3800  :5.5909
Training: 3900  :3.8991
Training: 4000  :3.4328
Training: 4100  :4.3386
Trai

In [None]:
CHECKPOINT = './model/image_captioning_model_v8.pth'
if Path(CHECKPOINT).exists():
    encoder, decoder, training_loss, validation_loss = load_model(encoder, decoder, CHECKPOINT)
else:
    print(f'{CHECKPOINT} file does not exist, training startging from scratch')

In [None]:
training_loss, validation_loss = train(**train_params)