In [17]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
from sklearn.metrics import classification_report
import wandb
import gc
import pandas as pd
from sklearn.model_selection import train_test_split
import mambapy
from mambapy.mamba import Mamba, MambaConfig
from torch import nn, optim
from timm.models.vision_transformer import _cfg, Mlp, Block


In [21]:
class BalancedDataset(Dataset):
    def __init__(self, X, y, limit_per_label=1600):
        self.X = X
        self.y = y
        self.limit_per_label = limit_per_label
        self.classes = np.unique(y)
        self.indices = self.balance_classes()

    def balance_classes(self):
        indices = []
        for cls in self.classes:
            cls_indices = np.where(self.y == cls)[0]
            if len(cls_indices) > self.limit_per_label:
                cls_indices = np.random.choice(cls_indices, self.limit_per_label, replace=False)
            indices.extend(cls_indices)
        np.random.shuffle(indices)
        return indices

    def re_sample(self):
        self.indices = self.balance_classes()

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        index = self.indices[idx]
        return self.X[index], self.y[index]
# Custom Dataset for validation with limit per class
class BalancedValidationDataset(Dataset):
    def __init__(self, X, y, limit_per_label=400):
        self.X = X
        self.y = y
        self.limit_per_label = limit_per_label
        self.classes = np.unique(y)
        self.indices = self.balance_classes()

    def balance_classes(self):
        indices = []
        for cls in self.classes:
            cls_indices = np.where(self.y == cls)[0]
            if len(cls_indices) > self.limit_per_label:
                cls_indices = np.random.choice(cls_indices, self.limit_per_label, replace=False)
            indices.extend(cls_indices)
        np.random.shuffle(indices)
        return indices
    
    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        index = self.indices[idx]
        return self.X[index], self.y[index]
    
# Create Datasets
class BalancedDatasetFusion(Dataset):
    def __init__(self, X_conv, X_gaia, y, limit_per_label=1600):
        self.X_conv = X_conv
        self.X_gaia = X_gaia
        self.y = y
        self.limit_per_label = limit_per_label
        self.classes = np.unique(y)
        self.indices = self.balance_classes()

    def balance_classes(self):
        indices = []
        for cls in self.classes:
            cls_indices = np.where(self.y == cls)[0]
            if len(cls_indices) > self.limit_per_label:
                cls_indices = np.random.choice(cls_indices, self.limit_per_label, replace=False)
            indices.extend(cls_indices)
        np.random.shuffle(indices)
        return indices

    def re_sample(self):
        self.indices = self.balance_classes()

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        index = self.indices[idx]
        return self.X_conv[index], self.X_gaia[index], self.y[index]

In [3]:
class StarClassifierMAMBA(nn.Module):
    def __init__(self, d_model, num_classes, d_state=64, d_conv=4, input_dim=17, n_layers=6):
        super(StarClassifierMAMBA, self).__init__()
        self.d_model = d_model
        self.num_classes = num_classes

        # MAMBA layer initialization
        config = MambaConfig(
            d_model=d_model,
            d_state=d_state,
            d_conv=d_conv,
            n_layers=n_layers

        )
        self.mamba_layer = Mamba(config)

        # Input projection to match the MAMBA layer dimension
        self.input_projection = nn.Linear(input_dim, d_model)

        # Fully connected classifier head
        self.classifier = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, num_classes)
        )


    def forward(self, x):
        x = self.input_projection(x)  # Ensure the input has the correct dimension
        x = x.unsqueeze(1)  # Adds a sequence dimension (L=1).
        x = self.mamba_layer(x)
        x = x.mean(dim=1)  # Pooling operation for classification
        x = self.classifier(x)
        return x

