In [9]:
import torch
import torch.nn as nn # NN networks (CNN, RNN, losses)
import torch.optim as optim # Aptimizers (Adam, Adadelta, Adagrad)
import torch.nn.functional as F # Activarions func (ReLU, Sigmoid) also included in nn
from torch.utils.data import DataLoader, Dataset # Dataset manager
from torch.nn.utils.rnn import pad_sequence
import torchvision.datasets as datasets # Datasets
import torchvision.transforms as transforms # Transformation to datasets
import torchvision
import pandas as pd
import os
import spacy # Tokenizer
from PIL import Image
from skimage import io 

In [10]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
spacy_eng = spacy.load("en_core_web_sm")


In [11]:
# Convert text to the numbers
# Create a vocabulary
# Use DataLoader
# Setup padding

In [12]:
class MyCollate:
    def __init__(self, pad_idx) -> None:
        self.pad_idx = pad_idx
        
    def __call__(self, batch):
        imgs = [item[0].unsqueeze(0) for item in batch]
        imgs = torch.cat(imgs, dim=0)
        
        targets = [item[1] for item in batch]
        targets = pad_sequence(targets, batch_first=False, padding_value=self.pad_idx)
        
        return imgs, targets
       
        

In [13]:
class Vocabulary:
    def __init__(self, freq_threshold) -> None:
        self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
        self.stoi = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
        self.freq_threshold = freq_threshold
        
    def __len__(self):
        return len(self.itos)
    
    @staticmethod
    def tokenizer_eng(text):
        return [tok.text.lower for tok in spacy_eng.tokenizer(text)]
    
    def build_vocabulary(self, sentence_list):
        frequencies = {}
        idx = 4
        
        for sentence in sentence_list:
            for word in self.tokenizer_eng(sentence):
                if word not in frequencies:
                    frequencies[word] = 1
                else: 
                    frequencies[word] += 1
                if frequencies[word] == self.freq_threshold:
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    idx += 1
    
    def numericalize(self, text):
        tokenized_text = self.tokenizer_eng(text)
        return [
            self.stoi[token] if token in self.stoi else self.stoi["<UNK>"]
            for token in tokenized_text
        ]

In [None]:
class FlickrDataset(Dataset):
    def __init__(self, root_dir, caption_file, transform=None, freq_threshold=5) -> None:
        # super().__init__()
        self.root_dir = root_dir
        self.df = pd.read_csv(caption_file)
        self.transform = transform
        
        # Get mage and captions col
        self.imgs = self.df["image"]
        self.captions = self.df["caption"]
        
        # Init and build a vocab
        self.vocab = Vocabulary(freq_threshold)
        self.vocab.build_vocabulary(self.captions.tolist())
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        # return super().__getitem__(index)
        caption = self.caption[index]
        img_id = self.imgs[index]
        img = Image.open(os.path.join(self.root_dir, img_id)).convert("RGB")
        
        if self.transform is not None:
            img = self.transform(img)
            
        numericalized_caption = [self.vocab.stoi["<SOS>"]]
        numericalized_caption += self.vocab.numericalize(caption)
        numericalized_caption.append(self.vocab.stoi["<EOS>"])
        
        return img, torch.tensor(numericalized_caption)
    
        

In [None]:
def get_loader(root_folder,
               annotetion_file,
               transform=None, 
               batch_size=32,
               num_workers=3,
               shuffle=True,
               pin_memory=True):
    dataset = FlickrDataset(root_folder, annotetion_file, transform=transform)
    pad_idx = dataset.vocab.stoi["<PAD>"]
    
    loader = DataLoader(dataset=dataset,
                        batch_size=batch_size,
                        num_workers=num_workers,
                        shuffle=shuffle,
                        pin_memory=pin_memory,
                        collate_fn=MyCollate(pad_idx=pad_idx))
    return loader

In [None]:
# transform = transforms.Compose()

root_folder = "./dataset/flickr8k/images/"
annotation_file = "./dataset/flickr8k/captions.txt"

dataloader = get_loader(root_folder, annotation_file, None)

In [None]:
transform = transforms.Compose(
    [transforms.Resize((224, 224)), transforms.ToTensor(),]
)

root_folder = ".\src/dataset/flickr8k/images/"
annotation_file = ".\src/dataset/flickr8k/captions.txt"
loader, dataset = get_loader(
    root_folder, annotation_file, transform=transform
)

for idx, (imgs, captions) in enumerate(loader):
    print(imgs.shape)
    print(captions.shape)