# Custom CEM Model - Alternative Attention Pipeline (SUM-based)

**Runtime:** ~15-20 minutes

This notebook:
1. Implements CEM from scratch using PyTorch
2. Uses LDAM Loss + WeightedRandomSampler for class imbalance
3. **Uses alternative dataset with SUM-based concept scoring**
4. **Validation set now has TRUE concept labels** (from 20% train split)

**Key Difference from Original:**
- Uses data from `alternative_attention_pipeline` (SUM of concept similarities)
- Validation set has ground-truth concept labels (not zeros)
- Can properly evaluate concept prediction during validation

**Prerequisites:** Run `0_prepare_alternative_attention_dataset.ipynb` first!

## Section 0: Setup & Configuration

In [1]:
# Imports
import os
import json
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

import pytorch_lightning as pl
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

from sklearn.metrics import (
    confusion_matrix,
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    matthews_corrcoef,
    roc_auc_score,
    balanced_accuracy_score,
    classification_report,
)

print("✓ All imports successful")

✓ All imports successful


In [2]:
# Set random seeds
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
pl.seed_everything(SEED)

print(f"✓ Random seed set to {SEED}")

Global seed set to 42


✓ Random seed set to 42


In [3]:
# Detect device
if torch.backends.mps.is_available():
    DEVICE = "mps"
    print("✓ Using MacBook GPU (MPS)")
elif torch.cuda.is_available():
    DEVICE = "cuda"
    print("✓ Using CUDA GPU")
else:
    DEVICE = "cpu"
    print("⚠ Using CPU")

✓ Using MacBook GPU (MPS)


In [4]:
# Define paths - CHANGED FOR ALTERNATIVE PIPELINE
PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), "..", ".."))
DATA_PROCESSED = os.path.join(PROJECT_ROOT, "data/processed")
DATASET_DIR = os.path.join(DATA_PROCESSED, "alternative_attention_pipeline")  # CHANGED
OUTPUT_DIR = "outputs_alternate_cem"  # CHANGED

print("✓ Paths configured")
print(f"  Dataset dir: {DATASET_DIR}")
print(f"  Output dir: {OUTPUT_DIR}")
print("\n  NOTE: Using ALTERNATIVE dataset (SUM-based concept scoring)")
print("        Validation set has TRUE concept labels!")

✓ Paths configured
  Dataset dir: /Users/gualtieromarencoturi/Desktop/thesis/Master-Thesis-CEM-Depression-etc-case-study/data/processed/alternative_attention_pipeline
  Output dir: outputs_alternate_cem

  NOTE: Using ALTERNATIVE dataset (SUM-based concept scoring)
        Validation set has TRUE concept labels!


In [5]:
# Define 21 BDI-II concept names
CONCEPT_NAMES = [
    "Sadness", "Pessimism", "Past failure", "Loss of pleasure",
    "Guilty feelings", "Punishment feelings", "Self-dislike", "Self-criticalness",
    "Suicidal thoughts or wishes", "Crying", "Agitation", "Loss of interest",
    "Indecisiveness", "Worthlessness", "Loss of energy", "Changes in sleeping pattern",
    "Irritability", "Changes in appetite", "Concentration difficulty",
    "Tiredness or fatigue", "Loss of interest in sex"
]
N_CONCEPTS = len(CONCEPT_NAMES)

print(f"✓ Defined {N_CONCEPTS} BDI-II concepts")

✓ Defined 21 BDI-II concepts


In [6]:
# Hyperparameters
HYPERPARAMS = {
    # Model architecture
    "embedding_dim": 384,
    "n_concepts": 21,
    "n_tasks": 1,
    "emb_size": 128,
    
    # CEM-specific
    "shared_prob_gen": True,        # Share probability generator across concepts
    "intervention_prob": 0.1,      # Training intervention probability
    "concept_temperature": 2.0,

    # Training
    "batch_size_train": 32,
    "batch_size_eval": 64,
    "max_epochs": 100,
    "learning_rate": 0.01,
    "weight_decay": 4e-05,
    
    # Loss
    "concept_loss_weight": 1.0,
    
    # LDAM Loss
    "use_ldam_loss": True,
    "n_positive": None,               # Will be set after loading data
    "n_negative": None,               # Will be set after loading data
    "ldam_max_margin": 0.7,           # Try: 0.3, 0.5, 0.7, 1.0
    "ldam_scale": 40,                 # Try: 20, 30, 40, 50
    
    # Weighted Sampler
    "use_weighted_sampler": True,
}

print("✓ Hyperparameters configured")
if HYPERPARAMS['use_ldam_loss']:
    print(f"  Using LDAM LOSS (margin={HYPERPARAMS['ldam_max_margin']}, scale={HYPERPARAMS['ldam_scale']})")
else:
    print(f"  Using standard BCE loss")

✓ Hyperparameters configured
  Using LDAM LOSS (margin=0.7, scale=40)


## Section 1: Load Preprocessed Data

