In [1]:
import sys
from pathlib import Path

script_dir = Path().resolve()
root_dir = (script_dir.parent)
sys.path.append(str(root_dir))

import pandas as pd
import numpy as np

from endata.datasets.timeseries_dataset import TimeSeriesDataset
from endata.trainer import Trainer

  from .autonotebook import tqdm as notebook_tqdm


## Training a model from scratch ##

To train your own model from scratch, the ` Trainer ` class provides a simple implementation. Simply define your a custom dataset and a ` Trainer ` object, and call the ` Trainer ` 's ` fit() ` method.

## Writing a custom dataset ##

When creating a custom time series dataset class for use with EnData, the class must inherit from the provided `TimeSeriesDataset` base class. The `TimeSeriesDataset` class provides a robust and modular framework for handling wide-format time series data. Custom implementations only need to assign a name to `self.name` and implement the `_preprocess_data` method, which is an abstract method in the base class. This method should ensure that the data is available in a clean wide-format data frame, that has the structure outlined below.

### Responsibilities of `_preprocess_data`

- Preprocess raw input data into a DataFrame that satisfies the expected structure.
- Ensure time series columns contain arrays of the correct sequence length (`seq_len`).
- Add any additional columns, such as entity identifiers or context variables.

### Benefits of the Base Class

- **Normalization and Scaling:** Automatically handles standardization and min-max scaling.
- **context Variables:** Provides support for encoding and managing context variables.
- **Time Series Merging and Splitting:** Facilitates operations to merge multiple time series columns into a single multidimensional array and split them back when needed.
- **Data Transformation:** Includes functions for inverse transformations to revert normalized data to its original scale.

---

### Expected Input DataFrame Structure

The input to the `TimeSeriesDataset` class must adhere to the following structure:

| **Column Name**       | **Description**                                                                                     |
|------------------------|-----------------------------------------------------------------------------------------------------|
| `timeseries_col1`      | A column containing arrays of length `seq_len` (after preprocessing) representing the first dimension of the time series. |
| `timeseries_col2`      | A column containing arrays of length `seq_len` (after preprocessing) representing the second dimension of the time series.|
| `entity_column`        | A column containing unique identifiers for each entity (e.g., user, household, or device ID).       |
| `context_var1`    | An (optional) static or numeric context variable (e.g., categorical or continuous feature).                |
| `context_var2`    | Further (optional) static or numeric context variables.                                                    |

- The `time_series_column_names` parameter specifies which columns are part of the time series.
- The `entity_column_name` parameter identifies the column containing unique entity IDs.
- The `context_var_column_names` parameter defines additional context variables.

---

In [4]:
class CustomTimeSeriesDataset(TimeSeriesDataset):
    """
    A custom TimeSeriesDataset implementation for handling toy data.

    Input data structure:
    - time_series_col1, time_series_col2: Time series data with arrays of length seq_len.
    - entity_id: Unique identifier for each entity.
    - static_context: Categorical or numeric context variable.
    """
    def __init__(
        self,
        data: pd.DataFrame,
        seq_len: int = 16,
        normalize: bool = True,
        scale: bool = True,
    ):
        time_series_column_names = ["time_series_col1", "time_series_col2"]
        context_var_column_names = ["context_var"]

        super().__init__(
            data=data,
            time_series_column_names=time_series_column_names,
            context_var_column_names=context_var_column_names,
            seq_len=seq_len,
            normalize=normalize,
            scale=scale,
        )

    def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame:
        """
        Preprocesses the raw input data to ensure it conforms to the expected format.

        - Ensures time series columns contain arrays of length seq_len.
        - Ensures all required columns are present.

        Args:
            data (pd.DataFrame): The raw input data.

        Returns:
            pd.DataFrame: The preprocessed data.
        """
        required_columns = ["time_series_col1", "time_series_col2", "context_var"]
        for col in required_columns:
            if col not in data.columns:
                raise ValueError(f"Missing required column: {col}")

        for col in ["time_series_col1", "time_series_col2"]:
            data[col] = data[col].apply(
                lambda x: np.array(x).reshape(-1, 1) if isinstance(x, list) else x
            )
            data[col] = data[col].apply(
                lambda x: np.array(x) if isinstance(x, np.ndarray) else ValueError(f"Invalid data in {col}")
            )
        for col in ["time_series_col1", "time_series_col2"]:
            data[col] = data[col].apply(
                lambda x: x[:self.seq_len] if len(x) >= self.seq_len else ValueError(f"Sequence too short in {col}")
            )
        return data

