In [23]:
#imports
import os
import pandas as pd
import spacy
import torch
from torchvision import transforms
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
from PIL import Image

spacy_eng = spacy.load("en_core_web_sm")

In [97]:
torch.cuda.set_device(3) if torch.cuda.is_available() else 'cpu'

In [84]:
class Vocabulary:
    def __init__(self, freq_threshold):
        self.itos = {0:"<PAD>", 1:"<SOS>", 2:"<EOS>", 3:"<UNK>"}
        self.stoi = {"<PAD>":0, "<SOS>":1, "<EOS>":2, "<UNK>":3}
        self.freq_threshold = freq_threshold

    def __len__(self):
        return len(self.itos)
    
    @staticmethod
    def tokenizer_eng(text):
        return [tok.text.lower() for tok in spacy_eng.tokenizer(text)]
    
    def build_vocabulary(self, sentence_list):
        frequencies = {}
        idx = 4
        
        for sentence in sentence_list:
            for word in self.tokenizer_eng(sentence):
                if word not in frequencies:
                    frequencies[word] = 1
                    
                else:
                    frequencies[word] += 1
                    
                if frequencies[word] == self.freq_threshold:
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    idx += 1

    def numericalize(self, text):
        tokenized_text = self.tokenizer_eng(text)
            
        return [
            self.stoi[token] if token in self.stoi else self.stoi["<UNK>"]
            for token in tokenized_text
        ]

In [85]:

class FlickrDataset(Dataset):
    def __init__(self, root_dir, captions_file, transform=None, freq_threshold=5):
        self.root_dir = root_dir
        self.df = pd.read_csv(captions_file)
        self.transform = transform
        
        self.imgs = self.df['image']
        self.captions = self.df['caption']
        
        self.vocab = Vocabulary(freq_threshold)
        self.vocab.build_vocabulary(self.captions.tolist())
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        caption = self.captions[index]
        img_id = self.imgs[index]
        img = Image.open(os.path.join(self.root_dir, img_id)).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        
        numericalized_caption = [self.vocab.stoi["<SOS>"]]
        numericalized_caption += self.vocab.numericalize(caption)
        numericalized_caption.append(self.vocab.stoi["<EOS>"])
        
        return img, torch.tensor(numericalized_caption)

In [86]:
class MyCollate:
    def __init__(self, pad_idx):
        self.pad_idx = pad_idx
        
    def __call__(self, batch):
        imgs = [item[0].unsqueeze(0) for item in batch]
        imgs = torch.cat(imgs, dim=0)
        targets = [item[1] for item in batch]
        targets = pad_sequence(targets, batch_first=False, padding_value=self.pad_idx)
    
        return imgs, targets

In [99]:
def get_loader(
    root_folder,
    annotation_file,
    transform,
    freq_threshold,
    batch_size=32,
    num_workers=0,
    shuffle=True,
    pin_memory=False
):
    dataset = FlickrDataset(root_folder, annotation_file, transform=transform, freq_threshold=freq_threshold)
    pad_idx = dataset.vocab.stoi["<PAD>"]
    
    loader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=shuffle,
        pin_memory=pin_memory,
        collate_fn=MyCollate(pad_idx=pad_idx)
    )
    
    return loader

In [100]:
root_folder = r'flickr8k/images'
annotation_file = r'flickr8k/captions.txt' 
transform = transforms.Compose([transforms.ToTensor(), transforms.Resize((224,224))])
dataloader = get_loader(root_folder=root_folder, annotation_file=annotation_file, transform=transform, freq_threshold=5)

In [101]:
for idx,(img, cc) in enumerate(dataloader):
    print(img)
    print(cc)
    break

tensor([[[[1.4550e-01, 2.1450e-01, 2.7372e-01,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [1.3006e-01, 2.1889e-01, 2.9534e-01,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [1.6801e-01, 2.5218e-01, 3.4924e-01,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          ...,
          [2.0393e-01, 2.0973e-01, 2.0010e-01,  ..., 3.0726e-03,
           3.2015e-03, 5.8145e-03],
          [2.0275e-01, 2.0016e-01, 1.9788e-01,  ..., 7.8431e-03,
           9.4977e-03, 1.2417e-02],
          [2.0086e-01, 2.0905e-01, 2.0062e-01,  ..., 7.8431e-03,
           1.2360e-02, 1.1213e-02]],

         [[1.2032e-01, 1.6349e-01, 2.1151e-01,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [1.1010e-01, 2.0110e-01, 2.5481e-01,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [2.0049e-01, 2.5605e-01, 3.2306e-01,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          ...,
          [3.3086e-01, 3.3274e-01, 3.3045e-01,  ..., 3.0726