# Advanced Mental Health Data Visualizations

This notebook creates comprehensive visualizations including:
- Interactive dashboards and plots
- Geographic heatmaps and choropleths
- Statistical visualizations
- Publication-ready charts
- Custom visualization themes

## 1. Setup and Imports

In [None]:
import sys
import os

# Add src directory to path
sys.path.append('../src')

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.figure_factory as ff
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Import custom modules
from visualization.static_plots import StaticPlots
from visualization.interactive_plots import InteractivePlots
from visualization.dashboard_components import DashboardComponents

# Configure plotting
plt.style.use('default')
sns.set_palette("husl")
plt.rcParams['figure.figsize'] = (14, 8)

# Plotly configuration
import plotly.io as pio
pio.renderers.default = "notebook"

print("✓ All visualization libraries imported successfully")

## 2. Load Processed Data

In [None]:
# Load cleaned data
data_path = Path("../data/processed/mental_health_cleaned.csv")

if data_path.exists():
    data = pd.read_csv(data_path)
    print(f"✓ Processed data loaded: {data.shape}")
else:
    print("⚠️ Processed data not found. Creating sample data for demonstration.")
    # Create comprehensive sample data
    np.random.seed(42)
    
    countries = ['United States', 'United Kingdom', 'Germany', 'France', 'Japan', 
                'Canada', 'Australia', 'Brazil', 'India', 'China', 'Russia', 
                'South Africa', 'Mexico', 'Italy', 'Spain', 'Netherlands']
    years = list(range(1990, 2024))
    regions = {'United States': 'North America', 'Canada': 'North America',
              'United Kingdom': 'Europe', 'Germany': 'Europe', 'France': 'Europe',
              'Italy': 'Europe', 'Spain': 'Europe', 'Netherlands': 'Europe',
              'Japan': 'Asia', 'China': 'Asia', 'India': 'Asia',
              'Australia': 'Oceania', 'Brazil': 'South America', 'Mexico': 'North America',
              'Russia': 'Europe', 'South Africa': 'Africa'}
    
    sample_data = []
    for country in countries:
        base_depression = np.random.uniform(3, 8)
        base_anxiety = np.random.uniform(2, 6)
        base_bipolar = np.random.uniform(0.5, 1.5)
        base_eating = np.random.uniform(0.2, 1.0)
        
        for year in years:
            trend_factor = (year - 1990) * 0.025
            noise = np.random.normal(0, 0.3)
            
            sample_data.append({
                'Entity': country,
                'Year': year,
                'Region': regions.get(country, 'Other'),
                'Depression_prevalence': max(0, base_depression + trend_factor + noise),
                'Anxiety_prevalence': max(0, base_anxiety + trend_factor * 0.8 + noise),
                'Bipolar_prevalence': max(0, base_bipolar + trend_factor * 0.3 + noise * 0.5),
                'Eating_disorders_prevalence': max(0, base_eating + trend_factor * 0.2 + noise * 0.3),
                'Population': np.random.randint(10000000, 1400000000),
                'GDP_per_capita': np.random.uniform(5000, 80000)
            })
    
    data = pd.DataFrame(sample_data)
    print(f"✓ Sample data created: {data.shape}")

# Load analysis results if available
results_path = Path("../data/processed/time_series_analysis_results.json")
if results_path.exists():
    import json
    with open(results_path, 'r') as f:
        analysis_results = json.load(f)
    print("✓ Time series analysis results loaded")
else:
    analysis_results = None
    print("⚠️ Time series analysis results not found")

# Display dataset overview
prevalence_cols = [col for col in data.columns if 'prevalence' in col.lower()]
print(f"\nDataset overview:")
print(f"  Countries: {data['Entity'].nunique()}")
print(f"  Years: {data['Year'].min()} - {data['Year'].max()}")
print(f"  Mental health metrics: {len(prevalence_cols)}")
if 'Region' in data.columns:
    print(f"  Regions: {data['Region'].nunique()}")

## 3. Initialize Visualization Tools

In [None]:
# Initialize visualization modules
static_plots = StaticPlots()
interactive_plots = InteractivePlots()
dashboard_components = DashboardComponents()

print("✓ Visualization tools initialized")

