In [1]:
%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import pickle

from tqdm import tqdm
from multiprocessing import Pool, cpu_count

In [2]:
from gluonts.dataset.multivariate_grouper import MultivariateGrouper
from gluonts.dataset.repository.datasets import dataset_recipes, get_dataset
from gluonts.evaluation.backtest import make_evaluation_predictions
from gluonts.evaluation import MultivariateEvaluator
from diffusers import (
    PNDMScheduler,
    DDIMScheduler,
    DPMSolverMultistepScheduler,
    KDPM2DiscreteScheduler,
    DEISMultistepScheduler,
)

from pts.model.time_grad import TimeGradEstimator
from pts.dataset.repository.datasets import dataset_recipes

from utils import omit_points, interpolate_np_array

  self.freq: BaseOffset = to_offset(freq)


In [3]:
dataset = get_dataset("electricity_nips", regenerate=False)

train_grouper = MultivariateGrouper(
    max_target_dim=int(dataset.metadata.feat_static_cat[0].cardinality)
)

test_grouper = MultivariateGrouper(
    num_test_dates=int(len(dataset.test) / len(dataset.train)),
    max_target_dim=int(dataset.metadata.feat_static_cat[0].cardinality),
)

dataset_train = train_grouper(dataset.train)
dataset_test = test_grouper(dataset.test)

  return pd.Period(val, freq)


In [4]:
evaluator = MultivariateEvaluator(
    quantiles=(np.arange(20) / 20.0)[1:], target_agg_funcs={"sum": np.sum}
)

