# Download the Flickr8k dataset from Kaggle
First create a token, download it and upload it here. Follow these steps: https://www.kaggle.com/discussions/general/74235

In [None]:
!pip install -q kaggle
!pip install timm

Collecting timm
  Downloading timm-0.9.12-py3-none-any.whl (2.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m9.3 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: timm
Successfully installed timm-0.9.12


### Upload you Kaggle API Token

In [None]:
from google.colab import files
files.upload()

Saving kaggle.json to kaggle.json


{'kaggle.json': b'{"username":"tamstakcs","key":"142401ead91574da9cedb59435020939"}'}

### Download Script for the Flickr8k dataset

In [None]:
!mkdir ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

!kaggle datasets download -d sayanf/flickr8k

Downloading flickr8k.zip to /content
100% 1.04G/1.04G [00:05<00:00, 192MB/s]
100% 1.04G/1.04G [00:05<00:00, 186MB/s]


### Organizing annotations and images into separate folders

In [None]:
!unzip -q flickr8k.zip
!mv Flickr8k_Dataset Flickr8k_images
!mv Flickr8k_text Flickr8k_annotations

### Downloading the GLOVE Embeddings

In [None]:
!wget https://nlp.stanford.edu/data/glove.6B.zip

--2023-12-14 14:45:58--  https://nlp.stanford.edu/data/glove.6B.zip
Resolving nlp.stanford.edu (nlp.stanford.edu)... 171.64.67.140
Connecting to nlp.stanford.edu (nlp.stanford.edu)|171.64.67.140|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: https://downloads.cs.stanford.edu/nlp/data/glove.6B.zip [following]
--2023-12-14 14:45:58--  https://downloads.cs.stanford.edu/nlp/data/glove.6B.zip
Resolving downloads.cs.stanford.edu (downloads.cs.stanford.edu)... 171.64.64.22
Connecting to downloads.cs.stanford.edu (downloads.cs.stanford.edu)|171.64.64.22|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 862182613 (822M) [application/zip]
Saving to: ‘glove.6B.zip’


2023-12-14 14:48:49 (4.83 MB/s) - ‘glove.6B.zip’ saved [862182613/862182613]



In [None]:
!unzip -d Flickr8k_annotations/ glove.6B.zip

Archive:  glove.6B.zip
  inflating: Flickr8k_annotations/glove.6B.50d.txt  
  inflating: Flickr8k_annotations/glove.6B.100d.txt  
  inflating: Flickr8k_annotations/glove.6B.200d.txt  
  inflating: Flickr8k_annotations/glove.6B.300d.txt  


### Download the images of our Test Dataset

In [None]:
import os
if os.path.isdir('test_dataset')==False:
    !gdown 1QHpHBFH9glz1P_vddU2g8Pg4pxFWlCv-
    !unzip -qq test_dataset.zip -d test_dataset
    !rm -rf test_dataset.zip

Downloading...
From: https://drive.google.com/uc?id=1QHpHBFH9glz1P_vddU2g8Pg4pxFWlCv-
To: /content/test_dataset.zip
100% 129M/129M [00:01<00:00, 70.6MB/s]


In [None]:
!echo "Number of images in the test dataset: $(ls -l test_dataset/test_dataset/ | wc -l) - 1"

Number of images in the test dataset: 176 - 1


### Downloading the captions for the images

In [None]:
!gdown 1KQhnrfWtPfXbApXHXGEiz6rGbGa4NdDJ
!mv test_captions_tokenized.txt Flickr8k_annotations/

Downloading...
From: https://drive.google.com/uc?id=1KQhnrfWtPfXbApXHXGEiz6rGbGa4NdDJ
To: /content/test_captions_tokenized.txt
  0% 0.00/60.0k [00:00<?, ?B/s]100% 60.0k/60.0k [00:00<00:00, 102MB/s]


In [None]:
!echo "Number of captions in the test dataset: $(wc -l < Flickr8k_annotations/test_captions_tokenized.txt)"
!echo $((175 * 5))

Number of captions in the test dataset: 875
875


In [None]:
!echo "Number of images in the dataset: $(ls -l Flickr8k_images/ | wc -l)"

Number of images in the dataset: 8092


### Add the test images to the image folder

In [None]:
!mv test_dataset/test_dataset/*.jpg Flickr8k_images/
!mv test_dataset/test_dataset/*.png Flickr8k_images/
!mv test_dataset/test_dataset/*.jpeg Flickr8k_images/

In [None]:
!echo "Number of images in the train dataset after adding test images: $(ls -l Flickr8k_images/ | wc -l)"

Number of images in the train dataset after adding test images: 8267


### Remove all unnecessary directories

In [None]:
!rm -rf sample_data/
!rm -f flickr8k.zip
!rm -f glove.6B.zip
!rm -rf test_dataset/
!rm -f kaggle.json

# Necessary imports

In [None]:
import os
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import json
from tqdm import tqdm
from collections import defaultdict
import matplotlib.pyplot as plt
import numpy as np
import timm
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F
import string
import nltk
import shutil
import time
from nltk.stem import WordNetLemmatizer
from collections import Counter
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import TransformerDecoderLayer, TransformerDecoder
from nltk.translate.bleu_score import corpus_bleu
from torch.utils.data import DataLoader


In [None]:
## CONSTANTS

EMBEDDING_SIZE = 300 # Used by GLOVE
MAX_LEN = 60

START_IDX = 1
END_IDX = 2
PAD_IDX = 0
UNK_IDX = 3

START_TOKEN = "<start>"
END_TOKEN = "<end>"
PAD_TOKEN = "<pad>"
UNK_TOKEN = "<unk>"

## TRAINING HYPERPARAMS

BATCH_SIZE = 32
SHUFFLE = True
NUM_WORKERS = 1
DROP_LAST = True
NUM_OF_EPOCHS = 20
L2_PENALTY = 0.5
LEARNING_RATE = 0.000008
GRADIENT_CLIPPING = 2.0
EVAL_PERIOD = 1

## MODEL

IMG_FEATURE_CHANNELS = 2048
IMG_SIZE = 256
DECODER_LAYERS = 8
D_MODEL = 512
FF_DIM = 1024
ATTENTION_HEADS = 16
DROPOUT = 0.5


In [None]:
def load_captions(data):
    image2caption = dict()
    for sample in data.split("\n"):
        tokens = sample.split()
        if len(sample) < 2:
            continue
        image_name, image_caption = tokens[0], tokens[1:]

        image_id = image_name.split(".")[0]
        image_caption = " ".join(image_caption)

        if image_id not in image2caption:
            image2caption[image_id] = list()
        image2caption[image_id].append(image_caption)

    return image2caption

In [None]:
def preprocess_caption(caption):
    punct_table = str.maketrans("", "", string.punctuation)
    # Extract separate tokens
    caption = caption.split()
    # Make tokens lowercase
    caption = [word.lower() for word in caption]
    # Remove punctuation
    caption = [word.translate(punct_table) for word in caption]
    # Remove trailing "'s" or "a"
    caption = [word for word in caption if len(word) > 1]
    # Remove tokens which contain number
    caption = [word for word in caption if word.isalpha()]
    return " ".join(caption)


def clean_captions(id2annotation):
    image2caption_clean = id2annotation.copy()
    for image_id, captions in id2annotation.items():
        for i in range(len(captions)):
            caption = captions[i]
            clean_caption = preprocess_caption(caption)
            image2caption_clean[image_id][i] =  clean_caption

    return image2caption_clean

In [None]:
def create_vocab(image2caption):
    word2idx = {"<pad>": 0, "<start>": 1, "<end>": 2, "<unk>": 3}
    words = set()
    for captions in image2caption.values():
        current_words = [word for caption in captions for word in caption.split()]
        words.update(current_words)

    starting_len = len(word2idx)
    words = list(words)
    word2idx.update({word: (idx + starting_len) for idx, word in enumerate(words)})

    return word2idx

In [None]:
def extract_embeddings(vocab):
    np.random.seed(20231104)
    glove_dir = "Flickr8k_annotations"
    embeddings_config = {
        "path": "Flickr8k_annotations/embeddings.txt",
        "size": EMBEDDING_SIZE
    }
    save_path_emb = embeddings_config["path"]
    embedding_dim = embeddings_config["size"]

    punct_table = str.maketrans("", "", string.punctuation)

    vectors = []
    new_vocab = {"<pad>": 0, "<start>": 1, "<end>": 2, "<unk>": 3}
    i = len(new_vocab)

    embedding_file_name = "glove.6B.{}d.txt".format(embedding_dim)
    embeddings_path = os.path.join(glove_dir, embedding_file_name)
    with open(embeddings_path, "rb") as f:
        for line in f:
            line = line.decode().split()
            word = line[0]
            word = word.strip().lower()
            word = word.translate(punct_table)
            if word in vocab and word not in new_vocab:
                embedding_vec = np.array(line[1:], dtype="float")
                vectors += [embedding_vec]
                new_vocab[word] = i
                i += 1
    with open("Flickr8k_annotations/word2idx.json", "w", encoding="utf8") as f:
        json.dump(new_vocab, f)

    vectors = np.array(vectors)
    # Embedding vector for tokens used for padding the input sequence
    pad_embedding = np.zeros((embedding_dim,))
    # Embedding vector for start of the sequence
    sos_embedding = np.random.normal(size=(embedding_dim,))
    # Embedding vector for end of the sequence
    eos_embedding = np.random.normal(size=(embedding_dim,))
    # Embedding vector for unknown token
    unk_embedding =  np.random.normal(size=(embedding_dim,))

    assert not np.allclose(sos_embedding, eos_embedding), "SOS and EOS embeddings are too close!"
    for emb_vec in vectors:
        assert not np.allclose(sos_embedding, emb_vec), "SOS embedding is too close to other embedding!"
        assert not np.allclose(eos_embedding, emb_vec), "EOS embedding is too close to other embedding!"


    print("Embedding vectors shape without SOS EOS UNK and PAD: ", vectors.shape)
    vectors = np.vstack([pad_embedding, sos_embedding, eos_embedding, unk_embedding, vectors])

    print("Embedding vectors shape having added SOS EOS UNK and PAD: ", vectors.shape)
    np.savetxt(save_path_emb, vectors)

    print("\nExtracted GloVe embeddings for all tokens in the training set.")
    print("Embedding vectors size:", embedding_dim)
    print("Vocab size:", len(new_vocab))
    return new_vocab

In [None]:
def save_captions(image2caption, subset_imgs, save_path):
    captions = []
    for image_name in subset_imgs:
        image_id = os.path.splitext(image_name)[0]
        if image_id in image2caption:
            for caption in image2caption[image_id]:
                captions.append("{} {}\n".format(image_name, caption))

    with open(save_path, "w") as f:
        f.writelines(captions)

def split_dataset(image2caption, split_images_paths, save_paths):
    for load_path, save_path in zip(split_images_paths, save_paths):
        with open(load_path, "r") as f:
            subset_imgs = [fname.replace("\n", "") for fname in f.readlines()]
        save_captions(image2caption, subset_imgs, save_path)

In [None]:
!echo "Number of images in the train dataset: $(cat Flickr8k_annotations/Flickr_8k.trainImages.txt | wc -l)"
!echo "Number of images in the test dataset: $(cat Flickr8k_annotations/Flickr_8k.testImages.txt | wc -l)"

Number of images in the train dataset: 6000
Number of images in the test dataset: 1000


### Add the test images to the training set

In [None]:
source_path = "Flickr8k_annotations/Flickr_8k.testImages.txt"
destination_path = "Flickr8k_annotations/Flickr_8k.trainImages.txt"

with open(destination_path, "a") as f_dest, open(source_path, "r") as f_source:
    shutil.copyfileobj(f_source, f_dest)

In [None]:
!echo "Number of images in the test dataset after adding training images: $(cat Flickr8k_annotations/Flickr_8k.trainImages.txt | wc -l)"

Number of images in the test dataset after adding training images: 7000


In [None]:
!echo "Number of lines in the tokenized dataset: $(cat Flickr8k_annotations/Flickr8k.token.txt | wc -l)"

Number of lines in the tokenized dataset: 40460


In [None]:
!echo "Number of lines in our tokenized dataset: $(cat Flickr8k_annotations/test_captions_tokenized.txt | wc -l)"

Number of lines in our tokenized dataset: 875


#### Removes test image IDs

In [None]:
!rm Flickr8k_annotations/Flickr_8k.testImages.txt

#### Add all unique image IDs from our test set and add it to the test IDs

In [None]:
!sed -n 's/\([^#]\+\)\(jpg\|jpeg\|png\).*/\1\2/p' Flickr8k_annotations/test_captions_tokenized.txt | uniq > Flickr8k_annotations/Flickr_8k.testImages.txt

In [None]:
!echo "Number of images in our test dataset: $(cat Flickr8k_annotations/Flickr_8k.testImages.txt | wc -l)"

Number of images in our test dataset: 175


### Open the tokenized captions of Flickr Images

#### Create a image2caption (dict): Mapping from image id to all captions of that image that occured in the datase

#### Clean the captions, remove stopwords etc.

#### Create a vocabulary of used words

#### Extract the embeddings for words from GLOVE

In [None]:
dataset_path = "Flickr8k_annotations/Flickr8k.token.txt"
with open(dataset_path, "r") as f:
  data = f.read()

image2caption = load_captions(data)
image2caption = clean_captions(image2caption)
vocab = create_vocab(image2caption)

new_vocab = extract_embeddings(vocab)

Embedding vectors shape without SOS EOS UNK and PAD:  (7886, 300)
Embedding vectors shape having added SOS EOS UNK and PAD:  (7890, 300)

Extracted GloVe embeddings for all tokens in the training set.
Embedding vectors size: 300
Vocab size: 7890


#### Create the image-caption data for the network: line of (image_id caption[i]) * 5 for every image

In [None]:
split_images = {
  "train": "Flickr8k_annotations/Flickr_8k.trainImages.txt",
  "validation": "Flickr8k_annotations/Flickr_8k.devImages.txt",
  "test" : "Flickr8k_annotations/Flickr_8k.testImages.txt",
}

split_save = {
  "train": "Flickr8k_annotations/train.txt",
  "validation": "Flickr8k_annotations/validation.txt",
  "test": "Flickr8k_annotations/test.txt",
}

split_images_paths = list(split_images.values())[:-1]
split_save_paths = list(split_save.values())[:-1]

print(split_images_paths)

split_dataset(image2caption, split_images_paths, split_save_paths)

['Flickr8k_annotations/Flickr_8k.trainImages.txt', 'Flickr8k_annotations/Flickr_8k.devImages.txt']


In [None]:
dataset_path = "Flickr8k_annotations/test_captions_tokenized.txt"
with open(dataset_path, "r") as f:
    data_test = f.read()

image2caption_test = load_captions(data_test)
image2caption_test = clean_captions(image2caption_test)

split_images_paths_test = list(split_images.values())[-1]
split_save_paths_test = list(split_save.values())[-1]

split_dataset(image2caption_test, [split_images_paths_test], [split_save_paths_test])

#### Set up GPU device

In [None]:
use_gpu = torch.cuda.is_available()
device = torch.device("cuda" if use_gpu else "cpu")
device

device(type='cuda')

### Create The Dataset Class

In [None]:
class Flickr8KDataset(Dataset):

    def __init__(self, path, type_data = "train"):
        with open(path, "r") as f:
            self._data = [line.replace("\n", "") for line in f.readlines()]

        self._inference_captions = self._group_captions(self._data)
        self._type = type_data

        with open("Flickr8k_annotations/word2idx.json", "r", encoding="utf8") as f:
            self._word2idx = json.load(f)
        self._idx2word = {str(idx): word for word, idx in self._word2idx.items()}

        self._start_idx = START_IDX
        self._end_idx = END_IDX
        self._pad_idx = PAD_IDX
        self._UNK_idx = UNK_IDX
        self._START_token = START_TOKEN
        self._END_token = END_TOKEN
        self._PAD_token = PAD_TOKEN
        self._UNK_token = UNK_TOKEN

        self._max_len = MAX_LEN

        self._img_feature_channels = IMG_FEATURE_CHANNELS
        self._img_size = IMG_SIZE

        self._image_transform = self._construct_image_transform(self._img_size)
        self.image_dir = "Flickr8k_images"

        self._data = self._create_input_label_mappings(self._data)

        self._dataset_size = len(self._data)

    def _construct_image_transform(self, image_size):
        normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
        preprocessing = transforms.Compose([
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            normalize,
        ])
        return preprocessing

    def _group_captions(self, data):
        grouped_captions = {}

        for line in data:
            caption_data = line.split()
            img_name, img_caption = caption_data[0].split("#")[0], caption_data[1:]
            if img_name not in grouped_captions:
                grouped_captions[img_name] = []

            grouped_captions[img_name].append(img_caption)

        return grouped_captions

    def _create_input_label_mappings(self, data):
        processed_data = []
        for line in data:
            tokens = line.split()
            img_name, caption_words = tokens[0].split("#")[0], tokens[1:]
            pair = (img_name, caption_words)
            processed_data.append(pair)

        return processed_data

    def _load_and_prepare_image(self, image_name):
        image_path = os.path.join(self.image_dir, image_name)
        raw_image = Image.open(image_path).convert("RGB")
        image_tensors = self._image_transform(raw_image)
        return image_tensors

    def inference_batch(self, batch_size):
        caption_data_items = list(self._inference_captions.items())

        num_batches = len(caption_data_items) // batch_size
        for idx in range(num_batches):
            caption_samples = caption_data_items[idx * batch_size: (idx + 1) * batch_size]
            batch_imgs = []
            batch_captions = []

            idx += batch_size

            for image_name, captions in caption_samples:
                batch_captions.append(captions)
                image_tensor = self._load_and_prepare_image(image_name)
                batch_imgs.append(image_tensor)


            batch_imgs = torch.stack(batch_imgs, dim=0)
            if batch_size == 1:
                batch_imgs = batch_imgs.squeeze(0)

            yield batch_imgs, batch_captions

    def __len__(self):
        return self._dataset_size

    def __getitem__(self, index):
        image_id, tokens = self._data[index]


        image_tensor = self._load_and_prepare_image(image_id)


        tokens = tokens[:self._max_len]

        tokens = [token.strip().lower() for token in tokens]
        tokens = [self._START_token] + tokens + [self._END_token]

        input_tokens = tokens[:-1].copy()
        tgt_tokens = tokens[1:].copy()

        sample_size = len(input_tokens)
        padding_size = self._max_len - sample_size

        if padding_size > 0:
            padding_vec = [self._PAD_token for _ in range(padding_size)]
            input_tokens += padding_vec.copy()
            tgt_tokens += padding_vec.copy()

        input_tokens = [self._word2idx.get(token, self._UNK_idx) for token in input_tokens]
        tgt_tokens = [self._word2idx.get(token, self._UNK_idx) for token in tgt_tokens]

        input_tokens = torch.Tensor(input_tokens).long()
        tgt_tokens = torch.Tensor(tgt_tokens).long()

        tgt_padding_mask = torch.ones([self._max_len, ])
        tgt_padding_mask[:sample_size] = 0.0
        tgt_padding_mask = tgt_padding_mask.bool()

        return image_tensor, input_tokens, tgt_tokens, tgt_padding_mask

In [None]:
train_set = Flickr8KDataset(split_save["train"], "train")
valid_set = Flickr8KDataset(split_save["validation"], "valid")

print("Number of image:caption pairs in the train set: ", len(train_set))
print("Number of image:caption pairs in the validation set: ", len(valid_set))

Number of image:caption pairs in the train set:  35000
Number of image:caption pairs in the validation set:  5000


In [None]:
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=SHUFFLE, num_workers=NUM_WORKERS, drop_last=DROP_LAST)

