# User Defined Network

## Includes

In [None]:
# mass includes
import math as m
import torch as t
from collections import OrderedDict
from ipynb.fs.full.module import BasicModule

## Modules

In [None]:
class resDense(BasicModule):
    def __init__(self, channels, layers, growth_rate):
        super(resDense, self).__init__()

        # residual dense layers
        self.features = t.nn.ModuleList([])
        inter_channels = channels
        for index in range(0, layers):
            self.features.append(
                t.nn.Conv2d(inter_channels,
                            growth_rate,
                            3,
                            padding=1,
                            bias=False))
            inter_channels += growth_rate
        self.relu = t.nn.ReLU()

        # fusion layer
        self.fusion = t.nn.Conv2d(inter_channels, channels, 1, bias=False)

    def forward(self, x):
        res = x
        for layer in self.features:
            out = self.relu(layer(res))
            res = t.cat([res, out], dim=1)
        out = self.fusion(res)

        return x + out


class channelAtt(BasicModule):
    def __init__(self, channels):
        super(channelAtt, self).__init__()

        # squeeze-excitation layer
        self.glb_pool = t.nn.AdaptiveAvgPool2d((1, 1))
        self.squeeze_excite = t.nn.Sequential(
            t.nn.Linear(channels, int(channels / 16)), t.nn.LeakyReLU(0.2),
            t.nn.Linear(int(channels / 16), channels), t.nn.Sigmoid())

    def forward(self, x):
        scale = self.glb_pool(x)
        scale = self.squeeze_excite(scale.squeeze())
        x = scale.view((x.size(0), x.size(1), 1, 1)) * x

        return x


class encode(BasicModule):
    def __init__(self, in_channels, out_channels, max_pool=True):
        super(encode, self).__init__()

        # features
        if max_pool:
            self.features = t.nn.Sequential(
                t.nn.MaxPool2d((2, 2)),
                t.nn.Conv2d(in_channels, out_channels, 3, padding=1),
                t.nn.LeakyReLU(0.2),
                t.nn.Conv2d(out_channels, out_channels, 3, padding=1),
                channelAtt(out_channels), t.nn.LeakyReLU(0.2))
        else:
            self.features = t.nn.Sequential(
                t.nn.Conv2d(in_channels, out_channels, 3, padding=1),
                t.nn.LeakyReLU(0.2),
                t.nn.Conv2d(out_channels, out_channels, 3, padding=1),
                channelAtt(out_channels), t.nn.LeakyReLU(0.2))

    def forward(self, x):
        x = self.features(x)

        return x


class skipConn(BasicModule):
    def __init__(self, in_channels, out_channels, avg_pool=True):
        super(skipConn, self).__init__()

        # features
        if avg_pool:
            self.features = t.nn.Sequential(
                t.nn.AvgPool2d((2, 2)),
                t.nn.Conv2d(in_channels, out_channels, 1),
                channelAtt(out_channels), t.nn.Tanh())
        else:
            self.features = t.nn.Sequential(
                t.nn.Conv2d(in_channels, out_channels, 1),
                channelAtt(out_channels), t.nn.Tanh())

    def forward(self, x):
        x = self.features(x)

        return x


class decode(BasicModule):
    def __init__(self,
                 in_channels,
                 inter_channels,
                 out_channels,
                 up_sample=True):
        super(decode, self).__init__()

        # features
        if up_sample:
            self.features = t.nn.Sequential(
                t.nn.Conv2d(in_channels, inter_channels, 1),
                t.nn.Conv2d(inter_channels, inter_channels, 3, padding=1),
                t.nn.LeakyReLU(0.2),
                t.nn.Upsample(scale_factor=2, mode='nearest'),
                t.nn.Conv2d(inter_channels, out_channels, 3, padding=1),
                channelAtt(out_channels), t.nn.LeakyReLU(0.2))
        else:
            self.features = t.nn.Sequential(
                t.nn.Conv2d(in_channels, inter_channels, 1),
                t.nn.Conv2d(inter_channels, inter_channels, 3, padding=1),
                t.nn.LeakyReLU(0.2),
                t.nn.Conv2d(inter_channels, out_channels, 3, padding=1),
                t.nn.LeakyReLU(0.2))

    def forward(self, x):
        x = self.features(x)

        return x

