# Load LLM, SAEs

sae layers: [7, 14, 21, 28, 40]
llm gemma 9b

In [1]:
from datasets import load_dataset
from transformers import AutoTokenizer
import torch
import circuitsvis as cv
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, FactoredMatrix
import tqdm
import torch.nn as nn
import torch.optim as optim
import math
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import functools

# device= 'cuda' if torch.cuda.is_available() else 'cpu'
# torch.set_default_device(device)
# assert device == 'cuda', "This notebook is not optimized for CPU"

import transformer_lens

# Load a model
model = transformer_lens.HookedTransformer.from_pretrained("gemma-2-9b", device="cuda")
pad_token_id = model.tokenizer.pad_token_id
print('pad token id is', pad_token_id)

for param in model.parameters():
   param.requires_grad_(False)

device = "cuda"





Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]



Loaded pretrained model gemma-2-9b into HookedTransformer
pad token id is 0


In [2]:
device= 'cuda' if torch.cuda.is_available() else 'cpu'
torch.set_default_device(device)
assert device == 'cuda', "This notebook is not optimized for CPU"


In [3]:
from sae_lens import SAE
def get_gemma_9b_sae_id(layer):
    return f"layer_{layer}/width_16k/canonical"
# sae_id = get_gemma_9b_sae_id(10)
# sae, cfg_dict, sparsity = SAE.from_pretrained(release="gemma-scope-9b-pt-res-canonical", sae_id="layer_10/width_16k/canonical", device=device)
# saes_to_load = [7, 14, 21, 40]
saes_to_load = [14, 21]


saes = []
for sae_layer in saes_to_load:
    sae, cfg_dict, sparsity = SAE.from_pretrained(release="gemma-scope-9b-pt-res-canonical", sae_id=get_gemma_9b_sae_id(sae_layer), device=device)
    saes.append(sae)

In [4]:
for sae in saes:
    for param in sae.parameters():
        param.requires_grad_(False)

# Load Dataset

In [5]:
# Updated version to return JSON with more names and structure for correct and incorrect keying examples

import json
import random

# Expanding the name pool with a larger set of names
extended_name_pool = [
    "Bob", "Sam", "Lilly", "Rob", "Alice", "Charlie", "Sally", "Tom", "Jake", "Emily", 
    "Megan", "Chris", "Sophia", "James", "Oliver", "Isabella", "Mia", "Jackson", 
    "Emma", "Ava", "Lucas", "Benjamin", "Ethan", "Grace", "Olivia", "Liam", "Noah"
]

for name in extended_name_pool:
    assert len(model.tokenizer.encode(name)) == 2, f"Name {name} has more than 1 token"

# Function to generate the dataset with correct and incorrect keying into dictionaries
def generate_extended_dataset(name_pool, num_samples=5):
    dataset = []
    
    for _ in range(num_samples):
        # Randomly select 5 names from the pool
        selected_names = random.sample(name_pool, 5)
        # Assign random ages to the selected names
        age_dict = {name: random.randint(10, 19) for name in selected_names}
        
        # Create a correct example
        correct_name = random.choice(list(age_dict.keys()))
        correct_prompt = f'Type "help", "copyright", "credits" or "license" for more information.\n>>> age = {age_dict}\n>>> age["{correct_name}"]\n'
        correct_response = age_dict[correct_name]
        correct_token = str(correct_response)[0]
        
        # Create an incorrect example with a name not in the dictionary
        incorrect_name = random.choice([name for name in name_pool if name not in age_dict])
        incorrect_prompt = f'Type "help", "copyright", "credits" or "license" for more information.\n>>> age = {age_dict}\n>>> age["{incorrect_name}"]\n'
        incorrect_response = "Traceback"
        incorrect_token = "Traceback"
        
        # Append the pair of correct and incorrect examples
        dataset.append({
            "correct": {
                "prompt": correct_prompt,
                "response": correct_response,
                "token": correct_token
            },
            "error": {
                "prompt": incorrect_prompt,
                "response": incorrect_response,
                "token": incorrect_token
            }
        })
    
    
    return dataset

# Generate the extended dataset
json_dataset = generate_extended_dataset(extended_name_pool, num_samples=100_000)

