# Mental Health Data Cleaning and Preprocessing

This notebook handles:
- Data cleaning and validation
- Missing value treatment
- Outlier detection and handling
- Feature engineering
- Data standardization

## 1. Setup and Imports

In [None]:
import sys
sys.path.append('../src')

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Import our preprocessing module
from data.preprocessing import DataPreprocessor

print("✓ All libraries imported successfully")

## 2. Load Raw Data

In [None]:
# Initialize preprocessor
preprocessor = DataPreprocessor(data_dir="../data")

# Load raw data
try:
    raw_data = preprocessor.load_raw_data()
    print(f"✓ Raw data loaded successfully: {raw_data.shape}")
    print(f"  Rows: {raw_data.shape[0]:,}")
    print(f"  Columns: {raw_data.shape[1]}")
    
    print("\nColumn names:")
    for i, col in enumerate(raw_data.columns, 1):
        print(f"{i:2d}. {col}")
        
except Exception as e:
    print(f"❌ Error loading data: {e}")
    raw_data = None

## 3. Data Cleaning

In [None]:
if raw_data is not None:
    print("=== BEFORE CLEANING ===")
    print(f"Shape: {raw_data.shape}")
    print(f"Missing values: {raw_data.isnull().sum().sum()}")
    
    # Apply cleaning
    cleaned_data = preprocessor.clean_mental_health_data(raw_data)
    
    print("\n=== AFTER CLEANING ===")
    print(f"Shape: {cleaned_data.shape}")
    print(f"Missing values: {cleaned_data.isnull().sum().sum()}")
    
    # Show changes
    rows_removed = raw_data.shape[0] - cleaned_data.shape[0]
    print(f"\n📊 Data cleaning summary:")
    print(f"  • Rows removed: {rows_removed:,} ({rows_removed/raw_data.shape[0]*100:.1f}%)")
    print(f"  • Columns standardized: {len(cleaned_data.columns)}")
    
    display(cleaned_data.head())
else:
    print("❌ No raw data available for cleaning")

## 4. Missing Values Analysis

In [None]:
if 'cleaned_data' in locals():
    print("=== MISSING VALUES ANALYSIS ===")
    
    missing_info = cleaned_data.isnull().sum()
    missing_pct = (missing_info / len(cleaned_data)) * 100
    
    missing_df = pd.DataFrame({
        'Missing Count': missing_info,
        'Missing Percentage': missing_pct
    }).sort_values('Missing Count', ascending=False)
    
    print("Missing values by column:")
    display(missing_df[missing_df['Missing Count'] > 0])
    
    # Visualize missing patterns
    if missing_df['Missing Count'].sum() > 0:
        fig, axes = plt.subplots(1, 2, figsize=(15, 6))
        
        # Missing values bar chart
        missing_cols = missing_df[missing_df['Missing Count'] > 0]
        if len(missing_cols) > 0:
            axes[0].barh(range(len(missing_cols)), missing_cols['Missing Percentage'])
            axes[0].set_yticks(range(len(missing_cols)))
            axes[0].set_yticklabels(missing_cols.index)
            axes[0].set_xlabel('Missing Percentage (%)')
            axes[0].set_title('Missing Data by Column')
        
        # Missing pattern heatmap
        sample_size = min(1000, len(cleaned_data))
        sample_data = cleaned_data.sample(sample_size)
        
        missing_matrix = sample_data.isnull().astype(int)
        if missing_matrix.sum().sum() > 0:
            sns.heatmap(missing_matrix.T, cbar=True, ax=axes[1], 
                       cmap='viridis', yticklabels=True)
            axes[1].set_title(f'Missing Data Pattern (Sample of {sample_size} rows)')
            axes[1].set_xlabel('Records')
        
        plt.tight_layout()
        plt.show()
    else:
        print("✅ No missing values found after cleaning!")

## 5. Outlier Detection

In [None]:
if 'cleaned_data' in locals():
    print("=== OUTLIER DETECTION ===")
    
    # Find prevalence columns
    prevalence_cols = [col for col in cleaned_data.columns if 'prevalence' in col]
    
    if prevalence_cols:
        fig, axes = plt.subplots(2, len(prevalence_cols), figsize=(4*len(prevalence_cols), 10))
        
        if len(prevalence_cols) == 1:
            axes = axes.reshape(-1, 1)
        
        outlier_summary = {}
        
        for i, col in enumerate(prevalence_cols):
            data_series = cleaned_data[col].dropna()
            
            if len(data_series) > 0:
                # Box plot
                axes[0, i].boxplot(data_series)
                axes[0, i].set_title(f'{col.replace("_", " ").title()}\nBox Plot')
                axes[0, i].set_ylabel('Prevalence (%)')
                
                # Histogram
                axes[1, i].hist(data_series, bins=30, alpha=0.7, edgecolor='black')
                axes[1, i].set_title(f'{col.replace("_", " ").title()}\nDistribution')
                axes[1, i].set_xlabel('Prevalence (%)')
                axes[1, i].set_ylabel('Frequency')
                
                # Calculate outliers using IQR method
                Q1 = data_series.quantile(0.25)
                Q3 = data_series.quantile(0.75)
                IQR = Q3 - Q1
                lower_bound = Q1 - 1.5 * IQR
                upper_bound = Q3 + 1.5 * IQR
                
                outliers = data_series[(data_series < lower_bound) | (data_series > upper_bound)]
                outlier_summary[col] = {
                    'count': len(outliers),
                    'percentage': len(outliers) / len(data_series) * 100,
                    'lower_bound': lower_bound,
                    'upper_bound': upper_bound
                }
        
        plt.tight_layout()
        plt.show()
        
        # Print outlier summary
        print("\nOutlier Detection Summary (IQR method):")
        for metric, info in outlier_summary.items():
            print(f"\n{metric}:")
            print(f"  • Outliers found: {info['count']} ({info['percentage']:.1f}%)")
            print(f"  • Valid range: {info['lower_bound']:.2f} - {info['upper_bound']:.2f}")
    else:
        print("No prevalence columns found for outlier detection")

