# Alternative Attention Dataset Preparation - Sum-Based Scoring

**Purpose:** Prepare dataset using SUM of concept similarities instead of MAX.

**Key Difference**:
- Original: `post_score = max(similarity_to_each_concept)`
- Alternative: `post_score = sum(similarity_to_each_concept)`

This captures posts relevant to MULTIPLE concepts rather than just the most similar one.

**Runtime:** ~40-50 minutes (same as original)

This notebook:
1. Loads training and test data from XML files
2. Uses SBERT to retrieve top-50 concept-relevant posts per subject (sum-based scoring)
3. Pools post embeddings using sum-based attention weights
4. Saves everything to `data/processed/alternative_attention_pipeline/`

## Section 0: Configuration & Setup

In [1]:
# Imports
import os
import glob
import re
import zipfile
import tempfile
import shutil
import json
import time

import numpy as np
import pandas as pd
import xml.etree.ElementTree as ET

import torch
from sentence_transformers import SentenceTransformer, util
from sklearn.model_selection import train_test_split
from tqdm import tqdm

print("✓ All imports successful")

  from tqdm.autonotebook import tqdm, trange


✓ All imports successful


In [2]:
# Set random seeds for reproducibility
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)

print(f"✓ Random seed set to {SEED}")

✓ Random seed set to 42


In [3]:
# Detect device (MPS/CUDA/CPU)
if torch.backends.mps.is_available():
    DEVICE = "mps"
    print("✓ Using MacBook GPU (MPS)")
elif torch.cuda.is_available():
    DEVICE = "cuda"
    print("✓ Using CUDA GPU")
else:
    DEVICE = "cpu"
    print("⚠ Using CPU (will be slow)")

✓ Using MacBook GPU (MPS)


In [4]:
# Define paths
PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), "..", ".."))
DATA_RAW = os.path.join(PROJECT_ROOT, "data/raw")
DATA_PROCESSED = os.path.join(PROJECT_ROOT, "data/processed")

# Training data paths
POS_DIR = os.path.join(DATA_RAW, "train/positive_examples_anonymous_chunks")
NEG_DIR = os.path.join(DATA_RAW, "train/negative_examples_anonymous_chunks")

# Test data paths
TEST_DIR = os.path.join(DATA_RAW, "test")
TEST_LABELS = os.path.join(TEST_DIR, "test_golden_truth.txt")

# Concept labels
CONCEPTS_FILE = os.path.join(DATA_PROCESSED, "merged_questionnaires.csv")

# Output directory - CHANGED FOR ALTERNATIVE PIPELINE
SAVE_DIR = os.path.join(DATA_PROCESSED, "alternative_attention_pipeline")
os.makedirs(SAVE_DIR, exist_ok=True)

print("✓ Paths configured")
print(f"  Project root: {PROJECT_ROOT}")
print(f"  Data save dir: {SAVE_DIR}")

✓ Paths configured
  Project root: /Users/gualtieromarencoturi/Desktop/thesis/Master-Thesis-CEM-Depression-etc-case-study
  Data save dir: /Users/gualtieromarencoturi/Desktop/thesis/Master-Thesis-CEM-Depression-etc-case-study/data/processed/alternative_attention_pipeline


In [5]:
# Define 21 BDI-II concept names
CONCEPT_NAMES = [
    "Sadness", "Pessimism", "Past failure", "Loss of pleasure",
    "Guilty feelings", "Punishment feelings", "Self-dislike", "Self-criticalness",
    "Suicidal thoughts or wishes", "Crying", "Agitation", "Loss of interest",
    "Indecisiveness", "Worthlessness", "Loss of energy", "Changes in sleeping pattern",
    "Irritability", "Changes in appetite", "Concentration difficulty",
    "Tiredness or fatigue", "Loss of interest in sex"
]
N_CONCEPTS = len(CONCEPT_NAMES)

print(f"✓ Defined {N_CONCEPTS} BDI-II concepts")

✓ Defined 21 BDI-II concepts


In [6]:
# Hyperparameters
HYPERPARAMS = {
    "k_posts": 50,              # Top-k posts per subject
    "sbert_model": "all-MiniLM-L6-v2",
    "embedding_dim": 384,
}
# =========================
# DEBUG / SANITY CHECK CONFIG
# =========================
DEBUG = True
DEBUG_N_SUBJECTS = 3          # how many subjects to inspect
DEBUG_TOP_N_POSTS = 5         # how many top posts to print
DEBUG_PRINT_CONCEPTS = True   # print per-concept similarity stats

print("✓ Hyperparameters configured:")
for k, v in HYPERPARAMS.items():
    print(f"  {k}: {v}")
    
    

