<img width="30%" src="https://ts.gluon.ai/dev/_static/gluonts.svg" alt="GluonTS logo" style="display: block; margin-left: auto; margin-right: auto;">

# Lab 1A: GluonTS

[GluonTS](https://ts.gluon.ai/stable/) is a Python library for probabilistic time series modeling, with a focus on deep learning-based approaches. 

First introduced in the [paper](https://www.jmlr.org/papers/volume21/19-820/19-820.pdf), it provides a toolkit for tasks such as forecasting and anomaly detection, simplifying the development and experimentation process for time series models. Supporting both PyTorch and MXNet implementations, GluonTS offers a modular and scalable design that is suitable for both experimentation and production use.

Refer to the blog post [Creating neural time series models with Gluon Time Series](https://aws.amazon.com/blogs/machine-learning/creating-neural-time-series-models-with-gluon-time-series/) for an introduction to GluonTS.

The library includes essential components like neural network architectures for sequences, feature processing steps, and [evaluation](https://ts.gluon.ai/stable/api/gluonts/gluonts.evaluation.html). It also comes with pre-built implementations of state-of-the-art [models](https://ts.gluon.ai/stable/getting_started/models.html), allowing for easy benchmarking and comparison. GluonTS supports various [data formats](https://ts.gluon.ai/stable/api/gluonts/gluonts.dataset.html) and provides [data loading](https://ts.gluon.ai/stable/api/gluonts/gluonts.dataset.loader.html) and iteration capabilities, making it suitable for handling large-scale time series datasets. Whether you're a scientist developing new models or a practitioner looking for out-of-the-box solutions, GluonTS offers the flexibility and tools needed to tackle complex time series problems.

## Import packages

In [None]:
# downgrade sentencepiece to 0.1.99 because it causes incompatibility issues in SMD 2.0
# this is fixed in SMD >= 2.1.0
# %pip -q install sentencepiece==0.1.99

In [None]:
%pip -q install --upgrade seaborn orjson statsmodels gluonts[mxnet] gluonts[Prophet]

In [None]:
# Uncomment if you need to restart kernel to get the packages
# import IPython
# IPython.Application.instance().kernel.do_shutdown(True)

In [None]:
# you need gluonts >= 0.15.1 otherwise DeepAR is not going to work
%pip show gluonts

In [None]:
%matplotlib inline

import json
import os
import zipfile
from time import gmtime, strftime, sleep
import random
import seaborn as sns
import matplotlib.pyplot as plt
import boto3
import sagemaker
import tqdm
import numpy as np
import pandas as pd
from statsmodels.tsa.seasonal import seasonal_decompose
from statsmodels.tsa.stattools import adfuller, kpss, acf, pacf
from statsmodels.stats.diagnostic import acorr_ljungbox
from scipy import stats
from scipy.stats import normaltest
import warnings
warnings.filterwarnings('ignore')
import matplotlib.colors as mcolors
from itertools import islice
import ipywidgets as widgets
from ipywidgets import (
interact, interactive, fixed, interact_manual,
IntSlider, Checkbox, Dropdown, DatePicker, Select, SelectMultiple, Checkbox
)

In [None]:
# setup plt environment
plt.rcParams["axes.grid"] = True
plt.rcParams["figure.figsize"] = (20, 3)
colors = list(mcolors.TABLEAU_COLORS)

## Set literals and general variables

In [None]:
sagemaker_session = sagemaker.Session()
region = sagemaker_session.boto_region_name

In [None]:
s3_bucket = sagemaker_session.default_bucket()  # replace with an existing bucket if needed
s3_prefix = "gluonts-demo-notebook"  # prefix used for all data stored within the bucket
experiment_prefix = "gluonts"
extract_to_path = '../data'

sm_role = sagemaker.get_execution_role()  # IAM role to use by SageMaker

In [None]:
# get domain_id and user profile name
NOTEBOOK_METADATA_FILE = "/opt/ml/metadata/resource-metadata.json"
domain_id = None

if os.path.exists(NOTEBOOK_METADATA_FILE):
    with open(NOTEBOOK_METADATA_FILE, "rb") as f:
        domain_id = json.loads(f.read()).get('DomainId')
        print(f"SageMaker domain id: {domain_id}")

## Download the dataset

Download the from the SageMaker example S3 bucket. You use the [electricity dataset](https://archive.ics.uci.edu/ml/datasets/ElectricityLoadDiagrams20112014) from the repository of the University of California, Irvine:
> Trindade, Artur. (2015). ElectricityLoadDiagrams20112014. UCI Machine Learning Repository. https://doi.org/10.24432/C58C86.

In [None]:
os.makedirs(extract_to_path, exist_ok=True)

In [None]:
dataset_zip_file_name = 'LD2011_2014.txt.zip'
dataset_path = f'{extract_to_path}/LD2011_2014.txt'

s3_dataset_path = f"datasets/timeseries/uci_electricity/{dataset_zip_file_name}"

In [None]:
if not os.path.isfile(dataset_path):
    print(f'Downloading and unzipping the dataset to {dataset_path}')
    s3_client = boto3.client("s3")
    s3_client.download_file(
        f"sagemaker-example-files-prod-{region}", s3_dataset_path, f"{extract_to_path}/{dataset_zip_file_name}"
    )

    with zipfile.ZipFile(f"{extract_to_path}/{dataset_zip_file_name}", "r") as zip_ref:
        total_size = sum(file.file_size for file in zip_ref.infolist())

        with tqdm.tqdm(total=total_size, unit='B', unit_scale=True, desc="Extracting") as pbar:
            for file in zip_ref.infolist():
                zip_ref.extract(file, extract_to_path)
                pbar.update(file.file_size)
        
    dataset_path = '.'.join(zip_ref.filename.split('.')[:-1])
else:
    print(f'The dataset {dataset_path} exists, skipping download and unzip!')

In [None]:
# see what is inside the file
# !head -n 2 {dataset_path} 

## Explore and preprocess data

### Load into a DataFrame and resample

In [None]:
df_raw = pd.read_csv(
    dataset_path, 
    sep=';', 
    index_col=0,
    decimal=',',
    parse_dates=True,
)

In [None]:
df_raw

In [None]:
# resample to 1h intervals
freq = "1h"
div = 4 # 1 hour contain 4x 15 min intervals, you need to  delete the resampled value by 4
num_timeseries = df_raw.shape[1]
data_kw = df_raw.resample(freq).sum() / div
timeseries = []

for i in tqdm.trange(num_timeseries):
    timeseries.append(np.trim_zeros(data_kw.iloc[:, i], trim="f"))

In [None]:
data_kw

### Visualize time series

In [None]:
def plot_timeseries(timeseries, start_time, length):
    n_cols = 2
    n_rows = (len(timeseries) + 1)//2
    
    fig, axs = plt.subplots(n_rows, n_cols, figsize=(20, 4*n_rows), sharex=True)
    axx = axs.ravel()
    for i, ts in tqdm.tqdm(enumerate(timeseries), total=len(timeseries), desc="Creating plots"):
        series = ts.loc[start_time:start_time + length*ts.index.freq]
        if len(series): series.plot(ax=axx[i])

        axx[i].set_xlabel("date")
        axx[i].set_ylabel(f"kW consumption - {ts.name} - {ts.index.freq}")
        axx[i].grid(which="minor", axis="x")
    
    plt.tight_layout()
    plt.show()

In [None]:
style = {"description_width": "initial"}
ts_id_list = [ts.name for ts in timeseries]
show_start_date = pd.Timestamp("2014-12-01")
time_step = timeseries[0].index.freq
max_samples = 64

Use the interactive plotting to visualize time series. You can change the following parameters:
- `Time series ids`: ids of the time series in the full dataset. You can select multiple time series to predict and to plot
- `Show from`: start of the displayed interval  
- `Length`: how many time steps are displayed starting from `Show from`
- `Random samples` and `Number of samples`: use these controls to show a random sample of the specified size from the time series dataset

In [None]:
@interact_manual(
    ts_ids=SelectMultiple(options=ts_id_list, value=[ts_id_list[0]], rows=5, style=style, description='Time series ids:'),
    start_date=DatePicker(value=show_start_date, style=style, description='Show from:'),
    length=IntSlider(min=1, max=730, value=100, style=style, description=f'Length in {time_step}:'),
    random_samples=Checkbox(value=False, description='Random samples'),
    num_samples=IntSlider(min=1, max=min(max_samples, len(ts_id_list)), value=min(10,len(ts_id_list)), style=style, description='Number of samples:'),
    continuous_update=False,
)
def plot_interact(ts_ids, start_date, length, random_samples, num_samples):
    ids = random.sample(ts_id_list, num_samples) if random_samples else ts_ids
    plot_timeseries([ts for ts in timeseries if ts.name in ids], start_date, length)    

### Optional: analyse time series
In this section you analyse time series by performing common operations like auto correlation analysis, stationarity detection, trend and seasonality period calculations. 

More specifically, the following code does:

**Basic statistics:**
- Mean, median, standard deviation
- Skewness and kurtosis
- Missing value detection
- Data frequency and span
  
**Stationarity analysis:**
- Augmented Dickey-Fuller test
- KPSS test
- Rolling statistics
  
**Distribution analysis:**
- Normality tests
- Quantile analysis
- Histogram generation

**Changepoint detection:**
- Moving average based detection

**Outlier detection:**
- Z-score method
- IQR method

**Correlation analysis:**
- ACF (Autocorrelation Function)
- PACF (Partial Autocorrelation Function)
- Significant lag identification
  
**Cyclical pattern analysis:**
- Spectral analysis using FFT
- Dominant frequency identification

In [None]:
class TimeSeriesAnalyzer:
    """
    Basic time series analysis example
    """
    def __init__(self, series, freq='H', title=None):
        self.series = series
        self.freq = freq
        self.title = title or f'Analysis for time series {series.name}'
        self.results = {}
        self.expected_periods = { # Set analysis periods based on frequency
            'H': [24, 168, 730],  # daily, weekly, monthly
            'D': [7, 30, 365],    # weekly, monthly, yearly
            'M': [12]             # yearly
        }[self.freq]
        
    def run_full_analysis(self):
        """
        Run all available analyses.
        """
        self.results = {
            'basic_stats': self.calculate_basic_stats(),
            'stationarity': self.check_stationarity(),
            'seasonality': self.analyze_seasonality(),
            'distribution': self.analyze_distribution(),
            'changepoints': self.detect_changepoints(),
            'outliers': self.detect_outliers(),
            'cyclical': self.analyze_cyclical_patterns(),
            'autocorrelation': self.analyze_autocorrelation()
        }
        return self.results
    
    def calculate_basic_stats(self):
        """
        Calculate basic time series statistics.
        """
        return {
            'mean': self.series.mean(),
            'median': self.series.median(),
            'std': self.series.std(),
            'skewness': self.series.skew(),
            'kurtosis': self.series.kurtosis(),
            'missing_values': self.series.isnull().sum(),
            'length': len(self.series),
            'start_date': self.series.index.min(),
            'end_date': self.series.index.max(),
            'frequency': self.freq
        }
    
    def check_stationarity(self):
        """
        Perform stationarity analysis.
        """
        # ADF Test
        adf_result = adfuller(self.series.dropna())
        
        # KPSS Test
        kpss_result = kpss(self.series.dropna())
        
        # Calculate rolling statistics
        rolling_mean = self.series.rolling(window=self.expected_periods[0]).mean()
        rolling_std = self.series.rolling(window=self.expected_periods[0]).std()
        
        return {
            'adf_test': {
                'statistic': adf_result[0],
                'p_value': adf_result[1],
                'critical_values': adf_result[4],
                'is_stationary': adf_result[1] < 0.05
            },
            'kpss_test': {
                'statistic': kpss_result[0],
                'p_value': kpss_result[1],
                'is_stationary': kpss_result[1] > 0.05
            },
            'rolling_statistics': {
                'mean': rolling_mean,
                'std': rolling_std
            }
        }
    
    def analyze_seasonality(self):
        """
        Analyze seasonal patterns using multiple methods.
        """
        # Perform seasonal decomposition for each expected period
        decompositions = {}
        for period in self.expected_periods:
            try:
                decomp = seasonal_decompose(self.series, period=period)
                strength = np.var(decomp.seasonal) / np.var(decomp.resid + decomp.seasonal)
                decompositions[period] = {
                    'decomposition': decomp,
                    'strength': strength
                }
            except:
                continue
        
        return {
            'decompositions': decompositions
        }
    
    def analyze_distribution(self):
        """
        Analyze the distribution of the time series.
        """
        # Normality test
        _, normality_p_value = normaltest(self.series.dropna())
        
        # Calculate quantiles
        quantiles = self.series.quantile([0.25, 0.5, 0.75])
        
        return {
            'normality_test': {
                'p_value': normality_p_value,
                'is_normal': normality_p_value > 0.05
            },
            'quantiles': quantiles,
            'iqr': quantiles[0.75] - quantiles[0.25],
            'histogram_data': np.histogram(self.series, bins='auto')
        }
        
    def detect_changepoints(self):
        """
        Detect significant changes in the time series.
        """
        # Simple moving average difference
        ma = self.series.rolling(window=self.expected_periods[0]).mean()
        diff = ma.diff()
        
        # Detect points where difference exceeds 2 standard deviations
        threshold = 2 * diff.std()
        changepoints = self.series.index[abs(diff) > threshold]
        
        return {
            'changepoints': changepoints,
            'n_changepoints': len(changepoints),
            'threshold': threshold
        }
    
    def detect_outliers(self):
        """
        Detect outliers using multiple methods.
        """
        # Z-score method
        z_scores = np.abs(stats.zscore(self.series.dropna()))
        z_score_outliers = self.series.index[z_scores > 3]
        
        # IQR method
        Q1 = self.series.quantile(0.25)
        Q3 = self.series.quantile(0.75)
        IQR = Q3 - Q1
        iqr_outliers = self.series.index[
            (self.series < (Q1 - 1.5 * IQR)) | 
            (self.series > (Q3 + 1.5 * IQR))
        ]
        
        return {
            'z_score': {
                'outliers': z_score_outliers,
                'count': len(z_score_outliers)
            },
            'iqr': {
                'outliers': iqr_outliers,
                'count': len(iqr_outliers)
            }
        }
    
    def analyze_cyclical_patterns(self):
        """
        Analyze cyclical patterns using spectral analysis.
        """
        # Perform FFT
        fft_values = np.fft.fft(self.series.dropna().values)
        fft_freq = np.fft.fftfreq(len(self.series))
        
        # Find dominant frequencies
        dominant_idx = np.argsort(np.abs(fft_values))[-5:]  # Top 5 frequencies
        
        return {
            'dominant_frequencies': fft_freq[dominant_idx],
            'dominant_amplitudes': np.abs(fft_values)[dominant_idx]
        }
    
    def analyze_autocorrelation(self):
        """
        Analyze autocorrelation and partial autocorrelation.
        """
        nlags = min(self.expected_periods[-1], len(self.series) // 4)
        acf_values = acf(self.series.dropna(), nlags=nlags)
        pacf_values = pacf(self.series.dropna(), nlags=nlags)
        
        return {
            'acf': acf_values,
            'pacf': pacf_values,
            'nlags': nlags,
            'significant_lags': {
                'acf': np.where(np.abs(acf_values) > 1.96/np.sqrt(len(self.series)))[0],
                'pacf': np.where(np.abs(pacf_values) > 1.96/np.sqrt(len(self.series)))[0]
            }
        }
    
    def plot_full_analysis(self):
        """
        Create comprehensive visualization of all analyses.
        """
        # Calculate number of seasonal decomposition subplots needed
        n_periods = len(self.results['seasonality']['decompositions'])
        total_plots = 5 + (n_periods * 3)  # Original + distribution + ACF/PACF + autocorrelation + frequences + (3 plots per period)
        
        fig = plt.figure(figsize=(20, 5 * total_plots))
        plot_position = 1
        
        # 1. Original Series + Rolling Statistics
        plt.subplot(total_plots, 1, plot_position)
        self.series.plot(label='Original')
        self.results['stationarity']['rolling_statistics']['mean'].plot(label=f'Rolling Mean, window={self.expected_periods[0]}')
        self.results['stationarity']['rolling_statistics']['std'].plot(label=f'Rolling Std, window={self.expected_periods[0]}')
        plt.title(f'{self.title} - Original Series and Rolling Statistics')
        plt.legend()
        plot_position += 1
        
        # 2. Seasonal Decompositions for each period        
        for period, decomp_info in self.results['seasonality']['decompositions'].items():
            period_str = str(period)
            period_name = {'24': 'daily', '168': 'weekly', '730': 'monthly'}.get(period_str, f'{period}-hour')
            title_suffix = f" (Period: {period}, Strength: {decomp_info['strength']:.3f})"
            
            decomp = decomp_info['decomposition']
                        
            # Trend
            plt.subplot(total_plots, 1, plot_position)
            decomp.trend.plot()
            plt.title(f'Trend Component - {period_name.capitalize()}{title_suffix}')
            plot_position += 1
            
            # Seasonal
            plt.subplot(total_plots, 1, plot_position)
            decomp.seasonal.plot()
            plt.title(f'Seasonal Component - {period_name.capitalize()}{title_suffix}')
            plot_position += 1
            
            # Residual
            plt.subplot(total_plots, 1, plot_position)
            decomp.resid.plot()
            plt.title(f'Residual Component - {period_name.capitalize()}{title_suffix}')
            plot_position += 1
        
        # 3. Distribution Analysis
        plt.subplot(total_plots, 1, plot_position)
        self.series.hist(bins=256)
        plt.title('Target distribution')
        plot_position += 1
        
        # 4. ACF/PACF
        plt.subplot(total_plots, 1, plot_position)
        corr_results = self.results['autocorrelation']
        plt.plot(corr_results['acf'], label='ACF')
        plt.plot(corr_results['pacf'], label='PACF')
        plt.axhline(y=0, color='r', linestyle='-')
        plt.axhline(y=1.96/np.sqrt(len(self.series)), color='r', linestyle='--')
        plt.axhline(y=-1.96/np.sqrt(len(self.series)), color='r', linestyle='--')
        plt.title('ACF/PACF')
        plt.legend()
        plot_position += 1
        
        # 5. Autocorrelation by Lag Plot
        plt.subplot(total_plots, 1, plot_position)
        max_lags = self.results['autocorrelation']['nlags']
        lags = range(1, max_lags + 1)
        correlations = [self.series.autocorr(lag=lag) for lag in lags]
        plt.bar(lags, correlations, alpha=0.5, color='blue')
        plt.axhline(y=0, color='r', linestyle='-')
        plt.axhline(y=1.96/np.sqrt(len(self.series)), color='r', linestyle='--', label='95% Confidence Interval')
        plt.axhline(y=-1.96/np.sqrt(len(self.series)), color='r', linestyle='--')
        plt.xlabel('Lag')
        plt.ylabel('Autocorrelation')
        plt.title(f'Autocorrelation by lag, max lag = {max_lags}')
        plt.legend()
        plot_position += 1

        # 6. Dominant Frequencies Plot
        ax = plt.subplot(total_plots, 1, plot_position)
        cyclical_results = self.results['cyclical']
        freq = cyclical_results['dominant_frequencies']
        
        ax.stem(cyclical_results['dominant_frequencies'], cyclical_results['dominant_amplitudes'])
        ax.set_xlim([np.min(freq)*2, np.max(freq)*2])
        plt.xlabel('Frequency')
        plt.ylabel('Amplitude')
        plt.title('Dominant Frequencies')
        plt.legend()
        plot_position += 1
        
        plt.tight_layout()
        return fig

Use the interactive plotting to analyse a specific time series. You can change the following parameters:
- `Time series id`: id of the time series in the full dataset
- `Analyse from` and `Analyse to`: start and end of the analysed interval  
- `Random sample`: activate this to pick a random sample from the full time series dataset

In [None]:
style = {"description_width": "initial"}
ts_id_list = [ts.name for ts in timeseries]
analyse_start_date = pd.Timestamp("2014-01-01")
analyse_end_date = pd.Timestamp("2014-04-30")
time_step = timeseries[0].index.freq

In [None]:
@interact_manual(
    ts_ids=Select(options=ts_id_list, value=ts_id_list[0], rows=5, style=style, description='Time series id:'),
    start_date=DatePicker(value=analyse_start_date, style=style, description='Analyse from:'),
    end_date=DatePicker(value=analyse_end_date, style=style, description='Analyse to:'),
    random_sample=Checkbox(value=False, description='Random sample'),
    continuous_update=False,
)
def plot_interact(ts_ids, start_date, end_date, random_sample):
    ids = random.sample(ts_id_list, 1) if random_sample else [ts_ids]
    ts = [ts for ts in timeseries if ts.name in ids][0].loc[start_date:end_date]
    
    print(f"Analysing time series {ts.name}, length = {len(ts)} data points...")
    
    analyzer = TimeSeriesAnalyzer(ts, freq='H')
    results = analyzer.run_full_analysis()
    fig = analyzer.plot_full_analysis()
        
    plt.show()
    
    # Print key findings
    print("\nKey findings:")
    print(f"1. Stationarity: {'Stationary' if results['stationarity']['adf_test']['is_stationary'] else 'Non-stationary'}")
    print(f"2. Distribution: {'Normal' if results['distribution']['normality_test']['is_normal'] else 'Non-normal'}")
    print(f"3. Number of outliers: {results['outliers']['z_score']['count']} (z-score method)")
    print(f"4. Number of changepoints: {results['changepoints']['n_changepoints']}")

    # Print more data
    print(f"\nBasic statistics:\n{pd.DataFrame.from_dict(results['basic_stats'], orient='index', columns=['value'])}")
    print(f"\nSeasonality:\n{results['seasonality']}")

---

## Train GluonTS models: a foundational walkthrough
In this section you begin with training and forecasting by using two simple models, [Seasonal Naive](https://otexts.com/fpp2/simple-methods.html#seasonal-na%C3%AFve-method) and [simple feedforward MLP](https://github.com/awslabs/gluonts/blob/dev/src/gluonts/mx/model/simple_feedforward/_estimator.py) to learn the foundation of GluonTS framework and get hands-on experience with GluonTS classes and utilities. You use a sample dataset with the single time series to make the example simple. 

After this step-by-step introduction, you move to training of more advanced models like Transformers, СNN, and [N-BEATS](https://openreview.net/forum?id=r1ecqn4YwB). Finally, you calculate and save model performance metrics.

### Setup environment and load packages

This notebook doesn't install GPU version of MXNet and uses CPU for model training and inference. 

In [None]:
!nvidia-smi

In [None]:
from mxnet import npx
# npx.set_np()
num_gpus = npx.num_gpus()  # Returns the number of available GPUs
print(f"Number of GPUs available: {num_gpus}")

In [None]:
# import GluonTS modules
from gluonts.dataset.pandas import PandasDataset
from gluonts.dataset.split import split, OffsetSplitter, DateSplitter
from gluonts.dataset.util import to_pandas
from gluonts.dataset.common import ListDataset
from gluonts.dataset.jsonl import JsonLinesFile
from gluonts.evaluation import make_evaluation_predictions, Evaluator
from gluonts.model.predictor import Predictor
from gluonts.model.forecast import QuantileForecast
from gluonts.dataset.field_names import FieldName
from pathlib import Path
from gluonts.mx import Trainer
from gluonts.mx import (
NBEATSEnsembleEstimator, NBEATSEstimator, GaussianProcessEstimator, DeepAREstimator,
TemporalFusionTransformerEstimator, MQCNNEstimator, MQRNNEstimator
)
from gluonts.model.seasonal_naive import SeasonalNaivePredictor
from gluonts.ext.prophet import ProphetPredictor
from gluonts.model.npts import NPTSPredictor
from gluonts.mx import SimpleFeedForwardEstimator

In [None]:
# predict for 7 days
prediction_days = 7
intervals_per_day = 24
prediction_length = prediction_days * intervals_per_day

print(f"Sampling frequency set to {freq}. Generate predictions for {prediction_length} intervals")

### GluonTS built-in datasets
GluonTS comes with many publicly available datasets. This section briefly shows what datasets are provided with GluonTS and loads the electricity dataset.
This notebook doesn't use any built-in GluonTS dataset but you can experiment with some popular datasets.

In [None]:
from gluonts.dataset.repository import get_dataset, dataset_names
from gluonts.dataset.util import to_pandas

In [None]:
print(f"Available datasets: {dataset_names}")

In [None]:
# load the electicity dataset
gluonts_dataset = get_dataset('electricity')

In [None]:
print(f"The electricity dataset contains {len(gluonts_dataset.train)} time series.")

In [None]:
# display the first time series
entry = next(iter(gluonts_dataset.train))
entry

In [None]:
# display some random time series
for e in np.random.choice(list(gluonts_dataset.train), size=4, replace=False):
    to_pandas(e)[-10*gluonts_dataset.metadata.prediction_length:].plot(title=f"item id:{e[FieldName.ITEM_ID]}")

    plt.tight_layout()
    plt.show()

### Create a sample dataset

The rest of the notebook continues to use the initially loaded original electricity dataset.

Create a smaller dataset with a subset of time series. You can use this sample dataset to simplify examples and to decrease model training time.

In [None]:
SAMPLE_SIZE = 1
MAX_TS_TO_DISPLAY = 5

In [None]:
# select some random time series to include in a small dataset
sample_size = SAMPLE_SIZE
columns_to_keep = np.random.choice(data_kw.columns.to_list(), size=sample_size, replace=False)
columns_to_keep

In [None]:
data_kw_sample = data_kw[columns_to_keep]
data_kw_sample

In [None]:
if data_kw_sample.shape[1] > MAX_TS_TO_DISPLAY:
    print(f"\033[91mToo many time series in the dataset to visualize, displaying a random sample of {MAX_TS_TO_DISPLAY}.\033[0m")
    sample = data_kw_sample.sample(n=MAX_TS_TO_DISPLAY, axis=1)
else:
    sample = data_kw_sample
    
plot_timeseries(
    [np.trim_zeros(sample.iloc[:, i], trim="f") for i in range(sample.shape[1])],
    pd.Timestamp("2014-12-01"), intervals_per_day*14
)

### Convert data to GluonTS format
A dataset should satisfy minimum requirements to be compartible with GluonTS: it should be an interable collections of data entries/time series, each entry should have at least a `target` field with values of the time series, and a `start` field with the start date of the time series.

To work directly on `pandas.DataFrame` or `pandas.Series`, you can use GluonTS `PandasDataset` class. GluonTS also supports multiple time series – they can be a list of the DataFrames, a dict of DataFrames, or a long format DataFrame with `item_id` column that designates each individual time series.

In [None]:
# take the last year of data for a sample
start_training_date = pd.Timestamp('2014-01-01')
end_dataset_date = pd.Timestamp('2014-12-31')

In [None]:
df_wide = data_kw_sample[(data_kw_sample.index > start_training_date) & (data_kw_sample.index <= end_dataset_date)]

In [None]:
ts_dataset = PandasDataset(dict(df_wide))

In [None]:
ts_dataset

In [None]:
# See what is inside the PandasDataset
entry = next(iter(ts_dataset))

print(entry)
print(f"Number of data points: {len(entry[FieldName.TARGET])}")
print(f"Number of time series in the dataset: {len(ts_dataset)}")

In [None]:
# show time series
for i, entry in enumerate(islice(ts_dataset, MAX_TS_TO_DISPLAY)):
    to_pandas(entry).plot(label=entry[FieldName.ITEM_ID], color=colors[i % len(colors)]) 
    plt.legend()
    plt.tight_layout()
    plt.show()

### Split into train and test datasets
Before training, you need to split the dataset to training and test parts. You can use a built-in [`splitter`](https://ts.gluon.ai/stable/api/gluonts/gluonts.dataset.split.html) to implement different strategies to split a given dataset. You can use [`OffsetSplitter`](https://ts.gluon.ai/stable/api/gluonts/gluonts.dataset.split.html#gluonts.dataset.split.OffsetSplitter) to split a uniform dataset by time step offset or [`DateSplitter`](https://ts.gluon.ai/stable/api/gluonts/gluonts.dataset.split.html#gluonts.dataset.split.DateSplitter) to split based on a specific date.

To generate and handle test pairs containing the test input and ground truth data, you can use [`TestTemplate`](https://ts.gluon.ai/stable/api/gluonts/gluonts.dataset.split.html#gluonts.dataset.split.TestTemplate) helper class.

Refer to the GluonTS API documentation for [`gluonts.dataset.split`](https://ts.gluon.ai/stable/api/gluonts/gluonts.dataset.split.html#) module for other helpful constructs and utilities.

In [None]:
# define some visualization helpers
def highlight_entry(entry, color, ax):
    start = entry["start"]
    end = entry["start"] + len(entry["target"])
    ax.axvspan(start, end, facecolor=color, alpha=0.2)

def plot_dataset_splitting(
    original_dataset, 
    training_dataset, 
    test_pairs
):
    n_rows = min(len(original_dataset), MAX_TS_TO_DISPLAY) + min(len(test_pairs), 2*MAX_TS_TO_DISPLAY)
    fig, axes = plt.subplots(n_rows, 1, figsize=(15, 3*n_rows))
    axes = axes.flatten()  # Convert 2D array of axes to 1D for easier indexing

    if len(original_dataset) > MAX_TS_TO_DISPLAY or len(test_pairs) > 2*MAX_TS_TO_DISPLAY:
        print(f"\033[91mToo many time series in the dataset to visualize, displaying first {MAX_TS_TO_DISPLAY} time series.\033[0m")

    # Current subplot index
    current_ax = 0
    
    # Plot original dataset and highlight the training part
    for original_entry, train_entry in zip(islice(original_dataset, MAX_TS_TO_DISPLAY), islice(training_dataset, MAX_TS_TO_DISPLAY)):
        ax = axes[current_ax]
        start = original_entry[FieldName.START].to_timestamp()
        end = (original_entry[FieldName.START] + len(original_entry[FieldName.TARGET])).to_timestamp()
        to_pandas(original_entry).plot(ax=ax)
        highlight_entry(train_entry, "red", ax)
        ax.legend([f"original dataset: {train_entry[FieldName.ITEM_ID]}", "training dataset"], loc="upper left")
        current_ax += 1

    # Plot test pairs
    for test_input, test_label in islice(test_pairs, 2*MAX_TS_TO_DISPLAY):
        ax = axes[current_ax]
        to_pandas(test_input).plot(ax=ax)
        to_pandas(test_label).plot(ax=ax)
        highlight_entry(test_input, "green", ax)
        highlight_entry(test_label, "blue", ax)
        ax.set_xlim(start, end)
        ax.legend([f"test input: {test_input[FieldName.ITEM_ID]}", "test label", "input", "label"], loc="upper left")
        current_ax += 1
        
    plt.tight_layout()
    plt.show()

#### Example 1: split by offset

In [None]:
# Split by offset
NUM_WINDOWS = 2 # you define how many test windows should be generated for each time series

train_ds, test_template = OffsetSplitter(offset=-NUM_WINDOWS*prediction_length).split(ts_dataset)
test_pairs = test_template.generate_instances(
    prediction_length=prediction_length, 
    windows=NUM_WINDOWS, 
)

In [None]:
print(f"The dataset is splitted in {len(train_ds)} training datasets and {len(test_pairs)} test pairs")

In [None]:
# visualize
plot_dataset_splitting(ts_dataset, train_ds, test_pairs)

#### Example 2: split by date

In [None]:
# Split by date
NUM_WINDOWS = 4
end_training_date = pd.Period(end_dataset_date, freq=freq) - NUM_WINDOWS*prediction_length

train_ds, test_template = DateSplitter(date=end_training_date).split(ts_dataset)
test_pairs = test_template.generate_instances(
    prediction_length=prediction_length,
    windows=NUM_WINDOWS,
    distance=prediction_length//2, # using 'distance' argument you can make windows overlap
)

In [None]:
print(f"The dataset is splitted in {len(train_ds)} training datasets and {len(test_pairs)} test pairs")

In [None]:
plot_dataset_splitting(ts_dataset, train_ds, test_pairs)

In [None]:
# visualize only ground truth (label) for each testing window
for _, test_label in islice(test_pairs, MAX_TS_TO_DISPLAY):
    to_pandas(test_label).plot()

plt.tight_layout()
plt.show()

### Training a built-in GluonTS algorithm
This section uses a single sample time series to demonstrate the prosess of training a model, producing predictions, and evaluating the results.

To encapsulate models and trained model artifacts, GluonTS uses an `Estimator`/`Predictor` pair of abstractions that should be familiar to users of other machine learning frameworks. An `Estimator` represents a model that can be trained on a dataset to yield a `Predictor`, which can later be used to make predictions on unseen data.

#### Example 1: Seasonal Naive
First make prediction using a simple seasonal model.

For each time series $Y$ the seasonal naive predictor produces a forecast:

$\tilde{Y}(T+k) = Y(T+k-h)$

where $T$ - forecast time, $k$ - prediction length-1, $h$ - season length 

In [None]:
# split the dataset at the last two weeks of the year and predict the week before last to avoid the one-off effect of the Dec 25th
train_ds, test_template = OffsetSplitter(offset=-2*prediction_length).split(ts_dataset)
test_pairs = test_template.generate_instances(
    prediction_length=prediction_length, 
    windows=1, 
)

print(f"The dataset is splitted in {len(train_ds)} training datasets and {len(test_pairs)} test pairs")

In [None]:
# create instance of predictor
seasonal_naive_predictor = SeasonalNaivePredictor(
    prediction_length=prediction_length,
    season_length=24,
)

In [None]:
# predict
forecasts = [seasonal_naive_predictor.predict_item(test_input) for test_input, _ in test_pairs]

In [None]:
forecast_entry = forecasts[0]

print(f"Number of sample paths: {forecast_entry.num_samples}")
print(f"Dimension of samples: {forecast_entry.samples.shape}")
print(f"Start date of the forecast window: {forecast_entry.start_date}")
print(f"Frequency of the time series: {forecast_entry.freq}")

In [None]:
# access to predictions
forecast_entry.mean

In [None]:
# for seasonal naive all quantiles are the same
forecast_entry.quantile(0.9)

In [None]:
# visualize predictions
i = 0
for test_input, test_label in islice(test_pairs, MAX_TS_TO_DISPLAY):
    fig, ax = plt.subplots(1, 1, figsize=(15,3))

    to_pandas(test_input)[-prediction_length:].plot(ax=ax, label=f"Input time series: {test_input[FieldName.ITEM_ID]}")
    to_pandas(test_label).plot(ax=ax, label="Ground truth")
    mean_forecast = forecasts[i].to_quantile_forecast(['mean'])
    pd.Series(data=mean_forecast.forecast_array[0], index=mean_forecast.index).plot(ax=ax, label="Mean forecast")
    i += 1

    plt.legend(loc='upper left')
    plt.tight_layout()
    plt.show()

#### Evaluate predictions
Use GluonTS [`Evaluator`](https://ts.gluon.ai/stable/api/gluonts/gluonts.evaluation.html) class to evaluate the forecast numerically. This class computes metrics per time series (item) as well as aggregated metrics accross all time series.

In [None]:
evaluator = Evaluator(quantiles=[0.5])
agg_metrics, item_metrics = evaluator(
    [to_pandas(l) for l in test_pairs.label], 
    forecasts,
    num_series=len(ts_dataset),
)

In [None]:
# aggregated metrics
print(json.dumps(agg_metrics, indent=2))

In [None]:
# metrics per time series
item_metrics

In [None]:
def visualize_item_metric(
    item_metrics,
    metric_name,
):
    fig, ax = plt.subplots(figsize=(15,6))

    metric_data = item_metrics[metric_name]
    ax.bar(item_metrics[FieldName.ITEM_ID], metric_data)

    if len(item_metrics) > 1:
        avg = metric_data.mean()
        std = metric_data.std()
        # Add average line
        ax.axhline(avg, color='red', linestyle='--', label='Average')
        # Add shaded area for standard deviation
        ax.fill_between(range(len(item_metrics)), avg - std, avg + std, color='green', alpha=0.2, label='±1 Std Dev')
    
    ax.set_title(f'{metric_name} per item')
    ax.set_ylabel(metric_name)
    
    # Show only horizontal grid lines
    ax.grid(axis='x', linestyle='--', alpha=0.7)

    ax.legend()
    plt.xticks(rotation='vertical')
    plt.tight_layout()
    plt.show()

In [None]:
visualize_item_metric(item_metrics, 'sMAPE')

#### Example 2: feedforward network
Now use GluonTS's built-in fundamental neural network model [`SimpleFeedForwardEstimator`](https://github.com/awslabs/gluonts/blob/dev/src/gluonts/mx/model/simple_feedforward/_estimator.py).

`SimpleFeedForwardEstimator` implements a Multilayer Perceptron (MLP) model that predicts future time steps based on previous observations. The model produces probabilistic forecasts, meaning it outputs probability distributions rather than single point estimates.

In [None]:
train_ds, _ = OffsetSplitter(offset=-prediction_length).split(ts_dataset)

In [None]:
feed_forward_estimator = SimpleFeedForwardEstimator(
    num_hidden_dimensions=[10],
    prediction_length=prediction_length,
    context_length=4*prediction_length,
    trainer=Trainer(ctx="cpu", epochs=5, learning_rate=1e-3, num_batches_per_epoch=100),
)

In [None]:
feed_forward_predictor = feed_forward_estimator.train(train_ds)

To make forecast generation and evaluation easier, you can use GluonTS helper function
[`make_evaluation_predictions`](https://ts.gluon.ai/stable/api/gluonts/gluonts.evaluation.html#gluonts.evaluation.make_evaluation_predictions). This function performs the following:
1. Removes the last window of `prediction_length` data points from the dataset
2. The estimator predicts the future `prediction_length` data points starting from the last point in the dataset
3. Outputs the forecast samples and the input dataset

In [None]:
# predict the last prediction_length data points of the dataset
forecast_it, ts_it = make_evaluation_predictions(
    dataset=ts_dataset, 
    predictor=feed_forward_predictor,
    num_samples=20
)

In [None]:
forecasts = list(forecast_it)
labels = list(ts_it)

In [None]:
# visualize predictions
for i, forecast in enumerate(islice(forecasts, MAX_TS_TO_DISPLAY)):
    plt.plot(labels[i][-2*prediction_length:].to_timestamp())
    forecast.plot(intervals=(0.9,), show_label=True)
    plt.legend([f"Ground truth: {forecast.item_id}", "predicted median", "90% confidence interval"])
    plt.show()

#### Evaluate predictions

In [None]:
evaluator = Evaluator(quantiles=(np.arange(10) / 10.0)[1:])
agg_metrics, item_metrics = evaluator(
    labels, 
    forecasts,
    num_series=len(ts_dataset),
)

In [None]:
print(json.dumps(agg_metrics, indent=2))

In [None]:
item_metrics

In [None]:
visualize_item_metric(item_metrics, 'sMAPE')

---

## Train GluonTS models: advanced usage

After learning some basic approaches to train and evaluate simple GluonTS models, you move to training of more advanced models.

### Prepare dataset
In this section you select time series and an inteval for training and evaluation.

#### Choose time series
Select the full dataset with all time series or only a subset of random or specific time series for training and evaluation.

<div class="alert alert-info">
To reduce the training time and inference time you can use a subset of time series instead of the full dataset with 370 time series.
</div>

In [None]:
USE_FULL_DATASET = False # The training of full dataset can take about 40 minutes per estimator
SAMPLE_SIZE = 10 # set number of samples in the dataset if you don't use the full dataset
MAX_TS_TO_DISPLAY = 5 # maximum number of displayed time series plots

In [None]:
# get the full dataset or a random sample of SAMPLE_SIZE
# you can change the selection to include specific time series
# ts_sample = data_kw[['item_id1', 'item_id2']]
ts_sample = data_kw if USE_FULL_DATASET else data_kw[np.random.choice(data_kw.columns.to_list(), size=SAMPLE_SIZE, replace=False)]

#### Choose start and end dates
Select the time interval for training and evaluation. Use the following visualization to choose your specific interval.

In [None]:
start_training_date = pd.Timestamp('2014-01-01')
end_dataset_date = pd.Timestamp('2014-12-31')

In [None]:
style = {"description_width": "initial"}
item_ids = ts_sample.columns.to_list()

In [None]:
@interact_manual(
    item_ids=SelectMultiple(options=item_ids,value=[item_ids[0]], rows=5, style=style, description='Time series ids:'),
    data_start=DatePicker(value=start_training_date, style=style, description='Data start:'),
    data_end=DatePicker(value=end_dataset_date, style=style, description='Data end:'),
    continuous_update=False,
)
def plot_interact(item_ids, data_start, data_end):

    print(f'Filtering and displaying the data from {data_start} to {data_end}')
    ts = ts_sample[(ts_sample.index > pd.Timestamp(data_start)) & (ts_sample.index <= pd.Timestamp(data_end))]

    for i, item in enumerate(islice(item_ids, 2*MAX_TS_TO_DISPLAY)):
        ts[item].plot(label=item, color=colors[i % len(colors)])
        
        plt.legend()
        plt.tight_layout()
        plt.show()

#### Convert to GluonTS format

<div class="alert alert-info">
If you want to use your own custom start and end dates, set them in the next code cell.
</div>

In [None]:
# set interval start and end to your preferred dates
start_training_date = pd.Timestamp('2014-01-01')
end_dataset_date = pd.Timestamp('2014-12-31')

In [None]:
ts_dataset = PandasDataset(
    dict(ts_sample[(ts_sample.index > start_training_date) & (ts_sample.index <= end_dataset_date)])
)

In [None]:
# show time series in the GluonTS dataset
for i, entry in enumerate(islice(ts_dataset, MAX_TS_TO_DISPLAY)):
    to_pandas(entry).plot(label=entry[FieldName.ITEM_ID], color=colors[i % len(colors)]) 
    plt.legend()
    plt.tight_layout()
    plt.show()

In [None]:
print(f'The GluonTS dataset contains {len(ts_dataset)} individual time series from {start_training_date} to {end_dataset_date}')

### Split and prepare test instances

For training and test you split data in multiple rolling windows starting from the end of the training dataset. You can choose the number of windows and whether windows overlap with each other.

In [None]:
# set backtest parameters
NUM_WINDOWS = 4 # number of rolling windows for backtest
# distance between windows, set to:
# < prediction_length for overlapping windows
# = prediction length for adjucent windows 
# > prediction_length for non overapping and non-adjucent windows
DISTANCE = prediction_length

# set the training-testing split date
end_training_date = pd.Period(end_dataset_date, freq=freq) - NUM_WINDOWS*prediction_length

In [None]:
train_ds, test_template = DateSplitter(date=end_training_date).split(ts_dataset)
test_pairs = test_template.generate_instances(
    prediction_length=prediction_length,
    windows=NUM_WINDOWS,
    distance=DISTANCE,
)

print(f"The dataset is splitted in {len(train_ds)} training datasets and {len(test_pairs)} test pairs. Training end is {end_training_date}")

In [None]:
plot_dataset_splitting(ts_dataset, train_ds, test_pairs)

### Train models

Now you're ready to train models. To demonstrate some built-in GluonTS algorithms you going to train the following models:

- [`SimpleFeedForward`](https://github.com/awslabs/gluonts/blob/dev/src/gluonts/mx/model/simple_feedforward/_estimator.py)
- [`NBEATS`](https://github.com/awslabs/gluonts/blob/dev/src/gluonts/mx/model/n_beats/_estimator.py), [paper](https://openreview.net/forum?id=r1ecqn4YwB)
- [`DeepAR`](https://github.com/awslabs/gluonts/blob/dev/src/gluonts/mx/model/deepar/_estimator.py), [paper](https://doi.org/10.1016/j.ijforecast.2019.07.001) 
- [`GaussianProcess`](https://github.com/awslabs/gluonts/blob/dev/src/gluonts/mx/model/gp_forecaster/_estimator.py)
- [`TemporalFusionTransformer`](https://github.com/awslabs/gluonts/blob/dev/src/gluonts/mx/model/tft/_estimator.py), [paper](https://doi.org/10.1016/j.ijforecast.2021.03.012)
- [`MQCNN`](https://github.com/awslabs/gluonts/blob/dev/src/gluonts/mx/model/seq2seq/_mq_dnn_estimator.py), [paper](https://arxiv.org/abs/1711.11053)

To compare performance of these models you're going to use statistical models like [`Seasonal Naive`](https://otexts.com/fpp2/simple-methods.html#seasonal-na%C3%AFve-method), [`Prophet`](https://facebook.github.io/prophet/), and [`NTPS`](https://github.com/awslabs/gluonts/blob/dev/src/gluonts/model/npts/_predictor.py) as a baseline.

You can experiment with other [available models](https://ts.gluon.ai/stable/getting_started/models.html) on your own using the code in this notebook.

In [None]:
# remove an item from this list if you don't want to train that model
estimators_to_train = [
    'SimpleFeedForward', 
    'NBEATS', 
    'DeepAR',
    'GaussianProcess', 
    'TemporalFusionTransformer', 
    'MQCNN', 
]

<div style="border: 4px solid coral; text-align: left; margin: auto;">
<b>Important considerations</b><br/>
1. This example is not production-grade model training<br/>
2. All estimators are trained with default hyperparameters which might not be the optimal configuration<br/>
3. All training is limited to 8 epochs which might not yield the most optimal model<br/>
4. This notebook uses CPU-only MXNet implementation of all shown neural network models<br/>
5. In a real-world use case you're going to use an ensemble of several models rather than a single model<br/>
6. In a real-world use case you might run a hyperparameter optimization as well<br/>
</div>

In [None]:
NUM_EPOCHS = 8
trainer_hyperparameters = {
    "ctx":"cpu",
    "epochs":NUM_EPOCHS,
    "learning_rate":01e-3,
    "clip_gradient":10,
    "weight_decay":1e-8,
    "num_batches_per_epoch":100,
}

# same trainer for all models
trainer = Trainer(**trainer_hyperparameters)

In [None]:
estimators = {}
model_hyperparameters = {
    "freq":freq,
    "prediction_length":prediction_length,
    "context_length":4*prediction_length,
    "trainer":trainer,
}

for e in estimators_to_train:
    if e == 'SimpleFeedForward':
        estimators[e] = SimpleFeedForwardEstimator(
            num_hidden_dimensions=[10],
            prediction_length=prediction_length,
            context_length=4*prediction_length,
            trainer=trainer
        )
    elif e == 'NBEATS':
        estimators[e] = NBEATSEstimator(
            **model_hyperparameters,
            loss_function='MAPE',
            num_stacks=30,
            widths=[512],
            num_blocks=[1],
        )
    elif e == 'DeepAR':
        estimators[e] = DeepAREstimator(
            **model_hyperparameters,
        )
    elif e == 'GaussianProcess':
        estimators[e] = GaussianProcessEstimator(
            **model_hyperparameters,
            cardinality=len(train_ds),
        )
    elif e == 'TemporalFusionTransformer':
        estimators[e] = TemporalFusionTransformerEstimator(
            **model_hyperparameters,
        )
    elif e == 'MQCNN':
        estimators[e] = MQCNNEstimator(
            **model_hyperparameters,
        )
    elif e == 'MQRNN':
        estimators[e] = MQRNNEstimator(
            **model_hyperparameters,
        )
    else:
        continue

print(f'Configured estimators: {[k for k in estimators.keys()]}')

<div class="alert alert-info">
With six given estimators and 10 time series the training takes about <b>60 minutes</b>. You can take only a subset of estimators to reduce training time.
</div>



In [None]:
print(f'Training {len(estimators.keys())} estimators on {len(train_ds)} time series.')
print(f'Estimators: {[k for k in estimators.keys()]}')

# train all estimators and store predictors in a dict
predictors = {
    n:e.train(train_ds) for n, e in estimators.items()
}

# add statistical models that don't need training for a baseline
predictors['SeasonalNaive'] = SeasonalNaivePredictor(prediction_length=prediction_length, season_length=24)
predictors['Prophet'] = ProphetPredictor(prediction_length=prediction_length)
predictors['NPTS'] = NPTSPredictor(prediction_length=prediction_length, context_length=4*prediction_length)



### Predict and visualize
Having all predictors you can generate forecasts for each test interval.

In [None]:
print(f"Running inference for {len(predictors.keys())} predictors on {len(test_pairs)} test datasets: {NUM_WINDOWS} rolling windows*{len(ts_dataset)} time series")
print(f'Predictors: {[k for k in predictors.keys()]}')

# generate forecast for each test pair and each predictor and save to a dict
forecasts_all = {
    n:list(p.predict(test_pairs.input, num_samples=20)) for n, p in predictors.items()
}

# ground truth
labels = [to_pandas(l) for l in test_pairs.label]

In [None]:
def visualize_predictions(
    item_id,
    original_dataset, # GluonTS PandasDataset
    forecasts, # iterator with predicted forecasts
    labels, # test_pairs.label iterator
    prediction_length,
    history_length, # how much of history are displayed
    c_interval=0.9, # confidence interval for probabilistic predictions
):
    # Get historical data, predictions, and label for the specific item_id
    historical_ts = to_pandas([e for e in original_dataset if e[FieldName.ITEM_ID] == item_id][0])
    item_forecasts = [f for f in forecasts if f.item_id == item_id]
    item_labels = [to_pandas(l) for l in labels if l[FieldName.ITEM_ID] == item_id]
    
    # Calculate the number of rows needed for the grid
    n_forecasts = len(item_forecasts)
    n_cols = 2
    n_rows = (n_forecasts +1) // n_cols
    
    # Create figure for historical data
    fig_hist, ax_hist = plt.subplots(figsize=(15, 3))
    plt.plot(historical_ts[-history_length:].to_timestamp(), color='b', label='Historical')
    ax_hist.set_title(f'Historical time series: {item_id}')
    ax_hist.legend()
    plt.tight_layout()
    
    # Create figure for forecasts
    fig_forecasts, axes = plt.subplots(n_rows, n_cols, figsize=(15, 3*n_rows))
    axes = axes.flatten() if n_rows > 1 else [axes]
    
    # Plot each forecast
    for idx, (label, forecast) in enumerate(zip(item_labels, item_forecasts)):
        ax = axes[idx]
        
        # Plot ground truth
        ax.plot(label.to_timestamp(), color='g', label='Ground truth')
        
        # Plot forecast with confidence interval
        forecast.plot(ax=ax, intervals=(c_interval,), show_label=True, color='r')
        
        ax.set_title(f'Prediction interval {idx}, start: {forecast.start_date}')
    
    # Remove empty subplots if any
    for idx in range(len(item_forecasts), len(axes)):
        fig_forecasts.delaxes(axes[idx])
    
    fig_forecasts.legend(['Ground truth', 'Predicted median', f'{c_interval*100:.0f}% confidence interval'], 
                        loc='upper center', 
                        bbox_to_anchor=(0.5, 1.0),
                        ncol=3,  # Display legend items in 3 columns
                        bbox_transform=fig_forecasts.transFigure)
    
    plt.tight_layout()
    # Adjust layout to account for the legend
    plt.subplots_adjust(top=0.85, hspace=0.5)
    plt.show()

In the following interactive visualization you can display historical data, predictions and ground truth for each evaluation window for a specific time series.

In [None]:
style = {"description_width": "initial"}
item_ids = [e[FieldName.ITEM_ID] for e in ts_dataset]

In [None]:
@interact_manual(
    model=Dropdown(options=list(forecasts_all.keys()), description='Model:'),
    item_id=Select(options=item_ids, value=item_ids[0], rows=5, style=style, description='Item id:'),
)
def plot_interact(model, item_id):
    visualize_predictions(
        item_id, 
        ts_dataset, 
        forecasts_all[model],
        test_pairs.label,
        prediction_length,
        NUM_WINDOWS*prediction_length
    )

If you'd like to compare performance across models and compare models per metric, run the following cells to save the results to a file and open the notebook [`lab6_results`](../lab6_results.ipynb) for the analysis.

### Evaluate predictions

In [None]:
print(f"Scoring {len(forecasts_all.keys())} forecasts on {len(test_pairs)} test pairs: {NUM_WINDOWS} rolling windows*{len(ts_dataset)} time series")
print(f'Predictors: {[k for k in predictors.keys()]}')

evaluator = Evaluator(quantiles=(np.arange(10) / 10.0)[1:])

backtest_scores = []

# calculate metrics for all predictors
for n, f in forecasts_all.items():
    agg_metrics, item_metrics = evaluator(
        labels, 
        f,
    )
    backtest_scores.append({'model':n, 'agg_metrics':agg_metrics, 'item_metrics':item_metrics})

In [None]:
style = {"description_width": "initial"}
item_ids = backtest_scores[0]['item_metrics'][FieldName.ITEM_ID].unique()
metrics = backtest_scores[0]['item_metrics'].columns[2:].tolist()

In [None]:
@interact_manual(
    model=Dropdown(options=list(forecasts_all.keys()), description='Model:'),
    metric=Select(options=metrics, value=metrics[0], rows=20, style=style, description='Metric:'),
)
def plot_interactive(model, metric):
    agg_metrics, item_metrics = [(m['agg_metrics'], m['item_metrics']) for m in backtest_scores if m['model'] == model][0]

    visualize_item_metric(item_metrics, metric)
    print(f'Aggregated metrics for {model} model:\n{json.dumps(agg_metrics, indent=2)}')

#### Save the model performance to a file

In [None]:
experiment_prefix = "gluonts"

In [None]:
os.makedirs("../model-performance", exist_ok=True)

In [None]:
def get_metrics_df(
    model_metrics: dict,
    experiment_name: str,
    timestamp=strftime("%Y%m%d-%H%M%S", gmtime()),
) -> pd.DataFrame:
    model_metrics_df = pd.DataFrame.from_dict(model_metrics, orient='index', columns=['value']).reset_index().rename(columns={'index': 'metric_name'})
    model_metrics_df['experiment'] = experiment_name
    model_metrics_df['timestamp'] = timestamp
    model_metrics_df = model_metrics_df[['timestamp', 'metric_name', 'value', 'experiment']].dropna(subset=['value'])

    # print(model_metrics_df)
    return model_metrics_df

In [None]:
# construct a DataFrame with all metrics for all models
model_metrics_df = pd.concat([
    get_metrics_df(
        s['agg_metrics'],
        f"{experiment_prefix}-{s['model']}-{freq}-{len(ts_dataset)}-{len(next(iter(ts_dataset))[FieldName.TARGET])}-bt{NUM_WINDOWS}",
    ) for s in backtest_scores
])

model_metrics_df

In [None]:
# save the metrics df to the file
experiment_name = f"{experiment_prefix}-{freq}-{len(ts_dataset)}-{len(next(iter(ts_dataset))[FieldName.TARGET])}-bt{NUM_WINDOWS}"

model_metrics_df.to_csv(
    f"../model-performance/{experiment_name}-{model_metrics_df['timestamp'].iloc[0]}.csv",
    index=False
)

---

## Optional: operationalizing GluonTS training and models

<div class="alert alert-info">
This section is <b>L300-400</b> level and assumes you're familiar with MLOps concepts, SageMaker features like pipelines, training jobs, and model registry. Completion time for this section is about <b>60 minutes</b>.
</div>

In [None]:
%load_ext autoreload
%autoreload 2

In this section you create a production-ready ML workflow to preprocess an input dataset, train a time series model, evaluate the model, register model in the model registry, and deploy model as a SageMaker real-time inference endpoint.

This notebook uses the PyTorch implementation of [Temporal Fusion Transformer](https://github.com/awslabs/gluonts/blob/dev/src/gluonts/torch/model/tft/estimator.py) as a model example. This model is one of the GluonTS built-in algorithms that you can use out of the box.

On this example you learn how to use [SageMaker MLOps features](https://aws.amazon.com/sagemaker/mlops/) and [Python SDK](https://sagemaker.readthedocs.io/en/stable/index.html) to create robust, reproducable, portable, and scalable ML solutions.

In [None]:
# helper to check resource quota in the account
def check_quota(q_map, instance, min_n=1):
    quotas_client = boto3.client("service-quotas")


    r = quotas_client.get_service_quota(
        ServiceCode="sagemaker",
        QuotaCode=q_map[instance],
    )

    q = r["Quota"]["Value"]
    n = r["Quota"]["QuotaName"]
    min_n = min_n

    b = q >= min_n

    print(f"\033[92mSUCCESS: Quota {q} for {n} >= required {min_n}\033[0m" if b else f"\033[91mWARNING: Quota {q} for {n} < required {min_n}\033[0m")

    return b
    
# helper to get the first available instance in the quota map
def get_best_instance(q_map):
    l = [i for i in
            [i if check_quota(q_map, i) else None 
             for i in q_map.keys()] if i is not None]
    return l[0] if len(l) > 0 else ''

### Configure defaults of AWS infrastructure
Here you use a YAML configuration file to define the default values that are automatically passed to SageMaker APIs, for example as job parameters. It's especially convenient when you need to provide static parameters for infrastructure settings, such as VPC ids, Security Groups, KMS keys etc, or work with remote functions.

Refer to [Configuring and using defaults with the SageMaker Python SDK](https://sagemaker.readthedocs.io/en/stable/overview.html#configuring-and-using-defaults-with-the-sagemaker-python-sdk) documentation for examples and more details.

Your GluonTS pipeline will use these `config.yaml` files for the default configuration values.

In [None]:
# Print default location of configuration files
from platformdirs import site_config_dir, user_config_dir

#Prints the location of the admin config file
print(os.path.join(site_config_dir("sagemaker"), "config.yaml"))

#Prints the location of the user config file
print(os.path.join(user_config_dir("sagemaker"), "config.yaml"))

In [None]:
%%writefile config.yaml

SchemaVersion: '1.0'
SageMaker:
    PythonSDK:
        Modules:
            RemoteFunction:
                InstanceType: ml.m5.2xlarge
                Dependencies: ./requirements.txt
                IncludeLocalWorkDir: true
                CustomFileFilter:
                    IgnoreNamePatterns: # files or directories to ignore
                        - "*.ipynb" # all notebook files
                        - "*.md" # all markdown files
                        - "__pycache__"
                        - "*.zip"
                        - "*.gz"
                        - "LD2011_2014.*"
                        - "*local"
                        - "*logs"
                        - "*data"
                        - "*output"

In [None]:
# copy the configuration file to user config file location
%mkdir -p {user_config_dir("sagemaker")}
%cp config.yaml {os.path.join(user_config_dir("sagemaker"), "config.yaml")}

### Prepare environment

Import required packages, create a `requirements.txt` file for SageMaker processing and training job, and create local source code directories for local mode execution.

In [None]:
from sagemaker.estimator import Estimator
from sagemaker.pytorch.estimator import PyTorch
from sagemaker.local import LocalSession

In [None]:
%%writefile requirements.txt
sentencepiece==0.1.99
lightning
gluonts>=0.16.0

In [None]:
!cp ./requirements.txt ./gluonts_pipeline/

In [None]:
!sudo rm -rf ./gluonts_pipeline_local/
!mkdir -p ./gluonts_pipeline_local/
!mkdir -p ./output/model/

#### Configure for local mode
Amazon SageMaker Studio applications support the use of local mode to create estimators, processors, and pipelines. With local mode, you can test your scripts locally in JupyterLab before running them in SageMaker managed training or hosting environments. The local mode is a convenient way to quickly iterate over your training script in the notebook to ensure it works as intended.  Refer to [Local mode support in Amazon SageMaker Studio](https://docs.aws.amazon.com/sagemaker/latest/dg/studio-updated-local.html) to understand which docker operations the Studio currently supports.

To use local mode in Studio applications, you must enable docker access for the Sagemaker domain and  install Docker into your JupyterLab space.

This section checks if docker access is enabled in the domain, installs Docker, and installs the `sagemaker[local]` extras from the sagemaker SDK.

In [None]:
# check that docker enabled in the SageMaker domain
docker_settings = boto3.client('sagemaker').describe_domain(DomainId=domain_id)['DomainSettings'].get('DockerSettings')
docker_enabled = False

if docker_settings:
    if docker_settings.get('EnableDockerAccess') in ['ENABLED']:
        print(f"\033[92mThe docker access is ENABLED in the domain {domain_id}\033[0m")
        docker_enabled = True

if not docker_enabled:
    raise Exception(f"\033[91mYou must enable docker access in the domain to use Studio local mode\033[0m")

<div class="alert alert-info">
If docker is not enabled, you need to enable the access following the instructions below. You may also skip the <b>Run and test pipeline steps locally</b> section and go directly to the <b>Construct a pipeline</b> section.
</div>

**To enable Docker access in the SageMaker domain**:

If you have the corresponding permissions in the notebook execution role, you can run the following code in a notebook:

```Python
import boto3

r = boto3.client('sagemaker').update_domain(
    DomainId=domain_id,
    DomainSettingsForUpdate={
        'DockerSettings': {
            'EnableDockerAccess':'ENABLED',
        }
    }
)
```

Alternatively run can run `aws sagemaker` CLI in the terminal:

```sh
aws sagemaker update-domain --domain-id <DOMAIN-ID> --domain-settings-for-update DockerSettings={EnableDockerAccess='ENABLED'}
```

For both options your execution role needs to have `sagemaker:UpdateDomain` permission.

#### Install Docker

In [None]:
%%bash

# see https://docs.docker.com/engine/install/ubuntu/#install-using-the-repository
sudo apt-get update
sudo apt-get install -y ca-certificates curl
sudo install -m 0755 -d /etc/apt/keyrings
sudo curl -fsSL https://download.docker.com/linux/ubuntu/gpg -o /etc/apt/keyrings/docker.asc
sudo chmod a+r /etc/apt/keyrings/docker.asc

# Add the repository to Apt sources:
echo \
  "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.asc] https://download.docker.com/linux/ubuntu \
  $(. /etc/os-release && echo "$VERSION_CODENAME") stable" | \
  sudo tee /etc/apt/sources.list.d/docker.list > /dev/null
sudo apt-get update

## Currently only Docker version 20.10.X is supported in Studio: see https://docs.aws.amazon.com/sagemaker/latest/dg/studio-updated-local.html
# pick the latest patch from:
# apt-cache madison docker-ce | awk '{ print $3 }' | grep -i 20.10
VERSION_STRING=5:20.10.24~3-0~ubuntu-jammy
sudo apt-get install docker-ce-cli=$VERSION_STRING docker-compose-plugin -y

# validate the Docker Client is able to access Docker Server at [unix:///docker/proxy.sock]
docker version

In [None]:
%pip install -q sagemaker[local]

### Upload the raw dataset to S3

In [None]:
# set S3 urls for input and output data
s3_input_data_prefix = f's3://{s3_bucket}/{s3_prefix}/data/raw'
s3_output_data_prefix = f's3://{s3_bucket}/{s3_prefix}/data/output'

In [None]:
# remove all previous datasets on S3
!aws s3 rm s3://{s3_bucket}/{s3_prefix}/data/ --recursive

In [None]:
# copy the original zip file with the dataset to S3
!aws s3 cp {extract_to_path}/{dataset_zip_file_name} {s3_input_data_prefix}/

In [None]:
# validate that the dataset is copied on S3
!aws s3 ls s3://{s3_bucket}/{s3_prefix}/data/ --recursive

### Set parameters

Set values for parameters passed to the pipeline steps.

In [None]:
# portion of the whole dataset used
data_start = pd.Timestamp('2014-01-01')
data_end = pd.Timestamp('2014-12-31')

# hyperparameters for training
# for more hyperparameters see https://github.com/awslabs/gluonts/blob/dev/src/gluonts/torch/model/tft/estimator.py
# extend the hyperparameters dictionary to additional hp to the estimator. You also need to adapt train.py
hyperparameters = {
    "epochs":8,
    "freq":freq,
    "prediction_length":prediction_length,
    "context_length":4*prediction_length,
    "quantiles":','.join(str(x) for x in (np.arange(10) / 10.0)[1:]),
    "backtest_windows":4,
    "num_samples":20, # not used in TFT predictor
}

# define metrics which should be collected by SageMaker
metric_definitions = [
    {'Name':'train_loss', 'Regex':'train_loss=([0-9\.]+)'},
    {'Name':'test_MSE', 'Regex':'test_MSE=([0-9\.]+)'},
    {'Name':'test_MAPE', 'Regex':'test_MAPE=([0-9\.]+)'},
    {'Name':'test_sMAPE', 'Regex':'test_sMAPE=([0-9\.]+)'},
    {'Name':'test_RMSE', 'Regex':'test_RMSE=([0-9\.]+)'},
    {'Name':'test_mean_wQuantileLoss', 'Regex':'test_mean_wQuantileLoss=([0-9\.]+)'},
    {'Name':'test_mean_absolute_QuantileLoss', 'Regex':'test_mean_absolute_QuantileLoss=([0-9\.]+)'},
]

# where to store training artefacts
s3_output_train_prefix = f's3://{s3_bucket}/{s3_prefix}/train/output'

# model package group name in the Model Registry
model_package_group_name = 'gluonts-tft-local'

# PyTorch Deep Learning Container framework version
dlc_framework_version = '2.3'
python_version = 'py311'

### Run and test pipeline steps locally

In this section you develop and test preprocessing and training step of the workflow locally before combining all steps into a pipeline.

You create the following pipeline:

| Step | Description |
|---|---|
| **Data processing** | runs a processing job to extract a sample of time series from the whole dataset and to perform dataset split into train and test|
| **Training and evaluation** | runs a SageMaker training job using PyTorch container and `TemporalFusionTransformer` estimator. After training, the job evaluates model performance on a hold-out dataset |
| **Conditional step** | checks if model performance meets the specified threshold |
| **Register model** | registers a version of the model in the SageMaker model registry |

#### Preprocessing step

In [None]:
# Python function code is in the local files
from gluonts_pipeline.preprocess import preprocess

In [None]:
r_preprocess = preprocess(
    input_data_s3_path = f'{s3_input_data_prefix}/{dataset_zip_file_name}',
    output_s3_prefix=s3_output_data_prefix,
    freq=freq,
    prediction_length=prediction_length,
    data_start=data_start,
    data_end=data_end,
    backtest_windows=4,
    sample_size=10,  # set to 0 to use the full dataset with 370 series
)
r_preprocess

In [None]:
# validate that the dataset is copied on S3
!aws s3 ls {s3_output_data_prefix}/{r_preprocess['pipeline_run_id']} --recursive

#### Training step

This section demonstrates two options for implementing GluonTS training. 

Option 1 is to run a custom training script as a remote function using `@step` decorator. Option 2 is to use a [SageMaker SDK framework for PyTorch](https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html) and leverage the built-in [SageMaker Deep Learning Containers (DLC)](https://github.com/aws/deep-learning-containers/blob/master/available_images.md).

This section uses Option 2 because it requires less effort for training artefact management and allows model deployment with one line of code from a trained estimator.

In [None]:
from gluonts_pipeline.train_step import train

In [None]:
# input datasets for training
# these datasets are prepared and uploaded to S3 by the preprocessing step
training_inputs = {
    'train': r_preprocess['train_data'],
    'test': r_preprocess['test_data'],
}

In [None]:
# Option 1: use local training script and then @step decorator for the pipeline step
# This code is given for your reference
# r_train = train(
#     train_data_s3_path=r_preprocess['train_data'],
#     test_data_s3_path=r_preprocess['test_data'],
#     output_s3_prefix=s3_output_train_prefix,
#     hyperparameters=hyperparameters,
# )
# r_train

In [None]:
# copy scripts and requirements.txt to the dedicated directory
!cp -rf ./gluonts_pipeline/* ./gluonts_pipeline_local

In [None]:
# create a local session
LOCAL_SESSION = LocalSession()
LOCAL_SESSION.config = {'local': {'local_code': True}}  # Ensure full code locality, see: https://sagemaker.readthedocs.io/en/stable/overview.html#local-mode

In [None]:
# Option 2: use a PyTorch estimator class
# note use of 'local' as instance_type parameter for local mode
tft_estimator_local = PyTorch(
    entry_point='train.py',
    source_dir='gluonts_pipeline_local',
    framework_version=dlc_framework_version,  
    py_version=python_version,
    hyperparameters=hyperparameters,
    role=sm_role,
    instance_count=1,
    instance_type='local',
    output_path=s3_output_train_prefix,
    base_job_name="gluonts-pipeline-training",
    sagemaker_session=LOCAL_SESSION,
)

In [None]:
# run training in a Docker container locally in the notebook
# the first run takes longer because the container must be pulled
tft_estimator_local.fit(training_inputs)

In [None]:
# validate that the train artefacts are uploaded to S3
!aws s3 ls {s3_output_train_prefix}/{tft_estimator_local._current_job_name} --recursive

In [None]:
# download model metrics and extract to the local volume
!aws s3 cp {s3_output_train_prefix}/{tft_estimator_local._current_job_name}/output . --recursive
!tar -xvf ./output.tar.gz -C output/

In [None]:
# display aggregated test metrics
with open('output/data/agg_metrics.json', 'r') as f:
    agg_metrics = json.load(f)

agg_metrics

In [None]:
# display item-level metrics
item_metrics = pd.read_csv("output/data/item_metrics.csv.gz", compression="gzip")
visualize_item_metric(item_metrics, 'sMAPE')

In [None]:
# can load the saved model from disk into a predictor 
!tar -xvf ./model.tar.gz -C output/model/
predictor_deserialized = Predictor.deserialize(Path("./output/model")) 

In [None]:
# predict the last prediction_length data points of the test dataset
forecast_it, ts_it = make_evaluation_predictions(
    dataset=ListDataset(JsonLinesFile(Path('./data/test.jsonl.gz')), freq=freq), 
    predictor=predictor_deserialized,
    num_samples=20
)
forecasts = list(forecast_it)
labels = list(ts_it)

In [None]:
# plot forecasts
for i, f in enumerate(islice(forecasts, MAX_TS_TO_DISPLAY*2)):
    plt.title(f'Time series: {f.item_id}')
    plt.plot(labels[0][-prediction_length:].to_timestamp(), linewidth=3)
    f.plot(color=colors[i % len(colors)], show_label=True)
    plt.legend(['Historical', 'Predicted median', '90% confidence interval'], loc='upper left')
    plt.show()

In [None]:
# you can also run the training job remotely on a specified instance type
# tft_estimator = PyTorch(
#     entry_point='train.py',
#     source_dir='gluonts_pipeline',
#     framework_version=dlc_framework_version,  
#     py_version=python_version,
#     hyperparameters=hyperparameters,
#     role=sm_role,
#     instance_count=1,
#     instance_type='ml.g5.4xlarge',
#     output_path=s3_output_train_prefix,
#     base_job_name="gluonts-pipeline-training",
#     metric_definitions=metric_definitions,
# )

# tft_estimator.fit(training_inputs, wait=False)

#### Model registration step

This step registers a new trained model version in the SageMaker model registry within a [model package group](https://docs.aws.amazon.com/sagemaker/latest/dg/model-registry-model-group.html). This step is implemented as a local Python function that runs remotely during execution of the pipeline.

If you don't have a completed training job, skip this step and go to **Construct a pipeline**. The model registration implementation requires a completed training job to attach an estimator and to register a model.

In [None]:
from gluonts_pipeline.register import register

In [None]:
# r_register = register(
#     training_job_name=<TRAINING JOB NAME>,
#     model_package_group_name=model_package_group_name,
#     pipeline_run_id=r_preprocess['pipeline_run_id']
# )
# r_register

### Construct a pipeline
After local testing you can use the same Python code without any changes to construct a pipeline.

The next cell creates a pipeline with previously developed and tested steps.

You don't need to manually define an ordering of the steps, as SageMaker automatically derives the processing flow based on data dependencies between pipeline's steps. You also don't need to manage transfer of artifacts and datasets from one pipeline's step to another, because SageMaker automatically takes care of the data flow.

In [None]:
from sagemaker.workflow.pipeline import Pipeline
from sagemaker.workflow.pipeline_context import PipelineSession
from sagemaker.workflow.functions import Join
from sagemaker.workflow.steps import (
    TrainingStep, 
    CacheConfig
)
from sagemaker.workflow.parameters import (
    ParameterInteger, 
    ParameterFloat, 
    ParameterString, 
    ParameterBoolean
)
from sagemaker.workflow.condition_step import ConditionStep
from sagemaker.workflow.conditions import ConditionLessThanOrEqualTo
from sagemaker.workflow.fail_step import FailStep
from sagemaker.workflow.function_step import step
from sagemaker.workflow.pipeline_definition_config import PipelineDefinitionConfig 
from sagemaker.workflow.execution_variables import ExecutionVariables

#### Define pipeline parameters

In [None]:
# quota codes for training and processing job instances
prc_instance_map = {
    'ml.c5.4xlarge':'L-E7898792',
    'ml.c5.2xlarge':'L-49679826',
    'ml.m5.xlarge':'L-CCE2AFA6'
}

trn_instance_map = {
    'ml.g5.4xlarge':'L-FE869B40',
    'ml.g5.2xlarge':'L-2D6DEB3C',
    **prc_instance_map,
}

# define unique pipeline and model registry package group name
timestamp = strftime('%d-%H-%M-%S', gmtime())
pipeline_name = f'gluonts-pipeline-{timestamp}'
model_package_group_name = f'gluonts-tft-{timestamp}'

# get processing and training instances based on the available account quotas
processing_instance_type = get_best_instance(prc_instance_map)
training_instance_type = get_best_instance(trn_instance_map)

# define pipeline parameters
pipeline_parameters = {
    'processing_instance_type':ParameterString(name='processing_instance_type', default_value=processing_instance_type),
    'training_instance_type':ParameterString(name='training_instance_type', default_value=training_instance_type),
    'input_data_s3_path':ParameterString(name='input_data_s3_path', default_value=f'{s3_input_data_prefix}/{dataset_zip_file_name}'),
    's3_output_data_prefix':ParameterString(name='s3_output_data_prefix', default_value=s3_output_data_prefix),
    's3_output_train_prefix':ParameterString(name='s3_output_train_prefix', default_value=s3_output_train_prefix),
    'freq':ParameterString(name='freq', default_value=freq),
    'prediction_length':ParameterInteger(name='prediction_length', default_value=prediction_length),
    'data_start':ParameterString(name='data_start', default_value=str(data_start)),
    'data_end':ParameterString(name='data_end', default_value=str(data_end)),
    'backtest_windows':ParameterInteger(name='backtest_windows', default_value=4),
    'sample_size':ParameterInteger(name='sample_size', default_value=10),
    'mean_wQL_score_threshold':ParameterFloat(name='mean_wQL_score_threshold', default_value=0.2),
    'model_package_group_name':ParameterString(name='model_package_group_name', default_value=model_package_group_name),
}

#### Define the pipeline

Now define the pipeline by combinining all step together. There are two additional steps in the pipeline.

**Condition step**  
The condition step checks the model performance score calculated in the training step and conditionally creates a model and registers it in the model registry, or stops and fails the pipeline execution.

**Fail step**  
A Pipelines [FailStep](https://sagemaker.readthedocs.io/en/stable/workflows/pipelines/sagemaker.workflow.pipelines.html#sagemaker.workflow.fail_step.FailStep) stops the pipeline execution if the model performance metric doesn't meet the specified threshold.

You need to pass only the last step to Pipeline constructor. The SDK automatically builds a pipeline DAG based on data dependencies between steps. Refer to the [Developer Guide](https://docs.aws.amazon.com/sagemaker/latest/dg/pipelines-overview.html) for more details.

In [None]:
from gluonts_pipeline.preprocess import preprocess
from gluonts_pipeline.train_step import train
from gluonts_pipeline.register import register

In [None]:
# preprocess data step
step_preprocess = step(
    preprocess, 
    instance_type=pipeline_parameters['processing_instance_type'],
    keep_alive_period_in_seconds=1800,
)(
    input_data_s3_path=pipeline_parameters['input_data_s3_path'],
    output_s3_prefix=pipeline_parameters['s3_output_data_prefix'],
    freq=pipeline_parameters['freq'],
    prediction_length=pipeline_parameters['prediction_length'],
    data_start=pipeline_parameters['data_start'],
    data_end=pipeline_parameters['data_end'],
    backtest_windows=pipeline_parameters['backtest_windows'],
    sample_size=pipeline_parameters['sample_size'],  
    pipeline_run_id=ExecutionVariables.PIPELINE_EXECUTION_ID,
)

# training step with PyTorch estimator
step_train = TrainingStep(
    name='train',
    step_args=PyTorch(
        entry_point='train.py',
        source_dir='gluonts_pipeline',
        framework_version=dlc_framework_version,  
        py_version=python_version,
        hyperparameters=hyperparameters,
        role=sm_role,
        instance_count=1,
        instance_type=pipeline_parameters['training_instance_type'],
        output_path=pipeline_parameters['s3_output_train_prefix'],
        base_job_name="gluonts-pipeline-training",
        metric_definitions=metric_definitions,
        sagemaker_session=PipelineSession(),
    ).fit({'train':step_preprocess['train_data'], 'test':step_preprocess['test_data']}),
    # cache_config=CacheConfig(enable_caching=True, expire_after="P30d"),
)

# This code if you use the training script with @step decorator
# step_train = step(
#     train,
#     instance_type=pipeline_parameters['training_instance_type'],
#     keep_alive_period_in_seconds=1800,
# )(
#     train_data_s3_path=step_preprocess['train_data'],
#     test_data_s3_path=step_preprocess['test_data'],
#     output_s3_prefix=pipeline_parameters['s3_output_train_prefix'],
#     hyperparameters=hyperparameters,
# )

# register model step
step_register = step(
    register, 
    keep_alive_period_in_seconds=1800,
)(
    training_job_name=step_train.properties.TrainingJobName,
    model_package_group_name=pipeline_parameters['model_package_group_name'],
    pipeline_run_id=step_preprocess['pipeline_run_id'],
)

# fail the pipeline execution step
step_fail = FailStep(
    name='fail',
    error_message=Join(
        on=" ", 
        values=["Execution failed due to mean_wQL > ", pipeline_parameters['mean_wQL_score_threshold']]
    ),
)

# check if the test score acceptable
step_condition = ConditionStep(
    name='check-mean-wQL',
    conditions=[
        ConditionLessThanOrEqualTo(
            left=step_train.properties.FinalMetricDataList['test_mean_wQuantileLoss'].Value,
            right=pipeline_parameters['mean_wQL_score_threshold']
    )],
    if_steps=[step_register],
    else_steps=[step_fail],
)

# Create a pipeline object
pipeline = Pipeline(
    name=f"{pipeline_name}",
    parameters=[v for v in pipeline_parameters.values()],
    steps=[step_condition],
    pipeline_definition_config=PipelineDefinitionConfig(use_custom_job_prefix=True)
)

#### Upsert the pipeline

If a pipeline with the same name already exits, SageMaker will update it.

In [None]:
# Upsert operation serialize the function code, arguments, and other artefacts to S3 where it can be accessed during pipeline's runtime
pipeline.upsert(role_arn=sm_role)

To see the created pipeline in the Studio UI, click on the link constructed by the code cell below:

In [None]:
from IPython.display import HTML

# Show the pipeline link
display(
    HTML('<b>See <a target="top" href="https://studio-{}.studio.{}.sagemaker.aws/pipelines/{}/graph">the pipeline</a> in the Studio UI</b>'.format(
            domain_id, region, pipeline_name))
)

### Execute the pipeline

A pipeline execution takes about 20 minutes.

In [None]:
# this starts a pipeline execution
pipeline_execution = pipeline.start()
pipeline_execution.describe()

In [None]:
# Uncomment if you would like to wait in the notebook until this execution completes
# pipeline_execution.wait() 
pipeline_execution.list_steps()

You can see the pipeline execution in the Studio UI by clicking on the link constructed by the following code cell:

In [None]:
# Show the pipeline execution link
display(
    HTML('<b>See <a target="top" href="https://studio-{}.studio.{}.sagemaker.aws/pipelines/{}/executions/{}/graph">the pipeline execution</a> in the Studio UI</b>'.format(
            domain_id, region, pipeline_name, pipeline_execution.describe()['PipelineExecutionArn'].split('/')[-1]))
)

### Explore pipeline execution

After the pipeline execution completed, you have a trained model which is registered as a new model version in the specified model package group in the SageMaker Model Registry. You can also explore calculated model metrics and get any details on the training job.

In [None]:
# make sure the execution completed
pipeline_execution.wait()
assert pipeline_execution.describe()['PipelineExecutionStatus'] == 'Succeeded'

In [None]:
pipeline_execution_id = pipeline_execution.arn.split('/')[-1]

In [None]:
# Display all steps with StepStatus
pipeline_execution.list_steps()

### Explore training metrics

You can access metrics stored in the training job execution metadata or download the metric files from S3. SageMaker automatically captures and saves metrics emitted by the training job via the log stream by looking for metrics defined in `metric_definitions` regex. You passed the `metric_definition` parameter to the training estimator during creation.

In [None]:
# get training job name from the pipeline step properties
train_step = [s for s in pipeline_execution.list_steps() if s['StepName'] == 'train'][0]
training_job_name = train_step['Metadata']['TrainingJob']['Arn'].split('/')[-1]

# Attach an estimator object to the job
estimator = Estimator.attach(training_job_name)

In [None]:
# display training job metrics
# these metrics were automatically captured and saved by SageMaker based on metric_definitions
estimator.training_job_analytics.dataframe()

In [None]:
# get all job properties including S3 output prefix where training artefacts stored
job_description = boto3.client('sagemaker').describe_training_job(TrainingJobName=training_job_name)
s3_output_prefix = job_description['OutputDataConfig']['S3OutputPath']

In [None]:
# display train artefacts uploaded to S3 by the training job
!aws s3 ls {s3_output_prefix}/{training_job_name} --recursive

In [None]:
# download and extract model metrics to the local volume
!aws s3 cp {s3_output_prefix}/{training_job_name}/output/output.tar.gz .
!tar -xvf ./output.tar.gz -C output/

In [None]:
# display aggregated test metrics
with open('output/agg_metrics.json', 'r') as f:
    agg_metrics = json.load(f)

agg_metrics

In [None]:
# display item-level metrics
item_metrics = pd.read_csv("output/item_metrics.csv.gz", compression="gzip")
visualize_item_metric(item_metrics, 'sMAPE')

### Deploy the model to a real-time endpoint

Now having a model version registered in the model registry, you can deploy this version as a SageMaker endpoint and run real-time inference.

#### Explore model registry

First take a look at the registered model version in the model package group. Open the link constructed by the following code cell and explore model version details. Note that you have end-to-end data lineage for the registered model version.

In [None]:
# you choose a specific model package to take the model from
# model_package_group_name = <model package group name>
model_package_group_name = pipeline_parameters['model_package_group_name'].default_value

# get the latest model version in the model package group
sm_model_package = boto3.client('sagemaker').list_model_packages(
    ModelPackageGroupName=model_package_group_name,
    SortBy="CreationTime",
    SortOrder="Descending",
)['ModelPackageSummaryList'][0]

In [None]:
from IPython.display import HTML

# Show the model version link
display(
    HTML('<b>See <a target="top" href="https://studio-{}.studio.{}.sagemaker.aws/models/registered-models/{}/versions/version-{}/lineage">the model version lineage</a> in the Studio UI</b>'.format(
            domain_id, region, pipeline_parameters['model_package_group_name'].default_value, sm_model_package['ModelPackageVersion']))
)

#### Deploy the model

You can deploy a trained model using the SageMaker Model Registry as the model metadata source. The model registry contains all information needed to deploy and to use the model.

In [None]:
from sagemaker.pytorch import PyTorchModel

In [None]:
# each model version has its own ARN
model_package_arn = sm_model_package['ModelPackageArn']
model_package_arn

In [None]:
# get the model data from the model registry
sm_model_package_data = boto3.client('sagemaker').describe_model_package(
    ModelPackageName=model_package_arn
)

# get the associated pipeline_run_id from the model metadata
pipeline_execution_id = sm_model_package_data['CustomerMetadataProperties'].get('pipeline_run_id')
if not pipeline_execution_id:
    pipeline_execution_id = sm_model_package_data['InferenceSpecification']['Containers'][0]['ModelDataUrl'].split('-')[-2]

print(f"Using version {sm_model_package_data['ModelPackageVersion']} of {sm_model_package_data['ModelPackageGroupName']}")
print(f'The model was registered by the pipeline execution {pipeline_execution_id}')

In [None]:
# create a meanful endpoint name
tft_endpoint_name = f"model-endpoint-{sm_model_package_data['ModelPackageGroupName']}-v{sm_model_package_data['ModelPackageVersion']}"
tft_endpoint_name

In [None]:
# delete the endpoint config if exists
try:
    sm = boto3.client('sagemaker')
    sm.describe_endpoint_config(EndpointConfigName=tft_endpoint_name)
    sm.delete_endpoint_config(EndpointConfigName=tft_endpoint_name)
except:
    pass

The following code implements a custom predictor class derived from the SageMaker Python SDK [`Predictor`](https://sagemaker.readthedocs.io/en/stable/api/inference/predictors.html#sagemaker.predictor.Predictor). You can use the custom predictor to perform any required data pre- and postprocessing on the client side. The custom predictor also takes care of request serialization and response deserialization.

In [None]:
from sagemaker.serializers import IdentitySerializer
from typing import List, Dict, Tuple, Union

class GluontsTFTPredictor(sagemaker.predictor.Predictor):
    def __init__(self, *args, **kwargs):
        super().__init__(
            *args,
            # serializer=JSONSerializer(),
            serializer=IdentitySerializer(content_type="application/json"),
            **kwargs,
        )

    def predict(
        self,
        list_dataset:ListDataset,
        prediction_start:pd.Timestamp,
        context_length:int=4*prediction_length,
        num_samples:int=20, # not used in TFT predictor
    ):
        """
        Input: dataset, prediction_start date, context_length
        The function prepares the required part of the dataset and serializes it as json str
        Output: a list of QuantileForecast objects
        """

        freq = list_dataset[0][FieldName.START].freq
        dataset_start = list_dataset[0][FieldName.START]
        
        # calculate the start of prediction dataset based on required context
        prediction_input_start = prediction_start - context_length*freq
        
        # calculate lower und upper array indices for the required part of the dataset
        l = len(pd.date_range(start=dataset_start.to_timestamp(), end=prediction_input_start, freq=freq))
        u = len(pd.date_range(start=dataset_start.to_timestamp(), end=prediction_start, freq=freq))

        return self.__decode_response(
            super(GluontsTFTPredictor, self).predict(self.__encode_request(
                list_dataset, prediction_input_start, l, u, num_samples
            )))
        
    def __encode_request(self, dataset, start_date, start_idx, end_idx, num_samples):
        inputs = {
            "inputs":[
                {
                    FieldName.ITEM_ID: i[FieldName.ITEM_ID],
                    FieldName.TARGET: i[FieldName.TARGET][start_idx:end_idx].tolist(),
                    FieldName.START: start_date.isoformat(),
                }
                for i in dataset
            ],
        }
        
        parameters = {
            "parameters": {
                "freq": dataset[0][FieldName.START].freq.freqstr,
                "num_samples": num_samples
            }
        }
        
        return json.dumps({**inputs, **parameters}).encode('utf-8')

    def __decode_response(self, response):
        return json.loads(response.decode('utf-8'), object_hook=quantile_forecast_decoder)


def quantile_forecast_decoder(obj):
    if "__type__" in obj and obj["__type__"] == "QuantileForecast":
        return QuantileForecast(
            forecast_arrays=np.array(obj["forecast_arrays"]),
            start_date=pd.Period(obj["start_date"], obj['freq']),
            forecast_keys=obj["forecast_keys"],
            item_id=obj["item_id"],
            info=obj["info"],            
        )
    return obj

In [None]:
# Create PyTorch model object
model = PyTorchModel(
    role=sm_role,
    model_data=sm_model_package_data['InferenceSpecification']['Containers'][0]['ModelDataUrl'],
    framework_version=sm_model_package_data['InferenceSpecification']['Containers'][0]['FrameworkVersion'],
    py_version=python_version,
    entry_point='inference.py',
    source_dir='gluonts_pipeline',
    predictor_cls=GluontsTFTPredictor,
)

In [None]:
# supported real-time inference instances
ep_instances = sm_model_package_data['InferenceSpecification']['SupportedRealtimeInferenceInstanceTypes']
ep_instances

Based on the account service quotas get the best available instances type for the endpoint.

In [None]:
ep_instance_map = {
    'ml.g5.2xlarge':'L-9614C779',
    'ml.g5.xlarge':'L-1928E07B',
    'ml.m5.2xlarge':'L-C88C8F13',
    'ml.m5.xlarge':'L-2F737F8D',
}

In [None]:
ep_instance = get_best_instance({k:v for k,v in ep_instance_map.items() if k in ep_instances})
print(f'Use {ep_instance} for the endpoint based on the avaliable quotas and supported instances')



In [None]:
# Deploy the model to an endpoint
tft_predictor = model.deploy(
    initial_instance_count=1,
    instance_type=ep_instance,
    endpoint_name=tft_endpoint_name,
)



### Make predictions and visualize

Now having a real-time inference endpoint, you can send datasets for predictions. This section creates an interactive visualization of forecasts generated by the trained model.

Use the test dataset which was created by the pipeline's preprocessing job. This ensures that you use the same subset of the time series that was used for model training.

In [None]:
# download the dataset from the preprocessing step
!aws s3 cp {s3_output_data_prefix}/{pipeline_execution_id}/test/test.jsonl.gz ./data

# load the dataset for predictions
test_ds = ListDataset(JsonLinesFile(Path('./data/test.jsonl.gz')), freq=freq)

In [None]:
print(f"""
Test dataset contains {len(test_ds)} time series
Time series: {[i[FieldName.ITEM_ID] for i in test_ds]}
Dataset start is {test_ds[0][FieldName.START]}
Each time series contains {len(test_ds[0][FieldName.TARGET])} data points
""")

In [None]:
# it's also possible to create a local predictor with the trained model
# !aws s3 cp {sm_model_package_data['InferenceSpecification']['Containers'][0]['ModelDataUrl']} .
# !tar -xvf ./model.tar.gz -C output/model/
# tft_predictor_local = Predictor.deserialize(Path("./output/model")) 

Make predictions using the real-time endpoint and evaluate results:

In [None]:
# take a data four weeks before the end of the dataset for the prediction start
prediction_start = end_dataset_date - 4*prediction_length*test_ds[0][FieldName.START].freq
# call the endpoint to generate forecasts
forecasts = tft_predictor.predict(test_ds, prediction_start, 4*prediction_length)

# calculate metrics for the forecasts
evaluator = Evaluator(quantiles=(np.arange(10) / 10.0)[1:])
agg_metrics, item_metrics = evaluator(
    [to_pandas(i) for i in test_ds], 
    forecasts,
    num_series=len(test_ds),
)

# show metrics
agg_metrics

In [None]:
visualize_item_metric(item_metrics, 'sMAPE')

#### Interactive visualization

Visualize model predictions using the following interactive control. You can change the following parameters:

- `Time series ids`: ids of the time series in the test dataset. You can select multiple time series to predict and to plot
- `Predict from`: start of the prediction interval  
- `Context length`: how many `prediction_length` of data points are sent to the model

In [None]:
style = {"description_width": "initial"}
ts_id_list = [i[FieldName.ITEM_ID] for i in test_ds]

In [None]:
@interact_manual(
    ts_ids=SelectMultiple(options=ts_id_list, value=[ts_id_list[0]], rows=5, style=style, description='Time series ids:'),    
    prediction_start=DatePicker(value=data_end, style=style, description='Predict from:'),
    context_length=IntSlider(min=1, max=10, value=4, style=style, description='Context length:'),
    # num_samples=IntSlider(min=1, max=100, value=20, style=style, description='Number of samples:'),
    continuous_update=False,
)
def plot_interact(ts_ids, prediction_start, context_length):

    def _to_pandas_dict(list_ds, start, end) -> Dict[str, pd.Series]:
        return {i[FieldName.ITEM_ID]:to_pandas(i)[start:end] for i in list_ds}

    freq = test_ds[0][FieldName.START].freq
    prediction_start = pd.Timestamp(prediction_start)
    prediction_ds = [i for i in test_ds if i[FieldName.ITEM_ID] in ts_ids]
    historical_series = _to_pandas_dict(prediction_ds, prediction_start - context_length*prediction_length*freq, prediction_start)
    label_series = _to_pandas_dict(prediction_ds, prediction_start, prediction_start + prediction_length*freq)

    # call the endpoint to generate forecasts
    forecasts = tft_predictor.predict(prediction_ds, prediction_start, context_length*prediction_length)
    
    figs = []
    c_interval = float(forecasts[0].forecast_keys[-1])
    for i, f in tqdm.tqdm(enumerate(forecasts), total=len(forecasts), desc='Creating plots'):
        fig, ax = plt.subplots(1, 1, figsize=(15,4))

        ax.set_title(f'Time series: {f.item_id}')
        ax.plot(historical_series[f.item_id].to_timestamp(), color=colors[i % len(colors)], linewidth=3)
        ax.plot(label_series[f.item_id].to_timestamp())
        f.plot(intervals=(c_interval,), name=f'{f.item_id} forecast', show_label=True)

        plt.legend(['Historical', 'Ground truth', 'Predicted median', f'{c_interval*100:.0f}% confidence interval'], loc="upper left")
        plt.tight_layout()
        figs.append(fig)
    
    plt.show()

---

## Further reading

GluonTS offers an advanced framework to create, train, and evaluate your own models. Refer to the GluonTS tutorial [Create your own model](https://ts.gluon.ai/stable/tutorials/forecasting/extended_tutorial.html#Create-your-own-model) for detailed documentation.

If you would like to run GluonTS training and predictions in a custom container, for example using SageMaker training or batch transform jobs, refer to this [sample Dockerfiles for GluonTS](https://github.com/awslabs/gluonts/tree/dev/examples/dockerfiles) GitHub repository.

You can find another example of productizaiton of GluonTS training in GitHub repository [Deep Demand Forecasting with Amazon SageMaker](https://github.com/awslabs/sagemaker-deep-demand-forecast). This example uses processing and training job with MXNet LSTNet estimator to implement a training pipeline.

For management of ML experiments, metrics, models, and artifacts you can use [SageMaker managed MLflow](https://docs.aws.amazon.com/sagemaker/latest/dg/mlflow.html). With MLflow you track, organize, view, analyze, and compare iterative ML experimentation to gain comparative insights and register and deploy your best performing models.

## Clean up

If you don't need the deployed endpoint anymore, delete it to avoid incuring unnecessary costs.

In [None]:
# delete the endpoint and the model
tft_predictor.delete_endpoint()
model.delete_model()

---