

1.   Convert text to numerical value
2.   We need a vocabulary mapping of each word to an index
3.   setup a  pytorch dataset to load the data
4.   Setup padding of every batch
5.   Setup a dataloader



In [4]:
import torch
import torch.nn as nn 
import torch.optim as optim 
import torch.nn.functional as F # activation functions.....etc
from torch.utils.data import DataLoader # easier dataset management
import torchvision.datasets as datasets 
import torchvision.transforms as transforms 
import torchvision.models as models
import torchvision
from PIL import Image
import os
import pandas as pd
import spacy # for tokenization 
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset


In [10]:
spacy_eng=spacy.load("en")

In [36]:
class Vocabulary:
    def __init__(self, freq_threshold):
        self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
        self.stoi = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
        self.freq_threshold = freq_threshold

    def __len__(self):
        return len(self.itos)

    @staticmethod
    def tokenizer_eng(text):
        return [tok.text.lower() for tok in spacy_eng.tokenizer(text)]

    def build_vocabulary(self, sentence_list):
        frequencies = {}
        idx = 4

        for sentence in sentence_list:
            for word in self.tokenizer_eng(sentence):
                if word not in frequencies:
                    frequencies[word] = 1

                else:
                    frequencies[word] += 1

                if frequencies[word] == self.freq_threshold:
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    idx += 1

    def numericalize(self, text):
        tokenized_text = self.tokenizer_eng(text)

        return [
            self.stoi[token] if token in self.stoi else self.stoi["<UNK>"]
            for token in tokenized_text
        ]

In [37]:
class MyCollate:
    def __init__(self, pad_idx):
        self.pad_idx = pad_idx

    def __call__(self, batch):
        imgs = [item[0].unsqueeze(0) for item in batch]
        imgs = torch.cat(imgs, dim=0)
        targets = [item[1] for item in batch]
        targets = pad_sequence(targets, batch_first=False, padding_value=self.pad_idx)

        return imgs, targets


In [79]:
class FlickerDataset(Dataset):
  def __init__(self,root,captions_file,transforms=None,freq_thershold=5):
    self.root=root
    self.df=pd.read_csv(captions_file)
    self.transform=transform


    # get image and caption col

    self.img=self.df["image"]
    self.caption=self.df["caption"]

    #Initialize and build vocabulary

    self.vocab=Vocabulary(freq_thershold)
    self.vocab.build_vocabulary(self.caption.tolist())

  
  def __len__(self):
    return len(self.df)

  def __getitem__(self,index):
    caption=self.caption[index]
    img_id=self.img[index]
    img=Image.open(os.path.join(self.root,img_id)).convert("RGB")

    if(self.transform):
      img=self.transform(img)
    
    numerical_caption=[self.vocab.stoi["<SOS>"]]
    numerical_caption+=self.vocab.numericalize(caption)
    numerical_caption.append(self.vocab.stoi["<EOS>"])

    return img,torch.tensor(numerical_caption)

In [80]:
def get_loader(root,annotation_file,transform,batch_size=32,num_workers=2,shuffle=True,pin_memory=True,):
  dataset=FlickerDataset(root,annotation_file,transform)
  pad_idx=dataset.vocab.stoi["<PAD>"]
  loader=DataLoader(
      dataset=dataset,
      batch_size=batch_size,
      num_workers=num_workers,
      shuffle=shuffle,
      pin_memory=pin_memory,
      collate_fn=MyCollate(pad_idx),
  )

  return loader,dataset

In [81]:
transform = transforms.Compose([
                                        transforms.Resize((300,300)),
                                        transforms.ToTensor()
                                        ]
                        
                                       )
dataloader,dataset=get_loader("/content/drive/MyDrive/Pytorch Tutorial/Datasets/flickr8k/images",
                      "/content/drive/MyDrive/Pytorch Tutorial/Datasets/flickr8k/captions.txt",
                      transform = transform )

In [82]:
for idx,(img,caption) in enumerate(dataloader):
  if(idx==1):
    print(img.shape)
    print(caption)
    break

torch.Size([32, 3, 300, 300])
tensor([[   1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1],
        [   4,    4,  139,    4,    4,    4,   10,    4,    4,    4,   71,    4,
           50,    4,    4,    4,    4,    4,    4,    4,   96,   71,    4,    4,
            4,    4,   57,    4,    4,  574,  166,    4],
        [  20,  402,   61,   14,   61,   14,  362,    6,   21,   80, 2564,    6,
           51,   61,  196,   61,  610,   14,   14,   80,   98,   97,   80, 2214,
          907,   56,   16,   20,   20, 2032,   97,   28],
        [  16,  116,   68,   17,  145,   17,   17,  111,    6,    8,  122,   17,
           34,   43,   80,   43,    8,   12,    8, 1462,  337,   34,   79,  316,
            7,  111,  543,    6,    6,  629,  108, 1558],
        [  21, 1021,  205,   32,    7,    3,  578,   13,   17,    4,    4,   29,
         