In [None]:
#default_exp data_model

# Data model

> Helper input validations

In [None]:
#export
from enum import Enum
from typing import Dict, List, Optional, Union

try:  
    from typing import Literal  # python>=3.8
except ImportError:
    from typing_extensions import Literal

import window_ops.ewm
import window_ops.expanding
import window_ops.rolling
from pydantic import BaseModel, root_validator
    
from mlforecast.core import date_features_dtypes

In [None]:
from pprint import pprint

import yaml
from nbdev import test_fail

In [None]:
#exporti
_available_tfms = {}
for module_name in ('rolling', 'expanding', 'ewm'):
    module = getattr(window_ops, module_name)
    for tfm in module.__all__:
        _available_tfms[tfm] = getattr(module, tfm)

In [None]:
#export
DateFeatures = Literal[tuple(date_features_dtypes.keys())]
        
Transforms = Literal[tuple(_available_tfms.keys())]


class DataFreq(str, Enum):
    """Pandas frequencies."""
    
    B = 'B'
    C = 'C'
    D = 'D'
    W = 'W'
    M = 'M'
    SM = 'SM'
    BM = 'BM'
    CBM = 'CBM'
    MS = 'MS'
    SMS = 'SMS'
    BMS = 'BMS'
    CBMS = 'CBMS'
    Q = 'Q'
    BQ = 'BQ'
    QS = 'QS'
    BQS = 'BQS'
    A = 'A'
    Y = 'Y'
    BA = 'BA'
    BY = 'BY'
    AS = 'AS'
    YS = 'YS'
    BAS = 'BAS'
    BYS = 'BYS'
    BH = 'BH'
    H = 'H'
    T = 'T'
    S = 'S'
    L = 'L'
    U = 'U'
    N = 'N'
    W_MON = 'W-MON'
    W_TUE = 'W-TUE'
    W_WED = 'W-WED'
    W_THU = 'W-THU'
    W_FRI = 'W-FRI'
    W_SAT = 'W-SAT'
    Q_JAN = 'Q-JAN'
    Q_FEB = 'Q-FEB'
    Q_MAR = 'Q-MAR'
    Q_APR = 'Q-APR'
    Q_MAY = 'Q-MAY'
    Q_JUN = 'Q-JUN'
    Q_JUL = 'Q-JUL'
    Q_AUG = 'Q-AUG'
    Q_SEP = 'Q-SEP'
    Q_OCT = 'Q-OCT'
    Q_NOV = 'Q-NOV'
    A_JAN = 'A-JAN'
    A_FEB = 'A-FEB'
    A_MAR = 'A-MAR'
    A_APR = 'A-APR'
    A_MAY = 'A-MAY'
    A_JUN = 'A-JUN'
    A_JUL = 'A-JUL'
    A_AUG = 'A-AUG'
    A_SEP = 'A-SEP'
    A_OCT = 'A-OCT'
    A_NOV = 'A-NOV'

    
class DataFormat(str, Enum):
    """Allowed data formats."""
    
    csv = 'csv'
    parquet = 'parquet'

    
class DataConfig(BaseModel):
    """Data configuration."""
    
    prefix: str
    input: str
    output: str
    format: DataFormat

        
class BacktestConfig(BaseModel):
    """Backtest configuration."""
    
    n_windows: int
    window_size: int

        
class FeaturesConfig(BaseModel):
    """Features configuration."""
    
    freq: DataFreq
    lags: Optional[List[int]]
    lag_transforms: Optional[Dict[int, List[Union[Transforms, Dict[Transforms, Dict]]]]]
    date_features: Optional[List[DateFeatures]]
    static_features: Optional[List[str]]
    num_threads: Optional[int]

        
class ForecastConfig(BaseModel):
    """Forecast configuration."""
    
    horizon: int

        
class ModelConfig(BaseModel):
    """Model configuration. 
    
    name must include the modules i.e. sklearn.ensemble.RandomForestRegressor."""
    
    name: str
    params: Optional[Dict]

        
class LocalConfig(BaseModel):
    """Configuration for local pipeline."""
    
    model: ModelConfig
     
    
class ClusterConfig(BaseModel):
    """Cluter configuration.
    
    class_name must include the modules i.e. dask.distributed.LocalCluster"""
    
    class_name: str
    class_kwargs: Dict
      
    
class DistributedModelName(str, Enum):
    """Available models for distributed training."""
    
    XGBoost = 'XGBForecast'
    LightGBM = 'LGBMForecast'


class DistributedModelConfig(BaseModel):
    """Configuration for distributed models."""
    
    name: DistributedModelName
    params: Optional[Dict]
      
    
class DistributedConfig(BaseModel):
    """Configuration for distributed training."""
    
    model: DistributedModelConfig
    cluster: ClusterConfig

        
class FlowConfig(BaseModel):
    """Flow configuration."""
    
    data: DataConfig
    features: FeaturesConfig
    backtest: Optional[BacktestConfig]
    forecast: Optional[ForecastConfig]
    local: Optional[LocalConfig]
    distributed: Optional[DistributedConfig]
        
    @root_validator
    def check_local_or_distributed(cls, values):
        local = values.get('local')
        distributed = values.get('distributed')
        if local and distributed:
            raise ValueError('Must specify either local or distributed, not both.')
        if not local and not distributed:
            raise ValueError('Must specify either local or distributed.')
        return values


In [None]:
with open('../sample_configs/local.yaml', 'r') as f:
    config = FlowConfig(**yaml.safe_load(f))
pprint(config.dict())

In [None]:
with open('../sample_configs/distributed.yaml', 'r') as f:
    config = FlowConfig(**yaml.safe_load(f))
pprint(config.dict())

In [None]:
with open('../sample_configs/distributed.yaml', 'r') as f:
    cfg = yaml.safe_load(f)

def test_specifying_both_local_and_distributed_fails():
    cfg['local'] = {'model': {'name': 'lgb.LGBMRegressor'}}
    FlowConfig(**cfg)
    
test_fail(test_specifying_both_local_and_distributed_fails, contains='not both')

In [None]:
with open('../sample_configs/distributed.yaml', 'r') as f:
    cfg = yaml.safe_load(f)

def test_not_specifying_local_nor_distributed_fails():
    del cfg['distributed']
    FlowConfig(**cfg)
    
test_fail(test_not_specifying_local_nor_distributed_fails, contains='either local or distributed')