In [1]:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

In [2]:
pd.set_option("display.max_columns", 100)
pd.set_option("display.max_rows", 100)

In [3]:
sns.set_theme(context="paper", style="whitegrid",
              rc={"font.family": "serif",
                    "font.serif": "CMU Serif"})

matplotlib.use("pgf")

# Permuted mnist with strengthening

In [9]:
experiments = [
    {
        "filename": "no_coreset_strengthen/permuted_mnist_coreset_0_epochs_20-100_lr_5e-3_init_var_1e-4_strengthen_{strengthen}_fix_for_real.csv",
        "metric": "Accuracy",
    },
    {
        "filename": "coreset_strengthen/permuted_mnist_coreset_200_epochs_20-100_lr_5e-3_init_var_1e-4_strengthen_{strengthen}_fix_for_real.csv",
        "metric": "Accuracy",
    },
    {
        "filename": "no_coreset_strengthen_reg/permuted_mnist_reg_coreset_0_epochs_20-100_lr_3e-4_init_var_1e-4_strengthen_{strengthen}_fix_for_real.csv",
        "metric": "RMSE",
    },
    {
        "filename": "coreset_strengthen_reg/permuted_mnist_reg_coreset_200_epochs_20-100_lr_3e-4_init_var_1e-4_strengthen_{strengthen}_fix_for_real.csv",
        "metric": "RMSE",
    },
]

In [10]:
factors = [0.8, 1, 1.1, 1.2, 1.3, 1.5, 2]
# coreset=0
# # filename = "permuted_mnist_coreset_{coreset}_epochs_20-100_lr_5e-3_init_var_1e-4_strengthen_{strengthen}.csv"
# # filename = "coreset_strengthen/permuted_mnist_coreset_{coreset}_epochs_20-100_lr_5e-3_init_var_1e-4_strengthen_{strengthen}_fix_for_real.csv"
# filename = "no_coreset_strengthen_reg/permuted_mnist_reg_coreset_{coreset}_epochs_20-100_lr_3e-4_init_var_1e-4_strengthen_{strengthen}_fix_for_real.csv"
# metric = "MSE"

