In [1]:
import json
from sae_lens import SAE, HookedSAETransformer
from functools import partial
import einops
import os
import gc
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torch.optim as optim
from datasets import load_dataset
from transformers import AutoTokenizer
from transformer_lens.hook_points import (
    HookPoint,
) 
import numpy as np
import pandas as pd
from pprint import pprint as pp
from typing import Tuple
from torch import Tensor
from functools import lru_cache
from typing import TypedDict, Optional, Tuple, Union
from tqdm import tqdm
import random

In [2]:
with open("config.json", 'r') as file:
   config = json.load(file)
token = config.get('huggingface_token', None)
os.environ["HF_TOKEN"] = token

# Define device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

hf_cache = "/work/pi_jensen_umass_edu/jnainani_umass_edu/mechinterp/huggingface_cache/hub"
os.environ["HF_HOME"] = hf_cache

# Load the model
model = HookedSAETransformer.from_pretrained("google/gemma-2-9b", device=device, cache_dir=hf_cache) 

pad_token_id = model.tokenizer.pad_token_id
for param in model.parameters():
   param.requires_grad_(False)



Device: cuda


Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]



Loaded pretrained model google/gemma-2-9b into HookedTransformer


In [3]:
device = "cuda"
layers= [7, 14, 21, 40]
l0s = [92, 67, 129, 125]
saes = [SAE.from_pretrained(release="gemma-scope-9b-pt-res", sae_id=f"layer_{layers[i]}/width_16k/average_l0_{l0s[i]}", device=device)[0] for i in range(len(layers))]

In [4]:
def cleanup_cuda():
   torch.cuda.empty_cache()
   gc.collect()

def clear_memory():
   for sae in saes:
      for param in sae.parameters():
         param.grad = None
      for param in sae.mask.parameters():
         param.grad = None

   for param in model.parameters():
      param.grad = None
   cleanup_cuda()


In [5]:
class SAEMasks(nn.Module):
    def __init__(self, hook_points, masks):
        super().__init__()
        self.hook_points = hook_points  # list of strings
        self.masks = masks

    def forward(self, x, sae_hook_point, mean_ablation=None):
        index = self.hook_points.index(sae_hook_point)
        mask = self.masks[index]
        censored_activations = torch.ones_like(x)
        if mean_ablation is not None:
            censored_activations = censored_activations * mean_ablation
        else:
            censored_activations = censored_activations * 0
        
        diff_to_x = x - censored_activations
        return censored_activations + diff_to_x * mask

    def print_mask_statistics(self):
        """
        Prints statistics about each binary mask:
          - total number of elements (latents)
          - total number of 'on' latents (mask == 1)
          - average on-latents per token
            * If shape == [latent_dim], there's effectively 1 token
            * If shape == [seq, latent_dim], it's 'sum of on-latents / seq'
        """
        for i, mask in enumerate(self.masks):
            shape = list(mask.shape)
            total_latents = mask.numel()
            total_on = mask.sum().item()  # number of 1's in the mask

            # Average on-latents per token depends on dimensions
            if len(shape) == 1:
                # e.g., shape == [latent_dim]
                avg_on_per_token = total_on  # only one token
            elif len(shape) == 2:
                # e.g., shape == [seq, latent_dim]
                seq_len = shape[0]
                avg_on_per_token = total_on / seq_len if seq_len > 0 else 0
            else:
                # If there's more than 2 dims, adapt as needed;
                # we'll just define "token" as the first dimension.
                seq_len = shape[0]
                avg_on_per_token = total_on / seq_len if seq_len > 0 else 0

            print(f"Statistics for mask '{self.hook_points[i]}':")
            print(f"  - Shape: {shape}")
            print(f"  - Total latents: {total_latents}")
            print(f"  - Latents ON (mask=1): {int(total_on)}")
            print(f"  - Average ON per token: {avg_on_per_token:.4f}\n")

    def save(self, save_dir, file_name="sae_masks.pt"):
        """
        Saves hook_points and masks to a single file (file_name) within save_dir.
        If you want multiple mask sets in the same directory, call save() with
        different file_name values. The directory is created if it does not exist.
        """
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        save_path = os.path.join(save_dir, file_name)
        checkpoint = {
            "hook_points": self.hook_points,
            "masks": self.masks
        }
        torch.save(checkpoint, save_path)
        print(f"SAEMasks saved to {save_path}")

    @classmethod
    def load(cls, load_dir, file_name="sae_masks.pt"):
        """
        Loads hook_points and masks from a single file (file_name) within load_dir,
        returning an instance of SAEMasks. If you stored multiple mask sets in the
        directory, specify the file_name to load the correct one.
        """
        load_path = os.path.join(load_dir, file_name)
        if not os.path.isfile(load_path):
            raise FileNotFoundError(f"No saved SAEMasks found at {load_path}")

        checkpoint = torch.load(load_path)
        hook_points = checkpoint["hook_points"]
        masks = checkpoint["masks"]

        instance = cls(hook_points=hook_points, masks=masks)
        print(f"SAEMasks loaded from {load_path}")
        return instance
    def get_num_latents(self):
        num_latents = 0
        for mask in self.masks:
            num_latents += (mask>0).sum().item()
        return num_latents


