# Task-Optimized Attention CEM - Unified Pipeline

**Purpose:** End-to-end pipeline with learnable attention optimized for task loss.

**Key Innovation:** Attention weights learned via backpropagation from depression classification loss instead of static concept-based similarity.

**Runtime:** ~1 hour (data prep: 40min, training: 15-20min)

This notebook:
1. Loads and retrieves top-k concept-relevant posts
2. Encodes posts WITHOUT pooling (stores individual embeddings)
3. Trains CEM with learnable attention module
4. Analyzes attention patterns
5. Evaluates on test set

## PART 1: DATA PREPARATION

### Section 0: Configuration & Setup

In [1]:
# Imports
import os
import glob
import re
import zipfile
import tempfile
import shutil
import json
import time

import numpy as np
import pandas as pd
import xml.etree.ElementTree as ET

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

from sentence_transformers import SentenceTransformer, util

import pytorch_lightning as pl
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks import ModelCheckpoint

import torchmetrics
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    confusion_matrix,
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    matthews_corrcoef,
    roc_auc_score,
    balanced_accuracy_score,
    classification_report,
)

from tqdm import tqdm

print("✓ All imports successful")

  from tqdm.autonotebook import tqdm, trange


✓ All imports successful


In [2]:
# Set random seeds for reproducibility
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
pl.seed_everything(SEED)

print(f"✓ Random seed set to {SEED}")

Global seed set to 42


✓ Random seed set to 42


In [3]:
# Detect device (MPS/CUDA/CPU)
if torch.backends.mps.is_available():
    DEVICE = "mps"
    print("✓ Using MacBook GPU (MPS)")
elif torch.cuda.is_available():
    DEVICE = "cuda"
    print("✓ Using CUDA GPU")
else:
    DEVICE = "cpu"
    print("⚠ Using CPU (will be slow)")

✓ Using MacBook GPU (MPS)


In [4]:
# Define paths
PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), ".."))
DATA_RAW = os.path.join(PROJECT_ROOT, "data/raw")
DATA_PROCESSED = os.path.join(PROJECT_ROOT, "data/processed")

# Training data paths
POS_DIR = os.path.join(DATA_RAW, "train/positive_examples_anonymous_chunks")
NEG_DIR = os.path.join(DATA_RAW, "train/negative_examples_anonymous_chunks")

# Test data paths
TEST_DIR = os.path.join(DATA_RAW, "test")
TEST_LABELS = os.path.join(TEST_DIR, "test_golden_truth.txt")

# Concept labels
CONCEPTS_FILE = os.path.join(DATA_PROCESSED, "merged_questionnaires.csv")

# Output directory
SAVE_DIR = os.path.join(DATA_PROCESSED, "attention_task_optimized")
os.makedirs(SAVE_DIR, exist_ok=True)

OUTPUT_DIR = "outputs_task_optimized_attention"
os.makedirs(OUTPUT_DIR, exist_ok=True)

print("✓ Paths configured")
print(f"  Project root: {PROJECT_ROOT}")
print(f"  Data save dir: {SAVE_DIR}")
print(f"  Output dir: {OUTPUT_DIR}")

✓ Paths configured
  Project root: /Users/gualtieromarencoturi/Desktop/thesis/Master-Thesis-CEM-Depression-etc-case-study
  Data save dir: /Users/gualtieromarencoturi/Desktop/thesis/Master-Thesis-CEM-Depression-etc-case-study/data/processed/attention_task_optimized
  Output dir: outputs_task_optimized_attention


In [5]:
# Define 21 BDI-II concept names
CONCEPT_NAMES = [
    "Sadness", "Pessimism", "Past failure", "Loss of pleasure",
    "Guilty feelings", "Punishment feelings", "Self-dislike", "Self-criticalness",
    "Suicidal thoughts or wishes", "Crying", "Agitation", "Loss of interest",
    "Indecisiveness", "Worthlessness", "Loss of energy", "Changes in sleeping pattern",
    "Irritability", "Changes in appetite", "Concentration difficulty",
    "Tiredness or fatigue", "Loss of interest in sex"
]
N_CONCEPTS = len(CONCEPT_NAMES)

print(f"✓ Defined {N_CONCEPTS} BDI-II concepts")

✓ Defined 21 BDI-II concepts


In [6]:
# Hyperparameters
HYPERPARAMS = {
    # Data preparation
    "k_posts": 50,              # Top-k posts per subject
    "sbert_model": "all-MiniLM-L6-v2",
    "post_dim": 384,            # SBERT embedding dimension
    
    # Model architecture
    "n_concepts": 21,
    "n_tasks": 1,
    "emb_size": 128,
    "attention_hidden": 128,    # Attention layer hidden dim
    
    # CEM-specific
    "shared_prob_gen": True,
    "intervention_prob": 0.25,
    
    # Training
    "batch_size_train": 32,
    "batch_size_eval": 64,
    "max_epochs": 100,
    "learning_rate": 0.01,
    "weight_decay": 4e-05,
    
    # Loss
    "concept_loss_weight": 1.0,
    "task_loss_weight": 1.0,
    
    # LDAM Loss
    "use_ldam_loss": True,
    "ldam_max_margin": 0.6,
    "ldam_scale": 40,
    
    # Weighted Sampler
    "use_weighted_sampler": True,
}

print("✓ Hyperparameters configured:")
for k, v in list(HYPERPARAMS.items())[:5]:
    print(f"  {k}: {v}")
print("  ...")

✓ Hyperparameters configured:
  k_posts: 50
  sbert_model: all-MiniLM-L6-v2
  post_dim: 384
  n_concepts: 21
  n_tasks: 1
  ...


### Section 1: Load Training Data

In [7]:
# Helper functions for XML parsing
import re
import xml.etree.ElementTree as ET

WHITESPACE_RE = re.compile(r"\s+")

def normalize_text(text):
    """Normalize text by removing null chars and extra whitespace."""
    if not text:
        return ""
    text = text.replace("\x00", "")
    text = WHITESPACE_RE.sub(" ", text).strip()
    return text

def extract_posts_from_xml(xml_path, min_chars=10):
    """Extract posts from a single XML file."""
    try:
        tree = ET.parse(xml_path)
        root = tree.getroot()
    except Exception as e:
        print(f"WARNING: Failed to parse {xml_path}: {e}")
        return []
    
    posts = []
    for writing in root.findall("WRITING"):
        title = writing.findtext("TITLE") or ""
        text = writing.findtext("TEXT") or ""
        
        combined = normalize_text(f"{title} {text}".strip())
        if len(combined) >= min_chars:
            posts.append(combined)
    
    return posts

print("✓ Helper functions defined")

✓ Helper functions defined


In [8]:
# Parse training XML files
print("Loading training data...")
start_time = time.time()

train_data = []

