# 消融实验 带解释的原数据

In [None]:
import json
from pathlib import Path

ORIG_PATH = "/home/hli962/Chunhou_Project/CL/dataset_mbti/mbti_dataset.json"
AUG_PATH  = "mbti_sample_with_all_views.json"
OUT_PATH  = "mbti_sample_with_all_views_TOPN.json"

# 若你的 JSON 是 {"data":[...]} 结构，则会自动取 data；如果本身就是列表也可
def load_list(path):
    with open(path, "r", encoding="utf-8") as f:
        obj = json.load(f)
    if isinstance(obj, dict) and "data" in obj and isinstance(obj["data"], list):
        return obj["data"]
    if isinstance(obj, list):
        return obj
    raise ValueError(f"{path} 应该是列表，或包含 data 为列表的 JSON")

def quick_match_rate(orig_list, cand_list, id_field="id", pair_fields=("type","post")):
    """可选小工具：估计增强集前N条与原始的重合度（若有可用键）"""
    def key(r):
        if id_field and isinstance(r, dict) and r.get(id_field) is not None:
            return ("ID", str(r[id_field]))
        a,b = pair_fields
        if isinstance(r, dict) and r.get(a) is not None and r.get(b) is not None:
            return ("PAIR", f"{r[a]}||{r[b]}")
        return None

    orig_keys = {key(r) for r in orig_list if key(r) is not None}
    if not orig_keys:
        return None, 0, 0
    cand_keys = [key(r) for r in cand_list]
    hit = sum(1 for k in cand_keys if k in orig_keys)
    return True, hit, len(cand_keys)

def main():
    orig_list = load_list(ORIG_PATH)
    aug_list  = load_list(AUG_PATH)

    N = len(orig_list)
    topn = aug_list[:N]

    # 写出切片结果（保持列表结构；如需 {"data": ...} 自行包一层）
    with open(OUT_PATH, "w", encoding="utf-8") as f:
        json.dump(topn, f, ensure_ascii=False, indent=2)

    print(f"原始数据条数 N = {N}")
    print(f"已从增强集保留前 N 条，写入：{OUT_PATH}（共 {len(topn)} 条）")

    # 可选一致性检查
    ok, hit, tot = quick_match_rate(orig_list, topn)
    if ok:
        rate = hit / tot if tot else 0
        print(f"快速匹配命中：{hit}/{tot}（≈ {rate:.1%}）。"
              f"若命中率很低，说明增强集的前N条未必对应原始样本顺序。")

if __name__ == "__main__":
    main()


In [None]:
import json
import re
from sklearn.model_selection import train_test_split
from collections import Counter

# ========== 工具函数 ==========
def norm_text(s: str) -> str:
    """规范化 posts_cleaned：去掉多余空格、换行并转小写"""
    s = s.replace("\r\n", "\n").replace("\r", "\n")
    s = re.sub(r"\s+", " ", s)
    return s.strip().lower()

def value_counts(data, key="type"):
    counter = Counter([d.get(key, None) for d in data])
    total = sum(counter.values()) or 1
    return {k: f"{v} ({v/total:.2%})" for k, v in counter.items()}

# ========== 1. 读入数据 ==========
with open("mbti_sample_with_all_views_TOPN.json", "r", encoding="utf-8") as f:
    full_data = json.load(f)

with open("test对应的原始数据.json", "r", encoding="utf-8") as f:
    test_data = json.load(f)

print("完整数据:", len(full_data))
print("测试集:", len(test_data))

# ========== 2. 按 posts_cleaned 剪掉 test ==========
test_keys = {norm_text(d["posts_cleaned"]) for d in test_data}
dev_data = [d for d in full_data if norm_text(d["posts_cleaned"]) not in test_keys]

print("去掉测试集后，剩余 dev:", len(dev_data))

# ========== 3. 从 dev 中再划分 train/val ==========
X = dev_data
y = [d["type"] for d in dev_data]

