In [1]:
from torch import Tensor, nn
from zeta import MambaBlock
from zeta.nn import FeedForward
from zeta import MultiQueryAttention
from zeta.nn.modules.simple_rmsnorm import SimpleRMSNorm
from jamba.moe import MoE
from zeta.nn import OutputHead

In [None]:
class BalancedDataset(Dataset):
    def __init__(self, X, y, limit_per_label=1600):
        self.X = X
        self.y = y
        self.limit_per_label = limit_per_label
        self.classes = np.unique(y)
        self.indices = self.balance_classes()

    def balance_classes(self):
        indices = []
        for cls in self.classes:
            cls_indices = np.where(self.y == cls)[0]
            if len(cls_indices) > self.limit_per_label:
                cls_indices = np.random.choice(cls_indices, self.limit_per_label, replace=False)
            indices.extend(cls_indices)
        np.random.shuffle(indices)
        return indices

    def re_sample(self):
        self.indices = self.balance_classes()

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        index = self.indices[idx]
        return self.X[index], self.y[index]
# Custom Dataset for validation with limit per class
class BalancedValidationDataset(Dataset):
    def __init__(self, X, y, limit_per_label=400):
        self.X = X
        self.y = y
        self.limit_per_label = limit_per_label
        self.classes = np.unique(y)
        self.indices = self.balance_classes()

    def balance_classes(self):
        indices = []
        for cls in self.classes:
            cls_indices = np.where(self.y == cls)[0]
            if len(cls_indices) > self.limit_per_label:
                cls_indices = np.random.choice(cls_indices, self.limit_per_label, replace=False)
            indices.extend(cls_indices)
        np.random.shuffle(indices)
        return indices
    
    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        index = self.indices[idx]
        return self.X[index], self.y[index]

In [None]:
def train_model_mamba(
    model, train_loader, val_loader, test_loader, 
    num_epochs=500, lr=1e-4, max_patience=20, device='cuda'
):
    # Move model to device
    model = model.to(device)

    # Define optimizer, scheduler, and loss function
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=int(max_patience / 3), verbose=True
    )
    criterion = nn.CrossEntropyLoss()

    best_val_loss = float('inf')
    patience = max_patience

    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss, train_accuracy = 0.0, 0.0

        for X_batch, y_batch in train_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            optimizer.zero_grad()
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * X_batch.size(0)
            train_accuracy += (outputs.argmax(dim=1) == y_batch).float().mean().item()

        # Validation phase
        model.eval()
        val_loss, val_accuracy = 0.0, 0.0
        with torch.no_grad():
            for X_val, y_val in val_loader:
                X_val, y_val = X_val.to(device), y_val.to(device)
                outputs = model(X_val)
                loss = criterion(outputs, y_val)

                val_loss += loss.item() * X_val.size(0)
                val_accuracy += (outputs.argmax(dim=1) == y_val).float().mean().item()

        # Test phase and metric collection
        test_loss, test_accuracy = 0.0, 0.0
        y_true, y_pred = [], []
        with torch.no_grad():
            for X_test, y_test in test_loader:
                X_test, y_test = X_test.to(device), y_test.to(device)
                outputs = model(X_test)
                loss = criterion(outputs, y_test)

                test_loss += loss.item() * X_test.size(0)
                test_accuracy += (outputs.argmax(dim=1) == y_test).float().mean().item()
                y_true.extend(y_test.cpu().numpy())
                y_pred.extend(outputs.argmax(dim=1).cpu().numpy())

        # Update scheduler
        scheduler.step(val_loss / len(val_loader.dataset))

        # Log metrics to WandB
        wandb.log({
            "epoch": epoch,
            "train_loss": train_loss / len(train_loader.dataset),
            "val_loss": val_loss / len(val_loader.dataset),
            "train_accuracy": train_accuracy / len(train_loader),
            "val_accuracy": val_accuracy / len(val_loader),
            "learning_rate": optimizer.param_groups[0]['lr'],
            "test_loss": test_loss / len(test_loader.dataset),
            "test_accuracy": test_accuracy / len(test_loader),
            "confusion_matrix": wandb.plot.confusion_matrix(
                probs=None, y_true=y_true, preds=y_pred, class_names=np.unique(y_true)
            ),
            "classification_report": classification_report(
                y_true, y_pred, target_names=[str(i) for i in range(len(np.unique(y_true)))]
            )
        })

        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience = max_patience
            best_model = model.state_dict()
        else:
            patience -= 1
            if patience <= 0:
                print("Early stopping triggered.")
                break

    # Load the best model weights
    model.load_state_dict(best_model)
    return model