Now that we have defined our dataset class, let's create some artificial timeseries columns and context variables which will comprise our dataset:

In [5]:
data = pd.DataFrame({
        "time_series_col1": [np.random.rand(16).tolist() for _ in range(100)],
        "time_series_col2": [np.random.rand(16).tolist() for _ in range(100)],
        "context_var": np.random.choice(["a", "b", "c"], size=100).tolist(),
    })

custom_dataset = CustomTimeSeriesDataset(data)
custom_dataset.data

Unnamed: 0,index,context_var,timeseries,is_frequency_rare,cluster,is_pattern_rare,is_rare
0,0,2,"[[1.2330412, 0.75476986], [0.47839186, 0.45179...",False,2,True,False
1,1,1,"[[0.22784828, 0.14329956], [0.45609343, 0.1140...",True,1,False,False
2,2,2,"[[0.9639308, 0.11796959], [0.029952327, 0.4122...",False,9,True,False
3,3,2,"[[1.0259975, 0.9888653], [0.23374613, 0.958439...",False,3,True,False
4,4,1,"[[0.7487496, 0.6402036], [0.34301788, 0.457638...",True,5,True,True
...,...,...,...,...,...,...,...
95,95,1,"[[1.1911998, 0.46278498], [0.5211637, 0.323662...",True,6,True,True
96,96,0,"[[0.80190253, 0.47433493], [0.29261357, 0.3672...",False,1,False,False
97,97,1,"[[0.37848985, 0.62985605], [-0.08124897, 0.364...",True,9,True,True
98,98,1,"[[0.9047128, 0.14532897], [-0.023065194, 0.476...",True,9,True,True


We will now create a `Trainer` object by passing the name of the desired model and the dataset object. To start training, simply call `Trainer.fit()`.

In [6]:
trainer = Trainer(model_name="acgan", dataset=custom_dataset)
trainer.fit()

Epoch 1: 100%|██████████| 1/1 [00:00<00:00,  2.31it/s]
Epoch 2: 100%|██████████| 1/1 [00:00<00:00, 24.33it/s]
Epoch 3: 100%|██████████| 1/1 [00:00<00:00, 26.08it/s]
Epoch 4: 100%|██████████| 1/1 [00:00<00:00, 24.67it/s]
Epoch 5: 100%|██████████| 1/1 [00:00<00:00, 21.35it/s]


Once training is complete, we can create a data generator object that has access to the trained model and dataset information. To generate data, there is no need to load in a trained model. Simply define the context variables, and call the `DataGenerator` 's `generate()` method.

In [7]:
data_generator = trainer.get_data_generator()

In [8]:
context_vars = {
    "context_var": 2
}
data_generator.set_model_context_vars(context_vars)
generated_df = data_generator.generate(num_samples=100)
generated_df

Unnamed: 0,context_var,timeseries
0,2,"[[0.4885186, 0.16913897], [0.6761445, 0.823993..."
1,2,"[[0.3628338, 0.1661479], [0.6694386, 0.6434852..."
2,2,"[[0.3997703, 0.20202447], [0.3828854, 0.322381..."
3,2,"[[0.25014755, 0.26958862], [0.5687821, 0.71890..."
4,2,"[[0.31721854, 0.46014428], [0.5777184, 0.33343..."
...,...,...
95,2,"[[0.39032197, 0.3213056], [0.51057315, 0.57828..."
96,2,"[[0.36170888, 0.20438792], [0.4092174, 0.50627..."
97,2,"[[0.39201736, 0.40636167], [0.6234959, 0.46320..."
98,2,"[[0.44466335, 0.24384849], [0.582526, 0.819687..."
