In [3]:
import torch
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn.metrics import f1_score, precision_score, recall_score, hamming_loss, roc_auc_score, average_precision_score
import os
import json
import gc
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
from tqdm import tqdm
import pickle
from torch.utils.data import Dataset
from sklearn.metrics import roc_auc_score


# Import your model architectures
from Fusion_Models import StarClassifierFusionMambaOut, StarClassifierFusionTransformer, StarClassifierFusionMambaTokenized

class MultiModalBalancedMultiLabelDataset(Dataset):
    """
    A balanced multi-label dataset that returns (X_spectra, X_gaia, y).
    It uses the same balancing strategy as `BalancedMultiLabelDataset`.
    """
    def __init__(self, X_spectra, X_gaia, y, limit_per_label=201):
        """
        Args:
            X_spectra (torch.Tensor): [num_samples, num_spectra_features]
            X_gaia (torch.Tensor): [num_samples, num_gaia_features]
            y (torch.Tensor): [num_samples, num_classes], multi-hot labels
            limit_per_label (int): limit or target number of samples per label
        """
        self.X_spectra = X_spectra
        self.X_gaia = X_gaia
        self.y = y
        self.limit_per_label = limit_per_label
        self.num_classes = y.shape[1]
        self.indices = self.balance_classes()
        
    def balance_classes(self):
        indices = []
        class_counts = torch.sum(self.y, axis=0)
        for cls in range(self.num_classes):
            cls_indices = np.where(self.y[:, cls] == 1)[0]
            if len(cls_indices) < self.limit_per_label:
                if len(cls_indices) == 0:
                    # No samples for this class
                    continue
                extra_indices = np.random.choice(
                    cls_indices, self.limit_per_label - len(cls_indices), replace=True
                )
                cls_indices = np.concatenate([cls_indices, extra_indices])
            elif len(cls_indices) > self.limit_per_label:
                cls_indices = np.random.choice(cls_indices, self.limit_per_label, replace=False)
            indices.extend(cls_indices)
        indices = np.unique(indices)
        np.random.shuffle(indices)
        return indices

    def re_sample(self):
        self.indices = self.balance_classes()

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        index = self.indices[idx]
        return (
            self.X_spectra[index],  # spectra features
            self.X_gaia[index],     # gaia features
            self.y[index],          # multi-hot labels
        )
    
def calculate_class_weights(y):
    if y.ndim > 1:  
        class_counts = np.sum(y, axis=0)  
    else:
        class_counts = np.bincount(y)

    total_samples = y.shape[0] if y.ndim > 1 else len(y)
    class_counts = np.where(class_counts == 0, 1, class_counts)  # Prevent division by zero
    class_weights = total_samples / (len(class_counts) * class_counts)
    
    return class_weights

def calculate_metrics(y_true, y_pred):
    metrics = {
        "micro_f1": f1_score(y_true, y_pred, average='micro'),
        "macro_f1": f1_score(y_true, y_pred, average='macro'),
        "weighted_f1": f1_score(y_true, y_pred, average='weighted'),
        "micro_precision": precision_score(y_true, y_pred, average='micro', zero_division=1),
        "macro_precision": precision_score(y_true, y_pred, average='macro', zero_division=1),
        "weighted_precision": precision_score(y_true, y_pred, average='weighted', zero_division=1),
        "micro_recall": recall_score(y_true, y_pred, average='micro'),
        "macro_recall": recall_score(y_true, y_pred, average='macro'),
        "weighted_recall": recall_score(y_true, y_pred, average='weighted'),
        "hamming_loss": hamming_loss(y_true, y_pred)
    }
    
    # Check if there are at least two classes present in y_true
    #if len(np.unique(y_true)) > 1:
        #metrics["roc_auc"] = roc_auc_score(y_true, y_pred, average='macro', multi_class='ovr')
    #else:
       # metrics["roc_auc"] = None  # or you can set it to a default value or message
    
    return metrics



def load_data():
    """Load and preprocess the data"""
    print("Loading datasets...")
    
    # Load classes
    with open("Pickles/Updated_List_of_Classes_ubuntu.pkl", "rb") as f:
        classes = pickle.load(f)
    
    # Load test data
    with open("Pickles/test_data_transformed_ubuntu.pkl", "rb") as f:
        X_test_full = pickle.load(f)
    
    # Extract labels
    y_test = X_test_full[classes]
    
    # Drop labels from both datasets
    X_test_full.drop(classes, axis=1, inplace=True)
    
    # Define Gaia columns
    gaia_columns = ["parallax", "ra", "dec", "ra_error", "dec_error", "parallax_error", 
                   "pmra", "pmdec", "pmra_error", "pmdec_error", "phot_g_mean_flux", 
                   "flagnopllx", "phot_g_mean_flux_error", "phot_bp_mean_flux", 
                   "phot_rp_mean_flux", "phot_bp_mean_flux_error", 
                   "phot_rp_mean_flux_error", "flagnoflux"]
    
    # Split data into spectra and gaia parts
    X_test_spectra = X_test_full.drop(columns={"otype", "obsid", *gaia_columns})
    
    X_test_gaia = X_test_full[gaia_columns]
    
    # Free up memory
    del X_test_full
    gc.collect()
    
    # Convert to PyTorch tensors
    X_test_spectra_tensor = torch.tensor(X_test_spectra.values, dtype=torch.float32)

    X_test_gaia_tensor = torch.tensor(X_test_gaia.values, dtype=torch.float32)
    
    y_test_tensor = torch.tensor(y_test.values, dtype=torch.float32)
    
    return (X_test_spectra_tensor, X_test_gaia_tensor, y_test_tensor)

