In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from statsmodels.tsa.seasonal import seasonal_decompose
from statsmodels.graphics.tsaplots import plot_acf, plot_pacf
from statsmodels.tsa.statespace.sarimax import SARIMAX
from prophet import Prophet
import warnings

warnings.filterwarnings("ignore")

# --- 1. Load and Explore Data ---
try:
    df = pd.read_csv('train.csv')
    print("DataFrame loaded successfully.")
    print(df.head())
    print("\nDataFrame Info:")
    df.info()
except FileNotFoundError:
    print("Error: train.csv not found. Please ensure the file is in the correct directory.")
    exit()

# Assuming the CSV has columns like 'Year', 'Month', and 'Sunspot_Count'
# Let's create a proper datetime index
df['Date'] = pd.to_datetime(df['Year'].astype(str) + '-' + df['Month'].astype(str), format='%Y-%m')
df.set_index('Date', inplace=True)

# Rename the sunspot count column if it's not 'Sunspot_Count'
# Based on typical sunspot datasets, it's often named 'monthly_average_sunspot_count' or similar.
# Let's check the columns and assume the last numeric column is the sunspot count if not explicitly named.
numeric_cols = df.select_dtypes(include=['float64', 'int64']).columns.tolist()
if 'Sunspot_Count' not in df.columns and len(numeric_cols) > 0:
    # Assuming the last numeric column after Year and Month (which we used for index) is the sunspot count
    # Let's be more robust: find the column that looks like sunspot data
    sunspot_col_candidates = [col for col in numeric_cols if col not in ['Year', 'Month']]
    if sunspot_col_candidates:
        df = df.rename(columns={sunspot_col_candidates[-1]: 'Sunspot_Count'})
        print(f"\nRenamed column '{sunspot_col_candidates[-1]}' to 'Sunspot_Count'.")
    else:
        print("\nCould not identify a suitable 'Sunspot_Count' column. Please check your CSV column names.")
        exit()
elif 'Sunspot_Count' not in df.columns:
     print("\nError: 'Sunspot_Count' column not found and could not be inferred. Please check your CSV column names.")
     exit()

# Select only the target variable for time series analysis
ts = df['Sunspot_Count'].asfreq('MS') # 'MS' for Month Start frequency

print(f"\nTime series shape: {ts.shape}")
print(f"Time series start: {ts.index.min()}")
print(f"Time series end: {ts.index.max()}")

# Visualize the time series
plt.figure(figsize=(15, 7))
plt.plot(ts)
plt.title('Monthly Average Sunspot Count (1749-2010)')
plt.xlabel('Date')
plt.ylabel('Sunspot Count')
plt.grid(True)
plt.show()

# Decompose the time series to observe trend, seasonality, and residuals
# Given the data spans centuries, a strong seasonal component (approx 11 years) is expected.
# Let's assume a yearly seasonality for decomposition for simplicity, as monthly data is given.
# For sunspots, the cycle is ~11 years, so a true 'seasonal' component might be longer.
# However, `seasonal_decompose` typically looks for shorter, fixed-period seasonality.
# We'll use additive model as sunspot counts are absolute.
decomposition = seasonal_decompose(ts, model='additive', period=12) # Period 12 for yearly seasonality
fig = decomposition.plot()
fig.set_size_inches(12, 8)
plt.tight_layout()
plt.show()

# Plot ACF and PACF to help identify AR and MA components for SARIMA
plt.figure(figsize=(14, 8))
plt.subplot(211)
plot_acf(ts, lags=50, ax=plt.gca(), title='Autocorrelation Function (ACF)')
plt.subplot(212)
plot_pacf(ts, lags=50, ax=plt.gca(), title='Partial Autocorrelation Function (PACF)')
plt.tight_layout()
plt.show()

# --- 2. Data Preprocessing (already done by setting index and frequency) ---
# The data is already monthly, so asfreq('MS') handles any missing frequencies by filling with NaN.
# Check for NaNs
print(f"\nNumber of NaNs in time series: {ts.isnull().sum()}")
# If there were NaNs, you'd typically handle them with interpolation:
# ts = ts.interpolate(method='time') # if NaNs exist

# --- 3. Model Selection and Training ---

# Define the training data (up to 2010)
train_data = ts[ts.index.year <= 2010]
print(f"\nTraining data ends: {train_data.index.max()}")

# --- Model 1: SARIMA (Seasonal AutoRegressive Integrated Moving Average) ---
# Sunspot data is known to have an ~11-year cycle, so seasonal_order (P,D,Q,S) will be important.
# S=12 for monthly data (yearly seasonality). For the 11-year cycle, S could be 12*11=132,
# but `seasonal_decompose` and `SARIMAX` with S=12 can capture shorter-term seasonality.
# To capture the 11-year cycle, we might need a very high seasonal order or a different approach.
# Let's start with a standard yearly seasonality (P,D,Q,12) and see.
# The ACF/PACF plots can help determine p, d, q, P, D, Q.
# A common approach is to try a few orders and use AIC/BIC to compare.

# For sunspots, a common SARIMA model might involve differencing (d=1 or 2) and seasonal differencing (D=1).
# Given the clear cycles, P, Q, p, q might also be non-zero.
# Let's try some common parameters first, and for a production model, you'd perform grid search.

