# 1. Load Paramter From Config File

In [None]:
import sys
from pathlib import Path
PROJECT_ROOT = Path.cwd().resolve().parents[2]
sys.path.append(str(PROJECT_ROOT))
print("Loaded project root:", PROJECT_ROOT)

import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import AdamW
from tqdm import tqdm
from dataset_allscore import DeepPHQDataset, DeepPHQValDataset, build_vocab, create_balanced_dataloader, split_by_pid
from transformer_model_allscore import DeepPHQTransformer
import yaml

# Load config
CONFIG_PATH = Path("config.yaml")

with open(CONFIG_PATH, "r") as f:
    config = yaml.safe_load(f)

config

# 2. Load processed CSV data

In [None]:
data_cfg = config["data"]
root = Path(data_cfg["data_root"])

level = data_cfg["level"]

if level == "word":
    csv_path = root / data_cfg["word_csv"]
elif level == "sentence":
    csv_path = root / data_cfg["sentence_csv"]
elif level == "dialogue":
    csv_path = root / data_cfg["dialogue_csv"]
else:
    raise ValueError("Unknown level in config")

print("Loading:", csv_path)

df = pd.read_csv(csv_path)
df.head()

# 3. Build Vocab and Balanced DataLoader

In [None]:
all_texts = df["Text"].tolist()
vocab = build_vocab(all_texts, min_freq=config["vocab"]["min_freq"])

In [None]:
# 1. split
train_df, val_df, test_df = split_by_pid(df)

# 2. create datasets
train_dataset = DeepPHQDataset(
    data=list(zip(train_df["PID"], train_df["Text"], train_df["PHQ_Score"])),
    vocab=vocab,
    max_length=config["data"]["max_length"]
)

val_dataset = DeepPHQValDataset(
    val_df,
    vocab,
    max_length=config["data"]["max_length"],
    stride=128
)
test_dataset = DeepPHQValDataset(
    test_df,
    vocab,
    max_length=config["data"]["max_length"],
    stride=128
)

val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
# 3. dataloaders
train_loader = create_balanced_dataloader(
    train_dataset,
    batch_size=config["dataloader"]["batch_size"]
)

# 4. verify shapes
batch = next(iter(train_loader))
print(batch["input_ids"].shape)
print(batch["label"].shape)
print(batch["pid"].shape)

# 4. Init Transformer Model

In [None]:
# ---- Load model config ----
model_cfg = config["model"]

# ---- Auto-select device ----
if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

print("Using device:", device)

# ---- Init model ----
model = DeepPHQTransformer(
    input_size=len(vocab),
    output_size=model_cfg["output_size"],
    hidden_dim=model_cfg["hidden_dim"],
    nhead=model_cfg["nhead"],
    num_layers=model_cfg["num_layers"],
    dropout=model_cfg["dropout"]
).to(device)

# 5. Train and Evaluate the Model

In [None]:
def evaluate_pid_level(model, val_loader, device="cuda", return_preds=False):
    model.eval()

    pid2preds = {}
    pid2labels = {}

    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch["input_ids"].to(device)
            labels = batch["label"].cpu().numpy()
            pids = batch["pid"].cpu().numpy()
            outputs = model(input_ids).cpu().numpy()

            for pid, lab_vec, pred_vec in zip(pids, labels, outputs):
                pid2preds.setdefault(pid, []).append(pred_vec)
                pid2labels[pid] = lab_vec

    final_pred_totals = []
    final_label_totals = []

    for pid in pid2preds:
        pred_mean = np.stack(pid2preds[pid]).mean(axis=0)
        final_pred_totals.append(pred_mean.sum())
        final_label_totals.append(pid2labels[pid].sum())

    final_pred_totals = np.array(final_pred_totals)
    final_label_totals = np.array(final_label_totals)

    mse = ((final_pred_totals - final_label_totals) ** 2).mean()
    mae = np.abs(final_pred_totals - final_label_totals).mean()
    rmse = np.sqrt(mse)

    if return_preds:
        return mse, mae, rmse, final_pred_totals, final_label_totals
    else:
        return mse, mae, rmse

In [None]:
train_cfg = config["training"]

optimizer = AdamW(
    model.parameters(),
    lr=float(train_cfg["learning_rate"]),
    weight_decay=float(train_cfg["weight_decay"])
)

criterion = nn.MSELoss()

# ---- history arrays ----
train_item_losses = []   # original training loss (item-level MSE)
train_total_mses = []    # NEW
train_total_maes = []    # NEW
train_total_rmses = []   # NEW

val_mses = []
val_maes = []
val_rmses = []