try:
    train_data, val_data = train_test_split(
        X,
        test_size=0.2,       # 验证集比例，可改
        random_state=42,
        stratify=y           # 分类任务分层
    )
except ValueError as e:
    print("⚠️ 分层失败，改为随机拆分:", e)
    train_data, val_data = train_test_split(
        X,
        test_size=0.2,
        random_state=42,
        stratify=None
    )

# ========== 4. 打印统计信息 ==========
print("\n样本数分布：")
print("Train:", len(train_data), value_counts(train_data))
print("Val  :", len(val_data), value_counts(val_data))
print("Test :", len(test_data), value_counts(test_data))

# ========== 5. 如需保存成文件 ==========
with open("train消融.json", "w", encoding="utf-8") as f:
    json.dump(train_data, f, ensure_ascii=False, indent=2)
with open("val消融.json", "w", encoding="utf-8") as f:
    json.dump(val_data, f, ensure_ascii=False, indent=2)
with open("test消融.json", "w", encoding="utf-8") as f:
    json.dump(test_data, f, ensure_ascii=False, indent=2)

print("\n已保存 train.json / val.json / test.json")


In [None]:
# -*- coding: utf-8 -*-
"""
将 mbti_sample_with_all_views.json 分层切分为 8:1:1 的 train/val/test
随机种子固定为 42，确保可复现
"""
import json
from pathlib import Path
from sklearn.model_selection import train_test_split

MBTI_16 = [
    "INTJ","INTP","ENTJ","ENTP","INFJ","INFP","ENFJ","ENFP",
    "ISTJ","ISFJ","ESTJ","ESFJ","ISTP","ISFP","ESTP","ESFP"
]
MBTI2ID = {t: i for i, t in enumerate(MBTI_16)}

def load_rows(path: Path):
    with path.open("r", encoding="utf-8") as f:
        rows = json.load(f)
    rows = [r for r in rows if isinstance(r, dict) and r.get("type") in MBTI2ID]
    if not rows:
        raise ValueError("输入文件中没有合法样本（缺少 'type' 或不在 16 类里）。")
    return rows

def save_json(rows, path: Path):
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("w", encoding="utf-8") as f:
        json.dump(rows, f, ensure_ascii=False, indent=2)

def main(
    input_file="mbti_sample_with_all_views_TOPN.json",
    outdir=".",
    seed=42
):
    inp = Path(input_file)
    out = Path(outdir)

    rows = load_rows(inp)
    y = [r["type"] for r in rows]

    # 先取 10% 作为 TEST（分层）
    trainval_rows, test_rows = train_test_split(
        rows, test_size=0.10, random_state=seed, stratify=y
    )
    # 再从 90% 里切 10% 作为 VAL（占总数 0.1）=> 0.1 / 0.9
    y_trainval = [r["type"] for r in trainval_rows]
    train_rows, val_rows = train_test_split(
        trainval_rows, test_size=0.1111111111, random_state=seed, stratify=y_trainval
    )

    save_json(train_rows, out / "train_topn.json")
    save_json(val_rows,   out / "val_topn.json")
    save_json(test_rows,  out / "test_topn.json")

    print(f"✅ 已保存：{out/'train.json'}（{len(train_rows)} 条）")
    print(f"✅ 已保存：{out/'val.json'}（{len(val_rows)} 条）")
    print(f"✅ 已保存：{out/'test.json'}（{len(test_rows)} 条）")

if __name__ == "__main__":
    main()


In [None]:
# -*- coding: utf-8 -*-
"""
DeepSeek-R1-Distill-Qwen-1.5B + QLoRA(4bit) + LoRA
读取 train_topn.json / val_topn.json / test_topn.json 进行训练与评测（VAL & TEST）
Transformers==4.55  (peft / bitsandbytes 按需安装)
"""

import os, json
from typing import Dict, Any

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

