In [0]:
# https://keras.io/
!pip install -q keras
import keras

In [None]:
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Reshape
from keras.layers.core import Activation
from keras.layers.normalization import BatchNormalization
from keras.layers.convolutional import UpSampling2D
from keras.layers.convolutional import Conv2D, MaxPooling2D
from keras.layers.core import Flatten
from keras.optimizers import SGD
from keras.datasets import mnist
import numpy as np
from PIL import Image
import argparse
import math


def generator_model():
    model = Sequential()
    model.add(Dense(input_dim=100, output_dim=1024))
    model.add(Activation('tanh'))
    model.add(Dense(128*7*7))
    model.add(BatchNormalization())
    model.add(Activation('tanh'))
    model.add(Reshape((7, 7, 128), input_shape=(128*7*7,)))
    model.add(UpSampling2D(size=(2, 2)))
    model.add(Conv2D(64, (5, 5), padding='same'))
    model.add(Activation('tanh'))
    model.add(UpSampling2D(size=(2, 2)))
    model.add(Conv2D(1, (5, 5), padding='same'))
    model.add(Activation('tanh'))
    return model


def discriminator_model():
    model = Sequential()
    model.add(
            Conv2D(64, (5, 5),
            padding='same',
            input_shape=(28, 28, 1))
            )
    model.add(Activation('tanh'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Conv2D(128, (5, 5)))
    model.add(Activation('tanh'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Flatten())
    model.add(Dense(1024))
    model.add(Activation('tanh'))
    model.add(Dense(1))
    model.add(Activation('sigmoid'))
    return model


def generator_containing_discriminator(g, d):
    model = Sequential()
    model.add(g)
    d.trainable = False
    model.add(d)
    return model


def combine_images(generated_images):
    num = generated_images.shape[0]
    width = int(math.sqrt(num))
    height = int(math.ceil(float(num)/width))
    shape = generated_images.shape[1:3]
    image = np.zeros((height*shape[0], width*shape[1]),
                     dtype=generated_images.dtype)
    for index, img in enumerate(generated_images):
        i = int(index/width)
        j = index % width
        image[i*shape[0]:(i+1)*shape[0], j*shape[1]:(j+1)*shape[1]] = \
            img[:, :, 0]
    return image


def train(BATCH_SIZE):
    (X_train, y_train), (X_test, y_test) = mnist.load_data()
    X_train = (X_train.astype(np.float32) - 127.5)/127.5
    X_train = X_train[:, :, :, None]
    X_test = X_test[:, :, :, None]
    # X_train = X_train.reshape((X_train.shape, 1) + X_train.shape[1:])
    d = discriminator_model()
    g = generator_model()
    d_on_g = generator_containing_discriminator(g, d)
    d_optim = SGD(lr=0.0005, momentum=0.9, nesterov=True)
    g_optim = SGD(lr=0.0005, momentum=0.9, nesterov=True)
    g.compile(loss='binary_crossentropy', optimizer="SGD")
    d_on_g.compile(loss='binary_crossentropy', optimizer=g_optim)
    d.trainable = True
    d.compile(loss='binary_crossentropy', optimizer=d_optim)
    for epoch in range(100):
        print("Epoch is", epoch)
        print("Number of batches", int(X_train.shape[0]/BATCH_SIZE))
        for index in range(int(X_train.shape[0]/BATCH_SIZE)):
            noise = np.random.uniform(-1, 1, size=(BATCH_SIZE, 100))
            image_batch = X_train[index*BATCH_SIZE:(index+1)*BATCH_SIZE]
            generated_images = g.predict(noise, verbose=0)
            if index % 20 == 0:
                image = combine_images(generated_images)
                image = image*127.5+127.5
                Image.fromarray(image.astype(np.uint8)).save(
                    str(epoch)+"_"+str(index)+".png")
            X = np.concatenate((image_batch, generated_images))
            y = [1] * BATCH_SIZE + [0] * BATCH_SIZE
            d_loss = d.train_on_batch(X, y)
            print("batch %d d_loss : %f" % (index, d_loss))
            noise = np.random.uniform(-1, 1, (BATCH_SIZE, 100))
            d.trainable = False
            g_loss = d_on_g.train_on_batch(noise, [1] * BATCH_SIZE)
            d.trainable = True
            print("batch %d g_loss : %f" % (index, g_loss))
            if index % 10 == 9:
                g.save_weights('generator', True)
                d.save_weights('discriminator', True)


def generate(BATCH_SIZE, nice=False):
    g = generator_model()
    g.compile(loss='binary_crossentropy', optimizer="SGD")
    g.load_weights('generator')
    if nice:
        d = discriminator_model()
        d.compile(loss='binary_crossentropy', optimizer="SGD")
        d.load_weights('discriminator')
        noise = np.random.uniform(-1, 1, (BATCH_SIZE*20, 100))
        generated_images = g.predict(noise, verbose=1)
        d_pret = d.predict(generated_images, verbose=1)
        index = np.arange(0, BATCH_SIZE*20)
        index.resize((BATCH_SIZE*20, 1))
        pre_with_index = list(np.append(d_pret, index, axis=1))
        pre_with_index.sort(key=lambda x: x[0], reverse=True)
        nice_images = np.zeros((BATCH_SIZE,) + generated_images.shape[1:3], dtype=np.float32)
        nice_images = nice_images[:, :, :, None]
        for i in range(BATCH_SIZE):
            idx = int(pre_with_index[i][1])
            nice_images[i, :, :, 0] = generated_images[idx, :, :, 0]
        image = combine_images(nice_images)
    else:
        noise = np.random.uniform(-1, 1, (BATCH_SIZE, 100))
        generated_images = g.predict(noise, verbose=1)
        image = combine_images(generated_images)
    image = image*127.5+127.5
    Image.fromarray(image.astype(np.uint8)).save(
        "generated_image.png")

Using TensorFlow backend.


In [None]:
train(BATCH_SIZE=32)



Epoch is 0
Number of batches 1875
batch 0 d_loss : 0.717485
batch 0 g_loss : 0.719484
batch 1 d_loss : 0.698724
batch 1 g_loss : 0.702067
batch 2 d_loss : 0.693461
batch 2 g_loss : 0.684042
batch 3 d_loss : 0.678437
batch 3 g_loss : 0.677454
batch 4 d_loss : 0.653074
batch 4 g_loss : 0.667103
batch 5 d_loss : 0.636978
batch 5 g_loss : 0.663174
batch 6 d_loss : 0.616316
batch 6 g_loss : 0.646046
batch 7 d_loss : 0.601052
batch 7 g_loss : 0.627764
batch 8 d_loss : 0.567922
batch 8 g_loss : 0.625304
batch 9 d_loss : 0.556588
batch 9 g_loss : 0.603251
batch 10 d_loss : 0.537633
batch 10 g_loss : 0.586539
batch 11 d_loss : 0.525940
batch 11 g_loss : 0.572845
batch 12 d_loss : 0.505949
batch 12 g_loss : 0.564898
batch 13 d_loss : 0.499069
batch 13 g_loss : 0.565922
batch 14 d_loss : 0.494954
batch 14 g_loss : 0.553155
batch 15 d_loss : 0.480873
batch 15 g_loss : 0.547601
batch 16 d_loss : 0.480038
batch 16 g_loss : 0.531615
batch 17 d_loss : 0.466359
batch 17 g_loss : 0.534420
batch 18 d_los

batch 153 g_loss : 0.771107
batch 154 d_loss : 0.284078
batch 154 g_loss : 0.768545
batch 155 d_loss : 0.291911
batch 155 g_loss : 0.779413
batch 156 d_loss : 0.292563
batch 156 g_loss : 0.805956
batch 157 d_loss : 0.245907
batch 157 g_loss : 0.760745
batch 158 d_loss : 0.290417
batch 158 g_loss : 0.755286
batch 159 d_loss : 0.255706
batch 159 g_loss : 0.833674
batch 160 d_loss : 0.268355
batch 160 g_loss : 0.815905
batch 161 d_loss : 0.256403
batch 161 g_loss : 0.839925
batch 162 d_loss : 0.286654
batch 162 g_loss : 0.882778
batch 163 d_loss : 0.258892
batch 163 g_loss : 0.895843
batch 164 d_loss : 0.238373
batch 164 g_loss : 0.914268
batch 165 d_loss : 0.286694
batch 165 g_loss : 0.907095
batch 166 d_loss : 0.243652
batch 166 g_loss : 0.995873
batch 167 d_loss : 0.248906
batch 167 g_loss : 0.976548
batch 168 d_loss : 0.249957
batch 168 g_loss : 1.026001
batch 169 d_loss : 0.243866
batch 169 g_loss : 1.047880
batch 170 d_loss : 0.220836
batch 170 g_loss : 1.069216
batch 171 d_loss : 0

batch 300 g_loss : 1.088796
batch 301 d_loss : 0.567398
batch 301 g_loss : 0.928250
batch 302 d_loss : 0.545493
batch 302 g_loss : 1.002749
batch 303 d_loss : 0.570665
batch 303 g_loss : 1.087338
batch 304 d_loss : 0.503118
batch 304 g_loss : 1.053275
batch 305 d_loss : 0.516800
batch 305 g_loss : 1.044044
batch 306 d_loss : 0.594264
batch 306 g_loss : 1.063753
batch 307 d_loss : 0.519214
batch 307 g_loss : 1.134147
batch 308 d_loss : 0.530934
batch 308 g_loss : 1.121230
batch 309 d_loss : 0.581734
batch 309 g_loss : 1.097302
batch 310 d_loss : 0.535517
batch 310 g_loss : 1.059381
batch 311 d_loss : 0.567819
batch 311 g_loss : 0.995289
batch 312 d_loss : 0.518348
batch 312 g_loss : 1.003891
batch 313 d_loss : 0.464727
batch 313 g_loss : 1.070916
batch 314 d_loss : 0.545015
batch 314 g_loss : 1.060147
batch 315 d_loss : 0.478084
batch 315 g_loss : 1.168765
batch 316 d_loss : 0.581226
batch 316 g_loss : 1.117071
batch 317 d_loss : 0.578682
batch 317 g_loss : 1.232405
batch 318 d_loss : 0

batch 450 d_loss : 0.535316
batch 450 g_loss : 1.008352
batch 451 d_loss : 0.515704
batch 451 g_loss : 1.105096
batch 452 d_loss : 0.529395
batch 452 g_loss : 1.098683
batch 453 d_loss : 0.553042
batch 453 g_loss : 1.019673
batch 454 d_loss : 0.519293
batch 454 g_loss : 1.020024
batch 455 d_loss : 0.542482
batch 455 g_loss : 0.910317
batch 456 d_loss : 0.544831
batch 456 g_loss : 0.907691
batch 457 d_loss : 0.579923
batch 457 g_loss : 0.842383
batch 458 d_loss : 0.549304
batch 458 g_loss : 0.968528
batch 459 d_loss : 0.528392
batch 459 g_loss : 1.032270
batch 460 d_loss : 0.581104
batch 460 g_loss : 0.969159
batch 461 d_loss : 0.587647
batch 461 g_loss : 0.956386
batch 462 d_loss : 0.623485
batch 462 g_loss : 0.966768
batch 463 d_loss : 0.540665
batch 463 g_loss : 0.935149
batch 464 d_loss : 0.500689
batch 464 g_loss : 0.986218
batch 465 d_loss : 0.479586
batch 465 g_loss : 0.958206
batch 466 d_loss : 0.475827
batch 466 g_loss : 1.034124
batch 467 d_loss : 0.525135
batch 467 g_loss : 1

batch 598 d_loss : 0.460791
batch 598 g_loss : 1.083981
batch 599 d_loss : 0.480695
batch 599 g_loss : 1.143752
batch 600 d_loss : 0.524927
batch 600 g_loss : 1.166093
batch 601 d_loss : 0.460971
batch 601 g_loss : 1.297674
batch 602 d_loss : 0.512146
batch 602 g_loss : 1.150487
batch 603 d_loss : 0.533260
batch 603 g_loss : 1.110017
batch 604 d_loss : 0.480549
batch 604 g_loss : 1.079154
batch 605 d_loss : 0.537146
batch 605 g_loss : 1.128096
batch 606 d_loss : 0.505567
batch 606 g_loss : 1.244106
batch 607 d_loss : 0.463311
batch 607 g_loss : 1.155816
batch 608 d_loss : 0.448523
batch 608 g_loss : 1.208422
batch 609 d_loss : 0.474382
batch 609 g_loss : 1.207161
batch 610 d_loss : 0.490139
batch 610 g_loss : 1.394354
batch 611 d_loss : 0.420154
batch 611 g_loss : 1.146855
batch 612 d_loss : 0.356967
batch 612 g_loss : 1.297776
batch 613 d_loss : 0.429685
batch 613 g_loss : 1.385126
batch 614 d_loss : 0.423202
batch 614 g_loss : 1.284010
batch 615 d_loss : 0.427680
batch 615 g_loss : 1

batch 748 g_loss : 1.182997
batch 749 d_loss : 0.390731
batch 749 g_loss : 1.345110
batch 750 d_loss : 0.470034
batch 750 g_loss : 1.253316
batch 751 d_loss : 0.351470
batch 751 g_loss : 1.327123
batch 752 d_loss : 0.374779
batch 752 g_loss : 1.318510
batch 753 d_loss : 0.368863
batch 753 g_loss : 1.464014
batch 754 d_loss : 0.419003
batch 754 g_loss : 1.406382
batch 755 d_loss : 0.374543
batch 755 g_loss : 1.311046
batch 756 d_loss : 0.444028
batch 756 g_loss : 1.220539
batch 757 d_loss : 0.465552
batch 757 g_loss : 1.234674
batch 758 d_loss : 0.457810
batch 758 g_loss : 1.205974
batch 759 d_loss : 0.404713
batch 759 g_loss : 1.294503
batch 760 d_loss : 0.443024
batch 760 g_loss : 1.196562
batch 761 d_loss : 0.422343
batch 761 g_loss : 1.191675
batch 762 d_loss : 0.383426
batch 762 g_loss : 1.157084
batch 763 d_loss : 0.402382
batch 763 g_loss : 1.298836
batch 764 d_loss : 0.475508
batch 764 g_loss : 1.319478
batch 765 d_loss : 0.387870
batch 765 g_loss : 1.150745
batch 766 d_loss : 0

batch 898 d_loss : 0.329493
batch 898 g_loss : 1.498768
batch 899 d_loss : 0.405500
batch 899 g_loss : 1.231758
batch 900 d_loss : 0.461412
batch 900 g_loss : 1.310985
batch 901 d_loss : 0.579981
batch 901 g_loss : 1.175876
batch 902 d_loss : 0.398902
batch 902 g_loss : 1.172351
batch 903 d_loss : 0.487567
batch 903 g_loss : 1.311985
batch 904 d_loss : 0.449313
batch 904 g_loss : 1.338954
batch 905 d_loss : 0.545986
batch 905 g_loss : 1.222734
batch 906 d_loss : 0.378935
batch 906 g_loss : 1.425147
batch 907 d_loss : 0.420152
batch 907 g_loss : 1.267225
batch 908 d_loss : 0.434250
batch 908 g_loss : 1.021756
batch 909 d_loss : 0.418179
batch 909 g_loss : 1.150352
batch 910 d_loss : 0.417066
batch 910 g_loss : 1.201800
batch 911 d_loss : 0.483350
batch 911 g_loss : 0.967144
batch 912 d_loss : 0.426248
batch 912 g_loss : 1.152966
batch 913 d_loss : 0.451426
batch 913 g_loss : 1.252844
batch 914 d_loss : 0.433033
batch 914 g_loss : 1.335000
batch 915 d_loss : 0.426213
batch 915 g_loss : 1

batch 1048 g_loss : 1.307321
batch 1049 d_loss : 0.282174
batch 1049 g_loss : 1.569252
batch 1050 d_loss : 0.326748
batch 1050 g_loss : 1.513834
batch 1051 d_loss : 0.313138
batch 1051 g_loss : 1.432735
batch 1052 d_loss : 0.327191
batch 1052 g_loss : 1.343126
batch 1053 d_loss : 0.310023
batch 1053 g_loss : 1.516090
batch 1054 d_loss : 0.343923
batch 1054 g_loss : 1.506439
batch 1055 d_loss : 0.315880
batch 1055 g_loss : 1.521456
batch 1056 d_loss : 0.349789
batch 1056 g_loss : 1.208192
batch 1057 d_loss : 0.313584
batch 1057 g_loss : 1.595654
batch 1058 d_loss : 0.330943
batch 1058 g_loss : 1.446036
batch 1059 d_loss : 0.319688
batch 1059 g_loss : 1.535874
batch 1060 d_loss : 0.325361
batch 1060 g_loss : 1.533431
batch 1061 d_loss : 0.344197
batch 1061 g_loss : 1.483652
batch 1062 d_loss : 0.292753
batch 1062 g_loss : 1.467966
batch 1063 d_loss : 0.256527
batch 1063 g_loss : 1.536387
batch 1064 d_loss : 0.325485
batch 1064 g_loss : 1.579263
batch 1065 d_loss : 0.381020
batch 1065 g_l

batch 1192 g_loss : 1.420190
batch 1193 d_loss : 0.327083
batch 1193 g_loss : 1.564209
batch 1194 d_loss : 0.340095
batch 1194 g_loss : 1.693103
batch 1195 d_loss : 0.371494
batch 1195 g_loss : 1.394744
batch 1196 d_loss : 0.320928
batch 1196 g_loss : 1.454461
batch 1197 d_loss : 0.411475
batch 1197 g_loss : 1.415937
batch 1198 d_loss : 0.354767
batch 1198 g_loss : 1.773944
batch 1199 d_loss : 0.322757
batch 1199 g_loss : 1.448356
batch 1200 d_loss : 0.334571
batch 1200 g_loss : 1.774300
batch 1201 d_loss : 0.309906
batch 1201 g_loss : 1.201935
batch 1202 d_loss : 0.315834
batch 1202 g_loss : 1.943817
batch 1203 d_loss : 0.340108
batch 1203 g_loss : 1.230630
batch 1204 d_loss : 0.268559
batch 1204 g_loss : 1.905958
batch 1205 d_loss : 0.436046
batch 1205 g_loss : 0.865862
batch 1206 d_loss : 0.437925
batch 1206 g_loss : 1.828074
batch 1207 d_loss : 0.301577
batch 1207 g_loss : 1.463827
batch 1208 d_loss : 0.311130
batch 1208 g_loss : 1.661705
batch 1209 d_loss : 0.394044
batch 1209 g_l

batch 1336 g_loss : 1.403135
batch 1337 d_loss : 0.279574
batch 1337 g_loss : 1.749354
batch 1338 d_loss : 0.250417
batch 1338 g_loss : 2.082818
batch 1339 d_loss : 0.331165
batch 1339 g_loss : 1.533584
batch 1340 d_loss : 0.242815
batch 1340 g_loss : 1.631195
batch 1341 d_loss : 0.312921
batch 1341 g_loss : 1.506251
batch 1342 d_loss : 0.349179
batch 1342 g_loss : 1.563606
batch 1343 d_loss : 0.323105
batch 1343 g_loss : 1.057793
batch 1344 d_loss : 0.339227
batch 1344 g_loss : 1.592349
batch 1345 d_loss : 0.310384
batch 1345 g_loss : 1.409451
batch 1346 d_loss : 0.291904
batch 1346 g_loss : 1.813497
batch 1347 d_loss : 0.293116
batch 1347 g_loss : 1.062975
batch 1348 d_loss : 0.324888
batch 1348 g_loss : 1.757511
batch 1349 d_loss : 0.266183
batch 1349 g_loss : 1.868830
batch 1350 d_loss : 0.322744
batch 1350 g_loss : 1.450181
batch 1351 d_loss : 0.283160
batch 1351 g_loss : 2.029853
batch 1352 d_loss : 0.307723
batch 1352 g_loss : 1.498043
batch 1353 d_loss : 0.287857
batch 1353 g_l

batch 1481 d_loss : 0.531564
batch 1481 g_loss : 1.159352
batch 1482 d_loss : 0.587924
batch 1482 g_loss : 1.221395
batch 1483 d_loss : 0.469215
batch 1483 g_loss : 1.390777
batch 1484 d_loss : 0.463938
batch 1484 g_loss : 1.662843
batch 1485 d_loss : 0.563876
batch 1485 g_loss : 0.970756
batch 1486 d_loss : 0.568214
batch 1486 g_loss : 1.765934
batch 1487 d_loss : 0.686233
batch 1487 g_loss : 0.971171
batch 1488 d_loss : 0.567576
batch 1488 g_loss : 1.994124
batch 1489 d_loss : 0.685750
batch 1489 g_loss : 1.212259
batch 1490 d_loss : 0.440938
batch 1490 g_loss : 1.403110
batch 1491 d_loss : 0.383432
batch 1491 g_loss : 1.486329
batch 1492 d_loss : 0.382870
batch 1492 g_loss : 1.478818
batch 1493 d_loss : 0.386403
batch 1493 g_loss : 1.000360
batch 1494 d_loss : 0.409562
batch 1494 g_loss : 1.957152
batch 1495 d_loss : 0.464983
batch 1495 g_loss : 1.069138
batch 1496 d_loss : 0.409992
batch 1496 g_loss : 1.768747
batch 1497 d_loss : 0.479509
batch 1497 g_loss : 1.093394
batch 1498 d_l

batch 1624 g_loss : 1.769799
batch 1625 d_loss : 0.332409
batch 1625 g_loss : 1.865885
batch 1626 d_loss : 0.251224
batch 1626 g_loss : 2.414574
batch 1627 d_loss : 0.432487
batch 1627 g_loss : 1.030392
batch 1628 d_loss : 0.505014
batch 1628 g_loss : 2.193525
batch 1629 d_loss : 0.552896
batch 1629 g_loss : 0.749629
batch 1630 d_loss : 0.473955
batch 1630 g_loss : 2.530683
batch 1631 d_loss : 0.433636
batch 1631 g_loss : 1.239737
batch 1632 d_loss : 0.286725
batch 1632 g_loss : 1.790278
batch 1633 d_loss : 0.378686
batch 1633 g_loss : 1.181203
batch 1634 d_loss : 0.447950
batch 1634 g_loss : 2.168179
batch 1635 d_loss : 0.351449
batch 1635 g_loss : 1.049619
batch 1636 d_loss : 0.328427
batch 1636 g_loss : 2.353460
batch 1637 d_loss : 0.506751
batch 1637 g_loss : 0.959947
batch 1638 d_loss : 0.509419
batch 1638 g_loss : 1.648051
batch 1639 d_loss : 0.445095
batch 1639 g_loss : 1.544820
batch 1640 d_loss : 0.270231
batch 1640 g_loss : 2.078279
batch 1641 d_loss : 0.250494
batch 1641 g_l