# Process positive examples
print("  Processing positive examples...")
pos_files = glob.glob(os.path.join(POS_DIR, "**", "*.xml"), recursive=True)
for xml_file in tqdm(pos_files, desc="Processing positive examples"):
    filename = os.path.basename(xml_file)
    match = re.match(r"train_(subject\d+)_\d+\.xml", filename)
    if match:
        subject_id = match.group(1)
        posts = extract_posts_from_xml(xml_file)
        for post in posts:
            train_data.append({
                "subject_id": subject_id,
                "label": 1,
                "text": post
            })

print(f"  Loaded {sum(d['label'] == 1 for d in train_data)} posts from positive subjects")

# Process negative examples
print("  Processing negative examples...")
neg_files = glob.glob(os.path.join(NEG_DIR, "**", "*.xml"), recursive=True)
for xml_file in tqdm(neg_files, desc="Processing negative examples"):
    filename = os.path.basename(xml_file)
    match = re.match(r"train_(subject\d+)_\d+\.xml", filename)
    if match:
        subject_id = match.group(1)
        posts = extract_posts_from_xml(xml_file)
        for post in posts:
            train_data.append({
                "subject_id": subject_id,
                "label": 0,
                "text": post
            })

train_posts_df = pd.DataFrame(train_data)

print(f"✓ Loaded training data in {time.time()-start_time:.1f}s")
print(f"  Total posts: {len(train_posts_df):,}")
print(f"  Unique subjects: {train_posts_df['subject_id'].nunique()}")

Loading training data...
  Processing positive examples...


Processing positive examples: 100%|██████████| 830/830 [00:00<00:00, 2537.29it/s]



  Loaded 29868 posts from positive subjects
  Processing negative examples...


Processing negative examples: 100%|██████████| 4031/4031 [00:02<00:00, 1733.06it/s]



✓ Loaded training data in 2.8s
  Total posts: 286,740
  Unique subjects: 486


In [9]:
# Load concept labels from questionnaires
print("Loading concept labels...")

concepts_df = pd.read_csv(CONCEPTS_FILE)
concepts_df["subject_id"] = concepts_df["Subject"].str.replace("train_", "", regex=True)

# Binarize concept values
concept_cols = [col for col in concepts_df.columns if col in CONCEPT_NAMES]
for col in concept_cols:
    concepts_df[col] = (concepts_df[col] > 0).astype(int)

print(f"✓ Loaded concept labels for {len(concepts_df)} subjects")

Loading concept labels...
✓ Loaded concept labels for 486 subjects


### Section 2: Load Test Data

In [10]:
# Extract test ZIP files to temporary directory
print("Extracting test data...")
temp_dir = tempfile.mkdtemp(prefix="test_chunks_")
print(f"  Temp directory: {temp_dir}")

for i in range(1, 11):
    zip_path = os.path.join(TEST_DIR, f"chunk {i}.zip")
    if os.path.exists(zip_path):
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(os.path.join(temp_dir, f"chunk_{i}"))
        if i % 3 == 0:
            print(f"  Extracted chunk {i}/10")

print("✓ Test data extracted")

Extracting test data...
  Temp directory: /var/folders/gb/m6c_r5xx6_14p7mlfjwk29900000gn/T/test_chunks_5701uwv8
  Extracted chunk 3/10
  Extracted chunk 3/10
  Extracted chunk 6/10
  Extracted chunk 6/10
  Extracted chunk 9/10
  Extracted chunk 9/10
✓ Test data extracted
✓ Test data extracted


In [11]:
# Load test labels
test_labels_df = pd.read_csv(TEST_LABELS, sep='	', header=None, names=['subject_id', 'label'])
test_labels_df['subject_id'] = test_labels_df['subject_id'].str.strip()

print(f"✓ Loaded test labels for {len(test_labels_df)} subjects")

✓ Loaded test labels for 401 subjects


In [12]:
# Parse test XML files
print("Loading test posts...")
test_data = []

test_xml_files = glob.glob(os.path.join(temp_dir, "**", "*.xml"), recursive=True)
print(f"  Found {len(test_xml_files)} XML files")

for xml_file in test_xml_files:
    filename = os.path.basename(xml_file)
    match = re.match(r"(test_subject\d+)_\d+\.xml", filename)
    if match:
        subject_id = match.group(1)
        label_row = test_labels_df[test_labels_df['subject_id'] == subject_id]
        if len(label_row) > 0:
            label = label_row.iloc[0]['label']
            posts = extract_posts_from_xml(xml_file)
            for post in posts:
                test_data.append({
                    "subject_id": subject_id,
                    "label": label,
                    "text": post
                })

test_posts_df = pd.DataFrame(test_data)

print(f"✓ Loaded test posts")
print(f"  Total posts: {len(test_posts_df):,}")
print(f"  Unique subjects: {test_posts_df['subject_id'].nunique()}")

Loading test posts...
  Found 4010 XML files
✓ Loaded test posts
  Total posts: 229,746
  Unique subjects: 401
✓ Loaded test posts
  Total posts: 229,746
  Unique subjects: 401


In [13]:
# Split test data into validation and test sets (stratified 50/50)
print("Splitting test data into validation and test...")

test_subjects = test_posts_df.groupby('subject_id')['label'].first().reset_index()

val_subjects, test_subjects_final = train_test_split(
    test_subjects['subject_id'],
    test_size=0.5,
    stratify=test_subjects['label'],
    random_state=SEED
)

val_posts_df = test_posts_df[test_posts_df['subject_id'].isin(val_subjects)].copy()
test_posts_df_final = test_posts_df[test_posts_df['subject_id'].isin(test_subjects_final)].copy()

print(f"✓ Split complete")
print(f"  Validation: {val_posts_df['subject_id'].nunique()} subjects")
print(f"  Test: {test_posts_df_final['subject_id'].nunique()} subjects")

Splitting test data into validation and test...
✓ Split complete
  Validation: 200 subjects
  Test: 201 subjects


### Section 3: SBERT Setup & Concept Embeddings

In [14]:
# Load SBERT model
print(f"Loading SBERT model: {HYPERPARAMS['sbert_model']}")
sbert_model = SentenceTransformer(HYPERPARAMS['sbert_model'])
sbert_model = sbert_model.to(DEVICE)

print(f"✓ SBERT model loaded on {DEVICE}")
print(f"  Embedding dimension: {sbert_model.get_sentence_embedding_dimension()}")

Loading SBERT model: all-MiniLM-L6-v2
✓ SBERT model loaded on mps
  Embedding dimension: 384
✓ SBERT model loaded on mps
  Embedding dimension: 384


In [15]:
# Create concept embeddings
print(f"Creating embeddings for {N_CONCEPTS} concepts...")
concept_embeddings = sbert_model.encode(
    CONCEPT_NAMES,
    convert_to_tensor=True,
    show_progress_bar=False
)

print(f"✓ Concept embeddings created")
print(f"  Shape: {concept_embeddings.shape}")

Creating embeddings for 21 concepts...
✓ Concept embeddings created
  Shape: torch.Size([21, 384])
✓ Concept embeddings created
  Shape: torch.Size([21, 384])


### Section 4: Post Retrieval (Top-k per Subject)

