In [14]:
import sys
from pathlib import Path

script_dir = Path().resolve()
root_dir = (script_dir.parent)
sys.path.append(str(root_dir))

import pandas as pd
import numpy as np

from datasets.timeseries_dataset import TimeSeriesDataset
from endata.trainer import Trainer

## Training a model from scratch ##

To train your own model from scratch, the ` Trainer ` class provides a simple implementation. Simply define your a custom dataset and a ` Trainer ` object, and call the ` Trainer ` 's ` fit() ` method.

## Writing a custom dataset ##

When creating a custom time series dataset class for use with EnData, the class must inherit from the provided `TimeSeriesDataset` base class. The `TimeSeriesDataset` class provides a robust and modular framework for handling wide-format time series data. Custom implementations only need to assign a name to `self.name` and implement the `_preprocess_data` method, which is an abstract method in the base class. This method should ensure that the data is available in a clean wide-format data frame, that has the structure outlined below.

### Responsibilities of `_preprocess_data`

- Preprocess raw input data into a DataFrame that satisfies the expected structure.
- Ensure time series columns contain arrays of the correct sequence length (`seq_len`).
- Add any additional columns, such as entity identifiers or conditioning variables.

### Benefits of the Base Class

- **Normalization and Scaling:** Automatically handles standardization and min-max scaling.
- **Conditioning Variables:** Provides support for encoding and managing conditioning variables.
- **Time Series Merging and Splitting:** Facilitates operations to merge multiple time series columns into a single multidimensional array and split them back when needed.
- **Data Transformation:** Includes functions for inverse transformations to revert normalized data to its original scale.

---

### Expected Input DataFrame Structure

The input to the `TimeSeriesDataset` class must adhere to the following structure:

| **Column Name**       | **Description**                                                                                     |
|------------------------|-----------------------------------------------------------------------------------------------------|
| `timeseries_col1`      | A column containing arrays of length `seq_len` (after preprocessing) representing the first dimension of the time series. |
| `timeseries_col2`      | A column containing arrays of length `seq_len` (after preprocessing) representing the second dimension of the time series.|
| `entity_column`        | A column containing unique identifiers for each entity (e.g., user, household, or device ID).       |
| `conditioning_var1`    | An (optional) static or numeric conditioning variable (e.g., categorical or continuous feature).                |
| `conditioning_var2`    | Further (optional) static or numeric conditioning variables.                                                    |

- The `time_series_column_names` parameter specifies which columns are part of the time series.
- The `entity_column_name` parameter identifies the column containing unique entity IDs.
- The `conditioning_var_column_names` parameter defines additional conditioning variables.

---

In [15]:
class CustomTimeSeriesDataset(TimeSeriesDataset):
    """
    A custom TimeSeriesDataset implementation for handling toy data.

    Input data structure:
    - time_series_col1, time_series_col2: Time series data with arrays of length seq_len.
    - entity_id: Unique identifier for each entity.
    - static_conditioning: Categorical or numeric conditioning variable.
    """
    def __init__(
        self,
        data: pd.DataFrame,
        seq_len: int = 16,
        normalize: bool = True,
        scale: bool = True,
    ):
        time_series_column_names = ["time_series_col1", "time_series_col2"]
        conditioning_var_column_names = ["conditioning_var"]

        super().__init__(
            data=data,
            time_series_column_names=time_series_column_names,
            conditioning_var_column_names=conditioning_var_column_names,
            seq_len=seq_len,
            normalize=normalize,
            scale=scale,
        )

    def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame:
        """
        Preprocesses the raw input data to ensure it conforms to the expected format.

        - Ensures time series columns contain arrays of length seq_len.
        - Ensures all required columns are present.

        Args:
            data (pd.DataFrame): The raw input data.

        Returns:
            pd.DataFrame: The preprocessed data.
        """
        required_columns = ["time_series_col1", "time_series_col2", "conditioning_var"]
        for col in required_columns:
            if col not in data.columns:
                raise ValueError(f"Missing required column: {col}")

        for col in ["time_series_col1", "time_series_col2"]:
            data[col] = data[col].apply(
                lambda x: np.array(x).reshape(-1, 1) if isinstance(x, list) else x
            )
            data[col] = data[col].apply(
                lambda x: np.array(x) if isinstance(x, np.ndarray) else ValueError(f"Invalid data in {col}")
            )
        for col in ["time_series_col1", "time_series_col2"]:
            data[col] = data[col].apply(
                lambda x: x[:self.seq_len] if len(x) >= self.seq_len else ValueError(f"Sequence too short in {col}")
            )
        return data