In [None]:
class ResidualBlock(nn.Module):

    def __init__(self, input_dim):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Linear(input_dim, input_dim),
            nn.LeakyReLU(),
            nn.Linear(input_dim, input_dim),
        )

    def forward(self, x):
        skip_connection = x
        x = self.block(x)
        x = skip_connection + x
        return x


class Normalize(nn.Module):
    def __init__(self, eps=1e-5):
        super(Normalize, self).__init__()
        self.register_buffer("eps", torch.Tensor([eps]))

    def forward(self, x, dim=-1):
        norm = x.norm(2, dim=dim).unsqueeze(-1)
        x = self.eps * (x / norm)
        return x


class PositionalEncodings(nn.Module):

    def __init__(self, seq_len, d_model, p_dropout):
        super(PositionalEncodings, self).__init__()
        token_positions = torch.arange(start=0, end=seq_len).view(-1, 1)
        dim_positions = torch.arange(start=0, end=d_model).view(1, -1)
        angles = token_positions / (10000 ** ((2 * dim_positions) / d_model))

        encodings = torch.zeros(1, seq_len, d_model)
        encodings[0, :, ::2] = torch.cos(angles[:, ::2])
        encodings[0, :, 1::2] = torch.sin(angles[:, 1::2])
        encodings.requires_grad = False
        self.register_buffer("positional_encodings", encodings)

        self.dropout = nn.Dropout(p_dropout)

    def forward(self, x):
        x = x + self.positional_encodings
        x = self.dropout(x)
        return x

