In [None]:
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision import datasets
from PIL import Image


class Cifar100(torch.utils.data.Dataset):             
    
    def __init__(self, root, train, download, random_state, transform=None):
        self.train = train                            #Boolean: state if we want or not the trainset
        self.transform = transform                    #set of transformation to apply to our dataset (ToTensor(), Normalize(), standardization etc.)
        self.is_transform_enabled = True              
        
        self.dataset = datasets.cifar.CIFAR100(       #cifar100 dataset
            root=root,
            train=train,
            download=download,
            transform=None)

        self.targets = np.array(self.dataset.targets)  # we extract the target label from the dataset

        # Use make_split(k:[batch labels]) to build k-th split dataset: this method is defined below, in the class.
        self.batch_splits = self.make_split(random_state)

    def make_split(self, random_state):  #--> make_split()
            #random shuffle 
            #of a vector containing class label (0-99):
            np.random.seed(random_state)                   # Useful to reproduce same values every time
            random_labels = list(range(0, 100))            # [0-99] labels
            np.random.shuffle(random_labels)               # shuffle of the index

            ten_split = np.array_split(random_labels, 10)  # we split our shuffled array in 10 sized chunks
           
            batch_splits = dict.fromkeys(np.arange(0, 10)) #initialize the dict for splits
            for i in range(10):
              batch_splits[i]=ten_split[i].tolist()        # {0:[1-st chunk], 1:[...], ... , 9:[...]} 

            # save label mapping
            #we save the label mapping generated by the shuffle
            #example: {0:72, 1:67 ... 99:32} <-- generated by the shuffle
            self.label_map = {k:v for v,k in enumerate(random_labels)}    
           
            return batch_splits                            

        
    def set_classes_batch(self, batch_idx): 
        """ Args:
            batch_idx: array of ten element, 10 classes of the current split 
                (e.g. self.batch_splits[0])
        """

        self.batch_idx  =  batch_idx                

        # retrieve element of the classes of the current split
        # Boolean mask returning only indexes where targets match an element of batch_idx 
        mask = np.isin(self.targets, self.batch_idx)      # True if target is in batch_idx
                                                          

        # Batch indices of interest
        idxes = np.where(mask)[0]     
        self.idxes = np.array(idxes)  

             
        #we save in self.batches_mapping a dictionary like this:
        ##              {0:37 ,1:34, ... , 999:2}
        #key are value from 0-999, values are index that refers to item belonging to the 10 classes in our split

        # this will be used in __getitem__ : __getitem__ generates random number from [0 - __len__];
        # we need to translate those numbers in indexes corresponding to data belonging to our split
       
        self.batches_mapping = {fake_idxes : real_idxes    for fake_idxes, real_idxes in enumerate(idxes)}

        # fake_idxes = [0-999] index used in __getitem__ to retrieve record of interest
        # real_idxes = [index data of our current splits] index used in __getitem__ to return element form self.dataset


    def __len__(self):
        return len(self.batches_mapping)         # we set the max value of idx in the __getitem__ method


    def __getitem__(self, fake_idxes):       
       
        real_idxes = self.batches_mapping[fake_idxes]    
        image = self.dataset.data[real_idxes]            # imageS of selected chunk
        label = self.dataset.targets[real_idxes]         # labelS of selected chunk

        image = Image.fromarray(image)           # Return a PIL image (an Image object)

        # Applies preprocessing when accessing the image if transformations are currently enabled
        if (self.transform is not None) and (self.is_transform_enabled is True):
            image = self.transform(image)

        mapped_label = self.label_map[label]     
      
        return image,mapped_label

    def enable_transform(self):
        self.is_transform_enabled = True

    def disable_transform(self):
        self.is_transform_enabled = False

    def transform_status(self):
        return self.is_transform_enabled