In [3]:

from tqdm import tqdm


In [None]:
################# train on MINST

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import torchvision.transforms as T
from einops import rearrange, repeat
import math
from tqdm import tqdm

# -------------------------
# Selective Scan
# -------------------------
def selective_scan(u, delta, A, B, C, D):
    dA = torch.einsum('bld,dn->bldn', delta, A)
    dB_u = torch.einsum('bld,bld,bln->bldn', delta, u, B)

    dA_cumsum = torch.cat([dA[:, 1:], torch.zeros_like(dA[:, :1])], dim=1)
    dA_cumsum = torch.flip(dA_cumsum, dims=[1])
    dA_cumsum = torch.cumsum(dA_cumsum, dim=1)
    dA_cumsum = torch.exp(dA_cumsum)
    dA_cumsum = torch.flip(dA_cumsum, dims=[1])

    x = dB_u * dA_cumsum
    x = torch.cumsum(x, dim=1) / (dA_cumsum + 1e-12)

    y = torch.einsum('bldn,bln->bld', x, C)
    return y + u * D

# -------------------------
# Model Args
# -------------------------
class ModelArgs:
    def __init__(self):
        self.model_input_dims = 16
        self.model_states = 16
        self.projection_expand_factor = 1
        self.conv_kernel_size = 4
        self.conv_use_bias = True
        self.dense_use_bias = False
        self.layer_id = -1
        self.seq_length = 128
        self.num_layers = 3
        self.dropout_rate = 0.2
        self.use_lm_head = False
        self.num_classes = 10  # MNIST has 10 classes
        self.final_activation = 'softmax'
        self.loss = nn.CrossEntropyLoss()
        self.optimizer = torch.optim.Adam
        self.metrics = ['accuracy']
        self.model_internal_dim = self.projection_expand_factor * self.model_input_dims
        self.delta_t_rank = math.ceil(self.model_input_dims / 16)

# -------------------------
# Mamba + Residual Blocks
# -------------------------
class MambaBlock(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.in_proj = nn.Linear(args.model_input_dims, args.model_internal_dim * 2, bias=False)
        self.conv1d = nn.Conv1d(args.model_internal_dim, args.model_internal_dim, kernel_size=args.conv_kernel_size,
                                padding=args.conv_kernel_size - 1, groups=args.model_internal_dim)
        self.x_proj = nn.Linear(args.model_internal_dim, args.delta_t_rank + args.model_states * 2, bias=False)
        self.delta_proj = nn.Linear(args.delta_t_rank, args.model_internal_dim)

        A_vals = torch.arange(1, args.model_states + 1).float()
        self.A_log = nn.Parameter(torch.log(repeat(A_vals, 'n -> d n', d=args.model_internal_dim)))
        self.D = nn.Parameter(torch.ones(args.model_internal_dim))
        self.out_proj = nn.Linear(args.model_internal_dim, args.model_input_dims, bias=args.dense_use_bias)

    def forward(self, x):
        b, l, d = x.shape
        x_and_res = self.in_proj(x)
        x1, res = x_and_res.chunk(2, dim=-1)

        x1 = rearrange(x1, 'b l d -> b d l')
        x1 = self.conv1d(x1)[..., :l]
        x1 = rearrange(x1, 'b d l -> b l d')
        x1 = F.selu(x1)

        A = -torch.exp(self.A_log)
        D = self.D
        x_dbl = self.x_proj(x1)
        delta, B, C = torch.split(x_dbl, [self.args.delta_t_rank, self.args.model_states, self.args.model_states], dim=-1)
        delta = F.softplus(self.delta_proj(delta))

        y = selective_scan(x1, delta, A, B, C, D)
        y = y * F.selu(res)
        return self.out_proj(y)

class ResidualBlock(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.norm = nn.LayerNorm(args.model_input_dims)
        self.mixer = MambaBlock(args)

    def forward(self, x):
        return self.mixer(self.norm(x)) + x

# -------------------------
# Mamba-UNet Architecture
# -------------------------
class MambaUNet(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, padding=1),  # Changed for MNIST (1 channel)
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.res_blocks = nn.Sequential(*[ResidualBlock(args) for _ in range(args.num_layers)])
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(16, 16, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Conv2d(16, args.num_classes, kernel_size=1)
        )
        self.activation = nn.Softmax(dim=1) if args.final_activation == 'softmax' else nn.Identity()

    def forward(self, x):
        x = self.encoder(x)
        b, c, h, w = x.size()
        x = x.permute(0, 2, 3, 1).reshape(b, h * w, c)
        x = self.res_blocks(x)
        x = x.reshape(b, h, w, c).permute(0, 3, 1, 2)
        x = self.decoder(x)
        return self.activation(x)

# -------------------------
# Dataset and Training
# -------------------------
def get_dataloaders(batch_size=4):
    transform = T.Compose([
        T.Resize((128, 128)),
        T.ToTensor()
    ])

    train_ds = MNIST(root='.', train=True, download=True, transform=transform)
    test_ds = MNIST(root='.', train=False, download=True, transform=transform)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2)
    test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=2)
    return train_loader, test_loader