In [6]:
class SparseMask(nn.Module):
    def __init__(self, shape, l1, seq_len=None, distinct_l1=0):
        super().__init__()
        if seq_len is not None:
            self.mask = nn.Parameter(torch.ones(seq_len, shape))
        else:
            self.mask = nn.Parameter(torch.ones(shape))
        self.l1 = l1
        self.distinct_l1 = distinct_l1
        self.max_temp = torch.tensor(1000.0)
        self.sparsity_loss = None
        self.ratio_trained = 1
        self.temperature = 1
        self.distinct_sparsity_loss = 0


    def forward(self, x, binary=False, mean_ablation=None):
        if binary:
            # binary mask, 0 if negative, 1 if positive
            binarized = (self.mask > 0).float()
            if mean_ablation is None:
                return x * binarized
            else:
                diff = x - mean_ablation
                return diff * binarized + mean_ablation
            

        self.temperature = self.max_temp ** self.ratio_trained
        mask = torch.sigmoid(self.mask * self.temperature)
        # mask = self.mask
        self.sparsity_loss = torch.abs(mask).sum() * self.l1
        # print("hello", torch.abs(mask).sum()) 
        # if len(mask.shape) == 2:
        #     self.distinct_sparsity_loss = torch.abs(mask).max(dim=0).values.sum() * self.distinct_l1

        if mean_ablation is None:
            return x * mask
        else:
            diff = x - mean_ablation
            return diff * mask + mean_ablation

# for sae in saes:
#     sae.mask = SparseMask(sae.cfg.d_sae, 1.0, seq_len=65)

In [7]:
class IGMask(nn.Module):
    # igscores is seq x num_sae_latents
    def __init__(self, ig_scores):
        super().__init__()
        self.ig_scores = ig_scores

    def forward(self, x, threshold, mean_ablation = None):
        censored_activations = torch.ones_like(x)
        if mean_ablation != None:
            censored_activations = censored_activations * mean_ablation
        else:
            censored_activations = censored_activations * 0

        mask = (self.ig_scores.abs() > threshold).float()
        
        diff_to_x = x - censored_activations
        return censored_activations + diff_to_x * mask
    
    def get_threshold_info(self, threshold):
        mask = (self.ig_scores.abs() > threshold).float()

        total_latents = mask.sum()
        avg_latents_per_tok = mask.sum()/mask.shape[0]
        latents_per_tok = mask.sum(dim=-1)
        return {"total_latents":total_latents,
                "avg_latents_per_tok":avg_latents_per_tok,
                "latents_per_tok":latents_per_tok}
    
    def get_binarized_mask(self, threshold):
        return (self.ig_scores.abs()>threshold).float()
    
def refresh_class():
    for sae in saes:
        if hasattr(sae, 'igmask'):
            sae.igmask = IGMask(sae.igmask.ig_scores)

try:
    refresh_class()
except Exception as e:
    print(e)

refresh_class()



def produce_ig_binary_masks(threshold=0.01):
    hook_points = []
    masks = []

    for sae in saes:
        hook_point = sae.cfg.hook_name
        mask = sae.igmask.get_binarized_mask(threshold=threshold)
        hook_points.append(hook_point)
        masks.append(mask)
    
    return SAEMasks(
        hook_points=hook_points,
        masks=masks
    )

In [8]:
bos_token_id = model.tokenizer.bos_token_id

