In [None]:
import xarray as xr
import ocf_blosc2
import pandas as pd
import numpy as np
import warnings
warnings.filterwarnings("ignore")
from tqdm import tqdm
import lightning.pytorch as pl
import pytorch_forecasting as pf
from pytorch_forecasting.metrics import QuantileLoss
from lightning.pytorch import Trainer
from pytorch_forecasting import Baseline, TemporalFusionTransformer, TimeSeriesDataSet
from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor
from lightning.pytorch.loggers import TensorBoardLogger
import torch
from pytorch_forecasting.models.temporal_fusion_transformer import TemporalFusionTransformer
import seaborn as sns
import matplotlib.pyplot as plt
from lightning.pytorch.tuner.tuning import Tuner

In [None]:
nwp_data = xr.open_dataset("../../../mnt/disks/gcp_data/nwp/ecmwf/UK_v2.zarr")
meta_data = pd.read_csv("data_files/metadata.csv")
pv_data = xr.open_dataset("data_files/pv.netcdf", engine='h5netcdf')

In [None]:
skip_ss_ids = ['8440', '16718', '8715', '17073', '9108', '9172', '10167', '10205', '10207', '10278', '26778', '26819', '10437', '10466', '26915', '10547', '26939', '26971', '10685', '10689', '2638', '2661', '2754', '2777', '2783', '2786', '2793', '2812', '2829', '2830', '2867', '2883', '2904', '2923', '2947', '2976', '2989', '2999', '3003', '3086', '3118', '3123', '3125', '3264', '3266', '3271', '3313', '3334', '3470', '3502', '11769', '11828', '11962', '3772', '11983', '3866', '3869', '4056', '4067', '4116', '4117', '4124', '4323', '4420', '20857', '4754', '13387', '13415', '5755', '5861', '5990', '6026', '6038', '6054', '14455', '6383', '6430', '6440', '6478', '6488', '6541', '6548', '6560', '14786', '6630', '6804', '6849', '6868', '6870', '6878', '6901', '6971', '7055', '7111', '7124', '7132', '7143', '7154', '7155', '7156', '7158', '7201', '7237', '7268', '7289', '7294', '7311', '7329', '7339', '7379', '7392', '7479', '7638', '7695', '7772', '15967', '7890', '16215', '7830']
hourly_pv_data = pv_data.sel(datetime=pv_data['datetime'].dt.minute == 0)
valid_ss_ids_data = [var for var in hourly_pv_data.data_vars if var not in skip_ss_ids]
pv_sites_id = np.random.choice(valid_ss_ids_data, 500, replace=False)
filtered_hourly_pv_data = hourly_pv_data[pv_sites_id]

In [None]:
def get_36_hour_range(start_datetime, hours=36):
    end_datetime = start_datetime + pd.Timedelta(hours=hours - 1, minutes=59)
    return start_datetime, end_datetime

def select_non_overlapping_datetimes(datetimes, num_selections, min_gap_hours):
    selected_datetimes = []
    available_datetimes = list(datetimes)
    for _ in range(num_selections):
        if not available_datetimes:
            break
        random_datetime = np.random.choice(available_datetimes)
        selected_datetimes.append(random_datetime)
        available_datetimes = [dt for dt in available_datetimes if dt > random_datetime + pd.Timedelta(hours=min_gap_hours)]
    return selected_datetimes

In [None]:
datetimes = pd.to_datetime(filtered_hourly_pv_data['datetime'].values)
data_dict = {'ss_id': [], 'pv_datetime': [], 'generation' : [], 'horizon':[]}
batch_size = 36
num_selections = 5000
min_gap_hours = 36

for ss_id in pv_sites_id:
    selected_datetimes = select_non_overlapping_datetimes(datetimes, num_selections, min_gap_hours)
    for start_datetime in selected_datetimes:
        start, end = get_36_hour_range(start_datetime, hours=batch_size)
        selected_data = hourly_pv_data.sel(datetime=slice(start, end))
        if len(selected_data['datetime']) < batch_size or selected_data[ss_id].isnull().any():
            continue
        hour_counter = 1
        batch_data = {'ss_id': [], 'pv_datetime': [], 'generation': [], 'horizon': []}
        for dt, power in zip(selected_data['datetime'].values, selected_data[ss_id].values):
            batch_data['ss_id'].append(int(ss_id))
            batch_data['pv_datetime'].append(dt)
            batch_data['generation'].append(power)
            batch_data['horizon'].append(hour_counter)
            hour_counter += 1
        if hour_counter - 1 == batch_size:
            for key in data_dict.keys():
                data_dict[key].extend(batch_data[key])

