In [None]:
# -*- coding: utf-8 -*-
"""
Qwen2.5-1.5B + QLoRA(4bit) + LoRA
对比实验：只用原始文本 vs 原文+三维解释
Transformers==4.55
"""

import os, json
from typing import Dict, Any

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

import numpy as np
import torch
import torch.nn.functional as F

from sklearn.metrics import (
    confusion_matrix, ConfusionMatrixDisplay,
    roc_curve, auc,
    precision_recall_fscore_support, classification_report, balanced_accuracy_score
)
from sklearn.preprocessing import label_binarize
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    BitsAndBytesConfig,
    DataCollatorWithPadding,
    Trainer, TrainingArguments,
    set_seed,
)

# 环境开关
os.environ["ACCELERATE_MIXED_PRECISION"] = "no"
os.environ["BITSANDBYTES_NOWELCOME"] = "1"

# ============ 实验输入开关 ============
# "posts_only": 只用原始文本；"all_views": 原文+三维解释
INPUT_MODE   = "posts_only"   # ←← 改成 "all_views" 可切回原设定
MODEL_NAME   = "Qwen/Qwen2.5-1.5B-Instruct"

# 为公平对比：两种模式下总 token 预算尽量一致（≈312）
if INPUT_MODE == "posts_only":
    BUDGET = {"posts_cleaned": 320, "semantic_view": 0, "sentiment_view": 0, "linguistic_view": 0}
else:
    BUDGET = {"posts_cleaned": 320, "semantic_view": 64, "sentiment_view": 32, "linguistic_view": 24}

MAX_LEN      = 440  # 最终编码 max_length（安全边界，和 Trainer truncation 一致）

SEED         = 42
EPOCHS       = 4
LR           = 2e-4
BSZ_TRN      = 8
BSZ_EVAL     = 4
GRAD_ACCUM   = 1
WARMUP_RATIO = 0.06
WEIGHT_DECAY = 0.01

TAG = "postsOnly" if INPUT_MODE=="posts_only" else "allViews"
OUTPUT_DIR   = f"mbti_lora_qwen1.5b_{TAG}_new"

USE_4BIT     = True
LORA_R       = 16
LORA_ALPHA   = 32
LORA_DROPOUT = 0.05
TARGET_MODULES = ["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"]

HF_TOKEN = os.getenv("HF_TOKEN")
HF_KW = {"token": HF_TOKEN} if HF_TOKEN else {}

MBTI_16 = [
    "INTJ","INTP","ENTJ","ENTP","INFJ","INFP","ENFJ","ENFP",
    "ISTJ","ISFJ","ESTJ","ESFJ","ISTP","ISFP","ESTP","ESFP"
]
MBTI2ID = {t: i for i, t in enumerate(MBTI_16)}

# ============ 基础函数 ============
def load_rows(path: str):
    with open(path, "r", encoding="utf-8") as f:
        rows = json.load(f)
    rows = [r for r in rows if isinstance(r, dict) and r.get("type") in MBTI2ID]
    if not rows:
        raise ValueError(f"{path} 中没有合法样本。")
    return rows

def mbti_to_4d(m: str):
    return (
        0 if m[0]=="I" else 1,
        0 if m[1]=="S" else 1,
        0 if m[2]=="F" else 1,
        0 if m[3]=="P" else 1,
    )

def truncate_to_budget(tok: "AutoTokenizer", text: str, budget: int) -> str:
    if budget <= 0:
        return ""
    enc = tok(text or "", add_special_tokens=False)
    ids = enc["input_ids"][: budget]
    return tok.decode(ids)

