In [1]:
#System Libs
import os
import glob
import json

#ImageProc Libs
import cv2
from albumentations import (Compose,Resize)
from albumentations.pytorch import ToTensorV2

#Viz
import matplotlib.pyplot as plt

#DL Libs
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
import torch.optim as optim
from torch.utils.data import Dataset,DataLoader
from torchsummary import  summary

In [2]:
files = open('configs.json')

In [3]:
ff = json.load(files)

In [4]:
images_files = glob.glob("/home/lustbeast/PaperWork/Veena Ma'am/Sketch to Face/IIITD_SketchDatabase/Semi-forensic database/IIIT-D student and staff/photo/*.jpg")

In [5]:
class sketchDataset(Dataset):
    def __init__(self,paths,transforms=None):
        self.paths = paths
        self.transforms = transforms
    def __len__(self):
        return len(self.paths)
    def __getitem__(self,idx):
        img = cv2.imread(self.paths[idx])
        if self.transforms:
            img = self.transforms(image=img)['image']

        return img

In [6]:
def augs():
    return Compose([
        Resize(384,384),
        ToTensorV2()
    ])

In [7]:
dataset = sketchDataset(images_files,augs())

In [8]:
dataset = DataLoader(dataset,num_workers = 4,batch_size=1)

In [9]:
images = next(iter(dataset)).half()

In [10]:
writer = SummaryWriter("runs/graph_viz/encoder/redefined")

## Architecture

### Modules:
    1. MultiScale-A
    2. Reduction
    3. MultiScale-B

<img src="RES/arch.png"></img>


Things to be considered:

    * In MultiScale-A module, The stream ID =1 given in the paper has been neglected due to average pooling. So, this module just contains three parallel streams.
    * In MultiScale-B module, The stream ID =4 given in the paper has been neglected due to the same reason. So, this module also contains just three parallel streams

In [11]:
class MultiScaleA(nn.Module):
    def __init__(self,in_channels,stream_1_1x1_filters,stream_2_1x1_filters,stream_2_3x3_filters,stream_3_1x1_filters,stream_3_3x3_filters):
        super(MultiScaleA,self).__init__()
        
        self.stream1 = conv_mod(in_channels,stream_1_1x1_filters,kernel_size=(1,1)) #Stream ID=1
        self.stream2 = nn.Sequential(
            conv_mod(in_channels,stream_2_1x1_filters,kernel_size=(1,1)),
            conv_mod(stream_2_1x1_filters,stream_2_3x3_filters,kernel_size=(3,3),padding=1),     #Stream ID=2
        )
        self.stream3 = nn.Sequential(
            conv_mod(in_channels,stream_3_1x1_filters,kernel_size=(1,1)),
            conv_mod(stream_3_1x1_filters,stream_3_3x3_filters[0],kernel_size=(3,3),padding=1),
            conv_mod(stream_3_3x3_filters[0],stream_3_3x3_filters[1],kernel_size=(3,3),padding=1)
        )

    def forward(self,x):

        stream1 = self.stream1(x)
        stream2 = self.stream2(x)
        stream3 = self.stream3(x)

        concat = torch.cat([stream1,stream2,stream3],axis=1)

        return concat


In [12]:
class Reduction(nn.Module):
    def __init__(self,in_channels,red_stream_2_3x3_filters,red_stream_3_1x1_filters,red_stream_3_3x3_filters):

        super(Reduction,self).__init__()
        self.stream1_MF = nn.MaxPool2d(kernel_size=(3,3),stride=2)
        
        self.stream2_CF = conv_mod(in_channels,red_stream_2_3x3_filters,kernel_size=(3,3),stride=(2,2))
        
        self.stream3_CF = nn.Sequential(
            conv_mod(in_channels,red_stream_3_1x1_filters,kernel_size=(1,1)),
            conv_mod(red_stream_3_1x1_filters,red_stream_3_3x3_filters[0],kernel_size=(3,3)),
            conv_mod(red_stream_3_3x3_filters[0],red_stream_3_3x3_filters[1],kernel_size=(3,3),stride=(2,2),padding=1)
        )
        

    def forward(self,x):

        stream1_MF = self.stream1_MF(x)
        stream2_CF = self.stream2_CF(x)
        stream3_CF = self.stream3_CF(x)

        print(stream1_MF.shape,stream2_CF.shape,stream3_CF.shape)

        return torch.cat([stream1_MF,stream2_CF,stream3_CF],axis=1)


