# 1. Load Paramter From Config File

In [None]:
import sys
from pathlib import Path
PROJECT_ROOT = Path.cwd().resolve().parents[2]
sys.path.append(str(PROJECT_ROOT))
print("Loaded project root:", PROJECT_ROOT)

import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import AdamW
from tqdm import tqdm
from dataset_classification import DeepPHQDataset, DeepPHQValDataset, build_vocab, create_balanced_dataloader, split_by_pid
from transformer_model_classification import DeepPHQTransformer
import yaml

# Load config
CONFIG_PATH = Path("config.yaml")

with open(CONFIG_PATH, "r") as f:
    config = yaml.safe_load(f)

config

# 2. Load processed CSV data

In [None]:
data_cfg = config["data"]
root = Path(data_cfg["data_root"])

level = data_cfg["level"]

if level == "word":
    csv_path = root / data_cfg["word_csv"]
elif level == "sentence":
    csv_path = root / data_cfg["sentence_csv"]
elif level == "dialogue":
    csv_path = root / data_cfg["dialogue_csv"]
else:
    raise ValueError("Unknown level in config")

print("Loading:", csv_path)

df = pd.read_csv(csv_path)
df.head()

# 3. Build Vocab and Balanced DataLoader

In [None]:
all_texts = df["Text"].tolist()
vocab = build_vocab(all_texts, min_freq=config["vocab"]["min_freq"])

In [None]:
# 1. split
train_df, val_df, test_df = split_by_pid(df)

# 2. create datasets
train_dataset = DeepPHQDataset(
    data=list(zip(train_df["PID"], train_df["Text"], train_df["PHQ_Score"])),
    vocab=vocab,
    max_length=config["data"]["max_length"]
)

val_dataset = DeepPHQValDataset(
    val_df,
    vocab,
    max_length=config["data"]["max_length"],
    stride=128
)
test_dataset = DeepPHQValDataset(
    test_df,
    vocab,
    max_length=config["data"]["max_length"],
    stride=128
)

val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
# 3. dataloaders
train_loader = create_balanced_dataloader(
    train_dataset,
    batch_size=config["dataloader"]["batch_size"]
)

# 4. verify shapes
batch = next(iter(train_loader))
print("input_ids:", batch["input_ids"].shape, batch["input_ids"].dtype)
print("labels:", batch["label"].shape, batch["label"].dtype)
print("pid:", batch["pid"].shape, batch["pid"].dtype)

print("Unique labels in first batch:")
print(torch.unique(batch["label"]))

# 4. Init Transformer Model

In [None]:
# ---- Load model config ----
model_cfg = config["model"]

# ---- Auto-select device ----
if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

print("Using device:", device)

# ---- Init model ----
model = DeepPHQTransformer(
    input_size=len(vocab),
    num_levels=model_cfg["num_classes_per_item"],
    num_items=model_cfg["num_items"],
    hidden_dim=model_cfg["hidden_dim"],
    nhead=model_cfg["nhead"],
    num_layers=model_cfg["num_layers"],
    dropout=model_cfg["dropout"]
).to(device)

# 5. Train and Evaluate the Model

