# Optuna Hyperparameter Optimization - Custom CEM

**Goal**: Find optimal hyperparameters to maximize MCC on validation set

**Method**: TPE sampler + MedianPruner

**Estimated Time**: 6-8 hours (50-100 trials)

**Problem**: Current model predicts 93.7% false positives due to aggressive LDAM Loss + WeightedRandomSampler

**Solution**: Systematic hyperparameter search with early stopping

In [1]:
# Imports
import os
import json
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

from sklearn.metrics import (
    confusion_matrix,
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    matthews_corrcoef,
    roc_auc_score,
)

import optuna

print("‚úì All imports successful")

‚úì All imports successful


In [None]:
# Base seed for reproducibility
BASE_SEED = 42

# Set Optuna sampler seed (for reproducible hyperparameter sampling)
np.random.seed(BASE_SEED)
torch.manual_seed(BASE_SEED)
pl.seed_everything(BASE_SEED)

print(f"‚úì Base seed set to {BASE_SEED}")
print("  Note: Per-trial seeds will be generated deterministically from trial.number")

In [3]:
# Detect device
if torch.backends.mps.is_available():
    DEVICE = "mps"
    print("‚úì Using MacBook GPU (MPS)")
elif torch.cuda.is_available():
    DEVICE = "cuda"
    print("‚úì Using CUDA GPU")
else:
    DEVICE = "cpu"
    print("‚ö† Using CPU")

‚úì Using MacBook GPU (MPS)


In [4]:
# Define paths
PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), ".."))
DATA_PROCESSED = os.path.join(PROJECT_ROOT, "data/processed")
DATASET_DIR = os.path.join(DATA_PROCESSED, "whole_pipeline")
OUTPUT_DIR = "outputs_optuna"

print("‚úì Paths configured")
print(f"  Dataset dir: {DATASET_DIR}")
print(f"  Output dir: {OUTPUT_DIR}")

‚úì Paths configured
  Dataset dir: /Users/gualtieromarencoturi/Desktop/thesis/Master-Thesis-CEM-Depression-etc-case-study/data/processed/whole_pipeline
  Output dir: outputs_optuna


In [5]:
# Define 21 BDI-II concept names
CONCEPT_NAMES = [
    "Sadness", "Pessimism", "Past failure", "Loss of pleasure",
    "Guilty feelings", "Punishment feelings", "Self-dislike", "Self-criticalness",
    "Suicidal thoughts or wishes", "Crying", "Agitation", "Loss of interest",
    "Indecisiveness", "Worthlessness", "Loss of energy", "Changes in sleeping pattern",
    "Irritability", "Changes in appetite", "Concentration difficulty",
    "Tiredness or fatigue", "Loss of interest in sex"
]
N_CONCEPTS = len(CONCEPT_NAMES)

print(f"‚úì Defined {N_CONCEPTS} BDI-II concepts")

‚úì Defined 21 BDI-II concepts


In [6]:
# Load training data
print("Loading preprocessed datasets...")

train_data = np.load(os.path.join(DATASET_DIR, "train_data.npz"))
X_train = train_data['X']
C_train = train_data['C']
y_train = train_data['y']

print(f"‚úì Loaded training data: {X_train.shape}")

# Load validation data
val_data = np.load(os.path.join(DATASET_DIR, "val_data.npz"))
X_val = val_data['X']
C_val = val_data['C']
y_val = val_data['y']

print(f"‚úì Loaded validation data: {X_val.shape}")

# Load class weights
with open(os.path.join(DATASET_DIR, "class_weights.json"), 'r') as f:
    class_info = json.load(f)

n_positive = class_info['n_positive']
n_negative = class_info['n_negative']
pos_weight = class_info['pos_weight']

print(f"\n‚úì Class distribution:")
print(f"  Negative: {n_negative}, Positive: {n_positive}")
print(f"  Ratio: 1:{pos_weight:.2f}")

print("\n‚ö† Test data will be loaded ONLY after optimization completes!")

Loading preprocessed datasets...
‚úì Loaded training data: (486, 384)
‚úì Loaded validation data: (200, 384)

‚úì Class distribution:
  Negative: 403, Positive: 83
  Ratio: 1:4.86

‚ö† Test data will be loaded ONLY after optimization completes!


In [7]:
# Fixed hyperparameters
FIXED_PARAMS = {
    "embedding_dim": 384,
    "n_concepts": 21,
    "n_tasks": 1,
    "batch_size_train": 32,
    "batch_size_eval": 64,
    "max_epochs": 100,
    "shared_prob_gen": True,
}

print("‚úì Fixed hyperparameters configured:")
for key, value in FIXED_PARAMS.items():
    print(f"  {key}: {value}")

‚úì Fixed hyperparameters configured:
  embedding_dim: 384
  n_concepts: 21
  n_tasks: 1
  batch_size_train: 32
  batch_size_eval: 64
  max_epochs: 100
  shared_prob_gen: True


In [8]:
# PyTorch Dataset
class CEMDataset(Dataset):
    def __init__(self, X, C, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.C = torch.tensor(C, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32)
    
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx], self.C[idx]

# Create datasets (DataLoaders will be created per trial)
train_dataset = CEMDataset(X_train, C_train, y_train)
val_dataset = CEMDataset(X_val, C_val, y_val)

# Validation loader (fixed, no sampling)
val_loader = DataLoader(val_dataset, batch_size=FIXED_PARAMS['batch_size_eval'], shuffle=False)

print("‚úì Datasets created")
print("  Train DataLoader will be created per trial (with/without sampler)")
print("  Validation DataLoader created (fixed)")

‚úì Datasets created
  Train DataLoader will be created per trial (with/without sampler)
  Validation DataLoader created (fixed)


In [9]:
# LDAM Loss (for class imbalance)
class LDAMLoss(nn.Module):
    """
    Label-Distribution-Aware Margin (LDAM) Loss for long-tailed recognition.
    
    Creates class-dependent margins to make decision boundaries harder for minority classes.
    """
    def __init__(self, n_positive, n_negative, max_margin=0.5, scale=30):
        super(LDAMLoss, self).__init__()
        self.max_margin = max_margin
        self.scale = scale
        
        # Compute class frequencies
        total = n_positive + n_negative
        freq_pos = n_positive / total
        freq_neg = n_negative / total
        
        # Compute margins: minority class gets larger margin
        margin_pos = max_margin * (freq_pos ** (-0.25))
        margin_neg = max_margin * (freq_neg ** (-0.25))
        
        self.register_buffer('margin_pos', torch.tensor(margin_pos))
        self.register_buffer('margin_neg', torch.tensor(margin_neg))
    
    def forward(self, logits, targets):
        logits = logits.view(-1)
        targets = targets.view(-1).float()
        
        # Apply class-dependent margins
        margin = targets * self.margin_pos + (1 - targets) * (-self.margin_neg)
        adjusted_logits = (logits - margin) * self.scale
        
        return F.binary_cross_entropy_with_logits(adjusted_logits, targets, reduction='mean')