✓ Hyperparameters configured:
  k_posts: 50
  sbert_model: all-MiniLM-L6-v2
  embedding_dim: 384


In [7]:
# Memory Management Configuration
MEMORY_CONFIG = {
    "post_batch_size": 32,        # Encode N posts at a time
    "subject_cache_interval": 10,  # Clear GPU cache every N subjects
    "use_no_grad": True,           # Disable gradient tracking
    "move_to_cpu_immediately": True # Move results to CPU after computation
}

print("✓ Memory configuration:")
for k, v in MEMORY_CONFIG.items():
    print(f"  {k}: {v}")

✓ Memory configuration:
  post_batch_size: 32
  subject_cache_interval: 10
  use_no_grad: True
  move_to_cpu_immediately: True


In [8]:
import gc

def clear_gpu_cache():
    """Clear GPU cache and run garbage collection."""
    if DEVICE == "mps":
        torch.mps.empty_cache()
    elif DEVICE == "cuda":
        torch.cuda.empty_cache()
    gc.collect()

print("✓ GPU cache clearing utility defined")

✓ GPU cache clearing utility defined


## Section 1: Load Training Data

Extract 486 training subjects with posts and concept labels

In [9]:
# Helper functions for XML parsing
WHITESPACE_RE = re.compile(r"\s+")

def normalize_text(text):
    """Normalize text by removing null chars and extra whitespace."""
    if not text:
        return ""
    text = text.replace("\u0000", "")
    text = WHITESPACE_RE.sub(" ", text).strip()
    return text

def extract_posts_from_xml(xml_path, min_chars=10):
    """Extract posts from a single XML file."""
    try:
        tree = ET.parse(xml_path)
        root = tree.getroot()
    except Exception as e:
        print(f"WARNING: Failed to parse {xml_path}: {e}")
        return []
    
    posts = []
    for writing in root.findall("WRITING"):
        title = writing.findtext("TITLE") or ""
        text = writing.findtext("TEXT") or ""
        
        combined = normalize_text(f"{title} {text}".strip())
        if len(combined) >= min_chars:
            posts.append(combined)
    
    return posts

print("✓ Helper functions defined")

✓ Helper functions defined


In [10]:
# Parse training XML files
print("Loading training data...")
start_time = time.time()

train_data = []

# Process positive examples
print("  Processing positive examples...")
pos_files = glob.glob(os.path.join(POS_DIR, "**", "*.xml"), recursive=True)
for xml_file in tqdm(pos_files, desc="Processing positive examples"):
    filename = os.path.basename(xml_file)
    match = re.match(r"train_(subject\d+)_\d+\.xml", filename)
    if match:
        subject_id = match.group(1)
        posts = extract_posts_from_xml(xml_file)
        for post in posts:
            train_data.append({
                "subject_id": subject_id,
                "label": 1,  # Positive (depression)
                "text": post
            })

print(f"  Loaded {sum(d['label'] == 1 for d in train_data)} posts from positive subjects")

# Process negative examples
print("  Processing negative examples...")
neg_files = glob.glob(os.path.join(NEG_DIR, "**", "*.xml"), recursive=True)
for xml_file in tqdm(neg_files, desc="Processing negative examples"):
    filename = os.path.basename(xml_file)
    match = re.match(r"train_(subject\d+)_\d+\.xml", filename)
    if match:
        subject_id = match.group(1)
        posts = extract_posts_from_xml(xml_file)
        for post in posts:
            train_data.append({
                "subject_id": subject_id,
                "label": 0,  # Negative (control)
                "text": post
            })

train_posts_df = pd.DataFrame(train_data)

print(f"\n✓ Loaded training data in {time.time()-start_time:.1f}s")
print(f"  Total posts: {len(train_posts_df):,}")
print(f"  Unique subjects: {train_posts_df['subject_id'].nunique()}")
print(f"  Label distribution:")
print(train_posts_df.groupby('label')['subject_id'].nunique())

Loading training data...
  Processing positive examples...


Processing positive examples: 100%|██████████| 830/830 [00:00<00:00, 2387.88it/s]


  Loaded 29868 posts from positive subjects
  Processing negative examples...


Processing negative examples: 100%|██████████| 4031/4031 [00:02<00:00, 1709.58it/s]



✓ Loaded training data in 2.8s
  Total posts: 286,740
  Unique subjects: 486
  Label distribution:
label
0    403
1     83
Name: subject_id, dtype: int64


In [11]:
# Load concept labels from questionnaires
print("Loading concept labels...")

concepts_df = pd.read_csv(CONCEPTS_FILE)
concepts_df["subject_id"] = concepts_df["Subject"].str.replace("train_", "", regex=True)

