# Mental Health Data Exploration

This notebook provides an initial exploration of the global mental health dataset, including:
- Data loading and overview
- Basic statistics and data quality assessment
- Initial visualizations
- Data structure analysis

## 1. Setup and Imports

In [None]:
import sys
import os

# Add src directory to path
sys.path.append('../src')

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import plotly.graph_objects as go
from pathlib import Path

# Configure plotting
plt.style.use('default')
sns.set_palette("husl")
plt.rcParams['figure.figsize'] = (12, 8)

# Suppress warnings
import warnings
warnings.filterwarnings('ignore')

print("✓ All libraries imported successfully")

## 2. Load and Download Data

In [None]:
# Import data acquisition module
from data.download_data import DataDownloader

# Initialize data downloader
downloader = DataDownloader(data_dir="../data")

# Download all datasets
print("Downloading mental health datasets...")
success = downloader.download_all_data()

if success:
    print("✓ Data download completed successfully!")
    
    # Get information about downloaded files
    info = downloader.get_data_info()
    print("\nDownloaded files:")
    for category, files in info.items():
        print(f"  {category}: {len(files)} files")
        for file in files:
            print(f"    - {file.name}")
else:
    print("❌ Data download failed!")

## 3. Load Raw Data

In [None]:
# Load the raw mental health data
data_path = Path("../data/raw/mental_health_prevalence.csv")

if data_path.exists():
    raw_data = pd.read_csv(data_path)
    print(f"✓ Data loaded successfully: {raw_data.shape}")
    print(f"  Rows: {raw_data.shape[0]:,}")
    print(f"  Columns: {raw_data.shape[1]}")
else:
    print("❌ Data file not found. Please run the data download cell first.")
    raw_data = None

## 4. Initial Data Overview

In [None]:
if raw_data is not None:
    print("=== DATA OVERVIEW ===")
    print(f"Dataset shape: {raw_data.shape}")
    print(f"Memory usage: {raw_data.memory_usage(deep=True).sum() / 1024**2:.2f} MB")
    
    print("\n=== COLUMN INFORMATION ===")
    print(raw_data.info())
    
    print("\n=== COLUMN NAMES ===")
    for i, col in enumerate(raw_data.columns):
        print(f"{i+1:2d}. {col}")

In [None]:
if raw_data is not None:
    print("=== FIRST 10 ROWS ===")
    display(raw_data.head(10))
    
    print("\n=== LAST 5 ROWS ===")
    display(raw_data.tail(5))

## 5. Data Quality Assessment

In [None]:
if raw_data is not None:
    print("=== MISSING VALUES ANALYSIS ===")
    missing_info = raw_data.isnull().sum()
    missing_pct = (missing_info / len(raw_data)) * 100
    
    missing_df = pd.DataFrame({
        'Missing Count': missing_info,
        'Missing Percentage': missing_pct
    }).sort_values('Missing Count', ascending=False)
    
    print(missing_df[missing_df['Missing Count'] > 0])
    
    # Visualize missing data
    if missing_df['Missing Count'].sum() > 0:
        plt.figure(figsize=(12, 6))
        missing_cols = missing_df[missing_df['Missing Count'] > 0]
        plt.bar(range(len(missing_cols)), missing_cols['Missing Percentage'])
        plt.xlabel('Columns')
        plt.ylabel('Missing Percentage (%)')
        plt.title('Missing Data by Column')
        plt.xticks(range(len(missing_cols)), missing_cols.index, rotation=45, ha='right')
        plt.tight_layout()
        plt.show()
    else:
        print("✓ No missing values found in the dataset!")

In [None]:
if raw_data is not None:
    print("=== BASIC STATISTICS ===")
    
    # Identify numeric columns
    numeric_cols = raw_data.select_dtypes(include=[np.number]).columns
    
    if len(numeric_cols) > 0:
        display(raw_data[numeric_cols].describe())
    
    print("\n=== CATEGORICAL COLUMNS ===")
    categorical_cols = raw_data.select_dtypes(include=['object']).columns
    
    for col in categorical_cols:
        unique_count = raw_data[col].nunique()
        print(f"{col}: {unique_count} unique values")
        if unique_count <= 20:  # Show values if not too many
            print(f"  Values: {sorted(raw_data[col].unique().tolist())}")
        print()

