In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]= "6,7"
import torch
from tqdm import tqdm
import plotly.express as px

torch.set_grad_enabled(False);

# %%
from transformer_lens import HookedTransformer
from sae_lens import SAE

# Choose a layer you want to focus on
device= "cuda:0"
model = HookedTransformer.from_pretrained("gpt2-small").to('cuda')

# Initialize SAE
layer = 8
sae, cfg_dict, _ = SAE.from_pretrained(
    release="gpt2-small-res-jb",
    sae_id=f"blocks.{layer}.hook_resid_pre",
    device= "cuda:0"
)

# get hook point
hook_point = sae.cfg.hook_name
print(hook_point)

Loaded pretrained model gpt2-small into HookedTransformer
Moving model to device:  cuda
blocks.8.hook_resid_pre


In [5]:
from sae_lens import SAETrainingRunner, LanguageModelSAERunnerConfig, SAE
from sae_lens.training.training_sae import TrainingSAE
import torch
from typing import Any, Tuple, cast
from transformer_lens import HookedTransformer
from typing import Optional
from datasets import Dataset, IterableDataset
from datasets import load_dataset
import os
import torch.nn.functional as F
import wandb

class RSAETrainingRunner(SAETrainingRunner):
    def __init__(
        self,
        cfg: LanguageModelSAERunnerConfig,
        r_weight: float = 1,
        n_steps: int = 50,
        k: int = 7,
    ):
        super().__init__(cfg)
        self.r_weight = r_weight
        self.n_steps = n_steps
        self.k = k
        self.tokens_processed = 0
        self.dataset_iterator: Optional[Iterator] = None
        self.dataset: Optional[Dataset | IterableDataset] = None
        self.optimizer = torch.optim.Adam(self.sae.parameters(), lr=self.cfg.lr)
        self.step = 0

    def run(self):
        if self.cfg.log_to_wandb:
            wandb.init(
                project=self.cfg.wandb_project,
                config=cast(Any, self.cfg),
                name=self.cfg.run_name,
                id=self.cfg.wandb_id,
            )

        self._compile_if_needed()

        try:
            # Main training loop
            while self.tokens_processed < self.cfg.training_tokens:
                batch = self.get_next_batch()
                loss_dict = self.training_step(batch)
                
                # Update tokens processed
                self.tokens_processed += batch.numel()
                
                # Log metrics
                if self.cfg.log_to_wandb and self.step % self.cfg.wandb_log_frequency == 0:
                    wandb.log({**loss_dict, "tokens_processed": self.tokens_processed}, step=self.step)

                self.step += 1

        except KeyboardInterrupt:
            print("Training interrupted. Saving checkpoint...")
            self.save_checkpoint(self.step)

        if self.cfg.log_to_wandb:
            wandb.finish()

        return self.sae

    def setup_dataset(self):
        if self.cfg.is_dataset_tokenized:
            self.dataset = load_dataset(self.cfg.dataset_path, streaming=self.cfg.streaming, split="train")
        else:
            # If the dataset is not tokenized, we need to tokenize it
            raw_dataset = load_dataset(self.cfg.dataset_path, streaming=self.cfg.streaming, split="train")
            self.dataset = raw_dataset.map(lambda examples: self.model.tokenizer(examples['text'], truncation=True, padding='max_length', max_length=self.cfg.context_size), batched=True)
        
        self.dataset_iterator = iter(self.dataset)

    def check_weight_updates(self):
        if not hasattr(self, 'old_weights'):
            self.old_weights = {name: param.clone().detach() for name, param in self.sae.named_parameters()}
            return

        total_diff = 0
        max_diff = 0
        for name, param in self.sae.named_parameters():
            diff = (param.data - self.old_weights[name]).abs().sum().item()
            total_diff += diff
            max_diff = max(max_diff, diff)
            self.old_weights[name] = param.clone().detach()
        
        print(f"Total weight diff: {total_diff:.6f}, Max weight diff: {max_diff:.6f}")
        #wandb.log({"total_weight_diff": total_diff, "max_weight_diff": max_diff}, step=self.step)

    def get_next_batch(self) -> torch.Tensor:
        if self.dataset_iterator is None:
            self.setup_dataset()
        
        batch_size = self.cfg.train_batch_size_tokens // self.cfg.context_size
        batch = []

        while len(batch) < batch_size:
            try:
                item = next(self.dataset_iterator)
                if self.cfg.is_dataset_tokenized:
                    input_ids = torch.tensor(item['input_ids'], dtype=torch.long)
                else:
                    input_ids = torch.tensor(item['input_ids'], dtype=torch.long)
                
                # Ensure the sequence is the correct length
                if input_ids.size(0) < self.cfg.context_size:
                    input_ids = torch.cat([input_ids, torch.zeros(self.cfg.context_size - input_ids.size(0), dtype=torch.long)])
                elif input_ids.size(0) > self.cfg.context_size:
                    input_ids = input_ids[:self.cfg.context_size]
                
                batch.append(input_ids)

            except StopIteration:
                # If we've reached the end of the dataset, reset the iterator
                self.dataset_iterator = iter(self.dataset)
                # If we couldn't fill the batch, break and return what we have
                if len(batch) > 0:
                    break
                else:
                    continue  # If the batch is empty, try again with the reset iterator

        # Stack the batch items into a single tensor
        batch_tensor = torch.stack(batch)

        return batch_tensor.to(self.cfg.device)

    def training_step(self, input_ids: torch.Tensor) -> dict[str, float]:
        # Compute activations and integrated gradients
        activations, attributions = self.integrated_gradients(
            self.model,
            input_ids,
            self.cfg.hook_name,
            self.n_steps,
            largest_logit=False
        )

        with torch.enable_grad():
            # Compute R score for each token in the sequence
            r_scores = []
            for i in range(activations.shape[1]): 
                 # Iterate over sequence length
                r_score = self.compute_r_score(
                    self.sae,
                    activations[:, i],
                    attributions[:, i],
                    self.k
                )
            
                r_scores.append(r_score)
            
            if i % 5 == 0:
                self.check_weight_updates()
            avg_r_score = sum(r_scores) / len(r_scores)
            sae_output = self.sae.training_forward_pass(activations, self.cfg.l1_coefficient)
            r_loss = 1 - avg_r_score
            print(f"mse and l0 loss: {sae_output.loss}")
            print(f"R_loss: {r_loss}")
            print(f"Average R_score: {avg_r_score}")
            total_loss = r_loss + (sae_output.loss / 800000)

        # Perform optimization step
        self.optimizer.zero_grad()
        total_loss.backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(self.sae.parameters(), max_norm=float('inf'))
        print(f"gradient_norm: {grad_norm}")
        #wandb.log({"gradient_norm": grad_norm}, step=self.step)
        self.optimizer.step()

        # Create a dictionary for logging
        loss_dict = {
            'mse_loss': sae_output.mse_loss,
            'l1_loss': sae_output.l1_loss,
            'ghost_grad_loss': sae_output.ghost_grad_loss,
            'auxiliary_reconstruction_loss': sae_output.auxiliary_reconstruction_loss,
            'r_score': avg_r_score,
            'total_loss': total_loss.item()
        }
        return loss_dict


    def compute_r_score(
        self,
        sae: SAE,
        activation: torch.Tensor,
        attribution: torch.Tensor,
        k: int = 5
    ) -> torch.Tensor:
        """
        Compute the R-score for multiple sets of vectors and attributions.
        
        Args:
        sae_directions: Tensor of shape [batch_size, num_vectors, vector_dim]
        attribution: Tensor of shape [batch_size, vector_dim]
        top_k_values: Tensor of shape [batch_size, num_vectors]
        epsilon: Small value to avoid division by zero
        
        Returns:
        r_scores: Tensor of shape [batch_size]
        """
        epsilon: float = 1e-8
        sae_directions, top_k_values = self.get_top_k_active_directions(sae, activation, k)
        sae_directions = sae_directions.permute(2, 1, 0)

        #print(f"attribtion: {attribution.shape}")
        #print(f"sae_directions: {sae_directions.shape}")

        # Compute the projection matrix
        V = sae_directions.transpose(1, 2)  # [batch_size, vector_dim, num_vectors]
        VTV = torch.bmm(V.transpose(1, 2), V)  # [batch_size, num_vectors, num_vectors]
        VTV_inv = torch.inverse(VTV + epsilon * torch.eye(VTV.size(1), device=VTV.device).unsqueeze(0))
        P = torch.bmm(torch.bmm(V, VTV_inv), V.transpose(1, 2))  # [batch_size, vector_dim, vector_dim]
        
        # Compute the projection of attribution onto the span of weighted directions
        u = attribution.unsqueeze(-1)  # [batch_size, vector_dim, 1]
        proj_u = torch.bmm(P, u)  # [batch_size, vector_dim, 1]
        
        
        # Compute the R-score
        r_scores = torch.sqrt(
            (torch.bmm(u.transpose(1, 2), proj_u).squeeze(-1) + epsilon) /
            (torch.sum(u**2, dim=1) + epsilon)
        ).squeeze(-1)
        #print(f"r_scores {r_scores}")
        return r_scores.mean()

    def integrated_gradients(
        self,
        model: HookedTransformer,
        input_ids: torch.Tensor,
        target_layer: str,
        n_steps: int = 30,
        largest_logit: bool = False,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        device = input_ids.device
        layer_index = int(target_layer.split('.')[1])

        # Get baseline and input activations
        with torch.no_grad():
            _, cache = model.run_with_cache(input_ids)
            input_activation = cache[target_layer]
        
        baseline_activation = torch.zeros_like(input_activation)
        delta = (input_activation - baseline_activation) / n_steps
        integrated_gradients = torch.zeros_like(input_activation)

        for i in range(n_steps + 1):
            with torch.enable_grad():
                with torch.autograd.set_detect_anomaly(True):
                    interpolated_activation = baseline_activation + i * delta
                    interpolated_activation.requires_grad_(True)
                    logits = self.compute_subsequent_outputs_logits_only(model, interpolated_activation, layer_index)

                    if largest_logit:
                        last_token_logits = logits[:, -1]
                        max_logit_value, max_logit_index = torch.max(last_token_logits, dim=-1)
                        
                        # Compute gradient only for the highest logit
                        max_logit_value.sum().backward(retain_graph=True)
                        gradient = interpolated_activation.grad

                        integrated_gradients += gradient / n_steps
                        interpolated_activation.grad = None  
                    else:
                        loss = logits[:, -1, input_ids[:, -1]].sum()
                        loss.backward(create_graph=False)
                        gradient = interpolated_activation.grad

                        integrated_gradients += gradient / n_steps

        attributions = integrated_gradients * (input_activation - baseline_activation)
        return input_activation, attributions

    
    def get_top_k_active_directions(self, sae: SAE, activation: torch.Tensor, k: int):
        # Assuming activation shape is [16, 768]
        latent = sae.encode(activation)  # [16, n_features]
        top_k_values, top_k_indices = torch.topk(latent.abs(), k, dim=-1)  # Both: [16, 5]
        
        # Gather top k directions for each item in the batch
        top_k_directions = sae.W_enc[:, top_k_indices.flatten()].view(sae.W_enc.shape[0], k, -1)  # [768, 5, 16]
        
        return top_k_directions, top_k_values

    def compute_subsequent_outputs_logits_only(self, model: torch.nn.Module, activation: torch.Tensor, layer_index: int) -> torch.Tensor:
        # Ensure activation requires gradients
        with torch.enable_grad():
            if not activation.requires_grad:
                activation.requires_grad_(True)
            if activation.ndim == 2:
                activation = activation.unsqueeze(1)       
            current_activation = activation
            for block in model.blocks[layer_index+1:]:
                current_activation = block(current_activation)
                
            current_activation = model.ln_final(current_activation)
            logits = model.unembed(current_activation)
        return logits
    



cfg = LanguageModelSAERunnerConfig(
    # Model and SAE Configuration
    model_name="gpt2-small",
    hook_name="blocks.8.hook_resid_pre",
    hook_layer=8,
    d_in=768,
    expansion_factor=32,

    # Dataset Configuration
    dataset_path="apollo-research/Skylion007-openwebtext-tokenizer-gpt2",
    is_dataset_tokenized=True,
    streaming=True,

    # Training Parameters
    train_batch_size_tokens=400,
    context_size=256,
    training_tokens=10_000_000,  # 10 million tokens for finetuning
    lr=5e-3,  # Lower learning rate for finetuning
    lr_scheduler_name="cosine",
    l1_coefficient=5e-3,
    lp_norm=1.0,

    # Logging and Checkpointing
    log_to_wandb=True,
    wandb_project="sae_finetuning_r_score",
    n_checkpoints=5,
    checkpoint_path="./checkpoints",

    # Hardware Configuration
    device="cuda",
    dtype="float32",
)
os.environ["CUDA_VISIBLE_DEVICES"]= "4,5,6,7"
pretrained_sae, original_cfg_dict, sparsity = SAE.from_pretrained(
    release="gpt2-small-res-jb",
    sae_id="blocks.8.hook_resid_pre",
    device="cuda:0",
)
# Initialize the runner with the pretrained SAE
runner = RSAETrainingRunner(cfg, pretrained_sae, n_steps=30, k=7)


Run name: 24576-L1-0.005-LR-0.005-Tokens-1.000e+07
n_tokens_per_buffer (millions): 0.16384
Lower bound: n_contexts_per_buffer (millions): 0.00064
Total training steps: 25000
Total wandb updates: 2500
n_tokens_per_feature_sampling_window (millions): 204.8
n_tokens_per_dead_feature_window (millions): 102.4
We will reset the sparsity calculation 12 times.
Number tokens in sparsity calculation window: 8.00e+05
Loaded pretrained model gpt2-small into HookedTransformer


Resolving data files:   0%|          | 0/73 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/73 [00:00<?, ?it/s]

Objective value: 7284788.0000:   2%|▏         | 2/100 [00:00<00:00, 209.66it/s]


In [7]:
import os
import json
import torch
import wandb
from sae_lens import SAE

def save_finetuned_sae(sae: SAE, config: dict, step: int, save_path: str, log_to_wandb: bool = True):
    # Create the save directory
    checkpoint_path = os.path.join(save_path, f"finetuned_sae_step_{step}")
    os.makedirs(checkpoint_path, exist_ok=True)
    
    # Save the SAE locally
    sae.save_model(checkpoint_path)
    
    # Save the configuration
    config_path = os.path.join(checkpoint_path, "config.json")
    with open(config_path, "w") as f:
        json.dump(config, f)
    
    print(f"Finetuned SAE saved locally to: {checkpoint_path}")
    
    # Log to wandb if enabled
    if log_to_wandb:
        # Initialize wandb if it's not already running
        if wandb.run is None:
            wandb.init(project="ae_finetuning_r_score")
        
        # Create a wandb Artifact
        artifact = wandb.Artifact(
            name=f"finetuned_sae_step_{step}",
            type="model",
            description="Finetuned Sparse Autoencoder",
            metadata=config
        )
        
        # Add files to the artifact
        artifact.add_file(os.path.join(checkpoint_path, "SAE_WEIGHTS_PATH"))
        artifact.add_file(os.path.join(checkpoint_path, "SAE_CFG_PATH"))
        artifact.add_file(config_path)
        
        # Log the artifact to wandb
        wandb.log_artifact(artifact)
        
        print(f"Finetuned SAE logged to wandb as artifact: {artifact.name}")


In [8]:
import wandb
finetuned_sae = runner.run()

save_finetuned_sae(
    sae=finetuned_sae,
    config=runner.cfg.to_dict(),
    step=runner.step,
    save_path="./checkpoints",
    log_to_wandb=True
)

Resolving data files:   0%|          | 0/73 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/73 [00:00<?, ?it/s]

mse and l0 loss: 85418.8515625
R_loss: 0.9062745571136475
Average R_score: 0.09372545778751373
gradient_norm: 0.6121602058410645
Total weight diff: 186176.226668, Max weight diff: 93280.453125
mse and l0 loss: 9503695.0
R_loss: 0.9079520106315613
Average R_score: 0.09204798936843872
gradient_norm: 17.791101455688477
Total weight diff: 130041.577149, Max weight diff: 68210.812500
mse and l0 loss: 259743.140625
R_loss: 0.9051637053489685
Average R_score: 0.0948362722992897
gradient_norm: 0.9893543124198914
Total weight diff: 104796.920078, Max weight diff: 52589.210938
mse and l0 loss: 735669.6875
R_loss: 0.9033849239349365
Average R_score: 0.09661506116390228
gradient_norm: 2.7058353424072266
Total weight diff: 78044.481183, Max weight diff: 43656.484375
mse and l0 loss: 1044603.3125
R_loss: 0.9098259806632996
Average R_score: 0.09017402678728104
gradient_norm: 1.4878594875335693
Total weight diff: 73737.491240, Max weight diff: 38093.335938
mse and l0 loss: 395614.0
R_loss: 0.899923205

wandb: Network error (ReadTimeout), entering retry loop.


Total weight diff: 946.054619, Max weight diff: 528.405396
mse and l0 loss: 4315.19775390625
R_loss: 0.5173982381820679
Average R_score: 0.48260176181793213
gradient_norm: 0.04197254776954651
Total weight diff: 870.173353, Max weight diff: 473.916443
mse and l0 loss: 4211.33984375
R_loss: 0.4625621438026428
Average R_score: 0.5374378561973572
gradient_norm: 0.041554149240255356
Total weight diff: 818.193115, Max weight diff: 441.718750
mse and l0 loss: 5780.1884765625
R_loss: 0.7025718688964844
Average R_score: 0.2974281311035156
gradient_norm: 0.06989959627389908
Total weight diff: 827.403580, Max weight diff: 451.834686
mse and l0 loss: 4019.802978515625
R_loss: 0.4446961283683777
Average R_score: 0.5553038716316223
gradient_norm: 0.03980289399623871
Total weight diff: 770.751506, Max weight diff: 412.522675
mse and l0 loss: 4595.3212890625
R_loss: 0.6894697546958923
Average R_score: 0.31053024530410767
gradient_norm: 0.05557461082935333
Total weight diff: 743.218094, Max weight diff