# 4.3 — Temporal Fusion Transformer: Load Demand
Point predictions with interpretable attention. 24h ahead, trained 2015–2017, tested 2018.

In [1]:
import pandas as pd
import numpy as np
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

df = pd.read_parquet('../cleaned_data.parquet')
df['time'] = pd.to_datetime(df['time'], utc=True)

# Use MPS (Apple Silicon GPU) if available
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
try:
    t = torch.randn(2, 2, device=device)
    _ = t @ t
except:
    device = torch.device('cpu')

print(f"Shape: {df.shape}")
print(f"PyTorch: {torch.__version__}")
print(f"Device: {device}")

Shape: (35056, 80)
PyTorch: 2.10.0
Device: mps


Prepare features and normalize using training stats

In [2]:
target_col = 'total load actual'
tso_col = 'total load forecast'

weather_cols = [
    'temp_madrid', 'temp_bilbao', 'temp_barcelona',
    'temp_seville', 'temp_valencia',
    'humidity_madrid', 'humidity_bilbao', 'humidity_barcelona',
    'humidity_seville', 'humidity_valencia',
    'pressure_madrid', 'pressure_bilbao', 'pressure_barcelona',
    'pressure_seville', 'pressure_valencia',
]
time_cols = ['hour', 'month']
feature_cols = weather_cols + time_cols

# Normalize using training stats only
train_mask = df['time'].dt.year <= 2017

target_mean = df.loc[train_mask, target_col].mean()
target_std = df.loc[train_mask, target_col].std()

feat_means = df.loc[train_mask, feature_cols].mean()
feat_stds = df.loc[train_mask, feature_cols].std().replace(0, 1)

target_norm = (df[target_col].values - target_mean) / target_std
features_norm = ((df[feature_cols] - feat_means) / feat_stds).fillna(0).values

# Combine: [target, features]
all_data = np.column_stack([target_norm, features_norm]).astype(np.float32)

print(f"Input channels: {all_data.shape[1]} (1 target + {len(feature_cols)} features)")
print(f"Target mean: {target_mean:.0f} MW, std: {target_std:.0f} MW")

Input channels: 18 (1 target + 17 features)
Target mean: 28576 MW, std: 4552 MW


Sliding window dataset — 168h context, 24h prediction

In [3]:
context_length = 168   # 7 days of history
prediction_length = 24  # 24h ahead

class TimeSeriesDataset(Dataset):
    def __init__(self, data, ctx_len, pred_len, start_idx, end_idx):
        self.data = data
        self.ctx_len = ctx_len
        self.pred_len = pred_len
        self.start = start_idx
        self.end = end_idx

    def __len__(self):
        return self.end - self.start - self.ctx_len - self.pred_len + 1

    def __getitem__(self, idx):
        i = self.start + idx
        # Observed context: target + all features
        x_observed = self.data[i : i + self.ctx_len]  # (ctx_len, 1+num_features)
        # Future known: features only (no target)
        x_future = self.data[i + self.ctx_len : i + self.ctx_len + self.pred_len, 1:]  # (pred_len, num_features)
        # Target to predict
        y = self.data[i + self.ctx_len : i + self.ctx_len + self.pred_len, 0]  # (pred_len,)
        return (
            torch.from_numpy(x_observed),
            torch.from_numpy(x_future),
            torch.from_numpy(y),
        )

train_end = int(train_mask.sum())
val_split = int(train_end * 0.8)

train_ds = TimeSeriesDataset(all_data, context_length, prediction_length, 0, val_split)
val_ds = TimeSeriesDataset(all_data, context_length, prediction_length, val_split, train_end)
train_loader = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=0)
val_loader = DataLoader(val_ds, batch_size=256, shuffle=False, num_workers=0)

print(f"Train samples: {len(train_ds)}, Val samples: {len(val_ds)}")
print(f"Context: {context_length}h, Prediction: {prediction_length}h")

Train samples: 20846, Val samples: 5069
Context: 168h, Prediction: 24h


TFT model — variable selection, gated residual networks, LSTM, interpretable multi-head attention

In [4]:
class GatedResidualNetwork(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, dropout=0.1):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)
        self.gate_fc = nn.Linear(hidden_size, output_size)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(output_size)
        self.skip = nn.Linear(input_size, output_size) if input_size != output_size else nn.Identity()

    def forward(self, x):
        residual = self.skip(x)
        h = F.elu(self.fc1(x))
        h = self.dropout(h)
        output = self.fc2(h)
        gate = torch.sigmoid(self.gate_fc(h))
        return self.layer_norm(gate * output + residual)


