In [1]:
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
import gc
import wandb
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import torch
from sklearn.model_selection import train_test_split
import gc
from sklearn.metrics import f1_score, precision_score, recall_score, hamming_loss, roc_auc_score
#from mamba_ssm import Mamba2
import matplotlib.pyplot as plt
from functools import partial

# General Helper Functions

In [2]:
class MultiModalBalancedMultiLabelDataset(Dataset):
    """
    A balanced multi-label dataset that returns (X_spectra, X_gaia, y).
    It uses the same balancing strategy as `BalancedMultiLabelDataset`.
    """
    def __init__(self, X_spectra, X_gaia, y, limit_per_label=201):
        """
        Args:
            X_spectra (torch.Tensor): [num_samples, num_spectra_features]
            X_gaia (torch.Tensor): [num_samples, num_gaia_features]
            y (torch.Tensor): [num_samples, num_classes], multi-hot labels
            limit_per_label (int): limit or target number of samples per label
        """
        self.X_spectra = X_spectra
        self.X_gaia = X_gaia
        self.y = y
        self.limit_per_label = limit_per_label
        self.num_classes = y.shape[1]
        self.indices = self.balance_classes()
        
    def balance_classes(self):
        indices = []
        class_counts = torch.sum(self.y, axis=0)
        for cls in range(self.num_classes):
            cls_indices = np.where(self.y[:, cls] == 1)[0]
            if len(cls_indices) < self.limit_per_label:
                if len(cls_indices) == 0:
                    # No samples for this class
                    continue
                extra_indices = np.random.choice(
                    cls_indices, self.limit_per_label - len(cls_indices), replace=True
                )
                cls_indices = np.concatenate([cls_indices, extra_indices])
            elif len(cls_indices) > self.limit_per_label:
                cls_indices = np.random.choice(cls_indices, self.limit_per_label, replace=False)
            indices.extend(cls_indices)
        indices = np.unique(indices)
        np.random.shuffle(indices)
        return indices

    def re_sample(self):
        self.indices = self.balance_classes()

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        index = self.indices[idx]
        return (
            self.X_spectra[index],  # spectra features
            self.X_gaia[index],     # gaia features
            self.y[index],          # multi-hot labels
        )
    
def calculate_class_weights(y):
    if y.ndim > 1:  
        class_counts = np.sum(y, axis=0)  
    else:
        class_counts = np.bincount(y)

    total_samples = y.shape[0] if y.ndim > 1 else len(y)
    class_counts = np.where(class_counts == 0, 1, class_counts)  # Prevent division by zero
    class_weights = total_samples / (len(class_counts) * class_counts)
    
    return class_weights

def calculate_metrics(y_true, y_pred):
    metrics = {
        "micro_f1": f1_score(y_true, y_pred, average='micro'),
        "macro_f1": f1_score(y_true, y_pred, average='macro'),
        "weighted_f1": f1_score(y_true, y_pred, average='weighted'),
        "micro_precision": precision_score(y_true, y_pred, average='micro', zero_division=1),
        "macro_precision": precision_score(y_true, y_pred, average='macro', zero_division=1),
        "weighted_precision": precision_score(y_true, y_pred, average='weighted', zero_division=1),
        "micro_recall": recall_score(y_true, y_pred, average='micro'),
        "macro_recall": recall_score(y_true, y_pred, average='macro'),
        "weighted_recall": recall_score(y_true, y_pred, average='weighted'),
        "hamming_loss": hamming_loss(y_true, y_pred)
    }
    
    # Check if there are at least two classes present in y_true
    #if len(np.unique(y_true)) > 1:
        #metrics["roc_auc"] = roc_auc_score(y_true, y_pred, average='macro', multi_class='ovr')
    #else:
       # metrics["roc_auc"] = None  # or you can set it to a default value or message
    
    return metrics

class CrossAttentionBlock(nn.Module):
    """
    A simple cross-attention block with a feed-forward sub-layer.
    """
    def __init__(self, d_model, n_heads=8):
        super().__init__()
        self.cross_attn = nn.MultiheadAttention(
            embed_dim=d_model, 
            num_heads=n_heads, 
            batch_first=True
        )
        self.norm1 = nn.LayerNorm(d_model)
        
        self.ffn = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.GELU(),
            nn.Linear(4 * d_model, d_model)
        )
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x_q, x_kv):
        """
        Args:
            x_q  : (batch_size, seq_len_q, d_model)
            x_kv : (batch_size, seq_len_kv, d_model)
        """
        # Cross-attention
        attn_output, _ = self.cross_attn(query=x_q, key=x_kv, value=x_kv)
        x = self.norm1(x_q + attn_output)

        # Feed forward
        ffn_out = self.ffn(x)
        x = self.norm2(x + ffn_out)
        
        return x

class FeatureTokenizer(nn.Module):
    """
    Splits input features into tokens of a specified dimension.
    """
    def __init__(self, input_dim, token_dim, d_model):
        """
        Args:
            input_dim: Dimension of the input features
            token_dim: Dimension of each token
            d_model: Model dimension that each token will be embedded to
        """
        super().__init__()
        self.input_dim = input_dim
        self.token_dim = token_dim
        
        # Calculate number of tokens based on input dimension and token dimension
        self.num_tokens = (input_dim + token_dim - 1) // token_dim  # Ceiling division
        
        # Padding to ensure input_dim is divisible by token_dim
        self.padded_dim = self.num_tokens * token_dim
        
        # Linear projection to embed each token to d_model
        self.token_embed = nn.Linear(token_dim, d_model)
        
    def forward(self, x):
        """
        Args:
            x: Input tensor of shape [batch_size, input_dim]
            
        Returns:
            Tokenized tensor of shape [batch_size, num_tokens, d_model]
        """
        batch_size = x.shape[0]
        
        # Pad input if needed
        if self.input_dim < self.padded_dim:
            padding = torch.zeros(batch_size, self.padded_dim - self.input_dim, 
                                 dtype=x.dtype, device=x.device)
            x = torch.cat([x, padding], dim=1)
        
        # Reshape into tokens
        x = x.view(batch_size, self.num_tokens, self.token_dim)
        
        # Embed each token to d_model
        x = self.token_embed(x)
        
        return x

In [3]:
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

def train_model_fusion(
    model,
    train_loader,
    val_loader,
    test_loader,
    num_epochs=100,
    lr=2.5e-3,
    max_patience=20,
    device='cuda'
):
    model = model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr)


    total_step_counter = 0 # Initialize step counter to know when to stop updating the schedulers
    max_total_steps = num_epochs * len(train_loader)  # Initial estimation of total steps
    scheduler = optim.lr_scheduler.OneCycleLR(
        optimizer, max_lr=lr, total_steps=max_total_steps
    )


    # Compute class weights based on the training set
    all_labels = []
    for _, _, y_batch in train_loader:
        all_labels.extend(y_batch.cpu().numpy())
    
    class_weights = calculate_class_weights(np.array(all_labels))
    class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)
    criterion = nn.BCEWithLogitsLoss(pos_weight=class_weights)
    
    best_val_loss = float('inf')
    patience = max_patience

    for epoch in range(num_epochs):
        # Resample training data
        train_loader.dataset.re_sample()

        # Update max_total_steps if needed
        current_max_steps = (epoch + 1) * len(train_loader) + (num_epochs - epoch - 1) * len(train_loader)
        if current_max_steps != max_total_steps:
            #print(f"Adjusting max steps: {max_total_steps} -> {current_max_steps}")
            max_total_steps = current_max_steps

        # Recompute class weights if needed
        all_labels = []
        for _, _, y_batch in train_loader:
            all_labels.extend(y_batch.cpu().numpy())
        class_weights = calculate_class_weights(np.array(all_labels))
        class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)
        criterion = nn.BCEWithLogitsLoss(pos_weight=class_weights)

        # --- Training ---
        model.train()
        train_loss, train_acc = 0.0, 0.0
        for X_spc, X_ga, y_batch in train_loader:
            X_spc, X_ga, y_batch = X_spc.to(device), X_ga.to(device), y_batch.to(device)
            optimizer.zero_grad()
            outputs = model(X_spc, X_ga)
            loss = criterion(outputs, y_batch)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * X_spc.size(0)
            predicted = (torch.sigmoid(outputs) > 0.5).float()
            correct = (predicted == y_batch).float()
            train_acc += correct.mean(dim=1).mean().item()

            # Only step if we haven't reached max steps
            if total_step_counter < max_total_steps:
                scheduler.step()
                total_step_counter += 1
            else:
                print(f"Reached maximum steps ({max_total_steps}), skipping scheduler step")

        # --- Validation ---
        model.eval()
        val_loss, val_acc = 0.0, 0.0
        with torch.no_grad():
            for X_spc, X_ga, y_batch in val_loader:
                X_spc, X_ga, y_batch = X_spc.to(device), X_ga.to(device), y_batch.to(device)
                outputs = model(X_spc, X_ga)
                loss = criterion(outputs, y_batch)
                val_loss += loss.item() * X_spc.size(0)
                predicted = (torch.sigmoid(outputs) > 0.5).float()
                correct = (predicted == y_batch).float()
                val_acc += correct.mean(dim=1).mean().item()

        # --- Test metrics (optional or do after training) ---
        test_loss, test_acc = 0.0, 0.0
        y_true, y_pred = [], []
        with torch.no_grad():
            for X_spc, X_ga, y_batch in test_loader:
                X_spc, X_ga, y_batch = X_spc.to(device), X_ga.to(device), y_batch.to(device)
                outputs = model(X_spc, X_ga)
                loss = criterion(outputs, y_batch)
                test_loss += loss.item() * X_spc.size(0)
                
                predicted = (torch.sigmoid(outputs) > 0.5).float()
                correct = (predicted == y_batch).float()
                test_acc += correct.mean(dim=1).mean().item()

                y_true.extend(y_batch.cpu().numpy())
                y_pred.extend(predicted.cpu().numpy())

        # Compute multi-label metrics as before
        all_metrics = calculate_metrics(np.array(y_true), np.array(y_pred))
        
        # Logging example
        wandb.log({
            "epoch": epoch,
            "train_loss": train_loss / len(train_loader.dataset),
            "val_loss": val_loss / len(val_loader.dataset),
            "train_acc": train_acc / len(train_loader),
            "val_acc": val_acc / len(val_loader),
            "test_loss": test_loss / len(test_loader.dataset),
            "test_acc": test_acc / len(test_loader),
            "lr": get_lr(optimizer),
            **all_metrics
        })

        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience = max_patience
            best_model = model.state_dict()
        else:
            patience -= 1
            if patience <= 0:
                print("Early stopping triggered.")
                break

    model.load_state_dict(best_model)
    model.eval()
    y_true, y_pred = [], []
    with torch.no_grad():
        for X_spc, X_ga, y_batch in test_loader:
            X_spc, X_ga, y_batch = X_spc.to(device), X_ga.to(device), y_batch.to(device)
            outputs = model(X_spc, X_ga)
            predicted = (torch.sigmoid(outputs) > 0.5).float()
            y_true.extend(y_batch.cpu().numpy())
            y_pred.extend(predicted.cpu().numpy())
    return model


# Initializing Batch

In [4]:
batch_size = 16
batch_limit = int(batch_size / 2.5)

# Load datasets
#X_train_full = pd.read_pickle("Pickles/train_data_transformed2.pkl")
#X_test_full = pd.read_pickle("Pickles/test_data_transformed.pkl")
# classes = pd.read_pickle("Pickles/Updated_list_of_Classes.pkl")
import pickle
# Open them in a cross-platform way
with open("Pickles/Updated_List_of_Classes_ubuntu.pkl", "rb") as f:
    classes = pickle.load(f)  # This reads the actual data
with open("Pickles/train_data_transformed_ubuntu.pkl", "rb") as f:
    X_train_full = pickle.load(f)
with open("Pickles/test_data_transformed_ubuntu.pkl", "rb") as f:
    X_test_full = pickle.load(f)




# Extract labels
y_train_full = X_train_full[classes]
y_test = X_test_full[classes]

