In [1]:
! pip install datasets jaxtyping



In [2]:
import wandb
from google.colab import userdata
import os
os.environ["WANDB_API_KEY"] = userdata.get('WANDB_API_KEY')
wandb.login()
from huggingface_hub.hf_api import HfFolder

HfFolder.save_token(userdata.get('HF_TOKEN'))

[34m[1mwandb[0m: Currently logged in as: [33mjacktpayne51[0m ([33mjacktpayne51-macquarie-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [3]:
# %%
import os
from IPython import get_ipython

ipython = get_ipython()
# Code to automatically update the HookedTransformer code as its edited without restarting the kernel
if ipython is not None:
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

import plotly.io as pio
pio.renderers.default = "jupyterlab"

# Import stuff
import einops
import json
import argparse

from datasets import load_dataset
from pathlib import Path
import plotly.express as px
from torch.distributions.categorical import Categorical
from tqdm import tqdm
import torch
import numpy as np
from typing import Optional, Union, Dict, Any
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer
import wandb

def to_numpy(tensor):
    """Convert a tensor to numpy array."""
    if isinstance(tensor, np.ndarray):
        return tensor
    elif isinstance(tensor, (list, tuple)):
        return np.array(tensor)
    elif isinstance(tensor, (int, float)):
        return np.array(tensor)
    elif isinstance(tensor, torch.Tensor):
        return tensor.detach().cpu().numpy()
    else:
        raise ValueError(f"Unsupported type for to_numpy conversion: {type(tensor)}")

class ObservableModel:
    """
    A wrapper for HuggingFace models that allows for activation capture and intervention.
    Replaces TransformerLens functionality with native PyTorch hooks.
    """
    def __init__(
        self,
        model_name: str,
        device: str = "cuda",
        dtype: torch.dtype = torch.bfloat16,
    ):
        self.dtype = dtype
        self.device = device

        # Configure model initialization properly
        model_kwargs = {
            "torch_dtype": self.dtype,
            "device_map": device,
            "use_cache": False,
            "cache_implementation": None  # Disable hybrid cache implementation when use_cache is False
        }

        self._model = AutoModelForCausalLM.from_pretrained(
            model_name,
            **model_kwargs
        )
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.cfg = self._model.config

        # Cache for storing activations during forward pass
        self.activation_cache = {}

        # Print model structure on initialization
        print("\nAvailable hook points:")
        self.print_available_hook_points()

    def print_available_hook_points(self):
        """Print the available hook points in the model structure"""
        def get_named_modules(model, prefix=""):
            for name, module in model.named_modules():
                if name:  # Skip empty names
                    print(f"- {name}")

        get_named_modules(self._model)

    def _find_module(self, hook_point: str) -> nn.Module:
        """Finds a module given its name in dot notation."""
        try:
            submodules = hook_point.split(".")
            module = self._model
            while submodules:
                module = getattr(module, submodules.pop(0))
            return module
        except Exception as e:
            raise ValueError(f"Could not find module {hook_point}: {str(e)}")

    def run_with_cache(
        self,
        input_ids: torch.Tensor,
        names_filter: Optional[str] = None
    ) -> tuple[Any, Dict[str, torch.Tensor]]:
        """
        Run the model while caching activations at specified points.
        Similar to TransformerLens's run_with_cache but for HF models.
        """
        self.activation_cache = {}

        def cache_hook(name: str):
            def hook(mod, inputs, outputs):
                # For LayerNorm, we want to cache the normalized activations
                if isinstance(mod, (nn.LayerNorm, nn.RMSNorm)):
                    if isinstance(outputs, tuple):
                        self.activation_cache[name] = outputs[0].detach()
                    else:
                        self.activation_cache[name] = outputs.detach()
                # For attention layers, we want to cache the input
                elif "input_layernorm" in name:
                    self.activation_cache[name] = inputs[0].detach()
                    # Also cache the scale for later use
                    if hasattr(mod, 'weight'):
                        scale = torch.rsqrt(inputs[0].pow(2).mean(-1, keepdim=True) + mod.eps) * mod.weight
                        self.activation_cache[f"{name}.scale"] = scale.detach()
                else:
                    if isinstance(outputs, tuple):
                        self.activation_cache[name] = outputs[0].detach()
                    else:
                        self.activation_cache[name] = outputs.detach()
                return outputs
            return hook

        handles = []
        if names_filter:
            try:
                module = self._find_module(names_filter)
                handles.append(module.register_forward_hook(cache_hook(names_filter)))
            except Exception as e:
                print(f"Error registering hook for {names_filter}: {str(e)}")
                raise

        try:
            with torch.no_grad():
                outputs = self._model.forward(
                    input_ids,
                    use_cache=False,  # Explicitly disable caching
                    output_hidden_states=True  # Ensure we get all hidden states
                )
        finally:
            for handle in handles:
                handle.remove()

        return outputs, self.activation_cache

    @property
    def cfg(self):
        """Access to model config, similar to TransformerLens."""
        return self._model.config

    @cfg.setter
    def cfg(self, value):
        self._model.config = value

from jaxtyping import Float
#from transformer_lens.hook_points import HookPoint

from functools import partial

from IPython.display import HTML

#from transformer_lens.utils import to_numpy
import pandas as pd

from html import escape
import colorsys

import plotly.graph_objects as go

update_layout_set = {
    "xaxis_range", "yaxis_range", "hovermode", "xaxis_title", "yaxis_title", "colorbar", "colorscale", "coloraxis",
    "title_x", "bargap", "bargroupgap", "xaxis_tickformat", "yaxis_tickformat", "title_y", "legend_title_text", "xaxis_showgrid",
    "xaxis_gridwidth", "xaxis_gridcolor", "yaxis_showgrid", "yaxis_gridwidth"
}

def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    if isinstance(tensor, list):
        tensor = torch.stack(tensor)
    kwargs_post = {k: v for k, v in kwargs.items() if k in update_layout_set}
    kwargs_pre = {k: v for k, v in kwargs.items() if k not in update_layout_set}
    if "facet_labels" in kwargs_pre:
        facet_labels = kwargs_pre.pop("facet_labels")
    else:
        facet_labels = None
    if "color_continuous_scale" not in kwargs_pre:
        kwargs_pre["color_continuous_scale"] = "RdBu"
    fig = px.imshow(to_numpy(tensor), color_continuous_midpoint=0.0,labels={"x":xaxis, "y":yaxis}, **kwargs_pre).update_layout(**kwargs_post)
    if facet_labels:
        for i, label in enumerate(facet_labels):
            fig.layout.annotations[i]['text'] = label

    fig.show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(y=to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, return_fig=False, **kwargs):
    x = to_numpy(x)
    y = to_numpy(y)
    fig = px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs)
    if return_fig:
        return fig
    fig.show(renderer)

def lines(lines_list, x=None, mode='lines', labels=None, xaxis='', yaxis='', title = '', log_y=False, hover=None, **kwargs):
    # Helper function to plot multiple lines
    if type(lines_list)==torch.Tensor:
        lines_list = [lines_list[i] for i in range(lines_list.shape[0])]
    if x is None:
        x=np.arange(len(lines_list[0]))
    fig = go.Figure(layout={'title':title})
    fig.update_xaxes(title=xaxis)
    fig.update_yaxes(title=yaxis)
    for c, line in enumerate(lines_list):
        if type(line)==torch.Tensor:
            line = to_numpy(line)
        if labels is not None:
            label = labels[c]
        else:
            label = c
        fig.add_trace(go.Scatter(x=x, y=line, mode=mode, name=label, hovertext=hover, **kwargs))
    if log_y:
        fig.update_layout(yaxis_type="log")
    fig.show()

def bar(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.bar(
        y=to_numpy(tensor),
        labels={"x": xaxis, "y": yaxis},
        template="simple_white",
        **kwargs).show(renderer)

def create_html(strings, values, saturation=0.5, allow_different_length=False):
    # escape strings to deal with tabs, newlines, etc.
    escaped_strings = [escape(s, quote=True) for s in strings]
    processed_strings = [
        s.replace("\n", "<br/>").replace("\t", "&emsp;").replace(" ", "&nbsp;")
        for s in escaped_strings
    ]

    if isinstance(values, torch.Tensor) and len(values.shape)>1:
        values = values.flatten().tolist()

    if not allow_different_length:
        assert len(processed_strings) == len(values)

    # scale values
    max_value = max(max(values), -min(values))+1e-3
    scaled_values = [v / max_value * saturation for v in values]

    # create html
    html = ""
    for i, s in enumerate(processed_strings):
        if i<len(scaled_values):
            v = scaled_values[i]
        else:
            v = 0
        if v < 0:
            hue = 0  # hue for red in HSV
        else:
            hue = 0.66  # hue for blue in HSV
        rgb_color = colorsys.hsv_to_rgb(
            hue, v, 1
        )  # hsv color with hue 0.66 (blue), saturation as v, value 1
        hex_color = "#%02x%02x%02x" % (
            int(rgb_color[0] * 255),
            int(rgb_color[1] * 255),
            int(rgb_color[2] * 255),
        )
        html += f'<span style="background-color: {hex_color}; border: 1px solid lightgray; font-size: 16px; border-radius: 3px;">{s}</span>'

    display(HTML(html))

# crosscoder stuff

def arg_parse_update_cfg(default_cfg):
    """
    Helper function to take in a dictionary of arguments, convert these to command line arguments, look at what was passed in, and return an updated dictionary.

    If in Ipython, just returns with no changes
    """
    if get_ipython() is not None:
        # Is in IPython
        print("In IPython - skipped argparse")
        return default_cfg
    cfg = dict(default_cfg)
    parser = argparse.ArgumentParser()
    for key, value in default_cfg.items():
        if type(value) == bool:
            # argparse for Booleans is broken rip. Now you put in a flag to change the default --{flag} to set True, --{flag} to set False
            if value:
                parser.add_argument(f"--{key}", action="store_false")
            else:
                parser.add_argument(f"--{key}", action="store_true")

        else:
            parser.add_argument(f"--{key}", type=type(value), default=value)
    args = parser.parse_args()
    parsed_args = vars(args)
    cfg.update(parsed_args)
    print("Updated config")
    print(json.dumps(cfg, indent=2))
    return cfg

def load_pile_lmsys_mixed_tokens():
    try:
        print("Loading data from disk")
        all_tokens = torch.load("/workspace/data/pile-lmsys-mix-1m-tokenized-gemma-2.pt")
    except:
        print("Data is not cached. Loading data from HF")
        data = load_dataset(
            "ckkissane/pile-lmsys-mix-1m-tokenized-gemma-2",
            split="train",
            cache_dir="/workspace/cache/"
        )
        data.save_to_disk("/workspace/data/pile-lmsys-mix-1m-tokenized-gemma-2.hf")
        data.set_format(type="torch", columns=["input_ids"])
        all_tokens = data["input_ids"]
        torch.save(all_tokens, "/workspace/data/pile-lmsys-mix-1m-tokenized-gemma-2.pt")
        print(f"Saved tokens to disk")
    return all_tokens

In [4]:
#from utils import *

from torch import nn
import pprint
import torch.nn.functional as F
from typing import Optional, Union
from huggingface_hub import hf_hub_download

from typing import NamedTuple

DTYPES = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
SAVE_DIR = Path("/workspace/crosscoder-model-diff-replication/checkpoints")

class LossOutput(NamedTuple):
    # loss: torch.Tensor
    l2_loss: torch.Tensor
    l1_loss: torch.Tensor
    l0_loss: torch.Tensor
    explained_variance: torch.Tensor
    explained_variance_A: torch.Tensor
    explained_variance_B: torch.Tensor

class CrossCoder(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        d_hidden = self.cfg["dict_size"]
        d_in = self.cfg["d_in"]
        self.dtype = DTYPES[self.cfg["enc_dtype"]]
        torch.manual_seed(self.cfg["seed"])
        # hardcoding n_models to 2
        self.W_enc = nn.Parameter(
            torch.empty(2, d_in, d_hidden, dtype=self.dtype)
        )
        self.W_dec = nn.Parameter(
            torch.nn.init.normal_(
                torch.empty(
                    d_hidden, 2, d_in, dtype=self.dtype
                )
            )
        )
        self.W_dec = nn.Parameter(
            torch.nn.init.normal_(
                torch.empty(
                    d_hidden, 2, d_in, dtype=self.dtype
                )
            )
        )
        # Make norm of W_dec 0.1 for each column, separate per layer
        self.W_dec.data = (
            self.W_dec.data / self.W_dec.data.norm(dim=-1, keepdim=True) * self.cfg["dec_init_norm"]
        )
        # Initialise W_enc to be the transpose of W_dec
        self.W_enc.data = einops.rearrange(
            self.W_dec.data.clone(),
            "d_hidden n_models d_model -> n_models d_model d_hidden",
        )
        self.b_enc = nn.Parameter(torch.zeros(d_hidden, dtype=self.dtype))
        self.b_dec = nn.Parameter(
            torch.zeros((2, d_in), dtype=self.dtype)
        )
        self.d_hidden = d_hidden

        self.to(self.cfg["device"])
        self.save_dir = None
        self.save_version = 0

    def encode(self, x, apply_relu=True):
        # x: [batch, n_models, d_model]
        x_enc = einops.einsum(
            x,
            self.W_enc,
            "batch n_models d_model, n_models d_model d_hidden -> batch d_hidden",
        )
        if apply_relu:
            acts = F.relu(x_enc + self.b_enc)
        else:
            acts = x_enc + self.b_enc
        return acts

    def decode(self, acts):
        # acts: [batch, d_hidden]
        acts_dec = einops.einsum(
            acts,
            self.W_dec,
            "batch d_hidden, d_hidden n_models d_model -> batch n_models d_model",
        )
        return acts_dec + self.b_dec

    def forward(self, x):
        # x: [batch, n_models, d_model]
        acts = self.encode(x)
        return self.decode(acts)

    def get_losses(self, x):
        # x: [batch, n_models, d_model]
        x = x.to(self.dtype)
        acts = self.encode(x)
        # acts: [batch, d_hidden]
        x_reconstruct = self.decode(acts)
        diff = x_reconstruct.float() - x.float()
        squared_diff = diff.pow(2)
        l2_per_batch = einops.reduce(squared_diff, 'batch n_models d_model -> batch', 'sum')
        l2_loss = l2_per_batch.mean()

        total_variance = einops.reduce((x - x.mean(0)).pow(2), 'batch n_models d_model -> batch', 'sum')
        explained_variance = 1 - l2_per_batch / total_variance

        per_token_l2_loss_A = (x_reconstruct[:, 0, :] - x[:, 0, :]).pow(2).sum(dim=-1).squeeze()
        total_variance_A = (x[:, 0, :] - x[:, 0, :].mean(0)).pow(2).sum(-1).squeeze()
        explained_variance_A = 1 - per_token_l2_loss_A / total_variance_A

        per_token_l2_loss_B = (x_reconstruct[:, 1, :] - x[:, 1, :]).pow(2).sum(dim=-1).squeeze()
        total_variance_B = (x[:, 1, :] - x[:, 1, :].mean(0)).pow(2).sum(-1).squeeze()
        explained_variance_B = 1 - per_token_l2_loss_B / total_variance_B

        decoder_norms = self.W_dec.norm(dim=-1)
        # decoder_norms: [d_hidden, n_models]
        total_decoder_norm = einops.reduce(decoder_norms, 'd_hidden n_models -> d_hidden', 'sum')
        l1_loss = (acts * total_decoder_norm[None, :]).sum(-1).mean(0)

        l0_loss = (acts>0).float().sum(-1).mean()

        return LossOutput(l2_loss=l2_loss, l1_loss=l1_loss, l0_loss=l0_loss, explained_variance=explained_variance, explained_variance_A=explained_variance_A, explained_variance_B=explained_variance_B)

    def create_save_dir(self):
        try:
            # For Colab, use a local directory
            base_dir = Path("./checkpoints")
            base_dir.mkdir(parents=True, exist_ok=True)

            version_list = [
                int(file.name.split("_")[1])
                for file in list(base_dir.iterdir())
                if "version" in str(file)
            ]
            if len(version_list):
                version = 1 + max(version_list)
            else:
                version = 0
            self.save_dir = base_dir / f"version_{version}"
            self.save_dir.mkdir(parents=True, exist_ok=True)
            print(f"Created save directory at {self.save_dir}")
        except Exception as e:
            print(f"Error creating save directory: {str(e)}")
            # Fallback to a simple directory structure
            self.save_dir = Path("./checkpoints/version_0")
            self.save_dir.mkdir(parents=True, exist_ok=True)
            print(f"Created fallback save directory at {self.save_dir}")

    def save(self):
        try:
            if self.save_dir is None:
                self.create_save_dir()
            weight_path = self.save_dir / f"{self.save_version}.pt"
            cfg_path = self.save_dir / f"{self.save_version}_cfg.json"

            torch.save(self.state_dict(), weight_path)
            with open(cfg_path, "w") as f:
                json.dump(self.cfg, f)

            print(f"Saved as version {self.save_version} in {self.save_dir}")
            self.save_version += 1
        except Exception as e:
            print(f"Error saving model: {str(e)}")
            # Try to save in current directory as fallback
            try:
                torch.save(self.state_dict(), f"crosscoder_checkpoint_{self.save_version}.pt")
                print(f"Saved fallback checkpoint as crosscoder_checkpoint_{self.save_version}.pt")
                self.save_version += 1
            except Exception as e2:
                print(f"Failed to save fallback checkpoint: {str(e2)}")

    @classmethod
    def load_from_hf(
        cls,
        repo_id: str = "ckkissane/crosscoder-gemma-2-2b-model-diff",
        path: str = "blocks.14.hook_resid_pre",
        device: Optional[Union[str, torch.device]] = None
    ) -> "CrossCoder":
        """
        Load CrossCoder weights and config from HuggingFace.

        Args:
            repo_id: HuggingFace repository ID
            path: Path within the repo to the weights/config
            model: The transformer model instance needed for initialization
            device: Device to load the model to (defaults to cfg device if not specified)

        Returns:
            Initialized CrossCoder instance
        """

        # Download config and weights
        config_path = hf_hub_download(
            repo_id=repo_id,
            filename=f"{path}/cfg.json"
        )
        weights_path = hf_hub_download(
            repo_id=repo_id,
            filename=f"{path}/cc_weights.pt"
        )

        # Load config
        with open(config_path, 'r') as f:
            cfg = json.load(f)

        # Override device if specified
        if device is not None:
            cfg["device"] = str(device)

        # Initialize CrossCoder with config
        instance = cls(cfg)

        # Load weights
        state_dict = torch.load(weights_path, map_location=cfg["device"])
        instance.load_state_dict(state_dict)

        return instance

    @classmethod
    def load(cls, version_dir, checkpoint_version):
        #save_dir = Path("/workspace/crosscoder-model-diff-replication/checkpoints") / str(version_dir)
        load_dir_colab = Path("/content/checkpoints") / str(version_dir)
        cfg_path = load_dir_colab / f"{str(checkpoint_version)}_cfg.json"
        weight_path = load_dir_colab / f"{str(checkpoint_version)}.pt"

        cfg = json.load(open(cfg_path, "r"))
        pprint.pprint(cfg)
        self = cls(cfg=cfg)
        self.load_state_dict(torch.load(weight_path))
        return self

In [5]:
#from utils import *
import tqdm

class Buffer:
    """
    This defines a data buffer, to store a stack of acts across both models that can be used to train the autoencoder.
    It'll automatically run the model to generate more when it gets halfway empty.
    Modified to work with HuggingFace models directly instead of TransformerLens.
    """

    def __init__(self, cfg, model_A, model_B, all_tokens):
        assert model_A.cfg.hidden_size == model_B.cfg.hidden_size
        self.cfg = cfg
        self.buffer_size = cfg["batch_size"] * cfg["buffer_mult"]
        self.buffer_batches = self.buffer_size // (cfg["seq_len"] - 1)
        self.buffer_size = self.buffer_batches * (cfg["seq_len"] - 1)
        self.buffer = torch.zeros(
            (self.buffer_size, 2, model_A.cfg.hidden_size),
            dtype=torch.bfloat16,
            requires_grad=False,
        ).to(cfg["device"]) # hardcoding 2 for model diffing
        self.model_A = model_A
        self.model_B = model_B
        self.token_pointer = 0
        self.first = True
        self.normalize = True
        self.all_tokens = all_tokens

        try:
            estimated_norm_scaling_factor_A = self.estimate_norm_scaling_factor(cfg["model_batch_size"], model_A)
            estimated_norm_scaling_factor_B = self.estimate_norm_scaling_factor(cfg["model_batch_size"], model_B)

            self.normalisation_factor = torch.tensor(
                [
                    estimated_norm_scaling_factor_A,
                    estimated_norm_scaling_factor_B,
                ],
                device=cfg["device"],
                dtype=torch.float32,
            )
            self.refresh()
        except Exception as e:
            print(f"Error during initialization: {str(e)}")
            raise

    @torch.no_grad()
    def estimate_norm_scaling_factor(self, batch_size, model, n_batches_for_norm_estimate: int = 100):
        """Estimate normalization scaling factor for model activations."""
        norms_per_batch = []
        for i in tqdm.tqdm(
            range(n_batches_for_norm_estimate), desc="Estimating norm scaling factor"
        ):
            try:
                tokens = self.all_tokens[i * batch_size : (i + 1) * batch_size].to(self.cfg["device"])
                if tokens.shape[0] == 0:
                    continue

                # Ensure consistent sequence length
                if tokens.shape[1] > self.cfg["seq_len"]:
                    tokens = tokens[:, :self.cfg["seq_len"]]

                _, cache = model.run_with_cache(
                    tokens,
                    names_filter=self.cfg["hook_point"],
                )
                acts = cache[self.cfg["hook_point"]]
                # Drop BOS token
                acts = acts[:, 1:, :]
                norms_per_batch.append(acts.norm(dim=-1).mean().item())
            except Exception as e:
                print(f"Error in batch {i}: {str(e)}")
                continue

        if not norms_per_batch:
            raise ValueError("Failed to compute any norms")

        mean_norm = np.mean(norms_per_batch)
        scaling_factor = np.sqrt(model.cfg.hidden_size) / mean_norm
        return scaling_factor

    @torch.no_grad()
    def refresh(self):
        """Refresh the buffer with new activations."""
        self.pointer = 0
        print("Refreshing the buffer!")
        with torch.autocast("cuda", torch.bfloat16):
            if self.first:
                num_batches = self.buffer_batches
            else:
                num_batches = self.buffer_batches // 2
            self.first = False

            batch_size = self.cfg["model_batch_size"]

            for _ in tqdm.trange(0, num_batches, batch_size):
                try:
                    # Get a batch of tokens
                    end_idx = min(self.token_pointer + batch_size, len(self.all_tokens))
                    tokens = self.all_tokens[self.token_pointer:end_idx].to(self.cfg["device"])

                    # Skip if we got an empty batch
                    if tokens.shape[0] == 0:
                        continue

                    # Ensure consistent sequence length
                    if tokens.shape[1] > self.cfg["seq_len"]:
                        tokens = tokens[:, :self.cfg["seq_len"]]

                    # Get activations from both models
                    _, cache_A = self.model_A.run_with_cache(
                        tokens,
                        names_filter=self.cfg["hook_point"],
                    )

                    _, cache_B = self.model_B.run_with_cache(
                        tokens,
                        names_filter=self.cfg["hook_point"],
                    )

                    # Extract activations and drop BOS token
                    acts_A = cache_A[self.cfg["hook_point"]][:, 1:, :]
                    acts_B = cache_B[self.cfg["hook_point"]][:, 1:, :]

                    # Stack and reshape
                    acts = torch.stack([acts_A, acts_B], dim=0)
                    acts = einops.rearrange(
                        acts,
                        "n_layers batch seq_len d_model -> (batch seq_len) n_layers d_model",
                    )

                    # Update buffer
                    available_space = self.buffer.size(0) - self.pointer
                    if acts.shape[0] > 0:
                        if acts.shape[0] > available_space:
                            acts = acts[:available_space]
                        self.buffer[self.pointer : self.pointer + acts.shape[0]] = acts
                        self.pointer += acts.shape[0]
                        self.token_pointer = end_idx

                except Exception as e:
                    print(f"Error in refresh batch: {str(e)}")
                    continue

            # Reset token pointer if needed
            if self.token_pointer >= len(self.all_tokens) - batch_size:
                self.token_pointer = 0

            # Shuffle buffer
            self.buffer = self.buffer[torch.randperm(self.buffer.shape[0]).to(self.cfg["device"])]
            self.pointer = 0

    @torch.no_grad()
    def next(self):
        """Get next batch of activations from the buffer."""
        out = self.buffer[self.pointer : self.pointer + self.cfg["batch_size"]].float()
        self.pointer += self.cfg["batch_size"]
        if self.pointer > self.buffer.shape[0] // 2 - self.cfg["batch_size"]:
            self.refresh()
        if self.normalize:
            out = out * self.normalisation_factor[None, :, None]
        return out


In [6]:
#from utils import *
#from crosscoder import CrossCoder
#from buffer import Buffer
import tqdm

from torch.nn.utils import clip_grad_norm_
class Trainer:
    def __init__(self, cfg, model_A, model_B, all_tokens):
        self.cfg = cfg
        self.model_A = model_A
        self.model_B = model_B
        self.crosscoder = CrossCoder(cfg)
        self.buffer = Buffer(cfg, model_A, model_B, all_tokens)
        self.total_steps = cfg["num_tokens"] // cfg["batch_size"]

        self.optimizer = torch.optim.Adam(
            self.crosscoder.parameters(),
            lr=cfg["lr"],
            betas=(cfg["beta1"], cfg["beta2"]),
        )
        self.scheduler = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer, self.lr_lambda
        )
        self.step_counter = 0

        wandb.init(project=cfg["wandb_project"], entity=cfg["wandb_entity"])

    def lr_lambda(self, step):
        if step < 0.8 * self.total_steps:
            return 1.0
        else:
            return 1.0 - (step - 0.8 * self.total_steps) / (0.2 * self.total_steps)

    def get_l1_coeff(self):
        # Linearly increases from 0 to cfg["l1_coeff"] over the first 0.05 * self.total_steps steps, then keeps it constant
        if self.step_counter < 0.05 * self.total_steps:
            return self.cfg["l1_coeff"] * self.step_counter / (0.05 * self.total_steps)
        else:
            return self.cfg["l1_coeff"]

    def step(self):
        acts = self.buffer.next()
        losses = self.crosscoder.get_losses(acts)
        loss = losses.l2_loss + self.get_l1_coeff() * losses.l1_loss
        loss.backward()
        clip_grad_norm_(self.crosscoder.parameters(), max_norm=1.0)
        self.optimizer.step()
        self.scheduler.step()
        self.optimizer.zero_grad()

        loss_dict = {
            "loss": loss.item(),
            "l2_loss": losses.l2_loss.item(),
            "l1_loss": losses.l1_loss.item(),
            "l0_loss": losses.l0_loss.item(),
            "l1_coeff": self.get_l1_coeff(),
            "lr": self.scheduler.get_last_lr()[0],
            "explained_variance": losses.explained_variance.mean().item(),
            "explained_variance_A": losses.explained_variance_A.mean().item(),
            "explained_variance_B": losses.explained_variance_B.mean().item(),
        }
        self.step_counter += 1
        return loss_dict

    def log(self, loss_dict):
        wandb.log(loss_dict, step=self.step_counter)
        print(loss_dict)

    def save(self):
        self.crosscoder.save()

    def train(self):
        self.step_counter = 0
        try:
            for i in tqdm.trange(self.total_steps):
                loss_dict = self.step()
                if i % self.cfg["log_every"] == 0:
                    self.log(loss_dict)
                if (i + 1) % self.cfg["save_every"] == 0:
                    self.save()
        finally:
            self.save()

In [7]:
#chat_model

In [8]:
# %%
#from utils import *
#from trainer import Trainer

# %%
device = 'cuda:0'

# Initialize our models using the ObservableModel wrapper
base_model = ObservableModel(
    "google/gemma-2-2b",
    device=device,
)

chat_model = ObservableModel(
    "google/gemma-2-2b-it",
    device=device,
)

# %%
all_tokens = load_pile_lmsys_mixed_tokens()

# %%
default_cfg = {
    "seed": 49,
    "batch_size": 4096,
    "buffer_mult": 128,
    "lr": 5e-5,
    "num_tokens": 2_000_000,
    "l1_coeff": 2,
    "beta1": 0.9,
    "beta2": 0.999,
    "d_in": base_model.cfg.hidden_size,
    "dict_size": 2**14,
    "seq_len": 1024,
    "enc_dtype": "fp32",
    "model_name": "google/gemma-2b",
    "site": "resid_pre",
    "device": "cuda:0",
    "model_batch_size": 4,
    "log_every": 100,
    "save_every": 30000,
    "dec_init_norm": 0.08,
    "hook_point": "model.layers.14.input_layernorm", # Adjust based on model architecture
    "wandb_project": "mats",
    "wandb_entity": "jacktpayne51-macquarie-university",
}

def verify_hook_point(model, hook_point):
    try:
        module = model._find_module(hook_point)
        print(f"Successfully found hook point: {hook_point}")
        print(f"Module type: {type(module)}")
        return True
    except Exception as e:
        print(f"Error finding hook point {hook_point}: {str(e)}")
        return False

if not verify_hook_point(base_model, default_cfg["hook_point"]):
    raise ValueError("Invalid hook point specified")

cfg = arg_parse_update_cfg(default_cfg)

trainer = Trainer(cfg, base_model, chat_model, all_tokens)
# %%

config.json:   0%|          | 0.00/818 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/24.2k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/481M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/168 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/46.4k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]


Available hook points:
- model
- model.embed_tokens
- model.layers
- model.layers.0
- model.layers.0.self_attn
- model.layers.0.self_attn.q_proj
- model.layers.0.self_attn.k_proj
- model.layers.0.self_attn.v_proj
- model.layers.0.self_attn.o_proj
- model.layers.0.mlp
- model.layers.0.mlp.gate_proj
- model.layers.0.mlp.up_proj
- model.layers.0.mlp.down_proj
- model.layers.0.mlp.act_fn
- model.layers.0.input_layernorm
- model.layers.0.post_attention_layernorm
- model.layers.0.pre_feedforward_layernorm
- model.layers.0.post_feedforward_layernorm
- model.layers.1
- model.layers.1.self_attn
- model.layers.1.self_attn.q_proj
- model.layers.1.self_attn.k_proj
- model.layers.1.self_attn.v_proj
- model.layers.1.self_attn.o_proj
- model.layers.1.mlp
- model.layers.1.mlp.gate_proj
- model.layers.1.mlp.up_proj
- model.layers.1.mlp.down_proj
- model.layers.1.mlp.act_fn
- model.layers.1.input_layernorm
- model.layers.1.post_attention_layernorm
- model.layers.1.pre_feedforward_layernorm
- model.laye

config.json:   0%|          | 0.00/838 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/24.2k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/241M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/187 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/47.0k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]


Available hook points:
- model
- model.embed_tokens
- model.layers
- model.layers.0
- model.layers.0.self_attn
- model.layers.0.self_attn.q_proj
- model.layers.0.self_attn.k_proj
- model.layers.0.self_attn.v_proj
- model.layers.0.self_attn.o_proj
- model.layers.0.mlp
- model.layers.0.mlp.gate_proj
- model.layers.0.mlp.up_proj
- model.layers.0.mlp.down_proj
- model.layers.0.mlp.act_fn
- model.layers.0.input_layernorm
- model.layers.0.post_attention_layernorm
- model.layers.0.pre_feedforward_layernorm
- model.layers.0.post_feedforward_layernorm
- model.layers.1
- model.layers.1.self_attn
- model.layers.1.self_attn.q_proj
- model.layers.1.self_attn.k_proj
- model.layers.1.self_attn.v_proj
- model.layers.1.self_attn.o_proj
- model.layers.1.mlp
- model.layers.1.mlp.gate_proj
- model.layers.1.mlp.up_proj
- model.layers.1.mlp.down_proj
- model.layers.1.mlp.act_fn
- model.layers.1.input_layernorm
- model.layers.1.post_attention_layernorm
- model.layers.1.pre_feedforward_layernorm
- model.laye


You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.



README.md:   0%|          | 0.00/292 [00:00<?, ?B/s]

train-00000-of-00008.parquet:   0%|          | 0.00/209M [00:00<?, ?B/s]

train-00001-of-00008.parquet:   0%|          | 0.00/209M [00:00<?, ?B/s]

train-00002-of-00008.parquet:   0%|          | 0.00/209M [00:00<?, ?B/s]

train-00003-of-00008.parquet:   0%|          | 0.00/209M [00:00<?, ?B/s]

train-00004-of-00008.parquet:   0%|          | 0.00/210M [00:00<?, ?B/s]

train-00005-of-00008.parquet:   0%|          | 0.00/209M [00:00<?, ?B/s]

train-00006-of-00008.parquet:   0%|          | 0.00/209M [00:00<?, ?B/s]

train-00007-of-00008.parquet:   0%|          | 0.00/209M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/963566 [00:00<?, ? examples/s]

Saving the dataset (0/8 shards):   0%|          | 0/963566 [00:00<?, ? examples/s]

Saved tokens to disk
Successfully found hook point: model.layers.14.input_layernorm
Module type: <class 'transformers.models.gemma2.modeling_gemma2.Gemma2RMSNorm'>
In IPython - skipped argparse


Estimating norm scaling factor: 100%|██████████| 100/100 [00:17<00:00,  5.60it/s]
Estimating norm scaling factor: 100%|██████████| 100/100 [00:16<00:00,  5.93it/s]


Refreshing the buffer!


100%|██████████| 128/128 [00:42<00:00,  3.00it/s]
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


In [9]:
trainer.train()


  0%|          | 2/488 [00:00<02:17,  3.54it/s]

{'loss': 4971.8173828125, 'l2_loss': 4971.8173828125, 'l1_loss': 118.55587768554688, 'l0_loss': 8172.958984375, 'l1_coeff': 0.0, 'lr': 5e-05, 'explained_variance': -0.5006786584854126, 'explained_variance_A': -0.5174552202224731, 'explained_variance_B': -0.48572492599487305}


 13%|█▎        | 62/488 [00:11<01:19,  5.38it/s]

Refreshing the buffer!



  0%|          | 0/64 [00:00<?, ?it/s][A
  2%|▏         | 1/64 [00:00<00:14,  4.43it/s][A
  3%|▎         | 2/64 [00:00<00:17,  3.46it/s][A
  5%|▍         | 3/64 [00:00<00:18,  3.22it/s][A
  6%|▋         | 4/64 [00:01<00:19,  3.13it/s][A
  8%|▊         | 5/64 [00:01<00:19,  3.08it/s][A
  9%|▉         | 6/64 [00:01<00:19,  3.05it/s][A
 11%|█         | 7/64 [00:02<00:18,  3.03it/s][A
 12%|█▎        | 8/64 [00:02<00:18,  3.02it/s][A
 14%|█▍        | 9/64 [00:02<00:18,  3.01it/s][A
 16%|█▌        | 10/64 [00:03<00:17,  3.01it/s][A
 17%|█▋        | 11/64 [00:03<00:17,  3.01it/s][A
 19%|█▉        | 12/64 [00:03<00:17,  3.00it/s][A
 20%|██        | 13/64 [00:04<00:16,  3.00it/s][A
 22%|██▏       | 14/64 [00:04<00:16,  3.00it/s][A
 23%|██▎       | 15/64 [00:04<00:16,  3.00it/s][A
 25%|██▌       | 16/64 [00:05<00:16,  3.00it/s][A
 27%|██▋       | 17/64 [00:05<00:15,  3.00it/s][A
 28%|██▊       | 18/64 [00:05<00:15,  3.00it/s][A
 30%|██▉       | 19/64 [00:06<00:15,  3.00it/s]

Created save directory at checkpoints/version_0
Saved as version 0 in checkpoints/version_0


KeyboardInterrupt: 

In [None]:
wandb.finish()


In [None]:
# %%
#from utils import *
#from crosscoder import CrossCoder
torch.set_grad_enabled(False);
# %%
VERSION_DIR = "version_9"  # The version directory containing your trained model
CHECKPOINT_VERSION = 0     # The checkpoint number you want to load
SAVE_DIR = Path("/content/checkpoints/")

cross_coder = CrossCoder.load(VERSION_DIR, CHECKPOINT_VERSION)

# %%
norms = cross_coder.W_dec.norm(dim=-1)
norms.shape
# %%
relative_norms = norms[:, 1] / norms.sum(dim=-1)
relative_norms.shape
# %%

fig = px.histogram(
    relative_norms.detach().cpu().numpy(),
    title="Gemma 2 2B Base vs IT Model Diff",
    labels={"value": "Relative decoder norm strength"},
    nbins=200,
)

fig.update_layout(showlegend=False)
fig.update_yaxes(title_text="Number of Latents")

# Update x-axis ticks
fig.update_xaxes(
    tickvals=[0, 0.25, 0.5, 0.75, 1.0],
    ticktext=['0', '0.25', '0.5', '0.75', '1.0']
)

fig.show()






In [None]:
# %%
shared_latent_mask = (relative_norms < 0.7) & (relative_norms > 0.3)
shared_latent_mask.shape
# %%
# Cosine similarity of recoder vectors between models

cosine_sims = (cross_coder.W_dec[:, 0, :] * cross_coder.W_dec[:, 1, :]).sum(dim=-1) / (cross_coder.W_dec[:, 0, :].norm(dim=-1) * cross_coder.W_dec[:, 1, :].norm(dim=-1))
cosine_sims.shape
# %%
import plotly.express as px
import torch

fig = px.histogram(
    cosine_sims[shared_latent_mask].to(torch.float32).detach().cpu().numpy(),
    #title="Cosine similarity of decoder vectors between models",
    log_y=True,  # Sets the y-axis to log scale
    range_x=[-1, 1],  # Sets the x-axis range from -1 to 1
    nbins=100,  # Adjust this value to change the number of bins
    labels={"value": "Cosine similarity of decoder vectors between models"}
)

fig.update_layout(showlegend=False)
fig.update_yaxes(title_text="Number of Latents (log scale)")

fig.show()
# %%

In [None]:
# %%

import plotly.express as px
import plotly.io as pio
import torch

# Configure plotly to work in Colab
pio.renderers.default = "colab"
torch.set_grad_enabled(False)

# %%
# Load locally trained crosscoder
# Replace version_dir and checkpoint_version with your actual values
VERSION_DIR = "version_9"  # The version directory containing your trained model
CHECKPOINT_VERSION = 0     # The checkpoint number you want to load

try:
    cross_coder = CrossCoder.load(VERSION_DIR, CHECKPOINT_VERSION)
    print(f"Successfully loaded crosscoder from version {VERSION_DIR}, checkpoint {CHECKPOINT_VERSION}")
except Exception as e:
    print(f"Error loading local crosscoder: {str(e)}")
    print("Falling back to HuggingFace model")
    cross_coder = CrossCoder.load_from_hf()

# %%
# Calculate norms and print shape for debugging
norms = cross_coder.W_dec.norm(dim=-1)
print("Norms shape:", norms.shape)

# %%
# Calculate relative norms and print shape for debugging
relative_norms = norms[:, 1] / norms.sum(dim=-1)
print("Relative norms shape:", relative_norms.shape)

# %%
# First histogram - Relative decoder norm strength
fig1 = px.histogram(
    relative_norms.detach().cpu().numpy(),
    title="Gemma 2 2B Base vs IT Model Diff",
    labels={"value": "Relative decoder norm strength"},
    nbins=200,
)

fig1.update_layout(
    showlegend=False,
    title_x=0.5,  # Center the title
    yaxis_title="Number of Latents",
    xaxis=dict(
        tickvals=[0, 0.25, 0.5, 0.75, 1.0],
        ticktext=['0', '0.25', '0.5', '0.75', '1.0']
    )
)

# Display the figure
fig1.show()

# %%
# Calculate shared latent mask and print shape for debugging
shared_latent_mask = (relative_norms < 0.7) & (relative_norms > 0.3)
print("Shared latent mask shape:", shared_latent_mask.shape)
print("Number of shared latents:", shared_latent_mask.sum().item())

# %%
# Calculate cosine similarities and print shape for debugging
cosine_sims = (cross_coder.W_dec[:, 0, :] * cross_coder.W_dec[:, 1, :]).sum(dim=-1) / (
    cross_coder.W_dec[:, 0, :].norm(dim=-1) * cross_coder.W_dec[:, 1, :].norm(dim=-1)
)
print("Cosine similarities shape:", cosine_sims.shape)

# %%
# Second histogram - Cosine similarity
fig2 = px.histogram(
    cosine_sims[shared_latent_mask].to(torch.float32).detach().cpu().numpy(),
    title="Cosine Similarity Distribution of Shared Features",
    labels={"value": "Cosine similarity of decoder vectors between models"},
    log_y=True,
    range_x=[-1, 1],
    nbins=100,
)

fig2.update_layout(
    showlegend=False,
    title_x=0.5,  # Center the title
    yaxis_title="Number of Latents (log scale)",
    xaxis_title="Cosine Similarity"
)

# Display the figure
fig2.show()

# Print summary statistics
print("\nSummary Statistics:")
cos_sims_filtered = cosine_sims[shared_latent_mask]
print(f"Mean cosine similarity: {cos_sims_filtered.mean().item():.3f}")
print(f"Median cosine similarity: {cos_sims_filtered.median().item():.3f}")
print(f"Std cosine similarity: {cos_sims_filtered.std().item():.3f}")
# %%
