In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import wandb
import copy
import matplotlib.pyplot as plt
from sklearn.metrics import f1_score, precision_score, recall_score, hamming_loss, roc_auc_score
import gc
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from sklearn.model_selection import train_test_split
from mamba_ssm import Mamba2
import matplotlib.pyplot as plt

In [None]:
class EnsembleStarClassifier:
    """
    Ensemble of StarClassifierFusion models with uncertainty quantification.
    """
    def __init__(
        self,
        model_class,
        model_args,
        num_models=5,
        device='cuda'
    ):
        """
        Args:
            model_class: The model class to use (StarClassifierFusion)
            model_args: Dictionary of arguments to pass to the model constructor
            num_models: Number of models in the ensemble
            device: Device to use for computation
        """
        self.num_models = num_models
        self.device = device
        self.models = []
        
        # Initialize models with different random initializations
        for i in range(num_models):
            model = model_class(**model_args)
            model.to(device)
            self.models.append(model)
            
    def train(
        self,
        train_loader,
        val_loader,
        test_loader,
        train_function,
        train_args,
        bootstrap=True,
        random_seed_offset=0
    ):
        """
        Train each model in the ensemble.
        
        Args:
            train_loader: DataLoader for training data
            val_loader: DataLoader for validation data
            test_loader: DataLoader for test data
            train_function: Function to train a single model
            train_args: Dictionary of arguments to pass to train_function
            bootstrap: Whether to use bootstrapping for training
            random_seed_offset: Offset for random seeds
        """
        trained_models = []
        
        for i in range(self.num_models):
            print(f"Training model {i+1}/{self.num_models}")
            
            # Set different random seed for each model
            seed = random_seed_offset + i
            torch.manual_seed(seed)
            np.random.seed(seed)
            
            if bootstrap:
                # Create bootstrapped dataset
                bootstrap_train_loader = self._create_bootstrap_loader(train_loader)
                curr_train_loader = bootstrap_train_loader
            else:
                curr_train_loader = train_loader
            
            # Initialize a new model for this ensemble member
            model = copy.deepcopy(self.models[i])
            
            # Create a new wandb run for this model
            run_name = f"ensemble_member_{i+1}"
            wandb.init(project="ALLSTARS_ensemble", name=run_name, group="ensemble_training", reinit=True)
            
            # Log ensemble member info
            wandb.config.update({
                "ensemble_member": i+1,
                "num_models": self.num_models,
                "bootstrap": bootstrap,
                "random_seed": seed
            })
            
            # Train the model
            trained_model = train_function(
                model=model,
                train_loader=curr_train_loader,
                val_loader=val_loader,
                test_loader=test_loader,
                **train_args
            )
            
            # Save the trained model
            trained_models.append(trained_model)
            
            # Save model checkpoint
            torch.save(trained_model.state_dict(), f"ensemble_model_{i+1}.pth")
            
            # Finish wandb run
            wandb.finish()
        
        self.models = trained_models
        return trained_models
    
    def _create_bootstrap_loader(self, dataloader):
        """
        Create a bootstrapped version of a dataloader.
        
        Args:
            dataloader: Original DataLoader
            
        Returns:
            DataLoader with bootstrapped samples
        """
        dataset = dataloader.dataset
        n_samples = len(dataset)
        
        # Generate bootstrap indices (sampling with replacement)
        bootstrap_indices = np.random.choice(n_samples, size=n_samples, replace=True)
        
        # Create a subset dataset with the bootstrapped indices
        bootstrap_dataset = torch.utils.data.Subset(dataset, bootstrap_indices)
        
        # Create a new dataloader with the bootstrapped dataset
        bootstrap_loader = DataLoader(
            bootstrap_dataset,
            batch_size=dataloader.batch_size,
            shuffle=True,
            num_workers=dataloader.num_workers if hasattr(dataloader, 'num_workers') else 0
        )
        
        return bootstrap_loader
    
    def predict(self, X_spectra, X_gaia, return_individual=False):
        """
        Generate predictions from the ensemble.
        
        Args:
            X_spectra: Spectral features tensor
            X_gaia: Gaia features tensor
            return_individual: Whether to return individual model predictions
            
        Returns:
            mean_probs: Mean probabilities across ensemble
            std_probs: Standard deviation of probabilities (uncertainty)
            individual_probs: Individual model probabilities (if return_individual=True)
        """
        # Ensure inputs are on the correct device
        X_spectra = X_spectra.to(self.device)
        X_gaia = X_gaia.to(self.device)
        
        all_probs = []
        
        # Get predictions from each model
        for model in self.models:
            model.eval()
            with torch.no_grad():
                logits = model(X_spectra, X_gaia)
                probs = torch.sigmoid(logits)
                all_probs.append(probs.cpu().numpy())
        
        # Stack predictions
        all_probs = np.stack(all_probs)
        
        # Calculate mean and standard deviation
        mean_probs = np.mean(all_probs, axis=0)
        std_probs = np.std(all_probs, axis=0)
        
        if return_individual:
            return mean_probs, std_probs, all_probs
        else:
            return mean_probs, std_probs
    
    def evaluate(self, test_loader, threshold=0.5, return_predictions=False):
        """
        Evaluate the ensemble on a test set.
        
        Args:
            test_loader: DataLoader for test data
            threshold: Classification threshold
            return_predictions: Whether to return predictions
            
        Returns:
            metrics: Dictionary of evaluation metrics
            mean_probs: Mean probabilities (if return_predictions=True)
            std_probs: Standard deviation of probabilities (if return_predictions=True)
            y_true: True labels (if return_predictions=True)
        """
        all_mean_probs = []
        all_std_probs = []
        all_y_true = []
        
        # Generate predictions for each batch
        for X_spectra, X_gaia, y_batch in test_loader:
            X_spectra, X_gaia = X_spectra.to(self.device), X_gaia.to(self.device)
            
            # Get ensemble predictions
            mean_probs, std_probs = self.predict(X_spectra, X_gaia)
            
            all_mean_probs.extend(mean_probs)
            all_std_probs.extend(std_probs)
            all_y_true.extend(y_batch.cpu().numpy())
        
        # Convert to numpy arrays
        mean_probs = np.array(all_mean_probs)
        std_probs = np.array(all_std_probs)
        y_true = np.array(all_y_true)
        
        # Make binary predictions
        y_pred = (mean_probs > threshold).astype(float)
        
        # Calculate metrics
        metrics = {
            "micro_f1": f1_score(y_true, y_pred, average='micro'),
            "macro_f1": f1_score(y_true, y_pred, average='macro'),
            "weighted_f1": f1_score(y_true, y_pred, average='weighted'),
            "micro_precision": precision_score(y_true, y_pred, average='micro', zero_division=1),
            "macro_precision": precision_score(y_true, y_pred, average='macro', zero_division=1),
            "weighted_precision": precision_score(y_true, y_pred, average='weighted', zero_division=1),
            "micro_recall": recall_score(y_true, y_pred, average='micro'),
            "macro_recall": recall_score(y_true, y_pred, average='macro'),
            "weighted_recall": recall_score(y_true, y_pred, average='weighted'),
            "hamming_loss": hamming_loss(y_true, y_pred),
            "mean_uncertainty": np.mean(std_probs),
            "median_uncertainty": np.median(std_probs),
            "max_uncertainty": np.max(std_probs)
        }
        
        # Try to calculate ROC AUC if possible
        try:
            metrics["roc_auc"] = roc_auc_score(y_true, mean_probs, average='macro', multi_class='ovr')
        except:
            metrics["roc_auc"] = None
        
        if return_predictions:
            return metrics, mean_probs, std_probs, y_true
        else:
            return metrics
    
    def visualize_uncertainty(self, mean_probs, std_probs, y_true, num_classes=10, class_names=None):
        """
        Visualize uncertainty for selected classes.
        
        Args:
            mean_probs: Mean probabilities from ensemble
            std_probs: Standard deviation of probabilities (uncertainty)
            y_true: True labels
            num_classes: Number of classes to visualize
            class_names: List of class names
        """
        n_classes = y_true.shape[1]
        
        # Select a subset of classes to visualize
        classes_to_plot = np.random.choice(n_classes, min(num_classes, n_classes), replace=False)
        
        # Create figure
        fig, axes = plt.subplots(len(classes_to_plot), 1, figsize=(10, 3 * len(classes_to_plot)))
        
        if len(classes_to_plot) == 1:
            axes = [axes]
        
        for i, class_idx in enumerate(classes_to_plot):
            ax = axes[i]
            
            # Get probabilities, uncertainties, and true labels for this class
            probs = mean_probs[:, class_idx]
            uncertainties = std_probs[:, class_idx]
            true_labels = y_true[:, class_idx]
            
            # Create scatter plot
            scatter = ax.scatter(probs, uncertainties, c=true_labels, cmap='coolwarm', alpha=0.6)
            
            # Add colorbar
            cbar = plt.colorbar(scatter, ax=ax)
            cbar.set_label('True Label')
            
            # Set class label
            if class_names is not None:
                class_label = class_names[class_idx]
            else:
                class_label = f"Class {class_idx}"
            
            ax.set_xlabel('Predicted Probability')
            ax.set_ylabel('Uncertainty (Std. Dev.)')
            ax.set_title(f'Uncertainty vs. Prediction for {class_label}')
            ax.grid(True, alpha=0.3)
            
            # Add threshold line
            ax.axvline(x=0.5, color='gray', linestyle='--', alpha=0.7)
        
        plt.tight_layout()
        return fig
    
    def analyze_errors(self, mean_probs, std_probs, y_true, threshold=0.5):
        """
        Analyze relationship between prediction errors and uncertainty.
        
        Args:
            mean_probs: Mean probabilities from ensemble
            std_probs: Standard deviation of probabilities (uncertainty)
            y_true: True labels
            threshold: Classification threshold
            
        Returns:
            fig: Matplotlib figure
        """
        # Make binary predictions
        y_pred = (mean_probs > threshold).astype(float)
        
        # Calculate error
        errors = np.abs(y_true - mean_probs)
        
        # Flatten arrays
        flat_errors = errors.flatten()
        flat_uncertainty = std_probs.flatten()
        
        # Create bins for uncertainty
        n_bins = 20
        bins = np.linspace(np.min(flat_uncertainty), np.max(flat_uncertainty), n_bins+1)
        bin_indices = np.digitize(flat_uncertainty, bins) - 1
        
        # Calculate mean error for each bin
        bin_mean_errors = np.zeros(n_bins)
        bin_counts = np.zeros(n_bins)
        
        for i in range(len(flat_errors)):
            bin_idx = bin_indices[i]
            if bin_idx >= 0 and bin_idx < n_bins:
                bin_mean_errors[bin_idx] += flat_errors[i]
                bin_counts[bin_idx] += 1
        
        # Avoid division by zero
        valid_bins = bin_counts > 0
        bin_mean_errors[valid_bins] /= bin_counts[valid_bins]
        
        # Create figure
        fig, ax = plt.subplots(figsize=(10, 6))
        
        # Plot mean error vs. uncertainty
        bin_centers = (bins[:-1] + bins[1:]) / 2
        ax.plot(bin_centers, bin_mean_errors, 'o-', markersize=8)
        
        # Fit linear regression
        valid_x = bin_centers[valid_bins]
        valid_y = bin_mean_errors[valid_bins]
        
        if len(valid_x) > 1:
            from sklearn.linear_model import LinearRegression
            reg = LinearRegression().fit(valid_x.reshape(-1, 1), valid_y)
            x_range = np.linspace(np.min(valid_x), np.max(valid_x), 100)
            y_pred = reg.predict(x_range.reshape(-1, 1))
            ax.plot(x_range, y_pred, 'r--', linewidth=2, 
                    label=f'Slope: {reg.coef_[0]:.4f}, R²: {reg.score(valid_x.reshape(-1, 1), valid_y):.4f}')
            ax.legend()
        
        ax.set_xlabel('Uncertainty (Std. Dev.)')
        ax.set_ylabel('Mean Absolute Error')
        ax.set_title('Relationship Between Uncertainty and Prediction Error')
        ax.grid(True, alpha=0.3)
        
        return fig
    
    def calibration_curve(self, mean_probs, std_probs, y_true, n_bins=10):
        """
        Plot calibration curve to analyze if predicted probabilities match observed frequencies.
        
        Args:
            mean_probs: Mean probabilities from ensemble
            std_probs: Standard deviation of probabilities (uncertainty)
            y_true: True labels
            n_bins: Number of bins for calibration curve
            
        Returns:
            fig: Matplotlib figure
        """
        # Flatten arrays
        flat_probs = mean_probs.flatten()
        flat_true = y_true.flatten()
        flat_uncertainty = std_probs.flatten()
        
        # Create bins for probabilities
        bins = np.linspace(0, 1, n_bins+1)
        bin_indices = np.digitize(flat_probs, bins) - 1
        
        # Calculate observed frequency and mean predicted probability for each bin
        bin_obs_freq = np.zeros(n_bins)
        bin_pred_prob = np.zeros(n_bins)
        bin_uncertainty = np.zeros(n_bins)
        bin_counts = np.zeros(n_bins)
        
        for i in range(len(flat_probs)):
            bin_idx = bin_indices[i]
            if bin_idx >= 0 and bin_idx < n_bins:
                bin_obs_freq[bin_idx] += flat_true[i]
                bin_pred_prob[bin_idx] += flat_probs[i]
                bin_uncertainty[bin_idx] += flat_uncertainty[i]
                bin_counts[bin_idx] += 1
        
        # Avoid division by zero
        valid_bins = bin_counts > 0
        bin_obs_freq[valid_bins] /= bin_counts[valid_bins]
        bin_pred_prob[valid_bins] /= bin_counts[valid_bins]
        bin_uncertainty[valid_bins] /= bin_counts[valid_bins]
        
        # Create figure
        fig, ax = plt.subplots(figsize=(10, 8))
        
        # Plot calibration curve
        bin_centers = (bins[:-1] + bins[1:]) / 2
        ax.plot(bin_centers, bin_obs_freq, 'o-', markersize=8, label='Calibration Curve')
        
        # Plot perfect calibration
        ax.plot([0, 1], [0, 1], 'k--', label='Perfect Calibration')
        
        # Plot uncertainties
        ax2 = ax.twinx()
        ax2.bar(bin_centers, bin_uncertainty, alpha=0.2, width=1/n_bins, color='r', label='Mean Uncertainty')
        ax2.set_ylabel('Mean Uncertainty (Std. Dev.)', color='r')
        ax2.tick_params(axis='y', labelcolor='r')
        
        # Add labels
        ax.set_xlabel('Mean Predicted Probability')
        ax.set_ylabel('Observed Frequency')
        ax.set_title('Calibration Curve with Uncertainty')
        ax.grid(True, alpha=0.3)
        
        # Add legends
        lines1, labels1 = ax.get_legend_handles_labels()
        lines2, labels2 = ax2.get_legend_handles_labels()
        ax.legend(lines1 + lines2, labels1 + labels2, loc='upper left')
        
        return fig
    
    def uncertainty_threshold(self, mean_probs, std_probs, y_true, threshold=0.5, uncertainty_percentiles=None):
        """
        Analyze performance at different uncertainty thresholds.
        
        Args:
            mean_probs: Mean probabilities from ensemble
            std_probs: Standard deviation of probabilities (uncertainty)
            y_true: True labels
            threshold: Classification threshold
            uncertainty_percentiles: List of uncertainty percentiles to evaluate
            
        Returns:
            fig: Matplotlib figure
            metrics: Dictionary of metrics at different uncertainty thresholds
        """
        if uncertainty_percentiles is None:
            uncertainty_percentiles = [0, 25, 50, 75, 90, 95]
        
        # Flatten arrays for uncertainty analysis
        flat_uncertainty = std_probs.flatten()
        
        # Calculate uncertainty thresholds
        uncertainty_thresholds = [np.percentile(flat_uncertainty, p) for p in uncertainty_percentiles]
        
        # Calculate metrics at different uncertainty thresholds
        metrics = []
        coverage = []
        
        for unc_thresh in uncertainty_thresholds:
            # Create mask for samples below uncertainty threshold
            mask = np.max(std_probs, axis=1) <= unc_thresh
            
            # Skip if no samples meet the criteria
            if np.sum(mask) == 0:
                metrics.append(None)
                coverage.append(0)
                continue
            
            # Filter predictions and true labels
            filtered_probs = mean_probs[mask]
            filtered_true = y_true[mask]
            
            # Make binary predictions
            filtered_pred = (filtered_probs > threshold).astype(float)
            
            # Calculate metrics
            current_metrics = {
                "micro_f1": f1_score(filtered_true, filtered_pred, average='micro'),
                "macro_f1": f1_score(filtered_true, filtered_pred, average='macro'),
                "weighted_f1": f1_score(filtered_true, filtered_pred, average='weighted'),
                "hamming_loss": hamming_loss(filtered_true, filtered_pred),
            }
            
            metrics.append(current_metrics)
            coverage.append(np.mean(mask))
        
        # Create figure
        fig, ax1 = plt.subplots(figsize=(12, 6))
        
        # Plot F1 score
        f1_scores = [m["micro_f1"] if m is not None else 0 for m in metrics]
        ax1.plot(uncertainty_percentiles, f1_scores, 'bo-', label='Micro F1 Score')
        
        # Plot macro F1 score
        macro_f1_scores = [m["macro_f1"] if m is not None else 0 for m in metrics]
        ax1.plot(uncertainty_percentiles, macro_f1_scores, 'go-', label='Macro F1 Score')
        
        # Plot coverage
        ax2 = ax1.twinx()
        ax2.plot(uncertainty_percentiles, coverage, 'r--', label='Data Coverage')
        ax2.set_ylabel('Data Coverage', color='r')
        ax2.tick_params(axis='y', labelcolor='r')
        
        # Add labels
        ax1.set_xlabel('Uncertainty Percentile Threshold')
        ax1.set_ylabel('F1 Score')
        ax1.set_title('Performance vs. Uncertainty Threshold')
        ax1.grid(True, alpha=0.3)
        
        # Add legend
        lines1, labels1 = ax1.get_legend_handles_labels()
        lines2, labels2 = ax2.get_legend_handles_labels()
        ax1.legend(lines1 + lines2, labels1 + labels2, loc='center right')
        
        return fig, {"metrics": metrics, "coverage": coverage, "percentiles": uncertainty_percentiles}
    
    def selective_prediction(self, mean_probs, std_probs, y_true, threshold=0.5):
        """
        Perform selective prediction analysis.
        
        Args:
            mean_probs: Mean probabilities from ensemble
            std_probs: Standard deviation of probabilities (uncertainty)
            y_true: True labels
            threshold: Classification threshold
            
        Returns:
            fig: Matplotlib figure
        """
        # Calculate max uncertainty for each sample
        max_uncertainties = np.max(std_probs, axis=1)
        
        # Sort samples by uncertainty
        sorted_indices = np.argsort(max_uncertainties)
        
        # Initialize lists for storing results
        coverages = []
        f1_scores = []
        
        # Calculate metrics at different coverage levels
        coverage_steps = np.linspace(0.1, 1.0, 10)
        
        for coverage in coverage_steps:
            # Select top-k% most certain predictions
            k = int(len(sorted_indices) * coverage)
            selected_indices = sorted_indices[:k]
            
            # Filter predictions and true labels
            selected_probs = mean_probs[selected_indices]
            selected_true = y_true[selected_indices]
            
            # Make binary predictions
            selected_pred = (selected_probs > threshold).astype(float)
            
            # Calculate F1 score
            f1 = f1_score(selected_true, selected_pred, average='micro')
            
            coverages.append(coverage)
            f1_scores.append(f1)
        
        # Create figure
        fig, ax = plt.subplots(figsize=(10, 6))
        
        # Plot F1 score vs. coverage
        ax.plot(coverages, f1_scores, 'bo-', markersize=8)
        
        # Add area under curve
        ax.fill_between(coverages, 0, f1_scores, alpha=0.2)
        
        # Add labels
        ax.set_xlabel('Coverage (Fraction of Data)')
        ax.set_ylabel('Micro F1 Score')
        ax.set_title('Selective Prediction: Performance vs. Coverage')
        ax.grid(True, alpha=0.3)
        
        # Add area under the curve value
        auc = np.trapz(f1_scores, coverages)
        ax.text(0.05, 0.95, f'AUC: {auc:.4f}', transform=ax.transAxes, 
                fontsize=12, verticalalignment='top', bbox=dict(boxstyle='round', alpha=0.1))
        
        return fig

