In [1]:
%load_ext autoreload
%autoreload 2

In [1]:
import temporal_fusion_transformer as tft
from tqdm.notebook import tqdm
import polars as pl
from absl_extra import flax_utils

In [None]:
df = tft.datasets.electricity.read_parquet("../data/electricity")
df.head()

Formatting inputs:   0%|          | 0/370 [00:00<?, ?it/s]

In [4]:
training_df, validation_df, test_df = tft.datasets.electricity.split_data(df)

In [8]:
preprocessor = tft.datasets.electricity.train_preprocessor(df)

Training scalers:   0%|          | 0/370 [00:00<?, ?it/s]

Fitting label encoders:   0%|          | 0/4 [00:00<?, ?it/s]

In [9]:
preprocessor

{'real': {'MT_325': StandardScaler(),
  'MT_256': StandardScaler(),
  'MT_355': StandardScaler(),
  'MT_097': StandardScaler(),
  'MT_282': StandardScaler(),
  'MT_290': StandardScaler(),
  'MT_028': StandardScaler(),
  'MT_208': StandardScaler(),
  'MT_035': StandardScaler(),
  'MT_368': StandardScaler(),
  'MT_263': StandardScaler(),
  'MT_215': StandardScaler(),
  'MT_279': StandardScaler(),
  'MT_187': StandardScaler(),
  'MT_147': StandardScaler(),
  'MT_205': StandardScaler(),
  'MT_227': StandardScaler(),
  'MT_104': StandardScaler(),
  'MT_084': StandardScaler(),
  'MT_138': StandardScaler(),
  'MT_136': StandardScaler(),
  'MT_004': StandardScaler(),
  'MT_025': StandardScaler(),
  'MT_365': StandardScaler(),
  'MT_285': StandardScaler(),
  'MT_220': StandardScaler(),
  'MT_029': StandardScaler(),
  'MT_012': StandardScaler(),
  'MT_036': StandardScaler(),
  'MT_273': StandardScaler(),
  'MT_107': StandardScaler(),
  'MT_191': StandardScaler(),
  'MT_261': StandardScaler(),
  

In [14]:
targets = ("power_usage",)
real_inputs = ("year",)
categorical_inputs = ("month", "day", "hour", "day_of_week")

lf_list = []

for i, sub_df in tqdm(df.groupby("id"), total=370, desc="Applying scalers..."):
    sub_df: pl.DataFrame
    sub_lf: pl.LazyFrame = sub_df.lazy()

    x_real = df[real_inputs].to_numpy()
    x_target = df[targets].to_numpy()

    x_real = preprocessor["real"][i].transform(x_real)
    x_targets = preprocessor["target"][i].transform(x_target)

    sub_lf = sub_lf.with_columns(
        [pl.lit(i).alias(j).cast(pl.Float32) for i, j in zip(x_real, real_inputs)]
    ).with_columns(pl.lit(i).alias(j).cast(pl.Float32) for i, j in zip(x_target, targets))
    lf_list.append(sub_lf)

df = pl.concat(lf_list).collect()

for i in tqdm(categorical_inputs):
    x = df[i].to_numpy()
    x = preprocessor["categorical"][i].transform(x)
    df = df.drop(i).with_columns(pl.lit(x).alias(i).cast(pl.Int8))

df

Applying scalers...:   0%|          | 0/370 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

timestamp,power_usage,year,id,month,day,hour,day_of_week
datetime[μs],f32,f32,str,i8,i8,i8,i8
2014-01-01 00:00:00,2.538071,0.0,"""MT_038""",0,0,0,2
2014-01-01 01:00:00,2.538071,0.0,"""MT_038""",0,0,1,2
2014-01-01 02:00:00,2.538071,0.0,"""MT_038""",0,0,2,2
2014-01-01 03:00:00,2.538071,0.0,"""MT_038""",0,0,3,2
2014-01-01 04:00:00,2.538071,0.0,"""MT_038""",0,0,4,2
2014-01-01 05:00:00,2.538071,0.0,"""MT_038""",0,0,5,2
2014-01-01 06:00:00,2.538071,0.0,"""MT_038""",0,0,6,2
2014-01-01 07:00:00,2.538071,0.0,"""MT_038""",0,0,7,2
2014-01-01 08:00:00,2.538071,0.0,"""MT_038""",0,0,8,2
2014-01-01 09:00:00,2.538071,0.0,"""MT_038""",0,0,9,2
