In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import wandb
import copy
import matplotlib.pyplot as plt
from sklearn.metrics import f1_score, precision_score, recall_score, hamming_loss, roc_auc_score
import gc
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from sklearn.model_selection import train_test_split
from mamba_ssm import Mamba2
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class CrossAttentionBlock(nn.Module):
    """
    A simple cross-attention block with a feed-forward sub-layer.
    """
    def __init__(self, d_model, n_heads=8):
        super().__init__()
        self.cross_attn = nn.MultiheadAttention(
            embed_dim=d_model, 
            num_heads=n_heads, 
            batch_first=True
        )
        self.norm1 = nn.LayerNorm(d_model)
        
        self.ffn = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.ReLU(),
            nn.Linear(4 * d_model, d_model)
        )
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x_q, x_kv):
        """
        Args:
            x_q  : (batch_size, seq_len_q, d_model)
            x_kv : (batch_size, seq_len_kv, d_model)
        """
        # Cross-attention
        attn_output, _ = self.cross_attn(query=x_q, key=x_kv, value=x_kv)
        x = self.norm1(x_q + attn_output)

        # Feed forward
        ffn_out = self.ffn(x)
        x = self.norm2(x + ffn_out)
        
        return x


class StarClassifierFusion(nn.Module):
    def __init__(
        self,
        d_model_spectra,
        d_model_gaia,
        num_classes,
        input_dim_spectra,
        input_dim_gaia,
        n_layers=6,
        use_cross_attention=True,
        n_cross_attn_heads=8,
        d_state=256,
        d_conv=4,
        expand=2,
    ):
        """
        Args:
            d_model_spectra (int): embedding dimension for the spectra MAMBA
            d_model_gaia (int): embedding dimension for the gaia MAMBA
            num_classes (int): multi-label classification
            input_dim_spectra (int): # of features for spectra
            input_dim_gaia (int): # of features for gaia
            n_layers (int): depth for each MAMBA
            use_cross_attention (bool): whether to use cross-attention
            n_cross_attn_heads (int): number of heads for cross-attention
        """
        super().__init__()



        # --- MAMBA 2 for spectra ---
        self.mamba_spectra = nn.Sequential(
            *[Mamba2(
                d_model=d_model_spectra,
                d_state=d_state,
                d_conv=d_conv,
                expand=expand,
            ) for _ in range(n_layers)]
        )
        self.input_proj_spectra = nn.Linear(input_dim_spectra, d_model_spectra)


        # --- MAMBA 2 for gaia ---
        self.mamba_gaia = nn.Sequential(
            *[Mamba2(
                d_model=d_model_gaia,
                d_state=d_state,
                d_conv=d_conv,
                expand=expand,
            ) for _ in range(n_layers)]
        )
        self.input_proj_gaia = nn.Linear(input_dim_gaia, d_model_gaia)

        # --- Cross Attention (Optional) ---
        self.use_cross_attention = use_cross_attention
        if use_cross_attention:
            # We'll do cross-attn in both directions or just one—here is an example with 2 blocks
            self.cross_attn_block_spectra = CrossAttentionBlock(d_model_spectra, n_heads=n_cross_attn_heads)
            self.cross_attn_block_gaia = CrossAttentionBlock(d_model_gaia, n_heads=n_cross_attn_heads)

        # --- Final Classifier ---
        # If you do late fusion by concatenation, the dimension is d_model_spectra + d_model_gaia
        # If you do average fusion, it is max(d_model_spectra, d_model_gaia) (or keep them separate).
        fusion_dim = d_model_spectra + d_model_gaia
        self.classifier = nn.Sequential(
            nn.LayerNorm(fusion_dim),
            nn.Linear(fusion_dim, num_classes)
        )
    
    def forward(self, x_spectra, x_gaia):
        """
        x_spectra : (batch_size, input_dim_spectra) or potentially (batch_size, seq_len_spectra, input_dim_spectra)
        x_gaia    : (batch_size, input_dim_gaia) or (batch_size, seq_len_gaia, input_dim_gaia)
        """
        # For MAMBA, we expect shape: (B, seq_len, d_model). 
        # If your input is just (B, d_in), we turn it into (B, 1, d_in).
        
        # --- Project to d_model and add sequence dimension (seq_len=1) ---
        x_spectra = self.input_proj_spectra(x_spectra)  # (B, d_model_spectra)
        x_spectra = x_spectra.unsqueeze(1)             # (B, 1, d_model_spectra)

        x_gaia = self.input_proj_gaia(x_gaia)          # (B, d_model_gaia)
        x_gaia = x_gaia.unsqueeze(1)                   # (B, 1, d_model_gaia)

        # --- MAMBA encoding (each modality separately) ---
        x_spectra = self.mamba_spectra(x_spectra)  # (B, 1, d_model_spectra)
        x_gaia = self.mamba_gaia(x_gaia)          # (B, 1, d_model_gaia)

        # Optionally, use cross-attention to fuse the representations
        if self.use_cross_attention:
            # Cross-attention from spectra -> gaia
            x_spectra_fused = self.cross_attn_block_spectra(x_spectra, x_gaia)
            # Cross-attention from gaia -> spectra
            x_gaia_fused = self.cross_attn_block_gaia(x_gaia, x_spectra)
            
            # Update x_spectra and x_gaia
            x_spectra = x_spectra_fused
            x_gaia = x_gaia_fused
        
        # --- Pool across sequence dimension (since our seq_len=1, just squeeze) ---
        x_spectra = x_spectra.mean(dim=1)  # (B, d_model_spectra)
        x_gaia = x_gaia.mean(dim=1)        # (B, d_model_gaia)

        # --- Late Fusion by Concatenation ---
        x_fused = torch.cat([x_spectra, x_gaia], dim=-1)  # (B, d_model_spectra + d_model_gaia)

        # --- Final classification ---
        logits = self.classifier(x_fused)  # (B, num_classes)
        return logits


In [3]:
def train_model_fusion(
    model,
    train_loader,
    val_loader,
    test_loader,
    num_epochs=100,
    lr=1e-4,
    max_patience=20,
    device='cuda'
):
    model = model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=int(max_patience / 5)
    )

    # We assume the datasets are MultiModalBalancedMultiLabelDataset
    # that returns (X_spectra, X_gaia, y).
    # You can keep the class weighting logic as in train_model_mamba.
    all_labels = []
    for _, _, y_batch in train_loader:
        all_labels.extend(y_batch.cpu().numpy())
    
    class_weights = calculate_class_weights(np.array(all_labels))
    class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)
    criterion = nn.BCEWithLogitsLoss(pos_weight=class_weights)
    
    best_val_loss = float('inf')
    patience = max_patience

    for epoch in range(num_epochs):
        # Resample training data
        train_loader.dataset.re_sample()

        # Recompute class weights if needed
        all_labels = []
        for _, _, y_batch in train_loader:
            all_labels.extend(y_batch.cpu().numpy())
        class_weights = calculate_class_weights(np.array(all_labels))
        class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)
        criterion = nn.BCEWithLogitsLoss(pos_weight=class_weights)

        # --- Training ---
        model.train()
        train_loss, train_acc = 0.0, 0.0
        for X_spc, X_ga, y_batch in train_loader:
            X_spc, X_ga, y_batch = X_spc.to(device), X_ga.to(device), y_batch.to(device)
            optimizer.zero_grad()
            outputs = model(X_spc, X_ga)
            loss = criterion(outputs, y_batch)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * X_spc.size(0)
            predicted = (torch.sigmoid(outputs) > 0.5).float()
            correct = (predicted == y_batch).float()
            train_acc += correct.mean(dim=1).mean().item()

        # --- Validation ---
        model.eval()
        val_loss, val_acc = 0.0, 0.0
        with torch.no_grad():
            for X_spc, X_ga, y_batch in val_loader:
                X_spc, X_ga, y_batch = X_spc.to(device), X_ga.to(device), y_batch.to(device)
                outputs = model(X_spc, X_ga)
                loss = criterion(outputs, y_batch)
                val_loss += loss.item() * X_spc.size(0)
                predicted = (torch.sigmoid(outputs) > 0.5).float()
                correct = (predicted == y_batch).float()
                val_acc += correct.mean(dim=1).mean().item()

        # --- Test metrics (optional or do after training) ---
        test_loss, test_acc = 0.0, 0.0
        y_true, y_pred = [], []
        with torch.no_grad():
            for X_spc, X_ga, y_batch in test_loader:
                X_spc, X_ga, y_batch = X_spc.to(device), X_ga.to(device), y_batch.to(device)
                outputs = model(X_spc, X_ga)
                loss = criterion(outputs, y_batch)
                test_loss += loss.item() * X_spc.size(0)
                
                predicted = (torch.sigmoid(outputs) > 0.5).float()
                correct = (predicted == y_batch).float()
                test_acc += correct.mean(dim=1).mean().item()

                y_true.extend(y_batch.cpu().numpy())
                y_pred.extend(predicted.cpu().numpy())

        # Compute multi-label metrics as before
        all_metrics = calculate_metrics(np.array(y_true), np.array(y_pred))

        # Logging example
        wandb.log({
            "epoch": epoch,
            "train_loss": train_loss / len(train_loader.dataset),
            "val_loss": val_loss / len(val_loader.dataset),
            "train_acc": train_acc / len(train_loader),
            "val_acc": val_acc / len(val_loader),
            "test_loss": test_loss / len(test_loader.dataset),
            "test_acc": test_acc / len(test_loader),
            **all_metrics
        })

        # Scheduler
        scheduler.step(val_loss / len(val_loader.dataset))

        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience = max_patience
            best_model = model.state_dict()
        else:
            patience -= 1
            if patience <= 0:
                print("Early stopping triggered.")
                break

    model.load_state_dict(best_model)
    return model

class MultiModalBalancedMultiLabelDataset(Dataset):
    """
    A balanced multi-label dataset that returns (X_spectra, X_gaia, y).
    It uses the same balancing strategy as `BalancedMultiLabelDataset`.
    """
    def __init__(self, X_spectra, X_gaia, y, limit_per_label=201):
        """
        Args:
            X_spectra (torch.Tensor): [num_samples, num_spectra_features]
            X_gaia (torch.Tensor): [num_samples, num_gaia_features]
            y (torch.Tensor): [num_samples, num_classes], multi-hot labels
            limit_per_label (int): limit or target number of samples per label
        """
        self.X_spectra = X_spectra
        self.X_gaia = X_gaia
        self.y = y
        self.limit_per_label = limit_per_label
        self.num_classes = y.shape[1]
        self.indices = self.balance_classes()
        
    def balance_classes(self):
        indices = []
        class_counts = torch.sum(self.y, axis=0)
        for cls in range(self.num_classes):
            cls_indices = np.where(self.y[:, cls] == 1)[0]
            if len(cls_indices) < self.limit_per_label:
                if len(cls_indices) == 0:
                    # No samples for this class
                    continue
                extra_indices = np.random.choice(
                    cls_indices, self.limit_per_label - len(cls_indices), replace=True
                )
                cls_indices = np.concatenate([cls_indices, extra_indices])
            elif len(cls_indices) > self.limit_per_label:
                cls_indices = np.random.choice(cls_indices, self.limit_per_label, replace=False)
            indices.extend(cls_indices)
        indices = np.unique(indices)
        np.random.shuffle(indices)
        return indices

    def re_sample(self):
        self.indices = self.balance_classes()

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        index = self.indices[idx]
        return (
            self.X_spectra[index],  # spectra features
            self.X_gaia[index],     # gaia features
            self.y[index],          # multi-hot labels
        )
def calculate_class_weights(y):
    if y.ndim > 1:  
        class_counts = np.sum(y, axis=0)  
    else:
        class_counts = np.bincount(y)

    total_samples = y.shape[0] if y.ndim > 1 else len(y)
    class_counts = np.where(class_counts == 0, 1, class_counts)  # Prevent division by zero
    class_weights = total_samples / (len(class_counts) * class_counts)
    
    return class_weights
def calculate_metrics(y_true, y_pred):
    metrics = {
        "micro_f1": f1_score(y_true, y_pred, average='micro'),
        "macro_f1": f1_score(y_true, y_pred, average='macro'),
        "weighted_f1": f1_score(y_true, y_pred, average='weighted'),
        "micro_precision": precision_score(y_true, y_pred, average='micro', zero_division=1),
        "macro_precision": precision_score(y_true, y_pred, average='macro', zero_division=1),
        "weighted_precision": precision_score(y_true, y_pred, average='weighted', zero_division=1),
        "micro_recall": recall_score(y_true, y_pred, average='micro'),
        "macro_recall": recall_score(y_true, y_pred, average='macro'),
        "weighted_recall": recall_score(y_true, y_pred, average='weighted'),
        "hamming_loss": hamming_loss(y_true, y_pred)
    }
    
    # Check if there are at least two classes present in y_true
    #if len(np.unique(y_true)) > 1:
        #metrics["roc_auc"] = roc_auc_score(y_true, y_pred, average='macro', multi_class='ovr')
    #else:
       # metrics["roc_auc"] = None  # or you can set it to a default value or message
    
    return metrics


In [4]:
batch_size = 128
batch_limit = int(batch_size / 2.5)


# Load datasets
import pickle
# Open them in a cross-platform way
with open("Pickles/Updated_List_of_Classes_ubuntu.pkl", "rb") as f:
    classes = pickle.load(f)  # This reads the actual data
with open("Pickles/train_data_transformed_ubuntu.pkl", "rb") as f:
    X_train_full = pickle.load(f)
with open("Pickles/test_data_transformed_ubuntu.pkl", "rb") as f:
    X_test_full = pickle.load(f)




# Extract labels
y_train_full = X_train_full[classes]
y_test = X_test_full[classes]

# Drop labels from both datasets
X_train_full.drop(classes, axis=1, inplace=True)
X_test_full.drop(classes, axis=1, inplace=True)


# Columns for spectral data (assuming all remaining columns after removing Gaia are spectra)
gaia_columns = ["parallax", "ra", "dec", "ra_error", "dec_error", "parallax_error", "pmra", "pmdec", 
                "pmra_error", "pmdec_error", "phot_g_mean_flux", "flagnopllx", "phot_g_mean_flux_error", 
                "phot_bp_mean_flux", "phot_rp_mean_flux", "phot_bp_mean_flux_error", "phot_rp_mean_flux_error", 
                "flagnoflux"]

# Spectra data (everything that is not Gaia-related) and the column 'otype'
X_train_spectra = X_train_full.drop(columns={"otype", "obsid", *gaia_columns})
X_test_spectra = X_test_full.drop(columns={"otype", "obsid", *gaia_columns})

# Gaia data (only the selected columns)
X_train_gaia = X_train_full[gaia_columns]
X_test_gaia = X_test_full[gaia_columns]

# Count nans and infs in x_train_gaia
print(X_train_gaia.isnull().sum())
print(X_train_gaia.isin([np.inf, -np.inf]).sum())


# Free up memory
del X_train_full, X_test_full
gc.collect()



# Split training set into training and validation
X_train_spectra, X_val_spectra, X_train_gaia, X_val_gaia, y_train, y_val = train_test_split(
    X_train_spectra, X_train_gaia, y_train_full, test_size=0.2, random_state=42
)

# Free memory
del y_train_full
gc.collect()



# Convert spectra and Gaia data into PyTorch tensors
X_train_spectra = torch.tensor(X_train_spectra.values, dtype=torch.float32)
X_val_spectra = torch.tensor(X_val_spectra.values, dtype=torch.float32)
X_test_spectra = torch.tensor(X_test_spectra.values, dtype=torch.float32)



X_train_gaia = torch.tensor(X_train_gaia.values, dtype=torch.float32)
X_val_gaia = torch.tensor(X_val_gaia.values, dtype=torch.float32)
X_test_gaia = torch.tensor(X_test_gaia.values, dtype=torch.float32)

y_train = torch.tensor(y_train.values, dtype=torch.float32)
y_val = torch.tensor(y_val.values, dtype=torch.float32)
y_test = torch.tensor(y_test.values, dtype=torch.float32)

# Print dataset shapes
print(f"X_train_spectra shape: {X_train_spectra.shape}")
print(f"X_val_spectra shape: {X_val_spectra.shape}")
print(f"X_test_spectra shape: {X_test_spectra.shape}")

print(f"X_train_gaia shape: {X_train_gaia.shape}")
print(f"X_val_gaia shape: {X_val_gaia.shape}")
print(f"X_test_gaia shape: {X_test_gaia.shape}")

print(f"y_train shape: {y_train.shape}")
print(f"y_val shape: {y_val.shape}")
print(f"y_test shape: {y_test.shape}")


train_dataset = MultiModalBalancedMultiLabelDataset(X_train_spectra, X_train_gaia, y_train, limit_per_label=batch_limit)
val_dataset = MultiModalBalancedMultiLabelDataset(X_val_spectra, X_val_gaia, y_val, limit_per_label=batch_limit)
test_dataset = MultiModalBalancedMultiLabelDataset(X_test_spectra, X_test_gaia, y_test, limit_per_label=batch_limit)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# print the number of samples in each dataset
print(f"Train dataset: {len(train_dataset)} samples")
print(f"Validation dataset: {len(val_dataset)} samples")
print(f"Test dataset: {len(test_dataset)} samples")


parallax                   0
ra                         0
dec                        0
ra_error                   0
dec_error                  0
parallax_error             0
pmra                       0
pmdec                      0
pmra_error                 0
pmdec_error                0
phot_g_mean_flux           0
flagnopllx                 0
phot_g_mean_flux_error     0
phot_bp_mean_flux          0
phot_rp_mean_flux          0
phot_bp_mean_flux_error    0
phot_rp_mean_flux_error    0
flagnoflux                 0
dtype: int64
parallax                   0
ra                         0
dec                        0
ra_error                   0
dec_error                  0
parallax_error             0
pmra                       0
pmdec                      0
pmra_error                 0
pmdec_error                0
phot_g_mean_flux           0
flagnopllx                 0
phot_g_mean_flux_error     0
phot_bp_mean_flux          0
phot_rp_mean_flux          0
phot_bp_mean_flux_error    0
p

