In [1]:
!git clone -b update-gluonts https://github.com/time-series-foundation-models/lag-llama/


Cloning into 'lag-llama'...
remote: Enumerating objects: 508, done.[K
remote: Counting objects: 100% (183/183), done.[K
remote: Compressing objects: 100% (69/69), done.[K
remote: Total 508 (delta 155), reused 114 (delta 114), pack-reused 325 (from 3)[K
Receiving objects: 100% (508/508), 286.89 KiB | 5.52 MiB/s, done.
Resolving deltas: 100% (253/253), done.


In [2]:
cd lag-llama

/Users/farhanmashrur/Desktop/cds/benchmark/foundation_model_notebooks/lag-llama


In [3]:
!pip install -r requirements.txt
!pip install -U torch torchvision
!huggingface-cli download time-series-foundation-models/Lag-Llama lag-llama.ckpt --local-dir .



Downloading 'lag-llama.ckpt' to '.cache/huggingface/download/lag-llama.ckpt.b5a5c4b8a0cfe9b81bdac35ed5d88b5033cd119b5206c28e9cd67c4b45fb2c96.incomplete'
lag-llama.ckpt: 100%|██████████████████████| 29.5M/29.5M [00:00<00:00, 47.1MB/s]
Download complete. Moving file to lag-llama.ckpt
lag-llama.ckpt


In [2]:
# Fix for numpy/pandas compatibility issue
# Run this first to fix the environment

# Option 1: Quick fix - reinstall numpy and pandas
!pip install --upgrade --force-reinstall numpy pandas

# Option 2: For Python 3.12, use compatible versions
# !pip install numpy>=1.26.0 pandas>=2.1.0 --force-reinstall

# Option 3: If still having issues, reinstall with no dependencies first
# !pip uninstall numpy pandas -y
# !pip install numpy pandas

# Restart kernel after running the above commands
print("Please restart your kernel after running the dependency fixes above")

# ============================================================================
# AFTER RESTARTING KERNEL, RUN THE CODE BELOW
# ============================================================================

# Installation (run these after fixing numpy/pandas)
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install gluonts[torch]
!pip install matplotlib
!pip install git+https://github.com/time-series-foundation-models/lag-llama.git

# Imports - run after fixing dependencies
import pandas as pd
import numpy as np
import torch
from datetime import datetime, timedelta
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

# Check versions to ensure compatibility
print(f"NumPy version: {np.__version__}")
print(f"Pandas version: {pd.__version__}")
print(f"PyTorch version: {torch.__version__}")

# Test basic functionality
test_df = pd.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]})
print("Basic pandas test successful!")
print(test_df)

# Setup device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Import GluonTS and Lag-Llama after fixing base dependencies
try:
    from gluonts.dataset.pandas import PandasDataset
    from gluonts.evaluation import make_evaluation_predictions
    print("GluonTS imported successfully!")
except ImportError as e:
    print(f"GluonTS import error: {e}")
    print("Try: pip install gluonts[torch] --upgrade")

try:
    from lag_llama.gluon.estimator import LagLlamaEstimator
    print("Lag-Llama imported successfully!")
except ImportError as e:
    print(f"Lag-Llama import error: {e}")
    print("Try: pip install git+https://github.com/time-series-foundation-models/lag-llama.git")

# Simple Lag-Llama wrapper class (same as before)
class LagLlamaForecaster:
    def __init__(self, device="auto", context_length=128, num_samples=100):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device == "auto" else device
        self.context_length = context_length
        self.num_samples = num_samples
        self.predictor = None
        
    def _create_predictor(self, prediction_length):
        """Create Lag-Llama predictor"""
        # You'll need to download the checkpoint first
        ckpt_path = "lag-llama/lag-llama.ckpt"
        
        try:
            ckpt = torch.load(ckpt_path, map_location=self.device, weights_only=False)
        except FileNotFoundError:
            print("Checkpoint not found! Please download lag-llama.ckpt")
            print("You can get it from: https://huggingface.co/time-series-foundation-models/Lag-Llama")
            raise
            
        estimator_args = ckpt["hyper_parameters"]["model_kwargs"]
        
        # RoPE scaling for longer contexts
        rope_scaling_arguments = {
            "type": "linear", 
            "factor": max(1.0, (self.context_length + prediction_length) / estimator_args["context_length"]),
        }
        
        estimator = LagLlamaEstimator(
            ckpt_path=ckpt_path,
            prediction_length=prediction_length,
            context_length=self.context_length,
            input_size=estimator_args["input_size"],
            n_layer=estimator_args["n_layer"],
            n_embd_per_head=estimator_args["n_embd_per_head"],
            n_head=estimator_args["n_head"],
            scaling=estimator_args["scaling"],
            time_feat=estimator_args["time_feat"],
            rope_scaling=rope_scaling_arguments,
            batch_size=1,
            num_parallel_samples=self.num_samples,
            device=self.device,
        )
        
        lightning_module = estimator.create_lightning_module()
        transformation = estimator.create_transformation()
        return estimator.create_predictor(transformation, lightning_module)
    
    def forecast_on_df(self, data, value_name="y", prediction_length=30, freq="D"):
        """Forecast on pandas DataFrame - similar interface to TimesFM"""
        # Prepare data
        df = data.copy()
        df['ds'] = pd.to_datetime(df['ds'])
        df = df.sort_values(['unique_id', 'ds'] if 'unique_id' in df.columns else ['ds'])
        
        # Convert to float32 for Lag-Llama
        df[value_name] = df[value_name].astype('float32')
        
        # Create dataset
        if 'unique_id' in df.columns:
            # Multiple series
            dataset = PandasDataset.from_long_dataframe(
                df, target=value_name, item_id="unique_id", timestamp="ds", freq=freq
            )
        else:
            # Single series
            df_dict = {0: df.set_index('ds')[value_name]}
            dataset = PandasDataset(df_dict, target=value_name)
        
        # Create predictor
        predictor = self._create_predictor(prediction_length)
        
        # Generate forecasts
        forecast_it, ts_it = make_evaluation_predictions(
            dataset=dataset,
            predictor=predictor,
            num_samples=self.num_samples
        )
        
        forecasts = list(forecast_it)
        
        # Convert to DataFrame format
        forecast_results = []
        
        for i, forecast in enumerate(forecasts):
            # Get unique_id
            unique_id = forecast.item_id if hasattr(forecast, 'item_id') else f"series_{i}"
            
            # Generate future dates
            last_date = forecast.start_date.to_timestamp()
            future_dates = pd.date_range(
                start=last_date, 
                periods=prediction_length, 
                freq=freq
            )
            
            # Get mean prediction and quantiles
            mean_forecast = forecast.mean
            q10 = forecast.quantile(0.1)
            q90 = forecast.quantile(0.9)
            
            for j, date in enumerate(future_dates):
                forecast_results.append({
                    'unique_id': unique_id,
                    'ds': date,
                    f'{value_name}_forecast': mean_forecast[j],
                    f'{value_name}_q10': q10[j],
                    f'{value_name}_q90': q90[j]
                })
        
        return pd.DataFrame(forecast_results)

# Test with sample data (run after fixing dependencies)
def test_with_sample_data():
    # Create sample data
    start_date = datetime.now()
    num_days = 365
    dates = [start_date + timedelta(days=i) for i in range(num_days)]
    sales = [100 + i * 10 + np.random.normal(0, 10) for i in range(num_days)]
    
    data = pd.DataFrame({
        'ds': dates,
        'sales': sales,
        'unique_id': "sales"
    })
    
    print("Sample data created:")
    print(data.head())
    print(f"Data shape: {data.shape}")
    
    # Initialize model (only if checkpoint is available)
    try:
        tfm = LagLlamaForecaster(
            device="auto",
            context_length=128,
            num_samples=100
        )
        print("LagLlamaForecaster initialized successfully!")
        
        # Note: Actual forecasting requires the checkpoint file
        print("To run forecasts, you need to download lag-llama.ckpt first")
        
    except Exception as e:
        print(f"Error initializing forecaster: {e}")
    
    return data

# Run test
if __name__ == "__main__":
    try:
        sample_data = test_with_sample_data()
    except Exception as e:
        print(f"Error in test: {e}")
        print("Please restart kernel and run dependency fixes first")

[0mCollecting numpy
  Using cached numpy-2.3.2-cp311-cp311-macosx_14_0_arm64.whl.metadata (62 kB)
Collecting pandas
  Using cached pandas-2.3.1-cp311-cp311-macosx_11_0_arm64.whl.metadata (91 kB)
Collecting python-dateutil>=2.8.2 (from pandas)
  Using cached python_dateutil-2.9.0.post0-py2.py3-none-any.whl.metadata (8.4 kB)
Collecting pytz>=2020.1 (from pandas)
  Using cached pytz-2025.2-py2.py3-none-any.whl.metadata (22 kB)
Collecting tzdata>=2022.7 (from pandas)
  Using cached tzdata-2025.2-py2.py3-none-any.whl.metadata (1.4 kB)
Collecting six>=1.5 (from python-dateutil>=2.8.2->pandas)
  Using cached six-1.17.0-py2.py3-none-any.whl.metadata (1.7 kB)
Using cached numpy-2.3.2-cp311-cp311-macosx_14_0_arm64.whl (5.4 MB)
Using cached pandas-2.3.1-cp311-cp311-macosx_11_0_arm64.whl (10.8 MB)
Using cached python_dateutil-2.9.0.post0-py2.py3-none-any.whl (229 kB)
Using cached pytz-2025.2-py2.py3-none-any.whl (509 kB)
Using cached tzdata-2025.2-py2.py3-none-any.whl (347 kB)
Using cached six-1.

In [24]:


!pip install --upgrade --force-reinstall numpy pandas

# Option 2: For Python 3.12, use compatible versions
# !pip install numpy>=1.26.0 pandas>=2.1.0 --force-reinstall

# Option 3: If still having issues, reinstall with no dependencies first
# !pip uninstall numpy pandas -y
# !pip install numpy pandas

# Restart kernel after running the above commands
print("Please restart your kernel after running the dependency fixes above")

# ============================================================================
# AFTER RESTARTING KERNEL, RUN THE CODE BELOW
# ============================================================================

# Installation (run these after fixing numpy/pandas)
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install gluonts[torch]
!pip install matplotlib
!pip install git+https://github.com/time-series-foundation-models/lag-llama.git

# Imports - run after fixing dependencies
import pandas as pd
import numpy as np
import torch
from datetime import datetime, timedelta
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

# Check versions to ensure compatibility
print(f"NumPy version: {np.__version__}")
print(f"Pandas version: {pd.__version__}")
print(f"PyTorch version: {torch.__version__}")

# Test basic functionality
test_df = pd.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]})
print("Basic pandas test successful!")
print(test_df)

