# Complete CBM Pipeline - Concept Bottleneck Model

This notebook implements a complete pipeline for:
1. Loading preprocessed data from disk
2. Training a simple Concept Bottleneck Model (CBM)
3. Evaluating with detailed metrics and concept probabilities

**Architecture:** X → Concept Logits → Concepts → Task Prediction

## Section 0: Configuration & Setup

In [1]:
# Imports
import os
import json
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

import pytorch_lightning as pl
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

from sklearn.metrics import (
    confusion_matrix,
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    matthews_corrcoef,
    roc_auc_score,
    balanced_accuracy_score,
    classification_report,
)

print("✓ All imports successful")

✓ All imports successful


In [2]:
# Set random seeds for reproducibility
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
pl.seed_everything(SEED)

print(f"✓ Random seed set to {SEED}")

Global seed set to 42


✓ Random seed set to 42


In [3]:
# Detect device (MPS/CUDA/CPU)
if torch.backends.mps.is_available():
    DEVICE = "mps"
    print("✓ Using MacBook GPU (MPS)")
elif torch.cuda.is_available():
    DEVICE = "cuda"
    print("✓ Using CUDA GPU")
else:
    DEVICE = "cpu"
    print("⚠ Using CPU (will be slow)")

✓ Using MacBook GPU (MPS)


In [4]:
# Define paths
PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), ".."))
DATA_PROCESSED = os.path.join(PROJECT_ROOT, "data/processed")
DATASET_DIR = os.path.join(DATA_PROCESSED, "whole_attention_pipeline")
OUTPUT_DIR = "outputs_cbm"

print("✓ Paths configured")
print(f"  Project root: {PROJECT_ROOT}")
print(f"  Dataset dir: {DATASET_DIR}")
print(f"  Output dir: {OUTPUT_DIR}")

✓ Paths configured
  Project root: /Users/gualtieromarencoturi/Desktop/thesis/Master-Thesis-CEM-Depression-etc-case-study
  Dataset dir: /Users/gualtieromarencoturi/Desktop/thesis/Master-Thesis-CEM-Depression-etc-case-study/data/processed/whole_attention_pipeline
  Output dir: outputs_cbm


In [5]:
# Define 21 BDI-II concept names
CONCEPT_NAMES = [
    "Sadness", "Pessimism", "Past failure", "Loss of pleasure",
    "Guilty feelings", "Punishment feelings", "Self-dislike", "Self-criticalness",
    "Suicidal thoughts or wishes", "Crying", "Agitation", "Loss of interest",
    "Indecisiveness", "Worthlessness", "Loss of energy", "Changes in sleeping pattern",
    "Irritability", "Changes in appetite", "Concentration difficulty",
    "Tiredness or fatigue", "Loss of interest in sex"
]
N_CONCEPTS = len(CONCEPT_NAMES)

print(f"✓ Defined {N_CONCEPTS} BDI-II concepts")

✓ Defined 21 BDI-II concepts


In [6]:
# Hyperparameters - UPDATED for better learning
HYPERPARAMS = {
    # Model architecture
    "embedding_dim": 384,       # SBERT embedding dimension
    "n_concepts": 21,
    "n_tasks": 1,
    
    # Training - REDUCED LR, INCREASED BATCH SIZE
    "batch_size_train": 64,     # Increased for stability
    "batch_size_eval": 128,     # Increased for faster eval
    "max_epochs": 200,          # More epochs for convergence
    "learning_rate": 0.001,     # REDUCED LR for stability
    "weight_decay": 0.001,      # Reduced regularization
    
    # Loss weights - BALANCED LOSSES
    "concept_loss_weight": 1.0, # EQUAL weight for concepts and task
    "task_pos_weight": 4.86,    # TRUE imbalance ratio
    
    # Hard bottleneck configuration - RELAXED
    "hard_bottleneck": False,   # TRY SOFT BOTTLENECK FIRST
    "concept_threshold": 0.3,   # Higher threshold for more activation
}

## Section 1: Load Preprocessed Data

Load the datasets that were saved by the CEM pipeline

In [7]:
# Load training data
print("Loading preprocessed datasets...")

train_data = np.load(os.path.join(DATASET_DIR, "train_data.npz"))
X_train = train_data['X']
C_train = train_data['C']
y_train = train_data['y']
train_subject_ids = train_data['subject_ids']

print(f"✓ Loaded training data:")
print(f"  X_train: {X_train.shape}")
print(f"  C_train: {C_train.shape}")
print(f"  y_train: {y_train.shape}")
print(f"  Subject IDs: {len(train_subject_ids)}")

Loading preprocessed datasets...
✓ Loaded training data:
  X_train: (486, 384)
  C_train: (486, 21)
  y_train: (486,)
  Subject IDs: 486


