In [116]:
import os
import matplotlib as plp
import numpy as np
import pandas as pd
import torch
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor, Resize
from torch.utils.data import Dataset
import PIL.Image as Image

In [117]:
# Load data
class Dataset(Dataset):
    def __init__(self, root, ann_file, transform=None):
        self.root = root
        ann_file_dir = os.path.join(self.root, ann_file)
        self.labels = pd.read_csv(ann_file_dir, sep='|')
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image_path = os.path.join(self.root, "images", self.labels.iloc[idx, 0])
        image = Image.open(image_path)

        # Get all the labels of the corresponding image
        labels = []
        start_id = idx - (idx % 5) 
        for i in range(start_id, start_id + 5):
            labels.append(self.labels.iloc[i, 2])

        # Just get the corresponding label
        # labels = self.labels.iloc[idx, 2]
        
        if self.transform:
            image = self.transform(image)
        
        return image, labels

In [118]:
preprocess = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(((0.444, 0.421, 0.385)), (0.285, 0.277, 0.286))
])
data = Dataset(root = "data", ann_file = "results.csv", transform = preprocess)

In [119]:
# nth image and its captions
n = 1000
print(data[n][0])
for caption in data[n][1]:
    print(caption)

tensor([[[0.8088, 0.8638, 0.8776,  ..., 0.8638, 0.8226, 0.7675],
         [0.9189, 0.9189, 0.8638,  ..., 1.0427, 0.9877, 0.9602],
         [0.9602, 0.9051, 0.8914,  ..., 1.1115, 1.0565, 1.0290],
         ...,
         [1.0702, 1.0702, 1.0427,  ..., 1.0565, 0.9739, 0.8776],
         [0.9464, 1.0152, 1.0427,  ..., 1.1253, 1.0290, 0.9602],
         [0.8501, 0.8638, 0.9051,  ..., 1.0152, 1.0565, 1.0152]],

        [[0.9577, 1.0285, 1.0426,  ..., 0.9718, 0.9152, 0.8727],
         [1.0285, 1.0426, 1.0001,  ..., 1.1417, 1.0709, 1.0426],
         [1.0851, 1.0426, 1.0285,  ..., 1.1983, 1.1134, 1.0709],
         ...,
         [1.1700, 1.1700, 1.1417,  ..., 1.1134, 1.0001, 0.9010],
         [1.0285, 1.0992, 1.1417,  ..., 1.1842, 1.1276, 1.0568],
         [0.9152, 0.9152, 0.9860,  ..., 1.0992, 1.1417, 1.0851]],

        [[0.7517, 0.8203, 0.8340,  ..., 0.7380, 0.7106, 0.6832],
         [0.8477, 0.8477, 0.8066,  ..., 0.9300, 0.8614, 0.8340],
         [0.8889, 0.8340, 0.8340,  ..., 0.9986, 0.9437, 0.