<div style="display:fill;
           background-color:#FFFFE0;
           letter-spacing:0.5px;border-bottom: 2px solid black;">
<img src="https://raw.githubusercontent.com/IqmanS/Machine-Learning-Notebooks/main/enefit/enefit-banner.jpg">
<H2 style="padding: 20px; color:black; font-weight:600; font-family: 'Garamond', 'Lucida Sans', sans-serif; text-align: center; font-size: 38px;">♾️ Pipeline ⇆ Baseline ♾️ </H2>
</div>


In [1]:
import os, sys, shutil, psutil

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.chdir('/kaggle/input/enefit-artefacts')
os.getcwd()

'/kaggle/input/enefit-artefacts'

In [2]:
def cpu_stats():
    pid = os.getpid()
    py = psutil.Process(pid)
    memory_use = py.memory_info()[0] / 2. ** 30
    return 'Memory:' + str(np.round(memory_use, 2)) + ' GB'

In [3]:
from IPython.display import Image, display, clear_output, HTML
from ipywidgets import Select, interact as interactive_cell
import warnings
warnings.filterwarnings("ignore")

In [4]:
%%time

!pip install "einops==0.4.1"
!pip install "lightning"

clear_output()

CPU times: user 410 ms, sys: 91.7 ms, total: 502 ms
Wall time: 27.9 s


In [5]:
%%time
import json, yaml, pickle, joblib
import base64
import re
from tqdm import tqdm
from typing import Dict, Iterable, List, Tuple, Optional, Union
from datetime import datetime
from itertools import product, combinations
from functools import reduce, partial
from operator import concat
from pathlib import Path

import math
import numpy as np
import pandas as pd
import polars as pl
import dask as dk

pl_cfg = pl.Config()
pl_cfg.set_tbl_rows(7)

pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)

CPU times: user 551 ms, sys: 105 ms, total: 656 ms
Wall time: 801 ms


In [6]:
def flatten_nested_list(arr):
    return reduce(concat, arr)

In [7]:
def display_pl_expressions(expr):
    expr_litr = [
        {
            'tables': E['tables'],
            'expressions': [str(e) for e in E['expressions']],
        } for E in expr
    ]
    print(json.dumps(expr_litr, indent=4))

In [8]:
%%time

from sklearn import set_config
from sklearn.metrics import mean_absolute_error as MAE, mean_squared_error as MSE
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import GridSearchCV
from sklearn.inspection import permutation_importance, partial_dependence

set_config(display="diagram")

CPU times: user 581 ms, sys: 271 ms, total: 851 ms
Wall time: 766 ms


In [9]:
def make_flowchart(graph: str):
    graph_bytes = graph.encode("ascii")
    base64_bytes = base64.b64encode(graph_bytes)
    base64_string = base64_bytes.decode("ascii")
    display(
        Image(url="https://mermaid.ink/img/" + base64_string)
    )

In [10]:
month_mapping = {
    'jan': 'January',
    'feb': 'February',
    'mrt': 'March',
    'mar': 'March',
    'apr': 'April',
    'mei': 'May',
    'jun': 'June',
    'aug': 'August',
    'dec': 'December'
}

def convert_month_abbr(date):    
    parts = date.split()
    if len(parts) == 2:
        if parts[1].isdigit():
            month, day = parts
        else:
            day, month = parts
        if month.lower() in month_mapping:
            return f"{day} {month_mapping[month.lower()]}"
    return pd.NaT

def get_holidays(years: Union[int, List[int]], country: str):
    if isinstance(years, List) is False:
        years = [years]
    
    DF = pd.DataFrame()
    pbar = tqdm(years)
    for year in pbar:
        pbar.set_description(f'Year {year}')
        
        # Query holidays from website
        try:
            url = f'https://www.timeanddate.com/holidays/{country}/{year}?hol=1'
            df = pd.read_html(url)[0]
        except:
            continue

        # Flatten the MultiIndex columns if present
        df.columns = df.columns.get_level_values(0)

        # Select the desired columns
        df = df[['Date', 'Name', 'Type']]

        # Filter for national holidays only
        # df = df[df['Type'] == 'National holiday']
        df = df[~df['Type'].isna()]

        # Add a 'Year' column to differentiate between years
        df['Year'] = year
        
        # Aggregations
        DF = pd.concat([DF, df], ignore_index=True)
    
    # Convert to datetime
    DF['Date'] = DF[ 'Date'].apply(convert_month_abbr)
    DF.dropna(subset='Date', axis=0, inplace=True)
    DF['date'] = DF[['Date','Year']].apply(lambda x: f"{x.Date} {x.Year}", axis=1)
    DF['date'] = DF[ 'date'].apply(pd.to_datetime, errors='coerce').dt.date
    DF.dropna(subset='date', axis=0, inplace=True)

    return DF[['date', 'Name', 'Type']]


holidays_fpath = '/kaggle/input/enefit-artefacts/data/holidays.csv'
if os.path.isfile(holidays_fpath):
    holidays = pd.read_csv(holidays_fpath, parse_dates=[0])
    holidays['date'] = holidays['date'].dt.date
else:
    holidays = get_holidays(list(range(2020, 2050)), country='estonia')
    holidays.to_csv(holidays_fpath, index=False)

holidays = pl.from_pandas(holidays)
display(holidays['Type'].value_counts())
display(holidays['Name'].value_counts())

Type,counts
str,u32
"""National holid…",218


Name,counts
str,u32
"""Good Friday""",21
"""Midsummer Day""",21
"""Christmas Day""",21
…,…
"""Boxing Day""",21
"""Pentecost""",8
"""Victory Day""",21


<div>
    <h1 style="font-family:  'Garamond', 'Lucida Sans', sans-serif; text-align: center; color: #263A29; font-weight: bold; font-size: 36px;">
    Data Preprocessing
    </h1>
</div>
<hr>

# [Data & Insights](https://www.kaggle.com/code/mrriandmstique/enefit-data-and-insights)
# [Schema & Features](https://www.kaggle.com/code/mrriandmstique/enefit-schema-and-features)

# Data Schema

In [11]:
# https://mermaid.live/edit
# https://fontawesome.com/v4/icons/

make_flowchart("""
    flowchart TD
        F -->|datetime| A(target)
        I[fa:fa-fire-extinguisher gas_prices] -->|date| A
        J[fa:fa-bolt electricity_prices] -->|datetime| A
        K[fa:fa-calendar holidays] -->|date| A
        A -->|client_id| C[client]
        C -->|county| M[county]
        E(fa:fa-industry weather_station) -->|county| M
        F[fa:fa-sun-o weather_history] -->|station_id| E
        G[fa:fa-snowflake-o weather_forecast] -->|station_id,datetime| F
        B[fa:fa-users prosumer] -->|target| A
        L[fa:fa-database capacity] -->|client_id| C
        D[fa:fa-user-circle-o county_name] -->|county| M