In [8]:
# Load validation data
val_data = np.load(os.path.join(DATASET_DIR, "val_data.npz"))
X_val = val_data['X']
C_val = val_data['C']
y_val = val_data['y']
val_subject_ids = val_data['subject_ids']

print(f"✓ Loaded validation data:")
print(f"  X_val: {X_val.shape}")
print(f"  C_val: {C_val.shape}")
print(f"  y_val: {y_val.shape}")

✓ Loaded validation data:
  X_val: (200, 384)
  C_val: (200, 21)
  y_val: (200,)


In [9]:
# Load test data
test_data = np.load(os.path.join(DATASET_DIR, "test_data.npz"))
X_test = test_data['X']
C_test = test_data['C']
y_test = test_data['y']
test_subject_ids = test_data['subject_ids']

print(f"✓ Loaded test data:")
print(f"  X_test: {X_test.shape}")
print(f"  C_test: {C_test.shape}")
print(f"  y_test: {y_test.shape}")

✓ Loaded test data:
  X_test: (201, 384)
  C_test: (201, 21)
  y_test: (201,)


In [10]:
# Load class weights
with open(os.path.join(DATASET_DIR, "class_weights.json"), 'r') as f:
    class_info = json.load(f)

n_positive = class_info['n_positive']
n_negative = class_info['n_negative']
pos_weight = class_info['pos_weight']

# Convert to tensor
pos_weight_tensor = torch.tensor([pos_weight], dtype=torch.float32)

print(f"✓ Loaded class weights:")
print(f"  Negative samples: {n_negative}")
print(f"  Positive samples: {n_positive}")
print(f"  Ratio: 1:{pos_weight:.2f}")
print(f"  pos_weight: {pos_weight:.4f}")

✓ Loaded class weights:
  Negative samples: 403
  Positive samples: 83
  Ratio: 1:4.86
  pos_weight: 4.8554


## Section 2: PyTorch Dataset & DataLoaders

In [11]:
class CBMDataset(Dataset):
    """PyTorch Dataset for CBM model."""
    
    def __init__(self, X, C, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.C = torch.tensor(C, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32)
    
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx], self.C[idx]

print("✓ CBMDataset class defined")

✓ CBMDataset class defined


In [12]:
# Create datasets
train_dataset = CBMDataset(X_train, C_train, y_train)
val_dataset = CBMDataset(X_val, C_val, y_val)
test_dataset = CBMDataset(X_test, C_test, y_test)

print("✓ Datasets created")
print(f"  Train: {len(train_dataset)} samples")
print(f"  Val: {len(val_dataset)} samples")
print(f"  Test: {len(test_dataset)} samples")

✓ Datasets created
  Train: 486 samples
  Val: 200 samples
  Test: 201 samples


In [13]:
# Create DataLoaders
train_loader = DataLoader(
    train_dataset,
    batch_size=HYPERPARAMS['batch_size_train'],
    shuffle=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=HYPERPARAMS['batch_size_eval'],
    shuffle=False
)

test_loader = DataLoader(
    test_dataset,
    batch_size=HYPERPARAMS['batch_size_eval'],
    shuffle=False
)

print("✓ DataLoaders created")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches: {len(val_loader)}")
print(f"  Test batches: {len(test_loader)}")

# Test batch
x_batch, y_batch, c_batch = next(iter(train_loader))
print(f"\n  Sample batch shapes:")
print(f"    X: {x_batch.shape}")
print(f"    y: {y_batch.shape}")
print(f"    C: {c_batch.shape}")

✓ DataLoaders created
  Train batches: 8
  Val batches: 2
  Test batches: 2

  Sample batch shapes:
    X: torch.Size([64, 384])
    y: torch.Size([64])
    C: torch.Size([64, 21])


## Section 3: Concept Bottleneck Model

Simple CBM architecture:
- X → Concept Extractor → Concept Logits
- Concept Probabilities (sigmoid) → Task Classifier → Task Logits

