In [1]:
from transformer_lens import HookedTransformer
import torch
import torch as t
device = "cuda" if torch.cuda.is_available() else "cpu"
base_model = HookedTransformer.from_pretrained("gpt2-small")
base_model.to(device)
from transformer_lens.evals import evaluate
base_results = evaluate(base_model)

Loaded pretrained model gpt2-small into HookedTransformer
Moving model to device:  cuda
36718


  0%|          | 0/293 [00:00<?, ?it/s]

wiki: 3.3675413816282065
10000


  0%|          | 0/1373 [00:00<?, ?it/s]

owt: 3.143429225034053
10000


  0%|          | 0/2119 [00:00<?, ?it/s]

pile: 2.946532333251273
45404


  0%|          | 0/23199 [00:00<?, ?it/s]

code: 1.9871331866424862


In [2]:
base_results

{'wiki_loss': 3.3675413816282065,
 'owt_loss': 3.143429225034053,
 'pile_loss': 2.946532333251273,
 'code_loss': 1.9871331866424862}

In [3]:
def clusterability(matrix, cluster_U_indices=None, cluster_V_indices=None, num_clusters=4):
    A = matrix ** 2
    mask = t.zeros_like(A, dtype=t.bool)
    
    cluster_size = (A.shape[0] // num_clusters, A.shape[1] // num_clusters)
    cluster_U_indices = {i: list(range(i*cluster_size[0], (i+1)*cluster_size[0])) for i in range(num_clusters)}
    cluster_V_indices = {i: list(range(i*cluster_size[1], (i+1)*cluster_size[1])) for i in range(num_clusters)}

    for cluster_idx in range(num_clusters):
        u_indices = t.tensor(cluster_U_indices[cluster_idx], dtype=t.long)
        v_indices = t.tensor(cluster_V_indices[cluster_idx], dtype=t.long)
        mask[u_indices.unsqueeze(1), v_indices] = True
    
    intra_cluster_out_sum = t.sum(A[mask])
    total_out_sum = t.sum(A)
    
    return intra_cluster_out_sum / total_out_sum

In [4]:
from transformer_lens.evals import make_wiki_data_loader, make_pile_data_loader, make_owt_data_loader, make_code_data_loader

datasets = {
    'wiki': make_wiki_data_loader(base_model.tokenizer, batch_size=8),
    'pile': make_pile_data_loader(base_model.tokenizer, batch_size=8),
    'owt': make_owt_data_loader(base_model.tokenizer, batch_size=8),
    'code': make_code_data_loader(base_model.tokenizer, batch_size=8),
}

36718
10000
10000
45404


In [5]:
clusterability(base_model.blocks[4].mlp.W_in)

tensor(0.2493, device='cuda:0', grad_fn=<DivBackward0>)

In [40]:
cluster_losses = []
train_losses = []
lomda = 40.0
model = HookedTransformer.from_pretrained("gpt2-small")
model.to(device)
blocks_to_cluster = [model.blocks[i].mlp.W_in for i in range(12)]
path = './checkpoints/'
num_epochs = 2
num_clusters = 4
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
model.train()
    
for epoch in range(num_epochs):
    for idx, batch in enumerate(datasets['wiki']):
        tokens = batch['tokens'].to(device)
        # cluster_loss = sum([clusterability(block) for block in blocks_to_cluster]) / len(blocks_to_cluster)
        cluster_loss = 0
        train_loss = model(tokens, return_type="loss")
        # cluster_losses.append(cluster_loss.item())
        cluster_losses.append(0)
        train_losses.append(train_loss.item())
        loss = train_loss - lomda * cluster_loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if idx % 100 == 0:
            print(f'Epoch {epoch+1}, Batch {idx}, Train Loss: {round(train_loss.item(), 4)}, Clusterability: {round(cluster_loss, 4)}')    
    torch.save(model.state_dict(), path + f'wiki_non_modular_mlp_in_model_epoch_{epoch+1}.pt')

# store the cluster losses and train losses
import pickle
with open(path + 'wiki_non_modular_mlp_in_cluster_losses.pkl', 'wb') as f:
    pickle.dump(cluster_losses, f)
with open(path + 'wiki_non_modular_mlp_in_train_losses.pkl', 'wb') as f:
    pickle.dump(train_losses, f)

Loaded pretrained model gpt2-small into HookedTransformer
Moving model to device:  cuda
Epoch 1, Batch 0, Train Loss: 3.2632, Clusterability: 0
Epoch 1, Batch 100, Train Loss: 2.8552, Clusterability: 0
Epoch 1, Batch 200, Train Loss: 2.9276, Clusterability: 0
Epoch 2, Batch 0, Train Loss: 2.5913, Clusterability: 0
Epoch 2, Batch 100, Train Loss: 2.7196, Clusterability: 0
Epoch 2, Batch 200, Train Loss: 2.5478, Clusterability: 0


In [9]:
# TRAINING PLOTS

In [10]:
import plotly.graph_objects as go
import plotly.express as px

In [18]:
# pick every 100, 200, etc. element from the list
def pick_elements(l, n):
    return l[::n]

train_losses_picked = pick_elements(train_losses, 100)
cluster_losses_picked = pick_elements(cluster_losses, 100)

In [23]:
# normalizing the losses
train_losses_picked = [x / max(train_losses) for x in train_losses]
cluster_losses_picked = [x / max(cluster_losses) for x in cluster_losses]

In [35]:
fig = go.Figure()
fig.add_trace(go.Scatter(x=list(range(len(train_losses_picked))), y=train_losses_picked, mode='lines', name='Train Loss (CE)', line=dict(color='darkred', width=3)))
fig.add_trace(go.Scatter(x=list(range(len(cluster_losses_picked))), y=cluster_losses_picked, mode='lines', name='Clusterability', line=dict(color='darkblue', width=3)))
fig.add_shape(type="line", x0=0, y0=0.25, x1=len(train_losses_picked), y1=0.25, name='Baseline', line=dict(color="gray", width=2, dash="dash"))
fig.update_layout(title='', xaxis_title='Training Steps', yaxis_title='Train Loss v. Clusterability')

fig.add_trace(go.Scatter(x=[0], y=[0], mode='lines', name='Baseline', line=dict(color='gray', width=2, dash='dash')))


fig.update_layout({'plot_bgcolor': 'rgba(255, 255, 255, 1)',})
# show fine grid lines on both axes on both subplotsß
fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='LightGray')
fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='LightGray')

