In [2]:
# 03_forecasting.ipynb, still in progress

import pandas as pd
import numpy as np
from statsmodels.tsa.holtwinters import ExponentialSmoothing
from statsmodels.tsa.statespace.sarimax import SARIMAX
from sklearn.metrics import mean_absolute_error, mean_squared_error
import plotly.express as px
import plotly.graph_objects as go
from sqlalchemy import create_engine
import warnings
warnings.filterwarnings('ignore')

# Database connection
DATABASE_URI = 'postgresql://woodylin@localhost:5432/healthcare_costs'
engine = create_engine(DATABASE_URI)

# Read cleaned data
df = pd.read_csv('../data/processed/cleaned_medicare_spending.csv')

# 1. Prepare Time Series Data
def prepare_time_series():
    """Prepare yearly spending data as time series"""
    years = ['2018', '2019', '2020', '2021', '2022']
    yearly_spending = {}
    
    for year in years:
        yearly_spending[year] = df[f'Tot_Spndng_{year}'].sum()
    
    # Convert to time series
    ts_data = pd.Series(yearly_spending, name='total_spending')
    ts_data.index = pd.to_datetime(ts_data.index, format='%Y')
    
    return ts_data

# 2. Simple Exponential Smoothing Forecast
def exponential_smoothing_forecast(ts_data, periods=2):
    """Generate forecast using Exponential Smoothing"""
    # Fit model
    model = ExponentialSmoothing(ts_data,
                                trend='add',
                                seasonal=None,
                                seasonal_periods=None)
    fitted_model = model.fit()
    
    # Make forecast
    forecast = fitted_model.forecast(periods)
    
    # Calculate confidence intervals using residuals
    resid = fitted_model.resid
    resid_std = resid.std()
    
    # Create forecast DataFrame with confidence intervals
    forecast_ci = pd.DataFrame(index=forecast.index)
    forecast_ci['forecast'] = forecast
    forecast_ci['lower'] = forecast - 1.96 * resid_std
    forecast_ci['upper'] = forecast + 1.96 * resid_std
    
    return forecast_ci

# 3. SARIMA Forecast
def sarima_forecast(ts_data, periods=2):
    """Generate forecast using SARIMA"""
    # Fit SARIMA model
    model = SARIMAX(ts_data, order=(1,1,1))
    fitted_model = model.fit()
    
    # Make forecast
    forecast = fitted_model.forecast(periods)
    forecast_ci = fitted_model.get_forecast(periods).conf_int()
    
    # Combine forecast and confidence intervals
    forecast_results = pd.DataFrame({
        'forecast': forecast,
        'lower': forecast_ci.iloc[:,0],
        'upper': forecast_ci.iloc[:,1]
    })
    
    return forecast_results

# 4. Visualize Forecasts
def plot_forecasts(ts_data, exp_forecast, sarima_forecast):
    """Create visualization of forecasts"""
    # Create figure
    fig = go.Figure()
    
    # Add historical data
    fig.add_trace(go.Scatter(x=ts_data.index, 
                            y=ts_data.values/1e9,  # Convert to billions
                            mode='lines+markers',
                            name='Historical Data'))
    
    # Add exponential smoothing forecast
    fig.add_trace(go.Scatter(x=exp_forecast.index, 
                            y=exp_forecast['forecast']/1e9,
                            mode='lines',
                            line=dict(dash='dash'),
                            name='Exp Smoothing Forecast'))
    
    # Add SARIMA forecast
    fig.add_trace(go.Scatter(x=sarima_forecast.index, 
                            y=sarima_forecast['forecast']/1e9,
                            mode='lines',
                            line=dict(dash='dot'),
                            name='SARIMA Forecast'))
    
    # Add confidence intervals
    fig.add_trace(go.Scatter(x=exp_forecast.index.tolist() + exp_forecast.index.tolist()[::-1],
                            y=(exp_forecast['upper']/1e9).tolist() + (exp_forecast['lower']/1e9).tolist()[::-1],
                            fill='toself',
                            fillcolor='rgba(0,100,80,0.2)',
                            line=dict(color='rgba(255,255,255,0)'),
                            name='Exp Smoothing CI'))
    
    fig.update_layout(title='Medicare Part D Spending Forecast',
                     xaxis_title='Year',
                     yaxis_title='Total Spending (Billions $)',
                     showlegend=True)
    
    return fig