In [16]:
def retrieve_top_k_posts(subject_id, posts_df, concept_embs, sbert, k=50):
    """Retrieve top-k posts for a subject based on concept similarity."""
    subj_posts = posts_df[posts_df['subject_id'] == subject_id]['text'].tolist()
    
    if len(subj_posts) == 0:
        return []
    
    # Pad if needed
    if len(subj_posts) <= k:
        if len(subj_posts) < k:
            extra_needed = k - len(subj_posts)
            padding = list(np.random.choice(subj_posts, size=extra_needed, replace=True))
            return subj_posts + padding
        else:
            return subj_posts
    
    # Encode and rank
    post_embeddings = sbert.encode(
        subj_posts,
        convert_to_tensor=True,
        show_progress_bar=False
    )
    
    cos_scores = util.cos_sim(post_embeddings, concept_embs)
    max_sim_scores = cos_scores.max(dim=1).values.cpu().numpy()
    top_k_indices = np.argpartition(-max_sim_scores, range(min(k, len(subj_posts))))[:k]
    
    return [subj_posts[i] for i in top_k_indices]

print("✓ Post retrieval function defined")

✓ Post retrieval function defined


In [17]:
# Retrieve top-k posts for all subjects
print(f"Retrieving top-{HYPERPARAMS['k_posts']} posts for all subjects...")
print("⏰ This will take ~40 minutes")
start_time = time.time()

# Training subjects
print("  Processing training subjects...")
train_selected = {}
train_subjects = train_posts_df['subject_id'].unique()
for subject_id in tqdm(train_subjects, desc="Processing training subjects"):
    selected = retrieve_top_k_posts(
        subject_id, train_posts_df, concept_embeddings, sbert_model, k=HYPERPARAMS['k_posts']
    )
    train_selected[subject_id] = selected

# Validation subjects
print("  Processing validation subjects...")
val_selected = {}
val_subjects = val_posts_df['subject_id'].unique()
for subject_id in tqdm(val_subjects, desc="Processing validation subjects"):
    selected = retrieve_top_k_posts(
        subject_id, val_posts_df, concept_embeddings, sbert_model, k=HYPERPARAMS['k_posts']
    )
    val_selected[subject_id] = selected

# Test subjects
print("  Processing test subjects...")
test_selected = {}
test_subjects = test_posts_df_final['subject_id'].unique()
for subject_id in tqdm(test_subjects, desc="Processing test subjects"):
    selected = retrieve_top_k_posts(
        subject_id, test_posts_df_final, concept_embeddings, sbert_model, k=HYPERPARAMS['k_posts']
    )
    test_selected[subject_id] = selected

print(f"✓ Post retrieval complete in {time.time()-start_time:.1f}s ({(time.time()-start_time)/60:.1f} min)")

Retrieving top-50 posts for all subjects...
⏰ This will take ~40 minutes
  Processing training subjects...


Processing training subjects: 100%|██████████| 486/486 [17:04<00:00,  2.11s/it]
Processing training subjects: 100%|██████████| 486/486 [17:04<00:00,  2.11s/it]


  Processing validation subjects...


Processing validation subjects: 100%|██████████| 200/200 [09:09<00:00,  2.75s/it]



  Processing test subjects...


Processing test subjects: 100%|██████████| 201/201 [09:11<00:00,  2.74s/it]

✓ Post retrieval complete in 2126.0s (35.4 min)





### Section 5: Encode Posts (WITHOUT Pooling)

**KEY CHANGE:** Store individual post embeddings instead of pre-pooling them. This allows attention to be learned during training.

In [18]:
def encode_posts_no_pooling(selected_posts_dict, sbert, max_posts=50):
    """
    Encode selected posts WITHOUT pooling - keep individual embeddings.
    
    This is the KEY CHANGE from concept-based attention:
    - OLD: Encode → Pool with concept-based weights → Save [n, 384]
    - NEW: Encode → Save individual posts → [n, k, 384]
    
    Args:
        selected_posts_dict: {subject_id: [post1, post2, ...]}
        sbert: Sentence-BERT model
        max_posts: Number of posts per subject (pad/truncate)
    
    Returns:
        post_embeddings: [n_subjects, max_posts, 384]
        subject_ids: List of subject IDs
    """
    subject_ids = list(selected_posts_dict.keys())
    all_post_embeddings = []

    for subject_id in tqdm(subject_ids, desc="Encoding posts"):
        posts = selected_posts_dict[subject_id]

        # Encode all posts for this subject
        post_embs = sbert.encode(
            posts,
            convert_to_tensor=True,
            show_progress_bar=False
        )  # shape: [k, 384]

        # Pad or truncate to exactly max_posts
        if len(post_embs) < max_posts:
            # Pad with zeros
            padding = torch.zeros(max_posts - len(post_embs), post_embs.shape[1])
            post_embs = torch.cat([post_embs, padding], dim=0)
        elif len(post_embs) > max_posts:
            # Truncate (shouldn't happen if retrieval is correct)
            post_embs = post_embs[:max_posts]

        all_post_embeddings.append(post_embs.cpu().numpy())

    return np.stack(all_post_embeddings), subject_ids

print("✓ Encoding function defined (NO pooling)")

✓ Encoding function defined (NO pooling)


In [19]:
# Encode posts for all splits (keep individual embeddings)
print("Encoding posts WITHOUT pooling...")
print("⏰ This will take ~8-10 minutes")

start_time = time.time()

print("  Encoding training set...")
X_train, train_subject_ids = encode_posts_no_pooling(
    train_selected,
    sbert_model,
    max_posts=HYPERPARAMS['k_posts']
)
print(f"    X_train shape: {X_train.shape}")  # Should be [n_train, 50, 384]

print("  Encoding validation set...")
X_val, val_subject_ids = encode_posts_no_pooling(
    val_selected,
    sbert_model,
    max_posts=HYPERPARAMS['k_posts']
)
print(f"    X_val shape: {X_val.shape}")

print("  Encoding test set...")
X_test, test_subject_ids = encode_posts_no_pooling(
    test_selected,
    sbert_model,
    max_posts=HYPERPARAMS['k_posts']
)
print(f"    X_test shape: {X_test.shape}")

print(f"✓ Encoding complete in {time.time()-start_time:.1f}s ({(time.time()-start_time)/60:.1f} min)")

# Verify shapes
assert X_train.shape[1] == HYPERPARAMS['k_posts'], "Train shape mismatch!"
assert X_train.shape[2] == HYPERPARAMS['post_dim'], "Embedding dim mismatch!"
print(f"✓ Shape verification passed: [n, {HYPERPARAMS['k_posts']}, {HYPERPARAMS['post_dim']}]")

Encoding posts WITHOUT pooling...
⏰ This will take ~8-10 minutes
  Encoding training set...


Encoding posts: 100%|██████████| 486/486 [03:25<00:00,  2.37it/s]
Encoding posts: 100%|██████████| 486/486 [03:25<00:00,  2.37it/s]


    X_train shape: (486, 50, 384)
  Encoding validation set...


