In [1]:
import sys
sys.path.append("../..")
import numpy as np
from iirc.datasets_loader import get_lifelong_datasets
from iirc.definitions import PYTORCH, IIRC_SETUP
from iirc.utils.download_cifar import download_extract_cifar100
from __future__ import print_function
from __future__ import division

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt

from torchvision.datasets import CIFAR100
import torchvision.transforms as tt
import torch.nn.functional as F
from torchmetrics.classification import MultilabelJaccardIndex


import time
import os
import copy
print("PyTorch Version: ", torch.__version__)
print("Torchvision Version: ", torchvision.__version__)

from IIRC_CIFAR_HIERARCHY import classHierarchy
# Detect if we have a GPU available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"device: {device}")

PyTorch Version:  1.13.0
Torchvision Version:  0.14.0
device: cuda:0


In [2]:
download_extract_cifar100("../../data")

extracting CIFAR 100
dataset extracted




In [3]:
import torchvision.transforms as transforms

essential_transforms_fn = transforms.ToTensor()
augmentation_transforms_fn = transforms.Compose([
    transforms.RandomCrop(32,padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()
])

In [4]:
dataset_splits, tasks, class_names_to_idx = \
    get_lifelong_datasets(dataset_name = "iirc_cifar100",
                          dataset_root = "../../data", # the imagenet folder (where the train and val folders reside, or the parent directory of cifar-100-python folder
                          setup = IIRC_SETUP,
                          framework = PYTORCH,
                          tasks_configuration_id = 0,
                          essential_transforms_fn = essential_transforms_fn,
                          augmentation_transforms_fn = augmentation_transforms_fn,
                          joint = False
                         )

Creating iirc_cifar100
Setup used: IIRC
Using PyTorch
Dataset created


In [5]:
# print(len(tasks))
n_classes_per_task = []
for task in tasks:
    n_classes_per_task.append(len(task))
n_classes_per_task = np.array(n_classes_per_task)

In [6]:
# lifelong_datasets['train'].choose_task(2)
# print(list(zip(*lifelong_datasets['train']))[1])
for i in dataset_splits:
    print(i)


train
intask_valid
posttask_valid
test


In [37]:
# initialize a pretrained model (imageNet)
model_name = "resnet" #choosing alexnet since it is "relatively" easy to train
# model_name = "squeezenet" # changed to squeezeNet since it gets same acc as alex but smaller
num_classes = 9 # in cifar100

batch_size = 8

num_epochs = 14

feature_extract = False #set to false so we can finetune entire model

In [38]:
def train_model(model, trainloader, testloader, criterion, optimizer, num_classes, num_epochs=5 ):
    since = time.time() # including this just because
    
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch+1, num_epochs))
        
                
        running_loss = 0.0
        running_corrects = 0

        # iterate over data
        for inputs,label1,label2 in trainloader:
            inputs = inputs.to(device)
            label1 = torch.from_numpy(np.array([class_names_to_idx[i] for i in label1]))
#             label1 = F.one_hot(label1, num_classes=num_classes)
#             label1 = label1.to(torch.float32)
            label1 = label1.to(device)
#             label2 = label2.to(device)


            #empty the gradients
            optimizer.zero_grad()

            outputs = model(inputs)
#             print(label1.dtype, outputs.dtype)
#             print(outputs,label1)
            loss = criterion(outputs, label1)
            loss.backward()
            optimizer.step()





            # statistics
            running_loss += loss.item() * inputs.size(0)
#             running_corrects += torch.sum(preds == labels.data)
                
        epoch_loss = running_loss / len(trainloader.dataset)
        print("len dataset = ",len(trainloader.dataset))
#             epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
            
        print('{} Loss: {:.4f}'.format('train', epoch_loss))

        print()
        test_model(model, testloader, num_classes, mode=0)
        model = model.to(device)

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))

    # load best model weights
    return model

In [39]:
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

