In [1]:
import joblib
import numpy as np
import pandas as pd

In [2]:
def create_features_from_timestamp(ts, hist_df, zone):
    # Lag features
    lag_1 = hist_df.at[ts - pd.Timedelta(minutes=5), zone]
    lag_24 = hist_df.at[ts - pd.Timedelta(hours=2), zone]
    lag_168 = hist_df.at[ts - pd.Timedelta(days=7), zone]

    # Time-based features (matching training logic)
    hour = ts.hour
    minute_bin = ts.minute // 5
    weekday = ts.weekday()
    month = ts.month

    features = {
        "is_weekday": int(weekday < 5),
        "hour_sin": np.sin(2 * np.pi * hour / 24),
        "hour_cos": np.cos(2 * np.pi * hour / 24),
        "bin5_sin": np.sin(2 * np.pi * minute_bin / 12),
        "bin5_cos": np.cos(2 * np.pi * minute_bin / 12),
        "dow_sin": np.sin(2 * np.pi * weekday / 7),
        "dow_cos": np.cos(2 * np.pi * weekday / 7),
        "month_sin": np.sin(2 * np.pi * month / 12),
        "month_cos": np.cos(2 * np.pi * month / 12),
        "lag_1": lag_1,
        "lag_24": lag_24,
        "lag_168": lag_168
    }

    return pd.DataFrame([features])

def make_forecast(time_step,zone_ids,historic_df,models):
    
    prediction = {}

    for zone in zone_ids:
        X = create_features_from_timestamp(time_step, historic_df, zone)
        if X is not None:
            model = models[zone]
            prediction[zone] = model.predict(X)[0]
        else:
            prediction[zone] = np.nan

    return prediction


In [3]:
# Load all processed data from Jan to Dec
processed_folder = "data/processed"
file_names = [f"processed_{i:02d}.csv" for i in range(1, 8)]

historic_df = pd.concat([
    pd.read_csv(f"{processed_folder}/{fname}", index_col=0, parse_dates=True)
    for fname in file_names
])
historic_df.index = pd.to_datetime(historic_df.index)

# Load trained models
model_folder = "data/models"
zone_ids = historic_df.columns
models = {zone: joblib.load(f"{model_folder}/xgb_zone_{zone}.joblib") for zone in zone_ids}

# Load model error distribution
error_distribution = pd.read_csv("data/models/val_error_distributions.csv",index_col=0)

In [4]:
# Run prediction loop for last week of Dec 2010
start_ts = pd.Timestamp("2010-07-25 00:00:00")
end_ts = pd.Timestamp("2010-07-31 00:05:00")

timestamps = pd.date_range(start=start_ts, end=end_ts, freq="5min")

In [5]:
# --- Inputs ---
time_step = timestamps[0]  # specify the timestamp
prediction = make_forecast(time_step, zone_ids, historic_df, models)  # shape: (n_zones,)

# --- Load validation error distribution ---
val_err_df = pd.read_csv("data/models/val_error_distributions.csv", index_col=0)
val_err_df.index = val_err_df.index.astype(int)  # ensure zone IDs are int

# --- Generate samples ---
n_samples = 100
samples = {}

for zone in zone_ids:
    mean_forecast = prediction[zone]
    err_mean = val_err_df.loc[int(zone), "mean"]
    err_std = val_err_df.loc[int(zone), "std"]
    
    # Sample from normal distribution centered at forecast + residual bias
    zone_samples = np.random.normal(loc=mean_forecast + err_mean, scale=err_std, size=n_samples)
    samples[zone] = zone_samples

# --- Convert to DataFrame ---
samples_df = pd.DataFrame(samples)
samples_df.index.name = "sample_id"


In [6]:
prediction