Encoding posts: 100%|██████████| 200/200 [01:07<00:00,  2.97it/s]
Encoding posts: 100%|██████████| 200/200 [01:07<00:00,  2.97it/s]


    X_val shape: (200, 50, 384)
  Encoding test set...


Encoding posts: 100%|██████████| 201/201 [01:08<00:00,  2.95it/s]

    X_test shape: (201, 50, 384)
✓ Encoding complete in 340.9s (5.7 min)
✓ Shape verification passed: [n, 50, 384]





### Section 6: Build Concept Matrices and Labels

In [20]:
# Build concept matrices and label vectors
print("Building concept matrices and labels...")

# Training: get concepts from questionnaires
C_train = []
y_train = []
for subject_id in train_subject_ids:
    label = train_posts_df[train_posts_df['subject_id'] == subject_id]['label'].iloc[0]
    y_train.append(label)
    
    concept_row = concepts_df[concepts_df['subject_id'] == subject_id]
    if len(concept_row) > 0:
        concepts = concept_row[concept_cols].values[0]
    else:
        concepts = np.zeros(N_CONCEPTS)
    C_train.append(concepts)

C_train = np.array(C_train, dtype=np.float32)
y_train = np.array(y_train, dtype=np.float32)

# Validation: zeros for concepts (no ground truth)
C_val = np.zeros((len(val_subject_ids), N_CONCEPTS), dtype=np.float32)
y_val = []
for subject_id in val_subject_ids:
    label = val_posts_df[val_posts_df['subject_id'] == subject_id]['label'].iloc[0]
    y_val.append(label)
y_val = np.array(y_val, dtype=np.float32)

# Test: zeros for concepts
C_test = np.zeros((len(test_subject_ids), N_CONCEPTS), dtype=np.float32)
y_test = []
for subject_id in test_subject_ids:
    label = test_posts_df_final[test_posts_df_final['subject_id'] == subject_id]['label'].iloc[0]
    y_test.append(label)
y_test = np.array(y_test, dtype=np.float32)

print("✓ Matrices built")
print(f"  Train: X={X_train.shape}, C={C_train.shape}, y={y_train.shape}")
print(f"  Val:   X={X_val.shape}, C={C_val.shape}, y={y_val.shape}")
print(f"  Test:  X={X_test.shape}, C={C_test.shape}, y={y_test.shape}")
print(f"Training label distribution: {np.bincount(y_train.astype(int))}")

Building concept matrices and labels...
✓ Matrices built
  Train: X=(486, 50, 384), C=(486, 21), y=(486,)
  Val:   X=(200, 50, 384), C=(200, 21), y=(200,)
  Test:  X=(201, 50, 384), C=(201, 21), y=(201,)
Training label distribution: [403  83]
✓ Matrices built
  Train: X=(486, 50, 384), C=(486, 21), y=(486,)
  Val:   X=(200, 50, 384), C=(200, 21), y=(200,)
  Test:  X=(201, 50, 384), C=(201, 21), y=(201,)
Training label distribution: [403  83]


### Section 7: Compute Class Weights

In [21]:
# Compute class weights for imbalanced dataset
n_negative = int(np.sum(y_train == 0))
n_positive = int(np.sum(y_train == 1))
pos_weight = n_negative / n_positive

# Update hyperparameters
HYPERPARAMS['n_positive'] = n_positive
HYPERPARAMS['n_negative'] = n_negative

print(f"Class imbalance:")
print(f"  Negative samples: {n_negative}")
print(f"  Positive samples: {n_positive}")
print(f"  Ratio: 1:{pos_weight:.2f}")

Class imbalance:
  Negative samples: 403
  Positive samples: 83
  Ratio: 1:4.86


### Section 8: Save Datasets

In [22]:
# Save processed datasets to disk
print("Saving datasets...")

# Save numpy arrays (NEW SHAPE: [n, k, 384])
np.savez_compressed(
    os.path.join(SAVE_DIR, "train_data.npz"),
    X=X_train,  # [n_train, 50, 384] - individual posts!
    C=C_train,
    y=y_train,
    subject_ids=np.array(train_subject_ids)
)

np.savez_compressed(
    os.path.join(SAVE_DIR, "val_data.npz"),
    X=X_val,
    C=C_val,
    y=y_val,
    subject_ids=np.array(val_subject_ids)
)

np.savez_compressed(
    os.path.join(SAVE_DIR, "test_data.npz"),
    X=X_test,
    C=C_test,
    y=y_test,
    subject_ids=np.array(test_subject_ids)
)

# Save class weights info
class_info = {
    "n_positive": n_positive,
    "n_negative": n_negative,
    "pos_weight": float(pos_weight)
}

with open(os.path.join(SAVE_DIR, "class_weights.json"), 'w') as f:
    json.dump(class_info, f, indent=4)

print(f"✓ Datasets saved to {SAVE_DIR}")
print(f"  train_data.npz: {X_train.shape}")
print(f"  val_data.npz:   {X_val.shape}")
print(f"  test_data.npz:  {X_test.shape}")
print(f"  class_weights.json")
print(f"NEW DATA SHAPE: [{X_train.shape[0]}, {X_train.shape[1]}, {X_train.shape[2]}]")
print(f"  OLD DATA SHAPE would have been: [{X_train.shape[0]}, {X_train.shape[2]}] (pre-pooled)")

Saving datasets...
✓ Datasets saved to /Users/gualtieromarencoturi/Desktop/thesis/Master-Thesis-CEM-Depression-etc-case-study/data/processed/attention_task_optimized
  train_data.npz: (486, 50, 384)
  val_data.npz:   (200, 50, 384)
  test_data.npz:  (201, 50, 384)
  class_weights.json
NEW DATA SHAPE: [486, 50, 384]
  OLD DATA SHAPE would have been: [486, 384] (pre-pooled)
✓ Datasets saved to /Users/gualtieromarencoturi/Desktop/thesis/Master-Thesis-CEM-Depression-etc-case-study/data/processed/attention_task_optimized
  train_data.npz: (486, 50, 384)
  val_data.npz:   (200, 50, 384)
  test_data.npz:  (201, 50, 384)
  class_weights.json
NEW DATA SHAPE: [486, 50, 384]
  OLD DATA SHAPE would have been: [486, 384] (pre-pooled)


In [23]:
# Clean up temporary directory
try:
    shutil.rmtree(temp_dir)
    print(f"✓ Cleaned up temporary directory")
except Exception as e:
    print(f"⚠ Failed to clean up temporary directory: {e}")

✓ Cleaned up temporary directory


In [24]:
print("" + "="*70)
print("          DATA PREPARATION COMPLETE")
print("="*70)
print(f"Saved data with shape: [n_samples, {HYPERPARAMS['k_posts']}, {HYPERPARAMS['post_dim']}]")
print(f"  Memory per split: ~{X_train.nbytes / 1024**2:.1f} MB")
print("Next: Train model with task-optimized attention")
print("="*70)

          DATA PREPARATION COMPLETE
