In [29]:
# 构建UNet网络

<img src="unet.png" width="1200">

In [30]:
import torch

In [31]:
# 两次卷积操作
class ConvBlock(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.step = torch.nn.Sequential(
            # 第一次卷积 (不改变大小，只改变输出通道数)
            torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, stride=1),
            #ReLU
            torch.nn.ReLU(),
            # 第二次卷积 (不改变大小，不改变输出通道数)
            torch.nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1, stride=1),
            #ReLU
            torch.nn.ReLU()
        )

    def forward(self, x):
        return self.step(x)

In [32]:
from torchsummary import summary

In [33]:
# 模块初始化
conv_block = ConvBlock(1, 64).to('cuda:0')

In [34]:
# 查看输出大小
summary(conv_block, (1, 256, 256))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 256, 256]             640
              ReLU-2         [-1, 64, 256, 256]               0
            Conv2d-3         [-1, 64, 256, 256]          36,928
              ReLU-4         [-1, 64, 256, 256]               0
Total params: 37,568
Trainable params: 37,568
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.25
Forward/backward pass size (MB): 128.00
Params size (MB): 0.14
Estimated Total Size (MB): 128.39
----------------------------------------------------------------


In [35]:
# 定义网络架构 (下采样：最大池化 上采样：双线性插值 特征融合：)
class UNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # 定义左半部分
        self.layer1 = ConvBlock(1, 64)
        self.layer2 = ConvBlock(64, 128)
        self.layer3 = ConvBlock(128, 256)
        self.layer4 = ConvBlock(256, 512)

        # 定义右半部分
        self.layer5 = ConvBlock(256 + 512, 256)
        self.layer6 = ConvBlock(128 + 256, 128)
        self.layer7 = ConvBlock(64 + 128, 64)

        # 最后一个卷积
        self.layer8 = torch.nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1, stride=1, padding=0)

        #池化
        self.Maxpool = torch.nn.MaxPool2d(kernel_size=2)
        # 上采样 -- scale_factor:放大倍数
        self.UpSample = torch.nn.Upsample(scale_factor=2, mode='bilinear')

        #sigmoid
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, x):
        # 对输入数据 x进行处理 (下采样)
        # input:(1*256*256) output:(64*256*256)
        x1 = self.layer1(x)
        # 池化
        # input:(64*256*256) output:(64*128*128)
        x1_mp = self.Maxpool(x1)

        # input:(64*128*128) output:(128*128*128)
        x2 = self.layer2(x1_mp)
        # input:(128*128*128) output:(128*64*64)
        x2_mp = self.Maxpool(x2)

        # input:(128*64*64) output:(256*64*64)
        x3 = self.layer3(x2_mp)
        # input:(256*64*64) output:(256*32*32)
        x3_mp = self.Maxpool(x3)

        # input:(256*32*32) output:(512*32*32)
        x4 = self.layer4(x3_mp)

        # 上采样部分
        # input:(512*32*32) output:(512*64*64)
        x5 = self.UpSample(x4)
        # 特征拼接 x3 和 x5
        x5 = torch.cat([x5, x3], dim=1)  # 在通道维度上拼接   output:(768*64*64)
        # 卷积 intput:(768*64*64)  output:(256*64*64)
        x5 = self.layer5(x5)

        # intput:(256*64*64)  output:(256*128*128)
        x6 = self.UpSample(x5)
        # 拼接 在通道维度上拼接 output:(384,128,128)
        x6 = torch.cat([x6, x2], dim=1)
        # 卷积 intput:(384*128*128)  output:(128*128*128)
        x6 = self.layer6(x6)

        # intput:(128*128*128) output:(128*256*256)
        x7 = self.UpSample(x6)
        # 拼接 在通道维度上拼接 output:(64+ 128*256*256)
        x7 = torch.cat([x7, x1], dim=1)
        # 卷积 input:(192*256*256) output:(64*256*256)
        x7 = self.layer7(x7)

        # 最后一次卷积
        # input:(64*256*256) output:(1*256*256)
        x8 = self.layer8(x7)

        #sigmoid
        x9 = self.sigmoid(x8)

        return x9

In [36]:
# unet实例化
unet = UNet().to('cuda:0')

In [37]:
summary(unet,(1,256,256))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 256, 256]             640
              ReLU-2         [-1, 64, 256, 256]               0
            Conv2d-3         [-1, 64, 256, 256]          36,928
              ReLU-4         [-1, 64, 256, 256]               0
         ConvBlock-5         [-1, 64, 256, 256]               0
         MaxPool2d-6         [-1, 64, 128, 128]               0
            Conv2d-7        [-1, 128, 128, 128]          73,856
              ReLU-8        [-1, 128, 128, 128]               0
            Conv2d-9        [-1, 128, 128, 128]         147,584
             ReLU-10        [-1, 128, 128, 128]               0
        ConvBlock-11        [-1, 128, 128, 128]               0
        MaxPool2d-12          [-1, 128, 64, 64]               0
           Conv2d-13          [-1, 256, 64, 64]         295,168
             ReLU-14          [-1, 256,