def train(model, train_loader, optimizer, criterion, device):
    model.train()
    total_loss, correct, total = 0, 0, 0
    pbar = tqdm(train_loader, desc='Training', leave=False)
    for imgs, labels in pbar:
        imgs, labels = imgs.to(device), labels.to(device)
        outputs = model(imgs)
        outputs = outputs.mean(dim=(2, 3))  # global average pooling
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * imgs.size(0)
        pred = outputs.argmax(dim=1)
        correct += (pred == labels).sum().item()
        total += imgs.size(0)

        pbar.set_postfix(loss=total_loss/total, acc=correct/total)
    print(f"Train Loss: {total_loss/total:.4f} | Accuracy: {correct/total:.4f}")

def test(model, test_loader, criterion, device):
    model.eval()
    total_loss, correct, total = 0, 0, 0
    with torch.no_grad():
        for imgs, labels in test_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            outputs = outputs.mean(dim=(2, 3))
            loss = criterion(outputs, labels)

            total_loss += loss.item() * imgs.size(0)
            pred = outputs.argmax(dim=1)
            correct += (pred == labels).sum().item()
            total += imgs.size(0)

    print(f"Test Loss: {total_loss/total:.4f} | Accuracy: {correct/total:.4f}")

# -------------------------
# Main
# -------------------------
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args = ModelArgs()
    model = MambaUNet(args).to(device)
    train_loader, test_loader = get_dataloaders()
    optimizer = args.optimizer(model.parameters(), lr=1e-3)
    criterion = args.loss

    for epoch in range(1, 11):
        print(f"\nEpoch {epoch}")
        train(model, train_loader, optimizer, criterion, device)
        test(model, test_loader, criterion, device)

if __name__ == '__main__':
    main()

100%|██████████| 9.91M/9.91M [00:00<00:00, 48.9MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.68MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 14.4MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 6.42MB/s]



Epoch 1


                                                                                     

Train Loss: 2.1847 | Accuracy: 0.2871




Test Loss: 2.1222 | Accuracy: 0.3273

Epoch 2


                                                                                     

Train Loss: 2.1178 | Accuracy: 0.3677




Test Loss: 2.0527 | Accuracy: 0.4238

Epoch 3


                                                                                     

Train Loss: 2.0425 | Accuracy: 0.4334




Test Loss: 2.0069 | Accuracy: 0.4789

Epoch 4


                                                                                     

Train Loss: 2.0227 | Accuracy: 0.4754




Test Loss: 2.0401 | Accuracy: 0.4685

Epoch 5


                                                                                     

Train Loss: 1.9904 | Accuracy: 0.5237




Test Loss: 2.1937 | Accuracy: 0.2636

Epoch 6


                                                                                     

