In [6]:
import json
from sae_lens import SAE, HookedSAETransformer
from functools import partial
import einops
import os
import gc
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torch.optim as optim
from datasets import load_dataset
from transformers import AutoTokenizer
from transformer_lens.hook_points import (
    HookPoint,
) 
import numpy as np
import pandas as pd
from pprint import pprint as pp
from typing import Tuple
from torch import Tensor
from functools import lru_cache
from typing import TypedDict, Optional, Tuple, Union
from tqdm import tqdm
import random

In [3]:
with open("config.json", 'r') as file:
   config = json.load(file)
token = config.get('huggingface_token', None)
os.environ["HF_TOKEN"] = token

# Define device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

hf_cache = "/work/pi_jensen_umass_edu/jnainani_umass_edu/mechinterp/huggingface_cache/hub"
os.environ["HF_HOME"] = hf_cache

# Load the model
model = HookedSAETransformer.from_pretrained("google/gemma-2-9b", device=device, cache_dir=hf_cache) 



Device: cuda


Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]



Loaded pretrained model google/gemma-2-9b into HookedTransformer


In [5]:
pad_token_id = model.tokenizer.pad_token_id
for param in model.parameters():
   param.requires_grad_(False)

device = "cuda"
layers= [7, 14, 21, 40]
l0s = [92, 67, 129, 125]
saes = [SAE.from_pretrained(release="gemma-scope-9b-pt-res", sae_id=f"layer_{layers[i]}/width_16k/average_l0_{l0s[i]}", device=device)[0] for i in range(len(layers))]

# utilities

In [9]:
def cleanup_cuda():
   torch.cuda.empty_cache()
   gc.collect()

def clear_memory():
   for sae in saes:
      for param in sae.parameters():
         param.grad = None
      for param in sae.mask.parameters():
         param.grad = None

   for param in model.parameters():
      param.grad = None
   cleanup_cuda()


In [8]:
cleanup_cuda()

# Mask

In [10]:
class SparseMask(nn.Module):
    def __init__(self, shape, l1, seq_len=None, distinct_l1=0):
        super().__init__()
        if seq_len is not None:
            self.mask = nn.Parameter(torch.ones(seq_len, shape))
        else:
            self.mask = nn.Parameter(torch.ones(shape))
        self.l1 = l1
        self.distinct_l1 = distinct_l1
        self.max_temp = torch.tensor(1000.0)
        self.sparsity_loss = None
        self.ratio_trained = 1
        self.temperature = 1
        self.distinct_sparsity_loss = 0


    def forward(self, x, binary=False, mean_ablation=None):
        if binary:
            # binary mask, 0 if negative, 1 if positive
            binarized = (self.mask > 0).float()
            if mean_ablation is None:
                return x * binarized
            else:
                diff = x - mean_ablation
                return diff * binarized + mean_ablation
            

        self.temperature = self.max_temp ** self.ratio_trained
        mask = torch.sigmoid(self.mask * self.temperature)
        self.sparsity_loss = torch.abs(mask).sum() * self.distinct_l1
        if len(mask.shape) == 2:
            self.distinct_sparsity_loss = torch.abs(mask).max(dim=0).values.sum() * self.l1

        if mean_ablation is None:
            return x * mask
        else:
            diff = x - mean_ablation
            return diff * mask + mean_ablation

# for sae in saes:
#     sae.mask = SparseMask(sae.cfg.d_sae, 1.0, seq_len=65)

In [14]:
bos_token_id = model.tokenizer.bos_token_id

def build_sae_hook_fn(
    sae,
    sequence,
    circuit_mask=None,
    use_mask=False,
    binarize_mask=False,
    mean_mask=False,
    ig_mask_threshold=None,
    cache_sae_grads=False,
    cache_masked_activations=False,
    cache_sae_activations=False,
    mean_ablate=False, 
    fake_activations=False):    
    
    mask = torch.ones_like(sequence, dtype=torch.bool)
    mask[sequence == bos_token_id] = False 

    def sae_hook(value, hook):
        # print(f"sae {sae.cfg.hook_name} running at layer {hook.layer()}")
        feature_acts = sae.encode(value)
        feature_acts = feature_acts * mask.unsqueeze(-1)
        if fake_activations != False and sae.cfg.hook_layer == fake_activations[0]:
            feature_acts = fake_activations[1]
        if cache_sae_grads:
            raise NotImplementedError("torch is confusing")
            sae.feature_acts = feature_acts.requires_grad_(True)
            sae.feature_acts.retain_grad()
        
        if cache_sae_activations:
            sae.feature_acts = feature_acts.detach().clone()
        
        # Learned Binary Masking
        if use_mask:
            if mean_mask:
                # apply the mask, with mean ablations
                feature_acts = sae.mask(feature_acts, binary=binarize_mask, mean_ablation=sae.mean_ablation)
            else:
                # apply the mask, without mean ablations
                feature_acts = sae.mask(feature_acts, binary=binarize_mask)

        # IG Masking
        if ig_mask_threshold != None:
            # apply the ig mask
            if mean_mask:
                feature_acts = sae.igmask(feature_acts, threshold=ig_mask_threshold, mean_ablation=sae.mean_ablation)
            else:
                feature_acts = sae.igmask(feature_acts, threshold=ig_mask_threshold)

        if circuit_mask is not None:
            raise NotImplementedError("mask interface not supported")
            mask_method = circuit_mask['mask_method']
            mask_indices = circuit_mask[sae.cfg.hook_name]
            if mask_method == 'keep_only':
                # any activations not in the mask are set to 0
                expanded_circuit_mask = torch.zeros_like(feature_acts)
                expanded_circuit_mask[:, :, mask_indices] = 1
                feature_acts = feature_acts * expanded_circuit_mask
            elif mask_method == 'zero_only':
                feature_acts[:, :, mask_indices] = 0
            else:
                raise ValueError(f"mask_method {mask_method} not recognized")
            
        if cache_masked_activations:
            sae.feature_acts = feature_acts.detach().clone()
        if mean_ablate:
            feature_acts = sae.mean_ablation

        out = sae.decode(feature_acts)
        # choose out or value based on the mask
        mask_expanded = mask.unsqueeze(-1).expand_as(value)
        value = torch.where(mask_expanded, out, value)
        return value
    return sae_hook