In [None]:
for experiment in experiments:
    filename_template = experiment["filename"]
    metric = experiment["metric"]
    experiment_name = filename_template.split('/')[0]
    
    dfs = []
    for factor in factors:
        f = filename_template.format(strengthen=factor)
        df = pd.read_csv(f, index_col=0)
        df["strengthen"] = factor
        dfs.append(df)
    
    df_all = pd.concat(dfs, axis=0)
    # If RMSE, take the sqrt of all columns
    columns_to_sqrt = ["acc"] + [f"task_{i}" for i in range(10)]
    if metric == "RMSE":
        for col in columns_to_sqrt:
            df_all[col] = np.sqrt(df_all[col])

    # Figure 1: Accuracy/RMSE vs Number of tasks seen
    plt.figure(figsize=(4, 3))
    for gdf in df_all.groupby("strengthen"):
        strengthen = gdf[0]
        gdf = gdf[1]
        sns.lineplot(
            data=gdf,
            x="n_tasks",
            y="acc",
            label=f"$\\tau = {strengthen}$",
            errorbar=None,
            linewidth=1.5 if strengthen != 1.0 else 2.5,
            marker='o',
            color="black" if strengthen == 1.0 else None,
            alpha=0.7 if strengthen != 1.0 else 1.0,
        )
    plt.xlabel("Number of tasks seen")
    plt.ylabel(metric)
    plt.xticks(range(1, 11))
    plt.savefig(f"visualisations/{experiment_name}_1.pdf", bbox_inches='tight')
    plt.close()
    
    # Figure 2: Performance by task for the final model (n_tasks == 10)
    perf_by_task = df_all.melt(
        id_vars=["strengthen", "n_tasks"],
        value_vars=[f"task_{i}" for i in range(10)],
        var_name="metric",
        value_name="value",
    ).dropna()
    perf_by_task["task_no"] = perf_by_task["metric"].str.replace("task_", "").astype(int) + 1
    
    task_10_perf = perf_by_task[perf_by_task["n_tasks"] == 10]

    plt.figure(figsize=(4, 3))
    colors_to_plot = [0.8, 1, 1.1, 1.2, 1.3, 1.5, 2]
    
    def get_color(s):
        f = colors_to_plot.index(s)
        return sns.color_palette("hls", 8)[f]
    
    fig_2_colors_to_plot = [0.8, 1, 1.2, 1.5, 2]
    for gdf in task_10_perf.groupby("strengthen"):
        strengthen = gdf[0]
        if strengthen not in fig_2_colors_to_plot:
            continue
        gdf = gdf[1]
        color = get_color(strengthen) if strengthen != 1 else "black"
        sns.lineplot(
            data=gdf,
            x="task_no",
            y="value",
            label=f"$\\tau = {strengthen}$",
            errorbar=None,
            color=color,
            linewidth=1.35 if strengthen != 1 else 2,
            marker='o'
        )
    
    if metric == "Accuracy":
        plt.ylim(0.5, 0.95)
    plt.xlabel("Dataset \\#")
    plt.ylabel(metric)
    plt.xticks(range(1, 11))
    plt.legend()
    plt.savefig(f"visualisations/{experiment_name}_2.pdf", bbox_inches='tight')
    plt.close()
    
    # Figure 3: Multi-facet plot by task
    plt.figure(figsize=(8, 6))
    df_filtered = perf_by_task[perf_by_task['n_tasks'] == 10]
    
    g = sns.relplot(
        data=df_filtered,
        x="strengthen", y="value", 
        kind="line", 
        col="task_no",
        col_wrap=5,
        marker="o"
    )
    
    g.set_titles("Task \\#{col_name}")
    g.set_axis_labels("$\\tau$", metric)
    
    for ax in g.axes.flat:
        ax.axvline(x=1.0, linestyle="--", color="red", alpha=0.4)
        ax.tick_params(axis='x', which='both', bottom=True, top=False, labelbottom=True, labelsize=10)
        ax.tick_params(axis='y', which='both', left=True, right=False, labelleft=True, labelsize=10)
    
    plt.tight_layout()
    plt.savefig(f"visualisations/{experiment_name}_3.pdf", bbox_inches='tight')
    plt.close()
    
    # Figure 4: Combined performance and average
    fig, ax = plt.subplots(1, 2, figsize=(7.5, 2.5), sharey=True)
    
    def get_color(s):
        return sns.color_palette("crest", 10)[s-1]
    
    tasks_to_highlight = [1, 4, 10]
    for task in range(1, 11):
        task_df = df_filtered[df_filtered["task_no"] == task]
        sns.lineplot(
            data=task_df,
            x="strengthen",
            y="value",
            label=f"Task {task}" if task in tasks_to_highlight else None,
            errorbar=None,
            color=get_color(task),
            alpha=1 if task in tasks_to_highlight else 0.3,
            linewidth=2 if task in tasks_to_highlight else 1,
            marker='o' if task in tasks_to_highlight else None,
            ax=ax[0]
        )
    
    ax[0].axvline(x=1.0, linestyle="--", color="gray", alpha=0.6)
    ax[0].set_xlabel("$\\tau$")
    ax[0].set_ylabel(metric)
    ax[0].set_xticks(colors_to_plot)
    
    avg = df_filtered.groupby("strengthen")["value"].mean()
    sns.lineplot(
        x=avg.index,
        y=avg.values,
        errorbar=None,
        color="black",
        linewidth=2,
        marker='o',
        ax=ax[1]
    )
    ax[1].axvline(x=1.0, linestyle="--", color="gray", alpha=0.6)
    ax[1].set_xlabel("$\\tau$")
    ax[1].set_xticks(colors_to_plot)
    
    plt.tight_layout()
    plt.savefig(f"visualisations/{experiment_name}_4.pdf", bbox_inches='tight')
    plt.close()
    
    print(f"Saved visualizations for {experiment_name}")