In [7]:
# Load training data
print("Loading preprocessed datasets...")

train_data = np.load(os.path.join(DATASET_DIR, "train_data.npz"))
X_train = train_data['X']
C_train = train_data['C']
y_train = train_data['y']
train_subject_ids = train_data['subject_ids']
# ===============================
# STEP 4.1: concept loss weighting
# ===============================

# C_train: (N_train_samples, n_concepts), values in {0,1}
concept_pos_counts = C_train.sum(axis=0)
concept_neg_counts = C_train.shape[0] - concept_pos_counts

concept_pos_weight = torch.tensor(
    concept_neg_counts / (concept_pos_counts + 1e-6),
    dtype=torch.float32
)

print(f"✓ Loaded training data: {X_train.shape}")

Loading preprocessed datasets...
✓ Loaded training data: (388, 384)


In [8]:
# Debugging
print("Concept pos_weight stats:")
print(" min:", concept_pos_weight.min().item())
print(" mean:", concept_pos_weight.mean().item())
print(" max:", concept_pos_weight.max().item())

Concept pos_weight stats:
 min: 3.9743590354919434
 mean: 26.74332618713379
 max: 95.9999771118164


In [9]:
# Load validation data
val_data = np.load(os.path.join(DATASET_DIR, "val_data.npz"))
X_val = val_data['X']
C_val = val_data['C']
y_val = val_data['y']
val_subject_ids = val_data['subject_ids']

print(f"✓ Loaded validation data: {X_val.shape}")
print(f"  Validation concept matrix has {np.count_nonzero(C_val)} non-zero values")
print(f"  (Previously was all zeros, now has TRUE labels!)")

✓ Loaded validation data: (98, 384)
  Validation concept matrix has 126 non-zero values
  (Previously was all zeros, now has TRUE labels!)


In [10]:
# Load test data
test_data = np.load(os.path.join(DATASET_DIR, "test_data.npz"))
X_test = test_data['X']
C_test = test_data['C']
y_test = test_data['y']
test_subject_ids = test_data['subject_ids']

print(f"✓ Loaded test data: {X_test.shape}")

✓ Loaded test data: (401, 384)


In [11]:
# Load class weights
with open(os.path.join(DATASET_DIR, "class_weights.json"), 'r') as f:
    class_info = json.load(f)

n_positive = class_info['n_positive']
n_negative = class_info['n_negative']
pos_weight = class_info['pos_weight']

# Update HYPERPARAMS with actual class counts for LDAM
HYPERPARAMS['n_positive'] = n_positive
HYPERPARAMS['n_negative'] = n_negative

print(f"✓ Loaded class weights:")
print(f"  Negative: {n_negative}, Positive: {n_positive}")
print(f"  Ratio: 1:{pos_weight:.2f}")

✓ Loaded class weights:
  Negative: 322, Positive: 66
  Ratio: 1:4.88


## Section 2: PyTorch Dataset & DataLoaders

In [12]:
class CEMDataset(Dataset):
    def __init__(self, X, C, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.C = torch.tensor(C, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32)
    
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx], self.C[idx]

# Create datasets
train_dataset = CEMDataset(X_train, C_train, y_train)
val_dataset = CEMDataset(X_val, C_val, y_val)
test_dataset = CEMDataset(X_test, C_test, y_test)

# Create WeightedRandomSampler for batch-level oversampling (if enabled)
if HYPERPARAMS['use_weighted_sampler']:
    # Compute class sample counts
    class_sample_counts = np.bincount(y_train.astype(int))  # [n_negative, n_positive]
    weights = 1. / class_sample_counts
    sample_weights = weights[y_train.astype(int)]
    
    # Create sampler
    train_sampler = WeightedRandomSampler(
        weights=sample_weights,
        num_samples=len(sample_weights),
        replacement=True  # Allow positive samples to appear multiple times
    )
    
    print(f"✓ WeightedRandomSampler created:")
    print(f"  Negative weight: {weights[0]:.4f}")
    print(f"  Positive weight: {weights[1]:.4f}")
    print(f"  Expected positive ratio per batch: ~{weights[1]/(weights[0]+weights[1]):.1%}")
    
    # Create train loader with sampler (shuffle=False when using sampler)
    train_loader = DataLoader(train_dataset, batch_size=HYPERPARAMS['batch_size_train'], sampler=train_sampler)
else:
    # Standard train loader with shuffle
    train_loader = DataLoader(train_dataset, batch_size=HYPERPARAMS['batch_size_train'], shuffle=True)
    print("✓ Using standard DataLoader (shuffle=True)")

# Validation and test loaders (no sampling)
val_loader = DataLoader(val_dataset, batch_size=HYPERPARAMS['batch_size_eval'], shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=HYPERPARAMS['batch_size_eval'], shuffle=False)

print("✓ All DataLoaders created")

✓ WeightedRandomSampler created:
  Negative weight: 0.0031
  Positive weight: 0.0152
  Expected positive ratio per batch: ~83.0%
✓ All DataLoaders created


