In [1]:
from PIL import Image
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import tarfile


tf.__version__

'2.0.0'

In [68]:
from tensorflow.keras import Model
from tensorflow.keras.layers import (InputLayer, ZeroPadding2D, Conv2D, 
                                     Conv2DTranspose, LeakyReLU, BatchNormalization, 
                                     Dense, Flatten, Input, Conv2DTranspose, ReLU)



class Encoder(Model):
    def __init__(self, batch_size, lrelu_slope=0.2):
        super(Model, self).__init__()
        self.initializer = tf.keras.initializers.RandomNormal(stddev=0.02)
        self.inputs = InputLayer(input_shape=(batch_size, 64, 64, 3))
        
        self.pad1 = ZeroPadding2D(padding=(2, 2))
        self.conv1 = Conv2D(filters=128, kernel_size=5, strides=(2, 2), 
                            kernel_initializer=self.initializer)
        
        self.pad2 = ZeroPadding2D(padding=(2, 2))
        self.conv2 = Conv2D(filters=256, kernel_size=5, strides=(2, 2), use_bias=False,
                            kernel_initializer=self.initializer)
        self.bn1 = BatchNormalization()
        self.lrelu2 = LeakyReLU(alpha=lrelu_slope)
        
        self.pad3 = ZeroPadding2D(padding=(2, 2))
        self.conv3 = Conv2D(filters=512, kernel_size=5, strides=(2, 2), use_bias=False,
                            kernel_initializer=self.initializer)
        self.bn2 = BatchNormalization()
        self.lrelu3 = LeakyReLU(alpha=lrelu_slope)
        
        self.pad4 = ZeroPadding2D(padding=(2, 2))
        self.conv4 = Conv2D(filters=1024, kernel_size=5, strides=(2, 2), use_bias=False,
                            kernel_initializer=self.initializer)
        self.bn3 = BatchNormalization()
        self.lrelu4 = LeakyReLU(alpha=lrelu_slope)
        
    def __call__(self, X, training=False):
        network = self.inputs(X)

        network = self.pad1(network)
        network = self.conv1(network)

        network = self.pad2(network)
        network = self.conv2(network) 
        network = self.bn1(network, training=training) 
        network = self.lrelu2(network)

        network = self.pad3(network)
        network = self.conv3(network) 
        network = self.bn2(network, training=training) 
        network = self.lrelu3(network)

        network = self.pad4(network)
        network = self.conv4(network) 
        network = self.bn3(network, training=training)
        
#         network = self.conv5(network)
        return network

class Discriminator(Encoder):
    def __init__(self, batch_size, lrelu_slope=0.2, n_input_channels=3):
        super(Discriminator, self).__init__(batch_size, lrelu_slope)
        if n_input_channels != 3:
            self.inputs = InputLayer(input_shape=(batch_size, 64, 64, n_input_channels))
            
        self.conv = Conv2D(filters=1, kernel_size=4, activation='sigmoid', 
                            kernel_initializer=self.initializer)
    
    @staticmethod
    def loss(logits_real, logits_fake):
        bce = tf.keras.losses.BinaryCrossentropy(from_logits=True)
        return tf.reduce_mean(bce(tf.ones(logits_real.shape), logits_real) + 
                              bce(tf.zeros(logits_fake.shape), logits_fake))
    
    def __call__(self, X, training=False):
        network = super(Discriminator, self).__call__(X)
        network = self.conv(network)
        return network
    
class Decoder(Model):
    def __init__(self, batch_size):
        super(Decoder, self).__init__()
        initializer = tf.keras.initializers.RandomNormal(stddev=0.02)
        self.inputs = InputLayer(input_shape=(batch_size, 2, 2, 64))
        
        self.transp_conv1 = Conv2DTranspose(filters=1024, kernel_size=2, strides=(2, 2),
                                            use_bias=False, kernel_initializer=initializer)
        self.bn1 = BatchNormalization()
        self.relu1 = ReLU()
        
        self.transp_conv2 = Conv2DTranspose(filters=512, kernel_size=2, strides=(2,2), 
                                            use_bias=False, kernel_initializer=initializer)
        self.bn2 = BatchNormalization()
        self.relu2 = ReLU()
        
        self.transp_conv3 = Conv2DTranspose(filters=256, kernel_size=2, strides=(2, 2), 
                                            use_bias=False, kernel_initializer=initializer)
        self.bn3 = BatchNormalization()
        self.relu3 = ReLU()
        
        self.transp_conv4 = Conv2DTranspose(filters=128, kernel_size=2, strides=(2, 2), 
                                            use_bias=False, kernel_initializer=initializer)
        self.bn4 = BatchNormalization()
        self.relu4 = ReLU()
        
        self.transp_conv5 = Conv2DTranspose(filters=3, kernel_size=2, strides=(2, 2), activation='tanh', 
                                            use_bias=False, kernel_initializer=initializer)
        
    def __call__(self, X, training=False):
        network = self.inputs(X)