In [14]:
class ConceptBottleneckModel(pl.LightningModule):
    """
    Concept Bottleneck Model with HARD bottleneck support.
    
    Architecture:
      X -> concept extractor -> concept logits -> discrete concepts (hard bottleneck) -> task classifier -> y
    
    Hard bottleneck ensures:
      - Concepts are discrete (0 or 1), not continuous probabilities
      - Task classifier only sees binary concept values
      - True interpretability and intervention capability
    """
    def __init__(
        self,
        input_dim,
        n_concepts,
        task_output_dim,
        c_extractor_arch,
        learning_rate=0.01,
        weight_decay=1e-4,
        concept_loss_weight=1.0,
        task_pos_weight=2.0,
        hard_bottleneck=True,
        concept_threshold=0.5,
    ):
        super().__init__()
        self.save_hyperparameters(ignore=['c_extractor_arch'])
        
        # Concept extractor
        self.concept_extractor = c_extractor_arch(output_dim=n_concepts)
        
        # IMPROVED Task classifier - small MLP instead of linear
        self.task_classifier = nn.Sequential(
            nn.Linear(n_concepts, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(64, task_output_dim)
        )
        
        # Loss functions
        self.concept_loss_fn = nn.BCEWithLogitsLoss()
        
        # Task loss: plain BCEWithLogitsLoss with modest pos_weight
        # Use the passed task_pos_weight parameter
        task_pos_weight_tensor = torch.tensor([task_pos_weight], dtype=torch.float32)
        self.task_loss_fn = nn.BCEWithLogitsLoss(pos_weight=task_pos_weight_tensor)
        print(f"✓ Task BCEWithLogitsLoss pos_weight set to {task_pos_weight_tensor.item():.2f}")

        # For threshold optimization during validation
        self.validation_outputs = []
    
    def forward(self, x, use_ground_truth_concepts=False, c_true=None):
        """
        Forward pass with hard bottleneck support.
        
        Args:
            x: Input features
            use_ground_truth_concepts: If True, use c_true instead of predictions (for intervention)
            c_true: Ground truth concepts (optional)
        
        Returns:
            concept_logits: Raw concept predictions (None if using ground truth)
            task_logits: Task predictions
        """
        if use_ground_truth_concepts and c_true is not None:
            # Use ground truth concepts (for intervention experiments)
            concept_input = c_true
            concept_logits = None
        else:
            # Get concept predictions
            concept_logits = self.concept_extractor(x)
            concept_probs = torch.sigmoid(concept_logits)
            
            if self.hparams.hard_bottleneck:
                # HARD BOTTLENECK: Use discrete predictions, detach gradients
                # This creates a true bottleneck - task classifier only sees binary concept values
                # and cannot backpropagate gradients to the concept extractor
                concept_input = (concept_probs >= self.hparams.concept_threshold).float().detach()
            else:
                # SOFT BOTTLENECK: Use continuous probabilities (end-to-end training)
                concept_input = concept_probs
        
        # Task prediction from concepts
        task_logits = self.task_classifier(concept_input)
        
        return concept_logits, task_logits
    
    def training_step(self, batch, batch_idx):
        x, y, c = batch
        concept_logits, task_logits = self(x)
        
        # Concept loss
        concept_loss = self.concept_loss_fn(concept_logits, c)
        
        # Task loss
        task_loss = self.task_loss_fn(task_logits.squeeze(), y.squeeze())
        
        # Combined loss (single-stage)
        loss = task_loss + self.hparams.concept_loss_weight * concept_loss
        
        # Log metrics
        self.log('train_loss', loss, on_epoch=True, prog_bar=True)
        self.log('train_concept_loss', concept_loss, on_epoch=True)
        self.log('train_task_loss', task_loss, on_epoch=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y, c = batch
        concept_logits, task_logits = self(x)
        
        # Concept loss
        concept_loss = self.concept_loss_fn(concept_logits, c)
        
        # Task loss
        task_loss = self.task_loss_fn(task_logits.squeeze(), y.squeeze())
        
        # Combined loss
        loss = task_loss + self.hparams.concept_loss_weight * concept_loss
        
        # Store outputs for threshold optimization
        self.validation_outputs.append({
            'task_logits': task_logits.detach().cpu().float(),  # Convert to float32
            'y_true': y.detach().cpu().float()  # Convert to float32
        })
        
        # Log metrics
        self.log('val_loss', loss, on_epoch=True, prog_bar=True)
        self.log('val_concept_loss', concept_loss, on_epoch=True)
        self.log('val_task_loss', task_loss, on_epoch=True)
        
        return loss
    
    def on_validation_epoch_end(self):
        """Find best threshold on validation set at end of each epoch."""
        if len(self.validation_outputs) == 0:
            return
        
        # Collect all validation predictions
        all_logits = torch.cat([x['task_logits'] for x in self.validation_outputs])
        all_y_true = torch.cat([x['y_true'] for x in self.validation_outputs])
        
        # Convert to numpy as float32 (not float64)
        logits = all_logits.squeeze().numpy().astype(np.float32)
        y_true = all_y_true.numpy().astype(np.int32)
        probs = 1.0 / (1.0 + np.exp(-logits))  # sigmoid
        
        # Log logit statistics - ensure float32 before converting to Python float
        self.log('val_logit_min', float(np.min(logits).astype(np.float32)))
        self.log('val_logit_mean', float(np.mean(logits).astype(np.float32)))
        self.log('val_logit_max', float(np.max(logits).astype(np.float32)))
        
        # Sweep thresholds to find best MCC
        from sklearn.metrics import matthews_corrcoef
        
        best_mcc = -1.0
        best_threshold = 0.5
        
        for threshold in np.arange(0.1, 0.9, 0.05):
            y_pred = (probs >= threshold).astype(np.int32)
            mcc = matthews_corrcoef(y_true, y_pred)
            if mcc > best_mcc:
                best_mcc = mcc
                best_threshold = float(threshold)
        
        # Log as float32 to avoid MPS issues
        self.log('val_best_threshold', float(np.float32(best_threshold)))
        self.log('val_best_mcc', float(np.float32(best_mcc)))
        
        # Clear for next epoch
        self.validation_outputs.clear()
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.parameters(),
            lr=self.hparams.learning_rate,
            weight_decay=self.hparams.weight_decay
        )
        
        # Add learning rate scheduler for better convergence
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, 
            mode='min', 
            factor=0.5, 
            patience=10, 
            min_lr=1e-6,
            verbose=True
        )
        
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val_loss",
                "interval": "epoch",
                "frequency": 1,
            },
        }

