In [1]:

def imshow(tensor, renderer=None, xaxis="", yaxis="", font_size=None, show=True, color_continuous_midpoint=0.0, **kwargs):
    import plotly.express as px
    import transformer_lens.utils as utils
    fig = px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=color_continuous_midpoint, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs)
    if not font_size is None:
        if 'x' in kwargs:
            fig.update_layout(
              xaxis = dict(
                tickmode='array',
                tickvals = kwargs['x'],
                ticktext = kwargs['x'], 
                ),
               font=dict(size=font_size, color="black"))
        if 'y' in kwargs:
            fig.update_layout(
              yaxis = dict(
                tickmode='array',
                tickvals = kwargs['y'],
                ticktext = kwargs['y'], 
                ),
               font=dict(size=font_size, color="black"))
    plot_args = {
        'width': 800,
        'height': 600,
        "autosize": False,
        'showlegend': True,
        'margin': {"l":0,"r":0,"t":100,"b":0}
    }
    
    fig.update_layout(**plot_args)
    fig.update_layout(legend=dict(
        yanchor="top",
        y=0.99,
        xanchor="left",
        x=0.01
    ))
    if show:
        fig.show(renderer)
    else:
        return fig


# requires
# pip install git+https://github.com/Phylliida/MambaLens.git

from mamba_lens import HookedMamba # this will take a little while to import
import torch
model_path = "state-spaces/mamba-370m"

# NOTE! We need to monkeypatch transformer lens to use register_full_backward_hook
model = HookedMamba.from_pretrained(model_path, device='cuda')
torch.set_grad_enabled(False)


from safetensors.torch import load_file
from collections import defaultdict

import sys
if not "/home/dev/sae-k-sparse-mamba/sae" in sys.path:
    sys.path.append("/home/dev/sae-k-sparse-mamba")
import os
os.chdir('/home/dev/sae-k-sparse-mamba')
saes = [None]
from importlib import reload
from sae.sae import Sae

ckpt_dir = "/home/dev/sae-k-sparse-mamba/"
for i in range(1,22):
    print(i)
    hook = f'blocks.{i}.hook_out_proj'
    path = [ckpt_dir + f for f in sorted(list(os.listdir(ckpt_dir))) if hook in f][0] + "/" + f'hook_{hook}.pt'
    #path = f'/home/dev/sae-k-sparse-mamba/blocks.{i}.hook_resid_pre/hook_blocks.{i}.hook_resid_pre.pt'
    print(path)
    #saes.append(Sae.load_from_disk(path, hook=f'blocks.{i}.hook_resid_pre', device=model.cfg.device))


