In [12]:
from collections import defaultdict
import pandas as pd
import numpy as np
import multiprocessing

import torch
from torch import Tensor
import torch.nn.functional as F
from torch import nn
from torch.distributions import (Normal, StudentT, Poisson)
from uni2ts.distribution.negative_binomial import (NegativeBinomial)
from uni2ts.distribution import (MixtureOutput, 
                                 NormalOutput, 
                                 StudentTOutput,
                                LaplaceOutput, 
                                NormalFixedScaleOutput,
                                NegativeBinomialOutput, 
                                LogNormalOutput)
from utils.data_loader import create_cached_tsmixup_datasets
from load_cached_features import *
from timesfm.pytorch_patched_decoder import ResidualBlock
import matplotlib.pyplot as plt
from tqdm import tqdm
from scipy.special import stdtrit
from scipy.stats import (poisson, nbinom)
from pytorch_forecasting.metrics.quantile import QuantileLoss

from collections.abc import Generator
from typing import Any
from datetime import datetime
from pathlib import Path
from gluonts.dataset.common import ListDataset
# from utils.utils import load_test_data
context_len = 512
device = 'cuda:3'
from chronos import ChronosBoltPipeline

In [13]:
# Loading tsmixup dataset
train_dataset, val_dataset = create_cached_tsmixup_datasets(
        max_samples=300000,
        context_length=512,
        prediction_length=128, # 1 or 128
        num_workers=16,
        cache_dir="/extra/datalab_scratch0/ctadler/time_series_models/mechanistic_interpretability/data/tsmixup_cache/",
        processed_cache_path="/extra/datalab_scratch0/ctadler/time_series_models/mechanistic_interpretability/data/tsmixup_cache/tsmixup_processed_300000_512_128.pkl",
        batch_size=4000
    )

def load_dataset(dataset, ts=1000, pred_length=1, ctx_len=512):
    if dataset == 'tsmixup':
        x = []
        y = []
        for i in range(ts) if isinstance(ts, int) else ts:
            val_dict = val_dataset[i]
            x.append(val_dict['past_values'])
            y.append(val_dict['future_values'])
        x = torch.stack(x)[:, -ctx_len:]
        y = torch.stack(y)[:,:pred_length]
        
    else:
        dataset_path = f"/extra/datalab_scratch0/ctadler/time_series_models/ts_foundation_calibration/data/{dataset}/y_{dataset}.csv"
        timestamp_column = "ds"

        data = pd.read_csv(
            dataset_path,
            parse_dates=[timestamp_column],
            index_col=0
        )

        x = []
        for id, vals in data.groupby('unique_id'):
            x.append(torch.from_numpy(vals['y'].to_numpy(np.float32)))
        x = torch.stack(x)

    if dataset != 'tsmixup':
        y = x[:,ctx_len:ctx_len+pred_length]
        x = x[:,:ctx_len]
    return x, y


🚀 CREATING CACHED TSMIXUP DATASETS
📂 Found existing processed data at /extra/datalab_scratch0/ctadler/time_series_models/mechanistic_interpretability/data/tsmixup_cache/tsmixup_processed_300000_512_128.pkl
⚡ Loading preprocessed data from cache...
✅ Loaded 172,454 preprocessed samples
📅 Cache created: 2025-08-22 13:11:48

📊 DATASET SUMMARY:
  Total processed samples: 172,454
  Context length: 512
  Prediction length: 128
🔀 Shuffling data...
📈 Data split:
  Training samples: 155,208
  Validation samples: 17,246
  Train ratio: 90.0%
🏗️  Creating PyTorch datasets...
🏗️  Dataset created with 155,208 samples
📊 Augmentation: ON
📈 Dataset Statistics (from 1000 samples):
  Sequence lengths: min=640, max=2046, mean=1318
  Value ranges: min=-48.3022, max=72.0737
  Value stats: mean=0.8625, std=2.7795
🏗️  Dataset created with 17,246 samples
📊 Augmentation: OFF
📈 Dataset Statistics (from 1000 samples):
  Sequence lengths: min=640, max=2047, mean=1307
  Value ranges: min=-17.3232, max=473.9922
  Va

In [14]:
pipeline = ChronosBoltPipeline.from_pretrained(
        "amazon/chronos-bolt-base",
        device_map=device,
        torch_dtype=torch.bfloat16,
    )

In [15]:
x, y = load_dataset("tsmixup", 1024)
batch_size = 1024
pred_len = 64
context_len = 512
x_input = x[:batch_size, -context_len:].to(device)
d_model = pipeline.model.model_dim

decoder_out = torch.zeros(batch_size, d_model)
loc_scale = torch.zeros(batch_size, 2)

def save_decoder_hook(module, input, output):
    decoder_out[:] = output.last_hidden_state.squeeze().detach().cpu()
    
def save_encoder_hook(module, input, output):
    loc_scale[:] = torch.stack(output[1], dim=-1).squeeze().detach().cpu()

pipeline.model.decoder.register_forward_hook(save_decoder_hook)
pipeline.model.instance_norm.register_forward_hook(save_encoder_hook)
start_time = time.time()
out = pipeline.model(x_input)
print(f"Time taken: {(time.time()-start_time):4f}")

Time taken: 0.282304


In [18]:
print(decoder_out.shape, loc_scale.shape)
print(out.quantile_preds.shape)

torch.Size([1024, 768]) torch.Size([1024, 2])
torch.Size([1024, 9, 64])


In [None]:
x = torch.ones((10,2))
print(x.device)
x = x.to(device)
print(x.device)

cpu
cuda:2
