# CASE 1

In [None]:
import torch
import numpy as np
import pandas as pd
import joblib
import plotly.graph_objects as go

# ---- PATHS ----
MODEL_PATH = "/new_model_test/best_model_SE2_Down_Volume.pth"
SCALER_X_PATH = "/new_model_test/scaler_X_SE2_Down_Volume.joblib"
SCALER_Y_PATH = "/new_model_test/scaler_y_SE2_Down_Volume.joblib"
FILE_PATH = "/results/results_merged.csv"

REGION = "SE2"
target_vars = [f"Balancing_ActivatedBalancingEnergy_{REGION}_mFRR_NotSpecifiedDownActivatedVolume"]

# -- Model class --
class MaskedLSTM(torch.nn.Module):
    def __init__(self, input_size, hidden_size=32, num_layers=1, dense_size=32, output_size=1, horizon=1, dropout=0.08, dropout_lstm=0.27, bidirectional=True):
        super().__init__()
        self.horizon = horizon
        self.output_size = output_size
        actual_in = input_size * 2
        self.lstm = torch.nn.LSTM(
            input_size=actual_in,
            hidden_size=hidden_size,
            num_layers=num_layers,
            dropout=dropout_lstm if num_layers > 1 else 0.0,
            bidirectional=bidirectional,
            batch_first=True,
        )
        hidden_out = hidden_size * (2 if bidirectional else 1)
        self.fc1 = torch.nn.Linear(hidden_out, dense_size)
        self.fc2 = torch.nn.Linear(dense_size, horizon * output_size)
        self.relu = torch.nn.ReLU()
        self.drop = torch.nn.Dropout(dropout)
        self.drop_lstm = torch.nn.Dropout(dropout_lstm)
    def forward(self, x, mask):
        x = torch.cat((x, mask), dim=2)
        out, _ = self.lstm(x)
        out = self.drop_lstm(out)
        out = self.relu(self.fc1(out[:, -1, :]))
        out = self.drop(out)
        out = self.fc2(out)
        return out.view(out.size(0), self.horizon, self.output_size)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
scaler_X = joblib.load(SCALER_X_PATH)
scaler_y = joblib.load(SCALER_Y_PATH)

best_params = {
    "hidden_size": 32,
    "num_layers": 1,
    "bidirectional": True,
    "dense_size": 32,
    "dropout": 0.08137184695714653,
    "dropout_lstm": 0.2675438586833995,
    "learning_rate": 0.0013843186375837144,
    "batch_size": 128,
    "seq_length": 216,
}
seq_length = best_params["seq_length"]

# --- Load and preprocess data ---
df_all = pd.read_csv(FILE_PATH)
price_column_name_to_plot = f"Balancing_PricesOfActivatedBalancingEnergy_{REGION}_mFRR_NotSpecifiedDownPrice"

# Store price data before dropping columns
price_data_full = None
if price_column_name_to_plot in df_all.columns:
    price_data_full = df_all[['DateTime', price_column_name_to_plot]].copy()
    price_data_full['DateTime'] = pd.to_datetime(price_data_full['DateTime'], utc=True)
    price_data_full.set_index('DateTime', inplace=True)
    print(f"Stored '{price_column_name_to_plot}' separately for plotting.")
else:
    print(f"Warning: Price column '{price_column_name_to_plot}' not found in df_all. Will not be plotted.")

from filter_features import pick_region_filter
BDZ_filter = pick_region_filter(region=REGION, remove_balancing=True)
drop_list = [
    f"Balancing_PricesOfActivatedBalancingEnergy_{REGION}_mFRR_NotSpecifiedUpPrice",
    f"Balancing_PricesOfActivatedBalancingEnergy_{REGION}_mFRR_NotSpecifiedDownPrice",
    f"Balancing_ActivatedBalancingEnergy_{REGION}_mFRR_NotSpecifiedUpActivatedVolume",
    f"Balancing_ActivatedBalancingEnergy_{REGION}_mFRR_NotSpecifiedDownActivatedVolume",
]
BDZ_filter = [c for c in BDZ_filter if c not in drop_list or c in target_vars]
df_all["DateTime"] = pd.to_datetime(df_all["DateTime"], utc=True)
df_all.set_index("DateTime", inplace=True)
df_all = df_all.asfreq("h")
df = df_all[BDZ_filter].copy()
df = df.asfreq("h")




CREATE_INDICATOR_FEATURES = False
CREATE_CYCLICAL_FEATURES = False
CREATE_LAGGED_FEATURES = False

new_indicator_features = []
if CREATE_INDICATOR_FEATURES:
    indicator_features_to_create = [col for col in BDZ_filter + target_vars if col in df.columns]
    for col in indicator_features_to_create:
        df[f'{col}_was_missing'] = df[col].isnull().astype(int)
    new_indicator_features = [f'{col}_was_missing' for col in indicator_features_to_create]

new_time_features = []
if CREATE_CYCLICAL_FEATURES:
    df['hour_sin'] = np.sin(2 * np.pi * df.index.hour / 24.0)
    df['hour_cos'] = np.cos(2 * np.pi * df.index.hour / 24.0)
    df['dayofweek_sin'] = np.sin(2 * np.pi * df.index.dayofweek / 7.0)
    df['dayofweek_cos'] = np.cos(2 * np.pi * df.index.dayofweek / 7.0)
    df['weekofyear_sin'] = np.sin(2 * np.pi * df.index.isocalendar().week / 52.0)
    df['weekofyear_cos'] = np.cos(2 * np.pi * df.index.isocalendar().week / 52.0)
    new_time_features = ['hour_sin', 'hour_cos', 'dayofweek_sin', 'dayofweek_cos', 'weekofyear_sin', 'weekofyear_cos']

new_lagged_features = []
if CREATE_LAGGED_FEATURES:
    for target_var in target_vars:
        df[f'{target_var}_lag1'] = df[target_var].shift(1)
        df[f'{target_var}_lag24'] = df[target_var].shift(24)
        df[f'{target_var}_lag168'] = df[target_var].shift(168)
    for target_var in target_vars:
        new_lagged_features.extend([f'{target_var}_lag1', f'{target_var}_lag24', f'{target_var}_lag168'])

df.dropna(axis=1, how="all", inplace=True)

exogenous_vars = [c for c in BDZ_filter if c not in target_vars]
if CREATE_CYCLICAL_FEATURES:
    exogenous_vars.extend(new_time_features)
if CREATE_LAGGED_FEATURES:
    exogenous_vars.extend(new_lagged_features)
if CREATE_INDICATOR_FEATURES:
    exogenous_vars.extend(new_indicator_features)
exogenous_vars = sorted(list(set(var for var in exogenous_vars if var in df.columns)))

# --- Instantiate the model ---
model = MaskedLSTM(
    input_size=len(exogenous_vars),
    hidden_size=best_params["hidden_size"],
    num_layers=best_params["num_layers"],
    dense_size=best_params["dense_size"],
    output_size=len(target_vars),
    dropout=best_params["dropout"],
    dropout_lstm=best_params["dropout_lstm"],
    bidirectional=best_params["bidirectional"],
).to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()

# --- Robust window selection: will always work for any seq_length and date range ---

desired_start = pd.Timestamp("2023-09-05", tz="UTC")
desired_end = pd.Timestamp("2023-09-08", tz="UTC")

all_dates = df.index
try:
    first_pred_idx = all_dates.get_loc(desired_start)
except KeyError:
    raise ValueError(f"Desired start {desired_start} not in DataFrame index!")

if first_pred_idx < seq_length:
    raise ValueError(f"Not enough data before {desired_start} to build input sequence with seq_length={seq_length}.")

window_start = all_dates[first_pred_idx - seq_length]
window_end = desired_end

df_pred_window = df.loc[window_start:window_end].copy()

# -- Scale features --
X_scaled = scaler_X.transform(df_pred_window[exogenous_vars])
y_true = df_pred_window[target_vars].values

# -- Build sequences --
def create_sequences(X, lookback, horizon=1):
    num_samples = X.shape[0] - lookback - horizon + 1
    X_seq = np.zeros((num_samples, lookback, X.shape[1]), dtype=np.float32)
    mask_seq = np.ones((num_samples, lookback, X.shape[1]), dtype=np.float32)
    for i in range(num_samples):
        X_seq[i] = X[i : i + lookback]
    return X_seq, mask_seq

X_seq, mask_seq = create_sequences(X_scaled, seq_length, 1)
nan_pos = np.isnan(X_seq)
mask_seq[nan_pos] = 0.0
X_seq[nan_pos] = 0.0

dates_seq = df_pred_window.index[seq_length : seq_length + len(X_seq)]
y_true_aligned = y_true[seq_length : seq_length + len(X_seq)]

# -- Filter for predictions within the requested window only --
mask = (dates_seq >= desired_start) & (dates_seq <= desired_end)
dates_seq = dates_seq[mask]
y_true_aligned = y_true_aligned[mask]

with torch.no_grad():
    X_tensor = torch.tensor(X_seq, dtype=torch.float32).to(DEVICE)
    mask_tensor = torch.tensor(mask_seq, dtype=torch.float32).to(DEVICE)
    y_pred_scaled = model(X_tensor, mask_tensor).cpu().numpy()[:, 0, :]  # (n, target_dim)
y_pred = scaler_y.inverse_transform(y_pred_scaled)
y_pred = y_pred[mask]

# --- Add price data ---
price_data_to_plot = None
if price_data_full is not None and price_column_name_to_plot in price_data_full.columns:
    try:
        # Align price data with dates_seq
        price_data_to_plot = price_data_full.loc[dates_seq, price_column_name_to_plot].values.astype(float)
    except Exception as e:
        print(f"Warning: Could not align price data with dates_seq: {e}")
        price_data_to_plot = None

# -- Plot
fig = go.Figure()
fig.add_trace(go.Scatter(
    x=dates_seq,
    y=y_true_aligned[:, 0],
    mode='lines+markers',
    name='True',
    marker=dict(symbol='circle', size=6)
))
fig.add_trace(go.Scatter(
    x=dates_seq,
    y=y_pred[:, 0],
    mode='lines+markers',
    name='Predicted',
    marker=dict(symbol='x', size=6)
))
if price_data_to_plot is not None:
    fig.add_trace(go.Scatter(
        x=dates_seq,
        y=price_data_to_plot,
        mode='lines',
        name='mFRR Down Price',
        line=dict(color='orange', width=2),
        yaxis='y2'
    ))
fig.update_layout(
    title=f"True vs Predicted for {target_vars[0]} ({desired_start.date()} to {desired_end.date()})",
    xaxis_title="Date",
    yaxis=dict(
        title="Volume (MW)"
    ),
    yaxis2=dict(
        title="mFRR Down Price (€/MWh)",
        overlaying='y',
        side='right',
        showgrid=False
    ),
    template="plotly_white",
    legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1)
)
fig.update_xaxes(tickformat="%d %b %Y %H:%M")
fig.show()


## Case 1 Plot

In [46]:
desired_start = pd.Timestamp("2023-09-05", tz="UTC")
desired_end = pd.Timestamp("2023-09-07", tz="UTC")

In [None]:
import torch
import numpy as np
import pandas as pd
import joblib
import plotly.graph_objects as go

# ---- PATHS ----
MODEL_PATH = "/new_model_test/best_model_SE2_Down_Volume.pth"
SCALER_X_PATH = "/new_model_test/scaler_X_SE2_Down_Volume.joblib"
SCALER_Y_PATH = "/new_model_test/scaler_y_SE2_Down_Volume.joblib"
FILE_PATH = "/results/results_merged.csv"

REGION = "SE2"
target_vars = [f"Balancing_ActivatedBalancingEnergy_{REGION}_mFRR_NotSpecifiedDownActivatedVolume"]

# -- Model class --
class MaskedLSTM(torch.nn.Module):
    def __init__(self, input_size, hidden_size=32, num_layers=1, dense_size=32, output_size=1, horizon=1, dropout=0.08, dropout_lstm=0.27, bidirectional=True):
        super().__init__()
        self.horizon = horizon
        self.output_size = output_size
        actual_in = input_size * 2
        self.lstm = torch.nn.LSTM(
            input_size=actual_in,
            hidden_size=hidden_size,
            num_layers=num_layers,
            dropout=dropout_lstm if num_layers > 1 else 0.0,
            bidirectional=bidirectional,
            batch_first=True,
        )
        hidden_out = hidden_size * (2 if bidirectional else 1)
        self.fc1 = torch.nn.Linear(hidden_out, dense_size)
        self.fc2 = torch.nn.Linear(dense_size, horizon * output_size)
        self.relu = torch.nn.ReLU()
        self.drop = torch.nn.Dropout(dropout)
        self.drop_lstm = torch.nn.Dropout(dropout_lstm)
    def forward(self, x, mask):
        x = torch.cat((x, mask), dim=2)
        out, _ = self.lstm(x)
        out = self.drop_lstm(out)
        out = self.relu(self.fc1(out[:, -1, :]))
        out = self.drop(out)
        out = self.fc2(out)
        return out.view(out.size(0), self.horizon, self.output_size)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
scaler_X = joblib.load(SCALER_X_PATH)
scaler_y = joblib.load(SCALER_Y_PATH)

best_params = {
    "hidden_size": 32,
    "num_layers": 1,
    "bidirectional": True,
    "dense_size": 32,
    "dropout": 0.08137184695714653,
    "dropout_lstm": 0.2675438586833995,
    "learning_rate": 0.0013843186375837144,
    "batch_size": 128,
    "seq_length": 216,
}
seq_length = best_params["seq_length"]

# --- Load and preprocess data ---
df_all = pd.read_csv(FILE_PATH)
price_column_name_to_plot = f"Balancing_PricesOfActivatedBalancingEnergy_{REGION}_mFRR_NotSpecifiedDownPrice"

# Store price data before dropping columns
price_data_full = None
if price_column_name_to_plot in df_all.columns:
    price_data_full = df_all[['DateTime', price_column_name_to_plot]].copy()
    price_data_full['DateTime'] = pd.to_datetime(price_data_full['DateTime'], utc=True)
    price_data_full.set_index('DateTime', inplace=True)
    print(f"Stored '{price_column_name_to_plot}' separately for plotting.")
else:
    print(f"Warning: Price column '{price_column_name_to_plot}' not found in df_all. Will not be plotted.")

from filter_features import pick_region_filter
BDZ_filter = pick_region_filter(region=REGION, remove_balancing=True)
drop_list = [
    f"Balancing_PricesOfActivatedBalancingEnergy_{REGION}_mFRR_NotSpecifiedUpPrice",
    f"Balancing_PricesOfActivatedBalancingEnergy_{REGION}_mFRR_NotSpecifiedDownPrice",
    f"Balancing_ActivatedBalancingEnergy_{REGION}_mFRR_NotSpecifiedUpActivatedVolume",
    f"Balancing_ActivatedBalancingEnergy_{REGION}_mFRR_NotSpecifiedDownActivatedVolume",
]
BDZ_filter = [c for c in BDZ_filter if c not in drop_list or c in target_vars]
df_all["DateTime"] = pd.to_datetime(df_all["DateTime"], utc=True)
df_all.set_index("DateTime", inplace=True)
df_all = df_all.asfreq("h")
df = df_all[BDZ_filter].copy()
df = df.asfreq("h")

def fill_ahead_features(df: pd.DataFrame) -> pd.DataFrame:
    """
    Forward-fills 0/NaN gaps in *-ahead* columns within their natural period
    (day-ahead → daily, week-ahead → Monday-anchored weeks, etc.).

    A value is propagated only up to the end of the period, and only while the
    entries being filled are 0 or NaN.
    """
    df = df.copy()

    period_freq = {
        "dayahead":  "D",        # daily groups
        "weekahead": "W-MON",    # ISO weeks starting on Monday
        "monthahead":"MS",       # month start
        "yearahead": "AS"        # year start
    }

    for period, freq in period_freq.items():
        ahead_cols = [c for c in df.columns if period in c.lower()]
        if not ahead_cols:
            continue

        for col in ahead_cols:
            s = df[col]

            # treat *strictly* 0 as missing, but preserve genuine zeros by
            # forward-filling only into 0/NaN slots
            s_filled = (
                s.replace(0, np.nan)                       # step 1: 0 → NaN
                 .groupby(pd.Grouper(freq=freq))           # step 2: group
                 .ffill()                                  # step 3: ffill inside group
                 .fillna(0)                                # step 4: keep leading zeros
            )

            df[col] = np.where(
                (s == 0) | s.isna(),        # fill only where original was 0/NaN
                s_filled,                   #   …with the forward-filled value
                s                           # keep genuine entries untouched
            )

    return df

# 


CREATE_INDICATOR_FEATURES = False
CREATE_CYCLICAL_FEATURES = False
CREATE_LAGGED_FEATURES = False

new_indicator_features = []
if CREATE_INDICATOR_FEATURES:
    indicator_features_to_create = [col for col in BDZ_filter + target_vars if col in df.columns]
    for col in indicator_features_to_create:
        df[f'{col}_was_missing'] = df[col].isnull().astype(int)
    new_indicator_features = [f'{col}_was_missing' for col in indicator_features_to_create]

new_time_features = []
if CREATE_CYCLICAL_FEATURES:
    df['hour_sin'] = np.sin(2 * np.pi * df.index.hour / 24.0)
    df['hour_cos'] = np.cos(2 * np.pi * df.index.hour / 24.0)
    df['dayofweek_sin'] = np.sin(2 * np.pi * df.index.dayofweek / 7.0)
    df['dayofweek_cos'] = np.cos(2 * np.pi * df.index.dayofweek / 7.0)
    df['weekofyear_sin'] = np.sin(2 * np.pi * df.index.isocalendar().week / 52.0)
    df['weekofyear_cos'] = np.cos(2 * np.pi * df.index.isocalendar().week / 52.0)
    new_time_features = ['hour_sin', 'hour_cos', 'dayofweek_sin', 'dayofweek_cos', 'weekofyear_sin', 'weekofyear_cos']

