In [None]:
import torch
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import math
import os
import pandas as pd
import random
import gc
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model, TaskType

# ==========================================
# 0. Global Settings
# ==========================================
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

set_seed(42)

SAVE_DIR = "./checkpoints/ugid_attn_only"
os.makedirs(SAVE_DIR, exist_ok=True)

# ==========================================
# 1. Load Model (BF16 Full Precision + LoRA)
# ==========================================
print("1) Clearing GPU memory & loading model...")
if "model" in locals():
    del model
gc.collect()
torch.cuda.empty_cache()

model_id = "NousResearch/Meta-Llama-3-8B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    output_attentions=True,
    output_hidden_states=True,
    attn_implementation="eager"
)

# ===== LoRA (ATTENTION ONLY) =====
peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    inference_mode=False,
    r=32,
    lora_alpha=64,
    lora_dropout=0.05,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],  # << attn-only
)
model = get_peft_model(model, peft_config)
print("Model is ready: BF16 base + LoRA(attn-only).")

# ==========================================
# 2. Data
# ==========================================
debias_pairs = [
    ("The doctor said that he", "The doctor said that she"),
    ("The nurse said that she", "The nurse said that he"),
    ("The engineer said that he", "The engineer said that she"),
    ("The teacher said that he", "The teacher said that she"),
    ("The CEO said that he", "The CEO said that she"),
    ("The secretary said that she", "The secretary said that he"),
    ("The developer said that he", "The developer said that she"),
    ("The manager said that he", "The manager said that she"),
    ("The cleaner said that she", "The cleaner said that he"),
    ("The driver said that he", "The driver said that she"),
] * 10

anchor_pairs = [
    ("The king said that he", "The king said that he"),
    ("The queen said that she", "The queen said that she"),
    ("The father said that he", "The father said that he"),
    ("The mother said that she", "The mother said that she"),
    ("The brother said that he", "The brother said that he"),
    ("The sister said that she", "The sister said that she"),
] * 10

print(f"Data prepared: Debias={len(debias_pairs)} | Anchor={len(anchor_pairs)}")

# ==========================================
# 3. Core Functions
# ==========================================
def get_exact_spectrum(attn_matrix):
    B, H, S, _ = attn_matrix.shape
    A_ii = torch.diagonal(attn_matrix, dim1=-2, dim2=-1)
    col_sum = attn_matrix.sum(dim=-2)
    future_attention_sum = col_sum - A_ii
    indices = torch.arange(S, device=attn_matrix.device).view(1, 1, S)
    denominator = torch.clamp((S - indices).float(), min=1.0)
    d_ii = future_attention_sum / denominator
    return d_ii - A_ii

def get_adaptive_weights(attn_a, attn_b, pronoun_idx=-1):
    A_p_row_a = attn_a[..., pronoun_idx, :]
    A_p_row_b = attn_b[..., pronoun_idx, :]
    return 0.5 * (A_p_row_a + A_p_row_b).detach()

def get_surrogate_topk_loss(attn_student, attn_teacher, k=10):
    seq_len = attn_teacher.shape[-1]
    actual_k = min(k, seq_len)
    _, topk_indices = torch.topk(attn_teacher, k=actual_k, dim=-1)
    vals_student = torch.gather(attn_student, -1, topk_indices)
    vals_teacher = torch.gather(attn_teacher, -1, topk_indices)
    return F.l1_loss(vals_student, vals_teacher)

def get_masked_kl_loss(logits_student, logits_teacher, input_ids, sensitive_ids):
    log_probs_student = F.log_softmax(logits_student, dim=-1)
    probs_teacher = F.softmax(logits_teacher, dim=-1)
    kl_per_token = F.kl_div(log_probs_student, probs_teacher, reduction='none').sum(dim=-1)
    mask = torch.ones_like(input_ids, dtype=torch.float32)
    for sid in sensitive_ids:
        mask[input_ids == sid] = 0.0
    return (kl_per_token * mask).sum() / (mask.sum() + 1e-6)

def strip_last_pronoun(text):
    if text.endswith(" he"):
        return text[:-3]
    if text.endswith(" she"):
        return text[:-4]
    return text

# ==========================================
# 4. Training (UGID-attn-only)
# ==========================================
optimizer = optim.AdamW(model.parameters(), lr=1e-4)

# ===== weights =====
lambda_a = 20.0     # ASIT ON
lambda_v = 0.0      # VSIT OFF
lambda_k = 0.0
lambda_kl = 0.0
lambda_logit = 0.0
lambda_anchor = 1.0  # <<<<< 降低 anchor 稳定

target_layers = [13, 15, 17]
sensitive_ids = [tokenizer.encode(" he")[1], tokenizer.encode(" she")[1]]
id_he, id_she = sensitive_ids

print("Starting training: UGID-attn-only...")
model.train()