## Section 3: Custom CEM Model Definition

In [13]:
# LDAM Loss (for class imbalance)
class LDAMLoss(nn.Module):
    """
    Label-Distribution-Aware Margin (LDAM) Loss for long-tailed recognition.
    
    Creates class-dependent margins to make decision boundaries harder for minority classes.
    """
    def __init__(self, n_positive, n_negative, max_margin=0.5, scale=30):
        super(LDAMLoss, self).__init__()
        self.max_margin = max_margin
        self.scale = scale
        
        # Compute class frequencies
        total = n_positive + n_negative
        freq_pos = n_positive / total
        freq_neg = n_negative / total
        
        # Compute margins: minority class gets larger margin
        margin_pos = max_margin * (freq_pos ** (-0.25))
        margin_neg = max_margin * (freq_neg ** (-0.25))
        
        self.register_buffer('margin_pos', torch.tensor(margin_pos))
        self.register_buffer('margin_neg', torch.tensor(margin_neg))
    
    def forward(self, logits, targets):
        logits = logits.view(-1)
        targets = targets.view(-1).float()
        
        # Apply class-dependent margins
        margin = targets * self.margin_pos + (1 - targets) * (-self.margin_neg)
        adjusted_logits = (logits - margin) * self.scale
        
        return F.binary_cross_entropy_with_logits(adjusted_logits, targets, reduction='mean')


