In [1]:

def imshow(tensor, renderer=None, xaxis="", yaxis="", font_size=None, show=True, color_continuous_midpoint=0.0, **kwargs):
    import plotly.express as px
    import transformer_lens.utils as utils
    fig = px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=color_continuous_midpoint, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs)
    if not font_size is None:
        if 'x' in kwargs:
            fig.update_layout(
              xaxis = dict(
                tickmode='array',
                tickvals = kwargs['x'],
                ticktext = kwargs['x'], 
                ),
               font=dict(size=font_size, color="black"))
        if 'y' in kwargs:
            fig.update_layout(
              yaxis = dict(
                tickmode='array',
                tickvals = kwargs['y'],
                ticktext = kwargs['y'], 
                ),
               font=dict(size=font_size, color="black"))
    plot_args = {
        'width': 800,
        'height': 600,
        "autosize": False,
        'showlegend': True,
        'margin': {"l":0,"r":0,"t":100,"b":0}
    }
    
    fig.update_layout(**plot_args)
    fig.update_layout(legend=dict(
        yanchor="top",
        y=0.99,
        xanchor="left",
        x=0.01
    ))
    if show:
        fig.show(renderer)
    else:
        return fig


# requires
# pip install git+https://github.com/Phylliida/MambaLens.git

from mamba_lens import HookedMamba # this will take a little while to import
import torch
model_path = "state-spaces/mamba-370m"

# NOTE! We need to monkeypatch transformer lens to use register_full_backward_hook
model = HookedMamba.from_pretrained(model_path, device='cuda')
torch.set_grad_enabled(False)


from safetensors.torch import load_file
from collections import defaultdict

import sys
if not "/home/dev/sae-k-sparse-mamba/sae" in sys.path:
    sys.path.append("/home/dev/sae-k-sparse-mamba")
import os
os.chdir('/home/dev/sae-k-sparse-mamba')
saes = [None]
from importlib import reload
from sae.sae import Sae

ckpt_dir = "/home/dev/sae-k-sparse-mamba/"
for i in range(1,22):
    print(i)
    hook = f'blocks.{i}.hook_out_proj'
    path = [ckpt_dir + f for f in sorted(list(os.listdir(ckpt_dir))) if hook in f][0] + "/" + f'hook_{hook}.pt'
    #path = f'/home/dev/sae-k-sparse-mamba/blocks.{i}.hook_resid_pre/hook_blocks.{i}.hook_resid_pre.pt'
    print(path)
    saes.append(Sae.load_from_disk(path, hook=f'blocks.{i}.hook_resid_pre', device=model.cfg.device))


global PATCHING_FORMAT_I
global patching_formats
def make_data(num_patching_pairs, patching, template_i, seed, valid_seed):
    constrain_to_answers = True
    # this makes our data size 800, first 400 is each (a,b) pair, and then second 400 is each pair swapped to be (b,a)
    has_symmetric_patching = True
    
    n1_patchings = ["""
    ABC BC A
    DBC BC D""",
        """
    ABC CB A
    DBC CB D"""]
    
    n2_patchings = ["""
    ABC AC B
    ADC AC D""",
        """
    ABC CA B
    ADC CA D"""]
    
    n3_patchings = ["""
    ABC AB C
    ABD AB D""",
        """
    ABC BA C
    ABD BA D"""]
    
    n4_patchings = ["""
    ABC AC B
    ABC BC A""",
        """
    ABC AB C
    ABC CB A""",
        """
    ABC BA C
    ABC CA B"""]
    
    n5_patchings = ["""
    ABC CA B
    ABC CB A
    """,
        """
    ABC BA C
    ABC BC A""",
        """
    ABC AB C
    ABC AC B"""]
    
    patchings = {
        'n1': n1_patchings,
        'n2': n2_patchings,
        'n3': n3_patchings,
        'n4': n4_patchings,
        'n5': n5_patchings
    }
    
    all_patchings = []
    for patching_set in patchings.values():
        all_patchings += patching_set
    all_patchings = sorted(all_patchings) # make deterministic 
    
    patch_all_names = ["""
    ABC AB C
    DEF DE F""",
        """
    ABC AC B
    DEF DF E""",     
        """
    ABC BA C
    DEF ED F""",
        """
    ABC BC A
    DEF EF D""",
        """
    ABC CA B
    DEF FD E""",
        """
    ABC CB A
    DEF FE D"""]
    
    
    patchings['all'] = all_patchings
    patchings['allatonce'] = patch_all_names
    from acdc.data.ioi import BABA_TEMPLATES, ABC_TEMPLATES
    from acdc.data.ioi import ioi_data_generator, ABC_TEMPLATES, get_all_single_name_abc_patching_formats
    from acdc.data.utils import generate_dataset
    templates = ABC_TEMPLATES
    #patching_formats = list(get_all_single_name_abc_patching_formats())
    global PATCHING_FORMAT_I
    global patching_formats
    PATCHING_FORMAT_I = patching
    patching_formats = ["\n".join([line.strip() for line in x.split("\n")]).strip() for x in patchings[PATCHING_FORMAT_I]]
    
    print("using patching format")
    for patch in patching_formats:
        print(patch)
        print("")
    #print(patching_formats)
    
    
    data = generate_dataset(model=model,
                      data_generator=ioi_data_generator,
                      num_patching_pairs=4,
                      seed=seed,
                      valid_seed=valid_seed,
                      constrain_to_answers=constrain_to_answers,
                      has_symmetric_patching=has_symmetric_patching, 
                      varying_data_lengths=True,
                      templates=templates,
                      patching_formats=patching_formats)
    
    
    import acdc.data.ioi
    from collections import defaultdict
    name_positions_map = defaultdict(lambda: [])
    for template in templates:
        name = acdc.data.ioi.good_names[0]
        template_filled_in = template.replace("[NAME]", name)
        template_filled_in = template_filled_in.replace("[PLACE]", acdc.data.ioi.good_nouns['[PLACE]'][0])
        template_filled_in = template_filled_in.replace("[OBJECT]", acdc.data.ioi.good_nouns['[OBJECT]'][0])
        # get the token positions of the [NAME] in the prompt
        name_positions = tuple([(i) for (i,s) in enumerate(model.to_str_tokens(torch.tensor(model.tokenizer.encode(template_filled_in)))) if s == f' {name}'])
        name_positions_map[name_positions].append(template)
    sorted_by_frequency = sorted(list(name_positions_map.items()), key=lambda x: -len(x[1]))
    most_frequent_name_positions, templates = sorted_by_frequency[0]
    print("using templates")
    templates = [templates[0]]
    for template in templates:
        print(template)
    print(f"with name positions {most_frequent_name_positions}")
    import acdc.data.ioi
    if 'Jesus' in acdc.data.ioi.good_names:
        print("removed jesus")
        acdc.data.ioi.good_names.remove("Jesus")
    data = generate_dataset(model=model,
                  data_generator=ioi_data_generator,
                  num_patching_pairs=num_patching_pairs,
                  seed=seed,
                  valid_seed=valid_seed,
                  constrain_to_answers=constrain_to_answers,
                  has_symmetric_patching=has_symmetric_patching, 
                  varying_data_lengths=True,
                  templates=templates,
                  patching_formats=patching_formats)
    
    print(model.to_str_tokens(data.data[0]))
    print(model.to_str_tokens(data.data[1]))
    return data


from transformer_lens.hook_points import HookPoint
# we do a hacky thing where this first hook clears the global storage
# second hook stores all the hooks
# then third hook computes the output (over all the hooks)
# this avoids recomputing and so is much faster
SAE_HOOKS = "sae hooks"
SAE_BATCHES = "sae batches"
SAE_OUTPUT = "sae output"
def sae_patching_storage_hook(
    x,
    hook: HookPoint,
    sae_feature_i: int,
    dummy: bool,
    position: int,
    layer: int,
    batch_start: int,
    batch_end: int,
    **kwargs,
):
    global sae_storage
    if not SAE_HOOKS in sae_storage:
        sae_storage[SAE_HOOKS] = [] # we can't do this above because it'll be emptied again on the next batch before this is called
    sae_storage[SAE_OUTPUT] = None # clear output
    sae_storage[SAE_HOOKS].append({"position": position, "sae_feature_i": sae_feature_i, "dummy": dummy})
    #print(f"sae feature i {sae_feature_i} position {position} layer {layer}")
    return x

from jaxtyping import Float
from einops import rearrange

