In [1]:
import gc
import wandb
import torch
import mambapy
import numpy as np
import pandas as pd
from torch import nn, optim
import matplotlib.pyplot as plt
import sklearn.metrics as metrics
from mambapy.mamba import Mamba, MambaConfig
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, precision_score, recall_score, hamming_loss, roc_auc_score

In [2]:
class BalancedMultiLabelDataset(Dataset):
    def __init__(self, X, y, limit_per_label=201):
        """
        Multilabel version of the BalancedDataset.
        
        Args:
        - X (array-like): Input features.
        - y (array-like): Multi-hot encoded labels (2D array, each row is a multi-hot vector).
        - limit_per_label (int): Target number of samples per label.
        """
        self.X = X
        self.y = y
        self.limit_per_label = limit_per_label
        self.num_classes = y.shape[1]  # Number of possible classes
        self.indices = self.balance_classes()

    def balance_classes(self):
        indices = []
        class_counts = torch.sum(self.y, axis=0)  # Total occurrences of each class
        for cls in range(self.num_classes):
            cls_indices = np.where(self.y[:, cls] == 1)[0]  # Indices where this label is active
            if len(cls_indices) < self.limit_per_label:  # Upsample minority classes
                if len(cls_indices) == 0:
                    #print(f"No samples found for class {cls}. Skipping.")
                    continue  # Skip this class if there are no samples for it
                extra_indices = np.random.choice(cls_indices, self.limit_per_label - len(cls_indices), replace=True)
                cls_indices = np.concatenate([cls_indices, extra_indices])
            elif len(cls_indices) > self.limit_per_label:  # Downsample majority classes
                cls_indices = np.random.choice(cls_indices, self.limit_per_label, replace=False)
            indices.extend(cls_indices)
        indices = np.unique(indices)  # Remove duplicate indices
        np.random.shuffle(indices)
        return indices

    def re_sample(self):
        """Rebalance the dataset if needed, for example, after changes to the dataset."""
        self.indices = self.balance_classes()

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        index = self.indices[idx]
        return self.X[index], self.y[index]

    
def calculate_metrics(y_true, y_pred):
    metrics = {
        "micro_f1": f1_score(y_true, y_pred, average='micro'),
        "macro_f1": f1_score(y_true, y_pred, average='macro'),
        "weighted_f1": f1_score(y_true, y_pred, average='weighted'),
        "micro_precision": precision_score(y_true, y_pred, average='micro', zero_division=1),
        "macro_precision": precision_score(y_true, y_pred, average='macro', zero_division=1),
        "weighted_precision": precision_score(y_true, y_pred, average='weighted', zero_division=1),
        "micro_recall": recall_score(y_true, y_pred, average='micro'),
        "macro_recall": recall_score(y_true, y_pred, average='macro'),
        "weighted_recall": recall_score(y_true, y_pred, average='weighted'),
        "hamming_loss": hamming_loss(y_true, y_pred)
    }
    
    # Check if there are at least two classes present in y_true
    #if len(np.unique(y_true)) > 1:
        #metrics["roc_auc"] = roc_auc_score(y_true, y_pred, average='macro', multi_class='ovr')
    #else:
       # metrics["roc_auc"] = None  # or you can set it to a default value or message
    
    return metrics
    
def calculate_class_weights(y):
    if y.ndim > 1:  # Check if y is 2D (multi-hot encoded)
        class_counts = np.sum(y, axis=0)  # Count how many times each class appears
    else:
        class_counts = np.bincount(y)  # For a 1D array, use bincount 
    total_samples = y.shape[0] if y.ndim > 1 else len(y)
    class_weights = np.where(class_counts > 0, total_samples / (len(class_counts) * class_counts), 0)
    return class_weights

def calculate_class_weights(y):
    if y.ndim > 1:  
        class_counts = np.sum(y, axis=0)  
    else:
        class_counts = np.bincount(y)

    total_samples = y.shape[0] if y.ndim > 1 else len(y)
    class_counts = np.where(class_counts == 0, 1, class_counts)  # Prevent division by zero
    class_weights = total_samples / (len(class_counts) * class_counts)
    
    return class_weights

    
