In [None]:
from src.toolkit.process_results import extract_results
from src.toolkit.post_metrics import compute_average_forgetting, compute_average_accuracy
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import os
import numpy as np
plt.style.use("matplotlibrc.template")

In [None]:
RANKS = {"resnet": [1, 2, 3, 4, 5, 6, 7], "vit": [1, 2, 4, 6, 8, 16, 32]}

basedir = "MY_RESULTS_DIR"

BASENAMES = {"vit": os.path.join(basedir, "lora_vit_seeds"), "resnet": os.path.join(basedir, "lora_resnet_seeds")}

params = {"resnet": 0.08 , "vit": 0.042}

all_frames = []

for network in ["vit", "resnet"]:
    ranks = RANKS[network]
    basename = BASENAMES[network]
    for rank in ranks:
        path = os.path.join(basename, f"lora_forget_{rank}")

        try:
            results = extract_results(path)
            df = results["training"]
        except Exception:
            print(f"Could not load {path}")
            continue
            
        df.pop("name")
        
        df["lora_rank"] = rank
        if rank == "full":
            df["ppercent"] = 100
        else:
            df["ppercent"] = params[network] * rank
        
        df["network"] = network

        #df["seed"] = 0

        df = compute_average_forgetting(df, num_tasks=5)
        df = compute_average_accuracy(df, num_tasks=5)

        # Clean ds

        df = df.dropna(subset=["Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp000"])
        
        all_frames.append(df)


frame = pd.concat(all_frames)
frame = frame.groupby(["step", "lora_rank", "network", "seed"]).first()

In [None]:
# Decorate with training exp

steps = frame.index.get_level_values("step")
steps = steps.unique()
print(steps)

step_idx_map = {step: idx for idx, step in enumerate(steps)}
frame["Training Exp"] = frame.index.get_level_values("step").map(step_idx_map)
training_exps = frame["Training Exp"].unique()

In [None]:
network="resnet"
tdf = frame.query("network == @network")

tdf[tdf.index.get_level_values("lora_rank") == 0]["Top1_Acc_Exp/eval_phase/test_stream/Task003/Exp003"]

In [None]:
# Figure individual tasks (Short) Accuracy

%matplotlib inline
plt.style.use("matplotlibrc.template")
task = 0

task_name = {0: "Imagenet", 1: "Cars", 2: "Flowers", 3: "Aircraft", 4: "Birds"}

add_legend = False

network = "vit"
vit_palette=sns.color_palette("Blues", 15)[4:]
ax = sns.lineplot(frame.query("network == @network"), x="Training Exp", y=f"Top1_Acc_Exp/eval_phase/test_stream/Task00{task}/Exp00{task}", hue="lora_rank", errorbar=("sd", 1.0), palette=vit_palette, legend=add_legend, marker="o")

network = "resnet"
resnet_palette=sns.color_palette("Greens", 15)[4:]
sns.lineplot(frame.query("network == @network"), x="Training Exp", y=f"Top1_Acc_Exp/eval_phase/test_stream/Task00{task}/Exp00{task}", hue="lora_rank", errorbar=("sd", 1.0), palette=resnet_palette, legend=False, marker="o")

handles, labels = ax.get_legend_handles_labels()

if add_legend:
    indexes_keep = [0, 1, 5, 7, 8]
    new_handles = [handles[i] for i in indexes_keep]
    new_labels = [labels[i] for i in indexes_keep]
    ax.legend(title="LoRA Rank", handles=new_handles, labels=new_labels, loc="lower left")
    sns.move_legend(ax, "upper left", bbox_to_anchor=(1, 1))

plt.ylabel(f"{task_name[task]} Accuracy")
plt.xlabel("Training Task")

plt.xticks(training_exps[task:], ["Imagenet", "Cars", "Flowers", "Aircraft", "Birds"][task:])


In [None]:
# Figure individual tasks (Short) Forgetting

%matplotlib qt
plt.style.use("matplotlibrc.template")
task = 0

task_name = {0: "Imagenet", 1: "Cars", 2: "Flowers", 3: "Aircraft", 4: "Birds"}

add_legend = False

network = "vit"
vit_palette=sns.color_palette("Blues", 15)[4:]
ax = sns.lineplot(frame.query("network == @network"), x="Training Exp", y=f"Forgetting_Top1_Acc_Exp/eval_phase/test_stream/Task00{task}/Exp00{task}", hue="lora_rank", errorbar=("sd", 1.0), palette=vit_palette, legend=add_legend, marker="o")

network = "resnet"
resnet_palette=sns.color_palette("Greens", 15)[4:]
sns.lineplot(frame.query("network == @network"), x="Training Exp", y=f"Forgetting_Top1_Acc_Exp/eval_phase/test_stream/Task00{task}/Exp00{task}", hue="lora_rank", errorbar=("sd", 1.0), palette=resnet_palette, legend=False, marker="o")

handles, labels = ax.get_legend_handles_labels()

if add_legend:
    indexes_keep = [0, 1, 5, 7, 8]
    new_handles = [handles[i] for i in indexes_keep]
    new_labels = [labels[i] for i in indexes_keep]
    ax.legend(title="LoRA Rank", handles=new_handles, labels=new_labels, loc="lower left")
    sns.move_legend(ax, "upper left", bbox_to_anchor=(1, 1))

