Design fingerprint generator model architecture

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
ROOT_PATH = "/content/drive/My Drive/synthetic_image_detection/rule_based/".replace(" ", "\\")

In [None]:
def concat_curr(prev, curr):
    diffY = prev.size()[2] - curr.size()[2]
    diffX = prev.size()[3] - curr.size()[3]

    curr = F.pad(curr, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])

    x = torch.cat([prev, curr], dim=1)
    return x

In [None]:
from torch import nn
import torch.nn.functional as F
import torch


class ConvLayer(nn.Module):
    def __init__(self, in_c, out_c, kernel, stride,
                 padding=0, dilation=1, bias=True, activ=None, norm=None,
                 pool=None):
        super(ConvLayer, self).__init__()
        self.conv = nn.Sequential()
        self.conv.add_module('conv', nn.Conv2d(in_c, out_c, kernel_size=kernel,
                                               stride=stride, dilation=dilation, padding=padding, bias=bias))

        if activ == 'leak':
            activ = nn.LeakyReLU(inplace=True)
        elif activ == 'relu':
            activ = nn.ReLU(inplace=True)
        elif activ == 'pleak':
            activ = nn.PReLU()
        elif activ == 'gelu':
            activ = nn.GELU()
        elif activ == 'selu':
            activ = nn.SELU()
        elif activ == 'sigmoid':
            activ = nn.Sigmoid()
        elif activ == 'softmax':
            activ = nn.Softmax(dim=1)
        elif activ == 'tanh':
            activ = nn.Tanh()
        if norm == 'bn':
            norm = nn.BatchNorm2d(out_c)
        if pool == 'max':
            pool = nn.MaxPool2d(2, 2)
        elif pool == 'avg':
            pool = nn.AvgPool2d(2, 2)

        if not norm is None:
            self.conv.add_module('norm', norm)

        if not pool is None:
            self.conv.add_module('pool', pool)

        if not activ is None:
            self.conv.add_module('activ', activ)

    def forward(self, x):
        x = self.conv(x)
        return x

In [None]:
class DeConvLayer(nn.Module):
    def __init__(self, in_c, out_c, kernel, stride,
                 padding=0, activ=None, norm=None,
                 pool=None, bias=True):
        super(DeConvLayer, self).__init__()
        self.deconv = nn.Sequential()
        self.deconv.add_module('deconv', nn.ConvTranspose2d(in_c, out_c, kernel_size=kernel,
                                                            stride=stride, padding=padding, bias=bias))

        if activ == 'leak':
            activ = nn.LeakyReLU(inplace=True)
        elif activ == 'relu':
            activ = nn.ReLU(inplace=True)
        elif activ == 'pleak':
            activ = nn.PReLU()
        elif activ == 'gelu':
            activ = nn.GELU()
        elif activ == 'selu':
            activ = nn.SELU()
        elif activ == 'sigmoid':
            activ = nn.Sigmoid()
        elif activ == 'softmax':
            activ = nn.Softmax(dim=1)
        if norm == 'bn':
            norm = nn.BatchNorm2d(out_c)
        if pool == 'max':
            pool = nn.MaxPool2d(2, 2)
        elif pool == 'avg':
            pool = nn.AvgPool2d(2, 2)

        if not norm is None:
            self.deconv.add_module('norm', norm)

        if not pool is None:
            self.deconv.add_module('pool', pool)

        if not activ is None:
            self.deconv.add_module('activ', activ)

    def forward(self, x):
        x = self.deconv(x)
        return x

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_c, out_c, activ=None, pool=None, norm='bn'):
        super(ConvBlock, self).__init__()
        self.c1 = ConvLayer(in_c, out_c, 3, 1, activ=activ, norm=norm, padding=1)
        self.c2 = ConvLayer(out_c, out_c, 3, 1, activ=activ, norm=norm, padding=1)

        if pool == 'up_stride':
            self.pool = DeConvLayer(out_c, out_c, 2, 2, norm=norm)
        elif pool == 'up_bilinear':
            self.pool = nn.Upsample(scale_factor=2, mode=pool[3:], align_corners=True)
        elif pool == 'up_nearest':
            self.pool = nn.Upsample(scale_factor=2, mode=pool[3:], align_corners=True)
        elif pool == 'down_max':
            self.pool = nn.MaxPool2d(2, 2)
        elif pool == 'down_stride':
            self.c2 = ConvLayer(out_c, out_c, 3, 2, activ=activ, norm=norm, padding=1)
            self.pool = None
        else:
            self.pool = None

    def forward(self, x):
        x = self.c2(self.c1(x))

        if self.pool:
            return x, self.pool(x)
        else:
            return 0, x

In [None]:
from torch import nn


class Unet(nn.Module):
    def __init__(self, device, inp_ch=1, out_ch=1,
                 arch=16, depth=3, activ='leak', concat=None):
        super(Unet, self).__init__()

        self.activ = activ
        self.device = device
        self.out_ch = out_ch
        self.inp_ch = inp_ch
        self.depth = depth
        self.arch = arch
        self.concat = None

        self.arch_n = []
        self.enc = []
        self.dec = []
        self.layers = []
        self.skip = []

        self.check_concat(concat)
        self.prep_arch_list()
        self.organize_arch()
        self.prepare_params()

    def check_concat(self, con):
        if con is None:
            self.concat = [1] * self.depth
        elif len(con) > self.depth:
            self.concat = con[:self.depth]
            self.concat = 2 * con
            self.concat[self.concat == 0] = 1
        elif len(con) < self.depth:
            self.concat = con + [0] * (self.depth - len(con))
            self.concat = 2 * con
            self.concat[self.concat == 0] = 1
        else:
            self.concat = 2 * con
            self.concat[self.concat == 0] = 1

    def prep_arch_list(self):
        for dl in range(0, self.depth + 1):
            self.arch_n.append((2 ** (dl - 1)) * self.arch)

        self.arch_n[0] = self.inp_ch

    def organize_arch(self):
        for idx in range(len(self.arch_n) - 1):
            self.enc.append(
                ConvBlock(self.arch_n[idx], self.arch_n[idx + 1], activ=self.activ, pool='down_max'))

        self.layers = [ConvBlock(self.arch_n[-1], self.arch_n[-1], activ=self.activ, pool='up_stride')]

        for idx in range(len(self.arch_n) - 2):
            self.dec.append(
                ConvBlock(self.concat[- (idx + 1)] * self.arch_n[- (idx + 1)], self.arch_n[- (idx + 2)],
                           activ=self.activ, pool='up_stride'))
        self.dec.append(ConvBlock(self.concat[0] * self.arch, self.arch, activ=self.activ))
        self.layers.append(ConvLayer(self.arch, self.out_ch, 1, 1, norm=None, activ='tanh'))

    def prepare_params(self):
        for blk_idx in range(len(self.enc)):
            self.add_module(f'enc_{blk_idx + 1}', self.enc[blk_idx])

        self.add_module(f'mid', self.layers[0])

        for blk_idx in range(len(self.dec)):
            self.add_module(f'dec_{blk_idx + 1}', self.dec[blk_idx])

        self.add_module(f'final', self.layers[1])

    def forward(self, img):
        h = img
        h_skip = []

        for conv in self.enc:
            hs, h = conv(h)
            h_skip.append(hs)

        _, h = self.mid(h)

        for l_idx in range(len(self.dec)):
            if self.concat[-(l_idx + 1)] == 2:
                _, h = self.dec[l_idx](concat_curr(h_skip[-(l_idx + 1)], h))
            else:
                _, h = self.dec[l_idx](h)

        h = self.final(h)

        return h