In [None]:

def imshow(tensor, renderer=None, xaxis="", yaxis="", font_size=None, show=True, color_continuous_midpoint=0.0, **kwargs):
    import plotly.express as px
    import transformer_lens.utils as utils
    fig = px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=color_continuous_midpoint, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs)
    if not font_size is None:
        if 'x' in kwargs:
            fig.update_layout(
              xaxis = dict(
                tickmode='array',
                tickvals = kwargs['x'],
                ticktext = kwargs['x'], 
                ),
               font=dict(size=font_size, color="black"))
        if 'y' in kwargs:
            fig.update_layout(
              yaxis = dict(
                tickmode='array',
                tickvals = kwargs['y'],
                ticktext = kwargs['y'], 
                ),
               font=dict(size=font_size, color="black"))
    plot_args = {
        'width': 800,
        'height': 600,
        "autosize": False,
        'showlegend': True,
        'margin': {"l":0,"r":0,"t":100,"b":0}
    }
    
    fig.update_layout(**plot_args)
    fig.update_layout(legend=dict(
        yanchor="top",
        y=0.99,
        xanchor="left",
        x=0.01
    ))
    if show:
        fig.show(renderer)
    else:
        return fig


# requires
# pip install git+https://github.com/Phylliida/MambaLens.git

from mamba_lens import HookedMamba # this will take a little while to import
import torch
model_path = "state-spaces/mamba-370m"

# NOTE! We need to monkeypatch transformer lens to use register_full_backward_hook
model = HookedMamba.from_pretrained(model_path, device='cuda')
torch.set_grad_enabled(False)


from safetensors.torch import load_file
from collections import defaultdict

import sys
if not "/home/dev/sae-k-sparse-mamba/sae" in sys.path:
    sys.path.append("/home/dev/sae-k-sparse-mamba")
import os
os.chdir('/home/dev/sae-k-sparse-mamba')
saes = [None]
from importlib import reload
from sae.sae import Sae

ckpt_dir = "/home/dev/sae-k-sparse-mamba/"
for i in range(1,22):
    print(i)
    hook = f'blocks.{i}.hook_out_proj'
    path = [ckpt_dir + f for f in sorted(list(os.listdir(ckpt_dir))) if hook in f][0] + "/" + f'hook_{hook}.pt'
    #path = f'/home/dev/sae-k-sparse-mamba/blocks.{i}.hook_resid_pre/hook_blocks.{i}.hook_resid_pre.pt'
    print(path)
    #saes.append(Sae.load_from_disk(path, hook=f'blocks.{i}.hook_resid_pre', device=model.cfg.device))


global PATCHING_FORMAT_I
global patching_formats
def make_data(num_patching_pairs, patching, template_i, seed, valid_seed):
    constrain_to_answers = True
    # this makes our data size 800, first 400 is each (a,b) pair, and then second 400 is each pair swapped to be (b,a)
    has_symmetric_patching = True
    
    n1_patchings = ["""
    ABC BC A
    DBC BC D""",
        """
    ABC CB A
    DBC CB D"""]
    
    n2_patchings = ["""
    ABC AC B
    ADC AC D""",
        """
    ABC CA B
    ADC CA D"""]
    
    n3_patchings = ["""
    ABC AB C
    ABD AB D""",
        """
    ABC BA C
    ABD BA D"""]
    
    n4_patchings = ["""
    ABC AC B
    ABC BC A""",
        """
    ABC AB C
    ABC CB A""",
        """
    ABC BA C
    ABC CA B"""]
    
    n5_patchings = ["""
    ABC CA B
    ABC CB A
    """,
        """
    ABC BA C
    ABC BC A""",
        """
    ABC AB C
    ABC AC B"""]
    
    patchings = {
        'n1': n1_patchings,
        'n2': n2_patchings,
        'n3': n3_patchings,
        'n4': n4_patchings,
        'n5': n5_patchings
    }
    
    all_patchings = []
    for patching_set in patchings.values():
        all_patchings += patching_set
    all_patchings = sorted(all_patchings) # make deterministic 
    
    patch_all_names = ["""
    ABC AB C
    DEF DE F""",
        """
    ABC AC B
    DEF DF E""",     
        """
    ABC BA C
    DEF ED F""",
        """
    ABC BC A
    DEF EF D""",
        """
    ABC CA B
    DEF FD E""",
        """
    ABC CB A
    DEF FE D"""]
    
    
    patchings['all'] = all_patchings
    patchings['allatonce'] = patch_all_names
    from acdc.data.ioi import BABA_TEMPLATES, ABC_TEMPLATES
    from acdc.data.ioi import ioi_data_generator, ABC_TEMPLATES, get_all_single_name_abc_patching_formats
    from acdc.data.utils import generate_dataset
    templates = ABC_TEMPLATES
    #patching_formats = list(get_all_single_name_abc_patching_formats())
    global PATCHING_FORMAT_I
    global patching_formats
    PATCHING_FORMAT_I = patching
    patching_formats = ["\n".join([line.strip() for line in x.split("\n")]).strip() for x in patchings[PATCHING_FORMAT_I]]
    
    print("using patching format")
    for patch in patching_formats:
        print(patch)
        print("")
    #print(patching_formats)
    
    
    data = generate_dataset(model=model,
                      data_generator=ioi_data_generator,
                      num_patching_pairs=4,
                      seed=seed,
                      valid_seed=valid_seed,
                      constrain_to_answers=constrain_to_answers,
                      has_symmetric_patching=has_symmetric_patching, 
                      varying_data_lengths=True,
                      templates=templates,
                      patching_formats=patching_formats)
    
    
    import acdc.data.ioi
    from collections import defaultdict
    name_positions_map = defaultdict(lambda: [])
    for template in templates:
        name = acdc.data.ioi.good_names[0]
        template_filled_in = template.replace("[NAME]", name)
        template_filled_in = template_filled_in.replace("[PLACE]", acdc.data.ioi.good_nouns['[PLACE]'][0])
        template_filled_in = template_filled_in.replace("[OBJECT]", acdc.data.ioi.good_nouns['[OBJECT]'][0])
        # get the token positions of the [NAME] in the prompt
        name_positions = tuple([(i) for (i,s) in enumerate(model.to_str_tokens(torch.tensor(model.tokenizer.encode(template_filled_in)))) if s == f' {name}'])
        name_positions_map[name_positions].append(template)
    sorted_by_frequency = sorted(list(name_positions_map.items()), key=lambda x: -len(x[1]))
    most_frequent_name_positions, templates = sorted_by_frequency[0]
    print("using templates")
    templates = [templates[0]]
    for template in templates:
        print(template)
    print(f"with name positions {most_frequent_name_positions}")
    import acdc.data.ioi
    if 'Jesus' in acdc.data.ioi.good_names:
        print("removed jesus")
        acdc.data.ioi.good_names.remove("Jesus")
    data = generate_dataset(model=model,
                  data_generator=ioi_data_generator,
                  num_patching_pairs=num_patching_pairs,
                  seed=seed,
                  valid_seed=valid_seed,
                  constrain_to_answers=constrain_to_answers,
                  has_symmetric_patching=has_symmetric_patching, 
                  varying_data_lengths=True,
                  templates=templates,
                  patching_formats=patching_formats)
    
    print(model.to_str_tokens(data.data[0]))
    print(model.to_str_tokens(data.data[1]))
    return data


from transformer_lens.hook_points import HookPoint
# we do a hacky thing where this first hook clears the global storage
# second hook stores all the hooks
# then third hook computes the output (over all the hooks)
# this avoids recomputing and so is much faster
SAE_HOOKS = "sae hooks"
SAE_BATCHES = "sae batches"
SAE_OUTPUT = "sae output"
def sae_patching_storage_hook(
    x,
    hook: HookPoint,
    sae_feature_i: int,
    dummy: bool,
    position: int,
    layer: int,
    batch_start: int,
    batch_end: int,
    **kwargs,
):
    global sae_storage
    if not SAE_HOOKS in sae_storage:
        sae_storage[SAE_HOOKS] = [] # we can't do this above because it'll be emptied again on the next batch before this is called
    sae_storage[SAE_OUTPUT] = None # clear output
    sae_storage[SAE_HOOKS].append({"position": position, "sae_feature_i": sae_feature_i, "dummy": dummy})
    #print(f"sae feature i {sae_feature_i} position {position} layer {layer}")
    return x

from jaxtyping import Float
from einops import rearrange

global sae_storage
sae_storage = {}
def sae_patching_hook(
    x: Float[torch.Tensor, "B L E"],
    hook: HookPoint,
    input_hook_name: str,
    layer: int,
    **kwargs,
) -> Float[torch.Tensor, "B L E"]:
    global sae_storage
    ### This is identical to what the conv is doing
    # but we break it apart so we can patch on individual filters
    # we have two input hooks, the second one is the one we want
    input_hook_name = input_hook_name[1]
    # don't recompute these if we don't need to
    # because we stored all the hooks and batches in conv_storage, we can just do them all at once
    
    # they need to share an output because they write to the same output tensor
    if sae_storage[SAE_OUTPUT] is None:
        #print(f"running for layer {layer}")
        K = saes[layer].cfg.k
        sae = saes[layer]
        #print(f"layer {layer} storage {sae_storage}")
        sae_output = torch.zeros(x.size(), device=model.cfg.device)
        #print("layer", layer, "keys", conv_storage)
        def get_filter_key(i):
            return f'filter_{i}'
        sae_input_uncorrupted = x[::2]
        sae_input_corrupted = x[1::2]
        B, L, D = sae_input_uncorrupted.size()
        for l in range(L):
            # [B, NFeatures]                             [B,D]
            uncorrupted_features = sae.encode(sae_input_uncorrupted[:,l])
            # [B, NFeatures]                             [B,D]
            corrupted_features = sae.encode(sae_input_corrupted[:,l])
            patched_features = corrupted_features.clone()
            #patched_features = torch.zeros(corrupted_features.size(), device=model.cfg.device) # patch everything except the features we are keeping around
            # apply hooks (one hook applies to a single feature)
            #print(f"{len(sae_storage[SAE_HOOKS])} hooks")
            for hook_data in sae_storage[SAE_HOOKS]:
                position = hook_data['position']
                sae_feature_i = hook_data['sae_feature_i']
                dummy = hook_data['dummy']
                if not dummy and (position == l or position is None): # position is None means all positions
                    if copy_from_other:
                        patched_features[:,sae_feature_i] = corrupted_features[:,sae_feature_i]
                    else:
                        patched_features[:,sae_feature_i] = uncorrupted_features[:,sae_feature_i]
                    
                    #print(f"applying sae feature {sae_feature_i} to position {position} for layer {layer}")
                    #uncorrupted_features[:,sae_feature_i] = corrupted_features[:,sae_feature_i]
            # compute sae outputs
            patched_top_acts, patched_top_indices = patched_features.topk(K, sorted=False)
            corrupted_top_acts, corrupted_top_indices = corrupted_features.topk(K, sorted=False)      
            sae_output[::2,l] = sae.decode(patched_top_acts, patched_top_indices)     
            sae_output[1::2,l] = sae.decode(corrupted_top_acts, corrupted_top_indices)
        sae_storage = {} # clean up and prepare for next layer
        sae_storage[SAE_OUTPUT] = sae_output # store the output
    return sae_storage[SAE_OUTPUT]

from dataclasses import dataclass, field
@dataclass
class SAEFeature:
    """Class for keeping track of an item in inventory."""
    layer: int
    pos: int
    feature_i: int
    attr: float
    records: list = field(default_factory=lambda: [])

    def __repr__(self):
        return str(self.layer) + " " + str(self.pos) + " " + str(self.feature_i) + " " + str(self.attr)

    def __str__(self):
        return self.__repr__()

def get_batched_index_into(indices):
    '''
    given data that is [B,N,V] and indicies that are [B,N,K] with each index being an index into the V space
    this gives you indexes you can use to access your values
    '''
    first_axis = []
    second_axis = []
    third_axis = []
    B, _, _ = indices.size()
    for b in range(B):
        second, third = get_index_into(indices[b])
        first_axis.append(torch.full(second.size(), fill_value=b, device=model.cfg.device))
        second_axis.append(second)
        third_axis.append(third)

    return torch.cat(first_axis), torch.cat(second_axis), torch.cat(third_axis)