global sae_storage
sae_storage = {}
def sae_patching_hook(
    x: Float[torch.Tensor, "B L E"],
    hook: HookPoint,
    input_hook_name: str,
    layer: int,
    **kwargs,
) -> Float[torch.Tensor, "B L E"]:
    global sae_storage
    ### This is identical to what the conv is doing
    # but we break it apart so we can patch on individual filters
    # we have two input hooks, the second one is the one we want
    input_hook_name = input_hook_name[1]
    # don't recompute these if we don't need to
    # because we stored all the hooks and batches in conv_storage, we can just do them all at once
    
    # they need to share an output because they write to the same output tensor
    if sae_storage[SAE_OUTPUT] is None:
        #print(f"running for layer {layer}")
        K = saes[layer].cfg.k
        sae = saes[layer]
        #print(f"layer {layer} storage {sae_storage}")
        sae_output = torch.zeros(x.size(), device=model.cfg.device)
        #print("layer", layer, "keys", conv_storage)
        def get_filter_key(i):
            return f'filter_{i}'
        sae_input_uncorrupted = x[::2]
        sae_input_corrupted = x[1::2]
        B, L, D = sae_input_uncorrupted.size()
        for l in range(L):
            # [B, NFeatures]                             [B,D]
            uncorrupted_features = sae.encode(sae_input_uncorrupted[:,l])
            # [B, NFeatures]                             [B,D]
            corrupted_features = sae.encode(sae_input_corrupted[:,l])
            patched_features = corrupted_features.clone()
            #patched_features = torch.zeros(corrupted_features.size(), device=model.cfg.device) # patch everything except the features we are keeping around
            # apply hooks (one hook applies to a single feature)
            #print(f"{len(sae_storage[SAE_HOOKS])} hooks")
            for hook_data in sae_storage[SAE_HOOKS]:
                position = hook_data['position']
                sae_feature_i = hook_data['sae_feature_i']
                dummy = hook_data['dummy']
                if not dummy and (position == l or position is None): # position is None means all positions
                    if copy_from_other:
                        patched_features[:,sae_feature_i] = corrupted_features[:,sae_feature_i]
                    else:
                        patched_features[:,sae_feature_i] = uncorrupted_features[:,sae_feature_i]
                    
                    #print(f"applying sae feature {sae_feature_i} to position {position} for layer {layer}")
                    #uncorrupted_features[:,sae_feature_i] = corrupted_features[:,sae_feature_i]
            # compute sae outputs
            patched_top_acts, patched_top_indices = patched_features.topk(K, sorted=False)
            corrupted_top_acts, corrupted_top_indices = corrupted_features.topk(K, sorted=False)      
            sae_output[::2,l] = sae.decode(patched_top_acts, patched_top_indices)     
            sae_output[1::2,l] = sae.decode(corrupted_top_acts, corrupted_top_indices)
        sae_storage = {} # clean up and prepare for next layer
        sae_storage[SAE_OUTPUT] = sae_output # store the output
    return sae_storage[SAE_OUTPUT]

import pickle
with open("cached_sae_feature_edges.pkl", "rb") as f:
    edges_to_keep = pickle.load(f)

  return self.fget.__get__(instance, owner)()
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Moving model to device:  cuda
1
/home/dev/sae-k-sparse-mamba/0.0001414213562373095 initialTry190.txtblocks.1.hook_out_proj/hook_blocks.1.hook_out_proj.pt
2
/home/dev/sae-k-sparse-mamba/0.0001414213562373095 initialTry191.txtblocks.2.hook_out_proj/hook_blocks.2.hook_out_proj.pt
3
/home/dev/sae-k-sparse-mamba/0.0001414213562373095 initialTry192.txtblocks.3.hook_out_proj/hook_blocks.3.hook_out_proj.pt
4
/home/dev/sae-k-sparse-mamba/0.0001414213562373095 initialTry193.txtblocks.4.hook_out_proj/hook_blocks.4.hook_out_proj.pt
5
/home/dev/sae-k-sparse-mamba/0.0001414213562373095 initialTry194.txtblocks.5.hook_out_proj/hook_blocks.5.hook_out_proj.pt
6
/home/dev/sae-k-sparse-mamba/0.0001414213562373095 initialTry195.txtblocks.6.hook_out_proj/hook_blocks.6.hook_out_proj.pt
7
/home/dev/sae-k-sparse-mamba/0.0001414213562373095 initialTry196.txtblocks.7.hook_out_proj/hook_blocks.7.hook_out_proj.pt
8
/home/dev/sae-k-sparse-mamba/0.0001414213562373095 initialTry200.txtblocks.8.hook_out_proj/hook_bloc

In [4]:
data = make_data(num_patching_pairs=20000, patching="all", template_i=0, seed=24, valid_seed=23)

using patching format
ABC AB C
ABC AC B

ABC AB C
ABC CB A

ABC AB C
ABD AB D

ABC AC B
ABC BC A

ABC AC B
ADC AC D

ABC BA C
ABC BC A

ABC BA C
ABC CA B

ABC BA C
ABD BA D

ABC BC A
DBC BC D

ABC CA B
ABC CB A

ABC CA B
ADC CA D

ABC CB A
DBC CB D

using templates
Then, [NAME], [NAME] and [NAME] went to the [PLACE]. [NAME] and [NAME] gave a [OBJECT] to
with name positions (2, 4, 6, 12, 14)
['<|endoftext|>', 'Then', ',', ' Olivia', ',', ' Ian', ' and', ' Aaron', ' went', ' to', ' the', ' restaurant', '.', ' Aaron', ' and', ' Olivia', ' gave', ' a', ' computer', ' to']
['<|endoftext|>', 'Then', ',', ' Olivia', ',', ' Ian', ' and', ' Aaron', ' went', ' to', ' the', ' restaurant', '.', ' Aaron', ' and', ' Ian', ' gave', ' a', ' computer', ' to']


In [11]:

toks = model.to_str_tokens(data.data[0])
name_positions = [3,5,7,13,15]
position_map = {}
L = data.data.size()[1]
for l in range(L):
    position_map[l] = f'pos{l}{toks[l]}'
position_map[3] = 'n1'
position_map[5] = 'n2'
position_map[7] = 'n3'
position_map[13] = 'n4'
position_map[15] = 'n5'
position_map[19] = 'out'


print(len(edges_to_keep))
sae_edges = defaultdict(lambda: defaultdict(lambda: []))
counts = defaultdict(lambda: defaultdict(lambda: 0))
num_to_show = 400
iters = 0

for edge in edges_to_keep:
    if '.sae' in edge.output_node and not edge.label is None:
        # [pos:feature_i]
        label = edge.label[1:-1]
        pos, feature_i = label.split(":")
        pos = int(pos)
        if feature_i == 'KEEP': continue # dummy edge used to ensure sae always applied
        iters += 1
        feature_i = int(feature_i)
        layer = int(edge.output_node.split(".")[0])
        attr = edge.score_diff_when_patched
        sae_edges[layer][pos].append((attr, feature_i))
        counts[layer][feature_i] += 1
        #print(layer, pos, position_map[pos], feature_i, attr)
        #print(layer, pos, feature_i, attr)
        #if iters > num_to_show: break


8487


In [12]:
total_num_features = 0
for layer in sorted(list(sae_edges.keys())):
    print(f"layer {layer} with {len(counts[layer])} unique features ({len([x for x in counts[layer] if x > 1])} duplicated)")
    total_num_features += len(counts[layer])
    values = sae_edges[layer]
    for pos in sorted(list(values.keys())):
        print(f"  pos {position_map[pos]} num sae {len(values[pos])} min attr scaled {'{:.3f}'.format(1000*min([x[0] for x in values[pos]]))}")
print(f"total num features {total_num_features}")

layer 1 with 144 unique features (144 duplicated)
  pos n1 num sae 23 min attr scaled -16.656
  pos n2 num sae 30 min attr scaled -10.017
  pos n3 num sae 28 min attr scaled -14.836
  pos n4 num sae 65 min attr scaled -84.558
  pos n5 num sae 61 min attr scaled -152.670
layer 2 with 17 unique features (17 duplicated)
  pos n1 num sae 4 min attr scaled -18.957
  pos n2 num sae 1 min attr scaled -19.427
  pos n3 num sae 6 min attr scaled -11.387
  pos n4 num sae 6 min attr scaled -29.547
  pos n5 num sae 8 min attr scaled -28.615
