In [None]:
!pip install sae-lens transformer-lens

In [None]:
import os
import gc
import requests
import time
from google.colab import drive
from huggingface_hub import login
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' # to be able to encounter less OOM errors

import torch
import pandas as pd
import numpy as np
from tqdm import tqdm
from transformers import AutoTokenizer
from sae_lens import SAE, HookedSAETransformer

In [None]:
torch.set_grad_enabled(False) # we don't need gradients for this job

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

In [None]:
drive.mount('/content/drive')
login(token= "XXX") # INSERT YOUR HF TOKEN HERE

In [None]:
MODEL_NAME = "google/gemma-2-9b-it" # instruction-tuned version of 9B Gemma-2 model from Google (available in Gemma-Scope and Neuronpedia)

In [None]:
df = pd.read_csv('/content/drive/MyDrive/thesis/data/SuicideAndDepression_Detection.csv')
df = df[["text", "class"]]
print(f"Dataset shape: {df.shape}", "\n\n")
print(f"Classes: {df['class'].value_counts()}", "\n\n") # perfect classes balance
print(f"Sample text length: \n\n{df['text'].str.len().describe()}")
print("\n\n")
df.head()


In [None]:
for threshold in [1000, 2000, 3000, 4000, 5000]:
    pct = (df['text'].str.len() <= threshold).mean() * 100
    print(f"<= {threshold} chars: {pct:.1f}% of data")

In [None]:
threshold = 4000 # we cannot process due to limited VRAM if a text is too long, but with 4k chars we are able to keep more than 95% of the data
df = df[df['text'].str.len() <= threshold]
df["class"].value_counts()

In [None]:
df = pd.concat(
    [
        df[df["class"] == "teenagers"].sample(n=20000, random_state=42),
        df[df["class"] == "SuicideWatch"].sample(n=20000, random_state=42),
        df[df["class"] == "depression"].sample(n=20000, random_state=42),
    ]
)

df.to_csv("/content/drive/MyDrive/thesis/data/suicide_teen_depression_sampled.tsv", sep='\t', index=False)

In [None]:
df = pd.read_csv("/content/drive/MyDrive/thesis/data/suicide_teen_depression_sampled.tsv", sep='\t')
df.head(2)

In [None]:
model = HookedSAETransformer.from_pretrained(MODEL_NAME, device=device)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

In [None]:
# extracting features from later layers for them to be more semantically rich (the latest layer available in Neuronpedia is 31)
# we use 16k width SAE due to limited compute, although a wider version is available
# canonical refers to L1 strength

LAYER = 31
sae = SAE.from_pretrained(
    release="gemma-scope-9b-it-res-canonical",
    sae_id=f"layer_{LAYER}/width_16k/canonical",
    device=device,
)
print(f"Loaded SAE for {MODEL_NAME} layer {LAYER}")
print(f"SAE config: d_in={sae.cfg.d_in}, d_sae={sae.cfg.d_sae}")

In [None]:
# we do not process in batches to have less OOMs due to large number of parameters in both model and its SAE