# Binarize concept values
concept_cols = [col for col in concepts_df.columns if col in CONCEPT_NAMES]
for col in concept_cols:
    concepts_df[col] = (concepts_df[col] > 0).astype(int)

print(f"✓ Loaded concept labels for {len(concepts_df)} subjects")

Loading concept labels...
✓ Loaded concept labels for 486 subjects


## Section 2: Load Test Data

Extract 401 test subjects and split into validation (200) and test (201)

In [12]:
# Extract test ZIP files to temporary directory
print("Extracting test data...")
temp_dir = tempfile.mkdtemp(prefix="test_chunks_")
print(f"  Temp directory: {temp_dir}")

for i in range(1, 11):
    zip_path = os.path.join(TEST_DIR, f"chunk {i}.zip")
    if os.path.exists(zip_path):
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(os.path.join(temp_dir, f"chunk_{i}"))
        if i % 3 == 0:
            print(f"  Extracted chunk {i}/10")

print("✓ Test data extracted")

Extracting test data...
  Temp directory: /var/folders/gb/m6c_r5xx6_14p7mlfjwk29900000gn/T/test_chunks_knny8xhm
  Extracted chunk 3/10
  Extracted chunk 6/10
  Extracted chunk 9/10
✓ Test data extracted


In [13]:
# Load test labels
test_labels_df = pd.read_csv(TEST_LABELS, sep='\t', header=None, names=['subject_id', 'label'])
test_labels_df['subject_id'] = test_labels_df['subject_id'].str.strip()

print(f"✓ Loaded test labels for {len(test_labels_df)} subjects")
print(f"  Label distribution:")
print(test_labels_df['label'].value_counts())

✓ Loaded test labels for 401 subjects
  Label distribution:
label
0    349
1     52
Name: count, dtype: int64


In [14]:
# Parse test XML files
print("Loading test posts...")
test_data = []

test_xml_files = glob.glob(os.path.join(temp_dir, "**", "*.xml"), recursive=True)
print(f"  Found {len(test_xml_files)} XML files")

for xml_file in test_xml_files:
    filename = os.path.basename(xml_file)
    match = re.match(r"(test_subject\d+)_\d+\.xml", filename)
    if match:
        subject_id = match.group(1)
        label_row = test_labels_df[test_labels_df['subject_id'] == subject_id]
        if len(label_row) > 0:
            label = label_row.iloc[0]['label']
            posts = extract_posts_from_xml(xml_file)
            for post in posts:
                test_data.append({
                    "subject_id": subject_id,
                    "label": label,
                    "text": post
                })

test_posts_df = pd.DataFrame(test_data)

print(f"✓ Loaded test posts")
print(f"  Total posts: {len(test_posts_df):,}")
print(f"  Unique subjects: {test_posts_df['subject_id'].nunique()}")

Loading test posts...
  Found 4010 XML files
✓ Loaded test posts
  Total posts: 229,746
  Unique subjects: 401


In [15]:
# Split test data into validation and test sets (stratified 50/50)
print("Splitting test data into validation and test...")

test_subjects = test_posts_df.groupby('subject_id')['label'].first().reset_index()

val_subjects, test_subjects_final = train_test_split(
    test_subjects['subject_id'],
    test_size=0.5,
    stratify=test_subjects['label'],
    random_state=SEED
)

val_posts_df = test_posts_df[test_posts_df['subject_id'].isin(val_subjects)].copy()
test_posts_df_final = test_posts_df[test_posts_df['subject_id'].isin(test_subjects_final)].copy()

print(f"✓ Split complete")
print(f"  Validation: {val_posts_df['subject_id'].nunique()} subjects")
print(f"  Test: {test_posts_df_final['subject_id'].nunique()} subjects")

Splitting test data into validation and test...
✓ Split complete
  Validation: 200 subjects
  Test: 201 subjects


## Section 3: SBERT Setup & Concept Embeddings

In [16]:
# Load SBERT model
print(f"Loading SBERT model: {HYPERPARAMS['sbert_model']}")
sbert_model = SentenceTransformer(HYPERPARAMS['sbert_model'])
sbert_model = sbert_model.to(DEVICE)

print(f"✓ SBERT model loaded on {DEVICE}")
print(f"  Embedding dimension: {sbert_model.get_sentence_embedding_dimension()}")

Loading SBERT model: all-MiniLM-L6-v2
✓ SBERT model loaded on mps
  Embedding dimension: 384


In [17]:
# Create concept embeddings
print(f"Creating embeddings for {N_CONCEPTS} concepts...")
concept_embeddings = sbert_model.encode(
    CONCEPT_NAMES,
    convert_to_tensor=True,
    show_progress_bar=False
)

