In [1]:
import torch
from transformers import AutoTokenizer, PreTrainedTokenizerFast
from dictionary_learning import AutoEncoder
from nnsight import LanguageModel
from dictionary_learning.utils import read_csv

In [2]:
# load autoencoder
ae = AutoEncoder.from_pretrained("/gpfs/helios/home/jpauklin/dictionary_learning/saes/trainer_0/ae.pt").to("cuda") # to is rquired to load to GPU

device = "cuda:0"
# Load Model
model_name = "/gpfs/space/projects/stacc_health/gpt2_model/estMed-gpt2_fine_tuned4/estMed-gpt2_fine_tuned4"
model = LanguageModel(
    model_name,
    device_map=device,
)

activation_dim = model.transformer.h[0].ln_1.normalized_shape[0] # output dimension of the MLP = 768

  state_dict = t.load(path)


In [3]:
data = read_csv("/gpfs/space/projects/stacc_health/data-synthetic/100k_synthetic_texts.csv", 2) # (csv_path, nr_of_text_batches_to_read)
text = list(data) # list from iterator

In [4]:
tokenizer = PreTrainedTokenizerFast(tokenizer_file="/gpfs/space/projects/stacc_health/gpt2_model/estMed-gpt2_fine_tuned4/estMed-gpt2_fine_tuned4/tokenizer.json")

if tokenizer.pad_token is None: # The tokenizer is missing its padding token
    tokenizer.pad_token = "<pad>"

# Padding and truncating to have batched tensors with the same length.    
tokens = tokenizer(text, return_tensors="pt", padding=True, truncation=True)  


Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


In [5]:
# Using nnsight, we hook into one of the MLP's layer and inspect them during a forward pass.
# This gets us the layer's activations (output of c_proj)

# The "with" keyword starts a context-like object. It defines that we will run the model with tracing. 
# The model is actually run upon exiting the tracing context. (https://nnsight.net/notebooks/tutorials/walkthrough/)
with model.trace(tokens['input_ids'], tokens.get('attention_mask')) as tracer: # ?? attention_mask
    
    # Selecting a specific layer to capture
    mlp_output = model.transformer.h[11].mlp.c_proj.output.save()

# After tracing, get the activations
activations = mlp_output.value  # shape: (batch_size aka nr_of_inputs, seq_len , hidden_dim_activation_value)

You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.


In [7]:
features = ae.encode(activations) # get features from activations
reconstructed_activations = ae.decode(features)

In [26]:
# features shape: (batch_size, seq_len, num_features)
batch_idx = 0
feature_activations_for_seq = features[batch_idx]  # shape: (seq_len, num_features)

# Get corresponding tokens
input_tokens = tokens['input_ids'][batch_idx]
decoded_tokens = tokenizer.convert_ids_to_tokens(input_tokens)

# For each token
for token_idx, token in enumerate(decoded_tokens):
    token_clean = token.replace("Ġ", "")
    print(f"Token: {token_clean}")
    feature_acts = feature_activations_for_seq[token_idx]
    
    # Show top 5 most active features for token
    topk = torch.topk(feature_acts, k=5)
    print(f"  Top features: {topk.indices.tolist()} with activations {topk.values.tolist()}")

Token: NAME
  Top features: [3401, 10404, 6611, 11128, 6052] with activations [28.05147361755371, 19.718608856201172, 8.891927719116211, 8.346430778503418, 7.992448806762695]
Token: :
  Top features: [3401, 8914, 10404, 2492, 3155] with activations [2.2227439880371094, 2.0554213523864746, 2.0495235919952393, 1.965754508972168, 1.9163177013397217]
Token: lÃ¤heb
  Top features: [2492, 6017, 1315, 8914, 6908] with activations [1.4795644283294678, 1.4043340682983398, 1.3195750713348389, 1.0100703239440918, 0.9864257574081421]
