# Lesson 5: Training a Reward Model on Mac M3 (Hugging Face)

Yes, you can run Hugging Face training on your Apple M3!
We use **PyTorch MPS (Metal Performance Shaders)** to accelerate training on your Mac's GPU.

**Goal**: Train a small "Reward Model" that can judge if an answer is correct.
This replaces the "Rule-Based Verifier" from Lesson 3 with a learned neural network.

### Prerequisites
*   `pip install trl peft`
*   **Note**: Training requires significant RAM. We will use a tiny model (`Qwen/Qwen2.5-0.5B-Instruct`) to ensure it fits comfortably.

In [None]:
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from trl import RewardTrainer, RewardConfig
from datasets import load_dataset
from peft import LoraConfig, TaskType, get_peft_model

# 1. Setup MPS (Apple Silicon)
device = "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Training on device: {device}")

## 1. Prepare Data
To train a Reward Model, we need **Pairs** of (Chosen, Rejected) responses.
We will load a tiny slice of the `anthropic/hh-rlhf` dataset (or a math preference set).

In [None]:
# We construct a tiny synthetic dataset for demonstration
# Format: {prompt, chosen, rejected}
data = [
    {"prompt": "2+2=", "chosen": "The answer is 4.", "rejected": "The answer is 5."},
    {"prompt": "Is 91 prime?", "chosen": "No, 91 is 7*13.", "rejected": "Yes, 91 is prime."},
    {"prompt": "Capital of France?", "chosen": "Paris", "rejected": "London"}
] * 10 # Repeat to fake a batch

from datasets import Dataset
dataset = Dataset.from_list(data)
print("Dataset prepared.")

## 2. Load Model (with LoRA)
We use **LoRA (Low-Rank Adaptation)**. Instead of training the whole model (heavy), we train small adapter layers (light).
We use `Qwen/Qwen2.5-0.5B-Instruct` because it is small and fast on Mac.

In [None]:
model_name = "Qwen/Qwen2.5-0.5B-Instruct"
print(f"Loading {model_name}...")

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

# Load Model with Classification Head (1 label = Score)
model = AutoModelForSequenceClassification.from_pretrained(
    model_name, 
    num_labels=1, 
    torch_dtype=torch.float16, # Half precision for M3 speed
    device_map=device
)
model.config.pad_token_id = tokenizer.pad_token_id

# Define LoRA Config
peft_config = LoraConfig(
    task_type=TaskType.SEQ_CLS,
    inference_mode=False,
    r=8,
    lora_alpha=32,
    lora_dropout=0.1
)
# NOTE: We do NOT define get_peft_model(model) here because RewardTrainer does it for us
print("Model and PEFT Config ready.")

## 3. Train on M3
TRL requires `RewardConfig` (not just vanilla `TrainingArguments`) for TRL v0.8+.

In [None]:
# Use RewardConfig instead of TrainingArguments
training_args = RewardConfig(
    output_dir="./reward_model_output",
    per_device_train_batch_size=2,
    num_train_epochs=1,
    learning_rate=1e-4,
    logging_steps=1,
    use_mps_device=True, # IMPORTANT: Tells HF to use M3 Metal
    bf16=False, 
    fp16=False, # FIXED: Disabled to avoid MPS unscale error on simple setup
)

trainer = RewardTrainer(
    model=model,
    processing_class=tokenizer,
    args=training_args,
    train_dataset=dataset,
    peft_config=peft_config,
)

print("Starting Training on M3...")
trainer.train()

## 4. Inference
Now we can use this trained model to score new answers.

In [None]:
def score_answer(prompt, answer):
    inputs = tokenizer(prompt, answer, return_tensors="pt", truncation=True).to(device)
    with torch.no_grad():
        score = model(**inputs).logits[0].cpu().item()
    return score

s1 = score_answer("Is 91 prime?", "No, it is 7*13.")
s2 = score_answer("Is 91 prime?", "Yes it is.")

print(f"Score (Correct): {s1:.4f}")
print(f"Score (Wrong): {s2:.4f}")
if s1 > s2:
    print("SUCCESS: Model prefers the correct answer.")

## 5. Verification: Did it actually learn?
We can check the training logs to see if the **Loss** decreased.
Loss going down means the model got better at distinguishing Correct vs Wrong.

In [None]:
history = trainer.state.log_history
if history:
    initial_loss = history[0].get('loss', 'N/A')
    final_loss = history[-1].get('loss', 'N/A') if history else 'N/A'
    print(f"Initial Loss: {initial_loss}")
    print(f"Final Loss:   {final_loss}")
    print("\n(If Final < Initial, the model effectively learned!)")
else:
    print("Loss history not found (run training first).")