Saved data with shape: [n_samples, 50, 384]
  Memory per split: ~35.6 MB
Next: Train model with task-optimized attention


## PART 2: MODEL TRAINING WITH TASK-OPTIMIZED ATTENTION

### Section 9: PyTorch Dataset

In [25]:
class CEMDataset(Dataset):
    """
    Dataset for CEM with individual post embeddings.
    
    X shape: [n_samples, k_posts, post_dim] - NEW!
    C shape: [n_samples, n_concepts]
    y shape: [n_samples, 1]
    """
    def __init__(self, X, C, y):
        self.X = torch.tensor(X, dtype=torch.float32)  # [n, 50, 384]
        self.C = torch.tensor(C, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32)
    
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx], self.C[idx]

# Create datasets
train_dataset = CEMDataset(X_train, C_train, y_train)
val_dataset = CEMDataset(X_val, C_val, y_val)
test_dataset = CEMDataset(X_test, C_test, y_test)

print(f"✓ Datasets created")
print(f"  Train: {len(train_dataset)} samples")
print(f"  Val:   {len(val_dataset)} samples")
print(f"  Test:  {len(test_dataset)} samples")

✓ Datasets created
  Train: 486 samples
  Val:   200 samples
  Test:  201 samples


In [26]:
# Create DataLoaders
if HYPERPARAMS['use_weighted_sampler']:
    class_sample_counts = np.bincount(y_train.astype(int))
    weights = 1. / class_sample_counts
    sample_weights = weights[y_train.astype(int)]
    
    train_sampler = WeightedRandomSampler(
        weights=sample_weights,
        num_samples=len(sample_weights),
        replacement=True
    )
    
    print(f"✓ Using WeightedRandomSampler")
    train_loader = DataLoader(train_dataset, batch_size=HYPERPARAMS['batch_size_train'], sampler=train_sampler)
else:
    print("✓ Using standard DataLoader (shuffle=True)")
    train_loader = DataLoader(train_dataset, batch_size=HYPERPARAMS['batch_size_train'], shuffle=True)

val_loader = DataLoader(val_dataset, batch_size=HYPERPARAMS['batch_size_eval'], shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=HYPERPARAMS['batch_size_eval'], shuffle=False)

print("✓ All DataLoaders created")

✓ Using WeightedRandomSampler
✓ All DataLoaders created


### Section 10: Loss Functions

In [27]:
class LDAMLoss(nn.Module):
    """
    Label-Distribution-Aware Margin (LDAM) Loss for long-tailed recognition.
    
    Creates class-dependent margins to make decision boundaries harder for minority classes.
    """
    def __init__(self, n_positive, n_negative, max_margin=0.5, scale=30):
        super(LDAMLoss, self).__init__()
        self.max_margin = max_margin
        self.scale = scale
        
        # Compute class frequencies
        total = n_positive + n_negative
        freq_pos = n_positive / total
        freq_neg = n_negative / total
        
        # Compute margins: minority class gets larger margin
        margin_pos = max_margin * (freq_pos ** (-0.25))
        margin_neg = max_margin * (freq_neg ** (-0.25))
        
        self.register_buffer('margin_pos', torch.tensor(margin_pos))
        self.register_buffer('margin_neg', torch.tensor(margin_neg))
    
    def forward(self, logits, targets):
        logits = logits.view(-1)
        targets = targets.view(-1).float()
        
        # Apply class-dependent margins
        margin = targets * self.margin_pos + (1 - targets) * (-self.margin_neg)
        adjusted_logits = (logits - margin) * self.scale
        
        return F.binary_cross_entropy_with_logits(adjusted_logits, targets, reduction='mean')

print("✓ LDAM Loss defined")

✓ LDAM Loss defined


### Section 11: Task-Optimized Attention CEM Model

**KEY INNOVATION:** Learnable attention layer that computes post importance scores, optimized via task loss gradients.

