In [None]:
#!git clone -b update-gluonts https://github.com/time-series-foundation-models/lag-llama/
!cd lag-llama && pip install -r requirements.txt

Cloning into 'lag-llama'...
remote: Enumerating objects: 508, done.[K
remote: Counting objects: 100% (181/181), done.[K
remote: Compressing objects: 100% (68/68), done.[K
remote: Total 508 (delta 154), reused 113 (delta 113), pack-reused 327 (from 3)[K
Receiving objects: 100% (508/508), 286.90 KiB | 6.83 MiB/s, done.
Resolving deltas: 100% (253/253), done.


In [44]:
!pip install -U torch torchvision
!huggingface-cli download time-series-foundation-models/Lag-Llama lag-llama.ckpt --local-dir .


Collecting torch
  Downloading torch-2.7.1-cp39-none-macosx_11_0_arm64.whl.metadata (29 kB)
Collecting torchvision
  Downloading torchvision-0.22.1-cp39-cp39-macosx_11_0_arm64.whl.metadata (6.1 kB)
Downloading torch-2.7.1-cp39-none-macosx_11_0_arm64.whl (68.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m68.6/68.6 MB[0m [31m31.9 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hDownloading torchvision-0.22.1-cp39-cp39-macosx_11_0_arm64.whl (1.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.9/1.9 MB[0m [31m41.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch, torchvision
  Attempting uninstall: torch
    Found existing installation: torch 2.7.0
    Uninstalling torch-2.7.0:
      Successfully uninstalled torch-2.7.0
Successfully installed torch-2.7.1 torchvision-0.22.1
Consider using `hf_transfer` for faster downloads. This solution comes with some limitations. See https://huggingface.co/docs/huggingface_hub/hf_t

In [45]:
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from gluonts.dataset.pandas import PandasDataset
from gluonts.dataset.split import split
from gluonts.evaluation import make_evaluation_predictions
from lag_llama.gluon.estimator import LagLlamaEstimator

# Configuration
PREDICTION_LENGTH = 12
CONTEXT_LENGTH = 32  # LagLlama was trained with context_length=32
TEST_LENGTH = 24
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [51]:
# Generate synthetic data with explicit frequency to avoid pandas warnings
def generate_time_series_data(periods=200, n_variables=3, univariate=False):
    # Use daily frequency to avoid hourly frequency issues with LagLlama
    index = pd.date_range(start="2021-01-01", periods=periods, freq="D")
    if univariate:
        trend = np.linspace(0, 10, periods)
        seasonal = 3 * np.sin(2 * np.pi * np.arange(periods) / 7)  # Weekly seasonality
        noise = np.random.normal(0, 0.5, periods)
        data = trend + seasonal + noise
        return pd.DataFrame(data, index=index, columns=["target"])
    else:
        columns = [f"series_{i}" for i in range(n_variables)]
        data = np.random.randn(periods, n_variables)
        for i in range(n_variables):
            trend = np.linspace(0, 5 + i*2, periods)
            seasonal = (1 + i*0.5) * np.sin(2 * np.pi * np.arange(periods) / (7 + i*3))  # Different weekly patterns
            data[:, i] += trend + seasonal
        return pd.DataFrame(data, index=index, columns=columns)

df_univariate = generate_time_series_data(univariate=True)
df_multivariate = generate_time_series_data()

print(f"Univariate shape: {df_univariate.shape}")
print(f"Multivariate shape: {df_multivariate.shape}")

# Convert to GluonTS format
gluon_uni = PandasDataset(dict(df_univariate))
gluon_multi = PandasDataset(dict(df_multivariate))

# Split data
train_uni, test_template_uni = split(gluon_uni, offset=-TEST_LENGTH)
train_multi, test_template_multi = split(gluon_multi, offset=-TEST_LENGTH)


Univariate shape: (200, 1)
Multivariate shape: (200, 3)


In [52]:
def get_lag_llama_predictions(dataset, prediction_length, device, context_length=32, num_samples=100):
    """
    Get LagLlama predictions using pre-trained weights (zero-shot).
    """
    try:
        # Load pre-trained checkpoint
        ckpt = torch.load("lag-llama.ckpt", map_location=device, weights_only=False)
        estimator_args = ckpt["hyper_parameters"]["model_kwargs"]
        
        # Create estimator with pre-trained weights and fixed frequency issues
        estimator = LagLlamaEstimator(
            ckpt_path="lag-llama.ckpt",  # This loads pre-trained weights!
            prediction_length=prediction_length,
            context_length=context_length,
            # Use exact parameters from pre-trained model
            input_size=estimator_args["input_size"],
            n_layer=estimator_args["n_layer"],
            n_embd_per_head=estimator_args["n_embd_per_head"],
            n_head=estimator_args["n_head"],
            scaling=estimator_args["scaling"],
            time_feat=estimator_args["time_feat"],
            # Fix frequency issues by using empty lags_seq
            lags_seq=[],  # Disable automatic lag detection to avoid frequency errors
            batch_size=1,
            num_parallel_samples=num_samples,
            device=device,
        )
        
        # Create predictor (no training needed!)
        lightning_module = estimator.create_lightning_module()
        transformation = estimator.create_transformation()
        predictor = estimator.create_predictor(transformation, lightning_module)
        
        # Make predictions (zero-shot)
        forecast_it, ts_it = make_evaluation_predictions(
            dataset=dataset,
            predictor=predictor,
            num_samples=num_samples
        )
        
        forecasts = list(forecast_it)
        tss = list(ts_it)
        
        return forecasts, tss
        
    except Exception as e:
        print(f"LagLlama failed: {e}")
        return None, None

# LagLlama Zero-shot Predictions - Univariate
print("=== LAGLLAMA ZERO-SHOT UNIVARIATE ===")
forecasts_uni, tss_uni = get_lag_llama_predictions(
    dataset=train_uni, 
    prediction_length=PREDICTION_LENGTH, 
    device=device,
    context_length=CONTEXT_LENGTH
)

if forecasts_uni is not None:
    # Extract results (same format as Moirai)
    results_uni = {}
    for col in df_univariate.columns:
        results_uni[col] = []
    
    for forecast in forecasts_uni:
        series_name = forecast.item_id
        if hasattr(forecast, 'samples'):
            median_pred = np.median(forecast.samples, axis=0)
        else:
            median_pred = forecast.mean
        if series_name in results_uni:
            results_uni[series_name].extend(median_pred.tolist())
    
    print("Univariate Results:")
    for series_name, predictions in results_uni.items():
        print(f"  {series_name}: {predictions[:5]}...")
else:
    print("❌ Univariate prediction failed")

=== LAGLLAMA ZERO-SHOT UNIVARIATE ===
LagLlama failed: Weights only load failed. This file can still be loaded, to do so you have two options, [1mdo those steps only if you trust the source of the checkpoint[0m. 
	(1) In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
	(2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.
	WeightsUnpickler error: Unsupported global: GLOBAL gluonts.torch.distributions.studentT.StudentTOutput was not an allowed global by default. Please use `torch.serialization.add_safe_globals([gluonts.torch.distributions.studentT.StudentTOutput])` or the `torch.serialization.safe_globals([gluonts.torch.distributions.studentT.StudentTOutput])` context manager to 