# eXtreme Alternative Attention Dataset Preparation - Temperature-Sharpened MAX

**Purpose:** Prepare dataset using temperature-sharpened MAX of concept similarities.

**Key Difference from 0c_prepare_max**:
- 0c: `post_score = max(similarity_to_each_concept)`
- This (0d): `post_score = max(similarity_to_each_concept / COSINE_TEMPERATURE)`

Temperature sharpening amplifies strong cosine similarities BEFORE taking MAX, creating more extreme post selection.
Lower temperature → more extreme (winner-take-all), higher temperature → smoother.

**Runtime:** ~40-50 minutes (same as 0c)

This notebook:
1. Loads training and test data from XML files
2. Uses SBERT to retrieve top-50 concept-relevant posts per subject
3. Applies temperature sharpening BEFORE MAX
4. Pools post embeddings using temperature-sharpened attention weights
5. Saves everything to `data/processed/extreme_alternative_attention_pipeline/`

## Section 0: Configuration & Setup

In [1]:
# Imports
import os
import glob
import re
import zipfile
import tempfile
import shutil
import json
import time

import numpy as np
import pandas as pd
import xml.etree.ElementTree as ET

import torch
from sentence_transformers import SentenceTransformer, util
from sklearn.model_selection import train_test_split
from tqdm import tqdm

print("✓ All imports successful")

✓ All imports successful


In [2]:
# Set random seeds for reproducibility
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)

print(f"✓ Random seed set to {SEED}")

✓ Random seed set to 42


In [3]:
# Detect device (MPS/CUDA/CPU)
if torch.backends.mps.is_available():
    DEVICE = "mps"
    print("✓ Using MacBook GPU (MPS)")
elif torch.cuda.is_available():
    DEVICE = "cuda"
    print("✓ Using CUDA GPU")
else:
    DEVICE = "cpu"
    print("⚠ Using CPU (will be slow)")

✓ Using CUDA GPU


In [4]:
# Define paths
PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), "..", ".."))
DATA_RAW = os.path.join(PROJECT_ROOT, "data/raw")
DATA_PROCESSED = os.path.join(PROJECT_ROOT, "data/processed")

# Training data paths
POS_DIR = os.path.join(DATA_RAW, "train/positive_examples_anonymous_chunks")
NEG_DIR = os.path.join(DATA_RAW, "train/negative_examples_anonymous_chunks")

# Test data paths
TEST_DIR = os.path.join(DATA_RAW, "test")
TEST_LABELS = os.path.join(TEST_DIR, "test_golden_truth.txt")

# Concept labels
CONCEPTS_FILE = os.path.join(DATA_PROCESSED, "merged_questionnaires.csv")

# Output directory - CHANGED FOR EXTREME ALTERNATIVE PIPELINE
SAVE_DIR = os.path.join(DATA_PROCESSED, "extreme_alternative_attention_pipeline")
os.makedirs(SAVE_DIR, exist_ok=True)

print("✓ Paths configured")
print(f"  Project root: {PROJECT_ROOT}")
print(f"  Data save dir: {SAVE_DIR}")

✓ Paths configured
  Project root: /teamspace/studios/this_studio/Master-Thesis-CEM-Depression-etc-case-study
  Data save dir: /teamspace/studios/this_studio/Master-Thesis-CEM-Depression-etc-case-study/data/processed/max_alternative_attention_pipeline


In [5]:
# Define 21 BDI-II concept names
CONCEPT_NAMES = [
    "Sadness", "Pessimism", "Past failure", "Loss of pleasure",
    "Guilty feelings", "Punishment feelings", "Self-dislike", "Self-criticalness",
    "Suicidal thoughts or wishes", "Crying", "Agitation", "Loss of interest",
    "Indecisiveness", "Worthlessness", "Loss of energy", "Changes in sleeping pattern",
    "Irritability", "Changes in appetite", "Concentration difficulty",
    "Tiredness or fatigue", "Loss of interest in sex"
]
N_CONCEPTS = len(CONCEPT_NAMES)

print(f"✓ Defined {N_CONCEPTS} BDI-II concepts")

✓ Defined 21 BDI-II concepts


In [6]:
# Hyperparameters
HYPERPARAMS = {
    "k_posts": 50,              # Top-k posts per subject
    "sbert_model": "all-MiniLM-L6-v2",
    "embedding_dim": 384,
}
# =========================
# DEBUG / SANITY CHECK CONFIG
# =========================
DEBUG = True
DEBUG_N_SUBJECTS = 3          # how many subjects to inspect
DEBUG_TOP_N_POSTS = 5         # how many top posts to print
DEBUG_PRINT_CONCEPTS = True   # print per-concept similarity stats

print("✓ Hyperparameters configured:")
for k, v in HYPERPARAMS.items():
    print(f"  {k}: {v}")

