In [1]:
%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import pickle

from tqdm import tqdm
from multiprocessing import Pool, cpu_count

In [2]:
from gluonts.dataset.multivariate_grouper import MultivariateGrouper
from gluonts.dataset.repository.datasets import dataset_recipes, get_dataset
from gluonts.evaluation.backtest import make_evaluation_predictions
from gluonts.evaluation import MultivariateEvaluator
from diffusers import (
    PNDMScheduler,
    DDIMScheduler,
    DPMSolverMultistepScheduler,
    KDPM2DiscreteScheduler,
    DEISMultistepScheduler,
)

from pts.model.time_grad import TimeGradEstimator
from pts.dataset.repository.datasets import dataset_recipes

from utils import omit_points, interpolate_np_array

  self.freq: BaseOffset = to_offset(freq)


In [3]:
dataset = get_dataset("solar_nips", regenerate=False)

train_grouper = MultivariateGrouper(
    max_target_dim=int(dataset.metadata.feat_static_cat[0].cardinality)
)

test_grouper = MultivariateGrouper(
    num_test_dates=int(len(dataset.test) / len(dataset.train)),
    max_target_dim=int(dataset.metadata.feat_static_cat[0].cardinality),
)

dataset_train = train_grouper(dataset.train)
dataset_test = test_grouper(dataset.test)

  return pd.Period(val, freq)


In [4]:
evaluator = MultivariateEvaluator(
    quantiles=(np.arange(20) / 20.0)[1:], target_agg_funcs={"sum": np.sum}
)

