In [1]:
# import os
# import torch
# import torch.nn as nn
# import torch.optim as optim
# import nibabel as nib
# import numpy as np
# from torch.utils.data import DataLoader, Dataset
# from torch.optim.lr_scheduler import StepLR
# from sklearn.metrics import accuracy_score, confusion_matrix, f1_score, roc_auc_score, average_precision_score
# from sklearn.model_selection import train_test_split, KFold
# from torch import autocast, GradScaler

# # Define the SSM Block for the 3D model
# class SSMBlock3D(nn.Module):
#     def __init__(self, input_dim, hidden_dim, output_dim):
#         super(SSMBlock3D, self).__init__()
#         # Transition matrix A, input matrix B, and observation matrix C
#         self.A = nn.Parameter(torch.randn(hidden_dim, hidden_dim))  # Transition matrix
#         self.B = nn.Parameter(torch.randn(input_dim, hidden_dim))    # Input matrix
#         self.C = nn.Parameter(torch.randn(hidden_dim, output_dim))   # Observation matrix
        
#         # Non-linear activation function
#         self.activation = nn.GELU()

#     def forward(self, x):
#         # x: (batch_size, depth, height, width, input_dim)
#         B, D, H, W, C = x.shape
#         x = x.reshape(B * D * H * W, C)  # Flatten spatial dimensions into one sequence

#         # Initialize the hidden state for each sequence
#         h = torch.zeros(B * D * H * W, self.B.shape[1]).to(x.device)

#         # Apply the input transformation for the entire input tensor
#         h = self.activation(torch.matmul(x, self.B))  # Input transformation

#         # Update the hidden state with the transition matrix A
#         h = self.activation(torch.matmul(h, self.A))

#         # Apply the observation matrix C to get the output
#         out = torch.matmul(h, self.C)

#         # Reshape output back to original spatial dimensions
#         out = out.reshape(B, D, H, W, -1)
        
#         return out


# # 3D Vision Mamba Model with SSM blocks and downsampling
# class VisionMamba3D(nn.Module):
#     def __init__(self, img_size=(240, 240, 155), patch_size=(4, 4, 4), in_chans=1, num_classes=2, embed_dim=96, depths=[4, 4, 4, 4], hidden_dim=128, output_dim=96):
#         super(VisionMamba3D, self).__init__()

#         self.num_layers = len(depths)
#         self.embed_dim = embed_dim

#         # Embedding layer (linear projection of patches)
#         self.patch_embed = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

#         # SSM layers with downsampling stages
#         self.layers = nn.ModuleList()
#         self.downsamples = nn.ModuleList()  # Define separate downsample layers

#         for i_layer in range(self.num_layers):
#             # Add SSM blocks for this stage, pass the correct input_dim (embed_dim)
#             stage = nn.Sequential(
#                 *[SSMBlock3D(input_dim=self.embed_dim, hidden_dim=hidden_dim, output_dim=self.embed_dim) for _ in range(depths[i_layer])]
#             )
#             self.layers.append(stage)

#             # Add downsampling layers after each stage except the last
#             if i_layer < self.num_layers - 1:
#                 downsample = nn.Conv3d(self.embed_dim, self.embed_dim * 2, kernel_size=(2, 2, 2), stride=(2, 2, 2))
#                 self.downsamples.append(downsample)
#                 self.embed_dim *= 2  # Update embedding dimension for the next layer

#         # Final bottleneck layer
#         self.bottleneck = nn.Conv3d(self.embed_dim, self.embed_dim, kernel_size=1)

#         # Final MLP classifier based on bottleneck features
#         self.fc = nn.Sequential(
#             nn.Linear(self.embed_dim, 512),
#             nn.ReLU(),
#             nn.Linear(512, num_classes)
#         )

#     def forward(self, x):
#         # Patch embedding
#         x = self.patch_embed(x)  # Output size should be (B, embed_dim, D/patch, H/patch, W/patch)

#         # Reshape to (B, D, H, W, C) format for SSM layers
#         B, C, D, H, W = x.shape
#         x = x.permute(0, 2, 3, 4, 1)  # Permute to (B, D, H, W, C)

#         print(x.shape)
#         # SSM blocks with downsampling
#         # for i_layer, layer in enumerate(self.layers):
#         for i_layer, layer in enumerate(1):
#             x = layer(x)
#             if i_layer < self.num_layers - 1:
#                 x = x.permute(0, 4, 1, 2, 3)  # Permute back to (B, C, D, H, W) for downsampling
#                 x = self.downsamples[i_layer](x)  # Apply downsampling
#                 x = x.permute(0, 2, 3, 4, 1)  # Permute back to (B, D, H, W, C)

