In [None]:
'''© 2024 Nokia
Licensed under the BSD 3-Clause Clear License
SPDX-License-Identifier: BSD-3-Clause-Clear '''
import numpy as np
import pandas as pd
import os
import sys
import torch
import seaborn as sns
import matplotlib.pyplot as plt
sns.set(style='whitegrid')

import torchvision
import torchvision.transforms as T

%matplotlib inline
%load_ext autoreload
%autoreload 2

sys.path.append('../')

from utils import tonp
import utils

torch.set_grad_enabled(False)
device = torch.device('cuda')

import matplotlib
matplotlib.rcParams["figure.dpi"] = 200
matplotlib.rcParams["figure.facecolor"] = 'white'

## Loading task encoder and retrieving tasks

In [None]:
TASK_TYPE = 'resnet18'

# TODO: Set the task ID (wandb id), checkpoint name, repo name, dataset name
# If you did not edit the checkpoint path in taskdiscovery code, you may not need to edit the checkpoint path.
TASK_ID = 'your_task_id'
CHECKPOINT_NAME = 'your_checkpoint_name'
REPO_NAME = 'your_repository_name'
PATH = os.path.abspath(f"../{REPO_NAME}/{TASK_ID}/checkpoints/{CHECKPOINT_NAME}.ckpt")
DATASET = 'your_dataset_name'

BATCH_SIZE = 64 

In [None]:
from models.as_uniformity import ASUniformityTraining

model = ASUniformityTraining.load_from_checkpoint(PATH, dataset=DATASET)
model.to(device)
model.eval();

In [4]:
trainloader = model.data_module.train_dataloader(batch_size=BATCH_SIZE) #, drop_last=False)

# Retrieve file names from index in dataset
pil_images = [item[0] for item in trainloader.dataset]
data_sample_indices = [item[2] for item in trainloader.dataset]
df = pd.DataFrame({"PIL_image": pil_images, "data_sample_index": data_sample_indices})
filenames = [item[0] for item in trainloader.dataset.imgs]
df_filenames = pd.DataFrame({"filename": filenames, "data_sample_index": data_sample_indices})
merged_df = df.merge(df_filenames, on='data_sample_index')

In [5]:
xs, idxs, logits, tasks = [], [], [], []
batch_num = 0
for batch in trainloader:
    l = model.logits_all_tasks(batch[0].to(model.device)).cpu()
    t = (l > 0).long()
    xs.append(batch[0])
    idxs.append(batch[2])
    logits.append(l)
    tasks.append(t)

xs, idxs, logits, tasks = map(lambda a: torch.cat(a), [xs, idxs, logits, tasks])

## Visualize top k images from each task/class

In [None]:
# You can change the k for top-k data you want to retreive by changing 20 below.
val_botk, idx_botk = (-logits).topk(20,dim=0)
val_topk, idx_topk = logits.topk(20,dim=0)

In [None]:
# Agreement score calculated on the architecture as task-net. E.g., if task-net=mlp then agreement score calculated using mlp architectures.
as_table = pd.read_csv(f'../assets/tasks/agreement_table_{TASK_ID}.csv')
as_table = as_table[as_table['task'] == TASK_TYPE]

if len(as_table) > 0:
    as_table = as_table.drop(columns=['task'])
    as_table = as_table.set_index('task_idx')
else:
    as_table = None

# If you don't have agreement score (as) table under the visualization, you need to create one from your wandb directory.
# For your reference, we put a dummy as table
file_path = f'./as_table_check_{TASK_ID}.csv'
as_table.to_csv(file_path, index=True)

In [None]:
def vis_topk(idx_topk, idxbotk, renorm=True, as_table=None, merged_df=None):
    nb_tasks = tasks.size(1)    # 32
    mean = torch.FloatTensor([x / 255.0 for x in [125.3, 123.0, 113.9]])[None, :, None, None]
    std = torch.FloatTensor([x / 255.0 for x in [63.0, 62.1, 66.7]])[None, :, None, None]
    
    xs_viz = xs

    fig,ax = plt.subplots(nb_tasks, 2, figsize=(2*7, nb_tasks*2))


    for tid in range(nb_tasks):
        img_grid = torchvision.utils.make_grid(xs_viz[idx_topk[:,tid]],nrow=10)
        ax[tid,0].imshow(img_grid.permute(1,2,0))

        cor_file_index_pairs = []  # Store pairs of (index, filename)
        xs_tensor_list = xs[idx_topk[:, tid]]
        for xs_item in xs_tensor_list:
            #matching = False
            for index, row in merged_df.iterrows():
                if torch.all(torch.eq(row['PIL_image'], xs_item)):
                    cor_file_index_pairs.append((index, row['filename']))
                    #break

        # Filter out rows in merged_df where no match was found
        # merged_df = merged_df.loc[merged_df.index.isin(idx for idx, _ in cor_file_index_pairs)]
        
        #idx_list, cor_file_names = zip(*cor_file_index_pairs)
        if cor_file_index_pairs:
            idx_list, cor_file_names = zip(*cor_file_index_pairs)
        else:
            print("Empty cor_file_index_pairs")
            return

        cor_file_names_df = pd.DataFrame({'idx': idx_list, 'filename': cor_file_names})

        cor_file_names_file_path = f"{DATASET}/filenames_{TASK_ID}/taskid_{tid}_class1_filenames.csv"
        if not os.path.exists(f"{DATASET}/filenames_{TASK_ID}"):
            os.makedirs(f"{DATASET}/filenames_{TASK_ID}")
        cor_file_names_df.to_csv(cor_file_names_file_path, index=False)

        img_grid = torchvision.utils.make_grid(xs_viz[idx_botk[:,tid]],nrow=10)
        ax[tid,1].imshow(img_grid.permute(1,2,0))
        ylabel = f'task {tid}'
        if as_table is not None: ylabel += f'\nAS={as_table.loc[tid]["as"]:.2f}'    # as_table length is 20
        ax[tid,0].set_ylabel(ylabel, fontsize=15)

        #cor_file_names = []
        cor_file_index_pairs = []
        xs_tensor_list = xs[idx_botk[:,tid]] #.tolist()
        for xs_item in xs_tensor_list:
            #matching = False
            for index, row in merged_df.iterrows():
                if torch.all(torch.eq(row['PIL_image'], xs_item)):
                    #cor_file_names.append(row['filename'])
                    cor_file_index_pairs.append((index, row['filename']))
                    #matching = True
                    # break
        
        #idx_list = idx_botk[:,tid].tolist()
        if cor_file_index_pairs:
            idx_list, cor_file_names = zip(*cor_file_index_pairs)
        else:
            print("Empty cor_file_index_pairs")
            return
        
        cor_file_names_df = pd.DataFrame({'idx': idx_list, 'filename': cor_file_names})
        cor_file_names_file_path = f"{DATASET}/filenames_{TASK_ID}/taskid_{tid}_class0_filenames.csv"
        cor_file_names_df.to_csv(cor_file_names_file_path, index=False)

    ax[0,0].set_title('class 1', size=20)
    ax[0,1].set_title('class 0', fontsize=20)

    for i in range(nb_tasks):
        ax[i,0].grid(False)
        ax[i,0].set_yticklabels([])
        ax[i,0].set_xticklabels([])
        ax[i,1].axis("off")

    plt.tight_layout()
    if not os.path.exists(DATASET):
        os.makedirs(DATASET)
    
    #plt.savefig(f'{DATASET}/{TASK_ID}_viz.png')

In [None]:
vis_topk(idx_topk, idx_botk, as_table=as_table, merged_df=merged_df)