print("✓ CBM model defined with HARD bottleneck support")

✓ CBM model defined with HARD bottleneck support


## Section 4: Model Initialization

In [15]:
def c_extractor_arch(output_dim):
    """IMPROVED Concept extractor architecture - deeper network."""
    return nn.Sequential(
        nn.Linear(HYPERPARAMS['embedding_dim'], 512),
        nn.BatchNorm1d(512),
        nn.ReLU(),
        nn.Dropout(0.2),
        nn.Linear(512, 256),
        nn.BatchNorm1d(256),
        nn.ReLU(),
        nn.Dropout(0.2),
        nn.Linear(256, output_dim)
    )

print("✓ Concept extractor architecture defined")

✓ Concept extractor architecture defined


In [16]:
# Get class counts
n_positive = class_info['n_positive']
n_negative = class_info['n_negative']

# Initialize CBM model
cbm_model = ConceptBottleneckModel(
    input_dim=HYPERPARAMS['embedding_dim'],
    n_concepts=HYPERPARAMS['n_concepts'],
    task_output_dim=1,
    c_extractor_arch=c_extractor_arch,
    learning_rate=HYPERPARAMS['learning_rate'],
    weight_decay=HYPERPARAMS['weight_decay'],
    concept_loss_weight=HYPERPARAMS['concept_loss_weight'],
    task_pos_weight=HYPERPARAMS['task_pos_weight'],
    hard_bottleneck=HYPERPARAMS['hard_bottleneck'],
    concept_threshold=HYPERPARAMS['concept_threshold'],
)


# Initialize task classifier bias to match positive prevalence
pos_frac = n_positive / (n_positive + n_negative)  # e.g., 26/201
bias_init = np.log(pos_frac / (1 - pos_frac))      # log-odds

# Safely access bias from the last layer in the Sequential model
if hasattr(cbm_model.task_classifier[-1], 'bias') and cbm_model.task_classifier[-1].bias is not None:
    cbm_model.task_classifier[-1].bias.data.fill_(bias_init)
    print(f"✓ Task classifier bias initialized to log-odds: {bias_init:.4f}")
else:
    print(f"⚠ Could not initialize bias (last layer has no bias), bias_init would be: {bias_init:.4f}")
print("✓ CBM model initialized")
print(f"  Task pos_weight: {HYPERPARAMS['task_pos_weight']}")
print(f"  Hard bottleneck: {HYPERPARAMS['hard_bottleneck']} (HARD with threshold={HYPERPARAMS['concept_threshold']})")
print(f"  Concept threshold: {HYPERPARAMS['concept_threshold']} (lowered from 0.5 to match learned distributions)")
print(f"  Concept loss weight: {HYPERPARAMS['concept_loss_weight']} (reduced to prioritize task)")
print(f"  Learning rate: {HYPERPARAMS['learning_rate']} (increased to escape local minima)")

✓ Task BCEWithLogitsLoss pos_weight set to 4.86
✓ Task classifier bias initialized to log-odds: -1.5801
✓ CBM model initialized
  Task pos_weight: 4.86
  Hard bottleneck: False (HARD with threshold=0.3)
  Concept threshold: 0.3 (lowered from 0.5 to match learned distributions)
  Concept loss weight: 1.0 (reduced to prioritize task)
  Learning rate: 0.001 (increased to escape local minima)


## Section 5: Training

In [17]:
print("="*70)
print("          SINGLE-STAGE HARD BOTTLENECK TRAINING")
print("="*70)
print()
print("Training concepts and task jointly with hard bottleneck (threshold=0.15)")
print("Lower threshold allows ~30-50% of concepts to activate (not all zeros)")
print()
print("="*70)