In [40]:
class MultilabelClassifier(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
        
        self.model_wo_fc = nn.Sequential(*(list(self.resnet.children())[:-1]))
        self.num_ftrs = self.resnet.fc.in_features
        
        self.fc = nn.Linear(self.num_ftrs, num_classes)
        
    def forward(self, x):
        x = self.model_wo_fc(x)
        x = torch.flatten(x, 1)
        x = torch.sigmoid(self.fc(x))
        return x

In [41]:
def initialize_model(num_classes):
    model_ft = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
    num_ftrs = model_ft.fc.in_features
    model_ft.fc = nn.Linear(num_ftrs, num_classes)
    
    return model_ft

In [44]:
def test_model(model,testloader,num_classes,mode=0):
    correct = 0
    total = len(testloader.dataset)
    running_corrects = 0 
    JS = MultilabelJaccardIndex(num_labels = int(num_classes), average='weighted')
    model = model.to(torch.device("cpu"))
    with torch.no_grad():
        for i,data in enumerate(testloader):
            images, label1,label2 = data
            # since subclass labels are introduced after their corresponding superclass labels,
            # in case we encounter a subclass label, we can assume it's superclass label has already been introduced
            # or that it's superclass does not exist
#             if label1 in classHierarchy or label1 in classHierarchy.values(): #if subclass has superclass or is superclass
#                 if label1 in classHierarchy.values(): # if label is superclass label
#                     label1 = torch.from_numpy(np.array([class_names_to_idx[i] for i in label1]))
                    
#                     label1 = F.one_hot(label1, num_classes=num_classes)
#                     label1 = label1.to(torch.int32)

#                     label = label1
#                 else: # if label is subclass and has superclass
#                     label2 = label1
#                     label1 = classHierarchy[label1]
                    
#                     label1 = torch.from_numpy(np.array([class_names_to_idx[i] for i in label1]))
#                     label1 = F.one_hot(label1, num_classes=num_classes)
#                     label1 = label1.to(torch.int32)

#                     label2 = torch.from_numpy(np.array([class_names_to_idx[i] for i in label2]))
#                     label2 = F.one_hot(label2, num_classes=num_classes)
#                     label2 = label2.to(torch.int32)

#                     label = label1 + label2
                    
#             else: # subclass has no superclass
            label1 = torch.from_numpy(np.array([class_names_to_idx[i] for i in label1]))
#             label1 = F.one_hot(label1, num_classes=num_classes)
#             label1 = label1.to(torch.int32)

            label = label1

#             label = label.to(torch.int32)
            outputs = model(images) # sigmoidless activation
            _, preds = torch.max(outputs, 1)
#             print(predicted)

            correct += JS(outputs, F.one_hot(label, num_classes=num_classes).to(torch.int32))
            running_corrects += torch.sum(preds == label.data)
#             correct += (predicted == label).sum().item()
#             correct /= batch_size

#             print(preds,label)
        if mode == 0:
            print(f"In-task validation JS: {correct / total} ")
            print(f"In-task validation accuracy: {running_corrects / total} ")
            
        elif mode == 1:
            print(f"Post-task validation JS: {correct / total} ")
            print(f"Post-task validation accuracy: {running_corrects / total} ")
            
        elif mode == 2:
            print(f"Final Test JS: {correct /total} ")
            print(f"Final Test accuracy: {running_corrects / total} ")
            
            


In [50]:
# Setup 
# BCE loss for multi-label classification
# sigmoid activation after FC layer 
# everything above 0.5 is a predicted label

criterion = nn.CrossEntropyLoss() # as output is sigmoidless

# get dataset corresponding to each split
train_data = dataset_splits["train"]
intask_val_data = dataset_splits["intask_valid"]
posttask_val_data = dataset_splits["posttask_valid"]
test_data = dataset_splits["test"]

# pre-trained Model on imageNet 
resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
# resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)


seen_classes = 0
# initialize data to train on first task
for task in range(len(tasks)):
    train_data.choose_task(task)
    intask_val_data.choose_task(task)
    posttask_val_data.choose_task(task)
    
    trainloader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=2)
    InTask_valloader = torch.utils.data.DataLoader(intask_val_data, batch_size=batch_size, shuffle=True, num_workers=2)
    PostTask_valloader = torch.utils.data.DataLoader(posttask_val_data, batch_size=batch_size, shuffle=True, num_workers=2)
    
    seen_classes += n_classes_per_task[task]
    
        
    new_fc = nn.Linear(2048, seen_classes)
    
    for cl in range(seen_classes-n_classes_per_task[task]):
        new_fc.weight[cl].data = resnet.fc.weight[cl].data
            
    resnet.fc = new_fc
    resnet = resnet.to(device)
    params_to_update = resnet.parameters()

    optimizer_ft = optim.SGD(params_to_update, lr=0.001, momentum=0.9)
    
    resnet = train_model(resnet, trainloader, InTask_valloader, criterion, optimizer_ft , seen_classes,num_epochs)