layer 3 with 213 unique features (213 duplicated)
  pos n1 num sae 35 min attr scaled -16.394
  pos n2 num sae 77 min attr scaled -46.406
  pos pos6 and num sae 1 min attr scaled -6.787
  pos n3 num sae 83 min attr scaled -27.689
  pos pos11 garden num sae 13 min attr scaled -20.751
  pos pos12. num sae 7 min attr scaled -12.354
  pos n4 num sae 84 min attr scaled -44.882
  pos n5 num sae 71 min attr scaled -77.533
layer 4 with 22 unique features (22 duplicated)

In [144]:
from dataclasses import dataclass, field

def get_batched_index_into(indices):
    '''
    given data that is [B,N,V] and indicies that are [B,N,K] with each index being an index into the V space
    this gives you indexes you can use to access your values
    '''
    first_axis = []
    second_axis = []
    third_axis = []
    B, _, _ = indices.size()
    for b in range(B):
        second, third = get_index_into(indices[b])
        first_axis.append(torch.full(second.size(), fill_value=b, device=model.cfg.device))
        second_axis.append(second)
        third_axis.append(third)

    return torch.cat(first_axis), torch.cat(second_axis), torch.cat(third_axis)

def get_index_into(indices):
    '''
    given data that is [N,V] and indicies that are [N,K] with each index being an index into the V space
    this gives you indexes you can use to access your values
    '''
    num_data, num_per_data = indices.size()
    # we want
    # [0,0,0,...,] num per data of these
    # [1,1,1,...,] num per data of these
    # ...
    # [num_data-1, num_data-1, ...]
    first_axis_index = torch.arange(num_data, dtype=torch.long).view(num_data, 1)*torch.ones([num_data, num_per_data], dtype=torch.long)
    # now we flatten it so it has an index for each term aligned with our indices
    first_axis_index = first_axis_index.flatten()
    second_axis_index = indices.flatten()
    return first_axis_index, second_axis_index
global buffer
buffer = None
global features_by_layer
def sae_hook(
    x,
    hook,
    layer,
):
    # s is [B,L,E]
    K = saes[layer].cfg.k
    sae = saes[layer]
    B,L,D = x.size()
    uncorrupted_features = sae.encode(x)
    top_acts, top_indices = uncorrupted_features.topk(K, sorted=False)
    global buffer
    if buffer is None:
        buffer = torch.zeros(uncorrupted_features.size(), device=model.cfg.device)
    buffer[:] = 0
    
    # zero everything except the top k
    buffer[get_batched_index_into(top_indices)] = top_acts.flatten()
    for feature in features_by_layer[layer]:
        feature.records += [x.item() for x in buffer[:,feature.pos,feature.feature_i]]
    # kernel can't handle doing all token positions at same time by default
    # but if we make it think B*L is a single batch index it works fine
    top_acts_flattened = top_acts.flatten(start_dim=0, end_dim=1)
    top_indices_flattened = top_indices.flatten(start_dim=0, end_dim=1)
    sae_out = sae.decode(top_acts_flattened, top_indices_flattened)
    sae_out = sae_out.unflatten(dim=0, sizes=(B,L))
    return sae_out



@dataclass
class SAEFeature:
    """Class for keeping track of an item in inventory."""
    layer: int
    pos: int
    feature_i: int
    attr: float
    records: list = field(default_factory=lambda: [])

    def __repr__(self):
        return str(self.layer) + " " + str(self.pos) + " " + str(self.feature_i) + " " + str(self.attr)

    def __str__(self):
        return self.__repr__()

def parse_feature(feature_str):
    layer, pos, pos_name, feature_i, attr = feature_str.split()
    layer = int(layer)
    pos = int(pos)
    feature_i = int(feature_i)
    attr = float(attr)
    return SAEFeature(layer=layer, pos=pos, feature_i=feature_i, attr=attr)
features = """
15 13 n4 15921 -1.4511811931733973
15 13 n4 11839 -1.1555863783723908
20 15 n5 27256 -1.0338027551188134
20 13 n4 23228 -0.9890130006533582
20 13 n4 2724 -0.9646386316744611
15 13 n4 25771 -0.8934519871108932
15 15 n5 26824 -0.8928024366832688
20 13 n4 23731 -0.844180965796113
15 15 n5 11839 -0.8429616270004772
15 15 n5 27758 -0.8418691491242498
14 15 n5 13971 -0.7999667014519218
15 13 n4 17259 -0.7739449571818113
20 15 n5 2724 -0.737606625945773
15 15 n5 31021 -0.6951032060023863
15 13 n4 7440 -0.6218199331851793
15 15 n5 8113 -0.6123286176807596
15 13 n4 8113 -0.6059157950221561
20 15 n5 25369 -0.5720785861194599
15 13 n4 6146 -0.5615684384829365
20 13 n4 25369 -0.5487742855912074
15 13 n4 26824 -0.5487118880992057
15 13 n4 31021 -0.5430846622330137
19 5 n2 30561 -0.5199790453916648
20 15 n5 23228 -0.5164780604463886
15 13 n4 28222 -0.5075037965434603
14 13 n4 13971 -0.4765838644379983
14 15 n5 32567 -0.4751656675944105
11 13 n4 22965 -0.46325298480223864
15 15 n5 6146 -0.45198706406517886
19 15 n5 30740 -0.43911108660540776
20 15 n5 29653 -0.43348417259403504
15 15 n5 8935 -0.4290462978715368
20 15 n5 8196 -0.40242351567940204
15 15 n5 16138 -0.4004057846032083
19 3 n1 30561 -0.3884095794055611
20 15 n5 1899 -0.3876068123790901
20 5 n2 27256 -0.3814757227519294
15 13 n4 2344 -0.38085291545576183
19 15 n5 30561 -0.37879287756368285
15 15 n5 22790 -0.37808844797109487
15 15 n5 2344 -0.3775725119630806
15 13 n4 12167 -0.3703001577523537
21 15 n5 7554 -0.34814969086801284
21 13 n4 7554 -0.34735435343463905
14 13 n4 32567 -0.3427205775587936
15 15 n5 29892 -0.33304538365337066
20 15 n5 23731 -0.33183777012163773
15 15 n5 8649 -0.3311291775389691
15 13 n4 2380 -0.32058422826230526
15 15 n5 1349 -0.317370749762631
15 15 n5 28979 -0.3154674572579097
15 13 n4 30976 -0.3153745725285262
20 13 n4 17612 -0.31191760893852916
15 13 n4 22801 -0.3013547840891988
12 15 n5 6008 -0.30098827754409285
12 15 n5 4851 -0.2981748393503949
20 13 n4 1899 -0.29404354887083173
14 5 n2 13971 -0.29324889490089845
15 13 n4 22790 -0.28884116747940425
15 15 n5 17259 -0.2834519026146154
19 5 n2 9076 -0.28338390760472976
20 5 n2 23228 -0.2785688600561116
15 13 n4 29892 -0.2779534184228396
20 15 n5 24925 -0.27310544914143975
15 13 n4 32240 -0.2724338702391833
19 7 n3 30561 -0.27105611146544106
20 13 n4 29653 -0.26252533006481826
19 13 n4 30740 -0.26012522503879154
20 5 n2 25369 -0.25853115250356495
11 13 n4 19600 -0.2510134789190488
20 15 n5 6758 -0.24530868871806888
15 7 n3 8935 -0.24060688240570016
11 13 n4 18719 -0.23991555671091191
15 13 n4 28979 -0.23681864549871534
20 15 n5 17612 -0.23483685902829166
15 3 n1 25771 -0.22723528533242643
15 15 n5 17920 -0.22721617101342417
20 15 n5 6986 -0.22323725407477468
14 15 n5 28831 -0.22306419239612296
15 15 n5 22801 -0.22277752275113016
20 13 n4 24925 -0.22063778418123547
15 3 n1 8113 -0.21664783913001884
11 15 n5 19600 -0.21422454587991524
16 15 n5 19800 -0.21313862562237773
20 13 n4 27256 -0.2103884415628272
15 7 n3 26824 -0.208646113776922
14 5 n2 32567 -0.20791079929767875
20 15 n5 10083 -0.20755451168952277
11 15 n5 18719 -0.20682069531903835
15 15 n5 30976 -0.20519282953318907
15 15 n5 10252 -0.19859311032632831
20 5 n2 8455 -0.19858658533848939
20 13 n4 3156 -0.18876934882428031
20 5 n2 2724 -0.1882108402205631
20 13 n4 6986 -0.18641833098081406
20 3 n1 25369 -0.18377804876945447
20 13 n4 1672 -0.18313410242262762
20 7 n3 17612 -0.17767891606126796
14 3 n1 13971 -0.177633166378655
20 13 n4 31901 -0.17577282109414227
15 5 n2 27758 -0.1747781486083113
15 7 n3 2380 -0.174283399428532
15 15 n5 9746 -0.1737501583957055
20 3 n1 29653 -0.17337747645069612
15 5 n2 26824 -0.16980820387470885
15 13 n4 1349 -0.16958206321578473
15 3 n1 2344 -0.16653591250360478
20 13 n4 15013 -0.16628824583312962
19 15 n5 27888 -0.16485086461761966
14 13 n4 28831 -0.16269497356734064
15 15 n5 25903 -0.15941790843498893
20 15 n5 1672 -0.15898824847317883
21 7 n3 6419 -0.15865897008188767
15 15 n5 3888 -0.1581162586571736
20 7 n3 8455 -0.1577744372516463
19 13 n4 9076 -0.15575948629702907
15 15 n5 25771 -0.15522614054498263
1 15 n5 25764 -0.15266994523699395
15 15 n5 12167 -0.15147657443594653
15 15 n5 15921 -0.15116392161144176
15 3 n1 31021 -0.1500888324371772
20 3 n1 23228 -0.14868561428011162
11 5 n2 19600 -0.147709660937835
15 13 n4 25903 -0.1476323436727398
15 15 n5 26556 -0.1475745009338425
14 7 n3 13971 -0.1464709811261855
20 13 n4 10083 -0.14522044641853427
21 5 n2 7554 -0.14491799040115438
11 7 n3 19600 -0.1438251133004087
20 7 n3 27256 -0.14189712883853645
15 3 n1 28222 -0.14156844822718995
15 7 n3 27758 -0.13967746320849983
20 13 n4 6758 -0.1395779878639587
16 5 n2 8413 -0.1384436698135687
15 15 n5 15762 -0.1374775571275677
15 5 n2 25771 -0.1365963689131604
14 7 n3 28831 -0.13502024069566687
18 13 n4 24113 -0.13279320938272576
20 3 n1 23731 -0.13276870549452724
20 3 n1 24925 -0.12945301840591128
15 5 n2 11839 -0.12674832851189421
20 15 n5 1336 -0.12278529405011795
20 3 n1 1899 -0.12241411584545858
20 7 n3 6758 -0.12176360644843953
20 15 n5 21539 -0.1217278007643472
15 13 n4 9187 -0.11823179898783565
20 15 n5 13117 -0.11767573590623215
15 7 n3 31021 -0.11718380186539434
11 7 n3 18719 -0.11692520580254495
10 7 n3 11071 -0.11588574545021402""".strip()
features = [parse_feature(line.strip()) for line in features.split("\n")]
features = []