# Output the JSON structure


In [6]:
json_dataset[0]

{'correct': {'prompt': 'Type "help", "copyright", "credits" or "license" for more information.\n>>> age = {\'Grace\': 19, \'Rob\': 10, \'Liam\': 11, \'Alice\': 16, \'Benjamin\': 14}\n>>> age["Liam"]\n',
  'response': 11,
  'token': '1'},
 'error': {'prompt': 'Type "help", "copyright", "credits" or "license" for more information.\n>>> age = {\'Grace\': 19, \'Rob\': 10, \'Liam\': 11, \'Alice\': 16, \'Benjamin\': 14}\n>>> age["Oliver"]\n',
  'response': 'Traceback',
  'token': 'Traceback'}}

In [7]:
tokenized = model.to_tokens(["I like pie", "I like cake and pie", "birds", 'cats'])
torch.count_nonzero(tokenized != pad_token_id, dim=-1)-1
tokenized


tensor([[     2, 235285,   1154,   4506,      0,      0],
        [     2, 235285,   1154,  11514,    578,   4506],
        [     2,  39954,      0,      0,      0,      0],
        [     2,  34371,      0,      0,      0,      0]], device='cuda:0')

In [8]:
tensa = torch.tensor([1,2,3,])
tesb = torch.tensor([1,2,3,4])
torch.cat([tensa, tesb])

tensor([1, 2, 3, 1, 2, 3, 4], device='cuda:0')

In [9]:
class ContrastiveDatasetBatch:
    def __init__(self, dataset_items):
        self.correct_batch = [item["correct"] for item in dataset_items]
        self.error_batch = [item["error"] for item in dataset_items]
        self.batch_size = len(self.correct_batch)
        self.correct_token_idx = model.to_single_token("1")
        self.error_token_idx = model.to_single_token("Traceback")

        correct_tokenized = None
        error_tokenized = None

        correct_prompts = [example["prompt"] for example in self.correct_batch]
        error_prompts = [example["prompt"] for example in self.error_batch]
        assert len(correct_prompts) == len(error_prompts)
        all_prompts = correct_prompts + error_prompts
        all_tokenized = model.to_tokens(all_prompts)
        last_non_pad_idxs = torch.count_nonzero(all_tokenized != pad_token_id, dim=-1) - 1


        correct_tokenized = all_tokenized[:self.batch_size]
        correct_answer_idxs = last_non_pad_idxs[:self.batch_size]
        error_tokenized = all_tokenized[self.batch_size:]
        error_answer_idxs = last_non_pad_idxs[self.batch_size:]

        self.correct_tokenized = correct_tokenized
        self.error_tokenized = error_tokenized
        self.correct_answer_seq_idxs = correct_answer_idxs
        self.error_answer_seq_idxs = error_answer_idxs

        self.all_tokenized = all_tokenized
        self.all_answer_seq_idxs = last_non_pad_idxs
        self.all_prompts = all_prompts
        # the tokens are: [correct prompt, correct prompt, ..., error prompt, error prompt, ...]
        self.all_answers_tok_idxs = torch.cat(
            [
                torch.ones(self.batch_size, dtype=torch.int64)*self.correct_token_idx,
                torch.ones(self.batch_size, dtype=torch.int64)*self.error_token_idx
            ]
            )
        self.all_wrong_answers_tok_idxs = torch.cat(
            [
                torch.ones(self.batch_size, dtype=torch.int64)*self.error_token_idx,
                torch.ones(self.batch_size, dtype=torch.int64)*self.correct_token_idx
            ]
            )
    
    def get_logit_diffs(self, clean_logits, error_logits):
        # for the clean pass, get the logit for the "1" token and the "Traceback" token
        correct_code_right_logits = clean_logits[torch.arange(self.batch_size), self.correct_answer_seq_idxs, self.correct_token_idx]
        correct_code_wrong_logits = clean_logits[torch.arange(self.batch_size), self.correct_answer_seq_idxs, self.error_token_idx]
        correct_logit_diffs = correct_code_right_logits - correct_code_wrong_logits

        # for the error pass, get the logit for the Traceback token and the "1" token
        error_code_right_logits = error_logits[torch.arange(self.batch_size), self.error_answer_seq_idxs, self.error_token_idx]
        error_code_wrong_logits = error_logits[torch.arange(self.batch_size), self.error_answer_seq_idxs, self.correct_token_idx]
        error_logit_diffs = error_code_right_logits - error_code_wrong_logits
        return {"correct_code_diff": correct_logit_diffs, "error_code_diff": error_logit_diffs}
        