#         print(network.shape)
        network = self.transp_conv1(network)
        network = self.bn1(network, training=training)
        network = self.relu1(network)
#         print(network.shape)
        network = self.transp_conv2(network)
        network = self.bn2(network, training=training)
        network = self.relu2(network)
#         print(network.shape)
        network = self.transp_conv3(network)
        network = self.bn3(network, training=training)
        network = self.relu3(network)
#         print(network.shape)
        network = self.transp_conv4(network)
        network = self.bn4(network, training=training)
        network = self.relu4(network)
#         print(network.shape)
        network = self.transp_conv5(network)
#         print(network.shape)
        return network
        
class Converter:
    def __init__(self, batch_size):
        initializer = tf.keras.initializers.RandomNormal(stddev=0.02)
        self.inputs = InputLayer(input_shape=(batch_size, 64, 64, 3))
        self.encoder = Encoder(batch_size)
        self.conv = Conv2D(filters=64, kernel_size=3, kernel_initializer=initializer)
        self.bn = BatchNormalization()
        self.relu = ReLU()
        self.decoder = Decoder(batch_size)
        
    def __call__(self, X, training=False):
        network = self.inputs(X)
        network = self.encoder(network)
        
        network = self.conv(network)
        network = self.bn(network)
        network = self.relu(network)

        network = self.decoder(network)
        return network

In [69]:
def test_converter():
    batch_size = 1
    conv = Converter(batch_size)
    X = tf.zeros((batch_size, 64, 64, 3))
    y = conv(X)
    assert y.shape == (batch_size, 64, 64, 3)
    print('Encoder produces output of valid shape.')
    
def test_discriminator():
    batch_size = 128
    X = tf.zeros((batch_size, 64, 64, 3))
    real_fake_discrim = Discriminator(batch_size)
    y = real_fake_discrim(X)
    assert y.shape == (batch_size, 1, 1, 1)
    
    domain_discrim = Discriminator(batch_size, n_input_channels=6)
    X = tf.zeros((batch_size, 64, 64, 6))
    y = domain_discrim(X)
    assert y.shape == (batch_size, 1, 1, 1)
    print('Discriminator produces outputs of valid shape.')

test_converter()
test_discriminator()

Encoder produces output of valid shape.
Discriminator produces outputs of valid shape.


In [None]:
def sample_noise(shape):
    return tf.random.uniform(shape, minval=-1.0, maxval=1.0)

def show_images(images, scaled=True):
    if scaled:
        images = (images * 255).astype(np.uint8)

    img_size = images.shape[1]
    images = np.reshape(images, [images.shape[0], -1])  # images reshape to (batch_size, D)
    sqrtn = int(np.ceil(np.sqrt(images.shape[0])))
    fig = plt.figure(figsize=(sqrtn, sqrtn))
    gs = gridspec.GridSpec(sqrtn, sqrtn)
    gs.update(wspace=0.05, hspace=0.05)

    for i, img in enumerate(images):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(img.reshape([img_size, img_size, 3]))

im = sample_noise((16, 64, 64, 3)).numpy()
show_images(im)

In [None]:
BATCH_SIZE = 128
LEARNING_RATE = 0.0002
MOMENTUM = 0.5
n_epochs = 25
device = '/device:GPU:0'
solver = tf.keras.optimizers.SGD(LEARNING_RATE, MOMENTUM)

encoder = Encoder(BATCH_SIZE)
domain_discrim = Discriminator(BATCH_SIZE)