# legend inside the plot in a box
fig.update_layout(legend=dict(x=0.55, y=0.93, traceorder="normal", bgcolor="white", bordercolor="black", borderwidth=1))

# width and height
fig.update_layout(width=500, height=500)

# make y axis start from 0
fig.update_yaxes(range=[0, 1])

# # ticks on both axes
fig.update_xaxes(tickmode='linear', tick0=0, dtick=500)
fig.update_yaxes(tickmode='linear', tick0=0, dtick=0.2)

# everthing latex font (for research paper)
fig.update_layout(font=dict(family='serif', size=15, color='black'))
fig.update_xaxes(title_font=dict(family='serif', size=18, color='black'))
fig.update_yaxes(title_font=dict(family='serif', size=18, color='black'))
fig.update_xaxes(tickfont=dict(family='serif', size=18, color='black'))
fig.update_yaxes(tickfont=dict(family='serif', size=18, color='black'))
fig.update_xaxes(showline=True, linewidth=1, linecolor='black', mirror=False)

fig.show()

In [37]:
# save
fig.write_image('plots/wiki_modular_mlp_in_training.pdf')

In [20]:
# load models

paths = ['./checkpoints/wiki_modular_mlp_in_model_epoch_1.pt', './checkpoints/wiki_modular_mlp_in_model_epoch_2.pt', './checkpoints/wiki_modular_mlp_in_model_epoch_3.pt', './checkpoints/wiki_modular_mlp_in_model_epoch_4.pt', './checkpoints/wiki_modular_mlp_in_model_epoch_5.pt', './checkpoints/wiki_modular_mlp_in_model_epoch_6.pt', './checkpoints/wiki_modular_mlp_in_model_epoch_7.pt', './checkpoints/wiki_modular_mlp_in_model_epoch_8.pt', './checkpoints/wiki_modular_mlp_in_model_epoch_9.pt', './checkpoints/wiki_modular_mlp_in_model_epoch_10.pt']

evals = []

for path in paths:

    model = HookedTransformer.from_pretrained("gpt2-small")
    model.load_state_dict(torch.load(path))
    model.to(device)
    evals.append(evaluate(model))

    print(path)

Loaded pretrained model gpt2-small into HookedTransformer
Moving model to device:  cuda



You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.



36718


  0%|          | 0/293 [00:00<?, ?it/s]

wiki: 2.638025111491137
10000


  0%|          | 0/1373 [00:00<?, ?it/s]

owt: 3.852511294997565
10000


  0%|          | 0/2119 [00:00<?, ?it/s]

