# Complete CEM Pipeline - End-to-End Depression Detection

This notebook implements a complete pipeline for:
1. Loading and processing training/test data
2. Retrieving top-20 concept-relevant posts per subject
3. Averaging post embeddings into single vectors
4. Training a Concept Embedding Model (CEM)
5. Evaluating with detailed metrics and concept probabilities

## Section 0: Configuration & Setup

In [2]:
# Imports
import os
import glob
import re
import zipfile
import tempfile
import shutil
import json
import time
from collections import defaultdict

import numpy as np
import pandas as pd
import xml.etree.ElementTree as ET

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from sentence_transformers import SentenceTransformer, util
from scipy.special import expit  # sigmoid

import pytorch_lightning as pl
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks import ModelCheckpoint

from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    confusion_matrix,
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    matthews_corrcoef,
    roc_auc_score,
    balanced_accuracy_score,
    classification_report,
)

from patched_model import PatchedConceptEmbeddingModel

print("✓ All imports successful")

  from tqdm.autonotebook import tqdm, trange


✓ All imports successful


In [3]:
# Set random seeds for reproducibility
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
pl.seed_everything(SEED)

print(f"✓ Random seed set to {SEED}")

Global seed set to 42


✓ Random seed set to 42


In [4]:
# Detect device (MPS/CUDA/CPU)
if torch.backends.mps.is_available():
    DEVICE = "mps"
    print("✓ Using MacBook GPU (MPS)")
elif torch.cuda.is_available():
    DEVICE = "cuda"
    print("✓ Using CUDA GPU")
else:
    DEVICE = "cpu"
    print("⚠ Using CPU (will be slow)")

✓ Using MacBook GPU (MPS)


In [5]:
# Define paths
PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), ".."))
DATA_RAW = os.path.join(PROJECT_ROOT, "data/raw")
DATA_PROCESSED = os.path.join(PROJECT_ROOT, "data/processed")
OUTPUT_DIR = "outputs"

# Training data paths (in train/ subdirectory)
POS_DIR = os.path.join(DATA_RAW, "train/positive_examples_anonymous_chunks")
NEG_DIR = os.path.join(DATA_RAW, "train/negative_examples_anonymous_chunks")

# Test data paths
TEST_DIR = os.path.join(DATA_RAW, "test")
TEST_LABELS = os.path.join(TEST_DIR, "test_golden_truth.txt")

# Concept labels
CONCEPTS_FILE = os.path.join(DATA_PROCESSED, "merged_questionnaires.csv")

print("✓ Paths configured")
print(f"  Project root: {PROJECT_ROOT}")
print(f"  Positive dir: {POS_DIR}")
print(f"  Negative dir: {NEG_DIR}")
print(f"  Output dir: {OUTPUT_DIR}")

✓ Paths configured
  Project root: /Users/gualtieromarencoturi/Desktop/thesis/Master-Thesis-CEM-Depression-etc-case-study
  Positive dir: /Users/gualtieromarencoturi/Desktop/thesis/Master-Thesis-CEM-Depression-etc-case-study/data/raw/train/positive_examples_anonymous_chunks
  Negative dir: /Users/gualtieromarencoturi/Desktop/thesis/Master-Thesis-CEM-Depression-etc-case-study/data/raw/train/negative_examples_anonymous_chunks
  Output dir: outputs


In [6]:
# Define 21 BDI-II concept names
CONCEPT_NAMES = [
    "Sadness", "Pessimism", "Past failure", "Loss of pleasure",
    "Guilty feelings", "Punishment feelings", "Self-dislike", "Self-criticalness",
    "Suicidal thoughts or wishes", "Crying", "Agitation", "Loss of interest",
    "Indecisiveness", "Worthlessness", "Loss of energy", "Changes in sleeping pattern",
    "Irritability", "Changes in appetite", "Concentration difficulty",
    "Tiredness or fatigue", "Loss of interest in sex"
]
N_CONCEPTS = len(CONCEPT_NAMES)

print(f"✓ Defined {N_CONCEPTS} BDI-II concepts")

✓ Defined 21 BDI-II concepts


In [7]:
# Hyperparameters
HYPERPARAMS = {
    # Data
    "k_posts": 20,              # Top-k posts per subject
    "sbert_model": "all-MiniLM-L6-v2",
    "embedding_dim": 384,       # SBERT embedding dimension
    
    # Model architecture
    "n_concepts": 21,
    "n_tasks": 1,
    "emb_size": 128,
    
    # Training
    "batch_size_train": 32,
    "batch_size_eval": 64,
    "max_epochs": 100,
    "learning_rate": 0.01,
    "weight_decay": 4e-05,
    
    # Loss weights
    "concept_loss_weight": 1.0,
    "training_intervention_prob": 0.25,
    
    # Focal Loss (set use_focal_loss=True to enable)
    "use_focal_loss": False,     # Set to True to use Focal Loss instead of BCE
    "focal_loss_alpha": 0.17,    # Proportion of positive class (83/486 ≈ 0.17)
    "focal_loss_gamma": 2.0,     # Focusing parameter (higher = more focus on hard examples)
}

print("✓ Hyperparameters configured:")
for k, v in HYPERPARAMS.items():
    print(f"  {k}: {v}")

✓ Hyperparameters configured:
  k_posts: 20
  sbert_model: all-MiniLM-L6-v2
  embedding_dim: 384
  n_concepts: 21
  n_tasks: 1
  emb_size: 128
  batch_size_train: 32
  batch_size_eval: 64
  max_epochs: 100
  learning_rate: 0.01
  weight_decay: 4e-05
  concept_loss_weight: 1.0
  training_intervention_prob: 0.25
  use_focal_loss: False
  focal_loss_alpha: 0.17
  focal_loss_gamma: 2.0


