In [8]:
import sys
import argparse
import datetime
import random
import numpy as np
import time
import torch
import pandas as pd
import os
import sys 
sys.path.append('..')
from pathlib import Path
import models
from timm.models import create_model
from datasets import build_continual_dataloader

# pandas display max column 30
pd.set_option('display.max_columns', 30)


In [9]:
fine_labels = [
    'apple',  # id 0
    'aquarium_fish',
    'baby',
    'bear',
    'beaver',
    'bed',
    'bee',
    'beetle',
    'bicycle',
    'bottle',
    'bowl',
    'boy',
    'bridge',
    'bus',
    'butterfly',
    'camel',
    'can',
    'castle',
    'caterpillar',
    'cattle',
    'chair',
    'chimpanzee',
    'clock',
    'cloud',
    'cockroach',
    'couch',
    'crab',
    'crocodile',
    'cup',
    'dinosaur',
    'dolphin',
    'elephant',
    'flatfish',
    'forest',
    'fox',
    'girl',
    'hamster',
    'house',
    'kangaroo',
    'computer_keyboard',
    'lamp',
    'lawn_mower',
    'leopard',
    'lion',
    'lizard',
    'lobster',
    'man',
    'maple_tree',
    'motorcycle',
    'mountain',
    'mouse',
    'mushroom',
    'oak_tree',
    'orange',
    'orchid',
    'otter',
    'palm_tree',
    'pear',
    'pickup_truck',
    'pine_tree',
    'plain',
    'plate',
    'poppy',
    'porcupine',
    'possum',
    'rabbit',
    'raccoon',
    'ray',
    'road',
    'rocket',
    'rose',
    'sea',
    'seal',
    'shark',
    'shrew',
    'skunk',
    'skyscraper',
    'snail',
    'snake',
    'spider',
    'squirrel',
    'streetcar',
    'sunflower',
    'sweet_pepper',
    'table',
    'tank',
    'telephone',
    'television',
    'tiger',
    'tractor',
    'train',
    'trout',
    'tulip',
    'turtle',
    'wardrobe',
    'whale',
    'willow_tree',
    'wolf',
    'woman',
    'worm',
]

fine_label_map = {i: label for i, label in enumerate(fine_labels)}

In [10]:
fine_label_map[0]

'apple'

In [11]:
# pre-load the original_model

sample_checkpoint = torch.load('../output_preserve/checkpoint/task3_checkpoint.pth')
sample_args = sample_checkpoint['args']

original_model = create_model(
    sample_args.model,
    pretrained=sample_args.pretrained,
    num_classes=sample_args.nb_classes,
    drop_rate=sample_args.drop,
    drop_path_rate=sample_args.drop_path,
    drop_block_rate=None,
)


model = create_model(
    sample_args.model,
    pretrained=sample_args.pretrained,
    num_classes=sample_args.nb_classes,
    drop_rate=sample_args.drop,
    drop_path_rate=sample_args.drop_path,
    drop_block_rate=None,
    prompt_length=sample_args.length,
    embedding_key=sample_args.embedding_key,
    prompt_init=sample_args.prompt_key_init,
    prompt_pool=sample_args.prompt_pool,
    prompt_key=sample_args.prompt_key,
    pool_size=sample_args.size,
    top_k=sample_args.top_k,
    batchwise_prompt=sample_args.batchwise_prompt,
    prompt_key_init=sample_args.prompt_key_init,
    head_type=sample_args.head_type,
    use_prompt_mask=sample_args.use_prompt_mask,
    )

# change data-path
sample_args.data_path = '../local_datasets/'
data_loader, class_mask = build_continual_dataloader(sample_args)
device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')

model = model.to(device)
original_model = original_model.to(device)


Files already downloaded and verified
Files already downloaded and verified


In [12]:


with torch.no_grad():
    for model_task_id in range(10):
        checkpoint = torch.load(f'../output_preserve/checkpoint/task{model_task_id+1}_checkpoint.pth')
        print(f"Loaded checkpoint for task {model_task_id+1}")
        args = checkpoint['args']
        model.load_state_dict(checkpoint['model'])
        model_task_outputs = []

        for eval_task_id, task_specific_loader in enumerate(data_loader):
            print(f"Processing task {eval_task_id+1}")

            for input, target in task_specific_loader['val']:
                
                input = input.to(device, non_blocking=True)
                target = target.to(device, non_blocking=True)

                # Get Query Embeddings
                output = original_model(input)
                cls_features = output['pre_logits']

                # Get Prompt Embeddings
                output = model(input, task_id=eval_task_id, cls_features=cls_features)
                logits = output['logits']

                pred_label = logits.argmax(dim=1)
                top_2_values, top_2_indices = logits.topk(2, dim=1)

                top_k_num = output['prompt_idx'].shape[1]
                total_prompt_num = output['similarity'].shape[1]


                # Get Prompt Info
                prompt_idx_df = pd.DataFrame(output['prompt_idx'].cpu().numpy(), columns=[f'selected_prompt_idx_{i}' for i in range(top_k_num)])
                similarity_df = pd.DataFrame(output['similarity'].detach().cpu().numpy(), columns=[f'similarity_{i}' for i in range(total_prompt_num)])

                # Get Pred Info
                pred_df = pd.DataFrame(pred_label.cpu().numpy(), columns=['y_pred'])
                pred_df['y_pred_str'] = pred_df['y_pred'].apply(lambda x: fine_label_map[x])
                pred_df['correct'] = pred_df['y_pred'] == target.cpu().numpy()
                top_2_df = pd.DataFrame(top_2_indices.cpu().numpy(), columns=['top_1_pred', 'top_2_pred'])
                top_2_values_df = pd.DataFrame(top_2_values.cpu().numpy(), columns=['top_1_value', 'top_2_value'])


                batch_prompt_info_df = pd.DataFrame(target.cpu().numpy(), columns=['y_true'])
                batch_prompt_info_df['y_true_str'] = batch_prompt_info_df['y_true'].apply(lambda x: fine_label_map[x])
                batch_prompt_info_df['current_task'] = model_task_id + 1
                batch_prompt_info_df['eval_task'] = eval_task_id + 1
                batch_prompt_info_df['num_prompts'] = total_prompt_num 
                batch_prompt_info_df['top_k_num'] = top_k_num
                column_order = ['current_task', 'eval_task', 'y_true','y_pred', 'correct', 'y_pred_str', 'y_true_str','top_1_pred', 'top_2_pred',
                                'top_1_value', 'top_2_value',
                                'selected_prompt_idx_0', 'selected_prompt_idx_1',
                                'selected_prompt_idx_2', 'selected_prompt_idx_3',
                                'selected_prompt_idx_4', 'similarity_0', 'similarity_1', 'similarity_2',
                                'similarity_3', 'similarity_4', 'similarity_5', 'similarity_6',
                                'similarity_7', 'similarity_8', 'similarity_9', 'num_prompts', 'top_k_num']

                batch_total_df = pd.concat([batch_prompt_info_df, prompt_idx_df, similarity_df, pred_df, top_2_df, top_2_values_df], axis=1)
                batch_total_df = batch_total_df[column_order]
                model_task_outputs.append(batch_total_df)
            
            if eval_task_id == model_task_id:
                # The CL model has not seen the task beyond this point
                model_task_outputs = pd.concat(model_task_outputs, ignore_index=True)
                # Save the model task outputs
                model_task_outputs.to_csv(f"prompt_analysis/cifar100_task{str(model_task_id+1).zfill(2)}_top{str(top_k_num).zfill(2)}_prompt{str(total_prompt_num).zfill(2)}.csv", index=False)
                # save as pickle
                model_task_outputs.to_pickle(f"prompt_analysis/cifar100_task{str(model_task_id+1).zfill(2)}_top{str(top_k_num).zfill(2)}_prompt{str(total_prompt_num).zfill(2)}.pkl")
                break