global PATCHING_FORMAT_I
global patching_formats
def make_data(num_patching_pairs, patching, template_i, seed, valid_seed):
    constrain_to_answers = True
    # this makes our data size 800, first 400 is each (a,b) pair, and then second 400 is each pair swapped to be (b,a)
    has_symmetric_patching = True
    
    n1_patchings = ["""
    ABC BC A
    DBC BC D""",
        """
    ABC CB A
    DBC CB D"""]
    
    n2_patchings = ["""
    ABC AC B
    ADC AC D""",
        """
    ABC CA B
    ADC CA D"""]
    
    n3_patchings = ["""
    ABC AB C
    ABD AB D""",
        """
    ABC BA C
    ABD BA D"""]
    
    n4_patchings = ["""
    ABC AC B
    ABC BC A""",
        """
    ABC AB C
    ABC CB A""",
        """
    ABC BA C
    ABC CA B"""]
    
    n5_patchings = ["""
    ABC CA B
    ABC CB A
    """,
        """
    ABC BA C
    ABC BC A""",
        """
    ABC AB C
    ABC AC B"""]
    
    patchings = {
        'n1': n1_patchings,
        'n2': n2_patchings,
        'n3': n3_patchings,
        'n4': n4_patchings,
        'n5': n5_patchings
    }
    
    all_patchings = []
    for patching_set in patchings.values():
        all_patchings += patching_set
    all_patchings = sorted(all_patchings) # make deterministic 
    
    patch_all_names = ["""
    ABC AB C
    DEF DE F""",
        """
    ABC AC B
    DEF DF E""",     
        """
    ABC BA C
    DEF ED F""",
        """
    ABC BC A
    DEF EF D""",
        """
    ABC CA B
    DEF FD E""",
        """
    ABC CB A
    DEF FE D"""]
    
    
    patchings['all'] = all_patchings
    patchings['allatonce'] = patch_all_names
    from acdc.data.ioi import BABA_TEMPLATES, ABC_TEMPLATES
    from acdc.data.ioi import ioi_data_generator, ABC_TEMPLATES, get_all_single_name_abc_patching_formats
    from acdc.data.utils import generate_dataset
    templates = ABC_TEMPLATES
    #patching_formats = list(get_all_single_name_abc_patching_formats())
    global PATCHING_FORMAT_I
    global patching_formats
    PATCHING_FORMAT_I = patching
    patching_formats = ["\n".join([line.strip() for line in x.split("\n")]).strip() for x in patchings[PATCHING_FORMAT_I]]
    
    print("using patching format")
    for patch in patching_formats:
        print(patch)
        print("")
    #print(patching_formats)
    
    
    data = generate_dataset(model=model,
                      data_generator=ioi_data_generator,
                      num_patching_pairs=4,
                      seed=seed,
                      valid_seed=valid_seed,
                      constrain_to_answers=constrain_to_answers,
                      has_symmetric_patching=has_symmetric_patching, 
                      varying_data_lengths=True,
                      templates=templates,
                      patching_formats=patching_formats)
    
    
    import acdc.data.ioi
    from collections import defaultdict
    name_positions_map = defaultdict(lambda: [])
    for template in templates:
        name = acdc.data.ioi.good_names[0]
        template_filled_in = template.replace("[NAME]", name)
        template_filled_in = template_filled_in.replace("[PLACE]", acdc.data.ioi.good_nouns['[PLACE]'][0])
        template_filled_in = template_filled_in.replace("[OBJECT]", acdc.data.ioi.good_nouns['[OBJECT]'][0])
        # get the token positions of the [NAME] in the prompt
        name_positions = tuple([(i) for (i,s) in enumerate(model.to_str_tokens(torch.tensor(model.tokenizer.encode(template_filled_in)))) if s == f' {name}'])
        name_positions_map[name_positions].append(template)
    sorted_by_frequency = sorted(list(name_positions_map.items()), key=lambda x: -len(x[1]))
    most_frequent_name_positions, templates = sorted_by_frequency[0]
    print("using templates")
    templates = [templates[0]]
    for template in templates:
        print(template)
    print(f"with name positions {most_frequent_name_positions}")
    import acdc.data.ioi
    if 'Jesus' in acdc.data.ioi.good_names:
        print("removed jesus")
        acdc.data.ioi.good_names.remove("Jesus")
    data = generate_dataset(model=model,
                  data_generator=ioi_data_generator,
                  num_patching_pairs=num_patching_pairs,
                  seed=seed,
                  valid_seed=valid_seed,
                  constrain_to_answers=constrain_to_answers,
                  has_symmetric_patching=has_symmetric_patching, 
                  varying_data_lengths=True,
                  templates=templates,
                  patching_formats=patching_formats)
    
    print(model.to_str_tokens(data.data[0]))
    print(model.to_str_tokens(data.data[1]))
    return data


from transformer_lens.hook_points import HookPoint
# we do a hacky thing where this first hook clears the global storage
# second hook stores all the hooks
# then third hook computes the output (over all the hooks)
# this avoids recomputing and so is much faster
SAE_HOOKS = "sae hooks"
SAE_BATCHES = "sae batches"
SAE_OUTPUT = "sae output"
def sae_patching_storage_hook(
    x,
    hook: HookPoint,
    sae_feature_i: int,
    dummy: bool,
    position: int,
    layer: int,
    batch_start: int,
    batch_end: int,
    **kwargs,
):
    global sae_storage
    if not SAE_HOOKS in sae_storage:
        sae_storage[SAE_HOOKS] = [] # we can't do this above because it'll be emptied again on the next batch before this is called
    sae_storage[SAE_OUTPUT] = None # clear output
    sae_storage[SAE_HOOKS].append({"position": position, "sae_feature_i": sae_feature_i, "dummy": dummy})
    #print(f"sae feature i {sae_feature_i} position {position} layer {layer}")
    return x