def evaluate_model(model, test_loader, device='cuda'):
    """Evaluate a model on test data and return comprehensive metrics"""
    model.eval()
    test_loss = 0.0
    test_acc = 0.0
    y_true, y_pred, y_prob = [], [], []
    
    # Compute class weights for loss function
    all_labels = []
    for _, _, y_batch in test_loader:
        all_labels.extend(y_batch.cpu().numpy())
    
    class_weights = calculate_class_weights(np.array(all_labels))
    class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)
    criterion = nn.BCEWithLogitsLoss(pos_weight=class_weights)
    
    # Evaluation loop
    with torch.no_grad():
        for X_spc, X_ga, y_batch in tqdm(test_loader, desc="Evaluating"):
            X_spc, X_ga, y_batch = X_spc.to(device), X_ga.to(device), y_batch.to(device)
            outputs = model(X_spc, X_ga)
            loss = criterion(outputs, y_batch)
            test_loss += loss.item() * X_spc.size(0)
            
            probs = torch.sigmoid(outputs)
            predicted = (probs > 0.5).float()
            correct = (predicted == y_batch).float()
            test_acc += correct.mean(dim=1).mean().item()

            y_true.extend(y_batch.cpu().numpy())
            y_pred.extend(predicted.cpu().numpy())
            y_prob.extend(probs.cpu().numpy())
    
    # Convert to numpy arrays
    y_true_array = np.array(y_true)
    y_pred_array = np.array(y_pred)
    y_prob_array = np.array(y_prob)
    
    # Calculate metrics
    metrics = calculate_metrics(y_true_array, y_pred_array)
    
    # Add average metrics
    metrics["avg_loss"] = test_loss / len(test_loader.dataset)
    metrics["avg_accuracy"] = test_acc / len(test_loader)
    
    # Calculate AUROC if possible
    try:
        class_aurocs = []
        for i in range(y_true_array.shape[1]):
            if len(np.unique(y_true_array[:, i])) > 1:
                class_auroc = roc_auc_score(y_true_array[:, i], y_prob_array[:, i])
                class_aurocs.append(class_auroc)
        
        if class_aurocs:
            metrics["macro_auroc"] = np.mean(class_aurocs)
    except Exception as e:
        print(f"Error calculating AUROC: {e}")
        metrics["macro_auroc"] = float('nan')
    
    return metrics