new_lagged_features = []
if CREATE_LAGGED_FEATURES:
    for target_var in target_vars:
        df[f'{target_var}_lag1'] = df[target_var].shift(1)
        df[f'{target_var}_lag24'] = df[target_var].shift(24)
        df[f'{target_var}_lag168'] = df[target_var].shift(168)
    for target_var in target_vars:
        new_lagged_features.extend([f'{target_var}_lag1', f'{target_var}_lag24', f'{target_var}_lag168'])

df.dropna(axis=1, how="all", inplace=True)

exogenous_vars = [c for c in BDZ_filter if c not in target_vars]
if CREATE_CYCLICAL_FEATURES:
    exogenous_vars.extend(new_time_features)
if CREATE_LAGGED_FEATURES:
    exogenous_vars.extend(new_lagged_features)
if CREATE_INDICATOR_FEATURES:
    exogenous_vars.extend(new_indicator_features)
exogenous_vars = sorted(list(set(var for var in exogenous_vars if var in df.columns)))

# --- Instantiate the model ---
model = MaskedLSTM(
    input_size=len(exogenous_vars),
    hidden_size=best_params["hidden_size"],
    num_layers=best_params["num_layers"],
    dense_size=best_params["dense_size"],
    output_size=len(target_vars),
    dropout=best_params["dropout"],
    dropout_lstm=best_params["dropout_lstm"],
    bidirectional=best_params["bidirectional"],
).to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()

# --- Robust window selection: will always work for any seq_length and date range ---

# desired_start = pd.Timestamp("2023-09-05", tz="UTC")
# desired_end = pd.Timestamp("2023-09-08", tz="UTC")

all_dates = df.index
try:
    first_pred_idx = all_dates.get_loc(desired_start)
except KeyError:
    raise ValueError(f"Desired start {desired_start} not in DataFrame index!")

if first_pred_idx < seq_length:
    raise ValueError(f"Not enough data before {desired_start} to build input sequence with seq_length={seq_length}.")

window_start = all_dates[first_pred_idx - seq_length]
window_end = desired_end

df_pred_window = df.loc[window_start:window_end].copy()

# -- Scale features --
X_scaled = scaler_X.transform(df_pred_window[exogenous_vars])
y_true = df_pred_window[target_vars].values

# -- Build sequences --
def create_sequences(X, lookback, horizon=1):
    num_samples = X.shape[0] - lookback - horizon + 1
    X_seq = np.zeros((num_samples, lookback, X.shape[1]), dtype=np.float32)
    mask_seq = np.ones((num_samples, lookback, X.shape[1]), dtype=np.float32)
    for i in range(num_samples):
        X_seq[i] = X[i : i + lookback]
    return X_seq, mask_seq

X_seq, mask_seq = create_sequences(X_scaled, seq_length, 1)
nan_pos = np.isnan(X_seq)
mask_seq[nan_pos] = 0.0
X_seq[nan_pos] = 0.0

dates_seq = df_pred_window.index[seq_length : seq_length + len(X_seq)]
y_true_aligned = y_true[seq_length : seq_length + len(X_seq)]

# -- Filter for predictions within the requested window only --
mask = (dates_seq >= desired_start) & (dates_seq <= desired_end)
dates_seq = dates_seq[mask]
y_true_aligned = y_true_aligned[mask]

with torch.no_grad():
    X_tensor = torch.tensor(X_seq, dtype=torch.float32).to(DEVICE)
    mask_tensor = torch.tensor(mask_seq, dtype=torch.float32).to(DEVICE)
    y_pred_scaled = model(X_tensor, mask_tensor).cpu().numpy()[:, 0, :]  # (n, target_dim)
y_pred = scaler_y.inverse_transform(y_pred_scaled)
y_pred = y_pred[mask]

# --- Add price data ---
price_data_to_plot = None
if price_data_full is not None and price_column_name_to_plot in price_data_full.columns:
    try:
        # Align price data with dates_seq
        price_data_to_plot = price_data_full.loc[dates_seq, price_column_name_to_plot].values.astype(float)
    except Exception as e:
        print(f"Warning: Could not align price data with dates_seq: {e}")
        price_data_to_plot = None

import matplotlib.pyplot as plt
import matplotlib.dates as mdates

# --- Colors and markers ---
true_color = "#636efa"    # Plotly default blue
pred_color = "#ef553b"    # Plotly default red
price_color = "orange"

# --- Matplotlib plot ---
fig, ax1 = plt.subplots(figsize=(12, 7))

# True values
ax1.plot(dates_seq, y_true_aligned[:, 0], label='True', marker='o', markersize=5, linewidth=2)

# Predicted values
ax1.plot(dates_seq, y_pred[:, 0], label='Predicted', color=pred_color, marker='x', markersize=5, linewidth=2)

ax1.set_xlabel("Date", fontsize=17)
ax1.set_ylabel("Volume (MW)", fontsize=17)
ax1.tick_params(axis='y', labelsize=16)
ax1.tick_params(axis='x', labelsize=16)
ax1.xaxis.set_major_formatter(mdates.DateFormatter('%d %b %Y\n%H:%M'))
plt.xticks(rotation=45)
plt.grid(axis='y', alpha=0.3)

# Add price on secondary y-axis
if price_data_to_plot is not None:
    ax2 = ax1.twinx()
    ax2.plot(dates_seq, price_data_to_plot, label='mFRR Down Price', color=price_color, linewidth=2)
    ax2.set_ylabel("mFRR Down Price (€/MWh)", fontsize=17, color=price_color)
    ax2.tick_params(axis='y', labelcolor=price_color, labelsize=16)
    ax2.grid(False)
else:
    ax2 = None

# Combined legend
lines_labels = [ax.get_legend_handles_labels() for ax in [ax1, ax2] if ax is not None]
lines, labels = [sum(lol, []) for lol in zip(*lines_labels)]
plt.legend(lines, labels, loc='upper right', fontsize=16, frameon=True)

plt.tight_layout()
plt.show()



## Feature importance

In [48]:
ATTR_START = "2023-09-05 19:00"
ATTR_END   = "2023-09-06 08:00"

from datetime import datetime
ATTR_START_dt = datetime.strptime(ATTR_START, "%Y-%m-%d %H:%M")
ATTR_END_dt   = datetime.strptime(ATTR_END, "%Y-%m-%d %H:%M")

## IG

In [None]:
attr_summaries = {}       # master container
# 


import os
import numpy as np
import torch
from captum.attr import IntegratedGradients
import matplotlib.pyplot as plt
import pandas as pd                 # already imported above

# ------------------------------------------------------------
# 0.  CONFIG & OUTPUT DIR
# ------------------------------------------------------------
plots_dir = "attribution_plots"
os.makedirs(plots_dir, exist_ok=True)

# ATTR_START = "2023-09-05 19:00"
# ATTR_END   = "2023-09-06 08:00"
attribution_times = pd.date_range(
    ATTR_START, ATTR_END, freq="h", tz="UTC"
)
print("Attribution times:", attribution_times)

# ------------------------------------------------------------
# 1.  BUILD INPUT / MASK TENSORS *** with the SAME preprocessing as inference ***
# ------------------------------------------------------------
input_seqs, mask_seqs = [], []

for t in attribution_times:
    if t not in df.index:
        print(f"⚠️  {t} missing in df.index – skipped");   continue
    idx = df.index.get_loc(t)
    if idx < seq_length:
        print(f"⚠️  Not enough history before {t} – skipped");   continue

    # raw window (seq_len × features) --------------------------
    window_raw = df.iloc[idx - seq_length: idx][exogenous_vars].values

    # ---> scale exactly like in inference
    window_scaled = scaler_X.transform(window_raw)

    # ---> build mask (1 = present, 0 = missing)
    mask = (~np.isnan(window_scaled)).astype(np.float32)

    # ---> impute NaNs with 0  (model expects this)
    window_scaled[np.isnan(window_scaled)] = 0.0

    input_seqs.append(window_scaled.astype(np.float32))
    mask_seqs.append(mask)

if not input_seqs:
    raise RuntimeError("No valid windows after filtering.")

X_attr      = torch.tensor(np.stack(input_seqs)).to(DEVICE)       # (N, L, F)
X_mask_attr = torch.tensor(np.stack(mask_seqs)).to(DEVICE)        # (N, L, F)

# quick sanity-check
print("Any NaNs after preprocessing?  X:", torch.isnan(X_attr).any().item(),
      " mask:", torch.isnan(X_mask_attr).any().item())

# ------------------------------------------------------------
# 2.  BASELINES  (in scaled space, no NaNs)
# ------------------------------------------------------------
zeros_X  = torch.zeros_like(X_attr[:1])
median_X = torch.median(X_attr, dim=0, keepdim=True).values

# --- Compute mean baseline (classic mean) --------------------
mean_X = torch.mean(X_attr, dim=0, keepdim=True)
mean_M = X_mask_attr[:1]                    # reuse mask for consistency

# Sanity check
assert not torch.isnan(mean_X).any(), "mean baseline still has NaNs!"
assert mean_X.shape  == (1, seq_length, len(exogenous_vars))
assert mean_M.shape  == (1, seq_length, len(exogenous_vars))

baselines = {
    "zeros": (zeros_X,  X_mask_attr[:1]),
    "median": (median_X, X_mask_attr[:1]),
    "mean":   (mean_X,   mean_M),
}

# ------------------------------------------------------------
# 3.  INTEGRATED GRADIENTS  – compatible with Captum < 0.7
# ------------------------------------------------------------
def model_forward(x, m):
    return model(x, m)            # (batch, 1, 1)

ig = IntegratedGradients(model_forward)

attr_results = {}
for name, (base_x, base_m) in baselines.items():
    # >>>> the 2-tuple style works in every Captum version
    (attr_x, attr_mask), delta = ig.attribute(
        inputs=(X_attr, X_mask_attr),
        baselines=(base_x.expand_as(X_attr),
                   base_m.expand_as(X_mask_attr)),
        target=0,
        n_steps=64,
        method="riemann_trapezoid",
        internal_batch_size=32,
        return_convergence_delta=True
    )

    print(f"{name:9s} | attr range [{attr_x.min():.2e}, {attr_x.max():.2e}] "
          f"δ-mean={delta.abs().mean():.2e}")

    attr_results[name] = attr_x.cpu().numpy()   # (N, L, F) – keep only X


# ------------------------------------------------------------
# 4.  AGGREGATE IMPORTANCE  (unchanged)
#     • |attr|  →  mean over (sample, time)  →  normalise to %
# ------------------------------------------------------------
attr_summary = {}
for name, a in attr_results.items():
    importance = np.mean(np.abs(a), axis=(0, 1))  # (F,)
    importance /= importance.sum() + 1e-12        # convert to %
    attr_summary[name] = importance

# ------------------------------------------------------------
# 5.  PLOT
# ------------------------------------------------------------
top_k   = 10
colors  = ["#3498db", "#2ecc71", "#e74c3c"]   # enough for 3 baselines
fig, axes = plt.subplots(2, 3, figsize=(20, 8))
axes = axes.flatten()

for i, (name, imp) in enumerate(attr_summary.items()):
    idx_sorted = np.argsort(imp)[-top_k:][::-1]
    feature_indices = idx_sorted      # indices into exogenous_vars
    ids = [f"#{idx+1}" for idx in feature_indices]   # consistent ID = index+1

    # --- bars ---
    ax = axes[i]
    ax.bar(range(top_k), imp[feature_indices] * 100, color=colors[i])       # %
    ax.set_xticks(range(top_k))
    ax.set_xticklabels(ids)
    ax.set_xlabel("Feature ID")
    ax.set_ylabel("Contribution (%)")
    ax.set_title(f"Integrated Gradients – {name}")

    # --- table ---
    ax_t = axes[i + 3]
    table_data = [[ids[j], exogenous_vars[feature_indices[j]]]
                  for j in range(top_k)]
    ax_t.axis("off")
    ax_t.set_title(f"{name}  mapping", pad=10)
    tbl = ax_t.table(
        cellText=table_data, colLabels=["ID", "Feature"],
        cellLoc="left", loc="center",
        colWidths=[0.07, 0.93],
    )
    tbl.auto_set_font_size(False)
    tbl.set_fontsize(9)
    tbl.scale(1.2, 1.2)


plt.suptitle("Feature Attribution (Integrated Gradients)\n"
             f"{ATTR_START} – {ATTR_END}", fontsize=16)
plt.tight_layout(rect=[0, 0, 1, 0.96])

out_path = os.path.join(plots_dir, "integrated_gradients_selected_window.png")
plt.savefig(out_path, dpi=300, bbox_inches="tight")
plt.show()
print("Saved ➜", out_path)

attr_summaries["IG"] = {k: v.copy() for k, v in attr_summary.items()}


## DL

In [None]:
import os, numpy as np, torch, matplotlib.pyplot as plt, pandas as pd
from captum.attr import DeepLiftShap

# ------------------------------------------------------------
# 0.  CONFIG & OUTPUT DIR
# ------------------------------------------------------------
plots_dir = "attribution_plots";  os.makedirs(plots_dir, exist_ok=True)
# ATTR_START = "2023-09-05 19:00"
# ATTR_END = "2023-09-06 08:00"
attribution_times = pd.date_range(ATTR_START, ATTR_END, freq="h", tz="UTC")
print("Attribution times:", attribution_times)

# ------------------------------------------------------------
# 1.  BUILD INPUT / MASK TENSORS (unchanged)
# ------------------------------------------------------------
input_seqs, mask_seqs = [], []
for t in attribution_times:
    if t not in df.index:
        print(f"⚠️  {t} missing in df.index – skipped");   continue
    idx = df.index.get_loc(t)
    if idx < seq_length:
        print(f"⚠️  Not enough history before {t} – skipped");   continue
    window_raw = df.iloc[idx - seq_length: idx][exogenous_vars].values
    window_scaled = scaler_X.transform(window_raw)
    mask = (~np.isnan(window_scaled)).astype(np.float32)
    window_scaled[np.isnan(window_scaled)] = 0.0
    input_seqs.append(window_scaled.astype(np.float32));  mask_seqs.append(mask)

if not input_seqs:  raise RuntimeError("No valid windows after filtering.")

X_attr      = torch.tensor(np.stack(input_seqs)).to(DEVICE)      # (N, L, F)
X_mask_attr = torch.tensor(np.stack(mask_seqs)).to(DEVICE)       # (N, L, F)
print("Any NaNs? X:", torch.isnan(X_attr).any().item(),
      "mask:", torch.isnan(X_mask_attr).any().item())

# ------------------------------------------------------------
# 2.  BASELINES (mean instead of pre-spike)
# ------------------------------------------------------------
zeros_X  = torch.zeros_like(X_attr[:1])
median_X = torch.median(X_attr, dim=0, keepdim=True).values

# Compute mean baseline over all valid attribution windows
mean_raw = []
for t in attribution_times:
    idx = df.index.get_loc(t)
    if idx < seq_length:
        continue
    window_raw = df.iloc[idx - seq_length: idx][exogenous_vars]
    mean_raw.append(window_raw.values)
if not mean_raw:
    raise ValueError("No valid windows to compute mean baseline.")
mean_raw = np.mean(np.stack(mean_raw), axis=0)  # shape: (L, F)
mean_scaled = scaler_X.transform(mean_raw)
mask_mean = (~np.isnan(mean_scaled)).astype(np.float32)
mean_scaled[np.isnan(mean_scaled)] = 0.0
baseline_mean_X = torch.tensor(mean_scaled, dtype=torch.float32).unsqueeze(0).to(DEVICE)
baseline_mean_M = torch.tensor(mask_mean,        dtype=torch.float32).unsqueeze(0).to(DEVICE)

baselines = {
    "zeros":     (zeros_X,        X_mask_attr[:1]),
    "median":    (median_X,       X_mask_attr[:1]),
    "mean":      (baseline_mean_X, baseline_mean_M),
}

# ------------------------------------------------------------
# 3.  DEEPLIFT SHAP (wrapped model)
# ------------------------------------------------------------
class Wrapper(torch.nn.Module):
    def __init__(self, core):  super().__init__();  self.core = core
    def forward(self, x, m):   return self.core(x, m)

dlshap = DeepLiftShap(Wrapper(model))

attr_results = {}
for name, (base_x, base_m) in baselines.items():
    (attr_x, _attr_m), delta = dlshap.attribute(
        inputs=(X_attr, X_mask_attr),
        baselines=(base_x.expand_as(X_attr),
                   base_m.expand_as(X_mask_attr)),
        target=0,
        return_convergence_delta=True
    )
    print(f"{name:9s} | attr range "
          f"[{attr_x.min():.2e}, {attr_x.max():.2e}] "
          f"δ-mean={delta.abs().mean():.2e}")
    attr_results[name] = attr_x.detach().cpu().numpy()   # (N, L, F)

# ------------------------------------------------------------
# 4.  AGGREGATE IMPORTANCE (unchanged)
# ------------------------------------------------------------
attr_summary = {}
for name, a in attr_results.items():
    importance = np.mean(np.abs(a), axis=(0, 1))
    importance /= importance.sum() + 1e-12
    attr_summary[name] = importance

# ------------------------------------------------------------
# 5.  PLOT (only titles & file name changed)
# ------------------------------------------------------------
top_k, colors = 10, ["#3498db", "#2ecc71", "#e74c3c"]
fig, axes = plt.subplots(2, 3, figsize=(20, 8));  axes = axes.flatten()

for i, (name, imp) in enumerate(attr_summary.items()):
    idx_sorted = np.argsort(imp)[-top_k:][::-1];  ids = [f"#{idx+1}" for idx in idx_sorted]
    ax = axes[i]
    ax.bar(range(top_k), imp[idx_sorted] * 100, color=colors[i])
    ax.set_xticks(range(top_k));  ax.set_xticklabels(ids)
    ax.set_xlabel("Feature ID");   ax.set_ylabel("Contribution (%)")
    ax.set_title(f"DeepLift SHAP – {name}")
    ax_t = axes[i + 3];  ax_t.axis("off")
    ax_t.set_title(f"{name}  mapping", pad=10)
    tbl = ax_t.table(cellText=[[ids[j], exogenous_vars[idx_sorted[j]]] for j in range(top_k)],
                     colLabels=["ID", "Feature"], cellLoc="left", loc="center",
                     colWidths=[0.07, 0.93])
    tbl.auto_set_font_size(False); tbl.set_fontsize(9); tbl.scale(1.2, 1.2)

