In [1]:
!pip install tensorflow



In [2]:
import tensorflow as tf
from tensorflow import keras
from keras.datasets import mnist
from keras.layers import Input, Dense, Flatten, BatchNormalization, Reshape, LeakyReLU
from keras.models import Sequential, Model
from keras.optimizers.legacy import Adam
import matplotlib.pyplot as plt
import numpy as np

In [3]:
img_rows=28
img_cols=28
channels=1
img_shape=(img_rows,img_cols,channels)

In [4]:
def build_generator():
  noise_shape=(100,)
  model=Sequential(
      [Dense(256,input_shape=noise_shape),
      LeakyReLU(alpha=0.2),
      BatchNormalization(momentum=0.8),
      Dense(512),
      LeakyReLU(alpha=0.2),
      BatchNormalization(momentum=0.8),
      Dense(1024),
      LeakyReLU(alpha=0.2),
      BatchNormalization(momentum=0.8),
      Dense(np.prod(img_shape),"tanh"),
      Reshape(img_shape)]
  )
  model.summary()

  noise=Input(shape=noise_shape)
  gen_img=model(noise)
  return Model(noise,gen_img)

In [5]:
def build_discriminator():
  model=Sequential(
      [Flatten(input_shape=img_shape),
      Dense(512),
      LeakyReLU(alpha=0.2),
      Dense(256),
      LeakyReLU(alpha=0.2),
      Dense(1,"sigmoid"),]
  )
  model.summary()

  img=Input(shape=img_shape)
  validity=model(img)
  return Model(img,validity)

In [6]:
def train(epochs, batch_size=128, save_interval=50):
  (xtrain, _ ),(_, _)=mnist.load_data()
  xtrain=(xtrain.astype(np.float32) - 125.5)/127.5
  xtrain=np.expand_dims(xtrain,axis=3)
  half_batch=int(batch_size/2)
  for epoch in range(epochs):
    idx=np.random.randint(0, xtrain.shape[0], half_batch)
    imgs=xtrain[idx]

    noise=np.random.normal(0,1,(half_batch,100))
    gen_imgs=generator.predict(noise)

    d_loss_real = discriminator.train_on_batch(imgs, np.ones((half_batch,1)))
    d_loss_fake = discriminator.train_on_batch(gen_imgs, np.zeros((half_batch,1)))
    d_loss = 0.5*np.add(d_loss_real, d_loss_fake)

    noise=np.random.normal(0,1,(batch_size,100))
    valid_y=np.array([1]*batch_size)
    g_loss = combined.train_on_batch(noise,valid_y)

    print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" %(epoch, d_loss[0],100*d_loss[1], g_loss))
    if epoch%save_interval==0:
      save_imgs(epoch)

In [7]:
def save_imgs(epoch):
  r,c =5,5
  noise=np.random.normal(0,1,(r*c,100))
  gen_imgs = generator.predict(noise)

  gen_imgs=0.5*gen_imgs+0.5
  fig,axs=plt.subplots(r,c)
  cnt=0
  for i in range(r):
    for j in range(c):
      axs[i,j].imshow(gen_imgs[cnt,:,:,0], cmap='gray')
      axs[i,j].axis('off')
      cnt+=1
  fig.savefig("images/mnist_%d.png"%epoch)
  plt.close()

In [9]:
optimizer=Adam(0.0002,0.5)
discriminator=build_discriminator()
discriminator.compile(loss='binary_crossentropy',optimizer=optimizer,metrics=['accuracy'])
generator=build_generator()
generator.compile(loss='binary_crossentropy',optimizer=optimizer)

z=Input(shape=(100,))
img=generator(z)
discriminator.trainable=False
valid=discriminator(img)

combined=Model(z, valid)
combined.compile(loss='binary_crossentropy',optimizer=optimizer)
train(epochs=10000, batch_size=128, save_interval=1000)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
7512 [D loss: 0.735109, acc.: 46.88%] [G loss: 0.783997]
7513 [D loss: 0.699760, acc.: 53.12%] [G loss: 0.783752]
7514 [D loss: 0.673280, acc.: 57.81%] [G loss: 0.757697]
7515 [D loss: 0.687739, acc.: 58.59%] [G loss: 0.801691]
7516 [D loss: 0.707265, acc.: 52.34%] [G loss: 0.803182]
7517 [D loss: 0.680578, acc.: 57.03%] [G loss: 0.783812]
7518 [D loss: 0.725715, acc.: 46.09%] [G loss: 0.802684]
7519 [D loss: 0.692355, acc.: 52.34%] [G loss: 0.794561]
7520 [D loss: 0.655166, acc.: 60.16%] [G loss: 0.819709]
7521 [D loss: 0.696829, acc.: 49.22%] [G loss: 0.793766]
7522 [D loss: 0.639459, acc.: 67.19%] [G loss: 0.813182]
7523 [D loss: 0.678741, acc.: 58.59%] [G loss: 0.795452]
7524 [D loss: 0.698741, acc.: 54.69%] [G loss: 0.783122]
7525 [D loss: 0.685418, acc.: 50.78%] [G loss: 0.791886]
7526 [D loss: 0.692750, acc.: 51.56%] [G loss: 0.817910]
7527 [D loss: 0.702983, acc.: 51.56%] [G loss: 0.821453]
7528 [D loss: 0.677291,

In [10]:
generator.save("generator_model.h5")

  saving_api.save_model(


In [37]:
!zip -r /content/GAN_MNIST.zip /content/GAN_MNIST
from google.colab import files
files.download("/content/GAN_MNIST.zip")

  adding: content/GAN_MNIST/ (stored 0%)
  adding: content/GAN_MNIST/images/ (stored 0%)
  adding: content/GAN_MNIST/images/mnist_4900.png (deflated 6%)
  adding: content/GAN_MNIST/images/mnist_8900.png (deflated 7%)
  adding: content/GAN_MNIST/images/mnist_0.png (deflated 7%)
  adding: content/GAN_MNIST/images/mnist_5900.png (deflated 7%)
  adding: content/GAN_MNIST/images/mnist_6900.png (deflated 7%)
  adding: content/GAN_MNIST/images/mnist_7900.png (deflated 7%)
  adding: content/GAN_MNIST/images/mnist_1900.png (deflated 6%)
  adding: content/GAN_MNIST/images/mnist_9900.png (deflated 7%)
  adding: content/GAN_MNIST/images/mnist_2900.png (deflated 6%)
  adding: content/GAN_MNIST/images/mnist_900.png (deflated 6%)
  adding: content/GAN_MNIST/images/mnist_3900.png (deflated 7%)
  adding: content/GAN_MNIST/generator_model.h5 (deflated 9%)
  adding: content/GAN_MNIST/.ipynb_checkpoints/ (stored 0%)


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>