In [18]:
from transformer_lens import (HookedTransformer, utils)
from transformer_lens.hook_points import HookPoint
import functools
import torch

import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from torch import Tensor
from torch.nn import functional as F
device = 'cuda:7' if torch.cuda.is_available() else 'cpu'
from transformers import PatchTSTForPrediction
from transformers.models.patchtst.modeling_patchtst import (
    PatchTSTForPredictionOutput
)
from data_loader import *
import pandas as pd

import torch
import os

import plotly.express as px
from sae_lens import (
    SAE,
    upload_saes_to_huggingface,
    LanguageModelSAERunnerConfig,
    TimeSeriesModelSAERunnerConfig,
    TimeSeriesModelSAETrainingRunner,
    SAETrainingRunner,
    StandardTrainingSAEConfig,
    LoggingConfig,
    HookedSAETransformer,
    ActivationsStore,
    run_evals,
)
import json

from sae_lens.evals import EvalConfig
from sae_lens.util import extract_stop_at_layer_from_tlens_hook_name
from sae_lens.training.activation_scaler import ActivationScaler
from datasets import load_dataset

In [3]:
gpt2: HookedSAETransformer = HookedSAETransformer.from_pretrained("gpt2-small", device=device)

gpt2_sae = SAE.from_pretrained(
    release="gpt2-small-res-jb",
    sae_id="blocks.7.hook_resid_pre",
    device=str(device),
)

Loaded pretrained model gpt2-small into HookedTransformer


In [17]:
# prompt = "Mitigating the risk of extinction from AI should be a global"
prompt = "The New York City police union generally doesn't provide lawyers for law enforcement charged with crimes not associated with their"
print(gpt2.to_str_tokens(prompt))
hook_name = "blocks.7.hook_resid_pre"
sae_hook_name = f"{hook_name}.hook_sae_acts_post"
logits, cache = gpt2.run_with_cache_with_saes(prompt, saes=[gpt2_sae])
print(cache[sae_hook_name][0, gpt2.to_str_tokens(prompt).index(" law"), 1])

top_prob, token_id_prediction = logits[0, -1].softmax(-1).max(-1)
print(f"Top prediction {top_prob.item():.2%} |{gpt2.to_string(token_id_prediction)}|")

['<|endoftext|>', 'The', ' New', ' York', ' City', ' police', ' union', ' generally', ' doesn', "'t", ' provide', ' lawyers', ' for', ' law', ' enforcement', ' charged', ' with', ' crimes', ' not', ' associated', ' with', ' their']
tensor(52.7983, device='cuda:7')
Top prediction 6.27% | union|


In [None]:
# get a random prompt
# get sae activations
# get top k activation values and their token/feature #
# shift the prompt
# get new sae activations
# get the activation values (for the specified feature) of the shifted tokens
# If the activation value doesn't change very much then we know that the sequence dimension is "aligned"
# If the activation values differe a lot, then the sequence dimension is not "aligned"

In [None]:
# Load TinyStories Dataset
ds = load_dataset("roneneldan/TinyStories")

README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00004-2d5a1467fff108(…):   0%|          | 0.00/249M [00:00<?, ?B/s]

data/train-00001-of-00004-5852b56a2bd28f(…):   0%|          | 0.00/248M [00:00<?, ?B/s]

data/train-00002-of-00004-a26307300439e9(…):   0%|          | 0.00/246M [00:00<?, ?B/s]

data/train-00003-of-00004-d243063613e5a0(…):   0%|          | 0.00/248M [00:00<?, ?B/s]

data/validation-00000-of-00001-869c898b5(…):   0%|          | 0.00/9.99M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/2119719 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/21990 [00:00<?, ? examples/s]

