In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class DoubConv(nn.Module):
    # 2个3*3 Conv结构, no padding
    def __init__(self, in_channels, out_channels):
        super(DoubConv, self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )
    def forward(self, x):
        return self.double_conv(x)

# t1 = DoubConv(3, 64)
# x = torch.randn(1, 3, 572, 572)
# t1(x).shape
# (1, 64, 568, 568)

In [3]:
class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Down, self).__init__()
        self.max_pool = nn.MaxPool2d(kernel_size=2)
        self.double_conv = DoubConv(in_channels, out_channels)
        
    def forward(self, x):
        return self.double_conv(self.max_pool(x))


In [4]:
class Up(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Up, self).__init__()
        self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
        # feacher map翻倍，需要kernel=2， stride=2
        # out = (in + 2* padding -kernel)/stride + 1
        # in = (out -1)* stride + kernel - 2*padding
        # in = (out - 1) * 2 + 2 = 2 * out
        self.double_conv = DoubConv(in_channels, out_channels)
        
    # 根据up的size对bridge进行裁剪
    def center_crop(self, up, bridge):
        _,_,up_h,up_w = up.size()
        _,_,bridge_h,bridge_w = up.size()
        diff_h = bridge_h - up_h
        diff_w = bridge_w - up_w
        return bridge[:,:, diff_h//2:diff_h//2+up_h, diff_w //2:diff_w+up_w]
    
    def forward(self, x, bridge):
        up = self.up(x)
        crop = self.center_crop(up, bridge)
        print('up', up.shape)
        print('crop', crop.shape)
        # cat操作，增加channel，所以dim为1
        out = torch.cat([up, crop], 1)
        return self.double_conv(out)
# x = torch.randn(1, 3, 572, 572)
# m1 = DoubConv(3, 64)
# Down1 = Down(64, 128)
# Down2 = Down(128, 256)
# Down3 = Down(256, 512)
# Down4 = Down(512, 1024)
# up4 = Up(1024, 512)
# up3 = Up(512, 256)
# up2 = Up(256, 128)
# up1 = Up(128, 64)
# d1 = m1(x)
# d2 = Down1(d1)
# d3 = Down2(d2)
# d4 = Down3(d3)
# x = Down4(d4)
# u4 = up4(x, d4)
# u3 = up3(u4, d3)
# u2 = up2(u3, d2)
# u1 = up1(u2, d1)

In [13]:
class Unet(nn.Module):
    def __init__(
        self,
        in_channels,# 输入channels
        n_class, # 输出class
        depth, # unet深度 论文中为5
        wf, # 第一层unet的channels 论文中为64
    ):
        super(Unet, self).__init__()
        self.down_path = nn.ModuleList()
        self.up_path = nn.ModuleList()
#         self.double_conv = DoubConv(in_channels, wf)
        self.out_conv = nn.Conv2d(wf, n_class, kernel_size=1)
        self.down_path.append(DoubConv(in_channels, wf))
        for i in range(depth-1):
            self.down_path.append(Down(wf*2**i, wf*2**(i+1)))
        for i in reversed(range(depth-1)):
            self.up_path.append(Up(wf*2**(i+1), wf*2**i))
    
    def forward(self, x):
        bridges = []
#         x = self.double_conv(x)
        down = x
        for i in range(len(self.down_path)):
#             print(self.down_path[i])
            down = self.down_path[i](down)
            print(down.shape)
            if i != len(self.down_path) -1:
                bridges.append(down)
        up = down
        for bridge in bridges:
            print(bridge.shape)
        for i in range(len(self.up_path)):
            print(self.up_path[i])
            up = self.up_path[i](up, bridges[-i])
            print(up.shape)
        print(up)
    
unet = Unet(3, 2, 5, 64)
x = torch.randn(1, 3, 572, 572)
unet(x)

torch.Size([1, 64, 568, 568])
torch.Size([1, 128, 280, 280])
torch.Size([1, 256, 136, 136])
torch.Size([1, 512, 64, 64])
torch.Size([1, 1024, 28, 28])
torch.Size([1, 64, 568, 568])
torch.Size([1, 128, 280, 280])
torch.Size([1, 256, 136, 136])
torch.Size([1, 512, 64, 64])
Up(
  (up): ConvTranspose2d(1024, 512, kernel_size=(2, 2), stride=(2, 2))
  (double_conv): DoubConv(
    (double_conv): Sequential(
      (0): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1))
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1))
      (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU()
    )
  )
)
up torch.Size([1, 512, 56, 56])
crop torch.Size([1, 64, 56, 56])