# Early stopping callback
early_stop = EarlyStopping(
    monitor='val_loss',
    patience=20,
    mode='min',
    verbose=True
)

# Checkpoint callback
checkpoint = ModelCheckpoint(
    monitor="val_loss",
    dirpath=os.path.join(OUTPUT_DIR, "models"),
    filename="cbm-{epoch:02d}-{val_loss:.2f}",
    save_top_k=1,
    mode="min"
)

# ADDITIONAL CALLBACK: Monitor concept activation
class ConceptActivationCallback(pl.Callback):
    def on_validation_epoch_end(self, trainer, pl_module):
        # Log concept activation statistics
        if hasattr(pl_module, 'validation_outputs') and len(pl_module.validation_outputs) > 0:
            # Get concept predictions from a few validation batches
            concept_activations = []
            for batch in trainer.val_dataloaders[0]:
                x_batch, y_batch, c_batch = batch
                x_batch = x_batch.to(pl_module.device)
                c_logits, _ = pl_module(x_batch)
                c_probs = torch.sigmoid(c_logits).detach().cpu().numpy()
                concept_activations.append(c_probs)
                if len(concept_activations) >= 5:  # Sample from 5 batches
                    break
            
            if concept_activations:
                avg_activations = np.mean(np.concatenate(concept_activations), axis=0)
                for i, activation in enumerate(avg_activations):
                    pl_module.log(f'concept_{i}_activation', float(activation))
                
                # Log overall concept sparsity
                sparsity = np.mean(avg_activations < 0.1)  # Concepts with <10% activation
                pl_module.log('concept_sparsity', float(sparsity))

# Trainer with additional monitoring
trainer = pl.Trainer(
    max_epochs=HYPERPARAMS['max_epochs'],
    accelerator=DEVICE,
    devices=1,
    logger=CSVLogger(save_dir=os.path.join(OUTPUT_DIR, "logs"), name="cbm"),
    log_every_n_steps=10,
    callbacks=[early_stop, checkpoint, ConceptActivationCallback()],
    enable_progress_bar=True
)

# DEBUGGING CELL: Analyze model behavior before training
print("\n" + "="*70)
print("MODEL ANALYSIS BEFORE TRAINING")
print("="*70)

# Test forward pass on a single batch
x_batch, y_batch, c_batch = next(iter(train_loader))
x_batch = x_batch.to(device_obj) if 'device_obj' in locals() else x_batch

with torch.no_grad():
    c_logits, y_logits = cbm_model(x_batch)
    c_probs = torch.sigmoid(c_logits)
    y_probs = torch.sigmoid(y_logits)

print(f"Input shape: {x_batch.shape}")
print(f"Concept logits shape: {c_logits.shape}")
print(f"Task logits shape: {y_logits.shape}")
print(f"Concept probs range: [{c_probs.min().item():.3f}, {c_probs.max().item():.3f}]")
print(f"Task probs range: [{y_probs.min().item():.3f}, {y_probs.max().item():.3f}]")

# Check concept activation distribution
concept_means = c_probs.mean(dim=0)
print(f"\nConcept activation means: {concept_means}")
print(f"Concepts with >50% activation: {(concept_means > 0.5).sum().item()}/{len(concept_means)}")

# Check class distribution in batch
print(f"\nBatch class distribution: {y_batch.sum().item()}/{len(y_batch)} positive samples")

print("\n" + "="*70)
print("STARTING TRAINING...")
print("="*70)

# Train
trainer.fit(cbm_model, train_loader, val_loader)

print("\n" + "="*70)
print("✓ TRAINING COMPLETE")
print("="*70)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


          SINGLE-STAGE HARD BOTTLENECK TRAINING

Training concepts and task jointly with hard bottleneck (threshold=0.15)
Lower threshold allows ~30-50% of concepts to activate (not all zeros)


MODEL ANALYSIS BEFORE TRAINING
Input shape: torch.Size([64, 384])
Concept logits shape: torch.Size([64, 21])
Task logits shape: torch.Size([64, 1])
Concept probs range: [0.053, 0.942]
Task probs range: [0.033, 0.338]

Concept activation means: tensor([0.5356, 0.5261, 0.4431, 0.4537, 0.4602, 0.4069, 0.5374, 0.5406, 0.5243,
        0.5960, 0.4888, 0.5445, 0.5317, 0.5341, 0.4959, 0.5237, 0.4854, 0.3957,
        0.4813, 0.4105, 0.5528])
Concepts with >50% activation: 11/21

Batch class distribution: 8.0/64 positive samples

STARTING TRAINING...



  | Name              | Type              | Params
