In [None]:
# ==================================================================================
# PowerCast: Electricity Price Forecasting Challenge - EDA Script
#
# This script performs exploratory data analysis (EDA) for the project.
# It loads all datasets, performs initial inspections, and visualizes trends,
# correlations, and lag effects that will help us to answer the challenge questions:
#
# 1. Market Trends & Price Fluctuations
# 2. Correlation & Feature Relationships
# 3. Price & Consumption Impact Analysis
# ==================================================================================

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
import pickle
import shap
import matplotlib.patches as mpatches

from sklearn.linear_model import LinearRegression
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
from sklearn.metrics import mean_squared_error, mean_absolute_error
from sklearn.model_selection import RandomizedSearchCV
from scipy.stats import randint
from math import sqrt

# =============================================================================
# SETTINGS & PLOTTING STYLE
# =============================================================================

sns.set_theme(style="whitegrid")
plt.rcParams["figure.figsize"] = (14, 7)

# =============================================================================
# HELPER FUNCTION: Load Excel File and Set Date Index
# =============================================================================

def load_file(file_path, date_col='Date'):
    df = pd.read_excel(file_path)
    df[date_col] = pd.to_datetime(df[date_col])
    df.set_index(date_col, inplace=True)
    return df

# =============================================================================
# Load Datasets
# =============================================================================

# Adjust these file paths as needed.
prices_path = r'path/tofile'
actual_cons_path = r'path/tofile'
forecast_cons_path = r'path/tofile'
actual_gen_path = r'path/tofile'
forecast_gen_da_path = r'path/tofile'
gen_forecast_intraday_path = r'path/tofile'
cross_border_path = r'path/tofile'
scheduled_exchanges_path = r'path/tofile'
imported_balancing_path = r'path/tofile'
exported_balancing_path = r'path/tofile'
balancing_energy_path = r'path/tofile'
costs_path = r'path/tofile' #costs of TSOs
auto_frr_path = r'path/tofile' #automatic frequency restoration reserve
fcr_path = r'path/tofile' #freq containment reserve

df_prices = load_file(prices_path)
df_actual_cons = load_file(actual_cons_path)
df_forecast_cons = load_file(forecast_cons_path)
df_actual_gen = load_file(actual_gen_path)
df_forecast_gen_da = load_file(forecast_gen_da_path)
df_gen_forecast_intraday = load_file(gen_forecast_intraday_path)
df_cross_border = load_file(cross_border_path)
df_scheduled_exchanges = load_file(scheduled_exchanges_path)
df_imported_balancing = load_file(imported_balancing_path)
df_exported_balancing = load_file(exported_balancing_path)
df_balancing_energy = load_file(balancing_energy_path)
df_costs = load_file(costs_path)
df_auto_frr = load_file(auto_frr_path)
df_fcr = load_file(fcr_path)

# Quick inspection of key DataFrames 
print("Day-Ahead Prices (first 5 rows):")
print(df_prices.head())
print("\nActual Consumption (first 5 rows):")
print(df_actual_cons.head())
print("\nForecasted Consumption (first 5 rows):")
print(df_forecast_cons.head())

# Hourly Trends for Day-Ahead Prices
# Hourly trends across different countries)

# Select all columns containing "dayahead" (Chartx will be too noisy)
all_price_cols = [col for col in df_prices.columns if "dayahead" in col.lower()]
selected_price_cols = [col for col in all_price_cols if "neighbours" not in col.lower()]

print("Selected day-ahead price columns (excluding neighbours):")
for col in selected_price_cols:
    print(" -", col)

# Hourly Plot 
plt.figure()
for col in selected_price_cols:
    plt.plot(df_prices.index, df_prices[col], label=col)
