In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import nibabel as nib
import numpy as np
from torch.utils.data import DataLoader, Dataset, Subset
from torch.optim.lr_scheduler import StepLR
from sklearn.metrics import (
    accuracy_score, confusion_matrix, f1_score, roc_auc_score, average_precision_score
)
from sklearn.model_selection import train_test_split, KFold

# Custom 3D Swin Transformer Block with 6 Layer Normalizations
class SwinTransformerBlock3D(nn.Module):
    def __init__(self, dim, num_heads, window_size, mlp_ratio=4.0, dropout=0.5):
        super(SwinTransformerBlock3D, self).__init__()
        self.window_size = window_size
        self.num_heads = num_heads

        # Define 6 Layer Normalization layers
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.norm3 = nn.LayerNorm(dim)
        self.norm4 = nn.LayerNorm(dim)
        self.norm5 = nn.LayerNorm(dim)
        self.norm6 = nn.LayerNorm(dim)

        # Multi-head self-attention
        self.attn = nn.MultiheadAttention(dim, num_heads)

        # MLP (feed-forward) with dropout
        self.mlp = nn.Sequential(
            nn.Linear(dim, int(mlp_ratio * dim)),
            nn.GELU(),
            nn.Dropout(dropout),  # Dropout after GELU
            nn.Linear(int(mlp_ratio * dim), dim),
            nn.Dropout(dropout)  # Dropout after second Linear layer
        )

    def forward(self, x):
        B, D, H, W, C = x.shape
        x = x.view(B, D * H * W, C)  # Flatten spatial dimensions for attention

        # Layer Normalization and Self-Attention with Residual Connection
        x = x + self.attn(self.norm1(x), self.norm2(x), self.norm3(x))[0]

        # Further normalization before the MLP block
        x = self.norm4(x)

        # MLP block with residual connection
        x = x + self.mlp(self.norm5(x))

        # Final normalization before output
        x = self.norm6(x)

        return x.view(B, D, H, W, C)


