In [None]:
class SRResNet(nn.Module):
    def __init__(self, num_channels=3, scale_factor=2, num_residual_blocks=64):
        super(SRResNet, self).__init__()
        self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=9, stride=1, padding=4)
        self.activation = nn.ReLU(inplace=True)
        self.residual_blocks = nn.Sequential(*[ResidualBlock(64) for _ in range(num_residual_blocks)])
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
        self.bn = nn.BatchNorm2d(32)
        self.upscale = nn.Sequential(
            nn.Conv2d(32, 256, kernel_size=3, stride=1, padding=1),
            nn.PixelShuffle(scale_factor),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, num_channels, kernel_size=9, stride=1, padding=4)
        )

    def forward(self, x):
        out = self.conv1(x)
        out = self.activation(out)
        residual = out
        out = self.residual_blocks(out)
        out = self.conv2(out)
        out = self.activation(out)
        out = self.conv3(out)
        out = self.activation(out)
        out = self.conv4(out)
        out = self.bn(out)

        residual = self.conv3(residual)

        out += residual
        out = self.upscale(out)
        return out

class ResidualBlock(nn.Module):
    def __init__(self, num_channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(num_channels, num_channels, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(num_channels)
        self.activation = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(num_channels)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.activation(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += x
        return out

class PerceptualLoss(nn.Module):
    def __init__(self):
        super(PerceptualLoss, self).__init__()
        vgg = models.vgg19(pretrained=True)
        self.features = nn.Sequential(*list(vgg.features.children())[:36]).eval()
        for param in self.features.parameters():
            param.requires_grad = False
        self.loss = nn.MSELoss()

    def forward(self, input, target):
        input_features = self.features(input)
        target_features = self.features(target)
        return self.loss(input_features, target_features)