In [1]:
%pip install -q datasetsforecast neuralforecast utilsforecast statsforecast

Note: you may need to restart the kernel to use updated packages.


In [None]:
import pandas as pd
import matplotlib.pyplot as plt
from neuralforecast import NeuralForecast
from neuralforecast.models import KAN
from sklearn.metrics import mean_squared_error

# Load the dataset
df = pd.read_csv("datasets/dailymintemp.csv")
df = df.rename(columns={"Date": "ds", "Daily minimum temperatures": "y"})
df["unique_id"] = "TempSeries"
df["ds"] = pd.to_datetime(df["ds"])

# Split into training and testing sets
train_df = df[df["ds"] < "1990-01-01"]
test_df = df[df["ds"] >= "1990-01-01"]

# Define forecast horizon
horizon = 7  # Predicting the next week

# Configure the KAN model
kan_model = KAN(
    h=horizon,
    input_size=2 * horizon,
    scaler_type='standard',
    max_steps=1000,
    early_stop_patience_steps=3
)

# FLOPs Calculation Function
def calculate_kan_flops(din, dout, G, K):
    return (din * dout) * (9 * K * (G + 1.5 * K) + 2 * G - 2.5 * K - 1)

# Parameters for FLOP calculation
din = kan_model.input_size       # input size to KAN
dout = kan_model.h               # forecast horizon
grid_size = 5                    # Assumed grid size used in KAN
spline_order = 3                 # Assumed spline order for KAN
flops = calculate_kan_flops(din, dout, grid_size, spline_order)
print(f"Estimated FLOPs for KAN model: {flops}")

# Initialize NeuralForecast
nf = NeuralForecast(models=[kan_model], freq='D')

# Calculate validation size
val_size = int(0.1 * len(train_df))

# Train the model (pass val_size)
nf.fit(df=train_df, val_size=val_size)

# Generate predictions
neural_preds = nf.predict(futr_df=test_df[['unique_id', 'ds']])

# Merge predictions with actual data
results = pd.merge(test_df, neural_preds, on=['unique_id', 'ds'], how='left')

# Calculate Mean Squared Error (MSE) for accuracy
mse = mean_squared_error(results['y'], results['KAN'])
print(f"Mean Squared Error (MSE) on Test Data: {mse:.4f}")

# Visualize the results
plt.figure(figsize=(10, 6))
plt.plot(results['ds'], results['y'], label='Actual')
plt.plot(results['ds'], results['KAN'], label='Predicted (KAN)')
plt.legend()
plt.xlabel('Date')
plt.ylabel('Daily Minimum Temperature')
plt.title('KAN Forecast for Daily Minimum Temperatures')
plt.show()