def main():
    # Create results directory
    results_dir = f"model_comparison_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
    os.makedirs(results_dir, exist_ok=True)
    os.makedirs("Models", exist_ok=True)
    
    # Set device
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {device}")
    
    # Load data
    (X_test_spectra, X_test_gaia, y_test) = load_data()
    
    # Create datasets and dataloaders
    batch_size = 16
    batch_limit = int(batch_size / 2.5)
    

    test_dataset = MultiModalBalancedMultiLabelDataset(
        X_test_spectra, X_test_gaia, y_test, limit_per_label=batch_limit
    )
    

    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    

    print(f"Test samples: {len(test_dataset)}")
    # Define model configurations to evaluate
    model_configs = [
        # MambaOut Models
        {
            "name": "MambaOut_1token",
            "model_class": StarClassifierFusionMambaOut,
            "params": {
                "d_model_spectra": 2048,
                "d_model_gaia": 2048,
                "num_classes": 55,
                "input_dim_spectra": 3647,
                "input_dim_gaia": 18,
                "token_dim_spectra": 3647,  # 1 token
                "token_dim_gaia": 18,       # 1 token
                "n_layers": 20,
                "d_conv": 1,
                "use_cross_attention": True,
                "n_cross_attn_heads": 8
            },
            "checkpoint": "Comparing_Mambas_Trans/gated_cnn_(mambaout)_1_token.pth"
        },
        {
            "name": "MambaOut_19_18token",
            "model_class": StarClassifierFusionMambaOut,
            "params": {
                "d_model_spectra": 2048,
                "d_model_gaia": 2048,
                "num_classes": 55,
                "input_dim_spectra": 3647,
                "input_dim_gaia": 18,
                "token_dim_spectra": 192,  # ~19 tokens
                "token_dim_gaia": 1,       # 18 tokens
                "n_layers": 20,
                "d_conv": 4,
                "use_cross_attention": True,
                "n_cross_attn_heads": 8
            },
            "checkpoint": "Comparing_Mambas_Trans/gated_cnn_(mambaout)_balanced.pth"
        },
        {
            "name": "MambaOut_522_18token",
            "model_class": StarClassifierFusionMambaOut,
            "params": {
                "d_model_spectra": 1536,
                "d_model_gaia": 1536,
                "num_classes": 55,
                "input_dim_spectra": 3647,
                "input_dim_gaia": 18,
                "token_dim_spectra": 7,    # ~522 tokens
                "token_dim_gaia": 1,       # 18 tokens
                "n_layers": 20,
                "d_conv": 32,
                "use_cross_attention": True,
                "n_cross_attn_heads": 8
            },
            "checkpoint": "Comparing_Mambas_Trans/gated_cnn_(mambaout)_max_tokens.pth"
        },
        
        # Transformer Models
        {
            "name": "Transformer_1token",
            "model_class": StarClassifierFusionTransformer,
            "params": {
                "d_model_spectra": 2048,
                "d_model_gaia": 2048,
                "num_classes": 55,
                "input_dim_spectra": 3647,
                "input_dim_gaia": 18,
                "token_dim_spectra": 3647,  # 1 token
                "token_dim_gaia": 18,       # 1 token
                "n_layers": 10,
                "n_heads": 8,
                "use_cross_attention": True,
                "n_cross_attn_heads": 8,
                "dropout": 0.1
            },
            "checkpoint": "Comparing_Mambas_Trans/transformer_1_token.pth"
        },
        {
            "name": "Transformer_19_18token",
            "model_class": StarClassifierFusionTransformer,
            "params": {
                "d_model_spectra": 2048,
                "d_model_gaia": 2048,
                "num_classes": 55,
                "input_dim_spectra": 3647,
                "input_dim_gaia": 18,
                "token_dim_spectra": 192,  # ~19 tokens
                "token_dim_gaia": 1,       # 18 tokens
                "n_layers": 10,
                "n_heads": 8,
                "use_cross_attention": True,
                "n_cross_attn_heads": 8,
                "dropout": 0.1
            },
            "checkpoint": "Comparing_Mambas_Trans/transformer_balanced.pth"
        },
        {
            "name": "Transformer_522_18token",
            "model_class": StarClassifierFusionTransformer,
            "params": {
                "d_model_spectra": 1536,
                "d_model_gaia": 1536,
                "num_classes": 55,
                "input_dim_spectra": 3647,
                "input_dim_gaia": 18,
                "token_dim_spectra": 7,    # ~522 tokens
                "token_dim_gaia": 1,       # 18 tokens
                "n_layers": 10,
                "n_heads": 8,
                "use_cross_attention": True,
                "n_cross_attn_heads": 8,
                "dropout": 0.1
            },
            "checkpoint": "Comparing_Mambas_Trans/transformer_max_tokens.pth"
        },
        
        # Mamba2 Tokenized Models
        {
            "name": "Mamba2_1token",
            "model_class": StarClassifierFusionMambaTokenized,
            "params": {
                "d_model_spectra": 2048,
                "d_model_gaia": 2048,
                "num_classes": 55,
                "input_dim_spectra": 3647,
                "input_dim_gaia": 18,
                "token_dim_spectra": 3647,  # 1 token
                "token_dim_gaia": 18,       # 1 token
                "n_layers": 20,
                "d_state": 32,
                "d_conv": 2,
                "expand": 2,
                "use_cross_attention": True,
                "n_cross_attn_heads": 8
            },
            "checkpoint": "Comparing_Mambas_Trans/mamba_1_token.pth"
        },
        {
            "name": "Mamba2_19_18token",
            "model_class": StarClassifierFusionMambaTokenized,
            "params": {
                "d_model_spectra": 2048,
                "d_model_gaia": 2048,
                "num_classes": 55,
                "input_dim_spectra": 3647,
                "input_dim_gaia": 18,
                "token_dim_spectra": 192,  # ~19 tokens
                "token_dim_gaia": 1,       # 18 tokens
                "n_layers": 20,
                "d_state": 32,
                "d_conv": 4,
                "expand": 2,
                "use_cross_attention": True,
                "n_cross_attn_heads": 8
            },
            "checkpoint": "Comparing_Mambas_Trans/mamba_balanced.pth"
        },
        {
            "name": "Mamba2_522_18token",
            "model_class": StarClassifierFusionMambaTokenized,
            "params": {
                "d_model_spectra": 1536,
                "d_model_gaia": 1536,
                "num_classes": 55,
                "input_dim_spectra": 3647,
                "input_dim_gaia": 18,
                "token_dim_spectra": 7,    # ~522 tokens
                "token_dim_gaia": 1,       # 18 tokens
                "n_layers": 20,
                "d_state": 16,
                "d_conv": 4,
                "expand": 2,
                "use_cross_attention": True,
                "n_cross_attn_heads": 8
            },
            "checkpoint": "Comparing_Mambas_Trans/mamba_max_tokens.pth"
        }
    ]
    
    # Store results
    results = {}
    
    # Evaluate each model
    for config in model_configs:
        print(f"\n{'='*50}")
        print(f"Evaluating model: {config['name']}")
        print(f"{'='*50}")
        
        # Create model instance
        model = config["model_class"](**config["params"])
        
        # Load checkpoint if exists
        checkpoint_path = config["checkpoint"]
        if os.path.exists(checkpoint_path):
            print(f"Loading checkpoint from {checkpoint_path}")
            model.load_state_dict(torch.load(checkpoint_path, map_location=device))
        else:
            print(f"Checkpoint {checkpoint_path} not found. Skipping this model.")
            continue
        
        # Move model to device
        model = model.to(device)
        
        # Print model statistics
        num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        print(f"Number of parameters: {num_params:,}")
        
        # Calculate model size in MB
        param_size = sum(p.nelement() * p.element_size() for p in model.parameters())
        buffer_size = sum(b.nelement() * b.element_size() for b in model.buffers())
        size_mb = (param_size + buffer_size) / (1024**2)
        print(f"Model size: {size_mb:.2f} MB")
        
        # Evaluate model
        metrics = evaluate_model(model, test_loader, device)
        
        # Add model info to metrics
        metrics["model_name"] = config["name"]
        metrics["num_parameters"] = num_params
        metrics["model_size_mb"] = size_mb
        
        # Get token counts for analysis
        spectra_tokens = (config["params"]["input_dim_spectra"] + config["params"]["token_dim_spectra"] - 1) // config["params"]["token_dim_spectra"]
        gaia_tokens = (config["params"]["input_dim_gaia"] + config["params"]["token_dim_gaia"] - 1) // config["params"]["token_dim_gaia"]
        metrics["spectra_tokens"] = spectra_tokens
        metrics["gaia_tokens"] = gaia_tokens
        metrics["total_tokens"] = spectra_tokens + gaia_tokens
        
        # Print key metrics
        print("\nTest Metrics:")
        print(f"  Loss: {metrics['avg_loss']:.4f}")
        print(f"  Accuracy: {metrics['avg_accuracy']:.4f}")
        print(f"  Micro F1: {metrics['micro_f1']:.4f}")
        print(f"  Macro F1: {metrics['macro_f1']:.4f}")
        print(f"  Weighted F1: {metrics['weighted_f1']:.4f}")
        print(f"  Macro AUROC: {metrics.get('macro_auroc', 'N/A')}")
        
        # Store results
        results[config["name"]] = metrics
        
        # Clear memory
        del model
        torch.cuda.empty_cache()
        gc.collect()
    
    # Save results to JSON
    results_file = os.path.join(results_dir, "model_comparison_results.json")
    with open(results_file, 'w') as f:
        json.dump(results, f, indent=2)
    
    # Convert to DataFrame for easier analysis
    results_df = pd.DataFrame.from_dict(results, orient='index')
    
    # Save DataFrame to CSV
    csv_file = os.path.join(results_dir, "model_comparison_results.csv")
    results_df.to_csv(csv_file)
    
    # Create model family summary
    model_families = {
        "MambaOut": [m for m in results_df.index if m.startswith("MambaOut")],
        "Transformer": [m for m in results_df.index if m.startswith("Transformer")],
        "Mamba2": [m for m in results_df.index if m.startswith("Mamba2")]
    }
    
    family_results = {}
    for family, models in model_families.items():
        if models:
            family_results[family] = results_df.loc[models].mean()
    
    family_df = pd.DataFrame.from_dict(family_results, orient='index')
    family_csv = os.path.join(results_dir, "model_family_summary.csv")
    family_df.to_csv(family_csv)
    
    # Create comparative visualizations
    
    # 1. Performance by model metrics bar chart
    key_metrics = ['micro_f1', 'macro_f1', 'weighted_f1', 'macro_auroc']
    plt.figure(figsize=(15, 10))
    
    for i, metric in enumerate(key_metrics):
        if metric in results_df.columns:
            plt.subplot(2, 2, i+1)
            sns.barplot(x=results_df.index, y=results_df[metric])
            plt.title(f"{metric.replace('_', ' ').title()}")
            plt.xticks(rotation=45, ha='right')
            plt.ylim(0, 1)
    
    plt.tight_layout()
    plt.savefig(os.path.join(results_dir, "key_metrics_comparison.png"))
    
    # 2. Performance vs Model Size scatter plot
    plt.figure(figsize=(12, 8))
    
    for family, models in model_families.items():
        if models:
            sns.scatterplot(
                x=results_df.loc[models, 'model_size_mb'], 
                y=results_df.loc[models, 'macro_f1'],
                label=family,
                s=100
            )
    
    for i, model in enumerate(results_df.index):
        plt.annotate(
            model,
            (results_df.loc[model, 'model_size_mb'], results_df.loc[model, 'macro_f1']),
            xytext=(5, 5),
            textcoords='offset points'
        )
    
    plt.xlabel("Model Size (MB)")
    plt.ylabel("Macro F1 Score")
    plt.title("Model Performance vs Model Size")
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.savefig(os.path.join(results_dir, "performance_vs_size.png"))
    
    # 3. Performance vs Number of Tokens scatter plot
    plt.figure(figsize=(12, 8))
    
    for family, models in model_families.items():
        if models:
            sns.scatterplot(
                x=results_df.loc[models, 'total_tokens'], 
                y=results_df.loc[models, 'macro_f1'],
                label=family,
                s=100
            )
    
    for i, model in enumerate(results_df.index):
        plt.annotate(
            model,
            (results_df.loc[model, 'total_tokens'], results_df.loc[model, 'macro_f1']),
            xytext=(5, 5),
            textcoords='offset points'
        )
    
    plt.xscale('log')
    plt.xlabel("Total Number of Tokens (log scale)")
    plt.ylabel("Macro F1 Score")
    plt.title("Model Performance vs Number of Tokens")
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.savefig(os.path.join(results_dir, "performance_vs_tokens.png"))
    
    print(f"\nResults saved to {results_dir}/")
    print(f"Summary: {family_csv}")
    
    # Generate a text summary report
    with open(os.path.join(results_dir, "summary_report.txt"), 'w') as f:
        f.write("MODEL EVALUATION SUMMARY\n")
        f.write("======================\n\n")
        
        f.write("Best Models by Metric:\n")
        for metric in ['micro_f1', 'macro_f1', 'weighted_f1', 'macro_auroc']:
            if metric in results_df.columns:
                best_model = results_df[metric].idxmax()
                f.write(f"  Best {metric}: {best_model} ({results_df.loc[best_model, metric]:.4f})\n")
        
        f.write("\nModel Family Comparison:\n")
        for family, metrics in family_results.items():
            f.write(f"  {family}:\n")
            f.write(f"    Macro F1: {metrics['macro_f1']:.4f}\n")
            f.write(f"    Average Size: {metrics['model_size_mb']:.2f} MB\n")
        
        f.write("\nToken Configuration Analysis:\n")
        token_configs = ["1token", "19_18token", "522_18token"]
        for config in token_configs:
            models = [m for m in results_df.index if config in m]
            if models:
                config_df = results_df.loc[models]
                f.write(f"  {config}:\n")
                f.write(f"    Average Macro F1: {config_df['macro_f1'].mean():.4f}\n")
                f.write(f"    Best Model: {config_df['macro_f1'].idxmax()} ({config_df['macro_f1'].max():.4f})\n")
        
        f.write("\nDetailed Model Rankings:\n")
        for rank, (model, metrics) in enumerate(results_df.sort_values('macro_f1', ascending=False).iterrows(), 1):
            f.write(f"  {rank}. {model}:\n")
            f.write(f"     Macro F1: {metrics['macro_f1']:.4f}\n")
            f.write(f"     Size: {metrics['model_size_mb']:.2f} MB\n")
            f.write(f"     Parameters: {metrics['num_parameters']:,}\n")
            f.write(f"     Token Config: {metrics['spectra_tokens']} spectra, {metrics['gaia_tokens']} gaia\n")
    
    print(f"Summary report generated: {results_dir}/summary_report.txt")