for epoch in range(5):
    total_loss = 0.0
    combined_data = [(x, y, "debias") for x, y in debias_pairs] + [(x, y, "anchor") for x, y in anchor_pairs]
    random.shuffle(combined_data)

    progress_bar = tqdm(combined_data, desc=f"Epoch {epoch+1}")
    for text_a, text_b, task_type in progress_bar:
        inputs_a = tokenizer(text_a, return_tensors="pt").to(model.device)

        # P_init reference = base (disable_adapter)
        with model.disable_adapter():
            with torch.no_grad():
                ref_outputs_a = model(**inputs_a, output_attentions=True, output_hidden_states=False)

        if task_type == "debias":
            inputs_b = tokenizer(text_b, return_tensors="pt").to(model.device)

            outputs_a = model(**inputs_a, output_attentions=True, output_hidden_states=True)
            outputs_b = model(**inputs_b, output_attentions=True, output_hidden_states=True)

            loss_kl_val = get_masked_kl_loss(outputs_a.logits, ref_outputs_a.logits, inputs_a.input_ids, sensitive_ids)

            loss_asit = 0.0
            loss_topk = 0.0

            for layer_idx in target_layers:
                lam_a = get_exact_spectrum(outputs_a.attentions[layer_idx])
                lam_b = get_exact_spectrum(outputs_b.attentions[layer_idx])
                w = get_adaptive_weights(outputs_a.attentions[layer_idx], outputs_b.attentions[layer_idx])

                mask = torch.ones(lam_a.shape[-1], device=model.device)
                mask[0] = 0
                mask = mask.view(1, 1, -1)

                loss_asit += (mask * w * (lam_a - lam_b) ** 2).sum()
                loss_topk += get_surrogate_topk_loss(outputs_a.attentions[layer_idx], ref_outputs_a.attentions[layer_idx])

            # behavior term: same as your eval prompt position
            prompt = strip_last_pronoun(text_a)
            inputs_p = tokenizer(prompt, return_tensors="pt").to(model.device)
            outputs_p = model(**inputs_p, output_attentions=False, output_hidden_states=False)
            logits_p = outputs_p.logits[0, -1, :]
            log_probs_p = F.log_softmax(logits_p, dim=-1)
            loss_logit_val = (log_probs_p[id_he] - log_probs_p[id_she]) ** 2

            loss = (
                lambda_a * loss_asit +
                lambda_k * loss_topk +
                lambda_kl * loss_kl_val +
                lambda_logit * loss_logit_val
            )

        else:
            outputs_a = model(**inputs_a, output_attentions=False, output_hidden_states=False)
            log_probs = F.log_softmax(outputs_a.logits, dim=-1)
            probs_ref = F.softmax(ref_outputs_a.logits, dim=-1)
            loss_kl_anchor = F.kl_div(log_probs, probs_ref, reduction='batchmean')
            loss = lambda_anchor * loss_kl_anchor

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        total_loss += float(loss.item())
        progress_bar.set_postfix({"loss": float(loss.item())})

    print(f"Epoch {epoch+1} Avg Loss: {total_loss/len(combined_data):.4f}")

print("Training finished: UGID-attn-only")

# ==========================================
# 5) Save LoRA adapter
# ==========================================
model.eval()
model.save_pretrained(SAVE_DIR)
tokenizer.save_pretrained(SAVE_DIR)
print(f"[Saved] LoRA adapter to: {SAVE_DIR}")

  from .autonotebook import tqdm as notebook_tqdm


1) Clearing GPU memory & loading model...


`torch_dtype` is deprecated! Use `dtype` instead!
The following generation flags are not valid and may be ignored: ['output_attentions', 'output_hidden_states']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Loading checkpoint shards: 100%|██████████| 4/4 [00:05<00:00,  1.34s/it]


Model is ready: BF16 base + LoRA(attn-only).
Data prepared: Debias=100 | Anchor=60
Starting training: UGID-attn-only...


Epoch 1: 100%|██████████| 160/160 [01:05<00:00,  2.45it/s, loss=0.00318] 


Epoch 1 Avg Loss: 0.0914


Epoch 2: 100%|██████████| 160/160 [01:04<00:00,  2.49it/s, loss=0.0442]  


Epoch 2 Avg Loss: 0.0420


Epoch 3: 100%|██████████| 160/160 [01:04<00:00,  2.49it/s, loss=0.000546]


Epoch 3 Avg Loss: 0.0117


Epoch 4: 100%|██████████| 160/160 [01:04<00:00,  2.49it/s, loss=0.000554]


Epoch 4 Avg Loss: 0.0071


Epoch 5: 100%|██████████| 160/160 [01:04<00:00,  2.49it/s, loss=0.00252] 


Epoch 5 Avg Loss: 0.0053
Training finished: UGID-attn-only
[Saved] LoRA adapter to: ./checkpoints/ugid_attn_only
