#Set-up

In [1]:
from google.colab import drive
drive.mount('/content/drive')

!pip install transformers

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
# set seeds
import random
import numpy as np
import torch

def set_seed(seed):
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed(seed)

set_seed(42)

In [3]:
## load custom functions from utils.py

import sys
sys.path.append('//content/drive/MyDrive/SAEs_for_Genomics')

import importlib
import utils
importlib.reload(utils)

<module 'utils' from '//content/drive/MyDrive/SAEs_for_Genomics/utils.py'>

# Load NT model

In [4]:
"loading smallest nucleotide transformer (50m params)"


from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch

num_params = 50 ## default 50

# Import the tokenizer and the model
tokenizer_nt = AutoTokenizer.from_pretrained(f"InstaDeepAI/nucleotide-transformer-v2-{num_params}m-multi-species", trust_remote_code=True)
model_nt = AutoModelForMaskedLM.from_pretrained(f"InstaDeepAI/nucleotide-transformer-v2-{num_params}m-multi-species", trust_remote_code=True)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/129 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/28.7k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/101 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.06k [00:00<?, ?B/s]

esm_config.py:   0%|          | 0.00/14.9k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/InstaDeepAI/nucleotide-transformer-v2-50m-multi-species:
- esm_config.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


modeling_esm.py:   0%|          | 0.00/58.2k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/InstaDeepAI/nucleotide-transformer-v2-50m-multi-species:
- modeling_esm.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


model.safetensors:   0%|          | 0.00/224M [00:00<?, ?B/s]

# Load and preprocess addgene dataset

In [5]:
import pandas as pd


# Constants
TEST_DATA_PATH = '/content/drive/MyDrive/NOO_paper/Datasets/WorldWide/BLAST_geac_ext_169k_val_random.csv'
TRAIN_DATA_PATH = '/content/drive/MyDrive/NOO_paper/Datasets/WorldWide/BLAST_geac_ext_169k_train_random.csv'
INFREQUENT_THRESHOLD = 10

def split_test_data(test_data):
    """Split test data into input and target variables."""
    y_test = test_data['nations']
    x_test = test_data[['sequence']]
    return x_test, y_test

def replace_infrequent_labels(labels, threshold=INFREQUENT_THRESHOLD):
    """Identify and replace infrequent labels."""
    label_counts = labels.value_counts()
    infrequent_labels = label_counts[label_counts < threshold].index
    return labels.replace(infrequent_labels, 'infrequent')

def map_labels_to_integers(labels):
    """Map labels to integers."""
    unique_labels = labels.unique()
    return {label: int(i) for i, label in enumerate(unique_labels)}

def without_US(data):
    """Filter out rows where the nation is 'UNITED STATES'."""
    data_wo_US = data[data['nations'] != 'UNITED STATES']
    data_wo_US.reset_index(drop=True, inplace=True)

    data_w_US = data[data['nations'] == 'UNITED STATES']
    data_w_US.reset_index(drop=True, inplace=True)
    return data_wo_US, data_w_US

def US_vs_them(labels):
    """Categorize labels into 'UNITED STATES' and 'NON US'."""
    return labels.apply(lambda x: x if x == 'UNITED STATES' else 'NON US')

def pad_sequence(seq, length, pad_char='N'):
    """Pad sequences to the specified length with the given character."""
    return seq.ljust(length, pad_char)[:length]

# Load data
train_data = pd.read_csv(TRAIN_DATA_PATH)
test_data = pd.read_csv(TEST_DATA_PATH)

print(f'test_data shape: {test_data.shape}')

# Remove US
# train_data, train_data_US = without_US(train_data)
# test_data, test_data_US = without_US(test_data)

print(f'test_data shape: {test_data.shape}')

# Split data
x_train, y_train = train_data[['sequence']], train_data['nations']
x_test, y_test = split_test_data(test_data)

print(f'test_data shape: {y_test.shape}')
print(f'x_train shape: {x_train.shape}')
print(f'y_train shape: {y_train.shape}')

# Combine labels from train and test datasets
processed_labels = pd.concat([y_train, y_test], axis=0, ignore_index=True)
label_to_int = map_labels_to_integers(processed_labels)


# map labels to integers
y_train = y_train.map(label_to_int)
y_test = y_test.map(label_to_int)

print(f'y_test shape: {y_test.shape}')


# reset indices before concat
x_train.reset_index(drop=True, inplace=True)
y_train.reset_index(drop=True, inplace=True)
x_test.reset_index(drop=True, inplace=True)
y_test.reset_index(drop=True, inplace=True)

df_train = pd.concat([x_train, y_train], axis=1)
df_val = pd.concat([x_test, y_test], axis=1)

print(f'test_data shape: {test_data.shape}')


# Filter out sequences shorter than min_length and clean them
min_length = 0
df_train = df_train[df_train['sequence'].str.len() > min_length]
df_val = df_val[df_val['sequence'].str.len() > min_length]

print(f'test_data shape: {test_data.shape}')


