In [1]:
import os
import pickle
import numpy as np

In [2]:
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from torch.nn.utils.rnn import pack_padded_sequence

In [3]:
from preprocesses import CustomCocoDataset, TokenIndex

# Define Encoder & Decoder

## CNN Encoder

In [17]:
class CNNEncoder(nn.Module):
    def __init__(self, embedding_size):
        super(CNNEncoder, self).__init__()
        resnet = models.resnet50(weights=True)
        layers = list(resnet.children())[:-1] # remove last full connected layer
        self.resnet_base = nn.Sequential(*layers)
        self.linear_layer = nn.Linear(resnet.fc.in_features, embedding_size)
        self.batch_norm = nn.BatchNorm1d(embedding_size, momentum=0.01)

    def forward(self, img):
        """Extract feature vectors from input images."""
        with torch.no_grad():
            features = self.resnet_base(img)
        features = features.reshape(features.size(0), -1)
        final_features = self.batch_norm(self.linear_layer(features))
        return final_features

## LSTM Decoder

In [31]:
class RNNModel(nn.Module):
    def __init__(self, embedding_size, hidden_layer_size, vocabulary_size, num_layers, max_seq_len=20):
        super(RNNModel, self).__init__()
        self.embedding_layer = nn.Embedding(vocabulary_size, embedding_size)
        self.lstm_layer = nn.LSTM(embedding_size, hidden_layer_size, num_layers, batch_first=True)
        self.linear_layer = nn.Linear(hidden_layer_size, vocabulary_size)
        self.max_seq_len = max_seq_len

    def forward(self, input_embeds, captions, lengths):
        """Decode image feature vectors and generates captions."""
        embeds = self.embedding_layer(captions)
        embeds = torch.cat((input_embeds.unsqueeze(1), embeds), 1)
        rnn_inputs = pack_padded_sequence(embeds, lengths, batch_first=True) 
        hidden_outputs, _ = self.lstm_layer(rnn_inputs)
        outputs = self.linear_layer(hidden_outputs[0])
        return outputs

    def sample(self, input_embeds, rnn_states=None):
        """Generate captions for given image features using greedy search."""
        token_idxs = []
        rnn_inputs = input_embeds.unsqueeze(1)
        for idx in range(self.max_seq_len):
            hidden_outputs, rnn_states = self.lstm_layer(rnn_inputs, rnn_states) # hidden: (batch_size, 1, hidden_size)
            outputs = self.linear_layer(hidden_outputs.squeeze(1))
            _, predictions = outputs.max(1)
            token_idxs.append(predictions)
            rnn_inputs = (self.embedding_layer(predictions)).unsqueeze(1)
        return torch.stack(token_idxs, 1)
        

# Model Training

## Preparation

create the model directory as check point for model training

In [32]:
if not os.path.exists("models"):
    os.makedirs("models")

load pre-defined Token2Index instance

In [33]:
with open('token_index.pkl', 'rb') as file:
    token2index = pickle.load(file)

token2index

<preprocesses.TokenIndex at 0x7e36c7d48a90>

define transformation pipeline for images

In [34]:
transform_pipeline = transforms.Compose([
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(), 
    transforms.ToTensor(), 
    transforms.Normalize((0.485, 0.456, 0.406),
                         (0.229, 0.224, 0.225))
])

In [35]:
coco_dataset = CustomCocoDataset("data/reshaped_train2014", "data/annotations/captions_train2014.json", token2index, transform_pipeline)
coco_dataset

loading annotations into memory...
Done (t=0.83s)
creating index...
index created!


<preprocesses.CustomCocoDataset at 0x7e361f873f10>

define `collate_fn` for Pytorch **DataLoader**. In PyTorch, the collate_fn is a function used by the DataLoader to specify how a list of data samples (from the dataset) should be merged into a single batch. It is particularly useful when working with datasets that have variable-length inputs or require custom preprocessing during batching.