✓ Hyperparameters configured:
  k_posts: 50
  sbert_model: all-MiniLM-L6-v2
  embedding_dim: 384


In [None]:
# Temperature sharpening parameter
COSINE_TEMPERATURE = 0.5  # Lower = more extreme (winner-take-all), Higher = smoother
# Try: 0.3 (extreme), 0.5 (moderate), 0.7 (mild)

print(f"✓ Cosine temperature configured: {COSINE_TEMPERATURE}")
print(f"  Lower temp = amplifies strong similarities (more extreme)")
print(f"  Higher temp = smoother similarity distribution")

In [7]:
# Memory Management Configuration
MEMORY_CONFIG = {
    "post_batch_size": 32,        # Encode N posts at a time
    "subject_cache_interval": 10,  # Clear GPU cache every N subjects
    "use_no_grad": True,           # Disable gradient tracking
    "move_to_cpu_immediately": True # Move results to CPU after computation
}

print("✓ Memory configuration:")
for k, v in MEMORY_CONFIG.items():
    print(f"  {k}: {v}")

✓ Memory configuration:
  post_batch_size: 32
  subject_cache_interval: 10
  use_no_grad: True
  move_to_cpu_immediately: True


In [8]:
import gc

def clear_gpu_cache():
    """Clear GPU cache and run garbage collection."""
    if DEVICE == "mps":
        torch.mps.empty_cache()
    elif DEVICE == "cuda":
        torch.cuda.empty_cache()
    gc.collect()

print("✓ GPU cache clearing utility defined")

✓ GPU cache clearing utility defined


## Section 1: Load Training Data

Extract 486 training subjects with posts and concept labels

In [9]:
# Helper functions for XML parsing
WHITESPACE_RE = re.compile(r"\s+")

def normalize_text(text):
    """Normalize text by removing null chars and extra whitespace."""
    if not text:
        return ""
    text = text.replace("\u0000", "")
    text = WHITESPACE_RE.sub(" ", text).strip()
    return text

def extract_posts_from_xml(xml_path, min_chars=10):
    """Extract posts from a single XML file."""
    try:
        tree = ET.parse(xml_path)
        root = tree.getroot()
    except Exception as e:
        print(f"WARNING: Failed to parse {xml_path}: {e}")
        return []
    
    posts = []
    for writing in root.findall("WRITING"):
        title = writing.findtext("TITLE") or ""
        text = writing.findtext("TEXT") or ""
        
        combined = normalize_text(f"{title} {text}".strip())
        if len(combined) >= min_chars:
            posts.append(combined)
    
    return posts

print("✓ Helper functions defined")

✓ Helper functions defined


In [10]:
# Parse training XML files
print("Loading training data...")
start_time = time.time()

train_data = []

# Process positive examples
print("  Processing positive examples...")
pos_files = glob.glob(os.path.join(POS_DIR, "**", "*.xml"), recursive=True)
for xml_file in tqdm(pos_files, desc="Processing positive examples"):
    filename = os.path.basename(xml_file)
    match = re.match(r"train_(subject\d+)_\d+\.xml", filename)
    if match:
        subject_id = match.group(1)
        posts = extract_posts_from_xml(xml_file)
        for post in posts:
            train_data.append({
                "subject_id": subject_id,
                "label": 1,  # Positive (depression)
                "text": post
            })

print(f"  Loaded {sum(d['label'] == 1 for d in train_data)} posts from positive subjects")

# Process negative examples
print("  Processing negative examples...")
neg_files = glob.glob(os.path.join(NEG_DIR, "**", "*.xml"), recursive=True)
for xml_file in tqdm(neg_files, desc="Processing negative examples"):
    filename = os.path.basename(xml_file)
    match = re.match(r"train_(subject\d+)_\d+\.xml", filename)
    if match:
        subject_id = match.group(1)
        posts = extract_posts_from_xml(xml_file)
        for post in posts:
            train_data.append({
                "subject_id": subject_id,
                "label": 0,  # Negative (control)
                "text": post
            })

train_posts_df = pd.DataFrame(train_data)

print(f"\n✓ Loaded training data in {time.time()-start_time:.1f}s")
print(f"  Total posts: {len(train_posts_df):,}")
print(f"  Unique subjects: {train_posts_df['subject_id'].nunique()}")
print(f"  Label distribution:")
print(train_posts_df.groupby('label')['subject_id'].nunique())

Loading training data...
  Processing positive examples...


Processing positive examples: 100%|██████████| 830/830 [00:00<00:00, 1268.71it/s]


  Loaded 29868 posts from positive subjects
  Processing negative examples...


Processing negative examples: 100%|██████████| 4031/4031 [00:05<00:00, 805.55it/s]



✓ Loaded training data in 6.0s
  Total posts: 286,740
  Unique subjects: 486
  Label distribution:
label
0    403
1     83
Name: subject_id, dtype: int64


