# CBM Hard Bottleneck - Two-Stage Training (Alternative Pipeline)

**Runtime:** ~20-30 minutes

This notebook implements a **true two-stage** Concept Bottleneck Model:
1. **Stage 1**: Train concept predictor (X → Concepts)
2. **Stage 2**: Freeze concept predictor, train task classifier (Concepts → Y)

**Key Features:**
- Uses alternative dataset with SUM-based concept scoring
- Validation set has TRUE concept labels (better supervision)
- Hard bottleneck with discrete concept predictions
- LDAM loss for imbalanced task classification
- Separate optimizers for each stage

**Prerequisites:** Run `0_prepare_alternative_attention_dataset.ipynb` first!

## Section 0: Configuration & Setup

In [1]:
# Imports
import os
import json
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

import pytorch_lightning as pl
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

from sklearn.metrics import (
    confusion_matrix,
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    matthews_corrcoef,
    roc_auc_score,
    balanced_accuracy_score,
    classification_report,
)

print("✓ All imports successful")

✓ All imports successful


In [2]:
# Set random seeds for reproducibility
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
pl.seed_everything(SEED)

print(f"✓ Random seed set to {SEED}")

Global seed set to 42


✓ Random seed set to 42


In [3]:
# Detect device (MPS/CUDA/CPU)
if torch.backends.mps.is_available():
    DEVICE = "mps"
    print("✓ Using MacBook GPU (MPS)")
elif torch.cuda.is_available():
    DEVICE = "cuda"
    print("✓ Using CUDA GPU")
else:
    DEVICE = "cpu"
    print("⚠ Using CPU (will be slow)")

✓ Using MacBook GPU (MPS)


In [4]:
# Define paths - USING ALTERNATIVE PIPELINE
PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), "..", ".."))
DATA_PROCESSED = os.path.join(PROJECT_ROOT, "data/processed")
DATASET_DIR = os.path.join(DATA_PROCESSED, "alternative_attention_pipeline")  # CHANGED
OUTPUT_DIR = "outputs_cbm_hard_alt"  # CHANGED

print("✓ Paths configured")
print(f"  Project root: {PROJECT_ROOT}")
print(f"  Dataset dir: {DATASET_DIR}")
print(f"  Output dir: {OUTPUT_DIR}")
print("\n  NOTE: Using ALTERNATIVE dataset (SUM-based concept scoring)")
print("        Validation set has TRUE concept labels!")

✓ Paths configured
  Project root: /Users/gualtieromarencoturi/Desktop/thesis/Master-Thesis-CEM-Depression-etc-case-study
  Dataset dir: /Users/gualtieromarencoturi/Desktop/thesis/Master-Thesis-CEM-Depression-etc-case-study/data/processed/alternative_attention_pipeline
  Output dir: outputs_cbm_hard_alt

  NOTE: Using ALTERNATIVE dataset (SUM-based concept scoring)
        Validation set has TRUE concept labels!


In [5]:
# Define 21 BDI-II concept names
CONCEPT_NAMES = [
    "Sadness", "Pessimism", "Past failure", "Loss of pleasure",
    "Guilty feelings", "Punishment feelings", "Self-dislike", "Self-criticalness",
    "Suicidal thoughts or wishes", "Crying", "Agitation", "Loss of interest",
    "Indecisiveness", "Worthlessness", "Loss of energy", "Changes in sleeping pattern",
    "Irritability", "Changes in appetite", "Concentration difficulty",
    "Tiredness or fatigue", "Loss of interest in sex"
]
N_CONCEPTS = len(CONCEPT_NAMES)

print(f"✓ Defined {N_CONCEPTS} BDI-II concepts")

✓ Defined 21 BDI-II concepts


In [6]:
# Hyperparameters - TWO-STAGE CONFIGURATION
HYPERPARAMS = {
    # Model architecture
    "embedding_dim": 384,
    "n_concepts": 21,
    "n_tasks": 1,
    
    # STAGE 1: Concept predictor training
    "stage1_batch_size_train": 32,
    "stage1_batch_size_eval": 64,
    "stage1_max_epochs": 100,
    "stage1_learning_rate": 0.001,
    "stage1_weight_decay": 0.0001,
    
    # STAGE 2: Task classifier training
    "stage2_batch_size_train": 32,
    "stage2_batch_size_eval": 64,
    "stage2_max_epochs": 100,
    "stage2_learning_rate": 0.01,
    "stage2_weight_decay": 0.0001,
    
    # Hard bottleneck
    "concept_threshold": 0.5,
    
    # LDAM Loss for Stage 2 (task)
    "use_ldam_loss": True,
    "n_positive": None,  # Will be set after loading data
    "n_negative": None,  # Will be set after loading data
    "ldam_max_margin": 0.5,
    "ldam_scale": 40,
}

