# Quick Start: Running Chronos Bolt on MUSED-FM Benchmark

This notebook shows how to run Chronos Bolt models on the MUSED-FM benchmark using the `run_musedfm.py` script.

Make sure you have the MUSED-FM benchmark data downloaded and set the `--benchmark-path` correctly before running this notebook.

We will use the MUSED-FM framework to load the data and run the Chronos Bolt model. This notebook demonstrates how to integrate Chronos Bolt with the MUSED-FM evaluation framework.


## Installation

Install required packages:

```bash
pip install chronos-forecasting
pip install musedfm
```

Make sure you have PyTorch installed with CUDA support if you want to use GPU acceleration.


In [11]:
import os
import sys
import subprocess
import pandas as pd
import numpy as np
import torch
from pathlib import Path
from typing import Optional
from abc import ABC, abstractmethod

# Add the src directory to the Python path
sys.path.insert(0, str(Path.cwd() / 'src'))

# Import MUSED-FM components
from musedfm.data import Benchmark
from musedfm.metrics import MAPE, MAE, RMSE, NMAE

print("MUSED-FM components imported successfully!")


MUSED-FM components imported successfully!


In [12]:
# Self-contained ChronosForecast class for the notebook
class ChronosForecast:
    """
    Chronos forecasting model wrapper for MUSED-FM evaluation.
    This class is self-contained within the notebook.
    """
    
    def __init__(self, model_path: str = "amazon/chronos-bolt-base", device: str = "cuda:0", num_samples: int = 20):
        """
        Initialize Chronos forecast model.
        
        Args:
            model_path: Path to Chronos model (HuggingFace model ID or local path)
            device: Device to run the model on
            num_samples: Number of samples for probabilistic forecasting
        """
        self.model_path = model_path
        self.device = device
        self.num_samples = num_samples
        self.pipeline = None
        self._load_model()
    
    def _load_model(self):
        """Load the Chronos model."""
        try:
            from chronos import BaseChronosPipeline, ForecastType
            
            self.pipeline = BaseChronosPipeline.from_pretrained(
                self.model_path,
                device_map=self.device,
            )
            print(f"Loaded Chronos model: {self.model_path}")
        except ImportError:
            raise ImportError("Chronos package not installed. Please install with: pip install chronos-forecasting")
        except Exception as e:
            raise RuntimeError(f"Failed to load Chronos model: {e}")
    
    def forecast(self, history: np.ndarray, covariates: Optional[np.ndarray] = None, forecast_horizon: Optional[int] = None) -> np.ndarray:
        """
        Generate forecast from historical data using Chronos.
        
        Args:
            history: Historical time series data
            covariates: Optional covariate data (ignored for Chronos)
            forecast_horizon: Number of future points to forecast (default: 1)
            
        Returns:
            Forecast values
        """
        if forecast_horizon is None:
            forecast_horizon = 1
        
        try:
            # Convert history to torch tensor
            if isinstance(history, np.ndarray):
                history_tensor = torch.tensor(history, dtype=torch.float32)
            else:
                history_tensor = torch.tensor(np.array(history), dtype=torch.float32)
            
            # Remove NaN values
            history_clean = history_tensor[~torch.isnan(history_tensor)]
            
            if len(history_clean) == 0:
                # If no valid data, return zeros
                return np.zeros(forecast_horizon)
            
            # Ensure we have enough history for forecasting
            if len(history_clean) < 2:
                # If insufficient data, return the last value repeated
                last_value = float(history_clean[-1]) if len(history_clean) > 0 else 0.0
                return np.full(forecast_horizon, last_value)
            
            # Generate forecast using Chronos
            context = [history_clean]
            
            # Determine prediction kwargs based on forecast type
            predict_kwargs = {}
            if hasattr(self.pipeline, 'forecast_type'):
                from chronos import ForecastType
                if self.pipeline.forecast_type == ForecastType.SAMPLES:
                    predict_kwargs = {"num_samples": self.num_samples}
            
            # Generate forecast
            forecast_output = self.pipeline.predict(
                context,
                prediction_length=forecast_horizon,
                **predict_kwargs
            )
            
            # Convert to numpy array
            if isinstance(forecast_output, torch.Tensor):
                forecast_np = forecast_output.numpy()
            else:
                forecast_np = np.array(forecast_output)
            
            # Handle different output shapes
            if forecast_np.ndim > 1:
                # If we have multiple samples, take the mean
                if forecast_np.shape[0] > 1:
                    forecast_np = np.mean(forecast_np, axis=0)
                else:
                    forecast_np = forecast_np[0]
            
            # Ensure we have the right length
            if len(forecast_np) != forecast_horizon:
                if len(forecast_np) > forecast_horizon:
                    forecast_np = forecast_np[:forecast_horizon]
                else:
                    # Pad with the last value if needed
                    last_val = forecast_np[-1] if len(forecast_np) > 0 else 0.0
                    forecast_np = np.pad(forecast_np, (0, forecast_horizon - len(forecast_np)), 'constant', constant_values=last_val)
            
            return forecast_np
            
        except Exception as e:
            print(f"Warning: Chronos forecasting failed: {e}")
            # Fallback to simple mean forecast
            if len(history) > 0:
                mean_val = np.nanmean(history)
                return np.full(forecast_horizon, mean_val)
            else:
                return np.zeros(forecast_horizon)