# 5. Drug-specific Forecasting
def forecast_top_drugs(n_drugs=5):
    """Generate forecasts for top drugs by spending"""
    # Get top drugs by 2022 spending
    top_drugs = df.nlargest(n_drugs, 'Tot_Spndng_2022')
    
    forecasts = {}
    for _, drug in top_drugs.iterrows():
        drug_spending = pd.Series({
            '2018': drug['Tot_Spndng_2018'],
            '2019': drug['Tot_Spndng_2019'],
            '2020': drug['Tot_Spndng_2020'],
            '2021': drug['Tot_Spndng_2021'],
            '2022': drug['Tot_Spndng_2022']
        })
        drug_spending.index = pd.to_datetime(drug_spending.index, format='%Y')
        
        # Forecast using exponential smoothing
        model = ExponentialSmoothing(drug_spending, trend='add')
        fitted = model.fit()
        forecast = fitted.forecast(2)
        
        forecasts[drug['Brnd_Name']] = forecast
    
    return pd.DataFrame(forecasts)

# 6. Calculate Forecast Accuracy Metrics
def calculate_accuracy_metrics(actual, predicted):
    """Calculate MAE and RMSE for model evaluation"""
    mae = mean_absolute_error(actual, predicted)
    rmse = np.sqrt(mean_squared_error(actual, predicted))
    return {'MAE': mae, 'RMSE': rmse}

# Run forecasting analysis
if __name__ == "__main__":
    print("Starting Medicare Part D Spending Forecast Analysis...\n")
    
    try:
        # Prepare time series data
        ts_data = prepare_time_series()
        print("Historical Spending Data (Billions $):")
        print(ts_data/1e9)
        
        # Generate forecasts
        exp_forecast = exponential_smoothing_forecast(ts_data)
        print("\nExponential Smoothing Forecast for 2023-2024 (Billions $):")
        print(exp_forecast/1e9)
        
        sarima_result = sarima_forecast(ts_data)
        print("\nSARIMA Forecast for 2023-2024 (Billions $):")
        print(sarima_result/1e9)
        
        # Generate and display visualization
        fig = plot_forecasts(ts_data, exp_forecast, sarima_result)
        fig.show()
        
        # Generate drug-specific forecasts
        top_drug_forecasts = forecast_top_drugs()
        print("\nTop Drug-Specific Forecasts for 2023-2024 (Millions $):")
        print(top_drug_forecasts/1e6)
        
        # Save forecasts
        exp_forecast.to_csv('../data/processed/total_spending_forecast.csv')
        top_drug_forecasts.to_csv('../data/processed/top_drugs_forecast.csv')
        
        print("\nForecasting analysis completed successfully!")
        
    except Exception as e:
        print(f"Error during forecasting: {str(e)}")

Starting Medicare Part D Spending Forecast Analysis...

Historical Spending Data (Billions $):
2018-01-01    327.033135
2019-01-01    362.151394
2020-01-01    395.873623
2021-01-01    431.217677
2022-01-01    480.863549
Name: total_spending, dtype: float64

Exponential Smoothing Forecast for 2023-2024 (Billions $):
              forecast       lower       upper
2023-01-01  515.906964  479.895108  551.918819
2024-01-01  551.023041  515.011185  587.034896
RUNNING THE L-BFGS-B CODE

           * * *

Machine precision = 2.220D-16
 N =            3     M =           10

At X0         0 variables are exactly at the bounds

At iterate    0    f=  3.32581D+01    |proj g|=  7.41658D+01

At iterate    5    f=  2.80929D+01    |proj g|=  1.85350D+01

At iterate   10    f=  2.12792D+01    |proj g|=  7.21469D+00

At iterate   15    f=  1.96207D+01    |proj g|=  1.56824D+00

At iterate   20    f=  1.93383D+01    |proj g|=  1.05005D+00

At iterate   25    f=  1.92860D+01    |proj g|=  2.94214D-01

At

 This problem is unconstrained.



Top Drug-Specific Forecasts for 2023-2024 (Millions $):
                 Eliquis    Trulicity     Revlimid
2023-01-01  17531.180170  7134.641213  6560.156855
2024-01-01  19844.214565  8047.057206  6978.689293

Forecasting analysis completed successfully!