print("✓ Hyperparameters configured (TWO-STAGE)")
print(f"  Stage 1: Concept predictor - {HYPERPARAMS['stage1_max_epochs']} epochs")
print(f"  Stage 2: Task classifier - {HYPERPARAMS['stage2_max_epochs']} epochs")
print(f"  Hard bottleneck threshold: {HYPERPARAMS['concept_threshold']}")

✓ Hyperparameters configured (TWO-STAGE)
  Stage 1: Concept predictor - 100 epochs
  Stage 2: Task classifier - 100 epochs
  Hard bottleneck threshold: 0.5


## Section 1: Load Preprocessed Data

In [7]:
# Load training data
print("Loading preprocessed datasets...")

train_data = np.load(os.path.join(DATASET_DIR, "train_data.npz"))
X_train = train_data['X']
C_train = train_data['C']
y_train = train_data['y']
train_subject_ids = train_data['subject_ids']

print(f"✓ Loaded training data:")
print(f"  X_train: {X_train.shape}")
print(f"  C_train: {C_train.shape}")
print(f"  y_train: {y_train.shape}")
print(f"  Subject IDs: {len(train_subject_ids)}")

Loading preprocessed datasets...
✓ Loaded training data:
  X_train: (388, 384)
  C_train: (388, 21)
  y_train: (388,)
  Subject IDs: 388


In [8]:
# Load validation data
val_data = np.load(os.path.join(DATASET_DIR, "val_data.npz"))
X_val = val_data['X']
C_val = val_data['C']
y_val = val_data['y']
val_subject_ids = val_data['subject_ids']

print(f"✓ Loaded validation data:")
print(f"  X_val: {X_val.shape}")
print(f"  C_val: {C_val.shape} - has TRUE concept labels!")
print(f"  y_val: {y_val.shape}")
print(f"  Non-zero concept values: {np.count_nonzero(C_val)}")

✓ Loaded validation data:
  X_val: (98, 384)
  C_val: (98, 21) - has TRUE concept labels!
  y_val: (98,)
  Non-zero concept values: 126


In [9]:
# Load test data
test_data = np.load(os.path.join(DATASET_DIR, "test_data.npz"))
X_test = test_data['X']
C_test = test_data['C']
y_test = test_data['y']
test_subject_ids = test_data['subject_ids']

print(f"✓ Loaded test data:")
print(f"  X_test: {X_test.shape}")
print(f"  C_test: {C_test.shape}")
print(f"  y_test: {y_test.shape}")

✓ Loaded test data:
  X_test: (401, 384)
  C_test: (401, 21)
  y_test: (401,)


In [10]:
# Load class weights
with open(os.path.join(DATASET_DIR, "class_weights.json"), 'r') as f:
    class_info = json.load(f)

n_positive = class_info['n_positive']
n_negative = class_info['n_negative']
pos_weight = class_info['pos_weight']

# Update HYPERPARAMS with actual class counts for LDAM
HYPERPARAMS['n_positive'] = n_positive
HYPERPARAMS['n_negative'] = n_negative

print(f"✓ Loaded class weights:")
print(f"  Negative samples: {n_negative}")
print(f"  Positive samples: {n_positive}")
print(f"  Ratio: 1:{pos_weight:.2f}")
print(f"  pos_weight: {pos_weight:.4f}")

✓ Loaded class weights:
  Negative samples: 322
  Positive samples: 66
  Ratio: 1:4.88
  pos_weight: 4.8788


## Section 2: PyTorch Dataset & DataLoaders

In [11]:
class CBMDataset(Dataset):
    """PyTorch Dataset for CBM model."""
    
    def __init__(self, X, C, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.C = torch.tensor(C, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32)
    
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx], self.C[idx]

print("✓ CBMDataset class defined")

✓ CBMDataset class defined


In [12]:
# Create datasets
train_dataset = CBMDataset(X_train, C_train, y_train)
val_dataset = CBMDataset(X_val, C_val, y_val)
test_dataset = CBMDataset(X_test, C_test, y_test)

