# Train CEM Model - Fast Pipeline

**Runtime:** ~10-15 minutes (just training)

This notebook:
1. Loads preprocessed data from `data/processed/whole_pipeline/`
2. Trains a Concept Embedding Model (CEM)
3. Evaluates and saves results

**Prerequisites:** Run `0_prepare_dataset.ipynb` first!

## Section 0: Setup & Configuration

In [None]:
# Imports
import os
import json
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import pytorch_lightning as pl
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks import ModelCheckpoint

from sklearn.metrics import (
    confusion_matrix,
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    matthews_corrcoef,
    roc_auc_score,
    balanced_accuracy_score,
    classification_report,
)

from patched_model import PatchedConceptEmbeddingModel

print("✓ All imports successful")

In [None]:
# Set random seeds
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
pl.seed_everything(SEED)

print(f"✓ Random seed set to {SEED}")

In [None]:
# Detect device
if torch.backends.mps.is_available():
    DEVICE = "mps"
    print("✓ Using MacBook GPU (MPS)")
elif torch.cuda.is_available():
    DEVICE = "cuda"
    print("✓ Using CUDA GPU")
else:
    DEVICE = "cpu"
    print("⚠ Using CPU")

In [None]:
# Define paths
PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), ".."))
DATA_PROCESSED = os.path.join(PROJECT_ROOT, "data/processed")
DATASET_DIR = os.path.join(DATA_PROCESSED, "whole_pipeline")
OUTPUT_DIR = "outputs"

print("✓ Paths configured")
print(f"  Dataset dir: {DATASET_DIR}")
print(f"  Output dir: {OUTPUT_DIR}")

In [None]:
# Define 21 BDI-II concept names
CONCEPT_NAMES = [
    "Sadness", "Pessimism", "Past failure", "Loss of pleasure",
    "Guilty feelings", "Punishment feelings", "Self-dislike", "Self-criticalness",
    "Suicidal thoughts or wishes", "Crying", "Agitation", "Loss of interest",
    "Indecisiveness", "Worthlessness", "Loss of energy", "Changes in sleeping pattern",
    "Irritability", "Changes in appetite", "Concentration difficulty",
    "Tiredness or fatigue", "Loss of interest in sex"
]
N_CONCEPTS = len(CONCEPT_NAMES)

print(f"✓ Defined {N_CONCEPTS} BDI-II concepts")

In [None]:
# Hyperparameters
HYPERPARAMS = {
    # Model architecture
    "embedding_dim": 384,
    "n_concepts": 21,
    "n_tasks": 1,
    "emb_size": 128,
    
    # Training
    "batch_size_train": 32,
    "batch_size_eval": 64,
    "max_epochs": 100,
    "learning_rate": 0.01,
    "weight_decay": 4e-05,
    
    # Loss weights
    "concept_loss_weight": 1.0,
    "training_intervention_prob": 0.25,
    
    # ===== LOSS FUNCTION SELECTION (Enable ONE) =====
    # LDAM Loss (RECOMMENDED - Best for severe class imbalance)
    "use_ldam_loss": True,       # Enable LDAM Loss
    "use_focal_loss": False,      # Disable Focal Loss
    
    # LDAM Loss parameters (only used if use_ldam_loss=True)
    "n_positive": None,           # Will be set after loading data
    "n_negative": None,           # Will be set after loading data
    "ldam_max_margin": 0.5,       # Try: 0.3, 0.5, 0.7, 1.0
    "ldam_scale": 30,             # Try: 20, 30, 40, 50
    
    # Focal Loss parameters (only used if use_focal_loss=True)
    "focal_loss_alpha": 0.17,    # Proportion of positive class
    "focal_loss_gamma": 3.0,     # Focusing parameter (2.0-4.0)
    
    # Weighted Sampler (batch-level oversampling)
    "use_weighted_sampler": True,  # Enable WeightedRandomSampler
}

print("✓ Hyperparameters configured")
if HYPERPARAMS['use_ldam_loss']:
    print(f"  Using LDAM LOSS (margin={HYPERPARAMS['ldam_max_margin']}, scale={HYPERPARAMS['ldam_scale']})")
elif HYPERPARAMS['use_focal_loss']:
    print(f"  Using FOCAL LOSS (alpha={HYPERPARAMS['focal_loss_alpha']}, gamma={HYPERPARAMS['focal_loss_gamma']})")
else:
    print(f"  Using standard BCE loss with class weights")

## Section 1: Load Preprocessed Data

In [None]:
# Load training data
print("Loading preprocessed datasets...")