def train_model_mamba(
    model, train_loader, val_loader, test_loader, 
    num_epochs=500, lr=1e-4, max_patience=20, device='cuda'
):
    # Move model to device
    model = model.to(device)

    # Define optimizer, scheduler, and loss function
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=int(max_patience / 5)
    )
    all_labels = []

    for _, y_batch in train_loader:
        all_labels.extend(y_batch.cpu().numpy())    
        
    class_weights = calculate_class_weights(np.array(all_labels))
    class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)
    #print("Class weights:", class_weights)
    criterion = nn.BCEWithLogitsLoss(pos_weight=class_weights)

    best_val_loss = float('inf')
    patience = max_patience

    for epoch in range(num_epochs):
        # Resample training and validation data
        train_loader.dataset.re_sample()
        #val_loader.dataset.balance_classes() should remane the same?

        # Class weights
        all_labels = []
        for _, y_batch in train_loader:
            all_labels.extend(y_batch.cpu().numpy())
        class_weights = calculate_class_weights(np.array(all_labels))
        class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)
        criterion = nn.BCEWithLogitsLoss(pos_weight=class_weights)

        # Training phase
        model.train()
        train_loss, train_accuracy = 0.0, 0.0

        for X_batch, y_batch in train_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            optimizer.zero_grad()
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * X_batch.size(0)
            # Convert outputs to binary predictions
            predicted = (torch.sigmoid(outputs) > 0.5).float()


            # Calculate accuracy for each batch
            correct = (predicted == y_batch).float()
            train_accuracy += correct.mean(dim=1).mean().item()  # Mean across classes and samples

            

        # Validation phase
        model.eval()
        val_loss, val_accuracy = 0.0, 0.0
        with torch.no_grad():
            for X_val, y_val in val_loader:
                X_val, y_val = X_val.to(device), y_val.to(device)
                outputs = model(X_val)
                loss = criterion(outputs, y_val)

                val_loss += loss.item() * X_val.size(0)
                predicted = (torch.sigmoid(outputs) > 0.5).float()
                correct = (predicted == y_val).float()
                val_accuracy += correct.mean(dim=1).mean().item()

        # Test phase
        test_loss, test_accuracy = 0.0, 0.0
        with torch.no_grad():
            for X_test, y_test in test_loader:
                X_test, y_test = X_test.to(device), y_test.to(device)
                outputs = model(X_test)
                loss = criterion(outputs, y_test)

                test_loss += loss.item() * X_test.size(0)
                predicted = (torch.sigmoid(outputs) > 0.5).float()
                correct = (predicted == y_test).float()
                test_accuracy += correct.mean(dim=1).mean().item()

        # Test phase and metric collection
        # Inside your test phase
        y_true, y_pred = [], []
        with torch.no_grad():
            for X_test, y_test in test_loader:
                X_test, y_test = X_test.to(device), y_test.to(device)
                outputs = model(X_test)
                predicted = (torch.sigmoid(outputs) > 0.5).float()
                y_true.extend(y_test.cpu().numpy())
                y_pred.extend(predicted.cpu().numpy())

        metrics = calculate_metrics(np.array(y_true), np.array(y_pred))
        wandb.log(metrics)

        # Update scheduler
        scheduler.step(val_loss / len(val_loader.dataset))

        # Log metrics to WandB
        wandb.log({
            "epoch": epoch,
            "train_loss": train_loss / len(train_loader.dataset),
            "val_loss": val_loss / len(val_loader.dataset),
            "train_accuracy": train_accuracy / len(train_loader),
            "val_accuracy": val_accuracy / len(val_loader),
            "learning_rate": optimizer.param_groups[0]['lr'],
            "test_loss": test_loss / len(test_loader.dataset),
            "test_accuracy": test_accuracy / len(test_loader),
            #"confusion_matrix": wandb.plot.confusion_matrix(
            #   probs=None, y_true=y_true, preds=y_pred, class_names=np.unique(y_true)
            #), remove for now as it is not multilabel
            #"classification_report": classification_report(
            #    y_true, y_pred, target_names=[str(i) for i in range(len(np.unique(y_true)))]
            #)
        })

        #for name, param in model.named_parameters():
          #  if param.grad is None:
          #      print(f"Warning: {name} has no gradient!")

        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience = max_patience
            best_model = model.state_dict()
        else:
            patience -= 1
            if patience <= 0:
                print("Early stopping triggered.")
                break

    # Load the best model weights
    model.load_state_dict(best_model)
    return model

