In [1]:
# %%
import os
import gc
import torch
os.chdir("/work/pi_jensen_umass_edu/jnainani_umass_edu/CircuitAnalysisSAEs")
import json
from sae_lens import SAE, HookedSAETransformer
from circ4latents import data_gen
from functools import partial
import einops

# Function to manage CUDA memory and clean up
def cleanup_cuda():
   torch.cuda.empty_cache()
   gc.collect()
# cleanup_cuda()
# Load the config
with open("config.json", 'r') as file:
   config = json.load(file)
token = config.get('huggingface_token', None)
os.environ["HF_TOKEN"] = token

# Define device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

hf_cache = "/work/pi_jensen_umass_edu/jnainani_umass_edu/mechinterp/huggingface_cache/hub"
os.environ["HF_HOME"] = hf_cache

# Load the model
model = HookedSAETransformer.from_pretrained("google/gemma-2-9b", device=device, cache_dir=hf_cache)



Device: cuda


Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]



Loaded pretrained model google/gemma-2-9b into HookedTransformer


In [2]:
layers= [7, 14, 21, 40]
l0s = [92, 67, 129, 125]
saes = [SAE.from_pretrained(release="gemma-scope-9b-pt-res", sae_id=f"layer_{layers[i]}/width_16k/average_l0_{l0s[i]}", device=device)[0] for i in range(len(layers))]

# Data Gen

In [3]:
# %%
# Updated version to return JSON with more names and structure for correct and incorrect keying examples

import json
import random

# Expanding the name pool with a larger set of names
extended_name_pool = [
    "Bob", "Sam", "Lilly", "Rob", "Alice", "Charlie", "Sally", "Tom", "Jake", "Emily", 
    "Megan", "Chris", "Sophia", "James", "Oliver", "Isabella", "Mia", "Jackson", 
    "Emma", "Ava", "Lucas", "Benjamin", "Ethan", "Grace", "Olivia", "Liam", "Noah"
]

for name in extended_name_pool:
    assert len(model.tokenizer.encode(name)) == 2, f"Name {name} has more than 1 token"

# Function to generate the dataset with correct and incorrect keying into dictionaries
def generate_extended_dataset(name_pool, num_samples=5):
    dataset = []
    
    for _ in range(num_samples):
        # Randomly select 5 names from the pool
        selected_names = random.sample(name_pool, 5)
        # Assign random ages to the selected names
        age_dict = {name: random.randint(10, 19) for name in selected_names}
        
        # Create a correct example
        correct_name = random.choice(list(age_dict.keys()))
        correct_prompt = f'Type "help", "copyright", "credits" or "license" for more information.\n>>> age = {age_dict}\n>>> print(age["{correct_name}"])\n'
        correct_response = age_dict[correct_name]
        correct_token = str(correct_response)[0]
        
        # Create an incorrect example with a name not in the dictionary
        incorrect_name = random.choice([name for name in name_pool if name not in age_dict])
        incorrect_prompt = f'Type "help", "copyright", "credits" or "license" for more information.\n>>> age = {age_dict}\n>>> print(age["{incorrect_name}"])\n'
        incorrect_response = "Traceback"
        incorrect_token = "Traceback"
        
        # Append the pair of correct and incorrect examples
        dataset.append({
            "correct": {
                "prompt": correct_prompt,
                "response": correct_response,
                "token": correct_token
            },
            "error": {
                "prompt": incorrect_prompt,
                "response": incorrect_response,
                "token": incorrect_token
            }
        })
        
    return dataset

# Generate the extended dataset
json_dataset = generate_extended_dataset(extended_name_pool, num_samples=10000)

# Output the JSON structure

# %%
clean_prompts = []
corr_prompts = []

answer_token = model.to_single_token("1")
traceback_token = model.to_single_token("Traceback")

for item in json_dataset[:50]:
    corr_prompts.append(item["correct"]["prompt"])
    clean_prompts.append(item["error"]["prompt"])

clean_tokens = model.to_tokens(clean_prompts)
corr_tokens = model.to_tokens(corr_prompts)

# %%
def logit_diff_fn(logits):
    err = logits[:, -1, traceback_token]
    no_err = logits[:, -1, answer_token]
    return (err - no_err).mean()

# Disable gradients for all parameters
for param in model.parameters():
   param.requires_grad_(False)

# # Compute logits for clean and corrupted samples
logits = model(clean_tokens)
clean_diff = logit_diff_fn(logits)

logits = model(corr_tokens)
corr_diff = logit_diff_fn(logits)

print(f"clean_diff: {clean_diff}")
print(f"corr_diff: {corr_diff}")

# # Cleanup
del logits
cleanup_cuda()

