# Part 4 — AntiDote Micro Hands-on (25 minutes)

This notebook gives you a **live, micro-scale bi-level training loop** that mirrors the paper's adversary-vs-defender dynamics.

## What you will run
- 8 harmful preference pairs (BeaverTails-style)
- 8 benign instruction examples (LIMA-style)
- Defender: LoRA rank = 4
- Adversary: tiny hidden bottleneck MLP
- Schedule per block:
  - 5 adversary steps (log `loss_a`)
  - 5 defender steps (log `loss_d`, `loss_cap`)
- 3 blocks total

You should see `loss_a` fall early, then flatten/rise once the defender starts adapting.


In [None]:
# If needed (e.g., fresh Colab runtime), uncomment:
# !pip -q install torch transformers peft datasets matplotlib


In [None]:
import json
import math
import random
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from IPython.display import clear_output

from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model


In [None]:
# -----------------
# Config
# -----------------
model_name = "Qwen/Qwen2.5-0.5B-Instruct"  # ~0.5B
harmful_path = "data/dpo_data.json"
benign_path = "data/instruction_tuning_data.json"

num_harmful = 8
num_benign = 8

lora_rank = 4
adversary_hidden_dim = 64

adv_steps_per_block = 5
def_steps_per_block = 5
num_blocks = 3

max_length = 256
lr_adversary = 3e-4
lr_defender = 2e-4
lambda_cap = 1.0
attack_scale = 0.35

seed = 42
random.seed(seed)
torch.manual_seed(seed)

device = "cuda" if torch.cuda.is_available() else "cpu"
print("device:", device)


In [None]:
# -----------------
# Load tiny datasets
# -----------------
with open(harmful_path) as f:
    harmful_all = json.load(f)
with open(benign_path) as f:
    benign_all = json.load(f)

harmful_data = random.sample(harmful_all, num_harmful)
benign_data = random.sample(benign_all, num_benign)

print(f"harmful examples: {len(harmful_data)}")
print(f"benign examples: {len(benign_data)}")
print("sample harmful prompt:", harmful_data[0]["prompt"][:120], "...")
print("sample benign prompt:", benign_data[0]["prompt"][:120], "...")


In [None]:
# -----------------
# Tokenizer/model + defender LoRA
# -----------------
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

base_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16 if device == "cuda" else torch.float32,
    device_map=None,
)

lora_cfg = LoraConfig(
    r=lora_rank,
    lora_alpha=16,
    lora_dropout=0.05,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    bias="none",
    task_type="CAUSAL_LM",
)
model = get_peft_model(base_model, lora_cfg)
model.to(device)
model.print_trainable_parameters()


In [None]:
# -----------------
# Tiny adversary
# -----------------
hidden_size = model.config.hidden_size

class TinyAdversary(nn.Module):
    def __init__(self, hidden_size, bottleneck=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(hidden_size, bottleneck),
            nn.Tanh(),
            nn.Linear(bottleneck, hidden_size),
        )

    def forward(self, hidden_states):
        # hidden_states: [B, T, H]
        pooled = hidden_states.mean(dim=1)      # [B, H]
        delta = self.net(pooled)                 # [B, H]
        return delta.unsqueeze(1)                # [B, 1, H]

adversary = TinyAdversary(hidden_size, adversary_hidden_dim).to(device)


In [None]:
# -----------------
# Utilities
# -----------------
def tokenize_pair(prompt, answer):
    messages = [
        {"role": "user", "content": prompt},
        {"role": "assistant", "content": answer},
    ]
    text = tokenizer.apply_chat_template(messages, tokenize=False)

    prompt_only = tokenizer.apply_chat_template(
        [{"role": "user", "content": prompt}],
        tokenize=False,
        add_generation_prompt=True,
    )

    tok_full = tokenizer(text, truncation=True, max_length=max_length, return_tensors="pt")
    tok_prompt = tokenizer(prompt_only, truncation=True, max_length=max_length, return_tensors="pt")

    input_ids = tok_full["input_ids"].to(device)
    attention_mask = tok_full["attention_mask"].to(device)

    labels = input_ids.clone()
    prompt_len = tok_prompt["input_ids"].shape[1]
    labels[:, :prompt_len] = -100
    labels[attention_mask == 0] = -100
    return input_ids, attention_mask, labels


def sequence_nll_from_logits(logits, labels):
    # logits: [B, T, V], labels: [B, T]
    shift_logits = logits[:, :-1, :].contiguous()
    shift_labels = labels[:, 1:].contiguous()

    loss_fct = nn.CrossEntropyLoss(ignore_index=-100, reduction="sum")
    token_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

    valid_tokens = (shift_labels != -100).sum().clamp_min(1)
    return token_loss / valid_tokens