class StarClassifierMAMBA(nn.Module):
    def __init__(self, d_model, num_classes, d_state=64, d_conv=4, input_dim=17, n_layers=6):
        super(StarClassifierMAMBA, self).__init__()
        self.d_model = d_model
        self.num_classes = num_classes

        # MAMBA layer initialization
        config = MambaConfig(
            d_model=d_model,
            d_state=d_state,
            d_conv=d_conv,
            n_layers=n_layers

        )
        self.mamba_layer = Mamba(config)

        # Input projection to match the MAMBA layer dimension
        self.input_projection = nn.Linear(input_dim, d_model)

        # Fully connected classifier head with sigmoid activation for multi-label classification
        self.classifier = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, num_classes)
        )


    def forward(self, x):
        x = self.input_projection(x)  # Ensure correct embedding dimension

        # Reshape to (batch_size, sequence_length, d_model) for Mamba
        x = x.unsqueeze(1)  # Adding a sequence dimension, making it (batch_size, 1, d_model)

        x = self.mamba_layer(x)  # Now the input shape is correct
        x = x.mean(dim=1)  # Pooling operation for classification
        x = self.classifier(x)  # Classification head

        return x


In [3]:
# If X exists, delete it
if 'X' in locals():   
    del X, y
gc.collect()

batch_size = 512

# Example usage
if __name__ == "__main__":
    # Load and preprocess your data (example from original script)
    X = pd.read_pickle("Pickles/train_data_transformed2.pkl")
    classes = pd.read_pickle("Pickles/Updated_list_of_Classes.pkl")

    # Get labels and set them as y, drop them from X
    y = X[classes]

        # Print the shape of the data
    print(f"Data shape: {X.shape}")

    # Use Gaia data
    X = X[["parallax", "ra", "dec", "ra_error", "dec_error", "parallax_error", "pmra", "pmdec", "pmra_error", "pmdec_error",
            "phot_g_mean_flux", "flagnopllx", "phot_g_mean_flux_error", "phot_bp_mean_flux", "phot_rp_mean_flux",
            "phot_bp_mean_flux_error", "phot_rp_mean_flux_error", "flagnoflux", 'RS*', 
            '**', 'El*', 'Y*O', 's*b', 'cC*', 'HB*', 'dS*', 'Or*', 'LP*', 'BS*', 'Ae*', 'WV*', 'HS*', 'Ev*', 'AB*', 'sg*', 's*r', 'Ce*', 'gD*', 'OH*', 'HXB', 'Pu*', 'RV*', 'Sy*', 'V*', 'TT*', 'SN*', 'Be*', 'SB*', 'Em*', 'Er*', 'PM*', 'HV*', 'pA*', 'C*', 'BY*', 'Ro*', 'XB*', 'Ma*', 'Pe*', 'CV*', 'bC*', 'RR*', 'Mi*', 'SX*', 'RG*', 'LM*', 'WD*', 'S*', 'MS*', 'Ir*', 'a2*', 'PN', 'EB*']]

    # Print the shape of the data
    print(f"Data shape: {X.shape}")
    
    # Drop labels
    X.drop(classes, axis=1, inplace=True)
    
    # Read test data
    X_test = pd.read_pickle("Pickles/test_data_transformed.pkl")

    # Get labels and set them as y, drop them from X
    y_test = X_test[classes]

    # Drop gaia data
    X_test = X_test[["parallax", "ra", "dec", "ra_error", "dec_error", "parallax_error", "pmra", "pmdec", "pmra_error", "pmdec_error", 
                "phot_g_mean_flux", "flagnopllx", "phot_g_mean_flux_error", "phot_bp_mean_flux", "phot_rp_mean_flux", 
                "phot_bp_mean_flux_error", "phot_rp_mean_flux_error", "flagnoflux",
                'RS*', '**', 'El*', 'Y*O', 's*b', 'cC*', 'HB*', 'dS*', 'Or*', 'LP*', 'BS*', 'Ae*', 'WV*', 'HS*', 'Ev*', 'AB*', 'sg*', 's*r', 'Ce*', 'gD*', 'OH*', 'HXB', 'Pu*', 'RV*', 'Sy*', 'V*', 'TT*', 'SN*', 'Be*', 'SB*', 'Em*', 'Er*', 'PM*', 'HV*', 'pA*', 'C*', 'BY*', 'Ro*', 'XB*', 'Ma*', 'Pe*', 'CV*', 'bC*', 'RR*', 'Mi*', 'SX*', 'RG*', 'LM*', 'WD*', 'S*', 'MS*', 'Ir*', 'a2*', 'PN', 'EB*']]

    # Drop labels
    X_test.drop(classes, axis=1, inplace=True)
    
    # Split validation data
    X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)

    # Clear memory
    del X, y
    gc.collect()

    # Convert to torch tensors and create datasets
    print(X_train.dtypes)

    X_train = torch.tensor(X_train.values, dtype=torch.float32) # Convert DataFrame to numpy array
    X_val = torch.tensor(X_val.values, dtype=torch.float32)     # Convert DataFrame to numpy array    
    X_test = torch.tensor(X_test.values, dtype=torch.float32)   # Convert DataFrame to numpy array
    y_train = torch.tensor(y_train.values, dtype=torch.float32)  # Convert DataFrame to numpy array and float32
    y_val = torch.tensor(y_val.values, dtype=torch.float32)      # Convert DataFrame to numpy array and float32
    y_test = torch.tensor(y_test.values, dtype=torch.float32)    # Convert DataFrame to numpy array and float32

    train_dataset = BalancedMultiLabelDataset(X_train, y_train)
    val_dataset = BalancedMultiLabelDataset(X_val, y_val)
    test_dataset = BalancedMultiLabelDataset(X_test, y_test)


    #train_dataset = BalancedDataset(X_train, y_train)
    #val_dataset = BalancedValidationDataset(X_val, y_val)
    #test_dataset = BalancedValidationDataset(X_test, y_test, limit_per_label=10000)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)   


    # Print the shapes of the datasets
    print(f"Train dataset shape: {X_train.shape}")
    print(f"Validation dataset shape: {X_val.shape}")
    print(f"Test dataset shape: {X_test.shape}")
    print(f"Train labels shape: {y_train.shape}")
    print(f"Validation labels shape: {y_val.shape}")
    print(f"Test labels shape: {y_test.shape}")