# Drop labels from both datasets
X_train_full.drop(classes, axis=1, inplace=True)
X_test_full.drop(classes, axis=1, inplace=True)


# Columns for spectral data (assuming all remaining columns after removing Gaia are spectra)
gaia_columns = ["parallax", "ra", "dec", "ra_error", "dec_error", "parallax_error", "pmra", "pmdec", 
                "pmra_error", "pmdec_error", "phot_g_mean_flux", "flagnopllx", "phot_g_mean_flux_error", 
                "phot_bp_mean_flux", "phot_rp_mean_flux", "phot_bp_mean_flux_error", "phot_rp_mean_flux_error", 
                "flagnoflux"]

# Spectra data (everything that is not Gaia-related) and the column 'otype'
X_train_spectra = X_train_full.drop(columns={"otype", "obsid", *gaia_columns})
X_test_spectra = X_test_full.drop(columns={"otype", "obsid", *gaia_columns})

# Gaia data (only the selected columns)
X_train_gaia = X_train_full[gaia_columns]
X_test_gaia = X_test_full[gaia_columns]

# Count nans and infs in x_train_gaia
print(X_train_gaia.isnull().sum())
print(X_train_gaia.isin([np.inf, -np.inf]).sum())


# Free up memory
del X_train_full, X_test_full
gc.collect()



# Split training set into training and validation
X_train_spectra, X_val_spectra, X_train_gaia, X_val_gaia, y_train, y_val = train_test_split(
    X_train_spectra, X_train_gaia, y_train_full, test_size=0.2, random_state=42
)

# Free memory
del y_train_full
gc.collect()



# Convert spectra and Gaia data into PyTorch tensors
X_train_spectra = torch.tensor(X_train_spectra.values, dtype=torch.float32)
X_val_spectra = torch.tensor(X_val_spectra.values, dtype=torch.float32)
X_test_spectra = torch.tensor(X_test_spectra.values, dtype=torch.float32)



X_train_gaia = torch.tensor(X_train_gaia.values, dtype=torch.float32)
X_val_gaia = torch.tensor(X_val_gaia.values, dtype=torch.float32)
X_test_gaia = torch.tensor(X_test_gaia.values, dtype=torch.float32)

y_train = torch.tensor(y_train.values, dtype=torch.float32)
y_val = torch.tensor(y_val.values, dtype=torch.float32)
y_test = torch.tensor(y_test.values, dtype=torch.float32)

# Print dataset shapes
print(f"X_train_spectra shape: {X_train_spectra.shape}")
print(f"X_val_spectra shape: {X_val_spectra.shape}")
print(f"X_test_spectra shape: {X_test_spectra.shape}")

print(f"X_train_gaia shape: {X_train_gaia.shape}")
print(f"X_val_gaia shape: {X_val_gaia.shape}")
print(f"X_test_gaia shape: {X_test_gaia.shape}")
train_test_split
print(f"y_train shape: {y_train.shape}")
print(f"y_val shape: {y_val.shape}")
print(f"y_test shape: {y_test.shape}")


train_dataset = MultiModalBalancedMultiLabelDataset(X_train_spectra, X_train_gaia, y_train, limit_per_label=batch_limit)
val_dataset = MultiModalBalancedMultiLabelDataset(X_val_spectra, X_val_gaia, y_val, limit_per_label=batch_limit)
test_dataset = MultiModalBalancedMultiLabelDataset(X_test_spectra, X_test_gaia, y_test, limit_per_label=batch_limit)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# print the number of samples in each dataset
print(f"Train dataset: {len(train_dataset)} samples")
print(f"Validation dataset: {len(val_dataset)} samples")
print(f"Test dataset: {len(test_dataset)} samples")


parallax                   0
ra                         0
dec                        0
ra_error                   0
dec_error                  0
parallax_error             0
pmra                       0
pmdec                      0
pmra_error                 0
pmdec_error                0
phot_g_mean_flux           0
flagnopllx                 0
phot_g_mean_flux_error     0
phot_bp_mean_flux          0
phot_rp_mean_flux          0
phot_bp_mean_flux_error    0
phot_rp_mean_flux_error    0
flagnoflux                 0
dtype: int64
parallax                   0
ra                         0
dec                        0
ra_error                   0
dec_error                  0
parallax_error             0
pmra                       0
pmdec                      0
pmra_error                 0
pmdec_error                0
phot_g_mean_flux           0
flagnopllx                 0
phot_g_mean_flux_error     0
phot_bp_mean_flux          0
phot_rp_mean_flux          0
phot_bp_mean_flux_error    0
p

In [5]:
if __name__ == "__main__":
    train_dataset = MultiModalBalancedMultiLabelDataset(
        X_train_spectra, X_train_gaia, y_train, limit_per_label=batch_limit
    )
    val_dataset = MultiModalBalancedMultiLabelDataset(
        X_val_spectra, X_val_gaia, y_val, limit_per_label=batch_limit
    )
    test_dataset = MultiModalBalancedMultiLabelDataset(
        X_test_spectra, X_test_gaia, y_test, limit_per_label=batch_limit
    )

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    # print the number of samples in each dataset
    print(f"Train samples: {len(train_dataset)}")
    print(f"Val samples: {len(val_dataset)}")
    print(f"Test samples: {len(test_dataset)}")


Train samples: 288
Val samples: 230
Test samples: 256


In [6]:
import torch
import torch.nn as nn
from functools import partial

# Import the needed components from your MambaOut implementation
from timm.models.layers import DropPath

class GatedCNNBlock(nn.Module):
    """Adaptation of GatedCNNBlock for sequence data with dynamic kernel size adaptation"""
    def __init__(self, dim, d_conv=4, expand=2, drop_path=0.):
        super().__init__()
        self.norm = nn.LayerNorm(dim, eps=1e-6)
        hidden = int(expand * dim)
        self.fc1 = nn.Linear(dim, hidden * 2)
        self.act = nn.GELU()
        
        # Store these for dynamic convolution sizing
        self.d_conv = d_conv
        self.hidden = hidden
        
        self.fc2 = nn.Linear(hidden, dim)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        
        # Use simpler approach for sequence length 1 (common case)
        # This avoids dynamic convolution creation
        if d_conv == 1:
            self.use_identity_for_length_1 = True

        
        # Cache for static convolution with kernel size 1 (for length 1 sequences)
        if d_conv == 1:
            self.conv1 = nn.Conv1d(
                in_channels=hidden,
                out_channels=hidden, 
                kernel_size=1,
                padding=0,
                groups=hidden
            )
        else:
            # Dynamic convolution for other lengths
            self.conv = nn.Conv1d(
                in_channels=hidden,
                out_channels=hidden, 
                kernel_size=d_conv,
                padding=(d_conv - 1) // 2,
                groups=hidden
            )

    def forward(self, x):
        # Input shape: [B, seq_len, dim]
        shortcut = x
        x = self.norm(x)
        
        # Split the channels for gating mechanism
        x = self.fc1(x)  # [B, seq_len, hidden*2]
        g, c = torch.chunk(x, 2, dim=-1)  # Each: [B, seq_len, hidden]
        
        # Get sequence length
        batch_size, seq_len, channels = c.shape
        
        # Apply gating mechanism
        c_permuted = c.permute(0, 2, 1)  # [B, hidden, seq_len]
        
        # Special case for sequence length 1 
        if seq_len == 1 and self.use_identity_for_length_1:
            # Use the pre-created kernel size 1 conv, which is like identity but keeps channels
            c_conv = self.conv1(c_permuted)
        else:
            # For other sequence lengths, fallback to kernel size 1 to avoid issues
            # The conv1 layer is already initialized and on the correct device
            c_conv = self.conv(c_permuted)
            c_conv = c_conv[:, :, :seq_len] # Ensure we only take the valid part
        
        c_final = c_conv.permute(0, 2, 1)  # [B, seq_len, hidden]
        
        # Gating mechanism
        x = self.fc2(self.act(g) * c_final)  # [B, seq_len, dim]
        
        x = self.drop_path(x)
        return x + shortcut
    
class GatedCNNBlock(nn.Module):
    """Simplified and fixed GatedCNNBlock that preserves sequence length"""
    def __init__(self, dim, d_conv=4, expand=2, drop_path=0.):
        super().__init__()
        self.norm = nn.LayerNorm(dim, eps=1e-6)
        hidden = int(expand * dim)
        self.fc1 = nn.Linear(dim, hidden * 2)
        self.act = nn.GELU()
        
        # Properly calculate padding to ensure output length matches input length
        # For kernel_size k, padding needed is (k-1)/2, rounded up for even kernels
        self.d_conv = d_conv
        padding = (d_conv - 1) // 2
        
        # Single convolution with proper padding
        self.conv = nn.Conv1d(
            in_channels=hidden,
            out_channels=hidden, 
            kernel_size=d_conv,
            padding=padding,
            groups=hidden
        )
        
        self.fc2 = nn.Linear(hidden, dim)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x):
        # Input shape: [B, seq_len, dim]
        shortcut = x
        x = self.norm(x)
        
        # Split for gating
        x = self.fc1(x)
        g, c = torch.chunk(x, 2, dim=-1)
        
        # Check shapes before processing
        batch_size, seq_len, channels = c.shape
        
        # Apply convolution
        c_permuted = c.permute(0, 2, 1)  # [B, hidden, seq_len]
        c_conv = self.conv(c_permuted)
        
        # Ensure output sequence length matches input
        if c_conv.size(2) != seq_len:
            if c_conv.size(2) < seq_len:
                # Pad if shorter
                padding = torch.zeros(
                    batch_size, channels, seq_len - c_conv.size(2),
                    device=c_conv.device, dtype=c_conv.dtype
                )
                c_conv = torch.cat([c_conv, padding], dim=2)
            else:
                # Truncate if longer
                c_conv = c_conv[:, :, :seq_len]
        
        c_final = c_conv.permute(0, 2, 1)  # [B, seq_len, hidden]
        
        # Perform gating and output projection
        x = self.fc2(self.act(g) * c_final)
        x = self.drop_path(x)
        
        return x + shortcut

class SequenceMambaOut(nn.Module):
    """Adaptation of MambaOut for sequence data with a single stage"""
    def __init__(self, d_model,  d_conv=4, expand=2, depth=1, drop_path=0.):
        super().__init__()
        
        # Create a sequence of GatedCNNBlocks
        self.blocks = nn.Sequential(
            *[GatedCNNBlock(
                dim=d_model,
                d_conv=d_conv,
                expand=expand,
                drop_path=drop_path
            ) for _ in range(depth)]
        )
    
    def forward(self, x):
        return self.blocks(x)

class CrossAttentionBlock(nn.Module):
    def __init__(self, dim, n_heads=8):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.attention = nn.MultiheadAttention(
            embed_dim=dim,
            num_heads=n_heads,
            batch_first=True
        )
    
    def forward(self, x, context):
        """
        x: (B, seq_len_x, dim)
        context: (B, seq_len_context, dim)
        """
        x_norm = self.norm(x)
        attn_output, _ = self.attention(
            query=x_norm,
            key=context,
            value=context
        )
        return x + attn_output


    
