# Setup

In [2]:
# %pip install 'torch==2.5.1' einops datasets jaxtyping sae_lens transformer_lens openai tabulate "nbformat>=4.2.0" umap-learn hdbscan eindex-callum git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python git+https://github.com/callummcdougall/sae_vis.git@callum/v3

In [3]:
import os
import sys

In [4]:
os.environ['HF_TOKEN'] = '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'

## Funzioni per vedere la memoria

In [5]:
import gc

import torch as t
from openai import OpenAIError
from tabulate import tabulate


def get_tensor_size(obj):
    size = 0
    if t.is_tensor(obj):
        size += obj.element_size() * obj.nelement()
    return size


def get_tensors_size(obj):
    if isinstance(obj, t.nn.Module):
        return sum(get_tensor_size(p) for p in obj.parameters())
    if hasattr(obj, "state_dict"):
        return sum(get_tensor_size(t) for t in obj.state_dict().values())
    return get_tensor_size(obj)


def get_device(obj):
    if t.is_tensor(obj):
        return str(obj.device)
    if isinstance(obj, t.nn.Module):
        try:
            return str(next(iter(obj.parameters())).device)
        except StopIteration:
            return "N/A"
    return "N/A"


def print_memory_status():
    t.cuda.synchronize()
    allocated = t.cuda.memory_allocated(0)
    total = t.cuda.get_device_properties(0).total_memory
    free = total - allocated
    print(f"Allocated: {allocated / 1024**3:.2f} GB")
    print(f"Total:  {total / 1024**3:.2f} GB")
    print(f"Free:  {free / 1024**3:.2f} GB")


def profile_pytorch_memory(namespace: dict, n_top: int = 10, filter_device: str = None):
    print_memory_status()

    object_sizes = {}
    for name, obj in namespace.items():
        try:
            obj_type = (
                type(obj).__name__
                if isinstance(obj, t.nn.Module)
                else f"Tensor {tuple(obj.shape)}"
                if t.is_tensor(obj)
                else None
            )
            if obj_type is None:
                continue
            device = get_device(obj)
            if filter_device and device != filter_device:
                continue
            size = get_tensors_size(obj)
            object_sizes[name] = (obj_type, device, size / (1024**3))
        except (OpenAIError, ReferenceError):
            # OpenAIError: we can't inspect the type of certain objects without triggering API request
            # ReferenceError: this object might have been garbage collected, so we don't care about it
            continue

    # Convert bytes to GB, sort by size & print
    sorted_sizes = sorted(object_sizes.items(), key=lambda x: x[1][2], reverse=True)[:n_top]
    table_data = [(name, obj_type, device, size) for name, (obj_type, device, size) in sorted_sizes]
    print(
        tabulate(
            table_data, headers=["Name", "Object", "Device", "Size (GB)"], floatfmt=".2f", tablefmt="simple_outline"
        )
    )


def find_cuda_tensors():
    cuda_tensors = []
    for obj in gc.get_objects():
        try:
            if t.is_tensor(obj) and obj.is_cuda:
                cuda_tensors.append(obj)
        except:
            pass
    return cuda_tensors


## Altre utilità

In [6]:
import matplotlib.pyplot as plt
import seaborn as sns

In [7]:
import math
def make_square(num: int) -> tuple[int, int]:
    """
    Trova i due fattori più vicini alla radice quadrata di un numero.

    Args:
        numero: Il numero intero di cui trovare i fattori.

    Returns:
        Una tupla contenente i due fattori più vicini alla radice quadrata.
    """

    if num < 1:
        raise ValueError("Il numero deve essere maggiore o uguale a 1.")

    radice_quadrata = math.sqrt(num)
    fattore_inferiore = math.floor(radice_quadrata)
    fattore_superiore = math.ceil(radice_quadrata)

    while num % fattore_inferiore != 0:
        fattore_inferiore -= 1

    while num % fattore_superiore != 0:
        fattore_superiore += 1

    return fattore_inferiore, fattore_superiore

## Imports

In [8]:
import gc
import itertools
import math
import os
import random
import sys
from collections import Counter
from copy import deepcopy
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Any, Callable, Literal, TypeAlias

import einops
import numpy as np
import pandas as pd
import plotly.express as px
import requests
import torch as t
import openai
from datasets import load_dataset
from huggingface_hub import hf_hub_download
from IPython.display import HTML, IFrame, clear_output, display
from jaxtyping import Float, Int
from rich import print as rprint
from rich.table import Table
from sae_lens import (
    SAE,
    ActivationsStore,
    HookedSAETransformer,
    LanguageModelSAERunnerConfig,
    SAEConfig,
    SAETrainingRunner,
    upload_saes_to_huggingface,
)
from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory
from sae_vis import SaeVisConfig, SaeVisData, SaeVisLayoutConfig
from tabulate import tabulate
from torch import Tensor, nn
from torch.distributions.categorical import Categorical
from torch.nn import functional as F
from tqdm.auto import tqdm
from transformer_lens import ActivationCache, HookedTransformer, utils
from transformer_lens.hook_points import HookPoint

## Check memoria