#     test_model(resnet, PostTask_valloader,seen_classes, mode=1)

# resnet = train_model(resnet, dataloader_dict, criterion, optimizer_ft, num_epochs=num_epochs, is_inception=(model_name=="inception"))

Epoch 1/14
len dataset =  8160
train Loss: 2.0059

In-task validation JS: 0.04680483415722847 
In-task validation accuracy: 0.3656862676143646 
Epoch 2/14
len dataset =  8160
train Loss: 1.7206

In-task validation JS: 0.05225110054016113 
In-task validation accuracy: 0.45098039507865906 
Epoch 3/14
len dataset =  8160
train Loss: 1.5141

In-task validation JS: 0.05766602233052254 
In-task validation accuracy: 0.5166666507720947 
Epoch 4/14
len dataset =  8160
train Loss: 1.3910

In-task validation JS: 0.05719464644789696 
In-task validation accuracy: 0.5460784435272217 
Epoch 5/14
len dataset =  8160
train Loss: 1.2503

In-task validation JS: 0.05770163610577583 
In-task validation accuracy: 0.5794117450714111 
Epoch 6/14
len dataset =  8160
train Loss: 1.2055

In-task validation JS: 0.0602080412209034 
In-task validation accuracy: 0.6107842922210693 
Epoch 7/14
len dataset =  8160
train Loss: 1.1261

In-task validation JS: 0.05925273895263672 
In-task validation accuracy: 0.6441176533

len dataset =  1600
train Loss: 1.3697

In-task validation JS: 0.03833705186843872 
In-task validation accuracy: 0.6949999928474426 
Epoch 2/14
len dataset =  1600
train Loss: 0.7878

In-task validation JS: 0.039363838732242584 
In-task validation accuracy: 0.7850000262260437 
Epoch 3/14
len dataset =  1600
train Loss: 0.6980

In-task validation JS: 0.03694196045398712 
In-task validation accuracy: 0.7850000262260437 
Epoch 4/14
len dataset =  1600
train Loss: 0.6040

In-task validation JS: 0.03845982253551483 
In-task validation accuracy: 0.7850000262260437 
Epoch 5/14
len dataset =  1600
train Loss: 0.5636

In-task validation JS: 0.035915177315473557 
In-task validation accuracy: 0.7950000166893005 
Epoch 6/14
len dataset =  1600
train Loss: 0.5512

In-task validation JS: 0.03760044649243355 
In-task validation accuracy: 0.7699999809265137 
Epoch 7/14
len dataset =  1600
train Loss: 0.5399

In-task validation JS: 0.03978794813156128 
In-task validation accuracy: 0.7599999904632568 
E

len dataset =  1760
train Loss: 0.9167

In-task validation JS: 0.03423295542597771 
In-task validation accuracy: 0.7318181991577148 
Epoch 3/14
len dataset =  1760
train Loss: 0.8009

In-task validation JS: 0.03792613744735718 
In-task validation accuracy: 0.7136363387107849 
Epoch 4/14
len dataset =  1760
train Loss: 0.7154

In-task validation JS: 0.04204545542597771 
In-task validation accuracy: 0.7181817889213562 
Epoch 5/14
len dataset =  1760
train Loss: 0.6099

In-task validation JS: 0.03764204680919647 
In-task validation accuracy: 0.7227272987365723 
Epoch 6/14
len dataset =  1760
train Loss: 0.6161