from jaxtyping import Float
from einops import rearrange

global sae_storage
sae_storage = {}
def sae_patching_hook(
    x: Float[torch.Tensor, "B L E"],
    hook: HookPoint,
    input_hook_name: str,
    layer: int,
    **kwargs,
) -> Float[torch.Tensor, "B L E"]:
    global sae_storage
    ### This is identical to what the conv is doing
    # but we break it apart so we can patch on individual filters
    # we have two input hooks, the second one is the one we want
    input_hook_name = input_hook_name[1]
    # don't recompute these if we don't need to
    # because we stored all the hooks and batches in conv_storage, we can just do them all at once
    
    # they need to share an output because they write to the same output tensor
    if sae_storage[SAE_OUTPUT] is None:
        #print(f"running for layer {layer}")
        K = saes[layer].cfg.k
        sae = saes[layer]
        #print(f"layer {layer} storage {sae_storage}")
        sae_output = torch.zeros(x.size(), device=model.cfg.device)
        #print("layer", layer, "keys", conv_storage)
        def get_filter_key(i):
            return f'filter_{i}'
        sae_input_uncorrupted = x[::2]
        sae_input_corrupted = x[1::2]
        B, L, D = sae_input_uncorrupted.size()
        for l in range(L):
            # [B, NFeatures]                             [B,D]
            uncorrupted_features = sae.encode(sae_input_uncorrupted[:,l])
            # [B, NFeatures]                             [B,D]
            corrupted_features = sae.encode(sae_input_corrupted[:,l])
            patched_features = corrupted_features.clone()
            #patched_features = torch.zeros(corrupted_features.size(), device=model.cfg.device) # patch everything except the features we are keeping around
            # apply hooks (one hook applies to a single feature)
            #print(f"{len(sae_storage[SAE_HOOKS])} hooks")
            for hook_data in sae_storage[SAE_HOOKS]:
                position = hook_data['position']
                sae_feature_i = hook_data['sae_feature_i']
                dummy = hook_data['dummy']
                if not dummy and (position == l or position is None): # position is None means all positions
                    if copy_from_other:
                        patched_features[:,sae_feature_i] = corrupted_features[:,sae_feature_i]
                    else:
                        patched_features[:,sae_feature_i] = uncorrupted_features[:,sae_feature_i]
                    
                    #print(f"applying sae feature {sae_feature_i} to position {position} for layer {layer}")
                    #uncorrupted_features[:,sae_feature_i] = corrupted_features[:,sae_feature_i]
            # compute sae outputs
            patched_top_acts, patched_top_indices = patched_features.topk(K, sorted=False)
            corrupted_top_acts, corrupted_top_indices = corrupted_features.topk(K, sorted=False)      
            sae_output[::2,l] = sae.decode(patched_top_acts, patched_top_indices)     
            sae_output[1::2,l] = sae.decode(corrupted_top_acts, corrupted_top_indices)
        sae_storage = {} # clean up and prepare for next layer
        sae_storage[SAE_OUTPUT] = sae_output # store the output
    return sae_storage[SAE_OUTPUT]

from dataclasses import dataclass, field
@dataclass
class SAEFeature:
    """Class for keeping track of an item in inventory."""
    layer: int
    pos: int
    feature_i: int
    attr: float
    records: list = field(default_factory=lambda: [])

    def __repr__(self):
        return str(self.layer) + " " + str(self.pos) + " " + str(self.feature_i) + " " + str(self.attr)

    def __str__(self):
        return self.__repr__()

def get_name_counts(feature):
    name_counts = {}
    DATA_LEN = len(feature.records)
    records_tensor = torch.tensor(feature.records)
    non_zero_indices = torch.arange(DATA_LEN)[records_tensor!=0]
    non_zero_tokens = data.data[non_zero_indices,feature.pos].cpu()
    non_zero_records = records_tensor[non_zero_indices]
    name_tokens = torch.unique(non_zero_tokens)
    for name_token in name_tokens:
        name_str = model.to_str_tokens(name_token.view(1,1))[0]
        name_counts[name_str] = non_zero_records[non_zero_tokens==name_token.item()]
    #for t,c in template_counts.items():
    #    print(f" template {t} with count {torch.mean(torch.tensor(c)).item()}")
    name_counts = sorted(list(name_counts.items()), key=lambda x: -torch.mean(x[1]).item())
    return name_counts
    #for n,c in name_counts[:100]:
    #    print(f" name {n} with avg {torch.mean(c).item()} min {torch.min(c).item()} max {torch.max(c).item()}")

