In [1]:
# -------------------------------------------------------- Imports --------------------------------------------------------------
import os
import pandas as pd
import numpy as np
import torch 
import pickle
from PIL import Image
from torchvision.transforms import Resize,Compose, ToTensor
from torchvision.io import read_image
from torch.utils.data import DataLoader, Dataset,DataLoader,random_split

In [2]:
# -------------------------------------------------------- Global configs --------------------------------------------------------------
RESHAPE_SIZE = 512
BATCH_SIZE = 32
RPARIS6K_DATASET_FOLDERPATH  = '/notebooks/cnnimageretrieval-pytorch/data/test/rparis6k'
ROXFORD5K_DATASET_FOLDERPATH = '/notebooks/cnnimageretrieval-pytorch/data/test/roxford5k'
PASCALVOC_DATASET_FOLDERPATH = '/notebooks/cnnimageretrieval-pytorch/data/test/pascalvoc'
CALTECH_DATASET_FOLDERPATH  = '/notebooks/cnnimageretrieval-pytorch/data/test/caltech101'

In [3]:
# -------------------------------------------------------- Custom Dataset --------------------------------------------------------------

class ParisDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        
        
        gnd_path = os.path.join(image_dir, 'gnd_rparis6k.pkl')
        
    
        handle = open(gnd_path,'rb')
        csv_file = pickle.load(handle)
        
        image_names = csv_file['imlist']

        self.image_paths = [ os.path.join(image_dir,'jpg',name+'.jpg') for name in image_names]
        self.transform = transform
        
        return

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path)

        if(self.transform):
            image = self.transform(image)
        return image

class OxfordDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        
        image_paths = os.listdir(image_dir)
        gnd_path = os.path.join(image_dir, 'gnd_roxford5k.pkl')
        
        handle = open(gnd_path,'rb')
        csv_file = pickle.load(handle)
        
        image_names = csv_file['imlist']

        self.image_paths = [ os.path.join(image_dir,'jpg',name+'.jpg') for name in image_names]
        self.transform = transform
        
        return

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path)
        
        if(self.transform):
            image = self.transform(image)
        return image   
    
class PascalVOCEasyDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        
        gnd_path = os.path.join(image_dir, 'gnd_pascalvoc_700.pkl')
        
        handle = open(gnd_path,'rb')
        csv_file = pickle.load(handle)
        
        image_names = csv_file['imlist']

        self.image_paths = [ os.path.join(image_dir,'jpg',name+'.jpg') for name in image_names]
        self.transform = transform
        
        return

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path)
        
        if(self.transform):
            image = self.transform(image)
        return image  
    
    
class PascalVOCMediumDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        
        gnd_path = os.path.join(image_dir, 'gnd_pascalvoc_700_medium.pkl')
        
        handle = open(gnd_path,'rb')
        csv_file = pickle.load(handle)
        
        image_names = csv_file['imlist']

        self.image_paths = [ os.path.join(image_dir,'jpg',name+'.jpg') for name in image_names]
        self.transform = transform
        
        return

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path)
        
        if(self.transform):
            image = self.transform(image)
        return image    
    
    
class PascalVOCHardDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        
        gnd_path = os.path.join(image_dir, 'gnd_pascalvoc_700_no_bbx.pkl')
        
        handle = open(gnd_path,'rb')
        csv_file = pickle.load(handle)
        
        image_names = csv_file['imlist']

        self.image_paths = [ os.path.join(image_dir,'jpg',name+'.jpg') for name in image_names]
        self.transform = transform
        
        return

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path)
        
        if(self.transform):
            image = self.transform(image)
        return image        
    
    
    
class Caltech101Dataset(Dataset):
    def __init__(self, image_dir, transform=None):
        
        gnd_path = os.path.join(image_dir, 'gnd_caltech101_700.pkl')
        
        handle = open(gnd_path,'rb')
        csv_file = pickle.load(handle)
        
        image_names = csv_file['imlist']

        self.image_paths = [ os.path.join(image_dir,'jpg',name+'.jpg') for name in image_names]
        self.transform = transform
        
        return

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path)
        
        if(self.transform):
            image = self.transform(image)
        return image        
    

In [4]:
# -------------------------------------------------------- Custom Transforms --------------------------------------------------------------

class ExpandDimension(object):
    def __call__(self, sample):
        if(sample.shape[0] == 1):
            sample = sample.repeat(3,1,1)
        return sample

In [5]:
# -------------------------------------------------------- Data Loaders  --------------------------------------------------------------

# ---- Dataset Transforms ----
content_transform = Compose([ToTensor(),Resize((RESHAPE_SIZE,RESHAPE_SIZE)),ExpandDimension()])

# ---- Dataset Reads  ----
dataset = OxfordDataset(image_dir=ROXFORD5K_DATASET_FOLDERPATH,transform=content_transform)

# ---- Dataset Split  ----
TRAIN_PERCENT = 0.9
VALIDATION_PERCENT = 0.05
TEST_PERCENT = 0.05

train_size = int(TRAIN_PERCENT*int(len(dataset)))
validation_size = int(VALIDATION_PERCENT*int(len(dataset)))
test_size  = len(dataset) - (train_size + validation_size)

train_dataset, validation_dataset, test_dataset = random_split(dataset, lengths=[train_size,validation_size,test_size],generator=torch.Generator().manual_seed(420))

In [25]:
#  ---- Dataloaders ----
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False)
validation_dataloader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)