class StarClassifierFusionMambaOut(nn.Module):
    def __init__(
        self,
        d_model_spectra,
        d_model_gaia,
        num_classes,
        input_dim_spectra,
        input_dim_gaia,
        token_dim_spectra,  # New parameter for token size
        token_dim_gaia,      # New parameter for token size
        n_layers=6,
        use_cross_attention=True,
        n_cross_attn_heads=8,
        d_conv=4,
        expand=2,
    ):
        """
        Args:
            d_model_spectra (int): embedding dimension for the spectra MAMBA
            d_model_gaia (int): embedding dimension for the gaia MAMBA
            num_classes (int): multi-label classification
            input_dim_spectra (int): # of features for spectra
            input_dim_gaia (int): # of features for gaia
            token_dim_spectra (int): size of each token for spectra features
            token_dim_gaia (int): size of each token for gaia features
            n_layers (int): depth for each MAMBA
            use_cross_attention (bool): whether to use cross-attention
            n_cross_attn_heads (int): number of heads for cross-attention
        """
        super().__init__()

        # --- Feature Tokenizers ---
        self.tokenizer_spectra = FeatureTokenizer(
            input_dim=input_dim_spectra,
            token_dim=token_dim_spectra,
            d_model=d_model_spectra
        )
        
        self.tokenizer_gaia = FeatureTokenizer(
            input_dim=input_dim_gaia,
            token_dim=token_dim_gaia,
            d_model=d_model_gaia
        )

        # --- MambaOut for spectra ---
        self.mamba_spectra = nn.Sequential(
            *[SequenceMambaOut(
                d_model=d_model_spectra,
                d_conv=d_conv,
                expand=expand,
                depth=1,
                drop_path=0.1 if i > 0 else 0.0,
            ) for i in range(n_layers)]
        )

        # --- MambaOut for gaia ---
        self.mamba_gaia = nn.Sequential(
            *[SequenceMambaOut(
                d_model=d_model_gaia,
                d_conv=d_conv,
                expand=expand,
                depth=1,
                drop_path=0.1 if i > 0 else 0.0,
            ) for i in range(n_layers)]
        )

        # --- Cross Attention (Optional) ---
        self.use_cross_attention = use_cross_attention
        if use_cross_attention:
            self.cross_attn_block_spectra = CrossAttentionBlock(d_model_spectra, n_heads=n_cross_attn_heads)
            self.cross_attn_block_gaia = CrossAttentionBlock(d_model_gaia, n_heads=n_cross_attn_heads)

        # --- Final Classifier ---
        fusion_dim = d_model_spectra + d_model_gaia
        self.classifier = nn.Sequential(
            nn.LayerNorm(fusion_dim),
            nn.Linear(fusion_dim, num_classes)
        )
    
    def forward(self, x_spectra, x_gaia):
        """
        x_spectra : (batch_size, input_dim_spectra)
        x_gaia    : (batch_size, input_dim_gaia)
        """
        # Tokenize input features
        # From [batch_size, input_dim] to [batch_size, num_tokens, d_model]
        x_spectra = self.tokenizer_spectra(x_spectra)  # (B, num_tokens_spectra, d_model_spectra)
        x_gaia = self.tokenizer_gaia(x_gaia)           # (B, num_tokens_gaia, d_model_gaia)

        # --- MambaOut encoding (each modality separately) ---
        x_spectra = self.mamba_spectra(x_spectra)  # (B, num_tokens_spectra, d_model_spectra)
        x_gaia = self.mamba_gaia(x_gaia)           # (B, num_tokens_gaia, d_model_gaia)

        # Optionally, use cross-attention to fuse the representations
        if self.use_cross_attention:
            # Cross-attention from spectra -> gaia
            x_spectra_fused = self.cross_attn_block_spectra(x_spectra, x_gaia)
            # Cross-attention from gaia -> spectra
            x_gaia_fused = self.cross_attn_block_gaia(x_gaia, x_spectra)
            
            # Update x_spectra and x_gaia
            x_spectra = x_spectra_fused
            x_gaia = x_gaia_fused
        
        # --- Pool across sequence dimension ---
        x_spectra = x_spectra.mean(dim=1)  # (B, d_model_spectra)
        x_gaia = x_gaia.mean(dim=1)        # (B, d_model_gaia)

        # --- Late Fusion by Concatenation ---
        x_fused = torch.cat([x_spectra, x_gaia], dim=-1)  # (B, d_model_spectra + d_model_gaia)

        # --- Final classification ---
        logits = self.classifier(x_fused)  # (B, num_classes)
        return logits

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from functools import partial

# Rotary Position Embeddings implementation
class RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_seq_len=4096):
        super().__init__()
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)
        self.max_seq_len = max_seq_len
        
        # Generate position embeddings once at initialization
        self._generate_embeddings()
        
    def _generate_embeddings(self):
        t = torch.arange(self.max_seq_len, dtype=torch.float)
        freqs = torch.einsum('i,j->ij', t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        cos = emb.cos().view(self.max_seq_len, 1, -1)
        sin = emb.sin().view(self.max_seq_len, 1, -1)
        self.register_buffer('cos_cached', cos)
        self.register_buffer('sin_cached', sin)
        
    def forward(self, seq_len):
        return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]

def rotate_half(x):
    """Rotate half the hidden dims of the input."""
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin):
    """Apply rotary position embeddings to q and k tensors."""
    # Handle the case where q and k have shape [batch_size, seq_len, head_dim]
    # or [batch_size, n_heads, seq_len, head_dim]
    if q.dim() == 3:
        # [batch_size, seq_len, head_dim] -> [batch_size, seq_len, 1, head_dim]
        q = q.unsqueeze(2)
        k = k.unsqueeze(2)
        # After this operation, we squeeze back
        squeeze_after = True
    else:
        squeeze_after = False
    
    # Reshape cos and sin for proper broadcasting
    # [seq_len, 1, head_dim] -> [1, seq_len, 1, head_dim]
    cos = cos.unsqueeze(0)
    sin = sin.unsqueeze(0)
    
    # Apply rotation
    q_rot = (q * cos) + (rotate_half(q) * sin)
    k_rot = (k * cos) + (rotate_half(k) * sin)
    
    if squeeze_after:
        q_rot = q_rot.squeeze(2)
        k_rot = k_rot.squeeze(2)
    
    return q_rot, k_rot

class RotarySelfAttention(nn.Module):
    """Self-attention with rotary position embeddings."""
    def __init__(self, dim, n_heads=8, dropout=0.1):
        super().__init__()
        self.dim = dim
        self.n_heads = n_heads
        self.head_dim = dim // n_heads
        assert self.head_dim * n_heads == dim, "dim must be divisible by n_heads"
        
        # QKV projections
        self.q_proj = nn.Linear(dim, dim)
        self.k_proj = nn.Linear(dim, dim)
        self.v_proj = nn.Linear(dim, dim)
        self.out_proj = nn.Linear(dim, dim)
        
        # Rotary positional embedding
        self.rope = RotaryEmbedding(self.head_dim)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        """
        Args:
            x: Input tensor of shape [batch_size, seq_len, dim]
            
        Returns:
            output: Tensor of same shape as input
        """
        batch_size, seq_len, _ = x.shape
        
        # Project to queries, keys, values
        q = self.q_proj(x)  # [batch_size, seq_len, dim]
        k = self.k_proj(x)  # [batch_size, seq_len, dim]
        v = self.v_proj(x)  # [batch_size, seq_len, dim]
        
        # Reshape for multi-head attention
        q = q.view(batch_size, seq_len, self.n_heads, self.head_dim)  
        k = k.view(batch_size, seq_len, self.n_heads, self.head_dim)
        v = v.view(batch_size, seq_len, self.n_heads, self.head_dim)
        
        # Get position embeddings
        cos, sin = self.rope(seq_len)
        
        # Apply rotary position embeddings to q and k
        q, k = apply_rotary_pos_emb(q, k, cos, sin)
        
        # Transpose for efficient batch matrix multiplication
        q = q.transpose(1, 2)  # [batch_size, n_heads, seq_len, head_dim]
        k = k.transpose(1, 2)  # [batch_size, n_heads, seq_len, head_dim]
        v = v.transpose(1, 2)  # [batch_size, n_heads, seq_len, head_dim]
        
        # Compute scaled dot-product attention
        attn_weights = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)  # [batch_size, n_heads, seq_len, seq_len]
        attn_weights = F.softmax(attn_weights, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        # Apply attention weights to values
        output = torch.matmul(attn_weights, v)  # [batch_size, n_heads, seq_len, head_dim]
        
        # Reshape back to original format
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.dim)
        
        # Apply output projection
        output = self.out_proj(output)
        
        return output

