In [2]:
from darts import TimeSeries
import pandas as pd
from darts.models import NaiveSeasonal
from darts.models import NaiveMean
from darts import TimeSeries
from sklearn.metrics import mean_absolute_percentage_error
from typing import Dict
from darts.models import (StatsForecastAutoARIMA, StatsForecastAutoETS, 
                          StatsForecastAutoTheta, StatsForecastAutoCES,
                          FourTheta, KalmanForecaster, CatBoostModel, Croston
                         )

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
df = pd.read_parquet("../../data/processed/dataset.parquet")
#df = df[[column for column in df.columns if 'feat' not in column]]
series = TimeSeries.from_dataframe(df, time_col='ds', value_cols=[column for column in df.columns if 'feat' not in column and column != 'ds'])

In [4]:
df

Unnamed: 0,Денситометр,КТ,КТ с КУ 1 зона,КТ с КУ 2 и более зон,ММГ,МРТ,МРТ с КУ 1 зона,МРТ с КУ 2 и более зон,РГ,Флюорограф,ds,feat_КТ с КУ 1 зона_lag-4_КТ,feat_КТ_lag-1_КТ с КУ 1 зона,feat_МРТ_lag-6_КТ с КУ 2 и более зон,feat_МРТ с КУ 1 зона_lag-1_МРТ,feat_МРТ_lag-1_МРТ с КУ 1 зона,feat_КТ с КУ 1 зона_lag-1_ММГ
1,17.0,6146,43.0,100.0,483,415,169.0,2.0,12450,392.0,2022-01-03,6146.0,43.0,100.0,415.0,169.0,483.0
2,1026.0,10868,424.0,451.0,9567,2156,669.0,9.0,48904,22626.0,2022-01-10,6146.0,43.0,100.0,415.0,169.0,483.0
3,910.0,12266,430.0,490.0,8791,2162,710.0,14.0,47364,20496.0,2022-01-17,6146.0,424.0,100.0,2156.0,669.0,9567.0
4,679.0,12793,336.0,471.0,7465,2066,667.0,7.0,40234,15227.0,2022-01-24,6146.0,430.0,100.0,2162.0,710.0,8791.0
5,571.0,13235,302.0,446.0,6124,1900,609.0,6.0,36502,12586.0,2022-01-31,6146.0,336.0,100.0,2066.0,667.0,7465.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
104,1294.0,3313,491.0,476.0,14856,1327,715.0,11.0,81751,5979.0,2023-12-25,4473.0,473.0,784.0,1681.0,863.0,16615.0
105,84.0,950,70.0,73.0,1185,544,131.0,1.0,16148,241.0,2024-01-01,4293.0,491.0,800.0,1327.0,715.0,14856.0
106,1427.0,3317,431.0,369.0,13964,1461,593.0,12.0,80644,5394.0,2024-01-08,4764.0,70.0,743.0,544.0,131.0,1185.0
107,1816.0,3939,563.0,518.0,17769,1712,809.0,17.0,98705,6580.0,2024-01-15,4087.0,431.0,782.0,1461.0,593.0,13964.0


In [14]:
def cross_val(
        models: Dict[str, object],
        series: TimeSeries,
        remains_rows_to_val: int = 24,
        n_rows_to_forecast: int = 4,
        step: int = 1,
) -> Dict[str, Dict[str, pd.Series]]:
    """
    Perform cross-validation

    Args:
        models (Dict[str, object]): A dictionary of models to be used for cross-validation.
        series (TimeSeries): The time series data to be used for cross-validation.
        remains_rows_to_val (int, optional): The number of rows to be used for validation. Defaults to 24.
        n_rows_to_forecast (int, optional): The number of rows to be forecasted. Defaults to 4.
        step (int, optional): The step size for iterating over the training data. Defaults to 1.

    Returns:
        Dict[str, Dict[str, pd.Series]]: A dictionary containing the cross-validation results for each combination of training and forecasting
        rows. The keys are strings representing the combination of rows, and the values are dictionaries containing
        the MAPE (mean absolute percentage error) for each model by each column.
      Example: []
    """
    remains_rows_to_val = 24
    n_rows_to_forecast = 4
    step = 1
    if remains_rows_to_val < n_rows_to_forecast:
        raise ValueError("remains_rows_to_val count must be higher then n_rows_to_forecast")

    cross_val_results = {}
    # 109 - 24          85 + 4, 86 + 4, 87 + 4 ...........

    for name, model in models.items():
        model_results = {}
        print(f'model_name now --- {name}')
        for n_rows_to_train in range(len(series) - remains_rows_to_val, len(series) - n_rows_to_forecast + 1, step):
            print(f"n_rows_to_train: {n_rows_to_train}", f"i_rows_for_forecast: {n_rows_to_train + 1} - {n_rows_to_train + n_rows_to_forecast}")
            train = series[:n_rows_to_train]
            val = series[n_rows_to_train:n_rows_to_train + n_rows_to_forecast]

            iteration_results = {}

            for column in train.columns:
                model.fit(train[column])
                preds = model.predict(n_rows_to_forecast)
                iteration_results[column] = mean_absolute_percentage_error(val[column].values(), preds.values())
                
            model_results[f'{n_rows_to_train} + {n_rows_to_forecast}'] = pd.Series(iteration_results)

        cross_val_results[name] = model_results
    return (cross_val_results)