class Encoder(nn.Module):

    def __init__(self, train=False):
        super(Encoder, self).__init__()
        resnet = timm.create_model('seresnet152d', pretrained=True)
        modules = list(resnet.children())[:-2]
        self.resnet = nn.Sequential(*modules)
        if not train:
          for p in self.resnet.parameters():
              p.requires_grad = False

    def forward(self, images):
        out = self.resnet(images)  # (batch_size, 2048, image_size/32, image_size/32)
        out = out.view(out.size(0), out.size(1), -1)
        out = out.permute(0, 2, 1)
        return out


class CaptionDecoder(nn.Module):

    def __init__(self, decoder_layers, d_model, ff_dim, attention_heads, dropout, embedding_dim, img_feature_channels, word_embeddings, vocab_size, device):
        super(CaptionDecoder, self).__init__()

        self.embedding_layer = nn.Embedding.from_pretrained(
            word_embeddings,
            freeze=True,
            padding_idx=0
        )

        self.entry_mapping_words = nn.Linear(embedding_dim, d_model)
        self.entry_mapping_img = nn.Linear(img_feature_channels, d_model)

        self.res_block = ResidualBlock(d_model)

        self.positional_encodings = PositionalEncodings(60, d_model, dropout)
        transformer_decoder_layer = TransformerDecoderLayer(
            d_model=d_model,
            nhead=attention_heads,
            dim_feedforward=ff_dim,
            dropout=dropout
        )
        self.decoder = TransformerDecoder(transformer_decoder_layer, decoder_layers)
        self.classifier = nn.Linear(d_model, vocab_size)
        self.set_up_causal_mask(MAX_LEN, device)

    def set_up_causal_mask(self, seq_len, device):
        self.casual_mask = (torch.triu(torch.ones(seq_len, seq_len)) == 1).transpose(0, 1)
        self.casual_mask = self.casual_mask.float().masked_fill(self.casual_mask == 0, float('-inf')).masked_fill(self.casual_mask == 1, float(0.0)).to(device)
        self.casual_mask.requires_grad = False

    def forward(self, x, image_features, tgt_padding_mask=None):
        image_features = self.entry_mapping_img(image_features)
        image_features = image_features.permute(1, 0, 2)
        image_features = F.leaky_relu(image_features)

        x = self.embedding_layer(x)
        x = self.entry_mapping_words(x)
        x = F.leaky_relu(x)

        x = self.res_block(x)
        x = F.leaky_relu(x)

        x = self.positional_encodings(x)

        x = x.permute(1, 0, 2)

        tgt_padding_mask = tgt_padding_mask.type(torch.bool)
        self.casual_mask = self.casual_mask.type(torch.bool)

        x = self.decoder(
            tgt=x,
            memory=image_features,
            tgt_key_padding_mask=tgt_padding_mask,
            tgt_mask=self.casual_mask
        )
        x = x.permute(1, 0, 2)

        x = self.classifier(x)
        return x

