In [25]:
import sys
sys.path.append("..")
import torch
import torch.nn as nn
import torch.nn.functional as F
from model.util import MinPool
from model.RESUNet import ResBlock
from model.model import *
from model.util import cat_tensor, crop_tensor
from model.model import Unet

In [28]:
class RC(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size = 3, strides = 1, padding = 1, block_number = 2) -> None:
        super().__init__()
        d = [
               nn.Conv2d(in_channels, out_channels, kernel_size, strides, padding)
            ]
   
        for i in range(block_number - 1):
            d.append(  nn.Conv2d(out_channels, out_channels, kernel_size, strides, padding)  )
        self.body = nn.Sequential(*d)

    def forward(self, x):
        x = self.body(x)
        return x
class Encode(nn.Module):
    def __init__(self, in_channel, out_channel, block_number = 1, conv_type = "conv"):
        super().__init__( )
        if conv_type == "conv":
            self.conv = RC(in_channel, out_channel, block_number = block_number)
        else:
            self.conv = RCS(in_channel, out_channel, block_number = block_number)
        self.downsample = nn.MaxPool2d(2)
        
    def forward(self, x):
        x_conv = self.conv(x)
        x_pool = self.downsample(x_conv)
        return x_conv, x_pool

In [37]:
class EncodeBlock(nn.Module):
    def __init__(self, in_channel, block_number = [ 2, 2, 2, 2], middle_channel = [8, 16, 32, 64, 128], conv_type = "conv"):
        super().__init__( )
        self.encode_0 = Encode(in_channel, middle_channel[0], block_number[0], conv_type)
        self.encode_1 = Encode(middle_channel[0], middle_channel[1], block_number[0], conv_type)
        self.encode_2 = Encode(middle_channel[1], middle_channel[2], block_number[1], conv_type)
        self.encode_3 = Encode(middle_channel[2], middle_channel[3], block_number[2], conv_type)
        self.encode_4 = Encode(middle_channel[3], middle_channel[4], block_number[3], conv_type)
        
        
    def forward(self, x):
        x0_conv, x0_pool = self.encode_0(x)
        x1_conv, x1_pool = self.encode_1(x0_pool)
        x2_conv, x2_pool = self.encode_2(x1_pool)
        x3_conv, x3_pool = self.encode_3(x2_pool)
        x4_conv, x4_pool = self.encode_4(x3_pool)
        return x0_conv, x1_conv, x2_conv, x3_conv, x4_conv

In [38]:

class DecodeBlock(nn.Module):
    def __init__(self, middle_channel = [8, 16, 32, 64, 128]):
        super().__init__( )
        self.pre = nn.Conv2d(in_channel, middle_channel[0], 1, 1)
        self.encode_1 = Encode(middle_channel[0], middle_channel[1], block_number)
        self.encode_2 = Encode(middle_channel[1], middle_channel[2], block_number)
        self.encode_3 = Encode(middle_channel[2], middle_channel[3], block_number)
        self.encode_4 = Encode(middle_channel[3], middle_channel[4], block_number)
        
    def forward(self, x):
        x = self.pre(x)
        x1 = self.encode_1(x)
        x2 = self.encode_2(x1)
        x3 = self.encode_3(x2)
        x4 = self.encode_4(x3)
        return x1, x2, x3, x4



class Decode(nn.Module):
    def __init__(self, in_channel, out_channel, conv_type = "conv"):
        super().__init__( )
        self.deconv = DCBL( in_channel, out_channel)
        self.conv = RCS(in_channel, out_channel, conv_type)
        
    def forward(self, x, y):
        x = self.deconv(x)
        # print(x.shape, y.shape)
        concat = torch.cat([x, y], dim=1)
        x = self.conv(concat)
        return x



In [55]:
class UBlock(nn.Module):
    def __init__(self, in_channel = 1, out_channel = 16, middle_channel = [ 8, 16, 32, 64, 128 ]):
        super().__init__()
        self.encode = EncodeBlock(in_channel, block_number = [2, 2, 2, 2], middle_channel = middle_channel , conv_type = "conv")
        self.up = nn.Upsample(scale_factor = 2)
        self.decode_0 = Decode(middle_channel[1], middle_channel[0], conv_type = "conv")
        self.decode_1 = Decode(middle_channel[2], middle_channel[1], conv_type = "conv")
        self.decode_2 = Decode(middle_channel[3], middle_channel[2], conv_type = "conv")
        self.decode_3 = Decode(middle_channel[4], middle_channel[3], conv_type = "conv")
        self.final = nn.Conv2d( middle_channel[0], out_channel, 1, 1)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        x0, x1, x2, x3, x4 = self.encode(x)
        print(x0.shape, x1.shape, x2.shape, x3.shape, x4.shape)
        # x4 = self.up(x4)

        x_1 = self.decode_3(x4, x3)
        
        x_2 = self.decode_2(x_1, x2)
        
        x_3 = self.decode_1(x_2, x1)
        
        x_4 = self.decode_0(x_3, x0)
        # print(x_4.shape)
        outp = self.sigmoid(self.final(x_4))
        # print(x_4.shape)
        return  outp

In [56]:
a = UBlock(1,1)

In [57]:
d = torch.zeros((1,1,80,80))

In [58]:
a(d)