#         print(x.shape)
#         # Bottleneck feature extraction
#         x = x.permute(0, 4, 1, 2, 3)  # Permute to (B, C, D, H, W)
#         x = self.bottleneck(x)
#         print('train log bottleneck', x)

#         # Global average pooling
#         x = x.mean(dim=[2, 3, 4])  # Pool over spatial dimensions

#         # Classification
#         x = self.fc(x)

#         return x


# # Dataset for 3D MRI images
# class TumorMRIDataset(Dataset):
#     def __init__(self, root_dir, limit=None):
#         self.root_dir = root_dir
#         self.samples = self._load_samples(root_dir, limit)

#     def _load_samples(self, root_dir, limit=None):
#         samples = []
#         for label in ['HGG', 'LGG']:
#             label_specific_sample_count = 0
#             folder_path = os.path.join(root_dir, label)
#             for patient_folder in os.listdir(folder_path):
#                 # Find any file that ends with 't1ce.nii'
#                 for file_name in os.listdir(os.path.join(folder_path, patient_folder)):
#                     if file_name.endswith('t1ce.nii'):
#                         img_path = os.path.join(folder_path, patient_folder, file_name)
#                         samples.append((img_path, 0 if label == 'HGG' else 1))
#                         label_specific_sample_count += 1
#                         if limit is not None and label_specific_sample_count >= limit:
#                             break
#                 if limit is not None and label_specific_sample_count >= limit:
#                     break
#         return samples

#     def __len__(self):
#         return len(self.samples)

#     def __getitem__(self, idx):
#         file_path, label = self.samples[idx]
#         img = nib.load(file_path).get_fdata()
#         print(type(img), img.shape)
#         img = self._pad_or_crop(img)
#         return torch.tensor(img, dtype=torch.float32).permute(2,0,1).unsqueeze(0), torch.tensor(label, dtype=torch.long)

#     def _pad_or_crop(self, img):
#         target_shape = img.shape
#         pad_size = [(max(0, target - img_dim)) for target, img_dim in zip(target_shape, img.shape)]
#         pad_widths = [(p // 2, p - p // 2) for p in pad_size]
#         img_padded = np.pad(img, pad_widths, mode='constant', constant_values=0)
#         return img_padded[:target_shape[0], :target_shape[1], :target_shape[2]]


# # Split the dataset into train and test sets by class
# def split_dataset_by_class(dataset, train_ratio=0.8):
#     HGG_samples = [sample for sample in dataset.samples if sample[1] == 0]
#     LGG_samples = [sample for sample in dataset.samples if sample[1] == 1]

#     # Split each class
#     HGG_train, HGG_test = train_test_split(HGG_samples, train_size=train_ratio, shuffle=True)
#     LGG_train, LGG_test = train_test_split(LGG_samples, train_size=train_ratio, shuffle=True)

#     # Combine the train and test samples
#     train_samples = HGG_train + LGG_train
#     test_samples = HGG_test + LGG_test

#     return train_samples, test_samples


# # Training function with mixed precision
# def train_model(train_loader, model, criterion, optimizer, scheduler, device, scaler):
#     model.train()
#     running_loss = 0.0
#     for images, labels in train_loader:
#         images, labels = images.to(device), labels.to(device)
        
#         optimizer.zero_grad()

#         # Use autocast for mixed precision
#         with autocast(device.type):
#             outputs = model(images)
#             loss = criterion(outputs, labels)

#         # Scale the loss before backward pass
#         scaler.scale(loss).backward()
#         scaler.step(optimizer)
#         scaler.update()

#         # Now step the scheduler after optimizer step
#         scheduler.step()
#         running_loss += loss.item()
#         print(loss)
#         print(loss.item())
#         print(running_loss)
#     return running_loss / len(train_loader)

# # Test and get predictions + loss 
# def test_model(test_loader, model, criterion, device):
#     model.eval()
#     pred_labels = []
#     running_loss = 0
#     with torch.no_grad():
#         for images, labels in test_loader:
#             images, labels = images.to(device), labels.to(device)
#             outputs = model(images)
#             _, preds = torch.max(outputs, 1)
#             pred_labels.extend(preds.cpu().numpy())
#             loss = criterion(outputs, labels)
#             running_loss += loss.item()
#     return pred_labels, running_loss / len(test_loader)



# # Evaluation function
# def evaluate_model(test_loader, model, criterion, device):
#     model.eval()
#     true_labels, pred_labels = [], []
#     with torch.no_grad():
#         for images, labels in test_loader:
#             images, labels = images.to(device), labels.to(device)
#             outputs = model(images)
#             _, preds = torch.max(outputs, 1)
#             true_labels.extend(labels.cpu().numpy())
#             pred_labels.extend(preds.cpu().numpy())