print("ChronosForecast class defined successfully!")


ChronosForecast class defined successfully!


## Configuration

Set up the benchmark path and model parameters. Adjust these according to your setup.
ma

In [13]:
# Configuration
BENCHMARK_PATH = "/workspace/data/fm_eval_nested/"  # Adjust this path to your MUSED-FM data
MODEL_PATH = "amazon/chronos-bolt-base"  # Chronos Bolt model
DEVICE = "cuda:0"  # Use "cpu" if you don't have CUDA
NUM_SAMPLES = 20  # Number of samples for probabilistic forecasting
MAX_WINDOWS = 50  # Limit windows per dataset for faster testing
OUTPUT_DIR = "./results/chronos_bolt"

# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)

print(f"Benchmark path: {BENCHMARK_PATH}")
print(f"Model: {MODEL_PATH}")
print(f"Device: {DEVICE}")
print(f"Output directory: {OUTPUT_DIR}")


Benchmark path: /workspace/data/fm_eval_nested/
Model: amazon/chronos-bolt-base
Device: cuda:0
Output directory: ./results/chronos_bolt


## Initialize Chronos Model

Create a ChronosForecast instance that integrates with the MUSED-FM framework.


In [14]:
# Initialize Chronos model
try:
    chronos_model = ChronosForecast(
        model_path=MODEL_PATH,
        device=DEVICE,
        num_samples=NUM_SAMPLES
    )
    print("Chronos model initialized successfully!")
except Exception as e:
    print(f"Error initializing Chronos model: {e}")
    print("Make sure you have installed chronos-forecasting and have the required dependencies.")


Loaded Chronos model: amazon/chronos-bolt-base
Chronos model initialized successfully!


## Run Chronos Model Directly in Notebook

Instead of using `run_musedfm.py`, we can run Chronos directly in the notebook for more control and immediate results.


In [15]:
# Load the MUSED-FM benchmark
print("Loading MUSED-FM benchmark...")
try:
    benchmark = Benchmark(BENCHMARK_PATH)
    print(f"Benchmark loaded successfully!")
    print(f"Number of categories: {len(benchmark)}")
    
    # Print some basic info about the benchmark
    total_datasets = 0
    for category in benchmark:
        print(f"Category: {category.category} ({len(category)} domains)")
        for domain in category:
            print(f"  Domain: {domain.domain_name} ({len(domain)} datasets)")
            total_datasets += len(domain)
    
    print(f"Total datasets in benchmark: {total_datasets}")
    