## 6. Geographic and Temporal Coverage

In [None]:
if raw_data is not None:
    # Identify country and year columns
    country_col = None
    year_col = None
    
    for col in raw_data.columns:
        if 'entity' in col.lower() or 'country' in col.lower():
            country_col = col
        elif 'year' in col.lower():
            year_col = col
    
    if country_col and year_col:
        print(f"=== GEOGRAPHIC COVERAGE (using {country_col}) ===")
        print(f"Total unique countries/entities: {raw_data[country_col].nunique()}")
        
        # Show top countries by data points
        country_counts = raw_data[country_col].value_counts().head(15)
        print("\nTop 15 countries by data points:")
        for country, count in country_counts.items():
            print(f"  {country}: {count} records")
        
        print(f"\n=== TEMPORAL COVERAGE (using {year_col}) ===")
        year_range = raw_data[year_col].agg(['min', 'max'])
        print(f"Year range: {year_range['min']} - {year_range['max']}")
        print(f"Total years covered: {raw_data[year_col].nunique()}")
        
        # Visualize temporal coverage
        plt.figure(figsize=(14, 6))
        
        # Data points per year
        plt.subplot(1, 2, 1)
        year_counts = raw_data[year_col].value_counts().sort_index()
        plt.plot(year_counts.index, year_counts.values, marker='o')
        plt.xlabel('Year')
        plt.ylabel('Number of Records')
        plt.title('Data Points per Year')
        plt.grid(True, alpha=0.3)
        
        # Countries per year
        plt.subplot(1, 2, 2)
        countries_per_year = raw_data.groupby(year_col)[country_col].nunique()
        plt.plot(countries_per_year.index, countries_per_year.values, marker='s', color='green')
        plt.xlabel('Year')
        plt.ylabel('Number of Countries')
        plt.title('Countries with Data per Year')
        plt.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()
    else:
        print("Could not identify country and year columns for coverage analysis")

## 7. Mental Health Metrics Overview

In [None]:
if raw_data is not None:
    # Identify prevalence columns
    prevalence_cols = [col for col in raw_data.columns if 'prevalence' in col.lower()]
    
    print(f"=== MENTAL HEALTH METRICS ({len(prevalence_cols)} found) ===")
    for i, col in enumerate(prevalence_cols, 1):
        print(f"{i}. {col}")
    
    if prevalence_cols:
        print("\n=== PREVALENCE STATISTICS ===")
        display(raw_data[prevalence_cols].describe())
        
        # Visualize distributions
        n_metrics = len(prevalence_cols)
        if n_metrics > 0:
            fig, axes = plt.subplots(2, min(2, n_metrics), figsize=(15, 10))
            if n_metrics == 1:
                axes = np.array([axes]).flatten()
            elif n_metrics == 2:
                axes = axes.flatten()
            
            for i, col in enumerate(prevalence_cols[:4]):  # Show up to 4 metrics
                if i < len(axes):
                    data_to_plot = raw_data[col].dropna()
                    if len(data_to_plot) > 0:
                        axes[i].hist(data_to_plot, bins=30, alpha=0.7, edgecolor='black')
                        axes[i].set_title(col.replace('_', ' ').title())
                        axes[i].set_xlabel('Prevalence (%)')
                        axes[i].set_ylabel('Frequency')
                        axes[i].grid(True, alpha=0.3)
            
            # Hide unused subplots
            for i in range(len(prevalence_cols), len(axes)):
                axes[i].set_visible(False)
            
            plt.suptitle('Distribution of Mental Health Prevalence Metrics', fontsize=16)
            plt.tight_layout()
            plt.show()
    else:
        print("No prevalence columns found in the dataset")

## 8. Sample Data Exploration