eval_dataset = ContrastiveDatasetBatch(json_dataset[:10])

In [10]:
simple_dataset = []
simple_labels = []

answer_token = model.to_single_token("1")
traceback_token = model.to_single_token("Traceback")

for item in json_dataset:
    simple_dataset.append(item["correct"]["prompt"])
    simple_dataset.append(item["error"]["prompt"])
    simple_labels.append(answer_token)
    simple_labels.append(traceback_token)


simple_dataset = model.to_tokens(simple_dataset)
simple_labels = torch.tensor(simple_labels)



In [11]:
permutation = torch.randperm(len(simple_dataset))
simple_dataset = simple_dataset[permutation]
simple_labels = simple_labels[permutation]

In [12]:
print(f"token:'{model.to_str_tokens(eval_dataset.all_tokenized[15])[eval_dataset.all_answer_seq_idxs[15]]}'")
print("Answer:")
model.to_string([eval_dataset.all_answers_tok_idxs[15]])

token:'
'
Answer:


'Traceback'

# Helper Fns

In [58]:
class SparseMask(nn.Module):
    def __init__(self, shape, l1):
        super().__init__()
        self.mask = nn.Parameter(torch.ones(shape))
        self.l1 = l1
        self.max_temp = torch.tensor(1000.0)
        self.sparsity_loss = None
        self.ratio_trained = 1


    def forward(self, x):
        temperature = self.max_temp ** self.ratio_trained
        mask = torch.sigmoid(self.mask * temperature)
        self.sparsity_loss = torch.abs(mask).sum() * self.l1
        return x * mask


saes[0].mask = SparseMask(saes[0].cfg.d_sae, 1.0)
saes[1].mask = SparseMask(saes[1].cfg.d_sae, 1.0)

In [59]:
bos_token_id = model.tokenizer.bos_token_id



def build_sae_hook_fn(sae, sequence, cache_grads=False, circuit_mask=None, use_mask=False):
    # make the mask for the sequence
    mask = torch.ones_like(sequence, dtype=torch.bool)
    mask[sequence == pad_token_id] = False
    mask[sequence == bos_token_id] = False # where mask is false, keep original
    def sae_hook(value, hook):
        # print(f"sae {sae.cfg.hook_name} running at layer {hook.layer()}")
        feature_acts = sae.encode(value)
        if cache_grads:
            sae.feature_acts = feature_acts
            sae.feature_acts.retain_grad()
        
        if use_mask:
            feature_acts = sae.mask(feature_acts) 

        if circuit_mask is not None:
            mask_method = circuit_mask['mask_method']
            mask_indices = circuit_mask[sae.cfg.hook_name]
            if mask_method == 'keep_only':
                # any activations not in the mask are set to 0
                expanded_circuit_mask = torch.zeros_like(feature_acts)
                expanded_circuit_mask[:, :, mask_indices] = 1
                feature_acts = feature_acts * expanded_circuit_mask
            elif mask_method == 'zero_only':
                feature_acts[:, :, mask_indices] = 0
            else:
                raise ValueError(f"mask_method {mask_method} not recognized")


        out = sae.decode(feature_acts)
        # choose out or value based on the mask
        mask_expanded = mask.unsqueeze(-1).expand_as(value)
        value = torch.where(mask_expanded, out, value)
        return value
    return sae_hook


    # def sae_hook_ablate(value, hook):
    # feature_acts = sae.encode(value)
    # # feature_acts[:, :, topsae_attr_indices] = 0
    # out = sae.decode(feature_acts)
    # return out


