## Load Data

In [0]:
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data as data
import torchvision as tv
import torchvision.models as models
from PIL import Image
import glob
import os


class DatasetManager:
    
    def __init__(self, dataset = 'cifar10', percent_data = 10.0, percent_val = 20.0, data_path = './data'):
        
        # 'dataset' can be 'hymenoptera', 'cifar10', or 'cifar100'.
        # 'percent_data' is the percentage of the full training set to be used.
        # 'percent_val' is the percentage of the *loaded* training set to be used as validation data.
        
        self.dataset = dataset
        self.data_path = data_path
        self.percent_data = percent_data
        self.percent_val = percent_val
        
        if self.dataset == 'hymenoptera':

            self.transform = tv.transforms.Compose([
                tv.transforms.RandomResizedCrop(224),
                tv.transforms.RandomHorizontalFlip(),
                tv.transforms.ToTensor(),
                tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
            
        elif self.dataset == 'cifar10' or self.dataset == 'cifar100':

            self.transform = tv.transforms.Compose([
                tv.transforms.RandomResizedCrop(224),
                tv.transforms.RandomHorizontalFlip(),
                tv.transforms.ToTensor(),
                tv.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
        
        return
    
    
    def ImportDataset(self, batch_size=5):
        
        self.batch_size = batch_size
        
        if self.dataset == 'hymenoptera':
        
            self.trainset = tv.datasets.ImageFolder(root=self.data_path,
                             transform=self.transform)
        
        # todo
        
        elif self.dataset == 'cifar10':

            self.trainset = tv.datasets.CIFAR10(root=self.data_path, train=True,
                                        download=True, transform=self.transform)

            self.testset = tv.datasets.CIFAR10(root=self.data_path, train=False,
                                       download=True, transform=self.transform)
        
        elif self.dataset == 'cifar100':

            self.trainset = tv.datasets.CIFAR100(root=self.data_path, train=True,
                                        download=True, transform=self.transform)

            self.testset = tv.datasets.CIFAR100(root=self.data_path, train=False,
                                       download=True, transform=self.transform)
             
        self.SplitData();
        self.GenerateLoaders();
                
        return
    
    
    def SplitData(self):
        
        len_full = self.trainset.__len__()
        len_train = int(np.round(len_full*self.percent_data/100.0))
        
        _, self.trainset = torch.utils.data.random_split(self.trainset, (len_full-len_train, len_train))
        
        len_val = int(np.round(len_train*self.percent_val/100.0))
        len_train = len_train - len_val
        
        self.valset, self.trainset = torch.utils.data.random_split(self.trainset, (len_val, len_train))
         
        len_full_test = self.testset.__len__()
        len_test = int(np.round(len_full_test*self.percent_data/100.0))
        
        _, self.testset = torch.utils.data.random_split(self.testset, (len_full_test-len_test, len_test))

        print('\nFull training set size: {}'.format(len_full))
        print('Full test set size: {}'.format(len_full_test))
        print('\nActive training set size: {}'.format(len_train))
        print('Active validation set size: {}'.format(len_val))
        print('Active test set size: {}'.format(len_test))
        
        return
    
    
    def GenerateLoaders(self):
        
        self.train_loader = torch.utils.data.DataLoader(self.trainset, batch_size=self.batch_size,
                                          shuffle=True, num_workers=0)
        self.val_loader = torch.utils.data.DataLoader(self.valset, batch_size=self.batch_size,
                                          shuffle=True, num_workers=0)
        self.test_loader = torch.utils.data.DataLoader(self.testset, batch_size=self.batch_size,
                                          shuffle=True, num_workers=0)          
            
        return


In [17]:
# Unit test for data import

dat = DatasetManager('cifar10', 10.0, 20.0)
dat.ImportDataset(5)



Files already downloaded and verified
Files already downloaded and verified

Full training set size: 50000
Full test set size: 10000

Active training set size: 4000
Active validation set size: 1000
Active test set size: 1000