def get_index_into(indices):
    '''
    given data that is [N,V] and indicies that are [N,K] with each index being an index into the V space
    this gives you indexes you can use to access your values
    '''
    num_data, num_per_data = indices.size()
    # we want
    # [0,0,0,...,] num per data of these
    # [1,1,1,...,] num per data of these
    # ...
    # [num_data-1, num_data-1, ...]
    first_axis_index = torch.arange(num_data, dtype=torch.long).view(num_data, 1)*torch.ones([num_data, num_per_data], dtype=torch.long)
    # now we flatten it so it has an index for each term aligned with our indices
    first_axis_index = first_axis_index.flatten()
    second_axis_index = indices.flatten()
    return first_axis_index, second_axis_index
global buffer
buffer = None
global features_by_layer
def sae_hook(
    x,
    hook,
    layer,
):
    # s is [B,L,E]
    K = saes[layer].cfg.k
    sae = saes[layer]
    B,L,D = x.size()
    uncorrupted_features = sae.encode(x)
    top_acts, top_indices = uncorrupted_features.topk(K, sorted=False)
    buffer = torch.zeros(uncorrupted_features.size(), device=model.cfg.device)
    global features_by_layer
    # zero everything except the top k
    buffer[get_batched_index_into(top_indices)] = top_acts.flatten()
    for feature in features_by_layer[layer]:
        if feature.pos < L: # sometimes prompt is too small to consider this feature
            feature.records += buffer[:,feature.pos,feature.feature_i].tolist()
    # kernel can't handle doing all token positions at same time by default
    # but if we make it think B*L is a single batch index it works fine
    top_acts_flattened = top_acts.flatten(start_dim=0, end_dim=1)
    top_indices_flattened = top_indices.flatten(start_dim=0, end_dim=1)
    sae_out = sae.decode(top_acts_flattened, top_indices_flattened)
    sae_out = sae_out.unflatten(dim=0, sizes=(B,L))
    return sae_out


def get_name_counts(feature):
    name_counts = {}
    DATA_LEN = len(feature.records)
    records_tensor = torch.tensor(feature.records)
    non_zero_indices = torch.arange(DATA_LEN)[records_tensor!=0]
    non_zero_tokens = data.data[non_zero_indices,feature.pos].cpu()
    non_zero_records = records_tensor[non_zero_indices]
    name_tokens = torch.unique(non_zero_tokens)
    for name_token in name_tokens:
        name_str = model.to_str_tokens(name_token.view(1,1))[0]
        name_counts[name_str] = non_zero_records[non_zero_tokens==name_token.item()]
    #for t,c in template_counts.items():
    #    print(f" template {t} with count {torch.mean(torch.tensor(c)).item()}")
    name_counts = sorted(list(name_counts.items()), key=lambda x: -torch.mean(x[1]).item())
    return name_counts
    #for n,c in name_counts[:100]:
    #    print(f" name {n} with avg {torch.mean(c).item()} min {torch.min(c).item()} max {torch.max(c).item()}")

data = make_data(num_patching_pairs=2, patching="all", template_i=0, seed=24, valid_seed=23)

toks = model.to_str_tokens(data.data[0])
name_positions = [3,5,7,13,15]
position_map = {}
L = data.data.size()[1]
for l in range(L):
    position_map[l] = f'pos{l}{toks[l]}'
position_map[3] = 'n1'
position_map[5] = 'n2'
position_map[7] = 'n3'
position_map[13] = 'n4'
position_map[15] = 'n5'
position_map[19] = 'out'
import pickle
with open("cached_sae_feature_edges.pkl", "rb") as f:
    edges_to_keep = pickle.load(f)

h = model.to_str_tokens(torch.arange(model.tokenizer.vocab_size))
spaceThings = [(i, x) for (i, x) in enumerate(h) if x[0] == ' ' and len(x.strip()) > 0]
prefix = data.data[0][:3].view(1,-1)
new_data_toks = torch.tensor([tok for (tok,s) in spaceThings], device=model.cfg.device)
data_for_all_tokens = torch.cat([prefix.repeat((len(new_data_toks),1)), new_data_toks.view(-1,1)], dim=1)
data.data = data_for_all_tokens
with open("layer_15_features_more_more.pkl", "rb") as f:
    data, features = pickle.load(f)

'''
#names = sorted(list(acdc.data.ioi.good_names))
names = [x for (i,x) in spaceThings if len(x.strip()) > 0]
NUM_NAMES = len(names)
#name_to_i = dict([(" " + name, i) for (i, name) in enumerate(names)])
name_to_i = dict([(name, i) for (i, name) in enumerate(names)])

def get_name_vector(feature, feat_type):
    name_vec = torch.zeros(NUM_NAMES)
    for name,counts in get_name_counts(feature):
        if feat_type == 'mean':
            name_vec[name_to_i[name]] = torch.mean(counts)
        elif feat_type == 'min':
            name_vec[name_to_i[name]] = torch.min(counts)
        elif feat_type == 'max':
            name_vec[name_to_i[name]] = torch.max(counts)
            
    return name_vec
'''
features_sorted_by_feat_i = defaultdict(lambda: [])
for feature in features:
    features_sorted_by_feat_i[feature.feature_i].append(feature)

In [None]:
from tqdm import tqdm
all_features = []
del features
from collections import defaultdict
K = 400
feature_i_top_k_indices = defaultdict(lambda: torch.tensor([]))
feature_i_top_k_values = defaultdict(lambda: torch.tensor([]))
num_data_points_seen = 0
for i in range(100, 5900, 100):
    print(i)
    path = f'/home/dev/sae-k-sparse-mamba/layer_15_features_on_large_data{i}.pkl'
    with open(path, "rb") as f:
        features = pickle.load(f)
    
    num_records = 0
    for feature in tqdm(features):
        records = torch.tensor(feature.records).flatten()
        num_records = records.size()[0]
        top = torch.topk(records, K)
        # offset indices by total num seen so far
        top_indices = top.indices + num_data_points_seen
        top_values = top.values
        merged_top_indices = torch.concatenate([feature_i_top_k_indices[feature.feature_i], top_indices])
        merged_top_values = torch.concatenate([feature_i_top_k_values[feature.feature_i], top_values])
        merged_top = torch.topk(merged_top_values, K)
        feature_i_top_k_indices[feature.feature_i] = merged_top_indices[merged_top.indices]
        feature_i_top_k_values[feature.feature_i] = merged_top_values[merged_top.indices]
        del feature.records
    num_data_points_seen += num_records
    del features


In [None]:
from tqdm import tqdm
all_features = []
del features
from collections import defaultdict
K = 400
feature_i_top_k_indices = defaultdict(lambda: torch.tensor([]))
feature_i_top_k_values = defaultdict(lambda: torch.tensor([]))
num_data_points_seen = 0
for i in range(100, 5900, 100):
    print(i)
    path = f'/home/dev/sae-k-sparse-mamba/layer_15_features_on_large_data{i}.pkl'
    with open(path, "rb") as f:
        features = pickle.load(f)

    num_records = 0
    for feature in tqdm(features):
        records = torch.tensor(feature.records).flatten()
        num_records = records.size()[0]
        top = torch.topk(records, K)
        # offset indices by total num seen so far
        top_indices = top.indices + num_data_points_seen
        top_values = top.values
        merged_top_indices = torch.concatenate([feature_i_top_k_indices[feature.feature_i], top_indices])
        merged_top_values = torch.concatenate([feature_i_top_k_values[feature.feature_i], top_values])
        merged_top = torch.topk(merged_top_values, K)
        feature_i_top_k_indices[feature.feature_i] = merged_top_indices[merged_top.indices]
        feature_i_top_k_values[feature.feature_i] = merged_top_values[merged_top.indices]
        del feature.records
    num_data_points_seen += num_records
    del features


In [None]:
for i in range(100, 5900, 100):
    print(i)
    path = f'/home/dev/sae-k-sparse-mamba/layer_15_features_on_large_data{i}.pkl'
    with open(path + "h", "rb") as f:
        features = pickle.load(f)

    num_records = len(features[0].records)
    data = torch.zeros(len(all_feature_i), num_records, 128, device='cuda')
    for feature in tqdm.tqdm(features):
        data[feature_to_storage_index[feature.feature_i], :, feature.pos] = torch.tensor(feature.records)
        del feature.records
    
    with open(path + "h", "wb") as f:
        pickle.dump(data, f)
    print(num_records)

In [None]:
with open("layer_15_top_act_data.pkl", "wb") as f:
    pickle.dump((dict(feature_i_top_k_indices), dict(feature_i_top_k_values)), f)
del features

In [None]:
import pickle
#K = 400
if 'data' in globals():
    del data
if 'features' in globals():
    print("cleanup features")
    for feature in features:
        if hasattr(feature, 'records'):
            del feature.records
    del features
with open("layer_15_top_act_data.pkl", "rb") as f:
    feature_i_top_k_indices, feature_i_top_k_values = pickle.load(f)


all_feature_i = sorted(list(feature_i_top_k_indices.keys()))
feature_to_storage_index = dict([(feat_i,index) for (index,feat_i) in enumerate(all_feature_i)])

del feature_i_top_k_indices, feature_i_top_k_values

#all_data = torch.zeros(len(all_feature_i), K, 128, device=torch.device(model.cfg.device))

#num_data_points_seen = 0
import tqdm
data = torch.zeros(len(all_feature_i), num_records, 128, device='cuda')

for i in range(100, 5900, 100):
    print(i)
    if os.path.exists(f'/home/dev/sae-k-sparse-mamba/layer_15_features_on_large_data{i}.pkl' + 'h'):
        continue
    path = f'/home/dev/sae-k-sparse-mamba/layer_15_features_on_large_data{i}.pkl'
    with open(path, "rb") as f:
        features = pickle.load(f)

    num_records = len(features[0].records)
    if data.size()[1] != num_records:
        del data
        data = torch.zeros(len(all_feature_i), num_records, 128, device='cuda')
    for feature in tqdm.tqdm(features):
        data[feature_to_storage_index[feature.feature_i], :, feature.pos] = torch.tensor(feature.records)
        del feature.records
    
    with open(path + "h", "wb") as f:
        pickle.dump(data, f)
    '''
    for feature_i in tqdm.tqdm(all_feature_i):
        feature_storage_index = feature_to_storage_index[feature_i]
        indices = feature_i_top_k_indices[feature_i] - num_data_points_seen
        indices_in_this_dataset = indices[indices < num_records and indices >= 0].nonzero().flatten()
        if indicies_in_this_dataset.size()[0] > 0:
            for feature in features:
                if feature.feature_i == feature_i:
                    for storage_index in indices_in_this_dataset:
                        data_index = indices[storage_index]
                        all_data[feature_storage_index, storage_index, feature.pos] = feature.records[data_index]
    '''
    

In [None]:
import pickle
#K = 400
if 'data' in globals():
    del data
if 'features' in globals():
    print("cleanup features")
    for feature in features:
        if hasattr(feature, 'records'):
            del feature.records
    del features
with open("layer_15_top_act_data.pkl", "rb") as f:
    feature_i_top_k_indices, feature_i_top_k_values = pickle.load(f)
K = 400

all_feature_i = sorted(list(feature_i_top_k_indices.keys()))
feature_to_storage_index = dict([(feat_i,index) for (index,feat_i) in enumerate(all_feature_i)])


all_data = torch.zeros(len(all_feature_i), K, 128, device=torch.device(model.cfg.device))

#num_data_points_seen = 0
import tqdm

num_data_points_seen = 0
for i in range(100, 5900, 100):
    print(i)
    path = f'/home/dev/sae-k-sparse-mamba/layer_15_features_on_large_data{i}.pklh'
    if 'feature_data' in globals():
        del feature_data
    with open(path, "rb") as f:
        feature_data = pickle.load(f)
    num_records = feature_data.size()[1]
    for feature_i in tqdm.tqdm(all_feature_i):
        feature_storage_index = feature_to_storage_index[feature_i]
        indices = feature_i_top_k_indices[feature_i] - num_data_points_seen
        filter_index = torch.logical_and(0 <= indices, indices < num_records)
        indices_in_this_dataset = torch.round(indices[filter_index]).long()
        #print(indices_in_this_dataset)
        storage_top_k_indices = filter_index.nonzero().flatten()
        if indices_in_this_dataset.size()[0] > 0:
            for index_in_this_dataset, storage_index in zip(indices_in_this_dataset, storage_top_k_indices):
                all_data[feature_storage_index, storage_index, :] = feature_data[feature_storage_index, index_in_this_dataset, :]
    num_data_points_seen += num_records

