# import libraries


In [None]:

import sys

if sys.platform.startswith('win32'):
    # Windows specific procedures
    print("Windows")
    import os
    os.environ["KERAS_BACKEND"] = "plaidml.keras.backend"

In [None]:
# import tensorflow as tf
import keras
# import keras.layers as Layers
import numpy as np
import matplotlib.pyplot as plt
import time

from IPython import display

# define parameters

In [None]:
LOSS = 'binary_crossentropy'
OPTIMIZER_D = tf.keras.optimizers.RMSprop()
OPTIMIZER_A = tf.keras.optimizers.RMSprop()
METRICS = 'accuracy'
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

INPUT_LENGTH = 128
BUFFER_SIZE = 60000

EPOCHS = 10
BATCH_SIZE = 256

INPUT_SHAPE = (28, 28, 1)


# load data

In [None]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
input_data = np.expand_dims(x_train, axis=-1).astype('float32')
input_data = (input_data)/255.0

print(input_data.shape)
print(np.max(input_data))

In [None]:
# Batch and shuffle the data
train_dataset = tf.data.Dataset.from_tensor_slices(input_data).shuffle(BUFFER_SIZE)
train_dataset = train_dataset.batch(BATCH_SIZE, drop_remainder=True)

# define model

### generator

In [None]:
def generator():
    dim = INPUT_SHAPE[0]//4
    depth = 256
    model = tf.keras.models.Sequential()

    model.add(Layers.Dense(dim*dim*depth, use_bias=False, input_dim=INPUT_LENGTH))
    model.add(Layers.BatchNormalization())
    model.add(Layers.LeakyReLU())

    model.add(Layers.Reshape((dim, dim, depth)))

    model.add(Layers.Conv2DTranspose(depth//2, 7, strides=1, padding='same', use_bias=False))
    model.add(Layers.BatchNormalization())
    model.add(Layers.LeakyReLU())

    model.add(Layers.Conv2DTranspose(depth//4, 5, strides=2, padding='same', use_bias=False))
    model.add(Layers.BatchNormalization())
    model.add(Layers.LeakyReLU())

    model.add(Layers.Conv2DTranspose(INPUT_SHAPE[2], 3, strides=2, activation='tanh', padding='same', use_bias=False))

    model.summary()
    return model

### discriminator

In [None]:
def discriminator():
    model = tf.keras.models.Sequential()

    model.add(Layers.Conv2D(64, 5, strides=2, padding='same', input_shape=INPUT_SHAPE))
    model.add(Layers.LeakyReLU())
    model.add(Layers.Dropout(0.3))

    model.add(Layers.Conv2D(128, 5, strides=2, padding='same'))
    model.add(Layers.LeakyReLU())

    model.add(Layers.Flatten())

    model.add(Layers.Dense(1, activation='sigmoid'))
    model.summary()
    return model

# Full model

In [None]:
gen = generator()
discr = discriminator()

discr.compile(loss=LOSS, optimizer=OPTIMIZER_D)

In [None]:
advr = tf.keras.models.Sequential([gen, discr])
advr.compile(loss=LOSS, optimizer=OPTIMIZER_A)

# Define train step

In [None]:
def train_step(images):
    gen, discr = advr.layers
    noise = tf.random.normal([BATCH_SIZE, INPUT_LENGTH])
    fake_images = gen(noise)
    x_d = tf.concat([images, fake_images], axis=0)
    y_d = tf.constant([[1.0]]*BATCH_SIZE+[[0.0]]*BATCH_SIZE)
    discr.trainable = True
    discr.train_on_batch(x_d, y_d)
    discr.trainable = False
    y_a = tf.constant([[1.0]]*BATCH_SIZE)
    advr.train_on_batch(noise, y_a)
    #return gen_loss, disc_loss
    

# Train model

In [None]:
seed = tf.random.normal([16, INPUT_LENGTH])

In [None]:
def generate_and_save_images(model, epoch, test_input):
  predictions = model(test_input, training=False)

  fig = plt.figure(figsize=(4,4))

  for i in range(predictions.shape[0]):
      plt.subplot(4, 4, i+1)
      plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
      plt.axis('off')

  plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
  plt.show()


In [None]:
def train(dataset, epochs):
    for e in range(epochs):
      start = time.time()
      for image_batch in dataset:
          train_step(image_batch)
      # Produce images for the GIF as we go
      display.clear_output(wait=True)
      generate_and_save_images(gen, e + 1, seed)

      print ('Time for epoch {} is {} sec'.format(e + 1, time.time()-start))

    # Generate after the final epoch
    display.clear_output(wait=True)
    generate_and_save_images(gen, epochs, seed)


In [None]:
train(train_dataset, EPOCHS)

# Generate new images

In [None]:
num_images = 10
noise = tf.random.normal([num_images, INPUT_LENGTH])
generated_images = gen(noise, training=False)*255.0

# View generated images

In [None]:
plt.imshow(np.squeeze(generated_images[np.random.randint(num_images)]), cmap='gray')