In [None]:
"""
Módulo unificado para entrenamiento de Redes Neuronales Bayesianas (BNN) con MFVI,
con diagnóstico visual del entrenamiento y validación temporal.

Novedades de visualización y diagnóstico:
1) Split temporal de validación dentro de main_* para monitorizar RMSE y NLL en validación.
2) Registro por época de NLL, KL por muestra, ELBO, tasa de aprendizaje, beta efectiva y sigma_y.
3) Gráfica integral con seis paneles: NLL train y val, KL, ELBO, RMSE val, LR y beta, sigma_y.
4) Gráficas de dispersión y serie temporal en validación al final del entrenamiento.

Nota: No se usa el ID de la estación como variable predictora.
"""

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from datetime import timedelta
import seaborn as sns
import joblib
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as dist
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error
from typing import Dict, List, Tuple, Optional, Union
import warnings
from scipy.stats import zscore
import math
import argparse

warnings.filterwarnings('ignore')
torch.classes.__path__ = []  # estabilidad para entornos sin extensiones

EPS = 1e-6

# ==================== COMPONENTES BNN (MFVI) ====================

class MFVILinear(nn.Module):
    """Capa lineal Bayesiana usando Inferencia Variacional de Campo Medio (MFVI)."""

    def __init__(self, dim_in, dim_out, prior_weight_std=1.0, prior_bias_std=1.0, init_std=0.05, device=None, dtype=None):
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.dim_in = dim_in
        self.dim_out = dim_out

        self.weight_mean = nn.Parameter(torch.empty((dim_out, dim_in), **factory_kwargs))
        self.bias_mean = nn.Parameter(torch.empty(dim_out, **factory_kwargs))
        self._weight_std_param = nn.Parameter(torch.empty((dim_out, dim_in), **factory_kwargs))
        self._bias_std_param = nn.Parameter(torch.empty(dim_out, **factory_kwargs))
        
        self.reset_parameters(init_std)

        prior_mean = 0.0
        self.register_buffer('prior_weight_mean', torch.full_like(self.weight_mean, prior_mean))
        self.register_buffer('prior_weight_std', torch.full_like(self._weight_std_param, prior_weight_std))
        self.register_buffer('prior_bias_mean', torch.full_like(self.bias_mean, prior_mean))
        self.register_buffer('prior_bias_std', torch.full_like(self._bias_std_param, prior_bias_std))

    def reset_parameters(self, init_std=0.05):
        nn.init.kaiming_uniform_(self.weight_mean, a=math.sqrt(5))
        bound = self.dim_in ** -0.5 if self.dim_in > 0 else 0
        nn.init.uniform_(self.bias_mean, -bound, bound)
        _init_std_param = np.log(init_std)
        self._weight_std_param.data = torch.full_like(self.weight_mean, _init_std_param)
        self._bias_std_param.data = torch.full_like(self.bias_mean, _init_std_param)

    @property
    def weight_std(self):
        return torch.clamp(torch.exp(self._weight_std_param), min=EPS)

    @property
    def bias_std(self):
        return torch.clamp(torch.exp(self._bias_std_param), min=EPS)

    def kl_divergence(self):
        q_weight = dist.Normal(self.weight_mean, self.weight_std)
        p_weight = dist.Normal(self.prior_weight_mean, self.prior_weight_std)
        kl = dist.kl_divergence(q_weight, p_weight).sum()
        
        q_bias = dist.Normal(self.bias_mean, self.bias_std)
        p_bias = dist.Normal(self.prior_bias_mean, self.prior_bias_std)
        kl += dist.kl_divergence(q_bias, p_bias).sum()
        return kl

    def forward(self, input):
        weight = self._normal_sample(self.weight_mean, self.weight_std)
        bias = self._normal_sample(self.bias_mean, self.bias_std)
        return F.linear(input, weight, bias)

    def _normal_sample(self, mean, std):
        epsilon = torch.randn_like(std)
        return mean + std * epsilon


def make_mfvi_bnn(layer_sizes, activation='GELU', **layer_kwargs):
    nonlinearity = getattr(nn, activation)() if isinstance(activation, str) else activation
    net = nn.Sequential()
    for i, (dim_in, dim_out) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
        net.add_module(f'MFVILinear{i}', MFVILinear(dim_in, dim_out, **layer_kwargs))
        if i < len(layer_sizes) - 2:
            net.add_module(f'Nonlinarity{i}', nonlinearity)
    return net


