In [2]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
from sklearn.metrics import classification_report
import wandb
import gc
import pandas as pd
from sklearn.model_selection import train_test_split
import mambapy
from mambapy.mamba import Mamba, MambaConfig
from torch import nn, optim


In [5]:
class BalancedDataset(Dataset):
    def __init__(self, X, y, limit_per_label=1600):
        self.X = X
        self.y = y
        self.limit_per_label = limit_per_label
        self.classes = np.unique(y)
        self.indices = self.balance_classes()

    def balance_classes(self):
        indices = []
        for cls in self.classes:
            cls_indices = np.where(self.y == cls)[0]
            if len(cls_indices) > self.limit_per_label:
                cls_indices = np.random.choice(cls_indices, self.limit_per_label, replace=False)
            indices.extend(cls_indices)
        np.random.shuffle(indices)
        return indices

    def re_sample(self):
        self.indices = self.balance_classes()

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        index = self.indices[idx]
        return self.X[index], self.y[index]
# Custom Dataset for validation with limit per class
class BalancedValidationDataset(Dataset):
    def __init__(self, X, y, limit_per_label=400):
        self.X = X
        self.y = y
        self.limit_per_label = limit_per_label
        self.classes = np.unique(y)
        self.indices = self.balance_classes()

    def balance_classes(self):
        indices = []
        for cls in self.classes:
            cls_indices = np.where(self.y == cls)[0]
            if len(cls_indices) > self.limit_per_label:
                cls_indices = np.random.choice(cls_indices, self.limit_per_label, replace=False)
            indices.extend(cls_indices)
        np.random.shuffle(indices)
        return indices
    
    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        index = self.indices[idx]
        return self.X[index], self.y[index]
    
def train_model_mamba(
    model, train_loader, val_loader, test_loader, 
    num_epochs=500, lr=1e-4, max_patience=20, device='cuda'
):
    # Move model to device
    model = model.to(device)

    # Define optimizer, scheduler, and loss function
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=int(max_patience / 3), verbose=True
    )
    criterion = nn.CrossEntropyLoss()

    best_val_loss = float('inf')
    patience = max_patience

    for epoch in range(num_epochs):
        # Resample training and validation data
        train_loader.dataset.re_sample()
        val_loader.dataset.balance_classes()

        # Training phase
        model.train()
        train_loss, train_accuracy = 0.0, 0.0

        for X_batch, y_batch in train_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            optimizer.zero_grad()
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * X_batch.size(0)
            train_accuracy += (outputs.argmax(dim=1) == y_batch).float().mean().item()

        # Validation phase
        model.eval()
        val_loss, val_accuracy = 0.0, 0.0
        with torch.no_grad():
            for X_val, y_val in val_loader:
                X_val, y_val = X_val.to(device), y_val.to(device)
                outputs = model(X_val)
                loss = criterion(outputs, y_val)

                val_loss += loss.item() * X_val.size(0)
                val_accuracy += (outputs.argmax(dim=1) == y_val).float().mean().item()

        # Test phase and metric collection
        test_loss, test_accuracy = 0.0, 0.0
        y_true, y_pred = [], []
        with torch.no_grad():
            for X_test, y_test in test_loader:
                X_test, y_test = X_test.to(device), y_test.to(device)
                outputs = model(X_test)
                loss = criterion(outputs, y_test)

                test_loss += loss.item() * X_test.size(0)
                test_accuracy += (outputs.argmax(dim=1) == y_test).float().mean().item()
                y_true.extend(y_test.cpu().numpy())
                y_pred.extend(outputs.argmax(dim=1).cpu().numpy())

        # Update scheduler
        scheduler.step(val_loss / len(val_loader.dataset))

        # Log metrics to WandB
        wandb.log({
            "epoch": epoch,
            "train_loss": train_loss / len(train_loader.dataset),
            "val_loss": val_loss / len(val_loader.dataset),
            "train_accuracy": train_accuracy / len(train_loader),
            "val_accuracy": val_accuracy / len(val_loader),
            "learning_rate": optimizer.param_groups[0]['lr'],
            "test_loss": test_loss / len(test_loader.dataset),
            "test_accuracy": test_accuracy / len(test_loader),
            "confusion_matrix": wandb.plot.confusion_matrix(
                probs=None, y_true=y_true, preds=y_pred, class_names=np.unique(y_true)
            ),
            "classification_report": classification_report(
                y_true, y_pred, target_names=[str(i) for i in range(len(np.unique(y_true)))]
            )
        })

        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience = max_patience
            best_model = model.state_dict()
        else:
            patience -= 1
            if patience <= 0:
                print("Early stopping triggered.")
                break

    # Load the best model weights
    model.load_state_dict(best_model)
    return model



In [4]:
batch_size = 256

# Example usage
if __name__ == "__main__":
    # Load and preprocess your data (example from original script)
    # Load and preprocess data
    X = pd.read_pickle("Pickles/trainv2.pkl")
    y = X["label"]
    label_mapping = {'star': 0, 'binary_star': 1, 'galaxy': 2, 'agn': 3}
    y = y.map(label_mapping).values
    X = X.drop(["parallax", "ra", "dec", "ra_error", "dec_error", "parallax_error", "pmra", "pmdec", "pmra_error", "pmdec_error", 
                "phot_g_mean_flux", "flagnopllx", "phot_g_mean_flux_error", "phot_bp_mean_flux", "phot_rp_mean_flux", 
                "phot_bp_mean_flux_error", "phot_rp_mean_flux_error", "label", "obsid"], axis=1).values
    
    # Read test data
    X_test = pd.read_pickle("Pickles/testv2.pkl")
    y_test = X_test["label"].map(label_mapping).values
    X_test = X_test.drop(["parallax", "ra", "dec", "ra_error", "dec_error", "parallax_error", "pmra", "pmdec", "pmra_error", "pmdec_error", 
                "phot_g_mean_flux", "flagnopllx", "phot_g_mean_flux_error", "phot_bp_mean_flux", "phot_rp_mean_flux", 
                "phot_bp_mean_flux_error", "phot_rp_mean_flux_error", "label", "obsid"], axis=1).values
    
    # Split data
    X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)

    # Clear memory
    del X, y
    gc.collect()

    # Convert to torch tensors and create datasets
    X_train = torch.tensor(X_train, dtype=torch.float32).unsqueeze(1)
    X_val = torch.tensor(X_val, dtype=torch.float32).unsqueeze(1)
    y_train = torch.tensor(y_train, dtype=torch.long)
    y_val = torch.tensor(y_val, dtype=torch.long)

    train_dataset = BalancedDataset(X_train, y_train)
    val_dataset = BalancedValidationDataset(X_val, y_val)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(BalancedValidationDataset(torch.tensor(X_test, dtype=torch.float32).unsqueeze(1),
                                                    torch.tensor(y_test, dtype=torch.long)), batch_size=batch_size, shuffle=False)
    


In [6]:
class StarClassifierMAMBA(nn.Module):
    def __init__(self, d_model, num_classes, d_state=64, d_conv=4, input_dim=17, n_layers=6):
        super(StarClassifierMAMBA, self).__init__()
        self.d_model = d_model
        self.num_classes = num_classes

        # MAMBA layer initialization
        config = MambaConfig(
            d_model=d_model,
            d_state=d_state,
            d_conv=d_conv,
            n_layers=n_layers

        )
        self.mamba_layer = Mamba(config)

        # Input projection to match the MAMBA layer dimension
        self.input_projection = nn.Linear(input_dim, d_model)

        # Fully connected classifier head
        self.classifier = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, num_classes)
        )


    def forward(self, x):
        x = self.input_projection(x)  # Ensure the input has the correct dimension
        #x = x.unsqueeze(1)  # Adds a sequence dimension (L=1).
        x = self.mamba_layer(x)
        x = x.mean(dim=1)  # Pooling operation for classification
        x = self.classifier(x)
        return x

In [7]:
from jamba.model import Jamba

# Define the model with your parameters
d_model = 128 # Embedding dimension
num_classes = 4  # Star classification categories

# Define the training parameters
num_epochs = 500
lr = 1e-4
patience = 30   
depth = 10

# Define the config dictionary object
config = {"num_classes": num_classes, "batch_size": batch_size, "lr": lr, "patience": patience, "num_epochs": num_epochs, "d_model": d_model, "depth": depth}

# Initialize WandB project
wandb.init(project="lamost-jamba-test", entity="joaoc-university-of-southampton", config=config)

# Define hyperparameters
num_epochs = 100
lr = 1e-3
patience = 10

# Initialize the Jamba model
model_mamba = Jamba(
    dim=3748,                # Input dimensionality
    depth=4,                # Number of layers
    num_tokens=100,         # Token size (adapt to your case)
    d_state=d_model,            # Hidden state dimensionality
    d_conv=128,             # Convolutional layers dimensionality
    heads=8,                # Number of attention heads
    num_experts=8,          # Number of expert networks
    num_experts_per_token=2 # Experts per token
)

print(model_mamba)
# Print number of parameters per layer
for name, param in model_mamba.named_parameters():
    print(name, param.numel())
print("Total number of parameters:", sum(p.numel() for p in model_mamba.parameters() if p.requires_grad))

