In [None]:
# WARNING these lines download a couple gigabytes of images.
# they are meant to be run on REMOTE COMPUTERS using REMOTE BANDWIDTH
# if you run these locally, like in a classroom, you are likely 
# to fill up your hard drive and make the local wifi unusable

# usually takes 1-2 minutes to download and decompress

# ! apt install liblz4-tool
# ! wget https://storage.googleapis.com/soph-data/celeba/celeba.csv
# ! wget https://storage.googleapis.com/soph-data/celeba/r64.tar.lz4
# ! unlz4 r64.tar.lz4 - | tar xf - r64
# ! wget https://storage.googleapis.com/soph-data/celeba/r128.tar.lz4
# ! unlz4 r128.tar.lz4 - | tar xf - r128

In [None]:
!pip install keras -U

In [None]:
%pylab inline

import keras
import keras.backend as K
import numpy as np
import functools
from tqdm import tqdm, tqdm_notebook
import pandas as pd
import soph

In [None]:
z_len = 100
im_size = 64
batch_size = 100
optimizer_d = keras.optimizers.Adam(lr=0.0002, beta_1=0)
optimizer_g = keras.optimizers.Adam(lr=0.00005, beta_1=0)

In [None]:
celeba_df = pd.read_csv("celeba.csv")
celeba_df.head()

In [None]:
y_cols = list(celeba_df.columns)
y_cols.remove('fn')

datagen = keras.preprocessing.image.ImageDataGenerator(
    preprocessing_function=lambda x: (2 / 255) * x - 1)

train_data = datagen.flow_from_dataframe(
    batch_size=batch_size,
    target_size=(im_size, im_size),
    dataframe=celeba_df,
    directory="r64",
    y_col=y_cols,
    class_mode="other",
    x_col='fn'
)

n_labels = len(y_cols)

In [None]:
dim = 50
alpha = 1.25
n_filters = lambda i: int(dim * (alpha**i))
lrelu = functools.partial(K.relu, alpha=.1)
bnorm = functools.partial(keras.layers.BatchNormalization, momentum=0.8)
conv = functools.partial(
    soph.Conv2D,
    kernel_size=3,
    padding="same",
    strides=2,
    kernel_initializer="orthogonal",
    spectral_normalization=True,
)

In [None]:
disc_base = keras.Sequential([
    keras.layers.InputLayer(input_shape=(im_size, im_size, 3)),
    conv(n_filters(0)),
    keras.layers.Activation(lrelu),
    conv(n_filters(1)),
    keras.layers.Activation(lrelu),
    conv(n_filters(2)),
    keras.layers.Activation(lrelu),
    conv(n_filters(3)),
    keras.layers.Activation(lrelu),
    conv(n_filters(4)),
    keras.layers.Activation(lrelu),
    keras.layers.GlobalAvgPool2D(),
])

gen = keras.Sequential([
    keras.layers.InputLayer(input_shape=(z_len + n_labels, )),
    keras.layers.Dense(n_filters(4) * 4 * 4, activation="relu"),
    keras.layers.Reshape((4, 4, n_filters(4))),
    bnorm(),
    
    keras.layers.UpSampling2D(),
    conv(n_filters(3), strides=1),
    bnorm(),
    keras.layers.Activation("relu"),
    
    keras.layers.UpSampling2D(),
    conv(n_filters(2), strides=1),
    bnorm(),
    keras.layers.Activation("relu"),
    
    keras.layers.UpSampling2D(),
    conv(n_filters(1), strides=1),
    bnorm(),
    keras.layers.Activation("relu"),
    
    keras.layers.UpSampling2D(),
    conv(n_filters(0), strides=1),
    bnorm(),
    keras.layers.Activation("relu"),
    
    conv(3, strides=1, kernel_size=4, activation='tanh')
])

disc_base.summary()
gen.summary()

In [None]:
disc_bin = keras.Sequential([
    keras.layers.InputLayer(input_shape=(im_size, im_size, 3)),
    disc_base,
    keras.layers.Dense(1)
])
disc_class = keras.Sequential([
    keras.layers.InputLayer(input_shape=(im_size, im_size, 3)),
    disc_base,
    keras.layers.Dense(n_labels, activation="sigmoid")
])


