In [None]:
!pip install transformers[torch]

In [None]:
pip install safetensors

In [None]:
pip install wandb

In [None]:
!wandb login 8a206cc2b554854126c7be0cbbecac026c0b15c6


# Import section

In [44]:
# Standard
import os
import numpy     as np
import pandas    as pd
import copy
from   itertools import starmap
from   typing    import Any, Dict, List, Optional, Union
from   pathlib   import Path
import wandb
from datetime import datetime, timedelta

# Third Party
import torch
from transformers import (
    EarlyStoppingCallback,
    PatchTSTConfig,
    PatchTSTForPrediction,
    Trainer,
    TrainingArguments,
)

from transformers import set_seed

set_seed(2023)


# Classes for data managing

In [3]:
def ts_padding(
    df: pd.DataFrame, timestamp_column: str = None, context_length: int = 1) -> pd.DataFrame:
    l = len(df)
    if l >= context_length: return df
    fill_length = context_length - l

    # create dataframe
    pad_df = pd.DataFrame(np.zeros([fill_length, df.shape[1]]), columns=df.columns)

    for c in df.columns:
        if c == timestamp_column: continue
        pad_df[c] = pad_df[c].astype(df.dtypes[c], copy=False)

    if (df[timestamp_column].dtype.type == np.datetime64) or (df[timestamp_column].dtype == int):
        last_timestamp = df.iloc[0][timestamp_column]
        period = df.iloc[1][timestamp_column] - df.iloc[0][timestamp_column]
        prepended_timestamps = [last_timestamp + offset * period for offset in range(-fill_length, 0)]
        pad_df[timestamp_column] = prepended_timestamps
    else:
        pad_df[timestamp_column] = None
    pad_df[timestamp_column] = pad_df[timestamp_column].astype(df[timestamp_column].dtype)

    # combine the data
    new_df = pd.concat([pad_df, df])
    return new_df

def np_to_torch(data: np.array, float_type=np.float32):
    if data.dtype == "float":
        return torch.from_numpy(data.astype(float_type))
    elif data.dtype == "int":
        return torch.from_numpy(data)
    return torch.from_numpy(data)

In [4]:
class BaseDFDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        data_df           : pd.DataFrame,
        timestamp_column  : str,
        x_cols            : list = [],
        y_cols            : list = [],
        context_length    : int  = 1,
        prediction_length : int  = 0,
        zero_padding      : bool = True,
    ):
        super().__init__()
        if not isinstance(x_cols, list): x_cols = [x_cols]
        if not isinstance(y_cols, list): y_cols = [y_cols]

        self.data_df           = data_df
        self.datetime_col      = timestamp_column
        self.x_cols            = x_cols
        self.y_cols            = y_cols
        self.context_length    = context_length
        self.prediction_length = prediction_length
        self.zero_padding      = zero_padding
        
        # sort the data by datetime
        data_df[timestamp_column] = pd.to_datetime(data_df[timestamp_column])
        self.data_df = data_df.sort_values(timestamp_column, ignore_index=True)

        # pad zero to the data_df if the len is shorter than seq_len+pred_len
        if zero_padding:
            self.data_df = self.pad_zero(data_df)

        self.timestamps  = self.data_df[timestamp_column].values

    def pad_zero(self, data_df):
        return ts_padding(
            data_df,
            timestamp_column= self.datetime_col,
            context_length  = self.context_length + self.prediction_length,
        )

    def __len__(self):
        return len(self.data_df) - self.context_length - self.prediction_length + 1

    def __getitem__(self, index: int):
        """
        Args:
            index (int): Index
        Returns:
            (Any): Sample and meta data, optionally transformed by the respective transforms.
        """
        raise NotImplementedError

