In [1]:
from sae_lens import HookedSAETransformer
import torch
import os

In [2]:
def generate_unigrams(transformer, layer, hook_name, all_tokens, t_fn, device, batch_size=32):
    all_tokens = list(all_tokens)

    sequences = [t_fn(t) for t in all_tokens]

    t_acts = None
    for i in range(0, len(sequences), batch_size):
        batch = sequences[i:i+batch_size]
        ids = torch.tensor(batch, device=device)

        _, activations = transformer.run_with_cache(ids, prepend_bos=False, stop_at_layer=layer)
        activations = activations[hook_name]
        if t_acts is None:
            t_acts = activations[:, -1, :].to('cpu')
        else:
            t_acts = torch.concat((t_acts, activations[:, -1, :].to('cpu')), dim=0)

    return t_acts
    

def all_unigrams(transformer, layer, hook_name, all_tokens, token_sequence_batch_size, device):
    store = {}

    with torch.no_grad():
        store['[<bos>, t]'] = generate_unigrams(transformer,    layer, hook_name, all_tokens, lambda t: [50256, t], device, batch_size=token_sequence_batch_size)
        store['[<bos>, t, !]'] = generate_unigrams(transformer,    layer, hook_name, all_tokens, lambda t: [50256, t, 0], device, batch_size=token_sequence_batch_size)
        store['[t]'] = generate_unigrams(transformer,    layer, hook_name, all_tokens, lambda t: [t], device, batch_size=token_sequence_batch_size)
        store['[<bos>, <pad>, t]'] = generate_unigrams(transformer,    layer, hook_name, all_tokens, lambda t: [50256, 50256, t], device, batch_size=token_sequence_batch_size)
        store['[t, t]'] = generate_unigrams(transformer, layer, hook_name, all_tokens, lambda t: [t, t], device, batch_size=token_sequence_batch_size)
        store['[" ", t]'] = generate_unigrams(transformer,  layer, hook_name, all_tokens, lambda t: [220, t],  device, batch_size=token_sequence_batch_size)
        store['[" ", " ", t]'] = generate_unigrams(transformer, layer, hook_name, all_tokens, lambda t: [220, 220, t],  device, batch_size=token_sequence_batch_size)
        store['[37233, t]'] = generate_unigrams(transformer,   layer, hook_name, all_tokens, lambda t: [37233, t],  device, batch_size=token_sequence_batch_size)
        store['[<bos>, 37233]'] = generate_unigrams(transformer,   layer, hook_name, all_tokens, lambda t: [50256, 37233],  device, batch_size=token_sequence_batch_size)


    return store

In [3]:
transformer = HookedSAETransformer.from_pretrained('gpt2-small', device='cuda')
tok = transformer.tokenizer



Loaded pretrained model gpt2-small into HookedTransformer


In [4]:
all_tokens = list(range(tok.vocab_size))
store = all_unigrams(transformer, 11, 'blocks.10.hook_resid_post', all_tokens, 1024, 'cuda')

In [None]:
for k, v in store.items():
    print(v, v.std())

In [5]:
for k, v in store.items():
    batch_mean = v.mean(dim=0, keepdim=True)
    batch_std = v.std(dim=0, keepdim=True)

    normalized = (v - batch_mean) / (batch_std + 1e-6)

    store[k] = normalized

In [11]:
for k, v in store.items():
    print(k, v.std())

[<bos>, t] tensor(1.0000)
[<bos>, t, !] tensor(1.0000)
[t] tensor(1.0000)
[<bos>, <pad>, t] tensor(1.0000)
[t, t] tensor(1.0000)
[" ", t] tensor(1.0000)
[" ", " ", t] tensor(1.0000)
[37233, t] tensor(1.0000)
[<bos>, 37233] tensor(1.6570)


In [7]:
if not os.path.exists('../cruft'):
    os.makedirs('../cruft')
torch.save(store, '../cruft/unigrams_gpt2_blocks.10.hook_resid_post_norm.pth')

In [20]:
import torch

In [35]:
a = torch.load('../cruft/unigrams_gpt2_blocks.10.hook_resid_post.pth')

In [36]:
a.keys()

dict_keys(['[<bos>, t]', '[<bos>, t, !]', '[t]', '[<bos>, <pad>, t]', '[t, t]', '[" ", t]', '[" ", " ", t]', '[37233, t]'])

In [37]:
for k, v in a.items():
    print(v.std())

tensor(1.0000)
tensor(1.0000)
tensor(1.0000)
tensor(1.0000)
tensor(1.0000)
tensor(1.0000)
tensor(1.0000)
tensor(1.0000)


In [31]:
for k, v in a.items():
    print(v.std())

tensor(8.6994)
tensor(9.9991)
tensor(112.4044)
tensor(8.6484)
tensor(8.8983)
tensor(8.6744)
tensor(8.7332)
tensor(8.3452)


In [32]:

for k, v in a.items():
    batch_mean = v.mean(dim=0, keepdim=True)
    batch_std = v.std(dim=0, keepdim=True)

    normalized = (v - batch_mean) / (batch_std + 1e-6)

    a[k] = normalized

In [33]:
for k, v in a.items():
    print(v.std())

tensor(1.0000)
tensor(1.0000)
tensor(1.0000)
tensor(1.0000)
tensor(1.0000)
tensor(1.0000)
tensor(1.0000)
tensor(1.0000)


In [34]:
torch.save(a, '../cruft/unigrams_gpt2_blocks.10.hook_resid_post.pth')