# User Defined Network

## Includes

In [None]:
# mass includes
import torch as t
from ipynb.fs.full.module import BasicModule
from torch.nn.functional import grid_sample

## Enhancer

In [None]:
def guidedSlice(grid, guide):
    # get coordinates for interpolation
    batch, _, hei, wid = guide.size()
    grid_y, grid_x = t.meshgrid([
        t.arange(0, hei, device=grid.device),
        t.arange(0, wid, device=grid.device)
    ])
    grid_y = grid_y.repeat(batch, 1, 1).unsqueeze(3) / (hei - 1) * 2 - 1
    grid_x = grid_x.repeat(batch, 1, 1).unsqueeze(3) / (wid - 1) * 2 - 1
    guide = guide.permute(0, 2, 3, 1).contiguous()
    guide_grid = t.cat([guide, grid_y, grid_x], dim=3).unsqueeze(1)

    # 3d slicing
    samples = grid_sample(grid,
                          guide_grid,
                          mode='bilinear',
                          align_corners=True)

    return samples.squeeze(2)


class dpConv2d(BasicModule):
    def __init__(self, in_chns, out_chns, stride, padding):
        super(dpConv2d, self).__init__()
        self.feats = t.nn.Sequential(
            t.nn.Conv2d(in_chns,
                        in_chns,
                        3,
                        stride=stride,
                        padding=padding,
                        groups=in_chns,
                        bias=False), t.nn.ReLU(inplace=True),
            t.nn.Conv2d(in_chns, out_chns, 1))

    def forward(self, x):

        return self.feats(x)


class invResidual(BasicModule):
    def __init__(self, in_chns, out_chns, stride):
        super(invResidual, self).__init__()

        inter_chns = in_chns * 6
        self.feats = t.nn.Sequential(t.nn.Conv2d(in_chns, inter_chns, 1),
                                     t.nn.ReLU(inplace=True),
                                     dpConv2d(inter_chns, out_chns, stride, 1))
        if stride == 1 and in_chns == out_chns:
            self.add_res = True
        else:
            self.add_res = False

    def forward(self, x):
        if self.add_res:

            return x + self.feats(x)
        else:

            return self.feats(x)


class bottleNeck(BasicModule):
    def __init__(self, in_chns, out_chns, stride, repeats):
        super(bottleNeck, self).__init__()

        layers = []
        for index in range(0, repeats):
            if index == 0:
                layers.append(invResidual(in_chns, out_chns, stride))
            else:
                layers.append(invResidual(out_chns, out_chns, 1))
        self.feats = t.nn.Sequential(*layers)

    def forward(self, x):

        return self.feats(x)


