In [9]:
import torch
from tirex import load_model, ForecastModel


import pandas as pd
import numpy as np

from utils.data_loader import create_cached_tsmixup_datasets
import time

# from utils.utils import load_test_data
context_len = 512
device = 'cuda:2'

In [10]:
# Loading tsmixup dataset
train_dataset, val_dataset = create_cached_tsmixup_datasets(
        max_samples=300000,
        context_length=512,
        prediction_length=128, # 1 or 128
        num_workers=16,
        cache_dir="/extra/datalab_scratch0/ctadler/time_series_models/mechanistic_interpretability/data/tsmixup_cache/",
        processed_cache_path="/extra/datalab_scratch0/ctadler/time_series_models/mechanistic_interpretability/data/tsmixup_cache/tsmixup_processed_300000_512_128.pkl",
        batch_size=4000
    )

def load_dataset(dataset, ts=1000, pred_length=1, ctx_len=512):
    if dataset == 'tsmixup':
        x = []
        y = []
        for i in range(ts) if isinstance(ts, int) else ts:
            val_dict = val_dataset[i]
            x.append(val_dict['past_values'])
            y.append(val_dict['future_values'])
        x = torch.stack(x)[:, -ctx_len:]
        y = torch.stack(y)[:,:pred_length]
        
    else:
        dataset_path = f"/extra/datalab_scratch0/ctadler/time_series_models/ts_foundation_calibration/data/{dataset}/y_{dataset}.csv"
        timestamp_column = "ds"

        data = pd.read_csv(
            dataset_path,
            parse_dates=[timestamp_column],
            index_col=0
        )

        x = []
        for id, vals in data.groupby('unique_id'):
            x.append(torch.from_numpy(vals['y'].to_numpy(np.float32)))
        x = torch.stack(x)

    if dataset != 'tsmixup':
        y = x[:,ctx_len:ctx_len+pred_length]
        x = x[:,:ctx_len]
    return x, y


🚀 CREATING CACHED TSMIXUP DATASETS
📂 Found existing processed data at /extra/datalab_scratch0/ctadler/time_series_models/mechanistic_interpretability/data/tsmixup_cache/tsmixup_processed_300000_512_128.pkl
⚡ Loading preprocessed data from cache...
✅ Loaded 172,454 preprocessed samples
📅 Cache created: 2025-08-22 13:11:48

📊 DATASET SUMMARY:
  Total processed samples: 172,454
  Context length: 512
  Prediction length: 128
🔀 Shuffling data...
📈 Data split:
  Training samples: 155,208
  Validation samples: 17,246
  Train ratio: 90.0%
🏗️  Creating PyTorch datasets...
🏗️  Dataset created with 155,208 samples
📊 Augmentation: ON
📈 Dataset Statistics (from 1000 samples):
  Sequence lengths: min=640, max=2046, mean=1318
  Value ranges: min=-48.3022, max=72.0737
  Value stats: mean=0.8625, std=2.7795
🏗️  Dataset created with 17,246 samples
📊 Augmentation: OFF
📈 Dataset Statistics (from 1000 samples):
  Sequence lengths: min=640, max=2047, mean=1307
  Value ranges: min=-17.3232, max=473.9922
  Va

In [11]:
model: ForecastModel = load_model("NX-AI/TiRex")

  @conditional_decorator(
  @conditional_decorator(


In [12]:
batch_size = 1024
x, y = load_dataset("tsmixup", batch_size)
pred_len = 128
context_len = 512
x_input = x[:batch_size, -context_len:]
d_model = model.model_config.block_kwargs.embedding_dim
patch_size = model.model_config.output_patch_size

out_patches = pred_len // patch_size
decoder_out = torch.zeros(batch_size, out_patches, d_model)
loc_scale = torch.zeros((batch_size, 2))
start_time = time.time()
forecast = model.forecast(context=x_input, prediction_length=pred_len, batch_size=batch_size,
                                     max_accelerated_rollout_steps=4, 
                                     get_loc_scale=loc_scale,
                                     get_hidden_states=decoder_out)
print(f"Time taken: {(time.time()-start_time):4f}")

Time taken: 0.552509


In [None]:
print(forecast[0].shape, forecast[1].shape)
# print(loc_scale[:5,:])
print(decoder_out.shape, loc_scale.shape)

cpu torch.Size([1024, 128])
torch.Size([1024, 4, 512]) torch.Size([1024, 2])


In [14]:
from torch import Tensor, nn
import einops