In [None]:
import os
from functools import partial

import coiled
import dask.dataframe as dd
import pandas as pd
from dask.distributed import Client
from orax_forecast import OraxForecast
from orax_forecast.models import *
from orax_forecast.model_selection import TimeSeriesSplit
from orax_forecast.metrics import root_mean_squared_scaled_error
from orax_forecast.pipeline import TimeSerie, freq2int
from orax_forecast.r_models import AutoARIMA, AutoETS
from tsfeatures import acf_features, stl_features, entropy

from nixtla.data.datasets import TourismInfo

In [None]:
os.environ['AWS_PROFILE'] = 'pos'
os.environ['AWS_DEFAULT_REGION'] = 'us-east-2'
os.environ['OMP_NUM_THREADS'] = '1'

s3_bucket = 's3://nixtla-datasets/tourism'

In [None]:
cluster = coiled.Cluster(configuration='groupo-abraxas/miniforecast_config', 
                         account='groupo-abraxas')
client = Client(cluster)
client

In [None]:
def flow(serie, freq, models, cv, ts_features,
         date_features,  test_periods, loss_fn):
    ts = TimeSerie(serie, freq, models, cv)
    ts.compute_ts_features(ts_features)
    ts.keep_effective_cv()
    ts.add_date_features_to_train(date_features)
    ts.create_X_test(test_periods)
    ts.add_date_features_to_X_test(date_features)
    ts.keep_effective_features()
    ts.keep_effective_models()
    ts.fit_models()
    ts.compute_test_predictions()
    ts.compute_cross_validation_losses(loss_fn)
    return ts

In [None]:
for group in TourismInfo.groups:
    data = dd.read_parquet(f'{s3_bucket}/{group.name}/data.parquet')
    data = data.repartition(npartitions=8, force=True)
    
    test_periods = group.horizon
    int_freq = freq2int[group.name[0]]
    models = dict(
        ARIMA=AutoARIMA(freq=int_freq, stepwise=False, approximation=False),
        ETS=AutoETS(freq=int_freq),
        SNaive=SeasonalNaive(season_length=int_freq),
        WAvg=WindowAverage(window_size=int_freq*4),
    )
    params = dict(
        freq='Y',
        cv=TimeSeriesSplit(n_splits=2, valid_size=test_periods),
        test_periods=test_periods,
        models=models,
        loss_fn=root_mean_squared_scaled_error,
        ts_features=[acf_features, stl_features, entropy],
        date_features=['year', 'month', 'quarter'],
    )
    partial_flow = partial(flow, **params)

    results_path = f'{s3_bucket}/{group.name}/of_results'
    of = OraxForecast()
    of.write_flow_results(data, partial_flow, results_path)

In [None]:
cluster.close()