In [None]:
# TTM M5 Multivariate

import math
import os
import sys
import warnings
import pickle
import numpy as np
import pandas as pd
from pathlib import Path
from tqdm.auto import tqdm

import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import Subset
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments, set_seed

from tsfm_public import (
    ForecastDFDataset,
    TimeSeriesPreprocessor,
    TinyTimeMixerForPrediction,
    TrackingCallback,
    count_parameters,
)
from tsfm_public.toolkit.time_series_preprocessor import prepare_data_splits

warnings.filterwarnings('ignore')
set_seed(42)

device = "cpu"  

FORECAST_LENGTH = 28
CONTEXT_LENGTH = 90  
TTM_MODEL_PATH = "ibm-granite/granite-timeseries-ttm-r2"
REVISION = "90-30-ft-l1-r2.1"  

data_dir = Path('../data/raw')
output_dir = Path('../data/ttm_m5_multivariate')
output_dir.mkdir(parents=True, exist_ok=True)

N_SERIES_SAMPLE = 2000
FEWSHOT_FRACTION = 0.10
NUM_EPOCHS = 5
BATCH_SIZE = 256  
LEARNING_RATE = 0.001

print(f"Context: {CONTEXT_LENGTH} | Horizon: {FORECAST_LENGTH}")
print(f"Series: {N_SERIES_SAMPLE} | Epochs: {NUM_EPOCHS} | Device: {device}")

# ============================================================================
# Load Data
# ============================================================================

print("\nLoading data...")
sales = pd.read_csv(data_dir / 'sales_train_evaluation.csv')
calendar = pd.read_csv(data_dir / 'calendar.csv')
prices = pd.read_csv(data_dir / 'sell_prices.csv')
print(f"  Loaded: Sales {sales.shape} | Calendar {calendar.shape} | Prices {prices.shape}")

# ============================================================================
# Prepare Calendar - VECTORIZED
# ============================================================================

print("Preparing calendar...")
calendar['date'] = pd.to_datetime(calendar['date'])
calendar['wday_num'] = calendar['wday'].astype('category').cat.codes
calendar['month_num'] = calendar['month'].astype('category').cat.codes
calendar['year_norm'] = calendar['year'] - 2011
calendar['is_weekend'] = calendar['weekday'].isin(['Saturday', 'Sunday']).astype(int)
calendar['has_event_1'] = calendar['event_name_1'].notna().astype(int)
calendar['has_event_2'] = calendar['event_name_2'].notna().astype(int)

cal_cols = ['wday_num', 'month_num', 'year_norm', 'is_weekend',
            'has_event_1', 'has_event_2', 'snap_CA', 'snap_TX', 'snap_WI']

cal_lookup = calendar.set_index('d')[cal_cols].values
cal_d_to_idx = {d: i for i, d in enumerate(calendar['d'].values)}

# ============================================================================
# Prepare Prices
# ============================================================================

print("Preparing prices...")
prices['id'] = prices['item_id'] + '_' + prices['store_id'] + '_evaluation'
prices_wide = prices.pivot_table(index='id', columns='wm_yr_wk', values='sell_price')
prices_wide = prices_wide.ffill(axis=1).bfill(axis=1)
cal_week_map = calendar.set_index('d')['wm_yr_wk'].to_dict()

# ============================================================================
# Convert to Long Format - VECTORIZED
# ============================================================================

print("\nConverting to long format (vectorized)...")