Saved visualizations for no_coreset_strengthen
Saved visualizations for coreset_strengthen
Saved visualizations for no_coreset_strengthen_reg
Saved visualizations for coreset_strengthen_reg


# Main plot: Permuted MNIST

In [86]:
def plot_experiments(experiments, filename, n_tasks, ylim=(0.6, 1.0), metric="Accuracy", *, figsize=(8, 3), outside_legend=False):
    plt.figure(figsize=figsize)
    for experiment in experiments:
        df = pd.read_csv(experiment["filename"], index_col=0)
        if metric == "RMSE":
            df["acc"] = np.sqrt(df["acc"])
        sns.lineplot(
            data=df,
            x="n_tasks",
            y="acc",
            label=experiment["name"],
            errorbar=None,
            linewidth=1.5,
            marker='o',
            alpha=0.7,
        )

    if outside_legend:
        plt.legend(loc='center left', bbox_to_anchor=(1.0, 0.5), frameon=False)
    else:
        plt.legend()

    plt.ylim(*ylim)
    plt.xlabel("Number of tasks seen")
    plt.ylabel(metric)
    plt.xticks(range(1, n_tasks+1))
    # plt.title(f"Permuted MNIST")
    
    # Save as pdf
    plt.savefig(filename, bbox_inches='tight')
    
    plt.close()

In [7]:
permuted_mnist_experiments = [
    {
        "name": "VCL",
        "filename": "no_coreset_strengthen/permuted_mnist_coreset_0_epochs_20-100_lr_5e-3_init_var_1e-4_strengthen_1_fix_for_real.csv",
    },
    {
        "name": "VCL + coreset",
        "filename": "coreset_strengthen/permuted_mnist_coreset_200_epochs_20-100_lr_5e-3_init_var_1e-4_strengthen_1_fix_for_real.csv",
    },
    {
        "name": "VCL + amplification ($\\tau = 1.2$, best)",
        "filename": "no_coreset_strengthen/permuted_mnist_coreset_0_epochs_20-100_lr_5e-3_init_var_1e-4_strengthen_1.2_fix_for_real.csv",
    },
    {
        "name": "VCL + coreset + amplification ($\\tau = 1.2$)",
        "filename": "coreset_strengthen/permuted_mnist_coreset_200_epochs_20-100_lr_5e-3_init_var_1e-4_strengthen_1.2_fix_for_real.csv",
    },
    {
        "name": "EWC ($\\lambda = 1$)",
        "filename": "permuted_mnist/permuted_mnist_EWC_lambda_1_epochs_4_lr_5e-3_hidden_100_approx_2000.csv",
    },
    {
        "name": "LP ($\\lambda = 0.1$)",
        "filename": "permuted_mnist/permuted_mnist_LP_lambda_0.1_epochs_20_lr_5e-3_hidden_100_approx_2000.csv",
    },
    {
        "name": "SI ($\\lambda = 0.5$)",
        "filename": "permuted_mnist/permuted_mnist_SI_epochs_20_lr_2e-3_lambda_0.5_xi_1e-3.csv"
    },
    {
        "name": "Naive",
        "filename": "permuted_mnist/permuted_mnist_naive_epochs_20_lr_5e-3.csv"
    }
]
plot_experiments(permuted_mnist_experiments, "visualisations/permuted_mnist_experiments.pdf", 10)

