In [None]:
import gc
import os
import pathlib
import tempfile
from datetime import date

import hvplot
import numpy as np
import polars as pl
from bokeh.models import DatetimeTickFormatter

import temporal_fusion_transformer as tft
from temporal_fusion_transformer.src.datasets import electricity
import sys


def reload_tft_module():
    del sys.modules["temporal_fusion_transformer"]
    import temporal_fusion_transformer as tft


xformatter = DatetimeTickFormatter(months="%b %Y")
hvplot.extension("bokeh")

In [None]:
LEFT_CUTOFF_YEAR = 2013
# 7 * 24
ENCODER_STEPS = 168
# 8 * 24
TOTAL_TIME_STEPS = 192

In [None]:
def convert_to_parquet(download_dir: str):
    if pathlib.Path(f"{download_dir}/LD2011_2014.parquet").is_file():
        print("Found LD2011_2014.parquet, will re-use it.")
        return

    with open(f"{download_dir}/LD2011_2014.txt") as file:
        txt_content = file.read()

    csv_content = txt_content.replace(",", ".").replace(";", ",")

    with tempfile.TemporaryDirectory() as tmpdir:
        with open(f"{tmpdir}/LD2011_2014.csv", "w+") as file:
            file.write(csv_content)

        pl.scan_csv(
            f"{tmpdir}/LD2011_2014.csv", infer_schema_length=999999, try_parse_dates=True
        ).rename({"": "timestamp"}).sink_parquet(f"{download_dir}/LD2011_2014.parquet")

        os.remove(f"{download_dir}/LD2011_2014.txt")


convert_to_parquet("../data/electricity")

In [None]:
raw_df = pl.read_parquet("../data/electricity/LD2011_2014.parquet")
raw_df.head()

In [None]:
raw_df.select("timestamp").describe(percentiles=None)

In [None]:
def format_raw_df(dataframe: pl.DataFrame) -> pl.DataFrame:
    timeseries_ids = dataframe.columns[1:]

    lf = dataframe.rename({"timestamp": "ts"}).lazy()
    lf_list = []

    for label in timeseries_ids:
        sub_lf = lf.select("ts", label)
        sub_lf = (
            sub_lf.rename({label: "y"})
            # down sample to 1h https://pola-rs.github.io/polars-book/user-guide/transformations/time-series/rolling/
            .sort("ts")
            .group_by_dynamic("ts", every="1h")
            .agg(pl.col("y").mean())
            .with_columns(
                [
                    pl.col("y").cast(pl.Float32),
                    pl.col("ts").dt.year().alias("year").cast(pl.UInt16),
                    pl.col("ts").dt.month().alias("month").cast(pl.UInt8),
                    pl.col("ts").dt.hour().alias("hour").cast(pl.UInt8),
                    pl.col("ts").dt.day().alias("day").cast(pl.UInt8),
                    pl.col("ts").dt.weekday().alias("day_of_week").cast(pl.UInt8),
                ],
                id=pl.lit(label),
            )
        )
        lf_list.append(sub_lf)

    df = pl.concat(pl.collect_all(lf_list)).shrink_to_fit(in_place=True).rechunk()
    return df.select("id", "ts", "year", "month", "day", "day_of_week", "hour", "y")


formatted_df = format_raw_df(raw_df)
formatted_df.head(10)

In [None]:
formatted_df.null_count()

In [None]:
formatted_df.select("ts").describe(percentiles=None)

In [None]:
formatted_df.select("id").head()

In [None]:
validation_boundary = date(2015, 6, 1)
tft.utils.plot_split(formatted_df, validation_boundary, groupby="id")

In [None]:
filtered_df = formatted_df.filter(pl.col("ts").dt.year() >= LEFT_CUTOFF_YEAR)
validation_boundary = date(2014, 10, 1)
tft.utils.plot_split(filtered_df, validation_boundary, groupby="id", autorange="x")

In [None]:
preprocessor = electricity.Preprocessor()
preprocessor

In [None]:
preprocessor.target

In [None]:
preprocessor.fit(filtered_df)
preprocessor

In [None]:
processed_df = preprocessor.transform(filtered_df)
processed_df.head(10)

In [None]:
training_df, test_df = tft.utils.split_dataframe(processed_df, validation_boundary)
len(training_df), len(test_df)

In [None]:
import mlx.core as mx


def make_time_series_array(dataframe: pl.DataFrame) -> np.ndarray:
    ts_list = []
    for _, dataframe_i in dataframe.group_by(["id"]):
        ts_i = tft.utils.timeseries_from_array(preprocessor.to_array(dataframe_i), TOTAL_TIME_STEPS)
        ts_list.append(ts_i)

    return mx.concatenate(ts_list, axis=0)


train_arr = make_time_series_array(training_df)
test_arr = make_time_series_array(test_df)

train_arr.shape, test_arr.shape