In [11]:
# Load concept labels from questionnaires
print("Loading concept labels...")

concepts_df = pd.read_csv(CONCEPTS_FILE)
concepts_df["subject_id"] = concepts_df["Subject"].str.replace("train_", "", regex=True)

# Binarize concept values
concept_cols = [col for col in concepts_df.columns if col in CONCEPT_NAMES]
for col in concept_cols:
    concepts_df[col] = (concepts_df[col] > 0).astype(int)

print(f"✓ Loaded concept labels for {len(concepts_df)} subjects")

Loading concept labels...
✓ Loaded concept labels for 486 subjects


## Section 2: Load Test Data

Load all 401 test subjects from test folder (will be used entirely as test set)

In [12]:
# Extract test ZIP files to temporary directory
print("Extracting test data...")
temp_dir = tempfile.mkdtemp(prefix="test_chunks_")
print(f"  Temp directory: {temp_dir}")

for i in range(1, 11):
    zip_path = os.path.join(TEST_DIR, f"chunk {i}.zip")
    if os.path.exists(zip_path):
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(os.path.join(temp_dir, f"chunk_{i}"))
        if i % 3 == 0:
            print(f"  Extracted chunk {i}/10")

print("✓ Test data extracted")

Extracting test data...
  Temp directory: /tmp/test_chunks_q60_s85r
  Extracted chunk 3/10
  Extracted chunk 6/10
  Extracted chunk 9/10
✓ Test data extracted


In [13]:
# Load test labels
test_labels_df = pd.read_csv(TEST_LABELS, sep='\t', header=None, names=['subject_id', 'label'])
test_labels_df['subject_id'] = test_labels_df['subject_id'].str.strip()

print(f"✓ Loaded test labels for {len(test_labels_df)} subjects")
print(f"  Label distribution:")
print(test_labels_df['label'].value_counts())

✓ Loaded test labels for 401 subjects
  Label distribution:
label
0    349
1     52
Name: count, dtype: int64


In [14]:
# Parse test XML files
print("Loading test posts...")
test_data = []

test_xml_files = glob.glob(os.path.join(temp_dir, "**", "*.xml"), recursive=True)
print(f"  Found {len(test_xml_files)} XML files")

for xml_file in test_xml_files:
    filename = os.path.basename(xml_file)
    match = re.match(r"(test_subject\d+)_\d+\.xml", filename)
    if match:
        subject_id = match.group(1)
        label_row = test_labels_df[test_labels_df['subject_id'] == subject_id]
        if len(label_row) > 0:
            label = label_row.iloc[0]['label']
            posts = extract_posts_from_xml(xml_file)
            for post in posts:
                test_data.append({
                    "subject_id": subject_id,
                    "label": label,
                    "text": post
                })

test_posts_df = pd.DataFrame(test_data)

print(f"✓ Loaded test posts")
print(f"  Total posts: {len(test_posts_df):,}")
print(f"  Unique subjects: {test_posts_df['subject_id'].nunique()}")

Loading test posts...
  Found 4010 XML files
✓ Loaded test posts
  Total posts: 229,746
  Unique subjects: 401


In [15]:
# Split TRAINING data into train and validation (80/20)
print("Splitting training data into train (80%) and validation (20%)...")

train_subjects = train_posts_df.groupby('subject_id')['label'].first().reset_index()

train_subjects_final, val_subjects = train_test_split(
    train_subjects['subject_id'],
    test_size=0.2,
    stratify=train_subjects['label'],
    random_state=SEED
)

# Create new train dataframe with only 80% of subjects
train_posts_df_final = train_posts_df[train_posts_df['subject_id'].isin(train_subjects_final)].copy()

# Create validation dataframe from remaining 20% of training subjects
val_posts_df = train_posts_df[train_posts_df['subject_id'].isin(val_subjects)].copy()

# Keep ALL test data as test set (no split)
test_posts_df_final = test_posts_df.copy()

print(f"✓ Split complete")
print(f"  Training: {train_posts_df_final['subject_id'].nunique()} subjects (80% of original train)")
print(f"  Validation: {val_posts_df['subject_id'].nunique()} subjects (20% of original train)")
print(f"  Test: {test_posts_df_final['subject_id'].nunique()} subjects (100% of test folder)")

# Show label distributions
print(f"\n  Training label distribution:")
print(train_posts_df_final.groupby('label')['subject_id'].nunique())
print(f"\n  Validation label distribution:")
print(val_posts_df.groupby('label')['subject_id'].nunique())
print(f"\n  Test label distribution:")
print(test_posts_df_final.groupby('label')['subject_id'].nunique())

Splitting training data into train (80%) and validation (20%)...
✓ Split complete
  Training: 388 subjects (80% of original train)
  Validation: 98 subjects (20% of original train)
  Test: 401 subjects (100% of test folder)

  Training label distribution:
label
0    322
1     66
Name: subject_id, dtype: int64

  Validation label distribution:
label
0    81
1    17
Name: subject_id, dtype: int64

  Test label distribution:
label
0    349
1     52
Name: subject_id, dtype: int64


## Section 3: SBERT Setup & Concept Embeddings

In [16]:
# Load SBERT model
print(f"Loading SBERT model: {HYPERPARAMS['sbert_model']}")
sbert_model = SentenceTransformer(HYPERPARAMS['sbert_model'])
sbert_model = sbert_model.to(DEVICE)

print(f"✓ SBERT model loaded on {DEVICE}")
print(f"  Embedding dimension: {sbert_model.get_sentence_embedding_dimension()}")

Loading SBERT model: all-MiniLM-L6-v2
✓ SBERT model loaded on cuda
  Embedding dimension: 384


In [17]:
# Create concept embeddings
print(f"Creating embeddings for {N_CONCEPTS} concepts...")
concept_embeddings = sbert_model.encode(
    CONCEPT_NAMES,
    convert_to_tensor=True,
    show_progress_bar=False
)

print(f"✓ Concept embeddings created")
print(f"  Shape: {concept_embeddings.shape}")

Creating embeddings for 21 concepts...
✓ Concept embeddings created
  Shape: torch.Size([21, 384])


In [18]:
def retrieve_top_k_posts_max(subject_id, posts_df, concept_embs, sbert, k=50, batch_size=32, debug=False):
    """
    Retrieve top-k posts for a subject based on MAX of concept similarities.
    OPTIMIZED: Uses batching to prevent memory exhaustion.
    
    For each post, takes MAX similarity across all 21 concepts.
    Selects posts that are highly relevant to at least ONE concept.
    """
    subj_posts = posts_df[posts_df['subject_id'] == subject_id]['text'].tolist()

    if len(subj_posts) == 0:
        return []

    if len(subj_posts) <= k:
        if len(subj_posts) < k:
            extra_needed = k - len(subj_posts)
            padding = list(np.random.choice(subj_posts, size=extra_needed, replace=True))
            return subj_posts + padding
        else:
            return subj_posts

    # Batch encoding to prevent memory issues
    max_sim_scores = []

    with torch.no_grad():  # Disable gradient tracking
        for i in range(0, len(subj_posts), batch_size):
            batch_posts = subj_posts[i:i + batch_size]

            # Encode batch
            batch_embeddings = sbert.encode(
                batch_posts,
                convert_to_tensor=True,
                show_progress_bar=False
            )

            # Compute similarities for this batch
            cos_scores = util.cos_sim(batch_embeddings, concept_embs)  # [batch, 21]
            # KEY: Take MAX instead of SUM
            batch_max_scores = cos_scores.max(dim=1)[0].cpu().numpy()  # [0] gets values, not indices

            max_sim_scores.extend(batch_max_scores)

            # Clear references
            del batch_embeddings, cos_scores, batch_max_scores

    max_sim_scores = np.array(max_sim_scores)

    if debug:
        print("\n" + "="*60)
        print(f"[DEBUG] Subject: {subject_id}")
        print(f"[DEBUG] Total posts: {len(subj_posts)}")
        print("[DEBUG] Max similarity stats:")
        print(f"  min={max_sim_scores.min():.4f} "
              f"max={max_sim_scores.max():.4f} "
              f"mean={max_sim_scores.mean():.4f} "
              f"std={max_sim_scores.std():.4f}")

        top_idx_sorted = np.argsort(-max_sim_scores)
        print(f"\n[DEBUG] Top-{DEBUG_TOP_N_POSTS} retrieved posts:")
        for rank, i in enumerate(top_idx_sorted[:DEBUG_TOP_N_POSTS]):
            print(f"\n  Rank {rank+1}")
            print(f"  Score: {max_sim_scores[i]:.4f}")
            print(f"  Text: {subj_posts[i][:300]}")

    # Select top-k posts
    top_k_indices = np.argpartition(-max_sim_scores, range(min(k, len(subj_posts))))[:k]

    return [subj_posts[i] for i in top_k_indices]

print("✓ Batched post retrieval function defined (MAX-based)")

✓ Batched post retrieval function defined (MAX-based)


In [19]:
# Retrieve top-k posts for all subjects
print(f"Retrieving top-{HYPERPARAMS['k_posts']} posts (MAX-based scoring with batching)...")
print("⏰ This will be faster and more memory-efficient")
start_time = time.time()

# Training subjects (80% of original training data)
print("  Processing training subjects (80% of train data)...")
train_selected = {}
train_subjects = train_posts_df_final['subject_id'].unique()

