# IntCEM Model - Intervention-Aware CEM with MAX Alternative Pipeline

**Runtime:** ~25-30 minutes (longer than standard CEM due to intervention rollouts)

This notebook:
1. Trains **IntCEM** (Intervention-Aware CEM) using MAX-based concept similarity data
2. Implements a **learned intervention policy** that decides which concepts to query
3. Uses **intervention rollouts** during training to simulate expert feedback
4. Maintains LDAM Loss for class imbalance (same as 1d_CEM_max_Gold)

**Key Difference from 1d_CEM_max_Gold:**
- 1d: Random intervention (25% probability)
- 1f: Learned policy network predicts optimal concepts to intervene on

**Prerequisites:** Run `0c_prepare_max_alt_dataset.ipynb` first!

## Section 0: Setup & Configuration

In [1]:
# Imports
import os
import json
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

import pytorch_lightning as pl
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks import ModelCheckpoint

from sklearn.metrics import (
    confusion_matrix,
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    matthews_corrcoef,
    roc_auc_score,
    balanced_accuracy_score,
    classification_report,
)

print("✓ All imports successful")

✓ All imports successful


In [2]:
# Set random seeds
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
pl.seed_everything(SEED)

print(f"✓ Random seed set to {SEED}")

Global seed set to 42


✓ Random seed set to 42


In [3]:
# Detect device
if torch.backends.mps.is_available():
    DEVICE = "mps"
    print("✓ Using MacBook GPU (MPS)")
elif torch.cuda.is_available():
    DEVICE = "cuda"
    print("✓ Using CUDA GPU")
else:
    DEVICE = "cpu"
    print("⚠ Using CPU")

✓ Using MacBook GPU (MPS)


In [4]:
# Define paths
PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), "..", ".."))
DATA_PROCESSED = os.path.join(PROJECT_ROOT, "data/processed")
DATASET_DIR = os.path.join(DATA_PROCESSED, "max_alternative_attention_pipeline")
OUTPUT_DIR = "outputs_intcem_max"

print("✓ Paths configured")
print(f"  Dataset dir: {DATASET_DIR}")
print(f"  Output dir: {OUTPUT_DIR}")

✓ Paths configured
  Dataset dir: /Users/gualtieromarencoturi/Desktop/thesis/Master-Thesis-CEM-Depression-etc-case-study/data/processed/max_alternative_attention_pipeline
  Output dir: outputs_intcem_max


In [5]:
# Define 21 BDI-II concept names
CONCEPT_NAMES = [
    "Sadness", "Pessimism", "Past failure", "Loss of pleasure",
    "Guilty feelings", "Punishment feelings", "Self-dislike", "Self-criticalness",
    "Suicidal thoughts or wishes", "Crying", "Agitation", "Loss of interest",
    "Indecisiveness", "Worthlessness", "Loss of energy", "Changes in sleeping pattern",
    "Irritability", "Changes in appetite", "Concentration difficulty",
    "Tiredness or fatigue", "Loss of interest in sex"
]
N_CONCEPTS = len(CONCEPT_NAMES)

print(f"✓ Defined {N_CONCEPTS} BDI-II concepts")

✓ Defined 21 BDI-II concepts


In [6]:
# Hyperparameters
HYPERPARAMS = {
    # Model architecture
    "embedding_dim": 384,
    "n_concepts": 21,
    "n_tasks": 1,
    "emb_size": 128,
    
    # CEM-specific
    "shared_prob_gen": True,        # Share probability generator across concepts
    "intervention_prob": 0.25,      # Training intervention probability (for warmup)
    
    # IntCEM-specific (NEW)
    "intervention_task_discount": 1.1,     # Penalty for errors after interventions
    "intervention_weight": 3.0,            # Policy loss weight
    "max_horizon": 6,                      # Max interventions per sample
    "initial_horizon": 2,                  # Starting horizon
    "horizon_rate": 1.005,                 # Curriculum increase rate
    "num_rollouts": 1,                     # Monte Carlo samples
    "intervention_policy_hidden": [256, 128],  # Policy MLP architecture
    
    # Training
    "batch_size_train": 32,
    "batch_size_eval": 64,
    "max_epochs": 100,
    "learning_rate": 0.01,
    "weight_decay": 4e-05,
    
    # Loss
    "concept_loss_weight": 1.0,
    
    # LDAM Loss
    "use_ldam_loss": True,
    "n_positive": None,               # Will be set after loading data
    "n_negative": None,               # Will be set after loading data
    "ldam_max_margin": 0.5,           # Try: 0.3, 0.5, 0.7, 1.0
    "ldam_scale": 30,                 # Try: 20, 30, 40, 50
    
    # Weighted Sampler
    "use_weighted_sampler": False,
}

