## Overview

This notebook contains the code for visualizing the model results, for analysis and interpretability.

### Plotting Accuracy Drops Across Tasks

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import ast
import os

def viz_acc_drops(accuracy_file_path, num_total_classes, num_classes_per_task, dataset):
    parsed_data = {}
    with open(accuracy_file_path, 'r') as f:
        header = f.readline().strip()

        for line_num, line in enumerate(f, start=2):
            line = line.strip()
            first_comma_idx = line.find(',')
            task_id_str = line[:first_comma_idx]
            task_id = int(task_id_str)
            second_comma_idx = line.find(',', first_comma_idx + 1)
            list_str = line[second_comma_idx + 1:]
            per_class_acc_list = ast.literal_eval(list_str)
            if not isinstance(per_class_acc_list, list):
                raise TypeError(f"Parsed data for 'Per-Class Accuracy' is not a list (type: {type(per_class_acc_list)}).")
            # float conversion
            per_class_acc_list = [float(acc) for acc in per_class_acc_list]
            parsed_data[task_id] = per_class_acc_list


    num_tasks_found = len(parsed_data)
    if num_tasks_found == 0:
        print("No valid task data parsed. Exiting.")
        exit()
    
    # create a matrix: rows=tasks, cols=classes
    acc_matrix = np.full((num_tasks_found, num_total_classes), np.nan)
    sorted_task_ids = sorted(parsed_data.keys())

    for task_idx, task_id in enumerate(sorted_task_ids):
        accuracies = parsed_data[task_id]
        num_classes_in_task_eval = len(accuracies)
        if num_classes_in_task_eval > num_total_classes:
            print(f"Warning: Task {task_id} reported {num_classes_in_task_eval} accuracies, exceeding total classes {num_total_classes}. Truncating.")
            num_classes_in_task_eval = num_total_classes
            accuracies = accuracies[:num_total_classes]
        acc_matrix[task_idx, :num_classes_in_task_eval] = accuracies

    # plot the heatmap
    plt.figure(figsize=(20, max(5, num_tasks_found * 0.7)))

    heatmap = sns.heatmap(
        acc_matrix,
        annot=False,
        fmt=".2f",
        cmap="viridis",
        linewidths=0.2,
        linecolor='lightgrey',
        cbar_kws={'label': 'Per-Class Accuracy'},
        vmin=0.0,
        vmax=1.0
    )

    plt.xlabel("Class ID")
    plt.ylabel("Evaluation Point (After Task X Completed)")
    plt.title(f"Per-Class Accuracy After Each Task - {dataset}", fontsize=16)
    plt.xticks(ticks=np.arange(0, num_total_classes, 5) + 0.5, labels=np.arange(0, num_total_classes, 5), rotation=90, fontsize=8)
    plt.yticks(ticks=np.arange(num_tasks_found) + 0.5, labels=[f"After Task {t}" for t in sorted_task_ids], rotation=0)

    for i in range(num_classes_per_task, num_total_classes, num_classes_per_task):
        plt.axvline(x=i, color='white', linestyle='--', linewidth=1.0)

    plt.tight_layout()
    plt.show()

In [None]:
# texture dataset
accuracy_file_path = 'accuracies-DTD.txt'
num_total_classes = 47
num_classes_per_task = 10

viz_acc_drops(accuracy_file_path, num_total_classes, num_classes_per_task, 'DTD')

### Plotting Relative Forgetting Measure

Average Relative Forgetting Measure is the proportion drop in accuracy for each task from the peak, averaged across all seen-before tasks. 

In [None]:
# using relative forgetting
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import ast