def build_input(item: Dict[str, Any], tok: "AutoTokenizer") -> str:
    # 原始文本（有 posts_cleaned 优先用）
    p = truncate_to_budget(tok, item.get("posts_cleaned", item.get("posts","")) or "", BUDGET["posts_cleaned"])

    if INPUT_MODE == "posts_only":
        return (
            f"[POSTS]\n{p}\n"
            f"[TASK] Predict MBTI type among {', '.join(MBTI_16)}."
        )
    else:
        sem = truncate_to_budget(tok, item.get("semantic_view","")  or "", BUDGET["semantic_view"])
        sen = truncate_to_budget(tok, item.get("sentiment_view","") or "", BUDGET["sentiment_view"])
        lin = truncate_to_budget(tok, item.get("linguistic_view","") or "", BUDGET["linguistic_view"])
        return (
            f"[POSTS]\n{p}\n[SEMANTIC]\n{sem}\n[SENTIMENT]\n{sen}\n[LINGUISTIC]\n{lin}\n"
            f"[TASK] Predict MBTI type among {', '.join(MBTI_16)}."
        )

class MBTIDataset(torch.utils.data.Dataset):
    def __init__(self, rows, tokenizer, max_len=512):
        self.rows = rows
        self.tok  = tokenizer
        self.max_len = max_len
    def __len__(self): return len(self.rows)
    def __getitem__(self, idx):
        it  = self.rows[idx]
        text= build_input(it, self.tok)
        y   = MBTI2ID[it["type"]]
        enc = self.tok(text, truncation=True, max_length=self.max_len)
        return {"input_ids": enc["input_ids"], "attention_mask": enc["attention_mask"], "labels": y}

# ============ 指标 ============
def compute_metrics(eval_pred):
    if isinstance(eval_pred, tuple):
        preds, labels = eval_pred
    else:
        preds, labels = eval_pred.predictions, eval_pred.label_ids
    if isinstance(preds, (list, tuple)):
        preds = preds[0]
    preds = np.asarray(preds); labels = np.asarray(labels)

    pred_ids = preds.argmax(-1)
    acc16 = float((pred_ids == labels).mean())
    bal_acc16 = balanced_accuracy_score(labels, pred_ids)

    p_micro, r_micro, f1_micro, _ = precision_recall_fscore_support(labels, pred_ids, average="micro", zero_division=0)
    p_macro, r_macro, f1_macro, _ = precision_recall_fscore_support(labels, pred_ids, average="macro", zero_division=0)
    p_weighted, r_weighted, f1_weighted, _ = precision_recall_fscore_support(labels, pred_ids, average="weighted", zero_division=0)

    pred_types = [MBTI_16[i] for i in pred_ids]
    true_types = [MBTI_16[i] for i in labels]
    c_ei=c_ns=c_tf=c_jp=c_all=0
    for pt, tt in zip(pred_types, true_types):
        pei,pns,ptf,pjp = mbti_to_4d(pt)
        tei,tns,ttf,tjp = mbti_to_4d(tt)
        c_ei += (pei==tei); c_ns += (pns==tns); c_tf += (ptf==ttf); c_jp += (pjp==tjp)
        c_all+= (pei==tei and pns==tns and ptf==ttf and pjp==tjp)
    n = len(labels)

    return {
        "acc_16": acc16, "bal_acc_16": bal_acc16,
        "p_micro": p_micro, "r_micro": r_micro, "f1_micro": f1_micro,
        "p_macro": p_macro, "r_macro": r_macro, "f1_macro": f1_macro,
        "p_weighted": p_weighted, "r_weighted": r_weighted, "f1_weighted": f1_weighted,
        "acc_ei": c_ei/n, "acc_ns": c_ns/n, "acc_tf": c_tf/n, "acc_jp": c_jp/n, "acc_4D": c_all/n
    }

