In [None]:
import random
import os
import numpy as np

from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import ImageFolder
from torchvision.datasets import *
from torchvision.datasets.folder import *
import torchvision.transforms as transforms
from torchvision import transforms



In [None]:
def pil_image_loader(path):
    with open(path, 'rb') as f:
        with Image.open(f) as img:
            return img.convert('RGB')


def accimage_loader(path):
    import accimage
    try:
        return accimage.Image(path)
    except IOError:
        return pil_image_loader(path)


In [None]:
def default_loader(path):
    from torchvision import get_image_backend
    if get_image_backend() == 'accimage':
        return accimage_loader(path)
    else:
        return pil_image_loader(path)
        

In [None]:
def make_dataset(dir, class_to_idx):
    images = []
    dir = os.path.expanduser(dir)
    for target in sorted(os.listdir(dir)):
        d = os.path.join(dir, target)
        if not os.path.isdir(d):
            continue

        for root, _, fnames in sorted(os.walk(d)):
            for fname in sorted(fnames):
                if is_image_file(fname):
                    path = os.path.join(root, fname)
                    item = (path, class_to_idx[target])
                    images.append(item)

    return images

In [None]:
def find_classes(dir):
    classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
    classes.sort()
    class_to_idx = {classes[i]: i for i in range(len(classes))}
    return classes, class_to_idx


In [None]:
# A bird data set Class
class BirdDataset(Dataset):

    def __init__(self, root, transform=None, target_transform=None,loader=default_loader):

        
        classes, class_to_idx = find_classes(root)
        imgs       = make_dataset( root, class_to_idx)

        if len(imgs) == 0:
            raise(RuntimeError("No image found: " + root + "\n"))

        self.root = root
        self.imgs = imgs
        self.classes = classes
        self.class_to_idx = class_to_idx
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader
        
    def __len__(self):
        return len(self.imgs) 
        
    def __getitem__(self, index):

        path, target = self.imgs[index]
        img = self.loader(path)
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

In [None]:
class Convert(object):
    def __call__(self, img):
        return torch.unsqueeze(torch.from_numpy(np.array(img)), 0).float()

class OneHot(object):
    def __call__(self, label):

        return label

class Flatten(object):
    def __call__(self, img):
        return img.view(28*28)




In [None]:

def fetch_dataloader(types, data_dir, params, **kwargs):

    dataloaders = {}
    
    normMean = [0.49139968, 0.48215827, 0.44653124]
    normStd = [0.24703233, 0.24348505, 0.26158768]
    normTransform = transforms.Normalize(normMean, normStd)
    trainTransform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normTransform
    ])
    testTransform = transforms.Compose([
        transforms.ToTensor(),
        normTransform
    ])
        # A data transform which crops image from center and creates a 128*128 image
    train_transformer = transforms.Compose([
        transforms.CenterCrop((128, params.width)), 
        transforms.ToTensor()])           



    for split in ['train', 'val', 'test']:
        if split in types:
            path       = os.path.join(data_dir, "{}".format(split))
            if split == 'train':
                dl = DataLoader(BirdDataset(path,transform=train_transformer),
                                batch_size=params.batch_size, shuffle=True,
                                num_workers=params.num_workers)
            elif split == 'val':
                dl = DataLoader(BirdDataset(path,transform=train_transformer), 
                                batch_size=params.batch_size, shuffle=False,
                                num_workers=params.num_workers)
            else: # test
                dl = DataLoader(BirdDataset(path,transform=train_transformer), 
                                batch_size=params.batch_size, shuffle=False,
                                num_workers=params.num_workers)

            dataloaders[split] = dl

    return dataloaders