In [1]:
import wandb
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np


In [2]:
wandb.login(key="8a88a8c49d1c2d31b8677fe0b8eb7d3e3a031f83")
api = wandb.Api()
fill_value = "N/A"

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /export/home/0schindl/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mbelaschindler[0m ([33mbelaschindler-university-hamburg[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
sweep_id = "belaschindler-university-hamburg/0schindl-LayUp_sweeps_question1_selection_method/wk4w5q0t"
sweep2_id = "belaschindler-university-hamburg/0schindl-LayUp_sweeps_question1_selection_method/6uix7luf"
sweep3_id = "belaschindler-university-hamburg/0schindl-LayUp_sweeps_question1_selection_method/lpeje9a7"


sweep = api.sweep(sweep_id)
sweep2 = api.sweep(sweep2_id)
sweep3 = api.sweep(sweep3_id)
runs1 = sweep.runs
runs2 = sweep2.runs
runs3 = sweep3.runs
cil_datasets1 = ["cifar100", "cub", "imageneta", "vtab", "omnibenchmark"]
cil_datasets2 =  ["imagenetr", "cars"]
dil_datasets = ["cddb", "dil_imagenetr", "limited_domainnet"]


In [4]:
def get_forgetting(runs, dataset_names, data):
    for run in runs:
        config = run.config
        summary = run.summary
        
        dataset = config.get("dataset")
        selection_method = config.get("selection_method")
        run_id = run.id
        state = run.state
        num_E = config.get("moe_max_experts")



        if state == "finished" and dataset in dataset_names and (selection_method == "around" or selection_method == "inv_ws_div"):
            # Get the history of the run
            history = run.history()

            run_data = {
                "run_id": run_id,
                "dataset": dataset,
                "selection_method": selection_method           
                }

            # End accuracy of tasks
            for key, value in summary.items():
                if key.startswith("task_") and key.endswith("/acc") and key != "task_mean/acc" and key != "task_wmean/acc":
                    t = int(key.split("_")[1].split("/")[0])

                    # is an initial task
                    if num_E <= t:
                        continue

                    key = "end_" + key
                    run_data[key] = value

            # Initial accuracy of tasks
            for line in history.items():
                key = line[0]
                if key.startswith("task_") and key.endswith("/acc") and key != "task_mean/acc" and key != "task_wmean/acc":
                    t = int(key.split("_")[1].split("/")[0])

                    # is an initial task
                    if num_E <= t:
                        continue

                    values = line[1]
                    filtered_values = values[~np.isnan(values)]
                    index = num_E - t - 1
                    base_acc = filtered_values.iloc[index]
                    key = "initial_" + key
                    run_data[key] = base_acc
            

            data.append(run_data)
    return data

def get_bwt(runs, dataset_names, data):
    for run in runs:
        config = run.config
        summary = run.summary
        
        dataset = config.get("dataset")
        selection_method = config.get("selection_method")
        run_id = run.id
        state = run.state
        num_E = config.get("moe_max_experts")



        if state == "finished" and dataset in dataset_names:
            # Get the history of the run
            history = run.history()

            run_data = {
                "run_id": run_id,
                "dataset": dataset,
                "selection_method": selection_method           
                }

            # End accuracy of tasks
            for key, value in summary.items():
                if key.startswith("task_") and key.endswith("/acc") and key != "task_mean/acc" and key != "task_wmean/acc":
                    t = int(key.split("_")[1].split("/")[0])

                    # is an initial task
                    if num_E <= t:
                        continue

                    key = "end_" + key
                    run_data[key] = value

            # Initial accuracy of tasks
            for line in history.items():
                key = line[0]
                if key.startswith("task_") and key.endswith("/acc") and key != "task_mean/acc" and key != "task_wmean/acc":
                    t = int(key.split("_")[1].split("/")[0])

                    # is an initial task
                    if num_E <= t:
                        continue

                    values = line[1]
                    filtered_values = values[~np.isnan(values)]
                    index = num_E - t - 1
                    base_acc = filtered_values.iloc[index]
                    key = "initial_" + key
                    run_data[key] = base_acc
            

            data.append(run_data)
    data = pd.DataFrame(data)
    return data

def get_average_forgetting(df, num_tasks=5):
    forgetting_values = []
    for index, row in df.iterrows():
        row_forgetting = []
        for i in range(num_tasks):  # Assuming there are tasks 0 to 4
            end_col = f'end_task_{i}/acc'
            initial_col = f'initial_task_{i}/acc'
            if end_col in row and initial_col in row and pd.notna(row[end_col]) and pd.notna(row[initial_col]):
                forgetting = row[initial_col] - row[end_col]
                row_forgetting.append(forgetting)
        # Calculate the average forgetting for the row, ignoring NaNs
        avg_forgetting = np.nanmean(row_forgetting)
        forgetting_values.append(avg_forgetting)
    return forgetting_values

In [5]:
print("####### 1 #######")
data = get_forgetting(runs1, cil_datasets1, [])
print("####### 2 #######")
data = get_forgetting(runs2, cil_datasets2, data)
print("####### 3 #######")
data = get_forgetting(runs3, dil_datasets, data)

# Add the average forgetting as a new column
data = pd.DataFrame(data)
data['average_forgetting'] = get_average_forgetting(data)

print("\nDataFrame der Sweep-Runs:")
pd.set_option('display.width', 5000)
pd.set_option('display.max_colwidth', None)
print(data.sort_values(by=['dataset', "selection_method"]))
pd.set_option('display.width', None)
pd.set_option('display.max_colwidth', None)


####### 1 #######
####### 2 #######
####### 3 #######

DataFrame der Sweep-Runs:
      run_id            dataset selection_method  end_task_0/acc  end_task_1/acc  end_task_2/acc  end_task_3/acc  end_task_4/acc  initial_task_3/acc  initial_task_2/acc  initial_task_0/acc  initial_task_1/acc  initial_task_4/acc  average_forgetting
12  doicok2o               cars           around        0.487256        0.295285        0.260024        0.252370        0.409478            0.428910            0.410693            0.623688            0.387097            0.550425            0.139280
13  k4z61mie               cars       inv_ws_div        0.533733        0.325062        0.323208        0.323460        0.357230            0.464455            0.448360            0.619190            0.359801            0.411908            0.088204
16  5jdhgw9k               cddb           around        0.599000        0.827500        0.597673        0.582500        0.477778            0.582500            0.597673    

In [13]:
print("####### BWT FROM DIL #######")
data = get_bwt(runs3, dil_datasets, [])
data["average_forgetting"] = get_average_forgetting(data)
pd.set_option('display.width', 5000)
pd.set_option('display.max_colwidth', None)
print(data.sort_values(by=['dataset', "selection_method"]))
pd.set_option('display.width', None)
pd.set_option('display.max_colwidth', None)

####### BWT FROM DIL #######
      run_id            dataset selection_method  end_task_0/acc  end_task_1/acc  end_task_2/acc  end_task_3/acc  end_task_4/acc  initial_task_3/acc  initial_task_0/acc  initial_task_1/acc  initial_task_2/acc  initial_task_4/acc  average_forgetting
9   5jdhgw9k               cddb           around        0.599000        0.827500        0.597673        0.582500        0.477778            0.582500            0.599000            0.827500            0.597673            0.477778            0.000000
8   ze5jnwiv               cddb       eucld_dist        0.599000        0.827500        0.597673        0.582500        0.477778            0.582500            0.599000            0.827500            0.597673            0.477778            0.000000
6   x2635mud               cddb   inv_eucld_dist        0.599000        0.827500        0.597673        0.582500        0.477778            0.582500            0.599000            0.827500            0.597673            0.47

In [14]:
sweep4_id = "belaschindler-university-hamburg/0schindl-LayUp_sweeps_question1_selection_method/hbvo6qhj"
sweep4 = api.sweep(sweep4_id)
runs4 = sweep4.runs


print("####### BWT FROM DIL 5 SEEDS #######")
data = get_bwt(runs4, dil_datasets, [])
data["average_forgetting"] = get_average_forgetting(data)
pd.set_option('display.width', 5000)
pd.set_option('display.max_colwidth', None)
print(data.sort_values(by=['dataset', "selection_method"]))
pd.set_option('display.width', None)
pd.set_option('display.max_colwidth', None)

####### BWT FROM DIL 5 SEEDS #######
     run_id        dataset selection_method  end_task_0/acc  end_task_1/acc  end_task_2/acc  end_task_3/acc  end_task_4/acc  initial_task_0/acc  initial_task_2/acc  initial_task_3/acc  initial_task_1/acc  initial_task_4/acc  average_forgetting
0  vza83nsz  dil_imagenetr       inv_ws_div        0.866667        0.700137        0.652000        0.789100        0.613497            0.855914            0.665333            0.770142            0.705601            0.638037            0.002726
1  5cgl8u3s  dil_imagenetr       inv_ws_div        0.873118        0.706284        0.656000        0.793839        0.613497            0.850538            0.669333            0.789100            0.713798            0.625767            0.001159
2  59mip3ad  dil_imagenetr       inv_ws_div        0.846237        0.702869        0.658667        0.793839        0.619632            0.834409            0.672000            0.777251            0.712432            0.625767        