In [2]:
from white_box import ResidualStats
import torch as th

stats = th.load(
    "/mnt/ssd-1/nora/tuned-lenses/pythia/12b-deduped/affine/eval/stream_stats.pt",
    map_location="cpu"
)

In [3]:
covmats = th.stack(list(stats.covariance()))
covs = covmats.flatten(1)

In [4]:
L, Q = th.linalg.eigh(covmats)

In [5]:
trimmed_L = L.clone()
trimmed_L[:, -2:] = 0
trimmed_covmats = Q @ th.diag_embed(trimmed_L) @ Q.mT
trimmed_covs = trimmed_covmats.flatten(1)

In [6]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-13b-deduped")

In [7]:
def frobenius_similarities(A):
    A = A.flatten(1)
    gram = A @ A.T
    norms = A.norm(dim=-1, keepdim=True)

    return gram / (norms * norms.T)

In [19]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import plotly.express as px


TEXT_WIDTH = 6.75 * 96

fig = make_subplots(
    rows=2, cols=1,
    vertical_spacing=0.075,
    # shared_yaxes=True,
    subplot_titles=("all principal components", "w/o top 2 components"),
    x_title="Layer",
    y_title="Layer"
).add_trace(
    go.Heatmap(
        z=frobenius_similarities(covs).cpu(),
        coloraxis="coloraxis",
    ),
    row=1, col=1
).add_trace(
    go.Heatmap(
        z=frobenius_similarities(trimmed_covs).cpu(),
        coloraxis="coloraxis",
    ),
    row=2, col=1
).update_annotations(
    font=dict(size=20, color="black"),
).update_layout(
    coloraxis_colorbar=dict(
        title=dict(
            side="right",
            text="Cosine similarity (Frobenius)",
        )
    ),
    font=dict(size=16, color="black"),
    height=TEXT_WIDTH * 1.5,
    margin_l=70,
    margin_r=30,
    margin_t=30,
    margin_b=70,
    width=TEXT_WIDTH,
    # title="Similarity of covariance across depth (Pythia 12B)"
).update_xaxes(
    dtick=5
).update_yaxes(
    autorange="reversed",
    dtick=5,
)
fig

In [23]:
root = "/mnt/ssd-1/nora/tuned-lenses/pythia/12b-deduped/"
vanilla_transfer = th.load(
    root + "affine/eval/aggregate_transfer_metrics.pt", map_location="cpu"
)
penalties = vanilla_transfer['transfer_ce'] - vanilla_transfer['transfer_ce'].diag()

In [36]:
from white_box.stats import spearmanr

spearmanr(frobenius_similarities(trimmed_covs)[:36, :36].cpu().flatten(), penalties.cpu().flatten())

tensor(-0.7826)

In [33]:
frobenius_similarities(trimmed_covs)[:36, :36].cpu().flatten().shape

torch.Size([1296])

In [26]:
penalties.cpu().flatten().shape

torch.Size([1296])

In [25]:
fig.write_image("/mnt/ssd-1/nora/pythia-12b-cov.pdf")

In [None]:
import torch.nn.functional as F

stats.covariance().pairwise_map(lambda a, b: F.cosine_similarity(a.flatten(), b.flatten(), dim=0))

In [None]:
lens_biases = [probe.bias.data.norm() for probe in lens]
lens_biases

In [None]:
stats.mean().zip_map(lambda mu, b: th.norm(mu + b).item(), lens_biases).plot()
stats.mean().map(lambda mu: th.norm(mu).item()).plot()

In [None]:
stats.mean().zip_map(lambda mu, b: th.norm(mu + b).item(), lens_biases).plot()

In [None]:
residuals.mean_norm()

In [None]:
import matplotlib.pyplot as plt

residuals.mean().map(th.norm).zip_map(lambda g, f: g / f, stats.mean_norm()).map(lambda x: x.cpu()).plot()
plt.yscale('log')

In [None]:
import torch.distributions as D

dists = stats.mean().zip_map(lambda mu, cov: D.MultivariateNormal(mu, cov), stats.covariance())

In [None]:
from tuned_lens.stats import gaussian_wasserstein_l2

wass = dists.pairwise_map(gaussian_wasserstein_l2)

In [None]:
wass.map(lambda x: x.cpu()).plot()

In [None]:
import torch as th

In [None]:
stats.mean().map(th.norm).map(lambda x: x.cpu()).plot()
stats.mean_norm().map(lambda x: x.cpu()).plot()