for pos, feats in sae_edges[15].items():
    features += [SAEFeature(layer=15, pos=pos, feature_i=feature_i, attr=attr) for (attr, feature_i) in feats]
from functools import partial
from tqdm import tqdm
def forward_check_features(data, features, batch_size):
    
    global features_by_layer

    features_by_layer = defaultdict(lambda: [])
    for feature in features:
        feature.records = []
        features_by_layer[feature.layer].append(feature)

    # only bother with SAE on the layers we are checking
    layers_to_apply_sae = sorted(list(features_by_layer.keys()))
    hooks = [(f'blocks.{layer}.hook_out_proj', partial(sae_hook, layer=layer)) for layer in layers_to_apply_sae]

    DATA_LEN = data.size()[0]
    for batch_start in tqdm(list(range(0, DATA_LEN, batch_size))):
        batch_end = min(DATA_LEN, batch_start+batch_size)
        data_batch = data[batch_start:batch_end]
        _ = model.run_with_hooks(input=data_batch, fwd_hooks=hooks, fast_ssm=True, fast_conv=True)
    

import acdc.data.ioi

with open("/home/dev/mamba_interp/MoreNames.txt", "r") as f:
    all_names = [x.strip() for x in f.read().split("\n") if len(x.strip()) > 0]
    # regenerate names, but more
    acdc.data.ioi.NAMES = sorted(list(set(all_names)))
    acdc.data.ioi.good_names = None 

In [None]:

data = make_data(num_patching_pairs=50000, patching="all", template_i=0, seed=24, valid_seed=23)
forward_check_features(data=data.data, features=features, batch_size=200)

In [17]:
with open("features_layer_15_with_collected_data.pkl", "wb") as f:
    pickle.dump(features, f)

In [3]:
with open("features_all_tokens_layer_15_with_collected_data.pkl", "rb") as f:
    features = pickle.load(f)
features = modifiedFeatures
data.data = data_for_all_tokens
global template_to_i
template_to_i = {}

global all_templates
all_templates = []

def extract_template(data_point):
    a,b,c,d,e = data_point[3], data_point[5], data_point[7], data_point[13], data_point[15]
    lookup = {}
    template = ""
    order = 'ABCDEF'
    order_ind = 0
    for name in [a,b,c,d,e]:
        if not name in lookup:
            lookup[name] = order[order_ind]
            order_ind += 1
        template += lookup[name]
    global all_templates
    
    if not template in template_to_i.keys():
        all_templates.append(template)
        template_to_i[template] = len(all_templates)-1
    return template_to_i[template]
'''
from tqdm import tqdm
data_to_template = []
for d in tqdm(range(data.data.size()[0])):
    data_to_template.append(extract_template(data.data[d]))
'''
def get_name_counts(feature):
    name_counts = {}
    DATA_LEN = len(feature.records)
    records_tensor = torch.tensor(feature.records)
    non_zero_indices = torch.arange(DATA_LEN)[records_tensor!=0]
    non_zero_tokens = data.data[non_zero_indices,feature.pos].cpu()
    non_zero_records = records_tensor[non_zero_indices]
    name_tokens = torch.unique(non_zero_tokens)
    for name_token in name_tokens:
        name_str = model.to_str_tokens(name_token.view(1,1))[0]
        name_counts[name_str] = non_zero_records[non_zero_tokens==name_token.item()]
    #for t,c in template_counts.items():
    #    print(f" template {t} with count {torch.mean(torch.tensor(c)).item()}")
    name_counts = sorted(list(name_counts.items()), key=lambda x: -torch.mean(x[1]).item())
    return name_counts
    #for n,c in name_counts[:100]:
    #    print(f" name {n} with avg {torch.mean(c).item()} min {torch.min(c).item()} max {torch.max(c).item()}")

#names = sorted(list(acdc.data.ioi.good_names))
names = [x for (i,x) in spaceThings if len(x.strip()) > 0]
NUM_NAMES = len(names)
#name_to_i = dict([(" " + name, i) for (i, name) in enumerate(names)])
name_to_i = dict([(name, i) for (i, name) in enumerate(names)])

def get_name_vector(feature, feat_type):
    name_vec = torch.zeros(NUM_NAMES)
    for name,counts in get_name_counts(feature):
        if feat_type == 'mean':
            name_vec[name_to_i[name]] = torch.mean(counts)
        elif feat_type == 'min':
            name_vec[name_to_i[name]] = torch.min(counts)
        elif feat_type == 'max':
            name_vec[name_to_i[name]] = torch.max(counts)
            
    return name_vec

#vecs = []
#for feature in features:
#    get_name_vector(feature)
    

features_sorted_by_feat_i = defaultdict(lambda: [])
for feature in features:
    features_sorted_by_feat_i[feature.feature_i].append(feature)
            

AttributeError: Can't get attribute 'SAEFeature' on <module '__main__'>

In [167]:

first_letter_to_name = defaultdict(lambda: [])
for (name_i, name) in enumerate(names):
    if len(name.strip()) > 0:
        first_letter_to_name[name.strip()[0].lower()].append((name, name_i))
#for letter in 'abcdefghijklmnopqrstuvwxyz':
#    first_letter_to_name[letter] = [(name, name_i) for (name_i, name) in enumerate(names) if len(name.strip().lower()) > 0 and name.strip().lower()[0] == letter]

contains_letter_to_name = {}
for letter in 'abcdefghijklmnopqrstuvwxyz':
    contains_letter_to_name[letter] = [(name, name_i) for (name_i, name) in enumerate(names) if letter in name.strip().lower()]

def detect_single_letter(feature):
    name_vec = get_name_vector(feature, 'mean')
    top_inds = torch.argsort(-name_vec)
    top_letter = names[top_inds[0]].strip()[0].lower()
    num_matched = 0
    for ind in top_inds[:200]:
        if top_letter != names[ind].strip()[0].lower():
            break
        else:
            num_matched += 1
    #print("num matched", num_matched, "num names of that", len(first_letter_to_name[top_letter]))
    return top_letter, num_matched / float(len(first_letter_to_name[top_letter])), len(first_letter_to_name[top_letter])