# Move the model to device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_mamba = model_mamba.to(device)

# Train the model
trained_model = train_model_mamba(
    model=model_mamba,
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    num_epochs=num_epochs,
    lr=lr,
    max_patience=patience,
    device=device
)

# Save the model and finish WandB session
wandb.finish()

    

wandb: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: joaoc (joaoc-university-of-southampton). Use `wandb login --relogin` to force relogin


Jamba(
  (layers): ModuleList(
    (0-3): 4 x JambaBlock(
      (mamba_layer): MambaBlock(
        (in_proj): Linear(in_features=3748, out_features=14992, bias=False)
        (conv1d): Conv1d(7496, 7496, kernel_size=(128,), stride=(1,), padding=(127,), groups=7496)
        (x_proj): Linear(in_features=7496, out_features=491, bias=False)
        (dt_proj): Linear(in_features=235, out_features=7496, bias=True)
        (out_proj): Linear(in_features=7496, out_features=3748, bias=False)
      )
      (mamba_moe_layer): MambaMoELayer(
        (mamba): MambaBlock(
          (in_proj): Linear(in_features=3748, out_features=14992, bias=False)
          (conv1d): Conv1d(7496, 7496, kernel_size=(128,), stride=(1,), padding=(127,), groups=7496)
          (x_proj): Linear(in_features=7496, out_features=491, bias=False)
          (dt_proj): Linear(in_features=235, out_features=7496, bias=True)
          (out_proj): Linear(in_features=7496, out_features=3748, bias=False)
        )
        (moe): MoE

RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.cuda.FloatTensor instead (while checking arguments for embedding)

In [8]:
# Define the model with your parameters
d_model = 2048 # Embedding dimension
num_classes = 4  # Star classification categories

# Define the training parameters
num_epochs = 500
lr = 2e-7
patience = 50   
depth = 6

# Define the config dictionary object
config = {"num_classes": num_classes, "batch_size": batch_size, "lr": lr, "patience": patience, "num_epochs": num_epochs, "d_model": d_model, "depth": depth}

# Initialize WandB project
wandb.init(project="lamost-mamba-test", entity="joaoc-university-of-southampton", config=config)
# Initialize and train the model
# Train the model using your `train_model_vit` or an adjusted training loop
model_mamba = StarClassifierMAMBA(d_model=d_model, num_classes=num_classes, input_dim=3748, n_layers=depth)
print(model_mamba)
# print number of parameters per layer
for name, param in model_mamba.named_parameters():
    print(name, param.numel())
print("Total number of parameters:", sum(p.numel() for p in model_mamba.parameters() if p.requires_grad))

# Move the model to device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_mamba = model_mamba.to(device)

# Train the model using your `train_model_vit` or an adjusted training loop
trained_model = train_model_mamba(
    model=model_mamba,
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    num_epochs=num_epochs,
    lr=lr,
    max_patience=patience,
    device=device
)
# Save the model and finish WandB session
wandb.finish()

StarClassifierMAMBA(
  (mamba_layer): Mamba(
    (layers): ModuleList(
      (0-5): 6 x ResidualBlock(
        (mixer): MambaBlock(
          (in_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (conv1d): Conv1d(4096, 4096, kernel_size=(4,), stride=(1,), padding=(3,), groups=4096)
          (x_proj): Linear(in_features=4096, out_features=256, bias=False)
          (dt_proj): Linear(in_features=128, out_features=4096, bias=True)
          (out_proj): Linear(in_features=4096, out_features=2048, bias=False)
        )
        (norm): RMSNorm()
      )
    )
  )
  (input_projection): Linear(in_features=3748, out_features=2048, bias=True)
  (classifier): Sequential(
    (0): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
    (1): Linear(in_features=2048, out_features=4, bias=True)
  )
)
mamba_layer.layers.0.mixer.A_log 262144
mamba_layer.layers.0.mixer.D 4096
mamba_layer.layers.0.mixer.in_proj.weight 16777216
mamba_layer.layers.0.mixer.conv1d.weight 16384
mamb



Early stopping triggered.


VBox(children=(Label(value='0.516 MB of 0.516 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇████
learning_rate,███████████████████████████████████▃▃▃▃▁
test_accuracy,▁▃▄▄▄▅▆▅▆▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇██████
test_loss,█▇▆▅▅▄▄▄▄▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▁▃▅▆▆▆▆▆▆▆▆▆▆▆▇▆▆▇▇▇▇▇▇▇▇▇▇▇▇██▇▇█▇████
train_loss,█▇▆▆▆▅▅▄▄▄▄▃▃▃▃▃▃▂▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▂▁▁▁▁
val_accuracy,▁▁▂▃▃▄▅▅▆▆▆▆▇▆▇▇▇▇▇▇▇▇▇▇█████████▇██████
val_loss,█▇▅▅▅▄▃▃▃▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
classification_report,precis...
epoch,384
learning_rate,0.0
test_accuracy,0.76964
test_loss,0.55254
train_accuracy,0.78883
train_loss,0.4609
val_accuracy,0.73679
val_loss,0.56997


In [9]:
# Define the model with your parameters
d_model = 2048 # Embedding dimension
num_classes = 4  # Star classification categories

# Define the training parameters
num_epochs = 500
lr = 2e-7
patience = 50   
depth = 4

# Define the config dictionary object
config = {"num_classes": num_classes, "batch_size": batch_size, "lr": lr, "patience": patience, "num_epochs": num_epochs, "d_model": d_model, "depth": depth}

# Initialize WandB project
wandb.init(project="lamost-mamba-test", entity="joaoc-university-of-southampton", config=config)
# Initialize and train the model
# Train the model using your `train_model_vit` or an adjusted training loop
model_mamba = StarClassifierMAMBA(d_model=d_model, num_classes=num_classes, input_dim=3748, n_layers=depth)
print(model_mamba)
# print number of parameters per layer
for name, param in model_mamba.named_parameters():
    print(name, param.numel())
print("Total number of parameters:", sum(p.numel() for p in model_mamba.parameters() if p.requires_grad))

# Move the model to device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_mamba = model_mamba.to(device)

# Train the model using your `train_model_vit` or an adjusted training loop
trained_model = train_model_mamba(
    model=model_mamba,
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    num_epochs=num_epochs,
    lr=lr,
    max_patience=patience,
    device=device
)
# Save the model and finish WandB session
wandb.finish()

StarClassifierMAMBA(
  (mamba_layer): Mamba(
    (layers): ModuleList(
      (0-3): 4 x ResidualBlock(
        (mixer): MambaBlock(
          (in_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (conv1d): Conv1d(4096, 4096, kernel_size=(4,), stride=(1,), padding=(3,), groups=4096)
          (x_proj): Linear(in_features=4096, out_features=256, bias=False)
          (dt_proj): Linear(in_features=128, out_features=4096, bias=True)
          (out_proj): Linear(in_features=4096, out_features=2048, bias=False)
        )
        (norm): RMSNorm()
      )
    )
  )
  (input_projection): Linear(in_features=3748, out_features=2048, bias=True)
  (classifier): Sequential(
    (0): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
    (1): Linear(in_features=2048, out_features=4, bias=True)
  )
)
mamba_layer.layers.0.mixer.A_log 262144
mamba_layer.layers.0.mixer.D 4096
mamba_layer.layers.0.mixer.in_proj.weight 16777216
mamba_layer.layers.0.mixer.conv1d.weight 16384
mamb



Early stopping triggered.


0,1
epoch,▁▁▁▂▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▄▄▄▅▅▅▆▆▆▆▇▇▇▇▇█████
learning_rate,████████████████████████████████▄▄▄▂▂▁▁▁
test_accuracy,▁▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇█▇██████████
test_loss,█▆▆▅▄▄▃▃▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▂▄▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇█▇████████████████
train_loss,█▇▆▅▅▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▃▅▅▅▆▆▇▇▇▇▇▇▇▇▇▇▇█▇████████████████████
val_loss,█▅▄▄▄▃▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
classification_report,precis...
epoch,444
learning_rate,0.0
test_accuracy,0.7625
test_loss,0.54717
train_accuracy,0.79338
train_loss,0.453
val_accuracy,0.7381
val_loss,0.56533


In [10]:
# Define the model with your parameters
d_model = 2048 # Embedding dimension
num_classes = 4  # Star classification categories

# Define the training parameters
num_epochs = 500
lr = 2e-7
patience = 50   
depth = 10

# Define the config dictionary object
config = {"num_classes": num_classes, "batch_size": batch_size, "lr": lr, "patience": patience, "num_epochs": num_epochs, "d_model": d_model, "depth": depth}

# Initialize WandB project
wandb.init(project="lamost-mamba-test", entity="joaoc-university-of-southampton", config=config)
# Initialize and train the model
# Train the model using your `train_model_vit` or an adjusted training loop
model_mamba = StarClassifierMAMBA(d_model=d_model, num_classes=num_classes, input_dim=3748, n_layers=depth)
print(model_mamba)
# print number of parameters per layer
for name, param in model_mamba.named_parameters():
    print(name, param.numel())
print("Total number of parameters:", sum(p.numel() for p in model_mamba.parameters() if p.requires_grad))

# Move the model to device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_mamba = model_mamba.to(device)

# Train the model using your `train_model_vit` or an adjusted training loop
trained_model = train_model_mamba(
    model=model_mamba,
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    num_epochs=num_epochs,
    lr=lr,
    max_patience=patience,
    device=device
)
# Save the model and finish WandB session
wandb.finish()

StarClassifierMAMBA(
  (mamba_layer): Mamba(
    (layers): ModuleList(
      (0-9): 10 x ResidualBlock(
        (mixer): MambaBlock(
          (in_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (conv1d): Conv1d(4096, 4096, kernel_size=(4,), stride=(1,), padding=(3,), groups=4096)
          (x_proj): Linear(in_features=4096, out_features=256, bias=False)
          (dt_proj): Linear(in_features=128, out_features=4096, bias=True)
          (out_proj): Linear(in_features=4096, out_features=2048, bias=False)
        )
        (norm): RMSNorm()
      )
    )
  )
  (input_projection): Linear(in_features=3748, out_features=2048, bias=True)
  (classifier): Sequential(
    (0): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
    (1): Linear(in_features=2048, out_features=4, bias=True)
  )
)
mamba_layer.layers.0.mixer.A_log 262144
mamba_layer.layers.0.mixer.D 4096
mamba_layer.layers.0.mixer.in_proj.weight 16777216
mamba_layer.layers.0.mixer.conv1d.weight 16384
mam



Early stopping triggered.


VBox(children=(Label(value='0.438 MB of 0.438 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▁▁▁▁▂▂▂▂▂▂▂▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▇▇▇▇████
learning_rate,███████████████████████████████████▃▃▃▁▁
test_accuracy,▁▂▄▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇█▇▇█▇▇███▇███▇████████
test_loss,█▇▆▆▅▃▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▅▆▆▇▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇███████████████
train_loss,█▇▆▅▅▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▂▂▅▅▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇████████████████████
val_loss,█▇▅▄▄▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
classification_report,precis...
epoch,321
learning_rate,0.0
test_accuracy,0.76496
test_loss,0.56053
train_accuracy,0.79689
train_loss,0.44989
val_accuracy,0.74396
val_loss,0.57359


In [11]:
# Define the model with your parameters
d_model = 256 # Embedding dimension
num_classes = 4  # Star classification categories

# Define the training parameters
num_epochs = 500
lr = 1e-5
patience = 30   
depth = 6

# Define the config dictionary object
config = {"num_classes": num_classes, "batch_size": batch_size, "lr": lr, "patience": patience, "num_epochs": num_epochs, "d_model": d_model, "depth": depth}

# Initialize WandB project
wandb.init(project="lamost-mamba-test", entity="joaoc-university-of-southampton", config=config)
# Initialize and train the model
# Train the model using your `train_model_vit` or an adjusted training loop
model_mamba = StarClassifierMAMBA(d_model=d_model, num_classes=num_classes, input_dim=3748, n_layers=depth)
print(model_mamba)
# print number of parameters per layer
for name, param in model_mamba.named_parameters():
    print(name, param.numel())
print("Total number of parameters:", sum(p.numel() for p in model_mamba.parameters() if p.requires_grad))

# Move the model to device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_mamba = model_mamba.to(device)

# Train the model using your `train_model_vit` or an adjusted training loop
trained_model = train_model_mamba(
    model=model_mamba,
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    num_epochs=num_epochs,
    lr=lr,
    max_patience=patience,
    device=device
)
# Save the model and finish WandB session
wandb.finish()

StarClassifierMAMBA(
  (mamba_layer): Mamba(
    (layers): ModuleList(
      (0-5): 6 x ResidualBlock(
        (mixer): MambaBlock(
          (in_proj): Linear(in_features=256, out_features=1024, bias=False)
          (conv1d): Conv1d(512, 512, kernel_size=(4,), stride=(1,), padding=(3,), groups=512)
          (x_proj): Linear(in_features=512, out_features=144, bias=False)
          (dt_proj): Linear(in_features=16, out_features=512, bias=True)
          (out_proj): Linear(in_features=512, out_features=256, bias=False)
        )
        (norm): RMSNorm()
      )
    )
  )
  (input_projection): Linear(in_features=3748, out_features=256, bias=True)
  (classifier): Sequential(
    (0): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    (1): Linear(in_features=256, out_features=4, bias=True)
  )
)
mamba_layer.layers.0.mixer.A_log 32768
mamba_layer.layers.0.mixer.D 512
mamba_layer.layers.0.mixer.in_proj.weight 262144
mamba_layer.layers.0.mixer.conv1d.weight 2048
mamba_layer.layers.0.



Early stopping triggered.


VBox(children=(Label(value='0.104 MB of 0.104 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇████
learning_rate,██████████████████████████████▃▃▃▃▃▁▁▁▁▁
test_accuracy,▁▃▅▆▅▆▆▆▇▆▆▇▇▇▇▇▇▇▇▇▇▇█▇▇▇▇▇█▇▇▇▇▇▇▇▇▇▇▇
test_loss,█▆▅▄▄▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▂▂▂▂▂
train_accuracy,▁▄▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇██████████
train_loss,█▆▅▄▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▃▄▅▆▆▆▆▇▇▇▇▇▇▇██▇███████▇▇█▇███▇▇▇▇▇▇▇▇
val_loss,█▇▆▅▄▃▃▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▂▂▂▂▂

0,1
classification_report,precis...
epoch,80
learning_rate,0.0
test_accuracy,0.72913
test_loss,0.60375
train_accuracy,0.81059
train_loss,0.40147
val_accuracy,0.73813
val_loss,0.60615


In [1]:
# Define the model with your parameters
d_model = 512 # Embedding dimension
num_classes = 4  # Star classification categories

# Define the training parameters
num_epochs = 500
lr = 1e-6
patience = 30   
depth = 10

# Define the config dictionary object
config = {"num_classes": num_classes, "batch_size": batch_size, "lr": lr, "patience": patience, "num_epochs": num_epochs, "d_model": d_model, "depth": depth}

# Initialize WandB project
wandb.init(project="lamost-mamba-test", entity="joaoc-university-of-southampton", config=config)
# Initialize and train the model
# Train the model using your `train_model_vit` or an adjusted training loop
model_mamba = StarClassifierMAMBA(d_model=d_model, num_classes=num_classes, input_dim=3748, n_layers=depth)
print(model_mamba)
# print number of parameters per layer
for name, param in model_mamba.named_parameters():
    print(name, param.numel())
print("Total number of parameters:", sum(p.numel() for p in model_mamba.parameters() if p.requires_grad))

# Move the model to device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_mamba = model_mamba.to(device)

# Train the model using your `train_model_vit` or an adjusted training loop
trained_model = train_model_mamba(
    model=model_mamba,
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    num_epochs=num_epochs,
    lr=lr,
    max_patience=patience,
    device=device
)
# Save the model and finish WandB session
wandb.finish()

NameError: name 'batch_size' is not defined

In [None]:
# Define the hyperparameters
num_classes = 4
num_epochs = 500
patience = 30

# Define sweep config
sweep_config = {
    "method": "random",
    "metric": {"name": "test_accuracy", "goal": "maximize"},
    "parameters": {
        "lr": {"values": [1e-4, 2e-6, 1e-5]},
        "dim": {"values": [64, 256, 512, 1024, 2048]},
        "depth": {"values": [3, 6, 12]}
    }
}

def sweep_train():
    with wandb.init() as run:
        config = run.config
        model_mamba = StarClassifierMAMBA(d_model=config.dim, num_classes=num_classes, input_dim=3748, n_layers=config.depth)
        
        # Pass config.num_epochs explicitly
        trained_model = train_model_mamba(model_mamba, train_loader, val_loader, test_loader, num_epochs=num_epochs, lr=config.lr, max_patience=patience, device='cuda')


# Start sweep
sweep_id = wandb.sweep(sweep_config, project="spectra-mamba-sweep")
wandb.agent(sweep_id, function=sweep_train, count=50)

Create sweep with ID: pitkh7le
Sweep URL: https://wandb.ai/joaoc-university-of-southampton/spectra-mamba-sweep/sweeps/pitkh7le


wandb: Agent Starting Run: 3hd6b91h with config:
wandb: 	depth: 12
wandb: 	dim: 512
wandb: 	lr: 2e-06




Early stopping triggered.


VBox(children=(Label(value='0.163 MB of 0.163 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▂▂▂▂▃▃▃▃▄▄▄▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇████
learning_rate,███████████████████████████████████▃▃▃▁▁
test_accuracy,▁▃▄▄▄▆▆▆▆▆▆▆▆▇▆▇▇▇▇▇█▇▇██▇▇▇██▇██▇█▇▇▇▇█
test_loss,█▆▅▄▄▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▃▄▄▅▆▆▆▆▆▇▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇█▇█▇██▇███████
train_loss,█▇▅▅▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▄▄▅▆▇▇▇▇▇▇▇▇██████████████████▇████████
val_loss,█▇▆▅▅▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
classification_report,precis...
epoch,115
learning_rate,0.0
test_accuracy,0.73538
test_loss,0.54989
train_accuracy,0.79302
train_loss,0.45769
val_accuracy,0.74725
val_loss,0.54785


wandb: Agent Starting Run: ws1b3rwh with config:
wandb: 	depth: 12
wandb: 	dim: 64
wandb: 	lr: 2e-06




Early stopping triggered.


0,1
epoch,▁▁▁▁▂▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇██
learning_rate,██████████████████████████████████▄▂▂▂▁▁
test_accuracy,▁▃▄▅▆▇▇▇▇▇▇▇▇▇██████████████████████████
test_loss,█▅▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▁▂▃▃▃▃▃▄▄▅▆▅▆▆▆▆▅▆▆▆▆▆▆▆▇▆▆▇▆▇▇▇█▆███▇▇
train_loss,█▅▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▁▂▂▃▃▃▃▅▆▆█▇▇▇▇██▇█▇██▇▇████▇▇▇██▇▇█▇██
val_loss,█▅▅▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
classification_report,precis...
epoch,481
learning_rate,0.0
test_accuracy,0.73638
test_loss,0.57178
train_accuracy,0.77179
train_loss,0.51159
val_accuracy,0.76028
val_loss,0.56679


wandb: Agent Starting Run: u05ek3j1 with config:
wandb: 	depth: 12
wandb: 	dim: 512
wandb: 	lr: 0.0001




Early stopping triggered.


0,1
epoch,▁▁▁▂▂▂▂▂▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▇▇▇▇▇███
learning_rate,██████████████▃▃▃▃▃▃▃▃▃▃▃▁▁▁▁▁▁▁▁▁
test_accuracy,▃▆▇██▇██▇▇▆█▇▇▆▆▅▄▅▅█▄▃▃▂▅▄▃▃▄▃▄▁▄
test_loss,▂▂▁▁▁▁▁▁▁▁▁▁▂▂▂▂▃▃▃▃▂▄▃▄▄▄▅▆▆▅▅▆█▆
train_accuracy,▁▄▅▅▅▆▆▆▆▇▆▆▆▇▇▇▇▇▇▇▇█▇▇▇█████████
train_loss,█▄▄▃▃▃▃▃▃▂▃▂▂▂▂▂▂▂▁▁▂▁▂▂▁▁▁▁▁▁▁▁▁▁
val_accuracy,▃▆▇█▆█▇▆▇▇▆█▇▄▇▆▆▆▆▆▆▃▆▃▄▄▄▃▄▄▄▄▁▅
val_loss,▂▂▁▁▁▁▁▁▁▂▂▁▂▂▂▂▃▃▃▃▃▄▃▄▄▄▅▆▆▅▅▆█▆

0,1
classification_report,precis...
epoch,33
learning_rate,3e-05
test_accuracy,0.68069
test_loss,1.28712
train_accuracy,0.83028
train_loss,0.3384
val_accuracy,0.70034
val_loss,1.25255


wandb: Agent Starting Run: zkbsixho with config:
wandb: 	depth: 3
wandb: 	dim: 2048
wandb: 	lr: 2e-06




Early stopping triggered.


VBox(children=(Label(value='0.105 MB of 0.105 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▁▁▂▂▂▂▃▃▃▃▃▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇▇████
learning_rate,██████████████████████████▃▃▃▃▃▃▃▃▃▁▁▁▁▁
test_accuracy,▁▄▅▆▆▆▇▇▇▇▇▇▇▇▇█▇▇█▇█▇█████████████▇▇█▇▇
test_loss,█▇▄▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▂▂
train_accuracy,▁▂▃▃▄▅▅▅▅▅▅▆▆▆▆▇▆▆▆▇▆▇▇▇▇▇▇▇▇███▇▇█▇████
train_loss,█▆▅▅▄▄▃▃▃▃▂▃▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▃▅▅▆▆▆▇▆▇▇▇▇▇▇▇▇██▇▇███▇█▇▇▇▇▇▇▇▇▇▇▇▇▇▇
val_loss,█▆▅▅▄▃▃▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▂▂▂

0,1
classification_report,precis...
epoch,71
learning_rate,0.0
test_accuracy,0.72199
test_loss,0.59822
train_accuracy,0.82062
train_loss,0.39872
val_accuracy,0.7466
val_loss,0.58923


wandb: Agent Starting Run: y1is8mn8 with config:
wandb: 	depth: 12
wandb: 	dim: 512
wandb: 	lr: 1e-05




Early stopping triggered.


VBox(children=(Label(value='0.074 MB of 0.074 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
learning_rate,████████████████████████▃▃▃▃▃▃▃▃▃▁▁▁▁▁▁▁
test_accuracy,▁▄▄▆▆▇▇▇▇▇▇▇▇██▇█▇▆▇█▇▇▆▇██▇▇▇▇██▇▇▇▇▆▇▇
test_loss,█▅▄▃▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▃
train_accuracy,▁▄▄▅▅▆▆▆▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇█▇▇▇█▇████████
train_loss,█▆▅▄▄▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▄▅▅▇▆▇▆▇▇▇▇█████▆▆█▇▇▇▆▆▇▇▆▇▇▇▆▇▅▆▆▆▆▆▆
val_loss,█▅▄▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▁▁▂▂▂▂▂▂▃▃▃▃▃▃▃

0,1
classification_report,precis...
epoch,47
learning_rate,0.0
test_accuracy,0.71641
test_loss,0.67592
train_accuracy,0.83013
train_loss,0.37327
val_accuracy,0.71729
val_loss,0.66709


wandb: Agent Starting Run: srvl9krt with config:
wandb: 	depth: 3
wandb: 	dim: 2048
wandb: 	lr: 1e-05




Early stopping triggered.


VBox(children=(Label(value='0.066 MB of 0.066 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇██
learning_rate,█████████████████████▃▃▃▃▃▃▃▃▃▃▃▁▁▁▁▁▁▁▁
test_accuracy,▄▆▆▇▇▇▇▇▇▇▇██▅▅▆▆▅▆▆▅▄▅▄▄▃▄▁▅▃▃▅▄▄▃▄▄▅▅▄
test_loss,▄▃▂▂▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▃▃▄▃▅▄▅▅▇▅▆▇▇▇▇█▇█▇██
train_accuracy,▁▄▄▅▅▅▆▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇█▇▇███████████
train_loss,█▅▅▄▄▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▂▅▆▆▇▇▇█▇████▆▆▆▅▄▆▆▅▄▆▄▄▃▄▁▄▃▃▃▄▃▃▃▃▄▄▃
val_loss,▄▃▂▂▁▁▁▁▁▁▁▁▁▂▂▂▂▂▂▃▃▄▃▅▄▅▅▇▅▆▇▇▇▇█▇█▇██

0,1
classification_report,precis...
epoch,41
learning_rate,0.0
test_accuracy,0.68348
test_loss,1.21955
train_accuracy,0.85383
train_loss,0.30419
val_accuracy,0.68211
val_loss,1.15511


wandb: Agent Starting Run: 6o2b64ex with config:
wandb: 	depth: 12
wandb: 	dim: 256
wandb: 	lr: 2e-06




Early stopping triggered.


VBox(children=(Label(value='0.303 MB of 0.303 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▃▄▄▄▅▅▅▅▆▆▆▆▆▆▆▆▆▇▇▇███
learning_rate,█████████████████████████████▄▄▄▄▄▂▂▂▁▁▁
test_accuracy,▁▃▄▆▇▇▇▇▇▇▇▇█▇▇█▇▇██████████████████████
test_loss,█▅▅▄▃▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▁▄▄▅▅▅▅▆▅▆▆▆▆▆▇▇▇▇▇█▇▇█▇█▇██▇██████████
train_loss,█▆▆▅▄▃▄▄▃▃▃▃▃▂▂▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▂▃▃▅▆▇▇▇▇▇▇▇▇▇▇████████▇▇███▇████▇█████
val_loss,█▆▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
classification_report,precis...
epoch,222
learning_rate,0.0
test_accuracy,0.74397
test_loss,0.5423
train_accuracy,0.78694
train_loss,0.47046
val_accuracy,0.74986
val_loss,0.54746


wandb: Agent Starting Run: 7uzutv2r with config:
wandb: 	depth: 3
wandb: 	dim: 1024
wandb: 	lr: 2e-06




Early stopping triggered.


VBox(children=(Label(value='0.134 MB of 0.134 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▁▁▂▂▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇███
learning_rate,█████████████████████████████████▃▃▃▃▁▁▁
test_accuracy,▁▃▄▅▅▅▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇█████████▇███████
test_loss,█▇▆▅▄▄▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▃▄▄▅▅▆▆▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇██▇▇█████████
train_loss,█▇▆▆▆▅▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▂▅▄▅▆▆▇▇▇▇▇▇▇▇██▇▇█████████████████████
val_loss,█▆▅▅▄▄▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
classification_report,precis...
epoch,93
learning_rate,0.0
test_accuracy,0.73862
test_loss,0.55779
train_accuracy,0.79406
train_loss,0.45494
val_accuracy,0.75572
val_loss,0.54654


wandb: Agent Starting Run: 4d0u86e6 with config:
wandb: 	depth: 12
wandb: 	dim: 2048
wandb: 	lr: 0.0001




Early stopping triggered.


VBox(children=(Label(value='0.055 MB of 0.055 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▂▂▂▂▂▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▇▇▇▇▇███
learning_rate,██████████████▃▃▃▃▃▃▃▃▃▃▃▁▁▁▁▁▁▁▁▁
test_accuracy,▄▆██▆▆▅▅▇▄▅▆▇▄▁▁▄▃▃▄▃▆▃▄▃▂▃▂▂▂▃▃▁▃
test_loss,▂▁▁▁▁▂▁▂▂▂▂▂▂▂▃▃▃▅▅▄▄▃▄▃▅▄▅▆▇▇▆▆█▇
train_accuracy,▁▄▅▅▅▅▆▆▆▇▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇███████
train_loss,█▃▃▃▂▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▃▆█▆▅▅▄▅▇▄▇▇▆▄▂▃▄▃▂▄▃▅▂▃▁▃▃▂▁▁▂▂▁▂
val_loss,▂▁▁▁▁▂▂▂▁▂▂▂▂▃▃▃▄▅▆▄▄▄▄▄▆▄▅▆▇▇▆▆█▇

0,1
classification_report,precis...
epoch,33
learning_rate,3e-05
test_accuracy,0.66719
test_loss,1.43654
train_accuracy,0.84941
train_loss,0.30494
val_accuracy,0.66776
val_loss,1.33641


wandb: Agent Starting Run: 2eh1hdbe with config:
wandb: 	depth: 12
wandb: 	dim: 64
wandb: 	lr: 2e-06




Early stopping triggered.


VBox(children=(Label(value='0.634 MB of 0.634 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▂▂▂▂▂▂▂▂▂▃▃▃▃▄▄▅▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇███
learning_rate,███████████████████████████████████▄▃▃▁▁
test_accuracy,▁▅▅▆▆▆▇▇▇▇▇▇█▇▇█▇▇██████████████████████
test_loss,█▄▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▂▂▃▃▄▅▅▅▅▅▅▆▆▆▆▆▆▆▆▆▇▆▆▇▇▇▇▇▇█▇██▇█▇███
train_loss,█▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▂▃▄▄▅▅▅▆▆▆▇▇▇▇█████████████████████████
val_loss,█▆▅▅▄▄▄▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
classification_report,precis...
epoch,477
learning_rate,0.0
test_accuracy,0.74643
test_loss,0.57047
train_accuracy,0.76639
train_loss,0.51655
val_accuracy,0.75051
val_loss,0.56579


wandb: Agent Starting Run: 1yrntaqx with config:
wandb: 	depth: 12
wandb: 	dim: 2048
wandb: 	lr: 0.0001




Early stopping triggered.


VBox(children=(Label(value='0.061 MB of 0.061 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
learning_rate,██████████████████▃▃▃▃▃▃▃▃▃▃▃▁▁▁▁▁▁▁▁▁
test_accuracy,▁▄▇▇█▇▆█▇█▆▃▇▂▇▅▅▅▃▄▃▂▂▃▄▃▄▃▁▃▄▄▂▃▂▂▃▂
test_loss,▂▂▁▁▁▂▁▁▁▁▂▂▂▂▁▂▂▂▃▃▄▄▃▅▄▄▄▅▆▄▄▅▆▆▆█▇▇
train_accuracy,▁▄▅▅▆▆▆▅▆▆▆▆▆▇▆▇▇▇▇▇▇▇▇▇▇▇▇███████████
train_loss,█▃▃▃▂▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▅▆▇▆▆▆█▇▇▅▄▇▃▇▄▆▆▃▅▃▂▃▃▃▄▃▃▂▃▄▃▂▃▃▂▃▃
val_loss,▂▁▁▁▁▂▁▁▁▁▂▂▂▂▂▂▂▂▃▃▅▄▄▅▄▅▄▅▆▅▅▆▆▆▆█▇▇

0,1
classification_report,precis...
epoch,37
learning_rate,3e-05
test_accuracy,0.64989
test_loss,1.63976
train_accuracy,0.84997
train_loss,0.29584
val_accuracy,0.67884
val_loss,1.53915


wandb: Agent Starting Run: lu4ido6m with config:
wandb: 	depth: 3
wandb: 	dim: 1024
wandb: 	lr: 1e-05




Early stopping triggered.


VBox(children=(Label(value='0.071 MB of 0.071 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇████
learning_rate,███████████████████████▃▃▃▃▃▃▃▃▃▃▁▁▁▁▁▁▁
test_accuracy,▁▄▅▆▇▇▇▇▇▇████▇▇▇▇█▆▇▇▆▇▇█▇▇▇▇▇▆▆▆▇▇▆▆▇▅
test_loss,█▆▅▃▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▂▂▂▃▃▃▃▃▄▅▅▅▄▄▄▅
train_accuracy,▁▄▄▅▅▅▆▆▆▆▆▆▇▆▆▆▇▇▇▇▇▇▇▇▇▇█▇████████████
train_loss,█▆▅▅▄▄▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▃▄▆▆▇▇▇▆▇▇█▇█▇█▇██▇▇▇▆▇▇▆▇▆▅▆▆▆▆▅▅▅▅▆▆▅
val_loss,█▆▅▃▃▂▂▁▁▁▁▁▁▁▁▁▂▁▁▂▂▂▂▂▂▂▂▃▃▃▃▃▄▅▄▅▄▄▄▅

0,1
classification_report,precis...
epoch,45
learning_rate,0.0
test_accuracy,0.70569
test_loss,0.77947
train_accuracy,0.83413
train_loss,0.35521
val_accuracy,0.70688
val_loss,0.7633


wandb: Agent Starting Run: syxa0wgf with config:
wandb: 	depth: 3
wandb: 	dim: 512
wandb: 	lr: 2e-06




Early stopping triggered.


VBox(children=(Label(value='0.203 MB of 0.203 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▂▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇██
learning_rate,████████████████████████████▄▄▄▄▄▄▄▂▂▁▁▁
test_accuracy,▁▅▅▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇██████████████████████
test_loss,█▆▆▅▅▄▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▄▄▄▅▆▅▆▆▆▆▆▇▆▇▇▇▇▇▇▇▇▇▇▇▇▇█▇▇██████████
train_loss,█▇▆▅▅▅▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁
val_accuracy,▁▂▂▃▃▆▆▆▆▆▇▇▇▇▇▇█████████▇█▇████████████
val_loss,█▆▅▅▄▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
classification_report,precis...
epoch,184
learning_rate,0.0
test_accuracy,0.74453
test_loss,0.54457
train_accuracy,0.79994
train_loss,0.45143
val_accuracy,0.75443
val_loss,0.53931


wandb: Sweep Agent: Waiting for job.
wandb: Job received.
wandb: Agent Starting Run: pqeskp86 with config:
wandb: 	depth: 3
wandb: 	dim: 2048
wandb: 	lr: 1e-05




Early stopping triggered.


VBox(children=(Label(value='0.062 MB of 0.062 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
learning_rate,███████████████████▃▃▃▃▃▃▃▃▃▃▃▁▁▁▁▁▁▁▁▁
test_accuracy,▃▅▅▆▇▇▇██▇██▆██▆▆▄▆█▇▆▄▆▆▃▄▅▁▄▄▄▅▄▃▃▃▄▄
test_loss,▄▃▂▂▁▁▁▁▁▂▁▁▂▂▁▂▂▃▃▃▃▄▄▄▄▅▅▆▇▆▆▇▆▇█████
train_accuracy,▁▃▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇█▇██▇█████████
train_loss,█▆▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▄▆▇▇▇▇██▆▇█▆▇▆▆▅▅▅▇▆▅▄▄▆▄▅▄▁▃▄▄▅▄▃▃▃▃▃
val_loss,▅▃▂▂▁▁▁▁▁▁▁▁▂▁▁▂▂▃▃▃▃▄▄▅▄▅▅▆▇▆▆▇▆▇█████

0,1
classification_report,precis...
epoch,38
learning_rate,0.0
test_accuracy,0.67902
test_loss,1.17359
train_accuracy,0.84536
train_loss,0.31738
val_accuracy,0.68016
val_loss,1.09903


wandb: Sweep Agent: Waiting for job.
wandb: Job received.
wandb: Agent Starting Run: co5o7mqo with config:
wandb: 	depth: 3
wandb: 	dim: 1024
wandb: 	lr: 1e-05




Early stopping triggered.


VBox(children=(Label(value='0.075 MB of 0.075 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
learning_rate,████████████████████████▃▃▃▃▃▃▃▃▃▁▁▁▁▁▁▁
test_accuracy,▁▃▄▅▇▇▇▇▇█▇▇█▇▇▆█▆▆▆▆▄▇▆▆▅▆▄▅▅▅▅▄▄▄▄▄▃▃▄
test_loss,█▆▄▃▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▂▄▂▂▂▃▃▃▃▃▄▅▅▅▆▅▆▆▆▆
train_accuracy,▁▄▅▅▅▆▆▆▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇█▇▇██████████
train_loss,█▆▅▄▄▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▃▄▄▆▇▇▇█▇▇▇▇█▇███▇█▅▇▅▇▆▆▆▆▆▆▅▅▅▅▄▅▅▅▄▄
val_loss,█▆▄▄▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▃▁▂▂▃▂▃▃▃▄▄▄▅▅▅▅▅▅

0,1
classification_report,precis...
epoch,48
learning_rate,0.0
test_accuracy,0.68125
test_loss,0.86952
train_accuracy,0.84108
train_loss,0.33556
val_accuracy,0.69645
val_loss,0.79296


wandb: Agent Starting Run: kcj6h4hb with config:
wandb: 	depth: 3
wandb: 	dim: 512
wandb: 	lr: 0.0001




Early stopping triggered.


VBox(children=(Label(value='0.047 MB of 0.047 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇███
learning_rate,███████████████▃▃▃▃▃▃▃▃▃▃▃▁▁▁▁▁▁▁▁▁
test_accuracy,▁▄▆▆▇█▆█▇█▇▇▆▇▆▆▆▆▆▅▃▆▅▄▃▃▅▅▃▂▄▃▄▃▃
test_loss,▃▂▂▁▁▁▁▁▁▁▁▁▂▁▂▂▂▃▃▃▃▃▄▅▅▅▅▅▇▇▆▆▇██
train_accuracy,▁▄▅▅▅▆▆▆▆▆▇▆▇▇▇▇▇▇▇▇▇███▇▇▇████████
train_loss,█▅▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▅▇▆▇▇▇████▇▇▇▆▆▆▅▆▅▄▅▄▄▄▃▅▅▄▃▄▅▃▂▂
val_loss,▄▂▁▁▁▁▁▁▁▁▁▁▂▂▂▃▃▃▃▃▄▃▄▅▅▅▅▆▇▇▆▆▇██

0,1
classification_report,precis...
epoch,34
learning_rate,3e-05
test_accuracy,0.66897
test_loss,1.31051
train_accuracy,0.83746
train_loss,0.32488
val_accuracy,0.67428
val_loss,1.25303


wandb: Agent Starting Run: easayzmz with config:
wandb: 	depth: 3
wandb: 	dim: 64
wandb: 	lr: 2e-06




VBox(children=(Label(value='0.615 MB of 0.615 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▂▂▂▂▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇████
learning_rate,██████████████████████████████████████▁▁
test_accuracy,▁▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇█▇█▇███████████████████
test_loss,█▆▆▅▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▂▂▂▃▃▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▆▇▇▇▇▇▇▇▇▇██▇██████
train_loss,█▇▆▆▅▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁
val_accuracy,▁▁▂▂▄▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇███████████
val_loss,█▇▅▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
classification_report,precis...
epoch,499
learning_rate,0.0
test_accuracy,0.73672
test_loss,0.58935
train_accuracy,0.76224
train_loss,0.55152
val_accuracy,0.75442
val_loss,0.58345


wandb: Agent Starting Run: x24d3cxv with config:
wandb: 	depth: 3
wandb: 	dim: 256
wandb: 	lr: 1e-05




Early stopping triggered.


VBox(children=(Label(value='0.067 MB of 0.067 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▇▇▇▇██
learning_rate,███████████████████████████████▃▃▃▃▁▁▁▁▁
test_accuracy,▁▄▄▅▅▆▆▆▇▇▇▇▇▇▇▇▇██████▇███▇██████▇█████
test_loss,█▇▆▄▄▃▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▅▅▆▆▆▆▇▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇██▇█████████████
train_loss,█▆▆▅▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▂▂▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▃▄▄▄▅▄▆▆▆▆▇▆▆▇▇▇▇▇█████▅███▇█▇▇▇▇▇██▇██
val_loss,█▆▅▄▄▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
classification_report,precis...
epoch,75
learning_rate,0.0
test_accuracy,0.74632
test_loss,0.5596
train_accuracy,0.78973
train_loss,0.44758
val_accuracy,0.74854
val_loss,0.56641


wandb: Agent Starting Run: ur6srdss with config:
wandb: 	depth: 12
wandb: 	dim: 256
wandb: 	lr: 2e-06




Early stopping triggered.


VBox(children=(Label(value='0.285 MB of 0.285 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▁▂▂▂▃▃▃▃▃▄▄▄▄▄▄▄▄▅▅▅▅▅▅▆▆▇▇▇▇▇▇▇▇▇███
learning_rate,███████████████████████████████▄▄▄▄▄▂▂▂▁
test_accuracy,▁▅▆▆▇▇▇▇▇▇▇▇██▇█████████████████████████
test_loss,█▇▄▄▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▃▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇███████████████████
train_loss,█▇▇▆▅▄▄▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▅▅▆▇▇▇▇▇▇██████████████████████████████
val_loss,██▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
classification_report,precis...
epoch,209
learning_rate,0.0
test_accuracy,0.74922
test_loss,0.5566
train_accuracy,0.77521
train_loss,0.47637
val_accuracy,0.75506
val_loss,0.54301


wandb: Agent Starting Run: scf2nhv1 with config:
wandb: 	depth: 3
wandb: 	dim: 2048
wandb: 	lr: 2e-06




Early stopping triggered.


VBox(children=(Label(value='0.098 MB of 0.098 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▁▁▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇███
learning_rate,████████████████████████████▃▃▃▃▃▃▃▁▁▁▁▁
test_accuracy,▁▃▄▄▄▅▆▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇█▇█▇▇▇██▇▇▇▇▇▇▇▇
test_loss,█▇▆▅▅▄▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁
train_accuracy,▁▄▅▅▅▆▆▆▇▇▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇██▇████████████
train_loss,█▇▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▄▄▅▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇█████▇███████████▇██▇
val_loss,█▅▄▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
classification_report,precis...
epoch,66
learning_rate,0.0
test_accuracy,0.73371
test_loss,0.57144
train_accuracy,0.81568
train_loss,0.41332
val_accuracy,0.74921
val_loss,0.57074


wandb: Agent Starting Run: aasgcfb4 with config:
wandb: 	depth: 3
wandb: 	dim: 2048
wandb: 	lr: 0.0001




Early stopping triggered.


VBox(children=(Label(value='0.055 MB of 0.055 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▂▂▂▂▂▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▇▇▇▇▇███
learning_rate,██████████████▃▃▃▃▃▃▃▃▃▃▃▁▁▁▁▁▁▁▁▁
test_accuracy,▃▅▅▇▆▇▆██▆▄▆▆▃▂▅▅▅▃▂▃▄▄▄▂▃▄▃▂▁▂▃▃▂
test_loss,▂▂▂▁▁▁▁▁▁▁▂▂▂▂▃▂▂▃▃▄▃▃▄▄▅▄▄▅▇█▇▆▇▇
train_accuracy,▁▅▅▅▅▆▆▆▆▇▆▆▇▆▆▇▇▇▇▇▇▇▇▇█▇████████
train_loss,█▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▆▄▇▇▆▇▆█▅▄▇▅▃▃▅▄▅▅▃▄▅▃▃▂▃▄▃▁▁▁▃▂▂
val_loss,▂▁▁▁▁▁▁▁▁▂▂▂▂▂▃▂▃▃▃▄▃▃▄▄▆▄▄▅▇█▇▆▇▇

0,1
classification_report,precis...
epoch,33
learning_rate,3e-05
test_accuracy,0.6567
test_loss,1.45858
train_accuracy,0.84529
train_loss,0.30809
val_accuracy,0.67428
val_loss,1.40694


wandb: Sweep Agent: Waiting for job.
wandb: Job received.
wandb: Agent Starting Run: p631dvxg with config:
wandb: 	depth: 3
wandb: 	dim: 512
wandb: 	lr: 1e-05




Early stopping triggered.


VBox(children=(Label(value='0.067 MB of 0.067 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇███
learning_rate,█████████████████████████▃▃▃▃▃▃▃▃▃▁▁▁▁▁▁
test_accuracy,▁▂▄▅▆▆▇▆▇▇▇▇█▇▇██▇▇▇███▇▇█▇▇█▇██████▇▇██
test_loss,█▅▄▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▄▅▆▅▆▇▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇██████████████
train_loss,█▆▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▂▃▄▅▆▇▆▇▇▇▇▇████████▇█▇▇███▇█████▇▇▇▇▇▇
val_loss,█▆▅▄▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
classification_report,precis...
epoch,55
learning_rate,0.0
test_accuracy,0.74475
test_loss,0.57715
train_accuracy,0.80973
train_loss,0.42089
val_accuracy,0.75312
val_loss,0.56424


wandb: Agent Starting Run: qhlukzcv with config:
wandb: 	depth: 3
wandb: 	dim: 512
wandb: 	lr: 0.0001




Early stopping triggered.


VBox(children=(Label(value='0.047 MB of 0.047 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇███
learning_rate,███████████████▃▃▃▃▃▃▃▃▃▃▃▁▁▁▁▁▁▁▁▁
test_accuracy,▁▆▅▇▇██▆▆▅▆▇▇▇▇▆▆▅▅▅▄▆▆▃▃▄▆▅▄▄▂▆▄▂▃
test_loss,▃▂▁▁▁▁▁▁▁▂▂▂▂▂▃▂▂▃▃▄▄▃▃▅▄▅▄▅▆▆▇▅▆██
train_accuracy,▁▄▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇██▇██████████
train_loss,█▅▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▄▅█▇▇▆▆▆▃▆█▇▆▅▄▅▄▄▄▄▄▆▃▄▄▃▃▁▃▂▄▃▂▁
val_loss,▄▂▁▁▁▁▁▁▁▂▂▁▂▂▃▂▃▃▃▄▄▄▄▅▅▅▅▆▇▇▇▅▆██

0,1
classification_report,precis...
epoch,34
learning_rate,3e-05
test_accuracy,0.68426
test_loss,1.258
train_accuracy,0.83795
train_loss,0.33058
val_accuracy,0.68406
val_loss,1.20141


wandb: Agent Starting Run: x1k80szh with config:
wandb: 	depth: 12
wandb: 	dim: 256
wandb: 	lr: 2e-06




Early stopping triggered.


VBox(children=(Label(value='0.292 MB of 0.292 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▂▂▂▂▂▃▃▃▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇███
learning_rate,███████████████████████████████████▃▃▃▃▁
test_accuracy,▁▆▆▇▇▇▇▇▇▇▇▇▇▇█▇▇██▇▇██▇████████████████
test_loss,█▇▆▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▄▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇███████████
train_loss,█▇▇▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▂▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▄▅▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇████████████████████
val_loss,█▇▅▅▅▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
classification_report,precis...
epoch,214
learning_rate,0.0
test_accuracy,0.74531
test_loss,0.55665
train_accuracy,0.78812
train_loss,0.47355
val_accuracy,0.75441
val_loss,0.54833


wandb: Agent Starting Run: csmjabkr with config:
wandb: 	depth: 3
wandb: 	dim: 256
wandb: 	lr: 1e-05




Early stopping triggered.


VBox(children=(Label(value='0.091 MB of 0.091 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▁▁▂▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇██
learning_rate,██████████████████████████████▃▃▃▃▃▁▁▁▁▁
test_accuracy,▁▃▄▅▅▆▆▇▇▇▇▇▇▇▇▇▇▇█▇▇▇▇▇██████▇██▇████▇█
test_loss,█▇▅▄▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▄▅▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇█▇▇▇███████████████
train_loss,█▅▅▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▃▅▅▆▆▆▇▇▇▇▇▇▇▇▇▇▇██▇█▇███▇▇████████▇█▇█
val_loss,█▇▆▅▄▃▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
classification_report,precis...
epoch,74
learning_rate,0.0
test_accuracy,0.73259
test_loss,0.57458
train_accuracy,0.79114
train_loss,0.46242
val_accuracy,0.7492
val_loss,0.56753


wandb: Agent Starting Run: lo7h514a with config:
wandb: 	depth: 3
wandb: 	dim: 256
wandb: 	lr: 1e-05




Early stopping triggered.


VBox(children=(Label(value='0.092 MB of 0.092 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▁▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇███
learning_rate,██████████████████████████████▃▃▃▃▃▁▁▁▁▁
test_accuracy,▁▄▅▆▅▆▆▅▆▇▇▆▇▇▇▇▇█▇▇█▇████▇██▇█████████▇
test_loss,█▇▅▄▄▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▃▄▄▅▅▅▆▆▆▆▆▆▆▇▆▇▇▇▇▇▇▇▇▇▇▇███████▇█████
train_loss,█▇▇▆▆▅▅▅▄▄▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▅▆▆▆▆▇▇▇▇▇█▇████▇█▇████████████████████
val_loss,█▇▆▅▄▃▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
classification_report,precis...
epoch,75
learning_rate,0.0
test_accuracy,0.74029
test_loss,0.56262
train_accuracy,0.79201
train_loss,0.46128
val_accuracy,0.76157
val_loss,0.55529


wandb: Sweep Agent: Waiting for job.
wandb: Job received.
wandb: Agent Starting Run: 8dq6nnok with config:
wandb: 	depth: 3
wandb: 	dim: 256
wandb: 	lr: 2e-06




Early stopping triggered.


VBox(children=(Label(value='0.308 MB of 0.308 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▆▇█
learning_rate,████████████████████████████████████▄▄▄▁
test_accuracy,▁▃▃▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇█▇▇▇█▇███▇▇███████
test_loss,██▇▆▅▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▃▃▃▄▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇███████████████
train_loss,█▇▆▅▅▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▅▆▆▇▇▇▇█▇██████████████████████████████
val_loss,█▆▅▅▅▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
classification_report,precis...
epoch,264
learning_rate,0.0
test_accuracy,0.73984
test_loss,0.5487
train_accuracy,0.77695
train_loss,0.49361
val_accuracy,0.75639
val_loss,0.53821


wandb: Agent Starting Run: pabafewa with config:
wandb: 	depth: 3
wandb: 	dim: 512
wandb: 	lr: 1e-05




Early stopping triggered.


VBox(children=(Label(value='0.085 MB of 0.085 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▁▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇▇██
learning_rate,██████████████████████████▃▃▃▃▃▃▃▃▁▁▁▁▁▁
test_accuracy,▁▄▅▅▆▇▆▆▇▇▇▇▇▇██▇▇▇▇█▇▇██▇▇▇▇▇▇▇██▇█▇▇▇▇
test_loss,█▆▅▄▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂
train_accuracy,▁▄▅▅▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇█▇▇████████████
train_loss,█▅▄▄▄▃▃▃▂▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▄▅▆▆▆▆▆▇▇▇███▇███████▇███▇███████▇█████
val_loss,█▆▅▄▄▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
classification_report,precis...
epoch,56
learning_rate,0.0
test_accuracy,0.71629
test_loss,0.59376
train_accuracy,0.81213
train_loss,0.41196
val_accuracy,0.74463
val_loss,0.57069


wandb: Sweep Agent: Waiting for job.
wandb: Job received.
wandb: Agent Starting Run: zxd8cian with config:
wandb: 	depth: 12
wandb: 	dim: 256
wandb: 	lr: 1e-05




Early stopping triggered.


VBox(children=(Label(value='0.100 MB of 0.100 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇███
learning_rate,██████████████████████████████▃▃▃▃▃▃▁▁▁▁
test_accuracy,▁▄▅▆▆▆▇▇▇▇▇▇▇▇▇▆▇▇▇▇▇██▇▇█▇███▇█▇▇▇▇▇▇▇▇
test_loss,█▆▄▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▂
train_accuracy,▁▁▂▃▃▄▄▄▅▅▅▅▅▅▆▆▆▅▆▆▆▆▆▆▆▆▆▇▇▇▇▇█▇██▇▇▇█
train_loss,█▆▅▄▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▃▆▆▆▆▇▇▇██▇▇█▇████▇█▇█████▇██▇████▇▇▇▇█
val_loss,█▆▅▄▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▁▂▂

0,1
classification_report,precis...
epoch,67
learning_rate,0.0
test_accuracy,0.72121
test_loss,0.59499
train_accuracy,0.80384
train_loss,0.41891
val_accuracy,0.744
val_loss,0.58152


wandb: Agent Starting Run: 9rsnppqs with config:
wandb: 	depth: 3
wandb: 	dim: 256
wandb: 	lr: 0.0001




Early stopping triggered.


VBox(children=(Label(value='0.056 MB of 0.056 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
learning_rate,███████████████████▃▃▃▃▃▃▃▃▃▃▃▁▁▁▁▁▁▁▁▁
test_accuracy,▂▁▄▄▅▆▇█▇▃▆█▇▆▅▇▂▅▅▇▅▂▅▆▆▅▂▃▂▁▅▄▂▃▃▃▅▁▁
test_loss,▄▃▂▂▂▁▁▁▁▁▁▁▁▁▁▂▃▃▃▂▃▄▃▃▃▄▅▅▅▄▄▅▇▇▇▆▆██
train_accuracy,▁▄▄▅▅▅▆▆▆▆▆▆▇▇▇▇▇▆▇▇▇▇▇▇█▇███▇█████████
train_loss,█▅▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▁▂▄▆▅▇▇█▄▇██▇▆▇▄▇▅▆▅▅▆▆▅▆▄▃▃▂▅▅▂▂▃▃▂▂▂
val_loss,▄▃▂▂▂▁▁▁▁▁▁▁▁▂▂▂▃▃▃▃▃▄▃▄▄▄▅▆▅▄▄▅▇▇▆▆▆▇█

0,1
classification_report,precis...
epoch,38
learning_rate,3e-05
test_accuracy,0.674
test_loss,1.1595
train_accuracy,0.82395
train_loss,0.34702
val_accuracy,0.67951
val_loss,1.16505


wandb: Agent Starting Run: 0ojpxjq2 with config:
wandb: 	depth: 3
wandb: 	dim: 1024
wandb: 	lr: 0.0001




Early stopping triggered.


VBox(children=(Label(value='0.057 MB of 0.057 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇███
learning_rate,███████████████▃▃▃▃▃▃▃▃▃▃▃▁▁▁▁▁▁▁▁▁
test_accuracy,▂▇▆▇█▇█▆█▆█▃▅▅▄█▄▂▆▅▄▃▆▃▂▂▃▃▃▂▁▃▄▃▂
test_loss,▂▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▃▃▃▃▄▄▃▃▄▅▅▅▆▇█▇▆▇▇
train_accuracy,▁▄▅▅▅▆▆▆▆▆▆▇▆▆▇▇▇▇▇▇███▇▇██████████
train_loss,█▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▂▅▆▅▇▆█▆█▅▇▄▄▅▅▇▅▂▅▄▃▄▅▂▃▂▂▄▃▁▂▂▃▂▃
val_loss,▃▁▁▁▁▁▁▁▁▂▁▂▂▂▂▂▃▃▃▄▄▄▄▃▄▆▅▅▆▇█▇▆▇▇

0,1
classification_report,precis...
epoch,34
learning_rate,3e-05
test_accuracy,0.66094
test_loss,1.41941
train_accuracy,0.84248
train_loss,0.30968
val_accuracy,0.68275
val_loss,1.2934


wandb: Agent Starting Run: pnsnoed1 with config:
wandb: 	depth: 3
wandb: 	dim: 64
wandb: 	lr: 2e-06




VBox(children=(Label(value='0.615 MB of 0.615 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▂▂▂▂▃▃▃▃▃▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇▇█████
learning_rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
test_accuracy,▁▃▄▄▅▅▅▅▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇█▇████████
test_loss,██▆▅▅▄▄▄▄▄▄▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▃▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▆▆▇▇▆▇▇▇▇▇▇▇▇▇█████▇█▇█
train_loss,█▆▆▅▅▄▄▄▄▄▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▂▁▁▁▁▁▁
val_accuracy,▁▂▄▄▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇█████████████████████
val_loss,█▇▇▅▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
classification_report,precis...
epoch,499
learning_rate,0.0
test_accuracy,0.74085
test_loss,0.57851
train_accuracy,0.75884
train_loss,0.53395
val_accuracy,0.75116
val_loss,0.5795


wandb: Agent Starting Run: 7focbw4m with config:
wandb: 	depth: 3
wandb: 	dim: 256
wandb: 	lr: 2e-06




In [14]:
# Define the model with your parameters
d_model = 512 # Embedding dimension
num_classes = 4  # Star classification categories

# Define the training parameters
num_epochs = 500
lr = 1e-4
patience = 50   
depth = 10

# Define the config dictionary object
config = {"num_classes": num_classes, "batch_size": batch_size, "lr": lr, "patience": patience, "num_epochs": num_epochs, "d_model": d_model, "depth": depth}

# Initialize WandB project
wandb.init(project="lamost-mamba-test", entity="joaoc-university-of-southampton", config=config)
# Initialize and train the model
# Train the model using your `train_model_vit` or an adjusted training loop
model_mamba = StarClassifierMAMBA(d_model=d_model, num_classes=num_classes, input_dim=3748, n_layers=depth)
print(model_mamba)
# print number of parameters per layer
for name, param in model_mamba.named_parameters():
    print(name, param.numel())
print("Total number of parameters:", sum(p.numel() for p in model_mamba.parameters() if p.requires_grad))

# Move the model to device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_mamba = model_mamba.to(device)

# Train the model using your `train_model_vit` or an adjusted training loop
trained_model = train_model_mamba(
    model=model_mamba,
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    num_epochs=num_epochs,
    lr=lr,
    max_patience=patience,
    device=device
)
# Save the model and finish WandB session
wandb.finish()

StarClassifierMAMBA(
  (mamba_layer): Mamba(
    (layers): ModuleList(
      (0-9): 10 x ResidualBlock(
        (mixer): MambaBlock(
          (in_proj): Linear(in_features=512, out_features=2048, bias=False)
          (conv1d): Conv1d(1024, 1024, kernel_size=(4,), stride=(1,), padding=(3,), groups=1024)
          (x_proj): Linear(in_features=1024, out_features=160, bias=False)
          (dt_proj): Linear(in_features=32, out_features=1024, bias=True)
          (out_proj): Linear(in_features=1024, out_features=512, bias=False)
        )
        (norm): RMSNorm()
      )
    )
  )
  (input_projection): Linear(in_features=3748, out_features=512, bias=True)
  (classifier): Sequential(
    (0): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (1): Linear(in_features=512, out_features=4, bias=True)
  )
)
mamba_layer.layers.0.mixer.A_log 65536
mamba_layer.layers.0.mixer.D 1024
mamba_layer.layers.0.mixer.in_proj.weight 1048576
mamba_layer.layers.0.mixer.conv1d.weight 4096
mamba_layer.



Early stopping triggered.


VBox(children=(Label(value='0.087 MB of 0.087 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
learning_rate,████████████████▃▃▃▃▃▃▃▃▃▃▃▁▁▁▁▁▁▁▁▁▁▁▁▁
test_accuracy,▂▁▄▄▅▆▇▅▆▆▅▆▅▇▅▇▆█▇▇▆▆▅▆▇▆▅▆▆▆▇▅▆▆▅▆▆▆▆▇
test_loss,▅▃▂▁▁▁▁▁▂▃▃▃▃▄▄▅▄▄▄▅▅▅▆▅▆▆▆▆▅▆▇▇▇█▇█▆▇▇▇
train_accuracy,▁▄▄▅▅▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇██▇▇▇████████████
train_loss,█▄▄▄▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▄▇▇▅▇▅▇▆▇▄▇▆▅▃▆▇█▇▇█▆▇▇▇▄█▇▇▇▆▆▇▇██▇▆█▆
val_loss,▅▃▂▁▁▂▂▂▃▃▄▄▅▅▅▄▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇█▇▇▇▇▇▇█

0,1
classification_report,precis...
epoch,55
learning_rate,3e-05
test_accuracy,0.71123
test_loss,1.0999
train_accuracy,0.91923
train_loss,0.16057
val_accuracy,0.72704
val_loss,1.1119


In [16]:
# Define the model with your parameters
d_model = 126 # Embedding dimension
num_classes = 4  # Star classification categories

# Define the training parameters
num_epochs = 500
lr = 1e-4
patience = 50   
depth = 10

# Define the config dictionary object
config = {"num_classes": num_classes, "batch_size": batch_size, "lr": lr, "patience": patience, "num_epochs": num_epochs, "d_model": d_model, "depth": depth}

# Initialize WandB project
wandb.init(project="lamost-mamba-test", entity="joaoc-university-of-southampton", config=config)
# Initialize and train the model
# Train the model using your `train_model_vit` or an adjusted training loop
model_mamba = StarClassifierMAMBA(d_model=d_model, num_classes=num_classes, input_dim=3748, n_layers=depth)
print(model_mamba)
# print number of parameters per layer
for name, param in model_mamba.named_parameters():
    print(name, param.numel())
print("Total number of parameters:", sum(p.numel() for p in model_mamba.parameters() if p.requires_grad))

# Move the model to device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_mamba = model_mamba.to(device)

# Train the model using your `train_model_vit` or an adjusted training loop
trained_model = train_model_mamba(
    model=model_mamba,
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    num_epochs=num_epochs,
    lr=lr,
    max_patience=patience,
    device=device
)
# Save the model and finish WandB session
wandb.finish()

StarClassifierMAMBA(
  (mamba_layer): Mamba(
    (layers): ModuleList(
      (0-9): 10 x ResidualBlock(
        (mixer): MambaBlock(
          (in_proj): Linear(in_features=126, out_features=504, bias=False)
          (conv1d): Conv1d(252, 252, kernel_size=(4,), stride=(1,), padding=(3,), groups=252)
          (x_proj): Linear(in_features=252, out_features=136, bias=False)
          (dt_proj): Linear(in_features=8, out_features=252, bias=True)
          (out_proj): Linear(in_features=252, out_features=126, bias=False)
        )
        (norm): RMSNorm()
      )
    )
  )
  (input_projection): Linear(in_features=3748, out_features=126, bias=True)
  (classifier): Sequential(
    (0): LayerNorm((126,), eps=1e-05, elementwise_affine=True)
    (1): Linear(in_features=126, out_features=4, bias=True)
  )
)
mamba_layer.layers.0.mixer.A_log 16128
mamba_layer.layers.0.mixer.D 252
mamba_layer.layers.0.mixer.in_proj.weight 63504
mamba_layer.layers.0.mixer.conv1d.weight 1008
mamba_layer.layers.0.mi



Early stopping triggered.


VBox(children=(Label(value='0.072 MB of 0.072 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇████
learning_rate,███████████████████▃▃▃▃▃▃▃▃▃▃▃▁▁▁▁▁▁▁▁▁▁
test_accuracy,▁▂▄▄▅▆▆▇▇▆▆▆▆█▆▆▆▆▆▆▆▆▇▇▇▆▆▇▆▇▇█▇▆█████▇
test_loss,█▄▃▂▂▁▂▁▁▁▂▂▂▂▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▆▆▆▆▆▆▆▇
train_accuracy,▁▄▅▅▅▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇████████████████████
train_loss,█▅▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▃▆▇▆▇▇▆██▆▇▆█▇▇▇█▇█▇▇▇█▇▇▇▇██▇▇█▇▇▇▇▇▇▇
val_loss,█▆▃▂▂▁▁▁▂▁▁▂▁▂▂▂▂▂▄▄▄▄▅▄▅▅▅▅▅▆▅▅▆▆▆▆▆▆▇▇

0,1
classification_report,precis...
epoch,61
learning_rate,3e-05
test_accuracy,0.70449
test_loss,0.98023
train_accuracy,0.8893
train_loss,0.22365
val_accuracy,0.71075
val_loss,0.98973
