In [1]:
import sys
sys.path.append("../")

In [2]:
import copy

import torch
import torchvision
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.nn import functional as F
import numpy as np

import inclearn

In [3]:
inclearn.train._set_seed(1)

Set seed 1


In [4]:
def get_task_dataset(max_class=100):
    train_transforms = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
    ])

    test_transforms = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
    ])


    train_dataset = datasets.cifar.CIFAR100("data", download=True, train=True, transform=train_transforms)
    test_dataset= datasets.cifar.CIFAR100("data", download=True, train=False, transform=test_transforms)
    select(train_dataset, max_class)
    select(test_dataset, max_class)
    
    train_loader = DataLoader(train_dataset, batch_size=128, num_workers=1, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=128, num_workers=1)
    return train_loader, test_loader

    
def select(dataset, max_class):
    targets = np.array(dataset.targets)
    
    indexes = np.where(targets <= max_class)[0]
    
    dataset.data = dataset.data[indexes]
    dataset.targets = targets[indexes].tolist()

In [5]:
device = torch.device("cuda:3")

In [6]:
def acc(model, loader):
    predictions = []
    all_targets = []
    
    for inputs, targets in loader:
        inputs = inputs.to(device)
        all_targets.append(targets.numpy())
        
        logits = model(inputs)
        predictions.append(logits.argmax(dim=1).cpu().numpy())

    predictions = np.concatenate(predictions)
    all_targets = np.concatenate(all_targets)
    
    total_acc = (predictions == all_targets).sum() / len(all_targets)
    
    return all_targets, predictions, total_acc

In [7]:
from torch.optim.lr_scheduler import _LRScheduler

class WarmUpLR(_LRScheduler):
    """warmup_training learning rate scheduler
    Args:
        optimizer: optimzier(e.g. SGD)
        total_iters: totoal_iters of warmup phase
    """
    def __init__(self, optimizer, total_iters, last_epoch=-1):
        
        self.total_iters = total_iters
        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        """we will use the first m batches, and set the learning
        rate to base_lr * m / total_iters
        """
        return [base_lr * self.last_epoch / (self.total_iters + 1e-8) for base_lr in self.base_lrs]

In [8]:
def train(max_class):
    lr = 0.1
    scheduling = [60, 120, 160]
    gamma = 0.2
    warmup_epochs = 1
    n_epochs = 200
    
    train_loader, test_loader = get_task_dataset(max_class)
    
    model = torchvision.models.resnet34(num_classes=100).to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, nesterov=True, momentum=0.9, weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, scheduling, gamma=gamma)
    warmup = WarmUpLR(optimizer, len(train_loader) * warmup_epochs)
        
    best_model = None
    best_acc = -1.

    for epoch in range(n_epochs):
        if epoch < warmup_epochs:
            warmup.step()
        else:
            scheduler.step(epoch)
            
        epoch_loss = 0.

        model.train()
        for inputs, targets in train_loader:
            inputs = inputs.to(device)
            targets = targets.to(device)

            optimizer.zero_grad()
            logits = model(inputs)
            loss = F.cross_entropy(logits, targets)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        model.eval()
        _, _, train_acc = acc(model, train_loader)
        _, _, test_acc = acc(model, test_loader)

        print("Max class: {}, Epoch {}/{}, train_loss: {}, train_acc: {}, test_acc: {}".format(
            max_class, epoch, n_epochs,
            round(epoch_loss / len(train_loader), 3),
            round(train_acc, 2), round(test_acc, 2)
        ))

        if test_acc > best_acc:
            print("Best acc! Saving model")
            best_model = copy.deepcopy(model)
            best_acc = test_acc
            
    del model, train_loader, test_loader
    return test_acc, best_acc

In [None]:
all_best, all_test = [], []

for i in range(10, 101, 10):
    print("Train up to {}.".format(i))
    t, b = train(i)
    all_test.append(t)
    all_best.append(b)

