In [None]:
import pandas as pd
from dsipts import TimeSeries, RNN,Monash,get_freq


In [None]:
m = Monash(filename='monash',baseUrl='https://forecastingdata.org/', rebuild=True)

In [None]:
m.downloaded


In [None]:
#m.download_dataset('data',4656144)

In [None]:
m.download_dataset('data',4656144)
m.save('monarch')

In [None]:
loaded_data,frequency,forecast_horizon,contain_missing_values,contain_equal_length = m.generate_dataset(4656144)


In [None]:
loaded_data

In [None]:
serie = pd.DataFrame({'signal':loaded_data.series_value.iloc[0]})
serie['time'] = pd.date_range(start = loaded_data.start_timestamp.iloc[0], periods=  serie.shape[0],freq=get_freq(frequency))
serie['cum'] = serie.time.dt.minute  + serie.time.dt.hour 
starting_point = {'cum':0} ##this can be used for creating the dataset: only samples with cum=0 in the first future lag will be used as samples! 
ts = TimeSeries('4656144')
ts.load_signal(serie.iloc[0:8000],enrich_cat=['dow','hour'],target_variables=['signal'])

In [None]:
ts.plot();

In [None]:
past_steps = 100
future_steps = 20
config = dict(model_configs =dict(
                                    cat_emb_dim = 16,
                                    kind = 'gru',
                                    hidden_RNN = 12,
                                    num_layers_RNN = 2,
                                    sum_emb = True,
                                    kernel_size = 15,
                                    past_steps = past_steps,
                                    future_steps = future_steps,
                                    past_channels = len(ts.num_var),
                                    future_channels = len(ts.future_variables),
                                    embs = [ts.dataset[c].nunique() for c in ts.cat_var],
                                    quantiles=[0.1,0.5,0.9],
                                    dropout_rate= 0.5,
                                    persistence_weight= 0.010,
                                    loss_type= 'l1',
                                    remove_last= True,
                                    use_bn = False,
                                    optim= 'torch.optim.Adam',
                                    activation= 'torch.nn.PReLU',                            
                                    out_channels = len(ts.target_variables)),
                scheduler_config = dict(gamma=0.1,step_size=100),
                optim_config = dict(lr = 0.0005,weight_decay=0.01))
model_sum = RNN(**config['model_configs'],optim_config = config['optim_config'],scheduler_config =config['scheduler_config'] )
ts.set_model(model_sum,config=config )

In [None]:
ts.set_model(model_sum,config=config )

In [None]:
ts.train_model(dirpath="/home/agobbi/Projects/TT/tmp/4656719v2",split_params=dict(perc_train=0.6, perc_valid=0.2,past_steps = past_steps,future_steps=future_steps, range_train=None, range_validation=None, range_test=None,shift = 0,starting_point=None,skip_step=1),batch_size=100,num_workers=4,max_epochs=40,auto_lr_find=True,devices='auto')

In [None]:
ts.losses.plot()

In [None]:
ts.modifier=None

In [None]:
res = ts.inference_on_set(batch_size = 100,num_workers = 4)

In [None]:
import numpy as np
np.sqrt(np.mean((res[res.lag==2]['signal_median'] - res[res.lag==2].signal)**2))

In [None]:
%matplotlib qt
res[res.lag==2].drop(columns='time').plot()

In [None]:
from datetime import timedelta
res['prediction_time'] = res.apply(lambda x: x.time-timedelta(minutes=60*x.lag), axis=1)

In [None]:
import matplotlib.pyplot as plt
date = '2006-02-15 02:20:01'

mask = res.prediction_time==date
plt.plot(res.lag[mask],res.signal[mask],label='real')
plt.plot(res.lag[mask],res.signal_median[mask],label='median')
plt.legend()

In [None]:
ts.save('tmp')
ts.load(RNN,'tmp',load_last=False)
res = ts.inference_on_set(batch_size = 100,num_workers = 4)

In [None]:
import matplotlib.pyplot as plt
res.sort_values(by='time',inplace=True)
plt.plot(res.time, res['signal'],label='real')
plt.plot(res.time, res['median'],label='median')

In [None]:
res['error'] =np.abs( res['signal']-res['signal_median'])
res.groupby('lag').error.mean().plot()