print(f"✓ Concept embeddings created")
print(f"  Shape: {concept_embeddings.shape}")

Creating embeddings for 21 concepts...
✓ Concept embeddings created
  Shape: torch.Size([21, 384])


In [19]:
def retrieve_top_k_posts_sum(subject_id, posts_df, concept_embs, sbert, k=50, batch_size=32, debug=False):
    """
    Retrieve top-k posts for a subject based on SUM of concept similarities.
    OPTIMIZED: Uses batching to prevent memory exhaustion.
    """
    subj_posts = posts_df[posts_df['subject_id'] == subject_id]['text'].tolist()

    if len(subj_posts) == 0:
        return []

    if len(subj_posts) <= k:
        if len(subj_posts) < k:
            extra_needed = k - len(subj_posts)
            padding = list(np.random.choice(subj_posts, size=extra_needed, replace=True))
            return subj_posts + padding
        else:
            return subj_posts

    # NEW: Batch encoding to prevent memory issues
    sum_sim_scores = []

    with torch.no_grad():  # Disable gradient tracking
        for i in range(0, len(subj_posts), batch_size):
            batch_posts = subj_posts[i:i + batch_size]

            # Encode batch
            batch_embeddings = sbert.encode(
                batch_posts,
                convert_to_tensor=True,
                show_progress_bar=False
            )

            # Compute similarities for this batch
            cos_scores = util.cos_sim(batch_embeddings, concept_embs)  # [batch, 21]
            batch_sum_scores = cos_scores.sum(dim=1).cpu().numpy()  # Move to CPU immediately

            sum_sim_scores.extend(batch_sum_scores)

            # Clear references
            del batch_embeddings, cos_scores, batch_sum_scores

    sum_sim_scores = np.array(sum_sim_scores)

    if debug:
        print("\n" + "="*60)
        print(f"[DEBUG] Subject: {subject_id}")
        print(f"[DEBUG] Total posts: {len(subj_posts)}")
        print("[DEBUG] Sum similarity stats:")
        print(f"  min={sum_sim_scores.min():.4f} "
              f"max={sum_sim_scores.max():.4f} "
              f"mean={sum_sim_scores.mean():.4f} "
              f"std={sum_sim_scores.std():.4f}")

        top_idx_sorted = np.argsort(-sum_sim_scores)
        print(f"\n[DEBUG] Top-{DEBUG_TOP_N_POSTS} retrieved posts:")
        for rank, i in enumerate(top_idx_sorted[:DEBUG_TOP_N_POSTS]):
            print(f"\n  Rank {rank+1}")
            print(f"  Score: {sum_sim_scores[i]:.4f}")
            print(f"  Text: {subj_posts[i][:300]}")

    # Select top-k posts
    top_k_indices = np.argpartition(-sum_sim_scores, range(min(k, len(subj_posts))))[:k]

    return [subj_posts[i] for i in top_k_indices]

print("✓ Batched post retrieval function defined (SUM-based)")

✓ Batched post retrieval function defined (SUM-based)


In [None]:
# Retrieve top-k posts for all subjects
print(f"Retrieving top-{HYPERPARAMS['k_posts']} posts (SUM-based scoring with batching)...")
print("⏰ This will be faster and more memory-efficient")
start_time = time.time()

# Training subjects
print("  Processing training subjects...")
train_selected = {}
train_subjects = train_posts_df['subject_id'].unique()

for idx, subject_id in enumerate(tqdm(train_subjects, desc="Train subjects")):
    selected = retrieve_top_k_posts_sum(
        subject_id,
        train_posts_df,
        concept_embeddings,
        sbert_model,
        k=HYPERPARAMS['k_posts'],
        batch_size=MEMORY_CONFIG['post_batch_size'],
        debug=(DEBUG and idx < DEBUG_N_SUBJECTS)
    )
    train_selected[subject_id] = selected

    # Clear GPU cache periodically
    if (idx + 1) % MEMORY_CONFIG['subject_cache_interval'] == 0:
        clear_gpu_cache()

# Validation subjects
print("\n  Processing validation subjects...")
val_selected = {}
val_subjects = val_posts_df['subject_id'].unique()

for idx, subject_id in enumerate(tqdm(val_subjects, desc="Val subjects")):
    selected = retrieve_top_k_posts_sum(
        subject_id,
        val_posts_df,
        concept_embeddings,
        sbert_model,
        k=HYPERPARAMS['k_posts'],
        batch_size=MEMORY_CONFIG['post_batch_size'],
        debug=(DEBUG and idx < DEBUG_N_SUBJECTS)
    )
    val_selected[subject_id] = selected

    if (idx + 1) % MEMORY_CONFIG['subject_cache_interval'] == 0:
        clear_gpu_cache()