def represent_cross_validation_results(
    cross_val_results: Dict[str, Dict[str, pd.Series]]
) -> None:
    """
    Represent the cross-validation results for each model.

    Args:
        models (Dict[str, object]): A dictionary of models.
        cross_val_results (Dict[str, Dict[str, float]]): A dictionary containing the cross-validation results
            
    Returns:
        None
    """
    for model in cross_val_results:
        represent_model_dict = {}
        for validation_step in cross_val_results[model]:
            #print(validation_step)
            represent_model_dict[validation_step] = cross_val_results[model][validation_step]

        represent_model_df = pd.DataFrame(represent_model_dict).T  # transpose DataFrame
        model_df_with_stats = pd.concat([represent_model_df, represent_model_df.describe()])

        display(model_df_with_stats.style.set_caption(f'{model}'))

    return




models = {
    'StatsForecastAutoARIMA': StatsForecastAutoARIMA(), 
    'StatsForecastAutoETS': StatsForecastAutoETS(),
    'StatsForecastAutoTheta':  StatsForecastAutoTheta(), 
    #'StatsForecastAutoCES': StatsForecastAutoCES(),
    #'FourTheta': FourTheta(),
    # 'KalmanForecaster': KalmanForecaster(dim_x=12),
    # 'CatBoostModel': CatBoostModel(lags=26),
    #'Croston': Croston()
}
ress = cross_val(models, series)
represent_cross_validation_results(ress) 

model_name now --- StatsForecastAutoARIMA
n_rows_to_train: 84 i_rows_for_forecast: 85 - 88
n_rows_to_train: 85 i_rows_for_forecast: 86 - 89
n_rows_to_train: 86 i_rows_for_forecast: 87 - 90
n_rows_to_train: 87 i_rows_for_forecast: 88 - 91
n_rows_to_train: 88 i_rows_for_forecast: 89 - 92
n_rows_to_train: 89 i_rows_for_forecast: 90 - 93
n_rows_to_train: 90 i_rows_for_forecast: 91 - 94
n_rows_to_train: 91 i_rows_for_forecast: 92 - 95
n_rows_to_train: 92 i_rows_for_forecast: 93 - 96
n_rows_to_train: 93 i_rows_for_forecast: 94 - 97
n_rows_to_train: 94 i_rows_for_forecast: 95 - 98
n_rows_to_train: 95 i_rows_for_forecast: 96 - 99
n_rows_to_train: 96 i_rows_for_forecast: 97 - 100
n_rows_to_train: 97 i_rows_for_forecast: 98 - 101
n_rows_to_train: 98 i_rows_for_forecast: 99 - 102
n_rows_to_train: 99 i_rows_for_forecast: 100 - 103
n_rows_to_train: 100 i_rows_for_forecast: 101 - 104
n_rows_to_train: 101 i_rows_for_forecast: 102 - 105
n_rows_to_train: 102 i_rows_for_forecast: 103 - 106
n_rows_to_tra

