In [None]:
import pandas as pd
from matplotlib import pyplot as plt
import seaborn as sns
from pathlib import Path
import numpy as np

In [None]:
# use wandb cli to get training curve data from wandb
# wandb artifact get wm-comp-limit-1/run-jk2dpehx-history:v0
# pd.read_parquet('../../artifacts/run-jk2dpehx-history:v0/0000.parquet')

In [None]:
# basic experiment
# filename1 = "./84m62mn1,n_reg=2,concurrent_reg=2,eval_acc.csv"
# filename2 = "./x3ef5jj1,n_reg=3,concurrent_reg=3,eval_acc.csv"
# filename3 = "./zctyznwg,n_reg=3,concurrent_reg=2,eval_acc.csv"

# n_reg = 50
# filename1 = "./shwv3iys,n_reg=50,concurrent_reg=2,seq_len=20,eval_acc.csv"
# filename2 = "./8qejrp9n,n_reg=50,concurrent_reg=3,seq_len=20,eval_acc.csv"
# filename3 = "./x0wla7ce,n_reg=50,concurrent_reg=4,seq_len=20,eval_acc.csv"


# basic experiment but with n_train=10_000
# filename1 = "rb7u44gj,n_reg=2,concurrent_reg=2,n_train=10_000,eval_acc.csv"
# filename2 = "gv8oc4n3,n_reg=3,concurrent_reg=2,n_train=10_000,eval_acc.csv"
# filename3 = "7dbv1sd9,n_reg=3,concurrent_reg=3,n_train=10_000,eval_acc.csv"

#
filenames = []
# filenames += list(Path("./").glob("*_5_*,eval_acc.csv"))
# filenames += list(Path("./").glob("*_5_5,eval_acc.csv"))
# filenames += list(Path("./").glob("*_4_4,eval_acc.csv"))
# filenames += list(Path("./").glob("*_3_3,eval_acc.csv"))
# filenames += list(Path("./").glob("*_2_2,eval_acc.csv"))
# filenames += list(Path("./").glob("*_2_2,eval_acc.csv"))
# filenames += list(Path("./").glob("*_5_2,eval_acc.csv"))

# filenames += list(Path("./").glob("*_3_3,eval_acc.csv"))
# filenames += list(Path("./").glob("*_4_4,eval_acc.csv"))

# filenames += list(Path("./").glob("*split_set*.csv"))
# filenames += list(Path("./").glob("100_2-4---64_swept/*.csv"))
# filenames += list(Path("./").glob("100_*_post_training/*.csv"))
# filenames += list(Path("./").glob("100_*_split_set_control/*.csv"))
filenames += list(Path("./").glob("100_*_4items_swept/*.csv"))
filenames = sorted(filenames, key=lambda x: "".join((str(x))))

# uncomment for post-training curves
# add None so our conditions are still aligned across all plots; in the post-training, we don't have
# a curve corresponding to the 2 concurrent regs condition
# filenames = [None, *filenames]

print(*filenames, sep="\n")


def read_curves(filename, offset=420):
    curves = pd.read_csv(filename).T
    curves = curves[~curves.index.str.contains("MIN|MAX|_step")]  # .iloc[1:, :]
    # set column index to the first row
    curves.columns = curves.iloc[0].astype(int)
    curves = curves[1:]

    # if "post_training" in str(filename):
    #     curves.columns = curves.columns.astype(int) + offset

    return curves

In [None]:
def plot_lines(
    filename,
    measurement_epoch=100,
    ulimit=900,
    fig=None,
    ax=None,
    percent_perfect=False,
):
    curves = read_curves(filename)
    print(f"found {len(curves)} curves in {filename}")
    # total_datapoints = curves.shape[1]
    # curves = curves.iloc[:, :ulimit]  # Limit to 350 steps for plotting
    f, a = fig, ax
    if f is None or a is None:
        f, a = plt.subplots(figsize=(20, 7))

    # Plot individual time-series
    if not percent_perfect:
        for i in range(0, len(curves)):
            sns.lineplot(data=curves.iloc[i, :], alpha=0.3, color="b", ax=a)

    # a.set_xticks(range(0, ulimit, 20))
    # a.set_xlim(0, ulimit)
    # a.set_xlim(0.0, 20)
    if not percent_perfect:
        a.set_yticks(np.arange(0.50, 1.01 - 1e-6, 0.05))
        a.set_ylim(0.55, 1.01)

    # Plot mean time-series with error area
    plot = sns.lineplot(
        # data=curves.reset_index(drop=True).mean(),
        data=(curves >= 0.999).sum(axis=0)  # / curves.shape[0]
        if percent_perfect
        else curves.mean(),
        color="k",
        linewidth=2,
        linestyle="--",
        label="Mean",
        ax=a,
    )

    n_reg = filename.stem.split("_")[1]
    a.text(
        58,
        0.75,
        f"{n_reg}",
        size=20,
        rotation=0.0,
        ha="center",
        va="center",
        bbox=dict(
            boxstyle="round",
            ec=(1.0, 0.5, 0.5),
            fc=(1.0, 0.8, 0.8),
        ),
    )

    if not percent_perfect:
        a.fill_between(
            plot.get_lines()[0].get_xdata(),
            #
            curves.iloc[:, :].mean() + curves.iloc[:, :].std() / curves.shape[0] ** 0.5,
            curves.iloc[:, :].mean() - curves.iloc[:, :].std() / curves.shape[0] ** 0.5,
            alpha=0.2,
            color="k",
            label="SEM",
        )

    epochs = 60
    # every 7 logging steps we have 1 epoch; mark these epochs
    a.set_xticks(
        np.arange(
            curves.columns[0],
            curves.columns[-1] + curves.columns[-1] / epochs,
            # curves.columns[-1] / epochs * 5,
        ),
        minor=True,
    )
    # number epochs on x-axis in a different color above the ticks
    # a.set_xticklabels(np.arange(0, epochs + 1, 5), minor=True, color="black")
    for tick in a.xaxis.get_minor_ticks():
        # tick.tick1line.set_visible(False)
        # tick.tick2line.set_visible(False)
        tick.label1.set_verticalalignment("bottom")
        tick.label1.set_horizontalalignment("center")

    a.set_xticks(curves.columns[0::5], minor=False)

    a.grid(which="major")
    a.grid(which="minor", linestyle="--", linewidth=0.5, alpha=0.2, color="red")
    a.set_title(str(filename).split("/")[1])
    a.set_xlabel("Training epoch")
    ax.set_ylabel("Proportion/number of runs that reach perfect accuracy")
    a.legend()
    plt.tight_layout()
    # plt.show()