## r2rNet

In [None]:
class r2rNet(BasicModule):
    def __init__(self, channels=64, rdbs=4, convs=8, growth_rate=32):
        super(r2rNet, self).__init__()
        self.model_name = 'r2rNet'
        self.rdbs = rdbs

        # feature extraction
        self.head = t.nn.Conv2d(7, channels, 3, padding=1)

        # RDBs
        self.features = t.nn.ModuleList(
            [t.nn.Conv2d(channels, channels, 3, padding=1)])
        for index in range(0, rdbs):
            self.features.append(resDense(channels, convs, growth_rate))
        self.features.append(t.nn.Conv2d(channels * rdbs, channels, 1))
        self.features.append(t.nn.Conv2d(channels, channels, 3, padding=1))

        # final fusion
        self.final = t.nn.Sequential(
            t.nn.Conv2d(channels, channels, 3, padding=1),
            t.nn.Conv2d(channels, 4, 1))

    def forward(self, img, wb):
        out = self.head(t.cat([img, wb], dim=1))
        res_0 = self.features[0](out)
        res_n = [res_0]
        for index in range(1, self.rdbs + 1):
            res_n.append(self.features[index](res_n[index - 1]))
        res_r2 = self.features[-2](t.cat(res_n[1:], dim=1))
        res_r1 = self.features[-1](res_r2)
        out = self.final(out + res_r1)

        return out

## Gain estimation module

In [None]:
class gainEst(BasicModule):
    def __init__(self):
        super(gainEst, self).__init__()
        self.model_name = 'gainEst'

        # encoders
        self.head = encode(3, 64, max_pool=False)
        self.down1 = encode(64, 96, max_pool=True)
        self.down2 = encode(96, 128, max_pool=True)
        self.down3 = encode(128, 192, max_pool=True)

        # bottleneck
        self.bottleneck = t.nn.Sequential(
            t.nn.MaxPool2d(2, 2), t.nn.Conv2d(192, 256, 3, padding=1),
            t.nn.LeakyReLU(0.2), t.nn.Conv2d(256, 256, 3, padding=1),
            t.nn.Upsample(scale_factor=2, mode='nearest'),
            t.nn.Conv2d(256, 192, 3, padding=1), channelAtt(192),
            t.nn.LeakyReLU(0.2))

        # decoders
        self.up1 = decode(384, 384, 128, up_sample=True)
        self.up2 = decode(256, 256, 96, up_sample=True)
        self.up3 = decode(192, 192, 64, up_sample=True)
        self.seg_out = t.nn.Sequential(decode(128, 128, 64, up_sample=False),
                                       t.nn.Conv2d(64, 2, 1))

        # external actication
        self.sigmoid = t.nn.Sigmoid()

        # prediction
        self.features = t.nn.Sequential(
            t.nn.Conv2d(5, 64, 3, stride=2, padding=1), t.nn.LeakyReLU(0.2),
            t.nn.Conv2d(64, 96, 3, stride=2, padding=1), t.nn.LeakyReLU(0.2),
            t.nn.Conv2d(96, 128, 3, stride=2, padding=1), t.nn.LeakyReLU(0.2),
            t.nn.Conv2d(128, 192, 3, stride=2, padding=1), t.nn.LeakyReLU(0.2),
            t.nn.Conv2d(192, 256, 3, stride=2, padding=1), t.nn.LeakyReLU(0.2))
        self.amp_out = t.nn.Sequential(t.nn.Linear(8 * 6 * 256,
                                                   128), t.nn.LeakyReLU(0.2),
                                       t.nn.Linear(128, 64),
                                       t.nn.LeakyReLU(0.2), t.nn.Linear(64, 2))

        # initialization
        self.initLayers()

    def forward(self, thumb_img, struct_img):
        # segmentation
        out_head = self.head(struct_img)
        out_d1 = self.down1(out_head)
        out_d2 = self.down2(out_d1)
        out_d3 = self.down3(out_d2)
        out_bottleneck = self.bottleneck(out_d3)
        out_u1 = self.up1(t.cat([out_d3, out_bottleneck], dim=1))
        out_u2 = self.up2(t.cat([out_d2, out_u1], dim=1))
        out_u3 = self.up3(t.cat([out_d1, out_u2], dim=1))
        out_mask = self.seg_out(t.cat([out_head, out_u3], dim=1))

        # prediction
        out_features = self.features(
            t.cat([thumb_img, self.sigmoid(out_mask)], dim=1))
        out_amp = self.amp_out(out_features.view(out_features.size(0), -1))
        out_amp = t.clamp(out_amp, 0.0, 1.0)

        return out_mask, out_amp