# Function to run ensemble training and evaluation
def train_and_evaluate_ensemble(
    model_class,
    model_args,
    train_loader,
    val_loader,
    test_loader,
    train_function,
    train_args,
    num_models=5,
    bootstrap=True,
    device='cuda',
    class_names=None
):
    """
    Train and evaluate an ensemble model.
    
    Args:
        model_class: The model class to use
        model_args: Arguments for model initialization
        train_loader: DataLoader for training data
        val_loader: DataLoader for validation data
        test_loader: DataLoader for test data
        train_function: Function to train a single model
        train_args: Arguments for the training function
        num_models: Number of models in the ensemble
        bootstrap: Whether to use bootstrapping for training
        device: Device to use for computation
        class_names: Names of the classes (optional)
        
    Returns:
        ensemble: Trained ensemble model
        metrics: Evaluation metrics
        figures: Dictionary of visualization figures
    """
    # Initialize wandb for the ensemble experiment
    wandb.init(project="ALLSTARS_ensemble", name="ensemble_experiment", reinit=True)
    
    # Log ensemble configuration
    wandb.config.update({
        "num_models": num_models,
        "bootstrap": bootstrap,
        **model_args,
        **train_args
    })
    
    # Initialize ensemble
    ensemble = EnsembleStarClassifier(
        model_class=model_class,
        model_args=model_args,
        num_models=num_models,
        device=device
    )
    
    # Train ensemble
    ensemble.train(
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader,
        train_function=train_function,
        train_args=train_args,
        bootstrap=bootstrap
    )
    
    # Evaluate ensemble
    metrics, mean_probs, std_probs, y_true = ensemble.evaluate(
        test_loader=test_loader,
        return_predictions=True
    )
    
    # Log evaluation metrics
    wandb.log(metrics)
    
    # Create visualizations
    figures = {}
    
    # Uncertainty visualization
    uncertainty_fig = ensemble.visualize_uncertainty(
        mean_probs=mean_probs,
        std_probs=std_probs,
        y_true=y_true,
        num_classes=10,
        class_names=class_names
    )
    figures['uncertainty'] = uncertainty_fig
    
    # Error analysis
    error_fig = ensemble.analyze_errors(
        mean_probs=mean_probs,
        std_probs=std_probs,
        y_true=y_true
    )
    figures['error_analysis'] = error_fig
    
    # Calibration curve
    calibration_fig = ensemble.calibration_curve(
        mean_probs=mean_probs,
        std_probs=std_probs,
        y_true=y_true
    )
    figures['calibration'] = calibration_fig
    
    # Uncertainty threshold analysis
    threshold_fig, threshold_metrics = ensemble.uncertainty_threshold(
        mean_probs=mean_probs,
        std_probs=std_probs,
        y_true=y_true
    )
    figures['threshold_analysis'] = threshold_fig
    
    # Selective prediction
    selective_fig = ensemble.selective_prediction(
        mean_probs=mean_probs,
        std_probs=std_probs,
        y_true=y_true
    )
    figures['selective_prediction'] = selective_fig
    
    # Log figures
    for name, fig in figures.items():
        wandb.log({name: wandb.Image(fig)})
        plt.close(fig)
    
    # Finish wandb run
    wandb.finish()
    
    return ensemble, metrics, figures