class CNNtoTransformer(nn.Module):

    def __init__(self, decoder_layers, d_model, ff_dim, attention_heads, dropout, embedding_dim, img_feature_channels, vocab_size, device):
        super(CNNtoTransformer, self).__init__()

        self.encoderCNN = Encoder()
        self.encoderCNN.eval()

        word_embeddings = torch.Tensor(np.loadtxt("Flickr8k_annotations/embeddings.txt"))
        self.decoderRNN = CaptionDecoder(decoder_layers, d_model, ff_dim, attention_heads, dropout, embedding_dim, img_feature_channels, word_embeddings, vocab_size, device)

    def forward(self, images, x_words, tgt_padding_mask=None):
        image_features = self.encoderCNN(images)
        outputs = self.decoderRNN(x_words, image_features, tgt_padding_mask)
        return outputs

    def forward_encoder(self, images):
        return self.encoderCNN(images)

    def forward_decoder(self, x_words, image_features, tgt_padding_mask):
        return self.decoderRNN(x_words, image_features, tgt_padding_mask)

    def set_encoder_train_mode(self, mode=True):
        if mode:
            self.encoderCNN.train()
        else:
            self.encoderCNN.eval()

    def set_decoder_train_mode(self, mode=True):
        if mode:
            self.decoderRNN.train()
        else:
            self.decoderRNN.eval()