{'4': 19.603739,
 '12': 0.66072917,
 '13': 14.678369,
 '24': 8.51144,
 '41': 13.644029,
 '42': 10.202907,
 '43': 11.408259,
 '45': 4.019348,
 '48': 70.93027,
 '50': 41.64568,
 '68': 80.68585,
 '74': 13.635518,
 '75': 17.813421,
 '79': 148.42027,
 '87': 26.93339,
 '88': 6.4559917,
 '90': 39.264954,
 '100': 17.806334,
 '107': 67.07496,
 '113': 40.446453,
 '114': 50.26215,
 '116': 10.237096,
 '120': 0.1622789,
 '125': 18.916723,
 '127': 3.3276725,
 '128': 0.1517544,
 '137': 39.64742,
 '140': 26.163395,
 '141': 48.291862,
 '142': 48.35974,
 '143': 23.2319,
 '144': 36.64362,
 '148': 81.34176,
 '151': 16.240034,
 '152': 4.796307,
 '153': 0.1860685,
 '158': 57.00488,
 '161': 34.860092,
 '162': 52.357647,
 '163': 29.926594,
 '164': 45.68096,
 '166': 11.510589,
 '170': 70.41677,
 '186': 46.395634,
 '194': 0.18467309,
 '202': 2.4496255,
 '209': 4.4667797,
 '211': 22.459225,
 '224': 13.755243,
 '229': 32.47827,
 '230': 53.51555,
 '231': 38.238987,
 '232': 15.5765505,
 '233': 27.303534,
 '234': 64

In [7]:
samples_df

Unnamed: 0_level_0,4,12,13,24,41,42,43,45,48,50,...,237,238,239,243,244,246,249,261,262,263
sample_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0,18.354072,-1.008295,12.564353,6.742511,16.025492,9.903598,20.121099,6.588510,59.385278,34.980314,...,21.019303,30.487976,53.794835,6.700305,10.130907,46.078761,62.036630,11.727839,29.584405,59.841177
1,19.422019,-0.814825,18.036073,9.921709,15.823188,9.474734,15.754035,3.176892,69.914466,46.133213,...,17.838029,34.110727,44.002605,7.960564,11.500894,48.115682,64.321926,8.242578,25.732730,44.769737
2,19.999583,1.340395,17.978331,10.612190,9.029508,10.042585,5.410044,7.081383,76.647140,39.307453,...,19.090982,32.898161,59.431512,8.850529,8.143627,52.619115,62.677026,14.306023,35.848709,58.696803
3,20.155916,1.590854,18.409401,11.444074,9.359299,12.395999,13.758663,4.555134,77.669901,40.271787,...,21.148301,25.344302,30.471044,5.291043,10.248033,53.183677,62.353867,9.612664,14.837072,50.456883
4,20.948052,0.488485,14.716490,9.478127,18.543421,10.553487,14.250963,2.643468,68.893479,46.265161,...,7.135995,31.779818,44.125747,8.255928,15.398304,51.638921,70.527108,9.948887,12.094630,52.920514
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95,24.307198,0.431868,22.517801,12.976152,5.793230,10.577495,4.178761,-0.325189,74.999216,53.619058,...,14.225758,18.310872,35.114085,6.042481,10.447623,45.593260,60.229838,14.143124,24.379172,59.874261
96,20.084647,1.587888,8.922301,7.338826,16.115632,13.569963,8.247848,4.058600,66.896175,39.062483,...,17.055870,37.543122,54.244636,6.395209,9.539313,51.282871,57.705269,10.552530,27.954461,47.979809
97,19.000876,1.074536,8.277801,8.336875,14.946398,9.951673,10.746079,6.814265,58.785204,38.117130,...,14.088018,28.943675,33.817934,8.693101,10.397432,53.331134,56.687141,9.168678,19.252974,53.999746
98,17.080731,-0.465514,12.904041,8.326206,16.425356,11.206333,10.538209,6.436367,52.438525,41.283599,...,0.602736,28.496870,38.516601,9.081727,14.680335,61.338200,70.488930,12.595631,24.945502,55.580830
