In [None]:
# ============================================================================
# 01_DATA_PREPARATION.IPYNB
# ============================================================================
# PURPOSE: Load, explore, and split dataset into train/val/test sets
# TIME ESTIMATE: 10-15 minutes (CPU only, no GPU needed)
# ============================================================================

# Multiclass Classification - Data Preparation (Stage 2 Input)

## Objective
Prepare the news article dataset for multiclass classification (SAFE, SENSITIVE, UNSAFE):
- Create stratified train/val/test splits
- Output files will be used as input for Stage 2 multiclass classifier training

## Strategy
- Use `summary_long_500` (400 words) as input text
- 70/15/15 stratified split by label AND category
- Save train/val/test datasets as stage2_input files

## 1. Setup and Imports

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from collections import Counter
import os
import json

# Set random seed for reproducibility
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)

# Configure display options
pd.set_option('display.max_columns', None)
pd.set_option('display.max_colwidth', 100)

notebook_dir = os.getcwd()
PROJECT_ROOT = os.path.abspath(os.path.join(notebook_dir, '..'))

# Create results directory for this notebook
RESULTS_DIR = os.path.join(PROJECT_ROOT, 'results', 'data_preparation')
os.makedirs(RESULTS_DIR, exist_ok=True)

print(f"Project root: {PROJECT_ROOT}")
print(f"Results directory: {RESULTS_DIR}")
print("✓ Setup complete")

## 2. Load and Explore Dataset

**Rationale:** Understanding data distribution is critical for:
- Identifying class imbalance (UNSAFE is minority class at ~19.7%)
- Ensuring all categories are represented
- Detecting any data quality issues

In [None]:
# Load the dataset
DATA_PATH = os.path.join(PROJECT_ROOT, 'data', 'raw', 'stage2_multiclass_classification_input.csv')

df = pd.read_csv(DATA_PATH)

print("="*80)
print("DATASET OVERVIEW")
print("="*80)
print(f"\nShape: {df.shape}")
print(f"\nColumns: {df.columns.tolist()}")
print(f"\nData types:\n{df.dtypes}")
print(f"\nMissing values:\n{df.isnull().sum()}")

In [None]:
# Display sample rows
print("\n" + "="*80)
print("SAMPLE DATA")
print("="*80)
print(df[['category', 'label', 'summary_long_500']].head(3))

# Check text lengths
df['text_length_words'] = df['summary_long_500'].str.split().str.len()
df['text_length_chars'] = df['summary_long_500'].str.len()

print("\n" + "="*80)
print("TEXT LENGTH STATISTICS (summary_long_500)")
print("="*80)
print(f"Average words: {df['text_length_words'].mean():.0f}")
print(f"Min words: {df['text_length_words'].min()}")
print(f"Max words: {df['text_length_words'].max()}")
print(f"Std words: {df['text_length_words'].std():.0f}")

## 3. Analyze Label Distribution

**Key Insight:** UNSAFE is the minority class (~19.7%)
- This imbalance requires special handling (class weights, focal loss)
- Our goal: 90%+ UNSAFE recall (catch 90%+ of dangerous content)
- Trade-off: Accept lower precision (more false alarms)

In [None]:
print("\n" + "="*80)
print("LABEL DISTRIBUTION")
print("="*80)

label_counts = df['label'].value_counts()
label_percentages = df['label'].value_counts(normalize=True) * 100

label_stats = pd.DataFrame({
    'Count': label_counts,
    'Percentage': label_percentages
})
print(label_stats)

# Visualize
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Count plot
axes[0].bar(label_counts.index, label_counts.values, color=['green', 'orange', 'red'])
axes[0].set_title('Label Distribution (Counts)', fontsize=14, fontweight='bold')
axes[0].set_ylabel('Count')
axes[0].set_xlabel('Label')
for i, v in enumerate(label_counts.values):
    axes[0].text(i, v + 50, str(v), ha='center', fontweight='bold')

