In [12]:
import torch
import torchvision.transforms as TF
from cl_lite.head.dynamic_simple import DynamicSimpleHead
import cl_lite.backbone as B
import os
import numpy as np
import sys
sys.path.append('..')
from rdfcil.datamodule import DataModule

dataset = "imagenet100"
num_classes = 100
data_root = '../rdfcil/data'
class_order = [68, 56, 78, 8, 23, 84, 90, 65, 74, 76, 40, 89, 3, 92, 55, 9, 26, 80, 43, 38, 58, 70, 77, 1, 85, 19, 17, 50, 28, 53, 13, 81, 45, 82, 6, 59, 83, 16, 15, 44, 91, 41, 72, 60, 79, 52, 20, 10, 31, 54, 37, 95, 14, 71, 96, 98, 97, 2, 64, 66, 42, 22, 35, 86, 24, 34, 87, 21, 99, 0, 88, 27, 18, 94, 11, 12, 47, 25, 30, 46, 62, 69, 36, 61, 7, 63, 75, 5, 32, 4, 51, 48, 73, 93, 39, 67, 29, 49, 57, 33]
num_tasks = 20

# Convert class_order to a tensor for faster indexing
class_order_tensor = torch.tensor(class_order).cuda()

# Create a tensor of zeros with the same length as class_order
# This tensor will be used to create a mapping where the index is the class order position
mapping_tensor = torch.zeros(len(class_order), dtype=torch.long).cuda()

# Assign the new class indices (which are just the indices of class_order_tensor) to the corresponding positions in mapping_tensor
mapping_tensor[class_order_tensor] = torch.arange(len(class_order_tensor)).cuda()

# Use the mapping tensor to map the labels
# fast_mapped_labels = mapping_tensor[labels]