In [36]:
def collate_function(batch):
    """Creates mini-batch tensors from the list of tuples (image, caption).

    we build our own custom collate_fn because  merging caption (including padding) 
    is not supported in default collate_fn
    Args:
        batch: list of tuple (image, caption).
                - image: torch tensor of shape (3, 256, 256).
                - caption: torch tensor of shape (?); variable length.
    Returns:
        images: torch tensor of shape (batch_size, 3, 256, 256).
        targets: torch tensor of shape (batch_size, padded_length).
        lengths: list; valid length for each padded caption.
    """
    # Sort a data list by caption length (descending order).
    batch.sort(key=lambda d: len(d[1]), reverse=True)
    imgs, caps = zip(*batch)

    # Merge images (from list of 3D tensors to 4D tensor).
    # Originally, imgs is a list of <batch_size> number of RGB images with dimensions (3, 256, 256)
    # This line of code turns it into a single tensor of dimensions (<batch_size>, 3, 256, 256)
    imgs = torch.stack(imgs, 0)

    # Merge captions (from list of 1D tensors to 2D tensor), similar to merging of images donw above.
    cap_lens = [len(cap) for cap in caps]
    targets = torch.zeros(len(caps), max(cap_lens)).long()
    for i, cap in enumerate(caps):
        end = cap_lens[i]
        targets[i, :end] = cap[:end]        
    return imgs, targets, cap_lens

In [37]:
data_loader = torch.utils.data.DataLoader(dataset=coco_dataset,batch_size=256,shuffle=True,collate_fn=collate_function)
print(f"There are {len(data_loader)} batches in each epochs")

There are 1618 batches in each epochs


## Construct the Model

In [38]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Code type using for Pytorch: {device}")

Code type using for Pytorch: cuda


In [39]:
encoder = CNNEncoder(256).to(device)

In [40]:
decoder = RNNModel(256, 512, len(token2index), 1).to(device)

In [41]:
loss_criterion = nn.CrossEntropyLoss()
parameters = list(decoder.parameters()) + list(encoder.linear_layer.parameters()) + list(encoder.batch_norm.parameters())
optimizer = torch.optim.Adam(parameters, lr=0.001)

## Training Loop

In [43]:
batch_num = len(data_loader)
epoch_num = 10
model_path = os.path.join('models',)
for epoch in range(epoch_num):
    for id_, (imgs, caps, lens) in enumerate(data_loader):
        # move batch to device
        imgs = imgs.to(device)
        caps = caps.to(device)
        tgts = pack_padded_sequence(caps, lens, batch_first=True)[0]

        # Forward, backward and optimize
        features = encoder(imgs)
        outputs = decoder(features, caps, lens)
        losses = loss_criterion(outputs, tgts)
        encoder.zero_grad()
        decoder.zero_grad()
        losses.backward()
        optimizer.step()

        # print log
        if id_%100 == 0: 
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Perplexity: {:5.4f}'.format(epoch, epoch_num, id_, batch_num, losses.item(), np.exp(losses.item())))
            
    # Save the model checkpoints
    torch.save(decoder.state_dict(), os.path.join('models', 'decoder-{}.ckpt'.format(epoch)))
    torch.save(encoder.state_dict(), os.path.join('models', 'encoder-{}.ckpt'.format(epoch)))

Epoch [0/5], Step [0/1618], Loss: 8.9880, Perplexity: 8006.6793
Epoch [0/5], Step [100/1618], Loss: 3.7494, Perplexity: 42.4946
Epoch [0/5], Step [200/1618], Loss: 3.3279, Perplexity: 27.8801
Epoch [0/5], Step [300/1618], Loss: 3.0066, Perplexity: 20.2182
Epoch [0/5], Step [400/1618], Loss: 2.9582, Perplexity: 19.2630
Epoch [0/5], Step [500/1618], Loss: 2.6639, Perplexity: 14.3517
Epoch [0/5], Step [600/1618], Loss: 2.6400, Perplexity: 14.0135
Epoch [0/5], Step [700/1618], Loss: 2.6156, Perplexity: 13.6749
Epoch [0/5], Step [800/1618], Loss: 2.6118, Perplexity: 13.6235
Epoch [0/5], Step [900/1618], Loss: 2.6401, Perplexity: 14.0147
Epoch [0/5], Step [1000/1618], Loss: 2.3962, Perplexity: 10.9809
Epoch [0/5], Step [1100/1618], Loss: 2.4257, Perplexity: 11.3097
Epoch [0/5], Step [1200/1618], Loss: 2.4685, Perplexity: 11.8052
Epoch [0/5], Step [1300/1618], Loss: 2.4187, Perplexity: 11.2314
Epoch [0/5], Step [1400/1618], Loss: 2.2768, Perplexity: 9.7454
Epoch [0/5], Step [1500/1618], Loss:

save the final decode and encoder

In [44]:
torch.save(decoder.state_dict(), os.path.join('models', 'decoder.ckpt'))
torch.save(encoder.state_dict(), os.path.join('models', 'encoder.ckpt'))