plt.suptitle("Feature Attribution (DeepLift SHAP)\n"
             f"{ATTR_START} – {ATTR_END}", fontsize=16)
plt.tight_layout(rect=[0,0,1,0.96])
out_path = os.path.join(plots_dir, "deeplift_shap_selected_window.png")
plt.savefig(out_path, dpi=300, bbox_inches="tight")
plt.show();  print("Saved ➜", out_path)

attr_summaries["DL_SHAP"] = {k: v.copy() for k, v in attr_summary.items()}


## FA

In [None]:
import os
import numpy as np
import torch
from captum.attr import FeatureAblation
import matplotlib.pyplot as plt
import pandas as pd

# ------------------------------------------------------------
# 0.  CONFIG & OUTPUT DIR
# ------------------------------------------------------------
plots_dir = "attribution_plots"
os.makedirs(plots_dir, exist_ok=True)

# ATTR_START = "2023-09-05 19:00"
# ATTR_END   = "2023-09-06 08:00"
attribution_times = pd.date_range(
    ATTR_START, ATTR_END, freq="h", tz="UTC"
)
print("Attribution times:", attribution_times)

# ------------------------------------------------------------
# 1.  BUILD INPUT / MASK TENSORS  *** SAME AS INFERENCE ***
# ------------------------------------------------------------
input_seqs, mask_seqs = [], []

for t in attribution_times:
    if t not in df.index:
        print(f"⚠️  {t} missing in df.index – skipped");   continue
    idx = df.index.get_loc(t)
    if idx < seq_length:
        print(f"⚠️  Not enough history before {t} – skipped");   continue

    # raw window (seq_len × features) --------------------------
    window_raw = df.iloc[idx - seq_length: idx][exogenous_vars].values

    # ---> scale exactly like in inference
    window_scaled = scaler_X.transform(window_raw)

    # ---> build mask (1 = present, 0 = missing)
    mask = (~np.isnan(window_scaled)).astype(np.float32)

    # ---> impute NaNs with 0  (model expects this)
    window_scaled[np.isnan(window_scaled)] = 0.0

    input_seqs.append(window_scaled.astype(np.float32))
    mask_seqs.append(mask)

if not input_seqs:
    raise RuntimeError("No valid windows after filtering.")

X_attr      = torch.tensor(np.stack(input_seqs)).to(DEVICE)       # (N, L, F)
X_mask_attr = torch.tensor(np.stack(mask_seqs)).to(DEVICE)        # (N, L, F)

# quick sanity-check
print("Any NaNs after preprocessing?  X:", torch.isnan(X_attr).any().item(),
      " mask:", torch.isnan(X_mask_attr).any().item())

# ------------------------------------------------------------
# 2.  BASELINES  (in scaled space, no NaNs) – USE MEAN INSTEAD OF PRE-SPIKE
# ------------------------------------------------------------
zeros_X  = torch.zeros_like(X_attr[:1])
median_X = torch.median(X_attr, dim=0, keepdim=True).values

# --- Compute mean baseline over all valid attribution windows ---
mean_raw = []
for t in attribution_times:
    idx = df.index.get_loc(t)
    if idx < seq_length:
        continue
    window_raw = df.iloc[idx - seq_length: idx][exogenous_vars]
    mean_raw.append(window_raw.values)
if not mean_raw:
    raise ValueError("No valid windows to compute mean baseline.")
mean_raw = np.mean(np.stack(mean_raw), axis=0)  # shape: (L, F)
mean_scaled = scaler_X.transform(mean_raw)
mask_mean = (~np.isnan(mean_scaled)).astype(np.float32)
mean_scaled[np.isnan(mean_scaled)] = 0.0
baseline_mean_X = torch.tensor(mean_scaled, dtype=torch.float32).unsqueeze(0).to(DEVICE)
baseline_mean_M = torch.tensor(mask_mean,        dtype=torch.float32).unsqueeze(0).to(DEVICE)

baselines = {
    "zeros":     (zeros_X,   X_mask_attr[:1]),
    "median":    (median_X,  X_mask_attr[:1]),
    # "mean":      (baseline_mean_X, baseline_mean_M),
}

# ------------------------------------------------------------
# 3.  FEATURE ABLATION  – Captum
# ------------------------------------------------------------
def model_forward(x, m):
    return model(x, m)            # (batch, 1, 1)

fa = FeatureAblation(model_forward)

attr_results = {}
for name, (base_x, base_m) in baselines.items():
    # >>>> Feature Ablation call mirrors previous IG signature
    attr_tuple = fa.attribute(
        inputs=(X_attr, X_mask_attr),
        baselines=(base_x.expand_as(X_attr),
                   base_m.expand_as(X_mask_attr)),
        target=0,
        perturbations_per_eval=32,    # controls internal batching like internal_batch_size
        feature_mask=None             # default: each scalar is its own feature
    )

    # attr_tuple is (attr_X, attr_mask)
    attr_x, _ = attr_tuple
    print(f"{name:9s} | attr range [{attr_x.min():.2e}, {attr_x.max():.2e}]")

    attr_results[name] = attr_x.cpu().numpy()   # (N, L, F) – keep only X

# ------------------------------------------------------------
# 4.  AGGREGATE IMPORTANCE  (unchanged)
#     • |attr|  →  mean over (sample, time)  →  normalise to %
# ------------------------------------------------------------
attr_summary = {}
for name, a in attr_results.items():
    importance = np.mean(np.abs(a), axis=(0, 1))  # (F,)
    importance /= importance.sum() + 1e-12        # convert to %
    attr_summary[name] = importance

# ------------------------------------------------------------
# 5.  PLOT
# ------------------------------------------------------------
top_k   = 10
colors  = ["#3498db", "#2ecc71", "#e74c3c"]
fig, axes = plt.subplots(2, 3, figsize=(20, 8))
axes = axes.flatten()

for i, (name, imp) in enumerate(attr_summary.items()):
    idx_sorted = np.argsort(imp)[-top_k:][::-1]
    feature_indices = idx_sorted      # indices into exogenous_vars
    ids = [f"#{idx+1}" for idx in feature_indices]   # consistent ID = index+1

    # --- bars ---
    ax = axes[i]
    ax.bar(range(top_k), imp[feature_indices] * 100, color=colors[i])       # %
    ax.set_xticks(range(top_k))
    ax.set_xticklabels(ids)
    ax.set_xlabel("Feature ID")
    ax.set_ylabel("Contribution (%)")
    ax.set_title(f"Feature Ablation – {name}")

    # --- table ---
    ax_t = axes[i + 3]
    table_data = [[ids[j], exogenous_vars[feature_indices[j]]]
                  for j in range(top_k)]
    ax_t.axis("off")
    ax_t.set_title(f"{name}  mapping", pad=10)
    tbl = ax_t.table(
        cellText=table_data, colLabels=["ID", "Feature"],
        cellLoc="left", loc="center",
        colWidths=[0.07, 0.93],
    )
    tbl.auto_set_font_size(False)
    tbl.set_fontsize(9)
    tbl.scale(1.2, 1.2)

plt.suptitle("Feature Attribution (Feature Ablation)\n"
             f"{ATTR_START} – {ATTR_END}", fontsize=16)
plt.tight_layout(rect=[0, 0, 1, 0.96])

out_path = os.path.join(plots_dir, "feature_ablation_selected_window.png")
plt.savefig(out_path, dpi=300, bbox_inches="tight")
plt.show()
print("Saved ➜", out_path)

attr_summaries["FA"] = {k: v.copy() for k, v in attr_summary.items()}


## Top params

In [None]:
# ------------------------------------------------------------------
# ❶  PARAMETERS & SANITY-CHECK
# ------------------------------------------------------------------
top_k = 10
methods    = list(attr_summaries.keys())                     # ['IG', 'DL_SHAP', 'FA']
baselines  = list(attr_summaries[methods[0]].keys())         # ['zeros', 'median', 'pre_spike']

print("Methods   :", methods)
print("Baselines :", baselines)

# ------------------------------------------------------------------
# ❷  COLLECT UNIQUE INDICES PER BASELINE
# ------------------------------------------------------------------
baseline_to_idxs = {b: set() for b in baselines}             # { 'zeros': set(), ... }

for method in methods:
    for baseline, importance in attr_summaries[method].items():
        top_idx = np.argsort(importance)[-top_k:]            # this method+baseline’s top-k
        baseline_to_idxs[baseline].update(top_idx)           # accumulate in the set

# ------------------------------------------------------------------
# ❸  REPORT
# ------------------------------------------------------------------
for baseline in baselines:
    idxs  = sorted(baseline_to_idxs[baseline])
    names = [exogenous_vars[i] for i in idxs]

    print(f"\nBaseline: {baseline}")
    print(f"Unique top {top_k} features across {len(methods)} methods: {len(names)}")
    print("-"*60)
    for n, feat in enumerate(names, 1):
        print(f"{n:2d}. {feat}")


## Top param plots

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

# --------------------------------------------------
# ❶ CONFIG - Enhanced for professional presentation
# --------------------------------------------------
top_k = 10
methods = ["IG", "DL_SHAP", "FA"]
# Scientific color palette (colorblind-friendly)
colors = ["#4477AA", "#66CCEE", "#EE6677"]  
bar_width = 0.25
font_family = 'serif'  # Academic standard

# Set overall matplotlib style for thesis-quality figures
plt.style.use('seaborn-v0_8-whitegrid')
mpl.rcParams['font.family'] = font_family
mpl.rcParams['font.size'] = 16
mpl.rcParams['axes.titlesize'] = 18
mpl.rcParams['axes.labelsize'] = 17
mpl.rcParams['xtick.labelsize'] = 16
mpl.rcParams['ytick.labelsize'] = 16
mpl.rcParams['legend.fontsize'] = 16
mpl.rcParams['figure.titlesize'] = 20

# --------------------------------------------------
# ❷ COLLECT UNIQUE TOP-k FEATURES PER BASELINE (exclude 'mean')
# --------------------------------------------------
all_baselines = list(attr_summaries[methods[0]].keys())
baselines = [b for b in all_baselines if b != "mean"]
baseline_to_idxs = {b: set() for b in baselines}

for m in methods:
    for b, imp in attr_summaries[m].items():
        if b == "mean":
            continue
        top_idx = np.argsort(imp)[-top_k:]
        baseline_to_idxs[b].update(top_idx)

# --------------------------------------------------
# ❸ BUILD THE FIGURE - Enhanced for thesis presentation
# --------------------------------------------------
fig, axes = plt.subplots(len(baselines), 1,
                         figsize=(10, 3.5*len(baselines)),  # More compact, thesis-friendly ratio
                         constrained_layout=True)  # Better spacing management

# Title and subtitle removed per user request

if len(baselines) == 1:
    axes = [axes]

# Corrected method names for legend
method_names = {
    "IG": "Integrated Gradient",
    "DL_SHAP": "DeepLIFT SHAP",
    "FA": "Feature Ablation"
}

for row, baseline in enumerate(baselines):
    ax = axes[row]
    idxs = sorted(baseline_to_idxs[baseline])
    n = len(idxs)
    
    # x-locations for the centre of each "feature group"
    x_centres = np.arange(n)
    
    # Draw bars with enhanced styling
    for j, (method, color) in enumerate(zip(methods, colors)):
        imp = attr_summaries[method][baseline] * 100
        heights = imp[idxs]
        
        # Add bars with hatching for better distinction in grayscale printing
        hatch_patterns = ['', '///', '...']
        ax.bar(x_centres + (j-1)*bar_width, heights,
               width=bar_width, color=color, alpha=0.85,
               hatch=hatch_patterns[j], 
               label=method_names[method] if row==0 else "",
               edgecolor='black', linewidth=0.5)
    
    # Enhanced subplot styling
    ax.set_title(f"Baseline: {baseline.capitalize()}", loc='left', fontweight='normal')
    ax.set_ylabel("Contribution (%)", fontweight='normal')
    ax.set_xlabel("Feature Index", fontweight='normal')
    ax.set_xticks(x_centres)
    
    # Improved tick labels with smaller rotation for readability
    feature_labels = [f"#{i+1}" for i in idxs]
    ax.set_xticklabels(feature_labels, rotation=45, ha='right')
    
    # Refined grid
    ax.grid(axis="y", alpha=0.3, linestyle='--')
    
    # Add spines for more professional look
    for spine in ax.spines.values():
        spine.set_visible(True)
        spine.set_linewidth(0.75)
    
    # Add a horizontal line at y=0
    ax.axhline(y=0, color='black', linestyle='-', linewidth=0.8, alpha=0.5)

# Create legend with enhanced styling - top right with frame
handles, labels = axes[0].get_legend_handles_labels()
legend = fig.legend(handles, labels, loc='upper center', 
                   bbox_to_anchor=(0.5, 0.01),  # Below all subplots
                   ncol=len(methods), frameon=True, 
                   fancybox=True, shadow=True)

# Figure caption removed as requested

plt.savefig('feature_importance_comparison.pdf', bbox_inches='tight', dpi=300)
plt.savefig('feature_importance_comparison.png', bbox_inches='tight', dpi=300)
plt.show()

# --------------------------------------------------
# ❹ PRINT THE "# – Feature name" LISTS (with enhanced formatting)
# --------------------------------------------------
print("\nTable X: Feature Index Mapping")
print("=" * 50)
for baseline in baselines:
    idxs = sorted(baseline_to_idxs[baseline])
    print(f"\nBaseline: {baseline.capitalize()}")
    print("-" * 40)
    print(f"{'Index':<8} {'Feature Name':<30}")
    print("-" * 40)
    for i in idxs:
        print(f"#{i+1:<7} {exogenous_vars[i]:<30}")

## Feature plots

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
from math import ceil
import matplotlib.dates as mdates
from matplotlib.ticker import MaxNLocator
import textwrap

# --------------------------------------------------
# ❶ USER CONFIG - Enhanced for thesis presentation
# --------------------------------------------------
top_n = 6                                     # how many features to show
ATTR_START = pd.Timestamp("2023-09-01 19:00", tz="UTC")
ATTR_END = pd.Timestamp("2023-09-10 08:00", tz="UTC")
PEAK_START = pd.Timestamp("2023-09-05 19:00", tz="UTC")
PEAK_END = pd.Timestamp("2023-09-06 08:00", tz="UTC")



plot_style = "subplots"
use_zscore = False

plt.style.use('seaborn-v0_8-whitegrid')
font_family = 'serif'
mpl.rcParams['font.family'] = font_family
mpl.rcParams['font.size'] = 12
mpl.rcParams['axes.titlesize'] = 14
mpl.rcParams['axes.labelsize'] = 13
mpl.rcParams['xtick.labelsize'] = 11
mpl.rcParams['ytick.labelsize'] = 11
mpl.rcParams['legend.fontsize'] = 12
mpl.rcParams['figure.titlesize'] = 16

# ---------- DETERMINISTIC COLOR ASSIGNMENT FOR GIVEN INDEXES ----------
special_indexes = [7, 15, 36, 35, 33, 30, 17, 34, 18]
special_indexes_sorted = sorted(special_indexes)  # [7, 15, 17, 18, 30, 33, 34, 35, 36]
special_color_map = {}

# Choose a color palette (can be larger than 9 if you like)
color_palette = plt.get_cmap("tab10", len(special_indexes_sorted)).colors
for i, idx in enumerate(special_indexes_sorted):
    special_color_map[idx] = color_palette[i]

def get_feature_color(idx, default_palette=plt.cm.tab10.colors):
    if idx in special_color_map:
        return special_color_map[idx]
    # fallback to default cycle for non-specified indexes
    return default_palette[idx % len(default_palette)]

# ------------ EXAMPLE: REST OF YOUR CODE BELOW ----------------------


methods = ["IG", "DL_SHAP", "FA"]
method_names = {
    "IG": "Integrated Gradient",
    "DL_SHAP": "DeepLIFT SHAP",
    "FA": "Feature Ablation"
}
all_baselines = list(attr_summaries[methods[0]].keys())
baselines = [b for b in all_baselines if b != "mean"]

attr_pct = {}
for m in methods:
    attr_pct[m] = {}
    for b in baselines:
        v = attr_summaries[m][b].astype(float)
        v = v / (v.sum() + 1e-12)
        attr_pct[m][b] = v

all_vectors = [attr_pct[m][b] for m in methods for b in baselines]
mean_imp = np.mean(all_vectors, axis=0)
top_idx = np.argsort(mean_imp)[-top_n:][::-1]
top_names = [exogenous_vars[i] for i in top_idx]

slice_df = df.loc[ATTR_START:ATTR_END, top_names]
if use_zscore and plot_style == "overlay":
    slice_df = (slice_df - slice_df.mean()) / slice_df.std(ddof=0)

# --------------------------------------------------
# ❺ PLOT - Enhanced for thesis quality: 3x2 subplots, wrapped titles
# --------------------------------------------------
def get_time_locator(start, end):
    duration_hours = (end - start).total_seconds() / 3600
    if duration_hours <= 48:
        return mdates.HourLocator(interval=2), mdates.DateFormatter('%H:%M\n%d-%b')
    elif duration_hours <= 96:
        return mdates.HourLocator(interval=6), mdates.DateFormatter('%d-%b\n%H:%M')
    elif duration_hours <= 168:
        return mdates.HourLocator(interval=12), mdates.DateFormatter('%d-%b\n%H:%M')
    else:
        return mdates.DayLocator(interval=1), mdates.DateFormatter('%d-%b')