# Percentage plot
axes[1].bar(label_percentages.index, label_percentages.values, color=['green', 'orange', 'red'])
axes[1].set_title('Label Distribution (Percentage)', fontsize=14, fontweight='bold')
axes[1].set_ylabel('Percentage (%)')
axes[1].set_xlabel('Label')
for i, v in enumerate(label_percentages.values):
    axes[1].text(i, v + 1, f'{v:.1f}%', ha='center', fontweight='bold')

plt.tight_layout()

plt.savefig(os.path.join(RESULTS_DIR, 'label_distribution.png'), dpi=300, bbox_inches='tight')
plt.show()
print(f"✓ Label distribution saved to: {os.path.join(RESULTS_DIR, 'label_distribution.png')}")

# Calculate imbalance ratio
unsafe_count = label_counts['UNSAFE']
safe_count = label_counts['SAFE']
imbalance_ratio = safe_count / unsafe_count
print(f"\nImbalance ratio (SAFE:UNSAFE): {imbalance_ratio:.2f}:1")
print("→ This is MODERATE imbalance (not extreme)")
print("→ Weighted loss + focal loss should handle this well")

## 4. Analyze Category Distribution

**Rationale:** Ensure stratified split maintains category balance
- Different categories may have different UNSAFE patterns
- Want representative samples from each category in train/val/test

In [None]:
print("\n" + "="*80)
print("CATEGORY DISTRIBUTION")
print("="*80)

category_counts = df['category'].value_counts()
print(category_counts)

# Cross-tabulation: Label x Category
print("\n" + "="*80)
print("LABEL × CATEGORY CROSS-TABULATION")
print("="*80)
cross_tab = pd.crosstab(df['category'], df['label'], margins=True)
print(cross_tab)

# Visualize
plt.figure(figsize=(12, 6))
cross_tab_pct = pd.crosstab(df['category'], df['label'], normalize='index') * 100
cross_tab_pct.plot(kind='bar', stacked=False, color=['green', 'orange', 'red'])
plt.title('Label Distribution by Category', fontsize=14, fontweight='bold')
plt.ylabel('Percentage (%)')
plt.xlabel('Category')
plt.legend(title='Label')
plt.xticks(rotation=45)
plt.tight_layout()
plt.savefig(os.path.join(RESULTS_DIR, 'category_label_distribution.png'), dpi=300, bbox_inches='tight')
plt.show()
print(f"✓ Category-label distribution saved to: {os.path.join(RESULTS_DIR, 'category_label_distribution.png')}")

print("\n→ Categories have varying UNSAFE percentages")
print("→ ENVIRONMENT has highest sample count (3256)")
print("→ SCIENCE has lowest sample count (1760)")
print("→ Stratified split will maintain these proportions")

## 5. Prepare Data for Splitting

**Strategy:** 
- Select only needed columns: text, label, category
- Create combined stratification key (label_category)
- This ensures splits maintain both label AND category proportions

In [None]:
# Select and rename columns
df_clean = df[['summary_long_500', 'label', 'category']].copy()
df_clean.columns = ['text', 'label', 'category']

# Remove temporary length columns from original df
df = df.drop(['text_length_words', 'text_length_chars'], axis=1, errors='ignore')

print(f"\nCleaned dataset shape: {df_clean.shape}")
print(f"Columns: {df_clean.columns.tolist()}")

# Create stratification column
# This ensures splits maintain proportions across BOTH label AND category
df_clean['strat_key'] = df_clean['label'] + '_' + df_clean['category']

print(f"\nUnique stratification keys: {df_clean['strat_key'].nunique()}")
print("(Should be 12: 3 labels × 4 categories)")

# Verify no missing values
assert df_clean.isnull().sum().sum() == 0, "ERROR: Missing values detected!"
print("\n✓ No missing values")

## 6. Stratified Train/Val/Test Split