def prepare_m5_vectorized(sales_df, cal_lookup, cal_d_to_idx, prices_wide, cal_week_map, n_series=None):
    if n_series:
        sales_df = sales_df.head(n_series).copy()

    date_cols = [c for c in sales_df.columns if c.startswith('d_')]
    n_days = len(date_cols)
    n_series_actual = len(sales_df)

    day_indices = [cal_d_to_idx[d] for d in date_cols]
    cal_matrix = cal_lookup[day_indices]
    sales_matrix = sales_df[date_cols].values.astype(float)

    price_matrix = np.ones((n_series_actual, n_days), dtype=float)
    for i, series_id in enumerate(sales_df['id'].values):
        if series_id in prices_wide.index:
            series_prices = prices_wide.loc[series_id]
            for j, d in enumerate(date_cols):
                wk = cal_week_map.get(d)
                if wk and wk in series_prices.index and pd.notna(series_prices[wk]):
                    price_matrix[i, j] = series_prices[wk]

    timestamps = pd.date_range(start='2011-01-29', periods=n_days, freq='D')
    series_ids = np.repeat(sales_df['id'].values, n_days)
    timestamp_arr = np.tile(timestamps, n_series_actual)
    values = sales_matrix.flatten()
    price_arr = price_matrix.flatten()
    cal_tiled = np.tile(cal_matrix, (n_series_actual, 1))

    df = pd.DataFrame({
        'timestamp': timestamp_arr,
        'series_id': series_ids,
        'value': values,
        'price': price_arr,
    })

    for i, col in enumerate(cal_cols):
        df[col] = cal_tiled[:, i]

    valid_mask = df.groupby('series_id')['value'].transform('count') >= (CONTEXT_LENGTH + FORECAST_LENGTH)
    df = df[valid_mask].reset_index(drop=True)

    return df

data = prepare_m5_vectorized(sales, cal_lookup, cal_d_to_idx, prices_wide, cal_week_map, n_series=N_SERIES_SAMPLE)
print(f"  Data: {data.shape} | Series: {data['series_id'].nunique()}")

# ============================================================================
# Preprocessor
# ============================================================================

print("\nTraining preprocessor...")

covariate_cols = ['price', 'wday_num', 'month_num', 'year_norm', 'is_weekend',
                  'has_event_1', 'has_event_2', 'snap_CA', 'snap_TX', 'snap_WI']

column_specifiers = {
    "timestamp_column": "timestamp",
    "id_columns": ["series_id"],
    "target_columns": ["value"],
    "control_columns": covariate_cols,
}

tsp = TimeSeriesPreprocessor(
    **column_specifiers,
    context_length=CONTEXT_LENGTH,
    prediction_length=FORECAST_LENGTH,
    scaling=True,
    encode_categorical=False,
    scaler_type="standard",
)

train_end_date = data['timestamp'].quantile(0.8)
df_train = data[data['timestamp'] <= train_end_date]
trained_tsp = tsp.train(df_train)

print(f"  Channels: input={tsp.num_input_channels}, target={len(tsp.prediction_channel_indices)}")

# ============================================================================
# Datasets
# ============================================================================

print("\nPreparing datasets...")

split_params = {"train": 0.6, "test": 0.2}
train_data, valid_data, test_data = prepare_data_splits(
    data,
    id_columns=column_specifiers["id_columns"],
    split_config=split_params,
    context_length=CONTEXT_LENGTH
)

min_length = CONTEXT_LENGTH + FORECAST_LENGTH + 2

def filter_short(df, id_cols, min_len):
    counts = df.groupby(id_cols).size()
    valid = counts[counts >= min_len].index
    return df[df[id_cols[0]].isin(valid)]

train_data = filter_short(train_data, column_specifiers["id_columns"], min_length)
valid_data = filter_short(valid_data, column_specifiers["id_columns"], min_length)

frequency_token = tsp.get_frequency_token(tsp.freq)
dataset_params = {
    "timestamp_column": column_specifiers["timestamp_column"],
    "id_columns": column_specifiers["id_columns"],
    "target_columns": column_specifiers["target_columns"],
    "control_columns": column_specifiers["control_columns"],
    "frequency_token": frequency_token,
    "context_length": CONTEXT_LENGTH,
    "prediction_length": FORECAST_LENGTH,
}