In [None]:
def test_nlp_sequence_alignment(num_tests=10, k=5, shift_amt=2, prompt_len=30):
    """
    Test sequence dimension alignment by comparing SAE activations before and after shifting prompts.
    
    Args:
        num_tests: Number of random prompts to test
        k: Number of top activations to track
        shift_amt: Number of tokens to shift the prompt by
        prompt_len: Length of the prompt in tokens
    """
    alignment_results = []
    
    for test_idx in range(num_tests):
        print(f"\n=== Test {test_idx + 1}/{num_tests} ===")
        
        # Step 1: Get a random story and tokenize it
        random_story_idx = np.random.randint(0, len(ds['train']))
        full_text = ds['train'][random_story_idx]['text']
        
        # Tokenize the full text to work with token indices
        full_tokens = gpt2.to_str_tokens(full_text)
        N = len(full_tokens)
        
        print(f"Full text has {N} tokens")
        
        # Check if we have enough tokens for the operation
        min_required_tokens = shift_amt + prompt_len
        if N < min_required_tokens:
            print(f"Story too short ({N} tokens), need at least {min_required_tokens}. Skipping...")
            continue
        
        # Choose start_ind such that we can create both prompts
        # start_ind must be >= shift_amt (so shifted prompt doesn't go negative)
        # start_ind + prompt_len must be <= N (so original prompt doesn't exceed bounds)
        max_start_ind = N - prompt_len
        min_start_ind = shift_amt
        
        if max_start_ind < min_start_ind:
            print(f"Cannot create valid prompts with current parameters. Skipping...")
            continue
            
        start_ind = np.random.randint(min_start_ind, max_start_ind + 1)
        
        print(f"Selected start_ind: {start_ind}")
        
        # Step 2: Create original and shifted prompts
        # Original: full_text[start_ind:start_ind+prompt_len]
        original_tokens = full_tokens[start_ind:start_ind + prompt_len]
        # original_prompt = gpt2.to_string(original_tokens)
        original_prompt = ''.join(original_tokens)
        
        # Shifted: full_text[start_ind-shift_amt:start_ind+prompt_len]
        shifted_start = start_ind - shift_amt
        shifted_tokens = full_tokens[shifted_start:start_ind + prompt_len]
        shifted_prompt = ''.join(shifted_tokens)
        # shifted_prompt = gpt2.to_string(shifted_tokens)
        
        print(f"Original prompt tokens ({len(original_tokens)}): {original_tokens}")
        print(f"Shifted prompt tokens ({len(shifted_tokens)}): {shifted_tokens}")
        print(f"Original prompt: {original_prompt[:100]}...")
        print(f"Shifted prompt: {shifted_prompt[:100]}...")
        
        # Step 3: Get SAE activations for original prompt
        logits, cache = gpt2.run_with_cache_with_saes(original_prompt, saes=[gpt2_sae])
        original_activations = cache[sae_hook_name][0]  # Shape: [seq_len, n_features]
        
        # Step 4: Get top k activation values and their token/feature indices
        # Focus on a middle token to avoid edge effects
        middle_token_idx = len(original_tokens) // 2
        token_activations = original_activations[middle_token_idx]  # Shape: [n_features]
        
        # Get top k features
        top_k_values, top_k_features = torch.topk(token_activations, k)
        
        print(f"Analyzing token at position {middle_token_idx}: '{original_tokens[middle_token_idx]}'")
        print(f"Top {k} features: {top_k_features.tolist()}")
        print(f"Top {k} values: {top_k_values.tolist()}")
        
        # Step 5: Get SAE activations for shifted prompt
        shifted_logits, shifted_cache = gpt2.run_with_cache_with_saes(shifted_prompt, saes=[gpt2_sae])
        shifted_activations = shifted_cache[sae_hook_name][0]
        
        # Step 6: Get activation values for the same features at corresponding positions
        # The corresponding token in the shifted prompt should be at position middle_token_idx + shift_amt
        # because the shifted prompt has shift_amt additional tokens at the beginning
        corresponding_token_idx = middle_token_idx + shift_amt
        
        if corresponding_token_idx < len(shifted_tokens):
            shifted_token_activations = shifted_activations[corresponding_token_idx]
            shifted_feature_values = shifted_token_activations[top_k_features]
            
            print(f"Corresponding token in shifted prompt at position {corresponding_token_idx}: '{shifted_tokens[corresponding_token_idx]}'")
            print(f"Original feature values: {top_k_values.tolist()}")
            print(f"Shifted feature values: {shifted_feature_values.tolist()}")
            
            # Verify that we're comparing the same actual token
            original_token_text = original_tokens[middle_token_idx]
            shifted_token_text = shifted_tokens[corresponding_token_idx]
            tokens_match = original_token_text == shifted_token_text
            print(f"Tokens match: {tokens_match} ('{original_token_text}' vs '{shifted_token_text}')")
            
            # Step 7: Calculate alignment metric
            cosine_sim = torch.nn.functional.cosine_similarity(
                top_k_values.unsqueeze(0), 
                shifted_feature_values.unsqueeze(0)
            ).item()
            
            l2_distance = torch.linalg.vector_norm(top_k_values - shifted_feature_values).item()
            relative_change = (l2_distance / torch.linalg.vector_norm(top_k_values).item()) * 100
            
            print(f"Cosine similarity: {cosine_sim:.4f}")
            print(f"L2 distance: {l2_distance:.4f}")
            print(f"Relative change: {relative_change:.2f}%")
            
            # Determine alignment
            is_aligned = cosine_sim > 0.8 and relative_change < 20  # Thresholds can be adjusted
            alignment_status = "ALIGNED" if is_aligned else "NOT ALIGNED"
            print(f"Alignment status: {alignment_status}")
            
            alignment_results.append({
                'test_idx': test_idx,
                'cosine_similarity': cosine_sim,
                'l2_distance': l2_distance,
                'relative_change': relative_change,
                'is_aligned': is_aligned,
                'tokens_match': tokens_match,
                'original_token': original_token_text,
                'shifted_token': shifted_token_text,
                'start_ind': start_ind,
                'shift_amt': shift_amt
            })
        else:
            print("Corresponding token index out of bounds, skipping this test")
    
    # Summary statistics
    print(f"\n=== SUMMARY ===")
    if alignment_results:
        aligned_count = sum(1 for r in alignment_results if r['is_aligned'])
        tokens_match_count = sum(1 for r in alignment_results if r['tokens_match'])
        avg_cosine_sim = np.mean([r['cosine_similarity'] for r in alignment_results])
        avg_relative_change = np.mean([r['relative_change'] for r in alignment_results])
        
        print(f"Tests completed: {len(alignment_results)}")
        print(f"Token matches: {tokens_match_count}/{len(alignment_results)} ({tokens_match_count/len(alignment_results)*100:.1f}%)")
        print(f"Aligned sequences: {aligned_count}/{len(alignment_results)} ({aligned_count/len(alignment_results)*100:.1f}%)")
        print(f"Average cosine similarity: {avg_cosine_sim:.4f}")
        print(f"Average relative change: {avg_relative_change:.2f}%")
        
        if aligned_count > len(alignment_results) / 2:
            print("CONCLUSION: Sequence dimension appears to be ALIGNED")
        else:
            print("CONCLUSION: Sequence dimension appears to be NOT ALIGNED")
    
    return alignment_results