In [8]:
split_mnist_experiments = [
    {
        "name": "VCL",
        "filename": "split_mnist/split_mnist_coreset_0_epochs_50_lr_2e-3_init_var_1e-6.csv",
    },
    {
        "name": "VCL + coreset",
        "filename": "split_mnist/split_mnist_coreset_200_epochs_50_lr_1e-3_init_var_1e-4.csv",
    },
    {
        "name": "EWC ($\\lambda = 0.1$)",
        "filename": "split_mnist/split_mnist_EWC_lambda_0.1_epochs_50_lr_2e-3_hidden_100_approx_2000.csv",
    },
    {
        "name": "LP ($\\lambda = 0.1$)",
        "filename": "split_mnist/split_mnist_LP_lambda_0.1_epochs_50_lr_2e-3_hidden_100_approx_2000.csv",
    },
    {
        "name": "SI ($\\lambda = 1.0$)",
        "filename": "split_mnist/split_mnist_SI_epochs_50_lr_2e-3_lambda_1_xi_1e-3.csv",
    }
]
plot_experiments(split_mnist_experiments, "visualisations/split_mnist_experiments.pdf", 5)

In [None]:
def plot_split_experiments(experiments, filename, n_tasks):
    fig, ax = plt.subplots(1, n_tasks, figsize=(2*n_tasks, 2), sharey=True)
    
    lines = []
    labels = []
    
    for experiment in experiments:
        df = pd.read_csv(experiment["filename"], index_col=0)
        for i in range(n_tasks):
            col_name = f"task_{i}"
            datapoints = df.dropna(subset=[col_name])[["n_tasks", col_name]]
            line = sns.lineplot(
                data=datapoints,
                x="n_tasks",
                y=col_name,
                label=experiment["name"],
                errorbar=None,
                linewidth=1.5,
                marker='o',
                alpha=0.7,
                ax=ax[i]
            )
            
            if i == 0:
                lines.append(line.lines[-1])
                labels.append(experiment["name"])
    
    for i in range(n_tasks):
        ax[i].get_legend().remove()
        ax[i].set_xlim(0.8, 5.2)
        ax[i].set_ylim(-0.05, 1.05)
        ax[i].set_xticks(range(1, n_tasks+1))
        ax[i].set_title(f"Task {i+1} ({2*i+1} or {2*i+2})")
        ax[i].set_ylabel("Accuracy")
        ax[i].set_xlabel("Tasks")
    # ax[0].set_ylabel("Accuracy")
    
    fig.legend(lines, labels, loc='upper center', bbox_to_anchor=(0.5, 1.10),
              ncol=len(experiments), frameon=False)
    
    plt.tight_layout()
    
    if filename:
        # Save as pdf
        plt.savefig(filename, bbox_inches='tight')

plot_split_experiments(split_mnist_experiments, "visualisations/split_mnist_experiments_by_task.pdf", 5)

In [None]:
permuted_mnist_reg_experiments = [
    {
        "name": "VCL",
        "filename": "no_coreset_strengthen_reg/permuted_mnist_reg_coreset_0_epochs_20-100_lr_3e-4_init_var_1e-4_strengthen_1_fix_for_real.csv", 
    },
    {
        "name": "VCL + amplification ($\\tau = 1.1$)",
        "filename": "no_coreset_strengthen_reg/permuted_mnist_reg_coreset_0_epochs_20-100_lr_3e-4_init_var_1e-4_strengthen_1.1_fix_for_real.csv", 
    },
    {
        "name": "VCL + coreset",
        "filename": "coreset_strengthen_reg/permuted_mnist_reg_coreset_200_epochs_20-100_lr_3e-4_init_var_1e-4_strengthen_1_fix_for_real.csv",
    },
    {
        "name": "VCL + coreset + amplification ($\\tau = 1.2$)",
        "filename": "coreset_strengthen_reg/permuted_mnist_reg_coreset_200_epochs_20-100_lr_3e-4_init_var_1e-4_strengthen_1.2_fix_for_real.csv",
    },
    {
        "name": "Naive",
        "filename": "permuted_mnist_reg/permuted_mnist_reg_naive_epochs_20_lr_1e-3_hidden_100.csv"
    },
    {
        "name": "EWC ($\\lambda = 100$)",
        "filename": "permuted_mnist_reg/permuted_mnist_reg_EWC_lambda_100_epochs_20_lr_1e-3_hidden_100_approx_2000.csv"
    },
    {
        "name": "LP ($\\lambda = 0.1$)",
        "filename": "permuted_mnist_reg/permuted_mnist_reg_LP_lambda_0.1_epochs_20_lr_1e-3_hidden_100_approx_2000.csv"
    },
    {
        "name": "SI ($\\lambda = 0.1$)",
        "filename": "permuted_mnist_reg/permuted_mnist_reg_SI_epochs_20_lr_1e-3_lambda_0.1_xi_1e-3.csv"
    }
]

