In [None]:
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Subset, DataLoader
from torch.backends import cudnn

import torchvision
from torchvision import transforms
from torchvision.datasets import CIFAR100

from PIL import Image
import matplotlib.pyplot as plt

In [None]:
!git clone https://github.com/akamaster/pytorch_resnet_cifar10.git

Cloning into 'pytorch_resnet_cifar10'...
remote: Enumerating objects: 5, done.[K
remote: Counting objects: 100% (5/5), done.[K
remote: Compressing objects: 100% (5/5), done.[K
remote: Total 81 (delta 0), reused 3 (delta 0), pack-reused 76[K
Unpacking objects: 100% (81/81), done.


In [None]:
from pytorch_resnet_cifar10.resnet import resnet32

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE

device(type='cuda')

# Deal with CIFAR100 dataset

In [None]:
class MyCIFAR100:
    '''
    https://www.cs.toronto.edu/~kriz/cifar.html
    100 classes containing 600 images each, 500 training images and 100 testing images per class
    '''    

    def __init__(self):
        self.num_classes = 100
        self.trainset, self.testset = self.get_dataset()

    def get_dataset(self):
        self.train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])
        self.test_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])

        if not os.path.isdir('./dataset'):
            trainset = CIFAR100('./dataset', train=True, download=True, 
                                transform=train_transform)
            testset = CIFAR100('./dataset', train=False, download=True, 
                            transform=test_transform)
        else:
            trainset = CIFAR100('./dataset', train=True, transform=train_transform)
            testset = CIFAR100('./dataset', train=False, transform=test_transform)

        return trainset, testset

    def sub_sample(self, sub_labels):
        '''
        Sub-sample a dataset, taking only those samples with label in sub_labels
        '''
        train_indexes = []
        for index in range(len(self.trainset)):
            label = self.trainset[index][1]
            if label in sub_labels:
                train_indexes.append(index)
        sub_trainset = Subset(set, train_indexes)

        test_indexes = []
        for index in range(len(self.testset)):
            label = self.testset[index][1]
            if label in sub_labels:
                test_indexes.append(index)
        sub_testset = Subset(testset, test_indexes)

        return sub_trainset, sub_testset

In [None]:
cifar = MyCIFAR100()
trainset, testset = cifar.sub_sample(range(10))

In [None]:
len(trainset), len(testset)

(5000, 1000)

In [None]:
train_loader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True)
test_loader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

# Learn new classes without any strategy

# Implement LwF

In [None]:
def common_loss(outputs, labels):
    return F.cross_entropy(outputs, labels)

In [None]:
def distillation_loss(student_outputs, teacher_outputs):
    return F.kl_div(
        F.log_softmax(student_outputs / T), F.softmax(teacher_outputs / T)
    )

# Implement iCaRL