#     acc = accuracy_score(true_labels, pred_labels)
#     cm = confusion_matrix(true_labels, pred_labels)
#     f1 = f1_score(true_labels, pred_labels)
#     auc = roc_auc_score(true_labels, pred_labels)
#     auc_pr = average_precision_score(true_labels, pred_labels)

#     return acc, cm, f1, auc, auc_pr


# # Cross-validation function
# def cross_validate(train_loader, model_class, criterion, optimizer_class, scheduler_class, device, num_epochs=20, k_folds=5):
#     kf = KFold(n_splits=k_folds, shuffle=True)
#     fold_results = []

#     for fold, (train_idx, val_idx) in enumerate(kf.split(train_loader.dataset)):
#         print(f"Fold {fold+1}/{k_folds}")

#         train_subset = torch.utils.data.Subset(train_loader.dataset, train_idx)
#         val_subset = torch.utils.data.Subset(train_loader.dataset, val_idx)

#         train_loader_fold = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
#         val_loader_fold = DataLoader(val_subset, batch_size=batch_size, shuffle=False)

#         # Instantiate model and optimizer
#         model = model_class().to(device)
#         optimizer = optim.AdamW(model.parameters(), lr=initial_lr, weight_decay=weight_decay)
#         scheduler = scheduler_class(optimizer)

#         # Initialize GradScaler for mixed precision training
#         scaler = GradScaler()

#         # Train and evaluate on each fold
#         for epoch in range(num_epochs):
#             train_loss = train_model(train_loader_fold, model, criterion, optimizer, scheduler, device, scaler)
#             print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}')

#         val_acc, _, val_f1, val_auc, val_auc_pr = evaluate_model(val_loader_fold, model, criterion, device)
#         fold_results.append((val_acc, val_f1, val_auc, val_auc_pr))

#     # Return average metrics across folds
#     avg_acc = np.mean([r[0] for r in fold_results])
#     avg_f1 = np.mean([r[1] for r in fold_results])
#     avg_auc = np.mean([r[2] for r in fold_results])
#     avg_auc_pr = np.mean([r[3] for r in fold_results])

#     return avg_acc, avg_f1, avg_auc, avg_auc_pr


# # Results saving function
# def save_results_to_file(file_path, cv_acc, test_acc, test_auc, test_auc_pr):
#     with open(file_path, 'w') as f:
#         f.write(f"Cross-validation Accuracy: {cv_acc:.4f}\n")
#         f.write(f"Test Accuracy: {test_acc:.4f}\n")
#         f.write(f"Test AUC: {test_auc:.4f}\n")
#         f.write(f"Test AUC-PR: {test_auc_pr:.4f}\n")


In [1]:
from torchinfo import summary
from model.VisionMamba3D import VisionMamba3D

model = VisionMamba3D(
    img_size=(155, 240, 240), patch_size=(4, 4, 3), in_chans=1, num_classes=2, embed_dim=96, depths=[4, 4, 4, 4], hidden_dim=128,
    ).to('cuda')
summary(model)

Layer (type:depth-idx)                   Param #
VisionMamba3D                            --
├─Conv3d: 1-1                            4,704
├─ModuleList: 1-2                        --
│    └─Sequential: 2-1                   --
│    │    └─SSMBlock3D: 3-1              40,960
│    │    └─SSMBlock3D: 3-2              40,960
│    │    └─SSMBlock3D: 3-3              40,960
│    │    └─SSMBlock3D: 3-4              40,960
│    └─Sequential: 2-2                   --
│    │    └─SSMBlock3D: 3-5              65,536
│    │    └─SSMBlock3D: 3-6              65,536
│    │    └─SSMBlock3D: 3-7              65,536
│    │    └─SSMBlock3D: 3-8              65,536
│    └─Sequential: 2-3                   --
│    │    └─SSMBlock3D: 3-9              114,688
│    │    └─SSMBlock3D: 3-10             114,688
│    │    └─SSMBlock3D: 3-11             114,688
│    │    └─SSMBlock3D: 3-12             114,688
│    └─Sequential: 2-4                   --
│    │    └─SSMBlock3D: 3-13             212,992
│    │    └

In [5]:
from torchinfo import summary
from model.VisionMamba3D_2 import VisionMamba3D

model = VisionMamba3D(
    img_size=(155, 240, 240), patch_size=(4, 4, 3), in_chans=1, num_classes=2, depths=[4, 4, 4, 4],
    ).to('cuda')