if __name__ == "__main__":
    main()

Using device: cuda
Loading datasets...
Test samples: 259

Evaluating model: MambaOut_1token
Loading checkpoint from Comparing_Mambas_Trans/gated_cnn_(mambaout)_1_token.pth
Number of parameters: 1,048,856,631
Model size: 4001.07 MB


Evaluating: 100%|██████████| 17/17 [00:00<00:00, 103.42it/s]


Test Metrics:
  Loss: 0.0674
  Accuracy: 0.9771
  Micro F1: 0.6099
  Macro F1: 0.4359
  Weighted F1: 0.5772
  Macro AUROC: 0.9298620191392433






Evaluating model: MambaOut_19_18token
Loading checkpoint from Comparing_Mambas_Trans/gated_cnn_(mambaout)_balanced.pth
Number of parameters: 1,042,237,495
Model size: 3975.82 MB


Evaluating: 100%|██████████| 17/17 [00:00<00:00, 20.29it/s]



Test Metrics:
  Loss: 0.0751
  Accuracy: 0.9735
  Micro F1: 0.5152
  Macro F1: 0.3707
  Weighted F1: 0.4852
  Macro AUROC: 0.9012070076937412

Evaluating model: MambaOut_522_18token
Loading checkpoint from Comparing_Mambas_Trans/gated_cnn_(mambaout)_max_tokens.pth
Number of parameters: 589,799,479
Model size: 2249.91 MB