except Exception as e:
    print(f"Error loading benchmark: {e}")
    print(f"Make sure the benchmark path '{BENCHMARK_PATH}' is correct and contains MUSED-FM data.")
    raise


Loading MUSED-FM benchmark...
Large dataset cifar100_timeseries_csvs detected (50000 files). Estimating window count by sampling first 100 files...
Estimated 27300000 total windows (546.0 avg per file)
Loading KITTI data from /workspace/data/fm_eval_nested/sequential/KITTI
Found 6114 parquet files
Successfully loaded 6114 valid files
Large dataset KITTI detected (6114 files). Estimating window count by sampling first 100 files...
Estimated 1016819 total windows (166.3 avg per file)
Large dataset ant_csv_out detected (10134 files). Estimating window count by sampling first 100 files...
Estimated 5365243 total windows (529.4 avg per file)
Large dataset hopper_csv_out detected (8470 files). Estimating window count by sampling first 100 files...
Estimated 1351981 total windows (159.6 avg per file)
Large dataset spriteworld detected (19534 files). Estimating window count by sampling first 100 files...
Estimated 9321429 total windows (477.2 avg per file)
Large dataset walker2d_csv_out detect

## Direct Evaluation

Now let's run Chronos directly on the benchmark data for immediate results and full control.


In [16]:
# Direct evaluation function
def evaluate_chronos_directly(benchmark, model, max_datasets=5, max_windows_per_dataset=10):
    """
    Directly evaluate Chronos model on benchmark data.
    """
    results = []
    dataset_count = 0
    
    print(f"Starting direct evaluation on up to {max_datasets} datasets...")
    
    for category in benchmark:
        if dataset_count >= max_datasets:
            break
            
        for domain in category:
            if dataset_count >= max_datasets:
                break
                
            for dataset in domain:
                if dataset_count >= max_datasets:
                    break
                    
                print(f"\nEvaluating dataset: {dataset.dataset_name} ({dataset_count + 1}/{max_datasets})")
                print(f"Category: {category.category}, Domain: {domain.domain}")
                print(f"Dataset size: {len(dataset)} windows")
                
                # Limit windows for faster evaluation
                windows_processed = 0
                dataset_metrics = {'MAPE': [], 'MAE': [], 'RMSE': [], 'NMAE': []}
                
                for window in dataset:
                    if windows_processed >= max_windows_per_dataset:
                        break
                        
                    try:
                        # Get history and future data
                        history = window.history
                        future = window.future
                        forecast_horizon = len(future)
                        
                        # Generate forecast
                        forecast = model.forecast(
                            history=history,
                            forecast_horizon=forecast_horizon
                        )
                        
                        # Calculate metrics
                        mape = MAPE(future, forecast)
                        mae = MAE(future, forecast)
                        rmse = RMSE(future, forecast)
                        nmae = NMAE(future, forecast)
                        
                        dataset_metrics['MAPE'].append(mape)
                        dataset_metrics['MAE'].append(mae)
                        dataset_metrics['RMSE'].append(rmse)
                        dataset_metrics['NMAE'].append(nmae)
                        
                        windows_processed += 1
                        
                    except Exception as e:
                        print(f"Warning: Error processing window {windows_processed}: {e}")
                        continue
                
                # Calculate average metrics for this dataset
                if windows_processed > 0:
                    avg_metrics = {}
                    for metric_name, values in dataset_metrics.items():
                        if values:
                            avg_metrics[metric_name] = np.mean(values)
                        else:
                            avg_metrics[metric_name] = np.nan
                    
                    result = {
                        'dataset': dataset.dataset_name,
                        'category': category.category,
                        'domain': domain.domain,
                        'windows_processed': windows_processed,
                        **avg_metrics
                    }
                    results.append(result)
                    
                    print(f"Processed {windows_processed} windows")
                    print(f"Average MAPE: {avg_metrics['MAPE']:.4f}")
                    print(f"Average MAE: {avg_metrics['MAE']:.4f}")
                    print(f"Average RMSE: {avg_metrics['RMSE']:.4f}")
                    print(f"Average NMAE: {avg_metrics['NMAE']:.4f}")
                
                dataset_count += 1
    
    return results

