In [1]:
!pip install pyunpack patool
!pip install wget
!pip install scikit-learn
!pip install torchvision
!pip install tensorboard



In [41]:
import pandas as pd
import numpy as np
import pyunpack
import math
import json

from data.data_download import Config, download_electricity
from data_formatters.base import DataTypes,  InputTypes
from data_formatters.electricity import ElectricityFormatter
from data.custom_dataset import TFTDataset

from models import GatedLinearUnit
from models import GateAddNormNetwork
from models import GatedResidualNetwork 
from models import ScaledDotProductAttention
from models import InterpretableMultiHeadAttention
from models import VariableSelectionNetwork

from quantile_loss import QuantileLossCalculator
from quantile_loss import NormalizedQuantileLossCalculator

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch import nn

import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.callbacks import LearningRateMonitor
from argparse import ArgumentParser
from pytorch_lightning.loggers import TensorBoardLogger

import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import ReduceLROnPlateau

import pandas as pd
from dateutil import parser
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.callbacks import TQDMProgressBar
from pathlib import Path



# Choose Experiment E0/E1/E2

In [42]:
EXP = "E3"  # "E0" | "E1" | "E2"| "E3"

if EXP == "E0":          
    use_tlf = True
    tlf_replace_with_seasonal = False
    err_target= False
elif EXP == "E1":       
    use_tlf = False
    tlf_replace_with_seasonal = False
    err_target= False
elif EXP == "E2":        
    use_tlf = True
    tlf_replace_with_seasonal = True
    err_target= False
elif EXP == "E3":
    use_tlf = True
    tlf_replace_with_seasonal = False
    err_target= True
else:
    raise ValueError("Unknown EXP")

from data_formatters.electricity import ElectricityFormatter
data_formatter = ElectricityFormatter(
    use_tlf=use_tlf,
    tlf_replace_with_seasonal=tlf_replace_with_seasonal,
    err_target=err_target
)


# Load Data and Data Pre-processing

In [43]:

# === 1. Load and clean weather dataset ===
weather_df = pd.read_csv("data/archive/weather_features.csv")
weather_df['dt_iso'] = weather_df['dt_iso'].apply(lambda x: parser.parse(x).replace(tzinfo=None))

weather_df['city_name'] = weather_df['city_name'].str.strip()
weather_df['city_name'] = weather_df['city_name'].replace({
    'Sevilla': 'Seville',
    'València': 'Valencia'
})

# === 2. Define population weights (Bilbao replaces Zaragoza)
weights = {
    "Madrid": 0.4804,
    "Barcelona": 0.2439,
    "Valencia": 0.1192,
    "Seville": 0.1044,
    "Bilbao": 0.0522
}
weather_df = weather_df[weather_df['city_name'].isin(weights.keys())]

# === 3. Weighted numeric aggregation
columns_to_weight = [
    'temp', 'temp_min', 'temp_max', 'pressure', 'humidity',
    'wind_speed', 'wind_deg', 'rain_1h', 'rain_3h', 'snow_3h', 'clouds_all'
]

for col in columns_to_weight:
    weather_df[col + '_weighted'] = weather_df.apply(
        lambda row: row[col] * weights[row['city_name']], axis=1
    )

weighted_numeric = weather_df.groupby('dt_iso').agg({
    col + '_weighted': 'sum' for col in columns_to_weight
}).reset_index()

# === 4. Simplify weather_description
top_desc = weather_df['weather_description'].value_counts()
top_desc = top_desc[top_desc > 100].index
weather_df['weather_description_simplified'] = weather_df['weather_description'].where(
    weather_df['weather_description'].isin(top_desc), other='other'
)

# === 5. Aggregate categorical and ID fields
def mode(series):
    return series.mode().iloc[0] if not series.mode().empty else 'unknown'

categorical_agg = weather_df.groupby('dt_iso').agg({
    'weather_main': mode,
    'weather_description_simplified': mode,
}).reset_index()

# === 6. Merge all weather features
weather_final = pd.merge(weighted_numeric, categorical_agg, on='dt_iso')
weather_final.rename(columns={'dt_iso': 'time'}, inplace=True)

# === 7. Load and clean energy dataset
energy_df = pd.read_csv("data/archive/energy_dataset.csv")
energy_df['time'] = energy_df['time'].apply(lambda x: parser.parse(x).replace(tzinfo=None))

# === 8. Merge energy + weather
merged_df = pd.merge(energy_df, weather_final, on='time', how='inner')

# === 9. Save result to final path
output_path = "data/weighted_weather_full.csv"
merged_df.to_csv(output_path, index=False)

print("✅ Saved:", output_path)


✅ Saved: data/weighted_weather_full.csv


In [None]:

# === 1. Load original data ===
df = pd.read_csv("data/weighted_weather_full.csv")

# === 2. Drop columns that are completely empty ===
df = df.dropna(axis=1, how='all')

# === 3. Drop rows where 'total load actual' (your target) is missing ===
df = df[df['total load actual'].notna()]

# === 4. Identify numeric columns and fill missing values ===
numeric_cols = df.select_dtypes(include=['float64', 'int64']).columns.tolist()
if 'days_from_start' in numeric_cols:
    numeric_cols.remove('days_from_start')  # leave it alone

# Forward fill first, then fill remaining NaNs with column medians
df[numeric_cols] = df[numeric_cols].fillna(method='ffill')
df[numeric_cols] = df[numeric_cols].fillna(df[numeric_cols].median())

# === 5. Handle categorical variables: fill with 'unknown' ===
categorical_cols = ['weather_main', 'weather_description_simplified']
for col in categorical_cols:
    if col in df.columns:
        df[col] = df[col].fillna('unknown')

# === 6. Save cleaned dataset ===
df.to_csv("data/weighted_weather_full_clean.csv", index=False)

print("✅ Data cleaned and saved as: weighted_weather_full_clean.csv")


In [None]:

# === 1. Load the merged energy + weather dataset ===
df = pd.read_csv("data/weighted_weather_full_clean.csv")

# === 2. Convert time column to datetime ===
df['time'] = pd.to_datetime(df['time'])

# === 3. Create 'days_from_start' (integer offset from first date) ===
df['days_from_start'] = (df['time'] - df['time'].min()).dt.days

# === 4. Extract time features ===
df['hour'] = df['time'].dt.hour
df['day_of_week'] = df['time'].dt.dayofweek  # Monday=0, Sunday=6

# === 5. Add 'id' column (static identifier) ===
df['identifier'] = 'Spain'
# === 6. Seasonal baseline (for tlf_replace_with_seasonal) ===
df['_time_order'] = df['days_from_start'] * 24 + df['hour'].astype(int)
df = df.sort_values(['identifier', '_time_order'])
if 'split' not in df.columns:
    def _make_splits(g, train_frac=0.8, val_frac=0.1):
        n = len(g)
        i1 = int(n * train_frac)
        i2 = int(n * (train_frac + val_frac))
        lab = np.array(['train'] * n)
        lab[i1:i2] = 'val'
        lab[i2:]   = 'test'
        g['split'] = lab
        return g
    df = df.groupby('identifier', group_keys=False).apply(_make_splits)