# Ensure indices are reset correctly
df_train.reset_index(drop=True, inplace=True)
df_val.reset_index(drop=True, inplace=True)

# Display the split data
print("Train Data Shape:", df_train.shape)
print("Validation Data Shape:", df_val.shape)


test_data shape: (15551, 4)
test_data shape: (15551, 4)
test_data shape: (15551,)
x_train shape: (93306, 1)
y_train shape: (93306,)
y_test shape: (15551,)
test_data shape: (15551, 4)
test_data shape: (15551, 4)
Train Data Shape: (93306, 2)
Validation Data Shape: (15551, 2)


# Set-up & Load SAE

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F

cfg = {
    "seed": 49,
    "batch_size": 4096*6,
    "buffer_mult": 384,
    "lr": 5e-5,
    "num_tokens": tokenizer_nt.vocab_size,
    "d_model": 512,
    "l1_coeff": 1e-1,
    "beta1": 0.9,
    "beta2": 0.999,
    "dict_mult": 8, # hidden_d = d_model * dict_mult
    "seq_len": 512,
    "d_mlp": 512,
    "enc_dtype":"fp32",
    "remove_rare_dir": False,
    "total_training_steps": 10000,
    "lr_warm_up_steps": 1000,
    "device": "cuda"
}
cfg["model_batch_size"] = 64
cfg["buffer_size"] = cfg["batch_size"] * cfg["buffer_mult"]
cfg["buffer_batches"] = cfg["buffer_size"] // cfg["seq_len"]

DTYPES = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}

class AutoEncoder(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        # HP-choices
        d_hidden = cfg["d_mlp"] * cfg["dict_mult"]
        d_mlp = cfg["d_mlp"]
        self.l0_coeff = cfg.get("l0_coeff", 5)
        self.threshold = cfg.get("activation_threshold", 0.3)
        # Temperature for sigmoid approximation
        self.temperature = cfg.get("temperature", 1.0)
        dtype = DTYPES[cfg["enc_dtype"]]
        torch.manual_seed(cfg["seed"])

        self.W_enc = nn.Parameter(torch.nn.init.kaiming_uniform_(torch.empty(d_mlp, d_hidden, dtype=dtype)))
        self.W_dec = nn.Parameter(torch.nn.init.kaiming_uniform_(torch.empty(d_hidden, d_mlp, dtype=dtype)))
        self.b_enc = nn.Parameter(torch.zeros(d_hidden, dtype=dtype))
        self.b_dec = nn.Parameter(torch.zeros(d_mlp, dtype=dtype))
        self.W_dec.data[:] = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)

        self.d_hidden = d_hidden
        self.to("cuda") if torch.cuda.is_available() else self.to("cpu")

    def get_continuous_l0(self, x):
        """
        Compute continuous relaxation of L0 norm using sigmoid
        This provides useful gradients unlike the discrete L0
        """
        # Shifted sigmoid to approximate step function
        return torch.sigmoid((x.abs() - self.threshold) / self.temperature)

    def forward(self, x):
        # encoding and decoding of input vec
        x_cent = x - self.b_dec
        pre_acts = x_cent @ self.W_enc + self.b_enc
        acts = F.relu(pre_acts)

        # Compute continuous L0 approximation before thresholding
        l0_proxy = self.get_continuous_l0(acts)

        # Apply hard threshold for forward pass --- This is actually jumprelu (I think!)
        acts_sparse = (acts.abs() > self.threshold).float() * acts
        x_reconstruct = acts_sparse @ self.W_dec + self.b_dec

        # L2 Loss (Reconstruction Loss)
        l2_loss = F.mse_loss(x_reconstruct.float(), x.float(), reduction='none')
        l2_loss = l2_loss.sum(-1)
        l2_loss = l2_loss.mean()

        # Normalized MSE for reporting
        nmse = torch.norm(x - x_reconstruct, p=2) / torch.norm(x, p=2)

        # Continuous L0 loss (using sigmoid approximation)
        l0_loss = l0_proxy.sum(dim=1).mean()

        # Total Loss: reconstruction + sparsity
        loss = l2_loss + self.l0_coeff * l0_loss

        # For monitoring: true L0 count (not used in optimization)
        true_l0 = (acts_sparse.float().abs() > 0).float().sum(dim=1).mean()

        # For monitoring: L1 loss
        l1_loss = acts_sparse.float().abs().sum(-1).mean()

        return loss, x_reconstruct, acts_sparse, l2_loss, nmse, l1_loss, true_l0

    @torch.no_grad()
    def remove_parallel_component_of_grads(self):
        W_dec_normed = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)
        W_dec_grad_proj = (self.W_dec.grad * W_dec_normed).sum(-1, keepdim=True) * W_dec_normed
        self.W_dec.grad -= W_dec_grad_proj



sae_model = AutoEncoder(cfg)

## Load already-trained SAE

