In [1]:
from transformer_lens import HookedTransformer

In [2]:
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"

device

'cuda'

In [3]:
base_model = HookedTransformer.from_pretrained("gpt2-small")
base_model.to(device)

Loaded pretrained model gpt2-small into HookedTransformer
Moving model to device:  cuda


HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-11): 12 x TransformerBlock(
      (ln1): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (h

## Evaluate the Base Model (MLP)

In [4]:
from transformer_lens.evals import evaluate

In [5]:
base_results = evaluate(base_model)

36718


  0%|          | 0/293 [00:00<?, ?it/s]

wiki: 3.3630548208066733
10000


  0%|          | 0/1373 [00:00<?, ?it/s]

owt: 3.149438685709887
10000


  0%|          | 0/2119 [00:00<?, ?it/s]

pile: 2.855179746552269
45404


  0%|          | 0/23199 [00:00<?, ?it/s]

code: 2.044976301712565


In [6]:
base_results

{'wiki_loss': 3.3630548208066733,
 'owt_loss': 3.149438685709887,
 'pile_loss': 2.855179746552269,
 'code_loss': 2.044976301712565}

In [7]:
# evaluate clusterability

In [5]:
import torch as t

def clusterability(matrix, cluster_U_indices, cluster_V_indices):
    num_clusters = len(cluster_U_indices)
    A = matrix ** 2
    mask = t.zeros_like(A, dtype=t.bool)
    
    for cluster_idx in range(num_clusters):
        u_indices = t.tensor(cluster_U_indices[cluster_idx], dtype=t.long)
        v_indices = t.tensor(cluster_V_indices[cluster_idx], dtype=t.long)
        mask[u_indices.unsqueeze(1), v_indices] = True
    
    intra_cluster_out_sum = t.sum(A[mask])
    total_out_sum = t.sum(A)
    
    return intra_cluster_out_sum / total_out_sum

In [6]:
from sklearn.cluster import KMeans
from scipy.sparse.linalg import svds
from collections import defaultdict
import numpy as np

def bipartite_spectral_clustering(similarity_matrix, k, cluster_U_indices=None, cluster_V_indices=None):
    
    A = similarity_matrix.detach().cpu().numpy() # transform from gpu to cpu, and then to numpy
    A = np.abs(A)
    D_U = np.diag(np.sum(A, axis=1))
    D_V = np.diag(np.sum(A, axis=0))
    D_U_inv_sqrt = np.linalg.inv(np.sqrt(D_U))
    D_V_inv_sqrt = np.linalg.inv(np.sqrt(D_V))
    A_tilde = D_U_inv_sqrt @ A @ D_V_inv_sqrt
    U, Sigma, Vt = svds(A_tilde, k=k)
    if cluster_U_indices is None:
        kmeans_U = KMeans(n_clusters=k, random_state=42).fit(U)
        labels_U = kmeans_U.labels_
        cluster_U_indices = defaultdict(list)
        for i, label in enumerate(labels_U):
            cluster_U_indices[label].append(i)
    if cluster_V_indices is None:
        kmeans_V = KMeans(n_clusters=k, random_state=42).fit(Vt.T)
        labels_V = kmeans_V.labels_
        cluster_V_indices = defaultdict(list)
        for i, label in enumerate(labels_V):
            cluster_V_indices[label].append(i)

    return cluster_U_indices, cluster_V_indices

In [7]:
num_clusters = 4
similarity_matrix = base_model.blocks[5].mlp.W_in
cluster_U_indices, cluster_V_indices = bipartite_spectral_clustering(similarity_matrix, num_clusters)

for i in range(num_clusters):
    print(f'Cluster {i} has {len(cluster_U_indices[i])} nodes in U and {len(cluster_V_indices[i])} nodes in V')

clusterability_score = clusterability(similarity_matrix, cluster_U_indices, cluster_V_indices)
print(f'Clusterability score: {round(clusterability_score.item(), 3)}')

Cluster 0 has 332 nodes in U and 688 nodes in V
Cluster 1 has 338 nodes in U and 711 nodes in V
Cluster 2 has 1 nodes in U and 646 nodes in V
Cluster 3 has 97 nodes in U and 1027 nodes in V
Clusterability score: 0.243


In [8]:
similarity_matrix.shape

torch.Size([768, 3072])

In [9]:
# sequential equal clusters for the MLP W_in matrix

num_clusters = 4
cluster_size = (similarity_matrix.shape[0] // num_clusters, similarity_matrix.shape[1] // num_clusters)
cluster_U_indices = {i: list(range(i*cluster_size[0], (i+1)*cluster_size[0])) for i in range(num_clusters)}
cluster_V_indices = {i: list(range(i*cluster_size[1], (i+1)*cluster_size[1])) for i in range(num_clusters)}

clusterability_score = clusterability(similarity_matrix, cluster_U_indices, cluster_V_indices)

for i in range(num_clusters):
    print(f'Cluster {i} has {len(cluster_U_indices[i])} nodes in U and {len(cluster_V_indices[i])} nodes in V')
print(f'Clusterability score: {round(clusterability_score.item(), 3)}')

Cluster 0 has 192 nodes in U and 768 nodes in V
Cluster 1 has 192 nodes in U and 768 nodes in V
Cluster 2 has 192 nodes in U and 768 nodes in V
Cluster 3 has 192 nodes in U and 768 nodes in V
Clusterability score: 0.25


In [10]:
from transformer_lens.evals import make_wiki_data_loader, make_pile_data_loader, make_owt_data_loader, make_code_data_loader

datasets = {
    'wiki': make_wiki_data_loader(base_model.tokenizer, batch_size=8),
    'pile': make_pile_data_loader(base_model.tokenizer, batch_size=8),
    'owt': make_owt_data_loader(base_model.tokenizer, batch_size=8),
    'code': make_code_data_loader(base_model.tokenizer, batch_size=8),
}

36718
10000
10000
45404


In [21]:
for idx, batch in enumerate(datasets['wiki']):
    print(idx)
    # print(batch['tokens'].shape)

    # # get loss of the model on the batch
    # tokens = batch['tokens'].to(device)
    # loss = base_model(tokens, return_type='loss')
    # print(loss)

    break

0


In [None]:
## Expt 1: Wiki on all layer MLP_in

cluster_losses = []
train_losses = []
test_losses = []
lomda = 20.0
model = HookedTransformer.from_pretrained("gpt2-small")
blocks = [model.blocks[i].mlp.W_in for i in range(12)]
path = './checkpoints/'
num_epochs = 30

for block_idx, block in enumerate(blocks):
    num_clusters = 4
    cluster_size = (block.shape[0] // num_clusters, block.shape[1] // num_clusters)
    cluster_U_indices = {i: list(range(i*cluster_size[0], (i+1)*cluster_size[0])) for i in range(num_clusters)}
    cluster_V_indices = {i: list(range(i*cluster_size[1], (i+1)*cluster_size[1])) for i in range(num_clusters)}
    model = HookedTransformer.from_pretrained("gpt2-small")
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    model.train()
    for epoch in range(num_epochs):
        for idx, batch in enumerate(datasets['wiki']):
            tokens = batch['tokens'].to(device)
            block_new = model.blocks[block_idx].mlp.W_in
            cluster_loss_mlp_in = clusterability(block_new, cluster_U_indices, cluster_V_indices)
            train_loss = model(tokens, return_type="loss")
            cluster_loss = cluster_loss_mlp_in
            cluster_losses.append(cluster_loss.item())
            loss = train_loss - lomda * cluster_loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if idx % 100 == 0:
                # print(f'Epoch {epoch+1}, Batch {idx}, Train Loss: {train_loss.item()}, Cluster Loss: {cluster_loss.item()}')
                print(f'Epoch {epoch+1}, Batch {idx}, Cluster Loss: {cluster_loss.item()}')
    # save the model
    torch.save(model.state_dict(), path + f'wiki_mlp_in_{block_idx}.pt')

Loaded pretrained model gpt2-small into HookedTransformer
Loaded pretrained model gpt2-small into HookedTransformer
Moving model to device:  cuda
Epoch 1, Batch 0, Cluster Loss: 0.24992132186889648
Epoch 1, Batch 100, Cluster Loss: 0.28389742970466614
Epoch 1, Batch 200, Cluster Loss: 0.3224729895591736
Epoch 2, Batch 0, Cluster Loss: 0.3622661232948303
Epoch 2, Batch 100, Cluster Loss: 0.4083871841430664
Epoch 2, Batch 200, Cluster Loss: 0.4566418528556824
Epoch 3, Batch 0, Cluster Loss: 0.5020801424980164
Epoch 3, Batch 100, Cluster Loss: 0.5501216053962708
Epoch 3, Batch 200, Cluster Loss: 0.5960720777511597
Epoch 4, Batch 0, Cluster Loss: 0.636108934879303
Epoch 4, Batch 100, Cluster Loss: 0.6757259964942932
Epoch 4, Batch 200, Cluster Loss: 0.7115509510040283
Epoch 5, Batch 0, Cluster Loss: 0.7414720058441162
Epoch 5, Batch 100, Cluster Loss: 0.7701610922813416
Epoch 5, Batch 200, Cluster Loss: 0.7955107688903809
Epoch 6, Batch 0, Cluster Loss: 0.8163723945617676
Epoch 6, Batch 10

KeyboardInterrupt: 