data = make_data(num_patching_pairs=2, patching="all", template_i=0, seed=24, valid_seed=23)

toks = model.to_str_tokens(data.data[0])
name_positions = [3,5,7,13,15]
position_map = {}
L = data.data.size()[1]
for l in range(L):
    position_map[l] = f'pos{l}{toks[l]}'
position_map[3] = 'n1'
position_map[5] = 'n2'
position_map[7] = 'n3'
position_map[13] = 'n4'
position_map[15] = 'n5'
position_map[19] = 'out'
import pickle
with open("cached_sae_feature_edges.pkl", "rb") as f:
    edges_to_keep = pickle.load(f)

h = model.to_str_tokens(torch.arange(model.tokenizer.vocab_size))
spaceThings = [(i, x) for (i, x) in enumerate(h) if x[0] == ' ' and len(x.strip()) > 0]
prefix = data.data[0][:3].view(1,-1)
new_data_toks = torch.tensor([tok for (tok,s) in spaceThings], device=model.cfg.device)
data_for_all_tokens = torch.cat([prefix.repeat((len(new_data_toks),1)), new_data_toks.view(-1,1)], dim=1)
data.data = data_for_all_tokens
with open("layer_15_features_more_more.pkl", "rb") as f:
    data, features = pickle.load(f)

'''
#names = sorted(list(acdc.data.ioi.good_names))
names = [x for (i,x) in spaceThings if len(x.strip()) > 0]
NUM_NAMES = len(names)
#name_to_i = dict([(" " + name, i) for (i, name) in enumerate(names)])
name_to_i = dict([(name, i) for (i, name) in enumerate(names)])

def get_name_vector(feature, feat_type):
    name_vec = torch.zeros(NUM_NAMES)
    for name,counts in get_name_counts(feature):
        if feat_type == 'mean':
            name_vec[name_to_i[name]] = torch.mean(counts)
        elif feat_type == 'min':
            name_vec[name_to_i[name]] = torch.min(counts)
        elif feat_type == 'max':
            name_vec[name_to_i[name]] = torch.max(counts)
            
    return name_vec
'''
features_sorted_by_feat_i = defaultdict(lambda: [])
for feature in features:
    features_sorted_by_feat_i[feature.feature_i].append(feature)

  return self.fget.__get__(instance, owner)()
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Moving model to device:  cuda
1
/home/dev/sae-k-sparse-mamba/0.0001414213562373095 initialTry190.txtblocks.1.hook_out_proj/hook_blocks.1.hook_out_proj.pt
2
/home/dev/sae-k-sparse-mamba/0.0001414213562373095 initialTry191.txtblocks.2.hook_out_proj/hook_blocks.2.hook_out_proj.pt
3
/home/dev/sae-k-sparse-mamba/0.0001414213562373095 initialTry192.txtblocks.3.hook_out_proj/hook_blocks.3.hook_out_proj.pt
4
/home/dev/sae-k-sparse-mamba/0.0001414213562373095 initialTry193.txtblocks.4.hook_out_proj/hook_blocks.4.hook_out_proj.pt
5
/home/dev/sae-k-sparse-mamba/0.0001414213562373095 initialTry194.txtblocks.5.hook_out_proj/hook_blocks.5.hook_out_proj.pt
6
/home/dev/sae-k-sparse-mamba/0.0001414213562373095 initialTry195.txtblocks.6.hook_out_proj/hook_blocks.6.hook_out_proj.pt
7
/home/dev/sae-k-sparse-mamba/0.0001414213562373095 initialTry196.txtblocks.7.hook_out_proj/hook_blocks.7.hook_out_proj.pt
8
/home/dev/sae-k-sparse-mamba/0.0001414213562373095 initialTry200.txtblocks.8.hook_out_proj/hook_bloc

In [2]:
global feature_labels
import pickle
with open("layer_15_features.pkl", "rb") as f:
    feature_labels = pickle.load(f)



