In [2]:
"""
TSB (Teunter-Syntetos-Babai)
"""

import pandas as pd
import numpy as np
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

from statsforecast import StatsForecast
from statsforecast.models import TSB
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import pickle

from m5_wrmsse import wrmsse

print("\n" + "="*80)
print("M5 FORECASTING - TSB FOR INTERMITTENT DEMAND")
print("="*80 + "\n")

# ============================================================================
# 1. SETUP & CARICAMENTO DATI
# ============================================================================

print("1/7 Caricamento dati...")

DATA_PATH = Path("../data")
RAW_DIR = DATA_PATH / "raw"
OUTPUT_DIR = DATA_PATH / "tsb_results"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# Sales
sales = pd.read_csv(RAW_DIR / "sales_train_evaluation.csv")
id_cols = ['id', 'item_id', 'dept_id', 'cat_id', 'store_id', 'state_id']
print(f"Sales: {sales.shape}")

cluster_df = pd.read_csv("serie_cluster_difficulty.csv")
print(f"Cluster data: {cluster_df.shape}")

# Test actuals
train_days = 1913
test_days = 28
test_cols = [f'd_{i}' for i in range(train_days + 1, train_days + test_days + 1)]
test_actuals = sales[['id'] + test_cols].copy()
test_actuals.columns = ['id'] + [f'F{i}' for i in range(1, test_days + 1)]
print(f"Test actuals: {test_actuals.shape}")

# ============================================================================
# 2. PREPARAZIONE DATI IN FORMATO LONG
# ============================================================================

print("\n2/7 Preparazione formato long per StatsForecast...")

# Training data (giorni 1-1913)
train_cols = [f'd_{i}' for i in range(1, train_days + 1)]
sales_train = sales[['id'] + train_cols].copy()

# Converti in long format
sales_long = sales_train.melt(
    id_vars=['id'],
    value_vars=train_cols,
    var_name='d',
    value_name='y' 
)

# Estrai day number
sales_long['d'] = sales_long['d'].str.extract('(\d+)').astype(int)

# Crea colonna 'ds' (date) 
start_date = pd.Timestamp('2011-01-29')
sales_long['ds'] = pd.to_datetime(start_date) + pd.to_timedelta(sales_long['d'] - 1, unit='D')
sales_long = sales_long.rename(columns={'id': 'unique_id'})

# Keep only required columns
sales_long = sales_long[['unique_id', 'ds', 'y']].sort_values(['unique_id', 'ds'])

print(f"Sales long: {sales_long.shape}")
print(f"Date range: {sales_long['ds'].min()} to {sales_long['ds'].max()}")
print(f"\nSample:")
print(sales_long.head(10))

# ============================================================================
# 3. ANALISI SPARSITY & IDENTIFICAZIONE SERIE INTERMITTENTI
# ============================================================================

print("\n3/7 Analisi sparsity e identificazione serie intermittenti...")

# Calcola sparsity per ogni serie
sparsity_df = sales_train.set_index('id')[train_cols].apply(
    lambda row: (row == 0).mean(), axis=1
).reset_index()
sparsity_df.columns = ['id', 'sparsity']

# Merge con cluster info
sparsity_df = sparsity_df.merge(cluster_df[['id', 'difficulty', 'cluster']], on='id', how='left')

# Statistiche sparsity per difficulty
print("\nSparsity per difficulty level:")
for diff in ['Easy', 'Medium', 'Hard']:
    subset = sparsity_df[sparsity_df['difficulty'] == diff]['sparsity']
    print(f"  {diff:8s}: mean={subset.mean():.2%}, median={subset.median():.2%}, "
          f"n={len(subset):5d}")

# Identifica serie intermittent (sparsity > 50% o difficulty = Hard)
intermittent_threshold = 0.50
intermittent_ids = sparsity_df[
    (sparsity_df['sparsity'] > intermittent_threshold) | 
    (sparsity_df['difficulty'] == 'Hard')
]['id'].tolist()