**Split Strategy:**
- 70% Train (6,925 samples) - For model training
- 15% Validation (1,484 samples) - For hyperparameter tuning, threshold selection
- 15% Test (1,484 samples) - For final evaluation (NEVER touch until end)

**Why Stratified?**
- Maintains ~19.7% UNSAFE in all splits
- Maintains category proportions in all splits
- Prevents train/val/test from having different distributions

In [None]:
print("\n" + "="*80)
print("PERFORMING STRATIFIED SPLIT")
print("="*80)

# First split: 70% train, 30% temporary
train_df, temp_df = train_test_split(
    df_clean,
    test_size=0.30,
    random_state=RANDOM_SEED,
    stratify=df_clean['strat_key']  # Stratify by label + category
)

# Second split: Split temp into 50-50 (giving us 15% val, 15% test)
val_df, test_df = train_test_split(
    temp_df,
    test_size=0.50,
    random_state=RANDOM_SEED,
    stratify=temp_df['strat_key']
)

# Remove stratification key (not needed anymore)
train_df = train_df.drop('strat_key', axis=1).reset_index(drop=True)
val_df = val_df.drop('strat_key', axis=1).reset_index(drop=True)
test_df = test_df.drop('strat_key', axis=1).reset_index(drop=True)

print(f"Train size: {len(train_df)} ({len(train_df)/len(df_clean)*100:.1f}%)")
print(f"Val size: {len(val_df)} ({len(val_df)/len(df_clean)*100:.1f}%)")
print(f"Test size: {len(test_df)} ({len(test_df)/len(df_clean)*100:.1f}%)")

## 7. Verify Split Quality

**Critical Check:** Ensure UNSAFE percentage is consistent across splits
- If splits have very different UNSAFE%, the split failed
- All should be ~19.7% ± 1%

In [None]:
print("\n" + "="*80)
print("SPLIT VERIFICATION")
print("="*80)

def analyze_split(df, split_name):
    """Analyze label and category distribution in a split"""
    print(f"\n{split_name.upper()} SET:")
    print(f"  Total samples: {len(df)}")
    
    # Label distribution
    label_dist = df['label'].value_counts()
    label_pct = df['label'].value_counts(normalize=True) * 100
    
    print(f"\n  Label distribution:")
    for label in ['SAFE', 'SENSITIVE', 'UNSAFE']:
        count = label_dist.get(label, 0)
        pct = label_pct.get(label, 0)
        print(f"    {label}: {count} ({pct:.1f}%)")
    
    # Category distribution
    category_dist = df['category'].value_counts()
    print(f"\n  Category distribution:")
    for cat in category_dist.index:
        print(f"    {cat}: {category_dist[cat]}")
    
    return label_pct.get('UNSAFE', 0)

# Analyze each split
train_unsafe_pct = analyze_split(train_df, 'train')
val_unsafe_pct = analyze_split(val_df, 'validation')
test_unsafe_pct = analyze_split(test_df, 'test')

# Check consistency
print("\n" + "="*80)
print("UNSAFE PERCENTAGE CONSISTENCY CHECK")
print("="*80)
print(f"Train: {train_unsafe_pct:.2f}%")
print(f"Val:   {val_unsafe_pct:.2f}%")
print(f"Test:  {test_unsafe_pct:.2f}%")
print(f"Difference: {max(train_unsafe_pct, val_unsafe_pct, test_unsafe_pct) - min(train_unsafe_pct, val_unsafe_pct, test_unsafe_pct):.2f}%")

if abs(train_unsafe_pct - val_unsafe_pct) < 1.5 and abs(train_unsafe_pct - test_unsafe_pct) < 1.5:
    print("\n✓ SPLIT QUALITY: EXCELLENT")
    print("  All splits have similar UNSAFE percentages")
else:
    print("\n⚠ SPLIT QUALITY: CHECK NEEDED")
    print("  Splits have different UNSAFE percentages")

## 8. Ready to Save Datasets

**Next Step:** Save train/val/test splits as Stage 2 input files for multiclass classifier training.