In [7]:
random_weights_path = "/content/drive/MyDrive/SAEs_for_Genomics/Weights/nt50m_sae_+40mtokens.pt"
state_dict = torch.load(random_weights_path)
sae_model = AutoEncoder(cfg)
sae_model.load_state_dict(state_dict)



  state_dict = torch.load(random_weights_path)


<All keys matched successfully>

# Using trained SAE to interpret the NuclTrans

## Analysing Rare Features (copied & adapted)

In [None]:
val_seqs = df_val['sequence'].tolist()
val_tokens = tokenizer_nt(val_seqs, max_length=512, padding='max_length', truncation=True, return_tensors="pt")

For each feature we can get the frequency at which it's non-zero (per token, averaged across a bunch of batches), and plot a histogram

In [None]:
@torch.no_grad()
def get_freqs(num_batches=20, local_encoder=None):
    if local_encoder is None:
        local_encoder = encoder

    # initialise frequency counters to 0 for all hidden neurons
    act_freq_scores = torch.zeros(sae_model.d_hidden, dtype=torch.float32).cuda()
    total = 0

    for i in range(num_batches):
        # prepare batch of tokens to input to the model
        tokens = val_tokens[i*cfg['model_batch_size']:(i+1)*cfg['model_batch_size']]

        # run model on batch of tokens
        #_, cache = model.run_with_cache(tokens, stop_at_layer=1, names_filter=utils.get_act_name("post", 0))
        mlp_act = utils.get_layer_activations(model_nt, tokens['input_ids'].cuda(), tokens['attention_mask'].cuda())
        mlp_act = mlp_act[0] # unnest

        # extract mlp activations and reshape for SAE
        mlp_act = mlp_act.reshape(-1, d_mlp)

        # normalise using same approach as for traning (optional)



        # input the acts into an SAE, get the SAEs hidden acts
        loss, x_reconstruct, hidden, l2_loss, nmse, l1_loss, true_l0 = local_encoder(mlp_act) ## acts is the second/third output -> 1/2
        act_freq_scores += (hidden > 0).sum(0) # increase counter if act > 0
        total+=hidden.shape[0]

    act_freq_scores /= total # turn counts into frequencies

    # calc and print number of never activated SAE units
    num_dead = (act_freq_scores==0).float().mean()
    print("Num dead", num_dead)

    return act_freq_scores


In [None]:
d_model = cfg["d_model"]
d_mlp = cfg["d_mlp"]
model = model_nt.cuda()

sae_model.cuda()
sae_model.eval()

freqs = get_freqs(num_batches = 50,
                  local_encoder = sae_model) # what % of time is a hidden unit activated > 0?

Num dead tensor(0.0001, device='cuda:0')


In [None]:
# how many sae units are rarely activated?

rare_T = 1e-4 #

print(f'Of {d_model*cfg["dict_mult"]} hidden SAE units, {sum(freqs < rare_T).item()} are very rarely activated')	#


# how many features are dense i.e. activate very often?

often_T = 0.3
print(f'Of {d_model*cfg["dict_mult"]} hidden SAE units, {sum(freqs > often_T).item()} are activated very often')	#

Of 16384 hidden SAE units, 124 are very rarely activated
Of 16384 hidden SAE units, 9322 are activated very often


In [None]:
import plotly.express as px

# Add 1e-9 so that dead features show up as log_freq -9
log_freq = (freqs + 10**-9).log10()
log_freq = log_freq.cpu().detach().numpy()

px.histogram(log_freq, title="Log Frequency of Features", histnorm='percent')

In [None]:
import numpy as np

# Get indices where freqs is not 0
mask = freqs != 0
indices = torch.where(mask)[0]
print(indices)

# Save indices to a file
np.save('non_rare_feature_indices.npy', indices.cpu().numpy())

tensor([   0,    1,    2,  ..., 8189, 8190, 8191], device='cuda:0')


In [None]:
encoder = sae_model # just renaming for simplicity


# Q: why encoder (as opposed to decoder) matrix?

is_rare = freqs < 1e-4 # get bool mask
rare_enc = encoder.W_enc[:, is_rare] # get cols from enc matrix
rare_mean = rare_enc.mean(-1) # average these cols

# cosine similarity of rare features to average rare feature and plot
cosine_sim = rare_mean @ encoder.W_enc / rare_mean.norm() / encoder.W_enc.norm(dim=0)

# move to cpu
cosine_sim = cosine_sim.cpu().detach().numpy()
is_rare = is_rare.cpu().detach().numpy()

px.histogram(cosine_sim,
             title="Cosine Sim with Average Rare Feature",
             color=is_rare,
             labels={"color": "is_rare", "count": "percent", "value": "cosine_sim"},
             marginal="box", histnorm="percent", barmode='overlay')

## Loading test-sequence w annotations

In [8]:
import pandas as pd
import torch
from transformers import AutoTokenizer

def load_and_process_annotations(file_path):
    """Load CSV and add 'valseq_' prefix to seq_id column if not already present."""
    df = pd.read_csv(file_path)
    df['seq_id'] = df['seq_id'].astype(str)
    # Add 'valseq_' prefix only if it's not already there
    df['seq_id'] = df['seq_id'].apply(lambda x: x if x.startswith('valseq_') else f'valseq_{x}')
    return df

