In [1]:
import numpy as np
import matplotlib.pyplot as plt
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape,Flatten
from keras.layers import BatchNormalization, LeakyReLU
from keras.models import Sequential, Model
from keras.optimizers import Adam

In [2]:
import ssl
import urllib.request

ssl._create_default_https_context = ssl._create_unverified_context

(X_train, _), (_,_) = mnist.load_data()

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [10]:
def build_generator():
    model = Sequential()
    model.add(Dense(256,input_dim = 100))
    model.add(LeakyReLU(alpha = 0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha = 0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(1024))
    model.add(LeakyReLU(alpha = 0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(784,activation='tanh'))
    model.add(Reshape((28,28,1)))
    return model

generator = build_generator()

In [11]:
def build_discriminator():
    model = Sequential()
    model.add(Flatten(input_shape=(28, 28, 1)))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha = 0.2))
    model.add(Dense(256))
    model.add(LeakyReLU(alpha = 0.2))
    model.add(Dense(1, activation = 'sigmoid'))
    return model
discriminator = build_discriminator()
discriminator.compile(optimizer = Adam(0.0002, 0.5), loss = 'binary_crossentropy', metrics = ['accuracy'])

In [14]:
discriminator.trainable = False

gan_input = Input(shape=(100, ))
generated_image = generator(gan_input)
gan_output = discriminator(generated_image)

gan = Model(gan_input, gan_output)
gan.compile(optimizer = Adam(0.0002, 0.5),loss = 'binary_crossentropy')

def train_gan(epochs, batch_size = 128):
    X_train, _ = mnist.load_data()
    X_train = (X_train[0].astype(np.float32) - 127.5) /  127.5
    X_train = np.expand_dims(X_train, axis =3)
    
    real = np.ones((batch_size, 1))
    fake = np.zeros((batch_size, 1))
    
    for epoch in range(epochs):
        idx = np.random.randint(0, X_train.shape[0], batch_size)
        real_images = X_train[idx]
        
        noise = np.random.normal(0, 1, (batch_size, 100))
        generated_images = generator.predict(noise)
        
        d_loss_real = discriminator.train_on_batch(real_images, real)
        d_loss_fake = discriminator.train_on_batch(generated_images, fake)
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
        
        noise = np.random.normal(0, 1, (batch_size, 100))
        g_loss = gan.train_on_batch(noise, real)
        
        if epoch % 100 == 0:
            print(f"{epoch}[D loss: {d_loss[0]}, acc.: {100*d_loss[1]}][G loss:{g_loss}]")
            save_images(epoch)

In [17]:
def save_images(epoch):
    r,c = 5,5
    noise = np.random.normal(0,1,(r*c,100))
    generated_images = generator.predict(noise)
    
    generated_images = 0.5 * generated_images + 0.5
    
    fig,axs = plt.subplots(r,c)
    count = 0
    for i in range(r):
        for j in range(c):
            axs[i, j].imshow(generated_images[count, :, :, 0], cmap ='gray')
            axs[i, j].axis('off')
            count += 1
    fig.savefig(f"gan_images_{epoch}.png")
    plt.close()

train_gan(epochs = 10000, batch_size = 64)

0[D loss: 0.3698130501434207, acc.: 70.3125][G loss:0.6116334795951843]
100[D loss: 0.007253856165334582, acc.: 100.0][G loss:4.898748397827148]


200[D loss: 0.030777394771575928, acc.: 100.0][G loss:4.302901268005371]
300[D loss: 0.3130642995238304, acc.: 84.375][G loss:2.96044921875]


400[D loss: 0.7039513885974884, acc.: 42.96875][G loss:0.6762176752090454]


500[D loss: 0.6486066579818726, acc.: 50.0][G loss:0.6912078857421875]
600[D loss: 0.647762343287468, acc.: 54.6875][G loss:0.7309393286705017]


700[D loss: 0.6024916172027588, acc.: 67.96875][G loss:0.8005037307739258]


800[D loss: 0.6430397033691406, acc.: 60.15625][G loss:0.8414516448974609]
900[D loss: 0.6307241320610046, acc.: 66.40625][G loss:0.8505131006240845]


1000[D loss: 0.5874053835868835, acc.: 75.78125][G loss:0.909314751625061]


1100[D loss: 0.5564882159233093, acc.: 81.25][G loss:0.8835906386375427]
1200[D loss: 0.5616680383682251, acc.: 77.34375][G loss:1.0241050720214844]