In [None]:
with open("all_15_data_topk.pkl", "wb") as f:
    pickle.dump(all_data, f)

In [None]:
if 'feature_data' in globals():
    del feature_data


from datasets import load_dataset
from sae.data import chunk_and_tokenize
dataset = load_dataset(
    "togethercomputer/RedPajama-Data-1T-Sample",
    split="train",
    trust_remote_code=True,
)
tokenizer = model.tokenizer
# too many processes crashes, probably memory issue
tokenized = chunk_and_tokenize(dataset, tokenizer, num_proc=8)

with open("all_15_data_topk.pkl", "rb") as f:
    top_k_data = pickle.load(f)


with open("layer_15_top_act_data.pkl", "rb") as f:
    feature_i_top_k_indices, feature_i_top_k_values = pickle.load(f)

In [None]:

K = 400

all_feature_i = sorted(list(feature_i_top_k_indices.keys()))
feature_to_storage_index = dict([(feat_i,index) for (index,feat_i) in enumerate(all_feature_i)])


token_data = torch.zeros(len(all_feature_i), K, 128, device=torch.device(model.cfg.device), dtype=torch.long)

import tqdm
for feature_i in tqdm.tqdm(all_feature_i):
    storage_index = feature_to_storage_index[feature_i]
    indices = feature_i_top_k_indices[feature_i]
    for k, index in enumerate(indices):
        token_data[storage_index, k, :] = tokenized[torch.round(index).long().item()]['input_ids'][:128]
    


In [None]:
with open("all_15_top_dataset_tokens.pkl", "wb") as f:
    pickle.dump(token_data, f)

In [None]:
saes = [None]
from importlib import reload
from sae.sae import Sae

ckpt_dir = "/home/dev/sae-k-sparse-mamba/"
for i in range(1,22):
    print(i)
    hook = f'blocks.{i}.hook_out_proj'
    path = [ckpt_dir + f for f in sorted(list(os.listdir(ckpt_dir))) if hook in f][0] + "/" + f'hook_{hook}.pt'
    #path = f'/home/dev/sae-k-sparse-mamba/blocks.{i}.hook_resid_pre/hook_blocks.{i}.hook_resid_pre.pt'
    print(path)
    saes.append(Sae.load_from_disk(path, hook=f'blocks.{i}.hook_resid_pre', device=model.cfg.device))


# # Name comparisons

In [None]:
import acdc.data.ioi
from functools import partial

with open("/home/dev/mamba_interp/MoreNames.txt", "r") as f:
    all_names = [x.strip() for x in f.read().split("\n") if len(x.strip()) > 0]
    # regenerate names, but more
    acdc.data.ioi.NAMES = sorted(list(set(all_names)))
    acdc.data.ioi.good_names = None 


N_TOKENS = 50
data = make_data(num_patching_pairs=100, patching="all", template_i=0, seed=24, valid_seed=23)

acdc.data.ioi.good_names = sorted(acdc.data.ioi.good_names)
from tqdm import tqdm
name_test_data = []
for name1 in tqdm(acdc.data.ioi.good_names[:]):
    for name2 in acdc.data.ioi.good_names[:1]:
        prompt = f'Lately,' + f' {name1},'*N_TOKENS
        toks = torch.tensor([model.tokenizer.bos_token_id] + model.tokenizer.encode(prompt), device=model.cfg.device)
        name_test_data.append(toks)

name_test_data = torch.stack(name_test_data)



def get_batched_index_into(indices):
    '''
    given data that is [B,N,V] and indicies that are [B,N,K] with each index being an index into the V space
    this gives you indexes you can use to access your values
    '''
    first_axis = []
    second_axis = []
    third_axis = []
    B, _, _ = indices.size()
    for b in range(B):
        second, third = get_index_into(indices[b])
        first_axis.append(torch.full(second.size(), fill_value=b, device=model.cfg.device))
        second_axis.append(second)
        third_axis.append(third)

    return torch.cat(first_axis), torch.cat(second_axis), torch.cat(third_axis)

def get_index_into(indices):
    '''
    given data that is [N,V] and indicies that are [N,K] with each index being an index into the V space
    this gives you indexes you can use to access your values
    '''
    num_data, num_per_data = indices.size()
    # we want
    # [0,0,0,...,] num per data of these
    # [1,1,1,...,] num per data of these
    # ...
    # [num_data-1, num_data-1, ...]
    first_axis_index = torch.arange(num_data, dtype=torch.long).view(num_data, 1)*torch.ones([num_data, num_per_data], dtype=torch.long)
    # now we flatten it so it has an index for each term aligned with our indices
    first_axis_index = first_axis_index.flatten()
    second_axis_index = indices.flatten()
    return first_axis_index, second_axis_index
global buffer
buffer = None
global features_by_layer
def sae_hook(
    x,
    hook,
    layer,
):
    # s is [B,L,E]
    K = saes[layer].cfg.k
    sae = saes[layer]
    B,L,D = x.size()
    uncorrupted_features = sae.encode(x)
    top_acts, top_indices = uncorrupted_features.topk(K, sorted=False)
    buffer = torch.zeros(uncorrupted_features.size(), device=model.cfg.device)
    global features_by_layer
    # zero everything except the top k
    buffer[get_batched_index_into(top_indices)] = top_acts.flatten()
    for feature in features_by_layer[layer]:
        if feature.pos < L: # sometimes prompt is too small to consider this feature
            feature.records += buffer[:,feature.pos,feature.feature_i].tolist()
            feature.full_records += uncorrupted_features[:,feature.pos,feature.feature_i].tolist()
    # kernel can't handle doing all token positions at same time by default
    # but if we make it think B*L is a single batch index it works fine
    top_acts_flattened = top_acts.flatten(start_dim=0, end_dim=1)
    top_indices_flattened = top_indices.flatten(start_dim=0, end_dim=1)
    sae_out = sae.decode(top_acts_flattened, top_indices_flattened)
    sae_out = sae_out.unflatten(dim=0, sizes=(B,L))
    return sae_out

def forward_check_features(data, features, batch_size):
    
    global features_by_layer

    features_by_layer = defaultdict(lambda: [])
    for feature in features:
        feature.records = []
        feature.full_records = []
        features_by_layer[feature.layer].append(feature)

    # only bother with SAE on the layers we are checking
    layers_to_apply_sae = sorted(list(features_by_layer.keys()))
    hooks = [(f'blocks.{layer}.hook_out_proj', partial(sae_hook, layer=layer)) for layer in layers_to_apply_sae]
    DATA_LEN = data.size()[0]
    for batch_start in tqdm(list(range(0, DATA_LEN, batch_size))):
        batch_end = min(DATA_LEN, batch_start+batch_size)
        data_batch = data[batch_start:batch_end]
        _ = model.run_with_hooks(input=data_batch, fwd_hooks=hooks, fast_ssm=True, fast_conv=True)

feature_i = 22605
L = name_test_data.size()[1]
for feature in features:
    if feature.feature_i == feature_i and feature.layer == 15:
        feat = feature
features_with_i = []
for i in range(1, L):
    features_with_i.append(SAEFeature(layer=feat.layer, pos=i, feature_i=feature_i, attr=feat.attr))
forward_check_features(name_test_data, features=features_with_i, batch_size=10)

activations = torch.zeros(name_test_data.size())
for feature in features_with_i:
    activations[:,feature.pos] = torch.tensor(feature.records)


In [63]:

activations2 = torch.zeros(name_test_data.size())
for feature in features_with_i:
    activations2[:,feature.pos] = torch.tensor(feature.full_records)

In [None]:

acts = activations.reshape(-1, 1, L)

def imshow(tensor, renderer=None, xaxis="", yaxis="", font_size=None, show=True, color_continuous_midpoint=0.0, **kwargs):
    import plotly.express as px
    import transformer_lens.utils as utils
    fig = px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=color_continuous_midpoint, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs)
    if not font_size is None:
        if 'x' in kwargs:
            fig.update_layout(
              xaxis = dict(
                tickmode='array',
                tickvals = kwargs['x'],
                ticktext = kwargs['x'], 
                ),
               font=dict(size=font_size, color="black"))
        if 'y' in kwargs:
            fig.update_layout(
              yaxis = dict(
                tickmode='array',
                tickvals = kwargs['y'],
                ticktext = kwargs['y'], 
                ),
               font=dict(size=font_size, color="black"))
    plot_args = {
        'width': 800*3,
        'height': 600*3,
        "autosize": False,
        'showlegend': True,
        'margin': {"l":0,"r":0,"t":100,"b":0}
    }
    
    fig.update_layout(**plot_args)
    fig.update_layout(legend=dict(
        yanchor="top",
        y=0.99,
        xanchor="left",
        x=0.01
    ))
    if show:
        fig.show(renderer)
    else:
        return fig

names_labels = acdc.data.ioi.good_names[:]

imshow(acts[:,:,-1], x=['h'], y=names_labels, font_size=12)




# y is first name, x is second

In [62]:

def detect_parity(l):
    num_swap = 0
    for k in range(len(l)-1):
        if l[k] != l[k+1]:
            num_swap += 1
    return num_swap

def sign(x):
    if x > 0:
        return 1
    elif x < 0:
        return -1
    else:
        return 0
print(model.to_str_tokens(name_test_data[0]))
L = name_test_data.size()[1]
acts = activations.reshape(-1, 1, L)
acts2 = activations2.reshape(-1, 1, L)
print(acts[0])
outs = []
outs2 = []
diffs = []
import math
for i in range(acts.size()[0]):
    out = []
    out2 = []
    for j in range(N_TOKENS):
        out.append(acts[i,0,4+j*2].item())
        out2.append(acts2[i,0,4+j*2].item())
    diff = []
    for j in range(len(out)-1):
        diff.append(sign(out[j]-out[j+1]))
    diff.append(detect_parity(diff))
    diff.append(i)
    diffs.append(diff)
    #out.append(i)
    outs.append(tuple(out))
    outs2.append(out2)
#outs = sorted(outs)
#diffs = sorted(diffs, key=lambda x: outs[x[-1]][0])
import time
i = 0
import pandas as pd
inds = sorted(list(range(len(outs))), key=lambda i: (outs[i][1], outs[i][0]))
import numpy as np
outsarr = np.array(outs).T

def display_inds(inds, title, prange=6):
    if len(inds) == 0:
        print(f"none for {title} ({len(inds)}/{len(diffs)} names)")
        return
    names_labels = acdc.data.ioi.good_names[:]
    batch_names_labels = [names_labels[i] for i in inds]
    batchoutsarr = outsarr[:,np.array(inds)]
    df = pd.DataFrame(batchoutsarr[:prange], columns=batch_names_labels)
    import plotly.express as px
    fig = px.line(df, x=df.index, y = df.columns, title=title + f" ({len(inds)}/{len(diffs)} names)")
    fig.show()

display_inds(inds, "all", prange=N_TOKENS)


always_zero = []
for ind in [i for i in inds]:
    if max(outs[ind]) == 0:
        always_zero.append(ind)
        inds.remove(ind)

display_inds(always_zero, "always zero")

goes_to_always_zero_at_1 = []
for ind in [i for i in inds]:
    if max(outs[ind][1:]) == 0:
        goes_to_always_zero_at_1.append(ind)
        inds.remove(ind)
display_inds(goes_to_always_zero_at_1, "always zero after 1")

goes_to_always_zero_at_2_decrease = []
for ind in [i for i in inds]:
    if max(outs[ind][2:]) == 0 and outs[ind][0] > outs[ind][1]:
        goes_to_always_zero_at_2_decrease.append(ind)
        inds.remove(ind)
display_inds(goes_to_always_zero_at_2_decrease, "always zero after 2 (decreasing)")

goes_to_always_zero_at_2_increase = []
for ind in [i for i in inds]:
    if max(outs[ind][2:]) == 0 and outs[ind][0] <= outs[ind][1]:
        goes_to_always_zero_at_2_increase.append(ind)
        inds.remove(ind)
display_inds(goes_to_always_zero_at_2_increase, "always zero after 2 (increasing or equal)")

goes_to_always_zero_at_3 = []
for ind in [i for i in inds]:
    if max(outs[ind][3:]) == 0:
        goes_to_always_zero_at_3.append(ind)
        inds.remove(ind)
display_inds(goes_to_always_zero_at_3, "always zero after 3")

