<a href="https://colab.research.google.com/github/VedantDere0104/GANs/blob/main/Progressive_Growing_GAN's.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Progressive Growing of GANs for Improved Quality, Stability, and Variation :- https://arxiv.org/abs/1710.10196

In [1]:
####

In [89]:
import tensorflow as tf
from tensorflow import keras
import numpy as np

In [3]:
class First_layer(tf.keras.layers.Layer):
  def __init__(self , n_filters):
    super(First_layer , self).__init__()

    self.conv1 = tf.keras.layers.Conv2D(n_filters , (4 , 4) , padding='same')
    self.conv2 = tf.keras.layers.Conv2D(n_filters , (3 , 3) , padding='same' )
    self.lrelu = tf.keras.layers.LeakyReLU()

  def call(self , x):
    x = self.conv1(x)
    x = self.lrelu(x)
    x = self.conv2(x)
    x = self.lrelu(x)

    return x

In [4]:
z = np.random.rand(1 , 4 , 4 , 512)

In [5]:
convT = tf.keras.layers.Conv2DTranspose(512 , (2 , 2) , strides=(2 , 2))(z)
convT.shape

TensorShape([1, 8, 8, 512])

In [6]:
class Upsample(tf.keras.layers.Layer):
  def __init__(self , n_filters):
    super(Upsample , self).__init__()

    self.convT = tf.keras.layers.Conv2DTranspose(n_filters , (2 , 2) , strides=(2 , 2))
    self.lrelu = tf.keras.layers.LeakyReLU()

  def call(self , x):
    x = self.convT(x)
    x = self.lrelu(x)
    return x

In [7]:
class Conv(tf.keras.layers.Layer):
  def __init__(self , n_filters):
    super(Conv , self).__init__()

    self.conv1 = tf.keras.layers.Conv2D(n_filters , (3 , 3) , strides=(1 , 1) , padding='same')
    self.lrelu = tf.keras.layers.LeakyReLU()
    self.conv2 = tf.keras.layers.Conv2D(n_filters , (3 , 3) , strides=(1 , 1) , padding='same')

  def call(self , x):
    x = self.conv1(x)
    x = self.lrelu(x)
    x = self.conv2(x)
    x = self.lrelu(x)
    return x

In [8]:
x = Upsample(512)(z)
print(x.shape)
x = Conv(512)(x)
x.shape

(1, 8, 8, 512)


TensorShape([1, 8, 8, 512])

In [9]:
class Middle_layer(tf.keras.layers.Layer):
  def __init__(self , n_filters):
    super(Middle_layer , self).__init__()

    self.upsample = Upsample(n_filters)
    self.conv = Conv(n_filters)

  def call(self , x):
    x = self.upsample(x)
    x = self.conv(x)
    return x

In [10]:
class Generator(tf.keras.layers.Layer):
  def __init__(self , n_filters):
    super(Generator , self).__init__()

    self.first_layer = First_layer(n_filters)
    self.middle_layer1 = Middle_layer(n_filters)
    self.middle_layer2 = Middle_layer(n_filters)
    self.middle_layer3 = Middle_layer(n_filters)
    n_filters = n_filters // 2
    self.middle_layer4 = Middle_layer(n_filters)
    n_filters = n_filters // 2
    self.middle_layer5 = Middle_layer(n_filters)
    n_filters = n_filters // 2
    self.middle_layer6 = Middle_layer(n_filters)
    n_filters = n_filters // 2
    self.middle_layer7 = Middle_layer(n_filters)
    n_filters = n_filters // 2
    self.middle_layer8 = Middle_layer(n_filters)
    self.last_layer = tf.keras.layers.Conv2D(3 , (1 , 1) , strides=(1 , 1) , padding='same')

  def call(self , x):
    x = self.first_layer(x)
    x = self.middle_layer1(x)
    x = self.middle_layer2(x)
    x = self.middle_layer3(x)
    x = self.middle_layer4(x)
    x = self.middle_layer5(x)
    x = self.middle_layer6(x)
    x = self.middle_layer7(x)
    x = self.middle_layer8(x)
    x = self.last_layer(x)
    return x

In [11]:
generator = Generator(512)

In [12]:
g = generator(z)
g.shape

TensorShape([1, 1024, 1024, 3])

In [13]:
img = np.random.randn(1 , 1024 , 1024 , 3)
d1 = tf.keras.layers.Conv2D(16 , (1 , 1) , strides=(1 , 1))(img)
d2 = tf.keras.layers.Conv2D(16 , (3 , 3) , strides=(1 , 1) , padding='same')(d1)
d2.shape