print("✓ Datasets created")
print(f"  Train: {len(train_dataset)} samples")
print(f"  Val: {len(val_dataset)} samples")
print(f"  Test: {len(test_dataset)} samples")

✓ Datasets created
  Train: 388 samples
  Val: 98 samples
  Test: 401 samples


## Section 3: Model Definitions

In [13]:
# LDAM Loss (for class imbalance in Stage 2)
class LDAMLoss(nn.Module):
    """
    Label-Distribution-Aware Margin (LDAM) Loss for long-tailed recognition.
    Creates class-dependent margins to make decision boundaries harder for minority classes.
    """
    def __init__(self, n_positive, n_negative, max_margin=0.5, scale=30):
        super(LDAMLoss, self).__init__()
        self.max_margin = max_margin
        self.scale = scale
        
        # Compute class frequencies
        total = n_positive + n_negative
        freq_pos = n_positive / total
        freq_neg = n_negative / total
        
        # Compute margins: minority class gets larger margin
        margin_pos = max_margin * (freq_pos ** (-0.25))
        margin_neg = max_margin * (freq_neg ** (-0.25))
        
        self.register_buffer('margin_pos', torch.tensor(margin_pos))
        self.register_buffer('margin_neg', torch.tensor(margin_neg))
    
    def forward(self, logits, targets):
        logits = logits.view(-1)
        targets = targets.view(-1).float()
        
        # Apply class-dependent margins
        margin = targets * self.margin_pos + (1 - targets) * (-self.margin_neg)
        adjusted_logits = (logits - margin) * self.scale
        
        return F.binary_cross_entropy_with_logits(adjusted_logits, targets, reduction='mean')

print("✓ LDAM Loss defined")

✓ LDAM Loss defined


In [14]:
# STAGE 1: Concept Predictor
class ConceptPredictor(pl.LightningModule):
    """
    Stage 1: Concept predictor (X → C)
    
    This model learns to predict concept labels from input embeddings.
    """
    def __init__(
        self,
        input_dim,
        n_concepts,
        learning_rate=0.001,
        weight_decay=1e-4,
    ):
        super().__init__()
        self.save_hyperparameters()
        
        # Concept extractor network
        self.concept_extractor = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, n_concepts)
        )
        
        # Loss function for concepts
        self.concept_loss_fn = nn.BCEWithLogitsLoss()
    
    def forward(self, x):
        return self.concept_extractor(x)
    
    def training_step(self, batch, batch_idx):
        x, y, c = batch
        c_logits = self(x)
        loss = self.concept_loss_fn(c_logits, c)
        
        self.log('train_concept_loss', loss, on_epoch=True, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y, c = batch
        c_logits = self(x)
        loss = self.concept_loss_fn(c_logits, c)
        
        self.log('val_concept_loss', loss, on_epoch=True, prog_bar=True)
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.parameters(),
            lr=self.hparams.learning_rate,
            weight_decay=self.hparams.weight_decay
        )
        
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='min',
            factor=0.5,
            patience=10,
            min_lr=1e-6,
            verbose=True
        )
        
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val_concept_loss",
                "interval": "epoch",
                "frequency": 1,
            },
        }

print("✓ Stage 1: ConceptPredictor defined")

✓ Stage 1: ConceptPredictor defined