Data shape: (108918, 3722)
Data shape: (108918, 73)
parallax                   float64
ra                         float64
dec                        float64
ra_error                   float32
dec_error                  float32
parallax_error             float32
pmra                       float64
pmdec                      float64
pmra_error                 float32
pmdec_error                float32
phot_g_mean_flux           float64
flagnopllx                 float64
phot_g_mean_flux_error     float32
phot_bp_mean_flux          float64
phot_rp_mean_flux          float64
phot_bp_mean_flux_error    float32
phot_rp_mean_flux_error    float32
flagnoflux                 float64
dtype: object
Train dataset shape: torch.Size([87134, 18])
Validation dataset shape: torch.Size([21784, 18])
Test dataset shape: torch.Size([27237, 18])
Train labels shape: torch.Size([87134, 55])
Validation labels shape: torch.Size([21784, 55])
Test labels shape: torch.Size([27237, 55])


In [None]:
# Same, but now with class weights
# Define the model with your parameters
d_model = 1024 # Embedding dimension
num_classes = 55  # Star classification categories
input_dim = 18 # Number of spectra points

# Define the training parameters
num_epochs = 2000
lr = 1e-4
patience = 100
depth = 12

# Define the config dictionary object
config = {"num_classes": num_classes, "batch_size": batch_size, "lr": lr, "patience": patience, "num_epochs": num_epochs, "d_model": d_model, "depth": depth}

# Initialize WandB project
wandb.init(project="ALLSTARS***gaia-mamba-test", entity="joaoc-university-of-southampton", config=config)
# Initialize and train the model
# Train the model using your `train_model_vit` or an adjusted training loop
model_mamba = StarClassifierMAMBA(d_model=d_model, num_classes=num_classes, input_dim=input_dim, n_layers=depth)
print(model_mamba)
# print number of parameters per layer
for name, param in model_mamba.named_parameters():
    print(name, param.numel())
print("Total number of parameters:", sum(p.numel() for p in model_mamba.parameters() if p.requires_grad))

# Move the model to device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
#device = 'cpu'
model_mamba = model_mamba.to(device)

# Train the model using your `train_model_vit` or an adjusted training loop
trained_model = train_model_mamba(
    model=model_mamba,
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    num_epochs=num_epochs,
    lr=lr,
    max_patience=patience,
    device=device
)
# Save the model and finish WandB session
wandb.finish()

wandb: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: joaoc (joaoc-university-of-southampton). Use `wandb login --relogin` to force relogin