def build_sae_hook_fn(
    # Core components
    sae,
    sequence,
    
    # Masking options
    circuit_mask: Optional[SAEMasks] = None,
    use_mask=False,
    binarize_mask=False,
    mean_mask=False,
    ig_mask_threshold=None,
    
    # Caching behavior
    cache_sae_grads=False,
    cache_masked_activations=False,
    cache_sae_activations=False,
    
    # Ablation options
    mean_ablate=False,  # Controls mean ablation of the SAE
    fake_activations=False,  # Controls whether to use fake activations
    ):    # make the mask for the sequence
    mask = torch.ones_like(sequence, dtype=torch.bool)
    # mask[sequence == pad_token_id] = False
    mask[sequence == bos_token_id] = False # where mask is false, keep original
    def sae_hook(value, hook):
        # print(f"sae {sae.cfg.hook_name} running at layer {hook.layer()}")
        feature_acts = sae.encode(value)
        feature_acts = feature_acts * mask.unsqueeze(-1)
        if fake_activations != False and sae.cfg.hook_layer == fake_activations[0]:
            feature_acts = fake_activations[1]
        if cache_sae_grads:
            raise NotImplementedError("torch is confusing")
            sae.feature_acts = feature_acts.requires_grad_(True)
            sae.feature_acts.retain_grad()
        
        if cache_sae_activations:
            sae.feature_acts = feature_acts.detach().clone()
        
        
        
        # Learned Binary Masking
        if use_mask:
            if mean_mask:
                # apply the mask, with mean ablations
                feature_acts = sae.mask(feature_acts, binary=binarize_mask, mean_ablation=sae.mean_ablation)
            else:
                # apply the mask, without mean ablations
                feature_acts = sae.mask(feature_acts, binary=binarize_mask)

        # IG Masking
        if ig_mask_threshold != None:
            # apply the ig mask
            if mean_mask:
                feature_acts = sae.igmask(feature_acts, threshold=ig_mask_threshold, mean_ablation=sae.mean_ablation)
            else:
                feature_acts = sae.igmask(feature_acts, threshold=ig_mask_threshold)

                
        if circuit_mask is not None:
            hook_point = sae.cfg.hook_name
            if mean_mask==True:
                feature_acts = circuit_mask(feature_acts, hook_point, mean_ablation=sae.mean_ablation)
            else:
                feature_acts = circuit_mask(feature_acts, hook_point)
            
        if cache_masked_activations:
            sae.feature_acts = feature_acts.detach().clone()
        if mean_ablate:
            feature_acts = sae.mean_ablation

        out = sae.decode(feature_acts)
        # choose out or value based on the mask
        mask_expanded = mask.unsqueeze(-1).expand_as(value)
        value = torch.where(mask_expanded, out, value)
        return value
    return sae_hook


    # def sae_hook_ablate(value, hook):
    # feature_acts = sae.encode(value)
    # # feature_acts[:, :, topsae_attr_indices] = 0
    # out = sae.decode(feature_acts)
    # return out


def build_hooks_list(sequence,
                    cache_sae_activations=False,
                    cache_sae_grads=False,
                    circuit_mask=None,
                    use_mask=False,
                    binarize_mask=False,
                    mean_mask=False,
                    cache_masked_activations=False,
                    mean_ablate=False,
                    fake_activations: Tuple[int, torch.Tensor] = False,
                    ig_mask_threshold=None,
                    ):
    hooks = []
    # blocks.0.hook_resid_pre
    # # fake hook that adds zero so gradients propagate through the model
    # param = nn.Parameter(torch.tensor(0.0, requires_grad=True))
    # hooks.append(
    #     (
    #         "blocks.0.hook_resid_pre",
    #         lambda value, hook: value + param,
    #     )
    # )
    for sae in saes:
        hooks.append(
            (
            sae.cfg.hook_name,
            build_sae_hook_fn(sae, sequence, cache_sae_grads=cache_sae_grads, circuit_mask=circuit_mask, use_mask=use_mask, binarize_mask=binarize_mask, cache_masked_activations=cache_masked_activations, cache_sae_activations=cache_sae_activations, mean_mask=mean_mask, mean_ablate=mean_ablate, fake_activations=fake_activations, ig_mask_threshold=ig_mask_threshold),
            )
        )
    return hooks 

def build_sae_logitfn(**kwargs):
    def logitfn(tokens):
        return model.run_with_hooks(
            tokens, 
            return_type="logits", 
            fwd_hooks=build_hooks_list(tokens, **kwargs)
            )
    return logitfn

# data

In [9]:
import json
file_path = 'data/ioi/ioi_train.json'
with open(file_path, 'r') as file:
    data = [json.loads(line) for line in file]
