# Single Digit Memorization Visualizations

In [1]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

In [2]:
import matplotlib.pyplot as plt

In [3]:
import os

In [4]:
import torch
import pickle
import networkx as nx
import seaborn as sns

import circuit_visual as cv

In [5]:
# get all of the tasks that have been patched
tasks = os.listdir("../data/patching_circuit")
tasks = set(map(lambda x: x[0:3], tasks))

In [6]:
for task in tasks:
    attn_patching = pickle.load(open(f"../data/patching_circuit/{task}-all_blocks.pkl", "rb"))
    # sns.heatmap(attn_patching, cmap="viridis", vmin=0)
    heads = cv.get_top_attn_heads(attn_patching, 0.05)
    nodes, edges, vals = cv.attn_heads_multipartite(heads)
    G, pos, cvals = cv.make_circuit_graph(nodes, edges, vals, color="viridis", scale=1.5)
    plt.figure(figsize=(3, 10))
    ax = plt.gca()
    nx.draw_networkx_edges(G, pos, arrows=True, width=1, arrowsize=10)
    
    for node, in zip(G.nodes()):
        color = "white" if "MLP" in node else cvals[node]
        cv.draw_rounded_node(ax, pos, node, color, width=0.18, height=0.05)
    
    # plt.autoscale()
    ax.axis('off')
    plt.axis("scaled")
    
    plt.savefig(f"{task}-circuit.pdf")
    plt.close()

In [None]:
# plot the covariance

attn_scores = {}

for task in tasks:
    attn_patching = pickle.load(open(f"../data/patching_circuit/{task}-all_blocks.pkl", "rb"))
    attn_scores[task] = attn_patching

In [None]:
import numpy as np

In [None]:
hm = np.zeros((len(attn_scores), len(attn_scores)))

keys = list(attn_scores.keys())

keys.sort(key=lambda x : (int(x[0]), int(x[2])))

for i, k1 in enumerate(keys):
    for j, k2 in enumerate(keys):
        v1 = attn_scores[k1].flatten()
        v2 = attn_scores[k2].flatten()
        pc = np.corrcoef(v1, v2)[0,1] ** 2
        hm[i,j] = pc

In [None]:
plt.figure(figsize=(13,10))

ax = sns.heatmap(hm, vmin=0, vmax=0.4)
ax.set_yticks(range(len(keys)))
ax.set_xticks(range(len(keys)))
ax.set_yticklabels(keys)
ax.set_xticklabels(keys)
plt.xticks(rotation=90)
plt.yticks(rotation=0)
plt.savefig("pearson-corr.pdf")