In [None]:
# %cd drive/MyDrive/dl-assigment-2/

# !unzip nature_12K.zip

In [1]:
import cv2
import glob
import random
import numpy as np
import torch
from pandas.core.common import flatten
torch.manual_seed(7)
torch.cuda.empty_cache()
from torchvision import transforms
from torch.utils.data import DataLoader,Dataset
from PIL import Image

In [2]:
# constants
IMG_MODE = 'RGB'
TRAIN_LABEL = 'train'
TEST_LABEL = 'test'

class iNaturalist(Dataset):
    def __init__(self, image_paths, class_to_idx, transform):
        self.all_images = image_paths
        self.current_transform = transform
        self.class_to_idx = class_to_idx
        
    def __len__(self):
        return len(self.all_images)

    def __getitem__(self, idx):
        image_filepath = self.all_images[idx]
        image = cv2.imread(image_filepath)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        y = self.class_to_idx[image_filepath.split('/')[-2]]
        
        X = Image.fromarray(np.uint8(image)).convert(IMG_MODE)
        X = Image.fromarray(image.astype('uint8'), IMG_MODE)
        X = self.current_transform(X)

        return X, y

def create_data(data_type, data_path,  data_aug, image_shape, b_size):
    classes = [image_path.split('/')[-1] for image_path in glob.glob(data_path + '/*')]

    all_images = [glob.glob(image_path + '/*') for image_path in glob.glob(data_path + '/*')]
    all_images = list(flatten(all_images))

    idx_to_class,class_to_idx = dict(),dict()
    for i, j in enumerate(classes):
        idx_to_class[i] = j
        class_to_idx[j] = i

    non_aug_tran = transforms.Compose([transforms.Resize((image_shape)),
                                transforms.ToTensor()
                                    ])
    if data_type == TEST_LABEL:
        test_image_paths=all_images
        test_dataset= iNaturalist(test_image_paths,class_to_idx,non_aug_tran)
        test_loader = DataLoader(test_dataset, batch_size=b_size, shuffle=True)

        return test_loader


    random.shuffle(all_images)

    tr_paths, v_paths = all_images[:int(0.9*len(all_images))], all_images[int(0.9*len(all_images)):] 

    tr_data,v_data = iNaturalist(tr_paths,class_to_idx,non_aug_tran),iNaturalist(v_paths,class_to_idx,non_aug_tran)

    if data_aug:
        augu_tran = transforms.Compose([transforms.Resize((image_shape)),
                transforms.RandomRotation(degrees=30),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomGrayscale(p=0.2),
                transforms.ToTensor(),
                            ])

        tr_data = iNaturalist(tr_paths,class_to_idx,augu_tran)
        v_data = iNaturalist(v_paths,class_to_idx,augu_tran)  

    t_loader,v_loader = DataLoader(tr_data, batch_size=b_size, shuffle=True),DataLoader(v_data, batch_size=b_size, shuffle=True)
    return t_loader,v_loader