def build_hooks_list(sequence,
                    cache_sae_activations=False,
                    cache_sae_grads=False,
                    circuit_mask=None,
                    use_mask=False,
                    binarize_mask=False,
                    mean_mask=False,
                    cache_masked_activations=False,
                    mean_ablate=False,
                    fake_activations: Tuple[int, torch.Tensor] = False,
                    ig_mask_threshold=None,
                    ):
    hooks = []
    for sae in saes:
        hooks.append(
            (
            sae.cfg.hook_name,
            build_sae_hook_fn(sae, sequence, cache_sae_grads=cache_sae_grads, circuit_mask=circuit_mask, use_mask=use_mask, binarize_mask=binarize_mask, cache_masked_activations=cache_masked_activations, cache_sae_activations=cache_sae_activations, mean_mask=mean_mask, mean_ablate=mean_ablate, fake_activations=fake_activations, ig_mask_threshold=ig_mask_threshold),
            )
        )
    return hooks 

def build_sae_logitfn(**kwargs):
    def logitfn(tokens):
        return model.run_with_hooks(
            tokens, 
            return_type="logits", 
            fwd_hooks=build_hooks_list(tokens, **kwargs)
            )
    return logitfn

# Data 

In [13]:
import json
file_path = 'data/sva/rc_train.json'
with open(file_path, 'r') as file:
    data = [json.loads(line) for line in file]
for entry in data:
    print(entry)
    break

{'clean_prefix': 'The friends that the dancer visits', 'patch_prefix': 'The friend that the dancer visits', 'clean_answer': ' go', 'patch_answer': ' goes', 'case': 'plural_singular'}


In [93]:
clean_data = []
corr_data = []
clean_labels = []
corr_labels = []
for entry in data:
    if model.to_tokens(entry['clean_prefix']).shape[-1] == 7:
        clean_data.append(entry['clean_prefix'])
        corr_data.append(entry['patch_prefix'])
        clean_labels.append(entry['clean_answer'])
        corr_labels.append(entry['patch_answer'])

In [71]:
model.generate("The friend that the dancer visits", max_new_tokens=5)

  0%|          | 0/5 [00:00<?, ?it/s]

'The friend that the dancer visits <u>had just got fired'

In [92]:
from transformer_lens.utils import test_prompt
test_prompt("The doctor that the chef forgives", " are", model)

Tokenized prompt: ['<bos>', 'The', ' doctor', ' that', ' the', ' chef', ' for', 'gives']
Tokenized answer: [' are']


Top 0th token. Logit: 29.81 Prob: 20.53% Token: | is|
Top 1th token. Logit: 28.81 Prob:  7.56% Token: |,|
Top 2th token. Logit: 28.64 Prob:  6.40% Token: | and|
Top 3th token. Logit: 28.63 Prob:  6.35% Token: | the|
Top 4th token. Logit: 28.35 Prob:  4.76% Token: | for|
Top 5th token. Logit: 28.29 Prob:  4.49% Token: | him|
Top 6th token. Logit: 28.16 Prob:  3.96% Token: |.|
Top 7th token. Logit: 27.99 Prob:  3.33% Token: | his|
Top 8th token. Logit: 27.98 Prob:  3.30% Token: | in|
Top 9th token. Logit: 27.88 Prob:  3.00% Token: | himself|


In [89]:
model.to_tokens(entry['clean_prefix']).shape[-1]

7

In [94]:
clean_data[2]

'The carpenters that the dancers praise'

In [None]:
model.to_tokens(entry['clean_prefix'])

