# Max Alternative Attention Dataset Preparation - Max-Based Scoring

**Purpose:** Prepare dataset using MAX of concept similarities instead of SUM.

**Key Difference**:
- Original (0_prepare): `post_score = sum(similarity_to_each_concept)`
- This notebook (0c): `post_score = max(similarity_to_each_concept)`

This captures posts that are HIGHLY relevant to at least ONE concept (specialists) rather than posts relevant to multiple concepts.

**Runtime:** ~40-50 minutes (same as original)

This notebook:
1. Loads training and test data from XML files
2. Uses SBERT to retrieve top-50 concept-relevant posts per subject (max-based scoring)
3. Pools post embeddings using max-based attention weights
4. Saves everything to `data/processed/max_alternative_attention_pipeline/`

## Section 0: Configuration & Setup

In [None]:
# Imports
import os
import glob
import re
import zipfile
import tempfile
import shutil
import json
import time

import numpy as np
import pandas as pd
import xml.etree.ElementTree as ET

import torch
from sentence_transformers import SentenceTransformer, util
from sklearn.model_selection import train_test_split
from tqdm import tqdm

print("✓ All imports successful")

In [None]:
# Set random seeds for reproducibility
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)

print(f"✓ Random seed set to {SEED}")

In [None]:
# Detect device (MPS/CUDA/CPU)
if torch.backends.mps.is_available():
    DEVICE = "mps"
    print("✓ Using MacBook GPU (MPS)")
elif torch.cuda.is_available():
    DEVICE = "cuda"
    print("✓ Using CUDA GPU")
else:
    DEVICE = "cpu"
    print("⚠ Using CPU (will be slow)")

In [None]:
# Define paths
PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), "..", ".."))
DATA_RAW = os.path.join(PROJECT_ROOT, "data/raw")
DATA_PROCESSED = os.path.join(PROJECT_ROOT, "data/processed")

# Training data paths
POS_DIR = os.path.join(DATA_RAW, "train/positive_examples_anonymous_chunks")
NEG_DIR = os.path.join(DATA_RAW, "train/negative_examples_anonymous_chunks")

# Test data paths
TEST_DIR = os.path.join(DATA_RAW, "test")
TEST_LABELS = os.path.join(TEST_DIR, "test_golden_truth.txt")

# Concept labels
CONCEPTS_FILE = os.path.join(DATA_PROCESSED, "merged_questionnaires.csv")

# Output directory - CHANGED FOR MAX ALTERNATIVE PIPELINE
SAVE_DIR = os.path.join(DATA_PROCESSED, "max_alternative_attention_pipeline")
os.makedirs(SAVE_DIR, exist_ok=True)

print("✓ Paths configured")
print(f"  Project root: {PROJECT_ROOT}")
print(f"  Data save dir: {SAVE_DIR}")

In [None]:
# Define 21 BDI-II concept names
CONCEPT_NAMES = [
    "Sadness", "Pessimism", "Past failure", "Loss of pleasure",
    "Guilty feelings", "Punishment feelings", "Self-dislike", "Self-criticalness",
    "Suicidal thoughts or wishes", "Crying", "Agitation", "Loss of interest",
    "Indecisiveness", "Worthlessness", "Loss of energy", "Changes in sleeping pattern",
    "Irritability", "Changes in appetite", "Concentration difficulty",
    "Tiredness or fatigue", "Loss of interest in sex"
]
N_CONCEPTS = len(CONCEPT_NAMES)

print(f"✓ Defined {N_CONCEPTS} BDI-II concepts")

In [None]:
# Hyperparameters
HYPERPARAMS = {
    "k_posts": 50,              # Top-k posts per subject
    "sbert_model": "all-MiniLM-L6-v2",
    "embedding_dim": 384,
}
# =========================
# DEBUG / SANITY CHECK CONFIG
# =========================
DEBUG = True
DEBUG_N_SUBJECTS = 3          # how many subjects to inspect
DEBUG_TOP_N_POSTS = 5         # how many top posts to print
DEBUG_PRINT_CONCEPTS = True   # print per-concept similarity stats

