### torch 정보와 library

In [1]:
!pip install torchinfo

Collecting torchinfo
  Downloading torchinfo-0.0.8-py3-none-any.whl (16 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-0.0.8


In [2]:
import torch
import torch.nn as nn

from torchinfo import summary

![](https://github.com/ugent-korea/pytorch-unet-segmentation/raw/master/readme_images/UNet_custom_parameter.png)

### 2가지 방법의 unet

https://github.com/ugent-korea/pytorch-unet-segmentation

In [3]:
class UNet(nn.Module):

    def __init__(self):

        super(UNet, self).__init__()

        # Conv block 1 - Down 1
        self.conv1_block = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=32,
                      kernel_size=3, padding=0, stride=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=32, out_channels=32,
                      kernel_size=3, padding=0, stride=1),
            nn.ReLU(inplace=True),
        )
        self.max1 = nn.MaxPool2d(kernel_size=2, stride=2)

        # Conv block 2 - Down 2
        self.conv2_block = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=64,
                      kernel_size=3, padding=0, stride=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=64, out_channels=64,
                      kernel_size=3, padding=0, stride=1),
            nn.ReLU(inplace=True),
        )
        self.max2 = nn.MaxPool2d(kernel_size=2, stride=2)

        # Conv block 3 - Down 3
        self.conv3_block = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=128,
                      kernel_size=3, padding=0, stride=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=128, out_channels=128,
                      kernel_size=3, padding=0, stride=1),
            nn.ReLU(inplace=True),
        )
        self.max3 = nn.MaxPool2d(kernel_size=2, stride=2)

        # Conv block 4 - Down 4
        self.conv4_block = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=256,
                      kernel_size=3, padding=0, stride=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=256, out_channels=256,
                      kernel_size=3, padding=0, stride=1),
            nn.ReLU(inplace=True),
        )
        self.max4 = nn.MaxPool2d(kernel_size=2, stride=2)

        # Conv block 5 - Down 5
        self.conv5_block = nn.Sequential(
            nn.Conv2d(in_channels=256, out_channels=512,
                      kernel_size=3, padding=0, stride=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=512, out_channels=512,
                      kernel_size=3, padding=0, stride=1),
            nn.ReLU(inplace=True),
        )

        # Up 1
        self.up_1 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=2, stride=2)

        # Up Conv block 1
        self.conv_up_1 = nn.Sequential(
            nn.Conv2d(in_channels=512, out_channels=256,
                      kernel_size=3, padding=0, stride=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=256, out_channels=256,
                      kernel_size=3, padding=0, stride=1),
            nn.ReLU(inplace=True),
        )

        # Up 2
        self.up_2 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=2, stride=2)

        # Up Conv block 2
        self.conv_up_2 = nn.Sequential(
            nn.Conv2d(in_channels=256, out_channels=128,
                      kernel_size=3, padding=0, stride=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=128, out_channels=128,
                      kernel_size=3, padding=0, stride=1),
            nn.ReLU(inplace=True),
        )

        # Up 3
        self.up_3 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=2, stride=2)

        # Up Conv block 3
        self.conv_up_3 = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=64,
                      kernel_size=3, padding=0, stride=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=64, out_channels=64,
                      kernel_size=3, padding=0, stride=1),
            nn.ReLU(inplace=True),
        )

        # Up 4
        self.up_4 = nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=2, stride=2)

        # Up Conv block 4
        self.conv_up_4 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=32,
                      kernel_size=3, padding=0, stride=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=32, out_channels=32,
                      kernel_size=3, padding=0, stride=1),
            nn.ReLU(inplace=True),
        )

        # Final output
        self.conv_final = nn.Conv2d(in_channels=32, out_channels=2,
                                    kernel_size=1, padding=0, stride=1)

    def forward(self, x):

        # Down 1
        x = self.conv1_block(x)
        conv1_out = x  # Save out1
        conv1_dim = x.shape[2]
        x = self.max1(x)

        # Down 2
        x = self.conv2_block(x)
        conv2_out = x
        conv2_dim = x.shape[2]
        x = self.max2(x)

        # Down 3
        x = self.conv3_block(x)
        conv3_out = x
        conv3_dim = x.shape[2]
        x = self.max3(x)

        # Down 4
        x = self.conv4_block(x)
        conv4_out = x
        conv4_dim = x.shape[2]
        x = self.max4(x)

        # Midpoint
        x = self.conv5_block(x)

        # Up 1
        x = self.up_1(x)
        lower = int((conv4_dim - x.shape[2]) / 2)
        upper = int(conv4_dim - lower)
        conv4_out_modified = conv4_out[:, :, lower:upper, lower:upper]
        x = torch.cat([x, conv4_out_modified], dim=1)
        x = self.conv_up_1(x)

        # Up 2
        x = self.up_2(x)
        lower = int((conv3_dim - x.shape[2]) / 2)
        upper = int(conv3_dim - lower)
        conv3_out_modified = conv3_out[:, :, lower:upper, lower:upper]
        x = torch.cat([x, conv3_out_modified], dim=1)
        x = self.conv_up_2(x)

        # Up 3
        x = self.up_3(x)
        lower = int((conv2_dim - x.shape[2]) / 2)
        upper = int(conv2_dim - lower)
        conv2_out_modified = conv2_out[:, :, lower:upper, lower:upper]
        x = torch.cat([x, conv2_out_modified], dim=1)
        x = self.conv_up_3(x)

        # Up 4
        x = self.up_4(x)
        lower = int((conv1_dim - x.shape[2]) / 2)
        upper = int(conv1_dim - lower)
        conv1_out_modified = conv1_out[:, :, lower:upper, lower:upper]
        x = torch.cat([x, conv1_out_modified], dim=1)
        x = self.conv_up_4(x)

        # Final output
        x = self.conv_final(x)

        return x

