In [None]:
from matplotlib.image import NonUniformImage
import matplotlib.pyplot as plt
from glob import glob
import seaborn as sns
import pandas as pd
import numpy as np
import matplotlib
import json
import re

sns.set_theme(style="darkgrid")
matplotlib.rcParams['figure.figsize'] = (20, 10)

%matplotlib inline

In [None]:
results_files = sorted(glob(f"results/*.json"))
all_results = []
layers = {}
for file in results_files:
    with open(file, "r") as f:
        results = json.load(f)
    model = re.search(r"results/(.*?).json", file).group(1)
    layers[model] = results.pop("layers")
    for task_name, problem in results.items():
        all_results.append({
            "task": task_name,
            "model": model,
            "model/task": f"{model}/{task_name}"}
            |
            {
                problem_name: np.array(values)
                for problem_name, values in problem.items()
            }
        )
df = pd.DataFrame(all_results)
df.head()

In [None]:
resnet_df = df[df["model"].str.match(r"resnet\d+")]
base_df = df[df["model"].str.match(r"[a-z0-9]+-base")]
large_df = df[df["model"].str.match(r"[a-z0-9]+-large")]

In [None]:
def activations(results: pd.DataFrame, title: str = None):
    tasks = results["task"].unique()
    models = results["model"].unique()
    num_tasks = len(tasks)
    num_models = len(models)
    fig, axes = plt.subplots(num_tasks * num_models)
    for i, task in enumerate(tasks):
        for j, model in enumerate(models):
            df = results[(results["task"] == task) & (results["model"] == model)]
            activations = df.squeeze(0)["RMS"]
            activations = np.log(activations + 1e-30)
            Y, X = activations.shape
            x = np.linspace(0, 1, X)
            y = np.linspace(0, 1, Y)
            ax = axes[i + j * num_tasks]
            image = NonUniformImage(ax, interpolation="nearest", cmap="hot", extent=(0, 1, 0, 1))
            image.set_data(A=activations, x=x, y=y)
            ax.add_image(image)
            ax.set_title(f"{task}/{model}")
            ax.set_xticks([])
            ax.set_yticks([])
    fig.suptitle(title)


In [None]:
# plot_channels(resnet_df, title="resnet")
plot_channels(base_df, title="base")
plot_channels(large_df, title="large")

In [None]:
def plot_counts(results: pd.DataFrame, title: str = None):
    tasks = results["task"].unique()
    models = results["model"].unique()
    fig, axes = plt.subplots(len(models))
    channel_counts = []
    for task in tasks:
        for model in models:
            outliers = results[(results["task"] == task) & (results["model"] == model)]
            outliers = outliers.squeeze(0)["outliers"]
            for layer, proportion in enumerate(outliers):
                channel_counts.append({
                    "task": task,
                    "model": model,
                    "layer": layer,
                    "percent": proportion * 100
                })
    channel_counts = pd.DataFrame(channel_counts)

    for i, model in enumerate(models):
        ax = axes[i] if len(models) > 1 else axes
        sns.barplot(channel_counts[channel_counts["model"] == model], x="layer", y="percent", hue="task", ax=ax)
        ax.set_title(model)
        ax.set_xlabel("")
    plt.suptitle(title)


In [None]:
plot_counts(base_df, title="base")
plot_counts(large_df, title="large")

In [None]:
def plot_kurtosis(results: pd.DataFrame, title: str = None):
    tasks = results["task"].unique()
    models = results["model"].unique()
    fig, axes = plt.subplots(len(models))
    channel_counts = []
    for task in tasks:
        for model in models:
            kurtoses = results[(results["task"] == task) & (results["model"] == model)]
            kurtosis = kurtoses.squeeze(0)["kurtosis"]
            rotated_kurtosis = kurtoses.squeeze(0)["rotated_kurtosis"]
            for layer, (value, rotated_value) in enumerate(zip(kurtosis, rotated_kurtosis)):
                channel_counts.append({
                    "task": task,
                    "type": "unrotated",
                    "model": model, 
                    "layer": layer,
                    "kurtosis": value,
                })
                channel_counts.append({
                    "task": task,
                    "type": "rotated",
                    "model": model, 
                    "layer": layer,
                    "kurtosis": rotated_value,
                })
    channel_counts = pd.DataFrame(channel_counts)

    for i, model in enumerate(models):
        ax = axes[i] if len(models) > 1 else axes
        _ = sns.barplot(channel_counts[channel_counts["model"] == model], x="layer", y="kurtosis", hue="type", ax=ax)

        for bar, line in zip(ax.patches, ax.lines):
            x = bar.get_x()
            width = bar.get_width()
            height = bar.get_height()
            y = max(line.get_ydata())

            ax.text(x + width / 2., y + 10, f"{height:.{int(height < 10)}f}", ha="center", va="bottom", size=10)

        ax.set_title(model)
        ax.set_xlabel("")
    plt.suptitle(title)

In [None]:
plot_kurtosis(base_df, title="base")
plot_kurtosis(large_df, title="large")