def viz_relative_forgetting_measure(filepath, dataset):
  """
  Calculates and visualizes the average relative forgetting.

  Relative Forgetting for a previous task i after learning task t is:
  (max_accuracy(task i) - current_accuracy(task i)) / max_accuracy(task i).
  If max_accuracy(task i) is 0, relative forgetting is 0.
  Ensures the value is non-negative.

  The average relative forgetting after task t is the average of this
  value across all tasks i < t.
  """
  data = []
  with open(filepath, 'r') as f:
    for line_num, line in enumerate(f, start=2):
      line = line.strip()
      parts = line.strip().split(',', 2)
      if len(parts) == 3:
        task = int(parts[0])
        overall_acc = float(parts[1])
        per_class_str = parts[2]
        per_class_list = ast.literal_eval(per_class_str)
        data.append({
            'Task': task,
            'Overall Accuracy': overall_acc,
            'Per-Class Accuracy': [float(acc) for acc in per_class_list]
        })

  df = pd.DataFrame(data)
  df = df.sort_values(by='Task').reset_index(drop=True)
  num_tasks = len(df)

  # find classes per task
  classes_per_task = []
  cumulative_classes = 0
  class_indices_per_task = {}
  for i in range(num_tasks):
      current_total_classes = len(df.loc[i, 'Per-Class Accuracy'])
      num_new_classes = current_total_classes - cumulative_classes
      classes_per_task.append(num_new_classes)
      task_class_indices = list(range(cumulative_classes, current_total_classes))
      class_indices_per_task[i] = task_class_indices
      cumulative_classes = current_total_classes


  print(f"Detected {num_tasks} tasks.")
  print(f"Classes introduced per task: {classes_per_task}")

  # peak acc per task
  peak_task_accuracies = {}
  for i in range(num_tasks):
    task_class_indices = class_indices_per_task.get(i, [])
    accuracies_at_eval_i = df.loc[i, 'Per-Class Accuracy']

    # get accuracies for task i's classes
    task_i_class_accuracies = [accuracies_at_eval_i[idx] for idx in task_class_indices]
    peak_task_accuracies[i] = np.mean(task_i_class_accuracies) # store the average of accuracies for the 10 classes belonging to this task


  print(f"Peak average accuracy per task (Acc(i,i)): {peak_task_accuracies}")

  # get average relative forgetting
  average_relative_forgetting = []
  evaluation_points = [] # tasks after which forgetting is measured (1 to T-1)

  for t in range(1, num_tasks):
      total_relative_forgetting_at_t = 0.0
      num_previous_tasks_measured = 0

      accuracies_at_eval_t = df.loc[t, 'Per-Class Accuracy']

      for i in range(t):
        task_i_class_indices = class_indices_per_task.get(i, [])

        # extract accuracies for task i's classes measured after task t
        task_i_accuracies_at_t = [accuracies_at_eval_t[idx] for idx in task_i_class_indices]
        current_acc_i_at_t = np.mean(task_i_accuracies_at_t) if task_i_accuracies_at_t else 0.0

        peak_acc_i = peak_task_accuracies.get(i, 0)

        # relative forgetting calculation
        if peak_acc_i > 1e-6:
            relative_forgetting_i_at_t = (peak_acc_i - current_acc_i_at_t) / peak_acc_i
        else:
            relative_forgetting_i_at_t = 0.0
        total_relative_forgetting_at_t += relative_forgetting_i_at_t
        num_previous_tasks_measured += 1

      avg_relative_forgetting_t = total_relative_forgetting_at_t / num_previous_tasks_measured
      average_relative_forgetting.append(avg_relative_forgetting_t)
      evaluation_points.append(t + 1)

  evaluation_points_xaxis = [pt for pt in range(1, num_tasks)] # Should match len(average_relative_forgetting)
  print(f"\nAverage Relative Forgetting after tasks {evaluation_points_xaxis}: {average_relative_forgetting}")


  # plotting
  plt.figure(figsize=(8, 5))
  if len(evaluation_points_xaxis) == len(average_relative_forgetting):
      plt.plot(evaluation_points_xaxis, average_relative_forgetting, marker='o', linestyle='-')
      plt.xlabel("Task Number Completed (Evaluation Point)")
      plt.ylabel("Average Relative Forgetting")
      plt.title(f"Average Relative Forgetting vs. Number of Tasks Completed - {dataset}")
      plt.xticks(evaluation_points_xaxis)
      plt.grid(True, linestyle='--', alpha=0.6)
      plt.ylim(bottom=0, top=1.05)
      plt.show()
  else:
      print("Error: Mismatch between evaluation points and calculated forgetting values. Cannot plot.")

### Confusion Matrix Plotting

In [None]:
test_dataset = datasets.DTD(
    root=data_root,
    split='test',
    download=True,
    transform=transform
)

# test_dataset = FGVCAircraft(
#     root='./data',
#     split='test',
#     annotation_level=granularity,
#     transform=transform,
#     download=True
# )

