In [None]:
# general imports
import os
os.environ["CUDA_VISIBLE_DEVICES"]= "4,5,6,7"
import torch
from tqdm import tqdm
import plotly.express as px

torch.set_grad_enabled(False);
# package import
from torch import Tensor
from transformer_lens import utils
from functools import partial
from jaxtyping import Int, Float

In [None]:
from transformer_lens import HookedTransformer
from sae_lens import SAE
from torch.nn.parallel import DataParallel

# Choose a layer you want to focus on
# For this tutorial, we're going to use layer 2
layer = 8

model = HookedTransformer.from_pretrained("gpt2-small").to('cuda')

# Initialize SAE
layer = 8
sae, cfg_dict, _ = SAE.from_pretrained(
    release="gpt2-small-res-jb",
    sae_id=f"blocks.{layer}.hook_resid_pre",
    device= "cuda:0"
)

# get hook point
hook_point = sae.cfg.hook_name
print(hook_point)

In [None]:
import torch
print(f"PyTorch Version: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")
print(f"CUDA Version: {torch.version.cuda}")

In [None]:
def get_top_k_active_directions_encoder(sae: SAE, activation: torch.Tensor, k: int) -> torch.Tensor:
    activation = activation.to(sae.W_enc.device)
    latent = sae.encode(activation)
    _, top_k_indices = torch.topk(latent.abs().squeeze(), k)
    top_k_directions = sae.W_enc.T[top_k_indices].T
    unit_vectors = torch.nn.functional.normalize(top_k_directions, p=2, dim=1)
    return unit_vectors

In [None]:
def get_top_k_active_directions_decoder(sae: SAE, activation: torch.Tensor, k: int) -> torch.Tensor:
    activation = activation.to(sae.W_dec.device)
    latent = sae.encode(activation)
    _, top_k_indices = torch.topk(latent.abs().squeeze(), k)
    top_k_directions = sae.W_dec[top_k_indices].T
    unit_vectors = torch.nn.functional.normalize(top_k_directions, p=2, dim=1)
    return unit_vectors

In [None]:
import torch
from transformer_lens import HookedTransformer
from sae_lens import SAE

def get_last_token_directions(model: HookedTransformer, sae: SAE, prompt: str, k: int = 5) -> torch.Tensor:
    # Tokenize the prompt
    tokens = model.to_tokens(prompt, prepend_bos=True)
    
    # Run the model and get the cache
    _, cache = model.run_with_cache(tokens)
    
    # Get the activation for the last token
    last_token_activation = cache[sae.cfg.hook_name][:, -1, :]
    
    # Get the top k active directions
    top_k_directions = get_top_k_active_directions(sae, last_token_activation, k)
    
    return top_k_directions


prompt = "The capital of France is"
k = 5

try:
    top_directions = get_last_token_directions(model, sae, prompt, k)
    print("Successfully computed top directions")
    print(f"Top directions shape: {top_directions.shape}")
except Exception as e:
    print(f"An error occurred: {str(e)}")
    import traceback
    traceback.print_exc()

# If successful, compute angles
if 'top_directions' in locals():
    angles = torch.zeros((k, k))
    for i in range(k):
        for j in range(k):
            cos_similarity = torch.dot(top_directions[i], top_directions[j])
            angle = torch.acos(torch.clamp(cos_similarity, -1, 1)) * 180 / torch.pi
            angles[i, j] = angle

    print("\nAngles between directions (in degrees):")
    print(angles)


In [None]:
import torch as th
import torch.nn.functional as F
from typing import NamedTuple, Literal, Callable
from tqdm.auto import tqdm

class CausalBasis(NamedTuple):
    energies: th.Tensor
    vectors: th.Tensor

def compute_direction_mean(
    batch: dict,
    model: th.nn.Module,
    layer_index: int,
    direction: th.Tensor
) -> float:
    total_projection = 0.0
    total_tokens = 0

    with th.no_grad():
        input_ids = batch["input_ids"].to(direction.device)
        attention_mask = batch["attention_mask"].to(direction.device)

        _, cache = model.run_with_cache(input_ids, attention_mask=attention_mask)
        activations = cache[f"blocks.{layer_index}.hook_resid_pre"]

        # Project activations onto the direction
        projection = th.einsum("bld,d->bl", activations, direction)
        
        # Apply attention mask and sum
        masked_projection = projection * attention_mask
        batch_sum = masked_projection.sum()
        batch_tokens = attention_mask.sum()

        total_projection += batch_sum.item()
        total_tokens += batch_tokens.item()

    return total_projection / total_tokens


In [None]:
import torch as th

def remove_subspace(u: th.Tensor, v: th.Tensor, direction_mean: th.Tensor, mode: str = "mean") -> th.Tensor:
    v = v.squeeze()  # Ensure v is 1D
    v_normalized = v / v.norm()
    if mode == "mean":
        proj_u = th.einsum("bi,i->b", u, v_normalized).unsqueeze(1) * v_normalized.unsqueeze(0)
        result = u - proj_u + direction_mean
        
    elif mode == "zero":
        result = u - th.einsum("bi,i->b", u, v_normalized).unsqueeze(1) * v_normalized.unsqueeze(0)
    else:
        raise ValueError(f"Unknown mode {mode}")
    return result

In [None]:
import torch as th
import matplotlib.pyplot as plt
import einops