summary(model)

Layer (type:depth-idx)                             Param #
VisionMamba3D                                      --
├─Conv3d: 1-1                                      4,704
├─ModuleList: 1-2                                  --
│    └─Sequential: 2-1                             --
│    │    └─TransformerBlockWithSSM: 3-1           111,264
│    │    └─TransformerBlockWithSSM: 3-2           111,264
│    │    └─TransformerBlockWithSSM: 3-3           111,264
│    │    └─TransformerBlockWithSSM: 3-4           111,264
│    └─Sequential: 2-2                             --
│    │    └─TransformerBlockWithSSM: 3-5           443,712
│    │    └─TransformerBlockWithSSM: 3-6           443,712
│    │    └─TransformerBlockWithSSM: 3-7           443,712
│    │    └─TransformerBlockWithSSM: 3-8           443,712
│    └─Sequential: 2-3                             --
│    │    └─TransformerBlockWithSSM: 3-9           1,772,160
│    │    └─TransformerBlockWithSSM: 3-10          1,772,160
│    │    └─Transfor

In [6]:
model.device

device(type='cuda', index=0)

In [4]:
import torch
# from mamba_ssm import Mamba2
from mamba_ssm.modules.mamba2_simple import Mamba2Simple as Mamba

batch, length, dim = 2, 64, 128
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=64,  # SSM state expansion factor, typically 64 or 128
    headdim=4,  # Attention head dimension
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
).to("cuda")
y = model(x)
assert y.shape == x.shape

RuntimeError: CUDA error: no kernel image is available for execution on the device
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [2]:
import torch
import torch.nn as nn

class SimpleS4(nn.Module):
    def __init__(self, d_model, seq_len):
        super(SimpleS4, self).__init__()
        self.d_model = d_model
        self.seq_len = seq_len
        self.kernel = nn.Parameter(torch.randn(seq_len))
        self.linear = nn.Linear(d_model, d_model)
    
    def forward(self, x):
        b, n, d = x.shape
        # Convolution with the state-space kernel
        x_fft = torch.fft.rfft(x, dim=1)
        kernel_fft = torch.fft.rfft(self.kernel, n=n)
        kernel_fft = kernel_fft.view(1, -1, 1) 
        out = torch.fft.irfft(x_fft * kernel_fft, n=n, dim=1)
        return self.linear(out)

# device = torch.device('cuda')
device = torch.device('cpu')

# Usage
seq_len = 102400
d_model = 512
x = torch.randn(1, seq_len, d_model).to(device)
s4_model = SimpleS4(d_model, seq_len).to(device)


In [5]:

output = s4_model(x)

In [6]:
x.shape, output.shape

(torch.Size([1, 102400, 512]), torch.Size([1, 102400, 512]))

In [13]:
!nvidia-smi

Thu Oct  3 01:16:23 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 551.61                 Driver Version: 551.61         CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                     TCC/WDDM  | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce GTX 1050      WDDM  |   00000000:01:00.0 Off |                  N/A |
| N/A   33C    P8             N/A / ERR!  |      46MiB /   4096MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [11]:
from importlib import reload
from utils.dataset import TumorMRIDataset
import utils
reload(utils.dataset)

<module 'utils.dataset' from 'c:\\Users\\shera\\Projects\\VisionMamba\\utils\\dataset.py'>

In [12]:
from utils.dataset import TumorMRIDataset
root_dir = './data/MICCAI_BraTS_2019_Data_Training/'
# Dataset and DataLoader
dataset = TumorMRIDataset(root_dir, limit=100)

In [23]:
t1, *t2 = dataset[0][0].shape

In [24]:
t1, t2

(5, [240, 240, 155])

In [1]:
import torch
x = torch.randn(2, 8, 50, 120, 120) # (B, C, D, H, W)
print(x.shape)
x = x.permute(0, 2, 3, 4, 1) # (B, D, H, W, C)
print(x.shape)
x = x.reshape(x.shape[0], -1, x.shape[-1]) # (B, D*H*W, C)
print(x.shape)
x = x.reshape(x.shape[0], -1, x.shape[-1]) # (B, D*H*W, C)
print(x.shape)

torch.Size([2, 8, 50, 120, 120])
torch.Size([2, 50, 120, 120, 8])
torch.Size([2, 720000, 8])
torch.Size([2, 720000, 8])


In [14]:

from model.modules.ssm import SSM

ssm = SSM(
    in_features=256,  # Dimension of the transformer model
    dt_rank=32,  # Rank of the dynamic routing matrix
    dim_inner=256,  # Inner dimension of the transformer model
    d_state=256,  # Dimension of the state vector
)