for idx, subject_id in enumerate(tqdm(train_subjects, desc="Train subjects")):
    selected = retrieve_top_k_posts_max(
        subject_id,
        train_posts_df_final,
        concept_embeddings,
        sbert_model,
        k=HYPERPARAMS['k_posts'],
        batch_size=MEMORY_CONFIG['post_batch_size'],
        debug=(DEBUG and idx < DEBUG_N_SUBJECTS)
    )
    train_selected[subject_id] = selected

    # Clear GPU cache periodically
    if (idx + 1) % MEMORY_CONFIG['subject_cache_interval'] == 0:
        clear_gpu_cache()

# Validation subjects (20% of original training data)
print("\n  Processing validation subjects (20% of train data)...")
val_selected = {}
val_subjects = val_posts_df['subject_id'].unique()

for idx, subject_id in enumerate(tqdm(val_subjects, desc="Val subjects")):
    selected = retrieve_top_k_posts_max(
        subject_id,
        val_posts_df,
        concept_embeddings,
        sbert_model,
        k=HYPERPARAMS['k_posts'],
        batch_size=MEMORY_CONFIG['post_batch_size'],
        debug=(DEBUG and idx < DEBUG_N_SUBJECTS)
    )
    val_selected[subject_id] = selected

    if (idx + 1) % MEMORY_CONFIG['subject_cache_interval'] == 0:
        clear_gpu_cache()

# Test subjects (100% of test folder)
print("\n  Processing test subjects (100% of test folder)...")
test_selected = {}
test_subjects = test_posts_df_final['subject_id'].unique()

for idx, subject_id in enumerate(tqdm(test_subjects, desc="Test subjects")):
    selected = retrieve_top_k_posts_max(
        subject_id,
        test_posts_df_final,
        concept_embeddings,
        sbert_model,
        k=HYPERPARAMS['k_posts'],
        batch_size=MEMORY_CONFIG['post_batch_size'],
        debug=(DEBUG and idx < DEBUG_N_SUBJECTS)
    )
    test_selected[subject_id] = selected

    if (idx + 1) % MEMORY_CONFIG['subject_cache_interval'] == 0:
        clear_gpu_cache()

# Final cache clear
clear_gpu_cache()

print(f"\n✓ Post retrieval complete in {time.time()-start_time:.1f}s ({(time.time()-start_time)/60:.1f} min)")
print(f"  Memory-optimized processing: {len(train_subjects) + len(val_subjects) + len(test_subjects)} subjects")

Retrieving top-50 posts (MAX-based scoring with batching)...
⏰ This will be faster and more memory-efficient
  Processing training subjects (80% of train data)...


Train subjects:   1%|          | 2/388 [00:01<04:00,  1.60it/s]


[DEBUG] Subject: subject9115
[DEBUG] Total posts: 1254
[DEBUG] Max similarity stats:
  min=-0.0131 max=0.6024 mean=0.2002 std=0.0868

[DEBUG] Top-5 retrieved posts:

  Rank 1
  Score: 0.6024
  Text: As someone who knows what it's like to suffer depression, I've always condoned suicide if the person really feels they have nothing left.

  Rank 2
  Score: 0.5655
  Text: "It's a permanent solution to a temporary problem" really angers me. I'm talking about suicide, of course. My mom says that phrase all the damn time, and I just want to say, "Or it could be a permanent solution to a problem that never goes away and eats at you until you have to do something to end i

  Rank 3
  Score: 0.5450
  Text: "Suicide is a permanent solution to a temporary problem."

  Rank 4
  Score: 0.5372
  Text: What is the purpose of crying? I understand you might cry if you need to get something out of your eye, but what purpose does crying when you're hurt or because you're sad serve? Oh, I don't mean why d

Train subjects: 100%|██████████| 388/388 [03:16<00:00,  1.98it/s]



  Processing validation subjects (20% of train data)...


Val subjects:   3%|▎         | 3/98 [00:00<00:28,  3.28it/s]


[DEBUG] Subject: subject1095
[DEBUG] Total posts: 1019
[DEBUG] Max similarity stats:
  min=0.0050 max=0.4646 mean=0.1622 std=0.0623

[DEBUG] Top-5 retrieved posts:

  Rank 1
  Score: 0.4646
  Text: The quality of 0 is still null.

  Rank 2
  Score: 0.4530
  Text: I lost my virginity.

  Rank 3
  Score: 0.4064
  Text: Suiciders has been pretty good.

  Rank 4
  Score: 0.4053
  Text: It is a thing. But you totally probably have a sadness fetish.

  Rank 5
  Score: 0.3968
  Text: Older women having high libido, being so rare, it gets its own term, hence "cougar"?

[DEBUG] Subject: subject1457
[DEBUG] Total posts: 93
[DEBUG] Max similarity stats:
  min=0.0673 max=0.4384 mean=0.2524 std=0.0822

[DEBUG] Top-5 retrieved posts:

  Rank 1
  Score: 0.4384
  Text: When I was a teenager, my method of suicide was drugs. I thought I had pumped enough into my system to just finally end it all. But then I woke up. The immediate thoughts were "I'm such a fuck-up I can't even kill myself right." It was