# Run direct evaluation
print("Starting Chronos evaluation...")
direct_results = evaluate_chronos_directly(benchmark, chronos_model, max_datasets=3, max_windows_per_dataset=5)
print(f"\nDirect evaluation completed on {len(direct_results)} datasets")


Starting Chronos evaluation...
Starting direct evaluation on up to 3 datasets...

Evaluating dataset: cifar100_timeseries_csvs (1/3)


AttributeError: 'Domain' object has no attribute 'domain'

In [None]:
# Import plotting utilities
import matplotlib.pyplot as plt
from musedfm.plotting import plot_window_forecasts, plot_multiple_windows, plot_baseline_comparison

# Plot forecasts for the first few windows from each dataset
def plot_chronos_forecasts(benchmark, model, max_datasets=3, max_windows_per_dataset=3):
    """
    Plot Chronos forecasts for visualization.
    """
    plot_data = []
    dataset_count = 0
    
    print(f"Collecting plot data from up to {max_datasets} datasets...")
    
    for category in benchmark:
        if dataset_count >= max_datasets:
            break
            
        for domain in category:
            if dataset_count >= max_datasets:
                break
                
            for dataset in domain:
                if dataset_count >= max_datasets:
                    break
                    
                print(f"\nCollecting plot data from: {dataset.dataset_name}")
                
                # Collect windows and forecasts for plotting
                windows_processed = 0
                dataset_forecasts = {}
                
                for window in dataset:
                    if windows_processed >= max_windows_per_dataset:
                        break
                        
                    try:
                        # Get history and future data
                        history = window.history
                        future = window.future
                        forecast_horizon = len(future)
                        
                        # Generate forecast
                        forecast = model.forecast(
                            history=history,
                            forecast_horizon=forecast_horizon
                        )
                        
                        # Store window and forecast data
                        plot_data.append({
                            'window': window,
                            'forecast': forecast,
                            'dataset_name': dataset.dataset_name,
                            'category': category.category,
                            'domain': domain.domain_name,
                            'window_idx': windows_processed
                        })
                        
                        windows_processed += 1
                        
                    except Exception as e:
                        print(f"Warning: Error processing window {windows_processed}: {e}")
                        continue
                
                print(f"Collected {windows_processed} windows for plotting")
                dataset_count += 1
    
    return plot_data

# Collect plot data
plot_data = plot_chronos_forecasts(benchmark, chronos_model, max_datasets=3, max_windows_per_dataset=3)

print(f"\nCollected plot data for {len(plot_data)} windows")


In [None]:
# Create individual plots for each window
if plot_data:
    print("Creating individual forecast plots...")
    
    # Create plots directory
    plots_dir = os.path.join(OUTPUT_DIR, "plots")
    os.makedirs(plots_dir, exist_ok=True)
    
    for i, data in enumerate(plot_data[:9]):  # Plot first 9 windows
        window = data['window']
        forecast = data['forecast']
        dataset_name = data['dataset_name']
        category = data['category']
        domain = data['domain']
        window_idx = data['window_idx']
        
        # Create forecasts dictionary for the plotting function
        forecasts = {'Chronos Bolt': forecast}
        
        # Create title
        title = f"{dataset_name} - Window {window_idx}\nCategory: {category}, Domain: {domain}"
        
        # Plot the window
        plot_window_forecasts(
            window=window,
            forecasts=forecasts,
            title=title,
            figsize=(12, 6),
            save_path=os.path.join(plots_dir, f"chronos_forecast_{i+1}.png")
        )
        
        print(f"Saved plot {i+1}: {dataset_name} - Window {window_idx}")
    
    print(f"\nIndividual plots saved to: {plots_dir}")
    
