# The Plan

I want to use the architecture encoder-decoder with attention architecture pattern:

* Encode the image into a rich visual representation
* Decode our representation into a sequence of words, one word at a time
* Use attention to let the decoder focus on different image regions for each word

Encoder will be a CNN with spatial feature map (7x7x512)
Decoder will be an LSTM with attention

In [16]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from collections import Counter
import matplotlib.pyplot as plt
import os
from PIL import Image
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm
torch.manual_seed(42)
np.random.seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

Using device: cuda
GPU: NVIDIA GeForce GTX 1650
Memory: 4.29 GB


# Building A Vocabulary

In [13]:
class Vocabulary:
    def __init__(self, freq_threshold=5):
        self.freq_threshold = freq_threshold
        self.word2idx = {}
        self.idx2word = {}
        self.idx = 0

        self.add_word("<PAD>")
        self.add_word("<START>")
        self.add_word("<END>")
        self.add_word("<UNK>")

    def add_word(self, word):
        if word not in self.word2idx:
            self.word2idx[word] = self.idx
            self.idx2word[self.idx] = word
            self.idx += 1

    def __len__(self):
        return len(self.word2idx)

    def build_vocabulary(self, caption_list):
        frequencies = Counter()

        for caption in caption_list:
            tokens = self.tokenize(caption)
            frequencies.update(tokens)

        for word, count in frequencies.items():
            if count >= self.freq_threshold:
                self.add_word(word)

        print(f"Vocabulary built with {len(self)} words")
        print(f"Words appearing >= {self.freq_threshold} times")

    @staticmethod
    def tokenize(text):
        return text.lower().split()

    def numericalize(self, text):
        tokens = self.tokenize(text)

        indices = [self.word2idx["<START>"]]

        for token in tokens:
            if token in self.word2idx:
                indices.append(self.word2idx[token])
            else:
                indices.append(self.word2idx["<UNK>"])

        indices.append(self.word2idx["<END>"])

        return indices

# Building A Dataset

In [15]:
class FlickrDataset(Dataset):
    def __init__(self, root_dir, captions_file, vocab=None, transform=None, build_vocab=False):
        self.root_dir = root_dir
        self.transform = transform

        self.df = pd.read_csv(captions_file)
        print(f"Loaded {len(self.df)} image-caption pairs")

        if build_vocab:
            self.vocab = Vocabulary(freq_threshold=5)
            self.vocab.build_vocabulary(self.df['caption'].tolist())
        else:
            self.vocab = vocab

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        caption = self.df.iloc[idx]['caption']
        img_name = self.df.iloc[idx]['image']
        img_path = os.path.join(self.root_dir, 'Images', img_name)

        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)

        numericalized_caption = self.vocab.numericalize(caption)

        return image, torch.tensor(numericalized_caption)

# Adding Custom Collate

In [17]:
class CustomCollate:
    def __init__(self, pad_idx):
        self.pad_idx = pad_idx

    def __call__(self, batch):
        images = [item[0] for item in batch]
        captions = [item[1] for item in batch]

        images = torch.stack(images, dim=0)

        lengths = [len(cap) for cap in captions]

        captions = pad_sequence(captions, batch_first=True, padding_value=self.pad_idx)

        return images, captions, lengths


# CNN Architechture

In [None]:
class CNNEncoder(nn.Module):
    def __init__(self, encoded_image_size=7):
        super(CNNEncoder, self).__init__()

        self.enc_image_size = encoded_image_size

        # 3 -> 64
        self.block1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2)  # 224 -> 112
        )

        # 64 -> 128
        self.block2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2)  # 112 -> 56
        )

        # 128 -> 256
        self.block3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2)  # 56 -> 28
        )

        # 256 -> 512
        self.block4 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2)  # 28 -> 14
        )

        # 512 -> 512
        self.block5 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2)  # 14 -> 7
        )

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, images):
        x = self.block1(images)  # (batch, 64, 112, 112)
        x = self.block2(x) # (batch, 128, 56, 56)
        x = self.block3(x) # (batch, 256, 28, 28)
        x = self.block4(x) # (batch, 512, 14, 14)
        x = self.block5(x) # (batch, 512, 7, 7)

        # For attention: (batch, 512, 7, 7) -> (batch, 49, 512)
        batch_size = x.size(0)
        x = x.permute(0, 2, 3, 1)  # (batch, 7, 7, 512)
        x = x.view(batch_size, -1, 512)  # (batch, 49, 512)

        return x


# Adding The Bahdanau Attention

In [19]:
class BahdanauAttention(nn.Module):
    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        super(BahdanauAttention, self).__init__()

        self.encoder_att = nn.Linear(encoder_dim, attention_dim)
        self.decoder_att = nn.Linear(decoder_dim, attention_dim)

        self.full_att = nn.Linear(attention_dim, 1)

        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)

    def forward(self, encoder_out, decoder_hidden):
        att1 = self.encoder_att(encoder_out)
        att2 = self.decoder_att(decoder_hidden)
        att = self.relu(att1 + att2.unsqueeze(1))
        att = self.full_att(att)
        alpha = self.softmax(att.squeeze(2))
        context = (encoder_out * alpha.unsqueeze(2)).sum(dim=1)

        return context, alpha