# Setup device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Import GluonTS and Lag-Llama after fixing base dependencies
try:
    from gluonts.dataset.pandas import PandasDataset
    from gluonts.evaluation import make_evaluation_predictions
    print("GluonTS imported successfully!")
except ImportError as e:
    print(f"GluonTS import error: {e}")
    print("Try: pip install gluonts[torch] --upgrade")

try:
    from lag_llama.gluon.estimator import LagLlamaEstimator
    print("Lag-Llama imported successfully!")
except ImportError as e:
    print(f"Lag-Llama import error: {e}")
    print("Try: pip install git+https://github.com/time-series-foundation-models/lag-llama.git")

# Simple Lag-Llama wrapper class (same as before)
class LagLlamaForecaster:
    def __init__(self, device="auto", context_length=128, num_samples=100):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device == "auto" else device
        self.context_length = context_length
        self.num_samples = num_samples
        self.predictor = None
        
    def _create_predictor(self, prediction_length):
        """Create Lag-Llama predictor"""
        # You'll need to download the checkpoint first
        ckpt_path = "lag-llama/lag-llama.ckpt"
        
        try:
            ckpt = torch.load(ckpt_path, map_location=self.device, weights_only=False)
        except FileNotFoundError:
            print("Checkpoint not found! Please download lag-llama.ckpt")
            print("You can get it from: https://huggingface.co/time-series-foundation-models/Lag-Llama")
            raise
            
        estimator_args = ckpt["hyper_parameters"]["model_kwargs"]
        
        # RoPE scaling for longer contexts
        rope_scaling_arguments = {
            "type": "linear", 
            "factor": max(1.0, (self.context_length + prediction_length) / estimator_args["context_length"]),
        }
        
        estimator = LagLlamaEstimator(
            ckpt_path=ckpt_path,
            prediction_length=prediction_length,
            context_length=self.context_length,
            input_size=estimator_args["input_size"],
            n_layer=estimator_args["n_layer"],
            n_embd_per_head=estimator_args["n_embd_per_head"],
            n_head=estimator_args["n_head"],
            scaling=estimator_args["scaling"],
            time_feat=estimator_args["time_feat"],
            rope_scaling=rope_scaling_arguments,
            batch_size=1,
            num_parallel_samples=self.num_samples,
            device=self.device,
        )
        
        lightning_module = estimator.create_lightning_module()
        transformation = estimator.create_transformation()
        return estimator.create_predictor(transformation, lightning_module)
    
    def forecast_on_df(self, data, value_name="y", prediction_length=30, freq="D"):
        """Forecast on pandas DataFrame - similar interface to TimesFM"""
        # Prepare data
        df = data.copy()
        df['ds'] = pd.to_datetime(df['ds'])
        df = df.sort_values(['unique_id', 'ds'] if 'unique_id' in df.columns else ['ds'])
        
        # Convert to float32 for Lag-Llama
        df[value_name] = df[value_name].astype('float32')
        
        # Create dataset
        if 'unique_id' in df.columns:
            # Multiple series
            dataset = PandasDataset.from_long_dataframe(
                df, target=value_name, item_id="unique_id", timestamp="ds", freq=freq
            )
        else:
            # Single series
            df_dict = {0: df.set_index('ds')[value_name]}
            dataset = PandasDataset(df_dict, target=value_name)
        
        # Create predictor
        predictor = self._create_predictor(prediction_length)
        
        # Generate forecasts
        forecast_it, ts_it = make_evaluation_predictions(
            dataset=dataset,
            predictor=predictor,
            num_samples=self.num_samples
        )
        
        forecasts = list(forecast_it)
        
        # Convert to DataFrame format
        forecast_results = []
        
        for i, forecast in enumerate(forecasts):
            # Get unique_id
            unique_id = forecast.item_id if hasattr(forecast, 'item_id') else f"series_{i}"
            
            # Generate future dates
            last_date = forecast.start_date.to_timestamp()
            future_dates = pd.date_range(
                start=last_date, 
                periods=prediction_length, 
                freq=freq
            )
            
            # Get mean prediction and quantiles
            mean_forecast = forecast.mean
            q10 = forecast.quantile(0.1)
            q90 = forecast.quantile(0.9)
            
            for j, date in enumerate(future_dates):
                forecast_results.append({
                    'unique_id': unique_id,
                    'ds': date,
                    f'{value_name}_forecast': mean_forecast[j],
                    f'{value_name}_q10': q10[j],
                    f'{value_name}_q90': q90[j]
                })
        
        return pd.DataFrame(forecast_results)

