In [1]:
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import torch.nn as nn
import cem
import torchmetrics
from cbm_model import SimpleCBM
import pytorch_lightning as pl
from pytorch_lightning.loggers import CSVLogger
import os
import matplotlib.pyplot as plt

# Loader #

In [2]:
# Load prepared data
data = np.load("../data/processed/cem_input.npz", allow_pickle=True)
X = data["X"]                # embeddings (n_subjects, input_dim)
C = data["C"]                # concept matrix (n_subjects, n_concepts)
y = data["y"]                # labels (n_subjects,)
subject_ids = data["subject_ids"]
concept_names = data["concept_names"]

print("Shapes:")
print("X:", X.shape, "C:", C.shape, "y:", y.shape)
print("Concepts:", concept_names)

Shapes:
X: (486, 1152) C: (486, 21) y: (486,)
Concepts: ['Sadness' 'Pessimism' 'Past failure' 'Loss of pleasure' 'Guilty feelings'
 'Punishment feelings' 'Self-dislike' 'Self-criticalness'
 'Suicidal thoughts or wishes' 'Crying' 'Agitation' 'Loss of interest'
 'Indecisiveness' 'Worthlessness' 'Loss of energy'
 'Changes in sleeping pattern' 'Irritability' 'Changes in appetite'
 'Concentration difficulty' 'Tiredness or fatigue'
 'Loss of interest in sex']


In [3]:
print(X.shape)

(486, 1152)


In [4]:
# PyTorch Dataset 
class CEMDataset(Dataset):
    def __init__(self, X, C, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.C = torch.tensor(C, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32)  # <-- float, not long!

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx], self.C[idx]

In [5]:
#  Train/Val/Test split
train_idx, test_idx = train_test_split(
    np.arange(len(y)), test_size=0.2, stratify=y, random_state=42
)
train_idx, val_idx = train_test_split(
    train_idx, test_size=0.2, stratify=y[train_idx], random_state=42
)

print(f"Train: {len(train_idx)}, Val: {len(val_idx)}, Test: {len(test_idx)}")

Train: 310, Val: 78, Test: 98


In [6]:
# Recreate datasets (using previously computed splits)
train_ds = CEMDataset(X[train_idx], C[train_idx], y[train_idx])
val_ds   = CEMDataset(X[val_idx], C[val_idx], y[val_idx])
test_ds  = CEMDataset(X[test_idx], C[test_idx], y[test_idx])

# Create DataLoaders
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
val_loader   = DataLoader(val_ds, batch_size=64, shuffle=False)
test_loader  = DataLoader(test_ds, batch_size=64, shuffle=False)

# Sanity check: inspect one batch and shapes
xb, yb, cb = next(iter(train_loader))
print("x batch:", xb.shape)   # (B, input_dim)
print("y batch:", yb.shape)   # (B,)  <-- task labels
print("c batch:", cb.shape)   # (B, n_concepts)

x batch: torch.Size([32, 1152])
y batch: torch.Size([32])
c batch: torch.Size([32, 21])


In [7]:
# Parameters 
input_dim = X.shape[1]          # SBERT embedding size (e.g., 384)
n_concepts = C.shape[1]         # number of concepts from your questionnaire
n_tasks = 1                     # binary classification (depressed vs control)

# Define concept extractor architecture
def c_extractor_arch(output_dim: int):
    layers = [
        nn.Linear(input_dim, 256),
        nn.ReLU(),
        nn.Dropout(0.3),
    ]
    if output_dim is not None:
        layers.append(nn.Linear(256, output_dim))
    else:
        # keep dimensionality stable for CEM to inspect
        layers.append(nn.Linear(256, 256))
    return nn.Sequential(*layers)


# Instantiate Simple Concept Bottleneck Model
cbm_model = SimpleCBM(
    n_concepts=n_concepts,
    n_tasks=n_tasks,
    input_dim=input_dim,
    c_extractor_arch=c_extractor_arch,
    concept_loss_weight=1.0,
    task_loss_weight=1.0,
)

print(cbm_model)


SimpleCBM(
  (x_to_c_model): Sequential(
    (0): Linear(in_features=1152, out_features=256, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.3, inplace=False)
    (3): Linear(in_features=256, out_features=21, bias=True)
  )
  (c_to_y_model): Linear(in_features=21, out_features=1, bias=True)
  (loss_concept): BCEWithLogitsLoss()
  (loss_task): BCEWithLogitsLoss()
)


In [8]:

# Reproducibility
pl.seed_everything(42)

# Logger (saves metrics to logs/)
logger = CSVLogger(save_dir="../logs", name="cbm_experiment")

# Determine accelerator
if torch.cuda.is_available():
    accelerator = "gpu"
elif torch.backends.mps.is_available():
    accelerator = "mps"
else:
    accelerator = "cpu"

# Trainer
trainer = pl.Trainer(
    max_epochs=100,
    accelerator=accelerator,
    devices=1 if accelerator == "mps" else "auto",
    logger=logger,
    log_every_n_steps=10,
    check_val_every_n_epoch=1
)

# Train the CBM model
trainer.fit(cbm_model, train_loader, val_loader)

Global seed set to 42
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: ../logs/cbm_experiment

  | Name         | Type              | Params
---------------------------------------------------
0 | x_to_c_model | Sequential        | 300 K 
1 | c_to_y_model | Linear            | 22    
2 | loss_concept | BCEWithLogitsLoss | 0     
3 | loss_task    | BCEWithLogitsLoss | 0     
---------------------------------------------------
300 K     Trainable params
0         Non-trainable params
300 K     Total params
1.202     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.