print("✓ Hyperparameters configured")
print(f"  Model: IntCEM (Intervention-Aware CEM)")
if HYPERPARAMS['use_ldam_loss']:
    print(f"  Using LDAM LOSS (margin={HYPERPARAMS['ldam_max_margin']}, scale={HYPERPARAMS['ldam_scale']})")
else:
    print(f"  Using standard BCE loss")
print(f"  Intervention policy weight: {HYPERPARAMS['intervention_weight']}")
print(f"  Horizon curriculum: {HYPERPARAMS['initial_horizon']} → {HYPERPARAMS['max_horizon']}")

✓ Hyperparameters configured
  Model: IntCEM (Intervention-Aware CEM)
  Using LDAM LOSS (margin=0.5, scale=30)
  Intervention policy weight: 3.0
  Horizon curriculum: 2 → 6


## Section 1: Load Preprocessed Data

In [7]:
# Load training data
print("Loading preprocessed datasets...")

train_data = np.load(os.path.join(DATASET_DIR, "train_data.npz"))
X_train = train_data['X']
C_train = train_data['C']
y_train = train_data['y']
train_subject_ids = train_data['subject_ids']

print(f"✓ Loaded training data: {X_train.shape}")

Loading preprocessed datasets...
✓ Loaded training data: (388, 384)


In [8]:
# Load validation data
val_data = np.load(os.path.join(DATASET_DIR, "val_data.npz"))
X_val = val_data['X']
C_val = val_data['C']
y_val = val_data['y']
val_subject_ids = val_data['subject_ids']

print(f"✓ Loaded validation data: {X_val.shape}")

✓ Loaded validation data: (98, 384)


In [9]:
# Load test data
test_data = np.load(os.path.join(DATASET_DIR, "test_data.npz"))
X_test = test_data['X']
C_test = test_data['C']
y_test = test_data['y']
test_subject_ids = test_data['subject_ids']

print(f"✓ Loaded test data: {X_test.shape}")

✓ Loaded test data: (401, 384)


In [10]:
# Load class weights
with open(os.path.join(DATASET_DIR, "class_weights.json"), 'r') as f:
    class_info = json.load(f)

n_positive = class_info['n_positive']
n_negative = class_info['n_negative']
pos_weight = class_info['pos_weight']

# Update HYPERPARAMS with actual class counts for LDAM
HYPERPARAMS['n_positive'] = n_positive
HYPERPARAMS['n_negative'] = n_negative

print(f"✓ Loaded class weights:")
print(f"  Negative: {n_negative}, Positive: {n_positive}")
print(f"  Ratio: 1:{pos_weight:.2f}")

✓ Loaded class weights:
  Negative: 322, Positive: 66
  Ratio: 1:4.88


## Section 2: PyTorch Dataset & DataLoaders

In [11]:
class CEMDataset(Dataset):
    def __init__(self, X, C, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.C = torch.tensor(C, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32)
    
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx], self.C[idx]

# Create datasets
train_dataset = CEMDataset(X_train, C_train, y_train)
val_dataset = CEMDataset(X_val, C_val, y_val)
test_dataset = CEMDataset(X_test, C_test, y_test)

# Create WeightedRandomSampler for batch-level oversampling (if enabled)
if HYPERPARAMS['use_weighted_sampler']:
    # Compute class sample counts
    class_sample_counts = np.bincount(y_train.astype(int))  # [n_negative, n_positive]
    weights = 1. / class_sample_counts
    sample_weights = weights[y_train.astype(int)]
    
    # Create sampler
    train_sampler = WeightedRandomSampler(
        weights=sample_weights,
        num_samples=len(sample_weights),
        replacement=True  # Allow positive samples to appear multiple times
    )
    
    print(f"✓ WeightedRandomSampler created:")
    print(f"  Negative weight: {weights[0]:.4f}")
    print(f"  Positive weight: {weights[1]:.4f}")
    print(f"  Expected positive ratio per batch: ~{weights[1]/(weights[0]+weights[1]):.1%}")
    
    # Create train loader with sampler (shuffle=False when using sampler)
    train_loader = DataLoader(train_dataset, batch_size=HYPERPARAMS['batch_size_train'], sampler=train_sampler)
else:
    # Standard train loader with shuffle
    train_loader = DataLoader(train_dataset, batch_size=HYPERPARAMS['batch_size_train'], shuffle=True)
    print("✓ Using standard DataLoader (shuffle=True)")

