In [1]:
import json
import logging
import os
import sys
import time
import datetime
from pathlib import Path
from typing import Dict, List, Tuple, Optional
from textwrap import dedent
from collections import defaultdict

import numpy as np
import pandas as pd
import torch as t
from tqdm.auto import tqdm
from IPython.display import display, HTML
from tabulate import tabulate
import matplotlib.pyplot as plt
import seaborn as sns

# Disable gradients for inference
t.set_grad_enabled(False)

# Reduce logging noise
logging.disable(sys.maxsize)

# Device selection
if t.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if t.cuda.is_available() else "cpu"

print(f"Torch Device: {device}")
print(f"PyTorch Version: {t.__version__}")

Torch Device: cuda
PyTorch Version: 2.10.0+cu130


In [5]:
# ═══════════════════════════════════════════════════════════════════
# CONFIGURATION — adjust these parameters for your experiments
# ═══════════════════════════════════════════════════════════════════

# Target model (local, runs through NNSight)
TARGET_MODEL_NAME = "google/gemma-3-4b-it"

# Attacker / Judge model (remote, via OpenAI API)
ATTACKER_MODEL = "gpt-4o"

# NNSIGHT (False if using a local model)
REMOTE = False

print("Configuration:")
print(f"  Target model:  {TARGET_MODEL_NAME}")
print(f"  Attacker:      {ATTACKER_MODEL}")


Configuration:
  Target model:  google/gemma-3-4b-it
  Attacker:      gpt-4o


In [6]:
from nnsight import LanguageModel

print(f"Loading target model: {TARGET_MODEL_NAME}")
model = LanguageModel(
    TARGET_MODEL_NAME,
    device_map="auto",
    torch_dtype=t.bfloat16,
    # Use PyTorch's Scaled Dot-Product Attention (SDPA) for faster inference.
    # Mathematically identical to default eager attention — same outputs, same hidden states,
    # same SAE activations — just uses fused CUDA kernels for ~20-40% speedup.
    # Built into PyTorch 2.0+, no extra install needed (unlike flash-attn which requires nvcc).
    attn_implementation="sdpa",
)
tokenizer = model.tokenizer

# Set padding for batched generation
tokenizer.padding_side = "left"
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print(f"  Architecture: {model.config.architectures}")
print(f"  Attention:    sdpa (PyTorch native)")
print(f"  Loaded successfully")

Loading target model: google/gemma-3-4b-it


  Architecture: ['Gemma3ForConditionalGeneration']
  Attention:    sdpa (PyTorch native)
  Loaded successfully


In [7]:
# ═══════════════════════════════════════════════════════════════════
# Speed test — NNSight vs vanilla HuggingFace generation throughput
# ═══════════════════════════════════════════════════════════════════

test_prompts = [
    [{"role": "user", "content": "What is the capital of France?"}],
    [{"role": "user", "content": "Explain photosynthesis in simple terms."}],
    [{"role": "user", "content": "Write a short poem about the ocean."}],
]
max_new = 200

def bench_nnsight(prompts, max_new_tokens):
    """Benchmark NNSight-wrapped generation (what the pipeline uses)."""
    total_tokens = 0
    total_time = 0.0
    for msgs in prompts:
        prompt_text = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
        input_len = len(tokenizer.encode(prompt_text))

        t0 = time.perf_counter()
        with model.generate(prompt_text, max_new_tokens=max_new_tokens, repetition_penalty=1.2, remote=REMOTE) as gen:
            tokens = model.generator.output.save()
        elapsed = time.perf_counter() - t0

        new_tokens = tokens.shape[-1] - input_len
        total_tokens += new_tokens
        total_time += elapsed
    return total_tokens, total_time

def bench_huggingface(prompts, max_new_tokens):
    """Benchmark vanilla HuggingFace model.generate() (no NNSight wrapper)."""
    # Access the underlying HF model directly
    hf_model = model._model
    total_tokens = 0
    total_time = 0.0
    for msgs in prompts:
        prompt_text = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
        input_ids = tokenizer(prompt_text, return_tensors="pt").input_ids.to(device)
        input_len = input_ids.shape[-1]

        t0 = time.perf_counter()
        with t.no_grad():
            output_ids = hf_model.generate(input_ids, max_new_tokens=max_new_tokens, repetition_penalty=1.2)
        elapsed = time.perf_counter() - t0

        new_tokens = output_ids.shape[-1] - input_len
        total_tokens += new_tokens
        total_time += elapsed
    return total_tokens, total_time

# Warmup run (first call is always slower due to CUDA graph compilation)
print("Warming up...")
_ = bench_nnsight(test_prompts[:1], 10)
_ = bench_huggingface(test_prompts[:1], 10)

# Benchmark NNSight
print(f"\nSpeed test ({len(test_prompts)} prompts, max_new_tokens={max_new})")
print("=" * 60)

nn_tokens, nn_time = bench_nnsight(test_prompts, max_new)
nn_tps = nn_tokens / nn_time if nn_time > 0 else 0
print(f"  NNSight:      {nn_tps:6.1f} tok/s  ({nn_tokens} tokens in {nn_time:.2f}s)")

hf_tokens, hf_time = bench_huggingface(test_prompts, max_new)
hf_tps = hf_tokens / hf_time if hf_time > 0 else 0
print(f"  HuggingFace:  {hf_tps:6.1f} tok/s  ({hf_tokens} tokens in {hf_time:.2f}s)")

print("=" * 60)
overhead_pct = 100 * (nn_time - hf_time) / hf_time if hf_time > 0 else 0
print(f"  NNSight overhead: {overhead_pct:+.1f}% vs vanilla HuggingFace")
print(f"  (Expected: NNSight adds tracing layer for interpretability access)")

Warming up...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]


Speed test (3 prompts, max_new_tokens=200)
  NNSight:        25.2 tok/s  (313 tokens in 12.43s)
  HuggingFace:    27.2 tok/s  (308 tokens in 11.33s)
  NNSight overhead: +9.7% vs vanilla HuggingFace
  (Expected: NNSight adds tracing layer for interpretability access)
