In [9]:
from sklearn.model_selection import train_test_split
import torch
import os
import gc
import numpy as np
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from mamba_ssm import Mamba2


In [10]:
class MemoryEfficientMamba(nn.Module):
    """
    Memory-efficient wrapper for Mamba2 with gradient checkpointing and 
    optional parameter quantization.
    """
    def __init__(self, d_model, d_state=16, d_conv=4, expand=2, use_checkpoint=True):
        super().__init__()
        self.mamba = Mamba2(
            d_model=d_model,
            d_state=d_state,
            d_conv=d_conv,
            expand=expand
        )
        self.use_checkpoint = use_checkpoint
    
    def forward(self, x):
        if self.use_checkpoint and self.training:
            return checkpoint(self.mamba, x)
        else:
            return self.mamba(x)

class MemoryEfficientStarClassifier(nn.Module):
    """
    Memory-efficient version of StarClassifierFusion with various
    optimizations to reduce VRAM usage.
    """
    def __init__(
        self,
        d_model_spectra,
        d_model_gaia,
        num_classes,
        input_dim_spectra,
        input_dim_gaia,
        n_layers=6,
        use_cross_attention=True,
        n_cross_attn_heads=8,
        d_state=16,  # Reduced from 256 to save memory
        d_conv=4,
        expand=2,
        use_checkpoint=True,
        activation_checkpointing=True,
        use_half_precision=True,
        sequential_processing=True
    ):
        super().__init__()
        self.use_checkpoint = use_checkpoint
        self.activation_checkpointing = activation_checkpointing
        self.sequential_processing = sequential_processing
        
        # Use lower precision
        self.dtype = torch.float16 if use_half_precision else torch.float32

        # Input projection layers
        self.input_proj_spectra = nn.Linear(input_dim_spectra, d_model_spectra)
        self.input_proj_gaia = nn.Linear(input_dim_gaia, d_model_gaia)
        
        # Memory-efficient Mamba layers
        self.mamba_spectra_layers = nn.ModuleList([
            MemoryEfficientMamba(
                d_model=d_model_spectra,
                d_state=d_state,
                d_conv=d_conv,
                expand=expand,
                use_checkpoint=activation_checkpointing
            ) for _ in range(n_layers)
        ])
        
        self.mamba_gaia_layers = nn.ModuleList([
            MemoryEfficientMamba(
                d_model=d_model_gaia,
                d_state=d_state,
                d_conv=d_conv,
                expand=expand,
                use_checkpoint=activation_checkpointing
            ) for _ in range(n_layers)
        ])

        # Cross-attention (optional)
        self.use_cross_attention = use_cross_attention
        if use_cross_attention:
            self.cross_attn_block_spectra = self._create_cross_attn_block(d_model_spectra, n_heads=n_cross_attn_heads)
            self.cross_attn_block_gaia = self._create_cross_attn_block(d_model_gaia, n_heads=n_cross_attn_heads)

        # Final classifier
        fusion_dim = d_model_spectra + d_model_gaia
        self.layer_norm = nn.LayerNorm(fusion_dim)
        self.classifier = nn.Linear(fusion_dim, num_classes)
    
    def _create_cross_attn_block(self, d_model, n_heads):
        """Creates a cross-attention block with optional gradient checkpointing."""
        class CrossAttentionBlock(nn.Module):
            def __init__(self, d_model, n_heads):
                super().__init__()
                self.cross_attn = nn.MultiheadAttention(
                    embed_dim=d_model, 
                    num_heads=n_heads, 
                    batch_first=True
                )
                self.norm1 = nn.LayerNorm(d_model)
                
                self.ffn = nn.Sequential(
                    nn.Linear(d_model, 4 * d_model),
                    nn.ReLU(),
                    nn.Linear(4 * d_model, d_model)
                )
                self.norm2 = nn.LayerNorm(d_model)
                
            def forward(self, x_q, x_kv):
                # Cross-attention
                attn_output, _ = self.cross_attn(query=x_q, key=x_kv, value=x_kv)
                x = self.norm1(x_q + attn_output)
                
                # Feed forward
                ffn_out = self.ffn(x)
                x = self.norm2(x + ffn_out)
                
                return x
        
        block = CrossAttentionBlock(d_model, n_heads)
        
        # Wrap with gradient checkpointing if requested
        if self.activation_checkpointing:
            def forward_with_checkpoint(module, x_q, x_kv):
                def custom_forward(x_q, x_kv):
                    return module(x_q, x_kv)
                return checkpoint(custom_forward, x_q, x_kv)
            
            class CheckpointedCrossAttention(nn.Module):
                def __init__(self, block):
                    super().__init__()
                    self.block = block
                
                def forward(self, x_q, x_kv):
                    return forward_with_checkpoint(self.block, x_q, x_kv)
            
            return CheckpointedCrossAttention(block)
        else:
            return block
    
    def _process_mamba_layers(self, x, layers):
        """Process input through Mamba layers, optionally sequentially to save memory."""
        if self.sequential_processing:
            for layer in layers:
                x = layer(x)
        else:
            # Process all layers at once (uses more memory but faster)
            for layer in layers:
                x = layer(x)
        return x
    
    def forward(self, x_spectra, x_gaia):
        # Convert to half precision if requested
        if hasattr(self, 'dtype') and self.dtype == torch.float16:
            x_spectra = x_spectra.half()
            x_gaia = x_gaia.half()
        
        # Project inputs
        x_spectra = self.input_proj_spectra(x_spectra)
        x_gaia = self.input_proj_gaia(x_gaia)
        
        # Add sequence dimension if needed
        if len(x_spectra.shape) == 2:
            x_spectra = x_spectra.unsqueeze(1)
        if len(x_gaia.shape) == 2:
            x_gaia = x_gaia.unsqueeze(1)
        
        # Process through Mamba layers
        x_spectra = self._process_mamba_layers(x_spectra, self.mamba_spectra_layers)
        x_gaia = self._process_mamba_layers(x_gaia, self.mamba_gaia_layers)
        
        # Optional cross-attention
        if self.use_cross_attention:
            x_spectra_fused = self.cross_attn_block_spectra(x_spectra, x_gaia)
            x_gaia_fused = self.cross_attn_block_gaia(x_gaia, x_spectra)
            x_spectra = x_spectra_fused
            x_gaia = x_gaia_fused
        
        # Pool across sequence dimension
        x_spectra = x_spectra.mean(dim=1)
        x_gaia = x_gaia.mean(dim=1)
        
        # Concatenate
        x_fused = torch.cat([x_spectra, x_gaia], dim=-1)
        
        # Final classification
        x_fused = self.layer_norm(x_fused)
        logits = self.classifier(x_fused)
        
        return logits

