In [None]:
import torch
from torch.nn.functional import cosine_similarity
import plotly.graph_objects as go

In [None]:
probes = {
    layer: torch.load(
        f"probes/linear/resid_{layer}_linear.pth"
    ).squeeze(0) for layer in range(8)
}

In [None]:
cos_sims = []
for layer in range(8):
    _cos_sims = []
    for layer2 in range(8):
        if layer2 > layer:
            _cos_sims.append(0)
            continue
        cos_sim = cosine_similarity(probes[layer], probes[layer2], dim=0).mean()
        _cos_sims.append(round(cos_sim.item(), 2))
    cos_sims.append(_cos_sims)

In [None]:
import plotly.figure_factory as ff
import plotly.express as px

blues = px.colors.sequential.Blues
print(blues)
blues[0] = "rgb(255, 255, 255)"
_fig = ff.create_annotated_heatmap(
    z=cos_sims,
    x=[f"Layer {x+1}" for x in range(8)],
    y=[f"Layer {x+1}" for x in range(8)],
    colorscale=blues,
    #zmin=0,
    #zmax=1,
    annotation_text=cos_sims,
    showscale=True,
    colorbar=dict(tickfont=dict(size=18)),
)
_fig.update_xaxes(side="bottom")
_fig.update_layout(
    yaxis_autorange="reversed",
    xaxis_showgrid=False,
    yaxis_showgrid=False,
    yaxis=dict(tickfont=dict(size=20)),
    xaxis=dict(tickfont=dict(size=20)),
)
for idx in range(len(_fig.layout.annotations)):
    _fig.layout.annotations[idx]["font"]["size"] = 20
    if _fig.layout.annotations[idx].text == "0":
        _fig.layout.annotations[idx].text = ""
        
_fig.show()
_fig.write_image("cos_sims.pdf")

In [None]:
heatmap = go.Heatmap(
    z=cos_sims,
    x=[str(x) for x in range(8)],
    y=[str(x) for x in range(8)],
    colorscale="blues",
    zmin=0,
    zmax=1,
    text=cos_sims,
    annotation_text=cos_sims,
)
layout = go.Layout(
    yaxis_autorange="reversed",
    xaxis_showgrid=False,
    yaxis_showgrid=False,
)
for i in range(len(fig.layout.annotations)):
    fig.layout.annotations[i] = cos_sims[i]

fig=go.Figure(data=[heatmap], layout=layout)
fig.show()