for entry in data:
    print(entry)
    break
example_length = 7

{'clean_prefix': 'Then, Richard and Erica went to the hospital. Richard gave a drink to', 'patch_prefix': 'Then, Matthew and Crystal went to the hospital. Megan gave a drink to', 'clean_answer': ' Erica', 'patch_answer': ' Richard', 'case': 'BABA_16'}


In [9]:
file_path = 'data/ioi/ioi_train_21.json'
with open(file_path, 'r') as file:
    data = json.load(file)

In [10]:
from transformer_lens.utils import test_prompt
test_prompt("Then, Matthew and Crystal went to the hospital. Megan gave a drink to", " Erica", model)

Tokenized prompt: ['<bos>', 'Then', ',', ' Matthew', ' and', ' Crystal', ' went', ' to', ' the', ' hospital', '.', ' Megan', ' gave', ' a', ' drink', ' to']
Tokenized answer: [' Erica']


Top 0th token. Logit: 28.58 Prob: 43.22% Token: | Matthew|
Top 1th token. Logit: 27.45 Prob: 13.99% Token: | Crystal|
Top 2th token. Logit: 27.27 Prob: 11.63% Token: | her|
Top 3th token. Logit: 27.19 Prob: 10.73% Token: | the|
Top 4th token. Logit: 26.28 Prob:  4.33% Token: | them|
Top 5th token. Logit: 25.45 Prob:  1.88% Token: | both|
Top 6th token. Logit: 25.22 Prob:  1.50% Token: | him|
Top 7th token. Logit: 25.16 Prob:  1.41% Token: | a|
Top 8th token. Logit: 25.14 Prob:  1.39% Token: | each|
Top 9th token. Logit: 24.26 Prob:  0.57% Token: | Matt|


In [15]:
example_length = 20

In [16]:
clean_data = []
corr_data = []
clean_labels = []
corr_labels = []
for entry in data:
    if model.to_tokens(entry['clean_prefix']).shape[-1] == example_length:
        clean_data.append(entry['clean_prefix'])
        corr_data.append(entry['patch_prefix'])
        clean_labels.append(entry['clean_answer'])
        corr_labels.append(entry['patch_answer'])
print(len(clean_data))

2966


In [17]:
clean_data = []
corr_data = []
clean_labels = []
corr_labels = []
for entry in data:
    if model.to_tokens(entry['clean_prefix']).shape[-1] == example_length:
        clean_data.append(entry['clean_prefix'])
        corr_data.append(entry['patch_prefix'])
        clean_labels.append(entry['clean_answer'])
        corr_labels.append(entry['patch_answer'])

N = 2900
clean_tokens = model.to_tokens(clean_data[:N])
corr_tokens = model.to_tokens(corr_data[:N])
clean_label_tokens = model.to_tokens(clean_labels[:N], prepend_bos=False).squeeze(-1)
corr_label_tokens = model.to_tokens(corr_labels[:N], prepend_bos=False).squeeze(-1)
print(clean_tokens.shape, corr_tokens.shape)

def logit_diff_fn(logits, clean_labels, corr_labels, token_wise=False):
    clean_logits = logits[torch.arange(logits.shape[0]), -1, clean_labels]
    corr_logits = logits[torch.arange(logits.shape[0]), -1, corr_labels]
    return (clean_logits - corr_logits).mean() if not token_wise else (clean_logits - corr_logits)