In [110]:
N = 10
clean_tokens = model.to_tokens(clean_data[:N])
corr_tokens = model.to_tokens(corr_data[:N])
clean_label_tokens = model.to_tokens(clean_labels[:N], prepend_bos=False).squeeze(-1)
corr_label_tokens = model.to_tokens(corr_labels[:N], prepend_bos=False).squeeze(-1)
print(clean_tokens.shape, corr_tokens.shape)

torch.Size([10, 7]) torch.Size([10, 7])


In [117]:
use_mask = False 
mean_mask = False
logits = model.run_with_hooks(
    clean_tokens, 
    return_type="logits", 
    fwd_hooks=build_hooks_list(clean_tokens, use_mask=use_mask, mean_mask=mean_mask)
    )
def logit_diff_fn(logits, clean_labels, corr_labels, token_wise=False):
    clean_logits = logits[torch.arange(logits.shape[0]), -1, clean_labels]
    corr_logits = logits[torch.arange(logits.shape[0]), -1, corr_labels]
    return (clean_logits - corr_logits).mean() if not token_wise else (clean_logits - corr_logits)

In [118]:
logit_diff_fn(logits, clean_label_tokens, corr_label_tokens)

tensor(2.8465, device='cuda:0', grad_fn=<MeanBackward0>)

In [None]:
@lru_cache(maxsize=None)
def memoize_encode(text):
    return(model.to_tokens(text))

def response_to_token(response):
    return memoize_encode(str(response)).squeeze()[1]

In [None]:
simple_dataset = []
simple_labels = []

# answer_token = model.to_single_token("1")
# traceback_token = model.to_single_token("Traceback")

for item in json_dataset:
    simple_dataset.append(item["correct"]["prompt"])
    simple_dataset.append(item["error"]["prompt"])
    simple_labels.append(response_to_token(item["correct"]["response"]))
    simple_labels.append(response_to_token(item["error"]["response"]))


simple_dataset = model.to_tokens(simple_dataset)
simple_labels = torch.tensor(simple_labels)



In [None]:
vocab = simple_labels.view(-1).unique()

def get_highest_other_logit(logits, label):
    other_logits = logits.clone()
    # set the logits for the correct label to -infinity
    other_logits[..., label] = -float("inf")
    # get the highest logit for the other labels
    mask = torch.ones_like(other_logits) * -float("inf")
    mask[..., vocab] = 0
    other_logits = other_logits + mask

    return other_logits.max(dim=-1).values

def get_highest_other_prob(logits, label):
    other_logits = logits.clone()
    # set the logits for the correct label to -infinity
    other_logits[..., label] = 0
    # get the highest logit for the other labels
    #mask = torch.ones_like(other_logits) * 0
    #mask[..., vocab] = 0
    #other_logits = other_logits + mask

    return other_logits.max(dim=-1).values


In [None]:
def compute_token_probabilities(logits, seq_idxs, labels):
    """Compute probabilities, logit differences, and CE loss for token predictions."""
    # Get probabilities for all tokens at the sequence positions
    probs = F.softmax(logits[torch.arange(logits.shape[0]), seq_idxs], dim=-1)
    
    # Calculate mean probabilities for correct and error tokens
    prob_correct = probs[torch.arange(probs.shape[0]), labels].mean().item()
    max_prob_idxs = probs.argmax(dim=-1)
    top1_acc = (max_prob_idxs == labels).float().mean().item()
    prob_error = get_highest_other_prob(probs, labels).mean().item()
    
    # Calculate logit difference
    correct_logits = logits[torch.arange(logits.shape[0]), seq_idxs, labels]
    error_logits = get_highest_other_logit(logits[torch.arange(logits.shape[0]), seq_idxs], labels)
    logit_diff = (correct_logits - error_logits).mean().item()
    
    # Calculate CE loss
    ce_loss = F.cross_entropy(
        logits[torch.arange(logits.shape[0]), seq_idxs], 
        labels
    ).item()
    
    return prob_correct, prob_error, logit_diff, top1_acc, ce_loss

def sanity_check_model_performance(logitfn):
    """Run sanity checks on model performance using contrastive examples."""
    baseline_dataset = ContrastiveDatasetBatch(json_dataset[-10:])
    
    # Get logits for both correct and error examples
    correct_logits = logitfn(baseline_dataset.correct_tokenized)
    error_logits = logitfn(baseline_dataset.error_tokenized)
    
    # Analyze correct examples
    prob_correct_in_correct, prob_error_in_correct, logit_diff_correct, top1_correct, ce_loss_correct = compute_token_probabilities(
        correct_logits,
        baseline_dataset.correct_answer_seq_idxs,
        baseline_dataset.correct_labels,
    )
    
    # Analyze error examples
    prob_error_in_error, prob_correct_in_error, logit_diff_error, top1_error, ce_loss_error = compute_token_probabilities(
        error_logits,
        baseline_dataset.error_answer_seq_idxs,
        baseline_dataset.error_labels,
    )
    
    return {
        'prob_age_given_correct': prob_correct_in_correct,
        'prob_traceback_given_correct(bad)': prob_error_in_correct,
        'logit_diff_correct(-=bad)': logit_diff_correct,
        'top1_correct': top1_correct,
        'ce_loss_correct': ce_loss_correct,
        'prob_traceback_given_error': prob_error_in_error,
        'prob_age_given_error(bad)': prob_correct_in_error,
        'logit_diff_error(-=bad)': logit_diff_error,
        'top1_error': top1_error,
        'ce_loss_error': ce_loss_error
    }