In [15]:
# STAGE 2: Task Classifier
class TaskClassifier(pl.LightningModule):
    """
    Stage 2: Task classifier (C → Y)
    
    This model takes predicted concepts and classifies the task.
    The concept predictor is frozen during this stage.
    """
    def __init__(
        self,
        concept_predictor,
        n_concepts,
        task_output_dim,
        concept_threshold=0.5,
        learning_rate=0.01,
        weight_decay=1e-4,
        use_ldam_loss=True,
        n_positive=83,
        n_negative=403,
        ldam_max_margin=0.5,
        ldam_scale=30,
    ):
        super().__init__()
        self.save_hyperparameters(ignore=['concept_predictor'])
        
        # Frozen concept predictor from Stage 1
        self.concept_predictor = concept_predictor
        # Freeze all concept predictor parameters
        for param in self.concept_predictor.parameters():
            param.requires_grad = False
        
        # Task classifier (operates on discrete concepts)
        self.task_classifier = nn.Sequential(
            nn.Linear(n_concepts, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(64, task_output_dim)
        )
        
        # Loss function for task
        if use_ldam_loss:
            self.task_loss_fn = LDAMLoss(n_positive, n_negative, ldam_max_margin, ldam_scale)
            print(f"  Using LDAM Loss (margin={ldam_max_margin}, scale={ldam_scale})")
        else:
            pos_weight_tensor = torch.tensor([n_negative / n_positive], dtype=torch.float32)
            self.task_loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight_tensor)
            print(f"  Using BCE Loss with pos_weight={pos_weight_tensor.item():.2f}")
    
    def forward(self, x):
        # Get concept predictions from frozen predictor
        with torch.no_grad():
            c_logits = self.concept_predictor(x)
            c_probs = torch.sigmoid(c_logits)
        
        # Hard bottleneck: binarize concepts
        c_discrete = (c_probs >= self.hparams.concept_threshold).float()
        
        # Task prediction from discrete concepts
        y_logits = self.task_classifier(c_discrete)
        
        return y_logits, c_discrete
    
    def training_step(self, batch, batch_idx):
        x, y, c = batch
        y_logits, _ = self(x)
        loss = self.task_loss_fn(y_logits.squeeze(), y.squeeze())
        
        self.log('train_task_loss', loss, on_epoch=True, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y, c = batch
        y_logits, _ = self(x)
        loss = self.task_loss_fn(y_logits.squeeze(), y.squeeze())
        
        self.log('val_task_loss', loss, on_epoch=True, prog_bar=True)
        return loss
    
    def configure_optimizers(self):
        # Only optimize task classifier parameters (concept predictor is frozen)
        optimizer = torch.optim.Adam(
            self.task_classifier.parameters(),
            lr=self.hparams.learning_rate,
            weight_decay=self.hparams.weight_decay
        )
        
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='min',
            factor=0.5,
            patience=10,
            min_lr=1e-6,
            verbose=True
        )
        
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val_task_loss",
                "interval": "epoch",
                "frequency": 1,
            },
        }

print("✓ Stage 2: TaskClassifier defined")

✓ Stage 2: TaskClassifier defined


## Section 4: STAGE 1 - Train Concept Predictor

In [16]:
print("="*70)
print("          STAGE 1: TRAINING CONCEPT PREDICTOR (X → C)")
print("="*70)
print()
print("Training concept predictor to predict all 21 BDI-II concepts")
print("Validation set has TRUE concept labels (better supervision!)")
print()
print("="*70)

          STAGE 1: TRAINING CONCEPT PREDICTOR (X → C)

Training concept predictor to predict all 21 BDI-II concepts
Validation set has TRUE concept labels (better supervision!)



In [17]:
# Create Stage 1 DataLoaders
stage1_train_loader = DataLoader(
    train_dataset,
    batch_size=HYPERPARAMS['stage1_batch_size_train'],
    shuffle=True
)

stage1_val_loader = DataLoader(
    val_dataset,
    batch_size=HYPERPARAMS['stage1_batch_size_eval'],
    shuffle=False
)

print(f"✓ Stage 1 DataLoaders created")
print(f"  Train batches: {len(stage1_train_loader)}")
print(f"  Val batches: {len(stage1_val_loader)}")

✓ Stage 1 DataLoaders created
  Train batches: 13
  Val batches: 2


In [18]:
# Initialize Stage 1 model
concept_predictor = ConceptPredictor(
    input_dim=HYPERPARAMS['embedding_dim'],
    n_concepts=HYPERPARAMS['n_concepts'],
    learning_rate=HYPERPARAMS['stage1_learning_rate'],
    weight_decay=HYPERPARAMS['stage1_weight_decay'],
)

print("✓ Stage 1 model initialized")
print(f"  Learning rate: {HYPERPARAMS['stage1_learning_rate']}")
print(f"  Max epochs: {HYPERPARAMS['stage1_max_epochs']}")

✓ Stage 1 model initialized
  Learning rate: 0.001
  Max epochs: 100


In [19]:
# Stage 1 callbacks
stage1_early_stop = EarlyStopping(
    monitor='val_concept_loss',
    patience=20,
    mode='min',
    verbose=True
)

stage1_checkpoint = ModelCheckpoint(
    monitor="val_concept_loss",
    dirpath=os.path.join(OUTPUT_DIR, "models/stage1"),
    filename="concept-predictor-{epoch:02d}-{val_concept_loss:.2f}",
    save_top_k=1,
    mode="min"
)