pv_df = pd.DataFrame(data_dict)
pv_df = pv_df.dropna(subset={'generation'})

In [None]:
pv_sites_id = [int(id) for id in pv_sites_id]
pv_site_dict = {'ss_id':[], "lat":[], "long": [], 'tilt':[], 'orientation':[], 'kwp':[]}

for id in pv_sites_id:
    row = meta_data[meta_data['ss_id'] == id]
    if not row.empty:
        pv_site_dict['ss_id'].append(id)
        pv_site_dict['lat'].append(row['latitude_rounded'].values[0])
        pv_site_dict['long'].append(row['longitude_rounded'].values[0])
        pv_site_dict['tilt'].append(row['tilt'].values[0])
        pv_site_dict['orientation'].append(row['orientation'].values[0])
        pv_site_dict['kwp'].append(row['kwp'].values[0])

meta_site_df = pd.DataFrame.from_dict(pv_site_dict)
combined_df = pd.merge(pv_df, meta_site_df, on='ss_id', how='inner')
combined_df['pv_datetime'] = pd.to_datetime(combined_df['pv_datetime'])
combined_df['pv_date'] = combined_df['pv_datetime'].dt.date
combined_df['pv_hour'] = combined_df['pv_datetime'].dt.hour

In [None]:
results = []
batch_size = 36
counter = 0

for i in tqdm(range(0, len(combined_df), batch_size), desc="Processing batches"):
    batch = combined_df.iloc[i:i + batch_size]
    if len(batch) < batch_size:
        continue
    initial_time = batch.iloc[0]['pv_datetime']
    lat = batch.iloc[0]['lat']
    lon = batch.iloc[0]['long']
    nwp_sel = nwp_data.sel(latitude=lat, method="nearest").sel(longitude=lon, method="nearest")
    init_time_sel = nwp_sel.sel(init_time=initial_time, method="ffill")
    if init_time_sel.init_time.size == 0:
        continue
    data_sel = init_time_sel.sel(step=slice(pd.Timedelta(hours=0), pd.Timedelta(hours=35)))
    data_df = data_sel.to_dataframe().reset_index()
    pivot_df = data_df.pivot_table(index=['init_time', 'step'], columns='variable', values='ECMWF_UK').reset_index()
    if len(pivot_df) < batch_size:
        continue
    for j in range(batch_size):
        pivot_df.loc[j, 'ss_id'] = batch.iloc[j]['ss_id']
        pivot_df.loc[j, 'pv_datetime'] = batch.iloc[j]['pv_datetime']
        pivot_df.loc[j, 'generation'] = batch.iloc[j]['generation']
        pivot_df.loc[j, 'horizon'] = batch.iloc[j]['horizon']
        pivot_df.loc[j, 'lat'] = lat
        pivot_df.loc[j, 'long'] = lon
        pivot_df.loc[j, 'tilt'] = batch.iloc[j]['tilt']
        pivot_df.loc[j, 'orientation'] = batch.iloc[j]['orientation']
        pivot_df.loc[j, 'kwp'] = batch.iloc[j]['kwp']
        pivot_df.loc[j, 'pv_hour'] = batch.iloc[j]['pv_hour']
    results.append(pivot_df)
    counter += 1

final_df = pd.concat(results, ignore_index=True)

In [None]:
cumulative_vars = ['dlwrf', 'dswrf', 'duvrs', 'sr']

def cumulative_to_instantaneous(group):
    for var in cumulative_vars:
        group[f'{var}'] = group[var].diff().fillna(group[var])
    return group

