# Hands-On Challenge: Running AntiDote's Arms Race in Miniature

**Welcome to the final challenge!**  
In this 25-minute activity, you'll simulate AntiDote's bi-level adversarial training on a tiny scale.  
Watch the defender LLM and adversary hypernetwork co-evolve in an arms race — right in your Colab.  

**Your task:**  
1. Load the model + datasets (micro subsets of BeaverTails and LIMA).  
2. Attach defender LoRA (rank=4).  
3. Instantiate tiny adversary hypernetwork.  
4. Run the training loop: 3 blocks of (5 adversary steps + 5 defender steps).  
5. Plot the losses: adversary harm (should drop then rise), defender harm + capability.  

**Challenge question:**  
Modify *one* hyperparameter (assigned to your group: adversary_lr, lora_rank, or adv_def_ratio).  
Run again. Did the arms race stabilize (defender holds), adversary dominate (harm keeps dropping), or defender win early?  
Class discussion: which param matters most for early robustness? (Connects to paper's decoupled loss ablation.)  
(Report in final cell. Best analysis wins a prize!)  

**Time breakdown:**  
- 5' setup/reading  
- 15' run/modify/experiment (ask for hints!)  
- 5' report + prize  

**Hints:**  
- Use batch_size=2 for speed.  
- If OOM, reduce rank to 2.  
- Expect ~5-10 min runtime on T4.  

Runtime: Colab T4 GPU (free tier).

## 1. Setup

In [None]:
%%capture
!pip install transformers datasets peft torch matplotlib --quiet

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model, PeftModel
from datasets import load_dataset
import numpy as np
import matplotlib.pyplot as plt
from copy import deepcopy

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {DEVICE}")

In [None]:
# Load Qwen2-0.5B-Instruct (small, aligned)
MODEL_ID = "Qwen/Qwen2-0.5B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token

base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.float32,
    device_map=DEVICE,
)
base_model.eval()

HIDDEN_DIM = base_model.config.hidden_size
N_LAYERS = base_model.config.num_hidden_layers
print(f"Layers: {N_LAYERS}, Hidden dim: {HIDDEN_DIM}")

## 2. Micro-Datasets

- Harmful: 8 from BeaverTails (prompt + harmful/safe pairs for DPO).
- Benign: 8 from LIMA (capability preservation).

In [None]:
# Load micro subsets
beaver = load_dataset("PKU-Alignment/BeaverTails", split="train").shuffle(seed=42).select(range(8))
lima = load_dataset("GAIR/lima", split="train").shuffle(seed=42).select(range(8))

# Format for DPO: (prompt, safe, harmful)
safe_data = [(ex['prompt'], ex['response_safe']) for ex in beaver]
harm_data = [(ex['prompt'], ex['response_unsafe']) for ex in beaver]

# Benign for capability: (prompt, response)
cap_data = [(ex['conversations'][0], ex['conversations'][1]) for ex in lima]

print(f"Harmful pairs: {len(harm_data)}, Benign: {len(cap_data)}")

In [None]:
MAX_LEN = 128

def tokenize_pairs(pairs, is_dpo=False):
    texts = [f"User: {p}\nAssistant: {r}{tokenizer.eos_token}" for p, r in pairs]
    enc = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=MAX_LEN).to(DEVICE)
    labels = enc["input_ids"].clone()
    labels[labels == tokenizer.pad_token_id] = -100
    enc["labels"] = labels
    return enc

safe_enc = tokenize_pairs(safe_data)
harm_enc = tokenize_pairs(harm_data)
cap_enc = tokenize_pairs(cap_data)

## 3. Adversary Hypernetwork

Tiny version: input activations → self-attn → FF → U/V heads.

In [None]:
class TinyHypernetwork(nn.Module):
    def __init__(self, input_dim=HIDDEN_DIM, rank=4, hidden_dim=128, n_heads=2):
        super().__init__()
        self.rank = rank
        self.attn = nn.MultiheadAttention(input_dim, n_heads)
        self.ff = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim),
        )
        self.head_u = nn.Linear(input_dim, rank * input_dim)
        self.head_v = nn.Linear(input_dim, input_dim * rank)

    def forward(self, x):
        # x: (B, T, D)
        x = x.transpose(0, 1)  # (T, B, D)
        attn_out, _ = self.attn(x, x, x)
        attn_out = attn_out.transpose(0, 1)  # (B, T, D)
        pooled = attn_out.mean(dim=1)  # (B, D)
        emb = self.ff(pooled)
        u = self.head_u(emb).view(-1, self.rank, HIDDEN_DIM)
        v = self.head_v(emb).view(-1, HIDDEN_DIM, self.rank)
        return u, v

## 4. Activation Hooks + Patch Injection

In [None]:
def get_activations(model, inputs, layer_idx):
    activations = {}
    def hook(module, inp, out):
        activations['act'] = out[0].detach()
    handle = model.model.layers[layer_idx].register_forward_hook(hook)
    model(**inputs)
    handle.remove()
    return activations['act']

