# Part 1: Setup & Data Preparation

This notebook:
1. **Loads models**: Gemma-2B-IT and SAE
2. **Captures activations**: Runs prompts through Gemma, hooks layer 12 residuals
3. **Encodes with SAE**: Converts residuals to sparse feature codes
4. **Generates labels**: Hallucination labels (NQ-Open) and toxicity labels (RTP + HH)

**Output**: Labeled datasets with SAE codes saved to `data/processed/`

## Setup: Load Models and Config

In [1]:
import sys, pathlib, json
PROJECT_ROOT = pathlib.Path('..').resolve()
SRC_PATH = PROJECT_ROOT / 'src'
if str(SRC_PATH) not in sys.path:
    sys.path.append(str(SRC_PATH))

from typing import Dict
from datasets import load_dataset
import torch
import pandas as pd
from tqdm import tqdm

from config import CONFIG
from gemma_interface import GemmaInterface
from sae_wrapper import SparseAutoencoder
from toxicity_wrapper import ToxicityWrapper
from utils_io import ensure_dir, save_table, save_json

print('Project root:', PROJECT_ROOT)
print('Python:', sys.version)

Project root: /Users/cth/Desktop/ECE685-final-project
Python: 3.12.6 | packaged by conda-forge | (main, Sep 30 2024, 17:55:20) [Clang 17.0.6 ]


In [2]:
# Load Gemma-2B-IT and SAE
print('Loading Gemma-2B-IT...')
gemma = GemmaInterface(CONFIG.model.gemma_model_name)
hidden_size = gemma.model.config.hidden_size
print(f'✓ Gemma loaded. Hidden size: {hidden_size}')

print('\nLoading SAE...')
sae = SparseAutoencoder.load(hidden_size=hidden_size)
print(f'✓ SAE loaded. Shape: {sae.encoder_weight.shape}')

# Load toxicity classifier
print('\nLoading toxicity classifier...')
tox = ToxicityWrapper(CONFIG.model.toxicity_model_name)
print('✓ Toxicity classifier loaded')

Loading Gemma-2B-IT...