In [2]:
class TransformerMoEBlock(nn.Module):
    def __init__(
        self,
        dim: int,
        heads: int,
        num_experts: int,
        num_experts_per_token: int,
        *args,
        **kwargs,
    ):
        """
        Initializes a TransformerMoEBlock.

        Args:
            dim (int): The dimension of the input tensor.
            heads (int): The number of attention heads.
            num_experts (int): The total number of experts.
            num_experts_per_token (int): The number of experts per token.
            *args: Variable length argument list.
            **kwargs: Arbitrary keyword arguments.
        """
        super().__init__()
        self.dim = dim
        self.heads = heads
        self.num_experts = num_experts
        self.num_experts_per_tok = num_experts_per_token

        self.attn = MultiQueryAttention(dim, heads)
        self.moe = MoE(
            dim,
            num_experts=num_experts,
            hidden_dim=dim * 4,
        )

    def forward(self, x: Tensor) -> Tensor:
        """
        Forward pass of the TransformerMoEBlock.

        Args:
            x (Tensor): Input tensor.

        Returns:
            Tensor: Output tensor after applying the TransformerMoEBlock.
        """
        skip = x
        x = SimpleRMSNorm(self.dim)(x)
        x, _, _ = self.attn(x) + x

        x = SimpleRMSNorm(self.dim)(x)
        moe_out, _ = self.moe(x)
        x = moe_out + skip
        return x


class TransformerBlock(nn.Module):
    def __init__(
        self,
        dim: int,
        heads: int,
        *args,
        **kwargs,
    ):
        """
        Initializes a TransformerBlock.

        Args:
            dim (int): Dimension of the input tensor.
            heads (int): Number of attention heads.
            num_experts (int): Number of experts.
            num_experts_per_token (int): Number of experts per token.
            *args: Variable length argument list.
            **kwargs: Arbitrary keyword arguments.
        """
        super().__init__()
        self.dim = dim
        self.heads = heads

        self.attn = MultiQueryAttention(dim, heads)
        self.ffn = FeedForward(
            dim,
            dim,
            4,
            swish=True,
            post_act_ln=True,
        )

    def forward(self, x: Tensor) -> Tensor:
        """
        Forward pass of the TransformerBlock.

        Args:
            x (Tensor): Input tensor.

        Returns:
            Tensor: Output tensor after applying the TransformerBlock.
        """
        skip = x
        x = SimpleRMSNorm(self.dim)(x)
        x, _, _ = self.attn(x)
        x += skip

        skip_two = x

        x = SimpleRMSNorm(self.dim)(x)
        x = self.ffn(x) + skip_two
        return x


class MambaMoELayer(nn.Module):
    def __init__(
        self,
        dim: int,
        d_state: int,
        d_conv: int,
        num_experts: int = 8,
        num_experts_per_token: int = 2,
        *args,
        **kwargs,
    ):
        """
        Initialize the MambaMoELayer.

        Args:
            dim (int): Dimension of the input tensor.
            d_state (int): Dimension of the state tensor.
            d_conv (int): Dimension of the convolutional tensor.
            num_experts (int, optional): Number of experts. Defaults to 8.
            num_experts_per_token (int, optional): Number of experts per token. Defaults to 2.
            *args: Variable length argument list.
            **kwargs: Arbitrary keyword arguments.
        """
        super().__init__()
        self.dim = dim
        self.d_state = d_state
        self.d_conv = d_conv
        self.num_experts = num_experts
        self.num_experts_per_tok = num_experts_per_token

        # Mamba
        self.mamba = MambaBlock(
            dim,
            depth=1,
            d_state=d_state,
            d_conv=d_conv,
        )

        # MoE
        self.moe = MoE(
            dim,
            num_experts=num_experts,
            hidden_dim=dim * 4,
        )

    def forward(self, x: Tensor):
        """
        Forward pass of the MambaMoELayer.

        Args:
            x (Tensor): Input tensor.

        Returns:
            Tensor: Output tensor after applying the MambaMoELayer.
        """
        skip = x

        x = SimpleRMSNorm(self.dim)(x)
        x = self.mamba(x) + x

        x = SimpleRMSNorm(self.dim)(x)
        moe_out, _ = self.moe(x)
        x = moe_out + skip
        return x