# Set visualization parameters
main_metric = prevalence_cols[0] if prevalence_cols else 'Depression_prevalence'
latest_year = data['Year'].max()
earliest_year = data['Year'].min()

print(f"Primary visualization metric: {main_metric}")
print(f"Time range: {earliest_year} - {latest_year}")

## 4. Geographic Visualizations

In [None]:
print("=== GEOGRAPHIC VISUALIZATIONS ===")

# 1. World choropleth map for latest year
latest_data = data[data['Year'] == latest_year].copy()

if not latest_data.empty:
    print(f"\n1. World Map - {main_metric.replace('_', ' ').title()} ({latest_year})")
    
    fig_world = px.choropleth(
        latest_data,
        locations='Entity',
        locationmode='country names',
        color=main_metric,
        hover_name='Entity',
        hover_data={main_metric: ':.2f'},
        color_continuous_scale='Reds',
        title=f'Global {main_metric.replace("_", " ").title()} - {latest_year}'
    )
    
    fig_world.update_layout(
        title_x=0.5,
        geo=dict(
            showframe=False,
            showcoastlines=True,
            projection_type='natural earth'
        ),
        width=1000,
        height=600
    )
    
    fig_world.show()

# 2. Regional comparison map
if 'Region' in data.columns:
    print(f"\n2. Regional Comparison - Average {main_metric.replace('_', ' ').title()}")
    
    regional_avg = latest_data.groupby('Region')[main_metric].mean().reset_index()
    regional_avg['Region_Display'] = regional_avg['Region']
    
    fig_regional = px.bar(
        regional_avg.sort_values(main_metric, ascending=True),
        x=main_metric,
        y='Region',
        orientation='h',
        color=main_metric,
        color_continuous_scale='viridis',
        title=f'Regional Average {main_metric.replace("_", " ").title()} - {latest_year}',
        text=main_metric
    )
    
    fig_regional.update_traces(texttemplate='%{text:.2f}%', textposition='inside')
    fig_regional.update_layout(
        title_x=0.5,
        xaxis_title=f'{main_metric.replace("_", " ").title()} (%)',
        yaxis_title='Region',
        width=800,
        height=500
    )
    
    fig_regional.show()

## 5. Time Series Visualizations

In [None]:
print("=== TIME SERIES VISUALIZATIONS ===")

# 1. Global trend over time
print("\n1. Global Trend Analysis")

global_trend = data.groupby('Year')[main_metric].agg(['mean', 'std', 'count']).reset_index()
global_trend.columns = ['Year', 'Mean', 'Std', 'Count']

fig_global = go.Figure()

# Add main trend line
fig_global.add_trace(go.Scatter(
    x=global_trend['Year'],
    y=global_trend['Mean'],
    mode='lines+markers',
    name='Global Average',
    line=dict(color='blue', width=3),
    marker=dict(size=8)
))

# Add confidence interval
fig_global.add_trace(go.Scatter(
    x=np.concatenate([global_trend['Year'], global_trend['Year'][::-1]]),
    y=np.concatenate([global_trend['Mean'] + global_trend['Std'], 
                     (global_trend['Mean'] - global_trend['Std'])[::-1]]),
    fill='toself',
    fillcolor='rgba(0,100,80,0.2)',
    line=dict(color='rgba(255,255,255,0)'),
    name='±1 Standard Deviation',
    hoverinfo='skip'
))

# Add trend line
z = np.polyfit(global_trend['Year'], global_trend['Mean'], 1)
trend_line = np.poly1d(z)(global_trend['Year'])

fig_global.add_trace(go.Scatter(
    x=global_trend['Year'],
    y=trend_line,
    mode='lines',
    name=f'Linear Trend (slope: {z[0]:.4f})',
    line=dict(color='red', dash='dash', width=2)
))

fig_global.update_layout(
    title=f'Global {main_metric.replace("_", " ").title()} Trend Over Time',
    xaxis_title='Year',
    yaxis_title=f'{main_metric.replace("_", " ").title()} (%)',
    width=1000,
    height=600,
    hovermode='x unified'
)

fig_global.show()

# 2. Country comparison time series
print("\n2. Top Countries Comparison")

# Get top 8 countries by latest year value
top_countries = latest_data.nlargest(8, main_metric)['Entity'].tolist()

