In [1]:
import torch
import torch.nn as nn
from pathlib import Path
import spacy
from torch.nn.utils.rnn import pad_sequence
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision
from pathlib import Path
import os
import pandas as pd
from PIL import Image
from torch.utils.data import (
    Dataset,
    DataLoader,
)
spacy_eng = spacy.load("en_core_web_sm")

In [2]:
class MyCollate:
    def __init__(self, pad_idx):
        print("Collate initialization")
        self.pad_idx = pad_idx

    def __call__(self, batch):
        print("Collate call")
        imgs = [item[0].unsqueeze(0) for item in batch]
        imgs = torch.cat(imgs, dim = 0)

        labels = [item[1] for item in batch]
        labels = pad_sequence(labels, batch_first=False, padding_value=self.pad_idx)

        return imgs, labels

In [3]:
class Vocabulary:
    def __init__(self, frequence_treshold):
        print("VVocab Init")
        self.itos = {0:"<PAD>", 1:"<SOS>", 2:"<EOS>", 3:"<UNK>"}
        self.stoi = {"<PAD>":0, "<SOS>":1, "<EOS>":2, "<UNK>":3}
        self.freq_threshold = frequence_treshold

    def __len__(self):
        print("Get len vocab")
        return len(self.itos)
    
    @staticmethod
    def tokenizer_eng(text):
        return [tok.text.lower() for tok in spacy_eng.tokenizer(text)]
    
    def build_vocabulary(self, sentence_list):
        print("Vocab build")
        idx = 4
        frequencies = {}
        for sentence in sentence_list:
            for word in self.tokenizer_eng(sentence):
                if word not in frequencies:
                    frequencies[word] = 1
                else:
                    frequencies[word] += 1

                if frequencies[word] == self.freq_threshold:
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    idx += 1

    def numericalize(self, text):
        tokenized_text = self.tokenizer_eng(text)
        return [self.stoi[token] if token in self.stoi else self.stoi["UNK"] for token in tokenized_text]

In [4]:
class FlickrDataset(Dataset):
    def __init__(self, root, captionsFile, transform = None, frequency_treshold = 5):
        print("Dataset initialized")
        self.root = root
        self.df = pd.read_csv(captionsFile)
        self.transform = transform

        self.images = self.df["image"]
        self.captions = self.df["caption"]

        self.vocab = Vocabulary(frequency_treshold)
        self.vocab.build_vocabulary(self.captions.tolist())

    def __len__(self):
        print("Get len dataset")
        return len(self.df)
    
    def __getitem__(self, index):
        print("Get item dataset")
        caption = self.captions[index]
        image_id = self.images[index]
        img = Image.open(os.path.join(self.root, image_id)).convert("RGB")

        if self.transform is not None:
            img = self.transform(img)

        numericCaption = [self.vocab.stoi["<SOS>"]]
        numericCaption += self.vocab.numericalize(caption)
        numericCaption.append(self.vocab.stoi["<EOS>"])
        print("Working on")
        return img, torch.tensor(numericCaption)

In [5]:
def get_loader(root, annotation, transform, batch_size = 32, num_workers = 1, shuffle = False, pin_memory = True):
    print("Get loader")
    dataset = FlickrDataset(root, annotation, transform=transform)
    pad_idx = dataset.vocab.stoi['<PAD>']

    loader = DataLoader(dataset=dataset, batch_size=batch_size, num_workers=num_workers, shuffle=shuffle, pin_memory=pin_memory, collate_fn=MyCollate(pad_idx=pad_idx))
    
    return loader

In [6]:
dataloader = get_loader("../../../data/flickr8k/Images/", "../../../data/flickr8k/captions.txt", transform = transforms.ToTensor())

for idx, (img, captions) in enumerate(dataloader):
    print(img.shape)
    print(captions.shape)
    

Get loader
Dataset initialized
VVocab Init
Vocab build
Collate initialization