Val subjects: 100%|██████████| 98/98 [00:52<00:00,  1.87it/s]



  Processing test subjects (100% of test folder)...


Test subjects:   0%|          | 2/401 [00:00<01:49,  3.63it/s]


[DEBUG] Subject: test_subject6048
[DEBUG] Total posts: 621
[DEBUG] Max similarity stats:
  min=0.0102 max=0.4379 mean=0.1624 std=0.0683

[DEBUG] Top-5 retrieved posts:

  Rank 1
  Score: 0.4379
  Text: Came here to say basically the same thing- so frustrated at the moment.

  Rank 2
  Score: 0.4201
  Text: Your probably waking up pretty dehydrated like most people, which is probably making your muscles pretty cramped. Combine this with having your body being completely stationary for the last 8 or so hours and you're gonna have a bumpy start. I'd say just keep getting up and doing it, your body will n

  Rank 3
  Score: 0.4090
  Text: How was the sex?

  Rank 4
  Score: 0.3784
  Text: It's an illness, probably has diarrhea

  Rank 5
  Score: 0.3749
  Text: Its pretty much become impossible not to like him.


Test subjects: 100%|██████████| 401/401 [03:30<00:00,  1.90it/s]



✓ Post retrieval complete in 459.7s (7.7 min)
  Memory-optimized processing: 887 subjects


In [20]:
def encode_and_attention_pool_max(selected_posts_dict, sbert, concept_embs,
                                   normalize=True, debug=False):
    """
    Encode posts and pool using MAX of concept similarities for attention.
    OPTIMIZED: Includes memory management for stability.
    
    For each post:
    1. Compute similarities to all 21 concepts
    2. Take MAX similarity as the post's relevance score
    3. Use softmax(max_scores / temperature) for attention weights
    4. Weighted sum pooling to create final embedding
    """
    subject_ids = list(selected_posts_dict.keys())
    pooled_embeddings = []

    with torch.no_grad():  # Disable gradient tracking
        for idx, subject_id in enumerate(subject_ids):
            posts = selected_posts_dict[subject_id]

            # Handle empty posts
            if len(posts) == 0:
                print(f"WARNING: No posts for subject {subject_id}, using zero embedding")
                pooled_embeddings.append(np.zeros(384))
                continue

            # Filter out empty posts
            posts = [p for p in posts if p.strip()]
            if len(posts) == 0:
                print(f"WARNING: All posts empty for subject {subject_id}, using zero embedding")
                pooled_embeddings.append(np.zeros(384))
                continue

            # Encode posts
            post_embs = sbert.encode(
                posts,
                convert_to_tensor=True,
                show_progress_bar=False
            )

            if post_embs.shape[0] == 0 or post_embs.shape[1] == 0:
                print(f"WARNING: Empty embeddings for subject {subject_id}, using zero embedding")
                pooled_embeddings.append(np.zeros(384))
                continue

            # Compute similarity to concepts
            cos_scores = util.cos_sim(post_embs, concept_embs)

            # Temperature sharpening on cosine similarities
            cos_scores = cos_scores / COSINE_TEMPERATURE

            # Take MAX after sharpening

            post_scores = cos_scores.max(dim=1)[0]  # [0] gets values, not indices

            # Clamp after MAX

            post_scores = torch.clamp(post_scores, min=0.0)

            # Attention weights
            TEMPERATURE = 0.2  
            attn_weights = torch.softmax(post_scores / TEMPERATURE, dim=0)

            if debug and idx < DEBUG_N_SUBJECTS:
                print("\n" + "="*60)
                print(f"[DEBUG][ATTENTION] Subject: {subject_id}")
                attn_np = attn_weights.cpu().numpy()
                print("[DEBUG][ATTENTION] Weight stats:")
                print(f"  min={attn_np.min():.6f} "
                      f"max={attn_np.max():.6f} "
                      f"mean={attn_np.mean():.6f} "
                      f"entropy={-np.sum(attn_np * np.log(attn_np + 1e-12)):.4f}")

                top_attn_idx = np.argsort(-attn_np)[:DEBUG_TOP_N_POSTS]
                print(f"\n[DEBUG][ATTENTION] Top-{DEBUG_TOP_N_POSTS} attended posts:")
                for rank, i in enumerate(top_attn_idx):
                    print(f"\n  Rank {rank+1}")
                    print(f"  Attention: {attn_np[i]:.6f}")
                    print(f"  Text: {posts[i][:300]}")

            # Weighted sum pooling
            pooled = torch.sum(attn_weights.unsqueeze(1) * post_embs, dim=0)
            pooled_embeddings.append(pooled.cpu().numpy())

            # Clean up GPU memory
            del post_embs, cos_scores, attn_weights, pooled

    return np.vstack(pooled_embeddings), subject_ids