def build_hooks_list(sequence, cache_sae_grads=False, circuit_mask=None, use_mask=False):
    hooks = []
    # blocks.0.hook_resid_pre
    # # fake hook that adds zero so gradients propagate through the model
    param = nn.Parameter(torch.tensor(0.0, requires_grad=True))
    hooks.append(
        (
            "blocks.0.hook_resid_pre",
            lambda value, hook: value + param,
        )
    )
    for sae in saes:
        hooks.append(
            (
            sae.cfg.hook_name,
            build_sae_hook_fn(sae, sequence, cache_grads=cache_sae_grads, circuit_mask=circuit_mask, use_mask=use_mask),
            )
        )
    return hooks 

In [60]:
def sanity_check_model_performance(logitfn):
    baseline_dataset = ContrastiveDatasetBatch(json_dataset[0:10])
    correct_logits = logitfn(baseline_dataset.correct_tokenized)
    error_logits = logitfn(baseline_dataset.error_tokenized)
    print("probability of predicting the correct age in the correct example")
    print(F.softmax(correct_logits[torch.arange(correct_logits.shape[0]), baseline_dataset.correct_answer_seq_idxs], dim=-1)[:, baseline_dataset.correct_token_idx].mean())
    print("and of traceback in that example")
    print(F.softmax(correct_logits[torch.arange(correct_logits.shape[0]), baseline_dataset.correct_answer_seq_idxs], dim=-1)[:, baseline_dataset.error_token_idx].mean())
    print("logit difference:")
    correct_logit_values = correct_logits[torch.arange(correct_logits.shape[0]), baseline_dataset.correct_answer_seq_idxs, baseline_dataset.correct_token_idx]
    error_logit_values = correct_logits[torch.arange(correct_logits.shape[0]), baseline_dataset.correct_answer_seq_idxs, baseline_dataset.error_token_idx]
    diff = correct_logit_values - error_logit_values
    print(diff.mean())

    print("probability of predicting the traceback in the error code example")
    print(F.softmax(error_logits[torch.arange(error_logits.shape[0]), baseline_dataset.error_answer_seq_idxs], dim=-1)[:, baseline_dataset.error_token_idx].mean())
    print("and of the correct age in that example")
    print(F.softmax(error_logits[torch.arange(error_logits.shape[0]), baseline_dataset.error_answer_seq_idxs], dim=-1)[:, baseline_dataset.correct_token_idx].mean())    
    print("logit difference:")
    correct_logit_values = error_logits[torch.arange(error_logits.shape[0]), baseline_dataset.error_answer_seq_idxs, baseline_dataset.error_token_idx]
    error_logit_values = error_logits[torch.arange(error_logits.shape[0]), baseline_dataset.error_answer_seq_idxs, baseline_dataset.correct_token_idx]
    diff = correct_logit_values - error_logit_values
    print(diff.mean())


In [61]:
def all_contrastive_difference(logitfn):
    baseline_dataset = ContrastiveDatasetBatch(json_dataset[0:10])
    all_tokenized = baseline_dataset.all_tokenized
    all_answer_seq_idxs = baseline_dataset.all_answer_seq_idxs
    all_answers_tok_idxs = baseline_dataset.all_answers_tok_idxs
    all_wrong_answers_tok_idxs = baseline_dataset.all_wrong_answers_tok_idxs
    all_prompts = baseline_dataset.all_prompts
    
    all_logits = logitfn(all_tokenized)
    all_answer_logits = all_logits[torch.arange(all_logits.shape[0]), all_answer_seq_idxs]
    all_answers_correct_logits = all_answer_logits[torch.arange(all_answer_logits.shape[0]), all_answers_tok_idxs]
    all_answers_wrong_logits = all_answer_logits[torch.arange(all_answer_logits.shape[0]), all_wrong_answers_tok_idxs]
    all_logit_diffs = all_answers_correct_logits - all_answers_wrong_logits
    # 6.7 dif for correct, 1.8 for error
    all_logit_diffs[0:baseline_dataset.batch_size] = all_logit_diffs[0:baseline_dataset.batch_size]/6.7
    all_logit_diffs[baseline_dataset.batch_size:] = all_logit_diffs[baseline_dataset.batch_size:]/1.8
    return all_logit_diffs



# SAE vs no SAE (sanity check and basic setup)

In [15]:
baseline_dataset = ContrastiveDatasetBatch(json_dataset[10:15])


