# Mental Health Time Series Analysis

This notebook performs comprehensive time series analysis including:
- Trend analysis and decomposition
- Seasonality detection
- Forecasting future trends
- Cross-country comparative analysis
- Statistical significance testing

## 1. Setup and Imports

In [None]:
import sys
import os

# Add src directory to path
sys.path.append('../src')

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Time series specific imports
from statsmodels.tsa.seasonal import seasonal_decompose
from statsmodels.tsa.arima.model import ARIMA
from statsmodels.tsa.holtwinters import ExponentialSmoothing
from sklearn.metrics import mean_absolute_error, mean_squared_error
from scipy import stats

# Import custom modules
from analysis.time_series import TimeSeriesAnalyzer
from analysis.statistical_tests import StatisticalTests
from visualization.static_plots import StaticPlots

# Configure plotting
plt.style.use('default')
sns.set_palette("husl")
plt.rcParams['figure.figsize'] = (14, 8)

print("✓ All libraries imported successfully")

## 2. Load Cleaned Data

In [None]:
# Load processed data
data_path = Path("../data/processed/mental_health_cleaned.csv")

if data_path.exists():
    data = pd.read_csv(data_path)
    print(f"✓ Cleaned data loaded: {data.shape}")
else:
    print("❌ Cleaned data not found. Please run data cleaning notebook first.")
    # Create sample data for demonstration
    np.random.seed(42)
    years = list(range(1990, 2024))
    countries = ['United States', 'United Kingdom', 'Germany', 'France', 'Japan']
    
    sample_data = []
    for country in countries:
        base_rate = np.random.uniform(4, 8)
        for year in years:
            trend = (year - 1990) * 0.03
            noise = np.random.normal(0, 0.5)
            sample_data.append({
                'Entity': country,
                'Year': year,
                'Depression_prevalence': max(0, base_rate + trend + noise),
                'Anxiety_prevalence': max(0, base_rate * 0.8 + trend + noise)
            })
    
    data = pd.DataFrame(sample_data)
    print(f"⚠️ Using sample data: {data.shape}")

# Display basic info
print(f"\nDataset overview:")
print(f"  Countries: {data['Entity'].nunique()}")
print(f"  Years: {data['Year'].min()} - {data['Year'].max()}")
print(f"  Records: {len(data):,}")

# Identify prevalence columns
prevalence_cols = [col for col in data.columns if 'prevalence' in col.lower()]
print(f"  Mental health metrics: {len(prevalence_cols)}")
for col in prevalence_cols:
    print(f"    • {col}")

## 3. Initialize Analysis Tools

In [None]:
# Initialize analysis modules
ts_analyzer = TimeSeriesAnalyzer()
stat_tests = StatisticalTests()
plotter = StaticPlots()

print("✓ Analysis tools initialized")

# Set analysis parameters
country_col = 'Entity'
year_col = 'Year'
main_metric = prevalence_cols[0] if prevalence_cols else None

print(f"Primary analysis metric: {main_metric}")

## 4. Global Trend Analysis

