# User Defined Network

## Includes

In [10]:
import torch as t
import torchvision as tv
from ipynb.fs.full.module import BasicModule
from collections import namedtuple

## Loss net (vgg16)

In [11]:
class LossNet(BasicModule):
    def __init__(self, requires_grad=False):
        super(lossNet, self).__init__()
        features = list(tv.models.vgg16(pretrained=True).features)[:23]
        self.features = t.nn.ModuleList(features).eval()

        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, x):
        results = []

        # use outputs of relu1_2, relu2_2, relu3_3, relu4_3 as loss
        for index, model in enumerate(self.features):
            x = model(x)
            if index in {3, 8, 15, 22}:
                results.append(x)
        loss_name = namedtuple("vggOutputs",
                               ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3'])

        return loss_name(*results)

## Transform net

In [12]:
class residual2d(BasicModule):
    def __init__(self, channels):
        super(residual2d, self).__init__()
        
        self.model = t.nn.Sequential(
            t.nn.Conv2d(
                in_channels=channels,
                out_channels=channels,
                kernel_size=3,
                padding=1),
            t.nn.BatchNorm2d(num_features=128),
            t.nn.ReLU(),
            t.nn.Conv2d(
                in_channels=channels,
                out_channels=channels,
                kernel_size=3,
                padding=1),
            t.nn.BatchNorm2d(num_features=128))

    def forward(self, x):
        result = self.model(x)

        return result + x

In [None]:
class TransformNet(BasicModule):
    def __init__(self):
        super(transformNet, self).__init__()
        self.model_name = 'TransformNet'

        # convolutional layers
        self.convLayers = t.nn.Sequential(
            t.nn.Conv2d(
                in_channels=3, out_channels=32, kernel_size=9, padding=4),
            t.nn.BatchNorm2d(num_features=32),
            t.nn.ReLU(),
            t.nn.Conv2d(
                in_channels=32,
                out_channels=64,
                kernel_size=3,
                stride=2,
                padding=1),
            t.nn.BatchNorm2d(num_features=64),
            t.nn.ReLU(),
            t.nn.Conv2d(
                in_channels=64,
                out_channels=128,
                kernel_size=3,
                stride=2,
                padding=1),
            t.nn.BatchNorm2d(num_features=128),
            t.nn.ReLU())

        # residual layers
        self.resLayers = t.nn.Sequential(
            residual2d(128), residual2d(128), residual2d(128), residual2d(128),
            residual2d(128))

        # deconv layers
        self.deconvLayers = t.nn.Sequential(
            t.nn.ConvTranspose2d(
                in_channels=128,
                out_channels=64,
                kernel_size=3,
                stride=2,
                padding=1,
                output_padding=1),
            t.nn.BatchNorm2d(num_features=64),
            t.nn.ReLU(),
            t.nn.ConvTranspose2d(
                in_channels=64,
                out_channels=32,
                kernel_size=3,
                stride=2,
                padding=1,
                output_padding=1),
            t.nn.BatchNorm2d(num_features=32),
            t.nn.ReLU(),
            t.nn.ConvTranspose2d(
                in_channels=32, out_channels=3, kernel_size=9, padding=4),
            t.nn.BatchNorm2d(num_features=3),
            t.nn.Tanh())

    def forward(self, x):
        x = self.convLayers(x)
        x = self.resLayers(x)
        x = self.deconvLayers(x)

        return x * 127.5 + 127.5