def extract_and_tokenize_sequences(df_annotations, df_val, tokenizer_nt):
    """Extract sequence IDs, get corresponding sequences, and tokenize them."""
    # Extract and sort sequence IDs
    seq_ids = list(set(df_annotations['seq_id']))
    # More robust parsing of sequence IDs
    parsed_ids = []
    for seq_id in seq_ids:
        try:
            if 'valseq_' in seq_id:
                parsed_ids.append(int(seq_id.split('valseq_')[1]))
            else:
                parsed_ids.append(int(seq_id))
        except ValueError:
            print(f"Warning: Could not parse seq_id: {seq_id}")
            continue

    seq_ids = sorted(parsed_ids)

    # Get and tokenize sequences
    sequences = df_val['sequence'].iloc[seq_ids].tolist()
    tokens = tokenizer_nt(
        sequences,
        max_length=512,
        padding='max_length',
        truncation=True,
        return_tensors="pt"
    )

    return tokens, seq_ids

# File paths
base_path = '/content/drive/MyDrive/SAEs_for_Genomics/Annotated_seqs'
files = {
    's0': f'{base_path}/ann_of_1000_seqs_set0.csv',
    's1': f'{base_path}/ann_of_1000_seqs_set1.csv',
    's2': f'{base_path}/ann_of_1000_seqs_set2.csv',
}

# Process all files
dfs = {key: load_and_process_annotations(path) for key, path in files.items()}

# Extract and tokenize sequences for each dataset
results = {
    key: extract_and_tokenize_sequences(df, df_val, tokenizer_nt)
    for key, df in dfs.items()
}

# Unpack results if needed
tokens_s0, seq_ids_s0 = results['s0']
tokens_s1, seq_ids_s1 = results['s1']
tokens_s2, seq_ids_s2 = results['s2']

## From tokenised sequences create df of each token with annotation

### skip for N >= 1000

In [None]:
# Create a table that lists each token in the sequences alongside its annotation(s)

token_df = utils.make_token_df_new(
                      tokens = tokens_3rd['input_ids'].squeeze(),
                      tokenizer = tokenizer_nt,
                      df_annotated = df_annotated_3k_3rd,
                      seq_ids = seq_ids,
                      len_prefix = 2, ## choice: what should these be?
                      len_suffix = 2,
                      nucleotides_per_token = 6, # particular to this model
                      descriptor_col = 'Feature' # values: Feature, Type, Description
)
token_df

# save token_df
token_df.to_csv(f'/content/drive/MyDrive/SAEs_for_Genomics/Annotated_seqs/token_df_val_3k_3rd.csv', index=False)

### and load directly

In [9]:
# load token_df for >= 1000 seqs
token_df_1k_s1 = pd.read_csv('/content/drive/MyDrive/SAEs_for_Genomics/Annotated_seqs/token_df_1k_ss1.csv')
token_df_1k_s2 = pd.read_csv('/content/drive/MyDrive/SAEs_for_Genomics/Annotated_seqs/token_df_1k_ss2.csv')
token_df_1k_s0 = pd.read_csv('/content/drive/MyDrive/SAEs_for_Genomics/Annotated_seqs/token_df_1k_ss0.csv')


### Running SAE

Let's go and investigate a non rare feature.

We start by getting the SAE activations for (all) token in our dataset

In [10]:
from torch.cuda.amp import autocast
from tqdm import tqdm

d_model = cfg["d_model"]
d_mlp = cfg["d_mlp"]
num_layer = 11 # @param
batch_size = 26

tokens = tokens_s1 #@param options:

# Calculate batch information
total_tokens = tokens['input_ids'].shape[0] * tokens['input_ids'].shape[1]
num_batches = (total_tokens + batch_size - 1) // batch_size

all_latents = []
all_acts = []

# Ensure models are in eval mode
sae_model.eval()
model_nt.eval()

# Add progress bar
for i in tqdm(range(num_batches), desc="Processing batches", unit="batch"):
    start_idx = i * batch_size
    end_idx = min((i + 1) * batch_size, total_tokens)

    # Reshape tokens for current batch
    batch_input_ids = tokens['input_ids'][start_idx:end_idx].cuda()
    batch_attention_mask = tokens['attention_mask'][start_idx:end_idx].cuda()

    with torch.no_grad():
        #add mixed precision
        with autocast():
            # Get MLP activations
            mlp_act = utils.get_layer_activations(model_nt.cuda(),
                                                batch_input_ids,
                                                batch_attention_mask,
                                                layer_N=num_layer)

            mlp_act = mlp_act[0].reshape(-1, d_mlp)
            all_acts.append(mlp_act)

            # Forward pass through SAE
            loss, x_reconstruct, latents, l2_loss, nmse, l1_loss, true_l0 = sae_model(mlp_act)
            all_latents.append(latents)

