Importing the required libraries

In [1]:
import os
import spacy
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torchtext
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision.models as models

from PIL import Image
from collections import Counter

from torchtext.vocab import vocab
from torchtext.data.utils import get_tokenizer

from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, dataloader



Image Captions

We have 4 captioins for each image.

In [2]:
img_cap = pd.read_csv("Dataset/flickr8k/captions.txt")
img_cap.sample(10)

Unnamed: 0,image,caption
34286,3683185795_704f445bf4.jpg,Boy and girl look on a puppy climbs tree
8438,238512430_30dc12b683.jpg,A person sits on the front deck of a ship and ...
4063,197504190_fd1fc3d4b7.jpg,Two children play soccer in the park .
17501,2950637275_98f1e30cca.jpg,A man doing a handstand outside of a garage .
13119,2677656448_6b7e7702af.jpg,The small brown and white dog is in the pool .
6302,2223382277_9efa58ec45.jpg,Two children are playing hockey on a frozen pond
23626,3252588185_3210fe94be.jpg,A woman in a white shirt and dark jacket stand...
24145,3270273940_61ef506f05.jpg,Children jumping off of cement .
23013,3225998968_ef786d86e0.jpg,A skier in red pants is on a snow covered slope .
2105,1417882092_c94c251eb3.jpg,Two guys standing side by side .


Loading the tokenizer and initializing the Counter

In [3]:
tokenizer = get_tokenizer("basic_english")

counter = Counter()

Building the vocabulary for captions

In [4]:
for line in img_cap["caption"].tolist():
    counter.update(tokenizer(line))

vocab = vocab(counter, min_freq = 5)

Adding the special tokens and setting the default index as UNK (unknown token)

In [5]:
unk_token = "<unk>"
pad_token = "<pad>"
sos_token = "<sos>"
eos_token = "<eos>"

vocab.insert_token(unk_token, 0)
vocab.insert_token(pad_token, 1)
vocab.insert_token(sos_token, 2)
vocab.insert_token(eos_token, 3)

vocab.set_default_index(vocab["<unk>"])

Now, defining the dataset and the creating a get_item method that returns image with their captions tokens in integer form.

In [6]:
class FlickrDataset(Dataset):

    # Initializing the paramters for the dataset
    def __init__(self, root_dir, captions_file, vocab, transform = None):

        """
        root_dir: Path to the images folder
        captions_file: Path to the CSV file containing image names and captions
        vocab: Vocabulary object
        transform: Optional transform to be applied on the images
        """

        self.root_dir = root_dir
        self.transform = transform
        self.vocab = vocab

        # Reading the captions file and storing the image names and captions in the DataFrame df
        df = pd.read_csv(captions_file)

        self.length = len(df)

        # Extracting the image names and captions from the DataFrame df
        self.captions = df["caption"]
        self.img_names = df["image"]

    def __len__(self):
        return self.length
    
    def __getitem__(self, idx):

        # Getting caption and the image for the specified index
        caption = self.captions[idx]
        img_name = self.img_names[idx]

        # Loading the image and applying the transform if provided
        img_location = os.path.join(self.root_dir, img_name)
        img = Image.open(img_location).convert("RGB")

        if self.transform is not None:
            img = self.transform(img)

        # Tokenizing the caption and converting it to a list of indices in the vocabulary
        caption_text_to_index = lambda x: [self.vocab[token] for token in tokenizer(x)]

        # Adding the start and end tokens to the caption vector and converting it to a PyTorch tensor
        caption_vec = []
        caption_vec += [vocab["<sos>"]]
        caption_vec += caption_text_to_index(caption)
        caption_vec += [vocab["<eos>"]]

        return img, torch.tensor(caption_vec)