In [None]:
# Chronos-2 on M5 Forecasting
# Zero-shot with covariates support

import sys
import warnings
import numpy as np
import pandas as pd
from pathlib import Path
from tqdm.auto import tqdm

warnings.filterwarnings('ignore')


FORECAST_LENGTH = 28
N_SERIES_TEST = None  
data_dir = Path('../data/raw')
output_dir = Path('../data/chronos_m5_output')
output_dir.mkdir(parents=True, exist_ok=True)

print("Chronos-2 on M5")
print(f"Forecast horizon: {FORECAST_LENGTH} days")
print(f"Using all 30,490 series")

# ============================================================================
# Load Chronos-2
# ============================================================================

print("\nLoading Chronos-2 model...")

from chronos import Chronos2Pipeline

# Use CPU or GPU
import torch
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Device: {device}")

# Load model
pipeline = Chronos2Pipeline.from_pretrained(
    "amazon/chronos-2",
    device_map=device,
    torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32,
)

print(f"Model loaded successfully")

# ============================================================================
# Load M5 Data
# ============================================================================

print("\nLoading M5 data...")

sales = pd.read_csv(data_dir / 'sales_train_evaluation.csv')
calendar = pd.read_csv(data_dir / 'calendar.csv')
prices = pd.read_csv(data_dir / 'sell_prices.csv')

print(f"  Sales: {sales.shape}")
print(f"  Calendar: {calendar.shape}")
print(f"  Prices: {prices.shape}")

# ============================================================================
# Prepare Data for Chronos-2
# ============================================================================

print("\nPreparing data...")

# Get date columns
date_cols = [c for c in sales.columns if c.startswith('d_')]
n_days = len(date_cols)
timestamps = pd.date_range(start='2011-01-29', periods=n_days, freq='D')

# Calendar features
calendar['date'] = pd.to_datetime(calendar['date'])
calendar_dict = calendar.set_index('date').to_dict('index')

def prepare_chronos_data(sales_df, calendar_df, n_series=None):
    """
    Prepare M5 data in Chronos-2 format with covariates
    
    Returns DataFrame with columns:
    - item_id: series identifier
    - timestamp: date
    - target: sales values
    - covariate columns (calendar features)
    """
    
    if n_series:
        sales_df = sales_df.head(n_series)
    
    # Align calendar to sales days
    date_cols = [c for c in sales_df.columns if c.startswith('d_')]
    calendar_aligned = calendar_df[calendar_df['d'].isin(date_cols)].reset_index(drop=True)
    
    # Pre-compute calendar features
    cal_wday = calendar_aligned['wday'].astype('category').cat.codes.values
    cal_month = calendar_aligned['month'].astype('category').cat.codes.values
    cal_weekend = calendar_aligned['weekday'].isin(['Saturday', 'Sunday']).astype(int).values
    cal_snap_ca = calendar_aligned['snap_CA'].values
    cal_snap_tx = calendar_aligned['snap_TX'].values
    cal_snap_wi = calendar_aligned['snap_WI'].values
    cal_event = calendar_aligned['event_name_1'].notna().astype(int).values
    
    all_data = []
    
    for idx in tqdm(range(len(sales_df)), desc="  Converting"):
        series_id = sales_df.iloc[idx]['id']
        sales_values = sales_df.iloc[idx][date_cols].values.astype(float)
        
        # Create DataFrame for this series
        df_series = pd.DataFrame({
            'item_id': series_id,
            'timestamp': timestamps,
            'target': sales_values,
            'wday': cal_wday,
            'month': cal_month,
            'is_weekend': cal_weekend,
            'snap_CA': cal_snap_ca,
            'snap_TX': cal_snap_tx,
            'snap_WI': cal_snap_wi,
            'event': cal_event,
        })
        
        all_data.append(df_series)
    
    return pd.concat(all_data, ignore_index=True)

# Convert all series
if N_SERIES_TEST:
    data_all = prepare_chronos_data(sales.head(N_SERIES_TEST), calendar)
else:
    data_all = prepare_chronos_data(sales, calendar)
print(f"  Data shape: {data_all.shape}")
print(f"  Series: {data_all['item_id'].nunique()}")

# ============================================================================
# Zero-Shot Forecasting
# ============================================================================

print("\nGenerating forecasts (zero-shot)...")