train_dataset = ForecastDFDataset(tsp.preprocess(train_data), **dataset_params)
valid_dataset = ForecastDFDataset(tsp.preprocess(valid_data), **dataset_params)

n_train = len(train_dataset)
train_dataset = Subset(train_dataset, np.random.permutation(n_train)[:int(FEWSHOT_FRACTION * n_train)])
n_valid = len(valid_dataset)
valid_dataset = Subset(valid_dataset, np.random.permutation(n_valid)[:int(FEWSHOT_FRACTION * n_valid)])

print(f"  Train: {len(train_dataset)} | Valid: {len(valid_dataset)}")

# ============================================================================
# Load Model
# ============================================================================

print("\nLoading TTM (multivariate)...")

finetune_forecast_model = TinyTimeMixerForPrediction.from_pretrained(
    TTM_MODEL_PATH,
    revision=REVISION,
    context_length=CONTEXT_LENGTH,
    prediction_filter_length=FORECAST_LENGTH,
    num_input_channels=tsp.num_input_channels,
    prediction_channel_indices=tsp.prediction_channel_indices,
    exogenous_channel_indices=list(range(len(tsp.prediction_channel_indices), tsp.num_input_channels)),
    decoder_mode="mix_channel",
    enable_forecast_channel_mixing=True,
)

print(f"  Parameters: {count_parameters(finetune_forecast_model):,}")

# ============================================================================
# Training
# ============================================================================

print("\nTraining...")

OUT_DIR = str(output_dir / "training")
os.makedirs(OUT_DIR, exist_ok=True)

finetune_forecast_args = TrainingArguments(
    output_dir=OUT_DIR,
    overwrite_output_dir=True,
    learning_rate=LEARNING_RATE,
    num_train_epochs=NUM_EPOCHS,
    do_eval=True,
    eval_strategy="epoch",
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=2 * BATCH_SIZE,
    dataloader_num_workers=1,
    report_to="none",
    save_strategy="epoch",
    logging_strategy="epoch",
    save_total_limit=1,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    use_cpu=True,
)

optimizer = AdamW(finetune_forecast_model.parameters(), lr=LEARNING_RATE)
scheduler = OneCycleLR(
    optimizer, LEARNING_RATE, epochs=NUM_EPOCHS,
    steps_per_epoch=math.ceil(len(train_dataset) / BATCH_SIZE),
)

finetune_forecast_trainer = Trainer(
    model=finetune_forecast_model,
    args=finetune_forecast_args,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    callbacks=[
        EarlyStoppingCallback(early_stopping_patience=3),
        TrackingCallback(),
    ],
    optimizers=(optimizer, scheduler),
)

finetune_forecast_trainer.train()
valid_results = finetune_forecast_trainer.evaluate(valid_dataset)
print(f"\nValidation Loss: {valid_results['eval_loss']:.4f}")

# ============================================================================
# Forecasting - FIXED
# ============================================================================

print("\nGenerating forecasts...")

finetune_forecast_model.eval()
finetune_forecast_model.to(device)

sales_full = pd.read_csv(data_dir / 'sales_train_evaluation.csv')
series_order = sales_full['id'].tolist()

print("  Preparing full data...")
data_full = prepare_m5_vectorized(sales_full, cal_lookup, cal_d_to_idx, prices_wide, cal_week_map, n_series=None)

# Manual global scaler normalization
print("  Normalizing...")
scaler_mean = float(tsp.scaler.mean_[0]) if hasattr(tsp, 'scaler') else 0.0
scaler_scale = float(tsp.scaler.scale_[0]) if hasattr(tsp, 'scaler') else 1.0

for col in ['value'] + covariate_cols:
    if col in data_full.columns:
        data_full[col] = (data_full[col] - scaler_mean) / scaler_scale

print(f"    Scaler: mean={scaler_mean:.2f}, scale={scaler_scale:.2f}")