In [None]:
bos_token_id = model.tokenizer.bos_token_id

def build_sae_hook_fn(
    # Core components
    sae,
    sequence,
    
    # Masking options
    circuit_mask=None,
    use_mask=False,
    binarize_mask=False,
    mean_mask=False,
    ig_mask_threshold=None,
    
    # Caching behavior
    cache_sae_grads=False,
    cache_masked_activations=False,
    cache_sae_activations=False,
    
    # Ablation options
    mean_ablate=False,  # Controls mean ablation of the SAE
    fake_activations=False,  # Controls whether to use fake activations
    ):    # make the mask for the sequence
    mask = torch.ones_like(sequence, dtype=torch.bool)
    # mask[sequence == pad_token_id] = False
    mask[sequence == bos_token_id] = False # where mask is false, keep original
    def sae_hook(value, hook):
        # print(f"sae {sae.cfg.hook_name} running at layer {hook.layer()}")
        feature_acts = sae.encode(value)
        feature_acts = feature_acts * mask.unsqueeze(-1)
        if fake_activations != False and sae.cfg.hook_layer == fake_activations[0]:
            feature_acts = fake_activations[1]
        if cache_sae_grads:
            raise NotImplementedError("torch is confusing")
            sae.feature_acts = feature_acts.requires_grad_(True)
            sae.feature_acts.retain_grad()
        
        if cache_sae_activations:
            sae.feature_acts = feature_acts.detach().clone()
        
        # Learned Binary Masking
        if use_mask:
            if mean_mask:
                # apply the mask, with mean ablations
                feature_acts = sae.mask(feature_acts, binary=binarize_mask, mean_ablation=sae.mean_ablation)
            else:
                # apply the mask, without mean ablations
                feature_acts = sae.mask(feature_acts, binary=binarize_mask)

        # IG Masking
        if ig_mask_threshold != None:
            # apply the ig mask
            if mean_mask:
                feature_acts = sae.igmask(feature_acts, threshold=ig_mask_threshold, mean_ablation=sae.mean_ablation)
            else:
                feature_acts = sae.igmask(feature_acts, threshold=ig_mask_threshold)

                
        

        

        if circuit_mask is not None:
            raise NotImplementedError("mask interface not supported")
            mask_method = circuit_mask['mask_method']
            mask_indices = circuit_mask[sae.cfg.hook_name]
            if mask_method == 'keep_only':
                # any activations not in the mask are set to 0
                expanded_circuit_mask = torch.zeros_like(feature_acts)
                expanded_circuit_mask[:, :, mask_indices] = 1
                feature_acts = feature_acts * expanded_circuit_mask
            elif mask_method == 'zero_only':
                feature_acts[:, :, mask_indices] = 0
            else:
                raise ValueError(f"mask_method {mask_method} not recognized")
            
        if cache_masked_activations:
            sae.feature_acts = feature_acts.detach().clone()
        if mean_ablate:
            feature_acts = sae.mean_ablation

        out = sae.decode(feature_acts)
        # choose out or value based on the mask
        mask_expanded = mask.unsqueeze(-1).expand_as(value)
        value = torch.where(mask_expanded, out, value)
        return value
    return sae_hook


    # def sae_hook_ablate(value, hook):
    # feature_acts = sae.encode(value)
    # # feature_acts[:, :, topsae_attr_indices] = 0
    # out = sae.decode(feature_acts)
    # return out


def build_hooks_list(sequence,
                    cache_sae_activations=False,
                    cache_sae_grads=False,
                    circuit_mask=None,
                    use_mask=False,
                    binarize_mask=False,
                    mean_mask=False,
                    cache_masked_activations=False,
                    mean_ablate=False,
                    fake_activations: Tuple[int, torch.Tensor] = False,
                    ig_mask_threshold=None,
                    ):
    hooks = []
    # blocks.0.hook_resid_pre
    # # fake hook that adds zero so gradients propagate through the model
    # param = nn.Parameter(torch.tensor(0.0, requires_grad=True))
    # hooks.append(
    #     (
    #         "blocks.0.hook_resid_pre",
    #         lambda value, hook: value + param,
    #     )
    # )
    for sae in saes:
        hooks.append(
            (
            sae.cfg.hook_name,
            build_sae_hook_fn(sae, sequence, cache_sae_grads=cache_sae_grads, circuit_mask=circuit_mask, use_mask=use_mask, binarize_mask=binarize_mask, cache_masked_activations=cache_masked_activations, cache_sae_activations=cache_sae_activations, mean_mask=mean_mask, mean_ablate=mean_ablate, fake_activations=fake_activations, ig_mask_threshold=ig_mask_threshold),
            )
        )
    return hooks 