# Custom CEM Implementation
class CustomCEM(pl.LightningModule):
    """
    Custom Concept Embedding Model (CEM) implementation.
    
    Architecture:
      X → concept_extractor → context_layers → prob_generator → dual_embeddings → task_classifier → y
    """
    def __init__(
        self,
        n_concepts=21,
        emb_size=128,
        input_dim=384,
        shared_prob_gen=True,
        intervention_prob=0.25,
        concept_loss_weight=1.0,
        learning_rate=0.01,
        weight_decay=4e-05,
        use_ldam_loss=True,
        n_positive=83,
        n_negative=403,
        ldam_max_margin=0.5,
        ldam_scale=30,
        concept_temperature=2.0,
        concept_pos_weight=None,


    ):
        super().__init__()
        self.save_hyperparameters()
        
        self.n_concepts = n_concepts
        self.emb_size = emb_size
        self.intervention_prob = intervention_prob
        self.concept_loss_weight = concept_loss_weight
        self.concept_temperature = concept_temperature

        # Stage 1: Concept Extractor (X → Pre-Concept Features)
        self.concept_extractor = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 256)  # Pre-concept features
        )
        
        # Stage 2: Context Generators (Features → Dual Embeddings)
        # Each concept gets its own context generator
        self.context_layers = nn.ModuleList([
            nn.Sequential(
                nn.Linear(256, emb_size * 2),  # Dual embeddings (true/false)
                nn.LeakyReLU()
            ) for _ in range(n_concepts)
        ])
        
        # Stage 3: Probability Generator (Contexts → Concept Probabilities)
        if shared_prob_gen:
            # Single shared generator for all concepts
            self.prob_generator = nn.Linear(emb_size * 2, 1)
        else:
            # Per-concept probability generators
            self.prob_generators = nn.ModuleList([
                nn.Linear(emb_size * 2, 1) for _ in range(n_concepts)
            ])
        
        self.shared_prob_gen = shared_prob_gen
        
        # Stage 4: Task Classifier (Concept Embeddings → Task Output)
        self.task_classifier = nn.Sequential(
            nn.Linear(n_concepts * emb_size, 128),
            nn.LeakyReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 1)  # Binary classification
        )
        
        # Loss functions
        self.concept_loss_fn = nn.BCEWithLogitsLoss()
        if use_ldam_loss:
            self.task_loss_fn = LDAMLoss(n_positive, n_negative, ldam_max_margin, ldam_scale)
        else:
            self.task_loss_fn = nn.BCEWithLogitsLoss()
    
    def forward(self, x, c_true=None, train=False):
        # Step 1: Extract pre-concept features
        pre_features = self.concept_extractor(x)  # (B, 256)
        
        # Step 2: Generate contexts and probabilities per concept
        contexts = []
        c_logits_list = []
        
        for i, context_layer in enumerate(self.context_layers):
            context = context_layer(pre_features)  # (B, emb_size*2)
            
            # Get probability logit
            if self.shared_prob_gen:
                logit = self.prob_generator(context)  # (B, 1)
            else:
                logit = self.prob_generators[i](context)
            
            contexts.append(context)
            c_logits_list.append(logit)
        
        c_logits = torch.cat(c_logits_list, dim=1)  # (B, 21)
        c_probs = torch.sigmoid(c_logits / self.concept_temperature)  # (B, 21)
        
        # Step 3: Apply intervention (optional during training)
        if train and self.intervention_prob > 0 and c_true is not None:
            intervention_mask = torch.bernoulli(
                torch.ones_like(c_probs) * self.intervention_prob
            )
            c_probs = c_probs * (1 - intervention_mask) + c_true * intervention_mask
        
        # Step 4: Mix dual embeddings based on probabilities
        concept_embeddings = []
        for i, context in enumerate(contexts):
            # Split into true/false embeddings
            emb_true = context[:, :self.emb_size]       # (B, emb_size)
            emb_false = context[:, self.emb_size:]      # (B, emb_size)
            
            # Concept probability
            prob = c_probs[:, i:i+1]  # (B, 1)

            # Concept confidence (distance from uncertainty)
            # Soft confidence gating
            confidence = torch.abs(prob - 0.5) * 2.0  # in [0, 1]

            # Add floor to preserve weak signals
            confidence = 0.3 + 0.7 * confidence

            mixed_emb = emb_true * prob + emb_false * (1 - prob)
            mixed_emb = mixed_emb * confidence


            concept_embeddings.append(mixed_emb)

        
        # Concatenate all concept embeddings
        c_embeddings = torch.cat(concept_embeddings, dim=1)  # (B, 21*emb_size)
        
        # Step 5: Task prediction
        y_logits = self.task_classifier(c_embeddings)  # (B, 1)
        
        return c_logits, y_logits
    
    def training_step(self, batch, batch_idx):
        x, y, c_true = batch
        c_logits, y_logits = self.forward(x, c_true=c_true, train=True)
        
        # Task loss (LDAM)
        task_loss = self.task_loss_fn(y_logits.squeeze(), y.squeeze())
        
        # Concept loss (BCE)
        concept_loss = self.concept_loss_fn(c_logits, c_true)
        
        # Combined loss
        loss = task_loss + self.concept_loss_weight * concept_loss
        
        # Logging
        self.log('train_loss', loss, on_epoch=True, prog_bar=True)
        self.log('train_task_loss', task_loss, on_epoch=True)
        self.log('train_concept_loss', concept_loss, on_epoch=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y, c_true = batch
        c_logits, y_logits = self.forward(x, c_true=c_true, train=False)
        
        # Task loss
        task_loss = self.task_loss_fn(y_logits.squeeze(), y.squeeze())
        
        # Concept loss (NOW MEANINGFUL - validation has true concept labels!)
        concept_loss = self.concept_loss_fn(c_logits, c_true)
        
        # Combined loss
        loss = task_loss + self.concept_loss_weight * concept_loss
        
        # Logging
        self.log('val_loss', loss, on_epoch=True, prog_bar=True)
        self.log('val_task_loss', task_loss, on_epoch=True)
        self.log('val_concept_loss', concept_loss, on_epoch=True)
        with torch.no_grad():
            self.log(
                "train_c_logit_mean",
                c_logits.mean(),
                on_epoch=True,
                prog_bar=False
            )
            self.log(
                "train_c_logit_std",
                c_logits.std(),
                on_epoch=True,
                prog_bar=False
            )

        return loss
    
    def configure_optimizers(self):
        return torch.optim.Adam(
            self.parameters(),
            lr=self.hparams.learning_rate,
            weight_decay=self.hparams.weight_decay
        )

print("✓ Custom CEM model defined")

✓ Custom CEM model defined


## Section 4: Model Initialization

In [14]:
# Initialize Custom CEM model
custom_cem = CustomCEM(
    n_concepts=HYPERPARAMS['n_concepts'],
    emb_size=HYPERPARAMS['emb_size'],
    input_dim=HYPERPARAMS['embedding_dim'],
    shared_prob_gen=HYPERPARAMS['shared_prob_gen'],
    intervention_prob=HYPERPARAMS['intervention_prob'],
    concept_loss_weight=HYPERPARAMS['concept_loss_weight'],
    learning_rate=HYPERPARAMS['learning_rate'],
    weight_decay=HYPERPARAMS['weight_decay'],
    use_ldam_loss=HYPERPARAMS['use_ldam_loss'],
    n_positive=HYPERPARAMS['n_positive'],
    n_negative=HYPERPARAMS['n_negative'],
    ldam_max_margin=HYPERPARAMS['ldam_max_margin'],
    ldam_scale=HYPERPARAMS['ldam_scale'],
    concept_temperature=HYPERPARAMS["concept_temperature"],
    concept_pos_weight=concept_pos_weight,
)

print("✓ Custom CEM model initialized")
print(f"  Using LDAM Loss (margin={HYPERPARAMS['ldam_max_margin']}, scale={HYPERPARAMS['ldam_scale']})")
print(f"  Concept embedding size: {HYPERPARAMS['emb_size']}")
print(f"  Intervention probability: {HYPERPARAMS['intervention_prob']}")
print(f"  Shared probability generator: {HYPERPARAMS['shared_prob_gen']}")
print(f"  Class counts: {HYPERPARAMS['n_positive']} positive, {HYPERPARAMS['n_negative']} negative")

✓ Custom CEM model initialized
  Using LDAM Loss (margin=0.7, scale=40)
  Concept embedding size: 128
  Intervention probability: 0.1
  Shared probability generator: True
  Class counts: 66 positive, 322 negative


## Section 5: Training

In [15]:
# Setup trainer
checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",
    dirpath=os.path.join(OUTPUT_DIR, "models"),
    filename="alternate-cem-{epoch:02d}-{val_loss:.2f}",
    save_top_k=1,
    mode="min"
)
early_stop_cb = EarlyStopping(
    monitor="val_concept_loss",
    patience=10,
    mode="min",
    verbose=True
)

