In [None]:
import torch
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import copy
import datetime

from einops import rearrange
from gluonts.dataset.multivariate_grouper import MultivariateGrouper
from gluonts.dataset.pandas import PandasDataset
from gluonts.dataset.split import split

from uni2ts.eval_util.plot import plot_single, plot_next_multi
from uni2ts.model.moirai import MoiraiForecast, MoiraiModule

In [None]:
def vis_raw_data(df):
    cols = df.columns
    ### plot the raw data
    fig, axes = plt.subplots(nrows=2, ncols=len(cols)//2, figsize=(25, 10))
    for i, ax, name in zip(np.arange(len(cols)), axes.flatten(), cols):
        ax.plot(df.values[i,:])
        ax.set_title(name)

def vis_inference_performance(label_day, forecast_day, input_day, CTX):
    plt.rcParams.update({'font.size': 16})

    ### plot the prediction intervals
    fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(25, 10))
    for i, ax, name in zip(np.arange(6), axes.flatten(), ['s_current', 's_shell', 's_shaft', 'n_current', 'n_shell', 'n_shaft']):
        plot_single(
            input_day,
            label_day,
            forecast_day,
            context_length=CTX,
            intervals=(0.1, 0.5),
            dim=i,
            ax=ax,
            name="pred",
            show_label=True,
        )
        ax.set_title(name)
        ax.set_xticks(ax.get_xticks())
        ax.set_xticklabels(ax.get_xticklabels(), rotation=0, ha="center")    
        xticklabels = ax.get_xticklabels()
        new_labels = [label.get_text().replace('2019-', '') for label in xticklabels]
        ax.set_xticklabels(new_labels)

    ### plot the exact prediction
    pred = forecast_day.quantile(0.5).transpose()
    gt   = label_day['target']

    # s_day = 30
    # len_day = 3
    # x_df = np.arange(1440*(s_day), 1440*(s_day+len_day))
    fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(25, 10))
    for i, ax, name in zip(np.arange(6), axes.flatten(), ['s_current', 's_shell', 's_shaft', 'n_current', 'n_shell', 'n_shaft']):
        ax.plot(gt[i,:], c='b', label='gt')
        ax.plot(pred[i,:], c='r', label='pred')
        xticklabels = ax.get_xticklabels()
        new_labels = [label.get_text().replace('2019-', '') for label in xticklabels]
        ax.set_xticklabels(new_labels)
        ax.set_title(name)

    ### plot the residual
    fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(25, 10))
    for i, ax, name in zip(np.arange(6), axes.flatten(), ['s_current', 's_shell', 's_shaft', 'n_current', 'n_shell', 'n_shaft']):
        ax.plot(gt[i,:] - pred[i,:], c='k', label='gt-pred')
        ax.set_title(name)


def model_inference(df, SIZE, PDT, CTX, PSZ, BSZ, TEST):

    ds = PandasDataset(dict(df))
    
    ### model settings
    SIZE = SIZE  # model size: choose from {'small', 'base', 'large'}
    PDT = PDT  # prediction length: any positive integer
    CTX = CTX  # context length: any positive integer
    PSZ = PSZ  # patch size: choose from {"auto", 8, 16, 32, 64, 128}
    BSZ = BSZ  # batch size: any positive integer
    TEST = TEST  # test set length: any positive integer

    # Group time series into multivariate dataset
    grouper = MultivariateGrouper(len(ds))
    multivar_ds = grouper(ds)

    # Split into train/test set
    train, test_template = split(
        multivar_ds, offset=-TEST
    )  # assign last TEST time steps as test set

    # Construct rolling window evaluation
    test_data = test_template.generate_instances(
        prediction_length=PDT,  # number of time steps for each prediction
        windows=TEST // PDT,    # number of windows in rolling window evaluation
        distance=PDT,           # number of time steps between each window - distance=PDT for non-overlapping windows
    )

    # Prepare model
    model = MoiraiForecast(
        module=MoiraiModule.from_pretrained(f"Salesforce/moirai-1.0-R-{SIZE}"),
        prediction_length=PDT,
        context_length=CTX,
        patch_size=PSZ,
        num_samples=100, # number of samples for probabilistic forecasts
        target_dim=len(ds),
        feat_dynamic_real_dim=ds.num_feat_dynamic_real,
        past_feat_dynamic_real_dim=ds.num_past_feat_dynamic_real,
    )

    predictor = model.create_predictor(batch_size=BSZ)
    forecasts = predictor.predict(test_data.input)

    input_it = iter(test_data.input)
    label_it = iter(test_data.label)
    forecast_it = iter(forecasts)

    forecast_list = []
    label_list = []

    for i in range(TEST// PDT - 1):
        
        inp = next(input_it)
        label = next(label_it)
        forecast = next(forecast_it)
        
        if i == 0:
            label_day = copy.deepcopy(label)
            forecast_day = copy.deepcopy(forecast)
            input_day = copy.deepcopy(inp)
        
        forecast_list.append(forecast.samples)
        label_list.append(label['target'])

        print(forecast.start_date)

    label_day['target'] = np.concatenate(label_list, axis=1)
    forecast_day.__setattr__('samples', np.concatenate(forecast_list, axis=1))

    return label_day, forecast_day, input_day

In [None]:
file_name = 'sada'
df = pd.read_csv('/home/xintingzhu/projects/llm_sada_series/timesfm/df19_new.csv', index_col='ds', parse_dates=True)

###
SIZE = "small"  # model size: choose from {'small', 'base', 'large'}
PDT = 30  # prediction length: any positive integer
CTX = 30  # context length: any positive integer
PSZ = "auto"  # patch size: choose from {"auto", 8, 16, 32, 64, 128}
BSZ = 32  # batch size: any positive integer
# TEST = len(df) - 1440 * 29  # test set length: any positive integer
TEST = len(df) - 1440 * 30  # test set length: any positive integer

label_day, forecast_day, input_day = model_inference(df, SIZE, PDT, CTX, PSZ, BSZ, TEST)

vis_inference_performance(label_day, forecast_day, input_day, CTX)