def detect_contains_letter(feature):
    name_vec = get_name_vector(feature, 'mean')
    top_inds = torch.argsort(-name_vec)
    top_letter = names[top_inds[0]].strip()[0].lower()
    num_matched = 0
    for ind in top_inds[:200]:
        if top_letter != names[ind].strip()[0].lower():
            break
        else:
            num_matched += 1
    #print("num matched", num_matched, "num names of that", len(first_letter_to_name[top_letter]))
    return top_letter, num_matched / float(len(first_letter_to_name[top_letter])), len(first_letter_to_name[top_letter])




# if we are the only letter that is seen 0.85 of them after seeing two letters, that's good enough
def improved_detect_first_letter(feature):
    first_letter_info = list_first_letter_info(feature)
    for i in range(1,4):
        i_good_letters = []
        for letter, freqdict in first_letter_info.items():
            val = max([freqdict[k] for k in range(1,i+1)])
            if val > 0.85:
                i_good_letters.append((letter, val))
        if len(i_good_letters) == 1:
            letter, confidance = i_good_letters[0]
            return letter, i, confidance, len(first_letter_to_name[letter])
        elif len(i_good_letters) > 1:
            return None, 0, 0.0, 0
    return None, 0, 0.0, 0

# gives a dict with [letter][ind]
# where ind is the number of distinct letters seen so far
# and the value of the dict[letter][ind] is the proportion of total
# names seen of that letter so far (for ind or less)
def list_first_letter_info(feature):
    
    name_vec = get_name_vector(feature, 'mean')
    top_inds = torch.argsort(-name_vec)
    top_letter = names[top_inds[0]].strip()[0].lower()
    frequencies = defaultdict(lambda: defaultdict(lambda: 0))
    for ind in top_inds[:500]:
        letter = names[ind].strip()[0].lower()
        num_letters = len(frequencies)
        if not letter in frequencies:
            num_letters += 1
        frequencies[letter][num_letters] += 1
        if num_letters > 3:
            break
    resultProportions = defaultdict(lambda: defaultdict(lambda: 0))
    for letter, freqdict in frequencies.items():
        nums = sorted(list(freqdict.keys()))
        total = 0
        for n in nums:
            total += freqdict[n]
            resultProportions[letter][n] = total / float(len(first_letter_to_name[letter]))
    return resultProportions


single_letter_feats = defaultdict(lambda: defaultdict(lambda: []))

from tqdm import tqdm

for feat_i, feats in tqdm(list(features_sorted_by_feat_i.items())):
    is_single_letter = True
    have_any = False
    max_num_seen = 0
    for feat in feats:
        if position_map[feat.pos][0] == 'n':
            letter, num_seen, rating, num_of_that_letter = improved_detect_first_letter(feat)
            max_num_seen = max(num_seen, max_num_seen)
            if rating < 0.85 or num_of_that_letter < 4:
                is_single_letter = False
            else:
                have_any = True # needed incase they are all non name tokens
        else:
            print(f"warning, feature {feat_i} has non name pos {position_map[feat.pos]} (all pos are {[position_map[f.pos] for f in feats]})")
    if is_single_letter and have_any:
        single_letter_feats[max_num_seen][feat_i] = feats

for num_seen in sorted(list(single_letter_feats.keys())):
    print(f"num seen {num_seen} single letter {len(single_letter_feats[num_seen])} num total {len(features_sorted_by_feat_i)}")


100%|█████████████████████████████████████████████████████████████████████████████████| 307/307 [06:20<00:00,  1.24s/it]


In [169]:

letters = defaultdict(lambda: 0)
for k in single_letter_feats.keys():
    print(f"k of {k}")
    for feat_i, feats in single_letter_feats[k].items():
        print(f"feature {feat_i}")
        letter = None
        for feat in feats:
            if position_map[feat.pos][0] == 'n':
                letter = detect_single_letter(feat)
                print(position_map[feat.pos], letter)
                letter = letter
            else:
                print(f"warning, feature {feat_i} has non name pos {position_map[feat.pos]} (all pos are {[position_map[f.pos] for f in feats]})")
        letters[letter[0]] += 1
        continue
        
        feat_vecs = [get_name_vector(feat, 'mean') for feat in feats]
        diffs = torch.zeros(len(feats), len(feats))
        for i,featv1 in enumerate(feat_vecs):
            for j,featv2 in enumerate(feat_vecs):
                diffs[i,j] = torch.mean(torch.abs(featv1-featv2))
        labels = [position_map[feat.pos] for feat in feats]
        imshow(diffs, x=labels, y=labels, font_size=9)
        
        avg_vec = torch.stack(feat_vecs).mean(dim=0)
        min_vec = torch.stack([get_name_vector(feat, 'min') for feat in feats]).min(dim=0).values
        max_vec = torch.stack([get_name_vector(feat, 'max') for feat in feats]).max(dim=0).values
        sorted_names = torch.argsort(-avg_vec)
        #print(avg_vec, min_vec, max_vec, sorted_names)
        for name_i in sorted_names[:50]:
            print(name_i)
            print(f" name {names[name_i]} with avg {avg_vec[name_i]} min {min_vec[name_i]} max {max_vec[name_i]}")

In [71]:
letters

defaultdict(<function __main__.<lambda>()>,
            {'p': 1,
             's': 2,
             't': 3,
             'l': 3,
             'v': 3,
             'n': 2,
             'a': 1,
             'b': 1,
             'i': 2,
             'r': 4,
             'w': 1,
             'z': 2})

In [None]:



def pretty_print_list_first_letter_info(info):
    for k,v in info.items():
        print(k, sorted(list(v.items())))

for feat_i, feats in features_sorted_by_feat_i.items():
    if any([position_map[feat.pos][0] != 'n' for feat in feats]):
        print(f"warning, feature {feat_i} has non name poses, all pos are {[position_map[f.pos] for f in feats]})")
        continue
    found = False
    for k in single_letter_feats.keys():
        if feat_i in single_letter_feats[k]:
            print(feat_i)
            print(position_map[feats[0].pos])
            print(f"feature {feat_i} is {k} seen, single letter feat ", detect_single_letter(feats[0]))
            found = True
            break
    if found:
        continue
        
    for feat in feats:
        if position_map[feat.pos][0] == 'n':
            print(position_map[feat.pos], detect_single_letter(feat))
            pretty_print_list_first_letter_info(list_first_letter_info(feat))
            print(improved_first_letter(feat))
    
    feat_vecs = [get_name_vector(feat, 'mean') for feat in feats]
    diffs = torch.zeros(len(feats), len(feats))
    for i,featv1 in enumerate(feat_vecs):
        for j,featv2 in enumerate(feat_vecs):
            diffs[i,j] = torch.mean(torch.abs(featv1-featv2))
    labels = [position_map[feat.pos] for feat in feats]
    imshow(diffs, x=labels, y=labels, font_size=9)
    
    avg_vec = torch.stack(feat_vecs).mean(dim=0)
    min_vec = torch.stack([get_name_vector(feat, 'min') for feat in feats]).min(dim=0).values
    max_vec = torch.stack([get_name_vector(feat, 'max') for feat in feats]).max(dim=0).values
    sorted_names = torch.argsort(-avg_vec)
    #print(avg_vec, min_vec, max_vec, sorted_names)
    for name_i in sorted_names[:50]:
        print(name_i)
        print(f" name {names[name_i]} with avg {avg_vec[name_i]} min {min_vec[name_i]} max {max_vec[name_i]}")

(2416, ' writ')

In [120]:
model.tokenizer.decode(torch.tensor([2416]).view(1))

' writ'

In [None]:
h = model.to_str_tokens(torch.arange(model.tokenizer.vocab_size))
spaceThings = [(i, x) for (i, x) in enumerate(h) if x[0] == ' ' and len(x.strip()) > 0]
modifiedFeatures = []
for feature in features:
    modifiedFeatures.append(SAEFeature(layer=feature.layer, pos=3, feature_i=feature.feature_i, attr=feature.attr))
prefix = data.data[0][:3].view(1,-1)
new_data_toks = torch.tensor([tok for (tok,s) in spaceThings], device=model.cfg.device)
data_for_all_tokens = torch.cat([prefix.repeat((len(new_data_toks),1)), new_data_toks.view(-1,1)], dim=1)

In [165]:
'''
['<|endoftext|>',
 'Then',
 ',',
'''

forward_check_features(data=data_for_all_tokens, features=modifiedFeatures, batch_size=200)