## Section 1: Load Training Data

Extract 486 training subjects with posts and concept labels

In [8]:
# Helper functions for XML parsing
WHITESPACE_RE = re.compile(r"\s+")

def normalize_text(text):
    """Normalize text by removing null chars and extra whitespace."""
    if not text:
        return ""
    text = text.replace("\u0000", "")
    text = WHITESPACE_RE.sub(" ", text).strip()
    return text

def extract_posts_from_xml(xml_path, min_chars=10):
    """Extract posts from a single XML file."""
    try:
        tree = ET.parse(xml_path)
        root = tree.getroot()
    except Exception as e:
        print(f"WARNING: Failed to parse {xml_path}: {e}")
        return []
    
    posts = []
    for writing in root.findall("WRITING"):
        title = writing.findtext("TITLE") or ""
        text = writing.findtext("TEXT") or ""
        
        combined = normalize_text(f"{title} {text}".strip())
        if len(combined) >= min_chars:
            posts.append(combined)
    
    return posts

print("✓ Helper functions defined")

✓ Helper functions defined


In [9]:
# Parse training XML files
print("Loading training data...")
start_time = time.time()

train_data = []

# Process positive examples
print("  Processing positive examples...")
pos_files = glob.glob(os.path.join(POS_DIR, "**", "*.xml"), recursive=True)
for xml_file in pos_files:
    filename = os.path.basename(xml_file)
    # Extract subject_id (e.g., train_subject1095_1.xml -> subject1095)
    match = re.match(r"train_(subject\d+)_\d+\.xml", filename)
    if match:
        subject_id = match.group(1)
        posts = extract_posts_from_xml(xml_file)
        for post in posts:
            train_data.append({
                "subject_id": subject_id,
                "label": 1,  # Positive (depression)
                "text": post
            })

print(f"  Loaded {len([d for d in train_data if d['label']==1])} posts from positive subjects")

# Process negative examples
print("  Processing negative examples...")
neg_files = glob.glob(os.path.join(NEG_DIR, "**", "*.xml"), recursive=True)
for xml_file in neg_files:
    filename = os.path.basename(xml_file)
    match = re.match(r"train_(subject\d+)_\d+\.xml", filename)
    if match:
        subject_id = match.group(1)
        posts = extract_posts_from_xml(xml_file)
        for post in posts:
            train_data.append({
                "subject_id": subject_id,
                "label": 0,  # Negative (control)
                "text": post
            })

train_posts_df = pd.DataFrame(train_data)

print(f"\n✓ Loaded training data in {time.time()-start_time:.1f}s")
print(f"  Total posts: {len(train_posts_df):,}")
print(f"  Unique subjects: {train_posts_df['subject_id'].nunique()}")
print(f"  Label distribution:")
print(train_posts_df.groupby('label')['subject_id'].nunique())

Loading training data...
  Processing positive examples...
  Loaded 29868 posts from positive subjects
  Processing negative examples...

✓ Loaded training data in 2.8s
  Total posts: 286,740
  Unique subjects: 486
  Label distribution:
label
0    403
1     83
Name: subject_id, dtype: int64


In [10]:
# Load concept labels from questionnaires
print("Loading concept labels...")

concepts_df = pd.read_csv(CONCEPTS_FILE)
# Clean subject_id (remove "train_" prefix)
concepts_df["subject_id"] = concepts_df["Subject"].str.replace("train_", "", regex=True)

# Binarize concept values (any value > 0 becomes 1)
concept_cols = [col for col in concepts_df.columns if col in CONCEPT_NAMES]
for col in concept_cols:
    concepts_df[col] = (concepts_df[col] > 0).astype(int)

print(f"✓ Loaded concept labels for {len(concepts_df)} subjects")
print(f"  Concept columns: {len(concept_cols)}")
print(f"  First few subjects:")
print(concepts_df[["subject_id", "Diagnosis"] + concept_cols[:3]].head())

Loading concept labels...
✓ Loaded concept labels for 486 subjects
  Concept columns: 21
  First few subjects:
    subject_id  Diagnosis  Sadness  Pessimism  Past failure
0  subject4550          0        0          0             0
1  subject4181          0        0          0             0
2  subject8202          0        0          0             0
3  subject6783          0        0          0             0
4  subject1642          0        0          0             0


## Section 2: Load Test Data

Extract 401 test subjects and split into validation (200) and test (201)

In [11]:
# Extract test ZIP files to temporary directory
print("Extracting test data...")
temp_dir = tempfile.mkdtemp(prefix="test_chunks_")
print(f"  Temp directory: {temp_dir}")

for i in range(1, 11):
    zip_path = os.path.join(TEST_DIR, f"chunk {i}.zip")
    if os.path.exists(zip_path):
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(os.path.join(temp_dir, f"chunk_{i}"))
        if i % 3 == 0:
            print(f"  Extracted chunk {i}/10")
    else:
        print(f"  WARNING: {zip_path} not found")

print("✓ Test data extracted")

Extracting test data...
  Temp directory: /var/folders/gb/m6c_r5xx6_14p7mlfjwk29900000gn/T/test_chunks_804t3r5a


  Extracted chunk 3/10
  Extracted chunk 6/10
  Extracted chunk 9/10
✓ Test data extracted