class AttackedModel:
    def __init__(self, base_model, layer_idx):
        self.model = deepcopy(base_model)
        self.layer_idx = layer_idx

    def apply_patch(self, u, v):
        delta_w = v.transpose(1, 2) @ u  # (B, D_out, D_in)
        # Assume injecting into mlp.fc1 for simplicity
        target = self.model.model.layers[self.layer_idx].mlp.gate_proj
        target.weight.data += delta_w.squeeze(0)

    def forward(self, **kwargs):
        return self.model(**kwargs)

## 5. Losses

- Harm: Negative DPO (adversary max, defender min).
- Cap: CE + KL to base.

In [None]:
def dpo_loss(pi_safe, pi_harm):
    # Simplified: log sigmoid (pi_safe - pi_harm)
    return -F.logsigmoid(pi_safe - pi_harm).mean()

def harm_loss(model, safe_enc, harm_enc):
    # Compute logprobs on safe/harm
    safe_out = model(**safe_enc).logits
    harm_out = model(**harm_enc).logits
    # Shift for next-token
    pi_safe = F.log_softmax(safe_out[:, :-1], dim=-1).gather(-1, safe_enc['labels'][:, 1:].unsqueeze(-1)).mean()
    pi_harm = F.log_softmax(harm_out[:, :-1], dim=-1).gather(-1, harm_enc['labels'][:, 1:].unsqueeze(-1)).mean()
    return -dpo_loss(pi_safe, pi_harm)  # Negative for adversary max

def cap_loss(model, cap_enc, base_model):
    out = model(**cap_enc)
    ce = out.loss
    base_logits = base_model(**cap_enc).logits
    kl = F.kl_div(F.log_softmax(out.logits, dim=-1), F.softmax(base_logits, dim=-1), reduction='batchmean')
    return ce + 0.1 * kl

## 6. Training Loop

Alternate phases A (adv/def on attacked) and B (def on clean).

In [None]:
# Hyperparams — MODIFY ONE FOR CHALLENGE!
adv_lr = 1e-4  # Group 1: try 5e-4 or 1e-5
def_lr = 3e-5
lora_rank = 4  # Group 2: try 2 or 8
adv_steps_per_block = 5  # Group 3: try 2:8 or 8:2 ratio
def_steps_per_block = 5
n_blocks = 3
target_layer = 10  # Fixed for micro

# Defender LoRA
lora_config = LoraConfig(r=lora_rank, lora_alpha=16, target_modules=["q_proj", "v_proj"])
defender = get_peft_model(deepcopy(base_model), lora_config).to(DEVICE)
def_opt = torch.optim.Adam(defender.parameters(), lr=def_lr)

# Adversary
adversary = TinyHypernetwork(rank=lora_rank).to(DEVICE)
adv_opt = torch.optim.Adam(adversary.parameters(), lr=adv_lr)

adv_losses = []
def_harm_losses = []
cap_losses = []

for block in range(n_blocks):
    print(f"Block {block+1}/{n_blocks}")
    
    # Phase A: Adversary max harm
    for _ in range(adv_steps_per_block):
        acts = get_activations(defender, safe_enc, target_layer)  # Or harm_enc
        u, v = adversary(acts)
        attacked = AttackedModel(defender, target_layer)
        attacked.apply_patch(u, v)
        loss_a = harm_loss(attacked.model, safe_enc, harm_enc)
        adv_opt.zero_grad()
        loss_a.backward()
        adv_opt.step()
        adv_losses.append(loss_a.item())
    
    # Phase A: Defender min harm under attack
    for _ in range(def_steps_per_block):
        acts = get_activations(defender, safe_enc, target_layer)
        u, v = adversary(acts)
        attacked = AttackedModel(defender, target_layer)
        attacked.apply_patch(u, v)
        loss_d = harm_loss(attacked.model, safe_enc, harm_enc)
        def_opt.zero_grad()
        loss_d.backward()
        def_opt.step()
        def_harm_losses.append(loss_d.item())
    
    # Phase B: Defender min cap on clean
    for _ in range(def_steps_per_block):
        loss_c = cap_loss(defender, cap_enc, base_model)
        def_opt.zero_grad()
        loss_c.backward()
        def_opt.step()
        cap_losses.append(loss_c.item())

# Extend lists to same length for plotting
max_len = max(len(adv_losses), len(def_harm_losses), len(cap_losses))
adv_losses += [np.nan] * (max_len - len(adv_losses))
def_harm_losses += [np.nan] * (max_len - len(def_harm_losses))
cap_losses += [np.nan] * (max_len - len(cap_losses))

## 7. Plot Losses

In [None]:
steps = range(max_len)
plt.figure(figsize=(10, 6))
plt.plot(steps, adv_losses, 'r-', label='Adversary Harm Loss (max)')
plt.plot(steps, def_harm_losses, 'b-', label='Defender Harm Loss (min)')
plt.plot(steps, cap_losses, 'g-', label='Capability Loss (min)')
plt.xlabel('Training Steps')
plt.ylabel('Loss')
plt.title('AntiDote Arms Race: Losses Over Time')
plt.legend()
plt.grid(alpha=0.3)
plt.show()

## 8. Challenge Report

Modified param: ________  
New value: ________  

Outcome: (stabilize / adv dominate / def early win)  

Insight: (1-2 sentences on why this param affects early robustness, link to paper)