<a href="https://colab.research.google.com/github/VedantDere0104/GANs/blob/main/Self_Adaptive_Sparse_Transform_GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [9]:
####

In [10]:
import tensorflow as tf

In [11]:
class MLP(tf.keras.layers.Layer):
  def __init__(self , n_filters , input_shape , output_shape , activation = 'relu'):
    super(MLP , self).__init__()
    self.convT = tf.keras.layers.Conv2DTranspose(filters=n_filters , kernel_size=4 , strides=2 , padding='valid' )
    self.batch_norm = tf.keras.layers.BatchNormalization()
    self.relu = tf.keras.layers.ReLU()
    self.tanh = tf.keras.layers.Activation('tanh')
    self.activation = activation

  def call(self , x):
    x = self.convT(x)
    x = self.batch_norm(x)
    if self.activation == 'relu':
      x = self.relu(x)
    elif self.activation == 'tanh':
      x = self.tanh(x)
    return x

In [12]:
class CSM(tf.keras.layers.Layer):
  def __init__(self , n_filters, input_shape , output_shape):
    super(CSM , self).__init__()

    self.dense = tf.keras.layers.Dense(n_filters , activation='relu')
  
  def call(self , x):
    x = self.dense(x)
    return x

In [13]:
class PSM(tf.keras.layers.Layer):
  def __init__(self , n_filters , input_shape , output_shape , use_conv = False):
    super(PSM , self).__init__()

    self.use_conv = use_conv

    self.convT = tf.keras.layers.Conv2DTranspose(n_filters , 4 , 2 , padding='valid')
    self.conv = tf.keras.layers.Conv2D(n_filters , 4 , 2 , padding='same')
    self.relu = tf.keras.layers.ReLU()

  def call(self , x):
    if self.use_conv:
      x = self.conv(x)
    else:
      x = self.convT(x)
    
    x = self.relu(x)
    return x
    

In [14]:
class Feature_Map_Recombination(tf.keras.layers.Layer):
  def __init__(self , n_filters , input_shape , output_shape):
    super(Feature_Map_Recombination , self).__init__()
    
    self.alpha = CSM(n_filters , input_shape , output_shape)
    self.beta = PSM(n_filters , input_shape , output_shape)

  def call(self , x):
    alpha_ = tf.matmul(self.alpha , x)
    beta_ = tf.matmul(self.beta , x)
    x = alpha_ + beta_
    return x

In [15]:
class Repeating_layer(tf.keras.layers.Layer):
  def __init__(self , n_filters , input_shape , output_shape):
    super(Repeating_layer , self).__init__()

    self.convT = tf.keras.layers.Conv2DTranspose(n_filters , 4 , 2 , padding='valid')
    self.sastm = Feature_Map_Recombination(n_filters , input_shape , output_shape)
    self.batch_norm = tf.keras.layers.BatchNormalization()
    self.relu = tf.keras.layers.ReLU()

  def call(self , x):
    x = self.convT(x)
    x = self.sastm(x)
    x = self.batch_norm(x)
    x = self.relu(x)
    return x

In [18]:
class GAN(tf.keras.layers.Layer):
  def __init__(self , input_shape , output_shape):
    super(GAN , self).__init__()
    n_filters = 64
    hidden_dim = 64

    self.mlp = MLP(n_filters , input_shape , hidden_dim )
    self.repeat = tf.keras.Sequential(
        [
         Repeating_layer(n_filters , hidden_dim , hidden_dim * 2) , 
         Repeating_layer(n_filters * 2 , hidden_dim * 2 , hidden_dim * 4) ,
         Repeating_layer(n_filters * 4 , hidden_dim*4 , hidden_dim * 8 ) , 
         Repeating_layer(n_filters * 8 , hidden_dim * 8 , hidden_dim * 16) , 
         Repeating_layer(n_filters * 16 , hidden_dim * 16 , hidden_dim * 8) 

        ]
    )

    self.last_layer = MLP(3, hidden_dim * 8 , output_shape , activation='tanh')

  def call(self , x):
    x = self.mlp(x)
    x = self.repeat(x)
    x = self.last_layer(x)
    return x