# ============ 可视化 ============
def plot_confusion_and_roc(y_true, y_prob, class_names, out_dir, suffix=""):
    os.makedirs(out_dir, exist_ok=True)
    y_pred = np.argmax(y_prob, axis=-1)

    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(class_names))))
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)
    fig_cm, ax_cm = plt.subplots(figsize=(8, 8), dpi=150)
    disp.plot(ax=ax_cm, xticks_rotation=45, cmap="Blues", colorbar=False)
    ax_cm.set_title(f"Confusion Matrix{suffix}")
    fig_cm.tight_layout()
    fig_cm.savefig(os.path.join(out_dir, f"confusion_matrix{suffix}.png"))
    plt.close(fig_cm)

    Y_true_bin = label_binarize(y_true, classes=list(range(len(class_names))))
    fpr = {}; tpr = {}; roc_auc = {}
    for i in range(len(class_names)):
        fpr[i], tpr[i], _ = roc_curve(Y_true_bin[:, i], y_prob[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])

    fpr["micro"], tpr["micro"], _ = roc_curve(Y_true_bin.ravel(), y_prob.ravel())
    roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

    all_fpr = np.unique(np.concatenate([fpr[i] for i in range(len(class_names))]))
    mean_tpr = np.zeros_like(all_fpr)
    for i in range(len(class_names)):
        mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])
    mean_tpr /= len(class_names)
    fpr["macro"] = all_fpr
    tpr["macro"] = mean_tpr
    roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])

    fig_roc, ax_roc = plt.subplots(figsize=(7, 7), dpi=150)
    ax_roc.plot(fpr["micro"], tpr["micro"], label=f"micro-average ROC (AUC = {roc_auc['micro']:.3f})", linewidth=2)
    ax_roc.plot(fpr["macro"], tpr["macro"], label=f"macro-average ROC (AUC = {roc_auc['macro']:.3f})", linewidth=2)
    ax_roc.plot([0, 1], [0, 1], "k--", linewidth=1)
    ax_roc.set_xlim([0.0, 1.0]); ax_roc.set_ylim([0.0, 1.05])
    ax_roc.set_xlabel("False Positive Rate"); ax_roc.set_ylabel("True Positive Rate")
    ax_roc.set_title(f"Multiclass ROC (micro & macro){suffix}")
    ax_roc.legend(loc="lower right")
    fig_roc.tight_layout()
    fig_roc.savefig(os.path.join(out_dir, f"roc_micro_macro{suffix}.png"))
    plt.close(fig_roc)

def mean_pool(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor):
    mask = attention_mask.unsqueeze(-1).type_as(last_hidden_state)
    summed = (last_hidden_state * mask).sum(dim=1)
    denom = mask.sum(dim=1).clamp(min=1e-6)
    return summed / denom

@torch.no_grad()
def extract_embeddings(model, tokenizer, dataset, device="cuda:0", batch_size=4):
    from torch.utils.data import DataLoader
    dl = DataLoader(dataset, batch_size=batch_size, shuffle=False,
                    collate_fn=DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8))
    embs, ys = [], []
    model.eval()
    for batch in dl:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].cpu().numpy()
        out = model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
        vec = mean_pool(out.hidden_states[-1], attention_mask)
        embs.append(vec.cpu().numpy()); ys.append(labels)
    return np.concatenate(embs, 0), np.concatenate(ys, 0)

