<a href="https://colab.research.google.com/github/Ayon150/AI/blob/main/Mamba_code_to_Mnist.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Install required packages (run in Colab or your environment)
!pip install torch torchvision mamba-ssm pytorch_lightning

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader, Subset
import pytorch_lightning as pl
from mamba_ssm import MambaLayer  # example import, refer to the specific API

# 1. Prepare MNIST dataset (even vs odd)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset  = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Convert labels to even=0, odd=1
train_dataset.targets = (train_dataset.targets % 2).long()
test_dataset.targets  = (test_dataset.targets  % 2).long()

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
test_loader  = DataLoader(test_dataset,  batch_size=64, shuffle=False, num_workers=4)

# 2. Define a simple model using MambaLayer + classification head
class MNIST_Mamba_Model(pl.LightningModule):
    def __init__(self):
        super().__init__()
        # e.g., flatten 28x28 into sequence 784 and feed into Mamba-SSM
        self.seq_len = 28*28
        self.input_dim = 1
        self.mamba = MambaLayer(d_model=64, seq_len=self.seq_len)
        self.fc1   = nn.Linear(64, 32)
        self.fc_out= nn.Linear(32, 1)

    def forward(self, x):
        # x: [batch, 1, 28, 28]
        b = x.size(0)
        x = x.view(b, self.seq_len, self.input_dim)  # [batch, seq_len, input_dim]
        y = self.mamba(x)                             # [batch, seq_len, d_model]
        y = y.mean(dim=1)                             # pool sequence dimension
        y = F.relu(self.fc1(y))
        y = torch.sigmoid(self.fc_out(y)).squeeze(1)
        return y

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss  = F.binary_cross_entropy(y_hat, y.float())
        acc   = ((y_hat>0.5).long()==y).float().mean()
        self.log('train_loss', loss, on_epoch=True)
        self.log('train_acc',  acc, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss  = F.binary_cross_entropy(y_hat, y.float())
        acc   = ((y_hat>0.5).long()==y).float().mean()
        self.log('val_loss', loss, on_epoch=True, prog_bar=True)
        self.log('val_acc',  acc, on_epoch=True, prog_bar=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss  = F.binary_cross_entropy(y_hat, y.float())
        acc   = ((y_hat>0.5).long()==y).float().mean()
        self.log('test_loss', loss)
        self.log('test_acc',  acc)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-4)

# 3. Train the model
model = MNIST_Mamba_Model()
trainer = pl.Trainer(max_epochs=10, gpus=1 if torch.cuda.is_available() else 0)
trainer.fit(model, train_loader, valid_dataloaders=test_loader)
trainer.test(model, test_dataloaders=test_loader)


Collecting mamba-ssm
  Downloading mamba_ssm-2.2.6.post3.tar.gz (113 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/113.9 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m113.9/113.9 kB[0m [31m5.8 MB/s[0m eta [36m0:00:00[0m
[?25h