In [28]:
class TaskOptimizedAttentionCEM(pl.LightningModule):
    """
    CEM with learnable attention that optimizes post selection for task loss.

    Architecture:
        1. Attention Layer: Learns which posts matter for task (NEW!)
        2. Attention Pooling: Weighted average of posts
        3. Concept Extractor: Extracts pre-concept features
        4. Concept Layers: 21 dual-embedding concepts
        5. Task Classifier: Final depression prediction
        
    KEY DIFFERENCE from original:
        - Input: [batch, k_posts, post_dim] instead of [batch, post_dim]
        - Has learnable attention_layer
        - Attention weights optimized via task loss gradients
    """
    def __init__(
        self,
        n_concepts=21,
        emb_size=128,
        k_posts=50,
        post_dim=384,
        attention_hidden=128,
        shared_prob_gen=True,
        intervention_prob=0.25,
        concept_loss_weight=1.0,
        task_loss_weight=1.0,
        learning_rate=0.01,
        weight_decay=4e-05,
        use_ldam_loss=True,
        n_positive=83,
        n_negative=403,
        ldam_max_margin=0.5,
        ldam_scale=30,
    ):
        super().__init__()
        self.save_hyperparameters()
        
        # =====================================================================
        # STAGE 0: ATTENTION MODULE (NEW!)
        # =====================================================================
        self.attention_layer = nn.Sequential(
            nn.Linear(post_dim, attention_hidden),
            nn.Tanh(),
            nn.Dropout(0.2),
            nn.Linear(attention_hidden, 1)  # Score for each post
        )
        
        # =====================================================================
        # STAGE 1: CONCEPT EXTRACTOR (receives pooled embedding)
        # =====================================================================
        self.concept_extractor = nn.Sequential(
            nn.Linear(post_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 256)
        )
        
        # =====================================================================
        # STAGE 2: CONCEPT CONTEXT GENERATORS (21 concepts, dual embeddings)
        # =====================================================================
        self.context_layers = nn.ModuleList([
            nn.Sequential(
                nn.Linear(256, emb_size * 2),  # Dual embeddings
                nn.LeakyReLU()
            ) for _ in range(n_concepts)
        ])
        
        # =====================================================================
        # STAGE 3: CONCEPT PROBABILITY GENERATORS
        # =====================================================================
        if shared_prob_gen:
            self.prob_generator = nn.Linear(emb_size * 2, 1)
        else:
            self.prob_generators = nn.ModuleList([
                nn.Linear(emb_size * 2, 1) for _ in range(n_concepts)
            ])
        
        # =====================================================================
        # STAGE 4: TASK CLASSIFIER (C2Y model)
        # =====================================================================
        self.task_classifier = nn.Sequential(
            nn.Linear(n_concepts * emb_size, 128),
            nn.LeakyReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 1)  # Binary classification
        )
        
        # =====================================================================
        # LOSSES
        # =====================================================================
        self.concept_loss_fn = nn.BCEWithLogitsLoss()
        if use_ldam_loss:
            self.task_loss_fn = LDAMLoss(n_positive, n_negative, ldam_max_margin, ldam_scale)
        else:
            self.task_loss_fn = nn.BCEWithLogitsLoss()
    
    def forward(self, post_embeddings, c_true=None, train=False):
        """
        Forward pass with learnable attention.
        
        Args:
            post_embeddings: [batch, k_posts, post_dim] - individual posts
            c_true: [batch, n_concepts] - concept labels (for interventions)
            train: bool - training mode
        
        Returns:
            c_logits: [batch, 21] - concept predictions
            y_logits: [batch, 1] - task prediction
            attn_weights: [batch, k_posts] - attention weights (for analysis)
        """
        batch_size = post_embeddings.shape[0]
        
        # ====================================================================
        # STEP 1: Compute attention weights (LEARNABLE!)
        # ====================================================================
        attn_scores = self.attention_layer(post_embeddings)  # [batch, k, 1]
        attn_weights = torch.softmax(attn_scores.squeeze(-1), dim=1)  # [batch, k]
        
        # ====================================================================
        # STEP 2: Attention-weighted pooling
        # ====================================================================
        pooled = torch.sum(
            attn_weights.unsqueeze(-1) * post_embeddings,  # [batch, k, 1] * [batch, k, 384]
            dim=1
        )  # [batch, 384]
        
        # ====================================================================
        # STEP 3: Concept extraction (existing CEM pipeline)
        # ====================================================================
        pre_c = self.concept_extractor(pooled)  # [batch, 256]
        
        # ====================================================================
        # STEP 4: Concept context and probability generation
        # ====================================================================
        contexts = []
        c_logits_list = []
        
        for i, context_layer in enumerate(self.context_layers):
            # Select probability generator
            if self.hparams.shared_prob_gen:
                prob_gen = self.prob_generator
            else:
                prob_gen = self.prob_generators[i]
            
            # Generate context
            context = context_layer(pre_c)  # [batch, 2*emb_size]
            
            # Generate concept probability (logit)
            logit = prob_gen(context)  # [batch, 1]
            c_logits_list.append(logit)
            
            contexts.append(context.unsqueeze(1))  # [batch, 1, 2*emb_size]
        
        c_logits = torch.cat(c_logits_list, dim=-1)  # [batch, 21]
        contexts = torch.cat(contexts, dim=1)  # [batch, 21, 2*emb_size]
        
        # ====================================================================
        # STEP 5: Concept interventions (during training)
        # ====================================================================
        c_probs = torch.sigmoid(c_logits)  # [batch, 21]
        
        if train and (self.hparams.intervention_prob > 0) and (c_true is not None):
            intervention_mask = torch.bernoulli(
                torch.ones_like(c_probs) * self.hparams.intervention_prob
            )
            c_probs = c_probs * (1 - intervention_mask) + c_true * intervention_mask
        
        # ====================================================================
        # STEP 6: Dual embeddings and task prediction
        # ====================================================================
        c_pred_embs = (
            contexts[:, :, :self.hparams.emb_size] * c_probs.unsqueeze(-1) +
            contexts[:, :, self.hparams.emb_size:] * (1 - c_probs.unsqueeze(-1))
        )
        c_pred_embs = c_pred_embs.view(batch_size, -1)  # [batch, 21*emb_size]
        
        y_logits = self.task_classifier(c_pred_embs)  # [batch, 1]
        
        return c_logits, y_logits, attn_weights
    
    def training_step(self, batch, batch_idx):
        post_embs, y, c_true = batch
        c_logits, y_logits, _ = self.forward(post_embs, c_true=c_true, train=True)
        
        # Task loss
        task_loss = self.task_loss_fn(y_logits.squeeze(), y.squeeze())
        
        # Concept loss
        concept_loss = self.concept_loss_fn(c_logits, c_true)
        
        # Combined loss
        loss = (self.hparams.task_loss_weight * task_loss + 
                self.hparams.concept_loss_weight * concept_loss)
        
        # Logging
        self.log('train_loss', loss, on_epoch=True, prog_bar=True)
        self.log('train_task_loss', task_loss, on_epoch=True)
        self.log('train_concept_loss', concept_loss, on_epoch=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        post_embs, y, c_true = batch
        c_logits, y_logits, _ = self.forward(post_embs, c_true=c_true, train=False)
        
        # Task loss
        task_loss = self.task_loss_fn(y_logits.squeeze(), y.squeeze())
        
        # Concept loss
        concept_loss = self.concept_loss_fn(c_logits, c_true)
        
        # Combined loss
        loss = (self.hparams.task_loss_weight * task_loss +
                self.hparams.concept_loss_weight * concept_loss)
        
        # Logging
        self.log('val_loss', loss, on_epoch=True, prog_bar=True)
        self.log('val_task_loss', task_loss, on_epoch=True)
        
        return loss
    
    def configure_optimizers(self):
        return torch.optim.Adam(
            self.parameters(),
            lr=self.hparams.learning_rate,
            weight_decay=self.hparams.weight_decay
        )

print("✓ TaskOptimizedAttentionCEM model defined")
print("  Key innovation: Learnable attention layer optimized via task loss")

✓ TaskOptimizedAttentionCEM model defined
  Key innovation: Learnable attention layer optimized via task loss


### Section 12: Model Initialization & Training

In [29]:
# Initialize TaskOptimizedAttentionCEM model
model = TaskOptimizedAttentionCEM(
    n_concepts=HYPERPARAMS['n_concepts'],
    emb_size=HYPERPARAMS['emb_size'],
    k_posts=HYPERPARAMS['k_posts'],
    post_dim=HYPERPARAMS['post_dim'],
    attention_hidden=HYPERPARAMS['attention_hidden'],
    shared_prob_gen=HYPERPARAMS['shared_prob_gen'],
    intervention_prob=HYPERPARAMS['intervention_prob'],
    concept_loss_weight=HYPERPARAMS['concept_loss_weight'],
    task_loss_weight=HYPERPARAMS['task_loss_weight'],
    learning_rate=HYPERPARAMS['learning_rate'],
    weight_decay=HYPERPARAMS['weight_decay'],
    use_ldam_loss=HYPERPARAMS['use_ldam_loss'],
    n_positive=HYPERPARAMS['n_positive'],
    n_negative=HYPERPARAMS['n_negative'],
    ldam_max_margin=HYPERPARAMS['ldam_max_margin'],
    ldam_scale=HYPERPARAMS['ldam_scale']
)

print("✓ Model initialized")
print(f"  Attention hidden dim: {HYPERPARAMS['attention_hidden']}")
print(f"  Concept embedding size: {HYPERPARAMS['emb_size']}")
print(f"  Using LDAM Loss: {HYPERPARAMS['use_ldam_loss']}")

✓ Model initialized
  Attention hidden dim: 128
  Concept embedding size: 128
  Using LDAM Loss: True


In [30]:
# Setup trainer
checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",
    dirpath=os.path.join(OUTPUT_DIR, "models"),
    filename="task-attention-cem-{epoch:02d}-{val_loss:.2f}",
    save_top_k=1,
    mode="min"
)

