In [3]:
# ============================================================================
# TTM 
# ============================================================================

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import pickle
import warnings
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from m5_wrmsse import wrmsse
warnings.filterwarnings('ignore')

plt.style.use('seaborn-v0_8-darkgrid')
device = 'mps' if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() else 'cpu'
print(f"Device: {device}")

# ============================================================================
# 1. SETUP
# ============================================================================

DATA_DIR = Path("../data/processed")
RAW_DIR = Path("../data/raw")
OUTPUT_DIR = Path("../data/ttm_finetuned")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# Hyperparameters 
CONTEXT_LENGTH = 512
FORECAST_HORIZON = 28
BATCH_SIZE = 128          
NUM_EPOCHS = 10
LEARNING_RATE = 1e-3
PATIENCE = 3
MIN_DELTA = 0.001

print(f"\nConfig:")
print(f"  Context: {CONTEXT_LENGTH} giorni")
print(f"  Horizon: {FORECAST_HORIZON} giorni")
print(f"  Batch size: {BATCH_SIZE} (aumentato per velocità)")
print(f"  Max epochs: {NUM_EPOCHS}")
print(f"  Early stopping patience: {PATIENCE}")
print(f"  Device: {device}")

# ============================================================================
# 2. CARICA DATI
# ============================================================================

print("\n[1/6] Caricamento dati...")

with open(DATA_DIR / "train_official.pkl", 'rb') as f:
    train = pickle.load(f)
with open(DATA_DIR / "eval_official.pkl", 'rb') as f:
    eval_data = pickle.load(f)

print(f"✓ Train: {train.shape}")
print(f"✓ Eval: {eval_data.shape}")

# Pivot
train_pivot = train.pivot(index='date', columns='id', values='sales')
print(f"✓ Train pivot: {train_pivot.shape}")

# ============================================================================
# 3. DATASET
# ============================================================================

print("\n[2/6] Preparazione dataset...")

class M5TimeSeriesDataset(Dataset):
    def __init__(self, data, context_length, forecast_horizon):
        self.context_length = context_length
        self.forecast_horizon = forecast_horizon
        self.data = data
        self.samples = []
        
        print(f"  Creando windows da {len(self.data.columns)} serie...")
        for col in tqdm(self.data.columns, desc="  Processing"):
            series = self.data[col].values
            
            if len(series) < context_length + forecast_horizon:
                continue
            
            for i in range(0, len(series) - context_length - forecast_horizon + 1, 28):
                context = series[i:i+context_length]
                target = series[i+context_length:i+context_length+forecast_horizon]
                
                context_mean = context.mean()
                context_std = context.std() + 1e-6
                
                context_norm = (context - context_mean) / context_std
                target_norm = (target - context_mean) / context_std
                
                self.samples.append({
                    'context': torch.FloatTensor(context_norm),
                    'target': torch.FloatTensor(target_norm),
                    'mean': context_mean,
                    'std': context_std
                })
        
        print(f"  ✓ Created {len(self.samples)} training samples")
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        return self.samples[idx]

train_dataset = M5TimeSeriesDataset(train_pivot, CONTEXT_LENGTH, FORECAST_HORIZON)
print(f"✓ Dataset ready: {len(train_dataset)} samples")
print(f"  Batches per epoch: {len(train_dataset) // BATCH_SIZE}")

# ============================================================================
# 4. MODELLO
# ============================================================================

print("\n[3/6] Definizione modello...")

class TimeSeriesTransformer(nn.Module):
    def __init__(self, context_length, forecast_horizon, hidden_dim=256):
        super().__init__()
        self.context_length = context_length
        self.forecast_horizon = forecast_horizon
        
        self.encoder = nn.Sequential(
            nn.Linear(context_length, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.LayerNorm(hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.1),
        )
        
        self.decoder = nn.Sequential(
            nn.Linear(hidden_dim // 2, hidden_dim // 4),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim // 4, forecast_horizon)
        )
        
    def forward(self, x):
        encoded = self.encoder(x)
        forecast = self.decoder(encoded)
        return forecast

model = TimeSeriesTransformer(CONTEXT_LENGTH, FORECAST_HORIZON).to(device)
total_params = sum(p.numel() for p in model.parameters())

print(f"✓ Modello creato: {total_params:,} params")

# ============================================================================
# 5. TRAINING CON EARLY STOPPING
# ============================================================================

print("\n[4/6] Training...")

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=2, factor=0.5)
criterion = nn.MSELoss()

losses_history = []
best_loss = float('inf')
best_epoch = -1
patience_counter = 0

for epoch in range(NUM_EPOCHS):
    model.train()
    epoch_losses = []
    
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}")
    
    for batch in progress_bar:
        context = batch['context'].to(device)
        target = batch['target'].to(device)
        
        optimizer.zero_grad()
        output = model(context)
        loss = criterion(output, target)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        epoch_losses.append(loss.item())
        progress_bar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    avg_loss = np.mean(epoch_losses)
    losses_history.append(avg_loss)
    scheduler.step(avg_loss)
    current_lr = optimizer.param_groups[0]['lr']
    
    print(f"Epoch {epoch+1}/{NUM_EPOCHS} - Avg Loss: {avg_loss:.4f} - LR: {current_lr:.6f}")
    
    # Early stopping
    if avg_loss < best_loss - MIN_DELTA:
        best_loss = avg_loss
        best_epoch = epoch + 1  # Aggiorna best_epoch
        patience_counter = 0
        torch.save(model.state_dict(), OUTPUT_DIR / 'best_model.pt')
        print(f"  ✅ Best model saved (loss: {best_loss:.4f})")
    else:
        patience_counter += 1
        print(f"  ⚠️  No improvement ({patience_counter}/{PATIENCE})")
        
        if patience_counter >= PATIENCE:
            print(f"\n🛑 Early stopping at epoch {epoch+1}")
            break