# Validation and test loaders (no sampling)
val_loader = DataLoader(val_dataset, batch_size=HYPERPARAMS['batch_size_eval'], shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=HYPERPARAMS['batch_size_eval'], shuffle=False)

print("✓ All DataLoaders created")

✓ Using standard DataLoader (shuffle=True)
✓ All DataLoaders created


## Section 3: IntCEM Model Definition

In [12]:
# LDAM Loss (for class imbalance)
class LDAMLoss(nn.Module):
    """
    Label-Distribution-Aware Margin (LDAM) Loss for long-tailed recognition.
    
    Creates class-dependent margins to make decision boundaries harder for minority classes.
    """
    def __init__(self, n_positive, n_negative, max_margin=0.5, scale=30):
        super(LDAMLoss, self).__init__()
        self.max_margin = max_margin
        self.scale = scale
        
        # Compute class frequencies
        total = n_positive + n_negative
        freq_pos = n_positive / total
        freq_neg = n_negative / total
        
        # Compute margins: minority class gets larger margin
        margin_pos = max_margin * (freq_pos ** (-0.25))
        margin_neg = max_margin * (freq_neg ** (-0.25))
        
        self.register_buffer('margin_pos', torch.tensor(margin_pos))
        self.register_buffer('margin_neg', torch.tensor(margin_neg))
    
    def forward(self, logits, targets):
        logits = logits.view(-1)
        targets = targets.view(-1).float()
        
        # Apply class-dependent margins
        margin = targets * self.margin_pos + (1 - targets) * (-self.margin_neg)
        adjusted_logits = (logits - margin) * self.scale
        
        return F.binary_cross_entropy_with_logits(adjusted_logits, targets, reduction='mean')