In [5]:
class EnsembleStarClassifier:
    """
    Ensemble of StarClassifierFusion models with uncertainty quantification.
    """
    def __init__(
        self,
        model_class,
        model_args,
        num_models=5,
        device='cuda'
    ):
        """
        Args:
            model_class: The model class to use (StarClassifierFusion)
            model_args: Dictionary of arguments to pass to the model constructor
            num_models: Number of models in the ensemble
            device: Device to use for computation
        """
        self.num_models = num_models
        self.device = device
        self.models = []
        
        # Initialize models with different random initializations
        for i in range(num_models):
            model = model_class(**model_args)
            model.to(device)
            self.models.append(model)
            
    def train(
        self,
        train_loader,
        val_loader,
        test_loader,
        train_function,
        train_args,
        bootstrap=True,
        random_seed_offset=0
    ):
        """
        Train each model in the ensemble.
        
        Args:
            train_loader: DataLoader for training data
            val_loader: DataLoader for validation data
            test_loader: DataLoader for test data
            train_function: Function to train a single model
            train_args: Dictionary of arguments to pass to train_function
            bootstrap: Whether to use bootstrapping for training
            random_seed_offset: Offset for random seeds
        """
        trained_models = []
        
        for i in range(self.num_models):
            print(f"Training model {i+1}/{self.num_models}")
            
            # Set different random seed for each model
            seed = random_seed_offset + i
            torch.manual_seed(seed)
            np.random.seed(seed)
            
            if bootstrap:
                # Create bootstrapped dataset
                bootstrap_train_loader = self._create_bootstrap_loader(train_loader)
                curr_train_loader = bootstrap_train_loader
            else:
                curr_train_loader = train_loader
            
            # Initialize a new model for this ensemble member
            model = copy.deepcopy(self.models[i])
            
            # Create a new wandb run for this model
            run_name = f"ensemble_member_{i+1}"
            wandb.init(project="ALLSTARS_ensemble", name=run_name, group="ensemble_training", reinit=True)
            
            # Log ensemble member info
            wandb.config.update({
                "ensemble_member": i+1,
                "num_models": self.num_models,
                "bootstrap": bootstrap,
                "random_seed": seed
            })
            
            # Train the model
            trained_model = train_function(
                model=model,
                train_loader=curr_train_loader,
                val_loader=val_loader,
                test_loader=test_loader,
                **train_args
            )
            
            # Save the trained model
            trained_models.append(trained_model)
            
            # Save model checkpoint
            torch.save(trained_model.state_dict(), f"ensemble_model_{i+1}.pth")
            
            # Finish wandb run
            wandb.finish()
        
        self.models = trained_models
        return trained_models
    
    def _create_bootstrap_loader(self, dataloader):
        """
        Create a bootstrapped version of a dataloader.
        
        Args:
            dataloader: Original DataLoader
            
        Returns:
            DataLoader with bootstrapped samples
        """
        dataset = dataloader.dataset
        n_samples = len(dataset)
        
        # Generate bootstrap indices (sampling with replacement)
        bootstrap_indices = np.random.choice(n_samples, size=n_samples, replace=True)
        
        # Create a subset dataset with the bootstrapped indices
        bootstrap_dataset = torch.utils.data.Subset(dataset, bootstrap_indices)
        
        # Create a new dataloader with the bootstrapped dataset
        bootstrap_loader = DataLoader(
            bootstrap_dataset,
            batch_size=dataloader.batch_size,
            shuffle=True,
            num_workers=dataloader.num_workers if hasattr(dataloader, 'num_workers') else 0
        )
        
        return bootstrap_loader
    
    def predict(self, X_spectra, X_gaia, return_individual=False):
        """
        Generate predictions from the ensemble.
        
        Args:
            X_spectra: Spectral features tensor
            X_gaia: Gaia features tensor
            return_individual: Whether to return individual model predictions
            
        Returns:
            mean_probs: Mean probabilities across ensemble
            std_probs: Standard deviation of probabilities (uncertainty)
            individual_probs: Individual model probabilities (if return_individual=True)
        """
        # Ensure inputs are on the correct device
        X_spectra = X_spectra.to(self.device)
        X_gaia = X_gaia.to(self.device)
        
        all_probs = []
        
        # Get predictions from each model
        for model in self.models:
            model.eval()
            with torch.no_grad():
                logits = model(X_spectra, X_gaia)
                probs = torch.sigmoid(logits)
                all_probs.append(probs.cpu().numpy())
        
        # Stack predictions
        all_probs = np.stack(all_probs)
        
        # Calculate mean and standard deviation
        mean_probs = np.mean(all_probs, axis=0)
        std_probs = np.std(all_probs, axis=0)
        
        if return_individual:
            return mean_probs, std_probs, all_probs
        else:
            return mean_probs, std_probs
    
    def evaluate(self, test_loader, threshold=0.5, return_predictions=False):
        """
        Evaluate the ensemble on a test set.
        
        Args:
            test_loader: DataLoader for test data
            threshold: Classification threshold
            return_predictions: Whether to return predictions
            
        Returns:
            metrics: Dictionary of evaluation metrics
            mean_probs: Mean probabilities (if return_predictions=True)
            std_probs: Standard deviation of probabilities (if return_predictions=True)
            y_true: True labels (if return_predictions=True)
        """
        all_mean_probs = []
        all_std_probs = []
        all_y_true = []
        
        # Generate predictions for each batch
        for X_spectra, X_gaia, y_batch in test_loader:
            X_spectra, X_gaia = X_spectra.to(self.device), X_gaia.to(self.device)
            
            # Get ensemble predictions
            mean_probs, std_probs = self.predict(X_spectra, X_gaia)
            
            all_mean_probs.extend(mean_probs)
            all_std_probs.extend(std_probs)
            all_y_true.extend(y_batch.cpu().numpy())
        
        # Convert to numpy arrays
        mean_probs = np.array(all_mean_probs)
        std_probs = np.array(all_std_probs)
        y_true = np.array(all_y_true)
        
        # Make binary predictions
        y_pred = (mean_probs > threshold).astype(float)
        
        # Calculate metrics
        metrics = {
            "micro_f1": f1_score(y_true, y_pred, average='micro'),
            "macro_f1": f1_score(y_true, y_pred, average='macro'),
            "weighted_f1": f1_score(y_true, y_pred, average='weighted'),
            "micro_precision": precision_score(y_true, y_pred, average='micro', zero_division=1),
            "macro_precision": precision_score(y_true, y_pred, average='macro', zero_division=1),
            "weighted_precision": precision_score(y_true, y_pred, average='weighted', zero_division=1),
            "micro_recall": recall_score(y_true, y_pred, average='micro'),
            "macro_recall": recall_score(y_true, y_pred, average='macro'),
            "weighted_recall": recall_score(y_true, y_pred, average='weighted'),
            "hamming_loss": hamming_loss(y_true, y_pred),
            "mean_uncertainty": np.mean(std_probs),
            "median_uncertainty": np.median(std_probs),
            "max_uncertainty": np.max(std_probs)
        }
        
        # Try to calculate ROC AUC if possible
        try:
            metrics["roc_auc"] = roc_auc_score(y_true, mean_probs, average='macro', multi_class='ovr')
        except:
            metrics["roc_auc"] = None
        
        if return_predictions:
            return metrics, mean_probs, std_probs, y_true
        else:
            return metrics
    
    def visualize_uncertainty(self, mean_probs, std_probs, y_true, num_classes=10, class_names=None):
        """
        Visualize uncertainty for selected classes.
        
        Args:
            mean_probs: Mean probabilities from ensemble
            std_probs: Standard deviation of probabilities (uncertainty)
            y_true: True labels
            num_classes: Number of classes to visualize
            class_names: List of class names
        """
        n_classes = y_true.shape[1]
        
        # Select a subset of classes to visualize
        classes_to_plot = np.random.choice(n_classes, min(num_classes, n_classes), replace=False)
        
        # Create figure
        fig, axes = plt.subplots(len(classes_to_plot), 1, figsize=(10, 3 * len(classes_to_plot)))
        
        if len(classes_to_plot) == 1:
            axes = [axes]
        
        for i, class_idx in enumerate(classes_to_plot):
            ax = axes[i]
            
            # Get probabilities, uncertainties, and true labels for this class
            probs = mean_probs[:, class_idx]
            uncertainties = std_probs[:, class_idx]
            true_labels = y_true[:, class_idx]
            
            # Create scatter plot
            scatter = ax.scatter(probs, uncertainties, c=true_labels, cmap='coolwarm', alpha=0.6)
            
            # Add colorbar
            cbar = plt.colorbar(scatter, ax=ax)
            cbar.set_label('True Label')
            
            # Set class label
            if class_names is not None:
                class_label = class_names[class_idx]
            else:
                class_label = f"Class {class_idx}"
            
            ax.set_xlabel('Predicted Probability')
            ax.set_ylabel('Uncertainty (Std. Dev.)')
            ax.set_title(f'Uncertainty vs. Prediction for {class_label}')
            ax.grid(True, alpha=0.3)
            
            # Add threshold line
            ax.axvline(x=0.5, color='gray', linestyle='--', alpha=0.7)
        
        plt.tight_layout()
        return fig
    
    def analyze_errors(self, mean_probs, std_probs, y_true, threshold=0.5):
        """
        Analyze relationship between prediction errors and uncertainty.
        
        Args:
            mean_probs: Mean probabilities from ensemble
            std_probs: Standard deviation of probabilities (uncertainty)
            y_true: True labels
            threshold: Classification threshold
            
        Returns:
            fig: Matplotlib figure
        """
        # Make binary predictions
        y_pred = (mean_probs > threshold).astype(float)
        
        # Calculate error
        errors = np.abs(y_true - mean_probs)
        
        # Flatten arrays
        flat_errors = errors.flatten()
        flat_uncertainty = std_probs.flatten()
        
        # Create bins for uncertainty
        n_bins = 20
        bins = np.linspace(np.min(flat_uncertainty), np.max(flat_uncertainty), n_bins+1)
        bin_indices = np.digitize(flat_uncertainty, bins) - 1
        
        # Calculate mean error for each bin
        bin_mean_errors = np.zeros(n_bins)
        bin_counts = np.zeros(n_bins)
        
        for i in range(len(flat_errors)):
            bin_idx = bin_indices[i]
            if bin_idx >= 0 and bin_idx < n_bins:
                bin_mean_errors[bin_idx] += flat_errors[i]
                bin_counts[bin_idx] += 1
        
        # Avoid division by zero
        valid_bins = bin_counts > 0
        bin_mean_errors[valid_bins] /= bin_counts[valid_bins]
        
        # Create figure
        fig, ax = plt.subplots(figsize=(10, 6))
        
        # Plot mean error vs. uncertainty
        bin_centers = (bins[:-1] + bins[1:]) / 2
        ax.plot(bin_centers, bin_mean_errors, 'o-', markersize=8)
        
        # Fit linear regression
        valid_x = bin_centers[valid_bins]
        valid_y = bin_mean_errors[valid_bins]
        
        if len(valid_x) > 1:
            from sklearn.linear_model import LinearRegression
            reg = LinearRegression().fit(valid_x.reshape(-1, 1), valid_y)
            x_range = np.linspace(np.min(valid_x), np.max(valid_x), 100)
            y_pred = reg.predict(x_range.reshape(-1, 1))
            ax.plot(x_range, y_pred, 'r--', linewidth=2, 
                    label=f'Slope: {reg.coef_[0]:.4f}, R²: {reg.score(valid_x.reshape(-1, 1), valid_y):.4f}')
            ax.legend()
        
        ax.set_xlabel('Uncertainty (Std. Dev.)')
        ax.set_ylabel('Mean Absolute Error')
        ax.set_title('Relationship Between Uncertainty and Prediction Error')
        ax.grid(True, alpha=0.3)
        
        return fig
    
    def calibration_curve(self, mean_probs, std_probs, y_true, n_bins=10):
        """
        Plot calibration curve to analyze if predicted probabilities match observed frequencies.
        
        Args:
            mean_probs: Mean probabilities from ensemble
            std_probs: Standard deviation of probabilities (uncertainty)
            y_true: True labels
            n_bins: Number of bins for calibration curve
            
        Returns:
            fig: Matplotlib figure
        """
        # Flatten arrays
        flat_probs = mean_probs.flatten()
        flat_true = y_true.flatten()
        flat_uncertainty = std_probs.flatten()
        
        # Create bins for probabilities
        bins = np.linspace(0, 1, n_bins+1)
        bin_indices = np.digitize(flat_probs, bins) - 1
        
        # Calculate observed frequency and mean predicted probability for each bin
        bin_obs_freq = np.zeros(n_bins)
        bin_pred_prob = np.zeros(n_bins)
        bin_uncertainty = np.zeros(n_bins)
        bin_counts = np.zeros(n_bins)
        
        for i in range(len(flat_probs)):
            bin_idx = bin_indices[i]
            if bin_idx >= 0 and bin_idx < n_bins:
                bin_obs_freq[bin_idx] += flat_true[i]
                bin_pred_prob[bin_idx] += flat_probs[i]
                bin_uncertainty[bin_idx] += flat_uncertainty[i]
                bin_counts[bin_idx] += 1
        
        # Avoid division by zero
        valid_bins = bin_counts > 0
        bin_obs_freq[valid_bins] /= bin_counts[valid_bins]
        bin_pred_prob[valid_bins] /= bin_counts[valid_bins]
        bin_uncertainty[valid_bins] /= bin_counts[valid_bins]
        
        # Create figure
        fig, ax = plt.subplots(figsize=(10, 8))
        
        # Plot calibration curve
        bin_centers = (bins[:-1] + bins[1:]) / 2
        ax.plot(bin_centers, bin_obs_freq, 'o-', markersize=8, label='Calibration Curve')
        
        # Plot perfect calibration
        ax.plot([0, 1], [0, 1], 'k--', label='Perfect Calibration')
        
        # Plot uncertainties
        ax2 = ax.twinx()
        ax2.bar(bin_centers, bin_uncertainty, alpha=0.2, width=1/n_bins, color='r', label='Mean Uncertainty')
        ax2.set_ylabel('Mean Uncertainty (Std. Dev.)', color='r')
        ax2.tick_params(axis='y', labelcolor='r')
        
        # Add labels
        ax.set_xlabel('Mean Predicted Probability')
        ax.set_ylabel('Observed Frequency')
        ax.set_title('Calibration Curve with Uncertainty')
        ax.grid(True, alpha=0.3)
        
        # Add legends
        lines1, labels1 = ax.get_legend_handles_labels()
        lines2, labels2 = ax2.get_legend_handles_labels()
        ax.legend(lines1 + lines2, labels1 + labels2, loc='upper left')
        
        return fig
    
    def uncertainty_threshold(self, mean_probs, std_probs, y_true, threshold=0.5, uncertainty_percentiles=None):
        """
        Analyze performance at different uncertainty thresholds.
        
        Args:
            mean_probs: Mean probabilities from ensemble
            std_probs: Standard deviation of probabilities (uncertainty)
            y_true: True labels
            threshold: Classification threshold
            uncertainty_percentiles: List of uncertainty percentiles to evaluate
            
        Returns:
            fig: Matplotlib figure
            metrics: Dictionary of metrics at different uncertainty thresholds
        """
        if uncertainty_percentiles is None:
            uncertainty_percentiles = [0, 25, 50, 75, 90, 95]
        
        # Flatten arrays for uncertainty analysis
        flat_uncertainty = std_probs.flatten()
        
        # Calculate uncertainty thresholds
        uncertainty_thresholds = [np.percentile(flat_uncertainty, p) for p in uncertainty_percentiles]
        
        # Calculate metrics at different uncertainty thresholds
        metrics = []
        coverage = []
        
        for unc_thresh in uncertainty_thresholds:
            # Create mask for samples below uncertainty threshold
            mask = np.max(std_probs, axis=1) <= unc_thresh
            
            # Skip if no samples meet the criteria
            if np.sum(mask) == 0:
                metrics.append(None)
                coverage.append(0)
                continue
            
            # Filter predictions and true labels
            filtered_probs = mean_probs[mask]
            filtered_true = y_true[mask]
            
            # Make binary predictions
            filtered_pred = (filtered_probs > threshold).astype(float)
            
            # Calculate metrics
            current_metrics = {
                "micro_f1": f1_score(filtered_true, filtered_pred, average='micro'),
                "macro_f1": f1_score(filtered_true, filtered_pred, average='macro'),
                "weighted_f1": f1_score(filtered_true, filtered_pred, average='weighted'),
                "hamming_loss": hamming_loss(filtered_true, filtered_pred),
            }
            
            metrics.append(current_metrics)
            coverage.append(np.mean(mask))
        
        # Create figure
        fig, ax1 = plt.subplots(figsize=(12, 6))
        
        # Plot F1 score
        f1_scores = [m["micro_f1"] if m is not None else 0 for m in metrics]
        ax1.plot(uncertainty_percentiles, f1_scores, 'bo-', label='Micro F1 Score')
        
        # Plot macro F1 score
        macro_f1_scores = [m["macro_f1"] if m is not None else 0 for m in metrics]
        ax1.plot(uncertainty_percentiles, macro_f1_scores, 'go-', label='Macro F1 Score')
        
        # Plot coverage
        ax2 = ax1.twinx()
        ax2.plot(uncertainty_percentiles, coverage, 'r--', label='Data Coverage')
        ax2.set_ylabel('Data Coverage', color='r')
        ax2.tick_params(axis='y', labelcolor='r')
        
        # Add labels
        ax1.set_xlabel('Uncertainty Percentile Threshold')
        ax1.set_ylabel('F1 Score')
        ax1.set_title('Performance vs. Uncertainty Threshold')
        ax1.grid(True, alpha=0.3)
        
        # Add legend
        lines1, labels1 = ax1.get_legend_handles_labels()
        lines2, labels2 = ax2.get_legend_handles_labels()
        ax1.legend(lines1 + lines2, labels1 + labels2, loc='center right')
        
        return fig, {"metrics": metrics, "coverage": coverage, "percentiles": uncertainty_percentiles}
    
    def selective_prediction(self, mean_probs, std_probs, y_true, threshold=0.5):
        """
        Perform selective prediction analysis.
        
        Args:
            mean_probs: Mean probabilities from ensemble
            std_probs: Standard deviation of probabilities (uncertainty)
            y_true: True labels
            threshold: Classification threshold
            
        Returns:
            fig: Matplotlib figure
        """
        # Calculate max uncertainty for each sample
        max_uncertainties = np.max(std_probs, axis=1)
        
        # Sort samples by uncertainty
        sorted_indices = np.argsort(max_uncertainties)
        
        # Initialize lists for storing results
        coverages = []
        f1_scores = []
        
        # Calculate metrics at different coverage levels
        coverage_steps = np.linspace(0.1, 1.0, 10)
        
        for coverage in coverage_steps:
            # Select top-k% most certain predictions
            k = int(len(sorted_indices) * coverage)
            selected_indices = sorted_indices[:k]
            
            # Filter predictions and true labels
            selected_probs = mean_probs[selected_indices]
            selected_true = y_true[selected_indices]
            
            # Make binary predictions
            selected_pred = (selected_probs > threshold).astype(float)
            
            # Calculate F1 score
            f1 = f1_score(selected_true, selected_pred, average='micro')
            
            coverages.append(coverage)
            f1_scores.append(f1)
        
        # Create figure
        fig, ax = plt.subplots(figsize=(10, 6))
        
        # Plot F1 score vs. coverage
        ax.plot(coverages, f1_scores, 'bo-', markersize=8)
        
        # Add area under curve
        ax.fill_between(coverages, 0, f1_scores, alpha=0.2)
        
        # Add labels
        ax.set_xlabel('Coverage (Fraction of Data)')
        ax.set_ylabel('Micro F1 Score')
        ax.set_title('Selective Prediction: Performance vs. Coverage')
        ax.grid(True, alpha=0.3)
        
        # Add area under the curve value
        auc = np.trapz(f1_scores, coverages)
        ax.text(0.05, 0.95, f'AUC: {auc:.4f}', transform=ax.transAxes, 
                fontsize=12, verticalalignment='top', bbox=dict(boxstyle='round', alpha=0.1))
        
        return fig

# Function to run ensemble training and evaluation
def train_and_evaluate_ensemble(
    model_class,
    model_args,
    train_loader,
    val_loader,
    test_loader,
    train_function,
    train_args,
    num_models=5,
    bootstrap=True,
    device='cuda',
    class_names=None
):
    """
    Train and evaluate an ensemble model.
    
    Args:
        model_class: The model class to use
        model_args: Arguments for model initialization
        train_loader: DataLoader for training data
        val_loader: DataLoader for validation data
        test_loader: DataLoader for test data
        train_function: Function to train a single model
        train_args: Arguments for the training function
        num_models: Number of models in the ensemble
        bootstrap: Whether to use bootstrapping for training
        device: Device to use for computation
        class_names: Names of the classes (optional)
        
    Returns:
        ensemble: Trained ensemble model
        metrics: Evaluation metrics
        figures: Dictionary of visualization figures
    """
    # Initialize wandb for the ensemble experiment
    wandb.init(project="ALLSTARS_ensemble", name="ensemble_experiment", reinit=True)
    
    # Log ensemble configuration
    wandb.config.update({
        "num_models": num_models,
        "bootstrap": bootstrap,
        **model_args,
        **train_args
    })
    
    # Initialize ensemble
    ensemble = EnsembleStarClassifier(
        model_class=model_class,
        model_args=model_args,
        num_models=num_models,
        device=device
    )
    
    # Train ensemble
    ensemble.train(
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader,
        train_function=train_function,
        train_args=train_args,
        bootstrap=bootstrap
    )
    
    # Evaluate ensemble
    metrics, mean_probs, std_probs, y_true = ensemble.evaluate(
        test_loader=test_loader,
        return_predictions=True
    )
    
    # Log evaluation metrics
    wandb.log(metrics)
    
    # Create visualizations
    figures = {}
    
    # Uncertainty visualization
    uncertainty_fig = ensemble.visualize_uncertainty(
        mean_probs=mean_probs,
        std_probs=std_probs,
        y_true=y_true,
        num_classes=10,
        class_names=class_names
    )
    figures['uncertainty'] = uncertainty_fig
    
    # Error analysis
    error_fig = ensemble.analyze_errors(
        mean_probs=mean_probs,
        std_probs=std_probs,
        y_true=y_true
    )
    figures['error_analysis'] = error_fig
    
    # Calibration curve
    calibration_fig = ensemble.calibration_curve(
        mean_probs=mean_probs,
        std_probs=std_probs,
        y_true=y_true
    )
    figures['calibration'] = calibration_fig
    
    # Uncertainty threshold analysis
    threshold_fig, threshold_metrics = ensemble.uncertainty_threshold(
        mean_probs=mean_probs,
        std_probs=std_probs,
        y_true=y_true
    )
    figures['threshold_analysis'] = threshold_fig
    
    # Selective prediction
    selective_fig = ensemble.selective_prediction(
        mean_probs=mean_probs,
        std_probs=std_probs,
        y_true=y_true
    )
    figures['selective_prediction'] = selective_fig
    
    # Log figures
    for name, fig in figures.items():
        wandb.log({name: wandb.Image(fig)})
        plt.close(fig)
    
    # Finish wandb run
    wandb.finish()
    
    return ensemble, metrics, figures

# Claude implementation

In [12]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from tqdm import tqdm

