<a href="https://colab.research.google.com/github/asutoshp10/GAN/blob/main/cGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow import keras
from keras.layers import *
from keras.optimizers import Adam
from keras import Sequential,Model,Input
from keras.datasets.cifar10 import load_data

In [None]:
(x_tr,y_tr),(x_test,y_test)=load_data()

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz


In [None]:
x_tr=(x_tr-127.5)/127.5

In [None]:
def generate_real_samples(dataset,n_samples):
    ix=np.random.randint(0,dataset.shape[0],n_samples)
    x1=dataset[ix]
    x2=y_tr[ix]
    x=[x1,x2]
    y=np.ones((n_samples,1))
    return x,y

In [None]:
def discriminator(in_shape=(32,32,3),n_classes=10):
  input_shape=Input(shape=(1,))
  in_dim=Input(shape=in_shape)
  d=Embedding(n_classes,50)(input_shape)
  d=Dense(in_shape[0]*in_shape[1])(d)
  d=Reshape((in_shape[0],in_shape[1],1))(d)

  merge=Concatenate()([in_dim,d])
  d=Conv2D(128,(3,3),padding='same',strides=2)(merge)
  d=LeakyReLU(alpha=0.2)(d)

  d=Conv2D(128,(3,3),padding='same',strides=2)(d)
  d=LeakyReLU(alpha=0.2)(d)

  d=Conv2D(128,(3,3),padding='same',strides=2)(d)
  d=LeakyReLU(alpha=0.2)(d)

  d=Flatten()(d)
  d=Dropout(0.4)(d)
  d=Dense(1,activation='sigmoid')(d)

  model=Model([in_dim,input_shape],d)
  opt=Adam(lr=0.0002,beta_1=0.5)
  model.compile(loss='binary_crossentropy',optimizer=opt,metrics=['accuracy'])
  return model


In [None]:
def generator(latent_dim=100,n_classes=10):
  in_latent=Input(shape=(latent_dim))
  in_class=Input(shape=(1,))
  c=Embedding(n_classes,50)(in_class)
  c=Dense(4*4)(c)
  c=Reshape((4,4,1))(c)

  l=Dense(4*4*256)(in_latent)
  l=LeakyReLU(alpha=0.2)(l)
  l=Reshape((4,4,256))(l)

  merge=Concatenate()([l,c])
  g=Conv2DTranspose(128,(4,4),strides=2,padding='same')(merge)
  g=LeakyReLU(alpha=0.2)(g)

  g=Conv2DTranspose(128,(4,4),strides=2,padding='same')(g)
  g=LeakyReLU(alpha=0.2)(g)

  g=Conv2DTranspose(128,(4,4),strides=2,padding='same')(g)
  g=LeakyReLU(alpha=0.2)(g)

  g=Conv2D(3,(3,3),padding='same',activation='tanh')(g)

  return Model([in_latent,in_class],g)

In [None]:
def generate_latent_dim(latent_dim, n_classes, n_samples):
  x2 = np.random.randint(0, n_classes, n_samples)
  x1 = np.random.randn(n_samples,latent_dim)
  return [x1,x2]

In [None]:
def generate_fake_samples(g_model,latent_dim=100,n_classes=10,n_samples=64):
    [l1,l2]=generate_latent_dim(latent_dim,n_classes,n_samples)
    y=np.zeros((n_samples,1))
    x1=g_model.predict([l1,l2])
    return [x1,l2],y

In [None]:
def gan(g_model,d_model):
  d_model.trainable=False
  #classes=Input(shape=(1,))
  #latent=Input(shape=(100,))
  latent,classes=g_model.input
  g_output=g_model.output

  d=d_model([g_output,classes])
  model=Model([latent,classes],d)
  opt=Adam(lr=0.0002,beta_1=0.5)
  model.compile(loss='binary_crossentropy',optimizer=opt)
  return model

In [None]:
def train(gan_model,g_model,d_model,dataset,latent_dim=100,n_epochs=250,n_batch=128):
  bat_per_epo=int(dataset.shape[0]/n_batch)
  half_batch=int(n_batch/2)

  for i in range(n_epochs):
    for j in range(bat_per_epo):
      x_real,y_real=generate_real_samples(dataset,half_batch)
      d_loss1,d_acc1=d_model.train_on_batch(x_real,y_real)
      x_fake,y_fake=generate_fake_samples(g_model,latent_dim,10,half_batch)
      d_loss2,d_acc2=d_model.train_on_batch(x_fake,y_fake)

      l=generate_latent_dim(latent_dim,10,n_batch)
      #print(l1.shape,l2.shape)
      y_gan=np.ones((n_batch,1))
      g_loss=gan_model.train_on_batch(l,y_gan)

      print(f'epoch:{i+1}/{n_epochs},batch:{j+1}/{bat_per_epo},d_loss1:{d_loss1},d_loss2:{d_loss2},g_loss:{g_loss}')
    if j%10==0:
      generate_plot(g_model,latent_dim,i)


In [None]:
def generate_plot(g_model,latent_dim,epoch_no):
  l=generate_latent_dim(latent_dim,10,100)
  l2=[]
  for i in range(10):
    for j in range(10):
      l2.append(j)

  x=g_model.predict([l[0],np.array(l2)])
  for i in range(100):
    plt.subplot(10,10,i+1)
    plt.axis='off'
    plt.imshow(x[i])
  g_model.save(f'model_{epoch_no}.h5')
  plt.savefig(f'plot_{epoch_no}.png')
  plt.show()

In [None]:
g_model=generator()
d_model=discriminator()
gan_model=gan(g_model,d_model)
train(gan_model,g_model,d_model,x_tr)