In [None]:
from transformer_model_classification import DeepPHQTransformer
from dataset_classification import DeepPHQDataset, DeepPHQValDataset, build_vocab, create_balanced_dataloader, split_by_pid
import numpy as np
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import torch
import pandas as pd
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using device: MPS")
else:
    device = torch.device("cpu")
    print("Using device: CPU")

In [None]:
def load_model(pt_path, device):
    ckpt = torch.load(pt_path, map_location=device, weights_only=False)

    # -------------------------------
    # 1. Read config + vocab
    # -------------------------------
    config = ckpt["config"]
    vocab = ckpt["vocab"]

    model_cfg = config["model"]

    # CORRECT reconstruction of CORAL model
    model = DeepPHQTransformer(
        input_size=len(vocab),
        hidden_dim=model_cfg["hidden_dim"],
        nhead=model_cfg["nhead"],
        num_layers=model_cfg["num_layers"],
        dropout=model_cfg["dropout"],
        num_items=model_cfg.get("num_items", 8),
        num_levels=4,    # fixed for PHQ 0-3
    )

    # -------------------------------
    # 2. Load weights
    # -------------------------------
    model.load_state_dict(ckpt["model_state"], strict=True)

    # -------------------------------
    # 3. MOVE everything to device
    # -------------------------------
    model.to(device)

    # --- FIX #1: MPS cannot auto-move embedding ---
    model.embedding.weight.data = model.embedding.weight.data.to(device)

    # --- FIX #2: PositionalEncoding buffer must also move ---
    model.pos_encoder.pe = model.pos_encoder.pe.to(device)

    model.eval()
    return model, vocab, config
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

model_word, vocab_word, cfg_word = load_model("./checkpoints/deep_phq_word.pt", device)
model_sentence, vocab_sentence, cfg_sentence = load_model("./checkpoints/deep_phq_sentence.pt", device)
model_dialogue, vocab_dialogue, cfg_dialogue = load_model("./checkpoints/deep_phq_dialogue.pt", device)

In [None]:
def build_loader(csv_path, vocab, max_length, batch_size=32):
    df = pd.read_csv(csv_path)

    dataset = DeepPHQValDataset(
        df,
        vocab=vocab,
        max_length=max_length,
        stride=128,
    )

    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    return loader

loader_word     = build_loader(cfg_word["data"]["data_root"] + cfg_word["data"]["word_csv"],
                               vocab_word, 
                               cfg_word["data"]["max_length"])

loader_sentence = build_loader(cfg_sentence["data"]["data_root"] + cfg_sentence["data"]["sentence_csv"],
                               vocab_sentence, 
                               cfg_sentence["data"]["max_length"])

loader_dialogue = build_loader(cfg_dialogue["data"]["data_root"] + cfg_dialogue["data"]["dialogue_csv"],
                               vocab_dialogue, 
                               cfg_dialogue["data"]["max_length"])

In [None]:
def evaluate_pid_level(model, val_loader, device="cpu", return_preds=False):
    model.eval()

    pid2preds = {}      # pid → list of (8,) predicted ordinal scores
    pid2labels = {}     # pid → vector (8,)
    pid_order = []

    with torch.no_grad():
        for batch in val_loader:

            # always move inputs to device
            input_ids = batch["input_ids"].to(device)

            labels = batch["label"].numpy()    # labels always CPU is ok
            pids   = batch["pid"].numpy()

            logits = model(input_ids)          # (B, 8, 3)
            probs = torch.sigmoid(logits)      # (B, 8, 3)

            # Expected ordinal score = sum(P(score ≥ k))
            expected_scores = probs.sum(dim=2).cpu().numpy()

            for pid, lab_vec, pred_vec in zip(pids, labels, expected_scores):
                pid2preds.setdefault(pid, []).append(pred_vec)
                pid2labels[pid] = lab_vec

    # ----------------------------
    # Aggregate PID-level predictions
    # ----------------------------
    final_pred_totals = []
    final_label_totals = []

    for pid in pid2preds:
        preds = np.stack(pid2preds[pid])      # (num_windows, 8)
        pred_mean_items = preds.mean(axis=0)  # (8,)
        pred_total = pred_mean_items.sum()

        label_items = pid2labels[pid]
        label_total = label_items.sum()

        pid_order.append(pid)
        final_pred_totals.append(pred_total)
        final_label_totals.append(label_total)

    # convert to numpy
    final_pred_totals = np.array(final_pred_totals)
    final_label_totals = np.array(final_label_totals)

    # metrics
    mse  = ((final_pred_totals - final_label_totals) ** 2).mean()
    mae  = np.abs(final_pred_totals - final_label_totals).mean()
    rmse = np.sqrt(mse)

    if return_preds:
        return mse, mae, rmse, final_pred_totals, final_label_totals, pid_order

    return mse, mae, rmse