Train Loss: 2.1124 | Accuracy: 0.3619




Test Loss: 1.9976 | Accuracy: 0.4987

Epoch 7


                                                                                    

Train Loss: 2.0995 | Accuracy: 0.3712




Test Loss: 2.2262 | Accuracy: 0.2158

Epoch 8


                                                                                     

Train Loss: 2.2646 | Accuracy: 0.1811




Test Loss: 2.3346 | Accuracy: 0.1010

Epoch 9


                                                                                     

Train Loss: 2.2637 | Accuracy: 0.1883




Test Loss: 2.3131 | Accuracy: 0.1259

Epoch 10


Training:   6%|▌         | 925/15000 [00:40<10:03, 23.32it/s, acc=0.124, loss=2.3]

In [None]:
!pip install torchview graphviz
!apt install graphviz


Collecting torchview
  Downloading torchview-0.2.7-py3-none-any.whl.metadata (13 kB)
Downloading torchview-0.2.7-py3-none-any.whl (26 kB)
Installing collected packages: torchview
Successfully installed torchview-0.2.7
Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
graphviz is already the newest version (2.42.2-6ubuntu0.1).
0 upgraded, 0 newly installed, 0 to remove and 34 not upgraded.


In [None]:
import torch
from torchview import draw_graph
# # Ensure your model code is in model.py or adjust the import

# Initialize model and args
args = ModelArgs()
model = MambaUNet(args)

# Dummy input: RGB image of size 128x128
dummy_input = torch.randn(1, 3, 128, 128)

# Create the graph
graph = draw_graph(
    model,
    input_data=dummy_input,
    expand_nested=True,
    graph_dir='TB',  # Top-to-bottom layout
    roll=True
)

# Render and save as PDF
graph.visual_graph.render(
    filename="mamba_unet_diagram",
    directory="./",
    format="pdf",
    cleanup=True
)

print("PDF saved as mamba_unet_diagram.pdf")


PDF saved as mamba_unet_diagram.pdf


# CODE FROM GPT FOR KVASIR SEG DATASET


In [1]:
from google.colab import files
uploaded = files.upload()


Saving kvasir-seg.zip to kvasir-seg.zip
unzip:  cannot find or open {kvasir-seg.zip}, {kvasir-seg.zip}.zip or {kvasir-seg.zip}.ZIP.


In [2]:
!unzip {"kvasir-seg.zip"}

Archive:  kvasir-seg.zip
   creating: Kvasir-SEG/
  inflating: Kvasir-SEG/kavsir_bboxes.json  
   creating: Kvasir-SEG/images/
  inflating: Kvasir-SEG/images/ck2bxiswtxuw80838qkisqjwz.jpg  
  inflating: Kvasir-SEG/images/ck2bxknhjvs1x0794iogrq49k.jpg  
  inflating: Kvasir-SEG/images/ck2bxlujamu330725szlc2jdu.jpg  
  inflating: Kvasir-SEG/images/ck2bxpfgxu2mk0748gsh7xelu.jpg  
  inflating: Kvasir-SEG/images/ck2bxqz3evvg20794iiyv5v2m.jpg  
  inflating: Kvasir-SEG/images/ck2bxskgxxzfv08386xkqtqdy.jpg  
  inflating: Kvasir-SEG/images/ck2bxw18mmz1k0725litqq2mc.jpg  
  inflating: Kvasir-SEG/images/ck2395w2mb4vu07480otsu6tw.jpg  
  inflating: Kvasir-SEG/images/ck2da7fwcjfis07218r1rvm95.jpg  
  inflating: Kvasir-SEG/images/cjyzjzssvd8pq0838f4nolj5l.jpg  
  inflating: Kvasir-SEG/images/cjyzk8qieoboa0848ogj51wwm.jpg  
  inflating: Kvasir-SEG/images/cju5hi52odyf90817prvcwg45.jpg  
  inflating: Kvasir-SEG/images/cju5hjxaae3i40850h5z2laf5.jpg  
  inflating: Kvasir-SEG/images/cju5hl8nee8a40755fm8qjj