def forward_with_attack(input_ids, attention_mask):
    outputs = model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        output_hidden_states=True,
        use_cache=False,
    )
    clean_hidden = outputs.hidden_states[-1]
    attacked_hidden = clean_hidden + attack_scale * adversary(clean_hidden)
    logits = model.base_model.lm_head(attacked_hidden)
    return logits


def harmful_losses(example):
    prompt = example["prompt"]

    # In this dataset: chosen is harmful answer, rejected is safe answer.
    harmful_answer = example["chosen"]
    safe_answer = example["rejected"]

    hi, hm, hl = tokenize_pair(prompt, harmful_answer)
    si, sm, sl = tokenize_pair(prompt, safe_answer)

    harmful_logits = forward_with_attack(hi, hm)
    safe_logits = forward_with_attack(si, sm)

    nll_harm = sequence_nll_from_logits(harmful_logits, hl)
    nll_safe = sequence_nll_from_logits(safe_logits, sl)

    # adversary wants harmful easier than safe
    loss_a = nll_harm - nll_safe
    # defender wants safe easier than harmful
    loss_d = nll_safe - nll_harm
    return loss_a, loss_d


def capability_loss(example):
    i, m, l = tokenize_pair(example["prompt"], example["response"])
    outputs = model(input_ids=i, attention_mask=m, use_cache=False)
    return sequence_nll_from_logits(outputs.logits, l)


In [None]:
# -----------------
# Optimizers + freeze helpers
# -----------------
defender_params = [p for p in model.parameters() if p.requires_grad]
opt_d = torch.optim.AdamW(defender_params, lr=lr_defender)
opt_a = torch.optim.AdamW(adversary.parameters(), lr=lr_adversary)


def set_trainable_for_adversary_phase():
    for p in model.parameters():
        p.requires_grad = False
    for p in adversary.parameters():
        p.requires_grad = True


def set_trainable_for_defender_phase():
    for p in model.parameters():
        p.requires_grad = True
    for p in adversary.parameters():
        p.requires_grad = False


In [None]:
# -----------------
# Bi-level micro loop
# -----------------
model.train()

history = {
    "step": [],
    "loss_a": [],
    "loss_d": [],
    "loss_cap": [],
}

global_step = 0

for block in range(num_blocks):
    print(f"\n===== Block {block+1}/{num_blocks} =====")

    # 1) Adversary phase
    set_trainable_for_adversary_phase()
    for s in range(adv_steps_per_block):
        ex = random.choice(harmful_data)
        loss_a, _ = harmful_losses(ex)

        opt_a.zero_grad()
        loss_a.backward()
        torch.nn.utils.clip_grad_norm_(adversary.parameters(), 1.0)
        opt_a.step()

        global_step += 1
        history["step"].append(global_step)
        history["loss_a"].append(float(loss_a.detach().cpu()))
        history["loss_d"].append(float("nan"))
        history["loss_cap"].append(float("nan"))

        print(f"Adv step {s+1}/{adv_steps_per_block} | loss_a={loss_a.item():.4f}")

    # 2) Defender phase
    set_trainable_for_defender_phase()
    for s in range(def_steps_per_block):
        harmful_ex = random.choice(harmful_data)
        benign_ex = random.choice(benign_data)

        _, loss_d = harmful_losses(harmful_ex)
        loss_cap = capability_loss(benign_ex)
        loss_total = loss_d + lambda_cap * loss_cap

        opt_d.zero_grad()
        loss_total.backward()
        torch.nn.utils.clip_grad_norm_(defender_params, 1.0)
        opt_d.step()

        global_step += 1
        history["step"].append(global_step)
        history["loss_a"].append(float("nan"))
        history["loss_d"].append(float(loss_d.detach().cpu()))
        history["loss_cap"].append(float(loss_cap.detach().cpu()))

        print(
            f"Def step {s+1}/{def_steps_per_block} | "
            f"loss_d={loss_d.item():.4f} | loss_cap={loss_cap.item():.4f}"
        )

    # Live plot after each block
    clear_output(wait=True)
    plt.figure(figsize=(10, 4))
    plt.plot(history["step"], history["loss_a"], marker="o", label="loss_a (adversary)")
    plt.plot(history["step"], history["loss_d"], marker="o", label="loss_d (defender safety)")
    plt.plot(history["step"], history["loss_cap"], marker="o", label="loss_cap (defender capability)")
    plt.title("Micro AntiDote Dynamics")
    plt.xlabel("micro-step")
    plt.ylabel("loss")
    plt.legend()
    plt.grid(alpha=0.3)
    plt.show()

print("Done.")


## Challenge knobs (for groups)
Run the notebook again with **one** knob changed:

1. `lr_adversary` (e.g., 1e-4, 1e-3)
2. `lora_rank` (e.g., 2, 8)
3. `adv_steps_per_block:def_steps_per_block` (e.g., 2:8 or 8:2)

Then report:
- Did it stabilize?
- Did adversary dominate?
- Did defender hold from the start?