In-task validation JS: 0.03707386180758476 
In-task validation accuracy: 0.7727272510528564 
Epoch 7/14
len dataset =  1760
train Loss: 0.5683

In-task validation JS: 0.03863636404275894 
In-task validation accuracy: 0.7318181991577148 
Epoch 8/14
len dataset =  1760
train Loss: 0.5263

In-task validation JS: 0.037784092128276825 
In-task validation accuracy: 0.7863636612892151 
Ep

In-task validation JS: 0.04762931168079376 
In-task validation accuracy: 0.8137931227684021 
Epoch 3/14
len dataset =  2320
train Loss: 0.5466

In-task validation JS: 0.04461206868290901 
In-task validation accuracy: 0.817241370677948 
Epoch 4/14
len dataset =  2320
train Loss: 0.4864

In-task validation JS: 0.044173337519168854 
In-task validation accuracy: 0.8310344815254211 
Epoch 5/14
len dataset =  2320
train Loss: 0.5106

In-task validation JS: 0.04535098746418953 
In-task validation accuracy: 0.8379310369491577 
Epoch 6/14
len dataset =  2320
train Loss: 0.4615

In-task validation JS: 0.0470135472714901 
In-task validation accuracy: 0.8413792848587036 
Epoch 7/14
len dataset =  2320
train Loss: 0.4673

In-task validation JS: 0.04371151700615883 
In-task validation accuracy: 0.8448275923728943 
Epoch 8/14
len dataset =  2320
train Loss: 0.4092

In-task validation JS: 0.04668257385492325 
In-task validation accuracy: 0.8448275923728943 
Epoch 9/14
len dataset =  2320
train Loss: 0

len dataset =  1760
train Loss: 0.6741

In-task validation JS: 0.03693181648850441 
In-task validation accuracy: 0.7590909004211426 
Epoch 4/14
len dataset =  1760
train Loss: 0.5374

In-task validation JS: 0.038778409361839294 
In-task validation accuracy: 0.7681818008422852 
Epoch 5/14
len dataset =  1760
train Loss: 0.4982

In-task validation JS: 0.04028002917766571 
In-task validation accuracy: 0.8136363625526428 
Epoch 6/14
len dataset =  1760
train Loss: 0.4466

In-task validation JS: 0.03892045468091965 
In-task validation accuracy: 0.8045454621315002 
Epoch 7/14
len dataset =  1760
train Loss: 0.4539

In-task validation JS: 0.04005681723356247 
In-task validation accuracy: 0.7818182110786438 
Epoch 8/14
len dataset =  1760
train Loss: 0.4054

In-task validation JS: 0.038352273404598236 
In-task validation accuracy: 0.8045454621315002 
Epoch 9/14
len dataset =  1760
train Loss: 0.3937

In-task validation JS: 0.04218750074505806 
In-task validation accuracy: 0.8136363625526428 
E

len dataset =  1680
train Loss: 0.9472

In-task validation JS: 0.04211309552192688 
In-task validation accuracy: 0.6523809432983398 
Epoch 4/14
len dataset =  1680
train Loss: 0.8431

In-task validation JS: 0.04077380895614624 
In-task validation accuracy: 0.6714285612106323 
Epoch 5/14
len dataset =  1680
train Loss: 0.7515

In-task validation JS: 0.03660714253783226 
In-task validation accuracy: 0.776190459728241 
Epoch 6/14
len dataset =  1680
train Loss: 0.7309

In-task validation JS: 0.03854166716337204 
In-task validation accuracy: 0.723809540271759 
Epoch 7/14
len dataset =  1680
train Loss: 0.7103

In-task validation JS: 0.0424107126891613 
In-task validation accuracy: 0.699999988079071 
Epoch 8/14
len dataset =  1680
train Loss: 0.6369

In-task validation JS: 0.03883928433060646 
In-task validation accuracy: 0.761904776096344 
Epoch 9/14
len dataset =  1680
train Loss: 0.6273