# Stage 1 trainer
stage1_trainer = pl.Trainer(
    max_epochs=HYPERPARAMS['stage1_max_epochs'],
    accelerator=DEVICE,
    devices=1,
    logger=CSVLogger(save_dir=os.path.join(OUTPUT_DIR, "logs"), name="stage1_concepts"),
    log_every_n_steps=10,
    callbacks=[stage1_early_stop, stage1_checkpoint],
    enable_progress_bar=True
)

print("✓ Stage 1 trainer configured")

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


✓ Stage 1 trainer configured


In [20]:
# Train Stage 1
print("\nStarting Stage 1 training...\n")
stage1_trainer.fit(concept_predictor, stage1_train_loader, stage1_val_loader)
print("\n✓ Stage 1 training complete!")


  | Name              | Type              | Params
--------------------------------------------------------
0 | concept_extractor | Sequential        | 335 K 
1 | concept_loss_fn   | BCEWithLogitsLoss | 0     
--------------------------------------------------------
335 K     Trainable params
0         Non-trainable params
335 K     Total params
1.342     Total estimated model params size (MB)



Starting Stage 1 training...



Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_concept_loss improved. New best score: 0.509


Validation: 0it [00:00, ?it/s]

Metric val_concept_loss improved by 0.248 >= min_delta = 0.0. New best score: 0.261


Validation: 0it [00:00, ?it/s]

Metric val_concept_loss improved by 0.052 >= min_delta = 0.0. New best score: 0.209


Validation: 0it [00:00, ?it/s]

Metric val_concept_loss improved by 0.010 >= min_delta = 0.0. New best score: 0.199


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Monitored metric val_concept_loss did not improve in the last 20 records. Best score: 0.199. Signaling Trainer to stop.



✓ Stage 1 training complete!


## Section 5: STAGE 2 - Train Task Classifier

In [21]:
print("="*70)
print("          STAGE 2: TRAINING TASK CLASSIFIER (C → Y)")
print("="*70)
print()
print("Freezing concept predictor and training task classifier")
print(f"Using LDAM Loss for imbalanced classification")
print(f"Hard bottleneck threshold: {HYPERPARAMS['concept_threshold']}")
print()
print("="*70)

          STAGE 2: TRAINING TASK CLASSIFIER (C → Y)

Freezing concept predictor and training task classifier
Using LDAM Loss for imbalanced classification
Hard bottleneck threshold: 0.5



In [22]:
# Create Stage 2 DataLoaders
stage2_train_loader = DataLoader(
    train_dataset,
    batch_size=HYPERPARAMS['stage2_batch_size_train'],
    shuffle=True
)

stage2_val_loader = DataLoader(
    val_dataset,
    batch_size=HYPERPARAMS['stage2_batch_size_eval'],
    shuffle=False
)

print(f"✓ Stage 2 DataLoaders created")
print(f"  Train batches: {len(stage2_train_loader)}")
print(f"  Val batches: {len(stage2_val_loader)}")

✓ Stage 2 DataLoaders created
  Train batches: 13
  Val batches: 2


In [23]:
# Initialize Stage 2 model with frozen concept predictor
task_classifier = TaskClassifier(
    concept_predictor=concept_predictor,
    n_concepts=HYPERPARAMS['n_concepts'],
    task_output_dim=1,
    concept_threshold=HYPERPARAMS['concept_threshold'],
    learning_rate=HYPERPARAMS['stage2_learning_rate'],
    weight_decay=HYPERPARAMS['stage2_weight_decay'],
    use_ldam_loss=HYPERPARAMS['use_ldam_loss'],
    n_positive=HYPERPARAMS['n_positive'],
    n_negative=HYPERPARAMS['n_negative'],
    ldam_max_margin=HYPERPARAMS['ldam_max_margin'],
    ldam_scale=HYPERPARAMS['ldam_scale'],
)

# Initialize task classifier bias
pos_frac = HYPERPARAMS['n_positive'] / (HYPERPARAMS['n_positive'] + HYPERPARAMS['n_negative'])
bias_init = np.log(pos_frac / (1 - pos_frac))
if hasattr(task_classifier.task_classifier[-1], 'bias') and task_classifier.task_classifier[-1].bias is not None:
    task_classifier.task_classifier[-1].bias.data.fill_(bias_init)
    print(f"✓ Task classifier bias initialized to log-odds: {bias_init:.4f}")

print("\n✓ Stage 2 model initialized")
print(f"  Concept predictor: FROZEN (no gradient updates)")
print(f"  Learning rate: {HYPERPARAMS['stage2_learning_rate']}")
print(f"  Max epochs: {HYPERPARAMS['stage2_max_epochs']}")

  Using LDAM Loss (margin=0.5, scale=40)