train_data = np.load(os.path.join(DATASET_DIR, "train_data.npz"))
X_train = train_data['X']
C_train = train_data['C']
y_train = train_data['y']
train_subject_ids = train_data['subject_ids']

print(f"✓ Loaded training data: {X_train.shape}")

In [None]:
# Load validation data
val_data = np.load(os.path.join(DATASET_DIR, "val_data.npz"))
X_val = val_data['X']
C_val = val_data['C']
y_val = val_data['y']
val_subject_ids = val_data['subject_ids']

print(f"✓ Loaded validation data: {X_val.shape}")

In [None]:
# Load test data
test_data = np.load(os.path.join(DATASET_DIR, "test_data.npz"))
X_test = test_data['X']
C_test = test_data['C']
y_test = test_data['y']
test_subject_ids = test_data['subject_ids']

print(f"✓ Loaded test data: {X_test.shape}")

In [None]:
# Load class weights
with open(os.path.join(DATASET_DIR, "class_weights.json"), 'r') as f:
    class_info = json.load(f)

n_positive = class_info['n_positive']
n_negative = class_info['n_negative']
pos_weight = class_info['pos_weight']

# Update HYPERPARAMS with actual class counts for LDAM
HYPERPARAMS['n_positive'] = n_positive
HYPERPARAMS['n_negative'] = n_negative

pos_weight_tensor = torch.tensor([pos_weight], dtype=torch.float32)

print(f"✓ Loaded class weights:")
print(f"  Negative: {n_negative}, Positive: {n_positive}")
print(f"  Ratio: 1:{pos_weight:.2f}")

## Section 2: PyTorch Dataset & DataLoaders

In [None]:
from torch.utils.data import WeightedRandomSampler

class CEMDataset(Dataset):
    def __init__(self, X, C, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.C = torch.tensor(C, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32)
    
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx], self.C[idx]

# Create datasets
train_dataset = CEMDataset(X_train, C_train, y_train)
val_dataset = CEMDataset(X_val, C_val, y_val)
test_dataset = CEMDataset(X_test, C_test, y_test)

# Create WeightedRandomSampler for batch-level oversampling (if enabled)
if HYPERPARAMS['use_weighted_sampler']:
    # Compute class sample counts
    class_sample_counts = np.bincount(y_train.astype(int))  # [n_negative, n_positive]
    weights = 1. / class_sample_counts
    sample_weights = weights[y_train.astype(int)]
    
    # Create sampler
    train_sampler = WeightedRandomSampler(
        weights=sample_weights,
        num_samples=len(sample_weights),
        replacement=True  # Allow positive samples to appear multiple times
    )
    
    print(f"✓ WeightedRandomSampler created:")
    print(f"  Negative weight: {weights[0]:.4f}")
    print(f"  Positive weight: {weights[1]:.4f}")
    print(f"  Expected positive ratio per batch: ~{weights[1]/(weights[0]+weights[1]):.1%}")
    
    # Create train loader with sampler (shuffle=False when using sampler)
    train_loader = DataLoader(train_dataset, batch_size=HYPERPARAMS['batch_size_train'], sampler=train_sampler)
else:
    # Standard train loader with shuffle
    train_loader = DataLoader(train_dataset, batch_size=HYPERPARAMS['batch_size_train'], shuffle=True)
    print("✓ Using standard DataLoader (shuffle=True)")

# Validation and test loaders (no sampling)
val_loader = DataLoader(val_dataset, batch_size=HYPERPARAMS['batch_size_eval'], shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=HYPERPARAMS['batch_size_eval'], shuffle=False)

print("✓ All DataLoaders created")

## Section 3: CEM Model Initialization

In [None]:
def c_extractor_arch(output_dim):
    return nn.Sequential(
        nn.Linear(HYPERPARAMS['embedding_dim'], 256),
        nn.ReLU(),
        nn.Dropout(0.3),
        nn.Linear(256, output_dim or 256)
    )