In [None]:
# get forgetting results
total_task_acc=[]
for t in range(1,num_tasks+1):
    print(f"Task {t}")
    
    # get the model
    if dataset.startswith("imagenet"):
        backbone = B.resnet.resnet18()
    else:
        backbone = ISCF_ResNet()

    # prefix = './ImageNet-100/imnet100_version_675_rdfcil_5task_49.44/task_{}'.format(t-1)
    # prefix = './ImageNet-100/version_508_imnet_5task_54.64/task_{}'.format(t-1)
    # prefix = './ImageNet-100/version_507_imnet_20task_32.90/task_{}/'.format(t-1)
    prefix = './ImageNet-100/version_793_imnet_20t_rdfcil/task_{}/'.format(t-1)
    # prefix = './ImageNet-100/version_451_imnet_10task_45.18/task_{}/'.format(t-1)
    # prefix = './ImageNet-100/imnet100_version_430_rdfcil_10task_40.7/task_{}/'.format(t-1)
    state_dict = torch.load(os.path.join(prefix,"checkpoints/best_acc.ckpt"))['state_dict']

    # dataload
    data_module = DataModule(root=data_root, 
                            dataset=dataset, 
                            batch_size=128, 
                            num_workers=4,
                            num_tasks=num_tasks,
                            class_order=class_order,
                            current_task=t-1,
                            )
    data_module.setup()
    head = DynamicSimpleHead(num_classes=data_module.num_classes, num_features=backbone.num_features, bias=False)
    # head = DynamicSimpleHead(num_classes=data_module.num_classes, num_features=backbone.num_features, bias=True)
    
    backbone_state= {}
    head_state = {}
    for _ in range(t-1):
        head.append(num_classes//num_tasks)
    for k,v in state_dict.items():
        if k.startswith('backbone'):
            backbone_state[k[9:]] = v
        elif k.startswith('head'):
            head_state[k[5:]] = v
            # head_state[k[17:]] = v
            
        
    backbone.load_state_dict(backbone_state)
    backbone.eval()
    head.load_state_dict(head_state)


    # train_dataloader = data_module.train_dataloader()
    val_dataloader = data_module.val_dataloader()

    backbone.cuda()
    head.cuda()

    task_correct= [0 for _ in range(t)]
    task_total = [0 for _ in range(t)]
    idx=0
    for batch in val_dataloader:
        images, labels = batch
        images = images.cuda()
        labels = labels.cuda()
        with torch.no_grad():
            output = backbone(images)
            output = head(output)
            # print(output[0],labels[0])
            labels = mapping_tensor[labels]
            for i in range(t):
                t_indices = torch.nonzero(torch.bitwise_and(num_classes//num_tasks*(i+1) >= labels, num_classes//num_tasks*(i) < labels) ).view(-1)
                # print(t_indices.view(-1))
                # task accuracy
                labels_t = labels[t_indices] # - i*num_classes//num_tasks
                output_t = output[t_indices]
                # if i==0:
                #     output_t = output_t[:,:num_classes//num_tasks*(i+1)]
                # else: output_t = output_t[:,num_classes//num_tasks*i:num_classes//num_tasks*(i+1)]
                task_correct[i] += (output_t.argmax(dim=1) == labels_t).sum().item()
                task_total[i] += len(labels_t)
        idx+=1
        print('\r idx: {}'.format(idx), end='')
    print()
    task_acc = [float(cc)/ct for cc,ct in zip(task_correct,task_total)]
    print(task_acc)

    for j in range(num_tasks-t):
        task_acc.append(0)
    total_task_acc.append(task_acc)
total_task_acc = np.array(total_task_acc)
print(total_task_acc)
result = []
for i in range(num_tasks):
    if i == 0:
        result.append(0)
    else:
        res = 0
        for j in range(i + 1):
            res += (np.max(total_task_acc[:, j]) - total_task_acc[i][j])
        res = res / i
        result.append(100 * res)

        
print('Forgetting result:')
print(result)
print(sum(result)/len(result))

Task 1
 idx: 2
[0.83]
Task 2
 idx: 4
[0.644, 0.79]
Task 3
 idx: 6
[0.548, 0.652, 0.71]
Task 4
 idx: 8
[0.328, 0.492, 0.708, 0.73]
Task 5
 idx: 10
[0.368, 0.232, 0.676, 0.42, 0.74]
Task 6
 idx: 12
[0.292, 0.328, 0.484, 0.38, 0.56, 0.775]
Task 7
 idx: 14
[0.276, 0.296, 0.56, 0.292, 0.384, 0.576, 0.635]
Task 8
 idx: 16
[0.32, 0.276, 0.264, 0.252, 0.348, 0.304, 0.588, 0.745]
Task 9
 idx: 18
[0.284, 0.184, 0.38, 0.244, 0.364, 0.348, 0.552, 0.496, 0.645]
Task 10
 idx: 20
[0.332, 0.18, 0.328, 0.232, 0.32, 0.224, 0.44, 0.296, 0.484, 0.6]
Task 11
 idx: 22
[0.172, 0.136, 0.268, 0.204, 0.296, 0.224, 0.376, 0.284, 0.416, 0.504, 0.795]
Task 12
 idx: 24
[0.22, 0.188, 0.2, 0.2, 0.292, 0.24, 0.332, 0.276, 0.428, 0.256, 0.548, 0.64]
Task 13
 idx: 26
[0.2, 0.188, 0.156, 0.144, 0.312, 0.14, 0.18, 0.172, 0.364, 0.228, 0.54, 0.584, 0.79]
Task 14
 idx: 28
[0.172, 0.156, 0.2, 0.16, 0.288, 0.212, 0.188, 0.256, 0.344, 0.204, 0.4, 0.484, 0.488, 0.735]
Task 15
 idx: 30
[0.168, 0.076, 0.124, 0.132, 0.26, 0.204, 0

In [None]:

forgetting = np.max(total_task_acc,axis=0) - np.min(total_task_acc,axis=0)
print(forgetting)
avg_forgetting = np.mean(forgetting)
print(avg_forgetting)