In [None]:
def evaluate_pid_level(model, val_loader, device="cuda", return_preds=False):
    model.eval()

    pid2preds = {}      # pid → list of (8,) predicted ordinal scores
    pid2labels = {}     # pid → vector (8,)
    pid_order = []

    with torch.no_grad():
        for batch in val_loader:

            input_ids = batch["input_ids"].to(device)
            labels = batch["label"].cpu().numpy()     # (B, 8)
            pids = batch["pid"].cpu().numpy()

            logits = model(input_ids)                 # (B, 8, 3)  ← ordinal boundaries
            probs = torch.sigmoid(logits)             # (B, 8, 3)

            # expected ordinal score: sum(P(score ≥ k))
            expected_scores = probs.sum(dim=2).cpu().numpy()   # (B, 8)

            for pid, lab_vec, pred_vec in zip(pids, labels, expected_scores):
                pid2preds.setdefault(pid, []).append(pred_vec)
                pid2labels[pid] = lab_vec

    # ----------------------------
    # aggregate per PID
    # ----------------------------
    final_pred_totals = []
    final_label_totals = []

    for pid in pid2preds:
        preds = np.stack(pid2preds[pid])      # (num_windows, 8)
        pred_mean_items = preds.mean(axis=0)  # (8,)
        pred_total = pred_mean_items.sum()

        label_items = pid2labels[pid]
        label_total = label_items.sum()

        pid_order.append(pid)
        final_pred_totals.append(pred_total)
        final_label_totals.append(label_total)

    final_pred_totals = np.array(final_pred_totals)
    final_label_totals = np.array(final_label_totals)

    mse = ((final_pred_totals - final_label_totals) ** 2).mean()
    mae = np.abs(final_pred_totals - final_label_totals).mean()
    rmse = np.sqrt(mse)

    if return_preds:
        return mse, mae, rmse, final_pred_totals, final_label_totals, pid_order

    return mse, mae, rmse

In [None]:
def coral_loss(logits, labels):
    """
    logits: (B, 8, 3)  ← boundary logits (K-1)
    labels: (B, 8)     ← true label ∈ {0,1,2,3}

    CORAL Loss = sum_k BCE( sigmoid(logit_k), I[y >= k] )
    """

    B, I, K_minus_1 = logits.shape  # K_minus_1=3 对应 4 级别（0–3）

    # 1. sigmoid → boundary probabilities
    probs = torch.sigmoid(logits)           # (B,8,3)

    # labels: (B,8)
    # expanded: (B,8,3)
    thresholds = torch.arange(1, K_minus_1+1, device=labels.device).view(1,1,K_minus_1)
    # e.g. thresholds = [1,2,3]

    # y>=k → 1 else 0
    boundary_targets = (labels.unsqueeze(-1) >= thresholds).float()   # (B,8,3)

    # 3. binary cross entropy on each boundary
    bce = nn.BCELoss(reduction="mean")

    loss = bce(probs, boundary_targets)

    return loss, probs

In [None]:
train_cfg = config["training"]

optimizer = AdamW(
    model.parameters(),
    lr=float(train_cfg["learning_rate"]),
    weight_decay=float(train_cfg["weight_decay"])
)

# -------------------------------
# history tracking
# -------------------------------
train_coral_losses = []
val_coral_losses = []
train_total_mses = []
val_mses = []


# ===========================
# Training Loop
# ===========================
for epoch in range(train_cfg["num_epochs"]):

    model.train()
    epoch_train_coral = 0
    batch_train_mses = []

    for batch in train_loader:
        input_ids = batch["input_ids"].to(device)
        labels = batch["label"].to(device)   # (B,8)

        optimizer.zero_grad()

        logits = model(input_ids)            # (B,8,3)

        # CORAL loss
        loss, probs = coral_loss(logits, labels)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), train_cfg["gradient_clip"])
        optimizer.step()

        epoch_train_coral += loss.item()

        # expected ordinal score = sum(P >= k)
        expected_scores = probs.sum(dim=2)   # (B,8)
        pred_total = expected_scores.sum(dim=1)
        true_total = labels.sum(dim=1).float()

        mse_batch = ((pred_total - true_total) ** 2).mean().item()
        batch_train_mses.append(mse_batch)

    train_coral_losses.append(epoch_train_coral / len(train_loader))
    train_total_mses.append(sum(batch_train_mses) / len(batch_train_mses))


    # ===========================
    # Validation
    # ===========================
    # 1) CORAL loss on val
    model.eval()
    total_val_coral = 0
    val_batches = 0

    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch["input_ids"].to(device)
            labels = batch["label"].to(device)

            logits = model(input_ids)
            vloss, _ = coral_loss(logits, labels)

            total_val_coral += vloss.item()
            val_batches += 1

    val_coral_losses.append(total_val_coral / val_batches)

    # 2) MSE on val
    val_mse, val_mae, val_rmse = evaluate_pid_level(model, val_loader, device)
    val_mses.append(val_mse)

    print(f"[Epoch {epoch+1}] Train CORAL={train_coral_losses[-1]:.4f} | "
          f"Val CORAL={val_coral_losses[-1]:.4f} | "
          f"Train MSE={train_total_mses[-1]:.2f} | "
          f"Val MSE={val_mses[-1]:.2f}")

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(12,5))