--------------------------------------------------------
0 | concept_extractor | Sequential        | 335 K 
1 | task_classifier   | Sequential        | 1.6 K 
2 | concept_loss_fn   | BCEWithLogitsLoss | 0     
3 | task_loss_fn      | BCEWithLogitsLoss | 0     
--------------------------------------------------------
336 K     Trainable params
0         Non-trainable params
336 K     Total params
1.348     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_loss improved. New best score: 1.971


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.145 >= min_delta = 0.0. New best score: 1.826


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.185 >= min_delta = 0.0. New best score: 1.641


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.217 >= min_delta = 0.0. New best score: 1.424


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.175 >= min_delta = 0.0. New best score: 1.249


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.118 >= min_delta = 0.0. New best score: 1.131


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.145 >= min_delta = 0.0. New best score: 0.986


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.120 >= min_delta = 0.0. New best score: 0.866


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.042 >= min_delta = 0.0. New best score: 0.824


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.078 >= min_delta = 0.0. New best score: 0.746


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.008 >= min_delta = 0.0. New best score: 0.738


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.028 >= min_delta = 0.0. New best score: 0.709


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Monitored metric val_loss did not improve in the last 20 records. Best score: 0.709. Signaling Trainer to stop.



✓ TRAINING COMPLETE


## Section 6: Test Evaluation

In [18]:
# Set model to evaluation mode
cbm_model.eval()

# Move model to device
device_obj = torch.device(DEVICE)
cbm_model = cbm_model.to(device_obj)

print("✓ Model set to evaluation mode")

✓ Model set to evaluation mode


In [19]:
# Run inference on test set
print("Running inference on test set...")

y_true_list = []
y_prob_list = []
y_logits_list = []
concept_probs_list = []

with torch.no_grad():
    for x_batch, y_batch, c_batch in test_loader:
        x_batch = x_batch.to(device_obj)
        
        # Forward pass
        c_logits, y_logits = cbm_model(x_batch)
        
        # Apply sigmoid to get probabilities
        c_probs = torch.sigmoid(c_logits).cpu().numpy()
        y_probs = torch.sigmoid(y_logits).cpu().squeeze().numpy()
        
        # Collect results
        y_true_list.extend(y_batch.numpy().astype(int).tolist())
        y_prob_list.extend(y_probs.tolist() if isinstance(y_probs, np.ndarray) else [y_probs])
        y_logits_list.extend(y_logits.cpu().squeeze().numpy().tolist() if y_logits.dim() > 1 else [y_logits.cpu().squeeze().item()])
        concept_probs_list.extend(c_probs.tolist())

# Convert to arrays
y_true = np.array(y_true_list)
y_prob = np.array(y_prob_list)
y_logits = np.array(y_logits_list)
concept_probs = np.array(concept_probs_list)

print("✓ Inference complete")
print(f"  Probabilities shape: {y_prob.shape}")
print(f"  Concept probs shape: {concept_probs.shape}")

# Log test logit statistics
print(f"\nTest set logit statistics:")
print(f"  Min:  {np.min(y_logits):.4f}")
print(f"  Mean: {np.mean(y_logits):.4f}")
print(f"  Max:  {np.max(y_logits):.4f}")
print(f"  Std:  {np.std(y_logits):.4f}")

# Check for diversity (should NOT be all identical)
if np.std(y_logits) < 0.01:
    print("  ⚠ WARNING: Logits have very low variance - model may have collapsed!")
else:
    print("  ✓ Logits show good diversity")

# Get best threshold from final validation epoch
# Read from trainer logs
import pandas as pd
log_dir = os.path.join(OUTPUT_DIR, "logs/cbm/version_0/metrics.csv")
if os.path.exists(log_dir):
    metrics_df = pd.read_csv(log_dir)
    # Get last non-null val_best_threshold
    best_threshold = metrics_df['val_best_threshold'].dropna().iloc[-1]
    print(f"\n✓ Using threshold from validation: {best_threshold:.2f}")
else:
    # Fallback: find best threshold on validation set manually
    print("\n⚠ No validation logs found, using threshold=0.5")
    best_threshold = 0.5

# Use best threshold for final predictions
y_pred = (y_prob >= best_threshold).astype(int)
print(f"  Predictions shape: {y_pred.shape}")

Running inference on test set...
✓ Inference complete
  Probabilities shape: (201,)
  Concept probs shape: (201, 21)

Test set logit statistics:
  Min:  -4.8226
  Mean: -2.6702
  Max:  7.6939
  Std:  2.4685
  ✓ Logits show good diversity

✓ Using threshold from validation: 0.45
  Predictions shape: (201,)


In [20]:
# Compute all metrics
print("Computing metrics...")

# Confusion matrix
cm = confusion_matrix(y_true, y_pred)
tn, fp, fn, tp = cm.ravel()