plt.ylabel(f"{task_name[task]} Accuracy")
plt.xlabel("Training Task")

plt.xticks(training_exps[task:], ["Imagenet", "Cars", "Flowers", "Aircraft", "Birds"][task:])


In [None]:
# Legend

# Legend generation

fig, ax = plt.subplots()

sm = plt.cm.ScalarMappable(cmap="Blues", norm=plt.Normalize(0, 40))

cbar = fig.colorbar(sm, ax=ax)

cbar.set_ticks([4, 8, 16, 32])
cbar.set_ticklabels([4, 8, 16, 32])
cbar.set_label("LoRA Rank (ViT)")

sm = plt.cm.ScalarMappable(cmap="Greens", norm=plt.Normalize(0, 8))

cbar = fig.colorbar(sm, ax=ax)

cbar.set_ticks([1, 3, 5, 7])
cbar.set_ticklabels([1, 3, 5, 7])
cbar.set_label("LoRA Rank (Resnet)")

plt.gca().set_visible(False)

In [None]:
len(handles)

In [None]:
# Figure individual tasks (Long)

%matplotlib inline
plt.style.use("matplotlibrc.template")
task = 1

task_name = {0: "Imagenet", 1: "Cars", 2: "Flowers", 3: "Aircraft", 4: "Birds"}

add_legend = False

network = "vit"
vit_palette=sns.color_palette("Blues", 15)[4:]
ax = sns.lineplot(frame.query("network == @network"), x="Training Exp", y=f"Top1_Acc_Exp/eval_phase/test_stream/Task00{task}/Exp00{task}", hue="lora_rank", errorbar=("sd", 1.0), palette=vit_palette, legend=add_legend, marker="o")

network = "resnet"
resnet_palette=sns.color_palette("Greens", 15)[4:]
sns.lineplot(frame.query("network == @network"), x="Training Exp", y=f"Top1_Acc_Exp/eval_phase/test_stream/Task00{task}/Exp00{task}", hue="lora_rank", errorbar=("sd", 1.0), palette=resnet_palette, legend=False, marker="o")

plt.ylabel(f"{task_name[task]} Accuracy")
plt.xlabel("Training Task")

plt.xticks(training_exps[task:], ["Imagenet", "Cars", "Flowers", "Aircraft", "Birds", "Cars", "Flowers", "Aircraft", "Birds"][task:])

plt.xticks(fontsize=15)

#plt.xticks(training_exps[task:], ["Imagenet", "Cars", "Flowers", "Aircraft", "Birds"][task:])
#plt.xticks(training_exps[task:], ["Imagenet", "Cars", "Flowers", "Aircraft", "Birds", "Cars", "Flowers", "Aircraft", "Birds"][task:])
#plt.savefig("./cars_long.pdf")

In [None]:
# Figure Average accuracy

add_legend = False

network = "vit"
ax = sns.lineplot(frame.query("network == @network"), x="Training Exp", y="Average_Accuracy", hue="lora_rank", errorbar=("sd", 1.0), palette=vit_palette, marker="o", legend=add_legend)

network = "resnet"
sns.lineplot(frame.query("network == @network"), x="Training Exp", y="Average_Accuracy", hue="lora_rank", errorbar=("sd", 1.0), palette=resnet_palette, marker="o", legend=False)

handles, labels = ax.get_legend_handles_labels()

new_handles = handles
new_labels = labels

if add_legend:
    ax.legend(title="LoRA Rank", handles=new_handles, labels=new_labels, loc="lower left")
#sns.move_legend(ax, "upper left", bbox_to_anchor=(1, 1))

#ticks = [0, 1, 2, 3, 4, 1, 2, 3, 4]
#plt.xticks(range(len(ticks)), ticks)

plt.xticks(training_exps, ["Imagenet", "Cars", "Flowers", "Aircraft", "Birds", "Cars", "Flowers", "Aircraft", "Birds"])
#plt.xticks(training_exps, ["Imagenet", "Cars", "Flowers", "Aircraft", "Birds"])

plt.xticks(fontsize=15)

plt.ylim(0.35, 0.85)

plt.xlabel("Training Task")
plt.ylabel("Average Accuracy")

In [None]:
network = "vit"
sns.lineplot(frame.query("network == @network"), x="lora_rank", y="ppercent")

network = "resnet"
sns.lineplot(frame.query("network == @network"), x="lora_rank", y="ppercent")

In [None]:
# Data for the tables

for r in [1, 3, 4, 5, 7]:
    network = "resnet"
    max_index = frame.index.get_level_values("step").max()
    
    df = frame.query("network == @network").query("lora_rank == @r").query("step == @max_index")
    
    print("Forgetting:", df["Average_Forgetting"])
    print("Average Acc:", df["Average_Accuracy"])

In [None]:

for r in [1, 2, 4, 8, 16, 32]:
    network = "vit"
    max_index = frame.index.get_level_values("step").max()
    
    df = frame.query("network == @network").query("lora_rank == @r").query("step == @max_index")
    
    print("Forgetting:", df["Average_Forgetting"])
    print("Average Acc:", df["Average_Accuracy"])

In [None]:
# Data for the tables
network = "vit"
max_index = frame.index.get_level_values("step").max()
    
df = frame.query("network == @network").query("step == @max_index")
    
sns.lineplot(df, x="lora_rank", y="Average_Forgetting")