Train up to 10.
Files already downloaded and verified
Files already downloaded and verified
Max class: 10, Epoch 0/200, train_loss: 5.254, train_acc: 0.0, test_acc: 0.0
Best acc! Saving model
Max class: 10, Epoch 1/200, train_loss: 5.614, train_acc: 0.09, test_acc: 0.09
Best acc! Saving model
Max class: 10, Epoch 2/200, train_loss: 2.673, train_acc: 0.13, test_acc: 0.15
Best acc! Saving model
Max class: 10, Epoch 3/200, train_loss: 2.464, train_acc: 0.15, test_acc: 0.12
Max class: 10, Epoch 4/200, train_loss: 2.37, train_acc: 0.16, test_acc: 0.16
Best acc! Saving model
Max class: 10, Epoch 5/200, train_loss: 2.313, train_acc: 0.17, test_acc: 0.18
Best acc! Saving model
Max class: 10, Epoch 6/200, train_loss: 2.23, train_acc: 0.21, test_acc: 0.23
Best acc! Saving model
Max class: 10, Epoch 7/200, train_loss: 2.226, train_acc: 0.24, test_acc: 0.25
Best acc! Saving model
Max class: 10, Epoch 8/200, train_loss: 2.098, train_acc: 0.26, test_acc: 0.28
Best acc! Saving model
Max class: 10, Ep

Max class: 10, Epoch 94/200, train_loss: 0.188, train_acc: 0.95, test_acc: 0.72
Max class: 10, Epoch 95/200, train_loss: 0.192, train_acc: 0.95, test_acc: 0.73
Max class: 10, Epoch 96/200, train_loss: 0.197, train_acc: 0.94, test_acc: 0.72
Max class: 10, Epoch 97/200, train_loss: 0.208, train_acc: 0.94, test_acc: 0.71
Max class: 10, Epoch 98/200, train_loss: 0.181, train_acc: 0.95, test_acc: 0.72
Max class: 10, Epoch 99/200, train_loss: 0.206, train_acc: 0.94, test_acc: 0.72
Max class: 10, Epoch 100/200, train_loss: 0.2, train_acc: 0.95, test_acc: 0.74
Max class: 10, Epoch 101/200, train_loss: 0.181, train_acc: 0.94, test_acc: 0.71
Max class: 10, Epoch 102/200, train_loss: 0.179, train_acc: 0.96, test_acc: 0.73
Max class: 10, Epoch 103/200, train_loss: 0.173, train_acc: 0.95, test_acc: 0.73
Max class: 10, Epoch 104/200, train_loss: 0.186, train_acc: 0.95, test_acc: 0.74
Max class: 10, Epoch 105/200, train_loss: 0.19, train_acc: 0.95, test_acc: 0.72
Max class: 10, Epoch 106/200, train_l

Max class: 10, Epoch 195/200, train_loss: 0.02, train_acc: 1.0, test_acc: 0.75
Max class: 10, Epoch 196/200, train_loss: 0.015, train_acc: 1.0, test_acc: 0.75
Max class: 10, Epoch 197/200, train_loss: 0.019, train_acc: 1.0, test_acc: 0.76
Max class: 10, Epoch 198/200, train_loss: 0.018, train_acc: 1.0, test_acc: 0.74
Max class: 10, Epoch 199/200, train_loss: 0.022, train_acc: 1.0, test_acc: 0.74
Train up to 20.
Files already downloaded and verified
Files already downloaded and verified
Max class: 20, Epoch 0/200, train_loss: 4.933, train_acc: 0.01, test_acc: 0.01
Best acc! Saving model
Max class: 20, Epoch 1/200, train_loss: 4.475, train_acc: 0.05, test_acc: 0.05
Best acc! Saving model
Max class: 20, Epoch 2/200, train_loss: 3.397, train_acc: 0.05, test_acc: 0.05
Max class: 20, Epoch 3/200, train_loss: 3.472, train_acc: 0.06, test_acc: 0.06
Best acc! Saving model
Max class: 20, Epoch 4/200, train_loss: 3.178, train_acc: 0.08, test_acc: 0.09
Best acc! Saving model
Max class: 20, Epoch 5

Max class: 20, Epoch 89/200, train_loss: 1.161, train_acc: 0.63, test_acc: 0.56
Max class: 20, Epoch 90/200, train_loss: 1.183, train_acc: 0.61, test_acc: 0.56
Max class: 20, Epoch 91/200, train_loss: 1.198, train_acc: 0.59, test_acc: 0.53
Max class: 20, Epoch 92/200, train_loss: 1.237, train_acc: 0.63, test_acc: 0.56
Max class: 20, Epoch 93/200, train_loss: 1.142, train_acc: 0.62, test_acc: 0.54
Max class: 20, Epoch 94/200, train_loss: 1.176, train_acc: 0.65, test_acc: 0.58
Best acc! Saving model
Max class: 20, Epoch 95/200, train_loss: 1.11, train_acc: 0.62, test_acc: 0.56
Max class: 20, Epoch 96/200, train_loss: 1.172, train_acc: 0.65, test_acc: 0.56
Max class: 20, Epoch 97/200, train_loss: 1.088, train_acc: 0.65, test_acc: 0.55
Max class: 20, Epoch 98/200, train_loss: 1.071, train_acc: 0.59, test_acc: 0.53
Max class: 20, Epoch 99/200, train_loss: 1.197, train_acc: 0.61, test_acc: 0.54
Max class: 20, Epoch 100/200, train_loss: 1.163, train_acc: 0.63, test_acc: 0.56
Max class: 20, Ep