fig_countries = go.Figure()

colors = px.colors.qualitative.Set1
for i, country in enumerate(top_countries):
    country_data = data[data['Entity'] == country].sort_values('Year')
    
    fig_countries.add_trace(go.Scatter(
        x=country_data['Year'],
        y=country_data[main_metric],
        mode='lines+markers',
        name=country,
        line=dict(color=colors[i % len(colors)], width=2),
        marker=dict(size=6)
    ))

fig_countries.update_layout(
    title=f'Top Countries - {main_metric.replace("_", " ").title()} Trends',
    xaxis_title='Year',
    yaxis_title=f'{main_metric.replace("_", " ").title()} (%)',
    width=1000,
    height=600,
    hovermode='x unified'
)

fig_countries.show()

## 6. Multi-Metric Comparisons

In [None]:
if len(prevalence_cols) >= 2:
    print("=== MULTI-METRIC COMPARISONS ===")
    
    # 1. Correlation heatmap
    print("\n1. Correlation Matrix of Mental Health Metrics")
    
    correlation_data = data[prevalence_cols].corr()
    
    fig_corr = px.imshow(
        correlation_data,
        text_auto=True,
        aspect='auto',
        color_continuous_scale='RdBu_r',
        color_continuous_midpoint=0,
        title='Correlation Matrix - Mental Health Metrics'
    )
    
    fig_corr.update_layout(
        width=800,
        height=600,
        title_x=0.5
    )
    
    fig_corr.show()
    
    # 2. Radar chart for regional profiles
    if 'Region' in data.columns:
        print("\n2. Regional Mental Health Profiles (Latest Year)")
        
        regional_profiles = latest_data.groupby('Region')[prevalence_cols].mean().reset_index()
        
        fig_radar = go.Figure()
        
        for i, row in regional_profiles.iterrows():
            fig_radar.add_trace(go.Scatterpolar(
                r=[row[col] for col in prevalence_cols],
                theta=[col.replace('_prevalence', '').replace('_', ' ').title() for col in prevalence_cols],
                fill='toself',
                name=row['Region'],
                line=dict(width=2)
            ))
        
        fig_radar.update_layout(
            polar=dict(
                radialaxis=dict(
                    visible=True,
                    range=[0, max([regional_profiles[col].max() for col in prevalence_cols])]
                )
            ),
            title='Regional Mental Health Profiles - All Metrics',
            width=800,
            height=600,
            title_x=0.5
        )
        
        fig_radar.show()
    
    # 3. Scatter plot matrix
    print("\n3. Relationships Between Mental Health Metrics")
    
    if len(prevalence_cols) >= 2:
        # Sample data for better performance
        sample_data = latest_data.sample(min(100, len(latest_data)))
        
        fig_scatter = px.scatter_matrix(
            sample_data,
            dimensions=prevalence_cols[:4],  # Limit to first 4 metrics
            color='Region' if 'Region' in sample_data.columns else None,
            title='Scatter Plot Matrix - Mental Health Metrics',
            hover_name='Entity'
        )
        
        fig_scatter.update_layout(
            width=1000,
            height=800
        )
        
        fig_scatter.show()
else:
    print("Multiple metrics not available for comparison")

## 7. Statistical Distribution Visualizations

In [None]:
print("=== STATISTICAL DISTRIBUTION VISUALIZATIONS ===")

# 1. Distribution analysis
print("\n1. Distribution of Mental Health Prevalence")

fig_dist = make_subplots(
    rows=2, cols=2,
    subplot_titles=('Histogram', 'Box Plot', 'Violin Plot', 'QQ Plot'),
    specs=[[{'secondary_y': False}, {'secondary_y': False}],
           [{'secondary_y': False}, {'secondary_y': False}]]
)

# Histogram
fig_dist.add_trace(
    go.Histogram(x=latest_data[main_metric], nbinsx=20, name='Distribution'),
    row=1, col=1
)

# Box plot by region
if 'Region' in latest_data.columns:
    for region in latest_data['Region'].unique():
        region_data = latest_data[latest_data['Region'] == region]
        fig_dist.add_trace(
            go.Box(y=region_data[main_metric], name=region),
            row=1, col=2
        )
