In [1]:
# %%
import os
import gc
import torch
os.chdir("/work/pi_jensen_umass_edu/jnainani_umass_edu/CircuitAnalysisSAEs")
import json
from sae_lens import SAE, HookedSAETransformer
from circ4latents import data_gen
from functools import partial
import einops

# Function to manage CUDA memory and clean up
def cleanup_cuda():
   torch.cuda.empty_cache()
   gc.collect()
# cleanup_cuda()
# Load the config
with open("config.json", 'r') as file:
   config = json.load(file)
token = config.get('huggingface_token', None)
os.environ["HF_TOKEN"] = token

# Define device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

hf_cache = "/work/pi_jensen_umass_edu/jnainani_umass_edu/mechinterp/huggingface_cache/hub"
os.environ["HF_HOME"] = hf_cache

# Load the model
model = HookedSAETransformer.from_pretrained("google/gemma-2-9b", device=device, cache_dir=hf_cache)



Device: cuda


Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]



Loaded pretrained model google/gemma-2-9b into HookedTransformer


In [2]:
layers= [7, 14, 21, 40]
l0s = [92, 67, 129, 125]
saes = [SAE.from_pretrained(release="gemma-scope-9b-pt-res", sae_id=f"layer_{layers[i]}/width_16k/average_l0_{l0s[i]}", device=device)[0] for i in range(len(layers))]

In [3]:
# %%
# Updated version to return JSON with more names and structure for correct and incorrect keying examples

import json
import random

# Expanding the name pool with a larger set of names
extended_name_pool = [
    "Bob", "Sam", "Lilly", "Rob", "Alice", "Charlie", "Sally", "Tom", "Jake", "Emily", 
    "Megan", "Chris", "Sophia", "James", "Oliver", "Isabella", "Mia", "Jackson", 
    "Emma", "Ava", "Lucas", "Benjamin", "Ethan", "Grace", "Olivia", "Liam", "Noah"
]

for name in extended_name_pool:
    assert len(model.tokenizer.encode(name)) == 2, f"Name {name} has more than 1 token"

# Function to generate the dataset with correct and incorrect keying into dictionaries
def generate_extended_dataset(name_pool, num_samples=5):
    dataset = []
    
    for _ in range(num_samples):
        # Randomly select 5 names from the pool
        selected_names = random.sample(name_pool, 5)
        # Assign random ages to the selected names
        age_dict = {name: random.randint(10, 19) for name in selected_names}
        
        # Create a correct example
        correct_name = random.choice(list(age_dict.keys()))
        correct_prompt = f'Type "help", "copyright", "credits" or "license" for more information.\n>>> age = {age_dict}\n>>> var = age["{correct_name}"]\n'
        correct_response = age_dict[correct_name]
        correct_token = str(correct_response)[0]
        
        # Create an incorrect example with a name not in the dictionary
        incorrect_name = random.choice([name for name in name_pool if name not in age_dict])
        incorrect_prompt = f'Type "help", "copyright", "credits" or "license" for more information.\n>>> age = {age_dict}\n>>> var = age["{incorrect_name}"]\n'
        incorrect_response = "Traceback"
        incorrect_token = "Traceback"
        
        # Append the pair of correct and incorrect examples
        dataset.append({
            "correct": {
                "prompt": correct_prompt,
                "response": correct_response,
                "token": correct_token
            },
            "error": {
                "prompt": incorrect_prompt,
                "response": incorrect_response,
                "token": incorrect_token
            }
        })
        
    return dataset

# Generate the extended dataset
json_dataset = generate_extended_dataset(extended_name_pool, num_samples=100)

# Output the JSON structure

# %%
clean_prompts = []
corr_prompts = []

answer_token = model.to_single_token(">>>")
traceback_token = model.to_single_token("Traceback")

for item in json_dataset[:50]:
    corr_prompts.append(item["correct"]["prompt"])
    clean_prompts.append(item["error"]["prompt"])

clean_tokens = model.to_tokens(clean_prompts)
corr_tokens = model.to_tokens(corr_prompts)

# %%
def logit_diff_fn(logits):
    err = logits[:, -1, traceback_token]
    no_err = logits[:, -1, answer_token]
    return (err - no_err).mean()