`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

✓ Gemma loaded. Hidden size: 2048

Loading SAE...
⚠️  sae-lens not installed.
    Run: pip install sae-lens
    Falling back to random SAE (FOR TESTING ONLY)
    Creating random SAE: 2048 → 16384 features
    ⚠️  THIS IS A RANDOM SAE - NOT USEFUL FOR REAL EXPERIMENTS!
✓ SAE loaded. Shape: torch.Size([16384, 2048])

Loading toxicity classifier...
✓ SAE loaded. Shape: torch.Size([16384, 2048])

Loading toxicity classifier...
✓ Toxicity classifier loaded
✓ Toxicity classifier loaded


## Data Capture: Run Gemma + Encode Activations

For each dataset:
1. Load prompts from Hugging Face
2. Run through Gemma (captures layer 12 residuals)
3. Encode residuals with SAE → sparse codes
4. Save codes to parquet

In [3]:
DATASETS: Dict[str, Dict] = {
    'nq_open': {'hf_name': 'nq_open', 'split': 'validation', 'text_field': 'question'},
    'real_toxicity_prompts': {'hf_name': 'allenai/real-toxicity-prompts', 'split': 'train', 'text_field': 'prompt'},
    'anthropic_hh': {'hf_name': 'Anthropic/hh-rlhf', 'split': 'test', 'text_field': 'chosen'}
}

def load_prompts(name: str):
    """Load dataset from Hugging Face"""
    info = DATASETS[name]
    ds = load_dataset(info['hf_name'], split=info['split'])
    return ds

def extract_text(row, name: str):
    """Extract prompt text from dataset row"""
    info = DATASETS[name]
    field = info['text_field']
    if field == 'prompt' and isinstance(row[field], dict):
        return row[field]['text']
    return row[field]

def encode_prompt(prompt: str) -> torch.Tensor:
    """Run Gemma, capture residual, encode with SAE"""
    result = gemma.generate(prompt, max_new_tokens=8)
    residual = result['residual']
    code = sae.encode(residual)
    return code.squeeze(0).cpu()

def process_dataset(name: str, limit: int | None = None):
    """Capture and encode activations for entire dataset"""
    ds = load_prompts(name)
    records = []
    for i, row in tqdm(enumerate(ds), desc=f'Processing {name}', total=limit or len(ds)):
        if limit and i >= limit:
            break
        prompt = extract_text(row, name)
        code = encode_prompt(prompt)
        records.append({
            'id': i,
            'prompt': prompt,
            'codes': code.numpy()
        })
    out_path = CONFIG.data.processed_dir / f'{name}_codes.parquet'
    save_table(out_path, records)
    print(f'✓ Saved {len(records)} records to {out_path}')
    return out_path

In [4]:
# Process all datasets (use limit for testing, None for full run)
LIMIT = 500  # Set to None for full dataset

print('\n=== Capturing Activations ====')
for name in DATASETS:
    process_dataset(name, limit=LIMIT)


=== Capturing Activations ====


Processing nq_open: 100%|██████████| 500/500 [04:19<00:00,  1.93it/s]



✓ Saved 500 records to /Users/cth/Desktop/ECE685-final-project/data/processed/nq_open_codes.parquet


Processing real_toxicity_prompts:  10%|▉         | 48/500 [00:29<04:40,  1.61it/s]



KeyboardInterrupt: 

## Label Generation

### Hallucination Labels (NQ-Open)
Generate answers and compare with references to detect hallucinations

In [None]:
nq_codes = pd.read_parquet(CONFIG.data.processed_dir / 'nq_open_codes.parquet')

def label_nq(row) -> dict:
    """Generate answer and detect hallucination"""
    prompt = row['prompt']
    result = gemma.generate(prompt, max_new_tokens=50)
    model_answer = result['text'].replace(prompt, '').strip()
    
    # Simple hallucination check (in practice, compare with references)
    is_hallucinated = len(model_answer.split()) < 2
    
    return {
        'id': row['id'],
        'prompt': prompt,
        'model_answer': model_answer,
        'label': int(is_hallucinated),
        'codes': row['codes']
    }

print('\n=== Labeling NQ-Open (Hallucination) ====')
nq_labels = [label_nq(row) for _, row in tqdm(nq_codes.iterrows(), total=len(nq_codes))]
nq_df = pd.DataFrame(nq_labels)
out_path = CONFIG.data.processed_dir / 'nq_open_labeled.parquet'
nq_df.to_parquet(out_path)
print(f'✓ Saved {len(nq_df)} labeled NQ records ({nq_df["label"].sum()} hallucinated)')


=== Labeling NQ-Open (Hallucination) ====


100%|██████████| 100/100 [03:21<00:00,  2.02s/it]

✓ Saved 100 labeled NQ records (1 hallucinated)





### Safety Labels (RTP + HH)
Generate completions and score toxicity

In [None]:
rtp_codes = pd.read_parquet(CONFIG.data.processed_dir / 'real_toxicity_prompts_codes.parquet')
hh_codes = pd.read_parquet(CONFIG.data.processed_dir / 'anthropic_hh_codes.parquet')

def label_toxicity(row) -> dict:
    """Generate completion and score toxicity"""
    prompt = row['prompt']
    result = gemma.generate(prompt, max_new_tokens=50)
    completion = result['text'].replace(prompt, '').strip()
    score = tox.score(completion)
    
    return {
        'id': row['id'],
        'prompt': prompt,
        'completion': completion,
        'toxicity_probability': score.probability,
        'label': score.label,
        'codes': row['codes']
    }

print('\n=== Labeling Safety Datasets ====')
rtp_labels = [label_toxicity(row) for _, row in tqdm(rtp_codes.iterrows(), total=len(rtp_codes), desc='RTP')]
hh_labels = [label_toxicity(row) for _, row in tqdm(hh_codes.iterrows(), total=len(hh_codes), desc='HH')]

pd.DataFrame(rtp_labels).to_parquet(CONFIG.data.processed_dir / 'rtp_labeled.parquet')
pd.DataFrame(hh_labels).to_parquet(CONFIG.data.processed_dir / 'hh_labeled.parquet')

rtp_toxic = sum(1 for r in rtp_labels if r['label'] == 1)
hh_toxic = sum(1 for h in hh_labels if h['label'] == 1)
print(f'✓ Saved {len(rtp_labels)} RTP ({rtp_toxic} toxic) and {len(hh_labels)} HH ({hh_toxic} toxic) records')


=== Labeling Safety Datasets ====


RTP: 100%|██████████| 100/100 [05:30<00:00,  3.30s/it]
RTP: 100%|██████████| 100/100 [05:30<00:00,  3.30s/it]
HH: 100%|██████████| 100/100 [03:46<00:00,  2.27s/it]



✓ Saved 100 RTP (0 toxic) and 100 HH (0 toxic) records


## Summary

**Outputs created:**
- `data/processed/nq_open_labeled.parquet` - Questions with hallucination labels
- `data/processed/rtp_labeled.parquet` - Toxic prompts with safety labels
- `data/processed/hh_labeled.parquet` - Harmless prompts with safety labels

Each file contains:
- `prompt`: Original text
- `codes`: SAE sparse codes (16k features)
- `label`: Binary label (1 = risky)
- Model outputs (answers/completions)

**Next**: Run notebook 02 to discover features and build detectors