In [None]:
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
from tensorboard.backend.event_processing import tag_types
from tensorboard.plugins.hparams import plugin_data_pb2
from google.protobuf.json_format import MessageToDict
import pandas as pd
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

In [None]:
sns.set()

In [None]:
readers = []
for version in range(98, 116):
    reader = EventAccumulator(
        f"../lightning_logs/version_{version}",
        size_guidance={tag_types.SCALARS: 0})
    reader.Reload()
    readers.append(reader)

In [None]:
tags = {
    'Train': 'train/accuracy',
    'Test': 'val/accuracy/dataloader_idx_1',
}

In [None]:
def get_hparams(reader):
    ret = {}
    hparams = plugin_data_pb2.HParamsPluginData.FromString(
        reader.summary_metadata['_hparams_/session_start_info'].plugin_data.content
    )
    buf = hparams.session_start_info.hparams
    for k, v in buf.items():
        ret[k] = MessageToDict(v)
    return ret

In [None]:
df = []

for reader in readers:
    l2 = get_hparams(reader)['l2_norm']
    for name, tag in tags.items():
        scalars = reader.Scalars(tag)
        values = [scalar.value for scalar in scalars[-50:]]
        mean = np.mean(values)
        df.append({
            "Weight Decay λ": l2, 
            "Accuracy": mean,
            "Metric": name
        })

df = pd.DataFrame(df)

In [None]:
plt.figure(figsize=(10,6))
ax = sns.lineplot(
    data=df, x="Weight Decay λ", y="Accuracy", hue="Metric", 
    marker="o", alpha=0.7, palette="mako_r"
)
ax.set(
    xscale="log",
)
ax.xaxis.set_major_locator(ticker.LogLocator(10, subs=(1, 3)))
ax.axvline(0.3, 0, 1, linestyle='--', color=sns.color_palette("mako_r", as_cmap=True)(0.8, 0.8))
ax.text(0.26, 0.2, "Interpolation Threshold", 
    horizontalalignment='right', 
    fontsize='large', 
    fontweight='ultralight',
    color='black' 
)
plt.savefig("static/lambda_vs_accuracy.png", dpi=300, bbox_inches="tight")

In [None]:
train_tags = {
    'Loss': 'train/loss',
    'Weight norm': 'train/l2_norm',
    'Train accuracy': 'train/accuracy',
    'Test accuracy': 'val/accuracy/dataloader_idx_1'
}

In [None]:
df2 = {}
for name, tag in train_tags.items():
    scalars = readers[0].Scalars(tag)
    values = [scalar.value for scalar in scalars]
    steps = [scalar.step for scalar in scalars]
    df2[name] = values
    df2["Step"] = steps

df2 = pd.DataFrame(df2)

In [None]:
df2

In [None]:
fig, axes = plt.subplots(3, 1, sharex=True, figsize=(16,10))

first = df2[df2["Train accuracy"] > 0.999]["Step"].min()
max_step = 30000

for i, ax in enumerate(axes):
    ax.set(xlim=[0, max_step])
    if i == 0:
        ax.axvline(first, 0, 1, linestyle='--', 
            color=sns.color_palette("mako_r", as_cmap=True)(0.8, 0.8),
            zorder=0, clip_on=False
        )
    else:
        ax.axvline(first, 0, 1.2, linestyle='--', 
            color=sns.color_palette("mako_r", as_cmap=True)(0.8, 0.8),
            zorder=0, clip_on=False
        )

for i, col in enumerate(list(train_tags.keys())[:2]):
    sns.lineplot(data=df2, x="Step", y=col, ax=axes[i], palette="mako_r")

transformed_df = df2[["Step", "Train accuracy", "Test accuracy"]]
transformed_df = transformed_df.melt(id_vars=["Step"], var_name="Metric", value_name="Accuracy")
sns.lineplot(data=transformed_df, x="Step", y="Accuracy", hue="Metric", ax=axes[2], palette="mako_r")

axes[1].text(first/2, 80, "Fitting the\n training set", 
    ha='center', 
    va='center',
    fontsize='large', 
    fontweight='ultralight',
    color='black' 
)

axes[1].text((first+max_step)/2, 80, "Finding simpler interpolations", 
    ha='center', 
    va='center',
    fontsize='x-large', 
    fontweight='ultralight',
    color='black' 
)

axes[2].text(first, -0.11, first, 
    ha='center', 
    va='top',
    fontsize=11.4, 
    fontweight='ultralight',
    color='black' 
)

plt.savefig("static/training_curves.png", dpi=300, bbox_inches="tight")

# BatchNorm with Weight Decay

In [None]:
runs = [98, 117]

df3 = []

for run in runs:
    reader = EventAccumulator(
        f"../lightning_logs/version_{run}",
        size_guidance={tag_types.SCALARS: 0})
    reader.Reload()
    scalars = reader.Scalars("val/accuracy/dataloader_idx_1")
    for scalar in scalars:
        df3.append({
            "Accuracy": scalar.value,
            "Step": scalar.step,
            "Epsilon": "1e-2" if run == 98 else "1e-5 (PyTorch default)"
        })

df3 = pd.DataFrame(df3)

In [None]:
plt.figure(figsize=(12,6))
ax = sns.lineplot(data=df3, x="Step", y="Accuracy", hue="Epsilon", palette="mako_r")
ax.set(
    xlim=[0, 40000]
)
plt.savefig("static/batch_norm_and_weight_decay.png", dpi=300, bbox_inches="tight")