In [5]:
class ForecastDFDataset(BaseDFDataset):
    def __init__(
        self,
        data_df          : pd.DataFrame,
        timestamp_column : str,
        context_length   : int        = 1,
        prediction_length: int        = 1,
        target_columns   : List[str]  = [],           
    ):
        # TODO complete me
        self.target_columns          = list(target_columns)
        x_cols = self.target_columns
        y_cols = copy.copy(x_cols)

        super().__init__(
            data_df           = data_df,
            timestamp_column  = timestamp_column,
            x_cols            = x_cols,
            y_cols            = y_cols,
            context_length    = context_length,
            prediction_length = prediction_length
        )

    def __getitem__(self, time_id):
        # seq_x: batch_size x seq_len x num_x_cols
        s_pos = time_id
        e_pos = time_id + self.context_length
        seq_x = self.data_df.loc[s_pos : e_pos-1, self.x_cols].values
            
        # seq_y: batch_size x pred_len x num_x_cols
        seq_y = self.data_df.loc[
            e_pos : e_pos + self.prediction_length -1,
            self.y_cols
        ].values

        ret = {
            "past_values"  : np_to_torch(seq_x),
            "future_values": np_to_torch(seq_y),
        }

        return ret

    def __len__(self):
        return len(self.data_df) - self.context_length - self.prediction_length + 1

# Data Loading

`timestamp_column`: Nome della colonna contenente i timestamp.

`forecast_columns`: Lista di colonne che devono essere predette.

`context_length`: Dato del dataframe che verrà utilizzato per il training. 

`forecast_horizon`: Numero di timestamp che devono essere predetti

`patch_length`: Dimensione del `PatchTST` model (deve essere un valore che divide interamente context_length)

In [6]:
emur_dataset      = pd.read_csv("./time_series.csv", parse_dates=["date"])

timestamp_column = "date"
forecast_columns = emur_dataset.columns[1:]

context_length   = 8
forecast_horizon = 2
patch_length     = 2
num_workers      = num_cpus = len(os.sched_getaffinity(0)) # CPUs numbers
batch_size       = 256 

In [7]:
display(emur_dataset)

Unnamed: 0,date,1,2,3,4,5,6,7,8,9,...,22,23,24,25,26,27,28,29,30,31
0,2021-01-01,0.0,0.0,0.0,0.0,7.0,5.0,1.0,0.0,0.0,...,0.0,6.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,2021-01-02,0.0,1.0,1.0,3.0,3.0,5.0,1.0,0.0,0.0,...,0.0,7.0,0.0,0.0,0.0,0.0,3.0,1.0,0.0,0.0
2,2021-01-03,0.0,2.0,0.0,1.0,3.0,3.0,0.0,0.0,0.0,...,5.0,11.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0
3,2021-01-04,0.0,1.0,3.0,1.0,5.0,4.0,0.0,0.0,0.0,...,3.0,23.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,2021-01-05,0.0,0.0,1.0,3.0,5.0,8.0,0.0,0.0,1.0,...,3.0,10.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1090,2023-12-27,1.0,0.0,3.0,6.0,9.0,9.0,2.0,0.0,0.0,...,6.0,29.0,0.0,0.0,0.0,0.0,3.0,1.0,0.0,0.0
1091,2023-12-28,0.0,1.0,4.0,9.0,6.0,8.0,0.0,0.0,0.0,...,9.0,16.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0
1092,2023-12-29,0.0,0.0,2.0,7.0,5.0,2.0,0.0,0.0,0.0,...,4.0,17.0,0.0,0.0,0.0,0.0,0.0,3.0,0.0,0.0
1093,2023-12-30,0.0,1.0,4.0,11.0,2.0,6.0,0.0,0.0,0.0,...,2.0,14.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


# Splitting

In [8]:
num_train  = int(len(emur_dataset) * 0.7)
num_test   = int(len(emur_dataset) * 0.2)
num_valid  = len(emur_dataset) - num_train - num_test

# NOTA: TEST E VALIDATION SONO SHIFTATI DI context_length cosi' la prima predizione segue immediatamente il training e il val.
f1, e1     = 0, num_train
f2, e2     = e1-context_length, e1+num_valid
f3, e3     = e2-context_length, len(emur_dataset)

train_data = emur_dataset.iloc[f1:e1, :].reset_index(drop=True)
valid_data = emur_dataset.iloc[f2:e2, :].reset_index(drop=True)
test_data  = emur_dataset.iloc[f3:e3, :].reset_index(drop=True)