final_df = final_df.groupby(['ss_id', 'init_time']).apply(cumulative_to_instantaneous).reset_index(drop=True)
final_df['normalize_generation'] = final_df['generation']/final_df['kwp']
final_df = final_df.rename(columns={'kwp': 'capacity'})
desired_order = ['ss_id', 'init_time', 'step', 'pv_datetime', 'pv_hour', 'horizon', 'generation', 'capacity', 'normalize_generation', 'lat', 'long', 'tilt', 'orientation', 'dlwrf', 'dswrf', 'duvrs', 'hcc', 'lcc', 'mcc', 'sde', 'sr', 't2m', 'tcc', 'u10', 'u100', 'v10', 'v100']
final_df = final_df[desired_order]

In [None]:
rows_per_batch = 36
num_batches_to_keep = 100
rows_to_keep = rows_per_batch * num_batches_to_keep
train_data = final_df[:-rows_to_keep]
test_data = final_df[-rows_to_keep:]

In [None]:
forecast_data = train_data.copy()
forecast_data = forecast_data.drop(columns=['Unnamed: 0'], errors='ignore')
forecast_data.rename(columns={'pv_hour': 'day_hour'}, inplace=True)
forecast_data['ss_id'] = forecast_data['ss_id'].astype(int)
forecast_data['pv_datetime'] = pd.to_datetime(forecast_data['pv_datetime'])
forecast_data['date'] = forecast_data['pv_datetime'].dt.date
forecast_data['day_of_week'] = forecast_data['pv_datetime'].dt.dayofweek
forecast_data['month'] = forecast_data['pv_datetime'].dt.month

In [None]:
target_variable = 'normalize_generation'
static_features = ['ss_id', 'capacity', 'lat', 'long', 'tilt', 'orientation']
known_future_inputs = ['dlwrf', 'dswrf', 'duvrs', 'hcc', 'lcc', 'mcc', 'sde', 'sr', 't2m', 'tcc', 'u10', 'u100', 'v10', 'v100', 'day_of_week', 'month', 'day_hour']
required_columns = static_features + known_future_inputs + [target_variable, 'pv_datetime', 'date']
forecast_data = forecast_data[required_columns]
forecast_data = forecast_data.fillna(method='bfill').fillna(method='ffill')
forecast_data['time_idx'] = forecast_data.index
forecast_data['ss_id'] = forecast_data['ss_id'].astype(str)
forecast_data['day_of_week'] = forecast_data['day_of_week'].astype(str)
forecast_data['month'] = forecast_data['month'].astype(str)
forecast_data['day_hour'] = forecast_data['day_hour'].astype(str)
forecast_data['time_idx'] = forecast_data['time_idx'].astype(int)

In [None]:
max_encoder_length = 36
max_prediction_length = 36
training_cutoff = forecast_data["pv_datetime"].max() - pd.Timedelta(hours=max_prediction_length)
training_data = forecast_data[forecast_data["pv_datetime"] <= training_cutoff]

In [None]:
training = TimeSeriesDataSet(
    forecast_data[lambda x: x.pv_datetime <= training_cutoff],
    time_idx="time_idx",
    target="normalize_generation",
    group_ids=["ss_id"],
    min_encoder_length=max_encoder_length // 2,
    max_encoder_length=max_encoder_length,
    min_prediction_length=1,
    max_prediction_length=max_prediction_length,
    static_reals=["capacity", "lat", "long", "tilt", "orientation"],
    time_varying_known_categoricals=["month", "day_of_week", "day_hour"],
    time_varying_known_reals=["time_idx", "dlwrf", "dswrf", "duvrs", "hcc", "lcc", "mcc", "sde", "sr", "t2m", "tcc", "u10", "u100", "v10", "v100"],
    time_varying_unknown_reals=["normalize_generation"],
    add_relative_time_idx=True,
    add_target_scales=True,
    add_encoder_length=True,
    allow_missing_timesteps=False,
)

validation = TimeSeriesDataSet.from_dataset(training, forecast_data, min_prediction_idx=training.index.time.max() + 1, stop_randomization=True)
batch_size = 128
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=2)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size, num_workers=2)

In [None]:
early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-5, patience=10, verbose=False, mode="min")
lr_logger = LearningRateMonitor()
logger = TensorBoardLogger("lightning_logs", name="my_model")