# Combine results, move to cpu before
all_acts = torch.cat(all_acts, dim=0).cpu()
all_latents = [x.cpu() for x in all_latents]
combined_latents = torch.cat(all_latents, dim=0).cpu()
torch.cuda.empty_cache()

  with autocast():
Processing batches: 100%|██████████| 19535/19535 [04:05<00:00, 79.58batch/s]


In [17]:
latent_id = 1640 # @param or set particular int value in range 0, 4095

# we avoid modifying token_df directly as its very time-consuming to reload if we mess it up
token_df_copy = token_df_1k_s1.copy() # @param

# get the activation value for the N-th unit in the SAE for each input in batch
hidden_act_feature_id = combined_latents[:, latent_id] # N = feature_id

# add this to the dataframe
token_df_copy[f"latent-{latent_id}-act"] = hidden_act_feature_id.cpu().detach().numpy()

# sort to show the most activating tokens on top, add colours
token_df_copy.sort_values(f"latent-{latent_id}-act", ascending=False).head(500
                                                                           ).style.background_gradient("coolwarm")

Unnamed: 0,seq_id,token_pos,tokens,context,token_annotations,context_annotations,e-value annotation,percentage match,latent-1640-act
432542,12987,414,,||,['special token: '],[],,,200.200241
432532,12987,404,,||,['special token: '],[],,,185.950241
432541,12987,413,,||,['special token: '],[],,,185.200241
174379,4974,299,CGCCCT,GGTGGTTACGCGCAGCGTGACCGCTACACTTGCCAG |CGCCCT| AGCGCCCGCTCCTTTCGCTTTCTTCCCTTCCTTTCT,['f1 ori'],['f1 ori'],[0.],[100.],180.575241
432537,12987,409,,||,['special token: '],[],,,172.950241
432618,12987,490,,||,['special token: '],[],,,172.325241
432488,12987,360,,||,['special token: '],[],,,171.450241
432520,12987,392,,||,['special token: '],[],,,169.450241
432605,12987,477,,||,['special token: '],[],,,166.450241
174549,4974,469,CGCCCT,GGTGGTTACGCGCAGCGTGACCGCTACACTTGCCAG |CGCCCT| AGCGCCCGCTCCTTTCGCTTTCTTCCCTTCCTTTCT,['M13 ori'],['M13 ori'],[0.],[100.],166.450241


We can now sort and display the top tokens that activate the hidden SAE unit


In [14]:
# I want to plot latent_id (0, 4096) against max activation

max_act_per_latent = []
pbar = tqdm(range(4096))
for latent_id in pbar:
    pbar.set_description(f"Processing latent {latent_id}")
    # we avoid modifying token_df directly as its very time-consuming to reload if we mess it up
    token_df_copy = token_df_1k_s1.copy()

    # get the activation value for the N-th unit in the SAE for each input in batch
    hidden_act_feature_id = combined_latents[:, latent_id] # N = feature_id

    # add this to the dataframe
    token_df_copy[f"latent-{latent_id}-act"] = hidden_act_feature_id.cpu().detach().numpy()

    # get highest act value
    max_act_per_latent.append(token_df_copy[f"latent-{latent_id}-act"].max())



# now plot
import plotly.express as px

px.scatter(x=range(4096), y=max_act_per_latent)





Processing latent 4095: 100%|██████████| 4096/4096 [02:48<00:00, 24.34it/s]


## Auto-searching monosemantic latents

1. Searching *functional* SAE Latents

In [None]:
def safe_get_annotations(ann_entry):
    if isinstance(ann_entry, str):
        try:
            return eval(ann_entry)
        except:
            return []
    return ann_entry  # already a list

N_latents = 4096
latent_dict = {}
# Then modify the analysis:
for latent_id in range(N_latents):
    hidden_act_feature_id = combined_latents[:, latent_id]
    token_df_copy[f"latent-{latent_id}-act"] = hidden_act_feature_id.cpu().detach().numpy()

    most_activating_tokens = token_df_copy.sort_values(f"latent-{latent_id}-act", ascending=False).head(20)

    # Skip if any activations are 0
    if (most_activating_tokens[f"latent-{latent_id}-act"] == 0).any():
        continue

    annotations = [safe_get_annotations(ann) for ann in most_activating_tokens['token_annotations']]

    if annotations:
        annotation_counts = {}
        for ann_list in annotations:
            for ann in ann_list:
                annotation_counts[ann] = annotation_counts.get(ann, 0) + 1

        common_annotations = {ann for ann, count in annotation_counts.items()
                            if count >= 10 and ann not in {'special token: <cls>', 'special token: <pad>'}}

        if common_annotations:
            latent_dict[latent_id] = common_annotations
            print(f"\nLatent {latent_id} appears to detect: {common_annotations}")
            print("Top 20 activating tokens and their annotations:")
            for _, row in most_activating_tokens.iterrows():
                print(f"Token: {row['tokens']}, Annotations: {safe_get_annotations(row['token_annotations'])}, "
                      f"Activation: {row[f'latent-{latent_id}-act']:.3f}")

  token_df_copy[f"latent-{latent_id}-act"] = hidden_act_feature_id.cpu().detach().numpy()
  token_df_copy[f"latent-{latent_id}-act"] = hidden_act_feature_id.cpu().detach().numpy()
  token_df_copy[f"latent-{latent_id}-act"] = hidden_act_feature_id.cpu().detach().numpy()
  token_df_copy[f"latent-{latent_id}-act"] = hidden_act_feature_id.cpu().detach().numpy()
  token_df_copy[f"latent-{latent_id}-act"] = hidden_act_feature_id.cpu().detach().numpy()
  token_df_copy[f"latent-{latent_id}-act"] = hidden_act_feature_id.cpu().detach().numpy()
  token_df_copy[f"latent-{latent_id}-act"] = hidden_act_feature_id.cpu().detach().numpy()
  token_df_copy[f"latent-{latent_id}-act"] = hidden_act_feature_id.cpu().detach().numpy()
  token_df_copy[f"latent-{latent_id}-act"] = hidden_act_feature_id.cpu().detach().numpy()
  token_df_copy[f"latent-{latent_id}-act"] = hidden_act_feature_id.cpu().detach().numpy()
  token_df_copy[f"latent-{latent_id}-act"] = hidden_act_feature_id.cpu().detach().numpy()
  token_df