In [3]:
global feature_labels
import pickle
with open("layer_15_features_take_two.pkl", "rb") as f:
    feature_labelsb = pickle.load(f)


In [5]:
print(feature_labelsb)

{6146: 'M (also fires on final token of M words)', 4104: 'May, five, fifth ', 10252: 'R ', 12348: 'Famous Actors, actor portrayal, embodied, playing', 24649: "? used to be 1800's and 1900's", 22605: 'Repeated Token', 12365: '? ', 2129: 'Michigan Locations (previously MMA Wrestling and Sports)', 12371: 'Middle East Diplomacy (also haemostatic/thrombosis/punk rock/hematopoiesis)', 6227: 'Russian Names, 3 digit numbers, renal excretions, and escaped quotes', 2165: '? Muilti token phrases maybe?', 4214: 'quantum terminology ', 6266: 'me/you/your (previously interactions with webpage)', 20604: 'Kidney Function', 16512: '?, accountants, marriage and health (microbiome, antioxidants) related terms', 6272: 'Human Suffering (Cancer, Auschwitz, Guantanamo Bay, Divorce, Chernobyl) previously renal cancer terms/sports teams', 30887: 'denominator of/simplify sqrt(, previouslywine tasting', 14506: 'words put together with no space (prevoiusly all caps licensing/warranty terms)', 20651: 'Russia relat

In [8]:

from IPython.display import display, clear_output
from ipywidgets import widgets
from IPython.display import display, HTML
global cur_feature_ind
cur_feature_ind = None
def display_unlabeled_feature():
    global cur_feature_ind
    global feature_labels
    available_features = features_sorted_by_feat_i.keys() - feature_labelsb.keys()
    for f in available_features:
        if not '?' in feature_labels[f]:
            continue
        print(f"feature {f}")
        cur_feature_ind = f
        feats = features_sorted_by_feat_i[f]
        text_item.value = feature_labels[f]
        #if any([position_map[feat.pos][0] != 'n' for feat in feats]):
            #print(f"warning, feature {feat_i} has non name poses, all pos are {[position_map[f.pos] for f in feats]})")
        #    continue
        display_feats(feats)
        return
    num_left = 0
    for f in feature_labels.keys():
        if "?" in feature_labels[f]:
            num_left += 1
    print(f"num left {num_left}")
    for f in feature_labels.keys():
        if "?" in feature_labels[f]:
            cur_feature_ind = f
            text_item.value = feature_labels[f]
            display_feats(features_sorted_by_feat_i[f])
            return

print(data.size())
display(HTML("""<link href='https://fonts.googleapis.com/css?family=Noto Sans' rel='stylesheet'>
<style>
/* Base Noto Sans and Serif for Latin, Greek, and Cyrillic */
@import url('https://fonts.googleapis.com/css2?family=Noto+Sans:wght@400;700&family=Noto+Serif:wght@400;700&display=swap');

/* East Asian scripts */
@import url('https://fonts.googleapis.com/css2?family=Noto+Sans+JP:wght@400;700&family=Noto+Sans+KR:wght@400;700&family=Noto+Sans+SC:wght@400;700&family=Noto+Sans+TC:wght@400;700&display=swap');

/* South Asian scripts */
@import url('https://fonts.googleapis.com/css2?family=Noto+Sans+Devanagari:wght@400;700&family=Noto+Sans+Bengali:wght@400;700&family=Noto+Sans+Tamil:wght@400;700&display=swap');

/* Middle Eastern scripts */
@import url('https://fonts.googleapis.com/css2?family=Noto+Sans+Arabic:wght@400;700&family=Noto+Sans+Hebrew:wght@400;700&display=swap');

/* Other scripts */
@import url('https://fonts.googleapis.com/css2?family=Noto+Sans+Thai:wght@400;700&family=Noto+Sans+Ethiopic:wght@400;700&display=swap');

/* Specialty fonts */
@import url('https://fonts.googleapis.com/css2?family=Noto+Sans+Mono:wght@400;700&family=Noto+Color+Emoji&display=swap');

#ayy {
  font-family: 'Noto Sans', 'Noto Sans JP', 'Noto Sans KR', 'Noto Sans SC', 'Noto Sans TC', 
               'Noto Sans Devanagari', 'Noto Sans Bengali', 'Noto Sans Tamil', 
               'Noto Sans Arabic', 'Noto Sans Hebrew', 'Noto Sans Thai', 'Noto Sans Ethiopic',
               sans-serif;
}

/* Language-specific rules */
:lang(ja) { font-family: 'Noto Sans JP', sans-serif; }
:lang(ko) { font-family: 'Noto Sans KR', sans-serif; }
:lang(zh-CN) { font-family: 'Noto Sans SC', sans-serif; }
:lang(zh-TW) { font-family: 'Noto Sans TC', sans-serif; }
:lang(hi) { font-family: 'Noto Sans Devanagari', sans-serif; }
:lang(bn) { font-family: 'Noto Sans Bengali', sans-serif; }
:lang(ta) { font-family: 'Noto Sans Tamil', sans-serif; }
:lang(ar) { font-family: 'Noto Sans Arabic', sans-serif; }
:lang(he) { font-family: 'Noto Sans Hebrew', sans-serif; }
:lang(th) { font-family: 'Noto Sans Thai', sans-serif; }
:lang(am), :lang(ti) { font-family: 'Noto Sans Ethiopic', sans-serif; }

/* Emoji support */
.emoji {
  font-family: 'Noto Color Emoji', sans-serif;
}
</style>"""))
def display_feats(feats):
    all_records = []
    feat_len = len(feats[0].records)
    for feat in feats:
        print(feat.pos)
        all_records += feat.records
    records = torch.tensor(all_records)
    dats = [torch.tensor([1]).repeat(feat_len),torch.tensor([2]).repeat(feat_len),torch.tensor([3]).repeat(feat_len),torch.tensor([4]).repeat(feat_len)]
    which_pos = torch.cat(dats)
    print(records.size())
    top_act_inds = torch.argsort(-records)
    covered_already = set()
    simpler_words = []
    for i in range(500):
        ind = top_act_inds[i]
        token_pos = which_pos[ind]
        data_pos = ind % feat_len
        act = records[ind]
        acts = [0]
        for j in range(1,5):
            acts.append([feat.records[data_pos] for feat in feats if feat.pos == j][0])
        toks = model.to_str_tokens(data[data_pos])
        relevant_str = "".join(toks[:token_pos+1])
        if relevant_str in covered_already:
            continue
        covered_already.add(relevant_str)
        out_toks = []
        colors = ['', 'red', 'orange', 'yellow', 'green']
        for j,tok in enumerate(toks):
            if len(tok.strip()) == 0:
                tok = repr(tok)
            tok = tok.replace("\n", "\\n")
            colored = f"<span id='ayy'><font color='{colors[j]}'>{tok}</font></span>"
            if j < 1: continue
            if j == token_pos:
                colored = f"<span id='ayy'><font color='pink'>{tok}</font</span>"
            if acts[j] > 0.01:
                out_toks.append(f"{colored}{acts[j]:.3f}")
            else:
                out_toks.append(colored)
        simpler = model.tokenizer.decode(data[data_pos][1:token_pos+1])
        simpler_words.append(simpler.strip())

        display(HTML(toks[token_pos] + "\t||\t" + "<span id='ayy'>" + simpler + "</span>\t||\t" + "".join(out_toks)))
        if len(simpler_words) == 20:
            out_s = "<span id='ayy'>What do "
            for s in simpler_words:
                out_s += f'"{s.strip()}", '
            out_s += "have in common? Take a deep breath and think step by step."
            display(HTML(out_s + "</span>"))
    '''
    feat_vecs = [get_name_vector(feat, 'mean') for feat in feats]
    avg_vec = torch.stack(feat_vecs).mean(dim=0)
    min_vec = torch.stack([get_name_vector(feat, 'min') for feat in feats]).min(dim=0).values
    max_vec = torch.stack([get_name_vector(feat, 'max') for feat in feats]).max(dim=0).values
    sorted_names = torch.argsort(-avg_vec)
    #print(avg_vec, min_vec, max_vec, sorted_names)
    for name_i in sorted_names[:100]:
        #print(name_i)
        print(f" name {names[name_i]} with avg {avg_vec[name_i]} min {min_vec[name_i]} max {max_vec[name_i]}")
    '''
    '''
    for feat in feats:
        if position_map[feat.pos][0] == 'n':
            print(position_map[feat.pos], detect_single_letter(feat))
            pretty_print_list_first_letter_info(list_first_letter_info(feat))
            print(improved_first_letter(feat))
    
    diffs = torch.zeros(len(feats), len(feats))
    for i,featv1 in enumerate(feat_vecs):
        for j,featv2 in enumerate(feat_vecs):
            diffs[i,j] = torch.mean(torch.abs(featv1-featv2))
    labels = [position_map[feat.pos] for feat in feats]
    imshow(diffs, x=labels, y=labels, font_size=9)
    '''
def save_labels():
    with open("layer_15_features_take_two.pkl", "wb") as f:
        global feature_labelsb
        pickle.dump(feature_labelsb, f)
        print(f"done saving {len(feature_labelsb)}")
import traceback
out = widgets.Output()
global cur_feature_ind
global feature_labels
def submitted(change):
    global cur_feature_ind
    global feature_labels
    if len(text_item.value.strip()) > 0 and (text_item.value != feature_labels[cur_feature_ind]):
        with out:
            try:
                times = 0
                clear_output()
                feature_labelsb[cur_feature_ind] = text_item.value
                save_labels()
                display_unlabeled_feature()
            except:
                print(traceback.format_exc())
                



text_item = widgets.Text(
    value='ffff',
    placeholder='Type something',
    description='String:',
    disabled=False,
    continuous_update=False,
)

display(text_item)
display(out)
text_item.observe(submitted, names='value')

with out:
    display_unlabeled_feature()
    

torch.Size([389372, 5])


Text(value='ffff', continuous_update=False, description='String:', placeholder='Type something')

Output()

In [9]:
from IPython.display import display, FileLink
import os
display(FileLink("layer_15_features.pkl"))

In [25]:






























combined = {}
for f in feature_labels.keys():
    if f in feature_labelsb:
        combined[f] = feature_labels[f] + " MOREDATA: " + feature_labelsb[f].replace("?", " ").strip()
    else:
        combined[f] = feature_labels[f]
    
    if f in feature_labelsb and feature_labelsb[f].strip()  == feature_labels[f].strip():
        combined[f] = feature_labels[f]
    
    if feature_labels[f].replace("?", "").strip() == "" and f in feature_labelsb:
        combined[f] = feature_labelsb[f]
    combined[f] = combined[f].replace("?", " ").replace("(", " ").replace(")", " ").strip()
    if combined[f].strip() == "" or combined[f].strip() == '||':
        del combined[f]
labs = sorted(list(combined.items()), key=lambda x: x[1].lower().strip())
for f,l in labs:
    #if not '?' in l:
    if l.strip() != "":
        print(f, l)
print(f"total num features identified {len(labs)}")

24649 1800's and 1900's MOREDATA: used to be 1800's and 1900's
15582 1800-1999
30568 1980's-1990s pop culture
6227 3 digit numbers, renal excretions, and escaped quotes MOREDATA: Russian Names, 3 digit numbers, renal excretions, and escaped quotes
8103 5/May/Five/Fifth
32395 [ with some other symbols  [", **[, etc.
31021 A
29892 A
14819 about names
1312 account/login terminology and oxidation/breathing
12348 actor portrayal, embodied, playing MOREDATA: Famous Actors, actor portrayal, embodied, playing
19724 aggressive/vigorously
21271 Airline Brands and Punk Rock bands
15562 all caps hurricaine/military codename stuff
14506 all caps licensing/warranty terms MOREDATA: words put together with no space  prevoiusly all caps licensing/warranty terms
23646 assay chemical terms
22790 B
28979 B
21593 blood related medical terminology
28482 Canada related terms and names
16927 Canada terms
29448 Cancer research cell lines
9622 cancer treatment related terminology and wording: radiation related 

In [19]:
feature_labelsb

{6146: 'M (also fires on final token of M words)',
 4104: 'May, five, fifth ',
 10252: 'R ',
 12348: 'Famous Actors, actor portrayal, embodied, playing',
 24649: "? used to be 1800's and 1900's",
 22605: 'Repeated Token'}

In [28]:
" Purushottam".strip()

'Purushottam'