In [12]:
# Load test labels
test_labels_df = pd.read_csv(TEST_LABELS, sep='\t', header=None, names=['subject_id', 'label'])
test_labels_df['subject_id'] = test_labels_df['subject_id'].str.strip()

print(f"✓ Loaded test labels for {len(test_labels_df)} subjects")
print(f"  Label distribution:")
print(test_labels_df['label'].value_counts())

✓ Loaded test labels for 401 subjects
  Label distribution:
label
0    349
1     52
Name: count, dtype: int64


In [13]:
# Parse test XML files
print("Loading test posts...")
test_data = []

test_xml_files = glob.glob(os.path.join(temp_dir, "**", "*.xml"), recursive=True)
print(f"  Found {len(test_xml_files)} XML files")

for xml_file in test_xml_files:
    filename = os.path.basename(xml_file)
    match = re.match(r"(test_subject\d+)_\d+\.xml", filename)
    if match:
        subject_id = match.group(1)
        # Get label from test_labels_df
        label_row = test_labels_df[test_labels_df['subject_id'] == subject_id]
        if len(label_row) > 0:
            label = label_row.iloc[0]['label']
            posts = extract_posts_from_xml(xml_file)
            for post in posts:
                test_data.append({
                    "subject_id": subject_id,
                    "label": label,
                    "text": post
                })

test_posts_df = pd.DataFrame(test_data)

print(f"✓ Loaded test posts")
print(f"  Total posts: {len(test_posts_df):,}")
print(f"  Unique subjects: {test_posts_df['subject_id'].nunique()}")

Loading test posts...
  Found 4010 XML files
✓ Loaded test posts
  Total posts: 229,746
  Unique subjects: 401


In [14]:
# Split test data into validation and test sets (stratified 50/50)
print("Splitting test data into validation and test...")

# Get unique subjects with labels
test_subjects = test_posts_df.groupby('subject_id')['label'].first().reset_index()

# Stratified split
val_subjects, test_subjects_final = train_test_split(
    test_subjects['subject_id'],
    test_size=0.5,
    stratify=test_subjects['label'],
    random_state=SEED
)

# Create validation and test dataframes
val_posts_df = test_posts_df[test_posts_df['subject_id'].isin(val_subjects)].copy()
test_posts_df_final = test_posts_df[test_posts_df['subject_id'].isin(test_subjects_final)].copy()

print(f"✓ Split complete")
print(f"  Validation: {val_posts_df['subject_id'].nunique()} subjects, {len(val_posts_df)} posts")
print(f"  Test: {test_posts_df_final['subject_id'].nunique()} subjects, {len(test_posts_df_final)} posts")
print(f"\n  Validation label distribution:")
print(val_posts_df.groupby('label')['subject_id'].nunique())
print(f"\n  Test label distribution:")
print(test_posts_df_final.groupby('label')['subject_id'].nunique())

Splitting test data into validation and test...
✓ Split complete
  Validation: 200 subjects, 116544 posts
  Test: 201 subjects, 113202 posts

  Validation label distribution:
label
0    174
1     26
Name: subject_id, dtype: int64

  Test label distribution:
label
0    175
1     26
Name: subject_id, dtype: int64


## Section 3: SBERT Setup & Concept Embeddings

In [15]:
# Load SBERT model
print(f"Loading SBERT model: {HYPERPARAMS['sbert_model']}")
sbert_model = SentenceTransformer(HYPERPARAMS['sbert_model'])
sbert_model = sbert_model.to(DEVICE)

print(f"✓ SBERT model loaded on {DEVICE}")
print(f"  Embedding dimension: {sbert_model.get_sentence_embedding_dimension()}")

Loading SBERT model: all-MiniLM-L6-v2
✓ SBERT model loaded on mps
  Embedding dimension: 384


In [16]:
# Create concept embeddings
print(f"Creating embeddings for {N_CONCEPTS} concepts...")
concept_embeddings = sbert_model.encode(
    CONCEPT_NAMES,
    convert_to_tensor=True,
    show_progress_bar=False
)

print(f"✓ Concept embeddings created")
print(f"  Shape: {concept_embeddings.shape}")

Creating embeddings for 21 concepts...
✓ Concept embeddings created
  Shape: torch.Size([21, 384])


## Section 4: Post Retrieval (Top-20 per Subject)

Select the 20 most concept-relevant posts for each subject

In [17]:
def retrieve_top_k_posts(subject_id, posts_df, concept_embs, sbert, k=20):
    """
    Retrieve top-k posts for a subject based on concept similarity.
    
    Returns:
        List of k selected post texts
    """
    # Get subject's posts
    subj_posts = posts_df[posts_df['subject_id'] == subject_id]['text'].tolist()
    
    if len(subj_posts) == 0:
        return []
    
    # Handle subjects with fewer than k posts
    if len(subj_posts) <= k:
        if len(subj_posts) < k:
            # Pad with random duplicates
            extra_needed = k - len(subj_posts)
            padding = list(np.random.choice(subj_posts, size=extra_needed, replace=True))
            return subj_posts + padding
        else:
            return subj_posts
    
    # Encode all subject's posts
    post_embeddings = sbert.encode(
        subj_posts,
        convert_to_tensor=True,
        show_progress_bar=False
    )
    
    # Compute cosine similarity: (num_posts, num_concepts)
    cos_scores = util.cos_sim(post_embeddings, concept_embs)
    
    # For each post, take max similarity across all concepts
    max_sim_scores = cos_scores.max(dim=1).values.cpu().numpy()
    
    # Select top-k posts
    top_k_indices = np.argpartition(-max_sim_scores, range(min(k, len(subj_posts))))[:k]
    
    selected_posts = [subj_posts[i] for i in top_k_indices]
    
    return selected_posts