print("✓ Memory-optimized attention pooling function defined (MAX-based)")

✓ Memory-optimized attention pooling function defined (MAX-based)


In [21]:
# Encode and pool for all splits
print("Encoding and pooling embeddings (MAX-based attention, memory-optimized)...")

start_time = time.time()

print("  Training set...")
X_train, train_subject_ids = encode_and_attention_pool_max(
    train_selected,
    sbert_model,
    concept_embeddings,
    normalize=True,
    debug=DEBUG
)
clear_gpu_cache()
print(f"    X_train shape: {X_train.shape}")

print("  Validation set...")
X_val, val_subject_ids = encode_and_attention_pool_max(
    val_selected, sbert_model, concept_embeddings, normalize=True
)
clear_gpu_cache()
print(f"    X_val shape: {X_val.shape}")

print("  Test set...")
X_test, test_subject_ids = encode_and_attention_pool_max(
    test_selected, sbert_model, concept_embeddings, normalize=True
)
clear_gpu_cache()
print(f"    X_test shape: {X_test.shape}")

print(f"\n✓ Encoding complete in {time.time()-start_time:.1f}s ({(time.time()-start_time)/60:.1f} min)")

Encoding and pooling embeddings (MAX-based attention, memory-optimized)...
  Training set...

[DEBUG][ATTENTION] Subject: subject9115
[DEBUG][ATTENTION] Weight stats:
  min=0.014175 max=0.044475 mean=0.020000 entropy=3.8648

[DEBUG][ATTENTION] Top-5 attended posts:

  Rank 1
  Attention: 0.044475
  Text: As someone who knows what it's like to suffer depression, I've always condoned suicide if the person really feels they have nothing left.

  Rank 2
  Attention: 0.036982
  Text: "It's a permanent solution to a temporary problem" really angers me. I'm talking about suicide, of course. My mom says that phrase all the damn time, and I just want to say, "Or it could be a permanent solution to a problem that never goes away and eats at you until you have to do something to end i

  Rank 3
  Attention: 0.033384
  Text: "Suicide is a permanent solution to a temporary problem."

  Rank 4
  Attention: 0.032098
  Text: What is the purpose of crying? I understand you might cry if you need to get 

## Section 6: Build Concept Matrices and Labels

In [22]:
# Build concept matrices and label vectors
print("Building concept matrices and labels...")

# Training: get concepts from questionnaires (80% of training data)
C_train = []
y_train = []
for subject_id in train_subject_ids:
    label = train_posts_df_final[train_posts_df_final['subject_id'] == subject_id]['label'].iloc[0]
    y_train.append(label)
    
    concept_row = concepts_df[concepts_df['subject_id'] == subject_id]
    if len(concept_row) > 0:
        concepts = concept_row[concept_cols].values[0]
    else:
        concepts = np.zeros(N_CONCEPTS)
    C_train.append(concepts)

C_train = np.array(C_train, dtype=np.float32)
y_train = np.array(y_train, dtype=np.float32)

# Validation: get concepts from questionnaires (20% of training data)
C_val = []
y_val = []
for subject_id in val_subject_ids:
    label = val_posts_df[val_posts_df['subject_id'] == subject_id]['label'].iloc[0]
    y_val.append(label)
    
    concept_row = concepts_df[concepts_df['subject_id'] == subject_id]
    if len(concept_row) > 0:
        concepts = concept_row[concept_cols].values[0]
    else:
        concepts = np.zeros(N_CONCEPTS)
    C_val.append(concepts)

C_val = np.array(C_val, dtype=np.float32)
y_val = np.array(y_val, dtype=np.float32)

# Test: zeros for concepts (no ground truth available)
C_test = np.zeros((len(test_subject_ids), N_CONCEPTS), dtype=np.float32)
y_test = []
for subject_id in test_subject_ids:
    label = test_posts_df_final[test_posts_df_final['subject_id'] == subject_id]['label'].iloc[0]
    y_test.append(label)
y_test = np.array(y_test, dtype=np.float32)

print("✓ Matrices built")
print(f"  Train: X={X_train.shape}, C={C_train.shape}, y={y_train.shape}")
print(f"  Val:   X={X_val.shape}, C={C_val.shape}, y={y_val.shape}")
print(f"  Test:  X={X_test.shape}, C={C_test.shape}, y={y_test.shape}")
print(f"\n  Training label distribution: {np.bincount(y_train.astype(int))}")
print(f"  Validation label distribution: {np.bincount(y_val.astype(int))}")
print(f"  Test label distribution: {np.bincount(y_test.astype(int))}")

Building concept matrices and labels...
✓ Matrices built
  Train: X=(388, 384), C=(388, 21), y=(388,)
  Val:   X=(98, 384), C=(98, 21), y=(98,)
  Test:  X=(401, 384), C=(401, 21), y=(401,)

  Training label distribution: [322  66]
  Validation label distribution: [81 17]
  Test label distribution: [349  52]


