In [1]:
import pandas as pd
import catboost
from sklearn.metrics import mean_squared_error

import utils

utils.configs.setup()

In [2]:
def load_data(seed: int):
    raw = utils.datasets.make_electricity_data(
        "2016-01-01", "2024-01-01", freq="15min", random_state=seed
    )
    return raw


def load_splits(seed: int, features: list[str]):
    raw = load_data(seed)
    display(raw.head(3))

    raw = raw.loc[:, features]
    data = utils.transformations.minute_to_daily(raw)
    display(data.head(3))
    train_end_date = "2022-01-01"
    validation_end_date = "2023-01-01"
    train, validation, test = utils.splits.to_train_validation_test_data(
        data, train_end_date, validation_end_date
    )
    return train, validation, test

In [3]:
def delay(df, delays: int | list[int]):
    if isinstance(df, pd.Series):
        df = df.to_frame()
    dfs = [df]
    if isinstance(delays, int):
        delays = range(1, delays + 1)
    for t in delays:
        delayed_df = df.shift(t)
        delayed_df.columns = [f"{c}_m{t}" for c in delayed_df.columns]
        dfs.append(delayed_df)
    vstacked_df = pd.concat(reversed(dfs), axis=1).dropna()
    return vstacked_df


In [4]:
def get_columns_by_time(df, time: str):
    time = time.replace(':', '_')
    columns = [c for c in df.columns if c.endswith(time)]
    selection = df.loc[:, columns]
    return selection

In [5]:
def evaluate(time: str, delays, train, validation, **kwargs):
    train = get_columns_by_time(train, time)
    validation = get_columns_by_time(validation, time)
    train_tf = delay(train, delays)
    val_tf = delay(validation, delays)
    model = catboost.CatBoostRegressor(**kwargs)
    X, y = train_tf.iloc[:, :-1], train_tf.iloc[:, -1]
    model.fit(X, y);
    y_pred = model.predict(X)
    train_mse = mean_squared_error(y, y_pred)
    val_mse = mean_squared_error(val_tf.iloc[:, -1], model.predict(val_tf.iloc[:, :-1]))

    print()
    print(f"{[c for c in X.columns]} -> {y.name}")
    print("Train MSE:\t", train_mse)
    print("Validation MSE:\t", val_mse)

In [6]:
seed = 42
columns = ["electricity"]
train, validation, test = load_splits(seed, columns)
evaluate(time="00:00", delays=1, train=train, validation=validation)

2025-01-15 10:30:29,020 - INFO - Setting numpy seed to: 42
2025-01-15 10:30:29,162 - INFO - Shape: (280512, 6) | Start: 2016-01-01 00:00:00 | End: 2023-12-31 23:45:00
2025-01-15 10:30:29,163 - INFO - Columns: ['electricity', 'wind_speed', 'wind_speed_no_seasonality', 'daily_seasonality', 'weekly_seasonality', 'yearly_seasonality']


Unnamed: 0_level_0,electricity,wind_speed,wind_speed_no_seasonality,daily_seasonality,weekly_seasonality,yearly_seasonality
time,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
2016-01-01 00:00:00,6.48,8.1,8.99,0.0,-0.43,0.02
2016-01-01 00:15:00,4.48,6.49,7.72,0.0,-0.43,0.02
2016-01-01 00:30:00,6.55,8.18,9.3,0.0,-0.43,0.02


2025-01-15 10:30:33,620 - INFO - Frequency change: 15min -> 1d
2025-01-15 10:30:33,622 - INFO - Shape change: (280512, 1) -> (2922, 96)


Unnamed: 0_level_0,electricity_00_00,electricity_00_15,electricity_00_30,electricity_00_45,electricity_01_00,electricity_01_15,electricity_01_30,electricity_01_45,electricity_02_00,electricity_02_15,...,electricity_21_30,electricity_21_45,electricity_22_00,electricity_22_15,electricity_22_30,electricity_22_45,electricity_23_00,electricity_23_15,electricity_23_30,electricity_23_45
date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
2016-01-01,6.48,4.48,6.55,12.0,4.92,4.38,11.35,7.93,6.13,9.99,...,4.07,4.51,1.8,5.49,4.46,3.81,2.39,3.44,2.66,0.1
2016-01-02,6.1,3.43,2.4,4.4,3.24,3.78,1.64,1.68,6.56,9.16,...,0.94,2.35,0.9,5.39,4.65,2.45,1.41,1.85,2.29,4.65
2016-01-03,4.68,4.61,4.12,3.79,3.37,6.92,5.25,2.99,7.21,7.51,...,2.79,0.99,2.55,4.59,5.08,2.08,8.99,1.33,2.72,5.55


2025-01-15 10:30:33,685 - INFO - # of training observations: 2192 | 75.02%
2025-01-15 10:30:33,686 - INFO - # of validation observations: 365 | 12.49%
2025-01-15 10:30:33,690 - INFO - # of test observations: 365 | 12.49%