Token: pik
  Top features: [2492, 3155, 1880, 1315, 8914] with activations [2.2087817192077637, 1.1387423276901245, 1.0914673805236816, 1.0670545101165771, 1.03082275390625]
Token: ale
  Top features: [2492, 3155, 8914, 5866, 4271] with activations [2.407727003097534, 1.7267181873321533, 1.608194351196289, 1.3907339572906494, 1.2961288690567017]
Token: lennureisi
  Top features: [10684, 2492, 1315, 8112, 10125] with activations [1.5269229412078857, 1.3731985092163086,

In [51]:
# given a specific feature and a batch of input text, outputs the feature's activation value for each token.

feature_nr = 2492

# features shape: (batch_size, seq_len, num_features)
batch_idx = 0
feature_activations_for_seq = features[batch_idx]  # shape: (seq_len, num_features)

# Get corresponding tokens
input_tokens = tokens['input_ids'][batch_idx]
decoded_tokens = tokenizer.convert_ids_to_tokens(input_tokens)

# Print given feature's activation value for each token
print(f"Feature: {feature_nr}")
for token_idx, token in enumerate(decoded_tokens):
    token_clean = token.replace("Ġ", "")
    feature_acts = feature_activations_for_seq[token_idx]
    print(f"{token_clean: >20} = {feature_acts[feature_nr]}")

Feature: 2492
                NAME = 3.4449825286865234
                   : = 1.965754508972168
              lÃ¤heb = 1.4795644283294678
                 pik = 2.2087817192077637
                 ale = 2.407727003097534
          lennureisi = 1.3731985092163086
                  le = 0.6448439359664917
                   . = 1.8179036378860474
                seal = 1.0401058197021484
              tunneb = 1.4472966194152832
                   , = 1.1269505023956299
                  et = 1.7242414951324463
                  ja = 1.5632731914520264
               bajal = 1.4519038200378418
                  ad = 1.2593913078308105
           valutavad = 1.2101409435272217
                   , = 1.3528906106948853
                kurk = 1.1858915090560913
                  on = 1.703606128692627
                kuiv = 1.4350690841674805
               kÃ¶ha = 0.08191561698913574
                   . = 0.8507479429244995
                 obj = 1.2404780387878418
                   . =

In [40]:
from IPython.display import display, HTML
import matplotlib.pyplot as plt
import matplotlib

def visualize_feature_activations(tokens, activations, feature_idx):
    """
    Visualize how strongly a particular feature is activated across tokens.
    
    Parameters:
        tokens (List[str]): List of tokens (strings).
        activations (Tensor): Tensor of shape (seq_len, num_features).
        feature_idx (int): Index of the feature to visualize.
    """
    # Get activations for the specified feature
    # : all on first axis, specific feature on 2nd axis. detach from computation graph. Move to CPU memory, since numpy only works on CPU tensors.
    feature_activations = activations[:, feature_idx].detach().cpu().numpy()

    # Color mapping
    max_val = feature_activations.max()
    min_val = feature_activations.min()
    #print(max_val, min_val)
    norm = matplotlib.colors.Normalize(vmin=min_val, vmax=max_val) # LogNorm
    cmap = matplotlib.colormaps.get_cmap('Oranges')

    # Build HTML with color intensity
    html = ""
    for token, activation in zip(tokens, feature_activations):
        token_clean = token.replace("Ġ", "")
        color = matplotlib.colors.rgb2hex(cmap(norm(activation)))
        html += f'<span style="background-color:{color}; padding: 2px; margin: 2px; border-radius: 3px;">{token_clean}</span> '

    display(HTML(html))

In [42]:
# tokens: List of decoded tokens (str), e.g., from tokenizer.convert_ids_to_tokens
# feature_activations_for_seq: shape (seq_len, num_features)

visualize_feature_activations(
    tokens=decoded_tokens,
    activations=feature_activations_for_seq,  # Tensor [seq_len, num_features]
    feature_idx=2429  # Choose a specific feature index
)

0.38095915 0.0
