# User Defined Network

## Includes

In [None]:
# mass includes
import torch as t
import numpy as np
from ipynb.fs.full.module import BasicModule

## Feature extraction network

In [None]:
class Encoder(BasicModule):

    def __init__(self):
        super(Encoder, self).__init__()
        self.model_name = 'Encoder'

        # feature extraction convnet
        self.feats = t.nn.Sequential(
            t.nn.Conv2d(3, 32, 3, padding=1, bias=False), t.nn.BatchNorm2d(32),
            t.nn.ReLU(), t.nn.MaxPool2d(2),
            t.nn.Conv2d(32, 64, 3, padding=1, bias=False),
            t.nn.BatchNorm2d(64), t.nn.ReLU(), t.nn.MaxPool2d(2),
            t.nn.Conv2d(64, 128, 3, padding=1, bias=False),
            t.nn.BatchNorm2d(128), t.nn.ReLU(), t.nn.MaxPool2d(2),
            t.nn.Conv2d(128, 256, 3, padding=1, bias=False),
            t.nn.BatchNorm2d(256), t.nn.ReLU(),
            t.nn.Conv2d(256, 256, 3, padding=1, bias=False))

        # for pooling binary mask
        self.mpool = t.nn.MaxPool2d(2**3)

    def forward(self, in_img, in_mask):
        out_feats = self.feats(in_img)
        out_mask = self.mpool(in_mask)

        return out_feats, out_mask

## Feature summary network

In [None]:
class Summarizer(BasicModule):

    def __init__(self):
        super(Summarizer, self).__init__()
        self.model_name = 'Summarizer'

        # feature sunmmation
        self.feats = t.nn.Sequential(
            t.nn.Conv2d(256, 512, 3, padding=1, bias=False),
            t.nn.BatchNorm2d(512), t.nn.ReLU(), t.nn.MaxPool2d(8),
            t.nn.Flatten(), t.nn.Linear(512 * 2 * 2, 256, bias=False),
            t.nn.BatchNorm1d(256), t.nn.ReLU(),
            t.nn.Linear(256, 128, bias=False))

    def forward(self, in_feats, in_mask):
        out_feats = self.feats(in_feats)

        return out_feats

## Mutual infomstion discriminator

In [None]:
class MIDiscriminator(BasicModule):

    def __init__(self, loc_chnls=256, glb_chnls=128, out_chnls=2048):
        super(MIDiscriminator, self).__init__()
        self.model_name = 'Disc-mi'

        # from table 8
        self.glb_feats1 = t.nn.Sequential(
            t.nn.Linear(glb_chnls, out_chnls, bias=False),
            t.nn.BatchNorm1d(out_chnls), t.nn.ReLU(),
            t.nn.Linear(out_chnls, out_chnls, bias=False))
        self.glb_feats2 = t.nn.Sequential(
            t.nn.Linear(glb_chnls, out_chnls, bias=False), t.nn.ReLU())

        # from table 9
        self.loc_feats1 = t.nn.Sequential(
            t.nn.Conv2d(loc_chnls, out_chnls, 1, bias=False),
            t.nn.BatchNorm2d(out_chnls), t.nn.ReLU(),
            t.nn.Conv2d(out_chnls, out_chnls, 1, bias=False))
        self.loc_feats2 = t.nn.Sequential(
            t.nn.Conv2d(loc_chnls, out_chnls, 1, bias=False), t.nn.ReLU())
        self.layer_norm = t.nn.LayerNorm(out_chnls)

        # initializ local shortcut (loc_feats2)
        eye_mask = np.zeros((out_chnls, loc_chnls, 1, 1), dtype=np.bool)
        for i in range(loc_chnls):
            eye_mask[i, i, 0, 0] = 1
        self.loc_feats2[0].weight.data.uniform_(-0.01, 0.01)
        self.loc_feats2[0].weight.data.masked_fill_(t.tensor(eye_mask), 1.0)

    def forward(self, in_loc_feats, in_glb_feats, in_loc_mask):
        # encode local features
        out_loc_feats = self.loc_feats1(in_loc_feats) + self.loc_feats2(
            in_loc_feats)
        out_loc_feats = t.permute(out_loc_feats, (0, 2, 3, 1))
        out_loc_feats = self.layer_norm(out_loc_feats)
        out_loc_feats = t.permute(out_loc_feats, (0, 3, 1, 2))

        # encode global features
        out_glb_feats = self.glb_feats1(in_glb_feats) + self.glb_feats2(
            in_glb_feats)

        # reshape
        out_loc_feats = t.flatten(out_loc_feats, start_dim=-2)
        out_glb_feats = out_glb_feats.unsqueeze(-1)
        out_loc_mask = t.flatten(in_loc_mask, start_dim=-2)

        return out_loc_feats, out_glb_feats, out_loc_mask

## Prior discriminator

In [None]:
class PirorDiscriminator(BasicModule):

    def __init__(self, in_chnls=128):
        super(PirorDiscriminator, self).__init__()
        self.model_name = 'Disc-pr'

        # feature sunmmation
        self.feats = t.nn.Sequential(t.nn.Linear(in_chnls, 512, bias=False),
                                     t.nn.BatchNorm1d(512), t.nn.ReLU(),
                                     t.nn.Linear(512, 512, bias=False),
                                     t.nn.BatchNorm1d(512), t.nn.ReLU(),
                                     t.nn.Linear(512, 1), t.nn.Sigmoid())

    def forward(self, in_feats):
        out_feats = self.feats(in_feats.view(in_feats.size(0), -1))

        return out_feats