In [5]:
def plot(
    target,
    forecast,
    prediction_length,
    prediction_intervals=(50.0, 90.0),
    color="g",
    fname=None,
):
    label_prefix = ""
    rows = 4
    cols = 4
    fig, axs = plt.subplots(rows, cols, figsize=(24, 24))
    axx = axs.ravel()
    seq_len, target_dim = target.shape

    ps = [50.0] + [
        50.0 + f * c / 2.0 for c in prediction_intervals for f in [-1.0, +1.0]
    ]

    percentiles_sorted = sorted(set(ps))

    def alpha_for_percentile(p):
        return (p / 100.0) ** 0.3

    for dim in range(0, min(rows * cols, target_dim)):
        ax = axx[dim]

        target[-2 * prediction_length :][dim].plot(ax=ax)

        ps_data = [forecast.quantile(p / 100.0)[:, dim] for p in percentiles_sorted]
        i_p50 = len(percentiles_sorted) // 2

        p50_data = ps_data[i_p50]
        p50_series = pd.Series(data=p50_data, index=forecast.index)
        p50_series.plot(color=color, ls="-", label=f"{label_prefix}median", ax=ax)

        for i in range(len(percentiles_sorted) // 2):
            ptile = percentiles_sorted[i]
            alpha = alpha_for_percentile(ptile)
            ax.fill_between(
                forecast.index,
                ps_data[i],
                ps_data[-i - 1],
                facecolor=color,
                alpha=alpha,
                interpolate=True,
            )
            # Hack to create labels for the error intervals.
            # Doesn't actually plot anything, because we only pass a single data point
            pd.Series(data=p50_data[:1], index=forecast.index[:1]).plot(
                color=color,
                alpha=alpha,
                linewidth=10,
                label=f"{label_prefix}{100 - ptile * 2}%",
                ax=ax,
            )

    legend = ["observations", "median prediction"] + [
        f"{k}% prediction interval" for k in prediction_intervals
    ][::-1]
    axx[0].legend(legend, loc="upper left")

    if fname is not None:
        plt.savefig(fname, bbox_inches="tight", pad_inches=0.05)


## Train TimeGrad on Electricity dataset

In [6]:
scheduler = DEISMultistepScheduler(
    num_train_timesteps=150,
    beta_end=0.1,
)

In [7]:
estimator = TimeGradEstimator(
    input_size=int(dataset.metadata.feat_static_cat[0].cardinality),
    hidden_size=64,
    num_layers=2,
    dropout_rate=0.1,
    lags_seq=[1],
    scheduler=scheduler,
    num_inference_steps=149,
    prediction_length=dataset.metadata.prediction_length,
    context_length=dataset.metadata.prediction_length,
    freq=dataset.metadata.freq,
    scaling="mean",
    trainer_kwargs=dict(max_epochs=200, accelerator="gpu", devices="1")
)

  offset = to_offset(freq_str)


In [8]:
predictor = estimator.train(dataset_train, cache_data=True, shuffle_buffer_length=1024)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/jupyter/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/configuration_validator.py:74: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
2024-05-22 08:45:38.063288: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type          | Params | In sizes                                                             | Out sizes        
-----------------------------------------------------------------------------------------------------------------------------------
0 | model | TimeGradModel | 432 K  | [[1, 1], [

Epoch 0: |          | 50/? [00:06<00:00,  8.02it/s, v_num=14, train_loss=0.338]

Epoch 0, global step 50: 'train_loss' reached 0.33825 (best 0.33825), saving model to '/home/jupyter/work/resources/lightning_logs/version_14/checkpoints/epoch=0-step=50.ckpt' as top 1


Epoch 1: |          | 50/? [00:06<00:00,  8.15it/s, v_num=14, train_loss=0.0982]

Epoch 1, global step 100: 'train_loss' reached 0.09817 (best 0.09817), saving model to '/home/jupyter/work/resources/lightning_logs/version_14/checkpoints/epoch=1-step=100.ckpt' as top 1


Epoch 2: |          | 50/? [00:06<00:00,  8.26it/s, v_num=14, train_loss=0.079] 

Epoch 2, global step 150: 'train_loss' reached 0.07897 (best 0.07897), saving model to '/home/jupyter/work/resources/lightning_logs/version_14/checkpoints/epoch=2-step=150.ckpt' as top 1


Epoch 3: |          | 50/? [00:06<00:00,  8.26it/s, v_num=14, train_loss=0.0752]

Epoch 3, global step 200: 'train_loss' reached 0.07517 (best 0.07517), saving model to '/home/jupyter/work/resources/lightning_logs/version_14/checkpoints/epoch=3-step=200.ckpt' as top 1


Epoch 4: |          | 50/? [00:06<00:00,  8.24it/s, v_num=14, train_loss=0.0678]

Epoch 4, global step 250: 'train_loss' reached 0.06784 (best 0.06784), saving model to '/home/jupyter/work/resources/lightning_logs/version_14/checkpoints/epoch=4-step=250.ckpt' as top 1


Epoch 5: |          | 50/? [00:06<00:00,  7.77it/s, v_num=14, train_loss=0.0611]

Epoch 5, global step 300: 'train_loss' reached 0.06113 (best 0.06113), saving model to '/home/jupyter/work/resources/lightning_logs/version_14/checkpoints/epoch=5-step=300.ckpt' as top 1


Epoch 6: |          | 50/? [00:06<00:00,  8.18it/s, v_num=14, train_loss=0.0554]

Epoch 6, global step 350: 'train_loss' reached 0.05538 (best 0.05538), saving model to '/home/jupyter/work/resources/lightning_logs/version_14/checkpoints/epoch=6-step=350.ckpt' as top 1


Epoch 7: |          | 50/? [00:06<00:00,  8.16it/s, v_num=14, train_loss=0.0519]

Epoch 7, global step 400: 'train_loss' reached 0.05194 (best 0.05194), saving model to '/home/jupyter/work/resources/lightning_logs/version_14/checkpoints/epoch=7-step=400.ckpt' as top 1


Epoch 8: |          | 50/? [00:06<00:00,  8.22it/s, v_num=14, train_loss=0.0512]

Epoch 8, global step 450: 'train_loss' reached 0.05117 (best 0.05117), saving model to '/home/jupyter/work/resources/lightning_logs/version_14/checkpoints/epoch=8-step=450.ckpt' as top 1


Epoch 9: |          | 50/? [00:06<00:00,  8.20it/s, v_num=14, train_loss=0.0486]

Epoch 9, global step 500: 'train_loss' reached 0.04857 (best 0.04857), saving model to '/home/jupyter/work/resources/lightning_logs/version_14/checkpoints/epoch=9-step=500.ckpt' as top 1


Epoch 10: |          | 50/? [00:06<00:00,  8.17it/s, v_num=14, train_loss=0.0503]

Epoch 10, global step 550: 'train_loss' was not in top 1


Epoch 11: |          | 50/? [00:06<00:00,  8.19it/s, v_num=14, train_loss=0.0473]

Epoch 11, global step 600: 'train_loss' reached 0.04731 (best 0.04731), saving model to '/home/jupyter/work/resources/lightning_logs/version_14/checkpoints/epoch=11-step=600.ckpt' as top 1


Epoch 12: |          | 50/? [00:06<00:00,  8.23it/s, v_num=14, train_loss=0.0461]

Epoch 12, global step 650: 'train_loss' reached 0.04611 (best 0.04611), saving model to '/home/jupyter/work/resources/lightning_logs/version_14/checkpoints/epoch=12-step=650.ckpt' as top 1


Epoch 13: |          | 50/? [00:06<00:00,  8.20it/s, v_num=14, train_loss=0.045] 

Epoch 13, global step 700: 'train_loss' reached 0.04499 (best 0.04499), saving model to '/home/jupyter/work/resources/lightning_logs/version_14/checkpoints/epoch=13-step=700.ckpt' as top 1


Epoch 14: |          | 50/? [00:06<00:00,  8.20it/s, v_num=14, train_loss=0.0439]

Epoch 14, global step 750: 'train_loss' reached 0.04392 (best 0.04392), saving model to '/home/jupyter/work/resources/lightning_logs/version_14/checkpoints/epoch=14-step=750.ckpt' as top 1


Epoch 15: |          | 50/? [00:06<00:00,  8.24it/s, v_num=14, train_loss=0.0425]

Epoch 15, global step 800: 'train_loss' reached 0.04251 (best 0.04251), saving model to '/home/jupyter/work/resources/lightning_logs/version_14/checkpoints/epoch=15-step=800.ckpt' as top 1


Epoch 16: |          | 50/? [00:06<00:00,  8.21it/s, v_num=14, train_loss=0.041] 

Epoch 16, global step 850: 'train_loss' reached 0.04100 (best 0.04100), saving model to '/home/jupyter/work/resources/lightning_logs/version_14/checkpoints/epoch=16-step=850.ckpt' as top 1


Epoch 17: |          | 50/? [00:06<00:00,  8.20it/s, v_num=14, train_loss=0.0399]

Epoch 17, global step 900: 'train_loss' reached 0.03993 (best 0.03993), saving model to '/home/jupyter/work/resources/lightning_logs/version_14/checkpoints/epoch=17-step=900.ckpt' as top 1


Epoch 18: |          | 50/? [00:06<00:00,  8.28it/s, v_num=14, train_loss=0.0402]

Epoch 18, global step 950: 'train_loss' was not in top 1


Epoch 19: |          | 50/? [00:06<00:00,  8.24it/s, v_num=14, train_loss=0.0394]

Epoch 19, global step 1000: 'train_loss' reached 0.03937 (best 0.03937), saving model to '/home/jupyter/work/resources/lightning_logs/version_14/checkpoints/epoch=19-step=1000.ckpt' as top 1


Epoch 20: |          | 50/? [00:06<00:00,  8.16it/s, v_num=14, train_loss=0.0394]

Epoch 20, global step 1050: 'train_loss' was not in top 1


Epoch 21: |          | 50/? [00:06<00:00,  8.19it/s, v_num=14, train_loss=0.0399]

Epoch 21, global step 1100: 'train_loss' was not in top 1


Epoch 22: |          | 50/? [00:06<00:00,  8.23it/s, v_num=14, train_loss=0.0382]

Epoch 22, global step 1150: 'train_loss' reached 0.03823 (best 0.03823), saving model to '/home/jupyter/work/resources/lightning_logs/version_14/checkpoints/epoch=22-step=1150.ckpt' as top 1


Epoch 23: |          | 50/? [00:06<00:00,  8.25it/s, v_num=14, train_loss=0.0388]

Epoch 23, global step 1200: 'train_loss' was not in top 1


Epoch 24: |          | 50/? [00:06<00:00,  8.21it/s, v_num=14, train_loss=0.0385]

Epoch 24, global step 1250: 'train_loss' was not in top 1


Epoch 25: |          | 50/? [00:06<00:00,  8.26it/s, v_num=14, train_loss=0.039] 

Epoch 25, global step 1300: 'train_loss' was not in top 1


Epoch 26: |          | 50/? [00:06<00:00,  8.25it/s, v_num=14, train_loss=0.0378]

Epoch 26, global step 1350: 'train_loss' reached 0.03783 (best 0.03783), saving model to '/home/jupyter/work/resources/lightning_logs/version_14/checkpoints/epoch=26-step=1350.ckpt' as top 1


Epoch 27: |          | 50/? [00:06<00:00,  8.25it/s, v_num=14, train_loss=0.0375]

Epoch 27, global step 1400: 'train_loss' reached 0.03749 (best 0.03749), saving model to '/home/jupyter/work/resources/lightning_logs/version_14/checkpoints/epoch=27-step=1400.ckpt' as top 1


Epoch 28: |          | 50/? [00:06<00:00,  8.18it/s, v_num=14, train_loss=0.038] 

Epoch 28, global step 1450: 'train_loss' was not in top 1


Epoch 29: |          | 50/? [00:06<00:00,  7.86it/s, v_num=14, train_loss=0.0382]

Epoch 29, global step 1500: 'train_loss' was not in top 1


Epoch 30: |          | 50/? [00:06<00:00,  8.24it/s, v_num=14, train_loss=0.0368]

Epoch 30, global step 1550: 'train_loss' reached 0.03682 (best 0.03682), saving model to '/home/jupyter/work/resources/lightning_logs/version_14/checkpoints/epoch=30-step=1550.ckpt' as top 1


Epoch 31: |          | 50/? [00:06<00:00,  8.24it/s, v_num=14, train_loss=0.0363]

Epoch 31, global step 1600: 'train_loss' reached 0.03630 (best 0.03630), saving model to '/home/jupyter/work/resources/lightning_logs/version_14/checkpoints/epoch=31-step=1600.ckpt' as top 1


Epoch 32: |          | 50/? [00:06<00:00,  8.19it/s, v_num=14, train_loss=0.0367]

Epoch 32, global step 1650: 'train_loss' was not in top 1


Epoch 33: |          | 50/? [00:06<00:00,  8.19it/s, v_num=14, train_loss=0.0373]

Epoch 33, global step 1700: 'train_loss' was not in top 1


Epoch 34: |          | 50/? [00:06<00:00,  8.16it/s, v_num=14, train_loss=0.0358]

Epoch 34, global step 1750: 'train_loss' reached 0.03582 (best 0.03582), saving model to '/home/jupyter/work/resources/lightning_logs/version_14/checkpoints/epoch=34-step=1750.ckpt' as top 1


Epoch 35: |          | 50/? [00:06<00:00,  8.19it/s, v_num=14, train_loss=0.0364]

Epoch 35, global step 1800: 'train_loss' was not in top 1


Epoch 36: |          | 50/? [00:06<00:00,  8.24it/s, v_num=14, train_loss=0.0358]

Epoch 36, global step 1850: 'train_loss' reached 0.03581 (best 0.03581), saving model to '/home/jupyter/work/resources/lightning_logs/version_14/checkpoints/epoch=36-step=1850.ckpt' as top 1


Epoch 37: |          | 50/? [00:06<00:00,  8.20it/s, v_num=14, train_loss=0.0361]

Epoch 37, global step 1900: 'train_loss' was not in top 1


Epoch 38: |          | 50/? [00:06<00:00,  8.20it/s, v_num=14, train_loss=0.0369]

Epoch 38, global step 1950: 'train_loss' was not in top 1


Epoch 39: |          | 50/? [00:06<00:00,  8.18it/s, v_num=14, train_loss=0.0358]

Epoch 39, global step 2000: 'train_loss' was not in top 1


Epoch 40: |          | 50/? [00:06<00:00,  8.13it/s, v_num=14, train_loss=0.0352]

Epoch 40, global step 2050: 'train_loss' reached 0.03523 (best 0.03523), saving model to '/home/jupyter/work/resources/lightning_logs/version_14/checkpoints/epoch=40-step=2050.ckpt' as top 1


Epoch 41: |          | 50/? [00:06<00:00,  8.28it/s, v_num=14, train_loss=0.0352]

Epoch 41, global step 2100: 'train_loss' reached 0.03522 (best 0.03522), saving model to '/home/jupyter/work/resources/lightning_logs/version_14/checkpoints/epoch=41-step=2100.ckpt' as top 1


Epoch 42: |          | 50/? [00:06<00:00,  8.26it/s, v_num=14, train_loss=0.0352]

Epoch 42, global step 2150: 'train_loss' reached 0.03521 (best 0.03521), saving model to '/home/jupyter/work/resources/lightning_logs/version_14/checkpoints/epoch=42-step=2150.ckpt' as top 1


Epoch 43: |          | 50/? [00:06<00:00,  8.32it/s, v_num=14, train_loss=0.0346]

Epoch 43, global step 2200: 'train_loss' reached 0.03455 (best 0.03455), saving model to '/home/jupyter/work/resources/lightning_logs/version_14/checkpoints/epoch=43-step=2200.ckpt' as top 1


Epoch 44: |          | 50/? [00:06<00:00,  8.29it/s, v_num=14, train_loss=0.034] 

Epoch 44, global step 2250: 'train_loss' reached 0.03401 (best 0.03401), saving model to '/home/jupyter/work/resources/lightning_logs/version_14/checkpoints/epoch=44-step=2250.ckpt' as top 1


Epoch 45: |          | 50/? [00:06<00:00,  8.29it/s, v_num=14, train_loss=0.034]

Epoch 45, global step 2300: 'train_loss' reached 0.03397 (best 0.03397), saving model to '/home/jupyter/work/resources/lightning_logs/version_14/checkpoints/epoch=45-step=2300.ckpt' as top 1


Epoch 46: |          | 50/? [00:06<00:00,  8.24it/s, v_num=14, train_loss=0.0346]

Epoch 46, global step 2350: 'train_loss' was not in top 1


Epoch 47: |          | 50/? [00:06<00:00,  8.24it/s, v_num=14, train_loss=0.0338]

Epoch 47, global step 2400: 'train_loss' reached 0.03385 (best 0.03385), saving model to '/home/jupyter/work/resources/lightning_logs/version_14/checkpoints/epoch=47-step=2400.ckpt' as top 1


Epoch 48: |          | 50/? [00:06<00:00,  8.25it/s, v_num=14, train_loss=0.0343]

Epoch 48, global step 2450: 'train_loss' was not in top 1


Epoch 49: |          | 50/? [00:06<00:00,  8.25it/s, v_num=14, train_loss=0.0347]

Epoch 49, global step 2500: 'train_loss' was not in top 1


Epoch 50: |          | 50/? [00:06<00:00,  8.28it/s, v_num=14, train_loss=0.034] 

Epoch 50, global step 2550: 'train_loss' was not in top 1


Epoch 51: |          | 50/? [00:06<00:00,  8.30it/s, v_num=14, train_loss=0.0333]

Epoch 51, global step 2600: 'train_loss' reached 0.03328 (best 0.03328), saving model to '/home/jupyter/work/resources/lightning_logs/version_14/checkpoints/epoch=51-step=2600.ckpt' as top 1


Epoch 52: |          | 50/? [00:06<00:00,  7.85it/s, v_num=14, train_loss=0.0329]

Epoch 52, global step 2650: 'train_loss' reached 0.03289 (best 0.03289), saving model to '/home/jupyter/work/resources/lightning_logs/version_14/checkpoints/epoch=52-step=2650.ckpt' as top 1


Epoch 53: |          | 50/? [00:06<00:00,  8.26it/s, v_num=14, train_loss=0.0338]

Epoch 53, global step 2700: 'train_loss' was not in top 1


Epoch 54: |          | 50/? [00:06<00:00,  8.28it/s, v_num=14, train_loss=0.0332]

Epoch 54, global step 2750: 'train_loss' was not in top 1


Epoch 55: |          | 50/? [00:06<00:00,  8.18it/s, v_num=14, train_loss=0.034] 

Epoch 55, global step 2800: 'train_loss' was not in top 1


Epoch 56: |          | 50/? [00:06<00:00,  8.23it/s, v_num=14, train_loss=0.0327]

Epoch 56, global step 2850: 'train_loss' reached 0.03274 (best 0.03274), saving model to '/home/jupyter/work/resources/lightning_logs/version_14/checkpoints/epoch=56-step=2850.ckpt' as top 1


Epoch 57: |          | 50/? [00:06<00:00,  8.25it/s, v_num=14, train_loss=0.0338]

Epoch 57, global step 2900: 'train_loss' was not in top 1


Epoch 58: |          | 50/? [00:06<00:00,  8.23it/s, v_num=14, train_loss=0.0328]

Epoch 58, global step 2950: 'train_loss' was not in top 1


Epoch 59: |          | 50/? [00:06<00:00,  8.21it/s, v_num=14, train_loss=0.0324]

Epoch 59, global step 3000: 'train_loss' reached 0.03237 (best 0.03237), saving model to '/home/jupyter/work/resources/lightning_logs/version_14/checkpoints/epoch=59-step=3000.ckpt' as top 1


Epoch 60: |          | 50/? [00:06<00:00,  8.23it/s, v_num=14, train_loss=0.0338]

Epoch 60, global step 3050: 'train_loss' was not in top 1


Epoch 61: |          | 50/? [00:06<00:00,  8.24it/s, v_num=14, train_loss=0.0355]

Epoch 61, global step 3100: 'train_loss' was not in top 1


Epoch 62: |          | 50/? [00:06<00:00,  8.31it/s, v_num=14, train_loss=0.033] 

Epoch 62, global step 3150: 'train_loss' was not in top 1


Epoch 63: |          | 50/? [00:06<00:00,  8.22it/s, v_num=14, train_loss=0.0333]

Epoch 63, global step 3200: 'train_loss' was not in top 1


Epoch 64: |          | 50/? [00:06<00:00,  8.20it/s, v_num=14, train_loss=0.0332]

Epoch 64, global step 3250: 'train_loss' was not in top 1


Epoch 65: |          | 50/? [00:06<00:00,  8.22it/s, v_num=14, train_loss=0.0334]

Epoch 65, global step 3300: 'train_loss' was not in top 1


Epoch 66: |          | 50/? [00:06<00:00,  8.27it/s, v_num=14, train_loss=0.0329]

Epoch 66, global step 3350: 'train_loss' was not in top 1


Epoch 67: |          | 50/? [00:06<00:00,  8.23it/s, v_num=14, train_loss=0.0322]

Epoch 67, global step 3400: 'train_loss' reached 0.03218 (best 0.03218), saving model to '/home/jupyter/work/resources/lightning_logs/version_14/checkpoints/epoch=67-step=3400.ckpt' as top 1


Epoch 68: |          | 50/? [00:06<00:00,  8.27it/s, v_num=14, train_loss=0.0326]

Epoch 68, global step 3450: 'train_loss' was not in top 1


Epoch 69: |          | 50/? [00:06<00:00,  8.20it/s, v_num=14, train_loss=0.0316]

Epoch 69, global step 3500: 'train_loss' reached 0.03165 (best 0.03165), saving model to '/home/jupyter/work/resources/lightning_logs/version_14/checkpoints/epoch=69-step=3500.ckpt' as top 1


Epoch 70: |          | 50/? [00:06<00:00,  8.31it/s, v_num=14, train_loss=0.033] 

Epoch 70, global step 3550: 'train_loss' was not in top 1


Epoch 71: |          | 50/? [00:06<00:00,  8.23it/s, v_num=14, train_loss=0.0334]

Epoch 71, global step 3600: 'train_loss' was not in top 1


Epoch 72: |          | 50/? [00:06<00:00,  8.22it/s, v_num=14, train_loss=0.0318]

Epoch 72, global step 3650: 'train_loss' was not in top 1


Epoch 73: |          | 50/? [00:06<00:00,  8.31it/s, v_num=14, train_loss=0.0315]

Epoch 73, global step 3700: 'train_loss' reached 0.03154 (best 0.03154), saving model to '/home/jupyter/work/resources/lightning_logs/version_14/checkpoints/epoch=73-step=3700.ckpt' as top 1


Epoch 74: |          | 50/? [00:06<00:00,  8.33it/s, v_num=14, train_loss=0.0325]

Epoch 74, global step 3750: 'train_loss' was not in top 1


Epoch 75: |          | 50/? [00:06<00:00,  8.27it/s, v_num=14, train_loss=0.0311]

Epoch 75, global step 3800: 'train_loss' reached 0.03114 (best 0.03114), saving model to '/home/jupyter/work/resources/lightning_logs/version_14/checkpoints/epoch=75-step=3800.ckpt' as top 1


Epoch 76: |          | 50/? [00:06<00:00,  7.83it/s, v_num=14, train_loss=0.0317]

Epoch 76, global step 3850: 'train_loss' was not in top 1


Epoch 77: |          | 50/? [00:06<00:00,  8.28it/s, v_num=14, train_loss=0.0309]

Epoch 77, global step 3900: 'train_loss' reached 0.03093 (best 0.03093), saving model to '/home/jupyter/work/resources/lightning_logs/version_14/checkpoints/epoch=77-step=3900.ckpt' as top 1


Epoch 78: |          | 50/? [00:06<00:00,  8.24it/s, v_num=14, train_loss=0.0314]

Epoch 78, global step 3950: 'train_loss' was not in top 1


Epoch 79: |          | 50/? [00:06<00:00,  8.22it/s, v_num=14, train_loss=0.0314]

Epoch 79, global step 4000: 'train_loss' was not in top 1


Epoch 80: |          | 50/? [00:06<00:00,  8.27it/s, v_num=14, train_loss=0.032] 

Epoch 80, global step 4050: 'train_loss' was not in top 1


Epoch 81: |          | 50/? [00:06<00:00,  8.28it/s, v_num=14, train_loss=0.0312]

Epoch 81, global step 4100: 'train_loss' was not in top 1


Epoch 82: |          | 50/? [00:06<00:00,  8.29it/s, v_num=14, train_loss=0.0314]

Epoch 82, global step 4150: 'train_loss' was not in top 1


Epoch 83: |          | 50/? [00:06<00:00,  8.25it/s, v_num=14, train_loss=0.0318]

Epoch 83, global step 4200: 'train_loss' was not in top 1


Epoch 84: |          | 50/? [00:06<00:00,  8.19it/s, v_num=14, train_loss=0.0311]

Epoch 84, global step 4250: 'train_loss' was not in top 1


Epoch 85: |          | 50/? [00:06<00:00,  8.26it/s, v_num=14, train_loss=0.0307]

Epoch 85, global step 4300: 'train_loss' reached 0.03067 (best 0.03067), saving model to '/home/jupyter/work/resources/lightning_logs/version_14/checkpoints/epoch=85-step=4300.ckpt' as top 1


Epoch 86: |          | 50/? [00:06<00:00,  8.21it/s, v_num=14, train_loss=0.0311]

Epoch 86, global step 4350: 'train_loss' was not in top 1


Epoch 87: |          | 50/? [00:06<00:00,  8.22it/s, v_num=14, train_loss=0.0306]

Epoch 87, global step 4400: 'train_loss' reached 0.03061 (best 0.03061), saving model to '/home/jupyter/work/resources/lightning_logs/version_14/checkpoints/epoch=87-step=4400.ckpt' as top 1


Epoch 88: |          | 50/? [00:06<00:00,  8.25it/s, v_num=14, train_loss=0.030] 

Epoch 88, global step 4450: 'train_loss' reached 0.03004 (best 0.03004), saving model to '/home/jupyter/work/resources/lightning_logs/version_14/checkpoints/epoch=88-step=4450.ckpt' as top 1


Epoch 89: |          | 50/? [00:06<00:00,  8.19it/s, v_num=14, train_loss=0.0314]

Epoch 89, global step 4500: 'train_loss' was not in top 1


Epoch 90: |          | 50/? [00:06<00:00,  8.18it/s, v_num=14, train_loss=0.0304]

Epoch 90, global step 4550: 'train_loss' was not in top 1


Epoch 91: |          | 50/? [00:06<00:00,  8.22it/s, v_num=14, train_loss=0.0303]

Epoch 91, global step 4600: 'train_loss' was not in top 1


Epoch 92: |          | 50/? [00:06<00:00,  8.30it/s, v_num=14, train_loss=0.0308]

Epoch 92, global step 4650: 'train_loss' was not in top 1


Epoch 93: |          | 50/? [00:06<00:00,  8.22it/s, v_num=14, train_loss=0.0308]

Epoch 93, global step 4700: 'train_loss' was not in top 1


Epoch 94: |          | 50/? [00:06<00:00,  8.13it/s, v_num=14, train_loss=0.0302]

Epoch 94, global step 4750: 'train_loss' was not in top 1


Epoch 95: |          | 50/? [00:06<00:00,  8.19it/s, v_num=14, train_loss=0.0306]

Epoch 95, global step 4800: 'train_loss' was not in top 1


Epoch 96: |          | 50/? [00:06<00:00,  8.20it/s, v_num=14, train_loss=0.0312]

Epoch 96, global step 4850: 'train_loss' was not in top 1


Epoch 97: |          | 50/? [00:06<00:00,  8.24it/s, v_num=14, train_loss=0.0306]

Epoch 97, global step 4900: 'train_loss' was not in top 1


Epoch 98: |          | 50/? [00:06<00:00,  8.30it/s, v_num=14, train_loss=0.0309]

Epoch 98, global step 4950: 'train_loss' was not in top 1


Epoch 99: |          | 50/? [00:06<00:00,  8.22it/s, v_num=14, train_loss=0.0313]

Epoch 99, global step 5000: 'train_loss' was not in top 1


Epoch 100: |          | 50/? [00:06<00:00,  8.24it/s, v_num=14, train_loss=0.0302]

Epoch 100, global step 5050: 'train_loss' was not in top 1


Epoch 101: |          | 50/? [00:06<00:00,  7.86it/s, v_num=14, train_loss=0.0301]

Epoch 101, global step 5100: 'train_loss' was not in top 1


Epoch 102: |          | 50/? [00:06<00:00,  8.28it/s, v_num=14, train_loss=0.0292]

Epoch 102, global step 5150: 'train_loss' reached 0.02924 (best 0.02924), saving model to '/home/jupyter/work/resources/lightning_logs/version_14/checkpoints/epoch=102-step=5150.ckpt' as top 1


Epoch 103: |          | 50/? [00:06<00:00,  8.25it/s, v_num=14, train_loss=0.0309]

Epoch 103, global step 5200: 'train_loss' was not in top 1


Epoch 104: |          | 50/? [00:06<00:00,  8.23it/s, v_num=14, train_loss=0.0303]

Epoch 104, global step 5250: 'train_loss' was not in top 1


Epoch 105: |          | 50/? [00:06<00:00,  8.23it/s, v_num=14, train_loss=0.0297]

Epoch 105, global step 5300: 'train_loss' was not in top 1


Epoch 106: |          | 50/? [00:06<00:00,  8.23it/s, v_num=14, train_loss=0.0305]

Epoch 106, global step 5350: 'train_loss' was not in top 1


Epoch 107: |          | 50/? [00:06<00:00,  8.22it/s, v_num=14, train_loss=0.0299]

Epoch 107, global step 5400: 'train_loss' was not in top 1


Epoch 108: |          | 50/? [00:06<00:00,  8.25it/s, v_num=14, train_loss=0.0295]

Epoch 108, global step 5450: 'train_loss' was not in top 1


Epoch 109: |          | 50/? [00:06<00:00,  8.29it/s, v_num=14, train_loss=0.0297]

Epoch 109, global step 5500: 'train_loss' was not in top 1


Epoch 110: |          | 50/? [00:06<00:00,  8.22it/s, v_num=14, train_loss=0.0296]

Epoch 110, global step 5550: 'train_loss' was not in top 1


Epoch 111: |          | 50/? [00:06<00:00,  8.24it/s, v_num=14, train_loss=0.030] 

Epoch 111, global step 5600: 'train_loss' was not in top 1


Epoch 112: |          | 50/? [00:06<00:00,  8.27it/s, v_num=14, train_loss=0.0296]

Epoch 112, global step 5650: 'train_loss' was not in top 1


Epoch 113: |          | 50/? [00:06<00:00,  8.23it/s, v_num=14, train_loss=0.0297]

Epoch 113, global step 5700: 'train_loss' was not in top 1


Epoch 114: |          | 50/? [00:06<00:00,  8.27it/s, v_num=14, train_loss=0.0294]

Epoch 114, global step 5750: 'train_loss' was not in top 1


Epoch 115: |          | 50/? [00:06<00:00,  8.25it/s, v_num=14, train_loss=0.0289]

Epoch 115, global step 5800: 'train_loss' reached 0.02889 (best 0.02889), saving model to '/home/jupyter/work/resources/lightning_logs/version_14/checkpoints/epoch=115-step=5800.ckpt' as top 1


Epoch 116: |          | 50/? [00:06<00:00,  8.24it/s, v_num=14, train_loss=0.0295]

Epoch 116, global step 5850: 'train_loss' was not in top 1


Epoch 117: |          | 50/? [00:06<00:00,  8.20it/s, v_num=14, train_loss=0.0286]

Epoch 117, global step 5900: 'train_loss' reached 0.02857 (best 0.02857), saving model to '/home/jupyter/work/resources/lightning_logs/version_14/checkpoints/epoch=117-step=5900.ckpt' as top 1


Epoch 118: |          | 50/? [00:06<00:00,  8.26it/s, v_num=14, train_loss=0.0286]

Epoch 118, global step 5950: 'train_loss' was not in top 1


Epoch 119: |          | 50/? [00:06<00:00,  8.30it/s, v_num=14, train_loss=0.029] 

Epoch 119, global step 6000: 'train_loss' was not in top 1


Epoch 120: |          | 50/? [00:06<00:00,  8.30it/s, v_num=14, train_loss=0.0291]

Epoch 120, global step 6050: 'train_loss' was not in top 1


Epoch 121: |          | 50/? [00:06<00:00,  8.21it/s, v_num=14, train_loss=0.0289]

Epoch 121, global step 6100: 'train_loss' was not in top 1


Epoch 122: |          | 50/? [00:06<00:00,  8.27it/s, v_num=14, train_loss=0.0297]

Epoch 122, global step 6150: 'train_loss' was not in top 1


Epoch 123: |          | 50/? [00:06<00:00,  8.21it/s, v_num=14, train_loss=0.0293]

Epoch 123, global step 6200: 'train_loss' was not in top 1


Epoch 124: |          | 50/? [00:06<00:00,  8.22it/s, v_num=14, train_loss=0.0304]

Epoch 124, global step 6250: 'train_loss' was not in top 1


Epoch 125: |          | 50/? [00:06<00:00,  7.88it/s, v_num=14, train_loss=0.0296]

Epoch 125, global step 6300: 'train_loss' was not in top 1


Epoch 126: |          | 50/? [00:06<00:00,  8.18it/s, v_num=14, train_loss=0.029] 

Epoch 126, global step 6350: 'train_loss' was not in top 1


Epoch 127: |          | 50/? [00:06<00:00,  8.27it/s, v_num=14, train_loss=0.0299]

Epoch 127, global step 6400: 'train_loss' was not in top 1


Epoch 128: |          | 50/? [00:06<00:00,  8.20it/s, v_num=14, train_loss=0.0292]

Epoch 128, global step 6450: 'train_loss' was not in top 1


Epoch 129: |          | 50/? [00:06<00:00,  8.27it/s, v_num=14, train_loss=0.0289]

Epoch 129, global step 6500: 'train_loss' was not in top 1


Epoch 130: |          | 50/? [00:06<00:00,  8.21it/s, v_num=14, train_loss=0.0291]

Epoch 130, global step 6550: 'train_loss' was not in top 1


Epoch 131: |          | 50/? [00:06<00:00,  8.21it/s, v_num=14, train_loss=0.0294]

Epoch 131, global step 6600: 'train_loss' was not in top 1


Epoch 132: |          | 50/? [00:06<00:00,  8.21it/s, v_num=14, train_loss=0.0294]

Epoch 132, global step 6650: 'train_loss' was not in top 1


Epoch 133: |          | 50/? [00:06<00:00,  8.17it/s, v_num=14, train_loss=0.0292]

Epoch 133, global step 6700: 'train_loss' was not in top 1


Epoch 134: |          | 50/? [00:06<00:00,  8.26it/s, v_num=14, train_loss=0.0301]

Epoch 134, global step 6750: 'train_loss' was not in top 1


Epoch 135: |          | 50/? [00:06<00:00,  8.22it/s, v_num=14, train_loss=0.0288]

Epoch 135, global step 6800: 'train_loss' was not in top 1


Epoch 136: |          | 50/? [00:06<00:00,  8.27it/s, v_num=14, train_loss=0.0285]

Epoch 136, global step 6850: 'train_loss' reached 0.02847 (best 0.02847), saving model to '/home/jupyter/work/resources/lightning_logs/version_14/checkpoints/epoch=136-step=6850.ckpt' as top 1


Epoch 137: |          | 50/? [00:06<00:00,  8.16it/s, v_num=14, train_loss=0.0291]

Epoch 137, global step 6900: 'train_loss' was not in top 1


Epoch 138: |          | 50/? [00:06<00:00,  8.18it/s, v_num=14, train_loss=0.0297]

Epoch 138, global step 6950: 'train_loss' was not in top 1


Epoch 139: |          | 50/? [00:06<00:00,  8.17it/s, v_num=14, train_loss=0.0298]

Epoch 139, global step 7000: 'train_loss' was not in top 1


Epoch 140: |          | 50/? [00:06<00:00,  8.20it/s, v_num=14, train_loss=0.0287]

Epoch 140, global step 7050: 'train_loss' was not in top 1


Epoch 141: |          | 50/? [00:06<00:00,  8.16it/s, v_num=14, train_loss=0.0288]

Epoch 141, global step 7100: 'train_loss' was not in top 1


Epoch 142: |          | 50/? [00:06<00:00,  8.17it/s, v_num=14, train_loss=0.0294]

Epoch 142, global step 7150: 'train_loss' was not in top 1


Epoch 143: |          | 50/? [00:06<00:00,  8.19it/s, v_num=14, train_loss=0.0291]

Epoch 143, global step 7200: 'train_loss' was not in top 1


Epoch 144: |          | 50/? [00:06<00:00,  8.28it/s, v_num=14, train_loss=0.0284]

Epoch 144, global step 7250: 'train_loss' reached 0.02840 (best 0.02840), saving model to '/home/jupyter/work/resources/lightning_logs/version_14/checkpoints/epoch=144-step=7250.ckpt' as top 1


Epoch 145: |          | 50/? [00:06<00:00,  8.24it/s, v_num=14, train_loss=0.0289]

Epoch 145, global step 7300: 'train_loss' was not in top 1


Epoch 146: |          | 50/? [00:06<00:00,  8.18it/s, v_num=14, train_loss=0.0292]

Epoch 146, global step 7350: 'train_loss' was not in top 1


Epoch 147: |          | 50/? [00:06<00:00,  8.27it/s, v_num=14, train_loss=0.0291]

Epoch 147, global step 7400: 'train_loss' was not in top 1


Epoch 148: |          | 50/? [00:06<00:00,  8.18it/s, v_num=14, train_loss=0.029] 

Epoch 148, global step 7450: 'train_loss' was not in top 1


Epoch 149: |          | 50/? [00:06<00:00,  8.23it/s, v_num=14, train_loss=0.0296]

Epoch 149, global step 7500: 'train_loss' was not in top 1


Epoch 150: |          | 50/? [00:06<00:00,  7.84it/s, v_num=14, train_loss=0.0285]

Epoch 150, global step 7550: 'train_loss' was not in top 1


Epoch 151: |          | 50/? [00:06<00:00,  8.23it/s, v_num=14, train_loss=0.029] 

Epoch 151, global step 7600: 'train_loss' was not in top 1


Epoch 152: |          | 50/? [00:06<00:00,  8.22it/s, v_num=14, train_loss=0.029]

Epoch 152, global step 7650: 'train_loss' was not in top 1


Epoch 153: |          | 50/? [00:06<00:00,  8.17it/s, v_num=14, train_loss=0.0292]

Epoch 153, global step 7700: 'train_loss' was not in top 1


Epoch 154: |          | 50/? [00:06<00:00,  8.18it/s, v_num=14, train_loss=0.0289]

Epoch 154, global step 7750: 'train_loss' was not in top 1


Epoch 155: |          | 50/? [00:06<00:00,  8.22it/s, v_num=14, train_loss=0.0292]

Epoch 155, global step 7800: 'train_loss' was not in top 1


Epoch 156: |          | 50/? [00:06<00:00,  8.26it/s, v_num=14, train_loss=0.0286]

Epoch 156, global step 7850: 'train_loss' was not in top 1


Epoch 157: |          | 50/? [00:06<00:00,  8.21it/s, v_num=14, train_loss=0.0286]

Epoch 157, global step 7900: 'train_loss' was not in top 1


Epoch 158: |          | 50/? [00:06<00:00,  8.21it/s, v_num=14, train_loss=0.0293]

Epoch 158, global step 7950: 'train_loss' was not in top 1


Epoch 159: |          | 50/? [00:06<00:00,  8.26it/s, v_num=14, train_loss=0.029] 

Epoch 159, global step 8000: 'train_loss' was not in top 1


Epoch 160: |          | 50/? [00:06<00:00,  8.23it/s, v_num=14, train_loss=0.029]

Epoch 160, global step 8050: 'train_loss' was not in top 1


Epoch 161: |          | 50/? [00:06<00:00,  8.24it/s, v_num=14, train_loss=0.029]

Epoch 161, global step 8100: 'train_loss' was not in top 1


Epoch 162: |          | 50/? [00:06<00:00,  8.19it/s, v_num=14, train_loss=0.0288]

Epoch 162, global step 8150: 'train_loss' was not in top 1


Epoch 163: |          | 50/? [00:06<00:00,  8.25it/s, v_num=14, train_loss=0.0291]

Epoch 163, global step 8200: 'train_loss' was not in top 1


Epoch 164: |          | 50/? [00:06<00:00,  8.25it/s, v_num=14, train_loss=0.0289]

Epoch 164, global step 8250: 'train_loss' was not in top 1


Epoch 165: |          | 50/? [00:06<00:00,  8.25it/s, v_num=14, train_loss=0.029] 

Epoch 165, global step 8300: 'train_loss' was not in top 1


Epoch 166: |          | 50/? [00:06<00:00,  8.22it/s, v_num=14, train_loss=0.0287]

Epoch 166, global step 8350: 'train_loss' was not in top 1


Epoch 167: |          | 50/? [00:06<00:00,  8.22it/s, v_num=14, train_loss=0.0284]

Epoch 167, global step 8400: 'train_loss' was not in top 1


Epoch 168: |          | 50/? [00:06<00:00,  8.05it/s, v_num=14, train_loss=0.0291]

Epoch 168, global step 8450: 'train_loss' was not in top 1


Epoch 169: |          | 50/? [00:06<00:00,  7.90it/s, v_num=14, train_loss=0.0288]

Epoch 169, global step 8500: 'train_loss' was not in top 1


Epoch 170: |          | 50/? [00:06<00:00,  7.98it/s, v_num=14, train_loss=0.0288]

Epoch 170, global step 8550: 'train_loss' was not in top 1


Epoch 171: |          | 50/? [00:06<00:00,  7.94it/s, v_num=14, train_loss=0.0289]

Epoch 171, global step 8600: 'train_loss' was not in top 1


Epoch 172: |          | 50/? [00:06<00:00,  7.84it/s, v_num=14, train_loss=0.0286]

Epoch 172, global step 8650: 'train_loss' was not in top 1


Epoch 173: |          | 50/? [00:06<00:00,  8.11it/s, v_num=14, train_loss=0.0288]

Epoch 173, global step 8700: 'train_loss' was not in top 1


Epoch 174: |          | 50/? [00:06<00:00,  7.76it/s, v_num=14, train_loss=0.0293]

Epoch 174, global step 8750: 'train_loss' was not in top 1


Epoch 175: |          | 50/? [00:06<00:00,  8.14it/s, v_num=14, train_loss=0.0291]

Epoch 175, global step 8800: 'train_loss' was not in top 1


Epoch 176: |          | 50/? [00:06<00:00,  8.17it/s, v_num=14, train_loss=0.0289]

Epoch 176, global step 8850: 'train_loss' was not in top 1


Epoch 177: |          | 50/? [00:06<00:00,  7.97it/s, v_num=14, train_loss=0.0292]

Epoch 177, global step 8900: 'train_loss' was not in top 1


Epoch 178: |          | 50/? [00:06<00:00,  8.06it/s, v_num=14, train_loss=0.0286]

Epoch 178, global step 8950: 'train_loss' was not in top 1


Epoch 179: |          | 50/? [00:06<00:00,  8.15it/s, v_num=14, train_loss=0.0293]

Epoch 179, global step 9000: 'train_loss' was not in top 1


Epoch 180: |          | 50/? [00:06<00:00,  8.20it/s, v_num=14, train_loss=0.0285]

Epoch 180, global step 9050: 'train_loss' was not in top 1


Epoch 181: |          | 50/? [00:06<00:00,  8.18it/s, v_num=14, train_loss=0.0286]

Epoch 181, global step 9100: 'train_loss' was not in top 1


Epoch 182: |          | 50/? [00:06<00:00,  8.26it/s, v_num=14, train_loss=0.0288]

Epoch 182, global step 9150: 'train_loss' was not in top 1


Epoch 183: |          | 50/? [00:06<00:00,  8.19it/s, v_num=14, train_loss=0.0287]

Epoch 183, global step 9200: 'train_loss' was not in top 1


Epoch 184: |          | 50/? [00:06<00:00,  8.18it/s, v_num=14, train_loss=0.0291]

Epoch 184, global step 9250: 'train_loss' was not in top 1


Epoch 185: |          | 50/? [00:06<00:00,  8.23it/s, v_num=14, train_loss=0.0286]

Epoch 185, global step 9300: 'train_loss' was not in top 1


Epoch 186: |          | 50/? [00:06<00:00,  8.22it/s, v_num=14, train_loss=0.0285]

Epoch 186, global step 9350: 'train_loss' was not in top 1


Epoch 187: |          | 50/? [00:06<00:00,  8.25it/s, v_num=14, train_loss=0.0287]

Epoch 187, global step 9400: 'train_loss' was not in top 1


Epoch 188: |          | 50/? [00:06<00:00,  8.19it/s, v_num=14, train_loss=0.0289]

Epoch 188, global step 9450: 'train_loss' was not in top 1


Epoch 189: |          | 50/? [00:06<00:00,  8.19it/s, v_num=14, train_loss=0.0287]

Epoch 189, global step 9500: 'train_loss' was not in top 1


Epoch 190: |          | 50/? [00:06<00:00,  8.02it/s, v_num=14, train_loss=0.0291]

Epoch 190, global step 9550: 'train_loss' was not in top 1


Epoch 191: |          | 50/? [00:06<00:00,  8.11it/s, v_num=14, train_loss=0.0291]

Epoch 191, global step 9600: 'train_loss' was not in top 1


Epoch 192: |          | 50/? [00:06<00:00,  8.09it/s, v_num=14, train_loss=0.0287]

Epoch 192, global step 9650: 'train_loss' was not in top 1


Epoch 193: |          | 50/? [00:06<00:00,  8.05it/s, v_num=14, train_loss=0.0283]

Epoch 193, global step 9700: 'train_loss' reached 0.02833 (best 0.02833), saving model to '/home/jupyter/work/resources/lightning_logs/version_14/checkpoints/epoch=193-step=9700.ckpt' as top 1


Epoch 194: |          | 50/? [00:06<00:00,  8.25it/s, v_num=14, train_loss=0.0283]

Epoch 194, global step 9750: 'train_loss' was not in top 1


Epoch 195: |          | 50/? [00:06<00:00,  8.21it/s, v_num=14, train_loss=0.0285]

Epoch 195, global step 9800: 'train_loss' was not in top 1


Epoch 196: |          | 50/? [00:06<00:00,  8.23it/s, v_num=14, train_loss=0.0288]

Epoch 196, global step 9850: 'train_loss' was not in top 1


Epoch 197: |          | 50/? [00:06<00:00,  8.22it/s, v_num=14, train_loss=0.0289]

Epoch 197, global step 9900: 'train_loss' was not in top 1


Epoch 198: |          | 50/? [00:06<00:00,  8.28it/s, v_num=14, train_loss=0.0278]

Epoch 198, global step 9950: 'train_loss' reached 0.02781 (best 0.02781), saving model to '/home/jupyter/work/resources/lightning_logs/version_14/checkpoints/epoch=198-step=9950.ckpt' as top 1


Epoch 199: |          | 50/? [00:06<00:00,  7.74it/s, v_num=14, train_loss=0.029] 

Epoch 199, global step 10000: 'train_loss' was not in top 1
`Trainer.fit` stopped: `max_epochs=200` reached.


Epoch 199: |          | 50/? [00:06<00:00,  7.74it/s, v_num=14, train_loss=0.029]


In [9]:
forecast_it, ts_it = make_evaluation_predictions(
    dataset=dataset_test, predictor=predictor, num_samples=100
)
forecasts = list(forecast_it)
targets = list(ts_it)
agg_metric, _ = evaluator(targets, forecasts, num_series=len(dataset_test))

Running evaluation: 7it [00:00, 69.94it/s]
Running evaluation: 7it [00:00, 67.76it/s]
Running evaluation: 7it [00:00, 69.25it/s]
Running evaluation: 7it [00:00, 68.53it/s]
Running evaluation: 7it [00:00, 67.48it/s]
Running evaluation: 7it [00:00, 70.72it/s]
Running evaluation: 7it [00:00, 66.11it/s]
Running evaluation: 7it [00:00, 66.36it/s]
Running evaluation: 7it [00:00, 68.45it/s]
Running evaluation: 7it [00:00, 63.15it/s]
Running evaluation: 7it [00:00, 68.92it/s]
Running evaluation: 7it [00:00, 71.08it/s]
Running evaluation: 7it [00:00, 66.88it/s]
Running evaluation: 7it [00:00, 70.50it/s]
Running evaluation: 7it [00:00, 66.70it/s]
Running evaluation: 7it [00:00, 68.93it/s]
Running evaluation: 7it [00:00, 66.34it/s]
Running evaluation: 7it [00:00, 63.02it/s]
Running evaluation: 7it [00:00, 69.79it/s]
Running evaluation: 7it [00:00, 62.69it/s]
Running evaluation: 7it [00:00, 64.15it/s]
Running evaluation: 7it [00:00, 69.20it/s]
Running evaluation: 7it [00:00, 59.42it/s]
Running eva

In [10]:
print('Timegrad on regular dataset metrics:')
print("CRPS: {}".format(agg_metric["mean_wQuantileLoss"]))
print("ND: {}".format(agg_metric["ND"]))
print("NRMSE: {}".format(agg_metric["NRMSE"]))
print("MSE: {}".format(agg_metric["MSE"]))

print('')

print("CRPS-Sum: {}".format(agg_metric["m_sum_mean_wQuantileLoss"]))
print("ND-Sum: {}".format(agg_metric["m_sum_ND"]))
print("NRMSE-Sum: {}".format(agg_metric["m_sum_NRMSE"]))
print("MSE-Sum: {}".format(agg_metric["m_sum_MSE"]))

Timegrad on regular dataset metrics:
CRPS: 0.054710504611443216
ND: 0.0690697753085394
NRMSE: 0.6498842557863022
MSE: 241902.15123054018

CRPS-Sum: 0.02233674123229429
ND-Sum: 0.030012125785324334
NRMSE-Sum: 0.040314983705848954
MSE-Sum: 127439523.04761903


In [11]:
"""
with open('/home/jupyter/datasphere/project/models/electricity.pkl', 'wb') as f:
    pickle.dump(predictor, f)
"""

"\nwith open('/home/jupyter/datasphere/project/models/electricity.pkl', 'wb') as f:\n    pickle.dump(predictor, f)\n"

## Conduct Experiments

In [12]:
"""
# If you want to load already prepared model

transformation = estimator.create_transformation()
lightning_module = estimator.create_lightning_module()
predictor = estimator.create_predictor(transformation, lightning_module)
with open('/home/jupyter/work/resources/solar/model.pkl', 'rb') as f:
    predictor = pickle.load(f)
"""

"\n# If you want to load already prepared model\n\ntransformation = estimator.create_transformation()\nlightning_module = estimator.create_lightning_module()\npredictor = estimator.create_predictor(transformation, lightning_module)\nwith open('/home/jupyter/work/resources/solar/model.pkl', 'rb') as f:\n    predictor = pickle.load(f)\n"

In [13]:
def interpolate_points(dataset, num_points: int = 5):
    new_dataset = []
    with Pool(processes=cpu_count()) as pool:
        results = pool.starmap(interpolate_np_array, [(ts['target'], num_points) for ts in dataset])
        for i, ts in enumerate(dataset):
            mask = np.isnan(ts['target'])
            new_dataset.append({
                'target': ts['target'] * (1. - mask) + results[i] * mask,
                'start': ts['start'],
                'feat_static_cat': ts['feat_static_cat']
            })
    return new_dataset

In [19]:
crps_sum_results = pd.DataFrame(columns=range(10))

with tqdm(total=9*10) as pbar:
    for p in np.linspace(0.1, 1, 9, endpoint=False): # frac of removed points
        dataset_test_p = omit_points(dataset_test, p)
        results_p = []

        # No interpolation
        forecast_it, ts_it = make_evaluation_predictions(
            dataset=dataset_test_p, 
            predictor=predictor, 
            num_samples=100,
        )
        forecasts = list(forecast_it)
        # targets = list(ts_it)  # Do not uncomment, we need targets without NaNs, so we use ones calculated above.
        agg_metric, _ = evaluator(targets, forecasts, num_series=len(dataset_test))
        results_p.append(agg_metric['m_sum_mean_wQuantileLoss'])
        pbar.update(1)

        # Interpolation with num_points
        for num_points in range(1, 10):
            forecast_it, ts_it = make_evaluation_predictions(
                dataset=interpolate_points(dataset_test_p, num_points), 
                predictor=predictor, 
                num_samples=100,
            )
            forecasts = list(forecast_it)
            # targets = list(ts_it)  # Do not uncomment, we need targets without NaNs, so we use ones calculated above.
            agg_metric, _ = evaluator(targets, forecasts, num_series=len(dataset_test))
            results_p.append(agg_metric['m_sum_mean_wQuantileLoss'])
            pbar.update(1)

        crps_sum_results.loc[p] = results_p

  0%|          | 0/90 [00:00<?, ?it/s]
Running evaluation: 7it [00:00, 66.66it/s]

Running evaluation: 7it [00:00, 72.28it/s]

Running evaluation: 7it [00:00, 71.95it/s]

Running evaluation: 7it [00:00, 66.21it/s]

Running evaluation: 7it [00:00, 69.07it/s]

Running evaluation: 7it [00:00, 63.31it/s]

Running evaluation: 7it [00:00, 69.06it/s]

Running evaluation: 7it [00:00, 72.29it/s]

Running evaluation: 7it [00:00, 66.89it/s]

Running evaluation: 7it [00:00, 71.57it/s]

Running evaluation: 7it [00:00, 64.91it/s]

Running evaluation: 7it [00:00, 70.68it/s]

Running evaluation: 7it [00:00, 70.86it/s]

Running evaluation: 7it [00:00, 67.71it/s]

Running evaluation: 7it [00:00, 70.23it/s]

Running evaluation: 7it [00:00, 66.56it/s]

Running evaluation: 7it [00:00, 66.81it/s]

Running evaluation: 7it [00:00, 70.64it/s]

Running evaluation: 7it [00:00, 67.94it/s]

Running evaluation: 7it [00:00, 69.84it/s]

Running evaluation: 7it [00:00, 69.35it/s]

Running evaluation: 7it [00:00, 65.46

In [20]:
crps_sum_results

Unnamed: 0,0,1,2,3,4,5,6,7,8,9
0.1,0.027086,0.027214,0.026956,0.0269,0.027509,0.027209,0.027129,0.026895,0.02741,0.027262
0.2,0.028681,0.028619,0.028953,0.028961,0.028329,0.028718,0.029221,0.028905,0.028679,0.028268
0.3,0.04026,0.041241,0.04071,0.040827,0.040772,0.040829,0.040282,0.040518,0.040399,0.040977
0.4,0.049984,0.0496,0.049526,0.049539,0.048798,0.049651,0.04916,0.049208,0.049556,0.049113
0.5,0.218451,0.21794,0.218166,0.218247,0.218441,0.218621,0.218975,0.218295,0.218798,0.21799
0.6,0.405616,0.403999,0.405476,0.404888,0.404395,0.405973,0.405904,0.405156,0.404947,0.405183
0.7,0.365148,0.363907,0.364438,0.364306,0.36375,0.363791,0.363968,0.363736,0.364003,0.364552
0.8,0.380127,0.380777,0.380521,0.380265,0.379935,0.380647,0.380733,0.380244,0.380624,0.380668
0.9,0.318306,0.31788,0.316428,0.31736,0.31905,0.319102,0.318317,0.316586,0.315723,0.320595
