In [1]:
from transformer_lens import HookedTransformer
import torch
import torch as t
device = "cuda" if torch.cuda.is_available() else "cpu"
base_model = HookedTransformer.from_pretrained("gpt2-small")
base_model.to(device)
from transformer_lens.evals import evaluate
base_results = evaluate(base_model)

Loaded pretrained model gpt2-small into HookedTransformer
Moving model to device:  cuda
36718


  0%|          | 0/293 [00:00<?, ?it/s]

wiki: 3.3703102546163124
10000


  0%|          | 0/1373 [00:00<?, ?it/s]

owt: 3.140637449698873
10000


  0%|          | 0/2119 [00:00<?, ?it/s]

pile: 2.903512966514814
45404


  0%|          | 0/23199 [00:00<?, ?it/s]

code: 1.9757940934436156


In [2]:
base_results

{'wiki_loss': 3.3703102546163124,
 'owt_loss': 3.140637449698873,
 'pile_loss': 2.903512966514814,
 'code_loss': 1.9757940934436156}

In [14]:
def clusterability(matrix, cluster_U_indices=None, cluster_V_indices=None, num_clusters=4):
    A = matrix ** 2
    mask = t.zeros_like(A, dtype=t.bool)
    
    cluster_size = (A.shape[0] // num_clusters, A.shape[1] // num_clusters)
    cluster_U_indices = {i: list(range(i*cluster_size[0], (i+1)*cluster_size[0])) for i in range(num_clusters)}
    cluster_V_indices = {i: list(range(i*cluster_size[1], (i+1)*cluster_size[1])) for i in range(num_clusters)}

    for cluster_idx in range(num_clusters):
        u_indices = t.tensor(cluster_U_indices[cluster_idx], dtype=t.long)
        v_indices = t.tensor(cluster_V_indices[cluster_idx], dtype=t.long)
        mask[u_indices.unsqueeze(1), v_indices] = True
    
    intra_cluster_out_sum = t.sum(A[mask])
    total_out_sum = t.sum(A)
    
    return intra_cluster_out_sum / total_out_sum

In [4]:
from transformer_lens.evals import make_wiki_data_loader, make_pile_data_loader, make_owt_data_loader, make_code_data_loader

datasets = {
    'wiki': make_wiki_data_loader(base_model.tokenizer, batch_size=8),
    'pile': make_pile_data_loader(base_model.tokenizer, batch_size=8),
    'owt': make_owt_data_loader(base_model.tokenizer, batch_size=8),
    'code': make_code_data_loader(base_model.tokenizer, batch_size=8),
}

36718
10000
10000
45404


In [15]:
clusterability(base_model.blocks[4].mlp.W_in)

tensor(0.2493, device='cuda:0', grad_fn=<DivBackward0>)

In [18]:
cluster_losses = []
train_losses = []
test_losses = []
lomda = 20.0
model = HookedTransformer.from_pretrained("gpt2-small")
model.to(device)
blocks_to_cluster = [model.blocks[i].mlp.W_in for i in range(12)]
path = './checkpoints/'
num_epochs = 1
num_clusters = 4
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
model.train()
    
for epoch in range(num_epochs):
    for idx, batch in enumerate(datasets['wiki']):
        tokens = batch['tokens'].to(device)
        cluster_loss = sum([clusterability(block) for block in blocks_to_cluster]) / len(blocks_to_cluster)
        train_loss = model(tokens, return_type="loss")
        cluster_losses.append(cluster_loss.item())
        loss = train_loss - lomda * cluster_loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if idx % 100 == 0:
            print(f'Epoch {epoch+1}, Batch {idx}, Train Loss: {round(train_loss.item(), 4)}, Clusterability: {round(cluster_loss.item(), 4)}')    

torch.save(model.state_dict(), path + f'wiki_modular_mlp_in_model.pt')

Loaded pretrained model gpt2-small into HookedTransformer
Moving model to device:  cuda
Epoch 1, Batch 0, Train Loss: 3.3778, Clusterability: 0.25
Epoch 1, Batch 100, Train Loss: 2.8608, Clusterability: 0.282
Epoch 1, Batch 200, Train Loss: 2.8833, Clusterability: 0.3212