class JambaBlock(nn.Module):
    """
    JambaBlock is a module that combines MambaBlock, MambaMoELayer, and TransformerBlock
    to process input tensors.

    Args:
        dim (int): The input dimension.
        d_state (int): The dimension of the state in MambaBlock and MambaMoELayer.
        d_conv (int): The dimension of the convolutional output in MambaBlock and MambaMoELayer.
        heads (int): The number of attention heads in TransformerBlock.
        num_experts (int, optional): The number of experts in MambaMoELayer. Defaults to 8.
        num_experts_per_token (int, optional): The number of experts per token in MambaMoELayer. Defaults to 2.

    Attributes:
        dim (int): The input dimension.
        d_state (int): The dimension of the state in MambaBlock and MambaMoELayer.
        d_conv (int): The dimension of the convolutional output in MambaBlock and MambaMoELayer.
        heads (int): The number of attention heads in TransformerBlock.
        num_experts (int): The number of experts in MambaMoELayer.
        num_experts_per_tok (int): The number of experts per token in MambaMoELayer.
        mamba_layer (MambaBlock): The MambaBlock layer.
        mamba_moe_layer (MambaMoELayer): The MambaMoELayer layer.
        transformer (TransformerBlock): The TransformerBlock layer.

    """

    def __init__(
        self,
        dim: int,
        d_state: int,
        d_conv: int,
        heads: int,
        num_experts: int = 8,
        num_experts_per_token: int = 2,
        *args,
        **kwargs,
    ):
        super().__init__()
        self.dim = dim
        self.d_state = d_state
        self.d_conv = d_conv
        self.heads = heads
        self.num_experts = num_experts
        self.num_experts_per_tok = num_experts_per_token

        # Mamba
        self.mamba_layer = MambaBlock(
            dim,
            depth=1,
            d_state=d_state,
            d_conv=d_conv,
        )

        # Mamba MoE layer
        self.mamba_moe_layer = MambaMoELayer(
            dim,
            d_state,
            d_conv,
            num_experts,
            num_experts_per_token,
        )

        # Transformer
        self.transformer = TransformerBlock(
            dim,
            heads,
        )

    def forward(self, x: Tensor) -> Tensor:
        x = self.mamba_layer(x)
        x = self.mamba_moe_layer(x)
        x = self.transformer(x)
        x = self.mamba_moe_layer(x)
        x = self.mamba_layer(x)
        x = self.mamba_moe_layer(x)
        return x


class Jamba(nn.Module):
    """
    Jamba model implementation.

    Args:
        dim (int): Dimension of the model.
        depth (int): Depth of the model.
        num_tokens (int): Number of tokens.
        max_seq_len (int): Maximum sequence length.
        d_state (int): State dimension.
        d_conv (int): Convolutional dimension.
        heads (int): Number of attention heads.
        num_experts (int, optional): Number of experts. Defaults to 8.
        num_experts_per_token (int, optional): Number of experts per token. Defaults to 2.
        pre_emb_norm (bool, optional): Whether to normalize the embeddings. Defaults to False.
        return_embeddings (bool, optional): Whether to return the embeddings. Defaults to False.

    Attributes:
        dim (int): Dimension of the model.
        depth (int): Depth of the model.
        d_state (int): State dimension.
        d_conv (int): Convolutional dimension.
        heads (int): Number of attention heads.
        num_experts (int): Number of experts.
        num_experts_per_tok (int): Number of experts per token.
        pre_emb_norm (bool): Whether to normalize the embeddings.
        return_embeddings (bool): Whether to return the embeddings.
        layers (nn.ModuleList): List of JambaBlock layers.
        embed (nn.Embedding): Embedding layer.
        norm (nn.LayerNorm or nn.Identity): Normalization layer.

    """

    def __init__(
        self,
        dim: int,
        depth: int,
        num_tokens: int,
        d_state: int,
        d_conv: int,
        heads: int,
        num_experts: int = 8,
        num_experts_per_token: int = 2,
        pre_emb_norm: bool = False,
        return_embeddings: bool = False,
        *args,
        **kwargs,
    ):
        super().__init__()
        self.dim = dim
        self.depth = depth
        self.d_state = d_state
        self.d_conv = d_conv
        self.heads = heads
        self.num_experts = num_experts
        self.num_experts_per_tok = num_experts_per_token
        self.pre_emb_norm = pre_emb_norm
        self.return_embeddings = return_embeddings

        # Layers
        self.layers = nn.ModuleList(
            [
                JambaBlock(
                    dim,
                    d_state,
                    d_conv,
                    heads,
                    num_experts,
                    num_experts_per_token,
                )
                for _ in range(depth)
            ]
        )

        # Pre Emb
        self.embed = nn.Embedding(num_tokens, dim)

        # Embedding Norm
        self.norm = (
            nn.LayerNorm(dim) if pre_emb_norm else nn.Identity()
        )

    def forward(self, x: Tensor) -> Tensor:
        """
        Forward pass of the Jamba model.

        Args:
            x (Tensor): Input tensor.

        Returns:
            Tensor: Output tensor.

        """
        # Embed the input tensor to transform
        # From tokens -> tensors
        x = self.embed(x)

        # Normalize the embeddings
        x = self.norm(x)

        # Apply the layers
        for layer in self.layers:
            x = layer(x)

        if self.return_embeddings:
            return x
        else:
            # return the logits
            return OutputHead(self.dim, -1)(x)