pile: 3.473926794410932
45404


  0%|          | 0/23199 [00:00<?, ?it/s]

code: 2.4532968667474124
./checkpoints/wiki_modular_mlp_in_model_epoch_1.pt
Loaded pretrained model gpt2-small into HookedTransformer
Moving model to device:  cuda
36718


  0%|          | 0/293 [00:00<?, ?it/s]

wiki: 2.5146239181556322
10000


  0%|          | 0/1373 [00:00<?, ?it/s]

owt: 4.25492256938821
10000


  0%|          | 0/2119 [00:00<?, ?it/s]

pile: 3.9084981479267085
45404


  0%|          | 0/23199 [00:00<?, ?it/s]

code: 2.7840299228630445
./checkpoints/wiki_modular_mlp_in_model_epoch_2.pt
Loaded pretrained model gpt2-small into HookedTransformer
Moving model to device:  cuda
36718


  0%|          | 0/293 [00:00<?, ?it/s]

wiki: 2.407421657354525
10000


  0%|          | 0/1373 [00:00<?, ?it/s]

owt: 4.681574830914488
10000


  0%|          | 0/2119 [00:00<?, ?it/s]

pile: 4.329598629828727
45404


  0%|          | 0/23199 [00:00<?, ?it/s]

code: 3.2547852284837475
./checkpoints/wiki_modular_mlp_in_model_epoch_3.pt
Loaded pretrained model gpt2-small into HookedTransformer
Moving model to device:  cuda
36718


  0%|          | 0/293 [00:00<?, ?it/s]

wiki: 2.2387401774378106
10000


  0%|          | 0/1373 [00:00<?, ?it/s]

owt: 5.0804329437784626
10000


  0%|          | 0/2119 [00:00<?, ?it/s]

pile: 4.703941373541804
45404


  0%|          | 0/23199 [00:00<?, ?it/s]

code: 3.7843725161977333
./checkpoints/wiki_modular_mlp_in_model_epoch_4.pt
Loaded pretrained model gpt2-small into HookedTransformer
Moving model to device:  cuda
36718


  0%|          | 0/293 [00:00<?, ?it/s]

wiki: 2.030227206721164
10000


  0%|          | 0/1373 [00:00<?, ?it/s]

owt: 5.5064976522237945
10000


  0%|          | 0/2119 [00:00<?, ?it/s]

pile: 5.103645905409709
45404


  0%|          | 0/23199 [00:00<?, ?it/s]

code: 4.000698682105187
./checkpoints/wiki_modular_mlp_in_model_epoch_5.pt
Loaded pretrained model gpt2-small into HookedTransformer
Moving model to device:  cuda
36718


  0%|          | 0/293 [00:00<?, ?it/s]

wiki: 1.8109123777635028
10000


  0%|          | 0/1373 [00:00<?, ?it/s]

owt: 5.919877226990049
10000


  0%|          | 0/2119 [00:00<?, ?it/s]

pile: 5.590842173831297
45404


  0%|          | 0/23199 [00:00<?, ?it/s]

code: 4.596446960279257
./checkpoints/wiki_modular_mlp_in_model_epoch_6.pt
Loaded pretrained model gpt2-small into HookedTransformer
Moving model to device:  cuda
36718


  0%|          | 0/293 [00:00<?, ?it/s]

wiki: 1.564398703008595
10000


  0%|          | 0/1373 [00:00<?, ?it/s]

owt: 6.313268000536626
10000


  0%|          | 0/2119 [00:00<?, ?it/s]

pile: 5.956836983709052
45404


  0%|          | 0/23199 [00:00<?, ?it/s]

code: 5.055065230567856
./checkpoints/wiki_modular_mlp_in_model_epoch_7.pt
Loaded pretrained model gpt2-small into HookedTransformer
Moving model to device:  cuda
36718


  0%|          | 0/293 [00:00<?, ?it/s]

wiki: 1.3214703715673768
10000


  0%|          | 0/1373 [00:00<?, ?it/s]

owt: 6.715229785088265
10000


  0%|          | 0/2119 [00:00<?, ?it/s]

pile: 6.388510260251489
45404


  0%|          | 0/23199 [00:00<?, ?it/s]

code: 5.1993512611578
./checkpoints/wiki_modular_mlp_in_model_epoch_8.pt
Loaded pretrained model gpt2-small into HookedTransformer
Moving model to device:  cuda
36718


  0%|          | 0/293 [00:00<?, ?it/s]