# Intervention-Aware CEM Implementation
class IntCEM(pl.LightningModule):
    """
    Intervention-Aware Concept Embedding Model (IntCEM).
    
    Extends CEM with a learned intervention policy that decides which concepts
    to query for ground truth during inference.
    
    Architecture:
      X → concept_extractor → context_layers → prob_generator → dual_embeddings → task_classifier → y
      NEW: concept_rank_model (policy network) → intervention decisions
    """
    def __init__(
        self,
        n_concepts=21,
        emb_size=128,
        input_dim=384,
        shared_prob_gen=True,
        intervention_prob=0.25,
        concept_loss_weight=1.0,
        learning_rate=0.01,
        weight_decay=4e-05,
        use_ldam_loss=True,
        n_positive=83,
        n_negative=403,
        ldam_max_margin=0.5,
        ldam_scale=30,
        # IntCEM-specific parameters
        intervention_task_discount=1.1,
        intervention_weight=3.0,
        max_horizon=6,
        initial_horizon=2,
        horizon_rate=1.005,
        num_rollouts=1,
        intervention_policy_hidden=[256, 128],
    ):
        super().__init__()
        self.save_hyperparameters()
        
        self.n_concepts = n_concepts
        self.emb_size = emb_size
        self.intervention_prob = intervention_prob
        self.concept_loss_weight = concept_loss_weight
        
        # Stage 1: Concept Extractor (X → Pre-Concept Features)
        self.concept_extractor = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 256)  # Pre-concept features
        )
        
        # Stage 2: Context Generators (Features → Dual Embeddings)
        # Each concept gets its own context generator
        self.context_layers = nn.ModuleList([
            nn.Sequential(
                nn.Linear(256, emb_size * 2),  # Dual embeddings (true/false)
                nn.LeakyReLU()
            ) for _ in range(n_concepts)
        ])
        
        # Stage 3: Probability Generator (Contexts → Concept Probabilities)
        if shared_prob_gen:
            # Single shared generator for all concepts
            self.prob_generator = nn.Linear(emb_size * 2, 1)
        else:
            # Per-concept probability generators
            self.prob_generators = nn.ModuleList([
                nn.Linear(emb_size * 2, 1) for _ in range(n_concepts)
            ])
        
        self.shared_prob_gen = shared_prob_gen
        
        # Stage 4: Task Classifier (Concept Embeddings → Task Output)
        self.task_classifier = nn.Sequential(
            nn.Linear(n_concepts * emb_size, 128),
            nn.LeakyReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 1)  # Binary classification
        )
        
        # Stage 5 (NEW): Intervention Policy Network
        # Input: [c_probs, prev_interventions] = (21 + 21 = 42 dims)
        # Output: scores for each concept (21 dims)
        layers = []
        prev_dim = 2 * n_concepts  # 42
        
        for hidden_dim in intervention_policy_hidden:
            layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.LeakyReLU(),
                nn.Dropout(0.2)
            ])
            prev_dim = hidden_dim
        
        layers.append(nn.Linear(prev_dim, n_concepts))
        self.concept_rank_model = nn.Sequential(*layers)
        
        # Horizon curriculum state
        self.current_horizon = initial_horizon
        
        # Cached contexts (for _forward_task)
        self.cached_contexts = None
        
        # Loss functions
        self.concept_loss_fn = nn.BCEWithLogitsLoss()
        if use_ldam_loss:
            self.task_loss_fn = LDAMLoss(n_positive, n_negative, ldam_max_margin, ldam_scale)
        else:
            self.task_loss_fn = nn.BCEWithLogitsLoss()
    
    def forward(self, x, c_true=None, train=False):
        # Step 1: Extract pre-concept features
        pre_features = self.concept_extractor(x)  # (B, 256)
        
        # Step 2: Generate contexts and probabilities per concept
        contexts = []
        c_logits_list = []
        
        for i, context_layer in enumerate(self.context_layers):
            context = context_layer(pre_features)  # (B, emb_size*2)
            
            # Get probability logit
            if self.shared_prob_gen:
                logit = self.prob_generator(context)  # (B, 1)
            else:
                logit = self.prob_generators[i](context)
            
            contexts.append(context)
            c_logits_list.append(logit)
        
        c_logits = torch.cat(c_logits_list, dim=1)  # (B, 21)
        c_probs = torch.sigmoid(c_logits)           # (B, 21)
        
        # Cache contexts for intervention rollouts
        self.cached_contexts = contexts
        
        # Step 3: Apply intervention (optional during training)
        if train and self.intervention_prob > 0 and c_true is not None:
            intervention_mask = torch.bernoulli(
                torch.ones_like(c_probs) * self.intervention_prob
            )
            c_probs = c_probs * (1 - intervention_mask) + c_true * intervention_mask
        
        # Step 4: Mix dual embeddings based on probabilities
        concept_embeddings = []
        for i, context in enumerate(contexts):
            # Split into true/false embeddings
            emb_true = context[:, :self.emb_size]       # First half
            emb_false = context[:, self.emb_size:]      # Second half
            
            # Weight by probability
            prob = c_probs[:, i:i+1]  # (B, 1)
            mixed_emb = emb_true * prob + emb_false * (1 - prob)
            concept_embeddings.append(mixed_emb)
        
        # Concatenate all concept embeddings
        c_embeddings = torch.cat(concept_embeddings, dim=1)  # (B, 21*emb_size)
        
        # Step 5: Task prediction
        y_logits = self.task_classifier(c_embeddings)  # (B, 1)
        
        return c_logits, y_logits
    
    def _forward_concepts(self, x):
        """Extract concept predictions without task prediction."""
        pre_features = self.concept_extractor(x)
        contexts = []
        c_logits_list = []
        
        for i, context_layer in enumerate(self.context_layers):
            context = context_layer(pre_features)
            logit = self.prob_generator(context) if self.shared_prob_gen else self.prob_generators[i](context)
            contexts.append(context)
            c_logits_list.append(logit)
        
        c_logits = torch.cat(c_logits_list, dim=1)
        self.cached_contexts = contexts  # Cache for _forward_task
        return c_logits, contexts
    
    def _forward_task(self, c_probs):
        """Compute task prediction from concept probabilities."""
        concept_embeddings = []
        for i, context in enumerate(self.cached_contexts):
            emb_true = context[:, :self.emb_size]
            emb_false = context[:, self.emb_size:]
            prob = c_probs[:, i:i+1]
            mixed_emb = emb_true * prob + emb_false * (1 - prob)
            concept_embeddings.append(mixed_emb)
        
        c_embeddings = torch.cat(concept_embeddings, dim=1)
        y_logits = self.task_classifier(c_embeddings)
        return y_logits
    
    def _prior_int_distribution(self, c_probs, intervention_mask=None):
        """Compute intervention policy scores."""
        if intervention_mask is None:
            intervention_mask = torch.zeros_like(c_probs)
        
        # Concatenate current predictions with intervention history
        policy_input = torch.cat([c_probs, intervention_mask], dim=1)  # (B, 42)
        
        # Get scores from policy network
        scores = self.concept_rank_model(policy_input)  # (B, 21)
        
        # Mask out already-intervened concepts
        scores = scores.masked_fill(intervention_mask.bool(), float('-inf'))
        
        return scores
    
    def _intervention_rollout_loss(self, x, c_true, y_true, train=True):
        """Perform intervention rollouts with SIMPLIFIED target selection."""
        batch_size = x.shape[0]
        device = x.device
        
        # Get initial concept predictions
        c_logits, _ = self._forward_concepts(x)
        c_probs = torch.sigmoid(c_logits)
        c_probs = torch.clamp(c_probs, 1e-6, 1 - 1e-6)  # Prevent NaN
        
        # Accumulate losses
        total_intervention_loss = 0.0
        total_task_loss = 0.0
        
        # Track interventions
        intervention_mask = torch.zeros_like(c_probs)
        
        for step in range(int(self.current_horizon)):
            # Get policy scores
            policy_scores = self._prior_int_distribution(c_probs, intervention_mask)
            
            # Sample using Gumbel-softmax (differentiable)
            if train:
                temperature = max(0.1, 0.5 * (0.95 ** self.current_epoch))
                intervention_sample = F.gumbel_softmax(policy_scores, tau=temperature, hard=True, dim=1)
            else:
                intervention_sample = F.one_hot(policy_scores.argmax(dim=1), num_classes=self.n_concepts).float()
            
            # Update mask
            intervention_mask = torch.clamp(intervention_mask + intervention_sample, 0, 1)
            
            # Apply intervention
            c_probs_intervened = c_probs * (1 - intervention_mask) + c_true * intervention_mask
            
            # Task prediction with intervened concepts
            y_logits = self._forward_task(c_probs_intervened)
            
            # Discounted task loss
            task_loss_step = self.task_loss_fn(y_logits.squeeze(), y_true.squeeze())
            total_task_loss += task_loss_step * (self.hparams.intervention_task_discount ** step)
            
            # SIMPLIFIED policy loss: target most uncertain concepts
            with torch.no_grad():
                uncertainty = torch.abs(c_probs - 0.5)  # Distance from 0.5
                optimal_concept = uncertainty.argmin(dim=1)  # Most uncertain
            
            policy_loss_step = F.cross_entropy(policy_scores, optimal_concept)
            total_intervention_loss += policy_loss_step
        
        intervention_loss = total_intervention_loss / self.hparams.num_rollouts
        task_loss_rollout = total_task_loss / self.hparams.num_rollouts
        avg_interventions = intervention_mask.sum().item() / batch_size
        
        return intervention_loss, task_loss_rollout, avg_interventions
    
    def training_step(self, batch, batch_idx):
        x, y, c_true = batch
        
        # Standard forward pass (with random interventions)
        c_logits, y_logits = self.forward(x, c_true=c_true, train=True)
        
        # Concept loss
        concept_loss = self.concept_loss_fn(c_logits, c_true)
        
        # Standard task loss
        task_loss_base = self.task_loss_fn(y_logits.squeeze(), y.squeeze())
        
        # Intervention rollout losses
        intervention_loss, task_loss_rollout, avg_interventions = self._intervention_rollout_loss(
            x, c_true, y, train=True
        )
        
        # Combined loss
        loss = (
            self.hparams.concept_loss_weight * concept_loss +
            task_loss_base +
            self.hparams.intervention_weight * intervention_loss +
            task_loss_rollout
        )
        
        # Logging
        self.log('train_loss', loss, on_epoch=True, prog_bar=True)
        self.log('train_concept_loss', concept_loss, on_epoch=True)
        self.log('train_task_loss_base', task_loss_base, on_epoch=True)
        self.log('train_intervention_loss', intervention_loss, on_epoch=True)
        self.log('train_task_loss_rollout', task_loss_rollout, on_epoch=True)
        self.log('train_avg_interventions', avg_interventions, on_epoch=True)
        self.log('current_horizon', self.current_horizon, on_epoch=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y, c_true = batch
        c_logits, y_logits = self.forward(x, c_true=c_true, train=False)
        
        # Task loss
        task_loss = self.task_loss_fn(y_logits.squeeze(), y.squeeze())
        
        # Concept loss
        concept_loss = self.concept_loss_fn(c_logits, c_true)
        
        # Combined loss
        loss = task_loss + self.concept_loss_weight * concept_loss
        
        # Logging
        self.log('val_loss', loss, on_epoch=True, prog_bar=True)
        self.log('val_task_loss', task_loss, on_epoch=True)
        self.log('val_concept_loss', concept_loss, on_epoch=True)
        
        return loss
    
    def on_train_epoch_end(self):
        """Update horizon curriculum at end of each epoch."""
        self.current_horizon = min(
            self.hparams.max_horizon,
            self.current_horizon * self.hparams.horizon_rate
        )
    
    def configure_optimizers(self):
        return torch.optim.Adam(
            self.parameters(),
            lr=self.hparams.learning_rate,
            weight_decay=self.hparams.weight_decay
        )

print("✓ IntCEM model defined")

✓ IntCEM model defined


## Section 4: Model Initialization

In [13]:
# Initialize IntCEM model
intcem = IntCEM(
    n_concepts=HYPERPARAMS['n_concepts'],
    emb_size=HYPERPARAMS['emb_size'],
    input_dim=HYPERPARAMS['embedding_dim'],
    shared_prob_gen=HYPERPARAMS['shared_prob_gen'],
    intervention_prob=HYPERPARAMS['intervention_prob'],
    concept_loss_weight=HYPERPARAMS['concept_loss_weight'],
    learning_rate=HYPERPARAMS['learning_rate'],
    weight_decay=HYPERPARAMS['weight_decay'],
    use_ldam_loss=HYPERPARAMS['use_ldam_loss'],
    n_positive=HYPERPARAMS['n_positive'],
    n_negative=HYPERPARAMS['n_negative'],
    ldam_max_margin=HYPERPARAMS['ldam_max_margin'],
    ldam_scale=HYPERPARAMS['ldam_scale'],
    # IntCEM-specific parameters
    intervention_task_discount=HYPERPARAMS['intervention_task_discount'],
    intervention_weight=HYPERPARAMS['intervention_weight'],
    max_horizon=HYPERPARAMS['max_horizon'],
    initial_horizon=HYPERPARAMS['initial_horizon'],
    horizon_rate=HYPERPARAMS['horizon_rate'],
    num_rollouts=HYPERPARAMS['num_rollouts'],
    intervention_policy_hidden=HYPERPARAMS['intervention_policy_hidden']
)

print("✓ IntCEM model initialized")
print(f"  Model type: Intervention-Aware CEM")
print(f"  Using LDAM Loss (margin={HYPERPARAMS['ldam_max_margin']}, scale={HYPERPARAMS['ldam_scale']})")
print(f"  Concept embedding size: {HYPERPARAMS['emb_size']}")
print(f"  Intervention policy hidden: {HYPERPARAMS['intervention_policy_hidden']}")
print(f"  Horizon curriculum: {HYPERPARAMS['initial_horizon']} → {HYPERPARAMS['max_horizon']} (rate={HYPERPARAMS['horizon_rate']})")
print(f"  Intervention weight: {HYPERPARAMS['intervention_weight']}")
print(f"  Class counts: {HYPERPARAMS['n_positive']} positive, {HYPERPARAMS['n_negative']} negative")
print(f"  Data source: MAX alternative attention pipeline (specialist posts)")

✓ IntCEM model initialized
  Model type: Intervention-Aware CEM
  Using LDAM Loss (margin=0.5, scale=30)
  Concept embedding size: 128
  Intervention policy hidden: [256, 128]
  Horizon curriculum: 2 → 6 (rate=1.005)
  Intervention weight: 3.0
  Class counts: 66 positive, 322 negative
  Data source: MAX alternative attention pipeline (specialist posts)


## Section 5: Training

In [14]:
# Setup trainer
checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",
    dirpath=os.path.join(OUTPUT_DIR, "models"),
    filename="intcem-max-{epoch:02d}-{val_loss:.2f}",
    save_top_k=1,
    mode="min"
)