print(f"\n✓ Serie intermittent identificate: {len(intermittent_ids):,} ({len(intermittent_ids)/len(sales)*100:.1f}%)")

# ============================================================================
# 4. TSB FORECASTING - TUTTE LE SERIE
# ============================================================================

print("\n4/7 TSB forecasting su tutte le serie...")

# StatsForecast con TSB
sf = StatsForecast(
    models=[TSB(alpha_d=0.2, alpha_p=0.2)],
    freq='D',
    n_jobs=-1
)


# Forecast SENZA prediction intervals
forecasts_tsb = sf.forecast(df=sales_long, h=test_days)

print(f"✓ TSB forecasts shape: {forecasts_tsb.shape}")
print(f"  Columns: {forecasts_tsb.columns.tolist()}")

# ============================================================================
# 5. POST-PROCESSING & FORMATTAZIONE
# ============================================================================

print("\n5/7 Post-processing forecasts...")

# Reset index per avere unique_id come colonna
forecasts_tsb = forecasts_tsb.reset_index()

# Pivot per avere formato [30490 x 28]
forecasts_wide = forecasts_tsb.pivot(index='unique_id', columns='ds', values='TSB')

# Rinomina colonne in F1, F2, ..., F28
forecasts_wide.columns = [f'F{i}' for i in range(1, test_days + 1)]

# Reset index per avere 'id' come colonna
forecasts_wide = forecasts_wide.reset_index()
forecasts_wide = forecasts_wide.rename(columns={'unique_id': 'id'})

# Reindex per matchare ordine originale
forecasts_wide = forecasts_wide.set_index('id').reindex(sales['id']).reset_index()

# Riempi NaN con 0 e clip negativi
test_all_cols = [f'F{i}' for i in range(1, 29)]
for col in test_all_cols:
    forecasts_wide[col] = forecasts_wide[col].fillna(0).clip(lower=0)

print(f"✓ Forecasts formatted: {forecasts_wide.shape}")
print(f"  Columns: {forecasts_wide.columns.tolist()[:5]}...")

# ============================================================================
# 6. WRMSSE EVALUATION
# ============================================================================

print("\n6/7 Calcolo WRMSSE...")

# WRMSSE globale
forecast_array = forecasts_wide[test_all_cols].values
wrmsse_tsb = wrmsse(forecast_array)

print(f"\n✅ TSB WRMSSE (global): {wrmsse_tsb:.4f}")

# WRMSSE stratificato per difficulty
print("\nWRMSSE per difficulty level:")

# Merge forecasts con difficulty
forecasts_with_diff = forecasts_wide.merge(
    cluster_df[['id', 'difficulty', 'sparsity']], 
    on='id', 
    how='left'
)

wrmsse_by_difficulty = {}

for diff in ['Easy', 'Medium', 'Hard']:
    mask = forecasts_with_diff['difficulty'] == diff
    subset_ids = forecasts_with_diff[mask]['id'].tolist()
    
    if len(subset_ids) == 0:
        continue
    
    # Subset forecasts e actuals
    subset_forecasts = forecasts_wide[forecasts_wide['id'].isin(subset_ids)][test_all_cols].values
    subset_actuals_df = test_actuals[test_actuals['id'].isin(subset_ids)]
    
    # MAE su subset (proxy WRMSSE)
    mae = np.abs(
        subset_forecasts - subset_actuals_df[test_all_cols].values
    ).mean()
    
    wrmsse_by_difficulty[diff] = mae
    
    print(f"  {diff:8s}: MAE={mae:.4f}, n={mask.sum():5d} serie, "
          f"sparsity_mean={forecasts_with_diff[mask]['sparsity'].mean():.2%}")


# ============================================================================
# 7. SALVATAGGIO
# ============================================================================

print("\nSalvataggio risultati...")

