In [15]:
import gc
import wandb
import torch
import mambapy
import numpy as np
import pandas as pd
from torch import nn, optim
import sklearn.metrics as metrics
from mambapy.mamba import Mamba, MambaConfig
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, precision_score, recall_score, hamming_loss, roc_auc_score

In [23]:
class BalancedMultiLabelDataset(Dataset):
    def __init__(self, X, y, limit_per_label=201):
        """
        Multilabel version of the BalancedDataset.
        
        Args:
        - X (array-like): Input features.
        - y (array-like): Multi-hot encoded labels (2D array, each row is a multi-hot vector).
        - limit_per_label (int): Target number of samples per label.
        """
        self.X = X
        self.y = y
        self.limit_per_label = limit_per_label
        self.num_classes = y.shape[1]  # Number of possible classes
        self.indices = self.balance_classes()

    def balance_classes(self):
        indices = []
        class_counts = torch.sum(self.y, axis=0)  # Total occurrences of each class
        for cls in range(self.num_classes):
            cls_indices = np.where(self.y[:, cls] == 1)[0]  # Indices where this label is active
            if len(cls_indices) < self.limit_per_label:  # Upsample minority classes
                if len(cls_indices) == 0:
                    print(f"No samples found for class {cls}. Skipping.")
                    continue  # Skip this class if there are no samples for it
                extra_indices = np.random.choice(cls_indices, self.limit_per_label - len(cls_indices), replace=True)
                cls_indices = np.concatenate([cls_indices, extra_indices])
            elif len(cls_indices) > self.limit_per_label:  # Downsample majority classes
                cls_indices = np.random.choice(cls_indices, self.limit_per_label, replace=False)
            indices.extend(cls_indices)
        indices = np.unique(indices)  # Remove duplicate indices
        np.random.shuffle(indices)
        return indices

    def re_sample(self):
        """Rebalance the dataset if needed, for example, after changes to the dataset."""
        self.indices = self.balance_classes()

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        index = self.indices[idx]
        return self.X[index], self.y[index]
    
def calculate_metrics(y_true, y_pred):
    metrics = {
        "micro_f1": f1_score(y_true, y_pred, average='micro'),
        "macro_f1": f1_score(y_true, y_pred, average='macro'),
        "weighted_f1": f1_score(y_true, y_pred, average='weighted'),
        "micro_precision": precision_score(y_true, y_pred, average='micro'),
        "macro_precision": precision_score(y_true, y_pred, average='macro'),
        "weighted_precision": precision_score(y_true, y_pred, average='weighted'),
        "micro_recall": recall_score(y_true, y_pred, average='micro'),
        "macro_recall": recall_score(y_true, y_pred, average='macro'),
        "weighted_recall": recall_score(y_true, y_pred, average='weighted'),
        "hamming_loss": hamming_loss(y_true, y_pred)
    }
    
    # Check if there are at least two classes present in y_true
    #if len(np.unique(y_true)) > 1:
        #metrics["roc_auc"] = roc_auc_score(y_true, y_pred, average='macro', multi_class='ovr')
    #else:
       # metrics["roc_auc"] = None  # or you can set it to a default value or message
    
    return metrics
    
def calculate_class_weights(y):
    if y.ndim > 1:  # Check if y is 2D (multi-hot encoded)
        class_counts = np.sum(y, axis=0)  # Count how many times each class appears
    else:
        class_counts = np.bincount(y)  # For a 1D array, use bincount 
    print("Class counts:", class_counts)       
    total_samples = y.shape[0] if y.ndim > 1 else len(y)
    class_weights = np.where(class_counts > 0, total_samples / (len(class_counts) * class_counts), 0)
    return class_weights

    
