# 双线性差值

# U-Net 代码

In [44]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [53]:
class UNetConvBlock(nn.Module):
    def __init__(self, in_chans, out_chans, padding, batch_norm):
        super(UNetConvBlock, self).__init__()
        block = [] # 新建空的列表， 要往里面填东西
        
        block.append(nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=int(padding)))
        block.append(nn.ReLU())
        if batch_norm:
            block.append(nn.BatchNorm2d(out_chans))
            pass
        self.block = nn.Sequential(*block)
        pass
    
    def forward(self, x):
        out = self.block(x)
        return out
    
    pass

# 写一个上采样的block
class UNetUpBlock(nn.Module):
    def __init__(self, in_chans, out_chans, up_mode, padding, batch_norm):
        super(UNetUpBlock, self).__init__()
        if up_mode == 'upconv':
            self.up = nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2)
        elif up_mode == 'upsample':
            self.up = nn.Sequential(nn.Upsample(mode='bilinear', scale_factor=2), 
                                    nn.Conv2d(in_chans, out_chans, kernel_size=1))
        self.conv_block = UNetConvBlock(in_chans, out_chans, padding, batch_norm)
        pass
    
    # 定义对feature map的裁剪函数
    def center_crop(self, layer, target_size):
        _, _, layer_height, layer_width = layer.size()
        diff_y = (layer_height - target_size[0])//2
        diff_x = (layer_width - target_size[1])//2
        return layer[:, :, diff_y:(diff_y + target_size[0]), diff_x:(diff_x + target_size[1])]
        pass
    
    def forward(self, x, bridge):
        up = self.up(x)
        crop1 = self.center_crop(bridge, up.shape[2:])
        out = torch.cat([up, crop1, 1])
        out = self.conv_block(out)
        return out
        pass
    
    pass

In [64]:
class UNet(nn.Module):
    def __init__(self, 
                 in_channels = 1,
                 n_classes = 2, # 最终由多少个分类 
                 depth = 5, # 网络的深度
                 wf = 6, # 第一层的个数， 2的wf次方， 正课中给出的数列是64
                 padding = False,
                 batch_norm = False, 
                 up_mode = 'upconv'
                ):
        super(Unet, self).__init__()
        assert up_mode in ('upconv', 'upsample')
        self.padding = padding
        self.depth = depth
        prev_channels = in_channels
        self.down_path = nn.ModuleList()
        for i in range(depth):
            self.down_path.append(UNetConvBlock(prev_channels, 2**(wf+1), padding, batch_norm))
            prev_channels = 2**(wf+i)
            
        self.up_path  = nn.ModuleList()
        for i in reversed(range(depth-1)):
            self.up_path.append(UNetConvBlock(prev_channels, 2**(wf+1), up_mode, padding, batch_norm))
            prev_channels = 2*(wf+i)
        self.list = nn.Conv2d(prev_channels, n_classes, kernel_size=1)
        pass
    
    def forward(self, x):
        blocks = []
        for i, down in enumerate(self.down_path):
            x = down(x)
            if i != len(self.down_path) - 1:
                blocks.append(x)
                x = F.max_pool2d(x, 2)
        for i, up in enumerate(self.up_path):
            x = up(x, block[-i -1])
        return self.last(x)
    
    pass

In [65]:
x = torch.randn((1, 1, 572, 572)) # batch_size是1，通道是1， 输入572 x 572
unet = UNet()
unet.eval()
y_unet = unet(x)

TypeError: super(type, obj): obj must be an instance or subtype of type