In [None]:
if main_metric:
    print("=== GLOBAL TREND ANALYSIS ===")
    
    # Calculate global average by year
    global_trends = data.groupby(year_col)[main_metric].agg(['mean', 'std', 'count']).reset_index()
    global_trends.columns = [year_col, 'Mean_Prevalence', 'Std_Prevalence', 'Country_Count']
    
    print(f"Global trend data shape: {global_trends.shape}")
    print(f"Year range: {global_trends[year_col].min()} - {global_trends[year_col].max()}")
    
    # Analyze global trend
    trend_analysis = ts_analyzer.analyze_trend(global_trends[year_col], global_trends['Mean_Prevalence'])
    
    print(f"\nGlobal trend analysis:")
    print(f"  Slope: {trend_analysis['slope']:.4f} per year")
    print(f"  P-value: {trend_analysis['p_value']:.6f}")
    print(f"  R-squared: {trend_analysis['r_squared']:.4f}")
    print(f"  Trend direction: {trend_analysis['trend_direction']}")
    print(f"  Significance: {'Significant' if trend_analysis['is_significant'] else 'Not significant'}")
    
    # Visualize global trend
    plt.figure(figsize=(15, 10))
    
    # Main trend plot
    plt.subplot(2, 2, 1)
    plt.plot(global_trends[year_col], global_trends['Mean_Prevalence'], 'o-', linewidth=2, markersize=6)
    plt.fill_between(global_trends[year_col], 
                     global_trends['Mean_Prevalence'] - global_trends['Std_Prevalence'],
                     global_trends['Mean_Prevalence'] + global_trends['Std_Prevalence'],
                     alpha=0.3)
    
    # Add trend line
    z = np.polyfit(global_trends[year_col], global_trends['Mean_Prevalence'], 1)
    p = np.poly1d(z)
    plt.plot(global_trends[year_col], p(global_trends[year_col]), '--r', linewidth=2, label=f'Trend (slope={z[0]:.4f})')
    
    plt.xlabel('Year')
    plt.ylabel(f'Average {main_metric.replace("_", " ").title()} (%)')
    plt.title('Global Mental Health Trend')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Year-over-year change
    plt.subplot(2, 2, 2)
    yoy_change = global_trends['Mean_Prevalence'].pct_change() * 100
    plt.bar(global_trends[year_col][1:], yoy_change[1:], alpha=0.7)
    plt.axhline(y=0, color='red', linestyle='--', alpha=0.7)
    plt.xlabel('Year')
    plt.ylabel('Year-over-Year Change (%)')
    plt.title('Annual Change in Global Prevalence')
    plt.grid(True, alpha=0.3)
    
    # Distribution over time
    plt.subplot(2, 2, 3)
    plt.hist(global_trends['Mean_Prevalence'], bins=15, alpha=0.7, edgecolor='black')
    plt.xlabel(f'{main_metric.replace("_", " ").title()} (%)')
    plt.ylabel('Frequency')
    plt.title('Distribution of Annual Global Averages')
    plt.grid(True, alpha=0.3)
    
    # Data availability
    plt.subplot(2, 2, 4)
    plt.plot(global_trends[year_col], global_trends['Country_Count'], 'g-o', linewidth=2)
    plt.xlabel('Year')
    plt.ylabel('Number of Countries')
    plt.title('Data Availability Over Time')
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

## 5. Country-Specific Trend Analysis

In [None]:
if main_metric:
    print("=== COUNTRY-SPECIFIC TRENDS ===")
    
    # Get top countries by data availability
    country_data_counts = data.groupby(country_col).size().sort_values(ascending=False)
    top_countries = country_data_counts.head(10).index.tolist()
    
    print(f"Analyzing trends for top {len(top_countries)} countries with most data")
    
    # Analyze trends for each country
    country_trends = {}
    trend_summary = []
    
    for country in top_countries:
        country_data = data[data[country_col] == country].copy()
        country_data = country_data.sort_values(year_col)
        
        if len(country_data) >= 10:  # Need sufficient data points
            trend_analysis = ts_analyzer.analyze_trend(
                country_data[year_col], 
                country_data[main_metric]
            )
            
            country_trends[country] = {
                'data': country_data,
                'trend': trend_analysis
            }
            
            trend_summary.append({
                'Country': country,
                'Slope': trend_analysis['slope'],
                'P_Value': trend_analysis['p_value'],
                'R_Squared': trend_analysis['r_squared'],
                'Trend_Direction': trend_analysis['trend_direction'],
                'Is_Significant': trend_analysis['is_significant'],
                'Data_Points': len(country_data)
            })
    
    # Create trend summary dataframe
    trend_df = pd.DataFrame(trend_summary)
    trend_df = trend_df.sort_values('Slope', ascending=False)
    
    print(f"\nTrend analysis completed for {len(trend_df)} countries")
    print(f"Significant increasing trends: {len(trend_df[(trend_df['Is_Significant']) & (trend_df['Slope'] > 0)])}")
    print(f"Significant decreasing trends: {len(trend_df[(trend_df['Is_Significant']) & (trend_df['Slope'] < 0)])}")
    
    # Display trend summary
    print("\nTop 10 Countries by Trend Slope:")
    display(trend_df.head(10)[['Country', 'Slope', 'Trend_Direction', 'Is_Significant', 'R_Squared']])