trainer = pl.Trainer(
    max_epochs=HYPERPARAMS['max_epochs'],
    accelerator=DEVICE,
    devices=1,
    logger=CSVLogger(save_dir=os.path.join(OUTPUT_DIR, "logs"), name="alternate_cem_pipeline"),
    log_every_n_steps=10,
    callbacks=[checkpoint_callback, early_stop_cb],
    enable_progress_bar=True
)

print("✓ Trainer configured")

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


✓ Trainer configured


In [16]:
# Train model
print("\nStarting training...\n")
trainer.fit(custom_cem, train_loader, val_loader)
print("\n✓ Training complete!")


Starting training...



  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")

  | Name              | Type              | Params
--------------------------------------------------------
0 | concept_extractor | Sequential        | 164 K 
1 | context_layers    | ModuleList        | 1.4 M 
2 | prob_generator    | Linear            | 257   
3 | task_classifier   | Sequential        | 344 K 
4 | concept_loss_fn   | BCEWithLogitsLoss | 0     
5 | task_loss_fn      | LDAMLoss          | 0     
--------------------------------------------------------
1.9 M     Trainable params
0         Non-trainable params
1.9 M     Total params
7.562     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_concept_loss improved. New best score: 0.549


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_concept_loss improved by 0.011 >= min_delta = 0.0. New best score: 0.538


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_concept_loss improved by 0.102 >= min_delta = 0.0. New best score: 0.435


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_concept_loss improved by 0.097 >= min_delta = 0.0. New best score: 0.338


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Monitored metric val_concept_loss did not improve in the last 10 records. Best score: 0.338. Signaling Trainer to stop.



✓ Training complete!


In [17]:
# ===============================
# STEP 5.1: Concept temperature calibration (validation)
# ===============================

custom_cem.eval()
device_obj = torch.device(DEVICE)
custom_cem = custom_cem.to(device_obj)

temps = torch.linspace(0.5, 5.0, steps=20)
best_temp = None
best_concept_loss = float("inf")

with torch.no_grad():
    for T in temps:
        total_loss = 0.0
        n_batches = 0

        for x_batch, _, c_batch in val_loader:
            x_batch = x_batch.to(device_obj)
            c_batch = c_batch.to(device_obj)

            c_logits, _ = custom_cem(x_batch)
            c_probs = torch.sigmoid(c_logits / T)

            loss = F.binary_cross_entropy(
                c_probs, c_batch, reduction="mean"
            )

            total_loss += loss.item()
            n_batches += 1

        avg_loss = total_loss / n_batches

        if avg_loss < best_concept_loss:
            best_concept_loss = avg_loss
            best_temp = T.item()

print(f"✓ Best concept temperature found: {best_temp:.3f}")


✓ Best concept temperature found: 0.500


In [18]:
custom_cem.concept_temperature = best_temp

In [19]:
# ===============================
# Concept importance via gradients
# ===============================

custom_cem.eval()
device_obj = torch.device(DEVICE)
custom_cem = custom_cem.to(device_obj)

concept_grad_accumulator = torch.zeros(N_CONCEPTS, device=device_obj)
n_samples = 0

for x_batch, _, _ in val_loader:
    x_batch = x_batch.to(device_obj)
    x_batch.requires_grad = True

    c_logits, y_logits = custom_cem(x_batch)
    y_prob = torch.sigmoid(y_logits).mean()

    grads = torch.autograd.grad(
        outputs=y_prob,
        inputs=c_logits,
        retain_graph=False,
        create_graph=False
    )[0]

    concept_grad_accumulator += grads.abs().mean(dim=0)
    n_samples += 1

concept_importance = (concept_grad_accumulator / n_samples).cpu().numpy()

# Normalize for readability
concept_importance /= concept_importance.sum()

print("\n" + "="*70)
print("            CONCEPT IMPORTANCE (GRADIENT-BASED)")
print("="*70)

for name, score in zip(CONCEPT_NAMES, concept_importance):
    print(f"{name:<35}: {score:.4f}")



            CONCEPT IMPORTANCE (GRADIENT-BASED)
Sadness                            : 0.0287
Pessimism                          : 0.0601
Past failure                       : 0.0060
Loss of pleasure                   : 0.0050
Guilty feelings                    : 0.0837
Punishment feelings                : 0.0528
Self-dislike                       : 0.1029
Self-criticalness                  : 0.0612
Suicidal thoughts or wishes        : 0.0815
Crying                             : 0.0099
Agitation                          : 0.0098
Loss of interest                   : 0.0087
Indecisiveness                     : 0.0175
Worthlessness                      : 0.0551
Loss of energy                     : 0.0866
Changes in sleeping pattern        : 0.0475
Irritability                       : 0.0067
Changes in appetite                : 0.0336
Concentration difficulty           : 0.0625
Tiredness or fatigue               : 0.1636
Loss of interest in sex            : 0.0166


