<h1>Dataset and DataLoader for Image Captioning</h1>


<h3><span style='color:yellow'>In the previous three tutorials, we highlighted Torchtext's ability to handle text data from JSON, CSV, TSV, and built-in torch datasets.</span></h3>

<h3><span style='color:yellow'>In this tutorial, we will demonstrate how to build a custom dataset for image captioning applications. Here, we construct dataset objects for both image and text data, reducing the demand for Torchtext.</span></h3>

<h3><span style='color:yellow'>We will use the Flicker 8k dataset, which includes a folder of images and a corresponding caption text file, stacking captions for each image (5 captions for every image).</span></h3>

<div style="display: flex; justify-content: center;">
    <img src='imagcaption.png' width='600'>
</div>

<h3><span style='color:yellow'>The workflow is as follows:</span></h3>

<ul style='font-size: 1.2em;'>
    <li> Convert text to numerical values, utilize tokenization, and map each word to an index.</li>
    <li> Build a Pytorch dataset to load the data.</li>
    <li> Set up batch padding (each example is attributed with padding to ensure samples have equal length).</li>
    <li> Configure the DataLoader to batch images and captions together, then pass them to the model for training and inference.</li>
</ul>


In [58]:
# Importing Libraries
import os
import torch
import pandas as pd
import spacy
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from PIL import Image

In [59]:
# Constructing the primary Flickr dataset class
# Step: 1

Image_path='./datastes/image captioning/Images/'
annotations='./datastes/image captioning/captions.txt'

class FlikerDataset(Dataset):
    def __init__(self,Image_path, captions_file, transform=None,frequancy_thresh=5):
        self.Image_path=Image_path
        self.df=pd.read_csv(captions_file)
        self.transform=transform
        
        # Getting the name of the image and captions
        self.img_name=self.df['image']
        self.captions=self.df['caption']
        
        # Building and initialize the vocabulary
        self.vocab=Vocabulary(frequancy_thresh)
        self.vocab.build_vocabulary(self.captions.tolist())
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        caption=self.captions[index]
        img_id=self.img_name[index]
        image=Image.open(os.path.join(self.Image_path,img_id)).convert('RGB')
        
        if self.transform:
            image=self.transform(image)
        
        numericalized_caption=[self.vocab.stoi['<SOS>']]
        numericalized_caption+=self.vocab.numericalize(caption)    
        numericalized_caption.append(self.vocab.stoi['<EOS>'])
        
        return image, torch.tensor(numericalized_caption)
    
    
    


In [60]:
# Building a vocabulary class
#Step :2
spacy_eng=spacy.load('en_core_web_sm')

class Vocabulary:
    def __init__(self,frequancy_threshold):
        self.itos={0:'<PAD>',1:'<SOS>',2:'<EOS>',3:'<UKN>'} # UKN= unknown words
        # The inverse of itos
        self.stoi={'<PAD>':0,'<SOS>':1,'<EOS>':2, '<UKN>':3}
        self.frequancy_threshold=frequancy_threshold
        
    def __len__(self):
        return len(self.itos)
    
    #  Define the English tokenizer
    @staticmethod
    def tokenizer(text):
        return [token.text.lower() for token in spacy_eng.tokenizer(text)]
    
    def build_vocabulary(self, sentence_list):
        frequancies={}  # To count specific wrod repeated in the dictionary
        index=4  # 4 because we already included th eidxs from 0-3  in the dictionary as special tokens
        
        for sentence in sentence_list:
            for word in self.tokenizer(sentence):
                if word not in frequancies:
                    frequancies[word]=1
                else:
                    frequancies[word]+=1
                    
                if frequancies[word]==self.frequancy_threshold:
                    self.stoi[word]=index
                    index=1
    def numericalize(self,text):
        tokenized_text=self.tokenizer(text)
        return [self.stoi[token] if token in self.stoi else self.stoi['<UKN>'] for token in tokenized_text]
                    
        
    

In [61]:
# Padding the tokenized sentences to have equal lengths.
# This can be achieved either by truncating all sentences to the length of the shortest sentence,
# or by padding all sentences to match the length of the longest sentence.

class Mycollate: # collate  is a collable entity in Pytorch DataLoader
    def __init__(self,pad_idx):
        self.pad_idx=pad_idx
        
    def __call__(self,batch): # Each batch is a list of examples. For each example, there will be an image and a numericalized token.
        # For batching purposes, we need an extra dimension for image data.
        images=[item[0].unsqueeze(0) for item in batch]
        images=torch.cat(images,dim=0)
        targets=[item[1] for item in batch]
        targets=pad_sequence(targets,batch_first=False,padding_value=self.pad_idx)
        return images,targets
        
        


In [62]:
# 4 Define the dataloader
def get_loader(
    root_folder,
    annotation_file,
    transform,
    batch_size=32,
    num_workers=4,
    shuffle=True,
    pin_memory=True):
    
    dataset=FlikerDataset(Image_path,annotation_file,transform=transform)
    pad_idx=dataset.vocab.stoi["<PAD>"]
    loader=DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=shuffle,
        pin_memory=pin_memory,
        collate_fn=Mycollate(pad_idx=pad_idx)
    )
    
    return loader
    

In [63]:
transform=transforms.Compose([transforms.Resize((256, 256), interpolation=Image.BILINEAR),transforms.ToTensor(),transforms.Normalize((0.5,),(0.5))]) #( (mean), (std))

# Since we have multiple workers, it's better to wrap the following code under the main function.
def main():
    dataLoader=get_loader(root_folder=Image_path,annotation_file=annotations,transform=transform)

    for idx,(images,captions) in enumerate(dataLoader):
        print(images.shape)
        print(captions.shape)
if __name__=='__main__':
    main()

torch.Size([32, 3, 256, 256])
torch.Size([22, 32])
torch.Size([32, 3, 256, 256])
torch.Size([21, 32])
torch.Size([32, 3, 256, 256])
torch.Size([24, 32])
torch.Size([32, 3, 256, 256])
torch.Size([21, 32])
torch.Size([32, 3, 256, 256])
torch.Size([25, 32])
torch.Size([32, 3, 256, 256])
torch.Size([28, 32])
torch.Size([32, 3, 256, 256])
torch.Size([22, 32])
torch.Size([32, 3, 256, 256])
torch.Size([22, 32])
torch.Size([32, 3, 256, 256])
torch.Size([21, 32])
torch.Size([32, 3, 256, 256])
torch.Size([25, 32])
torch.Size([32, 3, 256, 256])
torch.Size([25, 32])
torch.Size([32, 3, 256, 256])
torch.Size([19, 32])
torch.Size([32, 3, 256, 256])
torch.Size([23, 32])
torch.Size([32, 3, 256, 256])
torch.Size([25, 32])
torch.Size([32, 3, 256, 256])
torch.Size([23, 32])
torch.Size([32, 3, 256, 256])
torch.Size([19, 32])
torch.Size([32, 3, 256, 256])
torch.Size([20, 32])
torch.Size([32, 3, 256, 256])
torch.Size([21, 32])
torch.Size([32, 3, 256, 256])
torch.Size([21, 32])
torch.Size([32, 3, 256, 256])
t