In [1]:
import sys
from pathlib import Path

script_dir = Path().resolve()
root_dir = (script_dir.parent)
sys.path.append(str(root_dir))

import pandas as pd
import numpy as np

from datasets.pecanstreet import PecanStreetDataset
from datasets.openpower import OpenPowerDataset
from datasets.timeseries_dataset import TimeSeriesDataset
from generator.data_generator import DataGenerator

2024-11-21 16:18:39.906619: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-11-21 16:18:39.920677: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-11-21 16:18:39.924983: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-11-21 16:18:39.935646: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
  from .autonotebook import tqdm as notebook_

## Training a model from scratch ##

To train your own model from scratch, the ` DataGenerator ` class provides a simple implementation. Simply define your dataset and a ` DataGenerator ` object, and call the ` DataGenerator ` 's ` fit() ` method as follows:

In [2]:
dataset = OpenPowerDataset()
generator = DataGenerator(model_name="acgan")

## Training a model on custom data ##

When creating a custom time series dataset class for use with EnData, the class must inherit from the provided `TimeSeriesDataset` base class. The `TimeSeriesDataset` class provides a robust and modular framework for handling wide-format time series data. Custom implementations only need to define the `_preprocess_data` method, which is an abstract method in the base class.

### Responsibilities of `_preprocess_data`

- Preprocess raw input data into a DataFrame that satisfies the expected structure.
- Ensure time series columns contain arrays of the correct sequence length (`seq_len`).
- Add any additional columns, such as entity identifiers or conditioning variables.

### Benefits of the Base Class

- **Normalization and Scaling:** Automatically handles standardization and min-max scaling.
- **Conditioning Variables:** Provides support for encoding and managing conditioning variables.
- **Time Series Merging and Splitting:** Facilitates operations to merge multiple time series columns into a single multidimensional array and split them back when needed.
- **Data Transformation:** Includes functions for inverse transformations to revert normalized data to its original scale.

---

### Expected Input DataFrame Structure

The input to the `TimeSeriesDataset` class must adhere to the following structure:

| **Column Name**       | **Description**                                                                                     |
|------------------------|-----------------------------------------------------------------------------------------------------|
| `timeseries_col1`      | A column containing arrays of length `seq_len` (after preprocessing) representing the first dimension of the time series. |
| `timeseries_col2`      | A column containing arrays of length `seq_len` (after preprocessing) representing the second dimension of the time series.|
| `entity_column`        | A column containing unique identifiers for each entity (e.g., user, household, or device ID).       |
| `conditioning_var1`    | An (optional) static or numeric conditioning variable (e.g., categorical or continuous feature).                |
| `conditioning_var2`    | Further (optional) static or numeric conditioning variables.                                                    |

- The `time_series_column_names` parameter specifies which columns are part of the time series.
- The `entity_column_name` parameter identifies the column containing unique entity IDs.
- The `conditioning_var_column_names` parameter defines additional conditioning variables.

---

In [3]:
class CustomTimeSeriesDataset(TimeSeriesDataset):
    """
    A custom TimeSeriesDataset implementation for handling toy data.

    Input data structure:
    - time_series_col1, time_series_col2: Time series data with arrays of length seq_len.
    - entity_id: Unique identifier for each entity.
    - static_conditioning: Categorical or numeric conditioning variable.
    """
    def __init__(
        self,
        data: pd.DataFrame,
        seq_len: int = 16,
        normalize: bool = True,
        scale: bool = True,
    ):
        entity_column_name = "entity_id"
        time_series_column_names = ["time_series_col1", "time_series_col2"]
        conditioning_var_column_names = ["conditioning_var"]

        super().__init__(
            data=data,
            entity_column_name=entity_column_name,
            time_series_column_names=time_series_column_names,
            conditioning_var_column_names=conditioning_var_column_names,
            seq_len=seq_len,
            normalize=normalize,
            scale=scale,
        )

    def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame:
        """
        Preprocesses the raw input data to ensure it conforms to the expected format.

        - Ensures time series columns contain arrays of length seq_len.
        - Ensures all required columns are present.

        Args:
            data (pd.DataFrame): The raw input data.

        Returns:
            pd.DataFrame: The preprocessed data.
        """
        required_columns = ["entity_id", "time_series_col1", "time_series_col2", "conditioning_var"]
        for col in required_columns:
            if col not in data.columns:
                raise ValueError(f"Missing required column: {col}")

        for col in ["time_series_col1", "time_series_col2"]:
            data[col] = data[col].apply(
                lambda x: np.array(x).reshape(-1, 1) if isinstance(x, list) else x
            )
            data[col] = data[col].apply(
                lambda x: np.array(x) if isinstance(x, np.ndarray) else ValueError(f"Invalid data in {col}")
            )
        for col in ["time_series_col1", "time_series_col2"]:
            data[col] = data[col].apply(
                lambda x: x[:self.seq_len] if len(x) >= self.seq_len else ValueError(f"Sequence too short in {col}")
            )
        return data

In [4]:
data = pd.DataFrame({
        "entity_id": [f"entity_{i}" for i in range(100)],
        "time_series_col1": [np.random.rand(16).tolist() for _ in range(100)],
        "time_series_col2": [np.random.rand(16).tolist() for _ in range(100)],
        "conditioning_var": [np.random.randint(0, 5) for _ in range(100)],
    })

custom_dataset = CustomTimeSeriesDataset(data)

In [5]:
generator.set_dataset(custom_dataset)
generator.fit()

Epoch 1: 100%|██████████| 4/4 [00:00<00:00,  8.09it/s]
Epoch 2: 100%|██████████| 4/4 [00:00<00:00, 41.60it/s]
Epoch 3: 100%|██████████| 4/4 [00:00<00:00, 35.36it/s]
Epoch 4: 100%|██████████| 4/4 [00:00<00:00, 36.70it/s]
Epoch 5: 100%|██████████| 4/4 [00:00<00:00, 35.50it/s]
Epoch 6: 100%|██████████| 4/4 [00:00<00:00, 36.86it/s]
Epoch 7: 100%|██████████| 4/4 [00:00<00:00, 35.06it/s]
Epoch 8: 100%|██████████| 4/4 [00:00<00:00, 34.38it/s]
Epoch 9: 100%|██████████| 4/4 [00:00<00:00, 33.62it/s]
Epoch 10: 100%|██████████| 4/4 [00:00<00:00, 35.72it/s]
Epoch 11: 100%|██████████| 4/4 [00:00<00:00, 35.09it/s]
Epoch 12: 100%|██████████| 4/4 [00:00<00:00, 35.90it/s]
Epoch 13: 100%|██████████| 4/4 [00:00<00:00, 36.27it/s]
Epoch 14: 100%|██████████| 4/4 [00:00<00:00, 34.59it/s]
Epoch 15: 100%|██████████| 4/4 [00:00<00:00, 37.40it/s]
Epoch 16: 100%|██████████| 4/4 [00:00<00:00, 34.14it/s]
Epoch 17: 100%|██████████| 4/4 [00:00<00:00, 36.76it/s]
Epoch 18: 100%|██████████| 4/4 [00:00<00:00, 34.00it/s]
E

Saved ACGAN checkpoint to /home/fuest/EnData/tutorials/acgan_checkpoint_200.pt



