In [26]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import torch.nn as nn
import cem
from cem.models.cem import ConceptEmbeddingModel
import pytorch_lightning as pl
from pytorch_lightning.loggers import CSVLogger


# Loader #

In [27]:
# Load prepared data
data = np.load("../data/processed/cem_input.npz", allow_pickle=True)
X = data["X"]                # embeddings (n_subjects, input_dim)
C = data["C"]                # concept matrix (n_subjects, n_concepts)
y = data["y"]                # labels (n_subjects,)
subject_ids = data["subject_ids"]
concept_names = data["concept_names"]

print("Shapes:")
print("X:", X.shape, "C:", C.shape, "y:", y.shape)
print("Concepts:", concept_names)

Shapes:
X: (486, 1152) C: (486, 21) y: (486,)
Concepts: ['Sadness' 'Pessimism' 'Past failure' 'Loss of pleasure' 'Guilty feelings'
 'Punishment feelings' 'Self-dislike' 'Self-criticalness'
 'Suicidal thoughts or wishes' 'Crying' 'Agitation' 'Loss of interest'
 'Indecisiveness' 'Worthlessness' 'Loss of energy'
 'Changes in sleeping pattern' 'Irritability' 'Changes in appetite'
 'Concentration difficulty' 'Tiredness or fatigue'
 'Loss of interest in sex']


In [40]:
print(X.shape)

(486, 1152)


In [33]:
# PyTorch Dataset 
class CEMDataset(Dataset):
    def __init__(self, X, C, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.C = torch.tensor(C, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32)  # <-- float, not long!

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx], self.C[idx]


In [34]:
#  Train/Val/Test split
train_idx, test_idx = train_test_split(
    np.arange(len(y)), test_size=0.2, stratify=y, random_state=42
)
train_idx, val_idx = train_test_split(
    train_idx, test_size=0.2, stratify=y[train_idx], random_state=42
)

print(f"Train: {len(train_idx)}, Val: {len(val_idx)}, Test: {len(test_idx)}")

Train: 310, Val: 78, Test: 98


In [35]:
# Recreate datasets (using your previously computed splits)
train_ds = CEMDataset(X[train_idx], C[train_idx], y[train_idx])
val_ds   = CEMDataset(X[val_idx], C[val_idx], y[val_idx])
test_ds  = CEMDataset(X[test_idx], C[test_idx], y[test_idx])

# Create DataLoaders
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
val_loader   = DataLoader(val_ds, batch_size=64, shuffle=False)
test_loader  = DataLoader(test_ds, batch_size=64, shuffle=False)

# Sanity check: inspect one batch and shapes
xb, yb, cb = next(iter(train_loader))
print("x batch:", xb.shape)   # (B, input_dim)
print("y batch:", yb.shape)   # (B,)  <-- task labels
print("c batch:", cb.shape)   # (B, n_concepts)

x batch: torch.Size([32, 1152])
y batch: torch.Size([32])
c batch: torch.Size([32, 21])


In [36]:
# Parameters 
input_dim = X.shape[1]          # SBERT embedding size (e.g., 384)
n_concepts = C.shape[1]         # number of concepts from your questionnaire
n_tasks = 1                     # binary classification (depressed vs control)
emb_size = 128                  # size of concept embedding inside CEM

# Define concept extractor architecture
def c_extractor_arch(output_dim: int):
    layers = [
        nn.Linear(input_dim, 256),
        nn.ReLU(),
        nn.Dropout(0.3),
    ]
    if output_dim is not None:
        layers.append(nn.Linear(256, output_dim))
    else:
        # keep dimensionality stable for CEM to inspect
        layers.append(nn.Linear(256, 256))
    return nn.Sequential(*layers)


# Instantiate Concept Embedding Model
cem_model = ConceptEmbeddingModel(
    n_concepts=n_concepts,
    n_tasks=n_tasks,
    emb_size=emb_size,
    concept_loss_weight=1.0,
    training_intervention_prob=0.25,
    c_extractor_arch=c_extractor_arch,
    c2y_model=None  
)

print(cem_model)


ConceptEmbeddingModel(
  (pre_concept_model): Sequential(
    (0): Linear(in_features=1152, out_features=256, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.3, inplace=False)
    (3): Linear(in_features=256, out_features=256, bias=True)
  )
  (concept_context_generators): ModuleList(
    (0-20): 21 x Sequential(
      (0): Linear(in_features=256, out_features=256, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
    )
  )
  (concept_prob_generators): ModuleList(
    (0): Linear(in_features=256, out_features=1, bias=True)
  )
  (c2y_model): Sequential(
    (0): Linear(in_features=2688, out_features=1, bias=True)
  )
  (sig): Sigmoid()
  (loss_concept): BCELoss()
  (loss_task): BCEWithLogitsLoss()
)


In [37]:

# Reproducibility
pl.seed_everything(42)

# Logger (saves metrics to logs/)
logger = CSVLogger(save_dir="../logs", name="cem_experiment")

# Trainer
trainer = pl.Trainer(
    max_epochs=100,
    accelerator="gpu" if torch.cuda.is_available() else "cpu",
    devices="auto",
    logger=logger,
    log_every_n_steps=10,
    check_val_every_n_epoch=1
)

# Train the CEM model
trainer.fit(cem_model, train_loader, val_loader)


Global seed set to 42
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(

  | Name                       | Type              | Params
-----------------------------------------------------------------
0 | pre_concept_model          | Sequential        | 360 K 
1 | concept_context_generators | ModuleList        | 1.4 M 
2 | concept_prob_generators    | ModuleList        | 257   
3 | c2y_model                  | Sequential        | 2.7 K 
4 | sig                        | Sigmoid           | 0     
5 | loss_concept               | BCELoss           | 0     
6 | loss_task                  | BCEWithLogitsLoss | 0     
-----------------------------------------------------------------
1.7 M     Trainable params
0         Non-trainable params
1.7 M     Total params
6.982     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.


In [None]:
#--------WIP EVALUATION, just a sketch------
from sklearn.metrics import accuracy_score
from scipy.special import expit  # sigmoid

# Collect predictions
y_true, y_pred = [], []
cem_model.eval()

with torch.no_grad():
    for xb, yb, cb in test_loader:
        xb = xb.to(cem_model.device)
        # forward pass â†’ (c_pred, c_embs, y_pred_logits)
        _, _, y_logits = cem_model(xb)
        y_probs = expit(y_logits.cpu().numpy())  # apply sigmoid
        preds = (y_probs >= 0.5).astype(int)

        y_true.extend(yb.numpy())
        y_pred.extend(preds)

# Accuracy
acc = accuracy_score(y_true, y_pred)
print(f"Quick test accuracy: {acc*100:.2f}%")


Quick test accuracy: 91.84%
