In [1]:
from torch import nn
import torch
from torchvision.models import resnet18
import import_ipynb
from config import DEVICE

importing Jupyter notebook from config.ipynb


In [2]:
class Decoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Decoder, self).__init__()
        self.relu = nn.ReLU(inplace=True)
        # layer1(conv)
        self.conv1 = nn.Conv2d(in_channels, in_channels//4, kernel_size=3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(in_channels//4)
        # layer2(upconv)
        self.upconv = nn.ConvTranspose2d(
            in_channels//4, in_channels//4, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(in_channels//4)
        # layer3(conv)
        self.conv2 = nn.Conv2d(in_channels//4, out_channels, kernel_size=3, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(out_channels)
    
    def forward(self, image):
        x = self.relu(self.bn1(self.conv1(image)))
        x = self.relu(self.bn2(self.upconv(x)))
        x = self.relu(self.bn3(self.conv2(x)))
        
        return x

In [3]:
class AutoEncoder(nn.Module):
    def __init__(self, num_classes=1):
        super(AutoEncoder, self).__init__()
        # base on Resnet-18
        base = resnet18(pretrained=True)

        self.firstconv = nn.Sequential(
            base.conv1,
            base.bn1,
            base.relu,
            base.maxpool,
        )
        
        self.encoder1 = base.layer1
        self.encoder2 = base.layer2
        self.encoder3 = base.layer3
        self.encoder4 = base.layer4

        self.center = Decoder(512, 512)

        self.decoder1 = Decoder(512, 256)
        self.decoder2 = Decoder(256, 128)
        self.decoder3 = Decoder(128, 64)
        self.decoder4 = Decoder(64, 64)
        
        self.finalconv = nn.Sequential(
            nn.Conv2d(64, 32, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Dropout2d(0.1, False),
            nn.Conv2d(32, num_classes, kernel_size=1),
        )
        
    def forward(self, image, extract_feature=False):
        x = self.firstconv(image)
        x = self.encoder1(x)
        x = self.encoder2(x)
        x = self.encoder3(x)
        x = self.encoder4(x)
        
        # return壓縮圖片
        if(extract_feature):
            return x 
        
        x = self.center(x)
        x = self.decoder1(x)
        x = self.decoder2(x)
        x = self.decoder3(x)
        x = self.decoder4(x)
        
        output = self.finalconv(x)
        
        return output

In [5]:
if __name__ == "__main__":
    from torchsummary import summary
    
    inp = torch.ones((1, 3, 128, 128)).to(DEVICE)
    net = AutoEncoder().to(DEVICE)
    out = net(inp)
    print(out.shape)
    summary(net, (3, 224, 224))

torch.Size([1, 1, 128, 128])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5           [-1, 64, 56, 56]          36,864
       BatchNorm2d-6           [-1, 64, 56, 56]             128
              ReLU-7           [-1, 64, 56, 56]               0
            Conv2d-8           [-1, 64, 56, 56]          36,864
       BatchNorm2d-9           [-1, 64, 56, 56]             128
             ReLU-10           [-1, 64, 56, 56]               0
       BasicBlock-11           [-1, 64, 56, 56]               0
           Conv2d-12           [-1, 64, 56, 56]          36,864
      BatchNorm2d-13           [-1, 64, 56, 56]             128
          