In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [2]:
from efficientnet_pytorch import EfficientNet
model = EfficientNet.from_pretrained('efficientnet-b7')
model.to(device=device)

Loaded pretrained weights for efficientnet-b7


EfficientNet(
  (_conv_stem): Conv2dStaticSamePadding(
    3, 64, kernel_size=(3, 3), stride=(2, 2), bias=False
    (static_padding): ZeroPad2d(padding=(0, 1, 0, 1), value=0.0)
  )
  (_bn0): BatchNorm2d(64, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
  (_blocks): ModuleList(
    (0): MBConvBlock(
      (_depthwise_conv): Conv2dStaticSamePadding(
        64, 64, kernel_size=(3, 3), stride=[1, 1], groups=64, bias=False
        (static_padding): ZeroPad2d(padding=(1, 1, 1, 1), value=0.0)
      )
      (_bn1): BatchNorm2d(64, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
      (_se_reduce): Conv2dStaticSamePadding(
        64, 16, kernel_size=(1, 1), stride=(1, 1)
        (static_padding): Identity()
      )
      (_se_expand): Conv2dStaticSamePadding(
        16, 64, kernel_size=(1, 1), stride=(1, 1)
        (static_padding): Identity()
      )
      (_project_conv): Conv2dStaticSamePadding(
        64, 32, kernel_siz

In [3]:
blocks = []
blocks.append([list(model.children())[0]])
blocks.append([list(model.children())[1]])

MBBlocks_in_block = [4, 7, 7, 10, 10, 13, 4]
k = 0
for i in MBBlocks_in_block:
    block = []
    for j in range(0, i, 1):
        block.append(list(model.children())[2][k])
        k += 1
    blocks.append(block)

blocks

[[Conv2dStaticSamePadding(
    3, 64, kernel_size=(3, 3), stride=(2, 2), bias=False
    (static_padding): ZeroPad2d(padding=(0, 1, 0, 1), value=0.0)
  )],
 [BatchNorm2d(64, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)],
 [MBConvBlock(
    (_depthwise_conv): Conv2dStaticSamePadding(
      64, 64, kernel_size=(3, 3), stride=[1, 1], groups=64, bias=False
      (static_padding): ZeroPad2d(padding=(1, 1, 1, 1), value=0.0)
    )
    (_bn1): BatchNorm2d(64, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
    (_se_reduce): Conv2dStaticSamePadding(
      64, 16, kernel_size=(1, 1), stride=(1, 1)
      (static_padding): Identity()
    )
    (_se_expand): Conv2dStaticSamePadding(
      16, 64, kernel_size=(1, 1), stride=(1, 1)
      (static_padding): Identity()
    )
    (_project_conv): Conv2dStaticSamePadding(
      64, 32, kernel_size=(1, 1), stride=(1, 1), bias=False
      (static_padding): Identity()
    )
    (_bn2): Batch

In [4]:
class EFFUnet_encoder(nn.Module):
    def __init__(self, reqd_blocks):
        super().__init__()
        self.required_blocks = reqd_blocks

    def forward(self, x):
        block_outputs = []
        for blocks in self.required_blocks:
            for block in blocks:
                x = block(x)
                # print("Shape", x.shape)
            block_outputs.append(x)

        return x, block_outputs

In [5]:
# effunet_encoder = EFFUnet_encoder(blocks)
# x = torch.randn(1,3,310,250)
# out, block_outputs = effunet_encoder(x)
# out.shape, len(block_outputs)

In [6]:
class Upconvolution(nn.Module):
    def __init__(self, input_channels, output_channels):
        super().__init__()
        self.upconv = nn.ConvTranspose2d(input_channels, output_channels, kernel_size=2, stride=2)

    def forward(self, x):
        return self.upconv(x)

        # out = F.interpolate(out, (height, width), mode='bilinear', align_corners=True)

class EFFUnet_decoder(nn.Module):
    def __init__(self, n_classes):
        super().__init__()
        self.upconvolution_1 = Upconvolution(640, 512)
        self.upconvolution_2 = Upconvolution(896, 256)
        self.upconvolution_3 = Upconvolution(416, 128)
        self.upconvolution_4 = Upconvolution(208, 64)
        self.upconvolution_5 = Upconvolution(112, 16)

        # self.encoder_outputs = encoder_outputs

        self.conv = nn.Conv2d(16, n_classes, kernel_size=(1,1), stride=(1,1))

    def forward(self, x, encoder_outputs):
        x = self.upconvolution_1(x)
        print(x.size())
        encoder_output = F.interpolate(encoder_outputs[7], scale_factor=2, mode='bilinear', align_corners=True)
        print(encoder_output.size())
        x = concat(x, encoder_output)
        print(x.size())

        x = self.upconvolution_2(x)
        print(x.size())
        encoder_output = F.interpolate(encoder_outputs[5], scale_factor=2, mode='bilinear', align_corners=True)
        print(encoder_output.size())
        x = concat(x, encoder_output)
        # print(x.size())

        x = self.upconvolution_3(x)
        # print(x.size())
        encoder_output = F.interpolate(encoder_outputs[4], scale_factor=2, mode='bilinear', align_corners=True)
        # print(encoder_output.size())
        x = concat(x, encoder_output)
        # print(x.size())

        x = self.upconvolution_4(x)
        # print(x.size())
        encoder_output = F.interpolate(encoder_outputs[3], scale_factor=2, mode='bilinear', align_corners=True)
        # print(encoder_output.size())
        x = concat(x, encoder_output)
        # print(x.size())

        x = self.upconvolution_5(x)
        # print(x.size())
        x = self.conv(x)
        # print(x.size())

        return x

def concat(encoder, decoder):
    return torch.cat((decoder, encoder), dim=1)



In [7]:
# effunet_decoder = EFFUnet_decoder(2)
# # x = torch.randn(1,3,320,224)
# out = effunet_decoder(out, block_outputs)
# out.shape

In [8]:
class EFFUnet(nn.Module):
    def __init__(self, blocks, n_classes):
        super().__init__()
        self.effunet_encoder = EFFUnet_encoder(blocks)
        self.effunet_decoder = EFFUnet_decoder(n_classes=n_classes)

    def forward(self, x):
        x, block_outputs = self.effunet_encoder(x)
        out = self.effunet_decoder(x, block_outputs)

        return out
        

In [9]:
model = EFFUnet(blocks, n_classes = 2)
model.to(device=device)
x = torch.randn(1,3,320,224).to(device=device)
out = model(x)
out.size()

torch.Size([1, 512, 20, 14])
torch.Size([1, 384, 20, 14])
torch.Size([1, 896, 20, 14])
torch.Size([1, 256, 40, 28])
torch.Size([1, 160, 40, 28])


torch.Size([1, 2, 320, 224])