# Example SARIMA parameters (these are often tuned):
# order=(p,d,q) - non-seasonal components
# seasonal_order=(P,D,Q,S) - seasonal components
# S=12 for monthly data, representing yearly seasonality.
# A more accurate seasonal period for sunspots could be 132 (11 years * 12 months).
# However, SARIMAX with a very large seasonal period can be computationally intensive.
# Let's try S=12, and acknowledge that the 11-year cycle is harder to capture perfectly with this S.

print("\n--- Training SARIMA Model ---")
try:
    sarima_model = SARIMAX(train_data,
                            order=(1, 1, 1), # Example non-seasonal order
                            seasonal_order=(1, 1, 1, 12), # Example seasonal order (P,D,Q,S)
                            enforce_stationarity=False,
                            enforce_invertibility=False)
    sarima_results = sarima_model.fit(disp=False)
    print(sarima_results.summary())
except Exception as e:
    print(f"Error training SARIMA model: {e}")
    print("SARIMA might struggle with very long cycles or non-stationary data. Consider differencing more or trying different orders.")

# --- Model 2: Facebook Prophet ---
# Prophet is designed for business forecasting and handles multiple seasonality (daily, weekly, yearly)
# automatically and is robust to missing data and outliers. It requires columns 'ds' (datetime) and 'y' (value).
print("\n--- Training Prophet Model ---")
prophet_df = train_data.reset_index()
prophet_df = prophet_df.rename(columns={'Date': 'ds', 'Sunspot_Count': 'y'})

# Instantiate and fit the model
prophet_model = Prophet(
    growth='linear',
    seasonality_mode='additive',
    yearly_seasonality=True, # Prophet will try to detect yearly seasonality
    weekly_seasonality=False, # Monthly data, no weekly
    daily_seasonality=False,  # Monthly data, no daily
    seasonality_prior_scale=10, # Adjust for stronger seasonality
    changepoint_prior_scale=0.05 # Adjust for more flexibility in trend changes
)

# For sunspot data, a custom seasonality might be beneficial to capture the ~11 year cycle
# Here's how you could add it:
prophet_model.add_seasonality(name='sunspot_cycle', period=11*12, fourier_order=10) # 11 years * 12 months

prophet_model.fit(prophet_df)

# --- 4. Forecasting ---

# Define the forecast period (2011-2020)
forecast_start_date = '2011-01-01'
forecast_end_date = '2020-12-01'
forecast_dates = pd.date_range(start=forecast_start_date, end=forecast_end_date, freq='MS')
print(f"\nForecast period: {forecast_start_date} to {forecast_end_date}")
print(f"Number of forecast months: {len(forecast_dates)}")

# --- SARIMA Forecast ---
sarima_forecast = pd.Series() # Initialize an empty series

if 'sarima_results' in locals():
    try:
        sarima_forecast_res = sarima_results.predict(start=forecast_start_date, end=forecast_end_date)
        sarima_forecast = sarima_forecast_res.rename('SARIMA_Forecast')
        print("\nSARIMA Forecast (first 5 values):\n", sarima_forecast.head())
    except Exception as e:
        print(f"Error generating SARIMA forecast: {e}")
        sarima_forecast = pd.Series(index=forecast_dates, dtype=float) # Empty forecast if error


# --- Prophet Forecast ---
future_prophet = prophet_model.make_future_dataframe(periods=len(forecast_dates), freq='MS')
prophet_forecast = prophet_model.predict(future_prophet)

# Filter Prophet forecast to just the forecast period
prophet_forecast_period = prophet_forecast[(prophet_forecast['ds'] >= forecast_start_date) &
                                           (prophet_forecast['ds'] <= forecast_end_date)]

prophet_forecast_values = prophet_forecast_period[['ds', 'yhat']].set_index('ds')['yhat'].rename('Prophet_Forecast')
print("\nProphet Forecast (first 5 values):\n", prophet_forecast_values.head())

# --- 5. Visualization ---

plt.figure(figsize=(18, 9))
plt.plot(train_data, label='Historical Data (1749-2010)')

if not sarima_forecast.empty:
    plt.plot(sarima_forecast, label='SARIMA Forecast (2011-2020)', color='orange', linestyle='--')
else:
    print("\nSARIMA forecast not available due to errors during training/forecasting.")

if not prophet_forecast_values.empty:
    plt.plot(prophet_forecast_values, label='Prophet Forecast (2011-2020)', color='green', linestyle='-.')
else:
    print("\nProphet forecast not available due to errors during training/forecasting.")

plt.title('Monthly Average Sunspot Count Forecast')
plt.xlabel('Date')
plt.ylabel('Sunspot Count')
plt.legend()
plt.grid(True)
plt.show()

# Display the forecasted values
print("\n--- Predicted Monthly Average Sunspot Count (2011-2020) ---")
print("\nSARIMA Forecast:")
print(sarima_forecast.to_string())

print("\nProphet Forecast:")
print(prophet_forecast_values.to_string())

# Save forecasts to CSV if needed
sarima_forecast.to_csv('sarima_sunspot_forecast_2011_2020.csv', header=True)
prophet_forecast_values.to_csv('prophet_sunspot_forecast_2011_2020.csv', header=True)
print("\nForecasts saved to 'sarima_sunspot_forecast_2011_2020.csv' and 'prophet_sunspot_forecast_2011_2020.csv'")