100%|█████████████████████████████████████████████████████████████████████████████████| 144/144 [03:53<00:00,  1.62s/it]


In [6]:
from dataclasses import dataclass, field
@dataclass
class SAEFeature:
    """Class for keeping track of an item in inventory."""
    layer: int
    pos: int
    feature_i: int
    attr: float
    records: list = field(default_factory=lambda: [])

    def __repr__(self):
        return str(self.layer) + " " + str(self.pos) + " " + str(self.feature_i) + " " + str(self.attr)

    def __str__(self):
        return self.__repr__()
from tqdm import tqdm
data = make_data(num_patching_pairs=2, patching="all", template_i=0, seed=24, valid_seed=23)
h = model.to_str_tokens(torch.arange(model.tokenizer.vocab_size))
#spaceThings = [(i, x) for (i, x) in enumerate(h) if x[0] == ' ' and len(x.strip()) > 0]
spaceThings = [(i, x) for (i, x) in enumerate(h) if len(x.strip()) > 0]
prefix = data.data[0][:1].view(1,-1)
new_data_toks = torch.tensor([tok for (tok,s) in spaceThings], device=model.cfg.device)
data_for_all_tokens = torch.cat([prefix.repeat((len(new_data_toks),1)), new_data_toks.view(-1,1)], dim=1)
data.data = data_for_all_tokens
with open("features_all_tokens_layer_15_with_collected_data.pkl", "rb") as f:
    features = pickle.load(f)

DATA_LEN = data.data.size()[0]

TOP_K = 40
batch_size = 200
new_data = []
for batch_start in tqdm(list(range(0, DATA_LEN, batch_size))):
    batch_end = min(DATA_LEN, batch_start+batch_size)
    data_batch = data.data[batch_start:batch_end]
    # [B,L,V]
    logits = model(input=data_batch, fast_ssm=True, fast_conv=True)
    inds = torch.argsort(-logits[:,-1,:], dim=1)
    prs = torch.softmax(logits[:,-1,:], dim=1)
    for b in range(batch_end-batch_start):
        for t in range(TOP_K):
            pr = prs[b,inds[b,t]]
            if pr > 0.05:
                new_data.append(torch.concatenate([data_batch[b],inds[b,t:t+1]]))
    # [B,20]


using patching format
ABC AB C
ABC AC B

ABC AB C
ABC CB A

ABC AB C
ABD AB D

ABC AC B
ABC BC A

ABC AC B
ADC AC D

ABC BA C
ABC BC A

ABC BA C
ABC CA B

ABC BA C
ABD BA D

ABC BC A
DBC BC D

ABC CA B
ABC CB A

ABC CA B
ADC CA D

ABC CB A
DBC CB D

using templates
Then, [NAME], [NAME] and [NAME] went to the [PLACE]. [NAME] and [NAME] gave a [OBJECT] to
with name positions (2, 4, 6, 12, 14)
['<|endoftext|>', 'Then', ',', ' Olivia', ',', ' Ian', ' and', ' Aaron', ' went', ' to', ' the', ' restaurant', '.', ' Aaron', ' and', ' Olivia', ' gave', ' a', ' computer', ' to']
['<|endoftext|>', 'Then', ',', ' Olivia', ',', ' Ian', ' and', ' Aaron', ' went', ' to', ' the', ' restaurant', '.', ' Aaron', ' and', ' Ian', ' gave', ' a', ' computer', ' to']


100%|█████████████████████████████████████████████████████████████████████████████████| 250/250 [03:05<00:00,  1.35it/s]


In [7]:
data_for_all_tokens = torch.stack(new_data)
print(data_for_all_tokens.size())

'''
data_for_all_tokens = torch.cat([prefix.repeat((len(new_data_toks),1)), new_data_toks.view(-1,1)], dim=1)
data_for_all_tokens = data_for_all_tokens.repeat((TOP_K, 1))
next_tok = torch.concatenate(next_tokens).view(-1,1)
print(next_tok[:20])
print(next_tokens[0][0])
data_for_top_k_tokens = torch.cat([data_for_all_tokens, next_tok], dim=1)
print(data_for_top_k_tokens.size())
'''
for i in range(10):
    import random
    ind = random.randint(0,data_for_all_tokens.size()[0])
    print(model.to_str_tokens(data_for_all_tokens[ind]))



DATA_LEN = data_for_all_tokens.size()[0]
TOP_K = 20
batch_size = 200
new_data = []
for batch_start in tqdm(list(range(0, DATA_LEN, batch_size))):
    batch_end = min(DATA_LEN, batch_start+batch_size)
    data_batch = data_for_all_tokens[batch_start:batch_end]
    # [B,L,V]
    logits = model(input=data_batch, fast_ssm=True, fast_conv=True)
    inds = torch.argsort(-logits[:,-1,:], dim=1)
    prs = torch.softmax(logits[:,-1,:], dim=1)
    for b in range(batch_end-batch_start):
        for t in range(TOP_K):
            pr = prs[b,inds[b,t]]
            if pr > 0.05:
                new_data.append(torch.concatenate([data_batch[b],inds[b,t:t+1]]))

data_for_all_tokens2 = torch.stack(new_data)
print(data_for_all_tokens2.size())

for i in range(10):
    import random
    ind = random.randint(0,data_for_all_tokens2.size()[0])
    print(model.to_str_tokens(data_for_all_tokens2[ind]))

DATA_LEN = data_for_all_tokens2.size()[0]
TOP_K = 20
batch_size = 200
new_data = []
for batch_start in tqdm(list(range(0, DATA_LEN, batch_size))):
    batch_end = min(DATA_LEN, batch_start+batch_size)
    data_batch = data_for_all_tokens2[batch_start:batch_end]
    # [B,L,V]
    logits = model(input=data_batch, fast_ssm=True, fast_conv=True)
    inds = torch.argsort(-logits[:,-1,:], dim=1)
    prs = torch.softmax(logits[:,-1,:], dim=1)
    for b in range(batch_end-batch_start):
        for t in range(TOP_K):
            pr = prs[b,inds[b,t]]
            if pr > 0.05:
                new_data.append(torch.concatenate([data_batch[b],inds[b,t:t+1]]))

data_for_all_tokens3 = torch.stack(new_data)
print(data_for_all_tokens3.size())

for i in range(100):
    import random
    ind = random.randint(0,data_for_all_tokens2.size()[0])
    print(model.to_str_tokens(data_for_all_tokens2[ind]))



torch.Size([112552, 3])
['<|endoftext|>', ' antagon', 'ism']
['<|endoftext|>', ' citiz', 'in']
['<|endoftext|>', ' stalled', '\n']
['<|endoftext|>', ' ammonium', ' nitrate']
['<|endoftext|>', 'ouble', 'ts']
['<|endoftext|>', ' Waste', 'water']
['<|endoftext|>', ' appeal', '\n']
['<|endoftext|>', 'Paper', 'back']
['<|endoftext|>', 'ividual', ' and']
['<|endoftext|>', 'ulfide', '\n']


100%|█████████████████████████████████████████████████████████████████████████████████| 563/563 [05:47<00:00,  1.62it/s]


torch.Size([202620, 4])
['<|endoftext|>', 'ィ', 'ィ', 'ア']
['<|endoftext|>', ' съ', 'ем', 'оч']
['<|endoftext|>', ' file', '1', '.']
['<|endoftext|>', ' 297', '*', 'w']
['<|endoftext|>', ' spo', 'ilt', ' by']
['<|endoftext|>', '}}}(', '1', ')']
['<|endoftext|>', 'Browser', '-', 'based']
['<|endoftext|>', ' recurring', '?', '  ']
['<|endoftext|>', ' пол', 'ож', 'итель']
['<|endoftext|>', ' Lots', ' of', ' people']


100%|███████████████████████████████████████████████████████████████████████████████| 1014/1014 [10:34<00:00,  1.60it/s]