In [37]:
def baseline_model_logit_fn(tokens):
    correct_logits = model.run_with_hooks(
        tokens, 
        return_type="logits", 
        fwd_hooks=[
            #(
            #utils.get_act_name("pre", 0), # v=attention out
            #"blocks.0.hook_mlp_out",
            #mlp_ablation_hook,
            #),
            ]
        )
    return correct_logits

with torch.no_grad():
    sanity_check_model_performance(baseline_model_logit_fn)

probability of predicting the correct age in the correct example
tensor(0.9827, device='cuda:0')
and of traceback in that example
tensor(0.0010, device='cuda:0')
logit difference:
tensor(7.0165, device='cuda:0')
probability of predicting the traceback in the error code example
tensor(0.9648, device='cuda:0')
and of the correct age in that example
tensor(0.0008, device='cuda:0')
logit difference:
tensor(7.1005, device='cuda:0')


In [17]:
pad_token_id
bos_token_id = model.tokenizer.bos_token_id
bos_token_id

2

In [38]:
with torch.no_grad():
    def logitfn(tokens):
        return model.run_with_hooks(
            tokens, 
            return_type="logits", 
            fwd_hooks=build_hooks_list(tokens)
            )
    sanity_check_model_performance(logitfn)

probability of predicting the correct age in the correct example
tensor(0.9518, device='cuda:0')
and of traceback in that example
tensor(0.0048, device='cuda:0')
logit difference:
tensor(5.3536, device='cuda:0')
probability of predicting the traceback in the error code example
tensor(0.6051, device='cuda:0')
and of the correct age in that example
tensor(0.0273, device='cuda:0')
logit difference:
tensor(3.1331, device='cuda:0')


In [19]:
with torch.no_grad():
    def logitfn(tokens):
        return model.run_with_hooks(
            tokens, 
            return_type="logits", 
            fwd_hooks=build_hooks_list(tokens)
            )
    logits = logitfn(baseline_dataset.error_tokenized)

In [20]:
test_idx = 4
topk = torch.topk(logits[:, -1, :][test_idx], k=3)
print(model.to_str_tokens(topk.indices))
print(topk.values)
logits[:, -1, :][test_idx][baseline_dataset.correct_token_idx]

['Traceback', '>>>', 'None']
tensor([25.2613, 23.6651, 21.7949], device='cuda:0')


tensor(21.5102, device='cuda:0')

# Contrastive Logit Diff With Positive and Negative

In [41]:
with torch.no_grad():
    def logitfn(tokens):
        return model.run_with_hooks(
            tokens, 
            return_type="logits", 
            fwd_hooks=build_hooks_list(tokens)
            )
    all_logit_diffs = all_contrastive_difference(logitfn)
    print(all_logit_diffs)

tensor([0.7308, 0.8568, 0.7782, 0.7600, 0.8709, 0.8032, 0.8597, 0.7421, 0.7308,
        0.8580, 1.5148, 1.7967, 1.8832, 1.5217, 1.9293, 1.7721, 1.8547, 1.7349,
        2.0564, 1.3420], device='cuda:0')


# Attribution

In [17]:
def logitfn(tokens):
    return model.run_with_hooks(
        tokens, 
        return_type="logits", 
        fwd_hooks=build_hooks_list(tokens, cache_sae_grads=True)
        )

all_logit_diffs = all_contrastive_difference(logitfn)

In [18]:
loss = all_logit_diffs.sum() * -1 # maximize the difference

In [19]:
loss.backward()

In [20]:
loss

tensor(-26.7789, device='cuda:0', grad_fn=<MulBackward0>)

In [21]:
no_ablate = {"mask_method": "keep_only"}

In [22]:
num_features = saes[0].cfg.d_sae

In [23]:
with torch.no_grad():
    for sae in saes:
        delta_loss = torch.abs((sae.feature_acts.grad * sae.feature_acts).view(-1, sae.feature_acts.shape[-1]).sum(dim=0))
        topk = torch.topk(delta_loss, k=63)
        no_ablate[sae.cfg.hook_name] = topk.indices

In [24]:
import random
# deep copy
import copy
random_ablate = copy.deepcopy(no_ablate)

num_features = saes[0].cfg.d_sae

for key in random_ablate.keys():
    if key == "mask_method":
        continue