In [9]:
device = t.device("mps" if t.backends.mps.is_available() else "cuda" if t.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [10]:
t.cuda.empty_cache()

In [11]:
# Profile memory usage, and delete gemma models if we've loaded them in
namespace = globals().copy() | locals()
profile_pytorch_memory(namespace=namespace, filter_device="cuda:0")

Allocated: 0.00 GB
Total:  6.00 GB
Free:  6.00 GB
┌────────┬──────────┬──────────┬─────────────┐
│ Name   │ Object   │ Device   │ Size (GB)   │
├────────┼──────────┼──────────┼─────────────┤
└────────┴──────────┴──────────┴─────────────┘


# Training del SAE

Per i parametri riporto i config dei SAE allenati da google per gemma-scope-2b. 
Dall'HF google/gemma-scope-2b-pt-res:\
What Is gemma-scope-2b-pt-res?
* gemma-scope-: Gemma Scope is a comprehensive, open suite of sparse autoencoders for Gemma 2 9B and 2B. Sparse Autoencoders are a "microscope" of sorts that can help us break down a model’s internal activations into the underlying concepts, just as biologists use microscopes to study the individual cells of plants and animals.
* 2b-pt-: These SAEs were trained on Gemma v2 2B base model.
* res: These SAEs were trained on the model's residual stream.


Prendiamo un SAE di gemmascope. Per esempio quello che viene usato su neuronpedia per vederne i config:

In [16]:
release = "gemma-scope-2b-pt-res"
sae_id = "layer_20/width_16k/average_l0_71"
sae_gemmascope_cfg = SAE.from_pretrained(release, sae_id)[1]

In [17]:
sae_gemmascope_cfg

{'architecture': 'jumprelu',
 'd_in': 2304,
 'd_sae': 16384,
 'dtype': 'float32',
 'model_name': 'gemma-2-2b',
 'hook_name': 'blocks.20.hook_resid_post',
 'hook_layer': 20,
 'hook_head_index': None,
 'activation_fn_str': 'relu',
 'finetuning_scaling_factor': False,
 'sae_lens_training_version': None,
 'prepend_bos': True,
 'dataset_path': 'monology/pile-uncopyrighted',
 'context_size': 1024,
 'dataset_trust_remote_code': True,
 'apply_b_dec_to_input': False,
 'normalize_activations': None,
 'device': 'cpu',
 'neuronpedia_id': None}

In [18]:
sae_gemmascope_cfg['d_sae']/sae_gemmascope_cfg['d_in']

7.111111111111111

Traineremo un SAE sul livello 20. Dal paper https://arxiv.org/pdf/2408.05147 troviamo i token usati, ovvero 4B per una dimensione del SAE di circa 16k
La batch size usata è 4096 e la learning rate 7e-5

Parametri importanti:

In [28]:
# parametri del training
total_training_steps = 300_000  # Calculated from training_tokens / batch_size
batch_size = 4096
total_training_tokens = total_training_steps * batch_size
print(f"Total training tokens: {total_training_tokens:,}")
print(f"Total training steps: {total_training_steps:,}")
learning_rate = 7e-5
print(F"Learning rate: {learning_rate}")

lr_warm_up_steps = l1_warm_up_steps = total_training_steps // 10  # 10% of training
lr_decay_steps = total_training_steps // 5  # 20% of training

# wandb

log_to_wandb = True
print("Logging in to WANDB")

# parametri del modello
model_name = 'gemma-2-2b'
hook_name = 'blocks.20.hook_resid_post'
hook_layer = 20
d_sae = 16384
d_in = 2304

l1_coefficient = 2

context_size = 1024          # il pretrained ha 128

print(f'tokens per step: {batch_size*context_size}')

Total training tokens: 1,228,800,000
Total training steps: 300,000
Learning rate: 7e-05
Logging in to WANDB
tokens per step: 4194304


In [29]:
cfg = LanguageModelSAERunnerConfig(
    #
    # Data generation
    model_name=model_name,
    hook_name=hook_name,
    hook_layer=hook_layer,
    d_in=2304,
    dataset_path="chanind/openwebtext-gemma",
    is_dataset_tokenized=True,
    # dataset_path="HuggingFaceFW/fineweb",
    # is_dataset_tokenized=False,
    prepend_bos=True,
    streaming=True,
    train_batch_size_tokens=batch_size,
    context_size=context_size,
    #
    # SAE architecture
    architecture="jumprelu",
    d_sae=d_sae,
    b_dec_init_method="zeros",
    apply_b_dec_to_input=True,
    #
    # Activations store
    n_batches_in_buffer=32,
    training_tokens=total_training_tokens,
    store_batch_size_prompts=16,
    #
    # Training hyperparameters (standard)
    lr=learning_rate,
    adam_beta1=0.9,
    adam_beta2=0.999,
    lr_scheduler_name="constant",
    lr_warm_up_steps=lr_warm_up_steps,
    lr_decay_steps=lr_decay_steps,
    #
    # Training hyperparameters (SAE-specific)
    l1_coefficient=l1_coefficient,
    l1_warm_up_steps=l1_warm_up_steps,
    use_ghost_grads=False,
    feature_sampling_window=5000,
    dead_feature_window=5000,
    dead_feature_threshold=1e-6,
    #
    # Logging / evals
    log_to_wandb=True,
    wandb_project="gemma2-2bL20-16k",
    wandb_log_frequency=50,
    eval_every_n_wandb_logs=20,
    #
    # Misc.
    device=str(device),
    seed=42,
    n_checkpoints=5,
    checkpoint_path="checkpoints",
    dtype="float32",
)

In [None]:
t.set_grad_enabled(True)
runner = SAETrainingRunner(cfg)

In [None]:
import torch._dynamo
torch._dynamo.config.suppress_errors = True

In [None]:
sae = runner.run()

In [None]:
hf_repo_id = "Ale21m9/GEMMA2-2b-16k"
sae_id = cfg.hook_name

upload_saes_to_huggingface({sae_id: sae}, hf_repo_id=hf_repo_id)