# Test subjects
print("\n  Processing test subjects...")
test_selected = {}
test_subjects = test_posts_df_final['subject_id'].unique()

for idx, subject_id in enumerate(tqdm(test_subjects, desc="Test subjects")):
    selected = retrieve_top_k_posts_sum(
        subject_id,
        test_posts_df_final,
        concept_embeddings,
        sbert_model,
        k=HYPERPARAMS['k_posts'],
        batch_size=MEMORY_CONFIG['post_batch_size'],
        debug=(DEBUG and idx < DEBUG_N_SUBJECTS)
    )
    test_selected[subject_id] = selected

    if (idx + 1) % MEMORY_CONFIG['subject_cache_interval'] == 0:
        clear_gpu_cache()

# Final cache clear
clear_gpu_cache()

print(f"\n✓ Post retrieval complete in {time.time()-start_time:.1f}s ({(time.time()-start_time)/60:.1f} min)")
print(f"  Memory-optimized processing: {len(train_subjects) + len(val_subjects) + len(test_subjects)} subjects")

Retrieving top-50 posts (SUM-based scoring with batching)...
⏰ This will be faster and more memory-efficient
  Processing training subjects...


Train subjects:   0%|          | 1/486 [00:00<04:17,  1.88it/s]


[DEBUG] Subject: subject6760
[DEBUG] Total posts: 124
[DEBUG] Sum similarity stats:
  min=-0.9500 max=5.2591 mean=1.5585 std=1.3026

[DEBUG] Top-5 retrieved posts:

  Rank 1
  Score: 5.2591
  Text: Not feeling negative could be considered as better.

  Rank 2
  Score: 4.8450
  Text: The 12 Worst Habits for Your Mental Health

  Rank 3
  Score: 4.2735
  Text: Can you explain?

  Rank 4
  Score: 4.2102
  Text: A roller-coaster in which I always seem to be lost in the past or in the future, never living the moment.

  Rank 5
  Score: 3.8703
  Text: Selfish motives seem to make you succeed. I'm shocked!


Train subjects:   1%|          | 3/486 [00:01<02:43,  2.96it/s]


[DEBUG] Subject: subject7326
[DEBUG] Total posts: 118
[DEBUG] Sum similarity stats:
  min=-1.0593 max=3.1569 mean=1.0799 std=0.8906

[DEBUG] Top-5 retrieved posts:

  Rank 1
  Score: 3.1569
  Text: I have had a good friend in school, who changed A LOT in personality when he picked up these kind of games at the age of ~ 14. He isolated himself from others, started hating around, joined a "shooting club" (i dont know the englsih word for this) and started talking about killing people frequently,

  Rank 2
  Score: 3.0400
  Text: I actually DO think that playing ego-shooters increases your willingness to use violence and causes emotional blunting. edit: gettin' downvoted on an unpopular opinion thread, i did it reddit! ppl playing shooters are mad as fuck.

  Rank 3
  Score: 2.9997
  Text: "actively sabotage" lol. He offered her food. That's about it. He didn't fill her food with hidden chocolate and neither did he shove down pizza down her throat when she was asleep. There's NO point of

Train subjects: 100%|██████████| 486/486 [24:19<00:00,  3.00s/it]  



  Processing validation subjects...


Val subjects:   0%|          | 1/200 [00:01<04:36,  1.39s/it]


[DEBUG] Subject: test_subject3081
[DEBUG] Total posts: 121
[DEBUG] Sum similarity stats:
  min=-0.7924 max=4.1711 mean=1.5028 std=0.9968

[DEBUG] Top-5 retrieved posts:

  Rank 1
  Score: 4.1711
  Text: Sometimes I feel like a monster..

  Rank 2
  Score: 4.1295
  Text: Easily one of the most frustrating things about a relationship

  Rank 3
  Score: 3.3703
  Text: What's clunky?

  Rank 4
  Score: 3.3338
  Text: Burn it. Burn it all

  Rank 5
  Score: 3.2723
  Text: OMG DESCENT! !!


Val subjects:   1%|          | 2/200 [00:02<03:31,  1.07s/it]


[DEBUG] Subject: test_subject2751
[DEBUG] Total posts: 175
[DEBUG] Sum similarity stats:
  min=-0.7178 max=4.7579 mean=0.9268 std=0.8323