StarClassifierMAMBA(
  (mamba_layer): Mamba(
    (layers): ModuleList(
      (0-11): 12 x ResidualBlock(
        (mixer): MambaBlock(
          (in_proj): Linear(in_features=1024, out_features=4096, bias=False)
          (conv1d): Conv1d(2048, 2048, kernel_size=(4,), stride=(1,), padding=(3,), groups=2048)
          (x_proj): Linear(in_features=2048, out_features=192, bias=False)
          (dt_proj): Linear(in_features=64, out_features=2048, bias=True)
          (out_proj): Linear(in_features=2048, out_features=1024, bias=False)
        )
        (norm): RMSNorm()
      )
    )
  )
  (input_projection): Linear(in_features=18, out_features=1024, bias=True)
  (classifier): Sequential(
    (0): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (1): Linear(in_features=1024, out_features=55, bias=True)
  )
)
mamba_layer.layers.0.mixer.A_log 131072
mamba_layer.layers.0.mixer.D 2048
mamba_layer.layers.0.mixer.in_proj.weight 4194304
mamba_layer.layers.0.mixer.conv1d.weight 8192
mamba_

In [None]:
# Save the model
torch.save(trained_model.state_dict(), "mamba_gaia_star_classifier_proud-thunder-4.pth")

In [None]:
# Code can work with specific classes or all classes
def process_star_data(model_path, data_path, classes_path, d_model=1024, num_classes=55, input_dim=3647, depth=12, class_to_plot="AllStars***lamost"):
    # Load the data
    X = pd.read_pickle(data_path)
    classes = pd.read_pickle(classes_path)

    # Load the trained model
    model = StarClassifierMAMBA(d_model=d_model, num_classes=num_classes, input_dim=input_dim, n_layers=depth)

    # Load the state dictionary
    state_dict = torch.load(model_path)
    model.load_state_dict(state_dict)

    # Get the spectral data for all stars with the label multihot encoded
    y = X[classes]

    # Drop Gaia data
    X.drop(["parallax", "ra", "dec", "ra_error", "dec_error", "parallax_error", "pmra", "pmdec", "pmra_error", "pmdec_error", 
            "phot_g_mean_flux", "flagnopllx", "phot_g_mean_flux_error", "phot_bp_mean_flux", "phot_rp_mean_flux", 
            "phot_bp_mean_flux_error", "phot_rp_mean_flux_error", "obsid", "flagnoflux", "otype"], axis=1, inplace=True)
    print(f"X shape after Gaia data drop: {X.shape}")
    print(f"y shape: {y.shape}")

    if class_to_plot != "AllStars***lamost":
        # Filter for a specific class
        X = X[y[class_to_plot] == 1]
        y = y[y[class_to_plot] == 1]

        print(f"X shape after filtering for {class_to_plot}: {X.shape}")
        print(f"y shape after filtering for {class_to_plot}: {y.shape}")

    # Drop label columns
    X.drop(classes, axis=1, inplace=True)

    # Convert to tensors
    X = torch.tensor(X.values, dtype=torch.float32)
    y = torch.tensor(y.values, dtype=torch.float32)

    # Create DataLoader
    class BalancedMultiLabelDataset(Dataset):
        def __init__(self, X, y):
            self.X = X
            self.y = y

        def __len__(self):
            return len(self.X)

        def __getitem__(self, idx):
            return self.X[idx], self.y[idx]

    dataset = BalancedMultiLabelDataset(X, y)
    loader = DataLoader(dataset, batch_size=128, shuffle=False)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = model.to(device)
    model.eval()

    all_predicted = []
    all_y = []

    with torch.no_grad():
        for batch_idx, (X_batch, y_batch) in enumerate(loader):
            # Move batch to device
            X_batch = X_batch.to(device)  # Add channel dimension
            y_batch = y_batch.to(device)
            
            # Forward pass
            outputs = model(X_batch)
            predicted = (outputs > 0.5).float()
            
            # Store predictions and labels
            all_predicted.append(predicted.cpu().numpy())
            all_y.append(y_batch.cpu().numpy())

            # Free GPU memory
            torch.cuda.empty_cache()

    # Concatenate all predictions and labels
    y_cpu = np.concatenate(all_y, axis=0)
    predicted_cpu = np.concatenate(all_predicted, axis=0)

    return y_cpu, predicted_cpu

# Example usage
model_path = "
data_path = "Pickles/test_data_transformed.pkl"
classes_path = "Pickles/Updated_list_of_Classes.pkl"
y_cpu, predicted_cpu = process_star_data(model_path, data_path, classes_path)

# Save the predictions
np.save("mamba_lamost_v1_y_cpu.npy", y_cpu)
np.save("mamba_lamost_v1_predicted_cpu.npy", predicted_cpu)