else:
    fig_dist.add_trace(
        go.Box(y=latest_data[main_metric], name='All Countries'),
        row=1, col=2
    )

# Violin plot
fig_dist.add_trace(
    go.Violin(y=latest_data[main_metric], name='Distribution', box_visible=True),
    row=2, col=1
)

# QQ plot (normal distribution)
from scipy.stats import probplot
qq_data = probplot(latest_data[main_metric].dropna(), dist='norm')
fig_dist.add_trace(
    go.Scatter(x=qq_data[0][0], y=qq_data[0][1], mode='markers', name='Data Points'),
    row=2, col=2
)
# Add QQ line
fig_dist.add_trace(
    go.Scatter(x=qq_data[0][0], y=qq_data[1][1] + qq_data[1][0] * qq_data[0][0], 
              mode='lines', name='Normal Line', line=dict(color='red')),
    row=2, col=2
)

fig_dist.update_layout(
    title=f'Statistical Distribution Analysis - {main_metric.replace("_", " ").title()}',
    width=1000,
    height=800,
    showlegend=False
)

fig_dist.show()

# 2. Time series of distributions
print("\n2. Evolution of Distribution Over Time")

# Create violin plots for different decades
data['Decade'] = (data['Year'] // 10) * 10
decades = sorted(data['Decade'].unique())

fig_evolution = go.Figure()

for decade in decades[-4:]:  # Last 4 decades
    decade_data = data[data['Decade'] == decade]
    fig_evolution.add_trace(go.Violin(
        y=decade_data[main_metric],
        name=f'{decade}s',
        box_visible=True,
        meanline_visible=True
    ))

fig_evolution.update_layout(
    title=f'Distribution Evolution by Decade - {main_metric.replace("_", " ").title()}',
    yaxis_title=f'{main_metric.replace("_", " ").title()} (%)',
    xaxis_title='Decade',
    width=800,
    height=600
)

fig_evolution.show()

## 8. Advanced Interactive Visualizations

In [None]:
print("=== ADVANCED INTERACTIVE VISUALIZATIONS ===")

# 1. Animated bubble chart
print("\n1. Animated Bubble Chart - Mental Health vs GDP")

if 'GDP_per_capita' in data.columns:
    # Prepare data for animation
    bubble_data = data[data['Year'] >= 2000].copy()  # Focus on recent years
    bubble_data = bubble_data.dropna(subset=[main_metric, 'GDP_per_capita', 'Population'])
    
    fig_bubble = px.scatter(
        bubble_data,
        x='GDP_per_capita',
        y=main_metric,
        size='Population',
        color='Region' if 'Region' in bubble_data.columns else 'Entity',
        hover_name='Entity',
        animation_frame='Year',
        animation_group='Entity',
        size_max=50,
        range_x=[bubble_data['GDP_per_capita'].min() * 0.9, bubble_data['GDP_per_capita'].max() * 1.1],
        range_y=[bubble_data[main_metric].min() * 0.9, bubble_data[main_metric].max() * 1.1],
        title=f'{main_metric.replace("_", " ").title()} vs GDP per Capita Over Time'
    )
    
    fig_bubble.update_layout(
        xaxis_title='GDP per Capita (USD)',
        yaxis_title=f'{main_metric.replace("_", " ").title()} (%)',
        width=1000,
        height=600
    )
    
    fig_bubble.show()

# 2. Parallel coordinates plot
print("\n2. Parallel Coordinates - Multi-dimensional Analysis")

if len(prevalence_cols) >= 3:
    parallel_data = latest_data[['Entity', 'Region'] + prevalence_cols[:4]].dropna()
    
    # Normalize data for better visualization
    for col in prevalence_cols[:4]:
        parallel_data[f'{col}_norm'] = (parallel_data[col] - parallel_data[col].min()) / (parallel_data[col].max() - parallel_data[col].min())
    
    dimensions = []
    for col in prevalence_cols[:4]:
        dimensions.append(dict(
            label=col.replace('_prevalence', '').replace('_', ' ').title(),
            values=parallel_data[col],
            range=[parallel_data[col].min(), parallel_data[col].max()]
        ))
    
    fig_parallel = go.Figure(data=
        go.Parcoords(
            line=dict(color=parallel_data[prevalence_cols[0]],
                     colorscale='Viridis',
                     showscale=True,
                     colorbar=dict(title=prevalence_cols[0].replace('_', ' ').title())),
            dimensions=dimensions
        )
    )
    
    fig_parallel.update_layout(
        title='Parallel Coordinates - Mental Health Metrics Profile',
        width=1000,
        height=600
    )
    
    fig_parallel.show()

# 3. Sunburst chart for hierarchical data
if 'Region' in data.columns:
    print("\n3. Sunburst Chart - Regional and Country Breakdown")
    
    # Prepare hierarchical data
    sunburst_data = latest_data.groupby(['Region', 'Entity'])[main_metric].mean().reset_index()
    sunburst_data['All Countries'] = 'World'
    
    fig_sunburst = px.sunburst(
        sunburst_data,
        path=['All Countries', 'Region', 'Entity'],
        values=main_metric,
        title=f'Hierarchical View - {main_metric.replace("_", " ").title()} by Region and Country'
    )
    
    fig_sunburst.update_layout(
        width=800,
        height=600
    )
    
    fig_sunburst.show()

## 9. Publication-Ready Static Plots

In [None]:
print("=== PUBLICATION-READY STATIC PLOTS ===")

# Set publication style
plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams.update({
    'figure.figsize': (12, 8),
    'font.size': 12,
    'axes.titlesize': 14,
    'axes.labelsize': 12,
    'xtick.labelsize': 10,
    'ytick.labelsize': 10,
    'legend.fontsize': 10,
    'figure.titlesize': 16
})

# 1. Multi-panel figure
print("\n1. Multi-panel Summary Figure")

fig, axes = plt.subplots(2, 3, figsize=(18, 12))
fig.suptitle('Mental Health Analysis Summary', fontsize=20, y=0.98)

# Panel A: Global trend
global_trend = data.groupby('Year')[main_metric].mean()
axes[0, 0].plot(global_trend.index, global_trend.values, 'o-', linewidth=2, markersize=6)
z = np.polyfit(global_trend.index, global_trend.values, 1)
axes[0, 0].plot(global_trend.index, np.poly1d(z)(global_trend.index), '--r', linewidth=2)
axes[0, 0].set_title('A. Global Trend Over Time')
axes[0, 0].set_xlabel('Year')
axes[0, 0].set_ylabel(f'{main_metric.replace("_", " ").title()} (%)')
axes[0, 0].grid(True, alpha=0.3)

# Panel B: Regional comparison
if 'Region' in data.columns:
    regional_avg = latest_data.groupby('Region')[main_metric].mean().sort_values(ascending=True)
    bars = axes[0, 1].barh(range(len(regional_avg)), regional_avg.values, color='lightcoral')
    axes[0, 1].set_yticks(range(len(regional_avg)))
    axes[0, 1].set_yticklabels(regional_avg.index)
    axes[0, 1].set_title('B. Regional Comparison')
    axes[0, 1].set_xlabel(f'{main_metric.replace("_", " ").title()} (%)')
    
    # Add value labels
    for i, bar in enumerate(bars):
        width = bar.get_width()
        axes[0, 1].text(width + 0.1, bar.get_y() + bar.get_height()/2, 
                       f'{width:.2f}', ha='left', va='center')

# Panel C: Distribution
axes[0, 2].hist(latest_data[main_metric], bins=15, alpha=0.7, color='skyblue', edgecolor='black')
axes[0, 2].axvline(latest_data[main_metric].mean(), color='red', linestyle='--', linewidth=2, label='Mean')
axes[0, 2].axvline(latest_data[main_metric].median(), color='orange', linestyle='--', linewidth=2, label='Median')
axes[0, 2].set_title('C. Distribution (Latest Year)')
axes[0, 2].set_xlabel(f'{main_metric.replace("_", " ").title()} (%)')
axes[0, 2].set_ylabel('Frequency')
axes[0, 2].legend()
axes[0, 2].grid(True, alpha=0.3)

# Panel D: Top vs Bottom countries
country_latest = latest_data.groupby('Entity')[main_metric].mean().sort_values()
top_5 = country_latest.tail(5)
bottom_5 = country_latest.head(5)

y_pos = np.arange(len(top_5))
axes[1, 0].barh(y_pos, top_5.values, alpha=0.8, color='red', label='Highest')
axes[1, 0].barh(y_pos - 0.4, bottom_5.values, alpha=0.8, color='green', label='Lowest')
axes[1, 0].set_yticks(y_pos - 0.2)
axes[1, 0].set_yticklabels([f'{top}\nvs\n{bot}' for top, bot in zip(top_5.index, bottom_5.index)])
axes[1, 0].set_title('D. Highest vs Lowest Countries')
axes[1, 0].set_xlabel(f'{main_metric.replace("_", " ").title()} (%)')
axes[1, 0].legend()

# Panel E: Correlation with other metrics
if len(prevalence_cols) >= 2:
    x_metric = prevalence_cols[1] if len(prevalence_cols) > 1 else prevalence_cols[0]
    scatter_data = latest_data[[main_metric, x_metric]].dropna()
    axes[1, 1].scatter(scatter_data[x_metric], scatter_data[main_metric], alpha=0.7, s=60)
    
    # Add correlation line
    z = np.polyfit(scatter_data[x_metric], scatter_data[main_metric], 1)
    axes[1, 1].plot(scatter_data[x_metric], np.poly1d(z)(scatter_data[x_metric]), '--r', linewidth=2)
    
    # Calculate correlation
    corr = scatter_data[main_metric].corr(scatter_data[x_metric])
    axes[1, 1].set_title(f'E. Correlation (r = {corr:.3f})')
    axes[1, 1].set_xlabel(f'{x_metric.replace("_", " ").title()} (%)')
    axes[1, 1].set_ylabel(f'{main_metric.replace("_", " ").title()} (%)')
    axes[1, 1].grid(True, alpha=0.3)

# Panel F: Time series variance
yearly_std = data.groupby('Year')[main_metric].std()
axes[1, 2].fill_between(yearly_std.index, yearly_std.values, alpha=0.6, color='purple')
axes[1, 2].plot(yearly_std.index, yearly_std.values, 'o-', color='darkpurple', linewidth=2)
axes[1, 2].set_title('F. Between-Country Variability')
axes[1, 2].set_xlabel('Year')
axes[1, 2].set_ylabel('Standard Deviation')
axes[1, 2].grid(True, alpha=0.3)

plt.tight_layout()
plt.subplots_adjust(top=0.93)
plt.show()

# 2. High-quality time series plot
print("\n2. Publication-Quality Time Series")

fig, ax = plt.subplots(figsize=(14, 8))

# Plot global trend with confidence intervals
global_stats = data.groupby('Year')[main_metric].agg(['mean', 'std', 'sem'])

ax.fill_between(global_stats.index, 
               global_stats['mean'] - 1.96 * global_stats['sem'],
               global_stats['mean'] + 1.96 * global_stats['sem'],
               alpha=0.3, color='lightblue', label='95% Confidence Interval')

ax.plot(global_stats.index, global_stats['mean'], 'o-', 
       linewidth=3, markersize=8, color='darkblue', label='Global Average')

# Add trend line
z = np.polyfit(global_stats.index, global_stats['mean'], 1)
trend_line = np.poly1d(z)(global_stats.index)
ax.plot(global_stats.index, trend_line, '--', linewidth=2, color='red',
       label=f'Linear Trend (slope = {z[0]:.4f} per year)')

ax.set_xlabel('Year', fontsize=14)
ax.set_ylabel(f'{main_metric.replace("_", " ").title()} (%)', fontsize=14)
ax.set_title(f'Global {main_metric.replace("_", " ").title()} Trends: {earliest_year}-{latest_year}', 
            fontsize=16, pad=20)
ax.legend(fontsize=12, loc='best')
ax.grid(True, alpha=0.3)

# Add annotations
max_year = global_stats['mean'].idxmax()
max_value = global_stats['mean'].max()
ax.annotate(f'Peak: {max_value:.2f}% ({max_year})',
           xy=(max_year, max_value), xytext=(max_year-3, max_value+0.3),
           arrowprops=dict(arrowstyle='->', color='black', alpha=0.7),
           fontsize=11, ha='center')

plt.tight_layout()
plt.show()

## 10. Save Visualizations

In [None]:
print("=== SAVING VISUALIZATIONS ===")

# Create visualizations directory
viz_path = Path("../visualizations")
viz_path.mkdir(parents=True, exist_ok=True)

# Save publication-ready plots
print("\n1. Saving static plots...")

# Global trend plot
fig, ax = plt.subplots(figsize=(12, 8))
global_stats = data.groupby('Year')[main_metric].agg(['mean', 'std'])
ax.fill_between(global_stats.index, 
               global_stats['mean'] - global_stats['std'],
               global_stats['mean'] + global_stats['std'],
               alpha=0.3, color='lightblue')
ax.plot(global_stats.index, global_stats['mean'], 'o-', linewidth=2, markersize=6)
ax.set_xlabel('Year')
ax.set_ylabel(f'{main_metric.replace("_", " ").title()} (%)')
ax.set_title(f'Global {main_metric.replace("_", " ").title()} Trend')
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(viz_path / 'global_trend.png', dpi=300, bbox_inches='tight')
plt.savefig(viz_path / 'global_trend.pdf', bbox_inches='tight')
plt.close()

# Regional comparison
if 'Region' in data.columns:
    fig, ax = plt.subplots(figsize=(10, 6))
    regional_avg = latest_data.groupby('Region')[main_metric].mean().sort_values()
    bars = ax.barh(regional_avg.index, regional_avg.values, color='lightcoral')
    ax.set_xlabel(f'{main_metric.replace("_", " ").title()} (%)')
    ax.set_title(f'Regional Comparison - {main_metric.replace("_", " ").title()} ({latest_year})')
    
    # Add value labels
    for i, bar in enumerate(bars):
        width = bar.get_width()
        ax.text(width + 0.1, bar.get_y() + bar.get_height()/2, 
               f'{width:.2f}%', ha='left', va='center')
    
    plt.tight_layout()
    plt.savefig(viz_path / 'regional_comparison.png', dpi=300, bbox_inches='tight')
    plt.savefig(viz_path / 'regional_comparison.pdf', bbox_inches='tight')
    plt.close()

print(f"✓ Static plots saved to {viz_path}")

# Save interactive plots as HTML
print("\n2. Saving interactive plots...")

# Recreate and save world map
fig_world = px.choropleth(
    latest_data,
    locations='Entity',
    locationmode='country names',
    color=main_metric,
    hover_name='Entity',
    title=f'Global {main_metric.replace("_", " ").title()} - {latest_year}'
)
fig_world.write_html(str(viz_path / 'world_map.html'))

# Save global trend interactive
fig_global_interactive = px.line(global_stats.reset_index(), x='Year', y='mean',
                                title=f'Interactive Global {main_metric.replace("_", " ").title()} Trend')
fig_global_interactive.write_html(str(viz_path / 'global_trend_interactive.html'))

print(f"✓ Interactive plots saved to {viz_path}")

# Create visualization summary
viz_summary = {
    'static_plots': [
        'global_trend.png',
        'global_trend.pdf',
        'regional_comparison.png',
        'regional_comparison.pdf'
    ],
    'interactive_plots': [
        'world_map.html',
        'global_trend_interactive.html'
    ],
    'analysis_parameters': {
        'primary_metric': main_metric,
        'year_range': [earliest_year, latest_year],
        'countries_analyzed': data['Entity'].nunique(),
        'total_records': len(data)
    }
}

# Save summary
import json
with open(viz_path / 'visualization_summary.json', 'w') as f:
    json.dump(viz_summary, f, indent=2)

print(f"✓ Visualization summary saved")
print(f"\n🎉 All visualizations completed successfully!")
print(f"📁 Visualizations saved in: {viz_path.absolute()}")
print(f"\n📊 Summary:")
print(f"  • Static plots: {len(viz_summary['static_plots'])} files")
print(f"  • Interactive plots: {len(viz_summary['interactive_plots'])} files")
print(f"  • Primary metric: {main_metric.replace('_', ' ').title()}")
print(f"  • Data coverage: {data['Entity'].nunique()} countries, {earliest_year}-{latest_year}")

print(f"\n🚀 Project visualization phase completed!")
print(f"📈 Next: Run the dashboard with 'python src/dashboard/app.py'")