KeyboardInterrupt: 

2. Searching *syntactic* SAE latents

In [None]:
import pandas as pd
from tqdm import tqdm

## here we create a short list of candidate monosemantic latents for **kmers** by looking at the top-50 most activating tokens
## and asking: do at least half of them share a kmer?

def analyze_latent_features_fast(token_df, combined_latents, k=4, n_latents=4096, top_n=10):
    """Optimized version of latent feature analysis"""

    def get_kmers(token, k):
        if not isinstance(token, str) or k <= 0:
            return set()
        token = token.strip()
        if not token or len(token) < k:
            return set()
        return {token[i:i+k] for i in range(len(token) - k + 1)}

    # Convert combined_latents to numpy once
    activations_array = combined_latents.cpu().detach().numpy()

    # Pre-compute valid tokens and their k-mers
    valid_tokens = token_df['tokens'].dropna()
    token_to_kmers = {token: get_kmers(str(token), k) for token in valid_tokens}

    latent_dict = {}
    tokens_array = token_df['tokens'].values

    # Process in batches for better memory usage
    batch_size = 100
    for batch_start in tqdm(range(0, n_latents, batch_size)):
        batch_end = min(batch_start + batch_size, n_latents)

        # Process batch of latents
        batch_activations = activations_array[:, batch_start:batch_end]

        # Find top_n indices for each latent in batch
        top_indices = np.argpartition(-batch_activations, top_n, axis=0)[:top_n]

        # Process each latent in batch
        for i, latent_id in enumerate(range(batch_start, batch_end)):
            # Get tokens for top activations
            top_tokens = tokens_array[top_indices[:, i]]

            # Get k-mer sets for valid tokens
            kmer_sets = [token_to_kmers[token] for token in top_tokens
                        if pd.notna(token) and token in token_to_kmers]

            if kmer_sets:
                common_kmers = set.intersection(*kmer_sets)
                if common_kmers:
                    latent_dict[latent_id] = common_kmers
                    print(f"\nLatent {latent_id} appears to detect: {common_kmers}")

    return latent_dict

kmer_latent_dict_ = analyze_latent_features_fast(token_df, combined_latents, k=4, n_latents=4096, top_n=50) ## set n_latents to 100 to quickly test

In [None]:
#save latent dict as csv file
import pandas as pd

df = pd.DataFrame(list(latent_dict.items()), columns=['latent_id', 'annotation'])
df

# save
df.to_csv('/content/drive/MyDrive/SAEs_for_Genomics/Latent_dict_func_monosem_nt50m_sae_l10_+40mtokens.csv', index=False)

In [None]:
dict_values = latent_dict.values()

# turn list of sets into one large set
flat_set = set.union(*dict_values)
print(flat_set)

## Auto-Searching of MLP

1. Of the MLP neurons are there any that are somewhat monosemantic for a functional annotation?

In [None]:
for latent_id in range(512):
    # we avoid modifying token_df directly as its very time-consuming to reload if we mess it up
    token_df_copy = token_df.copy()

    # get the activation value for the N-th unit in the SAE for each input in batch
    hidden_act_feature_id = mlp_act[:, latent_id] # N = feature_id

    # add this to the dataframe
    token_df_copy[f"latent-{latent_id}-act"] = hidden_act_feature_id.cpu().detach().numpy()

    # print latent_id if the five most activating tokens share a token annotation
    most_activating_tokens = token_df_copy.sort_values(f"latent-{latent_id}-act", ascending=False).head(5)

    # Get annotations as lists
    annotations = most_activating_tokens['token_annotations'].tolist()

    # Check if there's any intersection between all annotation lists
    if annotations:
        # Convert all annotations to sets for intersection
        annotation_sets = [set(ann) for ann in annotations]
        common_annotations = set.intersection(*annotation_sets)
        filtered_annotations = common_annotations - {'special token: <cls>', 'special token: <pad>'}

        if filtered_annotations:  # If there are any shared annotations
            print(f"\nLatent {latent_id} appears to detect: {common_annotations}")
            print("Top 5 activating tokens and their annotations:")
            for _, row in most_activating_tokens.iterrows():
                print(f"Token: {row['tokens']}, Annotations: {row['token_annotations']}, "
                      f"Activation: {row[f'latent-{latent_id}-act']:.3f}")


