In [None]:
from __future__ import print_function
import pandas as pd
import numpy as np
import csv
import os
import torch

import matplotlib
import matplotlib.pyplot as plt
import random

from PIL import Image
from skimage import io
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader

In [None]:
class PETS_dataset(Dataset):
    
    
    def __init__(self, dat_folder,train = True, transform = None):   
        self.root_dir = os.path.dirname(dat_folder)
        self.dataset_dir = dat_folder
        filenames = [f for f in os.listdir(dat_folder) if os.path.isfile(os.path.join(self.dataset_dir, f))]
        
        self.classes_indexes(filenames)
        
        csv_file_train = os.path.join(self.root_dir, "data_Train.csv")
        csv_file_test = os.path.join(self.root_dir, "data_Test.csv")    
        
        if not "data_Train.csv" in os.listdir(self.root_dir) or not "data_Test.csv" in os.listdir(self.root_dir):
        
            train_line_list = []
            test_line_list = [] 

            for i,classes in enumerate(self.classes):
                indexes = self.indexes[i]
                random.shuffle(indexes)

                for j in range(len(indexes)):
                    if os.path.isfile(os.path.join(self.dataset_dir,filenames[indexes[j]])):
                        if j < len(indexes)*0.7:
                            line = [str(j),"1",str(i),classes,os.path.join(self.dataset_dir,filenames[indexes[j]])]
                            train_line_list.append([line])
                        else:
                            line = [str(j),"2",str(i),classes,os.path.join(self.dataset_dir,filenames[indexes[j]])]
                            test_line_list.append([line])

            with open(csv_file_train,"w") as f:
                writer = csv.writer(f,delimiter=',')
                writer.writerow(["counter","set", "class", "label","location"])
                for lines in train_line_list:
                    writer.writerows(lines)

            with open(csv_file_test,"w") as f:
                writer = csv.writer(f,delimiter=',')
                writer.writerow(["counter","set", "class", "label","location"])
                for lines in test_line_list:
                    writer.writerows(lines) 
        
        if train: self.PETS_datafile = pd.read_csv(csv_file_train)
        else: self.PETS_datafile = pd.read_csv(csv_file_test)
        
        self.transform = transform

            
    def classes_indexes(self, filenames):
        self.classes = []
        self.indexes = []
        
        for i,filename in enumerate(filenames):
            name_class = "_".join((filename.split("_"))[0:-1])
            if not name_class in self.classes: 
                self.classes.append(name_class)
                self.indexes.append([i])
            else:
                self.indexes[self.classes.index(name_class)].append(i)
            
            
    def __len__(self):
        return len(self.PETS_datafile)
    
    def __getitem__(self,idx):
        img_name = self.PETS_datafile.iloc[idx,4]
        image = self.__loadfile(img_name)
        target = self.PETS_datafile.iloc[idx,2]
        if self.transform:
            image = Image.fromarray(image)
            bands = image.getbands()
            if bands == ('R','G','B','A'):
                image = image.convert('RGB')
            sample = self.transform(image)
        else:
            sample = image
        return (sample,target)
    
    def __loadfile(self, data_file):
        image = io.imread(data_file)
        if len(image.shape)<3:
            image = np.stack((image,)*3, axis=-1)
        return image
    

In [None]:
#a = PETS_dataset("/home/rita/JupyterProjects/EYE-SEA/DataSets/Pets/Pet_Datasets")