RuntimeError: Given groups=1, weight of size 512 1024 3 3, expected input[1, 576, 56, 56] to have 1024 channels, but got 576 channels instead

In [24]:
#双线性插值和装置卷积
# 双线性插值用来代替转置卷积做上采样
# 1. 1*1 kernel 的作用：
# （1）调整channels
# （2）增加特征提取能力
# 2. 为啥fcn的padding=100
# (i + 2p - k)/s + 1 <= 7
# s = 32 k=3    6*32 + 3 = 195 所以p取100

# 3.转置卷积恢复形状，不能恢复数值
# 4.FCn的skip是做什么的
# 特征融合
# x = torch.randn(1,2,3,4)
# torch.cat([x,x],1).shape
# torch.cat?
# torch.Size([1, 4, 3, 4])
print(DoubConv(3, 64))

DoubConv(
  (double_conv): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
  )
)


In [30]:
import cv2
import numpy as np
import matplotlib.pyplot as plt

In [34]:
def Bi(src, new_size):
    dst_h, dst_w = new_size
    src_h, src_w = src[:2]
    if des_h == src_h and des_w == src_w:
        return src.copy()
    scale_x = float(src_w) / des_w
    scale_y = float(src_h) / des_h
    dst = np.zeros((dst_h, dst_w, 3))
    for i in range(3):
        for dst_y in range(dst_h):
            for dst_x in range(dst_w):
                # 像素单位为1 每个像素（h，w）对应中心位置（h+0.5，w + 0.5）
                src_x = (des_x + 0.5) * scale_x - 0.5
                src_y = (des_y + 0.5) * scale_y - 0.5
                # 左上角的点
                src_x_0 = int(np.floor(src_x))
                src_y_0 = int(np.floor(src_y))
                #像素值边界为1，防止出界
                src_x_1 = min(src_x_0 + 1, src_w -1)
                src_x_1 = min(src_x_0 + 1, src_w -1)

SyntaxError: unexpected EOF while parsing (<ipython-input-34-4198aef6a38b>, line 11)

In [1]:
# Unet 代码实现

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [10]:

# 下采样
class UnetConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, padding, batch_norm):
        super(UnetConvBlock, self).__init__()
        block = [] #建造空列表
        block.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=int(padding)))
        block.append(nn.ReLU)
        if batch_norm:
            block.append(nn.BatchNorm2d(out_channels))
        block.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=int(padding)))
        block.append(nn.ReLU)
        if batch_norm:
            block.append(nn.BatchNorm2d(out_channels))
        self.block = nn.Sequential(*block)
    
    def forward(self, x):
        return self.block(x)
# 上采样
def UnetUpBlock(nn.Module):
    def __init__(self, in_channels, out_channels, up_mode, padding, batch_norm):
        super(UnetUpBlock, self).__init__()
        if up_mode = "upconv":
            self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
        elif up_mode = "upsample":
            self.up = nn.Sequential(nn.Upsample(mode='bilinera', scale_factor=2),
                                    nn.Conv2d(in_channels, out_channels, kernel_size=1))
        self.conv_block = UnetConvBlock(in_channels, out_channels, padding, batch_norm)
    def center_crop(self, layer, target_size):
        _,_,layer_h,layer_w = layer.size()

class Unet(nn.Module):
    def __init__(
        self,
        in_channels=1, 
        n_class=2, 
        depth=5, 
        wf=6, 
        padding=False, 
        batch_norm=False, 
        up_mode='upconv'
    ):
        super(Unet, self).__init__()
        self.padding = padding
        self.depth = depth
        prev_channels = in_channels
        self.down_path = nn.ModuleList()
        for i in range(depth):
            self.down_path.append()

SyntaxError: invalid syntax (<ipython-input-10-4d57dce00d1b>, line 19)

In [None]:
e