trainer = pl.Trainer(
    max_epochs=HYPERPARAMS['max_epochs'],
    accelerator=DEVICE,
    devices=1,
    logger=CSVLogger(save_dir=os.path.join(OUTPUT_DIR, "logs"), name="intcem_max"),
    log_every_n_steps=10,
    callbacks=[checkpoint_callback],
    enable_progress_bar=True
)

print("✓ Trainer configured")

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


✓ Trainer configured


In [15]:
# Train model
print("\nStarting IntCEM training...\n")
trainer.fit(intcem, train_loader, val_loader)
print("\n✓ Training complete!")

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")

  | Name               | Type              | Params
---------------------------------------------------------
0 | concept_extractor  | Sequential        | 164 K 
1 | context_layers     | ModuleList        | 1.4 M 
2 | prob_generator     | Linear            | 257   
3 | task_classifier    | Sequential        | 344 K 
4 | concept_rank_model | Sequential        | 47.4 K
5 | concept_loss_fn    | BCEWithLogitsLoss | 0     
6 | task_loss_fn       | LDAMLoss          | 0     
---------------------------------------------------------
1.9 M     Trainable params
0         Non-trainable params
1.9 M     Total params
7.752     Total estimated model params size (MB)



Starting IntCEM training...



Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.



✓ Training complete!


