<a href="https://colab.research.google.com/github/alessandronicolini/IncrementalLearning/blob/main/cifar100.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:

import numpy as np
import torch
from torchvision import transforms
from torchvision.datasets import VisionDataset
from PIL import Image
import random
import torchvision
ROOT = './data'
class ilCIFAR100(VisionDataset):
    """
    Extends CIFAR100 class. Split the dataset into 10 batches, each one containing 10 classes.
    You can retrieve the batches from the attribute "batches", it has different structure according to
    test and train CIFAR100 splits:
        - train -> batches is a dictionary {0:{'train':indexes, 'val':indexes}...} 
        - test -> batches is a dictionary {0:indexes...}
    where the keys are the batch number.

    Args:
        root (string): Root directory of dataset where directory
            `cifar-10-batches-py` exists or will be saved to if download is set to True.
        seed(int): used to ensure reproducibility in shuffling operations.
        val_size(float, optional): between 0 and 1, fraction of data used for validation.
        train (bool, optional): If True, creates dataset from training set, otherwise
            creates from test set.
        transform (callable, optional): A function/transform that takes in an PIL image
            and returns a transformed version. E.g, `transforms.RandomCrop`
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
    """
    def __init__(self,classes_per_batch, seed, val_size=0.1, train=True, transform=None, target_transform=False, 
    download=True):
        
        super(ilCIFAR100, self).__init__(root=0)
        self.classes_per_batch=classes_per_batch

        
        self.__rs = seed # set random seed 
        self.train=train
        self.__transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
        ])

        self.__transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
        ])
        # if train:
        #     self.batches = self.__make_train_batches(val_size)
        # else:
        #     self.batches = self.__make_test_batches()
        if self.train :
          self.dataset = torchvision.datasets.CIFAR100(root=ROOT, train=True,
                                            download=True, transform=self.__transform_train)
        else:
          self.dataset = torchvision.datasets.CIFAR100(root=ROOT, train=False,
                                        download=True, transform=self.__transform_test)
        self.targets = np.array(self.dataset.targets) # make targets an array to exploit masking
        random.seed(seed)
        self.classes = random.sample(range(0, 100), 100)
        #self.classes = self.classes.reshape((10, -1)) # each row contains the classes for the corrisponding batch
        #print(self.classes)
        self.__dictionary = {}
        for i, c in enumerate(self.classes):
          self.__dictionary[c] = i


    def get_dict(self):
      return self.__dictionary
    def __getitem__(self, index):
        return index,self.dataset.__getitem__(index)
    def __len__(self):
        return self.dataset.__len__()
    def getbatches(self):
      classlist=self.classes
      batches=[]
      for i in range(0,int(100/self.classes_per_batch)):
        #print(i)
        batch=classlist[int(i*self.classes_per_batch):int(i*10+self.classes_per_batch)]
        batches.append(batch)
      return batches
    def get_batch_indexes(self):
      classlist=self.classes
      numclass=self.classes_per_batch
      batch_indexes=[]
      for i in range(0,int(100/self.classes_per_batch)):
        batch=classlist[int(i*numclass):int(i*numclass+numclass)]
        mask=np.isin(self.targets,batch)
        indexes=np.array(np.arange(len(self.dataset.targets)))
        indexes=indexes[mask]
        batch_indexes.append(indexes)
      return batch_indexes
    def get_train_val(self,valid):
      batches=self.get_batch_indexes()
      train=[]
      val=[]
      for batch in batches:
        #print(type(batch))
        random.shuffle(batch)
        valbatch=batch[0:int(valid*len(batch))]
        trainbatch=batch[int(valid*len(batch)):]
        train.append(trainbatch)
        val.append(valbatch)
      return train,val

In [None]:
train=ilCIFAR100(10,12)

Files already downloaded and verified
[60, 34, 84, 67, 85, 44, 18, 48, 1, 47, 61, 35, 82, 58, 76, 29, 71, 0, 79, 93, 56, 90, 20, 43, 26, 7, 73, 25, 9, 65, 95, 51, 11, 2, 74, 28, 96, 27, 99, 64, 70, 42, 62, 8, 98, 77, 39, 88, 10, 94, 3, 52, 68, 32, 5, 72, 38, 75, 69, 30, 40, 41, 24, 55, 91, 45, 12, 16, 22, 53, 63, 57, 31, 33, 21, 83, 49, 81, 59, 78, 97, 19, 46, 17, 36, 87, 6, 13, 14, 89, 80, 23, 86, 37, 54, 92, 50, 66, 15, 4]


In [None]:
train.getbatches()

0
1
2
3
4
5
6
7
8
9


[[60, 34, 84, 67, 85, 44, 18, 48, 1, 47],
 [61, 35, 82, 58, 76, 29, 71, 0, 79, 93],
 [56, 90, 20, 43, 26, 7, 73, 25, 9, 65],
 [95, 51, 11, 2, 74, 28, 96, 27, 99, 64],
 [70, 42, 62, 8, 98, 77, 39, 88, 10, 94],
 [3, 52, 68, 32, 5, 72, 38, 75, 69, 30],
 [40, 41, 24, 55, 91, 45, 12, 16, 22, 53],
 [63, 57, 31, 33, 21, 83, 49, 81, 59, 78],
 [97, 19, 46, 17, 36, 87, 6, 13, 14, 89],
 [80, 23, 86, 37, 54, 92, 50, 66, 15, 4]]

In [None]:
train.get_batch_indexes()

[array([    4,    25,    32, ..., 49978, 49985, 49987]),
 array([    1,     2,    12, ..., 49947, 49958, 49967]),
 array([    6,    33,    44, ..., 49996, 49998, 49999]),
 array([    3,     7,    11, ..., 49976, 49980, 49989]),
 array([   10,    15,    16, ..., 49952, 49962, 49982]),
 array([   27,    62,    82, ..., 49969, 49975, 49997]),
 array([   35,    38,    41, ..., 49984, 49991, 49992]),
 array([    9,    21,    24, ..., 49972, 49979, 49993]),
 array([    0,    13,    17, ..., 49983, 49988, 49994]),
 array([    5,     8,    18, ..., 49951, 49963, 49995])]

In [None]:
train.get_train_val(0.1)

<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>


([array([21903, 21602, 12483, ...,  2836, 22245,  6316]),
  array([ 8637, 29096, 40345, ..., 16194,  4607, 11271]),
  array([13700,  4674, 26629, ..., 49855, 41539,  2655]),
  array([48775,   469, 32098, ...,  1041, 25662, 46272]),
  array([37888, 23351, 25206, ..., 20634, 42682,  5745]),
  array([35295, 13048, 43419, ..., 27295, 34638,  2227]),
  array([21019, 15880, 17768, ..., 49586, 29959,  5828]),
  array([13663, 10584, 32793, ..., 39499, 32112, 24474]),
  array([36018, 28713, 45431, ..., 17997, 30932, 46801]),
  array([44872, 14190, 20922, ..., 42279, 45137, 24219])],
 [array([14504, 29552, 19577, 35523, 10693, 19015, 28754, 30467, 24798,
          3412,  2118, 14335, 22915, 34697,  2426, 25182,  4365, 18719,
          7185, 42765, 12995, 18379, 32752, 48075, 47802, 27891, 18020,
         12913, 39134, 45763, 49262, 26266, 39768,  2952, 49439, 45766,
         38246, 36973, 19440, 29434, 40532, 40602, 16425, 17581, 12039,
         36569, 43708, 39571,  8501, 28672, 43631,  9338,  