plt.title("Hourly Day-Ahead Prices")
plt.xlabel("Date")
plt.ylabel("Price (€/MWh)")
plt.legend(ncol=2, bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show()

# Daily Trends for Day-Ahead Prices 
df_prices_daily = df_prices[selected_price_cols].resample('D').mean()
plt.figure()
for col in selected_price_cols:
    plt.plot(df_prices_daily.index, df_prices_daily[col], label=col)
plt.title("Daily Average Day-Ahead Prices")
plt.xlabel("Date")
plt.ylabel("Average Price (€/MWh)")
plt.legend(ncol=2, bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show()

# Weekly Trends for Day-Ahead Prices 
df_prices_weekly = df_prices[selected_price_cols].resample('W').mean()
plt.figure()
for col in selected_price_cols:
    plt.plot(df_prices_weekly.index, df_prices_weekly[col], label=col)
plt.title("Weekly Average Day-Ahead Prices")
plt.xlabel("Date")
plt.ylabel("Average Price (€/MWh)")
plt.legend(ncol=2, bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show()

# Consumption data on Prices
# Relationship between consumption and price fluctuations
if ("Germany/Luxembourg [€/MWh]_dayahead" in df_prices.columns and 
    "Total (grid load) [MWh]_actual_cons" in df_actual_cons.columns and 
    "Residual load [MWh]_actual_cons" in df_actual_cons.columns):
    
    fig, axes = plt.subplots(nrows=3, ncols=1, sharex=True, figsize=(14, 10))

    # Day-Ahead Price
    axes[0].plot(df_prices.index, df_prices["Germany/Luxembourg [€/MWh]_dayahead"], color='blue')
    axes[0].set_ylabel("Price (€/MWh)", color='blue')
    axes[0].tick_params(axis='y', labelcolor='blue')
    axes[0].set_title("Day-Ahead Price, Actual Consumption, and Residual Load")
    axes[0].set_ylim(0, 450)

    # Actual Consumption
    axes[1].plot(df_actual_cons.index, df_actual_cons["Total (grid load) [MWh]_actual_cons"], color='red')
    axes[1].set_ylabel("Actual Consumption (MWh)", color='red')
    axes[1].tick_params(axis='y', labelcolor='red')

    # Residual Load
    axes[2].plot(df_actual_cons.index, df_actual_cons["Residual load [MWh]_actual_cons"], color='green')
    axes[2].set_ylabel("Residual Load (MWh)", color='green')
    axes[2].tick_params(axis='y', labelcolor='green')
    axes[2].set_xlabel("Date")

    plt.tight_layout()
    plt.show()
    
# Consumption: Actual vs. Forecasted 
if "Total (grid load) [MWh]_actual_cons" in df_actual_cons.columns and "Total (grid load) [MWh]_forecast_cons" in df_forecast_cons.columns:
    df_consumption = pd.DataFrame({
        "Actual": df_actual_cons["Total (grid load) [MWh]_actual_cons"],
        "Forecast": df_forecast_cons["Total (grid load) [MWh]_forecast_cons"]
    })
    df_consumption["Difference"] = df_consumption["Actual"] - df_consumption["Forecast"]
    df_consumption_daily = df_consumption.resample('D').mean()

    plt.figure()
    plt.plot(df_consumption_daily.index, df_consumption_daily["Actual"], label="Actual Consumption")
    plt.plot(df_consumption_daily.index, df_consumption_daily["Forecast"], label="Forecasted Consumption")
    plt.plot(df_consumption_daily.index, df_consumption_daily["Difference"], label="Difference", linestyle="--")
    plt.title("Daily Consumption: Actual vs. Forecasted")
    plt.xlabel("Date")
    plt.ylabel("Consumption (MWh)")
    plt.legend()
    plt.tight_layout()
    plt.show()

# =============================================================================
# Visualization for Aggregated Trends with Seasonal Highlights
# =============================================================================

# Resample daily data from the hourly prices 
df_prices_daily = df_prices[selected_price_cols].resample('D').mean()

# Compute average and median
df_prices_daily['Overall_Average'] = df_prices_daily.mean(axis=1)
df_prices_daily['Overall_Median'] = df_prices_daily.median(axis=1)

# 7-day moving average of the overall average
df_prices_daily['Overall_Average_MA7'] = df_prices_daily['Overall_Average'].rolling(window=7).mean()

plt.figure(figsize=(14, 7))
plt.plot(df_prices_daily.index, df_prices_daily['Overall_Average'], 
         label='Overall Daily Average', color='blue', alpha=0.7)
plt.plot(df_prices_daily.index, df_prices_daily['Overall_Average_MA7'], 
         label='7-Day Moving Average', color='red', linewidth=2)

# Add seasonal highlights
unique_years = sorted(df_prices_daily.index.year.unique())
for year in unique_years:
    summer_start = pd.Timestamp(year=year, month=6, day=20)
    summer_end = pd.Timestamp(year=year, month=9, day=23)
    winter_start_jan = pd.Timestamp(year=year, month=1, day=1)
    winter_end_mar = pd.Timestamp(year=year, month=3, day=20)
    winter_start_dec = pd.Timestamp(year=year, month=12, day=21)
    winter_end_dec = pd.Timestamp(year=year, month=12, day=31)
    
    if df_prices_daily.index.min() <= summer_start <= df_prices_daily.index.max():
        plt.axvspan(summer_start, summer_end, color='yellow', alpha=0.2)
    if df_prices_daily.index.min() <= winter_start_jan <= df_prices_daily.index.max():
        plt.axvspan(winter_start_jan, winter_end_mar, color='lightblue', alpha=0.2)
    if df_prices_daily.index.min() <= winter_start_dec <= df_prices_daily.index.max():
        plt.axvspan(winter_start_dec, winter_end_dec, color='lightblue', alpha=0.2)

# Create legend patches for seasons and line objects for the averages
summer_patch = mpatches.Patch(color='yellow', alpha=0.2, label='Summer (Jun 20 – Sep 23)')
winter_patch = mpatches.Patch(color='lightblue', alpha=0.2, label='Winter (Dec 21 – Mar 20)')
line_daily = plt.Line2D([], [], color='blue', label='Overall Daily Average')
line_ma7 = plt.Line2D([], [], color='red', label='7-Day Moving Average')

plt.grid(True, axis='y')
plt.title("Aggregated Daily Average Day-Ahead Prices with Seasonal Highlights")
plt.xlabel("Date")
plt.ylabel("Price (€/MWh)")
plt.legend(handles=[summer_patch, winter_patch, line_daily, line_ma7], loc='upper left')
plt.tight_layout()
plt.show()

# Boxplot: Distribution of Daily Overall Average Prices by Weekday
df_prices_daily['Weekday'] = df_prices_daily.index.day_name()

plt.figure(figsize=(14, 7))
sns.boxplot(
    x='Weekday', 
    y='Overall_Average', 
    data=df_prices_daily.reset_index(),
    order=['Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday', 'Saturday', 'Sunday'],
    palette='pastel'
)
plt.title("Distribution of Daily Overall Average Prices by Weekday")
plt.xlabel("Weekday")
plt.ylabel("Price (€/MWh)")
plt.tight_layout()
plt.show()


# 1. Overall Hourly Average: the mean across all countries per hour.
# 2. 24-Hour Moving Average: smooths out short-term fluctuations.
df_prices_hourly = df_prices[selected_price_cols].copy()
df_prices_hourly['Overall_Hourly_Average'] = df_prices_hourly.mean(axis=1)
df_prices_hourly['Rolling_24h_Average'] = df_prices_hourly['Overall_Hourly_Average'].rolling(window=24).mean()

plt.figure(figsize=(14, 7))
plt.plot(df_prices_hourly.index, df_prices_hourly['Overall_Hourly_Average'], 
         label='Overall Hourly Average', color='blue', alpha=0.7)
plt.plot(df_prices_hourly.index, df_prices_hourly['Rolling_24h_Average'], 
         label='24-Hour Moving Average', color='red', linewidth=2)

# Add seasonal highlights for each year
unique_years = sorted(df_prices_hourly.index.year.unique())
for year in unique_years:
    summer_start = pd.Timestamp(year=year, month=6, day=20)
    summer_end   = pd.Timestamp(year=year, month=9, day=23)
    winter_start_jan = pd.Timestamp(year=year, month=1, day=1)
    winter_end_mar   = pd.Timestamp(year=year, month=3, day=20)
    winter_start_dec = pd.Timestamp(year=year, month=12, day=21)
    winter_end_dec   = pd.Timestamp(year=year, month=12, day=31)
    
    if df_prices_hourly.index.min() <= summer_start <= df_prices_hourly.index.max():
        plt.axvspan(summer_start, summer_end, color='yellow', alpha=0.2)
    if df_prices_hourly.index.min() <= winter_start_jan <= df_prices_hourly.index.max():
        plt.axvspan(winter_start_jan, winter_end_mar, color='lightblue', alpha=0.2)
    if df_prices_hourly.index.min() <= winter_start_dec <= df_prices_hourly.index.max():
        plt.axvspan(winter_start_dec, winter_end_dec, color='lightblue', alpha=0.2)

# Create legend patches for seasons and lines for the hourly averages
summer_patch = mpatches.Patch(color='yellow', alpha=0.2, label='Summer (Jun 20 – Sep 23)')
winter_patch = mpatches.Patch(color='lightblue', alpha=0.2, label='Winter (Dec 21 – Mar 20)')
line_hourly = plt.Line2D([], [], color='blue', label='Overall Hourly Average')
line_24h = plt.Line2D([], [], color='red', label='24-Hour Moving Average')

plt.legend(handles=[summer_patch, winter_patch, line_hourly, line_24h], loc='upper left')
plt.grid(True, axis='y')
plt.title("Aggregated Hourly Average Day-Ahead Prices with Seasonal Highlights")
plt.xlabel("Date")
plt.ylabel("Price (€/MWh)")
plt.tight_layout()
plt.show()

# To better understand the distribution of prices per hour, I added a boxplot grouping by the hour of the day.
df_prices_hourly['Hour'] = df_prices_hourly.index.hour  

plt.figure(figsize=(14, 7))
sns.boxplot(
    x='Hour',
    y='Overall_Hourly_Average',
    data=df_prices_hourly.reset_index(),  
    color='lightblue'
)
plt.title("Distribution of Hourly Prices by Hour of Day")
plt.xlabel("Hour of Day (0-23)")
plt.ylabel("Price (€/MWh)")
plt.tight_layout()
plt.show()

# Resample weekly data from hourly prices
df_prices_weekly = df_prices[selected_price_cols].resample('W').mean()

# Compute average and median  for weekly data
df_prices_weekly['Overall_Average'] = df_prices_weekly.mean(axis=1)
df_prices_weekly['Overall_Median'] = df_prices_weekly.median(axis=1)
# Apply a smoothing window: 4-week moving average
df_prices_weekly['Overall_Average_MA4'] = df_prices_weekly['Overall_Average'].rolling(window=4).mean()

plt.figure(figsize=(14, 7))
plt.plot(df_prices_weekly.index, df_prices_weekly['Overall_Average'], 
         label='Overall Weekly Average', color='blue', alpha=0.7)
plt.plot(df_prices_weekly.index, df_prices_weekly['Overall_Average_MA4'], 
         label='4-Week Moving Average', color='red', linewidth=2)

# Loop over each unique year for seasonal highlights
unique_years = sorted(df_prices_weekly.index.year.unique())
for year in unique_years:
    summer_start = pd.Timestamp(year=year, month=6, day=20)
    summer_end = pd.Timestamp(year=year, month=9, day=23)
    winter_start_jan = pd.Timestamp(year=year, month=1, day=1)
    winter_end_mar = pd.Timestamp(year=year, month=3, day=20)
    winter_start_dec = pd.Timestamp(year=year, month=12, day=21)
    winter_end_dec = pd.Timestamp(year=year, month=12, day=31)
    
    if df_prices_weekly.index.min() <= summer_start <= df_prices_weekly.index.max():
        plt.axvspan(summer_start, summer_end, color='yellow', alpha=0.2)
    if df_prices_weekly.index.min() <= winter_start_jan <= df_prices_weekly.index.max():
        plt.axvspan(winter_start_jan, winter_end_mar, color='lightblue', alpha=0.2)
    if df_prices_weekly.index.min() <= winter_start_dec <= df_prices_weekly.index.max():
        plt.axvspan(winter_start_dec, winter_end_dec, color='lightblue', alpha=0.2)

# Create legend patches for seasons and line objects for weekly averages
summer_patch = mpatches.Patch(color='yellow', alpha=0.2, label='Summer (Jun 20 – Sep 23)')
winter_patch = mpatches.Patch(color='lightblue', alpha=0.2, label='Winter (Dec 21 – Mar 20)')
line_weekly = plt.Line2D([], [], color='blue', label='Overall Weekly Average')
line_ma4 = plt.Line2D([], [], color='red', label='4-Week Moving Average')

plt.grid(True, axis='y')
plt.title("Aggregated Weekly Average Day-Ahead Prices with Seasonal Highlights")
plt.xlabel("Date")
plt.ylabel("Price (€/MWh)")
plt.legend(handles=[summer_patch, winter_patch, line_weekly, line_ma4], loc='upper left')
plt.tight_layout()
plt.show()

# =============================================================================
# Electricity Generation vs. Price Trends Analysis
# =============================================================================

# --- Feature Engineering for Generation Data ---

# For Actual Generation: Sum all the actual generation columns to get a Total_Gen.
# The columns in df_actual_gen end with '_actual_gen'
actual_gen_cols = [col for col in df_actual_gen.columns if '_actual_gen' in col.lower()]
df_actual_gen['Total_Gen'] = df_actual_gen[actual_gen_cols].sum(axis=1)

# For Forecasted Generation: Use the total forecast column.
# Note: There are redundant columns, so we choose "Total [MWh]_forecast_dayahead_gen".
df_forecast_gen_da['Total_Gen_Forecast'] = df_forecast_gen_da["Total [MWh]_forecast_dayahead_gen"]

# Resample the generation data to daily frequency.
df_actual_gen_daily = df_actual_gen['Total_Gen'].resample('D').sum()
df_forecast_gen_daily = df_forecast_gen_da['Total_Gen_Forecast'].resample('D').sum()


# Dual Axis Time Series Plot
# This plot shows daily overall price (left y-axis) and generation (right y-axis).

fig, ax1 = plt.subplots(figsize=(14, 7))

# Plot Overall Daily Price on the left y-axis
line1, = ax1.plot(df_prices_daily.index, df_prices_daily['Overall_Average'], 
                  label="Overall Daily Price", color='blue', linewidth=2)
ax1.set_xlabel("Date")
ax1.set_ylabel("Price (€/MWh)", color='blue')
ax1.tick_params(axis='y', labelcolor='blue')

# Remove vertical gridlines, and keep horizontal gridlines only on ax1
ax1.grid(True, which='both', axis='y')
ax1.grid(False, which='both', axis='x')

# Create a second y-axis for generation data, and disable its gridlines
ax2 = ax1.twinx()
line2, = ax2.plot(df_actual_gen_daily.index, df_actual_gen_daily, 
                  label="Actual Generation", color='green', linestyle='-', linewidth=2)
line3, = ax2.plot(df_forecast_gen_daily.index, df_forecast_gen_daily, 
                  label="Forecast Generation", color='orange', linestyle='--', linewidth=2)
ax2.set_ylabel("Generation (MWh)", color='black')
ax2.tick_params(axis='y', labelcolor='black')
ax2.grid(False)

# Combine legends from both axes and place outside the chart on the right
lines = [line1, line2, line3]
labels = [line.get_label() for line in lines]
ax1.legend(lines, labels, loc='upper left', bbox_to_anchor=(1.05, 1), borderaxespad=0.)

plt.title("Daily Electricity Price and Generation (Actual vs. Forecast)")
plt.tight_layout()
plt.show()

# Scatter Plot of Generation Error vs. Price 
# Compute generation error (Actual - Forecast)
df_gen_error = pd.DataFrame({
    "Actual": df_actual_gen_daily,
    "Forecast": df_forecast_gen_daily
})
df_gen_error["Error"] = df_gen_error["Actual"] - df_gen_error["Forecast"]

# For alignment, I also extract the corresponding daily price values.
common_dates = df_gen_error.index.intersection(df_prices_daily.index)
error = df_gen_error.loc[common_dates, "Error"]
price = df_prices_daily.loc[common_dates, "Overall_Average"]

plt.figure(figsize=(10, 6))
sns.scatterplot(x=error, y=price, color='purple', alpha=0.7)
plt.xlabel("Generation Error (Actual - Forecast) (MWh)")
plt.ylabel("Overall Daily Price (€/MWh)")
plt.title("Scatter Plot of Generation Error vs. Price")
plt.tight_layout()
plt.show()

# =============================================================================
# Electricity Generation vs. Price Trends: Split Analysis 
# =============================================================================

# Feature Engineering for Actual Generation

# Renewable generation: Sum of Biomass, Hydropower, Wind offshore, Wind onshore,
# Photovoltaics, and Other renewable.
renewable_cols_actual = [
    "Biomass [MWh]_actual_gen", "Hydropower [MWh]_actual_gen",
    "Wind offshore [MWh]_actual_gen", "Wind onshore [MWh]_actual_gen",
    "Photovoltaics [MWh]_actual_gen", "Other renewable [MWh]_actual_gen"
]
df_actual_gen["Renewables_Actual"] = df_actual_gen[renewable_cols_actual].sum(axis=1)

# Conventional generation: Sum of Nuclear, Lignite, Hard coal, Fossil gas,
# Hydro pumped storage, and Other conventional.
conventional_cols_actual = [
    "Nuclear [MWh]_actual_gen", "Lignite [MWh]_actual_gen",
    "Hard coal [MWh]_actual_gen", "Fossil gas [MWh]_actual_gen",
    "Hydro pumped storage [MWh]_actual_gen", "Other conventional [MWh]_actual_gen"
]
df_actual_gen["Conventional_Actual"] = df_actual_gen[conventional_cols_actual].sum(axis=1)

# Feature Engineering for Forecasted Generation 

# - Renewables: "Photovoltaics and wind [MWh]_forecast_dayahead_gen"
# - Conventional: "Other [MWh]_forecast_dayahead_gen" (instead of the total)
df_forecast_gen_da["Renewables_Forecast"] = df_forecast_gen_da["Photovoltaics and wind [MWh]_forecast_dayahead_gen"]
df_forecast_gen_da["Conventional_Forecast"] = df_forecast_gen_da["Other [MWh]_forecast_dayahead_gen"]

# Resample Generation Data to Daily 
df_actual_gen_daily = df_actual_gen[["Renewables_Actual", "Conventional_Actual"]].resample('D').sum()
df_forecast_gen_daily = df_forecast_gen_da[["Renewables_Forecast", "Conventional_Forecast"]].resample('D').sum()


# =============================================================================
# Figure 1: Renewables vs. Price Trends 
# =============================================================================

fig, axes = plt.subplots(nrows=3, ncols=1, sharex=True, figsize=(14, 10))

# Overall Daily Price
axes[0].plot(df_prices_daily.index, df_prices_daily['Overall_Average'], color='blue', linewidth=2)
axes[0].set_ylabel("Price (€/MWh)")
axes[0].set_title("Overall Daily Price")
axes[0].grid(True, which='both', axis='y')
axes[0].grid(False, which='both', axis='x')

# Actual Renewables Generation
axes[1].plot(df_actual_gen_daily.index, df_actual_gen_daily["Renewables_Actual"], color='green', linewidth=2)
axes[1].set_ylabel("Actual Renewables (MWh)")
axes[1].set_title("Actual Renewables Generation")
axes[1].grid(True, which='both', axis='y')
axes[1].grid(False, which='both', axis='x')

# Forecast Renewables Generation
axes[2].plot(df_forecast_gen_daily.index, df_forecast_gen_daily["Renewables_Forecast"], 
             color='orange', linestyle='--', linewidth=2)
axes[2].set_ylabel("Forecast Renewables (MWh)")
axes[2].set_title("Forecast Renewables Generation")
axes[2].set_xlabel("Date")
axes[2].grid(True, which='both', axis='y')
axes[2].grid(False, which='both', axis='x')

plt.tight_layout()
plt.show()

# =============================================================================
# Figure 2: Conventional Generation vs. Price Trends 
# =============================================================================

fig, axes = plt.subplots(nrows=3, ncols=1, sharex=True, figsize=(14, 10))

# Overall Daily Price
axes[0].plot(df_prices_daily.index, df_prices_daily['Overall_Average'], color='blue', linewidth=2)
axes[0].set_ylabel("Price (€/MWh)")
axes[0].set_title("Overall Daily Price")
axes[0].grid(True, which='both', axis='y')
axes[0].grid(False, which='both', axis='x')

# Actual Conventional Generation
axes[1].plot(df_actual_gen_daily.index, df_actual_gen_daily["Conventional_Actual"], color='green', linewidth=2)
axes[1].set_ylabel("Actual Conventional (MWh)")
axes[1].set_title("Actual Conventional Generation")
axes[1].grid(True, which='both', axis='y')
axes[1].grid(False, which='both', axis='x')

# Forecast Conventional Generation
axes[2].plot(df_forecast_gen_daily.index, df_forecast_gen_daily["Conventional_Forecast"], 
             color='orange', linestyle='--', linewidth=2)
axes[2].set_ylabel("Forecast Conventional (MWh)")
axes[2].set_title("Forecast Conventional Generation")
axes[2].set_xlabel("Date")
axes[2].grid(True, which='both', axis='y')
axes[2].grid(False, which='both', axis='x')

plt.tight_layout()
plt.show()


# =============================================================================
# Cross-border Physical Flows vs. Scheduled Commercial Exchanges Analysis
# =============================================================================
# In this section, we merge the scheduled and cross-border datasets on the Date index

# We focus on the "Net export" columns.
df_flows = pd.merge(
    df_scheduled_exchanges[['Net export [MWh]_scheduled_exchanges']],
    df_cross_border[['Net export [MWh]_cross_border']],
    left_index=True, right_index=True, how='inner'
)

# Compute the difference and percentage difference between scheduled and cross-border flows.
df_flows['Difference'] = (df_flows['Net export [MWh]_scheduled_exchanges'] -
                          df_flows['Net export [MWh]_cross_border'])
df_flows['Perc_Difference'] = (df_flows['Difference'] /
                               df_flows['Net export [MWh]_scheduled_exchanges']) * 100

# Time Series of Scheduled vs. Cross-border Net Exports 
plt.figure(figsize=(14, 7))
plt.plot(df_flows.index, df_flows['Net export [MWh]_scheduled_exchanges'], 
         label='Scheduled Net Export', linewidth=2)
plt.plot(df_flows.index, df_flows['Net export [MWh]_cross_border'], 
         label='Cross-border Net Export', linewidth=2)
plt.title("Scheduled vs. Cross-border Net Exports Over Time")
plt.xlabel("Date")
plt.ylabel("Net Export (MWh)")
plt.legend()
plt.tight_layout()
plt.show()

# Time Series of the Difference (Scheduled - Cross-border)
plt.figure(figsize=(14, 7))
plt.plot(df_flows.index, df_flows['Difference'], color='red', 
         label='Difference (Scheduled - Cross-border)', linewidth=2)
plt.title("Difference Between Scheduled and Cross-border Net Exports")
plt.xlabel("Date")
plt.ylabel("Difference (MWh)")
plt.ylim(-4500, 4000)  # Set the y-axis limits to -50 and 50, adjust as needed
plt.legend()
plt.tight_layout()
plt.show()

# Scatter Plot of Scheduled vs. Cross-border Net Exports
plt.figure(figsize=(8, 6))
sns.scatterplot(x='Net export [MWh]_scheduled_exchanges', 
                y='Net export [MWh]_cross_border', 
                data=df_flows, color='purple', alpha=0.7)
# Add a reference line (45°) to assess the agreement between the two measures
min_val = df_flows['Net export [MWh]_scheduled_exchanges'].min()
max_val = df_flows['Net export [MWh]_scheduled_exchanges'].max()
plt.plot([min_val, max_val], [min_val, max_val], color='black', linestyle='--', label='45° Line')
plt.title("Scatter Plot: Scheduled vs. Cross-border Net Exports")
plt.xlabel("Scheduled Net Export (MWh)")
plt.ylabel("Cross-border Net Export (MWh)")
plt.legend()
plt.tight_layout()
plt.show()

# 7-Day Moving Averages for Smoothing Trends
df_flows['Scheduled_MA7'] = df_flows['Net export [MWh]_scheduled_exchanges'].rolling(window=7).mean()
df_flows['CrossBorder_MA7'] = df_flows['Net export [MWh]_cross_border'].rolling(window=7).mean()

plt.figure(figsize=(14, 7))
plt.plot(df_flows.index, df_flows['Scheduled_MA7'], 
         label='Scheduled Net Export 7-Day MA', linewidth=2)
plt.plot(df_flows.index, df_flows['CrossBorder_MA7'], 
         label='Cross-border Net Export 7-Day MA', linewidth=2)
plt.title("7-Day Moving Average: Scheduled vs. Cross-border Net Exports")
plt.xlabel("Date")
plt.ylabel("Net Export (MWh)")
plt.legend()
plt.tight_layout()
plt.show()

# =============================================================================
# Aggregating Exports and Imports for Scheduled Exchanges and Cross-Border Flows
# =============================================================================


# Identify export columns 
export_cols_sched = [col for col in df_scheduled_exchanges.columns if "(export)" in col.lower()]
# Identify import columns 
import_cols_sched = [col for col in df_scheduled_exchanges.columns if "(import)" in col.lower()]

# Compute the total exports and imports per row.
df_scheduled_exchanges['Total_Exports_scheduled'] = df_scheduled_exchanges[export_cols_sched].sum(axis=1)
df_scheduled_exchanges['Total_Imports_scheduled'] = df_scheduled_exchanges[import_cols_sched].sum(axis=1)


# For Cross-Border Physical Flows
export_cols_cross = [col for col in df_cross_border.columns if "(export)" in col.lower()]
import_cols_cross = [col for col in df_cross_border.columns if "(import)" in col.lower()]

df_cross_border['Total_Exports_cross'] = df_cross_border[export_cols_cross].sum(axis=1)
df_cross_border['Total_Imports_cross'] = df_cross_border[import_cols_cross].sum(axis=1)


# Merge the Aggregated Columns for Comparison
df_flows_extended = pd.merge(
    df_scheduled_exchanges[['Total_Exports_scheduled', 'Total_Imports_scheduled']],
    df_cross_border[['Total_Exports_cross', 'Total_Imports_cross']],
    left_index=True, right_index=True, how='inner'
)

print("Aggregated Exports and Imports (Scheduled vs Cross-Border):")
print(df_flows_extended.head())

# =============================================================================
# Visualization of Aggregated Exports and Imports
# =============================================================================

# Time Series Comparison of Exports
plt.figure(figsize=(14, 7))
plt.plot(df_flows_extended.index, df_flows_extended['Total_Exports_scheduled'], 
         label='Scheduled Exports', linewidth=2, color='blue')
plt.plot(df_flows_extended.index, df_flows_extended['Total_Exports_cross'], 
         label='Cross-Border Exports', linewidth=2, color='green')
plt.title("Time Series Comparison of Exports")
plt.xlabel("Date")
plt.ylabel("Total Exports (MWh)")
plt.legend()
plt.tight_layout()
plt.show()

# Time Series Comparison of Imports 
plt.figure(figsize=(14, 7))
plt.plot(df_flows_extended.index, df_flows_extended['Total_Imports_scheduled'], 
         label='Scheduled Imports', linewidth=2, color='red')
plt.plot(df_flows_extended.index, df_flows_extended['Total_Imports_cross'], 
         label='Cross-Border Imports', linewidth=2, color='orange')
plt.title("Time Series Comparison of Imports")
plt.xlabel("Date")
plt.ylabel("Total Imports (MWh)")
plt.legend()
plt.tight_layout()
plt.show()

# Scatter Plot - Exports (Scheduled vs. Cross-Border) 
plt.figure(figsize=(8, 6))
sns.scatterplot(x='Total_Exports_scheduled', y='Total_Exports_cross', 
                data=df_flows_extended, color='blue', alpha=0.7)
min_exp = df_flows_extended[['Total_Exports_scheduled', 'Total_Exports_cross']].min().min()
max_exp = df_flows_extended[['Total_Exports_scheduled', 'Total_Exports_cross']].max().max()
plt.plot([min_exp, max_exp], [min_exp, max_exp], color='black', linestyle='--', label='45° Line')
plt.title("Scatter Plot: Scheduled vs. Cross-Border Exports")
plt.xlabel("Scheduled Exports (MWh)")
plt.ylabel("Cross-Border Exports (MWh)")
plt.legend()
plt.tight_layout()
plt.show()

# Scatter Plot - Imports (Scheduled vs. Cross-Border) 
plt.figure(figsize=(8, 6))
sns.scatterplot(x='Total_Imports_scheduled', y='Total_Imports_cross', 
                data=df_flows_extended, color='red', alpha=0.7)
min_imp = df_flows_extended[['Total_Imports_scheduled', 'Total_Imports_cross']].min().min()
max_imp = df_flows_extended[['Total_Imports_scheduled', 'Total_Imports_cross']].max().max()
plt.plot([min_imp, max_imp], [min_imp, max_imp], color='black', linestyle='--', label='45° Line')
plt.title("Scatter Plot: Scheduled vs. Cross-Border Imports")
plt.xlabel("Scheduled Imports (MWh)")
plt.ylabel("Cross-Border Imports (MWh)")
plt.legend()
plt.tight_layout()
plt.show()

In [None]:
# =============================================================================
# 2. CORRELATION & FEATURE RELATIONSHIPS
# =============================================================================

# Identify features with strongest correlations with day-ahead prices)
# For simplicity, I decided to do the correlations using the Germany/Luxembourg day-ahead prices column

# Adjust column names as needed for merging; here we assume the indexes align
df_corr = pd.DataFrame(index=df_prices.index)
df_corr["Price_DE_LU"] = df_prices["Germany/Luxembourg [€/MWh]_dayahead"]

# Generation features: Aggregate them.
renewable_cols = ["Biomass [MWh]_actual_gen", "Hydropower [MWh]_actual_gen",
                  "Wind offshore [MWh]_actual_gen", "Wind onshore [MWh]_actual_gen",
                  "Photovoltaics [MWh]_actual_gen", "Other renewable [MWh]_actual_gen"]


if all(col in df_actual_gen.columns for col in renewable_cols):
    df_corr["Total_Renewables"] = df_actual_gen[renewable_cols].sum(axis=1)
else:
    print("Not all renewable columns found in df_actual_gen.")

conventional_cols = ["Nuclear [MWh]_actual_gen", "Lignite [MWh]_actual_gen",
                     "Hard coal [MWh]_actual_gen", "Fossil gas [MWh]_actual_gen",
                     "Hydro pumped storage [MWh]_actual_gen", "Other conventional [MWh]_actual_gen"]


if all(col in df_actual_gen.columns for col in conventional_cols):
    df_corr["Total_Conventionals"] = df_actual_gen[conventional_cols].sum(axis=1)
else:
    print("Not all conventional columns found in df_actual_gen.")
  

# Add TSO costs columns (2024)
for col in df_costs.columns:
    df_corr[col] = df_costs[col]

# Add Automatic Frequency Restoration Reserve columns (2023)
for col in df_auto_frr.columns:
    df_corr[col] = df_auto_frr[col]

# Add Frequency Containment Reserve column (2023)
for col in df_fcr.columns:
    df_corr[col] = df_fcr[col]


# Add Scheduled Comercial Exhanges Net export [MWh]_scheduled_exchanges column df_scheduled_exchanges
if "Net export [MWh]_scheduled_exchanges" in df_scheduled_exchanges.columns:
    df_corr["Scheduled Exchanges"] = df_scheduled_exchanges["Net export [MWh]_scheduled_exchanges"]

if "Net export [MWh]_cross_border" in df_cross_border.columns:
    df_corr["Cross Border"] = df_cross_border["Net export [MWh]_cross_border"]

# Add Balancing Energy Price
if "Price [€/MWh]_balancing_energy" in df_balancing_energy.columns:
    df_corr["Balancing_energy_price"] = df_balancing_energy["Price [€/MWh]_balancing_energy"]

# Drop any rows with missing data (if any) before correlation
df_corr_clean = df_corr.dropna()

# Calculate correlation matrix
corr_matrix = df_corr_clean.corr()

# Plot heatmap for correlations
plt.figure(figsize=(16, 12))
sns.heatmap(corr_matrix, annot=True, cmap="coolwarm", fmt=".2f")
plt.title("Correlation Matrix: Day-Ahead Price vs. TSO Costs, FRR, FCR, and Balancing Energy")
plt.tight_layout()
plt.show()

# Define a correlation threshold (both positive and negative) using absolute correlation
corr_threshold = 0.5

# Compute correlations between 'Price_DE_LU' and all other features
price_corr_series = corr_matrix["Price_DE_LU"].drop("Price_DE_LU")  

# Select features that have an absolute correlation above the threshold (this captures both positive and negative)
high_corr_features = price_corr_series[price_corr_series.abs() >= corr_threshold]

# Sort the results by absolute correlation
high_corr_features = high_corr_features.reindex(high_corr_features.abs().sort_values(ascending=False).index)

print("\nHighly correlated features with Germany/Luxembourg Day-Ahead Price (|corr| >= {:.2f}):".format(corr_threshold))
print(high_corr_features)


# Analyze correlations between prices of different countries
price_corr = df_prices[selected_price_cols].corr()
plt.figure(figsize=(10, 8))
sns.heatmap(price_corr, annot=True, cmap="viridis")
plt.title("Inter-Country Day-Ahead Price Correlations")
plt.tight_layout()
plt.show()

In [None]:
# =============================================================================
# 3. PRICE & CONSUMPTION IMPACT ANALYSIS 
# =============================================================================
# This section investigates:
#  - How scheduled commercial exchanges influence price fluctuations
#  - The impact of cross-border physical flows on day-ahead prices


# Merge scheduled commercial exchanges + cross-border flows with DE/LU day-ahead price
df_exchanges = pd.DataFrame(index=df_prices.index)

for col in df_scheduled_exchanges.columns:
    df_exchanges[col] = df_scheduled_exchanges[col]

for col in df_cross_border.columns:
    df_exchanges[col] = df_cross_border[col]

df_exchanges["Price_DE_LU"] = df_prices["Germany/Luxembourg [€/MWh]_dayahead"]

# Correlation Analysis (Non-Lagged)
df_exchanges_clean = df_exchanges.dropna()
corr_matrix_exchanges = df_exchanges_clean.corr()
price_corr_exchanges = corr_matrix_exchanges["Price_DE_LU"].drop("Price_DE_LU")

# Sort correlations in descending order
price_corr_exchanges_sorted = price_corr_exchanges.sort_values(ascending=False)

print("\nCorrelation with Price_DE_LU (Scheduled Exchanges + Cross-Border Flows):")
print(price_corr_exchanges_sorted)

# More readable bar plot
# Plot top N features for better legibility
top_n = 20
top_corr = price_corr_exchanges_sorted.head(top_n)

plt.figure(figsize=(8, 10))
sns.barplot(x=top_corr.values, y=top_corr.index, orient='h')
plt.title(f"Top {top_n} Correlations of Exchanges/Flows with DE/LU Day-Ahead Price")
plt.xlabel("Correlation")
plt.ylabel("Feature")
plt.tight_layout()
plt.show()

# Lag Analysis

def create_lagged_dataframe(df, cols, lags=[1, 2, 3, 6, 12]):

    lag_data = {}
    for col in cols:
        for lag in lags:
            lag_data[f"{col}_lag_{lag}h"] = df[col].shift(lag)
    # Combine original and lagged data
    df_lagged = pd.concat([df, pd.DataFrame(lag_data, index=df.index)], axis=1)
    return df_lagged

# Define which columns to lag (all except Price_DE_LU)
exchange_cols = [col for col in df_exchanges.columns if col != "Price_DE_LU"]
df_exchanges_lag = create_lagged_dataframe(df_exchanges, exchange_cols)
df_exchanges_lag.dropna(inplace=True)

# Compute correlations with Price_DE_LU
lag_corr_exchanges = df_exchanges_lag.corr()["Price_DE_LU"].sort_values(ascending=False)

print("\nCorrelation of lagged scheduled exchanges & cross-border flows with Day-Ahead Price:")
print(lag_corr_exchanges.head(30))  

# Horizontal Bar Plot for Lagged Features
# Filter out Price_DE_LU from the index so we don't plot it
lag_corr_exchanges_features = lag_corr_exchanges.drop("Price_DE_LU", errors='ignore')

# Let's plot the top 20 again for clarity
top_lag_n = 20
top_lag_corr = lag_corr_exchanges_features.head(top_lag_n)

plt.figure(figsize=(8, 10))
sns.barplot(x=top_lag_corr.values, y=top_lag_corr.index, orient='h')
plt.title(f"Top {top_lag_n} Correlations of Lagged Exchanges & Flows with DE/LU Price")
plt.xlabel("Correlation")
plt.ylabel("Lagged Feature")
plt.tight_layout()
plt.show()


In [None]:
# =============================================================================
# Model Pipeline for PowerCast: Electricity Price Forecasting Challenge
# =============================================================================
#  Important Note: If using this section as a .py script, don't forget to import 
#  the libraries and set the files path to be able to upload the files.
#
#  - Additional temporal features (day-of-week, hour-of-day with cyclical encoding)
#  - Advanced feature engineering 
#  - Outlier detection using IQR
#  - Linear Regression
#  - Random Forest,
#  - Scikit-learn GradientBoostingRegressor for quantile regression 
#    with post-hoc calibration
#  - Interpretability via SHAP
#
# Packages: pandas, numpy, matplotlib, seaborn, scikit-learn,
#                    shap, scipy
# =============================================================================

sns.set_theme(style="whitegrid")
plt.rcParams["figure.figsize"] = (14, 7)

# =============================================================================
# 1. HELPER FUNCTIONS
# =============================================================================
def load_file(file_path, date_col='Date'):
    """Load an Excel file, parse date, and set as index."""
    df = pd.read_excel(file_path)
    df[date_col] = pd.to_datetime(df[date_col])
    df.set_index(date_col, inplace=True)
    return df

def directional_accuracy(y_true, y_pred):
    """Calculate percentage of times the predicted direction matches actual direction."""
    y_true, y_pred = y_true.align(y_pred, join='inner')
    true_change = y_true.diff().dropna()
    pred_change = y_pred.diff().dropna()
    true_sign = np.sign(true_change)
    pred_sign = np.sign(pred_change)
    correct = (true_sign == pred_sign).sum()
    total = len(true_sign)
    return correct / total if total > 0 else np.nan

def volatility_capture(y_true, y_pred):
    """Compute ratio of std of predicted returns to actual returns."""
    y_true, y_pred = y_true.align(y_pred, join='inner')
    true_returns = y_true.diff().dropna()
    pred_returns = y_pred.diff().dropna()
    std_true = np.std(true_returns)
    std_pred = np.std(pred_returns)
    return std_pred / std_true if std_true != 0 else np.nan

def extreme_price_movement_detection(y_true, y_pred, threshold=0.15):
    """Compute precision and recall for detecting extreme price changes (> threshold)."""
    y_true, y_pred = y_true.align(y_pred, join='inner')
    true_pct_change = y_true.pct_change().dropna()
    pred_pct_change = y_pred.pct_change().dropna()
    true_extreme = (true_pct_change.abs() > threshold).astype(int)
    pred_extreme = (pred_pct_change.abs() > threshold).astype(int)
    TP = ((true_extreme == 1) & (pred_extreme == 1)).sum()
    FP = ((true_extreme == 0) & (pred_extreme == 1)).sum()
    FN = ((true_extreme == 1) & (pred_extreme == 0)).sum()
    precision = TP / (TP + FP) if (TP + FP) > 0 else 0.0
    recall = TP / (TP + FN) if (TP + FN) > 0 else 0.0
    return precision, recall

def create_lagged_features(df, columns, lags=[1, 2, 3, 6, 12]):
    df_lagged = df.copy()
    for col in columns:
        for lag in lags:
            df_lagged[f"{col}_lag_{lag}h"] = df_lagged[col].shift(lag)
    return df_lagged

def add_rolling_features(df, col, windows=[6, 12, 24]):
    for w in windows:
        df[f"{col}_rolling_mean_{w}h"] = df[col].rolling(w).mean()
        df[f"{col}_rolling_std_{w}h"] = df[col].rolling(w).std()
    return df

def persistence_forecast(train_y, test_y):
    """Naive forecast: next value equals previous value (fallback uses last training value)."""
    y_pred = test_y.shift(1)
    if len(y_pred) > 0:
        y_pred.iloc[0] = train_y.iloc[-1]
    return y_pred

def remove_outliers_iqr(df, col, multiplier=1.5):
    Q1 = df[col].quantile(0.25)
    Q3 = df[col].quantile(0.75)
    IQR = Q3 - Q1
    lower_bound = Q1 - multiplier * IQR
    upper_bound = Q3 + multiplier * IQR
    df[col] = df[col].clip(lower_bound, upper_bound)
    return df

def add_temporal_features(df):
    df = df.copy()
    # Ensure the index is a DatetimeIndex
    df['day_of_week'] = df.index.dayofweek
    df['hour_of_day'] = df.index.hour
    # Cyclical encoding: hour_of_day
    df['hour_sin'] = np.sin(2 * np.pi * df['hour_of_day'] / 24)
    df['hour_cos'] = np.cos(2 * np.pi * df['hour_of_day'] / 24)
    # Cyclical encoding: day_of_week
    df['dow_sin'] = np.sin(2 * np.pi * df['day_of_week'] / 7)
    df['dow_cos'] = np.cos(2 * np.pi * df['day_of_week'] / 7)
    return df

def calibrate_intervals(y_true, lower, upper, target_coverage=0.95):
    """
    Compute current coverage and return scaling factor.
    Scale factor = target_coverage / current_coverage (if current_coverage > 0).
    """
    current_coverage = ((y_true >= lower) & (y_true <= upper)).mean()
    print(f"Current coverage: {current_coverage:.2%}")
    factor = target_coverage / current_coverage if current_coverage > 0 else 1.0
    return factor

# =============================================================================
# 2. MERGE INTO A MODELING DATAFRAME 
# =============================================================================
model_df = pd.DataFrame(index=df_prices.index)
model_df["Price_DE_LU"] = df_prices["Germany/Luxembourg [€/MWh]_dayahead"]

# Consumption features
if "Total (grid load) [MWh]_actual_cons" in df_actual_cons.columns:
    model_df["Actual_cons"] = df_actual_cons["Total (grid load) [MWh]_actual_cons"]
if "Residual load [MWh]_actual_cons" in df_actual_cons.columns:
    model_df["Residual_load"] = df_actual_cons["Residual load [MWh]_actual_cons"]
if "Total (grid load) [MWh]_forecast_cons" in df_forecast_cons.columns:
    model_df["Forecast_cons"] = df_forecast_cons["Total (grid load) [MWh]_forecast_cons"]

# Generation features: Aggregate to macro-level
renewable_cols = ["Biomass [MWh]_actual_gen", "Hydropower [MWh]_actual_gen",
                  "Wind offshore [MWh]_actual_gen", "Wind onshore [MWh]_actual_gen",
                  "Photovoltaics [MWh]_actual_gen", "Other renewable [MWh]_actual_gen"]
if all(col in df_actual_gen.columns for col in renewable_cols):
    model_df["Total_Renewables"] = df_actual_gen[renewable_cols].sum(axis=1)

conventional_cols = ["Nuclear [MWh]_actual_gen", "Lignite [MWh]_actual_gen",
                     "Hard coal [MWh]_actual_gen", "Fossil gas [MWh]_actual_gen",
                     "Hydro pumped storage [MWh]_actual_gen", "Other conventional [MWh]_actual_gen"]
if all(col in df_actual_gen.columns for col in conventional_cols):
    model_df["Total_Conventionals"] = df_actual_gen[conventional_cols].sum(axis=1)

# Net export features
if "Net export [MWh]_cross_border" in df_cross_border.columns:
    model_df["Net_export_physical"] = df_cross_border["Net export [MWh]_cross_border"]
if "Net export [MWh]_scheduled_exchanges" in df_scheduled_exchanges.columns:
    model_df["Net_export_scheduled"] = df_scheduled_exchanges["Net export [MWh]_scheduled_exchanges"]

# Add temporal features 
model_df = add_temporal_features(model_df)

# =============================================================================
# 3 OUTLIER DETECTION: Cap outliers using IQR for each column
# =============================================================================
for col in model_df.columns:
    model_df = remove_outliers_iqr(model_df, col)
model_df.dropna(inplace=True)

# =============================================================================
# 4. ADVANCED FEATURE ENGINEERING
# =============================================================================

# Add rolling features
model_df = add_rolling_features(model_df, "Residual_load", windows=[6, 12, 24])
model_df = add_rolling_features(model_df, "Net_export_physical", windows=[6, 12])
model_df.dropna(inplace=True)

# Add lagged features for selected variables
lag_columns = ["Actual_cons", "Residual_load", "Forecast_cons", "Total_Renewables", 
               "Total_Conventionals", "Net_export_physical", "Net_export_scheduled", 
               "day_of_week", "hour_of_day", "hour_sin", "hour_cos", "dow_sin", "dow_cos"]
model_df_lagged = create_lagged_features(model_df, lag_columns, lags=[1, 2, 3, 6, 12])
model_df_lagged.dropna(inplace=True)

# =============================================================================
# 5. TRAIN/TEST SPLIT (80/20)
# =============================================================================

total_samples = len(model_df_lagged)
split_index = int(total_samples * 0.8)

train_df = model_df_lagged.iloc[:split_index].copy()
test_df  = model_df_lagged.iloc[split_index:].copy()

target_col = "Price_DE_LU"
X_train = train_df.drop(columns=[target_col])
y_train = train_df[target_col]
X_test  = test_df.drop(columns=[target_col])
y_test  = test_df[target_col]

print("Training set:", X_train.shape, y_train.shape)
print("Test set:", X_test.shape, y_test.shape)

# =============================================================================
# 6. BASELINE MODEL: PERSISTENCE
# =============================================================================

y_pred_persistence = persistence_forecast(y_train, y_test)

rmse_pers = sqrt(mean_squared_error(y_test, y_pred_persistence))
mae_pers  = mean_absolute_error(y_test, y_pred_persistence)
dir_acc_pers = directional_accuracy(y_test, y_pred_persistence)
vol_cap_pers = volatility_capture(y_test, y_pred_persistence)
prec_ext_pers, rec_ext_pers = extreme_price_movement_detection(y_test, y_pred_persistence)

print("\n=== BASELINE (Persistence) ===")
print(f"RMSE: {rmse_pers:.2f}")
print(f"MAE:  {mae_pers:.2f}")
print(f"Directional Accuracy: {dir_acc_pers:.2%}")
print(f"Volatility Capture (ratio): {vol_cap_pers:.2f}")
print(f"Extreme Movement Detection (Precision/Recall): {prec_ext_pers:.2f} / {rec_ext_pers:.2f}")

# =============================================================================
# 7. LINEAR REGRESSION
# =============================================================================

linreg = LinearRegression()
linreg.fit(X_train, y_train)
y_pred_lin = linreg.predict(X_test)

rmse_lin = sqrt(mean_squared_error(y_test, y_pred_lin))
mae_lin  = mean_absolute_error(y_test, y_pred_lin)
dir_acc_lin = directional_accuracy(y_test, pd.Series(y_pred_lin, index=y_test.index))
vol_cap_lin = volatility_capture(y_test, pd.Series(y_pred_lin, index=y_test.index))
prec_ext_lin, rec_ext_lin = extreme_price_movement_detection(y_test, pd.Series(y_pred_lin, index=y_test.index))

print("\n=== LINEAR REGRESSION ===")
print(f"RMSE: {rmse_lin:.2f}")
print(f"MAE:  {mae_lin:.2f}")
print(f"Directional Accuracy: {dir_acc_lin:.2%}")
print(f"Volatility Capture (ratio): {vol_cap_lin:.2f}")
print(f"Extreme Movement Detection (Precision/Recall): {prec_ext_lin:.2f} / {rec_ext_lin:.2f}")

# =============================================================================
# 8. Random Forest
# =============================================================================

from sklearn.model_selection import TimeSeriesSplit

param_dist_rf = {
    'n_estimators': [100, 200, 300, 400],
    'max_depth': [3, 5, 7, 10, None],
    'min_samples_split': randint(2, 15),
    'min_samples_leaf': randint(1, 10),
    'max_features': ['sqrt', 'log2', None]
}

rf_base = RandomForestRegressor(random_state=42)
tscv = TimeSeriesSplit(n_splits=5)
random_search_rf = RandomizedSearchCV(
    estimator=rf_base,
    param_distributions=param_dist_rf,
    n_iter=30,  # expanded iterations
    scoring='neg_mean_squared_error',
    cv=tscv,
    random_state=42,
    verbose=1
)

random_search_rf.fit(X_train, y_train)
print("Best params for Random Forest:", random_search_rf.best_params_)
best_rf = random_search_rf.best_estimator_

y_pred_rf_tuned = best_rf.predict(X_test)

rmse_rf = sqrt(mean_squared_error(y_test, y_pred_rf_tuned))
mae_rf  = mean_absolute_error(y_test, y_pred_rf_tuned)
dir_acc_rf = directional_accuracy(y_test, pd.Series(y_pred_rf_tuned, index=y_test.index))
vol_cap_rf = volatility_capture(y_test, pd.Series(y_pred_rf_tuned, index=y_test.index))
prec_ext_rf, rec_ext_rf = extreme_price_movement_detection(y_test, pd.Series(y_pred_rf_tuned, index=y_test.index))

print("\n=== TUNED RANDOM FOREST ===")
print(f"RMSE: {rmse_rf:.2f}")
print(f"MAE:  {mae_rf:.2f}")
print(f"Directional Accuracy: {dir_acc_rf:.2%}")
print(f"Volatility Capture (ratio): {vol_cap_rf:.2f}")
print(f"Extreme Movement Detection (Precision/Recall): {prec_ext_rf:.2f} / {rec_ext_rf:.2f}")


# ============================================================================================
# 09. Quantile Regression (GradientBoostingRegressor) & Post-Hoc Calibration
# ============================================================================================

print("\nTuning scikit-learn GradientBoostingRegressor for 95% prediction interval...")

# Train the model for the 5th percentile (lower bound)
gbr_lower = GradientBoostingRegressor(loss='quantile', alpha=0.05, 
                                      n_estimators=500, learning_rate=0.1, 
                                      max_depth=3, random_state=42)
gbr_lower.fit(X_train, y_train)
y_pred_lower = gbr_lower.predict(X_test)

# Train the model for the 95th percentile (upper bound)
gbr_upper = GradientBoostingRegressor(loss='quantile', alpha=0.95, 
                                      n_estimators=500, learning_rate=0.1, 
                                      max_depth=3, random_state=42)
gbr_upper.fit(X_train, y_train)
y_pred_upper = gbr_upper.predict(X_test)

coverage = ((y_test >= y_pred_lower) & (y_test <= y_pred_upper)).mean()
print(f"Initial Coverage of 95% prediction interval: {coverage:.2%}")

# Post-Hoc Calibration: If coverage is below 95%, scale the interval width.
def calibrate_intervals(y_true, lower, upper, target_coverage=0.95):
    current_coverage = ((y_true >= lower) & (y_true <= upper)).mean()
    print(f"Current coverage before calibration: {current_coverage:.2%}")
    factor = target_coverage / current_coverage if current_coverage > 0 else 1.0
    return factor

# Use a validation set for calibration
scale_factor = calibrate_intervals(y_test, y_pred_lower, y_pred_upper, target_coverage=0.95)
print(f"Scaling factor for intervals: {scale_factor:.2f}")

interval_mid = (y_pred_lower + y_pred_upper) / 2.0
interval_half_width = (y_pred_upper - y_pred_lower) / 2.0
y_pred_lower_cal = interval_mid - scale_factor * interval_half_width
y_pred_upper_cal = interval_mid + scale_factor * interval_half_width

coverage_cal = ((y_test >= y_pred_lower_cal) & (y_test <= y_pred_upper_cal)).mean()
print(f"Calibrated Coverage of 95% prediction interval: {coverage_cal:.2%}")

plt.figure(figsize=(12, 6))
plt.plot(y_test.index, y_test, label="Actual Price", color="black", linewidth=1.5)
plt.plot(y_test.index, (y_pred_lower_cal + y_pred_upper_cal) / 2, label="Calibrated Median", color="blue")
plt.fill_between(y_test.index, y_pred_lower_cal, y_pred_upper_cal, color="gray", alpha=0.3,
                 label="Calibrated 95% Prediction Interval")
plt.title("Calibrated GradientBoostingRegressor Prediction Intervals")
plt.xlabel("Time")
plt.ylabel("Price (€/MWh)")
plt.legend()
plt.tight_layout()
plt.show()


# =============================================================================
# 10. Interpretability with SHAP (For Random Forest)
# =============================================================================

os.environ["OMP_NUM_THREADS"] = "4" # set the number of threads that you want
shap_cache_file = "shap_values_rf.pkl"

if os.path.exists(shap_cache_file):
    os.remove(shap_cache_file)
    print("Deleted old SHAP cache file. Recomputing SHAP values...")

# We can sample if X_test is very large:
X_test_sample = X_test

# Check for additional columns in X_test_sample
extra_columns = set(X_test_sample.columns) - set(X_train.columns)
if extra_columns:
    print("Warning: X_test_sample contains additional columns not in X_train:")
    print(sorted(extra_columns))

# Compute or load SHAP values
if os.path.exists(shap_cache_file):
    with open(shap_cache_file, "rb") as f:
        shap_values_rf = pickle.load(f)
    print("Loaded SHAP values from cache.")
else:
    explainer_rf = shap.TreeExplainer(best_rf)
    shap_values_rf = explainer_rf.shap_values(X_test_sample)
    with open(shap_cache_file, "wb") as f:
        pickle.dump(shap_values_rf, f)
    print("Computed and cached SHAP values.")

# Plot SHAP summary
shap.summary_plot(shap_values_rf, X_test_sample, plot_type="bar", max_display=10, show=False)
plt.title("Random Forest Feature Importance (Top 10 via SHAP)")
plt.tight_layout()
plt.show()

# Plot SHAP summary 
shap.summary_plot(shap_values_rf, X_test_sample, max_display=10, show=False)
plt.title("Random Forest Detailed SHAP Summary (Top 10)")
plt.tight_layout()
plt.show()


# =============================================================================
# 11. Evaluation
# =============================================================================

results = [
    ("Persistence", rmse_pers, mae_pers, dir_acc_pers, vol_cap_pers, prec_ext_pers, rec_ext_pers),
    ("LinearReg", rmse_lin, mae_lin, dir_acc_lin, vol_cap_lin, prec_ext_lin, rec_ext_lin),
    ("RandomForest", rmse_rf, mae_rf, dir_acc_rf, vol_cap_rf, prec_ext_rf, rec_ext_rf),

]
results_df = pd.DataFrame(results, columns=[
    "Model", "RMSE", "MAE", "Directional_Accuracy", "Volatility_Capture",
    "Extreme_Precision", "Extreme_Recall"
])
print("\n=== FINAL RESULTS SUMMARY ===")
print(results_df)

results_df.plot(x="Model", y=["RMSE", "MAE"], kind="bar", figsize=(10, 5),
                title="Error Metrics by Model")
plt.tight_layout()
plt.show()

print("\nScript Complete!")