## Section 6: Test Evaluation

In [16]:
# Run inference on test set
print("Running inference on test set...")

intcem.eval()
device_obj = torch.device(DEVICE)
intcem = intcem.to(device_obj)

y_true_list = []
y_prob_list = []
concept_probs_list = []

with torch.no_grad():
    for x_batch, y_batch, c_batch in test_loader:
        x_batch = x_batch.to(device_obj)
        
        c_logits, y_logits = intcem(x_batch)
        c_probs = torch.sigmoid(c_logits).cpu().numpy()
        y_probs = torch.sigmoid(y_logits).cpu().squeeze().numpy()
        
        y_true_list.extend(y_batch.numpy().astype(int).tolist())
        y_prob_list.extend(y_probs.tolist() if isinstance(y_probs, np.ndarray) else [y_probs])
        concept_probs_list.extend(c_probs.tolist())

y_true = np.array(y_true_list)
y_prob = np.array(y_prob_list)
concept_probs = np.array(concept_probs_list)

print("✓ Inference complete")

Running inference on test set...
✓ Inference complete


In [17]:
# ------------------------------
# Run inference on validation set to find best threshold
# ------------------------------

val_true_list = []
val_prob_list = []

intcem.eval()
with torch.no_grad():
    for x_batch, y_batch, _ in val_loader:
        x_batch = x_batch.to(device_obj)
        _, y_logits = intcem(x_batch)
        y_probs = torch.sigmoid(y_logits).cpu().squeeze().numpy()
        
        val_true_list.extend(y_batch.numpy().astype(int).tolist())
        val_prob_list.extend(y_probs.tolist() if isinstance(y_probs, np.ndarray) else [y_probs])

