# User Defined Network

## Includes

In [None]:
# mass includes
import math as m
import torch as t
from ipynb.fs.full.module import BasicModule

 ## Modified U-Net

In [None]:
class downBlock(BasicModule):
    def __init__(self, in_channels, out_channels):
        super(downBlock, self).__init__()

        self.features = t.nn.Sequential(
            t.nn.MaxPool2d(2, 2),
            t.nn.Conv2d(in_channels, out_channels, 3, padding=1),
            t.nn.LeakyReLU(0.2, inplace=True),
            t.nn.Conv2d(out_channels, out_channels, 3, padding=1),
            t.nn.LeakyReLU(0.2, inplace=True))

    def forward(self, x):

        return self.features(x)


class upBlock(BasicModule):
    def __init__(self, in_channels, out_channels):
        super(upBlock, self).__init__()

        inter_channels = out_channels * 2
        self.features = t.nn.Sequential(
            t.nn.Conv2d(in_channels, inter_channels, 3, padding=1),
            t.nn.LeakyReLU(0.2, inplace=True),
            t.nn.Conv2d(inter_channels, inter_channels, 3, padding=1),
            t.nn.LeakyReLU(0.2, inplace=True),
            t.nn.ConvTranspose2d(inter_channels, out_channels, 2, stride=2))

    def forward(self, x):

        return self.features(x)


class UNet(BasicModule):
    def __init__(self):
        super(UNet, self).__init__()
        self.model_name = 'UNet'

        # head block
        self.head = t.nn.Sequential(
            t.nn.Conv2d(4, 32, 3, padding=1), t.nn.LeakyReLU(
                0.2, inplace=True), t.nn.Conv2d(32, 32, 3, padding=1),
            t.nn.LeakyReLU(0.2, inplace=True))

        # block 1-4
        self.d1 = downBlock(32, 64)
        self.d2 = downBlock(64, 128)
        self.d3 = downBlock(128, 256)

        # bottom block
        self.bottom = t.nn.Sequential(
            t.nn.MaxPool2d(2, 2), t.nn.Conv2d(256, 512, 3, padding=1),
            t.nn.LeakyReLU(0.2, inplace=True),
            t.nn.Conv2d(512, 512, 3, padding=1),
            t.nn.LeakyReLU(0.2, inplace=True),
            t.nn.ConvTranspose2d(512, 256, 2, stride=2))

        # blcok 5-8
        self.u1 = upBlock(512, 128)
        self.u2 = upBlock(256, 64)
        self.u3 = upBlock(128, 32)

        # final block
        self.final = t.nn.Sequential(
            t.nn.Conv2d(64, 32, 3, padding=1), t.nn.LeakyReLU(
                0.2, inplace=True), t.nn.Conv2d(32, 32, 3, padding=1),
            t.nn.LeakyReLU(0.2, inplace=True), t.nn.Conv2d(32, 12, 1),
            t.nn.PixelShuffle(2))

        # initialization
        self.initLayers()

    def forward(self, x):
        out_head = self.head(x)
        out_d1 = self.d1(out_head)
        out_d2 = self.d2(out_d1)
        out_d3 = self.d3(out_d2)
        out_bottom = self.bottom(out_d3)
        out_u1 = self.u1(t.cat([out_d3, out_bottom], dim=1))
        out_u2 = self.u2(t.cat([out_d2, out_u1], dim=1))
        out_u3 = self.u3(t.cat([out_d1, out_u2], dim=1))
        out_final = self.final(t.cat([out_head, out_u3], dim=1))

        return out_final

    def initLayers(self):
        for module in self.modules():
            if isinstance(module, t.nn.Conv2d):
                fan = module.kernel_size[0] * module.kernel_size[
                    1] * module.out_channels
                module.weight.data.normal_(0.0, m.sqrt(2. / fan))
                module.bias.data.zero_()
            elif isinstance(module, t.nn.ConvTranspose2d):
                fan = module.kernel_size[0] * module.kernel_size[
                    1] * module.out_channels
                module.weight.data.normal_(0.0, m.sqrt(2. / fan))
                module.bias.data.zero_()