In [1]:
from torch.utils.data import Dataset
from PIL import Image
from torchvision import transforms, utils

import pandas as pd
import numpy as np
import os
import torch
import glob

In [2]:
class SonDataset(Dataset):
    
    def __init__(self, updated_csv_path, img_path, transform = None):
        self.img_path = img_path
        self.votes_df = pd.read_csv(updated_csv_path)
        self.transform = transform
    
    def __getitem__(self, idx):
        self.img_score = self.votes_df.loc[idx, 'Average']
        self.img_name = self.votes_df.loc[idx, 'ID'].astype(str)
        self.img_file = []
        
        for directory, _ , _ in os.walk(self.img_path):
            self.img_file.extend(glob.glob(os.path.join(directory, self.img_name + '.jpg')))
            
        im = Image.open(self.img_file[0])
        img_as_img = im.convert('RGB')
        
        if self.transform:
            img_as_img = self.transform(img_as_img)
            
        sample = {'image' : img_as_img,
                  'image_name' : self.img_name,
                  'image_path' : self.img_file[0],
                  'image_score' : self.img_score}
        return sample
    
    def __len__(self):
        return len(votes_df)
        
        
    

In [3]:
#test the dataset
son_dataset = SonDataset('../data/updated_votes.csv', 
                         '/raid/data/datasets/SoN/images', 
                         transform = transforms.Compose([transforms.CenterCrop(224),
                                                         transforms.ToTensor(),
                                                         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                                              std=[0.229, 0.224, 0.225])
                                                        ]))

In [5]:
son_dataset[0]

{'image': tensor([[[ 2.0948,  2.0948,  2.0948,  ...,  2.1462,  2.1290,  2.1119],
          [ 2.0777,  2.0948,  2.0948,  ...,  2.1290,  2.1119,  2.1119],
          [ 2.0948,  2.0948,  2.1119,  ...,  2.1119,  2.1119,  2.0948],
          ...,
          [-1.2617, -0.9877, -0.9877,  ..., -0.9877, -0.8507, -0.8678],
          [-0.9534, -0.7650, -0.9705,  ..., -0.6965, -0.6109, -0.6965],
          [-0.9705, -1.0562, -1.0904,  ..., -0.5424, -0.5938, -0.8164]],
 
         [[ 2.1835,  2.1835,  2.1835,  ...,  2.2360,  2.2185,  2.2010],
          [ 2.1660,  2.1835,  2.1835,  ...,  2.2185,  2.2010,  2.2010],
          [ 2.1835,  2.1835,  2.2010,  ...,  2.2010,  2.2010,  2.1835],
          ...,
          [-1.3529, -1.0728, -1.0378,  ..., -1.0553, -0.9503, -0.9678],
          [-1.0028, -0.8102, -0.9678,  ..., -0.8978, -0.7927, -0.8803],
          [-1.0203, -1.0553, -1.0903,  ..., -0.8277, -0.8803, -1.1078]],
 
         [[ 2.3437,  2.3437,  2.3437,  ...,  2.3960,  2.3786,  2.3611],
          [ 2.3263,