In-task validation JS: 0.0431547611951828 
In-task validation accuracy: 0.6904761791229248 
Epoch 10/

In [None]:
for task in range(len(tasks)):
    test_data.choose_task(task)
    testloader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=True, num_workers=2)
    
    test_model(resnet, testloader,seen_classes, mode=2)
    

Final Test JS: 0.005005310755223036 
Final Test accuracy: 0.0 
Final Test JS: 0.012279164977371693 
Final Test accuracy: 0.0 
Final Test JS: 0.0077430554665625095 
Final Test accuracy: 0.0 
Final Test JS: 0.005166666582226753 
Final Test accuracy: 0.0 
Final Test JS: 0.0077416663989424706 
Final Test accuracy: 0.0 
Final Test JS: 0.005958333145827055 
Final Test accuracy: 0.0 
Final Test JS: 0.005324999336153269 
Final Test accuracy: 0.0 
Final Test JS: 0.008962499909102917 
Final Test accuracy: 0.0 
Final Test JS: 0.005108333658427 
Final Test accuracy: 0.0 
Final Test JS: 0.003666666802018881 
Final Test accuracy: 0.0 
Final Test JS: 0.007703703828155994 
Final Test accuracy: 0.0 
Final Test JS: 0.00904999952763319 
Final Test accuracy: 0.0 
Final Test JS: 0.011649850755929947 
Final Test accuracy: 0.0 


In [None]:
PATH = "models/resnet_IIRC.pth"
torch.save(resnet.state_dict(), PATH)

In [47]:
for i in PostTask_valloader:
    print([class_names_to_idx[j] for j in i[1]])

0
[7, 7, 1, 7, 4, 8, 0, 0]
[6, 7, 3, 0, 0, 1, 0, 9]
[6, 3, 3, 2, 7, 5, 3, 1]
[18, 9, 8, 0, 7, 6, 6, 5]
[7, 7, 1, 0, 7, 5, 8, 5]
[12, 4, 7, 6, 7, 8, 0, 0]
[8, 2, 6, 12, 7, 7, 4, 2]
[6, 3, 8, 9, 8, 7, 8, 6]
[6, 5, 0, 9, 7, 4, 4, 7]
[0, 9, 7, 12, 1, 12, 5, 2]
[3, 2, 1, 3, 8, 0, 4, 9]
[5, 3, 7, 7, 7, 7, 9, 8]
[4, 3, 2, 4, 6, 2, 4, 1]
[6, 1, 12, 0, 9, 4, 12, 2]
[3, 5, 4, 7, 1, 8, 5, 8]
[3, 1, 4, 9, 8, 9, 4, 18]
[7, 3, 6, 9, 18, 1, 5, 9]
[3, 1, 3, 9, 7, 8, 2, 1]
[9, 9, 9, 6, 4, 4, 7, 18]
[2, 8, 4, 7, 18, 2, 9, 9]
[6, 6, 18, 6, 9, 7, 8, 3]
[0, 2, 2, 6, 7, 1, 2, 7]
[3, 5, 0, 0, 7, 3, 0, 3]
[9, 9, 7, 7, 9, 9, 7, 5]
[0, 8, 9, 5, 1, 9, 9, 12]
[18, 0, 5, 7, 5, 4, 7, 7]
[1, 4, 8, 5, 1, 0, 6, 9]
[3, 1, 6, 9, 12, 8, 3, 2]
[3, 9, 2, 0, 2, 8, 7, 18]
[2, 6, 3, 3, 1, 3, 8, 9]
[12, 7, 5, 0, 7, 0, 6, 8]
[4, 6, 2, 1, 1, 2, 7, 9]
[9, 8, 3, 3, 5, 8, 5, 7]
[1, 2, 3, 0, 6, 9, 4, 7]
[6, 0, 1, 4, 5, 1, 4, 2]
[7, 1, 2, 6, 1, 7, 6, 6]
[4, 0, 0, 3, 7, 3, 2, 6]
[8, 4, 4, 7, 6, 12, 0, 0]
[9, 1, 6, 6, 0, 0, 9, 4]
[6, 4