✓ Task classifier bias initialized to log-odds: -1.5849

✓ Stage 2 model initialized
  Concept predictor: FROZEN (no gradient updates)
  Learning rate: 0.01
  Max epochs: 100


In [24]:
# Stage 2 callbacks
stage2_early_stop = EarlyStopping(
    monitor='val_task_loss',
    patience=20,
    mode='min',
    verbose=True
)

stage2_checkpoint = ModelCheckpoint(
    monitor="val_task_loss",
    dirpath=os.path.join(OUTPUT_DIR, "models/stage2"),
    filename="task-classifier-{epoch:02d}-{val_task_loss:.2f}",
    save_top_k=1,
    mode="min"
)

# Stage 2 trainer
stage2_trainer = pl.Trainer(
    max_epochs=HYPERPARAMS['stage2_max_epochs'],
    accelerator=DEVICE,
    devices=1,
    logger=CSVLogger(save_dir=os.path.join(OUTPUT_DIR, "logs"), name="stage2_task"),
    log_every_n_steps=10,
    callbacks=[stage2_early_stop, stage2_checkpoint],
    enable_progress_bar=True
)

print("✓ Stage 2 trainer configured")

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


✓ Stage 2 trainer configured


In [25]:
# Train Stage 2
print("\nStarting Stage 2 training...\n")
stage2_trainer.fit(task_classifier, stage2_train_loader, stage2_val_loader)
print("\n✓ Stage 2 training complete!")


  | Name              | Type             | Params
-------------------------------------------------------
0 | concept_predictor | ConceptPredictor | 335 K 
1 | task_classifier   | Sequential       | 11.5 K
2 | task_loss_fn      | LDAMLoss         | 0     
-------------------------------------------------------
11.5 K    Trainable params
335 K     Non-trainable params
346 K     Total params
1.388     Total estimated model params size (MB)



Starting Stage 2 training...



Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_task_loss improved. New best score: 11.822


Validation: 0it [00:00, ?it/s]

Metric val_task_loss improved by 2.426 >= min_delta = 0.0. New best score: 9.396


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_task_loss improved by 0.137 >= min_delta = 0.0. New best score: 9.259


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_task_loss improved by 0.141 >= min_delta = 0.0. New best score: 9.118


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_task_loss improved by 0.029 >= min_delta = 0.0. New best score: 9.089


Validation: 0it [00:00, ?it/s]

Metric val_task_loss improved by 0.222 >= min_delta = 0.0. New best score: 8.867


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_task_loss improved by 0.925 >= min_delta = 0.0. New best score: 7.942


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Monitored metric val_task_loss did not improve in the last 20 records. Best score: 7.942. Signaling Trainer to stop.



✓ Stage 2 training complete!


## Section 6: Test Evaluation

In [26]:
# Set model to evaluation mode
task_classifier.eval()

# Move model to device
device_obj = torch.device(DEVICE)
task_classifier = task_classifier.to(device_obj)

print("✓ Model set to evaluation mode")

✓ Model set to evaluation mode


In [27]:
# Create test DataLoader
test_loader = DataLoader(
    test_dataset,
    batch_size=HYPERPARAMS['stage2_batch_size_eval'],
    shuffle=False
)

print(f"✓ Test DataLoader created ({len(test_loader)} batches)")

✓ Test DataLoader created (7 batches)


In [28]:
# Run inference on validation set for threshold selection
print("\nSelecting decision threshold on validation set...")

y_val_true = []
y_val_prob = []

with torch.no_grad():
    for x_batch, y_batch, c_batch in stage2_val_loader:
        x_batch = x_batch.to(device_obj)
        
        y_logits, _ = task_classifier(x_batch)
        y_probs = torch.sigmoid(y_logits).cpu().squeeze().numpy()
        
        y_val_true.extend(y_batch.numpy().astype(int).tolist())
        y_val_prob.extend(y_probs.tolist() if isinstance(y_probs, np.ndarray) else [y_probs])

y_val_true = np.array(y_val_true)
y_val_prob = np.array(y_val_prob)

# Find best threshold
best_threshold = 0.5
best_mcc = -1.0

for threshold in np.linspace(0.1, 0.9, 50):
    y_pred_temp = (y_val_prob >= threshold).astype(int)
    mcc = matthews_corrcoef(y_val_true, y_pred_temp)
    if mcc > best_mcc:
        best_mcc = mcc
        best_threshold = threshold