2. Of the MLP neurons are there any that are somewhat monosemantic for some *syntactic* pattern?

In [None]:
k = 4 ## kmer length
latent_dict = {}

for latent_id in range(512):
    # we avoid modifying token_df directly as its very time-consuming to reload if we mess it up
    token_df_copy = token_df.copy()

    # get the activation value for the N-th unit in the SAE for each input in batch
    hidden_act_feature_id = combined_acts[:, latent_id] # N = feature_id

    # add this to the dataframe
    token_df_copy[f"latent-{latent_id}-act"] = hidden_act_feature_id.cpu().detach().numpy()

    # print latent_id if the five most activating tokens share a token annotation
    most_activating_tokens = token_df_copy.sort_values(f"latent-{latent_id}-act", ascending=False).head(10)


    # Get a set of all kmers for each most activating token
    def get_kmers(token, k):
        if not isinstance(token, str) or k <= 0:
            raise ValueError("Invalid input: token must be string and k must be positive")
        if len(token) < k:
            return set()
        return {token[i:i+k] for i in range(len(token) - k + 1)}

    kmer_sets = [get_kmers(token, k) for token in most_activating_tokens['tokens']]

    # Check if there's any intersection between all kmer sets stored
    if kmer_sets:
        common_kmers = set.intersection(*kmer_sets)
        if common_kmers:  # If there are any shared kmers
            latent_dict[latent_id] = common_kmers
            print(f"\nLatent {latent_id} appears to detect: {common_kmers}")
            print("Top 5 activating tokens and their annotations:")
            for _, row in most_activating_tokens.iterrows():
                print(f"Token: {row['tokens']}, Annotations: {row['token_annotations']}, "
                      f"Activation: {row[f'latent-{latent_id}-act']:.3f}")



In [None]:
latent_id = 188  #np.random.randint(0, 4096) # or set particular int value in this range


# we avoid modifying token_df directly as its very time-consuming to reload if we mess it up
token_df_copy = token_df.copy()

# get the activation value for the N-th unit in the SAE for each input in batch
hidden_act_feature_id = combined_acts[:, latent_id] # N = feature_id

# add this to the dataframe
token_df_copy[f"latent-{latent_id}-act"] = hidden_act_feature_id.cpu().detach().numpy()

# sort to show the most activating tokens on top, add colours
token_df_copy.sort_values(f"latent-{latent_id}-act", ascending=False).head(300).style.background_gradient("coolwarm")


## Calc sensitivity and specificity of SAE latent for Functional or Syntactic Feature

In [None]:
def contains_kmers(tokens: str, kmers: list) -> bool:
    """Check if a token sequence contains any of the kmers in the list"""
    if not isinstance(tokens, str):
        return False
    return any(k in tokens for k in kmers)

def contains_annotations(token_annotation: str, annotations: list) -> bool:
    """Check if a token sequence contains any of the given annotations"""
    if not isinstance(token_annotation, str):
        return False
    return any(annotation in token_annotation for annotation in annotations)

def calculate_stats(df, act_threshold, meaning, check: str):
    """Calculate various statistics about TAG tokens and activations"""

    # Create a function that's partially applied with the specific annotation
    if check == 'kmer':
      check_kmer = lambda x: contains_kmers(x, meaning)
      check_fn = check_kmer
      col = 'tokens'

    elif check == 'annotation':
      check_annotation = lambda x: contains_annotations(x, meaning)
      check_fn = check_annotation
      col = 'token_annotations'

    else: raise ValueError("check must be 'kmer' or 'annotation'")

    # Q1: Fraction of above-threshold activations containing TAG
    above_threshold = df[df[latent_column] > act_threshold]
    above_threshold_tag_fraction = above_threshold[col].apply(check_fn).mean()

    # Q2: Fraction of below-or-equal threshold activations containing TAG
    below_threshold = df[df[latent_column] <= act_threshold]
    below_threshold_tag_fraction = below_threshold[col].apply(check_fn).mean()

    # Q3: Overall fraction of rows containing TAG
    overall_tag_fraction = df[col].apply(check_fn).mean()

    # Q4: Fraction of tokens with positive activation
    positive_activation_fraction = (df[latent_column] > act_threshold).mean()

    # Q5: For rows containing TAG, fraction with positive activation
    tag_rows = df[df[col].apply(check_fn)]
    tag_positive_fraction = (tag_rows[latent_column] > act_threshold).mean()

    return {
        'above_threshold_tag': above_threshold_tag_fraction,
        'below_threshold_tag': below_threshold_tag_fraction,
        'overall_tag': overall_tag_fraction,
        'positive_activation': positive_activation_fraction,
        'tag_positive': tag_positive_fraction
    }

