In [None]:
from pathlib import Path

import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt

In [None]:
def read_curves(filename):
    curves = pd.read_csv(filename).T
    curves = curves[~curves.index.str.contains("MIN|MAX|_step")]  # .iloc[1:, :]
    # set column index to the first row
    curves.columns = curves.iloc[0].astype(int)
    curves = curves[1:]

    # replace cell values == 0.0 with NaN
    curves = curves.replace(0.0, np.nan)

    return curves

In [None]:
def filename_to_keys(filename):
    # read keys of the form key=value from the filename
    keys = filename.stem.split("_")
    d = {k.split("=")[0]: k.split("=")[1] for k in keys if "=" in k}
    for k, v in d.items():
        try:
            d[k] = int(v)
        except ValueError:
            try:
                d[k] = float(v)
            except ValueError:
                pass
    return d

In [None]:
filenames = []
filenames += list(Path("./").glob("*.csv"))

filenames = sorted(
    filenames,
    key=lambda fname: ((d := filename_to_keys(fname))["concurrent"], d["sparsity"]),
    reverse=False,
)

print(*filenames, sep="\n")

In [None]:
def plot_lines(
    filename,
    fig=None,
    ax=None,
):
    curves = read_curves(filename)

    # read keys of the form key=value from the filename
    d = filename_to_keys(filename)
    concurrent = d["concurrent"]
    sparsity = d["sparsity"]

    print(f"found {len(curves)} curves in {filename}")
    # total_datapoints = curves.shape[1]
    f, a = fig, ax
    if f is None or a is None:
        f, a = plt.subplots(figsize=(20, 7))

    # Plot individual time-series
    for i in range(0, len(curves)):
        color = "b"
        if "*" in filename.stem:
            color = "g"
        sns.lineplot(data=curves.iloc[i, :], alpha=0.3, color=color, ax=a)

    # a.set_xticks(range(0, ulimit, 20))
    # a.set_xlim(0, ulimit)
    # a.set_xlim(0.0, 20)
    a.set_yticks(np.arange(0.55, 1.01 - 1e-6, 0.05))
    a.set_ylim(0.55, 1.01)

    # Plot mean time-series with error area
    plot = sns.lineplot(
        # data=curves.reset_index(drop=True).mean(),
        data=curves.mean(),
        color="k",
        linewidth=2,
        linestyle="--",
        label="Mean",
        ax=a,
    )

    a.text(
        60 / (1 - sparsity) / 2,  # x position
        0.65,
        f"roles={concurrent}, {sparsity=}",
        size=20,
        rotation=0.0,
        ha="center",
        va="center",
        bbox=dict(
            boxstyle="round",
            ec=(1.0, 0.5, 0.5),
            fc=(1.0, 0.8, 0.8),
        ),
    )

    a.fill_between(
        plot.get_lines()[-1].get_xdata(),
        #
        curves.iloc[:, :].mean() + curves.iloc[:, :].std() / curves.shape[0] ** 0.5,
        curves.iloc[:, :].mean() - curves.iloc[:, :].std() / curves.shape[0] ** 0.5,
        alpha=0.2,
        color="k",
        label="SEM",
    )

    epochs = 60 / (1 - sparsity)
    a.set_xticks(
        np.arange(
            curves.columns[0],
            # curves.columns[-1] + curves.columns[-1] / epochs,
            epochs,
        ),
        minor=True,
    )
    # number epochs on x-axis in a different color above the ticks
    # a.set_xticklabels(np.arange(0, epochs + 1, 5), minor=True, color="black")
    for tick in a.xaxis.get_minor_ticks():
        # tick.tick1line.set_visible(False)
        # tick.tick2line.set_visible(False)
        tick.label1.set_verticalalignment("bottom")
        tick.label1.set_horizontalalignment("center")

    a.set_xticks(curves.columns[0::5], minor=False)

    a.grid(which="major", linestyle="-", linewidth=0.5, alpha=0.3, color="gray")
    # a.grid(which="minor", linestyle="--", linewidth=0.5, alpha=0.2, color="red")
    a.set_title(str(filename))
    a.set_xlabel("Training epoch")
    ax.set_ylabel("Held-out accuracy")
    a.legend()
    plt.tight_layout()
    # plt.show()

In [None]:
fig, ax = plt.subplots(4, 3, figsize=(25, 15))