#  Hour-of-week key: same hour & weekday share seasonal pattern
df['hour_of_week'] = df['day_of_week'].astype(int) * 24 + df['hour'].astype(int)
doy = (df['days_from_start'] % 365).astype(int)          # approx day-of-year
df['month_bin'] = (doy // 30).clip(0, 11) + 1            # 1..12 approx

# Seasonal baseline: use only past values (shift(1) then rolling)
K = 8  # past 8 weeks; try 4/12 for sensitivity analysis
df['seasonal_baseline'] = (
    df.groupby(['identifier', 'hour_of_week'])['total load actual']
      .transform(lambda s: s.shift(1).rolling(K, min_periods=1).median()))
# Fallback #1: expanding median on same hour_of_week (past-only)
mask = df['seasonal_baseline'].isna()
if mask.any():
    exp_hw = (
        df.groupby(['identifier', 'hour_of_week'])['total load actual']
          .transform(lambda s: s.shift(1).expanding().median())
    )
    df.loc[mask, 'seasonal_baseline'] = exp_hw[mask]

# Fallback #2: expanding median on coarser month_bin × hour (past-only)
mask = df['seasonal_baseline'].isna()
if mask.any():
    exp_mh = (
        df.groupby(['identifier', 'month_bin', 'hour'])['total load actual']
          .transform(lambda s: s.shift(1).expanding().median())
    )
    df.loc[mask, 'seasonal_baseline'] = exp_mh[mask]

# Fallback #3: expanding median per identifier (past-only)
mask = df['seasonal_baseline'].isna()
if mask.any():
    exp_id = (
        df.groupby('identifier')['total load actual']
          .transform(lambda s: s.shift(1).expanding().median())
    )
    df.loc[mask, 'seasonal_baseline'] = exp_id[mask]
df.drop(columns=['_time_order'], inplace=True)
still = df['seasonal_baseline'].isna()
if still.any():
    id_med = (df.loc[df['split']=='train']
                .groupby('identifier')['total load actual']
                .median())
    df.loc[still,'seasonal_baseline'] = df.loc[still,'identifier'].map(id_med)
# === 7. Rename target column ===
df.rename(columns={'total load actual': 'target'}, inplace=True)

# Residual/error target: actual - provided prediction
df['err_target'] = df['target'] - df['total load forecast']


# === 8. Save the cleaned dataset ===
df.to_csv("data/electricty.csv", index=False)

print("✅ Done: 'hour', 'day_of_week', 'days_from_start', 'id' added, target renamed, file saved.")



# Creating Datasets

In [20]:
df = pd.read_csv("data/electricty.csv")
print("📋 All column names:")
print(df.columns.tolist())


📋 All column names:
['time', 'generation biomass', 'generation fossil brown coal/lignite', 'generation fossil coal-derived gas', 'generation fossil gas', 'generation fossil hard coal', 'generation fossil oil', 'generation fossil oil shale', 'generation fossil peat', 'generation geothermal', 'generation hydro pumped storage consumption', 'generation hydro run-of-river and poundage', 'generation hydro water reservoir', 'generation marine', 'generation nuclear', 'generation other', 'generation other renewable', 'generation solar', 'generation waste', 'generation wind offshore', 'generation wind onshore', 'forecast solar day ahead', 'forecast wind onshore day ahead', 'total load forecast', 'target', 'price day ahead', 'price actual', 'temp_weighted', 'temp_min_weighted', 'temp_max_weighted', 'pressure_weighted', 'humidity_weighted', 'wind_speed_weighted', 'wind_deg_weighted', 'rain_1h_weighted', 'rain_3h_weighted', 'snow_3h_weighted', 'clouds_all_weighted', 'weather_main', 'weather_descrip

In [21]:
electricity = pd.read_csv('data/electricty.csv', index_col = 0)
train, valid, test = data_formatter.split_data(electricity)

Formatting train-valid-test splits.
Setting scalers with training data...


In [22]:
# ==== helper: keep only model columns based on column_definition ====
def filter_to_model_columns(df, coldef):
    """Whitelist columns that appear in the model's column definition."""
    keep = [name for (name, _, _) in coldef if name in df.columns]
    return df.loc[:, keep].copy()

# get current column definition from formatter
coldef = (data_formatter.get_column_definition()
          if hasattr(data_formatter, "get_column_definition")
          else data_formatter._column_definition)
# ==== E2 only: ensure 'seasonal_baseline' exists in all splits ====
if tlf_replace_with_seasonal:
    # If the column is already in the CSV/splits, just sanity-fix dtype & NaNs.
    has_all = all("seasonal_baseline" in df_.columns for df_ in (train, valid, test))

    if has_all:
        for name, d in (("train", train), ("valid", valid), ("test", test)):
            d["seasonal_baseline"] = pd.to_numeric(d["seasonal_baseline"], errors="coerce")
            if d["seasonal_baseline"].isna().any():
                d["seasonal_baseline"] = d["seasonal_baseline"].fillna(train["target"].median())
        # nothing else to do
    else:
        # build from TRAIN ONLY to avoid leakage, using available keys
        keys = [k for k in ("day_of_week", "hour") if k in train.columns]
        if len(keys) >= 1:
            base = (
                train.groupby(keys)["target"]
                     .median()
                     .reset_index()
                     .rename(columns={"target": "seasonal_baseline"})
            )

            def _merge_and_fix(df):
                # drop any existing same-named col to prevent _x/_y suffixes
                if "seasonal_baseline" in df.columns:
                    df = df.drop(columns=["seasonal_baseline"])
                merged = df.merge(base, on=keys, how="left", copy=False)
                if merged["seasonal_baseline"].isna().any():
                    merged["seasonal_baseline"] = merged["seasonal_baseline"].fillna(train["target"].median())
                return merged

            train = _merge_and_fix(train)
            valid = _merge_and_fix(valid)
            test  = _merge_and_fix(test)
        else:
            # last resort: global median from TRAIN
            sb = float(train["target"].median())
            for d in (train, valid, test):
                d["seasonal_baseline"] = sb
# else: nothing to do for E1

# ==== whitelist columns (works for both E1 and E2) ====
train = filter_to_model_columns(train, coldef)
valid = filter_to_model_columns(valid, coldef)
test  = filter_to_model_columns(test,  coldef)

# ==== quick sanity checks ====
known_cols = [n for (n, dt, role) in coldef
              if role == InputTypes.KNOWN_INPUT and dt == DataTypes.REAL_VALUED]
print("KNOWN(real) from formatter:", known_cols)
print("train columns:", list(train.columns)[:12], "... (#", len(train.columns), ")")

if EXP == "E1":
    assert "total load forecast" not in known_cols
    assert "total load forecast" not in train.columns
if EXP == "E2":
    assert "seasonal_baseline" in known_cols, "E2: seasonal_baseline must be KNOWN input."
    assert "seasonal_baseline" in train.columns, "E2: seasonal_baseline missing in dataframes."

            
             

KNOWN(real) from formatter: ['hour', 'day_of_week', 'forecast solar day ahead', 'forecast wind onshore day ahead', 'price day ahead', 'temp_weighted', 'temp_min_weighted', 'temp_max_weighted', 'pressure_weighted', 'humidity_weighted', 'wind_speed_weighted', 'wind_deg_weighted', 'rain_1h_weighted', 'rain_3h_weighted', 'snow_3h_weighted', 'clouds_all_weighted', 'total load forecast']
train columns: ['identifier', 'days_from_start', 'err_target', 'hour', 'day_of_week', 'forecast solar day ahead', 'forecast wind onshore day ahead', 'price day ahead', 'temp_weighted', 'temp_min_weighted', 'temp_max_weighted', 'pressure_weighted'] ... (# 43 )


In [23]:
train.shape, valid.shape, test.shape

((31523, 43), (744, 43), (3097, 43))

In [24]:
train.days_from_start.value_counts().to_frame().reset_index().sort_values(by='days_from_start')

Unnamed: 0,days_from_start,count
11,0,24
654,1,24
876,2,24
875,3,24
1313,4,18
...,...,...
437,1310,24
436,1311,24
435,1312,24
434,1313,24


In [25]:
valid.days_from_start.value_counts().to_frame().reset_index().sort_values(by='days_from_start')


Unnamed: 0,days_from_start,count
0,1308,24
16,1309,24
29,1310,24
28,1311,24
27,1312,24
26,1313,24
25,1314,24
24,1315,24
23,1316,24
22,1317,24


In [26]:
test.days_from_start.value_counts().to_frame().reset_index().sort_values(by=['days_from_start'])

Unnamed: 0,days_from_start,count
64,1332,24
65,1333,24
95,1334,24
94,1335,24
93,1336,24
...,...,...
38,1456,24
37,1457,24
36,1458,24
35,1459,24


## Reviewing Test dataset error

In [27]:
test = test.reset_index(drop=True)
test

Unnamed: 0,identifier,days_from_start,err_target,hour,day_of_week,forecast solar day ahead,forecast wind onshore day ahead,price day ahead,temp_weighted,temp_min_weighted,...,generation marine,generation nuclear,generation other,generation other renewable,generation solar,generation waste,generation wind offshore,generation wind onshore,weather_main,weather_description_simplified
0,Spain,1332,1.012600,-1.660795,0.999770,-0.657755,0.093105,1.082419,0.116072,0.140691,...,0.0,0.834974,-0.173525,1.547250,-0.606260,1.078413,0.0,0.057872,0,19
1,Spain,1332,0.980120,-1.516366,0.999770,-0.675507,0.042002,0.889354,0.085624,0.123137,...,0.0,0.838553,-0.173525,1.547250,-0.671879,1.058808,0.0,-0.032227,0,19
2,Spain,1332,0.804725,-1.371937,0.999770,-0.706277,-0.049544,0.801534,0.056716,0.098096,...,0.0,0.843325,-0.173525,1.404449,-0.744001,1.039204,0.0,-0.137861,0,19
3,Spain,1332,0.527558,-1.227508,0.999770,-0.725805,-0.143599,0.662137,0.025070,0.062054,...,0.0,0.848096,-0.221066,1.333048,-0.802526,1.019599,0.0,-0.214601,0,19
4,Spain,1332,0.438778,-1.083079,0.999770,-0.740598,-0.221037,0.627288,-0.010182,0.013126,...,0.0,0.852868,-0.221066,1.475849,-0.817305,0.980390,0.0,-0.231378,0,19
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3092,Spain,1460,0.113972,1.083358,-1.500369,-0.809831,-0.696954,1.418365,-0.359285,-0.363656,...,0.0,-0.232667,0.111724,0.761843,-0.813166,0.215814,0.0,-0.733450,0,19
3093,Spain,1460,-0.386228,1.227787,-1.500369,-0.836459,-0.665602,1.387001,-0.446290,-0.459454,...,0.0,-0.231474,0.064182,0.761843,-0.843907,0.274627,0.0,-0.679079,0,19
3094,Spain,1460,0.404132,1.372216,-1.500369,-0.845335,-0.649613,1.281059,-0.494750,-0.473481,...,0.0,-0.229088,0.016641,0.690443,-0.845089,0.392255,0.0,-0.612281,0,19
3095,Spain,1460,0.800394,1.516645,-1.500369,-0.849477,-0.690683,1.075449,-0.538937,-0.535750,...,0.0,-0.230281,0.016641,0.619042,-0.845089,0.411859,0.0,-0.586494,0,19


In [28]:
print(test.columns.tolist())


['identifier', 'days_from_start', 'err_target', 'hour', 'day_of_week', 'forecast solar day ahead', 'forecast wind onshore day ahead', 'price day ahead', 'temp_weighted', 'temp_min_weighted', 'temp_max_weighted', 'pressure_weighted', 'humidity_weighted', 'wind_speed_weighted', 'wind_deg_weighted', 'rain_1h_weighted', 'rain_3h_weighted', 'snow_3h_weighted', 'clouds_all_weighted', 'total load forecast', 'price actual', 'generation biomass', 'generation fossil brown coal/lignite', 'generation fossil coal-derived gas', 'generation fossil gas', 'generation fossil hard coal', 'generation fossil oil', 'generation fossil oil shale', 'generation fossil peat', 'generation geothermal', 'generation hydro pumped storage consumption', 'generation hydro run-of-river and poundage', 'generation hydro water reservoir', 'generation marine', 'generation nuclear', 'generation other', 'generation other renewable', 'generation solar', 'generation waste', 'generation wind offshore', 'generation wind onshore', 

## Get parameters

In [29]:
params = {}
params.update(data_formatter.get_experiment_params())
params.update(data_formatter.get_default_model_params())
parser = ArgumentParser(add_help=False)



train_dataset = TFTDataset(train,  formatter=data_formatter)
valid_dataset = TFTDataset(valid, formatter=data_formatter)
test_dataset  = TFTDataset(test,  formatter=data_formatter)

for k in params:
    if type(params[k]) in [int, float]:
        #if k == 'minibatch_size':
        #    parser.add_argument('--{}'.format(k), type=type(params[k]), default = 256)
        #else:
        parser.add_argument('--{}'.format(k), type=type(params[k]), default = params[k])
    else:
        parser.add_argument('--{}'.format(k), type=str, default = str(params[k]))
hparams = parser.parse_known_args()[0]





In [30]:
len(train_dataset), len(valid_dataset), len(test_dataset)

(31331, 552, 2905)

In [31]:
len(train_dataset), len(valid_dataset), len(test_dataset)

(31331, 552, 2905)

# Temporal Fusion Transformer

In [32]:
class TemporalFusionTransformer(pl.LightningModule):
    def __init__(self, hparams):
        super(TemporalFusionTransformer, self).__init__()
        
        
        self.save_hyperparameters(vars(hparams))
        self.hparams_ns = hparams
    
        self.name = self.__class__.__name__

        # Data parameters
        self.time_steps = int(hparams.total_time_steps)#int(params['total_time_steps'])
        self.input_size = int(hparams.input_size)#int(params['input_size'])
        self.output_size = int(hparams.output_size)#int(params['output_size'])
        self.category_counts = json.loads(str(hparams.category_counts))#json.loads(str(params['category_counts']))
        self.num_categorical_variables = len(self.category_counts)
        self.num_regular_variables = self.input_size - self.num_categorical_variables
        self.n_multiprocessing_workers = int(hparams.multiprocessing_workers) #int(params['multiprocessing_workers'])

       
        # --- Known future inputs (build progressively to honor switches) ---
        _known_block = [
    ('hour', DataTypes.REAL_VALUED, InputTypes.KNOWN_INPUT),
    ('day_of_week', DataTypes.REAL_VALUED, InputTypes.KNOWN_INPUT),
    ('forecast solar day ahead', DataTypes.REAL_VALUED, InputTypes.KNOWN_INPUT),
    ('forecast wind onshore day ahead', DataTypes.REAL_VALUED, InputTypes.KNOWN_INPUT),
    ('price day ahead', DataTypes.REAL_VALUED, InputTypes.KNOWN_INPUT),
    ('temp_weighted', DataTypes.REAL_VALUED, InputTypes.KNOWN_INPUT),
    ('temp_min_weighted', DataTypes.REAL_VALUED, InputTypes.KNOWN_INPUT),
    ('temp_max_weighted', DataTypes.REAL_VALUED, InputTypes.KNOWN_INPUT),
    ('pressure_weighted', DataTypes.REAL_VALUED, InputTypes.KNOWN_INPUT),
    ('humidity_weighted', DataTypes.REAL_VALUED, InputTypes.KNOWN_INPUT),
    ('wind_speed_weighted', DataTypes.REAL_VALUED, InputTypes.KNOWN_INPUT),
    ('wind_deg_weighted', DataTypes.REAL_VALUED, InputTypes.KNOWN_INPUT),
    ('rain_1h_weighted', DataTypes.REAL_VALUED, InputTypes.KNOWN_INPUT),
    ('rain_3h_weighted', DataTypes.REAL_VALUED, InputTypes.KNOWN_INPUT),
    ('snow_3h_weighted', DataTypes.REAL_VALUED, InputTypes.KNOWN_INPUT),
    ('clouds_all_weighted', DataTypes.REAL_VALUED, InputTypes.KNOWN_INPUT),
      ]
        if getattr(self.hparams, "tlf_replace_with_seasonal", False):
            _known_block.append(('seasonal_baseline', DataTypes.REAL_VALUED, InputTypes.KNOWN_INPUT))
        elif getattr(self.hparams, "use_tlf", True):
            _known_block.append(('total load forecast', DataTypes.REAL_VALUED, InputTypes.KNOWN_INPUT))

        core = [
            ('identifier', DataTypes.CATEGORICAL, InputTypes.ID),
            ('days_from_start', DataTypes.REAL_VALUED, InputTypes.TIME),
        ]
        if getattr(self.hparams, "err_target", False):
            core.append(('err_target', DataTypes.REAL_VALUED, InputTypes.TARGET))
        else:
            core.append(('target', DataTypes.REAL_VALUED, InputTypes.TARGET))
                        
        
        self.column_definition =core+_known_block+[

    # --- Observed real inputs ---
    ('price actual', DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
    ('generation biomass', DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
    ('generation fossil brown coal/lignite', DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
    ('generation fossil coal-derived gas', DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
    ('generation fossil gas', DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
    ('generation fossil hard coal', DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
    ('generation fossil oil', DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
    ('generation fossil oil shale', DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
    ('generation fossil peat', DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
    ('generation geothermal', DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
    ('generation hydro pumped storage consumption', DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
    ('generation hydro run-of-river and poundage', DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
    ('generation hydro water reservoir', DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
    ('generation marine', DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
    ('generation nuclear', DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
    ('generation other', DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
    ('generation other renewable', DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
    ('generation solar', DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
    ('generation waste', DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
    ('generation wind offshore', DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),
    ('generation wind onshore', DataTypes.REAL_VALUED, InputTypes.OBSERVED_INPUT),

    # --- Observed categorical ---
    ('weather_main', DataTypes.CATEGORICAL, InputTypes.OBSERVED_INPUT),
    ('weather_description_simplified', DataTypes.CATEGORICAL, InputTypes.OBSERVED_INPUT),]

        use_tlf = bool(getattr(self.hparams, "use_tlf", True))
        replace = bool(getattr(self.hparams, "tlf_replace_with_seasonal", False))
         # Relevant indices for TFT
        #Collect global indices of real-valued and categorical columns
        _real_cols = [i for i, (_, dt, _) in enumerate(self.column_definition) if dt == DataTypes.REAL_VALUED]
        _cat_cols  = [i for i, (_, dt, _) in enumerate(self.column_definition) if dt == DataTypes.CATEGORICAL]

        #Collect global indices by role (from the full column_definition)
        _obs_glob        = [i for i, (_, dt, rl) in enumerate(self.column_definition)
                    if rl == InputTypes.OBSERVED_INPUT and dt == DataTypes.REAL_VALUED]
        _known_reg_glob  = [i for i, (_, dt, rl) in enumerate(self.column_definition)
                    if rl == InputTypes.KNOWN_INPUT    and dt == DataTypes.REAL_VALUED]
        _known_cat_glob  = [i for i, (_, dt, rl) in enumerate(self.column_definition)
                    if rl == InputTypes.KNOWN_INPUT    and dt == DataTypes.CATEGORICAL]
        _static_reg_glob = [i for i, (_, dt, rl) in enumerate(self.column_definition)
                    if rl == InputTypes.STATIC_INPUT   and dt == DataTypes.REAL_VALUED]  # may be empty

         # Map global indices → local indices within the model’s stacked tensors
        #    (TFT expects indices relative to the real/categorical stacks, not global column_definition indices)
        # --- DEBUG: peek column_definition & parsed roles/dtypes ---
    
        self._input_obs_loc               = [ _real_cols.index(i) for i in _obs_glob ]
        self._known_regular_input_idx     = [ _real_cols.index(i) for i in _known_reg_glob ]
        self._known_categorical_input_idx = [ _cat_cols.index(i)  for i in _known_cat_glob ]
        self._static_input_loc            = [ _real_cols.index(i) for i in _static_reg_glob ] # [] if no static real inputs
        setattr(self.hparams, "input_obs_loc",            list(self._input_obs_loc))
        setattr(self.hparams, "known_regular_inputs",     list(self._known_regular_input_idx))
        setattr(self.hparams, "known_categorical_inputs", list(self._known_categorical_input_idx))
        setattr(self.hparams, "static_input_loc",         list(self._static_input_loc))
      
       
        if not use_tlf and not replace:
    # E1: drop TLF
            self.column_definition = [
        t for t in self.column_definition
        if not (t[0] == "total load forecast" and t[2] == InputTypes.KNOWN_INPUT)
           ]
        elif replace:
    # E2: replace TLF with seasonal_baselin
            new = []
            for name, dt, role in self.column_definition:
                if name == "total load forecast" and role == InputTypes.KNOWN_INPUT:
                    new.append(("seasonal_baseline", dt, role))
                else:
                    new.append((name, dt, role))
            self.column_definition = new

        self._refresh_input_names_and_counts()
        self.hparams.num_non_static_historical_inputs = self.num_non_static_historical_inputs
        self.hparams.num_non_static_future_inputs     = self.num_non_static_future_inputs

        # Network params
        self.quantiles = [0.1, 0.5, 0.9]
#         self.use_cudnn = use_cudnn  # Whether to use GPU optimised LSTM
        self.hidden_layer_size = int(hparams.hidden_layer_size)#int(params['hidden_layer_size'])
        self.dropout_rate = float(hparams.dropout_rate)#float(params['dropout_rate'])
        self.max_gradient_norm = float(hparams.max_gradient_norm)#float(params['max_gradient_norm'])
        self.learning_rate = float(hparams.learning_rate)#float(params['learning_rate'])
        self.minibatch_size = int(hparams.minibatch_size)#int(params['minibatch_size'])
        self.num_epochs = int(hparams.num_epochs)#int(params['num_epochs'])
        self.early_stopping_patience = int(hparams.early_stopping_patience)#int(params['early_stopping_patience'])
        self.weight_decay=int(hparams.weight_decay)

        self.num_encoder_steps = int(hparams.num_encoder_steps)#int(params['num_encoder_steps'])
        self.num_stacks = int(hparams.stack_size)#int(params['stack_size'])
        self.num_heads = int(hparams.num_heads)#int(params['num_heads'])
        self.hist_var_proj = torch.nn.Parameter(
    torch.empty(self.num_non_static_historical_inputs, self.hidden_layer_size)
    )
        self.fut_var_proj  = torch.nn.Parameter(
    torch.empty(self.num_non_static_future_inputs,     self.hidden_layer_size)
     )
        torch.nn.init.xavier_uniform_(self.hist_var_proj)
        torch.nn.init.xavier_uniform_(self.fut_var_proj)
        
        # Extra components to store Tensorflow nodes for attention computations
        self._input_placeholder = None
        self._attention_components = None
        self._prediction_parts = None

        print('*** {} params ***'.format(self.name))
        for k in vars(hparams):
            print('# {} = {}'.format(k, vars(hparams)[k]))
            
        self.train_criterion = QuantileLossCalculator(self.quantiles, self.output_size)
        self.test_criterion = NormalizedQuantileLossCalculator(self.quantiles, self.output_size)

        # Build model
        ## Build embeddings
        self.build_embeddings()
        
        ## Build Static Contex Networks
        self.build_static_context_networks()
        
        ## Building Variable Selection Networks
        self.build_variable_selection_networks()
        
        ## Build Lstm
        self.build_lstm()
        
        ## Build GLU for after lstm encoder decoder and layernorm
        self.build_post_lstm_gate_add_norm()
        
        ## Build Static Enrichment Layer
        self.build_static_enrichment()
        
        ## Building decoder multihead attention
        self.build_temporal_self_attention()
        
        ## Building positionwise decoder
        self.build_position_wise_feed_forward()
        
        ## Build output feed forward
        self.build_output_feed_forward()
        ##record KNOWN real-valued variable names in order 
        self.known_real_names = [name for (name, dtype, role) in self.column_definition if role == InputTypes.KNOWN_INPUT and dtype == DataTypes.REAL_VALUED]
        self.name2idx_known = {n: i for i, n in enumerate(self.known_real_names)}
        ## Initializing remaining weights
        self.init_weights()
        print("training_step exists:", hasattr(self, "training_step"))
        print("[CHK] final input_obs_loc =", self._input_obs_loc)
        self._gf_every = 50  # accumulate every N steps; change as you like
        self._gf_layers = [n for n, p in self.named_parameters()
                           if p.requires_grad and ("bias" not in n)]
        self._gf_index  = {name: i for i, name in enumerate(self._gf_layers)}
        n = len(self._gf_layers)
        self._gf_sum    = np.zeros(n, dtype=np.float64)  # sum of mean|grad|
        self._gf_count  = np.zeros(n, dtype=np.int64)    # how many times each layer got updated
        self._gf_steps  = 0


        
    def init_weights(self):
        for name, p in self.named_parameters():
            if ('lstm' in name and 'ih' in name) and 'bias' not in name:
                #print(name)
                #print(p.shape)
                torch.nn.init.xavier_uniform_(p)
#                 torch.nn.init.kaiming_normal_(p, a=0, mode='fan_in', nonlinearity='sigmoid')
            elif ('lstm' in name and 'hh' in name) and 'bias' not in name:
        
                 torch.nn.init.orthogonal_(p)
            
            elif 'lstm' in name and 'bias' in name:
                #print(name)
                #print(p.shape)
                torch.nn.init.zeros_(p)
    def _refresh_input_names_and_counts(self):
        self.known_real_names = [
        n for (n, dt, role) in self.column_definition
        if role == InputTypes.KNOWN_INPUT and dt == DataTypes.REAL_VALUED
    ]
        self.observed_real_names = [
        n for (n, dt, role) in self.column_definition
        if role == InputTypes.OBSERVED_INPUT and dt == DataTypes.REAL_VALUED
      ]

    # encoder (history) = observed + known; decoder (future) = known only
        self.hist_real_names = self.observed_real_names + self.known_real_names
        self.fut_real_names  = self.known_real_names

    # counts expected by later assertions / projections
        self.num_non_static_historical_inputs = len(self.hist_real_names)
        self.num_non_static_future_inputs     = len(self.fut_real_names)

    # handy map for TLF dropout / plots, etc.
        self.name2idx_known = {n: i for i, n in enumerate(self.known_real_names)}

        
    def get_historical_num_inputs(self):
        
        obs_inputs = [i for i in self._input_obs_loc]
        
        known_regular_inputs = [i for i in self._known_regular_input_idx
                                if i not in self._static_input_loc]
            
        known_categorical_inputs = [i for i in self._known_categorical_input_idx
                                    if i + self.num_regular_variables not in self._static_input_loc]
        
        wired_embeddings = [i for i in range(self.num_categorical_variables)
                            if i not in self._known_categorical_input_idx 
                            and i not in self._input_obs_loc]

        unknown_inputs = [i for i in range(self.num_regular_variables)
                          if i not in self._known_regular_input_idx
                          and i not in self._input_obs_loc]

        return len(obs_inputs+known_regular_inputs+known_categorical_inputs+wired_embeddings+unknown_inputs)
    
    def get_future_num_inputs(self):
            
        known_regular_inputs = [i for i in self._known_regular_input_idx
                                if i not in self._static_input_loc]
            
        known_categorical_inputs = [i for i in self._known_categorical_input_idx
                                    if i + self.num_regular_variables not in self._static_input_loc]

        return len(known_regular_inputs + known_categorical_inputs)
    
    def build_embeddings(self):
        self.categorical_var_embeddings = nn.ModuleList([nn.Embedding(self.category_counts[i], 
                                                                      self.hidden_layer_size) 
                                                     for i in range(self.num_categorical_variables)])

        self.regular_var_embeddings = nn.ModuleList([nn.Linear(1, 
                                                              self.hidden_layer_size) 
                                                  for i in range(self.num_regular_variables)])

    def build_variable_selection_networks(self):
        
        self.static_vsn = VariableSelectionNetwork(hidden_layer_size = self.hidden_layer_size,
                                                   input_size = self.hidden_layer_size * len(self._static_input_loc),
                                                   output_size = len(self._static_input_loc),
                                                   dropout_rate = self.dropout_rate)
        
        self.temporal_historical_vsn = VariableSelectionNetwork(hidden_layer_size = self.hidden_layer_size,
                                                                input_size = self.hidden_layer_size *
                                                                        self.num_non_static_historical_inputs,
                                                                output_size = self.num_non_static_historical_inputs,
                                                                dropout_rate = self.dropout_rate,
                                                                additional_context=self.hidden_layer_size)
        
        self.temporal_future_vsn = VariableSelectionNetwork(hidden_layer_size = self.hidden_layer_size,
                                                            input_size = self.hidden_layer_size *
                                                                        self.num_non_static_future_inputs,
                                                            output_size = self.num_non_static_future_inputs,
                                                            dropout_rate = self.dropout_rate,
                                                            additional_context=self.hidden_layer_size)
        
    def build_static_context_networks(self):
        
        self.static_context_variable_selection_grn = GatedResidualNetwork(self.hidden_layer_size,
                                                                          dropout_rate=self.dropout_rate)
        
        self.static_context_enrichment_grn = GatedResidualNetwork(self.hidden_layer_size,
                                                              dropout_rate=self.dropout_rate)

        self.static_context_state_h_grn = GatedResidualNetwork(self.hidden_layer_size,
                                                           dropout_rate=self.dropout_rate)
        
        self.static_context_state_c_grn = GatedResidualNetwork(self.hidden_layer_size,
                                                           dropout_rate=self.dropout_rate)
        
    def build_lstm(self):
        self.historical_lstm = nn.LSTM(input_size = self.hidden_layer_size,
                                       hidden_size = self.hidden_layer_size,
                                       batch_first = True)
        self.future_lstm = nn.LSTM(input_size = self.hidden_layer_size,
                                   hidden_size = self.hidden_layer_size,
                                   batch_first = True)
        
    def build_post_lstm_gate_add_norm(self):
        self.post_seq_encoder_gate_add_norm = GateAddNormNetwork(self.hidden_layer_size,
                                                                 self.hidden_layer_size,
                                                                 self.dropout_rate,
                                                                 activation = None)
        
    def build_static_enrichment(self):
        self.static_enrichment = GatedResidualNetwork(self.hidden_layer_size,
                                                      dropout_rate = self.dropout_rate,
                                                      additional_context=self.hidden_layer_size)
        
    def build_temporal_self_attention(self):
        self.self_attn_layer = InterpretableMultiHeadAttention(n_head = self.num_heads, 
                                                               d_model = self.hidden_layer_size,
                                                               dropout = self.dropout_rate)
        
        self.post_attn_gate_add_norm = GateAddNormNetwork(self.hidden_layer_size,
                                                           self.hidden_layer_size,
                                                           self.dropout_rate,
                                                           activation = None)
        
    def build_position_wise_feed_forward(self):
        self.GRN_positionwise = GatedResidualNetwork(self.hidden_layer_size,
                                                     dropout_rate = self.dropout_rate)
        
        self.post_tfd_gate_add_norm = GateAddNormNetwork(self.hidden_layer_size,
                                                         self.hidden_layer_size,
                                                         self.dropout_rate,
                                                         activation = None)
        
    def build_output_feed_forward(self):
        self.output_feed_forward = torch.nn.Linear(self.hidden_layer_size, 
                                                   self.output_size * len(self.quantiles))
    def get_decoder_mask(self, x, enc_len=None, as_bool=True):
 
        B, T = x.size(0), x.size(1)
        enc_len = self.hparams.num_encoder_steps if enc_len is None else int(enc_len)

    # future positions are masked (upper triangle)
    # True = mask out ; False = keep
        future_mask = torch.triu(torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=1)

    # Broadcast to batch: [B, T, T]
        mask_bool = future_mask.unsqueeze(0).expand(B, -1, -1).contiguous()

        if as_bool:
            return mask_bool  # bool mask: True means "do not attend"
        else:
         # additive mask: 0 for keep, -inf for mask
            zero = torch.tensor(0.0, device=x.device, dtype=x.dtype)
            neginf = torch.tensor(float("-inf"), device=x.device, dtype=x.dtype)
            return torch.where(mask_bool, neginf, zero)
    def mask_decoder_observed_(self, x_observed):
        if x_observed is None:
            return x_observed
        nd = int(self.hparams.total_time_steps - self.hparams.num_encoder_steps)
        if nd > 0:
            x_observed[:, -nd:, :] = 0.0
        return x_observed

    def mask_decoder_observed_cat_(self, x_obs_cat, pad_idx=0):
        if x_obs_cat is None:
            return x_obs_cat
        nd = int(self.hparams.total_time_steps - self.hparams.num_encoder_steps)
        if nd > 0:
            x_obs_cat[:, -nd:, :] = pad_idx
        return x_obs_cat

    
    def get_tft_embeddings(self, regular_inputs, categorical_inputs):
   
        B, T, Nr = regular_inputs.shape
        Nc = 0 if (categorical_inputs is None) else categorical_inputs.shape[-1]
        

   
        assert Nr == self.num_regular_variables, "regular_inputs dim mismatch"
        assert (Nc == self.num_categorical_variables) or (categorical_inputs is None), "categorical_inputs dim mismatch"

        has_cate = (categorical_inputs is not None) and (Nc > 0)

    # -------- helpers: split combined indices into regular / categorical local indices --------
        def split_combined_indices(idxs_combined):
            reg = set()
            cat = set()
            for j in idxs_combined:
                if j < self.num_regular_variables:
                    reg.add(j)  # local = combined
                else:
                # local index inside categorical block
                    cat_local = j - self.num_regular_variables
                    if has_cate and 0 <= cat_local < Nc:
                        cat.add(cat_local)
            return reg, cat

    # turn to sets for fast lookup
        static_combined = set(getattr(self, "_static_input_loc", []) or [])
        input_obs_combined = set(getattr(self, "_input_obs_loc", []) or [])
        known_reg_set = set(getattr(self, "_known_regular_input_idx", []) or [])
        known_cat_set = set(getattr(self, "_known_categorical_input_idx", []) or [])

        static_reg_idx, static_cat_idx = split_combined_indices(static_combined)
        obs_reg_idx,   obs_cat_idx   = split_combined_indices(input_obs_combined)
    
    # -------- Static inputs (optional) --------
        if static_reg_idx or static_cat_idx:
            static_regular_inputs = [
            self.regular_var_embeddings[i](regular_inputs[:, 0, i:i+1]) for i in sorted(static_reg_idx)
        ]  # each: (B, 1, H)
            static_categorical_inputs = []
            if has_cate and static_cat_idx:
                for i in sorted(static_cat_idx):  # local idx in categorical space
                    emb = self.categorical_var_embeddings[i](categorical_inputs[:, 0, i])  # (B, H)
                    static_categorical_inputs.append(emb.unsqueeze(1))  # -> (B,1,H)
            if static_regular_inputs or static_categorical_inputs:
                static_inputs = torch.cat(static_regular_inputs + static_categorical_inputs, dim=1)  # (B, S, H)
            else:
                static_inputs = None
        else:
            static_inputs = None

    # -------- Observed inputs (historical only block, we just build representation here) --------
        obs_parts = []
        if obs_reg_idx:
            obs_reg = torch.stack(
            [self.regular_var_embeddings[i](regular_inputs[..., i:i+1]) for i in sorted(obs_reg_idx)],
            dim=-1
             )  # (B,T,H, n_obs_reg)
            obs_parts.append(obs_reg)
        if has_cate and obs_cat_idx:
            obs_cat = torch.stack(
            [self.categorical_var_embeddings[i](categorical_inputs[..., i]) for i in sorted(obs_cat_idx)],
            dim=-1
            )  # (B,T,H, n_obs_cat)
            obs_parts.append(obs_cat)

        if obs_parts:
            obs_inputs = torch.cat(obs_parts, dim=-1)  # (B,T,H, n_obs_total)
        else:
            obs_inputs = torch.empty(B, T, self.hidden_layer_size, 0, device=regular_inputs.device)

    # -------- Unknown inputs (a priori unknown at prediction time) --------
    # regular unknown = not known & not observed
        unknown_reg_idx = [i for i in range(Nr) if (i not in known_reg_set and i not in obs_reg_idx and i not in static_reg_idx)]
        unknown_cat_idx = []
        if has_cate:
            unknown_cat_idx = [i for i in range(Nc)
                           if (i not in known_cat_set and i not in obs_cat_idx and i not in static_cat_idx)]

            unk_parts = []
        if unknown_reg_idx:
            unk_reg = torch.stack(
            [self.regular_var_embeddings[i](regular_inputs[..., i:i+1]) for i in unknown_reg_idx],
            dim=-1
           )  # (B,T,H, n_unk_reg)
            unk_parts.append(unk_reg)
        if has_cate and unknown_cat_idx:
            unk_cat = torch.stack(
            [self.categorical_var_embeddings[i](categorical_inputs[..., i]) for i in unknown_cat_idx],
            dim=-1
        )  # (B,T,H, n_unk_cat)
            unk_parts.append(unk_cat)

        if unk_parts:
            unknown_inputs = torch.cat(unk_parts, dim=-1)  # (B,T,H, n_unk_total)
        else:
            unknown_inputs = None


        known_parts = []
        if known_reg_set:
            known_reg = torch.stack(
            [self.regular_var_embeddings[i](regular_inputs[..., i:i+1]) for i in sorted(known_reg_set)],
            dim=-1
           )  # (B,T,H, n_known_reg)
            known_parts.append(known_reg)

        if has_cate and known_cat_set:
            known_cat = torch.stack(
            [self.categorical_var_embeddings[i](categorical_inputs[..., i]) for i in sorted(known_cat_set)],
            dim=-1
           )  # (B,T,H, n_known_cat)
            known_parts.append(known_cat)

        if known_parts:
            known_combined_layer = torch.cat(known_parts, dim=-1)  # (B,T,H, n_known_total)
        else:
            known_combined_layer = torch.empty(B, T, self.hidden_layer_size, 0, device=regular_inputs.device)

        return unknown_inputs, known_combined_layer, obs_inputs, static_inputs

        
    def forward(self, all_inputs):
        if not isinstance(all_inputs, (list, tuple)) or len(all_inputs) < 3:
            raise ValueError(f"forward expects (hist_regular, fut_regular, x_static[, hist_cate, fut_cate]), got {type(all_inputs)} len={len(all_inputs) if isinstance(all_inputs,(list,tuple)) else 'n/a'}")

        if len(all_inputs) >= 5:
            hist_regular, fut_regular, x_static, hist_cate, fut_cate = all_inputs[:5]
        else:
            hist_regular, fut_regular, x_static = all_inputs[:3]
            hist_cate = fut_cate = None

        device = hist_regular.device  # ensure all tensors move to the same device
        hist_regular = hist_regular.to(device).float()
        fut_regular  = fut_regular.to(device).float()
        # Cast/allocate x_static on the correct device
        x_static = (
    x_static.to(device).float()
    if x_static is not None
    else torch.zeros(hist_regular.size(0), 0, device=device)
     )

# Move categorical tensors to the same device/dtype
        if hist_cate is not None:
            hist_cate = hist_cate.to(device).long()
        if fut_cate is not None:
            fut_cate = fut_cate.to(device).long()

# Read config values
        enc_cfg = int(getattr(self.hparams, "num_encoder_steps", getattr(self, "num_encoder_steps")))
        total_cfg = int(getattr(self.hparams, "total_time_steps", getattr(self, "time_steps")))
        dec_cfg = total_cfg - enc_cfg
        H = int(getattr(self.hparams, "hidden_layer_size", getattr(self, "hidden_layer_size")))
                
# --- lazily project static features to hidden size H if needed ---
# Guard against empty static vectors (B, 0) to avoid in_features=0 errors.
        if (x_static is not None) and (x_static.dim() == 2) and (x_static.size(-1) > 0) and (x_static.size(-1) != H):
            if not hasattr(self, "static_proj"):
                self.static_proj = torch.nn.Linear(x_static.size(-1), H).to(device)
            x_static = self.static_proj(x_static)  # -> [B, H]

# Shapes, sanity checks
        B, enc, Nr_hist = hist_regular.shape
        _,  dec, Nr_fut = fut_regular.shape
# device = hist_regular.device  # (already set above)
        assert enc == enc_cfg, f"enc mismatch: got {enc}, expect {enc_cfg}"
        assert enc + dec == total_cfg, f"time len mismatch: enc({enc}) + dec({dec}) != total({total_cfg})"

# Decide whether static branch is enabled
        use_static_branch = (
    (x_static is not None)
    and (x_static.dim() == 2)
    and (x_static.shape[-1] == H)
    and (getattr(self, "static_vsn", None) is not None)
)

        if use_static_branch:
            static_encoder, sparse_weights = self.static_vsn(x_static)
            static_context_variable_selection = self.static_context_variable_selection_grn(static_encoder)
            static_context_enrichment        = self.static_context_enrichment_grn(static_encoder)
            static_context_state_h           = self.static_context_state_h_grn(static_encoder)
            static_context_state_c           = self.static_context_state_c_grn(static_encoder)
        else:
            if (x_static is not None) and (x_static.numel() > 0) and (x_static.shape[-1] != H):
                print(f"[WARN] x_static has dim {x_static.shape[-1]} != H({H}); bypassing static_vsn.")
            sparse_weights = None
            static_encoder                     = torch.zeros(B, H, device=device)
            static_context_variable_selection  = torch.zeros(B, H, device=device)
            static_context_enrichment          = torch.zeros(B, H, device=device)
            static_context_state_h             = torch.zeros(B, H, device=device)
            static_context_state_c             = torch.zeros(B, H, device=device)

        hist_emb = torch.einsum('btn,nh->bthn', hist_regular, self.hist_var_proj)  # [B, T, H, N_hist]
        fut_emb  = torch.einsum('btn,nh->bthn', fut_regular,  self.fut_var_proj)   # [B, T, H, N_fut]
        historical_features, historical_flags = self.temporal_historical_vsn(
        (hist_emb, static_context_variable_selection)   
       )  # -> [B, enc, H]
        future_features, future_flags = self.temporal_future_vsn(
        (fut_emb, static_context_variable_selection)
        )  # -> [B, dec, H]

        history_lstm, (state_h, state_c) = self.historical_lstm(
        historical_features,
        (static_context_state_h.unsqueeze(0), static_context_state_c.unsqueeze(0))
         )  # [B, enc, H]
        future_lstm, _ = self.future_lstm(future_features, (state_h, state_c))  # [B, dec, H]

    # Skip with original inputs after VSN
        input_embeddings = torch.cat((historical_features, future_features), dim=1)  # [B, total, H]
        lstm_layer       = torch.cat((history_lstm,      future_lstm),      dim=1)  # [B, total, H]

        temporal_feature_layer = self.post_seq_encoder_gate_add_norm(lstm_layer, input_embeddings)  # [B, total, H]

        expanded_static_context = static_context_enrichment.unsqueeze(1)  # [B,1,H]
        enriched = self.static_enrichment((temporal_feature_layer, expanded_static_context))       # [B, total, H]
        x, self_att = self.self_attn_layer(
        enriched, enriched, enriched,
        mask=self.get_decoder_mask(enriched)  
          )  # x: [B, total, H]
        x = self.post_attn_gate_add_norm(x, enriched)

        decoder = self.GRN_positionwise(x)  # [B, total, H]
        transformer_layer = self.post_tfd_gate_add_norm(decoder, temporal_feature_layer)  # [B, total, H]
        outputs = self.output_feed_forward(transformer_layer[:, enc:, :])  # [B, dec, out_size]
        return outputs
       
    def loss(self, y_hat, y):
        return self.train_criterion.apply(y_hat, y_true)
    
    def test_loss(self, y_hat, y):
        q = self.quantiles[1] if q is None else q
        return self.test_criterion.apply(y_true, y_hat, q)
        
    def _prepare_tft_inputs(self, batch):

        enc   = getattr(self, "num_encoder_steps", None) or getattr(self.hparams, "num_encoder_steps", None)
        total = getattr(self, "time_steps", None) or getattr(self, "total_time_steps", None) or getattr(self.hparams, "total_time_steps", None)

        if isinstance(batch[0], (list, tuple)):
            xb = batch[0]
            if len(xb) == 4:
            # (x_known, x_observed, x_static, x_categorical)
                x_known, x_observed, x_static, x_categorical = xb
            elif len(xb) == 3:
                x_known, x_observed, x_categorical = xb
                x_static = None                     
            else:
                raise ValueError(f"Unexpected input tuple length: {len(xb)}")
            y = batch[1]
        else:
            raise ValueError("Batch format not supported; expected ((x_known, x_observed, x_static[, x_categorical]), y)")
        def _squeeze_TK(x):
            if x is None:
                return None
            return x.squeeze(-1) if (x.dim() == 4 and x.size(-1) == 1) else x


        # 1) normalize dtypes & memory layout
        x_known       = _squeeze_TK(x_known).contiguous().float()
        x_observed    = _squeeze_TK(x_observed).contiguous().float()
        x_categorical = _squeeze_TK(x_categorical)
        if x_categorical is not None:
            x_categorical = x_categorical.contiguous().long()
        y = y.contiguous().float()
        # === OPTIONAL: training-time feature dropout on TLF ===
        p = float(getattr(self.hparams, "tlf_dropout_p", 0.0) or 0.0)
        if self.training and p > 0.0 and "total load forecast" in self.name2idx_known:
            tlf_idx = self.name2idx_known["total load forecast"]
        # mask shape: [B, T, 1] 
            mask = (torch.rand(x_known.size(0), x_known.size(1), 1, device=x_known.device) < p)
            col = x_known[..., tlf_idx:tlf_idx+1]
            mu  = col.mean()
            x_known[..., tlf_idx:tlf_idx+1] = torch.where(mask, mu, col)

        # read window config online from hparams (never use stale cached attrs)
        total = int(self.time_steps)      
        enc   = int(self.num_encoder_steps)     
        dec   = total - enc
        assert 0 < enc < total, f"invalid windows: enc={enc}, total={total}"

        # align time length T to `total` by cropping the last `total` steps (if longer)
        assert x_known.dim() == 3,    f"x_known must be [B,T,K], got {x_known.shape}"
        assert x_observed.dim() == 3, f"x_observed must be [B,T,K], got {x_observed.shape}"
        T = x_known.size(1)
        assert x_observed.size(1) == T, f"x_observed T {x_observed.size(1)} != x_known T {T}"

        if T != total:
            if T > total:
                s = T - total
                x_known    = x_known[:, s:, :].contiguous()
                x_observed = x_observed[:, s:, :].contiguous()
                if x_categorical is not None and x_categorical.size(1) == T:
                    x_categorical = x_categorical[:, s:, :].contiguous()
        # if y carries a time axis of length T, crop it consistently
                if y.dim() >= 2 and y.size(1) == T:
                    y = y[:, s:, ...].contiguous()
        # (optional) warn once
                if not hasattr(self, "_window_warned"):
                    print(f"[TFT] Cropped input windows from T={T} to total={total} (enc={enc}, dec={dec}).")
                    self._window_warned = True
            else:
                raise AssertionError(
            f"dataset window too short: T={T} < total_time_steps={total}. "
            f"Rebuild dataset with total={total} (enc={enc}, dec={dec}), "
            f"or set hparams.total_time_steps={T}."
        )

        # split into historical/future parts
        historical_inputs = torch.cat(
    [x_known[:, :enc, :], x_observed[:, :enc, :]], dim=-1
)       # [B, enc, K_hist]
        future_inputs = x_known[:, enc:, :]               # [B, dec, K_fut]

        # (optional) sanity checks
        assert historical_inputs.size(1) == enc, f"hist len {historical_inputs.size(1)} != enc {enc}"
        assert future_inputs.size(1)     == dec, f"fut  len {future_inputs.size(1)} != dec {dec}"

        if x_categorical is not None and x_categorical.size(-1) > 0:
            historical_categorical = x_categorical[:, :enc, :]
            future_categorical     = x_categorical[:, enc:, :] if getattr(self, "known_categorical_inputs", []) else None
        else:
            historical_categorical = None
            future_categorical     = None  

        if x_static is None:
            x_static_2d = torch.zeros(x_known.size(0), 0, device=x_known.device, dtype=x_known.dtype)
        else:
            x_static = x_static.float()
            if x_static.dim() == 3:
                with torch.no_grad():
                    if x_static.size(1) > 1:
                        diff = (x_static - x_static[:, :1, :]).abs().max().item()
                        if diff > 1e-6:
                            print(f"[WARN] x_static varies over time (max diff={diff:.2e}); using first timestep.")
                x_static_2d = x_static[:, 0, :]
            elif x_static.dim() == 2:
                x_static_2d = x_static
            else:
                raise ValueError(f"x_static must be [B,K] or [B,T,K], got {tuple(x_static.shape)}")


        if getattr(self, "categorical_var_embeddings", None) and (historical_categorical is not None):
            all_inputs = (historical_inputs, future_inputs, x_static_2d, historical_categorical, future_categorical)
        else:
            all_inputs = (historical_inputs, future_inputs, x_static_2d)

    
        return all_inputs, y
    def _align_targets_to_decoder(self, y, y_hat):
        dec_len = self.time_steps - self.num_encoder_steps
        assert dec_len > 0
        if y.dim() == 2:            # [B, T] -> [B, T, 1]
            y = y.unsqueeze(-1)
        y_dec = y[:, -dec_len:, :] 

        if y_hat.size(-1) > 1 and y_dec.size(-1) == 1:
            y_true = y_dec.expand(-1, -1, y_hat.size(-1))
        else:
            y_true = y_dec
        assert y_true.shape == y_hat.shape, f"y_true{y_true.shape} vs y_hat{y_hat.shape}"
        return y_true
    def _runtime_sanity_checks(self, batch, stage="train", strict=True):

        (x_known, x_observed, x_static), y = batch


        if getattr(self, "_did_runtime_checks", False):
            return
        self._did_runtime_checks = True

 
        if x_observed is not None:
            B, T, C_obs = x_observed.shape
            assert C_obs == len(self._input_obs_loc), (
            f"[{stage}] Observed channels ({C_obs}) != len(input_obs_loc) ({len(self._input_obs_loc)}). "
            "Check column_definition → mapping."
        )

   
        n_real = getattr(self, "_n_real_inputs", None)
        if n_real is not None:
            bad = [i for i in self._input_obs_loc if not (0 <= i < n_real)]
            assert not bad, f"[{stage}] input_obs_loc has out-of-range indices: {bad} (n_real={n_real})."
    
            assert len(set(self._input_obs_loc)) == len(self._input_obs_loc), \
            f"[{stage}] input_obs_loc contains duplicated indices: {self._input_obs_loc}"

   
        if x_observed is not None:
            nd = int(self.hparams.total_time_steps - self.hparams.num_encoder_steps)      
        if nd > 0:
            dec = slice(-nd, None)
            leak_val = float(x_observed[:, dec].abs().sum().item())
            assert leak_val == 0.0, f"[{stage}] Observed not masked in decoder! leak={leak_val:.6f}."

   
        if x_observed is not None:
            assert torch.isfinite(x_observed).all(), f"[{stage}] x_observed has NaN/Inf."
        if x_known is not None:
            assert torch.isfinite(x_known).all(), f"[{stage}] x_known has NaN/Inf."
        
    def training_step(self, batch, batch_idx):
        # unpack according to your batch structure
        (x_known, x_observed, x_static), y = batch
        x_observed = self.mask_decoder_observed_(x_observed)
        batch = (x_known, x_observed, x_static), y
        attn_mask = self.get_decoder_mask(x_known, as_bool=True)
        self._runtime_sanity_checks(batch, stage="train")
    
        all_inputs, y = self._prepare_tft_inputs(batch)
        y_hat = self(all_inputs)                               # [B, T_dec, Q]
        y_true = self._align_targets_to_decoder(y, y_hat)      # [B, T_dec, Q]
        loss = self.train_criterion.apply(y_hat, y_true)
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        (x_known, x_observed, x_static), y = batch
        x_observed = self.mask_decoder_observed_(x_observed)
        batch = (x_known, x_observed, x_static), y
        self._runtime_sanity_checks(batch, stage="val")
        
        all_inputs, y = self._prepare_tft_inputs(batch)
        y_hat = self(all_inputs)
        y_true = self._align_targets_to_decoder(y, y_hat)

        if hasattr(self, "test_criterion"):
            val_loss = self.test_criterion.apply(y_true, y_hat, self.quantiles[1]) 
        else:
            val_loss = self.train_criterion.apply(y_hat, y_true) 

        self.log("val_loss", val_loss, on_epoch=True, prog_bar=True)
        if batch_idx == 0:  # only print onc
            print("\n[Inspect all_inputs]")
            print("Type:", type(all_inputs))
            if isinstance(all_inputs, dict):
                for k, v in all_inputs.items():
                    print(f"  {k}: {type(v)}, shape={tuple(v.shape) if hasattr(v,'shape') else None}")
            elif isinstance(all_inputs, (list, tuple)):
                for i, v in enumerate(all_inputs):
                    print(f"  idx {i}: {type(v)}, shape={tuple(v.shape) if hasattr(v,'shape') else None}")
            else:
                print("  value:", all_inputs)
        return val_loss


    def test_step(self, batch, batch_idx):
        all_inputs, y = self._prepare_tft_inputs(batch)
        y_hat = self(all_inputs)
        y_true = self._align_targets_to_decoder(y, y_hat)      # [B, T_dec, Q]，已自动扩到 Q 维

        if hasattr(self, "test_criterion"):
            loss = self.test_criterion.apply(y_true, y_hat, self.quantiles[1])
        else:
            loss = self.train_criterion.apply(y_hat, y_true)

        self.log("test_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        return loss
 

    def configure_optimizers(self):
   
        hp = getattr(self, "hparams", None)
        lr = float(getattr(hp, "learning_rate", 1e-3))
        wd = float(getattr(hp, "weight_decay", 0.0))

        optimizer = torch.optim.AdamW(self.parameters(), lr=lr, weight_decay=wd)

        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode="min", factor=0.5, patience=3, verbose=True
    )

    # Return dict form so Lightning knows which metric to monitor
        return {
        "optimizer": optimizer,
        "lr_scheduler": {
            "scheduler": scheduler,
            "monitor": "val_loss",  # MUST match the key logged in validation_step
            "interval": "epoch",    # Not used by Plateau but harmless
            "frequency": 1,
            "strict": True,         # Raise error if 'val_loss' is missing
        },
    }
    def _accumulate_grad_flow(self):
        """Accumulate mean absolute gradient per layer (numbers only)."""
        any_grad = False
        for name, p in self.named_parameters():
            if (p.grad is None) or (not p.requires_grad) or ("bias" in name):
                continue
            any_grad = True
            i = self._gf_index.get(name, None)
            if i is None:
                continue
            # detach to CPU scalar; keep it tiny
            g = p.grad.detach().abs().mean().item()
            self._gf_sum[i]   += g
            self._gf_count[i] += 1
        if any_grad:
            self._gf_steps += 1

    # ---- draw once at the end; save a single PNG; close figure immediately ----
    def _plot_grad_flow_final(self, save_path="grad_flow_final.png"):
        """Plot aggregated mean|grad| per layer collected during training."""
        # safe average per layer (ignore layers that never received grad)
        denom = np.clip(self._gf_count, 1, None)
        vals  = (self._gf_sum / denom).astype(np.float64)

        # shorten very long layer names to fit x-axis
        def _short(n):
            n = n.replace(".weight", "")
            parts = n.split(".")
            return ".".join(parts[-2:]) if len(parts) >= 2 else n

        labels = [_short(n) for n in self._gf_layers]

        # dynamic width to make x-axis fit
        W = max(12.0, 0.18 * len(labels))  # 0.18 inch per label approximately
        fig, ax = plt.subplots(figsize=(W, 5.5), constrained_layout=True)

        ax.plot(vals, marker="o", linewidth=1.0, alpha=0.9)
        ax.hlines(0, 0, len(vals) - 1, linewidth=1)

        # sparse ticks if too many layers
        step = max(1, len(labels) // 60)  # at most ~60 tick labels
        tick_pos = np.arange(0, len(labels), step)
        ax.set_xticks(tick_pos)
        ax.set_xticklabels([labels[i] for i in tick_pos], rotation=90, fontsize=7)

        ax.set_xlim(-0.5, len(vals) - 0.5)
        ax.set_xlabel("Layers", fontsize=10)
        ax.set_ylabel("Mean |gradient| (aggregated over training)", fontsize=10)
        ax.set_title(f"Gradient flow (aggregated over {int(self._gf_steps)} steps)", fontsize=11)
        out_dir = Path("/Image")            # absolute path you asked for
        save_path = out_dir / "grad_flow_final.png"

        fig.savefig(save_path, dpi=180)  # one image on disk
        plt.close(fig)                   # no figure kept in memory
        return save_path

    # ---- hook into Lightning; accumulate numbers only during training ----
    def on_after_backward(self):
        # accumulate every N global steps; no plotting
        if (self.trainer.global_step % self._gf_every) == 0:
            self._accumulate_grad_flow()
    def on_train_end(self):
        path = self._plot_grad_flow_final(save_path="grad_flow_final.png")
        print(f"[grad-flow] saved: {path}")

    
    def train_dataloader(self):
        # REQUIRED
        return DataLoader(train_dataset, batch_size = self.minibatch_size, shuffle=True, drop_last=True, num_workers=4,persistent_workers=True)

    def val_dataloader(self):
        # OPTIONAL
        return DataLoader(valid_dataset, batch_size = self.minibatch_size, shuffle=False, drop_last=False, num_workers=4,persistent_workers=True )
    
    def test_dataloader(self):
        # OPTIONAL
        return DataLoader(test_dataset, batch_size = self.minibatch_size, shuffle=False, drop_last=False, num_workers=4,persistent_workers=True)

# Training

## Setting Device

In [33]:
DEVICE = torch.device("cuda: 0" if torch.cuda.is_available() else "cpu")
DEVICE

device(type='cpu')

## Model Instance and fined parameters

In [34]:
# ==== Window ====
hparams.num_encoder_steps = 168           # enc
PRED = 24
hparams.total_time_steps  = hparams.num_encoder_steps + PRED   # 192

# ==== Model ====
hparams.hidden_layer_size = 64
hparams.num_heads         = 2
hparams.dropout_rate      = 0.1
hparams.weight_decay      = 1e-3

# ==== Optimization ====
hparams.minibatch_size    = 64
hparams.learning_rate     = 5e-4
hparams.max_gradient_norm = 1.0
hparams.early_stopping_patience = 5
hparams.use_tlf = use_tlf
hparams.tlf_replace_with_seasonal = tlf_replace_with_seasonal
hparams.err_target = err_target

tft = TemporalFusionTransformer(hparams)  
tft

known_model = [n for (n, dt, role) in tft.column_definition
               if role == InputTypes.KNOWN_INPUT and dt == DataTypes.REAL_VALUED]
print("[MODEL] KNOWN =", known_model)
coldef_fmt = (data_formatter.get_column_definition()
              if hasattr(data_formatter, "get_column_definition")
              else data_formatter._column_definition)
known_fmt = [n for (n, dt, role) in coldef_fmt
             if role == InputTypes.KNOWN_INPUT and dt == DataTypes.REAL_VALUED]
print("[FMT]   KNOWN =", known_fmt)
coldef_ds = (train_dataset.get_column_definition()
             if hasattr(train_dataset, "get_column_definition")
             else getattr(train_dataset, "_column_definition", None))
known_ds = [n for (n, dt, role) in (coldef_ds or [])
            if role == InputTypes.KNOWN_INPUT and dt == DataTypes.REAL_VALUED]
print("[DS]    KNOWN =", known_ds)





*** TemporalFusionTransformer params ***
# total_time_steps = 192
# num_encoder_steps = 168
# num_epochs = 100
# early_stopping_patience = 5
# multiprocessing_workers = 5
# column_definition = [('identifier', <DataTypes.CATEGORICAL: 1>, <InputTypes.ID: 4>), ('days_from_start', <DataTypes.REAL_VALUED: 0>, <InputTypes.TIME: 5>), ('err_target', <DataTypes.REAL_VALUED: 0>, <InputTypes.TARGET: 0>), ('hour', <DataTypes.REAL_VALUED: 0>, <InputTypes.KNOWN_INPUT: 2>), ('day_of_week', <DataTypes.REAL_VALUED: 0>, <InputTypes.KNOWN_INPUT: 2>), ('forecast solar day ahead', <DataTypes.REAL_VALUED: 0>, <InputTypes.KNOWN_INPUT: 2>), ('forecast wind onshore day ahead', <DataTypes.REAL_VALUED: 0>, <InputTypes.KNOWN_INPUT: 2>), ('price day ahead', <DataTypes.REAL_VALUED: 0>, <InputTypes.KNOWN_INPUT: 2>), ('temp_weighted', <DataTypes.REAL_VALUED: 0>, <InputTypes.KNOWN_INPUT: 2>), ('temp_min_weighted', <DataTypes.REAL_VALUED: 0>, <InputTypes.KNOWN_INPUT: 2>), ('temp_max_weighted', <DataTypes.REAL_VALUED: 0

In [35]:
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
early_stop_callback = EarlyStopping(monitor = 'val_loss',
                                    min_delta = 1e-4,
                                    patience=hparams.early_stopping_patience,
                                    verbose=False,
                                    mode='min')
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    save_top_k=1,
    mode='min',
    filename='best-checkpoint',  
    verbose=True
)
from pytorch_lightning.callbacks import TQDMProgressBar

class CustomProgressBar(TQDMProgressBar):
    def get_metrics(self, *args, **kwargs):
        metrics = super().get_metrics(*args, **kwargs)
        metrics.pop("v_num", None) 
        return metrics
        


In [36]:
# tb_logger = TensorBoardLogger(save_dir="lightning_logs", name="tft")
lr_monitor = LearningRateMonitor(logging_interval="epoch")
tb_logger = TensorBoardLogger(save_dir="lightning_logs", name="tft")

trainer = pl.Trainer(
    max_epochs=100,
    accelerator="cpu", devices=1,                    
    logger=tb_logger,
    log_every_n_steps=20,
    gradient_clip_val=2.0,
    gradient_clip_algorithm="norm",
    callbacks=[
        early_stop_callback,
        checkpoint_callback,
        lr_monitor,
        CustomProgressBar()
    ],
    limit_train_batches=1.0,               
    limit_val_batches=1.0,
    num_sanity_val_steps=2,
    deterministic=True
)





GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1.0)` was configured so 100% of the batches per epoch will be used..
`Trainer(limit_val_batches=1.0)` was configured so 100% of the batches will be used..


In [None]:
trainer.fit(tft)

In [44]:
version_dir = "lightning_logs/tft/version_E1"

def latest_ckpt(d):
    cands = glob.glob(os.path.join(d, "checkpoints", "*.ckpt"))
    if not cands:
        raise FileNotFoundError(f"No ckpt under {d}/checkpoints")
    best = [p for p in cands if "best" in os.path.basename(p).lower()]
    if best: return max(best, key=os.path.getmtime)
    last = [p for p in cands if "last" in os.path.basename(p).lower()]
    if last: return max(last, key=os.path.getmtime)
    return max(cands, key=os.path.getmtime)

ckpt_path = latest_ckpt(version_dir)


results = trainer.test(model=tft, ckpt_path=ckpt_path)  
print(results)                      # [{'test_loss': ... , ...}]
print(results[0].get("test_loss"))  


Restoring states from the checkpoint path at lightning_logs/tft/version_E1/checkpoints/best-checkpoint.ckpt
/Users/guanyuxiaoxiong/opt/anaconda3/envs/tft-env/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:282: Be aware that when using `ckpt_path`, callbacks used to create the checkpoint need to be provided during `Trainer` instantiation. Please add the following callbacks: ["EarlyStopping{'monitor': 'val_loss', 'mode': 'min'}", "ModelCheckpoint{'monitor': 'val_loss', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None}"].
Loaded model weights from the checkpoint at lightning_logs/tft/version_E1/checkpoints/best-checkpoint.ckpt


Testing DataLoader 0: 100%|█████████████████████| 46/46 [00:39<00:00,  1.15it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_loss           0.5090025067329407
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
[{'test_loss': 0.5090025067329407}]
0.5090025067329407


# Evaluation

## Top-K Rank

In [37]:
# run_tft_topk.py
# Main entry to compute and visualize Top-K VSN variable importances.

from types import SimpleNamespace

from Evaluation.tft_topk import (
    latest_ckpt,
    move_to,
    unpack_batch_for_io,
    extract_version_suffix,
    assert_loader_matches_model,
    collect_vsn_means,
    names_from_column_definition,
    topk_print,
    _to_np,
    plot_var_bars,
)

# -------- Config --------
VERSION_DIR = "lightning_logs/tft/version_E3"
DEVICE      = "cpu"          # or "cuda"
MAX_BATCHES = 0              # 0 = full val set
TOPK_PRINT  = 20             # Top-K to print in console; set as you like
TOPK_BARS   = None           # If int, only plot top-K bars; None = plot all
OUT_DIR     = "Image/TopK"   # where to save plots

def main():
    # --- Restore model with EXACT hparams from checkpoint ---
    ckpt_path = latest_ckpt(VERSION_DIR)
    device = torch.device(DEVICE)

    ckpt = torch.load(ckpt_path, map_location=device)
    raw_hp = ckpt.get("hyper_parameters", {})
    raw_hp = raw_hp.get("hparams", raw_hp) if isinstance(raw_hp, dict) else raw_hp
    hparams = SimpleNamespace(**raw_hp) if isinstance(raw_hp, dict) else raw_hp

    model = TemporalFusionTransformer(hparams).to(device)
    _ = model.load_state_dict(ckpt["state_dict"], strict=False)
    model.eval()
    try:
        model.log = lambda *a, **k: None
    except Exception:
        pass

    # --- Dataloader from the SAME experiment config (val -> fallback train) ---
    try:
        loader = model.val_dataloader()
    except Exception:
        loader = None
    if loader is None:
        try:
            loader = model.train_dataloader()
        except Exception:
            raise RuntimeError("Neither val_dataloader nor train_dataloader is available.")

    # --- Preflight: confirm loader feature counts match checkpoint ---
    assert_loader_matches_model(model, loader)

    # --- Collect global mean VSN weights ---
    w_static_mean, w_hist_mean, w_future_mean = collect_vsn_means(model, loader, max_batches=MAX_BATCHES)

    # --- Names for pretty printing/plots ---
    n_s = int(w_static_mean.shape[0]) if w_static_mean is not None else 0
    n_h = int(w_hist_mean.shape[0])   if w_hist_mean   is not None else 0
    n_f = int(w_future_mean.shape[0]) if w_future_mean is not None else 0
    s_names, h_names, f_names = names_from_column_definition(model, n_s, n_h, n_f)

    # --- Print Top-K to console ---
    ver = extract_version_suffix(VERSION_DIR)
    if w_static_mean is not None:
        topk_print(w_static_mean, s_names, f"Static VSN  (version={ver})", TOPK_PRINT)
    else:
        print("\n[Info] static_vsn not captured or not present.")
    if w_hist_mean is not None:
        topk_print(w_hist_mean,   h_names, f"Historical VSN (encoder)  (version={ver})", TOPK_PRINT)
    else:
        print("\n[Info] temporal_historical_vsn not captured or not present.")
    if w_future_mean is not None:
        topk_print(w_future_mean, f_names, f"Future VSN (decoder)  (version={ver})", TOPK_PRINT,  )
    else:
        print("\n[Info] temporal_future_vsn not captured or not present.")

    # --- Plots ---
    arr_static = _to_np(w_static_mean)
    arr_hist   = _to_np(w_hist_mean)
    arr_future = _to_np(w_future_mean)

    if arr_static is not None:
        plot_var_bars(s_names, arr_static, f"Static VSN — mean weight (version={ver})",
                      filename=f"vsn_bars_static_v{ver}", out_dir=OUT_DIR, topk=TOPK_BARS)
    if arr_hist is not None:
        plot_var_bars(h_names, arr_hist,   f"Historical VSN — mean weight (version={ver})",
                      filename=f"vsn_bars_hist_v{ver}", out_dir=OUT_DIR, topk=TOPK_BARS)
    if arr_future is not None:
        plot_var_bars(f_names, arr_future, f"Future VSN — mean weight (version={ver})",
                      filename=f"vsn_bars_future_v{ver}", out_dir=OUT_DIR, topk=TOPK_BARS,)

    print("Saved figures to:", OUT_DIR)


if __name__ == "__main__":
    main()


*** TemporalFusionTransformer params ***
# total_time_steps = 192
# num_encoder_steps = 168
# num_epochs = 100
# early_stopping_patience = 10
# multiprocessing_workers = 5
# column_definition = [('identifier', <DataTypes.CATEGORICAL: 1>, <InputTypes.ID: 4>), ('days_from_start', <DataTypes.REAL_VALUED: 0>, <InputTypes.TIME: 5>), ('err_target', <DataTypes.REAL_VALUED: 0>, <InputTypes.TARGET: 0>), ('hour', <DataTypes.REAL_VALUED: 0>, <InputTypes.KNOWN_INPUT: 2>), ('day_of_week', <DataTypes.REAL_VALUED: 0>, <InputTypes.KNOWN_INPUT: 2>), ('forecast solar day ahead', <DataTypes.REAL_VALUED: 0>, <InputTypes.KNOWN_INPUT: 2>), ('forecast wind onshore day ahead', <DataTypes.REAL_VALUED: 0>, <InputTypes.KNOWN_INPUT: 2>), ('price day ahead', <DataTypes.REAL_VALUED: 0>, <InputTypes.KNOWN_INPUT: 2>), ('temp_weighted', <DataTypes.REAL_VALUED: 0>, <InputTypes.KNOWN_INPUT: 2>), ('temp_min_weighted', <DataTypes.REAL_VALUED: 0>, <InputTypes.KNOWN_INPUT: 2>), ('temp_max_weighted', <DataTypes.REAL_VALUED: 

## Ablation analysis

In [38]:
import torch
from types import SimpleNamespace

from Evaluation.tft_ablation import (
    latest_ckpt,
    names_from_column_definition,
    run_ablation,
    save_ablation_csv,
    plot_ablation,
    assert_loader_matches_model,
       # <-- requires the preflight helper added in tft_ablation.py
)

# -------- Config --------
VERSION_DIR = "lightning_logs/tft/version_E3"
DEVICE      = "cpu"            # or "cuda"
MODE        = "permute"        # "permute" or "zero"
SEED        = 42
MAX_BATCHES = 0                # 0 = full val set
VAR_LIST    = None             # e.g. ["rain_3h_weighted", "temp_weighted"]; None = all KNOWN+OBSERVED
OUT_DIR     = "Image/DMAC"
CSV_NAME    = f"ablation_{MODE}.csv"


def main():
    # --- Load model from checkpoint hparams (shape-safe) ---
    ckpt_path = latest_ckpt(VERSION_DIR)
    device = torch.device(DEVICE)

    ckpt = torch.load(ckpt_path, map_location=device)
    raw_hp = ckpt.get("hyper_parameters", {})
    raw_hp = raw_hp.get("hparams", raw_hp) if isinstance(raw_hp, dict) else raw_hp
    hparams = SimpleNamespace(**raw_hp) if isinstance(raw_hp, dict) else raw_hp

    # Adjust this import to your project if needed
    model = TemporalFusionTransformer(hparams).to(device)
    _ = model.load_state_dict(ckpt["state_dict"], strict=False)
    model.eval()
    try:
        model.log = lambda *a, **k: None  # silence self.log during manual evaluation
    except Exception:
        pass

    # --- Get DataLoader (val -> fallback train); both should be built using hparams above ---
    try:
        loader = model.val_dataloader()
    except Exception:
        loader = None
    if loader is None:
        try:
            loader = model.train_dataloader()
        except Exception:
            raise RuntimeError("No val/train dataloader available.")

    # --- Preflight: ensure loader feature counts match checkpoint expectations ---
    assert_loader_matches_model(model, loader)

    # --- Discover variable names for ablation list ---
    obs_names, knw_names = names_from_column_definition(model)

    # --- Run ablation ---
    rows, baseline, n_batches = run_ablation(
        model=model,
        loader=loader,
        observed_names=obs_names,
        known_names=knw_names,
        mode=MODE,
        max_batches=MAX_BATCHES,
        var_list=VAR_LIST,
        device=device,
        seed=SEED,
    )

    # --- Save CSV ---
    csv_path = save_ablation_csv(rows, OUT_DIR, CSV_NAME)
    print(f"\nSaved ablation table: {csv_path}")
    print(f"[Baseline] mean loss over {n_batches} batch(es): {baseline:.6f}")

    # --- Print top-k summary ---
    top_k = min(20, len(rows))
    print(f"\nTop {top_k} variables by impact (%): (mode={MODE}, batches={n_batches if MAX_BATCHES==0 else MAX_BATCHES})")
    for var, pert_mean, delta, pct in rows[:top_k]:
        print(f"{pct:8.3f}%  Δ={delta:10.6f}  loss'={pert_mean:10.6f}  {var}")

    # --- Figures (titles include version suffix parsed from VERSION_DIR) ---
    plot_ablation(rows, baseline, OUT_DIR, VERSION_DIR, n_batches, MODE)
    print("Saved figures to:", OUT_DIR)


if __name__ == "__main__":
    main()




*** TemporalFusionTransformer params ***
# total_time_steps = 192
# num_encoder_steps = 168
# num_epochs = 100
# early_stopping_patience = 10
# multiprocessing_workers = 5
# column_definition = [('identifier', <DataTypes.CATEGORICAL: 1>, <InputTypes.ID: 4>), ('days_from_start', <DataTypes.REAL_VALUED: 0>, <InputTypes.TIME: 5>), ('err_target', <DataTypes.REAL_VALUED: 0>, <InputTypes.TARGET: 0>), ('hour', <DataTypes.REAL_VALUED: 0>, <InputTypes.KNOWN_INPUT: 2>), ('day_of_week', <DataTypes.REAL_VALUED: 0>, <InputTypes.KNOWN_INPUT: 2>), ('forecast solar day ahead', <DataTypes.REAL_VALUED: 0>, <InputTypes.KNOWN_INPUT: 2>), ('forecast wind onshore day ahead', <DataTypes.REAL_VALUED: 0>, <InputTypes.KNOWN_INPUT: 2>), ('price day ahead', <DataTypes.REAL_VALUED: 0>, <InputTypes.KNOWN_INPUT: 2>), ('temp_weighted', <DataTypes.REAL_VALUED: 0>, <InputTypes.KNOWN_INPUT: 2>), ('temp_min_weighted', <DataTypes.REAL_VALUED: 0>, <InputTypes.KNOWN_INPUT: 2>), ('temp_max_weighted', <DataTypes.REAL_VALUED: 

## Integrated Gradients

In [39]:
# run_tft_ig.py  — adapted main program for your current setup

import torch
import os, re 
from types import SimpleNamespace
from Evaluation.tft_ig_lib import (
    latest_ckpt, move_to, ensure_dir,
    names_from_column_definition, unpack_batch_for_io, select_single_sample,
    integrated_gradients_single_input, topk_indices_by_total_abs, plot_time_curves,
)

# ---- Config ----
VERSION_DIR, DEVICE = "lightning_logs/tft/version_E3", "cpu"
SAMPLE_IDX, IG_STEPS, TOPK = 0, 32, 8
SAVE_DIR, ZERO_BASELINE = "Image/IG", True

def extract_version_suffix(version_dir: str) -> str:
    base = os.path.basename(version_dir.rstrip("/"))
    m = re.search(r"version[_\-]?(.+)$", base)
    return m.group(1) if m else base


def main():
    device = torch.device(DEVICE)

    # --- Restore model with EXACT hparams from checkpoint (avoid shape mismatches) ---
    ckpt_path = latest_ckpt(VERSION_DIR)
    ckpt = torch.load(ckpt_path, map_location=device)
    raw_hp = ckpt.get("hyper_parameters", {})
    raw_hp = raw_hp.get("hparams", raw_hp) if isinstance(raw_hp, dict) else raw_hp
    hparams = SimpleNamespace(**raw_hp) if isinstance(raw_hp, dict) else raw_hp

    model = TemporalFusionTransformer(hparams).to(device)
    _ = model.load_state_dict(ckpt["state_dict"], strict=False)
    model.eval()
    try:
        model.log = lambda *a, **k: None
    except Exception:
        pass

    # --- Use the dataloader built from the SAME experiment config ---
    try:
        loader = model.val_dataloader()
    except Exception:
        loader = None
    if loader is None:
        try:
            loader = model.train_dataloader()
        except Exception:
            raise RuntimeError("No val/train dataloader available.")

    # --- Preflight: run one forward via _prepare_tft_inputs to validate shapes ---
    batch = move_to(next(iter(loader)), device)
    packed = unpack_batch_for_io(batch)
    try:
        all_inputs, _ = model._prepare_tft_inputs(packed)
        _ = model(all_inputs)  # single forward to ensure loader matches checkpoint
    except Exception as e:
        raise RuntimeError(
            "DataLoader/ckpt feature configuration mismatch. "
            "Please ensure this loader was built with the SAME column_definition/data_formatter as the checkpoint."
        ) from e

    # --- Pick one sample from the batch for IG ---
    single = select_single_sample(packed, SAMPLE_IDX)

    # --- Variable names (for plotting legends); will be padded below if needed ---
    obs_names, knw_names = names_from_column_definition(model)

    # --- Integrated Gradients for KNOWN/OBSERVED streams ---
    attr_known = integrated_gradients_single_input(
        model, single, which="known", steps=IG_STEPS, zero_baseline=ZERO_BASELINE, device=device
    )
    attr_observed = integrated_gradients_single_input(
        model, single, which="observed", steps=IG_STEPS, zero_baseline=ZERO_BASELINE, device=device
    )

    # --- Ensure we have names for every channel that appears in attributions ---
    if attr_known is not None:
        V_k = attr_known.shape[-1]
        if not knw_names or len(knw_names) < V_k:
            base = len(knw_names or [])
            knw_names = (knw_names or []) + [f"known_{i}" for i in range(base, V_k)]
    if attr_observed is not None:
        V_h = attr_observed.shape[-1]
        if not obs_names or len(obs_names) < V_h:
            base = len(obs_names or [])
            obs_names = (obs_names or []) + [f"observed_{i}" for i in range(base, V_h)]

    # --- Plot & save ---
    ver = extract_version_suffix(VERSION_DIR)  # e.g., "E1"
    ensure_dir(SAVE_DIR)

    plot_time_curves(
    attr_known,
    knw_names,
    topk_indices_by_total_abs(attr_known, TOPK),
    f"IG — KNOWN (version={ver}) · sample {SAMPLE_IDX}",
    f"{SAVE_DIR}/ig_known_sample{SAMPLE_IDX}_faceted.png",
    # optional kwargs for faceting:
    types_map=None,     # auto infer: Price / Fossil / Renewable / Weather / Other
    col_wrap=2,         # 2 columns of subplots; change to 3 if you有很多类型
    sharey=False        # each facet has its own y-scale (更易读)
)

    plot_time_curves(
    attr_observed,
    obs_names,
    topk_indices_by_total_abs(attr_observed, TOPK),
    f"IG — OBSERVED (version={ver}) · sample {SAMPLE_IDX}",
    f"{SAVE_DIR}/ig_observed_sample{SAMPLE_IDX}_faceted.png",
    types_map=None,
    col_wrap=2,
    sharey=False
      )


if __name__ == "__main__":
    main()


*** TemporalFusionTransformer params ***
# total_time_steps = 192
# num_encoder_steps = 168
# num_epochs = 100
# early_stopping_patience = 10
# multiprocessing_workers = 5
# column_definition = [('identifier', <DataTypes.CATEGORICAL: 1>, <InputTypes.ID: 4>), ('days_from_start', <DataTypes.REAL_VALUED: 0>, <InputTypes.TIME: 5>), ('err_target', <DataTypes.REAL_VALUED: 0>, <InputTypes.TARGET: 0>), ('hour', <DataTypes.REAL_VALUED: 0>, <InputTypes.KNOWN_INPUT: 2>), ('day_of_week', <DataTypes.REAL_VALUED: 0>, <InputTypes.KNOWN_INPUT: 2>), ('forecast solar day ahead', <DataTypes.REAL_VALUED: 0>, <InputTypes.KNOWN_INPUT: 2>), ('forecast wind onshore day ahead', <DataTypes.REAL_VALUED: 0>, <InputTypes.KNOWN_INPUT: 2>), ('price day ahead', <DataTypes.REAL_VALUED: 0>, <InputTypes.KNOWN_INPUT: 2>), ('temp_weighted', <DataTypes.REAL_VALUED: 0>, <InputTypes.KNOWN_INPUT: 2>), ('temp_min_weighted', <DataTypes.REAL_VALUED: 0>, <InputTypes.KNOWN_INPUT: 2>), ('temp_max_weighted', <DataTypes.REAL_VALUED: 

# Only For E3

## Total Load Forecast Loss

In [42]:
# 1. Load the dataset
df = pd.read_csv("data/electricty.csv")

# ⚠️ Replace with your actual column names
y_true = torch.tensor(df["target"].values, dtype=torch.float32)
y_pred = torch.tensor(df["total load forecast"].values, dtype=torch.float32)


# Assume you trained with three quantiles
quantiles = [0.1, 0.5, 0.9]
output_size = 1   # one prediction value per time step

ql_calc = QuantileLossCalculator(quantiles, output_size)

# To match the input format of QuantileLossCalculator,
# we need to concatenate y_true / y_pred along quantile dimension: [N, output_size * num_quantiles]
a_targets = torch.cat([y_true.unsqueeze(1) for _ in quantiles], dim=1)  # [N, len(quantiles)]
b_preds   = torch.cat([y_pred.unsqueeze(1) for _ in quantiles], dim=1)  # [N, len(quantiles)]

# 3. Compute the overall test loss (mean across all time steps and quantiles)
test_loss = ql_calc.apply(b_preds, a_targets)
print("Overall test loss:", test_loss.item())

# 4. Optionally, compute Normalized Quantile Loss (NQL) for each quantile
nql_calc = NormalizedQuantileLossCalculator(quantiles, output_size)
for tau in quantiles:
    nql = nql_calc.apply(y_true, y_pred, tau)
    print(f"Normalized Quantile Loss (tau={tau}): {nql.item()}")


Overall test loss: 474.1687927246094
Normalized Quantile Loss (tau=0.1): 0.011410183273255825
Normalized Quantile Loss (tau=0.5): 0.01101555023342371
Normalized Quantile Loss (tau=0.9): 0.01062091439962387


## E3 Loss

In [25]:
from types import SimpleNamespace
# ===== Config =====
CKPT_DIR = "lightning_logs/tft/version_E3"
CSV_PATH = "data/electricty.csv"   # CSV with hourly raw target / total load forecast (TLF)
BATCH_SIZE = 256
NUM_WORKERS = 2

# ===== Utilities =====
def find_ckpt(d):
    # Prefer best/last, otherwise pick the newest .ckpt file
    for n in ("best-checkpoint.ckpt", "last.ckpt"):
        p = os.path.join(d, "checkpoints", n)
        if os.path.exists(p): return p
    cands = glob.glob(os.path.join(d, "checkpoints", "*.ckpt"))
    if not cands: raise FileNotFoundError(f"No ckpt in {d}/checkpoints")
    return max(cands, key=os.path.getmtime)

def to_ns(x):
    # Recursively convert dict to SimpleNamespace
    return SimpleNamespace(**{k: to_ns(v) if isinstance(v, dict) else v for k, v in x.items()}) if isinstance(x, dict) else x

def pick_col(df, cands):
    # Pick the first existing column name from candidates
    for c in cands:
        if c in df.columns: return c
    raise KeyError(f"None of {cands} in columns: {list(df.columns)[:20]}")

def norm_hour(series):
    # Normalize hour to 0..23 (handles 1..24 inputs as well)
    s = pd.to_numeric(series, errors="coerce").fillna(-1).astype(int)
    return ((s - 1) % 24 if s.min() >= 1 and s.max() >= 24 else s % 24).astype(int)

# ===== 1) Restore model (your __init__ takes positional hparams) =====
ckpt = torch.load(find_ckpt(CKPT_DIR), map_location="cpu")
raw_hp = ckpt.get("hyper_parameters", {})
raw_hp = raw_hp.get("hparams", raw_hp) if isinstance(raw_hp, dict) else raw_hp
hparams = to_ns(raw_hp)

tft = TemporalFusionTransformer(hparams)
tft.load_state_dict(ckpt["state_dict"], strict=False)
tft.eval()

# Key info
is_err = bool(getattr(hparams, "err_target", False))   # True: model outputs residual err_target
nd     = int(hparams.total_time_steps - hparams.num_encoder_steps)
dec    = slice(-nd, None)

# ===== 2) Forward pass on your existing test_dataset =====
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

# ===== 3) Rebuild decoder keys (id, day, hour) keeping original order (do NOT sort df_all) =====
df_all = test_dataset.data.reset_index(drop=True).copy()
idx_df = test_dataset.data_index.reset_index(drop=True)
N = len(idx_df)

ID_COL  = pick_col(df_all, ["identifier","targetidentifier","id","series_id"])
DAY_COL = pick_col(df_all, ["days_from_start","day_from_start","days"])
H_COL   = pick_col(df_all, ["hour","hours","hour_of_day","hr"])

df_all["_hour_norm"] = norm_hour(df_all[H_COL])
df_all["_k"] = df_all.groupby([ID_COL, DAY_COL]).cumcount()  # within-day order in original sequence

ids_mat, day_mat, hour_mat, k_mat = [], [], [], []
for r in idx_df.itertuples():
    rng = np.arange(r.end_abs - nd, r.end_abs)  # decoder slice for this sample
    ids_mat.append(  df_all.loc[rng, ID_COL].to_numpy() )
    day_mat.append(  df_all.loc[rng, DAY_COL].to_numpy() )
    hour_mat.append( df_all.loc[rng, "_hour_norm"].to_numpy() )
    k_mat.append(    df_all.loc[rng, "_k"].to_numpy() )

ids_mat  = np.stack(ids_mat,  axis=0)  # [N, nd]
day_mat  = np.stack(day_mat,  axis=0)
hour_mat = np.stack(hour_mat, axis=0)
k_mat    = np.stack(k_mat,    axis=0)

# ===== 4) Pull raw target and TLF from CSV and align by (id, day, hour) =====
df_csv_raw = pd.read_csv(CSV_PATH)
ID_C   = pick_col(df_csv_raw, ["identifier","targetidentifier","id","series_id"])
DAY_C  = pick_col(df_csv_raw, ["days_from_start","day_from_start","days"])
H_C    = pick_col(df_csv_raw, ["hour","hours","hour_of_day","hr"])
TGT_C  = pick_col(df_csv_raw, ["target","power_usage","load","y"])
TLF_C  = pick_col(df_csv_raw, ["total load forecast","total_load_forecast","tlf","base"])

df_csv = df_csv_raw[[ID_C, DAY_C, H_C, TGT_C, TLF_C]].copy()
df_csv["_hour_norm"] = norm_hour(df_csv[H_C])
df_csv = df_csv.rename(columns={ID_C:"id", DAY_C:"day"})
df_csv = df_csv.dropna(subset=["day"])
df_csv["id"]  = df_csv["id"].astype(str)
df_csv["day"] = pd.to_numeric(df_csv["day"], errors="coerce").astype(int)
df_csv = df_csv.drop_duplicates(["id","day","_hour_norm"])

key_df = pd.DataFrame({
    "id":  ids_mat.reshape(-1).astype(str),
    "day": day_mat.reshape(-1).astype(np.int64),
    "hn":  hour_mat.reshape(-1).astype(np.int64),
})
right = df_csv.rename(columns={"_hour_norm":"hn"})[["id","day","hn", TGT_C, TLF_C]]
joined = key_df.merge(right, on=["id","day","hn"], how="left")

y_true_flat = joined[TGT_C].to_numpy()
tlf_flat    = joined[TLF_C].to_numpy()

# Fallback: if many missings, align by within-day order k (keep df_all order untouched)
miss = np.isnan(y_true_flat).sum() + np.isnan(tlf_flat).sum()
if miss > 0.1 * y_true_flat.size:
    df_csv_sorted = df_csv.sort_values(["id","day","_hour_norm"]).copy()
    df_csv_sorted["_k"] = df_csv_sorted.groupby(["id","day"]).cumcount()
    key_df_k = pd.DataFrame({
        "id":  ids_mat.reshape(-1).astype(str),
        "day": day_mat.reshape(-1).astype(np.int64),
        "k":   k_mat.reshape(-1).astype(np.int64),
    })
    right_k = df_csv_sorted[["id","day","_k", TGT_C, TLF_C]].rename(columns={"_k":"k"})
    joined  = key_df_k.merge(right_k, on=["id","day","k"], how="left")
    y_true_flat = joined[TGT_C].to_numpy()
    tlf_flat    = joined[TLF_C].to_numpy()

expected = N * nd
assert y_true_flat.size == expected and tlf_flat.size == expected, "Alignment size mismatch"
y_true = y_true_flat.reshape(N, nd)
tlf    = tlf_flat.reshape(N, nd)

# ===== 5) Forward pass: get err_target or direct outputs (no inverse-scaling, no mapping) =====
outs, taus, Q = [], None, None
with torch.no_grad():
    for batch in test_loader:
        (x_known, x_observed, x_static), _ = batch
        if hasattr(tft, "_prepare_tft_inputs"):
            all_inputs, _ = tft._prepare_tft_inputs(batch)
            out = tft(all_inputs)
        else:
            out = tft(x_known, x_observed, x_static)
        if out.dim() == 2: out = out.unsqueeze(-1)     # [B, nd, 1]
        if Q is None:
            Q = out.size(-1)
            hp_q = getattr(getattr(tft, "hparams", None), "quantiles", None)
            taus = list(hp_q) if isinstance(hp_q, (list, tuple)) and len(hp_q)==Q else ([0.5] if Q==1 else [0.1,0.5,0.9])
        outs.append(out.cpu())
outs = torch.cat(outs, dim=0).numpy()                  # [N, nd, Q]

# Combine: err_target + TLF  (or use out directly)
y_hat_q = (outs + tlf[:, :, None]) if is_err else outs  # no scaling/mapping applied

# ===== 6) Compute QuantileLoss / NQL using your calculators =====
ql_calc  = getattr(tft, "train_criterion", None)
nql_calc = getattr(tft, "test_criterion",  None)
if ql_calc is None or nql_calc is None:
    raise RuntimeError("Missing QuantileLossCalculator / NormalizedQuantileLossCalculator (tft.train_criterion / tft.test_criterion).")

y_true_flat2 = y_true.reshape(-1)
y_hat_flat_q = y_hat_q.reshape(-1, y_hat_q.shape[-1])

targets_q = torch.tensor(np.repeat(y_true_flat2[:, None], y_hat_flat_q.shape[1], axis=1), dtype=torch.float32)
preds_q   = torch.tensor(y_hat_flat_q, dtype=torch.float32)
y_true_t  = torch.tensor(y_true_flat2, dtype=torch.float32)

print(f"[TEST] QuantileLoss (mean over taus): {ql_calc.apply(preds_q, targets_q).item():.6f}")
for j, tau in enumerate(taus):
    print(f"[TEST] NQL@{tau}: {nql_calc.apply(y_true_t, preds_q[:, j], tau).item():.6f}")


*** TemporalFusionTransformer params ***
# total_time_steps = 192
# num_encoder_steps = 168
# num_epochs = 100
# early_stopping_patience = 10
# multiprocessing_workers = 5
# column_definition = [('identifier', <DataTypes.CATEGORICAL: 1>, <InputTypes.ID: 4>), ('days_from_start', <DataTypes.REAL_VALUED: 0>, <InputTypes.TIME: 5>), ('err_target', <DataTypes.REAL_VALUED: 0>, <InputTypes.TARGET: 0>), ('hour', <DataTypes.REAL_VALUED: 0>, <InputTypes.KNOWN_INPUT: 2>), ('day_of_week', <DataTypes.REAL_VALUED: 0>, <InputTypes.KNOWN_INPUT: 2>), ('forecast solar day ahead', <DataTypes.REAL_VALUED: 0>, <InputTypes.KNOWN_INPUT: 2>), ('forecast wind onshore day ahead', <DataTypes.REAL_VALUED: 0>, <InputTypes.KNOWN_INPUT: 2>), ('price day ahead', <DataTypes.REAL_VALUED: 0>, <InputTypes.KNOWN_INPUT: 2>), ('temp_weighted', <DataTypes.REAL_VALUED: 0>, <InputTypes.KNOWN_INPUT: 2>), ('temp_min_weighted', <DataTypes.REAL_VALUED: 0>, <InputTypes.KNOWN_INPUT: 2>), ('temp_max_weighted', <DataTypes.REAL_VALUED: 