# ---- CORAL loss ----
plt.subplot(1,2,1)
plt.plot(train_coral_losses, label="Train CORAL Loss")
plt.plot(val_coral_losses, label="Val CORAL Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("CORAL Loss (Training vs Validation)")
plt.grid(True)
plt.legend()

# ---- MSE ----
plt.subplot(1,2,2)
plt.plot(train_total_mses, label="Train Total MSE")
plt.plot(val_mses, label="Val Total MSE")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Total Score MSE (PID-level)")
plt.grid(True)
plt.legend()

plt.tight_layout()
plt.show()

# 7. Test Case

In [None]:
print("\nRunning FINAL TEST evaluation...")

test_mse, test_mae, test_rmse, test_pred_totals, test_label_totals, test_pids = \
    evaluate_pid_level(model, test_loader, device, return_preds=True)

print("===== FINAL TEST RESULTS =====")
print(f"Test MSE  = {test_mse:.4f}")
print(f"Test MAE  = {test_mae:.4f}")
print(f"Test RMSE = {test_rmse:.4f}")
print("\n===== TEST SET CASES (PID-level) =====")

for pid, pred, true in zip(test_pids, test_pred_totals, test_label_totals):
    print(f"PID {pid}:  Pred Total = {pred:.2f}   |   True Total = {true:.2f}   |   Error = {abs(pred-true):.2f}")

In [None]:
import matplotlib.pyplot as plt
import numpy as np

true_scores = test_label_totals      # shape (N,)
pred_scores = test_pred_totals       # shape (N,)

# -------------------------------
# (1) Scatter: True vs Predicted
# -------------------------------
plt.figure(figsize=(6,6))
plt.scatter(true_scores, pred_scores, alpha=0.6)
maxv = max(max(true_scores), max(pred_scores))
plt.plot([0, maxv], [0, maxv], 'r--')  # y=x line
plt.xlabel("True Total PHQ-8 Score")
plt.ylabel("Predicted Total PHQ-8 Score")
plt.title("True vs Predicted PHQ-8 Total Scores")
plt.grid(True)
plt.show()

# -------------------------------
# (2) Residual Plot
# -------------------------------
residual = pred_scores - true_scores

plt.figure(figsize=(6,5))
plt.scatter(true_scores, residual, alpha=0.6)
plt.axhline(0, color='r', linestyle='--')
plt.xlabel("True Total PHQ-8 Score")
plt.ylabel("Prediction Error (Pred - True)")
plt.title("Residual Plot")
plt.grid(True)
plt.show()

# -------------------------------
# (3) Error Distribution
# -------------------------------
plt.figure(figsize=(6,5))
plt.hist(residual, bins=20, alpha=0.7, color='purple')
plt.axvline(0, color='r', linestyle='--')
plt.xlabel("Prediction Error")
plt.ylabel("Count")
plt.title("Error Histogram")
plt.grid(True)
plt.show()

# 8. Save Result

In [None]:
level = config["data"]["level"]   # "word", "sentence", "dialogue"

save_dir = Path(config["checkpoint"]["save_dir"])
print("Save directory:", save_dir)
save_dir.mkdir(parents=True, exist_ok=True)

model_name = f"deep_phq_{level}.pt"
save_path = save_dir / model_name
torch.save({ 
    "model_state": model.state_dict(),
    "vocab": vocab,
    "config": config,

    # === Training / validation history ===
    "train_loss": train_coral_losses,
    "val_loss": val_coral_losses,
    "train_mse": train_total_mses,
    "val_mse": val_mses,
}, save_path)

print(f"[✓] Saved model + curves to: {save_path}")