Max class: 20, Epoch 188/200, train_loss: 0.305, train_acc: 0.93, test_acc: 0.61
Max class: 20, Epoch 189/200, train_loss: 0.295, train_acc: 0.93, test_acc: 0.62
Max class: 20, Epoch 190/200, train_loss: 0.298, train_acc: 0.93, test_acc: 0.61
Max class: 20, Epoch 191/200, train_loss: 0.31, train_acc: 0.94, test_acc: 0.62
Max class: 20, Epoch 192/200, train_loss: 0.329, train_acc: 0.93, test_acc: 0.61
Max class: 20, Epoch 193/200, train_loss: 0.292, train_acc: 0.93, test_acc: 0.61
Max class: 20, Epoch 194/200, train_loss: 0.268, train_acc: 0.94, test_acc: 0.62
Max class: 20, Epoch 195/200, train_loss: 0.278, train_acc: 0.94, test_acc: 0.61
Max class: 20, Epoch 196/200, train_loss: 0.271, train_acc: 0.94, test_acc: 0.62
Max class: 20, Epoch 197/200, train_loss: 0.274, train_acc: 0.94, test_acc: 0.61
Max class: 20, Epoch 198/200, train_loss: 0.255, train_acc: 0.94, test_acc: 0.62
Max class: 20, Epoch 199/200, train_loss: 0.255, train_acc: 0.94, test_acc: 0.61
Train up to 30.
Files already

Max class: 30, Epoch 90/200, train_loss: 0.395, train_acc: 0.87, test_acc: 0.64
Max class: 30, Epoch 91/200, train_loss: 0.424, train_acc: 0.85, test_acc: 0.62
Max class: 30, Epoch 92/200, train_loss: 0.395, train_acc: 0.85, test_acc: 0.62
Max class: 30, Epoch 93/200, train_loss: 0.429, train_acc: 0.88, test_acc: 0.63
Max class: 30, Epoch 94/200, train_loss: 0.39, train_acc: 0.87, test_acc: 0.63
Max class: 30, Epoch 95/200, train_loss: 0.423, train_acc: 0.86, test_acc: 0.61
Max class: 30, Epoch 96/200, train_loss: 0.425, train_acc: 0.89, test_acc: 0.64
Max class: 30, Epoch 97/200, train_loss: 0.373, train_acc: 0.89, test_acc: 0.65
Max class: 30, Epoch 98/200, train_loss: 0.345, train_acc: 0.89, test_acc: 0.63
Max class: 30, Epoch 99/200, train_loss: 0.38, train_acc: 0.89, test_acc: 0.64
Max class: 30, Epoch 100/200, train_loss: 0.393, train_acc: 0.88, test_acc: 0.64
Max class: 30, Epoch 101/200, train_loss: 0.378, train_acc: 0.83, test_acc: 0.6
Max class: 30, Epoch 102/200, train_loss:

Max class: 30, Epoch 191/200, train_loss: 0.021, train_acc: 1.0, test_acc: 0.67
Max class: 30, Epoch 192/200, train_loss: 0.018, train_acc: 1.0, test_acc: 0.67
Max class: 30, Epoch 193/200, train_loss: 0.021, train_acc: 1.0, test_acc: 0.67
Max class: 30, Epoch 194/200, train_loss: 0.021, train_acc: 1.0, test_acc: 0.67
Max class: 30, Epoch 195/200, train_loss: 0.021, train_acc: 1.0, test_acc: 0.67
Max class: 30, Epoch 196/200, train_loss: 0.02, train_acc: 1.0, test_acc: 0.67
Max class: 30, Epoch 197/200, train_loss: 0.019, train_acc: 1.0, test_acc: 0.67
Max class: 30, Epoch 198/200, train_loss: 0.018, train_acc: 1.0, test_acc: 0.67
Max class: 30, Epoch 199/200, train_loss: 0.021, train_acc: 1.0, test_acc: 0.67
Train up to 40.
Files already downloaded and verified
Files already downloaded and verified
Max class: 40, Epoch 0/200, train_loss: 5.027, train_acc: 0.01, test_acc: 0.01
Best acc! Saving model
Max class: 40, Epoch 1/200, train_loss: 4.242, train_acc: 0.07, test_acc: 0.08
Best acc