print("✓ Hyperparameters configured:")
for k, v in HYPERPARAMS.items():
    print(f"  {k}: {v}")

In [None]:
# Memory Management Configuration
MEMORY_CONFIG = {
    "post_batch_size": 32,        # Encode N posts at a time
    "subject_cache_interval": 10,  # Clear GPU cache every N subjects
    "use_no_grad": True,           # Disable gradient tracking
    "move_to_cpu_immediately": True # Move results to CPU after computation
}

print("✓ Memory configuration:")
for k, v in MEMORY_CONFIG.items():
    print(f"  {k}: {v}")

In [None]:
import gc

def clear_gpu_cache():
    """Clear GPU cache and run garbage collection."""
    if DEVICE == "mps":
        torch.mps.empty_cache()
    elif DEVICE == "cuda":
        torch.cuda.empty_cache()
    gc.collect()

print("✓ GPU cache clearing utility defined")

## Section 1: Load Training Data

Extract 486 training subjects with posts and concept labels

In [None]:
# Helper functions for XML parsing
WHITESPACE_RE = re.compile(r"\s+")

def normalize_text(text):
    """Normalize text by removing null chars and extra whitespace."""
    if not text:
        return ""
    text = text.replace("\u0000", "")
    text = WHITESPACE_RE.sub(" ", text).strip()
    return text

def extract_posts_from_xml(xml_path, min_chars=10):
    """Extract posts from a single XML file."""
    try:
        tree = ET.parse(xml_path)
        root = tree.getroot()
    except Exception as e:
        print(f"WARNING: Failed to parse {xml_path}: {e}")
        return []
    
    posts = []
    for writing in root.findall("WRITING"):
        title = writing.findtext("TITLE") or ""
        text = writing.findtext("TEXT") or ""
        
        combined = normalize_text(f"{title} {text}".strip())
        if len(combined) >= min_chars:
            posts.append(combined)
    
    return posts

print("✓ Helper functions defined")

In [None]:
# Parse training XML files
print("Loading training data...")
start_time = time.time()

train_data = []

# Process positive examples
print("  Processing positive examples...")
pos_files = glob.glob(os.path.join(POS_DIR, "**", "*.xml"), recursive=True)
for xml_file in tqdm(pos_files, desc="Processing positive examples"):
    filename = os.path.basename(xml_file)
    match = re.match(r"train_(subject\d+)_\d+\.xml", filename)
    if match:
        subject_id = match.group(1)
        posts = extract_posts_from_xml(xml_file)
        for post in posts:
            train_data.append({
                "subject_id": subject_id,
                "label": 1,  # Positive (depression)
                "text": post
            })

print(f"  Loaded {sum(d['label'] == 1 for d in train_data)} posts from positive subjects")

# Process negative examples
print("  Processing negative examples...")
neg_files = glob.glob(os.path.join(NEG_DIR, "**", "*.xml"), recursive=True)
for xml_file in tqdm(neg_files, desc="Processing negative examples"):
    filename = os.path.basename(xml_file)
    match = re.match(r"train_(subject\d+)_\d+\.xml", filename)
    if match:
        subject_id = match.group(1)
        posts = extract_posts_from_xml(xml_file)
        for post in posts:
            train_data.append({
                "subject_id": subject_id,
                "label": 0,  # Negative (control)
                "text": post
            })

train_posts_df = pd.DataFrame(train_data)

print(f"\n✓ Loaded training data in {time.time()-start_time:.1f}s")
print(f"  Total posts: {len(train_posts_df):,}")
print(f"  Unique subjects: {train_posts_df['subject_id'].nunique()}")
print(f"  Label distribution:")
print(train_posts_df.groupby('label')['subject_id'].nunique())

In [None]:
# Load concept labels from questionnaires
print("Loading concept labels...")

concepts_df = pd.read_csv(CONCEPTS_FILE)
concepts_df["subject_id"] = concepts_df["Subject"].str.replace("train_", "", regex=True)

