In [4]:
# --- Import libraries ---
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import statsmodels.api as sm
from statsmodels.tsa.seasonal import STL
from statsmodels.tsa.statespace.sarimax import SARIMAX
from prophet import Prophet
from pandas.plotting import register_matplotlib_converters

register_matplotlib_converters()

# --- Load and clean your data ---
file_path = '/mnt/data/fludataset2023-2024.xlsx'

# Read the file, skipping the first two rows (important)
df = pd.read_excel(file_path, sheet_name='Sheet1', skiprows=2)

# Rename columns properly
df.columns = ['Surveillance_Week', 'Year', 'A_Subtyped', 'A_H3N2', 'A_H1N1', 'Influenza_B', 'Percent_Positive_A', 'Percent_Positive_B']

# Create a 'Date' column (week ending on Saturday)
df['Date'] = pd.to_datetime(df['Year'].astype(str) + '-W' + df['Surveillance_Week'].astype(str) + '-6', format='%G-W%V-%u')

# Calculate total Influenza Cases
df['Cases'] = df['A_Subtyped'] + df['A_H3N2'] + df['A_H1N1'] + df['Influenza_B']

# Create a 'Virus' column to match your filtering logic
df['Virus'] = 'Influenza'

# Set Date as index
df.set_index('Date', inplace=True)

# --- Prepare for modeling ---
df_single = df[df['Virus'] == 'Influenza'].sort_index()
df_weekly = df_single['Cases'].asfreq('W-SAT')
df_weekly = df_weekly.interpolate(method='linear')

# --- Fourier Terms ---
def add_fourier_terms(df, period=52, order=2):
    t = np.arange(len(df))
    for k in range(1, order + 1):
        df[f'sin_{k}'] = np.sin(2 * np.pi * k * t / period)
        df[f'cos_{k}'] = np.cos(2 * np.pi * k * t / period)
    return df

# Add Fourier terms
df_fourier = df_weekly.to_frame(name='Cases')
df_fourier = add_fourier_terms(df_fourier.copy(), period=52, order=2)
X = df_fourier[[col for col in df_fourier.columns if 'sin' in col or 'cos' in col]]
X = sm.add_constant(X)
y = df_fourier['Cases']
model = sm.OLS(y, X).fit()
df_fourier['Fitted'] = model.predict(X)

# Forecast Fourier into future
future_index = pd.date_range(start=df_fourier.index[-1] + pd.Timedelta(weeks=1), periods=12, freq='W-SAT')
df_future = pd.DataFrame(index=future_index)
df_future = add_fourier_terms(df_future, period=52, order=2)
df_future = sm.add_constant(df_future)
df_future['Forecast'] = model.predict(df_future)

# --- STL decomposition (log scale) ---
df_log = np.log(df_weekly + 1e-6)  # Avoid log(0)
stl_log = STL(df_log, period=52)
result_log = stl_log.fit()
stl_fitted = np.exp(result_log.trend + result_log.seasonal)

# --- SARIMA model ---
sarima_model = SARIMAX(df_weekly, order=(1, 1, 1), seasonal_order=(1, 1, 1, 42))
sarima_result = sarima_model.fit(disp=False)
sarima_forecast = sarima_result.get_forecast(steps=12)
sarima_pred = sarima_forecast.predicted_mean

# --- Prophet model ---
df_prophet = df_weekly.reset_index()
df_prophet.columns = ['ds', 'y']  # Prophet needs 'ds' and 'y'

prophet = Prophet(weekly_seasonality=True)
prophet.fit(df_prophet)

df_future_prophet = prophet.make_future_dataframe(periods=12, freq='W-SAT')
forecast_prophet = prophet.predict(df_future_prophet)

# --- Plot all results ---
plt.figure(figsize=(14, 8))

# Actual data points
plt.plot(df_weekly.index, df_weekly, label='Actual Cases', color='black', marker='o')

# Fourier fit and forecast
plt.plot(df_fourier.index, df_fourier['Fitted'], label='Fourier Fit', color='red')
plt.plot(df_future.index, df_future['Forecast'], '--', color='red', label='Fourier Forecast')

# STL Trend + Seasonality
plt.plot(df_weekly.index, stl_fitted, label='STL (log) Trend + Seasonality', color='green')

# SARIMA Forecast
plt.plot(sarima_pred.index, sarima_pred.values, '--', label='SARIMA Forecast', color='blue')

# Prophet Forecast
plt.plot(forecast_prophet['ds'], forecast_prophet['yhat'], '--', label='Prophet Forecast', color='orange')

# --- Custom X-axis: show Surveillance Weeks ---
weeks = df_single['Surveillance_Week'].values

# Set ticks and labels
plt.xticks(ticks=df_weekly.index[::2],  # every 2 weeks to avoid clutter
           labels=weeks[::2], rotation=45, ha='right', fontsize=10)

plt.title('Flu Case Forecasting with Fourier, STL (log), SARIMA, and Prophet', fontsize=16)
plt.xlabel('Surveillance Week', fontsize=14)
plt.ylabel('Cases', fontsize=14)
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()


ValueError: time data "Year-WSurveilaince Week-6" doesn't match format "%G-W%V-%u", at position 0. You might want to try:
    - passing `format` if your strings have a consistent format;
    - passing `format='ISO8601'` if your strings are all ISO8601 but not necessarily in exactly the same format;
    - passing `format='mixed'`, and the format will be inferred for each element individually. You might want to use `dayfirst` alongside this.