# Example usage
if __name__ == "__main__":
    # Example configuration for ensemble
    num_models = 5
    
    # Model configuration
    d_model_spectra = 2048
    d_model_gaia = 2048
    num_classes = 55
    input_dim_spectra = 3647
    input_dim_gaia = 18
    n_layers = 12
    
    # Training configuration
    lr = 2.5e-4
    patience = 600
    num_epochs = 200
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # Initialize model arguments
    model_args = {
        "d_model_spectra": d_model_spectra,
        "d_model_gaia": d_model_gaia,
        "num_classes": num_classes,
        "input_dim_spectra": input_dim_spectra,
        "input_dim_gaia": input_dim_gaia,
        "n_layers": n_layers,
        "use_cross_attention": True,
        "n_cross_attn_heads": 8
    }
    
    # Initialize training arguments
    train_args = {
        "num_epochs": num_epochs,
        "lr": lr,
        "max_patience": patience,
        "device": device
    }
    
    # Train and evaluate ensemble
    ensemble, metrics, figures = train_and_evaluate_ensemble(
        model_class=StarClassifierFusion,
        model_args=model_args,
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader,
        train_function=train_model_fusion,
        train_args=train_args,
        num_models=num_models,
        bootstrap=True,
        device=device,
        class_names=classes
    )
    
    # Save ensemble model
    torch.save({
        "model_args": model_args,
        "train_args": train_args,
        "num_models": num_models,
        "models": [model.state_dict() for model in ensemble.models]
    }, "ensemble_model.pth")
    
    print("Ensemble model training and evaluation complete!")
    print(f"Final metrics: {metrics}")
    
    # Generate predictions with uncertainty
    # Example:
    # mean_probs, std_probs = ensemble.predict(X_test_spectra, X_test_gaia)
    # predictions = (mean_probs > 0.5).astype(float)
    # uncertainties = std_probs

NameError: name 'StarClassifierFusion' is not defined