In [None]:
fig, ax = plt.subplots(2, 3, figsize=(25, 10))
for i, f in enumerate(filenames):
    if f:
        print(f)
        plot_lines(f, fig=fig, ax=ax[i // 3, i % 3], percent_perfect=False)
fig.tight_layout()
plt.show()


In [None]:
def plot_means(
    filenames: list,
    measurement_epoch: int = 950,
    exclude_initial: int = 0,
    percent_perfect: bool = False,
    perfect_threshold: float = 0.999,
):
    f, a = plt.subplots(figsize=(14, 6))
    if not percent_perfect:
        a.set_ylim(0.4, 1.01)
    else:
        a.set_ylim(-0.05, 0.50)
        a.set_yticks(np.arange(0.0, 1.01, 0.05))

    # a.set_yscale("log")

    # initialize a discrete cmap based on the number of files
    cmap = plt.get_cmap("tab10")

    data_dict = {}
    for filename in filenames:
        curves = read_curves(filename, offset=0).iloc[
            :, exclude_initial : measurement_epoch + 1
        ]
        data_dict[filename] = curves

    for i, (filename, curves) in enumerate(data_dict.items()):
        color = cmap(i)
        plot = sns.lineplot(
            data=(
                (curves >= perfect_threshold).sum(axis=0)
                + np.repeat(np.random.uniform(-2e-1, 2e-1), curves.shape[1])
            )
            / curves.shape[0]
            if percent_perfect
            else curves.mean(),
            linewidth=2,
            linestyle="-",
            label=f"{(str(filename).split('/')[1][: -4 - 8])} (shading: SEM)",
            color=color,
            alpha=0.7,
        )

        # fill between the sd area
        if not percent_perfect:
            a.fill_between(
                # range(int(curves.columns[0]), int(curves.columns[-1]) + 1),
                plot.get_lines()[0].get_xdata(),
                curves.iloc[:, :].mean()
                + curves.iloc[:, :].std() / curves.shape[0] ** 0.5,
                curves.iloc[:, :].mean()
                - curves.iloc[:, :].std() / curves.shape[0] ** 0.5,
                # curves.iloc[:, :].reset_index(drop=True).mean() + curves.iloc[:, :].reset_index(drop=True).std() / np.sqrt(curves.shape[0]),
                # curves.iloc[:, :].reset_index(drop=True).mean() - curves.iloc[:, :].reset_index(drop=True).std() / np.sqrt(curves.shape[0]),
                alpha=0.1,
                color=color,
                # label=f"{filename} SEM",
            )

    handles, labels = a.get_legend_handles_labels()
    pretraining_handles = [
        h for h, l in zip(handles, labels) if "post_training" not in l
    ]
    pretraining_labels = [l for l in labels if "post_training" not in l]
    # post_training_handles = [h for h, l in zip(handles, labels) if "post_training" in l]
    # post_training_labels = [l for l in labels if "post_training" in l]

    legend1 = a.legend(
        pretraining_handles,
        pretraining_labels,
        # loc=(0.15, 0.01),
        title="Training from scratch",
    )
    a.add_artist(legend1)
    # a.legend(
    #     post_training_handles,
    #     post_training_labels,
    #     loc="lower right",
    #     title="Post-training after exposure",
    # )
    plt.tight_layout()

    epochs = 60
    # every 7 logging steps we have 1 epoch; mark these epochs
    a.set_xticks(
        np.arange(
            curves.columns[0],
            curves.columns[-1] + curves.columns[-1] / epochs,
            curves.columns[-1] / epochs,
        ),
        minor=True,
    )
    # number epochs on x-axis in a different color above the ticks
    a.set_xticklabels(np.arange(0, epochs + 2, 1), minor=True, color="red")
    for tick in a.xaxis.get_minor_ticks():
        # tick.tick1line.set_visible(False)
        # tick.tick2line.set_visible(False)
        tick.label1.set_verticalalignment("bottom")
        tick.label1.set_horizontalalignment("center")

    a.set_xticks(curves.columns[::50], minor=False)

    a.grid(which="major")
    a.grid(which="minor", linestyle="--", linewidth=0.5, alpha=0.2, color="red")
    a.set_xlabel("Batch (nearest Epoch in red)")
    if not percent_perfect:
        a.set_ylabel("Held-out accuracy")

    if not percent_perfect:
        a.set_title("Mean accuracy across 20 random seeds colored by condition")
    else:
        a.set_title(
            "Proportion of runs that reach perfect accuracy across 20 random seeds colored by condition"
        )


plot_means(
    filenames,
    measurement_epoch=0,
    exclude_initial=0,
    percent_perfect=False,
    perfect_threshold=0.99,
)