## Section 7: Compute Class Weights

In [23]:
# Compute class weights for imbalanced dataset
n_negative = int(np.sum(y_train == 0))
n_positive = int(np.sum(y_train == 1))
pos_weight = n_negative / n_positive

print(f"Class imbalance:")
print(f"  Negative samples: {n_negative}")
print(f"  Positive samples: {n_positive}")
print(f"  Ratio: 1:{pos_weight:.2f}")
print(f"  Computed pos_weight: {pos_weight:.4f}")

Class imbalance:
  Negative samples: 322
  Positive samples: 66
  Ratio: 1:4.88
  Computed pos_weight: 4.8788


## Section 8: Save All Datasets

Save everything for fast loading by training pipelines

In [24]:
# Save processed datasets to disk
print("Saving datasets...")

# Save numpy arrays
np.savez_compressed(
    os.path.join(SAVE_DIR, "train_data.npz"),
    X=X_train,
    C=C_train,
    y=y_train,
    subject_ids=np.array(train_subject_ids)
)

np.savez_compressed(
    os.path.join(SAVE_DIR, "val_data.npz"),
    X=X_val,
    C=C_val,
    y=y_val,
    subject_ids=np.array(val_subject_ids)
)

np.savez_compressed(
    os.path.join(SAVE_DIR, "test_data.npz"),
    X=X_test,
    C=C_test,
    y=y_test,
    subject_ids=np.array(test_subject_ids)
)

# Save class weights info
class_info = {
    "n_positive": n_positive,
    "n_negative": n_negative,
    "pos_weight": float(pos_weight)
}

with open(os.path.join(SAVE_DIR, "class_weights.json"), 'w') as f:
    json.dump(class_info, f, indent=4)

print(f"✓ Datasets saved to {SAVE_DIR}")
print(f"  train_data.npz: {X_train.shape[0]} samples")
print(f"  val_data.npz:   {X_val.shape[0]} samples")
print(f"  test_data.npz:  {X_test.shape[0]} samples")
print(f"  class_weights.json")

Saving datasets...
✓ Datasets saved to /teamspace/studios/this_studio/Master-Thesis-CEM-Depression-etc-case-study/data/processed/max_alternative_attention_pipeline
  train_data.npz: 388 samples
  val_data.npz:   98 samples
  test_data.npz:  401 samples
  class_weights.json


## Section 9: Cleanup

In [25]:
# Clean up temporary directory
try:
    shutil.rmtree(temp_dir)
    print(f"✓ Cleaned up temporary directory: {temp_dir}")
except Exception as e:
    print(f"⚠ Failed to clean up temporary directory: {e}")

✓ Cleaned up temporary directory: /tmp/test_chunks_q60_s85r


In [26]:
print("\n" + "="*70)
print("      EXTREME ALTERNATIVE DATASET PREPARATION COMPLETE")
print("="*70)
print("\nSaved files:")
print(f"  {SAVE_DIR}/train_data.npz")
print(f"  {SAVE_DIR}/val_data.npz")
print(f"  {SAVE_DIR}/test_data.npz")
print(f"  {SAVE_DIR}/class_weights.json")
print("\nData split strategy:")
print("  - Training: 80% of train folder (~389 subjects)")
print("  - Validation: 20% of train folder (~97 subjects)")
print("  - Test: 100% of test folder (401 subjects)")
print("\nKey difference from 0c_prepare_max:")
print("  - Uses temperature-sharpened MAX of concept similarities")
print("  - Amplifies strong cosine similarities before MAX")
print(f"  - Temperature sharpening (T={COSINE_TEMPERATURE}) creates more extreme post selection")
print("\nUse this data with CEM/CBM training notebooks!")
print("="*70)


      MAX ALTERNATIVE DATASET PREPARATION COMPLETE

Saved files:
  /teamspace/studios/this_studio/Master-Thesis-CEM-Depression-etc-case-study/data/processed/max_alternative_attention_pipeline/train_data.npz
  /teamspace/studios/this_studio/Master-Thesis-CEM-Depression-etc-case-study/data/processed/max_alternative_attention_pipeline/val_data.npz
  /teamspace/studios/this_studio/Master-Thesis-CEM-Depression-etc-case-study/data/processed/max_alternative_attention_pipeline/test_data.npz
  /teamspace/studios/this_studio/Master-Thesis-CEM-Depression-etc-case-study/data/processed/max_alternative_attention_pipeline/class_weights.json

Data split strategy:
  - Training: 80% of train folder (~389 subjects)
  - Validation: 20% of train folder (~97 subjects)
  - Test: 100% of test folder (401 subjects)

Key difference from original:
  - Uses MAX of concept similarities instead of SUM
  - Captures posts highly relevant to at least ONE concept
  - Focuses on 'specialist' posts rather than 'generali