In [None]:
# download necessary file
! wget https://storage.googleapis.com/soph-data/celeba/soph.py

In [None]:
%pylab inline

import keras
import keras.backend as K
import numpy as np
import functools
from tqdm import tqdm, tqdm_notebook
import soph

In [None]:
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()

In [None]:
# we can see that the range of pixels is 0-255. 
# We'll go ahead and scale it from -1 to 1
# This will be important later
print("original range of pixels", x_train.min(), x_train.max())

x_train = (2/255)*x_train - 1
x_test = (2/255)*x_test - 1

print("new range of pixels", x_train.min(), x_train.max())

In [None]:
lrelu = functools.partial(K.relu, alpha=.1)
z_len = 100
im_size = 32

In [None]:
conv = functools.partial(
    soph.Conv2D,
    filters=50,
    kernel_size=3,
    padding="same",
    activation=lrelu,
    strides=2,
    spectral_normalization=True,
)
bnorm = functools.partial(keras.layers.BatchNormalization, momentum=0.8)

disc = keras.Sequential([
    keras.layers.InputLayer(input_shape=(im_size, im_size, 3)),
    conv(),
    bnorm(),
    conv(),
    bnorm(),
    conv(),
    bnorm(),
    keras.layers.Flatten(),
    keras.layers.Dense(1)
])

gen = keras.Sequential([
    keras.layers.InputLayer(input_shape=(z_len, )),
    keras.layers.Dense(50 * 4 * 4, activation=lrelu),
    keras.layers.Reshape((4, 4, 50)),
    bnorm(),
    keras.layers.UpSampling2D(),
    

    conv(strides=1),
    bnorm(),
    keras.layers.UpSampling2D(),
    conv(strides=1),
    bnorm(),
    keras.layers.UpSampling2D(),

    
    conv(filters=3, strides=1, activation='tanh')
])

disc.summary()
gen.summary()

In [None]:
img_real = keras.Input(shape=(im_size, im_size, 3 ))
z_gen = keras.Input(shape=(z_len, ))
img_fake = gen(z_gen)

logits_real = disc(img_real)
logits_fake = disc(img_fake)

def disc_loss(yt, yp):
    
    loss_real = K.mean(K.relu(1-logits_real))
    loss_fake = K.mean(K.relu(1+logits_fake))
    
    return (loss_real + loss_fake)/2

def gen_loss(yt, yp):
        return -K.mean(logits_fake)

In [None]:
gen.trainable = False
disc.trainable = True

disc_train = keras.Model([img_real, z_gen], [logits_real, logits_fake])
disc_train.compile(loss = [disc_loss, None], optimizer=keras.optimizers.Adam(lr=0.0002, beta_1=0))

In [None]:
gen.trainable = True
disc.trainable = False

gen_train = keras.Model(z_gen, logits_fake)
gen_train.compile(loss = gen_loss, optimizer=keras.optimizers.Adam(lr=0.00005, beta_1=0))

In [None]:
figsize(10,10)

n_disc =2
sample_side = 10
batch_size = 64

num_steps = x_train.shape[0]//batch_size

num_epochs = 10

d_loss_list = []
g_loss_list = []

dumb = np.ones((batch_size, 1))

# this will be used for sampling after each epoch. 
# if we stick with the same set of noise, we can see how the generated images changes

noise_sample = np.random.normal(size=(sample_side**2, z_len))

for epoch_i in range(num_epochs):
    print(f"starting epoch {epoch_i}")
    
    for step_i in tqdm_notebook(range(num_steps)):
        
        
        for _ in range(n_disc):
            idx = np.random.randint(0,x_train.shape[0], batch_size)

            x_batch = x_train[idx,...]

            noise_batch = np.random.normal(size=(batch_size, z_len))
                        
            # Train the critic
            d_loss = disc_train.train_on_batch([x_batch, noise_batch], dumb)
            d_loss_list.append(d_loss[0])

        noise_batch = np.random.normal(size=(batch_size, z_len))
        g_loss = gen_train.train_on_batch(noise_batch, dumb)
        g_loss_list.append(g_loss)
        
    
    
    gen_sample = gen.predict(noise_sample)
    
    gen_sample = (gen_sample+1)/2
    
    gen_sample = gen_sample.reshape(sample_side,sample_side*im_size,im_size,3)
    gen_sample = gen_sample.transpose((0,2,1,3))
    gen_sample = gen_sample.reshape(sample_side*im_size,sample_side*im_size,3)
    gen_sample = gen_sample.transpose((1,0,2))
    
    
    plt.clf();
    plt.figure();
    plt.imshow(gen_sample)
    plt.show();
    
    
    print(f" disc loss: {np.mean(d_loss_list[-n_disc*num_steps:])} gen loss: {np.mean(g_loss_list[-num_steps:])}")
    