torch.Size([1, 8, 80, 80]) torch.Size([1, 16, 40, 40]) torch.Size([1, 32, 20, 20]) torch.Size([1, 64, 10, 10]) torch.Size([1, 128, 5, 5])
torch.Size([1, 8, 80, 80])


tensor([[[[0.5349, 0.4223, 0.5027,  ..., 0.5012, 0.5160, 0.5606],
          [0.4539, 0.7564, 0.6146,  ..., 0.6628, 0.5574, 0.5171],
          [0.4196, 0.4581, 0.4133,  ..., 0.5134, 0.6109, 0.5072],
          ...,
          [0.5070, 0.5513, 0.5393,  ..., 0.6484, 0.4538, 0.5778],
          [0.5582, 0.4084, 0.6523,  ..., 0.5177, 0.6396, 0.5299],
          [0.4728, 0.5026, 0.5293,  ..., 0.4657, 0.4785, 0.5060]]]],
       grad_fn=<SigmoidBackward0>)

In [34]:
class Unet(nn.Module):
    def __init__(self, need_return_dict = True):
        super(Unet, self).__init__()
        self.need_return_dict = need_return_dict
        self.layer1_conv = double_conv2d_bn(1, 8)
        self.layer2_conv = double_conv2d_bn(8, 16)
        self.layer3_conv = double_conv2d_bn(16, 32)
        self.layer4_conv = double_conv2d_bn(32, 64)
        self.layer5_conv = double_conv2d_bn(64, 128)
        self.layer6_conv = double_conv2d_bn(128, 64)
        self.layer7_conv = double_conv2d_bn(64, 32)
        self.layer8_conv = double_conv2d_bn(32, 16)
        self.layer9_conv = double_conv2d_bn(16, 8)
        
        self.layer10_conv = nn.Conv2d(8, 1, kernel_size = 3,
                                      stride = 1, padding = 1, bias=True)

        self.deconv1 = deconv2d_bn(128, 64)
        self.deconv2 = deconv2d_bn(64, 32)
        self.deconv3 = deconv2d_bn(32, 16)
        self.deconv4 = deconv2d_bn(16, 8)

        self.sigmoid = nn.Sigmoid()
        self.erode = MinPool(2,2,1)
        self.dilate = nn.MaxPool2d(2, stride = 1)

    def build_result(self, x, y):
        return {
            "mask": x,
            "edge": y,
        }

    def forward(self, x):
        # print(x.shape)
        conv1 = self.layer1_conv(x)
        pool1 = F.max_pool2d(conv1, 2)

        conv2 = self.layer2_conv(pool1)
        pool2 = F.max_pool2d(conv2, 2)

        conv3 = self.layer3_conv(pool2)
        pool3 = F.max_pool2d(conv3, 2)

        conv4 = self.layer4_conv(pool3)
        pool4 = F.max_pool2d(conv4, 2)

        conv5 = self.layer5_conv(pool4)
        print(conv1.shape, conv2.shape, conv3.shape,  conv4.shape,   conv5.shape)
        convt1 = self.deconv1(conv5)
        concat1 = torch.cat([convt1, conv4], dim=1)
        conv6 = self.layer6_conv(concat1)

        convt2 = self.deconv2(conv6)
        concat2 = torch.cat([convt2, conv3], dim=1)
        conv7 = self.layer7_conv(concat2)

        convt3 = self.deconv3(conv7)
        concat3 = torch.cat([convt3, conv2], dim=1)
        conv8 = self.layer8_conv(concat3)

        convt4 = self.deconv4(conv8)
        concat4 = torch.cat([convt4, conv1], dim=1)
        conv9 = self.layer9_conv(concat4)
        outp = self.layer10_conv(conv9)
        outp = self.sigmoid(outp)
        edge = nn.functional.pad(outp, (1, 0, 1, 0))
        edge = self.dilate(edge) - self.erode(edge)
        return  self.build_result(outp, edge) if self.need_return_dict else (outp, edge) 

In [35]:
m = Unet(False)

In [36]:
m(torch.zeros(1,1,80,80))

torch.Size([1, 8, 80, 80]) torch.Size([1, 16, 40, 40]) torch.Size([1, 32, 20, 20]) torch.Size([1, 64, 10, 10]) torch.Size([1, 128, 5, 5])


(tensor([[[[0.4911, 0.4911, 0.4911,  ..., 0.4911, 0.4911, 0.4911],
           [0.4911, 0.4911, 0.4911,  ..., 0.4911, 0.4911, 0.4911],
           [0.4911, 0.4911, 0.4911,  ..., 0.4911, 0.4911, 0.4911],
           ...,
           [0.4911, 0.4911, 0.4911,  ..., 0.4911, 0.4911, 0.4911],
           [0.4911, 0.4911, 0.4911,  ..., 0.4911, 0.4911, 0.4911],
           [0.4911, 0.4911, 0.4911,  ..., 0.4911, 0.4911, 0.4911]]]],
        grad_fn=<SigmoidBackward0>),
 tensor([[[[0.4911, 0.4911, 0.4911,  ..., 0.4911, 0.4911, 0.4911],
           [0.4911, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
           [0.4911, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
           ...,
           [0.4911, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
           [0.4911, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
           [0.4911, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]]]],
        grad_fn=<SubBackward0>))