for epoch in range(train_cfg["num_epochs"]):

    model.train()
    epoch_item_loss = 0

    # for computing epoch-level total score stats
    batch_total_mses = []
    batch_total_maes = []
    batch_total_rmses = []

    progress = tqdm(train_loader, desc=f"Epoch {epoch+1}")

    for batch in progress:
        input_ids = batch["input_ids"].to(device)
        labels = batch["label"].to(device)  # (B, 8)

        optimizer.zero_grad()

        outputs = model(input_ids)          # (B, 8)

        # ---------- (1) item-level MSE loss (training target) ----------
        loss = criterion(outputs, labels)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), train_cfg["gradient_clip"])
        optimizer.step()

        epoch_item_loss += loss.item()

        # ---------- (2) total-score loss (for visualization only) ----------
        pred_total = outputs.sum(dim=1)    # (B,)
        label_total = labels.sum(dim=1)    # (B,)

        mse_batch = ((pred_total - label_total)**2).mean().item()
        mae_batch = (pred_total - label_total).abs().mean().item()
        rmse_batch = mse_batch**0.5

        batch_total_mses.append(mse_batch)
        batch_total_maes.append(mae_batch)
        batch_total_rmses.append(rmse_batch)

        progress.set_postfix(loss=loss.item())

    # ---- epoch-level stats ----
    avg_item_loss = epoch_item_loss / len(train_loader)
    avg_total_mse = sum(batch_total_mses) / len(batch_total_mses)
    avg_total_mae = sum(batch_total_maes) / len(batch_total_maes)
    avg_total_rmse = sum(batch_total_rmses) / len(batch_total_rmses)

    train_item_losses.append(avg_item_loss)
    train_total_mses.append(avg_total_mse)
    train_total_maes.append(avg_total_mae)
    train_total_rmses.append(avg_total_rmse)

    # ---- validation (PID-level) ----
    val_mse, val_mae, val_rmse = evaluate_pid_level(model, val_loader, device)

    val_mses.append(val_mse)
    val_maes.append(val_mae)
    val_rmses.append(val_rmse)

    print(
        f"[Epoch {epoch+1}] "
        f"Train ItemLoss={avg_item_loss:.4f} | "
        f"Train Total MSE={avg_total_mse:.4f} | "
        f"Val MSE={val_mse:.4f} | "
        f"Val MAE={val_mae:.4f} | "
        f"Val RMSE={val_rmse:.4f}"
    )

In [None]:
import matplotlib.pyplot as plt

plt.plot(train_total_mses, label="Train MSE (total)")
plt.plot(val_mses, label="Val MSE (total)")
plt.legend()
plt.grid(True)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training vs Validation Loss (PID-level)")
plt.show()

# 7. Test Case

In [None]:
test_mse, test_mae, test_rmse = evaluate_pid_level(model, test_loader, device)
print("\nRunning FINAL TEST evaluation...")

test_mse, test_mae, test_rmse, test_preds, test_labels = evaluate_pid_level(
    model,
    test_loader,
    device,
    return_preds=True
)

print("===== FINAL TEST RESULTS =====")
print(f"Test MSE  = {test_mse:.4f}")
print(f"Test MAE  = {test_mae:.4f}")
print(f"Test RMSE = {test_rmse:.4f}")

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# -------------------------------
# (1) Scatter: True vs Predicted
# -------------------------------
plt.figure(figsize=(6,6))
plt.scatter(test_labels, test_preds, alpha=0.6)
plt.plot([0, max(test_labels)], [0, max(test_labels)], 'r--')  # y=x line
plt.xlabel("True Total PHQ Score")
plt.ylabel("Predicted Total PHQ Score")
plt.title("True vs Predicted PHQ Total")
plt.grid(True)
plt.show()

# -------------------------------
# (2) Residual Plot (error)
# -------------------------------
residual = test_preds - test_labels

plt.figure(figsize=(6,5))
plt.scatter(test_labels, residual, alpha=0.6)
plt.axhline(0, color='r', linestyle='--')
plt.xlabel("True Total PHQ Score")
plt.ylabel("Prediction Error (Pred - True)")
plt.title("Residual Plot")
plt.grid(True)
plt.show()

# -------------------------------
# (3) Error Distribution
# -------------------------------
plt.figure(figsize=(6,5))
plt.hist(residual, bins=20, alpha=0.7, color='purple')
plt.axvline(0, color='r', linestyle='--')
plt.xlabel("Prediction Error")
plt.ylabel("Count")
plt.title("Error Histogram")
plt.grid(True)
plt.show()