In [8]:
import torch
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
from transformers import AutoTokenizer, PreTrainedTokenizerFast
from dictionary_learning import AutoEncoder
from nnsight import LanguageModel
from dictionary_learning.utils import read_csv
from IPython.display import display, HTML

In [9]:
import scipy.sparse
from scipy.sparse import csr_matrix
from dictionary_learning.sparse_feature_writer import SparseFeatureWriter

import gc
import h5py
from collections import defaultdict

In [10]:
# load autoencoder
ae = AutoEncoder.from_pretrained("/gpfs/helios/home/jpauklin/dictionary_learning/saes/estMedSae120425/trainer_0/ae_12042025.pt").to("cuda") # to is rquired to load to GPU

device = "cuda:0" #GPU
# Load Model
model_name = "/gpfs/space/projects/stacc_health/gpt2_model/estMed-gpt2_fine_tuned4/estMed-gpt2_fine_tuned4"
model = LanguageModel(
    model_name,
    device_map=device,
)

activation_dim = model.transformer.h[0].ln_1.normalized_shape[0] # output dimension of the MLP = 768

  state_dict = t.load(path)


In [11]:
data = read_csv("/gpfs/space/projects/stacc_health/data-synthetic/100k_synthetic_texts.csv", 2) # (csv_path, nr_of_text_batches_to_read)
text = list(data) # list from iterator

In [12]:
#text.append("See on test lause, mis on käsitsi lisatud.")

In [13]:
tokenizer_filepath = "/gpfs/space/projects/stacc_health/gpt2_model/estMed-gpt2_fine_tuned4/estMed-gpt2_fine_tuned4/tokenizer.json"
tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_filepath)
tokenizer.pad_token = "<pad>"

# Padding to have batched tensors with the same length.    
tokens = tokenizer(text, return_tensors="pt", padding=True)   # , truncation=True


In [14]:
# Using nnsight, we hook into one of the MLP's layer and inspect them during a forward pass.
# This gets us the layer's activations (output of c_proj)

# The "with" keyword starts a context-like object. It defines that we will run the model with tracing. 
with model.trace(tokens['input_ids'], tokens.get('attention_mask')) as tracer:
    
    # Selecting a specific layer to capture
    mlp_output = model.transformer.h[11].mlp.c_proj.output.save()
    
# The model is actually run upon exiting the tracing context. (https://nnsight.net/notebooks/tutorials/walkthrough/)
activations = mlp_output.value  # shape: (batch_size aka nr_of_inputs, seq_len , hidden_dim_activation_value)

You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.


In [15]:
features = ae.encode(activations) # get features from activations
reconstructed_activations = ae.decode(features)

In [16]:
def get_batch(batch_idx):
    token_batch_size = 462
    
    with h5py.File("features/featuresWithTokens.h5", 'r') as f:
        shape = f.attrs['shape']
        csr = csr_matrix((f['data'][:], f['indices'][:], f['indptr'][:]), shape=shape)
        tokens_h5 = f['tokens'][:]
        
    row = csr.getrow(i)
    
    feature_activations_for_seq =  csr
    
    return 

In [17]:

def token_top_features(tokens, features, batch_idx):
    # features shape: (batch_size, seq_len, num_features)
    feature_activations_for_seq = features[batch_idx]

    # Get corresponding tokens
    input_tokens = tokens['input_ids'][batch_idx]

    # For each token
    for token_idx, token in enumerate(input_tokens):
        decoded_token = tokenizer.decode(token)
        print(f"Token: {decoded_token}")
        feature_acts = feature_activations_for_seq[token_idx]

        # Show top k most active features for token
        topk = torch.topk(feature_acts, k=5)
        print(f"  Top features: {topk.indices.tolist()} with activations {topk.values.tolist()}")

In [18]:
token_top_features(tokens, features, 0)

Token: NAME
  Top features: [11257, 8048, 10334, 8871, 2872] with activations [21.490509033203125, 21.448265075683594, 11.61661148071289, 7.110403537750244, 6.928773880004883]