class Enhancer(BasicModule):
    def __init__(self, pretrain=True, z_res=16, ilm_iter=9, clr_pcc=60):
        super(Enhancer, self).__init__()
        if pretrain == True:
            self.model_name = 'Enhancer_final'
        else:
            self.model_name = 'Enhancer_pretrain'

        # model settings
        self.z_res = z_res  # z-axis resolution
        self.ilm_iter = ilm_iter  # number of iteration
        self.clr_pcc = clr_pcc  # number of polynominal coefficients

        # feature extraction
        self.feats = t.nn.Sequential(t.nn.Conv2d(3, 32, 3, padding=1),
                                     t.nn.ReLU(inplace=True),
                                     dpConv2d(32, 16, 1, 1),
                                     bottleNeck(16, 24, 2, 2),
                                     bottleNeck(24, 32, 2, 3),
                                     bottleNeck(32, 64, 2, 4),
                                     bottleNeck(64, 96, 1, 3),
                                     bottleNeck(96, 160, 2, 3))

        # illumination path
        self.ilm_loc = t.nn.Sequential(t.nn.Conv2d(160, 96, 3, padding=1),
                                       t.nn.ReLU(inplace=True),
                                       t.nn.Conv2d(96, 96, 3, padding=1),
                                       t.nn.ReLU(inplace=True),
                                       t.nn.Conv2d(96, 96, 3, padding=1))
        self.ilm_glb = t.nn.Sequential(
            t.nn.Conv2d(160, 96, 3, stride=2, padding=1),
            t.nn.ReLU(inplace=True), t.nn.Conv2d(96,
                                                 96,
                                                 3,
                                                 stride=2,
                                                 padding=1),
            t.nn.ReLU(inplace=True), t.nn.Conv2d(96,
                                                 96,
                                                 3,
                                                 stride=2,
                                                 padding=1),
            t.nn.ReLU(inplace=True), t.nn.Flatten(), t.nn.Linear(384, 256),
            t.nn.ReLU(inplace=True), t.nn.Linear(256, 128),
            t.nn.ReLU(inplace=True), t.nn.Linear(128, 96), t.nn.Sigmoid())
        self.ilm_out = t.nn.Sequential(
            t.nn.ReLU(inplace=True),
            t.nn.Conv2d(96, self.z_res * self.ilm_iter, 1))
        self.ilm_map = t.nn.Sequential(t.nn.Conv2d(3, 32, 3, padding=1),
                                       t.nn.ReLU(inplace=True),
                                       t.nn.Conv2d(32, 32, 3, padding=1),
                                       t.nn.ReLU(inplace=True),
                                       t.nn.Conv2d(32, 1, 3, padding=1),
                                       t.nn.Tanh())

        # color path
        self.clr_out = t.nn.Sequential(
            t.nn.Conv2d(160, 96, 3, stride=2, padding=1),
            t.nn.ReLU(inplace=True), t.nn.Conv2d(96,
                                                 96,
                                                 3,
                                                 stride=2,
                                                 padding=1),
            t.nn.ReLU(inplace=True), t.nn.Conv2d(96,
                                                 96,
                                                 3,
                                                 stride=2,
                                                 padding=1),
            t.nn.ReLU(inplace=True), t.nn.Flatten(), t.nn.Linear(384, 256),
            t.nn.ReLU(inplace=True), t.nn.Linear(256, 128),
            t.nn.ReLU(inplace=True), t.nn.Linear(128, self.clr_pcc))

    def forward(self, down_img, full_img):
        # extract features
        shared_feat = self.feats(down_img)

        # ilmination enhancement branch
        ilm_loc = self.ilm_loc(shared_feat)
        ilm_glb = self.ilm_glb(shared_feat)
        ilm_glb = ilm_glb.view(-1, 96, 1, 1)
        ilm_grid = self.ilm_out(ilm_loc * ilm_glb)
        ilm_grid = ilm_grid.view(-1, self.ilm_iter, self.z_res, 16, 16)
        ilm_guide = self.ilm_map(full_img)

        # bilateral umsapling
        ilm_coes = guidedSlice(ilm_grid, ilm_guide)

        # clr enhancement branch
        clr_coes = self.clr_out(shared_feat)

        return ilm_coes, clr_coes

## Denoiser

In [None]:
class sptAtt(BasicModule):
    def __init__(self):
        super(sptAtt, self).__init__()

        # point weights
        self.weights = t.nn.Sequential(t.nn.Conv2d(2, 1, 3, padding=1),
                                       t.nn.Sigmoid())

    def forward(self, x):
        # global max pooling
        gmp_maps = t.max(x, 1)[0].unsqueeze(1)

        # global average pooling
        gap_maps = t.mean(x, 1).unsqueeze(1)

        # get weights
        weight = self.weights(t.cat([gmp_maps, gap_maps], dim=1))

        return x * weight


class chnAtt(BasicModule):
    def __init__(self, n_feats, reduction=16):
        super(chnAtt, self).__init__()

        # channel weights
        intp_feats = n_feats // reduction
        self.weights = t.nn.Sequential(t.nn.AdaptiveAvgPool2d(1),
                                       t.nn.Conv2d(n_feats, intp_feats, 1),
                                       t.nn.ReLU(inplace=True),
                                       t.nn.Conv2d(intp_feats, n_feats, 1),
                                       t.nn.Sigmoid())

    def forward(self, x):
        # get weights
        weight = self.weights(x)

        return x * weight


class dualAttBlk(BasicModule):
    def __init__(self, n_feats):
        super(dualAttBlk, self).__init__()

        # feature extraction
        self.feats = t.nn.Sequential(
            t.nn.Conv2d(n_feats, n_feats, 3, padding=1), t.nn.PReLU(n_feats),
            t.nn.Conv2d(n_feats, n_feats, 3, padding=1))

        # dual attentions
        self.spt_att = sptAtt()
        self.chn_att = chnAtt(n_feats)

        # output residual
        self.res = t.nn.Conv2d(n_feats * 2, n_feats, 1)

    def forward(self, x):
        # extract features
        feats = self.feats(x)

        # get attention maps
        spt_maps = self.spt_att(feats)
        chn_maps = self.chn_att(feats)

        # get residual maps
        res = self.res(t.cat([spt_maps, chn_maps], dim=1))

        return x + res