import numpy as np
import torch
import torch.nn.functional as F

from sklearn.metrics import (
    confusion_matrix, ConfusionMatrixDisplay,
    roc_curve, auc,
    precision_recall_fscore_support, classification_report, balanced_accuracy_score
)
from sklearn.preprocessing import label_binarize
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    BitsAndBytesConfig,
    DataCollatorWithPadding,
    Trainer, TrainingArguments,
    set_seed,
)

# 环境开关
os.environ["ACCELERATE_MIXED_PRECISION"] = "no"
os.environ["BITSANDBYTES_NOWELCOME"] = "1"

# ============ 配置 ============
MODEL_NAME   = "Qwen/Qwen2.5-1.5B-Instruct"
MAX_LEN      = 440
BUDGET = {"posts_cleaned": 320, "semantic_view": 64, "sentiment_view": 32, "linguistic_view": 24}

SEED         = 42
EPOCHS       = 4
LR           = 2e-4
BSZ_TRN      = 8
BSZ_EVAL     = 4
GRAD_ACCUM   = 1
WARMUP_RATIO = 0.06
WEIGHT_DECAY = 0.01
OUTPUT_DIR   = "mbti_lora_qwen1.5b-kaggle_xiaorongExplain_ckpt_new"

USE_4BIT     = True
LORA_R       = 16
LORA_ALPHA   = 32
LORA_DROPOUT = 0.05
TARGET_MODULES = ["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"]

HF_TOKEN = os.getenv("HF_TOKEN")
HF_KW = {"token": HF_TOKEN} if HF_TOKEN else {}

MBTI_16 = [
    "INTJ","INTP","ENTJ","ENTP","INFJ","INFP","ENFJ","ENFP",
    "ISTJ","ISFJ","ESTJ","ESFJ","ISTP","ISFP","ESTP","ESFP"
]
MBTI2ID = {t: i for i, t in enumerate(MBTI_16)}

# ============ 基础函数 ============
def load_rows(path: str):
    with open(path, "r", encoding="utf-8") as f:
        rows = json.load(f)
    rows = [r for r in rows if isinstance(r, dict) and r.get("type") in MBTI2ID]
    if not rows:
        raise ValueError(f"{path} 中没有合法样本。")
    return rows

def mbti_to_4d(m: str):
    return (
        0 if m[0]=="I" else 1,
        0 if m[1]=="S" else 1,
        0 if m[2]=="F" else 1,
        0 if m[3]=="P" else 1,
    )

def truncate_to_budget(tok: AutoTokenizer, text: str, budget: int) -> str:
    enc = tok(text or "", add_special_tokens=False)
    ids = enc["input_ids"][: budget]
    return tok.decode(ids)

def build_input(item: Dict[str, Any], tok: AutoTokenizer) -> str:
    p   = truncate_to_budget(tok, item.get("posts_cleaned", item.get("posts","")) or "", BUDGET["posts_cleaned"])
    sem = truncate_to_budget(tok, item.get("semantic_view","")  or "", BUDGET["semantic_view"])
    sen = truncate_to_budget(tok, item.get("sentiment_view","") or "", BUDGET["sentiment_view"])
    lin = truncate_to_budget(tok, item.get("linguistic_view","") or "", BUDGET["linguistic_view"])
    return (
        f"[POSTS]\n{p}\n[SEMANTIC]\n{sem}\n[SENTIMENT]\n{sen}\n[LINGUISTIC]\n{lin}\n"
        f"[TASK] Predict MBTI type among {', '.join(MBTI_16)}."
    )

class MBTIDataset(torch.utils.data.Dataset):
    def __init__(self, rows, tokenizer, max_len=512):
        self.rows = rows
        self.tok  = tokenizer
        self.max_len = max_len
    def __len__(self): return len(self.rows)
    def __getitem__(self, idx):
        it  = self.rows[idx]
        text= build_input(it, self.tok)
        y   = MBTI2ID[it["type"]]
        enc = self.tok(text, truncation=True, max_length=self.max_len)
        return {"input_ids": enc["input_ids"], "attention_mask": enc["attention_mask"], "labels": y}

