In [1]:
from transformer_lens import HookedTransformer
import torch
from plotly import graph_objects as go
import plotly.express as px
import numpy as np

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = HookedTransformer.from_pretrained("gpt2-small")

Loaded pretrained model gpt2-small into HookedTransformer


In [3]:
model.to(device)
print('lol')

Moving model to device:  cuda
lol


In [4]:
import torch as t

def clusterability(matrix, cluster_U_indices, cluster_V_indices):
    num_clusters = len(cluster_U_indices)
    A = matrix ** 2
    mask = t.zeros_like(A, dtype=t.bool)
    
    for cluster_idx in range(num_clusters):
        u_indices = t.tensor(cluster_U_indices[cluster_idx], dtype=t.long)
        v_indices = t.tensor(cluster_V_indices[cluster_idx], dtype=t.long)
        mask[u_indices.unsqueeze(1), v_indices] = True
    
    intra_cluster_out_sum = t.sum(A[mask])
    total_out_sum = t.sum(A)
    
    return intra_cluster_out_sum / total_out_sum

In [5]:
num_clusters = 4
cluster_size = (model.blocks[0].mlp.W_in.shape[0] // num_clusters, model.blocks[0].mlp.W_in.shape[1] // num_clusters)
cluster_U_indices = {i: list(range(i*cluster_size[0], (i+1)*cluster_size[0])) for i in range(num_clusters)}
cluster_V_indices = {i: list(range(i*cluster_size[1], (i+1)*cluster_size[1])) for i in range(num_clusters)}

## Expt 1 and 2: MLP in and out Modularity (Wiki)

In [6]:
path = './checkpoints/wiki_mlp_'

In [7]:
clusterability_mlp_in = []
clusterability_mlp_out = []
clusterability_mlp_in_conditional = []
clusterability_mlp_out_conditional = []
blocks = range(12)

In [8]:
for block in blocks:
    model.load_state_dict(torch.load(path + 'in_' + str(block) + '.pt'))
    b = model.blocks[block].mlp.W_in
    cluster_size = (b.shape[0] // num_clusters, b.shape[1] // num_clusters)
    cluster_U_indices = {i: list(range(i*cluster_size[0], (i+1)*cluster_size[0])) for i in range(num_clusters)}
    cluster_V_indices = {i: list(range(i*cluster_size[1], (i+1)*cluster_size[1])) for i in range(num_clusters)}
    clusterability_mlp_in.append(clusterability(model.blocks[block].mlp.W_in, cluster_U_indices, cluster_V_indices).item())
    model.load_state_dict(torch.load(path + 'out_' + str(block) + '.pt'))
    b = model.blocks[block].mlp.W_out
    cluster_size = (b.shape[0] // num_clusters, b.shape[1] // num_clusters)
    cluster_U_indices = {i: list(range(i*cluster_size[0], (i+1)*cluster_size[0])) for i in range(num_clusters)}
    cluster_V_indices = {i: list(range(i*cluster_size[1], (i+1)*cluster_size[1])) for i in range(num_clusters)}
    clusterability_mlp_out.append(clusterability(model.blocks[block].mlp.W_out, cluster_U_indices, cluster_V_indices).item())
    model.load_state_dict(torch.load(path + 'in_' + str(block) + '_conditional.pt'))
    b = model.blocks[block].mlp.W_in
    cluster_size = (b.shape[0] // num_clusters, b.shape[1] // num_clusters)
    cluster_U_indices = {i: list(range(i*cluster_size[0], (i+1)*cluster_size[0])) for i in range(num_clusters)}
    cluster_V_indices = {i: list(range(i*cluster_size[1], (i+1)*cluster_size[1])) for i in range(num_clusters)}
    clusterability_mlp_in_conditional.append(clusterability(model.blocks[block].mlp.W_in, cluster_U_indices, cluster_V_indices).item())
    model.load_state_dict(torch.load(path + 'out_' + str(block) + '_conditional.pt'))
    b = model.blocks[block].mlp.W_out
    cluster_size = (b.shape[0] // num_clusters, b.shape[1] // num_clusters)
    cluster_U_indices = {i: list(range(i*cluster_size[0], (i+1)*cluster_size[0])) for i in range(num_clusters)}
    cluster_V_indices = {i: list(range(i*cluster_size[1], (i+1)*cluster_size[1])) for i in range(num_clusters)}
    clusterability_mlp_out_conditional.append(clusterability(model.blocks[block].mlp.W_out, cluster_U_indices, cluster_V_indices).item())

    print(block)

  model.load_state_dict(torch.load(path + 'in_' + str(block) + '.pt'))
  model.load_state_dict(torch.load(path + 'out_' + str(block) + '.pt'))
  model.load_state_dict(torch.load(path + 'in_' + str(block) + '_conditional.pt'))


  model.load_state_dict(torch.load(path + 'out_' + str(block) + '_conditional.pt'))


0
1
2
3
4
5
6
7
8
9
10
11


In [27]:
# plot the clusterability of the mlp in and out weights on the same plot using plotly

fig = go.Figure()

# thick lines
fig.add_trace(go.Scatter(x=list(blocks), y=clusterability_mlp_in, mode='lines+markers', name='MLP In', line=dict(color='darkred', width=2)))
fig.add_trace(go.Scatter(x=list(blocks), y=clusterability_mlp_out, mode='lines+markers', name='MLP Out', line=dict(color='darkblue', width=2)))

fig.update_layout(title='', xaxis_title='Layer', yaxis_title='Clusterability')

fig.update_layout({'plot_bgcolor': 'rgba(255, 255, 255, 1)',})
# show fine grid lines on both axes on both subplotsß
fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='LightGray')
fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='LightGray')

# legend inside the plot in a box
fig.update_layout(legend=dict(x=0.05, y=0.1, traceorder="normal", bgcolor="white", bordercolor="black", borderwidth=1))

# width and height
fig.update_layout(width=500, height=350)

# y axis from 0 to 1
fig.update_yaxes(range=[0, 1])
# x axis from 0 to 11
fig.update_xaxes(range=[0, 11])

# ticks on both axes
fig.update_xaxes(tickmode='linear', tick0=0, dtick=1)
fig.update_yaxes(tickmode='linear', tick0=0, dtick=0.2)

    # show all x ticks
fig.update_xaxes(tickvals=np.arange(10))

# remove space from top of figure and add some space at the bottom
fig.update_layout(margin=dict(t=50, b=140))

# everthing latex font (for research paper)
fig.update_layout(font=dict(family='serif', size=15, color='black'))
fig.update_xaxes(title_font=dict(family='serif', size=18, color='black'))
fig.update_yaxes(title_font=dict(family='serif', size=18, color='black'))
fig.update_xaxes(tickfont=dict(family='serif', size=18, color='black'))
fig.update_yaxes(tickfont=dict(family='serif', size=18, color='black'))
fig.update_xaxes(showline=True, linewidth=1, linecolor='black', mirror=False)

# fig.update_layout(
#     font_family="Computer Modern",
#     font_size=20,
# )


fig.show()

In [29]:
# save the plot in pdf format
fig.write_image("plots/clusterability_mlp.pdf")