In [None]:
img_real = keras.Input(shape=(im_size, im_size, 3 ))
z_gen = keras.Input(shape=(z_len + n_labels, ))
img_fake = gen(z_gen)

logits_real = disc_bin(img_real)
logits_fake = disc_bin(img_fake)

class_real = disc_class(img_real)
class_fake = disc_class(img_fake)

def disc_loss(yt, yp):
    
    loss_real = K.mean(K.relu(1-logits_real))
    loss_fake = K.mean(K.relu(1+logits_fake))
    
    return (loss_real + loss_fake)/2

def gen_loss(yt, yp):
        return -K.mean(logits_fake)

In [None]:
disc_bin.trainable = True
disc_class.trainable = True
gen.trainable = False

disc_train = keras.Model([img_real, z_gen], [logits_real, logits_fake])
disc_train.compile(loss = [disc_loss, None], optimizer=optimizer_d)

disc_class_train = keras.Model(img_real, class_real)
disc_class_train.compile(loss = "binary_crossentropy", optimizer=optimizer_d)

In [None]:
disc_bin.trainable = False
disc_class.trainable = False
gen.trainable = True

gen_train = keras.Model(z_gen, logits_fake)
gen_train.compile(loss = gen_loss, optimizer=optimizer_g)

gen_class_train = keras.Model(z_gen, class_fake)
gen_class_train.compile(loss = "binary_crossentropy", optimizer=optimizer_g)

In [None]:
# generate samples once. 
# these will be used to test the generator after each epoch

sample_side = 10
sample_num = sample_side*sample_side
sample_noise = np.random.normal(size=(sample_num, z_len))
sample_class = celeba_df[y_cols].sample(sample_num).values
sample_comb = np.concatenate( (sample_noise, sample_class), axis=1)

In [None]:
figsize(10,10)

n_disc =2
sample_side = 10
batch_size = 200

num_steps = train_data.n // ((n_disc)*train_data.batch_size)
sample_interval = 50000 // ((n_disc)*train_data.batch_size)

num_epochs = 10

d_loss_list = []
g_loss_list = []

# this will be used for sampling after each epoch. 
# if we stick with the same set of noise, we can see how the generated images changes

noise_sample = np.random.normal(size=(sample_side**2, z_len))

for epoch_i in range(num_epochs):
    print(f"starting epoch {epoch_i}")
    
    for step_i in tqdm(range(num_steps)):
        
        
        for _ in range(n_disc):
            # ---------------------
            #  Train Discriminator
            # ---------------------

            # Select a random batch of images

            x,y = train_data.next()
            dumb = np.zeros_like(y)


            # Sample generator input
            noise = np.random.normal(size=(x.shape[0], z_len))
            z_comb = np.concatenate((noise,y),axis=1)


            # Train the critic
            d_loss = disc_train.train_on_batch([x, z_comb], dumb)
            d_loss = disc_class_train.train_on_batch(x, y)

            d_loss_list.append(d_loss)

        # ---------------------
        #  Train Generator
        # ---------------------

        # Select a random batch of images
        g_loss = gen_train.train_on_batch(z_comb, dumb)
        g_loss = gen_class_train.train_on_batch(z_comb, y)
        g_loss_list.append(g_loss)

        if step_i % sample_interval == 0:

            gen_sample = gen.predict(sample_comb)

            gen_sample = (gen_sample+1)/2

            gen_sample = gen_sample.reshape(sample_side,sample_side*im_size,im_size,3)
            gen_sample = gen_sample.transpose((0,2,1,3))
            gen_sample = gen_sample.reshape(sample_side*im_size,sample_side*im_size,3)
            gen_sample = gen_sample.transpose((1,0,2))


            plt.clf();
            plt.figure();
            fig = plt.imshow(gen_sample);
            fig.axes.get_xaxis().set_visible(False)
            fig.axes.get_yaxis().set_visible(False)
            plt.show();


            print(f" disc loss: {np.mean(d_loss_list[-n_disc*num_steps:])} gen loss: {np.mean(g_loss_list[-num_steps:])}")


In [None]:
x,y = train_data.next()