# Binarize concept values
concept_cols = [col for col in concepts_df.columns if col in CONCEPT_NAMES]
for col in concept_cols:
    concepts_df[col] = (concepts_df[col] > 0).astype(int)

print(f"✓ Loaded concept labels for {len(concepts_df)} subjects")

## Section 2: Load Test Data

Load all 401 test subjects from test folder (will be used entirely as test set)

In [None]:
# Extract test ZIP files to temporary directory
print("Extracting test data...")
temp_dir = tempfile.mkdtemp(prefix="test_chunks_")
print(f"  Temp directory: {temp_dir}")

for i in range(1, 11):
    zip_path = os.path.join(TEST_DIR, f"chunk {i}.zip")
    if os.path.exists(zip_path):
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(os.path.join(temp_dir, f"chunk_{i}"))
        if i % 3 == 0:
            print(f"  Extracted chunk {i}/10")

print("✓ Test data extracted")

In [None]:
# Load test labels
test_labels_df = pd.read_csv(TEST_LABELS, sep='\t', header=None, names=['subject_id', 'label'])
test_labels_df['subject_id'] = test_labels_df['subject_id'].str.strip()

print(f"✓ Loaded test labels for {len(test_labels_df)} subjects")
print(f"  Label distribution:")
print(test_labels_df['label'].value_counts())

In [None]:
# Parse test XML files
print("Loading test posts...")
test_data = []

test_xml_files = glob.glob(os.path.join(temp_dir, "**", "*.xml"), recursive=True)
print(f"  Found {len(test_xml_files)} XML files")

for xml_file in test_xml_files:
    filename = os.path.basename(xml_file)
    match = re.match(r"(test_subject\d+)_\d+\.xml", filename)
    if match:
        subject_id = match.group(1)
        label_row = test_labels_df[test_labels_df['subject_id'] == subject_id]
        if len(label_row) > 0:
            label = label_row.iloc[0]['label']
            posts = extract_posts_from_xml(xml_file)
            for post in posts:
                test_data.append({
                    "subject_id": subject_id,
                    "label": label,
                    "text": post
                })

test_posts_df = pd.DataFrame(test_data)

print(f"✓ Loaded test posts")
print(f"  Total posts: {len(test_posts_df):,}")
print(f"  Unique subjects: {test_posts_df['subject_id'].nunique()}")

In [None]:
# Split TRAINING data into train and validation (80/20)
print("Splitting training data into train (80%) and validation (20%)...")

train_subjects = train_posts_df.groupby('subject_id')['label'].first().reset_index()

train_subjects_final, val_subjects = train_test_split(
    train_subjects['subject_id'],
    test_size=0.2,
    stratify=train_subjects['label'],
    random_state=SEED
)

# Create new train dataframe with only 80% of subjects
train_posts_df_final = train_posts_df[train_posts_df['subject_id'].isin(train_subjects_final)].copy()

# Create validation dataframe from remaining 20% of training subjects
val_posts_df = train_posts_df[train_posts_df['subject_id'].isin(val_subjects)].copy()

# Keep ALL test data as test set (no split)
test_posts_df_final = test_posts_df.copy()

print(f"✓ Split complete")
print(f"  Training: {train_posts_df_final['subject_id'].nunique()} subjects (80% of original train)")
print(f"  Validation: {val_posts_df['subject_id'].nunique()} subjects (20% of original train)")
print(f"  Test: {test_posts_df_final['subject_id'].nunique()} subjects (100% of test folder)")

# Show label distributions
print(f"\n  Training label distribution:")
print(train_posts_df_final.groupby('label')['subject_id'].nunique())
print(f"\n  Validation label distribution:")
print(val_posts_df.groupby('label')['subject_id'].nunique())
print(f"\n  Test label distribution:")
print(test_posts_df_final.groupby('label')['subject_id'].nunique())

## Section 3: SBERT Setup & Concept Embeddings

In [None]:
# Load SBERT model
print(f"Loading SBERT model: {HYPERPARAMS['sbert_model']}")
sbert_model = SentenceTransformer(HYPERPARAMS['sbert_model'])
sbert_model = sbert_model.to(DEVICE)