plot_experiments(permuted_mnist_reg_experiments, "visualisations/permuted_mnist_reg_experiments.pdf", 10, (0.0, 0.35), "RMSE", figsize=(6.5, 2.2), outside_legend=True)

# Permuted MNIST Telemetry

In [5]:
df = pd.read_csv("permuted_mnist_reg_coreset_200_epochs_20-100_lr_1e-3_init_var_1e-4_telemetry.csv", index_col=0)
df["n_tasks"] = df.index + 1
df

Unnamed: 0,kl_divergence,nt,logvar_0.01,logvar_0.05,logvar_0.1,logvar_0.25,logvar_0.5,logvar_0.75,logvar_0.9,logvar_0.95,logvar_0.99,logvar_sample_0.01,logvar_sample_0.05,logvar_sample_0.1,logvar_sample_0.25,logvar_sample_0.5,logvar_sample_0.75,logvar_sample_0.9,logvar_sample_0.95,logvar_sample_0.99,weight_logvar_0.01,weight_logvar_0.05,weight_logvar_0.1,weight_logvar_0.25,weight_logvar_0.5,weight_logvar_0.75,weight_logvar_0.9,weight_logvar_0.95,weight_logvar_0.99,n_tasks
0,1025.733032,59800,-13.214473,-12.848153,-12.614787,-12.130378,-11.340388,-10.171642,-7.30296,-4.858655,-1.464109,-13.348482,-12.907009,-12.650014,-12.134521,-11.320284,-10.111753,-7.261482,-4.844698,-1.462484,-10.902492,-9.297417,-9.190764,-8.999,-8.180964,-4.733675,-2.226076,-2.226076,-2.226076,1
1,404.086426,59800,-2.739844,-2.291058,-2.058115,-1.672338,-1.238763,-0.750275,-0.305368,-0.155527,0.13315,-4.266385,-3.381829,-2.975412,-2.358928,-1.717609,-1.095165,-0.453719,-0.091774,0.674053,-10.899358,-9.297994,-9.19103,-8.998678,-8.180174,-4.736662,-2.230999,-2.226076,-2.220692,2
2,469.843842,59800,-9.93003,-8.578997,-7.73653,-6.368092,-4.775432,-2.373379,-1.156059,-0.576215,0.317234,-10.804141,-9.099922,-8.11902,-6.560419,-4.785201,-2.448761,-1.161144,-0.547402,0.526286,-10.929689,-9.301622,-9.192933,-8.999923,-8.181121,-4.742733,-2.250977,-2.229248,-2.212143,3
3,752.552185,59800,-11.05621,-9.791208,-8.867528,-7.358881,-5.548737,-2.824545,-1.233064,-0.599703,0.193605,-12.225459,-10.489489,-9.430038,-7.653301,-5.598947,-2.925893,-1.270214,-0.609745,0.340695,-10.978423,-9.30538,-9.195692,-9.000999,-8.182693,-4.74624,-2.277544,-2.239425,-2.210519,4
4,833.392944,59800,-12.31625,-10.935529,-10.072799,-8.562836,-6.380282,-3.322545,-1.548571,-0.666036,0.452706,-13.38793,-11.617848,-10.557641,-8.725464,-6.313477,-3.315763,-1.509033,-0.680852,0.561372,-11.019742,-9.308682,-9.197462,-9.001947,-8.183541,-4.749356,-2.298354,-2.251477,-2.210664,5
5,1247.96875,59800,-12.36267,-10.761329,-9.852358,-8.338249,-6.208897,-3.06968,-1.178706,-0.494615,0.365401,-13.572958,-11.689734,-10.568143,-8.680516,-6.199579,-3.121244,-1.241519,-0.511766,0.491522,-11.048819,-9.311558,-9.199407,-9.002887,-8.185619,-4.753814,-2.315789,-2.261422,-2.21312,6
6,1352.182373,59800,-12.833796,-10.894087,-9.915742,-8.301992,-6.058897,-3.063347,-1.275443,-0.604287,0.107875,-13.901664,-11.739895,-10.49492,-8.49045,-5.998794,-3.101852,-1.341538,-0.631176,0.249833,-11.073164,-9.313547,-9.199894,-9.003615,-8.186623,-4.756352,-2.329983,-2.270302,-2.215938,7
7,1738.232056,59800,-13.016774,-11.157318,-10.184424,-8.697143,-6.456819,-3.256651,-1.09917,-0.506622,0.267181,-13.774371,-11.786106,-10.646578,-8.839869,-6.367016,-3.264037,-1.204004,-0.505337,0.470083,-11.090906,-9.317068,-9.200944,-9.003911,-8.188776,-4.760313,-2.348584,-2.28135,-2.220403,8
8,2308.476074,59800,-13.483713,-11.335684,-10.237982,-8.576843,-6.19698,-3.084686,-1.116185,-0.541687,0.390916,-14.409443,-12.076397,-10.769348,-8.72947,-6.115722,-3.097612,-1.250346,-0.599528,0.557979,-11.108126,-9.319154,-9.202432,-9.004732,-8.188422,-4.761469,-2.362373,-2.289928,-2.223845,9
9,2303.439941,59800,-13.642403,-11.270894,-10.321984,-8.40469,-6.001174,-2.94242,-1.213159,-0.628926,0.058002,-14.393868,-11.899479,-10.677003,-8.586313,-5.999774,-2.999852,-1.296402,-0.679523,0.150457,-11.125982,-9.321228,-9.203913,-9.004635,-8.189644,-4.765498,-2.375914,-2.301266,-2.228275,10