Learning rate set to 0.046345
0:	learn: 6.2570864	total: 50.5ms	remaining: 50.4s
1:	learn: 6.1622966	total: 52.5ms	remaining: 26.2s
2:	learn: 6.0761080	total: 54.2ms	remaining: 18s
3:	learn: 5.9945680	total: 55.4ms	remaining: 13.8s
4:	learn: 5.9187205	total: 68.2ms	remaining: 13.6s
5:	learn: 5.8494967	total: 69.3ms	remaining: 11.5s
6:	learn: 5.7847737	total: 76.4ms	remaining: 10.8s
7:	learn: 5.7251342	total: 80.2ms	remaining: 9.94s
8:	learn: 5.6707384	total: 82ms	remaining: 9.03s
9:	learn: 5.6187378	total: 84.6ms	remaining: 8.37s
10:	learn: 5.5717320	total: 85.5ms	remaining: 7.69s
11:	learn: 5.5282375	total: 86.4ms	remaining: 7.12s
12:	learn: 5.4892303	total: 87.1ms	remaining: 6.62s
13:	learn: 5.4528052	total: 88ms	remaining: 6.2s
14:	learn: 5.4199984	total: 89ms	remaining: 5.84s
15:	learn: 5.3898340	total: 90.3ms	remaining: 5.55s
16:	learn: 5.3618117	total: 91.7ms	remaining: 5.3s
17:	learn: 5.3354751	total: 93ms	remaining: 5.07s
18:	learn: 5.3117364	total: 96.1ms	remaining: 4.96s
19:	

In [None]:
seed = 42
columns = ["electricity"]
train, validation, test = load_splits(seed, columns)
evaluate(time="00:00", delays=1, train=train, validation=validation, depth=8)

2024-12-17 22:15:19,562 - INFO - Setting numpy seed to: 42
2024-12-17 22:15:19,681 - INFO - Shape: (280512, 6) | Start: 2016-01-01 00:00:00 | End: 2023-12-31 23:45:00
2024-12-17 22:15:19,682 - INFO - Columns: ['electricity', 'wind_speed', 'wind_speed_no_seasonality', 'daily_seasonality', 'weekly_seasonality', 'yearly_seasonality']


Unnamed: 0_level_0,electricity,wind_speed,wind_speed_no_seasonality,daily_seasonality,weekly_seasonality,yearly_seasonality
time,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
2016-01-01 00:00:00,6.48,8.1,8.99,0.0,-0.43,0.02
2016-01-01 00:15:00,4.48,6.49,7.72,0.0,-0.43,0.02
2016-01-01 00:30:00,6.55,8.18,9.3,0.0,-0.43,0.02


2024-12-17 22:15:22,444 - INFO - Frequency change: 15min -> 1d
2024-12-17 22:15:22,445 - INFO - Shape change: (280512, 1) -> (2922, 96)


Unnamed: 0_level_0,electricity_00_00,electricity_00_15,electricity_00_30,electricity_00_45,electricity_01_00,electricity_01_15,electricity_01_30,electricity_01_45,electricity_02_00,electricity_02_15,...,electricity_21_30,electricity_21_45,electricity_22_00,electricity_22_15,electricity_22_30,electricity_22_45,electricity_23_00,electricity_23_15,electricity_23_30,electricity_23_45
date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
2016-01-01,6.48,4.48,6.55,12.0,4.92,4.38,11.35,7.93,6.13,9.99,...,4.07,4.51,1.8,5.49,4.46,3.81,2.39,3.44,2.66,0.1
2016-01-02,6.1,3.43,2.4,4.4,3.24,3.78,1.64,1.68,6.56,9.16,...,0.94,2.35,0.9,5.39,4.65,2.45,1.41,1.85,2.29,4.65
2016-01-03,4.68,4.61,4.12,3.79,3.37,6.92,5.25,2.99,7.21,7.51,...,2.79,0.99,2.55,4.59,5.08,2.08,8.99,1.33,2.72,5.55


2024-12-17 22:15:22,573 - INFO - # of training observations: 2192 | 75.02%
2024-12-17 22:15:22,588 - INFO - # of validation observations: 365 | 12.49%
2024-12-17 22:15:22,590 - INFO - # of test observations: 365 | 12.49%


Learning rate set to 0.046345
0:	learn: 6.2567793	total: 4.18ms	remaining: 4.17s
1:	learn: 6.1613222	total: 8.41ms	remaining: 4.2s
2:	learn: 6.0720175	total: 11.7ms	remaining: 3.9s
3:	learn: 5.9893592	total: 15.5ms	remaining: 3.87s
4:	learn: 5.9147267	total: 17.7ms	remaining: 3.52s
5:	learn: 5.8464353	total: 20.1ms	remaining: 3.33s
6:	learn: 5.7805500	total: 24.5ms	remaining: 3.47s
7:	learn: 5.7202339	total: 27.4ms	remaining: 3.4s
8:	learn: 5.6651122	total: 29.8ms	remaining: 3.28s
9:	learn: 5.6145434	total: 31.8ms	remaining: 3.15s
10:	learn: 5.5692125	total: 32.6ms	remaining: 2.93s
11:	learn: 5.5251459	total: 34.7ms	remaining: 2.85s
12:	learn: 5.4858858	total: 36.7ms	remaining: 2.78s
13:	learn: 5.4495140	total: 39.2ms	remaining: 2.76s
14:	learn: 5.4161770	total: 41.8ms	remaining: 2.75s
15:	learn: 5.3846414	total: 44.3ms	remaining: 2.72s
16:	learn: 5.3561330	total: 46.1ms	remaining: 2.67s
17:	learn: 5.3303510	total: 47.5ms	remaining: 2.59s
18:	learn: 5.3056357	total: 49.8ms	remaining: 2