print(f"\n✓ Training completato! Best loss: {best_loss:.4f}, Best epoch: {best_epoch}")

# ============================================================================
# 6. FORECASTING E CALCOLO WRMSSE
# ============================================================================

print("\n[5/6] Generazione forecasts...")

model.load_state_dict(torch.load(OUTPUT_DIR / 'best_model.pt'))
model.eval()

# Carica ordine originale serie
sales_orig = pd.read_csv(RAW_DIR / "sales_train_evaluation.csv")
series_order = sales_orig['id'].tolist()

print(f"Serie totali: {len(series_order)}")

all_forecasts = []

with torch.no_grad():
    for series_id in tqdm(series_order, desc="Forecasting"):
        if series_id in train_pivot.columns:
            context = train_pivot[series_id].tail(CONTEXT_LENGTH).values
            
            context_mean = context.mean()
            context_std = context.std() + 1e-6
            context_norm = (context - context_mean) / context_std
            
            context_tensor = torch.FloatTensor(context_norm).unsqueeze(0).to(device)
            output = model(context_tensor)
            
            forecast_norm = output.cpu().numpy().flatten()
            forecast = forecast_norm * context_std + context_mean
            forecast = np.maximum(forecast, 0)
        else:
            if series_id in train['id'].values:
                last_val = train[train['id'] == series_id]['sales'].iloc[-1]
                forecast = np.full(FORECAST_HORIZON, last_val)
            else:
                forecast = np.zeros(FORECAST_HORIZON)
        
        all_forecasts.append(forecast)

forecast_array = np.array(all_forecasts)
print(f"✓ Forecast array: {forecast_array.shape}")

print("\n[6/6] Calcolo WRMSSE...")

wrmsse_score = wrmsse(forecast_array)

print(f"\n✅ WRMSSE: {wrmsse_score:.4f}")

# ============================================================================
# 7. SALVATAGGIO
# ============================================================================

forecast_df = pd.DataFrame(forecast_array, index=series_order)
forecast_df.to_pickle(OUTPUT_DIR / 'ttm_forecasts.pkl')

summary = {
    'wrmsse': wrmsse_score,
    'epochs_trained': len(losses_history),
    'best_epoch': best_epoch,
    'best_loss': best_loss,
    'n_series': len(series_order),
}

with open(OUTPUT_DIR / 'ttm_summary.pkl', 'wb') as f:
    pickle.dump(summary, f)

print("\n✓ Summary salvato")



Device: mps

Config:
  Context: 512 giorni
  Horizon: 28 giorni
  Batch size: 128 (aumentato per velocità)
  Max epochs: 10
  Early stopping patience: 3
  Device: mps

[1/6] Caricamento dati...
✓ Train: (58327370, 11)
✓ Eval: (853720, 11)
✓ Train pivot: (1913, 30490)

[2/6] Preparazione dataset...
  Creando windows da 30490 serie...


  Processing: 100%|█████████████████████| 30490/30490 [00:25<00:00, 1194.30it/s]


  ✓ Created 1524500 training samples
✓ Dataset ready: 1524500 samples
  Batches per epoch: 11910

[3/6] Definizione modello...
✓ Modello creato: 241,372 params

[4/6] Training...


Epoch 1/10: 100%|███████████| 11911/11911 [01:08<00:00, 172.80it/s, loss=4.0293]


Epoch 1/10 - Avg Loss: 34458409384.0568 - LR: 0.001000
  ✅ Best model saved (loss: 34458409384.0568)


Epoch 2/10: 100%|███████████| 11911/11911 [01:07<00:00, 176.07it/s, loss=2.7698]


Epoch 2/10 - Avg Loss: 34458364809.2081 - LR: 0.001000
  ✅ Best model saved (loss: 34458364809.2081)


Epoch 3/10: 100%|███████████| 11911/11911 [01:08<00:00, 172.74it/s, loss=7.5165]


Epoch 3/10 - Avg Loss: 34458420115.6638 - LR: 0.001000
  ⚠️  No improvement (1/3)


Epoch 4/10: 100%|█| 11911/11911 [01:10<00:00, 168.41it/s, loss=39285538816.0000]


Epoch 4/10 - Avg Loss: 34461181251.8401 - LR: 0.000500
  ⚠️  No improvement (2/3)


Epoch 5/10: 100%|███████████| 11911/11911 [01:09<00:00, 170.97it/s, loss=8.3934]


Epoch 5/10 - Avg Loss: 34458409815.1267 - LR: 0.000500
  ⚠️  No improvement (3/3)

🛑 Early stopping at epoch 5

✓ Training completato! Best loss: 34458364809.2081, Best epoch: 2

[5/6] Generazione forecasts...
Serie totali: 30490


Forecasting: 100%|██████████████████████| 30490/30490 [00:26<00:00, 1165.20it/s]


✓ Forecast array: (30490, 28)

[6/6] Calcolo WRMSSE...

✅ WRMSSE: 0.9783

✓ Summary salvato
