## 5. Transformer Training with Early Stopping

### 5.1 Training Function

In [None]:
@error_handler
def train_transformer_with_early_stopping(model, X_train, y_train, areas_train, X_val, y_val, areas_val, device):
    """Train the optimized transformer model with error handling"""
    logger.info("🤖 TRAINING OPTIMIZED TRANSFORMER")
    logger.info("=" * 40)
    
    try:
        # Convert to tensors with error checking
        try:
            X_train_tensor = torch.FloatTensor(X_train).to(device)
            X_val_tensor = torch.FloatTensor(X_val).to(device)
            y_train_tensor = torch.LongTensor(y_train).to(device)
            y_val_tensor = torch.LongTensor(y_val).to(device)
            areas_train_tensor = torch.LongTensor(areas_train).to(device)
            areas_val_tensor = torch.LongTensor(areas_val).to(device)
        except Exception as e:
            raise ModelInitializationError(f"Error converting data to tensors: {e}")
        
        # Training setup
        optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
        criterion = nn.CrossEntropyLoss()
        
        best_val_acc = 0.0
        best_model_state = None
        patience_counter = 0
        
        logger.info(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
        logger.info(f"Training for {EPOCHS} epochs with early stopping...")
        
        # Create checkpoint directory
        os.makedirs('checkpoints', exist_ok=True)
        
        start_time = time.time()
        training_history = []
        
        # Training loop with error handling and checkpointing
        for epoch in range(EPOCHS):
            epoch_start_time = time.time()
            
            try:
                # Training
                model.train()
                optimizer.zero_grad()
                
                outputs = model(X_train_tensor, areas_train_tensor)
                loss = criterion(outputs['fire_logits'], y_train_tensor)
                
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                scheduler.step()
                
                # Validation
                model.eval()
                with torch.no_grad():
                    val_outputs = model(X_val_tensor, areas_val_tensor)
                    val_preds = torch.argmax(val_outputs['fire_logits'], dim=1)
                    val_acc = (val_preds == y_val_tensor).float().mean().item()
                    
                    # Track memory usage
                    if torch.cuda.is_available():
                        memory_usage = torch.cuda.memory_allocated() / 1e9  # GB
                    else:
                        memory_usage = 0
                    
                    # Log progress
                    epoch_time = time.time() - epoch_start_time
                    logger.info(f"Epoch {epoch:3d}: Loss={loss:.4f}, Val_Acc={val_acc:.4f}, "
                               f"Time={epoch_time:.1f}s, Memory={memory_usage:.2f}GB")
                    
                    # Update training history
                    training_history.append({
                        'epoch': epoch,
                        'train_loss': loss.item(),
                        'val_accuracy': val_acc,
                        'learning_rate': scheduler.get_last_lr()[0],
                        'time': epoch_time,
                        'memory_usage': memory_usage
                    })
                    
                    # Update dashboard
                    if epoch % VISUALIZATION_CONFIG['update_interval'] == 0:
                        update_dashboard(
                            dashboard_fig, dashboard_axes, dashboard_data,
                            epoch, loss.item(), val_acc, scheduler.get_last_lr()[0],
                            epoch_time, memory_usage
                        )
                    
                    # Check for improvement
                    if val_acc > best_val_acc:
                        best_val_acc = val_acc
                        best_model_state = model.state_dict().copy()
                        patience_counter = 0
                        
                        # Save checkpoint
                        checkpoint_path = f'checkpoints/transformer_epoch_{epoch}_acc_{val_acc:.4f}.pt'
                        torch.save({
                            'epoch': epoch,
                            'model_state_dict': model.state_dict(),
                            'optimizer_state_dict': optimizer.state_dict(),
                            'scheduler_state_dict': scheduler.state_dict(),
                            'val_accuracy': val_acc,
                            'training_history': training_history
                        }, checkpoint_path)
                        logger.info(f"   💾 Saved checkpoint: {checkpoint_path}")
                    else:
                        patience_counter += 1
                    
                    # Early stopping
                    if patience_counter >= EARLY_STOPPING_PATIENCE:
                        logger.info(f"Early stopping at epoch {epoch}")
                        break
                
            except RuntimeError as e:
                # Check if it's an out-of-memory error
                if "CUDA out of memory" in str(e):
                    logger.error(f"❌ CUDA out of memory at epoch {epoch}")
                    logger.error(f"   Try reducing batch size or model size")
                    
                    # Try to recover by clearing cache
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
                    
                    # Save current best model before exiting
                    if best_model_state:
                        recovery_path = f'checkpoints/transformer_recovery_acc_{best_val_acc:.4f}.pt'
                        torch.save({
                            'epoch': epoch,
                            'model_state_dict': best_model_state,
                            'val_accuracy': best_val_acc,
                            'training_history': training_history
                        }, recovery_path)
                        logger.info(f"   💾 Saved recovery checkpoint: {recovery_path}")
                    
                    raise TrainingProcessError(f"CUDA out of memory at epoch {epoch}")
                else:
                    # Other runtime errors
                    logger.error(f"❌ Runtime error at epoch {epoch}: {e}")
                    raise TrainingProcessError(f"Runtime error at epoch {epoch}: {e}")
            
            except Exception as e:
                logger.error(f"❌ Error during epoch {epoch}: {e}")
                logger.error(traceback.format_exc())
                
                # Try to save checkpoint before exiting
                if best_model_state:
                    recovery_path = f'checkpoints/transformer_recovery_acc_{best_val_acc:.4f}.pt'
                    torch.save({
                        'epoch': epoch,
                        'model_state_dict': best_model_state,
                        'val_accuracy': best_val_acc,
                        'training_history': training_history
                    }, recovery_path)
                    logger.info(f"   💾 Saved recovery checkpoint: {recovery_path}")
                
                raise TrainingProcessError(f"Error during epoch {epoch}: {e}")
        
        # Load best model
        if best_model_state:
            model.load_state_dict(best_model_state)
        
        training_time = time.time() - start_time
        logger.info(f"✅ Transformer training completed!")
        logger.info(f"   Best validation accuracy: {best_val_acc:.4f}")
        logger.info(f"   Training time: {training_time:.1f}s ({training_time/60:.1f} min)")
        
        return model, best_val_acc, training_history
        
    except Exception as e:
        # Catch any uncaught exceptions
        logger.error(f"❌ Unexpected error in transformer training: {e}")
        logger.error(traceback.format_exc())
        raise TrainingProcessError(f"Unexpected error in transformer training: {e}")

### 5.2 Learning Curves Visualization

In [None]:
def plot_learning_curves(training_history):
    """Plot learning curves from training history"""
    
    # Extract metrics
    epochs = [entry['epoch'] for entry in training_history]
    train_losses = [entry['train_loss'] for entry in training_history]
    val_accuracies = [entry['val_accuracy'] for entry in training_history]
    
    # Create figure
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    # Loss curve
    ax1.plot(epochs, train_losses, 'b-', marker='o', label='Training Loss')
    ax1.set_title('Training Loss Curve')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.grid(True)
    ax1.legend()
    
    # Add moving average
    window_size = min(5, len(train_losses))
    if window_size > 1:
        moving_avg = np.convolve(train_losses, np.ones(window_size)/window_size, mode='valid')
        ax1.plot(epochs[window_size-1:], moving_avg, 'r--', label=f'{window_size}-epoch Moving Avg')
        ax1.legend()
    
    # Accuracy curve
    ax2.plot(epochs, val_accuracies, 'g-', marker='o', label='Validation Accuracy')
    ax2.set_title('Validation Accuracy Curve')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy')
    ax2.grid(True)
    ax2.legend()
    
    # Add moving average
    if window_size > 1:
        moving_avg = np.convolve(val_accuracies, np.ones(window_size)/window_size, mode='valid')
        ax2.plot(epochs[window_size-1:], moving_avg, 'r--', label=f'{window_size}-epoch Moving Avg')
        ax2.legend()
    
    plt.tight_layout()
    
    # Save figure if enabled
    if VISUALIZATION_CONFIG['save_figures']:
        plt.savefig(f"{VISUALIZATION_CONFIG['figure_dir']}/learning_curves.png", dpi=300)
    
    plt.show()

### 5.3 Train Transformer Model

In [None]:
# Train transformer model
model, best_val_acc, training_history = train_transformer_with_early_stopping(
    model, X_train, y_train, areas_train, X_val, y_val, areas_val, device
)

# Plot learning curves
plot_learning_curves(training_history)

# Visualize memory usage
visualize_memory_usage(memory_usage, timestamps)

## 6. ML Ensemble Models

### 6.1 Feature Engineering

In [None]:
def engineer_features(X):
    """Advanced feature engineering for ML models"""
    
    logger.info("🔧 Engineering features...")
    start_time = time.time()
    
    features = []
    for i in range(X.shape[0]):
        sample_features = []
        for j in range(X.shape[2]):
            series = X[i, :, j]
            # Statistical features
            sample_features.extend([
                np.mean(series), np.std(series), np.min(series), np.max(series),
                np.median(series), np.percentile(series, 25), np.percentile(series, 75)
            ])
            # Trend features
            if len(series) > 1:
                slope = np.polyfit(range(len(series)), series, 1)[0]
                sample_features.append(slope)
                diff = np.diff(series)
                sample_features.extend([np.mean(np.abs(diff)), np.std(diff)])
            else:
                sample_features.extend([0, 0, 0])
        features.append(sample_features)
    
    features = np.array(features)
    
    logger.info(f"   ✅ Engineered {features.shape[1]} features for {features.shape[0]} samples")
    logger.info(f"   ⏱️ Time: {time.time() - start_time:.1f}s")
    
    return features

### 6.2 ML Ensemble Training

In [None]:
@error_handler
def train_optimized_ml_ensemble(X_train, y_train, X_val, y_val):
    """Train ML ensemble models"""
    logger.info("📊 TRAINING ML ENSEMBLE")
    logger.info("=" * 30)
    
    # Feature engineering
    logger.info("🔧 Engineering features...")
    X_train_features = engineer_features(X_train)
    X_val_features = engineer_features(X_val)
    
    # Scale features
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train_features)
    X_val_scaled = scaler.transform(X_val_features)
    
    ml_models = {}
    ml_results = {}
    
    # Random Forest
    logger.info("🌳 Training Random Forest...")
    rf_start_time = time.time()
    rf_model = RandomForestClassifier(
        n_estimators=100,  # Reduced from 200
        max_depth=15,
        random_state=RANDOM_SEED,
        n_jobs=-1
    )
    rf_model.fit(X_train_scaled, y_train)
    rf_val_acc = rf_model.score(X_val_scaled, y_val)
    ml_models['random_forest'] = rf_model
    ml_results['random_forest'] = rf_val_acc
    logger.info(f"   ✅ Random Forest Val Acc: {rf_val_acc:.4f}, Time: {time.time() - rf_start_time:.1f}s")
    
    # XGBoost (if available)
    if XGB_AVAILABLE:
        logger.info("⚡ Training XGBoost...")
        xgb_start_time = time.time()
        xgb_model = xgb.XGBClassifier(
            n_estimators=100,  # Reduced from 300
            max_depth=6,
            learning_rate=0.1,
            subsample=0.8,
            random_state=RANDOM_SEED
        )
        xgb_model.fit(X_train_scaled, y_train)
        xgb_val_acc = xgb_model.score(X_val_scaled, y_val)
        ml_models['xgboost'] = xgb_model
        ml_results['xgboost'] = xgb_val_acc
        logger.info(f"   ✅ XGBoost Val Acc: {xgb_val_acc:.4f}, Time: {time.time() - xgb_start_time:.1f}s")
    
    # LightGBM (if available)
    if LGB_AVAILABLE:
        logger.info("💡 Training LightGBM...")
        lgb_start_time = time.time()
        lgb_model = lgb.LGBMClassifier(
            n_estimators=100,  # Reduced from 300
            max_depth=6,
            learning_rate=0.1,
            subsample=0.8,
            random_state=RANDOM_SEED,
            verbose=-1
        )
        lgb_model.fit(X_train_scaled, y_train)
        lgb_val_acc = lgb_model.score(X_val_scaled, y_val)
        ml_models['lightgbm'] = lgb_model
        ml_results['lightgbm'] = lgb_val_acc
        logger.info(f"   ✅ LightGBM Val Acc: {lgb_val_acc:.4f}, Time: {time.time() - lgb_start_time:.1f}s")
    
    return ml_models, ml_results, scaler, X_train_features

### 6.3 Train ML Ensemble Models

In [None]:
# Train ML ensemble models
ml_models, ml_results, scaler, X_train_features = train_optimized_ml_ensemble(X_train, y_train, X_val, y_val)

# Print results
logger.info("ML Ensemble Results:")
for model_name, val_acc in ml_results.items():
    logger.info(f"   {model_name}: {val_acc:.4f}")