In [None]:
encoder_decoder = CNNtoTransformer(
    decoder_layers = DECODER_LAYERS,
    d_model = D_MODEL,
    ff_dim = FF_DIM,
    attention_heads = ATTENTION_HEADS,
    dropout = DROPOUT,
    embedding_dim = EMBEDDING_SIZE,
    img_feature_channels = IMG_FEATURE_CHANNELS,
    vocab_size = len(new_vocab),
    device = device
    )

encoder_decoder = encoder_decoder.to(device)
encoder_decoder.set_encoder_train_mode(False)
encoder_decoder.set_decoder_train_mode(True)

model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

In [None]:
optimizer = torch.optim.AdamW(
    encoder_decoder.parameters(),
    lr=LEARNING_RATE,
    weight_decay=L2_PENALTY
)
loss_fcn = nn.CrossEntropyLoss(label_smoothing=0.1)

In [None]:
def save_checkpoint(model, optimizer, start_time, epoch):

    target_dir = os.path.join("checkpoints", str(start_time))
    os.makedirs(target_dir, exist_ok=True)

    save_path_model = os.path.join(target_dir, f"model_{epoch}.pth")
    save_path_optimizer = os.path.join(target_dir, f"optimizer_{epoch}.pth")
    torch.save(model.state_dict(), save_path_model)
    torch.save(optimizer.state_dict(), save_path_optimizer)
    print("Model saved.")