wiki: 1.1072587707255146
10000


  0%|          | 0/1373 [00:00<?, ?it/s]

owt: 7.219124477688629
10000


  0%|          | 0/2119 [00:00<?, ?it/s]

pile: 6.9105392729881965
45404


  0%|          | 0/23199 [00:00<?, ?it/s]

code: 5.925149955371819
./checkpoints/wiki_modular_mlp_in_model_epoch_9.pt
Loaded pretrained model gpt2-small into HookedTransformer
Moving model to device:  cuda
36718


  0%|          | 0/293 [00:00<?, ?it/s]

wiki: 0.8762277789635233
10000


  0%|          | 0/1373 [00:00<?, ?it/s]

owt: 7.583900919054995
10000


  0%|          | 0/2119 [00:00<?, ?it/s]

pile: 7.247782697772036
45404


  0%|          | 0/23199 [00:00<?, ?it/s]

code: 6.183383214591753
./checkpoints/wiki_modular_mlp_in_model_epoch_10.pt


In [21]:
evals

[{'wiki_loss': 2.638025111491137,
  'owt_loss': 3.852511294997565,
  'pile_loss': 3.473926794410932,
  'code_loss': 2.4532968667474124},
 {'wiki_loss': 2.5146239181556322,
  'owt_loss': 4.25492256938821,
  'pile_loss': 3.9084981479267085,
  'code_loss': 2.7840299228630445},
 {'wiki_loss': 2.407421657354525,
  'owt_loss': 4.681574830914488,
  'pile_loss': 4.329598629828727,
  'code_loss': 3.2547852284837475},
 {'wiki_loss': 2.2387401774378106,
  'owt_loss': 5.0804329437784626,
  'pile_loss': 4.703941373541804,
  'code_loss': 3.7843725161977333},
 {'wiki_loss': 2.030227206721164,
  'owt_loss': 5.5064976522237945,
  'pile_loss': 5.103645905409709,
  'code_loss': 4.000698682105187},
 {'wiki_loss': 1.8109123777635028,
  'owt_loss': 5.919877226990049,
  'pile_loss': 5.590842173831297,
  'code_loss': 4.596446960279257},
 {'wiki_loss': 1.564398703008595,
  'owt_loss': 6.313268000536626,
  'pile_loss': 5.956836983709052,
  'code_loss': 5.055065230567856},
 {'wiki_loss': 1.3214703715673768,
  'o

In [22]:
# plot the evals

import plotly.graph_objects as go
import plotly.express as px

wiki_losses = [evals[i]['wiki_loss'] for i in range(10)]
pile_losses = [evals[i]['pile_loss'] for i in range(10)]
owt_losses = [evals[i]['owt_loss'] for i in range(10)]
code_losses = [evals[i]['code_loss'] for i in range(10)]

fig = go.Figure()
fig.add_trace(go.Scatter(x=list(range(len(wiki_losses))), y=wiki_losses, mode='lines', name='Wiki', line=dict(color='darkred', width=3)))
fig.add_trace(go.Scatter(x=list(range(len(pile_losses))), y=pile_losses, mode='lines', name='Pile', line=dict(color='darkblue', width=3)))
fig.add_trace(go.Scatter(x=list(range(len(owt_losses))), y=owt_losses, mode='lines', name='OWT', line=dict(color='darkgreen', width=3)))
fig.add_trace(go.Scatter(x=list(range(len(code_losses))), y=code_losses, mode='lines', name='Code', line=dict(color='darkorange', width=3)))
fig.update_layout(title='', xaxis_title='Epochs', yaxis_title='Loss')

fig.update_layout({'plot_bgcolor': 'rgba(255, 255, 255, 1)',})

# show fine grid lines on both axes on both subplotsß
fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='LightGray')
fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='LightGray')

# legend inside the plot in a box
fig.update_layout(legend=dict(x=0.55, y=0.87, traceorder="normal", bgcolor="white", bordercolor="black", borderwidth=1))

# width and height
fig.update_layout(width=500, height=500)

# # ticks on both axes
# fig.update_xaxes(tickmode='linear', tick0=0, dtick=100)

# everthing latex font (for research paper)
fig.update_layout(font=dict(family='serif', size=15, color='black'))
fig.update_xaxes(title_font=dict(family='serif', size=18, color='black'))
fig.update_yaxes(title_font=dict(family='serif', size=18, color='black'))
fig.update_xaxes(tickfont=dict(family='serif', size=18, color='black'))
fig.update_yaxes(tickfont=dict(family='serif', size=18, color='black'))

fig.show()