# Metrics
acc = accuracy_score(y_true, y_pred)
balanced_acc = balanced_accuracy_score(y_true, y_pred)
roc_auc = roc_auc_score(y_true, y_prob)
mcc = matthews_corrcoef(y_true, y_pred)

f1_binary = f1_score(y_true, y_pred, pos_label=1)
f1_macro = f1_score(y_true, y_pred, average='macro')
f1_micro = f1_score(y_true, y_pred, average='micro')

precision_binary = precision_score(y_true, y_pred, pos_label=1)
recall_binary = recall_score(y_true, y_pred, pos_label=1)

print("✓ Metrics computed")

Computing metrics...
✓ Metrics computed


In [21]:
# Print formatted results
print("\n" + "="*70)
print("                    TEST SET EVALUATION")
print("="*70)
print()
print(f"Dataset Statistics:")
print(f"  Test subjects:        {len(y_true)}")
print(f"  Positive cases:       {np.sum(y_true)} ({100*np.sum(y_true)/len(y_true):.1f}%)")
print(f"  Negative cases:       {len(y_true)-np.sum(y_true)} ({100*(len(y_true)-np.sum(y_true))/len(y_true):.1f}%)")
print()
print(f"Performance Metrics:")
print(f"  Accuracy:                  {acc:.4f}")
print(f"  Balanced Accuracy:         {balanced_acc:.4f}")
print(f"  ROC-AUC:                   {roc_auc:.4f}")
print(f"  Matthews Correlation:      {mcc:.4f}")
print()
print(f"  F1 Score (Binary):         {f1_binary:.4f}")
print(f"  F1 Score (Macro):          {f1_macro:.4f}")
print(f"  F1 Score (Micro):          {f1_micro:.4f}")
print()
print(f"  Precision (Binary):        {precision_binary:.4f}")
print(f"  Recall (Binary):           {recall_binary:.4f}")

# Enhanced confusion matrix display
print(f"\n{'CONFUSION MATRIX':^50}")
print("="*50)
print(f"{'':>20} │ {'Predicted Negative':^12} │ {'Predicted Positive':^12}")
print("─"*50)
print(f"{'Actual Negative':>20} │ {f'TN = {tn}':^12} │ {f'FP = {fp}':^12}")
print(f"{'Actual Positive':>20} │ {f'FN = {fn}':^12} │ {f'TP = {tp}':^12}")
print("="*50)

print(f"\n  True Positives:  {tp:>3}/{int(np.sum(y_true)):<3} ({100*tp/np.sum(y_true):>5.1f}% of depression cases caught)")
print(f"  False Negatives: {fn:>3}/{int(np.sum(y_true)):<3} ({100*fn/np.sum(y_true):>5.1f}% of depression cases MISSED)")
print(f"  True Negatives:  {tn:>3}/{int(len(y_true)-np.sum(y_true)):<3} ({100*tn/(len(y_true)-np.sum(y_true)):>5.1f}% of healthy correctly identified)")
print(f"  False Positives: {fp:>3}/{int(len(y_true)-np.sum(y_true)):<3} ({100*fp/(len(y_true)-np.sum(y_true)):>5.1f}% false alarms)")

print()
print("Classification Report:")
print(classification_report(y_true, y_pred, target_names=['Negative', 'Positive']))
print("="*70)


                    TEST SET EVALUATION

Dataset Statistics:
  Test subjects:        201
  Positive cases:       26 (12.9%)
  Negative cases:       175 (87.1%)

Performance Metrics:
  Accuracy:                  0.8557
  Balanced Accuracy:         0.6715
  ROC-AUC:                   0.8101
  Matthews Correlation:      0.3489

  F1 Score (Binary):         0.4314
  F1 Score (Macro):          0.6744
  F1 Score (Micro):          0.8557

  Precision (Binary):        0.4400
  Recall (Binary):           0.4231

                 CONFUSION MATRIX                 
                     │ Predicted Negative │ Predicted Positive