def train_model_mamba(
    model, train_loader, val_loader, test_loader, 
    num_epochs=500, lr=1e-4, max_patience=20, device='cuda'
):
    # Move model to device
    model = model.to(device)

    # Define optimizer, scheduler, and loss function
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=int(max_patience / 5), verbose=True
    )
    all_labels = []

    for _, y_batch in train_loader:
        all_labels.extend(y_batch.cpu().numpy())    
        print("Shape of train_loader:", train_loader.dataset)
        print("Shape of val_loader:", len(val_loader))
        print("Shape of test_loader:", len(test_loader))
        print("Shape of all_labels:", len(all_labels))
        
    class_weights = calculate_class_weights(np.array(all_labels))
    class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)
    print("Class weights:", class_weights)
    criterion = nn.BCEWithLogitsLoss(pos_weight=class_weights)
    best_val_loss = float('inf')
    patience = max_patience

    for epoch in range(num_epochs):
        # Resample training and validation data
        train_loader.dataset.re_sample()
        val_loader.dataset.balance_classes()

        # Class weights
        all_labels = []
        for _, y_batch in train_loader:
            all_labels.extend(y_batch.cpu().numpy())
        class_weights = calculate_class_weights(np.array(all_labels))
        class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)
        criterion = nn.BCEWithLogitsLoss(pos_weight=class_weights)

        # Training phase
        model.train()
        train_loss, train_accuracy = 0.0, 0.0

        for X_batch, y_batch in train_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            optimizer.zero_grad()
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * X_batch.size(0)
            # Convert outputs to binary predictions
            predicted = (outputs > 0.5).float()

            # Calculate accuracy for each batch
            correct = (predicted == y_batch).float()
            train_accuracy += correct.mean(dim=1).mean().item()  # Mean across classes and samples

            

        # Validation phase
        model.eval()
        val_loss, val_accuracy = 0.0, 0.0
        with torch.no_grad():
            for X_val, y_val in val_loader:
                X_val, y_val = X_val.to(device), y_val.to(device)
                outputs = model(X_val)
                loss = criterion(outputs, y_val)

                val_loss += loss.item() * X_val.size(0)
                predicted = (outputs > 0.5).float()
                correct = (predicted == y_val).float()
                val_accuracy += correct.mean(dim=1).mean().item()

        # Test phase
        test_loss, test_accuracy = 0.0, 0.0
        with torch.no_grad():
            for X_test, y_test in test_loader:
                X_test, y_test = X_test.to(device), y_test.to(device)
                outputs = model(X_test)
                loss = criterion(outputs, y_test)

                test_loss += loss.item() * X_test.size(0)
                predicted = (outputs > 0.5).float()
                correct = (predicted == y_test).float()
                test_accuracy += correct.mean(dim=1).mean().item()

        # Test phase and metric collection
        # Inside your test phase
        y_true, y_pred = [], []
        with torch.no_grad():
            for X_test, y_test in test_loader:
                X_test, y_test = X_test.to(device), y_test.to(device)
                outputs = model(X_test)
                predicted = (outputs > 0.5).float()
                y_true.extend(y_test.cpu().numpy())
                y_pred.extend(predicted.cpu().numpy())

        metrics = calculate_metrics(np.array(y_true), np.array(y_pred))
        wandb.log(metrics)

        # Update scheduler
        scheduler.step(val_loss / len(val_loader.dataset))

        # Log metrics to WandB
        wandb.log({
            "epoch": epoch,
            "train_loss": train_loss / len(train_loader.dataset),
            "val_loss": val_loss / len(val_loader.dataset),
            "train_accuracy": train_accuracy / len(train_loader),
            "val_accuracy": val_accuracy / len(val_loader),
            "learning_rate": optimizer.param_groups[0]['lr'],
            "test_loss": test_loss / len(test_loader.dataset),
            "test_accuracy": test_accuracy / len(test_loader),
            #"confusion_matrix": wandb.plot.confusion_matrix(
            #   probs=None, y_true=y_true, preds=y_pred, class_names=np.unique(y_true)
            #), remove for now as it is not multilabel
            #"classification_report": classification_report(
            #    y_true, y_pred, target_names=[str(i) for i in range(len(np.unique(y_true)))]
            #)
        })

        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience = max_patience
            best_model = model.state_dict()
        else:
            patience -= 1
            if patience <= 0:
                print("Early stopping triggered.")
                break

    # Load the best model weights
    model.load_state_dict(best_model)
    return model