""")

# Data Definition

In [12]:
data_root = Path("/kaggle/input/predict-energy-behavior-of-prosumers")
artefact_root = Path("/kaggle/input/enefit-artefacts/artefacts")

In [13]:
with open(artefact_root / 'data_definitions.yaml', 'r') as file:
    data_definitions = yaml.safe_load(file)
# data_definitions

# Data Loading

In [14]:
%%time

DataTables = {
    'holidays': holidays,
}

pbar = tqdm(data_definitions.items(), total=len(data_definitions))
for dname, ddef in pbar:
    pbar.set_description(dname)
    
    # Load data
    cols = flatten_nested_list([c for c in ddef['columns'].values()])
    parse_dt = True if ('Datetime' in list(ddef['columns'].keys())) else False
    table = pl.read_csv(source=data_root / ddef['filename'], 
                        columns=cols, try_parse_dates=parse_dt)
    
    # Standardize name
    if 'rename' in ddef.keys():
        table = table.rename(mapping=ddef['rename'])
        
    # Gather
    DataTables[dname] = table

weather_station_map: 100%|██████████| 7/7 [00:04<00:00,  1.44it/s]

CPU times: user 9.41 s, sys: 2.21 s, total: 11.6 s
Wall time: 4.87 s





In [15]:
def display_table(table_name, mode):
    if mode == 'statistic':
        with pl.Config(tbl_rows=10):
            display(DataTables[table_name].describe())
    else:
        display(DataTables[table_name])
    
dropbox_1 = Select(description='Table:', options=sorted(list(DataTables.keys())), rows=1,)
dropbox_2 = Select(description='Mode:', options=['sample','statistic',], rows=2,)

# interactive_cell(display_table, table_name=dropbox_1, mode=dropbox_2)

In [16]:
client_columns = ['county','product_type','is_business']
geoloc_columns = ['longitude','latitude',]
cloud_columns = [f'cloudcover_{x}' for x in ['low','mid','high','total']]

# Data Processing

#### ⚠️ Python-native functions take much longer time to process❗

> **Wall time**: 8-10 s (Numpy) >> 150 ms (Polars)
<br>`DataTables['forecasted_weather'].with_columns(
    pl.struct(["wind_eastward", "wind_northward"]).apply(
        lambda x: dict(zip(
            ('wind_speed','wind_direction'), 
            convert_east_north_to_magnitude_degree(x["wind_eastward"], 
                                                   x["wind_northward"])
        ))
    ).alias("result")
).unnest("result").with_columns([
    pl.col('wind_direction').radians().alias('wind_direction_rad')
]).lazy().select(pl.col('^wind_.*$')).collect()`

In [17]:
processes = [    
    {
             'tables' : ['targets'],
        'expressions' : [
                            pl.col("datetime").cast(pl.Date).alias("date")
                        ],
    },
    {
             'tables' : ['historical_weather'],
        'expressions' : [
                            (pl.when(pl.col(x) > 1)
                               .then(pl.col(x) / 100)
                          .otherwise(pl.col(x))).alias(x) for x in cloud_columns
                        ] + [pl.col('wind_direction').radians().alias('wind_direction')],
    },
    {
             'tables' : ['forecasted_weather'],
        'expressions' : [
                            ((pl.col('wind_eastward') ** 2 + pl.col('wind_northward') ** 2) ** 0.5).alias('wind_speed'),
                            (pl.arctan2('wind_eastward','wind_northward') - np.pi).alias('wind_direction'),
                        ],
    },
    {
             'tables' : ['historical_weather','forecasted_weather'],
        'expressions' : [
                            (pl.when(pl.col('wind_direction') <= 0)
                               .then(pl.col('wind_direction') + 2*np.pi)
                          .otherwise(pl.col('wind_direction'))).alias('wind_direction'),
                        ],
    },
]

# display_pl_expressions(processes)

pbar = tqdm(processes)
for proc_config in pbar:
    for table in proc_config['tables']:
        DataTables[table] = DataTables[table].with_columns(proc_config['expressions'])

100%|██████████| 4/4 [00:00<00:00, 11.95it/s]


In [18]:
%%time

groupby_columns = ['datetime','longitude','latitude']
ignore_columns = ['data_block_id','origin_datetime','wind_eastward','wind_northward']
metric_columns = [col for col in DataTables['forecasted_weather'].columns
                       if col not in groupby_columns + ignore_columns]

combine_funcs = []
for c in metric_columns:
    combine_funcs.extend([pl.col(c).last(), 
                          pl.col(c).mean().alias(f'mean_{c}'), 
                          pl.col(c).std().alias(f'stddv_{c}'),
                          pl.col(c).alias(f'array_{c}'), ])

DataTables['forecasted_weather'] = \
DataTables['forecasted_weather'].sort(groupby_columns)\
                            .group_by(groupby_columns).agg(*combine_funcs)\
                                .sort(groupby_columns)

DataTables['forecasted_weather'] = DataTables['forecasted_weather'].with_columns([
    pl.col(col).fill_null(pl.lit(0.)) 
       for col in DataTables['forecasted_weather'].columns
        if col.startswith('stddv_')
])

DataTables['forecasted_weather']

CPU times: user 11.9 s, sys: 4.39 s, total: 16.3 s
Wall time: 4.6 s


datetime,longitude,latitude,hours_ahead,mean_hours_ahead,stddv_hours_ahead,array_hours_ahead,temperature,mean_temperature,stddv_temperature,array_temperature,dewpoint,mean_dewpoint,stddv_dewpoint,array_dewpoint,cloudcover_high,mean_cloudcover_high,stddv_cloudcover_high,array_cloudcover_high,cloudcover_low,mean_cloudcover_low,stddv_cloudcover_low,array_cloudcover_low,cloudcover_mid,mean_cloudcover_mid,stddv_cloudcover_mid,array_cloudcover_mid,cloudcover_total,mean_cloudcover_total,stddv_cloudcover_total,array_cloudcover_total,direct_solar_radiation,mean_direct_solar_radiation,stddv_direct_solar_radiation,array_direct_solar_radiation,surface_solar_radiation,mean_surface_solar_radiation,stddv_surface_solar_radiation,array_surface_solar_radiation,snowfall,mean_snowfall,stddv_snowfall,array_snowfall,precipitation,mean_precipitation,stddv_precipitation,array_precipitation,wind_speed,mean_wind_speed,stddv_wind_speed,array_wind_speed,wind_direction,mean_wind_direction,stddv_wind_direction,array_wind_direction
datetime[μs],f64,f64,i64,f64,f64,list[i64],f64,f64,f64,list[f64],f64,f64,f64,list[f64],f64,f64,f64,list[f64],f64,f64,f64,list[f64],f64,f64,f64,list[f64],f64,f64,f64,list[f64],f64,f64,f64,list[f64],f64,f64,f64,list[f64],f64,f64,f64,list[f64],f64,f64,f64,list[f64],f64,f64,f64,list[f64],f64,f64,f64,list[f64]
2021-09-01 03:00:00,21.7,57.6,1,1.0,0.0,[1],15.655786,15.655786,0.0,[15.655786],11.553613,11.553613,0.0,[11.553613],0.904816,0.904816,0.0,[0.904816],0.019714,0.019714,0.0,[0.019714],0.0,0.0,0.0,[0.0],0.905899,0.905899,0.0,[0.905899],0.0,0.0,0.0,[0.0],0.0,0.0,0.0,[0.0],0.0,0.0,0.0,[0.0],0.0,0.0,0.0,[0.0],9.115422,9.115422,0.0,[9.115422],0.04514,0.04514,0.0,[0.04514]
2021-09-01 03:00:00,21.7,57.9,1,1.0,0.0,[1],16.050439,16.050439,0.0,[16.050439],12.355493,12.355493,0.0,[12.355493],0.886078,0.886078,0.0,[0.886078],0.051636,0.051636,0.0,[0.051636],0.000092,0.000092,0.0,[0.000092],0.889587,0.889587,0.0,[0.889587],0.0,0.0,0.0,[0.0],0.0,0.0,0.0,[0.0],0.0,0.0,0.0,[0.0],0.0,0.0,0.0,[0.0],10.107589,10.107589,0.0,[10.107589],6.240934,6.240934,0.0,[6.240934]
2021-09-01 03:00:00,21.7,58.2,1,1.0,0.0,[1],15.965112,15.965112,0.0,[15.965112],12.732202,12.732202,0.0,[12.732202],0.861237,0.861237,0.0,[0.861237],0.025238,0.025238,0.0,[0.025238],0.000244,0.000244,0.0,[0.000244],0.863724,0.863724,0.0,[0.863724],0.0,0.0,0.0,[0.0],0.0,0.0,0.0,[0.0],0.0,0.0,0.0,[0.0],0.0,0.0,0.0,[0.0],10.323137,10.323137,0.0,[10.323137],6.157255,6.157255,0.0,[6.157255]
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
2023-06-01 02:00:00,28.2,59.1,48,48.0,0.0,[48],8.720483,8.720483,0.0,[8.720483],3.83291,3.83291,0.0,[3.83291],0.349823,0.349823,0.0,[0.349823],0.283783,0.283783,0.0,[0.283783],0.121979,0.121979,0.0,[0.121979],0.541138,0.541138,0.0,[0.541138],0.0,0.0,0.0,[0.0],0.0,0.0,0.0,[0.0],0.0,0.0,0.0,[0.0],0.0,0.0,0.0,[0.0],3.494726,3.494726,0.0,[3.494726],5.04859,5.04859,0.0,[5.04859]
2023-06-01 02:00:00,28.2,59.4,48,48.0,0.0,[48],10.650171,10.650171,0.0,[10.650171],7.280054,7.280054,0.0,[7.280054],0.33313,0.33313,0.0,[0.33313],0.210815,0.210815,0.0,[0.210815],0.461761,0.461761,0.0,[0.461761],0.671173,0.671173,0.0,[0.671173],0.0,0.0,0.0,[0.0],0.0,0.0,0.0,[0.0],0.0,0.0,0.0,[0.0],0.0,0.0,0.0,[0.0],3.647102,3.647102,0.0,[3.647102],4.981714,4.981714,0.0,[4.981714]
2023-06-01 02:00:00,28.2,59.7,48,48.0,0.0,[48],11.233179,11.233179,0.0,[11.233179],7.069238,7.069238,0.0,[7.069238],0.703461,0.703461,0.0,[0.703461],0.292313,0.292313,0.0,[0.292313],0.934021,0.934021,0.0,[0.934021],0.989716,0.989716,0.0,[0.989716],0.0,0.0,0.0,[0.0],0.0,0.0,0.0,[0.0],0.0,0.0,0.0,[0.0],0.0,0.0,0.0,[0.0],2.872061,2.872061,0.0,[2.872061],5.039545,5.039545,0.0,[5.039545]


In [19]:
DataTables['gas_prices'] = \
DataTables['gas_prices'].sort(by=['forecast_date','origin_date'], descending=False)\
                  .unique(subset=['forecast_date'], keep='last')

In [20]:
DataTables['electricity_prices'] = \
DataTables['electricity_prices'].sort(by=['forecast_datetime','origin_datetime'], descending=False)\
                          .unique(subset=['forecast_datetime'], keep='last')

In [21]:
DataTables['holidays'] = DataTables['holidays'].with_columns([pl.lit(True).alias('is_holidays')])\
                                               .drop(columns=['Name','Type'])

<div>
    <h1 style="font-family:  'Garamond', 'Lucida Sans', sans-serif; text-align: center; color: #263A29; font-weight: bold; font-size: 36px;">
    Feature Engineering
    </h1>
</div>
<hr>

In [22]:
processes = [
    {
        "anchor" :  'targets',
       "support" :  'holidays',
           "how" :  'left',
            "on" : ['date'],
    },
    {
        "anchor" :  'targets',
       "support" :  'clients',
        "ignore" : ['data_block_id'],
           "how" :  'left',
            "on" : client_columns + ['date'],
    },
    {
        "anchor" :  'targets',
       "support" :  'gas_prices',
        "ignore" : ['data_block_id','origin_date'],
           "how" :  'left',
       "left_on" : ['date'],
      "right_on" : ['forecast_date'],
    },
    {
        "anchor" :  'targets',
       "support" :  'electricity_prices',
        "ignore" : ['data_block_id','origin_datetime'],
           "how" :  'left',
       "left_on" : ['datetime'],
      "right_on" : ['forecast_datetime'],
    },
]

pbar = tqdm(processes)
for proc_config in pbar:
    anchor = proc_config.pop('anchor')
    support = proc_config.pop('support')
    ignore = proc_config.pop('ignore', None)
    DataTables[anchor] = DataTables[anchor].join((DataTables[support] if not ignore else \
                                                  DataTables[support].drop(columns=ignore)), **proc_config)

100%|██████████| 4/4 [00:00<00:00, 10.16it/s]


In [23]:
processes = {
    "targets" : {
        "is_holidays"      : False,
           "holidays_name" : -1,
           "holidays_type" : -1,
    }
}

pbar = tqdm(processes.items())
for table, filler_config in pbar:
    null_fillers = [pl.col(k).fill_null(pl.lit(v)) for k,v in filler_config.items()]
    DataTables[table] = DataTables[table].with_columns(null_fillers)

100%|██████████| 1/1 [00:00<00:00, 30.10it/s]


In [24]:
# Check missing values

for col in ['eic_count','pv_capacity','min_gas_price','max_gas_price','electricity_price']:
    print(col, DataTables['targets'].filter(pl.col(col).is_null())['date'].to_pandas().unique())

eic_count [datetime.date(2023, 5, 30) datetime.date(2023, 5, 31)]
pv_capacity [datetime.date(2023, 5, 30) datetime.date(2023, 5, 31)]
min_gas_price [datetime.date(2023, 5, 31)]
max_gas_price [datetime.date(2023, 5, 31)]
electricity_price [datetime.date(2022, 3, 27) datetime.date(2023, 3, 26)
 datetime.date(2023, 5, 31)]


#### ⚠️ Remove dates 2023-05-30 & 31.❗

In [25]:
DataTables['targets'] = DataTables['targets'].filter(pl.col('date') < datetime(2023, 5, 30))

In [26]:
# Check missing values

for col in ['eic_count','pv_capacity','min_gas_price','max_gas_price','electricity_price']:
    print(col, DataTables['targets'].filter(pl.col(col).is_null())['date'].to_pandas().unique())

eic_count []
pv_capacity []
min_gas_price []
max_gas_price []
electricity_price [datetime.date(2022, 3, 27) datetime.date(2023, 3, 26)]


In [27]:
DataTables['targets'].filter(pl.col('electricity_price').is_null())['datetime'].to_pandas().unique()

<DatetimeArray>
['2022-03-27 02:00:00', '2023-03-26 02:00:00']
Length: 2, dtype: datetime64[ns]

### [Temporal-Feature Encoding](https://developer.nvidia.com/blog/three-approaches-to-encoding-time-information-as-features-for-ml-models/)

In [28]:
PERIODS = {
          "month"  :  12,
    "day_of_year"  : 366,
    "day_of_month" :  31,
    "day_of_week"  :   7,
           "hour"  :  24,
         "minute"  :  60,
         "second"  :  60,     
}

processes = [
    {
             'tables' : ['targets'],
        'expressions' : [
                            pl.col("datetime").dt.year().alias("year"),
                            pl.col("datetime").dt.month().alias("month"),
                            pl.col("datetime").dt.ordinal_day().alias("day_of_year"),
                            pl.col("datetime").dt.day().alias("day_of_month"),
                            pl.col("datetime").dt.weekday().alias("day_of_week"),
                           (pl.col("datetime").dt.weekday() > 5).alias("is_weekend"),
                            pl.col("datetime").dt.hour().alias("hour"),
                            pl.col("datetime").dt.minute().alias("minute"),
                            pl.col("datetime").dt.second().alias("second"),
                        ],
    },
    {
             'tables' : ['targets'],
        'expressions' : flatten_nested_list([
                            [(pl.col(col) * np.pi / PERIODS[col]).sin().alias(f"SIN({col})"),
                             (pl.col(col) * np.pi / PERIODS[col]).cos().alias(f"COS({col})"),] 
                                         for col in PERIODS.keys()
                        ])
    },
]

display_pl_expressions(processes)

pbar = tqdm(processes)
for proc_config in pbar:
    for table in proc_config['tables']:
        DataTables[table] = DataTables[table].with_columns(proc_config['expressions'])

[
    {
        "tables": [
            "targets"
        ],
        "expressions": [
            "col(\"datetime\").dt.year().alias(\"year\")",
            "col(\"datetime\").dt.month().alias(\"month\")",
            "col(\"datetime\").dt.ordinal_day().alias(\"day_of_year\")",
            "col(\"datetime\").dt.day().alias(\"day_of_month\")",
            "col(\"datetime\").dt.weekday().alias(\"day_of_week\")",
            "[(col(\"datetime\").dt.weekday()) > (5)].alias(\"is_weekend\")",
            "col(\"datetime\").dt.hour().alias(\"hour\")",
            "col(\"datetime\").dt.minute().alias(\"minute\")",
            "col(\"datetime\").dt.second().alias(\"second\")"
        ]
    },
    {
        "tables": [
            "targets"
        ],
        "expressions": [
            "[([(col(\"month\")) * (3.141593)]) / (12)].sin().alias(\"SIN(month)\")",
            "[([(col(\"month\")) * (3.141593)]) / (12)].cos().alias(\"COS(month)\")",
            "[([(col(\"day_of_year\")) * (3.141593)

100%|██████████| 2/2 [00:00<00:00,  4.50it/s]


### [Exponentially-Weighted Moving Average](https://pola-rs.github.io/polars/py-polars/html/reference/series/api/polars.Series.ewm_mean.html)

In [29]:
%%time

# Reduce daily targets, apply dayoff-effect transformations, then, merge back
dayoff_columns = ['is_weekend','is_holidays']
daily_targets = DataTables['targets'].select(dayoff_columns+['date']).unique()
daily_targets = daily_targets.with_columns([
    pl.col(col).ewm_mean(span=2, adjust=True).alias(f"EMA({col})") 
       for col in dayoff_columns
])

DataTables['targets'] = DataTables['targets'].join(
    daily_targets.select(['date'] + [f"EMA({col})" for col in dayoff_columns]), how='left', on=['date'])

# Display
ignore_columns = ['county','is_business','product_type','target','is_consumption','data_block_id','row_id','prediction_unit_id']
datetime_columns = [col for col in DataTables['targets'].columns if col not in ignore_columns]

with pl.Config(tbl_rows=10):
    display(DataTables['targets'].select(datetime_columns).describe())

describe,datetime,date,holidays_name,holidays_type,is_holidays,eic_count,pv_capacity,min_gas_price,max_gas_price,electricity_price,year,month,day_of_year,day_of_month,day_of_week,is_weekend,hour,minute,second,SIN(month),COS(month),SIN(day_of_year),COS(day_of_year),SIN(day_of_month),COS(day_of_month),SIN(day_of_week),COS(day_of_week),SIN(hour),COS(hour),SIN(minute),COS(minute),SIN(second),COS(second),EMA(is_weekend),EMA(is_holidays)
str,str,str,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
"""count""","""2012112""","""2012112""",2012112.0,2012112.0,2012112.0,2012112.0,2012112.0,2012112.0,2012112.0,2012112.0,2012112.0,2012112.0,2012112.0,2012112.0,2012112.0,2012112.0,2012112.0,2012112.0,2012112.0,2012112.0,2012112.0,2012112.0,2012112.0,2012112.0,2012112.0,2012112.0,2012112.0,2012112.0,2012112.0,2012112.0,2012112.0,2012112.0,2012112.0,2012112.0,2012112.0
"""null_count""","""0""","""0""",0.0,0.0,0.0,0.0,0.0,0.0,0.0,266.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
"""mean""",,,-0.847945,-0.971636,0.028364,73.345118,1450.771451,95.559829,108.414457,157.59014,2022.053913,6.435053,180.220449,15.655669,4.002767,0.286123,11.5,0.0,0.0,0.589046,-0.06033,0.590785,0.022282,0.650058,-0.009368,0.625645,-0.143923,0.63571,0.041667,0.0,1.0,0.0,1.0,0.285568,0.028353
"""std""",,,1.026252,0.166011,0.166011,144.062707,2422.20483,47.540453,54.719761,121.351232,0.64409,3.669702,112.112528,8.760945,2.000106,0.451948,6.922188,0.0,0.0,0.313203,0.742489,0.298753,0.749149,0.298548,0.698718,0.329197,0.692448,0.309632,0.705878,0.0,0.0,0.0,0.0,0.315929,0.12004
"""min""","""2021-09-01 00:…","""2021-09-01""",-1.0,-1.0,0.0,5.0,5.5,28.1,34.1,-10.06,2021.0,1.0,1.0,1.0,1.0,0.0,0.0,0.0,0.0,1.2246e-16,-1.0,0.008583,-0.999963,1.2246e-16,-1.0,1.2246e-16,-1.0,0.0,-0.991445,0.0,1.0,0.0,1.0,0.0,0.0
"""25%""",,,-1.0,-1.0,,13.0,321.9,60.0,67.67,85.32,2022.0,3.0,80.0,8.0,2.0,,6.0,0.0,0.0,0.258819,-0.866025,0.336637,-0.76788,0.394356,-0.688967,0.433884,-0.900969,0.382683,-0.608761,0.0,1.0,0.0,1.0,0.024693,2.0309e-21
"""50%""",,,-1.0,-1.0,,32.0,645.2,86.0,94.0,128.84,2022.0,6.0,167.0,16.0,4.0,,12.0,0.0,0.0,0.707107,6.1232e-17,0.633978,0.136906,0.724793,-0.050649,0.781831,-0.222521,0.707107,0.130526,0.0,1.0,0.0,1.0,0.10257,2.8326e-11
"""75%""",,,-1.0,-1.0,,70.0,1567.15,109.74,133.0,199.99,2022.0,10.0,285.0,23.0,6.0,,18.0,0.0,0.0,0.866025,0.707107,0.857315,0.773351,0.937752,0.688967,0.974928,0.62349,0.92388,0.793353,0.0,1.0,0.0,1.0,0.667107,3.4e-05
"""max""","""2023-05-29 23:…","""2023-05-29""",10.0,0.0,1.0,1517.0,19314.31,250.0,305.0,4000.0,2023.0,12.0,365.0,31.0,7.0,1.0,23.0,0.0,0.0,1.0,0.965926,1.0,0.999963,0.998717,0.994869,0.974928,0.900969,1.0,1.0,0.0,1.0,0.0,1.0,0.988569,0.888889


CPU times: user 3.83 s, sys: 1.2 s, total: 5.03 s
Wall time: 1.33 s


In [30]:
%%time

# Convert single-target into multi-target problem
DataTables['targets_x2'] = DataTables['targets'].pivot(index=client_columns+datetime_columns,
                                                     columns=['is_consumption'],
                                                      values=['target'])\
                                             .rename(mapping={'0': 'production', '1': 'consumption',})

CPU times: user 9.02 s, sys: 2.82 s, total: 11.8 s
Wall time: 10.9 s


In [31]:
%%time

# Handle missing data --> Average of (N-pre and N-post) values
# NOTE:  
#    - Sort by `client_columns` to guarantee null-filling inside
#    - Sort by `datetime` to have correct neighborhood
HALF_NEIGHBOR = 2

expressions = []
for col in ['production','consumption','electricity_price','min_gas_price','max_gas_price']:
    col_fillnull = pl.col(col).shift(1) + pl.col(col).shift(-1)
    if HALF_NEIGHBOR - 1 > 0:
        for i in range(2, HALF_NEIGHBOR+1):
            col_fillnull += (pl.col(col).shift(i) + pl.col(col).shift(-i))
    col_condition = pl.when(pl.col(col).is_null()).then(col_fillnull).otherwise(pl.col(col))
    expressions.append(col_condition)
    
DataTables['targets_x2'] = DataTables['targets_x2'].sort(by=client_columns+['datetime'], descending=False)\
                                                .with_columns(expressions)

# Troubleshoot
DataTables['targets_x2'].filter([pl.col(col).is_null() for col in ['production','consumption']])

CPU times: user 935 ms, sys: 365 ms, total: 1.3 s
Wall time: 386 ms


county,product_type,is_business,datetime,date,holidays_name,holidays_type,is_holidays,eic_count,pv_capacity,min_gas_price,max_gas_price,electricity_price,year,month,day_of_year,day_of_month,day_of_week,is_weekend,hour,minute,second,SIN(month),COS(month),SIN(day_of_year),COS(day_of_year),SIN(day_of_month),COS(day_of_month),SIN(day_of_week),COS(day_of_week),SIN(hour),COS(hour),SIN(minute),COS(minute),SIN(second),COS(second),EMA(is_weekend),EMA(is_holidays),production,consumption
i64,i64,i64,datetime[μs],date,i64,i64,bool,i64,f64,f64,f64,f64,i32,u32,u32,u32,u32,bool,u32,u32,u32,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64


In [32]:
%%time

from sklearn.preprocessing import QuantileTransformer

# Use Sklearn for `inverse_transform`
TargetTransformer = QuantileTransformer(n_quantiles=10, 
                                          subsample=DataTables['targets_x2'].select(pl.count())[0,0])

X = DataTables['targets_x2'].select(['production','consumption']).to_pandas().values
Y = TargetTransformer.fit_transform(X)
# X_ = TargetTransformer.inverse_transform(Y)

# Transform to uniform distribution
for ci, col in enumerate(['production','consumption']):
    DataTables['targets_x2'] = DataTables['targets_x2'].insert_column(-1, pl.Series(f"QUANTILE({col})", Y[:,ci]))

# Add ordinal-classification as extra task(s)
DataTables['targets_x2'] = DataTables['targets_x2'].with_columns([
    (pl.col(f"QUANTILE({col})").round(1) * 10).cast(pl.Int32).alias(f"PERCENTILE({col})")
    for col in ['production','consumption']
])

# Statistics for troubleshooting
step = 0.1
percentiles = np.arange(step, 1., step)

columns = [ 'production','QUANTILE(production)' ,'PERCENTILE(production)',
           'consumption','QUANTILE(consumption)','PERCENTILE(consumption)']

with pl.Config(tbl_rows=int(1/step)+10):
    display(DataTables['targets_x2'].select(columns).describe(percentiles=percentiles))

describe,production,QUANTILE(production),PERCENTILE(production),consumption,QUANTILE(consumption),PERCENTILE(consumption)
str,f64,f64,f64,f64,f64,f64
"""count""",1006056.0,1006056.0,1006056.0,1006056.0,1006056.0,1006056.0
"""null_count""",0.0,0.0,0.0,0.0,0.0,0.0
"""mean""",88.012279,0.423832,4.261248,461.152283,0.491075,4.92077
"""std""",379.977885,0.3438,3.468321,1200.577818,0.27933,2.825251
"""min""",0.0,0.0,0.0,0.0,0.0,0.0
"""10%""",0.0,0.0,0.0,12.028,0.100003,1.0
"""20%""",0.0,0.0,0.0,25.655,0.195993,2.0
"""30%""",0.0,0.0,0.0,44.786,0.297148,3.0
"""40%""",0.02,0.354701,4.0,69.522,0.392399,4.0
"""50%""",0.38,0.478671,5.0,109.004,0.492644,5.0


CPU times: user 909 ms, sys: 263 ms, total: 1.17 s
Wall time: 431 ms


In [33]:
%%time

# Handle ORDINAL features for faster than Sklearn (> 1 min)

for col in ['PERCENTILE(production)','PERCENTILE(consumption)']:
    # print(DataTables['targets_x2'][col].value_counts().sort(col))
    num_classes = DataTables['targets_x2'].max()[col][0] + 1
    aggregations = []
    for i in range(num_classes):
        aggregations.append((pl.col(col) >= i).cast(pl.Int8).alias(col+f'=={i}'))
    DataTables['targets_x2'] = DataTables['targets_x2'].with_columns(aggregations)

DataTables['targets_x2'].lazy().select(pl.col('^PERCENTILE.*$')).collect().sort(by=['PERCENTILE(production)'])

CPU times: user 256 ms, sys: 17.8 ms, total: 273 ms
Wall time: 84.1 ms


PERCENTILE(production),PERCENTILE(consumption),PERCENTILE(production)==0,PERCENTILE(production)==1,PERCENTILE(production)==2,PERCENTILE(production)==3,PERCENTILE(production)==4,PERCENTILE(production)==5,PERCENTILE(production)==6,PERCENTILE(production)==7,PERCENTILE(production)==8,PERCENTILE(production)==9,PERCENTILE(production)==10,PERCENTILE(consumption)==0,PERCENTILE(consumption)==1,PERCENTILE(consumption)==2,PERCENTILE(consumption)==3,PERCENTILE(consumption)==4,PERCENTILE(consumption)==5,PERCENTILE(consumption)==6,PERCENTILE(consumption)==7,PERCENTILE(consumption)==8,PERCENTILE(consumption)==9,PERCENTILE(consumption)==10
i32,i32,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8
0,4,1,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0,0
0,4,1,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0,0
0,4,1,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0,0
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
10,5,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0
10,6,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0
10,6,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0


In [34]:
feature_definitions = {
        "boolean_columns" : ['is_business','is_holidays','is_weekend',],
    "categorical_columns" : ['county','product_type','holidays_name','holidays_type'],
     "multilabel_columns" : [],
        "ordinal_columns" : [],
      "numerical_columns" : [f"{m}({v})" for v in ['month','hour','minute','second',
                                                   'day_of_year','day_of_month','day_of_week']
                                         for m in ['SIN','COS']] + \
                            ['EMA(is_weekend)','EMA(is_holidays)',
                             'electricity_price','min_gas_price','max_gas_price',
                             'eic_count','pv_capacity','production','consumption'],
    
    # NOTE: do NOT include `datetime` features
    "passthrough_columns" : [col for col in DataTables['targets_x2'].columns 
                                  if col.startswith('PERCENTILE') and ('==' in col)], 
}

with open('/kaggle/working/feature_definitions.yaml', 'w') as file:
    yaml.dump(feature_definitions, file)

print(len(flatten_nested_list([Vs for Vs in feature_definitions.values()])))
feature_definitions

52


{'boolean_columns': ['is_business', 'is_holidays', 'is_weekend'],
 'categorical_columns': ['county',
  'product_type',
  'holidays_name',
  'holidays_type'],
 'multilabel_columns': [],
 'ordinal_columns': [],
 'numerical_columns': ['SIN(month)',
  'COS(month)',
  'SIN(hour)',
  'COS(hour)',
  'SIN(minute)',
  'COS(minute)',
  'SIN(second)',
  'COS(second)',
  'SIN(day_of_year)',
  'COS(day_of_year)',
  'SIN(day_of_month)',
  'COS(day_of_month)',
  'SIN(day_of_week)',
  'COS(day_of_week)',
  'EMA(is_weekend)',
  'EMA(is_holidays)',
  'electricity_price',
  'min_gas_price',
  'max_gas_price',
  'eic_count',
  'pv_capacity',
  'production',
  'consumption'],
 'passthrough_columns': ['PERCENTILE(production)==0',
  'PERCENTILE(production)==1',
  'PERCENTILE(production)==2',
  'PERCENTILE(production)==3',
  'PERCENTILE(production)==4',
  'PERCENTILE(production)==5',
  'PERCENTILE(production)==6',
  'PERCENTILE(production)==7',
  'PERCENTILE(production)==8',
  'PERCENTILE(production)==9',
  '

In [35]:
%%time

# Pipeline has issue to concat `float` with `datetime`
temp_df = DataTables['targets_x2'].select(['datetime','date']) 
Dataset = DataTables['targets_x2'].to_pandas()#.sample(1000).reset_index(drop=True)

# Reference: https://copyprogramming.com/howto/how-to-save-a-custom-transformer-in-sklearn

from utils.pipeline import build_pipeline
FeaturePipeline = build_pipeline(**feature_definitions)

# Run transformations
FeaturePipeline.fit(Dataset)

# Dump to file
pipeline_path = '/kaggle/working/FeaturePipeline.jbl'
joblib.dump(FeaturePipeline, pipeline_path)

FeaturePipeline

[ColumnTransformer]  (1 of 31) Processing numerical_SIN(month), total=   0.0s
[ColumnTransformer]  (2 of 31) Processing numerical_COS(month), total=   0.0s
[ColumnTransformer]  (3 of 31) Processing numerical_SIN(hour), total=   0.0s
[ColumnTransformer]  (4 of 31) Processing numerical_COS(hour), total=   0.0s
[ColumnTransformer]  (5 of 31) Processing numerical_SIN(minute), total=   0.0s
[ColumnTransformer]  (6 of 31) Processing numerical_COS(minute), total=   0.0s
[ColumnTransformer]  (7 of 31) Processing numerical_SIN(second), total=   0.0s
[ColumnTransformer]  (8 of 31) Processing numerical_COS(second), total=   0.0s
[ColumnTransformer]  (9 of 31) Processing numerical_SIN(day_of_year), total=   0.0s
[ColumnTransformer]  (10 of 31) Processing numerical_COS(day_of_year), total=   0.0s
[ColumnTransformer]  (11 of 31) Processing numerical_SIN(day_of_month), total=   0.0s
[ColumnTransformer]  (12 of 31) Processing numerical_COS(day_of_month), total=   0.0s
[ColumnTransformer]  (13 of 31) P

In [36]:
FeaturePipeline.named_transformers_

{'numerical_SIN(month)': Pipeline(steps=[('scale', StandardScaler())]),
 'numerical_COS(month)': Pipeline(steps=[('scale', StandardScaler())]),
 'numerical_SIN(hour)': Pipeline(steps=[('scale', StandardScaler())]),
 'numerical_COS(hour)': Pipeline(steps=[('scale', StandardScaler())]),
 'numerical_SIN(minute)': Pipeline(steps=[('scale', StandardScaler())]),
 'numerical_COS(minute)': Pipeline(steps=[('scale', StandardScaler())]),
 'numerical_SIN(second)': Pipeline(steps=[('scale', StandardScaler())]),
 'numerical_COS(second)': Pipeline(steps=[('scale', StandardScaler())]),
 'numerical_SIN(day_of_year)': Pipeline(steps=[('scale', StandardScaler())]),
 'numerical_COS(day_of_year)': Pipeline(steps=[('scale', StandardScaler())]),
 'numerical_SIN(day_of_month)': Pipeline(steps=[('scale', StandardScaler())]),
 'numerical_COS(day_of_month)': Pipeline(steps=[('scale', StandardScaler())]),
 'numerical_SIN(day_of_week)': Pipeline(steps=[('scale', StandardScaler())]),
 'numerical_COS(day_of_week)':

In [37]:
%%time

FeatureSet = pd.DataFrame(data=FeaturePipeline.transform(Dataset), 
                        columns=FeaturePipeline.get_feature_names_out())

DataTables['targets_featured'] = pl.concat([temp_df, pl.from_pandas(FeatureSet)], how='horizontal')
# DataTables['targets_featured'].write_csv(featureset_path)

# Dtype-casting
DataTables['targets_featured'] = DataTables['targets_featured'].with_columns([
    pl.col(col).cast(pl.Int8 if any([col.startswith(dtype) 
                                                for dtype in ['categorical','multilabel']])
                else pl.Float32) 
       for col in DataTables['targets_featured'].columns if 'date' not in col
])

DataTables['targets_featured']

CPU times: user 8.25 s, sys: 253 ms, total: 8.5 s
Wall time: 8.14 s


datetime,date,numerical_SIN(month)__SIN(month),numerical_COS(month)__COS(month),numerical_SIN(hour)__SIN(hour),numerical_COS(hour)__COS(hour),numerical_SIN(minute)__SIN(minute),numerical_COS(minute)__COS(minute),numerical_SIN(second)__SIN(second),numerical_COS(second)__COS(second),numerical_SIN(day_of_year)__SIN(day_of_year),numerical_COS(day_of_year)__COS(day_of_year),numerical_SIN(day_of_month)__SIN(day_of_month),numerical_COS(day_of_month)__COS(day_of_month),numerical_SIN(day_of_week)__SIN(day_of_week),numerical_COS(day_of_week)__COS(day_of_week),numerical_EMA(is_weekend)__EMA(is_weekend),numerical_EMA(is_holidays)__EMA(is_holidays),numerical_electricity_price__electricity_price,numerical_min_gas_price__min_gas_price,numerical_max_gas_price__max_gas_price,numerical_eic_count__eic_count,numerical_pv_capacity__pv_capacity,numerical_production__production,numerical_consumption__consumption,boolean_is_business__is_business,boolean_is_holidays__is_holidays,boolean_is_weekend__is_weekend,categorical_county__county,categorical_product_type__product_type,categorical_holidays_name__holidays_name,categorical_holidays_type__holidays_type,original___PERCENTILE(production)==0,original___PERCENTILE(production)==1,original___PERCENTILE(production)==2,original___PERCENTILE(production)==3,original___PERCENTILE(production)==4,original___PERCENTILE(production)==5,original___PERCENTILE(production)==6,original___PERCENTILE(production)==7,original___PERCENTILE(production)==8,original___PERCENTILE(production)==9,original___PERCENTILE(production)==10,original___PERCENTILE(consumption)==0,original___PERCENTILE(consumption)==1,original___PERCENTILE(consumption)==2,original___PERCENTILE(consumption)==3,original___PERCENTILE(consumption)==4,original___PERCENTILE(consumption)==5,original___PERCENTILE(consumption)==6,original___PERCENTILE(consumption)==7,original___PERCENTILE(consumption)==8,original___PERCENTILE(consumption)==9,original___PERCENTILE(consumption)==10
datetime[μs],date,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,i8,i8,i8,i8,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
2021-09-01 00:00:00,2021-09-01,0.376948,-0.871093,-2.053114,1.357647,0.0,0.0,0.0,0.0,0.9213,-0.697167,-1.838528,1.437258,1.061016,0.5292,-0.903899,-0.236195,-0.536375,-1.058674,-1.134772,-0.474412,-0.433808,-0.231625,-0.334966,1.0,-1.0,-1.0,0,0,0,0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,1.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0
2021-09-01 01:00:00,2021-09-01,0.376948,-0.871093,-1.631562,1.345527,0.0,0.0,0.0,0.0,0.9213,-0.697167,-1.838528,1.437258,1.061016,0.5292,-0.903899,-0.236195,-0.566125,-1.058674,-1.134772,-0.474412,-0.433808,-0.231625,-0.3328,1.0,-1.0,-1.0,0,0,0,0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,1.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0
2021-09-01 02:00:00,2021-09-01,0.376948,-0.871093,-1.217223,1.309375,0.0,0.0,0.0,0.0,0.9213,-0.697167,-1.838528,1.437258,1.061016,0.5292,-0.903899,-0.236195,-0.578898,-1.058674,-1.134772,-0.474412,-0.433808,-0.231625,-0.331551,1.0,-1.0,-1.0,0,0,0,0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,1.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
2023-05-29 21:00:00,2023-05-29,1.203311,0.429836,-0.817186,-1.367865,0.0,0.0,0.0,0.0,1.228209,0.354315,-1.503138,-1.38849,-0.582514,1.508982,-0.171239,-0.236195,-0.620515,-1.417737,-1.30546,-0.127341,0.304445,-0.218953,-0.1375,1.0,-1.0,-1.0,15,3,0,0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,0.0,0.0,0.0,0.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,0.0,0.0,0.0
2023-05-29 22:00:00,2023-05-29,1.203311,0.429836,-1.217223,-1.427431,0.0,0.0,0.0,0.0,1.228209,0.354315,-1.503138,-1.38849,-0.582514,1.508982,-0.171239,-0.236195,-0.620597,-1.417737,-1.30546,-0.127341,0.304445,-0.231625,-0.134391,1.0,-1.0,-1.0,15,3,0,0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,0.0,0.0,0.0
2023-05-29 23:00:00,2023-05-29,1.203311,0.429836,-1.631562,-1.463583,0.0,0.0,0.0,0.0,1.228209,0.354315,-1.503138,-1.38849,-0.582514,1.508982,-0.171239,-0.236195,-0.661967,-1.417737,-1.30546,-0.127341,0.304445,-0.231625,-0.236633,1.0,-1.0,-1.0,15,3,0,0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,0.0,0.0,0.0,0.0


In [38]:
# Test inverse-transformation for numerical variables

column = 'production'

transformer = FeaturePipeline.named_transformers_[f'numerical_{column}']

X = DataTables['targets_x2'].select([column]).to_pandas().values
X_ = transformer.inverse_transform(
    DataTables['targets_featured'][[f'numerical_{column}__{column}']].to_pandas())

np.abs(X-X_).sum() / np.prod(X.shape)

4.818557570889618e-06

In [39]:
# Test inverse-transformation for categorical variables

column = 'county'
column_out = f'categorical_{column}'
column_classes = Dataset[column].unique().tolist()

transformer = FeaturePipeline.named_transformers_[column_out].steps[0][1]
transformer.get_feature_names_out(column_classes)

array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], dtype=object)

In [40]:
# Clear memory
del Dataset, FeatureSet, temp_df

In [41]:
DataTables['targets_featured'] = DataTables['targets_featured'].rename({
    col: col.replace('original__', 'ordinal__')
    for col in DataTables['targets_featured'].columns
    if col.startswith('original__')
})

DataTables['targets_featured'].head(2)

datetime,date,numerical_SIN(month)__SIN(month),numerical_COS(month)__COS(month),numerical_SIN(hour)__SIN(hour),numerical_COS(hour)__COS(hour),numerical_SIN(minute)__SIN(minute),numerical_COS(minute)__COS(minute),numerical_SIN(second)__SIN(second),numerical_COS(second)__COS(second),numerical_SIN(day_of_year)__SIN(day_of_year),numerical_COS(day_of_year)__COS(day_of_year),numerical_SIN(day_of_month)__SIN(day_of_month),numerical_COS(day_of_month)__COS(day_of_month),numerical_SIN(day_of_week)__SIN(day_of_week),numerical_COS(day_of_week)__COS(day_of_week),numerical_EMA(is_weekend)__EMA(is_weekend),numerical_EMA(is_holidays)__EMA(is_holidays),numerical_electricity_price__electricity_price,numerical_min_gas_price__min_gas_price,numerical_max_gas_price__max_gas_price,numerical_eic_count__eic_count,numerical_pv_capacity__pv_capacity,numerical_production__production,numerical_consumption__consumption,boolean_is_business__is_business,boolean_is_holidays__is_holidays,boolean_is_weekend__is_weekend,categorical_county__county,categorical_product_type__product_type,categorical_holidays_name__holidays_name,categorical_holidays_type__holidays_type,ordinal___PERCENTILE(production)==0,ordinal___PERCENTILE(production)==1,ordinal___PERCENTILE(production)==2,ordinal___PERCENTILE(production)==3,ordinal___PERCENTILE(production)==4,ordinal___PERCENTILE(production)==5,ordinal___PERCENTILE(production)==6,ordinal___PERCENTILE(production)==7,ordinal___PERCENTILE(production)==8,ordinal___PERCENTILE(production)==9,ordinal___PERCENTILE(production)==10,ordinal___PERCENTILE(consumption)==0,ordinal___PERCENTILE(consumption)==1,ordinal___PERCENTILE(consumption)==2,ordinal___PERCENTILE(consumption)==3,ordinal___PERCENTILE(consumption)==4,ordinal___PERCENTILE(consumption)==5,ordinal___PERCENTILE(consumption)==6,ordinal___PERCENTILE(consumption)==7,ordinal___PERCENTILE(consumption)==8,ordinal___PERCENTILE(consumption)==9,ordinal___PERCENTILE(consumption)==10
datetime[μs],date,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,i8,i8,i8,i8,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
2021-09-01 00:00:00,2021-09-01,0.376948,-0.871093,-2.053114,1.357647,0.0,0.0,0.0,0.0,0.9213,-0.697167,-1.838528,1.437258,1.061016,0.5292,-0.903899,-0.236195,-0.536375,-1.058674,-1.134772,-0.474412,-0.433808,-0.231625,-0.334966,1.0,-1.0,-1.0,0,0,0,0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,1.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0
2021-09-01 01:00:00,2021-09-01,0.376948,-0.871093,-1.631562,1.345527,0.0,0.0,0.0,0.0,0.9213,-0.697167,-1.838528,1.437258,1.061016,0.5292,-0.903899,-0.236195,-0.566125,-1.058674,-1.134772,-0.474412,-0.433808,-0.231625,-0.3328,1.0,-1.0,-1.0,0,0,0,0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,1.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0


In [42]:
%%time

id_columns = ['categorical_county__county',
              'categorical_product_type__product_type',
              'boolean_is_business__is_business',]
groupby_columns = id_columns + ['date']

aggregations = [pl.col('datetime').count().alias('HOURS')]
aggregations.extend([
     pl.col(col).flatten()
        for col in DataTables['targets_featured'].columns \
         if col not in groupby_columns
])

DataTables['targets_timeseries'] = \
DataTables['targets_featured'].sort(by=groupby_columns, descending=False)\
                             .group_by(groupby_columns).agg(aggregations)\
                              .sort(by=groupby_columns, descending=False)

DataTables['targets_timeseries'] = DataTables['targets_timeseries'].with_columns([
    pl.col(col).repeat_by("HOURS").alias(col+"_replicated") for col in id_columns
])

# Statistics for troubleshooting
display(DataTables['targets_timeseries']['HOURS'].value_counts().sort(by=['HOURS'], descending=False))

HOURS,counts
u32,u32
24,41919


CPU times: user 2.32 s, sys: 528 ms, total: 2.85 s
Wall time: 783 ms


# Data Splitting

> **Splitting Strategy**: refer to [Periodicity & Seasonality](https://www.kaggle.com/code/mrriandmstique/enefit-data-and-insights)
<br>`Choose validation date-range similar to private test (01-02-2024 -> 30-04-2024)`

#### ⚠️ After choosing the best model, it should be trained on full data.❗

In [43]:
%%time

target_dt_range = pl.col("date").is_between(datetime(2023, 2, 1), 
                                            datetime(2023, 6, 1), closed='both')

DataTables['timeseries_train'] = DataTables['targets_timeseries'].filter(~target_dt_range)
DataTables['timeseries_valid'] = DataTables['targets_timeseries'].filter(target_dt_range)

CPU times: user 160 ms, sys: 137 ms, total: 297 ms
Wall time: 80.9 ms


In [44]:
%%time

id_columns = ['categorical_county__county',
              'categorical_product_type__product_type',
              'boolean_is_business__is_business',]
groupby_columns = id_columns

aggregations = [pl.col('HOURS').sum(), pl.col('date').count().alias('DAYS')]
aggregations.extend([
     pl.col(col).flatten()
        for col in DataTables['targets_timeseries'].columns \
         if col not in groupby_columns + ['HOURS','date']
])

for subset in ['train','valid']:
    subset = f'timeseries_{subset}'
    DataTables[subset] = \
    DataTables[subset].sort(by=groupby_columns, descending=False)\
                     .group_by(groupby_columns).agg(aggregations)\
                      .sort(by=groupby_columns, descending=False)
    
    # display(DataTables[subset].head(2))

    # Statistics for troubleshooting
    stats = DataTables[subset]['HOURS'].value_counts().sort(by=['HOURS'], descending=False)
    display(stats)

HOURS,counts
u32,u32
720,1
5328,1
8616,1
…,…
11688,2
11712,3
12432,56


HOURS,counts
u32,u32
744,1
888,1
960,1
1368,1
2424,1
2832,64


CPU times: user 1.6 s, sys: 390 ms, total: 1.99 s
Wall time: 555 ms


In [45]:
# Clear memory

print(cpu_stats())

unused_tables = []
for k in DataTables.keys():
    
    if k == 'targets_featured':
        continue
        
    if 'timeseries' not in k:
        unused_tables.append(k)
        
for k in unused_tables:
    del DataTables[k]

print(cpu_stats())

Memory:3.32 GB
Memory:2.28 GB


<div>
    <h1 style="font-family:  'Garamond', 'Lucida Sans', sans-serif; text-align: center; color: #263A29; font-weight: bold; font-size: 36px;">
    Model Architecture
    </h1>
</div>
<hr>


In [46]:
HTML("""
<head>
    <meta name="viewport" content="width=device-width, initial-scale=1">
    <style>
    * {
      box-sizing: border-box;
    }
    
    /* Create two equal columns that floats next to each other */
    .column {
      float: left;
      width: 50%;
      padding: 10px;
      height: 300px; /* Should be removed. Only for demonstration */
    }
    
    /* Clear floats after the columns */
    .row:after {
      content: "";
      display: table;
      clear: both;
    }
    </style>