# Batched inference
freq_token = torch.tensor([frequency_token], dtype=torch.long, device=device)
target_col = column_specifiers["target_columns"][0]
all_cols = [target_col] + covariate_cols
INFERENCE_BATCH = 128

all_forecasts_dict = {}
series_groups = data_full.groupby('series_id')

print("  Building batches...")
batch_tensors = []
batch_series_ids = []

for series_id in tqdm(series_order, desc="    Preparing"):
    if series_id in series_groups.groups:
        group = series_groups.get_group(series_id).tail(CONTEXT_LENGTH)
        if len(group) >= CONTEXT_LENGTH:
            tensor = torch.from_numpy(group[all_cols].values[-CONTEXT_LENGTH:].astype(np.float32))
            batch_tensors.append(tensor)
            batch_series_ids.append(series_id)
        else:
            batch_tensors.append(None)
            batch_series_ids.append(series_id)
    else:
        batch_tensors.append(None)
        batch_series_ids.append(series_id)

print("  Forecasting...")
with torch.no_grad():
    for i in tqdm(range(0, len(batch_tensors), INFERENCE_BATCH), desc="    Batches"):
        batch_slice = batch_tensors[i:i+INFERENCE_BATCH]
        series_slice = batch_series_ids[i:i+INFERENCE_BATCH]
        valid_idx = [j for j, t in enumerate(batch_slice) if t is not None]

        if not valid_idx:
            for sid in series_slice:
                all_forecasts_dict[sid] = np.zeros(FORECAST_LENGTH)
            continue

        valid_tensors = torch.stack([batch_slice[j] for j in valid_idx]).to(device)
        outputs = finetune_forecast_model(
            past_values=valid_tensors,
            freq_token=freq_token.expand(valid_tensors.shape[0])
        )

        if hasattr(outputs, 'prediction_outputs'):
            forecasts = outputs.prediction_outputs.cpu().numpy()
        elif hasattr(outputs, 'logits'):
            forecasts = outputs.logits.cpu().numpy()
        else:
            forecasts = outputs.cpu().numpy()

        if forecasts.ndim == 3:
            forecasts = forecasts[:, :, 0]
        
        forecasts = forecasts[:, :FORECAST_LENGTH]
        forecasts = forecasts * scaler_scale + scaler_mean
        forecasts = np.maximum(forecasts, 0)

        valid_iter = iter(range(len(valid_idx)))
        for j, sid in enumerate(series_slice):
            if j in valid_idx:
                idx = next(valid_iter)
                all_forecasts_dict[sid] = forecasts[idx]
            else:
                all_forecasts_dict[sid] = np.zeros(FORECAST_LENGTH)

all_forecasts = [all_forecasts_dict[sid] for sid in series_order]
forecast_array = np.array(all_forecasts)
print(f"  Forecast array: {forecast_array.shape}")

# ============================================================================
# WRMSSE
# ============================================================================

print("\nCalculating WRMSSE...")

sys.path.append('../src')
try:
    from m5_wrmsse import wrmsse
    wrmsse_score = wrmsse(forecast_array)
    
    print(f"\n{'='*60}")
    print(f"RESULTS")
    print(f"{'='*60}")
    print(f"  WRMSSE: {wrmsse_score:.4f}")
    print(f"{'='*60}")
except Exception as e:
    print(f"  Error: {e}")
    import traceback
    traceback.print_exc()
    wrmsse_score = None

# ============================================================================
# Save
# ============================================================================

forecast_df = pd.DataFrame(forecast_array, index=series_order)
forecast_df.to_pickle(output_dir / 'forecasts.pkl')

summary = {
    'wrmsse': wrmsse_score,
    'valid_loss': valid_results['eval_loss'],
    'n_series': len(series_order),
    'context_length': CONTEXT_LENGTH,
    'covariates': covariate_cols,
}

with open(output_dir / 'summary.pkl', 'wb') as f:
    pickle.dump(summary, f)

print("\nDone")