if plot_style == "overlay":
    fig, ax = plt.subplots(figsize=(10, 6))
    ax.axvspan(PEAK_START, PEAK_END, alpha=0.15, color='gray', label='Peak Period')
    line_styles = ['-', '--', ':', '-.', '-', '--']
    for i, (name, style) in enumerate(zip(top_names, line_styles)):
        color = get_feature_color(top_idx[i])
        ax.plot(slice_df.index, slice_df[name],
                linewidth=2.5,
                color=color,
                linestyle=style,
                label=f"#{top_idx[i]+1} {name}")
    locator, formatter = get_time_locator(ATTR_START, ATTR_END)
    ax.xaxis.set_major_locator(locator)
    ax.xaxis.set_major_formatter(formatter)
    ax.set_title("Temporal Evolution of Top Feature Contributions", fontweight='normal', pad=15)
    ax.set_ylabel("Z-Score" if use_zscore else "Feature Value", fontweight='normal')
    ax.set_xlabel("Date", fontweight='normal')
    ax.grid(alpha=0.3, linestyle='--')
    for spine in ax.spines.values():
        spine.set_visible(True)
        spine.set_linewidth(0.75)
    handles, labels = ax.get_legend_handles_labels()
    if 'Peak Period' not in labels:
        from matplotlib.patches import Patch
        handles.append(Patch(facecolor='gray', alpha=0.15))
        labels.append('Peak Period (23:00-05:00)')
    ax.legend(handles, labels, bbox_to_anchor=(1.01, 1), loc="bottom left", frameon=True, fancybox=True, shadow=True)
    fig.tight_layout()

elif plot_style == "subplots":
    nrows, ncols = 3, 2
    fig, axes = plt.subplots(nrows, ncols, figsize=(13, 9), sharex=True, constrained_layout=True)
    axes = axes.flatten()
    fig.suptitle("Temporal Evolution of Key Feature Contributions (Case 1)", fontweight='normal', y=1.06)
    wrap_width = 65

    locator, formatter = get_time_locator(ATTR_START, ATTR_END)

    for i, (ax, idx, name) in enumerate(zip(axes, top_idx, top_names)):
        color = get_feature_color(idx)
        ax.axvspan(PEAK_START, PEAK_END, alpha=0.3, color='gray')
        ax.plot(slice_df.index, slice_df[name], linewidth=2.5, color=color)
        feature_title = f"#{idx+1}: {name} ({mean_imp[idx]*100:.2f}%)"
        feature_title_wrapped = "\n".join(textwrap.wrap(feature_title, wrap_width))
        ax.set_title(feature_title_wrapped, loc='left', fontweight='normal', fontsize=11)
        ax.set_ylabel("Value", fontweight='normal')
        ax.yaxis.set_major_locator(MaxNLocator(nbins=5))
        ax.grid(alpha=0.3, linestyle='--')
        for spine in ax.spines.values():
            spine.set_visible(True)
            spine.set_linewidth(0.75)
        val_min = slice_df[name].min()
        val_max = slice_df[name].max()
        if val_min == val_max:
            delta = val_min * 0.01 if val_min != 0 else 0.01
            ax.axhspan(val_min - delta, val_max + delta, alpha=0.1, color=color)
        else:
            ax.axhspan(val_min, val_max, alpha=0.1, color=color)
        if i == 0:
            ax.text(0.98, 0.95, 'Peak Period', transform=ax.transAxes,
                    bbox=dict(facecolor='gray', alpha=0.3, edgecolor='none', pad=3),
                    ha='right', va='top', fontsize=10)
    for j in range(top_n, nrows * ncols):
        fig.delaxes(axes[j])
    for ax in axes[-ncols:]:
        ax.xaxis.set_major_locator(locator)
        ax.xaxis.set_major_formatter(formatter)
        ax.set_xlabel("Date", fontweight='normal')
    period_text = (f"Period: {ATTR_START.strftime('%Y-%m-%d %H:%M')} to {ATTR_END.strftime('%Y-%m-%d %H:%M')} UTC | "
                  f"Peak Activity: {PEAK_START.strftime('%Y-%m-%d %H:%M')} to {PEAK_END.strftime('%Y-%m-%d %H:%M')} UTC")
    fig.text(0.5, 1.02, period_text, ha='center', fontstyle='italic', fontsize=11)

else:
    raise ValueError("plot_style must be 'subplots' or 'overlay'")

plt.savefig('feature_timeseries_analysis.pdf', bbox_inches='tight', dpi=300)
plt.savefig('feature_timeseries_analysis.png', bbox_inches='tight', dpi=300)
plt.show()


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
from math import ceil
import matplotlib.dates as mdates
from matplotlib.ticker import MaxNLocator
import textwrap

# --------------------------------------------------
# ❶ USER CONFIG - Enhanced for thesis presentation
# --------------------------------------------------
top_n = 6                                     # how many features to show
ATTR_START = pd.Timestamp("2023-09-01 19:00", tz="UTC")
ATTR_END = pd.Timestamp("2023-09-10 08:00", tz="UTC")
PEAK_START = pd.Timestamp("2023-09-05 19:00", tz="UTC")
PEAK_END = pd.Timestamp("2023-09-06 08:00", tz="UTC")
PEAK_PRICE = pd.Timestamp("2023-09-06 03:00", tz="UTC")  # New peak price line

plot_style = "subplots"
use_zscore = False

plt.style.use('seaborn-v0_8-whitegrid')
font_family = 'serif'
mpl.rcParams['font.family'] = font_family
mpl.rcParams['font.size'] = 16
mpl.rcParams['axes.titlesize'] = 18
mpl.rcParams['axes.labelsize'] = 15
mpl.rcParams['xtick.labelsize'] = 14
mpl.rcParams['ytick.labelsize'] = 14
mpl.rcParams['legend.fontsize'] = 16
mpl.rcParams['figure.titlesize'] = 20

# --- EXPLICIT HIGH-CONTRAST COLORS FOR YOUR INDEXES ---
SPECIAL_INDEX_COLORS = { # index nb - 1
    6:   "#1f77b4",  # blue
    14:  "#ff7f0e",  # orange
    16:  "#e377c2",  # pink
    17:  "#bcbd22",  # olive/lime
    29:  "#8c564b",  # brown
    32:  "#9467bd",  # purple
    33:  "#7f7f7f",  # gray
    34:  "#d62728",  # red
    35:  "#2ca02c",  # green
}

def get_feature_color(idx, default_palette=plt.cm.tab10.colors):
    if idx in SPECIAL_INDEX_COLORS:
        return SPECIAL_INDEX_COLORS[idx]
    return default_palette[idx % len(default_palette)]

# --- YOUR DATA/PREPROCESSING ---

methods = ["IG", "DL_SHAP", "FA"]
all_baselines = list(attr_summaries[methods[0]].keys())
baselines = [b for b in all_baselines if b != "mean"]

attr_pct = {}
for m in methods:
    attr_pct[m] = {}
    for b in baselines:
        v = attr_summaries[m][b].astype(float)
        v = v / (v.sum() + 1e-12)
        attr_pct[m][b] = v

all_vectors = [attr_pct[m][b] for m in methods for b in baselines]
mean_imp = np.mean(all_vectors, axis=0)
top_idx = np.argsort(mean_imp)[-top_n:][::-1]
top_names = [exogenous_vars[i] for i in top_idx]

slice_df = df.loc[ATTR_START:ATTR_END, top_names]
if use_zscore and plot_style == "overlay":
    slice_df = (slice_df - slice_df.mean()) / slice_df.std(ddof=0)

# --- PLOTTING ---
def get_time_locator(start, end):
    duration_hours = (end - start).total_seconds() / 3600
    if duration_hours <= 48:
        return mdates.HourLocator(interval=2), mdates.DateFormatter('%H:%M\n%d-%b')
    elif duration_hours <= 96:
        return mdates.HourLocator(interval=6), mdates.DateFormatter('%d-%b\n%H:%M')
    elif duration_hours <= 168:
        return mdates.HourLocator(interval=12), mdates.DateFormatter('%d-%b\n%H:%M')
    else:
        return mdates.DayLocator(interval=1), mdates.DateFormatter('%d-%b')

if plot_style == "overlay":
    fig, ax = plt.subplots(figsize=(10, 6))
    ax.axvspan(PEAK_START, PEAK_END, alpha=0.15, color='gray', label='Peak Period')
    ax.axvline(PEAK_PRICE, color='red', linestyle='--', linewidth=2, label='Peak Price')
    line_styles = ['-', '--', ':', '-.', '-', '--']
    for i, (name, style) in enumerate(zip(top_names, line_styles)):
        color = get_feature_color(top_idx[i])
        ax.plot(slice_df.index, slice_df[name],
                linewidth=2.5,
                color=color,
                linestyle=style,
                label=f"#{top_idx[i]+1} {name}")
    locator, formatter = get_time_locator(ATTR_START, ATTR_END)
    ax.xaxis.set_major_locator(locator)
    ax.xaxis.set_major_formatter(formatter)
    plt.setp(ax.xaxis.get_majorticklabels(), rotation=45, ha='right')
    ax.set_ylabel("Z-Score" if use_zscore else "Feature Value", fontweight='normal')
    ax.set_xlabel("Date", fontweight='normal')
    ax.grid(alpha=0.3, linestyle='--')
    for spine in ax.spines.values():
        spine.set_visible(True)
        spine.set_linewidth(0.75)
    handles, labels = ax.get_legend_handles_labels()
    if 'Peak Period' not in labels:
        from matplotlib.patches import Patch
        handles.append(Patch(facecolor='gray', alpha=0.15))
        labels.append('Peak Period (23:00-05:00)')
    ax.legend(handles, labels, bbox_to_anchor=(1.01, 1), loc="bottom left", frameon=True, fancybox=True, shadow=True)
    fig.tight_layout()

elif plot_style == "subplots":
    nrows, ncols = 3, 2
    fig, axes = plt.subplots(nrows, ncols, figsize=(13, 9), sharex=True, constrained_layout=True)
    axes = axes.flatten()
    wrap_width = 41

    locator, formatter = get_time_locator(ATTR_START, ATTR_END)

    for i, (ax, idx, name) in enumerate(zip(axes, top_idx, top_names)):
        color = get_feature_color(idx)
        ax.axvspan(PEAK_START, PEAK_END, alpha=0.3, color='gray')
        ax.axvline(PEAK_PRICE, color='red', linestyle='--', linewidth=2)
        ax.plot(slice_df.index, slice_df[name], linewidth=2.5, color=color)
        feature_title = f"#{idx+1}: {name} ({mean_imp[idx]*100:.2f}%)"
        feature_title_wrapped = "\n".join(textwrap.wrap(feature_title, wrap_width))
        ax.set_title(feature_title_wrapped, loc='left', fontweight='normal', fontsize=16)
        ax.set_ylabel("Value", fontweight='normal')
        ax.yaxis.set_major_locator(MaxNLocator(nbins=5))
        ax.grid(alpha=0.3, linestyle='--')
        for spine in ax.spines.values():
            spine.set_visible(True)
            spine.set_linewidth(0.75)
        val_min = slice_df[name].min()
        val_max = slice_df[name].max()
        if val_min == val_max:
            delta = val_min * 0.01 if val_min != 0 else 0.01
            ax.axhspan(val_min - delta, val_max + delta, alpha=0.1, color=color)
        else:
            ax.axhspan(val_min, val_max, alpha=0.1, color=color)
        if i == 0:
            ax.text(0.98, 0.95, 'Peak Period', transform=ax.transAxes,
                    bbox=dict(facecolor='gray', alpha=0.3, edgecolor='none', pad=3),
                    ha='right', va='top', fontsize=14)
            ax.text(0.98, 0.80, 'Peak Price', transform=ax.transAxes,
                    bbox=dict(facecolor='red', alpha=0.3, edgecolor='none', pad=3),
                    ha='right', va='top', fontsize=14)
    for j in range(top_n, nrows * ncols):
        fig.delaxes(axes[j])
    for ax in axes[-ncols:]:
        ax.xaxis.set_major_locator(locator)
        ax.xaxis.set_major_formatter(formatter)
        plt.setp(ax.xaxis.get_majorticklabels(), rotation=45, ha='right')
        ax.set_xlabel("Date", fontweight='normal')


else:
    raise ValueError("plot_style must be 'subplots' or 'overlay'")

plt.savefig('feature_timeseries_analysis.pdf', bbox_inches='tight', dpi=300)
plt.savefig('feature_timeseries_analysis.png', bbox_inches='tight', dpi=300)
plt.show()

# CASE 2

In [None]:
import torch
import numpy as np
import pandas as pd
import joblib
import plotly.graph_objects as go

# ---- PATHS ----
MODEL_PATH = "/new_model_test/best_model_SE2_Down_Volume.pth"
SCALER_X_PATH = "/new_model_test/scaler_X_SE2_Down_Volume.joblib"
SCALER_Y_PATH = "/new_model_test/scaler_y_SE2_Down_Volume.joblib"
FILE_PATH = "/results/results_merged.csv"

REGION = "SE2"
target_vars = [f"Balancing_ActivatedBalancingEnergy_{REGION}_mFRR_NotSpecifiedDownActivatedVolume"]

# -- Model class --
class MaskedLSTM(torch.nn.Module):
    def __init__(self, input_size, hidden_size=32, num_layers=1, dense_size=32, output_size=1, horizon=1, dropout=0.08, dropout_lstm=0.27, bidirectional=True):
        super().__init__()
        self.horizon = horizon
        self.output_size = output_size
        actual_in = input_size * 2
        self.lstm = torch.nn.LSTM(
            input_size=actual_in,
            hidden_size=hidden_size,
            num_layers=num_layers,
            dropout=dropout_lstm if num_layers > 1 else 0.0,
            bidirectional=bidirectional,
            batch_first=True,
        )
        hidden_out = hidden_size * (2 if bidirectional else 1)
        self.fc1 = torch.nn.Linear(hidden_out, dense_size)
        self.fc2 = torch.nn.Linear(dense_size, horizon * output_size)
        self.relu = torch.nn.ReLU()
        self.drop = torch.nn.Dropout(dropout)
        self.drop_lstm = torch.nn.Dropout(dropout_lstm)
    def forward(self, x, mask):
        x = torch.cat((x, mask), dim=2)
        out, _ = self.lstm(x)
        out = self.drop_lstm(out)
        out = self.relu(self.fc1(out[:, -1, :]))
        out = self.drop(out)
        out = self.fc2(out)
        return out.view(out.size(0), self.horizon, self.output_size)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
scaler_X = joblib.load(SCALER_X_PATH)
scaler_y = joblib.load(SCALER_Y_PATH)

best_params = {
    "hidden_size": 32,
    "num_layers": 1,
    "bidirectional": True,
    "dense_size": 32,
    "dropout": 0.08137184695714653,
    "dropout_lstm": 0.2675438586833995,
    "learning_rate": 0.0013843186375837144,
    "batch_size": 128,
    "seq_length": 216,
}
seq_length = best_params["seq_length"]

# --- Load and preprocess data ---
df_all = pd.read_csv(FILE_PATH)
price_column_name_to_plot = f"Balancing_PricesOfActivatedBalancingEnergy_{REGION}_mFRR_NotSpecifiedDownPrice"

# Store price data before dropping columns
price_data_full = None
if price_column_name_to_plot in df_all.columns:
    price_data_full = df_all[['DateTime', price_column_name_to_plot]].copy()
    price_data_full['DateTime'] = pd.to_datetime(price_data_full['DateTime'], utc=True)
    price_data_full.set_index('DateTime', inplace=True)
    print(f"Stored '{price_column_name_to_plot}' separately for plotting.")
else:
    print(f"Warning: Price column '{price_column_name_to_plot}' not found in df_all. Will not be plotted.")

from filter_features import pick_region_filter
BDZ_filter = pick_region_filter(region=REGION, remove_balancing=True)
drop_list = [
    f"Balancing_PricesOfActivatedBalancingEnergy_{REGION}_mFRR_NotSpecifiedUpPrice",
    f"Balancing_PricesOfActivatedBalancingEnergy_{REGION}_mFRR_NotSpecifiedDownPrice",
    f"Balancing_ActivatedBalancingEnergy_{REGION}_mFRR_NotSpecifiedUpActivatedVolume",
    f"Balancing_ActivatedBalancingEnergy_{REGION}_mFRR_NotSpecifiedDownActivatedVolume",
]
BDZ_filter = [c for c in BDZ_filter if c not in drop_list or c in target_vars]
df_all["DateTime"] = pd.to_datetime(df_all["DateTime"], utc=True)
df_all.set_index("DateTime", inplace=True)
df_all = df_all.asfreq("h")
df = df_all[BDZ_filter].copy()
df = df.asfreq("h")


# 


CREATE_INDICATOR_FEATURES = False
CREATE_CYCLICAL_FEATURES = False
CREATE_LAGGED_FEATURES = False

new_indicator_features = []
if CREATE_INDICATOR_FEATURES:
    indicator_features_to_create = [col for col in BDZ_filter + target_vars if col in df.columns]
    for col in indicator_features_to_create:
        df[f'{col}_was_missing'] = df[col].isnull().astype(int)
    new_indicator_features = [f'{col}_was_missing' for col in indicator_features_to_create]

new_time_features = []
if CREATE_CYCLICAL_FEATURES:
    df['hour_sin'] = np.sin(2 * np.pi * df.index.hour / 24.0)
    df['hour_cos'] = np.cos(2 * np.pi * df.index.hour / 24.0)
    df['dayofweek_sin'] = np.sin(2 * np.pi * df.index.dayofweek / 7.0)
    df['dayofweek_cos'] = np.cos(2 * np.pi * df.index.dayofweek / 7.0)
    df['weekofyear_sin'] = np.sin(2 * np.pi * df.index.isocalendar().week / 52.0)
    df['weekofyear_cos'] = np.cos(2 * np.pi * df.index.isocalendar().week / 52.0)
    new_time_features = ['hour_sin', 'hour_cos', 'dayofweek_sin', 'dayofweek_cos', 'weekofyear_sin', 'weekofyear_cos']

new_lagged_features = []
if CREATE_LAGGED_FEATURES:
    for target_var in target_vars:
        df[f'{target_var}_lag1'] = df[target_var].shift(1)
        df[f'{target_var}_lag24'] = df[target_var].shift(24)
        df[f'{target_var}_lag168'] = df[target_var].shift(168)
    for target_var in target_vars:
        new_lagged_features.extend([f'{target_var}_lag1', f'{target_var}_lag24', f'{target_var}_lag168'])

df.dropna(axis=1, how="all", inplace=True)

exogenous_vars = [c for c in BDZ_filter if c not in target_vars]
if CREATE_CYCLICAL_FEATURES:
    exogenous_vars.extend(new_time_features)
