In [19]:
import os
import pandas as pd
import spacy
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
from PIL import Image 
import torchvision.transforms as T

spacy_eng = spacy.load("en_core_web_sm")

In [2]:
# We want to convert text -> numerical values
# 1. We need a Vocabulary mapping each word to an index
# 2. We need to setup a Pytorch dataset to load the data
# 3. Setup padding of every batch (all examples should be
#    of same seq_len and setup dataloader).

In [23]:
class Vocabulary:
    def __init__(self, freq_thresh):
        # index to string
        self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
        # string to index
        self.stoi = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
        self.freq_thresh = freq_thresh

    def __len__(self):
        return len(self.itos)
    
    # "I love peanuts" -> ["i", "love", "peanuts"]
    @staticmethod
    def tokenizer_eng(text):
        return [tok.text.lower() for tok in spacy_eng.tokenizer(text)]
    
    def build_vocabulary(self, sentence_list):
        frequencies = {}
        idx = 4 # after the first 3 define tokens

        for sentence in sentence_list:
            for word in self.tokenizer_eng(sentence):
                if word not in frequencies:
                    frequencies[word] = 1
                else:
                    frequencies[word] += 1
                
                # add the word if we see if enough times
                if frequencies[word] == self.freq_thresh:
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    idx += 1

    def numericalize(self, text):
        tokenized_text = self.tokenizer_eng(text)
        # return the index of the token in the vocab if it exists,
        # otherwise return the index for the unknown token
        return [self.stoi[token] if token in self.stoi else self.stoi["<UNK>"] for token in tokenized_text]
    


class FlickrDataset(Dataset):
    def __init__(self, root_dir, captions_file, transform=None, freq_thresh=5):
        self.root_dir = root_dir
        self.df = pd.read_csv(captions_file)
        self.transform = transform

        # get the image, caption colums
        self.imgs = self.df['image']
        self.captions = self.df['caption']

        # initialize vocab and build vocab
        self.vocab = Vocabulary(freq_thresh)
        self.vocab.build_vocabulary(self.captions.tolist())

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        caption = self.captions[index]
        img_id = self.imgs[index]
        img = Image.open(os.path.join(self.root_dir, img_id), mode='r').convert("RGB")

        if self.transform is not None:
            img = self.transform(img)

        # stoi is string to index
        numericalized_caption = [self.vocab.stoi["<SOS>"]]
        numericalized_caption += self.vocab.numericalize(caption)
        numericalized_caption.append(self.vocab.stoi["<EOS>"])

        return img, torch.tensor(numericalized_caption)
    

class MyCollate:
    def __init__(self, pad_idx):
        self.pad_idx = pad_idx

    # allows you to call the object itself as a function
    def __call__(self, batch):
        # unsqueeze the image to create the batch dimension
        imgs = [item[0].unsqueeze(0) for item in batch]
        # concatenate the images along the batch dimension
        imgs = torch.cat(imgs, dim=0)
        targets = [item[1] for item in batch]
        targets = pad_sequence(targets, batch_first=False, padding_value=self.pad_idx)
        return imgs, targets
    
def get_loader(root_folder, annotation_file, transform, batch_size=32, 
               num_workers=8, shuffle=True, pin_memory=True):
    dataset = FlickrDataset(root_folder, annotation_file, transform=transform)
    # pad with the pad token
    pad_idx = dataset.vocab.stoi["<PAD>"]
    # use the collate function collate the images and captions (and do padding)
    loader = DataLoader(dataset=dataset, batch_size=batch_size, num_workers=num_workers, 
                        shuffle=shuffle, pin_memory=pin_memory, collate_fn=MyCollate(pad_idx))
    return loader

In [24]:
transforms = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor()
])
dataloader = get_loader("dataset/flickr8k/Images/", "dataset/flickr8k/captions.txt", 
                        transform=transforms)

for idx, (imgs, captions) in enumerate(dataloader):
    print(imgs.shape)
    print(captions.shape)
    break

torch.Size([32, 3, 224, 224])
torch.Size([22, 32])