print("✓ Post retrieval function defined")

✓ Post retrieval function defined


In [18]:
# Retrieve top-k posts for all subjects
print(f"Retrieving top-{HYPERPARAMS['k_posts']} posts for all subjects...")
start_time = time.time()

# Training subjects
print("  Processing training subjects...")
train_selected = {}
train_subjects = train_posts_df['subject_id'].unique()
for idx, subject_id in enumerate(train_subjects):
    selected = retrieve_top_k_posts(
        subject_id,
        train_posts_df,
        concept_embeddings,
        sbert_model,
        k=HYPERPARAMS['k_posts']
    )
    train_selected[subject_id] = selected
    
    if (idx + 1) % 100 == 0:
        print(f"    Processed {idx + 1}/{len(train_subjects)} subjects")

# Validation subjects
print("  Processing validation subjects...")
val_selected = {}
val_subjects = val_posts_df['subject_id'].unique()
for idx, subject_id in enumerate(val_subjects):
    selected = retrieve_top_k_posts(
        subject_id,
        val_posts_df,
        concept_embeddings,
        sbert_model,
        k=HYPERPARAMS['k_posts']
    )
    val_selected[subject_id] = selected
    
    if (idx + 1) % 50 == 0:
        print(f"    Processed {idx + 1}/{len(val_subjects)} subjects")

# Test subjects
print("  Processing test subjects...")
test_selected = {}
test_subjects = test_posts_df_final['subject_id'].unique()
for idx, subject_id in enumerate(test_subjects):
    selected = retrieve_top_k_posts(
        subject_id,
        test_posts_df_final,
        concept_embeddings,
        sbert_model,
        k=HYPERPARAMS['k_posts']
    )
    test_selected[subject_id] = selected
    
    if (idx + 1) % 50 == 0:
        print(f"    Processed {idx + 1}/{len(test_subjects)} subjects")

print(f"\n✓ Post retrieval complete in {time.time()-start_time:.1f}s")
print(f"  Train: {len(train_selected)} subjects x {HYPERPARAMS['k_posts']} posts")
print(f"  Val: {len(val_selected)} subjects x {HYPERPARAMS['k_posts']} posts")
print(f"  Test: {len(test_selected)} subjects x {HYPERPARAMS['k_posts']} posts")

Retrieving top-20 posts for all subjects...
  Processing training subjects...
    Processed 100/486 subjects
    Processed 200/486 subjects
    Processed 300/486 subjects
    Processed 400/486 subjects
  Processing validation subjects...
    Processed 50/200 subjects
    Processed 100/200 subjects
    Processed 150/200 subjects
    Processed 200/200 subjects
  Processing test subjects...
    Processed 50/201 subjects
    Processed 100/201 subjects
    Processed 150/201 subjects
    Processed 200/201 subjects

✓ Post retrieval complete in 2327.5s
  Train: 486 subjects x 20 posts
  Val: 200 subjects x 20 posts
  Test: 201 subjects x 20 posts


## Section 5: Embedding Aggregation (Simple Averaging)

Encode the selected posts and average them into single embeddings per subject

In [19]:
def encode_and_average(selected_posts_dict, sbert, batch_size=64):
    """
    Encode selected posts and average them per subject.
    
    Args:
        selected_posts_dict: {subject_id: [post1, post2, ...], ...}
    
    Returns:
        averaged_embeddings: (n_subjects, embedding_dim)
        subject_ids: list of subject_ids in same order
    """
    subject_ids = list(selected_posts_dict.keys())
    averaged_embeddings = []
    
    for subject_id in subject_ids:
        posts = selected_posts_dict[subject_id]
        
        # Encode all posts for this subject
        post_embs = sbert.encode(
            posts,
            convert_to_numpy=True,
            show_progress_bar=False
        )
        
        # Average across posts (axis=0)
        avg_emb = np.mean(post_embs, axis=0)
        averaged_embeddings.append(avg_emb)
    
    return np.array(averaged_embeddings), subject_ids

print("✓ Encoding and averaging function defined")

✓ Encoding and averaging function defined


In [20]:
# Encode and average for all splits
print("Encoding and averaging embeddings...")
start_time = time.time()

print("  Training set...")
X_train, train_subject_ids = encode_and_average(train_selected, sbert_model)
print(f"    X_train shape: {X_train.shape}")

print("  Validation set...")
X_val, val_subject_ids = encode_and_average(val_selected, sbert_model)
print(f"    X_val shape: {X_val.shape}")

print("  Test set...")
X_test, test_subject_ids = encode_and_average(test_selected, sbert_model)
print(f"    X_test shape: {X_test.shape}")

print(f"\n✓ Encoding complete in {time.time()-start_time:.1f}s")

Encoding and averaging embeddings...
  Training set...
    X_train shape: (486, 384)
  Validation set...
    X_val shape: (200, 384)
  Test set...
    X_test shape: (201, 384)

✓ Encoding complete in 502.0s


In [21]:
# Build concept matrices and label vectors
print("Building concept matrices and labels...")

# Training: get concepts from questionnaires
C_train = []
y_train = []
for subject_id in train_subject_ids:
    # Get label
    label = train_posts_df[train_posts_df['subject_id'] == subject_id]['label'].iloc[0]
    y_train.append(label)
    
    # Get concepts from questionnaire
    concept_row = concepts_df[concepts_df['subject_id'] == subject_id]
    if len(concept_row) > 0:
        concepts = concept_row[concept_cols].values[0]
    else:
        # If missing, use zeros
        concepts = np.zeros(N_CONCEPTS)
    C_train.append(concepts)