trainer = pl.Trainer(
    max_epochs=HYPERPARAMS['max_epochs'],
    accelerator=DEVICE,
    devices=1,
    logger=CSVLogger(save_dir=os.path.join(OUTPUT_DIR, "logs"), name="task_attention_cem"),
    log_every_n_steps=10,
    callbacks=[checkpoint_callback],
    enable_progress_bar=True
)

print("✓ Trainer configured")

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


✓ Trainer configured


In [31]:
# Train model
print("Starting training...")
print("⏰ This will take ~15-20 minutes")

trainer.fit(model, train_loader, val_loader)

print("✓ Training complete!")

Starting training...
⏰ This will take ~15-20 minutes


  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")

  | Name              | Type              | Params
--------------------------------------------------------
0 | attention_layer   | Sequential        | 49.4 K
1 | concept_extractor | Sequential        | 164 K 
2 | context_layers    | ModuleList        | 1.4 M 
3 | prob_generator    | Linear            | 257   
4 | task_classifier   | Sequential        | 344 K 
5 | concept_loss_fn   | BCEWithLogitsLoss | 0     
6 | task_loss_fn      | LDAMLoss          | 0     
--------------------------------------------------------
1.9 M     Trainable params
0         Non-trainable params
1.9 M     Total params
7.760     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.


✓ Training complete!


### Section 13: Validation Threshold Selection

In [32]:
# Select decision threshold on validation set
print("Selecting decision threshold on validation set...")

model.eval()
device_obj = torch.device(DEVICE)
model = model.to(device_obj)

y_val_true = []
y_val_prob = []

with torch.no_grad():
    for post_embs, y_batch, c_batch in val_loader:
        post_embs = post_embs.to(device_obj)
        
        _, y_logits, _ = model(post_embs)
        y_probs = torch.sigmoid(y_logits).cpu().squeeze().numpy()
        
        y_val_true.extend(y_batch.numpy().astype(int).tolist())
        y_val_prob.extend(y_probs.tolist() if isinstance(y_probs, np.ndarray) else [y_probs])

y_val_true = np.array(y_val_true)
y_val_prob = np.array(y_val_prob)

# Find best threshold for 80% recall
best_threshold = 0.5
best_precision = 0.0
target_recall = 0.80

for threshold in np.linspace(0.01, 0.50, 50):
    y_pred_temp = (y_val_prob >= threshold).astype(int)
    
    if np.sum(y_pred_temp) == 0:
        continue
    
    recall = recall_score(y_val_true, y_pred_temp)
    precision = precision_score(y_val_true, y_pred_temp)
    
    if recall >= target_recall and precision > best_precision:
        best_precision = precision
        best_threshold = threshold

print(f"✓ Selected threshold: {best_threshold:.2f}")
print(f"  Target recall ≥ {target_recall}, achieved precision = {best_precision:.3f}")

Selecting decision threshold on validation set...
✓ Selected threshold: 0.50
  Target recall ≥ 0.8, achieved precision = 0.000
✓ Selected threshold: 0.50
  Target recall ≥ 0.8, achieved precision = 0.000


### Section 14: Attention Analysis

**NEW CAPABILITY:** Analyze which posts the model focuses on for its predictions.

In [33]:
# Analyze attention weights on validation set
print("" + "="*70)
print("           ATTENTION WEIGHT ANALYSIS")
print("="*70)

model.eval()
model = model.to(device_obj)

all_attn_weights = []
all_y_true = []
all_y_pred = []

with torch.no_grad():
    for post_embs, y, c in val_loader:
        post_embs = post_embs.to(device_obj)
        
        c_logits, y_logits, attn_weights = model(post_embs)
        
        all_attn_weights.append(attn_weights.cpu().numpy())
        all_y_true.append(y.cpu().numpy())
        all_y_pred.append(torch.sigmoid(y_logits).cpu().numpy())

attn_weights_val = np.vstack(all_attn_weights)  # [n_val, k_posts]
y_true_val = np.concatenate(all_y_true)
y_pred_val = np.concatenate(all_y_pred)

print(f"Attention weights shape: {attn_weights_val.shape}")
print(f"Attention statistics:")
print(f"  Mean weight: {attn_weights_val.mean():.4f}")
print(f"  Std weight:  {attn_weights_val.std():.4f}")
print(f"  Max weight:  {attn_weights_val.max():.4f}")
print(f"  Min weight:  {attn_weights_val.min():.4f}")

# Compute entropy of attention distribution
def attention_entropy(weights):
    """Compute entropy of attention distribution"""
    entropy = -np.sum(weights * np.log(weights + 1e-10), axis=1)
    return entropy

entropy_val = attention_entropy(attn_weights_val)
max_entropy = np.log(HYPERPARAMS['k_posts'])

print(f"Attention entropy:")
print(f" Mean: {entropy_val.mean():.4f} (max = {max_entropy:.4f} for uniform)")
print(f" Std:  {entropy_val.std():.4f}")
print(f" Concentration: {(max_entropy - entropy_val.mean()) / max_entropy * 100:.1f}%")

# Compare attention for correct vs incorrect predictions
y_pred_binary = (y_pred_val > 0.5).astype(int).flatten()
y_true_binary = y_true_val.astype(int).flatten()
correct_mask = (y_pred_binary == y_true_binary)

print(f"Attention by prediction quality:")
print(f"  Correct predictions:   entropy = {entropy_val[correct_mask].mean():.4f}")
print(f"  Incorrect predictions: entropy = {entropy_val[~correct_mask].mean():.4f}")

# Save attention weights for further analysis
np.savez(
    os.path.join(OUTPUT_DIR, 'attention_analysis.npz'),
    attn_weights=attn_weights_val,
    y_true=y_true_val,
    y_pred=y_pred_val,
    entropy=entropy_val,
    correct_mask=correct_mask
)

print(f"✓ Saved attention analysis to {OUTPUT_DIR}/attention_analysis.npz")
print("="*70)

           ATTENTION WEIGHT ANALYSIS
Attention weights shape: (200, 50)
Attention statistics:
  Mean weight: 0.0200
  Std weight:  0.1156
  Max weight:  1.0000
  Min weight:  0.0000
Attention entropy:
 Mean: 0.5934 (max = 3.9120 for uniform)
 Std:  0.5254
 Concentration: 84.8%
Attention by prediction quality:
  Correct predictions:   entropy = 0.5673
  Incorrect predictions: entropy = 0.8052
✓ Saved attention analysis to outputs_task_optimized_attention/attention_analysis.npz


### Section 15: Test Set Evaluation

In [34]:
# Run inference on test set
print("" + "="*70)
print("                  TEST SET EVALUATION")
print("="*70)
print("Running inference on test set...")

model.eval()
model = model.to(device_obj)

y_true_list = []
y_prob_list = []
concept_probs_list = []