#     for i in range(len(random_ablate[key])):
#         random_ablate[key][i] = random.randint(0, num_features-1)
# random_ablate

In [25]:
# test the ablation
    
with torch.no_grad():
    def logitfn(tokens):
        return model.run_with_hooks(
            tokens, 
            return_type="logits", 
            fwd_hooks=build_hooks_list(
                tokens,
                cache_sae_grads=False,
                circuit_mask=no_ablate)
            )
    
    all_logit_diffs = sanity_check_model_performance(logitfn)
    #all_logit_diffs = san(logitfn).mean()
    # print(all_logit_diffs)

probability of predicting the correct age in the correct example
tensor(0.0433, device='cuda:0')
and of traceback in that example
tensor(0.0035, device='cuda:0')
logit difference:
tensor(2.5023, device='cuda:0')
probability of predicting the traceback in the error code example
tensor(0.0041, device='cuda:0')
and of the correct age in that example
tensor(0.0426, device='cuda:0')
logit difference:
tensor(-2.3316, device='cuda:0')


# Optimize Binary Mask

In [62]:
batch_size = 16

# makethe length of the dataset a multiple of the batch size
simple_dataset = simple_dataset[:batch_size*(len(simple_dataset)//batch_size)]
simple_labels = simple_labels[:batch_size*(len(simple_labels)//batch_size)]

simple_dataset = simple_dataset.view(-1, batch_size, 65)
simple_labels = simple_labels.view(-1, batch_size)


In [75]:
def logitfn(tokens):
    return model.run_with_hooks(
        tokens, 
        return_type="logits", 
        fwd_hooks=build_hooks_list(tokens, use_mask=True)
        )


def forward_pass(batch, labels, logitfn, ratio_trained=0):
    for sae in saes:
        sae.mask.ratio_trained = ratio_trained
    tokens = batch
    logits = logitfn(tokens)
    last_token_logits = logits[:, -1, :]
    loss = F.cross_entropy(last_token_logits, labels)
    sparsity_loss = 0
    for sae in saes:
        sparsity_loss = sparsity_loss + sae.mask.sparsity_loss
    
    sparsity_loss = sparsity_loss / len(saes)

    return loss, sparsity_loss

In [68]:
all_optimized_params = []
for sae in saes:
    all_optimized_params.extend(list(sae.mask.parameters()))

In [69]:
optimizer = optim.Adam(all_optimized_params, lr=0.1)


In [66]:
from tqdm import tqdm

total_steps = simple_dataset.shape[0]*0.01

with tqdm(total=total_steps, desc="Training Progress") as pbar:
    for i, (x, y) in enumerate(zip(simple_dataset, simple_labels)):
        optimizer.zero_grad()
        
        # Calculate ratio trained
        ratio_trained = i / total_steps
        
        # Update mask ratio for each SAE
        for sae in saes:
            sae.mask.ratio_trained = ratio_trained
        
        # Forward pass with updated ratio_trained
        loss, sparsity_loss = forward_pass(x, y, logitfn, ratio_trained=0)
        avg_nonzero_elements = sparsity_loss
        sparsity_loss = sparsity_loss/50
        total_loss = loss + sparsity_loss
        
        # Backward pass and optimizer step
        total_loss.backward()
        optimizer.step()
        
        # Update tqdm bar with relevant metrics
        pbar.set_postfix({
            'Step': i,
            'Progress': f"{ratio_trained:.2%}",
            "Avg Nonzero Elements": f"{avg_nonzero_elements:.2f}",
            'Task Loss': f"{loss.item():.4f}",
            'Sparsity Loss': f"{sparsity_loss.item():.4f}"
        })
        
        # Update the tqdm progress bar
        pbar.update(1)


  full_bar = Bar(frac,
Training Progress: 230it [06:16,  1.64s/it, Step=229, Progress=183.26%, Avg Nonzero Elements=143.31, Task Loss=0.1822, Sparsity Loss=2.8663]                                        


KeyboardInterrupt: 

In [70]:
from tqdm import tqdm

total_steps = simple_dataset.shape[0]*0.015

with tqdm(total=total_steps, desc="Training Progress") as pbar:
    for i, (x, y) in enumerate(zip(simple_dataset, simple_labels)):
        optimizer.zero_grad()
        
        # Calculate ratio trained
        ratio_trained = i / total_steps
        
        # Update mask ratio for each SAE
        for sae in saes:
            sae.mask.ratio_trained = ratio_trained
        
        # Forward pass with updated ratio_trained
        loss, sparsity_loss = forward_pass(x, y, logitfn, ratio_trained=ratio_trained)
        avg_nonzero_elements = sparsity_loss
        sparsity_loss = sparsity_loss/50
        total_loss = loss + sparsity_loss
        
        # Backward pass and optimizer step
        total_loss.backward()
        optimizer.step()
        
        # Update tqdm bar with relevant metrics
        pbar.set_postfix({
            'Step': i,
            'Progress': f"{ratio_trained:.2%}",
            "Avg Nonzero Elements": f"{avg_nonzero_elements:.2f}",
            'Task Loss': f"{loss.item():.4f}",
            'Sparsity Loss': f"{sparsity_loss.item():.4f}"
        })
        
        # Update the tqdm progress bar
        pbar.update(1)


Training Progress: 202it [05:31,  1.64s/it, Step=201, Progress=107.23%, Avg Nonzero Elements=73.50, Task Loss=0.7459, Sparsity Loss=1.4700]                           


KeyboardInterrupt: 

Nonzero elements in mask for blocks.14.hook_resid_post: 74
Nonzero elements in mask for blocks.21.hook_resid_post: 65


In [30]:
saes[0].cfg.hook_name

'blocks.14.hook_resid_post'

In [28]:
torch.where(torch.sigmoid(saes[0].mask.mask*1000))


(tensor([  357,   849,   850,  1401,  1641,  1672,  1788,  1919,  2186,  2567,
          2770,  2796,  2936,  2959,  2992,  3082,  3330,  3373,  3468,  3714,
          3742,  4141,  4149,  4160,  4300,  4472,  4698,  4713,  4718,  4736,
          4825,  4834,  4873,  4882,  4996,  5003,  5459,  5606,  5713,  5736,
          5741,  5932,  5938,  6150,  6205,  6400,  6436,  7216,  7435,  7570,
          7598,  7761,  7857,  8036,  8269,  8303,  8321,  8429,  8482,  8512,
          8746,  8930,  9368, 10044, 10244, 10512, 10545, 10550, 11150, 11333,
         11578, 12180, 12355, 12514, 12658, 12722, 12749, 13138, 13148, 13518,
         13585, 13691, 13770, 13958, 13962, 14057, 14187, 14201, 14309, 14502,
         14572, 14632, 14864, 14951, 15137, 15625, 15628, 15633, 15646, 15761,
         15849], device='cuda:0'),)

In [33]:
sae.mask

SparseMask()

In [71]:

optimizer.zero_grad()

for param in saes[0].parameters():
    param.grad = None

for param in model.parameters():
    param.grad = None

for param in saes[0].mask.parameters():
    param.grad = None

torch.cuda.empty_cache()

## Eval

In [72]:


for sae in saes:
    mask = sae.mask.mask
    print(f"Nonzero elements in mask for {sae.cfg.hook_name}: {torch.count_nonzero(torch.sigmoid(mask*1000))}")

Nonzero elements in mask for blocks.14.hook_resid_post: 78
Nonzero elements in mask for blocks.21.hook_resid_post: 69


In [44]:
saes[0].mask.ratio_trained

1.1470337174562526

In [76]:
with torch.no_grad():
    sanity_check_model_performance(logitfn)

probability of predicting the correct age in the correct example
tensor(0.6323, device='cuda:0')
and of traceback in that example
tensor(0.0257, device='cuda:0')
logit difference:
tensor(3.2191, device='cuda:0')
probability of predicting the traceback in the error code example
tensor(0.4098, device='cuda:0')
and of the correct age in that example
tensor(0.0472, device='cuda:0')
logit difference:
tensor(2.1808, device='cuda:0')


In [40]:
# save saes[0].mask
torch.save(saes[0].mask, 'mask.pt')

# Clear Grads

In [123]:
# this gives bytes, convert to GB
torch.cuda.memory_reserved() / 1e9, "GB"

(72.30980096, 'GB')