class recResBlk(BasicModule):
    def __init__(self, n_feats, num_dab):
        super(recResBlk, self).__init__()

        # output residual
        modules = []
        modules = [dualAttBlk(n_feats) for _ in range(num_dab)]
        modules.append(t.nn.Conv2d(n_feats, n_feats, 3, padding=1))
        self.res = t.nn.Sequential(*modules)

    def forward(self, x):
        # get residual maps
        res = self.res(x)

        return x + res


class Denoiser(BasicModule):
    def __init__(self, cam_model, num_rrg=6, num_dab=2, n_feats=64):
        super(Denoiser, self).__init__()
        self.model_name = 'Denoiser_%s' % cam_model

        head_modules = [t.nn.Conv2d(12, n_feats, 3, padding=1)]

        body_modules = [
            recResBlk(n_feats, num_dab=num_dab) for _ in range(num_rrg)
        ]
        body_modules.append(t.nn.Conv2d(n_feats, n_feats, 3, padding=1))
        body_modules.append(t.nn.PReLU(n_feats))

        tail_modules = [t.nn.Conv2d(n_feats, 4, 3, padding=1)]

        self.head = t.nn.Sequential(*head_modules)
        self.body = t.nn.Sequential(*body_modules)
        self.tail = t.nn.Sequential(*tail_modules)

        # convolutional LSTM
        self.lstm_f = t.nn.Sequential(
            t.nn.Conv2d(n_feats * 2, n_feats, 3, padding=1), t.nn.Sigmoid())
        self.lstm_i = t.nn.Sequential(
            t.nn.Conv2d(n_feats * 2, n_feats, 3, padding=1), t.nn.Sigmoid())
        self.lstm_g = t.nn.Sequential(
            t.nn.Conv2d(n_feats * 2, n_feats, 3, padding=1), t.nn.Tanh())
        self.lstm_o = t.nn.Sequential(
            t.nn.Conv2d(n_feats * 2, n_feats, 3, padding=1), t.nn.Sigmoid())

    def forward(self, in_img, noise_map, img_res=None):
        # head features
        if img_res is None:
            feat = self.head(
                t.cat([in_img, noise_map,
                       t.zeros_like(in_img)], dim=1))
        else:
            feat = self.head(t.cat([in_img, noise_map, img_res], dim=1))

        # initialize hidden states
        if img_res is None:
            img_res = t.zeros_like(in_img)
            self.state_h = t.zeros_like(feat)
            self.state_c = t.zeros_like(feat)

        # back projection
        feat = t.cat([feat, self.state_h], dim=1)
        self.state_c = self.lstm_f(feat) * self.state_c + self.lstm_i(
            feat) * self.lstm_g(feat)
        self.state_h = self.lstm_o(feat) * t.tanh(self.state_c)

        feat = self.state_h + self.body(self.state_h)

        out_img = self.tail(feat)
        out_img = t.clamp(out_img, 0.0, 1.0)

        return out_img

## Discriminator

In [None]:
class Discriminator(BasicModule):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model_name = 'Discriminator'

        # feature extraction
        self.feats = t.nn.Sequential(
            t.nn.Conv2d(6, 8, 4, stride=2, padding=1), t.nn.ReLU(),
            t.nn.Conv2d(8, 16, 4, stride=2, padding=1), t.nn.ReLU(),
            t.nn.Conv2d(16, 32, 4, stride=2, padding=1), t.nn.ReLU(),
            t.nn.Conv2d(32, 64, 4, stride=2, padding=1), t.nn.ReLU(),
            t.nn.Conv2d(64, 128, 4, stride=2, padding=1), t.nn.ReLU(),
            t.nn.Conv2d(128, 256, 4, stride=2, padding=1), t.nn.ReLU(),
            t.nn.Conv2d(256, 512, 4, stride=2, padding=1), t.nn.ReLU(),
            t.nn.Flatten(), t.nn.Linear(2048, 256), t.nn.ReLU(),
            t.nn.Linear(256, 128), t.nn.ReLU(), t.nn.Linear(128, 1))

    def forward(self, img1, img2):
        out_label = self.feats(t.cat([img1, img2], dim=1))

        return out_label