1300[D loss: 0.5819709002971649, acc.: 72.65625][G loss:0.9526621103286743]
1400[D loss: 0.5449346303939819, acc.: 78.90625][G loss:1.0941789150238037]


1500[D loss: 0.5217130184173584, acc.: 78.90625][G loss:1.1611086130142212]


1600[D loss: 0.549655944108963, acc.: 76.5625][G loss:1.1511993408203125]
1700[D loss: 0.5764940977096558, acc.: 71.875][G loss:1.081390142440796]


1800[D loss: 0.5045652240514755, acc.: 77.34375][G loss:1.139988899230957]


1900[D loss: 0.4988410621881485, acc.: 83.59375][G loss:1.082879900932312]
2000[D loss: 0.5504713356494904, acc.: 74.21875][G loss:1.062625527381897]


2100[D loss: 0.5804810225963593, acc.: 68.75][G loss:1.0568459033966064]


2200[D loss: 0.6715314090251923, acc.: 59.375][G loss:0.9845061302185059]
2300[D loss: 0.5118896812200546, acc.: 82.8125][G loss:1.0119329690933228]


2400[D loss: 0.5805966854095459, acc.: 71.875][G loss:0.9624971151351929]
2500[D loss: 0.6226929426193237, acc.: 66.40625][G loss:0.8872741460800171]


2600[D loss: 0.587517499923706, acc.: 71.875][G loss:1.0514190196990967]


2700[D loss: 0.6097998023033142, acc.: 66.40625][G loss:1.016707420349121]
2800[D loss: 0.6157359480857849, acc.: 68.75][G loss:0.9989440441131592]


2900[D loss: 0.5749314427375793, acc.: 73.4375][G loss:0.9730767607688904]


3000[D loss: 0.5538907051086426, acc.: 79.6875][G loss:0.9493328332901001]
3100[D loss: 0.5548747479915619, acc.: 73.4375][G loss:1.0710453987121582]


3200[D loss: 0.5813236832618713, acc.: 71.875][G loss:1.054861068725586]


3300[D loss: 0.5689790546894073, acc.: 75.0][G loss:0.9908589124679565]
3400[D loss: 0.6021322906017303, acc.: 67.1875][G loss:0.9833284616470337]


3500[D loss: 0.6497190296649933, acc.: 60.15625][G loss:0.968590497970581]


3600[D loss: 0.6547919809818268, acc.: 62.5][G loss:1.0208160877227783]
3700[D loss: 0.5947176814079285, acc.: 74.21875][G loss:0.9444044828414917]


3800[D loss: 0.6172421276569366, acc.: 67.1875][G loss:1.0121361017227173]
3900[D loss: 0.5697589218616486, acc.: 71.875][G loss:0.9758402705192566]


4000[D loss: 0.6191858053207397, acc.: 65.625][G loss:0.8998044729232788]


4100[D loss: 0.6563844084739685, acc.: 60.9375][G loss:0.9674139022827148]
4200[D loss: 0.6080316305160522, acc.: 71.875][G loss:1.0173466205596924]


4300[D loss: 0.6683732271194458, acc.: 61.71875][G loss:0.9230300188064575]


4400[D loss: 0.6392416656017303, acc.: 66.40625][G loss:0.9560965299606323]
4500[D loss: 0.6291216611862183, acc.: 66.40625][G loss:0.9513747692108154]


4600[D loss: 0.6412449479103088, acc.: 62.5][G loss:0.970248818397522]


4700[D loss: 0.6107291281223297, acc.: 65.625][G loss:0.9341418147087097]
4800[D loss: 0.5881426334381104, acc.: 73.4375][G loss:0.9108672142028809]


4900[D loss: 0.6705198884010315, acc.: 56.25][G loss:0.8928346633911133]
5000[D loss: 0.6549724042415619, acc.: 55.46875][G loss:0.9204555749893188]


5100[D loss: 0.6792510747909546, acc.: 58.59375][G loss:0.9792760014533997]


5200[D loss: 0.6238093972206116, acc.: 64.84375][G loss:0.960403323173523]
5300[D loss: 0.6459329128265381, acc.: 60.9375][G loss:0.8314372301101685]


5400[D loss: 0.6195452809333801, acc.: 66.40625][G loss:0.9325863122940063]


5500[D loss: 0.6870007514953613, acc.: 56.25][G loss:0.9244795441627502]
5600[D loss: 0.6772038042545319, acc.: 55.46875][G loss:0.9176824688911438]