In [None]:
print("\n" + "="*80)
print("PREPARING STAGE 2 INPUT DATASETS")
print("="*80)
print("\nDatasets will be saved as:")
print("  - train_stage2_input.csv")
print("  - val_stage2_input.csv")
print("  - test_stage2_input.csv")
print("\nThese files contain 3-class labels (SAFE, SENSITIVE, UNSAFE)")
print("and will be used for multiclass classifier training.")

## 9. Calculate Class Weights (Optional)

**Purpose:** Compute class weights for handling imbalanced classes during training.

In [None]:
print("\n" + "="*80)
print("CALCULATING CLASS WEIGHTS")
print("="*80)

# Calculate weights for 3-class classification
train_labels = train_df['label'].values
label_classes = np.unique(train_labels)
class_weights = compute_class_weight(
    class_weight='balanced',
    classes=label_classes,
    y=train_labels
)

print("\nClass Weights (for multiclass classification):")
class_weight_dict = {}
for cls, weight in zip(label_classes, class_weights):
    class_weight_dict[cls] = weight
    print(f"  {cls}: {weight:.2f}")

print("\n→ These weights can be used during model training to handle class imbalance")

## 10. Save Datasets as Stage 2 Input Files

**Purpose:** Save train/val/test splits as stage2_input files for multiclass classifier training.

In [None]:
print("\n" + "="*80)
print("SAVING STAGE 2 INPUT DATASETS")
print("="*80)

# Save datasets as stage2_input files
train_df.to_csv(os.path.join(PROJECT_ROOT, 'data', 'processed', 'train_stage2_input.csv'), index=False)
val_df.to_csv(os.path.join(PROJECT_ROOT, 'data', 'processed', 'val_stage2_input.csv'), index=False)
test_df.to_csv(os.path.join(PROJECT_ROOT, 'data', 'processed', 'test_stage2_input.csv'), index=False)
print("✓ Saved train_stage2_input.csv")
print("✓ Saved val_stage2_input.csv")
print("✓ Saved test_stage2_input.csv")

# Save metadata
metadata = {
    'total_samples': len(df_clean),
    'train_samples': len(train_df),
    'val_samples': len(val_df),
    'test_samples': len(test_df),
    'label_distribution': df_clean['label'].value_counts().to_dict(),
    'category_distribution': df_clean['category'].value_counts().to_dict(),
    'class_weights': class_weight_dict,
    'random_seed': RANDOM_SEED
}

with open(os.path.join(PROJECT_ROOT, 'data', 'processed', 'metadata.json'), 'w') as f:
    json.dump(metadata, f, indent=2)
print("✓ Saved metadata.json")

## 11. Create Keyword Filter List (Optional)

**Purpose:** Catch obvious UNSAFE content without running the model

## 12. Create Keyword Filter List

**Purpose:** Catch obvious UNSAFE content without running the model
- Adds 5-7% to UNSAFE recall
- Reduces model load
- Acts as first line of defense

**Strategy:**
- Extract high-frequency words from UNSAFE articles
- Manually curate known danger signals
- Use for preprocessing in inference pipeline

In [None]:
print("\n" + "="*80)
print("CREATING UNSAFE KEYWORD FILTER")
print("="*80)

# Extract top words from UNSAFE articles
# unsafe_texts = train_df[train_df['label'] == 'UNSAFE']['text'].values
# all_unsafe_text = ' '.join(unsafe_texts).lower()

# Simple word frequency count (you can use TF-IDF for better results)
# from collections import Counter
# import re

# # Tokenize and count
# words = re.findall(r'\b[a-z]{4,}\b', all_unsafe_text)  # Words with 4+ letters
# word_freq = Counter(words)

# # Get top 50 most common words in UNSAFE articles
# top_unsafe_words = [word for word, count in word_freq.most_common(100)]
# print(top_unsafe_words)

