In [1]:
from monai.networks.nets import resnet50, EfficientNetBN
from monai.networks.nets import get_efficientnet_image_size
import torch
import torch.nn as nn

import segmentation_models_pytorch as smp

This helper function sets the .requires_grad attribute of the parameters in the model to False when we are feature extracting. By default, when we load a pretrained model all of the parameters have .requires_grad=True, which is fine if we are training from scratch or finetuning. However, if we are feature extracting and only want to compute gradients for the newly initialized layer then we want all of the other parameters to not require gradients

In [6]:
model = smp.Unet(
    encoder_name="resnet50",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=1,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=3,                      # model output channels (number of classes in your dataset)
)

encoder = list(model.encoder.children())
print(encoder)

[Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False), BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), ReLU(inplace=True), MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False), Sequential(
  (0): Bottleneck(
    (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (downsample): Sequential(
      (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=Tr

In [7]:
features=nn.Sequential(*encoder)

for p in features.parameters():
    p.requires_grad = False

In [15]:
model = nn.Sequential(
          features,
          nn.AdaptiveAvgPool2d(output_size=(1, 1)),
          nn.Flatten(),
          nn.Linear(2048, 1)
        )

print(model)

Sequential(
  (0): Sequential(
    (0): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          (0): Conv2d(64, 256

In [16]:
inputs = torch.rand(4, 1, 256, 256)

output = model(inputs)
print(output.shape)

torch.Size([4, 5])


In [None]:
from torchvision import datasets, models, transforms

model_ft = models.resnet50(pretrained=True)

print(model_ft)

In [None]:
print(encoder.out_channels)

In [None]:
# import torch
# import torch.nn as nn


# class MRNet(nn.Module):
#     def __init__(self):
#         super().__init__()
#         self.pretrained_model = ResNet(pretrained=True)
#         self.pooling_layer = nn.AdaptiveAvgPool2d(1)
#         self.classifer = nn.Linear(256, 2)

#     def forward(self, x):
#         x = torch.squeeze(x, dim=0) 
#         features = self.pretrained_model.features(x)
#         pooled_features = self.pooling_layer(features)
#         pooled_features = pooled_features.view(pooled_features.size(0), -1)
#         flattened_features = torch.max(pooled_features, 0, keepdim=True)[0]
#         output = self.classifer(flattened_features)
#         return output