# # Define error type metric
def _err_type_metric(logits, clean_logit_diff, corr_logit_diff):
    patched_logit_diff = logit_diff_fn(logits)
    return (patched_logit_diff - corr_logit_diff) / (clean_logit_diff - corr_logit_diff)

err_metric_denoising = partial(_err_type_metric, clean_logit_diff=clean_diff, corr_logit_diff=corr_diff)

clean_diff: 6.4912004470825195
corr_diff: -6.798116207122803


In [33]:
# simple_dataset = []
# simple_labels = []
clean_dataset = []
corr_dataset = []
clean_labels = []
corr_labels = []

answer_token = model.to_single_token("1")
traceback_token = model.to_single_token("Traceback")

for item in json_dataset:
    clean_dataset.append(item["error"]["prompt"])
    corr_dataset.append(item["correct"]["prompt"])
    clean_labels.append(traceback_token)
    corr_labels.append(answer_token)


clean_tok_dataset = model.to_tokens(clean_dataset)
clean_labels = torch.tensor(clean_labels)

corr_tok_dataset = model.to_tokens(corr_dataset)
corr_labels = torch.tensor(corr_labels)


In [35]:
permutation = torch.randperm(len(clean_tok_dataset))
clean_tok_dataset = clean_tok_dataset[permutation]
clean_labels = clean_labels[permutation]

corr_tok_dataset = corr_tok_dataset[permutation]
corr_labels = corr_labels[permutation]

# Sparse Mask Definition

In [47]:
import signal
import torch.nn as nn
class KeyboardInterruptBlocker:
    def __enter__(self):
        # Ignore SIGINT (KeyboardInterrupt) and save the old handler
        self.original_handler = signal.signal(signal.SIGINT, signal.SIG_IGN)

    def __exit__(self, exc_type, exc_value, traceback):
        # Restore the original SIGINT handler
        signal.signal(signal.SIGINT, self.original_handler)

class SparseMask(nn.Module):
    def __init__(self, shape, l1, seq_len=None):
        super().__init__()
        if seq_len is not None:
            self.mask = nn.Parameter(torch.ones(seq_len, shape))
        else:
            self.mask = nn.Parameter(torch.ones(shape))
        self.l1 = l1
        self.max_temp = torch.tensor(1000.0)
        self.sparsity_loss = None
        self.ratio_trained = 1
        self.temperature = 1


    def forward(self, x, binary=False, mean_ablation=None):
        if binary and mean_ablation is not None:
            binarized = (self.mask > 0).float().to(x.device)
            return x * binarized + mean_ablation * (~binarized.bool())
        if binary:
            # binary mask, 0 if negative, 1 if positive
            binarized = (self.mask > 0).float()
            return x * binarized
        
        self.temperature = self.max_temp ** self.ratio_trained
        mask = torch.sigmoid(self.mask * self.temperature)
        self.sparsity_loss = torch.abs(mask).sum() * self.l1

        if mean_ablation is None:
            return x * mask
        else:
            return x * mask + mean_ablation * (~mask.bool())

for sae in saes:
    sae.mask = SparseMask(sae.cfg.d_sae, 1.0, seq_len=67)


In [48]:

def apply_mask(mask_idxs, sae):
    mask = torch.full_like(sae.mask.mask, -10)
    mask[mask_idxs] = 10
    sae.mask.mask.data = mask

def load_sparsemask(mask_path):
    json_mask = json.load(open(mask_path))
    for sae in saes:
        apply_mask(json_mask[sae.cfg.hook_name], sae)

bos_token_id = model.tokenizer.bos_token_id
pad_token_id = model.tokenizer.pad_token_id
def build_sae_hook_fn(sae, sequence, cache_grads=False, circuit_mask=None, use_mask=False, binarize_mask=False, cache_masked_activations=False, cache_sae_activations=False, mean_ablate=False, mean_mask=False):
    # make the mask for the sequence
    mask = torch.ones_like(sequence, dtype=torch.bool)
    mask[sequence == pad_token_id] = False
    mask[sequence == bos_token_id] = False # where mask is false, keep original
    def sae_hook(value, hook):
        # print(f"sae {sae.cfg.hook_name} running at layer {hook.layer()}")
        feature_acts = sae.encode(value)
        if cache_grads:
            sae.feature_acts = feature_acts
            sae.feature_acts.retain_grad()
        
        if cache_sae_activations:
            sae.feature_acts = feature_acts.detach().clone()
        
        if use_mask:
            if mean_mask:
                feature_acts = sae.mask(feature_acts, binary=binarize_mask, mean_ablation=sae.mean_ablation)
            else:
                feature_acts = sae.mask(feature_acts, binary=binarize_mask)

        if circuit_mask is not None:
            mask_method = circuit_mask['mask_method']
            mask_indices = circuit_mask[sae.cfg.hook_name]
            if mask_method == 'keep_only':
                # any activations not in the mask are set to 0
                expanded_circuit_mask = torch.zeros_like(feature_acts)
                expanded_circuit_mask[:, :, mask_indices] = 1
                feature_acts = feature_acts * expanded_circuit_mask
            elif mask_method == 'zero_only':
                feature_acts[:, :, mask_indices] = 0
            else:
                raise ValueError(f"mask_method {mask_method} not recognized")
            
        if cache_masked_activations:
            sae.feature_acts = feature_acts.detach().clone()
        if mean_ablate:
            feature_acts = sae.mean_ablation

        out = sae.decode(feature_acts)
        # choose out or value based on the mask
        mask_expanded = mask.unsqueeze(-1).expand_as(value)
        value = torch.where(mask_expanded, out, value)
        return value
    return sae_hook