In [13]:
class MultiScaleB(nn.Module):
    def __init__(self,in_channels,Bstream_1_1x1,Bstream_2_1x1,Bstream_2_3x3,Bstream_3_1x1,Bstream_3_3x3):
        super(MultiScaleB,self).__init__()
        self.st1 = conv_mod(in_channels,Bstream_1_1x1,kernel_size=(1,1))

        self.st2 = nn.Sequential(
            conv_mod(in_channels,Bstream_2_1x1,kernel_size=(1,1)),
            conv_mod(Bstream_2_1x1,Bstream_2_3x3[0],kernel_size=(1,3)),
            conv_mod(Bstream_2_3x3[0],Bstream_2_3x3[1],kernel_size=(3,1),padding=(1,1))
            )

        self.st3 = nn.Sequential(
            conv_mod(in_channels,Bstream_3_1x1,kernel_size=(1,1)),
            conv_mod(Bstream_3_1x1,Bstream_3_3x3[0],kernel_size=(1,3)),
            conv_mod(Bstream_3_3x3[0],Bstream_3_3x3[1],kernel_size=(3,1)),
            conv_mod(Bstream_3_3x3[1],Bstream_3_3x3[2],kernel_size=(1,3)),
            conv_mod(Bstream_3_3x3[2],Bstream_3_3x3[3],kernel_size=(3,1),padding=(2,2))
        )


    def forward(self,x):

        st1 = self.st1(x)
        st2 = self.st2(x)
        st3 = self.st3(x)

        print(st1.shape,st2.shape,st3.shape)

        return torch.cat([st1,st2,st3],axis=1)

        

In [14]:
class conv_mod(nn.Module):
    def __init__(self,in_channels,out_channels,kernel_size=(1,1),stride=1,padding=0,activation='relu'):
        super(conv_mod,self).__init__()
        self.mod = nn.Sequential(
            nn.BatchNorm2d(in_channels),
            nn.ReLU(),
            nn.Conv2d(in_channels,out_channels,kernel_size=kernel_size,stride=stride,padding=padding)

        )

    def forward(self,x):
        x = self.mod(x)
        return x

## Architecture - Encoder

In [25]:
class Multi_Block(nn.Module):
    def __init__(self,blocks):
        super(Multi_Block,self).__init__()
        self.BlockA_params = [i for i in blocks[0].values()]
        self.BlockB_params = [j for j in blocks[1].values()]
        self.BlockC_params = [k for k in blocks[2].values()]

        self.mulA = MultiScaleA(self.BlockA_params[0],self.BlockA_params[1],self.BlockA_params[2],self.BlockA_params[3],self.BlockA_params[4],self.BlockA_params[5])
        self.red = Reduction(self.BlockB_params[0],self.BlockB_params[1],self.BlockB_params[2],self.BlockB_params[3])

        self.mulB = MultiScaleB(self.BlockC_params[0],self.BlockC_params[1],self.BlockC_params[2],self.BlockC_params[3],self.BlockC_params[4],self.BlockC_params[5])

    def forward(self,x):

        x = self.mulA(x)
        x = self.red(x)
        x = self.mulB(x)

        return x
class self_attention(nn.Module):
    def __init__(self,feature):
        super(self_attention,self).__init__()
        self.f = nn.Conv2d(feature.shape[1],feature.shape[1],kernel_size=(1,1))
        self.g = nn.Conv2d(feature.shape[1],feature.shape[1],kernel_size=(1,1))
        self.h = nn.Conv2d(feature.shape[1],feature.shape[1],kernel_size=(1,1))
        self.soft = nn.Softmax()
    def forward(self,x):
        f = self.f(x)
        g = self.g(x)
        h = self.h(x)
        fg = torch.matmul(f,g)
        fg_soft = self.soft(fg)
        return torch.matmul(fg_soft,h)