## Raw processing module

In [None]:
class ispNet(BasicModule):
    def __init__(self):
        super(ispNet, self).__init__()

        # encoders
        self.head = encode(8, 64, max_pool=False)
        self.down1 = encode(64, 64, max_pool=True)
        self.down2 = encode(64, 64, max_pool=True)

        # skip connections
        self.skip1 = skipConn(1, 64, avg_pool=False)
        self.skip2 = skipConn(64, 64, avg_pool=True)
        self.skip3 = skipConn(64, 64, avg_pool=True)

        # decoders
        self.up1 = decode(128, 64, 64, up_sample=True)
        self.up2 = decode(128, 64, 64, up_sample=True)
        self.srgb_out = t.nn.Sequential(
            decode(128, 64, 64, up_sample=False),
            t.nn.Upsample(scale_factor=2, mode='nearest'),
            t.nn.Conv2d(64, 3, 3, padding=1))

    def forward(self, color_map, mag_map, amp, wb):
        # to prevent saturation
        mag_map = amp.view(-1, 1, 1, 1) * mag_map
        mag_map = t.nn.functional.tanh(mag_map - 0.5)
        max_mag = 2.0 * amp.view(-1, 1, 1, 1)
        max_mag = t.nn.functional.tanh(max_mag - 0.5)
        mag_map = mag_map / max_mag

        # encoder outputs
        out_head = self.head(t.cat([color_map, mag_map, wb], dim=1))
        out_d1 = self.down1(out_head)
        out_d2 = self.down2(out_d1)

        # skip connection outputs
        out_s1 = self.skip1(mag_map)
        out_s2 = self.skip2(out_head)
        out_s3 = self.skip3(out_d1)

        # decoder outputs
        out_u1 = self.up1(t.cat([out_s3, out_d2], dim=1))
        out_u2 = self.up2(t.cat([out_s2, out_u1], dim=1))
        out_srgb = self.srgb_out(t.cat([out_s1, out_u2], dim=1))
        out_srgb = t.clamp(out_srgb, 0.0, 1.0)

        return out_srgb


class rawProcess(BasicModule):
    def __init__(self):
        super(rawProcess, self).__init__()
        self.model_name = 'rawProcess'

        # isp module
        self.isp_net = ispNet()

        # fusion
        self.fusion = t.nn.Sequential(t.nn.Conv2d(6, 128, 3, padding=1),
                                      channelAtt(128),
                                      t.nn.Conv2d(128, 3, 3, padding=1))

        # initialization
        self.initLayers()

    def forward(self, raw_data, amp_high, amp_low, wb):
        # convert to color map and mgnitude map
        mag_map = t.sqrt(t.sum(t.pow(raw_data, 2), 1, keepdim=True))
        color_map = raw_data / (mag_map + 1e-4)

        # convert to sRGB images
        out_high = self.isp_net(color_map, mag_map, amp_high, wb)
        out_low = self.isp_net(color_map, mag_map, amp_low, wb)

        # image fusion
        out_fused = self.fusion(t.cat([out_high, out_low], dim=1))
        out_fused = t.clamp(out_fused, 0.0, 1.0)

        return out_high, out_low, out_fused