### Importing libraries for dataset creation

In [1]:
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

In [2]:
from torch.utils.data import DataLoader, Dataset

In [3]:
import pandas as pd
import string

In [1]:
from torchvision import transforms
from torchvision.io import read_image

In [5]:
from pathlib import Path

In [1]:
import pickle


In [7]:
def remove_spaces(str_):
    return ' '.join(str_.split())

In [10]:
class ImageCaptioningDataset(Dataset):
    def __init__(self,caption_file,image_dir,image_size=(224,224),image_transform=transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])):
        self.image_dir = Path(image_dir)
        captions_csv = pd.read_csv(caption_file, sep="|")
        self.captions_csv = captions_csv.rename(lambda x: x.strip(), axis=1)
        self.captions_csv.comment.loc[19999] = self.captions_csv.comment_number.loc[19999][3:]
        self.captions_csv.comment_number.loc[19999] = 4
        self.captions_csv['comment'] = self.captions_csv['comment'].str.lower()
        self.captions_csv['comment'] = self.captions_csv['comment'].str.strip()
        self.captions_csv['comment'] = self.captions_csv['comment'].apply(remove_spaces)
        self.captions_csv['comment'] = self.captions_csv['comment'].str.translate(str.maketrans('', '', string.punctuation))
        
        self.image_transform = transforms.Compose([transforms.Resize(image_size),transforms.ToTensor(),image_transform])        

        self.tokenizer = get_tokenizer('spacy', language='en_core_web_sm',)
        self.list_of_tokens = self.captions_csv['comment'].apply(self.tokenizer)
        self.vocab = build_vocab_from_iterator(self.list_of_tokens,min)
        self.vocab.append_token('<START>')
        self.vocab.append_token('<END>')
        self.vocab.append_token('<PAD>')
        self.list_of_tokens = list(map(lambda x: ['<START>'] + x + ['<END>'], self.list_of_tokens))
        self.list_of_tokens = [torch.IntTensor(self.vocab(tokens)) for tokens in self.list_of_tokens]
        self.images = {image_name: False for image_name in self.captions_csv.image_name.unique()}
        
            
    def __len__(self):
        return self.captions_csv.shape[0]
    
    def __getitem__(self, idx):
        if isinstance(self.images[self.captions_csv.loc[idx,'image_name']], torch.Tensor):
            return self.images[self.captions_csv.loc[idx,'image_name']], self.list_of_tokens[idx], self.list_of_tokens[idx].shape[0]
        image_path = self.image_dir / Path(self.captions_csv.loc[idx,'image_name'])
        image = Image.open(image_path)
        image = self.image_transform(image)
        self.images[self.captions_csv.loc[idx,'image_name']] = image
        return image, self.list_of_tokens[idx], self.list_of_tokens[idx].shape[0]

In [15]:
class ImageCaptioningDatasetOnlyFeatures(Dataset):
    def __init__(self,caption_file,image_dir,encoder=None,cache_file=None):
        self.image_dir = Path(image_dir)
        captions_csv = pd.read_csv(caption_file, sep="|")
        
        self.captions_csv = captions_csv.rename(lambda x: x.strip(), axis=1)
        self.captions_csv.comment.loc[19999] = self.captions_csv.comment_number.loc[19999][3:]
        self.captions_csv.comment_number.loc[19999] = 4
        self.captions_csv['comment'] = self.captions_csv['comment'].str.lower()
        self.captions_csv['comment'] = self.captions_csv['comment'].str.strip()
        self.captions_csv['comment'] = self.captions_csv['comment'].apply(remove_spaces)
        self.captions_csv['comment'] = self.captions_csv['comment'].str.translate(str.maketrans('', '', string.punctuation))
    
        self.tokenizer = get_tokenizer('spacy', language='en_core_web_sm',)
        self.list_of_tokens = self.captions_csv['comment'].apply(self.tokenizer)
        self.vocab = build_vocab_from_iterator(self.list_of_tokens, min_freq=10,special_first=True,specials=["<START>", "<END>", "<PAD>"],)
        self.vocab.set_default_index(-1)
        self.list_of_tokens = list(map(lambda x: ['<START>'] + x + ['<END>'], self.list_of_tokens))
        self.list_of_tokens = [torch.IntTensor([ele for ele in self.vocab(tokens) if ele !=-1]) for tokens in self.list_of_tokens]
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        if encoder:
            self.encoder=encoder.to(self.device).eval()
        if cache_file:
            with open(cache_file, 'rb') as f:
                self.images_features = pickle.load(f)
        else:
            self.images_features = {image_name: False for image_name in self.captions_csv.image_name.unique()}
            
    def __len__(self):
        return self.captions_csv.shape[0]
    
    def __getitem__(self, idx):
        if isinstance(self.images_features[self.captions_csv.loc[idx,'image_name']], torch.Tensor):
            return self.images_features[self.captions_csv.loc[idx,'image_name']], self.list_of_tokens[idx], self.list_of_tokens[idx].shape[0], self.image_dir / Path(self.captions_csv.loc[idx,'image_name'])
        image_path = self.image_dir / Path(self.captions_csv.loc[idx,'image_name'])
        image = read_image(str(image_path))
        features = self.encoder(image.to(self.device)).squeeze().cpu()
        self.images_features[self.captions_csv.loc[idx,'image_name']] = features
        return features, self.list_of_tokens[idx], self.list_of_tokens[idx].shape[0], image_path