In [20]:
# ===============================
# Concept ablation study
# ===============================

custom_cem.eval()

baseline_probs = []
baseline_true = []

with torch.no_grad():
    for x_batch, y_batch, _ in val_loader:
        x_batch = x_batch.to(device_obj)
        _, y_logits = custom_cem(x_batch)
        baseline_probs.extend(torch.sigmoid(y_logits).cpu().numpy().squeeze().tolist())
        baseline_true.extend(y_batch.numpy().astype(int).tolist())

baseline_probs = np.array(baseline_probs)
baseline_true = np.array(baseline_true)
best_threshold = 0.5

baseline_mcc = matthews_corrcoef(
    baseline_true,
    (baseline_probs >= best_threshold).astype(int)
)

print(f"\nBaseline MCC: {baseline_mcc:.4f}")

print("\nConcept ablation MCC drops:")
print("="*70)

for i, name in enumerate(CONCEPT_NAMES):
    probs = []
    true = []

    with torch.no_grad():
        for x_batch, y_batch, c_batch in val_loader:
            x_batch = x_batch.to(device_obj)

            c_logits, y_logits = custom_cem(x_batch)
            c_logits[:, i] = 0.0

            y_probs = torch.sigmoid(y_logits).cpu().numpy().squeeze()

            probs.extend(y_probs.tolist())
            true.extend(y_batch.numpy().astype(int).tolist())

    mcc = matthews_corrcoef(true, (np.array(probs) >= best_threshold).astype(int))
    drop = baseline_mcc - mcc

    print(f"{name:<35}: ΔMCC = {drop:+.4f}")



Baseline MCC: 0.2220

Concept ablation MCC drops:
Sadness                            : ΔMCC = +0.0000
Pessimism                          : ΔMCC = +0.0000
Past failure                       : ΔMCC = +0.0000
Loss of pleasure                   : ΔMCC = +0.0000
Guilty feelings                    : ΔMCC = +0.0000
Punishment feelings                : ΔMCC = +0.0000
Self-dislike                       : ΔMCC = +0.0000
Self-criticalness                  : ΔMCC = +0.0000
Suicidal thoughts or wishes        : ΔMCC = +0.0000
Crying                             : ΔMCC = +0.0000
Agitation                          : ΔMCC = +0.0000
Loss of interest                   : ΔMCC = +0.0000
Indecisiveness                     : ΔMCC = +0.0000
Worthlessness                      : ΔMCC = +0.0000
Loss of energy                     : ΔMCC = +0.0000
Changes in sleeping pattern        : ΔMCC = +0.0000
Irritability                       : ΔMCC = +0.0000
Changes in appetite                : ΔMCC = +0.0000
Concentration

## Section 6: Test Evaluation

In [21]:
# -----------------------------
# VALIDATION THRESHOLD SELECTION
# -----------------------------
print("\nSelecting decision threshold on validation set...")

custom_cem.eval()

# Move model to device
device_obj = torch.device(DEVICE)
custom_cem = custom_cem.to(device_obj)

y_val_true = []
y_val_prob = []

with torch.no_grad():
    for x_batch, y_batch, c_batch in val_loader:
        x_batch = x_batch.to(device_obj)

        _, y_logits = custom_cem(x_batch)
        y_probs = torch.sigmoid(y_logits).cpu().squeeze().numpy()

        y_val_true.extend(y_batch.numpy().astype(int).tolist())
        y_val_prob.extend(y_probs.tolist() if isinstance(y_probs, np.ndarray) else [y_probs])

y_val_true = np.array(y_val_true)
y_val_prob = np.array(y_val_prob)

best_threshold = 0.5
best_precision = 0.0
target_recall = 0.80

for threshold in np.linspace(0.01, 0.50, 50):
    y_pred_temp = (y_val_prob >= threshold).astype(int)

    if np.sum(y_pred_temp) == 0:
        continue

    recall = recall_score(y_val_true, y_pred_temp)
    precision = precision_score(y_val_true, y_pred_temp)

    if recall >= target_recall and precision > best_precision:
        best_precision = precision
        best_threshold = threshold

print(f"✓ Selected validation threshold: {best_threshold:.2f}")
print(f"✓ Validation recall ≥ {target_recall}, precision = {best_precision:.3f}")


Selecting decision threshold on validation set...
✓ Selected validation threshold: 0.50
✓ Validation recall ≥ 0.8, precision = 0.000


In [22]:
# Run inference on test set
print("Running inference on test set...")

custom_cem.eval()
device_obj = torch.device(DEVICE)
custom_cem = custom_cem.to(device_obj)

