In [5]:
import torch
import torch.nn as nn
from torchsummary import summary
import segmentation_models_pytorch as smp

class Network(nn.Module):

  def __init__(self,embedding_size=128):
    super(Network,self).__init__()
    self.seg_model = smp.DeepLabV3Plus('resnet101',classes=2,in_channels=3,encoder_weights='imagenet',activation=None)
    self.encoder= self.seg_model.encoder
    self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
    self.fc = nn.Sequential(
                nn.Linear(2048,2048),
                nn.InstanceNorm1d(2048),
                nn.ReLU(),
                nn.Linear(2048,embedding_size),
                nn.InstanceNorm1d(128)
        ) #2048 for ResNet50 and 101;
    self.contrast=False

  def forward(self,x):
    if self.contrast is True:
        # print('yo')
        x =self.encoder(x)
        x=x[-1] #Taking the last feature map only
        x =self.avgpool(x)
        x = torch.flatten(x, 1)
        x =self.fc(x)
        return x
    else:
        return self.seg_model(x)
    
dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = Network()
model.contrast=False

model = nn.DataParallel(model)
model = model.to(dev)

model.eval()
summary(model, (3, 512, 512), 8, 'cuda')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [8, 64, 256, 256]           9,408
            Conv2d-2          [8, 64, 256, 256]           9,408
            Conv2d-3          [8, 64, 256, 256]           9,408
            Conv2d-4          [8, 64, 256, 256]           9,408
       BatchNorm2d-5          [8, 64, 256, 256]             128
       BatchNorm2d-6          [8, 64, 256, 256]             128
       BatchNorm2d-7          [8, 64, 256, 256]             128
       BatchNorm2d-8          [8, 64, 256, 256]             128
              ReLU-9          [8, 64, 256, 256]               0
             ReLU-10          [8, 64, 256, 256]               0
             ReLU-11          [8, 64, 256, 256]               0
             ReLU-12          [8, 64, 256, 256]               0
        MaxPool2d-13          [8, 64, 128, 128]               0
        MaxPool2d-14          [8, 64, 1

In [None]:
import torch
import torch.nn as nn
from torchsummary import summary
import segmentation_models_pytorch as smp

class Network(nn.Module):

  def __init__(self,embedding_size=128):
    super(Network,self).__init__()
    self.seg_model = smp.DeepLabV3Plus('resnet50',classes=2,in_channels=3,encoder_weights='imagenet',activation=None)
    self.encoder= self.seg_model.encoder
    self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
    self.fc = nn.Sequential(
                nn.Linear(2048,2048),
                nn.InstanceNorm1d(2048),
                nn.ReLU(),
                nn.Linear(2048,embedding_size),
                nn.InstanceNorm1d(128)
        ) #2048 for ResNet50 and 101;
    self.contrast=False

  def forward(self,x):
    if self.contrast is True:
        # print('yo')
        x =self.encoder(x)
        x=x[-1] #Taking the last feature map only
        x =self.avgpool(x)
        x = torch.flatten(x, 1)
        x =self.fc(x)
        return x
    else:
        return self.seg_model(x)
    
dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = Network()
model.contrast=False

model = nn.DataParallel(model)
model = model.to(dev)

model.eval()
summary(model, (3, 512, 512), 8, 'cuda')