print(f"✓ SBERT model loaded on {DEVICE}")
print(f"  Embedding dimension: {sbert_model.get_sentence_embedding_dimension()}")

In [None]:
# Create concept embeddings
print(f"Creating embeddings for {N_CONCEPTS} concepts...")
concept_embeddings = sbert_model.encode(
    CONCEPT_NAMES,
    convert_to_tensor=True,
    show_progress_bar=False
)

print(f"✓ Concept embeddings created")
print(f"  Shape: {concept_embeddings.shape}")

In [None]:
def retrieve_top_k_posts_max(subject_id, posts_df, concept_embs, sbert, k=50, batch_size=32, debug=False):
    """
    Retrieve top-k posts for a subject based on MAX of concept similarities.
    OPTIMIZED: Uses batching to prevent memory exhaustion.
    
    For each post, takes MAX similarity across all 21 concepts.
    Selects posts that are highly relevant to at least ONE concept.
    """
    subj_posts = posts_df[posts_df['subject_id'] == subject_id]['text'].tolist()

    if len(subj_posts) == 0:
        return []

    if len(subj_posts) <= k:
        if len(subj_posts) < k:
            extra_needed = k - len(subj_posts)
            padding = list(np.random.choice(subj_posts, size=extra_needed, replace=True))
            return subj_posts + padding
        else:
            return subj_posts

    # Batch encoding to prevent memory issues
    max_sim_scores = []

    with torch.no_grad():  # Disable gradient tracking
        for i in range(0, len(subj_posts), batch_size):
            batch_posts = subj_posts[i:i + batch_size]

            # Encode batch
            batch_embeddings = sbert.encode(
                batch_posts,
                convert_to_tensor=True,
                show_progress_bar=False
            )

            # Compute similarities for this batch
            cos_scores = util.cos_sim(batch_embeddings, concept_embs)  # [batch, 21]
            # KEY: Take MAX instead of SUM
            batch_max_scores = cos_scores.max(dim=1)[0].cpu().numpy()  # [0] gets values, not indices

            max_sim_scores.extend(batch_max_scores)

            # Clear references
            del batch_embeddings, cos_scores, batch_max_scores

    max_sim_scores = np.array(max_sim_scores)

    if debug:
        print("\n" + "="*60)
        print(f"[DEBUG] Subject: {subject_id}")
        print(f"[DEBUG] Total posts: {len(subj_posts)}")
        print("[DEBUG] Max similarity stats:")
        print(f"  min={max_sim_scores.min():.4f} "
              f"max={max_sim_scores.max():.4f} "
              f"mean={max_sim_scores.mean():.4f} "
              f"std={max_sim_scores.std():.4f}")

        top_idx_sorted = np.argsort(-max_sim_scores)
        print(f"\n[DEBUG] Top-{DEBUG_TOP_N_POSTS} retrieved posts:")
        for rank, i in enumerate(top_idx_sorted[:DEBUG_TOP_N_POSTS]):
            print(f"\n  Rank {rank+1}")
            print(f"  Score: {max_sim_scores[i]:.4f}")
            print(f"  Text: {subj_posts[i][:300]}")

    # Select top-k posts
    top_k_indices = np.argpartition(-max_sim_scores, range(min(k, len(subj_posts))))[:k]

    return [subj_posts[i] for i in top_k_indices]

print("✓ Batched post retrieval function defined (MAX-based)")

In [None]:
# Retrieve top-k posts for all subjects
print(f"Retrieving top-{HYPERPARAMS['k_posts']} posts (MAX-based scoring with batching)...")
print("⏰ This will be faster and more memory-efficient")
start_time = time.time()

# Training subjects (80% of original training data)
print("  Processing training subjects (80% of train data)...")
train_selected = {}
train_subjects = train_posts_df_final['subject_id'].unique()