if CREATE_LAGGED_FEATURES:
    exogenous_vars.extend(new_lagged_features)
if CREATE_INDICATOR_FEATURES:
    exogenous_vars.extend(new_indicator_features)
exogenous_vars = sorted(list(set(var for var in exogenous_vars if var in df.columns)))

# --- Instantiate the model ---
model = MaskedLSTM(
    input_size=len(exogenous_vars),
    hidden_size=best_params["hidden_size"],
    num_layers=best_params["num_layers"],
    dense_size=best_params["dense_size"],
    output_size=len(target_vars),
    dropout=best_params["dropout"],
    dropout_lstm=best_params["dropout_lstm"],
    bidirectional=best_params["bidirectional"],
).to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()

# --- Robust window selection: will always work for any seq_length and date range ---

desired_start = pd.Timestamp("2024-04-07 06:00:00", tz="UTC")
desired_end = pd.Timestamp("2024-04-07 22:00:00", tz="UTC")

all_dates = df.index
try:
    first_pred_idx = all_dates.get_loc(desired_start)
except KeyError:
    raise ValueError(f"Desired start {desired_start} not in DataFrame index!")

if first_pred_idx < seq_length:
    raise ValueError(f"Not enough data before {desired_start} to build input sequence with seq_length={seq_length}.")

window_start = all_dates[first_pred_idx - seq_length]
window_end = desired_end

df_pred_window = df.loc[window_start:window_end].copy()

# -- Scale features --
X_scaled = scaler_X.transform(df_pred_window[exogenous_vars])
y_true = df_pred_window[target_vars].values

# -- Build sequences --
def create_sequences(X, lookback, horizon=1):
    num_samples = X.shape[0] - lookback - horizon + 1
    X_seq = np.zeros((num_samples, lookback, X.shape[1]), dtype=np.float32)
    mask_seq = np.ones((num_samples, lookback, X.shape[1]), dtype=np.float32)
    for i in range(num_samples):
        X_seq[i] = X[i : i + lookback]
    return X_seq, mask_seq

X_seq, mask_seq = create_sequences(X_scaled, seq_length, 1)
nan_pos = np.isnan(X_seq)
mask_seq[nan_pos] = 0.0
X_seq[nan_pos] = 0.0

dates_seq = df_pred_window.index[seq_length : seq_length + len(X_seq)]
y_true_aligned = y_true[seq_length : seq_length + len(X_seq)]

# -- Filter for predictions within the requested window only --
mask = (dates_seq >= desired_start) & (dates_seq <= desired_end)
dates_seq = dates_seq[mask]
y_true_aligned = y_true_aligned[mask]

with torch.no_grad():
    X_tensor = torch.tensor(X_seq, dtype=torch.float32).to(DEVICE)
    mask_tensor = torch.tensor(mask_seq, dtype=torch.float32).to(DEVICE)
    y_pred_scaled = model(X_tensor, mask_tensor).cpu().numpy()[:, 0, :]  # (n, target_dim)
y_pred = scaler_y.inverse_transform(y_pred_scaled)
y_pred = y_pred[mask]

# --- Add price data ---
price_data_to_plot = None
if price_data_full is not None and price_column_name_to_plot in price_data_full.columns:
    try:
        # Align price data with dates_seq
        price_data_to_plot = price_data_full.loc[dates_seq, price_column_name_to_plot].values.astype(float)
    except Exception as e:
        print(f"Warning: Could not align price data with dates_seq: {e}")
        price_data_to_plot = None

# -- Plot
fig = go.Figure()
fig.add_trace(go.Scatter(
    x=dates_seq,
    y=y_true_aligned[:, 0],
    mode='lines+markers',
    name='True',
    marker=dict(symbol='circle', size=6)
))
fig.add_trace(go.Scatter(
    x=dates_seq,
    y=y_pred[:, 0],
    mode='lines+markers',
    name='Predicted',
    marker=dict(symbol='x', size=6)
))
if price_data_to_plot is not None:
    fig.add_trace(go.Scatter(
        x=dates_seq,
        y=price_data_to_plot,
        mode='lines',
        name='mFRR Down Price',
        line=dict(color='orange', width=2),
        yaxis='y2'
    ))
fig.update_layout(
    title=f"True vs Predicted for {target_vars[0]} ({desired_start.date()} to {desired_end.date()})",
    xaxis_title="Date",
    yaxis=dict(
        title="Volume (MW)"
    ),
    yaxis2=dict(
        title="mFRR Down Price (€/MWh)",
        overlaying='y',
        side='right',
        showgrid=False
    ),
    template="plotly_white",
    legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1)
)
fig.update_xaxes(tickformat="%d %b %Y %H:%M")
fig.show()


## Case 2 plot

In [57]:
desired_start = pd.Timestamp("2024-04-07 06:00:00", tz="UTC")
desired_end = pd.Timestamp("2024-04-07 22:00:00", tz="UTC")

In [None]:
import torch
import numpy as np
import pandas as pd
import joblib
import plotly.graph_objects as go

# ---- PATHS ----
MODEL_PATH = "/new_model_test/best_model_SE2_Down_Volume.pth"
SCALER_X_PATH = "/new_model_test/scaler_X_SE2_Down_Volume.joblib"
SCALER_Y_PATH = "/new_model_test/scaler_y_SE2_Down_Volume.joblib"
FILE_PATH = "/results/results_merged.csv"

REGION = "SE2"
target_vars = [f"Balancing_ActivatedBalancingEnergy_{REGION}_mFRR_NotSpecifiedDownActivatedVolume"]

# -- Model class --
class MaskedLSTM(torch.nn.Module):
    def __init__(self, input_size, hidden_size=32, num_layers=1, dense_size=32, output_size=1, horizon=1, dropout=0.08, dropout_lstm=0.27, bidirectional=True):
        super().__init__()
        self.horizon = horizon
        self.output_size = output_size
        actual_in = input_size * 2
        self.lstm = torch.nn.LSTM(
            input_size=actual_in,
            hidden_size=hidden_size,
            num_layers=num_layers,
            dropout=dropout_lstm if num_layers > 1 else 0.0,
            bidirectional=bidirectional,
            batch_first=True,
        )
        hidden_out = hidden_size * (2 if bidirectional else 1)
        self.fc1 = torch.nn.Linear(hidden_out, dense_size)
        self.fc2 = torch.nn.Linear(dense_size, horizon * output_size)
        self.relu = torch.nn.ReLU()
        self.drop = torch.nn.Dropout(dropout)
        self.drop_lstm = torch.nn.Dropout(dropout_lstm)
    def forward(self, x, mask):
        x = torch.cat((x, mask), dim=2)
        out, _ = self.lstm(x)
        out = self.drop_lstm(out)
        out = self.relu(self.fc1(out[:, -1, :]))
        out = self.drop(out)
        out = self.fc2(out)
        return out.view(out.size(0), self.horizon, self.output_size)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
scaler_X = joblib.load(SCALER_X_PATH)
scaler_y = joblib.load(SCALER_Y_PATH)

best_params = {
    "hidden_size": 32,
    "num_layers": 1,
    "bidirectional": True,
    "dense_size": 32,
    "dropout": 0.08137184695714653,
    "dropout_lstm": 0.2675438586833995,
    "learning_rate": 0.0013843186375837144,
    "batch_size": 128,
    "seq_length": 216,
}
seq_length = best_params["seq_length"]

# --- Load and preprocess data ---
df_all = pd.read_csv(FILE_PATH)
price_column_name_to_plot = f"Balancing_PricesOfActivatedBalancingEnergy_{REGION}_mFRR_NotSpecifiedDownPrice"

# Store price data before dropping columns
price_data_full = None
if price_column_name_to_plot in df_all.columns:
    price_data_full = df_all[['DateTime', price_column_name_to_plot]].copy()
    price_data_full['DateTime'] = pd.to_datetime(price_data_full['DateTime'], utc=True)
    price_data_full.set_index('DateTime', inplace=True)
    print(f"Stored '{price_column_name_to_plot}' separately for plotting.")
else:
    print(f"Warning: Price column '{price_column_name_to_plot}' not found in df_all. Will not be plotted.")

from filter_features import pick_region_filter
BDZ_filter = pick_region_filter(region=REGION, remove_balancing=True)
drop_list = [
    f"Balancing_PricesOfActivatedBalancingEnergy_{REGION}_mFRR_NotSpecifiedUpPrice",
    f"Balancing_PricesOfActivatedBalancingEnergy_{REGION}_mFRR_NotSpecifiedDownPrice",
    f"Balancing_ActivatedBalancingEnergy_{REGION}_mFRR_NotSpecifiedUpActivatedVolume",
    f"Balancing_ActivatedBalancingEnergy_{REGION}_mFRR_NotSpecifiedDownActivatedVolume",
]
BDZ_filter = [c for c in BDZ_filter if c not in drop_list or c in target_vars]
df_all["DateTime"] = pd.to_datetime(df_all["DateTime"], utc=True)
df_all.set_index("DateTime", inplace=True)
df_all = df_all.asfreq("h")
df = df_all[BDZ_filter].copy()
df = df.asfreq("h")


# 


CREATE_INDICATOR_FEATURES = False
CREATE_CYCLICAL_FEATURES = False
CREATE_LAGGED_FEATURES = False

new_indicator_features = []
if CREATE_INDICATOR_FEATURES:
    indicator_features_to_create = [col for col in BDZ_filter + target_vars if col in df.columns]
    for col in indicator_features_to_create:
        df[f'{col}_was_missing'] = df[col].isnull().astype(int)
    new_indicator_features = [f'{col}_was_missing' for col in indicator_features_to_create]

new_time_features = []
if CREATE_CYCLICAL_FEATURES:
    df['hour_sin'] = np.sin(2 * np.pi * df.index.hour / 24.0)
    df['hour_cos'] = np.cos(2 * np.pi * df.index.hour / 24.0)
    df['dayofweek_sin'] = np.sin(2 * np.pi * df.index.dayofweek / 7.0)
    df['dayofweek_cos'] = np.cos(2 * np.pi * df.index.dayofweek / 7.0)
    df['weekofyear_sin'] = np.sin(2 * np.pi * df.index.isocalendar().week / 52.0)
    df['weekofyear_cos'] = np.cos(2 * np.pi * df.index.isocalendar().week / 52.0)
    new_time_features = ['hour_sin', 'hour_cos', 'dayofweek_sin', 'dayofweek_cos', 'weekofyear_sin', 'weekofyear_cos']

new_lagged_features = []
if CREATE_LAGGED_FEATURES:
    for target_var in target_vars:
        df[f'{target_var}_lag1'] = df[target_var].shift(1)
        df[f'{target_var}_lag24'] = df[target_var].shift(24)
        df[f'{target_var}_lag168'] = df[target_var].shift(168)
    for target_var in target_vars:
        new_lagged_features.extend([f'{target_var}_lag1', f'{target_var}_lag24', f'{target_var}_lag168'])

df.dropna(axis=1, how="all", inplace=True)

exogenous_vars = [c for c in BDZ_filter if c not in target_vars]
if CREATE_CYCLICAL_FEATURES:
    exogenous_vars.extend(new_time_features)
if CREATE_LAGGED_FEATURES:
    exogenous_vars.extend(new_lagged_features)
if CREATE_INDICATOR_FEATURES:
    exogenous_vars.extend(new_indicator_features)
exogenous_vars = sorted(list(set(var for var in exogenous_vars if var in df.columns)))

# --- Instantiate the model ---
model = MaskedLSTM(
    input_size=len(exogenous_vars),
    hidden_size=best_params["hidden_size"],
    num_layers=best_params["num_layers"],
    dense_size=best_params["dense_size"],
    output_size=len(target_vars),
    dropout=best_params["dropout"],
    dropout_lstm=best_params["dropout_lstm"],
    bidirectional=best_params["bidirectional"],
).to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()

# --- Robust window selection: will always work for any seq_length and date range ---

# desired_start = pd.Timestamp("2024-04-07 06:00:00", tz="UTC")
# desired_end = pd.Timestamp("2024-04-07 22:00:00", tz="UTC")

all_dates = df.index
try:
    first_pred_idx = all_dates.get_loc(desired_start)
except KeyError:
    raise ValueError(f"Desired start {desired_start} not in DataFrame index!")

if first_pred_idx < seq_length:
    raise ValueError(f"Not enough data before {desired_start} to build input sequence with seq_length={seq_length}.")

window_start = all_dates[first_pred_idx - seq_length]
window_end = desired_end

df_pred_window = df.loc[window_start:window_end].copy()

# -- Scale features --
X_scaled = scaler_X.transform(df_pred_window[exogenous_vars])
y_true = df_pred_window[target_vars].values

# -- Build sequences --
def create_sequences(X, lookback, horizon=1):
    num_samples = X.shape[0] - lookback - horizon + 1
    X_seq = np.zeros((num_samples, lookback, X.shape[1]), dtype=np.float32)
    mask_seq = np.ones((num_samples, lookback, X.shape[1]), dtype=np.float32)
    for i in range(num_samples):
        X_seq[i] = X[i : i + lookback]
    return X_seq, mask_seq

X_seq, mask_seq = create_sequences(X_scaled, seq_length, 1)
nan_pos = np.isnan(X_seq)
mask_seq[nan_pos] = 0.0
X_seq[nan_pos] = 0.0

dates_seq = df_pred_window.index[seq_length : seq_length + len(X_seq)]
y_true_aligned = y_true[seq_length : seq_length + len(X_seq)]

# -- Filter for predictions within the requested window only --
mask = (dates_seq >= desired_start) & (dates_seq <= desired_end)
dates_seq = dates_seq[mask]
y_true_aligned = y_true_aligned[mask]

with torch.no_grad():
    X_tensor = torch.tensor(X_seq, dtype=torch.float32).to(DEVICE)
    mask_tensor = torch.tensor(mask_seq, dtype=torch.float32).to(DEVICE)
    y_pred_scaled = model(X_tensor, mask_tensor).cpu().numpy()[:, 0, :]  # (n, target_dim)
y_pred = scaler_y.inverse_transform(y_pred_scaled)
y_pred = y_pred[mask]

# --- Add price data ---
price_data_to_plot = None
if price_data_full is not None and price_column_name_to_plot in price_data_full.columns:
    try:
        # Align price data with dates_seq
        price_data_to_plot = price_data_full.loc[dates_seq, price_column_name_to_plot].values.astype(float)
    except Exception as e:
        print(f"Warning: Could not align price data with dates_seq: {e}")
        price_data_to_plot = None


import matplotlib.pyplot as plt
import matplotlib.dates as mdates

# --- Colors and markers ---
true_color = "#636efa"    # Plotly default blue
pred_color = "#ef553b"    # Plotly default red
price_color = "orange"

# --- Matplotlib plot ---
fig, ax1 = plt.subplots(figsize=(12, 7))

# True values
ax1.plot(dates_seq, y_true_aligned[:, 0], label='True', marker='o', markersize=5, linewidth=2)

# Predicted values
ax1.plot(dates_seq, y_pred[:, 0], label='Predicted', color=pred_color, marker='x', markersize=5, linewidth=2)

ax1.set_xlabel("Date", fontsize=17)
ax1.set_ylabel("Volume (MW)", fontsize=17)
ax1.tick_params(axis='y', labelsize=16)
ax1.tick_params(axis='x', labelsize=16)
ax1.xaxis.set_major_formatter(mdates.DateFormatter('%d %b %Y\n%H:%M'))
plt.xticks(rotation=45)
plt.grid(axis='y', alpha=0.3)

# Add price on secondary y-axis
if price_data_to_plot is not None:
    ax2 = ax1.twinx()
    ax2.plot(dates_seq, price_data_to_plot, label='mFRR Down Price', color=price_color, linewidth=2)
    ax2.set_ylabel("mFRR Down Price (€/MWh)", fontsize=17, color=price_color)
    ax2.tick_params(axis='y', labelcolor=price_color, labelsize=16)
    ax2.grid(False)

else:
    ax2 = None

# Combined legend
lines_labels = [ax.get_legend_handles_labels() for ax in [ax1, ax2] if ax is not None]
lines, labels = [sum(lol, []) for lol in zip(*lines_labels)]
plt.legend(lines, labels, loc='upper right', fontsize=16, frameon=True)

plt.tight_layout()
plt.show()



## Feature importance

## IG

In [59]:
ATTR_START = "2024-04-07 09:00"
ATTR_END   = "2024-04-07 17:00"

from datetime import datetime
ATTR_START_dt = datetime.strptime(ATTR_START, "%Y-%m-%d %H:%M")
ATTR_END_dt   = datetime.strptime(ATTR_END, "%Y-%m-%d %H:%M")

In [None]:
attr_summaries = {}       # master container
# 


import os
import numpy as np
import torch
from captum.attr import IntegratedGradients
import matplotlib.pyplot as plt
import pandas as pd                 # already imported above

# ------------------------------------------------------------
# 0.  CONFIG & OUTPUT DIR
# ------------------------------------------------------------
plots_dir = "attribution_plots"
os.makedirs(plots_dir, exist_ok=True)

# ATTR_START = "2024-04-07 09:00"
# ATTR_END   = "2024-04-07 17:00"
attribution_times = pd.date_range(
    ATTR_START, ATTR_END, freq="h", tz="UTC"
)
print("Attribution times:", attribution_times)

# ------------------------------------------------------------
# 1.  BUILD INPUT / MASK TENSORS *** with the SAME preprocessing as inference ***
# ------------------------------------------------------------
input_seqs, mask_seqs = [], []

for t in attribution_times:
    if t not in df.index:
        print(f"⚠️  {t} missing in df.index – skipped");   continue
    idx = df.index.get_loc(t)
    if idx < seq_length:
        print(f"⚠️  Not enough history before {t} – skipped");   continue

    # raw window (seq_len × features) --------------------------
    window_raw = df.iloc[idx - seq_length: idx][exogenous_vars].values

    # ---> scale exactly like in inference
    window_scaled = scaler_X.transform(window_raw)

    # ---> build mask (1 = present, 0 = missing)
    mask = (~np.isnan(window_scaled)).astype(np.float32)

    # ---> impute NaNs with 0  (model expects this)
    window_scaled[np.isnan(window_scaled)] = 0.0

    input_seqs.append(window_scaled.astype(np.float32))
    mask_seqs.append(mask)