In [None]:
def greedy_decoding(model, img_features_batched, sos_id, eos_id, pad_id, idx2word, max_len, device):
    batch_size = img_features_batched.size(0)

    # Define the initial state of decoder input
    x_words = torch.Tensor([sos_id] + [pad_id] * (max_len - 1)).to(device).long()
    x_words = x_words.repeat(batch_size, 1)
    padd_mask = torch.Tensor([True] * max_len).to(device).bool()
    padd_mask = padd_mask.repeat(batch_size, 1)

    # Is each image from the batch decoded
    is_decoded = [False] * batch_size
    generated_captions = []
    for _ in range(batch_size):
        generated_captions.append([])

    for i in range(max_len - 1):
        # Update the padding masks
        padd_mask[:, i] = False

        # Get the model prediction for the next word
        y_pred_prob = model.forward_decoder(x_words, img_features_batched, padd_mask)
        # Extract the prediction from the specific (next word) position of the target sequence
        y_pred_prob = y_pred_prob[torch.arange(batch_size), [i] * batch_size].clone()
        # Extract the most probable word
        y_pred = y_pred_prob.argmax(-1)

        for batch_idx in range(batch_size):
            if is_decoded[batch_idx]:
                continue
            # Add the generated word to the caption
            generated_captions[batch_idx].append(idx2word[str(y_pred[batch_idx].item())])
            if y_pred[batch_idx] == eos_id:
                # Caption has been fully generated for this image
                is_decoded[batch_idx] = True

        if np.all(is_decoded):
            break

        if i < (max_len - 1):   # We haven't reached maximum number of decoding steps
            # Update the input tokens for the next iteration
            x_words[torch.arange(batch_size), [i+1] * batch_size] = y_pred.view(-1)

    # Complete the caption for images which haven't been fully decoded
    for batch_idx in range(batch_size):
        if not is_decoded[batch_idx]:
            generated_captions[batch_idx].append(idx2word[str(eos_id)])

    # Clean the EOS symbol
    for caption in generated_captions:
        caption.remove("<end>")

    return generated_captions

