In [43]:
import torch
import torchvision
from torch.utils.data import Dataset
import torchvision.transforms.functional as fn
import torchvision.transforms as T
import matplotlib.pyplot as plt
from utilities import createAnnotation
import pandas as pd
from IPython.display import display
from PIL import Image 
import random
import numpy as np

In [44]:
IMAGESROOTDIR = 'NINCO_OOD_classes'

In [45]:
class ImageDataset(Dataset):
    def __init__(self, rootDir):
        self.rootDir = rootDir
        createAnnotation(self.rootDir)
        self.annotation =  pd.read_csv('output.csv')


    def __getitem__(self, index):
        data_path = self.annotation.iloc[index,0]
        image = Image.open(data_path)
        label = self.annotation.iloc[index,1]
        return image, label

    def __len__(self):
        return len(self.annotation)

In [46]:
# instance of class ImageDataset
# contains all 765 images with their respective labels
imageDataset = ImageDataset(rootDir=IMAGESROOTDIR)

In [47]:
# Class which is used to rescale the image to a given size
# input: outputSize: int
# return: tuple(PIL Image, label)
class Rescale:
    def __init__(self, outputSize):
        self.outputSize = outputSize   
        
        
    def __calculateNewSize(self, size):
        initialWidth, initialHeight = size
        
        
        RATIO = initialWidth/self.outputSize
        newWidth = self.outputSize
        newHeight = initialHeight/RATIO
        
        return (round(newWidth), round(newHeight))  
        
    #sample data is a tuple(image, label)
    def __call__(self, sampleData):
        image, label = sampleData
        
        size = image.size
        
        newWidth, newHeight = self.__calculateNewSize(size)
        
        transformedImage = fn.resize(image, [newHeight, newWidth])
        
        return transformedImage,label

In [48]:
# Class which is used to center crop non quadratic images
# input: outputSize: int
# return: tuple(PIL Image, label)
class CenterCrop:
    def __init__(self, outputSize):
        self.outputSize = outputSize
        
    # creates a quadratic image
    def __call__(self, sampleData):
        image, label = sampleData
        
        width, height = image.size
        
        if (width != height or width != self.outputSize):
            centerCrop = torchvision.transforms.CenterCrop(self.outputSize)

            return centerCrop(image), label
        return image,label

In [49]:
# Constants for the size of the images
RESCALE = 240
CROP = 240
# objects for resizing
rescale = Rescale(RESCALE)
crop = CenterCrop(CROP)
composed = T.Compose([rescale, crop])

# given an Index returns the transformed Immage
# input: Index: int
# return: tuple(PIL Image, label)
def transform(index):
    assert index <= len(imageDataset)
    tmp = composed(imageDataset[index])
    return tmp

In [50]:
# objects for tensor transformation
pilToTensor = T.ToTensor()
tensorToPil = T.ToPILImage()

In [61]:
# Class which is used to get the resized images with label
# input: datasetLength: int
# output:{'image': Tensor, 'label': String}
class DataLoader(Dataset):
    def __init__(self, datasetLength):
        self.datasetLength = datasetLength
        
    def __getitem__(self, index):
        assert (0 < index <= self.datasetLength)
        self.index = index
        (picture, label) = transform(index)
        image = pilToTensor(picture)
        sample = {'image': image, 'label': label}
        return sample

In [63]:
# Amount of random samples 
BATCHSIZE = 4

dataloader = DataLoader(len(imageDataset))
for i in range(BATCHSIZE):
    index = random.randint(0,len(imageDataset))
    sample = dataloader[index]
  
    
    