# Test with sample data (run after fixing dependencies)
from pathlib import Path

def test_with_sample_data():
    # Create sample data
    start_date = datetime.now()
    num_days = 365
    dates = [start_date + timedelta(days=i) for i in range(num_days)]
    sales = [100 + i * 10 + np.random.normal(0, 10) for i in range(num_days)]
    
    data = pd.DataFrame({
        'ds': dates,
        'sales': sales,
        'unique_id': "sales"
    })
    
    print("Sample data created:")
    print(data.head())
    print(f"Data shape: {data.shape}")
    
    # Initialize model (only if checkpoint is available)
    ckpt_path = Path("lag-llama/lag-llama.ckpt")
    
    if not ckpt_path.exists():
        print("To run forecasts, you need to download lag-llama.ckpt first")
        print("Visit: https://huggingface.co/time-series-foundation-models/Lag-Llama")
        return data

    try:
        tfm = LagLlamaForecaster(
            device="auto",
            context_length=128,
            num_samples=100
        )
        print("LagLlamaForecaster initialized successfully!")
    except Exception as e:
        print(f"Error initializing forecaster: {e}")
    
    return data

# Run test
if __name__ == "__main__":
    try:
        sample_data = test_with_sample_data()
    except Exception as e:
        print(f"Error in test: {e}")
        print("Please restart kernel and run dependency fixes first")