y_true_list = []
y_prob_list = []
concept_probs_list = []

with torch.no_grad():
    for x_batch, y_batch, c_batch in test_loader:
        x_batch = x_batch.to(device_obj)
        
        c_logits, y_logits = custom_cem(x_batch)
        c_probs = torch.sigmoid(c_logits).cpu().numpy()
        y_probs = torch.sigmoid(y_logits).cpu().squeeze().numpy()
        
        y_true_list.extend(y_batch.numpy().astype(int).tolist())
        y_prob_list.extend(y_probs.tolist() if isinstance(y_probs, np.ndarray) else [y_probs])
        concept_probs_list.extend(c_probs.tolist())

y_true = np.array(y_true_list)
y_prob = np.array(y_prob_list)
concept_probs = np.array(concept_probs_list)

print("✓ Inference complete")

Running inference on test set...
✓ Inference complete


In [23]:
#Printing concept stats and probabilities
print("\n" + "=" * 70)
print("            CONCEPT ACTIVATION PROBABILITY STATISTICS")
print("=" * 70)

for i, concept_name in enumerate(CONCEPT_NAMES):
    probs = concept_probs[:, i]

    print(f"\nConcept: {concept_name}")
    print(f"  Mean probability:      {np.mean(probs):.4f}")
    print(f"  Std deviation:         {np.std(probs):.4f}")
    print(f"  Min probability:       {np.min(probs):.4f}")
    print(f"  Max probability:       {np.max(probs):.4f}")
    print(f"  Median probability:    {np.median(probs):.4f}")
    print(f"  25th percentile:       {np.percentile(probs, 25):.4f}")
    print(f"  75th percentile:       {np.percentile(probs, 75):.4f}")
print("\n" + "=" * 70)
print("            SAMPLE CONCEPT PROBABILITIES (FIRST 5)")
print("=" * 70)

n_samples_to_show = min(5, concept_probs.shape[0])

for i, concept_name in enumerate(CONCEPT_NAMES):
    probs = concept_probs[:n_samples_to_show, i]
    probs_str = ", ".join([f"{p:.3f}" for p in probs])
    print(f"{concept_name:<30}: {probs_str}")


            CONCEPT ACTIVATION PROBABILITY STATISTICS

Concept: Sadness
  Mean probability:      0.3065
  Std deviation:         0.1001
  Min probability:       0.0007
  Max probability:       0.4374
  Median probability:    0.3314
  25th percentile:       0.2575
  75th percentile:       0.3874

Concept: Pessimism
  Mean probability:      0.3237
  Std deviation:         0.2254
  Min probability:       0.0000
  Max probability:       0.9893
  Median probability:    0.2551
  25th percentile:       0.1536
  75th percentile:       0.4543

Concept: Past failure
  Mean probability:      0.0944
  Std deviation:         0.1647
  Min probability:       0.0000
  Max probability:       0.7055
  Median probability:    0.0018
  25th percentile:       0.0002
  75th percentile:       0.1119

Concept: Loss of pleasure
  Mean probability:      0.2404
  Std deviation:         0.0943
  Min probability:       0.0000
  Max probability:       0.3928
  Median probability:    0.2543
  25th percentile:       

In [24]:
y_pred = (y_prob >= best_threshold).astype(int)

## Section 7: Results Display

In [25]:
# Compute all metrics
cm = confusion_matrix(y_true, y_pred)
tn, fp, fn, tp = cm.ravel()

acc = accuracy_score(y_true, y_pred)
balanced_acc = balanced_accuracy_score(y_true, y_pred)
roc_auc = roc_auc_score(y_true, y_prob)
mcc = matthews_corrcoef(y_true, y_pred)
f1_binary = f1_score(y_true, y_pred, pos_label=1)
f1_macro = f1_score(y_true, y_pred, average='macro')
precision_binary = precision_score(y_true, y_pred, pos_label=1)
recall_binary = recall_score(y_true, y_pred, pos_label=1)

# Print results
print("\n" + "="*70)
print("                    TEST SET EVALUATION")
print("="*70)
print(f"\nDecision Threshold: {best_threshold:.2f}")

# Enhanced Confusion Matrix Display
print(f"\n{'CONFUSION MATRIX':^50}")
print("="*50)
print(f"{'':>20} │ {'Predicted Negative':^12} │ {'Predicted Positive':^12}")
print("─"*50)
print(f"{'Actual Negative':>20} │ {f'TN = {tn}':^12} │ {f'FP = {fp}':^12}")
print(f"{'Actual Positive':>20} │ {f'FN = {fn}':^12} │ {f'TP = {tp}':^12}")
print("="*50)
print(f"\n  True Positives:  {tp:>3}/{int(np.sum(y_true)):<3} ({100*tp/np.sum(y_true):>5.1f}% of depression cases caught)")
print(f"  False Negatives: {fn:>3}/{int(np.sum(y_true)):<3} ({100*fn/np.sum(y_true):>5.1f}% of depression cases MISSED)")
print(f"  True Negatives:  {tn:>3}/{int(len(y_true)-np.sum(y_true)):<3} ({100*tn/(len(y_true)-np.sum(y_true)):>5.1f}% of healthy correctly identified)")
print(f"  False Positives: {fp:>3}/{int(len(y_true)-np.sum(y_true)):<3} ({100*fp/(len(y_true)-np.sum(y_true)):>5.1f}% false alarms)")