def build_hooks_list(sequence,
                    cache_sae_activations=False,
                    cache_sae_grads=False,
                    circuit_mask=None,
                    use_mask=False,
                    binarize_mask=False,
                    mean_mask=False,
                    cache_masked_activations=False,
                    mean_ablate=False,
                    ):
    hooks = []
    # # fake hook that adds zero so gradients propagate through the model
    param = nn.Parameter(torch.tensor(0.0, requires_grad=True))
    hooks.append(
        (
            "blocks.0.hook_resid_pre",
            lambda value, hook: value + param,
        )
    )
    for sae in saes:
        hooks.append(
            (
            sae.cfg.hook_name,
            build_sae_hook_fn(sae, sequence, cache_grads=cache_sae_grads, circuit_mask=circuit_mask, use_mask=use_mask, binarize_mask=binarize_mask, cache_masked_activations=cache_masked_activations, cache_sae_activations=cache_sae_activations, mean_ablate=mean_ablate, mean_mask=mean_mask),
            )
        )
    return hooks 

import gc
def cleanup_cuda():
    torch.cuda.empty_cache()
    gc.collect()


In [38]:
with torch.no_grad():
    def logitfn(tokens):
        return model.run_with_hooks(
            tokens,
            return_type="logits",
            fwd_hooks=build_hooks_list(tokens, cache_sae_activations=True)
            )
    logits = logitfn(corr_tokens)
    for sae in saes:
        print(sae.feature_acts.shape)
    del logits
    cleanup_cuda()

torch.Size([50, 67, 16384])
torch.Size([50, 67, 16384])
torch.Size([50, 67, 16384])
torch.Size([50, 67, 16384])


# Token Level Running Mean

In [39]:
from tqdm import tqdm
def running_mean_tensor(old_mean, new_value, n):
    """Update the running mean tensor using the current batch."""
    return old_mean + (new_value - old_mean) / n
def get_sae_means(saes, dataset, total_steps, batch_size=16):
    """
    Compute token-level means across the dataset in a batched manner.
    Args:
        dataset (Tensor): The input dataset of tokenized data.
        total_steps (int): Number of steps to process.
        batch_size (int): Number of examples per batch.
    """
    for sae in saes:
        # Initialize mean_ablation with correct shape
        sae.mean_ablation = torch.zeros((dataset[0].shape[0], sae.cfg.d_sae)).float().to(device)
    total_samples = len(dataset)
    num_batches = (total_samples + batch_size - 1) // batch_size  # Calculate number of batches
    with tqdm(total=min(total_steps, num_batches), desc="Mean Accum Progress") as pbar:
        sample_count = 0  # To track total number of samples processed
        for i in range(0, total_samples, batch_size):
            # Batch selection
            batch_x = dataset[i:i+batch_size]
            with torch.no_grad():
                _ = logitfn(batch_x)  # Get logits (forward pass)
                for sae in saes:
                    # Compute batch mean over tokens
                    batch_mean = sae.feature_acts.mean(dim=0)  # Mean across the batch
                    sample_count += len(batch_x)  # Update sample count
                    # Update running mean tensor
                    sae.mean_ablation = running_mean_tensor(
                        sae.mean_ablation,
                        batch_mean,
                        sample_count
                    )
            pbar.update(1)  # Update progress bar
            # Stop if we've processed enough steps
            if i // batch_size >= total_steps:
                break
            cleanup_cuda()
# get_sae_means(saes, corr_tok_dataset, 5, batch_size=16)

# Completeness

In [56]:
import torch

# Number of elements to remove from the last dimension
num_remove = 5