## 6. Feature Engineering

In [None]:
if 'cleaned_data' in locals():
    print("=== FEATURE ENGINEERING ===")
    
    # Add derived features
    enhanced_data = preprocessor.add_derived_features(cleaned_data)
    
    new_features = set(enhanced_data.columns) - set(cleaned_data.columns)
    
    print(f"Original columns: {len(cleaned_data.columns)}")
    print(f"Enhanced columns: {len(enhanced_data.columns)}")
    print(f"New features added: {len(new_features)}")
    
    print("\nNew features:")
    for i, feature in enumerate(sorted(new_features), 1):
        print(f"{i:2d}. {feature}")
    
    # Show sample of enhanced data
    print("\nSample of enhanced data:")
    display(enhanced_data.head())
    
    # Analyze new features
    if 'region' in enhanced_data.columns:
        print("\nRegional distribution:")
        region_counts = enhanced_data['region'].value_counts()
        for region, count in region_counts.items():
            print(f"  {region}: {count} records")
        
        # Visualize regional distribution
        plt.figure(figsize=(10, 6))
        region_counts.plot(kind='bar')
        plt.title('Data Distribution by Region')
        plt.xlabel('Region')
        plt.ylabel('Number of Records')
        plt.xticks(rotation=45)
        plt.tight_layout()
        plt.show()

## 7. Load and Merge External Data

In [None]:
if 'enhanced_data' in locals():
    print("=== LOADING EXTERNAL DATA ===")
    
    # Load external datasets
    external_data = preprocessor.load_external_data()
    
    print(f"External datasets found: {len(external_data)}")
    for name, df in external_data.items():
        print(f"  • {name}: {df.shape}")
    
    if external_data:
        # Merge with main dataset
        merged_data = preprocessor.merge_datasets(enhanced_data, external_data)
        
        print(f"\nData before merge: {enhanced_data.shape}")
        print(f"Data after merge: {merged_data.shape}")
        
        new_cols = set(merged_data.columns) - set(enhanced_data.columns)
        if new_cols:
            print(f"\nNew columns from external data:")
            for col in sorted(new_cols):
                print(f"  • {col}")
        
        # Show sample of merged data
        print("\nSample of merged data:")
        display(merged_data[['country', 'year'] + list(new_cols)].head())
    else:
        print("No external data found, proceeding with enhanced data only")
        merged_data = enhanced_data

## 8. Complete Data Processing Pipeline

In [None]:
print("=== RUNNING COMPLETE PROCESSING PIPELINE ===")

try:
    # Run the complete processing pipeline
    final_data = preprocessor.process_all_data()
    
    print(f"✅ Processing completed successfully!")
    print(f"Final dataset shape: {final_data.shape}")
    
    # Show final data summary
    print("\n=== FINAL DATA SUMMARY ===")
    print(f"Total records: {len(final_data):,}")
    print(f"Total columns: {len(final_data.columns)}")
    print(f"Countries: {final_data['country'].nunique()}")
    print(f"Years: {final_data['year'].nunique()}")
    print(f"Year range: {final_data['year'].min()} - {final_data['year'].max()}")
    
    # Missing values in final data
    missing_final = final_data.isnull().sum().sum()
    print(f"Missing values: {missing_final} ({missing_final/final_data.size*100:.2f}%)")
    
    print("\n=== COLUMN CATEGORIES ===")
    
    # Categorize columns
    id_cols = ['country', 'year']
    prevalence_cols = [col for col in final_data.columns if 'prevalence' in col]
    derived_cols = [col for col in final_data.columns 
                   if any(keyword in col for keyword in ['change', 'decade', 'since', 'total'])]
    external_cols = [col for col in final_data.columns 
                    if any(keyword in col.lower() for keyword in ['population', 'gdp'])]
    other_cols = [col for col in final_data.columns 
                 if col not in id_cols + prevalence_cols + derived_cols + external_cols]
    
    print(f"Identifier columns ({len(id_cols)}): {id_cols}")
    print(f"Prevalence metrics ({len(prevalence_cols)}): {prevalence_cols}")
    print(f"Derived features ({len(derived_cols)}): {derived_cols}")
    print(f"External data ({len(external_cols)}): {external_cols}")
    print(f"Other columns ({len(other_cols)}): {other_cols}")
    
