In [10]:
from collections import defaultdict
import pandas as pd
import numpy as np
from uni2ts.model.moirai import MoiraiForecast, MoiraiModule
from uni2ts.model.moirai2 import Moirai2Forecast, Moirai2Module
import multiprocessing

import torch
from torch import Tensor
import torch.nn.functional as F
from torch import nn
from torch.distributions import (Normal, StudentT, Poisson)
from uni2ts.distribution.negative_binomial import (NegativeBinomial)
from uni2ts.distribution import (MixtureOutput, 
                                 NormalOutput, 
                                 StudentTOutput,
                                LaplaceOutput, 
                                NormalFixedScaleOutput,
                                NegativeBinomialOutput, 
                                LogNormalOutput)
from utils.data_loader import create_cached_tsmixup_datasets
from load_cached_features import *
from timesfm.pytorch_patched_decoder import ResidualBlock
import matplotlib.pyplot as plt
from tqdm import tqdm
from scipy.special import stdtrit
from scipy.stats import (poisson, nbinom)
from pytorch_forecasting.metrics.quantile import QuantileLoss

from collections.abc import Generator
from typing import Any
from datetime import datetime
from pathlib import Path
from gluonts.dataset.common import ListDataset

import einops
# from utils.utils import load_test_data
context_len = 512
device = 'cuda:3'

In [2]:
# Loading tsmixup dataset
train_dataset, val_dataset = create_cached_tsmixup_datasets(
        max_samples=300000,
        context_length=512,
        prediction_length=128, # 1 or 128
        num_workers=16,
        cache_dir="/extra/datalab_scratch0/ctadler/time_series_models/mechanistic_interpretability/data/tsmixup_cache/",
        processed_cache_path="/extra/datalab_scratch0/ctadler/time_series_models/mechanistic_interpretability/data/tsmixup_cache/tsmixup_processed_300000_512_128.pkl",
        batch_size=4000
    )

def load_dataset(dataset, ts=1000, pred_length=1, ctx_len=512):
    if dataset == 'tsmixup':
        x = []
        y = []
        for i in range(ts) if isinstance(ts, int) else ts:
            val_dict = val_dataset[i]
            x.append(val_dict['past_values'])
            y.append(val_dict['future_values'])
        x = torch.stack(x)[:, -ctx_len:]
        y = torch.stack(y)[:,:pred_length]
        
    else:
        dataset_path = f"/extra/datalab_scratch0/ctadler/time_series_models/ts_foundation_calibration/data/{dataset}/y_{dataset}.csv"
        timestamp_column = "ds"

        data = pd.read_csv(
            dataset_path,
            parse_dates=[timestamp_column],
            index_col=0
        )

        x = []
        for id, vals in data.groupby('unique_id'):
            x.append(torch.from_numpy(vals['y'].to_numpy(np.float32)))
        x = torch.stack(x)

    if dataset != 'tsmixup':
        y = x[:,ctx_len:ctx_len+pred_length]
        x = x[:,:ctx_len]
    return x, y


In [3]:
pred_len = 64
patch_size = 32  # patch size: choose from {"auto", 8, 16, 32, 64, 128}
moirai = MoiraiForecast(
        module=MoiraiModule.from_pretrained(f"Salesforce/moirai-1.1-R-small"),
        prediction_length=pred_len,
        context_length=context_len,
        patch_size=patch_size,
        num_samples=100,
        target_dim=1,
        feat_dynamic_real_dim=0,
        past_feat_dynamic_real_dim=0,
    )

patch_size = 16  # patch size: choose from {"auto", 8, 16, 32, 64, 128}
moirai = Moirai2Forecast(
        module=Moirai2Module.from_pretrained(
            f"Salesforce/moirai-2.0-R-small",
        ),
        prediction_length=pred_len,
        context_length=context_len,
        target_dim=1,
        feat_dynamic_real_dim=0,
        past_feat_dynamic_real_dim=0,
    )
print(moirai.module.d_model)

384


In [None]:
x, y = load_dataset('tsmixup', len(val_dataset), pred_length=pred_len, ctx_len=context_len)
batch_size = 32

past_target = x[:batch_size, :, None].to(device)
start_time = time.time()

