In [11]:
import joblib
import numpy as np
import pandas as pd

In [12]:
def create_features_from_timestamp(ts, hist_df, zone):
    # Lag features
    lag_1 = hist_df.at[ts - pd.Timedelta(minutes=5), zone]
    lag_24 = hist_df.at[ts - pd.Timedelta(hours=2), zone]
    lag_168 = hist_df.at[ts - pd.Timedelta(days=7), zone]

    # Time-based features (matching training logic)
    hour = ts.hour
    minute_bin = ts.minute // 5
    weekday = ts.weekday()
    month = ts.month

    features = {
        "is_weekday": int(weekday < 5),
        "hour_sin": np.sin(2 * np.pi * hour / 24),
        "hour_cos": np.cos(2 * np.pi * hour / 24),
        "bin5_sin": np.sin(2 * np.pi * minute_bin / 12),
        "bin5_cos": np.cos(2 * np.pi * minute_bin / 12),
        "dow_sin": np.sin(2 * np.pi * weekday / 7),
        "dow_cos": np.cos(2 * np.pi * weekday / 7),
        "month_sin": np.sin(2 * np.pi * month / 12),
        "month_cos": np.cos(2 * np.pi * month / 12),
        "lag_1": lag_1,
        "lag_24": lag_24,
        "lag_168": lag_168
    }

    return pd.DataFrame([features])

def make_forecast(time_step,zone_ids,historic_df,models):
    
    prediction = {}

    for zone in zone_ids:
        X = create_features_from_timestamp(time_step, historic_df, zone)
        if X is not None:
            model = models[zone]
            prediction[zone] = model.predict(X)[0]
        else:
            prediction[zone] = np.nan

    return prediction


In [None]:
# Load all processed data from Jan to Dec
processed_folder = "data/processed"
file_names = [f"processed_{i:02d}.csv" for i in range(1, 8)]

historic_df = pd.concat([
    pd.read_csv(f"{processed_folder}/{fname}", index_col=0, parse_dates=True)
    for fname in file_names
])
historic_df.index = pd.to_datetime(historic_df.index)

# Load trained models
model_folder = "data/models"
zone_ids = historic_df.columns
models = {zone: joblib.load(f"{model_folder}/xgb_zone_{zone}.joblib") for zone in zone_ids}

# Load model error distribution
error_distribution = pd.read_csv("data/models/val_error_distributions.csv",index_col=0)

In [14]:
# Run prediction loop for last week of Dec 2010
start_ts = pd.Timestamp("2010-07-25 00:00:00")
end_ts = pd.Timestamp("2010-07-31 00:05:00")

timestamps = pd.date_range(start=start_ts, end=end_ts, freq="5min")

In [33]:
# --- Inputs ---
time_step = timestamps[0]  # specify the timestamp
prediction = make_forecast(time_step, zone_ids, historic_df, models)  # shape: (n_zones,)

# --- Load validation error distribution ---
val_err_df = pd.read_csv("data/models/val_error_distributions.csv", index_col=0)
val_err_df.index = val_err_df.index.astype(int)  # ensure zone IDs are int

# --- Generate samples ---
n_samples = 100
samples = {}

for zone in zone_ids:
    mean_forecast = prediction[zone]
    err_mean = val_err_df.loc[int(zone), "mean"]
    err_std = val_err_df.loc[int(zone), "std"]
    
    # Sample from normal distribution centered at forecast + residual bias
    zone_samples = np.random.normal(loc=mean_forecast + err_mean, scale=err_std, size=n_samples)
    samples[zone] = zone_samples

# --- Convert to DataFrame ---
samples_df = pd.DataFrame(samples)
samples_df.index.name = "sample_id"


In [34]:
prediction