──────────────────────────────────────────────────
     Actual Negative │   TN = 161   │   FP = 14   
     Actual Positive │   FN = 15    │   TP = 11   

  True Positives:   11/26  ( 42.3% of depression cases caught)
  False Negatives:  15/26  ( 57.7% of depression cases MISSED)
  True Negatives:  161/175 ( 92.0% of healthy correctly identified)
  False Positives:  14/175 (

In [22]:
# Save metrics to JSON
metrics_dict = {
    "n_samples": int(len(y_true)),
    "n_positive": int(np.sum(y_true)),
    "n_negative": int(len(y_true) - np.sum(y_true)),
    "best_threshold": float(best_threshold),
    "accuracy": float(acc),
    "balanced_accuracy": float(balanced_acc),
    "roc_auc": float(roc_auc),
    "mcc": float(mcc),
    "f1_binary": float(f1_binary),
    "f1_macro": float(f1_macro),
    "f1_micro": float(f1_micro),
    "precision_binary": float(precision_binary),
    "recall_binary": float(recall_binary),
    "confusion_matrix": {
        "tn": int(tn),
        "fp": int(fp),
        "fn": int(fn),
        "tp": int(tp)
    }
}

os.makedirs(os.path.join(OUTPUT_DIR, "results"), exist_ok=True)
metrics_path = os.path.join(OUTPUT_DIR, "results/test_metrics.json")

with open(metrics_path, 'w') as f:
    json.dump(metrics_dict, f, indent=4)

print(f"✓ Metrics saved to {metrics_path}")

✓ Metrics saved to outputs_cbm/results/test_metrics.json


In [23]:
# Create predictions DataFrame with concept probabilities
predictions_df = pd.DataFrame({
    'subject_id': test_subject_ids,
    'y_true': y_true,
    'y_pred': y_pred,
    'y_prob': y_prob
})

# Add concept probabilities
for i, concept_name in enumerate(CONCEPT_NAMES):
    predictions_df[concept_name] = concept_probs[:, i]

# Save to CSV
predictions_path = os.path.join(OUTPUT_DIR, "results/test_predictions.csv")
predictions_df.to_csv(predictions_path, index=False)

print(f"✓ Predictions saved to {predictions_path}")
print(f"\nFirst 10 subjects with concept probabilities:")
print(predictions_df.head(10))

✓ Predictions saved to outputs_cbm/results/test_predictions.csv

First 10 subjects with concept probabilities:
         subject_id  y_true  y_pred    y_prob   Sadness  Pessimism  \
0  test_subject4471       1       1  0.869716  0.089046   0.227653   
1  test_subject8981       0       0  0.046299  0.025544   0.096699   
2  test_subject8777       0       0  0.016334  0.014689   0.036407   
3  test_subject1372       0       0  0.008605  0.186242   0.254578   
4  test_subject1830       0       0  0.020472  0.023861   0.056897   
5  test_subject3791       0       0  0.014678  0.019136   0.018401   
6  test_subject2284       0       0  0.018185  0.046446   0.068418   
7  test_subject5689       0       0  0.021459  0.038876   0.072512   
8  test_subject7467       1       0  0.256077  0.049246   0.118178   
9  test_subject7578       0       0  0.014783  0.020362   0.028903   

   Past failure  Loss of pleasure  Guilty feelings  Punishment feelings  ...  \
0      0.225428          0.040993     

In [24]:
# Display concept activation statistics
print("\nConcept Activation Statistics:")
print("="*70)
print(f"{'Concept':<35} {'Mean':>10} {'Std':>10} {'Max':>10}")
print("-"*70)
for i, concept_name in enumerate(CONCEPT_NAMES):
    mean_act = np.mean(concept_probs[:, i])
    std_act = np.std(concept_probs[:, i])
    max_act = np.max(concept_probs[:, i])
    print(f"{concept_name:<35} {mean_act:>10.4f} {std_act:>10.4f} {max_act:>10.4f}")
print("="*70)


Concept Activation Statistics:
Concept                                   Mean        Std        Max
----------------------------------------------------------------------
Sadness                                 0.0685     0.0663     0.3273
Pessimism                               0.1220     0.1210     0.6445
Past failure                            0.1359     0.1281     0.5971
Loss of pleasure                        0.0444     0.0495     0.2993
Guilty feelings                         0.0449     0.0579     0.3415
Punishment feelings                     0.0739     0.0904     0.5012
Self-dislike                            0.1780     0.1821     0.8220
Self-criticalness                       0.0744     0.0538     0.2930
Suicidal thoughts or wishes             0.0444     0.0340     0.1735
Crying                                  0.0459     0.0507     0.2610
Agitation                               0.0313     0.0290     0.1750
Loss of interest                        0.0705     0.0795     0.5363


## Section 8: Summary

In [25]:
print("\n" + "="*70)
print("              CBM PIPELINE EXECUTION COMPLETE")
print("="*70)
print("\nGenerated files:")
print(f"  Model checkpoint: {OUTPUT_DIR}/models/")
print(f"  Metrics JSON:     {OUTPUT_DIR}/results/test_metrics.json")
print(f"  Predictions CSV:  {OUTPUT_DIR}/results/test_predictions.csv")
print(f"  Training logs:    {OUTPUT_DIR}/logs/")
print("="*70)


              CBM PIPELINE EXECUTION COMPLETE

Generated files:
  Model checkpoint: outputs_cbm/models/
  Metrics JSON:     outputs_cbm/results/test_metrics.json
  Predictions CSV:  outputs_cbm/results/test_predictions.csv
  Training logs:    outputs_cbm/logs/