C_train = np.array(C_train, dtype=np.float32)
y_train = np.array(y_train, dtype=np.float32)

# Validation: zeros for concepts (no ground truth)
C_val = np.zeros((len(val_subject_ids), N_CONCEPTS), dtype=np.float32)
y_val = []
for subject_id in val_subject_ids:
    label = val_posts_df[val_posts_df['subject_id'] == subject_id]['label'].iloc[0]
    y_val.append(label)
y_val = np.array(y_val, dtype=np.float32)

# Test: zeros for concepts
C_test = np.zeros((len(test_subject_ids), N_CONCEPTS), dtype=np.float32)
y_test = []
for subject_id in test_subject_ids:
    label = test_posts_df_final[test_posts_df_final['subject_id'] == subject_id]['label'].iloc[0]
    y_test.append(label)
y_test = np.array(y_test, dtype=np.float32)

print("✓ Matrices built")
print(f"  Train: X={X_train.shape}, C={C_train.shape}, y={y_train.shape}")
print(f"  Val:   X={X_val.shape}, C={C_val.shape}, y={y_val.shape}")
print(f"  Test:  X={X_test.shape}, C={C_test.shape}, y={y_test.shape}")
print(f"\n  Training label distribution: {np.bincount(y_train.astype(int))}")
print(f"  Validation label distribution: {np.bincount(y_val.astype(int))}")
print(f"  Test label distribution: {np.bincount(y_test.astype(int))}")

Building concept matrices and labels...
✓ Matrices built
  Train: X=(486, 384), C=(486, 21), y=(486,)
  Val:   X=(200, 384), C=(200, 21), y=(200,)
  Test:  X=(201, 384), C=(201, 21), y=(201,)

  Training label distribution: [403  83]
  Validation label distribution: [174  26]
  Test label distribution: [175  26]


### Class Imbalance Handling

Compute class weights to address the imbalance (403 negative vs 83 positive)

In [22]:
# Compute class weights for imbalanced dataset
n_negative = int(np.sum(y_train == 0))
n_positive = int(np.sum(y_train == 1))
pos_weight = n_negative / n_positive

# Convert to tensor
pos_weight_tensor = torch.tensor([pos_weight], dtype=torch.float32)

print(f"Class imbalance:")
print(f"  Negative samples: {n_negative}")
print(f"  Positive samples: {n_positive}")
print(f"  Ratio: 1:{pos_weight:.2f}")
print(f"  Computed pos_weight: {pos_weight:.4f}")

Class imbalance:
  Negative samples: 403
  Positive samples: 83
  Ratio: 1:4.86
  Computed pos_weight: 4.8554


### Save Datasets

Save the processed datasets to disk for later use

In [23]:
# Save processed datasets to disk for later use
SAVE_DIR = os.path.join(DATA_PROCESSED, "whole_pipeline")
os.makedirs(SAVE_DIR, exist_ok=True)

print("Saving datasets...")

# Save numpy arrays
np.savez_compressed(
    os.path.join(SAVE_DIR, "train_data.npz"),
    X=X_train,
    C=C_train,
    y=y_train,
    subject_ids=np.array(train_subject_ids)
)

np.savez_compressed(
    os.path.join(SAVE_DIR, "val_data.npz"),
    X=X_val,
    C=C_val,
    y=y_val,
    subject_ids=np.array(val_subject_ids)
)

np.savez_compressed(
    os.path.join(SAVE_DIR, "test_data.npz"),
    X=X_test,
    C=C_test,
    y=y_test,
    subject_ids=np.array(test_subject_ids)
)

# Save class weights info
class_info = {
    "n_positive": n_positive,
    "n_negative": n_negative,
    "pos_weight": float(pos_weight)
}

with open(os.path.join(SAVE_DIR, "class_weights.json"), 'w') as f:
    json.dump(class_info, f, indent=4)

print(f"✓ Datasets saved to {SAVE_DIR}")
print(f"  train_data.npz: {X_train.shape[0]} samples")
print(f"  val_data.npz:   {X_val.shape[0]} samples")
print(f"  test_data.npz:  {X_test.shape[0]} samples")
print(f"  class_weights.json")
print("\nTo load later:")
print("  train = np.load(os.path.join(DATA_PROCESSED, 'whole_pipeline/train_data.npz'))")
print("  X_train = train['X']")
print("  C_train = train['C']")
print("  y_train = train['y']")
print("  train_subject_ids = train['subject_ids']")

Saving datasets...
✓ Datasets saved to /Users/gualtieromarencoturi/Desktop/thesis/Master-Thesis-CEM-Depression-etc-case-study/data/processed/whole_pipeline
  train_data.npz: 486 samples
  val_data.npz:   200 samples
  test_data.npz:  201 samples
  class_weights.json

To load later:
  train = np.load(os.path.join(DATA_PROCESSED, 'whole_pipeline/train_data.npz'))
  X_train = train['X']
  C_train = train['C']
  y_train = train['y']
  train_subject_ids = train['subject_ids']


## Section 6: PyTorch Dataset & DataLoaders

In [24]:
class CEMDataset(Dataset):
    """PyTorch Dataset for CEM model."""
    
    def __init__(self, X, C, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.C = torch.tensor(C, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32)
    
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx], self.C[idx]

print("✓ CEMDataset class defined")

✓ CEMDataset class defined


