In [4]:
import torch,torchvision
import torch.utils.data as data
import torchvision.datasets as datasets
import torchvision.transforms as transforms

In [1]:
class SimpleLoader(object):
    def __init__(self,dataset,batch_size,crop_size,transform=None,split=False,valid_size=0.1):
        transform=transforms.Compose([transforms.Resize(256),
                                      transforms.RandomCrop(crop_size),
                                      transforms.ToTensor(),
                                      transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])]) if transform is None else transform
        if type(dataset)==str:
            if dataset=='CIFAR10':
                self.num_classes=10
                self.train_ds=datasets.CIFAR10(root='data',train=True,transform=transform,download=True)
                self.test_ds=datasets.CIFAR10(root='data',train=False,transform=transform,download=True)
            elif dataset=='CIFAR100':
                self.num_classes=100
                self.train_ds=datasets.CIFAR100(root='data',train=True,transform=transform,download=True)
                self.test_ds=datasets.CIFAR100(root='data',train=False,transform=transform,download=True)
            elif dataset=='ImageNet':
                self.num_classes=1000
                self.train_ds=datasets.ImageFolder(root='/home/osilab5/hdd/ImageNet/Data/train',transform=transform)
                self.valid_ds=self.test_ds=datasets.ImageFolder(root='/home/osilab5/hdd/ImageNet/Data/val',transform=transform)
        else:
            raise Exception('wrong dataset')
        
        self.batch_size=batch_size
        
        if split:
            valid_size=int(valid_size*len(self.train_ds))
            train_size=len(self.train_ds)-valid_size
            self.train_ds,self.valid_ds=data.random_split(self.train_ds,[train_size,valid_size])
        else:
            self.valid_ds=self.test_ds
    
    def GetNumClasses(self):
        return self.num_classes
    
    def GetTrainLoader(self):
        return data.DataLoader(self.train_ds,batch_size=self.batch_size,shuffle=True,num_workers=2,pin_memory=True,drop_last=True)
    
    def GetValidLoader(self):
        return data.DataLoader(self.valid_ds,batch_size=self.batch_size*2,num_workers=2,pin_memory=True,drop_last=True)
    
    def GetTestLoader(self):
        return data.DataLoader(self.test_ds,batch_size=self.batch_size*2,num_workers=2,pin_memory=True,drop_last=True)