In [9]:
display(train_data)
display(valid_data)
display(test_data)


Unnamed: 0,date,1,2,3,4,5,6,7,8,9,...,22,23,24,25,26,27,28,29,30,31
0,2021-01-01,0.0,0.0,0.0,0.0,7.0,5.0,1.0,0.0,0.0,...,0.0,6.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,2021-01-02,0.0,1.0,1.0,3.0,3.0,5.0,1.0,0.0,0.0,...,0.0,7.0,0.0,0.0,0.0,0.0,3.0,1.0,0.0,0.0
2,2021-01-03,0.0,2.0,0.0,1.0,3.0,3.0,0.0,0.0,0.0,...,5.0,11.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0
3,2021-01-04,0.0,1.0,3.0,1.0,5.0,4.0,0.0,0.0,0.0,...,3.0,23.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,2021-01-05,0.0,0.0,1.0,3.0,5.0,8.0,0.0,0.0,1.0,...,3.0,10.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
761,2023-02-01,0.0,0.0,4.0,7.0,5.0,8.0,0.0,0.0,1.0,...,1.0,27.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
762,2023-02-02,0.0,1.0,1.0,3.0,2.0,1.0,0.0,0.0,0.0,...,4.0,17.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0
763,2023-02-03,0.0,2.0,1.0,5.0,4.0,4.0,3.0,1.0,0.0,...,1.0,22.0,0.0,0.0,0.0,0.0,2.0,0.0,0.0,0.0
764,2023-02-04,1.0,0.0,6.0,3.0,5.0,2.0,2.0,0.0,3.0,...,1.0,17.0,0.0,0.0,0.0,0.0,2.0,0.0,0.0,0.0


Unnamed: 0,date,1,2,3,4,5,6,7,8,9,...,22,23,24,25,26,27,28,29,30,31
0,2023-01-29,0.0,2.0,0.0,5.0,1.0,1.0,1.0,0.0,0.0,...,2.0,11.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,2023-01-30,0.0,0.0,1.0,11.0,7.0,3.0,0.0,0.0,1.0,...,3.0,16.0,0.0,0.0,0.0,0.0,2.0,1.0,0.0,0.0
2,2023-01-31,0.0,1.0,5.0,3.0,2.0,3.0,0.0,1.0,0.0,...,2.0,16.0,0.0,0.0,0.0,0.0,4.0,1.0,0.0,0.0
3,2023-02-01,0.0,0.0,4.0,7.0,5.0,8.0,0.0,0.0,1.0,...,1.0,27.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,2023-02-02,0.0,1.0,1.0,3.0,2.0,1.0,0.0,0.0,0.0,...,4.0,17.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
113,2023-05-22,0.0,1.0,4.0,4.0,9.0,4.0,0.0,0.0,1.0,...,5.0,28.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0
114,2023-05-23,0.0,4.0,0.0,10.0,2.0,1.0,0.0,0.0,1.0,...,6.0,21.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0
115,2023-05-24,0.0,0.0,2.0,6.0,3.0,5.0,1.0,0.0,0.0,...,3.0,19.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0
116,2023-05-25,0.0,0.0,7.0,11.0,6.0,0.0,0.0,0.0,0.0,...,2.0,16.0,0.0,0.0,0.0,0.0,2.0,1.0,0.0,0.0