class DeepEnsemble:
    """
    Deep Ensemble implementation for uncertainty quantification in multi-label classification.
    Trains multiple instances of the same model with different random initializations.
    """
    def __init__(
        self, 
        model_class, 
        model_args, 
        num_models=5, 
        device='cuda'
    ):
        """
        Initialize the deep ensemble.
        
        Args:
            model_class: The model class to instantiate (e.g., StarClassifierFusion)
            model_args: Dictionary of arguments to pass to the model constructor
            num_models: Number of models in the ensemble
            device: Device to run the models on ('cuda' or 'cpu')
        """
        self.model_class = model_class
        self.model_args = model_args
        self.num_models = num_models
        self.device = device
        self.models = []
        
        # Initialize models with different random seeds
        for i in range(num_models):
            # Set seed for reproducibility but different for each model
            torch.manual_seed(42 + i)
            np.random.seed(42 + i)
            
            # Create model instance
            model = model_class(**model_args).to(device)
            self.models.append(model)
    
    def train(
        self, 
        train_loader, 
        val_loader, 
        test_loader=None, 
        num_epochs=100, 
        lr=1e-4, 
        max_patience=20,
        scheduler_type='OneCycleLR',
        log_to_wandb=True
    ):
        """
        Train all models in the ensemble.
        
        Args:
            train_loader: DataLoader for training data
            val_loader: DataLoader for validation data
            test_loader: DataLoader for test data (optional)
            num_epochs: Maximum number of epochs to train
            lr: Learning rate
            max_patience: Maximum patience for early stopping
            scheduler_type: Type of learning rate scheduler ('OneCycleLR' or 'ReduceLROnPlateau')
            log_to_wandb: Whether to log training progress to wandb
            
        Returns:
            List of trained models
        """
        import wandb

        for model_idx, model in enumerate(self.models):
            print(f"\n----- Training Ensemble Model {model_idx+1}/{self.num_models} -----\n")
            
            # Initialize a new wandb run for each model
            if log_to_wandb:
                run = wandb.init(
                    project="ALLSTARS_ensemble", 
                    name=f"model_{model_idx}",
                    group="ensemble_training",
                    config={
                        **self.model_args,
                        "model_idx": model_idx,
                        "num_models": self.num_models,
                        "lr": lr,
                        "max_patience": max_patience,
                        "scheduler_type": scheduler_type,
                        "num_epochs": num_epochs
                    },
                    reinit=True
                )
            
            optimizer = optim.AdamW(model.parameters(), lr=lr)
            
            # Configure the scheduler
            if scheduler_type == 'OneCycleLR':
                scheduler = optim.lr_scheduler.OneCycleLR(
                    optimizer, 
                    max_lr=lr,
                    epochs=num_epochs, 
                    steps_per_epoch=len(train_loader)
                )
            else:  # ReduceLROnPlateau
                scheduler = optim.lr_scheduler.ReduceLROnPlateau(
                    optimizer, 
                    mode='min', 
                    factor=0.5, 
                    patience=int(max_patience / 5)
                )
            
            # We assume the datasets are MultiModalBalancedMultiLabelDataset
            # that returns (X_spectra, X_gaia, y)
            all_labels = []
            for _, _, y_batch in train_loader:
                all_labels.extend(y_batch.cpu().numpy())
            
            # Calculate class weights to handle imbalanced classes
            class_weights = self._calculate_class_weights(np.array(all_labels))
            class_weights = torch.tensor(class_weights, dtype=torch.float).to(self.device)
            criterion = nn.BCEWithLogitsLoss(pos_weight=class_weights)
            
            best_val_loss = float('inf')
            patience = max_patience
            best_model_state = None

            # Training loop
            for epoch in range(num_epochs):
                # Resample training data for balanced batches
                train_loader.dataset.re_sample()

                # Recompute class weights based on the new sampling
                all_labels = []
                for _, _, y_batch in train_loader:
                    all_labels.extend(y_batch.cpu().numpy())
                class_weights = self._calculate_class_weights(np.array(all_labels))
                class_weights = torch.tensor(class_weights, dtype=torch.float).to(self.device)
                criterion = nn.BCEWithLogitsLoss(pos_weight=class_weights)

                # --- Training Phase ---
                model.train()
                train_loss, train_acc = 0.0, 0.0
                for X_spc, X_ga, y_batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Training"):
                    X_spc, X_ga, y_batch = X_spc.to(self.device), X_ga.to(self.device), y_batch.to(self.device)
                    optimizer.zero_grad()
                    outputs = model(X_spc, X_ga)
                    loss = criterion(outputs, y_batch)
                    loss.backward()
                    optimizer.step()

                    train_loss += loss.item() * X_spc.size(0)
                    predicted = (torch.sigmoid(outputs) > 0.5).float()
                    correct = (predicted == y_batch).float()
                    train_acc += correct.mean(dim=1).sum().item()
                
                train_loss /= len(train_loader.dataset)
                train_acc /= len(train_loader.dataset)

                # --- Validation Phase ---
                model.eval()
                val_loss, val_acc = 0.0, 0.0
                with torch.no_grad():
                    for X_spc, X_ga, y_batch in tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Validation"):
                        X_spc, X_ga, y_batch = X_spc.to(self.device), X_ga.to(self.device), y_batch.to(self.device)
                        outputs = model(X_spc, X_ga)
                        loss = criterion(outputs, y_batch)
                        val_loss += loss.item() * X_spc.size(0)
                        predicted = (torch.sigmoid(outputs) > 0.5).float()
                        correct = (predicted == y_batch).float()
                        val_acc += correct.mean(dim=1).sum().item()
                
                val_loss /= len(val_loader.dataset)
                val_acc /= len(val_loader.dataset)

                # --- Test Phase (if provided) ---
                test_metrics = {}
                if test_loader is not None:
                    test_loss, test_acc = 0.0, 0.0
                    y_true, y_pred = [], []
                    with torch.no_grad():
                        for X_spc, X_ga, y_batch in tqdm(test_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Testing"):
                            X_spc, X_ga, y_batch = X_spc.to(self.device), X_ga.to(self.device), y_batch.to(self.device)
                            outputs = model(X_spc, X_ga)
                            loss = criterion(outputs, y_batch)
                            test_loss += loss.item() * X_spc.size(0)
                            
                            predicted = (torch.sigmoid(outputs) > 0.5).float()
                            correct = (predicted == y_batch).float()
                            test_acc += correct.mean(dim=1).sum().item()

                            y_true.extend(y_batch.cpu().numpy())
                            y_pred.extend(predicted.cpu().numpy())

                    test_loss /= len(test_loader.dataset)
                    test_acc /= len(test_loader.dataset)
                    test_metrics = self._calculate_metrics(np.array(y_true), np.array(y_pred))
                    test_metrics.update({
                        "test_loss": test_loss,
                        "test_acc": test_acc,
                    })

                # Logging
                if log_to_wandb:
                    log_data = {
                        "epoch": epoch,
                        "train_loss": train_loss,
                        "val_loss": val_loss,
                        "train_acc": train_acc,
                        "val_acc": val_acc,
                        "lr": self._get_lr(optimizer)
                    }
                    log_data.update(test_metrics)
                    wandb.log(log_data)

                # Update learning rate scheduler
                if scheduler_type == 'OneCycleLR':
                    scheduler.step()
                else:  # ReduceLROnPlateau
                    scheduler.step(val_loss)

                # Early stopping
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    patience = max_patience
                    best_model_state = model.state_dict().copy()
                    if log_to_wandb:
                        wandb.run.summary["best_val_loss"] = best_val_loss
                else:
                    patience -= 1
                    if patience <= 0:
                        print("Early stopping triggered.")
                        break

                print(f"Epoch {epoch+1}/{num_epochs} - "
                      f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, "
                      f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

            # Load the best model state
            if best_model_state is not None:
                model.load_state_dict(best_model_state)
            
            # Close wandb run
            if log_to_wandb:
                wandb.finish()

        return self.models
    
    def predict(self, loader, return_individual=False):
        """
        Make predictions with the ensemble.
        
        Args:
            loader: DataLoader for the data to predict
            return_individual: Whether to return predictions from individual models
            
        Returns:
            mean_probs: Mean probability across all models
            std_probs: Standard deviation of probabilities (uncertainty measure)
            individual_probs: (Optional) Predictions from each individual model
        """
        all_probs = []
        
        for model in self.models:
            model.eval()
            model_probs = []
            
            with torch.no_grad():
                for X_spc, X_ga, _ in loader:
                    X_spc, X_ga = X_spc.to(self.device), X_ga.to(self.device)
                    outputs = model(X_spc, X_ga)
                    probs = torch.sigmoid(outputs).cpu().numpy()
                    model_probs.append(probs)
            
            # Concatenate batches
            model_probs = np.concatenate(model_probs, axis=0)
            all_probs.append(model_probs)
        
        # Stack along a new axis to get shape (num_models, num_samples, num_classes)
        all_probs = np.stack(all_probs, axis=0)
        
        # Calculate mean and std across models (axis=0)
        mean_probs = np.mean(all_probs, axis=0)
        std_probs = np.std(all_probs, axis=0)
        
        if return_individual:
            return mean_probs, std_probs, all_probs
        else:
            return mean_probs, std_probs
    
    def predict_sample(self, X_spectra, X_gaia):
        """
        Make predictions for a single sample.
        
        Args:
            X_spectra: Spectral features (tensor)
            X_gaia: Gaia features (tensor)
            
        Returns:
            mean_probs: Mean probability across all models
            std_probs: Standard deviation of probabilities (uncertainty measure)
            all_probs: Predictions from each individual model
        """
        # Ensure inputs are tensors with batch dimension
        if len(X_spectra.shape) == 1:
            X_spectra = X_spectra.unsqueeze(0)
        if len(X_gaia.shape) == 1:
            X_gaia = X_gaia.unsqueeze(0)
        
        X_spectra = X_spectra.to(self.device)
        X_gaia = X_gaia.to(self.device)
        
        all_probs = []
        for model in self.models:
            model.eval()
            with torch.no_grad():
                outputs = model(X_spectra, X_gaia)
                probs = torch.sigmoid(outputs).cpu().numpy()
                all_probs.append(probs)
        
        # Stack to get shape (num_models, 1, num_classes)
        all_probs = np.stack(all_probs, axis=0)
        
        # Calculate mean and std across models
        mean_probs = np.mean(all_probs, axis=0)
        std_probs = np.std(all_probs, axis=0)
        
        return mean_probs[0], std_probs[0], all_probs[:, 0, :]
    
    def save_models(self, path_prefix):
        """
        Save all models in the ensemble.
        
        Args:
            path_prefix: Path prefix for saving models
        """
        for i, model in enumerate(self.models):
            torch.save(model.state_dict(), f"{path_prefix}_model_{i}.pth")
    
    def load_models(self, path_prefix):
        """
        Load all models in the ensemble.
        
        Args:
            path_prefix: Path prefix for loading models
        """
        self.models = []
        for i in range(self.num_models):
            model = self.model_class(**self.model_args).to(self.device)
            model.load_state_dict(torch.load(f"{path_prefix}_model_{i}.pth"))
            model.eval()
            self.models.append(model)
    
    def _calculate_class_weights(self, y):
        """Calculate class weights for handling imbalanced classes."""
        if y.ndim > 1:  
            class_counts = np.sum(y, axis=0)  
        else:
            class_counts = np.bincount(y)

        total_samples = y.shape[0] if y.ndim > 1 else len(y)
        class_counts = np.where(class_counts == 0, 1, class_counts)  # Prevent division by zero
        class_weights = total_samples / (len(class_counts) * class_counts)
        
        return class_weights
    
    def _calculate_metrics(self, y_true, y_pred):
        """Calculate evaluation metrics for multi-label classification."""
        from sklearn.metrics import f1_score, precision_score, recall_score, hamming_loss
        
        metrics = {
            "micro_f1": f1_score(y_true, y_pred, average='micro'),
            "macro_f1": f1_score(y_true, y_pred, average='macro'),
            "weighted_f1": f1_score(y_true, y_pred, average='weighted'),
            "micro_precision": precision_score(y_true, y_pred, average='micro', zero_division=1),
            "macro_precision": precision_score(y_true, y_pred, average='macro', zero_division=1),
            "weighted_precision": precision_score(y_true, y_pred, average='weighted', zero_division=1),
            "micro_recall": recall_score(y_true, y_pred, average='micro'),
            "macro_recall": recall_score(y_true, y_pred, average='macro'),
            "weighted_recall": recall_score(y_true, y_pred, average='weighted'),
            "hamming_loss": hamming_loss(y_true, y_pred)
        }
        
        return metrics
    
    def _get_lr(self, optimizer):
        """Get current learning rate from optimizer."""
        for param_group in optimizer.param_groups:
            return param_group['lr']

class UncertaintyVisualizer:
    """
    Utility class for visualizing uncertainty in multi-label classification.
    """
    def __init__(self, class_names):
        """
        Initialize the visualizer.
        
        Args:
            class_names: List of class names
        """
        self.class_names = class_names
    
    def plot_prediction_with_uncertainty(self, mean_probs, std_probs, true_labels=None, 
                                         threshold=0.5, top_k=10, figsize=(12, 8)):
        """
        Plot prediction probabilities with uncertainty bars.
        
        Args:
            mean_probs: Mean probabilities from ensemble
            std_probs: Standard deviation of probabilities
            true_labels: True labels (optional)
            threshold: Decision threshold for positive prediction
            top_k: Number of top predictions to show
            figsize: Figure size
            
        Returns:
            matplotlib figure
        """
        # Get indices of top k predictions by mean probability
        top_indices = np.argsort(mean_probs)[::-1][:top_k]
        
        # Extract top k values
        top_means = mean_probs[top_indices]
        top_stds = std_probs[top_indices]
        top_classes = [self.class_names[i] for i in top_indices]
        
        # Create figure
        fig, ax = plt.subplots(figsize=figsize)
        
        # Create horizontal bar chart with error bars
        y_pos = np.arange(len(top_classes))
        bar_colors = ['green' if prob >= threshold else 'red' for prob in top_means]
        
        # Plot bars
        bars = ax.barh(y_pos, top_means, xerr=top_stds, align='center', 
                     alpha=0.7, color=bar_colors, capsize=5)
        
        # Add true labels if provided
        if true_labels is not None:
            true_label_indices = np.where(true_labels == 1)[0]
            true_class_names = [self.class_names[i] for i in true_label_indices]
            
            # Mark true labels in the plot
            for i, class_name in enumerate(top_classes):
                if class_name in true_class_names:
                    ax.get_children()[i].set_edgecolor('blue')
                    ax.get_children()[i].set_linewidth(2)
        
        # Add threshold line
        ax.axvline(x=threshold, color='gray', linestyle='--', alpha=0.7)
        
        # Customize plot
        ax.set_yticks(y_pos)
        ax.set_yticklabels(top_classes)
        ax.invert_yaxis()  # labels read top-to-bottom
        ax.set_xlabel('Probability')
        ax.set_title('Prediction with Uncertainty')
        
        # Add gridlines
        ax.grid(axis='x', linestyle='--', alpha=0.7)
        
        # Add legend
        from matplotlib.patches import Patch
        legend_elements = [
            Patch(facecolor='green', edgecolor='black', label='Above threshold'),
            Patch(facecolor='red', edgecolor='black', label='Below threshold')
        ]
        if true_labels is not None:
            legend_elements.append(Patch(facecolor='white', edgecolor='blue', linewidth=2, label='True label'))
        
        ax.legend(handles=legend_elements, loc='lower right')
        
        # Add text annotations for probabilities
        for i, (mean, std) in enumerate(zip(top_means, top_stds)):
            ax.text(mean + 0.02, i, f'{mean:.2f} ± {std:.2f}', va='center')
        
        plt.tight_layout()
        return fig
    
    def plot_uncertainty_distribution(self, std_probs, predictions, figsize=(10, 6)):
        """
        Plot histogram of uncertainty (standard deviation) distribution.
        
        Args:
            std_probs: Standard deviation of probabilities (shape: num_samples x num_classes)
            predictions: Binary predictions (shape: num_samples x num_classes)
            figsize: Figure size
            
        Returns:
            matplotlib figure
        """
        # Flatten predictions and standard deviations
        std_probs_flat = std_probs.flatten()
        predictions_flat = predictions.flatten()
        
        # Create figure
        fig, ax = plt.subplots(figsize=figsize)
        
        # Plot histograms
        ax.hist(std_probs_flat, bins=50, alpha=0.5, label='All predictions')
        ax.hist(std_probs_flat[predictions_flat == 1], bins=50, alpha=0.5, label='Positive predictions')
        
        # Customize plot
        ax.set_xlabel('Uncertainty (Standard Deviation)')
        ax.set_ylabel('Count')
        ax.set_title('Distribution of Prediction Uncertainty')
        ax.legend()
        ax.grid(linestyle='--', alpha=0.7)
        
        plt.tight_layout()
        return fig
    
    def plot_uncertainty_vs_error(self, mean_probs, std_probs, true_labels, figsize=(10, 6)):
        """
        Scatter plot of uncertainty vs prediction error.
        
        Args:
            mean_probs: Mean probabilities from ensemble
            std_probs: Standard deviation of probabilities
            true_labels: True labels
            figsize: Figure size
            
        Returns:
            matplotlib figure
        """
        # Calculate errors (absolute difference between mean probability and true label)
        errors = np.abs(mean_probs - true_labels)
        
        # Flatten arrays
        std_probs_flat = std_probs.flatten()
        errors_flat = errors.flatten()
        
        # Create figure
        fig, ax = plt.subplots(figsize=figsize)
        
        # Create scatter plot
        scatter = ax.scatter(std_probs_flat, errors_flat, alpha=0.5, s=10)
        
        # Add trend line
        from scipy.stats import linregress
        slope, intercept, r_value, p_value, std_err = linregress(std_probs_flat, errors_flat)
        x = np.linspace(std_probs_flat.min(), std_probs_flat.max(), 100)
        y = slope * x + intercept
        ax.plot(x, y, color='red', linestyle='--')
        
        # Add correlation coefficient
        ax.text(0.95, 0.05, f'Correlation: {r_value:.3f}', transform=ax.transAxes, 
                ha='right', va='bottom', bbox=dict(facecolor='white', alpha=0.5))
        
        # Customize plot
        ax.set_xlabel('Uncertainty (Standard Deviation)')
        ax.set_ylabel('Prediction Error')
        ax.set_title('Uncertainty vs Prediction Error')
        ax.grid(linestyle='--', alpha=0.7)
        
        plt.tight_layout()
        return fig
    
    def plot_calibration_curve(self, mean_probs, std_probs, true_labels, n_bins=10, figsize=(10, 6)):
        """
        Plot calibration curve to assess the quality of predicted probabilities.
        
        Args:
            mean_probs: Mean probabilities from ensemble
            std_probs: Standard deviation of probabilities (for coloring points)
            true_labels: True labels
            n_bins: Number of bins for the calibration curve
            figsize: Figure size
            
        Returns:
            matplotlib figure
        """
        # Flatten arrays
        mean_probs_flat = mean_probs.flatten()
        std_probs_flat = std_probs.flatten()
        true_labels_flat = true_labels.flatten()
        
        # Remove NaN values if any
        valid_indices = ~np.isnan(mean_probs_flat) & ~np.isnan(true_labels_flat)
        mean_probs_flat = mean_probs_flat[valid_indices]
        std_probs_flat = std_probs_flat[valid_indices]
        true_labels_flat = true_labels_flat[valid_indices]
        
        # Create bins and calculate calibration metrics
        bins = np.linspace(0.0, 1.0, n_bins + 1)
        binned = np.digitize(mean_probs_flat, bins) - 1
        bin_accs = np.zeros(n_bins)
        bin_confs = np.zeros(n_bins)
        bin_sizes = np.zeros(n_bins)
        bin_uncerts = np.zeros(n_bins)
        
        for bin_idx in range(n_bins):
            bin_mask = binned == bin_idx
            if np.sum(bin_mask) > 0:
                bin_sizes[bin_idx] = np.sum(bin_mask)
                bin_accs[bin_idx] = np.mean(true_labels_flat[bin_mask])
                bin_confs[bin_idx] = np.mean(mean_probs_flat[bin_mask])
                bin_uncerts[bin_idx] = np.mean(std_probs_flat[bin_mask])
        
        # Calculate the expected calibration error (ECE)
        ece = np.sum(bin_sizes / len(mean_probs_flat) * np.abs(bin_accs - bin_confs))
        
        # Create figure
        fig, ax = plt.subplots(figsize=figsize)
        
        # Plot the calibration curve
        ax.plot([0, 1], [0, 1], linestyle='--', color='gray', label='Perfectly calibrated')
        
        # Color points by uncertainty
        sc = ax.scatter(bin_confs, bin_accs, 
                      s=bin_sizes / np.sum(bin_sizes) * 2000,  # Size proportional to bin size
                      c=bin_uncerts, cmap='viridis', alpha=0.8, 
                      linewidths=1, edgecolors='black')
        
        # Add colorbar for uncertainty
        cbar = plt.colorbar(sc, ax=ax)
        cbar.set_label('Mean uncertainty (std)')
        
        # Customize plot
        ax.set_xlabel('Mean predicted probability')
        ax.set_ylabel('Fraction of positives (accuracy)')
        ax.set_xlim([0, 1])
        ax.set_ylim([0, 1])
        ax.set_title(f'Calibration Curve (ECE: {ece:.3f})')
        ax.grid(linestyle='--', alpha=0.7)
        ax.legend(loc='lower right')
        
        plt.tight_layout()
        return fig

In [13]:
import torch
import wandb
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

# Import your model and the ensemble implementation
#from your_model_file import StarClassifierFusion, MultiModalBalancedMultiLabelDataset
#from ensemble_uncertainty import DeepEnsemble, UncertaintyVisualizer

if __name__ == "__main__":
    # Configuration
    d_model_spectra = 256
    d_model_gaia = 256  # tiny
    num_classes = 55
    input_dim_spectra = 3647
    input_dim_gaia = 18
    n_layers = 12
    lr = 2.5e-4
    patience = 600
    num_epochs = 200
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # Initialize wandb
    wandb.init(project="ALLSTARS_ensemble_uncertainty")
    
    # Model arguments
    model_args = {
        "d_model_spectra": d_model_spectra,
        "d_model_gaia": d_model_gaia,
        "num_classes": num_classes,
        "input_dim_spectra": input_dim_spectra,
        "input_dim_gaia": input_dim_gaia,
        "n_layers": n_layers,
        "use_cross_attention": True,
        "n_cross_attn_heads": 8
    }
    
    # Create deep ensemble with 5 models
    ensemble = DeepEnsemble(
        model_class=StarClassifierFusion,
        model_args=model_args,
        num_models=5,  # Number of models in the ensemble
        device=device
    )
    
    # Load your data loaders as you did in the original code
    # ...
    
    # Train the ensemble
    trained_models = ensemble.train(
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader,
        num_epochs=num_epochs,
        lr=lr,
        max_patience=patience,
        scheduler_type='OneCycleLR',
        log_to_wandb=True
    )
    
    # Save the ensemble models
    ensemble.save_models("ensemble_mamba_v1")
    
    # Evaluate the ensemble on test set
    mean_probs, std_probs = ensemble.predict(test_loader)
    
    # Convert mean probabilities to binary predictions
    threshold = 0.5
    predictions = (mean_probs >= threshold).astype(float)
    
    # Get true labels from test set
    true_labels = []
    for _, _, y_batch in test_loader:
        true_labels.extend(y_batch.cpu().numpy())
    true_labels = np.array(true_labels)
    
    # Calculate metrics
    from sklearn.metrics import f1_score, precision_score, recall_score
    metrics = {
        "micro_f1": f1_score(true_labels, predictions, average='micro'),
        "macro_f1": f1_score(true_labels, predictions, average='macro'),
        "weighted_f1": f1_score(true_labels, predictions, average='weighted'),
        "micro_precision": precision_score(true_labels, predictions, average='micro', zero_division=1),
        "macro_precision": precision_score(true_labels, predictions, average='macro', zero_division=1),
        "micro_recall": recall_score(true_labels, predictions, average='micro'),
        "macro_recall": recall_score(true_labels, predictions, average='macro')
    }
    
    # Log metrics to wandb
    wandb.log(metrics)
    
    # Initialize visualizer with class names
    class_names = [f"Class_{i}" for i in range(num_classes)]  # Replace with actual class names
    visualizer = UncertaintyVisualizer(class_names)
    
    # Visualize uncertainty for a sample
    sample_idx = 0  # Choose any sample
    fig = visualizer.plot_prediction_with_uncertainty(
        mean_probs[sample_idx], 
        std_probs[sample_idx], 
        true_labels[sample_idx],
        threshold=0.5
    )
    wandb.log({"sample_prediction": wandb.Image(fig)})
    
    # Visualize uncertainty distribution
    fig = visualizer.plot_uncertainty_distribution(std_probs, predictions)
    wandb.log({"uncertainty_distribution": wandb.Image(fig)})
    
    # Visualize uncertainty vs error
    fig = visualizer.plot_uncertainty_vs_error(mean_probs, std_probs, true_labels)
    wandb.log({"uncertainty_vs_error": wandb.Image(fig)})
    
    # Visualize calibration curve
    fig = visualizer.plot_calibration_curve(mean_probs, std_probs, true_labels)
    wandb.log({"calibration_curve": wandb.Image(fig)})
    
    # Analysis: Find high uncertainty predictions
    flat_std = std_probs.flatten()
    high_uncertainty_threshold = np.percentile(flat_std, 95)  # Top 5% most uncertain
    
    high_uncertainty_mask = std_probs >= high_uncertainty_threshold
    num_high_uncertainty = np.sum(high_uncertainty_mask)
    
    print(f"Number of high uncertainty predictions: {num_high_uncertainty}")
    print(f"Average accuracy for high uncertainty predictions: {np.mean((predictions[high_uncertainty_mask] == true_labels[high_uncertainty_mask]).astype(float))}")
    print(f"Average accuracy for other predictions: {np.mean((predictions[~high_uncertainty_mask] == true_labels[~high_uncertainty_mask]).astype(float))}")
    
    wandb.finish()


# For inference on new data
def predict_with_uncertainty(ensemble, X_spectra, X_gaia, class_names, threshold=0.5):
    """
    Make predictions with uncertainty estimation on new data.
    
    Args:
        ensemble: Trained DeepEnsemble
        X_spectra: Spectral features (tensor)
        X_gaia: Gaia features (tensor)
        class_names: List of class names
        threshold: Decision threshold for positive predictions
        
    Returns:
        predictions: Binary predictions
        mean_probs: Mean probabilities
        uncertainties: Standard deviations of probabilities
        visualization: Matplotlib figure of predictions with uncertainty
    """
    # Get predictions with uncertainty
    mean_probs, std_probs, all_model_probs = ensemble.predict_sample(X_spectra, X_gaia)
    
    # Convert to binary predictions
    predictions = (mean_probs >= threshold).astype(float)
    
    # Visualize
    visualizer = UncertaintyVisualizer(class_names)
    fig = visualizer.plot_prediction_with_uncertainty(
        mean_probs, std_probs, threshold=threshold
    )
    
    return predictions, mean_probs, std_probs, fig


----- Training Ensemble Model 1/5 -----



Epoch 1/200 - Training: 100%|██████████| 15/15 [00:35<00:00,  2.36s/it]
Epoch 1/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.59it/s]
Epoch 1/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 34.01it/s]


Epoch 1/200 - Train Loss: 0.5282, Train Acc: 0.7654, Val Loss: 0.3791, Val Acc: 0.9306


Epoch 2/200 - Training: 100%|██████████| 15/15 [00:02<00:00,  6.78it/s]
Epoch 2/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.56it/s]
Epoch 2/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 33.61it/s]


Epoch 2/200 - Train Loss: 0.3296, Train Acc: 0.9550, Val Loss: 0.2726, Val Acc: 0.9647


Epoch 3/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.64it/s]
Epoch 3/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 32.77it/s]
Epoch 3/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 33.71it/s]


Epoch 3/200 - Train Loss: 0.2537, Train Acc: 0.9638, Val Loss: 0.2222, Val Acc: 0.9650


Epoch 4/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.61it/s]
Epoch 4/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 30.43it/s]
Epoch 4/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 33.34it/s]


Epoch 4/200 - Train Loss: 0.2155, Train Acc: 0.9639, Val Loss: 0.1940, Val Acc: 0.9650


Epoch 5/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.01it/s]
Epoch 5/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.81it/s]
Epoch 5/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 33.28it/s]


Epoch 5/200 - Train Loss: 0.1929, Train Acc: 0.9639, Val Loss: 0.1762, Val Acc: 0.9650


Epoch 6/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.46it/s]
Epoch 6/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.85it/s]
Epoch 6/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 33.17it/s]


Epoch 6/200 - Train Loss: 0.1785, Train Acc: 0.9638, Val Loss: 0.1644, Val Acc: 0.9650


Epoch 7/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.28it/s]
Epoch 7/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 30.33it/s]
Epoch 7/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 32.67it/s]


Epoch 7/200 - Train Loss: 0.1687, Train Acc: 0.9638, Val Loss: 0.1558, Val Acc: 0.9650


Epoch 8/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.52it/s]
Epoch 8/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.27it/s]
Epoch 8/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 34.10it/s]


Epoch 8/200 - Train Loss: 0.1612, Train Acc: 0.9639, Val Loss: 0.1493, Val Acc: 0.9650


Epoch 9/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.70it/s]
Epoch 9/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 30.47it/s]
Epoch 9/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 32.09it/s]


Epoch 9/200 - Train Loss: 0.1553, Train Acc: 0.9638, Val Loss: 0.1437, Val Acc: 0.9650


Epoch 10/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.61it/s]
Epoch 10/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.62it/s]
Epoch 10/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 30.87it/s]


Epoch 10/200 - Train Loss: 0.1502, Train Acc: 0.9639, Val Loss: 0.1390, Val Acc: 0.9650


Epoch 11/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.39it/s]
Epoch 11/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.56it/s]
Epoch 11/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 32.39it/s]


Epoch 11/200 - Train Loss: 0.1459, Train Acc: 0.9639, Val Loss: 0.1347, Val Acc: 0.9650


Epoch 12/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.17it/s]
Epoch 12/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.34it/s]
Epoch 12/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 32.84it/s]


Epoch 12/200 - Train Loss: 0.1418, Train Acc: 0.9638, Val Loss: 0.1309, Val Acc: 0.9650


Epoch 13/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.62it/s]
Epoch 13/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.80it/s]
Epoch 13/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 32.09it/s]


Epoch 13/200 - Train Loss: 0.1381, Train Acc: 0.9640, Val Loss: 0.1271, Val Acc: 0.9650


Epoch 14/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.50it/s]
Epoch 14/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 32.05it/s]
Epoch 14/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 32.45it/s]


Epoch 14/200 - Train Loss: 0.1346, Train Acc: 0.9638, Val Loss: 0.1238, Val Acc: 0.9650


Epoch 15/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.47it/s]
Epoch 15/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 32.08it/s]
Epoch 15/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 32.96it/s]


Epoch 15/200 - Train Loss: 0.1314, Train Acc: 0.9637, Val Loss: 0.1206, Val Acc: 0.9650


Epoch 16/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.56it/s]
Epoch 16/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 32.09it/s]
Epoch 16/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 33.14it/s]


Epoch 16/200 - Train Loss: 0.1283, Train Acc: 0.9638, Val Loss: 0.1177, Val Acc: 0.9650


Epoch 17/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.63it/s]
Epoch 17/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 32.26it/s]
Epoch 17/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 32.82it/s]


Epoch 17/200 - Train Loss: 0.1256, Train Acc: 0.9640, Val Loss: 0.1149, Val Acc: 0.9650


Epoch 18/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.59it/s]
Epoch 18/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.86it/s]
Epoch 18/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 34.75it/s]


Epoch 18/200 - Train Loss: 0.1228, Train Acc: 0.9640, Val Loss: 0.1124, Val Acc: 0.9650


Epoch 19/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  9.19it/s]
Epoch 19/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 33.84it/s]
Epoch 19/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 34.89it/s]


Epoch 19/200 - Train Loss: 0.1201, Train Acc: 0.9639, Val Loss: 0.1098, Val Acc: 0.9650


Epoch 20/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  9.09it/s]
Epoch 20/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.80it/s]
Epoch 20/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 32.81it/s]


Epoch 20/200 - Train Loss: 0.1178, Train Acc: 0.9640, Val Loss: 0.1076, Val Acc: 0.9650


Epoch 21/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.37it/s]
Epoch 21/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.26it/s]
Epoch 21/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 32.53it/s]


Epoch 21/200 - Train Loss: 0.1155, Train Acc: 0.9639, Val Loss: 0.1053, Val Acc: 0.9650


Epoch 22/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.04it/s]
Epoch 22/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.11it/s]
Epoch 22/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 31.28it/s]


Epoch 22/200 - Train Loss: 0.1132, Train Acc: 0.9639, Val Loss: 0.1034, Val Acc: 0.9650


Epoch 23/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  7.62it/s]
Epoch 23/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 30.49it/s]
Epoch 23/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 30.48it/s]


Epoch 23/200 - Train Loss: 0.1111, Train Acc: 0.9639, Val Loss: 0.1014, Val Acc: 0.9650


Epoch 24/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  7.85it/s]
Epoch 24/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 29.98it/s]
Epoch 24/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 32.00it/s]


Epoch 24/200 - Train Loss: 0.1092, Train Acc: 0.9638, Val Loss: 0.0995, Val Acc: 0.9650


Epoch 25/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.26it/s]
Epoch 25/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 33.24it/s]
Epoch 25/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 35.05it/s]


Epoch 25/200 - Train Loss: 0.1072, Train Acc: 0.9638, Val Loss: 0.0977, Val Acc: 0.9650


Epoch 26/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.49it/s]
Epoch 26/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 32.25it/s]
Epoch 26/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 35.03it/s]


Epoch 26/200 - Train Loss: 0.1053, Train Acc: 0.9640, Val Loss: 0.0964, Val Acc: 0.9650


Epoch 27/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  9.15it/s]
Epoch 27/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 33.77it/s]
Epoch 27/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 35.30it/s]


Epoch 27/200 - Train Loss: 0.1036, Train Acc: 0.9638, Val Loss: 0.0946, Val Acc: 0.9650


Epoch 28/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  9.15it/s]
Epoch 28/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 33.91it/s]
Epoch 28/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 35.18it/s]


Epoch 28/200 - Train Loss: 0.1018, Train Acc: 0.9639, Val Loss: 0.0930, Val Acc: 0.9650


Epoch 29/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.13it/s]
Epoch 29/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 30.85it/s]
Epoch 29/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 32.75it/s]


Epoch 29/200 - Train Loss: 0.1002, Train Acc: 0.9639, Val Loss: 0.0918, Val Acc: 0.9650


Epoch 30/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.33it/s]
Epoch 30/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 33.14it/s]
Epoch 30/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 31.39it/s]


Epoch 30/200 - Train Loss: 0.0987, Train Acc: 0.9639, Val Loss: 0.0903, Val Acc: 0.9650


Epoch 31/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.31it/s]
Epoch 31/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 33.19it/s]
Epoch 31/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 32.90it/s]


Epoch 31/200 - Train Loss: 0.0972, Train Acc: 0.9640, Val Loss: 0.0891, Val Acc: 0.9650


Epoch 32/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.36it/s]
Epoch 32/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.48it/s]
Epoch 32/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 34.71it/s]


Epoch 32/200 - Train Loss: 0.0956, Train Acc: 0.9640, Val Loss: 0.0878, Val Acc: 0.9650


Epoch 33/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.46it/s]
Epoch 33/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.83it/s]
Epoch 33/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 31.21it/s]


Epoch 33/200 - Train Loss: 0.0942, Train Acc: 0.9640, Val Loss: 0.0866, Val Acc: 0.9650


Epoch 34/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.58it/s]
Epoch 34/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.09it/s]
Epoch 34/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 30.91it/s]


Epoch 34/200 - Train Loss: 0.0929, Train Acc: 0.9640, Val Loss: 0.0856, Val Acc: 0.9650


Epoch 35/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.63it/s]
Epoch 35/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 32.00it/s]
Epoch 35/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 33.65it/s]


Epoch 35/200 - Train Loss: 0.0915, Train Acc: 0.9639, Val Loss: 0.0843, Val Acc: 0.9650


Epoch 36/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.64it/s]
Epoch 36/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 32.09it/s]
Epoch 36/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 33.67it/s]


Epoch 36/200 - Train Loss: 0.0901, Train Acc: 0.9639, Val Loss: 0.0830, Val Acc: 0.9650


Epoch 37/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.55it/s]
Epoch 37/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 32.02it/s]
Epoch 37/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 33.14it/s]


Epoch 37/200 - Train Loss: 0.0889, Train Acc: 0.9639, Val Loss: 0.0824, Val Acc: 0.9650


Epoch 38/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.58it/s]
Epoch 38/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.86it/s]
Epoch 38/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 33.03it/s]


Epoch 38/200 - Train Loss: 0.0877, Train Acc: 0.9639, Val Loss: 0.0812, Val Acc: 0.9650


Epoch 39/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.58it/s]
Epoch 39/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 32.26it/s]
Epoch 39/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 33.72it/s]


Epoch 39/200 - Train Loss: 0.0864, Train Acc: 0.9639, Val Loss: 0.0805, Val Acc: 0.9650


Epoch 40/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.55it/s]
Epoch 40/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 32.05it/s]
Epoch 40/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 33.63it/s]


Epoch 40/200 - Train Loss: 0.0851, Train Acc: 0.9640, Val Loss: 0.0792, Val Acc: 0.9650


Epoch 41/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.58it/s]
Epoch 41/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 30.78it/s]
Epoch 41/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 32.84it/s]