In [None]:
def evaluate(subset, encoder_decoder, device):
    """Evaluates (BLEU score) caption generation model on a given subset.

    Arguments:
        subset (Flickr8KDataset): Train/Val/Test subset
        encoder (nn.Module): CNN which generates image features
        decoder (nn.Module): Transformer Decoder which generates captions for images
        config (object): Contains configuration for the evaluation pipeline
        device (torch.device): Device on which to port used tensors
    Returns:
        bleu (float): BLEU-{1:4} scores performance metric on the entire subset - corpus bleu
    """
    max_len = 60
    batch_size = 32
    bleu_w = {
      "bleu-1": [1.0],
      "bleu-2": [0.5, 0.5],
      "bleu-3": [0.333, 0.333, 0.333],
      "bleu-4": [0.25, 0.25, 0.25, 0.25]
    }

    idx2word = subset._idx2word

    sos_id = subset._start_idx
    eos_id = subset._end_idx
    pad_id = subset._pad_idx

    references_total = []
    predictions_total = []

    print("Evaluating model.")
    for x_img_batched, y_caption_batched in subset.inference_batch(batch_size):

        x_img_batched = x_img_batched.to(device)
        img_features = encoder_decoder.forward_encoder(x_img_batched)
        img_features = img_features.detach()

        predictions = greedy_decoding(encoder_decoder, img_features, sos_id, eos_id, pad_id, idx2word, max_len, device)

        references_total += y_caption_batched
        predictions_total += predictions

    # Evaluate BLEU score of the generated captions
    bleu_1 = corpus_bleu(references_total, predictions_total, weights=bleu_w["bleu-1"]) * 100
    bleu_2 = corpus_bleu(references_total, predictions_total, weights=bleu_w["bleu-2"]) * 100
    bleu_3 = corpus_bleu(references_total, predictions_total, weights=bleu_w["bleu-3"]) * 100
    bleu_4 = corpus_bleu(references_total, predictions_total, weights=bleu_w["bleu-4"]) * 100
    bleu = [bleu_1, bleu_2, bleu_3, bleu_4]
    return bleu

In [None]:
start_time = time.strftime("%b-%d_%H-%M-%S")

max_bleu = [0.0, 0.0, 0.0, 0.0]
patience = 5

for epoch in range(NUM_OF_EPOCHS):
    print("Epoch:", epoch)
    encoder_decoder.set_encoder_train_mode(False)
    encoder_decoder.set_decoder_train_mode(True)

    print("Number of batches of 32 on training set: ", len(train_set)//BATCH_SIZE)
    batch_step = 0
    batch_loss = 0
    for x_img, x_words, y, tgt_padding_mask in train_loader:

        optimizer.zero_grad()
        batch_step += 1

        x_img, x_words = x_img.to(device), x_words.to(device)
        y = y.to(device)
        tgt_padding_mask = tgt_padding_mask.to(device)


        y_pred = encoder_decoder(x_img, x_words, tgt_padding_mask)
        tgt_padding_mask = torch.logical_not(tgt_padding_mask)
        y_pred = y_pred[tgt_padding_mask]

        y = y[tgt_padding_mask]

        loss = loss_fcn(y_pred, y.long())
        batch_loss += loss.detach()

        loss.backward()
        torch.nn.utils.clip_grad_norm_(encoder_decoder.parameters(), GRADIENT_CLIPPING)
        optimizer.step()

        if batch_step % 50 == 0:
            print("Batch Step: ", batch_step)
            print("Batch Loss: ", loss.detach())

    print("Epoch mean loss: ", batch_loss/batch_step)

    save_checkpoint(encoder_decoder, optimizer, start_time, epoch)
    if (epoch + 1) % EVAL_PERIOD == 0:
        with torch.no_grad():
            encoder_decoder.set_decoder_train_mode(False)

            valid_bleu = evaluate(valid_set, encoder_decoder, device)

            if any(x > y for x, y in zip(valid_bleu, max_bleu)):
                max_bleu = valid_bleu
            else:
                print("None of the BLEU scores increased, patience down: ", patience)
                patience -= 1

            print(valid_bleu)

            encoder_decoder.set_decoder_train_mode(True)

    if patience == 0:
        print("Training stopped due to early stopping.")
        break
    print()

Epoch: 0
Number of batches of 32 on training set:  1093
Batch Step:  50
Batch Loss:  tensor(5.7370, device='cuda:0')
Batch Step:  100
Batch Loss:  tensor(5.7799, device='cuda:0')
Batch Step:  150
Batch Loss:  tensor(5.8365, device='cuda:0')
Batch Step:  200
Batch Loss:  tensor(5.6079, device='cuda:0')
Batch Step:  250
Batch Loss:  tensor(5.7523, device='cuda:0')
Batch Step:  300
Batch Loss:  tensor(5.9917, device='cuda:0')
Batch Step:  350
Batch Loss:  tensor(5.8114, device='cuda:0')
Batch Step:  400
Batch Loss:  tensor(5.5458, device='cuda:0')
Batch Step:  450
Batch Loss:  tensor(5.7606, device='cuda:0')
Batch Step:  500
Batch Loss:  tensor(5.7411, device='cuda:0')
Batch Step:  550
Batch Loss:  tensor(5.7722, device='cuda:0')
Batch Step:  600
Batch Loss:  tensor(5.4226, device='cuda:0')
Batch Step:  650
Batch Loss:  tensor(5.7959, device='cuda:0')
Batch Step:  700
Batch Loss:  tensor(5.7883, device='cuda:0')
Batch Step:  750
Batch Loss:  tensor(5.7276, device='cuda:0')
Batch Step:  80

In [None]:
class TestDataset(Dataset):
    def __init__(self, split_file):
        self.image_dir = "Flickr8k_images"
        self.data = self.load_data(split_file)
        self.setup_transforms()

    def setup_transforms(self):
        self.show_transform = transforms.Compose([
            transforms.Resize(IMG_SIZE),
            transforms.CenterCrop(IMG_SIZE),
            transforms.ToTensor(),
        ])
        self.normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
        self.model_transform = transforms.Compose([
            transforms.Resize(IMG_SIZE),
            transforms.CenterCrop(IMG_SIZE),
            transforms.ToTensor(),
            self.normalize,
        ])

    def load_data(self, split_file):
        with open(split_file, "r") as f:
            data = [line.replace("\n", "") for line in f.readlines()]
        return self.group_captions(data)

    def group_captions(self, data):
        grouped_captions = {}
        for line in data:
            caption_data = line.split()
            img_name, img_caption = caption_data[0].split("#")[0], caption_data[1:]
            if img_name not in grouped_captions:
                grouped_captions[img_name] = []
            grouped_captions[img_name].append(img_caption)
        return grouped_captions

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):

        img_id = list(self.data.keys())[idx]
        caption_words_list = self.data[img_id]
        img_path = os.path.join(self.image_dir, img_id)
        raw_image = Image.open(img_path).convert("RGB")

        image_tensor = self.model_transform(raw_image)
        raw_image_tensor = self.show_transform(raw_image)

        return image_tensor, raw_image_tensor, caption_words_list

    def display_image(self, image_tensor):
        image_array = image_tensor.permute(1, 2, 0).numpy()
        image_array = (image_array * 255).astype('uint8')
        plt.imshow(image_array)
        plt.axis('off')
        plt.show()