print(f"\nPerformance Metrics:")
print(f"  Accuracy:                  {acc:.4f}")
print(f"  Balanced Accuracy:         {balanced_acc:.4f}")
print(f"  ROC-AUC:                   {roc_auc:.4f}")
print(f"  Matthews Correlation:      {mcc:.4f}")
print(f"\n  F1 Score (Binary):         {f1_binary:.4f}")
print(f"  F1 Score (Macro):          {f1_macro:.4f}")
print(f"  Precision (Binary):        {precision_binary:.4f}")
print(f"  Recall (Binary):           {recall_binary:.4f}")

print("\n" + classification_report(y_true, y_pred, target_names=['Negative', 'Positive']))
print("="*70)


                    TEST SET EVALUATION

Decision Threshold: 0.50

                 CONFUSION MATRIX                 
                     │ Predicted Negative │ Predicted Positive
──────────────────────────────────────────────────
     Actual Negative │   TN = 250   │   FP = 99   
     Actual Positive │   FN = 25    │   TP = 27   

  True Positives:   27/52  ( 51.9% of depression cases caught)
  False Negatives:  25/52  ( 48.1% of depression cases MISSED)
  True Negatives:  250/349 ( 71.6% of healthy correctly identified)
  False Positives:  99/349 ( 28.4% false alarms)

Performance Metrics:
  Accuracy:                  0.6908
  Balanced Accuracy:         0.6178
  ROC-AUC:                   0.6777
  Matthews Correlation:      0.1705

  F1 Score (Binary):         0.3034
  F1 Score (Macro):          0.5523
  Precision (Binary):        0.2143
  Recall (Binary):           0.5192

              precision    recall  f1-score   support

    Negative       0.91      0.72      0.80       349


In [26]:
# Save results
metrics_dict = {
    "model_type": "alternate_cem",
    "dataset": "alternative_attention_pipeline",
    "threshold": float(best_threshold),
    "n_samples": int(len(y_true)),
    "n_positive": int(np.sum(y_true)),
    "n_negative": int(len(y_true) - np.sum(y_true)),
    "accuracy": float(acc),
    "balanced_accuracy": float(balanced_acc),
    "roc_auc": float(roc_auc),
    "mcc": float(mcc),
    "f1_binary": float(f1_binary),
    "f1_macro": float(f1_macro),
    "precision_binary": float(precision_binary),
    "recall_binary": float(recall_binary),
    "confusion_matrix": {"tn": int(tn), "fp": int(fp), "fn": int(fn), "tp": int(tp)}
}

os.makedirs(os.path.join(OUTPUT_DIR, "results"), exist_ok=True)
with open(os.path.join(OUTPUT_DIR, "results/test_metrics.json"), 'w') as f:
    json.dump(metrics_dict, f, indent=4)

# Save predictions
predictions_df = pd.DataFrame({
    'subject_id': test_subject_ids,
    'y_true': y_true,
    'y_pred': y_pred,
    'y_prob': y_prob
})

for i, concept_name in enumerate(CONCEPT_NAMES):
    predictions_df[concept_name] = concept_probs[:, i]

predictions_df.to_csv(os.path.join(OUTPUT_DIR, "results/test_predictions.csv"), index=False)

print(f"✓ Results saved to {OUTPUT_DIR}/results/")

✓ Results saved to outputs_alternate_cem/results/


In [27]:
print("\n" + "="*70)
print("        ALTERNATE CEM TRAINING COMPLETE (SUM-BASED DATASET)")
print("="*70)
print(f"\nGenerated files:")
print(f"  Model checkpoint: {OUTPUT_DIR}/models/")
print(f"  Metrics JSON:     {OUTPUT_DIR}/results/test_metrics.json")
print(f"  Predictions CSV:  {OUTPUT_DIR}/results/test_predictions.csv")
print("\nKey differences from original pipeline:")
print("  - Uses SUM-based concept scoring (captures multi-concept posts)")
print("  - Validation set has TRUE concept labels (from train split)")
print("  - Better concept supervision during validation")
print("="*70)


        ALTERNATE CEM TRAINING COMPLETE (SUM-BASED DATASET)

Generated files:
  Model checkpoint: outputs_alternate_cem/models/
  Metrics JSON:     outputs_alternate_cem/results/test_metrics.json
  Predictions CSV:  outputs_alternate_cem/results/test_predictions.csv

Key differences from original pipeline:
  - Uses SUM-based concept scoring (captures multi-concept posts)
  - Validation set has TRUE concept labels (from train split)
  - Better concept supervision during validation