goes_to_zero_at_0123 = []
for ind in [i for i in inds]:
    if outs[ind][0] == 0 and outs[ind][1] == 0 and outs[ind][2] == 0 and outs[ind][3] == 0:
        goes_to_zero_at_0123.append(ind)
        inds.remove(ind)
display_inds(goes_to_zero_at_0123, "zero at 0,1,2, and 3 (but eventually non-zero after)")

goes_to_zero_at_123 = []
for ind in [i for i in inds]:
    if outs[ind][1] == 0 and outs[ind][2] == 0 and outs[ind][3] == 0:
        goes_to_zero_at_123.append(ind)
        inds.remove(ind)
display_inds(goes_to_zero_at_123, "zero at 1,2, and 3 (but eventually non-zero after)")

goes_to_zero_at_12 = []
for ind in [i for i in inds]:
    if outs[ind][1] == 0 and outs[ind][2] == 0:
        goes_to_zero_at_12.append(ind)
        inds.remove(ind)
display_inds(goes_to_zero_at_12, "zero at 1,2 (but eventually non-zero after)")

goes_to_zero_at_1 = []
for ind in [i for i in inds]:
    if outs[ind][1] == 0:
        goes_to_zero_at_1.append(ind)
        inds.remove(ind)
display_inds(goes_to_zero_at_1, "zero at 1 (but eventually non-zero after)")

goes_to_zero_at_02 = []
for ind in [i for i in inds]:
    if outs[ind][2] == 0 and outs[ind][0] > outs[ind][1]:
        goes_to_zero_at_02.append(ind)
        inds.remove(ind)
display_inds(goes_to_zero_at_02, "zero at 2, decreasing 0>1>2(and eventually non-zero after)")

goes_to_zero_at_02dec = []
for ind in [i for i in inds]:
    if outs[ind][2] == 0 and outs[ind][0] <= outs[ind][1]:
        goes_to_zero_at_02dec.append(ind)
        inds.remove(ind)
display_inds(goes_to_zero_at_02dec, "zero at 2, 0<=1>2(and eventually non-zero after)")


goes_to_zero_at_3 = []
for ind in [i for i in inds]:
    if outs[ind][3] == 0:
        goes_to_zero_at_3.append(ind)
        inds.remove(ind)
display_inds(goes_to_zero_at_3, "zero at 3 (and eventually non-zero after)")



decreasing_terms = []
for ind in [i for i in inds]:
    decreasing = True
    for i in range(10):
        if outs[ind][i] <= outs[ind][i+1]:
            decreasing = False
    if decreasing:
        decreasing_terms.append(ind)
        inds.remove(ind)
display_inds(decreasing_terms, "strictly decrasing for first 10")


decreasing_terms = []
for ind in [i for i in inds]:
    decreasing = True
    for i in range(1,10):
        if outs[ind][i] <= outs[ind][i+1]:
            decreasing = False
    if decreasing:
        decreasing_terms.append(ind)
        inds.remove(ind)
display_inds(decreasing_terms, "strictly decrasing for 1-10")

always_non_zero = []
for ind in [i for i in inds]:
    if min(outs[ind]) > 0:
        always_non_zero.append(ind)
        inds.remove(ind)
display_inds(always_non_zero, "always non-zero")


four_sign_changes = []
for ind in [i for i in inds]:
    decreasing = True
    if diffs[ind][-2] == 5:
        four_sign_changes.append(ind)
        inds.remove(ind)
display_inds(four_sign_changes, "five sign changes", prange=300)

four_sign_changes = []
for ind in [i for i in inds]:
    decreasing = True
    if diffs[ind][-2] == 4:
        four_sign_changes.append(ind)
        inds.remove(ind)
display_inds(four_sign_changes, "four sign changes", prange=300)


four_sign_changes = []
for ind in [i for i in inds]:
    decreasing = True
    if diffs[ind][-2] == 3:
        four_sign_changes.append(ind)
        inds.remove(ind)
display_inds(four_sign_changes, "three sign changes", prange=300)


four_sign_changes = []
for ind in [i for i in inds]:
    decreasing = True
    if diffs[ind][-2] == 2:
        four_sign_changes.append(ind)
        inds.remove(ind)
display_inds(four_sign_changes, "two sign changes", prange=300)



four_sign_changes = []
for ind in [i for i in inds]:
    decreasing = True
    if diffs[ind][-2] == 1:
        four_sign_changes.append(ind)
        inds.remove(ind)
display_inds(four_sign_changes, "one sign changes", prange=300)

display_inds(inds, "other", prange=300)
'''
zero_for_first_100 = []
for ind in [i for i in inds]:
    if max(outs[ind][:100]) == 0:
        zero_for_first_100.append(ind)
        inds.remove(ind)
display_inds(zero_for_first_100, "zero for first 100", prange=300)

zero_for_first_50 = []
for ind in [i for i in inds]:
    if max(outs[ind][:50]) == 0:
        zero_for_first_50.append(ind)
        inds.remove(ind)
display_inds(zero_for_first_50, "zero for first 50", prange=300)


zero_for_first_20 = []
for ind in [i for i in inds]:
    if max(outs[ind][:20]) == 0:
        zero_for_first_20.append(ind)
        inds.remove(ind)
display_inds(zero_for_first_20, "zero for first 20", prange=300)

zero_for_first_10 = []
for ind in [i for i in inds]:
    if max(outs[ind][:10]) == 0:
        zero_for_first_10.append(ind)
        inds.remove(ind)
display_inds(zero_for_first_10, "zero for first 10", prange=300)

zero_for_last_100 = []
for ind in [i for i in inds]:
    if max(outs[ind][-100:]) == 0:
        zero_for_last_100.append(ind)
        inds.remove(ind)
display_inds(zero_for_last_100, "zero for last 100", prange=300)

zero_for_last_50 = []
for ind in [i for i in inds]:
    if max(outs[ind][-50:]) == 0:
        zero_for_last_50.append(ind)
        inds.remove(ind)
display_inds(zero_for_last_50, "zero for last 50", prange=300)




import numpy as np
import pandas as pd
# [Time, Names]
outs2arr = np.array(outs2).T
outsarr = np.array(outs).T
'''



'''
for diff in diffs:
    print(acdc.data.ioi.good_names[diff[-1]], max(outs[diff[-1]])) #, diff[::-1][0:30])
    #print(outs[diff[-1]][:30], torch.argmax(torch.tensor(outs[diff[-1]])))
    print(diff[-2])
    print(outs[diff[-1]])
    print(outs2[diff[-1]])
    #print(acdc.data.ioi.good_names[diff[-1]], ['{:.5f}'.format(d) for d in diff[::-1]]) # out[::-1])
    #print(outs[diff[-1]][::-1])
    time.sleep(0.1)
    i += 1
    if i > 20: break
'''

['<|endoftext|>', 'L', 'ately', ',', ' Aaron', ',', ' Aaron', ',', ' Aaron', ',', ' Aaron', ',', ' Aaron', ',', ' Aaron', ',', ' Aaron', ',', ' Aaron', ',', ' Aaron', ',', ' Aaron', ',', ' Aaron', ',', ' Aaron', ',', ' Aaron', ',', ' Aaron', ',', ' Aaron', ',', ' Aaron', ',', ' Aaron', ',', ' Aaron', ',', ' Aaron', ',', ' Aaron', ',', ' Aaron', ',', ' Aaron', ',', ' Aaron', ',', ' Aaron', ',', ' Aaron', ',', ' Aaron', ',', ' Aaron', ',', ' Aaron', ',', ' Aaron', ',', ' Aaron', ',', ' Aaron', ',', ' Aaron', ',', ' Aaron', ',', ' Aaron', ',', ' Aaron', ',', ' Aaron', ',', ' Aaron', ',', ' Aaron', ',', ' Aaron', ',', ' Aaron', ',', ' Aaron', ',', ' Aaron', ',', ' Aaron', ',', ' Aaron', ',', ' Aaron', ',', ' Aaron', ',', ' Aaron', ',', ' Aaron', ',', ' Aaron', ',', ' Aaron', ',']
tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1001,
         0.0000, 0.1334, 0.0000, 0.1347, 0.0000, 0.1348, 0.0000, 0.1354, 0.0000,
         0.1357, 0.0000, 0.1357, 0.0000, 0.1353, 0.

none for always zero after 2 (decreasing) (0/978 names)
none for always zero after 2 (increasing or equal) (0/978 names)
none for always zero after 3 (0/978 names)


none for zero at 2, decreasing 0>1>2(and eventually non-zero after) (0/978 names)
none for zero at 2, 0<=1>2(and eventually non-zero after) (0/978 names)
none for zero at 3 (and eventually non-zero after) (0/978 names)
none for strictly decrasing for first 10 (0/978 names)


none for five sign changes (0/978 names)


none for other (0/978 names)


"\nfor diff in diffs:\n    print(acdc.data.ioi.good_names[diff[-1]], max(outs[diff[-1]])) #, diff[::-1][0:30])\n    #print(outs[diff[-1]][:30], torch.argmax(torch.tensor(outs[diff[-1]])))\n    print(diff[-2])\n    print(outs[diff[-1]])\n    print(outs2[diff[-1]])\n    #print(acdc.data.ioi.good_names[diff[-1]], ['{:.5f}'.format(d) for d in diff[::-1]]) # out[::-1])\n    #print(outs[diff[-1]][::-1])\n    time.sleep(0.1)\n    i += 1\n    if i > 20: break\n"

In [None]:
len((0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0))

In [None]:
acts = activations.reshape(-1, 1, L)

inds = torch.argsort(acts[:,0,4])
for i in inds:
    print(acdc.data.ioi.good_names[i], acts[i,0,4].item())

# Manual Test Feature

In [None]:

indices = torch.round(feature_i_top_k_indices[2380]).long()
print(tokenized[indices[0].item()]['input_ids'])
feature_i_top_k_values[2380]

In [139]:
from dataclasses import dataclass, field
from functools import partial


from tqdm import tqdm

with open("features_all_tokens_layer_15_with_collected_data.pkl", "rb") as f:
    features = pickle.load(f)



def forward_check_featuresf(data, features):
    
    global features_by_layer
    
    #with open("layer_15_features_on_large_data.pkl", "rb") as f:
    #    features = pickle.load(f)
    features_by_layer = defaultdict(lambda: [])
    for feature in features:
        feature.records = []
        feature.full_records = []
        features_by_layer[feature.layer].append(feature)

    # only bother with SAE on the layers we are checking
    layers_to_apply_sae = sorted(list(features_by_layer.keys()))
    hooks = [(f'blocks.{layer}.hook_out_proj', partial(sae_hook, layer=layer)) for layer in layers_to_apply_sae]
    _ = model.run_with_hooks(input=data, fwd_hooks=hooks, fast_ssm=True, fast_conv=True)
        
import traceback
def clicked(arg):
    with outputTesting:
        clear_output()
        try:
            if eval_data.value.strip() == "":
                text = text_itemw.value
                tokenized_input = torch.tensor([model.tokenizer.bos_token_id] + model.tokenizer.encode(text), device=model.cfg.device).reshape(1,-1)
            else:
                tokenized_input = eval(eval_data.value).reshape(1, -1)
                print("eval to", tokenized_input)
            feature_i = int(feature_index.value)
            L = tokenized_input.size()[1]
            for feature in features:
                if feature.feature_i == feature_i and feature.layer == 15:
                    feat = feature
            features_with_i = []
            for i in range(1, L):
                features_with_i.append(SAEFeature(layer=feat.layer, pos=i, feature_i=feature_i, attr=feat.attr))
            forward_check_featuresf(tokenized_input, features=features_with_i)
            activations = torch.zeros(L)
            for feature in features_with_i:
                activations[feature.pos] = feature.records[0]
            toks = model.to_str_tokens(tokenized_input[0])
            print(toks)
            token_pos = torch.argmax(activations).item()
            
            if activations[token_pos] == 0.0:
                token_pos = L
            out_toks = []
            print(activations)
            for j,tok in enumerate(toks):
                if len(tok.strip()) == 0:
                    tok = repr(tok)
                tok = tok.replace("\n", "\\n")
                colored = f"<span id='ayy'><font color='white'>{tok}</font></span>"
                if j < 1: continue
                if j == token_pos:
                    colored = f"<span id='ayy'><font color='red'>{tok}</font></span>"
                elif activations[j].item() > 0.01:
                    colored = f"<span id='ayy'><font color='pink'>{tok}</font></span>"
                if activations[j].item() > 0.01:
                    out_toks.append(f"{colored}{activations[j].item():.3f}")
                else:
                    out_toks.append(colored)
            simpler = model.tokenizer.decode(tokenized_input[0,1:token_pos+1])
            print(simpler)
            if token_pos == L:
                print("all zero")
            else:
                display(HTML(toks[token_pos] + "\t<br/><br/>\t" + "<span id='ayy'>" + simpler + "</span>\t<br/><br/>\t" + "".join(out_toks))) 
        except:
            print(traceback.format_exc())

                