class StarClassifierMAMBA(nn.Module):
    def __init__(self, d_model, num_classes, d_state=64, d_conv=4, input_dim=17, n_layers=6):
        super(StarClassifierMAMBA, self).__init__()
        self.d_model = d_model
        self.num_classes = num_classes

        # MAMBA layer initialization
        config = MambaConfig(
            d_model=d_model,
            d_state=d_state,
            d_conv=d_conv,
            n_layers=n_layers

        )
        self.mamba_layer = Mamba(config)

        # Input projection to match the MAMBA layer dimension
        self.input_projection = nn.Linear(input_dim, d_model)

        # Fully connected classifier head with sigmoid activation for multi-label classification
        self.classifier = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, num_classes)
        )


    def forward(self, x):
        x = self.input_projection(x)  # Ensure the input has the correct dimension
        #x = x.unsqueeze(1)  # Adds a sequence dimension (L=1).
        x = self.mamba_layer(x)
        x = x.mean(dim=1)  # Pooling operation for classification
        x = self.classifier(x)
        return x


# Open the data and prepare datasets for Classifier

In [10]:
# If X exists, delete it
if 'X' in locals():   
    del X, y
gc.collect()

batch_size = 1024

# Example usage
if __name__ == "__main__":
    # Load and preprocess your data (example from original script)
    X = pd.read_pickle("Pickles/train_data_transformed.pkl")
    classes = pd.read_pickle("Pickles/Updated_list_of_Classes.pkl")

    # Get labels and set them as y, drop them from X
    y = X[classes]

    # Drop gaia data
    X.drop(["parallax", "ra", "dec", "ra_error", "dec_error", "parallax_error", "pmra", "pmdec", "pmra_error", "pmdec_error", 
            "phot_g_mean_flux", "flagnopllx", "phot_g_mean_flux_error", "phot_bp_mean_flux", "phot_rp_mean_flux", 
            "phot_bp_mean_flux_error", "phot_rp_mean_flux_error", "obsid", "flagnoflux", "otype"], axis=1, inplace=True)
    
    # Drop labels
    X.drop(classes, axis=1, inplace=True)
    
    # Read test data
    X_test = pd.read_pickle("Pickles/test_data_transformed.pkl")

    # Get labels and set them as y, drop them from X
    y_test = X_test[classes]

    # Drop gaia data
    X_test.drop(["parallax", "ra", "dec", "ra_error", "dec_error", "parallax_error", "pmra", "pmdec", "pmra_error", "pmdec_error", 
                "phot_g_mean_flux", "flagnopllx", "phot_g_mean_flux_error", "phot_bp_mean_flux", "phot_rp_mean_flux", 
                "phot_bp_mean_flux_error", "phot_rp_mean_flux_error", "obsid", "flagnoflux", "otype"], axis=1, inplace=True)
    
    # Drop labels
    X_test.drop(classes, axis=1, inplace=True)
    
    # Split validation data
    X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)

    # Clear memory
    del X, y
    gc.collect()

    # Convert to torch tensors and create datasets
    X_train = torch.tensor(X_train.values, dtype=torch.float32).unsqueeze(1)  # Convert DataFrame to numpy array
    X_val = torch.tensor(X_val.values, dtype=torch.float32).unsqueeze(1)      # Convert DataFrame to numpy array    
    X_test = torch.tensor(X_test.values, dtype=torch.float32).unsqueeze(1)    # Convert DataFrame to numpy array
    y_train = torch.tensor(y_train.values, dtype=torch.float32)  # Convert DataFrame to numpy array and float32
    y_val = torch.tensor(y_val.values, dtype=torch.float32)      # Convert DataFrame to numpy array and float32
    y_test = torch.tensor(y_test.values, dtype=torch.float32)    # Convert DataFrame to numpy array and float32

    train_dataset = BalancedMultiLabelDataset(X_train, y_train)
    val_dataset = BalancedMultiLabelDataset(X_val, y_val)
    test_dataset = BalancedMultiLabelDataset(X_test, y_test)


    #train_dataset = BalancedDataset(X_train, y_train)
    #val_dataset = BalancedValidationDataset(X_val, y_val)
    #test_dataset = BalancedValidationDataset(X_test, y_test, limit_per_label=10000)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    # Print the shapes of the datasets
    print(f"Train dataset shape: {X_train.shape}")
    print(f"Validation dataset shape: {X_val.shape}")
    print(f"Test dataset shape: {X_test.shape}")
    print(f"Train labels shape: {y_train.shape}")
    print(f"Validation labels shape: {y_val.shape}")
    print(f"Test labels shape: {y_test.shape}")