print(f"✓ Selected validation threshold: {best_threshold:.2f}")
print(f"  Validation MCC: {best_mcc:.4f}")


Selecting decision threshold on validation set...
✓ Selected validation threshold: 0.77
  Validation MCC: 0.5075


In [29]:
# Run inference on test set
print("\nRunning inference on test set...")

y_true_list = []
y_prob_list = []
concept_probs_list = []

with torch.no_grad():
    for x_batch, y_batch, c_batch in test_loader:
        x_batch = x_batch.to(device_obj)
        
        # Get concept predictions
        c_logits = concept_predictor(x_batch)
        c_probs = torch.sigmoid(c_logits).cpu().numpy()
        
        # Get task predictions
        y_logits, _ = task_classifier(x_batch)
        y_probs = torch.sigmoid(y_logits).cpu().squeeze().numpy()
        
        y_true_list.extend(y_batch.numpy().astype(int).tolist())
        y_prob_list.extend(y_probs.tolist() if isinstance(y_probs, np.ndarray) else [y_probs])
        concept_probs_list.extend(c_probs.tolist())

# Convert to arrays
y_true = np.array(y_true_list)
y_prob = np.array(y_prob_list)
concept_probs = np.array(concept_probs_list)
y_pred = (y_prob >= best_threshold).astype(int)

print("✓ Inference complete")
print(f"  Probabilities shape: {y_prob.shape}")
print(f"  Concept probs shape: {concept_probs.shape}")


Running inference on test set...
✓ Inference complete
  Probabilities shape: (401,)
  Concept probs shape: (401, 21)


## Section 7: Results Display

In [30]:
# Compute all metrics
cm = confusion_matrix(y_true, y_pred)
tn, fp, fn, tp = cm.ravel()

acc = accuracy_score(y_true, y_pred)
balanced_acc = balanced_accuracy_score(y_true, y_pred)
roc_auc = roc_auc_score(y_true, y_prob)
mcc = matthews_corrcoef(y_true, y_pred)
f1_binary = f1_score(y_true, y_pred, pos_label=1)
f1_macro = f1_score(y_true, y_pred, average='macro')
precision_binary = precision_score(y_true, y_pred, pos_label=1)
recall_binary = recall_score(y_true, y_pred, pos_label=1)

# Print results
print("\n" + "="*70)
print("                    TEST SET EVALUATION")
print("="*70)
print(f"\nDecision Threshold: {best_threshold:.2f}")

# Enhanced Confusion Matrix Display
print(f"\n{'CONFUSION MATRIX':^50}")
print("="*50)
print(f"{'':>20} │ {'Predicted Negative':^12} │ {'Predicted Positive':^12}")
print("─"*50)
print(f"{'Actual Negative':>20} │ {f'TN = {tn}':^12} │ {f'FP = {fp}':^12}")
print(f"{'Actual Positive':>20} │ {f'FN = {fn}':^12} │ {f'TP = {tp}':^12}")
print("="*50)
print(f"\n  True Positives:  {tp:>3}/{int(np.sum(y_true)):<3} ({100*tp/np.sum(y_true):>5.1f}% of depression cases caught)")
print(f"  False Negatives: {fn:>3}/{int(np.sum(y_true)):<3} ({100*fn/np.sum(y_true):>5.1f}% of depression cases MISSED)")
print(f"  True Negatives:  {tn:>3}/{int(len(y_true)-np.sum(y_true)):<3} ({100*tn/(len(y_true)-np.sum(y_true)):>5.1f}% of healthy correctly identified)")
print(f"  False Positives: {fp:>3}/{int(len(y_true)-np.sum(y_true)):<3} ({100*fp/(len(y_true)-np.sum(y_true)):>5.1f}% false alarms)")

print(f"\nPerformance Metrics:")
print(f"  Accuracy:                  {acc:.4f}")
print(f"  Balanced Accuracy:         {balanced_acc:.4f}")
print(f"  ROC-AUC:                   {roc_auc:.4f}")
print(f"  Matthews Correlation:      {mcc:.4f}")
print(f"\n  F1 Score (Binary):         {f1_binary:.4f}")
print(f"  F1 Score (Macro):          {f1_macro:.4f}")
print(f"  Precision (Binary):        {precision_binary:.4f}")
print(f"  Recall (Binary):           {recall_binary:.4f}")

print("\n" + classification_report(y_true, y_pred, target_names=['Negative', 'Positive']))
print("="*70)


                    TEST SET EVALUATION