else:
    print("No plot data available.")


In [None]:
# Create a multi-window comparison plot
if plot_data and len(plot_data) >= 3:
    print("Creating multi-window comparison plot...")
    
    # Select first 6 windows for comparison
    selected_data = plot_data[:6]
    windows = [data['window'] for data in selected_data]
    forecasts_dict = {'Chronos Bolt': {i: data['forecast'] for i, data in enumerate(selected_data)}}
    
    # Create window titles
    window_titles = []
    for data in selected_data:
        title = f"{data['dataset_name']}\nWindow {data['window_idx']}"
        window_titles.append(title)
    
    # Create multi-window plot
    plot_multiple_windows(
        windows=windows,
        forecasts_dict=forecasts_dict,
        window_titles=window_titles,
        figsize=(18, 12),
        save_path=os.path.join(plots_dir, "chronos_multi_window_comparison.png")
    )
    
    print("Multi-window comparison plot saved!")
    
    # Also create a baseline comparison plot for the first window
    if plot_data:
        first_data = plot_data[0]
        window = first_data['window']
        forecast = first_data['forecast']
        
        # Calculate metrics for the baseline comparison
        future = window.future
        metrics = {
            'Chronos Bolt': {
                'MAPE': MAPE(future, forecast),
                'MAE': MAE(future, forecast),
                'RMSE': RMSE(future, forecast),
                'NMAE': NMAE(future, forecast)
            }
        }
        
        forecasts = {'Chronos Bolt': forecast}
        title = f"Chronos Bolt Performance\n{first_data['dataset_name']} - Window {first_data['window_idx']}"
        
        plot_baseline_comparison(
            window=window,
            forecasts=forecasts,
            metrics=metrics,
            title=title,
            figsize=(15, 8),
            save_path=os.path.join(plots_dir, "chronos_baseline_comparison.png")
        )
        
        print("Baseline comparison plot saved!")
    
    print(f"\nAll plots saved to: {plots_dir}")
    print("You can view the plots to see how Chronos Bolt performs on different datasets!")
    
else:
    print("Not enough data for multi-window comparison plot.")


In [None]:
# Display results
if direct_results:
    df_results = pd.DataFrame(direct_results)
    print("Evaluation Results:")
    print("=" * 50)
    print(df_results.to_string(index=False))
    
    # Calculate overall averages
    print("\nOverall Average Metrics:")
    print("=" * 30)
    numeric_cols = ['MAPE', 'MAE', 'RMSE', 'NMAE']
    for col in numeric_cols:
        if col in df_results.columns:
            avg_val = df_results[col].mean()
            print(f"{col}: {avg_val:.6f}")
    
    # Save results to CSV
    results_file = os.path.join(OUTPUT_DIR, "chronos_results.csv")
    df_results.to_csv(results_file, index=False)
    print(f"\nResults saved to: {results_file}")
    
    # Save summary
    summary_file = os.path.join(OUTPUT_DIR, "summary.txt")
    with open(summary_file, 'w') as f:
        f.write("Chronos Bolt Evaluation Summary\n")
        f.write("=" * 50 + "\n\n")
        f.write(f"Total datasets evaluated: {len(direct_results)}\n")
        f.write(f"Model: {MODEL_PATH}\n")
        f.write(f"Device: {DEVICE}\n")
        f.write(f"Max windows per dataset: 5\n\n")
        f.write("Overall Average Metrics:\n")
        for col in numeric_cols:
            if col in df_results.columns:
                avg_val = df_results[col].mean()
                f.write(f"  {col}: {avg_val:.6f}\n")
    
    print(f"Summary saved to: {summary_file}")
    
else:
    print("No results available. Make sure the evaluation completed successfully.")