In [None]:
from nltk.translate.bleu_score import sentence_bleu

def evaluate_bleu(encoder_decoder, test_dataloader, device, subset):
    max_len = 60
    sos_id, eos_id, pad_id = subset._start_idx, subset._end_idx, subset._pad_idx
    idx2word = subset._idx2word

    bleu_w = {
        "bleu-1": [1.0],
        "bleu-2": [0.5, 0.5],
        "bleu-3": [0.333, 0.333, 0.333],
    }

    references_total = []
    predictions_total = []

    bleu_1_overall, bleu_2_overall, bleu_3_overall = 0, 0, 0

    print("Evaluating model.")

    c = 0
    for x_img_batched, x_img_raw, caption_words_list in test_dataloader:
        x_img_batched = x_img_batched.to(device)
        img_features = encoder_decoder.forward_encoder(x_img_batched).detach()
        predictions = greedy_decoding(encoder_decoder, img_features, sos_id, eos_id, pad_id, idx2word, max_len, device)
        prediction = predictions[0]
        prediction = [string for string in prediction if "<unk>" not in string]

        bleu_1_image, bleu_2_image, bleu_3_image = 0, 0, 0

        list_of_lists_of_words = [[word[0] if isinstance(word, tuple) else word for word in sublist] for sublist in caption_words_list]
        print("------------------------------------")
        print("Predicted Caption:")
        print(" ".join(prediction))
        print()
        print("Actual Caption:")
        for caption in list_of_lists_of_words:
            print(" ".join(caption))
            bleu_1_image = max(bleu_1_image, sentence_bleu([caption], prediction, weights=bleu_w["bleu-1"]) * 100)
            bleu_2_image = max(bleu_2_image, sentence_bleu([caption], prediction, weights=bleu_w["bleu-2"]) * 100)
            bleu_3_image = max(bleu_3_image, sentence_bleu([caption], prediction, weights=bleu_w["bleu-3"]) * 100)
        print("------------------------------------")
        bleu_1_overall += bleu_1_image
        bleu_2_overall += bleu_2_image
        bleu_3_overall += bleu_3_image
        c += 1

    bleu_1_overall /= c
    bleu_2_overall /= c
    bleu_3_overall /= c

    print(f"bleu_1 = {bleu_1_overall:.2f}, bleu_2 = {bleu_2_overall:.2f}, bleu_3 = {bleu_3_overall:.2f}")

test_dataset = TestDataset(split_save["test"])
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)
evaluate_bleu(encoder_decoder, test_dataloader, device, train_set)


Evaluating model.
------------------------------------
Predicted Caption:
woman and white dog and white and white dog is standing on her dog

Actual Caption:
girl in pink shirt riding large white dog indoors
child on big dog looking towards camera inside house
young girl sitting on back of white furry dog
large white dog carrying girl in pink indoors
indoors girl wearing hat rides on big white dog
------------------------------------


The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


------------------------------------
Predicted Caption:
woman with black and black dog with sunglasses is wearing blue jacket and black and black and red shirt and red shirt and white shirt is wearing white shirt and black shirt and white shirt and white shirt is standing on the sand

Actual Caption:
tan dog wearing red sunglasses on beach
small dog with tongue out sitting on sand
happy dog in sunglasses laying on beach
dog wearing shades relaxing by ocean
sunny day dog with red glasses on shore
------------------------------------
------------------------------------
Predicted Caption:
dog is running on the grass

Actual Caption:
dog chasing ball on grassy field
two dogs playing catch outside
brown dog running after pink ball
large white dog walking on grass
people watching dogs play in park
------------------------------------
------------------------------------
Predicted Caption:
two dogs are playing in the grass

Actual Caption:
three dogs playing together in park area
woman stand