In [None]:

from tqdm import tqdm


In [None]:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import OxfordIIITPet
import torchvision.transforms as T
from einops import rearrange, repeat
import math
import os

# -------------------------
# Selective Scan
# -------------------------
def selective_scan(u, delta, A, B, C, D):
    dA = torch.einsum('bld,dn->bldn', delta, A)
    dB_u = torch.einsum('bld,bld,bln->bldn', delta, u, B)

    dA_cumsum = torch.cat([dA[:, 1:], torch.zeros_like(dA[:, :1])], dim=1)
    dA_cumsum = torch.flip(dA_cumsum, dims=[1])
    dA_cumsum = torch.cumsum(dA_cumsum, dim=1)
    dA_cumsum = torch.exp(dA_cumsum)
    dA_cumsum = torch.flip(dA_cumsum, dims=[1])

    x = dB_u * dA_cumsum
    x = torch.cumsum(x, dim=1) / (dA_cumsum + 1e-12)

    y = torch.einsum('bldn,bln->bld', x, C)
    return y + u * D

# -------------------------
# Model Args
# -------------------------
class ModelArgs:
    def __init__(self):
        self.model_input_dims = 32
        self.model_states = 32
        self.projection_expand_factor = 2
        self.conv_kernel_size = 4
        self.conv_use_bias = True
        self.dense_use_bias = False
        self.layer_id = -1
        self.seq_length = 128
        self.num_layers = 5
        self.dropout_rate = 0.2
        self.use_lm_head = False
        self.num_classes = 37
        self.final_activation = 'softmax'
        self.loss = nn.CrossEntropyLoss()
        self.optimizer = torch.optim.Adam
        self.metrics = ['accuracy']
        self.model_internal_dim = self.projection_expand_factor * self.model_input_dims
        self.delta_t_rank = math.ceil(self.model_input_dims / 16)

