In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import sagemaker
import boto3
import json
import tempfile

import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
from tqdm import tqdm

import mxnet as mx
from mxnet import gluon
from gluonts.dataset.common import ListDataset
from gluonts.dataset.loader import (
    TrainDataLoader, ValidationDataLoader, InferenceDataLoader
)
from gluonts.dataset.repository.datasets import get_dataset
from gluonts.model.deepar import DeepAREstimator
from gluonts.trainer import Trainer
from gluonts.support.util import get_hybrid_forward_input_names, HybridContext
from gluonts.trainer import learning_rate_scheduler as lrs

In [3]:
from gluonts.model.san import *

In [None]:
data = get_dataset('electricity')

In [4]:
estimator = SelfAttentionEstimator(
    freq='h',
    prediction_length=24,
    context_length=168,
    model_dim=64,
    ffn_dim_multiplier=2,
    num_heads=4,
    num_layers=3,
    num_outputs=3,
    cardinalities=[370],
    kernel_sizes=[5,9],
    distance_encoding='dot',
    use_feat_dynamic_cat=False,
    use_feat_dynamic_real=False,
    use_feat_static_cat=True,
    use_feat_static_real=False,
    trainer=Trainer(hybridize=False, epochs=1, learning_rate=1e-4)
)
transformation = estimator.create_transformation()
training_data_loader = TrainDataLoader(
    dataset=data.train,
    transform=transformation,
    batch_size=estimator.trainer.batch_size,
    num_batches_per_epoch=estimator.trainer.num_batches_per_epoch,
    ctx=estimator.trainer.ctx,
    dtype=estimator.dtype,
    num_workers=None,
    num_prefetch=None,
    shuffle_buffer_length=None,
)

In [5]:
with estimator.trainer.ctx:
    train_net = estimator.create_training_network()
train_net.initialize(ctx=estimator.trainer.ctx)
input_names = get_hybrid_forward_input_names(train_net)
with HybridContext(
    net=train_net,
    hybridize=False,
    static_alloc=True,
    static_shape=True,
):
    batch_size = training_data_loader.batch_size
    lr_scheduler = lrs.MetricAttentiveScheduler(
        objective="min",
        patience=estimator.trainer.patience,
        decay_factor=estimator.trainer.learning_rate_decay_factor,
        min_lr=estimator.trainer.minimum_learning_rate,
    )
    optimizer = mx.optimizer.Adam(
        learning_rate=estimator.trainer.learning_rate,
        lr_scheduler=lr_scheduler,
        wd=estimator.trainer.weight_decay,
        clip_gradient=estimator.trainer.clip_gradient,
    )
    trainer = mx.gluon.Trainer(
        train_net.collect_params(),
        optimizer=optimizer,
        kvstore="device",
    )
    
    for epoch_no in range(estimator.trainer.epochs):
        if estimator.trainer.halt:
            break
        curr_lr = trainer.learning_rate
        epoch_loss = mx.metric.Loss()
        for batch_no, data_entry in enumerate(training_data_loader, 1):
            if estimator.trainer.halt:
                break
            inputs = [data_entry[k] for k in input_names]
            with mx.autograd.record():
                output = train_net(*inputs)
                if isinstance(output, (list, tuple)):
                    loss = output[0]
                else:
                    loss = output
                print(loss.asnumpy())
            loss.backward()
            trainer.step(batch_size)
            epoch_loss.update(None, preds=loss)
            lv = epoch_loss.get_name_value()[0][1]
            if not np.isfinite(lv):
                print(f"Epoch{epoch_no} gave nan loss")
        should_continue = lr_scheduler.step(epoch_loss.get_name_value()[0][1]) 
        if not should_continue:
            print("Stopping training")
            break