# The Plan

I want to use the architecture encoder-decoder with attention architecture pattern:

* Encode the image into a rich visual representation
* Decode our representation into a sequence of words, one word at a time
* Use attention to let the decoder focus on different image regions for each word

Encoder will be a CNN with spatial feature map (7x7x512)
Decoder will be an LSTM with attention

In [14]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from collections import Counter
import matplotlib.pyplot as plt
import os
from PIL import Image
from tqdm import tqdm
torch.manual_seed(42)
np.random.seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

Using device: cuda
GPU: NVIDIA GeForce GTX 1650
Memory: 4.29 GB


# Building A Vocabulary

In [13]:
class Vocabulary:
    def __init__(self, freq_threshold=5):
        self.freq_threshold = freq_threshold
        self.word2idx = {}
        self.idx2word = {}
        self.idx = 0

        self.add_word("<PAD>")
        self.add_word("<START>")
        self.add_word("<END>")
        self.add_word("<UNK>")

    def add_word(self, word):
        if word not in self.word2idx:
            self.word2idx[word] = self.idx
            self.idx2word[self.idx] = word
            self.idx += 1

    def __len__(self):
        return len(self.word2idx)

    def build_vocabulary(self, caption_list):
        frequencies = Counter()

        for caption in caption_list:
            tokens = self.tokenize(caption)
            frequencies.update(tokens)

        for word, count in frequencies.items():
            if count >= self.freq_threshold:
                self.add_word(word)

        print(f"Vocabulary built with {len(self)} words")
        print(f"Words appearing >= {self.freq_threshold} times")

    @staticmethod
    def tokenize(text):
        return text.lower().split()

    def numericalize(self, text):
        tokens = self.tokenize(text)

        indices = [self.word2idx["<START>"]]

        for token in tokens:
            if token in self.word2idx:
                indices.append(self.word2idx[token])
            else:
                indices.append(self.word2idx["<UNK>"])

        indices.append(self.word2idx["<END>"])

        return indices

# Building A Dataset

In [15]:
class FlickrDataset(Dataset):
    def __init__(self, root_dir, captions_file, vocab=None, transform=None, build_vocab=False):
        self.root_dir = root_dir
        self.transform = transform

        self.df = pd.read_csv(captions_file)
        print(f"Loaded {len(self.df)} image-caption pairs")

        if build_vocab:
            self.vocab = Vocabulary(freq_threshold=5)
            self.vocab.build_vocabulary(self.df['caption'].tolist())
        else:
            self.vocab = vocab

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        caption = self.df.iloc[idx]['caption']
        img_name = self.df.iloc[idx]['image']
        img_path = os.path.join(self.root_dir, 'Images', img_name)

        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)

        numericalized_caption = self.vocab.numericalize(caption)

        return image, torch.tensor(numericalized_caption)