for idx, subject_id in enumerate(tqdm(train_subjects, desc="Train subjects")):
    selected = retrieve_top_k_posts_max(
        subject_id,
        train_posts_df_final,
        concept_embeddings,
        sbert_model,
        k=HYPERPARAMS['k_posts'],
        batch_size=MEMORY_CONFIG['post_batch_size'],
        debug=(DEBUG and idx < DEBUG_N_SUBJECTS)
    )
    train_selected[subject_id] = selected

    # Clear GPU cache periodically
    if (idx + 1) % MEMORY_CONFIG['subject_cache_interval'] == 0:
        clear_gpu_cache()

# Validation subjects (20% of original training data)
print("\n  Processing validation subjects (20% of train data)...")
val_selected = {}
val_subjects = val_posts_df['subject_id'].unique()

for idx, subject_id in enumerate(tqdm(val_subjects, desc="Val subjects")):
    selected = retrieve_top_k_posts_max(
        subject_id,
        val_posts_df,
        concept_embeddings,
        sbert_model,
        k=HYPERPARAMS['k_posts'],
        batch_size=MEMORY_CONFIG['post_batch_size'],
        debug=(DEBUG and idx < DEBUG_N_SUBJECTS)
    )
    val_selected[subject_id] = selected

    if (idx + 1) % MEMORY_CONFIG['subject_cache_interval'] == 0:
        clear_gpu_cache()

# Test subjects (100% of test folder)
print("\n  Processing test subjects (100% of test folder)...")
test_selected = {}
test_subjects = test_posts_df_final['subject_id'].unique()

for idx, subject_id in enumerate(tqdm(test_subjects, desc="Test subjects")):
    selected = retrieve_top_k_posts_max(
        subject_id,
        test_posts_df_final,
        concept_embeddings,
        sbert_model,
        k=HYPERPARAMS['k_posts'],
        batch_size=MEMORY_CONFIG['post_batch_size'],
        debug=(DEBUG and idx < DEBUG_N_SUBJECTS)
    )
    test_selected[subject_id] = selected

    if (idx + 1) % MEMORY_CONFIG['subject_cache_interval'] == 0:
        clear_gpu_cache()

# Final cache clear
clear_gpu_cache()

print(f"\n✓ Post retrieval complete in {time.time()-start_time:.1f}s ({(time.time()-start_time)/60:.1f} min)")
print(f"  Memory-optimized processing: {len(train_subjects) + len(val_subjects) + len(test_subjects)} subjects")

In [None]:
def encode_and_attention_pool_max(selected_posts_dict, sbert, concept_embs,
                                   normalize=True, debug=False):
    """
    Encode posts and pool using MAX of concept similarities for attention.
    OPTIMIZED: Includes memory management for stability.
    
    For each post:
    1. Compute similarities to all 21 concepts
    2. Take MAX similarity as the post's relevance score
    3. Use softmax(max_scores / temperature) for attention weights
    4. Weighted sum pooling to create final embedding
    """
    subject_ids = list(selected_posts_dict.keys())
    pooled_embeddings = []

    with torch.no_grad():  # Disable gradient tracking
        for idx, subject_id in enumerate(subject_ids):
            posts = selected_posts_dict[subject_id]

            # Handle empty posts
            if len(posts) == 0:
                print(f"WARNING: No posts for subject {subject_id}, using zero embedding")
                pooled_embeddings.append(np.zeros(384))
                continue

            # Filter out empty posts
            posts = [p for p in posts if p.strip()]
            if len(posts) == 0:
                print(f"WARNING: All posts empty for subject {subject_id}, using zero embedding")
                pooled_embeddings.append(np.zeros(384))
                continue

            # Encode posts
            post_embs = sbert.encode(
                posts,
                convert_to_tensor=True,
                show_progress_bar=False
            )

            if post_embs.shape[0] == 0 or post_embs.shape[1] == 0:
                print(f"WARNING: Empty embeddings for subject {subject_id}, using zero embedding")
                pooled_embeddings.append(np.zeros(384))
                continue

            # Compute similarity to concepts
            cos_scores = util.cos_sim(post_embs, concept_embs)

            # KEY: Take MAX instead of SUM
            post_scores = cos_scores.max(dim=1)[0]  # [0] gets values, not indices

            # Remove negative similarities
            post_scores = torch.clamp(post_scores, min=0.0)

            # Attention weights
            TEMPERATURE = 0.2  
            attn_weights = torch.softmax(post_scores / TEMPERATURE, dim=0)

            if debug and idx < DEBUG_N_SUBJECTS:
                print("\n" + "="*60)
                print(f"[DEBUG][ATTENTION] Subject: {subject_id}")
                attn_np = attn_weights.cpu().numpy()
                print("[DEBUG][ATTENTION] Weight stats:")
                print(f"  min={attn_np.min():.6f} "
                      f"max={attn_np.max():.6f} "
                      f"mean={attn_np.mean():.6f} "
                      f"entropy={-np.sum(attn_np * np.log(attn_np + 1e-12)):.4f}")

                top_attn_idx = np.argsort(-attn_np)[:DEBUG_TOP_N_POSTS]
                print(f"\n[DEBUG][ATTENTION] Top-{DEBUG_TOP_N_POSTS} attended posts:")
                for rank, i in enumerate(top_attn_idx):
                    print(f"\n  Rank {rank+1}")
                    print(f"  Attention: {attn_np[i]:.6f}")
                    print(f"  Text: {posts[i][:300]}")

            # Weighted sum pooling
            pooled = torch.sum(attn_weights.unsqueeze(1) * post_embs, dim=0)
            pooled_embeddings.append(pooled.cpu().numpy())

            # Clean up GPU memory
            del post_embs, cos_scores, attn_weights, pooled

    return np.vstack(pooled_embeddings), subject_ids