In [None]:
from torch.utils.data import Dataset
from PIL import Image
import os

class KvasirSegDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None, mask_transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.mask_transform = mask_transform
        self.images = sorted(os.listdir(image_dir))

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.images[idx])
        mask_path = os.path.join(self.mask_dir, self.images[idx])

        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")

        if self.transform:
            image = self.transform(image)
        if self.mask_transform:
            mask = self.mask_transform(mask)

        return image, mask.long()


In [None]:
def get_kvasir_dataloaders(image_dir, mask_dir, batch_size=4):
    transform = T.Compose([
        T.Resize((128, 128)),
        T.ToTensor()
    ])
    mask_transform = T.Compose([
        T.Resize((128, 128)),
        T.ToTensor()
    ])

    dataset = KvasirSegDataset(image_dir, mask_dir, transform, mask_transform)
    train_len = int(0.8 * len(dataset))
    test_len = len(dataset) - train_len
    train_set, test_set = torch.utils.data.random_split(dataset, [train_len, test_len])

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_set, batch_size=batch_size)
    return train_loader, test_loader


In [4]:
# Mamba-UNet for Kvasir-SEG (Binary Segmentation)

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms as T
from PIL import Image
from tqdm import tqdm
from einops import rearrange, repeat
import math

# -------------------------
# Selective Scan
# -------------------------
def selective_scan(u, delta, A, B, C, D):
    dA = torch.einsum('bld,dn->bldn', delta, A)
    dB_u = torch.einsum('bld,bld,bln->bldn', delta, u, B)

    dA_cumsum = torch.cat([dA[:, 1:], torch.zeros_like(dA[:, :1])], dim=1)
    dA_cumsum = torch.flip(dA_cumsum, dims=[1])
    dA_cumsum = torch.cumsum(dA_cumsum, dim=1)
    dA_cumsum = torch.exp(dA_cumsum)
    dA_cumsum = torch.flip(dA_cumsum, dims=[1])

    x = dB_u * dA_cumsum
    x = torch.cumsum(x, dim=1) / (dA_cumsum + 1e-12)

    y = torch.einsum('bldn,bln->bld', x, C)
    return y + u * D

# -------------------------
# Args
# -------------------------
class ModelArgs:
    def __init__(self):
        self.model_input_dims = 16
        self.model_states = 16
        self.projection_expand_factor = 1
        self.conv_kernel_size = 4
        self.conv_use_bias = True
        self.dense_use_bias = False
        self.layer_id = -1
        self.seq_length = 128
        self.num_layers = 3
        self.dropout_rate = 0.2
        self.use_lm_head = False
        self.num_classes = 2  # foreground, background
        self.final_activation = 'none'
        self.loss = nn.CrossEntropyLoss()
        self.optimizer = torch.optim.Adam
        self.metrics = ['dice']
        self.model_internal_dim = self.projection_expand_factor * self.model_input_dims
        self.delta_t_rank = math.ceil(self.model_input_dims / 16)