if not input_seqs:
    raise RuntimeError("No valid windows after filtering.")

X_attr      = torch.tensor(np.stack(input_seqs)).to(DEVICE)       # (N, L, F)
X_mask_attr = torch.tensor(np.stack(mask_seqs)).to(DEVICE)        # (N, L, F)

# quick sanity-check
print("Any NaNs after preprocessing?  X:", torch.isnan(X_attr).any().item(),
      " mask:", torch.isnan(X_mask_attr).any().item())

# ------------------------------------------------------------
# 2.  BASELINES  (in scaled space, no NaNs)
# ------------------------------------------------------------
zeros_X  = torch.zeros_like(X_attr[:1])
median_X = torch.median(X_attr, dim=0, keepdim=True).values

# --- Compute mean baseline (classic mean) --------------------
mean_X = torch.mean(X_attr, dim=0, keepdim=True)
mean_M = X_mask_attr[:1]                    # reuse mask for consistency

# Sanity check
assert not torch.isnan(mean_X).any(), "mean baseline still has NaNs!"
assert mean_X.shape  == (1, seq_length, len(exogenous_vars))
assert mean_M.shape  == (1, seq_length, len(exogenous_vars))

baselines = {
    "zeros": (zeros_X,  X_mask_attr[:1]),
    "median": (median_X, X_mask_attr[:1]),
    "mean":   (mean_X,   mean_M),
}

# ------------------------------------------------------------
# 3.  INTEGRATED GRADIENTS  – compatible with Captum < 0.7
# ------------------------------------------------------------
def model_forward(x, m):
    return model(x, m)            # (batch, 1, 1)

ig = IntegratedGradients(model_forward)

attr_results = {}
for name, (base_x, base_m) in baselines.items():
    # >>>> the 2-tuple style works in every Captum version
    (attr_x, attr_mask), delta = ig.attribute(
        inputs=(X_attr, X_mask_attr),
        baselines=(base_x.expand_as(X_attr),
                   base_m.expand_as(X_mask_attr)),
        target=0,
        n_steps=64,
        method="riemann_trapezoid",
        internal_batch_size=32,
        return_convergence_delta=True
    )

    print(f"{name:9s} | attr range [{attr_x.min():.2e}, {attr_x.max():.2e}] "
          f"δ-mean={delta.abs().mean():.2e}")

    attr_results[name] = attr_x.cpu().numpy()   # (N, L, F) – keep only X


# ------------------------------------------------------------
# 4.  AGGREGATE IMPORTANCE  (unchanged)
#     • |attr|  →  mean over (sample, time)  →  normalise to %
# ------------------------------------------------------------
attr_summary = {}
for name, a in attr_results.items():
    importance = np.mean(np.abs(a), axis=(0, 1))  # (F,)
    importance /= importance.sum() + 1e-12        # convert to %
    attr_summary[name] = importance

# ------------------------------------------------------------
# 5.  PLOT
# ------------------------------------------------------------
top_k   = 10
colors  = ["#3498db", "#2ecc71", "#e74c3c"]   # enough for 3 baselines
fig, axes = plt.subplots(2, 3, figsize=(20, 8))
axes = axes.flatten()

for i, (name, imp) in enumerate(attr_summary.items()):
    idx_sorted = np.argsort(imp)[-top_k:][::-1]
    feature_indices = idx_sorted      # indices into exogenous_vars
    ids = [f"#{idx+1}" for idx in feature_indices]   # consistent ID = index+1

    # --- bars ---
    ax = axes[i]
    ax.bar(range(top_k), imp[feature_indices] * 100, color=colors[i])       # %
    ax.set_xticks(range(top_k))
    ax.set_xticklabels(ids)
    ax.set_xlabel("Feature ID")
    ax.set_ylabel("Contribution (%)")
    ax.set_title(f"Integrated Gradients – {name}")

    # --- table ---
    ax_t = axes[i + 3]
    table_data = [[ids[j], exogenous_vars[feature_indices[j]]]
                  for j in range(top_k)]
    ax_t.axis("off")
    ax_t.set_title(f"{name}  mapping", pad=10)
    tbl = ax_t.table(
        cellText=table_data, colLabels=["ID", "Feature"],
        cellLoc="left", loc="center",
        colWidths=[0.07, 0.93],
    )
    tbl.auto_set_font_size(False)
    tbl.set_fontsize(9)
    tbl.scale(1.2, 1.2)


plt.suptitle("Feature Attribution (Integrated Gradients)\n"
             f"{ATTR_START} – {ATTR_END}", fontsize=16)
plt.tight_layout(rect=[0, 0, 1, 0.96])

out_path = os.path.join(plots_dir, "integrated_gradients_selected_window.png")
plt.savefig(out_path, dpi=300, bbox_inches="tight")
plt.show()
print("Saved ➜", out_path)

attr_summaries["IG"] = {k: v.copy() for k, v in attr_summary.items()}


## DL

In [None]:
import os, numpy as np, torch, matplotlib.pyplot as plt, pandas as pd
from captum.attr import DeepLiftShap

# ------------------------------------------------------------
# 0.  CONFIG & OUTPUT DIR
# ------------------------------------------------------------
plots_dir = "attribution_plots";  os.makedirs(plots_dir, exist_ok=True)
# ATTR_START = "2024-04-07 09:00"
# ATTR_END   = "2024-04-07 17:00"
attribution_times = pd.date_range(ATTR_START, ATTR_END, freq="h", tz="UTC")
print("Attribution times:", attribution_times)

# ------------------------------------------------------------
# 1.  BUILD INPUT / MASK TENSORS (unchanged)
# ------------------------------------------------------------
input_seqs, mask_seqs = [], []
for t in attribution_times:
    if t not in df.index:
        print(f"⚠️  {t} missing in df.index – skipped");   continue
    idx = df.index.get_loc(t)
    if idx < seq_length:
        print(f"⚠️  Not enough history before {t} – skipped");   continue
    window_raw = df.iloc[idx - seq_length: idx][exogenous_vars].values
    window_scaled = scaler_X.transform(window_raw)
    mask = (~np.isnan(window_scaled)).astype(np.float32)
    window_scaled[np.isnan(window_scaled)] = 0.0
    input_seqs.append(window_scaled.astype(np.float32));  mask_seqs.append(mask)

if not input_seqs:  raise RuntimeError("No valid windows after filtering.")

X_attr      = torch.tensor(np.stack(input_seqs)).to(DEVICE)      # (N, L, F)
X_mask_attr = torch.tensor(np.stack(mask_seqs)).to(DEVICE)       # (N, L, F)
print("Any NaNs? X:", torch.isnan(X_attr).any().item(),
      "mask:", torch.isnan(X_mask_attr).any().item())

# ------------------------------------------------------------
# 2.  BASELINES (mean instead of pre-spike)
# ------------------------------------------------------------
zeros_X  = torch.zeros_like(X_attr[:1])
median_X = torch.median(X_attr, dim=0, keepdim=True).values

# Compute mean baseline over all valid attribution windows
mean_raw = []
for t in attribution_times:
    idx = df.index.get_loc(t)
    if idx < seq_length:
        continue
    window_raw = df.iloc[idx - seq_length: idx][exogenous_vars]
    mean_raw.append(window_raw.values)
if not mean_raw:
    raise ValueError("No valid windows to compute mean baseline.")
mean_raw = np.mean(np.stack(mean_raw), axis=0)  # shape: (L, F)
mean_scaled = scaler_X.transform(mean_raw)
mask_mean = (~np.isnan(mean_scaled)).astype(np.float32)
mean_scaled[np.isnan(mean_scaled)] = 0.0
baseline_mean_X = torch.tensor(mean_scaled, dtype=torch.float32).unsqueeze(0).to(DEVICE)
baseline_mean_M = torch.tensor(mask_mean,        dtype=torch.float32).unsqueeze(0).to(DEVICE)

baselines = {
    "zeros":     (zeros_X,        X_mask_attr[:1]),
    "median":    (median_X,       X_mask_attr[:1]),
    "mean":      (baseline_mean_X, baseline_mean_M),
}

# ------------------------------------------------------------
# 3.  DEEPLIFT SHAP (wrapped model)
# ------------------------------------------------------------
class Wrapper(torch.nn.Module):
    def __init__(self, core):  super().__init__();  self.core = core
    def forward(self, x, m):   return self.core(x, m)

dlshap = DeepLiftShap(Wrapper(model))

attr_results = {}
for name, (base_x, base_m) in baselines.items():
    (attr_x, _attr_m), delta = dlshap.attribute(
        inputs=(X_attr, X_mask_attr),
        baselines=(base_x.expand_as(X_attr),
                   base_m.expand_as(X_mask_attr)),
        target=0,
        return_convergence_delta=True
    )
    print(f"{name:9s} | attr range "
          f"[{attr_x.min():.2e}, {attr_x.max():.2e}] "
          f"δ-mean={delta.abs().mean():.2e}")
    attr_results[name] = attr_x.detach().cpu().numpy()   # (N, L, F)

# ------------------------------------------------------------
# 4.  AGGREGATE IMPORTANCE (unchanged)
# ------------------------------------------------------------
attr_summary = {}
for name, a in attr_results.items():
    importance = np.mean(np.abs(a), axis=(0, 1))
    importance /= importance.sum() + 1e-12
    attr_summary[name] = importance

# ------------------------------------------------------------
# 5.  PLOT (only titles & file name changed)
# ------------------------------------------------------------
top_k, colors = 10, ["#3498db", "#2ecc71", "#e74c3c"]
fig, axes = plt.subplots(2, 3, figsize=(20, 8));  axes = axes.flatten()

for i, (name, imp) in enumerate(attr_summary.items()):
    idx_sorted = np.argsort(imp)[-top_k:][::-1];  ids = [f"#{idx+1}" for idx in idx_sorted]
    ax = axes[i]
    ax.bar(range(top_k), imp[idx_sorted] * 100, color=colors[i])
    ax.set_xticks(range(top_k));  ax.set_xticklabels(ids)
    ax.set_xlabel("Feature ID");   ax.set_ylabel("Contribution (%)")
    ax.set_title(f"DeepLift SHAP – {name}")
    ax_t = axes[i + 3];  ax_t.axis("off")
    ax_t.set_title(f"{name}  mapping", pad=10)
    tbl = ax_t.table(cellText=[[ids[j], exogenous_vars[idx_sorted[j]]] for j in range(top_k)],
                     colLabels=["ID", "Feature"], cellLoc="left", loc="center",
                     colWidths=[0.07, 0.93])
    tbl.auto_set_font_size(False); tbl.set_fontsize(9); tbl.scale(1.2, 1.2)

plt.suptitle("Feature Attribution (DeepLift SHAP)\n"
             f"{ATTR_START} – {ATTR_END}", fontsize=16)
plt.tight_layout(rect=[0,0,1,0.96])
out_path = os.path.join(plots_dir, "deeplift_shap_selected_window.png")
plt.savefig(out_path, dpi=300, bbox_inches="tight")
plt.show();  print("Saved ➜", out_path)

attr_summaries["DL_SHAP"] = {k: v.copy() for k, v in attr_summary.items()}


## FA

In [None]:
import os
import numpy as np
import torch
from captum.attr import FeatureAblation
import matplotlib.pyplot as plt
import pandas as pd

# ------------------------------------------------------------
# 0.  CONFIG & OUTPUT DIR
# ------------------------------------------------------------
plots_dir = "attribution_plots"
os.makedirs(plots_dir, exist_ok=True)

# ATTR_START = "2024-04-07 09:00"
# ATTR_END   = "2024-04-07 17:00"
attribution_times = pd.date_range(
    ATTR_START, ATTR_END, freq="h", tz="UTC"
)
print("Attribution times:", attribution_times)

# ------------------------------------------------------------
# 1.  BUILD INPUT / MASK TENSORS  *** SAME AS INFERENCE ***
# ------------------------------------------------------------
input_seqs, mask_seqs = [], []

for t in attribution_times:
    if t not in df.index:
        print(f"⚠️  {t} missing in df.index – skipped");   continue
    idx = df.index.get_loc(t)
    if idx < seq_length:
        print(f"⚠️  Not enough history before {t} – skipped");   continue

    # raw window (seq_len × features) --------------------------
    window_raw = df.iloc[idx - seq_length: idx][exogenous_vars].values

    # ---> scale exactly like in inference
    window_scaled = scaler_X.transform(window_raw)

    # ---> build mask (1 = present, 0 = missing)
    mask = (~np.isnan(window_scaled)).astype(np.float32)

    # ---> impute NaNs with 0  (model expects this)
    window_scaled[np.isnan(window_scaled)] = 0.0

    input_seqs.append(window_scaled.astype(np.float32))
    mask_seqs.append(mask)

if not input_seqs:
    raise RuntimeError("No valid windows after filtering.")

X_attr      = torch.tensor(np.stack(input_seqs)).to(DEVICE)       # (N, L, F)
X_mask_attr = torch.tensor(np.stack(mask_seqs)).to(DEVICE)        # (N, L, F)

# quick sanity-check
print("Any NaNs after preprocessing?  X:", torch.isnan(X_attr).any().item(),
      " mask:", torch.isnan(X_mask_attr).any().item())

# ------------------------------------------------------------
# 2.  BASELINES  (in scaled space, no NaNs) – USE MEAN INSTEAD OF PRE-SPIKE
# ------------------------------------------------------------
zeros_X  = torch.zeros_like(X_attr[:1])
median_X = torch.median(X_attr, dim=0, keepdim=True).values

# --- Compute mean baseline over all valid attribution windows ---
mean_raw = []
for t in attribution_times:
    idx = df.index.get_loc(t)
    if idx < seq_length:
        continue
    window_raw = df.iloc[idx - seq_length: idx][exogenous_vars]
    mean_raw.append(window_raw.values)
if not mean_raw:
    raise ValueError("No valid windows to compute mean baseline.")
mean_raw = np.mean(np.stack(mean_raw), axis=0)  # shape: (L, F)
mean_scaled = scaler_X.transform(mean_raw)
mask_mean = (~np.isnan(mean_scaled)).astype(np.float32)
mean_scaled[np.isnan(mean_scaled)] = 0.0
baseline_mean_X = torch.tensor(mean_scaled, dtype=torch.float32).unsqueeze(0).to(DEVICE)
baseline_mean_M = torch.tensor(mask_mean,        dtype=torch.float32).unsqueeze(0).to(DEVICE)

baselines = {
    "zeros":     (zeros_X,   X_mask_attr[:1]),
    "median":    (median_X,  X_mask_attr[:1]),
    # "mean":      (baseline_mean_X, baseline_mean_M),
}

# ------------------------------------------------------------
# 3.  FEATURE ABLATION  – Captum
# ------------------------------------------------------------
def model_forward(x, m):
    return model(x, m)            # (batch, 1, 1)

fa = FeatureAblation(model_forward)

attr_results = {}
for name, (base_x, base_m) in baselines.items():
    # >>>> Feature Ablation call mirrors previous IG signature
    attr_tuple = fa.attribute(
        inputs=(X_attr, X_mask_attr),
        baselines=(base_x.expand_as(X_attr),
                   base_m.expand_as(X_mask_attr)),
        target=0,
        perturbations_per_eval=32,    # controls internal batching like internal_batch_size
        feature_mask=None             # default: each scalar is its own feature
    )

    # attr_tuple is (attr_X, attr_mask)
    attr_x, _ = attr_tuple
    print(f"{name:9s} | attr range [{attr_x.min():.2e}, {attr_x.max():.2e}]")

    attr_results[name] = attr_x.cpu().numpy()   # (N, L, F) – keep only X

# ------------------------------------------------------------
# 4.  AGGREGATE IMPORTANCE  (unchanged)
#     • |attr|  →  mean over (sample, time)  →  normalise to %
# ------------------------------------------------------------
attr_summary = {}
for name, a in attr_results.items():
    importance = np.mean(np.abs(a), axis=(0, 1))  # (F,)
    importance /= importance.sum() + 1e-12        # convert to %
    attr_summary[name] = importance

# ------------------------------------------------------------
# 5.  PLOT
# ------------------------------------------------------------
top_k   = 10
colors  = ["#3498db", "#2ecc71", "#e74c3c"]
fig, axes = plt.subplots(2, 3, figsize=(20, 8))
axes = axes.flatten()

for i, (name, imp) in enumerate(attr_summary.items()):
    idx_sorted = np.argsort(imp)[-top_k:][::-1]
    feature_indices = idx_sorted      # indices into exogenous_vars
    ids = [f"#{idx+1}" for idx in feature_indices]   # consistent ID = index+1

    # --- bars ---
    ax = axes[i]
    ax.bar(range(top_k), imp[feature_indices] * 100, color=colors[i])       # %
    ax.set_xticks(range(top_k))
    ax.set_xticklabels(ids)
    ax.set_xlabel("Feature ID")
    ax.set_ylabel("Contribution (%)")
    ax.set_title(f"Feature Ablation – {name}")

    # --- table ---
    ax_t = axes[i + 3]
    table_data = [[ids[j], exogenous_vars[feature_indices[j]]]
                  for j in range(top_k)]
    ax_t.axis("off")
    ax_t.set_title(f"{name}  mapping", pad=10)
    tbl = ax_t.table(
        cellText=table_data, colLabels=["ID", "Feature"],
        cellLoc="left", loc="center",
        colWidths=[0.07, 0.93],
    )
    tbl.auto_set_font_size(False)
    tbl.set_fontsize(9)
    tbl.scale(1.2, 1.2)

plt.suptitle("Feature Attribution (Feature Ablation)\n"
             f"{ATTR_START} – {ATTR_END}", fontsize=16)
plt.tight_layout(rect=[0, 0, 1, 0.96])

out_path = os.path.join(plots_dir, "feature_ablation_selected_window.png")
plt.savefig(out_path, dpi=300, bbox_inches="tight")
plt.show()
print("Saved ➜", out_path)

attr_summaries["FA"] = {k: v.copy() for k, v in attr_summary.items()}