# Manual high-confidence UNSAFE keywords (curated list)
manual_keywords = [
    'rape', 'raped', 'sexual assault', 'sexually assaulted', 'molest', 'molestation',
    'pedophile', 'child abuse', 'abused', 'violence', 'violent', 'murder', 'killed',
    'death', 'weapon', 'gun', 'shooting', 'shot', 'blood', 'injury', 'injured',
    'assault', 'attacked', 'victim', 'predator', 'harassment', 'harassed',
    'explicit', 'pornography', 'nude', 'naked', 'inappropriate', 'misconduct', 'suicide', 'assassination'
]

# Combine (remove duplicates)
#combined_keywords = list(set(manual_keywords))

print(f"Manual high-confidence keywords: {len(manual_keywords)}")
#print(f"Combined unique keywords: {len(combined_keywords)}")

# Preview
print("\nSample keywords:")
print(manual_keywords[:20])

# Save to file
with open(os.path.join(PROJECT_ROOT, 'data', 'keywords', 'unsafe_keywords.txt'), 'w') as f:
    for keyword in sorted(manual_keywords):
        f.write(f"{keyword}\n")

print("\n✓ Saved unsafe_keywords.txt")
print("\nNOTE: Review and refine this list manually!")
print("Add domain-specific terms based on news categories")

## 12. Summary and Next Steps

**What we accomplished:**
✓ Loaded and explored news articles from stage2_multiclass_classification_input.csv
✓ Created stratified 70/15/15 train/val/test split
✓ Saved datasets as stage2_input files (train_stage2_input.csv, val_stage2_input.csv, test_stage2_input.csv)
✓ Calculated class weights for imbalance handling
✓ Created keyword filter for preprocessing (optional)

**Next steps:**
1. Train multiclass classifier model → Notebook 02 (02_multiclass_classifier_training.ipynb)
2. Train kid-safe rewriter model → Notebook 03 (03_kid_safe_rewriter_training.ipynb)

In [None]:
print("\n" + "="*80)
print("DATA PREPARATION COMPLETE!")
print("="*80)

# Calculate actual sizes dynamically
train_size = len(train_df)
val_size = len(val_df)
test_size = len(test_df)

print("\nGenerated files:")
print("  data/processed/")
print(f"    ├── train_stage2_input.csv ({train_size:,} samples)")
print(f"    ├── val_stage2_input.csv ({val_size:,} samples)")
print(f"    ├── test_stage2_input.csv ({test_size:,} samples)")
print("    └── metadata.json")
print("  data/keywords/")
print("    └── unsafe_keywords.txt (optional)")

# Display split breakdown
print("\n" + "="*80)
print("DATASET SPLIT SUMMARY")
print("="*80)
print(f"\nTotal samples: {len(df_clean):,}")
print(f"  Train: {train_size:,} ({train_size/len(df_clean)*100:.1f}%)")
print(f"  Val:   {val_size:,} ({val_size/len(df_clean)*100:.1f}%)")
print(f"  Test:  {test_size:,} ({test_size/len(df_clean)*100:.1f}%)")

# Label distribution across splits
print("\nLabel distribution across splits:")
for split_name, split_df in [('Train', train_df), ('Val', val_df), ('Test', test_df)]:
    unsafe_pct = (split_df['label'] == 'UNSAFE').sum() / len(split_df) * 100
    sensitive_pct = (split_df['label'] == 'SENSITIVE').sum() / len(split_df) * 100
    safe_pct = (split_df['label'] == 'SAFE').sum() / len(split_df) * 100
    print(f"  {split_name:5s}: UNSAFE={unsafe_pct:5.1f}%  SENSITIVE={sensitive_pct:5.1f}%  SAFE={safe_pct:5.1f}%")

print("\n" + "="*80)
print("READY FOR MULTICLASS CLASSIFIER TRAINING!")
print("="*80)
print("\nNext: Open 02_multiclass_classifier_training.ipynb")
print("\nThese files (train/val/test_stage2_input.csv) will be used as input for Stage 2 multiclass classifier training.")