In [1]:
from sae_lens import HookedSAETransformer
import torch

In [2]:
def generate_unigrams(transformer, layer, hook_name, all_tokens, t_fn, device, batch_size=32):
    all_tokens = list(all_tokens)

    sequences = [t_fn(t) for t in all_tokens]

    t_acts = None
    for i in range(0, len(sequences), batch_size):
        batch = sequences[i:i+batch_size]
        ids = torch.tensor(batch, device=device)

        _, activations = transformer.run_with_cache(ids, prepend_bos=False, stop_at_layer=layer)
        activations = activations[hook_name]
        if t_acts is None:
            t_acts = activations[:, -1, :].to('cpu')
        else:
            t_acts = torch.concat((t_acts, activations[:, -1, :].to('cpu')), dim=0)

    return t_acts
    

def all_unigrams(transformer, layer, hook_name, all_tokens, token_sequence_batch_size, device):
    store = {}

    with torch.no_grad():
        store['bos'] = generate_unigrams(transformer,    layer, hook_name, all_tokens, lambda t: [transformer.tokenizer.bos_token_id, t], device, batch_size=token_sequence_batch_size)
        store['repeat'] = generate_unigrams(transformer, layer, hook_name, all_tokens, lambda t: [t, t], device, batch_size=token_sequence_batch_size)
        store['space'] = generate_unigrams(transformer,  layer, hook_name, all_tokens, lambda t: [220, t],  device, batch_size=token_sequence_batch_size)
        store['newline'] = generate_unigrams(transformer, layer, hook_name, all_tokens, lambda t: [198, t],  device, batch_size=token_sequence_batch_size)
        store['space2'] = generate_unigrams(transformer, layer, hook_name, all_tokens, lambda t: [220, 220, t],  device, batch_size=token_sequence_batch_size)
        store['rand'] = generate_unigrams(transformer,   layer, hook_name, all_tokens, lambda t: [37233, t],  device, batch_size=token_sequence_batch_size)

    return store

In [3]:
transformer = HookedSAETransformer.from_pretrained('gpt2-small', device='cuda')
tok = transformer.tokenizer

Loaded pretrained model gpt2-small into HookedTransformer


In [4]:
dir(tok)

['SPECIAL_TOKENS_ATTRIBUTES',
 '__annotations__',
 '__call__',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getstate__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__len__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_add_tokens',
 '_additional_special_tokens',
 '_auto_class',
 '_batch_encode_plus',
 '_bos_token',
 '_call_one',
 '_cls_token',
 '_compile_jinja_template',
 '_convert_encoding',
 '_convert_id_to_token',
 '_convert_token_to_id_with_added_voc',
 '_create_repo',
 '_decode',
 '_decode_use_source_tokenizer',
 '_encode_plus',
 '_eos_token',
 '_eventual_warn_about_too_long_sequence',
 '_eventually_correct_t5_max_length',
 '_from_pretrained',
 '_get_files_timestamps',
 '_get_padding_truncation_strategies',
 '_in_target_context_manager',
 

In [6]:
tok.vocab_size

50257

In [7]:

all_tokens = list(range(tok.vocab_size))
store = all_unigrams(transformer, 11, 'blocks.10.hook_resid_post', all_tokens, 1024, 'cuda')

In [9]:
torch.save(store, '../cruft/unigrams_blocks.10.hook_resid_post.pth')