class TransformerBlock(nn.Module):
    """Transformer block with rotary self-attention and feed-forward network."""
    def __init__(self, dim, n_heads=8, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = RotarySelfAttention(dim, n_heads, dropout)
        self.norm2 = nn.LayerNorm(dim)
        self.ffn = nn.Sequential(
            nn.Linear(dim, 4 * dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(4 * dim, dim),
            nn.Dropout(dropout)
        )
    
    def forward(self, x):
        """
        Args:
            x: Input tensor of shape [batch_size, seq_len, dim]
        """
        # Self-attention with residual connection
        x = x + self.attn(self.norm1(x))
        
        # FFN with residual connection
        x = x + self.ffn(self.norm2(x))
        
        return x

class TransformerFeatureExtractor(nn.Module):
    """Stack of transformer blocks for feature extraction."""
    def __init__(self, d_model, n_layers=6, n_heads=8, dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList([
            TransformerBlock(d_model, n_heads, dropout)
            for _ in range(n_layers)
        ])
    
    def forward(self, x):
        """
        Args:
            x: Input tensor of shape [batch_size, seq_len, d_model]
        
        Returns:
            Processed tensor of same shape
        """
        for layer in self.layers:
            x = layer(x)
        return x

class CrossAttentionBlock(nn.Module):
    """
    Cross-attention block to attend from one modality to another.
    """
    def __init__(self, dim, n_heads=8):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.attention = nn.MultiheadAttention(
            embed_dim=dim,
            num_heads=n_heads,
            batch_first=True
        )
    
    def forward(self, x, context):
        """
        Args:
            x: Query tensor of shape [batch_size, seq_len_q, dim]
            context: Key/value tensor of shape [batch_size, seq_len_kv, dim]
        
        Returns:
            Output tensor of shape [batch_size, seq_len_q, dim]
        """
        x_norm = self.norm(x)
        attn_output, _ = self.attention(
            query=x_norm,
            key=context,
            value=context
        )
        return x + attn_output


class StarClassifierFusionTransformer(nn.Module):
    """Transformer-based feature extractor with tokenization for multi-modal fusion."""
    def __init__(
        self,
        d_model_spectra,
        d_model_gaia,
        num_classes,
        input_dim_spectra,
        input_dim_gaia,
        token_dim_spectra=64,  # Size of each token for spectra
        token_dim_gaia=2,      # Size of each token for gaia
        n_layers=6,
        n_heads=8,
        use_cross_attention=True,
        n_cross_attn_heads=8,
        dropout=0.1,
    ):
        """
        Args:
            d_model_spectra (int): embedding dimension for the spectra Transformer
            d_model_gaia (int): embedding dimension for the gaia Transformer
            num_classes (int): multi-label classification
            input_dim_spectra (int): # of features for spectra
            input_dim_gaia (int): # of features for gaia
            token_dim_spectra (int): size of each token for spectra features
            token_dim_gaia (int): size of each token for gaia features
            n_layers (int): depth for each Transformer
            n_heads (int): number of attention heads
            use_cross_attention (bool): whether to use cross-attention
            n_cross_attn_heads (int): number of heads for cross-attention
            dropout (float): dropout rate
        """
        super().__init__()

        # --- Feature Tokenizers ---
        self.tokenizer_spectra = FeatureTokenizer(
            input_dim=input_dim_spectra,
            token_dim=token_dim_spectra,
            d_model=d_model_spectra
        )
        
        self.tokenizer_gaia = FeatureTokenizer(
            input_dim=input_dim_gaia,
            token_dim=token_dim_gaia,
            d_model=d_model_gaia
        )

        # --- Transformer for spectra ---
        self.transformer_spectra = TransformerFeatureExtractor(
            d_model=d_model_spectra,
            n_layers=n_layers,
            n_heads=n_heads,
            dropout=dropout
        )

        # --- Transformer for gaia ---
        self.transformer_gaia = TransformerFeatureExtractor(
            d_model=d_model_gaia,
            n_layers=n_layers,
            n_heads=n_heads,
            dropout=dropout
        )

        # --- Cross Attention (Optional) ---
        self.use_cross_attention = use_cross_attention
        if use_cross_attention:
            self.cross_attn_block_spectra = CrossAttentionBlock(d_model_spectra, n_heads=n_cross_attn_heads)
            self.cross_attn_block_gaia = CrossAttentionBlock(d_model_gaia, n_heads=n_cross_attn_heads)

        # --- Final Classifier ---
        fusion_dim = d_model_spectra + d_model_gaia
        self.classifier = nn.Sequential(
            nn.LayerNorm(fusion_dim),
            nn.Dropout(dropout),
            nn.Linear(fusion_dim, num_classes)
        )
    
    def forward(self, x_spectra, x_gaia):
        """
        Args:
            x_spectra: Spectra features of shape [batch_size, input_dim_spectra]
            x_gaia: Gaia features of shape [batch_size, input_dim_gaia]
            
        Returns:
            logits: Classification logits of shape [batch_size, num_classes]
        """
        # Tokenize input features
        # From [batch_size, input_dim] to [batch_size, num_tokens, d_model]
        x_spectra_tokens = self.tokenizer_spectra(x_spectra)
        x_gaia_tokens = self.tokenizer_gaia(x_gaia)
        
        # Process through transformers
        x_spectra = self.transformer_spectra(x_spectra_tokens)  # [batch_size, num_tokens_spectra, d_model]
        x_gaia = self.transformer_gaia(x_gaia_tokens)          # [batch_size, num_tokens_gaia, d_model]

        # Optional cross-attention
        if self.use_cross_attention:
            x_spectra = self.cross_attn_block_spectra(x_spectra, x_gaia)
            x_gaia = self.cross_attn_block_gaia(x_gaia, x_spectra)
        
        # Global pooling over sequence dimension
        x_spectra = x_spectra.mean(dim=1)  # [batch_size, d_model]
        x_gaia = x_gaia.mean(dim=1)        # [batch_size, d_model]

        # Concatenate for fusion
        x_fused = torch.cat([x_spectra, x_gaia], dim=-1)  # [batch_size, 2*d_model]

        # Final classification
        logits = self.classifier(x_fused)  # [batch_size, num_classes]
        
        return logits

In [8]:
import torch
import torch.nn as nn
from mamba_ssm import Mamba2


class StarClassifierFusionMambaTokenized(nn.Module):
    def __init__(
        self,
        d_model_spectra,
        d_model_gaia,
        num_classes,
        input_dim_spectra,
        input_dim_gaia,
        token_dim_spectra=64,  # Size of each token for spectra
        token_dim_gaia=2,      # Size of each token for gaia
        n_layers=10,
        use_cross_attention=True,
        n_cross_attn_heads=8,
        d_state=256,
        d_conv=4,
        expand=2,
    ):
        """
        Args:
            d_model_spectra (int): embedding dimension for the spectra MAMBA
            d_model_gaia (int): embedding dimension for the gaia MAMBA
            num_classes (int): multi-label classification
            input_dim_spectra (int): # of features for spectra
            input_dim_gaia (int): # of features for gaia
            token_dim_spectra (int): size of each token for spectra features
            token_dim_gaia (int): size of each token for gaia features
            n_layers (int): depth for each MAMBA
            use_cross_attention (bool): whether to use cross-attention
            n_cross_attn_heads (int): number of heads for cross-attention
            d_state (int): state dimension for Mamba
            d_conv (int): convolution dimension for Mamba
            expand (int): expansion factor for Mamba
        """
        super().__init__()

        # --- Feature Tokenizers ---
        self.tokenizer_spectra = FeatureTokenizer(
            input_dim=input_dim_spectra,
            token_dim=token_dim_spectra,
            d_model=d_model_spectra
        )
        
        self.tokenizer_gaia = FeatureTokenizer(
            input_dim=input_dim_gaia,
            token_dim=token_dim_gaia,
            d_model=d_model_gaia
        )

        # --- MAMBA 2 for spectra ---
        self.mamba_spectra = nn.Sequential(
            *[Mamba2(
                d_model=d_model_spectra,
                d_state=d_state,
                d_conv=d_conv,
                expand=expand,
            ) for _ in range(n_layers)]
        )

        # --- MAMBA 2 for gaia ---
        self.mamba_gaia = nn.Sequential(
            *[Mamba2(
                d_model=d_model_gaia,
                d_state=d_state,
                d_conv=d_conv,
                expand=expand,
            ) for _ in range(n_layers)]
        )

        # --- Cross Attention (Optional) ---
        self.use_cross_attention = use_cross_attention
        if use_cross_attention:
            self.cross_attn_block_spectra = CrossAttentionBlock(d_model_spectra, n_heads=n_cross_attn_heads)
            self.cross_attn_block_gaia = CrossAttentionBlock(d_model_gaia, n_heads=n_cross_attn_heads)

        # --- Final Classifier ---
        fusion_dim = d_model_spectra + d_model_gaia
        self.classifier = nn.Sequential(
            nn.LayerNorm(fusion_dim),
            nn.Linear(fusion_dim, num_classes)
        )
    
    def forward(self, x_spectra, x_gaia):
        """
        Args:
            x_spectra: Spectra features of shape [batch_size, input_dim_spectra]
            x_gaia: Gaia features of shape [batch_size, input_dim_gaia]
            
        Returns:
            logits: Classification logits of shape [batch_size, num_classes]
        """
        # Tokenize input features
        # From [batch_size, input_dim] to [batch_size, num_tokens, d_model]
        x_spectra_tokens = self.tokenizer_spectra(x_spectra)
        x_gaia_tokens = self.tokenizer_gaia(x_gaia)
        
        # Process through Mamba models
        x_spectra = self.mamba_spectra(x_spectra_tokens)  # [batch_size, num_tokens_spectra, d_model]
        x_gaia = self.mamba_gaia(x_gaia_tokens)          # [batch_size, num_tokens_gaia, d_model]

        # Optional cross-attention
        if self.use_cross_attention:
            x_spectra = self.cross_attn_block_spectra(x_spectra, x_gaia)
            x_gaia = self.cross_attn_block_gaia(x_gaia, x_spectra)
        
        # Global pooling over sequence dimension
        x_spectra = x_spectra.mean(dim=1)  # [batch_size, d_model]
        x_gaia = x_gaia.mean(dim=1)        # [batch_size, d_model]

        # Concatenate for fusion
        x_fused = torch.cat([x_spectra, x_gaia], dim=-1)  # [batch_size, 2*d_model]

        # Final classification
        logits = self.classifier(x_fused)  # [batch_size, num_classes]
        
        return logits

In [9]:
egoinrign

NameError: name 'egoinrign' is not defined

In [None]:
if __name__ == "__main__":
    # Example config
    d_model_spectra = 2048
    d_model_gaia = 2048
    num_classes = 55
    input_dim_spectra = 3647
    input_dim_gaia = 18
    token_dim_spectra = 3647
    token_dim_gaia = 18
    n_layers = 20
    lr = 2.5e-6
    patience = 200
    num_epochs = 800
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print("Device:", device)

    # Initialize WandB
    wandb.init(project="ALLSTARS_multimodal_fusion_mamba_out")
    
    config = {
        "num_classes": num_classes,
        "d_model_spectra": d_model_spectra,
        "d_model_gaia": d_model_gaia,
        "input_dim_spectra": input_dim_spectra,
        "input_dim_gaia": input_dim_gaia,
        "n_layers": n_layers,
        "lr": lr,
        "patience": patience,
        "num_epochs": num_epochs
    }
    wandb.config.update(config)

    # Instantiate the fusion model
    # Try use_cross_attention=False for late-fusion, True for cross-attention
    model_fusion = StarClassifierFusionMambaOut(
        d_model_spectra=d_model_spectra,
        d_model_gaia=d_model_gaia,
        num_classes=num_classes,
        input_dim_spectra=input_dim_spectra,
        input_dim_gaia=input_dim_gaia,
        token_dim_gaia=token_dim_gaia,
        token_dim_spectra=token_dim_spectra,
        n_layers=n_layers,
        d_conv=1,
        use_cross_attention=True,  # set to False to compare with late fusion
        n_cross_attn_heads=8
    )
    model_fusion.to(device)

    # Print size of model in GB
    print(f"Model size: {sum(p.numel() for p in model_fusion.parameters()) / 1e9:.2f} GB")
    param_size = 0
    for param in model_fusion.parameters():
        param_size += param.nelement() * param.element_size()
    buffer_size = 0
    for buffer in model_fusion.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()

    size_all_mb = (param_size + buffer_size) / 1024**2
    print('model size: {:.3f}MB'.format(size_all_mb))
    # Compute parameter size
    param_size = sum(p.nelement() * p.element_size() for p in model_fusion.parameters())

    # Compute buffer size
    buffer_size = sum(b.nelement() * b.element_size() for b in model_fusion.buffers())

    # Total size in MB
    total_size_mb = (param_size + buffer_size) / (1024 ** 2)
    print(f"Model size: {total_size_mb:.3f} MB")

    print(model_fusion)
    # print number of parameters per layer
    for name, param in model_fusion.named_parameters():
        print(name, param.numel())
    print("Total number of parameters:", sum(p.numel() for p in model_fusion.parameters() if p.requires_grad))

    # Train the fusion model
    trained_fusion_model = train_model_fusion(
        model=model_fusion,
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader,
        num_epochs=num_epochs,
        lr=lr,
        max_patience=patience,
        device=device
    )

    wandb.finish()

# Save the model
torch.save(trained_fusion_model.state_dict(), "Models/model_fusion_mambaoutv3.pth")


Device: cuda


[34m[1mwandb[0m: Currently logged in as: [33mjoaocsgalmeida[0m ([33mjoaocsgalmeida-university-of-southampton[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


Model size: 1.05 GB
model size: 4001.071MB
Model size: 4001.071 MB
StarClassifierFusionMambaOut(
  (tokenizer_spectra): FeatureTokenizer(
    (token_embed): Linear(in_features=3647, out_features=2048, bias=True)
  )
  (tokenizer_gaia): FeatureTokenizer(
    (token_embed): Linear(in_features=18, out_features=2048, bias=True)
  )
  (mamba_spectra): Sequential(
    (0): SequenceMambaOut(
      (blocks): Sequential(
        (0): GatedCNNBlock(
          (norm): LayerNorm((2048,), eps=1e-06, elementwise_affine=True)
          (fc1): Linear(in_features=2048, out_features=8192, bias=True)
          (act): GELU(approximate='none')
          (conv): Conv1d(4096, 4096, kernel_size=(1,), stride=(1,), groups=4096)
          (fc2): Linear(in_features=4096, out_features=2048, bias=True)
          (drop_path): Identity()
        )
      )
    )
    (1): SequenceMambaOut(
      (blocks): Sequential(
        (0): GatedCNNBlock(
          (norm): LayerNorm((2048,), eps=1e-06, elementwise_affine=True)
  

0,1
epoch,▁▁▁▁▁▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇███
hamming_loss,█▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr,▁▁▁▂▂▅▅▅▆▆▇▇▇▇███████▇▇▇▇▆▆▆▆▆▅▅▄▄▄▄▄▄▄▃
macro_f1,▁▁▁▁▁▁▁▄▄▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇█▇▇▇██▇▇██████
macro_precision,▅▇████▇▇▇▆▄▅▃▃▄▃▃▃▂▂▂▃▃▁▂▁▂▂▂▃▂▃▂▂▂▂▃▃▃▂
macro_recall,▁▁▁▁▁▁▁▁▁▁▃▃▃▄▅▆▆▆▆▆▇▇▇▇▇▆▇▇▇█▇▇██▇█████
micro_f1,▂▁▁▁▁▂▂▂▃▄▅▅▅▆▅▆▆▆▆▆▆▇▇▇▆▇█▇▇▇▇█████████
micro_precision,▁▁▁▂▃█████▆▆▆▇▇▇▇▆▆▆▆▆▆▆▇▆▆▆▆▆▇▇▆▆▆▇▇▇▆▇
micro_recall,▇▃▁▁▁▁▁▁▁▁▃▃▃▃▅▅▅▅▅▆▆▇▇▇▆▇▇▇█▇██████████
test_acc,▁▆▇▇▇▇▇▇▇▇██████████████████████████████

0,1
epoch,546.0
hamming_loss,0.02296
lr,0.0
macro_f1,0.4396
macro_precision,0.78807
macro_recall,0.3945
micro_f1,0.61844
micro_precision,0.79819
micro_recall,0.50476
test_acc,0.97582


# MambaOut 19+18 tokens

In [None]:
if __name__ == "__main__":
    # Configuration for tokenized transformer model
    d_model_spectra = 2048
    d_model_gaia = 2048
    num_classes = 55
    input_dim_spectra = 3647
    input_dim_gaia = 18
    d_conv = 4
    
    # Token dimensions - these control the sequence length
    token_dim_spectra = 192  # Will create ~19 tokens for spectra (3647/192)
    token_dim_gaia = 1      # Will create 18 tokens for gaia (18/1)
    
    n_layers = 20  
    n_heads = 8
    
    lr = 2.5e-6
    patience = 200
    num_epochs = 800
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print("Device:", device)
    
    # Initialize WandB
    wandb.init(project="ALLSTARS_multimodal_fusion_MambaOut_18_19")
    config = {
        "model_type": "transformer_tokenized",
        "num_classes": num_classes,
        "d_model_spectra": d_model_spectra,
        "d_model_gaia": d_model_gaia,
        "input_dim_spectra": input_dim_spectra,
        "input_dim_gaia": input_dim_gaia,
        "token_dim_spectra": token_dim_spectra,
        "token_dim_gaia": token_dim_gaia,
        "num_tokens_spectra": (input_dim_spectra + token_dim_spectra - 1) // token_dim_spectra,
        "num_tokens_gaia": (input_dim_gaia + token_dim_gaia - 1) // token_dim_gaia,
        "n_layers": n_layers,
        "n_heads": n_heads,
        "use_cross_attention": True,
        "lr": lr,
        "patience": patience,
        "num_epochs": num_epochs
    }
    wandb.config.update(config)
    
    # Instantiate the tokenized MAMBAOut model
    model_MambaOut = StarClassifierFusionMambaOut(
        d_model_spectra=d_model_spectra,
        d_model_gaia=d_model_gaia,
        num_classes=num_classes,
        input_dim_spectra=input_dim_spectra,
        input_dim_gaia=input_dim_gaia,
        token_dim_spectra=token_dim_spectra,
        token_dim_gaia=token_dim_gaia,
        n_layers=n_layers,
        d_conv=d_conv,
        use_cross_attention=True,
        n_cross_attn_heads=8)
    model_MambaOut.to(device)
    
    # Print model statistics
    print(f"Number of spectra tokens: {(input_dim_spectra + token_dim_spectra - 1) // token_dim_spectra}")
    print(f"Number of gaia tokens: {(input_dim_gaia + token_dim_gaia - 1) // token_dim_gaia}")
    
    # Compute parameter size
    param_size = 0
    for param in model_MambaOut.parameters():
        param_size += param.nelement() * param.element_size()
    
    # Compute buffer size
    buffer_size = 0
    for buffer in model_MambaOut.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()

    print(model_MambaOut)
    # print number of parameters per layer
    for name, param in model_MambaOut.named_parameters():
        print(name, param.numel())
    
    # Total size in MB
    total_size_mb = (param_size + buffer_size) / (1024 ** 2)
    print(f"Model size: {total_size_mb:.3f} MB")
    print(f"Total parameters: {sum(p.numel() for p in model_MambaOut.parameters() if p.requires_grad)}")
    
    # Train the MAMBAOUT model
    trained_MAMBAOut_model = train_model_fusion(
        model=model_MambaOut,
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader,
        num_epochs=num_epochs,
        lr=lr,
        max_patience=patience,
        device=device
    )
    wandb.finish()
    
    # Save the model
    torch.save(trained_MAMBAOut_model.state_dict(), "Models/model_fusion_MambaOut_18_19_v2.pth")


Device: cuda


[34m[1mwandb[0m: Currently logged in as: [33mjoaocsgalmeida[0m ([33mjoaocsgalmeida-university-of-southampton[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


Number of spectra tokens: 19
Number of gaia tokens: 18
StarClassifierFusionMambaOut(
  (tokenizer_spectra): FeatureTokenizer(
    (token_embed): Linear(in_features=192, out_features=2048, bias=True)
  )
  (tokenizer_gaia): FeatureTokenizer(
    (token_embed): Linear(in_features=1, out_features=2048, bias=True)
  )
  (mamba_spectra): Sequential(
    (0): SequenceMambaOut(
      (blocks): Sequential(
        (0): GatedCNNBlock(
          (norm): LayerNorm((2048,), eps=1e-06, elementwise_affine=True)
          (fc1): Linear(in_features=2048, out_features=8192, bias=True)
          (act): GELU(approximate='none')
          (conv): Conv1d(4096, 4096, kernel_size=(4,), stride=(1,), padding=(1,), groups=4096)
          (fc2): Linear(in_features=4096, out_features=2048, bias=True)
          (drop_path): Identity()
        )
      )
    )
    (1): SequenceMambaOut(
      (blocks): Sequential(
        (0): GatedCNNBlock(
          (norm): LayerNorm((2048,), eps=1e-06, elementwise_affine=True)
  

0,1
epoch,▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇██
hamming_loss,█▄▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr,▁▁▁▁▁▂▂▂▂▃▄▄▄▅▅▆▇▇███████▇▇▇▇▇▆▆▆▆▆▅▅▅▄▃
macro_f1,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▃▃▄▅▅▅▆▆▆▇▇▇█████▇██████
macro_precision,▁▄▅▆▇█████▇█▇▇▆▆▆▆▆▆▅▆▅▅▆▅▆▅▆▅▆▅▆▆▆▅▆▆▆▆
macro_recall,▁▁▁▁▁▁▁▁▁▁▁▂▃▃▃▄▅▅▆▅▆▇▆▆▆▇▇▇▇▇▇████▇▇██▇
micro_f1,▁▁▁▁▁▁▁▁▁▁▂▃▄▄▅▅▆▆▆▆▆▇▇▇▇▇██▇▇██████████
micro_precision,▁████████▄█▇▇▇▇▆▇▆▇▇▇▇▇▇▆▇▆▆▇▆▆▆▆▆▇▇▆▆▆▆
micro_recall,▁▁▁▁▁▁▁▁▁▁▄▅▅▅▅▆▆▆▆▇▇▆▇▇▇▇▇▇▇▇▇▇▇█▇▇████
test_acc,▁▂▂▂▂▂▂▂▂▂▂▂▂▂▃▄▅▆▆▆▇▆▆▆▇▇▇▇▇█▇▇██▇█▇███

0,1
epoch,568.0
hamming_loss,0.02656
lr,0.0
macro_f1,0.36749
macro_precision,0.83075
macro_recall,0.31212
micro_f1,0.5163
micro_precision,0.77647
micro_recall,0.38672
test_acc,0.97342


# MambaOut Most Tokens

In [None]:

if __name__ == "__main__":
    # Configuration for tokenized transformer model, reduced embedding dimension
    d_model_spectra = 1536
    d_model_gaia = 1536
    num_classes = 55
    input_dim_spectra = 3647
    input_dim_gaia = 18
    
    # Token dimensions - these control the sequence length
    token_dim_spectra = 7  # Will create 522 tokens for spectra (3647/7)
    token_dim_gaia = 1      # Will create 18 tokens for gaia (18/1)
    
    n_layers = 20  
    n_heads = 8
    
    lr = 2.5e-6
    patience = 200
    num_epochs = 800
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print("Device:", device)
    
    # Initialize WandB
    wandb.init(project="ALLSTARS_multimodal_fusion_transformer_tokenized")
    config = {
        "model_type": "transformer_tokenized",
        "num_classes": num_classes,
        "d_model_spectra": d_model_spectra,
        "d_model_gaia": d_model_gaia,
        "input_dim_spectra": input_dim_spectra,
        "input_dim_gaia": input_dim_gaia,
        "token_dim_spectra": token_dim_spectra,
        "token_dim_gaia": token_dim_gaia,
        "num_tokens_spectra": (input_dim_spectra + token_dim_spectra - 1) // token_dim_spectra,
        "num_tokens_gaia": (input_dim_gaia + token_dim_gaia - 1) // token_dim_gaia,
        "n_layers": n_layers,
        "n_heads": n_heads,
        "use_cross_attention": True,
        "lr": lr,
        "patience": patience,
        "num_epochs": num_epochs
    }
    wandb.config.update(config)
    
    # Instantiate the tokenized transformer model
    model_MambaOut = StarClassifierFusionMambaOut(
        d_model_spectra=d_model_spectra,
        d_model_gaia=d_model_gaia,
        num_classes=num_classes,
        input_dim_spectra=input_dim_spectra,
        input_dim_gaia=input_dim_gaia,
        token_dim_spectra=token_dim_spectra,
        token_dim_gaia=token_dim_gaia,
        n_layers=n_layers,
        use_cross_attention=True,
        d_conv=32,
        n_cross_attn_heads=8
    )
    model_MambaOut.to(device)

    
    # Print model statistics
    print(f"Number of spectra tokens: {(input_dim_spectra + token_dim_spectra - 1) // token_dim_spectra}")
    print(f"Number of gaia tokens: {(input_dim_gaia + token_dim_gaia - 1) // token_dim_gaia}")
    
    # Compute parameter size
    param_size = 0
    for param in model_MambaOut.parameters():
        param_size += param.nelement() * param.element_size()
    
    # Compute buffer size
    buffer_size = 0
    for buffer in model_MambaOut.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()

    print(model_MambaOut)
    # print number of parameters per layer
    for name, param in model_MambaOut.named_parameters():
        print(name, param.numel())
    
    # Total size in MB
    total_size_mb = (param_size + buffer_size) / (1024 ** 2)
    print(f"Model size: {total_size_mb:.3f} MB")
    print(f"Total parameters: {sum(p.numel() for p in model_MambaOut.parameters() if p.requires_grad)}")
    
    # Train the transformer model
    trained_transformer_model = train_model_fusion(
        model=model_MambaOut,
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader,
        num_epochs=num_epochs,
        lr=lr,
        max_patience=patience,
        device=device
    )
    wandb.finish()
    
    # Save the model
    torch.save(trained_transformer_model.state_dict(), "Models/model_fusion_MambaOut_Many_tokens.pth")

Device: cuda


[34m[1mwandb[0m: Currently logged in as: [33mjoaocsgalmeida[0m ([33mjoaocsgalmeida-university-of-southampton[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


Number of spectra tokens: 521
Number of gaia tokens: 18
StarClassifierFusionMambaOut(
  (tokenizer_spectra): FeatureTokenizer(
    (token_embed): Linear(in_features=7, out_features=1536, bias=True)
  )
  (tokenizer_gaia): FeatureTokenizer(
    (token_embed): Linear(in_features=1, out_features=1536, bias=True)
  )
  (mamba_spectra): Sequential(
    (0): SequenceMambaOut(
      (blocks): Sequential(
        (0): GatedCNNBlock(
          (norm): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)
          (fc1): Linear(in_features=1536, out_features=6144, bias=True)
          (act): GELU(approximate='none')
          (conv): Conv1d(3072, 3072, kernel_size=(32,), stride=(1,), padding=(15,), groups=3072)
          (fc2): Linear(in_features=3072, out_features=1536, bias=True)
          (drop_path): Identity()
        )
      )
    )
    (1): SequenceMambaOut(
      (blocks): Sequential(
        (0): GatedCNNBlock(
          (norm): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)
 

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▅▅▅▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇███
hamming_loss,█▅▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr,▁▁▁▂▂▃▃▄▄▅▆▇▇▇▇████████████████▇▇▇▇▇▇▇▇▆
macro_f1,▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▃▃▄▄▅▅▅▆▆▇▇▇▇▇▇▇███
macro_precision,▁▅███████████████▇▇▇▆▆▅▆▅▄▄▅▅▄▄▄▄▄▄▃▄▄▅▄
macro_recall,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▄▅▅▆▇▇▇▇▇█▆███
micro_f1,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▃▂▃▄▅▅▆▆▆▆▇█▇▇██▇███
micro_precision,▁▁▂▂▂▁█████████▁▃▃▅▅▅▆▇▆▆▇▆▇▇▆▆▆▆▆▆▆▆▆▆▆
micro_recall,█▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▃▃▄▄▄▅▄▅▆▅▆▆▆▆▆
test_acc,▁▄▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▆▆▅▆▆▆▇▇▇▇▇█▇▇███████

0,1
epoch,437.0
hamming_loss,0.02831
lr,0.0
macro_f1,0.34189
macro_precision,0.77948
macro_recall,0.28316
micro_f1,0.44784
micro_precision,0.78155
micro_recall,0.31384
test_acc,0.9717


In [None]:
breakit up

SyntaxError: invalid syntax (3536932824.py, line 1)

# Transformer 1 token

In [10]:
if __name__ == "__main__":
    # Example config for transformer-based model
    d_model_spectra = 2048
    d_model_gaia = 2048
    num_classes = 55
    input_dim_spectra = 3647
    input_dim_gaia = 18
    n_layers = 10  # Match the number of layers with MAMBA models for fair comparison
    lr = 2.5e-6
    patience = 200
    num_epochs = 800
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print("Device:", device)
    
    # Initialize WandB
    wandb.init(project="ALLSTARS_multimodal_fusion_transformer",
    config = {
        "model_type": "transformer",
        "num_classes": num_classes,
        "d_model_spectra": d_model_spectra,
        "d_model_gaia": d_model_gaia,
        "input_dim_spectra": input_dim_spectra,
        "input_dim_gaia": input_dim_gaia,
        "n_layers": n_layers,
        "n_heads": 8,  # Number of attention heads
        "use_cross_attention": True,
        "lr": lr,
        "patience": patience,
        "num_epochs": num_epochs
    })
    
    # Instantiate the transformer model
    model_transformer = StarClassifierFusionTransformer(
        d_model_spectra=d_model_spectra,
        d_model_gaia=d_model_gaia,
        num_classes=num_classes,
        input_dim_spectra=input_dim_spectra,
        input_dim_gaia=input_dim_gaia,
        n_layers=n_layers,
        n_heads=8,
        use_cross_attention=True,  # set to False to compare with late fusion
        n_cross_attn_heads=8,
        dropout=0.1
    )
    model_transformer.to(device)
    
    # Print model statistics
    print(f"Model size: {sum(p.numel() for p in model_transformer.parameters()) / 1e9:.3f} GB")
    
    # Compute parameter size
    param_size = 0
    for param in model_transformer.parameters():
        param_size += param.nelement() * param.element_size()
    
    # Compute buffer size
    buffer_size = 0
    for buffer in model_transformer.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()
    
    # Total size in MB
    total_size_mb = (param_size + buffer_size) / (1024 ** 2)
    print(f"Model size: {total_size_mb:.3f} MB")
    
    # Print model architecture
    print(model_transformer)
    
    # Print number of parameters per layer
    for name, param in model_transformer.named_parameters():
        print(name, param.numel())
    print("Total number of parameters:", sum(p.numel() for p in model_transformer.parameters() if p.requires_grad))
    
    # Train the transformer model
    trained_transformer_model = train_model_fusion(
        model=model_transformer,
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader,
        num_epochs=num_epochs,
        lr=lr,
        max_patience=patience,
        device=device
    )
    wandb.finish()
    
    # Save the model
    torch.save(trained_transformer_model.state_dict(), "Models/model_fusion_transformer_1token_v2.pth")

Device: cuda


[34m[1mwandb[0m: Currently logged in as: [33mjoaocsgalmeida[0m ([33mjoaocsgalmeida-university-of-southampton[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


Model size: 1.041 GB
Model size: 4131.557 MB
StarClassifierFusionTransformer(
  (tokenizer_spectra): FeatureTokenizer(
    (token_embed): Linear(in_features=64, out_features=2048, bias=True)
  )
  (tokenizer_gaia): FeatureTokenizer(
    (token_embed): Linear(in_features=2, out_features=2048, bias=True)
  )
  (transformer_spectra): TransformerFeatureExtractor(
    (layers): ModuleList(
      (0-9): 10 x TransformerBlock(
        (norm1): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
        (attn): RotarySelfAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (k_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (v_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (out_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (rope): RotaryEmbedding()
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (norm2): LayerNorm((2048,), eps=1e-05, elementwise_affine=Tr

0,1
epoch,▁▂▂▂▂▂▃▃▃▃▃▃▃▃▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇█████
hamming_loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr,▁▁▁▁▂▂▄▄▄▅▆▇██████▇▇▇▇▇▇▆▆▅▅▅▅▄▄▄▄▄▃▃▃▃▂
macro_f1,▁▁▁▁▁▁▁▁▁▁▂▂▃▄▄▅▆▅▅▆▆▇▆▇▇▇▇▇█▇██▇███████
macro_precision,▆██████▇▇▇▅▇▆▅▃▃▃▄▂▃▁▂▂▂▂▂▁▂▂▁▁▂▂▂▂▁▂▁▁▁
macro_recall,▁▁▁▁▁▁▁▁▁▁▂▂▃▄▄▅▄▆▅▆▆▇▇▆▇▇▇▇▇▇███████▇██
micro_f1,▁▁▁▁▁▁▁▁▁▁▂▃▃▄▅▅▅▆▆▆▇▇▇▇▇▇▇▇▇█▇▇▇███████
micro_precision,▁██▂▁▅▃▆▆▅▆▅▆▆▆▆▅▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆
micro_recall,▁▁▁▁▁▁▁▂▂▂▃▃▃▃▄▄▅▆▅▆▆▆▆▇▇▇▇▇▇▇█▇█▇▇▇▇███
test_acc,▁▁▁▁▁▁▁▁▁▂▃▂▂▃▄▅▄▄▅▅▅▆▇▆▆▆███▇▇█▇▇▇█████

0,1
epoch,640.0
hamming_loss,0.02656
lr,0.0
macro_f1,0.41904
macro_precision,0.67243
macro_recall,0.39741
micro_f1,0.57306
micro_precision,0.70112
micro_recall,0.48456
test_acc,0.97344


# Transformer with 19 + 18 Tokens

In [None]:
if __name__ == "__main__":
    # Configuration for tokenized transformer model
    d_model_spectra = 2048
    d_model_gaia = 2048
    num_classes = 55
    input_dim_spectra = 3647
    input_dim_gaia = 18
    
    # Token dimensions - these control the sequence length
    token_dim_spectra = 192  # Will create ~19 tokens for spectra (3647/192)
    token_dim_gaia = 1      # Will create 18 tokens for gaia (18/1)
    
    n_layers = 10  # Reduced number of layers since we have more tokens
    n_heads = 8
    
    lr = 1e-6
    patience = 200
    num_epochs = 800
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print("Device:", device)
    
    # Initialize WandB
    wandb.init(project="ALLSTARS_multimodal_fusion_transformer_tokenized")
    config = {
        "model_type": "transformer_tokenized",
        "num_classes": num_classes,
        "d_model_spectra": d_model_spectra,
        "d_model_gaia": d_model_gaia,
        "input_dim_spectra": input_dim_spectra,
        "input_dim_gaia": input_dim_gaia,
        "token_dim_spectra": token_dim_spectra,
        "token_dim_gaia": token_dim_gaia,
        "num_tokens_spectra": (input_dim_spectra + token_dim_spectra - 1) // token_dim_spectra,
        "num_tokens_gaia": (input_dim_gaia + token_dim_gaia - 1) // token_dim_gaia,
        "n_layers": n_layers,
        "n_heads": n_heads,
        "use_cross_attention": True,
        "lr": lr,
        "patience": patience,
        "num_epochs": num_epochs
    }
    wandb.config.update(config)
    
    # Instantiate the tokenized transformer model
    model_transformer = StarClassifierFusionTransformer(
        d_model_spectra=d_model_spectra,
        d_model_gaia=d_model_gaia,
        num_classes=num_classes,
        input_dim_spectra=input_dim_spectra,
        input_dim_gaia=input_dim_gaia,
        token_dim_spectra=token_dim_spectra,
        token_dim_gaia=token_dim_gaia,
        n_layers=n_layers,
        n_heads=n_heads,
        use_cross_attention=True,
        n_cross_attn_heads=8,
        dropout=0.1
    )
    model_transformer.to(device)
    
    # Print model statistics
    print(f"Number of spectra tokens: {(input_dim_spectra + token_dim_spectra - 1) // token_dim_spectra}")
    print(f"Number of gaia tokens: {(input_dim_gaia + token_dim_gaia - 1) // token_dim_gaia}")
    
    # Compute parameter size
    param_size = 0
    for param in model_transformer.parameters():
        param_size += param.nelement() * param.element_size()
    
    # Compute buffer size
    buffer_size = 0
    for buffer in model_transformer.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()

    print(model_fusion)
    # print number of parameters per layer
    for name, param in model_fusion.named_parameters():
        print(name, param.numel())
    
    # Total size in MB
    total_size_mb = (param_size + buffer_size) / (1024 ** 2)
    print(f"Model size: {total_size_mb:.3f} MB")
    print(f"Total parameters: {sum(p.numel() for p in model_transformer.parameters() if p.requires_grad)}")
    
    # Train the transformer model
    trained_transformer_model = train_model_fusion(
        model=model_transformer,
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader,
        num_epochs=num_epochs,
        lr=lr,
        max_patience=patience,
        device=device
    )
    wandb.finish()
    
    # Save the model
    torch.save(trained_transformer_model.state_dict(), "Models/model_fusion_transformer_tokenized.pth")

Device: cuda


[34m[1mwandb[0m: Currently logged in as: [33mjoaocsgalmeida[0m ([33mjoaocsgalmeida-university-of-southampton[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


Number of spectra tokens: 19
Number of gaia tokens: 18
Model size: 4052.549 MB
Total parameters: 1041377335
Early stopping triggered.


0,1
epoch,▁▁▁▁▁▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇███
hamming_loss,█████▆▇▆▆▆▆▅▅▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁
lr,▁▁▁▂▂▄▅▆▆▇▇████████▇▇▇▇▇▇▆▆▆▆▆▆▅▅▅▄▄▃▃▃▂
macro_f1,▁▁▁▁▁▁▁▁▂▃▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇████████████
macro_precision,██████▇▆▆▄▃▃▂▂▂▂▂▂▂▂▁▁▂▁▂▁▁▁▁▂▁▁▁▁▁▂▂▁▂▂
macro_recall,▇▁▁▁▁▁▁▁▁▁▁▃▃▄▄▄▅▄▅▅▅▅▅▆▆▇▇▇▇▇▇█████████
micro_f1,▁▁▁▁▁▁▁▁▁▁▁▂▂▃▃▄▄▄▆▆▆▇▇▇▇▇▇▇████████████
micro_precision,▁█████▇▆▆▆▆▆▆▆▆▆▆▇▆▆▆▆▆▆▆▆▆▆▆▆▆▇▆▆▆▆▆▆▆▆
micro_recall,▅▂▁▁▁▁▁▁▁▁▁▁▁▁▁▂▃▃▃▃▃▄▄▄▅▇▇▇▆▇▇▇▇███████
test_acc,▁███████████████████████████████████████

0,1
epoch,659.0
hamming_loss,0.02185
lr,0.0
macro_f1,0.44726
macro_precision,0.6264
macro_recall,0.42024
micro_f1,0.63076
micro_precision,0.77981
micro_recall,0.52955
test_acc,0.97823


# Transformer with 522 + 18 Tokens

In [None]:
if __name__ == "__main__":
    # Configuration for tokenized transformer model
    d_model_spectra = 1536
    d_model_gaia = 1536
    num_classes = 55
    input_dim_spectra = 3647
    input_dim_gaia = 18
    
    # Token dimensions - these control the sequence length
    token_dim_spectra = 7 # Will create ~522 tokens for spectra (3647/7)
    token_dim_gaia = 1      # Will create 18 tokens for gaia (18/1)
    
    n_layers = 10 # Reduced number of layers since we have more tokens
    n_heads = 8
    
    lr = 1e-6
    patience = 200
    num_epochs = 800
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print("Device:", device)
    
    # Initialize WandB
    wandb.init(project="ALLSTARS_multimodal_fusion_transformer_522")
    config = {
        "model_type": "transformer_tokenized",
        "num_classes": num_classes,
        "d_model_spectra": d_model_spectra,
        "d_model_gaia": d_model_gaia,
        "input_dim_spectra": input_dim_spectra,
        "input_dim_gaia": input_dim_gaia,
        "token_dim_spectra": token_dim_spectra,
        "token_dim_gaia": token_dim_gaia,
        "num_tokens_spectra": (input_dim_spectra + token_dim_spectra - 1) // token_dim_spectra,
        "num_tokens_gaia": (input_dim_gaia + token_dim_gaia - 1) // token_dim_gaia,
        "n_layers": n_layers,
        "n_heads": n_heads,
        "use_cross_attention": True,
        "lr": lr,
        "patience": patience,
        "num_epochs": num_epochs
    }
    wandb.config.update(config)
    
    # Instantiate the tokenized transformer model
    model_transformer = StarClassifierFusionTransformer(
        d_model_spectra=d_model_spectra,
        d_model_gaia=d_model_gaia,
        num_classes=num_classes,
        input_dim_spectra=input_dim_spectra,
        input_dim_gaia=input_dim_gaia,
        token_dim_spectra=token_dim_spectra,
        token_dim_gaia=token_dim_gaia,
        n_layers=n_layers,
        n_heads=n_heads,
        use_cross_attention=True,
        n_cross_attn_heads=8,
        dropout=0.1
    )
    model_transformer.to(device)
    
    # Print model statistics
    print(f"Number of spectra tokens: {(input_dim_spectra + token_dim_spectra - 1) // token_dim_spectra}")
    print(f"Number of gaia tokens: {(input_dim_gaia + token_dim_gaia - 1) // token_dim_gaia}")
    
    # Compute parameter size
    param_size = 0
    for param in model_transformer.parameters():
        param_size += param.nelement() * param.element_size()
    
    # Compute buffer size
    buffer_size = 0
    for buffer in model_transformer.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()
    print(model_transformer)
    # print number of parameters per layer
    for name, param in model_transformer.named_parameters():
        print(name, param.numel())
    
    # Total size in MB
    total_size_mb = (param_size + buffer_size) / (1024 ** 2)
    print(f"Model size: {total_size_mb:.3f} MB")
    print(f"Total parameters: {sum(p.numel() for p in model_transformer.parameters() if p.requires_grad)}")
    
    # Train the transformer model
    trained_transformer_model = train_model_fusion(
        model=model_transformer,
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader,
        num_epochs=num_epochs,
        lr=lr,
        max_patience=patience,
        device=device
    )
    wandb.finish()
    
    # Save the model
    torch.save(trained_transformer_model.state_dict(), "Models/model_fusion_transformer_tokenized_many_tokens.pth")

Device: cuda


[34m[1mwandb[0m: Currently logged in as: [33mjoaocsgalmeida[0m ([33mjoaocsgalmeida-university-of-southampton[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


Number of spectra tokens: 521
Number of gaia tokens: 18
StarClassifierFusionTransformer(
  (tokenizer_spectra): FeatureTokenizer(
    (token_embed): Linear(in_features=7, out_features=1536, bias=True)
  )
  (tokenizer_gaia): FeatureTokenizer(
    (token_embed): Linear(in_features=1, out_features=1536, bias=True)
  )
  (transformer_spectra): TransformerFeatureExtractor(
    (layers): ModuleList(
      (0-9): 10 x TransformerBlock(
        (norm1): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
        (attn): RotarySelfAttention(
          (q_proj): Linear(in_features=1536, out_features=1536, bias=True)
          (k_proj): Linear(in_features=1536, out_features=1536, bias=True)
          (v_proj): Linear(in_features=1536, out_features=1536, bias=True)
          (out_proj): Linear(in_features=1536, out_features=1536, bias=True)
          (rope): RotaryEmbedding()
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (norm2): LayerNorm((1536,), eps=1e-05, elementwise

0,1
epoch,▁▁▁▁▁▁▁▂▂▂▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇██
hamming_loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr,▁▁▂▂▂▃▄▄▆▇█████▇▇▇▆▆▆▅▅▅▄▄▄▃▃▃▂▂▂▂▂▁▁▁▁▁
macro_f1,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▃▄▄▃▃▅▅▆▆▆▇▇▇▇▇▇▇▇▇▇▇██▇█
macro_precision,▅█████████▆▇▇▇▆▇▆▇▆▆▄▄▄▃▁▃▂▁▁▂▂▂▂▂▂▃▂▂▂▂
macro_recall,█▅▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▃▃▃▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄
micro_f1,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▃▂▃▄▅▃▅▅▄▆▆▆▇▇▇▇███▇████
micro_precision,▁█████████▅▂▄▅▆▆▆▅▅▆▅▆▆▅▆▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅
micro_recall,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▃▃▅▅▅▆▅▇▇▇▇▇▇▇▇▇▇▇████
test_acc,▂▃▃▃▃▃▃▃▃▃▃▂▁▄▃▄▄▄▅▅▅▅▆▆▆▆▇▆██▇▇▇▇▇▇████

0,1
epoch,799.0
hamming_loss,0.03313
lr,0.0
macro_f1,0.27559
macro_precision,0.68441
macro_recall,0.23068
micro_f1,0.34986
micro_precision,0.63184
micro_recall,0.2419
test_acc,0.96671


# MAMBA2

# MAMBAIN (MAMBA2) 522 + 18

In [None]:
if __name__ == "__main__":
    # Example config
    d_model_spectra = 1536
    d_model_gaia = 1536
    num_classes = 55
    input_dim_spectra = 3647
    input_dim_gaia = 18
    n_layers = 20
    d_state = 16 # State dimension for MAMBA
    d_conv = 4  # Convolution dimension for MAMBA
    expand = 2  # Expansion factor for MAMBA

    # Token dimensions - these control the sequence length
    token_dim_spectra = 7    # Will create 522 tokens for spectra (3647/7)
    token_dim_gaia = 1       # Will create 18 tokens for gaia (18/1)

    lr = 2.5e-6
    patience = 200
    num_epochs = 800
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print("Device:", device)

    # Initialize WandB
    wandb.init(project="ALLSTARS_multimodal_fusion_mamba_v2")
    
    config = {
        "num_classes": num_classes,
        "d_model_spectra": d_model_spectra,
        "d_model_gaia": d_model_gaia,
        "input_dim_spectra": input_dim_spectra,
        "input_dim_gaia": input_dim_gaia,
        "n_layers": n_layers,
        "lr": lr,
        "patience": patience,
        "num_epochs": num_epochs,
        "d_state": d_state,
        "d_conv": d_conv,
        "expand": expand,
        "token_dim_gaia": token_dim_gaia,
        "token_dim_spectra": token_dim_spectra,
    }
    wandb.config.update(config)

    # Instantiate the fusion model
    # Try use_cross_attention=False for late-fusion, True for cross-attention
    model_fusion = StarClassifierFusionMambaTokenized(
        d_model_spectra=d_model_spectra,
        d_model_gaia=d_model_gaia,
        num_classes=num_classes,
        input_dim_spectra=input_dim_spectra,
        input_dim_gaia=input_dim_gaia,
        token_dim_spectra=token_dim_spectra,
        token_dim_gaia=token_dim_gaia,
        n_layers=n_layers,
        use_cross_attention=True,  # set to False to compare with late fusion
        n_cross_attn_heads=8
    )
    model_fusion.to(device)

    # Print size of model in GB
    print(f"Model size: {sum(p.numel() for p in model_fusion.parameters()) / 1e9:.2f} GB")
    param_size = 0
    for param in model_fusion.parameters():
        param_size += param.nelement() * param.element_size()
    buffer_size = 0
    for buffer in model_fusion.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()

    size_all_mb = (param_size + buffer_size) / 1024**2
    print('model size: {:.3f}MB'.format(size_all_mb))
    # Compute parameter size
    param_size = sum(p.nelement() * p.element_size() for p in model_fusion.parameters())

    # Compute buffer size
    buffer_size = sum(b.nelement() * b.element_size() for b in model_fusion.buffers())

    # Total size in MB
    total_size_mb = (param_size + buffer_size) / (1024 ** 2)
    print(f"Model size: {total_size_mb:.3f} MB")

    print(model_fusion)
    # print number of parameters per layer
    for name, param in model_fusion.named_parameters():
        print(name, param.numel())
    print("Total number of parameters:", sum(p.numel() for p in model_fusion.parameters() if p.requires_grad))

    # Train the fusion model
    trained_fusion_model = train_model_fusion(
        model=model_fusion,
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader,
        num_epochs=num_epochs,
        lr=lr,
        max_patience=patience,
        device=device
    )

    wandb.finish()

# Save the model
torch.save(trained_fusion_model.state_dict(), "Models/model_fusion_mamba_maxtokens.pth")


Device: cuda
Model size: 0.62 GB
model size: 2367.272MB
Model size: 2367.272 MB
StarClassifierFusionMambaTokenized(
  (tokenizer_spectra): FeatureTokenizer(
    (token_embed): Linear(in_features=7, out_features=1536, bias=True)
  )
  (tokenizer_gaia): FeatureTokenizer(
    (token_embed): Linear(in_features=1, out_features=1536, bias=True)
  )
  (mamba_spectra): Sequential(
    (0): Mamba2(
      (in_proj): Linear(in_features=1536, out_features=6704, bias=False)
      (conv1d): Conv1d(3584, 3584, kernel_size=(4,), stride=(1,), padding=(3,), groups=3584)
      (act): SiLU()
      (norm): RMSNorm()
      (out_proj): Linear(in_features=3072, out_features=1536, bias=False)
    )
    (1): Mamba2(
      (in_proj): Linear(in_features=1536, out_features=6704, bias=False)
      (conv1d): Conv1d(3584, 3584, kernel_size=(4,), stride=(1,), padding=(3,), groups=3584)
      (act): SiLU()
      (norm): RMSNorm()
      (out_proj): Linear(in_features=3072, out_features=1536, bias=False)
    )
    (2): M

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇█████
hamming_loss,█▃▃▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr,▁▁▁▁▁▂▂▃▃▄▄▄▅▅▆▆▇▇▇▇▇▇▇███████▇▇▇▇▇▇▆▆▆▆
macro_f1,▁▁▁▁▁▁▁▂▂▂▅▄▄▆▆▆▇▆▇▆▇▇█▇▇▇█▇█████▇███▇██
macro_precision,▁▅▇█████████▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇
macro_recall,▁▁▁▁▁▂▃▄▄▄▅▄▅▄▅▅▅▆▅▆▆▆▇▆▇▆▆▇▆▇▇▇█▇█▇█▇▇█
micro_f1,▁▁▁▁▁▁▁▁▁▂▂▂▂▂▃▄▅▄▆▆▆▅▆▇▇▇▇▇▇▇▇▇▇█████▇█
micro_precision,▁▁▁█████▅▅▆▆▇▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇
micro_recall,▁▁▁▁▁▁▁▁▁▂▂▄▄▄▅▅▅▅▅▆▆▅▅▇▆▇▇▇▇▆▇▇▇██▇▇▇▇█
test_acc,▁▁▁▁▁▁▁▁▁▁▂▃▂▃▄▃▅▅▆▅▅▇▇▇▆▇▆▆▆▇▇▆▇▇▆██▇██

0,1
epoch,449.0
hamming_loss,0.02646
lr,0.0
macro_f1,0.33348
macro_precision,0.88376
macro_recall,0.27981
micro_f1,0.49046
micro_precision,0.83333
micro_recall,0.34749
test_acc,0.97299


# Mamba 19 + 18 Tokens

In [None]:
if __name__ == "__main__":
    # Clean Cache of GPU
    torch.cuda.empty_cache()

    # Example config
    d_model_spectra = 2048
    d_model_gaia = 2048
    num_classes = 55
    input_dim_spectra = 3647
    input_dim_gaia = 18
    n_layers = 20
    d_state = 32 # State dimension for MAMBA
    d_conv = 4  # Convolution dimension for MAMBA
    expand = 2  # Expansion factor for MAMBA

    # Token dimensions - these control the sequence length
    token_dim_spectra = 192  # Will create ~19 tokens for spectra (3647/192)
    token_dim_gaia = 1       # Will create 18 tokens for gaia (18/1)

    lr = 2.5e-6
    patience = 200
    num_epochs = 800
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print("Device:", device)

    # Initialize WandB
    wandb.init(project="ALLSTARS_multimodal_fusion_mamba_v2")
    
    config = {
        "num_classes": num_classes,
        "d_model_spectra": d_model_spectra,
        "d_model_gaia": d_model_gaia,
        "input_dim_spectra": input_dim_spectra,
        "input_dim_gaia": input_dim_gaia,
        "n_layers": n_layers,
        "lr": lr,
        "patience": patience,
        "num_epochs": num_epochs
    }
    wandb.config.update(config)

    # Instantiate the fusion model
    # Try use_cross_attention=False for late-fusion, True for cross-attention
    model_fusion = StarClassifierFusionMambaTokenized(
        d_model_spectra=d_model_spectra,
        d_model_gaia=d_model_gaia,
        num_classes=num_classes,
        input_dim_spectra=input_dim_spectra,
        input_dim_gaia=input_dim_gaia,
        n_layers=n_layers,
        use_cross_attention=True,  # set to False to compare with late fusion
        n_cross_attn_heads=8
    )
    model_fusion.to(device)

    # Print size of model in GB
    print(f"Model size: {sum(p.numel() for p in model_fusion.parameters()) / 1e9:.2f} GB")
    param_size = 0
    for param in model_fusion.parameters():
        param_size += param.nelement() * param.element_size()
    buffer_size = 0
    for buffer in model_fusion.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()

    size_all_mb = (param_size + buffer_size) / 1024**2
    print('model size: {:.3f}MB'.format(size_all_mb))
    # Compute parameter size
    param_size = sum(p.nelement() * p.element_size() for p in model_fusion.parameters())

    # Compute buffer size
    buffer_size = sum(b.nelement() * b.element_size() for b in model_fusion.buffers())

    # Total size in MB
    total_size_mb = (param_size + buffer_size) / (1024 ** 2)
    print(f"Model size: {total_size_mb:.3f} MB")

    print(model_fusion)
    # print number of parameters per layer
    for name, param in model_fusion.named_parameters():
        print(name, param.numel())
    print("Total number of parameters:", sum(p.numel() for p in model_fusion.parameters() if p.requires_grad))

    # Train the fusion model
    trained_fusion_model = train_model_fusion(
        model=model_fusion,
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader,
        num_epochs=num_epochs,
        lr=lr,
        max_patience=patience,
        device=device
    )

    wandb.finish()

# Save the model
torch.save(trained_fusion_model.state_dict(), "Models/model_fusion_mamba_19_18.pth")


Device: cuda


[34m[1mwandb[0m: Currently logged in as: [33mjoaocsgalmeida[0m ([33mjoaocsgalmeida-university-of-southampton[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


Model size: 1.09 GB
model size: 4153.686MB
Model size: 4153.686 MB
StarClassifierFusionMambaTokenized(
  (tokenizer_spectra): FeatureTokenizer(
    (token_embed): Linear(in_features=64, out_features=2048, bias=True)
  )
  (tokenizer_gaia): FeatureTokenizer(
    (token_embed): Linear(in_features=2, out_features=2048, bias=True)
  )
  (mamba_spectra): Sequential(
    (0): Mamba2(
      (in_proj): Linear(in_features=2048, out_features=8768, bias=False)
      (conv1d): Conv1d(4608, 4608, kernel_size=(4,), stride=(1,), padding=(3,), groups=4608)
      (act): SiLU()
      (norm): RMSNorm()
      (out_proj): Linear(in_features=4096, out_features=2048, bias=False)
    )
    (1): Mamba2(
      (in_proj): Linear(in_features=2048, out_features=8768, bias=False)
      (conv1d): Conv1d(4608, 4608, kernel_size=(4,), stride=(1,), padding=(3,), groups=4608)
      (act): SiLU()
      (norm): RMSNorm()
      (out_proj): Linear(in_features=4096, out_features=2048, bias=False)
    )
    (2): Mamba2(
     

0,1
epoch,▁▁▁▁▂▂▂▂▂▂▃▃▄▄▄▄▄▄▅▅▅▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇████
hamming_loss,█▇▇▇▇▇▇▇▇▇▅▄▄▄▄▃▃▂▃▂▃▃▃▃▃▂▂▂▂▁▂▁▂▂▁▁▁▁▂▁
lr,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▄▆▇▇▇▇████████▇▇▇▇▇▇▇▆▆▆▅▅
macro_f1,▁▁▁▂▂▃▅▄▅▄▆▆▆▆▆▆▆▆▆▇▇▇▇▆▇▇▇▇▇█▇▇▇▇▇██▇██
macro_precision,▁▅█████████████▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇
macro_recall,▂▁▁▂▂▃▄▄▄▄▅▆▆▅▆▆▆▆▆▇▆▇▇█▇█▇▇█▇▇▇▇████▇██
micro_f1,▂▂▁▁▁▁▂▂▃▃▄▄▅▆▅▆▇▆▆▇▇▇▆▇▇▇▇▇▇▇▇█▇▇██████
micro_precision,▁██████▇▇▆▇▆▇▆▇▆▆▆▆▆▅▆▆▆▅▆▅▅▆▅▅▅▅▅▅▅▅▅▅▅
micro_recall,▂▂▁▁▁▂▃▃▄▄▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇██▇▇████████
test_acc,▁▁▁▁▁▁▁▁▂▂▃▃▄▅▅▅▅▆▆▆▅▅▆▆▆▆▆▅▆▇▆██▇▇▇▇█▇▇

0,1
epoch,521.0
hamming_loss,0.02368
lr,0.0
macro_f1,0.40767
macro_precision,0.87328
macro_recall,0.35113
micro_f1,0.58924
micro_precision,0.82534
micro_recall,0.45817
test_acc,0.97614


# Mamba 1 token

In [None]:
if __name__ == "__main__":
    # Clean Cache of GPU
    torch.cuda.empty_cache()

    # Example config
    d_model_spectra = 2048
    d_model_gaia = 2048
    num_classes = 55
    input_dim_spectra = 3647
    input_dim_gaia = 18
    n_layers = 20
    d_state = 32 # State dimension for MAMBA
    d_conv = 2  # Convolution dimension for MAMBA
    expand = 2  # Expansion factor for MAMBA

    # Token dimensions - these control the sequence length
    token_dim_spectra = 3647  # Will create 1 token for spectra (3647/3647)
    token_dim_gaia = 18       # Will create 1 token for gaia (18/1)

    lr = 2.5e-6
    patience = 200
    num_epochs = 800
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print("Device:", device)

    # Initialize WandB
    wandb.init(project="ALLSTARS_multimodal_fusion_mamba_v2")
    
    config = {
        "num_classes": num_classes,
        "d_model_spectra": d_model_spectra,
        "d_model_gaia": d_model_gaia,
        "input_dim_spectra": input_dim_spectra,
        "input_dim_gaia": input_dim_gaia,
        "n_layers": n_layers,
        "lr": lr,
        "patience": patience,
        "num_epochs": num_epochs
    }
    wandb.config.update(config)

    # Instantiate the fusion model
    # Try use_cross_attention=False for late-fusion, True for cross-attention
    model_fusion = StarClassifierFusionMambaTokenized(
        d_model_spectra=d_model_spectra,
        d_model_gaia=d_model_gaia,
        num_classes=num_classes,
        input_dim_spectra=input_dim_spectra,
        input_dim_gaia=input_dim_gaia,
        n_layers=n_layers,
        use_cross_attention=True,  # set to False to compare with late fusion
        n_cross_attn_heads=8
    )
    model_fusion.to(device)

    # Print size of model in GB
    print(f"Model size: {sum(p.numel() for p in model_fusion.parameters()) / 1e9:.2f} GB")
    param_size = 0
    for param in model_fusion.parameters():
        param_size += param.nelement() * param.element_size()
    buffer_size = 0
    for buffer in model_fusion.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()

    size_all_mb = (param_size + buffer_size) / 1024**2
    print('model size: {:.3f}MB'.format(size_all_mb))
    # Compute parameter size
    param_size = sum(p.nelement() * p.element_size() for p in model_fusion.parameters())

    # Compute buffer size
    buffer_size = sum(b.nelement() * b.element_size() for b in model_fusion.buffers())

    # Total size in MB
    total_size_mb = (param_size + buffer_size) / (1024 ** 2)
    print(f"Model size: {total_size_mb:.3f} MB")

    print(model_fusion)
    # print number of parameters per layer
    for name, param in model_fusion.named_parameters():
        print(name, param.numel())
    print("Total number of parameters:", sum(p.numel() for p in model_fusion.parameters() if p.requires_grad))

    # Train the fusion model
    trained_fusion_model = train_model_fusion(
        model=model_fusion,
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader,
        num_epochs=num_epochs,
        lr=lr,
        max_patience=patience,
        device=device
    )

    wandb.finish()

# Save the model
torch.save(trained_fusion_model.state_dict(), "Models/model_fusion_mamba_1_token.pth")

Device: cuda


[34m[1mwandb[0m: Currently logged in as: [33mjoaocsgalmeida[0m ([33mjoaocsgalmeida-university-of-southampton[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


Model size: 1.09 GB
model size: 4153.686MB
Model size: 4153.686 MB
StarClassifierFusionMambaTokenized(
  (tokenizer_spectra): FeatureTokenizer(
    (token_embed): Linear(in_features=64, out_features=2048, bias=True)
  )
  (tokenizer_gaia): FeatureTokenizer(
    (token_embed): Linear(in_features=2, out_features=2048, bias=True)
  )
  (mamba_spectra): Sequential(
    (0): Mamba2(
      (in_proj): Linear(in_features=2048, out_features=8768, bias=False)
      (conv1d): Conv1d(4608, 4608, kernel_size=(4,), stride=(1,), padding=(3,), groups=4608)
      (act): SiLU()
      (norm): RMSNorm()
      (out_proj): Linear(in_features=4096, out_features=2048, bias=False)
    )
    (1): Mamba2(
      (in_proj): Linear(in_features=2048, out_features=8768, bias=False)
      (conv1d): Conv1d(4608, 4608, kernel_size=(4,), stride=(1,), padding=(3,), groups=4608)
      (act): SiLU()
      (norm): RMSNorm()
      (out_proj): Linear(in_features=4096, out_features=2048, bias=False)
    )
    (2): Mamba2(
     

0,1
epoch,▁▁▁▁▂▂▂▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇███
hamming_loss,██████████▇▇▇▆▆▆▅▅▅▅▄▄▄▄▄▄▃▂▃▂▃▃▂▂▂▂▁▂▁▁
lr,▁▁▁▁▁▂▂▂▃▃▃▄▅▅▆▆▇▇▇▇██████████████▇▇▇▆▆▆
macro_f1,▁▁▁▁▂▄▄▄▄▄▅▅▅▅▅▅▅▅▅▆▅▆▆▆▆▇▆▇▆▆▇▇▇▇██▇▇▇█
macro_precision,▁█████▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇
macro_recall,▁▁▁▁▁▁▁▁▁▂▂▂▃▃▃▅▅▄▄▄▅▅▅▅▆▆▆▆▇▇▇▆█▇██▇███
micro_f1,▁▁▁▁▁▁▁▁▂▃▄▄▅▅▅▅▆▆▆▇▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇█▇▇█
micro_precision,▁█████▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▆▇▇
micro_recall,▆▂▁▁▁▁▁▂▂▃▃▄▅▅▅▆▆▆▆▆▇▇▆▆▇▇▇▇▇▇▇▇█▇██▇███
test_acc,▁▁▁▁▁▁▁▁▂▂▃▃▄▄▄▅▅▅▅▅▆▇▆▆▆▆▇▇▇▆▇▇▇▇▇▇▇▇▇█

0,1
epoch,476.0
hamming_loss,0.02393
lr,0.0
macro_f1,0.38904
macro_precision,0.8911
macro_recall,0.33105
micro_f1,0.56739
micro_precision,0.83712
micro_recall,0.42913
test_acc,0.97607