No samples found for class 51. Skipping.
No samples found for class 11. Skipping.
No samples found for class 20. Skipping.
No samples found for class 21. Skipping.
No samples found for class 24. Skipping.
No samples found for class 27. Skipping.
No samples found for class 34. Skipping.
No samples found for class 38. Skipping.
No samples found for class 53. Skipping.
Train dataset shape: torch.Size([87134, 1, 3647])
Validation dataset shape: torch.Size([21784, 1, 3647])
Test dataset shape: torch.Size([27237, 1, 3647])
Train labels shape: torch.Size([87134, 55])
Validation labels shape: torch.Size([21784, 55])
Test labels shape: torch.Size([27237, 55])


# Train model

In [12]:
# Same, but now with class weights
# Define the model with your parameters
d_model = 1024 # Embedding dimension
num_classes = 55  # Star classification categories
input_dim = 3647 # Number of spectra points

# Define the training parameters
num_epochs = 3
lr = 2e-4
patience = 100   
depth = 6

# Define the config dictionary object
config = {"num_classes": num_classes, "batch_size": batch_size, "lr": lr, "patience": patience, "num_epochs": num_epochs, "d_model": d_model, "depth": depth}

# Initialize WandB project
wandb.init(project="ALLSTARS***lamost-mamba-test", entity="joaoc-university-of-southampton", config=config)
# Initialize and train the model
# Train the model using your `train_model_vit` or an adjusted training loop
model_mamba = StarClassifierMAMBA(d_model=d_model, num_classes=num_classes, input_dim=input_dim, n_layers=depth)
print(model_mamba)
# print number of parameters per layer
for name, param in model_mamba.named_parameters():
    print(name, param.numel())
print("Total number of parameters:", sum(p.numel() for p in model_mamba.parameters() if p.requires_grad))

# Move the model to device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_mamba = model_mamba.to(device)

# Train the model using your `train_model_vit` or an adjusted training loop
trained_model = train_model_mamba(
    model=model_mamba,
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    num_epochs=num_epochs,
    lr=lr,
    max_patience=patience,
    device=device
)
# Save the model and finish WandB session
wandb.finish()

0,1
epoch,▁▁▂▂▂▃▃▃▄▄▅▅▅▆▆▆▇▇▇██
hamming_loss,█▄▃▂▂▂▂▂▂▂▂▂▂▃▂▂▁▂▂▂▁
learning_rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
macro_f1,▂▁▁▁▁▁▁▁▁▁▂▂▃▃▄▄█▅▅▅▇
macro_precision,▄▂▁▂▂▂▁▁▁▂▃▃▃▃▅▃▇▇█▆█
macro_recall,▂▁▁▁▁▁▁▁▁▁▁▂▃▄▃▄█▅▅▅▇
micro_f1,▂▁▁▁▁▁▁▁▁▁▁▁▂▂▃▃█▄▄▄█
micro_precision,▁▁▁▂▁▂▁▁▁▅▅█▇▄▇█▇█▇██
micro_recall,▂▁▁▁▁▁▁▁▁▁▁▁▂▂▃▃█▄▄▄█
test_accuracy,▁▅▆▇▇▇▇▇▇▇▇▇▇▆▇▇▇▇▇▇█

0,1
epoch,20.0
hamming_loss,0.03378
learning_rate,0.0002
macro_f1,0.04944
macro_precision,0.11201
macro_recall,0.03598
micro_f1,0.03206
micro_precision,0.81319
micro_recall,0.01635
test_accuracy,0.96621


StarClassifierMAMBA(
  (mamba_layer): Mamba(
    (layers): ModuleList(
      (0-5): 6 x ResidualBlock(
        (mixer): MambaBlock(
          (in_proj): Linear(in_features=1024, out_features=4096, bias=False)
          (conv1d): Conv1d(2048, 2048, kernel_size=(4,), stride=(1,), padding=(3,), groups=2048)
          (x_proj): Linear(in_features=2048, out_features=192, bias=False)
          (dt_proj): Linear(in_features=64, out_features=2048, bias=True)
          (out_proj): Linear(in_features=2048, out_features=1024, bias=False)
        )
        (norm): RMSNorm()
      )
    )
  )
  (input_projection): Linear(in_features=3647, out_features=1024, bias=True)
  (classifier): Sequential(
    (0): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (1): Linear(in_features=1024, out_features=55, bias=True)
  )
)
mamba_layer.layers.0.mixer.A_log 131072
mamba_layer.layers.0.mixer.D 2048
mamba_layer.layers.0.mixer.in_proj.weight 4194304
mamba_layer.layers.0.mixer.conv1d.weight 8192
mamba_



