In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import sagemaker
import boto3
import json
import tempfile
import inspect

import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
from tqdm import tqdm

import mxnet as mx
from mxnet import gluon
from gluonts.dataset.common import ListDataset
from gluonts.dataset.loader import (
    TrainDataLoader, ValidationDataLoader, InferenceDataLoader
)
from gluonts.dataset.repository.datasets import get_dataset
from gluonts.model.deepar import DeepAREstimator
from gluonts.trainer import Trainer
from gluonts.support.util import get_hybrid_forward_input_names, HybridContext
from gluonts.trainer import learning_rate_scheduler as lrs

In [3]:
from gluonts.model.tft import *

In [4]:
data = get_dataset('electricity')

In [5]:
estimator = TemporalFusionTransformerEstimator(
    freq='h',
    prediction_length=24,
    context_length=168,
    hidden_dim=64,
    variable_dim=16,
    num_heads=4,
    num_outputs=3,
    static_cardinalities={'feat_static_cat': 321},
)
transformation = estimator.create_transformation()
training_data_loader = TrainDataLoader(
    dataset=data.train,
    transform=transformation,
    batch_size=estimator.trainer.batch_size,
    num_batches_per_epoch=estimator.trainer.num_batches_per_epoch,
    ctx=estimator.trainer.ctx,
    dtype=estimator.dtype,
    num_workers=None,
    num_prefetch=None,
    shuffle_buffer_length=None,
)

In [6]:
with estimator.trainer.ctx:
    train_net = estimator.create_training_network()
train_net.initialize(ctx=estimator.trainer.ctx)
input_names = get_hybrid_forward_input_names(train_net)
with HybridContext(
    net=train_net,
    hybridize=False,
    static_alloc=True,
    static_shape=True,
):
    batch_size = training_data_loader.batch_size
    lr_scheduler = lrs.MetricAttentiveScheduler(
        objective="min",
        patience=estimator.trainer.patience,
        decay_factor=estimator.trainer.learning_rate_decay_factor,
        min_lr=estimator.trainer.minimum_learning_rate,
    )
    optimizer = mx.optimizer.Adam(
        learning_rate=estimator.trainer.learning_rate,
        lr_scheduler=lr_scheduler,
        wd=estimator.trainer.weight_decay,
        clip_gradient=estimator.trainer.clip_gradient,
    )
    trainer = mx.gluon.Trainer(
        train_net.collect_params(),
        optimizer=optimizer,
        kvstore="device",
    )
    
    for epoch_no in range(estimator.trainer.epochs):
        if estimator.trainer.halt:
            break
        curr_lr = trainer.learning_rate
        epoch_loss = mx.metric.Loss()
        for batch_no, data_entry in enumerate(training_data_loader, 1):
            if estimator.trainer.halt:
                break
            args = inspect.signature(train_net.hybrid_forward).parameters
            inputs = []
            for n, (name, arg) in enumerate(args.items()):
                if n == 0:
                    if name == 'F':
                        continue
                    else:
                        raise RuntimeError(
                            f"Expected first argument of HybridBlock to be `F`, "
                            f"but found `{param_names[0]}`"
                        )
                if name in data_entry:
                    inputs.append(data_entry[name])
                elif not (arg.default is inspect._empty):
                    inputs.append(arg.default)
                else:
                    raise RuntimeError(
                        f"The value of argument `{name}` of HybridBlock is not provided, "
                        f"and no default value is given."
                    )
            with mx.autograd.record():
                output = train_net(*inputs)
                if isinstance(output, (list, tuple)):
                    loss = output[0]
                else:
                    loss = output
                print(loss.asnumpy())
            loss.backward()
            trainer.step(batch_size)
            epoch_loss.update(None, preds=loss)
            lv = epoch_loss.get_name_value()[0][1]
            if not np.isfinite(lv):
                print(f"Epoch{epoch_no} gave nan loss")
        should_continue = lr_scheduler.step(epoch_loss.get_name_value()[0][1]) 
        if not should_continue:
            print("Stopping training")
            break

learning rate from ``lr_scheduler`` has been overwritten by ``learning_rate`` in optimizer.
[3.401213]
[2.83992]
[1.7574481]
[4.8020363]
[3.8611]


KeyboardInterrupt: 

In [7]:
it =estimator.create_predictor(transformation, train_net).predict(data.test)

In [8]:
next(it)

QuantileForecast(array([[15.701146 , 15.933726 , 16.086489 , 16.190166 , 14.944334 ,
        15.047554 , 15.084908 , 15.078518 , 15.045128 , 14.997371 ,
        14.946486 , 14.905204 , 14.888753 , 14.913998 , 14.991172 ,
        15.114402 , 15.261231 , 15.405194 , 15.535284 , 15.6496935,
        15.747926 , 15.831028 , 15.901478 , 15.961993 ],
       [-4.3144903, -4.0334826, -3.7718096, -3.5304308, -6.078389 ,
        -6.512109 , -6.87024  , -7.212331 , -7.5642233, -7.9372187,
        -8.331906 , -8.733825 , -9.10871  , -9.396594 , -9.514503 ,
        -9.386939 , -9.001687 , -8.421546 , -7.7481394, -7.073148 ,
        -6.441493 , -5.8694124, -5.3613176, -4.913842 ],
       [42.326935 , 42.24982  , 42.07966  , 41.872448 , 37.762405 ,
        38.253376 , 38.763416 , 39.31705  , 39.9351   , 40.635895 ,
        41.433567 , 42.331944 , 43.313763 , 44.324123 , 45.256214 ,
        45.965775 , 46.332767 , 46.32603  , 46.02235  , 45.543823 ,
        44.99133  , 44.428432 , 43.891773 , 43.397438