In [None]:
import cv2
from tensorflow.keras import models, layers
from tensorflow.keras.models import Model
from tensorflow.keras.layers import BatchNormalization, Activation, Flatten
from tensorflow.keras.optimizers import Adam,RMSprop,SGD
import tensorflow.keras.backend as k
import os
import tensorflow.keras as keras
import tensorflow as tf
from tensorflow.keras.layers import Input, Add, Dense, Activation, ZeroPadding2D, BatchNormalization, Flatten, Conv2D, Conv3D,UpSampling2D,ReLU,Dropout
import json
import matplotlib.image as mpimg
from tqdm.notebook import tqdm
import pickle
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Layer

In [None]:
class Residual(Layer):
    def __init__(self, channels_in,kernel,**kwargs):
        super(Residual, self).__init__(**kwargs)
        self.trainable = True
        self.channels_in = channels_in
        self.kernel = kernel
        self.C1 = Conv2D( self.channels_in, 3, padding="same")
        self.C2 = Conv2D( self.channels_in, 3, padding="same")
        self.BN1 = BatchNormalization()
        self.BN2 = BatchNormalization()
    
    # def get_config(self):
    #     config = super().get_config()
    #     config.update({
    #         "trainable":self.trainable,
    #         "channels_in": self.channels_in,
    #         "kernel":self.kernel,
    #         "C1":self.C1,
    #         "C2":self.C2,
    #         "BN1":self.BN1,
    #         "BN2":self.BN2
    #     })



    def call(self, x):
        
        self.first = x
        x = self.BN1(x)
        
        x = self.C1(x)
        x = self.BN2(x)

        x = self.C2(x)

        
        residual =  Add()([x, self.first])
        x =  Activation("relu")(residual)
        return x

    def compute_output_shape(self, input_shape):
        return input_shape

In [None]:
class Encoder(Layer):
  def __init__(self,kernel_nums=[16,32],kernel_size=[(7,7),(3,3)],strides=[(1,1),(2,2)]):
    super().__init__()
    self.trainable = True
    self.kernel1 = kernel_size[0]
    self.kernel2 = kernel_size[1]
    self.filters1 = kernel_nums[0]
    self.filters2 = kernel_nums[1]
    self.s1 = strides[0]
    self.s2 = strides[1]
    self.R1 = Residual(self.filters2,self.kernel2)
    self.R2 = Residual(self.filters2,self.kernel2)
    self.R3 = Residual(self.filters2,self.kernel2)

    self.Rtext1 = Residual(self.filters2,self.kernel2)
    self.Rtext2 = Residual(self.filters2,self.kernel2)

    self.Rshape1 = Residual(self.filters2,self.kernel2)
    self.Rshape2 = Residual(self.filters2,self.kernel2)

    self.conv1 = Conv2D(filters=self.filters1,kernel_size=self.kernel1,strides=self.s1,activation='relu',padding='same')
    self.conv2 = Conv2D(filters=self.filters2,kernel_size=self.kernel2,strides=self.s2,activation='relu',padding='same')
    self.conv_texture = Conv2D(filters=self.filters2,kernel_size=self.kernel2,strides=self.s2,activation='relu',padding='same')
    self.conv_shape = Conv2D(filters=self.filters2,kernel_size=self.kernel2,strides=self.s1,activation='relu',padding='same')


  def call(self,data):
      x = data
      x=self.conv1(x)
      x=self.conv2(x)
      x=self.R1(x)
      x=self.R2(x)
      x_final_res=self.R3(x)

      # obtaining texture representation
      x_text= self.conv_texture(x_final_res)
      x_text= self.Rtext1(x_text)
      texture_output = self.Rtext2(x_text)

      # obtaining shape representation
      x_shape= self.conv_shape(x_final_res)
      x_shape =self.Rshape1(x_shape)
      shape_output= self.Rshape2(x_shape)

      return texture_output, shape_output
      


In [None]:
class Manipulator(Layer):
  def __init__(self,kernel=(3,3),  filters=32, stride=(1,1)):
    super().__init__()
    self.trainable = True
    self.kernel = kernel
    self.filters= filters
    self.stride = stride

    self.R = Residual(self.filters,self.kernel)
      
    self.conv1= Conv2D(self.filters, self.kernel, self.stride,activation = 'relu', padding = 'same')
    self.conv2= Conv2D(self.filters, self.kernel, self.stride,activation = 'relu', padding = 'same')


  def call(self,data):

    self.Ma=data[0]
    self.Mb=data[1]
    self.alpha= data[2]

    self.diff = tf.abs(self.Ma-self.Mb)

    input = self.diff
    x = self.conv1(input)
    x= self.alpha*x
    x= self.conv2(x)
    x= self.R(x)
    manip_output = x+self.Ma
    return manip_output




In [None]:
class Decoder(Layer):
  def __init__(self,kernel_size=[(3,3),(7,7)],filters=[32,3],stride=(1,1)):
    super().__init__()
    self.trainable = True
    self.filter1= filters[0]
    self.filter2= filters[1]
    self.k1 =kernel_size[0]
    self.k2 = kernel_size[1]
    self.stride=stride

    self.R1 = Residual(64,self.k1)
    self.R2 = Residual(64,self.k1)
    self.R3 = Residual(64,self.k1)
    self.R4 = Residual(64,self.k1)
    self.R5 = Residual(64,self.k1)
    self.R6 = Residual(64,self.k1)
    self.R7 = Residual(64,self.k1)
    self.R8 = Residual(64,self.k1)
    self.R9 = Residual(64,self.k1)

    self.upsampling1 = UpSampling2D( size=(2, 2),  interpolation='nearest')
    self.upsampling2 = UpSampling2D( size=(2, 2),  interpolation='nearest')
    self.conv1= Conv2D(self.filter1, self.k1, self.stride,activation = 'relu', padding = 'same')
    self.conv2= Conv2D(self.filter2, self.k2, self.stride,activation = 'relu', padding = 'same')



  def call(self,data):
      self.text_b = data[0]
      self.manip_output= data[1]
      x= self.upsampling1(self.text_b)
      # print(x.shape,manip_output.shape)
      x = tf.concat([x,self.manip_output],axis=-1)
      # for i in range(9):
      #   x= self.R(x)
      x= self.R1(x)
      x= self.R2(x)
      x= self.R3(x)
      x= self.R4(x)
      x= self.R5(x)
      x= self.R6(x)
      x= self.R7(x)
      x= self.R8(x)
      x= self.R9(x)
      x= self.upsampling2(x)  
      x= self.conv1(x)
      final_output = self.conv2(x)
      return final_output