Epoch 41/200 - Train Loss: 0.0841, Train Acc: 0.9639, Val Loss: 0.0784, Val Acc: 0.9650


Epoch 42/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.54it/s]
Epoch 42/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.90it/s]
Epoch 42/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 33.64it/s]


Epoch 42/200 - Train Loss: 0.0828, Train Acc: 0.9641, Val Loss: 0.0776, Val Acc: 0.9650


Epoch 43/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.59it/s]
Epoch 43/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 32.18it/s]
Epoch 43/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 33.47it/s]


Epoch 43/200 - Train Loss: 0.0817, Train Acc: 0.9640, Val Loss: 0.0765, Val Acc: 0.9650


Epoch 44/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.57it/s]
Epoch 44/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.74it/s]
Epoch 44/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 30.99it/s]


Epoch 44/200 - Train Loss: 0.0805, Train Acc: 0.9640, Val Loss: 0.0756, Val Acc: 0.9650


Epoch 45/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.50it/s]
Epoch 45/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.85it/s]
Epoch 45/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 33.58it/s]


Epoch 45/200 - Train Loss: 0.0794, Train Acc: 0.9640, Val Loss: 0.0749, Val Acc: 0.9650


Epoch 46/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.59it/s]
Epoch 46/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.94it/s]
Epoch 46/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 33.51it/s]


Epoch 46/200 - Train Loss: 0.0785, Train Acc: 0.9640, Val Loss: 0.0740, Val Acc: 0.9650


Epoch 47/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.59it/s]
Epoch 47/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.82it/s]
Epoch 47/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 33.55it/s]


Epoch 47/200 - Train Loss: 0.0779, Train Acc: 0.9641, Val Loss: 0.0739, Val Acc: 0.9650


Epoch 48/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.60it/s]
Epoch 48/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 32.01it/s]
Epoch 48/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 33.51it/s]


Epoch 48/200 - Train Loss: 0.0766, Train Acc: 0.9640, Val Loss: 0.0727, Val Acc: 0.9650


Epoch 49/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.55it/s]
Epoch 49/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 32.00it/s]
Epoch 49/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 33.48it/s]


Epoch 49/200 - Train Loss: 0.0758, Train Acc: 0.9641, Val Loss: 0.0719, Val Acc: 0.9650


Epoch 50/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.59it/s]
Epoch 50/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.87it/s]
Epoch 50/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 33.43it/s]


Epoch 50/200 - Train Loss: 0.0746, Train Acc: 0.9642, Val Loss: 0.0708, Val Acc: 0.9650


Epoch 51/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.59it/s]
Epoch 51/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.89it/s]
Epoch 51/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 33.79it/s]


Epoch 51/200 - Train Loss: 0.0737, Train Acc: 0.9641, Val Loss: 0.0705, Val Acc: 0.9650


Epoch 52/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.57it/s]
Epoch 52/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 32.06it/s]
Epoch 52/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 33.42it/s]


Epoch 52/200 - Train Loss: 0.0724, Train Acc: 0.9640, Val Loss: 0.0695, Val Acc: 0.9650


Epoch 53/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.48it/s]
Epoch 53/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 32.17it/s]
Epoch 53/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 33.46it/s]


Epoch 53/200 - Train Loss: 0.0716, Train Acc: 0.9643, Val Loss: 0.0689, Val Acc: 0.9650


Epoch 54/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.61it/s]
Epoch 54/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 32.03it/s]
Epoch 54/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 33.44it/s]


Epoch 54/200 - Train Loss: 0.0705, Train Acc: 0.9642, Val Loss: 0.0682, Val Acc: 0.9650


Epoch 55/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.54it/s]
Epoch 55/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.72it/s]
Epoch 55/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 33.54it/s]


Epoch 55/200 - Train Loss: 0.0696, Train Acc: 0.9643, Val Loss: 0.0681, Val Acc: 0.9649


Epoch 56/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.62it/s]
Epoch 56/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 32.06it/s]
Epoch 56/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 33.70it/s]


Epoch 56/200 - Train Loss: 0.0687, Train Acc: 0.9643, Val Loss: 0.0667, Val Acc: 0.9650


Epoch 57/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.55it/s]
Epoch 57/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.72it/s]
Epoch 57/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 33.23it/s]


Epoch 57/200 - Train Loss: 0.0678, Train Acc: 0.9644, Val Loss: 0.0663, Val Acc: 0.9650


Epoch 58/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.27it/s]
Epoch 58/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.76it/s]
Epoch 58/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 33.59it/s]


Epoch 58/200 - Train Loss: 0.0670, Train Acc: 0.9644, Val Loss: 0.0660, Val Acc: 0.9649


Epoch 59/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.58it/s]
Epoch 59/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.96it/s]
Epoch 59/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 33.36it/s]


Epoch 59/200 - Train Loss: 0.0660, Train Acc: 0.9645, Val Loss: 0.0646, Val Acc: 0.9650


Epoch 60/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.56it/s]
Epoch 60/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.92it/s]
Epoch 60/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 33.18it/s]


Epoch 60/200 - Train Loss: 0.0651, Train Acc: 0.9645, Val Loss: 0.0640, Val Acc: 0.9650


Epoch 61/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.57it/s]
Epoch 61/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.85it/s]
Epoch 61/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 33.69it/s]


Epoch 61/200 - Train Loss: 0.0641, Train Acc: 0.9645, Val Loss: 0.0636, Val Acc: 0.9650


Epoch 62/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.60it/s]
Epoch 62/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.92it/s]
Epoch 62/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 33.60it/s]


Epoch 62/200 - Train Loss: 0.0636, Train Acc: 0.9646, Val Loss: 0.0629, Val Acc: 0.9650


Epoch 63/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  7.86it/s]
Epoch 63/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 29.60it/s]
Epoch 63/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 33.37it/s]


Epoch 63/200 - Train Loss: 0.0627, Train Acc: 0.9645, Val Loss: 0.0618, Val Acc: 0.9650


Epoch 64/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.38it/s]
Epoch 64/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.83it/s]
Epoch 64/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 32.32it/s]


Epoch 64/200 - Train Loss: 0.0620, Train Acc: 0.9647, Val Loss: 0.0615, Val Acc: 0.9650


Epoch 65/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.49it/s]
Epoch 65/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 28.39it/s]
Epoch 65/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 32.85it/s]


Epoch 65/200 - Train Loss: 0.0608, Train Acc: 0.9646, Val Loss: 0.0610, Val Acc: 0.9650


Epoch 66/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  7.90it/s]
Epoch 66/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 29.34it/s]
Epoch 66/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 31.23it/s]


Epoch 66/200 - Train Loss: 0.0602, Train Acc: 0.9648, Val Loss: 0.0603, Val Acc: 0.9650


Epoch 67/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  7.57it/s]
Epoch 67/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 28.34it/s]
Epoch 67/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 29.69it/s]


Epoch 67/200 - Train Loss: 0.0593, Train Acc: 0.9649, Val Loss: 0.0598, Val Acc: 0.9651


Epoch 68/200 - Training: 100%|██████████| 15/15 [00:02<00:00,  7.45it/s]
Epoch 68/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 29.38it/s]
Epoch 68/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 31.31it/s]


Epoch 68/200 - Train Loss: 0.0588, Train Acc: 0.9649, Val Loss: 0.0594, Val Acc: 0.9652


Epoch 69/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  7.79it/s]
Epoch 69/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 28.60it/s]
Epoch 69/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 30.04it/s]


Epoch 69/200 - Train Loss: 0.0578, Train Acc: 0.9651, Val Loss: 0.0585, Val Acc: 0.9654


Epoch 70/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  7.75it/s]
Epoch 70/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 28.62it/s]
Epoch 70/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 30.53it/s]


Epoch 70/200 - Train Loss: 0.0572, Train Acc: 0.9652, Val Loss: 0.0580, Val Acc: 0.9652


Epoch 71/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  7.72it/s]
Epoch 71/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 29.26it/s]
Epoch 71/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 30.19it/s]


Epoch 71/200 - Train Loss: 0.0564, Train Acc: 0.9652, Val Loss: 0.0576, Val Acc: 0.9654


Epoch 72/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  7.82it/s]
Epoch 72/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 28.96it/s]
Epoch 72/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 30.49it/s]


Epoch 72/200 - Train Loss: 0.0559, Train Acc: 0.9655, Val Loss: 0.0570, Val Acc: 0.9654


Epoch 73/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  7.79it/s]
Epoch 73/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 28.90it/s]
Epoch 73/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 30.29it/s]


Epoch 73/200 - Train Loss: 0.0549, Train Acc: 0.9654, Val Loss: 0.0567, Val Acc: 0.9654


Epoch 74/200 - Training: 100%|██████████| 15/15 [00:02<00:00,  7.46it/s]
Epoch 74/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 28.82it/s]
Epoch 74/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 27.00it/s]


Epoch 74/200 - Train Loss: 0.0546, Train Acc: 0.9655, Val Loss: 0.0561, Val Acc: 0.9660


Epoch 75/200 - Training: 100%|██████████| 15/15 [00:02<00:00,  6.99it/s]
Epoch 75/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 25.39it/s]
Epoch 75/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 26.03it/s]


Epoch 75/200 - Train Loss: 0.0539, Train Acc: 0.9657, Val Loss: 0.0555, Val Acc: 0.9653


Epoch 76/200 - Training: 100%|██████████| 15/15 [00:02<00:00,  6.85it/s]
Epoch 76/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 25.45it/s]
Epoch 76/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 26.19it/s]


Epoch 76/200 - Train Loss: 0.0530, Train Acc: 0.9658, Val Loss: 0.0551, Val Acc: 0.9659


Epoch 77/200 - Training: 100%|██████████| 15/15 [00:02<00:00,  6.83it/s]
Epoch 77/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 28.20it/s]
Epoch 77/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 28.24it/s]


Epoch 77/200 - Train Loss: 0.0522, Train Acc: 0.9660, Val Loss: 0.0547, Val Acc: 0.9658


Epoch 78/200 - Training: 100%|██████████| 15/15 [00:02<00:00,  7.45it/s]
Epoch 78/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 26.88it/s]
Epoch 78/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 28.15it/s]


Epoch 78/200 - Train Loss: 0.0515, Train Acc: 0.9663, Val Loss: 0.0541, Val Acc: 0.9658


Epoch 79/200 - Training: 100%|██████████| 15/15 [00:02<00:00,  7.13it/s]
Epoch 79/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 27.17it/s]
Epoch 79/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 27.34it/s]


Epoch 79/200 - Train Loss: 0.0510, Train Acc: 0.9664, Val Loss: 0.0535, Val Acc: 0.9663


Epoch 80/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  7.51it/s]
Epoch 80/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 26.25it/s]
Epoch 80/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 28.06it/s]


Epoch 80/200 - Train Loss: 0.0502, Train Acc: 0.9666, Val Loss: 0.0531, Val Acc: 0.9660


Epoch 81/200 - Training: 100%|██████████| 15/15 [00:02<00:00,  7.25it/s]
Epoch 81/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 26.18it/s]
Epoch 81/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 29.12it/s]


Epoch 81/200 - Train Loss: 0.0497, Train Acc: 0.9671, Val Loss: 0.0525, Val Acc: 0.9662


Epoch 82/200 - Training: 100%|██████████| 15/15 [00:02<00:00,  7.28it/s]
Epoch 82/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 27.63it/s]
Epoch 82/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 27.18it/s]


Epoch 82/200 - Train Loss: 0.0495, Train Acc: 0.9668, Val Loss: 0.0522, Val Acc: 0.9669


Epoch 83/200 - Training: 100%|██████████| 15/15 [00:02<00:00,  7.06it/s]
Epoch 83/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 25.29it/s]
Epoch 83/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 26.03it/s]


Epoch 83/200 - Train Loss: 0.0486, Train Acc: 0.9671, Val Loss: 0.0518, Val Acc: 0.9660


Epoch 84/200 - Training: 100%|██████████| 15/15 [00:02<00:00,  6.83it/s]
Epoch 84/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 25.46it/s]
Epoch 84/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 26.11it/s]


Epoch 84/200 - Train Loss: 0.0483, Train Acc: 0.9670, Val Loss: 0.0513, Val Acc: 0.9666


Epoch 85/200 - Training: 100%|██████████| 15/15 [00:02<00:00,  6.98it/s]
Epoch 85/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 25.17it/s]
Epoch 85/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 26.13it/s]


Epoch 85/200 - Train Loss: 0.0475, Train Acc: 0.9673, Val Loss: 0.0509, Val Acc: 0.9669


Epoch 86/200 - Training: 100%|██████████| 15/15 [00:02<00:00,  7.08it/s]
Epoch 86/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 26.02it/s]
Epoch 86/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 28.61it/s]


Epoch 86/200 - Train Loss: 0.0466, Train Acc: 0.9678, Val Loss: 0.0506, Val Acc: 0.9665


Epoch 87/200 - Training: 100%|██████████| 15/15 [00:02<00:00,  7.18it/s]
Epoch 87/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 28.46it/s]
Epoch 87/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 30.12it/s]


Epoch 87/200 - Train Loss: 0.0461, Train Acc: 0.9679, Val Loss: 0.0499, Val Acc: 0.9664


Epoch 88/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  7.76it/s]
Epoch 88/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 28.93it/s]
Epoch 88/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 30.21it/s]


Epoch 88/200 - Train Loss: 0.0453, Train Acc: 0.9689, Val Loss: 0.0499, Val Acc: 0.9666


Epoch 89/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  7.78it/s]
Epoch 89/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 29.07it/s]
Epoch 89/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 30.11it/s]


Epoch 89/200 - Train Loss: 0.0447, Train Acc: 0.9680, Val Loss: 0.0496, Val Acc: 0.9670


Epoch 90/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  7.73it/s]
Epoch 90/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 29.01it/s]
Epoch 90/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 29.97it/s]


Epoch 90/200 - Train Loss: 0.0445, Train Acc: 0.9686, Val Loss: 0.0491, Val Acc: 0.9667


Epoch 91/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  7.77it/s]
Epoch 91/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 28.73it/s]
Epoch 91/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 30.03it/s]


Epoch 91/200 - Train Loss: 0.0438, Train Acc: 0.9687, Val Loss: 0.0485, Val Acc: 0.9675


Epoch 92/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  7.66it/s]
Epoch 92/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 29.49it/s]
Epoch 92/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 30.56it/s]


Epoch 92/200 - Train Loss: 0.0434, Train Acc: 0.9685, Val Loss: 0.0481, Val Acc: 0.9678


Epoch 93/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  7.89it/s]
Epoch 93/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 29.40it/s]
Epoch 93/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 30.72it/s]


Epoch 93/200 - Train Loss: 0.0426, Train Acc: 0.9692, Val Loss: 0.0477, Val Acc: 0.9676


Epoch 94/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  7.87it/s]
Epoch 94/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 29.43it/s]
Epoch 94/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 30.70it/s]


Epoch 94/200 - Train Loss: 0.0423, Train Acc: 0.9694, Val Loss: 0.0474, Val Acc: 0.9677


Epoch 95/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  7.97it/s]
Epoch 95/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 29.50it/s]
Epoch 95/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 30.87it/s]


Epoch 95/200 - Train Loss: 0.0420, Train Acc: 0.9692, Val Loss: 0.0469, Val Acc: 0.9680


Epoch 96/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  7.90it/s]
Epoch 96/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 29.61it/s]
Epoch 96/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 30.76it/s]


Epoch 96/200 - Train Loss: 0.0408, Train Acc: 0.9696, Val Loss: 0.0466, Val Acc: 0.9684


Epoch 97/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  7.92it/s]
Epoch 97/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 29.41it/s]
Epoch 97/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 30.57it/s]


Epoch 97/200 - Train Loss: 0.0409, Train Acc: 0.9696, Val Loss: 0.0463, Val Acc: 0.9678


Epoch 98/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  7.91it/s]
Epoch 98/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 29.52it/s]
Epoch 98/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 30.89it/s]


Epoch 98/200 - Train Loss: 0.0404, Train Acc: 0.9702, Val Loss: 0.0464, Val Acc: 0.9679


Epoch 99/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  7.96it/s]
Epoch 99/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 29.39it/s]
Epoch 99/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 30.97it/s]


Epoch 99/200 - Train Loss: 0.0402, Train Acc: 0.9699, Val Loss: 0.0463, Val Acc: 0.9681


Epoch 100/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  7.91it/s]
Epoch 100/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 29.54it/s]
Epoch 100/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 30.79it/s]


Epoch 100/200 - Train Loss: 0.0400, Train Acc: 0.9699, Val Loss: 0.0463, Val Acc: 0.9681


Epoch 101/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  7.94it/s]
Epoch 101/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 30.12it/s]
Epoch 101/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 31.53it/s]


Epoch 101/200 - Train Loss: 0.0399, Train Acc: 0.9701, Val Loss: 0.0454, Val Acc: 0.9688


Epoch 102/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  7.58it/s]
Epoch 102/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 30.11it/s]
Epoch 102/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 31.40it/s]


Epoch 102/200 - Train Loss: 0.0388, Train Acc: 0.9704, Val Loss: 0.0450, Val Acc: 0.9684


Epoch 103/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.13it/s]
Epoch 103/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 30.32it/s]
Epoch 103/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 31.64it/s]


Epoch 103/200 - Train Loss: 0.0385, Train Acc: 0.9708, Val Loss: 0.0446, Val Acc: 0.9682


Epoch 104/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  7.95it/s]
Epoch 104/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 30.35it/s]
Epoch 104/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 31.74it/s]


Epoch 104/200 - Train Loss: 0.0385, Train Acc: 0.9705, Val Loss: 0.0446, Val Acc: 0.9697


Epoch 105/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.11it/s]
Epoch 105/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 30.34it/s]
Epoch 105/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 31.63it/s]


Epoch 105/200 - Train Loss: 0.0374, Train Acc: 0.9710, Val Loss: 0.0443, Val Acc: 0.9687


Epoch 106/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.08it/s]
Epoch 106/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 30.39it/s]
Epoch 106/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 31.50it/s]


Epoch 106/200 - Train Loss: 0.0374, Train Acc: 0.9713, Val Loss: 0.0442, Val Acc: 0.9684


Epoch 107/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.04it/s]
Epoch 107/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 29.92it/s]
Epoch 107/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 30.97it/s]


Epoch 107/200 - Train Loss: 0.0369, Train Acc: 0.9710, Val Loss: 0.0440, Val Acc: 0.9689


Epoch 108/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  7.99it/s]
Epoch 108/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 29.84it/s]
Epoch 108/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 31.19it/s]


Epoch 108/200 - Train Loss: 0.0372, Train Acc: 0.9711, Val Loss: 0.0435, Val Acc: 0.9692


Epoch 109/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  7.85it/s]
Epoch 109/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 29.59it/s]
Epoch 109/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 30.74it/s]


Epoch 109/200 - Train Loss: 0.0357, Train Acc: 0.9721, Val Loss: 0.0430, Val Acc: 0.9693


Epoch 110/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  7.91it/s]
Epoch 110/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 29.62it/s]
Epoch 110/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 30.87it/s]


Epoch 110/200 - Train Loss: 0.0349, Train Acc: 0.9717, Val Loss: 0.0427, Val Acc: 0.9693


Epoch 111/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  7.85it/s]
Epoch 111/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 29.35it/s]
Epoch 111/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 30.60it/s]


Epoch 111/200 - Train Loss: 0.0347, Train Acc: 0.9726, Val Loss: 0.0423, Val Acc: 0.9695


Epoch 112/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  7.88it/s]
Epoch 112/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 29.37it/s]
Epoch 112/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 30.66it/s]


Epoch 112/200 - Train Loss: 0.0346, Train Acc: 0.9724, Val Loss: 0.0419, Val Acc: 0.9698


Epoch 113/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  7.90it/s]
Epoch 113/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 29.37it/s]
Epoch 113/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 31.00it/s]


Epoch 113/200 - Train Loss: 0.0336, Train Acc: 0.9729, Val Loss: 0.0421, Val Acc: 0.9697


Epoch 114/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  7.88it/s]
Epoch 114/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 29.49it/s]
Epoch 114/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 30.40it/s]


Epoch 114/200 - Train Loss: 0.0337, Train Acc: 0.9727, Val Loss: 0.0416, Val Acc: 0.9695


Epoch 115/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  7.83it/s]
Epoch 115/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 29.29it/s]
Epoch 115/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 30.63it/s]


Epoch 115/200 - Train Loss: 0.0331, Train Acc: 0.9728, Val Loss: 0.0415, Val Acc: 0.9695


Epoch 116/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  7.88it/s]
Epoch 116/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 29.44it/s]
Epoch 116/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 30.74it/s]


Epoch 116/200 - Train Loss: 0.0332, Train Acc: 0.9725, Val Loss: 0.0415, Val Acc: 0.9699


Epoch 117/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  7.88it/s]
Epoch 117/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 29.28it/s]
Epoch 117/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 30.62it/s]


Epoch 117/200 - Train Loss: 0.0326, Train Acc: 0.9732, Val Loss: 0.0411, Val Acc: 0.9699


Epoch 118/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  7.87it/s]
Epoch 118/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 29.41it/s]
Epoch 118/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 30.57it/s]


Epoch 118/200 - Train Loss: 0.0329, Train Acc: 0.9730, Val Loss: 0.0409, Val Acc: 0.9709


Epoch 119/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  7.96it/s]
Epoch 119/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 29.53it/s]
Epoch 119/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 31.98it/s]


Epoch 119/200 - Train Loss: 0.0321, Train Acc: 0.9737, Val Loss: 0.0408, Val Acc: 0.9705


Epoch 120/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.13it/s]
Epoch 120/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 29.98it/s]
Epoch 120/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 32.00it/s]


Epoch 120/200 - Train Loss: 0.0332, Train Acc: 0.9727, Val Loss: 0.0403, Val Acc: 0.9703


Epoch 121/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.23it/s]
Epoch 121/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.14it/s]
Epoch 121/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 32.34it/s]


Epoch 121/200 - Train Loss: 0.0322, Train Acc: 0.9734, Val Loss: 0.0410, Val Acc: 0.9698


Epoch 122/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.27it/s]
Epoch 122/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 30.75it/s]
Epoch 122/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 31.99it/s]


Epoch 122/200 - Train Loss: 0.0320, Train Acc: 0.9733, Val Loss: 0.0405, Val Acc: 0.9696


Epoch 123/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.26it/s]
Epoch 123/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 30.93it/s]
Epoch 123/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 32.29it/s]


Epoch 123/200 - Train Loss: 0.0315, Train Acc: 0.9742, Val Loss: 0.0400, Val Acc: 0.9705


Epoch 124/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.28it/s]
Epoch 124/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 30.79it/s]
Epoch 124/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 32.18it/s]


Epoch 124/200 - Train Loss: 0.0315, Train Acc: 0.9733, Val Loss: 0.0396, Val Acc: 0.9711


Epoch 125/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  7.73it/s]
Epoch 125/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 30.82it/s]
Epoch 125/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 31.82it/s]


Epoch 125/200 - Train Loss: 0.0305, Train Acc: 0.9742, Val Loss: 0.0397, Val Acc: 0.9702


Epoch 126/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.15it/s]
Epoch 126/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 30.86it/s]
Epoch 126/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 30.89it/s]


Epoch 126/200 - Train Loss: 0.0299, Train Acc: 0.9741, Val Loss: 0.0393, Val Acc: 0.9711


Epoch 127/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.19it/s]
Epoch 127/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 32.99it/s]
Epoch 127/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 31.23it/s]


Epoch 127/200 - Train Loss: 0.0305, Train Acc: 0.9746, Val Loss: 0.0394, Val Acc: 0.9704


Epoch 128/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.05it/s]
Epoch 128/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 29.41it/s]
Epoch 128/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 34.45it/s]


Epoch 128/200 - Train Loss: 0.0295, Train Acc: 0.9745, Val Loss: 0.0390, Val Acc: 0.9711


Epoch 129/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.13it/s]
Epoch 129/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 32.54it/s]
Epoch 129/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 34.71it/s]


Epoch 129/200 - Train Loss: 0.0301, Train Acc: 0.9745, Val Loss: 0.0390, Val Acc: 0.9708


Epoch 130/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.66it/s]
Epoch 130/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 33.02it/s]
Epoch 130/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 34.44it/s]


Epoch 130/200 - Train Loss: 0.0292, Train Acc: 0.9752, Val Loss: 0.0390, Val Acc: 0.9713


Epoch 131/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.39it/s]
Epoch 131/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 30.67it/s]
Epoch 131/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 34.33it/s]


Epoch 131/200 - Train Loss: 0.0294, Train Acc: 0.9749, Val Loss: 0.0391, Val Acc: 0.9710


Epoch 132/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.47it/s]
Epoch 132/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 33.17it/s]
Epoch 132/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 34.71it/s]


Epoch 132/200 - Train Loss: 0.0291, Train Acc: 0.9750, Val Loss: 0.0386, Val Acc: 0.9709


Epoch 133/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.70it/s]
Epoch 133/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 32.91it/s]
Epoch 133/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 34.82it/s]


Epoch 133/200 - Train Loss: 0.0289, Train Acc: 0.9747, Val Loss: 0.0382, Val Acc: 0.9710


Epoch 134/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.53it/s]
Epoch 134/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.81it/s]
Epoch 134/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 34.59it/s]


Epoch 134/200 - Train Loss: 0.0284, Train Acc: 0.9756, Val Loss: 0.0386, Val Acc: 0.9720


Epoch 135/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.67it/s]
Epoch 135/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 32.96it/s]
Epoch 135/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 31.86it/s]


Epoch 135/200 - Train Loss: 0.0275, Train Acc: 0.9765, Val Loss: 0.0377, Val Acc: 0.9717


Epoch 136/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.85it/s]
Epoch 136/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 33.03it/s]
Epoch 136/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 34.89it/s]


Epoch 136/200 - Train Loss: 0.0281, Train Acc: 0.9754, Val Loss: 0.0378, Val Acc: 0.9720


Epoch 137/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.71it/s]
Epoch 137/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 32.88it/s]
Epoch 137/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 34.86it/s]


Epoch 137/200 - Train Loss: 0.0274, Train Acc: 0.9767, Val Loss: 0.0377, Val Acc: 0.9722


Epoch 138/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.81it/s]
Epoch 138/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 33.12it/s]
Epoch 138/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 34.92it/s]


