<a target="_blank" href="https://colab.research.google.com/github/NX-AI/tirex/blob/main/examples/quick_start_tirex.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open Quick Start In Colab"/>
</a>

### Install TiRex package

In [None]:
!pip install 'tirex-ts[notebooks,gluonts,hfdataset]' -q

### Imports and Load Data

In [None]:
import numpy as np
import pandas as pd
import torch
from matplotlib import pyplot as plt

from tirex import ForecastModel, load_model


def plot_forecast(ctx, quantile_fc, real_future_values=None):
    median_forecast = quantile_fc[:, 4].numpy()
    lower_bound = quantile_fc[:, 0].numpy()
    upper_bound = quantile_fc[:, 8].numpy()

    original_x = range(len(ctx))
    forecast_x = range(len(ctx), len(ctx) + len(median_forecast))

    plt.figure(figsize=(12, 6))
    plt.plot(original_x, ctx, label="Ground Truth Context", color="#4a90d9")
    if real_future_values is not None:
        original_fut_x = range(len(ctx), len(ctx) + len(real_future_values))
        plt.plot(original_fut_x, real_future_values, label="Ground Truth Future", color="#4a90d9", linestyle=":")
    plt.plot(forecast_x, median_forecast, label="Forecast (Median)", color="#d94e4e", linestyle="--")
    plt.fill_between(
        forecast_x, lower_bound, upper_bound, color="#d94e4e", alpha=0.1, label="Forecast 10% - 90% Quantiles"
    )
    plt.xlim(left=0)
    plt.legend()
    plt.grid(True)
    plt.show()


data_base_url = "https://raw.githubusercontent.com/NX-AI/tirex/refs/heads/main/tests/data/"
data_short = pd.read_csv(f"{data_base_url}/air_passengers.csv").values.reshape(-1)
data_long = pd.read_csv(f"{data_base_url}/loop_seattle_5T.csv").values.reshape(-1)

### Load Model

In [None]:
model: ForecastModel = load_model("NX-AI/TiRex")

### Generate Forecast

In [None]:
# Short Horizon - Example
ctx_s, future_s = np.split(data_short, [-12])
quantiles, mean = model.forecast(ctx_s, prediction_length=24)
plot_forecast(ctx_s, quantiles[0], future_s)

In [None]:
# Long Horizon - Example
ctx_l, future_l = np.split(data_long, [-512])
quantiles, mean = model.forecast(ctx_l, prediction_length=768)
plot_forecast(ctx_l, quantiles[0], future_l)

### Input Options

TiRex supports forecasting with different input types

In [None]:
data = torch.tensor(data_short)

# Torch tensor (2D or 1D)
quantiles, means = model.forecast(context=data, prediction_length=24)
print("Predictions (Torch tensor):\n", type(quantiles), quantiles.shape)

# List of Torch tensors (List of 1D) - will be padded
list_torch_data = [data, data, data]
quantiles, means = model.forecast(context=list_torch_data, prediction_length=24, batch_size=2)
print("Predictions (List of Torch tensors):\n", type(quantiles), quantiles.shape)

# NumPy array (2D or 1D)
quantiles, means = model.forecast(context=data.numpy(), prediction_length=24, output_type="torch")
print("Predictions (NumPy):\n", type(quantiles), quantiles.shape)


# List of NumPy arrays (List of 1D) - will be padded
list_numpy_data = [data.numpy()]  # Split into 3 sequences
quantiles, means = model.forecast(context=list_numpy_data, prediction_length=24)
print("Predictions (List of NumPy arrays):\n", type(quantiles), quantiles.shape)


# GluonTS Dataset
try:
    from typing import cast

    from gluonts.dataset import Dataset

    gluon_dataset = cast(Dataset, [{"target": data, "item_id": 1}, {"target": data, "item_id": 22}])
    quantiles, means = model.forecast_gluon(gluon_dataset, prediction_length=24)
    print("Predictions GluonDataset:\n", type(quantiles), quantiles.shape)
    # If you use also `glutonts` as your output type the start_time and item_id get preserved accordingly
    predictions_gluon = model.forecast_gluon(gluon_dataset, prediction_length=24, output_type="gluonts")
    print("Predictions GluonDataset:\n", type(predictions_gluon), type(predictions_gluon[0]))
except Exception as e:
    print(e)
    # To use the gluonts function you need to install the optional dependency
    # pip install tirex[gluonts]
    pass

### Output Options


TiRex supports different output types for the forecasts

In [None]:
data = torch.tensor(data_short)

# Default: 2D Torch tensor
quantiles, means = model.forecast(context=data, prediction_length=24, output_type="torch")
print("Predictions:\n", type(quantiles), quantiles.shape)


# 2D Numpy Array
quantiles, means = model.forecast(context=data, prediction_length=24, output_type="numpy")
print("Predictions:\n", type(quantiles), quantiles.shape)


# Iterate by patch
# You can also use the forecast function as iterable. This might help with big datasets. All output_types are supported
for i, fc_batch in enumerate(
    model.forecast(context=[data, data, data, data, data], batch_size=2, output_type="torch", yield_per_batch=True)
):
    quantiles, means = fc_batch
    print(f"Predictions batch {i}:\n", type(quantiles), quantiles.shape)


try:
    # QuantileForecast (GluonTS)
    predictions_gluonts = model.forecast(context=data, prediction_length=24, output_type="gluonts")
    print("Predictions (GluonTS Quantile Forecast):\n", type(predictions_gluon), type(predictions_gluon[0]))
    predictions_gluonts[0].plot()
except Exception as e:
    print(e)
    # To use the gluonts function you need to install the optional dependency
    # pip install tirex[gluonts]