In [7]:
import pandas as pd
import numpy as np
from darts import TimeSeries
from darts.models import *
from darts.metrics import *
from darts.dataprocessing.transformers import Scaler
from darts.models.forecasting.catboost_model import CatBoostModel
import matplotlib.pyplot as plt

In [8]:
df = pd.read_csv(r'all_features_with_swan.csv')
df.drop(columns=['is_swan_event_covid', 'is_swan_event_default', 'is_swan_event_oil', 'is_swan_event_crimea'], inplace=True)
df.head()

Unnamed: 0,Отчетная дата,Всего,добыча полезных ископаемых,добыча топливно-энергетических полезных ископаемых,обрабатывающие производства,производство пищевых продуктов,обработка древесины и производство изделий из дерева,целлюлозно-бумажное производство; издательская и полиграфическая деятельность,"производство кокса, нефтепродуктов и ядерных материалов",химическое производство,...,bucks_rate,gdp,gdp_growth,gdp_pps,gdp_usd,gdp_deflator,index_price_mod,index_price_mod_growth,index_price_mod_growth_prev_year,key_rate
0,2009-01-01,346769.0,3084.0,1813.0,69830.0,23171.0,3464.0,1861.0,381.0,2475.0,...,29.3916,38807.2,-7.8,3054.26,1309.17,102.0,113.4,102.4,0.0,13.0
1,2009-02-01,346769.0,3084.0,1813.0,69830.0,23171.0,3464.0,1861.0,381.0,2475.0,...,29.3916,38807.2,-7.8,3054.26,1309.17,102.0,0.0,101.7,104.1,13.0
2,2009-03-01,346769.0,3084.0,1813.0,69830.0,23171.0,3464.0,1861.0,381.0,2475.0,...,29.3916,38807.2,-7.8,3054.26,1309.17,102.0,0.0,101.3,105.4,13.0
3,2009-04-01,346769.0,3084.0,1813.0,69830.0,23171.0,3464.0,1861.0,381.0,2475.0,...,33.9032,38807.2,-7.8,3054.26,1309.17,102.0,0.0,100.7,106.2,13.0
4,2009-05-01,380065.0,4032.0,1871.0,74435.0,24133.0,3711.0,1999.0,672.0,2615.0,...,32.974,38807.2,-7.8,3054.26,1309.17,102.0,0.0,100.6,106.8,12.5


In [11]:
def covariates(data: pd.DataFrame) -> list:

    true_columns = list(data.columns)
    true_columns.remove('Отчетная дата')
    
    train_data = []
    for column in true_columns:
        ts = TimeSeries.from_dataframe(data[['Отчетная дата', column]], time_col='Отчетная дата', fillna_value=True)
        scaler = Scaler()

        data[column] = np.array(data[column], dtype=np.int64)
        train_data.append(scaler.fit_transform(ts))
        
        if column == 'Всего':
            global main_scaler
            main_scaler = scaler

    
    return train_data

In [13]:
main_data = covariates(df)

train_data, val_data = main_data[0][:-24], main_data[0][-24:]

model_DCTB = CatBoostModel(lags=1)
best_parameters = model_DCTB.gridsearch(
                            parameters={'output_chunk_length': [i for i in range(1, 13)],
                                        'lags': [[-i for i in range(1, j)] for j in range(2, 64)]},
                            series=train_data,
                            val_series=val_data,
                            metric=r2_score,
                            verbose=True,
                            n_jobs=-1)

print(best_parameters)


  0%|          | 0/744 [00:00<?, ?it/s]

In [None]:
best_model = CatBoostModel(lags=[],
                           output_chunk_length=)

best_model.fit(series=main_data[0])
pred_data = best_model.predict(series=train_data, n=24)

pred_data = main_scaler.inverse_transform(pred_data)
val_data = main_scaler.inverse_transform(val_data)

print(r2_score(val_data, pred_data))

pred_data.plot('Predict')
val_data.plot('Actual')

In [None]:
def advanced_gridsearch(num_lags=4, output_chunk=2):
    iter = 0
    lags_list = [[-i for i in range(1, j)] for j in range(2, num_lags + 1)]
    for chunk in range(1, output_chunk + 1):
        for lag in lags_list:
            tmp_model = CatBoostModel(lags=lag,
                                      output_chunk_length=chunk)
            tmp_model.fit(series=main_data, verbose=False)
            tmp_pred = tmp_model.predict(series=main_data[0], n=96)

            tmp_pred = main_scaler.inverse_transform(tmp_pred)

            tmp_pred.plot(label='Our predict')
            plt.savefig(f'graphics/{iter}.png')
            iter += 1
            print(f'Calculate model #{iter} done!')

In [None]:
advanced_gridsearch()

In [None]:
advanced_gridsearch(64, 12)