In [1]:
from collections import defaultdict
import pandas as pd
import numpy as np
import timesfm
import multiprocessing

import torch
import torch.nn.functional as F
import argparse
import os
import time
import math
from gluonts.dataset.pandas import PandasDataset
from gluonts.dataset.split import split
from gluonts.itertools import batcher
from utils.utils import load_test_data
from utils.data_loader import create_cached_tsmixup_datasets
from load_cached_features import *
from timesfm.pytorch_patched_decoder import ResidualBlock
import matplotlib.pyplot as plt
# from utils.utils import load_test_data
PSZ = "auto"  # patch size: choose from {"auto", 8, 16, 32, 64, 128}
BSZ = 33  # batch size: any positive integer
TEST = 100  # test set length: any positive integer
context_len = 512
pred_len = 128
device = 'cuda:3'

 See https://github.com/google-research/timesfm/blob/master/README.md for updated APIs.
Loaded PyTorch TimesFM, likely because python version is 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0].


In [2]:
tfm = timesfm.TimesFm(
            hparams=timesfm.TimesFmHparams(
                backend='gpu',
                per_core_batch_size=64,
                context_len=context_len,  # currently max supported
                horizon_len=pred_len,  # number of steps to predict
                input_patch_len=32,  # fixed parameters
                output_patch_len=128,
                num_layers=50,
                model_dims=1280,
                use_positional_embedding=False,
                point_forecast_mode='mean',
                device=device,
            ),
            checkpoint=timesfm.TimesFmCheckpoint(
                huggingface_repo_id="google/timesfm-2.0-500m-pytorch"),
        )

Fetching 5 files:   0%|          | 0/5 [00:00<?, ?it/s]

In [3]:
# Loading tsmixup dataset
train_dataset, val_dataset = create_cached_tsmixup_datasets(
        max_samples=300000,
        context_length=512,
        prediction_length=96, # 1 or 96
        num_workers=16,
        cache_dir="/extra/datalab_scratch0/ctadler/time_series_models/mechanistic_interpretability/data/tsmixup_cache/",
        processed_cache_path="/extra/datalab_scratch0/ctadler/time_series_models/mechanistic_interpretability/data/tsmixup_cache/tsmixup_processed_300000_512_96.pkl",
        batch_size=4000
    )

def load_dataset(dataset, ts=1000, pred_length=1, ctx_len=512):
    if dataset == 'tsmixup':
        x = []
        y = []
        for i in range(ts) if isinstance(ts, int) else ts:
            val_dict = val_dataset[i]
            x.append(val_dict['past_values'])
            y.append(val_dict['future_values'])
        x = torch.stack(x)[:, -ctx_len:]
        y = torch.stack(y)[:,:pred_length]
        
    else:
        dataset_path = f"/extra/datalab_scratch0/ctadler/time_series_models/ts_foundation_calibration/data/{dataset}/y_{dataset}.csv"
        timestamp_column = "ds"

        data = pd.read_csv(
            dataset_path,
            parse_dates=[timestamp_column],
            index_col=0
        )

        x = []
        for id, vals in data.groupby('unique_id'):
            x.append(torch.from_numpy(vals['y'].to_numpy(np.float32)))
        x = torch.stack(x)

    if dataset != 'tsmixup':
        y = x[:,ctx_len:ctx_len+pred_length]
        x = x[:,:ctx_len]
    return x, y

In [6]:
@torch.no_grad()
def process_transformer_output(model: timesfm.TimesFm, stats, model_output, output_patch_len, output_dim):
    output_ts = model._model.horizon_ff_layer(model_output)

    # Reshape using view
    b, n, _ = output_ts.shape
    output_ts = output_ts.view(b, n, output_patch_len, output_dim)

    mu = stats[..., 0]
    sigma = stats[..., 1]
    output_ts = output_ts * sigma[:, None, None, None] + mu[:, None, None, None]
    return output_ts[:, -1].cpu().numpy()
    # return model.ppd._reverse_transform(output_ts, stats)

x, y = load_dataset('tsmixup', 100, pred_length=pred_len, ctx_len=context_len)
# print(x.shape)
batch_size = 64
context = [x[i] for i in range(batch_size)]
# print(context)
# print(len(context), context[0].shape)
_, quantile_forecasts, (transformer_output, stats) = tfm.forecast(context, freq=[0] * len(context), get_stacked_transformer=True)
transformer_output = transformer_output.to(device)
stats = stats.to(device)
pred_quants = process_transformer_output(tfm, stats, transformer_output, 128, 10)
print(f"quantile shape: {quantile_forecasts.shape}, pred_quants: {pred_quants.shape}, mse: {np.sum(quantile_forecasts - pred_quants)}")
print(transformer_output.shape, stats.shape)

quantile shape: (64, 128, 10), pred_quants: (64, 128, 10), mse: 0.0
torch.Size([64, 16, 1280]) torch.Size([64, 2])


In [None]:
print(transformer_output.min(), transformer_output.max(), transformer_output.quantile(0.999))

tensor(-86.9447, device='cuda:3') tensor(163.5984, device='cuda:3') tensor(57.9059, device='cuda:3')


: 