In [10]:
# ---------- Evaluation block ----------
import numpy as np
import torch
from scipy.special import expit  # sigmoid
from sklearn.metrics import (
    confusion_matrix,
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    matthews_corrcoef,
    roc_auc_score,
    balanced_accuracy_score,
    classification_report,
)

# Put model in eval mode
cbm_model.eval()

y_true_list = []
y_pred_list = []
y_prob_list = []

# Determine device to put inputs on
device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
cbm_model.to(device)

with torch.no_grad():
    for xb, yb, cb in test_loader:
        xb = xb.to(device)
        
        _, task_logits = cbm_model(xb)

        # move logits to cpu numpy for metric computation
        task_logits = task_logits.detach().cpu().squeeze()

        # If logits shape is (B,1) or (B,), make sure we have (B,)
        if task_logits.ndim == 2 and task_logits.shape[1] == 1:
            task_logits = task_logits[:, 0]

        # Convert ground truth to numpy ints
        yb_numpy = yb.numpy().astype(int).ravel()

        # Convert logits to probabilities using sigmoid for binary
        y_probs = expit(task_logits.numpy()) if isinstance(task_logits, torch.Tensor) else expit(task_logits)

        y_pred = (y_probs >= 0.5).astype(int)

        y_true_list.extend(yb_numpy.tolist())
        y_pred_list.extend(y_pred.tolist())
        y_prob_list.extend(y_probs.tolist())

# Convert to numpy arrays
y_true = np.array(y_true_list, dtype=int)
y_pred = np.array(y_pred_list, dtype=int)
y_prob = np.array(y_prob_list, dtype=float)

# Basic checks
assert y_true.shape == y_pred.shape, "y_true and y_pred shape mismatch"
n_samples = len(y_true)

# Confusion matrix
cm = confusion_matrix(y_true, y_pred)
# For binary classification scikit returns array([[TN, FP], [FN, TP]])
tn, fp, fn, tp = cm.ravel() if cm.size == 4 else (None, None, None, None)

# Metrics
acc = accuracy_score(y_true, y_pred)
precision_macro = precision_score(y_true, y_pred, average="macro", zero_division=0)
recall_macro = recall_score(y_true, y_pred, average="macro", zero_division=0)
f1_macro = f1_score(y_true, y_pred, average="macro", zero_division=0)

precision_micro = precision_score(y_true, y_pred, average="micro", zero_division=0)
recall_micro = recall_score(y_true, y_pred, average="micro", zero_division=0)
f1_micro = f1_score(y_true, y_pred, average="micro", zero_division=0)

precision_binary = precision_score(y_true, y_pred, pos_label=1, zero_division=0)
recall_binary = recall_score(y_true, y_pred, pos_label=1, zero_division=0)
f1_binary = f1_score(y_true, y_pred, pos_label=1, zero_division=0)

mcc = matthews_corrcoef(y_true, y_pred)
balanced_acc = balanced_accuracy_score(y_true, y_pred)

# ROC AUC for binary, requires at least two classes present in y_true
try:
    roc_auc = roc_auc_score(y_true, y_prob)
except ValueError:
    roc_auc = float("nan")  # not defined if only one class present in y_true

# Print results
print("\nEvaluation summary")
print("==================")
print(f"Samples evaluated: {n_samples}")
print(f"Accuracy: {acc:.4f}")
print(f"Balanced accuracy: {balanced_acc:.4f}")
print(f"ROC AUC: {roc_auc:.4f}")
print(f"Matthews correlation coefficient: {mcc:.4f}")
print("")
print("F1 scores:")
print(f"  binary (positive label=1): {f1_binary:.4f}")
print(f"  macro average: {f1_macro:.4f}")
print(f"  micro average: {f1_micro:.4f}")
print("")
print("Confusion matrix (rows=true, cols=pred):")
print(cm)
if tn is not None:
    print(f"TN: {tn}, FP: {fp}, FN: {fn}, TP: {tp}")
print("")
print("Classification report (precision, recall, f1, support):")
print(classification_report(y_true, y_pred, zero_division=0))

# Optionally, save metrics to the trainer logger CSV or a json file
metrics_dict = {
    "n_samples": n_samples,
    "accuracy": float(acc),
    "balanced_accuracy": float(balanced_acc),
    "roc_auc": float(roc_auc) if not np.isnan(roc_auc) else None,
    "mcc": float(mcc),
    "f1_binary": float(f1_binary),
    "f1_macro": float(f1_macro),
    "f1_micro": float(f1_micro),
}
# Example saving to a file
import json
with open("../logs/eval_metrics_cbm.json", "w") as fh:
    json.dump(metrics_dict, fh, indent=2)
print("Saved metrics to ../logs/eval_metrics_cbm.json")



Evaluation summary
Samples evaluated: 98
Accuracy: 0.9184
Balanced accuracy: 0.8344
ROC AUC: 0.9085
Matthews correlation coefficient: 0.7034

F1 scores:
  binary (positive label=1): 0.7500
  macro average: 0.8506
  micro average: 0.9184

Confusion matrix (rows=true, cols=pred):
[[78  3]
 [ 5 12]]
TN: 78, FP: 3, FN: 5, TP: 12

Classification report (precision, recall, f1, support):
              precision    recall  f1-score   support

           0       0.94      0.96      0.95        81
           1       0.80      0.71      0.75        17

    accuracy                           0.92        98
   macro avg       0.87      0.83      0.85        98
weighted avg       0.92      0.92      0.92        98

Saved metrics to ../logs/eval_metrics_cbm.json