# Custom CEM Implementation
class CustomCEM(pl.LightningModule):
    """
    Custom Concept Embedding Model (CEM) implementation.
    
    Architecture:
      X ‚Üí concept_extractor ‚Üí context_layers ‚Üí prob_generator ‚Üí dual_embeddings ‚Üí task_classifier ‚Üí y
    """
    def __init__(
        self,
        n_concepts=21,
        emb_size=128,
        input_dim=384,
        shared_prob_gen=True,
        intervention_prob=0.25,
        concept_loss_weight=1.0,
        learning_rate=0.01,
        weight_decay=4e-05,
        use_ldam_loss=True,
        n_positive=83,
        n_negative=403,
        ldam_max_margin=0.5,
        ldam_scale=30,
    ):
        super().__init__()
        self.save_hyperparameters()
        
        self.n_concepts = n_concepts
        self.emb_size = emb_size
        self.intervention_prob = intervention_prob
        self.concept_loss_weight = concept_loss_weight
        
        # Stage 1: Concept Extractor (X ‚Üí Pre-Concept Features)
        self.concept_extractor = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 256)
        )
        
        # Stage 2: Context Generators (Features ‚Üí Dual Embeddings)
        self.context_layers = nn.ModuleList([
            nn.Sequential(
                nn.Linear(256, emb_size * 2),
                nn.LeakyReLU()
            ) for _ in range(n_concepts)
        ])
        
        # Stage 3: Probability Generator
        if shared_prob_gen:
            self.prob_generator = nn.Linear(emb_size * 2, 1)
        else:
            self.prob_generators = nn.ModuleList([
                nn.Linear(emb_size * 2, 1) for _ in range(n_concepts)
            ])
        
        self.shared_prob_gen = shared_prob_gen
        
        # Stage 4: Task Classifier
        self.task_classifier = nn.Sequential(
            nn.Linear(n_concepts * emb_size, 128),
            nn.LeakyReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 1)
        )
        
        # Loss functions
        self.concept_loss_fn = nn.BCEWithLogitsLoss()
        if use_ldam_loss:
            self.task_loss_fn = LDAMLoss(n_positive, n_negative, ldam_max_margin, ldam_scale)
        else:
            self.task_loss_fn = nn.BCEWithLogitsLoss()
    
    def forward(self, x, c_true=None, train=False):
        # Extract features
        pre_features = self.concept_extractor(x)
        
        # Generate contexts and probabilities
        contexts = []
        c_logits_list = []
        
        for i, context_layer in enumerate(self.context_layers):
            context = context_layer(pre_features)
            
            if self.shared_prob_gen:
                logit = self.prob_generator(context)
            else:
                logit = self.prob_generators[i](context)
            
            contexts.append(context)
            c_logits_list.append(logit)
        
        c_logits = torch.cat(c_logits_list, dim=1)
        c_probs = torch.sigmoid(c_logits)
        
        # Apply intervention
        if train and self.intervention_prob > 0 and c_true is not None:
            intervention_mask = torch.bernoulli(
                torch.ones_like(c_probs) * self.intervention_prob
            )
            c_probs = c_probs * (1 - intervention_mask) + c_true * intervention_mask
        
        # Mix dual embeddings
        concept_embeddings = []
        for i, context in enumerate(contexts):
            emb_true = context[:, :self.emb_size]
            emb_false = context[:, self.emb_size:]
            
            prob = c_probs[:, i:i+1]
            mixed_emb = emb_true * prob + emb_false * (1 - prob)
            concept_embeddings.append(mixed_emb)
        
        c_embeddings = torch.cat(concept_embeddings, dim=1)
        y_logits = self.task_classifier(c_embeddings)
        
        return c_logits, y_logits
    
    def training_step(self, batch, batch_idx):
        x, y, c_true = batch
        c_logits, y_logits = self.forward(x, c_true=c_true, train=True)
        
        task_loss = self.task_loss_fn(y_logits.squeeze(), y.squeeze())
        concept_loss = self.concept_loss_fn(c_logits, c_true)
        loss = task_loss + self.concept_loss_weight * concept_loss
        
        self.log('train_loss', loss, on_epoch=True)
        self.log('train_task_loss', task_loss, on_epoch=True)
        self.log('train_concept_loss', concept_loss, on_epoch=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y, c_true = batch
        c_logits, y_logits = self.forward(x, c_true=c_true, train=False)
        
        task_loss = self.task_loss_fn(y_logits.squeeze(), y.squeeze())
        concept_loss = self.concept_loss_fn(c_logits, c_true)
        loss = task_loss + self.concept_loss_weight * concept_loss
        
        self.log('val_loss', loss, on_epoch=True, prog_bar=True)
        self.log('val_task_loss', task_loss, on_epoch=True)
        self.log('val_concept_loss', concept_loss, on_epoch=True)
        
        return loss
    
    def configure_optimizers(self):
        return torch.optim.Adam(
            self.parameters(),
            lr=self.hparams.learning_rate,
            weight_decay=self.hparams.weight_decay
        )

print("‚úì CustomCEM and LDAMLoss classes defined")

‚úì CustomCEM and LDAMLoss classes defined


In [None]:
# Objective function for Optuna
def objective(trial):
    """
    Optuna objective function to maximize validation F1 while achieving 75% recall.
    
    Strategy:
      1. Train model with sampled hyperparameters
      2. Find threshold that achieves 75% recall on validation set
      3. Return F1 score at that threshold
    
    Returns:
        float: F1 score on validation set (with threshold optimized for 75% recall)
    """
    # ============================================================================
    # STEP 0: Per-Trial Deterministic Seeding
    # ============================================================================
    trial_seed = BASE_SEED + trial.number
    
    np.random.seed(trial_seed)
    torch.manual_seed(trial_seed)
    pl.seed_everything(trial_seed, workers=True)
    
    trial.set_user_attr('trial_seed', trial_seed)
    
    # ============================================================================
    # STEP 1: Sample hyperparameters
    # ============================================================================
    use_ldam = trial.suggest_categorical('use_ldam_loss', [True, False])
    ldam_margin = trial.suggest_float('ldam_max_margin', 0.1, 1.0)
    ldam_scale = trial.suggest_int('ldam_scale', 10, 50)
    use_sampler = trial.suggest_categorical('use_weighted_sampler', [True, False])
    lr = trial.suggest_float('learning_rate', 0.001, 0.05, log=True)
    concept_weight = trial.suggest_float('concept_loss_weight', 0.5, 2.0)
    emb_size = trial.suggest_categorical('emb_size', [64, 128, 256])
    intervention = trial.suggest_float('intervention_prob', 0.0, 0.5)
    wd = trial.suggest_float('weight_decay', 1e-5, 1e-3, log=True)
    
    # Log all hyperparameters to trial attributes
    trial.set_user_attr('use_ldam_loss', use_ldam)
    trial.set_user_attr('ldam_max_margin', ldam_margin)
    trial.set_user_attr('ldam_scale', ldam_scale)
    trial.set_user_attr('use_weighted_sampler', use_sampler)
    trial.set_user_attr('learning_rate', lr)
    trial.set_user_attr('concept_loss_weight', concept_weight)
    trial.set_user_attr('emb_size', emb_size)
    trial.set_user_attr('intervention_prob', intervention)
    trial.set_user_attr('weight_decay', wd)
    
    # ============================================================================
    # STEP 2: Create DataLoader with or without sampler
    # ============================================================================
    if use_sampler:
        # Batch-level oversampling
        class_sample_counts = np.bincount(y_train.astype(int))
        weights = 1.0 / class_sample_counts
        sample_weights = weights[y_train.astype(int)]
        
        train_sampler = WeightedRandomSampler(
            weights=sample_weights,
            num_samples=len(sample_weights),
            replacement=True,
            generator=torch.Generator().manual_seed(trial_seed)  # Seeded generator
        )
        
        # Log sampler statistics
        trial.set_user_attr('sampler_negative_weight', float(weights[0]))
        trial.set_user_attr('sampler_positive_weight', float(weights[1]))
        trial.set_user_attr('sampler_expected_pos_ratio', 
                           float(weights[1]/(weights[0]+weights[1])))
        
        train_loader = DataLoader(
            train_dataset,
            batch_size=FIXED_PARAMS['batch_size_train'],
            sampler=train_sampler,
            worker_init_fn=lambda worker_id: np.random.seed(trial_seed + worker_id)
        )
    else:
        train_loader = DataLoader(
            train_dataset,
            batch_size=FIXED_PARAMS['batch_size_train'],
            shuffle=True,
            generator=torch.Generator().manual_seed(trial_seed),  # Seeded generator
            worker_init_fn=lambda worker_id: np.random.seed(trial_seed + worker_id)
        )
    
    # ============================================================================
    # STEP 3: Create model
    # ============================================================================
    model = CustomCEM(
        n_concepts=FIXED_PARAMS['n_concepts'],
        emb_size=emb_size,
        input_dim=FIXED_PARAMS['embedding_dim'],
        shared_prob_gen=FIXED_PARAMS['shared_prob_gen'],
        intervention_prob=intervention,
        concept_loss_weight=concept_weight,
        learning_rate=lr,
        weight_decay=wd,
        use_ldam_loss=use_ldam,
        n_positive=n_positive,
        n_negative=n_negative,
        ldam_max_margin=ldam_margin,
        ldam_scale=ldam_scale,
    )
    
    # ============================================================================
    # STEP 4: Setup trainer with EarlyStopping
    # ============================================================================
    early_stop_callback = EarlyStopping(
        monitor='val_loss',
        patience=15,  # Stop if no improvement for 15 epochs
        mode='min',
        verbose=False
    )
    
    trainer = pl.Trainer(
        max_epochs=FIXED_PARAMS['max_epochs'],
        accelerator=DEVICE,
        devices=1,
        callbacks=[early_stop_callback],
        enable_progress_bar=False,
        logger=False,
        enable_checkpointing=False,
    )
    
    # ============================================================================
    # STEP 5: Train
    # ============================================================================
    try:
        trainer.fit(model, train_loader, val_loader)
    except optuna.TrialPruned:
        raise
    
    # ============================================================================
    # STEP 6: Run validation inference with proper device handling
    # ============================================================================
    model.eval()
    device_obj = torch.device(DEVICE)
    model = model.to(device_obj)
    
    y_prob_val = []
    y_true_val = []
    
    with torch.no_grad():
        for x_batch, y_batch, c_batch in val_loader:
            # Move to device
            x_batch = x_batch.to(device_obj)
            c_batch = c_batch.to(device_obj)
            
            # Forward pass
            c_logits, y_logits = model(x_batch)
            
            # Move to CPU for numpy conversion
            y_probs = torch.sigmoid(y_logits).cpu().squeeze().numpy()
            y_true_batch = y_batch.cpu().numpy().astype(int)  # Explicit CPU
            
            # Collect predictions
            if y_probs.ndim == 0:  # Single sample edge case
                y_prob_val.append(float(y_probs))
                y_true_val.append(int(y_true_batch))
            else:
                y_prob_val.extend(y_probs.tolist())
                y_true_val.extend(y_true_batch.tolist())
    
    y_prob_val = np.array(y_prob_val)
    y_true_val = np.array(y_true_val)
    
    # ============================================================================
    # STEP 7: Find threshold achieving 75% recall on validation set
    # ============================================================================
    target_recall = 0.75
    best_threshold, achieved_recall, precision = find_threshold_for_target_recall(
        y_true_val, y_prob_val, target_recall=target_recall
    )
    
    # Calculate F1 for this threshold
    if achieved_recall > 0 and precision > 0:
        f1 = 2 * (precision * achieved_recall) / (precision + achieved_recall)
    else:
        f1 = 0.0
    
    # Calculate MCC as well
    y_pred_val = (y_prob_val >= best_threshold).astype(int)
    mcc = matthews_corrcoef(y_true_val, y_pred_val)
    
    # ============================================================================
    # STEP 8: Log all metrics to trial user attributes
    # ============================================================================
    trial.set_user_attr('best_threshold', float(best_threshold))
    trial.set_user_attr('achieved_recall', float(achieved_recall))
    trial.set_user_attr('precision', float(precision))
    trial.set_user_attr('f1_score', float(f1))
    trial.set_user_attr('mcc', float(mcc))
    trial.set_user_attr('target_recall', float(target_recall))
    
    # Log training info
    if hasattr(trainer, 'callback_metrics'):
        metrics = trainer.callback_metrics
        if 'val_loss' in metrics:
            trial.set_user_attr('final_val_loss', float(metrics['val_loss']))
        if 'train_loss' in metrics:
            trial.set_user_attr('final_train_loss', float(metrics['train_loss']))
    
    trial.set_user_attr('num_epochs_trained', trainer.current_epoch)
    trial.set_user_attr('early_stopped', trainer.current_epoch < FIXED_PARAMS['max_epochs'])
    
    # ============================================================================
    # STEP 9: Return F1 score (objective to maximize)
    # ============================================================================
    return f1


def find_threshold_for_target_recall(y_true, y_prob, target_recall=0.75):
    """
    Find threshold that achieves target recall with best precision.
    
    For depression screening: Prioritize catching cases (recall) over precision.
    
    Args:
        y_true: True labels
        y_prob: Predicted probabilities
        target_recall: Minimum recall required (default: 0.75)
    
    Returns:
        best_threshold: Threshold achieving target recall
        achieved_recall: Actual recall achieved
        precision: Precision at that threshold
    """
    best_precision = 0
    best_threshold = 0.5
    achieved_recall = 0
    
    for threshold in np.arange(0.01, 0.99, 0.01):  # Fine-grained search
        y_pred = (y_prob >= threshold).astype(int)
        
        # Skip if no positives predicted
        if np.sum(y_pred) == 0:
            continue
        
        try:
            recall = recall_score(y_true, y_pred)
            precision = precision_score(y_true, y_pred)
        except:
            continue
        
        # Only consider thresholds that meet recall target
        if recall >= target_recall:
            if precision > best_precision:
                best_precision = precision
                best_threshold = threshold
                achieved_recall = recall
    
    # If no threshold achieves target recall, return best recall achieved
    if achieved_recall == 0:
        print(f"‚ö† Cannot achieve {target_recall:.0%} recall. Finding best achievable...")
        best_recall = 0
        for threshold in np.arange(0.01, 0.99, 0.01):
            y_pred = (y_prob >= threshold).astype(int)
            if np.sum(y_pred) == 0:
                continue
            try:
                recall = recall_score(y_true, y_pred)
                precision = precision_score(y_true, y_pred)
            except:
                continue
            if recall > best_recall:
                best_recall = recall
                best_precision = precision
                best_threshold = threshold
                achieved_recall = recall
    
    return best_threshold, achieved_recall, best_precision

print("‚úì Objective function defined")

In [None]:
# Create Optuna study
study = optuna.create_study(
    direction='maximize',  
    sampler=optuna.samplers.TPESampler(seed=BASE_SEED),
    pruner=optuna.pruners.MedianPruner(
        n_startup_trials=5,       # Don't prune first 5 trials
        n_warmup_steps=10,        # Wait 10 epochs before pruning
        interval_steps=5,         # Check every 5 epochs
    )
)

print("="*70)
print("                 OPTUNA STUDY CREATED")
print("="*70)
print("\nConfiguration:")
print("  Objective:     Achieve 75% recall with maximum precision (maximize F1)")
print("  Sampler:       TPE (Tree-structured Parzen Estimator)")
print("  Pruner:        MedianPruner (early stopping)")
print("  Search Space:  9 hyperparameters")
print("\nHyperparameters to optimize:")
print("  - use_ldam_loss: [True, False]")
print("  - ldam_max_margin: [0.1, 1.0]")
print("  - ldam_scale: [10, 50]")
print("  - use_weighted_sampler: [True, False]")
print("  - learning_rate: [0.001, 0.05] (log scale)")
print("  - concept_loss_weight: [0.5, 2.0]")
print("  - emb_size: [64, 128, 256]")
print("  - intervention_prob: [0.0, 0.5]")
print("  - weight_decay: [1e-5, 1e-3] (log scale)")
print("="*70)

In [12]:
# Run optimization
n_trials = 10
timeout = 0.5 * 3600  #3600=1h

print("\n" + "="*70)
print("                STARTING OPTIMIZATION")
print("="*70)
print(f"\nSettings:")
print(f"  Max trials:        {n_trials}")
print(f"  Timeout:           {timeout/3600:.0f} hours")
print(f"  Expected runtime:  6-8 hours")
print(f"\n‚è∞ This will take several hours. Monitor progress below...\n")
print("="*70)

study.optimize(
    objective,
    n_trials=n_trials,
    timeout=timeout,
    show_progress_bar=True
)

print("\n" + "="*70)
print("                OPTIMIZATION COMPLETE")
print("="*70)
print(f"\nResults:")
print(f"  Completed trials:  {len(study.trials)}")
print(f"  Pruned trials:     {len([t for t in study.trials if t.state == optuna.trial.TrialState.PRUNED])}")
print(f"  Best MCC:          {study.best_value:.4f}")
print("="*70)


                STARTING OPTIMIZATION

Settings:
  Max trials:        10
  Timeout:           0 hours
  Expected runtime:  6-8 hours

‚è∞ This will take several hours. Monitor progress below...



  0%|          | 0/10 [00:00<?, ?it/s]

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name              | Type              | Params
--------------------------------------------------------
0 | concept_extractor | Sequential        | 164 K 
1 | context_layers    | ModuleList        | 1.4 M 
2 | prob_generator    | Linear            | 257   
3 | task_classifier   | Sequential        | 344 K 
4 | concept_loss_fn   | BCEWithLogitsLoss | 0     
5 | task_loss_fn      | BCEWithLogitsLoss | 0     
--------------------------------------------------------
1.9 M     Trainable params
0         Non-trainable params
1.9 M     Total params
7.562     Total estimated model params size (MB)
  rank_zero_warn(
  rank_zero_warn(
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name              | Type              | Params
--------

[I 2025-12-13 19:19:35,837] Trial 0 finished with value: 0.7384615384615384 and parameters: {'use_ldam_loss': False, 'ldam_max_margin': 0.7587945476302645, 'ldam_scale': 34, 'use_weighted_sampler': True, 'learning_rate': 0.001255111517297384, 'concept_loss_weight': 1.7992642186624028, 'emb_size': 128, 'intervention_prob': 0.48495492608099716, 'weight_decay': 0.000462258900102083}. Best is trial 0 with value: 0.7384615384615384.


  rank_zero_warn(
  rank_zero_warn(
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name              | Type              | Params
--------------------------------------------------------
0 | concept_extractor | Sequential        | 164 K 
1 | context_layers    | ModuleList        | 690 K 
2 | prob_generator    | Linear            | 129   
3 | task_classifier   | Sequential        | 172 K 
4 | concept_loss_fn   | BCEWithLogitsLoss | 0     
5 | task_loss_fn      | BCEWithLogitsLoss | 0     
--------------------------------------------------------
1.0 M     Trainable params
0         Non-trainable params
1.0 M     Total params
4.110     Total estimated model params size (MB)


[I 2025-12-13 19:19:53,595] Trial 1 finished with value: 0.7924528301886792 and parameters: {'use_ldam_loss': True, 'ldam_max_margin': 0.2650640588680905, 'ldam_scale': 22, 'use_weighted_sampler': True, 'learning_rate': 0.003124565071260871, 'concept_loss_weight': 1.4177793420835691, 'emb_size': 256, 'intervention_prob': 0.22803499210851796, 'weight_decay': 0.00037183641805732076}. Best is trial 1 with value: 0.7924528301886792.


  rank_zero_warn(
  rank_zero_warn(
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name              | Type              | Params
--------------------------------------------------------
0 | concept_extractor | Sequential        | 164 K 
1 | context_layers    | ModuleList        | 1.4 M 
2 | prob_generator    | Linear            | 257   
3 | task_classifier   | Sequential        | 344 K 
4 | concept_loss_fn   | BCEWithLogitsLoss | 0     
5 | task_loss_fn      | LDAMLoss          | 0     
--------------------------------------------------------
1.9 M     Trainable params
0         Non-trainable params
1.9 M     Total params
7.562     Total estimated model params size (MB)


[I 2025-12-13 19:20:10,393] Trial 2 finished with value: 0.7241379310344829 and parameters: {'use_ldam_loss': False, 'ldam_max_margin': 0.6331731119758383, 'ldam_scale': 11, 'use_weighted_sampler': True, 'learning_rate': 0.0012897950480855534, 'concept_loss_weight': 1.92332830588, 'emb_size': 64, 'intervention_prob': 0.048836057003191935, 'weight_decay': 0.000233596350262616}. Best is trial 1 with value: 0.7924528301886792.


  rank_zero_warn(
  rank_zero_warn(
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name              | Type              | Params
--------------------------------------------------------
0 | concept_extractor | Sequential        | 164 K 
1 | context_layers    | ModuleList        | 2.8 M 
2 | prob_generator    | Linear            | 513   
3 | task_classifier   | Sequential        | 688 K 
4 | concept_loss_fn   | BCEWithLogitsLoss | 0     
5 | task_loss_fn      | LDAMLoss          | 0     
--------------------------------------------------------
3.6 M     Trainable params
0         Non-trainable params
3.6 M     Total params
14.466    Total estimated model params size (MB)


[I 2025-12-13 19:20:27,936] Trial 3 finished with value: 0.6666666666666667 and parameters: {'use_ldam_loss': True, 'ldam_max_margin': 0.5456592191001431, 'ldam_scale': 11, 'use_weighted_sampler': True, 'learning_rate': 0.01335381908879058, 'concept_loss_weight': 0.9675666141341164, 'emb_size': 128, 'intervention_prob': 0.4847923138822793, 'weight_decay': 0.0003550304858128307}. Best is trial 1 with value: 0.7924528301886792.


  rank_zero_warn(
  rank_zero_warn(
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name              | Type              | Params
--------------------------------------------------------
0 | concept_extractor | Sequential        | 164 K 
1 | context_layers    | ModuleList        | 690 K 
2 | prob_generator    | Linear            | 129   
3 | task_classifier   | Sequential        | 172 K 
4 | concept_loss_fn   | BCEWithLogitsLoss | 0     
5 | task_loss_fn      | LDAMLoss          | 0     
--------------------------------------------------------
1.0 M     Trainable params
0         Non-trainable params
1.0 M     Total params
4.110     Total estimated model params size (MB)


[I 2025-12-13 19:20:55,558] Trial 4 finished with value: 0.6060606060606061 and parameters: {'use_ldam_loss': True, 'ldam_max_margin': 0.6381099809299766, 'ldam_scale': 47, 'use_weighted_sampler': False, 'learning_rate': 0.0011935477742481386, 'concept_loss_weight': 0.9879954961448965, 'emb_size': 256, 'intervention_prob': 0.17837666334679464, 'weight_decay': 3.6464395589807184e-05}. Best is trial 1 with value: 0.7924528301886792.


  rank_zero_warn(
  rank_zero_warn(
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name              | Type              | Params
--------------------------------------------------------
0 | concept_extractor | Sequential        | 164 K 
1 | context_layers    | ModuleList        | 2.8 M 
2 | prob_generator    | Linear            | 513   
3 | task_classifier   | Sequential        | 688 K 
4 | concept_loss_fn   | BCEWithLogitsLoss | 0     
5 | task_loss_fn      | LDAMLoss          | 0     
--------------------------------------------------------
3.6 M     Trainable params
0         Non-trainable params
3.6 M     Total params
14.466    Total estimated model params size (MB)


[I 2025-12-13 19:21:07,722] Trial 5 finished with value: 0.7333333333333334 and parameters: {'use_ldam_loss': True, 'ldam_max_margin': 0.8219772826786357, 'ldam_scale': 13, 'use_weighted_sampler': True, 'learning_rate': 0.002175764980119757, 'concept_loss_weight': 0.5082831756854036, 'emb_size': 64, 'intervention_prob': 0.38563517334297287, 'weight_decay': 1.4063366777718176e-05}. Best is trial 1 with value: 0.7924528301886792.


  rank_zero_warn(
  rank_zero_warn(
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name              | Type              | Params
--------------------------------------------------------
0 | concept_extractor | Sequential        | 164 K 
1 | context_layers    | ModuleList        | 2.8 M 
2 | prob_generator    | Linear            | 513   
3 | task_classifier   | Sequential        | 688 K 
4 | concept_loss_fn   | BCEWithLogitsLoss | 0     
5 | task_loss_fn      | BCEWithLogitsLoss | 0     
--------------------------------------------------------
3.6 M     Trainable params
0         Non-trainable params
3.6 M     Total params
14.466    Total estimated model params size (MB)


[I 2025-12-13 19:21:26,482] Trial 6 finished with value: 0.7924528301886792 and parameters: {'use_ldam_loss': True, 'ldam_max_margin': 0.8767930832880342, 'ldam_scale': 35, 'use_weighted_sampler': True, 'learning_rate': 0.0033755895712060846, 'concept_loss_weight': 0.9877749830401206, 'emb_size': 256, 'intervention_prob': 0.23610746258097465, 'weight_decay': 1.7345566642360933e-05}. Best is trial 1 with value: 0.7924528301886792.


  rank_zero_warn(
  rank_zero_warn(
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name              | Type              | Params
--------------------------------------------------------
0 | concept_extractor | Sequential        | 164 K 
1 | context_layers    | ModuleList        | 690 K 
2 | prob_generator    | Linear            | 129   
3 | task_classifier   | Sequential        | 172 K 
4 | concept_loss_fn   | BCEWithLogitsLoss | 0     
5 | task_loss_fn      | LDAMLoss          | 0     
--------------------------------------------------------
1.0 M     Trainable params
0         Non-trainable params
1.0 M     Total params
4.110     Total estimated model params size (MB)


[I 2025-12-13 19:21:39,612] Trial 7 finished with value: 0.7692307692307693 and parameters: {'use_ldam_loss': False, 'ldam_max_margin': 0.6051494778125466, 'ldam_scale': 41, 'use_weighted_sampler': False, 'learning_rate': 0.005325732706437205, 'concept_loss_weight': 0.5381286901161428, 'emb_size': 256, 'intervention_prob': 0.15717799053816334, 'weight_decay': 0.0001040258761588385}. Best is trial 1 with value: 0.7924528301886792.


  rank_zero_warn(
  rank_zero_warn(
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name              | Type              | Params
--------------------------------------------------------
0 | concept_extractor | Sequential        | 164 K 
1 | context_layers    | ModuleList        | 2.8 M 
2 | prob_generator    | Linear            | 513   
3 | task_classifier   | Sequential        | 688 K 
4 | concept_loss_fn   | BCEWithLogitsLoss | 0     
5 | task_loss_fn      | BCEWithLogitsLoss | 0     
--------------------------------------------------------
3.6 M     Trainable params
0         Non-trainable params
3.6 M     Total params
14.466    Total estimated model params size (MB)


[I 2025-12-13 19:21:59,089] Trial 8 finished with value: 0.7241379310344829 and parameters: {'use_ldam_loss': True, 'ldam_max_margin': 0.4693446307320668, 'ldam_scale': 40, 'use_weighted_sampler': True, 'learning_rate': 0.0031065548585819088, 'concept_loss_weight': 0.7418319308810066, 'emb_size': 64, 'intervention_prob': 0.43573029509385885, 'weight_decay': 0.00040489662225846743}. Best is trial 1 with value: 0.7924528301886792.


  rank_zero_warn(
  rank_zero_warn(


[I 2025-12-13 19:22:13,950] Trial 9 finished with value: 0.6875 and parameters: {'use_ldam_loss': False, 'ldam_max_margin': 0.5854080177240857, 'ldam_scale': 43, 'use_weighted_sampler': True, 'learning_rate': 0.0015380658115982007, 'concept_loss_weight': 0.8419027438129125, 'emb_size': 256, 'intervention_prob': 0.0034760652655953517, 'weight_decay': 0.00010507384024181397}. Best is trial 1 with value: 0.7924528301886792.

                OPTIMIZATION COMPLETE

Results:
  Completed trials:  10
  Pruned trials:     0
  Best MCC:          0.7925


# Custom CEM Model - PyTorch Implementation

**Runtime:** ~15-20 minutes

This notebook:
1. Implements CEM from scratch using PyTorch
2. Uses LDAM Loss + WeightedRandomSampler for class imbalance
3. Same hyperparameters as `1_train_cem.ipynb`

**Prerequisites:** Run `0_prepare_dataset.ipynb` first!

## Section 0: Setup & Configuration

In [13]:
# Imports
import os
import json
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

import pytorch_lightning as pl
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks import ModelCheckpoint

from sklearn.metrics import (
    confusion_matrix,
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    matthews_corrcoef,
    roc_auc_score,
    balanced_accuracy_score,
    classification_report,
)

print("‚úì All imports successful")

‚úì All imports successful


In [14]:
# Set random seeds
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
pl.seed_everything(SEED)

print(f"‚úì Random seed set to {SEED}")

Global seed set to 42


‚úì Random seed set to 42


In [15]:
# Detect device
if torch.backends.mps.is_available():
    DEVICE = "mps"
    print("‚úì Using MacBook GPU (MPS)")
elif torch.cuda.is_available():
    DEVICE = "cuda"
    print("‚úì Using CUDA GPU")
else:
    DEVICE = "cpu"
    print("‚ö† Using CPU")

‚úì Using MacBook GPU (MPS)


In [16]:
# Define paths
PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), ".."))
DATA_PROCESSED = os.path.join(PROJECT_ROOT, "data/processed")
DATASET_DIR = os.path.join(DATA_PROCESSED, "whole_pipeline")
OUTPUT_DIR = "outputs_custom_cem"

print("‚úì Paths configured")
print(f"  Dataset dir: {DATASET_DIR}")
print(f"  Output dir: {OUTPUT_DIR}")

‚úì Paths configured
  Dataset dir: /Users/gualtieromarencoturi/Desktop/thesis/Master-Thesis-CEM-Depression-etc-case-study/data/processed/whole_pipeline
  Output dir: outputs_custom_cem


In [17]:
# Define 21 BDI-II concept names
CONCEPT_NAMES = [
    "Sadness", "Pessimism", "Past failure", "Loss of pleasure",
    "Guilty feelings", "Punishment feelings", "Self-dislike", "Self-criticalness",
    "Suicidal thoughts or wishes", "Crying", "Agitation", "Loss of interest",
    "Indecisiveness", "Worthlessness", "Loss of energy", "Changes in sleeping pattern",
    "Irritability", "Changes in appetite", "Concentration difficulty",
    "Tiredness or fatigue", "Loss of interest in sex"
]
N_CONCEPTS = len(CONCEPT_NAMES)

print(f"‚úì Defined {N_CONCEPTS} BDI-II concepts")

‚úì Defined 21 BDI-II concepts


In [18]:
# Hyperparameters
HYPERPARAMS = {
    # Model architecture
    "embedding_dim": 384,
    "n_concepts": 21,
    "n_tasks": 1,
    "emb_size": 128,
    
    # CEM-specific
    "shared_prob_gen": True,        # Share probability generator across concepts
    "intervention_prob": 0.25,      # Training intervention probability
    
    # Training
    "batch_size_train": 32,
    "batch_size_eval": 64,
    "max_epochs": 100,
    "learning_rate": 0.01,
    "weight_decay": 4e-05,
    
    # Loss
    "concept_loss_weight": 1.0,
    
    # LDAM Loss
    "use_ldam_loss": True,
    "n_positive": None,               # Will be set after loading data
    "n_negative": None,               # Will be set after loading data
    "ldam_max_margin": 0.3,           # Try: 0.3, 0.5, 0.7, 1.0
    "ldam_scale": 20,                 # Try: 20, 30, 40, 50
    
    # Weighted Sampler
    "use_weighted_sampler": False,
}

print("‚úì Hyperparameters configured")
if HYPERPARAMS['use_ldam_loss']:
    print(f"  Using LDAM LOSS (margin={HYPERPARAMS['ldam_max_margin']}, scale={HYPERPARAMS['ldam_scale']})")
else:
    print(f"  Using standard BCE loss")

‚úì Hyperparameters configured
  Using LDAM LOSS (margin=0.3, scale=20)


## Section 1: Load Preprocessed Data

In [19]:
# Load training data
print("Loading preprocessed datasets...")

train_data = np.load(os.path.join(DATASET_DIR, "train_data.npz"))
X_train = train_data['X']
C_train = train_data['C']
y_train = train_data['y']
train_subject_ids = train_data['subject_ids']

print(f"‚úì Loaded training data: {X_train.shape}")

Loading preprocessed datasets...
‚úì Loaded training data: (486, 384)


In [20]:
# Load validation data
val_data = np.load(os.path.join(DATASET_DIR, "val_data.npz"))
X_val = val_data['X']
C_val = val_data['C']
y_val = val_data['y']
val_subject_ids = val_data['subject_ids']

print(f"‚úì Loaded validation data: {X_val.shape}")

‚úì Loaded validation data: (200, 384)


In [21]:
# Load test data
test_data = np.load(os.path.join(DATASET_DIR, "test_data.npz"))
X_test = test_data['X']
C_test = test_data['C']
y_test = test_data['y']
test_subject_ids = test_data['subject_ids']

print(f"‚úì Loaded test data: {X_test.shape}")

‚úì Loaded test data: (201, 384)


In [None]:
# Load class weights
with open(os.path.join(DATASET_DIR, "class_weights.json"), 'r') as f:
    class_info = json.load(f)

n_positive = class_info['n_positive']
n_negative = class_info['n_negative']
pos_weight = class_info['pos_weight']

# Update HYPERPARAMS with actual class counts for LDAM
HYPERPARAMS['n_positive'] = n_positive
HYPERPARAMS['n_negative'] = n_negative

print(f"‚úì Loaded class weights:")
print(f"  Negative: {n_negative}, Positive: {n_positive}")
print(f"  Ratio: 1:{pos_weight:.2f}")

## Section 2: PyTorch Dataset & DataLoaders

In [None]:
class CEMDataset(Dataset):
    def __init__(self, X, C, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.C = torch.tensor(C, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32)
    
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx], self.C[idx]

# Create datasets
train_dataset = CEMDataset(X_train, C_train, y_train)
val_dataset = CEMDataset(X_val, C_val, y_val)
test_dataset = CEMDataset(X_test, C_test, y_test)

# Create WeightedRandomSampler for batch-level oversampling (if enabled)
if HYPERPARAMS['use_weighted_sampler']:
    # Compute class sample counts
    class_sample_counts = np.bincount(y_train.astype(int))  # [n_negative, n_positive]
    weights = 1. / class_sample_counts
    sample_weights = weights[y_train.astype(int)]
    
    # Create sampler
    train_sampler = WeightedRandomSampler(
        weights=sample_weights,
        num_samples=len(sample_weights),
        replacement=True  # Allow positive samples to appear multiple times
    )
    
    print(f"‚úì WeightedRandomSampler created:")
    print(f"  Negative weight: {weights[0]:.4f}")
    print(f"  Positive weight: {weights[1]:.4f}")
    print(f"  Expected positive ratio per batch: ~{weights[1]/(weights[0]+weights[1]):.1%}")
    
    # Create train loader with sampler (shuffle=False when using sampler)
    train_loader = DataLoader(train_dataset, batch_size=HYPERPARAMS['batch_size_train'], sampler=train_sampler)
else:
    # Standard train loader with shuffle
    train_loader = DataLoader(train_dataset, batch_size=HYPERPARAMS['batch_size_train'], shuffle=True)
    print("‚úì Using standard DataLoader (shuffle=True)")

# Validation and test loaders (no sampling)
val_loader = DataLoader(val_dataset, batch_size=HYPERPARAMS['batch_size_eval'], shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=HYPERPARAMS['batch_size_eval'], shuffle=False)

print("‚úì All DataLoaders created")

## Section 3: Custom CEM Model Definition

In [None]:
# LDAM Loss (for class imbalance)
class LDAMLoss(nn.Module):
    """
    Label-Distribution-Aware Margin (LDAM) Loss for long-tailed recognition.
    
    Creates class-dependent margins to make decision boundaries harder for minority classes.
    """
    def __init__(self, n_positive, n_negative, max_margin=0.5, scale=30):
        super(LDAMLoss, self).__init__()
        self.max_margin = max_margin
        self.scale = scale
        
        # Compute class frequencies
        total = n_positive + n_negative
        freq_pos = n_positive / total
        freq_neg = n_negative / total
        
        # Compute margins: minority class gets larger margin
        margin_pos = max_margin * (freq_pos ** (-0.25))
        margin_neg = max_margin * (freq_neg ** (-0.25))
        
        self.register_buffer('margin_pos', torch.tensor(margin_pos))
        self.register_buffer('margin_neg', torch.tensor(margin_neg))
    
    def forward(self, logits, targets):
        logits = logits.view(-1)
        targets = targets.view(-1).float()
        
        # Apply class-dependent margins
        margin = targets * self.margin_pos + (1 - targets) * (-self.margin_neg)
        adjusted_logits = (logits - margin) * self.scale
        
        return F.binary_cross_entropy_with_logits(adjusted_logits, targets, reduction='mean')


# Custom CEM Implementation
class CustomCEM(pl.LightningModule):
    """
    Custom Concept Embedding Model (CEM) implementation.
    
    Architecture:
      X ‚Üí concept_extractor ‚Üí context_layers ‚Üí prob_generator ‚Üí dual_embeddings ‚Üí task_classifier ‚Üí y
    """
    def __init__(
        self,
        n_concepts=21,
        emb_size=128,
        input_dim=384,
        shared_prob_gen=True,
        intervention_prob=0.25,
        concept_loss_weight=1.0,
        learning_rate=0.01,
        weight_decay=4e-05,
        use_ldam_loss=True,
        n_positive=83,
        n_negative=403,
        ldam_max_margin=0.5,
        ldam_scale=30,
    ):
        super().__init__()
        self.save_hyperparameters()
        
        self.n_concepts = n_concepts
        self.emb_size = emb_size
        self.intervention_prob = intervention_prob
        self.concept_loss_weight = concept_loss_weight
        
        # Stage 1: Concept Extractor (X ‚Üí Pre-Concept Features)
        self.concept_extractor = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 256)  # Pre-concept features
        )
        
        # Stage 2: Context Generators (Features ‚Üí Dual Embeddings)
        # Each concept gets its own context generator
        self.context_layers = nn.ModuleList([
            nn.Sequential(
                nn.Linear(256, emb_size * 2),  # Dual embeddings (true/false)
                nn.LeakyReLU()
            ) for _ in range(n_concepts)
        ])
        
        # Stage 3: Probability Generator (Contexts ‚Üí Concept Probabilities)
        if shared_prob_gen:
            # Single shared generator for all concepts
            self.prob_generator = nn.Linear(emb_size * 2, 1)
        else:
            # Per-concept probability generators
            self.prob_generators = nn.ModuleList([
                nn.Linear(emb_size * 2, 1) for _ in range(n_concepts)
            ])
        
        self.shared_prob_gen = shared_prob_gen
        
        # Stage 4: Task Classifier (Concept Embeddings ‚Üí Task Output)
        self.task_classifier = nn.Sequential(
            nn.Linear(n_concepts * emb_size, 128),
            nn.LeakyReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 1)  # Binary classification
        )
        
        # Loss functions
        self.concept_loss_fn = nn.BCEWithLogitsLoss()
        if use_ldam_loss:
            self.task_loss_fn = LDAMLoss(n_positive, n_negative, ldam_max_margin, ldam_scale)
        else:
            self.task_loss_fn = nn.BCEWithLogitsLoss()
    
    def forward(self, x, c_true=None, train=False):
        # Step 1: Extract pre-concept features
        pre_features = self.concept_extractor(x)  # (B, 256)
        
        # Step 2: Generate contexts and probabilities per concept
        contexts = []
        c_logits_list = []
        
        for i, context_layer in enumerate(self.context_layers):
            context = context_layer(pre_features)  # (B, emb_size*2)
            
            # Get probability logit
            if self.shared_prob_gen:
                logit = self.prob_generator(context)  # (B, 1)
            else:
                logit = self.prob_generators[i](context)
            
            contexts.append(context)
            c_logits_list.append(logit)
        
        c_logits = torch.cat(c_logits_list, dim=1)  # (B, 21)
        c_probs = torch.sigmoid(c_logits)           # (B, 21)
        
        # Step 3: Apply intervention (optional during training)
        if train and self.intervention_prob > 0 and c_true is not None:
            intervention_mask = torch.bernoulli(
                torch.ones_like(c_probs) * self.intervention_prob
            )
            c_probs = c_probs * (1 - intervention_mask) + c_true * intervention_mask
        
        # Step 4: Mix dual embeddings based on probabilities
        concept_embeddings = []
        for i, context in enumerate(contexts):
            # Split into true/false embeddings
            emb_true = context[:, :self.emb_size]       # First half
            emb_false = context[:, self.emb_size:]      # Second half
            
            # Weight by probability
            prob = c_probs[:, i:i+1]  # (B, 1)
            mixed_emb = emb_true * prob + emb_false * (1 - prob)
            concept_embeddings.append(mixed_emb)
        
        # Concatenate all concept embeddings
        c_embeddings = torch.cat(concept_embeddings, dim=1)  # (B, 21*emb_size)
        
        # Step 5: Task prediction
        y_logits = self.task_classifier(c_embeddings)  # (B, 1)
        
        return c_logits, y_logits
    
    def training_step(self, batch, batch_idx):
        x, y, c_true = batch
        c_logits, y_logits = self.forward(x, c_true=c_true, train=True)
        
        # Task loss (LDAM)
        task_loss = self.task_loss_fn(y_logits.squeeze(), y.squeeze())
        
        # Concept loss (BCE)
        concept_loss = self.concept_loss_fn(c_logits, c_true)
        
        # Combined loss
        loss = task_loss + self.concept_loss_weight * concept_loss
        
        # Logging
        self.log('train_loss', loss, on_epoch=True, prog_bar=True)
        self.log('train_task_loss', task_loss, on_epoch=True)
        self.log('train_concept_loss', concept_loss, on_epoch=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y, c_true = batch
        c_logits, y_logits = self.forward(x, c_true=c_true, train=False)
        
        # Task loss
        task_loss = self.task_loss_fn(y_logits.squeeze(), y.squeeze())
        
        # Concept loss
        concept_loss = self.concept_loss_fn(c_logits, c_true)
        
        # Combined loss
        loss = task_loss + self.concept_loss_weight * concept_loss
        
        # Logging
        self.log('val_loss', loss, on_epoch=True, prog_bar=True)
        self.log('val_task_loss', task_loss, on_epoch=True)
        self.log('val_concept_loss', concept_loss, on_epoch=True)
        
        return loss
    
    def configure_optimizers(self):
        return torch.optim.Adam(
            self.parameters(),
            lr=self.hparams.learning_rate,
            weight_decay=self.hparams.weight_decay
        )

print("‚úì Custom CEM model defined")

## Section 4: Model Initialization

In [None]:
# Initialize Custom CEM model
custom_cem = CustomCEM(
    n_concepts=HYPERPARAMS['n_concepts'],
    emb_size=HYPERPARAMS['emb_size'],
    input_dim=HYPERPARAMS['embedding_dim'],
    shared_prob_gen=HYPERPARAMS['shared_prob_gen'],
    intervention_prob=HYPERPARAMS['intervention_prob'],
    concept_loss_weight=HYPERPARAMS['concept_loss_weight'],
    learning_rate=HYPERPARAMS['learning_rate'],
    weight_decay=HYPERPARAMS['weight_decay'],
    use_ldam_loss=HYPERPARAMS['use_ldam_loss'],
    n_positive=HYPERPARAMS['n_positive'],
    n_negative=HYPERPARAMS['n_negative'],
    ldam_max_margin=HYPERPARAMS['ldam_max_margin'],
    ldam_scale=HYPERPARAMS['ldam_scale']
)

print("‚úì Custom CEM model initialized")
print(f"  Using LDAM Loss (margin={HYPERPARAMS['ldam_max_margin']}, scale={HYPERPARAMS['ldam_scale']})")
print(f"  Concept embedding size: {HYPERPARAMS['emb_size']}")
print(f"  Intervention probability: {HYPERPARAMS['intervention_prob']}")
print(f"  Shared probability generator: {HYPERPARAMS['shared_prob_gen']}")
print(f"  Class counts: {HYPERPARAMS['n_positive']} positive, {HYPERPARAMS['n_negative']} negative")

## Section 5: Training

In [None]:
# Setup trainer
checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",
    dirpath=os.path.join(OUTPUT_DIR, "models"),
    filename="custom-cem-{epoch:02d}-{val_loss:.2f}",
    save_top_k=1,
    mode="min"
)

trainer = pl.Trainer(
    max_epochs=HYPERPARAMS['max_epochs'],
    accelerator=DEVICE,
    devices=1,
    logger=CSVLogger(save_dir=os.path.join(OUTPUT_DIR, "logs"), name="custom_cem_pipeline"),
    log_every_n_steps=10,
    callbacks=[checkpoint_callback],
    enable_progress_bar=True
)

print("‚úì Trainer configured")

In [None]:
# Train model
print("\nStarting training...\n")
trainer.fit(custom_cem, train_loader, val_loader)
print("\n‚úì Training complete!")

## Section 6: Test Evaluation

In [None]:
# Run inference on test set
print("Running inference on test set...")

custom_cem.eval()
device_obj = torch.device(DEVICE)
custom_cem = custom_cem.to(device_obj)

y_true_list = []
y_prob_list = []
concept_probs_list = []

with torch.no_grad():
    for x_batch, y_batch, c_batch in test_loader:
        x_batch = x_batch.to(device_obj)
        
        c_logits, y_logits = custom_cem(x_batch)
        c_probs = torch.sigmoid(c_logits).cpu().numpy()
        y_probs = torch.sigmoid(y_logits).cpu().squeeze().numpy()
        
        y_true_list.extend(y_batch.numpy().astype(int).tolist())
        y_prob_list.extend(y_probs.tolist() if isinstance(y_probs, np.ndarray) else [y_probs])
        concept_probs_list.extend(c_probs.tolist())

y_true = np.array(y_true_list)
y_prob = np.array(y_prob_list)
concept_probs = np.array(concept_probs_list)

print("‚úì Inference complete")

## Section 7: Results Display

In [None]:
# MANUAL THRESHOLD TEST - Using threshold=0.1 to achieve 75%+ recall
best_threshold = 0.1  # ‚Üê MANUALLY SET TO 0.1 (will catch 20/26 = 77% recall)

print(f"\nüîç TESTING THRESHOLD: {best_threshold:.2f}")
print(f"   Expected: Catch 20/26 depression cases (77% recall)")

# Apply threshold
y_pred = (y_prob >= best_threshold).astype(int)

print(f"\n‚úì Predictions created")
print(f"  Predicted positive: {np.sum(y_pred)} / {len(y_pred)}")


In [None]:
# Compute all metrics
cm = confusion_matrix(y_true, y_pred)
tn, fp, fn, tp = cm.ravel()

acc = accuracy_score(y_true, y_pred)
balanced_acc = balanced_accuracy_score(y_true, y_pred)
roc_auc = roc_auc_score(y_true, y_prob)
mcc = matthews_corrcoef(y_true, y_pred)
f1_binary = f1_score(y_true, y_pred, pos_label=1)
f1_macro = f1_score(y_true, y_pred, average='macro')
precision_binary = precision_score(y_true, y_pred, pos_label=1)
recall_binary = recall_score(y_true, y_pred, pos_label=1)

# Print results
print("\n" + "="*70)
print("                    TEST SET EVALUATION")
print("="*70)
print(f"\nDecision Threshold: {best_threshold:.2f}")

# Enhanced Confusion Matrix Display
print(f"\n{'CONFUSION MATRIX':^50}")
print("="*50)
print(f"{'':>20} ‚îÇ {'Predicted Negative':^12} ‚îÇ {'Predicted Positive':^12}")
print("‚îÄ"*50)
print(f"{'Actual Negative':>20} ‚îÇ {f'TN = {tn}':^12} ‚îÇ {f'FP = {fp}':^12}")
print(f"{'Actual Positive':>20} ‚îÇ {f'FN = {fn}':^12} ‚îÇ {f'TP = {tp}':^12}")
print("="*50)
print(f"\n  True Positives:  {tp:>3}/{int(np.sum(y_true)):<3} ({100*tp/np.sum(y_true):>5.1f}% of depression cases caught)")
print(f"  False Negatives: {fn:>3}/{int(np.sum(y_true)):<3} ({100*fn/np.sum(y_true):>5.1f}% of depression cases MISSED)")
print(f"  True Negatives:  {tn:>3}/{int(len(y_true)-np.sum(y_true)):<3} ({100*tn/(len(y_true)-np.sum(y_true)):>5.1f}% of healthy correctly identified)")
print(f"  False Positives: {fp:>3}/{int(len(y_true)-np.sum(y_true)):<3} ({100*fp/(len(y_true)-np.sum(y_true)):>5.1f}% false alarms)")

print(f"\nPerformance Metrics:")
print(f"  Accuracy:                  {acc:.4f}")
print(f"  Balanced Accuracy:         {balanced_acc:.4f}")
print(f"  ROC-AUC:                   {roc_auc:.4f}")
print(f"  Matthews Correlation:      {mcc:.4f}")
print(f"\n  F1 Score (Binary):         {f1_binary:.4f}")
print(f"  F1 Score (Macro):          {f1_macro:.4f}")
print(f"  Precision (Binary):        {precision_binary:.4f}")
print(f"  Recall (Binary):           {recall_binary:.4f}")

print("\n" + classification_report(y_true, y_pred, target_names=['Negative', 'Positive']))
print("="*70)

In [None]:
# Save results
metrics_dict = {
    "model_type": "custom_cem",
    "threshold": float(best_threshold),
    "n_samples": int(len(y_true)),
    "n_positive": int(np.sum(y_true)),
    "n_negative": int(len(y_true) - np.sum(y_true)),
    "accuracy": float(acc),
    "balanced_accuracy": float(balanced_acc),
    "roc_auc": float(roc_auc),
    "mcc": float(mcc),
    "f1_binary": float(f1_binary),
    "f1_macro": float(f1_macro),
    "precision_binary": float(precision_binary),
    "recall_binary": float(recall_binary),
    "confusion_matrix": {"tn": int(tn), "fp": int(fp), "fn": int(fn), "tp": int(tp)}
}

os.makedirs(os.path.join(OUTPUT_DIR, "results"), exist_ok=True)
with open(os.path.join(OUTPUT_DIR, "results/test_metrics.json"), 'w') as f:
    json.dump(metrics_dict, f, indent=4)

# Save predictions
predictions_df = pd.DataFrame({
    'subject_id': test_subject_ids,
    'y_true': y_true,
    'y_pred': y_pred,
    'y_prob': y_prob
})

for i, concept_name in enumerate(CONCEPT_NAMES):
    predictions_df[concept_name] = concept_probs[:, i]

predictions_df.to_csv(os.path.join(OUTPUT_DIR, "results/test_predictions.csv"), index=False)

print(f"‚úì Results saved to {OUTPUT_DIR}/results/")

In [None]:
print("\n" + "="*70)
print("              CUSTOM CEM TRAINING COMPLETE")
print("="*70)
print(f"\nGenerated files:")
print(f"  Model checkpoint: {OUTPUT_DIR}/models/")
print(f"  Metrics JSON:     {OUTPUT_DIR}/results/test_metrics.json")
print(f"  Predictions CSV:  {OUTPUT_DIR}/results/test_predictions.csv")
print("="*70)

In [None]:
# Display best hyperparameters
print("="*70)
print("                 BEST HYPERPARAMETERS")
print("="*70)

best_params = study.best_params
best_threshold = study.best_trial.user_attrs['best_threshold']

print("\nOptimal hyperparameters:")
for key, value in sorted(best_params.items()):
    if isinstance(value, float):
        print(f"  {key:<25} {value:.6f}")
    else:
        print(f"  {key:<25} {value}")

print(f"\n  {'best_threshold':<25} {best_threshold:.2f} (optimized on validation)")
print(f"\nValidation Performance:")
print(f"  MCC:                      {study.best_value:.4f}")
print("="*70)

In [None]:
# Save best hyperparameters
best_config = {
    **study.best_params,
    'best_threshold': float(best_threshold),
    'validation_mcc': float(study.best_value),
    'n_trials': len(study.trials),
    'n_pruned': len([t for t in study.trials if t.state == optuna.trial.TrialState.PRUNED]),
}

os.makedirs(OUTPUT_DIR, exist_ok=True)
with open(os.path.join(OUTPUT_DIR, 'best_hyperparameters.json'), 'w') as f:
    json.dump(best_config, f, indent=4)

print(f"‚úì Saved best hyperparameters to {OUTPUT_DIR}/best_hyperparameters.json")

In [None]:
print("\n" + "="*70)
print("           TRAINING FINAL MODEL WITH BEST HYPERPARAMETERS")
print("="*70)

# Load test data
print("\nLoading test data...")
test_data = np.load(os.path.join(DATASET_DIR, "test_data.npz"))
X_test = test_data['X']
C_test = test_data['C']
y_test = test_data['y']
test_subject_ids = test_data['subject_ids']

print(f"‚úì Loaded test data: {X_test.shape}")

In [None]:
# Combine train + val for final training
print("\nCombining train + validation sets...")

X_train_full = np.concatenate([X_train, X_val], axis=0)
C_train_full = np.concatenate([C_train, C_val], axis=0)
y_train_full = np.concatenate([y_train, y_val], axis=0)

print(f"‚úì Combined: {X_train_full.shape}")

train_full_dataset = CEMDataset(X_train_full, C_train_full, y_train_full)
test_dataset = CEMDataset(X_test, C_test, y_test)

In [None]:
# Load best hyperparameters
print("\nLoading best hyperparameters...")

with open(os.path.join(OUTPUT_DIR, 'best_hyperparameters.json'), 'r') as f:
    best_config = json.load(f)

print("‚úì Best hyperparameters loaded:")
for key, value in sorted(best_config.items()):
    if key not in ['n_trials', 'n_pruned', 'validation_mcc']:
        print(f"  {key:<25} {value}")

In [None]:
# Create DataLoader with best configuration
print("\nCreating final DataLoader...")

use_sampler = best_config['use_weighted_sampler']
final_seed = BASE_SEED + 9999

if use_sampler:
    class_sample_counts = np.bincount(y_train_full.astype(int))
    weights = 1.0 / class_sample_counts
    sample_weights = weights[y_train_full.astype(int)]

    train_full_sampler = WeightedRandomSampler(
        weights=sample_weights,
        num_samples=len(sample_weights),
        replacement=True,
        generator=torch.Generator().manual_seed(final_seed)
    )

    train_full_loader = DataLoader(
        train_full_dataset,
        batch_size=FIXED_PARAMS['batch_size_train'],
        sampler=train_full_sampler,
        worker_init_fn=lambda worker_id: np.random.seed(final_seed + worker_id)
    )
    print("‚úì Using WeightedRandomSampler")
else:
    train_full_loader = DataLoader(
        train_full_dataset,
        batch_size=FIXED_PARAMS['batch_size_train'],
        shuffle=True,
        generator=torch.Generator().manual_seed(final_seed),
        worker_init_fn=lambda worker_id: np.random.seed(final_seed + worker_id)
    )
    print("‚úì Using standard shuffle")

test_loader = DataLoader(test_dataset, batch_size=FIXED_PARAMS['batch_size_eval'], shuffle=False)

In [None]:
# Seed for final model
np.random.seed(final_seed)
torch.manual_seed(final_seed)
pl.seed_everything(final_seed, workers=True)

# Create final model
print("\nInitializing final model...")

final_model = CustomCEM(
    n_concepts=FIXED_PARAMS['n_concepts'],
    emb_size=best_config['emb_size'],
    input_dim=FIXED_PARAMS['embedding_dim'],
    shared_prob_gen=FIXED_PARAMS['shared_prob_gen'],
    intervention_prob=best_config['intervention_prob'],
    concept_loss_weight=best_config['concept_loss_weight'],
    learning_rate=best_config['learning_rate'],
    weight_decay=best_config['weight_decay'],
    use_ldam_loss=best_config['use_ldam_loss'],
    n_positive=n_positive,
    n_negative=n_negative,
    ldam_max_margin=best_config['ldam_max_margin'],
    ldam_scale=best_config['ldam_scale'],
)

print("‚úì Final model initialized")

In [None]:
# Setup trainer
print("\nSetting up trainer...")

final_checkpoint = ModelCheckpoint(
    monitor="train_loss",
    dirpath=os.path.join(OUTPUT_DIR, "final_model"),
    filename="final-custom-cem-{epoch:02d}-{train_loss:.2f}",
    save_top_k=1,
    mode="min"
)

final_trainer = pl.Trainer(
    max_epochs=FIXED_PARAMS['max_epochs'],
    accelerator=DEVICE,
    devices=1,
    callbacks=[final_checkpoint],
    enable_progress_bar=True,
    logger=CSVLogger(save_dir=os.path.join(OUTPUT_DIR, "logs"), name="final_model"),
)

print("\nStarting final model training...\n")
final_trainer.fit(final_model, train_full_loader)
print("\n‚úì Training complete!")

In [None]:
# Test set inference
print("\n" + "="*70)
print("                  FINAL MODEL - TEST SET EVALUATION")
print("="*70)

print("\nRunning inference...")

final_model.eval()
device_obj = torch.device(DEVICE)
final_model = final_model.to(device_obj)

y_true_test = []
y_prob_test = []
concept_probs_test = []

with torch.no_grad():
    for x_batch, y_batch, c_batch in test_loader:
        x_batch = x_batch.to(device_obj)
        c_logits, y_logits = final_model(x_batch)

        c_probs = torch.sigmoid(c_logits).cpu().numpy()
        y_probs = torch.sigmoid(y_logits).cpu().squeeze().numpy()
        y_true_batch = y_batch.cpu().numpy().astype(int)

        if y_probs.ndim == 0:
            y_prob_test.append(float(y_probs))
            y_true_test.append(int(y_true_batch))
            concept_probs_test.append(c_probs.squeeze())
        else:
            y_prob_test.extend(y_probs.tolist())
            y_true_test.extend(y_true_batch.tolist())
            concept_probs_test.extend(c_probs.tolist())

y_true_test = np.array(y_true_test)
y_prob_test = np.array(y_prob_test)
concept_probs_test = np.array(concept_probs_test)

print(f"‚úì Predictions for {len(y_true_test)} samples")

In [None]:
# Apply best threshold and compute metrics
best_threshold = best_config['best_threshold']
print(f"\nApplying threshold: {best_threshold:.2f}")

y_pred_test = (y_prob_test >= best_threshold).astype(int)

cm = confusion_matrix(y_true_test, y_pred_test)
tn, fp, fn, tp = cm.ravel()

test_accuracy = accuracy_score(y_true_test, y_pred_test)
test_balanced_acc = balanced_accuracy_score(y_true_test, y_pred_test)
test_roc_auc = roc_auc_score(y_true_test, y_prob_test)
test_mcc = matthews_corrcoef(y_true_test, y_pred_test)
test_f1 = f1_score(y_true_test, y_pred_test)
test_precision = precision_score(y_true_test, y_pred_test) if (tp + fp) > 0 else 0.0
test_recall = recall_score(y_true_test, y_pred_test) if (tp + fn) > 0 else 0.0

# Display results
print("\n" + "="*70)
print("                    TEST SET RESULTS")
print("="*70)
print(f"\n{'CONFUSION MATRIX':^50}")
print("="*50)
print(f"{'':>20} ‚îÇ {'Predicted Neg':^15} ‚îÇ {'Predicted Pos':^15}")
print("‚îÄ"*50)
print(f"{'Actual Negative':>20} ‚îÇ {f'TN = {tn}':^15} ‚îÇ {f'FP = {fp}':^15}")
print(f"{'Actual Positive':>20} ‚îÇ {f'FN = {fn}':^15} ‚îÇ {f'TP = {tp}':^15}")
print("="*50)

n_pos = int(np.sum(y_true_test))
n_neg = int(len(y_true_test) - n_pos)

print(f"\n  TP: {tp}/{n_pos} ({100*tp/n_pos if n_pos > 0 else 0:.1f}% caught)")
print(f"  FN: {fn}/{n_pos} ({100*fn/n_pos if n_pos > 0 else 0:.1f}% missed)")

print(f"\nMetrics:")
print(f"  MCC:        {test_mcc:.4f}")
print(f"  F1:         {test_f1:.4f}")
print(f"  Recall:     {test_recall:.4f}")
print(f"  Precision:  {test_precision:.4f}")
print(f"  ROC-AUC:    {test_roc_auc:.4f}")

print("\n" + classification_report(y_true_test, y_pred_test, target_names=['Negative', 'Positive']))
print("="*70)

In [None]:
# Save results
print("\nSaving results...")

final_results = {
    'optimization_summary': {
        'n_trials': best_config['n_trials'],
        'best_validation_mcc': best_config['validation_mcc'],
    },
    'best_hyperparameters': {k: v for k, v in best_config.items()
                             if k not in ['n_trials', 'n_pruned', 'validation_mcc']},
    'test_metrics': {
        'threshold': float(best_threshold),
        'mcc': float(test_mcc),
        'f1': float(test_f1),
        'recall': float(test_recall),
        'precision': float(test_precision),
        'roc_auc': float(test_roc_auc),
        'accuracy': float(test_accuracy),
        'confusion_matrix': {'tn': int(tn), 'fp': int(fp), 'fn': int(fn), 'tp': int(tp)}
    }
}

with open(os.path.join(OUTPUT_DIR, 'final_test_results.json'), 'w') as f:
    json.dump(final_results, f, indent=4)

# Save predictions
predictions_df = pd.DataFrame({
    'subject_id': test_subject_ids,
    'y_true': y_true_test,
    'y_pred': y_pred_test,
    'y_prob': y_prob_test
})

for i, concept_name in enumerate(CONCEPT_NAMES):
    predictions_df[concept_name] = concept_probs_test[:, i]

predictions_df.to_csv(os.path.join(OUTPUT_DIR, 'final_test_predictions.csv'), index=False)

print(f"‚úì Saved to {OUTPUT_DIR}/")

In [None]:
print("\n" + "="*70)
print("              OPTIMIZATION COMPLETE")
print("="*70)

print("\nüìä SUMMARY:")
print(f"  Trials:              {best_config['n_trials']}")
print(f"  Best val MCC:        {best_config['validation_mcc']:.4f}")

print("\nüèÜ BEST HYPERPARAMETERS:")
print(f"  Embedding size:      {best_config['emb_size']}")
print(f"  Learning rate:       {best_config['learning_rate']:.6f}")
print(f"  Use LDAM:            {best_config['use_ldam_loss']}")
print(f"  Use sampler:         {best_config['use_weighted_sampler']}")

print("\nüéØ TEST PERFORMANCE:")
print(f"  MCC:                 {test_mcc:.4f}")
print(f"  Recall:              {test_recall:.4f} ({tp}/{n_pos} caught)")
print(f"  Precision:           {test_precision:.4f}")

print("\nüìÅ FILES:")
print(f"  Best params:         {OUTPUT_DIR}/best_hyperparameters.json")
print(f"  Test results:        {OUTPUT_DIR}/final_test_results.json")
print(f"  Predictions:         {OUTPUT_DIR}/final_test_predictions.csv")
print(f"  Model:               {OUTPUT_DIR}/final_model/")

print("\n‚úÖ Final model ready for deployment!")
print("="*70)

In [None]:
# Visualization: Optimization History
fig = optuna.visualization.plot_optimization_history(study)
fig.write_html(os.path.join(OUTPUT_DIR, 'optimization_history.html'))
fig.show()

print(f"‚úì Saved optimization history plot to {OUTPUT_DIR}/optimization_history.html")

In [None]:
# Visualization: Parameter Importances
fig = optuna.visualization.plot_param_importances(study)
fig.write_html(os.path.join(OUTPUT_DIR, 'param_importances.html'))
fig.show()

print(f"‚úì Saved parameter importance plot to {OUTPUT_DIR}/param_importances.html")

In [None]:
# Visualization: Parallel Coordinate Plot
fig = optuna.visualization.plot_parallel_coordinate(study)
fig.write_html(os.path.join(OUTPUT_DIR, 'parallel_coordinate.html'))
fig.show()

print(f"‚úì Saved parallel coordinate plot to {OUTPUT_DIR}/parallel_coordinate.html")

In [None]:
# Save final results
final_results = {
    'best_hyperparameters': best_config,
    'test_metrics': {
        'threshold': float(best_threshold),
        'mcc': float(test_mcc),
        'recall': float(test_recall),
        'precision': float(test_precision),
        'f1': float(test_f1),
        'roc_auc': float(test_roc_auc),
        'accuracy': float(test_accuracy),
        'confusion_matrix': {
            'tn': int(tn),
            'fp': int(fp),
            'fn': int(fn),
            'tp': int(tp)
        }
    }
}

with open(os.path.join(OUTPUT_DIR, 'final_test_results.json'), 'w') as f:
    json.dump(final_results, f, indent=4)

# Save predictions
predictions_df = pd.DataFrame({
    'subject_id': test_subject_ids,
    'y_true': y_true_test,
    'y_pred': y_pred_test,
    'y_prob': y_prob_test
})

for i, concept_name in enumerate(CONCEPT_NAMES):
    predictions_df[concept_name] = concept_probs_test[:, i]

predictions_df.to_csv(os.path.join(OUTPUT_DIR, 'test_predictions.csv'), index=False)

print(f"\n‚úì Saved final results to {OUTPUT_DIR}/final_test_results.json")
print(f"‚úì Saved predictions to {OUTPUT_DIR}/test_predictions.csv")

print("\n" + "="*70)
print("              OPTUNA OPTIMIZATION COMPLETE")
print("="*70)
print(f"\nGenerated files:")
print(f"  Best hyperparameters:      {OUTPUT_DIR}/best_hyperparameters.json")
print(f"  Final test results:        {OUTPUT_DIR}/final_test_results.json")
print(f"  Test predictions:          {OUTPUT_DIR}/test_predictions.csv")
print(f"  Best model checkpoint:     {OUTPUT_DIR}/models/")
print(f"  Optimization history:      {OUTPUT_DIR}/optimization_history.html")
print(f"  Parameter importances:     {OUTPUT_DIR}/param_importances.html")
print(f"  Parallel coordinate plot:  {OUTPUT_DIR}/parallel_coordinate.html")
print("="*70)