In [1]:
from itertools import islice
%matplotlib inline
from matplotlib import pyplot as plt
import matplotlib.dates as mdates

In [2]:
from gluonts.dataset.repository.datasets import get_dataset
from gluonts.dataset.common import ListDataset
from gluonts.evaluation import make_evaluation_predictions, Evaluator
from gluonts.torch.distributions import NegativeBinomialOutput
from gluonts.torch.modules.loss import NegativeLogLikelihood

from TSMixer_Auxiliary import TSMixerEstimator

In [3]:
dataset = get_dataset("solar-energy", regenerate=False)

In [4]:
list(dataset.train)[0]

{'target': array([0., 0., 0., ..., 0., 0., 0.], dtype=float32),
 'start': Period('2006-01-01 00:00', 'H'),
 'feat_static_cat': array([0]),
 'item_id': 0}

In [8]:
estimator = TSMixerEstimator(
    prediction_length=dataset.metadata.prediction_length,
    context_length=dataset.metadata.prediction_length*4,
    freq=dataset.metadata.freq,
    scaling="std",

    # hidden_size=64,
    n_blocks=5,
    num_feat_static_cat=1,
    cardinality=[len(list(dataset.train))],
    
    batch_size=128,
    num_batches_per_epoch=100,
    trainer_kwargs=dict(accelerator="gpu", max_epochs=20)
)

In [9]:
predictor = estimator.train(
    training_data=dataset.train, 
    cache_data=True, 
    shuffle_buffer_length=1024, 
    validation_data=dataset.test
)

AssertionError: 

In [None]:
forecast_it, ts_it = make_evaluation_predictions(
    dataset=dataset.test,
    predictor=predictor,
)

In [None]:
forecasts = list(forecast_it)

In [None]:
tss = list(ts_it)

In [None]:
evaluator = Evaluator()

In [None]:
agg_metrics, ts_metrics = evaluator(iter(tss), iter(forecasts))

In [None]:
agg_metrics

In [None]:
plt.figure(figsize=(20, 15))
date_formater = mdates.DateFormatter('%b, %d')
plt.rcParams.update({'font.size': 15})

for idx, (forecast, ts) in islice(enumerate(zip(forecasts, tss)), 9):
    ax = plt.subplot(3, 3, idx+1)

    plt.plot(ts[-4 * dataset.metadata.prediction_length:].to_timestamp(), label="target", )
    forecast.plot( color='g')
    plt.xticks(rotation=60)
    ax.xaxis.set_major_formatter(date_formater)
    ax.set_title(forecast.item_id)

plt.gcf().tight_layout()
plt.legend()
plt.show()