val_true = np.array(val_true_list)
val_prob = np.array(val_prob_list)

# Try thresholds only on validation set
thresholds_to_try = [0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8]
best_f1 = 0
best_threshold = 0.5

for threshold in thresholds_to_try:
    y_pred_temp = (val_prob >= threshold).astype(int)
    if np.sum(y_pred_temp) == 0:
        continue
    f1 = f1_score(val_true, y_pred_temp)
    if f1 > best_f1:
        best_f1 = f1
        best_threshold = threshold

print(f"\n✓ Best threshold (from validation): {best_threshold:.2f} (F1={best_f1:.4f})")



✓ Best threshold (from validation): 0.10 (F1=0.5854)


In [18]:
# Apply the threshold found on validation set to test set
y_pred = (y_prob >= best_threshold).astype(int)


## Section 7: Results Display

In [19]:
# Compute all metrics
cm = confusion_matrix(y_true, y_pred)
tn, fp, fn, tp = cm.ravel()

acc = accuracy_score(y_true, y_pred)
balanced_acc = balanced_accuracy_score(y_true, y_pred)
roc_auc = roc_auc_score(y_true, y_prob)
mcc = matthews_corrcoef(y_true, y_pred)
f1_binary = f1_score(y_true, y_pred, pos_label=1)
f1_macro = f1_score(y_true, y_pred, average='macro')
precision_binary = precision_score(y_true, y_pred, pos_label=1)
recall_binary = recall_score(y_true, y_pred, pos_label=1)

# Print results
print("\n" + "="*70)
print("                    TEST SET EVALUATION (IntCEM)")
print("="*70)
print(f"\nDecision Threshold: {best_threshold:.2f}")

# Enhanced Confusion Matrix Display
print(f"\n{'CONFUSION MATRIX':^50}")
print("="*50)
print(f"{'':>20} │ {'Predicted Negative':^12} │ {'Predicted Positive':^12}")
print("─"*50)
print(f"{'Actual Negative':>20} │ {f'TN = {tn}':^12} │ {f'FP = {fp}':^12}")
print(f"{'Actual Positive':>20} │ {f'FN = {fn}':^12} │ {f'TP = {tp}':^12}")
print("="*50)
print(f"\n  True Positives:  {tp:>3}/{int(np.sum(y_true)):<3} ({100*tp/np.sum(y_true):>5.1f}% of depression cases caught)")
print(f"  False Negatives: {fn:>3}/{int(np.sum(y_true)):<3} ({100*fn/np.sum(y_true):>5.1f}% of depression cases MISSED)")
print(f"  True Negatives:  {tn:>3}/{int(len(y_true)-np.sum(y_true)):<3} ({100*tn/(len(y_true)-np.sum(y_true)):>5.1f}% of healthy correctly identified)")
print(f"  False Positives: {fp:>3}/{int(len(y_true)-np.sum(y_true)):<3} ({100*fp/(len(y_true)-np.sum(y_true)):>5.1f}% false alarms)")