print("✓ Memory-optimized attention pooling function defined (MAX-based)")

In [None]:
# Encode and pool for all splits
print("Encoding and pooling embeddings (MAX-based attention, memory-optimized)...")

start_time = time.time()

print("  Training set...")
X_train, train_subject_ids = encode_and_attention_pool_max(
    train_selected,
    sbert_model,
    concept_embeddings,
    normalize=True,
    debug=DEBUG
)
clear_gpu_cache()
print(f"    X_train shape: {X_train.shape}")

print("  Validation set...")
X_val, val_subject_ids = encode_and_attention_pool_max(
    val_selected, sbert_model, concept_embeddings, normalize=True
)
clear_gpu_cache()
print(f"    X_val shape: {X_val.shape}")

print("  Test set...")
X_test, test_subject_ids = encode_and_attention_pool_max(
    test_selected, sbert_model, concept_embeddings, normalize=True
)
clear_gpu_cache()
print(f"    X_test shape: {X_test.shape}")

print(f"\n✓ Encoding complete in {time.time()-start_time:.1f}s ({(time.time()-start_time)/60:.1f} min)")

## Section 6: Build Concept Matrices and Labels

In [None]:
# Build concept matrices and label vectors
print("Building concept matrices and labels...")

# Training: get concepts from questionnaires (80% of training data)
C_train = []
y_train = []
for subject_id in train_subject_ids:
    label = train_posts_df_final[train_posts_df_final['subject_id'] == subject_id]['label'].iloc[0]
    y_train.append(label)
    
    concept_row = concepts_df[concepts_df['subject_id'] == subject_id]
    if len(concept_row) > 0:
        concepts = concept_row[concept_cols].values[0]
    else:
        concepts = np.zeros(N_CONCEPTS)
    C_train.append(concepts)

C_train = np.array(C_train, dtype=np.float32)
y_train = np.array(y_train, dtype=np.float32)

# Validation: get concepts from questionnaires (20% of training data)
C_val = []
y_val = []
for subject_id in val_subject_ids:
    label = val_posts_df[val_posts_df['subject_id'] == subject_id]['label'].iloc[0]
    y_val.append(label)
    
    concept_row = concepts_df[concepts_df['subject_id'] == subject_id]
    if len(concept_row) > 0:
        concepts = concept_row[concept_cols].values[0]
    else:
        concepts = np.zeros(N_CONCEPTS)
    C_val.append(concepts)

C_val = np.array(C_val, dtype=np.float32)
y_val = np.array(y_val, dtype=np.float32)

