# Description

In this notebook, we will train the Image Captioning model:
- Dataset: Flickr8k

In [None]:
import os 
import numpy as np 
import pandas as pd 
import matplotlib.pyplot as plt
%matplotlib inline
from collections import Counter
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.models as models

from data_utils import *

# 1. Prepare dataset

## 1.1. Load dataset

In [2]:
PATH_FILE_CAPTION = 'data/captions.txt'
PATH_FOLDER_IMAGES = 'data/Images/'

In [None]:
# Read the captions
df = pd.read_csv(PATH_FILE_CAPTION)

# Building the vocab
vocab = Vocabulary(freq_threshold=1)

vocab.build_vocab(df.caption.values)
print(f"Number of words in the vocab: {len(vocab)}")

In [None]:
# Load the Custom Dataset
transform = transforms.Compose([
    transforms.Resize((226, 226)),
    transforms.RandomCrop((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

dataset = FlickrDataset(df, vocab, path_folder_images=PATH_FOLDER_IMAGES, transform=transform)

display_random_image(dataset)

## 1.2. Prepare the Data Loader

In [5]:
BATH_SIZE = 128
NUM_WORKERS = 8
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

padding_value = dataset.vocab.str_2_int['<PAD>']
data_loader = DataLoader(dataset, batch_size=BATH_SIZE, num_workers=NUM_WORKERS, shuffle=True, collate_fn=Collate(padding_value))



# 2. Define model architecture

- Model: seq2seq model. 
- Encoder: pretrained Mobile v3 model. 
- Decoder: Bahdanau Attention and LSTM cell.

In [6]:
class EncoderCNN(nn.Module):
    def __init__(self):
        super(EncoderCNN, self).__init__()
        base_model = models.mobilenet_v3_small()
        for param in base_model.parameters():
            param.requires_grad_(False)
        
        modules = list(base_model.children())[:-2]
        self.base_model = nn.Sequential(*modules)
        
    def forward(self, images):
        features = self.base_model(images)                                    # (batch_size, 576, 7, 7)
        features = features.permute(0, 2, 3, 1)                           # (batch_size, 7, 7, 576)
        features = features.view(features.size(0), -1, features.size(-1)) # (batch_size, 49, 576)
        return features
    
    
def calculate_number_parameters(model):
    return sum(p.numel() for p in model.parameters())

In [None]:
test_image = torch.zeros((32, 3, 224, 224))
encoder = EncoderCNN()
test_output_encoder = encoder(test_image)

print(f"Test Output Shape: {test_output_encoder.shape}")
print(f"Number of parameters in the model: {calculate_number_parameters(encoder)}")

In [8]:
#Bahdanau Attention
class Attention(nn.Module):
    def __init__(self, encoder_dim,decoder_dim,attention_dim):
        super(Attention, self).__init__()
        
        self.attention_dim = attention_dim
        
        self.W = nn.Linear(decoder_dim,attention_dim)
        self.U = nn.Linear(encoder_dim,attention_dim)
        
        self.A = nn.Linear(attention_dim,1)
        
        
    def forward(self, features, hidden_state):
        u_hs = self.U(features)     #(batch_size,num_layers,attention_dim)
        w_ah = self.W(hidden_state) #(batch_size,attention_dim)
        
        combined_states = torch.tanh(u_hs + w_ah.unsqueeze(1)) #(batch_size,num_layers,attemtion_dim)
        
        attention_scores = self.A(combined_states)         #(batch_size,num_layers,1)
        attention_scores = attention_scores.squeeze(2)     #(batch_size,num_layers)
        
        
        alpha = F.softmax(attention_scores,dim=1)          #(batch_size,num_layers)
        
        attention_weights = features * alpha.unsqueeze(2)  #(batch_size,num_layers,features_dim)
        attention_weights = attention_weights.sum(dim=1)   #(batch_size,num_layers)
        
        return alpha,attention_weights

In [9]:
#Attention Decoder
class DecoderRNN(nn.Module):
    def __init__(self, embed_size, vocab_size, attention_dim,encoder_dim,decoder_dim,drop_prob=0.3):
        super().__init__()
        
        #save the model param
        self.vocab_size = vocab_size
        self.attention_dim = attention_dim
        self.decoder_dim = decoder_dim
        
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.attention = Attention(encoder_dim, decoder_dim, attention_dim)
        
        self.init_h = nn.Linear(encoder_dim, decoder_dim)  
        self.init_c = nn.Linear(encoder_dim, decoder_dim)  
        self.lstm_cell = nn.LSTMCell(embed_size + encoder_dim, decoder_dim, bias=True)
        self.f_beta = nn.Linear(decoder_dim, encoder_dim)
        
        self.fcn = nn.Linear(decoder_dim, vocab_size)
        self.drop = nn.Dropout(drop_prob)
        
    
    def forward(self, features, captions):
        
        embeds = self.embedding(captions)
        
        # Initialize LSTM state
        h, c = self.init_hidden_state(features)  # (batch_size, decoder_dim)
        
        # get the seq length to iterate
        seq_length = len(captions[0]) - 1 # Exclude the last one
        batch_size = captions.size(0)
        num_features = features.size(1)
        
        preds = torch.zeros(batch_size, seq_length, self.vocab_size).to(device)
        alphas = torch.zeros(batch_size, seq_length, num_features).to(device)
                
        for s in range(seq_length):
            alpha, context = self.attention(features, h)
            lstm_input = torch.cat((embeds[:, s], context), dim=1)
            h, c = self.lstm_cell(lstm_input, (h, c))
                    
            output = self.fcn(self.drop(h))
            
            preds[:,s] = output
            alphas[:,s] = alpha  
        
        return preds, alphas
    
    def generate_caption(self, features, max_len=20, vocab=None):
        # Inference part: given the image features generate the captions
        
        batch_size = features.size(0)
        h, c = self.init_hidden_state(features)  # (batch_size, decoder_dim)
        
        alphas = []
        
        # starting input
        word = torch.tensor(vocab.str_2_int['<SOS>']).view(1,-1).to(device)
        embeds = self.embedding(word)

        captions = []
        
        for i in range(max_len):
            alpha, context = self.attention(features, h)
            
            # store the apla score
            alphas.append(alpha.cpu().detach().numpy())
            
            lstm_input = torch.cat((embeds[:, 0], context), dim=1)
            h, c = self.lstm_cell(lstm_input, (h, c))
            output = self.fcn(self.drop(h))
            output = output.view(batch_size,-1)
        
            # select the word with most val
            predicted_word_idx = output.argmax(dim=1)
            
            # save the generated word
            captions.append(predicted_word_idx.item())
            
            # end if <EOS detected>
            if vocab.int_2_str[predicted_word_idx.item()] == "<EOS>":
                break
            
            # send generated word as the next caption
            embeds = self.embedding(predicted_word_idx.unsqueeze(0))
        
        # covert the vocab idx to words and return sentence
        return [vocab.int_2_str[idx] for idx in captions], alphas
    
    
    def init_hidden_state(self, encoder_out):
        mean_encoder_out = encoder_out.mean(dim=1)
        h = self.init_h(mean_encoder_out)  # (batch_size, decoder_dim)
        c = self.init_c(mean_encoder_out)
        return h, c

In [None]:
decoder = DecoderRNN(embed_size=128, vocab_size=len(dataset.vocab), attention_dim=256, encoder_dim=576, decoder_dim=512)

test_output_decoder, _ = decoder(test_output_encoder, torch.randint(0, 100, (32, 16)))

print(f"Test Output Shape: {test_output_decoder.shape}")
print(f"Number of parameters in decoder: {calculate_number_parameters(decoder)}")

In [11]:
class EncoderDecoder(nn.Module):
    def __init__(self,embed_size, vocab_size, attention_dim, encoder_dim, decoder_dim):
        super().__init__()
        self.encoder = EncoderCNN()
        self.decoder = DecoderRNN(
            embed_size=embed_size,
            vocab_size = vocab_size,
            attention_dim=attention_dim,
            encoder_dim=encoder_dim,
            decoder_dim=decoder_dim
        )
        
    def forward(self, images, captions):
        features = self.encoder(images)
        outputs = self.decoder(features, captions)
        return outputs

# 3. Training 

In [12]:
# Hyperparams
embed_size = 128
vocab_size = len(dataset.vocab)
attention_dim = 256
encoder_dim = 576
decoder_dim = 256
learning_rate = 3e-4

# Init model
model = EncoderDecoder(
    embed_size = embed_size,
    vocab_size = vocab_size,
    attention_dim = attention_dim,
    encoder_dim = encoder_dim,
    decoder_dim = decoder_dim
).to(device)

criterion = nn.CrossEntropyLoss(ignore_index=dataset.vocab.str_2_int["<PAD>"])
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
print(f"Number of parameters in the model: {calculate_number_parameters(model)}")

In [14]:
def save_model(model,num_epochs):
    """
    Helper function to save the model
    """
    model_state = {
        'num_epochs':num_epochs,
        'embed_size':embed_size,
        'vocab_size':len(dataset.vocab),
        'attention_dim':attention_dim,
        'encoder_dim':encoder_dim,
        'decoder_dim':decoder_dim,
        'state_dict':model.state_dict()
    }

    torch.save(model_state,'attention_model_state.pth')

In [15]:
def show_image(img, title=None):
    """Imshow for Tensor."""
    
    #unnormalize 
    img[0] = img[0] * 0.229
    img[1] = img[1] * 0.224 
    img[2] = img[2] * 0.225 
    img[0] += 0.485 
    img[1] += 0.456 
    img[2] += 0.406
    
    img = img.numpy().transpose((1, 2, 0))
    
    
    plt.imshow(img)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # pause a bit so that plots are updated

In [None]:
num_epochs = 100
print_every = 3

for epoch in range(1, num_epochs+1):   
    for idx, (image, captions) in enumerate(iter(data_loader)):
        image,captions = image.to(device),captions.to(device)

        optimizer.zero_grad()

        outputs, attentions = model(image, captions)

        # Calculate the batch loss without the <START> token.
        targets = captions[:, 1:]
        loss = criterion(outputs.view(-1, vocab_size), targets.reshape(-1))
        
        loss.backward()
        optimizer.step()

    if epoch % print_every == 0:
        print("Epoch: {} loss: {:.5f}".format(epoch,loss.item()))
        
        # generate the caption
        model.eval()
        with torch.no_grad():
            dataiter = iter(data_loader)
            img,_ = next(dataiter)
            features = model.encoder(img[0:1].to(device))
            caps,alphas = model.decoder.generate_caption(features,vocab=dataset.vocab)
            caption = ' '.join(caps)
            show_image(img[0],title=caption)
            
        model.train()
        
    # save the latest model
    save_model(model, epoch)