# Split into context (history) and future
context_end = timestamps[-FORECAST_LENGTH]
context_df = data_all[data_all['timestamp'] <= context_end]

print(f"  Context data: {context_df.shape}")
print(f"  Forecasting {FORECAST_LENGTH} steps ahead...")
print(f"  Estimated time: ~4-6 hours for 30k series")

# Chronos-2 predict_df API
import time
start_time = time.time()

forecasts = pipeline.predict_df(
    df=context_df,
    prediction_length=FORECAST_LENGTH,
    id_column="item_id",
    timestamp_column="timestamp",
    target="target",
    quantile_levels=[0.1, 0.5, 0.9],  # Probabilistic forecasts
    batch_size=128,  # Increased from 32 for speed
)

elapsed = time.time() - start_time
print(f"  Forecasting completed in {elapsed/60:.1f} minutes")

print(f"  Forecast shape: {forecasts.shape}")
print(f"  Forecast columns: {list(forecasts.columns)}")

# ============================================================================
# Extract Point Forecasts
# ============================================================================

print("\nExtracting point forecasts...")

# Get median (0.5 quantile) as point forecast
if N_SERIES_TEST:
    series_order = sales.head(N_SERIES_TEST)['id'].tolist()
    print(f"  NOTE: Using {N_SERIES_TEST} series (testing)")
else:
    series_order = sales['id'].tolist()
    print(f"  Using all {len(series_order)} series")

all_forecasts = []

# Determine the median column name
if 'target_0.5' in forecasts.columns:
    median_col = 'target_0.5'
elif '0.5' in forecasts.columns:
    median_col = '0.5'
elif 'mean' in forecasts.columns:
    median_col = 'mean'
else:
    # Use first numeric column after item_id and timestamp
    numeric_cols = forecasts.select_dtypes(include=[np.number]).columns
    median_col = [c for c in numeric_cols if c not in ['item_id']][0]

print(f"  Using column: {median_col}")

for series_id in tqdm(series_order, desc="  Processing"):
    series_forecast = forecasts[
        (forecasts['item_id'] == series_id)
    ].sort_values('timestamp')
    
    if len(series_forecast) > 0:
        # Use median column
        forecast_values = series_forecast[median_col].values[:FORECAST_LENGTH]
        
        # Pad if needed
        if len(forecast_values) < FORECAST_LENGTH:
            forecast_values = np.pad(
                forecast_values,
                (0, FORECAST_LENGTH - len(forecast_values)),
                mode='edge'
            )
    else:
        # Fallback: zeros
        forecast_values = np.zeros(FORECAST_LENGTH)
    
    all_forecasts.append(forecast_values)

forecast_array = np.array(all_forecasts)
forecast_array = np.maximum(forecast_array, 0)  # No negative sales

print(f"  Final array: {forecast_array.shape}")

# ============================================================================
# Calculate WRMSSE
# ============================================================================

print("\nCalculating WRMSSE...")

sys.path.append('../src')
try:
    from m5_wrmsse import wrmsse
    wrmsse_score = wrmsse(forecast_array)
    
    print(f"\n{'='*60}")
    print(f"CHRONOS-2 RESULTS")
    print(f"{'='*60}")
    print(f"  WRMSSE: {wrmsse_score:.4f}")
    print(f"{'='*60}")
except Exception as e:
    print(f"  Error: {e}")
    import traceback
    traceback.print_exc()
    wrmsse_score = None

# ============================================================================
# Save Results
# ============================================================================

print("\nSaving results...")

# Save forecasts
forecast_df = pd.DataFrame(forecast_array, index=series_order)
forecast_df.to_pickle(output_dir / 'chronos_forecasts.pkl')

# Save full probabilistic forecasts
forecasts.to_pickle(output_dir / 'chronos_probabilistic.pkl')

# Save summary
import pickle
summary = {
    'wrmsse': wrmsse_score,
    'model': 'chronos-2',
    'n_series': len(series_order),
    'forecast_length': FORECAST_LENGTH,
    'zero_shot': True,
}

with open(output_dir / 'summary.pkl', 'wb') as f:
    pickle.dump(summary, f)

print("\nDone!")
print(f"  Forecasts saved to: {output_dir}")
if wrmsse_score:
    print(f"  WRMSSE: {wrmsse_score:.4f}")