with torch.no_grad():
    target, observed_mask, sample_id, time_id, variate_id, prediction_mask = moirai._convert(
        patch_size=patch_size,
        past_target=past_target,                                 # B x past_time x D
        past_observed_target=torch.isfinite(past_target),               # B x past_time x D (bool)
        past_is_pad=torch.full_like(past_target[:, :, 0], False, dtype=bool),                                 # B x past_time (bool)
    )
    # patch_sizes = torch.ones_like(time_id, dtype=torch.long) * patch_size

    moirai_module = moirai.module.to(device)
    # moirai_module.get_reprs = True
    # print(f"Target: {target.shape}, observed: {observed_mask.shape}")
    # (preds, stats) = moirai_module(
    #     target,
    #     observed_mask,
    #     sample_id,
    #     time_id,
    #     variate_id,
    #     prediction_mask)
    # print(f"Took {(time.time()-start_time):.3f} seconds or {(time.time()-start_time)/batch_size:.3e} sec/sample")
    # print(preds.shape, stats.shape)
    # print(prediction_mask.shape)
    # print(preds[prediction_mask].reshape([batch_size, -1, preds.shape[-1]]).shape, stats[prediction_mask].shape)

    moirai_module.get_reprs = False
    moirai_module.eval()
    preds = moirai_module(
        target=target,
        observed_mask=observed_mask,
        sample_id=sample_id,
        time_id=time_id,
        variate_id=variate_id,
        prediction_mask=prediction_mask,
        training_mode=False)
    print(preds.shape)
    forecast = einops.rearrange(preds[:,context_len//patch_size,:],
                                "B (pred_len quantiles) -> B pred_len quantiles",
                                quantiles = 9, pred_len = pred_len)

    # print(preds[prediction_mask].reshape([batch_size, -1, preds.shape[-1]])[:,].shape)
    # print(torch.all(preds[prediction_mask].reshape([batch_size, -1, preds.shape[-1]]) == preds[:,-8:]))
    # print(preds.shape)
    # print(patch_sizes.shape, patch_size)
    # print(past_target.shape)
    # print(prediction_mask[0])

# batch_size 64 = 7.527e-03 sec/sample 360MiB
# batch_size 256 = 8.334e-05 sec/sample
# batch_size 8192 = 6.369e-05 sec/sample 5164MiB 
# batch_size 17246 = 6.340e-05 sec/sample 11040MiB
# batch_size 32768 = 7.402e-05 sec/sample 20450MiB
# batch_size 40000 = 7.124e-05 sec/sample 22350MiB

torch.Size([32, 36, 576])
tensor([[[-4.2389e-01, -2.1102e-01, -5.2238e-02,  ...,  3.3228e-01,
           3.7784e-01,  4.2304e-01],
         [ 4.0230e-01,  2.8132e-01,  8.4823e-02,  ..., -9.2818e-01,
          -1.2928e-01,  2.1589e-01],
         [ 4.7088e-01,  8.1729e-01,  9.4729e-01,  ...,  1.2616e+00,
           1.2408e+00,  1.1349e+00],
         ...,
         [ 3.3687e+00,  3.4486e+00,  3.5951e+00,  ...,  3.5238e+00,
           3.3352e+00,  3.1638e+00],
         [ 2.9093e+00,  2.6798e+00,  2.8039e+00,  ...,  4.0414e+00,
           4.2295e+00,  4.2987e+00],
         [ 4.3955e+00,  4.4018e+00,  4.3972e+00,  ...,  3.8909e+00,
           3.6451e+00,  3.4308e+00]],

        [[-4.5164e-02, -4.3477e-02, -4.4172e-02,  ..., -2.2126e-02,
          -1.7004e-02, -2.1506e-02],
         [-3.4099e-02, -2.4940e-02, -1.8077e-02,  ...,  1.0282e-01,
          -9.6476e-03, -1.7290e-02],
         [-2.3975e-02, -3.1502e-02, -3.2562e-02,  ..., -2.2143e-02,
          -3.6243e-02, -2.0800e-02],
         ...,

In [None]:
moirai_module.get_reprs = False
def torch_ds_to_gluonts_listdataset(dataset) -> Generator[dict[str, Any]]:
    def entries():
        for i in range(len(dataset)):
            val_dict = dataset[i]
            x = torch.cat((val_dict['past_values'], val_dict['future_values']))
            # past_vals, future_vals = dataset[i]
            yield {
                "target": x.numpy(),  # array of shape (time,)
                "start": datetime(1970, 1, 1, 0, 0),
                "freq": 'h',
                "item_id": f"item_{i}",
            }

    return ListDataset(entries(), freq='h')
gluonts_ds = torch_ds_to_gluonts_listdataset(val_dataset)

# print(x.shape)
batch_size = 64
predictor = moirai.create_predictor(batch_size=batch_size)
forecasts = predictor.predict(gluonts_ds)

input_it = iter(x)
label_it = iter(y)
forecast_it = iter(forecasts)

start_time = time.time()
for i, (input, label, forecast) in enumerate(zip(input_it, label_it, forecast_it)):
    # print(f"{i} {(time.time()-start_time):.4f}")
    print(forecast.forecast_array.shape)
    break

(9, 64)