Evaluating: 100%|██████████| 17/17 [00:06<00:00,  2.73it/s]



Test Metrics:
  Loss: 0.0612
  Accuracy: 0.9719
  Micro F1: 0.4371
  Macro F1: 0.3131
  Weighted F1: 0.3990
  Macro AUROC: 0.9174140493747273

Evaluating model: Transformer_1token
Checkpoint Comparing_Mambas_Trans/transformer_1_token.pth not found. Skipping this model.

Evaluating model: Transformer_19_18token
Loading checkpoint from Comparing_Mambas_Trans/transformer_balanced.pth
Number of parameters: 1,041,377,335
Model size: 4052.55 MB


Evaluating: 100%|██████████| 17/17 [00:00<00:00, 19.44it/s]



Test Metrics:
  Loss: 0.0760
  Accuracy: 0.9743
  Micro F1: 0.5782
  Macro F1: 0.4444
  Weighted F1: 0.5574
  Macro AUROC: 0.8785354495338072

Evaluating model: Transformer_522_18token
Checkpoint Comparing_Mambas_Trans/transformer_max_tokens.pth not found. Skipping this model.

Evaluating model: Mamba2_1token
Checkpoint Comparing_Mambas_Trans/mamba_1_token.pth not found. Skipping this model.

Evaluating model: Mamba2_19_18token
Checkpoint Comparing_Mambas_Trans/mamba_balanced.pth not found. Skipping this model.

Evaluating model: Mamba2_522_18token
Loading checkpoint from Comparing_Mambas_Trans/mamba_max_tokens.pth