def find_largest_consecutive_tag_sequence(df):
    """Find largest N where top-N rows all contain TAG"""
    sorted_df = df.sort_values(latent_column, ascending=False)
    check_cmv = lambda x: contains_annotations(x, ['CMV enhancer', 'CMV promoter', 'CMV IE94 promoter'])

    for N in range(1, len(df) + 1):
        top_n = sorted_df.head(N)
        if not all(top_n['token_annotations'].apply(check_cmv)):
            return N - 1
    return len(df)

# create empty pd df with column for latent_id, annotation and evidence_for_act_from_tag
columns = ['latent_id', 'annotation', 'evidence_for_act_from_ann', 'evidence_for_ann_from_act', 'precision', 'recall']
df = pd.DataFrame(columns=columns)

using_kmer = False
using_annotation = not using_kmer

# Iterate over latent dict
for latent_id, meaning in latent_dict.items():

    # Calculate all statistics
    act_threshold = 0.0  # Activation threshold
    latent_column = f"latent-{latent_id}-act"


    # get most activating tokens for latent id
    token_df_copy = token_df.copy()
    hidden_act_feature_id = combined_latents_new[:, latent_id] # N = feature_id
    token_df_copy[f"latent-{latent_id}-act"] = hidden_act_feature_id.cpu().detach().numpy()
    token_df_copy.sort_values(f"latent-{latent_id}-act", ascending=False).head(300).style.background_gradient("coolwarm")

    ### input all to calc stats
    if using_annotation:
      annotation = list(meaning)
      stats = calculate_stats(token_df_copy, act_threshold, meaning = annotation, check = 'annotation')

    elif using_kmer:
      kmer_strings = [''.join(kmer) for kmer in meaning]
      stats = calculate_stats(token_df_copy, act_threshold, meaning = kmer_strings, check = 'kmer')



    ## Calculate the posterior to prior odds ratios
    evidence_for_act_from_tag = (stats['tag_positive']/(1-stats['tag_positive'])) / (stats['positive_activation']/(1-stats['positive_activation']))
    evidence_for_tag_from_act = (stats['above_threshold_tag']/(1-stats['above_threshold_tag'])) / (stats['overall_tag']/(1-stats['overall_tag']))


    # only print at least moderately-monosemantic latents
    if min(evidence_for_act_from_tag, evidence_for_tag_from_act) > 20 or max(evidence_for_act_from_tag, evidence_for_tag_from_act)>200:

        print(f"\nLatent {latent_id} appears to detect: {meaning}")

        print(f"Strength of evidence for act > {act_threshold} from {meaning} (as BayesF): {evidence_for_act_from_tag:.3f}")
        print(f"Strength of evidence for {meaning} from act > {act_threshold} (as BayesF): {evidence_for_tag_from_act:.3f}")



            # Print results in a formatted way
        print(f"\n {meaning} Token Analysis Results")
        print("=" * 50)
        print(f"Analysis for activation threshold: {act_threshold}")
        print("-" * 50)
        print(f" P(token annotated with {meaning}):                      {stats['overall_tag']:.3f}")
        print(f" P(token annotated with {meaning}|activation > {act_threshold}):     {stats['above_threshold_tag']:.3f}")
        print(f" P(activation > {act_threshold}):                        {stats['positive_activation']:.3f}")
        print(f" P(activation > {act_threshold}|token annotated with {meaning}):     {stats['tag_positive']:.3f}")


        # add to df
        df.loc[len(df)] = [latent_id, meaning, evidence_for_act_from_tag, evidence_for_tag_from_act, stats['above_threshold_tag'], stats['tag_positive']]

        # Find and print largest consecutive sequence0
        #largest_n = find_largest_consecutive_tag_sequence(token_df_copy)
        #print("-" * 50)
        #print(f"Largest N where top-N rows all contain {annotation}: {largest_n}")

# save df
df.to_csv('/content/drive/MyDrive/SAEs_for_Genomics/Latent_dict_4MER_monosem_nt50m_sae_+40mtokens.csv', index=False)

  evidence_for_tag_from_act = (stats['above_threshold_tag']/(1-stats['above_threshold_tag'])) / (stats['overall_tag']/(1-stats['overall_tag']))
  evidence_for_tag_from_act = (stats['above_threshold_tag']/(1-stats['above_threshold_tag'])) / (stats['overall_tag']/(1-stats['overall_tag']))
  evidence_for_tag_from_act = (stats['above_threshold_tag']/(1-stats['above_threshold_tag'])) / (stats['overall_tag']/(1-stats['overall_tag']))
  evidence_for_tag_from_act = (stats['above_threshold_tag']/(1-stats['above_threshold_tag'])) / (stats['overall_tag']/(1-stats['overall_tag']))