In [5]:
def plot(
    target,
    forecast,
    prediction_length,
    prediction_intervals=(50.0, 90.0),
    color="g",
    fname=None,
):
    label_prefix = ""
    rows = 4
    cols = 4
    fig, axs = plt.subplots(rows, cols, figsize=(24, 24))
    axx = axs.ravel()
    seq_len, target_dim = target.shape

    ps = [50.0] + [
        50.0 + f * c / 2.0 for c in prediction_intervals for f in [-1.0, +1.0]
    ]

    percentiles_sorted = sorted(set(ps))

    def alpha_for_percentile(p):
        return (p / 100.0) ** 0.3

    for dim in range(0, min(rows * cols, target_dim)):
        ax = axx[dim]

        target[-2 * prediction_length :][dim].plot(ax=ax)

        ps_data = [forecast.quantile(p / 100.0)[:, dim] for p in percentiles_sorted]
        i_p50 = len(percentiles_sorted) // 2

        p50_data = ps_data[i_p50]
        p50_series = pd.Series(data=p50_data, index=forecast.index)
        p50_series.plot(color=color, ls="-", label=f"{label_prefix}median", ax=ax)

        for i in range(len(percentiles_sorted) // 2):
            ptile = percentiles_sorted[i]
            alpha = alpha_for_percentile(ptile)
            ax.fill_between(
                forecast.index,
                ps_data[i],
                ps_data[-i - 1],
                facecolor=color,
                alpha=alpha,
                interpolate=True,
            )
            # Hack to create labels for the error intervals.
            # Doesn't actually plot anything, because we only pass a single data point
            pd.Series(data=p50_data[:1], index=forecast.index[:1]).plot(
                color=color,
                alpha=alpha,
                linewidth=10,
                label=f"{label_prefix}{100 - ptile * 2}%",
                ax=ax,
            )

    legend = ["observations", "median prediction"] + [
        f"{k}% prediction interval" for k in prediction_intervals
    ][::-1]
    axx[0].legend(legend, loc="upper left")

    if fname is not None:
        plt.savefig(fname, bbox_inches="tight", pad_inches=0.05)


In [6]:
scheduler = DEISMultistepScheduler(
    num_train_timesteps=150,
    beta_end=0.1,
)

In [7]:
estimator = TimeGradEstimator(
    input_size=int(dataset.metadata.feat_static_cat[0].cardinality),
    hidden_size=64,
    num_layers=2,
    dropout_rate=0.1,
    lags_seq=[1],
    scheduler=scheduler,
    num_inference_steps=149,
    prediction_length=dataset.metadata.prediction_length,
    context_length=dataset.metadata.prediction_length,
    freq=dataset.metadata.freq,
    scaling="mean",
    trainer_kwargs=dict(max_epochs=200, accelerator="gpu", devices="1")
)

  offset = to_offset(freq_str)


In [8]:
predictor = estimator.train(dataset_train, cache_data=True, shuffle_buffer_length=1024)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/jupyter/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/configuration_validator.py:74: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
2024-05-22 13:31:52.321965: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type          | Params | In sizes                                                             | Out sizes        
-----------------------------------------------------------------------------------------------------------------------------------
0 | model | TimeGradModel | 186 K  | [[1, 1], [

Epoch 0: |          | 50/? [00:07<00:00,  6.72it/s, v_num=15, train_loss=0.385]

Epoch 0, global step 50: 'train_loss' reached 0.38476 (best 0.38476), saving model to '/home/jupyter/work/resources/lightning_logs/version_15/checkpoints/epoch=0-step=50.ckpt' as top 1


Epoch 1: |          | 50/? [00:06<00:00,  7.97it/s, v_num=15, train_loss=0.155]

Epoch 1, global step 100: 'train_loss' reached 0.15453 (best 0.15453), saving model to '/home/jupyter/work/resources/lightning_logs/version_15/checkpoints/epoch=1-step=100.ckpt' as top 1


Epoch 2: |          | 50/? [00:06<00:00,  7.95it/s, v_num=15, train_loss=0.090]

Epoch 2, global step 150: 'train_loss' reached 0.09004 (best 0.09004), saving model to '/home/jupyter/work/resources/lightning_logs/version_15/checkpoints/epoch=2-step=150.ckpt' as top 1


Epoch 3: |          | 50/? [00:06<00:00,  7.98it/s, v_num=15, train_loss=0.0767]

Epoch 3, global step 200: 'train_loss' reached 0.07673 (best 0.07673), saving model to '/home/jupyter/work/resources/lightning_logs/version_15/checkpoints/epoch=3-step=200.ckpt' as top 1


Epoch 4: |          | 50/? [00:06<00:00,  8.00it/s, v_num=15, train_loss=0.0685]

Epoch 4, global step 250: 'train_loss' reached 0.06849 (best 0.06849), saving model to '/home/jupyter/work/resources/lightning_logs/version_15/checkpoints/epoch=4-step=250.ckpt' as top 1


Epoch 5: |          | 50/? [00:06<00:00,  7.67it/s, v_num=15, train_loss=0.0644]

Epoch 5, global step 300: 'train_loss' reached 0.06436 (best 0.06436), saving model to '/home/jupyter/work/resources/lightning_logs/version_15/checkpoints/epoch=5-step=300.ckpt' as top 1


Epoch 6: |          | 50/? [00:06<00:00,  7.99it/s, v_num=15, train_loss=0.0611]

Epoch 6, global step 350: 'train_loss' reached 0.06111 (best 0.06111), saving model to '/home/jupyter/work/resources/lightning_logs/version_15/checkpoints/epoch=6-step=350.ckpt' as top 1


Epoch 7: |          | 50/? [00:06<00:00,  8.01it/s, v_num=15, train_loss=0.0599]

Epoch 7, global step 400: 'train_loss' reached 0.05993 (best 0.05993), saving model to '/home/jupyter/work/resources/lightning_logs/version_15/checkpoints/epoch=7-step=400.ckpt' as top 1


Epoch 8: |          | 50/? [00:06<00:00,  7.99it/s, v_num=15, train_loss=0.0567]

Epoch 8, global step 450: 'train_loss' reached 0.05667 (best 0.05667), saving model to '/home/jupyter/work/resources/lightning_logs/version_15/checkpoints/epoch=8-step=450.ckpt' as top 1


Epoch 9: |          | 50/? [00:06<00:00,  8.05it/s, v_num=15, train_loss=0.0561]

Epoch 9, global step 500: 'train_loss' reached 0.05606 (best 0.05606), saving model to '/home/jupyter/work/resources/lightning_logs/version_15/checkpoints/epoch=9-step=500.ckpt' as top 1


Epoch 10: |          | 50/? [00:06<00:00,  8.02it/s, v_num=15, train_loss=0.0537]

Epoch 10, global step 550: 'train_loss' reached 0.05374 (best 0.05374), saving model to '/home/jupyter/work/resources/lightning_logs/version_15/checkpoints/epoch=10-step=550.ckpt' as top 1


Epoch 11: |          | 50/? [00:06<00:00,  8.08it/s, v_num=15, train_loss=0.0531]

Epoch 11, global step 600: 'train_loss' reached 0.05313 (best 0.05313), saving model to '/home/jupyter/work/resources/lightning_logs/version_15/checkpoints/epoch=11-step=600.ckpt' as top 1


Epoch 12: |          | 50/? [00:06<00:00,  8.03it/s, v_num=15, train_loss=0.0528]

Epoch 12, global step 650: 'train_loss' reached 0.05277 (best 0.05277), saving model to '/home/jupyter/work/resources/lightning_logs/version_15/checkpoints/epoch=12-step=650.ckpt' as top 1


Epoch 13: |          | 50/? [00:06<00:00,  8.06it/s, v_num=15, train_loss=0.053] 

Epoch 13, global step 700: 'train_loss' was not in top 1


Epoch 14: |          | 50/? [00:06<00:00,  8.07it/s, v_num=15, train_loss=0.0502]

Epoch 14, global step 750: 'train_loss' reached 0.05017 (best 0.05017), saving model to '/home/jupyter/work/resources/lightning_logs/version_15/checkpoints/epoch=14-step=750.ckpt' as top 1


Epoch 15: |          | 50/? [00:06<00:00,  8.00it/s, v_num=15, train_loss=0.0499]

Epoch 15, global step 800: 'train_loss' reached 0.04992 (best 0.04992), saving model to '/home/jupyter/work/resources/lightning_logs/version_15/checkpoints/epoch=15-step=800.ckpt' as top 1


Epoch 16: |          | 50/? [00:06<00:00,  8.06it/s, v_num=15, train_loss=0.0499]

Epoch 16, global step 850: 'train_loss' reached 0.04991 (best 0.04991), saving model to '/home/jupyter/work/resources/lightning_logs/version_15/checkpoints/epoch=16-step=850.ckpt' as top 1


Epoch 17: |          | 50/? [00:06<00:00,  8.07it/s, v_num=15, train_loss=0.051] 

Epoch 17, global step 900: 'train_loss' was not in top 1


Epoch 18: |          | 50/? [00:06<00:00,  8.10it/s, v_num=15, train_loss=0.049]

Epoch 18, global step 950: 'train_loss' reached 0.04899 (best 0.04899), saving model to '/home/jupyter/work/resources/lightning_logs/version_15/checkpoints/epoch=18-step=950.ckpt' as top 1


Epoch 19: |          | 50/? [00:06<00:00,  8.13it/s, v_num=15, train_loss=0.0503]

Epoch 19, global step 1000: 'train_loss' was not in top 1


Epoch 20: |          | 50/? [00:06<00:00,  8.05it/s, v_num=15, train_loss=0.049] 

Epoch 20, global step 1050: 'train_loss' reached 0.04899 (best 0.04899), saving model to '/home/jupyter/work/resources/lightning_logs/version_15/checkpoints/epoch=20-step=1050.ckpt' as top 1


Epoch 21: |          | 50/? [00:06<00:00,  8.12it/s, v_num=15, train_loss=0.049]

Epoch 21, global step 1100: 'train_loss' reached 0.04899 (best 0.04899), saving model to '/home/jupyter/work/resources/lightning_logs/version_15/checkpoints/epoch=21-step=1100.ckpt' as top 1


Epoch 22: |          | 50/? [00:06<00:00,  8.10it/s, v_num=15, train_loss=0.0488]

Epoch 22, global step 1150: 'train_loss' reached 0.04885 (best 0.04885), saving model to '/home/jupyter/work/resources/lightning_logs/version_15/checkpoints/epoch=22-step=1150.ckpt' as top 1


Epoch 23: |          | 50/? [00:06<00:00,  8.05it/s, v_num=15, train_loss=0.0474]

Epoch 23, global step 1200: 'train_loss' reached 0.04737 (best 0.04737), saving model to '/home/jupyter/work/resources/lightning_logs/version_15/checkpoints/epoch=23-step=1200.ckpt' as top 1


Epoch 24: |          | 50/? [00:06<00:00,  8.01it/s, v_num=15, train_loss=0.0473]

Epoch 24, global step 1250: 'train_loss' reached 0.04731 (best 0.04731), saving model to '/home/jupyter/work/resources/lightning_logs/version_15/checkpoints/epoch=24-step=1250.ckpt' as top 1


Epoch 25: |          | 50/? [00:06<00:00,  8.04it/s, v_num=15, train_loss=0.0482]

Epoch 25, global step 1300: 'train_loss' was not in top 1


Epoch 26: |          | 50/? [00:06<00:00,  8.08it/s, v_num=15, train_loss=0.0476]

Epoch 26, global step 1350: 'train_loss' was not in top 1


Epoch 27: |          | 50/? [00:06<00:00,  8.14it/s, v_num=15, train_loss=0.0459]

Epoch 27, global step 1400: 'train_loss' reached 0.04592 (best 0.04592), saving model to '/home/jupyter/work/resources/lightning_logs/version_15/checkpoints/epoch=27-step=1400.ckpt' as top 1


Epoch 28: |          | 50/? [00:06<00:00,  8.06it/s, v_num=15, train_loss=0.0476]

Epoch 28, global step 1450: 'train_loss' was not in top 1


Epoch 29: |          | 50/? [00:06<00:00,  7.70it/s, v_num=15, train_loss=0.0474]

Epoch 29, global step 1500: 'train_loss' was not in top 1


Epoch 30: |          | 50/? [00:06<00:00,  8.09it/s, v_num=15, train_loss=0.0468]

Epoch 30, global step 1550: 'train_loss' was not in top 1


Epoch 31: |          | 50/? [00:06<00:00,  8.06it/s, v_num=15, train_loss=0.0462]

Epoch 31, global step 1600: 'train_loss' was not in top 1


Epoch 32: |          | 50/? [00:06<00:00,  8.09it/s, v_num=15, train_loss=0.0477]

Epoch 32, global step 1650: 'train_loss' was not in top 1


Epoch 33: |          | 50/? [00:06<00:00,  8.13it/s, v_num=15, train_loss=0.0454]

Epoch 33, global step 1700: 'train_loss' reached 0.04545 (best 0.04545), saving model to '/home/jupyter/work/resources/lightning_logs/version_15/checkpoints/epoch=33-step=1700.ckpt' as top 1


Epoch 34: |          | 50/? [00:06<00:00,  8.08it/s, v_num=15, train_loss=0.0461]

Epoch 34, global step 1750: 'train_loss' was not in top 1


Epoch 35: |          | 50/? [00:06<00:00,  8.12it/s, v_num=15, train_loss=0.0471]

Epoch 35, global step 1800: 'train_loss' was not in top 1


Epoch 36: |          | 50/? [00:06<00:00,  8.16it/s, v_num=15, train_loss=0.0455]

Epoch 36, global step 1850: 'train_loss' was not in top 1


Epoch 37: |          | 50/? [00:06<00:00,  8.11it/s, v_num=15, train_loss=0.0452]

Epoch 37, global step 1900: 'train_loss' reached 0.04518 (best 0.04518), saving model to '/home/jupyter/work/resources/lightning_logs/version_15/checkpoints/epoch=37-step=1900.ckpt' as top 1


Epoch 38: |          | 50/? [00:06<00:00,  8.07it/s, v_num=15, train_loss=0.0446]

Epoch 38, global step 1950: 'train_loss' reached 0.04459 (best 0.04459), saving model to '/home/jupyter/work/resources/lightning_logs/version_15/checkpoints/epoch=38-step=1950.ckpt' as top 1


Epoch 39: |          | 50/? [00:06<00:00,  8.03it/s, v_num=15, train_loss=0.045] 

Epoch 39, global step 2000: 'train_loss' was not in top 1


Epoch 40: |          | 50/? [00:06<00:00,  8.15it/s, v_num=15, train_loss=0.0458]

Epoch 40, global step 2050: 'train_loss' was not in top 1


Epoch 41: |          | 50/? [00:06<00:00,  8.13it/s, v_num=15, train_loss=0.0451]

Epoch 41, global step 2100: 'train_loss' was not in top 1


Epoch 42: |          | 50/? [00:06<00:00,  8.20it/s, v_num=15, train_loss=0.045] 

Epoch 42, global step 2150: 'train_loss' was not in top 1


Epoch 43: |          | 50/? [00:06<00:00,  8.11it/s, v_num=15, train_loss=0.0457]

Epoch 43, global step 2200: 'train_loss' was not in top 1


Epoch 44: |          | 50/? [00:06<00:00,  8.11it/s, v_num=15, train_loss=0.0451]

Epoch 44, global step 2250: 'train_loss' was not in top 1


Epoch 45: |          | 50/? [00:06<00:00,  8.04it/s, v_num=15, train_loss=0.0444]

Epoch 45, global step 2300: 'train_loss' reached 0.04438 (best 0.04438), saving model to '/home/jupyter/work/resources/lightning_logs/version_15/checkpoints/epoch=45-step=2300.ckpt' as top 1


Epoch 46: |          | 50/? [00:06<00:00,  8.09it/s, v_num=15, train_loss=0.0452]

Epoch 46, global step 2350: 'train_loss' was not in top 1


Epoch 47: |          | 50/? [00:06<00:00,  7.99it/s, v_num=15, train_loss=0.0449]

Epoch 47, global step 2400: 'train_loss' was not in top 1


Epoch 48: |          | 50/? [00:06<00:00,  8.06it/s, v_num=15, train_loss=0.0442]

Epoch 48, global step 2450: 'train_loss' reached 0.04422 (best 0.04422), saving model to '/home/jupyter/work/resources/lightning_logs/version_15/checkpoints/epoch=48-step=2450.ckpt' as top 1


Epoch 49: |          | 50/? [00:06<00:00,  8.05it/s, v_num=15, train_loss=0.0454]

Epoch 49, global step 2500: 'train_loss' was not in top 1


Epoch 50: |          | 50/? [00:06<00:00,  8.07it/s, v_num=15, train_loss=0.0451]

Epoch 50, global step 2550: 'train_loss' was not in top 1


Epoch 51: |          | 50/? [00:06<00:00,  8.04it/s, v_num=15, train_loss=0.0445]

Epoch 51, global step 2600: 'train_loss' was not in top 1


Epoch 52: |          | 50/? [00:06<00:00,  8.08it/s, v_num=15, train_loss=0.0445]

Epoch 52, global step 2650: 'train_loss' was not in top 1


Epoch 53: |          | 50/? [00:06<00:00,  7.60it/s, v_num=15, train_loss=0.0452]

Epoch 53, global step 2700: 'train_loss' was not in top 1


Epoch 54: |          | 50/? [00:06<00:00,  8.11it/s, v_num=15, train_loss=0.0454]

Epoch 54, global step 2750: 'train_loss' was not in top 1


Epoch 55: |          | 50/? [00:06<00:00,  8.09it/s, v_num=15, train_loss=0.0447]

Epoch 55, global step 2800: 'train_loss' was not in top 1


Epoch 56: |          | 50/? [00:06<00:00,  8.08it/s, v_num=15, train_loss=0.0435]

Epoch 56, global step 2850: 'train_loss' reached 0.04354 (best 0.04354), saving model to '/home/jupyter/work/resources/lightning_logs/version_15/checkpoints/epoch=56-step=2850.ckpt' as top 1


Epoch 57: |          | 50/? [00:06<00:00,  8.05it/s, v_num=15, train_loss=0.0446]

Epoch 57, global step 2900: 'train_loss' was not in top 1


Epoch 58: |          | 50/? [00:06<00:00,  8.13it/s, v_num=15, train_loss=0.0441]

Epoch 58, global step 2950: 'train_loss' was not in top 1


Epoch 59: |          | 50/? [00:06<00:00,  8.09it/s, v_num=15, train_loss=0.0448]

Epoch 59, global step 3000: 'train_loss' was not in top 1


Epoch 60: |          | 50/? [00:06<00:00,  8.03it/s, v_num=15, train_loss=0.0444]

Epoch 60, global step 3050: 'train_loss' was not in top 1


Epoch 61: |          | 50/? [00:06<00:00,  8.01it/s, v_num=15, train_loss=0.044] 

Epoch 61, global step 3100: 'train_loss' was not in top 1


Epoch 62: |          | 50/? [00:06<00:00,  8.05it/s, v_num=15, train_loss=0.0448]

Epoch 62, global step 3150: 'train_loss' was not in top 1


Epoch 63: |          | 50/? [00:06<00:00,  8.08it/s, v_num=15, train_loss=0.0436]

Epoch 63, global step 3200: 'train_loss' was not in top 1


Epoch 64: |          | 50/? [00:06<00:00,  8.14it/s, v_num=15, train_loss=0.0448]

Epoch 64, global step 3250: 'train_loss' was not in top 1


Epoch 65: |          | 50/? [00:06<00:00,  8.23it/s, v_num=15, train_loss=0.0441]

Epoch 65, global step 3300: 'train_loss' was not in top 1


Epoch 66: |          | 50/? [00:06<00:00,  8.10it/s, v_num=15, train_loss=0.0436]

Epoch 66, global step 3350: 'train_loss' was not in top 1


Epoch 67: |          | 50/? [00:06<00:00,  8.04it/s, v_num=15, train_loss=0.0441]

Epoch 67, global step 3400: 'train_loss' was not in top 1


Epoch 68: |          | 50/? [00:06<00:00,  8.06it/s, v_num=15, train_loss=0.0431]

Epoch 68, global step 3450: 'train_loss' reached 0.04313 (best 0.04313), saving model to '/home/jupyter/work/resources/lightning_logs/version_15/checkpoints/epoch=68-step=3450.ckpt' as top 1


Epoch 69: |          | 50/? [00:06<00:00,  8.12it/s, v_num=15, train_loss=0.0434]

Epoch 69, global step 3500: 'train_loss' was not in top 1


Epoch 70: |          | 50/? [00:06<00:00,  8.10it/s, v_num=15, train_loss=0.044] 

Epoch 70, global step 3550: 'train_loss' was not in top 1


Epoch 71: |          | 50/? [00:06<00:00,  8.06it/s, v_num=15, train_loss=0.0429]

Epoch 71, global step 3600: 'train_loss' reached 0.04285 (best 0.04285), saving model to '/home/jupyter/work/resources/lightning_logs/version_15/checkpoints/epoch=71-step=3600.ckpt' as top 1


Epoch 72: |          | 50/? [00:06<00:00,  8.13it/s, v_num=15, train_loss=0.044] 

Epoch 72, global step 3650: 'train_loss' was not in top 1


Epoch 73: |          | 50/? [00:06<00:00,  8.09it/s, v_num=15, train_loss=0.0435]

Epoch 73, global step 3700: 'train_loss' was not in top 1


Epoch 74: |          | 50/? [00:06<00:00,  8.09it/s, v_num=15, train_loss=0.0426]

Epoch 74, global step 3750: 'train_loss' reached 0.04264 (best 0.04264), saving model to '/home/jupyter/work/resources/lightning_logs/version_15/checkpoints/epoch=74-step=3750.ckpt' as top 1


Epoch 75: |          | 50/? [00:06<00:00,  8.20it/s, v_num=15, train_loss=0.0426]

Epoch 75, global step 3800: 'train_loss' reached 0.04259 (best 0.04259), saving model to '/home/jupyter/work/resources/lightning_logs/version_15/checkpoints/epoch=75-step=3800.ckpt' as top 1


Epoch 76: |          | 50/? [00:06<00:00,  8.03it/s, v_num=15, train_loss=0.0441]

Epoch 76, global step 3850: 'train_loss' was not in top 1


Epoch 77: |          | 50/? [00:06<00:00,  7.61it/s, v_num=15, train_loss=0.0428]

Epoch 77, global step 3900: 'train_loss' was not in top 1


Epoch 78: |          | 50/? [00:06<00:00,  8.07it/s, v_num=15, train_loss=0.0438]

Epoch 78, global step 3950: 'train_loss' was not in top 1


Epoch 79: |          | 50/? [00:06<00:00,  8.05it/s, v_num=15, train_loss=0.0438]

Epoch 79, global step 4000: 'train_loss' was not in top 1


Epoch 80: |          | 50/? [00:06<00:00,  7.99it/s, v_num=15, train_loss=0.0423]

Epoch 80, global step 4050: 'train_loss' reached 0.04230 (best 0.04230), saving model to '/home/jupyter/work/resources/lightning_logs/version_15/checkpoints/epoch=80-step=4050.ckpt' as top 1


Epoch 81: |          | 50/? [00:06<00:00,  8.08it/s, v_num=15, train_loss=0.0414]

Epoch 81, global step 4100: 'train_loss' reached 0.04142 (best 0.04142), saving model to '/home/jupyter/work/resources/lightning_logs/version_15/checkpoints/epoch=81-step=4100.ckpt' as top 1


Epoch 82: |          | 50/? [00:06<00:00,  8.07it/s, v_num=15, train_loss=0.0434]

Epoch 82, global step 4150: 'train_loss' was not in top 1


Epoch 83: |          | 50/? [00:06<00:00,  8.03it/s, v_num=15, train_loss=0.0426]

Epoch 83, global step 4200: 'train_loss' was not in top 1


Epoch 84: |          | 50/? [00:06<00:00,  8.10it/s, v_num=15, train_loss=0.0433]

Epoch 84, global step 4250: 'train_loss' was not in top 1


Epoch 85: |          | 50/? [00:06<00:00,  8.08it/s, v_num=15, train_loss=0.042] 

Epoch 85, global step 4300: 'train_loss' was not in top 1


Epoch 86: |          | 50/? [00:06<00:00,  8.09it/s, v_num=15, train_loss=0.0425]

Epoch 86, global step 4350: 'train_loss' was not in top 1


Epoch 87: |          | 50/? [00:06<00:00,  8.08it/s, v_num=15, train_loss=0.0428]

Epoch 87, global step 4400: 'train_loss' was not in top 1


Epoch 88: |          | 50/? [00:06<00:00,  8.09it/s, v_num=15, train_loss=0.0423]

Epoch 88, global step 4450: 'train_loss' was not in top 1


Epoch 89: |          | 50/? [00:06<00:00,  8.12it/s, v_num=15, train_loss=0.0418]

Epoch 89, global step 4500: 'train_loss' was not in top 1


Epoch 90: |          | 50/? [00:06<00:00,  8.10it/s, v_num=15, train_loss=0.0425]

Epoch 90, global step 4550: 'train_loss' was not in top 1


Epoch 91: |          | 50/? [00:06<00:00,  8.11it/s, v_num=15, train_loss=0.0425]

Epoch 91, global step 4600: 'train_loss' was not in top 1


Epoch 92: |          | 50/? [00:06<00:00,  8.17it/s, v_num=15, train_loss=0.0416]

Epoch 92, global step 4650: 'train_loss' was not in top 1


Epoch 93: |          | 50/? [00:06<00:00,  8.09it/s, v_num=15, train_loss=0.0411]

Epoch 93, global step 4700: 'train_loss' reached 0.04107 (best 0.04107), saving model to '/home/jupyter/work/resources/lightning_logs/version_15/checkpoints/epoch=93-step=4700.ckpt' as top 1


Epoch 94: |          | 50/? [00:06<00:00,  8.24it/s, v_num=15, train_loss=0.0416]

Epoch 94, global step 4750: 'train_loss' was not in top 1


Epoch 95: |          | 50/? [00:06<00:00,  8.14it/s, v_num=15, train_loss=0.0412]

Epoch 95, global step 4800: 'train_loss' was not in top 1


Epoch 96: |          | 50/? [00:06<00:00,  8.10it/s, v_num=15, train_loss=0.0425]

Epoch 96, global step 4850: 'train_loss' was not in top 1


Epoch 97: |          | 50/? [00:06<00:00,  8.10it/s, v_num=15, train_loss=0.0424]

Epoch 97, global step 4900: 'train_loss' was not in top 1


Epoch 98: |          | 50/? [00:06<00:00,  8.15it/s, v_num=15, train_loss=0.0425]

Epoch 98, global step 4950: 'train_loss' was not in top 1


Epoch 99: |          | 50/? [00:06<00:00,  8.16it/s, v_num=15, train_loss=0.0406]

Epoch 99, global step 5000: 'train_loss' reached 0.04058 (best 0.04058), saving model to '/home/jupyter/work/resources/lightning_logs/version_15/checkpoints/epoch=99-step=5000.ckpt' as top 1


Epoch 100: |          | 50/? [00:06<00:00,  8.12it/s, v_num=15, train_loss=0.0416]

Epoch 100, global step 5050: 'train_loss' was not in top 1


Epoch 101: |          | 50/? [00:06<00:00,  7.68it/s, v_num=15, train_loss=0.0425]

Epoch 101, global step 5100: 'train_loss' was not in top 1


Epoch 102: |          | 50/? [00:06<00:00,  8.10it/s, v_num=15, train_loss=0.0419]

Epoch 102, global step 5150: 'train_loss' was not in top 1


Epoch 103: |          | 50/? [00:06<00:00,  8.09it/s, v_num=15, train_loss=0.0426]

Epoch 103, global step 5200: 'train_loss' was not in top 1


Epoch 104: |          | 50/? [00:06<00:00,  8.10it/s, v_num=15, train_loss=0.0413]

Epoch 104, global step 5250: 'train_loss' was not in top 1


Epoch 105: |          | 50/? [00:06<00:00,  8.10it/s, v_num=15, train_loss=0.0429]

Epoch 105, global step 5300: 'train_loss' was not in top 1


Epoch 106: |          | 50/? [00:06<00:00,  8.12it/s, v_num=15, train_loss=0.0419]

Epoch 106, global step 5350: 'train_loss' was not in top 1


Epoch 107: |          | 50/? [00:06<00:00,  8.04it/s, v_num=15, train_loss=0.0426]

Epoch 107, global step 5400: 'train_loss' was not in top 1


Epoch 108: |          | 50/? [00:06<00:00,  8.13it/s, v_num=15, train_loss=0.0415]

Epoch 108, global step 5450: 'train_loss' was not in top 1


Epoch 109: |          | 50/? [00:06<00:00,  8.09it/s, v_num=15, train_loss=0.0424]

Epoch 109, global step 5500: 'train_loss' was not in top 1


Epoch 110: |          | 50/? [00:06<00:00,  8.05it/s, v_num=15, train_loss=0.0419]

Epoch 110, global step 5550: 'train_loss' was not in top 1


Epoch 111: |          | 50/? [00:06<00:00,  8.11it/s, v_num=15, train_loss=0.042] 

Epoch 111, global step 5600: 'train_loss' was not in top 1


Epoch 112: |          | 50/? [00:06<00:00,  8.11it/s, v_num=15, train_loss=0.0414]

Epoch 112, global step 5650: 'train_loss' was not in top 1


Epoch 113: |          | 50/? [00:06<00:00,  8.09it/s, v_num=15, train_loss=0.0427]

Epoch 113, global step 5700: 'train_loss' was not in top 1


Epoch 114: |          | 50/? [00:06<00:00,  8.09it/s, v_num=15, train_loss=0.0432]

Epoch 114, global step 5750: 'train_loss' was not in top 1


Epoch 115: |          | 50/? [00:06<00:00,  8.02it/s, v_num=15, train_loss=0.0416]

Epoch 115, global step 5800: 'train_loss' was not in top 1


Epoch 116: |          | 50/? [00:06<00:00,  8.06it/s, v_num=15, train_loss=0.042] 

Epoch 116, global step 5850: 'train_loss' was not in top 1


Epoch 117: |          | 50/? [00:06<00:00,  8.05it/s, v_num=15, train_loss=0.0406]

Epoch 117, global step 5900: 'train_loss' was not in top 1


Epoch 118: |          | 50/? [00:06<00:00,  8.12it/s, v_num=15, train_loss=0.0426]

Epoch 118, global step 5950: 'train_loss' was not in top 1


Epoch 119: |          | 50/? [00:06<00:00,  8.15it/s, v_num=15, train_loss=0.0406]

Epoch 119, global step 6000: 'train_loss' was not in top 1


Epoch 120: |          | 50/? [00:06<00:00,  8.14it/s, v_num=15, train_loss=0.0406]

Epoch 120, global step 6050: 'train_loss' reached 0.04058 (best 0.04058), saving model to '/home/jupyter/work/resources/lightning_logs/version_15/checkpoints/epoch=120-step=6050.ckpt' as top 1


Epoch 121: |          | 50/? [00:06<00:00,  8.13it/s, v_num=15, train_loss=0.0423]

Epoch 121, global step 6100: 'train_loss' was not in top 1


Epoch 122: |          | 50/? [00:06<00:00,  8.13it/s, v_num=15, train_loss=0.0415]

Epoch 122, global step 6150: 'train_loss' was not in top 1


Epoch 123: |          | 50/? [00:06<00:00,  8.14it/s, v_num=15, train_loss=0.0428]

Epoch 123, global step 6200: 'train_loss' was not in top 1


Epoch 124: |          | 50/? [00:06<00:00,  8.08it/s, v_num=15, train_loss=0.0424]

Epoch 124, global step 6250: 'train_loss' was not in top 1


Epoch 125: |          | 50/? [00:06<00:00,  8.08it/s, v_num=15, train_loss=0.0414]

Epoch 125, global step 6300: 'train_loss' was not in top 1


Epoch 126: |          | 50/? [00:06<00:00,  7.73it/s, v_num=15, train_loss=0.0417]

Epoch 126, global step 6350: 'train_loss' was not in top 1


Epoch 127: |          | 50/? [00:06<00:00,  8.04it/s, v_num=15, train_loss=0.0418]

Epoch 127, global step 6400: 'train_loss' was not in top 1


Epoch 128: |          | 50/? [00:06<00:00,  8.02it/s, v_num=15, train_loss=0.0422]

Epoch 128, global step 6450: 'train_loss' was not in top 1


Epoch 129: |          | 50/? [00:06<00:00,  8.07it/s, v_num=15, train_loss=0.042] 

Epoch 129, global step 6500: 'train_loss' was not in top 1


Epoch 130: |          | 50/? [00:06<00:00,  8.07it/s, v_num=15, train_loss=0.0413]

Epoch 130, global step 6550: 'train_loss' was not in top 1


Epoch 131: |          | 50/? [00:06<00:00,  8.07it/s, v_num=15, train_loss=0.0414]

Epoch 131, global step 6600: 'train_loss' was not in top 1


Epoch 132: |          | 50/? [00:06<00:00,  8.05it/s, v_num=15, train_loss=0.0426]

Epoch 132, global step 6650: 'train_loss' was not in top 1


Epoch 133: |          | 50/? [00:06<00:00,  8.14it/s, v_num=15, train_loss=0.0411]

Epoch 133, global step 6700: 'train_loss' was not in top 1


Epoch 134: |          | 50/? [00:06<00:00,  8.07it/s, v_num=15, train_loss=0.0416]

Epoch 134, global step 6750: 'train_loss' was not in top 1


Epoch 135: |          | 50/? [00:06<00:00,  8.07it/s, v_num=15, train_loss=0.0398]

Epoch 135, global step 6800: 'train_loss' reached 0.03975 (best 0.03975), saving model to '/home/jupyter/work/resources/lightning_logs/version_15/checkpoints/epoch=135-step=6800.ckpt' as top 1


Epoch 136: |          | 50/? [00:06<00:00,  8.06it/s, v_num=15, train_loss=0.0412]

Epoch 136, global step 6850: 'train_loss' was not in top 1


Epoch 137: |          | 50/? [00:06<00:00,  8.11it/s, v_num=15, train_loss=0.0419]

Epoch 137, global step 6900: 'train_loss' was not in top 1


Epoch 138: |          | 50/? [00:06<00:00,  8.05it/s, v_num=15, train_loss=0.0414]

Epoch 138, global step 6950: 'train_loss' was not in top 1


Epoch 139: |          | 50/? [00:06<00:00,  8.11it/s, v_num=15, train_loss=0.0399]

Epoch 139, global step 7000: 'train_loss' was not in top 1


Epoch 140: |          | 50/? [00:06<00:00,  8.07it/s, v_num=15, train_loss=0.0417]

Epoch 140, global step 7050: 'train_loss' was not in top 1


Epoch 141: |          | 50/? [00:06<00:00,  8.10it/s, v_num=15, train_loss=0.0411]

Epoch 141, global step 7100: 'train_loss' was not in top 1


Epoch 142: |          | 50/? [00:06<00:00,  8.10it/s, v_num=15, train_loss=0.0412]

Epoch 142, global step 7150: 'train_loss' was not in top 1


Epoch 143: |          | 50/? [00:06<00:00,  8.11it/s, v_num=15, train_loss=0.0411]

Epoch 143, global step 7200: 'train_loss' was not in top 1


Epoch 144: |          | 50/? [00:06<00:00,  8.15it/s, v_num=15, train_loss=0.0413]

Epoch 144, global step 7250: 'train_loss' was not in top 1


Epoch 145: |          | 50/? [00:06<00:00,  8.05it/s, v_num=15, train_loss=0.0419]

Epoch 145, global step 7300: 'train_loss' was not in top 1


Epoch 146: |          | 50/? [00:06<00:00,  8.12it/s, v_num=15, train_loss=0.0416]

Epoch 146, global step 7350: 'train_loss' was not in top 1


Epoch 147: |          | 50/? [00:06<00:00,  8.04it/s, v_num=15, train_loss=0.0419]

Epoch 147, global step 7400: 'train_loss' was not in top 1


Epoch 148: |          | 50/? [00:06<00:00,  8.05it/s, v_num=15, train_loss=0.0409]

Epoch 148, global step 7450: 'train_loss' was not in top 1


Epoch 149: |          | 50/? [00:06<00:00,  8.06it/s, v_num=15, train_loss=0.0418]

Epoch 149, global step 7500: 'train_loss' was not in top 1


Epoch 150: |          | 50/? [00:06<00:00,  7.71it/s, v_num=15, train_loss=0.0413]

Epoch 150, global step 7550: 'train_loss' was not in top 1


Epoch 151: |          | 50/? [00:06<00:00,  8.16it/s, v_num=15, train_loss=0.0414]

Epoch 151, global step 7600: 'train_loss' was not in top 1


Epoch 152: |          | 50/? [00:06<00:00,  8.07it/s, v_num=15, train_loss=0.0421]

Epoch 152, global step 7650: 'train_loss' was not in top 1


Epoch 153: |          | 50/? [00:06<00:00,  8.14it/s, v_num=15, train_loss=0.041] 

Epoch 153, global step 7700: 'train_loss' was not in top 1


Epoch 154: |          | 50/? [00:06<00:00,  8.05it/s, v_num=15, train_loss=0.041]

Epoch 154, global step 7750: 'train_loss' was not in top 1


Epoch 155: |          | 50/? [00:06<00:00,  8.13it/s, v_num=15, train_loss=0.0414]

Epoch 155, global step 7800: 'train_loss' was not in top 1


Epoch 156: |          | 50/? [00:06<00:00,  8.09it/s, v_num=15, train_loss=0.0412]

Epoch 156, global step 7850: 'train_loss' was not in top 1


Epoch 157: |          | 50/? [00:06<00:00,  8.12it/s, v_num=15, train_loss=0.0418]

Epoch 157, global step 7900: 'train_loss' was not in top 1


Epoch 158: |          | 50/? [00:06<00:00,  8.14it/s, v_num=15, train_loss=0.0408]

Epoch 158, global step 7950: 'train_loss' was not in top 1


Epoch 159: |          | 50/? [00:06<00:00,  8.10it/s, v_num=15, train_loss=0.0411]

Epoch 159, global step 8000: 'train_loss' was not in top 1


Epoch 160: |          | 50/? [00:06<00:00,  8.10it/s, v_num=15, train_loss=0.0412]

Epoch 160, global step 8050: 'train_loss' was not in top 1


Epoch 161: |          | 50/? [00:06<00:00,  8.10it/s, v_num=15, train_loss=0.0414]

Epoch 161, global step 8100: 'train_loss' was not in top 1


Epoch 162: |          | 50/? [00:06<00:00,  8.14it/s, v_num=15, train_loss=0.0409]

Epoch 162, global step 8150: 'train_loss' was not in top 1


Epoch 163: |          | 50/? [00:06<00:00,  8.07it/s, v_num=15, train_loss=0.0411]

Epoch 163, global step 8200: 'train_loss' was not in top 1


Epoch 164: |          | 50/? [00:06<00:00,  8.04it/s, v_num=15, train_loss=0.0406]

Epoch 164, global step 8250: 'train_loss' was not in top 1


Epoch 165: |          | 50/? [00:06<00:00,  8.14it/s, v_num=15, train_loss=0.0415]

Epoch 165, global step 8300: 'train_loss' was not in top 1


Epoch 166: |          | 50/? [00:06<00:00,  8.13it/s, v_num=15, train_loss=0.0421]

Epoch 166, global step 8350: 'train_loss' was not in top 1


Epoch 167: |          | 50/? [00:06<00:00,  8.04it/s, v_num=15, train_loss=0.0416]

Epoch 167, global step 8400: 'train_loss' was not in top 1


Epoch 168: |          | 50/? [00:06<00:00,  8.10it/s, v_num=15, train_loss=0.0419]

Epoch 168, global step 8450: 'train_loss' was not in top 1


Epoch 169: |          | 50/? [00:06<00:00,  8.12it/s, v_num=15, train_loss=0.0417]

Epoch 169, global step 8500: 'train_loss' was not in top 1


Epoch 170: |          | 50/? [00:06<00:00,  8.16it/s, v_num=15, train_loss=0.0407]

Epoch 170, global step 8550: 'train_loss' was not in top 1


Epoch 171: |          | 50/? [00:06<00:00,  8.16it/s, v_num=15, train_loss=0.0406]

Epoch 171, global step 8600: 'train_loss' was not in top 1


Epoch 172: |          | 50/? [00:06<00:00,  8.15it/s, v_num=15, train_loss=0.0405]

Epoch 172, global step 8650: 'train_loss' was not in top 1


Epoch 173: |          | 50/? [00:06<00:00,  8.09it/s, v_num=15, train_loss=0.0412]

Epoch 173, global step 8700: 'train_loss' was not in top 1


Epoch 174: |          | 50/? [00:06<00:00,  8.14it/s, v_num=15, train_loss=0.0415]

Epoch 174, global step 8750: 'train_loss' was not in top 1


Epoch 175: |          | 50/? [00:06<00:00,  7.74it/s, v_num=15, train_loss=0.0425]

Epoch 175, global step 8800: 'train_loss' was not in top 1


Epoch 176: |          | 50/? [00:06<00:00,  8.13it/s, v_num=15, train_loss=0.0414]

Epoch 176, global step 8850: 'train_loss' was not in top 1


Epoch 177: |          | 50/? [00:06<00:00,  8.03it/s, v_num=15, train_loss=0.0416]

Epoch 177, global step 8900: 'train_loss' was not in top 1


Epoch 178: |          | 50/? [00:06<00:00,  7.96it/s, v_num=15, train_loss=0.0421]

Epoch 178, global step 8950: 'train_loss' was not in top 1


Epoch 179: |          | 50/? [00:06<00:00,  8.05it/s, v_num=15, train_loss=0.0416]

Epoch 179, global step 9000: 'train_loss' was not in top 1


Epoch 180: |          | 50/? [00:06<00:00,  8.13it/s, v_num=15, train_loss=0.0413]

Epoch 180, global step 9050: 'train_loss' was not in top 1


Epoch 181: |          | 50/? [00:06<00:00,  8.17it/s, v_num=15, train_loss=0.0408]

Epoch 181, global step 9100: 'train_loss' was not in top 1


Epoch 182: |          | 50/? [00:06<00:00,  8.11it/s, v_num=15, train_loss=0.0414]

Epoch 182, global step 9150: 'train_loss' was not in top 1


Epoch 183: |          | 50/? [00:06<00:00,  8.12it/s, v_num=15, train_loss=0.0414]

Epoch 183, global step 9200: 'train_loss' was not in top 1


Epoch 184: |          | 50/? [00:06<00:00,  8.11it/s, v_num=15, train_loss=0.0424]

Epoch 184, global step 9250: 'train_loss' was not in top 1


Epoch 185: |          | 50/? [00:06<00:00,  8.14it/s, v_num=15, train_loss=0.0421]

Epoch 185, global step 9300: 'train_loss' was not in top 1


Epoch 186: |          | 50/? [00:06<00:00,  8.16it/s, v_num=15, train_loss=0.0418]

Epoch 186, global step 9350: 'train_loss' was not in top 1


Epoch 187: |          | 50/? [00:06<00:00,  8.08it/s, v_num=15, train_loss=0.0415]

Epoch 187, global step 9400: 'train_loss' was not in top 1


Epoch 188: |          | 50/? [00:06<00:00,  8.06it/s, v_num=15, train_loss=0.0424]

Epoch 188, global step 9450: 'train_loss' was not in top 1


Epoch 189: |          | 50/? [00:06<00:00,  8.15it/s, v_num=15, train_loss=0.0416]

Epoch 189, global step 9500: 'train_loss' was not in top 1


Epoch 190: |          | 50/? [00:06<00:00,  8.06it/s, v_num=15, train_loss=0.0411]

Epoch 190, global step 9550: 'train_loss' was not in top 1


Epoch 191: |          | 50/? [00:06<00:00,  8.09it/s, v_num=15, train_loss=0.0412]

Epoch 191, global step 9600: 'train_loss' was not in top 1


Epoch 192: |          | 50/? [00:06<00:00,  8.08it/s, v_num=15, train_loss=0.0414]

Epoch 192, global step 9650: 'train_loss' was not in top 1


Epoch 193: |          | 50/? [00:06<00:00,  8.09it/s, v_num=15, train_loss=0.0418]

Epoch 193, global step 9700: 'train_loss' was not in top 1


Epoch 194: |          | 50/? [00:06<00:00,  8.11it/s, v_num=15, train_loss=0.0413]

Epoch 194, global step 9750: 'train_loss' was not in top 1


Epoch 195: |          | 50/? [00:06<00:00,  8.09it/s, v_num=15, train_loss=0.0411]

Epoch 195, global step 9800: 'train_loss' was not in top 1


Epoch 196: |          | 50/? [00:06<00:00,  8.02it/s, v_num=15, train_loss=0.0413]

Epoch 196, global step 9850: 'train_loss' was not in top 1


Epoch 197: |          | 50/? [00:06<00:00,  8.14it/s, v_num=15, train_loss=0.0416]

Epoch 197, global step 9900: 'train_loss' was not in top 1


Epoch 198: |          | 50/? [00:06<00:00,  8.13it/s, v_num=15, train_loss=0.0404]

Epoch 198, global step 9950: 'train_loss' was not in top 1


Epoch 199: |          | 50/? [00:06<00:00,  8.16it/s, v_num=15, train_loss=0.0416]

Epoch 199, global step 10000: 'train_loss' was not in top 1
`Trainer.fit` stopped: `max_epochs=200` reached.


Epoch 199: |          | 50/? [00:06<00:00,  8.15it/s, v_num=15, train_loss=0.0416]


In [9]:
forecast_it, ts_it = make_evaluation_predictions(
    dataset=dataset_test, predictor=predictor, num_samples=100
)
forecasts = list(forecast_it)
targets = list(ts_it)
agg_metric, _ = evaluator(targets, forecasts, num_series=len(dataset_test))

Running evaluation: 7it [00:00, 58.25it/s]
Running evaluation: 7it [00:00, 66.72it/s]
Running evaluation: 7it [00:00, 70.20it/s]
Running evaluation: 7it [00:00, 65.56it/s]
Running evaluation: 7it [00:00, 65.72it/s]
Running evaluation: 7it [00:00, 63.88it/s]
Running evaluation: 7it [00:00, 63.29it/s]
Running evaluation: 7it [00:00, 65.00it/s]
Running evaluation: 7it [00:00, 66.85it/s]
Running evaluation: 7it [00:00, 65.02it/s]
Running evaluation: 7it [00:00, 68.68it/s]
Running evaluation: 7it [00:00, 70.25it/s]
Running evaluation: 7it [00:00, 66.68it/s]
Running evaluation: 7it [00:00, 68.54it/s]
Running evaluation: 7it [00:00, 65.28it/s]
Running evaluation: 7it [00:00, 70.84it/s]
Running evaluation: 7it [00:00, 69.58it/s]
Running evaluation: 7it [00:00, 64.63it/s]
Running evaluation: 7it [00:00, 64.15it/s]
Running evaluation: 7it [00:00, 61.15it/s]
Running evaluation: 7it [00:00, 63.26it/s]
Running evaluation: 7it [00:00, 69.66it/s]
Running evaluation: 7it [00:00, 63.63it/s]
Running eva

In [10]:
print('Timegrad on regular dataset metrics:')
print("CRPS: {}".format(agg_metric["mean_wQuantileLoss"]))
print("ND: {}".format(agg_metric["ND"]))
print("NRMSE: {}".format(agg_metric["NRMSE"]))
print("MSE: {}".format(agg_metric["MSE"]))

print('')

print("CRPS-Sum: {}".format(agg_metric["m_sum_mean_wQuantileLoss"]))
print("ND-Sum: {}".format(agg_metric["m_sum_ND"]))
print("NRMSE-Sum: {}".format(agg_metric["m_sum_NRMSE"]))
print("MSE-Sum: {}".format(agg_metric["m_sum_MSE"]))

Timegrad on regular dataset metrics:
CRPS: 0.3743160235033126
ND: 0.4681447740693562
NRMSE: 0.9333072944939037
MSE: 826.2797468362436

CRPS-Sum: 0.31375429086490336
ND-Sum: 0.3635174670093458
NRMSE-Sum: 0.6307043416960839
MSE-Sum: 7082239.095238095


In [11]:
with open('/home/jupyter/datasphere/project/models/solar.pkl', 'wb') as f:
    pickle.dump(predictor, f)

In [12]:
"""
# If you want to load already prepared model

transformation = estimator.create_transformation()
lightning_module = estimator.create_lightning_module()
predictor = estimator.create_predictor(transformation, lightning_module)
with open('/home/jupyter/work/resources/solar/model.pkl', 'rb') as f:
    predictor = pickle.load(f)
"""

"\n# If you want to load already prepared model\n\ntransformation = estimator.create_transformation()\nlightning_module = estimator.create_lightning_module()\npredictor = estimator.create_predictor(transformation, lightning_module)\nwith open('/home/jupyter/work/resources/solar/model.pkl', 'rb') as f:\n    predictor = pickle.load(f)\n"

In [13]:
def interpolate_points(dataset, num_points: int = 5):
    new_dataset = []
    with Pool(processes=cpu_count()) as pool:
        results = pool.starmap(interpolate_np_array, [(ts['target'], num_points) for ts in dataset])
        for i, ts in enumerate(dataset):
            mask = np.isnan(ts['target'])
            new_dataset.append({
                'target': ts['target'] * (1. - mask) + results[i] * mask,
                'start': ts['start'],
                'feat_static_cat': ts['feat_static_cat']
            })
    return new_dataset

In [14]:
crps_sum_results = pd.DataFrame(columns=range(10))

with tqdm(total=9*10) as pbar:
    for p in np.linspace(0.1, 1, 9, endpoint=False): # frac of removed points
        dataset_test_p = omit_points(dataset_test, p)
        results_p = []

        # No interpolation
        forecast_it, ts_it = make_evaluation_predictions(
            dataset=dataset_test_p, 
            predictor=predictor, 
            num_samples=100,
        )
        forecasts = list(forecast_it)
        # targets = list(ts_it)  # Do not uncomment, we need targets without NaNs, so we use ones calculated above.
        agg_metric, _ = evaluator(targets, forecasts, num_series=len(dataset_test))
        results_p.append(agg_metric['m_sum_mean_wQuantileLoss'])
        pbar.update(1)

        # Interpolation with num_points
        for num_points in range(1, 10):
            forecast_it, ts_it = make_evaluation_predictions(
                dataset=interpolate_points(dataset_test_p, num_points), 
                predictor=predictor, 
                num_samples=100,
            )
            forecasts = list(forecast_it)
            # targets = list(ts_it)  # Do not uncomment, we need targets without NaNs, so we use ones calculated above.
            agg_metric, _ = evaluator(targets, forecasts, num_series=len(dataset_test))
            results_p.append(agg_metric['m_sum_mean_wQuantileLoss'])
            pbar.update(1)

        crps_sum_results.loc[p] = results_p

  0%|          | 0/90 [00:00<?, ?it/s]
Running evaluation: 7it [00:00, 66.04it/s]

Running evaluation: 7it [00:00, 69.74it/s]

Running evaluation: 7it [00:00, 61.04it/s]

Running evaluation: 7it [00:00, 60.16it/s]

Running evaluation: 7it [00:00, 63.60it/s]

Running evaluation: 7it [00:00, 66.33it/s]

Running evaluation: 7it [00:00, 62.18it/s]

Running evaluation: 7it [00:00, 68.40it/s]

Running evaluation: 7it [00:00, 65.29it/s]

Running evaluation: 7it [00:00, 69.06it/s]

Running evaluation: 7it [00:00, 70.59it/s]

Running evaluation: 7it [00:00, 67.75it/s]

Running evaluation: 7it [00:00, 69.96it/s]

Running evaluation: 7it [00:00, 67.39it/s]

Running evaluation: 7it [00:00, 67.09it/s]

Running evaluation: 7it [00:00, 69.16it/s]

Running evaluation: 7it [00:00, 62.41it/s]

Running evaluation: 7it [00:00, 68.46it/s]

Running evaluation: 7it [00:00, 69.28it/s]

Running evaluation: 7it [00:00, 66.52it/s]

Running evaluation: 7it [00:00, 67.76it/s]

Running evaluation: 7it [00:00, 63.83

In [15]:
crps_sum_results

Unnamed: 0,0,1,2,3,4,5,6,7,8,9
0.1,0.316294,0.308476,0.326342,0.367161,0.323442,0.312021,0.319708,0.312569,0.328698,0.309832
0.2,0.365494,0.342766,0.360484,0.396763,0.355549,0.34308,0.360361,0.34427,0.356978,0.344453
0.3,0.357299,0.313618,0.352668,0.366452,0.353973,0.31252,0.356401,0.317901,0.35496,0.316115
0.4,0.331644,0.373367,0.343645,0.430531,0.346478,0.37775,0.346317,0.37573,0.347669,0.377319
0.5,0.418495,0.416832,0.336657,0.469853,0.334727,0.419248,0.334798,0.417846,0.335849,0.414912
0.6,0.346795,0.334831,0.410464,0.385944,0.408754,0.337744,0.409436,0.337253,0.408815,0.334682
0.7,0.388019,0.327894,0.371527,0.370052,0.368144,0.319977,0.36833,0.31912,0.36745,0.323415
0.8,0.400023,0.429624,0.392653,0.48832,0.38717,0.42611,0.389264,0.428022,0.390439,0.428694
0.9,1.4581,1.346979,1.178966,1.393135,1.177913,1.346834,1.181553,1.348415,1.181906,1.35006