def extract_sae_activations(
    texts,
    classes,
    model,
    sae,
    tokenizer,
    checkpoint_dir,
    checkpoint_prefix,
    checkpoint_interval=100,
    start_idx=0
    ):

    model.eval()

    os.makedirs(checkpoint_dir, exist_ok=True)

    all_cont_activations = []
    all_any_activations = []
    all_last_token_activations = []
    all_binary_sum_activations = []
    processed_texts = []
    processed_classes = []

    checkpoint_path = os.path.join(checkpoint_dir, f"{checkpoint_prefix}_checkpoint.npz")
    resume_from = start_idx

    if os.path.exists(checkpoint_path):
        print(f"Found existing checkpoint at {checkpoint_path}")
        checkpoint_data = np.load(checkpoint_path, allow_pickle=True)

        all_cont_activations = checkpoint_data['cont_activations'].tolist()
        all_any_activations = checkpoint_data['any_activations'].tolist()
        all_last_token_activations = checkpoint_data['last_token_activations'].tolist()
        all_binary_sum_activations = checkpoint_data['binary_sum_activations'].tolist()
        processed_texts = checkpoint_data['texts'].tolist()
        processed_classes = checkpoint_data['classes'].tolist()
        resume_from = checkpoint_data['last_processed_idx'].item() + 1

        print(f"Resuming from index {resume_from} (already processed {resume_from} items)")
    else:
        print(f"No checkpoint found. Starting from index {start_idx}")
        resume_from = start_idx

    for i in tqdm(range(resume_from, len(texts)), desc="Extracting SAE activations", initial=resume_from, total=len(texts)):
        text = texts[i]
        cls = classes[i]
        formatted_text = tokenizer.apply_chat_template(
            conversation=[{"role": "user", "content": text}],
            tokenize=False,
            add_generation_prompt=True
        )

        tokens = model.to_tokens(formatted_text, prepend_bos=False)
        try:
            with torch.inference_mode():
                _, cache = model.run_with_cache_with_saes(tokens, saes=[sae])
                sae_acts_cpu = cache[f"{sae.cfg.metadata.hook_name}.hook_sae_acts_post"].cpu()
                del cache
                del tokens
                torch.cuda.empty_cache()

                start_pos = 4
                end_pos = sae_acts_cpu.shape[1]

                # we comnpute 4 types of aggregations
                cont_activation = sae_acts_cpu[0, start_pos:end_pos, :].sum(dim=0).numpy()
                any_token_activation = ((sae_acts_cpu[0, start_pos:end_pos, :] > 0).sum(dim=0) > 0).numpy()
                last_token_activation = sae_acts_cpu[0, end_pos - 6, :].numpy()
                binary_sum_activation = (sae_acts_cpu[0, start_pos:end_pos, :] > 0).sum(dim=0).numpy()

                del sae_acts_cpu

            all_cont_activations.append(cont_activation)
            all_any_activations.append(any_token_activation)
            all_last_token_activations.append(last_token_activation)
            all_binary_sum_activations.append(binary_sum_activation)
            processed_texts.append(text)
            processed_classes.append(cls)

        except Exception as e:
            print(f"Error processing index {i}: {e}")
            continue

        if i % 10 == 0:
            gc.collect()
            torch.cuda.empty_cache()

        if (i + 1) % checkpoint_interval == 0:
            np.savez_compressed(
                checkpoint_path,
                cont_activations=np.array(all_cont_activations),
                any_activations=np.array(all_any_activations),
                last_token_activations=np.array(all_last_token_activations),
                binary_sum_activations=np.array(all_binary_sum_activations),
                texts=np.array(processed_texts, dtype=object),
                classes=np.array(processed_classes, dtype=object),
                last_processed_idx=np.array(i)
            )

    np.savez_compressed(
        checkpoint_path,
        cont_activations=np.array(all_cont_activations),
        any_activations=np.array(all_any_activations),
        last_token_activations=np.array(all_last_token_activations),
        binary_sum_activations=np.array(all_binary_sum_activations),
        texts=np.array(processed_texts, dtype=object),
        classes=np.array(processed_classes, dtype=object),
        last_processed_idx=np.array(len(texts) - 1)
    )

    return (
        processed_texts,
        processed_classes,
        np.array(all_cont_activations),
        np.array(all_any_activations),
        np.array(all_last_token_activations),
        np.array(all_binary_sum_activations)
    )

In [None]:
depr = df[df["class"] == "depression"]
depr.shape

In [None]:
start, end = 15000, 20000
data_chunk = depr.iloc[start:end]
data_chunk.shape

In [None]:
MODEL_NAME_CLEAN = MODEL_NAME.replace('/', '_')

result = extract_sae_activations(
    data_chunk['text'].tolist(),
    data_chunk['class'].tolist(),
    model,
    sae,
    tokenizer,
    checkpoint_dir='/content/drive/MyDrive/thesis/checkpoints/',
    checkpoint_prefix=f'depression_{MODEL_NAME_CLEAN}_{LAYER}_{start}_{end}',
    checkpoint_interval=300
)