In [None]:
def train_model_mamba(
    model, train_loader, val_loader, test_loader, 
    num_epochs=500, lr=1e-4, max_patience=20, device='cuda'
):
    # Move model to device
    model = model.to(device)

    # Define optimizer, scheduler, and loss function
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=int(max_patience / 3), verbose=True
    )
    criterion = nn.CrossEntropyLoss()

    best_val_loss = float('inf')
    patience = max_patience

    for epoch in range(num_epochs):
        # Resample training and validation data
        train_loader.dataset.re_sample()
        val_loader.dataset.balance_classes()

        # Training phase
        model.train()
        train_loss, train_accuracy = 0.0, 0.0

        for X_batch, y_batch in train_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            optimizer.zero_grad()
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * X_batch.size(0)
            train_accuracy += (outputs.argmax(dim=1) == y_batch).float().mean().item()

        # Validation phase
        model.eval()
        val_loss, val_accuracy = 0.0, 0.0
        with torch.no_grad():
            for X_val, y_val in val_loader:
                X_val, y_val = X_val.to(device), y_val.to(device)
                outputs = model(X_val)
                loss = criterion(outputs, y_val)

                val_loss += loss.item() * X_val.size(0)
                val_accuracy += (outputs.argmax(dim=1) == y_val).float().mean().item()

        # Test phase and metric collection
        test_loss, test_accuracy = 0.0, 0.0
        y_true, y_pred = [], []
        with torch.no_grad():
            for X_test, y_test in test_loader:
                X_test, y_test = X_test.to(device), y_test.to(device)
                outputs = model(X_test)
                loss = criterion(outputs, y_test)

                test_loss += loss.item() * X_test.size(0)
                test_accuracy += (outputs.argmax(dim=1) == y_test).float().mean().item()
                y_true.extend(y_test.cpu().numpy())
                y_pred.extend(outputs.argmax(dim=1).cpu().numpy())

        # Update scheduler
        scheduler.step(val_loss / len(val_loader.dataset))

        # Log metrics to WandB
        wandb.log({
            "epoch": epoch,
            "train_loss": train_loss / len(train_loader.dataset),
            "val_loss": val_loss / len(val_loader.dataset),
            "train_accuracy": train_accuracy / len(train_loader),
            "val_accuracy": val_accuracy / len(val_loader),
            "learning_rate": optimizer.param_groups[0]['lr'],
            "test_loss": test_loss / len(test_loader.dataset),
            "test_accuracy": test_accuracy / len(test_loader),
            "confusion_matrix": wandb.plot.confusion_matrix(
                probs=None, y_true=y_true, preds=y_pred, class_names=np.unique(y_true)
            ),
            "classification_report": classification_report(
                y_true, y_pred, target_names=[str(i) for i in range(len(np.unique(y_true)))]
            )
        })

        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience = max_patience
            best_model = model.state_dict()
        else:
            patience -= 1
            if patience <= 0:
                print("Early stopping triggered.")
                break

    # Load the best model weights
    model.load_state_dict(best_model)
    return model

In [32]:
# Fusion dataset opening
batch_size = 2

if __name__ == "__main__":
        # Load and preprocess data
        X = pd.read_pickle("Pickles/trainv2.pkl")
        gaia_features = ["parallax", "ra", "dec", "ra_error", "dec_error", "parallax_error", "pmra", "pmdec", "pmra_error", "pmdec_error", 
        "phot_g_mean_flux", "flagnopllx", "phot_g_mean_flux_error", "phot_bp_mean_flux", "phot_rp_mean_flux", 
        "phot_bp_mean_flux_error", "phot_rp_mean_flux_error"]

        # Extract Gaia and LASMOST data
        X_gaia = X[gaia_features].values
        X_lamost = X.drop(gaia_features + ["label"], axis=1).values
        y = X["label"]

        # Mapping labels to integers
        label_mapping = {'star': 0, 'binary_star': 1, 'galaxy': 2, 'agn': 3}
        y = y.map(label_mapping).values

        # Read test data
        X_test = pd.read_pickle("Pickles/testv2.pkl")
        X_test_gaia = X_test[gaia_features].values
        X_test_conv = X_test.drop(gaia_features + ["label"], axis=1).values
        y_test = X_test["label"]
        y_test = y_test.map(label_mapping).values


        # Split data into train and validation
        X_train_conv, X_val_conv, X_train_gaia, X_val_gaia, y_train, y_val = train_test_split(X_lamost, X_gaia, y, test_size=0.2, random_state=42)

        # Convert to PyTorch tensors
        X_train_conv = torch.tensor(X_train_conv, dtype=torch.float32).unsqueeze(1)
        X_val_conv = torch.tensor(X_val_conv, dtype=torch.float32).unsqueeze(1)
        X_train_gaia = torch.tensor(X_train_gaia, dtype=torch.float32)
        X_val_gaia = torch.tensor(X_val_gaia, dtype=torch.float32)
        y_train = torch.tensor(y_train, dtype=torch.long)
        y_val = torch.tensor(y_val, dtype=torch.long)
        X_test_conv = torch.tensor(X_test_conv, dtype=torch.float32).unsqueeze(1)
        X_test_gaia = torch.tensor(X_test_gaia, dtype=torch.float32)
        y_test = torch.tensor(y_test, dtype=torch.long)

        # Create DataLoaders
        train_dataset = BalancedDatasetFusion(X_train_conv, X_train_gaia, y_train)
        val_dataset = BalancedDatasetFusion(X_val_conv, X_val_gaia, y_val)
        test_dataset = BalancedDatasetFusion(X_test_conv, X_test_gaia, y_test)

        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [15]:
class CrossAttentionBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., theta=10000,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, has_mlp=True):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = CrossAttention(
            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, theta=theta)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.has_mlp = has_mlp
        if has_mlp:
            self.norm2 = norm_layer(dim)
            mlp_hidden_dim = int(dim * mlp_ratio)
            self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x):
        x = x[:, 0:1, ...] + self.drop_path(self.attn(self.norm1(x)))
        if self.has_mlp:
            x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x
class CrossAttention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., theta=10000):
        super().__init__()
        self.num_heads = num_heads
        self.dim = dim
        self.scale = qk_scale or (dim // num_heads) ** -0.5
        self.theta = theta

        self.wq = nn.Linear(dim, dim, bias=qkv_bias)
        self.wk = nn.Linear(dim, dim, bias=qkv_bias)
        self.wv = nn.Linear(dim, dim, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        
        # Initialize rotary frequencies
        self.freqs = init_rope_frequencies(dim, num_heads, theta)

    def forward(self, x):
        B, N, C = x.shape
        q = self.wq(x[:, 0:1, ...]).view(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        k = self.wk(x).view(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        v = self.wv(x).view(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)

        # Apply rotary position embedding
        q_rot, k_rot = apply_rotary_position_embeddings(self.freqs.to(x.device), q, k)

        # Attention calculation with rotated embeddings
        attn = (q_rot @ k_rot.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        

        x = (attn @ v).transpose(1, 2).reshape(B, 1, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x
def init_rope_frequencies(dim, num_heads, theta, rotate=False):
    # Adjust the size of `mag` to match the per-head dimension
    per_head_dim = dim // ( num_heads)
    mag = 1 / (theta ** (torch.arange(0, per_head_dim).float() / (dim // num_heads))).unsqueeze(0)

    # Adjust `angles` accordingly
    angles = torch.rand(num_heads, per_head_dim//2) * 2 * torch.pi if rotate else torch.zeros(num_heads, per_head_dim//2)

    # Compute `freq_x` and `freq_y` with matching dimensions
    freq_x = mag * torch.cat([torch.cos(angles), torch.cos(torch.pi / 2 + angles)], dim=-1)
    freq_y = mag * torch.cat([torch.sin(angles), torch.sin(torch.pi / 2 + angles)], dim=-1)

    return torch.stack([freq_x, freq_y], dim=0)


def apply_rotary_position_embeddings(freqs, q, k):
    # Ensure `cos` and `sin` have the same shape as `q` and `k` by adding unsqueeze
    cos, sin = freqs[0].unsqueeze(1), freqs[1].unsqueeze(1)    
    
    # Broadcast `cos` and `sin` to match `q` and `k` dimensions
    q_rot = (q * cos) + (torch.roll(q, shifts=1, dims=-1) * sin)
    k_rot = (k * cos) + (torch.roll(k, shifts=1, dims=-1) * sin)
    
    return q_rot, k_rot

In [78]:
class DualMambaClassifier(nn.Module):
    def __init__(self, gaia_dim, spectra_dim, d_model, num_classes, d_state=64, d_conv=4, n_layers=6):
        super(DualMambaClassifier, self).__init__()
        # MAMBA model for Gaia data
        self.gaia_model = StarClassifierMAMBA(d_model, num_classes, d_state, d_conv, gaia_dim, n_layers)
        # MAMBA model for spectra data
        self.spectra_model = StarClassifierMAMBA(d_model, num_classes, d_state, d_conv, spectra_dim, n_layers)
        # Cross attention block
        self.gaia_model.input_projection = nn.Linear(gaia_dim, d_model)  # Gaia input
        self.spectra_model.input_projection = nn.Linear(spectra_dim, d_model)  # Spectra input

        print("Shape of Gaia input projection: ", self.gaia_model.input_projection)
        print("Shape of Spectra input projection: ", self.spectra_model.input_projection)
        self.cross_attention = CrossAttentionBlock(dim=d_model*2, num_heads=8)
        print("Shape of cross attention: ", self.cross_attention)

        # Classification head
        self.classifier = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, num_classes)
        )

    def forward(self, gaia_x, spectra_x):
        print(f'Gaia input shape: {gaia_x.shape}')
        print(f'Spectra input shape: {spectra_x.shape}')
        
        # Ensure the inputs are correctly aligned
        if gaia_x.shape[1] != 17 or spectra_x.shape[2] != 3749:
            raise ValueError("Input dimensions do not match the expected dimensions for Gaia and spectra data.")

        gaia_features = self.gaia_model.mamba_layer(self.gaia_model.input_projection(gaia_x).unsqueeze(1))
        spectra_features = self.spectra_model.mamba_layer(self.spectra_model.input_projection(spectra_x.squeeze(1)))

        # Cross attention: allowing information sharing between modalities
        combined_features = torch.cat([gaia_features, spectra_features], dim=1)
        print(f'Combined features shape: {combined_features.shape}')

        fused_features = self.cross_attention(combined_features)
        print(f'Fused features shape: {fused_features.shape}')

        # Global average pooling and classification
        pooled_features = fused_features.mean(dim=1)
        print(f'Pooled features shape: {pooled_features.shape}')

        output = self.classifier(pooled_features)

        return output

# Training loop
def train_model_mamba_fusion(
    model, train_loader, val_loader, test_loader, 
    num_epochs=500, lr=1e-4, max_patience=20, device='cuda'
):
    # Move model to device
    model = model.to(device)

    # Define optimizer, scheduler, and loss function
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=int(max_patience / 3), verbose=True
    )
    criterion = nn.CrossEntropyLoss()

    best_val_loss = float('inf')
    patience = max_patience

    for epoch in range(num_epochs):
        # Resample training and validation data if needed
        train_loader.dataset.re_sample()
        val_loader.dataset.balance_classes()

        # Training phase
        model.train()
        train_loss, train_accuracy = 0.0, 0.0

        for spectra_batch, gaia_batch, y_batch in train_loader:
            spectra_batch, gaia_batch, y_batch = (
                spectra_batch.to(device),
                gaia_batch.to(device),
                y_batch.to(device)
            )
            optimizer.zero_grad()
            # Ensure the correct order of inputs
            outputs = model(gaia_batch, spectra_batch)
            loss = criterion(outputs, y_batch)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * spectra_batch.size(0)
            train_accuracy += (outputs.argmax(dim=1) == y_batch).float().mean().item()

        # Validation phase
        model.eval()
        val_loss, val_accuracy = 0.0, 0.0
        with torch.no_grad():
            for spectra_val, gaia_val, y_val in val_loader:
                spectra_val, gaia_val, y_val = (
                    spectra_val.to(device),
                    gaia_val.to(device),
                    y_val.to(device)
                )
                outputs = model(gaia_val, spectra_val)
                loss = criterion(outputs, y_val)

                val_loss += loss.item() * spectra_val.size(0)
                val_accuracy += (outputs.argmax(dim=1) == y_val).float().mean().item()

        # Test phase and metric collection
        test_loss, test_accuracy = 0.0, 0.0
        y_true, y_pred = [], []
        with torch.no_grad():
            for spectra_test, gaia_test, y_test in test_loader:
                spectra_test, gaia_test, y_test = (
                    spectra_test.to(device),
                    gaia_test.to(device),
                    y_test.to(device)
                )
                outputs = model(gaia_test, spectra_test)
                loss = criterion(outputs, y_test)

                test_loss += loss.item() * spectra_test.size(0)
                test_accuracy += (outputs.argmax(dim=1) == y_test).float().mean().item()
                y_true.extend(y_test.cpu().numpy())
                y_pred.extend(outputs.argmax(dim=1).cpu().numpy())

        # Update scheduler
        scheduler.step(val_loss / len(val_loader.dataset))

        # Log metrics to WandB
        wandb.log({
            "epoch": epoch,
            "train_loss": train_loss / len(train_loader.dataset),
            "val_loss": val_loss / len(val_loader.dataset),
            "train_accuracy": train_accuracy / len(train_loader),
            "val_accuracy": val_accuracy / len(val_loader),
            "learning_rate": optimizer.param_groups[0]['lr'],
            "test_loss": test_loss / len(test_loader.dataset),
            "test_accuracy": test_accuracy / len(test_loader),
            "confusion_matrix": wandb.plot.confusion_matrix(
                probs=None, y_true=y_true, preds=y_pred, class_names=np.unique(y_true)
            ),
            "classification_report": classification_report(
                y_true, y_pred, target_names=[str(i) for i in range(len(np.unique(y_true)))]
            )
        })

        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience = max_patience
            best_model = model.state_dict()
        else:
            patience -= 1
            if patience <= 0:
                print("Early stopping triggered.")
                break

    # Load the best model weights
    model.load_state_dict(best_model)
    return model

In [79]:
# Define the model with your parameters
d_model = 256 # Embedding dimension
num_classes = 4  # Star classification categories

# Define the training parameters
num_epochs = 500
lr = 2e-7
patience = 50   
depth = 10

# Initialize the dual model
dual_model = DualMambaClassifier(
    gaia_dim=17, spectra_dim=3748, d_model = d_model, num_classes=4, n_layers=6
)

# Move the model to the device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dual_model = dual_model.to(device)

# Training loop
dual_model = train_model_mamba_fusion(
    dual_model, train_loader, val_loader, test_loader, 
    num_epochs=num_epochs, lr=lr, max_patience=patience, device=device
)



Shape of Gaia input projection:  Linear(in_features=17, out_features=256, bias=True)
Shape of Spectra input projection:  Linear(in_features=3748, out_features=256, bias=True)
Shape of cross attention:  CrossAttentionBlock(
  (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (attn): CrossAttention(
    (wq): Linear(in_features=512, out_features=512, bias=False)
    (wk): Linear(in_features=512, out_features=512, bias=False)
    (wv): Linear(in_features=512, out_features=512, bias=False)
    (attn_drop): Dropout(p=0.0, inplace=False)
    (proj): Linear(in_features=512, out_features=512, bias=True)
    (proj_drop): Dropout(p=0.0, inplace=False)
  )
  (drop_path): Identity()
  (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (mlp): Mlp(
    (fc1): Linear(in_features=512, out_features=2048, bias=True)
    (act): GELU(approximate='none')
    (drop1): Dropout(p=0.0, inplace=False)
    (norm): Identity()
    (fc2): Linear(in_features=2048, out_features=512, bias=

RuntimeError: mat1 and mat2 shapes cannot be multiplied (2x3749 and 3748x256)

In [None]:
# Define the model with your parameters
d_model = 2048 # Embedding dimension
num_classes = 4  # Star classification categories

# Define the training parameters
num_epochs = 500
lr = 2e-7
patience = 50   
depth = 10

# Define the config dictionary object
config = {"num_classes": num_classes, "batch_size": batch_size, "lr": lr, "patience": patience, "num_epochs": num_epochs, "d_model": d_model, "depth": depth}

# Initialize WandB project
wandb.init(project="lamost-mamba-test", entity="joaoc-university-of-southampton", config=config)
# Initialize and train the model
# Train the model using your `train_model_vit` or an adjusted training loop
model_mamba = StarClassifierMAMBA(d_model=d_model, num_classes=num_classes, input_dim=3748, n_layers=depth)
print(model_mamba)
# print number of parameters per layer
for name, param in model_mamba.named_parameters():
    print(name, param.numel())
print("Total number of parameters:", sum(p.numel() for p in model_mamba.parameters() if p.requires_grad))

# Move the model to device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_mamba = model_mamba.to(device)

# Train the model using your `train_model_vit` or an adjusted training loop
trained_model = train_model_mamba(
    model=model_mamba,
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    num_epochs=num_epochs,
    lr=lr,
    max_patience=patience,
    device=device
)
# Save the model and finish WandB session
wandb.finish()