In [None]:
eval_word = evaluate_pid_level(model_word, loader_word, device=device, return_preds=True)
eval_sentence = evaluate_pid_level(model_sentence, loader_sentence, device=device, return_preds=True)
eval_dialogue = evaluate_pid_level(model_dialogue, loader_dialogue, device=device, return_preds=True)

In [None]:
def load_curve(pt_path):
    ckpt = torch.load(pt_path, map_location="cpu", weights_only=False)
    return ckpt["train_mse"], ckpt["val_mse"]
    
train_w, val_w = load_curve("./checkpoints/deep_phq_word.pt")
train_s, val_s = load_curve("./checkpoints/deep_phq_sentence.pt")
train_d, val_d = load_curve("./checkpoints/deep_phq_dialogue.pt")

In [None]:
mse_values = [mse_w, mse_s, mse_d]
labels = ["word", "sentence", "dialogue"]

plt.figure(figsize=(6,4))
plt.bar(labels, mse_values, color=["#4c72b0", "#55a868", "#c44e52"])

plt.title("Test MSE Comparison")
plt.ylabel("PID-Level MSE")
plt.xlabel("Input-Level")


for i, v in enumerate(mse_values):
    plt.text(i, v + 0.1, f"{v:.2f}", ha='center', fontsize=10)

plt.grid(axis='y', linestyle='--', alpha=0.6)
plt.show()

In [None]:
plt.figure(figsize=(10,5))

plt.plot(train_w, label="word")
plt.plot(train_s, label="sentence")
plt.plot(train_d, label="dialogue")

plt.title("Training CORAL Loss Across Granularity")
plt.xlabel("Epoch")
plt.ylabel("CORAL Loss")
plt.legend()
plt.grid(True)
plt.show()

In [None]:
plt.figure(figsize=(10,5))

plt.plot(val_w, label="word")
plt.plot(val_s, label="sentence")
plt.plot(val_d, label="dialogue")

plt.title("Validation CORAL Loss Across Granularity")
plt.xlabel("Epoch")
plt.ylabel("CORAL Loss")
plt.legend()
plt.grid(True)
plt.show()

In [None]:
plt.figure(figsize=(10,5))
plt.plot(train_w, label="train word")
plt.plot(val_w, label="val word")

plt.plot(train_s, label="train sentence")
plt.plot(val_s, label="val sentence")

plt.plot(train_d, label="train dialogue")
plt.plot(val_d, label="val dialogue")

plt.title("Train/Val PID-Level MSE Across Granularity")
plt.xlabel("Epoch")
plt.ylabel("MSE")
plt.legend()
plt.grid(True)
plt.show()

In [None]:
print("Final Train Loss:")
print("  word     =", train_w[-1])
print("  sentence =", train_s[-1])
print("  dialogue =", train_d[-1])

print("\nFinal Val Loss:")
print("  word     =", val_w[-1])
print("  sentence =", val_s[-1])
print("  dialogue =", val_d[-1])