batch_size = 16 
clean_tokens = clean_tokens[:batch_size*(len(clean_tokens)//batch_size)]
corr_tokens = corr_tokens[:batch_size*(len(corr_tokens)//batch_size)]
clean_label_tokens = clean_label_tokens[:batch_size*(len(clean_label_tokens)//batch_size)]
corr_label_tokens = corr_label_tokens[:batch_size*(len(corr_label_tokens)//batch_size)]

clean_tokens = clean_tokens.reshape(-1, batch_size, clean_tokens.shape[-1])
corr_tokens = corr_tokens.reshape(-1, batch_size, corr_tokens.shape[-1])
clean_label_tokens = clean_label_tokens.reshape(-1, batch_size)
corr_label_tokens = corr_label_tokens.reshape(-1, batch_size)

print(clean_tokens.shape, corr_tokens.shape, clean_label_tokens.shape, corr_label_tokens.shape)

torch.Size([2900, 20]) torch.Size([2900, 20])
torch.Size([181, 16, 20]) torch.Size([181, 16, 20]) torch.Size([181, 16]) torch.Size([181, 16])


In [18]:
use_mask = False 
mean_mask = False
avg_logit_diff = 0
cleanup_cuda()
with torch.no_grad():
    for i in range(10):
        logits = model.run_with_hooks(
            clean_tokens[i], 
            return_type="logits", 
            fwd_hooks=build_hooks_list(clean_tokens[i], use_mask=use_mask, mean_mask=mean_mask)
            )
        ld = logit_diff_fn(logits, clean_label_tokens[i], corr_label_tokens[i])
        print(ld)
        avg_logit_diff += ld
        del logits
        cleanup_cuda()
avg_logit_diff = (avg_logit_diff / 10).item()
print("Average LD: ", avg_logit_diff)

tensor(4.3811, device='cuda:0')
tensor(4.2698, device='cuda:0')
tensor(4.0039, device='cuda:0')
tensor(3.3653, device='cuda:0')
tensor(3.4589, device='cuda:0')
tensor(3.2856, device='cuda:0')
tensor(3.1497, device='cuda:0')
tensor(3.6056, device='cuda:0')
tensor(3.6865, device='cuda:0')
tensor(3.4956, device='cuda:0')
Average LD:  3.6701977252960205


In [19]:
for sae in saes:
    sae.mask = SparseMask(sae.cfg.d_sae, 1.0, seq_len=example_length).to(device)

In [20]:
def running_mean_tensor(old_mean, new_value, n):
    return old_mean + (new_value - old_mean) / n

def get_sae_means(mean_tokens, total_batches, batch_size, per_token_mask=False):
    for sae in saes:
        sae.mean_ablation = torch.zeros(sae.cfg.d_sae).float().to(device)
    
    with tqdm(total=total_batches*batch_size, desc="Mean Accum Progress") as pbar:
        for i in range(total_batches):
            for j in range(batch_size):
                with torch.no_grad():
                    _ = model.run_with_hooks(
                        mean_tokens[i, j], 
                        return_type="logits", 
                        fwd_hooks=build_hooks_list(mean_tokens[i, j], cache_sae_activations=True)
                        )
                    for sae in saes:
                        sae.mean_ablation = running_mean_tensor(sae.mean_ablation, sae.feature_acts, i+1)
                    cleanup_cuda()
                pbar.update(1)

            if i >= total_batches:
                break

get_sae_means(corr_tokens, 25, 16)

Mean Accum Progress: 100%|██████████| 400/400 [01:29<00:00,  4.46it/s]


In [17]:
use_mask

False

In [19]:
saes[0].mean_ablation.shape

torch.Size([1, 16, 16384])

In [21]:
avg_mean_diff = 0
with torch.no_grad():
    for i in range(5):
        logits = model.run_with_hooks(
            clean_tokens[i], 
            return_type="logits", 
            fwd_hooks=build_hooks_list(clean_tokens[i], mean_ablate=True)
            )
        ld = logit_diff_fn(logits, clean_label_tokens[i], corr_label_tokens[i])
        print(ld)
        avg_mean_diff += ld
        del logits
        cleanup_cuda()
print("Average LD: ", avg_mean_diff)

tensor(0.1156, device='cuda:0')
tensor(-0.1215, device='cuda:0')
tensor(-0.1009, device='cuda:0')
tensor(-0.1535, device='cuda:0')
tensor(-0.0856, device='cuda:0')
Average LD:  tensor(-0.3460, device='cuda:0')


# mask training 

In [22]:
import torch.nn.functional as F
import wandb

In [23]:
import signal
class KeyboardInterruptBlocker:
    def __enter__(self):
        # Block SIGINT and store old mask
        self.old_mask = signal.pthread_sigmask(signal.SIG_BLOCK, {signal.SIGINT})

    def __exit__(self, exc_type, exc_value, traceback):
        # Restore old mask (unblock SIGINT)
        signal.pthread_sigmask(signal.SIG_SETMASK, self.old_mask)

class Range:
    def __init__(self, *args):
        # Support for range(start, stop, step) or range(stop)
        self.args = args

        # Validate input like the built-in range does
        if len(self.args) not in {1, 2, 3}:
            raise TypeError(f"Range expected at most 3 arguments, got {len(self.args)}")
        
        self.range = __builtins__.range(*self.args)  # Create the range object

    def __iter__(self):
        for i in self.range:
            try:
                with KeyboardInterruptBlocker():
                    yield i
            except KeyboardInterrupt:
                print("Keyboard interrupt received. Exiting iteration.")
                break

    def __len__(self):
        return len(self.range)

In [None]:
def do_training_run(token_dataset, labels_dataset, corr_labels_dataset, sparsity_multiplier, task, example_length=6, loss_function='ce', per_token_mask=False, use_mask=True, mean_mask=False, distinct_sparsity_multiplier=0 ):

    def logitfn(tokens):
        logits =  model.run_with_hooks(
            tokens, 
            return_type="logits", 
            fwd_hooks=build_hooks_list(tokens, use_mask=use_mask, mean_mask=mean_mask)
            )
        return logits

    def forward_pass(batch, clean_label_tokens, corr_label_tokens, logitfn, ratio_trained=1, loss_function='ce'):
        for sae in saes:
            sae.mask.ratio_trained = ratio_trained
        tokens = batch
        logits = logitfn(tokens)
        last_token_logits = logits[:, -1, :]
        if loss_function == 'ce':
            loss = F.cross_entropy(last_token_logits, clean_label_tokens)
        elif loss_function == 'logit_diff':
            fwd_logit_diff = logit_diff_fn(logits, clean_label_tokens, corr_label_tokens)
            loss = torch.abs(avg_logit_diff - fwd_logit_diff)

        sparsity_loss = 0
        distinct_sparsity_loss = 0
        # if per_token_mask:
        #     distinct_sparsity_loss = 0
        for sae in saes:
            sparsity_loss = sparsity_loss + sae.mask.sparsity_loss
            # if per_token_mask:
            #     distinct_sparsity_loss = distinct_sparsity_loss + sae.mask.distinct_sparsity_loss
        
        sparsity_loss = sparsity_loss / len(saes)
        distinct_sparsity_loss = distinct_sparsity_loss / len(saes)

        return loss, sparsity_loss, distinct_sparsity_loss

    print("doing a run with sparsity multiplier", sparsity_multiplier)
    all_optimized_params = []
    config = {
        "batch_size": 16,
        "learning_rate": 1,
        "total_steps": token_dataset.shape[0]*0.8,
        "sparsity_multiplier": sparsity_multiplier
    }

    for sae in saes:
        if per_token_mask:
            sae.mask = SparseMask(sae.cfg.d_sae, 1.0, seq_len=example_length).to(device)
        else:
            sae.mask = SparseMask(sae.cfg.d_sae, 1.0).to(device)
        all_optimized_params.extend(list(sae.mask.parameters()))
        sae.mask.max_temp = torch.tensor(200.0)
    
    wandb.init(project="sae circuits", config=config)
    optimizer = optim.Adam(all_optimized_params, lr=config["learning_rate"])
    total_steps = config["total_steps"] #*20 #*config["batch_size"]
    # epochs = 20
    with tqdm(total=total_steps, desc="Training Progress") as pbar:
        # for epoch in range(epochs):
        for i, (x, y, z) in enumerate(zip(token_dataset, labels_dataset, corr_labels_dataset)):
            with KeyboardInterruptBlocker():
                optimizer.zero_grad()
                
                # Calculate ratio trained
                ratio_trained =  i / total_steps*1.1
                
                # Update mask ratio for each SAE
                for sae in saes:
                    sae.mask.ratio_trained = ratio_trained
                
                # Forward pass with updated ratio_trained
                loss, sparsity_loss, _ = forward_pass(x, y, z, logitfn, ratio_trained=ratio_trained, loss_function=loss_function)
                # if per_token_mask:
                #     sparsity_loss = sparsity_loss / example_length
                # print("sp loss", sparsity_loss)
                # print("sae 0 sp loss", saes[0].mask.sparsity_loss)
                # print("sae 1 sp loss", saes[1].mask.sparsity_loss)
                # print("sae 2 sp loss", saes[2].mask.sparsity_loss)
                # print("sae 3 sp loss", saes[3].mask.sparsity_loss)
                
                avg_nonzero_elements = sparsity_loss
                # avg_distinct_nonzero_elements = distinct_sparsity_loss
                    
                sparsity_loss = sparsity_loss * config["sparsity_multiplier"] #+ distinct_sparsity_loss * distinct_sparsity_multiplier
                total_loss = sparsity_loss + loss
                infodict  = {"Step": i, "Progress": ratio_trained, "Avg Nonzero Elements": avg_nonzero_elements.item(), "Task Loss": loss.item(), "Sparsity Loss": sparsity_loss.item(), "temperature": saes[0].mask.temperature} # "avg distinct lat/sae":avg_distinct_nonzero_elements.item()
                wandb.log(infodict)
                # break
                # Backward pass and optimizer step
                total_loss.backward()
                optimizer.step()
                
                # Update tqdm bar with relevant metrics
                pbar.set_postfix(infodict)
                
                # Update the tqdm progress bar
                pbar.update(1)
                # break
                if i >= total_steps*1.1:
                    break
    wandb.finish()

    optimizer.zero_grad()

    for sae in saes:
        for param in sae.parameters():
            param.grad = None
        for param in sae.mask.parameters():
            param.grad = None
    
    for param in model.parameters():
        param.grad = None

    torch.cuda.empty_cache()

    mask_dict = {}

    total_density = 0
    for sae in saes:
        mask_dict[sae.cfg.hook_name] = torch.where(sae.mask.mask > 0)[1].tolist()   # rob thinks .view(-1) needed here
        total_density += (sae.mask.mask > 0).sum().item()
    mask_dict["total_density"] = total_density
    mask_dict['avg_density'] = total_density / len(saes)

    if per_token_mask:
        print("total # latents in circuit: ", total_density)
    print("avg density", mask_dict['avg_density'])

    ### EVAL ###
    def masked_logit_fn(tokens):
        logits =  model.run_with_hooks(
            tokens, 
            return_type="logits", 
            fwd_hooks=build_hooks_list(tokens, use_mask=use_mask, mean_mask=mean_mask, binarize_mask=True)
            )
        return logits

    def eval_ce_loss(batch, labels, logitfn, ratio_trained=10):
        for sae in saes:
            sae.mask.ratio_trained = ratio_trained
        tokens = batch
        logits = logitfn(tokens)
        last_token_logits = logits[:, -1, :]
        loss = F.cross_entropy(last_token_logits, labels)
        return loss
    
    def eval_logit_diff(num_batches, batch, clean_labels, corr_labels, logitfn, ratio_trained=10):
        for sae in saes:
            sae.mask.ratio_trained = ratio_trained
        avg_ld = 0
        for i in range(num_batches):
            tokens = batch[-i]
            logits = logitfn(tokens)
            ld = logit_diff_fn(logits, clean_labels[-i], corr_labels[-i])
            avg_ld += ld
            del logits
            cleanup_cuda()
        return (avg_ld / num_batches).item()

    with torch.no_grad():
        loss = eval_ce_loss(token_dataset[-1], labels_dataset[-1], masked_logit_fn)
        print("CE loss:", loss)
        cleanup_cuda()
        logit_diff = eval_logit_diff(10, token_dataset, labels_dataset, corr_labels_dataset, masked_logit_fn)
        print("Logit Diff:", logit_diff)
        cleanup_cuda()

    save_path = f"masks/{task}/{loss_function}_{str(sparsity_multiplier)}_run/"
    os.makedirs(save_path, exist_ok=True)
    mask_dict['ce_loss'] = loss.item()
    mask_dict['logit_diff'] = logit_diff
    faithfulness = logit_diff / avg_logit_diff
    mask_dict['faithfulness'] = faithfulness
    
    for idx, sae in enumerate(saes):
        mask_path = f"sae_mask_{idx}.pt"
        torch.save(sae.mask.state_dict(), os.path.join(save_path,mask_path))
        print(f"Saved mask for SAE {idx} to {mask_path}")

    json.dump(mask_dict, open(os.path.join(save_path,f"{str(sparsity_multiplier)}_run.json"), "w"))

In [27]:
do_training_run(token_dataset =clean_tokens, labels_dataset= clean_label_tokens, corr_labels_dataset=corr_label_tokens, sparsity_multiplier=100, task='ioi/baba21/', example_length=20, loss_function="logit_diff", per_token_mask=False, use_mask=True, mean_mask=True)

doing a run with sparsity multiplier 100


VBox(children=(Label(value='0.012 MB of 0.012 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
Avg Nonzero Elements,▁▁▂▂▃▄▄▅▅▆▆▇▇▇██████████████████████████
Progress,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
Sparsity Loss,▁▁▂▂▃▄▄▅▅▆▆▇▇▇██████████████████████████
Step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
Task Loss,█▇▇█▅▃▆▄▃▄▃▃▂▃▁▂▃▂▂▂▂▄▂▃▁▃▁▃▁▂▂▁▂▁▁▁▁▃▂▂
temperature,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▂▂▂▂▃▃▃▃▃▄▄▄▅▅▅▆▆▇█

0,1
Avg Nonzero Elements,16384.0
Progress,0.75967
Sparsity Loss,1638400.0
Step,100.0
Task Loss,0.67491
temperature,55.97834


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01111281055620768, max=1.0)…

  full_bar = Bar(frac,
Training Progress: 161it [02:06,  1.28it/s, Step=160, Progress=1.22, Avg Nonzero Elements=16384.0, Task Loss=0.517, Sparsity Loss=1.64e+6, temperature=tensor(626.3716)]                           


VBox(children=(Label(value='0.021 MB of 0.021 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
Avg Nonzero Elements,▁▂▃▄▅▆▆▇▇███████████████████████████████
Progress,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
Sparsity Loss,▁▂▃▄▅▆▆▇▇███████████████████████████████
Step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
Task Loss,█▇█▄▄▃▃▂▃▁▃▂▁▃▁▁▂▃▂▂▄▂▁▃▂▂▂▂▂▂▁▃▂▁▆▂▂▂▂▁
temperature,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▂▃▃▃▃▄▄▅▆▆█

0,1
Avg Nonzero Elements,16384.0
Progress,1.21547
Sparsity Loss,1638400.0
Step,160.0
Task Loss,0.5165
temperature,626.37158


IndexError: tuple index out of range

In [45]:
torch.abs(saes[0].mask.mask).sum() * saes[0].mask.distinct_l1

tensor(262144., device='cuda:0', grad_fn=<MulBackward0>)

In [65]:
torch.abs(saes[0].mask.mask).sum()

tensor(16384., device='cuda:0', grad_fn=<SumBackward0>)

In [66]:
saes[0].mask.temperature

tensor(1.)

In [70]:
torch.sigmoid(saes[0].mask.mask * saes[0].mask.temperature)

tensor([0.7311, 0.7311, 0.7311,  ..., 0.7311, 0.7311, 0.7311], device='cuda:0',
       grad_fn=<SigmoidBackward0>)

In [41]:
saes[0].mask.sparsity_loss

tensor(191642.6094, device='cuda:0', grad_fn=<MulBackward0>)

In [40]:
torch.sum(saes[0].mask.mask == 1)

tensor(262144, device='cuda:0')

In [43]:
torch.sum(saes[0].mask.mask>0)

tensor(262144, device='cuda:0')

In [44]:
16056*16

256000

In [31]:
0.8*1.3

1.04

In [91]:
from data.ioi import ioi_dataset

ioi_data = ioi_dataset.IOIDataset(prompt_type="BABA",nb_templates=2, N = 6000)
abc_data = (
    ioi_data.gen_flipped_prompts(("IO", "RAND"))
    .gen_flipped_prompts(("S", "RAND"))
    .gen_flipped_prompts(("S1", "RAND"))
    )



In [92]:
lower = 0
i = 0
for pr in ioi_data.ioi_prompts:
    ll = len(model.tokenizer.encode(pr['text']))
    # print(ll)
    # print(pr['text'])
    if ll == 21:
        lower += 1
    # i += 1
    # if i > 10:
    #     break
print(lower)

2966


In [94]:
import re
def truncate_to_end_with_to(text):
    return re.sub(r'\s+\S+$', '', text)

In [95]:
full_ioi_dataset = []
for ind, pr in enumerate(ioi_data.ioi_prompts):
    pr_abc = abc_data.ioi_prompts[ind]
    ll = len(model.tokenizer.encode(pr['text']))
    labv = len(model.tokenizer.encode(pr_abc['text']))
    if ll == 21 and ll == labv:
        ioi_dict = {"clean_prefix": truncate_to_end_with_to(pr['text']), "patch_prefix": truncate_to_end_with_to(pr_abc['text']), "clean_answer": ' '+pr['IO'], "patch_answer": ' '+pr['S'], "case": "BABA_21"}
        full_ioi_dataset.append(ioi_dict)

len(full_ioi_dataset)

2966

In [96]:
# save the ioi dataset
with open('data/ioi_train_21.json', 'w') as outfile:
    json.dump(full_ioi_dataset, outfile)

In [98]:
import json
file_path = 'data/ioi/ioi_train_21.json'
with open(file_path, 'r') as file:
    data = json.load(file)