# Disable gradients for all parameters
for param in model.parameters():
   param.requires_grad_(False)

# # Compute logits for clean and corrupted samples
logits = model(clean_tokens)
clean_diff = logit_diff_fn(logits)

logits = model(corr_tokens)
corr_diff = logit_diff_fn(logits)

print(f"clean_diff: {clean_diff}")
print(f"corr_diff: {corr_diff}")

# # Cleanup
del logits
cleanup_cuda()

# # Define error type metric
def _err_type_metric(logits, clean_logit_diff, corr_logit_diff):
    patched_logit_diff = logit_diff_fn(logits)
    return (patched_logit_diff - corr_logit_diff) / (clean_logit_diff - corr_logit_diff)

err_metric_denoising = partial(_err_type_metric, clean_logit_diff=clean_diff, corr_logit_diff=corr_diff)

clean_diff: 3.529353141784668
corr_diff: -4.800076007843018


In [None]:
class SparseMask(nn.Module):
    def __init__(self, shape, l1):
        super().__init__()
        self.mask = nn.Parameter(torch.ones(shape))
        self.l1 = l1
        self.max_temp = torch.tensor(1000.0)
        self.sparsity_loss = None
        self.ratio_trained = 1


    def forward(self, x):
        temperature = self.max_temp ** self.ratio_trained
        mask = torch.sigmoid(self.mask * temperature)
        self.sparsity_loss = torch.abs(mask).sum() * self.l1
        return x * mask


saes[0].mask = SparseMask(saes[0].cfg.d_sae, 1.0)
saes[1].mask = SparseMask(saes[1].cfg.d_sae, 1.0)

In [None]:
batch_size = 16

# makethe length of the dataset a multiple of the batch size
simple_dataset = simple_dataset[:batch_size*(len(simple_dataset)//batch_size)]
simple_labels = simple_labels[:batch_size*(len(simple_labels)//batch_size)]

simple_dataset = simple_dataset.view(-1, batch_size, 65)
simple_labels = simple_labels.view(-1, batch_size)


In [None]:
def logitfn(tokens):
    return model.run_with_hooks(
        tokens, 
        return_type="logits", 
        fwd_hooks=build_hooks_list(tokens, use_mask=True)
        )


def forward_pass(batch, labels, logitfn, ratio_trained=0):
    for sae in saes:
        sae.mask.ratio_trained = ratio_trained
    tokens = batch
    logits = logitfn(tokens)
    last_token_logits = logits[:, -1, :]
    loss = F.cross_entropy(last_token_logits, labels)
    sparsity_loss = 0
    for sae in saes:
        sparsity_loss = sparsity_loss + sae.mask.sparsity_loss
    
    sparsity_loss = sparsity_loss / len(saes)

    return loss, sparsity_loss

In [None]:
all_optimized_params = []
for sae in saes:
    all_optimized_params.extend(list(sae.mask.parameters()))

optimizer = optim.Adam(all_optimized_params, lr=0.1)


In [None]:
from tqdm import tqdm

total_steps = simple_dataset.shape[0]*0.01

with tqdm(total=total_steps, desc="Training Progress") as pbar:
    for i, (x, y) in enumerate(zip(simple_dataset, simple_labels)):
        optimizer.zero_grad()
        
        # Calculate ratio trained
        ratio_trained = i / total_steps
        
        # Update mask ratio for each SAE
        for sae in saes:
            sae.mask.ratio_trained = ratio_trained
        
        # Forward pass with updated ratio_trained
        loss, sparsity_loss = forward_pass(x, y, logitfn, ratio_trained=0)
        avg_nonzero_elements = sparsity_loss
        sparsity_loss = sparsity_loss/50
        total_loss = loss + sparsity_loss
        
        # Backward pass and optimizer step
        total_loss.backward()
        optimizer.step()
        
        # Update tqdm bar with relevant metrics
        pbar.set_postfix({
            'Step': i,
            'Progress': f"{ratio_trained:.2%}",
            "Avg Nonzero Elements": f"{avg_nonzero_elements:.2f}",
            'Task Loss': f"{loss.item():.4f}",
            'Sparsity Loss': f"{sparsity_loss.item():.4f}"
        })
        
        # Update the tqdm progress bar
        pbar.update(1)