## 6. Visualize Country Trends

In [None]:
if country_trends and main_metric:
    print("=== COUNTRY TREND VISUALIZATION ===")
    
    # Plot trends for top countries
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    axes = axes.flatten()
    
    countries_to_plot = list(country_trends.keys())[:6]
    colors = plt.cm.Set1(np.linspace(0, 1, len(countries_to_plot)))
    
    for i, country in enumerate(countries_to_plot):
        country_info = country_trends[country]
        country_data = country_info['data']
        trend_info = country_info['trend']
        
        # Plot data points
        axes[i].scatter(country_data[year_col], country_data[main_metric], 
                       alpha=0.7, s=50, color=colors[i])
        
        # Plot trend line
        z = np.polyfit(country_data[year_col], country_data[main_metric], 1)
        p = np.poly1d(z)
        axes[i].plot(country_data[year_col], p(country_data[year_col]), 
                    '--', linewidth=2, color='red')
        
        # Formatting
        significance = '***' if trend_info['p_value'] < 0.001 else '**' if trend_info['p_value'] < 0.01 else '*' if trend_info['p_value'] < 0.05 else ''
        title = f"{country}\nSlope: {trend_info['slope']:.4f}{significance}\nR²: {trend_info['r_squared']:.3f}"
        
        axes[i].set_title(title, fontsize=10)
        axes[i].set_xlabel('Year')
        axes[i].set_ylabel(f'{main_metric.replace("_", " ").title()} (%)')
        axes[i].grid(True, alpha=0.3)
    
    plt.suptitle(f'Mental Health Trends by Country ({main_metric.replace("_", " ").title()})', fontsize=16)
    plt.tight_layout()
    plt.show()
    
    # Comparative trend plot
    plt.figure(figsize=(15, 8))
    
    for i, country in enumerate(countries_to_plot):
        country_data = country_trends[country]['data']
        plt.plot(country_data[year_col], country_data[main_metric], 
                'o-', label=country, linewidth=2, markersize=4)
    
    plt.xlabel('Year')
    plt.ylabel(f'{main_metric.replace("_", " ").title()} (%)')
    plt.title('Comparative Mental Health Trends')
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

## 7. Time Series Decomposition

