In [3]:
import torch.nn as nn
import sys
sys.path.append('/home/fonta42/Desktop/masters-degree/U-Mamba')
from umamba.nnunetv2.nets.UMambaEnc_2d import UMambaEnc
import segmentation_models_pytorch as smp

  @autocast(enabled=False)


In [4]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [10]:
unet_model = smp.Unet(
        encoder_name="resnet50",
        encoder_weights="imagenet",
        in_channels=3,
        classes=1,
)

num_params = count_parameters(unet_model)
print(f"Number of trainable parameters in U-Net: {num_params:,}")

Number of trainable parameters in U-Net: 32,521,105


In [9]:
# Example configuration for a UMamba model aiming to match a ResNet-50 UNet
# Define the input patch size as a tuple (height, width)
input_size = (256, 256)           # The resolution of the input images (or patches) fed into the network

# Define the number of input channels
in_channels = 3                   # Number of channels in the input images; 3 for standard RGB images

# Define the number of resolution stages in the encoder and decoder
n_stages = 5                      # The network is structured into 5 stages (levels); each stage performs downsampling/upscaling

# Define the number of feature channels for each stage
features_per_stage = [64, 128, 256, 512, 1024]  
# For each stage, the network will use these many channels. Increasing these values boosts model capacity.
# Here, stage 1 has 64 channels, stage 2 has 128, and so on, with the deepest stage having 1024 channels.

# Define the convolution kernel size for each stage
kernel_sizes = [[3, 3]] * n_stages       
# Each stage will use a 3x3 convolution kernel. This list is replicated for each stage.

# Define the stride for each stage's downsampling operation
strides = [[1, 1], [2, 2], [2, 2], [2, 2], [2, 2]]
# The first stage uses a stride of 1 (no downsampling), and each subsequent stage uses a stride of 2,
# which halves the spatial dimensions at each stage.

# Define the number of convolutional blocks per stage in the encoder
n_conv_per_stage = 1              
# Each stage in the encoder will have 1 convolutional block. Increasing this number will add more layers per stage.

# Define the number of convolutional blocks per stage in the decoder
n_conv_per_stage_decoder = 1       
# Each stage in the decoder will have 1 convolutional block, mirroring the encoder structure.
    
num_classes = 1                   # Binary segmentation
conv_op = nn.Conv2d              
norm_op = nn.InstanceNorm2d      
norm_op_kwargs = {'eps': 1e-5, 'affine': True}
nonlin = nn.LeakyReLU            
nonlin_kwargs = {'inplace': True}
deep_supervision = False          # Use single output (to match SMP Unet)
stem_channels = features_per_stage[0]

# Instantiate UMambaEnc directly:

umamba_model = UMambaEnc(
    input_size=input_size,
    input_channels=in_channels,
    n_stages=n_stages,
    features_per_stage=features_per_stage,
    conv_op=conv_op,
    kernel_sizes=kernel_sizes,
    strides=strides,
    n_conv_per_stage=n_conv_per_stage,
    num_classes=num_classes,
    n_conv_per_stage_decoder=n_conv_per_stage_decoder,
    conv_bias=True,
    norm_op=norm_op,
    norm_op_kwargs=norm_op_kwargs,
    dropout_op=None,
    dropout_op_kwargs=None,
    nonlin=nonlin,
    nonlin_kwargs=nonlin_kwargs,
    deep_supervision=deep_supervision,
    stem_channels=stem_channels
)

num_params = count_parameters(umamba_model)
print(f"Number of trainable parameters in U-Mamba: {num_params:,}")

feature_map_sizes: [[256, 256], [128, 128], [64, 64], [32, 32], [16, 16]]
do_channel_token: [False, False, False, False, True]
MambaLayer: dim: 64
MambaLayer: dim: 256
MambaLayer: dim: 256
Number of trainable parameters in U-Mamba: 31,297,860


In [8]:
import torch
from segment_anything import sam_model_registry
sys.path.append('../MedSAM')


class MedSAM(nn.Module):
    def __init__(self, sam_model):
        super(MedSAM, self).__init__()
        self.image_encoder = sam_model.image_encoder
        self.mask_decoder = sam_model.mask_decoder
        self.prompt_encoder = sam_model.prompt_encoder

        for param in self.prompt_encoder.parameters():
            param.requires_grad = False

    def forward(self, images, box):
        image_embedding = self.image_encoder(images)
        box_torch = torch.as_tensor(box, dtype=torch.float32, device=images.device)
        if len(box_torch.shape) == 2:
            box_torch = box_torch[:, None, :]

        sparse_embeddings, dense_embeddings = self.prompt_encoder(
            points=None, boxes=box_torch, masks=None
        )

        low_res_masks, _ = self.mask_decoder(
            image_embeddings=image_embedding,
            image_pe=self.prompt_encoder.get_dense_pe(),
            sparse_prompt_embeddings=sparse_embeddings,
            dense_prompt_embeddings=dense_embeddings,
            multimask_output=False,
        )

        ori_res_masks = F.interpolate(
            low_res_masks,
            size=(images.shape[2], images.shape[3]),
            mode="bilinear",
            align_corners=False,
        )
        return ori_res_masks

# Load the MedSAM model
MedSAM_CKPT_PATH = "/home/fonta42/Desktop/masters-degree/MedSAM/work_dir/MedSAM/medsam_vit_b.pth"
sam_model_inst = sam_model_registry['vit_b'](checkpoint=MedSAM_CKPT_PATH)
medsam_model = MedSAM(sam_model_inst)

# Count and print the number of parameters
num_params = count_parameters(medsam_model)
print(f"Number of trainable parameters in MedSAM: {num_params:,}")


Number of trainable parameters in MedSAM: 93,729,252
