In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pickle

%load_ext autoreload
%autoreload 2

def plot_preds(train, test, pred_dict, model_name, show_samples=False):
    pred = pred_dict['median']
    pred = pd.Series(pred, index=test.index)
    plt.figure(figsize=(8, 6), dpi=100)
    plt.plot(train, color='black')
    plt.plot(test, label='Truth', color='black')
    plt.plot(pred, label=model_name, color='purple')
    # shade 90% confidence interval
    samples = pred_dict['samples']
    lower = np.quantile(samples, 0.05, axis=0)
    upper = np.quantile(samples, 0.95, axis=0)
    plt.fill_between(pred.index, lower, upper, alpha=0.3, color='purple')
    if show_samples:
        samples = pred_dict['samples']
        # convert df to numpy array
        samples = samples.values if isinstance(samples, pd.DataFrame) else samples
        for i in range(min(10, samples.shape[0])):
            plt.plot(pred.index, samples[i], color='purple', alpha=0.3, linewidth=1)
    plt.legend(loc='upper left')
    if 'NLL/D' in pred_dict:
        nll = pred_dict['NLL/D']
        if nll is not None:
            plt.text(0.03, 0.85, f'NLL/D: {nll:.2f}', transform=plt.gca().transAxes, bbox=dict(facecolor='white', alpha=0.5))
    plt.show()

## Darts

In [None]:
from data.small_context import get_datasets

output_dir = 'precomputed_outputs/darts'
datasets = get_datasets()
for ds_name, data in datasets.items():
    print(ds_name)
    data = datasets[ds_name]
    train, test = data
    with open(f'{output_dir}/{ds_name}.pkl', 'rb') as f:
        out = pickle.load(f)
    for model in out:
        plot_preds(train, test, out[model], model, show_samples=True)

## Synthetic

In [None]:
from data.synthetic import get_synthetic_datasets
output_dir = 'precomputed_outputs/synthetic'
datasets = get_synthetic_datasets()
for ds_name, data in datasets.items():
    print(ds_name)
    data = datasets[ds_name]
    train, test = data
    with open(f'{output_dir}/{ds_name}.pkl', 'rb') as f:
        out = pickle.load(f)
    for model in out:
        plot_preds(train, test, out[model], model, show_samples=True)

## Monash

In [None]:
def plot_monash_preds(train, test, pred_dict, model_name, max_series):
    for i in range(min(max_series, len(test))):
        pred = pd.Series(pred_dict['median'][i], index=test[i].index)
        plt.figure(figsize=(8, 6), dpi=100)
        plt.plot(train[i], color='black')
        plt.plot(test[i], label='Truth', color='black')
        plt.plot(pred, label=model_name, color='purple')
        plt.legend(loc='upper left')
        ymax = max(train[i].max(), test[i].max()) * 1.1
        ymin = plt.gca().get_ylim()[0]
        plt.ylim(ymin, ymax)
        if 'NLL/D' in pred_dict:
            nll = pred_dict['NLL/D'][i]
            if nll is not None:
                plt.text(0.03, 0.85, f'NLL/D: {nll:.2f}', transform=plt.gca().transAxes, bbox=dict(facecolor='white', alpha=0.5))
        plt.show()

In [None]:
from data.monash import get_datasets
output_dir = 'precomputed_outputs/monash'
max_history_len = 500
datasets = get_datasets()
for ds_name, data in datasets.items():
    if not os.path.exists(f'{output_dir}/{ds_name}.pkl'):
        continue
    print(ds_name)
    data = datasets[ds_name]
    train, test = data
    train = [x[-max_history_len:] for x in train]
    # turn into pd series
    train = [pd.Series(train[i], index=pd.RangeIndex(len(train[i]))) for i in range(len(train))]
    test = [pd.Series(test[i], index=pd.RangeIndex(len(train[i]), len(train[i]) + len(test[i]))) for i in range(len(test))]
    
    with open(f'{output_dir}/{ds_name}.pkl', 'rb') as f:
        out = pickle.load(f)
    for model in out:
        plot_monash_preds(train, test, out[model], model, max_series=3)