Epoch 138/200 - Train Loss: 0.0277, Train Acc: 0.9764, Val Loss: 0.0377, Val Acc: 0.9722


Epoch 139/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.51it/s]
Epoch 139/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 28.81it/s]
Epoch 139/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 30.94it/s]


Epoch 139/200 - Train Loss: 0.0279, Train Acc: 0.9767, Val Loss: 0.0376, Val Acc: 0.9723


Epoch 140/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.26it/s]
Epoch 140/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 32.84it/s]
Epoch 140/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 31.48it/s]


Epoch 140/200 - Train Loss: 0.0274, Train Acc: 0.9759, Val Loss: 0.0381, Val Acc: 0.9728


Epoch 141/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.37it/s]
Epoch 141/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 29.87it/s]
Epoch 141/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 31.83it/s]


Epoch 141/200 - Train Loss: 0.0278, Train Acc: 0.9764, Val Loss: 0.0378, Val Acc: 0.9736


Epoch 142/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.68it/s]
Epoch 142/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 33.05it/s]
Epoch 142/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 34.73it/s]


Epoch 142/200 - Train Loss: 0.0273, Train Acc: 0.9766, Val Loss: 0.0380, Val Acc: 0.9726


Epoch 143/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.82it/s]
Epoch 143/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 33.18it/s]
Epoch 143/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 34.82it/s]


Epoch 143/200 - Train Loss: 0.0275, Train Acc: 0.9763, Val Loss: 0.0376, Val Acc: 0.9729


Epoch 144/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.46it/s]
Epoch 144/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 30.75it/s]
Epoch 144/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 34.34it/s]


Epoch 144/200 - Train Loss: 0.0275, Train Acc: 0.9765, Val Loss: 0.0377, Val Acc: 0.9726


Epoch 145/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.34it/s]
Epoch 145/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 33.04it/s]
Epoch 145/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 34.79it/s]


Epoch 145/200 - Train Loss: 0.0275, Train Acc: 0.9766, Val Loss: 0.0377, Val Acc: 0.9731


Epoch 146/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.37it/s]
Epoch 146/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.42it/s]
Epoch 146/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 31.10it/s]


Epoch 146/200 - Train Loss: 0.0264, Train Acc: 0.9773, Val Loss: 0.0368, Val Acc: 0.9727


Epoch 147/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.39it/s]
Epoch 147/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 32.38it/s]
Epoch 147/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 34.97it/s]


Epoch 147/200 - Train Loss: 0.0264, Train Acc: 0.9776, Val Loss: 0.0368, Val Acc: 0.9746


Epoch 148/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.85it/s]
Epoch 148/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 32.98it/s]
Epoch 148/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 34.73it/s]


Epoch 148/200 - Train Loss: 0.0256, Train Acc: 0.9781, Val Loss: 0.0363, Val Acc: 0.9736


Epoch 149/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.83it/s]
Epoch 149/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 33.24it/s]
Epoch 149/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 34.85it/s]


Epoch 149/200 - Train Loss: 0.0251, Train Acc: 0.9783, Val Loss: 0.0364, Val Acc: 0.9738


Epoch 150/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.61it/s]
Epoch 150/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.30it/s]
Epoch 150/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 34.80it/s]


Epoch 150/200 - Train Loss: 0.0256, Train Acc: 0.9782, Val Loss: 0.0365, Val Acc: 0.9735


Epoch 151/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.75it/s]
Epoch 151/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 33.20it/s]
Epoch 151/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 31.91it/s]


Epoch 151/200 - Train Loss: 0.0251, Train Acc: 0.9781, Val Loss: 0.0361, Val Acc: 0.9739


Epoch 152/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.83it/s]
Epoch 152/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 32.44it/s]
Epoch 152/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 32.06it/s]


Epoch 152/200 - Train Loss: 0.0245, Train Acc: 0.9789, Val Loss: 0.0365, Val Acc: 0.9742


Epoch 153/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.53it/s]
Epoch 153/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 33.25it/s]
Epoch 153/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 35.13it/s]


Epoch 153/200 - Train Loss: 0.0242, Train Acc: 0.9797, Val Loss: 0.0357, Val Acc: 0.9742


Epoch 154/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.51it/s]
Epoch 154/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.45it/s]
Epoch 154/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 34.92it/s]


Epoch 154/200 - Train Loss: 0.0244, Train Acc: 0.9784, Val Loss: 0.0363, Val Acc: 0.9745


Epoch 155/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.32it/s]
Epoch 155/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 30.69it/s]
Epoch 155/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 34.93it/s]


Epoch 155/200 - Train Loss: 0.0249, Train Acc: 0.9783, Val Loss: 0.0364, Val Acc: 0.9749


Epoch 156/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.71it/s]
Epoch 156/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 33.30it/s]
Epoch 156/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 34.78it/s]


Epoch 156/200 - Train Loss: 0.0233, Train Acc: 0.9804, Val Loss: 0.0358, Val Acc: 0.9745


Epoch 157/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.82it/s]
Epoch 157/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 33.32it/s]
Epoch 157/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 33.91it/s]


Epoch 157/200 - Train Loss: 0.0240, Train Acc: 0.9796, Val Loss: 0.0371, Val Acc: 0.9742


Epoch 158/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.63it/s]
Epoch 158/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 28.41it/s]
Epoch 158/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 34.38it/s]


Epoch 158/200 - Train Loss: 0.0239, Train Acc: 0.9795, Val Loss: 0.0358, Val Acc: 0.9755


Epoch 159/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.87it/s]
Epoch 159/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 33.15it/s]
Epoch 159/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 32.37it/s]


Epoch 159/200 - Train Loss: 0.0235, Train Acc: 0.9799, Val Loss: 0.0362, Val Acc: 0.9747


Epoch 160/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.85it/s]
Epoch 160/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 32.30it/s]
Epoch 160/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 34.72it/s]


Epoch 160/200 - Train Loss: 0.0236, Train Acc: 0.9804, Val Loss: 0.0364, Val Acc: 0.9757


Epoch 161/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.72it/s]
Epoch 161/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 33.32it/s]
Epoch 161/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 33.84it/s]


Epoch 161/200 - Train Loss: 0.0232, Train Acc: 0.9805, Val Loss: 0.0362, Val Acc: 0.9749


Epoch 162/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.04it/s]
Epoch 162/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 30.56it/s]
Epoch 162/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 32.29it/s]


Epoch 162/200 - Train Loss: 0.0242, Train Acc: 0.9796, Val Loss: 0.0360, Val Acc: 0.9763


Epoch 163/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.51it/s]
Epoch 163/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 33.00it/s]
Epoch 163/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 35.08it/s]


Epoch 163/200 - Train Loss: 0.0230, Train Acc: 0.9812, Val Loss: 0.0363, Val Acc: 0.9748


Epoch 164/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.09it/s]
Epoch 164/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 32.38it/s]
Epoch 164/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 34.31it/s]


Epoch 164/200 - Train Loss: 0.0231, Train Acc: 0.9806, Val Loss: 0.0359, Val Acc: 0.9761


Epoch 165/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.70it/s]
Epoch 165/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 32.95it/s]
Epoch 165/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 34.59it/s]


Epoch 165/200 - Train Loss: 0.0228, Train Acc: 0.9811, Val Loss: 0.0354, Val Acc: 0.9752


Epoch 166/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.73it/s]
Epoch 166/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 32.50it/s]
Epoch 166/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 34.90it/s]


Epoch 166/200 - Train Loss: 0.0218, Train Acc: 0.9810, Val Loss: 0.0354, Val Acc: 0.9765


Epoch 167/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.61it/s]
Epoch 167/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 33.23it/s]
Epoch 167/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 34.97it/s]


Epoch 167/200 - Train Loss: 0.0220, Train Acc: 0.9820, Val Loss: 0.0354, Val Acc: 0.9764


Epoch 168/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.84it/s]
Epoch 168/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 33.13it/s]
Epoch 168/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 34.88it/s]


Epoch 168/200 - Train Loss: 0.0233, Train Acc: 0.9809, Val Loss: 0.0352, Val Acc: 0.9768


Epoch 169/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.86it/s]
Epoch 169/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 33.01it/s]
Epoch 169/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 34.91it/s]


Epoch 169/200 - Train Loss: 0.0224, Train Acc: 0.9810, Val Loss: 0.0352, Val Acc: 0.9760


Epoch 170/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.66it/s]
Epoch 170/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 33.41it/s]
Epoch 170/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 34.34it/s]


Epoch 170/200 - Train Loss: 0.0234, Train Acc: 0.9811, Val Loss: 0.0367, Val Acc: 0.9755


Epoch 171/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.60it/s]
Epoch 171/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 32.72it/s]
Epoch 171/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 34.63it/s]


Epoch 171/200 - Train Loss: 0.0231, Train Acc: 0.9803, Val Loss: 0.0353, Val Acc: 0.9757


Epoch 172/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.84it/s]
Epoch 172/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 33.32it/s]
Epoch 172/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 33.12it/s]


Epoch 172/200 - Train Loss: 0.0228, Train Acc: 0.9811, Val Loss: 0.0353, Val Acc: 0.9762


Epoch 173/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  7.92it/s]
Epoch 173/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 30.07it/s]
Epoch 173/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 31.33it/s]


Epoch 173/200 - Train Loss: 0.0237, Train Acc: 0.9801, Val Loss: 0.0366, Val Acc: 0.9762


Epoch 174/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.71it/s]
Epoch 174/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 29.58it/s]
Epoch 174/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 34.00it/s]


Epoch 174/200 - Train Loss: 0.0247, Train Acc: 0.9796, Val Loss: 0.0354, Val Acc: 0.9762


Epoch 175/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.87it/s]
Epoch 175/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 33.11it/s]
Epoch 175/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 34.92it/s]


Epoch 175/200 - Train Loss: 0.0254, Train Acc: 0.9797, Val Loss: 0.0359, Val Acc: 0.9753


Epoch 176/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.27it/s]
Epoch 176/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 30.68it/s]
Epoch 176/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 32.06it/s]


Epoch 176/200 - Train Loss: 0.0235, Train Acc: 0.9817, Val Loss: 0.0357, Val Acc: 0.9756


Epoch 177/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.19it/s]
Epoch 177/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 29.55it/s]
Epoch 177/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 31.25it/s]


Epoch 177/200 - Train Loss: 0.0231, Train Acc: 0.9798, Val Loss: 0.0355, Val Acc: 0.9763


Epoch 178/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.83it/s]
Epoch 178/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 33.08it/s]
Epoch 178/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 33.04it/s]


Epoch 178/200 - Train Loss: 0.0236, Train Acc: 0.9807, Val Loss: 0.0365, Val Acc: 0.9746


Epoch 179/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.86it/s]
Epoch 179/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 33.03it/s]
Epoch 179/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 31.90it/s]


Epoch 179/200 - Train Loss: 0.0224, Train Acc: 0.9809, Val Loss: 0.0351, Val Acc: 0.9766


Epoch 180/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.79it/s]
Epoch 180/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 33.13it/s]
Epoch 180/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 34.51it/s]


Epoch 180/200 - Train Loss: 0.0222, Train Acc: 0.9807, Val Loss: 0.0349, Val Acc: 0.9765


Epoch 181/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.66it/s]
Epoch 181/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.66it/s]
Epoch 181/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 32.39it/s]


Epoch 181/200 - Train Loss: 0.0216, Train Acc: 0.9821, Val Loss: 0.0347, Val Acc: 0.9765


Epoch 182/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.89it/s]
Epoch 182/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 33.17it/s]
Epoch 182/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 34.81it/s]


Epoch 182/200 - Train Loss: 0.0219, Train Acc: 0.9815, Val Loss: 0.0348, Val Acc: 0.9763


Epoch 183/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.83it/s]
Epoch 183/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 33.14it/s]
Epoch 183/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 34.74it/s]


Epoch 183/200 - Train Loss: 0.0210, Train Acc: 0.9821, Val Loss: 0.0349, Val Acc: 0.9764


Epoch 184/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.84it/s]
Epoch 184/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 33.04it/s]
Epoch 184/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 34.69it/s]


Epoch 184/200 - Train Loss: 0.0208, Train Acc: 0.9825, Val Loss: 0.0352, Val Acc: 0.9779


Epoch 185/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.83it/s]
Epoch 185/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 33.27it/s]
Epoch 185/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 34.90it/s]


Epoch 185/200 - Train Loss: 0.0208, Train Acc: 0.9823, Val Loss: 0.0346, Val Acc: 0.9773


Epoch 186/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.81it/s]
Epoch 186/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 32.23it/s]
Epoch 186/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 35.07it/s]


Epoch 186/200 - Train Loss: 0.0201, Train Acc: 0.9829, Val Loss: 0.0348, Val Acc: 0.9764


Epoch 187/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.88it/s]
Epoch 187/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 33.25it/s]
Epoch 187/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 34.66it/s]


Epoch 187/200 - Train Loss: 0.0214, Train Acc: 0.9814, Val Loss: 0.0353, Val Acc: 0.9769


Epoch 188/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.81it/s]
Epoch 188/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 33.01it/s]
Epoch 188/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 34.87it/s]


Epoch 188/200 - Train Loss: 0.0214, Train Acc: 0.9826, Val Loss: 0.0358, Val Acc: 0.9767


Epoch 189/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.87it/s]
Epoch 189/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 33.09it/s]
Epoch 189/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 34.81it/s]


Epoch 189/200 - Train Loss: 0.0201, Train Acc: 0.9830, Val Loss: 0.0353, Val Acc: 0.9771


Epoch 190/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.78it/s]
Epoch 190/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 32.82it/s]
Epoch 190/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 34.73it/s]


Epoch 190/200 - Train Loss: 0.0217, Train Acc: 0.9821, Val Loss: 0.0355, Val Acc: 0.9768


Epoch 191/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.85it/s]
Epoch 191/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 33.15it/s]
Epoch 191/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 34.73it/s]


Epoch 191/200 - Train Loss: 0.0207, Train Acc: 0.9822, Val Loss: 0.0347, Val Acc: 0.9773


Epoch 192/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.87it/s]
Epoch 192/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 33.08it/s]
Epoch 192/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 35.11it/s]


Epoch 192/200 - Train Loss: 0.0205, Train Acc: 0.9826, Val Loss: 0.0343, Val Acc: 0.9785


Epoch 193/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.84it/s]
Epoch 193/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 33.17it/s]
Epoch 193/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 34.94it/s]


Epoch 193/200 - Train Loss: 0.0204, Train Acc: 0.9831, Val Loss: 0.0345, Val Acc: 0.9786


Epoch 194/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.83it/s]
Epoch 194/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 33.01it/s]
Epoch 194/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 34.84it/s]


Epoch 194/200 - Train Loss: 0.0200, Train Acc: 0.9837, Val Loss: 0.0353, Val Acc: 0.9772


Epoch 195/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.84it/s]
Epoch 195/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 33.18it/s]
Epoch 195/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 34.89it/s]


Epoch 195/200 - Train Loss: 0.0206, Train Acc: 0.9827, Val Loss: 0.0348, Val Acc: 0.9772


Epoch 196/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.60it/s]
Epoch 196/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 32.22it/s]
Epoch 196/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 33.33it/s]


Epoch 196/200 - Train Loss: 0.0201, Train Acc: 0.9829, Val Loss: 0.0347, Val Acc: 0.9773


Epoch 197/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.60it/s]
Epoch 197/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.80it/s]
Epoch 197/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 32.73it/s]


Epoch 197/200 - Train Loss: 0.0201, Train Acc: 0.9832, Val Loss: 0.0353, Val Acc: 0.9767


Epoch 198/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.40it/s]
Epoch 198/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 30.74it/s]
Epoch 198/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 32.12it/s]


Epoch 198/200 - Train Loss: 0.0199, Train Acc: 0.9833, Val Loss: 0.0347, Val Acc: 0.9780


Epoch 199/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.51it/s]
Epoch 199/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.47it/s]
Epoch 199/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 32.84it/s]


Epoch 199/200 - Train Loss: 0.0198, Train Acc: 0.9828, Val Loss: 0.0347, Val Acc: 0.9777


Epoch 200/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.54it/s]
Epoch 200/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 32.03it/s]
Epoch 200/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 32.90it/s]

Epoch 200/200 - Train Loss: 0.0191, Train Acc: 0.9839, Val Loss: 0.0349, Val Acc: 0.9774





0,1
epoch,▁▁▁▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▆▆▆▇▇▇▇▇▇▇▇▇██████
hamming_loss,████████████████▇▇▇▇▆▆▅▅▅▅▅▄▅▅▃▂▂▂▁▂▁▁▁▁
lr,▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▂▂▂▃▃▃▃▃▄▄▅▅▅▅▆▆▆▆▆▇▇▇█
macro_f1,▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▂▂▃▃▃▃▃▄▄▄▄▅▅▅▅▅▆▇▇████
macro_precision,▁███████████████▇▇██▇▇▇▇▇▇▇▇▆▇▆▆▆▇▆▆▆▆▆▆
macro_recall,▂▁▁▁▁▁▁▁▁▁▁▁▂▂▂▄▄▄▄▄▅▅▅▅▅▆▅▆▆▇▆▆▆▇███▇██
micro_f1,▁▁▁▁▁▁▁▁▁▁▁▁▂▃▂▂▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▇▇▇▇▇█▇██
micro_precision,██████████▁▁▆▁▆▅▇▇▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇
micro_recall,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▃▃▃▄▄▄▄▅▅▆▆▆▆▇█▇▇██
test_acc,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▃▃▄▄▄▄▄▅▅▄▅▆▇▇▇▆█▇██

0,1
best_val_loss,0.03434
epoch,199.0
hamming_loss,0.02262
lr,4e-05
macro_f1,0.40395
macro_precision,0.81217
macro_recall,0.33485
micro_f1,0.56873
micro_precision,0.86506
micro_recall,0.42362


[34m[1mwandb[0m: Currently logged in as: [33mjoaocsgalmeida[0m ([33mjoaocsgalmeida-university-of-southampton[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin



----- Training Ensemble Model 2/5 -----



Epoch 1/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.47it/s]
Epoch 1/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.45it/s]
Epoch 1/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 32.52it/s]


Epoch 1/200 - Train Loss: 0.5215, Train Acc: 0.7864, Val Loss: 0.3781, Val Acc: 0.9531


Epoch 2/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.54it/s]
Epoch 2/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 32.00it/s]
Epoch 2/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 32.77it/s]


Epoch 2/200 - Train Loss: 0.3298, Train Acc: 0.9616, Val Loss: 0.2757, Val Acc: 0.9649


Epoch 3/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.54it/s]
Epoch 3/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.55it/s]
Epoch 3/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 32.93it/s]


Epoch 3/200 - Train Loss: 0.2587, Train Acc: 0.9638, Val Loss: 0.2283, Val Acc: 0.9650


Epoch 4/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.54it/s]
Epoch 4/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.94it/s]
Epoch 4/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 32.84it/s]


Epoch 4/200 - Train Loss: 0.2225, Train Acc: 0.9638, Val Loss: 0.2015, Val Acc: 0.9650


Epoch 5/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.51it/s]
Epoch 5/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.84it/s]
Epoch 5/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 33.10it/s]


Epoch 5/200 - Train Loss: 0.2011, Train Acc: 0.9638, Val Loss: 0.1843, Val Acc: 0.9650


Epoch 6/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.56it/s]
Epoch 6/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.87it/s]
Epoch 6/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 32.96it/s]


Epoch 6/200 - Train Loss: 0.1870, Train Acc: 0.9639, Val Loss: 0.1725, Val Acc: 0.9650


Epoch 7/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.54it/s]
Epoch 7/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.94it/s]
Epoch 7/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 32.91it/s]


Epoch 7/200 - Train Loss: 0.1769, Train Acc: 0.9639, Val Loss: 0.1635, Val Acc: 0.9650


Epoch 8/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  7.98it/s]
Epoch 8/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 26.55it/s]
Epoch 8/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 29.54it/s]


Epoch 8/200 - Train Loss: 0.1690, Train Acc: 0.9639, Val Loss: 0.1565, Val Acc: 0.9650


Epoch 9/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.12it/s]
Epoch 9/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.82it/s]
Epoch 9/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 30.07it/s]


Epoch 9/200 - Train Loss: 0.1626, Train Acc: 0.9638, Val Loss: 0.1506, Val Acc: 0.9650


Epoch 10/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.01it/s]
Epoch 10/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.55it/s]
Epoch 10/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 32.89it/s]


Epoch 10/200 - Train Loss: 0.1572, Train Acc: 0.9639, Val Loss: 0.1454, Val Acc: 0.9650


Epoch 11/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.52it/s]
Epoch 11/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.53it/s]
Epoch 11/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 31.47it/s]


Epoch 11/200 - Train Loss: 0.1524, Train Acc: 0.9639, Val Loss: 0.1408, Val Acc: 0.9650


Epoch 12/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.39it/s]
Epoch 12/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 26.64it/s]
Epoch 12/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 29.16it/s]


Epoch 12/200 - Train Loss: 0.1479, Train Acc: 0.9638, Val Loss: 0.1363, Val Acc: 0.9650


Epoch 13/200 - Training: 100%|██████████| 15/15 [00:02<00:00,  7.40it/s]
Epoch 13/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 28.55it/s]
Epoch 13/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 28.55it/s]


Epoch 13/200 - Train Loss: 0.1439, Train Acc: 0.9639, Val Loss: 0.1324, Val Acc: 0.9650


Epoch 14/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.08it/s]
Epoch 14/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.78it/s]
Epoch 14/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 29.82it/s]


Epoch 14/200 - Train Loss: 0.1401, Train Acc: 0.9639, Val Loss: 0.1287, Val Acc: 0.9650


Epoch 15/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  7.85it/s]
Epoch 15/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 27.31it/s]
Epoch 15/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 28.23it/s]


Epoch 15/200 - Train Loss: 0.1365, Train Acc: 0.9640, Val Loss: 0.1254, Val Acc: 0.9650


Epoch 16/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  7.99it/s]
Epoch 16/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.63it/s]
Epoch 16/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 32.96it/s]


Epoch 16/200 - Train Loss: 0.1332, Train Acc: 0.9638, Val Loss: 0.1219, Val Acc: 0.9650


Epoch 17/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.50it/s]
Epoch 17/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.81it/s]
Epoch 17/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 32.99it/s]


Epoch 17/200 - Train Loss: 0.1301, Train Acc: 0.9639, Val Loss: 0.1189, Val Acc: 0.9650


Epoch 18/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.38it/s]
Epoch 18/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.47it/s]
Epoch 18/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 32.93it/s]


Epoch 18/200 - Train Loss: 0.1271, Train Acc: 0.9639, Val Loss: 0.1161, Val Acc: 0.9650


Epoch 19/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.41it/s]
Epoch 19/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 30.77it/s]
Epoch 19/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 34.11it/s]


Epoch 19/200 - Train Loss: 0.1243, Train Acc: 0.9637, Val Loss: 0.1132, Val Acc: 0.9650


Epoch 20/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.67it/s]
Epoch 20/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 30.24it/s]
Epoch 20/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 32.61it/s]


Epoch 20/200 - Train Loss: 0.1216, Train Acc: 0.9638, Val Loss: 0.1108, Val Acc: 0.9650


Epoch 21/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.54it/s]
Epoch 21/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 33.21it/s]
Epoch 21/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 35.23it/s]


Epoch 21/200 - Train Loss: 0.1191, Train Acc: 0.9639, Val Loss: 0.1084, Val Acc: 0.9650


Epoch 22/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.56it/s]
Epoch 22/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 30.00it/s]
Epoch 22/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 32.57it/s]


Epoch 22/200 - Train Loss: 0.1167, Train Acc: 0.9638, Val Loss: 0.1061, Val Acc: 0.9650


Epoch 23/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.72it/s]
Epoch 23/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 33.48it/s]
Epoch 23/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 30.57it/s]


Epoch 23/200 - Train Loss: 0.1144, Train Acc: 0.9639, Val Loss: 0.1041, Val Acc: 0.9650


Epoch 24/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.32it/s]
Epoch 24/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 33.25it/s]
Epoch 24/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 31.42it/s]


Epoch 24/200 - Train Loss: 0.1122, Train Acc: 0.9638, Val Loss: 0.1019, Val Acc: 0.9650


Epoch 25/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.23it/s]
Epoch 25/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 33.30it/s]
Epoch 25/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 32.96it/s]


Epoch 25/200 - Train Loss: 0.1102, Train Acc: 0.9640, Val Loss: 0.1002, Val Acc: 0.9650


Epoch 26/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.45it/s]
Epoch 26/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 30.05it/s]
Epoch 26/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 35.08it/s]


Epoch 26/200 - Train Loss: 0.1082, Train Acc: 0.9638, Val Loss: 0.0982, Val Acc: 0.9650


Epoch 27/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.45it/s]
Epoch 27/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 29.37it/s]
Epoch 27/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 35.10it/s]


Epoch 27/200 - Train Loss: 0.1063, Train Acc: 0.9639, Val Loss: 0.0966, Val Acc: 0.9650


Epoch 28/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.73it/s]
Epoch 28/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 33.57it/s]
Epoch 28/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 34.97it/s]


Epoch 28/200 - Train Loss: 0.1044, Train Acc: 0.9638, Val Loss: 0.0949, Val Acc: 0.9650


Epoch 29/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.70it/s]
Epoch 29/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 33.43it/s]
Epoch 29/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 35.02it/s]


Epoch 29/200 - Train Loss: 0.1028, Train Acc: 0.9640, Val Loss: 0.0934, Val Acc: 0.9650


Epoch 30/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.09it/s]
Epoch 30/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 29.88it/s]
Epoch 30/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 32.94it/s]


Epoch 30/200 - Train Loss: 0.1009, Train Acc: 0.9639, Val Loss: 0.0920, Val Acc: 0.9650


Epoch 31/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.37it/s]
Epoch 31/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.72it/s]
Epoch 31/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 34.92it/s]


Epoch 31/200 - Train Loss: 0.0993, Train Acc: 0.9639, Val Loss: 0.0904, Val Acc: 0.9650


Epoch 32/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.88it/s]
Epoch 32/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 33.47it/s]
Epoch 32/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 30.27it/s]


Epoch 32/200 - Train Loss: 0.0977, Train Acc: 0.9640, Val Loss: 0.0893, Val Acc: 0.9650


Epoch 33/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.15it/s]
Epoch 33/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 28.36it/s]
Epoch 33/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 34.59it/s]