# Test: zeros for concepts (no ground truth available)
C_test = np.zeros((len(test_subject_ids), N_CONCEPTS), dtype=np.float32)
y_test = []
for subject_id in test_subject_ids:
    label = test_posts_df_final[test_posts_df_final['subject_id'] == subject_id]['label'].iloc[0]
    y_test.append(label)
y_test = np.array(y_test, dtype=np.float32)

print("✓ Matrices built")
print(f"  Train: X={X_train.shape}, C={C_train.shape}, y={y_train.shape}")
print(f"  Val:   X={X_val.shape}, C={C_val.shape}, y={y_val.shape}")
print(f"  Test:  X={X_test.shape}, C={C_test.shape}, y={y_test.shape}")
print(f"\n  Training label distribution: {np.bincount(y_train.astype(int))}")
print(f"  Validation label distribution: {np.bincount(y_val.astype(int))}")
print(f"  Test label distribution: {np.bincount(y_test.astype(int))}")

## Section 7: Compute Class Weights

In [None]:
# Compute class weights for imbalanced dataset
n_negative = int(np.sum(y_train == 0))
n_positive = int(np.sum(y_train == 1))
pos_weight = n_negative / n_positive

print(f"Class imbalance:")
print(f"  Negative samples: {n_negative}")
print(f"  Positive samples: {n_positive}")
print(f"  Ratio: 1:{pos_weight:.2f}")
print(f"  Computed pos_weight: {pos_weight:.4f}")

## Section 8: Save All Datasets

Save everything for fast loading by training pipelines

In [None]:
# Save processed datasets to disk
print("Saving datasets...")

# Save numpy arrays
np.savez_compressed(
    os.path.join(SAVE_DIR, "train_data.npz"),
    X=X_train,
    C=C_train,
    y=y_train,
    subject_ids=np.array(train_subject_ids)
)

np.savez_compressed(
    os.path.join(SAVE_DIR, "val_data.npz"),
    X=X_val,
    C=C_val,
    y=y_val,
    subject_ids=np.array(val_subject_ids)
)

np.savez_compressed(
    os.path.join(SAVE_DIR, "test_data.npz"),
    X=X_test,
    C=C_test,
    y=y_test,
    subject_ids=np.array(test_subject_ids)
)

# Save class weights info
class_info = {
    "n_positive": n_positive,
    "n_negative": n_negative,
    "pos_weight": float(pos_weight)
}

with open(os.path.join(SAVE_DIR, "class_weights.json"), 'w') as f:
    json.dump(class_info, f, indent=4)

print(f"✓ Datasets saved to {SAVE_DIR}")
print(f"  train_data.npz: {X_train.shape[0]} samples")
print(f"  val_data.npz:   {X_val.shape[0]} samples")
print(f"  test_data.npz:  {X_test.shape[0]} samples")
print(f"  class_weights.json")

## Section 9: Cleanup

In [None]:
# Clean up temporary directory
try:
    shutil.rmtree(temp_dir)
    print(f"✓ Cleaned up temporary directory: {temp_dir}")
except Exception as e:
    print(f"⚠ Failed to clean up temporary directory: {e}")

In [None]:
print("\n" + "="*70)
print("      MAX ALTERNATIVE DATASET PREPARATION COMPLETE")
print("="*70)
print("\nSaved files:")
print(f"  {SAVE_DIR}/train_data.npz")
print(f"  {SAVE_DIR}/val_data.npz")
print(f"  {SAVE_DIR}/test_data.npz")
print(f"  {SAVE_DIR}/class_weights.json")
print("\nData split strategy:")
print("  - Training: 80% of train folder (~389 subjects)")
print("  - Validation: 20% of train folder (~97 subjects)")
print("  - Test: 100% of test folder (401 subjects)")
print("\nKey difference from original:")
print("  - Uses MAX of concept similarities instead of SUM")
print("  - Captures posts highly relevant to at least ONE concept")
print("  - Focuses on 'specialist' posts rather than 'generalist' posts")
print("\nUse this data with CEM/CBM training notebooks!")
print("="*70)