{'4': 19.603739,
 '12': 0.66072917,
 '13': 14.678369,
 '24': 8.51144,
 '41': 13.644029,
 '42': 10.202907,
 '43': 11.408259,
 '45': 4.019348,
 '48': 70.93027,
 '50': 41.64568,
 '68': 80.68585,
 '74': 13.635518,
 '75': 17.813421,
 '79': 148.42027,
 '87': 26.93339,
 '88': 6.4559917,
 '90': 39.264954,
 '100': 17.806334,
 '107': 67.07496,
 '113': 40.446453,
 '114': 50.26215,
 '116': 10.237096,
 '120': 0.1622789,
 '125': 18.916723,
 '127': 3.3276725,
 '128': 0.1517544,
 '137': 39.64742,
 '140': 26.163395,
 '141': 48.291862,
 '142': 48.35974,
 '143': 23.2319,
 '144': 36.64362,
 '148': 81.34176,
 '151': 16.240034,
 '152': 4.796307,
 '153': 0.1860685,
 '158': 57.00488,
 '161': 34.860092,
 '162': 52.357647,
 '163': 29.926594,
 '164': 45.68096,
 '166': 11.510589,
 '170': 70.41677,
 '186': 46.395634,
 '194': 0.18467309,
 '202': 2.4496255,
 '209': 4.4667797,
 '211': 22.459225,
 '224': 13.755243,
 '229': 32.47827,
 '230': 53.51555,
 '231': 38.238987,
 '232': 15.5765505,
 '233': 27.303534,
 '234': 64

In [35]:
samples_df

Unnamed: 0_level_0,4,12,13,24,41,42,43,45,48,50,...,237,238,239,243,244,246,249,261,262,263
sample_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0,15.596598,2.387888,10.816958,4.308836,14.318166,10.683886,9.159829,6.256551,73.482481,42.024085,...,27.028128,25.528201,42.121228,9.845412,7.764008,36.513116,81.051249,5.791815,31.977365,43.791245
1,23.767902,2.083799,6.970364,6.910254,21.080408,8.043932,3.134550,2.726856,68.719315,45.394564,...,13.992778,23.634130,45.807302,7.216891,13.065395,46.254944,70.846442,16.030020,28.323884,55.599041
2,20.191530,-0.328250,11.048689,3.751459,6.952077,10.768917,3.685478,4.798557,86.395380,38.308235,...,14.038531,23.204340,55.294107,6.558637,12.559458,52.299480,66.328932,15.121200,31.656896,50.055164
3,22.496108,1.076483,7.059430,10.678689,15.746479,9.122063,3.551740,5.191927,73.156673,43.184176,...,15.799845,21.184023,43.730109,9.778020,8.484358,55.276194,59.433910,7.026239,25.056769,55.920010
4,18.654842,3.097986,18.481384,11.712479,15.449322,12.061199,9.516206,5.413791,69.461459,34.772440,...,4.968247,23.983701,31.045379,5.866747,7.991793,54.235545,68.246103,10.065087,17.980412,59.196088
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95,20.557057,-0.293997,15.738333,13.098387,16.197667,16.500761,17.138573,3.448980,84.117835,35.973079,...,14.481323,28.787383,38.751914,8.257661,6.769053,44.623665,66.690826,11.187865,29.514112,52.053275
96,19.603081,0.422226,16.652133,7.626990,17.072284,8.062043,10.883048,7.790260,64.030064,37.770444,...,7.651663,35.431274,44.452499,9.403511,15.065733,52.148339,58.705995,6.850183,23.073234,51.438247
97,18.670172,1.462309,11.104611,10.601891,13.587105,8.652042,18.727491,5.697066,64.437276,44.081638,...,15.995439,33.668405,38.965651,6.451272,18.540951,55.688505,77.331339,7.040478,25.220787,40.066492
98,14.516750,-0.350950,2.961561,10.230233,14.848056,9.745804,11.045042,3.066028,78.178301,38.926541,...,15.798862,15.494554,37.566673,11.445479,10.336270,37.544298,72.213558,9.138366,18.444825,46.422588