RuntimeError: Error(s) in loading state_dict for StarClassifierFusionMambaTokenized:
	size mismatch for mamba_spectra.0.in_proj.weight: copying a param with shape torch.Size([6704, 1536]) from checkpoint, the shape in current model is torch.Size([6224, 1536]).
	size mismatch for mamba_spectra.0.conv1d.weight: copying a param with shape torch.Size([3584, 1, 4]) from checkpoint, the shape in current model is torch.Size([3104, 1, 4]).
	size mismatch for mamba_spectra.0.conv1d.bias: copying a param with shape torch.Size([3584]) from checkpoint, the shape in current model is torch.Size([3104]).
	size mismatch for mamba_spectra.1.in_proj.weight: copying a param with shape torch.Size([6704, 1536]) from checkpoint, the shape in current model is torch.Size([6224, 1536]).
	size mismatch for mamba_spectra.1.conv1d.weight: copying a param with shape torch.Size([3584, 1, 4]) from checkpoint, the shape in current model is torch.Size([3104, 1, 4]).
	size mismatch for mamba_spectra.1.conv1d.bias: copying a param with shape torch.Size([3584]) from checkpoint, the shape in current model is torch.Size([3104]).
	size mismatch for mamba_spectra.2.in_proj.weight: copying a param with shape torch.Size([6704, 1536]) from checkpoint, the shape in current model is torch.Size([6224, 1536]).
	size mismatch for mamba_spectra.2.conv1d.weight: copying a param with shape torch.Size([3584, 1, 4]) from checkpoint, the shape in current model is torch.Size([3104, 1, 4]).
	size mismatch for mamba_spectra.2.conv1d.bias: copying a param with shape torch.Size([3584]) from checkpoint, the shape in current model is torch.Size([3104]).
	size mismatch for mamba_spectra.3.in_proj.weight: copying a param with shape torch.Size([6704, 1536]) from checkpoint, the shape in current model is torch.Size([6224, 1536]).
	size mismatch for mamba_spectra.3.conv1d.weight: copying a param with shape torch.Size([3584, 1, 4]) from checkpoint, the shape in current model is torch.Size([3104, 1, 4]).
	size mismatch for mamba_spectra.3.conv1d.bias: copying a param with shape torch.Size([3584]) from checkpoint, the shape in current model is torch.Size([3104]).
	size mismatch for mamba_spectra.4.in_proj.weight: copying a param with shape torch.Size([6704, 1536]) from checkpoint, the shape in current model is torch.Size([6224, 1536]).
	size mismatch for mamba_spectra.4.conv1d.weight: copying a param with shape torch.Size([3584, 1, 4]) from checkpoint, the shape in current model is torch.Size([3104, 1, 4]).
	size mismatch for mamba_spectra.4.conv1d.bias: copying a param with shape torch.Size([3584]) from checkpoint, the shape in current model is torch.Size([3104]).
	size mismatch for mamba_spectra.5.in_proj.weight: copying a param with shape torch.Size([6704, 1536]) from checkpoint, the shape in current model is torch.Size([6224, 1536]).
	size mismatch for mamba_spectra.5.conv1d.weight: copying a param with shape torch.Size([3584, 1, 4]) from checkpoint, the shape in current model is torch.Size([3104, 1, 4]).
	size mismatch for mamba_spectra.5.conv1d.bias: copying a param with shape torch.Size([3584]) from checkpoint, the shape in current model is torch.Size([3104]).
	size mismatch for mamba_spectra.6.in_proj.weight: copying a param with shape torch.Size([6704, 1536]) from checkpoint, the shape in current model is torch.Size([6224, 1536]).
	size mismatch for mamba_spectra.6.conv1d.weight: copying a param with shape torch.Size([3584, 1, 4]) from checkpoint, the shape in current model is torch.Size([3104, 1, 4]).
	size mismatch for mamba_spectra.6.conv1d.bias: copying a param with shape torch.Size([3584]) from checkpoint, the shape in current model is torch.Size([3104]).
	size mismatch for mamba_spectra.7.in_proj.weight: copying a param with shape torch.Size([6704, 1536]) from checkpoint, the shape in current model is torch.Size([6224, 1536]).
	size mismatch for mamba_spectra.7.conv1d.weight: copying a param with shape torch.Size([3584, 1, 4]) from checkpoint, the shape in current model is torch.Size([3104, 1, 4]).
	size mismatch for mamba_spectra.7.conv1d.bias: copying a param with shape torch.Size([3584]) from checkpoint, the shape in current model is torch.Size([3104]).
	size mismatch for mamba_spectra.8.in_proj.weight: copying a param with shape torch.Size([6704, 1536]) from checkpoint, the shape in current model is torch.Size([6224, 1536]).
	size mismatch for mamba_spectra.8.conv1d.weight: copying a param with shape torch.Size([3584, 1, 4]) from checkpoint, the shape in current model is torch.Size([3104, 1, 4]).
	size mismatch for mamba_spectra.8.conv1d.bias: copying a param with shape torch.Size([3584]) from checkpoint, the shape in current model is torch.Size([3104]).
	size mismatch for mamba_spectra.9.in_proj.weight: copying a param with shape torch.Size([6704, 1536]) from checkpoint, the shape in current model is torch.Size([6224, 1536]).
	size mismatch for mamba_spectra.9.conv1d.weight: copying a param with shape torch.Size([3584, 1, 4]) from checkpoint, the shape in current model is torch.Size([3104, 1, 4]).
	size mismatch for mamba_spectra.9.conv1d.bias: copying a param with shape torch.Size([3584]) from checkpoint, the shape in current model is torch.Size([3104]).
	size mismatch for mamba_spectra.10.in_proj.weight: copying a param with shape torch.Size([6704, 1536]) from checkpoint, the shape in current model is torch.Size([6224, 1536]).
	size mismatch for mamba_spectra.10.conv1d.weight: copying a param with shape torch.Size([3584, 1, 4]) from checkpoint, the shape in current model is torch.Size([3104, 1, 4]).
	size mismatch for mamba_spectra.10.conv1d.bias: copying a param with shape torch.Size([3584]) from checkpoint, the shape in current model is torch.Size([3104]).
	size mismatch for mamba_spectra.11.in_proj.weight: copying a param with shape torch.Size([6704, 1536]) from checkpoint, the shape in current model is torch.Size([6224, 1536]).
	size mismatch for mamba_spectra.11.conv1d.weight: copying a param with shape torch.Size([3584, 1, 4]) from checkpoint, the shape in current model is torch.Size([3104, 1, 4]).
	size mismatch for mamba_spectra.11.conv1d.bias: copying a param with shape torch.Size([3584]) from checkpoint, the shape in current model is torch.Size([3104]).
	size mismatch for mamba_spectra.12.in_proj.weight: copying a param with shape torch.Size([6704, 1536]) from checkpoint, the shape in current model is torch.Size([6224, 1536]).
	size mismatch for mamba_spectra.12.conv1d.weight: copying a param with shape torch.Size([3584, 1, 4]) from checkpoint, the shape in current model is torch.Size([3104, 1, 4]).
	size mismatch for mamba_spectra.12.conv1d.bias: copying a param with shape torch.Size([3584]) from checkpoint, the shape in current model is torch.Size([3104]).
	size mismatch for mamba_spectra.13.in_proj.weight: copying a param with shape torch.Size([6704, 1536]) from checkpoint, the shape in current model is torch.Size([6224, 1536]).
	size mismatch for mamba_spectra.13.conv1d.weight: copying a param with shape torch.Size([3584, 1, 4]) from checkpoint, the shape in current model is torch.Size([3104, 1, 4]).
	size mismatch for mamba_spectra.13.conv1d.bias: copying a param with shape torch.Size([3584]) from checkpoint, the shape in current model is torch.Size([3104]).
	size mismatch for mamba_spectra.14.in_proj.weight: copying a param with shape torch.Size([6704, 1536]) from checkpoint, the shape in current model is torch.Size([6224, 1536]).
	size mismatch for mamba_spectra.14.conv1d.weight: copying a param with shape torch.Size([3584, 1, 4]) from checkpoint, the shape in current model is torch.Size([3104, 1, 4]).
	size mismatch for mamba_spectra.14.conv1d.bias: copying a param with shape torch.Size([3584]) from checkpoint, the shape in current model is torch.Size([3104]).
	size mismatch for mamba_spectra.15.in_proj.weight: copying a param with shape torch.Size([6704, 1536]) from checkpoint, the shape in current model is torch.Size([6224, 1536]).
	size mismatch for mamba_spectra.15.conv1d.weight: copying a param with shape torch.Size([3584, 1, 4]) from checkpoint, the shape in current model is torch.Size([3104, 1, 4]).
	size mismatch for mamba_spectra.15.conv1d.bias: copying a param with shape torch.Size([3584]) from checkpoint, the shape in current model is torch.Size([3104]).
	size mismatch for mamba_spectra.16.in_proj.weight: copying a param with shape torch.Size([6704, 1536]) from checkpoint, the shape in current model is torch.Size([6224, 1536]).
	size mismatch for mamba_spectra.16.conv1d.weight: copying a param with shape torch.Size([3584, 1, 4]) from checkpoint, the shape in current model is torch.Size([3104, 1, 4]).
	size mismatch for mamba_spectra.16.conv1d.bias: copying a param with shape torch.Size([3584]) from checkpoint, the shape in current model is torch.Size([3104]).
	size mismatch for mamba_spectra.17.in_proj.weight: copying a param with shape torch.Size([6704, 1536]) from checkpoint, the shape in current model is torch.Size([6224, 1536]).
	size mismatch for mamba_spectra.17.conv1d.weight: copying a param with shape torch.Size([3584, 1, 4]) from checkpoint, the shape in current model is torch.Size([3104, 1, 4]).
	size mismatch for mamba_spectra.17.conv1d.bias: copying a param with shape torch.Size([3584]) from checkpoint, the shape in current model is torch.Size([3104]).
	size mismatch for mamba_spectra.18.in_proj.weight: copying a param with shape torch.Size([6704, 1536]) from checkpoint, the shape in current model is torch.Size([6224, 1536]).
	size mismatch for mamba_spectra.18.conv1d.weight: copying a param with shape torch.Size([3584, 1, 4]) from checkpoint, the shape in current model is torch.Size([3104, 1, 4]).
	size mismatch for mamba_spectra.18.conv1d.bias: copying a param with shape torch.Size([3584]) from checkpoint, the shape in current model is torch.Size([3104]).
	size mismatch for mamba_spectra.19.in_proj.weight: copying a param with shape torch.Size([6704, 1536]) from checkpoint, the shape in current model is torch.Size([6224, 1536]).
	size mismatch for mamba_spectra.19.conv1d.weight: copying a param with shape torch.Size([3584, 1, 4]) from checkpoint, the shape in current model is torch.Size([3104, 1, 4]).
	size mismatch for mamba_spectra.19.conv1d.bias: copying a param with shape torch.Size([3584]) from checkpoint, the shape in current model is torch.Size([3104]).
	size mismatch for mamba_gaia.0.in_proj.weight: copying a param with shape torch.Size([6704, 1536]) from checkpoint, the shape in current model is torch.Size([6224, 1536]).
	size mismatch for mamba_gaia.0.conv1d.weight: copying a param with shape torch.Size([3584, 1, 4]) from checkpoint, the shape in current model is torch.Size([3104, 1, 4]).
	size mismatch for mamba_gaia.0.conv1d.bias: copying a param with shape torch.Size([3584]) from checkpoint, the shape in current model is torch.Size([3104]).
	size mismatch for mamba_gaia.1.in_proj.weight: copying a param with shape torch.Size([6704, 1536]) from checkpoint, the shape in current model is torch.Size([6224, 1536]).
	size mismatch for mamba_gaia.1.conv1d.weight: copying a param with shape torch.Size([3584, 1, 4]) from checkpoint, the shape in current model is torch.Size([3104, 1, 4]).
	size mismatch for mamba_gaia.1.conv1d.bias: copying a param with shape torch.Size([3584]) from checkpoint, the shape in current model is torch.Size([3104]).
	size mismatch for mamba_gaia.2.in_proj.weight: copying a param with shape torch.Size([6704, 1536]) from checkpoint, the shape in current model is torch.Size([6224, 1536]).
	size mismatch for mamba_gaia.2.conv1d.weight: copying a param with shape torch.Size([3584, 1, 4]) from checkpoint, the shape in current model is torch.Size([3104, 1, 4]).
	size mismatch for mamba_gaia.2.conv1d.bias: copying a param with shape torch.Size([3584]) from checkpoint, the shape in current model is torch.Size([3104]).
	size mismatch for mamba_gaia.3.in_proj.weight: copying a param with shape torch.Size([6704, 1536]) from checkpoint, the shape in current model is torch.Size([6224, 1536]).
	size mismatch for mamba_gaia.3.conv1d.weight: copying a param with shape torch.Size([3584, 1, 4]) from checkpoint, the shape in current model is torch.Size([3104, 1, 4]).
	size mismatch for mamba_gaia.3.conv1d.bias: copying a param with shape torch.Size([3584]) from checkpoint, the shape in current model is torch.Size([3104]).
	size mismatch for mamba_gaia.4.in_proj.weight: copying a param with shape torch.Size([6704, 1536]) from checkpoint, the shape in current model is torch.Size([6224, 1536]).
	size mismatch for mamba_gaia.4.conv1d.weight: copying a param with shape torch.Size([3584, 1, 4]) from checkpoint, the shape in current model is torch.Size([3104, 1, 4]).
	size mismatch for mamba_gaia.4.conv1d.bias: copying a param with shape torch.Size([3584]) from checkpoint, the shape in current model is torch.Size([3104]).
	size mismatch for mamba_gaia.5.in_proj.weight: copying a param with shape torch.Size([6704, 1536]) from checkpoint, the shape in current model is torch.Size([6224, 1536]).
	size mismatch for mamba_gaia.5.conv1d.weight: copying a param with shape torch.Size([3584, 1, 4]) from checkpoint, the shape in current model is torch.Size([3104, 1, 4]).
	size mismatch for mamba_gaia.5.conv1d.bias: copying a param with shape torch.Size([3584]) from checkpoint, the shape in current model is torch.Size([3104]).
	size mismatch for mamba_gaia.6.in_proj.weight: copying a param with shape torch.Size([6704, 1536]) from checkpoint, the shape in current model is torch.Size([6224, 1536]).
	size mismatch for mamba_gaia.6.conv1d.weight: copying a param with shape torch.Size([3584, 1, 4]) from checkpoint, the shape in current model is torch.Size([3104, 1, 4]).
	size mismatch for mamba_gaia.6.conv1d.bias: copying a param with shape torch.Size([3584]) from checkpoint, the shape in current model is torch.Size([3104]).
	size mismatch for mamba_gaia.7.in_proj.weight: copying a param with shape torch.Size([6704, 1536]) from checkpoint, the shape in current model is torch.Size([6224, 1536]).
	size mismatch for mamba_gaia.7.conv1d.weight: copying a param with shape torch.Size([3584, 1, 4]) from checkpoint, the shape in current model is torch.Size([3104, 1, 4]).
	size mismatch for mamba_gaia.7.conv1d.bias: copying a param with shape torch.Size([3584]) from checkpoint, the shape in current model is torch.Size([3104]).
	size mismatch for mamba_gaia.8.in_proj.weight: copying a param with shape torch.Size([6704, 1536]) from checkpoint, the shape in current model is torch.Size([6224, 1536]).
	size mismatch for mamba_gaia.8.conv1d.weight: copying a param with shape torch.Size([3584, 1, 4]) from checkpoint, the shape in current model is torch.Size([3104, 1, 4]).
	size mismatch for mamba_gaia.8.conv1d.bias: copying a param with shape torch.Size([3584]) from checkpoint, the shape in current model is torch.Size([3104]).
	size mismatch for mamba_gaia.9.in_proj.weight: copying a param with shape torch.Size([6704, 1536]) from checkpoint, the shape in current model is torch.Size([6224, 1536]).
	size mismatch for mamba_gaia.9.conv1d.weight: copying a param with shape torch.Size([3584, 1, 4]) from checkpoint, the shape in current model is torch.Size([3104, 1, 4]).
	size mismatch for mamba_gaia.9.conv1d.bias: copying a param with shape torch.Size([3584]) from checkpoint, the shape in current model is torch.Size([3104]).
	size mismatch for mamba_gaia.10.in_proj.weight: copying a param with shape torch.Size([6704, 1536]) from checkpoint, the shape in current model is torch.Size([6224, 1536]).
	size mismatch for mamba_gaia.10.conv1d.weight: copying a param with shape torch.Size([3584, 1, 4]) from checkpoint, the shape in current model is torch.Size([3104, 1, 4]).
	size mismatch for mamba_gaia.10.conv1d.bias: copying a param with shape torch.Size([3584]) from checkpoint, the shape in current model is torch.Size([3104]).
	size mismatch for mamba_gaia.11.in_proj.weight: copying a param with shape torch.Size([6704, 1536]) from checkpoint, the shape in current model is torch.Size([6224, 1536]).
	size mismatch for mamba_gaia.11.conv1d.weight: copying a param with shape torch.Size([3584, 1, 4]) from checkpoint, the shape in current model is torch.Size([3104, 1, 4]).
	size mismatch for mamba_gaia.11.conv1d.bias: copying a param with shape torch.Size([3584]) from checkpoint, the shape in current model is torch.Size([3104]).
	size mismatch for mamba_gaia.12.in_proj.weight: copying a param with shape torch.Size([6704, 1536]) from checkpoint, the shape in current model is torch.Size([6224, 1536]).
	size mismatch for mamba_gaia.12.conv1d.weight: copying a param with shape torch.Size([3584, 1, 4]) from checkpoint, the shape in current model is torch.Size([3104, 1, 4]).
	size mismatch for mamba_gaia.12.conv1d.bias: copying a param with shape torch.Size([3584]) from checkpoint, the shape in current model is torch.Size([3104]).
	size mismatch for mamba_gaia.13.in_proj.weight: copying a param with shape torch.Size([6704, 1536]) from checkpoint, the shape in current model is torch.Size([6224, 1536]).
	size mismatch for mamba_gaia.13.conv1d.weight: copying a param with shape torch.Size([3584, 1, 4]) from checkpoint, the shape in current model is torch.Size([3104, 1, 4]).
	size mismatch for mamba_gaia.13.conv1d.bias: copying a param with shape torch.Size([3584]) from checkpoint, the shape in current model is torch.Size([3104]).
	size mismatch for mamba_gaia.14.in_proj.weight: copying a param with shape torch.Size([6704, 1536]) from checkpoint, the shape in current model is torch.Size([6224, 1536]).
	size mismatch for mamba_gaia.14.conv1d.weight: copying a param with shape torch.Size([3584, 1, 4]) from checkpoint, the shape in current model is torch.Size([3104, 1, 4]).
	size mismatch for mamba_gaia.14.conv1d.bias: copying a param with shape torch.Size([3584]) from checkpoint, the shape in current model is torch.Size([3104]).
	size mismatch for mamba_gaia.15.in_proj.weight: copying a param with shape torch.Size([6704, 1536]) from checkpoint, the shape in current model is torch.Size([6224, 1536]).
	size mismatch for mamba_gaia.15.conv1d.weight: copying a param with shape torch.Size([3584, 1, 4]) from checkpoint, the shape in current model is torch.Size([3104, 1, 4]).
	size mismatch for mamba_gaia.15.conv1d.bias: copying a param with shape torch.Size([3584]) from checkpoint, the shape in current model is torch.Size([3104]).
	size mismatch for mamba_gaia.16.in_proj.weight: copying a param with shape torch.Size([6704, 1536]) from checkpoint, the shape in current model is torch.Size([6224, 1536]).
	size mismatch for mamba_gaia.16.conv1d.weight: copying a param with shape torch.Size([3584, 1, 4]) from checkpoint, the shape in current model is torch.Size([3104, 1, 4]).
	size mismatch for mamba_gaia.16.conv1d.bias: copying a param with shape torch.Size([3584]) from checkpoint, the shape in current model is torch.Size([3104]).
	size mismatch for mamba_gaia.17.in_proj.weight: copying a param with shape torch.Size([6704, 1536]) from checkpoint, the shape in current model is torch.Size([6224, 1536]).
	size mismatch for mamba_gaia.17.conv1d.weight: copying a param with shape torch.Size([3584, 1, 4]) from checkpoint, the shape in current model is torch.Size([3104, 1, 4]).
	size mismatch for mamba_gaia.17.conv1d.bias: copying a param with shape torch.Size([3584]) from checkpoint, the shape in current model is torch.Size([3104]).
	size mismatch for mamba_gaia.18.in_proj.weight: copying a param with shape torch.Size([6704, 1536]) from checkpoint, the shape in current model is torch.Size([6224, 1536]).
	size mismatch for mamba_gaia.18.conv1d.weight: copying a param with shape torch.Size([3584, 1, 4]) from checkpoint, the shape in current model is torch.Size([3104, 1, 4]).
	size mismatch for mamba_gaia.18.conv1d.bias: copying a param with shape torch.Size([3584]) from checkpoint, the shape in current model is torch.Size([3104]).
	size mismatch for mamba_gaia.19.in_proj.weight: copying a param with shape torch.Size([6704, 1536]) from checkpoint, the shape in current model is torch.Size([6224, 1536]).
	size mismatch for mamba_gaia.19.conv1d.weight: copying a param with shape torch.Size([3584, 1, 4]) from checkpoint, the shape in current model is torch.Size([3104, 1, 4]).
	size mismatch for mamba_gaia.19.conv1d.bias: copying a param with shape torch.Size([3584]) from checkpoint, the shape in current model is torch.Size([3104]).