In [12]:
from importlib import reload

import model
reload(model.modules.ssm)

<module 'model.modules.ssm' from 'c:\\Users\\shera\\Projects\\VisionMamba\\model\\modules\\ssm.py'>

In [20]:
import time

In [21]:
time.time()

1728368110.9876351

In [26]:
a = [1,2,3,4,5]
b = [6,7,8,9,10]

torch.save({'a': a, 'b': b}, 'test.pt')

In [27]:
x = torch.load('test.pt')

  x = torch.load('test.pt')


In [30]:
type(x['a'])

list

In [31]:
import torch
x = torch.randn(2, 51*60*60, 256)


In [33]:
f"{x.numpy()}"

'[[[ 1.292295    0.12974325  1.9748904  ...  0.89918077  0.5026496\n    0.495291  ]\n  [-1.1271591   0.507634   -0.22853649 ...  1.1101712  -0.9780393\n   -1.4631107 ]\n  [ 0.84693134 -1.4322822   0.93316334 ...  0.7821969  -0.892769\n   -1.0943215 ]\n  ...\n  [-2.1517289   0.88686776  1.237539   ... -1.1798173   0.28577477\n    1.0158257 ]\n  [ 0.9999575  -1.5350035  -0.18565762 ... -0.7606606  -0.8681743\n   -0.50495344]\n  [-0.9347093  -0.40063527  0.91297567 ...  0.16093272  0.8936355\n   -0.60487616]]\n\n [[ 0.8504815   1.9720266   0.6084179  ... -0.96612144  1.3971791\n   -0.05779519]\n  [-0.80218816 -0.95540756  0.46097296 ...  3.0260103  -0.47619528\n   -2.0453818 ]\n  [ 1.3117734  -0.79209846  1.7043817  ...  0.2105545  -0.6013392\n    1.0901706 ]\n  ...\n  [-2.1645515   0.10518072  0.4936595  ... -0.74191993  0.59055686\n   -0.7300158 ]\n  [-1.4442246  -0.05387694  0.9841744  ...  0.9396119   0.6107726\n    1.0055447 ]\n  [ 0.06521791 -0.4809723  -0.06550229 ...  1.2880242   

In [19]:
torch.randn(2, 183600, 256).unsqueeze(-1).shape

torch.Size([2, 183600, 256, 1])

In [14]:

# Paths and configurations
# root_dir = '/src/workspace/MICCAI_BraTS_2019_Data_Training/'  # Change to your dataset path
root_dir = './MICCAI_BraTS_2019_Data_Training/'  # Change to your dataset path
batch_size = 2
initial_lr = 1e-5  # Lowered learning rate to prevent NaN
num_epochs = 20
weight_decay = 1e-5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Dataset and DataLoader
dataset = TumorMRIDataset(root_dir, limit=100)
train_samples, test_samples = split_dataset_by_class(dataset)
print(f"Train samples: {len(train_samples)}, Test samples: {len(test_samples)}")

# Create datasets and loaders for train and test sets
train_dataset = torch.utils.data.Subset(dataset, [dataset.samples.index(s) for s in train_samples])
test_dataset = torch.utils.data.Subset(dataset, [dataset.samples.index(s) for s in test_samples])

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Model, Loss, Optimizer, and Scheduler
model_class = lambda: VisionMamba3D(img_size=(240, 240, 155), patch_size=(4, 4, 4), in_chans=1, num_classes=2)
criterion = nn.CrossEntropyLoss()
optimizer_class = lambda params: optim.AdamW(params, lr=initial_lr, weight_decay=weight_decay)
scheduler_class = lambda opt: StepLR(opt, step_size=5, gamma=0.5)

# Perform cross-validation
# cv_acc, cv_f1, cv_auc, cv_auc_pr = cross_validate(train_loader, model_class, criterion, optimizer_class, scheduler_class, device, num_epochs=num_epochs, k_folds=5)

# Train on the full training set and evaluate on the test set
model = model_class().to(device)
optimizer = optimizer_class(model.parameters())
scheduler = scheduler_class(optimizer)

# Initialize GradScaler for full training
scaler = GradScaler()

# Train model on the entire training set
train_losses, test_losses = [], []
for epoch in range(num_epochs):
    train_loss = train_model(train_loader, model, criterion, optimizer, scheduler, device, scaler)
    _, test_loss = test_model(test_loader, model, criterion, device)
    print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}')

# Evaluate on the test set
test_acc, _, test_f1, test_auc, test_auc_pr = evaluate_model(test_loader, model, criterion, device)


Train samples: 140, Test samples: 36