## Top params

In [None]:
# ------------------------------------------------------------------
# ❶  PARAMETERS & SANITY-CHECK
# ------------------------------------------------------------------
top_k = 10
methods    = list(attr_summaries.keys())                     # ['IG', 'DL_SHAP', 'FA']
baselines  = list(attr_summaries[methods[0]].keys())         # ['zeros', 'median', 'pre_spike']

print("Methods   :", methods)
print("Baselines :", baselines)

# ------------------------------------------------------------------
# ❷  COLLECT UNIQUE INDICES PER BASELINE
# ------------------------------------------------------------------
baseline_to_idxs = {b: set() for b in baselines}             # { 'zeros': set(), ... }

for method in methods:
    for baseline, importance in attr_summaries[method].items():
        top_idx = np.argsort(importance)[-top_k:]            # this method+baseline’s top-k
        baseline_to_idxs[baseline].update(top_idx)           # accumulate in the set

# ------------------------------------------------------------------
# ❸  REPORT
# ------------------------------------------------------------------
for baseline in baselines:
    idxs  = sorted(baseline_to_idxs[baseline])
    names = [exogenous_vars[i] for i in idxs]

    print(f"\nBaseline: {baseline}")
    print(f"Unique top {top_k} features across {len(methods)} methods: {len(names)}")
    print("-"*60)
    for n, feat in enumerate(names, 1):
        print(f"{n:2d}. {feat}")


## Top params plot

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

# --------------------------------------------------
# ❶ CONFIG - Enhanced for professional presentation
# --------------------------------------------------
top_k = 10
methods = ["IG", "DL_SHAP", "FA"]
# Scientific color palette (colorblind-friendly)
colors = ["#4477AA", "#66CCEE", "#EE6677"]  
bar_width = 0.25
font_family = 'serif'  # Academic standard

# Set overall matplotlib style for thesis-quality figures
plt.style.use('seaborn-v0_8-whitegrid')
mpl.rcParams['font.family'] = font_family
mpl.rcParams['font.size'] = 16
mpl.rcParams['axes.titlesize'] = 18
mpl.rcParams['axes.labelsize'] = 17
mpl.rcParams['xtick.labelsize'] = 16
mpl.rcParams['ytick.labelsize'] = 16
mpl.rcParams['legend.fontsize'] = 16
mpl.rcParams['figure.titlesize'] = 20

# --------------------------------------------------
# ❷ COLLECT UNIQUE TOP-k FEATURES PER BASELINE (exclude 'mean')
# --------------------------------------------------
all_baselines = list(attr_summaries[methods[0]].keys())
baselines = [b for b in all_baselines if b != "mean"]
baseline_to_idxs = {b: set() for b in baselines}

for m in methods:
    for b, imp in attr_summaries[m].items():
        if b == "mean":
            continue
        top_idx = np.argsort(imp)[-top_k:]
        baseline_to_idxs[b].update(top_idx)

# --------------------------------------------------
# ❸ BUILD THE FIGURE - Enhanced for thesis presentation
# --------------------------------------------------
fig, axes = plt.subplots(len(baselines), 1,
                         figsize=(10, 3.5*len(baselines)),  # More compact, thesis-friendly ratio
                         constrained_layout=True)  # Better spacing management

# Title and subtitle removed per user request

if len(baselines) == 1:
    axes = [axes]

# Corrected method names for legend
method_names = {
    "IG": "Integrated Gradient",
    "DL_SHAP": "DeepLIFT SHAP",
    "FA": "Feature Ablation"
}

for row, baseline in enumerate(baselines):
    ax = axes[row]
    idxs = sorted(baseline_to_idxs[baseline])
    n = len(idxs)
    
    # x-locations for the centre of each "feature group"
    x_centres = np.arange(n)
    
    # Draw bars with enhanced styling
    for j, (method, color) in enumerate(zip(methods, colors)):
        imp = attr_summaries[method][baseline] * 100
        heights = imp[idxs]
        
        # Add bars with hatching for better distinction in grayscale printing
        hatch_patterns = ['', '///', '...']
        ax.bar(x_centres + (j-1)*bar_width, heights,
               width=bar_width, color=color, alpha=0.85,
               hatch=hatch_patterns[j], 
               label=method_names[method] if row==0 else "",
               edgecolor='black', linewidth=0.5)
    
    # Enhanced subplot styling
    ax.set_title(f"Baseline: {baseline.capitalize()}", loc='left', fontweight='normal')
    ax.set_ylabel("Contribution (%)", fontweight='normal')
    ax.set_xlabel("Feature Index", fontweight='normal')
    ax.set_xticks(x_centres)

    
    # Improved tick labels with smaller rotation for readability
    feature_labels = [f"#{i+1}" for i in idxs]
    ax.set_xticklabels(feature_labels, rotation=45, ha='right')
    
    # Refined grid
    ax.grid(axis="y", alpha=0.3, linestyle='--')
    
    # Add spines for more professional look
    for spine in ax.spines.values():
        spine.set_visible(True)
        spine.set_linewidth(0.75)
    
    # Add a horizontal line at y=0
    ax.axhline(y=0, color='black', linestyle='-', linewidth=0.8, alpha=0.5)

# Create legend with enhanced styling - top right with frame
handles, labels = axes[0].get_legend_handles_labels()
legend = fig.legend(handles, labels, loc='upper center', 
                   bbox_to_anchor=(0.5, 0.01),  # Below all subplots
                   ncol=len(methods), frameon=True, 
                   fancybox=True, shadow=True)

# Figure caption removed as requested

plt.savefig('feature_importance_comparison.pdf', bbox_inches='tight', dpi=300)
plt.savefig('feature_importance_comparison.png', bbox_inches='tight', dpi=300)
plt.show()

# --------------------------------------------------
# ❹ PRINT THE "# – Feature name" LISTS (with enhanced formatting)
# --------------------------------------------------
print("\nTable X: Feature Index Mapping")
print("=" * 50)
for baseline in baselines:
    idxs = sorted(baseline_to_idxs[baseline])
    print(f"\nBaseline: {baseline.capitalize()}")
    print("-" * 40)
    print(f"{'Index':<8} {'Feature Name':<30}")
    print("-" * 40)
    for i in idxs:
        print(f"#{i+1:<7} {exogenous_vars[i]:<30}")

In [None]:
# %%
# Print mapping from index (starting at 1) to feature name as a dictionary

feature_dict = {i + 1: name for i, name in enumerate(exogenous_vars)}
print("Index-to-Feature Mapping:")
print(feature_dict)


## Feature plots

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
from math import ceil
import matplotlib.dates as mdates
from matplotlib.ticker import MaxNLocator
import textwrap

# --------------------------------------------------
# ❶ USER CONFIG - Enhanced for thesis presentation
# --------------------------------------------------
top_n = 6
ATTR_START = pd.Timestamp("2024-04-03 10:00", tz="UTC")
ATTR_END = pd.Timestamp("2024-04-11 16:00", tz="UTC")
PEAK_START = pd.Timestamp("2024-04-07 10:00", tz="UTC")
PEAK_END = pd.Timestamp("2024-04-07 16:00", tz="UTC")


plot_style = "subplots"
use_zscore = False

plt.style.use('seaborn-v0_8-whitegrid')
font_family = 'serif'
mpl.rcParams['font.family'] = font_family
mpl.rcParams['font.size'] = 12
mpl.rcParams['axes.titlesize'] = 14
mpl.rcParams['axes.labelsize'] = 13
mpl.rcParams['xtick.labelsize'] = 11
mpl.rcParams['ytick.labelsize'] = 11
mpl.rcParams['legend.fontsize'] = 12
mpl.rcParams['figure.titlesize'] = 16

colors = plt.cm.tab10.colors


methods = ["IG", "DL_SHAP", "FA"]
method_names = {
    "IG": "Integrated Gradient",
    "DL_SHAP": "DeepLIFT SHAP",
    "FA": "Feature Ablation"
}
all_baselines = list(attr_summaries[methods[0]].keys())
baselines = [b for b in all_baselines if b != "mean"]

attr_pct = {}
for m in methods:
    attr_pct[m] = {}
    for b in baselines:
        v = attr_summaries[m][b].astype(float)
        v = v / (v.sum() + 1e-12)
        attr_pct[m][b] = v

all_vectors = [attr_pct[m][b] for m in methods for b in baselines]
mean_imp = np.mean(all_vectors, axis=0)
top_idx = np.argsort(mean_imp)[-top_n:][::-1]
top_names = [exogenous_vars[i] for i in top_idx]

slice_df = df.loc[ATTR_START:ATTR_END, top_names]
if use_zscore and plot_style == "overlay":
    slice_df = (slice_df - slice_df.mean()) / slice_df.std(ddof=0)

# --------------------------------------------------
# ❺ PLOT - Enhanced for thesis quality: 3x2 subplots, wrapped titles
# --------------------------------------------------
# ----- THIS FUNCTION CHOOSES THE BEST LOCATOR BASED ON DURATION -----
def get_time_locator(start, end):
    duration_hours = (end - start).total_seconds() / 3600
    if duration_hours <= 48:
        return mdates.HourLocator(interval=2), mdates.DateFormatter('%H:%M\n%d-%b')
    elif duration_hours <= 96:
        return mdates.HourLocator(interval=6), mdates.DateFormatter('%d-%b\n%H:%M')
    elif duration_hours <= 168:
        return mdates.HourLocator(interval=12), mdates.DateFormatter('%d-%b\n%H:%M')
    else:
        return mdates.DayLocator(interval=1), mdates.DateFormatter('%d-%b')
# --------------------------------------------------------------------

if plot_style == "overlay":
    fig, ax = plt.subplots(figsize=(10, 6))
    ax.axvspan(PEAK_START, PEAK_END, alpha=0.15, color='gray', label='Peak Period')
    line_styles = ['-', '--', ':', '-.', '-', '--']
    for i, (name, style) in enumerate(zip(top_names, line_styles)):
        ax.plot(slice_df.index, slice_df[name],
                linewidth=2.5,
                color=colors[top_idx[i] % len(colors)],  # Changed here
                linestyle=style,
                label=f"#{top_idx[i]+1} {name}")
    locator, formatter = get_time_locator(ATTR_START, ATTR_END) # <-- FIXED
    ax.xaxis.set_major_locator(locator)
    ax.xaxis.set_major_formatter(formatter)
    ax.set_title("Temporal Evolution of Top Feature Contributions", fontweight='normal', pad=15)
    ax.set_ylabel("Z-Score" if use_zscore else "Feature Value", fontweight='normal')
    ax.set_xlabel("Date", fontweight='normal')
    ax.grid(alpha=0.3, linestyle='--')
    for spine in ax.spines.values():
        spine.set_visible(True)
        spine.set_linewidth(0.75)
    handles, labels = ax.get_legend_handles_labels()
    if 'Peak Period' not in labels:
        from matplotlib.patches import Patch
        handles.append(Patch(facecolor='gray', alpha=0.15))
        labels.append('Peak Period (23:00-05:00)')
    ax.legend(handles, labels, bbox_to_anchor=(1.01, 1), loc="bottom left", frameon=True, fancybox=True, shadow=True)
    fig.tight_layout()

elif plot_style == "subplots":
    nrows, ncols = 3, 2
    fig, axes = plt.subplots(nrows, ncols, figsize=(13, 9), sharex=True, constrained_layout=True)
    axes = axes.flatten()
    fig.suptitle("Temporal Evolution of Key Feature Contributions (Case 2)", fontweight='normal', y=1.06)
    wrap_width = 65

    locator, formatter = get_time_locator(ATTR_START, ATTR_END)  # <-- FIXED

    for i, (ax, idx, name) in enumerate(zip(axes, top_idx, top_names)):
        ax.axvspan(PEAK_START, PEAK_END, alpha=0.2, color='gray')
        ax.plot(slice_df.index, slice_df[name], linewidth=2.5, color=colors[idx % len(colors)])  # Changed here
        feature_title = f"#{idx+1}: {name} ({mean_imp[idx]*100:.2f}%)"
        feature_title_wrapped = "\n".join(textwrap.wrap(feature_title, wrap_width))
        ax.set_title(feature_title_wrapped, loc='left', fontweight='normal', fontsize=11)
        ax.set_ylabel("Value", fontweight='normal')
        ax.yaxis.set_major_locator(MaxNLocator(nbins=5))
        ax.grid(alpha=0.3, linestyle='--')
        for spine in ax.spines.values():
            spine.set_visible(True)
            spine.set_linewidth(0.75)
        val_min = slice_df[name].min()
        val_max = slice_df[name].max()
        if val_min == val_max:
            delta = val_min * 0.01 if val_min != 0 else 0.01
            ax.axhspan(val_min - delta, val_max + delta, alpha=0.1, color=colors[idx % len(colors)])  # Changed here
        else:
            ax.axhspan(val_min, val_max, alpha=0.1, color=colors[idx % len(colors)])  # Changed here
        if i == 0:
            ax.text(0.98, 0.95, 'Peak Period', transform=ax.transAxes,
                    bbox=dict(facecolor='gray', alpha=0.2, edgecolor='none', pad=3),
                    ha='right', va='top', fontsize=10)
    # Turn off any unused axes (if top_n < nrows*ncols)
    for j in range(top_n, nrows * ncols):
        fig.delaxes(axes[j])
    # Format x-axis for time data on the bottom row
    for ax in axes[-ncols:]:
        ax.xaxis.set_major_locator(locator)      # <-- FIXED
        ax.xaxis.set_major_formatter(formatter)  # <-- FIXED
        ax.set_xlabel("Date", fontweight='normal')
    period_text = (f"Period: {ATTR_START.strftime('%Y-%m-%d %H:%M')} to {ATTR_END.strftime('%Y-%m-%d %H:%M')} UTC | "
                  f"Peak Activity: {PEAK_START.strftime('%Y-%m-%d %H:%M')} to {PEAK_END.strftime('%Y-%m-%d %H:%M')} UTC")
    fig.text(0.5, 1.02, period_text, ha='center', fontstyle='italic', fontsize=11)

else:
    raise ValueError("plot_style must be 'subplots' or 'overlay'")

plt.savefig('feature_timeseries_analysis.pdf', bbox_inches='tight', dpi=300)
plt.savefig('feature_timeseries_analysis.png', bbox_inches='tight', dpi=300)
plt.show()

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
from math import ceil
import matplotlib.dates as mdates
from matplotlib.ticker import MaxNLocator
import textwrap

# ----- YOUR CONFIG -----
top_n = 6
ATTR_START = pd.Timestamp("2024-04-03 10:00", tz="UTC")
ATTR_END   = pd.Timestamp("2024-04-11 16:00", tz="UTC")
PEAK_START = pd.Timestamp("2024-04-07 10:00", tz="UTC")
PEAK_END   = pd.Timestamp("2024-04-07 16:00", tz="UTC")
PEAK_PRICE = pd.Timestamp("2024-04-07 13:00", tz="UTC")  # New peak price line

plot_style = "subplots"
use_zscore = False

plt.style.use('seaborn-v0_8-whitegrid')
font_family = 'serif'
mpl.rcParams['font.family'] = font_family
mpl.rcParams['font.size'] = 16
mpl.rcParams['axes.titlesize'] = 18
mpl.rcParams['axes.labelsize'] = 15
mpl.rcParams['xtick.labelsize'] = 14
mpl.rcParams['ytick.labelsize'] = 14
mpl.rcParams['legend.fontsize'] = 16
mpl.rcParams['figure.titlesize'] = 20

# --- EXPLICIT HIGH-CONTRAST COLORS FOR YOUR INDEXES ---
SPECIAL_INDEX_COLORS = { # index nb - 1
    6:   "#1f77b4",  # blue
    14:  "#ff7f0e",  # orange
    16:  "#e377c2",  # pink
    17:  "#bcbd22",  # olive/lime
    29:  "#8c564b",  # brown
    32:  "#9467bd",  # purple
    33:  "#7f7f7f",  # gray
    34:  "#d62728",  # red
    35:  "#2ca02c",  # green
}

def get_feature_color(idx, default_palette=plt.cm.tab10.colors):
    if idx in SPECIAL_INDEX_COLORS:
        return SPECIAL_INDEX_COLORS[idx]
    return default_palette[idx % len(default_palette)]

# --- YOUR DATA/PREPROCESSING ---

methods = ["IG", "DL_SHAP", "FA"]
all_baselines = list(attr_summaries[methods[0]].keys())
baselines = [b for b in all_baselines if b != "mean"]

attr_pct = {}
for m in methods:
    attr_pct[m] = {}
    for b in baselines:
        v = attr_summaries[m][b].astype(float)
        v = v / (v.sum() + 1e-12)
        attr_pct[m][b] = v

all_vectors = [attr_pct[m][b] for m in methods for b in baselines]
mean_imp = np.mean(all_vectors, axis=0)
top_idx = np.argsort(mean_imp)[-top_n:][::-1]
top_names = [exogenous_vars[i] for i in top_idx]

slice_df = df.loc[ATTR_START:ATTR_END, top_names]
if use_zscore and plot_style == "overlay":
    slice_df = (slice_df - slice_df.mean()) / slice_df.std(ddof=0)

def get_time_locator(start, end):
    duration_hours = (end - start).total_seconds() / 3600
    if duration_hours <= 48:
        return mdates.HourLocator(interval=2), mdates.DateFormatter('%H:%M\n%d-%b')
    elif duration_hours <= 96:
        return mdates.HourLocator(interval=6), mdates.DateFormatter('%d-%b\n%H:%M')
    elif duration_hours <= 168:
        return mdates.HourLocator(interval=12), mdates.DateFormatter('%d-%b\n%H:%M')
    else:
        return mdates.DayLocator(interval=1), mdates.DateFormatter('%d-%b')