def compute_functional_similarity(orig_activations, abl_activations, orig_logits, abl_logits, beta=0):
    # Compute mean squared difference for each layer's activations
    msd_list = []
    for orig_act, abl_act in zip(orig_activations[:-1], abl_activations[:-1]):  # Exclude logits
        msd = th.mean((th.norm(orig_act) - th.norm(abl_act))**2)
        msd_list.append(msd)
    
    # Average mean squared difference across layers
    if msd_list:
        avg_msd = th.mean(th.stack(msd_list))
    else:
        # If there are no intermediate activations, use only the logits
        avg_msd = th.mean((orig_logits - abl_logits) ** 2)

    orig_logits_flat = einops.rearrange(orig_logits, "... vocab -> (...) vocab")
    abl_logits_flat = einops.rearrange(abl_logits, "... vocab -> (...) vocab")

    
    # Compute KL divergence for logits
    kl_div = F.kl_div(
        F.log_softmax(abl_logits_flat, dim=-1),
        F.log_softmax(orig_logits_flat, dim=-1),
        log_target=True,
        reduction="batchmean",
    )
    # Combine average MSD and KL divergence
    # Note: We use negative MSD because we want to maximize similarity (minimize difference)
    similarity = beta * avg_msd + (1 - beta) * kl_div
    
    return similarity, avg_msd, kl_div

def compute_functional_similarity_logits_only(orig_logits, abl_logits, beta=0):
    orig_logits_flat = einops.rearrange(orig_logits, "... vocab -> (...) vocab")
    abl_logits_flat = einops.rearrange(abl_logits, "... vocab -> (...) vocab")

    # Compute KL divergence for logits
    kl_div = F.kl_div(
        F.log_softmax(abl_logits_flat, dim=-1),
        F.log_softmax(orig_logits_flat, dim=-1),
        log_target=True,
        reduction="batchmean",
    )
    similarity = kl_div
    
    return similarity, 0, kl_div

def compute_subsequent_outputs(model: th.nn.Module, activation: th.Tensor, layer_index: int) -> list[th.Tensor]:
    outputs = [activation]  # Include the initial activation

    if activation.ndim == 2:
        activation = activation.unsqueeze(1)  # Add pos dimension if missing

    current_activation = activation

    for block in model.blocks[layer_index+1:]:
        # Pass through the complete block
        current_activation = block(current_activation)
        outputs.append(current_activation)

    # Apply the final layer normalization
    current_activation = model.ln_final(current_activation)

    # Compute the logits
    logits = model.unembed(current_activation)
    outputs.append(logits)

    return outputs

def compute_subsequent_outputs_logits_only(model: th.nn.Module, activation: th.Tensor, layer_index: int) -> list[th.Tensor]:
    if activation.ndim == 2:
        activation = activation.unsqueeze(1) 
    current_activation = activation

    for block in model.blocks[layer_index+1:]:
        current_activation = block(current_activation)
    current_activation = model.ln_final(current_activation)
    logits = model.unembed(current_activation)

    return logits


In [None]:
import torch
import torch.nn.functional as F
import einops

def ablation_operation(v: torch.Tensor, u: torch.Tensor) -> torch.Tensor:
    """
    Implement the ablation operation v ⊖ u.
    """
    return torch.max(torch.zeros_like(v), v - torch.abs(u))

def find_ablation_magnitude(sae: SAE, activation: torch.Tensor, latent_direction: torch.Tensor, threshold: float = 0.5) -> float:
    """
    Find the magnitude c such that the latent just stops activating.
    """
    c = 0.0
    step = 1.0
    max_iterations = 100
    
    for _ in range(max_iterations):
        ablated_activation = ablation_operation(activation, c * latent_direction)
        latent = sae.encode(ablated_activation)
        if latent[latent_direction] < threshold:
            return c
        c += step
        
    return c  # Return the last c if we didn't converge

def compute_sq_score(sae: SAE, model: torch.nn.Module, activations: torch.Tensor, latent_index: int, layer_index: int, num_samples: int = 10):
    """
    Compute the SQ score for a given latent.
    """
    device = activations.device
    latent_direction = sae.W_dec[latent_index].T.to(device)
    
    sq_scores = []
    
    for activation in activations:
        # Find c such that the latent just stops activating
        c = find_ablation_magnitude(sae, activation, latent_direction)
        
        # Compute x-
        x_minus = ablation_operation(activation, c * latent_direction)
        
        # Compute functional difference
        orig_logits = compute_subsequent_outputs_logits_only(model, activation.unsqueeze(0), layer_index)
        abl_logits = compute_subsequent_outputs_logits_only(model, x_minus.unsqueeze(0), layer_index)
        
        s_l, _, _ = compute_functional_similarity_logits_only(orig_logits, abl_logits)
        
        # Compute S_max
        s_max_samples = []
        for _ in range(num_samples):
            random_direction = torch.randn_like(latent_direction)
            random_direction = random_direction / random_direction.norm() * c
            x_random = ablation_operation(activation, random_direction)
            random_logits = compute_subsequent_outputs_logits_only(model, x_random.unsqueeze(0), layer_index)
            s_random, _, _ = compute_functional_similarity_logits_only(orig_logits, random_logits)
            s_max_samples.append(s_random)
        
        s_max = max(s_max_samples)
        
        # Compute SQ score for this activation
        sq_score = s_l / s_max if s_max > 0 else 0
        sq_scores.append(sq_score)
    
    # Return the average SQ score
    return torch.mean(torch.stack(sq_scores))

# Example usage
sae = SAE(...)  # Your SAE model
model = TransformerModel(...)  # Your main model
activations = ...  # Your set of activations that activate the latent
latent_index = ...  # The index of the latent you want to evaluate
layer_index = ...  # The index of the layer where the activations are from

sq_score = compute_sq_score(sae, model, activations, latent_index, layer_index)
print(f"SQ Score for latent {latent_index}: {sq_score.item()}")