In [1]:
import gc
import os
import pathlib
import tempfile
from datetime import date

import hvplot
import numpy as np
import polars as pl
from bokeh.models import DatetimeTickFormatter

import temporal_fusion_transformer as tft
from temporal_fusion_transformer.src.datasets import electricity
import sys


def reload_tft_module():
    del sys.modules['temporal_fusion_transformer']
    import temporal_fusion_transformer as tft
    

xformatter = DatetimeTickFormatter(months="%b %Y")
hvplot.extension("bokeh")

In [2]:
LEFT_CUTOFF_YEAR = 2013
# 7 * 24
ENCODER_STEPS = 168  
 # 8 * 24
TOTAL_TIME_STEPS = 192  

In [3]:
def convert_to_parquet(download_dir: str):
    if pathlib.Path(f"{download_dir}/LD2011_2014.parquet").is_file():
        print("Found LD2011_2014.parquet, will re-use it.")
        return
    
    with open(f"{download_dir}/LD2011_2014.txt") as file:
        txt_content = file.read()
    
    csv_content = txt_content.replace(",", ".").replace(";", ",")
    
    with tempfile.TemporaryDirectory() as tmpdir:
        with open(f"{tmpdir}/LD2011_2014.csv", "w+") as file:
            file.write(csv_content)
        
        pl.scan_csv(
            f"{tmpdir}/LD2011_2014.csv", infer_schema_length=999999, try_parse_dates=True
        ).rename({"": "timestamp"}).sink_parquet(f"{download_dir}/LD2011_2014.parquet")
        
        os.remove(f"{download_dir}/LD2011_2014.txt")


convert_to_parquet('../data/electricity')

Found LD2011_2014.parquet, will re-use it.


In [4]:
raw_df = pl.read_parquet('../data/electricity/LD2011_2014.parquet')
raw_df.head()

timestamp,MT_001,MT_002,MT_003,MT_004,MT_005,MT_006,MT_007,MT_008,MT_009,MT_010,MT_011,MT_012,MT_013,MT_014,MT_015,MT_016,MT_017,MT_018,MT_019,MT_020,MT_021,MT_022,MT_023,MT_024,MT_025,MT_026,MT_027,MT_028,MT_029,MT_030,MT_031,MT_032,MT_033,MT_034,MT_035,MT_036,…,MT_334,MT_335,MT_336,MT_337,MT_338,MT_339,MT_340,MT_341,MT_342,MT_343,MT_344,MT_345,MT_346,MT_347,MT_348,MT_349,MT_350,MT_351,MT_352,MT_353,MT_354,MT_355,MT_356,MT_357,MT_358,MT_359,MT_360,MT_361,MT_362,MT_363,MT_364,MT_365,MT_366,MT_367,MT_368,MT_369,MT_370
datetime[μs],f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,…,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
2011-01-01 00:15:00,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,…,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2011-01-01 00:30:00,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,…,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2011-01-01 00:45:00,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,…,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2011-01-01 01:00:00,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,…,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2011-01-01 01:15:00,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,…,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [5]:
raw_df.select('timestamp').describe(percentiles=None)

statistic,timestamp
str,str
"""count""","""140256"""
"""null_count""","""0"""
"""mean""","""2012-12-31 12:…"
"""std""",
"""min""","""2011-01-01 00:…"
"""max""","""2015-01-01 00:…"


In [6]:
def format_raw_df(dataframe: pl.DataFrame) -> pl.DataFrame:
    timeseries_ids = dataframe.columns[1:]
    
    lf = dataframe.rename({"timestamp": "ts"}).lazy()
    lf_list = []
    
    for label in timeseries_ids:
        sub_lf = lf.select("ts", label)
        sub_lf = (
            sub_lf
            .rename({label: "y"})
            # down sample to 1h https://pola-rs.github.io/polars-book/user-guide/transformations/time-series/rolling/
            .sort("ts")
            .group_by_dynamic("ts", every="1h")
            .agg(pl.col('y').mean())
            .with_columns(
                [
                    pl.col("y").cast(pl.Float32),
                    pl.col("ts").dt.year().alias("year").cast(pl.UInt16),
                    pl.col("ts").dt.month().alias("month").cast(pl.UInt8),
                    pl.col("ts").dt.hour().alias("hour").cast(pl.UInt8),
                    pl.col("ts").dt.day().alias("day").cast(pl.UInt8),
                    pl.col("ts").dt.weekday().alias("day_of_week").cast(pl.UInt8),
                ],
                id=pl.lit(label),
            )
        )
        lf_list.append(sub_lf)
    
    df = pl.concat(pl.collect_all(lf_list)).shrink_to_fit(in_place=True).rechunk()
    return df.select('id', 'ts', 'year', 'month', 'day',  'day_of_week', 'hour','y')


formatted_df = format_raw_df(raw_df)
formatted_df.head(10)

id,ts,year,month,day,day_of_week,hour,y
str,datetime[μs],u16,u8,u8,u8,u8,f32
"""MT_001""",2011-01-01 00:00:00,2011,1,1,6,0,0.0
"""MT_001""",2011-01-01 01:00:00,2011,1,1,6,1,0.0
"""MT_001""",2011-01-01 02:00:00,2011,1,1,6,2,0.0
"""MT_001""",2011-01-01 03:00:00,2011,1,1,6,3,0.0
"""MT_001""",2011-01-01 04:00:00,2011,1,1,6,4,0.0
"""MT_001""",2011-01-01 05:00:00,2011,1,1,6,5,0.0
"""MT_001""",2011-01-01 06:00:00,2011,1,1,6,6,0.0
"""MT_001""",2011-01-01 07:00:00,2011,1,1,6,7,0.0
"""MT_001""",2011-01-01 08:00:00,2011,1,1,6,8,0.0
"""MT_001""",2011-01-01 09:00:00,2011,1,1,6,9,0.0


In [7]:
formatted_df.null_count()

id,ts,year,month,day,day_of_week,hour,y
u32,u32,u32,u32,u32,u32,u32,u32
0,0,0,0,0,0,0,0


In [8]:
formatted_df.select('ts').describe(percentiles=None)

statistic,ts
str,str
"""count""","""12974050"""
"""null_count""","""0"""
"""mean""","""2012-12-31 11:…"
"""std""",
"""min""","""2011-01-01 00:…"
"""max""","""2015-01-01 00:…"