Unnamed: 0,date,1,2,3,4,5,6,7,8,9,...,22,23,24,25,26,27,28,29,30,31
0,2023-05-19,0.0,1.0,4.0,4.0,6.0,4.0,0.0,0.0,1.0,...,5.0,21.0,0.0,0.0,0.0,0.0,0.0,2.0,0.0,0.0
1,2023-05-20,0.0,0.0,1.0,5.0,5.0,1.0,1.0,0.0,1.0,...,6.0,13.0,0.0,0.0,0.0,0.0,4.0,2.0,0.0,0.0
2,2023-05-21,0.0,1.0,4.0,3.0,4.0,2.0,0.0,0.0,0.0,...,2.0,14.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0
3,2023-05-22,0.0,1.0,4.0,4.0,9.0,4.0,0.0,0.0,1.0,...,5.0,28.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0
4,2023-05-23,0.0,4.0,0.0,10.0,2.0,1.0,0.0,0.0,1.0,...,6.0,21.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
222,2023-12-27,1.0,0.0,3.0,6.0,9.0,9.0,2.0,0.0,0.0,...,6.0,29.0,0.0,0.0,0.0,0.0,3.0,1.0,0.0,0.0
223,2023-12-28,0.0,1.0,4.0,9.0,6.0,8.0,0.0,0.0,0.0,...,9.0,16.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0
224,2023-12-29,0.0,0.0,2.0,7.0,5.0,2.0,0.0,0.0,0.0,...,4.0,17.0,0.0,0.0,0.0,0.0,0.0,3.0,0.0,0.0
225,2023-12-30,0.0,1.0,4.0,11.0,2.0,6.0,0.0,0.0,0.0,...,2.0,14.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


# Dataset creation

In [10]:
train_dataset = ForecastDFDataset(
    train_data,
    timestamp_column  = "date",
    target_columns    = forecast_columns,
    context_length    = context_length,
    prediction_length = forecast_horizon,
)

valid_dataset = ForecastDFDataset(
    valid_data,
    timestamp_column  = "date",
    target_columns    = forecast_columns,
    context_length    = context_length,
    prediction_length = forecast_horizon,
)

test_dataset = ForecastDFDataset(
    test_data,
    timestamp_column  = "date",
    target_columns    = forecast_columns,
    context_length    = context_length,
    prediction_length = forecast_horizon,
)

# Model creation

In [11]:
config = PatchTSTConfig(
    num_input_channels=len(forecast_columns),
    context_length=context_length,
    patch_length=patch_length,
    patch_stride=patch_length,
    prediction_length=forecast_horizon,
    random_mask_ratio=0.4,
    d_model=128,
    num_attention_heads=16,
    num_hidden_layers=3,
    ffn_dim=256,
    dropout=0.2,
    head_dropout=0.2,
    pooling_type=None,
    channel_attention=False,
    scaling="std",
    loss="mse",
    pre_norm=True,
    norm_type="batchnorm",
)
model = PatchTSTForPrediction(config)

# Main

In [12]:
log_path   = "./data/emur/log/"
model_path = "./data/emur/model/"

Path(log_path).mkdir(parents=True, exist_ok=True)
Path(model_path).mkdir(parents=True, exist_ok=True)

In [13]:
training_args  = TrainingArguments(
    output_dir                  = model_path,
    overwrite_output_dir        = True,
    num_train_epochs            = 100,
    do_eval                     = True,
    eval_strategy               = "epoch",
    per_device_train_batch_size = batch_size,
    per_device_eval_batch_size  = batch_size,
    dataloader_num_workers      = num_workers,
    save_strategy               = "epoch",
    logging_strategy            = "epoch",
    save_total_limit            = 3,
    logging_dir                 = log_path,
    load_best_model_at_end      = True,
    metric_for_best_model       = "eval_loss",
    greater_is_better           = False,
    label_names                 = ["future_values"],
    report_to="wandb",
)

# Create the early stopping callback
early_stopping_callback      = EarlyStoppingCallback(
    early_stopping_patience  = 10,      # Number of epochs with no improvement after which to stop
    early_stopping_threshold = 0.0001,  # Minimum improvement required to consider as improvement
)

# define trainer
trainer = Trainer(
    model         = model,
    args          = training_args,
    train_dataset = train_dataset,
    eval_dataset  = valid_dataset,
    callbacks     = [early_stopping_callback],
)



In [None]:
# start a new wandb run to track this script
wandb.init(
    project="emur-analysis",
    name="emur",
)

In [None]:
# pretrain
trainer.train()

In [None]:
results = trainer.evaluate(test_dataset)
print("Test result:")
print(results)

In [None]:
save_dir = "data/emur/model/pretrain/"
os.makedirs(save_dir, exist_ok=True)
trainer.save_model(save_dir)