In [12]:
seed = 42
columns = ["wind_speed", "electricity"]
train, validation, test = load_splits(seed, columns)
evaluate(time="00:00", delays=1, train=train, validation=validation)


2024-12-17 22:13:20,891 - INFO - Setting numpy seed to: 42
2024-12-17 22:13:21,008 - INFO - Shape: (280512, 6) | Start: 2016-01-01 00:00:00 | End: 2023-12-31 23:45:00
2024-12-17 22:13:21,009 - INFO - Columns: ['electricity', 'wind_speed', 'wind_speed_no_seasonality', 'daily_seasonality', 'weekly_seasonality', 'yearly_seasonality']


Unnamed: 0_level_0,electricity,wind_speed,wind_speed_no_seasonality,daily_seasonality,weekly_seasonality,yearly_seasonality
time,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
2016-01-01 00:00:00,6.48,8.1,8.99,0.0,-0.43,0.02
2016-01-01 00:15:00,4.48,6.49,7.72,0.0,-0.43,0.02
2016-01-01 00:30:00,6.55,8.18,9.3,0.0,-0.43,0.02


2024-12-17 22:13:23,320 - INFO - Frequency change: 15min -> 1d
2024-12-17 22:13:23,321 - INFO - Shape change: (280512, 2) -> (2922, 192)


Unnamed: 0_level_0,wind_speed_00_00,wind_speed_00_15,wind_speed_00_30,wind_speed_00_45,wind_speed_01_00,wind_speed_01_15,wind_speed_01_30,wind_speed_01_45,wind_speed_02_00,wind_speed_02_15,...,electricity_21_30,electricity_21_45,electricity_22_00,electricity_22_15,electricity_22_30,electricity_22_45,electricity_23_00,electricity_23_15,electricity_23_30,electricity_23_45
date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
2016-01-01,8.1,6.49,8.18,9.82,7.6,6.65,11.3,8.96,8.49,9.56,...,4.07,4.51,1.8,5.49,4.46,3.81,2.39,3.44,2.66,0.1
2016-01-02,6.24,6.31,6.15,5.61,3.92,6.44,5.86,4.3,6.96,9.56,...,0.94,2.35,0.9,5.39,4.65,2.45,1.41,1.85,2.29,4.65
2016-01-03,7.5,4.77,6.88,8.17,5.88,7.67,7.21,4.95,8.36,8.88,...,2.79,0.99,2.55,4.59,5.08,2.08,8.99,1.33,2.72,5.55


2024-12-17 22:13:23,359 - INFO - # of training observations: 2192 | 75.02%
2024-12-17 22:13:23,359 - INFO - # of validation observations: 365 | 12.49%
2024-12-17 22:13:23,360 - INFO - # of test observations: 365 | 12.49%


Learning rate set to 0.046345
0:	learn: 6.0969690	total: 1.3ms	remaining: 1.3s
1:	learn: 5.8483750	total: 2.44ms	remaining: 1.22s
2:	learn: 5.6088837	total: 3.62ms	remaining: 1.2s
3:	learn: 5.3822463	total: 5.08ms	remaining: 1.27s
4:	learn: 5.1734169	total: 7.13ms	remaining: 1.42s
5:	learn: 4.9654874	total: 8.6ms	remaining: 1.42s
6:	learn: 4.7658895	total: 9.74ms	remaining: 1.38s
7:	learn: 4.5779834	total: 10.9ms	remaining: 1.35s
8:	learn: 4.3957306	total: 12.7ms	remaining: 1.4s
9:	learn: 4.2193995	total: 14ms	remaining: 1.39s
10:	learn: 4.0548318	total: 15.4ms	remaining: 1.39s
11:	learn: 3.9030634	total: 16.8ms	remaining: 1.39s
12:	learn: 3.7544359	total: 19.5ms	remaining: 1.48s
13:	learn: 3.6135540	total: 20.8ms	remaining: 1.47s
14:	learn: 3.4808709	total: 21.8ms	remaining: 1.43s
15:	learn: 3.3503013	total: 22.8ms	remaining: 1.4s
16:	learn: 3.2286237	total: 23.9ms	remaining: 1.38s
17:	learn: 3.1104316	total: 24.9ms	remaining: 1.36s
18:	learn: 2.9957887	total: 26ms	remaining: 1.34s
19