In [7]:
import numpy as np
import torch
from torch import nn

import pickle
import yaml
from easydict import EasyDict as edict
from warnings import filterwarnings
filterwarnings('ignore')

from source.utils import set_random_seed, get_data
set_random_seed(42)


In [8]:
f = open('./args/args.pkl','rb')
args = pickle.load(f)
f1 = open('./args/cifar_dytox.yaml', 'r', encoding='utf-8')
args1 = yaml.safe_load(f1)
f2 = open('./args/cifar100_order1.yaml', 'r', encoding='utf-8')
args2 = yaml.safe_load(f2)

In [9]:
args.update(args1)
args.update(args2)
args = edict(args)
args.data_set = 'CIFAR'
args.data_path = '/home/choiyj/pycil/data'
args.output_basedir = ''
args.distributed=False

args.initial_increment = 0
args.increment = 10
incremental_classes = 10
total_step = (100-args.initial_increment)//incremental_classes
total_step = total_step if args.initial_increment==0 else total_step+1

In [10]:
from source.datasets import build_dataset
scenario_train, args.nb_classes = build_dataset(is_train=True, args=args)
scenario_val, _ = build_dataset(is_train=False, args=args)

Files already downloaded and verified
Files already downloaded and verified


In [38]:
from source.ViT_exp import ViT_clf

initial = 0
total_steps = 10

device = 'cuda:0'

model = ViT_clf(num_classes=10, img_size=32, patch_size=4, num_patches=64, 
                in_chans=3, embed_dim=384, depth=6, num_heads=args.num_heads, mlp_ratio=4.0, 
                qkv_bias=False, qk_scale=False, drop_rate=0., attn_drop_rate=0.0, drop_path=args.drop_path, norm_layer=nn.LayerNorm,
                attention_type='GPSA')
model.to(device)
states = [torch.load(f'./weights/state_{i}_{initial}-{total_steps}steps.pt', map_location=device) for i in range(10)]
tokens = [states[i]['task_tokens'] for i in range(10)]
val_loaders = [get_data(i, args=args, scenario_train=scenario_train, scenario_val=scenario_val)[1] for i in range(10)]

In [39]:
from source.ViT_exp import inference
from sklearn.metrics import accuracy_score
import pandas as pd

In [None]:
df_outer = pd.DataFrame(np.full((10,10), np.nan))
df_inner = pd.DataFrame(np.full((10,10), np.nan))

model = ViT_clf(num_classes=10, img_size=32, patch_size=4, num_patches=64, 
                in_chans=3, embed_dim=384, depth=6, num_heads=args.num_heads, mlp_ratio=4.0, 
                qkv_bias=False, qk_scale=False, drop_rate=0., attn_drop_rate=0.0, drop_path=args.drop_path, norm_layer=nn.LayerNorm,
                attention_type='GPSA')
model.to(device)
for task_id in range(10):
    print(f'model{task_id}')
    model.load_state_dict(states[task_id])
    acc_inners = []
    for task, loader in enumerate(val_loaders[:task_id+1]):
        logits, targets = inference(model, loader, task, tokens[task], device)

        acc_inner = accuracy_score(targets, logits.argmax(1) + task*10)
        df_inner.iloc[task_id,task] = acc_inner
        acc_inners.append(acc_inner)
    print(np.mean(acc_inners))
    if task_id == 0:
        model.set_teacher_task_token(nn.Parameter(torch.zeros(1, 1, 384)))
    model.classifier_expand(10)

In [41]:
df_inner['average'] = df_inner.mean(axis=1)
df_inner

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,average
0,0.926,,,,,,,,,,0.926
1,0.921,0.879,,,,,,,,,0.9
2,0.923,0.87,0.859,,,,,,,,0.884
3,0.926,0.87,0.852,0.856,,,,,,,0.876
4,0.923,0.867,0.849,0.866,0.885,,,,,,0.878
5,0.926,0.866,0.85,0.865,0.882,0.92,,,,,0.884833
6,0.925,0.863,0.85,0.86,0.884,0.915,0.819,,,,0.873714
7,0.919,0.856,0.841,0.853,0.882,0.918,0.822,0.905,,,0.8745
8,0.919,0.858,0.845,0.852,0.875,0.913,0.817,0.902,0.892,,0.874778
9,0.92,0.847,0.841,0.838,0.874,0.907,0.825,0.895,0.891,0.853,0.8691
