In [None]:
import matplotlib.pyplot as plt
from gluonts.dataset.pandas import PandasDataset
from gluonts.dataset.split import split
from peakweather.dataset import PeakWeatherDataset

from uni2ts.eval_util.plot import plot_single, plot_next_multi
from uni2ts.model.moirai2 import Moirai2Forecast, Moirai2Module

In [None]:
ds = PeakWeatherDataset(
            root=None,
            compute_uv=False,
            station_type="meteo_station",
            freq="h",
            aggregation_methods={"temperature": "mean"},
        )
train, mask = ds.get_observations(
    parameters="temperature",
    first_date="2020-01-01",
    last_date="2020-11-30",
    as_numpy=True,
    return_mask=True,
)
good_stations = (mask.sum(axis=0) > 0).squeeze()
test = ds.get_observations(
    parameters="temperature",
    first_date="2021-01-01",
    last_date="2021-01-31",
)[:, good_stations]

In [None]:
SIZE = "small"  # model size: choose from {'small', 'base', 'large'}
PDT = 20  # prediction length: any positive integer
CTX = 200  # context length: any positive integer
PSZ = "auto"  # patch size: choose from {"auto", 8, 16, 32, 64, 128}
BSZ = 32  # batch size: any positive integer
TEST = 100  # test set length: any positive integer

In [None]:
# Convert into GluonTS dataset
ds = PandasDataset(dict(df))

# Split into train/test set
train, test_template = split(
    ds, offset=-TEST
)  # assign last TEST time steps as test set

# Construct rolling window evaluation
test_data = test_template.generate_instances(
    prediction_length=PDT,  # number of time steps for each prediction
    windows=TEST // PDT,  # number of windows in rolling window evaluation
    distance=PDT,  # number of time steps between each window - distance=PDT for non-overlapping windows
)

In [None]:
model = Moirai2Forecast(
    module=Moirai2Module.from_pretrained(
        f"Salesforce/moirai-2.0-R-small",
    ),
    prediction_length=100,
    context_length=1680,
    target_dim=1,
    feat_dynamic_real_dim=0,
    past_feat_dynamic_real_dim=0,
)

predictor = model.create_predictor(batch_size=BSZ)
forecasts = predictor.predict(test_data.input)

input_it = iter(test_data.input)
label_it = iter(test_data.label)
forecast_it = iter(forecasts)

In [None]:
# Make predictions
fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(25, 10))
plot_next_multi(
    axes,
    input_it,
    label_it,
    forecast_it,
    context_length=200,
    intervals=(0.5, 0.9),
    dim=None,
    name="pred",
    show_label=True,
)