[DEBUG] Top-5 retrieved posts:

  Rank 1
  Score: 4.7579
  Text: What are some Causes Effects that most people don't realise?

  Rank 2
  Score: 2.8359
  Text: What sort of bizarre situation do you often simulate in your mind? When watching a live performance (e.g. concert/play/musical), I often think of what would happen if my mind is suddenly swapped with the performer. I would think of how much I'd panic as I don't know their material, and how the perfo

  Rank 3
  Score: 2.6792
  Text: I dislike games with traitor mechanism or those that require backstabbing. I just can't identify any fun factor. Why do people find lying and deceit to be fun? The only game I can stand in this genre is One Night Ultimate Werewolf because it lasts 10 minutes.

  Rank 4
  Score: 2.6308
  Text: **The Resistance: Avalon**, **Game of Thrones (2nd ed)**, **BSG**, **Cos

Val subjects:   2%|▏         | 3/200 [00:03<04:07,  1.26s/it]


[DEBUG] Subject: test_subject6974
[DEBUG] Total posts: 74
[DEBUG] Sum similarity stats:
  min=-0.7505 max=3.9495 mean=1.4575 std=0.9522

[DEBUG] Top-5 retrieved posts:

  Rank 1
  Score: 3.9495
  Text: Eating food that fell on the floor.

  Rank 2
  Score: 3.2105
  Text: blackholing

  Rank 3
  Score: 3.1062
  Text: BURRRRRRN.

  Rank 4
  Score: 3.0664
  Text: Good. Fuck them.

  Rank 5
  Score: 3.0180
  Text: Health Care.


Val subjects:  26%|██▌       | 51/200 [02:41<09:58,  4.02s/it]

In [None]:
# Retrieve top-k posts for all subjects
print(f"Retrieving top-{HYPERPARAMS['k_posts']} posts (SUM-based scoring)...")
print("⏰ This will take some time")
start_time = time.time()

# Training subjects
print("  Processing training subjects...")
train_selected = {}
train_subjects = train_posts_df['subject_id'].unique()
for idx, subject_id in enumerate(train_subjects):
    selected = retrieve_top_k_posts_sum(
        subject_id,
        train_posts_df,
        concept_embeddings,
        sbert_model,
        k=HYPERPARAMS['k_posts'],
        debug=(DEBUG and idx < DEBUG_N_SUBJECTS)
    )

    train_selected[subject_id] = selected
    
    if (idx + 1) % 100 == 0:
        print(f"    Processed {idx + 1}/{len(train_subjects)} subjects")

# Validation subjects
print("  Processing validation subjects...")
val_selected = {}
val_subjects = val_posts_df['subject_id'].unique()
for idx, subject_id in enumerate(val_subjects):
    selected = retrieve_top_k_posts_sum(
        subject_id,
        val_posts_df,  # FIXED: Use val_posts_df instead of train_posts_df
        concept_embeddings,
        sbert_model,
        k=HYPERPARAMS['k_posts'],
        debug=(DEBUG and idx < DEBUG_N_SUBJECTS)
    )
    val_selected[subject_id] = selected

# Test subjects
print("  Processing test subjects...")
test_selected = {}
test_subjects = test_posts_df_final['subject_id'].unique()
for idx, subject_id in enumerate(test_subjects):
    selected = retrieve_top_k_posts_sum(
        subject_id,
        test_posts_df_final,  # FIXED: Use test_posts_df_final instead of train_posts_df
        concept_embeddings,
        sbert_model,
        k=HYPERPARAMS['k_posts'],
        debug=(DEBUG and idx < DEBUG_N_SUBJECTS)
    )
    test_selected[subject_id] = selected

print(f"\n✓ Post retrieval complete in {time.time()-start_time:.1f}s ({(time.time()-start_time)/60:.1f} min)")