In [None]:
if main_metric and len(global_trends) >= 8:
    print("=== TIME SERIES DECOMPOSITION ===")
    
    # Perform decomposition on global trend
    try:
        # Prepare time series
        ts_data = global_trends.set_index(year_col)['Mean_Prevalence']
        
        # Decomposition
        decomposition = seasonal_decompose(ts_data, model='additive', period=min(4, len(ts_data)//2))
        
        # Plot decomposition
        fig, axes = plt.subplots(4, 1, figsize=(15, 12))
        
        # Original
        decomposition.observed.plot(ax=axes[0], title='Original Time Series')
        axes[0].set_ylabel('Prevalence (%)')
        axes[0].grid(True, alpha=0.3)
        
        # Trend
        decomposition.trend.plot(ax=axes[1], title='Trend Component', color='orange')
        axes[1].set_ylabel('Trend')
        axes[1].grid(True, alpha=0.3)
        
        # Seasonal
        decomposition.seasonal.plot(ax=axes[2], title='Seasonal Component', color='green')
        axes[2].set_ylabel('Seasonal')
        axes[2].grid(True, alpha=0.3)
        
        # Residual
        decomposition.resid.plot(ax=axes[3], title='Residual Component', color='red')
        axes[3].set_ylabel('Residual')
        axes[3].set_xlabel('Year')
        axes[3].grid(True, alpha=0.3)
        
        plt.suptitle('Time Series Decomposition - Global Mental Health Trend', fontsize=16)
        plt.tight_layout()
        plt.show()
        
        # Analyze components
        trend_strength = 1 - (decomposition.resid.var() / decomposition.observed.var())
        seasonal_strength = 1 - (decomposition.resid.var() / (decomposition.observed - decomposition.trend).var())
        
        print(f"\nDecomposition Analysis:")
        print(f"  Trend strength: {trend_strength:.3f}")
        print(f"  Seasonal strength: {seasonal_strength:.3f}")
        print(f"  Residual variance: {decomposition.resid.var():.4f}")
        
    except Exception as e:
        print(f"Could not perform decomposition: {e}")
        print("This might be due to insufficient data points or irregular time series")
else:
    print("Insufficient data for time series decomposition (need at least 8 data points)")

## 8. Forecasting

In [None]:
if main_metric and len(global_trends) >= 5:
    print("=== MENTAL HEALTH TREND FORECASTING ===")
    
    # Prepare data for forecasting
    ts_data = global_trends.set_index(year_col)['Mean_Prevalence']
    forecast_horizon = 5  # 5 years into the future
    
    print(f"Training data: {len(ts_data)} years ({ts_data.index.min()} - {ts_data.index.max()})")
    print(f"Forecast horizon: {forecast_horizon} years")
    
    # Method 1: Linear trend extrapolation
    print("\n1. Linear Trend Extrapolation")
    future_years = list(range(ts_data.index.max() + 1, ts_data.index.max() + forecast_horizon + 1))
    
    # Fit linear trend
    z = np.polyfit(ts_data.index, ts_data.values, 1)
    linear_forecast = np.poly1d(z)(future_years)
    
    # Method 2: Exponential Smoothing (if possible)
    exp_forecast = None
    try:
        print("\n2. Exponential Smoothing")
        exp_model = ExponentialSmoothing(ts_data, trend='add', seasonal=None)
        exp_fit = exp_model.fit()
        exp_forecast = exp_fit.forecast(forecast_horizon)
        print(f"  Exponential smoothing successful")
    except Exception as e:
        print(f"  Exponential smoothing failed: {e}")
    
    # Method 3: ARIMA (if possible)
    arima_forecast = None
    try:
        print("\n3. ARIMA Model")
        arima_model = ARIMA(ts_data, order=(1, 1, 1))
        arima_fit = arima_model.fit()
        arima_forecast = arima_fit.forecast(forecast_horizon)
        print(f"  ARIMA model successful")
    except Exception as e:
        print(f"  ARIMA model failed: {e}")
    
    # Visualize forecasts
    plt.figure(figsize=(15, 10))
    
    # Historical data
    plt.plot(ts_data.index, ts_data.values, 'o-', linewidth=2, markersize=6, label='Historical Data')
    
    # Linear forecast
    plt.plot(future_years, linear_forecast, 's--', linewidth=2, label='Linear Trend Forecast', alpha=0.8)
    
    # Exponential smoothing forecast
    if exp_forecast is not None:
        plt.plot(future_years, exp_forecast, '^--', linewidth=2, label='Exponential Smoothing', alpha=0.8)
    
    # ARIMA forecast
    if arima_forecast is not None:
        plt.plot(future_years, arima_forecast, 'd--', linewidth=2, label='ARIMA Forecast', alpha=0.8)
    
    # Add vertical line to separate historical from forecast
    plt.axvline(x=ts_data.index.max(), color='red', linestyle=':', alpha=0.7, label='Forecast Start')
    
    plt.xlabel('Year')
    plt.ylabel(f'{main_metric.replace("_", " ").title()} (%)')
    plt.title('Mental Health Trend Forecasting')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()
    
    # Display forecast values
    print(f"\n=== FORECAST RESULTS ===")
    forecast_df = pd.DataFrame({
        'Year': future_years,
        'Linear_Trend': linear_forecast
    })
    
    if exp_forecast is not None:
        forecast_df['Exponential_Smoothing'] = exp_forecast.values
    
    if arima_forecast is not None:
        forecast_df['ARIMA'] = arima_forecast.values
    
    display(forecast_df)
    
    # Calculate forecast statistics
    print(f"\nForecast Summary:")
    print(f"  Linear trend change over {forecast_horizon} years: {linear_forecast[-1] - linear_forecast[0]:.3f}%")
    if exp_forecast is not None:
        print(f"  Exponential smoothing change: {exp_forecast.iloc[-1] - exp_forecast.iloc[0]:.3f}%")
    if arima_forecast is not None:
        print(f"  ARIMA change: {arima_forecast.iloc[-1] - arima_forecast.iloc[0]:.3f}%")
        
else:
    print("Insufficient data for forecasting (need at least 5 data points)")

## 9. Statistical Significance Testing

In [None]:
if main_metric and len(trend_df) >= 2:
    print("=== STATISTICAL SIGNIFICANCE TESTING ===")
    
    # Test 1: Are trends significantly different from zero?
    print("\n1. Testing if country trends are significantly different from zero:")
    slopes = trend_df['Slope'].values
    t_stat, p_value = stats.ttest_1samp(slopes, 0)
    
    print(f"  Mean slope: {slopes.mean():.6f}")
    print(f"  Standard deviation: {slopes.std():.6f}")
    print(f"  T-statistic: {t_stat:.4f}")
    print(f"  P-value: {p_value:.6f}")
    print(f"  Conclusion: {'Trends are significantly different from zero' if p_value < 0.05 else 'No significant trend'}")
    
    # Test 2: Compare trends between regions (if region data available)
    if 'Region' in data.columns:
        print("\n2. Regional comparison of trends:")
        
        # Calculate regional trends
        regional_trends = []
        for region in data['Region'].unique():
            region_data = data[data['Region'] == region]
            if len(region_data) >= 10:
                region_trend = region_data.groupby(year_col)[main_metric].mean()
                if len(region_trend) >= 5:
                    slope, _, _, p_val, _ = stats.linregress(region_trend.index, region_trend.values)
                    regional_trends.append({
                        'Region': region,
                        'Slope': slope,
                        'P_Value': p_val,
                        'Countries': region_data['Entity'].nunique()
                        'Data_Points': len(region_data)
                    })
        
        if regional_trends:
            regional_df = pd.DataFrame(regional_trends)
            print("  Regional trend analysis:")
            display(regional_df)
            
            # Test if regional trends are significantly different
            if len(regional_df) >= 2:
                f_stat, f_p_value = stats.f_oneway(*[regional_df[regional_df['Region'] == region]['Slope'].values for region in regional_df['Region']])
                print(f"\n  ANOVA F-test for regional differences:")
                print(f"    F-statistic: {f_stat:.4f}")
                print(f"    P-value: {f_p_value:.6f}")
                print(f"    Conclusion: {'Significant regional differences' if f_p_value < 0.05 else 'No significant regional differences'}")
    
    # Test 3: Correlation between prevalence metrics
    if len(prevalence_cols) >= 2:
        print("\n3. Correlation between mental health metrics:")
        
        correlation_matrix = data[prevalence_cols].corr()
        print("  Correlation matrix:")
        display(correlation_matrix)
        
        # Visualize correlation matrix
        plt.figure(figsize=(10, 8))
        mask = np.triu(np.ones_like(correlation_matrix, dtype=bool))
        sns.heatmap(correlation_matrix, mask=mask, annot=True, cmap='coolwarm', center=0,
                   square=True, fmt='.3f', cbar_kws={'label': 'Correlation Coefficient'})
        plt.title('Correlation Between Mental Health Metrics')
        plt.tight_layout()
        plt.show()
        
        # Test significance of correlations
        print("\n  Correlation significance tests:")
        for i, col1 in enumerate(prevalence_cols):
            for col2 in prevalence_cols[i+1:]:
                clean_data = data[[col1, col2]].dropna()
                if len(clean_data) >= 10:
                    corr_coef, corr_p_value = stats.pearsonr(clean_data[col1], clean_data[col2])
                    significance = '***' if corr_p_value < 0.001 else '**' if corr_p_value < 0.01 else '*' if corr_p_value < 0.05 else ''
                    print(f"    {col1} vs {col2}: r = {corr_coef:.3f}{significance} (p = {corr_p_value:.6f})")
else:
    print("Insufficient data for statistical testing")

## 10. Save Analysis Results

In [None]:
# Save time series analysis results
results_path = Path("../data/processed")
results_path.mkdir(parents=True, exist_ok=True)

# Compile analysis results
analysis_results = {
    'global_trend': {
        'data_shape': global_trends.shape if 'global_trends' in locals() else (0, 0),
        'trend_analysis': trend_analysis if 'trend_analysis' in locals() else {},
        'year_range': [global_trends[year_col].min(), global_trends[year_col].max()] if 'global_trends' in locals() else []
    },
    'country_trends': {
        'countries_analyzed': len(trend_df) if 'trend_df' in locals() else 0,
        'significant_increasing': len(trend_df[(trend_df['Is_Significant']) & (trend_df['Slope'] > 0)]) if 'trend_df' in locals() else 0,
        'significant_decreasing': len(trend_df[(trend_df['Is_Significant']) & (trend_df['Slope'] < 0)]) if 'trend_df' in locals() else 0
    },
    'forecasting': {
        'methods_used': ['Linear Trend'],
        'forecast_horizon': forecast_horizon if 'forecast_horizon' in locals() else 0,
        'linear_forecast': linear_forecast.tolist() if 'linear_forecast' in locals() else []
    },
    'statistical_tests': {
        'trend_significance_test': {
            'mean_slope': slopes.mean() if 'slopes' in locals() else 0,
            'p_value': p_value if 'p_value' in locals() else 1,
            'significant': p_value < 0.05 if 'p_value' in locals() else False
        }
    }
}

# Save results to JSON
import json
results_file = results_path / "time_series_analysis_results.json"
with open(results_file, 'w') as f:
    json.dump(analysis_results, f, indent=2)

print(f"✅ Time series analysis results saved to: {results_file}")

# Save country trends if available
if 'trend_df' in locals():
    trend_file = results_path / "country_trends.csv"
    trend_df.to_csv(trend_file, index=False)
    print(f"✅ Country trends saved to: {trend_file}")

# Save forecast data if available
if 'forecast_df' in locals():
    forecast_file = results_path / "forecast_results.csv"
    forecast_df.to_csv(forecast_file, index=False)
    print(f"✅ Forecast results saved to: {forecast_file}")

print(f"\n🎉 Time series analysis completed successfully!")
print(f"📊 Key findings:")
if 'trend_analysis' in locals():
    print(f"  • Global trend: {trend_analysis['trend_direction']} ({'significant' if trend_analysis['is_significant'] else 'not significant'})")
if 'trend_df' in locals():
    print(f"  • Countries analyzed: {len(trend_df)}")
    print(f"  • Countries with significant trends: {len(trend_df[trend_df['Is_Significant']])}")
if 'forecast_horizon' in locals():
    print(f"  • Forecast horizon: {forecast_horizon} years")

print(f"\n📁 Proceed to notebook 04_visualization.ipynb for advanced visualizations")