print(f"\nPerformance Metrics:")
print(f"  Accuracy:                  {acc:.4f}")
print(f"  Balanced Accuracy:         {balanced_acc:.4f}")
print(f"  ROC-AUC:                   {roc_auc:.4f}")
print(f"  Matthews Correlation:      {mcc:.4f}")
print(f"\n  F1 Score (Binary):         {f1_binary:.4f}")
print(f"  F1 Score (Macro):          {f1_macro:.4f}")
print(f"  Precision (Binary):        {precision_binary:.4f}")
print(f"  Recall (Binary):           {recall_binary:.4f}")

print("\n" + classification_report(y_true, y_pred, target_names=['Negative', 'Positive']))
print("="*70)


                    TEST SET EVALUATION (IntCEM)

Decision Threshold: 0.10

                 CONFUSION MATRIX                 
                     │ Predicted Negative │ Predicted Positive
──────────────────────────────────────────────────
     Actual Negative │   TN = 302   │   FP = 47   
     Actual Positive │   FN = 11    │   TP = 41   

  True Positives:   41/52  ( 78.8% of depression cases caught)
  False Negatives:  11/52  ( 21.2% of depression cases MISSED)
  True Negatives:  302/349 ( 86.5% of healthy correctly identified)
  False Positives:  47/349 ( 13.5% false alarms)

Performance Metrics:
  Accuracy:                  0.8554
  Balanced Accuracy:         0.8269
  ROC-AUC:                   0.9101
  Matthews Correlation:      0.5307

  F1 Score (Binary):         0.5857
  F1 Score (Macro):          0.7491
  Precision (Binary):        0.4659
  Recall (Binary):           0.7885

              precision    recall  f1-score   support

    Negative       0.96      0.87      0.91  

In [20]:
# Save results
metrics_dict = {
    "model_type": "intcem_max",
    "data_source": "max_alternative_attention_pipeline",
    "threshold": float(best_threshold),
    "n_samples": int(len(y_true)),
    "n_positive": int(np.sum(y_true)),
    "n_negative": int(len(y_true) - np.sum(y_true)),
    "accuracy": float(acc),
    "balanced_accuracy": float(balanced_acc),
    "roc_auc": float(roc_auc),
    "mcc": float(mcc),
    "f1_binary": float(f1_binary),
    "f1_macro": float(f1_macro),
    "precision_binary": float(precision_binary),
    "recall_binary": float(recall_binary),
    "confusion_matrix": {"tn": int(tn), "fp": int(fp), "fn": int(fn), "tp": int(tp)}
}

os.makedirs(os.path.join(OUTPUT_DIR, "results"), exist_ok=True)
with open(os.path.join(OUTPUT_DIR, "results/test_metrics.json"), 'w') as f:
    json.dump(metrics_dict, f, indent=4)

# Save predictions
predictions_df = pd.DataFrame({
    'subject_id': test_subject_ids,
    'y_true': y_true,
    'y_pred': y_pred,
    'y_prob': y_prob
})

for i, concept_name in enumerate(CONCEPT_NAMES):
    predictions_df[concept_name] = concept_probs[:, i]

predictions_df.to_csv(os.path.join(OUTPUT_DIR, "results/test_predictions.csv"), index=False)

print(f"✓ Results saved to {OUTPUT_DIR}/results/")

✓ Results saved to outputs_intcem_max/results/


In [21]:
print("\n" + "="*70)
print("              IntCEM TRAINING COMPLETE")
print("="*70)
print(f"\nGenerated files:")
print(f"  Model checkpoint: {OUTPUT_DIR}/models/")
print(f"  Metrics JSON:     {OUTPUT_DIR}/results/test_metrics.json")
print(f"  Predictions CSV:  {OUTPUT_DIR}/results/test_predictions.csv")
print(f"\nModel type: Intervention-Aware CEM (IntCEM)")
print(f"Data source: max_alternative_attention_pipeline")
print(f"  - Uses MAX-based concept similarity (specialist posts)")
print(f"  - Learns intervention policy to select optimal concepts")
print(f"  - Horizon curriculum: {HYPERPARAMS['initial_horizon']} → {HYPERPARAMS['max_horizon']}")
print("="*70)


              IntCEM TRAINING COMPLETE

Generated files:
  Model checkpoint: outputs_intcem_max/models/
  Metrics JSON:     outputs_intcem_max/results/test_metrics.json
  Predictions CSV:  outputs_intcem_max/results/test_predictions.csv

Model type: Intervention-Aware CEM (IntCEM)
Data source: max_alternative_attention_pipeline
  - Uses MAX-based concept similarity (specialist posts)
  - Learns intervention policy to select optimal concepts
  - Horizon curriculum: 2 → 6