def plot_tsne(embs, labels, class_names, out_png, title="t-SNE (PCA init)"):
    n = len(labels)
    k = min(50, embs.shape[1])
    embs50 = PCA(n_components=k, random_state=SEED).fit_transform(embs)
    perplexity = max(5, min(50, n // 20))
    tsne2 = TSNE(n_components=2, init="pca", random_state=SEED, perplexity=perplexity, learning_rate="auto")
    X2 = tsne2.fit_transform(embs50)

    plt.figure(figsize=(8, 7), dpi=150)
    for cid, cname in enumerate(class_names):
        idx = (labels == cid)
        if idx.sum() == 0: continue
        plt.scatter(X2[idx, 0], X2[idx, 1], s=10, alpha=0.7, label=cname)
    plt.title(title); plt.xticks([]); plt.yticks([])
    plt.legend(markerscale=2, fontsize=8, ncol=2, frameon=False)
    plt.tight_layout(); plt.savefig(out_png); plt.close()

# ============ 主流程 ============
def main():
    torch.cuda.set_device(0)
    set_seed(SEED)
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

    # 沿用相同切分（公平对比）
    train_rows = load_rows("train消融.json")
    val_rows   = load_rows("val消融.json")
    test_rows  = load_rows("test消融.json")

    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True, **HF_KW)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"

    quant_cfg = BitsAndBytesConfig(
        load_in_4bit=USE_4BIT,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
    ) if USE_4BIT else None

    model_kwargs = dict(
        num_labels=16,
        quantization_config=quant_cfg,
        device_map={"": "cuda:0"},
        low_cpu_mem_usage=True,
        **HF_KW,
    )
    model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, **model_kwargs)
    model.config.pad_token_id = tokenizer.pad_token_id
    model.config.use_cache = False

    from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
    model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=False)
    try:
        model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
    except Exception:
        pass
    peft_cfg = LoraConfig(
        task_type=TaskType.SEQ_CLS,
        r=LORA_R, lora_alpha=LORA_ALPHA, lora_dropout=LORA_DROPOUT,
        target_modules=TARGET_MODULES, bias="none"
    )
    model = get_peft_model(model, peft_cfg).to("cuda:0")

    # 防 Trainer 再次 .to()
    def _noop_to(self, *args, **kwargs): return self
    model.to = _noop_to.__get__(model, type(model))

    train_ds = MBTIDataset(train_rows, tokenizer, max_len=MAX_LEN)
    val_ds   = MBTIDataset(val_rows,   tokenizer, max_len=MAX_LEN)
    test_ds  = MBTIDataset(test_rows,  tokenizer, max_len=MAX_LEN)
    collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)

    args = TrainingArguments(
        output_dir=OUTPUT_DIR,
        per_device_train_batch_size=BSZ_TRN,
        per_device_eval_batch_size=BSZ_EVAL,
        gradient_accumulation_steps=GRAD_ACCUM,
        num_train_epochs=EPOCHS,
        learning_rate=LR,
        warmup_ratio=WARMUP_RATIO,
        weight_decay=WEIGHT_DECAY,
        lr_scheduler_type="linear",
        eval_strategy="epoch",
        save_strategy="epoch",
        save_total_limit=2,
        logging_steps=50,
        bf16=False, fp16=False,
        report_to="none",
        load_best_model_at_end=True,
        metric_for_best_model="eval_acc_4D",
        greater_is_better=True,
        optim="paged_adamw_8bit",
        eval_accumulation_steps=12,
        gradient_checkpointing=False,
    )

    trainer = Trainer(
        model=model,
        args=args,
        train_dataset=train_ds,
        eval_dataset=val_ds,
        tokenizer=tokenizer,
        data_collator=collator,
        compute_metrics=compute_metrics,
    )

    # 训练
    trainer.train()

    # ===== VAL =====
    val_output = trainer.predict(val_ds)
    val_logits = val_output.predictions[0] if isinstance(val_output.predictions, (list, tuple)) else val_output.predictions
    val_probs  = F.softmax(torch.tensor(val_logits, dtype=torch.float32), dim=-1).cpu().numpy()
    val_y_true = val_output.label_ids
    plot_confusion_and_roc(val_y_true, val_probs, MBTI_16, OUTPUT_DIR, suffix="")
    print(f"Saved: {os.path.join(OUTPUT_DIR, 'confusion_matrix.png')}")
    print(f"Saved: {os.path.join(OUTPUT_DIR, 'roc_micro_macro.png')}")

    eval_metrics = trainer.evaluate(eval_dataset=val_ds)
    print("\n=== Final Eval (on VAL) ===")
    for k, v in eval_metrics.items():
        try: print(f"{k}: {float(v):.4f}")
        except: print(k, v)

    val_pred_ids = val_probs.argmax(-1)
    val_report = classification_report(val_y_true, val_pred_ids, target_names=MBTI_16, digits=4, zero_division=0)
    with open(os.path.join(OUTPUT_DIR, "classification_report_val.txt"), "w", encoding="utf-8") as f:
        f.write(val_report)
    print("\n=== VAL Classification Report ===\n", val_report)

    # ===== TEST =====
    test_output = trainer.predict(test_ds)
    test_logits = test_output.predictions[0] if isinstance(test_output.predictions, (list, tuple)) else test_output.predictions
    test_probs  = F.softmax(torch.tensor(test_logits, dtype=torch.float32), dim=-1).cpu().numpy()
    test_y_true = test_output.label_ids
    plot_confusion_and_roc(test_y_true, test_probs, MBTI_16, OUTPUT_DIR, suffix="_test")
    print(f"Saved: {os.path.join(OUTPUT_DIR, 'confusion_matrix_test.png')}")
    print(f"Saved: {os.path.join(OUTPUT_DIR, 'roc_micro_macro_test.png')}")

    test_pred_ids = test_probs.argmax(-1)
    test_report = classification_report(test_y_true, test_pred_ids, target_names=MBTI_16, digits=4, zero_division=0)
    with open(os.path.join(OUTPUT_DIR, "classification_report_test.txt"), "w", encoding="utf-8") as f:
        f.write(test_report)
    print("\n=== TEST Classification Report ===\n", test_report)

    # 4D 指标打印（与原版一致）
    acc16 = float((test_pred_ids == test_y_true).mean())
    pred_types = [MBTI_16[i] for i in test_pred_ids]
    true_types = [MBTI_16[i] for i in test_y_true]
    c_ei=c_ns=c_tf=c_jp=c_all=0
    for pt, tt in zip(pred_types, true_types):
        pei,pns,ptf,pjp = mbti_to_4d(pt)
        tei,tns,ttf,tjp = mbti_to_4d(tt)
        c_ei += (pei==tei); c_ns += (pns==tns); c_tf += (ptf==ttf); c_jp += (pjp==tjp)
        c_all+= (pei==tei and pns==tns and ptf==ttf and pjp==tjp)
    n = len(test_y_true)
    print("\n=== Final Test (held-out TEST) ===")
    print(f"acc_16: {acc16:.4f}")
    print(f"acc_ei: {c_ei/n:.4f}  acc_ns: {c_ns/n:.4f}  acc_tf: {c_tf/n:.4f}  acc_jp: {c_jp/n:.4f}  acc_4D: {c_all/n:.4f}")

    # ===== t-SNE 聚类（可选）=====
    val_embs, val_y = extract_embeddings(model, tokenizer, val_ds, device="cuda:0", batch_size=BSZ_EVAL)
    plot_tsne(val_embs, val_y, MBTI_16, os.path.join(OUTPUT_DIR, "tsne_val.png"), title=f"t-SNE (VAL) - {TAG}")

    test_embs, test_y = extract_embeddings(model, tokenizer, test_ds, device="cuda:0", batch_size=BSZ_EVAL)
    plot_tsne(test_embs, test_y, MBTI_16, os.path.join(OUTPUT_DIR, "tsne_test.png"), title=f"t-SNE (TEST) - {TAG}")
    print(f"Saved: {os.path.join(OUTPUT_DIR, 'tsne_val.png')}")
    print(f"Saved: {os.path.join(OUTPUT_DIR, 'tsne_test.png')}")

    # 保存 LoRA 适配器
    trainer.save_model(OUTPUT_DIR)
    print(f"\n✅ LoRA adapter saved to: {OUTPUT_DIR}")

    # 推理示例
    model.eval()
    sample = test_rows[0]
    text = build_input(sample, tokenizer)
    batch = tokenizer(text, return_tensors="pt", truncation=True, max_length=MAX_LEN)
    batch = {k: v.to("cuda:0") for k, v in batch.items()}
    with torch.no_grad():
        logits = model(**batch).logits
        pred_id = int(torch.argmax(logits, dim=-1))
        pred_mbti = MBTI_16[pred_id]
    print("\n[Inference on TEST sample]")
    print("原标签:", sample["type"], " | 预测:", pred_mbti)

if __name__ == "__main__":
    main()