# ============ 评价指标 ============
def compute_metrics(eval_pred):
    if isinstance(eval_pred, tuple):
        preds, labels = eval_pred
    else:
        preds, labels = eval_pred.predictions, eval_pred.label_ids
    if isinstance(preds, (list, tuple)):
        preds = preds[0]
    preds = np.asarray(preds)
    labels = np.asarray(labels)

    pred_ids = preds.argmax(-1)
    acc16 = float((pred_ids == labels).mean())
    bal_acc16 = balanced_accuracy_score(labels, pred_ids)

    # 16类 Precision / Recall / F1
    p_micro, r_micro, f1_micro, _ = precision_recall_fscore_support(labels, pred_ids, average="micro", zero_division=0)
    p_macro, r_macro, f1_macro, _ = precision_recall_fscore_support(labels, pred_ids, average="macro", zero_division=0)
    p_weighted, r_weighted, f1_weighted, _ = precision_recall_fscore_support(labels, pred_ids, average="weighted", zero_division=0)

    # 原脚本中的 4D 维度指标
    pred_types = [MBTI_16[i] for i in pred_ids]
    true_types = [MBTI_16[i] for i in labels]
    c_ei=c_ns=c_tf=c_jp=c_all=0
    for pt, tt in zip(pred_types, true_types):
        pei,pns,ptf,pjp = mbti_to_4d(pt)
        tei,tns,ttf,tjp = mbti_to_4d(tt)
        c_ei += (pei==tei); c_ns += (pns==tns); c_tf += (ptf==ttf); c_jp += (pjp==tjp)
        c_all+= (pei==tei and pns==tns and ptf==ttf and pjp==tjp)
    n = len(labels)

    return {
        "acc_16": acc16,
        "bal_acc_16": bal_acc16,
        "p_micro": p_micro, "r_micro": r_micro, "f1_micro": f1_micro,
        "p_macro": p_macro, "r_macro": r_macro, "f1_macro": f1_macro,
        "p_weighted": p_weighted, "r_weighted": r_weighted, "f1_weighted": f1_weighted,
        "acc_ei": c_ei/n, "acc_ns": c_ns/n, "acc_tf": c_tf/n, "acc_jp": c_jp/n, "acc_4D": c_all/n
    }

# ============ 作图：混淆矩阵 & ROC ============
def plot_confusion_and_roc(y_true, y_prob, class_names, out_dir, suffix=""):
    os.makedirs(out_dir, exist_ok=True)
    y_pred = np.argmax(y_prob, axis=-1)

    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(class_names))))
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)
    fig_cm, ax_cm = plt.subplots(figsize=(8, 8), dpi=150)
    disp.plot(ax=ax_cm, xticks_rotation=45, cmap="Blues", colorbar=False)
    ax_cm.set_title(f"Confusion Matrix{suffix}")
    fig_cm.tight_layout()
    fig_cm.savefig(os.path.join(out_dir, f"confusion_matrix{suffix}.png"))
    plt.close(fig_cm)

    Y_true_bin = label_binarize(y_true, classes=list(range(len(class_names))))
    fpr = {}; tpr = {}; roc_auc = {}
    for i in range(len(class_names)):
        fpr[i], tpr[i], _ = roc_curve(Y_true_bin[:, i], y_prob[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])

    fpr["micro"], tpr["micro"], _ = roc_curve(Y_true_bin.ravel(), y_prob.ravel())
    roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

    all_fpr = np.unique(np.concatenate([fpr[i] for i in range(len(class_names))]))
    mean_tpr = np.zeros_like(all_fpr)
    for i in range(len(class_names)):
        mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])
    mean_tpr /= len(class_names)
    fpr["macro"] = all_fpr
    tpr["macro"] = mean_tpr
    roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])

    fig_roc, ax_roc = plt.subplots(figsize=(7, 7), dpi=150)
    ax_roc.plot(fpr["micro"], tpr["micro"],
                label=f"micro-average ROC (AUC = {roc_auc['micro']:.3f})", linewidth=2)
    ax_roc.plot(fpr["macro"], tpr["macro"],
                label=f"macro-average ROC (AUC = {roc_auc['macro']:.3f})", linewidth=2)
    ax_roc.plot([0, 1], [0, 1], "k--", linewidth=1)
    ax_roc.set_xlim([0.0, 1.0]); ax_roc.set_ylim([0.0, 1.05])
    ax_roc.set_xlabel("False Positive Rate"); ax_roc.set_ylabel("True Positive Rate")
    ax_roc.set_title(f"Multiclass ROC (micro & macro){suffix}")
    ax_roc.legend(loc="lower right")
    fig_roc.tight_layout()
    fig_roc.savefig(os.path.join(out_dir, f"roc_micro_macro{suffix}.png"))
    plt.close(fig_roc)