def kl_divergence_model(bnn):
    kl = 0.0
    for module in bnn.modules():
        if hasattr(module, 'kl_divergence'):
            kl += module.kl_divergence()
    return kl


def gauss_loglik(y, y_pred, log_noise_var):
    l2_dist = (y - y_pred).pow(2).sum(-1)
    return -0.5 * (log_noise_var + math.log(2 * math.pi) + l2_dist * torch.exp(-log_noise_var))


def test_nll(y, y_pred, log_noise_var):
    nll_samples = -gauss_loglik(y, y_pred, log_noise_var)  # (K, N)
    nll = -torch.logsumexp(-nll_samples, dim=0) + math.log(nll_samples.shape[0])
    return nll.mean()

# ==================== DATASET PYTORCH ====================

class BayesianDataset(Dataset):
    def __init__(self, features, target):
        self.features = torch.tensor(features.values, dtype=torch.float32)
        self.target = torch.tensor(target.values, dtype=torch.float32).reshape(-1, 1)

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        return self.features[idx], self.target[idx]

# ==================== ENTRENADOR UNIFICADO ====================

class BNNUnifiedTrainer:
    """
    Clase unificada para entrenamiento de BNNs que maneja tanto modelos individuales 
    como globales con una interfaz consistente.
    """
    
    def __init__(self):
        self.df_master = None
        self.model = None
        self.scaler_dict = {}
        self.scaler_target = None
        self.device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
        torch.manual_seed(42)
        np.random.seed(42)
        if 'cuda' in self.device:
            torch.cuda.manual_seed_all(42)
        print(f"Using device: {self.device}")

    def load_data(self) -> pd.DataFrame:
        try:
            df = pd.read_parquet('../data/super_processed/7_4_no2_with_traffic_and_1meteo_and_1trafic_id.parquet')
            df['fecha'] = pd.to_datetime(df['fecha'])
            return df
        except Exception as e:
            print(f"Error al cargar los datos: {str(e)}")
            raise e
    
    def get_available_sensors(self) -> List[str]:
        if self.df_master is None:
            self.df_master = self.load_data()
        return sorted(self.df_master['id_no2'].unique().tolist())

    def remove_outliers(self, df: pd.DataFrame, method: str, columns: List[str] = None) -> pd.DataFrame:
        if method == 'none': 
            return df
        if columns is None:
            columns = ['no2_value']
        df_filtered = df.copy()
        for col in columns:
            if col not in df_filtered.columns: 
                continue
            if method == 'iqr':
                Q1, Q3 = df_filtered[col].quantile(0.25), df_filtered[col].quantile(0.75)
                IQR = Q3 - Q1
                if IQR > 0:
                    lower, upper = Q1 - 1.5 * IQR, Q3 + 1.5 * IQR
                    df_filtered = df_filtered[(df_filtered[col] >= lower) & (df_filtered[col] <= upper)]
            elif method == 'zscore':
                df_filtered = df_filtered[np.abs(zscore(df_filtered[col], nan_policy='omit')) < 3]
        return df_filtered

    def split_data(self, df: pd.DataFrame, split_date: pd.Timestamp) -> Tuple[pd.DataFrame, pd.DataFrame]:
        train = df[df['fecha'] < split_date].copy()
        test = df[df['fecha'] >= split_date].copy()
        return train, test

    def scale_features(self, X_train: pd.DataFrame, X_other: pd.DataFrame, features: List[str]) -> Tuple[pd.DataFrame, pd.DataFrame, Dict]:
        scaler_dict = {}
        X_train_s, X_other_s = X_train.copy(), X_other.copy()
        for feature in features:
            if feature in X_train.columns and pd.api.types.is_numeric_dtype(X_train[feature]):
                scaler = StandardScaler()
                X_train_s[feature] = scaler.fit_transform(X_train[[feature]]).flatten()
                X_other_s[feature] = scaler.transform(X_other[[feature]]).flatten()
                scaler_dict[feature] = scaler
        return X_train_s, X_other_s, scaler_dict

    def apply_feature_scalers(self, X: pd.DataFrame, scaler_dict: Dict, features: List[str]) -> pd.DataFrame:
        Xs = X.copy()
        for feature in features:
            if feature in scaler_dict and feature in X.columns:
                Xs[feature] = scaler_dict[feature].transform(X[[feature]]).flatten()
        return Xs

    def scale_target(self, y_train: pd.Series) -> Tuple[pd.Series, StandardScaler]:
        scaler = StandardScaler()
        y_scaled = scaler.fit_transform(y_train.values.reshape(-1, 1)).flatten()
        return pd.Series(y_scaled, index=y_train.index, name=y_train.name), scaler

    def train_bnn_model(self, X_train, y_train, train_config,
                        X_val=None, y_val=None, y_val_scaled=None, scaler_target=None):
        """Entrenar el modelo BNN con diagnósticos y validación opcional."""
        dataset = BayesianDataset(X_train, y_train)
        auto_bs = min(train_config.get('batch_size', 512), max(64, len(dataset) // 4))
        dataloader = DataLoader(dataset, batch_size=auto_bs, shuffle=True)

        layer_sizes = [X_train.shape[1]] + train_config['hidden_dims'] + [1]
        model = make_mfvi_bnn(
            layer_sizes,
            activation=train_config['activation'],
            prior_weight_std=1.0,
            prior_bias_std=1.0,
            init_std=0.05,
            device=self.device
        ).to(self.device)

        log_noise_var = nn.Parameter(torch.ones(1, device=self.device) * -3.0)
        params = list(model.parameters()) + [log_noise_var]
        optimizer = torch.optim.Adam(params, lr=train_config['learning_rate'])

        T = train_config['n_epochs']
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=[int(0.6*T), int(0.85*T)], gamma=0.5
        )

        warmup = max(1, int(0.6 * T))
        N_data = len(dataset)

        # Tensores de validación si existen
        if X_val is not None:
            X_val_tensor = torch.tensor(X_val.values, dtype=torch.float32).to(self.device)
            y_val_tensor = torch.tensor(y_val_scaled.values, dtype=torch.float32).reshape(-1, 1).to(self.device)

        logs = []  # se almacenan dicts con todas las métricas por época

        print("Inicio del entrenamiento BNN con registro de métricas por época")
        for epoch in range(T):
            model.train()
            epoch_nll, epoch_kl = 0.0, 0.0

            beta_t = train_config.get('beta', 1.0) * min(1.0, (epoch + 1) / warmup)

            for x_batch, y_batch in dataloader:
                x_batch, y_batch = x_batch.to(self.device), y_batch.to(self.device)
                optimizer.zero_grad()

                y_pred = model(x_batch)
                nll = -gauss_loglik(y_batch, y_pred, log_noise_var).mean()
                kl = kl_divergence_model(model)

                loss = nll + beta_t * kl * (len(x_batch) / N_data)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()

                epoch_nll += nll.item() * len(x_batch)
                epoch_kl += kl.item()

            scheduler.step()

            # Métricas de validación periódicas
            val_rmse = None
            val_nll = None
            if X_val is not None and (epoch % train_config.get('eval_every', 1) == 0):
                model.eval()
                with torch.no_grad():
                    K_eval = train_config.get('K_eval', 30)
                    y_preds_val = torch.stack([model(X_val_tensor) for _ in range(K_eval)], dim=0)
                    # NLL en validación en espacio escalado
                    val_nll = test_nll(y_val_tensor, y_preds_val, log_noise_var).item()
                    # Media y RMSE en espacio original
                    pred_mean_val_scaled = y_preds_val.mean(0).detach().cpu().numpy()
                    pred_mean_val = scaler_target.inverse_transform(pred_mean_val_scaled).flatten()
                    val_rmse = float(np.sqrt(mean_squared_error(y_val.values, pred_mean_val)))

            # Registro por época
            current_lr = float(optimizer.param_groups[0]['lr'])
            sigma_y = float(torch.exp(0.5 * log_noise_var).item())
            nll_per_sample = epoch_nll / N_data
            kl_per_sample = epoch_kl / N_data
            elbo_per_sample = nll_per_sample + beta_t * kl_per_sample

            logs.append({
                'epoch': epoch + 1,
                'train_nll': nll_per_sample,
                'train_kl_per_sample': kl_per_sample,
                'train_elbo': elbo_per_sample,
                'val_nll': val_nll,
                'val_rmse': val_rmse,
                'lr': current_lr,
                'beta_t': beta_t,
                'sigma_y': sigma_y,
            })

            if (epoch + 1) % 10 == 0 or epoch == 0:
                msg = f"Época {epoch+1}/{T}  NLL:{nll_per_sample:.3f}  KL/N:{kl_per_sample:.3f}  ELBO:{elbo_per_sample:.3f}"
                if val_rmse is not None:
                    msg += f"  Val RMSE:{val_rmse:.2f}  Val NLL:{val_nll:.3f}"
                msg += f"  LR:{current_lr:.4g}  beta:{beta_t:.3f}  sigma_y:{sigma_y:.3f}"
                print(msg)

        print("Entrenamiento completado")
        return model, log_noise_var, logs

    def predict(self, model, X_test, K, log_noise_var):
        model.eval()
        X_test_tensor = torch.tensor(X_test.values, dtype=torch.float32).to(self.device)
        with torch.no_grad():
            y_preds = torch.stack([model(X_test_tensor) for _ in range(K)], dim=0)
        pred_mean = y_preds.mean(0)
        epistemic_uncertainty = y_preds.var(0).sqrt()
        aleatoric_uncertainty = torch.exp(0.5 * log_noise_var).expand_as(pred_mean)
        total_uncertainty = (epistemic_uncertainty**2 + aleatoric_uncertainty**2).sqrt()
        return {
            'y_preds_all': y_preds,
            'mean': pred_mean,
            'epistemic_std': epistemic_uncertainty,
            'aleatoric_std': aleatoric_uncertainty,
            'total_std': total_uncertainty
        }

    def evaluate_model(self, predictions, y_test, y_test_scaled, scaler_target, log_noise_var, sensor_id: str = None):
        pred_mean_scaled = predictions['mean'].detach().cpu().numpy()
        pred_mean = scaler_target.inverse_transform(pred_mean_scaled).flatten()

        total_std_scaled = predictions['total_std'].detach().cpu().numpy().flatten()
        unscaled_std = total_std_scaled * scaler_target.scale_[0]
        
        epistemic_std = predictions['epistemic_std'].detach().cpu().numpy().flatten()
        epistemic_unscaled_std = epistemic_std * scaler_target.scale_[0]

        df_preds = pd.DataFrame({
            'prediction': pred_mean,
            'epistemic_uncertainty': epistemic_unscaled_std
        })
        sensor_suffix = f"_{sensor_id}" if sensor_id else "_global"
        filename = f'../predictions/bnn_predictions_with_epistemic_uncertainty{sensor_suffix}.csv'
        os.makedirs(os.path.dirname(filename), exist_ok=True)
        df_preds.to_csv(filename, index=False)
        print(f"Predicciones guardadas en: {filename}")

        rmse = np.sqrt(mean_squared_error(y_test, pred_mean))
        r2 = r2_score(y_test, pred_mean)
        mae = mean_absolute_error(y_test, pred_mean)
        
        y_test_tensor = torch.tensor(y_test_scaled.values, dtype=torch.float32).reshape(-1, 1).to(self.device)
        nll = test_nll(y_test_tensor, predictions['y_preds_all'], log_noise_var).item()
        
        return {
            'rmse': rmse, 'r2': r2, 'mae': mae, 'test_nll': nll,
            'y_pred': pred_mean, 'y_pred_std': unscaled_std,
            'predictions_df': df_preds
        }

    def evaluate_global_by_sensor(self, data_prep: Dict, model, log_noise_var, scaler_dict, 
                                  scaler_target, selected_features, K_predict: int) -> Dict:
        results_by_sensor = {}
        test_df = data_prep['test_df']
        print("Evaluación global por sensor")

        for sensor_id in data_prep['sensors_test']:
            print(f"Evaluando sensor: {sensor_id}")
            sensor_test_df = test_df[test_df['id_no2'] == sensor_id].copy()
            if len(sensor_test_df) == 0:
                print(f"No hay datos de test para sensor {sensor_id}")
                continue
            
            X_test_sensor = sensor_test_df[selected_features].copy()
            y_test_sensor = sensor_test_df['no2_value'].copy()
            X_test_sensor_scaled = self.apply_feature_scalers(X_test_sensor, scaler_dict, selected_features)

            y_test_sensor_scaled = pd.Series(
                scaler_target.transform(y_test_sensor.values.reshape(-1, 1)).flatten(),
                index=y_test_sensor.index,
                name=y_test_sensor.name
            )
            predictions = self.predict(model, X_test_sensor_scaled, K_predict, log_noise_var)
            metrics = self.evaluate_model(predictions, y_test_sensor, y_test_sensor_scaled, 
                                          scaler_target, log_noise_var, sensor_id)
            results_by_sensor[sensor_id] = {
                'metrics': metrics,
                'test_df': sensor_test_df,
                'n_samples': len(sensor_test_df)
            }
            print(f"RMSE: {metrics['rmse']:.2f}, R²: {metrics['r2']:.3f}, MAE: {metrics['mae']:.2f}")
        return results_by_sensor
    
    def save_model(self, path: str, model, log_noise_var, scaler_dict, scaler_target, 
                   feature_names: List[str], model_config: Dict = None):
        os.makedirs(os.path.dirname(path), exist_ok=True)
        model_state = {
            'model_state_dict': model.state_dict(),
            'log_noise_var': log_noise_var,
            'scaler_dict': scaler_dict,
            'scaler_target': scaler_target,
            'feature_names': feature_names,
            'model_config': model_config or {}
        }
        joblib.dump(model_state, path)
        print(f"Modelo guardado en: {path}")

    def load_model(self, path: str, layer_sizes: List[int], activation: str = 'GELU'):
        if not os.path.exists(path):
            print(f"Modelo no encontrado en: {path}")
            return None
        model_state = joblib.load(path)
        model = make_mfvi_bnn(layer_sizes, activation=activation, device=self.device).to(self.device)
        model.load_state_dict(model_state['model_state_dict'])
        log_noise_var = model_state['log_noise_var'].to(self.device)
        return (model, log_noise_var, model_state['scaler_dict'], 
                model_state['scaler_target'], model_state['feature_names'])

# ==================== VISUALIZACIÓN Y REPORTE ====================

def print_model_metrics(metrics: Dict, title: str = "Métricas de Evaluación"):
    print(f"\n{title}")
    print(f"RMSE: {metrics['rmse']:.2f} µg/m³")
    print(f"R²: {metrics['r2']:.3f}")
    print(f"MAE: {metrics['mae']:.2f} µg/m³")
    print(f"Test NLL: {metrics['test_nll']:.3f}")

def print_global_summary(results_by_sensor: Dict):
    print("\nResumen Modelo Global por Sensor")
    print("="*60)
    all_rmse = [results_by_sensor[s]['metrics']['rmse'] for s in results_by_sensor]
    all_r2 = [results_by_sensor[s]['metrics']['r2'] for s in results_by_sensor]
    all_mae = [results_by_sensor[s]['metrics']['mae'] for s in results_by_sensor]
    print(f"RMSE promedio: {np.mean(all_rmse):.2f} ± {np.std(all_rmse):.2f} µg/m³")
    print(f"R² promedio: {np.mean(all_r2):.3f} ± {np.std(all_r2):.3f}")
    print(f"MAE promedio: {np.mean(all_mae):.2f} ± {np.std(all_mae):.2f} µg/m³")
    for sensor_id in sorted(results_by_sensor.keys()):
        m = results_by_sensor[sensor_id]['metrics']
        print(f"{sensor_id}: RMSE={m['rmse']:.2f}, R²={m['r2']:.3f}, MAE={m['mae']:.2f}, n={results_by_sensor[sensor_id]['n_samples']}")

def save_training_loss_plot(logs: List[Dict], filename: str):
    """Gráfica integral de diagnóstico de entrenamiento y validación."""
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    log_df = pd.DataFrame(logs)

    fig, axes = plt.subplots(2, 3, figsize=(16, 8))
    ax = axes[0, 0]
    ax.plot(log_df['epoch'], log_df['train_nll'], label='NLL entrenamiento')
    if 'val_nll' in log_df and log_df['val_nll'].notna().any():
        ax.plot(log_df['epoch'], log_df['val_nll'], label='NLL validación')
    ax.set_title("NLL por época")
    ax.set_xlabel("Época")
    ax.legend()

    ax = axes[0, 1]
    ax.plot(log_df['epoch'], log_df['train_kl_per_sample'])
    ax.set_title("KL por muestra")
    ax.set_xlabel("Época")

    ax = axes[0, 2]
    ax.plot(log_df['epoch'], log_df['train_elbo'])
    ax.set_title("ELBO por muestra")
    ax.set_xlabel("Época")

    ax = axes[1, 0]
    if 'val_rmse' in log_df and log_df['val_rmse'].notna().any():
        ax.plot(log_df['epoch'], log_df['val_rmse'])
        ax.set_title("RMSE validación")
        ax.set_xlabel("Época")
        ax.set_ylabel("µg/m³")
    else:
        ax.axis('off')

    ax = axes[1, 1]
    ax.plot(log_df['epoch'], log_df['lr'], label='LR')
    ax2 = ax.twinx()
    ax2.plot(log_df['epoch'], log_df['beta_t'], color='tab:orange', label='beta')
    ax.set_title("LR y beta")
    ax.set_xlabel("Época")
    ax.legend(loc='upper left')
    ax2.legend(loc='upper right')

    ax = axes[1, 2]
    ax.plot(log_df['epoch'], log_df['sigma_y'])
    ax.set_title("Desviación del ruido σ_y")
    ax.set_xlabel("Época")

    plt.tight_layout()
    plt.savefig(filename)
    plt.close(fig)
    print(f"Curvas de entrenamiento guardadas en {filename}")

def save_validation_plots(y_val: pd.Series, pred_val: np.ndarray, fechas: Optional[pd.Series], out_prefix: str):
    """Gráficas auxiliares de validación: dispersión y serie temporal."""
    os.makedirs(os.path.dirname(out_prefix), exist_ok=True)
    # Dispersión
    fig, ax = plt.subplots(figsize=(5, 5))
    ax.scatter(y_val.values, pred_val, s=10, alpha=0.5)
    minv = min(y_val.min(), pred_val.min())
    maxv = max(y_val.max(), pred_val.max())
    ax.plot([minv, maxv], [minv, maxv])
    ax.set_xlabel("Observado (µg/m³)")
    ax.set_ylabel("Predicho (µg/m³)")
    ax.set_title("Validación: predicho vs observado")
    plt.tight_layout()
    plt.savefig(out_prefix + "_scatter.png")
    plt.close(fig)
    # Serie temporal si hay fechas
    if fechas is not None:
        fig, ax = plt.subplots(figsize=(10, 4))
        ax.plot(fechas.values, y_val.values, label='observado')
        ax.plot(fechas.values, pred_val, label='predicho')
        ax.set_title("Validación: serie temporal")
        ax.set_xlabel("Tiempo")
        ax.set_ylabel("µg/m³")
        ax.legend()
        plt.tight_layout()
        plt.savefig(out_prefix + "_series.png")
        plt.close(fig)

# ==================== FUNCIONES PRINCIPALES ====================

def main_individual(config: Dict):
    print("Entrenamiento BNN Individual")
    trainer = BNNUnifiedTrainer()
    print("\nCargando datos")
    trainer.df_master = trainer.load_data()
    if trainer.df_master.empty:
        return
    
    print("\nConfiguración Individual")
    print(f"Sensor: {config['sensor_id']}")
    print(f"Fecha división: {config['split_date']}")
    print(f"Outliers: {config['outlier_method']}")

    sample_data = trainer.df_master[trainer.df_master['id_no2'] == config['sensor_id']].head(100).copy()
    if sample_data.empty:
        print(f"No se encontraron datos para el sensor {config['sensor_id']}")
        return
    
    all_features = [c for c in sample_data.columns 
                    if c not in ['fecha', 'id_no2', 'no2_value'] 
                    and pd.api.types.is_numeric_dtype(sample_data[c])]
    if config['features'] == ['all']:
        selected_features = all_features
    else:
        selected_features = [f for f in config['features'] if f in all_features]
    
    print(f"\nSe usan {len(selected_features)} características")
    print("Características seleccionadas:", selected_features)
    
    data_prep = trainer.prepare_individual_data(config, selected_features)
    if not data_prep:
        return
    
    # Split de validación temporal dentro del train
    train_df = data_prep['train_df']
    val_frac = config.get('val_fraction', 0.1)
    split_val_date = train_df['fecha'].quantile(1 - val_frac)
    core_train_df = train_df[train_df['fecha'] < split_val_date].copy()
    val_df = train_df[train_df['fecha'] >= split_val_date].copy()

    X_train = core_train_df[selected_features]
    y_train = core_train_df['no2_value']
    X_val = val_df[selected_features]
    y_val = val_df['no2_value']

    X_test = data_prep['test_df'][selected_features]
    y_test = data_prep['test_df']['no2_value']
    
    # Escalado con scalers del core train
    X_train_s, X_val_s, scaler_dict = trainer.scale_features(X_train, X_val, selected_features)
    X_test_s = trainer.apply_feature_scalers(X_test, scaler_dict, selected_features)
    y_train_s, scaler_target = trainer.scale_target(y_train)
    y_val_s = pd.Series(scaler_target.transform(y_val.values.reshape(-1, 1)).flatten(), index=y_val.index, name=y_val.name)
    y_test_s = pd.Series(scaler_target.transform(y_test.values.reshape(-1, 1)).flatten(), index=y_test.index, name=y_test.name)
    
    print(f"Datos: train_core={len(X_train_s)}, val={len(X_val_s)}, test={len(X_test_s)}")
    
    # Entrenar con validación
    model, log_noise_var, logs = trainer.train_bnn_model(
        X_train_s, y_train_s, config['train_config'],
        X_val=X_val_s, y_val=y_val, y_val_scaled=y_val_s, scaler_target=scaler_target
    )
    
    # Guardar curvas de entrenamiento
    save_training_loss_plot(logs, f"../models/bnn_training_diagnostics_{config['sensor_id']}.png")

    # Evaluación final en test
    print("\nEvaluación en test")
    predictions = trainer.predict(model, X_test_s, config['K_predict'], log_noise_var)
    metrics = trainer.evaluate_model(predictions, y_test, y_test_s, scaler_target, log_noise_var, config['sensor_id'])
    print_model_metrics(metrics)

    # Gráficas de validación al final
    with torch.no_grad():
        y_val_preds = torch.stack([model(torch.tensor(X_val_s.values, dtype=torch.float32).to(trainer.device)) for _ in range(50)], dim=0)
        pred_val_mean_scaled = y_val_preds.mean(0).detach().cpu().numpy()
        pred_val_mean = scaler_target.inverse_transform(pred_val_mean_scaled).flatten()
    save_validation_plots(y_val, pred_val_mean, val_df['fecha'], f"../models/bnn_val_{config['sensor_id']}")

    model_path = f"../models/bnn_model_{config['sensor_id']}.pkl"
    trainer.save_model(model_path, model, log_noise_var, scaler_dict, scaler_target, selected_features, config)

    print("\nProceso individual completado")
    return trainer, model, metrics


def main_global(config: Dict):
    print("Entrenamiento BNN Global Multi-Sensor")
    trainer = BNNUnifiedTrainer()
    
    print("\nCargando datos")
    trainer.df_master = trainer.load_data()
    if trainer.df_master.empty:
        return
    
    print("\nConfiguración Global")
    print(f"Sensores entrenamiento: {config['sensors_train']}")
    print(f"Sensores test: {config['sensors_test']}")
    print(f"Fecha división: {config['split_date']}")
    print(f"Outliers: {config['outlier_method']}")
    
    sample_data = trainer.df_master[trainer.df_master['id_no2'].isin(config['sensors_train'])].head(100).copy()
    all_features = [c for c in sample_data.columns 
                    if c not in ['fecha', 'id_no2', 'no2_value'] 
                    and pd.api.types.is_numeric_dtype(sample_data[c])]
    if config['features'] == ['all']:
        selected_features = all_features
    else:
        selected_features = [f for f in config['features'] if f in all_features]
    print(f"\nSe usan {len(selected_features)} características")

    data_prep = trainer.prepare_global_data(config, selected_features)
    if not data_prep:
        return
    
    # Split de validación temporal dentro de train
    train_df = data_prep['train_df']
    val_frac = config.get('val_fraction', 0.1)
    split_val_date = train_df['fecha'].quantile(1 - val_frac)
    core_train_df = train_df[train_df['fecha'] < split_val_date].copy()
    val_df = train_df[train_df['fecha'] >= split_val_date].copy()

    X_train = core_train_df[selected_features]
    y_train = core_train_df['no2_value']
    X_val = val_df[selected_features]
    y_val = val_df['no2_value']

    X_test = data_prep['test_df'][selected_features]
    y_test = data_prep['test_df']['no2_value']
    
    # Escalado con scalers del core train
    X_train_s, X_val_s, scaler_dict = trainer.scale_features(X_train, X_val, selected_features)
    X_test_s = trainer.apply_feature_scalers(X_test, scaler_dict, selected_features)
    y_train_s, scaler_target = trainer.scale_target(y_train)
    y_val_s = pd.Series(scaler_target.transform(y_val.values.reshape(-1, 1)).flatten(), index=y_val.index, name=y_val.name)
    y_test_s = pd.Series(scaler_target.transform(y_test.values.reshape(-1, 1)).flatten(), index=y_test.index, name=y_test.name)
    
    print(f"Datos: train_core={len(X_train_s)}, val={len(X_val_s)}, test={len(X_test_s)}")
    
    # Entrenar con validación
    model, log_noise_var, logs = trainer.train_bnn_model(
        X_train_s, y_train_s, config['train_config'],
        X_val=X_val_s, y_val=y_val, y_val_scaled=y_val_s, scaler_target=scaler_target
    )
    
    # Guardar curvas de entrenamiento
    save_training_loss_plot(logs, f"../models/bnn_training_diagnostics_global.png")
    
    # Evaluación global
    print("\nEvaluación global en test")
    predictions = trainer.predict(model, X_test_s, config['K_predict'], log_noise_var)
    global_metrics = trainer.evaluate_model(predictions, y_test, y_test_s, scaler_target, log_noise_var, "global")
    
    # Evaluación por sensor
    sensor_results = trainer.evaluate_global_by_sensor(
        data_prep, model, log_noise_var, scaler_dict, scaler_target, 
        selected_features, config['K_predict']
    )
    
    print_model_metrics(global_metrics, "Métricas Globales")
    print_global_summary(sensor_results)

    # Gráficas de validación al final
    with torch.no_grad():
        y_val_preds = torch.stack([model(torch.tensor(X_val_s.values, dtype=torch.float32).to(trainer.device)) for _ in range(50)], dim=0)
        pred_val_mean_scaled = y_val_preds.mean(0).detach().cpu().numpy()
        pred_val_mean = scaler_target.inverse_transform(pred_val_mean_scaled).flatten()
    save_validation_plots(y_val, pred_val_mean, val_df['fecha'], f"../models/bnn_val_global")

    model_path = f"../models/bnn_model_global.pkl"
    trainer.save_model(model_path, model, log_noise_var, scaler_dict, scaler_target, selected_features, config)
    
    print("\nProceso global completado")
    return trainer, model, global_metrics, sensor_results


# ==================== EJEMPLO DE EJECUCIÓN GLOBAL ====================

def ejemplo_global():
    print("=" * 60)
    print("EJEMPLO: ENTRENAMIENTO BNN GLOBAL MULTI-SENSOR")
    print("=" * 60)
    
    config = {
        'sensors_train': ['28079004', '28079008', '28079011', '28079016', '28079036', '28079038', '28079039','28079040','28079047', '28079048'],
        'sensors_test': ['28079050', '28079056', '28079035'],
        'split_date': '2024-01-01',
        'outlier_method': 'none',

        # conjunto de variables sin ID de estación
        'features': [
            'intensidad', 'carga', 'intensidad_lag8', 'carga_lag4', 'carga_lag2',
            't2m', 't2m_ma6', 'wind_speed_ma24', 'wind_speed_ewm3',
            'ssrd', 'ssrd_sum24', 'u10_ewm6', 'tp_sum24',
            'wind_dir_sin_ma6', 'wind_dir_cos_ma6',
            'hour_sin','hour_cos','dow_sin','dow_cos','month_sin','month_cos'
        ],

        'K_predict': 80,

        'train_config': {
            'learning_rate': 0.003,
            'n_epochs': 120,
            'batch_size': 512,
            'hidden_dims': [64, 32],
            'activation': 'ReLU',
            'beta': 0.5,
            'val_fraction': 0.1,   # solo lectura por train_bnn_model si se pasara ahí
            'K_eval': 30,
            'eval_every': 1
        },

        # val_fraction se usa en main_* para el split temporal
        'val_fraction': 0.1
    }
    
    return main_global(config)


if __name__ == "__main__":
    ejemplo_global()


EJEMPLO: ENTRENAMIENTO BNN GLOBAL MULTI-SENSOR
Entrenamiento BNN Global Multi-Sensor
Using device: cpu

Cargando datos

Configuración Global
Sensores entrenamiento: ['28079004', '28079008', '28079011', '28079016', '28079036', '28079038', '28079039', '28079040', '28079047', '28079048']
Sensores test: ['28079050', '28079056', '28079035']
Fecha división: 2024-01-01
Outliers: none

Se usan 15 características


AttributeError: 'BNNUnifiedTrainer' object has no attribute 'prepare_global_data'