5700[D loss: 0.6364477574825287, acc.: 64.0625][G loss:0.9074916243553162]


5800[D loss: 0.674837738275528, acc.: 59.375][G loss:0.853234052658081]
5900[D loss: 0.6212098598480225, acc.: 67.1875][G loss:0.9566311836242676]


6000[D loss: 0.6515088379383087, acc.: 60.15625][G loss:0.9014945030212402]


6100[D loss: 0.6300647854804993, acc.: 62.5][G loss:0.9562975168228149]
6200[D loss: 0.6373026371002197, acc.: 65.625][G loss:0.8679916858673096]


6300[D loss: 0.6571722030639648, acc.: 59.375][G loss:0.935126781463623]
6400[D loss: 0.668074905872345, acc.: 57.03125][G loss:0.9273802042007446]


6500[D loss: 0.6394788324832916, acc.: 71.09375][G loss:0.937286376953125]


6600[D loss: 0.6748400330543518, acc.: 64.84375][G loss:0.807942807674408]
6700[D loss: 0.6588306725025177, acc.: 60.9375][G loss:0.8490819334983826]


6800[D loss: 0.6396415531635284, acc.: 60.9375][G loss:0.9534904956817627]


6900[D loss: 0.6746442317962646, acc.: 59.375][G loss:0.9382304549217224]
7000[D loss: 0.6629626154899597, acc.: 58.59375][G loss:0.9385400414466858]


7100[D loss: 0.6363238990306854, acc.: 61.71875][G loss:0.951632559299469]


7200[D loss: 0.6789445281028748, acc.: 54.6875][G loss:0.900259256362915]
7300[D loss: 0.6463287770748138, acc.: 62.5][G loss:0.8362614512443542]


7400[D loss: 0.6203669011592865, acc.: 64.84375][G loss:0.8995876312255859]
7500[D loss: 0.6647595763206482, acc.: 59.375][G loss:0.8676342964172363]


7600[D loss: 0.6178926825523376, acc.: 71.09375][G loss:0.9397890567779541]


7700[D loss: 0.6280323266983032, acc.: 64.0625][G loss:0.8906768560409546]
7800[D loss: 0.6597005724906921, acc.: 64.0625][G loss:0.8893228769302368]


7900[D loss: 0.6762827932834625, acc.: 57.03125][G loss:0.8996420502662659]


8000[D loss: 0.6597426235675812, acc.: 60.15625][G loss:0.9629033207893372]
8100[D loss: 0.6580213606357574, acc.: 61.71875][G loss:0.9165663719177246]


8200[D loss: 0.6232990026473999, acc.: 67.96875][G loss:0.9128062725067139]


8300[D loss: 0.6448296904563904, acc.: 64.84375][G loss:0.9450543522834778]
8400[D loss: 0.6594042181968689, acc.: 61.71875][G loss:0.8954659700393677]


8500[D loss: 0.6645207703113556, acc.: 60.15625][G loss:0.9039722681045532]


8600[D loss: 0.6686364114284515, acc.: 54.6875][G loss:0.8804199695587158]
8700[D loss: 0.6284319758415222, acc.: 65.625][G loss:0.8834078311920166]


8800[D loss: 0.6747044026851654, acc.: 62.5][G loss:0.8937047719955444]
8900[D loss: 0.6656822562217712, acc.: 60.15625][G loss:0.9422429800033569]


9000[D loss: 0.6534618735313416, acc.: 60.9375][G loss:0.8663043975830078]


9100[D loss: 0.6979722678661346, acc.: 55.46875][G loss:0.8766571283340454]
9200[D loss: 0.6137778759002686, acc.: 69.53125][G loss:0.9178847670555115]


9300[D loss: 0.6519144177436829, acc.: 61.71875][G loss:0.9281726479530334]


9400[D loss: 0.7019370198249817, acc.: 58.59375][G loss:0.9015482664108276]
9500[D loss: 0.6354238986968994, acc.: 64.84375][G loss:0.8751258850097656]


9600[D loss: 0.659547746181488, acc.: 59.375][G loss:0.9287635087966919]


9700[D loss: 0.6419765651226044, acc.: 62.5][G loss:0.9442354440689087]
9800[D loss: 0.6232584118843079, acc.: 67.96875][G loss:0.8760768175125122]


9900[D loss: 0.6482604444026947, acc.: 60.9375][G loss:0.882538914680481]