In [4]:
model=UNet()
summary(model,(1,1,572,572))

Layer (type:depth-idx)                   Output Shape              Param #
├─Sequential: 1-1                        [1, 32, 568, 568]         --
|    └─Conv2d: 2-1                       [1, 32, 570, 570]         320
|    └─ReLU: 2-2                         [1, 32, 570, 570]         --
|    └─Conv2d: 2-3                       [1, 32, 568, 568]         9,248
|    └─ReLU: 2-4                         [1, 32, 568, 568]         --
├─MaxPool2d: 1-2                         [1, 32, 284, 284]         --
├─Sequential: 1-3                        [1, 64, 280, 280]         --
|    └─Conv2d: 2-5                       [1, 64, 282, 282]         18,496
|    └─ReLU: 2-6                         [1, 64, 282, 282]         --
|    └─Conv2d: 2-7                       [1, 64, 280, 280]         36,928
|    └─ReLU: 2-8                         [1, 64, 280, 280]         --
├─MaxPool2d: 1-4                         [1, 64, 140, 140]         --
├─Sequential: 1-5                        [1, 128, 136, 136]        --
|  

https://amaarora.github.io/2020/09/13/unet.html

In [5]:
import torchvision

In [6]:
class Block(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3)
        self.relu  = nn.ReLU()
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3)
    
    def forward(self, x):
        return self.conv2(self.relu(self.conv1(x)))


class Encoder(nn.Module):
    def __init__(self, chs=(3,64,128,256,512,1024)):
        super().__init__()
        self.enc_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)])
        self.pool       = nn.MaxPool2d(2)
    
    def forward(self, x):
        ftrs = []
        for block in self.enc_blocks:
            x = block(x)
            ftrs.append(x)
            x = self.pool(x)
        return ftrs


class Decoder(nn.Module):
    def __init__(self, chs=(1024, 512, 256, 128, 64)):
        super().__init__()
        self.chs         = chs
        self.upconvs    = nn.ModuleList([nn.ConvTranspose2d(chs[i], chs[i+1], 2, 2) for i in range(len(chs)-1)])
        self.dec_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)]) 
        
    def forward(self, x, encoder_features):
        for i in range(len(self.chs)-1):
            x        = self.upconvs[i](x)
            enc_ftrs = self.crop(encoder_features[i], x)
            x        = torch.cat([x, enc_ftrs], dim=1)
            x        = self.dec_blocks[i](x)
        return x
    
    def crop(self, enc_ftrs, x):
        _, _, H, W = x.shape
        enc_ftrs   = torchvision.transforms.CenterCrop([H, W])(enc_ftrs)
        return enc_ftrs


class UNet(nn.Module):
    def __init__(self, enc_chs=(3,64,128,256,512,1024), dec_chs=(1024, 512, 256, 128, 64), num_class=1, retain_dim=False, out_sz=(572,572)):
        super().__init__()
        self.encoder     = Encoder(enc_chs)
        self.decoder     = Decoder(dec_chs)
        self.head        = nn.Conv2d(dec_chs[-1], num_class, 1)
        self.retain_dim  = retain_dim

    def forward(self, x):
        enc_ftrs = self.encoder(x)
        out      = self.decoder(enc_ftrs[::-1][0], enc_ftrs[::-1][1:])
        out      = self.head(out)
        if self.retain_dim:
            out = F.interpolate(out, out_sz)
        return out

In [7]:
model=UNet()
summary(model,(1,3,572,572))

Layer (type:depth-idx)                   Output Shape              Param #
├─Encoder: 1-1                           [1, 64, 568, 568]         --
|    └─ModuleList: 2                     []                        --
|    |    └─Block: 3-1                   [1, 64, 568, 568]         38,720
|    └─MaxPool2d: 2-1                    [1, 64, 284, 284]         --
|    └─ModuleList: 2                     []                        --
|    |    └─Block: 3-2                   [1, 128, 280, 280]        221,440
|    └─MaxPool2d: 2-2                    [1, 128, 140, 140]        --
|    └─ModuleList: 2                     []                        --
|    |    └─Block: 3-3                   [1, 256, 136, 136]        885,248
|    └─MaxPool2d: 2-3                    [1, 256, 68, 68]          --
|    └─ModuleList: 2                     []                        --
|    |    └─Block: 3-4                   [1, 512, 64, 64]          3,539,968
|    └─MaxPool2d: 2-4                    [1, 512, 32, 32]       