# Run the test
results = test_nlp_sequence_alignment(num_tests=100, k=10, shift_amt=5, prompt_len=30)


=== Test 1/100 ===
Full text has 145 tokens
Selected start_ind: 72
Original prompt tokens (30): [' kept', ' giving', ' the', ' sw', 'an', ' more', ' and', ' more', ' bread', '.', ' She', ' didn', "'t", ' know', ' that', ' too', ' much', ' bread', ' could', ' spoil', ' the', ' sw', 'an', "'s", ' tum', 'my', '.', ' The', ' sw', 'an']
Shifted prompt tokens (35): ['.', ' But', ' the', ' little', ' girl', ' kept', ' giving', ' the', ' sw', 'an', ' more', ' and', ' more', ' bread', '.', ' She', ' didn', "'t", ' know', ' that', ' too', ' much', ' bread', ' could', ' spoil', ' the', ' sw', 'an', "'s", ' tum', 'my', '.', ' The', ' sw', 'an']
Original prompt:  kept giving the swan more and more bread. She didn't know that too much bread could spoil the swan'...
Shifted prompt: . But the little girl kept giving the swan more and more bread. She didn't know that too much bread ...
Analyzing token at position 15: ' too'
Top 10 features: [9541, 21234, 11750, 11179, 20918, 12871, 3879, 16779, 1395, 

In [46]:
patchtst = HookedSAETransformer.from_pretrained("patchtst_relu", center_unembed=False).to(device)

hf_repo_id = "Coaster41/patchtst-sae-16-4.0"
# hf_repo_id = "Coaster41/patchtst-sae-8-0.5"
sae_id = "blocks.2.hook_mlp_out"
sae_patchtst_hook_name = f"{sae_id}.hook_sae_acts_post"

patchtst_sae = SAE.from_pretrained(
    release=hf_repo_id, sae_id=sae_id, device=str(device)
)

# Loading tsmixup dataset
train_dataset, val_dataset = create_cached_tsmixup_datasets(
        max_samples=300000,
        context_length=1024,
        prediction_length=96, # 1 or 96
        num_workers=16,
        cache_dir="/extra/datalab_scratch0/ctadler/time_series_models/mechanistic_interpretability/data/tsmixup_cache/",
        processed_cache_path="/extra/datalab_scratch0/ctadler/time_series_models/mechanistic_interpretability/data/tsmixup_cache/tsmixup_processed_300000_1024_96.pkl",
        batch_size=4000
    )

def load_tsmixup(ts=1000, ctx_len=512):
    x = []
    # y = []
    for i in range(ts) if isinstance(ts, int) else ts:
        val_dict = val_dataset[i]
        x.append(val_dict['past_values'])
        # y.append(val_dict['future_values'])
    x = torch.stack(x)[:, -ctx_len:]
    # y = torch.stack(y)[:,:pred_length]
    return x.to(device)



Loaded pretrained model patchtst_relu into HookedTransformer
Moving model to device:  cuda:7


This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)