In [None]:
tft = TemporalFusionTransformer.from_dataset(
    training,
    learning_rate=0.03,
    hidden_size=16,
    attention_head_size=1,
    dropout=0.1,
    hidden_continuous_size=8,
    output_size=7,
    loss=QuantileLoss(),
    log_interval=10,
    reduce_on_plateau_patience=4,
)

trainer = pl.Trainer(
    max_epochs=50,
    gradient_clip_val=0.1,
    limit_train_batches=30,
    callbacks=[lr_logger, early_stop_callback],
    logger=logger
)

In [None]:
tuner = Tuner(trainer)
res = tuner.lr_find(
    tft,
    train_dataloaders=train_dataloader,
    val_dataloaders=val_dataloader,
    max_lr=30.0,
    min_lr=1e-6,
)
print(f"suggested learning rate: {res.suggestion()}")
fig = res.plot(show=True, suggest=True)
fig.show()

In [None]:
trainer.fit(
    tft, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader,
)

In [None]:
best_model_path = trainer.checkpoint_callback.best_model_path
best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path)
actuals = torch.cat([y[0] for x, y in iter(val_dataloader)])
predictions = best_tft.predict(val_dataloader)

In [None]:
mask = ~torch.isnan(actuals) & ~torch.isnan(predictions)
actuals_filtered = actuals[mask]
predictions_filtered = predictions[mask]
mae = (actuals_filtered - predictions_filtered).abs().mean().item()
print(f"Mean Absolute Error: {mae}")

In [None]:
for i in range(min(10, len(actuals))):
    plt.figure(figsize=(10, 6))
    plt.plot(actuals[i].numpy(), label='Actuals', marker='o')
    plt.plot(predictions[i].numpy(), label='Predictions', marker='x')
    plt.title(f'Generation Forecast vs Actual for PV site {i+1}')
    plt.xlabel('Hour')
    plt.ylabel('Generation Value')
    plt.legend()
    plt.grid(True)
    plt.show()

In [None]:
test_data_processed = test_data.copy()
test_data_processed.rename(columns={'pv_hour': 'day_hour'}, inplace=True)
test_data_processed['ss_id'] = test_data_processed['ss_id'].astype(int)
test_data_processed['pv_datetime'] = pd.to_datetime(test_data_processed['pv_datetime'])
test_data_processed['date'] = test_data_processed['pv_datetime'].dt.date
test_data_processed['day_of_week'] = test_data_processed['pv_datetime'].dt.dayofweek
test_data_processed['month'] = test_data_processed['pv_datetime'].dt.month
test_data_processed['ss_id'] = test_data_processed['ss_id'].astype(str)
test_data_processed['day_of_week'] = test_data_processed['day_of_week'].astype(str)
test_data_processed['month'] = test_data_processed['month'].astype(str)
test_data_processed['day_hour'] = test_data_processed['day_hour'].astype(str)
test_data_processed['time_idx'] = test_data_processed.index
test_data_processed['time_idx'] = test_data_processed['time_idx'].astype(int)

In [None]:
new_data = TimeSeriesDataSet(
    test_data_processed,
    time_idx="time_idx",
    target="normalize_generation",
    group_ids=["ss_id"],
    min_encoder_length=max_encoder_length // 2,
    max_encoder_length=max_encoder_length,
    min_prediction_length=1,
    max_prediction_length=max_prediction_length,
    static_reals=["capacity", "lat", "long", "tilt", "orientation"],
    time_varying_known_categoricals=["month", "day_of_week", "day_hour"],
    time_varying_known_reals=["time_idx", "dlwrf", "dswrf", "duvrs", "hcc", "lcc", "mcc", "sde", "sr", "t2m", "tcc", "u10", "u100", "v10", "v100"],
    time_varying_unknown_reals=["normalize_generation"],
    add_relative_time_idx=True,
    add_target_scales=True,
    add_encoder_length=True,
    allow_missing_timesteps=False,
)

new_data_loader = new_data.to_dataloader(train=False, batch_size=128, num_workers=2)
test_actuals = torch.cat([y[0] for x, y in iter(new_data_loader)])
test_predictions = best_tft.predict(new_data_loader)
test_mae = (test_actuals - test_predictions).abs().mean().item()
print(f"Test Mean Absolute Error: {test_mae}")