Decision Threshold: 0.77

                 CONFUSION MATRIX                 
                     │ Predicted Negative │ Predicted Positive
──────────────────────────────────────────────────
     Actual Negative │   TN = 337   │   FP = 12   
     Actual Positive │   FN = 46    │    TP = 6   

  True Positives:    6/52  ( 11.5% of depression cases caught)
  False Negatives:  46/52  ( 88.5% of depression cases MISSED)
  True Negatives:  337/349 ( 96.6% of healthy correctly identified)
  False Positives:  12/349 (  3.4% false alarms)

Performance Metrics:
  Accuracy:                  0.8554
  Balanced Accuracy:         0.5405
  ROC-AUC:                   0.5965
  Matthews Correlation:      0.1314

  F1 Score (Binary):         0.1714
  F1 Score (Macro):          0.5461
  Precision (Binary):        0.3333
  Recall (Binary):           0.1154

              precision    recall  f1-score   support

    Negative       0.88      0.97      0.92       349


In [31]:
# Save metrics to JSON
metrics_dict = {
    "model_type": "cbm_hard_two_stage",
    "dataset": "alternative_attention_pipeline",
    "threshold": float(best_threshold),
    "n_samples": int(len(y_true)),
    "n_positive": int(np.sum(y_true)),
    "n_negative": int(len(y_true) - np.sum(y_true)),
    "accuracy": float(acc),
    "balanced_accuracy": float(balanced_acc),
    "roc_auc": float(roc_auc),
    "mcc": float(mcc),
    "f1_binary": float(f1_binary),
    "f1_macro": float(f1_macro),
    "precision_binary": float(precision_binary),
    "recall_binary": float(recall_binary),
    "confusion_matrix": {"tn": int(tn), "fp": int(fp), "fn": int(fn), "tp": int(tp)}
}

os.makedirs(os.path.join(OUTPUT_DIR, "results"), exist_ok=True)
with open(os.path.join(OUTPUT_DIR, "results/test_metrics.json"), 'w') as f:
    json.dump(metrics_dict, f, indent=4)

print(f"✓ Metrics saved to {OUTPUT_DIR}/results/test_metrics.json")

✓ Metrics saved to outputs_cbm_hard_alt/results/test_metrics.json


In [32]:
# Save predictions
predictions_df = pd.DataFrame({
    'subject_id': test_subject_ids,
    'y_true': y_true,
    'y_pred': y_pred,
    'y_prob': y_prob
})

for i, concept_name in enumerate(CONCEPT_NAMES):
    predictions_df[concept_name] = concept_probs[:, i]

predictions_df.to_csv(os.path.join(OUTPUT_DIR, "results/test_predictions.csv"), index=False)

print(f"✓ Predictions saved to {OUTPUT_DIR}/results/test_predictions.csv")

✓ Predictions saved to outputs_cbm_hard_alt/results/test_predictions.csv


In [33]:
print("\n" + "="*70)
print("        CBM HARD TWO-STAGE TRAINING COMPLETE")
print("="*70)
print(f"\nGenerated files:")
print(f"  Stage 1 checkpoint: {OUTPUT_DIR}/models/stage1/")
print(f"  Stage 2 checkpoint: {OUTPUT_DIR}/models/stage2/")
print(f"  Metrics JSON:       {OUTPUT_DIR}/results/test_metrics.json")
print(f"  Predictions CSV:    {OUTPUT_DIR}/results/test_predictions.csv")
print("\nKey features:")
print("  - Two-stage training (concepts then task)")
print("  - Hard bottleneck with discrete concepts")
print("  - LDAM loss for imbalanced classification")
print("  - Alternative dataset (SUM-based scoring)")
print("  - Validation set with TRUE concept labels")
print("="*70)


        CBM HARD TWO-STAGE TRAINING COMPLETE

Generated files:
  Stage 1 checkpoint: outputs_cbm_hard_alt/models/stage1/
  Stage 2 checkpoint: outputs_cbm_hard_alt/models/stage2/
  Metrics JSON:       outputs_cbm_hard_alt/results/test_metrics.json
  Predictions CSV:    outputs_cbm_hard_alt/results/test_predictions.csv

Key features:
  - Two-stage training (concepts then task)
  - Hard bottleneck with discrete concepts
  - LDAM loss for imbalanced classification
  - Alternative dataset (SUM-based scoring)
  - Validation set with TRUE concept labels