# Initialize CEM model with LDAM Loss support
cem_model = PatchedConceptEmbeddingModel(
    n_concepts=HYPERPARAMS['n_concepts'],
    n_tasks=HYPERPARAMS['n_tasks'],
    input_dim=HYPERPARAMS['embedding_dim'],
    emb_size=HYPERPARAMS['emb_size'],
    concept_loss_weight=HYPERPARAMS['concept_loss_weight'],
    training_intervention_prob=HYPERPARAMS['training_intervention_prob'],
    c_extractor_arch=c_extractor_arch,
    learning_rate=HYPERPARAMS['learning_rate'],
    weight_decay=HYPERPARAMS['weight_decay'],
    c2y_model=None,
    task_class_weights=None if (HYPERPARAMS['use_focal_loss'] or HYPERPARAMS['use_ldam_loss']) else pos_weight_tensor,
    # Focal Loss params
    use_focal_loss=HYPERPARAMS['use_focal_loss'],
    focal_loss_alpha=HYPERPARAMS['focal_loss_alpha'],
    focal_loss_gamma=HYPERPARAMS['focal_loss_gamma'],
    # LDAM Loss params
    use_ldam_loss=HYPERPARAMS['use_ldam_loss'],
    n_positive=HYPERPARAMS['n_positive'],
    n_negative=HYPERPARAMS['n_negative'],
    ldam_max_margin=HYPERPARAMS['ldam_max_margin'],
    ldam_scale=HYPERPARAMS['ldam_scale']
)

print("✓ CEM model initialized")
if HYPERPARAMS['use_ldam_loss']:
    print(f"  Using LDAM Loss (margin={HYPERPARAMS['ldam_max_margin']}, scale={HYPERPARAMS['ldam_scale']})")
    print(f"  Class counts: {HYPERPARAMS['n_positive']} positive, {HYPERPARAMS['n_negative']} negative")
elif HYPERPARAMS['use_focal_loss']:
    print(f"  Using Focal Loss (alpha={HYPERPARAMS['focal_loss_alpha']}, gamma={HYPERPARAMS['focal_loss_gamma']})")
else:
    print(f"  Using BCE Loss with pos_weight={pos_weight:.4f}")

## Section 4: Training

In [None]:
# Setup trainer
checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",
    dirpath=os.path.join(OUTPUT_DIR, "models"),
    filename="cem-{epoch:02d}-{val_loss:.2f}",
    save_top_k=1,
    mode="min"
)

trainer = pl.Trainer(
    max_epochs=HYPERPARAMS['max_epochs'],
    accelerator=DEVICE,
    devices=1,
    logger=CSVLogger(save_dir=os.path.join(OUTPUT_DIR, "logs"), name="cem_pipeline"),
    log_every_n_steps=10,
    callbacks=[checkpoint_callback],
    enable_progress_bar=True
)

print("✓ Trainer configured")

In [None]:
# Train model
print("\nStarting training...\n")
trainer.fit(cem_model, train_loader, val_loader)
print("\n✓ Training complete!")

## Section 5: Test Evaluation

In [None]:
# Run inference on test set
print("Running inference on test set...")

cem_model.eval()
device_obj = torch.device(DEVICE)
cem_model = cem_model.to(device_obj)

y_true_list = []
y_prob_list = []
concept_probs_list = []

with torch.no_grad():
    for x_batch, y_batch, c_batch in test_loader:
        x_batch = x_batch.to(device_obj)
        
        c_logits, _, y_logits = cem_model(x_batch)
        c_probs = torch.sigmoid(c_logits).cpu().numpy()
        y_probs = torch.sigmoid(y_logits).cpu().squeeze().numpy()
        
        y_true_list.extend(y_batch.numpy().astype(int).tolist())
        y_prob_list.extend(y_probs.tolist() if isinstance(y_probs, np.ndarray) else [y_probs])
        concept_probs_list.extend(c_probs.tolist())

y_true = np.array(y_true_list)
y_prob = np.array(y_prob_list)
concept_probs = np.array(concept_probs_list)

print("✓ Inference complete")

In [None]:
# Try different thresholds to find best for minority class
print("\nTesting different decision thresholds...")
print(f"{'Threshold':<12} {'Recall':<10} {'Precision':<10} {'F1':<10}")
print("-"*50)

best_f1 = 0
best_threshold = 0.5

for threshold in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]:
    y_pred_temp = (y_prob >= threshold).astype(int)
    
    if np.sum(y_pred_temp) == 0:
        continue
    
    recall = recall_score(y_true, y_pred_temp)
    precision = precision_score(y_true, y_pred_temp)
    f1 = f1_score(y_true, y_pred_temp)
    
    print(f"{threshold:<12.1f} {recall:<10.4f} {precision:<10.4f} {f1:<10.4f}")
    
    if f1 > best_f1:
        best_f1 = f1
        best_threshold = threshold

print(f"\n✓ Best threshold: {best_threshold:.2f} (F1={best_f1:.4f})")

# Use best threshold for final predictions
y_pred = (y_prob >= best_threshold).astype(int)

## Section 6: Results Display

In [None]:
# Compute all metrics
cm = confusion_matrix(y_true, y_pred)
tn, fp, fn, tp = cm.ravel()