In [None]:
if raw_data is not None and country_col and year_col:
    print("=== SAMPLE COUNTRY ANALYSIS ===")
    
    # Get a sample country with good data coverage
    country_data_counts = raw_data[country_col].value_counts()
    sample_countries = country_data_counts.head(5).index.tolist()
    
    print(f"Analyzing sample countries: {sample_countries}")
    
    if prevalence_cols:
        # Create time series for sample countries
        plt.figure(figsize=(15, 8))
        
        colors = plt.cm.Set1(np.linspace(0, 1, len(sample_countries)))
        
        for i, country in enumerate(sample_countries):
            country_data = raw_data[raw_data[country_col] == country].copy()
            country_data = country_data.sort_values(year_col)
            
            if len(country_data) > 0 and prevalence_cols[0] in country_data.columns:
                plt.plot(country_data[year_col], country_data[prevalence_cols[0]], 
                        'o-', label=country, color=colors[i], linewidth=2, markersize=6)
        
        plt.xlabel('Year')
        plt.ylabel(prevalence_cols[0].replace('_', ' ').title())
        plt.title(f'{prevalence_cols[0].replace("_", " ").title()} Trends for Sample Countries')
        plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.show()
        
        # Show detailed data for one country
        sample_country = sample_countries[0]
        sample_data = raw_data[raw_data[country_col] == sample_country].copy()
        sample_data = sample_data.sort_values(year_col)
        
        print(f"\n=== DETAILED DATA FOR {sample_country.upper()} ===")
        display(sample_data[[year_col] + prevalence_cols].head(10))

## 9. Data Quality Summary

In [None]:
if raw_data is not None:
    print("=== DATA QUALITY SUMMARY ===")
    
    total_records = len(raw_data)
    total_countries = raw_data[country_col].nunique() if country_col else 'Unknown'
    total_years = raw_data[year_col].nunique() if year_col else 'Unknown'
    total_metrics = len(prevalence_cols)
    
    print(f"📊 Dataset Overview:")
    print(f"   • Total records: {total_records:,}")
    print(f"   • Countries/Entities: {total_countries}")
    print(f"   • Years covered: {total_years}")
    print(f"   • Mental health metrics: {total_metrics}")
    
    # Data completeness
    if prevalence_cols:
        completeness = (raw_data[prevalence_cols].notna().sum() / len(raw_data)) * 100
        avg_completeness = completeness.mean()
        
        print(f"\n📈 Data Completeness:")
        print(f"   • Average completeness: {avg_completeness:.1f}%")
        
        for metric, comp in completeness.items():
            status = "✓" if comp >= 80 else "⚠️" if comp >= 50 else "❌"
            print(f"   • {metric}: {comp:.1f}% {status}")
    
    # Recommendations
    print(f"\n🔍 Next Steps:")
    print(f"   1. Data cleaning and preprocessing needed")
    print(f"   2. Handle missing values and outliers")
    print(f"   3. Standardize country names and validate data")
    print(f"   4. Create derived features and regional groupings")
    print(f"   5. Proceed to detailed analysis")
    
    print(f"\n✅ Data exploration completed successfully!")
    print(f"📁 Proceed to notebook 02_data_cleaning.ipynb for data preprocessing")

## 10. Save Exploration Results

In [None]:
if raw_data is not None:
    # Save basic exploration results
    exploration_results = {
        'dataset_shape': raw_data.shape,
        'total_countries': raw_data[country_col].nunique() if country_col else 0,
        'total_years': raw_data[year_col].nunique() if year_col else 0,
        'year_range': [raw_data[year_col].min(), raw_data[year_col].max()] if year_col else [],
        'prevalence_metrics': prevalence_cols,
        'missing_data_summary': raw_data.isnull().sum().to_dict(),
        'top_countries': raw_data[country_col].value_counts().head(10).to_dict() if country_col else {}
    }
    
    # Save to JSON
    import json
    results_path = Path("../data/processed/exploration_results.json")
    results_path.parent.mkdir(parents=True, exist_ok=True)
    
    # Convert numpy types for JSON serialization
    def convert_numpy(obj):
        if isinstance(obj, np.integer):
            return int(obj)
        elif isinstance(obj, np.floating):
            return float(obj)
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        return obj
    
    clean_results = {}
    for key, value in exploration_results.items():
        if isinstance(value, dict):
            clean_results[key] = {k: convert_numpy(v) for k, v in value.items()}
        else:
            clean_results[key] = convert_numpy(value)
    
    with open(results_path, 'w') as f:
        json.dump(clean_results, f, indent=2)
    
    print(f"✅ Exploration results saved to: {results_path}")
    print(f"📝 Summary: {len(clean_results)} key findings documented")
else:
    print("❌ No data available to save results")