In [1]:
import numpy as np
import pandas as pd
import spacy
from collections import Counter
from PIL import Image
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader,Dataset
import torchvision.transforms as T

In [2]:
class Vocabulary:
    def __init__(self, threshold=5):
        self.itos = {0:"<PAD>",1:"<SOS>",2:"<EOS>",3:"<UNK>"}
        self.stoi = {v:k for k,v in self.itos.items()}
        self.freq_threshold = threshold
        self.spacy_eng = spacy.load("en_core_web_lg")
        
    def __len__(self):
        return len(self.itos)
    
    def tokenize(self, text):
        tokenized_text = self.spacy_eng.tokenizer(text)
        tokeinzed_text_lower = []
        for token in tokenized_text:
            tokeinzed_text_lower.append(token.text.lower())
        return tokeinzed_text_lower
    
    def build_vocab(self, text_list):
        frequencies = Counter()
        idx = 4
        for text in text_list:
            tokenized_text = self.tokenize(text)
            for word in tokenized_text:
                frequencies[word] +=1
                if frequencies[word] == self.freq_threshold :
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    idx += 1
    
    def string_to_numerical(self, text):
        tokenized_text = self.tokenize(text)
        numeric_list = []
        for token in tokenized_text:
            if token in self.stoi:
                numeric_list.append(self.stoi[token])
            else:
                numeric_list.append(self.stoi['<UNK>'])
        return numeric_list

In [16]:
class Collator:
    def __init__(self, pad_idx, batch_first=False):
        self.pad_idx = pad_idx
        self.batch_first = batch_first
        
    def __call__(self, batch):
        imgs = []
        targets = []
        for item in batch:
            image = item[0]
            targets.append(item[1])
            imgs.append(image.unsqueeze(0))
        imgs = torch.cat(imgs, dim=0)
        targets = pad_sequence(targets, batch_first=self.batch_first, padding_value=self.pad_idx)
        return imgs, targets

In [12]:
class LoadData:
    def __init__(self, image_dir, caption_dir, transform=None, freq_threshold = 5):
        self.image_dir = image_dir
        self.caption_dir = caption_dir
        self.caption_df = pd.read_csv(caption_dir)
        self.image_names = self.caption_df['image']
        self.captions = self.caption_df['caption']
        self.vocab = Vocabulary(threshold=freq_threshold)
        self.vocab.build_vocab(self.captions.tolist())
        self.transform = transform
        
    def __len__(self):
        return self.caption_df.shape[0]
    
    def __getitem__(self, idx):
        caption = self.captions[idx]
        img_name = self.image_names[idx]
        img = Image.open(self.image_dir + img_name).convert("RGB")
        if self.transform:
            img = self.transform(img)
        caption_vector = []
        caption_vector += [self.vocab.stoi['<SOS>']]
        caption_vector += self.vocab.string_to_numerical(caption)
        caption_vector += [self.vocab.stoi['<EOS>']]
        return img, torch.tensor(caption_vector)

In [5]:
batch_size = 32
num_worker = 0
transform = T.Compose([
    T.Resize(256),
    T.RandomCrop(224),
    T.ToTensor(),
    T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

In [17]:
dataset = LoadData(
    image_dir = '../data/Images/',
    caption_dir = '../data/captions.txt',
    transform = transform
)

In [18]:
pad_idx = dataset.vocab.stoi['<PAD>']
collator = Collator(
    pad_idx=pad_idx,
    batch_first=True
)
data_loader = DataLoader(
    dataset = dataset,
    batch_size = batch_size,
    num_workers = num_worker,
    shuffle = True,
    collate_fn = collator
)