In [None]:
def encode_and_attention_pool_sum(selected_posts_dict, sbert, concept_embs,
                                   normalize=True, debug=False):
    """
    Encode posts and pool using SUM of concept similarities for attention.
    OPTIMIZED: Includes memory management for stability.
    """
    subject_ids = list(selected_posts_dict.keys())
    pooled_embeddings = []

    with torch.no_grad():  # NEW: Disable gradient tracking
        for idx, subject_id in enumerate(subject_ids):
            posts = selected_posts_dict[subject_id]

            # Handle empty posts
            if len(posts) == 0:
                print(f"WARNING: No posts for subject {subject_id}, using zero embedding")
                pooled_embeddings.append(np.zeros(384))
                continue

            # Filter out empty posts
            posts = [p for p in posts if p.strip()]
            if len(posts) == 0:
                print(f"WARNING: All posts empty for subject {subject_id}, using zero embedding")
                pooled_embeddings.append(np.zeros(384))
                continue

            # Encode posts
            post_embs = sbert.encode(
                posts,
                convert_to_tensor=True,
                show_progress_bar=False
            )

            if post_embs.shape[0] == 0 or post_embs.shape[1] == 0:
                print(f"WARNING: Empty embeddings for subject {subject_id}, using zero embedding")
                pooled_embeddings.append(np.zeros(384))
                continue

            # Compute similarity to concepts
            cos_scores = util.cos_sim(post_embs, concept_embs)

            # Sum across all concepts
            if normalize:
                normalized_scores = cos_scores / (cos_scores.sum(dim=1, keepdim=True) + 1e-10)
                post_scores = normalized_scores.sum(dim=1)
            else:
                post_scores = cos_scores.sum(dim=1)

            # Attention weights
            attn_weights = torch.softmax(post_scores, dim=0)

            if debug and idx < DEBUG_N_SUBJECTS:
                print("\n" + "="*60)
                print(f"[DEBUG][ATTENTION] Subject: {subject_id}")
                attn_np = attn_weights.cpu().numpy()
                print("[DEBUG][ATTENTION] Weight stats:")
                print(f"  min={attn_np.min():.6f} "
                      f"max={attn_np.max():.6f} "
                      f"mean={attn_np.mean():.6f} "
                      f"entropy={-np.sum(attn_np * np.log(attn_np + 1e-12)):.4f}")

                top_attn_idx = np.argsort(-attn_np)[:DEBUG_TOP_N_POSTS]
                print(f"\n[DEBUG][ATTENTION] Top-{DEBUG_TOP_N_POSTS} attended posts:")
                for rank, i in enumerate(top_attn_idx):
                    print(f"\n  Rank {rank+1}")
                    print(f"  Attention: {attn_np[i]:.6f}")
                    print(f"  Text: {posts[i][:300]}")

            # Weighted sum pooling
            pooled = torch.sum(attn_weights.unsqueeze(1) * post_embs, dim=0)
            pooled_embeddings.append(pooled.cpu().numpy())

            # NEW: Clean up GPU memory
            del post_embs, cos_scores, attn_weights, pooled

    return np.vstack(pooled_embeddings), subject_ids

print("✓ Memory-optimized attention pooling function defined")

In [None]:
# Encode and pool for all splits
print("Encoding and pooling embeddings (SUM-based attention, memory-optimized)...")

start_time = time.time()

print("  Training set...")
X_train, train_subject_ids = encode_and_attention_pool_sum(
    train_selected,
    sbert_model,
    concept_embeddings,
    normalize=True,
    debug=DEBUG
)
clear_gpu_cache()  # NEW: Clear after training
print(f"    X_train shape: {X_train.shape}")

print("  Validation set...")
X_val, val_subject_ids = encode_and_attention_pool_sum(
    val_selected, sbert_model, concept_embeddings, normalize=True
)
clear_gpu_cache()  # NEW: Clear after validation
print(f"    X_val shape: {X_val.shape}")

print("  Test set...")
X_test, test_subject_ids = encode_and_attention_pool_sum(
    test_selected, sbert_model, concept_embeddings, normalize=True
)
clear_gpu_cache()  # NEW: Clear after test
print(f"    X_test shape: {X_test.shape}")

print(f"\n✓ Encoding complete in {time.time()-start_time:.1f}s ({(time.time()-start_time)/60:.1f} min)")

In [None]:
# Encode and pool for all splits
print("Encoding and pooling embeddings (SUM-based attention)...")

start_time = time.time()

print("  Training set...")
X_train, train_subject_ids = encode_and_attention_pool_sum(
    train_selected,
    sbert_model,
    concept_embeddings,
    normalize=True,
    debug=DEBUG
)

print(f"    X_train shape: {X_train.shape}")

print("  Validation set...")
X_val, val_subject_ids = encode_and_attention_pool_sum(  # CHANGED FUNCTION NAME
    val_selected, sbert_model, concept_embeddings, normalize=True
)
print(f"    X_val shape: {X_val.shape}")

print("  Test set...")
X_test, test_subject_ids = encode_and_attention_pool_sum(  # CHANGED FUNCTION NAME
    test_selected, sbert_model, concept_embeddings, normalize=True
)
print(f"    X_test shape: {X_test.shape}")

print(f"\n✓ Encoding complete in {time.time()-start_time:.1f}s ({(time.time()-start_time)/60:.1f} min)")

## Section 6: Build Concept Matrices and Labels

In [None]:
# Build concept matrices and label vectors
print("Building concept matrices and labels...")

