In [1]:
from itertools import islice
%matplotlib inline
from matplotlib import pyplot as plt
import matplotlib.dates as mdates

In [2]:
from gluonts.dataset.repository.datasets import get_dataset
from gluonts.dataset.common import ListDataset
from gluonts.evaluation import make_evaluation_predictions, Evaluator
from gluonts.torch.distributions import NegativeBinomialOutput
from gluonts.torch.modules.loss import NegativeLogLikelihood

from TSMixer import TSMixerEstimator

In [3]:
dataset = get_dataset("solar-energy", regenerate=False)

In [4]:
estimator = TSMixerEstimator(
    prediction_length=dataset.metadata.prediction_length,
    context_length=dataset.metadata.prediction_length*4,
    freq=dataset.metadata.freq,
    scaling="std",

    # hidden_size=64,
    n_blocks=5,
    
    batch_size=128,
    num_batches_per_epoch=100,
    trainer_kwargs=dict(accelerator="cpu", max_epochs=100)
)

In [None]:
predictor = estimator.train(
    training_data=dataset.train, 
    cache_data=True, 
    shuffle_buffer_length=1024, 
    validation_data=dataset.test
)

GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(

  | Name  | Type         | Params
---------------------------------------
0 | model | TSMixerModel | 49.5 K
---------------------------------------
49.5 K    Trainable params
0         Non-trainable params
49.5 K    Total params
0.198     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Epoch 0, global step 100: 'val_loss' reached 3.95022 (best 3.95022), saving model to 'C:\\Users\\maerzale\\OneDrive - adidas\\ConvTS Mixer\\lightning_logs\\version_55\\checkpoints\\epoch=0-step=100.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 1, global step 200: 'val_loss' reached 3.78138 (best 3.78138), saving model to 'C:\\Users\\maerzale\\OneDrive - adidas\\ConvTS Mixer\\lightning_logs\\version_55\\checkpoints\\epoch=1-step=200.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 2, global step 300: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 3, global step 400: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 4, global step 500: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 5, global step 600: 'val_loss' reached 3.75541 (best 3.75541), saving model to 'C:\\Users\\maerzale\\OneDrive - adidas\\ConvTS Mixer\\lightning_logs\\version_55\\checkpoints\\epoch=5-step=600.ckpt' as top 1


In [None]:
forecast_it, ts_it = make_evaluation_predictions(
    dataset=dataset.test,
    predictor=predictor,
)

In [None]:
forecasts = list(forecast_it)

In [None]:
tss = list(ts_it)

In [None]:
evaluator = Evaluator()

In [None]:
agg_metrics, ts_metrics = evaluator(iter(tss), iter(forecasts))

In [None]:
agg_metrics

In [None]:
plt.figure(figsize=(20, 15))
date_formater = mdates.DateFormatter('%b, %d')
plt.rcParams.update({'font.size': 15})

for idx, (forecast, ts) in islice(enumerate(zip(forecasts, tss)), 9):
    ax = plt.subplot(3, 3, idx+1)

    plt.plot(ts[-4 * dataset.metadata.prediction_length:].to_timestamp(), label="target", )
    forecast.plot( color='g')
    plt.xticks(rotation=60)
    ax.xaxis.set_major_formatter(date_formater)
    ax.set_title(forecast.item_id)

plt.gcf().tight_layout()
plt.legend()
plt.show()