# -------------------------
# Mamba + Residual Blocks
# -------------------------
class MambaBlock(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.in_proj = nn.Linear(args.model_input_dims, args.model_internal_dim * 2, bias=False)
        self.conv1d = nn.Conv1d(args.model_internal_dim, args.model_internal_dim, kernel_size=args.conv_kernel_size,
                                padding=args.conv_kernel_size-1, groups=args.model_internal_dim)
        self.x_proj = nn.Linear(args.model_internal_dim, args.delta_t_rank + args.model_states * 2, bias=False)
        self.delta_proj = nn.Linear(args.delta_t_rank, args.model_internal_dim)

        A_vals = torch.arange(1, args.model_states + 1).float()
        self.A_log = nn.Parameter(torch.log(repeat(A_vals, 'n -> d n', d=args.model_internal_dim)))
        self.D = nn.Parameter(torch.ones(args.model_internal_dim))
        self.out_proj = nn.Linear(args.model_internal_dim, args.model_input_dims, bias=args.dense_use_bias)

    def forward(self, x):
        b, l, d = x.shape
        x_and_res = self.in_proj(x)
        x1, res = x_and_res.chunk(2, dim=-1)

        x1 = rearrange(x1, 'b l d -> b d l')
        x1 = self.conv1d(x1)[..., :l]
        x1 = rearrange(x1, 'b d l -> b l d')
        x1 = F.silu(x1)

        A = -torch.exp(self.A_log)
        D = self.D
        x_dbl = self.x_proj(x1)
        delta, B, C = torch.split(x_dbl, [self.args.delta_t_rank, self.args.model_states, self.args.model_states], dim=-1)
        delta = F.softplus(self.delta_proj(delta))

        y = selective_scan(x1, delta, A, B, C, D)
        y = y * F.silu(res)
        return self.out_proj(y)

class ResidualBlock(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.norm = nn.LayerNorm(args.model_input_dims)
        self.mixer = MambaBlock(args)

    def forward(self, x):
        return self.mixer(self.norm(x)) + x

# -------------------------
# Mamba-UNet Architecture
# -------------------------
class MambaUNet(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.res_blocks = nn.Sequential(*[ResidualBlock(args) for _ in range(args.num_layers)])
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(32, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Conv2d(32, args.num_classes, kernel_size=1)
        )
        self.activation = nn.Softmax(dim=1) if args.final_activation == 'softmax' else nn.Identity()

    def forward(self, x):
        x = self.encoder(x)
        b, c, h, w = x.size()
        x = x.permute(0, 2, 3, 1).reshape(b, h * w, c)
        x = self.res_blocks(x)
        x = x.reshape(b, h, w, c).permute(0, 3, 1, 2)
        x = self.decoder(x)
        return self.activation(x)

# -------------------------
# Dataset and Training
# -------------------------
def get_dataloaders(batch_size=4):
    transform = T.Compose([
        T.Resize((128, 128)),
        T.ToTensor()
    ])

    train_ds = OxfordIIITPet(root='.', split='trainval', download=True, target_types='category', transform=transform)
    test_ds = OxfordIIITPet(root='.', split='test', download=True, target_types='category', transform=transform)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2)
    test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=2)
    return train_loader, test_loader

def train(model, train_loader, optimizer, criterion, device):
    model.train()
    total_loss, correct, total = 0, 0, 0
    pbar = tqdm(train_loader, desc='Training', leave=False)
    for imgs, labels in pbar:
        imgs, labels = imgs.to(device), labels.to(device)
        outputs = model(imgs)
        outputs = outputs.mean(dim=(2, 3))  # global average pooling
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * imgs.size(0)
        pred = outputs.argmax(dim=1)
        correct += (pred == labels).sum().item()
        total += imgs.size(0)

        pbar.set_postfix(loss=total_loss/total, acc=correct/total)
    print(f"Train Loss: {total_loss/total:.4f} | Accuracy: {correct/total:.4f}")


def test(model, test_loader, criterion, device):
    model.eval()
    total_loss, correct, total = 0, 0, 0
    with torch.no_grad():
        for imgs, labels in test_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            outputs = outputs.mean(dim=(2, 3))
            loss = criterion(outputs, labels)

            total_loss += loss.item() * imgs.size(0)
            pred = outputs.argmax(dim=1)
            correct += (pred == labels).sum().item()
            total += imgs.size(0)

    print(f"Test Loss: {total_loss/total:.4f} | Accuracy: {correct/total:.4f}")

# -------------------------
# Main
# -------------------------
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args = ModelArgs()
    model = MambaUNet(args).to(device)
    train_loader, test_loader = get_dataloaders()
    optimizer = args.optimizer(model.parameters(), lr=1e-3)
    criterion = args.loss

    for epoch in range(1, 11):
        print(f"\nEpoch {epoch}")
        train(model, train_loader, optimizer, criterion, device)
        test(model, test_loader, criterion, device)

if __name__ == '__main__':
    main()




Epoch 1


Training:   1%|▏         | 13/920 [08:21<9:48:11, 38.91s/it, acc=0.0385, loss=3.61]

In [None]:
################# train on MINST

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import torchvision.transforms as T
from einops import rearrange, repeat
import math
from tqdm import tqdm

# -------------------------
# Selective Scan
# -------------------------
def selective_scan(u, delta, A, B, C, D):
    dA = torch.einsum('bld,dn->bldn', delta, A)
    dB_u = torch.einsum('bld,bld,bln->bldn', delta, u, B)

    dA_cumsum = torch.cat([dA[:, 1:], torch.zeros_like(dA[:, :1])], dim=1)
    dA_cumsum = torch.flip(dA_cumsum, dims=[1])
    dA_cumsum = torch.cumsum(dA_cumsum, dim=1)
    dA_cumsum = torch.exp(dA_cumsum)
    dA_cumsum = torch.flip(dA_cumsum, dims=[1])

    x = dB_u * dA_cumsum
    x = torch.cumsum(x, dim=1) / (dA_cumsum + 1e-12)

    y = torch.einsum('bldn,bln->bld', x, C)
    return y + u * D

# -------------------------
# Model Args
# -------------------------
class ModelArgs:
    def __init__(self):
        self.model_input_dims = 16
        self.model_states = 16
        self.projection_expand_factor = 1
        self.conv_kernel_size = 4
        self.conv_use_bias = True
        self.dense_use_bias = False
        self.layer_id = -1
        self.seq_length = 128
        self.num_layers = 3
        self.dropout_rate = 0.2
        self.use_lm_head = False
        self.num_classes = 10  # MNIST has 10 classes
        self.final_activation = 'softmax'
        self.loss = nn.CrossEntropyLoss()
        self.optimizer = torch.optim.Adam
        self.metrics = ['accuracy']
        self.model_internal_dim = self.projection_expand_factor * self.model_input_dims
        self.delta_t_rank = math.ceil(self.model_input_dims / 16)

# -------------------------
# Mamba + Residual Blocks
# -------------------------
class MambaBlock(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.in_proj = nn.Linear(args.model_input_dims, args.model_internal_dim * 2, bias=False)
        self.conv1d = nn.Conv1d(args.model_internal_dim, args.model_internal_dim, kernel_size=args.conv_kernel_size,
                                padding=args.conv_kernel_size - 1, groups=args.model_internal_dim)
        self.x_proj = nn.Linear(args.model_internal_dim, args.delta_t_rank + args.model_states * 2, bias=False)
        self.delta_proj = nn.Linear(args.delta_t_rank, args.model_internal_dim)

        A_vals = torch.arange(1, args.model_states + 1).float()
        self.A_log = nn.Parameter(torch.log(repeat(A_vals, 'n -> d n', d=args.model_internal_dim)))
        self.D = nn.Parameter(torch.ones(args.model_internal_dim))
        self.out_proj = nn.Linear(args.model_internal_dim, args.model_input_dims, bias=args.dense_use_bias)

    def forward(self, x):
        b, l, d = x.shape
        x_and_res = self.in_proj(x)
        x1, res = x_and_res.chunk(2, dim=-1)

        x1 = rearrange(x1, 'b l d -> b d l')
        x1 = self.conv1d(x1)[..., :l]
        x1 = rearrange(x1, 'b d l -> b l d')
        x1 = F.selu(x1)

        A = -torch.exp(self.A_log)
        D = self.D
        x_dbl = self.x_proj(x1)
        delta, B, C = torch.split(x_dbl, [self.args.delta_t_rank, self.args.model_states, self.args.model_states], dim=-1)
        delta = F.softplus(self.delta_proj(delta))

        y = selective_scan(x1, delta, A, B, C, D)
        y = y * F.selu(res)
        return self.out_proj(y)

class ResidualBlock(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.norm = nn.LayerNorm(args.model_input_dims)
        self.mixer = MambaBlock(args)

    def forward(self, x):
        return self.mixer(self.norm(x)) + x

# -------------------------
# Mamba-UNet Architecture
# -------------------------
class MambaUNet(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, padding=1),  # Changed for MNIST (1 channel)
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.res_blocks = nn.Sequential(*[ResidualBlock(args) for _ in range(args.num_layers)])
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(16, 16, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Conv2d(16, args.num_classes, kernel_size=1)
        )
        self.activation = nn.Softmax(dim=1) if args.final_activation == 'softmax' else nn.Identity()

    def forward(self, x):
        x = self.encoder(x)
        b, c, h, w = x.size()
        x = x.permute(0, 2, 3, 1).reshape(b, h * w, c)
        x = self.res_blocks(x)
        x = x.reshape(b, h, w, c).permute(0, 3, 1, 2)
        x = self.decoder(x)
        return self.activation(x)

# -------------------------
# Dataset and Training
# -------------------------
def get_dataloaders(batch_size=4):
    transform = T.Compose([
        T.Resize((128, 128)),
        T.ToTensor()
    ])

    train_ds = MNIST(root='.', train=True, download=True, transform=transform)
    test_ds = MNIST(root='.', train=False, download=True, transform=transform)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2)
    test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=2)
    return train_loader, test_loader

def train(model, train_loader, optimizer, criterion, device):
    model.train()
    total_loss, correct, total = 0, 0, 0
    pbar = tqdm(train_loader, desc='Training', leave=False)
    for imgs, labels in pbar:
        imgs, labels = imgs.to(device), labels.to(device)
        outputs = model(imgs)
        outputs = outputs.mean(dim=(2, 3))  # global average pooling
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * imgs.size(0)
        pred = outputs.argmax(dim=1)
        correct += (pred == labels).sum().item()
        total += imgs.size(0)

        pbar.set_postfix(loss=total_loss/total, acc=correct/total)
    print(f"Train Loss: {total_loss/total:.4f} | Accuracy: {correct/total:.4f}")

def test(model, test_loader, criterion, device):
    model.eval()
    total_loss, correct, total = 0, 0, 0
    with torch.no_grad():
        for imgs, labels in test_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            outputs = outputs.mean(dim=(2, 3))
            loss = criterion(outputs, labels)

            total_loss += loss.item() * imgs.size(0)
            pred = outputs.argmax(dim=1)
            correct += (pred == labels).sum().item()
            total += imgs.size(0)

    print(f"Test Loss: {total_loss/total:.4f} | Accuracy: {correct/total:.4f}")

# -------------------------
# Main
# -------------------------
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args = ModelArgs()
    model = MambaUNet(args).to(device)
    train_loader, test_loader = get_dataloaders()
    optimizer = args.optimizer(model.parameters(), lr=1e-3)
    criterion = args.loss

    for epoch in range(1, 11):
        print(f"\nEpoch {epoch}")
        train(model, train_loader, optimizer, criterion, device)
        test(model, test_loader, criterion, device)

if __name__ == '__main__':
    main()

100%|██████████| 9.91M/9.91M [00:00<00:00, 48.9MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.68MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 14.4MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 6.42MB/s]



Epoch 1


                                                                                     

Train Loss: 2.1847 | Accuracy: 0.2871




Test Loss: 2.1222 | Accuracy: 0.3273

Epoch 2


                                                                                     

Train Loss: 2.1178 | Accuracy: 0.3677




Test Loss: 2.0527 | Accuracy: 0.4238

Epoch 3


                                                                                     

Train Loss: 2.0425 | Accuracy: 0.4334




Test Loss: 2.0069 | Accuracy: 0.4789

Epoch 4


                                                                                     

Train Loss: 2.0227 | Accuracy: 0.4754




Test Loss: 2.0401 | Accuracy: 0.4685

Epoch 5


                                                                                     

Train Loss: 1.9904 | Accuracy: 0.5237




Test Loss: 2.1937 | Accuracy: 0.2636

Epoch 6


                                                                                     

Train Loss: 2.1124 | Accuracy: 0.3619




Test Loss: 1.9976 | Accuracy: 0.4987

Epoch 7


                                                                                    

Train Loss: 2.0995 | Accuracy: 0.3712




Test Loss: 2.2262 | Accuracy: 0.2158

Epoch 8


                                                                                     

Train Loss: 2.2646 | Accuracy: 0.1811




Test Loss: 2.3346 | Accuracy: 0.1010

Epoch 9


                                                                                     

Train Loss: 2.2637 | Accuracy: 0.1883




Test Loss: 2.3131 | Accuracy: 0.1259

Epoch 10


Training:   6%|▌         | 925/15000 [00:40<10:03, 23.32it/s, acc=0.124, loss=2.3]

In [None]:
!pip install torchview graphviz
!apt install graphviz


Collecting torchview
  Downloading torchview-0.2.7-py3-none-any.whl.metadata (13 kB)
Downloading torchview-0.2.7-py3-none-any.whl (26 kB)
Installing collected packages: torchview
Successfully installed torchview-0.2.7
Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
graphviz is already the newest version (2.42.2-6ubuntu0.1).
0 upgraded, 0 newly installed, 0 to remove and 34 not upgraded.


In [None]:
import torch
from torchview import draw_graph
# # Ensure your model code is in model.py or adjust the import

# Initialize model and args
args = ModelArgs()
model = MambaUNet(args)

# Dummy input: RGB image of size 128x128
dummy_input = torch.randn(1, 3, 128, 128)

# Create the graph
graph = draw_graph(
    model,
    input_data=dummy_input,
    expand_nested=True,
    graph_dir='TB',  # Top-to-bottom layout
    roll=True
)

# Render and save as PDF
graph.visual_graph.render(
    filename="mamba_unet_diagram",
    directory="./",
    format="pdf",
    cleanup=True
)

print("PDF saved as mamba_unet_diagram.pdf")


PDF saved as mamba_unet_diagram.pdf