# Salva forecasts (formato compatibile con Notebook 09)
forecasts_wide.to_pickle(OUTPUT_DIR / 'tsb_forecasts.pkl')
print(f"✓ tsb_forecasts.pkl ({forecasts_wide.shape})")

# Summary
tsb_summary = {
    'wrmsse_global': float(wrmsse_tsb),
    'wrmsse_by_difficulty': {k: float(v) for k, v in wrmsse_by_difficulty.items()},
    'n_intermittent_series': len(intermittent_ids),
    'intermittent_threshold': intermittent_threshold,
    'model': 'TSB',
    'alpha_d': 0.2,
    'alpha_p': 0.2
}

with open(OUTPUT_DIR / 'tsb_summary.pkl', 'wb') as f:
    pickle.dump(tsb_summary, f)
print("✓ tsb_summary.pkl")



M5 FORECASTING - TSB FOR INTERMITTENT DEMAND

1/7 Caricamento dati...
Sales: (30490, 1947)
Cluster data: (30490, 21)
Test actuals: (30490, 29)

2/7 Preparazione formato long per StatsForecast...
Sales long: (58327370, 3)
Date range: 2011-01-29 00:00:00 to 2016-04-24 00:00:00

Sample:
                          unique_id         ds  y
1612    FOODS_1_001_CA_1_evaluation 2011-01-29  3
32102   FOODS_1_001_CA_1_evaluation 2011-01-30  0
62592   FOODS_1_001_CA_1_evaluation 2011-01-31  0
93082   FOODS_1_001_CA_1_evaluation 2011-02-01  1
123572  FOODS_1_001_CA_1_evaluation 2011-02-02  4
154062  FOODS_1_001_CA_1_evaluation 2011-02-03  2
184552  FOODS_1_001_CA_1_evaluation 2011-02-04  0
215042  FOODS_1_001_CA_1_evaluation 2011-02-05  2
245532  FOODS_1_001_CA_1_evaluation 2011-02-06  0
276022  FOODS_1_001_CA_1_evaluation 2011-02-07  0

3/7 Analisi sparsity e identificazione serie intermittenti...

Sparsity per difficulty level:
  Easy    : mean=21.78%, median=21.12%, n= 1857
  Medium  : mean=60.7

Forecast: 100%|██████████| 3049/3049 [00:03<00:00, 858.03it/s]
Forecast: 100%|██████████| 3049/3049 [00:03<00:00, 901.83it/s]
Forecast: 100%|██████████| 3049/3049 [00:03<00:00, 875.61it/s]
Forecast: 100%|██████████| 3049/3049 [00:03<00:00, 882.62it/s]
Forecast: 100%|██████████| 3049/3049 [00:03<00:00, 882.57it/s]
Forecast: 100%|██████████| 3049/3049 [00:03<00:00, 862.47it/s]
Forecast: 100%|██████████| 3049/3049 [00:03<00:00, 867.38it/s]
Forecast: 100%|██████████| 3049/3049 [00:03<00:00, 869.85it/s]
Forecast: 100%|██████████| 3049/3049 [00:03<00:00, 882.63it/s]
Forecast: 100%|██████████| 3049/3049 [00:03<00:00, 871.93it/s]


✓ TSB forecasts shape: (853720, 2)
  Columns: ['ds', 'TSB']

5/7 Post-processing forecasts...
✓ Forecasts formatted: (30490, 29)
  Columns: ['id', 'F1', 'F2', 'F3', 'F4']...

6/7 Calcolo WRMSSE...

✅ TSB WRMSSE (global): 1.0733

WRMSSE per difficulty level:
  Easy    : MAE=2.5609, n= 1857 serie, sparsity_mean=21.78%
  Medium  : MAE=1.2256, n=18316 serie, sparsity_mean=60.74%
  Hard    : MAE=0.5303, n=10317 serie, sparsity_mean=89.79%

Salvataggio risultati...
✓ tsb_forecasts.pkl ((30490, 29))
✓ tsb_summary.pkl