In [6]:
np.exp(df.iloc[0]["logvar_0.01"])

1.824010388347273e-06

## Plot aleatoric uncertainty centiles

In [9]:
plt.figure(figsize=(4, 3))

palette = sns.color_palette()

plt.plot(df["n_tasks"], np.exp(df["logvar_0.5"]), color=palette[0])
plt.fill_between(
    df["n_tasks"],
    np.exp(df["logvar_0.25"]),
    np.exp(df["logvar_0.75"]),
    alpha=0.2,
    color=palette[0],
)
plt.fill_between(
    df["n_tasks"],
    np.exp(df["logvar_0.01"]),
    np.exp(df["logvar_0.99"]),
    alpha=0.2,
    color=palette[0],
)

plt.yscale("log")
plt.ylabel("Aleatoric variance")
plt.xlabel("Number of tasks seen")

plt.tight_layout()

plt.savefig("visualisations/permuted_mnist_aleatoric_centiles.pdf")

In [10]:
plt.figure(figsize=(4, 3))

palette = sns.color_palette()

plt.plot(df["n_tasks"], np.exp(df["weight_logvar_0.5"]), color=palette[0])
plt.fill_between(
    df["n_tasks"],
    np.exp(df["weight_logvar_0.25"]),
    np.exp(df["weight_logvar_0.75"]),
    alpha=0.2,
    color=palette[0],
)
plt.fill_between(
    df["n_tasks"],
    np.exp(df["weight_logvar_0.01"]),
    np.exp(df["weight_logvar_0.99"]),
    alpha=0.2,
    color=palette[0],
)

plt.yscale("log")
plt.ylabel("Weight variance (log)")
plt.xlabel("Number of tasks seen")

plt.tight_layout()

plt.savefig("visualisations/permuted_mnist_weight_var_centiles.pdf")