In [1]:
import torch
import torchvision
import torchvision.transforms as transforms

In [2]:
# sub-parts of the U-Net model

import torch
import torch.nn as nn
import torch.nn.functional as F

# 需要仔细考虑为什么要定义这么多class

#实现左边的横向卷积
class double_conv(nn.Module):
    '''(conv => BN => ReLU) * 2'''
    def __init__(self, in_ch, out_ch):
        super(double_conv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            
            # 对小批量(mini-batch)3d数据组成的4d输入进行批标准化(Batch Normalization)操作
            # 进行批标准化，在训练时，该层计算每次输入的均值与方差，并进行移动平均
            nn.BatchNorm2d(out_ch),
            
            """inplace=True
            计算结果不会有影响。利用in-place计算可以节省内（显）存，
            同时还可以省去反复申请和释放内存的时间。但是会对原变量覆盖，
            只要不带来错误就用。
            """            
            nn.ReLU(inplace=True),
            
            # torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)
            # 输入的通道数 / 输出的通道数 / 卷积核大小
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            
            
            # 每次卷积之后做一次批标准化
            nn.BatchNorm2d(out_ch),
            
            # 做完标准化之后再进行ReLU
            nn.ReLU(inplace=True)
        )

        
    def forward(self, x):
        x = self.conv(x)
        return x

#实现左边第一行的卷积（因为左边第一行开始时不包括池化 所以要先class一个左边第一行，再class包含池化的类
class inconv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(inconv, self).__init__()
        self.conv = double_conv(in_ch, out_ch)  #此处double_conv为上面定义的函数

    def forward(self, x):
        x = self.conv(x)
        return x

#实现左边的向下池化操作，并完成另一层的卷积
class down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(down, self).__init__()
        self.mpconv = nn.Sequential(
            nn.MaxPool2d(2),  # 以池化为2*2的尺寸进行池化
            double_conv(in_ch, out_ch)  #池化之后接下一层的卷积
        )

    def forward(self, x):
        x = self.mpconv(x)
        return x


class up(nn.Module):
    # 对输入数据应用双线性变换
    def __init__(self, in_ch, out_ch, bilinear=True):
        super(up, self).__init__()
        #  would be a nice idea if the upsampling could be learned too,
        #  but my machine do not have enough memory to handle all those weights
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            """
            class torch.nn.Upsample(size=None, scale_factor=None, mode='nearest', align_corners=None)
            size:根据不同的输入类型制定的输出大小
            scale_factor:指定输出为输入的多少倍数
            mode:可使用的上采样算法，有'nearest', 'linear', 'bilinear', 'bicubic' and 'trilinear'. 默认使用'nearest'
            align_corners:如果为True，输入的角像素将与输出张量对齐，因此将保存下来这些像素的值。仅当使用的算法为'linear', 'bilinear'or 'trilinear'时可以使用。默认设置为False
            """
        else:
            self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2)
            """
            如果不为双线性数据，则使用卷积来实现上采样，对于每一条边输入输出的尺寸的公式如下：
            output = （input-1）*stride - 2*padding +kernal_size +output_padding
            class torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=1, padding=0, 
                                                       output_padding=0, groups=1, bias=True, dilation=1)
            in_channels(int) – 输入信号的通道数
            out_channels(int) – 卷积产生的通道数
            kerner_size(int or tuple) - 卷积核的大小
            stride(int or tuple,optional) - 卷积步长，即要将输入扩大的倍数。
            padding(int or tuple, optional) - 输入的每一条边补充0的层数，高宽都增加2*padding
            output_padding(int or tuple, optional) - 输出边补充0的层数，高宽都增加padding
            groups(int, optional) – 从输入通道到输出通道的阻塞连接数
            bias(bool, optional) - 如果bias=True，添加偏置
            dilation(int or tuple, optional) – 卷积核元素之间的间距
            
            
            
            
            """

        self.conv = double_conv(in_ch, out_ch)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, (diffX // 2, diffX - diffX//2,
                        diffY // 2, diffY - diffY//2))
        
        # for padding issues, see 
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd

        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        return x


class outconv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(outconv, self).__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, 1)

    def forward(self, x):
        x = self.conv(x)
        return x


In [4]:
# full assembly of the sub-parts to form the complete net

import torch.nn.functional as F

#from .unet_parts import *

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes):
        super(UNet, self).__init__()
        self.inc = inconv(n_channels, 64)
        self.down1 = down(64, 128)
        self.down2 = down(128, 256)
        self.down3 = down(256, 512)
        self.down4 = down(512, 512)
        self.up1 = up(1024, 256)
        self.up2 = up(512, 128)
        self.up3 = up(256, 64)
        self.up4 = up(128, 64)
        self.outc = outconv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        x = self.outc(x)
        return F.sigmoid(x)


In [10]:
from torchsummary import summary


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(3,3).to(device)

summary(model, (3, 224, 224))



----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 224, 224]           1,792
       BatchNorm2d-2         [-1, 64, 224, 224]             128
              ReLU-3         [-1, 64, 224, 224]               0
            Conv2d-4         [-1, 64, 224, 224]          36,928
       BatchNorm2d-5         [-1, 64, 224, 224]             128
              ReLU-6         [-1, 64, 224, 224]               0
       double_conv-7         [-1, 64, 224, 224]               0
            inconv-8         [-1, 64, 224, 224]               0
         MaxPool2d-9         [-1, 64, 112, 112]               0
           Conv2d-10        [-1, 128, 112, 112]          73,856
      BatchNorm2d-11        [-1, 128, 112, 112]             256
             ReLU-12        [-1, 128, 112, 112]               0
           Conv2d-13        [-1, 128, 112, 112]         147,584
      BatchNorm2d-14        [-1, 128, 1