Unnamed: 0,Денситометр,КТ,КТ с КУ 1 зона,КТ с КУ 2 и более зон,ММГ,МРТ,МРТ с КУ 1 зона,МРТ с КУ 2 и более зон,РГ,Флюорограф
84 + 4,0.034154,0.086074,0.14088,0.089603,0.068748,0.066783,0.105382,0.340654,0.085008,0.110848
85 + 4,0.003909,0.145848,0.0257,0.04921,0.067431,0.060085,0.075084,0.260956,0.113506,0.08486
86 + 4,0.07226,0.217624,0.026818,0.092264,0.086998,0.027104,0.086145,0.337212,0.093454,0.100039
87 + 4,0.095485,0.149083,0.046927,0.109248,0.094461,0.022913,0.058745,0.430702,0.117479,0.090482
88 + 4,0.116423,0.128459,0.087617,0.107919,0.119208,0.069617,0.039175,0.332984,0.139143,0.093391
89 + 4,0.150554,0.124053,0.158415,0.152863,0.136917,0.125035,0.06802,0.236534,0.143563,0.097645
90 + 4,0.217212,0.044946,0.198318,0.144909,0.202329,0.193717,0.069722,0.262811,0.187417,0.140646
91 + 4,0.14833,0.075917,0.1937,0.149228,0.125634,0.242939,0.105079,0.267762,0.15977,0.075037
92 + 4,0.142192,0.05495,0.144901,0.104632,0.106046,0.183791,0.09223,0.29144,0.104718,0.063735
93 + 4,0.124059,0.013676,0.064574,0.064125,0.064166,0.137,0.060874,0.292838,0.065165,0.09031


Unnamed: 0,Денситометр,КТ,КТ с КУ 1 зона,КТ с КУ 2 и более зон,ММГ,МРТ,МРТ с КУ 1 зона,МРТ с КУ 2 и более зон,РГ,Флюорограф
84 + 4,0.034542,0.222306,0.255339,0.145534,0.073641,0.132228,0.13142,0.342463,0.079776,0.078121
85 + 4,0.005507,0.090411,0.052557,0.056362,0.070553,0.124494,0.100635,0.262707,0.110476,0.049515
86 + 4,0.073618,0.09734,0.031866,0.047513,0.089158,0.088756,0.102609,0.259597,0.093028,0.100994
87 + 4,0.096184,0.147709,0.039681,0.081442,0.092639,0.040122,0.079313,0.306521,0.118022,0.097031
88 + 4,0.116591,0.172103,0.084535,0.105845,0.117586,0.071797,0.070936,0.337754,0.140395,0.103361
89 + 4,0.150277,0.190378,0.150723,0.156531,0.136177,0.108331,0.072515,0.23885,0.146349,0.103802
90 + 4,0.196088,0.172451,0.19784,0.160828,0.18771,0.171294,0.072389,0.253305,0.189301,0.08461
91 + 4,0.161284,0.171047,0.20089,0.175257,0.140433,0.228926,0.104582,0.272316,0.165873,0.05257
92 + 4,0.15196,0.154224,0.163315,0.14932,0.11051,0.193966,0.103856,0.292631,0.140182,0.04383
93 + 4,0.138096,0.12033,0.077876,0.06891,0.066233,0.171565,0.084106,0.306355,0.099984,0.083362


Unnamed: 0,Денситометр,КТ,КТ с КУ 1 зона,КТ с КУ 2 и более зон,ММГ,МРТ,МРТ с КУ 1 зона,МРТ с КУ 2 и более зон,РГ,Флюорограф
84 + 4,0.047975,0.094806,0.213587,0.141597,0.062951,0.089918,0.157273,0.322186,0.051188,0.087491
85 + 4,0.016825,0.164228,0.07184,0.047921,0.058226,0.07328,0.056954,0.254703,0.084589,0.053513
86 + 4,0.08659,0.241783,0.049695,0.052185,0.085982,0.03254,0.053458,0.269557,0.084549,0.101878
87 + 4,0.101424,0.190489,0.033156,0.086707,0.092569,0.02121,0.038609,0.332072,0.108943,0.107482
88 + 4,0.11458,0.162871,0.078804,0.100786,0.111791,0.068851,0.04764,0.358466,0.130434,0.111856
89 + 4,0.142944,0.153768,0.143174,0.147995,0.128801,0.120147,0.068177,0.24262,0.13593,0.101812
90 + 4,0.190401,0.074815,0.190075,0.14658,0.183231,0.189926,0.086662,0.256772,0.168098,0.13127
91 + 4,0.148685,0.095317,0.19186,0.156217,0.128631,0.239788,0.130784,0.267945,0.151201,0.053433
92 + 4,0.136502,0.078197,0.15205,0.121432,0.088622,0.179572,0.108382,0.284168,0.128972,0.039826
93 + 4,0.120746,0.032081,0.0728,0.062845,0.060189,0.123672,0.072864,0.296363,0.089357,0.084827