class UltraMemoryEfficientEnsemble:
    """
    Ultra memory-efficient implementation of ensemble for uncertainty quantification.
    Avoids creating multiple models in memory and removes quantization during initialization.
    """
    def __init__(
        self, 
        model_class, 
        model_args, 
        num_models=5, 
        device='cuda',
        checkpoint_dir='ensemble_checkpoints'
    ):
        """
        Initialize the ultra memory-efficient ensemble.
        
        Args:
            model_class: Model class to instantiate
            model_args: Arguments for model initialization
            num_models: Number of models in ensemble
            device: Device to use
            checkpoint_dir: Directory to save/load model checkpoints
        """
        self.model_class = model_class
        self.model_args = model_args
        self.num_models = num_models
        self.device = device
        self.checkpoint_dir = checkpoint_dir
        
        # Create directory if it doesn't exist
        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)
        
        # We don't create any models during initialization to save memory
        self.active_model = None
        self.active_model_idx = -1
    
    def _create_model(self, model_idx, for_inference=False):
        """Create a new model instance with appropriate seed."""
        # Set seeds for reproducibility
        torch.manual_seed(42 + model_idx)
        np.random.seed(42 + model_idx)
        
        # Create model instance
        model = self.model_class(**self.model_args)
        
        # Move to device and set mode
        model = model.to(self.device)
        if for_inference:
            model.eval()
        
        return model
    
    def _get_checkpoint_path(self, model_idx):
        """Get path for model checkpoint."""
        return os.path.join(self.checkpoint_dir, f"model_{model_idx}.pt")
    
    def _save_model(self, model, model_idx):
        """Save model to checkpoint."""
        # Save to disk
        checkpoint_path = self._get_checkpoint_path(model_idx)
        # Move state_dict to CPU before saving to avoid GPU memory issues
        state_dict = {k: v.cpu() for k, v in model.state_dict().items()}
        torch.save(state_dict, checkpoint_path)
    
    def _load_model(self, model_idx, for_inference=False):
        """Load a model for the specified index."""
        # Clear any existing model to free memory
        if self.active_model is not None:
            del self.active_model
            torch.cuda.empty_cache()
            gc.collect()
        
        # Create a new model instance
        model = self._create_model(model_idx, for_inference)
        
        # Load checkpoint if it exists
        checkpoint_path = self._get_checkpoint_path(model_idx)
        if os.path.exists(checkpoint_path):
            # Load state dict from CPU to save GPU memory
            state_dict = torch.load(checkpoint_path, map_location='cpu')
            model.load_state_dict(state_dict)
            
            # Move model to device after loading weights
            model = model.to(self.device)
        
        self.active_model = model
        self.active_model_idx = model_idx
        return model
    
    def train_single_model(
        self, 
        model_idx,
        train_loader, 
        val_loader, 
        test_loader=None, 
        num_epochs=100, 
        lr=1e-4, 
        max_patience=20,
        scheduler_type='OneCycleLR',
        batch_accumulation=1,  # Gradient accumulation steps
        log_to_wandb=True
    ):
        """
        Train a single model in the ensemble.
        
        Args:
            model_idx: Index of the model to train
            train_loader: DataLoader for training data
            val_loader: DataLoader for validation data
            test_loader: DataLoader for test data (optional)
            num_epochs: Maximum number of epochs to train
            lr: Learning rate
            max_patience: Maximum patience for early stopping
            scheduler_type: Type of learning rate scheduler ('OneCycleLR' or 'ReduceLROnPlateau')
            batch_accumulation: Number of batches to accumulate gradients (to simulate larger batch size)
            log_to_wandb: Whether to log training progress to wandb
            
        Returns:
            Trained model (also saved to checkpoint)
        """
        import torch.optim as optim
        from tqdm import tqdm
        
        # Initialize wandb if requested
        if log_to_wandb:
            try:
                import wandb
                run = wandb.init(
                    project="ALLSTARS_ultra_memory_efficient", 
                    name=f"model_{model_idx}",
                    group="memory_efficient_training",
                    config={
                        **self.model_args,
                        "model_idx": model_idx,
                        "num_models": self.num_models,
                        "lr": lr,
                        "max_patience": max_patience,
                        "scheduler_type": scheduler_type,
                        "batch_accumulation": batch_accumulation,
                        "num_epochs": num_epochs
                    },
                    reinit=True
                )
            except ImportError:
                print("wandb not installed. Training without logging.")
                log_to_wandb = False
        
        # Load or create model
        model = self._load_model(model_idx, for_inference=False)
        model.train()
        
        # Create optimizer (SGD uses less memory than Adam/AdamW)
        optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
        
        # Configure the scheduler
        if scheduler_type == 'OneCycleLR':
            scheduler = optim.lr_scheduler.OneCycleLR(
                optimizer, 
                max_lr=lr,
                epochs=num_epochs, 
                steps_per_epoch=len(train_loader) // batch_accumulation
            )
        else:  # ReduceLROnPlateau
            scheduler = optim.lr_scheduler.ReduceLROnPlateau(
                optimizer, 
                mode='min', 
                factor=0.5, 
                patience=int(max_patience / 5)
            )
        
        # Calculate class weights for imbalanced classes
        all_labels = []
        for _, _, y_batch in train_loader:
            all_labels.extend(y_batch.cpu().numpy())
        
        class_weights = self._calculate_class_weights(np.array(all_labels))
        class_weights = torch.tensor(class_weights, dtype=torch.float).to(self.device)
        criterion = nn.BCEWithLogitsLoss(pos_weight=class_weights)
        
        best_val_loss = float('inf')
        patience = max_patience
        
        # Initialize mixed precision scaler if supported
        scaler = torch.cuda.amp.GradScaler() if hasattr(torch.cuda, 'amp') else None

        # Training loop
        for epoch in range(num_epochs):
            # Resample training data if needed
            if hasattr(train_loader.dataset, 're_sample'):
                train_loader.dataset.re_sample()
                
                # Recompute class weights if needed
                all_labels = []
                for _, _, y_batch in train_loader:
                    all_labels.extend(y_batch.cpu().numpy())
                class_weights = self._calculate_class_weights(np.array(all_labels))
                class_weights = torch.tensor(class_weights, dtype=torch.float).to(self.device)
                criterion = nn.BCEWithLogitsLoss(pos_weight=class_weights)

            # --- Training Phase ---
            model.train()
            train_loss, train_acc = 0.0, 0.0
            batch_count = 0
            
            for i, (X_spc, X_ga, y_batch) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Training")):
                X_spc, X_ga, y_batch = X_spc.to(self.device), X_ga.to(self.device), y_batch.to(self.device)
                
                # Only zero gradients at the start of accumulation cycle
                if i % batch_accumulation == 0:
                    optimizer.zero_grad(set_to_none=True)
                
                # Forward pass with mixed precision if available
                if scaler:
                    with torch.cuda.amp.autocast():
                        outputs = model(X_spc, X_ga)
                        loss = criterion(outputs, y_batch) / batch_accumulation
                    
                    # Backward pass with scaled gradients
                    scaler.scale(loss).backward()
                    
                    # Step only at the end of accumulation cycle
                    if (i + 1) % batch_accumulation == 0 or (i + 1) == len(train_loader):
                        scaler.step(optimizer)
                        scaler.update()
                        if scheduler_type == 'OneCycleLR':
                            scheduler.step()
                else:
                    outputs = model(X_spc, X_ga)
                    loss = criterion(outputs, y_batch) / batch_accumulation
                    loss.backward()
                    
                    if (i + 1) % batch_accumulation == 0 or (i + 1) == len(train_loader):
                        optimizer.step()
                        if scheduler_type == 'OneCycleLR':
                            scheduler.step()
                
                # Calculate metrics
                train_loss += loss.item() * batch_accumulation * X_spc.size(0)
                predicted = (torch.sigmoid(outputs) > 0.5).float()
                correct = (predicted == y_batch).float()
                train_acc += correct.mean(dim=1).sum().item()
                batch_count += X_spc.size(0)
                
                # Free up memory
                del X_spc, X_ga, y_batch, outputs, loss, predicted, correct
                torch.cuda.empty_cache()
            
            train_loss /= batch_count
            train_acc /= batch_count

            # --- Validation Phase ---
            model.eval()
            val_loss, val_acc = 0.0, 0.0
            val_batch_count = 0
            
            with torch.no_grad():
                for X_spc, X_ga, y_batch in tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Validation"):
                    X_spc, X_ga, y_batch = X_spc.to(self.device), X_ga.to(self.device), y_batch.to(self.device)
                    
                    # Forward pass with mixed precision if available
                    if scaler:
                        with torch.cuda.amp.autocast():
                            outputs = model(X_spc, X_ga)
                            loss = criterion(outputs, y_batch)
                    else:
                        outputs = model(X_spc, X_ga)
                        loss = criterion(outputs, y_batch)
                    
                    val_loss += loss.item() * X_spc.size(0)
                    predicted = (torch.sigmoid(outputs) > 0.5).float()
                    correct = (predicted == y_batch).float()
                    val_acc += correct.mean(dim=1).sum().item()
                    val_batch_count += X_spc.size(0)
                    
                    # Free up memory
                    del X_spc, X_ga, y_batch, outputs, loss, predicted, correct
                    torch.cuda.empty_cache()
            
            val_loss /= val_batch_count
            val_acc /= val_batch_count

            # --- Test Phase (if provided) ---
            test_metrics = {}
            if test_loader is not None:
                test_loss, test_acc = 0.0, 0.0
                test_batch_count = 0
                y_true, y_pred = [], []
                
                with torch.no_grad():
                    for X_spc, X_ga, y_batch in tqdm(test_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Testing"):
                        X_spc, X_ga, y_batch = X_spc.to(self.device), X_ga.to(self.device), y_batch.to(self.device)
                        
                        # Forward pass with mixed precision if available
                        if scaler:
                            with torch.cuda.amp.autocast():
                                outputs = model(X_spc, X_ga)
                                loss = criterion(outputs, y_batch)
                        else:
                            outputs = model(X_spc, X_ga)
                            loss = criterion(outputs, y_batch)
                        
                        test_loss += loss.item() * X_spc.size(0)
                        predicted = (torch.sigmoid(outputs) > 0.5).float()
                        correct = (predicted == y_batch).float()
                        test_acc += correct.mean(dim=1).sum().item()
                        test_batch_count += X_spc.size(0)
                        
                        # Store on CPU to save GPU memory
                        y_true.extend(y_batch.cpu().numpy())
                        y_pred.extend(predicted.cpu().numpy())
                        
                        # Free up memory
                        del X_spc, X_ga, y_batch, outputs, loss, predicted, correct
                        torch.cuda.empty_cache()
                
                test_loss /= test_batch_count
                test_acc /= test_batch_count
                
                # Calculate metrics on CPU to save GPU memory
                test_metrics = self._calculate_metrics(np.array(y_true), np.array(y_pred))
                test_metrics.update({
                    "test_loss": test_loss,
                    "test_acc": test_acc,
                })

            # Log metrics
            if log_to_wandb:
                log_data = {
                    "epoch": epoch,
                    "train_loss": train_loss,
                    "val_loss": val_loss,
                    "train_acc": train_acc,
                    "val_acc": val_acc,
                    "lr": self._get_lr(optimizer)
                }
                log_data.update(test_metrics)
                wandb.log(log_data)
            
            # Print progress
            print(f"Epoch {epoch+1}/{num_epochs} - "
                  f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, "
                  f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

            # Update ReduceLROnPlateau scheduler if used
            if scheduler_type == 'ReduceLROnPlateau':
                scheduler.step(val_loss)

            # Early stopping and checkpoint saving
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                patience = max_patience
                
                # Save best model
                self._save_model(model, model_idx)
                
                if log_to_wandb:
                    wandb.run.summary["best_val_loss"] = best_val_loss
            else:
                patience -= 1
                if patience <= 0:
                    print("Early stopping triggered.")
                    break
            
            # Save every 10 epochs as a checkpoint
            if (epoch + 1) % 10 == 0:
                checkpoint_path = os.path.join(self.checkpoint_dir, f"model_{model_idx}_epoch_{epoch+1}.pt")
                state_dict = {k: v.cpu() for k, v in model.state_dict().items()}
                torch.save(state_dict, checkpoint_path)
        
        # Close wandb run if used
        if log_to_wandb:
            wandb.finish()
        
        return model
    
    def train(
        self,
        train_loader, 
        val_loader, 
        test_loader=None, 
        num_epochs=100, 
        lr=1e-4, 
        max_patience=20,
        scheduler_type='OneCycleLR',
        batch_accumulation=1,
        log_to_wandb=True
    ):
        """
        Train all models in the ensemble.
        
        Args:
            train_loader: DataLoader for training data
            val_loader: DataLoader for validation data
            test_loader: DataLoader for test data (optional)
            num_epochs: Maximum number of epochs to train
            lr: Learning rate
            max_patience: Maximum patience for early stopping
            scheduler_type: Type of learning rate scheduler
            batch_accumulation: Number of batches to accumulate gradients
            log_to_wandb: Whether to log training progress to wandb
            
        Returns:
            List of trained model paths
        """
        for model_idx in range(self.num_models):
            print(f"\n----- Training Ensemble Model {model_idx+1}/{self.num_models} -----\n")
            
            # Train this model
            self.train_single_model(
                model_idx=model_idx,
                train_loader=train_loader,
                val_loader=val_loader,
                test_loader=test_loader,
                num_epochs=num_epochs,
                lr=lr,
                max_patience=max_patience,
                scheduler_type=scheduler_type,
                batch_accumulation=batch_accumulation,
                log_to_wandb=log_to_wandb
            )
            
            # Free memory before moving to next model
            if self.active_model is not None:
                del self.active_model
                self.active_model = None
                self.active_model_idx = -1
                torch.cuda.empty_cache()
                gc.collect()
        
        # Return paths to all model checkpoints
        return [self._get_checkpoint_path(i) for i in range(self.num_models)]
    
    def predict(self, loader, return_individual=False, micro_batch_size=1):
        """
        Make predictions with the ensemble, one model at a time to save memory.
        Uses micro-batching to further reduce memory usage.
        
        Args:
            loader: DataLoader for the data to predict
            return_individual: Whether to return predictions from individual models
            micro_batch_size: Size of micro-batches to process at once
            
        Returns:
            mean_probs: Mean probability across all models
            std_probs: Standard deviation of probabilities (uncertainty measure)
            individual_probs: (Optional) Predictions from each individual model
        """
        from tqdm import tqdm
        import math
        
        # Get shapes from first batch
        for X_spc, X_ga, y in loader:
            num_classes = y.shape[1]
            break
        
        # Total number of samples
        num_samples = len(loader.dataset)
        
        # Create array to store all model outputs (if return_individual)
        all_model_outputs = [] if return_individual else None
        
        # Running sum and sum of squares for mean and std calculation
        sum_outputs = np.zeros((num_samples, num_classes))
        sum_squared_outputs = np.zeros((num_samples, num_classes))
        
        # Process each model sequentially
        for model_idx in range(self.num_models):
            print(f"Making predictions with model {model_idx+1}/{self.num_models}")
            
            # Load model
            model = self._load_model(model_idx, for_inference=True)
            model.eval()
            
            # Array to store this model's outputs
            model_outputs = np.zeros((num_samples, num_classes))
            
            # Track current position in the outputs array
            sample_idx = 0
            
            # Process batches
            with torch.no_grad():
                for X_spc, X_ga, _ in tqdm(loader, desc=f"Model {model_idx+1}"):
                    batch_size = X_spc.shape[0]
                    
                    # Process in micro-batches to save memory
                    for micro_start in range(0, batch_size, micro_batch_size):
                        micro_end = min(micro_start + micro_batch_size, batch_size)
                        
                        X_spc_micro = X_spc[micro_start:micro_end].to(self.device)
                        X_ga_micro = X_ga[micro_start:micro_end].to(self.device)
                        
                        # Forward pass with mixed precision
                        if hasattr(torch.cuda, 'amp'):
                            with torch.cuda.amp.autocast():
                                outputs = model(X_spc_micro, X_ga_micro)
                        else:
                            outputs = model(X_spc_micro, X_ga_micro)
                        
                        # Get probabilities
                        probs = torch.sigmoid(outputs).cpu().numpy()
                        
                        # Store in output array
                        start_idx = sample_idx + micro_start
                        end_idx = sample_idx + micro_end
                        model_outputs[start_idx:end_idx] = probs
                        
                        # Free memory
                        del X_spc_micro, X_ga_micro, outputs, probs
                        torch.cuda.empty_cache()
                    
                    # Update sample index
                    sample_idx += batch_size
            
            # Store model outputs if returning individual predictions
            if return_individual:
                all_model_outputs.append(model_outputs)
            
            # Update sums for mean and std calculation
            sum_outputs += model_outputs
            sum_squared_outputs += model_outputs**2
            
            # Free memory
            del model, model_outputs
            torch.cuda.empty_cache()
            gc.collect()
        
        # Calculate mean and std
        mean_probs = sum_outputs / self.num_models
        
        # Calculate standard deviation
        variance = (sum_squared_outputs / self.num_models) - (mean_probs**2)
        # Clip small negative values that can occur due to numerical issues
        variance = np.clip(variance, 0, None)
        std_probs = np.sqrt(variance)
        
        if return_individual:
            return mean_probs, std_probs, np.array(all_model_outputs)
        else:
            return mean_probs, std_probs

    def _calculate_class_weights(self, y):
        """Calculate class weights for handling imbalanced classes."""
        if y.ndim > 1:  
            class_counts = np.sum(y, axis=0)  
        else:
            class_counts = np.bincount(y)

        total_samples = y.shape[0] if y.ndim > 1 else len(y)
        class_counts = np.where(class_counts == 0, 1, class_counts)  # Prevent division by zero
        class_weights = total_samples / (len(class_counts) * class_counts)
        
        return class_weights
    
    def _calculate_metrics(self, y_true, y_pred):
        """Calculate evaluation metrics for multi-label classification."""
        from sklearn.metrics import f1_score, precision_score, recall_score, hamming_loss
        
        metrics = {
            "micro_f1": f1_score(y_true, y_pred, average='micro'),
            "macro_f1": f1_score(y_true, y_pred, average='macro'),
            "weighted_f1": f1_score(y_true, y_pred, average='weighted'),
            "micro_precision": precision_score(y_true, y_pred, average='micro', zero_division=1),
            "macro_precision": precision_score(y_true, y_pred, average='macro', zero_division=1),
            "weighted_precision": precision_score(y_true, y_pred, average='weighted', zero_division=1),
            "micro_recall": recall_score(y_true, y_pred, average='micro'),
            "macro_recall": recall_score(y_true, y_pred, average='macro'),
            "weighted_recall": recall_score(y_true, y_pred, average='weighted'),
            "hamming_loss": hamming_loss(y_true, y_pred)
        }
        
        return metrics
    
    def _get_lr(self, optimizer):
        """Get current learning rate from optimizer."""
        for param_group in optimizer.param_groups:
            return param_group['lr']

In [11]:
class MultiModalBalancedMultiLabelDataset(Dataset):
    """
    A balanced multi-label dataset that returns (X_spectra, X_gaia, y).
    It uses the same balancing strategy as `BalancedMultiLabelDataset`.
    """
    def __init__(self, X_spectra, X_gaia, y, limit_per_label=201):
        """
        Args:
            X_spectra (torch.Tensor): [num_samples, num_spectra_features]
            X_gaia (torch.Tensor): [num_samples, num_gaia_features]
            y (torch.Tensor): [num_samples, num_classes], multi-hot labels
            limit_per_label (int): limit or target number of samples per label
        """
        self.X_spectra = X_spectra
        self.X_gaia = X_gaia
        self.y = y
        self.limit_per_label = limit_per_label
        self.num_classes = y.shape[1]
        self.indices = self.balance_classes()
        
    def balance_classes(self):
        indices = []
        class_counts = torch.sum(self.y, axis=0)
        for cls in range(self.num_classes):
            cls_indices = np.where(self.y[:, cls] == 1)[0]
            if len(cls_indices) < self.limit_per_label:
                if len(cls_indices) == 0:
                    # No samples for this class
                    continue
                extra_indices = np.random.choice(
                    cls_indices, self.limit_per_label - len(cls_indices), replace=True
                )
                cls_indices = np.concatenate([cls_indices, extra_indices])
            elif len(cls_indices) > self.limit_per_label:
                cls_indices = np.random.choice(cls_indices, self.limit_per_label, replace=False)
            indices.extend(cls_indices)
        indices = np.unique(indices)
        np.random.shuffle(indices)
        return indices

    def re_sample(self):
        self.indices = self.balance_classes()

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        index = self.indices[idx]
        return (
            self.X_spectra[index],  # spectra features
            self.X_gaia[index],     # gaia features
            self.y[index],          # multi-hot labels
        )
def calculate_class_weights(y):
    if y.ndim > 1:  
        class_counts = np.sum(y, axis=0)  
    else:
        class_counts = np.bincount(y)

    total_samples = y.shape[0] if y.ndim > 1 else len(y)
    class_counts = np.where(class_counts == 0, 1, class_counts)  # Prevent division by zero
    class_weights = total_samples / (len(class_counts) * class_counts)
    
    return class_weights
def calculate_metrics(y_true, y_pred):
    metrics = {
        "micro_f1": f1_score(y_true, y_pred, average='micro'),
        "macro_f1": f1_score(y_true, y_pred, average='macro'),
        "weighted_f1": f1_score(y_true, y_pred, average='weighted'),
        "micro_precision": precision_score(y_true, y_pred, average='micro', zero_division=1),
        "macro_precision": precision_score(y_true, y_pred, average='macro', zero_division=1),
        "weighted_precision": precision_score(y_true, y_pred, average='weighted', zero_division=1),
        "micro_recall": recall_score(y_true, y_pred, average='micro'),
        "macro_recall": recall_score(y_true, y_pred, average='macro'),
        "weighted_recall": recall_score(y_true, y_pred, average='weighted'),
        "hamming_loss": hamming_loss(y_true, y_pred)
    }
    
    # Check if there are at least two classes present in y_true
    #if len(np.unique(y_true)) > 1:
        #metrics["roc_auc"] = roc_auc_score(y_true, y_pred, average='macro', multi_class='ovr')
    #else:
       # metrics["roc_auc"] = None  # or you can set it to a default value or message
    
    return metrics


In [13]:
batch_size = 128
batch_limit = int(batch_size / 2.5)

# Load datasets
#X_train_full = pd.read_pickle("Pickles/train_data_transformed2.pkl")
#X_test_full = pd.read_pickle("Pickles/test_data_transformed.pkl")
# classes = pd.read_pickle("Pickles/Updated_list_of_Classes.pkl")
import pickle
# Open them in a cross-platform way
with open("Pickles/Updated_List_of_Classes_ubuntu.pkl", "rb") as f:
    classes = pickle.load(f)  # This reads the actual data
with open("Pickles/train_data_transformed_ubuntu.pkl", "rb") as f:
    X_train_full = pickle.load(f)
with open("Pickles/test_data_transformed_ubuntu.pkl", "rb") as f:
    X_test_full = pickle.load(f)




# Extract labels
y_train_full = X_train_full[classes]
y_test = X_test_full[classes]

# Drop labels from both datasets
X_train_full.drop(classes, axis=1, inplace=True)
X_test_full.drop(classes, axis=1, inplace=True)


# Columns for spectral data (assuming all remaining columns after removing Gaia are spectra)
gaia_columns = ["parallax", "ra", "dec", "ra_error", "dec_error", "parallax_error", "pmra", "pmdec", 
                "pmra_error", "pmdec_error", "phot_g_mean_flux", "flagnopllx", "phot_g_mean_flux_error", 
                "phot_bp_mean_flux", "phot_rp_mean_flux", "phot_bp_mean_flux_error", "phot_rp_mean_flux_error", 
                "flagnoflux"]

# Spectra data (everything that is not Gaia-related) and the column 'otype'
X_train_spectra = X_train_full.drop(columns={"otype", "obsid", *gaia_columns})
X_test_spectra = X_test_full.drop(columns={"otype", "obsid", *gaia_columns})

# Gaia data (only the selected columns)
X_train_gaia = X_train_full[gaia_columns]
X_test_gaia = X_test_full[gaia_columns]

# Count nans and infs in x_train_gaia
print(X_train_gaia.isnull().sum())
print(X_train_gaia.isin([np.inf, -np.inf]).sum())


# Free up memory
del X_train_full, X_test_full
gc.collect()



# Split training set into training and validation
X_train_spectra, X_val_spectra, X_train_gaia, X_val_gaia, y_train, y_val = train_test_split(
    X_train_spectra, X_train_gaia, y_train_full, test_size=0.2, random_state=42
)

# Free memory
del y_train_full
gc.collect()



# Convert spectra and Gaia data into PyTorch tensors
X_train_spectra = torch.tensor(X_train_spectra.values, dtype=torch.float32)
X_val_spectra = torch.tensor(X_val_spectra.values, dtype=torch.float32)
X_test_spectra = torch.tensor(X_test_spectra.values, dtype=torch.float32)



X_train_gaia = torch.tensor(X_train_gaia.values, dtype=torch.float32)
X_val_gaia = torch.tensor(X_val_gaia.values, dtype=torch.float32)
X_test_gaia = torch.tensor(X_test_gaia.values, dtype=torch.float32)

y_train = torch.tensor(y_train.values, dtype=torch.float32)
y_val = torch.tensor(y_val.values, dtype=torch.float32)
y_test = torch.tensor(y_test.values, dtype=torch.float32)

# Print dataset shapes
print(f"X_train_spectra shape: {X_train_spectra.shape}")
print(f"X_val_spectra shape: {X_val_spectra.shape}")
print(f"X_test_spectra shape: {X_test_spectra.shape}")

print(f"X_train_gaia shape: {X_train_gaia.shape}")
print(f"X_val_gaia shape: {X_val_gaia.shape}")
print(f"X_test_gaia shape: {X_test_gaia.shape}")

print(f"y_train shape: {y_train.shape}")
print(f"y_val shape: {y_val.shape}")
print(f"y_test shape: {y_test.shape}")


train_dataset = MultiModalBalancedMultiLabelDataset(X_train_spectra, X_train_gaia, y_train, limit_per_label=batch_limit)
val_dataset = MultiModalBalancedMultiLabelDataset(X_val_spectra, X_val_gaia, y_val, limit_per_label=batch_limit)
test_dataset = MultiModalBalancedMultiLabelDataset(X_test_spectra, X_test_gaia, y_test, limit_per_label=batch_limit)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# print the number of samples in each dataset
print(f"Train dataset: {len(train_dataset)} samples")
print(f"Validation dataset: {len(val_dataset)} samples")
print(f"Test dataset: {len(test_dataset)} samples")


parallax                   0
ra                         0
dec                        0
ra_error                   0
dec_error                  0
parallax_error             0
pmra                       0
pmdec                      0
pmra_error                 0
pmdec_error                0
phot_g_mean_flux           0
flagnopllx                 0
phot_g_mean_flux_error     0
phot_bp_mean_flux          0
phot_rp_mean_flux          0
phot_bp_mean_flux_error    0
phot_rp_mean_flux_error    0
flagnoflux                 0
dtype: int64
parallax                   0
ra                         0
dec                        0
ra_error                   0
dec_error                  0
parallax_error             0
pmra                       0
pmdec                      0
pmra_error                 0
pmdec_error                0
phot_g_mean_flux           0
flagnopllx                 0
phot_g_mean_flux_error     0
phot_bp_mean_flux          0
phot_rp_mean_flux          0
phot_bp_mean_flux_error    0
p

In [9]:
# Import your model and the ultra memory-efficient ensemble
#from ultra_memory_efficient import MemoryEfficientStarClassifier, UltraMemoryEfficientEnsemble
#from your_dataset_file import MultiModalBalancedMultiLabelDataset  # Import your dataset class

# Function to track memory usage
def print_gpu_memory():
    if torch.cuda.is_available():
        print(f"GPU memory allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
        print(f"GPU memory reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
        # Force garbage collection
        gc.collect()
        torch.cuda.empty_cache()

# Example configuration for high-dimensional embeddings
CONFIG = {
    "d_model_spectra": 2048,  # Higher embedding dimension
    "d_model_gaia": 2048,     # Higher embedding dimension
    "num_classes": 55,
    "input_dim_spectra": 3647,
    "input_dim_gaia": 18,
    "n_layers": 12,
    "d_state": 8,            # Reduced state dimension to save memory
    "d_conv": 4,
    "expand": 2,
    "use_cross_attention": True,
    "n_cross_attn_heads": 8,
    "use_checkpoint": True,
    "activation_checkpointing": True,
    "use_half_precision": True,
    "sequential_processing": True
}

if __name__ == "__main__":
    # Set device
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {device}")
    
    # Print initial memory usage
    print("\nInitial memory usage:")
    print_gpu_memory()
    
    # Load your dataset
    # Replace this with your actual dataset loading code
    batch_size = 16  # Use smaller batch size to save memory
    batch_limit = int(batch_size / 2.5)
    
    # Load datasets (replace with your actual loading code)
    print("Loading datasets...")
    train_dataset = MultiModalBalancedMultiLabelDataset(
        X_train_spectra, X_train_gaia, y_train, limit_per_label=batch_limit
    )
    val_dataset = MultiModalBalancedMultiLabelDataset(
        X_val_spectra, X_val_gaia, y_val, limit_per_label=batch_limit
    )
    test_dataset = MultiModalBalancedMultiLabelDataset(
        X_test_spectra, X_test_gaia, y_test, limit_per_label=batch_limit
    )
    
    # Create data loaders with smaller batches
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    
    # Create ultra memory-efficient ensemble (no models are created yet)
    print("\nCreating ultra memory-efficient ensemble...")
    ensemble = UltraMemoryEfficientEnsemble(
        model_class=MemoryEfficientStarClassifier,
        model_args=CONFIG,
        num_models=5,
        device=device,
        checkpoint_dir='ultra_memory_efficient_models'
    )
    
    # Print memory usage after creating ensemble (should be minimal)
    print("\nMemory usage after creating ensemble (no models created yet):")
    print_gpu_memory()
    
    # ----------------- Option 1: Train the entire ensemble -----------------
    # Uncomment this section to train all models in the ensemble
    
    # print("\nTraining all models in the ensemble...")
    # model_paths = ensemble.train(
    #     train_loader=train_loader,
    #     val_loader=val_loader,
    #     test_loader=test_loader,
    #     num_epochs=100,
    #     lr=1e-4,
    #     max_patience=20,
    #     scheduler_type='OneCycleLR',
    #     batch_accumulation=4,  # Accumulate gradients over 4 batches (effectively 4x batch size)
    #     log_to_wandb=True
    # )
    
    # ----------------- Option 2: Train one model at a time -----------------
    # For more control, train models one at a time
    
    print("\nTraining a single model from the ensemble...")
    model_idx = 0  # Train the first model
    
    model = ensemble.train_single_model(
        model_idx=model_idx,
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader,
        num_epochs=100,
        lr=1e-4,
        max_patience=20,
        scheduler_type='OneCycleLR',
        batch_accumulation=4,  # Accumulate gradients over 4 batches to simulate larger batch
        log_to_wandb=True
    )
    
    # Free up memory
    del model
    torch.cuda.empty_cache()
    gc.collect()
    
    print("\nMemory usage after training and cleanup:")
    print_gpu_memory()
    
    # ----------------- Making predictions with the ensemble -----------------
    
    # Setup for smaller batch inference
    inference_batch_size = 8  # Use a very small batch size for inference
    test_loader_small = DataLoader(
        test_dataset, 
        batch_size=inference_batch_size, 
        shuffle=False,
        num_workers=2
    )
    
    print("\nMaking predictions with the trained model(s)...")
    
    # Use micro-batch size of 1 for minimal memory usage
    mean_probs, std_probs = ensemble.predict(
        test_loader_small, 
        return_individual=False,
        micro_batch_size=1
    )
    
    print("\nMemory usage after prediction:")
    print_gpu_memory()
    
    print("\nPrediction shape:", mean_probs.shape)
    print("Uncertainty shape:", std_probs.shape)
    
    # ----------------- Analyze results -----------------
    
    # Convert to binary predictions
    threshold = 0.5
    predictions = (mean_probs >= threshold).astype(float)
    
    # Calculate overall accuracy
    # Assuming y_test is available and has the same order as predictions
    accuracy = np.mean(np.equal(predictions, y_test.numpy()).astype(float))
    print(f"\nOverall accuracy: {accuracy:.4f}")
    
    # Identify high uncertainty predictions
    high_uncertainty_threshold = np.percentile(std_probs, 90)  # Top 10% most uncertain
    high_uncertainty_mask = std_probs >= high_uncertainty_threshold
    
    high_uncertainty_count = np.sum(high_uncertainty_mask)
    print(f"\nHigh uncertainty predictions: {high_uncertainty_count} ({high_uncertainty_count/std_probs.size:.2%})")
    
    # Check if high uncertainty correlates with errors
    high_uncertainty_errors = np.mean(np.not_equal(
        predictions[high_uncertainty_mask], 
        y_test.numpy()[high_uncertainty_mask]
    ).astype(float))
    
    normal_uncertainty_errors = np.mean(np.not_equal(
        predictions[~high_uncertainty_mask], 
        y_test.numpy()[~high_uncertainty_mask]
    ).astype(float))
    
    print(f"Error rate for high uncertainty predictions: {high_uncertainty_errors:.4f}")
    print(f"Error rate for normal predictions: {normal_uncertainty_errors:.4f}")
    print(f"Ratio: {high_uncertainty_errors/normal_uncertainty_errors:.2f}x")
    
    print("\nDone!")

Using device: cuda

Initial memory usage:
GPU memory allocated: 0.00 MB
GPU memory reserved: 0.00 MB
Loading datasets...

Creating ultra memory-efficient ensemble...

Memory usage after creating ensemble (no models created yet):
GPU memory allocated: 0.00 MB
GPU memory reserved: 0.00 MB

Training a single model from the ensemble...


[34m[1mwandb[0m: Currently logged in as: [33mjoaocsgalmeida[0m ([33mjoaocsgalmeida-university-of-southampton[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


  scaler = torch.cuda.amp.GradScaler() if hasattr(torch.cuda, 'amp') else None
  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 1/100 - Training: 100%|██████████| 18/18 [00:22<00:00,  1.25s/it]
  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 1/100 - Validation: 100%|██████████| 15/15 [00:01<00:00, 11.31it/s]
  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 1/100 - Testing: 100%|██████████| 17/17 [00:01<00:00,  8.66it/s]


Epoch 1/100 - Train Loss: 0.7293, Train Acc: 0.4869, Val Loss: 0.7252, Val Acc: 0.4911


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 2/100 - Training: 100%|██████████| 18/18 [00:03<00:00,  4.64it/s]
  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 2/100 - Validation: 100%|██████████| 15/15 [00:00<00:00, 19.79it/s]
  with torch.cuda.amp.autocast():
Epoch 2/100 - Testing: 100%|██████████| 17/17 [00:00<00:00, 20.69it/s]


Epoch 2/100 - Train Loss: 0.7249, Train Acc: 0.4922, Val Loss: 0.7100, Val Acc: 0.5112


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 3/100 - Training: 100%|██████████| 18/18 [00:02<00:00,  8.09it/s]
  with torch.cuda.amp.autocast():
Epoch 3/100 - Validation: 100%|██████████| 15/15 [00:00<00:00, 18.54it/s]
  with torch.cuda.amp.autocast():
Epoch 3/100 - Testing: 100%|██████████| 17/17 [00:00<00:00, 19.60it/s]


Epoch 3/100 - Train Loss: 0.7018, Train Acc: 0.5230, Val Loss: 0.6836, Val Acc: 0.5504


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 4/100 - Training: 100%|██████████| 19/19 [00:02<00:00,  7.93it/s]
  with torch.cuda.amp.autocast():
Epoch 4/100 - Validation: 100%|██████████| 15/15 [00:00<00:00, 19.90it/s]
  with torch.cuda.amp.autocast():
Epoch 4/100 - Testing: 100%|██████████| 17/17 [00:00<00:00, 20.13it/s]


Epoch 4/100 - Train Loss: 0.6641, Train Acc: 0.5836, Val Loss: 0.6438, Val Acc: 0.6073


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 5/100 - Training: 100%|██████████| 19/19 [00:02<00:00,  8.16it/s]
  with torch.cuda.amp.autocast():
Epoch 5/100 - Validation: 100%|██████████| 15/15 [00:00<00:00, 18.54it/s]
  with torch.cuda.amp.autocast():
Epoch 5/100 - Testing: 100%|██████████| 17/17 [00:00<00:00, 19.84it/s]


Epoch 5/100 - Train Loss: 0.6261, Train Acc: 0.6333, Val Loss: 0.5923, Val Acc: 0.6835


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 6/100 - Training: 100%|██████████| 19/19 [00:02<00:00,  7.96it/s]
  with torch.cuda.amp.autocast():
Epoch 6/100 - Validation: 100%|██████████| 15/15 [00:00<00:00, 17.71it/s]
  with torch.cuda.amp.autocast():
Epoch 6/100 - Testing: 100%|██████████| 17/17 [00:00<00:00, 18.47it/s]


Epoch 6/100 - Train Loss: 0.5697, Train Acc: 0.7186, Val Loss: 0.5342, Val Acc: 0.7650


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 7/100 - Training: 100%|██████████| 19/19 [00:05<00:00,  3.49it/s]
  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 7/100 - Validation: 100%|██████████| 15/15 [00:00<00:00, 20.24it/s]
  with torch.cuda.amp.autocast():
Epoch 7/100 - Testing: 100%|██████████| 17/17 [00:00<00:00, 20.18it/s]


Epoch 7/100 - Train Loss: 0.5114, Train Acc: 0.7950, Val Loss: 0.4745, Val Acc: 0.8387


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 8/100 - Training: 100%|██████████| 19/19 [00:02<00:00,  7.99it/s]
  with torch.cuda.amp.autocast():
Epoch 8/100 - Validation: 100%|██████████| 15/15 [00:00<00:00, 18.14it/s]
  with torch.cuda.amp.autocast():
Epoch 8/100 - Testing: 100%|██████████| 17/17 [00:00<00:00, 18.92it/s]


Epoch 8/100 - Train Loss: 0.4507, Train Acc: 0.8609, Val Loss: 0.4181, Val Acc: 0.8923


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 9/100 - Training: 100%|██████████| 19/19 [00:02<00:00,  8.20it/s]
  with torch.cuda.amp.autocast():
Epoch 9/100 - Validation: 100%|██████████| 15/15 [00:00<00:00, 17.63it/s]
  with torch.cuda.amp.autocast():
Epoch 9/100 - Testing: 100%|██████████| 17/17 [00:00<00:00, 17.83it/s]


Epoch 9/100 - Train Loss: 0.3944, Train Acc: 0.9094, Val Loss: 0.3678, Val Acc: 0.9288


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 10/100 - Training: 100%|██████████| 19/19 [00:02<00:00,  8.15it/s]
  with torch.cuda.amp.autocast():
Epoch 10/100 - Validation: 100%|██████████| 15/15 [00:00<00:00, 17.76it/s]
  with torch.cuda.amp.autocast():
Epoch 10/100 - Testing: 100%|██████████| 17/17 [00:00<00:00, 18.43it/s]


Epoch 10/100 - Train Loss: 0.3491, Train Acc: 0.9367, Val Loss: 0.3243, Val Acc: 0.9484


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 11/100 - Training: 100%|██████████| 19/19 [00:02<00:00,  7.76it/s]
  with torch.cuda.amp.autocast():
Epoch 11/100 - Validation: 100%|██████████| 15/15 [00:00<00:00, 16.73it/s]
  with torch.cuda.amp.autocast():
Epoch 11/100 - Testing: 100%|██████████| 17/17 [00:00<00:00, 18.28it/s]


Epoch 11/100 - Train Loss: 0.3075, Train Acc: 0.9490, Val Loss: 0.2875, Val Acc: 0.9572


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 12/100 - Training: 100%|██████████| 18/18 [00:02<00:00,  7.36it/s]
  with torch.cuda.amp.autocast():
Epoch 12/100 - Validation: 100%|██████████| 15/15 [00:00<00:00, 17.34it/s]
  with torch.cuda.amp.autocast():
Epoch 12/100 - Testing: 100%|██████████| 17/17 [00:00<00:00, 18.51it/s]


Epoch 12/100 - Train Loss: 0.2789, Train Acc: 0.9571, Val Loss: 0.2575, Val Acc: 0.9607


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 13/100 - Training: 100%|██████████| 19/19 [00:02<00:00,  7.53it/s]
  with torch.cuda.amp.autocast():
Epoch 13/100 - Validation: 100%|██████████| 15/15 [00:00<00:00, 16.91it/s]
  with torch.cuda.amp.autocast():
Epoch 13/100 - Testing: 100%|██████████| 17/17 [00:00<00:00, 17.50it/s]


Epoch 13/100 - Train Loss: 0.2492, Train Acc: 0.9605, Val Loss: 0.2331, Val Acc: 0.9625


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 14/100 - Training: 100%|██████████| 19/19 [00:02<00:00,  7.89it/s]
  with torch.cuda.amp.autocast():
Epoch 14/100 - Validation: 100%|██████████| 15/15 [00:00<00:00, 15.93it/s]
  with torch.cuda.amp.autocast():
Epoch 14/100 - Testing: 100%|██████████| 17/17 [00:01<00:00, 16.40it/s]


Epoch 14/100 - Train Loss: 0.2210, Train Acc: 0.9616, Val Loss: 0.2127, Val Acc: 0.9629


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 15/100 - Training: 100%|██████████| 19/19 [00:02<00:00,  7.74it/s]
  with torch.cuda.amp.autocast():
Epoch 15/100 - Validation: 100%|██████████| 15/15 [00:00<00:00, 15.43it/s]
  with torch.cuda.amp.autocast():
Epoch 15/100 - Testing: 100%|██████████| 17/17 [00:01<00:00, 16.06it/s]


Epoch 15/100 - Train Loss: 0.2046, Train Acc: 0.9621, Val Loss: 0.1968, Val Acc: 0.9631


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 16/100 - Training: 100%|██████████| 19/19 [00:02<00:00,  7.41it/s]
  with torch.cuda.amp.autocast():
Epoch 16/100 - Validation: 100%|██████████| 15/15 [00:01<00:00, 14.34it/s]
  with torch.cuda.amp.autocast():
Epoch 16/100 - Testing: 100%|██████████| 17/17 [00:01<00:00, 15.66it/s]


Epoch 16/100 - Train Loss: 0.1910, Train Acc: 0.9622, Val Loss: 0.1836, Val Acc: 0.9634


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 17/100 - Training: 100%|██████████| 18/18 [00:02<00:00,  7.26it/s]
  with torch.cuda.amp.autocast():
Epoch 17/100 - Validation: 100%|██████████| 15/15 [00:01<00:00, 14.53it/s]
  with torch.cuda.amp.autocast():
Epoch 17/100 - Testing: 100%|██████████| 17/17 [00:01<00:00, 14.24it/s]


Epoch 17/100 - Train Loss: 0.1792, Train Acc: 0.9620, Val Loss: 0.1726, Val Acc: 0.9636


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 18/100 - Training: 100%|██████████| 19/19 [00:02<00:00,  7.03it/s]
  with torch.cuda.amp.autocast():
Epoch 18/100 - Validation: 100%|██████████| 15/15 [00:00<00:00, 18.81it/s]
  with torch.cuda.amp.autocast():
Epoch 18/100 - Testing: 100%|██████████| 17/17 [00:00<00:00, 19.16it/s]


Epoch 18/100 - Train Loss: 0.1711, Train Acc: 0.9616, Val Loss: 0.1640, Val Acc: 0.9636


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 19/100 - Training: 100%|██████████| 19/19 [00:02<00:00,  8.04it/s]
  with torch.cuda.amp.autocast():
Epoch 19/100 - Validation: 100%|██████████| 15/15 [00:00<00:00, 16.83it/s]
  with torch.cuda.amp.autocast():
Epoch 19/100 - Testing: 100%|██████████| 17/17 [00:00<00:00, 17.45it/s]


Epoch 19/100 - Train Loss: 0.1623, Train Acc: 0.9622, Val Loss: 0.1562, Val Acc: 0.9636


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 20/100 - Training: 100%|██████████| 18/18 [00:02<00:00,  7.79it/s]
  with torch.cuda.amp.autocast():
Epoch 20/100 - Validation: 100%|██████████| 15/15 [00:00<00:00, 16.60it/s]
  with torch.cuda.amp.autocast():
Epoch 20/100 - Testing: 100%|██████████| 17/17 [00:00<00:00, 17.70it/s]


Epoch 20/100 - Train Loss: 0.1512, Train Acc: 0.9622, Val Loss: 0.1500, Val Acc: 0.9637


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 21/100 - Training: 100%|██████████| 18/18 [00:02<00:00,  7.71it/s]
  with torch.cuda.amp.autocast():
Epoch 21/100 - Validation: 100%|██████████| 15/15 [00:00<00:00, 15.41it/s]
  with torch.cuda.amp.autocast():
Epoch 21/100 - Testing: 100%|██████████| 17/17 [00:01<00:00, 15.42it/s]


Epoch 21/100 - Train Loss: 0.1472, Train Acc: 0.9626, Val Loss: 0.1448, Val Acc: 0.9638


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 22/100 - Training: 100%|██████████| 18/18 [00:02<00:00,  7.48it/s]
  with torch.cuda.amp.autocast():
Epoch 22/100 - Validation: 100%|██████████| 15/15 [00:01<00:00, 14.93it/s]
  with torch.cuda.amp.autocast():
Epoch 22/100 - Testing: 100%|██████████| 17/17 [00:01<00:00, 15.09it/s]


Epoch 22/100 - Train Loss: 0.1457, Train Acc: 0.9620, Val Loss: 0.1398, Val Acc: 0.9638


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 23/100 - Training: 100%|██████████| 19/19 [00:02<00:00,  7.30it/s]
  with torch.cuda.amp.autocast():
Epoch 23/100 - Validation: 100%|██████████| 15/15 [00:00<00:00, 17.60it/s]
  with torch.cuda.amp.autocast():
Epoch 23/100 - Testing: 100%|██████████| 17/17 [00:00<00:00, 18.82it/s]


Epoch 23/100 - Train Loss: 0.1383, Train Acc: 0.9624, Val Loss: 0.1362, Val Acc: 0.9638


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 24/100 - Training: 100%|██████████| 19/19 [00:02<00:00,  8.11it/s]
  with torch.cuda.amp.autocast():
Epoch 24/100 - Validation: 100%|██████████| 15/15 [00:00<00:00, 17.61it/s]
  with torch.cuda.amp.autocast():
Epoch 24/100 - Testing: 100%|██████████| 17/17 [00:00<00:00, 18.36it/s]


Epoch 24/100 - Train Loss: 0.1375, Train Acc: 0.9623, Val Loss: 0.1318, Val Acc: 0.9638


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 25/100 - Training: 100%|██████████| 19/19 [00:02<00:00,  8.08it/s]
  with torch.cuda.amp.autocast():
Epoch 25/100 - Validation: 100%|██████████| 15/15 [00:00<00:00, 17.59it/s]
  with torch.cuda.amp.autocast():
Epoch 25/100 - Testing: 100%|██████████| 17/17 [00:00<00:00, 18.25it/s]


Epoch 25/100 - Train Loss: 0.1326, Train Acc: 0.9623, Val Loss: 0.1287, Val Acc: 0.9638


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 26/100 - Training: 100%|██████████| 18/18 [00:02<00:00,  7.87it/s]
  with torch.cuda.amp.autocast():
Epoch 26/100 - Validation: 100%|██████████| 15/15 [00:00<00:00, 16.21it/s]
  with torch.cuda.amp.autocast():
Epoch 26/100 - Testing: 100%|██████████| 17/17 [00:00<00:00, 17.44it/s]


Epoch 26/100 - Train Loss: 0.1272, Train Acc: 0.9621, Val Loss: 0.1256, Val Acc: 0.9638


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 27/100 - Training: 100%|██████████| 18/18 [00:02<00:00,  7.81it/s]
  with torch.cuda.amp.autocast():
Epoch 27/100 - Validation: 100%|██████████| 15/15 [00:00<00:00, 15.56it/s]
  with torch.cuda.amp.autocast():
Epoch 27/100 - Testing: 100%|██████████| 17/17 [00:01<00:00, 16.37it/s]


Epoch 27/100 - Train Loss: 0.1259, Train Acc: 0.9624, Val Loss: 0.1234, Val Acc: 0.9637


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 28/100 - Training: 100%|██████████| 18/18 [00:02<00:00,  8.19it/s]
  with torch.cuda.amp.autocast():
Epoch 28/100 - Validation: 100%|██████████| 15/15 [00:00<00:00, 16.53it/s]
  with torch.cuda.amp.autocast():
Epoch 28/100 - Testing: 100%|██████████| 17/17 [00:01<00:00, 16.77it/s]


Epoch 28/100 - Train Loss: 0.1231, Train Acc: 0.9622, Val Loss: 0.1208, Val Acc: 0.9637


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 29/100 - Training: 100%|██████████| 18/18 [00:02<00:00,  7.56it/s]
  with torch.cuda.amp.autocast():
Epoch 29/100 - Validation: 100%|██████████| 15/15 [00:00<00:00, 15.05it/s]
  with torch.cuda.amp.autocast():
Epoch 29/100 - Testing: 100%|██████████| 17/17 [00:01<00:00, 15.74it/s]


Epoch 29/100 - Train Loss: 0.1210, Train Acc: 0.9625, Val Loss: 0.1192, Val Acc: 0.9637


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 30/100 - Training: 100%|██████████| 19/19 [00:02<00:00,  7.47it/s]
  with torch.cuda.amp.autocast():
Epoch 30/100 - Validation: 100%|██████████| 15/15 [00:00<00:00, 17.48it/s]
  with torch.cuda.amp.autocast():
Epoch 30/100 - Testing: 100%|██████████| 17/17 [00:00<00:00, 18.16it/s]


Epoch 30/100 - Train Loss: 0.1195, Train Acc: 0.9624, Val Loss: 0.1168, Val Acc: 0.9637


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 31/100 - Training: 100%|██████████| 19/19 [00:02<00:00,  8.01it/s]
  with torch.cuda.amp.autocast():
Epoch 31/100 - Validation: 100%|██████████| 15/15 [00:00<00:00, 16.31it/s]
  with torch.cuda.amp.autocast():
Epoch 31/100 - Testing: 100%|██████████| 17/17 [00:01<00:00, 16.91it/s]


Epoch 31/100 - Train Loss: 0.1156, Train Acc: 0.9624, Val Loss: 0.1150, Val Acc: 0.9637


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 32/100 - Training: 100%|██████████| 19/19 [00:02<00:00,  7.85it/s]
  with torch.cuda.amp.autocast():
Epoch 32/100 - Validation: 100%|██████████| 15/15 [00:00<00:00, 15.85it/s]
  with torch.cuda.amp.autocast():
Epoch 32/100 - Testing: 100%|██████████| 17/17 [00:01<00:00, 16.38it/s]


Epoch 32/100 - Train Loss: 0.1155, Train Acc: 0.9623, Val Loss: 0.1133, Val Acc: 0.9637


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 33/100 - Training: 100%|██████████| 18/18 [00:02<00:00,  7.69it/s]
  with torch.cuda.amp.autocast():
Epoch 33/100 - Validation: 100%|██████████| 15/15 [00:00<00:00, 15.47it/s]
  with torch.cuda.amp.autocast():
Epoch 33/100 - Testing: 100%|██████████| 17/17 [00:01<00:00, 16.03it/s]


Epoch 33/100 - Train Loss: 0.1142, Train Acc: 0.9624, Val Loss: 0.1122, Val Acc: 0.9637


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 34/100 - Training: 100%|██████████| 19/19 [00:02<00:00,  7.73it/s]
  with torch.cuda.amp.autocast():
Epoch 34/100 - Validation: 100%|██████████| 15/15 [00:01<00:00, 14.95it/s]
  with torch.cuda.amp.autocast():
Epoch 34/100 - Testing: 100%|██████████| 17/17 [00:01<00:00, 15.23it/s]


Epoch 34/100 - Train Loss: 0.1117, Train Acc: 0.9624, Val Loss: 0.1103, Val Acc: 0.9637


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 35/100 - Training: 100%|██████████| 18/18 [00:02<00:00,  7.21it/s]
  with torch.cuda.amp.autocast():
Epoch 35/100 - Validation: 100%|██████████| 15/15 [00:01<00:00, 13.72it/s]
  with torch.cuda.amp.autocast():
Epoch 35/100 - Testing: 100%|██████████| 17/17 [00:01<00:00, 15.37it/s]


Epoch 35/100 - Train Loss: 0.1106, Train Acc: 0.9624, Val Loss: 0.1086, Val Acc: 0.9637


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 36/100 - Training: 100%|██████████| 19/19 [00:02<00:00,  7.07it/s]
  with torch.cuda.amp.autocast():
Epoch 36/100 - Validation: 100%|██████████| 15/15 [00:00<00:00, 18.20it/s]
  with torch.cuda.amp.autocast():
Epoch 36/100 - Testing: 100%|██████████| 17/17 [00:00<00:00, 18.77it/s]


Epoch 36/100 - Train Loss: 0.1102, Train Acc: 0.9618, Val Loss: 0.1075, Val Acc: 0.9637


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 37/100 - Training: 100%|██████████| 18/18 [00:02<00:00,  8.05it/s]
  with torch.cuda.amp.autocast():
Epoch 37/100 - Validation: 100%|██████████| 15/15 [00:00<00:00, 16.25it/s]
  with torch.cuda.amp.autocast():
Epoch 37/100 - Testing: 100%|██████████| 17/17 [00:00<00:00, 17.49it/s]


Epoch 37/100 - Train Loss: 0.1077, Train Acc: 0.9622, Val Loss: 0.1061, Val Acc: 0.9637


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 38/100 - Training: 100%|██████████| 19/19 [00:02<00:00,  7.70it/s]
  with torch.cuda.amp.autocast():
Epoch 38/100 - Validation: 100%|██████████| 15/15 [00:00<00:00, 15.41it/s]
  with torch.cuda.amp.autocast():
Epoch 38/100 - Testing: 100%|██████████| 17/17 [00:01<00:00, 16.63it/s]

Epoch 38/100 - Train Loss: 0.1068, Train Acc: 0.9624, Val Loss: 0.1061, Val Acc: 0.9637



  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 39/100 - Training: 100%|██████████| 18/18 [00:02<00:00,  7.55it/s]
  with torch.cuda.amp.autocast():
Epoch 39/100 - Validation: 100%|██████████| 15/15 [00:00<00:00, 15.59it/s]
  with torch.cuda.amp.autocast():
Epoch 39/100 - Testing: 100%|██████████| 17/17 [00:01<00:00, 16.01it/s]


Epoch 39/100 - Train Loss: 0.1062, Train Acc: 0.9626, Val Loss: 0.1043, Val Acc: 0.9637


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 40/100 - Training: 100%|██████████| 19/19 [00:02<00:00,  7.84it/s]
  with torch.cuda.amp.autocast():
Epoch 40/100 - Validation: 100%|██████████| 15/15 [00:00<00:00, 15.87it/s]
  with torch.cuda.amp.autocast():
Epoch 40/100 - Testing: 100%|██████████| 17/17 [00:01<00:00, 16.71it/s]


Epoch 40/100 - Train Loss: 0.1041, Train Acc: 0.9621, Val Loss: 0.1032, Val Acc: 0.9637


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 41/100 - Training: 100%|██████████| 19/19 [00:02<00:00,  7.90it/s]
  with torch.cuda.amp.autocast():
Epoch 41/100 - Validation: 100%|██████████| 15/15 [00:00<00:00, 15.20it/s]
  with torch.cuda.amp.autocast():
Epoch 41/100 - Testing: 100%|██████████| 17/17 [00:01<00:00, 16.70it/s]


Epoch 41/100 - Train Loss: 0.1043, Train Acc: 0.9620, Val Loss: 0.1021, Val Acc: 0.9637


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 42/100 - Training: 100%|██████████| 19/19 [00:02<00:00,  7.47it/s]
  with torch.cuda.amp.autocast():
Epoch 42/100 - Validation: 100%|██████████| 15/15 [00:00<00:00, 15.78it/s]
  with torch.cuda.amp.autocast():
Epoch 42/100 - Testing: 100%|██████████| 17/17 [00:01<00:00, 16.75it/s]


Epoch 42/100 - Train Loss: 0.1041, Train Acc: 0.9623, Val Loss: 0.1019, Val Acc: 0.9637


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 43/100 - Training: 100%|██████████| 19/19 [00:02<00:00,  7.96it/s]
  with torch.cuda.amp.autocast():
Epoch 43/100 - Validation: 100%|██████████| 15/15 [00:00<00:00, 17.59it/s]
  with torch.cuda.amp.autocast():
Epoch 43/100 - Testing: 100%|██████████| 17/17 [00:00<00:00, 18.15it/s]


Epoch 43/100 - Train Loss: 0.1026, Train Acc: 0.9621, Val Loss: 0.1008, Val Acc: 0.9637


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 44/100 - Training: 100%|██████████| 19/19 [00:02<00:00,  8.01it/s]
  with torch.cuda.amp.autocast():
Epoch 44/100 - Validation: 100%|██████████| 15/15 [00:00<00:00, 18.44it/s]
  with torch.cuda.amp.autocast():
Epoch 44/100 - Testing: 100%|██████████| 17/17 [00:00<00:00, 18.81it/s]


Epoch 44/100 - Train Loss: 0.1009, Train Acc: 0.9623, Val Loss: 0.1005, Val Acc: 0.9637


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 45/100 - Training: 100%|██████████| 18/18 [00:02<00:00,  7.84it/s]
  with torch.cuda.amp.autocast():
Epoch 45/100 - Validation: 100%|██████████| 15/15 [00:00<00:00, 17.48it/s]
  with torch.cuda.amp.autocast():
Epoch 45/100 - Testing: 100%|██████████| 17/17 [00:00<00:00, 17.38it/s]


Epoch 45/100 - Train Loss: 0.1001, Train Acc: 0.9624, Val Loss: 0.0996, Val Acc: 0.9637


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 46/100 - Training: 100%|██████████| 18/18 [00:02<00:00,  7.83it/s]
  with torch.cuda.amp.autocast():
Epoch 46/100 - Validation: 100%|██████████| 15/15 [00:00<00:00, 16.12it/s]
  with torch.cuda.amp.autocast():
Epoch 46/100 - Testing: 100%|██████████| 17/17 [00:00<00:00, 17.16it/s]


Epoch 46/100 - Train Loss: 0.0998, Train Acc: 0.9623, Val Loss: 0.0983, Val Acc: 0.9637


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 47/100 - Training: 100%|██████████| 19/19 [00:02<00:00,  7.99it/s]
  with torch.cuda.amp.autocast():
Epoch 47/100 - Validation: 100%|██████████| 15/15 [00:00<00:00, 15.97it/s]
  with torch.cuda.amp.autocast():
Epoch 47/100 - Testing: 100%|██████████| 17/17 [00:01<00:00, 16.35it/s]

Epoch 47/100 - Train Loss: 0.0996, Train Acc: 0.9626, Val Loss: 0.0987, Val Acc: 0.9637



  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 48/100 - Training: 100%|██████████| 18/18 [00:02<00:00,  7.81it/s]
  with torch.cuda.amp.autocast():
Epoch 48/100 - Validation: 100%|██████████| 15/15 [00:00<00:00, 15.98it/s]
  with torch.cuda.amp.autocast():
Epoch 48/100 - Testing: 100%|██████████| 17/17 [00:01<00:00, 16.40it/s]


Epoch 48/100 - Train Loss: 0.0988, Train Acc: 0.9625, Val Loss: 0.0980, Val Acc: 0.9637


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 49/100 - Training: 100%|██████████| 19/19 [00:02<00:00,  7.86it/s]
  with torch.cuda.amp.autocast():
Epoch 49/100 - Validation: 100%|██████████| 15/15 [00:00<00:00, 15.26it/s]
  with torch.cuda.amp.autocast():
Epoch 49/100 - Testing: 100%|██████████| 17/17 [00:01<00:00, 15.53it/s]


Epoch 49/100 - Train Loss: 0.1006, Train Acc: 0.9624, Val Loss: 0.0973, Val Acc: 0.9637


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 50/100 - Training: 100%|██████████| 19/19 [00:02<00:00,  7.58it/s]
  with torch.cuda.amp.autocast():
Epoch 50/100 - Validation: 100%|██████████| 15/15 [00:00<00:00, 17.81it/s]
  with torch.cuda.amp.autocast():
Epoch 50/100 - Testing: 100%|██████████| 17/17 [00:00<00:00, 18.28it/s]


Epoch 50/100 - Train Loss: 0.0979, Train Acc: 0.9621, Val Loss: 0.0967, Val Acc: 0.9637


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 51/100 - Training: 100%|██████████| 19/19 [00:02<00:00,  8.15it/s]
  with torch.cuda.amp.autocast():
Epoch 51/100 - Validation: 100%|██████████| 15/15 [00:00<00:00, 16.43it/s]
  with torch.cuda.amp.autocast():
Epoch 51/100 - Testing: 100%|██████████| 17/17 [00:01<00:00, 16.82it/s]

Epoch 51/100 - Train Loss: 0.0975, Train Acc: 0.9624, Val Loss: 0.0967, Val Acc: 0.9637



  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 52/100 - Training: 100%|██████████| 19/19 [00:02<00:00,  8.11it/s]
  with torch.cuda.amp.autocast():
Epoch 52/100 - Validation: 100%|██████████| 15/15 [00:00<00:00, 16.57it/s]
  with torch.cuda.amp.autocast():
Epoch 52/100 - Testing: 100%|██████████| 17/17 [00:00<00:00, 17.35it/s]


Epoch 52/100 - Train Loss: 0.0969, Train Acc: 0.9624, Val Loss: 0.0958, Val Acc: 0.9637


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 53/100 - Training: 100%|██████████| 18/18 [00:02<00:00,  7.81it/s]
  with torch.cuda.amp.autocast():
Epoch 53/100 - Validation: 100%|██████████| 15/15 [00:00<00:00, 15.63it/s]
  with torch.cuda.amp.autocast():
Epoch 53/100 - Testing: 100%|██████████| 17/17 [00:01<00:00, 16.35it/s]

Epoch 53/100 - Train Loss: 0.0975, Train Acc: 0.9626, Val Loss: 0.0958, Val Acc: 0.9637



  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 54/100 - Training: 100%|██████████| 19/19 [00:02<00:00,  8.05it/s]
  with torch.cuda.amp.autocast():
Epoch 54/100 - Validation: 100%|██████████| 15/15 [00:00<00:00, 15.65it/s]
  with torch.cuda.amp.autocast():
Epoch 54/100 - Testing: 100%|██████████| 17/17 [00:01<00:00, 16.43it/s]


Epoch 54/100 - Train Loss: 0.0959, Train Acc: 0.9620, Val Loss: 0.0948, Val Acc: 0.9637


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 55/100 - Training: 100%|██████████| 18/18 [00:02<00:00,  7.38it/s]
  with torch.cuda.amp.autocast():
Epoch 55/100 - Validation: 100%|██████████| 15/15 [00:00<00:00, 15.32it/s]
  with torch.cuda.amp.autocast():
Epoch 55/100 - Testing: 100%|██████████| 17/17 [00:01<00:00, 15.98it/s]


Epoch 55/100 - Train Loss: 0.0956, Train Acc: 0.9624, Val Loss: 0.0947, Val Acc: 0.9637


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 56/100 - Training: 100%|██████████| 19/19 [00:02<00:00,  7.45it/s]
  with torch.cuda.amp.autocast():
Epoch 56/100 - Validation: 100%|██████████| 15/15 [00:01<00:00, 13.89it/s]
  with torch.cuda.amp.autocast():
Epoch 56/100 - Testing: 100%|██████████| 17/17 [00:01<00:00, 15.64it/s]


Epoch 56/100 - Train Loss: 0.0949, Train Acc: 0.9621, Val Loss: 0.0942, Val Acc: 0.9637


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 57/100 - Training: 100%|██████████| 19/19 [00:02<00:00,  7.84it/s]
  with torch.cuda.amp.autocast():
Epoch 57/100 - Validation: 100%|██████████| 15/15 [00:00<00:00, 16.50it/s]
  with torch.cuda.amp.autocast():
Epoch 57/100 - Testing: 100%|██████████| 17/17 [00:00<00:00, 17.25it/s]


Epoch 57/100 - Train Loss: 0.0954, Train Acc: 0.9626, Val Loss: 0.0948, Val Acc: 0.9637


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 58/100 - Training: 100%|██████████| 19/19 [00:02<00:00,  8.00it/s]
  with torch.cuda.amp.autocast():
Epoch 58/100 - Validation: 100%|██████████| 15/15 [00:00<00:00, 16.92it/s]
  with torch.cuda.amp.autocast():
Epoch 58/100 - Testing: 100%|██████████| 17/17 [00:00<00:00, 17.83it/s]


Epoch 58/100 - Train Loss: 0.0946, Train Acc: 0.9621, Val Loss: 0.0944, Val Acc: 0.9637


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 59/100 - Training: 100%|██████████| 19/19 [00:02<00:00,  8.25it/s]
  with torch.cuda.amp.autocast():
Epoch 59/100 - Validation: 100%|██████████| 15/15 [00:00<00:00, 17.09it/s]
  with torch.cuda.amp.autocast():
Epoch 59/100 - Testing: 100%|██████████| 17/17 [00:00<00:00, 17.58it/s]


Epoch 59/100 - Train Loss: 0.0956, Train Acc: 0.9621, Val Loss: 0.0937, Val Acc: 0.9637


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 60/100 - Training: 100%|██████████| 18/18 [00:02<00:00,  7.90it/s]
  with torch.cuda.amp.autocast():
Epoch 60/100 - Validation: 100%|██████████| 15/15 [00:00<00:00, 16.93it/s]
  with torch.cuda.amp.autocast():
Epoch 60/100 - Testing: 100%|██████████| 17/17 [00:01<00:00, 16.30it/s]


Epoch 60/100 - Train Loss: 0.0943, Train Acc: 0.9623, Val Loss: 0.0928, Val Acc: 0.9637


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 61/100 - Training: 100%|██████████| 18/18 [00:02<00:00,  7.73it/s]
  with torch.cuda.amp.autocast():
Epoch 61/100 - Validation: 100%|██████████| 15/15 [00:00<00:00, 15.42it/s]
  with torch.cuda.amp.autocast():
Epoch 61/100 - Testing: 100%|██████████| 17/17 [00:01<00:00, 16.09it/s]

Epoch 61/100 - Train Loss: 0.0943, Train Acc: 0.9623, Val Loss: 0.0928, Val Acc: 0.9637



  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 62/100 - Training: 100%|██████████| 19/19 [00:02<00:00,  7.66it/s]
  with torch.cuda.amp.autocast():
Epoch 62/100 - Validation: 100%|██████████| 15/15 [00:00<00:00, 15.62it/s]
  with torch.cuda.amp.autocast():
Epoch 62/100 - Testing: 100%|██████████| 17/17 [00:01<00:00, 16.71it/s]


Epoch 62/100 - Train Loss: 0.0938, Train Acc: 0.9623, Val Loss: 0.0926, Val Acc: 0.9637


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 63/100 - Training: 100%|██████████| 19/19 [00:02<00:00,  7.72it/s]
  with torch.cuda.amp.autocast():
Epoch 63/100 - Validation: 100%|██████████| 15/15 [00:01<00:00, 14.56it/s]
  with torch.cuda.amp.autocast():
Epoch 63/100 - Testing: 100%|██████████| 17/17 [00:01<00:00, 15.48it/s]

Epoch 63/100 - Train Loss: 0.0938, Train Acc: 0.9626, Val Loss: 0.0928, Val Acc: 0.9637



  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 64/100 - Training: 100%|██████████| 19/19 [00:02<00:00,  7.65it/s]
  with torch.cuda.amp.autocast():
Epoch 64/100 - Validation: 100%|██████████| 15/15 [00:01<00:00, 14.35it/s]
  with torch.cuda.amp.autocast():
Epoch 64/100 - Testing: 100%|██████████| 17/17 [00:01<00:00, 15.43it/s]


Epoch 64/100 - Train Loss: 0.0935, Train Acc: 0.9621, Val Loss: 0.0921, Val Acc: 0.9637


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 65/100 - Training: 100%|██████████| 18/18 [00:02<00:00,  7.65it/s]
  with torch.cuda.amp.autocast():
Epoch 65/100 - Validation: 100%|██████████| 15/15 [00:00<00:00, 15.12it/s]
  with torch.cuda.amp.autocast():
Epoch 65/100 - Testing: 100%|██████████| 17/17 [00:01<00:00, 15.96it/s]


Epoch 65/100 - Train Loss: 0.0937, Train Acc: 0.9621, Val Loss: 0.0918, Val Acc: 0.9637


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 66/100 - Training: 100%|██████████| 19/19 [00:02<00:00,  7.63it/s]
  with torch.cuda.amp.autocast():
Epoch 66/100 - Validation: 100%|██████████| 15/15 [00:01<00:00, 14.53it/s]
  with torch.cuda.amp.autocast():
Epoch 66/100 - Testing: 100%|██████████| 17/17 [00:01<00:00, 15.30it/s]

Epoch 66/100 - Train Loss: 0.0927, Train Acc: 0.9623, Val Loss: 0.0921, Val Acc: 0.9637



  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 67/100 - Training: 100%|██████████| 18/18 [00:02<00:00,  7.17it/s]
  with torch.cuda.amp.autocast():
Epoch 67/100 - Validation: 100%|██████████| 15/15 [00:01<00:00, 13.95it/s]
  with torch.cuda.amp.autocast():
Epoch 67/100 - Testing: 100%|██████████| 17/17 [00:01<00:00, 14.73it/s]

Epoch 67/100 - Train Loss: 0.0929, Train Acc: 0.9625, Val Loss: 0.0920, Val Acc: 0.9637



  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 68/100 - Training: 100%|██████████| 19/19 [00:02<00:00,  7.51it/s]
  with torch.cuda.amp.autocast():
Epoch 68/100 - Validation: 100%|██████████| 15/15 [00:01<00:00, 13.76it/s]
  with torch.cuda.amp.autocast():
Epoch 68/100 - Testing: 100%|██████████| 17/17 [00:01<00:00, 14.86it/s]


Epoch 68/100 - Train Loss: 0.0936, Train Acc: 0.9622, Val Loss: 0.0918, Val Acc: 0.9637


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 69/100 - Training: 100%|██████████| 18/18 [00:02<00:00,  6.96it/s]
  with torch.cuda.amp.autocast():
Epoch 69/100 - Validation: 100%|██████████| 15/15 [00:01<00:00, 14.46it/s]
  with torch.cuda.amp.autocast():
Epoch 69/100 - Testing: 100%|██████████| 17/17 [00:01<00:00, 14.91it/s]

Epoch 69/100 - Train Loss: 0.0936, Train Acc: 0.9625, Val Loss: 0.0920, Val Acc: 0.9637



  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 70/100 - Training: 100%|██████████| 19/19 [00:02<00:00,  7.48it/s]
  with torch.cuda.amp.autocast():
Epoch 70/100 - Validation: 100%|██████████| 15/15 [00:01<00:00, 14.24it/s]
  with torch.cuda.amp.autocast():
Epoch 70/100 - Testing: 100%|██████████| 17/17 [00:01<00:00, 15.40it/s]


Epoch 70/100 - Train Loss: 0.0927, Train Acc: 0.9620, Val Loss: 0.0914, Val Acc: 0.9637


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 71/100 - Training: 100%|██████████| 19/19 [00:02<00:00,  7.38it/s]
  with torch.cuda.amp.autocast():
Epoch 71/100 - Validation: 100%|██████████| 15/15 [00:01<00:00, 14.08it/s]
  with torch.cuda.amp.autocast():
Epoch 71/100 - Testing: 100%|██████████| 17/17 [00:01<00:00, 15.33it/s]

Epoch 71/100 - Train Loss: 0.0923, Train Acc: 0.9621, Val Loss: 0.0915, Val Acc: 0.9637



  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 72/100 - Training: 100%|██████████| 18/18 [00:02<00:00,  7.16it/s]
  with torch.cuda.amp.autocast():
Epoch 72/100 - Validation: 100%|██████████| 15/15 [00:01<00:00, 13.58it/s]
  with torch.cuda.amp.autocast():
Epoch 72/100 - Testing: 100%|██████████| 17/17 [00:01<00:00, 14.55it/s]

Epoch 72/100 - Train Loss: 0.0918, Train Acc: 0.9622, Val Loss: 0.0915, Val Acc: 0.9637



  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 73/100 - Training: 100%|██████████| 19/19 [00:02<00:00,  6.89it/s]
  with torch.cuda.amp.autocast():
Epoch 73/100 - Validation: 100%|██████████| 15/15 [00:01<00:00, 14.50it/s]
  with torch.cuda.amp.autocast():
Epoch 73/100 - Testing: 100%|██████████| 17/17 [00:01<00:00, 14.78it/s]


Epoch 73/100 - Train Loss: 0.0929, Train Acc: 0.9622, Val Loss: 0.0911, Val Acc: 0.9637


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 74/100 - Training: 100%|██████████| 19/19 [00:02<00:00,  7.35it/s]
  with torch.cuda.amp.autocast():
Epoch 74/100 - Validation: 100%|██████████| 15/15 [00:01<00:00, 14.39it/s]
  with torch.cuda.amp.autocast():
Epoch 74/100 - Testing: 100%|██████████| 17/17 [00:01<00:00, 15.46it/s]


Epoch 74/100 - Train Loss: 0.0922, Train Acc: 0.9623, Val Loss: 0.0909, Val Acc: 0.9637


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 75/100 - Training: 100%|██████████| 19/19 [00:02<00:00,  7.19it/s]
  with torch.cuda.amp.autocast():
Epoch 75/100 - Validation: 100%|██████████| 15/15 [00:01<00:00, 13.94it/s]
  with torch.cuda.amp.autocast():
Epoch 75/100 - Testing: 100%|██████████| 17/17 [00:01<00:00, 14.49it/s]

Epoch 75/100 - Train Loss: 0.0920, Train Acc: 0.9622, Val Loss: 0.0914, Val Acc: 0.9637



  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 76/100 - Training: 100%|██████████| 18/18 [00:02<00:00,  7.16it/s]
  with torch.cuda.amp.autocast():
Epoch 76/100 - Validation: 100%|██████████| 15/15 [00:01<00:00, 14.22it/s]
  with torch.cuda.amp.autocast():
Epoch 76/100 - Testing: 100%|██████████| 17/17 [00:01<00:00, 14.54it/s]

Epoch 76/100 - Train Loss: 0.0925, Train Acc: 0.9624, Val Loss: 0.0912, Val Acc: 0.9637



  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 77/100 - Training: 100%|██████████| 19/19 [00:02<00:00,  7.69it/s]
  with torch.cuda.amp.autocast():
Epoch 77/100 - Validation: 100%|██████████| 15/15 [00:01<00:00, 13.89it/s]
  with torch.cuda.amp.autocast():
Epoch 77/100 - Testing: 100%|██████████| 17/17 [00:01<00:00, 14.52it/s]

Epoch 77/100 - Train Loss: 0.0924, Train Acc: 0.9623, Val Loss: 0.0918, Val Acc: 0.9637



  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 78/100 - Training: 100%|██████████| 18/18 [00:02<00:00,  7.02it/s]
  with torch.cuda.amp.autocast():
Epoch 78/100 - Validation: 100%|██████████| 15/15 [00:01<00:00, 14.52it/s]
  with torch.cuda.amp.autocast():
Epoch 78/100 - Testing: 100%|██████████| 17/17 [00:01<00:00, 14.73it/s]


Epoch 78/100 - Train Loss: 0.0921, Train Acc: 0.9624, Val Loss: 0.0909, Val Acc: 0.9637


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 79/100 - Training: 100%|██████████| 19/19 [00:02<00:00,  7.45it/s]
  with torch.cuda.amp.autocast():
Epoch 79/100 - Validation: 100%|██████████| 15/15 [00:01<00:00, 13.77it/s]
  with torch.cuda.amp.autocast():
Epoch 79/100 - Testing: 100%|██████████| 17/17 [00:01<00:00, 14.71it/s]

Epoch 79/100 - Train Loss: 0.0923, Train Acc: 0.9625, Val Loss: 0.0917, Val Acc: 0.9637



  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 80/100 - Training: 100%|██████████| 19/19 [00:02<00:00,  7.55it/s]
  with torch.cuda.amp.autocast():
Epoch 80/100 - Validation: 100%|██████████| 15/15 [00:01<00:00, 13.70it/s]
  with torch.cuda.amp.autocast():
Epoch 80/100 - Testing: 100%|██████████| 17/17 [00:01<00:00, 14.31it/s]


Epoch 80/100 - Train Loss: 0.0933, Train Acc: 0.9621, Val Loss: 0.0911, Val Acc: 0.9637


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)
Epoch 81/100 - Training:  16%|█▌        | 3/19 [00:01<00:05,  2.93it/s]


ValueError: Tried to step 401 times. The specified number of total steps is 400

# Assymetric Gaia and Lamost

In [4]:
import torch
import torch.nn as nn
import numpy as np
from torch.utils.checkpoint import checkpoint
from mamba_ssm import Mamba2

class AsymmetricMemoryEfficientStarClassifier(nn.Module):
    """
    Memory-efficient version of StarClassifierFusion with asymmetric dimensions
    for spectral and Gaia data, allowing for much smaller Gaia embeddings.
    """
    def __init__(
        self,
        d_model_spectra,
        d_model_gaia,
        num_classes,
        input_dim_spectra,
        input_dim_gaia,
        n_layers=6,
        use_cross_attention=True,
        n_cross_attn_heads=8,
        d_state_spectra=16,
        d_state_gaia=8,  # Can be smaller for Gaia
        d_conv=4,
        expand=2,
        use_checkpoint=True,
        activation_checkpointing=True,
        use_half_precision=True,
        sequential_processing=True
    ):
        super().__init__()
        self.use_checkpoint = use_checkpoint
        self.activation_checkpointing = activation_checkpointing
        self.sequential_processing = sequential_processing
        
        # Store dimensions for later use
        self.d_model_spectra = d_model_spectra
        self.d_model_gaia = d_model_gaia
        
        # Use lower precision
        self.dtype = torch.float16 if use_half_precision else torch.float32

        # Input projection layers - project inputs to their respective embedding spaces
        self.input_proj_spectra = nn.Linear(input_dim_spectra, d_model_spectra)
        self.input_proj_gaia = nn.Linear(input_dim_gaia, d_model_gaia)
        
        # Memory-efficient Mamba layers for Spectra (higher dimension)
        self.mamba_spectra_layers = nn.ModuleList([
            self._create_mamba_layer(
                d_model=d_model_spectra,
                d_state=d_state_spectra,
                d_conv=d_conv,
                expand=expand
            ) for _ in range(n_layers)
        ])
        
        # Memory-efficient Mamba layers for Gaia (lower dimension)
        self.mamba_gaia_layers = nn.ModuleList([
            self._create_mamba_layer(
                d_model=d_model_gaia,
                d_state=d_state_gaia,
                d_conv=d_conv,
                expand=expand
            ) for _ in range(n_layers)
        ])

        # Cross-attention (optional)
        self.use_cross_attention = use_cross_attention
        if use_cross_attention:
            # Adaptation layers for cross-attention with different dimensions
            # For Gaia→Spectra attention, we need to project Gaia to match Spectra dimension
            self.gaia_to_spectra_proj = nn.Linear(d_model_gaia, d_model_spectra)
            
            # For Spectra→Gaia attention, we need to project Spectra to match Gaia dimension
            self.spectra_to_gaia_proj = nn.Linear(d_model_spectra, d_model_gaia)
            
            # Create cross-attention blocks
            self.cross_attn_block_spectra = self._create_cross_attn_block(
                d_model=d_model_spectra, n_heads=n_cross_attn_heads
            )
            self.cross_attn_block_gaia = self._create_cross_attn_block(
                d_model=d_model_gaia, n_heads=n_cross_attn_heads
            )

        # Final classifier
        # Add a projection layer to transform concatenated features to a common fusion dimension
        fusion_dim = d_model_spectra + d_model_gaia
        self.layer_norm = nn.LayerNorm(fusion_dim)
        self.classifier = nn.Linear(fusion_dim, num_classes)
    
    def _create_mamba_layer(self, d_model, d_state, d_conv, expand):
        """Create a memory-efficient Mamba layer."""
        mamba = Mamba2(
            d_model=d_model,
            d_state=d_state,
            d_conv=d_conv,
            expand=expand
        )
        
        # Wrap with gradient checkpointing if requested
        if self.activation_checkpointing:
            return MemoryEfficientMamba(mamba, use_checkpoint=True)
        else:
            return mamba
    
    def _create_cross_attn_block(self, d_model, n_heads):
        """Creates a cross-attention block with optional gradient checkpointing."""
        class CrossAttentionBlock(nn.Module):
            def __init__(self, d_model, n_heads):
                super().__init__()
                self.cross_attn = nn.MultiheadAttention(
                    embed_dim=d_model, 
                    num_heads=n_heads, 
                    batch_first=True
                )
                self.norm1 = nn.LayerNorm(d_model)
                
                # Smaller FFN to save memory
                self.ffn = nn.Sequential(
                    nn.Linear(d_model, 2 * d_model),  # Reduced from 4x
                    nn.ReLU(),
                    nn.Linear(2 * d_model, d_model)   # Reduced from 4x
                )
                self.norm2 = nn.LayerNorm(d_model)
                
            def forward(self, x_q, x_kv):
                # Cross-attention
                attn_output, _ = self.cross_attn(query=x_q, key=x_kv, value=x_kv)
                x = self.norm1(x_q + attn_output)
                
                # Feed forward
                ffn_out = self.ffn(x)
                x = self.norm2(x + ffn_out)
                
                return x
        
        block = CrossAttentionBlock(d_model, n_heads)
        
        # Wrap with gradient checkpointing if requested
        if self.activation_checkpointing:
            def forward_with_checkpoint(module, x_q, x_kv):
                def custom_forward(x_q, x_kv):
                    return module(x_q, x_kv)
                return checkpoint(custom_forward, x_q, x_kv)
            
            class CheckpointedCrossAttention(nn.Module):
                def __init__(self, block):
                    super().__init__()
                    self.block = block
                
                def forward(self, x_q, x_kv):
                    return forward_with_checkpoint(self.block, x_q, x_kv)
            
            return CheckpointedCrossAttention(block)
        else:
            return block
    
    def _process_mamba_layers(self, x, layers):
        """Process input through Mamba layers, optionally sequentially to save memory."""
        if self.sequential_processing:
            for layer in layers:
                x = layer(x)
                # Optional: explicitly delete intermediate activations
                torch.cuda.empty_cache()
        else:
            # Process all layers at once (uses more memory but faster)
            for layer in layers:
                x = layer(x)
        return x
    
    def forward(self, x_spectra, x_gaia):
        # Convert to half precision if requested
        if hasattr(self, 'dtype') and self.dtype == torch.float16:
            x_spectra = x_spectra.half()
            x_gaia = x_gaia.half()
        
        # Project inputs to their respective embedding spaces
        x_spectra = self.input_proj_spectra(x_spectra)
        x_gaia = self.input_proj_gaia(x_gaia)
        
        # Add sequence dimension if needed
        if len(x_spectra.shape) == 2:
            x_spectra = x_spectra.unsqueeze(1)
        if len(x_gaia.shape) == 2:
            x_gaia = x_gaia.unsqueeze(1)
        
        # Process through Mamba layers
        x_spectra = self._process_mamba_layers(x_spectra, self.mamba_spectra_layers)
        x_gaia = self._process_mamba_layers(x_gaia, self.mamba_gaia_layers)
        
        # Optional cross-attention (with dimension adaptation)
        if self.use_cross_attention:
            # Project Gaia features to match spectra dimension for spectra's cross-attention
            x_gaia_projected = self.gaia_to_spectra_proj(x_gaia)
            
            # Project Spectra features to match Gaia dimension for Gaia's cross-attention
            x_spectra_projected = self.spectra_to_gaia_proj(x_spectra)
            
            # Cross-attention from spectra -> gaia (using projected Gaia)
            x_spectra_fused = self.cross_attn_block_spectra(x_spectra, x_gaia_projected)
            
            # Cross-attention from gaia -> spectra (using projected Spectra)
            x_gaia_fused = self.cross_attn_block_gaia(x_gaia, x_spectra_projected)
            
            # Update embeddings
            x_spectra = x_spectra_fused
            x_gaia = x_gaia_fused
            
            # Free memory
            del x_gaia_projected, x_spectra_projected, x_spectra_fused, x_gaia_fused
            torch.cuda.empty_cache()
        
        # Pool across sequence dimension
        x_spectra = x_spectra.mean(dim=1)
        x_gaia = x_gaia.mean(dim=1)
        
        # Concatenate (different dimensions are fine for concatenation)
        x_fused = torch.cat([x_spectra, x_gaia], dim=-1)
        
        # Final classification
        x_fused = self.layer_norm(x_fused)
        logits = self.classifier(x_fused)
        
        return logits


class MemoryEfficientMamba(nn.Module):
    """
    Memory-efficient wrapper for Mamba2 with gradient checkpointing.
    """
    def __init__(self, mamba, use_checkpoint=True):
        super().__init__()
        self.mamba = mamba
        self.use_checkpoint = use_checkpoint
    
    def forward(self, x):
        if self.use_checkpoint and self.training:
            return checkpoint(self.mamba, x)
        else:
            return self.mamba(x)


class UltraMemoryEfficientEnsemble:
    """
    Ultra memory-efficient implementation of ensemble for uncertainty quantification.
    Avoids creating multiple models in memory and removes quantization during initialization.
    
    Note: See ultra-memory-efficient.py for full implementation.
    """
    # This is a placeholder class - use the full implementation from the other file
    pass

In [6]:
import torch
import numpy as np
from torch.utils.data import DataLoader
#from ultra_memory_efficient import UltraMemoryEfficientEnsemble
#from asymmetric_dimensions import AsymmetricMemoryEfficientStarClassifier

# Function to estimate model VRAM usage
def estimate_vram_usage(config):
    # Parameters count estimates
    d_model_spectra = config["d_model_spectra"]
    d_model_gaia = config["d_model_gaia"]
    input_dim_spectra = config["input_dim_spectra"]
    input_dim_gaia = config["input_dim_gaia"]
    n_layers = config["n_layers"]
    num_classes = config["num_classes"]
    
    # Estimate input projection parameters
    input_proj_params = (input_dim_spectra * d_model_spectra) + (input_dim_gaia * d_model_gaia)
    
    # Estimate Mamba parameters (very rough estimation)
    # For each Mamba block: we have the main projection, expand factor, state space, etc.
    mamba_params_spectra = n_layers * d_model_spectra * d_model_spectra * 4  # Approximation
    mamba_params_gaia = n_layers * d_model_gaia * d_model_gaia * 4  # Approximation
    
    # Cross-attention parameters
    cross_attn_params = 0
    if config["use_cross_attention"]:
        # Adaptation layers
        cross_attn_params += (d_model_spectra * d_model_gaia) * 2
        # Attention layers
        cross_attn_params += d_model_spectra * d_model_spectra * 3  # Q,K,V projections
        cross_attn_params += d_model_gaia * d_model_gaia * 3  # Q,K,V projections
        # FFN layers
        cross_attn_params += (d_model_spectra * d_model_spectra * 4) + (d_model_gaia * d_model_gaia * 4)
    
    # Classifier parameters
    classifier_params = (d_model_spectra + d_model_gaia) * num_classes
    
    # Total parameters
    total_params = input_proj_params + mamba_params_spectra + mamba_params_gaia + cross_attn_params + classifier_params
    
    # Estimate VRAM usage (in GB)
    # Parameters (4 bytes per param in FP32, 2 bytes in FP16)
    param_memory = total_params * (2 if config["use_half_precision"] else 4) / (1024**3)
    
    # Optimizer states (Adam uses 8 bytes per parameter, SGD 4 bytes)
    # We'll assume SGD for memory efficiency
    optimizer_memory = total_params * 4 / (1024**3)
    
    # Activations - this is highly approximate
    batch_size = config["batch_size"]
    activations_memory = batch_size * (d_model_spectra + d_model_gaia) * 2 / (1024**3)
    
    # Total VRAM usage
    total_vram = param_memory + optimizer_memory + activations_memory
    
    return {
        "parameters": int(total_params),
        "param_memory_GB": param_memory,
        "optimizer_memory_GB": optimizer_memory,
        "activations_memory_GB": activations_memory,
        "total_vram_GB": total_vram
    }

# Memory-optimized configurations
CONFIGS = [
    {
        "name": "Balanced High-Dim",
        "d_model_spectra": 4096,
        "d_model_gaia": 1024,  # 4x reduction for Gaia
        "num_classes": 55,
        "input_dim_spectra": 3647,
        "input_dim_gaia": 18,
        "n_layers": 12,
        "d_state_spectra": 16,
        "d_state_gaia": 8,
        "d_conv": 4,
        "expand": 2,
        "use_cross_attention": True,
        "n_cross_attn_heads": 8,
        "use_checkpoint": True,
        "activation_checkpointing": True,
        "use_half_precision": True,
        "sequential_processing": True,
        "batch_size": 16
    },
    {
        "name": "Extreme Asymmetric",
        "d_model_spectra": 4096,
        "d_model_gaia": 512,   # 8x reduction for Gaia
        "num_classes": 55,
        "input_dim_spectra": 3647,
        "input_dim_gaia": 18,
        "n_layers": 12,
        "d_state_spectra": 16,
        "d_state_gaia": 8,
        "d_conv": 4,
        "expand": 2,
        "use_cross_attention": True,
        "n_cross_attn_heads": 8,
        "use_checkpoint": True,
        "activation_checkpointing": True,
        "use_half_precision": True,
        "sequential_processing": True,
        "batch_size": 32
    },
    {
        "name": "Extremely Memory Efficient",
        "d_model_spectra": 3072,
        "d_model_gaia": 256,   # 12x reduction for Gaia
        "num_classes": 55,
        "input_dim_spectra": 3647,
        "input_dim_gaia": 18,
        "n_layers": 10,        # Reduced layers
        "d_state_spectra": 8,  # Smaller state
        "d_state_gaia": 4,     # Tiny state for Gaia
        "d_conv": 2,           # Smaller conv
        "expand": 1,           # No expansion
        "use_cross_attention": True,
        "n_cross_attn_heads": 4,  # Fewer heads
        "use_checkpoint": True,
        "activation_checkpointing": True,
        "use_half_precision": True,
        "sequential_processing": True,
        "batch_size": 32
    },
    {
        "name": "Balanced Medium-Dim",
        "d_model_spectra": 2048,  # Reduced spectra dim
        "d_model_gaia": 512,      # 4x reduction for Gaia
        "num_classes": 55,
        "input_dim_spectra": 3647,
        "input_dim_gaia": 18,
        "n_layers": 8,            # Reduced layers
        "d_state_spectra": 8,     # Smaller state
        "d_state_gaia": 4,        # Tiny state for Gaia
        "d_conv": 2,              # Smaller conv
        "expand": 2,
        "use_cross_attention": True,
        "n_cross_attn_heads": 4,  # Fewer heads
        "use_checkpoint": True,
        "activation_checkpointing": True,
        "use_half_precision": True,
        "sequential_processing": True,
        "batch_size": 32
    }
]

# Evaluate each configuration
print("Evaluating configurations for VRAM usage:")
print("-" * 80)

for config in CONFIGS:
    vram_usage = estimate_vram_usage(config)
    print(f"Configuration: {config['name']}")
    print(f"d_model_spectra: {config['d_model_spectra']}, d_model_gaia: {config['d_model_gaia']}")
    print(f"Estimated parameters: {vram_usage['parameters']:,}")
    print(f"Parameter memory: {vram_usage['param_memory_GB']:.2f} GB")
    print(f"Optimizer memory: {vram_usage['optimizer_memory_GB']:.2f} GB")
    print(f"Activations memory: {vram_usage['activations_memory_GB']:.2f} GB")
    print(f"Total estimated VRAM: {vram_usage['total_vram_GB']:.2f} GB")
    print("-" * 80)

print("\nRecommended Configuration:")
print("-" * 80)

# Choose the best configuration
# For a 24 GB GPU, we want to stay under ~20 GB to leave room for system overhead
recommended = None
for config in CONFIGS:
    vram_usage = estimate_vram_usage(config)
    if vram_usage['total_vram_GB'] < 20:
        if recommended is None or config['d_model_spectra'] > recommended['d_model_spectra']:
            recommended = config

if recommended:
    vram_usage = estimate_vram_usage(recommended)
    print(f"Recommended Configuration: {recommended['name']}")
    print(f"d_model_spectra: {recommended['d_model_spectra']}, d_model_gaia: {recommended['d_model_gaia']}")
    print(f"n_layers: {recommended['n_layers']}")
    print(f"Estimated parameters: {vram_usage['parameters']:,}")
    print(f"Total estimated VRAM: {vram_usage['total_vram_GB']:.2f} GB")
    print("-" * 80)
else:
    print("No configuration fits within 20 GB VRAM limit.")
    print("Consider further reductions in model dimensions.")

# Example usage of the recommended configuration
def train_with_recommended_config():
    """Example of how to train with the recommended configuration."""
    # Set device
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # Create model using recommended config
    model = AsymmetricMemoryEfficientStarClassifier(
        d_model_spectra=recommended['d_model_spectra'],
        d_model_gaia=recommended['d_model_gaia'],
        num_classes=recommended['num_classes'],
        input_dim_spectra=recommended['input_dim_spectra'],
        input_dim_gaia=recommended['input_dim_gaia'],
        n_layers=recommended['n_layers'],
        d_state_spectra=recommended['d_state_spectra'],
        d_state_gaia=recommended['d_state_gaia'],
        d_conv=recommended['d_conv'],
        expand=recommended['expand'],
        use_cross_attention=recommended['use_cross_attention'],
        n_cross_attn_heads=recommended['n_cross_attn_heads'],
        use_checkpoint=recommended['use_checkpoint'],
        activation_checkpointing=recommended['activation_checkpointing'],
        use_half_precision=recommended['use_half_precision'],
        sequential_processing=recommended['sequential_processing']
    )
    
    # Create ensemble with this model
    ensemble = UltraMemoryEfficientEnsemble(
        model_class=AsymmetricMemoryEfficientStarClassifier,
        model_args=recommended,
        num_models=5,
        device=device,
        checkpoint_dir='asymmetric_ensemble_models'
    )
    
    # Example training code (assuming your datasets are available)
    train_loader = DataLoader(
        train_dataset, 
        batch_size=recommended['batch_size'], 
        shuffle=True,
        num_workers=2
    )
    val_loader = DataLoader(
        val_dataset, 
        batch_size=recommended['batch_size'], 
        shuffle=False,
        num_workers=2
    )
    
    # Train one model at a time to maximize memory efficiency
    for model_idx in range(5):
        ensemble.train_single_model(
            model_idx=model_idx,
            train_loader=train_loader,
            val_loader=val_loader,
            num_epochs=100,
            lr=1e-4,
            batch_accumulation=4,  # Accumulate gradients over 4 batches
            scheduler_type='OneCycleLR'
        )
    
    return ensemble

Evaluating configurations for VRAM usage:
--------------------------------------------------------------------------------
Configuration: Balanced High-Dim
d_model_spectra: 4096, d_model_gaia: 1024
Estimated parameters: 1,004,045,312
Parameter memory: 1.87 GB
Optimizer memory: 3.74 GB
Activations memory: 0.00 GB
Total estimated VRAM: 5.61 GB
--------------------------------------------------------------------------------
Configuration: Extreme Asymmetric
d_model_spectra: 4096, d_model_gaia: 512
Estimated parameters: 956,559,872
Parameter memory: 1.78 GB
Optimizer memory: 3.56 GB
Activations memory: 0.00 GB
Total estimated VRAM: 5.35 GB
--------------------------------------------------------------------------------
Configuration: Extremely Memory Efficient
d_model_spectra: 3072, d_model_gaia: 256
Estimated parameters: 459,591,936
Parameter memory: 0.86 GB
Optimizer memory: 1.71 GB
Activations memory: 0.00 GB
Total estimated VRAM: 2.57 GB
------------------------------------------------

In [7]:
def train_single_model(
    self, 
    model_idx,
    train_loader, 
    val_loader, 
    test_loader=None, 
    num_epochs=100, 
    lr=1e-4, 
    max_patience=20,
    scheduler_type='OneCycleLR',
    batch_accumulation=1,  # Gradient accumulation steps
    log_to_wandb=True
):
    """
    Train a single model in the ensemble with fixed scheduler stepping.
    
    Args:
        model_idx: Index of the model to train
        train_loader: DataLoader for training data
        val_loader: DataLoader for validation data
        test_loader: DataLoader for test data (optional)
        num_epochs: Maximum number of epochs to train
        lr: Learning rate
        max_patience: Maximum patience for early stopping
        scheduler_type: Type of learning rate scheduler ('OneCycleLR' or 'ReduceLROnPlateau')
        batch_accumulation: Number of batches to accumulate gradients 
        log_to_wandb: Whether to log training progress to wandb
        
    Returns:
        Trained model (also saved to checkpoint)
    """
    import torch.optim as optim
    from tqdm import tqdm
    
    # Initialize wandb if requested
    if log_to_wandb:
        try:
            import wandb
            run = wandb.init(
                project="ALLSTARS_ultra_memory_efficient", 
                name=f"model_{model_idx}",
                group="memory_efficient_training",
                config={
                    **self.model_args,
                    "model_idx": model_idx,
                    "num_models": self.num_models,
                    "lr": lr,
                    "max_patience": max_patience,
                    "scheduler_type": scheduler_type,
                    "batch_accumulation": batch_accumulation,
                    "num_epochs": num_epochs
                },
                reinit=True
            )
        except ImportError:
            print("wandb not installed. Training without logging.")
            log_to_wandb = False
    
    # Load or create model
    model = self._load_model(model_idx, for_inference=False)
    model.train()
    
    # Create optimizer (SGD uses less memory than Adam/AdamW)
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    
    # Calculate the actual number of optimization steps that will be performed
    # This is crucial for OneCycleLR to prevent the 'stepped too many times' error
    # For gradient accumulation, we need to account for the reduced number of steps
    effective_steps_per_epoch = (len(train_loader) + batch_accumulation - 1) // batch_accumulation
    
    # Configure the scheduler
    if scheduler_type == 'OneCycleLR':
        scheduler = optim.lr_scheduler.OneCycleLR(
            optimizer, 
            max_lr=lr,
            epochs=num_epochs, 
            steps_per_epoch=effective_steps_per_epoch
        )
    else:  # ReduceLROnPlateau
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, 
            mode='min', 
            factor=0.5, 
            patience=int(max_patience / 5)
        )
    
    # Calculate class weights for imbalanced classes
    all_labels = []
    for _, _, y_batch in train_loader:
        all_labels.extend(y_batch.cpu().numpy())
    
    class_weights = self._calculate_class_weights(np.array(all_labels))
    class_weights = torch.tensor(class_weights, dtype=torch.float).to(self.device)
    criterion = nn.BCEWithLogitsLoss(pos_weight=class_weights)
    
    best_val_loss = float('inf')
    patience = max_patience
    
    # Initialize mixed precision scaler if supported
    scaler = torch.cuda.amp.GradScaler() if hasattr(torch.cuda, 'amp') else None

    # Training loop
    for epoch in range(num_epochs):
        # Resample training data if needed
        if hasattr(train_loader.dataset, 're_sample'):
            train_loader.dataset.re_sample()
            
            # Recompute class weights if needed
            all_labels = []
            for _, _, y_batch in train_loader:
                all_labels.extend(y_batch.cpu().numpy())
            class_weights = self._calculate_class_weights(np.array(all_labels))
            class_weights = torch.tensor(class_weights, dtype=torch.float).to(self.device)
            criterion = nn.BCEWithLogitsLoss(pos_weight=class_weights)

        # --- Training Phase ---
        model.train()
        train_loss, train_acc = 0.0, 0.0
        batch_count = 0
        optimization_steps = 0  # Track actual optimizer steps
        
        # Reset gradients at the start of epoch
        optimizer.zero_grad(set_to_none=True)
        
        for i, (X_spc, X_ga, y_batch) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Training")):
            X_spc, X_ga, y_batch = X_spc.to(self.device), X_ga.to(self.device), y_batch.to(self.device)
            
            # Forward pass with mixed precision if available
            if scaler:
                with torch.cuda.amp.autocast():
                    outputs = model(X_spc, X_ga)
                    loss = criterion(outputs, y_batch) / batch_accumulation
                
                # Backward pass with scaled gradients
                scaler.scale(loss).backward()
                
                # Step only at the end of accumulation cycle or at the end of epoch
                if (i + 1) % batch_accumulation == 0 or (i + 1) == len(train_loader):
                    scaler.step(optimizer)
                    scaler.update()
                    optimizer.zero_grad(set_to_none=True)
                    optimization_steps += 1
                    
                    # Step the scheduler after optimization
                    if scheduler_type == 'OneCycleLR':
                        scheduler.step()
            else:
                outputs = model(X_spc, X_ga)
                loss = criterion(outputs, y_batch) / batch_accumulation
                loss.backward()
                
                if (i + 1) % batch_accumulation == 0 or (i + 1) == len(train_loader):
                    optimizer.step()
                    optimizer.zero_grad(set_to_none=True)
                    optimization_steps += 1
                    
                    # Step the scheduler after optimization
                    if scheduler_type == 'OneCycleLR':
                        scheduler.step()
            
            # Calculate metrics
            train_loss += loss.item() * batch_accumulation * X_spc.size(0)
            predicted = (torch.sigmoid(outputs) > 0.5).float()
            correct = (predicted == y_batch).float()
            train_acc += correct.mean(dim=1).sum().item()
            batch_count += X_spc.size(0)
            
            # Free up memory
            del X_spc, X_ga, y_batch, outputs, loss, predicted, correct
            torch.cuda.empty_cache()
        
        print(f"Optimization steps this epoch: {optimization_steps}")
        
        train_loss /= batch_count
        train_acc /= batch_count

        # --- Validation Phase ---
        model.eval()
        val_loss, val_acc = 0.0, 0.0
        val_batch_count = 0
        
        with torch.no_grad():
            for X_spc, X_ga, y_batch in tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Validation"):
                X_spc, X_ga, y_batch = X_spc.to(self.device), X_ga.to(self.device), y_batch.to(self.device)
                
                # Forward pass with mixed precision if available
                if scaler:
                    with torch.cuda.amp.autocast():
                        outputs = model(X_spc, X_ga)
                        loss = criterion(outputs, y_batch)
                else:
                    outputs = model(X_spc, X_ga)
                    loss = criterion(outputs, y_batch)
                
                val_loss += loss.item() * X_spc.size(0)
                predicted = (torch.sigmoid(outputs) > 0.5).float()
                correct = (predicted == y_batch).float()
                val_acc += correct.mean(dim=1).sum().item()
                val_batch_count += X_spc.size(0)
                
                # Free up memory
                del X_spc, X_ga, y_batch, outputs, loss, predicted, correct
                torch.cuda.empty_cache()
        
        val_loss /= val_batch_count
        val_acc /= val_batch_count

        # --- Test Phase (if provided) ---
        test_metrics = {}
        if test_loader is not None:
            test_loss, test_acc = 0.0, 0.0
            test_batch_count = 0
            y_true, y_pred = [], []
            
            with torch.no_grad():
                for X_spc, X_ga, y_batch in tqdm(test_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Testing"):
                    X_spc, X_ga, y_batch = X_spc.to(self.device), X_ga.to(self.device), y_batch.to(self.device)
                    
                    # Forward pass with mixed precision if available
                    if scaler:
                        with torch.cuda.amp.autocast():
                            outputs = model(X_spc, X_ga)
                            loss = criterion(outputs, y_batch)
                    else:
                        outputs = model(X_spc, X_ga)
                        loss = criterion(outputs, y_batch)
                    
                    test_loss += loss.item() * X_spc.size(0)
                    predicted = (torch.sigmoid(outputs) > 0.5).float()
                    correct = (predicted == y_batch).float()
                    test_acc += correct.mean(dim=1).sum().item()
                    test_batch_count += X_spc.size(0)
                    
                    # Store on CPU to save GPU memory
                    y_true.extend(y_batch.cpu().numpy())
                    y_pred.extend(predicted.cpu().numpy())
                    
                    # Free up memory
                    del X_spc, X_ga, y_batch, outputs, loss, predicted, correct
                    torch.cuda.empty_cache()
            
            test_loss /= test_batch_count
            test_acc /= test_batch_count
            
            # Calculate metrics on CPU to save GPU memory
            test_metrics = self._calculate_metrics(np.array(y_true), np.array(y_pred))
            test_metrics.update({
                "test_loss": test_loss,
                "test_acc": test_acc,
            })

        # Log metrics
        if log_to_wandb:
            log_data = {
                "epoch": epoch,
                "train_loss": train_loss,
                "val_loss": val_loss,
                "train_acc": train_acc,
                "val_acc": val_acc,
                "lr": self._get_lr(optimizer)
            }
            log_data.update(test_metrics)
            wandb.log(log_data)
        
        # Print progress
        print(f"Epoch {epoch+1}/{num_epochs} - "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

        # Update ReduceLROnPlateau scheduler if used
        if scheduler_type == 'ReduceLROnPlateau':
            scheduler.step(val_loss)

        # Early stopping and checkpoint saving
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience = max_patience
            
            # Save best model
            self._save_model(model, model_idx)
            
            if log_to_wandb:
                wandb.run.summary["best_val_loss"] = best_val_loss
        else:
            patience -= 1
            if patience <= 0:
                print("Early stopping triggered.")
                break
        
        # Save every 10 epochs as a checkpoint
        if (epoch + 1) % 10 == 0:
            checkpoint_path = os.path.join(self.checkpoint_dir, f"model_{model_idx}_epoch_{epoch+1}.pt")
            state_dict = {k: v.cpu() for k, v in model.state_dict().items()}
            torch.save(state_dict, checkpoint_path)
    
    # Close wandb run if used
    if log_to_wandb:
        wandb.finish()
    
    return model

In [20]:
# --- Ultra-low memory configurations for 24GB GPU ---

# Configuration 1: Maximum Asymmetry
# Maximizes the spectral embedding dimension while keeping total VRAM usage manageable
MAX_ASYMMETRY_CONFIG = {
    "name": "Maximum Asymmetry",
    "d_model_spectra": 4096,      # Maximum dimension for spectral data
    "d_model_gaia": 256,          # 16x reduction for Gaia
    "num_classes": 55,
    "input_dim_spectra": 3647,
    "input_dim_gaia": 18,
    "n_layers": 12,
    "d_state_spectra": 16,
    "d_state_gaia": 4,            # Tiny state dimension for Gaia
    "d_conv": 4,
    "expand": 2,
    "use_cross_attention": True,
    "n_cross_attn_heads": 4,      # Reduced heads
    "use_checkpoint": True,
    "activation_checkpointing": True,
    "use_half_precision": True,
    "sequential_processing": True,
    "batch_size": 24,             # Moderate batch size
    "micro_batch_size": 8,        # Process in smaller chunks during training
    "gradient_accumulation": 4    # Effective batch size = 24 × 4 = 96
}

# Configuration 2: Extreme Low Memory
# For when you need to train with very large spectral embeddings
EXTREME_LOW_MEM_CONFIG = {
    "name": "Extreme Low Memory",
    "d_model_spectra": 5120,      # Even larger for spectral data
    "d_model_gaia": 128,          # 40x reduction for Gaia
    "num_classes": 55,
    "input_dim_spectra": 3647,
    "input_dim_gaia": 18,
    "n_layers": 12,
    "d_state_spectra": 8,         # Reduced state dimension
    "d_state_gaia": 2,            # Minimal state dimension
    "d_conv": 2,                  # Reduced conv
    "expand": 1,                  # No expansion
    "use_cross_attention": False, # Remove cross-attention to save memory
    "use_checkpoint": True,
    "activation_checkpointing": True,
    "use_half_precision": True,
    "sequential_processing": True,
    "batch_size": 16,
    "micro_batch_size": 4,
    "gradient_accumulation": 8    # Effective batch size = 16 × 8 = 128
}

# Configuration 3: Balanced Performance
# Good balance between memory efficiency and model performance
BALANCED_CONFIG = {
    "name": "Balanced Performance",
    "d_model_spectra": 3072,      # Moderate dimension
    "d_model_gaia": 384,          # 8x reduction
    "num_classes": 55,
    "input_dim_spectra": 3647,
    "input_dim_gaia": 18,
    "n_layers": 10,               # Slightly reduced layers
    "d_state_spectra": 16,
    "d_state_gaia": 4,
    "d_conv": 4,
    "expand": 2,
    "use_cross_attention": True,
    "n_cross_attn_heads": 6,
    "use_checkpoint": True,
    "activation_checkpointing": True,
    "use_half_precision": True,
    "sequential_processing": True,
    "batch_size": 32,
    "micro_batch_size": 8,        
    "gradient_accumulation": 2    # Effective batch size = 32 × 2 = 64
}

# Configuration 4: Production Ensemble
# For training multiple ensemble models efficiently
PRODUCTION_ENSEMBLE_CONFIG = {
    "name": "Production Ensemble",
    "d_model_spectra": 2560,      # Further reduced for ensemble training
    "d_model_gaia": 320,          # 8x reduction
    "num_classes": 55,
    "input_dim_spectra": 3647,
    "input_dim_gaia": 18,
    "n_layers": 8,                # Reduced layers for faster training
    "d_state_spectra": 8,
    "d_state_gaia": 4,
    "d_conv": 2,
    "expand": 2,
    "use_cross_attention": True,
    "n_cross_attn_heads": 4,
    "use_checkpoint": True,
    "activation_checkpointing": True,
    "use_half_precision": True,
    "sequential_processing": True,
    "batch_size": 32,
    "micro_batch_size": 8,
    "gradient_accumulation": 2    # Effective batch size = 32 × 2 = 64
}

def get_memory_efficient_config(available_vram_gb=24, target_spectra_dim=4096):
    """
    Dynamically generate a memory-efficient configuration based on available VRAM
    and target spectral dimension.
    
    Args:
        available_vram_gb: Available VRAM in GB
        target_spectra_dim: Target embedding dimension for spectral data
        
    Returns:
        Optimized configuration dictionary
    """
    # Base configuration
    config = {
        "num_classes": 55,
        "input_dim_spectra": 3647,
        "input_dim_gaia": 18,
        "use_checkpoint": True,
        "activation_checkpointing": True,
        "use_half_precision": True,
        "sequential_processing": True,
    }
    
    # Scale dimensions based on available VRAM
    usable_vram = available_vram_gb * 0.8  # Leave 20% for system overhead
    
    if usable_vram >= 20:
        # High memory scenario
        config.update({
            "d_model_spectra": min(target_spectra_dim, 4096),
            "d_model_gaia": 256,
            "n_layers": 12,
            "d_state_spectra": 16,
            "d_state_gaia": 4,
            "d_conv": 4,
            "expand": 2,
            "use_cross_attention": True,
            "n_cross_attn_heads": 8,
            "batch_size": 32,
            "gradient_accumulation": 2
        })
    elif usable_vram >= 16:
        # Medium memory scenario
        config.update({
            "d_model_spectra": min(target_spectra_dim, 3072),
            "d_model_gaia": 256,
            "n_layers": 10,
            "d_state_spectra": 8,
            "d_state_gaia": 4,
            "d_conv": 2,
            "expand": 2,
            "use_cross_attention": True,
            "n_cross_attn_heads": 4,
            "batch_size": 24,
            "gradient_accumulation": 4
        })
    elif usable_vram >= 12:
        # Low memory scenario
        config.update({
            "d_model_spectra": min(target_spectra_dim, 2048),
            "d_model_gaia": 192,
            "n_layers": 8,
            "d_state_spectra": 8,
            "d_state_gaia": 4,
            "d_conv": 2,
            "expand": 1,
            "use_cross_attention": True,
            "n_cross_attn_heads": 4,
            "batch_size": 16,
            "gradient_accumulation": 4
        })
    else:
        # Very low memory scenario
        config.update({
            "d_model_spectra": min(target_spectra_dim, 1536),
            "d_model_gaia": 128,
            "n_layers": 6,
            "d_state_spectra": 8,
            "d_state_gaia": 2,
            "d_conv": 2,
            "expand": 1,
            "use_cross_attention": False,
            "batch_size": 8,
            "gradient_accumulation": 8
        })
    
    # Calculate ratio between spectral and Gaia dimensions
    ratio = config["d_model_spectra"] / config["d_model_gaia"]
    config["name"] = f"Auto-{config['d_model_spectra']}:{config['d_model_gaia']} ({ratio:.1f}x)"
    
    return config

In [None]:
#from ultra_memory_efficient import UltraMemoryEfficientEnsemble
#?rom asymmetric_dimensions import AsymmetricMemoryEfficientStarClassifier
#from advanced_asymmetric_config import (
    MAX_ASYMMETRY_CONFIG, 
    EXTREME_LOW_MEM_CONFIG, 
    BALANCED_CONFIG,
    PRODUCTION_ENSEMBLE_CONFIG,
    get_memory_efficient_config
)


In [21]:
import torch
import torch.nn as nn
import os
import gc
import numpy as np
from torch.utils.data import DataLoader
import argparse


# Define a function to safely clean GPU memory
def clean_gpu_memory():
    gc.collect()
    torch.cuda.empty_cache()
    if torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats()
        current_memory = torch.cuda.memory_allocated() / (1024**2)
        print(f"Current GPU memory usage: {current_memory:.2f} MB")

# Function to create and train a model with asymmetric dimensions
def train_asymmetric_model(
    config,
    train_dataset,
    val_dataset,
    test_dataset=None,
    output_dir="asymmetric_models",
    num_epochs=100,
    log_to_wandb=True,
    model_idx=0
):
    """
    Train a model with asymmetric embedding dimensions for spectral and Gaia data.
    
    Args:
        config: Configuration dictionary with model parameters
        train_dataset: Training dataset
        val_dataset: Validation dataset
        test_dataset: Test dataset (optional)
        output_dir: Directory to save model checkpoints
        num_epochs: Number of epochs to train
        log_to_wandb: Whether to log to Weights & Biases
        model_idx: Model index for ensemble training
        
    Returns:
        Trained model or None if training fails
    """
    # Set device
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {device}")
    
    if device == 'cuda':
        # Check available GPU memory
        total_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
        free_memory = torch.cuda.memory_reserved() / (1024**3)
        print(f"Total GPU memory: {total_memory:.2f} GB")
        print(f"Reserved GPU memory: {free_memory:.2f} GB")
    
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    # Print configuration
    print(f"\nTraining model with configuration: {config['name']}")
    print(f"Spectral dimension: {config['d_model_spectra']}, Gaia dimension: {config['d_model_gaia']}")
    print(f"Ratio: {config['d_model_spectra'] / config['d_model_gaia']:.1f}x")
    print(f"Batch size: {config['batch_size']}, Gradient accumulation: {config.get('gradient_accumulation', 1)}")
    print(f"Effective batch size: {config['batch_size'] * config.get('gradient_accumulation', 1)}")
    
    try:
        # Create data loaders with appropriate batch sizes
        batch_size = config['batch_size']
        
        # You can modify DataLoader parameters based on your dataset
        train_loader = DataLoader(
            train_dataset, 
            batch_size=batch_size, 
            shuffle=True,
            num_workers=2,
            pin_memory=True
        )
        
        val_loader = DataLoader(
            val_dataset, 
            batch_size=batch_size, 
            shuffle=False,
            num_workers=2,
            pin_memory=True
        )
        
        test_loader = None
        if test_dataset is not None:
            test_loader = DataLoader(
                test_dataset, 
                batch_size=batch_size, 
                shuffle=False,
                num_workers=2,
                pin_memory=True
            )
        
        # Create checkpoint directory for this specific configuration
        config_dir = os.path.join(output_dir, config['name'].replace(':', '_').replace(' ', '_'))
        os.makedirs(config_dir, exist_ok=True)
        
        # Create ensemble with just one model for now
        ensemble = UltraMemoryEfficientEnsemble(
            model_class=AsymmetricMemoryEfficientStarClassifier,
            model_args=config,
            num_models=1,  # Just train one model for now
            device=device,
            checkpoint_dir=config_dir
        )
        
        # Train the model
        model = ensemble.train_single_model(
            model_idx=model_idx,
            train_loader=train_loader,
            val_loader=val_loader,
            test_loader=test_loader,
            num_epochs=num_epochs,
            lr=1e-4,
            max_patience=20,
            scheduler_type='OneCycleLR',
            batch_accumulation=config.get('gradient_accumulation', 1),
            log_to_wandb=log_to_wandb
        )
        
        return model
    
    except Exception as e:
        print(f"Error training model: {e}")
        # Try to clean up memory
        clean_gpu_memory()
        return None

def main():
    parser = argparse.ArgumentParser(description='Train a model with asymmetric embedding dimensions')
    parser.add_argument('--config', type=str, default='auto', 
                        choices=['max_asymmetry', 'extreme_low_mem', 'balanced', 'production', 'auto'],
                        help='Configuration to use')
    parser.add_argument('--target_dim', type=int, default=4096, 
                        help='Target spectral dimension for auto config')
    parser.add_argument('--epochs', type=int, default=100, 
                        help='Number of epochs to train')
    parser.add_argument('--output_dir', type=str, default='asymmetric_models', 
                        help='Directory to save model checkpoints')
    parser.add_argument('--no_wandb', action='store_true', 
                        help='Disable Weights & Biases logging')
    args = parser.parse_args()
    
    # Load datasets (replace this with your actual loading code)
    # Example placeholder:
    from your_dataset_module import (
        X_train_spectra, X_train_gaia, y_train,
        X_val_spectra, X_val_gaia, y_val,
        X_test_spectra, X_test_gaia, y_test,
        MultiModalBalancedMultiLabelDataset
    )
    
    # Create datasets with appropriate batch limits
    # These would be replaced with your actual dataset creation code
    train_dataset = MultiModalBalancedMultiLabelDataset(
        X_train_spectra, X_train_gaia, y_train, limit_per_label=201
    )
    val_dataset = MultiModalBalancedMultiLabelDataset(
        X_val_spectra, X_val_gaia, y_val, limit_per_label=201
    )
    test_dataset = MultiModalBalancedMultiLabelDataset(
        X_test_spectra, X_test_gaia, y_test, limit_per_label=201
    )
    
    # Select configuration
    if args.config == 'max_asymmetry':
        config = MAX_ASYMMETRY_CONFIG
    elif args.config == 'extreme_low_mem':
        config = EXTREME_LOW_MEM_CONFIG
    elif args.config == 'balanced':
        config = BALANCED_CONFIG
    elif args.config == 'production':
        config = PRODUCTION_ENSEMBLE_CONFIG
    else:  # auto
        # Auto-configure based on available VRAM
        if torch.cuda.is_available():
            vram_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3)
        else:
            vram_gb = 16  # Default assumption
        
        config = get_memory_efficient_config(
            available_vram_gb=vram_gb,
            target_spectra_dim=args.target_dim
        )
    
    # Train model
    trained_model = train_asymmetric_model(
        config=config,
        train_dataset=train_dataset,
        val_dataset=val_dataset,
        test_dataset=test_dataset,
        output_dir=args.output_dir,
        num_epochs=args.epochs,
        log_to_wandb=not args.no_wandb
    )
    
    # Clean up
    if trained_model is not None:
        del trained_model
    
    clean_gpu_memory()
    print("Training complete.")

if __name__ == "__main__":
    # Set environment variables for better memory management
    os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
    
    main()

usage: ipykernel_launcher.py [-h]
                             [--config {max_asymmetry,extreme_low_mem,balanced,production,auto}]
                             [--target_dim TARGET_DIM] [--epochs EPOCHS]
                             [--output_dir OUTPUT_DIR] [--no_wandb]
ipykernel_launcher.py: error: unrecognized arguments: --f=/home/joao/.local/share/jupyter/runtime/kernel-v380eaa32613c5cb54b8cb2e37731c87b559a132d2.json


SystemExit: 2