TensorShape([1, 1024, 1024, 16])

In [14]:
class Downsample(tf.keras.layers.Layer):
  def __init__(self , n_filters):
    super(Downsample , self).__init__()

    self.conv1 = tf.keras.layers.Conv2D(n_filters , (3 , 3) , strides=(2 , 2) , padding='same')
    self.lrelu = tf.keras.layers.LeakyReLU()
  def call(self , x):
    return self.lrelu(self.conv1(x))


In [15]:
downsample = Downsample(32)
d3 = downsample(d2)
d3.shape

TensorShape([1, 512, 512, 32])

In [16]:
class Disc_first_layer(tf.keras.layers.Layer):
  def __init__(self , n_filters):
    super(Disc_first_layer , self).__init__()

    self.conv1 = tf.keras.layers.Conv2D(n_filters , (1 , 1) , strides=(1 , 1) , padding='same')
    self.conv2 = tf.keras.layers.Conv2D(n_filters , (3 , 3) , strides=(1 , 1) , padding='same')
    self.conv3 = tf.keras.layers.Conv2D(n_filters * 2 , (3 , 3) , strides=(1 , 1) , padding='same')
    self.downsample = Downsample(n_filters * 2)
    self.lrelu = tf.keras.layers.LeakyReLU()

  def call(self , x):
    x = self.lrelu(self.conv1(x))
    x = self.lrelu(self.conv2(x))
    x = self.lrelu(self.conv3(x))
    x = self.downsample(x)
    return x


In [18]:
first_layer = Disc_first_layer(16)
d = first_layer(img)
d.shape

TensorShape([1, 512, 512, 32])

In [20]:
class Middle_layer(tf.keras.layers.Layer):
  def __init__(self , n_filters):
    super(Middle_layer , self).__init__()

    self.conv1 = tf.keras.layers.Conv2D(n_filters //2 ,(3 , 3) , strides=(1 , 1) , padding='same')
    self.conv2 = tf.keras.layers.Conv2D(n_filters , (3 , 3) , strides=(1 , 1) , padding='same')
    self.downsample = Downsample(n_filters)
    self.lrelu = tf.keras.layers.LeakyReLU()

  def call(self , x):
    x = self.lrelu(self.conv1(x))
    x = self.lrelu(self.conv2(x))
    x = self.downsample(x)
    return x

In [21]:
middle_layer = Middle_layer(64)

In [22]:
d = middle_layer(d)
d.shape

TensorShape([1, 256, 256, 64])

In [80]:
class Discriminator(tf.keras.layers.Layer):
  def __init__(self , start_filters):
    super(Discriminator , self).__init__()
    
    self.first_layer = First_layer(start_filters)
    n_filters = start_filters * 2
    self.middle_layer1 = Middle_layer(n_filters)
    n_filters = n_filters * 2
    self.middle_layer2 = Middle_layer(n_filters)
    n_filters = n_filters * 2
    self.middle_layer3 = Middle_layer(n_filters)
    n_filters = n_filters * 2
    self.middle_layer4 = Middle_layer(n_filters)
    n_filters = n_filters * 2
    self.middle_layer5 = Middle_layer(n_filters)
    self.middle_layer6 = Middle_layer(n_filters)
    self.middle_layer7 = Middle_layer(n_filters)
    self.middle_layer8 = Middle_layer(n_filters)

    self.second_last = tf.keras.layers.Conv2D(n_filters , (3 , 3) , strides=(1 , 1) , padding='same')
    self.last = tf.keras.layers.Conv2D(n_filters , (4 , 4) , strides=(2 , 2))

    self.flatten = tf.keras.layers.Flatten()
    self.linear1 = tf.keras.layers.Dense(n_filters , activation=tf.keras.activations.relu)
    self.linear2 = tf.keras.layers.Dense(1)
    
     
  def call(self , x):
    x = self.first_layer(x)
    x = self.middle_layer1(x)
    x = self.middle_layer2(x)
    x = self.middle_layer3(x)
    x = self.middle_layer4(x)
    x = self.middle_layer5(x)
    x = self.middle_layer6(x)
    x = self.middle_layer7(x)
    x = self.middle_layer8(x)
    x = self.second_last(x)
    x = self.last(x)
    x = self.flatten(x)
    x = self.linear1(x)
    x = self.linear2(x)
    return x

In [81]:
disc = Discriminator(16)

In [82]:
m = disc(img)

In [83]:
m.shape

TensorShape([1, 1])