torch.Size([389372, 5])
['<|endoftext|>', 'Were', ' there', ' any']
['<|endoftext|>', ' in', ' -', '4']
['<|endoftext|>', ' decor', '.', 'css']
['<|endoftext|>', 'zag', 're', 'ba']
['<|endoftext|>', ' BUSINESS', ' RE', 'PORT']
['<|endoftext|>', ' Feld', 'k', 'irc']
['<|endoftext|>', ' Eigen', 'values', ' and']
['<|endoftext|>', ' van', ' der', ' Wa']
['<|endoftext|>', 'Tg', 'A', 'i']
['<|endoftext|>', ' Ref', 'erral', ' of']
['<|endoftext|>', ' semiconductor', ' devices', ' such']
['<|endoftext|>', ' assignment', ':', '\n']
['<|endoftext|>', ' Mend', 'elian', ' inheritance']
['<|endoftext|>', 'bf', 'c', '_']
['<|endoftext|>', '---------------------------------------------------', '\n', '--']
['<|endoftext|>', ' Derby', ' (', 'film']
['<|endoftext|>', 'EV', 'ERY', 'ONE']
['<|endoftext|>', ' affinity', ':', ' {']
['<|endoftext|>', '#', 'include', ' <']
['<|endoftext|>', 'pmb', '\n', '\n']
['<|endoftext|>', ' attribut', 'es', ' to']
['<|endoftext|>', ' MAKE', ' IT', ' RIGHT']
['<|endoftex

In [9]:

features_sorted_by_feat_i = defaultdict(lambda: [])
for feature in features:
    new_feats = []
    for pos in range(1,5):
        feat1 = SAEFeature(layer=15, pos=pos, feature_i=feature.feature_i, attr=feature.attr)
        new_feats.append(feat1)
    features_sorted_by_feat_i[feature.feature_i] = new_feats

new_modified_feats = []
for f,feats in features_sorted_by_feat_i.items():
    new_modified_feats += feats
from functools import partial

def get_batched_index_into(indices):
    '''
    given data that is [B,N,V] and indicies that are [B,N,K] with each index being an index into the V space
    this gives you indexes you can use to access your values
    '''
    first_axis = []
    second_axis = []
    third_axis = []
    B, _, _ = indices.size()
    for b in range(B):
        second, third = get_index_into(indices[b])
        first_axis.append(torch.full(second.size(), fill_value=b, device=model.cfg.device))
        second_axis.append(second)
        third_axis.append(third)

    return torch.cat(first_axis), torch.cat(second_axis), torch.cat(third_axis)

def get_index_into(indices):
    '''
    given data that is [N,V] and indicies that are [N,K] with each index being an index into the V space
    this gives you indexes you can use to access your values
    '''
    num_data, num_per_data = indices.size()
    # we want
    # [0,0,0,...,] num per data of these
    # [1,1,1,...,] num per data of these
    # ...
    # [num_data-1, num_data-1, ...]
    first_axis_index = torch.arange(num_data, dtype=torch.long).view(num_data, 1)*torch.ones([num_data, num_per_data], dtype=torch.long)
    # now we flatten it so it has an index for each term aligned with our indices
    first_axis_index = first_axis_index.flatten()
    second_axis_index = indices.flatten()
    return first_axis_index, second_axis_index
global buffer
buffer = None
global features_by_layer
def sae_hook(
    x,
    hook,
    layer,
):
    # s is [B,L,E]
    K = saes[layer].cfg.k
    sae = saes[layer]
    B,L,D = x.size()
    uncorrupted_features = sae.encode(x)
    top_acts, top_indices = uncorrupted_features.topk(K, sorted=False)
    global buffer
    if buffer is None:
        buffer = torch.zeros(uncorrupted_features.size(), device=model.cfg.device)
    buffer[:] = 0
    
    # zero everything except the top k
    buffer[get_batched_index_into(top_indices)] = top_acts.flatten()
    for feature in features_by_layer[layer]:
        feature.records += [x.item() for x in buffer[:,feature.pos,feature.feature_i]]
    # kernel can't handle doing all token positions at same time by default
    # but if we make it think B*L is a single batch index it works fine
    top_acts_flattened = top_acts.flatten(start_dim=0, end_dim=1)
    top_indices_flattened = top_indices.flatten(start_dim=0, end_dim=1)
    sae_out = sae.decode(top_acts_flattened, top_indices_flattened)
    sae_out = sae_out.unflatten(dim=0, sizes=(B,L))
    return sae_out


def forward_check_features(data, features, batch_size):
    
    global features_by_layer

    features_by_layer = defaultdict(lambda: [])
    for feature in features:
        feature.records = []
        features_by_layer[feature.layer].append(feature)

    # only bother with SAE on the layers we are checking
    layers_to_apply_sae = sorted(list(features_by_layer.keys()))
    hooks = [(f'blocks.{layer}.hook_out_proj', partial(sae_hook, layer=layer)) for layer in layers_to_apply_sae]
    DATA_LEN = data.size()[0]
    for batch_start in tqdm(list(range(0, DATA_LEN, batch_size))):
        batch_end = min(DATA_LEN, batch_start+batch_size)
        data_batch = data[batch_start:batch_end]
        _ = model.run_with_hooks(input=data_batch, fwd_hooks=hooks, fast_ssm=True, fast_conv=True)
    


forward_check_features(data=data_for_all_tokens3, features=new_modified_feats, batch_size=200)

100%|█████████████████████████████████████████████████████████████████████████████| 1947/1947 [1:22:09<00:00,  2.53s/it]


In [11]:
with open("layer_15_features_more_more.pkl", "wb") as f:
    pickle.dump((data_for_all_tokens3, new_modified_feats), f)

In [2]:
from datasets import load_dataset
from sae.data import chunk_and_tokenize
dataset = load_dataset(
    "togethercomputer/RedPajama-Data-1T-Sample",
    split="train",
    trust_remote_code=True,
)
tokenizer = model.tokenizer
# too many processes crashes, probably memory issue
tokenized = chunk_and_tokenize(dataset, tokenizer, num_proc=8)

In [None]:
from dataclasses import dataclass, field
from functools import partial


@dataclass
class SAEFeature:
    """Class for keeping track of an item in inventory."""
    layer: int
    pos: int
    feature_i: int
    attr: float
    records: list = field(default_factory=lambda: [])

    def __repr__(self):
        return str(self.layer) + " " + str(self.pos) + " " + str(self.feature_i) + " " + str(self.attr)

    def __str__(self):
        return self.__repr__()




with open("features_all_tokens_layer_15_with_collected_data.pkl", "rb") as f:
    features = pickle.load(f)





def get_batched_index_into(indices):
    '''
    given data that is [B,N,V] and indicies that are [B,N,K] with each index being an index into the V space
    this gives you indexes you can use to access your values
    '''
    first_axis = []
    second_axis = []
    third_axis = []
    B, _, _ = indices.size()
    for b in range(B):
        second, third = get_index_into(indices[b])
        first_axis.append(torch.full(second.size(), fill_value=b, device=model.cfg.device))
        second_axis.append(second)
        third_axis.append(third)

    return torch.cat(first_axis), torch.cat(second_axis), torch.cat(third_axis)

def get_index_into(indices):
    '''
    given data that is [N,V] and indicies that are [N,K] with each index being an index into the V space
    this gives you indexes you can use to access your values
    '''
    num_data, num_per_data = indices.size()
    # we want
    # [0,0,0,...,] num per data of these
    # [1,1,1,...,] num per data of these
    # ...
    # [num_data-1, num_data-1, ...]
    first_axis_index = torch.arange(num_data, dtype=torch.long).view(num_data, 1)*torch.ones([num_data, num_per_data], dtype=torch.long)
    # now we flatten it so it has an index for each term aligned with our indices
    first_axis_index = first_axis_index.flatten()
    second_axis_index = indices.flatten()
    return first_axis_index, second_axis_index
global buffer
buffer = None
global features_by_layer
def sae_hook(
    x,
    hook,
    layer,
):
    # s is [B,L,E]
    K = saes[layer].cfg.k
    sae = saes[layer]
    B,L,D = x.size()
    uncorrupted_features = sae.encode(x)
    top_acts, top_indices = uncorrupted_features.topk(K, sorted=False)
    global buffer
    if buffer is None:
        buffer = torch.zeros(uncorrupted_features.size(), device=model.cfg.device)
    buffer[:] = 0
    global features_by_layer
    # zero everything except the top k
    buffer[get_batched_index_into(top_indices)] = top_acts.flatten()
    for feature in features_by_layer[layer]:
        feature.records += buffer[:,feature.pos,feature.feature_i].tolist()
    # kernel can't handle doing all token positions at same time by default
    # but if we make it think B*L is a single batch index it works fine
    top_acts_flattened = top_acts.flatten(start_dim=0, end_dim=1)
    top_indices_flattened = top_indices.flatten(start_dim=0, end_dim=1)
    sae_out = sae.decode(top_acts_flattened, top_indices_flattened)
    sae_out = sae_out.unflatten(dim=0, sizes=(B,L))
    return sae_out
from tqdm import tqdm

def forward_check_features(data, features, batch_size):
    
    global features_by_layer
    
    #with open("layer_15_features_on_large_data.pkl", "rb") as f:
    #    features = pickle.load(f)
    features_by_layer = defaultdict(lambda: [])
    for feature in features:
        feature.records = []
        features_by_layer[feature.layer].append(feature)

    # only bother with SAE on the layers we are checking
    layers_to_apply_sae = sorted(list(features_by_layer.keys()))
    hooks = [(f'blocks.{layer}.hook_out_proj', partial(sae_hook, layer=layer)) for layer in layers_to_apply_sae]
    DATA_LEN = len(data)
    indd = 0
    num_already_processed = len(features[0].records)
    for batch_start in tqdm(list(range(0, DATA_LEN, batch_size))):
        indd += 1
        if indd % 100 == 0:
            with open(f"layer_15_features_on_large_data{indd}.pkl", "wb") as f:
                pickle.dump(features, f)
                print("saving")
            for feature in features:
                del feature.records
                feature.records = []
        batch_end = min(DATA_LEN, batch_start+batch_size)
        if batch_end <= num_already_processed: continue
        data_batch = data[batch_start:batch_end]['input_ids'][:,:128]
        _ = model.run_with_hooks(input=data_batch, fwd_hooks=hooks, fast_ssm=True, fast_conv=True)

features_sorted_by_feat_i = defaultdict(lambda: [])
for feature in features:
    new_feats = []
    for pos in range(1,128):
        feat1 = SAEFeature(layer=15, pos=pos, feature_i=feature.feature_i, attr=feature.attr)
        new_feats.append(feat1)
    features_sorted_by_feat_i[feature.feature_i] = new_feats

new_modified_feats = []
for f,feats in features_sorted_by_feat_i.items():
    new_modified_feats += feats

forward_check_features(tokenized, new_modified_feats, batch_size=100)

  2%|█▎                                                                             | 99/5880 [06:02<6:47:19,  4.23s/it]

saving


  3%|██▋                                                                           | 199/5880 [12:40<6:23:19,  4.05s/it]

saving


  5%|███▉                                                                          | 299/5880 [19:18<8:49:10,  5.69s/it]

saving


  7%|█████▎                                                                        | 399/5880 [25:46<6:25:53,  4.22s/it]

saving


  8%|██████▌                                                                       | 499/5880 [32:16<6:05:27,  4.08s/it]

saving


 10%|███████▉                                                                      | 599/5880 [38:50<8:20:11,  5.68s/it]

saving


 12%|█████████▎                                                                    | 699/5880 [45:26<5:59:13,  4.16s/it]

saving


 14%|██████████▌                                                                   | 799/5880 [52:04<5:49:03,  4.12s/it]

saving


 15%|███████████▉                                                                  | 899/5880 [58:42<7:45:09,  5.60s/it]

saving


 17%|████████████▉                                                               | 999/5880 [1:05:17<5:35:41,  4.13s/it]

saving


 19%|██████████████                                                             | 1099/5880 [1:11:52<5:24:21,  4.07s/it]

saving


 20%|███████████████▎                                                           | 1199/5880 [1:18:29<7:23:45,  5.69s/it]

saving


 22%|████████████████▌                                                          | 1299/5880 [1:25:42<5:19:20,  4.18s/it]

saving


 22%|████████████████▊                                                          | 1320/5880 [1:27:41<3:49:06,  3.01s/it]

In [22]:
def get_name_counts(feature):
    name_counts = {}
    DATA_LEN = len(feature.records)
    records_tensor = torch.tensor(feature.records)
    non_zero_indices = torch.arange(DATA_LEN)[records_tensor!=0]
    non_zero_tokens = data.data[non_zero_indices,feature.pos].cpu()
    non_zero_records = records_tensor[non_zero_indices]
    name_tokens = torch.unique(non_zero_tokens)
    for name_token in name_tokens:
        name_str = model.to_str_tokens(name_token.view(1,1))[0]
        name_counts[name_str] = non_zero_records[non_zero_tokens==name_token.item()]
    #for t,c in template_counts.items():
    #    print(f" template {t} with count {torch.mean(torch.tensor(c)).item()}")
    name_counts = sorted(list(name_counts.items()), key=lambda x: -torch.mean(x[1]).item())
    return name_counts
    #for n,c in name_counts[:100]:
    #    print(f" name {n} with avg {torch.mean(c).item()} min {torch.min(c).item()} max {torch.max(c).item()}")

#names = sorted(list(acdc.data.ioi.good_names))
names = [x for (i,x) in spaceThings if len(x.strip()) > 0]
NUM_NAMES = len(names)
#name_to_i = dict([(" " + name, i) for (i, name) in enumerate(names)])
name_to_i = dict([(name, i) for (i, name) in enumerate(names)])

def get_name_vector(feature, feat_type):
    name_vec = torch.zeros(NUM_NAMES)
    for name,counts in get_name_counts(feature):
        if feat_type == 'mean':
            name_vec[name_to_i[name]] = torch.mean(counts)
        elif feat_type == 'min':
            name_vec[name_to_i[name]] = torch.min(counts)
        elif feat_type == 'max':
            name_vec[name_to_i[name]] = torch.max(counts)
            
    return name_vec

toks = model.to_str_tokens(data.data[0])
name_positions = [3,5,7,13,15]
position_map = {}
L = data.data.size()[1]
for l in range(L):
    position_map[l] = f'pos{l}{toks[l]}'
position_map[3] = 'n1'
position_map[5] = 'n2'
position_map[7] = 'n3'
position_map[13] = 'n4'
position_map[15] = 'n5'
position_map[19] = 'out'

NameError: name 'spaceThings' is not defined

In [4]:
global feature_labels
feature_labels = {}

In [1]:

from IPython.display import display, clear_output
from ipywidgets.widgets import Output
from ipywidgets import widgets
out = Output()
print(len(features_sorted_by_feat_i))

global cur_feature_ind

def display_unlabeled_feature():
    global feature_labels
    available_features = features_sorted_by_feat_i.keys() - feature_labels.keys()
    for f in available_features:
        print(f"feature {f}")
        feats = features_sorted_by_feat_i[f]
        #if any([position_map[feat.pos][0] != 'n' for feat in feats]):
            #print(f"warning, feature {feat_i} has non name poses, all pos are {[position_map[f.pos] for f in feats]})")
        #    continue
        global cur_feature_ind
        cur_feature_ind = f
        display_feats(feats)
        return
    

def display_feats(feats):
    
    feat_vecs = [get_name_vector(feat, 'mean') for feat in feats]
    avg_vec = torch.stack(feat_vecs).mean(dim=0)
    min_vec = torch.stack([get_name_vector(feat, 'min') for feat in feats]).min(dim=0).values
    max_vec = torch.stack([get_name_vector(feat, 'max') for feat in feats]).max(dim=0).values
    sorted_names = torch.argsort(-avg_vec)
    #print(avg_vec, min_vec, max_vec, sorted_names)
    for name_i in sorted_names[:50]:
        #print(name_i)
        print(f" name {names[name_i]} with avg {avg_vec[name_i]} min {min_vec[name_i]} max {max_vec[name_i]}")
    
    for feat in feats:
        if position_map[feat.pos][0] == 'n':
            print(position_map[feat.pos], detect_single_letter(feat))
            pretty_print_list_first_letter_info(list_first_letter_info(feat))
            print(improved_first_letter(feat))
    
    diffs = torch.zeros(len(feats), len(feats))
    for i,featv1 in enumerate(feat_vecs):
        for j,featv2 in enumerate(feat_vecs):
            diffs[i,j] = torch.mean(torch.abs(featv1-featv2))
    labels = [position_map[feat.pos] for feat in feats]
    imshow(diffs, x=labels, y=labels, font_size=9)
    

def save_labels():
    with open("layer_15_features.pkl", "wb") as f:
        global feature_labels
        pickle.dump(feature_labels, f)
        print(f"done saving {len(feature_labels)}")

def submitted(arg):
    if len(text_item.value.strip()) > 0:
        with out:
            clear_output()
            global cur_feature_ind
            global feature_labels
            feature_labels[cur_feature_ind] = text_item.value
            save_labels()
            display_unlabeled_feature()
        
        text_item.value = ''
text_item = widgets.Text(
    value='',
    placeholder='Type something',
    description='String:',
    disabled=False
)
text_item.on_submit(submitted, names="value")
display(text_item)
display(out)

with out:
    clear_output()
    display_unlabeled_feature()

NameError: name 'features_sorted_by_feat_i' is not defined

In [206]:
print(feature_labels)

{}
