In [None]:
from matplotlib.image import NonUniformImage
import matplotlib.pyplot as plt
from glob import glob
import seaborn as sns
import pandas as pd
import numpy as np
import matplotlib
import json
import re

sns.set_theme(style="darkgrid")
matplotlib.rcParams['figure.figsize'] = (20, 10)

%matplotlib inline

In [None]:
def load_results_df(key: str) -> pd.DataFrame:
    results_files = glob(f"results_*-{key}.json")
    all_results = []
    for file in results_files:
        with open(file, "r") as f:
            results = json.load(f)
        model = re.search(r"results_(.*?).json", file).group(1)
        for task, values in results.items():
            all_results.append({
                "task": task,
                "model": model,
                "task/model": f"{task}/{model}",
                "activations": np.array(values)
            })
    return pd.DataFrame(all_results)

In [None]:
base_results = load_results_df("base")
large_results = load_results_df("large")
base_results.head()

In [None]:
def plot_channels(results: pd.DataFrame, title: str = None):
    tasks = results["task"].unique()
    models = results["model"].unique()
    num_tasks = len(tasks)
    num_models = len(models)
    fig, axes = plt.subplots(num_tasks * num_models)
    for i, task in enumerate(tasks):
        for j, model in enumerate(models):
            df = results[(results["task"] == task) & (results["model"] == model)]
            activations = df.squeeze(0)["activations"]
            activations = np.log(activations + 1e-30)
            Y, X = activations.shape
            x = np.linspace(0, 1, X)
            y = np.linspace(0, 1, Y)
            ax = axes[i + j * num_tasks]
            image = NonUniformImage(ax, interpolation="nearest", cmap="hot", extent=(0, 1, 0, 1))
            image.set_data(A=activations, x=x, y=y)
            ax.add_image(image)
            ax.set_title(f"{task}/{model}")
            ax.set_xticks([])
            ax.set_yticks([])
    fig.suptitle(title)


In [None]:
plot_channels(base_results, title="base")
plot_channels(large_results, title="large")

In [None]:
def plot_counts(results: pd.DataFrame, title: str = None):
    tasks = results["task"].unique()
    models = results["model"].unique()
    fig, axes = plt.subplots(len(models))
    channel_counts = []
    for task in tasks:
        for model in models:
            activations = results[(results["task"] == task) & (results["model"] == model)]
            activations = activations.squeeze(0)["activations"]
            activations = (activations > 0).sum(axis=1)
            for layer, count in enumerate(activations):
                channel_counts.append({
                    "task": task,
                    "model": model,
                    "layer": layer,
                    "count": count
                })
    channel_counts = pd.DataFrame(channel_counts)

    for i, model in enumerate(models):
        ax = axes[i] if len(models) > 1 else axes
        sns.barplot(channel_counts[channel_counts["model"] == model], x="layer", y="count", hue="task", ax=ax)
        ax.set_title(model)
        ax.set_xlabel("")
    plt.suptitle(title)


In [None]:
plot_counts(base_results, title="base")
plot_counts(large_results, title="large")