Max class: 40, Epoch 85/200, train_loss: 0.451, train_acc: 0.88, test_acc: 0.61
Max class: 40, Epoch 86/200, train_loss: 0.413, train_acc: 0.89, test_acc: 0.62
Max class: 40, Epoch 87/200, train_loss: 0.424, train_acc: 0.88, test_acc: 0.61
Max class: 40, Epoch 88/200, train_loss: 0.411, train_acc: 0.87, test_acc: 0.61
Max class: 40, Epoch 89/200, train_loss: 0.422, train_acc: 0.85, test_acc: 0.59
Max class: 40, Epoch 90/200, train_loss: 0.382, train_acc: 0.88, test_acc: 0.6
Max class: 40, Epoch 91/200, train_loss: 0.392, train_acc: 0.89, test_acc: 0.61
Max class: 40, Epoch 92/200, train_loss: 0.402, train_acc: 0.88, test_acc: 0.6
Max class: 40, Epoch 93/200, train_loss: 0.397, train_acc: 0.89, test_acc: 0.6
Max class: 40, Epoch 94/200, train_loss: 0.375, train_acc: 0.88, test_acc: 0.6
Max class: 40, Epoch 95/200, train_loss: 0.388, train_acc: 0.87, test_acc: 0.6
Max class: 40, Epoch 96/200, train_loss: 0.405, train_acc: 0.88, test_acc: 0.61
Max class: 40, Epoch 97/200, train_loss: 0.37

Max class: 40, Epoch 185/200, train_loss: 0.017, train_acc: 1.0, test_acc: 0.65
Max class: 40, Epoch 186/200, train_loss: 0.015, train_acc: 1.0, test_acc: 0.65
Max class: 40, Epoch 187/200, train_loss: 0.015, train_acc: 1.0, test_acc: 0.65
Max class: 40, Epoch 188/200, train_loss: 0.014, train_acc: 1.0, test_acc: 0.65
Max class: 40, Epoch 189/200, train_loss: 0.015, train_acc: 1.0, test_acc: 0.65
Max class: 40, Epoch 190/200, train_loss: 0.016, train_acc: 1.0, test_acc: 0.65
Max class: 40, Epoch 191/200, train_loss: 0.014, train_acc: 1.0, test_acc: 0.65
Max class: 40, Epoch 192/200, train_loss: 0.015, train_acc: 1.0, test_acc: 0.65
Max class: 40, Epoch 193/200, train_loss: 0.015, train_acc: 1.0, test_acc: 0.65
Max class: 40, Epoch 194/200, train_loss: 0.014, train_acc: 1.0, test_acc: 0.65
Best acc! Saving model
Max class: 40, Epoch 195/200, train_loss: 0.017, train_acc: 1.0, test_acc: 0.65
Max class: 40, Epoch 196/200, train_loss: 0.015, train_acc: 1.0, test_acc: 0.65
Max class: 40, Ep

Max class: 50, Epoch 77/200, train_loss: 0.628, train_acc: 0.84, test_acc: 0.6
Max class: 50, Epoch 78/200, train_loss: 0.612, train_acc: 0.83, test_acc: 0.6
Max class: 50, Epoch 79/200, train_loss: 0.605, train_acc: 0.83, test_acc: 0.6
Max class: 50, Epoch 80/200, train_loss: 0.605, train_acc: 0.83, test_acc: 0.58
Max class: 50, Epoch 81/200, train_loss: 0.596, train_acc: 0.84, test_acc: 0.6
Max class: 50, Epoch 82/200, train_loss: 0.576, train_acc: 0.84, test_acc: 0.59
Max class: 50, Epoch 83/200, train_loss: 0.581, train_acc: 0.84, test_acc: 0.59
Max class: 50, Epoch 84/200, train_loss: 0.574, train_acc: 0.84, test_acc: 0.58
Max class: 50, Epoch 85/200, train_loss: 0.566, train_acc: 0.84, test_acc: 0.58
Max class: 50, Epoch 86/200, train_loss: 0.559, train_acc: 0.84, test_acc: 0.59
Max class: 50, Epoch 87/200, train_loss: 0.555, train_acc: 0.83, test_acc: 0.57
Max class: 50, Epoch 88/200, train_loss: 0.551, train_acc: 0.86, test_acc: 0.59
Max class: 50, Epoch 89/200, train_loss: 0.5