class Encoder(nn.Module):
    def __init__(self,params):
        super(Encoder,self).__init__()
        self.A = Multi_Block(params['BlockA'])
        self.B = Multi_Block(params['BlockB'])
        self.C = Multi_Block(params['BlockC'])
        self.sing_1 = conv_mod(416,256,kernel_size=(5,5))
        self.sing_2 = conv_mod(256,128,kernel_size=(5,5))
        self.sing_3 = conv_mod(128,64,kernel_size=(5,5))
        self.sing_4 = conv_mod(64,32,kernel_size=(5,5))
       

    def forward(self,x):
        multi_x_1 = self.A(x)
        multi_x_2 = self.B(multi_x_1)
        multi_x_3 = self.C(multi_x_2)
        single_x_1 = self.sing_1(multi_x_3)
        single_x_2 = self.sing_2(single_x_1)
        single_x_3 = self.sing_3(single_x_2)
        single_x_4 = self.sing_4(single_x_3)
        
        return multi_x_1,multi_x_2,multi_x_3,single_x_1,single_x_2,single_x_3,single_x_4

## Architecture - Decoder

In [26]:
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder,self).__init__()
        self.up1 = nn.ConvTranspose2d(32,64,kernel_size=(5,5))
        self.up2 = nn.ConvTranspose2d(64,128,kernel_size=(5,5))
        self.up3 = nn.ConvTranspose2d(128,256,kernel_size=(5,5))
        self.up4 = nn.ConvTranspose2d(256,416,kernel_size=(5,5))
    def forward(self,x):
        x = self.up1(x)
        x = self.up2(x)
        x = self.up3(x)
        x = self.up4(x)

        return x


In [27]:
enc = Encoder(params=ff).half().cuda()

In [28]:
img = torch.randn(1,3,384,384)

In [29]:
enc_out = enc(img.half().to("cuda:0"))

torch.Size([1, 224, 191, 191]) torch.Size([1, 32, 191, 191]) torch.Size([1, 256, 191, 191])
torch.Size([1, 32, 191, 191]) torch.Size([1, 128, 191, 191]) torch.Size([1, 256, 191, 191])
torch.Size([1, 256, 95, 95]) torch.Size([1, 32, 95, 95]) torch.Size([1, 256, 95, 95])
torch.Size([1, 32, 95, 95]) torch.Size([1, 128, 95, 95]) torch.Size([1, 256, 95, 95])
torch.Size([1, 256, 47, 47]) torch.Size([1, 32, 47, 47]) torch.Size([1, 256, 47, 47])
torch.Size([1, 32, 47, 47]) torch.Size([1, 128, 47, 47]) torch.Size([1, 256, 47, 47])


In [31]:
dec  = Decoder().to('cuda:0')

In [33]:
dec_out = dec(enc_out[6].to("cuda:0"))

In [34]:
dec_out.shape

torch.Size([1, 416, 47, 47])

In [None]:
decb = DecoderB(416,)

In [44]:
def convert_layers(model, layer_type_old, layer_type_new, convert_weights=False, num_groups=None):
    for name, module in reversed(model._modules.items()):
        if len(list(module.children())) > 0:
            # recurse
            model._modules[name] = convert_layers(module, layer_type_old, layer_type_new, convert_weights)

        if type(module) == layer_type_old:
            layer_old = module
            layer_new = layer_type_new(module.num_features if num_groups is None else num_groups, module.num_features, module.eps, module.affine) 

            if convert_weights:
                layer_new.weight = layer_old.weight
                layer_new.bias = layer_old.bias

            model._modules[name] = layer_new

    return model


# Replace BatchNorm with GroupNorm

In [46]:
convert_layers(enc,conv_mod,torch.nn.ConvTranspose2d)

ModuleAttributeError: 'conv_mod' object has no attribute 'num_features'