Now that we have defined our dataset class, let's create some artificial timeseries columns and conditioning variables which will comprise our dataset:

In [16]:
data = pd.DataFrame({
        "time_series_col1": [np.random.rand(16).tolist() for _ in range(100)],
        "time_series_col2": [np.random.rand(16).tolist() for _ in range(100)],
        "conditioning_var": np.random.choice(["a", "b", "c"], size=100).tolist(),
    })

custom_dataset = CustomTimeSeriesDataset(data)
custom_dataset.data

Unnamed: 0,index,conditioning_var,timeseries,is_frequency_rare,cluster,is_pattern_rare,is_rare
0,0,0,"[[0.92653304, 1.009781], [0.42525345, 0.831920...",True,3,True,True
1,1,0,"[[0.776322, 1.0532234], [0.49634635, 0.1887229...",True,4,False,False
2,2,1,"[[0.24718751, 0.25115073], [0.7902722, 0.31481...",False,5,False,False
3,3,0,"[[0.567987, 0.6009842], [0.2487371, 0.774192],...",True,4,False,False
4,4,0,"[[0.23009494, 0.72973764], [0.51643836, 0.3718...",True,4,False,False
...,...,...,...,...,...,...,...
95,95,2,"[[0.77344924, 0.7023896], [0.758659, 0.5280927...",False,7,True,False
96,96,0,"[[0.50343215, 0.7812002], [0.68687725, 1.02030...",True,1,True,True
97,97,1,"[[0.41048184, 0.5843879], [0.60409164, 1.09790...",False,4,False,False
98,98,2,"[[0.38684386, 0.47564107], [0.93689454, 0.3493...",False,9,True,False


We will now create a `Trainer` object by passing the name of the desired model and the dataset object. To start training, simply call `Trainer.fit()`.

In [17]:
trainer = Trainer(model_name="acgan", dataset=custom_dataset)
trainer.fit()

Epoch 1: 100%|██████████| 1/1 [00:00<00:00, 21.53it/s]
Epoch 2: 100%|██████████| 1/1 [00:00<00:00, 24.58it/s]
Epoch 3: 100%|██████████| 1/1 [00:00<00:00, 23.85it/s]
Epoch 4: 100%|██████████| 1/1 [00:00<00:00, 23.18it/s]
Epoch 5: 100%|██████████| 1/1 [00:00<00:00, 25.27it/s]


Once training is complete, we can create a data generator object that has access to the trained model and dataset information. To generate data, there is no need to load in a trained model. Simply define the conditioning variables, and call the `DataGenerator` 's `generate()` method.

In [18]:
data_generator = trainer.get_data_generator()

In [19]:
conditioning_vars = {
    "conditioning_var": 2
}
data_generator.set_model_conditioning_vars(conditioning_vars)
generated_df = data_generator.generate(num_samples=100)
generated_df

Unnamed: 0,conditioning_var,timeseries
0,2,"[[0.33624244, 0.30665344], [0.15576483, 0.1029..."
1,2,"[[0.27478665, 0.20339221], [0.18739967, 0.4558..."
2,2,"[[0.5461618, 0.21532051], [-0.124457866, 0.051..."
3,2,"[[0.5325349, 0.24832326], [0.3949008, -0.01513..."
4,2,"[[0.5069144, 0.21480957], [0.32279572, 0.14314..."
...,...,...
95,2,"[[0.3393251, 0.14905047], [0.3640431, 0.159490..."
96,2,"[[0.5540692, 0.32635656], [0.3118684, 0.286660..."
97,2,"[[0.41061264, 0.41596156], [0.5272329, 0.21765..."
98,2,"[[0.46809554, 0.32013273], [0.31512564, 0.1190..."
