#### 1) We need to convert Text to Numerical Values
#### 2) We need a vocabulary mapping each word to a index
#### 3) We need to setup a PyTorch dataset to load the data
#### 4) Setup padding for every batch (to make same sequence length for all examples)
#### 5) Setup DataLoader

Ref: https://www.youtube.com/watch?v=9sHcLvVXsns&list=PLhhyoLH6IjfxeoooqP9rhU3HJIAVAJ3Vz&index=12&ab_channel=AladdinPersson

Ref: https://github.com/aladdinpersson/Machine-Learning-Collection/blob/master/ML/Pytorch/Basics/custom_dataset_txt/loader_customtext.py     

In [66]:
import os
import pandas as pd
import spacy # For Tokenization
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from PIL import Image

In [49]:
spacy_eng = spacy.load("en_core_web_sm")

In [50]:
class Vocabulary():
    # freq_threshold => How many times a word is being repeated in vocab
    # Ignore if a word is not being repeated a certain amount of times
    def __init__(self, freq_threshold):
        # Index to string
        self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
        # String to index
        self.stoi = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "UNK": 3}
        self.freq_threshold = freq_threshold
        
    def __len__(self):
        return len(self.itos)
    
    @staticmethod
    def tokenizer_eng(text):
        return [tok.text.lower() for tok in spacy_eng.tokenizer(text)]
    
    def build_vocabulary(self, sentence_list):
        frequencies = {}
        idx = 4 # because 3 are already initialized in self.itos
        
        for sentence in sentence_list:
            for word in self.tokenizer_eng(sentence):
                if word not in frequencies:
                    frequencies[word] = 1
                else:
                    frequencies[word] += 1
                    
                if frequencies[word] == self.freq_threshold:
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    idx += 1
                    
    def numericalize(self, text):
        tokenized_text = self.tokenizer_eng(text)
        
        return [
            self.stoi[token] if token in self.stoi else self.stoi["UNK"]
            for token in tokenized_text
        ]

In [51]:
class FlickerDataset(Dataset):
    def __init__(self, root_dir, captions_file, transform=None, freq_threshold=5):
        self.root_dir = root_dir
        self.df = pd.read_csv(captions_file)
        self.transform = transform
        
        self.imgs = self.df["image"]
        self.captions = self.df["caption"]
        
        # Initialize/Build the vocabulary
        self.vocab = Vocabulary(freq_threshold)
        self.vocab.build_vocabulary(self.captions.tolist())
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        caption = self.captions[index]
        img_id = self.imgs[index]
        img = Image.open(os.path.join(self.root_dir, img_id)).convert("RGB")
        
        if self.transform:
            img = self.transform(img)
            
        numericalized_caption = [self.vocab.stoi["<SOS>"]]
        numericalized_caption += self.vocab.numericalize(caption)
        numericalized_caption.append(self.vocab.stoi["<EOS>"])
        
        return img, torch.tensor(numericalized_caption)

In [77]:
class MyCollate:
    def __init__(self, pad_idx):
        self.pad_idx = pad_idx
        
    def __call__(self, batch):
        imgs = [item[0].unsqueeze(0) for item in batch]
        imgs = torch.cat(imgs, dim=0)
        targets = [item[1] for item in batch]
        targets = pad_sequence(targets, batch_first=False, padding_value=self.pad_idx)
        
        return imgs, targets

In [78]:
def get_loader(root_folder, 
               annotation_file, 
               transform, 
               batch_size=32, 
               num_workers=8, 
               shuffle=True,
               pin_memory=True):
    
    dataset = FlickerDataset(root_folder, annotation_file, transform=transform)
    pad_idx = dataset.vocab.stoi["<PAD>"]
    
    loader = DataLoader(dataset=dataset, batch_size=batch_size,
                        num_workers=num_workers, shuffle=shuffle,
                        pin_memory=pin_memory, collate_fn=MyCollate(pad_idx=pad_idx)
                       )
    
    return loader

In [79]:
root_dir = "../../Data/data5/flickr8k/images/"
annotation_dir = "../../Data/data5/flickr8k/captions.txt"

transform = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ]
)
dataloader = get_loader(root_dir, annotation_dir, transform=transform)

In [81]:
for idx, (imgs, captions) in enumerate(dataloader):
    print(imgs.shape)
    print(captions.shape)
    break

torch.Size([32, 3, 224, 224])
torch.Size([29, 32])


In [84]:
imgs[0]

tensor([[[0.1647, 0.1686, 0.1843,  ..., 0.1922, 0.1725, 0.1725],
         [0.1647, 0.1765, 0.1765,  ..., 0.2549, 0.1804, 0.1725],
         [0.1686, 0.2157, 0.2196,  ..., 0.1725, 0.1451, 0.1569],
         ...,
         [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],

        [[0.1373, 0.1529, 0.1961,  ..., 0.1843, 0.1569, 0.1451],
         [0.1765, 0.1647, 0.1882,  ..., 0.2745, 0.1608, 0.1804],
         [0.1882, 0.2000, 0.2627,  ..., 0.1725, 0.0980, 0.1373],
         ...,
         [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],

        [[0.0471, 0.0392, 0.0941,  ..., 0.1098, 0.0902, 0.0863],
         [0.0627, 0.0745, 0.0941,  ..., 0.1843, 0.0980, 0.1098],
         [0.0941, 0.1098, 0.1412,  ..., 0.0980, 0.0667, 0.

In [83]:
captions[0]

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1])