# Number of batches to process
num_batches = 3  # Adjust this as needed
batch_size = 16  # Batch size for processing

# Remove random subsets of the mask along the [-1] dimension
def get_indices_to_remove(mask, num_remove):
    active_indices = (mask > 0).nonzero(as_tuple=True)[-1]  # Get indices of active elements in the last dimension
    if len(active_indices) < num_remove:
        raise ValueError("Not enough active elements to remove.")
    indices_to_remove = active_indices[torch.randperm(len(active_indices))[:num_remove]].to(mask.device)  # Move to the same device
    return indices_to_remove

def apply_subset_removal(mask, indices_to_remove):
    indices_to_remove = indices_to_remove.to(mask.device)  # Ensure indices are on the same device
    modified_mask = mask.clone()
    modified_mask.index_fill_(-1, indices_to_remove, 0)  # Set selected elements to 0
    return modified_mask

# Circuit (random temporary mask)
for sae in saes:
    sae.original_mask = SparseMask(sae.cfg.d_sae, l1=1.0, seq_len=67)
    with torch.no_grad():
        sae.original_mask.mask.data = torch.randint(0, 2, sae.original_mask.mask.shape).float().to(device)

# Generate means for the corrupted distribution
get_sae_means(saes, corr_tok_dataset, 5, batch_size=16)

# Helper function for batched processing
def calculate_logit_diff(data, n_batches, batch_size, use_circuit):
    total_logit_diff = 0.0
    for batch_idx in range(n_batches):
        # Get the batch data
        batch_data = data[batch_idx * batch_size : (batch_idx + 1) * batch_size]
        with torch.no_grad():
            logits = model.run_with_hooks(
                batch_data,
                return_type="logits",
                fwd_hooks=build_hooks_list(
                    batch_data, 
                    cache_sae_activations=False, 
                    use_mask=True, 
                    binarize_mask=True, 
                    mean_mask=True
                )
            )
            logit_diff = logit_diff_fn(logits).item()
            total_logit_diff += logit_diff
            del logits
            cleanup_cuda()
    # Return the average logit difference across batches
    return total_logit_diff / n_batches

# Calculate logit diff for clean and corrupted masks
batch_clean_data = clean_tok_dataset[:batch_size * num_batches]
cleanup_cuda()

# Model logit diff (All ones mask)
for sae in saes:
    with torch.no_grad():
        sae.mask.mask.data = torch.ones_like(sae.mask.mask.data).float().to(device)
model_logit_diff = calculate_logit_diff(batch_clean_data, num_batches, batch_size, use_circuit=False)
print(f"Model logit diff: {model_logit_diff}")

# Circuit logit diff (Original mask)
for sae in saes:
    with torch.no_grad():
        sae.mask.mask.data = sae.original_mask.mask.data
circuit_logit_diff = calculate_logit_diff(batch_clean_data, num_batches, batch_size, use_circuit=True)
print(f"Circuit logit diff: {circuit_logit_diff}")

# Evaluate F(C \ K) and F(M \ K) for N batches
for case in ['circuit', 'model']:
    total_knock_logit_diff = 0.0
    for batch_idx in range(num_batches):
        batch_data = batch_clean_data[batch_idx * batch_size : (batch_idx + 1) * batch_size]
        with torch.no_grad():
            for sae in saes:
                # Get indices to remove based on the original mask
                indices_to_remove = get_indices_to_remove(sae.original_mask.mask.data, num_remove)
                if case == 'circuit':
                    sae.mask.mask.data = apply_subset_removal(sae.original_mask.mask.data, indices_to_remove).to(device)
                else:
                    sae.mask.mask.data = apply_subset_removal(sae.all_ones_mask.mask.data, indices_to_remove).to(device)
            logits = model.run_with_hooks(
                batch_data,
                return_type="logits",
                fwd_hooks=build_hooks_list(
                    batch_data, 
                    cache_sae_activations=False, 
                    use_mask=True, 
                    binarize_mask=True, 
                    mean_mask=True
                )
            )
            logit_diff = logit_diff_fn(logits).item()
            total_knock_logit_diff += logit_diff
            del logits
            cleanup_cuda()
    # Average logit difference for the case
    avg_knock_logit_diff = total_knock_logit_diff / num_batches
    if case == 'circuit':
        print(f"F(C \ K): {avg_knock_logit_diff}")
    else:
        print(f"F(M \ K): {avg_knock_logit_diff}")


Mean Accum Progress: 6it [00:08,  1.38s/it]                       


Model logit diff: 3.799908002217611
Circuit logit diff: -16.73712158203125
F(C \ K): -16.738837560017902
F(M \ K): 3.792098601659139