text_itemw = widgets.Text(
    value='',
    placeholder='Test String',
    description='Test String',
    disabled=False,
    continuous_update=False,
)

feature_index = widgets.Text(
    value='',
    placeholder='Feature Index',
    description='Feature Index',
    disabled=False,
    continuous_update=False,
)
eval_data = widgets.Text(
    value='',
    placeholder='Eval Tokens',
    description='Eval Tokens',
    disabled=False,
    continuous_update=False,
)

button_download = widgets.Button(description = 'Test')   
button_download.on_click(clicked)

outputTesting = widgets.Output()

display(text_itemw)
display(feature_index)
display(eval_data)
display(button_download)
display(outputTesting)


Text(value='', continuous_update=False, description='Test String', placeholder='Test String')

Text(value='', continuous_update=False, description='Feature Index', placeholder='Feature Index')

Text(value='', continuous_update=False, description='Eval Tokens', placeholder='Eval Tokens')

Button(description='Test', style=ButtonStyle())

Output()

In [None]:
top_k_data[feature_to_storage_index[2380], 1]

# Display TopK Activations from Dataset subset

In [140]:
%%html
<style>
/*overwrite hard coded write background by vscode for ipywidges */
.cell-output-ipywidget-background {
   background-color: transparent !important;
}

/*set widget foreground text and color of interactive widget to vs dark theme color */
:root {
    --jp-widgets-color: var(--vscode-editor-foreground);
    --jp-widgets-font-size: var(--vscode-editor-font-size);
}
</style>

In [142]:

with open("layer_15_top_act_data.pkl", "rb") as f:
    feature_i_top_k_indices, feature_i_top_k_values = pickle.load(f)

with open("all_15_data_topk.pkl", "rb") as f:
    top_k_data = pickle.load(f)

with open("all_15_top_dataset_tokens.pkl", "rb") as f:
    token_data = pickle.load(f)


with open("layer_15_features.pkl", "rb") as f:
    feature_labels = pickle.load(f)
with open("layer_15_features_take_two.pkl", "rb") as f:
    feature_labelsb = pickle.load(f)

K = 400

all_feature_i = sorted(list(feature_i_top_k_indices.keys()))
feature_to_storage_index = dict([(feat_i,index) for (index,feat_i) in enumerate(all_feature_i)])



from IPython.display import display, clear_output
from ipywidgets import widgets
from IPython.display import display, HTML
global cur_feature_ind
cur_feature_ind = None
def display_unlabeled_feature():
    global cur_feature_ind
    global feature_labels
    cur_feature_ind = 15112
    text_item.value = feature_labels[cur_feature_ind]
    display_feat(cur_feature_ind)
    available_features = features_sorted_by_feat_i.keys() - feature_labelsb.keys()
    maybe = []
    for f in available_features:
        #if len(feature_labels[f].strip()) != 1: continue
        if '?' not in feature_labels[f]:
           continue
        maybe.append((features_sorted_by_feat_i[f][0].attr, f))
        continue
    return
    maybe.sort(key=lambda x: x[0])
    for attr, f in maybe:
        f = 22605
        attr = features_sorted_by_feat_i[f][0].attr
        print(attr)
        print(f"feature {f}")
        cur_feature_ind = f
        text_item.value = feature_labels[f]
        #if any([position_map[feat.pos][0] != 'n' for feat in feats]):
            #print(f"warning, feature {feat_i} has non name poses, all pos are {[position_map[f.pos] for f in feats]})")
        #    continue
        display_feat(cur_feature_ind)
        return
    num_left = 0
    for f in feature_labels.keys():
        if "?" in feature_labels[f]:
            num_left += 1
    print(f"num left {num_left}")
    for f in feature_labels.keys():
        if "?" in feature_labels[f]:
            cur_feature_ind = f
            text_item.value = feature_labels[f]
            display_feats(features_sorted_by_feat_i[f])
            return

display(HTML("""<link href='https://fonts.googleapis.com/css?family=Noto Sans' rel='stylesheet'>
<style>
/* Base Noto Sans and Serif for Latin, Greek, and Cyrillic */
@import url('https://fonts.googleapis.com/css2?family=Noto+Sans:wght@400;700&family=Noto+Serif:wght@400;700&display=swap');

/* East Asian scripts */
@import url('https://fonts.googleapis.com/css2?family=Noto+Sans+JP:wght@400;700&family=Noto+Sans+KR:wght@400;700&family=Noto+Sans+SC:wght@400;700&family=Noto+Sans+TC:wght@400;700&display=swap');

/* South Asian scripts */
@import url('https://fonts.googleapis.com/css2?family=Noto+Sans+Devanagari:wght@400;700&family=Noto+Sans+Bengali:wght@400;700&family=Noto+Sans+Tamil:wght@400;700&display=swap');

/* Middle Eastern scripts */
@import url('https://fonts.googleapis.com/css2?family=Noto+Sans+Arabic:wght@400;700&family=Noto+Sans+Hebrew:wght@400;700&display=swap');

/* Other scripts */
@import url('https://fonts.googleapis.com/css2?family=Noto+Sans+Thai:wght@400;700&family=Noto+Sans+Ethiopic:wght@400;700&display=swap');

/* Specialty fonts */
@import url('https://fonts.googleapis.com/css2?family=Noto+Sans+Mono:wght@400;700&family=Noto+Color+Emoji&display=swap');

#ayy {
  font-family: 'Noto Sans', 'Noto Sans JP', 'Noto Sans KR', 'Noto Sans SC', 'Noto Sans TC', 
               'Noto Sans Devanagari', 'Noto Sans Bengali', 'Noto Sans Tamil', 
               'Noto Sans Arabic', 'Noto Sans Hebrew', 'Noto Sans Thai', 'Noto Sans Ethiopic',
               sans-serif;
}

/* Language-specific rules */
:lang(ja) { font-family: 'Noto Sans JP', sans-serif; }
:lang(ko) { font-family: 'Noto Sans KR', sans-serif; }
:lang(zh-CN) { font-family: 'Noto Sans SC', sans-serif; }
:lang(zh-TW) { font-family: 'Noto Sans TC', sans-serif; }
:lang(hi) { font-family: 'Noto Sans Devanagari', sans-serif; }
:lang(bn) { font-family: 'Noto Sans Bengali', sans-serif; }
:lang(ta) { font-family: 'Noto Sans Tamil', sans-serif; }
:lang(ar) { font-family: 'Noto Sans Arabic', sans-serif; }
:lang(he) { font-family: 'Noto Sans Hebrew', sans-serif; }
:lang(th) { font-family: 'Noto Sans Thai', sans-serif; }
:lang(am), :lang(ti) { font-family: 'Noto Sans Ethiopic', sans-serif; }

/* Emoji support */
.emoji {
  font-family: 'Noto Color Emoji', sans-serif;
}
</style>"""))
def display_feat(feature_i):
    covered_already = set()
    simpler_words = []
    for k in range(50):
        storage_index = feature_to_storage_index[feature_i]
        activations = top_k_data[storage_index, k]
        tokens = token_data[storage_index, k]
        token_pos = torch.argmax(activations).item()
        toks = model.to_str_tokens(tokens)
        relevant_str = "".join(toks[:token_pos+1])
        if relevant_str in covered_already:
            continue
        covered_already.add(relevant_str)
        out_toks = []
        colors = [''] + ['red', 'orange', 'yellow', 'green']*256
        for j,tok in enumerate(toks):
            if len(tok.strip()) == 0:
                tok = repr(tok)
            tok = tok.replace("\n", "\\n")
            colored = f"<span id='ayy'><font color='white'>{tok}</font></span>"
            if j < 1: continue
            if j == token_pos:
                colored = f"<span id='ayy'><font color='red'>{tok}</font></span>"
            elif activations[j].item() > 0.01:
                colored = f"<span id='ayy'><font color='pink'>{tok}</font></span>"
            if activations[j].item() > 0.01:
                out_toks.append(f"{colored}{activations[j].item():.3f}")
            else:
                out_toks.append(colored)
        simpler = model.tokenizer.decode(tokens[1:token_pos+1])
        simpler_words.append(simpler.strip())

        display(HTML(toks[token_pos] + "\t||\t" + "<span id='ayy'>" + simpler + "</span>\t||\t" + "".join(out_toks)))
        if len(simpler_words) == 20:
            out_s = "<span id='ayy'>What do "
            for s in simpler_words:
                out_s += f'"{s.strip()}", '
            out_s += "have in common? Take a deep breath and think step by step."
            display(HTML(out_s + "</span>"))
    '''
    feat_vecs = [get_name_vector(feat, 'mean') for feat in feats]
    avg_vec = torch.stack(feat_vecs).mean(dim=0)
    min_vec = torch.stack([get_name_vector(feat, 'min') for feat in feats]).min(dim=0).values
    max_vec = torch.stack([get_name_vector(feat, 'max') for feat in feats]).max(dim=0).values
    sorted_names = torch.argsort(-avg_vec)
    #print(avg_vec, min_vec, max_vec, sorted_names)
    for name_i in sorted_names[:100]:
        #print(name_i)
        print(f" name {names[name_i]} with avg {avg_vec[name_i]} min {min_vec[name_i]} max {max_vec[name_i]}")
    '''
    '''
    for feat in feats:
        if position_map[feat.pos][0] == 'n':
            print(position_map[feat.pos], detect_single_letter(feat))
            pretty_print_list_first_letter_info(list_first_letter_info(feat))
            print(improved_first_letter(feat))
    
    diffs = torch.zeros(len(feats), len(feats))
    for i,featv1 in enumerate(feat_vecs):
        for j,featv2 in enumerate(feat_vecs):
            diffs[i,j] = torch.mean(torch.abs(featv1-featv2))
    labels = [position_map[feat.pos] for feat in feats]
    imshow(diffs, x=labels, y=labels, font_size=9)
    '''
def save_labels():
    with open("layer_15_features_take_two.pkl", "wb") as f:
        global feature_labelsb
        pickle.dump(feature_labelsb, f)
        print(f"done saving {len(feature_labelsb)}")
import traceback
out = widgets.Output()
global cur_feature_ind
global feature_labels
def submitted(change):
    global cur_feature_ind
    global feature_labels
    if len(text_item.value.strip()) > 0 and (text_item.value != feature_labels[cur_feature_ind]):
        with out:
            try:
                times = 0
                clear_output()
                feature_labelsb[cur_feature_ind] = text_item.value
                save_labels()
                display_unlabeled_feature()
            except:
                print(traceback.format_exc())
                



text_item = widgets.Text(
    value='ffff',
    placeholder='Type something',
    description='String:',
    disabled=False,
    continuous_update=False,
)

display(text_item)
display(out)
text_item.observe(submitted, names='value')

with out:
    display_unlabeled_feature()
    






Text(value='ffff', continuous_update=False, description='String:', placeholder='Type something')

Output()

In [None]:
import torch
with open("textlayer15.pkl", "rb") as f:
    dat = torch.load(f)

In [None]:
import torch
with open("textlayer15.pkl", "wb") as f:
    torch.save(features, f)

In [None]:
global feature_labels
import pickle
with open("layer_15_features.pkl", "rb") as f:
    feature_labels = pickle.load(f)



In [None]:
global feature_labels
import pickle
with open("layer_15_features_take_two.pkl", "rb") as f:
    feature_labelsb = pickle.load(f)


In [None]:
print(feature_labelsb)

In [None]:

from IPython.display import display, clear_output
from ipywidgets import widgets
from IPython.display import display, HTML
global cur_feature_ind
cur_feature_ind = None
def display_unlabeled_feature():
    global cur_feature_ind
    global feature_labels
    available_features = features_sorted_by_feat_i.keys() - feature_labelsb.keys()
    for f in available_features:
        if not '?' in feature_labels[f]:
            continue
        print(f"feature {f}")
        cur_feature_ind = f
        feats = features_sorted_by_feat_i[f]
        text_item.value = feature_labels[f]
        #if any([position_map[feat.pos][0] != 'n' for feat in feats]):
            #print(f"warning, feature {feat_i} has non name poses, all pos are {[position_map[f.pos] for f in feats]})")
        #    continue
        display_feats(feats)
        return
    num_left = 0
    for f in feature_labels.keys():
        if "?" in feature_labels[f]:
            num_left += 1
    print(f"num left {num_left}")
    for f in feature_labels.keys():
        if "?" in feature_labels[f]:
            cur_feature_ind = f
            text_item.value = feature_labels[f]
            display_feats(features_sorted_by_feat_i[f])
            return

print(data.size())
display(HTML("""<link href='https://fonts.googleapis.com/css?family=Noto Sans' rel='stylesheet'>
<style>
/* Base Noto Sans and Serif for Latin, Greek, and Cyrillic */
@import url('https://fonts.googleapis.com/css2?family=Noto+Sans:wght@400;700&family=Noto+Serif:wght@400;700&display=swap');

/* East Asian scripts */
@import url('https://fonts.googleapis.com/css2?family=Noto+Sans+JP:wght@400;700&family=Noto+Sans+KR:wght@400;700&family=Noto+Sans+SC:wght@400;700&family=Noto+Sans+TC:wght@400;700&display=swap');

/* South Asian scripts */
@import url('https://fonts.googleapis.com/css2?family=Noto+Sans+Devanagari:wght@400;700&family=Noto+Sans+Bengali:wght@400;700&family=Noto+Sans+Tamil:wght@400;700&display=swap');

/* Middle Eastern scripts */
@import url('https://fonts.googleapis.com/css2?family=Noto+Sans+Arabic:wght@400;700&family=Noto+Sans+Hebrew:wght@400;700&display=swap');

/* Other scripts */
@import url('https://fonts.googleapis.com/css2?family=Noto+Sans+Thai:wght@400;700&family=Noto+Sans+Ethiopic:wght@400;700&display=swap');

/* Specialty fonts */
@import url('https://fonts.googleapis.com/css2?family=Noto+Sans+Mono:wght@400;700&family=Noto+Color+Emoji&display=swap');

#ayy {
  font-family: 'Noto Sans', 'Noto Sans JP', 'Noto Sans KR', 'Noto Sans SC', 'Noto Sans TC', 
               'Noto Sans Devanagari', 'Noto Sans Bengali', 'Noto Sans Tamil', 
               'Noto Sans Arabic', 'Noto Sans Hebrew', 'Noto Sans Thai', 'Noto Sans Ethiopic',
               sans-serif;
}

/* Language-specific rules */
:lang(ja) { font-family: 'Noto Sans JP', sans-serif; }
:lang(ko) { font-family: 'Noto Sans KR', sans-serif; }
:lang(zh-CN) { font-family: 'Noto Sans SC', sans-serif; }
:lang(zh-TW) { font-family: 'Noto Sans TC', sans-serif; }
:lang(hi) { font-family: 'Noto Sans Devanagari', sans-serif; }
:lang(bn) { font-family: 'Noto Sans Bengali', sans-serif; }
:lang(ta) { font-family: 'Noto Sans Tamil', sans-serif; }
:lang(ar) { font-family: 'Noto Sans Arabic', sans-serif; }
:lang(he) { font-family: 'Noto Sans Hebrew', sans-serif; }
:lang(th) { font-family: 'Noto Sans Thai', sans-serif; }
:lang(am), :lang(ti) { font-family: 'Noto Sans Ethiopic', sans-serif; }

/* Emoji support */
.emoji {
  font-family: 'Noto Color Emoji', sans-serif;
}
</style>"""))
def display_feats(feats):
    all_records = []
    feat_len = len(feats[0].records)
    for feat in feats:
        print(feat.pos)
        all_records += feat.records
    records = torch.tensor(all_records)
    dats = [torch.tensor([1]).repeat(feat_len),torch.tensor([2]).repeat(feat_len),torch.tensor([3]).repeat(feat_len),torch.tensor([4]).repeat(feat_len)]
    which_pos = torch.cat(dats)
    print(records.size())
    top_act_inds = torch.argsort(-records)
    covered_already = set()
    simpler_words = []
    for i in range(500):
        ind = top_act_inds[i]
        token_pos = which_pos[ind]
        data_pos = ind % feat_len
        act = records[ind]
        acts = [0]
        for j in range(1,5):
            acts.append([feat.records[data_pos] for feat in feats if feat.pos == j][0])
        toks = model.to_str_tokens(data[data_pos])
        relevant_str = "".join(toks[:token_pos+1])
        if relevant_str in covered_already:
            continue
        covered_already.add(relevant_str)
        out_toks = []
        colors = ['', 'red', 'orange', 'yellow', 'green']
        for j,tok in enumerate(toks):
            if len(tok.strip()) == 0:
                tok = repr(tok)
            tok = tok.replace("\n", "\\n")
            colored = f"<span id='ayy'><font color='{colors[j]}'>{tok}</font></span>"
            if j < 1: continue
            if j == token_pos:
                colored = f"<span id='ayy'><font color='pink'>{tok}</font</span>"
            if acts[j] > 0.01:
                out_toks.append(f"{colored}{acts[j]:.3f}")
            else:
                out_toks.append(colored)
        simpler = model.tokenizer.decode(data[data_pos][1:token_pos+1])
        simpler_words.append(simpler.strip())

        display(HTML(toks[token_pos] + "\t||\t" + "<span id='ayy'>" + simpler + "</span>\t||\t" + "".join(out_toks)))
        if len(simpler_words) == 20:
            out_s = "<span id='ayy'>What do "
            for s in simpler_words:
                out_s += f'"{s.strip()}", '
            out_s += "have in common? Take a deep breath and think step by step."
            display(HTML(out_s + "</span>"))
    '''
    feat_vecs = [get_name_vector(feat, 'mean') for feat in feats]
    avg_vec = torch.stack(feat_vecs).mean(dim=0)
    min_vec = torch.stack([get_name_vector(feat, 'min') for feat in feats]).min(dim=0).values
    max_vec = torch.stack([get_name_vector(feat, 'max') for feat in feats]).max(dim=0).values
    sorted_names = torch.argsort(-avg_vec)
    #print(avg_vec, min_vec, max_vec, sorted_names)
    for name_i in sorted_names[:100]:
        #print(name_i)
        print(f" name {names[name_i]} with avg {avg_vec[name_i]} min {min_vec[name_i]} max {max_vec[name_i]}")
    '''
    '''
    for feat in feats:
        if position_map[feat.pos][0] == 'n':
            print(position_map[feat.pos], detect_single_letter(feat))
            pretty_print_list_first_letter_info(list_first_letter_info(feat))
            print(improved_first_letter(feat))
    
    diffs = torch.zeros(len(feats), len(feats))
    for i,featv1 in enumerate(feat_vecs):
        for j,featv2 in enumerate(feat_vecs):
            diffs[i,j] = torch.mean(torch.abs(featv1-featv2))
    labels = [position_map[feat.pos] for feat in feats]
    imshow(diffs, x=labels, y=labels, font_size=9)
    '''
def save_labels():
    with open("layer_15_features_take_two.pkl", "wb") as f:
        global feature_labelsb
        pickle.dump(feature_labelsb, f)
        print(f"done saving {len(feature_labelsb)}")
import traceback
out = widgets.Output()
global cur_feature_ind
global feature_labels
def submitted(change):
    global cur_feature_ind
    global feature_labels
    if len(text_item.value.strip()) > 0 and (text_item.value != feature_labels[cur_feature_ind]):
        with out:
            try:
                times = 0
                clear_output()
                feature_labelsb[cur_feature_ind] = text_item.value
                save_labels()
                display_unlabeled_feature()
            except:
                print(traceback.format_exc())
                



text_item = widgets.Text(
    value='ffff',
    placeholder='Type something',
    description='String:',
    disabled=False,
    continuous_update=False,
)

display(text_item)
display(out)
text_item.observe(submitted, names='value')

with out:
    display_unlabeled_feature()
    

In [None]:
from IPython.display import display, FileLink
import os
display(FileLink("layer_15_features.pkl"))

In [119]:






























combined = {}
for f in feature_labels.keys():
    if f in feature_labelsb:
        combined[f] = feature_labels[f] + " MOREDATA: " + feature_labelsb[f].replace("?", " ").strip()
    else:
        combined[f] = feature_labels[f]
    
    if f in feature_labelsb and feature_labelsb[f].strip()  == feature_labels[f].strip():
        combined[f] = feature_labels[f]
    
    if feature_labels[f].replace("?", "").strip() == "" and f in feature_labelsb:
        combined[f] = feature_labelsb[f]
    combined[f] = combined[f].replace("?", " ").replace("(", " ").replace(")", " ").strip()
    if combined[f].strip() == "" or combined[f].strip() == '||':
        del combined[f]
labs = sorted(list(combined.items()), key=lambda x: x[1].lower().strip())
for f,l in labs:
    #if not '?' in l:
    if l.strip() != "":
        print(f, l)
        print(f"  attr {features_sorted_by_feat_i[f][0].attr}")
for f in feature_labels.keys():
    if len([(f2,l) for (f2,l) in labs if f2 == f]) == 0:
        print(f"unknown {f} with attr {features_sorted_by_feat_i[f][0].attr}")
        combined[f] = feature_labels[f]
labs = sorted(list(combined.items()), key=lambda x: features_sorted_by_feat_i[x[0]][0].attr)
for f,l in labs:
    #if not '?' in l:
    if l.strip() != "":
        print(f, l)
        print(f"  attr {features_sorted_by_feat_i[f][0].attr}")



print(f"total num features identified {len(labs)}")


# 27638 fires slightly only the second time the name is mentioned for a 1/8 of the names

24649 1800's and 1900's MOREDATA: used to be 1800's and 1900's
  attr -0.013689302893567401
15582 1800-1999
  attr -0.009123188330249832
30568 1980's-1990s pop culture
  attr -0.012585221676090441
6227 3 digit numbers, renal excretions, and escaped quotes MOREDATA: Russian Names, 3 digit numbers, renal excretions, and escaped quotes
  attr -0.008696446947169534
8103 5/May/Five/Fifth
  attr -0.01274918073158915