# ============ 向量抽取 & t-SNE ============
def mean_pool(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor):
    # last_hidden_state: [B, T, H]; attention_mask: [B, T]
    mask = attention_mask.unsqueeze(-1).type_as(last_hidden_state)  # [B, T, 1]
    summed = (last_hidden_state * mask).sum(dim=1)                  # [B, H]
    denom = mask.sum(dim=1).clamp(min=1e-6)                         # [B, 1]
    return summed / denom

@torch.no_grad()
def extract_embeddings(model, tokenizer, dataset, device="cuda:0", batch_size=4):
    from torch.utils.data import DataLoader
    dl = DataLoader(dataset, batch_size=batch_size, shuffle=False,
                    collate_fn=DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8))
    embs = []
    ys   = []
    model.eval()
    for batch in dl:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].cpu().numpy()

        out = model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
        last_h = out.hidden_states[-1]  # [B,T,H]
        vec = mean_pool(last_h, attention_mask)  # [B,H]
        embs.append(vec.cpu().numpy())
        ys.append(labels)
    embs = np.concatenate(embs, axis=0)
    ys   = np.concatenate(ys, axis=0)
    return embs, ys

def plot_tsne(embs, labels, class_names, out_png, title="t-SNE (PCA init)"):
    n = len(labels)
    k = min(50, embs.shape[1])
    embs50 = PCA(n_components=k, random_state=SEED).fit_transform(embs)
    perplexity = max(5, min(50, n // 20))  # 经验值
    tsne2 = TSNE(n_components=2, init="pca", random_state=SEED, perplexity=perplexity, learning_rate="auto")
    X2 = tsne2.fit_transform(embs50)

    plt.figure(figsize=(8, 7), dpi=150)
    for cid, cname in enumerate(class_names):
        idx = (labels == cid)
        if idx.sum() == 0:
            continue
        plt.scatter(X2[idx, 0], X2[idx, 1], s=10, alpha=0.7, label=cname)
    plt.title(title)
    plt.xticks([]); plt.yticks([])
    plt.legend(markerscale=2, fontsize=8, ncol=2, frameon=False)
    plt.tight_layout()
    plt.savefig(out_png)
    plt.close()

# ============ 主流程 ============
def main():
    torch.cuda.set_device(0)
    set_seed(SEED)
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

    # 读取切分后的文件
    train_rows = load_rows("train消融.json")
    val_rows   = load_rows("val消融.json")
    test_rows  = load_rows("test消融.json")

    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True, **HF_KW)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"

    quant_cfg = BitsAndBytesConfig(
        load_in_4bit=USE_4BIT,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
    ) if USE_4BIT else None

    model_kwargs = dict(
        num_labels=16,
        quantization_config=quant_cfg,
        device_map={"": "cuda:0"},
        low_cpu_mem_usage=True,
        **HF_KW,
    )
    model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, **model_kwargs)
    model.config.pad_token_id = tokenizer.pad_token_id
    model.config.use_cache = False

    from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
    model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=False)
    try:
        model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
    except Exception:
        pass
    peft_cfg = LoraConfig(
        task_type=TaskType.SEQ_CLS,
        r=LORA_R, lora_alpha=LORA_ALPHA, lora_dropout=LORA_DROPOUT,
        target_modules=TARGET_MODULES, bias="none"
    )
    model = get_peft_model(model, peft_cfg)
    model = model.to("cuda:0")

    # 防止Trainer在量化+PEFT时二次 to()
    def _noop_to(self, *args, **kwargs): return self
    model.to = _noop_to.__get__(model, type(model))

    train_ds = MBTIDataset(train_rows, tokenizer, max_len=MAX_LEN)
    val_ds   = MBTIDataset(val_rows,   tokenizer, max_len=MAX_LEN)
    test_ds  = MBTIDataset(test_rows,  tokenizer, max_len=MAX_LEN)
    collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)

    args = TrainingArguments(
        output_dir=OUTPUT_DIR,
        per_device_train_batch_size=BSZ_TRN,
        per_device_eval_batch_size=BSZ_EVAL,
        gradient_accumulation_steps=GRAD_ACCUM,
        num_train_epochs=EPOCHS,
        learning_rate=LR,
        warmup_ratio=WARMUP_RATIO,
        weight_decay=WEIGHT_DECAY,
        lr_scheduler_type="linear",
        eval_strategy="epoch",
        save_strategy="epoch",
        save_total_limit=2,
        logging_steps=50,
        bf16=False, fp16=False,
        report_to="none",
        load_best_model_at_end=True,
        metric_for_best_model="eval_acc_4D",
        greater_is_better=True,
        optim="paged_adamw_8bit",
        eval_accumulation_steps=12,
        gradient_checkpointing=False,
    )

    trainer = Trainer(
        model=model,
        args=args,
        train_dataset=train_ds,
        eval_dataset=val_ds,
        tokenizer=tokenizer,
        data_collator=collator,
        compute_metrics=compute_metrics,
    )

    # 训练
    trainer.train()

    # ===== 验证集预测 & 作图 =====
    val_output = trainer.predict(val_ds)
    val_logits = val_output.predictions[0] if isinstance(val_output.predictions, (list, tuple)) else val_output.predictions
    val_probs  = F.softmax(torch.tensor(val_logits, dtype=torch.float32), dim=-1).cpu().numpy()
    val_y_true = val_output.label_ids
    plot_confusion_and_roc(val_y_true, val_probs, MBTI_16, OUTPUT_DIR, suffix="")
    print(f"Saved: {os.path.join(OUTPUT_DIR, 'confusion_matrix.png')}")
    print(f"Saved: {os.path.join(OUTPUT_DIR, 'roc_micro_macro.png')}")

    # 验证集指标（含 P/R/F1）
    eval_metrics = trainer.evaluate(eval_dataset=val_ds)
    print("\n=== Final Eval (on VAL) ===")
    for k, v in eval_metrics.items():
        try: print(f"{k}: {float(v):.4f}")
        except: print(k, v)

    # 保存 VAL classification_report
    val_pred_ids = val_probs.argmax(-1)
    val_report = classification_report(val_y_true, val_pred_ids, target_names=MBTI_16, digits=4, zero_division=0)
    with open(os.path.join(OUTPUT_DIR, "classification_report_val.txt"), "w", encoding="utf-8") as f:
        f.write(val_report)
    print("\n=== VAL Classification Report ===\n", val_report)

    # ===== 测试集评测 & 作图 =====
    test_output = trainer.predict(test_ds)
    test_logits = test_output.predictions[0] if isinstance(test_output.predictions, (list, tuple)) else test_output.predictions
    test_probs  = F.softmax(torch.tensor(test_logits, dtype=torch.float32), dim=-1).cpu().numpy()
    test_y_true = test_output.label_ids
    plot_confusion_and_roc(test_y_true, test_probs, MBTI_16, OUTPUT_DIR, suffix="_test")
    print(f"Saved: {os.path.join(OUTPUT_DIR, 'confusion_matrix_test.png')}")
    print(f"Saved: {os.path.join(OUTPUT_DIR, 'roc_micro_macro_test.png')}")

    # 保存 TEST classification_report
    test_pred_ids = test_probs.argmax(-1)
    test_report = classification_report(test_y_true, test_pred_ids, target_names=MBTI_16, digits=4, zero_division=0)
    with open(os.path.join(OUTPUT_DIR, "classification_report_test.txt"), "w", encoding="utf-8") as f:
        f.write(test_report)
    print("\n=== TEST Classification Report ===\n", test_report)

    # 测试集 4D 指标（保持原有打印）
    acc16 = float((test_pred_ids == test_y_true).mean())
    pred_types = [MBTI_16[i] for i in test_pred_ids]
    true_types = [MBTI_16[i] for i in test_y_true]
    c_ei=c_ns=c_tf=c_jp=c_all=0
    for pt, tt in zip(pred_types, true_types):
        pei,pns,ptf,pjp = mbti_to_4d(pt)
        tei,tns,ttf,tjp = mbti_to_4d(tt)
        c_ei += (pei==tei); c_ns += (pns==tns); c_tf += (ptf==ttf); c_jp += (pjp==tjp)
        c_all+= (pei==tei and pns==tns and ptf==ttf and pjp==tjp)
    n = len(test_y_true)
    print("\n=== Final Test (held-out TEST) ===")
    print(f"acc_16: {acc16:.4f}")
    print(f"acc_ei: {c_ei/n:.4f}  acc_ns: {c_ns/n:.4f}  acc_tf: {c_tf/n:.4f}  acc_jp: {c_jp/n:.4f}  acc_4D: {c_all/n:.4f}")

    # ====== 聚类：VAL / TEST（t-SNE）======
    val_embs, val_y = extract_embeddings(model, tokenizer, val_ds, device="cuda:0", batch_size=BSZ_EVAL)
    plot_tsne(val_embs, val_y, MBTI_16, os.path.join(OUTPUT_DIR, "tsne_val.png"), title="t-SNE (VAL)")

    test_embs, test_y = extract_embeddings(model, tokenizer, test_ds, device="cuda:0", batch_size=BSZ_EVAL)
    plot_tsne(test_embs, test_y, MBTI_16, os.path.join(OUTPUT_DIR, "tsne_test.png"), title="t-SNE (TEST)")
    print(f"Saved: {os.path.join(OUTPUT_DIR, 'tsne_val.png')}")
    print(f"Saved: {os.path.join(OUTPUT_DIR, 'tsne_test.png')}")

    # 保存 LoRA 适配器
    trainer.save_model(OUTPUT_DIR)
    print(f"\n✅ LoRA adapter saved to: {OUTPUT_DIR}")

    # 推理示例（从 TEST 取一个样本）
    model.eval()
    sample = test_rows[0]
    text = build_input(sample, tokenizer)
    batch = tokenizer(text, return_tensors="pt", truncation=True, max_length=MAX_LEN)
    batch = {k: v.to("cuda:0") for k, v in batch.items()}
    with torch.no_grad():
        logits = model(**batch).logits
        pred_id = int(torch.argmax(logits, dim=-1))
        pred_mbti = MBTI_16[pred_id]
    print("\n[Inference on TEST sample]")
    print("原标签:", sample["type"], " | 预测:", pred_mbti)

if __name__ == "__main__":
    main()
