In [1]:
from sae_lens import HookedSAETransformer
import torch
import os

In [2]:
def generate_unigrams(transformer, layer, hook_name, all_tokens, t_fn, device, batch_size=32):
    all_tokens = list(all_tokens)

    sequences = [t_fn(t) for t in all_tokens]

    t_acts = None
    for i in range(0, len(sequences), batch_size):
        batch = sequences[i:i+batch_size]
        ids = torch.tensor(batch, device=device)

        _, activations = transformer.run_with_cache(ids, prepend_bos=False, stop_at_layer=layer)
        activations = activations[hook_name]
        if t_acts is None:
            t_acts = activations[:, -1, :].to('cpu')
        else:
            t_acts = torch.concat((t_acts, activations[:, -1, :].to('cpu')), dim=0)

    return t_acts
    

def all_unigrams(transformer, layer, hook_name, all_tokens, token_sequence_batch_size, device):
    store = {}

    with torch.no_grad():
        store['bos'] = generate_unigrams(transformer,    layer, hook_name, all_tokens, lambda t: [transformer.tokenizer.bos_token_id, t], device, batch_size=token_sequence_batch_size)
        store['repeat'] = generate_unigrams(transformer, layer, hook_name, all_tokens, lambda t: [t, t], device, batch_size=token_sequence_batch_size)
        store['space'] = generate_unigrams(transformer,  layer, hook_name, all_tokens, lambda t: [220, t],  device, batch_size=token_sequence_batch_size)
        store['newline'] = generate_unigrams(transformer, layer, hook_name, all_tokens, lambda t: [198, t],  device, batch_size=token_sequence_batch_size)
        store['space2'] = generate_unigrams(transformer, layer, hook_name, all_tokens, lambda t: [220, 220, t],  device, batch_size=token_sequence_batch_size)
        store['rand'] = generate_unigrams(transformer,   layer, hook_name, all_tokens, lambda t: [37233, t],  device, batch_size=token_sequence_batch_size)

    return store

In [3]:
transformer = HookedSAETransformer.from_pretrained('gpt2-small', device='cuda')
tok = transformer.tokenizer

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]



Loaded pretrained model gpt2-small into HookedTransformer


In [4]:
all_tokens = list(range(tok.vocab_size))
store = all_unigrams(transformer, 11, 'blocks.10.hook_resid_post', all_tokens, 1024, 'cuda')

In [5]:
if not os.path.exists('../cruft'):
    os.makedirs('../cruft')
torch.save(store, '../cruft/unigrams_gpt2_blocks.10.hook_resid_post.pth')

In [6]:
a = torch.load('../cruft/unigrams_gpt2_blocks.10.hook_resid_post.pth')

In [7]:
a.keys()

dict_keys(['bos', 'repeat', 'space', 'newline', 'space2', 'rand'])