In [1]:
#%%
# %load_ext autoreload
# %autoreload 2

In [2]:
#%%
# Import libraries
from dataclasses import dataclass
import torch
import os
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
)
from trl import ORPOConfig, ORPOTrainer, setup_chat_format
from sae_lens import HookedSAETransformer, SAE

In [3]:
#%%
# Import our custom ISAERFT components
try:
    # When imported as a module
    from model_components.IsaerftConfig import IsaerftConfig
    from model_components.IsaerftPeft import IsaerftPeft
except ImportError:
    # When run directly as a script
    import sys
    import os
    # Add the parent directory to the path
    sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
    from model_components.IsaerftConfig import IsaerftConfig
    from model_components.IsaerftPeft import IsaerftPeft

In [4]:
#%%
import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # Use only GPU 0
import torch
print(torch.cuda.device_count())  # Should print 1

In [5]:
#%%
# Authenticate to Hugging Face
from huggingface_hub import login

login(token=os.environ['HUGGINGFACE_WRITE_KEY'])

In [6]:
#%%
# Load dataset
dataset = load_dataset(path="trl-lib/ultrafeedback_binarized")

In [7]:
#%%
# Define the model
model_name = "EleutherAI/pythia-70m-deduped" 

device = (
    "cuda:0"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)
assert 'cuda' in device

In [8]:
#%%
# Model to fine-tune
model = HookedSAETransformer.from_pretrained(
    model_name,
    # torch_dtype=torch.float32,
    # device_map=device
).to(device)
# model.config.use_cache = False
# tokenizer = AutoTokenizer.from_pretrained(model_name)

tokenizer = AutoTokenizer.from_pretrained(model_name)

In [9]:
#%%
# non_hooked_model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b").to('cpu')

In [10]:
#%%
# del non_hooked_model
# chat = [
#     { "role": "user", "content": "Write a hello world program" },
# ]
# prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
# print(prompt)
# #%%
# _, tokenizer = setup_chat_format(non_hooked_model, tokenizer) # TODO: Figure out if this is a problem that I'm not applying the chat format to the model

In [11]:
#%%
# Apply ISAERFT to the model
from sae_lens import SAE

release = "pythia-70m-deduped-res-sm"
sae_id = "blocks.4.hook_resid_post"
isaerft_config = IsaerftConfig(
    target_hooks=[
        (release, sae_id),
    ],
    depth=-1  # Bias-only for simplicity
)

In [12]:
#%%
# Apply the ISAERFT adapter
model = IsaerftPeft(model, isaerft_config)

In [13]:
# %%
%load_ext autoreload
%autoreload 2

In [14]:
#%%
# Apply the ISAERFT adapter
model = IsaerftPeft(model, isaerft_config)

In [15]:
#%%
# Apply the ISAERFT adapter
model = IsaerftPeft(model, isaerft_config)

In [16]:
#%%
type(model)

model_components.IsaerftPeft.IsaerftPeft

In [17]:
#%%
type(model.model)

sae_lens.analysis.hooked_sae_transformer.HookedSAETransformer

In [18]:
model.model.saes

<bound method HookedSAETransformer.saes of HookedSAETransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (blocks): ModuleList(
    (0-3): 4 x TransformerBlock(
      (ln1): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
        (hook_rot_k): HookPoint()
        (hook_rot_q): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): 

In [19]:
#%%
# Check device placement of model components
def check_device():
    print("Checking device placement of model components...")
    model_device = next(model.base_model.parameters()).device
    import torch
    print(f"{torch.device(device)=}")
    print(f"{model_device=}")
    assert model_device == torch.device(device)
    # Check base model
    print(f"Base model device: {model_device}")

    assert(all(p[1].device == model_device for p in model.named_parameters()))
    text = "Hello, world!"
    input_ids = tokenizer(text, return_tensors="pt").input_ids.to(device)
    print(f"Input shape: {input_ids.shape}")


    # Forward pass
    output = model(input_ids)
    print(f"Output shape: {output.shape}")
    print(tokenizer.apply_chat_template(text, tokenize=False, add_generation_prompt=True))
    # Run an example through the model to verify forward pass
    print("Testing model forward pass...")

    # Create a simple test input
    test_chat = [
        {"role": "user", "content": "Write a hello world program"}
    ]
    test_prompt = tokenizer.apply_chat_template(test_chat, tokenize=False, add_generation_prompt=True)
    print(test_prompt)
    # Tokenize input
    inputs = tokenizer(test_prompt, return_tensors="pt").to(device)
    print(inputs)
    # Run forward pass
    outputs = model(**inputs)
    print(outputs)
    print("Forward pass successful!")
    print(f"Output shape: {outputs.logits.shape}")
    print(f"Output device: {outputs.logits.device}")
    assert outputs.logits.device == model_device, "Output device doesn't match model device"

In [20]:
#%%

# model, tokenizer = setup_chat_format(model, tokenizer)

In [21]:
#%%
# Set our name for the finetune to be saved &/ uploaded to
finetune_name = "PYTHIA-FT-ORPO-ISAERFT"
finetune_tags = ["smol-course", "module_1", "isaerft"]

In [22]:
#%%
# Train model with ORPO
orpo_args = ORPOConfig(
    # Small learning rate to prevent catastrophic forgetting
    learning_rate=8e-6,
    # Linear learning rate decay over training
    lr_scheduler_type="linear",
    # Maximum combined length of prompt + completion
    max_length=1024,
    # Maximum length for input prompts
    max_prompt_length=512,
    # Controls weight of the odds ratio loss (λ in paper)
    beta=0.1,
    # Batch size for training
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    # Helps with training stability by accumulating gradients before updating
    gradient_accumulation_steps=8,
    # Memory-efficient optimizer for CUDA, falls back to adamw_torch for CPU/MPS
    optim="paged_adamw_8bit" if ("cuda" in device)else "adamw_torch",
    # When to run evaluation
    eval_strategy="steps",
    # Evaluate every 20% of training
    eval_steps=0.2,
    # Log metrics every step
    logging_steps=1,
    # Gradual learning rate warmup
    warmup_steps=10,
    # Disable external logging
    report_to="wandb",
    # Where to save model/checkpoints
    output_dir="./results/",
    # Enable MPS (Metal Performance Shaders) if available
    use_mps_device=device == "mps",
    hub_model_id=finetune_name,
    # Training for a shorter time for this example
    num_train_epochs=(1/4*.25),
)

In [23]:
#%%
# Initialize wandb
import wandb
from datetime import datetime
import uuid
wandb.finish()
wandb.login(key=os.environ['WANDB_KEY'])

wandb.init(
    project="gemma-2-2b-orpo-isaerft",
    name=f"run-{datetime.now().strftime('%Y%m%d-%H%M')}-{uuid.uuid4().hex[:6]}",
    tags=finetune_tags
)

<wandb.sdk.wandb_run.Run at 0x7fec5f981a50>

In [24]:
#%%
# Initialize wandb
import wandb
from datetime import datetime
import uuid
wandb.finish()
wandb.login(key=os.environ['WANDB_KEY'])

wandb.init(
    project="pythia70m-orpo-isaerft",
    name=f"run-{datetime.now().strftime('%Y%m%d-%H%M')}-{uuid.uuid4().hex[:6]}",
    tags=finetune_tags
)