In [4]:
batch_size = 512

# Example usage
if __name__ == "__main__":
    # Load and preprocess your data (example from original script)
    # Load and preprocess data
    X = pd.read_pickle("Pickles/trainv2.pkl")
    y = X["label"]
    label_mapping = {'star': 0, 'binary_star': 1, 'galaxy': 2, 'agn': 3}
    y = y.map(label_mapping).values
    X = X.drop(["parallax", "ra", "dec", "ra_error", "dec_error", "parallax_error", "pmra", "pmdec", "pmra_error", "pmdec_error", 
                "phot_g_mean_flux", "flagnopllx", "phot_g_mean_flux_error", "phot_bp_mean_flux", "phot_rp_mean_flux", 
                "phot_bp_mean_flux_error", "phot_rp_mean_flux_error", "label", "obsid"], axis=1).values
    
    # Read test data
    X_test = pd.read_pickle("Pickles/testv2.pkl")
    y_test = X_test["label"].map(label_mapping).values
    X_test = X_test.drop(["parallax", "ra", "dec", "ra_error", "dec_error", "parallax_error", "pmra", "pmdec", "pmra_error", "pmdec_error", 
                "phot_g_mean_flux", "flagnopllx", "phot_g_mean_flux_error", "phot_bp_mean_flux", "phot_rp_mean_flux", 
                "phot_bp_mean_flux_error", "phot_rp_mean_flux_error", "label", "obsid"], axis=1).values
    
    # Split data
    X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)

    # Clear memory
    del X, y
    gc.collect()

    # Convert to torch tensors and create datasets
    X_train = torch.tensor(X_train, dtype=torch.float32).unsqueeze(1)
    X_val = torch.tensor(X_val, dtype=torch.float32).unsqueeze(1)
    y_train = torch.tensor(y_train, dtype=torch.long)
    y_val = torch.tensor(y_val, dtype=torch.long)

    train_dataset = BalancedDataset(X_train, y_train)
    val_dataset = BalancedValidationDataset(X_val, y_val)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(BalancedValidationDataset(torch.tensor(X_test, dtype=torch.float32).unsqueeze(1),
                                                    torch.tensor(y_test, dtype=torch.long)), batch_size=batch_size, shuffle=False)
    


NameError: name 'pd' is not defined

In [3]:
from jamba.model import Jamba

# Define the model with your parameters
d_model = 128 # Embedding dimension
num_classes = 4  # Star classification categories

# Define the training parameters
num_epochs = 500
lr = 1e-4
patience = 30   
depth = 10

# Define the config dictionary object
config = {"num_classes": num_classes, "batch_size": batch_size, "lr": lr, "patience": patience, "num_epochs": num_epochs, "d_model": d_model, "depth": depth}

# Initialize WandB project
wandb.init(project="lamost-jamba-test", entity="joaoc-university-of-southampton", config=config)

# Define hyperparameters
num_epochs = 100
lr = 1e-3
patience = 10

# Initialize the Jamba model
model_mamba = Jamba(
    dim=3748,                # Input dimensionality
    depth=4,                # Number of layers
    num_tokens=100,         # Token size (adapt to your case)
    d_state=d_model,            # Hidden state dimensionality
    d_conv=128,             # Convolutional layers dimensionality
    heads=8,                # Number of attention heads
    num_experts=8,          # Number of expert networks
    num_experts_per_token=2 # Experts per token
)

print(model_mamba)
# Print number of parameters per layer
for name, param in model_mamba.named_parameters():
    print(name, param.numel())
print("Total number of parameters:", sum(p.numel() for p in model_mamba.parameters() if p.requires_grad))

# Move the model to device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_mamba = model_mamba.to(device)

# Train the model
trained_model = train_model_mamba(
    model=model_mamba,
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    num_epochs=num_epochs,
    lr=lr,
    max_patience=patience,
    device=device
)

# Save the model and finish WandB session
wandb.finish()

    

NameError: name 'batch_size' is not defined