Shape of train_loader: <__main__.BalancedMultiLabelDataset object at 0x0000018CD8DC7F90>
Shape of val_loader: 5
Shape of test_loader: 5
Shape of all_labels: 1024
Shape of train_loader: <__main__.BalancedMultiLabelDataset object at 0x0000018CD8DC7F90>
Shape of val_loader: 5
Shape of test_loader: 5
Shape of all_labels: 2048
Shape of train_loader: <__main__.BalancedMultiLabelDataset object at 0x0000018CD8DC7F90>
Shape of val_loader: 5
Shape of test_loader: 5
Shape of all_labels: 3072
Shape of train_loader: <__main__.BalancedMultiLabelDataset object at 0x0000018CD8DC7F90>
Shape of val_loader: 5
Shape of test_loader: 5
Shape of all_labels: 4096
Shape of train_loader: <__main__.BalancedMultiLabelDataset object at 0x0000018CD8DC7F90>
Shape of val_loader: 5
Shape of test_loader: 5
Shape of all_labels: 5120
Shape of train_loader: <__main__.BalancedMultiLabelDataset object at 0x0000018CD8DC7F90>
Shape of val_loader: 5
Shape of test_loader: 5
Shape of all_labels: 6012
Class counts: [2.040e+02 1.1

  class_weights = np.where(class_counts > 0, total_samples / (len(class_counts) * class_counts), 0)


Class weights: tensor([5.3583e-01, 9.7423e-02, 4.9686e+00, 3.3124e-01, 5.7531e+00, 1.4196e+00,
        2.3867e-01, 3.3634e-01, 9.3427e-01, 4.9238e-01, 5.1079e-01, 5.4655e+01,
        4.3724e+00, 5.3322e-01, 5.4410e-02, 4.2368e-01, 3.2150e+00, 9.9372e+00,
        1.1881e+00, 4.9238e-01, 3.6436e+01, 5.4655e+01, 4.9913e-01, 1.8218e+01,
        5.4655e+01, 1.3767e-01, 1.1269e+00, 1.0931e+02, 3.4159e+00, 3.7955e-01,
        5.4383e-01, 4.9686e-01, 5.4383e-01, 9.7597e-01, 1.8218e+01, 5.4113e-01,
        5.3847e-01, 3.0791e-01, 1.5616e+01, 3.0364e+00, 5.1805e-01, 1.1881e+00,
        1.0931e+02, 3.8220e-01, 4.6515e-01, 1.5616e+01, 3.8900e-01, 5.4383e-01,
        5.0373e-01, 3.6436e+00, 1.3767e-01, 0.0000e+00, 1.0931e+01, 2.7327e+01,
        3.6804e-01], device='cuda:0')
No samples found for class 51. Skipping.
No samples found for class 11. Skipping.
No samples found for class 20. Skipping.
No samples found for class 21. Skipping.
No samples found for class 24. Skipping.
No samples found for c

  class_weights = np.where(class_counts > 0, total_samples / (len(class_counts) * class_counts), 0)


No samples found for class 51. Skipping.
No samples found for class 11. Skipping.
No samples found for class 20. Skipping.
No samples found for class 21. Skipping.
No samples found for class 24. Skipping.
No samples found for class 27. Skipping.
No samples found for class 34. Skipping.
No samples found for class 38. Skipping.
No samples found for class 53. Skipping.
Class counts: [2.100e+02 1.120e+03 2.200e+01 3.320e+02 1.900e+01 7.700e+01 4.670e+02
 3.160e+02 1.170e+02 2.180e+02 2.090e+02 2.000e+00 2.500e+01 2.030e+02
 2.014e+03 2.570e+02 3.400e+01 1.100e+01 9.200e+01 2.200e+02 3.000e+00
 2.000e+00 2.130e+02 6.000e+00 2.000e+00 7.940e+02 9.700e+01 1.000e+00
 3.200e+01 2.930e+02 2.010e+02 2.240e+02 2.010e+02 1.120e+02 6.000e+00
 2.020e+02 2.020e+02 3.570e+02 7.000e+00 3.600e+01 2.110e+02 9.200e+01
 1.000e+00 2.820e+02 2.330e+02 7.000e+00 2.880e+02 2.010e+02 2.130e+02
 3.000e+01 7.810e+02 0.000e+00 1.000e+01 4.000e+00 2.850e+02]


  class_weights = np.where(class_counts > 0, total_samples / (len(class_counts) * class_counts), 0)
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


0,1
epoch,▁▅█
hamming_loss,█▃▁
learning_rate,▁▁▁
macro_f1,▁▁█
macro_precision,█▆▁
macro_recall,█▆▁
micro_f1,█▃▁
micro_precision,▁▄█
micro_recall,█▃▁
test_accuracy,▁▆█

0,1
epoch,2.0
hamming_loss,0.03431
learning_rate,0.0002
macro_f1,0.01214
macro_precision,0.03636
macro_recall,0.0091
micro_f1,0.00044
micro_precision,0.06897
micro_recall,0.00022
test_accuracy,0.96569


In [8]:
# Save the model
torch.save(trained_model.state_dict(), "mamba_star_classifier_devout_feather_60.pth")

In [9]:
import sklearn.metrics as metrics

# Confusion matrix and classification report
y_true, y_pred = [], []
with torch.no_grad():
    for X_test, y_test in test_loader:
        X_test, y_test = X_test.to(device), y_test.to(device)
        outputs = trained_model(X_test)
        predicted = (outputs > 0.5).float()
        y_true.extend(y_test.cpu().numpy())
        y_pred.extend(predicted.cpu().numpy())

cm = metrics.multilabel_confusion_matrix(y_true, y_pred)
print(cm)
print(metrics.classification_report(y_true, y_pred, target_names=classes))


[[[4598   11]
  [ 186   15]]

 [[3823   22]
  [ 732  233]]

 [[4804    0]
  [   6    0]]

 [[4633   15]
  [ 114   48]]

 [[4804    0]
  [   4    2]]

 [[4786    0]
  [  14   10]]

 [[4347   31]
  [ 229  203]]

 [[4511   26]
  [ 183   90]]

 [[4767    7]
  [  27    9]]

 [[4571   21]
  [  83  135]]

 [[4716    6]
  [  67   21]]

 [[4808    0]
  [   2    0]]

 [[4802    0]
  [   8    0]]

 [[4594   12]
  [  52  152]]

 [[3218   24]
  [ 735  833]]

 [[4717    6]
  [  38   49]]

 [[4800    0]
  [   5    5]]

 [[4808    0]
  [   1    1]]

 [[4781    0]
  [  18   11]]

 [[4680   10]
  [  63   57]]

 [[4809    0]
  [   1    0]]

 [[4809    0]
  [   1    0]]

 [[4577   26]
  [ 183   24]]

 [[4807    1]
  [   2    0]]

 [[4808    0]
  [   2    0]]

 [[4012   35]
  [ 603  160]]

 [[4775    7]
  [  26    2]]

 [[4809    0]
  [   1    0]]

 [[4793    5]
  [  10    2]]

 [[4537    0]
  [ 273    0]]

 [[4605    4]
  [ 192    9]]

 [[4582   15]
  [ 194   19]]

 [[4604    5]
  [ 188   13]]

 [[4772   

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


              precision    recall  f1-score   support

         RS*       0.58      0.07      0.13       201
          **       0.91      0.24      0.38       965
         El*       0.00      0.00      0.00         6
         Y*O       0.76      0.30      0.43       162
         s*b       1.00      0.33      0.50         6
         cC*       1.00      0.42      0.59        24
         HB*       0.87      0.47      0.61       432
         dS*       0.78      0.33      0.46       273
         Or*       0.56      0.25      0.35        36
         LP*       0.87      0.62      0.72       218
         BS*       0.78      0.24      0.37        88
         Ae*       0.00      0.00      0.00         2
         WV*       0.00      0.00      0.00         8
         HS*       0.93      0.75      0.83       204
         Ev*       0.97      0.53      0.69      1568
         AB*       0.89      0.56      0.69        87
         sg*       1.00      0.50      0.67        10
         s*r       1.00    

In [21]:
# clear vram and cache
torch.cuda.empty_cache()

# Evaluating Model for Em*

In [22]:
# Same, but now with class weights
d_model = 2048 # Embedding dimension
num_classes = 55  # Star classification categories
input_dim = 3647 # Number of spectra points

# Define the training parameters
num_epochs = 3000
lr = 2e-5
patience = 100   
depth = 6

# Load the data
X = pd.read_pickle("Pickles/train_data_transformed.pkl")
classes = pd.read_pickle("Pickles/Updated_list_of_Classes.pkl")

# Load the trained model
trained_model = StarClassifierMAMBA(d_model=d_model, num_classes=num_classes, input_dim=input_dim, n_layers=depth)

# Load the state dictionary
state_dict = torch.load("mamba_star_classifier_devout_feather_60.pth")
#model.load_state_dict(state_dict)

# Set the model to evaluation mode
#trained_model.eval()

# Get the spectral data for all stars with the Em* label multihot encoded
y = X[classes]

# Drop gaia data
X.drop(["parallax", "ra", "dec", "ra_error", "dec_error", "parallax_error", "pmra", "pmdec", "pmra_error", "pmdec_error", 
        "phot_g_mean_flux", "flagnopllx", "phot_g_mean_flux_error", "phot_bp_mean_flux", "phot_rp_mean_flux", 
        "phot_bp_mean_flux_error", "phot_rp_mean_flux_error", "obsid", "flagnoflux", "otype"], axis=1, inplace=True)
print(X.shape)
print(y.shape)

# Get the spectral data for all stars with a value of 1 in the Em* label
X = X[y["Em*"] == 1]
y = y[y["Em*"] == 1]

print(X.shape)
print(y.shape)

# Drop labels
X.drop(classes, axis=1, inplace=True)

# Do batches of 32
X = X.values
y = y.values
X = torch.tensor(X, dtype=torch.float32).unsqueeze(1)
y = torch.tensor(y, dtype=torch.float32)

# Do batches of 32
dataset = BalancedMultiLabelDataset(X, y)
loader = DataLoader(dataset, batch_size=1, shuffle=False)

# Pass the data through the trained model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = trained_model.to(device)
model.eval()
with torch.no_grad():
    outputs = model(X.to(device))
    predicted = (outputs > 0.5).float()

# Move tensors to CPU before converting to NumPy arrays
y_cpu = y.cpu()
predicted_cpu = predicted.cpu()

# Confusion matrix and classification report
cm = metrics.multilabel_confusion_matrix(y_cpu.numpy(), predicted_cpu.numpy())
print(cm)
print(metrics.classification_report(y_cpu.numpy(), predicted_cpu.numpy(), target_names=classes))

  state_dict = torch.load("mamba_star_classifier_devout_feather_60.pth")


(108918, 3702)
(108918, 55)
(9269, 3702)
(9269, 55)
No samples found for class 0. Skipping.
No samples found for class 1. Skipping.
No samples found for class 2. Skipping.
No samples found for class 3. Skipping.
No samples found for class 4. Skipping.
No samples found for class 5. Skipping.
No samples found for class 6. Skipping.
No samples found for class 7. Skipping.
No samples found for class 8. Skipping.
No samples found for class 9. Skipping.
No samples found for class 10. Skipping.
No samples found for class 11. Skipping.
No samples found for class 12. Skipping.
No samples found for class 13. Skipping.
No samples found for class 14. Skipping.
No samples found for class 15. Skipping.
No samples found for class 16. Skipping.
No samples found for class 17. Skipping.
No samples found for class 18. Skipping.
No samples found for class 19. Skipping.
No samples found for class 20. Skipping.
No samples found for class 21. Skipping.
No samples found for class 22. Skipping.
No samples foun

OutOfMemoryError: CUDA out of memory. Tried to allocate 9.05 GiB. GPU 0 has a total capacity of 16.00 GiB of which 0 bytes is free. Of the allocated memory 22.06 GiB is allocated by PyTorch, and 721.94 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)