In [None]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import torch as th


def index_to_name(index: int) -> str:
    return f"{index - 1}.ffn" if index > 0 else "input"

def plot_transfer_perplexity(
        metrics: dict[str, dict],
        key: str = "transfer_ce",
        model_name: str = "Pythia"
    ) -> go.Figure:
    fig = make_subplots(
        len(metrics), 1,
        shared_yaxes=True,
        subplot_titles=list(metrics.keys()),
        vertical_spacing=0.1,
        x_title="Test layer",
        y_title="bits per byte",
    )

    first_metric = next(iter(metrics.values()))
    max_y = 0.0
    L = len(first_metric[key])    # Number of layers

    for i, (name, metric) in enumerate(metrics.items()):
        # If we stack the transfer results into a matrix where each row is a
        # train layer and each column is a test layer, then the diagonal entries
        # are the baseline perplexities, "transferring" from a layer to itself.
        baseline = metric[key].diag()
        max_y = max(max_y, metric[key].max())

        fig.add_traces(
            [
                go.Scatter(
                    x=th.arange(len(row)),
                    y=row,
                    legendgroup=name,
                    legendgrouptitle_text=name,
                    marker_color=['#EF553B' if j == i else '#636efa' for j in range(len(row))],
                    marker_size=[20 if j == i else 10 for j in range(len(row))],
                    marker_symbol=["star" if j == i else "circle" for j in range(len(row))],
                    mode="lines+markers",
                    name=f"{index_to_name(i)} transfer",
                    visible=i == 1,
                )
                for i, row in enumerate(metric[key])
            ] + [
                go.Scatter(
                    x=th.arange(L),
                    y=baseline,
                    mode="lines",
                    name="baseline",
                    line=dict(dash="dash"),
                )
            ],
            rows=i + 1,
            cols=1
        )

    trace_mask = th.hstack([th.eye(L), th.ones(L).unsqueeze(1)]).bool()
    return fig.update_layout(
        autosize=False,
        height=500 * len(metrics),
        width=1000,
    ).update_layout(
        hovermode="y unified",
        sliders=[
            dict(
                active=1,
                currentvalue=dict(prefix="Train layer: "),
                pad=dict(t=40), # Make room for the x-axis title
                steps=[
                    dict(
                        args=[dict(visible=mask)],
                        label=str(i - 1) if i != 0 else "input",
                        method="restyle",
                    )
                    for i, mask in enumerate(trace_mask)
                ],
            )
        ],
        title=f"Tuned lens transfer perplexity ({model_name})",
    ).update_xaxes(
        tickvals=th.arange(L),
        ticktext=[str(i - 1) if i != 0 else "input" for i in range(L)],
    ).update_yaxes(
        range=[0, max_y],
    )

In [None]:
fig = plot_transfer_perplexity({
    'Tuned lens': vanilla_transfer,
}, model_name="OPT 125M", key="transfer_ce")
fig