acc = accuracy_score(y_true, y_pred)
balanced_acc = balanced_accuracy_score(y_true, y_pred)
roc_auc = roc_auc_score(y_true, y_prob)
mcc = matthews_corrcoef(y_true, y_pred)
f1_binary = f1_score(y_true, y_pred, pos_label=1)
f1_macro = f1_score(y_true, y_pred, average='macro')
precision_binary = precision_score(y_true, y_pred, pos_label=1)
recall_binary = recall_score(y_true, y_pred, pos_label=1)

# Print results
print("\n" + "="*70)
print("                    TEST SET EVALUATION")
print("="*70)
print(f"\nDecision Threshold: {best_threshold:.2f}")

# Enhanced Confusion Matrix Display
print(f"\n{'CONFUSION MATRIX':^50}")
print("="*50)
print(f"{'':>20} │ {'Predicted Negative':^12} │ {'Predicted Positive':^12}")
print("─"*50)
print(f"{'Actual Negative':>20} │ {f'TN = {tn}':^12} │ {f'FP = {fp}':^12}")
print(f"{'Actual Positive':>20} │ {f'FN = {fn}':^12} │ {f'TP = {tp}':^12}")
print("="*50)
print(f"\n  True Positives:  {tp:>3}/{int(np.sum(y_true)):<3} ({100*tp/np.sum(y_true):>5.1f}% of depression cases caught)")
print(f"  False Negatives: {fn:>3}/{int(np.sum(y_true)):<3} ({100*fn/np.sum(y_true):>5.1f}% of depression cases MISSED)")
print(f"  True Negatives:  {tn:>3}/{int(len(y_true)-np.sum(y_true)):<3} ({100*tn/(len(y_true)-np.sum(y_true)):>5.1f}% of healthy correctly identified)")
print(f"  False Positives: {fp:>3}/{int(len(y_true)-np.sum(y_true)):<3} ({100*fp/(len(y_true)-np.sum(y_true)):>5.1f}% false alarms)")

print(f"\nPerformance Metrics:")
print(f"  Accuracy:                  {acc:.4f}")
print(f"  Balanced Accuracy:         {balanced_acc:.4f}")
print(f"  ROC-AUC:                   {roc_auc:.4f}")
print(f"  Matthews Correlation:      {mcc:.4f}")
print(f"\n  F1 Score (Binary):         {f1_binary:.4f}")
print(f"  F1 Score (Macro):          {f1_macro:.4f}")
print(f"  Precision (Binary):        {precision_binary:.4f}")
print(f"  Recall (Binary):           {recall_binary:.4f}")

print("\n" + classification_report(y_true, y_pred, target_names=['Negative', 'Positive']))
print("="*70)

In [None]:
# Save results
metrics_dict = {
    "threshold": float(best_threshold),
    "n_samples": int(len(y_true)),
    "n_positive": int(np.sum(y_true)),
    "n_negative": int(len(y_true) - np.sum(y_true)),
    "accuracy": float(acc),
    "balanced_accuracy": float(balanced_acc),
    "roc_auc": float(roc_auc),
    "mcc": float(mcc),
    "f1_binary": float(f1_binary),
    "f1_macro": float(f1_macro),
    "precision_binary": float(precision_binary),
    "recall_binary": float(recall_binary),
    "confusion_matrix": {"tn": int(tn), "fp": int(fp), "fn": int(fn), "tp": int(tp)}
}

os.makedirs(os.path.join(OUTPUT_DIR, "results"), exist_ok=True)
with open(os.path.join(OUTPUT_DIR, "results/test_metrics.json"), 'w') as f:
    json.dump(metrics_dict, f, indent=4)

# Save predictions
predictions_df = pd.DataFrame({
    'subject_id': test_subject_ids,
    'y_true': y_true,
    'y_pred': y_pred,
    'y_prob': y_prob
})

for i, concept_name in enumerate(CONCEPT_NAMES):
    predictions_df[concept_name] = concept_probs[:, i]

predictions_df.to_csv(os.path.join(OUTPUT_DIR, "results/test_predictions.csv"), index=False)

print(f"✓ Results saved to {OUTPUT_DIR}/results/")

In [None]:
print("\n" + "="*70)
print("              CEM TRAINING COMPLETE")
print("="*70)
print(f"\nGenerated files:")
print(f"  Model checkpoint: {OUTPUT_DIR}/models/")
print(f"  Metrics JSON:     {OUTPUT_DIR}/results/test_metrics.json")
print(f"  Predictions CSV:  {OUTPUT_DIR}/results/test_predictions.csv")
print("="*70)