Epoch 33/200 - Train Loss: 0.0961, Train Acc: 0.9640, Val Loss: 0.0880, Val Acc: 0.9650


Epoch 34/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.11it/s]
Epoch 34/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 30.56it/s]
Epoch 34/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 31.66it/s]


Epoch 34/200 - Train Loss: 0.0945, Train Acc: 0.9639, Val Loss: 0.0867, Val Acc: 0.9650


Epoch 35/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  7.97it/s]
Epoch 35/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 30.31it/s]
Epoch 35/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 31.96it/s]


Epoch 35/200 - Train Loss: 0.0931, Train Acc: 0.9639, Val Loss: 0.0857, Val Acc: 0.9650


Epoch 36/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.01it/s]
Epoch 36/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 29.90it/s]
Epoch 36/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 33.20it/s]


Epoch 36/200 - Train Loss: 0.0918, Train Acc: 0.9640, Val Loss: 0.0844, Val Acc: 0.9650


Epoch 37/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.31it/s]
Epoch 37/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 33.10it/s]
Epoch 37/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 32.70it/s]


Epoch 37/200 - Train Loss: 0.0902, Train Acc: 0.9640, Val Loss: 0.0835, Val Acc: 0.9650


Epoch 38/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.47it/s]
Epoch 38/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 32.92it/s]
Epoch 38/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 35.17it/s]


Epoch 38/200 - Train Loss: 0.0888, Train Acc: 0.9640, Val Loss: 0.0825, Val Acc: 0.9650


Epoch 39/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.67it/s]
Epoch 39/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 33.11it/s]
Epoch 39/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 34.93it/s]


Epoch 39/200 - Train Loss: 0.0874, Train Acc: 0.9639, Val Loss: 0.0812, Val Acc: 0.9650


Epoch 40/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.53it/s]
Epoch 40/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 33.39it/s]
Epoch 40/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 30.73it/s]


Epoch 40/200 - Train Loss: 0.0860, Train Acc: 0.9640, Val Loss: 0.0801, Val Acc: 0.9650


Epoch 41/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.69it/s]
Epoch 41/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 33.49it/s]
Epoch 41/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 35.27it/s]


Epoch 41/200 - Train Loss: 0.0846, Train Acc: 0.9640, Val Loss: 0.0793, Val Acc: 0.9650


Epoch 42/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  9.25it/s]
Epoch 42/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 34.23it/s]
Epoch 42/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 35.40it/s]


Epoch 42/200 - Train Loss: 0.0834, Train Acc: 0.9640, Val Loss: 0.0783, Val Acc: 0.9650


Epoch 43/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.44it/s]
Epoch 43/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 34.23it/s]
Epoch 43/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 33.95it/s]


Epoch 43/200 - Train Loss: 0.0824, Train Acc: 0.9639, Val Loss: 0.0774, Val Acc: 0.9650


Epoch 44/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.30it/s]
Epoch 44/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 33.91it/s]
Epoch 44/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 35.19it/s]


Epoch 44/200 - Train Loss: 0.0812, Train Acc: 0.9640, Val Loss: 0.0765, Val Acc: 0.9650


Epoch 45/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.64it/s]
Epoch 45/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 32.25it/s]
Epoch 45/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 31.51it/s]


Epoch 45/200 - Train Loss: 0.0797, Train Acc: 0.9639, Val Loss: 0.0752, Val Acc: 0.9650


Epoch 46/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.89it/s]
Epoch 46/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 34.19it/s]
Epoch 46/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 34.92it/s]


Epoch 46/200 - Train Loss: 0.0785, Train Acc: 0.9640, Val Loss: 0.0746, Val Acc: 0.9650


Epoch 47/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.78it/s]
Epoch 47/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.64it/s]
Epoch 47/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 33.56it/s]


Epoch 47/200 - Train Loss: 0.0773, Train Acc: 0.9640, Val Loss: 0.0740, Val Acc: 0.9650


Epoch 48/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  9.18it/s]
Epoch 48/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 33.75it/s]
Epoch 48/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 35.17it/s]


Epoch 48/200 - Train Loss: 0.0764, Train Acc: 0.9640, Val Loss: 0.0733, Val Acc: 0.9650


Epoch 49/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  9.10it/s]
Epoch 49/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 34.33it/s]
Epoch 49/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 35.36it/s]


Epoch 49/200 - Train Loss: 0.0758, Train Acc: 0.9641, Val Loss: 0.0720, Val Acc: 0.9650


Epoch 50/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.74it/s]
Epoch 50/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 33.13it/s]
Epoch 50/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 31.86it/s]


Epoch 50/200 - Train Loss: 0.0744, Train Acc: 0.9641, Val Loss: 0.0717, Val Acc: 0.9650


Epoch 51/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.98it/s]
Epoch 51/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 32.98it/s]
Epoch 51/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 35.31it/s]


Epoch 51/200 - Train Loss: 0.0732, Train Acc: 0.9643, Val Loss: 0.0708, Val Acc: 0.9650


Epoch 52/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.91it/s]
Epoch 52/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.81it/s]
Epoch 52/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 33.08it/s]


Epoch 52/200 - Train Loss: 0.0723, Train Acc: 0.9642, Val Loss: 0.0703, Val Acc: 0.9650


Epoch 53/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  9.12it/s]
Epoch 53/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 34.32it/s]
Epoch 53/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 35.09it/s]


Epoch 53/200 - Train Loss: 0.0713, Train Acc: 0.9642, Val Loss: 0.0693, Val Acc: 0.9650


Epoch 54/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.21it/s]
Epoch 54/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 29.82it/s]
Epoch 54/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 29.76it/s]


Epoch 54/200 - Train Loss: 0.0701, Train Acc: 0.9642, Val Loss: 0.0685, Val Acc: 0.9650


Epoch 55/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  7.94it/s]
Epoch 55/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 29.99it/s]
Epoch 55/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 31.47it/s]


Epoch 55/200 - Train Loss: 0.0694, Train Acc: 0.9642, Val Loss: 0.0677, Val Acc: 0.9650


Epoch 56/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.34it/s]
Epoch 56/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 28.61it/s]
Epoch 56/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 29.62it/s]


Epoch 56/200 - Train Loss: 0.0683, Train Acc: 0.9643, Val Loss: 0.0667, Val Acc: 0.9650


Epoch 57/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  7.71it/s]
Epoch 57/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 32.65it/s]
Epoch 57/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 33.06it/s]


Epoch 57/200 - Train Loss: 0.0674, Train Acc: 0.9644, Val Loss: 0.0663, Val Acc: 0.9650


Epoch 58/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.68it/s]
Epoch 58/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 32.58it/s]
Epoch 58/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 33.56it/s]


Epoch 58/200 - Train Loss: 0.0665, Train Acc: 0.9645, Val Loss: 0.0655, Val Acc: 0.9650


Epoch 59/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.74it/s]
Epoch 59/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 33.77it/s]
Epoch 59/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 35.12it/s]


Epoch 59/200 - Train Loss: 0.0657, Train Acc: 0.9646, Val Loss: 0.0650, Val Acc: 0.9651


Epoch 60/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.48it/s]
Epoch 60/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 33.89it/s]
Epoch 60/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 27.79it/s]


Epoch 60/200 - Train Loss: 0.0650, Train Acc: 0.9644, Val Loss: 0.0644, Val Acc: 0.9650


Epoch 61/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.44it/s]
Epoch 61/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.22it/s]
Epoch 61/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 32.75it/s]


Epoch 61/200 - Train Loss: 0.0640, Train Acc: 0.9646, Val Loss: 0.0638, Val Acc: 0.9651


Epoch 62/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.47it/s]
Epoch 62/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.32it/s]
Epoch 62/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 32.82it/s]


Epoch 62/200 - Train Loss: 0.0631, Train Acc: 0.9646, Val Loss: 0.0630, Val Acc: 0.9651


Epoch 63/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.49it/s]
Epoch 63/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 32.00it/s]
Epoch 63/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 32.80it/s]


Epoch 63/200 - Train Loss: 0.0622, Train Acc: 0.9648, Val Loss: 0.0623, Val Acc: 0.9651


Epoch 64/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.50it/s]
Epoch 64/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 32.03it/s]
Epoch 64/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 32.79it/s]


Epoch 64/200 - Train Loss: 0.0615, Train Acc: 0.9647, Val Loss: 0.0616, Val Acc: 0.9651


Epoch 65/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.55it/s]
Epoch 65/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.80it/s]
Epoch 65/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 32.36it/s]


Epoch 65/200 - Train Loss: 0.0609, Train Acc: 0.9648, Val Loss: 0.0608, Val Acc: 0.9652


Epoch 66/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.47it/s]
Epoch 66/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 32.12it/s]
Epoch 66/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 31.64it/s]


Epoch 66/200 - Train Loss: 0.0601, Train Acc: 0.9650, Val Loss: 0.0604, Val Acc: 0.9651


Epoch 67/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.33it/s]
Epoch 67/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.62it/s]
Epoch 67/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 32.18it/s]


Epoch 67/200 - Train Loss: 0.0594, Train Acc: 0.9652, Val Loss: 0.0602, Val Acc: 0.9651


Epoch 68/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.52it/s]
Epoch 68/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 32.13it/s]
Epoch 68/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 32.97it/s]


Epoch 68/200 - Train Loss: 0.0587, Train Acc: 0.9651, Val Loss: 0.0598, Val Acc: 0.9652


Epoch 69/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.53it/s]
Epoch 69/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 32.08it/s]
Epoch 69/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 32.92it/s]


Epoch 69/200 - Train Loss: 0.0581, Train Acc: 0.9651, Val Loss: 0.0587, Val Acc: 0.9653


Epoch 70/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.41it/s]
Epoch 70/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.07it/s]
Epoch 70/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 32.44it/s]


Epoch 70/200 - Train Loss: 0.0573, Train Acc: 0.9653, Val Loss: 0.0583, Val Acc: 0.9654


Epoch 71/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.50it/s]
Epoch 71/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.95it/s]
Epoch 71/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 32.93it/s]


Epoch 71/200 - Train Loss: 0.0561, Train Acc: 0.9653, Val Loss: 0.0575, Val Acc: 0.9655


Epoch 72/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.54it/s]
Epoch 72/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.86it/s]
Epoch 72/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 32.81it/s]


Epoch 72/200 - Train Loss: 0.0557, Train Acc: 0.9657, Val Loss: 0.0574, Val Acc: 0.9655


Epoch 73/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.54it/s]
Epoch 73/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.92it/s]
Epoch 73/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 32.77it/s]


Epoch 73/200 - Train Loss: 0.0549, Train Acc: 0.9660, Val Loss: 0.0569, Val Acc: 0.9655


Epoch 74/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.90it/s]
Epoch 74/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 33.85it/s]
Epoch 74/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 34.54it/s]


Epoch 74/200 - Train Loss: 0.0543, Train Acc: 0.9659, Val Loss: 0.0563, Val Acc: 0.9661


Epoch 75/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.71it/s]
Epoch 75/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 31.98it/s]
Epoch 75/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 32.91it/s]


Epoch 75/200 - Train Loss: 0.0537, Train Acc: 0.9660, Val Loss: 0.0557, Val Acc: 0.9660


Epoch 76/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.68it/s]
Epoch 76/200 - Validation: 100%|██████████| 12/12 [00:00<00:00, 34.96it/s]
Epoch 76/200 - Testing: 100%|██████████| 13/13 [00:00<00:00, 34.37it/s]


Epoch 76/200 - Train Loss: 0.0531, Train Acc: 0.9662, Val Loss: 0.0555, Val Acc: 0.9657


Epoch 77/200 - Training: 100%|██████████| 15/15 [00:01<00:00,  8.61it/s]
Epoch 77/200 - Validation:   0%|          | 0/12 [00:00<?, ?it/s]


KeyboardInterrupt: 

# Memory efficient implementation

In [6]:
import torch
import torch.nn as nn
import numpy as np
from torch.utils.checkpoint import checkpoint
from mamba_ssm import Mamba2

class MemoryEfficientMamba(nn.Module):
    """
    Memory-efficient wrapper for Mamba2 with gradient checkpointing and 
    optional parameter quantization.
    """
    def __init__(self, d_model, d_state=16, d_conv=4, expand=2, use_checkpoint=True):
        super().__init__()
        self.mamba = Mamba2(
            d_model=d_model,
            d_state=d_state,
            d_conv=d_conv,
            expand=expand
        )
        self.use_checkpoint = use_checkpoint
    
    def forward(self, x):
        if self.use_checkpoint and self.training:
            return checkpoint(self.mamba, x)
        else:
            return self.mamba(x)

class MemoryEfficientStarClassifier(nn.Module):
    """
    Memory-efficient version of StarClassifierFusion with various
    optimizations to reduce VRAM usage.
    """
    def __init__(
        self,
        d_model_spectra,
        d_model_gaia,
        num_classes,
        input_dim_spectra,
        input_dim_gaia,
        n_layers=6,
        use_cross_attention=True,
        n_cross_attn_heads=8,
        d_state=16,  # Reduced from 256 to save memory
        d_conv=4,
        expand=2,
        use_checkpoint=True,
        activation_checkpointing=True,
        use_half_precision=True,
        sequential_processing=True
    ):
        super().__init__()
        self.use_checkpoint = use_checkpoint
        self.activation_checkpointing = activation_checkpointing
        self.sequential_processing = sequential_processing
        
        # Use lower precision
        self.dtype = torch.float16 if use_half_precision else torch.float32

        # Input projection layers
        self.input_proj_spectra = nn.Linear(input_dim_spectra, d_model_spectra)
        self.input_proj_gaia = nn.Linear(input_dim_gaia, d_model_gaia)
        
        # Memory-efficient Mamba layers
        self.mamba_spectra_layers = nn.ModuleList([
            MemoryEfficientMamba(
                d_model=d_model_spectra,
                d_state=d_state,
                d_conv=d_conv,
                expand=expand,
                use_checkpoint=activation_checkpointing
            ) for _ in range(n_layers)
        ])
        
        self.mamba_gaia_layers = nn.ModuleList([
            MemoryEfficientMamba(
                d_model=d_model_gaia,
                d_state=d_state,
                d_conv=d_conv,
                expand=expand,
                use_checkpoint=activation_checkpointing
            ) for _ in range(n_layers)
        ])

        # Cross-attention (optional)
        self.use_cross_attention = use_cross_attention
        if use_cross_attention:
            self.cross_attn_block_spectra = self._create_cross_attn_block(d_model_spectra, n_heads=n_cross_attn_heads)
            self.cross_attn_block_gaia = self._create_cross_attn_block(d_model_gaia, n_heads=n_cross_attn_heads)

        # Final classifier
        fusion_dim = d_model_spectra + d_model_gaia
        self.layer_norm = nn.LayerNorm(fusion_dim)
        self.classifier = nn.Linear(fusion_dim, num_classes)
    
    def _create_cross_attn_block(self, d_model, n_heads):
        """Creates a cross-attention block with optional gradient checkpointing."""
        class CrossAttentionBlock(nn.Module):
            def __init__(self, d_model, n_heads):
                super().__init__()
                self.cross_attn = nn.MultiheadAttention(
                    embed_dim=d_model, 
                    num_heads=n_heads, 
                    batch_first=True
                )
                self.norm1 = nn.LayerNorm(d_model)
                
                self.ffn = nn.Sequential(
                    nn.Linear(d_model, 4 * d_model),
                    nn.ReLU(),
                    nn.Linear(4 * d_model, d_model)
                )
                self.norm2 = nn.LayerNorm(d_model)
                
            def forward(self, x_q, x_kv):
                # Cross-attention
                attn_output, _ = self.cross_attn(query=x_q, key=x_kv, value=x_kv)
                x = self.norm1(x_q + attn_output)
                
                # Feed forward
                ffn_out = self.ffn(x)
                x = self.norm2(x + ffn_out)
                
                return x
        
        block = CrossAttentionBlock(d_model, n_heads)
        
        # Wrap with gradient checkpointing if requested
        if self.activation_checkpointing:
            def forward_with_checkpoint(module, x_q, x_kv):
                def custom_forward(x_q, x_kv):
                    return module(x_q, x_kv)
                return checkpoint(custom_forward, x_q, x_kv)
            
            class CheckpointedCrossAttention(nn.Module):
                def __init__(self, block):
                    super().__init__()
                    self.block = block
                
                def forward(self, x_q, x_kv):
                    return forward_with_checkpoint(self.block, x_q, x_kv)
            
            return CheckpointedCrossAttention(block)
        else:
            return block
    
    def _process_mamba_layers(self, x, layers):
        """Process input through Mamba layers, optionally sequentially to save memory."""
        if self.sequential_processing:
            for layer in layers:
                x = layer(x)
        else:
            # Process all layers at once (uses more memory but faster)
            for layer in layers:
                x = layer(x)
        return x
    
    def forward(self, x_spectra, x_gaia):
        # Convert to half precision if requested
        if hasattr(self, 'dtype') and self.dtype == torch.float16:
            x_spectra = x_spectra.half()
            x_gaia = x_gaia.half()
        
        # Project inputs
        x_spectra = self.input_proj_spectra(x_spectra)
        x_gaia = self.input_proj_gaia(x_gaia)
        
        # Add sequence dimension if needed
        if len(x_spectra.shape) == 2:
            x_spectra = x_spectra.unsqueeze(1)
        if len(x_gaia.shape) == 2:
            x_gaia = x_gaia.unsqueeze(1)
        
        # Process through Mamba layers
        x_spectra = self._process_mamba_layers(x_spectra, self.mamba_spectra_layers)
        x_gaia = self._process_mamba_layers(x_gaia, self.mamba_gaia_layers)
        
        # Optional cross-attention
        if self.use_cross_attention:
            x_spectra_fused = self.cross_attn_block_spectra(x_spectra, x_gaia)
            x_gaia_fused = self.cross_attn_block_gaia(x_gaia, x_spectra)
            x_spectra = x_spectra_fused
            x_gaia = x_gaia_fused
        
        # Pool across sequence dimension
        x_spectra = x_spectra.mean(dim=1)
        x_gaia = x_gaia.mean(dim=1)
        
        # Concatenate
        x_fused = torch.cat([x_spectra, x_gaia], dim=-1)
        
        # Final classification
        x_fused = self.layer_norm(x_fused)
        logits = self.classifier(x_fused)
        
        return logits