Loaded checkpoint for task 1
Processing task 1
Loaded checkpoint for task 2
Processing task 1
Processing task 2
Loaded checkpoint for task 3
Processing task 1
Processing task 2
Processing task 3
Loaded checkpoint for task 4
Processing task 1
Processing task 2
Processing task 3
Processing task 4
Loaded checkpoint for task 5
Processing task 1
Processing task 2
Processing task 3
Processing task 4
Processing task 5
Loaded checkpoint for task 6
Processing task 1
Processing task 2
Processing task 3
Processing task 4
Processing task 5
Processing task 6
Loaded checkpoint for task 7
Processing task 1
Processing task 2
Processing task 3
Processing task 4
Processing task 5
Processing task 6
Processing task 7
Loaded checkpoint for task 8
Processing task 1
Processing task 2
Processing task 3
Processing task 4
Processing task 5
Processing task 6
Processing task 7
Processing task 8
Loaded checkpoint for task 9
Processing task 1
Processing task 2
Processing task 3
Processing task 4
Processing task 5
P

In [13]:
model_task_outputs

Unnamed: 0,current_task,eval_task,y_true,y_pred,correct,y_pred_str,y_true_str,top_1_pred,top_2_pred,top_1_value,top_2_value,selected_prompt_idx_0,selected_prompt_idx_1,selected_prompt_idx_2,selected_prompt_idx_3,selected_prompt_idx_4,similarity_0,similarity_1,similarity_2,similarity_3,similarity_4,similarity_5,similarity_6,similarity_7,similarity_8,similarity_9,num_prompts,top_k_num
0,10,1,0,0,True,apple,apple,0,57,8.334985,4.439066,1,6,3,0,9,0.168496,0.170098,0.109390,0.168889,0.110622,0.112097,0.169646,0.108429,0.074886,0.166834,10,5
1,10,1,8,8,True,bicycle,bicycle,8,48,8.052672,2.769730,8,7,2,4,5,0.167451,0.168473,0.191274,0.167684,0.188641,0.184920,0.167878,0.194242,0.230847,0.163738,10,5
2,10,1,8,8,True,bicycle,bicycle,8,48,7.875314,2.391398,1,6,3,0,9,0.151491,0.158051,0.072830,0.152704,0.071429,0.070072,0.155395,0.073359,0.073442,0.139758,10,5
3,10,1,4,4,True,beaver,beaver,4,55,8.793460,5.722698,7,2,4,5,9,0.209097,0.208866,0.256777,0.207719,0.252587,0.247815,0.208274,0.256845,0.217684,0.225385,10,5
4,10,1,6,6,True,bee,bee,6,14,8.855971,2.801218,0,3,9,6,1,0.198590,0.196920,0.181595,0.198302,0.182610,0.184047,0.197970,0.180869,0.173747,0.198250,10,5
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9995,10,10,97,97,True,wolf,wolf,97,38,4.919327,4.084279,7,2,8,4,5,0.203094,0.203449,0.349487,0.201402,0.336886,0.327498,0.203448,0.353423,0.347133,0.227263,10,5
9996,10,10,95,95,True,whale,whale,95,30,6.782552,5.408639,7,2,4,5,9,0.271061,0.266677,0.463722,0.269426,0.455091,0.446660,0.268104,0.465440,0.248067,0.349964,10,5
9997,10,10,97,97,True,wolf,wolf,97,42,7.388253,2.740940,7,2,4,5,9,0.125025,0.123019,0.214582,0.123688,0.210589,0.206101,0.123204,0.216436,0.122216,0.159413,10,5
9998,10,10,96,47,False,maple_tree,willow_tree,47,52,6.212194,5.442663,8,7,2,4,5,0.205754,0.206914,0.414306,0.205703,0.398494,0.385857,0.206960,0.424383,0.573680,0.207531,10,5


In [14]:
model_task_outputs.columns

Index(['current_task', 'eval_task', 'y_true', 'y_pred', 'correct',
       'y_pred_str', 'y_true_str', 'top_1_pred', 'top_2_pred', 'top_1_value',
       'top_2_value', 'selected_prompt_idx_0', 'selected_prompt_idx_1',
       'selected_prompt_idx_2', 'selected_prompt_idx_3',
       'selected_prompt_idx_4', 'similarity_0', 'similarity_1', 'similarity_2',
       'similarity_3', 'similarity_4', 'similarity_5', 'similarity_6',
       'similarity_7', 'similarity_8', 'similarity_9', 'num_prompts',
       'top_k_num'],
      dtype='object')