In [None]:
import numpy as np
import torch
import numpy as np
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader, Subset, ConcatDataset
from torchvision import models
import torch.optim as optim

# seeds
seed = 88
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
cudnn.deterministic = True
cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
g = torch.Generator()
g.manual_seed(seed)

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

In [None]:
import torch.nn as nn
def modify_resnet_head(model, num_classes):
  """
  Modify the last fully connected layer of the ResNet model to match the number of classes.
  """

  old_fc = model.fc
  old_num_classes = old_fc.out_features
  num_ftrs = old_fc.in_features

  # Create the new head
  new_fc = nn.Linear(num_ftrs, num_classes).cuda()

  # Copy weights and biases from the old head
  if old_num_classes < num_classes:
    new_fc.weight.data[:old_num_classes, :] = old_fc.weight.data.clone().cuda()
    new_fc.bias.data[:old_num_classes] = old_fc.bias.data.clone().cuda()

  model.fc = new_fc
  return model

In [None]:
def get_confusion_matrix(model, test_loader, num_classes):
    """
    Computes the confusion matrix for a given model and dataset.
    """

    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for imgs, labels in test_loader:
            imgs, labels = imgs.cuda(), labels.cuda()
            output = model(imgs)
            preds = output.argmax(dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())

    cm = confusion_matrix(all_labels, all_preds, labels=range(num_classes))
    return cm

num_classes = 47

model = models.resnet18(pretrained=True)
modify_resnet_head(model, num_classes)
model.load_state_dict(torch.load("model-naive-DTD.pth"))
model.cuda()

test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=256,
    shuffle=False,
    num_workers=4,
    worker_init_fn=seed_worker,
    generator=g
)

# get the confusion matrix
cm = get_confusion_matrix(model, test_loader, num_classes)

# --- display the Confusion Matrix ---
plt.figure(figsize=(12, 10))
sns.heatmap(cm, annot=False, fmt="d", cmap="Blues", linewidths=.5)
plt.xlabel("Predicted Labels")
plt.ylabel("True Labels")
plt.title("Confusion Matrix")
plt.show()

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import ast
import os

def calculate_and_visualize_bwt(accuracy_file_paths, dataset_names):
    """
    Calculates Backward Transfer (BWT) from per-class accuracies in multiple text files
    and visualizes them with a grouped bar chart.
    """

    num_datasets = len(accuracy_file_paths)
    all_bwt_values = []

    for file_path in accuracy_file_paths:
        parsed_data = {}
        try:
            with open(file_path, 'r') as f:
                for line in f:
                    line = line.strip()
                    task_id_str, _, per_class_acc_str = line.split(',', 2)
                    task_id = int(task_id_str)
                    per_class_acc_list = ast.literal_eval(per_class_acc_str)
                    parsed_data[task_id] = [float(acc) for acc in per_class_acc_list]
        except FileNotFoundError:
            print(f"Error: Accuracy file not found at '{file_path}'")
            return

        # calculate BWT
        bwt_values = []
        for task_id in range(1, len(parsed_data)):
            prev_task_accuracies = parsed_data.get(task_id - 1, [])
            current_task_accuracies = parsed_data.get(task_id, [])

            relevant_classes = len(prev_task_accuracies)
            if relevant_classes > 0:
                initial_performance = np.mean(prev_task_accuracies)
                current_performance = np.mean(current_task_accuracies[:relevant_classes])
                bwt = initial_performance - current_performance
                bwt_values.append(bwt)

        all_bwt_values.append(bwt_values)

    # Plot BWT values
    num_tasks = len(all_bwt_values[0])
    task_labels = [f"Task {i}" for i in range(1, num_tasks + 1)]
    bar_width = 0.8 / num_datasets
    x_positions = np.arange(num_tasks)

    fig, ax = plt.subplots(figsize=(10, 6))

    for i, dataset_name in enumerate(dataset_names):
        ax.bar(x_positions + i * bar_width, all_bwt_values[i], width=bar_width, label=dataset_name)

    ax.set_xlabel("Task")
    ax.set_ylabel("Backward Transfer (BWT)")
    ax.set_title("Backward Transfer Comparison")
    ax.set_xticks(x_positions + bar_width * (num_datasets -1)/2)
    ax.set_xticklabels(task_labels, rotation=45, ha='right')
    ax.legend()
    plt.tight_layout()
    plt.show()