In [13]:
import torch
from torch import nn

In [14]:
class conv(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.relu = nn.ReLU()
        
    def forward(self, images):
        x = self.conv1(images)
        x = self.relu(x)
        
        return x

In [15]:
class VDSR(nn.Module):
    def __init__(self):
        super().__init__()
        #input layer
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, padding=1),
            nn.ReLU()
        )
        
        # hidden layer
        trunk = []
        # D = 20, 減去頭尾所以是18
        for _ in range(18):
            trunk.append(conv(64))
            self.trunk = nn.Sequential(*trunk)
            
        # output layer
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=3, padding=1)
        
    def forward(self, images):
        x = self.conv1(images)
        x = self.trunk(x)
        output = self.conv2(x)
        output = torch.add(output, images)
        
        return output

In [17]:
if __name__ == '__main__':
    from torchsummary import summary
    
    net = VDSR().cuda()
    summary(net,(1, 256, 256))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 256, 256]             640
              ReLU-2         [-1, 64, 256, 256]               0
            Conv2d-3         [-1, 64, 256, 256]          36,928
              ReLU-4         [-1, 64, 256, 256]               0
              conv-5         [-1, 64, 256, 256]               0
            Conv2d-6         [-1, 64, 256, 256]          36,928
              ReLU-7         [-1, 64, 256, 256]               0
              conv-8         [-1, 64, 256, 256]               0
            Conv2d-9         [-1, 64, 256, 256]          36,928
             ReLU-10         [-1, 64, 256, 256]               0
             conv-11         [-1, 64, 256, 256]               0
           Conv2d-12         [-1, 64, 256, 256]          36,928
             ReLU-13         [-1, 64, 256, 256]               0
             conv-14         [-1, 64, 2