In [None]:
# This class returns a list:
# list[0] = batch of first 10 classes (0, 9) (5000 elements)
# list[1] = batch of second 10 classes (10, 19) (5000 elements)
# ...

import torch
import torchvision
import torchvision.transforms as transforms
from sklearn.model_selection import train_test_split

from torchvision.datasets import VisionDataset
from torch.utils.data import Subset, DataLoader
import matplotlib.pyplot as plt
import numpy as np
import random
import itertools
import sys

ROOT = './data'

In [None]:
# This class handles cifar-100 dataset. 
# The constructor downloads the data from torchvision.dataset repository
# The method get_batches splits the data into batches of the specified size
# Trainset  50k elements
# Testset   10k elements

class cifar_100(VisionDataset):

    def __init__(self, classes_per_batch, split, rand_seed=None):
        super(cifar_100, self).__init__(root=0)
        self.__classes_per_batch = classes_per_batch

        # normalize should not be used but if removed we cannot reach paper results
        self.__transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
        ])

        self.__transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
        ])

        self.__split = split
        self.__dataset = None
        
        if self.__split == 'train':
          self.dataset = torchvision.datasets.CIFAR100(root=ROOT, train=True,
                                            download=True, transform=self.__transform_train)
        else:
          self.dataset = torchvision.datasets.CIFAR100(root=ROOT, train=False,
                                        download=True, transform=self.__transform_test)
        
        if rand_seed!=None:
            random.seed(rand_seed)
            self.__classes = random.sample(range(0, 100), 100)
        else:
            #self.__classes = np.random.permutation(100).tolist()
            self.__classes=[36, 61, 49, 58, 92, 90, 68, 32, 28, 52, 47, 87, 1, 41, 93, 6, 88, 12, 38, 91, 81, 33, 8, 48, 60, 27, 50, 17, 56, 97, 34, 42, 84, 66, 62, 26, 29, 51, 3, 72, 39, 9, 37, 85, 13, 25, 11, 67, 99, 74, 30, 2, 64, 71, 19, 35, 31, 63, 54, 15, 43, 73, 40, 55, 7, 78, 14, 10, 70, 44, 0, 86, 79, 57, 75, 46, 83, 82, 22, 4, 45, 18, 89, 5, 59, 21, 95, 96, 69, 16, 98, 23, 80, 65, 76, 77, 20, 24, 94, 53]
            #self.__classes=[56, 99, 35, 94, 2, 85, 96, 62, 12, 34, 19, 63, 40, 24, 79, 97, 92, 20, 7, 53, 76, 39, 17, 0, 18, 27, 74, 9, 37, 3, 45, 78, 65, 75, 16, 57, 83, 30, 1, 22, 11, 8, 38, 15, 49, 87, 26, 42, 91, 61, 6, 90, 81, 66, 44, 89, 70, 4, 23, 10, 5, 51, 14, 71, 88, 28, 41, 25, 80, 54, 55, 32, 13, 52, 36, 86, 64, 60, 93, 77, 67, 29, 21, 73, 46, 95, 68, 69, 33, 98, 58, 84, 47, 50, 72, 82, 59, 31, 48, 43]
        
        self.__dictionary = {}
        self.batches = []

        for i, c in enumerate(self.__classes):
          self.__dictionary[c] = i

        self.__num_classes = 100
        self.__indexes = []

    def __getitem__(self, index):
        return index, self.dataset[index][0], self.dataset[index][1]

    def get_dictionary(self):
      return self.__dictionary

    def cif_len(self):
        return len(self.dataset.targets)

    def __len__(self):
        return len(self.dataset.targets)

    def get_tot_classes_count(self):
        return self.__num_classes
    
    def get_classes_per_batch(self):
        return self.__classes_per_batch

    def getClassIndexes(self, classes):
        indexes = np.array(np.arange(len(self.dataset.targets)))
        labels = self.dataset.targets
        mask = np.isin(labels, classes)
        indexes = indexes[mask]
        return indexes

    def get_classes_list(self):
      return self.__classes

    def get_batches(self):

      labels = np.array(self.dataset.targets)
      num_elem = len(labels)

      ids = [False for i in range(num_elem)]
      for k, c in enumerate(self.__classes):
          if k != 0 and not k % self.__classes_per_batch:
              new_ids = []
              for i, v in enumerate(ids):
                if v:
                    new_ids.append(i)
              ids = new_ids
              batch = Subset(self.dataset, ids)
              self.batches.append(batch)

              if k == 100:
                break
              ids = [False for i in range(num_elem)]

          indexes = labels == c
          ids = [i or j for i, j in zip(ids,indexes)]
    
      return self.batches

    def getIndexesLabels(self, indexes):
        labels=self.dataset.targets
        lab_ind=np.array(np.arange(len(labels)))
        mask = np.isin(lab_ind, indexes)
        labels = np.array(labels)
        labels = labels[mask]
        return labels

    def retrieveTrainVal(self,bat,bat_ind,cif):
        samples = range(0,len(bat_ind))
        #print(len(samples))
        labels=cif.getIndexesLabels(bat_ind.tolist())
        train, val, y_train, y_val = train_test_split(samples,labels,test_size=0.1,random_state=42,stratify=labels)
        index_train = []
        index_val = []
        for i, el in enumerate(samples):
            if el in train:
                index_train.append(i)
            else:
                index_val.append(i)

        train_dataset = Subset(bat, index_train)
        val_dataset = Subset(bat, index_val)
        return train_dataset, val_dataset


    def get_transform_test(self):
        return self.__transform_test

    def get_first_k_batches(self, k):
        k_max = self.__num_classes / self.__classes_per_batch
        if k > k_max or k < 0:
          print("ERROR: K must be a positive integer from 1 to", int(k_max), file=sys.stderr)
          return None
        
        first_k = []
        for i in range(k):
          first_k.extend(self.batches[i])
        
        return first_k

    def divide_in_classes(self, batch):
        data_in_classes = []
        labels = [el[1] for el in batch]
        labels = np.array(labels)
        data = [el[0] for el in batch]

        classes = np.unique(labels)
    
        for c in classes:
          ids = labels == c
        new_ids = []
        for i, v in enumerate(ids):
            if v:
                new_ids.append(i)
        ids = new_ids
        data = Subset(data, ids)
        data = [el for el in data]
        data_in_classes.append(data)
        
        return data_in_classes

print("IMPORT CIFAR DONE Rseed")

IMPORT CIFAR DONE
