In [None]:
from __future__ import annotations

In [None]:
import abc
from configparser import ConfigParser
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Union
import logging

In [None]:
import dask.bag as db  # type: ignore
import dask.dataframe as dd  # type: ignore
import pandas as pd  # type: ignore

In [None]:
from eos.data_io.eos_struct import ObservationMeta, PoolQuery

In [None]:
from .pool import Pool  # type: ignore

In [None]:
@dataclass(kw_only=True)
class DaskPool(Pool[pd.DataFrame]):
    """
    This is the numpy array storage for pooling the real-time data from the cloud.
    RecordDataframePool is supposed to support large local data pool with buffer capacity
    only bounded by local system storage.
    In one single folder.
    It uses pyarrow parquet for data storage and dask dataframe for data processing.
    meta information is stored in parquet metadata (in footer of parquet file).

    TODO alternative will be save experience in episodes
    TODO sample random quadruples will need some care to reassure the randomness

    TODO using dask dataframe for data processing
    TODO using dask delayed to parallelize the data processing like sampling, while appending data
    for each
    NPAStore provides the following features:
    - location: can be provided to change the default location in recipe
    - recipe: the config file for the pool
    """

    recipe: ConfigParser  # field(default_factory=get_filemeta_config)
    query: PoolQuery  # field(default_factory=PoolQuery)  # search record based on query in arrays is very inefficient
    meta: ObservationMeta  # field(default_factory=ObservationMeta)  # meta information for the data collection
    pl_path: Optional[
        Path
    ] = None  # Path('.')  # path to parquet file for RECORD, to avro file for EPISODE
    # in record file pool, query is mostly ignored for sample, only for checking.
    # list of dask DataFrame with the target vehicle and driver
    logger: Optional[logging.Logger] = None
    dict_logger: Optional[dict] = None

    def __post_init__(self):
        super().__post_init__()

        # all required specification for the data collection must be available
        assert all(
            key in self.recipe["DEFAULT"]
            for key in [
                "data_folder",
                "recipe_file_name",
                "coll_type",
            ]
        ), f"recipe specification incomplete!"
        self.pl_path = (
            Path(self.recipe["DEFAULT"]["data_folder"])
            / self.recipe["DEFAULT"]["coll_type"]
        )  # coll_type used as part of the path of the parquet storage location,
        # for example, 'data_folder'/'RECORD' or 'data_folder'/'EPISODE'

    def find(self, query: PoolQuery) -> Optional[pd.DataFrame]:
        """
        Find records by PoolQuery with
            - vehicle
            - driver
            - episodestart_start
            - episodestart_end
            - timestamp_start
            - timestamp_end

        return: a DataFrame with all records in the query time range
        """

        df = self.get_query(query).compute()
        assert type(df) == pd.DataFrame, f"df is not a pandas DataFrame"
        return df

    @abc.abstractmethod
    def get_query(
        self, query: Optional[PoolQuery] = None
    ) -> Optional[Union[dd.DataFrame, db.Bag]]:
        """
        Get records by PoolQuery with
            - vehicle
            - driver
            - episodestart_start
            - episodestart_end
            - timestamp_start
            - timestamp_end

        return: a DataFrame with all records in the query time range
        """

        pass

    def _count(self, query: Optional[PoolQuery] = None):
        """
        Count the number of records in the db.
        rule is an optional dictionary specifying a rule or
        a pipeline in mongodb
        query = {
            - vehicle
            - driver
            - episodestart_start
            - episodestart_end
            - timestamp_start
            - timestamp_end
        }
        """
        items = self.get_query(query)  # either a dask dataframe or a dask bag
        if items is None:
            return 0
        else:
            return len(items.compute())

    @abc.abstractmethod
    def sample(self, size: int, *, query: Optional[PoolQuery] = None) -> pd.DataFrame:
        """
        Sample a batch of data from the pool
        returns a pandas dataframe
        """

In [None]:
if __name__ == "__main__":
    pass