In [14]:
import torch

def inspect_model_size(model):
    # for name, param in model.named_parameters():
    #     print(f'Layer: {name} | Size: {param.size()} | Values : {param[:2]}')
    total_params = sum(p.numel() for p in model.parameters())
    print(f'Total number of parameters in the model: {total_params}')
    # convert to M
    total_params /= 1e6
    print(f'Total number of parameters in the model: {total_params} M')
    # if default dtype is float32
    # total memory in MB
    total_memory = total_params * 4
    print(f'Total memory {total_memory} MB')

In [17]:
from monai.networks.nets import UNet

model = UNet(
    spatial_dims=2,
    in_channels=3,
    out_channels=3,
    channels=(64, 128, 256, 512),
    strides=(2, 2, 2),
    num_res_units=6,
)

inspect_model_size(model)

Total number of parameters in the model: 18945844
Total number of parameters in the model: 18.945844 M
Total memory 75.783376 MB


In [10]:
# now we try SegResNetDS

# network:
#   {_target_: SegResNetDS,
#   init_filters: 32,
#   blocks_down: [1, 2, 2, 4, 4, 4],
#   norm: INSTANCE, # INSTANCE , BATCH
#   in_channels: '@input_channels',
#   out_channels: '@output_classes',
#   dsdepth: 4}

In [16]:
from monai.networks.nets import SegResNetDS

input_channels = 3  # Assuming input channels are 3
output_classes = 3  # Assuming output classes are 3

segresnetds_model = SegResNetDS(
    spatial_dims=2,
    init_filters=32,
    blocks_down=[1, 2, 2, 4, 4, 4],
    norm='INSTANCE',  # INSTANCE or BATCH
    in_channels=input_channels,
    out_channels=output_classes,
    dsdepth=4
)

inspect_model_size(segresnetds_model)

Total number of parameters in the model: 118738700
Total number of parameters in the model: 118.7387 M
Total memory 474.9548 MB


In [21]:
from monai.networks.nets import SEResNet101

seresnet101_model = SEResNet101(
    spatial_dims=2,
    in_channels=3,
    num_classes=3
)

inspect_model_size(seresnet101_model)

Total number of parameters in the model: 47284019
Total number of parameters in the model: 47.284019 M
Total memory 189.136076 MB


In [23]:
from monai.networks.nets import SEResNet101

import ssl
import torch

# Create an unverified SSL context
ssl._create_default_https_context = ssl._create_unverified_context

def inspect_model_size(model):
    # for name, param in model.named_parameters():
    #     print(f'Layer: {name} | Size: {param.size()} | Values : {param[:2]}')
    total_params = sum(p.numel() for p in model.parameters())
    print(f'Total number of parameters in the model: {total_params}')
    # convert to M
    total_params /= 1e6
    print(f'Total number of parameters in the model: {total_params} M')
    # if default dtype is float32
    # total memory in MB
    total_memory = total_params * 4
    print(f'Total memory {total_memory} MB')

seresnet101_model = SEResNet101(
    spatial_dims=2,
    in_channels=3,
    num_classes=3,
    pretrained=True,
)

inspect_model_size(seresnet101_model)

Downloading: "http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth" to /Users/WXC321/.cache/torch/hub/checkpoints/se_resnet101-7e38fcc6.pth
100%|██████████| 189M/189M [09:15<00:00, 356kB/s] 

Total number of parameters in the model: 47284019
Total number of parameters in the model: 47.284019 M
Total memory 189.136076 MB





In [24]:
# save the model

save_name = '../seresnet101_model_pretrained.pth'
torch.save(seresnet101_model.state_dict(), save_name)
print(f'Model saved to {save_name}')

Model saved to ../seresnet101_model_pretrained.pth
