# Sample dataset

In [1]:
import pandas as pd

In [2]:
schema = {
    'Features': [
        {
            'FeatureName': 'target',
            'FeatureType': 'float'
        },
        {
            'FeatureName': 'group_id_0',
            'FeatureType': 'int'
        },
        {
            'FeatureName': 'group_id_1',
            'FeatureType': 'int'
        },
        {
            'FeatureName': 'timestamp',
            'FeatureType': 'timestamp'
        }
    ]
}

In [3]:
df = pd.read_csv('data/train.csv')
df = df.rename(
    columns={
        'store': 'group_id_0',
        'item': 'group_id_1', 
        'date': 'timestamp', 
        'sales': 'target'
    }
)

In [6]:
df['target'] = df['target'].astype(float)
df['timestamp'] = pd.to_datetime(df['timestamp'])

# Resolving bucket

In [7]:
from minio import Minio
import pyarrow.parquet as pq
import s3fs

In [15]:
class DatasetBridge:
    """Bridges buckets and dataset creation.
    """
    MINIO_ENDPOINT = 'http://minio:9000'
    ROOT_PATH = 'data/'

    def __init__(self, bucket_name, access_key, secret_key):
        self.bucket_name = bucket_name
        self.access_key = access_key
        self.secret_key = secret_key        

    def get_parquet_dataset(self, base_dir):
        s3_path = self._make_s3_root_path(base_dir)
        fs = self._get_s3_filesystem()
        parquet_dataset = pq.ParquetDataset(s3_path, filesystem=fs)
        return parquet_dataset

    def _make_s3_root_path(self, *args):
        path = ("s3://" +
                # self.ROOT_PATH +
                self.bucket_name +
                '/' +
                '/'.join(args))
        return path

    def _get_s3_filesystem(self):
        client_kwargs = {
            'endpoint_url': self.MINIO_ENDPOINT,
            'aws_access_key_id': self.access_key,
            'aws_secret_access_key': self.secret_key,
            'verify': False
        }
        fs = s3fs.S3FileSystem(anon=False, use_ssl=False,
                               client_kwargs=client_kwargs)

        return fs
    
    
class Dataset:
    """Interface for parquet datasets.
    """
    def __init__(self, parquet_ds):
        self.parquet_ds = parquet_ds
        
    def get_pandas_df(self):
        return self.parquet_ds.read_pandas().to_pandas()
    
    def get_group_ids(self):
        arrow_schema = self.get_arrow_schema()
        return [x for x in arrow_schema.names if x.startswith('group_id')]
        
    def get_arrow_schema(self):
        return self.parquet_ds.schema.to_arrow_schema()
    
    def merge(self, parquet_ds):
        left_df = self.get_pandas_df()
        right_df = parquet_ds.get_pandas_df()
        group_ids = self.get_group_ids()
        merged_df = pd.merge(
            left=left_df,
            right=right_df,
            on=group_ids + ['timestamp']
        )
        return merged_df

# Estimator

In [24]:
from mooncake.nn import SeqToSeq, TemporalFusionTransformer as TFT

ESTIMATORS = {
    'seq2seq': SeqToSeq,
    'tft': TFT
}


class EstimatorCreator:
    def __init__(self, predictor, target_dataset):
        self.predictor = predictor
        self.target_dataset = target_dataset

    def create_estimator(self):
        cls = self._get_estimator_class()
        estimator_args = self._get_estimator_args()
        return cls(**estimator_args)

    def _get_estimator_class(self):
        return ESTIMATORS[self.predictor.algorithm]

    def _get_estimator_args(self):
        pass


class EstimatorArgsCreator:
    def __init__(self, predictor, target_dataset):
        self.predictor = predictor
        self.target_dataset = target_dataset

    def get_estimator_args(self):
        pass

# Putting it all together 

In [23]:
# Data bridge between minio and mlpi.
bridge = DatasetBridge(bucket_name='sample', access_key='oxxo', secret_key='password')

# Create ``dataset`` object which is a easy-to-use parquet dataset interface.
parquet_dataset = bridge.get_parquet_dataset('target')
dataset = Dataset(parquet_dataset)