class MemoryEfficientEnsemble:
    """
    Memory-efficient implementation of ensemble for uncertainty quantification.
    Only keeps one model in VRAM at a time and uses quantization and other optimizations.
    """
    def __init__(
        self, 
        model_class, 
        model_args, 
        num_models=5, 
        device='cuda',
        low_memory_mode=True,
        quantize_models=True,
        offload_optimizer=True
    ):
        """
        Initialize the memory-efficient ensemble.
        
        Args:
            model_class: Model class to instantiate
            model_args: Arguments for model initialization
            num_models: Number of models in ensemble
            device: Device to use
            low_memory_mode: Whether to use memory optimizations
            quantize_models: Whether to quantize model weights
            offload_optimizer: Whether to offload optimizer states to CPU
        """
        self.model_class = model_class
        self.model_args = model_args
        self.num_models = num_models
        self.device = device
        self.low_memory_mode = low_memory_mode
        self.quantize_models = quantize_models
        self.offload_optimizer = offload_optimizer
        
        # Storage for model states (not actual models)
        self.model_states = [None] * num_models
        
        # Create only one model instance to save memory
        self.active_model = self._create_model(0)
        self.active_model_idx = 0
    
    def _create_model(self, model_idx):
        """Create a new model instance with appropriate seed."""
        torch.manual_seed(42 + model_idx)
        np.random.seed(42 + model_idx)
        model = self.model_class(**self.model_args).to(self.device)
        
        if self.quantize_models and hasattr(torch.quantization, 'quantize_dynamic'):
            # Only quantize during inference, not training
            model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
            model = torch.quantization.prepare(model)
        
        return model
    
    def _load_model_state(self, model_idx):
        """Load a specific model's state into the active model."""
        if self.model_states[model_idx] is not None:
            self.active_model.load_state_dict(self.model_states[model_idx])
            self.active_model_idx = model_idx
    
    def train(
        self, 
        train_loader, 
        val_loader, 
        test_loader=None, 
        num_epochs=100, 
        lr=1e-4, 
        max_patience=20,
        scheduler_type='OneCycleLR',
        log_to_wandb=True,
        save_checkpoints=True,
        checkpoint_dir='checkpoints'
    ):
        """
        Train all models in the ensemble sequentially to save memory.
        """
        import os
        import wandb
        from tqdm import tqdm
        import torch.optim as optim
        
        if save_checkpoints and not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)
        
        for model_idx in range(self.num_models):
            print(f"\n----- Training Ensemble Model {model_idx+1}/{self.num_models} -----\n")
            
            # Initialize wandb for this model
            if log_to_wandb:
                run = wandb.init(
                    project="ALLSTARS_memory_efficient_ensemble", 
                    name=f"model_{model_idx}",
                    group="memory_efficient_training",
                    config={
                        **self.model_args,
                        "model_idx": model_idx,
                        "num_models": self.num_models,
                        "lr": lr,
                        "max_patience": max_patience,
                        "scheduler_type": scheduler_type,
                        "num_epochs": num_epochs,
                        "low_memory_mode": self.low_memory_mode,
                        "quantize_models": self.quantize_models,
                        "offload_optimizer": self.offload_optimizer
                    },
                    reinit=True
                )
            
            # Create or load model
            if model_idx == 0 and self.active_model_idx == 0:
                # Use existing model for first run
                model = self.active_model
            else:
                # Create new model with appropriate seed for subsequent runs
                model = self._create_model(model_idx)
                self.active_model = model
                self.active_model_idx = model_idx
            
            # Move model to device
            model = model.to(self.device)
            
            # Create optimizer with memory-efficient options
            if self.offload_optimizer and hasattr(optim, 'SGD'):
                # SGD uses less memory than Adam/AdamW
                optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
            else:
                optimizer = optim.AdamW(model.parameters(), lr=lr)
            
            # Configure the scheduler
            if scheduler_type == 'OneCycleLR':
                scheduler = optim.lr_scheduler.OneCycleLR(
                    optimizer, 
                    max_lr=lr,
                    epochs=num_epochs, 
                    steps_per_epoch=len(train_loader)
                )
            else:  # ReduceLROnPlateau
                scheduler = optim.lr_scheduler.ReduceLROnPlateau(
                    optimizer, 
                    mode='min', 
                    factor=0.5, 
                    patience=int(max_patience / 5)
                )
            
            # Calculate class weights for imbalanced classes
            all_labels = []
            for _, _, y_batch in train_loader:
                all_labels.extend(y_batch.cpu().numpy())
            
            class_weights = self._calculate_class_weights(np.array(all_labels))
            class_weights = torch.tensor(class_weights, dtype=torch.float).to(self.device)
            criterion = nn.BCEWithLogitsLoss(pos_weight=class_weights)
            
            best_val_loss = float('inf')
            patience = max_patience
            best_model_state = None

            # Training loop
            for epoch in range(num_epochs):
                # Resample training data for balanced batches
                train_loader.dataset.re_sample()

                # Recompute class weights based on the new sampling
                all_labels = []
                for _, _, y_batch in train_loader:
                    all_labels.extend(y_batch.cpu().numpy())
                class_weights = self._calculate_class_weights(np.array(all_labels))
                class_weights = torch.tensor(class_weights, dtype=torch.float).to(self.device)
                criterion = nn.BCEWithLogitsLoss(pos_weight=class_weights)

                # --- Training Phase ---
                model.train()
                train_loss, train_acc = 0.0, 0.0
                
                # Use mixed precision training to save memory
                scaler = torch.cuda.amp.GradScaler() if hasattr(torch.cuda, 'amp') else None
                
                for X_spc, X_ga, y_batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Training"):
                    X_spc, X_ga, y_batch = X_spc.to(self.device), X_ga.to(self.device), y_batch.to(self.device)
                    
                    # Clear memory before forward pass
                    optimizer.zero_grad(set_to_none=True)
                    
                    # Use mixed precision for forward pass if available
                    if scaler:
                        with torch.cuda.amp.autocast():
                            outputs = model(X_spc, X_ga)
                            loss = criterion(outputs, y_batch)
                        
                        # Mixed precision backward pass
                        scaler.scale(loss).backward()
                        scaler.step(optimizer)
                        scaler.update()
                    else:
                        outputs = model(X_spc, X_ga)
                        loss = criterion(outputs, y_batch)
                        loss.backward()
                        optimizer.step()

                    train_loss += loss.item() * X_spc.size(0)
                    predicted = (torch.sigmoid(outputs) > 0.5).float()
                    correct = (predicted == y_batch).float()
                    train_acc += correct.mean(dim=1).sum().item()
                    
                    # Free up memory
                    del X_spc, X_ga, y_batch, outputs, loss, predicted, correct
                    torch.cuda.empty_cache()
                
                train_loss /= len(train_loader.dataset)
                train_acc /= len(train_loader.dataset)

                # --- Validation Phase ---
                model.eval()
                val_loss, val_acc = 0.0, 0.0
                with torch.no_grad():
                    for X_spc, X_ga, y_batch in tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Validation"):
                        X_spc, X_ga, y_batch = X_spc.to(self.device), X_ga.to(self.device), y_batch.to(self.device)
                        
                        # Use mixed precision for inference if available
                        if hasattr(torch.cuda, 'amp'):
                            with torch.cuda.amp.autocast():
                                outputs = model(X_spc, X_ga)
                                loss = criterion(outputs, y_batch)
                        else:
                            outputs = model(X_spc, X_ga)
                            loss = criterion(outputs, y_batch)
                        
                        val_loss += loss.item() * X_spc.size(0)
                        predicted = (torch.sigmoid(outputs) > 0.5).float()
                        correct = (predicted == y_batch).float()
                        val_acc += correct.mean(dim=1).sum().item()
                        
                        # Free up memory
                        del X_spc, X_ga, y_batch, outputs, loss, predicted, correct
                        torch.cuda.empty_cache()
                
                val_loss /= len(val_loader.dataset)
                val_acc /= len(val_loader.dataset)

                # --- Test Phase (if provided) ---
                test_metrics = {}
                if test_loader is not None:
                    test_loss, test_acc = 0.0, 0.0
                    y_true, y_pred = [], []
                    with torch.no_grad():
                        for X_spc, X_ga, y_batch in tqdm(test_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Testing"):
                            X_spc, X_ga, y_batch = X_spc.to(self.device), X_ga.to(self.device), y_batch.to(self.device)
                            
                            # Use mixed precision for inference if available
                            if hasattr(torch.cuda, 'amp'):
                                with torch.cuda.amp.autocast():
                                    outputs = model(X_spc, X_ga)
                                    loss = criterion(outputs, y_batch)
                            else:
                                outputs = model(X_spc, X_ga)
                                loss = criterion(outputs, y_batch)
                            
                            test_loss += loss.item() * X_spc.size(0)
                            
                            predicted = (torch.sigmoid(outputs) > 0.5).float()
                            correct = (predicted == y_batch).float()
                            test_acc += correct.mean(dim=1).sum().item()

                            # Collect for metrics calculation
                            y_true.extend(y_batch.cpu().numpy())
                            y_pred.extend(predicted.cpu().numpy())
                            
                            # Free up memory
                            del X_spc, X_ga, y_batch, outputs, loss, predicted, correct
                            torch.cuda.empty_cache()

                    test_loss /= len(test_loader.dataset)
                    test_acc /= len(test_loader.dataset)
                    
                    # Calculate metrics on CPU to save GPU memory
                    test_metrics = self._calculate_metrics(np.array(y_true), np.array(y_pred))
                    test_metrics.update({
                        "test_loss": test_loss,
                        "test_acc": test_acc,
                    })

                # Logging
                if log_to_wandb:
                    log_data = {
                        "epoch": epoch,
                        "train_loss": train_loss,
                        "val_loss": val_loss,
                        "train_acc": train_acc,
                        "val_acc": val_acc,
                        "lr": self._get_lr(optimizer)
                    }
                    log_data.update(test_metrics)
                    wandb.log(log_data)

                # Update learning rate scheduler
                if scheduler_type == 'OneCycleLR':
                    scheduler.step()
                else:  # ReduceLROnPlateau
                    scheduler.step(val_loss)

                # Early stopping and model saving
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    patience = max_patience
                    
                    # Save best model state (in memory or to disk)
                    if self.low_memory_mode and save_checkpoints:
                        # Save to disk
                        torch.save(
                            model.state_dict(), 
                            os.path.join(checkpoint_dir, f"model_{model_idx}_best.pth")
                        )
                    else:
                        # Save in memory
                        best_model_state = model.state_dict().copy()
                    
                    if log_to_wandb:
                        wandb.run.summary["best_val_loss"] = best_val_loss
                else:
                    patience -= 1
                    if patience <= 0:
                        print("Early stopping triggered.")
                        break

                # Periodic checkpoint saving
                if save_checkpoints and (epoch + 1) % 10 == 0:
                    torch.save(
                        model.state_dict(), 
                        os.path.join(checkpoint_dir, f"model_{model_idx}_epoch_{epoch+1}.pth")
                    )

            # Load the best model state
            if self.low_memory_mode and save_checkpoints:
                # Load from disk
                model.load_state_dict(
                    torch.load(os.path.join(checkpoint_dir, f"model_{model_idx}_best.pth"))
                )
            elif best_model_state is not None:
                # Load from memory
                model.load_state_dict(best_model_state)
            
            # Store the model state
            self.model_states[model_idx] = model.state_dict().copy()
            
            # Save final model
            if save_checkpoints:
                torch.save(
                    model.state_dict(),
                    os.path.join(checkpoint_dir, f"model_{model_idx}_final.pth")
                )
            
            # Close wandb run
            if log_to_wandb:
                wandb.finish()
            
            # Free up memory before next model
            if model_idx < self.num_models - 1:
                del model, optimizer, scheduler
                torch.cuda.empty_cache()
        
        return self.model_states
    
    def predict(self, loader, return_individual=False, batch_size=None):
        """
        Make predictions with the ensemble, one model at a time to save memory.
        
        Args:
            loader: DataLoader for the data to predict
            return_individual: Whether to return predictions from individual models
            batch_size: Optional batch size override for loader
            
        Returns:
            mean_probs: Mean probability across all models
            std_probs: Standard deviation of probabilities (uncertainty measure)
            individual_probs: (Optional) Predictions from each individual model
        """
        # Create a new dataloader with specified batch size if needed
        if batch_size is not None and batch_size != loader.batch_size:
            from torch.utils.data import DataLoader
            new_loader = DataLoader(
                loader.dataset, 
                batch_size=batch_size, 
                shuffle=False,
                num_workers=loader.num_workers
            )
            loader = new_loader
        
        # Get sample dimensions from first batch
        for X_spc, X_ga, y in loader:
            num_classes = y.shape[1]
            break
        
        # Get total number of samples
        num_samples = len(loader.dataset)
        
        # Storage for predictions
        all_probs = []
        
        # Process one model at a time
        for model_idx in range(self.num_models):
            print(f"Making predictions with model {model_idx+1}/{self.num_models}")
            
            # Load model state
            self._load_model_state(model_idx)
            self.active_model.eval()
            
            # Storage for this model's predictions
            model_probs = np.zeros((num_samples, num_classes))
            
            # Batch indices
            start_idx = 0
            
            with torch.no_grad():
                for X_spc, X_ga, _ in loader:
                    batch_size = X_spc.shape[0]
                    
                    # Move to device
                    X_spc, X_ga = X_spc.to(self.device), X_ga.to(self.device)
                    
                    # Use mixed precision if available
                    if hasattr(torch.cuda, 'amp'):
                        with torch.cuda.amp.autocast():
                            outputs = self.active_model(X_spc, X_ga)
                    else:
                        outputs = self.active_model(X_spc, X_ga)
                    
                    # Get probabilities
                    probs = torch.sigmoid(outputs).cpu().numpy()
                    
                    # Store predictions
                    end_idx = start_idx + batch_size
                    model_probs[start_idx:end_idx] = probs
                    start_idx = end_idx
                    
                    # Free up memory
                    del X_spc, X_ga, outputs, probs
                    torch.cuda.empty_cache()
            
            # Add to ensemble predictions
            all_probs.append(model_probs)
        
        # Stack along a new axis to get shape (num_models, num_samples, num_classes)
        all_probs = np.stack(all_probs, axis=0)
        
        # Calculate mean and std across models (axis=0)
        mean_probs = np.mean(all_probs, axis=0)
        std_probs = np.std(all_probs, axis=0)
        
        if return_individual:
            return mean_probs, std_probs, all_probs
        else:
            return mean_probs, std_probs
    
    def save_models(self, path_prefix):
        """
        Save all model states.
        
        Args:
            path_prefix: Path prefix for saving models
        """
        for i, state in enumerate(self.model_states):
            if state is not None:
                torch.save(state, f"{path_prefix}_model_{i}.pth")
    
    def load_models(self, path_prefix):
        """
        Load all model states.
        
        Args:
            path_prefix: Path prefix for loading models
        """
        self.model_states = []
        for i in range(self.num_models):
            state = torch.load(f"{path_prefix}_model_{i}.pth", map_location='cpu')
            self.model_states.append(state)
    
    def _calculate_class_weights(self, y):
        """Calculate class weights for handling imbalanced classes."""
        if y.ndim > 1:  
            class_counts = np.sum(y, axis=0)  
        else:
            class_counts = np.bincount(y)

        total_samples = y.shape[0] if y.ndim > 1 else len(y)
        class_counts = np.where(class_counts == 0, 1, class_counts)  # Prevent division by zero
        class_weights = total_samples / (len(class_counts) * class_counts)
        
        return class_weights
    
    def _calculate_metrics(self, y_true, y_pred):
        """Calculate evaluation metrics for multi-label classification."""
        from sklearn.metrics import f1_score, precision_score, recall_score, hamming_loss
        
        metrics = {
            "micro_f1": f1_score(y_true, y_pred, average='micro'),
            "macro_f1": f1_score(y_true, y_pred, average='macro'),
            "weighted_f1": f1_score(y_true, y_pred, average='weighted'),
            "micro_precision": precision_score(y_true, y_pred, average='micro', zero_division=1),
            "macro_precision": precision_score(y_true, y_pred, average='macro', zero_division=1),
            "weighted_precision": precision_score(y_true, y_pred, average='weighted', zero_division=1),
            "micro_recall": recall_score(y_true, y_pred, average='micro'),
            "macro_recall": recall_score(y_true, y_pred, average='macro'),
            "weighted_recall": recall_score(y_true, y_pred, average='weighted'),
            "hamming_loss": hamming_loss(y_true, y_pred)
        }
        
        return metrics
    
    def _get_lr(self, optimizer):
        """Get current learning rate from optimizer."""
        for param_group in optimizer.param_groups:
            return param_group['lr']

In [7]:
import torch
import os
import gc
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
#from memory_optimization import MemoryEfficientStarClassifier, MemoryEfficientEnsemble

# Track memory usage
def print_gpu_memory():
    if torch.cuda.is_available():
        print(f"GPU memory allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
        print(f"GPU memory reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
        # Force garbage collection
        gc.collect()
        torch.cuda.empty_cache()

# Example config with higher embedding dimensions
CONFIG = {
    "d_model_spectra": 2048,  # Doubled from original 2048
    "d_model_gaia": 2048,     # Doubled from original 2048
    "num_classes": 55,
    "input_dim_spectra": 3647,
    "input_dim_gaia": 18,
    "n_layers": 12,
    "d_state": 16,           # Reduced from 256 to save memory
    "d_conv": 4,
    "expand": 2,
    "use_cross_attention": True,
    "n_cross_attn_heads": 8,
    "use_checkpoint": True,
    "activation_checkpointing": True,
    "use_half_precision": True,
    "sequential_processing": True
}

if __name__ == "__main__":
    # Set device
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {device}")
    
    # Print initial memory usage
    print("\nInitial memory usage:")
    print_gpu_memory()
    
    # Create memory-efficient model
    print("\nCreating memory-efficient model...")
    model = MemoryEfficientStarClassifier(
        d_model_spectra=CONFIG["d_model_spectra"],
        d_model_gaia=CONFIG["d_model_gaia"],
        num_classes=CONFIG["num_classes"],
        input_dim_spectra=CONFIG["input_dim_spectra"],
        input_dim_gaia=CONFIG["input_dim_gaia"],
        n_layers=CONFIG["n_layers"],
        d_state=CONFIG["d_state"],
        d_conv=CONFIG["d_conv"],
        expand=CONFIG["expand"],
        use_cross_attention=CONFIG["use_cross_attention"],
        n_cross_attn_heads=CONFIG["n_cross_attn_heads"],
        use_checkpoint=CONFIG["use_checkpoint"],
        activation_checkpointing=CONFIG["activation_checkpointing"],
        use_half_precision=CONFIG["use_half_precision"],
        sequential_processing=CONFIG["sequential_processing"]
    ).to(device)
    
    # Print model memory usage
    print("\nMemory usage after creating model:")
    print_gpu_memory()
    
    # Print model size information
    param_size = sum(p.nelement() * (2 if CONFIG["use_half_precision"] else 4) for p in model.parameters()) / (1024**2)
    print(f"\nModel parameters size: {param_size:.2f} MB")
    print(f"Number of parameters: {sum(p.nelement() for p in model.parameters()):,}")
    
    # Create ensemble
    print("\nCreating memory-efficient ensemble...")
    ensemble = MemoryEfficientEnsemble(
        model_class=MemoryEfficientStarClassifier,
        model_args=CONFIG,
        num_models=5,
        device=device,
        low_memory_mode=True,
        quantize_models=True,
        offload_optimizer=True
    )
    
    # Print ensemble memory usage
    print("\nMemory usage after creating ensemble:")
    print_gpu_memory()
    
    # Generate some random data for demonstration
    print("\nGenerating random data...")
    batch_size = 32  # Use a smaller batch size to save memory
    
    # Create smaller mock datasets for demonstration
    class MockDataset:
        def __init__(self, num_samples=1000):
            self.X_spectra = torch.randn(num_samples, CONFIG["input_dim_spectra"])
            self.X_gaia = torch.randn(num_samples, CONFIG["input_dim_gaia"])
            self.y = torch.randint(0, 2, (num_samples, CONFIG["num_classes"])).float()
            self.indices = list(range(num_samples))
        
        def __len__(self):
            return len(self.indices)
        
        def __getitem__(self, idx):
            idx = self.indices[idx]
            return self.X_spectra[idx], self.X_gaia[idx], self.y[idx]
        
        def re_sample(self):
            np.random.shuffle(self.indices)
    
    # Create mock datasets
    train_dataset = MockDataset(num_samples=1000)
    val_dataset = MockDataset(num_samples=200)
    test_dataset = MockDataset(num_samples=200)
    
    # Create data loaders with smaller batches
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    print("\nMemory usage after creating data loaders:")
    print_gpu_memory()
    
    # Run a test forward pass
    print("\nRunning test forward pass...")
    model.eval()
    with torch.no_grad():
        for X_spectra, X_gaia, _ in train_loader:
            X_spectra = X_spectra.to(device)
            X_gaia = X_gaia.to(device)
            outputs = model(X_spectra, X_gaia)
            break
    
    print("\nMemory usage after forward pass:")
    print_gpu_memory()
    
    # Clean up
    del outputs, X_spectra, X_gaia
    torch.cuda.empty_cache()
    gc.collect()
    
    print("\nMemory usage after cleanup:")
    print_gpu_memory()
    
    # Demonstrating memory-efficient training (just a couple steps)
    print("\nDemonstrating memory-efficient training (small subset)...")
    # Only train for a few steps to demonstrate
    train_subset = [next(iter(train_loader)) for _ in range(2)]
    val_subset = [next(iter(val_loader))]
    
    # Create minimal loaders
    class MinimalLoader:
        def __init__(self, data):
            self.data = data
            self.dataset = MockDataset(len(data))  # Just for compatibility
        def __iter__(self):
            return iter(self.data)
        def __len__(self):
            return len(self.data)
    
    mini_train_loader = MinimalLoader(train_subset)
    mini_val_loader = MinimalLoader(val_subset)
    
    # Train for just a couple steps
    ensemble.train(
        train_loader=mini_train_loader,
        val_loader=mini_val_loader,
        num_epochs=2,
        lr=1e-4,
        max_patience=5,
        log_to_wandb=False,
        save_checkpoints=True,
        checkpoint_dir='memory_efficient_checkpoints'
    )
    
    print("\nMemory usage after training:")
    print_gpu_memory()
    
    # Save and load models
    if not os.path.exists('memory_efficient_models'):
        os.makedirs('memory_efficient_models')
    
    print("\nSaving models...")
    ensemble.save_models("memory_efficient_models/ensemble")
    
    print("\nLoading models...")
    ensemble.load_models("memory_efficient_models/ensemble")
    
    print("\nMemory usage after loading models:")
    print_gpu_memory()
    
    # Make predictions
    print("\nMaking predictions with ensemble...")
    mean_probs, std_probs = ensemble.predict(
        mini_val_loader, 
        batch_size=8  # Use even smaller batch size for prediction
    )
    
    print("\nMemory usage after prediction:")
    print_gpu_memory()
    
    print("\nPrediction shape:", mean_probs.shape)
    print("Uncertainty shape:", std_probs.shape)
    
    print("\nDone!")

Using device: cuda

Initial memory usage:
GPU memory allocated: 0.00 MB
GPU memory reserved: 0.00 MB

Creating memory-efficient model...

Memory usage after creating model:
GPU memory allocated: 2738.05 MB
GPU memory reserved: 2770.00 MB

Model parameters size: 1369.01 MB
Number of parameters: 717,756,727

Creating memory-efficient ensemble...




AssertionError: _add_observer_ only works with cpu or single-device CUDA modules, but got devices {device(type='cuda', index=0), device(type='cpu')}

# Dashboard

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset, DataLoader
import ipywidgets as widgets
from IPython.display import display, clear_output
import pandas as pd

class EnsembleDashboard:
    """
    Interactive dashboard for exploring ensemble predictions and uncertainty.
    """
    def __init__(self, ensemble, test_dataset, class_names, batch_size=32):
        """
        Initialize the dashboard.
        
        Args:
            ensemble: Trained DeepEnsemble instance
            test_dataset: Dataset containing test samples
            class_names: List of class names
            batch_size: Batch size for processing test samples
        """
        self.ensemble = ensemble
        self.test_dataset = test_dataset
        self.class_names = class_names
        self.batch_size = batch_size
        
        # Process the entire test set to get predictions
        self.prepare_predictions()
        
        # Create dashboard UI
        self.create_dashboard()
    
    def prepare_predictions(self):
        """Process all test samples to get predictions with uncertainty."""
        test_loader = DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False)
        
        # Get predictions with uncertainty
        self.mean_probs, self.std_probs = self.ensemble.predict(test_loader)
        
        # Get true labels
        self.true_labels = []
        for _, _, y in test_loader:
            self.true_labels.append(y.numpy())
        self.true_labels = np.concatenate(self.true_labels, axis=0)
        
        # Calculate binary predictions using default threshold
        self.threshold = 0.5
        self.predictions = (self.mean_probs >= self.threshold).astype(float)
        
        # Calculate sample-level uncertainty metrics
        self.sample_uncertainty = np.mean(self.std_probs, axis=1)
        self.sample_error = np.mean(np.abs(self.predictions - self.true_labels), axis=1)
        
        # Calculate class-level metrics
        self.class_accuracy = np.mean((self.predictions == self.true_labels), axis=0)
        self.class_uncertainty = np.mean(self.std_probs, axis=0)
    
    def create_dashboard(self):
        """Create the interactive dashboard."""
        # Create sample selection slider
        self.sample_slider = widgets.IntSlider(
            value=0,
            min=0,
            max=len(self.test_dataset) - 1,
            step=1,
            description='Sample:',
            continuous_update=False
        )
        
        # Create threshold slider
        self.threshold_slider = widgets.FloatSlider(
            value=0.5,
            min=0.1,
            max=0.9,
            step=0.05,
            description='Threshold:',
            continuous_update=False
        )
        
        # Create view selection tabs
        self.tab = widgets.Tab()
        self.tab.children = [
            widgets.VBox([widgets.Label('Sample Prediction')]),
            widgets.VBox([widgets.Label('Class Performance')]),
            widgets.VBox([widgets.Label('Uncertainty Analysis')]),
            widgets.VBox([widgets.Label('Calibration')])
        ]
        self.tab.set_title(0, 'Sample View')
        self.tab.set_title(1, 'Class View')
        self.tab.set_title(2, 'Uncertainty')
        self.tab.set_title(3, 'Calibration')
        
        # Create output area
        self.output = widgets.Output()
        
        # Create update button
        self.update_button = widgets.Button(
            description='Update View',
            button_style='primary',
            tooltip='Click to update the visualization'
        )
        self.update_button.on_click(self.update_view)
        
        # Create filter for high uncertainty
        self.uncertainty_filter = widgets.Checkbox(
            value=False,
            description='Show high uncertainty samples only',
            tooltip='Filter to show only samples with high uncertainty'
        )
        
        # Create export button
        self.export_button = widgets.Button(
            description='Export Results',
            button_style='info',
            tooltip='Export predictions and uncertainty to CSV'
        )
        self.export_button.on_click(self.export_results)
        
        # Arrange widgets
        controls = widgets.VBox([
            self.sample_slider, 
            self.threshold_slider,
            self.uncertainty_filter,
            widgets.HBox([self.update_button, self.export_button])
        ])
        
        # Create dashboard layout
        dashboard = widgets.VBox([
            widgets.HTML('<h2>Ensemble Uncertainty Dashboard</h2>'),
            controls,
            self.tab,
            self.output
        ])
        
        # Display dashboard
        display(dashboard)
        
        # Initial update
        self.update_view(None)
    
    def update_view(self, _):
        """Update the visualization based on current settings."""
        with self.output:
            clear_output(wait=True)
            
            # Get current values
            sample_idx = self.sample_slider.value
            threshold = self.threshold_slider.value
            
            # Update predictions if threshold changed
            if threshold != self.threshold:
                self.threshold = threshold
                self.predictions = (self.mean_probs >= self.threshold).astype(float)
                self.sample_error = np.mean(np.abs(self.predictions - self.true_labels), axis=1)
                self.class_accuracy = np.mean((self.predictions == self.true_labels), axis=0)
            
            # Filter for high uncertainty if selected
            if self.uncertainty_filter.value:
                high_uncertainty_threshold = np.percentile(self.sample_uncertainty, 80)
                high_uncertainty_indices = np.where(self.sample_uncertainty >= high_uncertainty_threshold)[0]
                if len(high_uncertainty_indices) > 0:
                    sample_idx = high_uncertainty_indices[sample_idx % len(high_uncertainty_indices)]
                    self.sample_slider.value = sample_idx
            
            # Get current tab index
            tab_idx = self.tab.selected_index
            
            # Sample View
            if tab_idx == 0:
                # Get a single sample
                X_spectra, X_gaia, y_true = self.test_dataset[sample_idx]
                
                # Get predictions for this sample from all models
                mean_prob, std_prob, all_probs = self.ensemble.predict_sample(X_spectra, X_gaia)
                
                fig = plt.figure(figsize=(15, 10))
                
                # Create a 2x2 subplot grid
                gs = fig.add_gridspec(2, 2)
                
                # Prediction with uncertainty for top classes
                ax1 = fig.add_subplot(gs[0, :])
                self._plot_sample_prediction(ax1, mean_prob, std_prob, y_true.numpy(), threshold)
                
                # Individual model predictions
                ax2 = fig.add_subplot(gs[1, 0])
                self._plot_model_disagreement(ax2, all_probs, mean_prob, y_true.numpy(), threshold)
                
                # Sample metadata or additional info
                ax3 = fig.add_subplot(gs[1, 1])
                self._plot_sample_metadata(ax3, sample_idx, mean_prob, std_prob, y_true.numpy())
                
                plt.tight_layout()
                plt.show()
            
            # Class View
            elif tab_idx == 1:
                fig = plt.figure(figsize=(15, 10))
                
                # Create a 2x2 subplot grid
                gs = fig.add_gridspec(2, 2)
                
                # Class accuracy vs uncertainty
                ax1 = fig.add_subplot(gs[0, 0])
                self._plot_class_accuracy_vs_uncertainty(ax1)
                
                # Top/bottom classes by uncertainty
                ax2 = fig.add_subplot(gs[0, 1])
                self._plot_top_bottom_uncertain_classes(ax2)
                
                # Confusion matrix for a few selected classes
                ax3 = fig.add_subplot(gs[1, :])
                self._plot_class_prediction_distribution(ax3)
                
                plt.tight_layout()
                plt.show()
            
            # Uncertainty Analysis
            elif tab_idx == 2:
                fig = plt.figure(figsize=(15, 10))
                
                # Create a 2x2 subplot grid
                gs = fig.add_gridspec(2, 2)
                
                # Uncertainty distribution
                ax1 = fig.add_subplot(gs[0, 0])
                self._plot_uncertainty_distribution(ax1)
                
                # Uncertainty vs error
                ax2 = fig.add_subplot(gs[0, 1])
                self._plot_uncertainty_vs_error(ax2)
                
                # High uncertainty samples
                ax3 = fig.add_subplot(gs[1, 0])
                self._plot_high_uncertainty_analysis(ax3)
                
                # Low uncertainty samples
                ax4 = fig.add_subplot(gs[1, 1])
                self._plot_low_uncertainty_analysis(ax4)
                
                plt.tight_layout()
                plt.show()
            
            # Calibration
            elif tab_idx == 3:
                fig = plt.figure(figsize=(15, 10))
                
                # Create a 2x2 subplot grid
                gs = fig.add_gridspec(2, 2)
                
                # Calibration curve
                ax1 = fig.add_subplot(gs[0, :])
                self._plot_calibration_curve(ax1)
                
                # Reliability diagram
                ax2 = fig.add_subplot(gs[1, 0])
                self._plot_reliability_diagram(ax2)
                
                # Calibration by uncertainty level
                ax3 = fig.add_subplot(gs[1, 1])
                self._plot_calibration_by_uncertainty(ax3)
                
                plt.tight_layout()
                plt.show()
    
    def _plot_sample_prediction(self, ax, mean_prob, std_prob, true_label, threshold):
        """Plot prediction with uncertainty bars for a sample."""
        # Get indices of top k predictions by mean probability
        top_k = 10
        top_indices = np.argsort(mean_prob)[::-1][:top_k]
        
        # Extract top k values
        top_means = mean_prob[top_indices]
        top_stds = std_prob[top_indices]
        top_classes = [self.class_names[i] for i in top_indices]
        
        # Create horizontal bar chart with error bars
        y_pos = np.arange(len(top_classes))
        bar_colors = ['green' if prob >= threshold else 'red' for prob in top_means]
        
        # Plot bars
        ax.barh(y_pos, top_means, xerr=top_stds, align='center', 
               alpha=0.7, color=bar_colors, capsize=5)
        
        # Add true labels markers
        true_indices = np.where(true_label == 1)[0]
        for i, idx in enumerate(top_indices):
            if idx in true_indices:
                ax.get_children()[i].set_edgecolor('blue')
                ax.get_children()[i].set_linewidth(2)
        
        # Add threshold line
        ax.axvline(x=threshold, color='gray', linestyle='--', alpha=0.7)
        
        # Customize plot
        ax.set_yticks(y_pos)
        ax.set_yticklabels(top_classes)
        ax.invert_yaxis()  # labels read top-to-bottom
        ax.set_xlabel('Probability')
        ax.set_title('Prediction with Uncertainty')
        
        # Add gridlines
        ax.grid(axis='x', linestyle='--', alpha=0.7)
        
        # Add legend
        from matplotlib.patches import Patch
        legend_elements = [
            Patch(facecolor='green', edgecolor='black', label='Above threshold'),
            Patch(facecolor='red', edgecolor='black', label='Below threshold'),
            Patch(facecolor='white', edgecolor='blue', linewidth=2, label='True label')
        ]
        
        ax.legend(handles=legend_elements, loc='lower right')
        
        # Add text annotations for probabilities
        for i, (mean, std) in enumerate(zip(top_means, top_stds)):
            ax.text(mean + 0.02, i, f'{mean:.2f} ± {std:.2f}', va='center')
    
    def _plot_model_disagreement(self, ax, all_probs, mean_prob, true_label, threshold):
        """Plot individual model predictions to show disagreement."""
        # Get top classes
        top_k = 5
        top_indices = np.argsort(mean_prob)[::-1][:top_k]
        top_classes = [self.class_names[i] for i in top_indices]
        
        # Prepare data for plotting
        model_indices = np.arange(all_probs.shape[0])
        
        # Create a grouped bar chart
        bar_width = 0.8 / top_k
        for i, class_idx in enumerate(top_indices):
            probs = all_probs[:, class_idx]
            pos = model_indices - 0.4 + i * bar_width
            bars = ax.bar(pos, probs, bar_width, alpha=0.7,
                        label=f"{top_classes[i]}")
            
            # Highlight true labels
            if class_idx in np.where(true_label == 1)[0]:
                for bar in bars:
                    bar.set_edgecolor('blue')
                    bar.set_linewidth(1.5)
        
        # Add threshold line
        ax.axhline(y=threshold, color='black', linestyle='--', alpha=0.5)
        
        # Customize plot
        ax.set_ylabel('Probability')
        ax.set_xlabel('Model')
        ax.set_title('Individual Model Predictions')
        ax.set_xticks(model_indices)
        ax.set_xticklabels([f'Model {i+1}' for i in model_indices])
        ax.legend(loc='upper right')
        ax.grid(axis='y', linestyle='--', alpha=0.7)
    
    def _plot_sample_metadata(self, ax, sample_idx, mean_prob, std_prob, true_label):
        """Plot metadata and summary statistics for the selected sample."""
        # Get uncertainty and error metrics
        uncertainty = np.mean(std_prob)
        
        # Predicted classes (above threshold)
        threshold = self.threshold
        pred_indices = np.where(mean_prob >= threshold)[0]
        pred_classes = [self.class_names[i] for i in pred_indices]
        
        # True classes
        true_indices = np.where(true_label == 1)[0]
        true_classes = [self.class_names[i] for i in true_indices]
        
        # Calculate metrics
        if len(true_indices) > 0:
            # True positives, false positives, false negatives
            tp = np.sum((mean_prob >= threshold) & (true_label == 1))
            fp = np.sum((mean_prob >= threshold) & (true_label == 0))
            fn = np.sum((mean_prob < threshold) & (true_label == 1))
            
            # Precision, recall, F1
            precision = tp / (tp + fp) if (tp + fp) > 0 else 0
            recall = tp / (tp + fn) if (tp + fn) > 0 else 0
            f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
            
            # Accuracy
            accuracy = np.mean((mean_prob >= threshold).astype(int) == true_label)
        else:
            precision, recall, f1, accuracy = 0, 0, 0, 0
        
        # Clear the axis for text display
        ax.axis('off')
        
        # Create a text summary
        text = f"Sample #{sample_idx}\n\n"
        text += f"Average Uncertainty: {uncertainty:.3f}\n"
        text += f"Decision Threshold: {threshold:.2f}\n\n"
        
        text += f"Performance Metrics:\n"
        text += f"Accuracy: {accuracy:.3f}\n"
        text += f"Precision: {precision:.3f}\n"
        text += f"Recall: {recall:.3f}\n"
        text += f"F1 Score: {f1:.3f}\n\n"
        
        text += f"Predicted Classes ({len(pred_classes)}):\n"
        text += ", ".join(pred_classes[:5])
        if len(pred_classes) > 5:
            text += f" ... ({len(pred_classes) - 5} more)"
        text += "\n\n"
        
        text += f"True Classes ({len(true_classes)}):\n"
        text += ", ".join(true_classes[:5])
        if len(true_classes) > 5:
            text += f" ... ({len(true_classes) - 5} more)"
        
        # Display the text
        ax.text(0, 1, text, fontsize=10, va='top', linespacing=1.5)
        ax.set_title('Sample Information')
    
    def _plot_class_accuracy_vs_uncertainty(self, ax):
        """Plot class accuracy vs. uncertainty scatter plot."""
        # Get class-level accuracy and uncertainty
        class_accuracy = np.mean((self.predictions == self.true_labels), axis=0)
        class_uncertainty = np.mean(self.std_probs, axis=0)
        
        # Create scatter plot
        scatter = ax.scatter(class_uncertainty, class_accuracy, alpha=0.7, 
                          s=50, c=np.sum(self.true_labels, axis=0), cmap='viridis')
        
        # Add correlation line
        from scipy.stats import linregress
        slope, intercept, r_value, _, _ = linregress(class_uncertainty, class_accuracy)
        x = np.linspace(min(class_uncertainty), max(class_uncertainty), 100)
        y = slope * x + intercept
        ax.plot(x, y, 'r--', alpha=0.7)
        
        # Add correlation text
        ax.text(0.05, 0.05, f'r = {r_value:.3f}', transform=ax.transAxes,
              bbox=dict(facecolor='white', alpha=0.7))
        
        # Add class labels for outliers (most uncertain or least accurate)
        outliers = np.argsort(class_uncertainty)[-5:]  # Top 5 most uncertain
        for i in outliers:
            ax.annotate(self.class_names[i], 
                      (class_uncertainty[i], class_accuracy[i]),
                      xytext=(5, 5), textcoords='offset points')
        
        # Add colorbar to indicate class frequency
        cbar = plt.colorbar(scatter, ax=ax)
        cbar.set_label('Class Frequency')
        
        # Customize plot
        ax.set_xlabel('Average Uncertainty')
        ax.set_ylabel('Class Accuracy')
        ax.set_title('Class Accuracy vs. Uncertainty')
        ax.grid(alpha=0.3)
    
    def _plot_top_bottom_uncertain_classes(self, ax):
        """Plot top and bottom classes by uncertainty."""
        # Sort classes by uncertainty
        class_uncertainty = np.mean(self.std_probs, axis=0)
        sorted_indices = np.argsort(class_uncertainty)
        
        # Get top and bottom 5 classes
        bottom_indices = sorted_indices[:5]  # Least uncertain
        top_indices = sorted_indices[-5:][::-1]  # Most uncertain
        
        # Combine and prepare for plotting
        all_indices = np.concatenate([top_indices, bottom_indices])
        all_classes = [self.class_names[i] for i in all_indices]
        all_uncertainties = class_uncertainty[all_indices]
        all_accuracies = self.class_accuracy[all_indices]
        
        # Create colors for bars
        colors = ['red'] * 5 + ['green'] * 5
        
        # Plot horizontal bars for uncertainty
        y_pos = np.arange(len(all_classes))
        bars = ax.barh(y_pos, all_uncertainties, align='center', color=colors, alpha=0.7)
        
        # Add accuracy as text
        for i, acc in enumerate(all_accuracies):
            ax.text(all_uncertainties[i] + 0.005, i, f'Acc: {acc:.2f}', va='center')
        
        # Customize plot
        ax.set_yticks(y_pos)
        ax.set_yticklabels(all_classes)
        ax.invert_yaxis()  # Labels read top-to-bottom
        ax.set_xlabel('Average Uncertainty')
        ax.set_title('Most and Least Uncertain Classes')
        ax.grid(axis='x', alpha=0.3)
        
        # Add legend
        from matplotlib.patches import Patch
        legend_elements = [
            Patch(facecolor='red', label='High Uncertainty'),
            Patch(facecolor='green', label='Low Uncertainty')
        ]
        ax.legend(handles=legend_elements, loc='lower right')
    
    def _plot_class_prediction_distribution(self, ax):
        """Plot distribution of predictions by class."""
        # Select a few representative classes
        class_uncertainty = np.mean(self.std_probs, axis=0)
        
        # Pick classes with varying uncertainty levels
        sorted_by_uncertainty = np.argsort(class_uncertainty)
        num_classes = len(sorted_by_uncertainty)
        
        # Choose 5 classes spread across the uncertainty spectrum
        indices = [
            sorted_by_uncertainty[0],  # Lowest uncertainty
            sorted_by_uncertainty[num_classes // 4],
            sorted_by_uncertainty[num_classes // 2],
            sorted_by_uncertainty[3 * num_classes // 4],
            sorted_by_uncertainty[-1]  # Highest uncertainty
        ]
        
        selected_classes = [self.class_names[i] for i in indices]
        
        # For each selected class, compute the distribution of predicted probabilities
        # separated for true positive and true negative samples
        data = []
        for i, class_idx in enumerate(indices):
            # Get positive and negative samples for this class
            pos_mask = self.true_labels[:, class_idx] == 1
            neg_mask = ~pos_mask
            
            if np.sum(pos_mask) > 0:
                # Positive samples
                pos_probs = self.mean_probs[pos_mask, class_idx]
                pos_uncertainties = self.std_probs[pos_mask, class_idx]
                
                # Create violin plot data
                pos_data = {
                    'Class': selected_classes[i],
                    'Type': 'True Positive',
                    'Probability': pos_probs,
                    'Uncertainty': pos_uncertainties
                }
                data.append(pos_data)
            
            if np.sum(neg_mask) > 0:
                # Negative samples
                neg_probs = self.mean_probs[neg_mask, class_idx]
                neg_uncertainties = self.std_probs[neg_mask, class_idx]
                
                # Create violin plot data
                neg_data = {
                    'Class': selected_classes[i],
                    'Type': 'True Negative',
                    'Probability': neg_probs,
                    'Uncertainty': neg_uncertainties
                }
                data.append(neg_data)
        
        # Create violin plots
        for i, d in enumerate(data):
            # Position on x-axis
            x_pos = i
            
            # Create violin plot
            violin_parts = ax.violinplot(d['Probability'], positions=[x_pos], widths=0.8,
                                       showmeans=False, showmedians=True, showextrema=True)
            
            # Color by type (TP or TN)
            color = 'blue' if d['Type'] == 'True Positive' else 'red'
            for pc in violin_parts['bodies']:
                pc.set_facecolor(color)
                pc.set_alpha(0.7)
        
        # Add threshold line
        ax.axhline(y=self.threshold, color='black', linestyle='--', alpha=0.5)
        
        # Customize plot
        ax.set_xticks(range(len(data)))
        ax.set_xticklabels([f"{d['Class']}\n({d['Type']})" for d in data], rotation=45, ha='right')
        ax.set_ylabel('Predicted Probability')
        ax.set_title('Distribution of Predicted Probabilities by Class')
        ax.grid(axis='y', alpha=0.3)
        
        # Add uncertainty annotations
        for i, d in enumerate(data):
            mean_uncertainty = np.mean(d['Uncertainty'])
            ax.text(i, 0.05, f'Unc: {mean_uncertainty:.3f}', ha='center',
                  bbox=dict(facecolor='white', alpha=0.7))
    
    def _plot_uncertainty_distribution(self, ax):
        """Plot the distribution of uncertainty values."""
        # Plot histogram of sample-level uncertainty
        ax.hist(self.sample_uncertainty, bins=30, alpha=0.7, color='blue')
        
        # Mark vertical line for median
        median_uncertainty = np.median(self.sample_uncertainty)
        ax.axvline(x=median_uncertainty, color='red', linestyle='--', alpha=0.7)
        ax.text(median_uncertainty + 0.001, 0.9 * ax.get_ylim()[1], 
              f'Median: {median_uncertainty:.3f}', color='red')
        
        # Customize plot
        ax.set_xlabel('Sample Uncertainty (mean std across classes)')
        ax.set_ylabel('Count')
        ax.set_title('Distribution of Uncertainty Values')
        ax.grid(alpha=0.3)
    
    def _plot_uncertainty_vs_error(self, ax):
        """Plot uncertainty vs. error scatter plot."""
        # Create scatter plot
        scatter = ax.scatter(self.sample_uncertainty, self.sample_error, 
                          alpha=0.3, s=10, c=np.sum(self.true_labels, axis=1), cmap='viridis')
        
        # Add correlation line
        from scipy.stats import linregress
        slope, intercept, r_value, _, _ = linregress(self.sample_uncertainty, self.sample_error)
        x = np.linspace(min(self.sample_uncertainty), max(self.sample_uncertainty), 100)
        y = slope * x + intercept
        ax.plot(x, y, 'r--', alpha=0.7)
        
        # Add correlation text
        ax.text(0.05, 0.95, f'r = {r_value:.3f}', transform=ax.transAxes,
              bbox=dict(facecolor='white', alpha=0.7))
        
        # Customize plot
        ax.set_xlabel('Uncertainty')
        ax.set_ylabel('Error Rate')
        ax.set_title('Uncertainty vs. Error Correlation')
        ax.grid(alpha=0.3)
        
        # Add colorbar
        cbar = plt.colorbar(scatter, ax=ax)
        cbar.set_label('Number of True Classes')
    
    def _plot_high_uncertainty_analysis(self, ax):
        """Analyze predictions with high uncertainty."""
        # Get high uncertainty threshold (e.g., top 10%)
        high_threshold = np.percentile(self.sample_uncertainty, 90)
        high_mask = self.sample_uncertainty >= high_threshold
        
        # Count of samples with high uncertainty
        num_high = np.sum(high_mask)
        
        # Average error for high uncertainty samples
        high_error = np.mean(self.sample_error[high_mask]) if num_high > 0 else 0
        
        # Average error for other samples
        other_error = np.mean(self.sample_error[~high_mask]) if np.sum(~high_mask) > 0 else 0
        
        # Error ratio (high uncertainty vs other)
        error_ratio = high_error / other_error if other_error > 0 else 0
        
        # Distribution of high uncertainty predictions by class
        high_uncertain_class_counts = np.sum(self.predictions[high_mask], axis=0)
        
        # Get top classes with high uncertainty predictions
        top_indices = np.argsort(high_uncertain_class_counts)[::-1][:10]
        top_classes = [self.class_names[i] for i in top_indices]
        top_counts = high_uncertain_class_counts[top_indices]
        
        # Plot bar chart of top classes with high uncertainty
        y_pos = np.arange(len(top_classes))
        bars = ax.barh(y_pos, top_counts, align='center', alpha=0.7)
        
        # Add text annotations for error rates by class
        for i, idx in enumerate(top_indices):
            # Calculate error rate for this class in high uncertainty samples
            class_preds = self.predictions[high_mask, idx]
            class_true = self.true_labels[high_mask, idx]
            class_error = np.mean(np.abs(class_preds - class_true)) if len(class_preds) > 0 else 0
            
            # Add text annotation
            ax.text(top_counts[i] + 0.5, i, f'Error: {class_error:.2f}', va='center')
        
        # Customize plot
        ax.set_yticks(y_pos)
        ax.set_yticklabels(top_classes)
        ax.invert_yaxis()  # Labels read top-to-bottom
        ax.set_xlabel('Count')
        
        # Add summary text at the top
        title = (f'High Uncertainty Samples (>{high_threshold:.3f})\n'
                f'Count: {num_high} ({num_high/len(self.sample_uncertainty):.1%})\n'
                f'Error: {high_error:.3f} vs {other_error:.3f} ({error_ratio:.1f}x higher)')
        ax.set_title(title)
        ax.grid(axis='x', alpha=0.3)
    
    def _plot_low_uncertainty_analysis(self, ax):
        """Analyze predictions with low uncertainty."""
        # Get low uncertainty threshold (e.g., bottom 10%)
        low_threshold = np.percentile(self.sample_uncertainty, 10)
        low_mask = self.sample_uncertainty <= low_threshold
        
        # Count of samples with low uncertainty
        num_low = np.sum(low_mask)
        
        # Average error for low uncertainty samples
        low_error = np.mean(self.sample_error[low_mask]) if num_low > 0 else 0
        
        # Average error for other samples
        other_error = np.mean(self.sample_error[~low_mask]) if np.sum(~low_mask) > 0 else 0
        
        # Error ratio (low uncertainty vs other)
        error_ratio = low_error / other_error if other_error > 0 else 0
        
        # Distribution of low uncertainty predictions by class
        low_uncertain_class_counts = np.sum(self.predictions[low_mask], axis=0)
        
        # Get top classes with low uncertainty predictions
        top_indices = np.argsort(low_uncertain_class_counts)[::-1][:10]
        top_classes = [self.class_names[i] for i in top_indices]
        top_counts = low_uncertain_class_counts[top_indices]
        
        # Plot bar chart of top classes with low uncertainty
        y_pos = np.arange(len(top_classes))
        bars = ax.barh(y_pos, top_counts, align='center', alpha=0.7, color='green')
        
        # Add text annotations for error rates by class
        for i, idx in enumerate(top_indices):
            # Calculate error rate for this class in low uncertainty samples
            class_preds = self.predictions[low_mask, idx]
            class_true = self.true_labels[low_mask, idx]
            class_error = np.mean(np.abs(class_preds - class_true)) if len(class_preds) > 0 else 0
            
            # Add text annotation
            ax.text(top_counts[i] + 0.5, i, f'Error: {class_error:.2f}', va='center')
        
        # Customize plot
        ax.set_yticks(y_pos)
        ax.set_yticklabels(top_classes)
        ax.invert_yaxis()  # Labels read top-to-bottom
        ax.set_xlabel('Count')
        
        # Add summary text at the top
        title = (f'Low Uncertainty Samples (<{low_threshold:.3f})\n'
                f'Count: {num_low} ({num_low/len(self.sample_uncertainty):.1%})\n'
                f'Error: {low_error:.3f} vs {other_error:.3f} ({error_ratio:.1f}x lower)')
        ax.set_title(title)
        ax.grid(axis='x', alpha=0.3)
    
    def _plot_calibration_curve(self, ax):
        """Plot calibration curve to assess the quality of predicted probabilities."""
        # Flatten arrays
        mean_probs_flat = self.mean_probs.flatten()
        std_probs_flat = self.std_probs.flatten()
        true_labels_flat = self.true_labels.flatten()
        
        # Remove NaN values if any
        valid_indices = ~np.isnan(mean_probs_flat) & ~np.isnan(true_labels_flat)
        mean_probs_flat = mean_probs_flat[valid_indices]
        std_probs_flat = std_probs_flat[valid_indices]
        true_labels_flat = true_labels_flat[valid_indices]
        
        # Create bins and calculate calibration metrics
        n_bins = 10
        bins = np.linspace(0.0, 1.0, n_bins + 1)
        binned = np.digitize(mean_probs_flat, bins) - 1
        bin_accs = np.zeros(n_bins)
        bin_confs = np.zeros(n_bins)
        bin_sizes = np.zeros(n_bins)
        bin_uncerts = np.zeros(n_bins)
        
        for bin_idx in range(n_bins):
            bin_mask = binned == bin_idx
            if np.sum(bin_mask) > 0:
                bin_sizes[bin_idx] = np.sum(bin_mask)
                bin_accs[bin_idx] = np.mean(true_labels_flat[bin_mask])
                bin_confs[bin_idx] = np.mean(mean_probs_flat[bin_mask])
                bin_uncerts[bin_idx] = np.mean(std_probs_flat[bin_mask])
        
        # Calculate the expected calibration error (ECE)
        ece = np.sum(bin_sizes / len(mean_probs_flat) * np.abs(bin_accs - bin_confs))
        
        # Plot the calibration curve
        ax.plot([0, 1], [0, 1], linestyle='--', color='gray', label='Perfectly calibrated')
        
        # Color points by uncertainty
        sc = ax.scatter(bin_confs, bin_accs, 
                      s=bin_sizes / np.sum(bin_sizes) * 2000,  # Size proportional to bin size
                      c=bin_uncerts, cmap='viridis', alpha=0.8, 
                      linewidths=1, edgecolors='black')
        
        # Add colorbar for uncertainty
        cbar = plt.colorbar(sc, ax=ax)
        cbar.set_label('Mean uncertainty (std)')
        
        # Customize plot
        ax.set_xlabel('Mean predicted probability')
        ax.set_ylabel('Fraction of positives (accuracy)')
        ax.set_xlim([0, 1])
        ax.set_ylim([0, 1])
        ax.set_title(f'Calibration Curve (ECE: {ece:.3f})')
        ax.grid(linestyle='--', alpha=0.7)
        ax.legend(loc='lower right')
    
    def _plot_reliability_diagram(self, ax):
        """Plot reliability diagram showing calibration for each bin."""
        # Flatten arrays
        mean_probs_flat = self.mean_probs.flatten()
        true_labels_flat = self.true_labels.flatten()
        
        # Remove NaN values if any
        valid_indices = ~np.isnan(mean_probs_flat) & ~np.isnan(true_labels_flat)
        mean_probs_flat = mean_probs_flat[valid_indices]
        true_labels_flat = true_labels_flat[valid_indices]
        
        # Create bins and calculate calibration metrics
        n_bins = 10
        bins = np.linspace(0.0, 1.0, n_bins + 1)
        binned = np.digitize(mean_probs_flat, bins) - 1
        bin_accs = np.zeros(n_bins)
        bin_confs = np.zeros(n_bins)
        bin_sizes = np.zeros(n_bins)
        
        for bin_idx in range(n_bins):
            bin_mask = binned == bin_idx
            if np.sum(bin_mask) > 0:
                bin_sizes[bin_idx] = np.sum(bin_mask)
                bin_accs[bin_idx] = np.mean(true_labels_flat[bin_mask])
                bin_confs[bin_idx] = np.mean(mean_probs_flat[bin_mask])
        
        # Create the reliability diagram
        width = 1.0 / n_bins
        bin_edges = np.linspace(0, 1, n_bins + 1)
        bin_centers = bin_edges[:-1] + width / 2
        
        # Plot the observed frequency in each bin
        ax.bar(bin_centers, bin_accs, width=width, alpha=0.7, color='blue', 
             label='Observed frequency')
        
        # Plot the gap between observed and predicted
        gaps = bin_accs - bin_confs
        ax.bar(bin_centers, gaps, bottom=bin_confs, width=width, alpha=0.7, 
             color='red' if np.sum(gaps < 0) > np.sum(gaps >= 0) else 'green', 
             label='Gap')
        
        # Add perfect calibration line
        ax.plot([0, 1], [0, 1], linestyle='--', color='gray', label='Perfect calibration')
        
        # Customize plot
        ax.set_xlabel('Predicted probability')
        ax.set_ylabel('Observed frequency')
        ax.set_xlim([0, 1])
        ax.set_ylim([0, 1])
        ax.set_title('Reliability Diagram')
        ax.grid(alpha=0.3)
        ax.legend(loc='upper left')
        
        # Add bin size annotations
        for i, (center, size) in enumerate(zip(bin_centers, bin_sizes)):
            ax.text(center, 0.05, f'{int(size)}', ha='center', 
                  bbox=dict(facecolor='white', alpha=0.7))
    
    def _plot_calibration_by_uncertainty(self, ax):
        """Plot calibration for different uncertainty levels."""
        # Sort samples by uncertainty
        sorted_indices = np.argsort(self.sample_uncertainty)
        n_samples = len(sorted_indices)
        
        # Create bins with equal number of samples
        n_bins = 5
        bin_size = n_samples // n_bins
        
        # Calculate calibration metrics for each bin
        bin_eces = []
        bin_uncerts = []
        
        for i in range(n_bins):
            # Get indices for this bin
            start_idx = i * bin_size
            end_idx = (i + 1) * bin_size if i < n_bins - 1 else n_samples
            bin_indices = sorted_indices[start_idx:end_idx]
            
            # Get mean uncertainty for this bin
            bin_uncert = np.mean(self.sample_uncertainty[bin_indices])
            bin_uncerts.append(bin_uncert)
            
            # Calculate ECE for this bin
            bin_probs = self.mean_probs[bin_indices].flatten()
            bin_labels = self.true_labels[bin_indices].flatten()
            
            # Create sub-bins for ECE calculation
            n_sub_bins = 10
            sub_bins = np.linspace(0.0, 1.0, n_sub_bins + 1)
            binned = np.digitize(bin_probs, sub_bins) - 1
            
            # Calculate ECE
            ece = 0
            for sub_bin_idx in range(n_sub_bins):
                sub_bin_mask = binned == sub_bin_idx
                if np.sum(sub_bin_mask) > 0:
                    sub_bin_size = np.sum(sub_bin_mask)
                    sub_bin_acc = np.mean(bin_labels[sub_bin_mask])
                    sub_bin_conf = np.mean(bin_probs[sub_bin_mask])
                    ece += (sub_bin_size / len(bin_probs)) * np.abs(sub_bin_acc - sub_bin_conf)
            
            bin_eces.append(ece)
        
        # Plot bar chart of ECE by uncertainty bin
        bars = ax.bar(range(n_bins), bin_eces, alpha=0.7)
        
        # Add text annotations for mean uncertainty
        for i, (ece, uncert) in enumerate(zip(bin_eces, bin_uncerts)):
            ax.text(i, ece + 0.005, f'Unc: {uncert:.3f}', ha='center')
        
        # Customize plot
        ax.set_xticks(range(n_bins))
        ax.set_xticklabels([f'Bin {i+1}' for i in range(n_bins)])
        ax.set_ylabel('Expected Calibration Error (ECE)')
        ax.set_title('Calibration Error by Uncertainty Level')
        ax.grid(axis='y', alpha=0.3)
    
    def export_results(self, _):
        """Export prediction results and uncertainty metrics to CSV."""
        with self.output:
            clear_output(wait=True)
            
            # Gather results
            results = {
                'sample_idx': np.arange(len(self.mean_probs)),
                'uncertainty': self.sample_uncertainty,
                'error_rate': self.sample_error
            }
            
            # Add binary predictions for each class
            for i, class_name in enumerate(self.class_names):
                results[f'pred_{class_name}'] = self.predictions[:, i]
                results[f'true_{class_name}'] = self.true_labels[:, i]
                results[f'prob_{class_name}'] = self.mean_probs[:, i]
                results[f'uncert_{class_name}'] = self.std_probs[:, i]
            
            # Create DataFrame
            df = pd.DataFrame(results)
            
            # Save to CSV
            filename = 'ensemble_predictions.csv'
            df.to_csv(filename, index=False)
            
            print(f"Results exported to {filename}")
            print(f"DataFrame shape: {df.shape}")
            print("\nFirst few rows:")
            display(df.head())


# Example usage
if __name__ == "__main__":
    # Load saved ensemble
    ensemble = DeepEnsemble(
        model_class=StarClassifierFusion,
        model_args={
            "d_model_spectra": 2048,
            "d_model_gaia": 2048,
            "num_classes": 55,
            "input_dim_spectra": 3647,
            "input_dim_gaia": 18,
            "n_layers": 12,
            "use_cross_attention": True,
            "n_cross_attn_heads": 8
        },
        num_models=5,
        device='cuda'
    )
    
    # Load saved models
    ensemble.load_models("ensemble_mamba_v1")
    
    # Initialize class names (replace with actual names)
    class_names = [f"Class_{i}" for i in range(55)]
    
    # Create dashboard with test dataset
    dashboard = EnsembleDashboard(
        ensemble=ensemble,
        test_dataset=test_dataset,
        class_names=class_names
    