In [25]:
# Create datasets
train_dataset = CEMDataset(X_train, C_train, y_train)
val_dataset = CEMDataset(X_val, C_val, y_val)
test_dataset = CEMDataset(X_test, C_test, y_test)

print("✓ Datasets created")
print(f"  Train: {len(train_dataset)} samples")
print(f"  Val: {len(val_dataset)} samples")
print(f"  Test: {len(test_dataset)} samples")

✓ Datasets created
  Train: 486 samples
  Val: 200 samples
  Test: 201 samples


In [26]:
# Create DataLoaders
train_loader = DataLoader(
    train_dataset,
    batch_size=HYPERPARAMS['batch_size_train'],
    shuffle=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=HYPERPARAMS['batch_size_eval'],
    shuffle=False
)

test_loader = DataLoader(
    test_dataset,
    batch_size=HYPERPARAMS['batch_size_eval'],
    shuffle=False
)

print("✓ DataLoaders created")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches: {len(val_loader)}")
print(f"  Test batches: {len(test_loader)}")

# Test batch
x_batch, y_batch, c_batch = next(iter(train_loader))
print(f"\n  Sample batch shapes:")
print(f"    X: {x_batch.shape}")
print(f"    y: {y_batch.shape}")
print(f"    C: {c_batch.shape}")

✓ DataLoaders created
  Train batches: 16
  Val batches: 4
  Test batches: 4

  Sample batch shapes:
    X: torch.Size([32, 384])
    y: torch.Size([32])
    C: torch.Size([32, 21])


## Section 7: CEM Model Initialization

In [27]:
def c_extractor_arch(output_dim):
    """Concept extractor architecture."""
    return nn.Sequential(
        nn.Linear(HYPERPARAMS['embedding_dim'], 256),
        nn.ReLU(),
        nn.Dropout(0.3),
        nn.Linear(256, output_dim or 256)
    )

print("✓ Concept extractor architecture defined")

✓ Concept extractor architecture defined


In [28]:
# Initialize CEM model
cem_model = PatchedConceptEmbeddingModel(
    n_concepts=HYPERPARAMS['n_concepts'],
    n_tasks=HYPERPARAMS['n_tasks'],
    input_dim=HYPERPARAMS['embedding_dim'],
    emb_size=HYPERPARAMS['emb_size'],
    concept_loss_weight=HYPERPARAMS['concept_loss_weight'],
    training_intervention_prob=HYPERPARAMS['training_intervention_prob'],
    c_extractor_arch=c_extractor_arch,
    learning_rate=HYPERPARAMS['learning_rate'],
    weight_decay=HYPERPARAMS['weight_decay'],
    c2y_model=None,
    task_class_weights=pos_weight_tensor,  # Use class weights for imbalanced data
    use_focal_loss=HYPERPARAMS['use_focal_loss'],
    focal_loss_alpha=HYPERPARAMS['focal_loss_alpha'],
    focal_loss_gamma=HYPERPARAMS['focal_loss_gamma']
)

print("✓ CEM model initialized")
if HYPERPARAMS['use_focal_loss']:
    print(f"  Using Focal Loss (alpha={HYPERPARAMS['focal_loss_alpha']}, gamma={HYPERPARAMS['focal_loss_gamma']})")
else:
    print(f"  Using BCE Loss with pos_weight={pos_weight:.4f}")
print(cem_model)

✓ CEM model initialized
  Using BCE Loss with pos_weight=4.8554
PatchedConceptEmbeddingModel(
  (pre_concept_model): Sequential(
    (0): Linear(in_features=384, out_features=256, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.3, inplace=False)
    (3): Linear(in_features=256, out_features=256, bias=True)
  )
  (concept_context_generators): ModuleList(
    (0-20): 21 x Sequential(
      (0): Linear(in_features=256, out_features=256, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
    )
  )
  (concept_prob_generators): ModuleList(
    (0): Linear(in_features=256, out_features=1, bias=True)
  )
  (c2y_model): Sequential(
    (0): Linear(in_features=2688, out_features=1, bias=True)
  )
  (loss_concept): BCEWithLogitsLoss()
  (loss_task): BCEWithLogitsLoss()
)


## Section 8: Training

In [29]:
# Setup trainer
checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",
    dirpath=os.path.join(OUTPUT_DIR, "models"),
    filename="cem-{epoch:02d}-{val_loss:.2f}",
    save_top_k=1,
    mode="min"
)

trainer = pl.Trainer(
    max_epochs=HYPERPARAMS['max_epochs'],
    accelerator=DEVICE,
    devices=1,
    logger=CSVLogger(save_dir=os.path.join(OUTPUT_DIR, "logs"), name="cem_pipeline"),
    log_every_n_steps=10,
    callbacks=[checkpoint_callback],
    enable_progress_bar=True
)

print("✓ Trainer configured")
print(f"  Device: {DEVICE}")
print(f"  Max epochs: {HYPERPARAMS['max_epochs']}")

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


✓ Trainer configured
  Device: mps
  Max epochs: 100


In [30]:
# Train model
print("\nStarting training...\n")
print("="*70)

trainer.fit(cem_model, train_loader, val_loader)

print("="*70)
print("\n✓ Training complete!")


Starting training...



  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")

  | Name                       | Type              | Params