class SwinTransformer3D(nn.Module):
    def __init__(self, img_size=(256, 256, 32), patch_size=(8, 8, 8), in_chans=1, 
                 num_classes=2, embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], 
                 mlp_ratio=4.0, dropout=0.5):
        super(SwinTransformer3D, self).__init__()

        self.num_layers = len(depths)
        self.embed_dim = embed_dim

        # Embedding layer (linear projection of patches)
        self.patch_embed = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

        # Transformer layers with downsampling stages
        self.layers = nn.ModuleList()
        self.downsamples = nn.ModuleList()

        # Initialize input_size after patch embedding
        input_size = [dim // p for dim, p in zip(img_size, patch_size)]  # e.g., [32, 32, 4]

        for i_layer in range(self.num_layers):
            # Swin Transformer blocks for this stage
            stage = nn.Sequential(
                *[SwinTransformerBlock3D(dim=self.embed_dim, num_heads=num_heads[i_layer], 
                                         window_size=2, mlp_ratio=mlp_ratio, dropout=dropout) 
                  for _ in range(depths[i_layer])]
            )
            self.layers.append(stage)

            if i_layer < self.num_layers - 1:
                # Determine kernel_size and stride for downsampling
                kernel_size, stride = self.get_safe_kernel_and_stride(input_size)

                downsample = nn.Conv3d(
                    self.embed_dim, self.embed_dim * 2, 
                    kernel_size=kernel_size, stride=stride
                )
                self.downsamples.append(downsample)
                self.embed_dim *= 2

                # Update input_size for next layer
                input_size = [
                    max(1, (size - (k - 1) - 1) // s + 1)
                    for size, k, s in zip(input_size, kernel_size, stride)
                ]

        # Final bottleneck layer
        self.bottleneck = nn.Conv3d(self.embed_dim, self.embed_dim, kernel_size=1)

        # Final classifier with dropout
        self.fc = nn.Sequential(
            nn.Linear(self.embed_dim, 512),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(512, num_classes)
        )

    def get_safe_kernel_and_stride(self, input_size):
        """Ensure the kernel size and stride do not exceed the input size after downsampling."""
        kernel_size = []
        stride = []
        for size in input_size:
            if size >= 2:
                kernel_size.append(2)
                stride.append(2)
            else:
                kernel_size.append(1)
                stride.append(1)
        return tuple(kernel_size), tuple(stride)

    def forward(self, x):
        # Patch embedding
        x = self.patch_embed(x)  # (B, C, D', H', W')

        # Reshape to (B, D', H', W', C)
        B, C, D, H, W = x.shape
        x = x.permute(0, 2, 3, 4, 1)

        # Transformer layers with downsampling
        for i_layer, layer in enumerate(self.layers):
            x = layer(x)
            if i_layer < self.num_layers - 1:
                x = x.permute(0, 4, 1, 2, 3)  # Move channels back to second dimension
                x = self.downsamples[i_layer](x)  # Apply downsampling
                x = x.permute(0, 2, 3, 4, 1)  # Move channels back to last dimension

        # Bottleneck layer
        x = x.permute(0, 4, 1, 2, 3)
        x = self.bottleneck(x)

        # Global average pooling
        x = x.mean(dim=[2, 3, 4])

        # Classification
        x = self.fc(x)

        return x


# Dataset for 3D images
class CustomDataset3D(Dataset):
    def __init__(self, root_dir, class_labels, target_shape=(256, 256, 32)):
        self.root_dir = root_dir
        self.class_labels = class_labels
        self.target_shape = target_shape
        self.samples = self._load_samples()

    def _load_samples(self):
        samples = []
        for class_name, label in self.class_labels.items():
            class_path = os.path.join(self.root_dir, class_name)
            if os.path.isdir(class_path):
                for file_name in os.listdir(class_path):
                    if file_name.endswith('.nii') or file_name.endswith('.nii.gz'):
                        img_path = os.path.join(class_path, file_name)
                        samples.append((img_path, label))
        return samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        img = nib.load(img_path).get_fdata()
        img = self._pad_or_crop(img)
        img = self._normalize(img)
        img = torch.tensor(img, dtype=torch.float32).unsqueeze(0)  # Add channel dimension
        label = torch.tensor(label, dtype=torch.long)
        return img, label

    def _pad_or_crop(self, img):
        target_shape = self.target_shape
        current_shape = img.shape
        slices = []
        pad_widths = []

        for curr_dim, target_dim in zip(current_shape, target_shape):
            if curr_dim > target_dim:
                start = (curr_dim - target_dim) // 2
                end = start + target_dim
                slices.append(slice(start, end))
                pad_widths.append((0, 0))
            elif curr_dim < target_dim:
                pad_size = target_dim - curr_dim
                pad_left = pad_size // 2
                pad_right = pad_size - pad_left
                slices.append(slice(0, curr_dim))
                pad_widths.append((pad_left, pad_right))
            else:
                slices.append(slice(0, curr_dim))
                pad_widths.append((0, 0))

        img = img[slices[0], slices[1], slices[2]]
        img = np.pad(img, pad_widths, mode='constant', constant_values=0)
        return img

    def _normalize(self, img):
        img = img - np.min(img)
        if np.max(img) != 0:
            img = img / np.max(img)
        return img

# Function to split dataset
def split_dataset(dataset, test_size=0.2):
    indices = list(range(len(dataset)))
    labels = [dataset.samples[i][1] for i in indices]
    train_indices, test_indices = train_test_split(indices, test_size=test_size, stratify=labels, random_state=24)
    return train_indices, test_indices

# Training and evaluation functions
def train_model(train_loader, model, criterion, optimizer, scheduler, device):
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = model(images)
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    scheduler.step()  # Move scheduler step outside the batch loop
    return running_loss / len(train_loader)

def evaluate_model(test_loader, model, criterion, device):
    model.eval()
    true_labels, pred_labels, pred_probs = [], [], []
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            probs = nn.functional.softmax(outputs, dim=1)
            _, preds = torch.max(outputs, 1)
            true_labels.extend(labels.cpu().numpy())
            pred_labels.extend(preds.cpu().numpy())
            pred_probs.extend(probs[:, 1].cpu().numpy())

    acc = accuracy_score(true_labels, pred_labels)
    cm = confusion_matrix(true_labels, pred_labels)
    f1 = f1_score(true_labels, pred_labels)
    try:
        auc = roc_auc_score(true_labels, pred_probs)
        auc_pr = average_precision_score(true_labels, pred_probs)
    except ValueError:
        auc = np.nan
        auc_pr = np.nan

    return acc, cm, f1, auc, auc_pr

def cross_validate(dataset, model_class, criterion, optimizer_class, scheduler_class, device, num_epochs=50, k_folds=5):
    kf = KFold(n_splits=k_folds, shuffle=True, random_state=24)
    fold_results = []

    # Extract labels from the dataset
    labels = [dataset[i][1].item() for i in range(len(dataset))]

    for fold, (train_idx, val_idx) in enumerate(kf.split(range(len(labels)), labels)):
        print(f"\nFold {fold+1}/{k_folds}")

        # Create subsets for the current fold
        train_subset = Subset(dataset, train_idx)
        val_subset = Subset(dataset, val_idx)

        train_loader_fold = DataLoader(train_subset, batch_size=batch_size, shuffle=True, num_workers=0)
        val_loader_fold = DataLoader(val_subset, batch_size=batch_size, shuffle=False, num_workers=0)

        # Initialize model, optimizer, and scheduler
        model = model_class().to(device)
        optimizer = optimizer_class(model.parameters())
        scheduler = scheduler_class(optimizer)

        # Train and evaluate the model for this fold
        for epoch in range(num_epochs):
            train_loss = train_model(train_loader_fold, model, criterion, optimizer, scheduler, device)
            print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}')

        # Evaluate on the validation set
        val_acc, _, val_f1, val_auc, val_auc_pr = evaluate_model(val_loader_fold, model, criterion, device)
        print(f'Validation Accuracy: {val_acc:.4f}, F1 Score: {val_f1:.4f}, AUC: {val_auc:.4f}, AUC-PR: {val_auc_pr:.4f}')

        fold_results.append((val_acc, val_f1, val_auc, val_auc_pr))

    # Calculate and return average metrics across all folds
    avg_acc = np.mean([r[0] for r in fold_results])
    avg_f1 = np.mean([r[1] for r in fold_results])
    avg_auc = np.mean([r[2] for r in fold_results])
    avg_auc_pr = np.mean([r[3] for r in fold_results])

    return avg_acc, avg_f1, avg_auc, avg_auc_pr

# Function to save results to a text file
def save_results_to_file(file_path, cv_acc, test_acc, test_auc, test_auc_pr):
    with open(file_path, 'w') as f:
        f.write(f"Cross-validation Accuracy: {cv_acc:.4f}\n")
        f.write(f"Test Accuracy: {test_acc:.4f}\n")
        f.write(f"Test AUC: {test_auc:.4f}\n")
        f.write(f"Test AUC-PR: {test_auc_pr:.4f}\n")



In [22]:
import torch
from torchinfo import summary


In [7]:
summary(xmerModel.layers[0])

Layer (type:depth-idx)                                  Param #
Sequential                                              --
├─SwinTransformerBlock3D: 1-1                           --
│    └─LayerNorm: 2-1                                   192
│    └─LayerNorm: 2-2                                   192
│    └─LayerNorm: 2-3                                   192
│    └─LayerNorm: 2-4                                   192
│    └─LayerNorm: 2-5                                   192
│    └─LayerNorm: 2-6                                   192
│    └─MultiheadAttention: 2-7                          27,936
│    │    └─NonDynamicallyQuantizableLinear: 3-1        9,312
│    └─Sequential: 2-8                                  --
│    │    └─Linear: 3-2                                 37,248
│    │    └─GELU: 3-3                                   --
│    │    └─Dropout: 3-4                                --
│    │    └─Linear: 3-5                                 36,960
│    │    └─Dropout: 3-6      

In [23]:
from torchinfo import summary

xmerModel = SwinTransformer3D(img_size=(240, 240, 155), patch_size=(8,8,5), in_chans=5, num_classes=2)
# summary(xmerModel, input_size=(1, 5, 240, 240, 155))
# summary(xmerModel.layers[0])

In [3]:
from torchinfo import summary
from model.VisionMamba3D import VisionMamba3D

mambamodel = VisionMamba3D(
    img_size=(155, 240, 240),
    patch_size=(5, 8, 8),
    in_chans=5, num_classes=2,
    debug=False,
)
summary(mambamodel, input_size=(1, 5, 155, 240, 240), depth=5)

Layer (type:depth-idx)                        Output Shape              Param #
VisionMamba3D                                 [1, 2]                    --
├─PatchEmbedding3D: 1-1                       [1, 96, 31, 30, 30]       --
│    └─Conv3d: 2-1                            [1, 96, 31, 30, 30]       153,696
├─ModuleList: 1-8                             --                        (recursive)
│    └─Sequential: 2-2                        [1, 27900, 96]            --
│    │    └─TransformerBlockWithSSM: 3-1      [1, 27900, 96]            --
│    │    │    └─MambaLayer: 4-1              [1, 27900, 96]            --
│    │    │    └─LayerNorm: 4-2               [1, 27900, 96]            192
│    │    │    └─Sequential: 4-3              [1, 27900, 96]            --
│    │    │    │    └─Linear: 5-1             [1, 27900, 384]           37,248
│    │    │    │    └─GELU: 5-2               [1, 27900, 384]           --
│    │    │    │    └─Dropout: 5-3            [1, 27900, 384]           --
│

In [3]:
from torchinfo import summary

xmerModel = SwinTransformer3D(img_size=(155, 240, 240), patch_size=(8,8,8), in_chans=5, num_classes=2)
summary(xmerModel, input_size=(2, 5, *(155, 240, 240)))

Layer (type:depth-idx)                        Output Shape              Param #
SwinTransformer3D                             [2, 2]                    --
├─Conv3d: 1-1                                 [2, 96, 19, 30, 30]       245,856
├─ModuleList: 1-8                             --                        (recursive)
│    └─Sequential: 2-1                        [2, 19, 30, 30, 96]       --
│    │    └─SwinTransformerBlock3D: 3-1       [2, 19, 30, 30, 96]       112,608
│    │    └─SwinTransformerBlock3D: 3-2       [2, 19, 30, 30, 96]       112,608
├─ModuleList: 1-7                             --                        (recursive)
│    └─Conv3d: 2-2                            [2, 192, 9, 15, 15]       147,648
├─ModuleList: 1-8                             --                        (recursive)
│    └─Sequential: 2-3                        [2, 9, 15, 15, 192]       --
│    │    └─SwinTransformerBlock3D: 3-3       [2, 9, 15, 15, 192]       446,400
│    │    └─SwinTransformerBlock3D: 3-4    

In [None]:
from torchinfo import summary
from model.VisionMamba3D import VisionMamba3D

model = VisionMamba3D(
    img_size=(155, 240, 240), patch_size=(4, 4, 3), in_chans=1, num_classes=2, depths=[4, 4, 4, 4],
    ).to('cuda')
summary(model)

Layer (type:depth-idx)                             Param #
VisionMamba3D                                      --
├─Conv3d: 1-1                                      4,704
├─ModuleList: 1-2                                  --
│    └─Sequential: 2-1                             --
│    │    └─TransformerBlockWithSSM: 3-1           111,264
│    │    └─TransformerBlockWithSSM: 3-2           111,264
│    │    └─TransformerBlockWithSSM: 3-3           111,264
│    │    └─TransformerBlockWithSSM: 3-4           111,264
│    └─Sequential: 2-2                             --
│    │    └─TransformerBlockWithSSM: 3-5           443,712
│    │    └─TransformerBlockWithSSM: 3-6           443,712
│    │    └─TransformerBlockWithSSM: 3-7           443,712
│    │    └─TransformerBlockWithSSM: 3-8           443,712
│    └─Sequential: 2-3                             --
│    │    └─TransformerBlockWithSSM: 3-9           1,772,160
│    │    └─TransformerBlockWithSSM: 3-10          1,772,160
│    │    └─Transfor

In [None]:
import torch
import time

from mamba_ssm import Mamba
from mamba_ssm import Mamba2
from mamba_ssm.modules.mamba2_simple import Mamba2Simple

# xs = [
#     torch.randn(2, 16*1024, 96).to('cuda'),
#     torch.randn(2, 16*1024//4, 192).to('cuda'),
#     torch.randn(2, 16*1024//16, 384).to('cuda'),
#     torch.randn(2, 16*1024//64, 576).to('cuda'),
# ]
xshapes = [(2, 16*1024, 96), (2, 16*1024//4, 192), (2, 16*1024//16, 384), (2, 16*1024//64, 576)]

for xshape in xshapes:
# for x in xs:
    # batch, length, dim = x.shape
    batch, length, dim = xshape
    x = torch.randn(batch, length, dim).to('cuda')
    for M in [Mamba, Mamba, Mamba, Mamba, Mamba, Mamba2, Mamba2, Mamba2, Mamba2, Mamba2, Mamba2Simple, Mamba2Simple, Mamba2Simple, Mamba2Simple, Mamba2Simple]:
        try:
            if (M==Mamba):
                model = M(
                    d_model=dim, # Model dimension d_model
                    d_state=64,  # SSM state expansion factor, typically 64 or 128
                    d_conv=4,    # Local convolution width
                    expand=2,    # Block expansion factor
                ).to("cuda")
            else:
                model = M(
                    d_model=dim, # Model dimension d_model
                    d_state=64,  # SSM state expansion factor, typically 64 or 128
                    headdim=4,  # Attention head dimension
                    d_conv=4,    # Local convolution width
                    expand=2,    # Block expansion factor
                ).to("cuda")
            st = time.time()
            if (model(x).shape!=xshape):
                print(M.__name__, 'error')
            print(M.__name__, time.time()-st)
        except Exception as e:
            print(e)
            pass

In [None]:
import torch
# from mamba_ssm import Mamba2
from mamba_ssm.modules.mamba2_simple import Mamba2Simple as Mamba

batch, length, dim = 10, 1024*1024, 128
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=64,  # SSM state expansion factor, typically 64 or 128
    headdim=4,  # Attention head dimension
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
).to("cuda")
y = model(x)
assert y.shape == x.shape

RuntimeError: CUDA error: no kernel image is available for execution on the device
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [2]:
import torch
import torch.nn as nn

class SimpleS4(nn.Module):
    def __init__(self, d_model, seq_len):
        super(SimpleS4, self).__init__()
        self.d_model = d_model
        self.seq_len = seq_len
        self.kernel = nn.Parameter(torch.randn(seq_len))
        self.linear = nn.Linear(d_model, d_model)
    
    def forward(self, x):
        b, n, d = x.shape
        # Convolution with the state-space kernel
        x_fft = torch.fft.rfft(x, dim=1)
        kernel_fft = torch.fft.rfft(self.kernel, n=n)
        kernel_fft = kernel_fft.view(1, -1, 1) 
        out = torch.fft.irfft(x_fft * kernel_fft, n=n, dim=1)
        return self.linear(out)

# device = torch.device('cuda')
device = torch.device('cpu')

# Usage
seq_len = 102400
d_model = 512
x = torch.randn(1, seq_len, d_model).to(device)
s4_model = SimpleS4(d_model, seq_len).to(device)


In [5]:

output = s4_model(x)

In [6]:
x.shape, output.shape

(torch.Size([1, 102400, 512]), torch.Size([1, 102400, 512]))

In [11]:
from importlib import reload
from utils.dataset import TumorMRIDataset
import utils
reload(utils.dataset)

<module 'utils.dataset' from 'c:\\Users\\shera\\Projects\\VisionMamba\\utils\\dataset.py'>

In [12]:
from utils.dataset import TumorMRIDataset
root_dir = './data/MICCAI_BraTS_2019_Data_Training/'
# Dataset and DataLoader
dataset = TumorMRIDataset(root_dir, limit=100)

In [23]:
t1, *t2 = dataset[0][0].shape

In [24]:
t1, t2

(5, [240, 240, 155])

In [14]:

from model.modules.ssm import SSM

ssm = SSM(
    in_features=256,  # Dimension of the transformer model
    dt_rank=32,  # Rank of the dynamic routing matrix
    dim_inner=256,  # Inner dimension of the transformer model
    d_state=256,  # Dimension of the state vector
)

In [12]:
from importlib import reload

import model
reload(model.modules.ssm)

<module 'model.modules.ssm' from 'c:\\Users\\shera\\Projects\\VisionMamba\\model\\modules\\ssm.py'>