[0mCollecting numpy
  Using cached numpy-2.3.2-cp311-cp311-macosx_14_0_arm64.whl.metadata (62 kB)
Collecting pandas
  Using cached pandas-2.3.1-cp311-cp311-macosx_11_0_arm64.whl.metadata (91 kB)
Collecting python-dateutil>=2.8.2 (from pandas)
  Using cached python_dateutil-2.9.0.post0-py2.py3-none-any.whl.metadata (8.4 kB)
Collecting pytz>=2020.1 (from pandas)
  Using cached pytz-2025.2-py2.py3-none-any.whl.metadata (22 kB)
Collecting tzdata>=2022.7 (from pandas)
  Using cached tzdata-2025.2-py2.py3-none-any.whl.metadata (1.4 kB)
Collecting six>=1.5 (from python-dateutil>=2.8.2->pandas)
  Using cached six-1.17.0-py2.py3-none-any.whl.metadata (1.7 kB)
Using cached numpy-2.3.2-cp311-cp311-macosx_14_0_arm64.whl (5.4 MB)
Using cached pandas-2.3.1-cp311-cp311-macosx_11_0_arm64.whl (10.8 MB)
Using cached python_dateutil-2.9.0.post0-py2.py3-none-any.whl (229 kB)
Using cached pytz-2025.2-py2.py3-none-any.whl (509 kB)
Using cached tzdata-2025.2-py2.py3-none-any.whl (347 kB)
Using cached six-1.

In [22]:
from pathlib import Path

ckpt_path = Path( "lag-llama/lag-llama.ckpt") 
print(ckpt_path.exists())


True


In [25]:

import pandas as pd
import numpy as np
import torch
from datetime import datetime, timedelta
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

from gluonts.dataset.pandas import PandasDataset
from gluonts.evaluation import make_evaluation_predictions
from lag_llama.gluon.estimator import LagLlamaEstimator

# TimesFM-style Lag-Llama interface
class LagLlama:
    def __init__(self, 
                 device="auto",
                 context_len=128,
                 horizon_len=30,
                 num_samples=100,
                 checkpoint_path="lag-llama/lag-llama.ckpt"):
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device == "auto" else device
        self.context_len = context_len
        self.horizon_len = horizon_len
        self.num_samples = num_samples
        self.checkpoint_path = checkpoint_path
        
        print(f"Initialized Lag-Llama with device: {self.device}")
        print(f"Context length: {self.context_len}, Horizon: {self.horizon_len}")
    
    def forecast_on_df(self, inputs, freq="D", value_name="y", num_jobs=-1):
        """
        Forecast on DataFrame - TimesFM-style interface
        
        Args:
            inputs: DataFrame with ['ds', value_name, 'unique_id'] columns
            freq: Frequency string ('D', 'H', 'M', etc.)
            value_name: Name of target column
            num_jobs: Ignored (for compatibility)
            
        Returns:
            DataFrame with forecasts
        """
        
        # Prepare data
        df = inputs.copy()
        df['ds'] = pd.to_datetime(df['ds'])
        df = df.sort_values(['unique_id', 'ds'] if 'unique_id' in df.columns else ['ds'])
        df[value_name] = df[value_name].astype('float32')
        
        # Create GluonTS dataset
        if 'unique_id' in df.columns:
            dataset = PandasDataset.from_long_dataframe(
                df, target=value_name, item_id="unique_id", timestamp="ds", freq=freq
            )
        else:
            df_dict = {0: df.set_index('ds')[value_name]}
            dataset = PandasDataset(df_dict, target=value_name)
        
        # Load checkpoint and create predictor
        ckpt = torch.load(self.checkpoint_path, map_location=self.device, weights_only=False)
        estimator_args = ckpt["hyper_parameters"]["model_kwargs"]
        
        # RoPE scaling for longer contexts
        rope_scaling_arguments = {
            "type": "linear", 
            "factor": max(1.0, (self.context_len + self.horizon_len) / estimator_args["context_length"]),
        }
        
        estimator = LagLlamaEstimator(
            ckpt_path=self.checkpoint_path,
            prediction_length=self.horizon_len,
            context_length=self.context_len,
            input_size=estimator_args["input_size"],
            n_layer=estimator_args["n_layer"],
            n_embd_per_head=estimator_args["n_embd_per_head"],
            n_head=estimator_args["n_head"],
            scaling=estimator_args["scaling"],
            time_feat=estimator_args["time_feat"],
            rope_scaling=rope_scaling_arguments,
            batch_size=1,
            num_parallel_samples=self.num_samples,
            device=self.device,
        )
        
        lightning_module = estimator.create_lightning_module()
        transformation = estimator.create_transformation()
        predictor = estimator.create_predictor(transformation, lightning_module)
        
        # Generate forecasts
        forecast_it, ts_it = make_evaluation_predictions(
            dataset=dataset,
            predictor=predictor,
            num_samples=self.num_samples
        )
        
        forecasts = list(forecast_it)
        
        # Convert to DataFrame (TimesFM format)
        forecast_results = []
        
        for i, forecast in enumerate(forecasts):
            unique_id = forecast.item_id if hasattr(forecast, 'item_id') else f"series_{i}"
            
            # Generate future dates
            last_date = forecast.start_date.to_timestamp()
            future_dates = pd.date_range(
                start=last_date, 
                periods=self.horizon_len, 
                freq=freq
            )
            
            # Get predictions
            mean_forecast = forecast.mean
            q10 = forecast.quantile(0.1)
            q90 = forecast.quantile(0.9)
            
            for j, date in enumerate(future_dates):
                forecast_results.append({
                    'unique_id': unique_id,
                    'ds': date,
                    f'{value_name}': mean_forecast[j],  # TimesFM format
                    f'{value_name}_q10': q10[j],
                    f'{value_name}_q90': q90[j]
                })
        
        return pd.DataFrame(forecast_results)

# Initialize model (exactly like TimesFM)
tfm = LagLlama(
    device="auto",
    context_len=128,
    horizon_len=30,
    num_samples=100,
    checkpoint_path="lag-llama/lag-llama.ckpt"
)

# Univariate example (exactly like TimesFM)
start_date = datetime.now()
num_days = 365
dates = [start_date + timedelta(days=i) for i in range(num_days)]
sales = [100 + i * 10 + np.random.normal(0, 10) for i in range(num_days)]

data = pd.DataFrame({
    'ds': dates,
    'sales': sales,
    'unique_id': "sales"
})

# Multivariate example (other variables get ignored)
start_date = datetime.now()
num_days = 365
dates = [start_date + timedelta(days=i) for i in range(num_days)]
sales = [100 + i * 10 + np.random.normal(0, 10) for i in range(num_days)]
ad_spend = [50 + 20 * np.sin(2 * np.pi * i / 30) + np.random.normal(0, 5) for i in range(num_days)]
temperature = [20 + 10 * np.sin(2 * np.pi * i / 365) + np.random.normal(0, 2) for i in range(num_days)]

data = pd.DataFrame({
    'ds': dates,
    'sales': sales,
    'ad_spend': ad_spend,
    'temperature': temperature,
    'unique_id': "sales"
})

# Forecast (exactly like TimesFM)
import time
start_time = time.time()

forecast_df = tfm.forecast_on_df(
    inputs=data,
    freq="D",
    value_name="sales",
    num_jobs=-1,
)

print(f"Wall time: {time.time() - start_time:.2f}s")

# Display results (exactly like TimesFM)
print("Original data:")
print(data.head())
print(f"Shape: {data.shape}")

print("\nForecast:")
print(forecast_df.head())
print(f"Shape: {forecast_df.shape}")

# Plot results
plt.figure(figsize=(12, 6))
plt.plot(data['ds'][-60:], data['sales'][-60:], label='Historical', color='blue')
plt.plot(forecast_df['ds'], forecast_df['sales'], label='Forecast', color='red')
plt.fill_between(forecast_df['ds'], 
                 forecast_df['sales_q10'], 
                 forecast_df['sales_q90'], 
                 alpha=0.3, color='red', label='80% Confidence')
plt.legend()
plt.title('Sales Forecast with Lag-Llama')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

Initialized Lag-Llama with device: cpu
Context length: 128, Horizon: 30


UnpicklingError: Weights only load failed. This file can still be loaded, to do so you have two options, [1mdo those steps only if you trust the source of the checkpoint[0m. 
	(1) In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
	(2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.
	WeightsUnpickler error: Unsupported global: GLOBAL gluonts.torch.distributions.studentT.StudentTOutput was not an allowed global by default. Please use `torch.serialization.add_safe_globals([gluonts.torch.distributions.studentT.StudentTOutput])` or the `torch.serialization.safe_globals([gluonts.torch.distributions.studentT.StudentTOutput])` context manager to allowlist this global if you trust this class/function.

Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.

In [26]:
# Lag-Llama with TimesFM-style interface

# Install dependencies (run once)
!pip install torch torchvision torchaudio
!pip install gluonts[torch]
!pip install matplotlib pandas numpy
!pip install git+https://github.com/time-series-foundation-models/lag-llama.git

# Download checkpoint (run once)
# Download lag-llama.ckpt from: https://huggingface.co/time-series-foundation-models/Lag-Llama
# Place it in: lag-llama/lag-llama.ckpt

import pandas as pd
import numpy as np
import torch
from datetime import datetime, timedelta
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

from gluonts.dataset.pandas import PandasDataset
from gluonts.evaluation import make_evaluation_predictions
from lag_llama.gluon.estimator import LagLlamaEstimator

# TimesFM-style Lag-Llama interface
class LagLlama:
    def __init__(self, 
                 device="auto",
                 context_len=128,
                 horizon_len=30,
                 num_samples=100,
                 checkpoint_path="lag-llama/lag-llama.ckpt"):
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device == "auto" else device
        self.context_len = context_len
        self.horizon_len = horizon_len
        self.num_samples = num_samples
        self.checkpoint_path = checkpoint_path
        
        print(f"Initialized Lag-Llama with device: {self.device}")
        print(f"Context length: {self.context_len}, Horizon: {self.horizon_len}")
    
    def forecast_on_df(self, inputs, freq="D", value_name="y", num_jobs=-1):
        """
        Forecast on DataFrame - TimesFM-style interface
        
        Args:
            inputs: DataFrame with ['ds', value_name, 'unique_id'] columns
            freq: Frequency string ('D', 'H', 'M', etc.)
            value_name: Name of target column
            num_jobs: Ignored (for compatibility)
            
        Returns:
            DataFrame with forecasts
        """
        
        # Prepare data
        df = inputs.copy()
        df['ds'] = pd.to_datetime(df['ds'])
        df = df.sort_values(['unique_id', 'ds'] if 'unique_id' in df.columns else ['ds'])
        df[value_name] = df[value_name].astype('float32')
        
        # Create GluonTS dataset
        if 'unique_id' in df.columns:
            dataset = PandasDataset.from_long_dataframe(
                df, target=value_name, item_id="unique_id", timestamp="ds", freq=freq
            )
        else:
            df_dict = {0: df.set_index('ds')[value_name]}
            dataset = PandasDataset(df_dict, target=value_name)
        
        # Load checkpoint and create predictor (fix for PyTorch 2.6)
        import torch.serialization
        torch.serialization.add_safe_globals([
            'gluonts.torch.distributions.studentT.StudentTOutput',
            'gluonts.torch.distributions.neg_binomial.NegativeBinomialOutput',
            'gluonts.torch.distributions.normal.NormalOutput'
        ])
        ckpt = torch.load(self.checkpoint_path, map_location=self.device, weights_only=False)
        estimator_args = ckpt["hyper_parameters"]["model_kwargs"]
        
        # RoPE scaling for longer contexts
        rope_scaling_arguments = {
            "type": "linear", 
            "factor": max(1.0, (self.context_len + self.horizon_len) / estimator_args["context_length"]),
        }
        
        estimator = LagLlamaEstimator(
            ckpt_path=self.checkpoint_path,
            prediction_length=self.horizon_len,
            context_length=self.context_len,
            input_size=estimator_args["input_size"],
            n_layer=estimator_args["n_layer"],
            n_embd_per_head=estimator_args["n_embd_per_head"],
            n_head=estimator_args["n_head"],
            scaling=estimator_args["scaling"],
            time_feat=estimator_args["time_feat"],
            rope_scaling=rope_scaling_arguments,
            batch_size=1,
            num_parallel_samples=self.num_samples,
            device=self.device,
        )
        
        lightning_module = estimator.create_lightning_module()
        transformation = estimator.create_transformation()
        predictor = estimator.create_predictor(transformation, lightning_module)
        
        # Generate forecasts
        forecast_it, ts_it = make_evaluation_predictions(
            dataset=dataset,
            predictor=predictor,
            num_samples=self.num_samples
        )
        
        forecasts = list(forecast_it)
        
        # Convert to DataFrame (TimesFM format)
        forecast_results = []
        
        for i, forecast in enumerate(forecasts):
            unique_id = forecast.item_id if hasattr(forecast, 'item_id') else f"series_{i}"
            
            # Generate future dates
            last_date = forecast.start_date.to_timestamp()
            future_dates = pd.date_range(
                start=last_date, 
                periods=self.horizon_len, 
                freq=freq
            )
            
            # Get predictions
            mean_forecast = forecast.mean
            q10 = forecast.quantile(0.1)
            q90 = forecast.quantile(0.9)
            
            for j, date in enumerate(future_dates):
                forecast_results.append({
                    'unique_id': unique_id,
                    'ds': date,
                    f'{value_name}': mean_forecast[j],  # TimesFM format
                    f'{value_name}_q10': q10[j],
                    f'{value_name}_q90': q90[j]
                })
        
        return pd.DataFrame(forecast_results)

# Initialize model (exactly like TimesFM)
tfm = LagLlama(
    device="auto",
    context_len=128,
    horizon_len=30,
    num_samples=100,
    checkpoint_path="lag-llama/lag-llama.ckpt"
)

# Univariate example (exactly like TimesFM)
start_date = datetime.now()
num_days = 365
dates = [start_date + timedelta(days=i) for i in range(num_days)]
sales = [100 + i * 10 + np.random.normal(0, 10) for i in range(num_days)]

data = pd.DataFrame({
    'ds': dates,
    'sales': sales,
    'unique_id': "sales"
})

# Multivariate example (other variables get ignored)
start_date = datetime.now()
num_days = 365
dates = [start_date + timedelta(days=i) for i in range(num_days)]
sales = [100 + i * 10 + np.random.normal(0, 10) for i in range(num_days)]
ad_spend = [50 + 20 * np.sin(2 * np.pi * i / 30) + np.random.normal(0, 5) for i in range(num_days)]
temperature = [20 + 10 * np.sin(2 * np.pi * i / 365) + np.random.normal(0, 2) for i in range(num_days)]

data = pd.DataFrame({
    'ds': dates,
    'sales': sales,
    'ad_spend': ad_spend,
    'temperature': temperature,
    'unique_id': "sales"
})

# Forecast (exactly like TimesFM)
import time
start_time = time.time()

forecast_df = tfm.forecast_on_df(
    inputs=data,
    freq="D",
    value_name="sales",
    num_jobs=-1,
)

print(f"Wall time: {time.time() - start_time:.2f}s")

# Display results (exactly like TimesFM)
print("Original data:")
print(data.head())
print(f"Shape: {data.shape}")

print("\nForecast:")
print(forecast_df.head())
print(f"Shape: {forecast_df.shape}")

# Plot results
plt.figure(figsize=(12, 6))
plt.plot(data['ds'][-60:], data['sales'][-60:], label='Historical', color='blue')
plt.plot(forecast_df['ds'], forecast_df['sales'], label='Forecast', color='red')
plt.fill_between(forecast_df['ds'], 
                 forecast_df['sales_q10'], 
                 forecast_df['sales_q90'], 
                 alpha=0.3, color='red', label='80% Confidence')
plt.legend()
plt.title('Sales Forecast with Lag-Llama')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

Collecting torchaudio
  Downloading torchaudio-2.7.1-cp311-cp311-macosx_11_0_arm64.whl.metadata (6.6 kB)
Downloading torchaudio-2.7.1-cp311-cp311-macosx_11_0_arm64.whl (1.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m28.7 MB/s[0m eta [36m0:00:00[0m
[0mInstalling collected packages: torchaudio
[0mSuccessfully installed torchaudio-2.7.1
[0mzsh:1: no matches found: gluonts[torch]
[0mCollecting git+https://github.com/time-series-foundation-models/lag-llama.git
  Cloning https://github.com/time-series-foundation-models/lag-llama.git to /private/var/folders/ls/yshfj6q93s9_16nz44mv2pb00000gn/T/pip-req-build-lpmkxlju
  Running command git clone --filter=blob:none --quiet https://github.com/time-series-foundation-models/lag-llama.git /private/var/folders/ls/yshfj6q93s9_16nz44mv2pb00000gn/T/pip-req-build-lpmkxlju
  Resolved https://github.com/time-series-foundation-models/lag-llama.git to commit df7531a83a19b3c6a0222d703ca9bf59ef7a6ab9
  Installin

AttributeError: 'str' object has no attribute '__module__'