if plot_style == "overlay":
    fig, ax = plt.subplots(figsize=(10, 6))
    ax.axvspan(PEAK_START, PEAK_END, alpha=0.15, color='gray', label='Peak Period')
    ax.axvline(PEAK_PRICE, color='red', linestyle='--', linewidth=2, label='Peak Price')
    line_styles = ['-', '--', ':', '-.', '-', '--']
    for i, (name, style) in enumerate(zip(top_names, line_styles)):
        color = get_feature_color(top_idx[i])
        ax.plot(slice_df.index, slice_df[name],
                linewidth=2.5,
                color=color,
                linestyle=style,
                label=f"#{top_idx[i]+1} {name}")
    locator, formatter = get_time_locator(ATTR_START, ATTR_END)
    ax.xaxis.set_major_locator(locator)
    ax.xaxis.set_major_formatter(formatter)
    plt.setp(ax.xaxis.get_majorticklabels(), rotation=45, ha='right')
    ax.set_ylabel("Z-Score" if use_zscore else "Feature Value", fontweight='normal')
    ax.set_xlabel("Date", fontweight='normal')
    ax.grid(alpha=0.3, linestyle='--')
    for spine in ax.spines.values():
        spine.set_visible(True)
        spine.set_linewidth(0.75)
    handles, labels = ax.get_legend_handles_labels()
    if 'Peak Period' not in labels:
        from matplotlib.patches import Patch
        handles.append(Patch(facecolor='gray', alpha=0.15))
        labels.append('Peak Period (23:00-05:00)')
    ax.legend(handles, labels, bbox_to_anchor=(1.01, 1), loc="bottom left", frameon=True, fancybox=True, shadow=True)
    fig.tight_layout()

elif plot_style == "subplots":
    nrows, ncols = 3, 2
    fig, axes = plt.subplots(nrows, ncols, figsize=(13, 9), sharex=True, constrained_layout=True)
    axes = axes.flatten()
    wrap_width = 41

    locator, formatter = get_time_locator(ATTR_START, ATTR_END)

    for i, (ax, idx, name) in enumerate(zip(axes, top_idx, top_names)):
        color = get_feature_color(idx)
        ax.axvspan(PEAK_START, PEAK_END, alpha=0.3, color='gray')
        ax.axvline(PEAK_PRICE, color='red', linestyle='--', linewidth=2)
        ax.plot(slice_df.index, slice_df[name], linewidth=2.5, color=color)
        feature_title = f"#{idx+1}: {name} ({mean_imp[idx]*100:.2f}%)"
        feature_title_wrapped = "\n".join(textwrap.wrap(feature_title, wrap_width))
        ax.set_title(feature_title_wrapped, loc='left', fontweight='normal', fontsize=16)
        ax.set_ylabel("Value", fontweight='normal')
        ax.yaxis.set_major_locator(MaxNLocator(nbins=5))
        ax.grid(alpha=0.3, linestyle='--')
        for spine in ax.spines.values():
            spine.set_visible(True)
            spine.set_linewidth(0.75)
        val_min = slice_df[name].min()
        val_max = slice_df[name].max()
        # --------- THIS IS THE ONLY CHANGED PART -----------
        # If idx == 36 (for index #36 Transmission_PhysicalFlows_FROM_SE2_FlowValue), 
        # don't include zero in the y-limits:
        if idx == 35:
            ax.set_ylim(val_min, val_max)
            ax.axhspan(val_min  + 300 , val_max, alpha=0.1, color=color)
        else:
            if val_min == val_max:
                delta = val_min * 0.01 if val_min != 0 else 0.01
                ax.axhspan(val_min - delta, val_max + delta, alpha=0.1, color=color)
            else:
                ax.axhspan(val_min, val_max, alpha=0.1, color=color)
                # Default: don't force to zero, unless you want to do so for all others (you currently don't)
        # ----------------------------------------------------
        if i == 0:
            ax.text(0.98, 0.95, 'Peak Period', transform=ax.transAxes,
                    bbox=dict(facecolor='gray', alpha=0.3, edgecolor='none', pad=3),
                    ha='right', va='top', fontsize=14)
            ax.text(0.98, 0.80, 'Peak Price', transform=ax.transAxes,
                    bbox=dict(facecolor='red', alpha=0.3, edgecolor='none', pad=3),
                    ha='right', va='top', fontsize=14)
    for j in range(top_n, nrows * ncols):
        fig.delaxes(axes[j])
    for ax in axes[-ncols:]:
        ax.xaxis.set_major_locator(locator)
        ax.xaxis.set_major_formatter(formatter)
        plt.setp(ax.xaxis.get_majorticklabels(), rotation=45, ha='right')
        ax.set_xlabel("Date", fontweight='normal')


else:
    raise ValueError("plot_style must be 'subplots' or 'overlay'")

plt.savefig('feature_timeseries_analysis.pdf', bbox_inches='tight', dpi=300)
plt.savefig('feature_timeseries_analysis.png', bbox_inches='tight', dpi=300)
plt.show()

# Val set

In [None]:
import torch
import numpy as np
import pandas as pd
import joblib
import plotly.graph_objects as go

# ---- PATHS ----
MODEL_PATH = "/new_model_test/best_model_SE2_Down_Volume.pth"
SCALER_X_PATH = "/new_model_test/scaler_X_SE2_Down_Volume.joblib"
SCALER_Y_PATH = "/new_model_test/scaler_y_SE2_Down_Volume.joblib"
FILE_PATH = "/results/results_merged.csv"

REGION = "SE2"
target_vars = [f"Balancing_ActivatedBalancingEnergy_{REGION}_mFRR_NotSpecifiedDownActivatedVolume"]

# -- Model class --
class MaskedLSTM(torch.nn.Module):
    def __init__(self, input_size, hidden_size=32, num_layers=1, dense_size=32, output_size=1, horizon=1, dropout=0.08, dropout_lstm=0.27, bidirectional=True):
        super().__init__()
        self.horizon = horizon
        self.output_size = output_size
        actual_in = input_size * 2
        self.lstm = torch.nn.LSTM(
            input_size=actual_in,
            hidden_size=hidden_size,
            num_layers=num_layers,
            dropout=dropout_lstm if num_layers > 1 else 0.0,
            bidirectional=bidirectional,
            batch_first=True,
        )
        hidden_out = hidden_size * (2 if bidirectional else 1)
        self.fc1 = torch.nn.Linear(hidden_out, dense_size)
        self.fc2 = torch.nn.Linear(dense_size, horizon * output_size)
        self.relu = torch.nn.ReLU()
        self.drop = torch.nn.Dropout(dropout)
        self.drop_lstm = torch.nn.Dropout(dropout_lstm)
    def forward(self, x, mask):
        x = torch.cat((x, mask), dim=2)
        out, _ = self.lstm(x)
        out = self.drop_lstm(out)
        out = self.relu(self.fc1(out[:, -1, :]))
        out = self.drop(out)
        out = self.fc2(out)
        return out.view(out.size(0), self.horizon, self.output_size)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
scaler_X = joblib.load(SCALER_X_PATH)
scaler_y = joblib.load(SCALER_Y_PATH)

best_params = {
    "hidden_size": 32,
    "num_layers": 1,
    "bidirectional": True,
    "dense_size": 32,
    "dropout": 0.08137184695714653,
    "dropout_lstm": 0.2675438586833995,
    "learning_rate": 0.0013843186375837144,
    "batch_size": 128,
    "seq_length": 216,
}
seq_length = best_params["seq_length"]

# --- Load and preprocess data ---
df_all = pd.read_csv(FILE_PATH)
price_column_name_to_plot = f"BalancingER_PricesOfActivatedBalancingEnergy_{REGION}_mFRR_NotSpecifiedDownPrice"

# Store price data before dropping columns
price_data_full = None
if price_column_name_to_plot in df_all.columns:
    price_data_full = df_all[['DateTime', price_column_name_to_plot]].copy()
    price_data_full['DateTime'] = pd.to_datetime(price_data_full['DateTime'], utc=True)
    price_data_full.set_index('DateTime', inplace=True)
    print(f"Stored '{price_column_name_to_plot}' separately for plotting.")
else:
    print(f"Warning: Price column '{price_column_name_to_plot}' not found in df_all. Will not be plotted.")

from filter_features import pick_region_filter
BDZ_filter = pick_region_filter(region=REGION, remove_balancing=True)
drop_list = [
    f"Balancing_PricesOfActivatedBalancingEnergy_{REGION}_mFRR_NotSpecifiedUpPrice",
    f"Balancing_PricesOfActivatedBalancingEnergy_{REGION}_mFRR_NotSpecifiedDownPrice",
    f"Balancing_ActivatedBalancingEnergy_{REGION}_mFRR_NotSpecifiedUpActivatedVolume",
    f"Balancing_ActivatedBalancingEnergy_{REGION}_mFRR_NotSpecifiedDownActivatedVolume",
]
BDZ_filter = [c for c in BDZ_filter if c not in drop_list or c in target_vars]
df_all["DateTime"] = pd.to_datetime(df_all["DateTime"], utc=True)
df_all.set_index("DateTime", inplace=True)
df_all = df_all.asfreq("h")
df = df_all[BDZ_filter].copy()
df = df.asfreq("h")

def fill_ahead_features(df: pd.DataFrame) -> pd.DataFrame:
    """
    Forward-fills 0/NaN gaps in *-ahead* columns within their natural period
    (day-ahead → daily, week-ahead → Monday-anchored weeks, etc.).

    A value is propagated only up to the end of the period, and only while the
    entries being filled are 0 or NaN.
    """
    df = df.copy()

    period_freq = {
        "dayahead":  "D",        # daily groups
        "weekahead": "W-MON",    # ISO weeks starting on Monday
        "monthahead":"MS",       # month start
        "yearahead": "AS"        # year start
    }

    for period, freq in period_freq.items():
        ahead_cols = [c for c in df.columns if period in c.lower()]
        if not ahead_cols:
            continue

        for col in ahead_cols:
            s = df[col]

            # treat *strictly* 0 as missing, but preserve genuine zeros by
            # forward-filling only into 0/NaN slots
            s_filled = (
                s.replace(0, np.nan)                       # step 1: 0 → NaN
                 .groupby(pd.Grouper(freq=freq))           # step 2: group
                 .ffill()                                  # step 3: ffill inside group
                 .fillna(0)                                # step 4: keep leading zeros
            )

            df[col] = np.where(
                (s == 0) | s.isna(),        # fill only where original was 0/NaN
                s_filled,                   #   …with the forward-filled value
                s                           # keep genuine entries untouched
            )

    return df

# 


CREATE_INDICATOR_FEATURES = False
CREATE_CYCLICAL_FEATURES = False
CREATE_LAGGED_FEATURES = False

new_indicator_features = []
if CREATE_INDICATOR_FEATURES:
    indicator_features_to_create = [col for col in BDZ_filter + target_vars if col in df.columns]
    for col in indicator_features_to_create:
        df[f'{col}_was_missing'] = df[col].isnull().astype(int)
    new_indicator_features = [f'{col}_was_missing' for col in indicator_features_to_create]

new_time_features = []
if CREATE_CYCLICAL_FEATURES:
    df['hour_sin'] = np.sin(2 * np.pi * df.index.hour / 24.0)
    df['hour_cos'] = np.cos(2 * np.pi * df.index.hour / 24.0)
    df['dayofweek_sin'] = np.sin(2 * np.pi * df.index.dayofweek / 7.0)
    df['dayofweek_cos'] = np.cos(2 * np.pi * df.index.dayofweek / 7.0)
    df['weekofyear_sin'] = np.sin(2 * np.pi * df.index.isocalendar().week / 52.0)
    df['weekofyear_cos'] = np.cos(2 * np.pi * df.index.isocalendar().week / 52.0)
    new_time_features = ['hour_sin', 'hour_cos', 'dayofweek_sin', 'dayofweek_cos', 'weekofyear_sin', 'weekofyear_cos']

new_lagged_features = []
if CREATE_LAGGED_FEATURES:
    for target_var in target_vars:
        df[f'{target_var}_lag1'] = df[target_var].shift(1)
        df[f'{target_var}_lag24'] = df[target_var].shift(24)
        df[f'{target_var}_lag168'] = df[target_var].shift(168)
    for target_var in target_vars:
        new_lagged_features.extend([f'{target_var}_lag1', f'{target_var}_lag24', f'{target_var}_lag168'])

df.dropna(axis=1, how="all", inplace=True)

exogenous_vars = [c for c in BDZ_filter if c not in target_vars]
if CREATE_CYCLICAL_FEATURES:
    exogenous_vars.extend(new_time_features)
if CREATE_LAGGED_FEATURES:
    exogenous_vars.extend(new_lagged_features)
if CREATE_INDICATOR_FEATURES:
    exogenous_vars.extend(new_indicator_features)
exogenous_vars = sorted(list(set(var for var in exogenous_vars if var in df.columns)))

# --- Instantiate the model ---
model = MaskedLSTM(
    input_size=len(exogenous_vars),
    hidden_size=best_params["hidden_size"],
    num_layers=best_params["num_layers"],
    dense_size=best_params["dense_size"],
    output_size=len(target_vars),
    dropout=best_params["dropout"],
    dropout_lstm=best_params["dropout_lstm"],
    bidirectional=best_params["bidirectional"],
).to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()

# --- Robust window selection: will always work for any seq_length and date range ---

# desired_start = pd.Timestamp("2023-09-05", tz="UTC")
# desired_end = pd.Timestamp("2023-09-08", tz="UTC")

desired_start = pd.Timestamp("2025-02-04", tz="UTC")
desired_end = pd.Timestamp("2025-02-20", tz="UTC")


all_dates = df.index
try:
    first_pred_idx = all_dates.get_loc(desired_start)
except KeyError:
    raise ValueError(f"Desired start {desired_start} not in DataFrame index!")

if first_pred_idx < seq_length:
    raise ValueError(f"Not enough data before {desired_start} to build input sequence with seq_length={seq_length}.")

window_start = all_dates[first_pred_idx - seq_length]
window_end = desired_end

df_pred_window = df.loc[window_start:window_end].copy()

# -- Scale features --
X_scaled = scaler_X.transform(df_pred_window[exogenous_vars])
y_true = df_pred_window[target_vars].values

# -- Build sequences --
def create_sequences(X, lookback, horizon=1):
    num_samples = X.shape[0] - lookback - horizon + 1
    X_seq = np.zeros((num_samples, lookback, X.shape[1]), dtype=np.float32)
    mask_seq = np.ones((num_samples, lookback, X.shape[1]), dtype=np.float32)
    for i in range(num_samples):
        X_seq[i] = X[i : i + lookback]
    return X_seq, mask_seq

X_seq, mask_seq = create_sequences(X_scaled, seq_length, 1)
nan_pos = np.isnan(X_seq)
mask_seq[nan_pos] = 0.0
X_seq[nan_pos] = 0.0

dates_seq = df_pred_window.index[seq_length : seq_length + len(X_seq)]
y_true_aligned = y_true[seq_length : seq_length + len(X_seq)]

# -- Filter for predictions within the requested window only --
mask = (dates_seq >= desired_start) & (dates_seq <= desired_end)
dates_seq = dates_seq[mask]
y_true_aligned = y_true_aligned[mask]

with torch.no_grad():
    X_tensor = torch.tensor(X_seq, dtype=torch.float32).to(DEVICE)
    mask_tensor = torch.tensor(mask_seq, dtype=torch.float32).to(DEVICE)
    y_pred_scaled = model(X_tensor, mask_tensor).cpu().numpy()[:, 0, :]  # (n, target_dim)
y_pred = scaler_y.inverse_transform(y_pred_scaled)
y_pred = y_pred[mask]

# --- Add price data ---
price_data_to_plot = None
if price_data_full is not None and price_column_name_to_plot in price_data_full.columns:
    try:
        # Align price data with dates_seq
        price_data_to_plot = price_data_full.loc[dates_seq, price_column_name_to_plot].values.astype(float)
    except Exception as e:
        print(f"Warning: Could not align price data with dates_seq: {e}")
        price_data_to_plot = None

import matplotlib.pyplot as plt
import matplotlib.dates as mdates

# --- Colors and markers ---
true_color = "#636efa"    # Plotly default blue
pred_color = "#ef553b"    # Plotly default red
price_color = "orange"

# --- Matplotlib plot ---
fig, ax1 = plt.subplots(figsize=(12, 7))

# True values
ax1.plot(dates_seq, y_true_aligned[:, 0], label='True', color=true_color, marker='o', markersize=5, linewidth=2)

# Predicted values
ax1.plot(dates_seq, y_pred[:, 0], label='Predicted', color=pred_color, marker='x', markersize=5, linewidth=2)

ax1.set_xlabel("Date", fontsize=13)
ax1.set_ylabel("Volume (MW)", fontsize=13, color=true_color)
ax1.tick_params(axis='y', labelcolor=true_color)
ax1.xaxis.set_major_formatter(mdates.DateFormatter('%d %b %Y\n%H:%M'))
plt.xticks(rotation=45)
plt.grid(axis='y', alpha=0.3)

# Add price on secondary y-axis
if price_data_to_plot is not None:
    ax2 = ax1.twinx()
    ax2.plot(dates_seq, price_data_to_plot, label='mFRR Down Price', color=price_color, linewidth=2)
    ax2.set_ylabel("mFRR Down Price (€/MWh)", fontsize=13, color=price_color)
    ax2.tick_params(axis='y', labelcolor=price_color)
    ax2.grid(False)
else:
    ax2 = None

# Title and legend
fig.suptitle(
    f"True vs Predicted for mFRR Down Volume",
    fontsize=15, fontweight='normal', y=1.0
)

# Combined legend
lines_labels = [ax.get_legend_handles_labels() for ax in [ax1, ax2] if ax is not None]
lines, labels = [sum(lol, []) for lol in zip(*lines_labels)]
plt.legend(lines, labels, loc='lower center', bbox_to_anchor=(0.5, -0.4), ncol=3, fontsize=12, frameon=False)

# --- Add analysis period annotation (NEW) ---
period_text = (
    f"Validation set: {desired_start.strftime('%Y-%m-%d %H:%M')} to "
    f"{desired_end.strftime('%Y-%m-%d %H:%M')} UTC"
)
fig.text(0.5, 0.94, period_text, ha='center', fontstyle='italic', fontsize=11)

plt.tight_layout()
plt.show()



In [69]:
import pandas as pd

df_read = pd.read_csv("/processed_data/results/resultsset_v2.csv")
df_read.columns = df_read.columns.str.replace(' ', '', regex=False)