with torch.no_grad():
    for post_embs, y_batch, c_batch in test_loader:
        post_embs = post_embs.to(device_obj)
        
        c_logits, y_logits, _ = model(post_embs)
        c_probs = torch.sigmoid(c_logits).cpu().numpy()
        y_probs = torch.sigmoid(y_logits).cpu().squeeze().numpy()
        
        y_true_list.extend(y_batch.numpy().astype(int).tolist())
        y_prob_list.extend(y_probs.tolist() if isinstance(y_probs, np.ndarray) else [y_probs])
        concept_probs_list.extend(c_probs.tolist())

y_true = np.array(y_true_list)
y_prob = np.array(y_prob_list)
concept_probs = np.array(concept_probs_list)

# Apply threshold
y_pred = (y_prob >= best_threshold).astype(int)

print("✓ Inference complete")

                  TEST SET EVALUATION
Running inference on test set...
✓ Inference complete
✓ Inference complete


In [35]:
# Compute all metrics
cm = confusion_matrix(y_true, y_pred)
tn, fp, fn, tp = cm.ravel()

acc = accuracy_score(y_true, y_pred)
balanced_acc = balanced_accuracy_score(y_true, y_pred)
roc_auc = roc_auc_score(y_true, y_prob)
mcc = matthews_corrcoef(y_true, y_pred)
f1_binary = f1_score(y_true, y_pred, pos_label=1)
precision_binary = precision_score(y_true, y_pred, pos_label=1)
recall_binary = recall_score(y_true, y_pred, pos_label=1)

# Print results
print(f"Decision Threshold: {best_threshold:.2f}")

print(f"{'CONFUSION MATRIX':^50}")
print("="*50)
print(f"{'':>20} │ {'Predicted Neg':^12} │ {'Predicted Pos':^12}")
print("─"*50)
print(f"{'Actual Negative':>20} │ {f'TN = {tn}':^12} │ {f'FP = {fp}':^12}")
print(f"{'Actual Positive':>20} │ {f'FN = {fn}':^12} │ {f'TP = {tp}':^12}")
print("="*50)

n_pos = int(np.sum(y_true))
n_neg = int(len(y_true) - n_pos)

print(f"TP: {tp}/{n_pos} ({100*tp/n_pos if n_pos > 0 else 0:.1f}% caught)")
print(f"  FN: {fn}/{n_pos} ({100*fn/n_pos if n_pos > 0 else 0:.1f}% missed)")

print(f"Performance Metrics:")
print(f"  MCC:                {mcc:.4f}")
print(f"  F1 Score:           {f1_binary:.4f}")
print(f"  Recall:             {recall_binary:.4f}")
print(f"  Precision:          {precision_binary:.4f}")
print(f"  ROC-AUC:            {roc_auc:.4f}")
print(f"  Accuracy:           {acc:.4f}")
print(f"  Balanced Accuracy:  {balanced_acc:.4f}")

print("" + classification_report(y_true, y_pred, target_names=['Negative', 'Positive']))
print("="*70)

Decision Threshold: 0.50
                 CONFUSION MATRIX                 
                     │ Predicted Neg │ Predicted Pos
──────────────────────────────────────────────────
     Actual Negative │   TN = 170   │    FP = 5   
     Actual Positive │   FN = 16    │   TP = 10   
TP: 10/26 (38.5% caught)
  FN: 16/26 (61.5% missed)
Performance Metrics:
  MCC:                0.4547
  F1 Score:           0.4878
  Recall:             0.3846
  Precision:          0.6667
  ROC-AUC:            0.7857
  Accuracy:           0.8955
  Balanced Accuracy:  0.6780
              precision    recall  f1-score   support

    Negative       0.91      0.97      0.94       175
    Positive       0.67      0.38      0.49        26

    accuracy                           0.90       201
   macro avg       0.79      0.68      0.71       201
weighted avg       0.88      0.90      0.88       201



In [36]:
# Save results
metrics_dict = {
    "model_type": "task_optimized_attention_cem",
    "threshold": float(best_threshold),
    "n_samples": int(len(y_true)),
    "n_positive": int(np.sum(y_true)),
    "n_negative": int(len(y_true) - np.sum(y_true)),
    "accuracy": float(acc),
    "balanced_accuracy": float(balanced_acc),
    "roc_auc": float(roc_auc),
    "mcc": float(mcc),
    "f1_binary": float(f1_binary),
    "precision_binary": float(precision_binary),
    "recall_binary": float(recall_binary),
    "confusion_matrix": {"tn": int(tn), "fp": int(fp), "fn": int(fn), "tp": int(tp)}
}

os.makedirs(os.path.join(OUTPUT_DIR, "results"), exist_ok=True)
with open(os.path.join(OUTPUT_DIR, "results/test_metrics.json"), 'w') as f:
    json.dump(metrics_dict, f, indent=4)

# Save predictions
predictions_df = pd.DataFrame({
    'subject_id': test_subject_ids,
    'y_true': y_true,
    'y_pred': y_pred,
    'y_prob': y_prob
})

for i, concept_name in enumerate(CONCEPT_NAMES):
    predictions_df[concept_name] = concept_probs[:, i]

predictions_df.to_csv(os.path.join(OUTPUT_DIR, "results/test_predictions.csv"), index=False)

print(f"✓ Results saved to {OUTPUT_DIR}/results/")

✓ Results saved to outputs_task_optimized_attention/results/


In [37]:
print("" + "="*70)
print("        TASK-OPTIMIZED ATTENTION CEM - COMPLETE")
print("="*70)

print(f"Generated files:")
print(f"  Data:      {SAVE_DIR}/")
print(f"  Model:     {OUTPUT_DIR}/models/")
print(f"  Metrics:   {OUTPUT_DIR}/results/test_metrics.json")
print(f"  Predictions: {OUTPUT_DIR}/results/test_predictions.csv")
print(f"  Attention: {OUTPUT_DIR}/attention_analysis.npz")

print(f"Key Results:")
print(f"  Test MCC:       {mcc:.4f}")
print(f"  Test F1:        {f1_binary:.4f}")
print(f"  Test Recall:    {recall_binary:.4f}")
print(f"  Attention concentration: {(max_entropy - entropy_val.mean()) / max_entropy * 100:.1f}%")

print("✅ Unified pipeline with task-optimized attention complete!")
print("="*70)

        TASK-OPTIMIZED ATTENTION CEM - COMPLETE
Generated files:
  Data:      /Users/gualtieromarencoturi/Desktop/thesis/Master-Thesis-CEM-Depression-etc-case-study/data/processed/attention_task_optimized/
  Model:     outputs_task_optimized_attention/models/
  Metrics:   outputs_task_optimized_attention/results/test_metrics.json
  Predictions: outputs_task_optimized_attention/results/test_predictions.csv
  Attention: outputs_task_optimized_attention/attention_analysis.npz
Key Results:
  Test MCC:       0.4547
  Test F1:        0.4878
  Test Recall:    0.3846
  Attention concentration: 84.8%
✅ Unified pipeline with task-optimized attention complete!
