In [None]:
# ==========================================
# UGID-MLP-only (FINAL, reviewer-proof)
# Purpose:
#   Show that constraining ONLY MLP (nodes)
#   cannot block bias routing via attention.
# ==========================================

import torch
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import os
import random
import gc
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model, TaskType

# --------------------------
# 0. Seed
# --------------------------
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(42)

SAVE_DIR = "./checkpoints/ugid_mlp_only"
os.makedirs(SAVE_DIR, exist_ok=True)

# --------------------------
# 1. Load base model (MLP-only LoRA)
# --------------------------
print("Loading base model (UGID-MLP-only)...")
gc.collect()
torch.cuda.empty_cache()

model_id = "NousResearch/Meta-Llama-3-8B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    output_attentions=False,      # ★ 不需要 attention loss
    output_hidden_states=True,
)

# ===== LoRA: MLP ONLY =====
peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    inference_mode=False,
    r=32,
    lora_alpha=64,
    lora_dropout=0.05,
    target_modules=["gate_proj", "up_proj", "down_proj"],  # ★ MLP-only
)
model = get_peft_model(model, peft_config)
model.train()

print("Model ready: BF16 base + LoRA(MLP-only)")

# --------------------------
# 2. Data
# --------------------------
debias_pairs = [
    ("The doctor said that he", "The doctor said that she"),
    ("The nurse said that she", "The nurse said that he"),
    ("The engineer said that he", "The engineer said that she"),
    ("The teacher said that he", "The teacher said that she"),
    ("The CEO said that he", "The CEO said that she"),
    ("The secretary said that she", "The secretary said that he"),
    ("The developer said that he", "The developer said that she"),
    ("The manager said that he", "The manager said that she"),
    ("The cleaner said that she", "The cleaner said that he"),
    ("The driver said that he", "The driver said that she"),
] * 10

anchor_pairs = [
    ("The king said that he", "The king said that he"),
    ("The queen said that she", "The queen said that she"),
    ("The father said that he", "The father said that he"),
    ("The mother said that she", "The mother said that she"),
] * 10

print(f"Data prepared: Debias={len(debias_pairs)} | Anchor={len(anchor_pairs)}")

# --------------------------
# 3. Loss utilities
# --------------------------
def get_masked_kl_loss(logits_student, logits_teacher, input_ids, sensitive_ids):
    """
    Stability term: keep non-sensitive tokens close to base model.
    """
    logp_s = F.log_softmax(logits_student, dim=-1)
    p_t = F.softmax(logits_teacher, dim=-1)
    kl = F.kl_div(logp_s, p_t, reduction="none").sum(dim=-1)

    mask = torch.ones_like(input_ids, dtype=torch.float32)
    for sid in sensitive_ids:
        mask[input_ids == sid] = 0.0

    return (kl * mask).sum() / (mask.sum() + 1e-6)

# --------------------------
# 4. Training setup
# --------------------------
optimizer = optim.AdamW(model.parameters(), lr=1e-4)

# ===== Loss weights =====
lambda_v = 20.0      # ★ ONLY node / MLP constraint
lambda_kl = 1.0      # ★ stability, but weak
lambda_anchor = 1.0  # ★★★ CRITICAL: small anchor, do NOT hide failure
lambda_logit = 0.0   # ★★★ MUST be zero (no behavioral shortcut)

target_layers = [13, 15, 17]
sensitive_ids = [tokenizer.encode(" he")[1], tokenizer.encode(" she")[1]]

print("Training UGID-MLP-only...")

# --------------------------
# 5. Training loop
# --------------------------
for epoch in range(5):
    data = [(a, b, "debias") for a, b in debias_pairs] + \
           [(a, b, "anchor") for a, b in anchor_pairs]
    random.shuffle(data)

    pbar = tqdm(data, desc=f"Epoch {epoch+1}")
    total = 0.0

    for a, b, typ in pbar:
        inp_a = tokenizer(a, return_tensors="pt").to(model.device)

        # base reference
        with model.disable_adapter():
            with torch.no_grad():
                ref = model(**inp_a)

        if typ == "debias":
            inp_b = tokenizer(b, return_tensors="pt").to(model.device)

            out_a = model(**inp_a, output_hidden_states=True)
            out_b = model(**inp_b, output_hidden_states=True)

            # stability
            loss_kl = get_masked_kl_loss(
                out_a.logits, ref.logits, inp_a.input_ids, sensitive_ids
            )

            # ===== VSIT: MLP-only, downstream-only =====
            loss_v = 0.0
            for l in target_layers:
                ha = out_a.hidden_states[l+1] - out_a.hidden_states[l]
                hb = out_b.hidden_states[l+1] - out_b.hidden_states[l]

                S = ha.shape[1]
                mask = torch.ones(S, device=ha.device)
                mask[0] = 0        # BOS
                mask[-1] = 0       # pronoun token (DO NOT erase bias at source)
                mask = mask.view(1, -1, 1)

                loss_v += (mask * (ha - hb) ** 2).mean()
                

            loss = lambda_v * loss_v + lambda_kl * loss_kl
            

        else:
            # anchor stability
            out = model(**inp_a)
            loss = lambda_anchor * F.kl_div(
                F.log_softmax(out.logits, dim=-1),
                F.softmax(ref.logits, dim=-1),
                reduction="batchmean"
            )

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total += loss.item()
        pbar.set_postfix(loss=f"{loss.item():.4f}")

    print(f"Epoch {epoch+1} avg loss: {total/len(data):.4f}")

# --------------------------
# 6. Save adapter
# --------------------------
model.eval()
model.save_pretrained(SAVE_DIR)
tokenizer.save_pretrained(SAVE_DIR)
print(f"[Saved] UGID-MLP-only to {SAVE_DIR}")

Loading base model (UGID-MLP-only)...


Loading checkpoint shards: 100%|██████████| 4/4 [00:04<00:00,  1.01s/it]


Model ready: BF16 base + LoRA(MLP-only)
Data prepared: Debias=100 | Anchor=40
Training UGID-MLP-only...


Epoch 1: 100%|██████████| 140/140 [00:34<00:00,  4.03it/s, loss=0.0068]


Epoch 1 avg loss: 0.1105


Epoch 2: 100%|██████████| 140/140 [00:35<00:00,  3.90it/s, loss=0.0144]


Epoch 2 avg loss: 0.0731


Epoch 3: 100%|██████████| 140/140 [00:36<00:00,  3.80it/s, loss=0.0035]


Epoch 3 avg loss: 0.0720


Epoch 4: 100%|██████████| 140/140 [00:33<00:00,  4.20it/s, loss=0.0688]


Epoch 4 avg loss: 0.0497


Epoch 5: 100%|██████████| 140/140 [00:33<00:00,  4.20it/s, loss=0.0046] 


Epoch 5 avg loss: 0.0343
[Saved] UGID-MLP-only to ./checkpoints/ugid_mlp_only
