In [52]:
import math
import torch 
import torch.nn as nn
import torch.nn.init as init

In [53]:
# 建立block
class Block(nn.Module):

  def __init__(self, in_ch, out_ch):
    super(Block, self).__init__()
    self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1)
    self.bn1 = nn.BatchNorm2d(out_ch)
    self.relu1 = nn.ReLU(inplace=True)

  def forward(self, x):
    out = self.relu1(self.bn1(self.conv1(x)))
    return out
    
# 建立Layer
def make_layers(in_channels, layer_list):
  layers = []
  for v in layer_list:
    layers += [Block(in_channels, v)]
    in_channels = v          # 下次输入为上次输出
  return nn.Sequential(*layers)

class Layer(nn.Module):

  def __init__(self, in_channels, layer_list):
    super(Layer, self).__init__()
    self.layer = make_layers(in_channels, layer_list)
  
  def forward(self, x):
    out = self.layer(x)
    return out

In [54]:
 # VGG 19 
# [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
class VGG(nn.Module):
    def __init__(self):
      super(VGG, self).__init__()
      self.layer1 = Layer(3, [64, 64])
      self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
      self.layer2 = Layer(64, [128, 128])
      self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
      self.layer3 = Layer(128, [256, 256, 256, 256])
      self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
      self.layer4 = Layer(256, [512, 512, 512, 512])
      self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
      self.layer5 = Layer(512, [512, 512, 512, 512])
      self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
      f1 = self.pool1(self.layer1(x))
      f2 = self.pool2(self.layer2(f1))
      f3 = self.pool3(self.layer3(f2))
      f4 = self.pool4(self.layer4(f3))
      f5 = self.pool5(self.layer5(f4))
      return [f3, f4, f5]

In [55]:
 # 建立上采样模块
class MergeUpsample(nn.Module):

    def __init__(self, in_cha1, in_cha2, out_chan):
      super(MergeUpsample, self).__init__()
      self.conv11 = Block(in_chan1, out_ch)
      self.conv12 = Block(in_chan2, out_ch)
      self.conv2 = Block(out_ch, out_ch)
      self.upsample = nn.ConvTranspose2d(out_chan1, 
          out_chan2 ,2 ,stride=2) 

    def forward(self, x, y):
        p1 = self.conv11(self.upsample(x))
        p2 = self.conv12(y)
        out = self.conv2(p1+p2) 
        return out

In [56]:
class FCNDecode(nn.Module):
    def __init__(self, n, in_channels, out_channels, upsample_ratio):
        super(FCNDecode, self).__init__()
        self.conv1 = Layer(in_channels, [out_channels]*n)
        self.trans_conv1 = nn.ConvTranspose2d(
                out_channels,
                out_channels,
                upsample_ratio,
                stride=upsample_ratio)        
    def forward(self, x):
        out = self.trans_conv1(self.conv1(x))
        return out

# 建立FCN_Seg模型
class FCNSeg(nn.Module):
    def __init__(self, n, in_channels, out_channels, upsample_ratio):
        super(FCNSeg, self).__init__()
        self.encode = VGG()
        self.decode = FCNDecode(n, in_channels, out_channels, upsample_ratio)
        self.classifier = nn.Conv2d(out_channels, 10, 3, padding=1)
    def forward(self, x):
        feature_list = self.encode(x)
        out = self.decode(feature_list[-1])
        pro = self.classifier(out)
        return out


In [57]:
x = torch.randn((10, 3, 256, 256)) # batchsize, channel, h, w
model = FCNSeg(4, 512, 256, 32) # 卷积层， 输入通道， 输出通道，上采32
model.eval()
y = model(x)
a = y.size()  #torch.Size([10, 256, 256, 256])
print(a)


torch.Size([10, 256, 256, 256])


In [71]:
mm = VGG()
a = mm(x)

In [79]:

a[0]

tensor([[[[1.9787e+00, 2.5396e-01, 1.1422e+00,  ..., 4.1819e-01,
           5.2685e-01, 1.5354e+00],
          [5.8054e-01, 1.5523e+00, 1.2389e+00,  ..., 9.3665e-01,
           2.1467e+00, 2.2394e+00],
          [1.9400e+00, 1.4048e+00, 2.1371e+00,  ..., 7.1560e-01,
           8.2002e-01, 2.3327e+00],
          ...,
          [2.0689e-01, 0.0000e+00, 2.8828e-01,  ..., 1.8751e+00,
           9.2167e-01, 2.4849e+00],
          [1.1684e+00, 5.6458e-01, 9.9096e-01,  ..., 5.0725e-01,
           2.8691e+00, 1.4341e+00],
          [1.4612e+00, 5.8960e-01, 2.8911e+00,  ..., 1.5678e+00,
           1.7246e+00, 1.6685e+00]],

         [[8.1094e-01, 1.2471e+00, 1.4414e+00,  ..., 2.5748e+00,
           1.6712e+00, 1.7166e+00],
          [4.1622e-01, 1.3656e+00, 3.2727e-01,  ..., 1.4686e+00,
           1.6716e+00, 1.3965e+00],
          [0.0000e+00, 6.9457e-01, 1.7946e+00,  ..., 0.0000e+00,
           2.3336e-01, 1.9426e+00],
          ...,
          [0.0000e+00, 4.9444e-01, 1.9850e+00,  ..., 1.1831