In [None]:
import torch
from torch.utils.data import Dataset
from torchvision.io import read_image
from torchvision.transforms import ToTensor, Lambda, Normalize, CenterCrop, Resize
from torchvision.io.image import ImageReadMode

import cv2
import numpy as np



class CustomImageDataset(Dataset):
    
    def __init__(self, files, labels):
        "files: contains the list of path to each file"
        "labels: a list of all the labels in order of the image paths in files"
        self.files = files
        self.labels = labels
        self.transform = Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        self.target_transform = Lambda(lambda y: torch.zeros(3, 
                                        dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))
        
    def __len__(self):
        "denotes the total number of samples"
        return len(self.files)
    
    def __getitem__(self, index):
        "read one sample"
        x = read_image(self.files[index])
        y = self.labels[index]

        x = Resize((400, 800))(x)
        x = self.transform(x.type(torch.float))
        x = x.type(torch.float)
        y = self.target_transform(y)
        
        return x, y