def build_sae_logitfn(**kwargs):
    def logitfn(tokens):
        return model.run_with_hooks(
            tokens, 
            return_type="logits", 
            fwd_hooks=build_hooks_list(tokens, **kwargs)
            )
    return logitfn

In [None]:
batch_size = 16

# makethe length of the dataset a multiple of the batch size
simple_dataset = simple_dataset[:batch_size*(len(simple_dataset)//batch_size)]
simple_labels = simple_labels[:batch_size*(len(simple_labels)//batch_size)]

simple_dataset = simple_dataset.view(-1, batch_size, 65)
simple_labels = simple_labels.view(-1, batch_size)


In [None]:
def logitfn(tokens):
    return model.run_with_hooks(
        tokens, 
        return_type="logits", 
        fwd_hooks=build_hooks_list(tokens, use_mask=True, mean_mask=True)
        )


def forward_pass(batch, labels, logitfn, ratio_trained=1):
    for sae in saes:
        sae.mask.ratio_trained = ratio_trained
    tokens = batch
    logits = logitfn(tokens)
    last_token_logits = logits[:, -1, :]
    loss = F.cross_entropy(last_token_logits, labels)
    sparsity_loss = 0
    for sae in saes:
        sparsity_loss = sparsity_loss + sae.mask.sparsity_loss
    
    sparsity_loss = sparsity_loss / len(saes)

    return loss, sparsity_loss

def mask_logits(logits):
    """
    Mask logits to only allow tokens in vocab_tensor by setting all other logits to -inf.
    Works with any number of leading batch dimensions.
    
    Args:
        logits: Tensor of shape [..., vocab_size] containing the logits
    
    Returns:
        Tensor of same shape as logits with all non-vocab logits set to -inf
    """
    vocab_tensor = simple_labels.view(-1).unique()
    mask = torch.zeros_like(logits)
    mask[..., vocab_tensor] = 1
    return logits.masked_fill(mask == 0, float('-inf'))

In [None]:
def do_training_run(sparsity_multiplier, loss_function='ce', per_token_mask=False, use_mask=False, mean_mask=False, distinct_sparsity_multiplier=0):

    def logitfn(tokens):
        logits =  model.run_with_hooks(
            tokens, 
            return_type="logits", 
            fwd_hooks=build_hooks_list(tokens, use_mask=use_mask, mean_mask=mean_mask)
            )
        if loss_function == 'ce_vocab':
            return mask_logits(logits)
        return logits

    def forward_pass(batch, labels, logitfn, ratio_trained=1):
        for sae in saes:
            sae.mask.ratio_trained = ratio_trained
        tokens = batch
        logits = logitfn(tokens)
        last_token_logits = logits[:, -1, :]
        if loss_function == 'ce' or loss_function == 'ce_vocab':
            loss = F.cross_entropy(last_token_logits, labels)
            
        elif loss_function == 'logit_diff':
            # from the last token logits, get the logit for the "1" token and the "Traceback" token
            correct_logits = last_token_logits[torch.arange(last_token_logits.shape[0]), labels]
            incorrect_labels = torch.where(labels == answer_token, traceback_token, answer_token)
            incorrect_logits = last_token_logits[torch.arange(last_token_logits.shape[0]), incorrect_labels]
            loss = (incorrect_logits - correct_logits).mean() # it should very negative if the model is right.
        else:
            raise ValueError(f"Loss function {loss_function} not recognized")
        sparsity_loss = 0
        if per_token_mask:
            distinct_sparsity_loss = 0
        for sae in saes:
            sparsity_loss = sparsity_loss + sae.mask.sparsity_loss
            if per_token_mask:
                distinct_sparsity_loss = distinct_sparsity_loss + sae.mask.distinct_sparsity_loss
        
        sparsity_loss = sparsity_loss / len(saes)
        distinct_sparsity_loss = distinct_sparsity_loss / len(saes)

        return loss, sparsity_loss, distinct_sparsity_loss

    print("doing a run with sparsity multiplier", sparsity_multiplier)
    all_optimized_params = []
    config = {
        "batch_size": 16,
        "learning_rate": 0.05,
        "total_steps": simple_dataset.shape[0]*0.01,
        "sparsity_multiplier": sparsity_multiplier
    }

    for sae in saes:
        if per_token_mask:
            sae.mask = SparseMask(sae.cfg.d_sae, 1.0, seq_len=65, distinct_l1=1.0)
        else:
            sae.mask = SparseMask(sae.cfg.d_sae, 1.0)
        all_optimized_params.extend(list(sae.mask.parameters()))
        sae.mask.max_temp = torch.tensor(200.0)
    
    wandb.init(project="sae circuits", config=config)
    optimizer = optim.Adam(all_optimized_params, lr=config["learning_rate"])
    total_steps = config["total_steps"]

    with tqdm(total=total_steps, desc="Training Progress") as pbar:
        for i, (x, y) in enumerate(zip(simple_dataset, simple_labels)):
            with KeyboardInterruptBlocker():
                optimizer.zero_grad()
                
                # Calculate ratio trained
                ratio_trained = i / total_steps
                
                # Update mask ratio for each SAE
                for sae in saes:
                    sae.mask.ratio_trained = ratio_trained
                
                # Forward pass with updated ratio_trained
                loss, sparsity_loss, distinct_sparsity_loss = forward_pass(x, y, logitfn, ratio_trained=ratio_trained)
                if per_token_mask:
                    sparsity_loss = sparsity_loss / 65


                avg_nonzero_elements = sparsity_loss
                avg_distinct_nonzero_elements = distinct_sparsity_loss
                    
                sparsity_loss = sparsity_loss * config["sparsity_multiplier"] + distinct_sparsity_loss * distinct_sparsity_multiplier
                total_loss = loss + sparsity_loss
                infodict  = {"Step": i, "Progress": ratio_trained, "Avg Nonzero Elements": avg_nonzero_elements.item(), "avg distinct lat/sae":avg_distinct_nonzero_elements.item(), "Task Loss": loss.item(), "Sparsity Loss": sparsity_loss.item(), "temperature": saes[0].mask.temperature}
                wandb.log(infodict)
                
                # Backward pass and optimizer step
                total_loss.backward()
                optimizer.step()
                
                # Update tqdm bar with relevant metrics
                pbar.set_postfix(infodict)
                
                # Update the tqdm progress bar
                pbar.update(1)
                if i >= total_steps*1.3:
                    break
    wandb.finish()

    optimizer.zero_grad()

    for sae in saes:
        for param in sae.parameters():
            param.grad = None
        for param in sae.mask.parameters():
            param.grad = None
    
    for param in model.parameters():
        param.grad = None

    torch.cuda.empty_cache()

    mask_dict = {}

    total_density = 0
    for sae in saes:
        mask_dict[sae.cfg.hook_name] = torch.where(torch.sigmoid(sae.mask.mask*10000))[0].tolist()   # rob thinks .view(-1) needed here
        total_density += torch.sigmoid(sae.mask.mask*10000).sum().item()
    mask_dict["total_density"] = total_density
    mask_dict['avg_density'] = total_density / len(saes)

    if per_token_mask:
        print("total # latents in circuit: ", total_density)
    print("avg density", mask_dict['avg_density'])

    ### EVAL ###
    def masked_logit_fn(tokens):
        logits =  model.run_with_hooks(
            tokens, 
            return_type="logits", 
            fwd_hooks=build_hooks_list(tokens, use_mask=use_mask, mean_mask=mean_mask, binarize_mask=True)
            )
        if loss_function == 'ce_vocab':
            return mask_logits(logits)
        return logits

    def eval_ce_loss(batch, labels, logitfn, ratio_trained=10):
        for sae in saes:
            sae.mask.ratio_trained = ratio_trained
        tokens = batch
        logits = logitfn(tokens)
        last_token_logits = logits[:, -1, :]
        loss = F.cross_entropy(last_token_logits, labels)

        # sparsity_loss = 0
        # for sae in saes:
        #     sparsity_loss = sparsity_loss + sae.mask.sparsity_loss
        
        # sparsity_loss = sparsity_loss / len(saes)

        return loss

    with torch.no_grad():
        loss = eval_ce_loss(simple_dataset[-1], simple_labels[-1], masked_logit_fn)
        print("CE loss:", loss)

    mask_dict['ce_loss'] = loss.item()


    json.dump(mask_dict, open(f"{str(sparsity_multiplier)}_run.json", "w"))
    
    

In [None]:
def masked_logit_fn(tokens):
    return model.run_with_hooks(
        tokens, 
        return_type="logits", 
        fwd_hooks=build_hooks_list(tokens, use_mask=True, mean_mask=True, binarize_mask=True)
        )

def eval_ce_loss(batch, labels, logitfn, ratio_trained=10):
    for sae in saes:
        sae.mask.ratio_trained = ratio_trained
    tokens = batch
    logits = logitfn(tokens)
    last_token_logits = logits[:, -1, :]
    loss = F.cross_entropy(last_token_logits, labels)

    # sparsity_loss = 0
    # for sae in saes:
    #     sparsity_loss = sparsity_loss + sae.mask.sparsity_loss
    
    # sparsity_loss = sparsity_loss / len(saes)

    return loss

with torch.no_grad():
    loss = eval_ce_loss(simple_dataset[21], simple_labels[21], masked_logit_fn)
    
    print("CE loss:", loss)



In [None]:
# do_training_run(14, loss_function='logit_diff', per_token_mask=True, use_mask=True, mean_mask=True)
# do_training_run(0.1, loss_function="ce", per_token_mask=True, use_mask=True, mean_mask=True, distinct_sparsity_multiplier=0.02)
# do_training_run(0.025, per_token_mask=False, use_mask=True, mean_mask=True)
do_training_run(0.6, loss_function="ce", per_token_mask=True, use_mask=True, mean_mask=True)


In [None]:
def logitfn(tokens, vocab_mask=False):
    logits =  model.run_with_hooks(
        tokens, 
        return_type="logits", 
        fwd_hooks=build_hooks_list(tokens, use_mask=True, mean_mask=True, binarize_mask=True)
        )
    if vocab_mask:
        return mask_logits(logits)
    return logits

with torch.no_grad():
    pp(sanity_check_model_performance(logitfn))

In [None]:
def logitfn(tokens, vocab_mask=False):
    logits =  model.run_with_hooks(
        tokens, 
        return_type="logits", 
        fwd_hooks=build_hooks_list(tokens, use_mask=True, mean_mask=True, binarize_mask=True)
        )
    if vocab_mask:
        return mask_logits(logits)
    return logits

# def forward_pass(batch, labels, logitfn, ratio_trained=1):
#     for sae in saes:
#         sae.mask.ratio_trained = ratio_trained
#     tokens = batch
#     logits = logitfn(tokens)
#     last_token_logits = logits[:, -1, :]
#     loss = F.cross_entropy(last_token_logits, labels)
#     sparsity_loss = 0
#     for sae in saes:
#         sparsity_loss = sparsity_loss + sae.mask.sparsity_loss
    
#     sparsity_loss = sparsity_loss / len(saes)

    return loss, sparsity_loss
with torch.no_grad():
    idx = 0
    print("original example")
    print(f"{model.to_string(simple_dataset[-1][idx])}")
    print("original label")
    print(model.to_string(simple_labels[-1][idx:idx+1]))
    logits = logitfn(simple_dataset[-1][idx:idx+1][-1])[-1][-1]
    print(logits.shape)
    probs = F.softmax(logits, dim=-1)
    print(probs.shape)
    # get the top 3 tokens and their probabilities
    topk = torch.topk(probs, k=3)
    print(model.to_str_tokens(topk.indices))
    print(topk.values)
    # get the cross entropy loss
    loss = F.cross_entropy(logits.unsqueeze(0), simple_labels[-1][idx:idx+1])
    print(f"ce loss {loss}")

with torch.no_grad():
    correct = 0
    num_items = 16
    for i in range(num_items):
        example = simple_dataset[-2][i]
        label = simple_labels[-2][i]
        logits = logitfn(example)[-1][-1]
        top_logit = torch.argmax(logits)
        if top_logit == label:
            print("correct", model.to_string(label))
            correct += 1
        else:
            print("error", model.to_string(label))
        
    print(f"correct: {correct} out of {num_items}")

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import ListedColormap, BoundaryNorm

# Assuming 'saes', 'model', and 'simple_dataset' are defined
tokens = model.to_str_tokens(simple_dataset[2][0])
num_masks = 4 # Number of masks you have
counts_per_mask = []
for mask_index in range(num_masks):
    testmask = saes[mask_index].mask.mask.data.clone()
    binarized = (testmask > 0.0).float()
    counts = []
    for i in range(len(tokens)):
        counts.append(torch.count_nonzero(binarized[i]).item())
    counts_per_mask.append(counts)

# Convert counts to a NumPy array
data = np.array(counts_per_mask) # Shape: (num_masks, num_tokens)

# Create a mask for zero values
zero_mask = data == 0

# Define a colormap
cmap = sns.color_palette("viridis", as_cmap=True)

# Plot the heatmap with the mask
plt.figure(figsize=(14, 6))
ax = sns.heatmap(
    data,
    annot=True,
    fmt='d',
    cmap=cmap,
    mask=zero_mask,
    cbar_kws={'label': 'Counts'},
    linewidths=0.5,
    linecolor='gray'
)

# Set x-axis labels to tokens
ax.set_xticks(np.arange(len(tokens)) + 0.5)
ax.set_xticklabels(tokens, rotation=90, fontsize=8)

# Add numeric indices above the chart
ax2 = ax.twiny()
ax2.set_xlim(ax.get_xlim())
ax2.set_xticks(np.arange(len(tokens)) + 0.5)
ax2.set_xticklabels(np.arange(len(tokens)), rotation=90)  # Rotate indices 90 degrees
ax2.set_xlabel('Token Indices')

# Set y-axis labels to masks
ax.set_yticks(np.arange(num_masks) + 0.5)
ax.set_yticklabels([f'SAE {i}' for i in range(num_masks)], rotation=0)

plt.xlabel('Tokens')
plt.ylabel('SAE Number Active Latents')
plt.title('Active SAE Latents per Token per Mask (Zero Counts Hidden)')
plt.tight_layout()
plt.show()

# print the total # latents in the circuit
total_latents = 0
for i, sae in enumerate(saes):
    print(f"layer {i} latents: {torch.sum(sae.mask.mask > 0)} 🖕")
    total_latents += torch.sum(sae.mask.mask > 0)
print(f"total # latents: {total_latents} 🖕")


In [None]:
torch.nonzero(saes[2].mask.mask[64]>0)

In [None]:
print(model.to_string(simple_dataset[0][2]))

In [None]:
def masked_logit_fn(tokens):
    return model.run_with_hooks(
        tokens, 
        return_type="logits", 
        fwd_hooks=build_hooks_list(tokens, use_mask=True, binarize_mask=True, mean_mask=True)
        )

def baseline_logit_fn(tokens):
    return model.run_with_hooks(
        tokens, 
        return_type="logits", 
        fwd_hooks=[]
        )

def baseline_sae_logit_fn(tokens):
    return model.run_with_hooks(
        tokens, 
        return_type="logits", 
        fwd_hooks=build_hooks_list(tokens)
        )

# serialize the masks
mask_dict = {}
for sae in saes:
    mask_dict[sae.cfg.hook_name] = torch.where(torch.sigmoid(sae.mask.mask*10000))[0].tolist()

json.dump(mask_dict, open("mask_dict.json", "w"))
len(mask_dict["blocks.7.hook_resid_post"])

for sae in saes:
    mask = sae.mask.mask
    print(f"Nonzero elements in mask for {sae.cfg.hook_name}: {torch.count_nonzero(torch.sigmoid(mask*1000))}")

def eval_ce_loss(batch, labels, logitfn, ratio_trained=10):
    for sae in saes:
        sae.mask.ratio_trained = ratio_trained
    tokens = batch
    logits = logitfn(tokens)
    last_token_logits = logits[:, -1, :]
    loss = F.cross_entropy(last_token_logits, labels)
    sparsity_loss = 0
    for sae in saes:
        sparsity_loss = sparsity_loss + sae.mask.sparsity_loss
    
    sparsity_loss = sparsity_loss / len(saes)

    return loss

with torch.no_grad():
    loss = eval_ce_loss(simple_dataset[-1], simple_labels[-1], baseline_logit_fn)
    print("baseline loss:", loss)

with torch.no_grad():
    loss = eval_ce_loss(simple_dataset[-1], simple_labels[-1], baseline_sae_logit_fn)
    print("sae loss: ", loss)

with torch.no_grad():
    loss = eval_ce_loss(simple_dataset[-1], simple_labels[-1], masked_logit_fn)
    print("ablated loss: ", loss)

def eval_ce_loss(batch, labels, logitfn, ratio_trained=10):
    # Assuming 'saes' is defined elsewhere in your code
    for sae in saes:
        sae.mask.ratio_trained = ratio_trained
    tokens = batch
    logits = logitfn(tokens)
    last_token_logits = logits[:, -1, :]
    loss = F.cross_entropy(last_token_logits, labels)
    sparsity_loss = 0
    for sae in saes:
        sparsity_loss += sae.mask.sparsity_loss
    
    sparsity_loss = sparsity_loss / len(saes)
    total_loss = loss + sparsity_loss

    return loss.item()  # Return the loss as a scalar value



In [None]:
results_list = []

# Method: VANILLA GEMMA 9B
method_name = "VANILLA GEMMA 9B"
with torch.no_grad():
    results = sanity_check_model_performance(baseline_logit_fn)
    # Compute cross-entropy loss
    ce_loss = eval_ce_loss(simple_dataset[-1], simple_labels[-1], baseline_logit_fn)
    results['ce_loss'] = ce_loss
    results['method'] = method_name
    results_list.append(results)

# Method: GEMMA 9B WITH SAE (no masks)
method_name = "GEMMA 9B WITH SAE (no masks)"
with torch.no_grad():
    results = sanity_check_model_performance(baseline_sae_logit_fn)
    # Compute cross-entropy loss
    ce_loss = eval_ce_loss(simple_dataset[-1], simple_labels[-1], baseline_sae_logit_fn)
    results['ce_loss'] = ce_loss
    results['method'] = method_name
    results_list.append(results)

# Method: GEMMA 9B WITH SAE Masked
method_name = "GEMMA 9B WITH SAE Masked"
with torch.no_grad():
    results = sanity_check_model_performance(masked_logit_fn)
    # Compute cross-entropy loss
    ce_loss = eval_ce_loss(simple_dataset[-1], simple_labels[-1], masked_logit_fn)
    results['ce_loss'] = ce_loss
    results['method'] = method_name
    results_list.append(results)

# Create a DataFrame and display it with descriptive column names
df = pd.DataFrame(results_list)
df = df.set_index('method')  # Set 'method' as the index

# Rename columns for better readability
df = df.rename(columns={
    'prob_correct_in_correct': 'P(Age|Correct Ex)',
    'prob_error_in_correct': 'P(Traceback|Correct Ex)',
    'logit_diff_correct': 'Logit Diff (Correct Ex)',
    'prob_error_in_error': 'P(Traceback|Error Ex)',
    'prob_correct_in_error': 'P(Age|Error Ex)',
    'logit_diff_error': 'Logit Diff (Error Ex)',
    'ce_loss': 'Cross-Entropy Loss'
})

# Display the DataFrame
df
 