for i, f in enumerate(filenames):
    if f:
        print(f)
        plot_lines(f, fig=fig, ax=ax[i // 3, i % 3])

# share x and y axes across all subplots
# reference = (0, 2)
# for a in ax.flat:
#     a.label_outer()  # Hide x labels and tick labels for top plots and y ticks for right plots
#     a.set_xticks(ax[reference].get_xticks())
#     a.set_yticks(ax[reference].get_yticks())
#     a.set_xlim(ax[reference].get_xlim())
#     a.set_ylim(ax[reference].get_ylim())

fig.tight_layout()
plt.show()


In [None]:
# in this section, we are going to load in all the filenames and their curves, but collapse the data
# by averaging it across all seeds within a filename. we're also going to normalize the epochs by the highest no.
# of epochs across all files so the curves are comparable across different sparsity levels (which trained for different
# epoch durations but effectively the same numbers of gradients).
# then we're going to merge it all into a
# dataframe where each row is a snapshot in the time series, and each column is a data header such
# as accuracy, sparsity, concurrent roles, etc.

In [None]:
[*read_curves(filenames[0]).items()][0][1]

In [None]:
flattened_data = []

curves = [*map(read_curves, filenames)]
for i, c in enumerate(curves):
    # replace cell values == 0.0 with NaN
    c = c.replace(0.0, np.nan)

    # normalize epoch so the epoch index goes from 0 to 1
    max_epoch = c.columns[-1]
    actual_epochs = c.columns.to_list()
    c.columns = c.columns / max_epoch

    # average accuracy values for each epoch across all seeds
    # c_mean = c.mean(axis=0)
    # also note the standard error of the mean (SEM) for each epoch
    # c_sem = c.std(axis=0) / c.shape[0] ** 0.5

    concurrent = filename_to_keys(filenames[i])["concurrent"]
    sparsity = filename_to_keys(filenames[i])["sparsity"]
    sweep_id = filenames[i].stem.split("_")[0]

    # now add each row to the flattened data as a datapoint with corresponding metadata columns
    for i, (epoch, acc_series) in enumerate(c.items()):
        for run_id, acc in acc_series.items():
            flattened_data.append(
                {
                    "epoch_norm": epoch,
                    "epoch": actual_epochs[i],
                    "accuracy": acc,
                    "run_id": run_id.split()[0],
                    "sparsity": sparsity,
                    "concurrent": concurrent,
                    "sweep_id": sweep_id,
                    # "SEM": c_sem.iloc[i],
                }
            )

datapoints = pd.DataFrame(flattened_data)
datapoints

In [None]:
# view = datapoints[datapoints["sparsity"] == 0]
# view = datapoints[datapoints["concurrent"] == 3]
view = datapoints

# fig, ax = plt.subplots(figsize=(12, 5))
# ax.grid(which="major", linestyle="-", linewidth=0.5, alpha=0.4, color="grey")

g = sns.FacetGrid(
    view,
    hue="concurrent",
    col="sparsity",
    # palette=sns.color_palette("rocket", n_colors=4),
    palette=sns.color_palette("rocket", n_colors=4),
    height=4,
    aspect=1.6,
    ylim=(0.65, 1.01),
    gridspec_kws={"hspace": 0.1, "wspace": 0.1},
)
g.map_dataframe(
    sns.lineplot,
    x="epoch_norm",
    y="accuracy",
    markers=True,
    dashes=False,
    markeredgewidth=0,
    markersize=5,
    alpha=0.9,
    errorbar="se",
    err_kws={
        "alpha": 0.2,
        "linewidth": 0,
    },
)
g.add_legend()
for ax in g.axes.flat:
    ax.grid(which="major", linestyle="-", linewidth=0.5, alpha=0.4, color="grey")

# sns.lineplot(
#     data=view,
#     x="epoch",
#     # x="epoch_norm",
#     y="accuracy",
#     # hue="concurrent",
#     hue="sparsity",
#     # style="sparsity",
#     style="concurrent",
#     markers=True,
#     dashes=False,
#     markeredgewidth=0,
#     markersize=5,
#     alpha=0.7,
#     palette=sns.color_palette("rocket", n_colors=3),
#     ax=ax,
#     errorbar="se",
#     err_kws={
#         "alpha": 0.2,
#         "linewidth": 0,
#     },
# )


# ax.set_ylim(0.75, 1.01)

g.tight_layout()
plt.show()

In [None]:
# compute the area-under-the-curve (AUC) for each curve and make a bar plot
auc = (
    view.fillna(0)
    .groupby(["sparsity", "concurrent", "epoch_norm", "run_id"])
    .agg({"accuracy": "mean"})
    .reset_index()
)
# integrate over the epoch_norm values and accuracy values to compute the area under the curve
auc = (
    auc.groupby(["sparsity", "concurrent", "run_id"])
    .apply(lambda x: np.trapezoid(x["accuracy"], x["epoch_norm"]))
    .reset_index(name="auc")
)

fig, ax = plt.subplots(figsize=(7, 4))
ax.grid(which="major", linestyle="-", linewidth=0.5, alpha=0.4, color="grey")

sns.barplot(
    data=auc,
    x="sparsity",
    y="auc",
    hue="concurrent",
    errorbar="se",
    err_kws={
        "alpha": 0.9,
        "linewidth": 2,
        "marker": ",",
        # "markersize": 10,
        # "markeredgewidth": 1.6,
        "linestyle": "-",
        "color": "k",
    },
    palette=sns.color_palette("rocket", n_colors=4),
    ax=ax,
)

ax.set_ybound(0.65, 1.01)