In [9]:
formatted_df.select('id').head()

id
str
"""MT_001"""
"""MT_001"""
"""MT_001"""
"""MT_001"""
"""MT_001"""


In [10]:
validation_boundary = date(2015, 6, 1)
tft.utils.plot_split(formatted_df, validation_boundary, groupby='id')

In [11]:
filtered_df = formatted_df.filter(pl.col("ts").dt.year() >= LEFT_CUTOFF_YEAR)
validation_boundary = date(2014, 10, 1)
tft.utils.plot_split(filtered_df, validation_boundary, groupby='id', autorange="x")

In [12]:
preprocessor = electricity.Preprocessor()
preprocessor

{'year': "defaultdict(<class 'sklearn.preprocessing._data.StandardScaler'>, {})", 'target': "defaultdict(<class 'sklearn.preprocessing._data.StandardScaler'>, {})", 'categorical': "defaultdict(<class 'sklearn.preprocessing._label.LabelEncoder'>, {})"}

In [13]:
preprocessor.target

defaultdict(sklearn.preprocessing._data.StandardScaler, {})

In [14]:
preprocessor.fit(filtered_df)
preprocessor

{'target': "{'MT_296': StandardScaler(), 'MT_064': StandardScaler(), 'MT_310': StandardScaler(), 'MT_354': StandardScaler(), 'MT_227': StandardScaler(), 'MT_058': StandardScaler(), 'MT_134': StandardScaler(), 'MT_240': StandardScaler(), 'MT_287': StandardScaler(), 'MT_346': StandardScaler(), 'MT_083': StandardScaler(), 'MT_106': StandardScaler(), 'MT_273': StandardScaler(), 'MT_158': StandardScaler(), 'MT_039': StandardScaler(), 'MT_225': StandardScaler(), 'MT_048': StandardScaler(), 'MT_149': StandardScaler(), 'MT_315': StandardScaler(), 'MT_138': StandardScaler(), 'MT_129': StandardScaler(), 'MT_325': StandardScaler(), 'MT_012': StandardScaler(), 'MT_097': StandardScaler(), 'MT_150': StandardScaler(), 'MT_345': StandardScaler(), 'MT_066': StandardScaler(), 'MT_071': StandardScaler(), 'MT_023': StandardScaler(), 'MT_366': StandardScaler(), 'MT_300': StandardScaler(), 'MT_274': StandardScaler(), 'MT_351': StandardScaler(), 'MT_177': StandardScaler(), 'MT_087': StandardScaler(), 'MT_368

In [15]:
processed_df = preprocessor.transform(filtered_df)
processed_df.head(10)

id,ts,year,month,day,day_of_week,hour,y
i64,datetime[μs],f64,i64,i64,i64,i64,f32
181,2013-01-01 00:00:00,-0.999943,0,0,1,0,-0.162965
181,2013-01-01 01:00:00,-0.999943,0,0,1,1,-0.162776
181,2013-01-01 02:00:00,-0.999943,0,0,1,2,-0.164003
181,2013-01-01 03:00:00,-0.999943,0,0,1,3,-0.16287
181,2013-01-01 04:00:00,-0.999943,0,0,1,4,-0.16287
181,2013-01-01 05:00:00,-0.999943,0,0,1,5,-0.163483
181,2013-01-01 06:00:00,-0.999943,0,0,1,6,-0.163484
181,2013-01-01 07:00:00,-0.999943,0,0,1,7,-0.162681
181,2013-01-01 08:00:00,-0.999943,0,0,1,8,-0.161689
181,2013-01-01 09:00:00,-0.999943,0,0,1,9,-0.1615


In [16]:
training_df, test_df = tft.utils.split_dataframe(processed_df, validation_boundary)
len(training_df), len(test_df)

(5665440, 817330)

In [1]:
import mlx.core as mx


def make_time_series_array(dataframe: pl.DataFrame) -> np.ndarray:
    ts_list = []
    for _, dataframe_i in dataframe.group_by(['id']):
        ts_i = tft.utils.timeseries_from_array(preprocessor.to_array(dataframe_i), TOTAL_TIME_STEPS)
        ts_list.append(ts_i)
    
    
    return mx.concatenate(ts_list, axis=0)   
        
        
        
train_arr = make_time_series_array(training_df)
test_arr = make_time_series_array(test_df)

train_arr.shape, test_arr.shape

NameError: name 'np' is not defined

In [2]:
gc.collect()

NameError: name 'gc' is not defined