🚀 CREATING CACHED TSMIXUP DATASETS
📂 Found existing processed data at /extra/datalab_scratch0/ctadler/time_series_models/mechanistic_interpretability/data/tsmixup_cache/tsmixup_processed_300000_1024_96.pkl
⚡ Loading preprocessed data from cache...
✅ Loaded 113,892 preprocessed samples
📅 Cache created: 2025-08-07 10:25:17

📊 DATASET SUMMARY:
  Total processed samples: 113,892
  Context length: 1024
  Prediction length: 96
🔀 Shuffling data...
📈 Data split:
  Training samples: 102,502
  Validation samples: 11,390
  Train ratio: 90.0%
🏗️  Creating PyTorch datasets...
🏗️  Dataset created with 102,502 samples
📊 Augmentation: ON
📈 Dataset Statistics (from 1000 samples):
  Sequence lengths: min=1120, max=2048, mean=1551
  Value ranges: min=-41.9283, max=49.6888
  Value stats: mean=0.9387, std=2.0463
🏗️  Dataset created with 11,390 samples
📊 Augmentation: OFF
📈 Dataset Statistics (from 1000 samples):
  Sequence lengths: min=1120, max=2048, mean=1556
  Value ranges: min=-52.9891, max=118.3337
  

In [49]:
def patchify(x: Tensor, patch_len=16):
    return x.reshape(x.shape[0], x.shape[1]//patch_len, patch_len)

def unpatchify(x: Tensor):
    return x.flatten(start_dim=1)

def test_ts_sequence_alignment(num_tests=10, k=5, shift_amt=2, prompt_len=32, patch_len=16):
    """
    Test sequence dimension alignment by comparing SAE activations before and after shifting prompts.
    
    Args:
        num_tests: Number of random prompts to test
        k: Number of top activations to track
        shift_amt: Number of tokens to shift the prompt by
        prompt_len: Length of the prompt in tokens
    """
    alignment_results = []
    
    for test_idx in range(num_tests):
        print(f"\n=== Test {test_idx + 1}/{num_tests} ===")
        
        # Step 1: Get a random story and tokenize it
        random_story_idx = np.random.randint(0, len(val_dataset))
        full_text = load_tsmixup([random_story_idx], ctx_len=1024)
        
        # Tokenize the full text to work with token indices
        full_tokens = patchify(full_text, patch_len)
        N = full_tokens.shape[1]
        
        print(f"Full text has {N} patches")
        
        # Check if we have enough tokens for the operation
        min_required_tokens = shift_amt + prompt_len
        if N < min_required_tokens:
            print(f"Story too short ({N} patches), need at least {min_required_tokens}. Skipping...")
            continue
        
        # Choose start_ind such that we can create both prompts
        # start_ind must be >= shift_amt (so shifted prompt doesn't go negative)
        # start_ind + prompt_len must be <= N (so original prompt doesn't exceed bounds)
        max_start_ind = N - prompt_len
        min_start_ind = shift_amt
        
        if max_start_ind < min_start_ind:
            print(f"Cannot create valid prompts with current parameters. Skipping...")
            continue
            
        start_ind = np.random.randint(min_start_ind, max_start_ind + 1)
        
        print(f"Selected start_ind: {start_ind}")
        
        # Step 2: Create original and shifted prompts
        # Original: full_text[start_ind:start_ind+prompt_len]
        original_tokens = full_tokens[:,start_ind:start_ind + prompt_len]
        original_prompt = unpatchify(original_tokens)
        
        # Shifted: full_text[start_ind-shift_amt:start_ind+prompt_len]
        shifted_start = start_ind - shift_amt
        shifted_tokens = full_tokens[:,shifted_start:start_ind + prompt_len]
        shifted_prompt = unpatchify(shifted_tokens)
        
        print(f"Original prompt tokens ({original_tokens.shape[1]}): {original_tokens}")
        print(f"Shifted prompt tokens ({shifted_tokens.shape[1]}): {shifted_tokens}")
        print(f"Original prompt: {original_prompt[0,:100]}...")
        print(f"Shifted prompt: {shifted_prompt[0,:100]}...")
        
        # Step 3: Get SAE activations for original prompt
        logits, cache = patchtst.run_with_cache_with_saes(original_prompt, saes=[patchtst_sae])
        original_activations = cache[sae_patchtst_hook_name][0]  # Shape: [seq_len, n_features]
        
        # Step 4: Get top k activation values and their token/feature indices
        # Focus on a middle token to avoid edge effects
        middle_token_idx = original_tokens.shape[1] // 2
        token_activations = original_activations[middle_token_idx]  # Shape: [n_features]
        
        # Get top k features
        top_k_values, top_k_features = torch.topk(token_activations, k)
        
        print(f"Analyzing token at position {middle_token_idx}: '{original_tokens[0,middle_token_idx]}'")
        print(f"Top {k} features: {top_k_features.tolist()}")
        print(f"Top {k} values: {top_k_values.tolist()}")
        
        # Step 5: Get SAE activations for shifted prompt
        shifted_logits, shifted_cache = patchtst.run_with_cache_with_saes(shifted_prompt, saes=[patchtst_sae])
        shifted_activations = shifted_cache[sae_patchtst_hook_name][0]
        
        # Step 6: Get activation values for the same features at corresponding positions
        # The corresponding token in the shifted prompt should be at position middle_token_idx + shift_amt
        # because the shifted prompt has shift_amt additional tokens at the beginning
        corresponding_token_idx = middle_token_idx + shift_amt
        
        if corresponding_token_idx < shifted_tokens.shape[1]:
            shifted_token_activations = shifted_activations[corresponding_token_idx]
            shifted_feature_values = shifted_token_activations[top_k_features]
            
            print(f"Corresponding token in shifted prompt at position {corresponding_token_idx}: '{shifted_tokens[0,corresponding_token_idx]}'")
            print(f"Original feature values: {top_k_values.tolist()}")
            print(f"Shifted feature values: {shifted_feature_values.tolist()}")
            
            # Verify that we're comparing the same actual token
            original_token_text = original_tokens[0,middle_token_idx]
            shifted_token_text = shifted_tokens[0,corresponding_token_idx]
            tokens_match = torch.all(original_token_text == shifted_token_text)
            print(f"Tokens match: {tokens_match} ('{original_token_text}' vs '{shifted_token_text}')")
            
            # Step 7: Calculate alignment metric
            cosine_sim = torch.nn.functional.cosine_similarity(
                top_k_values.unsqueeze(0), 
                shifted_feature_values.unsqueeze(0)
            ).item()
            
            l2_distance = torch.linalg.vector_norm(top_k_values - shifted_feature_values).item()
            relative_change = (l2_distance / torch.linalg.vector_norm(top_k_values).item()) * 100
            
            print(f"Cosine similarity: {cosine_sim:.4f}")
            print(f"L2 distance: {l2_distance:.4f}")
            print(f"Relative change: {relative_change:.2f}%")
            
            # Determine alignment
            is_aligned = cosine_sim > 0.8 and relative_change < 20  # Thresholds can be adjusted
            alignment_status = "ALIGNED" if is_aligned else "NOT ALIGNED"
            print(f"Alignment status: {alignment_status}")
            
            alignment_results.append({
                'test_idx': test_idx,
                'cosine_similarity': cosine_sim,
                'l2_distance': l2_distance,
                'relative_change': relative_change,
                'is_aligned': is_aligned,
                'tokens_match': tokens_match,
                'original_token': original_token_text,
                'shifted_token': shifted_token_text,
                'start_ind': start_ind,
                'shift_amt': shift_amt
            })
        else:
            print("Corresponding token index out of bounds, skipping this test")
    
    # Summary statistics
    print(f"\n=== SUMMARY ===")
    if alignment_results:
        aligned_count = sum(1 for r in alignment_results if r['is_aligned'])
        tokens_match_count = sum(1 for r in alignment_results if r['tokens_match'])
        avg_cosine_sim = np.mean([r['cosine_similarity'] for r in alignment_results])
        avg_relative_change = np.mean([r['relative_change'] for r in alignment_results])
        
        print(f"Tests completed: {len(alignment_results)}")
        print(f"Token matches: {tokens_match_count}/{len(alignment_results)} ({tokens_match_count/len(alignment_results)*100:.1f}%)")
        print(f"Aligned sequences: {aligned_count}/{len(alignment_results)} ({aligned_count/len(alignment_results)*100:.1f}%)")
        print(f"Average cosine similarity: {avg_cosine_sim:.4f}")
        print(f"Average relative change: {avg_relative_change:.2f}%")
        
        if aligned_count > len(alignment_results) / 2:
            print("CONCLUSION: Sequence dimension appears to be ALIGNED")
        else:
            print("CONCLUSION: Sequence dimension appears to be NOT ALIGNED")
    
    return alignment_results

# Run the test
results = test_ts_sequence_alignment(num_tests=5, k=10, shift_amt=1, prompt_len=32)


=== Test 1/5 ===
Full text has 64 patches
Selected start_ind: 25
Original prompt tokens (32): tensor([[[0.2503, 0.2218, 0.1673, 0.1897, 0.2077, 0.1901, 0.1998, 0.2940,
          0.4001, 0.3645, 0.3967, 0.3605, 0.3850, 0.4365, 0.3613, 1.0611],
         [1.1176, 0.7635, 1.0720, 0.5913, 0.8422, 0.9411, 0.8443, 1.1271,
          0.4149, 0.3786, 0.3242, 0.2256, 0.7387, 0.2257, 0.3044, 0.2468],
         [0.3573, 0.4487, 0.4591, 0.6553, 0.6384, 0.5180, 0.4625, 0.5518,
          1.5172, 1.0426, 0.5288, 1.0398, 0.5566, 0.7937, 0.6713, 0.5195],
         [0.4621, 0.4169, 0.2997, 0.2838, 0.2207, 0.1794, 0.1582, 0.1643,
          0.1486, 0.2957, 0.3810, 0.4118, 0.4271, 0.4157, 0.5784, 0.5344],
         [0.5926, 0.5439, 1.1740, 0.8232, 0.4216, 0.3624, 0.8927, 0.4088,
          0.7375, 1.1750, 0.1983, 0.2145, 0.1079, 0.2006, 0.2562, 0.3553],
         [0.4170, 0.3255, 0.4017, 0.3116, 0.4052, 0.4038, 0.3531, 0.4290,
          0.4946, 1.0316, 0.4843, 0.5601, 0.4892, 0.2749, 0.3069, 0.2324],
         [0