Token:  :
  Top features: [5011, 8266, 8871, 6319, 8048] with activations [3.0666751861572266, 2.015342950820923, 1.3793467283248901, 1.2618141174316406, 0.8697214722633362]
Token:  läheb
  Top features: [9454, 2791, 11426, 5903, 8594] with activations [1.7806084156036377, 1.030806064605713, 0.7869167923927307, 0.6623857021331787, 0.5622596740722656]
Token:  pik
  Top features: [11149, 5903, 2378, 10450, 5543] with activations [0.8619814515113831, 0.6991864442825317, 0.6093074083328247, 0.5957626104354858, 0.5683003664016724]
Token: ale
  Top features: [5903, 1937, 10871, 11426, 6924] with activations [1.3473151922225952, 0.7340749502182007, 0.6217766404151917, 0.6039524078369141, 0.6014381051063538]
Token:  lennureisi
  Top features: [1937, 5903, 2422, 5499, 5167] with activations [2.7617714405059814, 1.93541359

In [20]:
# given a specific feature and a batch of input text, outputs the feature's activation value for each token.

def feature_activations(tokens, features, batch_idx, feature_idx):
    # features shape: (batch_size, seq_len, num_features)
    feature_activations_for_seq = features[batch_idx]  # shape: (seq_len, num_features)

    # Get corresponding tokens
    input_tokens = tokens['input_ids'][batch_idx]

    # Print given feature's activation value for each token
    print(f"Feature: {feature_idx}")
    for token_idx, token in enumerate(input_tokens):
        decoded_token = tokenizer.decode(token)
        #token_clean = decoded_token.replace("Ġ", "")
        feature_acts = features[batch_idx][token_idx]
        print(f"{decoded_token: >20} = {feature_acts[feature_idx]}")

In [31]:
feature_activations(tokens, features, 0, 7417)

Feature: 7417
                NAME = 0.0
                   : = 0.0
               läheb = 0.0
                 pik = 0.0
                 ale = 0.11386876553297043
          lennureisi = 0.014625731855630875
                  le = 0.048650261014699936
                   . = 0.0
                seal = 0.0
              tunneb = 0.0
                   , = 0.0
                  et = 0.0
                  ja = 0.15995392203330994
               bajal = 0.15718324482440948
                  ad = 0.0
           valutavad = 0.4343339800834656
                   , = 0.0
                kurk = 0.018008112907409668
                  on = 0.0
                kuiv = 0.0
                köha = 0.015588922426104546
                   . = 0.0
                 obj = 0.05980239436030388
                   . = 0.17642372846603394
                kops = 0.0
               puhas = 0.3031104803085327
                   , = 0.0
                neel = 0.08607755601406097
             punetav = 0.0
         

In [22]:
def visualize_feature_activations(tokens, features, batch_idx, feature_idx):
    # detach from computation graph. Move to CPU memory, since numpy only works on CPU tensors.
    feature_activations = features[batch_idx, : , feature_idx].detach().cpu().numpy()

    # Get corresponding tokens
    input_tokens = tokens['input_ids'][batch_idx]
    
    # Color mapping
    max_val = 30 
    min_val = 0
    forward = lambda x: x** (1/3) # function. Transforms the data values before applying the colormap.
    inverse = lambda x: x**3 # function. Transforms from the normalized scale [0, 1] back to the original data space.
    norm = matplotlib.colors.FuncNorm((forward, inverse), vmin=min_val, vmax=max_val) #(https://matplotlib.org/stable/users/explain/colors/colormapnorms.html)
    #norm = matplotlib.colors.Normalize(vmin=min_val, vmax=max_val)
    cmap = matplotlib.colormaps.get_cmap('Oranges')

    # Build HTML with color highlighting
    html = ""
    for token, activation in zip(input_tokens, feature_activations): 
        decoded_token = tokenizer.decode(token)
     
        # TODO: Ei lisaks spani padding tokenite korral
    
        color = matplotlib.colors.rgb2hex(cmap(norm(activation)))
        html += f'<span style="background-color:{color}; padding: 2px; margin: 2px; border-radius: 3px;">{decoded_token}</span> '

    display(HTML(html))

In [32]:
visualize_feature_activations(tokens, features, 0, 7417)