In [1]:
import torch
from torch import nn
from math import sqrt

In [2]:
class Conv_ReLU_Block(nn.Module):
    def __init__(self):
        super(Conv_ReLU_Block, self).__init__()
        self.conv = nn.Conv2d(in_channels=64, out_channels=64,
                              kernel_size=3, stride=1, padding=1, bias=False)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        return self.relu(self.conv(x))

In [3]:
class VDSR(nn.Module):
    def __init__(self):
        super(VDSR, self).__init__()
        self.residual_layer = self.make_layer(Conv_ReLU_Block, 18)
        self.input = nn.Conv2d(
            in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)
        self.output = nn.Conv2d(
            in_channels=64, out_channels=3, kernel_size=3, stride=1, padding=1, bias=False)
        self.relu = nn.ReLU(inplace=True)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, sqrt(2. / n))

    def make_layer(self, block, num_of_layer):
        layers = []
        for _ in range(num_of_layer):
            layers.append(block())
        return nn.Sequential(*layers)

    def forward(self, x):
        residual = x
        out = self.relu(self.input(x))
        out = self.residual_layer(out)
        out = self.output(out)
        out = torch.add(out, residual)
        return out

In [4]:
class conv2(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.relu = nn.ReLU()
        
    def forward(self, images):
        x = self.conv1(images)
        x = self.relu(x)
        
        return x

In [5]:
class VDSR2(nn.Module):
    def __init__(self):
        super().__init__()
        #input layer
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1),
            nn.ReLU()
        )
        
        # hidden layer
        trunk = []
        # D = 20, 減去頭尾所以是18
        for _ in range(18):
            trunk.append(conv(64))
            self.trunk = nn.Sequential(*trunk)
            
        # output layer
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=3, kernel_size=3, padding=1)
        
    def forward(self, images):
        x = self.conv1(images)
        x = self.trunk(x)
        output = self.conv2(x)
        output = torch.add(output, images)
        
        return output

In [6]:
if __name__ == '__main__':
    from torchsummary import summary
    
    net = VDSR().cuda()
    summary(net,(3, 256, 256))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 256, 256]           1,728
              ReLU-2         [-1, 64, 256, 256]               0
            Conv2d-3         [-1, 64, 256, 256]          36,864
              ReLU-4         [-1, 64, 256, 256]               0
   Conv_ReLU_Block-5         [-1, 64, 256, 256]               0
            Conv2d-6         [-1, 64, 256, 256]          36,864
              ReLU-7         [-1, 64, 256, 256]               0
   Conv_ReLU_Block-8         [-1, 64, 256, 256]               0
            Conv2d-9         [-1, 64, 256, 256]          36,864
             ReLU-10         [-1, 64, 256, 256]               0
  Conv_ReLU_Block-11         [-1, 64, 256, 256]               0
           Conv2d-12         [-1, 64, 256, 256]          36,864
             ReLU-13         [-1, 64, 256, 256]               0
  Conv_ReLU_Block-14         [-1, 64, 2