-----------------------------------------------------------------
0 | pre_concept_model          | Sequential        | 164 K 
1 | concept_context_generators | ModuleList        | 1.4 M 
2 | concept_prob_generators    | ModuleList        | 257   
3 | c2y_model                  | Sequential        | 2.7 K 
4 | loss_concept               | BCEWithLogitsLoss | 0     
5 | loss_task                  | BCEWithLogitsLoss | 0     
-----------------------------------------------------------------
1.5 M     Trainable params
0         Non-trainable params
1.5 M     Total params
6.196     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.



✓ Training complete!


## Section 9: Test Evaluation

In [31]:
# Set model to evaluation mode
cem_model.eval()

# Move model to device
device_obj = torch.device(DEVICE)
cem_model = cem_model.to(device_obj)

print("✓ Model set to evaluation mode")

✓ Model set to evaluation mode


In [32]:
# Run inference on test set
print("Running inference on test set...")

y_true_list = []
y_pred_list = []
y_prob_list = []
concept_probs_list = []

with torch.no_grad():
    for x_batch, y_batch, c_batch in test_loader:
        x_batch = x_batch.to(device_obj)
        
        # Forward pass
        c_logits, _, y_logits = cem_model(x_batch)
        
        # Apply sigmoid to get probabilities
        c_probs = torch.sigmoid(c_logits).cpu().numpy()
        y_probs = torch.sigmoid(y_logits).cpu().squeeze().numpy()
        
        # Threshold at 0.5 for predictions
        y_pred = (y_probs >= 0.5).astype(int)
        
        # Collect results
        y_true_list.extend(y_batch.numpy().astype(int).tolist())
        y_pred_list.extend(y_pred.tolist() if isinstance(y_pred, np.ndarray) else [y_pred])
        y_prob_list.extend(y_probs.tolist() if isinstance(y_probs, np.ndarray) else [y_probs])
        concept_probs_list.extend(c_probs.tolist())

# Convert to arrays
y_true = np.array(y_true_list)
y_pred = np.array(y_pred_list)
y_prob = np.array(y_prob_list)
concept_probs = np.array(concept_probs_list)

print("✓ Inference complete")
print(f"  Predictions shape: {y_pred.shape}")
print(f"  Concept probs shape: {concept_probs.shape}")

Running inference on test set...
✓ Inference complete
  Predictions shape: (201,)
  Concept probs shape: (201, 21)


## Section 10: Metrics & Results Display

In [33]:
# Compute all metrics
print("Computing metrics...")

# Confusion matrix
cm = confusion_matrix(y_true, y_pred)
tn, fp, fn, tp = cm.ravel()

# Metrics
acc = accuracy_score(y_true, y_pred)
balanced_acc = balanced_accuracy_score(y_true, y_pred)
roc_auc = roc_auc_score(y_true, y_prob)
mcc = matthews_corrcoef(y_true, y_pred)

f1_binary = f1_score(y_true, y_pred, pos_label=1)
f1_macro = f1_score(y_true, y_pred, average='macro')
f1_micro = f1_score(y_true, y_pred, average='micro')

precision_binary = precision_score(y_true, y_pred, pos_label=1)
recall_binary = recall_score(y_true, y_pred, pos_label=1)

print("✓ Metrics computed")

Computing metrics...
✓ Metrics computed


In [34]:
# Print formatted results
print("\n" + "="*70)
print("                    TEST SET EVALUATION")
print("="*70)
print()
print(f"Dataset Statistics:")
print(f"  Test subjects:        {len(y_true)}")
print(f"  Positive cases:       {np.sum(y_true)} ({100*np.sum(y_true)/len(y_true):.1f}%)")
print(f"  Negative cases:       {len(y_true)-np.sum(y_true)} ({100*(len(y_true)-np.sum(y_true))/len(y_true):.1f}%)")
print()
print(f"Performance Metrics:")
print(f"  Accuracy:                  {acc:.4f}")
print(f"  Balanced Accuracy:         {balanced_acc:.4f}")
print(f"  ROC-AUC:                   {roc_auc:.4f}")
print(f"  Matthews Correlation:      {mcc:.4f}")
print()
print(f"  F1 Score (Binary):         {f1_binary:.4f}")
print(f"  F1 Score (Macro):          {f1_macro:.4f}")
print(f"  F1 Score (Micro):          {f1_micro:.4f}")
print()
print(f"  Precision (Binary):        {precision_binary:.4f}")
print(f"  Recall (Binary):           {recall_binary:.4f}")
print()
print(f"Confusion Matrix:")
print(f"                    Predicted Neg    Predicted Pos")
print(f"Actual Neg          {tn:^16d} {fp:^16d}")
print(f"Actual Pos          {fn:^16d} {tp:^16d}")
print()
print("Classification Report:")
print(classification_report(y_true, y_pred, target_names=['Negative', 'Positive']))
print("="*70)


                    TEST SET EVALUATION

Dataset Statistics:
  Test subjects:        201
  Positive cases:       26 (12.9%)
  Negative cases:       175 (87.1%)

Performance Metrics:
  Accuracy:                  0.8507
  Balanced Accuracy:         0.6851
  ROC-AUC:                   0.8798
  Matthews Correlation:      0.3587

  F1 Score (Binary):         0.4444
  F1 Score (Macro):          0.6791
  F1 Score (Micro):          0.8507

  Precision (Binary):        0.4286
  Recall (Binary):           0.4615

Confusion Matrix:
                    Predicted Neg    Predicted Pos
Actual Neg                159               16       
Actual Pos                 14               12       

Classification Report:
              precision    recall  f1-score   support

    Negative       0.92      0.91      0.91       175
    Positive       0.43      0.46      0.44        26

    accuracy                           0.85       201
   macro avg       0.67      0.69      0.68       201
weighted avg     