class VariableSelectionNetwork(nn.Module):
    def __init__(self, num_vars, d_model, dropout=0.1):
        super().__init__()
        self.num_vars = num_vars
        self.d_model = d_model
        self.var_transforms = nn.ModuleList([nn.Linear(1, d_model) for _ in range(num_vars)])
        self.weight_network = GatedResidualNetwork(num_vars * d_model, d_model, num_vars, dropout)

    def forward(self, x):
        var_outputs = []
        for i in range(self.num_vars):
            var_outputs.append(self.var_transforms[i](x[:, :, i:i+1]))
        var_stack = torch.stack(var_outputs, dim=2)
        flat = var_stack.reshape(x.shape[0], x.shape[1], -1)
        weights = F.softmax(self.weight_network(flat), dim=-1)
        selected = (var_stack * weights.unsqueeze(-1)).sum(dim=2)
        return selected, weights


class InterpretableMultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.1):
        super().__init__()
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, self.d_k)
        self.out_proj = nn.Linear(self.d_k, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, q, k, v, mask=None):
        bs = q.size(0)
        Q = self.W_q(q).view(bs, -1, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_k(k).view(bs, -1, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_v(v).unsqueeze(1).expand(-1, self.n_heads, -1, -1)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)
        context = torch.matmul(attn, V)
        context = context.mean(dim=1)
        return self.out_proj(context), attn


class TemporalFusionTransformer(nn.Module):
    def __init__(self, num_observed, num_known_future, d_model=32, n_heads=4,
                 n_lstm_layers=1, pred_len=24, dropout=0.1):
        super().__init__()
        self.pred_len = pred_len
        self.d_model = d_model
        self.obs_vsn = VariableSelectionNetwork(num_observed, d_model, dropout)
        self.fut_vsn = VariableSelectionNetwork(num_known_future, d_model, dropout)
        self.encoder_lstm = nn.LSTM(d_model, d_model, n_lstm_layers, batch_first=True,
                                     dropout=dropout if n_lstm_layers > 1 else 0)
        self.decoder_lstm = nn.LSTM(d_model, d_model, n_lstm_layers, batch_first=True,
                                     dropout=dropout if n_lstm_layers > 1 else 0)
        self.lstm_gate = GatedResidualNetwork(d_model, d_model, d_model, dropout)
        self.attention = InterpretableMultiHeadAttention(d_model, n_heads, dropout)
        self.attn_gate = GatedResidualNetwork(d_model, d_model, d_model, dropout)
        self.output_proj = nn.Linear(d_model, 1)

    def forward(self, x_observed, x_future):
        enc_selected, enc_weights = self.obs_vsn(x_observed)
        dec_selected, dec_weights = self.fut_vsn(x_future)
        enc_out, (h, c) = self.encoder_lstm(enc_selected)
        dec_out, _ = self.decoder_lstm(dec_selected, (h, c))
        lstm_out = torch.cat([enc_out, dec_out], dim=1)
        input_cat = torch.cat([enc_selected, dec_selected], dim=1)
        lstm_out = self.lstm_gate(lstm_out) + input_cat
        attn_out, attn_weights = self.attention(lstm_out, lstm_out, lstm_out)
        attn_out = self.attn_gate(attn_out) + lstm_out
        decoder_out = attn_out[:, -self.pred_len:, :]
        output = self.output_proj(decoder_out).squeeze(-1)
        return output, enc_weights, attn_weights


num_observed = all_data.shape[1]  # target + features
num_known_future = len(feature_cols)  # features only (no target)

model = TemporalFusionTransformer(
    num_observed=num_observed,
    num_known_future=num_known_future,
    d_model=16,
    n_heads=4,
    n_lstm_layers=1,
    pred_len=prediction_length,
    dropout=0.1,
).to(device)

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Observed vars: {num_observed}, Future vars: {num_known_future}")
print(f"On device: {device}")

Model parameters: 27,972
Observed vars: 18, Future vars: 17
On device: mps


Train with MSE loss, early stopping (patience=15)

In [5]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
criterion = nn.MSELoss()

n_epochs = 100
patience = 15
best_val_loss = float('inf')
best_state = None
epochs_no_improve = 0

for epoch in range(n_epochs):
    # Train
    model.train()
    train_losses = []
    for x_obs, x_fut, y in train_loader:
        x_obs, x_fut, y = x_obs.to(device), x_fut.to(device), y.to(device)
        preds, _, _ = model(x_obs, x_fut)
        loss = criterion(preds, y)
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        train_losses.append(loss.item())
    
    # Validate
    model.eval()
    val_losses = []
    with torch.no_grad():
        for x_obs, x_fut, y in val_loader:
            x_obs, x_fut, y = x_obs.to(device), x_fut.to(device), y.to(device)
            preds, _, _ = model(x_obs, x_fut)
            val_losses.append(criterion(preds, y).item())
    
    scheduler.step()
    train_loss = np.mean(train_losses)
    val_loss = np.mean(val_losses)
    
    # Early stopping
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_state = {k: v.clone() for k, v in model.state_dict().items()}
        epochs_no_improve = 0
        marker = ' *'
    else:
        epochs_no_improve += 1
        marker = ''
    
    if (epoch + 1) % 10 == 0 or epoch == 0 or epochs_no_improve == 0:
        print(f"Epoch {epoch+1:3d}/{n_epochs}, Train: {train_loss:.5f}, Val: {val_loss:.5f}, LR: {scheduler.get_last_lr()[0]:.6f}{marker}")
    
    if epochs_no_improve >= patience:
        print(f"\nEarly stopping at epoch {epoch+1} (no improvement for {patience} epochs)")
        break

# Restore best weights
if best_state is not None:
    model.load_state_dict(best_state)
    print(f"Restored best model (val loss: {best_val_loss:.5f})")

Epoch   1/100, Train: 0.52457, Val: 0.42746, LR: 0.001000 *


Epoch   2/100, Train: 0.36147, Val: 0.34672, LR: 0.000999 *


Epoch   3/100, Train: 0.32714, Val: 0.32779, LR: 0.000998 *


Epoch   4/100, Train: 0.29998, Val: 0.32210, LR: 0.000996 *


Epoch   5/100, Train: 0.28557, Val: 0.32012, LR: 0.000994 *


Epoch   6/100, Train: 0.27168, Val: 0.29128, LR: 0.000991 *


Epoch   9/100, Train: 0.24883, Val: 0.27844, LR: 0.000980 *


Epoch  10/100, Train: 0.23933, Val: 0.27779, LR: 0.000976 *


Epoch  20/100, Train: 0.15997, Val: 0.31750, LR: 0.000905



Early stopping at epoch 25 (no improvement for 15 epochs)
Restored best model (val loss: 0.27779)


Generate 24h-ahead forecasts on 2018 test set

In [6]:
model.eval()
test_start = train_end
test_end = len(all_data)

all_preds = []
all_actuals = []
all_tso = []
all_times = []

with torch.no_grad():
    for i in range(test_start, test_end - prediction_length, prediction_length):
        if i - context_length < 0:
            continue

        x_obs = torch.from_numpy(all_data[i - context_length : i]).unsqueeze(0).to(device)
        x_fut = torch.from_numpy(all_data[i : i + prediction_length, 1:]).unsqueeze(0).to(device)

        preds, _, _ = model(x_obs, x_fut)

        # Denormalize
        pred_mw = preds.squeeze().cpu().numpy() * target_std + target_mean
        actual_mw = all_data[i : i + prediction_length, 0] * target_std + target_mean
        tso_mw = df[tso_col].iloc[i : i + prediction_length].values
        times = df['time'].iloc[i : i + prediction_length].values

        all_preds.append(pred_mw)
        all_actuals.append(actual_mw)
        all_tso.append(tso_mw)
        all_times.append(times)

print(f"Generated {len(all_preds)} forecast windows across 2018")

Generated 364 forecast windows across 2018


Evaluate: MAE, RMSE, MAPE vs TSO baseline

In [7]:
preds_flat = np.concatenate(all_preds)
actuals_flat = np.concatenate(all_actuals)
tso_flat = np.concatenate(all_tso)

mae_tft = np.mean(np.abs(actuals_flat - preds_flat))
rmse_tft = np.sqrt(np.mean((actuals_flat - preds_flat) ** 2))
mape_tft = np.mean(np.abs((actuals_flat - preds_flat) / actuals_flat)) * 100

mae_tso = np.mean(np.abs(actuals_flat - tso_flat))
rmse_tso = np.sqrt(np.mean((actuals_flat - tso_flat) ** 2))
mape_tso = np.mean(np.abs((actuals_flat - tso_flat) / actuals_flat)) * 100

print(f"{'Metric':<10} {'TFT':>10} {'TSO':>10}")
print(f"{'MAE (MW)':<10} {mae_tft:>10.0f} {mae_tso:>10.0f}")
print(f"{'RMSE (MW)':<10} {rmse_tft:>10.0f} {rmse_tso:>10.0f}")
print(f"{'MAPE (%)':<10} {mape_tft:>10.1f} {mape_tso:>10.1f}")

Metric            TFT        TSO
MAE (MW)         2128        270
RMSE (MW)        2720        390
MAPE (%)          7.4        0.9


XGBoost residual correction — per-horizon models trained on validation-set TFT errors

In [None]:
fig, ax = plt.subplots(figsize=(14, 5))

sample_windows = range(9, 16)
for w in sample_windows:
    if w >= len(all_preds):
        break
    hours = range(w * prediction_length, (w + 1) * prediction_length)
    ax.plot(hours, all_actuals[w], color='#1a1a2e', linewidth=1.5,
            label='Actual' if w == 9 else None)
    ax.plot(hours, corrected_preds[w], color='#e76f51', linewidth=1.5,
            label='TFT + XGBoost' if w == 9 else None)
    ax.plot(hours, all_tso[w], color='grey', linewidth=1, linestyle='--',
            label='TSO forecast' if w == 9 else None)

ax.set_xlabel('Hour')
ax.set_ylabel('MW')
ax.set_title('TFT Load Demand')
ax.legend()
plt.tight_layout()
plt.show()

Predicted vs actual — sample week

In [None]:
import os
os.makedirs('../dashboard/public/data', exist_ok=True)

sample_data = []
for w in sample_windows:
    if w >= len(all_preds):
        break
    for h in range(prediction_length):
        t = pd.Timestamp(all_times[w][h])
        sample_data.append({
            'time': t.strftime('%Y-%m-%d %H:%M'),
            'actual': round(float(all_actuals[w][h]), 1),
            'predicted': round(float(corrected_preds[w][h]), 1),
            'tso': round(float(all_tso[w][h]), 1),
        })

output = {
    'target': 'load',
    'model': 'TFT + XGBoost Residual Correction',
    'prediction_length_hours': prediction_length,
    'context_length_hours': context_length,
    'metrics': {
        'mae': round(float(corr_mae), 1),
        'rmse': round(float(corr_rmse), 1),
        'mape': round(float(corr_mape), 1),
        'tso_mae': round(float(mae_tso), 1),
        'tso_rmse': round(float(rmse_tso), 1),
        'raw_mae': round(float(mae_tft), 1),
        'raw_rmse': round(float(rmse_tft), 1),
    },
    'sample_forecast': sample_data,
}

with open('../dashboard/public/data/tft_load.json', 'w') as f:
    json.dump(output, f, indent=2)

print('Saved tft_load.json')
print(f"Raw  MAE: {mae_tft:.0f} MW → Corrected MAE: {corr_mae:.0f} MW (TSO: {mae_tso:.0f} MW)")

Export JSON for dashboard

In [9]:
import os
os.makedirs('../dashboard/public/data', exist_ok=True)

sample_data = []
for w in sample_windows:
    if w >= len(all_preds):
        break
    for h in range(prediction_length):
        t = pd.Timestamp(all_times[w][h])
        sample_data.append({
            'time': t.strftime('%Y-%m-%d %H:%M'),
            'actual': round(float(all_actuals[w][h]), 1),
            'predicted': round(float(all_preds[w][h]), 1),
            'tso': round(float(all_tso[w][h]), 1),
        })

output = {
    'target': 'load',
    'model': 'TFT (Temporal Fusion Transformer)',
    'prediction_length_hours': prediction_length,
    'context_length_hours': context_length,
    'metrics': {
        'mae': round(float(mae_tft), 1),
        'rmse': round(float(rmse_tft), 1),
        'mape': round(float(mape_tft), 1),
        'tso_mae': round(float(mae_tso), 1),
        'tso_rmse': round(float(rmse_tso), 1),
        'tso_mape': round(float(mape_tso), 1),
    },
    'sample_forecast': sample_data,
}

with open('../dashboard/public/data/tft_load.json', 'w') as f:
    json.dump(output, f, indent=2)

print('Saved tft_load.json')
print(f"TFT  MAE: {output['metrics']['mae']} MW | RMSE: {output['metrics']['rmse']} MW | MAPE: {output['metrics']['mape']}%")
print(f"TSO  MAE: {output['metrics']['tso_mae']} MW | RMSE: {output['metrics']['tso_rmse']} MW | MAPE: {output['metrics']['tso_mape']}%")

Saved tft_load.json
TFT  MAE: 2128.2 MW | RMSE: 2720.1 MW | MAPE: 7.4%
TSO  MAE: 270.1 MW | RMSE: 389.7 MW | MAPE: 0.9%