# -------------------------
# Mamba Blocks
# -------------------------
class MambaBlock(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.in_proj = nn.Linear(args.model_input_dims, args.model_internal_dim * 2, bias=False)
        self.conv1d = nn.Conv1d(args.model_internal_dim, args.model_internal_dim, kernel_size=args.conv_kernel_size,
                                padding=args.conv_kernel_size - 1, groups=args.model_internal_dim)
        self.x_proj = nn.Linear(args.model_internal_dim, args.delta_t_rank + args.model_states * 2, bias=False)
        self.delta_proj = nn.Linear(args.delta_t_rank, args.model_internal_dim)

        A_vals = torch.arange(1, args.model_states + 1).float()
        self.A_log = nn.Parameter(torch.log(repeat(A_vals, 'n -> d n', d=args.model_internal_dim)))
        self.D = nn.Parameter(torch.ones(args.model_internal_dim))
        self.out_proj = nn.Linear(args.model_internal_dim, args.model_input_dims, bias=args.dense_use_bias)

    def forward(self, x):
        b, l, d = x.shape
        x_and_res = self.in_proj(x)
        x1, res = x_and_res.chunk(2, dim=-1)

        x1 = rearrange(x1, 'b l d -> b d l')
        x1 = self.conv1d(x1)[..., :l]
        x1 = rearrange(x1, 'b d l -> b l d')
        x1 = F.selu(x1)

        A = -torch.exp(self.A_log)
        D = self.D
        x_dbl = self.x_proj(x1)
        delta, B, C = torch.split(x_dbl, [self.args.delta_t_rank, self.args.model_states, self.args.model_states], dim=-1)
        delta = F.softplus(self.delta_proj(delta))

        y = selective_scan(x1, delta, A, B, C, D)
        y = y * F.selu(res)
        return self.out_proj(y)

class ResidualBlock(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.norm = nn.LayerNorm(args.model_input_dims)
        self.mixer = MambaBlock(args)

    def forward(self, x):
        return self.mixer(self.norm(x)) + x

# -------------------------
# UNet with Mamba Blocks
# -------------------------
class MambaUNet(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.res_blocks = nn.Sequential(*[ResidualBlock(args) for _ in range(args.num_layers)])
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(16, 16, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Conv2d(16, args.num_classes, kernel_size=1)
        )

    def forward(self, x):
        x = self.encoder(x)
        b, c, h, w = x.size()
        x = x.permute(0, 2, 3, 1).reshape(b, h * w, c)
        x = self.res_blocks(x)
        x = x.reshape(b, h, w, c).permute(0, 3, 1, 2)
        x = self.decoder(x)
        return x

# -------------------------
# Dataset Loader
# -------------------------
class KvasirSegDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = sorted(os.listdir(image_dir))

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = Image.open(os.path.join(self.image_dir, self.images[idx])).convert("RGB")
        mask = Image.open(os.path.join(self.mask_dir, self.images[idx])).convert("L")

        if self.transform:
            img = self.transform(img)
            mask = self.transform(mask)

        return img, mask.squeeze(0).long()

# -------------------------
# Dice Metric
# -------------------------
def dice_score(preds, targets):
    preds = torch.argmax(preds, dim=1)
    smooth = 1e-6
    intersection = (preds * targets).sum(dim=(1, 2))
    union = preds.sum(dim=(1, 2)) + targets.sum(dim=(1, 2))
    dice = ((2 * intersection + smooth) / (union + smooth)).mean()
    return dice.item()

# -------------------------
# Training & Testing
# -------------------------
def train(model, loader, optimizer, criterion, device):
    model.train()
    total_loss, total_dice = 0, 0
    for imgs, masks in tqdm(loader, desc='Training'):
        imgs, masks = imgs.to(device), masks.to(device)
        outputs = model(imgs)
        loss = criterion(outputs, masks)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        total_dice += dice_score(outputs.detach(), masks.detach())
    print(f"Train Loss: {total_loss / len(loader):.4f}, Dice: {total_dice / len(loader):.4f}")

def test(model, loader, criterion, device):
    model.eval()
    total_loss, total_dice = 0, 0
    with torch.no_grad():
        for imgs, masks in tqdm(loader, desc='Testing'):
            imgs, masks = imgs.to(device), masks.to(device)
            outputs = model(imgs)
            loss = criterion(outputs, masks)
            total_loss += loss.item()
            total_dice += dice_score(outputs, masks)
    print(f"Test Loss: {total_loss / len(loader):.4f}, Dice: {total_dice / len(loader):.4f}")

# -------------------------
# Main Function
# -------------------------
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args = ModelArgs()
    model = MambaUNet(args).to(device)

    image_dir = "/content/Kvasir-SEG/images"
    mask_dir = "/content/Kvasir-SEG/masks"

    transform = T.Compose([
        T.Resize((128, 128)),
        T.ToTensor()
    ])

    dataset = KvasirSegDataset(image_dir, mask_dir, transform=transform)
    train_len = int(0.8 * len(dataset))
    test_len = len(dataset) - train_len
    train_set, test_set = torch.utils.data.random_split(dataset, [train_len, test_len])
    train_loader = DataLoader(train_set, batch_size=4, shuffle=True)
    test_loader = DataLoader(test_set, batch_size=4)

    optimizer = args.optimizer(model.parameters(), lr=1e-3)
    criterion = args.loss

    for epoch in range(1, 11):
        print(f"\nEpoch {epoch}")
        train(model, train_loader, optimizer, criterion, device)
        test(model, test_loader, criterion, device)

if __name__ == '__main__':
    main()


Epoch 1


Training: 100%|██████████| 200/200 [00:15<00:00, 13.20it/s]


Train Loss: 0.3643, Dice: 0.0055


Testing: 100%|██████████| 50/50 [00:02<00:00, 23.74it/s]


Test Loss: 0.3564, Dice: 0.0000

Epoch 2


Training: 100%|██████████| 200/200 [00:13<00:00, 14.45it/s]


Train Loss: 0.3419, Dice: 0.0000


Testing: 100%|██████████| 50/50 [00:02<00:00, 20.31it/s]


Test Loss: 0.3739, Dice: 0.0000

Epoch 3


Training: 100%|██████████| 200/200 [00:13<00:00, 14.61it/s]


Train Loss: 0.3386, Dice: 0.0000


Testing: 100%|██████████| 50/50 [00:02<00:00, 23.85it/s]


Test Loss: 0.3577, Dice: 0.0000

Epoch 4


Training: 100%|██████████| 200/200 [00:13<00:00, 14.58it/s]


Train Loss: 0.3338, Dice: 0.0023


Testing: 100%|██████████| 50/50 [00:02<00:00, 24.10it/s]


Test Loss: 0.3480, Dice: 0.0000

Epoch 5


Training: 100%|██████████| 200/200 [00:13<00:00, 14.57it/s]


Train Loss: 0.3316, Dice: 0.0069


Testing: 100%|██████████| 50/50 [00:02<00:00, 19.43it/s]


Test Loss: 0.3542, Dice: 0.0003

Epoch 6


Training: 100%|██████████| 200/200 [00:13<00:00, 14.65it/s]


Train Loss: 0.3347, Dice: 0.0024


Testing: 100%|██████████| 50/50 [00:02<00:00, 23.85it/s]


Test Loss: 0.3458, Dice: 0.0242

Epoch 7


Training: 100%|██████████| 200/200 [00:13<00:00, 14.60it/s]


Train Loss: 0.3298, Dice: 0.0182


Testing: 100%|██████████| 50/50 [00:02<00:00, 24.31it/s]


Test Loss: 0.3500, Dice: 0.0001

Epoch 8


Training: 100%|██████████| 200/200 [00:13<00:00, 14.65it/s]


Train Loss: 0.3303, Dice: 0.0026


Testing: 100%|██████████| 50/50 [00:02<00:00, 18.60it/s]


Test Loss: 0.3520, Dice: 0.0023

Epoch 9


Training: 100%|██████████| 200/200 [00:14<00:00, 14.22it/s]


Train Loss: 0.3312, Dice: 0.0121


Testing: 100%|██████████| 50/50 [00:02<00:00, 24.25it/s]


Test Loss: 0.3461, Dice: 0.0067

Epoch 10


Training: 100%|██████████| 200/200 [00:13<00:00, 14.67it/s]


Train Loss: 0.3285, Dice: 0.0097


Testing: 100%|██████████| 50/50 [00:02<00:00, 23.67it/s]

Test Loss: 0.3444, Dice: 0.0002





In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

# -------------------------
# Mamba Block (unchanged)
# -------------------------
class MambaBlock(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.in_proj = nn.Linear(args.model_input_dims, args.model_internal_dim * 2, bias=False)
        self.conv1d = nn.Conv1d(args.model_internal_dim, args.model_internal_dim, kernel_size=args.conv_kernel_size,
                                padding=args.conv_kernel_size - 1, groups=args.model_internal_dim)
        self.x_proj = nn.Linear(args.model_internal_dim, args.delta_t_rank + args.model_states * 2, bias=False)
        self.delta_proj = nn.Linear(args.delta_t_rank, args.model_internal_dim)

        A_vals = torch.arange(1, args.model_states + 1).float()
        self.A_log = nn.Parameter(torch.log(torch.repeat_interleave(A_vals[None, :], args.model_internal_dim, dim=0)))
        self.D = nn.Parameter(torch.ones(args.model_internal_dim))
        self.out_proj = nn.Linear(args.model_internal_dim, args.model_input_dims, bias=args.dense_use_bias)

    def selective_scan(self, u, delta, A, B, C, D):
        dA = torch.einsum('bld,dn->bldn', delta, A)
        dB_u = torch.einsum('bld,bld,bln->bldn', delta, u, B)

        dA_cumsum = torch.cat([dA[:, 1:], torch.zeros_like(dA[:, :1])], dim=1)
        dA_cumsum = torch.flip(dA_cumsum, dims=[1])
        dA_cumsum = torch.cumsum(dA_cumsum, dim=1)
        dA_cumsum = torch.exp(dA_cumsum)
        dA_cumsum = torch.flip(dA_cumsum, dims=[1])

        x = dB_u * dA_cumsum
        x = torch.cumsum(x, dim=1) / (dA_cumsum + 1e-12)
        y = torch.einsum('bldn,bln->bld', x, C)
        return y + u * D

    def forward(self, x):
        b, l, d = x.shape
        x_and_res = self.in_proj(x)
        x1, res = x_and_res.chunk(2, dim=-1)

        x1 = rearrange(x1, 'b l d -> b d l')
        x1 = self.conv1d(x1)[..., :l]
        x1 = rearrange(x1, 'b d l -> b l d')
        x1 = F.selu(x1)

        A = -torch.exp(self.A_log.to(x.device))
        D = self.D.to(x.device)
        x_dbl = self.x_proj(x1)
        delta, B, C = torch.split(x_dbl, [self.args.delta_t_rank, self.args.model_states, self.args.model_states], dim=-1)
        delta = F.softplus(self.delta_proj(delta))

        y = self.selective_scan(x1, delta, A, B, C, D)
        y = y * F.selu(res)
        return self.out_proj(y)

# -------------------------
# UNet Blocks
# -------------------------
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.net(x)

class Down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_ch, out_ch)
        )

    def forward(self, x):
        return self.net(x)

class Up(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2)
        self.conv = DoubleConv(in_ch, out_ch)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        diffY = x2.size(2) - x1.size(2)
        diffX = x2.size(3) - x1.size(3)
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

# -------------------------
# MambaUNet with Skip Connections
# -------------------------
class MambaUNet(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args

        self.inc = DoubleConv(3, 32)
        self.down1 = Down(32, 64)
        self.down2 = Down(64, 128)

        # Mamba bottleneck
        self.bottleneck_down = Down(128, args.model_input_dims)
        self.mamba = MambaBlock(args)

        self.up2 = Up(args.model_input_dims, 128)
        self.up1 = Up(128, 64)
        self.outc = nn.Conv2d(64, 1, kernel_size=1)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)

        # Bottleneck
        x4 = self.bottleneck_down(x3)
        b, c, h, w = x4.shape
        x4_seq = x4.permute(0, 2, 3, 1).reshape(b, h*w, c)
        x4_seq = self.mamba(x4_seq)
        x4 = x4_seq.reshape(b, h, w, c).permute(0, 3, 1, 2)

        x = self.up2(x4, x3)
        x = self.up1(x, x2)
        logits = self.outc(x)
        return logits


In [6]:
import torch.nn.functional as F

def dice_loss(preds, targets, smooth=1e-6):
    preds = torch.sigmoid(preds)
    preds = preds.contiguous()
    targets = targets.contiguous()
    intersection = (preds * targets).sum(dim=(2, 3))
    union = preds.sum(dim=(2, 3)) + targets.sum(dim=(2, 3))
    dice = (2. * intersection + smooth) / (union + smooth)
    return 1 - dice.mean()

def combined_loss(preds, targets):
    bce = F.binary_cross_entropy_with_logits(preds, targets)
    dsc = dice_loss(preds, targets)
    return bce + dsc


In [7]:
class ModelArgs:
    def __init__(self):
        self.img_size = 128             # Input image size: (img_size x img_size)
        self.d_model = 256              # Hidden dimension for Mamba blocks
        self.n_layers = 4               # Number of encoder/decoder blocks (or Mamba blocks)
        self.d_state = 16               # State size for the Mamba SSM
        self.expand = 2                 # Expansion factor in the SSM block
        self.conv_kernel = 3            # Kernel size for Mamba convolution
        self.bias = True                # Bias in linear layers
        self.optimizer = torch.optim.Adam
        self.loss = combined_loss       # Combined BCE + Dice loss

def dice_score(preds, targets, threshold=0.5):
    preds = torch.sigmoid(preds)
    preds = (preds > threshold).float()
    intersection = (preds * targets).sum(dim=(2, 3))
    union = preds.sum(dim=(2, 3)) + targets.sum(dim=(2, 3))
    dice = (2. * intersection) / (union + 1e-8)
    return dice.mean().item()

def train(model, loader, optimizer, criterion, device):
    model.train()
    total_loss, total_dice = 0, 0
    for imgs, masks in tqdm(loader, desc='Training'):
        imgs, masks = imgs.to(device), masks.to(device)
        outputs = model(imgs)
        loss = criterion(outputs, masks)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        total_dice += dice_score(outputs.detach(), masks.detach())
    print(f"Train Loss: {total_loss / len(loader):.4f}, Dice: {total_dice / len(loader):.4f}")

def test(model, loader, criterion, device):
    model.eval()
    total_loss, total_dice = 0, 0
    with torch.no_grad():
        for imgs, masks in tqdm(loader, desc='Testing'):
            imgs, masks = imgs.to(device), masks.to(device)
            outputs = model(imgs)
            loss = criterion(outputs, masks)
            total_loss += loss.item()
            total_dice += dice_score(outputs, masks)
    print(f"Test Loss: {total_loss / len(loader):.4f}, Dice: {total_dice / len(loader):.4f}")
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args = ModelArgs()
    model = MambaUNet(args).to(device)

    image_dir = "/content/Kvasir-SEG/images"
    mask_dir = "/content/Kvasir-SEG/masks"

    transform = T.Compose([
        T.Resize((128, 128)),
        T.ToTensor()
    ])

    dataset = KvasirSegDataset(image_dir, mask_dir, transform=transform)
    train_len = int(0.8 * len(dataset))
    test_len = len(dataset) - train_len
    train_set, test_set = torch.utils.data.random_split(dataset, [train_len, test_len])
    train_loader = DataLoader(train_set, batch_size=4, shuffle=True)
    test_loader = DataLoader(test_set, batch_size=4)

    optimizer = args.optimizer(model.parameters(), lr=1e-3)
    criterion = args.loss

    for epoch in range(1, 11):
        print(f"\nEpoch {epoch}")
        train(model, train_loader, optimizer, criterion, device)
        test(model, test_loader, criterion, device)

if __name__ == '__main__':
    main()


AttributeError: 'ModelArgs' object has no attribute 'model_input_dims'