In [35]:
# Save metrics to JSON
metrics_dict = {
    "n_samples": int(len(y_true)),
    "n_positive": int(np.sum(y_true)),
    "n_negative": int(len(y_true) - np.sum(y_true)),
    "accuracy": float(acc),
    "balanced_accuracy": float(balanced_acc),
    "roc_auc": float(roc_auc),
    "mcc": float(mcc),
    "f1_binary": float(f1_binary),
    "f1_macro": float(f1_macro),
    "f1_micro": float(f1_micro),
    "precision_binary": float(precision_binary),
    "recall_binary": float(recall_binary),
    "confusion_matrix": {
        "tn": int(tn),
        "fp": int(fp),
        "fn": int(fn),
        "tp": int(tp)
    }
}

os.makedirs(os.path.join(OUTPUT_DIR, "results"), exist_ok=True)
metrics_path = os.path.join(OUTPUT_DIR, "results/test_metrics.json")

with open(metrics_path, 'w') as f:
    json.dump(metrics_dict, f, indent=4)

print(f"✓ Metrics saved to {metrics_path}")

✓ Metrics saved to outputs/results/test_metrics.json


In [36]:
# Create predictions DataFrame with concept probabilities
predictions_df = pd.DataFrame({
    'subject_id': test_subject_ids,
    'y_true': y_true,
    'y_pred': y_pred,
    'y_prob': y_prob
})

# Add concept probabilities
for i, concept_name in enumerate(CONCEPT_NAMES):
    predictions_df[concept_name] = concept_probs[:, i]

# Save to CSV
predictions_path = os.path.join(OUTPUT_DIR, "results/test_predictions.csv")
predictions_df.to_csv(predictions_path, index=False)

print(f"✓ Predictions saved to {predictions_path}")
print(f"\nFirst 10 subjects with concept probabilities:")
print(predictions_df.head(10))

✓ Predictions saved to outputs/results/test_predictions.csv

First 10 subjects with concept probabilities:
         subject_id  y_true  y_pred        y_prob   Sadness  Pessimism  \
0  test_subject4471       1       0  2.887360e-03  0.041965   0.097764   
1  test_subject8981       0       0  2.285121e-09  0.005434   0.058009   
2  test_subject8777       0       0  2.217279e-21  0.000088   0.014087   
3  test_subject1372       0       0  4.164572e-10  0.004552   0.055387   
4  test_subject1830       0       0  1.097032e-05  0.014577   0.037514   
5  test_subject3791       0       0  3.811828e-21  0.000133   0.017802   
6  test_subject2284       0       0  1.210110e-09  0.004736   0.049253   
7  test_subject5689       0       0  3.861887e-12  0.002627   0.033180   
8  test_subject7467       1       0  1.204088e-08  0.007306   0.068030   
9  test_subject7578       0       0  4.536292e-26  0.000011   0.002672   

   Past failure  Loss of pleasure  Guilty feelings  Punishment feelings  ...  

In [37]:
# Display concept activation statistics
print("\nConcept Activation Statistics:")
print("="*70)
print(f"{'Concept':<35} {'Mean':>10} {'Std':>10} {'Max':>10}")
print("-"*70)
for i, concept_name in enumerate(CONCEPT_NAMES):
    mean_act = np.mean(concept_probs[:, i])
    std_act = np.std(concept_probs[:, i])
    max_act = np.max(concept_probs[:, i])
    print(f"{concept_name:<35} {mean_act:>10.4f} {std_act:>10.4f} {max_act:>10.4f}")
print("="*70)


Concept Activation Statistics:
Concept                                   Mean        Std        Max
----------------------------------------------------------------------
Sadness                                 0.0356     0.0693     0.3065
Pessimism                               0.1028     0.1734     0.7077
Past failure                            0.0968     0.1896     0.7864
Loss of pleasure                        0.0262     0.0619     0.3623
Guilty feelings                         0.0196     0.0479     0.2155
Punishment feelings                     0.0507     0.0846     0.3506
Self-dislike                            0.1501     0.2114     0.9281
Self-criticalness                       0.0467     0.0904     0.3340
Suicidal thoughts or wishes             0.0138     0.0351     0.1587
Crying                                  0.0091     0.0229     0.1382
Agitation                               0.0066     0.0127     0.0497
Loss of interest                        0.0526     0.0954     0.4219


## Section 11: Cleanup

In [38]:
# Clean up temporary directory
try:
    shutil.rmtree(temp_dir)
    print(f"✓ Cleaned up temporary directory: {temp_dir}")
except Exception as e:
    print(f"⚠ Failed to clean up temporary directory: {e}")

✓ Cleaned up temporary directory: /var/folders/gb/m6c_r5xx6_14p7mlfjwk29900000gn/T/test_chunks_804t3r5a


In [39]:
print("\n" + "="*70)
print("              PIPELINE EXECUTION COMPLETE")
print("="*70)
print("\nGenerated files:")
print(f"  Model checkpoint: {OUTPUT_DIR}/models/")
print(f"  Metrics JSON:     {OUTPUT_DIR}/results/test_metrics.json")
print(f"  Predictions CSV:  {OUTPUT_DIR}/results/test_predictions.csv")
print(f"  Training logs:    {OUTPUT_DIR}/logs/")
print("="*70)


              PIPELINE EXECUTION COMPLETE

Generated files:
  Model checkpoint: outputs/models/
  Metrics JSON:     outputs/results/test_metrics.json
  Predictions CSV:  outputs/results/test_predictions.csv
  Training logs:    outputs/logs/