except Exception as e:
    print(f"❌ Processing failed: {e}")
    import traceback
    traceback.print_exc()

## 9. Data Validation

In [None]:
if 'final_data' in locals():
    print("=== DATA VALIDATION ===")
    
    validation_results = {}
    
    # 1. Check data types
    print("1. Data Types Validation:")
    numeric_cols = final_data.select_dtypes(include=[np.number]).columns
    object_cols = final_data.select_dtypes(include=['object']).columns
    
    print(f"   ✓ Numeric columns: {len(numeric_cols)}")
    print(f"   ✓ Text columns: {len(object_cols)}")
    
    # 2. Check value ranges for prevalence
    print("\n2. Prevalence Values Validation:")
    prevalence_cols = [col for col in final_data.columns if 'prevalence' in col]
    
    for col in prevalence_cols:
        values = final_data[col].dropna()
        min_val, max_val = values.min(), values.max()
        
        # Check if values are reasonable (0-100% for prevalence)
        valid_range = (min_val >= 0) and (max_val <= 100)
        status = "✓" if valid_range else "⚠️"
        
        print(f"   {status} {col}: {min_val:.2f}% - {max_val:.2f}%")
    
    # 3. Check for duplicates
    print("\n3. Duplicate Records Check:")
    if 'country' in final_data.columns and 'year' in final_data.columns:
        duplicates = final_data.duplicated(subset=['country', 'year']).sum()
        status = "✓" if duplicates == 0 else "⚠️"
        print(f"   {status} Duplicate country-year combinations: {duplicates}")
    
    # 4. Check temporal consistency
    print("\n4. Temporal Consistency:")
    if 'year' in final_data.columns:
        year_range = final_data['year'].max() - final_data['year'].min()
        print(f"   ✓ Year span: {year_range} years")
        
        # Check for reasonable year values
        reasonable_years = (final_data['year'] >= 1990) & (final_data['year'] <= 2025)
        unreasonable_count = (~reasonable_years).sum()
        status = "✓" if unreasonable_count == 0 else "⚠️"
        print(f"   {status} Records with unreasonable years: {unreasonable_count}")
    
    # 5. Geographic coverage
    print("\n5. Geographic Coverage:")
    if 'country' in final_data.columns:
        total_countries = final_data['country'].nunique()
        print(f"   ✓ Total countries: {total_countries}")
        
        # Countries with most data
        top_countries = final_data['country'].value_counts().head(5)
        print(f"   ✓ Top countries by data points:")
        for country, count in top_countries.items():
            print(f"      • {country}: {count} records")
    
    print("\n✅ Data validation completed!")

## 10. Final Data Overview

In [None]:
if 'final_data' in locals():
    print("=== FINAL PROCESSED DATA OVERVIEW ===")
    
    # Display basic info
    print(f"Dataset shape: {final_data.shape}")
    print(f"Memory usage: {final_data.memory_usage(deep=True).sum() / 1024**2:.2f} MB")
    
    # Display first few rows
    print("\nFirst 5 rows:")
    display(final_data.head())
    
    # Summary statistics for numeric columns
    numeric_cols = final_data.select_dtypes(include=[np.number]).columns
    if len(numeric_cols) > 0:
        print("\nSummary statistics for numeric columns:")
        display(final_data[numeric_cols].describe())
    
    # Create a simple visualization
    if 'country' in final_data.columns and 'year' in final_data.columns:
        plt.figure(figsize=(12, 6))
        
        # Data coverage over time
        plt.subplot(1, 2, 1)
        year_counts = final_data['year'].value_counts().sort_index()
        plt.plot(year_counts.index, year_counts.values, 'o-', linewidth=2, markersize=6)
        plt.xlabel('Year')
        plt.ylabel('Number of Records')
        plt.title('Data Coverage Over Time')
        plt.grid(True, alpha=0.3)
        
        # Top countries by data availability
        plt.subplot(1, 2, 2)
        top_countries = final_data['country'].value_counts().head(10)
        plt.barh(range(len(top_countries)), top_countries.values)
        plt.yticks(range(len(top_countries)), top_countries.index)
        plt.xlabel('Number of Records')
        plt.title('Top 10 Countries by Data Availability')
        
        plt.tight_layout()
        plt.show()
    
    print("\n📊 Data preprocessing completed successfully!")
    print("📁 Processed data saved to: ../data/processed/mental_health_processed.csv")
    print("📋 Data summary saved to: ../data/processed/data_summary.json")
    print("🔄 Ready for analysis - proceed to notebook 03_time_series_analysis.ipynb")