Max class: 50, Epoch 180/200, train_loss: 0.017, train_acc: 1.0, test_acc: 0.63
Max class: 50, Epoch 181/200, train_loss: 0.017, train_acc: 1.0, test_acc: 0.63
Max class: 50, Epoch 182/200, train_loss: 0.016, train_acc: 1.0, test_acc: 0.62
Max class: 50, Epoch 183/200, train_loss: 0.016, train_acc: 1.0, test_acc: 0.63
Max class: 50, Epoch 184/200, train_loss: 0.016, train_acc: 1.0, test_acc: 0.63
Max class: 50, Epoch 185/200, train_loss: 0.016, train_acc: 1.0, test_acc: 0.63
Max class: 50, Epoch 186/200, train_loss: 0.017, train_acc: 1.0, test_acc: 0.63
Max class: 50, Epoch 187/200, train_loss: 0.015, train_acc: 1.0, test_acc: 0.63
Max class: 50, Epoch 188/200, train_loss: 0.015, train_acc: 1.0, test_acc: 0.63
Max class: 50, Epoch 189/200, train_loss: 0.015, train_acc: 1.0, test_acc: 0.63
Max class: 50, Epoch 190/200, train_loss: 0.016, train_acc: 1.0, test_acc: 0.63
Best acc! Saving model
Max class: 50, Epoch 191/200, train_loss: 0.016, train_acc: 1.0, test_acc: 0.63
Max class: 50, Ep

Max class: 60, Epoch 74/200, train_loss: 0.636, train_acc: 0.83, test_acc: 0.61
Max class: 60, Epoch 75/200, train_loss: 0.636, train_acc: 0.82, test_acc: 0.6
Max class: 60, Epoch 76/200, train_loss: 0.62, train_acc: 0.83, test_acc: 0.6
Max class: 60, Epoch 77/200, train_loss: 0.616, train_acc: 0.81, test_acc: 0.59
Max class: 60, Epoch 78/200, train_loss: 0.59, train_acc: 0.83, test_acc: 0.6
Max class: 60, Epoch 79/200, train_loss: 0.595, train_acc: 0.81, test_acc: 0.59
Max class: 60, Epoch 80/200, train_loss: 0.587, train_acc: 0.81, test_acc: 0.58
Max class: 60, Epoch 81/200, train_loss: 0.582, train_acc: 0.81, test_acc: 0.59
Max class: 60, Epoch 82/200, train_loss: 0.58, train_acc: 0.83, test_acc: 0.6
Max class: 60, Epoch 83/200, train_loss: 0.575, train_acc: 0.83, test_acc: 0.59
Max class: 60, Epoch 84/200, train_loss: 0.556, train_acc: 0.85, test_acc: 0.59
Max class: 60, Epoch 85/200, train_loss: 0.547, train_acc: 0.83, test_acc: 0.59
Max class: 60, Epoch 86/200, train_loss: 0.552,

Max class: 60, Epoch 176/200, train_loss: 0.015, train_acc: 1.0, test_acc: 0.63
Max class: 60, Epoch 177/200, train_loss: 0.015, train_acc: 1.0, test_acc: 0.64
Best acc! Saving model
Max class: 60, Epoch 178/200, train_loss: 0.015, train_acc: 1.0, test_acc: 0.63
Max class: 60, Epoch 179/200, train_loss: 0.014, train_acc: 1.0, test_acc: 0.63
Max class: 60, Epoch 180/200, train_loss: 0.014, train_acc: 1.0, test_acc: 0.63
Max class: 60, Epoch 181/200, train_loss: 0.014, train_acc: 1.0, test_acc: 0.63
Max class: 60, Epoch 182/200, train_loss: 0.014, train_acc: 1.0, test_acc: 0.64
Max class: 60, Epoch 183/200, train_loss: 0.014, train_acc: 1.0, test_acc: 0.64
Max class: 60, Epoch 184/200, train_loss: 0.015, train_acc: 1.0, test_acc: 0.64
Max class: 60, Epoch 185/200, train_loss: 0.014, train_acc: 1.0, test_acc: 0.64
Max class: 60, Epoch 186/200, train_loss: 0.014, train_acc: 1.0, test_acc: 0.64
Max class: 60, Epoch 187/200, train_loss: 0.015, train_acc: 1.0, test_acc: 0.63
Max class: 60, Ep