32395 [ with some other symbols  [", **[, etc.
  attr -0.006224751563763675
31021 A
  attr -0.09784678883806919
29892 A
  attr -0.0719999106204341
14819 about names
  attr -0.009356041136925342
1312 account/login terminology and oxidation/breathing
  attr -0.017313731845206348
12348 actor portrayal, embodied, playing MOREDATA: Famous Actors, actor portrayal, embodied, playing
  attr -0.010314146090422582
32334 advertising/spam text, usually abuilt fine china  on the ina  or dating  on the meet , also "get your feet tapping"
  attr -0.013396876420983972
19724 aggressive/vigorously

In [None]:
feature_labelsb

In [None]:
" Purushottam".strip()



from torch.nn.functional import relu

# Do Features Use Cross Talk?


In [138]:


def get_batched_index_into(indices):
    '''
    given data that is [B,N,V] and indicies that are [B,N,K] with each index being an index into the V space
    this gives you indexes you can use to access your values
    '''
    first_axis = []
    second_axis = []
    third_axis = []
    B, _, _ = indices.size()
    for b in range(B):
        second, third = get_index_into(indices[b])
        first_axis.append(torch.full(second.size(), fill_value=b, device=model.cfg.device))
        second_axis.append(second)
        third_axis.append(third)

    return torch.cat(first_axis), torch.cat(second_axis), torch.cat(third_axis)

def get_index_into(indices):
    '''
    given data that is [N,V] and indicies that are [N,K] with each index being an index into the V space
    this gives you indexes you can use to access your values
    '''
    num_data, num_per_data = indices.size()
    # we want
    # [0,0,0,...,] num per data of these
    # [1,1,1,...,] num per data of these
    # ...
    # [num_data-1, num_data-1, ...]
    first_axis_index = torch.arange(num_data, dtype=torch.long).view(num_data, 1)*torch.ones([num_data, num_per_data], dtype=torch.long)
    # now we flatten it so it has an index for each term aligned with our indices
    first_axis_index = first_axis_index.flatten()
    second_axis_index = indices.flatten()
    return first_axis_index, second_axis_index
global buffer
buffer = None
def sae_hook(
    x,
    hook,
    layer,
):
    # s is [B,L,E]
    K = saes[layer].cfg.k
    sae = saes[layer]
    B,L,D = x.size()
    uncorrupted_features = sae.encode(x)
    top_acts, top_indices = uncorrupted_features.topk(K, sorted=False)
    buffer = torch.zeros(uncorrupted_features.size(), device=model.cfg.device)
    global features_by_layer
    # zero everything except the top k
    buffer[get_batched_index_into(top_indices)] = top_acts.flatten()
    for feature in features_by_layer[layer]:
        if feature.pos < L: # sometimes prompt is too small to consider this feature
            feature.records += buffer[:,feature.pos,feature.feature_i].tolist()
            feature.full_records += uncorrupted_features[:,feature.pos,feature.feature_i].tolist()
    # kernel can't handle doing all token positions at same time by default
    # but if we make it think B*L is a single batch index it works fine
    top_acts_flattened = top_acts.flatten(start_dim=0, end_dim=1)
    top_indices_flattened = top_indices.flatten(start_dim=0, end_dim=1)
    sae_out = sae.decode(top_acts_flattened, top_indices_flattened)
    sae_out = sae_out.unflatten(dim=0, sizes=(B,L))
    return sae_out


prompt_dla = 'Recently, John, James, and Sam went to the market. James'
prompt_tokens_dla = model.to_tokens(prompt_dla)
print(model.to_str_tokens(prompt_tokens_dla))
logits_dla, activations_dla = model.run_with_cache(prompt_tokens_dla)

global features_by_layer
features_by_layer = defaultdict(lambda: [])
layers_to_apply_sae = [15]
feats = []
L = prompt_tokens_dla.size()[1]
for layer in layers_to_apply_sae:
    for f, feats in features_sorted_by_feat_i.items():
        if len(feats) == 0:
            print("huh zero", f, feats)
            continue
        feat = feats[0]
        if feat.layer == layer:
            for i in range(1, L):
                features_by_layer[layer].append(SAEFeature(layer=feat.layer, pos=i, feature_i=feat.feature_i, attr=feat.attr))
                features_by_layer[layer][-1].records = []
                features_by_layer[layer][-1].full_records = []
hooks = [(f'blocks.{layer}.hook_out_proj', partial(sae_hook, layer=layer)) for layer in layers_to_apply_sae]
_ = model.run_with_hooks(input=prompt_tokens_dla, fwd_hooks=hooks, fast_ssm=True, fast_conv=True)




for layer in layers_to_apply_sae:
    # [B,L,E]
    x = activations_dla[f'blocks.{layer}.hook_ssm_input']
    # [B,L,E]
    y = activations_dla[f'blocks.{layer}.hook_y']
    # [B,L,E,N]
    A_bar = activations_dla[f'blocks.{layer}.hook_A_bar']
    # [B,L,E,N]
    B_bar = activations_dla[f'blocks.{layer}.hook_B_bar']
    # [B,L,E]
    skip = F.silu(activations_dla[f'blocks.{layer}.hook_skip'])
    # [B,L,N]
    C = activations_dla[f'blocks.{layer}.hook_C']
    # [D,E]
    W_out = model.blocks[layer].out_proj
    # [E]
    W_D = model.blocks[layer].W_D
    
    
    # [B,L,E,N]
    #h = torch.zeros([B,L,E,N], device=model.cfg.device)
    ys = []
    h = torch.zeros([B,E,N], device=model.cfg.device)
    for l in range(L):
        if l < 1:
            # [B,E,N]   [B,E,N]     [B,E,N]          [B,E,N]          [B,E]
            h        =    h    *  A_bar[:,l,:,:]  + B_bar[:,l,:,:] * x[:,l].view(B, E, 1) 
            h_0 = h
        else:       
            # [B,E,N]  [B,E,N]      [B,E,N] 
            h_0       =  h_0   *  A_bar[:,l,:,:] # do the A_bar multiply but ignore other x's
            # [B,E,N]   [B,E,N]     [B,E,N]          [B,E,N]          [B,E]
            h        =   h_0  + B_bar[:,l,:,:] * x[:,l].view(B, E, 1)
        # [B,E]    [B,E,N]       [B,N,1]   # this is like [E,N] x [N,1] for each batch

        y_l       =   h     @   C[:,l,:].view(B,N,1)
        # [B,E]              [B,E,1]
        y_l      =    y_l.view(B,E)
        ys.append(y_l)
    y = torch.stack(ys, dim=1)
    '''
    for l in range(L):
        # [B,E,N]
        h_l = activations_dla[f'blocks.{layer}.hook_h.{l}']
        # [B,L,E,N][:,l,:,:]   [B,E,N]
        h[:,l,:,:]             = h_l
    #                   (       [B,L,E,1]        )
    # [B,L,E]           [B,L,E,N] x  [B,L,N,1]                     [B,L,E]  [E]
    y_out         =     (   h     @ C.view(B,L,N,1)).view(B,L,E) +   x    *  W_D
    '''
    y_out         =     y +   x    *  W_D
    # [B,L,E]       [B,L,E]   [B,L,E]
    y_out         = y_out   *   skip
    # [B,L,D]       [B,L,E] x    [E,D]
    y_out         = y_out   @ W_out.weight.T

    print("wwbwf")
    sae = saes[layer]
    # F is feature size
    # [B,L,F]         [B,L,D]       [D]                [D,F]
    K = sae.cfg.k
    sae_vals = sae.encode(y_out)
    top_acts, top_indices = sae_vals.topk(K, sorted=False)
    buffer = torch.zeros(sae_vals.size(), device=model.cfg.device)
    # zero everything except the top k
    buffer[get_batched_index_into(top_indices)] = top_acts.flatten()
    #sae_vals      = relu((y_out - sae.b_dec) @ sae.encoder.weight.T)
    tests = 200
    t = 0
    for feat in sorted(list(features_by_layer[layer]), key=lambda f:  -abs(f.records[0] - buffer[0,f.pos,f.feature_i])):
        diff = buffer[0,feat.pos,feat.feature_i] - feat.records[0]
        if feat.pos == L-1:
            if abs(diff) > 0.001:
                print(diff.item(), feat.records[0], '->', buffer[0,l,feat.feature_i].item())
                print(feat.feature_i, feat.attr, dict(labs)[feat.feature_i])
                t += 1
        if t >= 200: break
    if t >= 200: break
    

['<|endoftext|>', 'Recently', ',', ' John', ',', ' James', ',', ' and', ' Sam', ' went', ' to', ' the', ' market', '.', ' James']
huh zero <_io.TextIOWrapper name='/home/dev/mamba_interp/MoreNames.txt' mode='r' encoding='UTF-8'> []
wwbwf
0.21070288121700287 0.0 -> 0.21070288121700287
22605 -0.016981449274680926 Repeated Token
-0.10474039614200592 0.1802479773759842 -> 0.07550758123397827
27638 -0.019120284184737102 ?
-0.08901895582675934 0.08901895582675934 -> 0.0
5441 -0.02206722709161113 three digit numbers and determined/resolved/finishes MOREDATA: code notation  previously three digit numbers and determined/resolved/finishes
-0.07798303663730621 0.07798303663730621 -> 0.0
15112 -0.006430887430042276 maybe double letters
0.0596872977912426 0.0 -> 0.0596872977912426
11188 -0.013412890531981247 Fighter/Bomber Wings
-0.055812593549489975 0.055812593549489975 -> 0.0
8113 -0.07060992660626653 N
0.052919819951057434 0.0 -> 0.052919819951057434
23490 -0.012551901419328715 lipids/cholestero

# Feature Decomposition


In [64]:
# construct a linear program to find the minimal change in state that modifies a target feature

# [B,E,N]

import torch
import torch.nn.functional as F
import pandas as pd
from einops import rearrange
import ipywidgets
from IPython.display import display, clear_output

prompt_dla = 'Lately, John, James, and John'
prompt_tokens_dla = model.to_tokens(prompt_dla)
print(model.to_str_tokens(prompt_tokens_dla))
logits_dla, activations_dla = model.run_with_cache(prompt_tokens_dla)
token_labels_dla = [f"{token}_{index}" for index, token in enumerate(model.to_str_tokens(prompt_tokens_dla[0]))]



B,L = prompt_tokens_dla.size()

D, V, E, N = model.cfg.D, model.cfg.V, model.cfg.E, model.cfg.N


# last layer after adding y_out
# [B,L,D]
# resid_passed_into_norm = activations_dla[f'blocks.{model.cfg.n_layers-1}.hook_resid_post']

# norms along the d dimension
# [B,L,1]
# divide_magnitudes = torch.rsqrt(resid_passed_into_norm.pow(2).mean(-1, keepdim=True)+1e-5)

# [1,1,D]
# W_N = model.norm.weight.view(1,1,D)

# W_unembed = model.lm_head

# [B,L,V]     [D->V]               [B,L,D]                  [B,L,1]        [1,1,D]
# embed_logits = W_unembed(activations_dla[f'hook_embed'] * divide_magnitudes  *  W_N)

# [B,L,V]
# res_logits_simple = embed_logits.clone()
# res_logits_by_e = embed_logits.clone()

# [B,L]
# res_logits_by_e_n = embed_logits.clone()[:,:,target_token_dla]

contributions_h_e = torch.zeros([model.cfg.n_layers, model.cfg.N], device=model.cfg.device)

contributions_y = torch.zeros([model.cfg.n_layers], device=model.cfg.device)

contributions_xd = torch.zeros([model.cfg.n_layers], device=model.cfg.device)

contributions_h_e_n = torch.zeros([model.cfg.n_layers, model.cfg.E, model.cfg.N], device=model.cfg.device)


LAYERS_STUDYING = [15]

for layer in range(LAYERS_STUDYING):
    # [B,L,E]
    x = activations_dla[f'blocks.{layer}.hook_ssm_input']
    # [B,L,E]
    y = activations_dla[f'blocks.{layer}.hook_y']
    # [B,L,E]
    skip = F.silu(activations_dla[f'blocks.{layer}.hook_skip'])
    # [B,L,N]
    C = activations_dla[f'blocks.{layer}.hook_C']
    # [D,E]
    W_out = model.blocks[layer].out_proj

    
    
    # [B,L,E,N]
    h = torch.zeros([B,L,E,N], device=model.cfg.device)
    for l in range(L):
        # [B,E,N]
        h_l = activations_dla[f'blocks.{layer}.hook_h.{l}']
        # [B,L,E,N][:,l,:,:]   [B,E,N]
        h[:,l,:,:]             = h_l

    
    # we want to make something that does E -> V
    # this requires W_out, the norm stuff, and W_u

    # fold the norm into W_unembed
    # this doesn't include the divide_magnitudes, but those are [B,L,1] so they commute
    # its fine to just use .weight because neither of these have bias
    # [D,V]                     [D,V]          [D,1]
    # W_unembed_and_norm = W_unembed.weight.T * W_N.view(D,1)
    # [E,V]                [E,D]            [D,V]
    # W_unembed_from_e = W_out.weight.T @ W_unembed_and_norm

    # we can use this to do DLA of individual streams (there are ExN of them)
    # [B,L,E,N]       [B,L,E,N]         [B,L,1,N]         [B,L,E,1]                    [B,L,1,1]
    #y_out           =     h       * C.view(B,L,1,N)  * skip.view(B,L,E,1) # * divide_magnitudes.view(B,L,1,1)

    # we have E streams of size N
    # if we just matmul W_unembed_from_e to each E sized stream, we get V streams of size N
    # that matmul is just dot product, which is just element-wise product and then sum
    # we want the element-wise product before the sum
    # unfortunately this would be [B,L,E,N,V] which is too big for memory
    # so we will just compute the logit for our target v
    # [E]                     [E,V][:,target_token_dla]
    #logit_proj_vec = W_unembed_from_e[:,target_token_dla]
    #sae = saes[layer]

    # move E to last axis so easier to do matmuls
    # [B,L,N,D]              [B,L,N,E]                     x     [E,D]  
    #y_out    =    rearrange(y_out, 'B L E N -> B L N E')    @  W_out.weight.T
    
    # F = num features
    # [B,L,D,F]       [B,L,N,D]       [D]               [D,F]
    #sae_features =   (y_out   - sae.b_dec.weight) @ sae.encoder.weight.T

    
    # [E,F]                      [E,D]             [D,]
    #proj_into_sae_feature =  W_out.weight.T @ 
   
    #                   (       [B,L,E,1]        )
    # [B,L,E]           [B,L,E,N] x  [B,L,N,1]
    y_out         =     (   h     @ C.view(B,L,N,1)).view(B,L,E)
    # [B,L,E]       [B,L,E]   [B,L,E]
    y_out         = y_out   *   skip
    # [B,L,D]       [B,L,E] x    [E,D]
    y_out         = y_out   @ W_out.weight.T
    
    sae = saes[layer]
    # F is feature size
    #               [B,L,D]       [D]                [D,F]
    sae_vals      = (y_out - sae.b_dec.weight) @ sae.encoder.weight.T





    # [B,L,E,N]    [B,L,E,N]        [1,1,E,1]
    logit_contrib = y_out   *   logit_proj_vec.view(1,1,E,1)

    # [E,N]                         [B,L,E,N][0,-1]
    contributions_h_e_n[layer] = logit_contrib[0,-1]

    # [B,L]                 [B,L,E,N].sum(along E and N axis)
    stream_logit_total = logit_contrib.sum(dim=-1).sum(dim=-1)
    
    # [B,L]                   [B,L]
    res_logits_by_e_n += stream_logit_total

    # [E]
    W_D = model.blocks[layer].W_D

    y_from_c = torch.zeros([B,L,model.cfg.E], device=model.cfg.device)
    for n in range(N):
        
        # [B,L,E]       [B,L,1]                [B,L,E,N][:,:,:,n]
        y_from_c     += C[:,:,n].view(B,L,1)   *    h[:,:,:,n]

        # [B,L,V]                 [D->V]    [E->D]         [B,L,1]              [B,L,E,N][:,:,:,n]    [B,L,E]            [B,L,1]          [1,1,D]
        logit_contribution_h_e = W_unembed( W_out(    C[:,:,n].view(B,L,1)   *    h[:,:,:,n]  *         skip )       * divide_magnitudes * W_N )
        res_logits_by_e += logit_contribution_h_e
        contributions_h_e[layer,n] = logit_contribution_h_e[0,-1,target_token_dla]

    # make sure we are computing it right
    assert(torch.allclose(y, y_from_c, atol=0.01))
    

    # [B,L,V]                 [D->V]  [E->D] [B,L,E]         [B,L,E]     [B,L,1]       [1,1,D]
    logit_contribution_y  = W_unembed(W_out(    y    *        skip  )*divide_magnitudes* W_N )

    # [B,L,V]                 [D->V]  [E->D] [B,L,E]       [1,1,E]                  [B,L,E]     [B,L,1]        [1,1,D]
    logit_contribution_xd = W_unembed(W_out(    x    *  W_D.view(1,1,E)      *       skip )*divide_magnitudes * W_N)
    
    #before_add_to_resid = activations_dla[f'blocks.{layer}.hook_out_proj']
    #simple_test = W_unembed(before_add_to_resid*divide_magnitudes*W_N)

    #after_skip = activations_dla[f'blocks.{layer}.hook_after_skip']
    #simple_test = W_unembed(W_out(after_skip)*divide_magnitudes*W_N)
    contributions_y[layer] = logit_contribution_y[0,-1,target_token_dla]
    contributions_xd[layer] = logit_contribution_xd[0,-1,target_token_dla]
    
    res_logits_simple += logit_contribution_y + logit_contribution_xd
    res_logits_by_e += logit_contribution_xd # we already did the y in the loop
    # [B,L]                    [B,L,V][:,:,target_token_dla]
    res_logits_by_e_n += logit_contribution_xd[:,:,target_token_dla] # we already did the y above
    
    #res_logits += simple_test
    
# simplest W_unembed(resid_passed_into_norm*divide_magnitudes*W_N)

# make sure we did it right
assert(torch.allclose(res_logits_simple, logits_dla, atol=0.01))
assert(torch.allclose(res_logits_by_e, logits_dla, atol=0.01))
#                       [B,L]             [B,L,V][:,:,target_token_dla]
assert(torch.allclose(res_logits_by_e_n, logits_dla[:,:,target_token_dla], atol=0.01))

layer_names = ['layer ' + str(i) for i in range(model.cfg.n_layers)]

annotation = f"Prompt: '{prompt_dla}' -> '{prompt_answer_dla}'" 

ANNOTATION_FONT_SIZE = 10
def add_annotation(fig, annotation):
    fig.add_annotation(dict(font=dict(color='black',size=ANNOTATION_FONT_SIZE),
                                        x=0,
                                        y=1,
                                        showarrow=False,
                                        text=annotation,
                                        textangle=0,
                                        xanchor='left',
                                        yanchor='bottom',
                                        xref="paper",
                                        yref="paper"))

fig = px.imshow(utils.to_numpy(contributions_h_e.T), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":"layer", "y":'N'}, x=layer_names, title=f"Logit contributions of E-sized h's")
add_annotation(fig, annotation=annotation)
fig.show()

fig = px.imshow(utils.to_numpy(contributions_h_e.T[:,:-1]), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":"layer", "y":'N'}, x=layer_names[:-1], title=f"Logit contributions of E-sized h's (no last layer)")
add_annotation(fig, annotation=annotation)
fig.show()

def bar_chart(data, x_labels, y_label, title, annotation, font_size=None):
    # it requires a pandas dict with the columns and rows named, annoying
    # by default rows and columns are named with ints so we relabel them accordingly
    renames = dict([(i, x_labels[i]) for i in range(len(x_labels))])
    ps = pd.DataFrame(data.cpu().numpy()).rename(renames, axis='rows').rename({0: y_label}, axis='columns')
    fig = px.bar(ps, y=y_label, x=x_labels, title=title)
    add_annotation(fig, annotation=annotation)
    if not font_size is None:
        fig.update_layout(
          xaxis = dict(
            tickmode='array',
            tickvals = x_labels,
            ticktext = x_labels, 
            ),
           font=dict(size=font_size, color="black"))
        #fig.update_xaxes(title_font=dict(size=font_size))
    
    fig.show()




bar_chart(data=contributions_y, x_labels=layer_names, y_label='logit contribution', title='Direct Logit Attribution y', annotation=annotation)
bar_chart(data=contributions_y[:-1], x_labels=layer_names[:-1], y_label='logit contribution', title='Direct Logit Attribution y (no last layer)', annotation=annotation)
bar_chart(data=contributions_xd, x_labels=layer_names, y_label='logit contribution', title='Direct Logit Attribution x*W_D', annotation=annotation)
bar_chart(data=contributions_xd[:-1], x_labels=layer_names[:-1], y_label='logit contribution', title='Direct Logit Attribution x*W_D (no last layer)', annotation=annotation)




# also display embed contribution here
contributions_ssm = torch.zeros([model.cfg.n_layers+1])
contributions_ssm[1:] = contributions_y + contributions_xd
contributions_ssm[0] = embed_logits[0,-1,target_token_dla]
contributions_ssm_labels = ['embed'] + layer_names
bar_chart(data=contributions_ssm, x_labels=contributions_ssm_labels, y_label='logit contribution', title='Direct Logit Attribution ssm=y+x*W_D', annotation=annotation)
bar_chart(data=contributions_ssm[:-1], x_labels=contributions_ssm_labels[:-1], y_label='logit contribution', title='Direct Logit Attribution ssm=y+x*W_D (no last layer)', annotation=annotation)



top_n = 100

def get_top_n(data, label_prefix):
    inds = torch.argsort(-torch.abs(data.flatten()))
    e_ind = inds // N
    n_ind = inds % N
    top_n_data = []
    top_n_labels = []
    only_neg = data.clone()
    only_neg[data>0] = 0
    tail_mag_neg = torch.sum(only_neg.flatten()[inds[top_n:]]).item()
    only_pos = data.clone()
    only_pos[data<0] = 0
    tail_mag_pos = torch.sum(only_pos.flatten()[inds[top_n:]]).item()
    tail_mag_total = torch.sum(torch.abs(data).flatten()[inds[top_n:]]).item()
    for i, (e, n) in enumerate(zip(e_ind[:top_n], n_ind[:top_n])):
        top_n_data.append(data[e,n].item())
        top_n_labels.append(f'{label_prefix} e={e} n={n}')
    return top_n_data, top_n_labels, (tail_mag_pos, tail_mag_neg, tail_mag_total)

def get_top_n_from_layers(data, layers):
    all_data, all_labels = [], []
    tail_mags = torch.zeros([3])
    for layer in layers:
        # this could be done simpler by flattening the whole thing but i can't be bothered to figure out how to get 3D index from that
        top_n_data, top_n_labels, mags = get_top_n(data=data[layer], label_prefix=f'layer={layer}')
        all_data += top_n_data
        all_labels += top_n_labels
        tail_mags += torch.tensor(list(mags))
        
    all_data = torch.tensor(all_data)
    tops = torch.argsort(-torch.abs(all_data))[:top_n]
    all_data = all_data[tops]
    all_labels = [all_labels[i.item()] for i in tops]
    return all_data, all_labels, tail_mags


FONT_SIZE = 7

def get_tail_mag_str(tail_mags):
    return f'  sum of streams not present +:{tail_mags[0].item()} -:{tail_mags[1].item()} all:{tail_mags[2].item()}'

all_layers_data, all_layers_labels, tail_mags = get_top_n_from_layers(data=contributions_h_e_n, layers=range(model.cfg.n_layers))
tail_mag_str = get_tail_mag_str(tail_mags)
bar_chart(data=all_layers_data, x_labels=all_layers_labels, y_label='logit contribution', title=f'Direct Logit Attribution of top {top_n} streams', annotation=annotation + tail_mag_str, font_size=FONT_SIZE)

all_layers_data, all_layers_labels, tail_mags = get_top_n_from_layers(data=contributions_h_e_n, layers=range(model.cfg.n_layers-2))
tail_mag_str = get_tail_mag_str(tail_mags)
bar_chart(data=all_layers_data, x_labels=all_layers_labels, y_label='logit contribution', title=f'Direct Logit Attribution of top {top_n} streams (excluding last two layers)', annotation=annotation + tail_mag_str, font_size=FONT_SIZE)


def display_dla(layer_ind):
    top_n = 100
    with output: # this lets the stuff we output here be visible
        clear_output()
        batch_size = N*4
        # [E,N]
        data = contributions_h_e_n[layer_ind]

        top_n_data, top_n_labels, tail_mags = get_top_n(data=data, label_prefix="")
        tail_mag_str = get_tail_mag_str(torch.tensor(list(tail_mags)))
        bar_chart(data=torch.tensor(top_n_data), x_labels=top_n_labels, y_label='logit contribution', title=f'Direct Logit Attribution of top {top_n} streams on layer {layer_ind}', annotation=annotation + tail_mag_str, font_size=FONT_SIZE)
        
        ''' show all in batches
        for start_e in range(0, E, batch_size):
            end_e = min(E, start_e + batch_size)
            # [end_e-start_e,N]
            batch_data = data[start_e:end_e,:]
            X = [str(e) for e in range(start_e, end_e)]
            fig = px.imshow(utils.to_numpy(data.T), x=X, color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":"N", "y":'E'}, title=f"Logit contributions of streams at layer {layer_ind}")
            add_annotation(fig, annotation=annotation)
            fig.show()
        '''
       
choose_layer_dropdown_dla = ipywidgets.Dropdown(
    options=layer_names,
    value=layer_names[0],
    description='Layer',
) 

def choose_layer_dla(change):
    if change['type'] == 'change' and change['name'] == 'value':
        choose_layer_dropdown_dla.layer = layer_names.index(change['new'])
        display_dla(layer_ind=choose_layer_dropdown_dla.layer)

choose_layer_dropdown_dla.layer = 0

choose_layer_dropdown_dla.observe(choose_layer_dla)

display(choose_layer_dropdown_dla)

# you can't just display stuff inside a widget callback, you need a wrap any display code in this
output = ipywidgets.Output()
display(output)

display_dla(layer_ind=choose_layer_dropdown_dla.layer)

['<|endoftext|>', 'L', 'ately', ',', ' John', ',', ' James', ',', ' and', ' John']


TypeError: 'list' object cannot be interpreted as an integer

In [68]:
import sympy
import torch
from sympy import Matrix

a = Matrix([1,2,3])
from sympy.abc import a,b,c
torch.tensor([a,b,c])

RuntimeError: Could not infer dtype of Symbol