</head>

<body>
    <h2>Model Baselines</h2>
    <div class="row">
      <div class="column">
        <a href="https://github.com/lucidrains/ETSformer-pytorch" style="font-size:3vw">ETSformer</a>
        <img src="https://raw.githubusercontent.com/salesforce/etsformer/master/pics/etsformer.png" alt="drawing"/>
      </div>
      <div class="column">
        <a href="https://github.com/asmodaay/ti-maeurl" style="font-size:3vw">TI-MAE</a>
        <img src="https://d3i71xaburhd42.cloudfront.net/644b14c253bc76f0914b1645d7af59e6042d59f9/4-Figure3-1.png" alt="drawing"/>
      </div>
    </div>
    <div class="row">
        <a href="https://gaoxiangluo.github.io/2021/08/01/Group-Norm-Batch-Norm-Instance-Norm-which-is-better/" style="font-size:3vw">Normalizations</a>
        <img src="https://gaoxiangluo.github.io/2021/08/01/Group-Norm-Batch-Norm-Instance-Norm-which-is-better/figure6.png" alt="drawing"/>
    </div>
</body>
""")

In [47]:
import torch

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [48]:
forecast_window = 24
lookback_window = 24 * 3

tsdataset_schema = {
     "groupby_columns" : id_columns,
      "ignore_columns" : ['DAYS','datetime'],
       "count_column"  : 'HOURS',
     "forecast_window" : forecast_window,
     "lookback_window" : lookback_window,
}

with open('/kaggle/working/dataloader_schema.yaml', 'w') as file:
    yaml.dump(tsdataset_schema, file)

In [49]:
from models.etsformer.data_loader import TsDataLoaderSimple

DataLoader = dict()

for subset in ['train','valid']:
    DataLoader[subset] = TsDataLoaderSimple(DataTables[f'timeseries_{subset}'],
                                            is_shuffled = (subset == 'train'),
                                            time_skip = (6 if subset == 'train' else 24),
                                            batch_size = (16 if subset == 'train' else 32),
                                            **tsdataset_schema)

In [50]:
targets = ['production','consumption']

features_schema = []
for feat_name in DataTables['targets_featured'].columns:
    feat_type = feat_name.split('_')[0]
    if feat_type not in ['numerical','boolean','categorical','ordinal','multilabel','original']:
        continue
    feat_name_ = feat_name.replace(feat_type + '_', '')
    feat_schema = {
             'name' : feat_name + '_replicated' if any([col in feat_name for col in client_columns]) else feat_name,
             'type' : feat_type,
        'is_target' : any([t in feat_name_ for t in targets]),
    }
    if feat_type == 'categorical':
        # num classes = 1 (class=0) + max(class=N-1) + 1 (class=-1 / outlier)
        feat_schema.update({'n_classes': int(DataTables['targets_featured'].select(feat_name).to_pandas().values.flatten().max()+2)})
    features_schema.append(feat_schema)
    
features_schema = sorted(features_schema, key=lambda d: d['is_target'], reverse=True)
with open('/kaggle/working/features_schema.yaml', 'w') as file:
    yaml.dump(features_schema, file)

features_schema_df = pd.DataFrame.from_records(features_schema)
features_schema_df.T

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51
name,numerical_production__production,numerical_consumption__consumption,ordinal___PERCENTILE(production)==0,ordinal___PERCENTILE(production)==1,ordinal___PERCENTILE(production)==2,ordinal___PERCENTILE(production)==3,ordinal___PERCENTILE(production)==4,ordinal___PERCENTILE(production)==5,ordinal___PERCENTILE(production)==6,ordinal___PERCENTILE(production)==7,ordinal___PERCENTILE(production)==8,ordinal___PERCENTILE(production)==9,ordinal___PERCENTILE(production)==10,ordinal___PERCENTILE(consumption)==0,ordinal___PERCENTILE(consumption)==1,ordinal___PERCENTILE(consumption)==2,ordinal___PERCENTILE(consumption)==3,ordinal___PERCENTILE(consumption)==4,ordinal___PERCENTILE(consumption)==5,ordinal___PERCENTILE(consumption)==6,ordinal___PERCENTILE(consumption)==7,ordinal___PERCENTILE(consumption)==8,ordinal___PERCENTILE(consumption)==9,ordinal___PERCENTILE(consumption)==10,numerical_SIN(month)__SIN(month),numerical_COS(month)__COS(month),numerical_SIN(hour)__SIN(hour),numerical_COS(hour)__COS(hour),numerical_SIN(minute)__SIN(minute),numerical_COS(minute)__COS(minute),numerical_SIN(second)__SIN(second),numerical_COS(second)__COS(second),numerical_SIN(day_of_year)__SIN(day_of_year),numerical_COS(day_of_year)__COS(day_of_year),numerical_SIN(day_of_month)__SIN(day_of_month),numerical_COS(day_of_month)__COS(day_of_month),numerical_SIN(day_of_week)__SIN(day_of_week),numerical_COS(day_of_week)__COS(day_of_week),numerical_EMA(is_weekend)__EMA(is_weekend),numerical_EMA(is_holidays)__EMA(is_holidays),numerical_electricity_price__electricity_price,numerical_min_gas_price__min_gas_price,numerical_max_gas_price__max_gas_price,numerical_eic_count__eic_count,numerical_pv_capacity__pv_capacity,boolean_is_business__is_business_replicated,boolean_is_holidays__is_holidays,boolean_is_weekend__is_weekend,categorical_county__county_replicated,categorical_product_type__product_type_replicated,categorical_holidays_name__holidays_name,categorical_holidays_type__holidays_type
type,numerical,numerical,ordinal,ordinal,ordinal,ordinal,ordinal,ordinal,ordinal,ordinal,ordinal,ordinal,ordinal,ordinal,ordinal,ordinal,ordinal,ordinal,ordinal,ordinal,ordinal,ordinal,ordinal,ordinal,numerical,numerical,numerical,numerical,numerical,numerical,numerical,numerical,numerical,numerical,numerical,numerical,numerical,numerical,numerical,numerical,numerical,numerical,numerical,numerical,numerical,boolean,boolean,boolean,categorical,categorical,categorical,categorical
is_target,True,True,True,True,True,True,True,True,True,True,True,True,True,True,True,True,True,True,True,True,True,True,True,True,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False
n_classes,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,17.0,5.0,13.0,3.0


In [51]:
model_schema = {

    "normalization" : {
         "dropout_rate" : 0.369,
        "instance_norm" : False,     # Otherwise, BatchNorm
    },

    "transformation" : {
            "model_dim" :   512,   # 512 as in paper
    "embed_kernel_size" :     3,   # kernel size for 1d conv for input embedding
           "num_layers" :     2,   # number of encoder and decoder layers (corresponding)
            "num_heads" :     8,   # number of exponentially-smoothing attention heads
                 "topK" :     3,   # number of frequencies with highest amplitude to attend
              "dropout" : 0.369,   # 0.2 as in paper
    },

    "predictions": [
        {
            "target" : f"numerical_{col}__{col}",
              "name" : f"{col[0]}_rg",
            "weight" : 1.69,
              "task" : "regression",
              "loss" : "huber",
        } for col in ['production','consumption']
    ] + [
        {
            "target" : [f"ordinal___PERCENTILE({col})=={order}" for order in range(11)],
              "name" :  f"{col[0]}_cl",
            "weight" : 0.69,
              "task" : "classification",
              "loss" : "crossentroy",
        "train_only" : True,
        "multilabel" : True,
    "classes_weight" : [0.169 * (i+1) for i in range(11)], # latter classes equals to higher values
        } for col in ['production','consumption']
    ],
}

with open('/kaggle/working/model_schema.yaml', 'w') as file:
    yaml.dump(model_schema, file)

In [52]:
from models.etsformer.model_wrapper import MultiVariateForecast

Model = MultiVariateForecast(features_schema, model_schema)

ckpt_path = '/kaggle/input/enefit-artefacts/artefacts/etsformer/ordinal_regression_3d.pt'
if os.path.isfile(ckpt_path):
    print('Loading ...')
    Model.load_state_dict(torch.load(ckpt_path, map_location=device))

ckpt_path = '/kaggle/working/model.pt'

Loading ...


In [53]:
trainer_schema = {

    # Trainer configuration
    'lr' : 169e-5, 
    'lr_scheduler' : 'cosine', 
    'optimizer' : 'adam', 
    'device' : device,

    # Early Stopping
    'patience' : 19,  
    'verbose' : True,
    'checkpoint_path' : ckpt_path,

    # Other callbacks
    'tb_logs' : False,  # Tensorboard Logging

    # Training strategy
    'num_epochs' : 100, 
    'num_steps_forecast' : forecast_window,
}

with open('/kaggle/working/trainer_schema.yaml', 'w') as file:
    yaml.dump(trainer_schema, file)

In [54]:
from models.etsformer.trainer import TimeSeriesTrainer

trainer = TimeSeriesTrainer(model = Model,
                    train_dataset = DataLoader['train'],
                    valid_dataset = DataLoader['valid'], **trainer_schema)


  Model: Multi-Variate Forecasting with Exponential-Smoothing Time-Series Transformer  


MultiVariateForecast(
  (aggregate): FeaturesAggregation(
    (Embedders): ModuleDict(
      (categorical_county__county_replicated): Embedding(18, 3)
      (categorical_product_type__product_type_replicated): Embedding(6, 2)
      (categorical_holidays_name__holidays_name): Embedding(14, 3)
      (categorical_holidays_type__holidays_type): Embedding(4, 2)
    )
  )
  (normalize): FeaturesNormalization(
    (regularize): Dropout(p=0.369, inplace=False)
    (interconn): Linear(in_features=58, out_features=58, bias=True)
    (activate_): Mish()
    (normalize): BatchNorm1d(58, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (transform): ETSFormer(
    (extractor): Sequential(
      (0): Rearrange('b n d -> b d n')
      (1): Conv1d(58, 512, kernel_size=(3,), stride=(1,), padding=(1,))
      (2): Dropout(p=0.369, inplace=False)
      (3): Rearrange('b d n -> b n d')
    )
    (

In [None]:
trainer.train()

Adjusting learning rate of group 0 to 1.6900e-03.
Epoch 0 / 100


570/8.46k |▎   |06:11, p_rg:0.042-c_rg:0.026-p_cl:0.253-c_cl:0.123-lr:0.002-lr:1.69e-03 

<div>
    <h1 style="font-family:  'Garamond', 'Lucida Sans', sans-serif; text-align: center; color: #263A29; font-weight: bold; font-size: 36px;">
    Model Evaluation
    </h1>
</div>
<hr>

In [None]:
Model.load_state_dict(torch.load(ckpt_path, map_location=device))

In [None]:
%%time

features_order = Model.aggregate.features_order

all_client_predictions = pd.DataFrame()
for client_data in tqdm(DataTables['timeseries_valid'].iter_rows(named=True),
              total=len(DataTables['timeseries_valid'])):
    client_data = pl.from_dict(client_data).drop(columns=['HOURS','DAYS'])\
                                           .with_columns(pl.col('datetime').cast(pl.Date).alias('date'))

    client_dates = client_data['date'].unique().sort().to_list()
    for lookback_day, forecast_day in zip(client_dates[:-1], client_dates[1:]):
        lookback_data = client_data.filter(pl.col('date') == lookback_day).to_pandas()
        forecast_data = client_data.filter(pl.col('date') == forecast_day)\
                                   .select(id_columns + ['datetime', 'numerical_production__production',
                                                                     'numerical_consumption__consumption'])\
                                   .rename({'numerical_production__production'   : 'groundtruth|production',
                                            'numerical_consumption__consumption' : 'groundtruth|consumption',
                                          'categorical_county__county' : 'county', 
                                          'categorical_product_type__product_type' : 'product_type',
                                                'boolean_is_business__is_business' : 'is_business',})

        # Convert to tensor(s)
        X = {
                k: torch.tensor(v).unsqueeze(dim=0).to(device)
            for k, v in lookback_data.to_dict(orient='list').items()
             if k not in ['date','datetime']
        }

        # Predict
        preds = Model(X, forecast_window)

        # Combine
        forecast_values = {
            f'forecast|{k}': pl.lit(preds[f'scaled_{k}'].detach().cpu().squeeze().tolist())
                    for k in ['production','consumption']
        }
        forecast_data = forecast_data.with_columns(**forecast_values).to_pandas()
        all_client_predictions = pd.concat([all_client_predictions, forecast_data], axis=0)

all_client_predictions.head(3)

In [None]:
%%time

for target in ['production','consumption']:
    transformer = FeaturePipeline.named_transformers_[f'numerical_{target}']

    for prefix in ['groundtruth','forecast']:
        column = f'{prefix}|{target}'
        print(f"Converting {column} ...")
        
        all_client_predictions[column] = transformer.inverse_transform(all_client_predictions[[column]].values).flatten()

all_client_predictions.head(3)

In [None]:
# fig = go.Figure()

# fig.add_trace(go.Scatter(x=dailyProd.index, y=dailyProd.target, fill='tozeroy', mode='lines', line_color='#427D9D', name='Daily Production',))
# fig.add_trace(go.Scatter(x=dailyCons.index, y=dailyCons.target, fill='tonexty', mode='lines', line_color='#FA163F', name='Daily Consumption',))

# fig.add_hline(y=meanCons, line_dash="dot", line_color='red', annotation_text="Average Consumption",)
# fig.add_hline(y=meanProd, line_dash="dot", line_color='grey', annotation_text="Average Production",)

# fig.update_yaxes(type="log", range=[math.log10(0.1), math.log10(1000)])
# fig.update_layout(xaxis_title="Date", legend_x=0.69, 
#                   yaxis_title="Amount", legend_y=1.369,
#                         title=dict(text="Daily Consumption / Production", font={'size': 19}, automargin=True, yref='paper'),)
# fig.show()

<div>
    <h1 style="font-family:  'Garamond', 'Lucida Sans', sans-serif; text-align: center; color: #263A29; font-weight: bold; font-size: 36px;">
    Model Packaging
    </h1>
</div>
<hr>

# [Workspace Package](https://www.kaggle.com/code/andrewscholan/offline-package-wheeler-public)

<div>
    <h1 style="font-family:  'Garamond', 'Lucida Sans', sans-serif; text-align: center; color: #263A29; font-weight: bold; font-size: 36px;">
    Submission
    </h1>
</div>
<hr>

> **Template**:
<br>`
import enefit
env = enefit.make_env()
iter_test = env.iter_test()
for (test, revealed_targets, client, historical_weather,
           forecast_weather, electricity_prices, gas_prices, 
           sample_prediction) in iter_test:
    sample_prediction['target'] = 0
    env.predict(sample_prediction)
`

> **Reset Environment**:
<br>`
enefit.make_env.__called__ = False
type(env)._state = type(type(env)._state).__dict__['INIT']
iter_test = env.iter_test()
`

In [None]:
try:
    import enefit
    env = enefit.make_env()
    
except Exception:
    enefit.make_env.__called__ = False
    type(env)._state = type(type(env)._state).__dict__['INIT']
    iter_test = env.iter_test()

SHOW_DATE_ONLY = True

it = 0
iter_test = env.iter_test()
for (test, targets, client, weather_hist, weather_fx, 
            electr, gas, submission) in iter_test:
    it += 1
    print(f"\n\n\nIteration {it}")
    
    table_names = ['client','test','targets','weather_hist','weather_fx','electr','gas']
    for table in table_names:
        df = getattr(sys.modules[__name__], table)
        print('\n', table)
        
        display_columns = df.columns
        if SHOW_DATE_ONLY:
            display_columns = [col for col in df.columns if 'date' in col]
        display(df[display_columns].describe().T)
    
    submission['target'] = 0
    env.predict(submission)