# Training: get concepts from questionnaires
C_train = []
y_train = []
for subject_id in train_subject_ids:
    label = train_posts_df[train_posts_df['subject_id'] == subject_id]['label'].iloc[0]
    y_train.append(label)
    
    concept_row = concepts_df[concepts_df['subject_id'] == subject_id]
    if len(concept_row) > 0:
        concepts = concept_row[concept_cols].values[0]
    else:
        concepts = np.zeros(N_CONCEPTS)
    C_train.append(concepts)

C_train = np.array(C_train, dtype=np.float32)
y_train = np.array(y_train, dtype=np.float32)

# Validation: zeros for concepts (no ground truth)
C_val = np.zeros((len(val_subject_ids), N_CONCEPTS), dtype=np.float32)
y_val = []
for subject_id in val_subject_ids:
    label = val_posts_df[val_posts_df['subject_id'] == subject_id]['label'].iloc[0]
    y_val.append(label)
y_val = np.array(y_val, dtype=np.float32)

# Test: zeros for concepts
C_test = np.zeros((len(test_subject_ids), N_CONCEPTS), dtype=np.float32)
y_test = []
for subject_id in test_subject_ids:
    label = test_posts_df_final[test_posts_df_final['subject_id'] == subject_id]['label'].iloc[0]
    y_test.append(label)
y_test = np.array(y_test, dtype=np.float32)

print("✓ Matrices built")
print(f"  Train: X={X_train.shape}, C={C_train.shape}, y={y_train.shape}")
print(f"  Val:   X={X_val.shape}, C={C_val.shape}, y={y_val.shape}")
print(f"  Test:  X={X_test.shape}, C={C_test.shape}, y={y_test.shape}")
print(f"\n  Training label distribution: {np.bincount(y_train.astype(int))}")
print(f"  Validation label distribution: {np.bincount(y_val.astype(int))}")
print(f"  Test label distribution: {np.bincount(y_test.astype(int))}")

## Section 7: Compute Class Weights

In [None]:
# Compute class weights for imbalanced dataset
n_negative = int(np.sum(y_train == 0))
n_positive = int(np.sum(y_train == 1))
pos_weight = n_negative / n_positive

print(f"Class imbalance:")
print(f"  Negative samples: {n_negative}")
print(f"  Positive samples: {n_positive}")
print(f"  Ratio: 1:{pos_weight:.2f}")
print(f"  Computed pos_weight: {pos_weight:.4f}")

## Section 8: Save All Datasets

Save everything for fast loading by training pipelines

In [None]:
# Save processed datasets to disk
print("Saving datasets...")

# Save numpy arrays
np.savez_compressed(
    os.path.join(SAVE_DIR, "train_data.npz"),
    X=X_train,
    C=C_train,
    y=y_train,
    subject_ids=np.array(train_subject_ids)
)

np.savez_compressed(
    os.path.join(SAVE_DIR, "val_data.npz"),
    X=X_val,
    C=C_val,
    y=y_val,
    subject_ids=np.array(val_subject_ids)
)

np.savez_compressed(
    os.path.join(SAVE_DIR, "test_data.npz"),
    X=X_test,
    C=C_test,
    y=y_test,
    subject_ids=np.array(test_subject_ids)
)

# Save class weights info
class_info = {
    "n_positive": n_positive,
    "n_negative": n_negative,
    "pos_weight": float(pos_weight)
}

with open(os.path.join(SAVE_DIR, "class_weights.json"), 'w') as f:
    json.dump(class_info, f, indent=4)

print(f"✓ Datasets saved to {SAVE_DIR}")
print(f"  train_data.npz: {X_train.shape[0]} samples")
print(f"  val_data.npz:   {X_val.shape[0]} samples")
print(f"  test_data.npz:  {X_test.shape[0]} samples")
print(f"  class_weights.json")

## Section 9: Cleanup

In [None]:
# Clean up temporary directory
try:
    shutil.rmtree(temp_dir)
    print(f"✓ Cleaned up temporary directory: {temp_dir}")
except Exception as e:
    print(f"⚠ Failed to clean up temporary directory: {e}")

In [None]:
print("\n" + "="*70)
print("      ALTERNATIVE DATASET PREPARATION COMPLETE (SUM-BASED)")
print("="*70)
print("\nSaved files:")
print(f"  {SAVE_DIR}/train_data.npz")
print(f"  {SAVE_DIR}/val_data.npz")
print(f"  {SAVE_DIR}/test_data.npz")
print(f"  {SAVE_DIR}/class_weights.json")
print("\nKey difference from original:")
print("  - Uses SUM of concept similarities instead of MAX")
print("  - Captures posts relevant to MULTIPLE concepts")
print("  - Normalization enabled (each concept weighted equally)")
print("\nUse this data with CEM/CBM training notebooks!")
print("="*70)