In [None]:
from importlib import reload
from IPython.display import HTML, clear_output
from PIL import Image
import numpy as np

%matplotlib inline
import matplotlib.pyplot as plt

from util import load_data, random_samples, plot, interpolation, animation

from gan import GAN
from wgangp import WGANGP


## Loading Data

Data loading is currently implemented in two ways:

- Load all pngs in a directory
- Load a hdf5 file

The `load_data` function returns the data and the resulting shape. In the case of loading a hdf5 file, the resulting shape may be different than the requested shape. In the case of loading pngs, it is attempting to convert the images to the requested color mode and then filters out images that do not fit.


In [None]:
# mode: 1 - grayscale
#       2 - grayscale with alpha
#       3 - rgb
#       4 - rgb with alpha
mode = 3
desired_shape = (32, 32, mode)  # width x height x color

In [None]:
imgs, desired_shape = load_data('LLD-icon-sharp.hdf5') # from https://data.vision.ee.ethz.ch/sagea/lld/
#imgs, desired_shape = load_data('scrapper/faviconpngs2/', desired_shape)

# The shape of the loaded data can be different than the requested shape in the case of hdf5
print(desired_shape)
mode = desired_shape[2]

In [None]:
# Shows samples of the loaded data
random_samples(imgs, mode)

## Instantiating the GAN

There are currently two included GAN types, the original GAN and the improved Wasserstein GAN (WGANGP).

The GANs can be loaded with different network architectures from `libs/architectures`. They define which layers are included in the generator and discriminator. Current architectures:

- dense (for grayscale)
- conv1 (inspired by the ACGAN)
- conv2 (improved for good results)
- resnet (using [WGANGP ResNet32](https://github.com/igul222/improved_wgan_training/blob/master/gan_cifar_resnet.py) as reference)

In [None]:
#gan = GAN(desired_shape, architecture='dense')
#gan = GAN(desired_shape, architecture='conv1')
gan = GAN(desired_shape, architecture='conv2')
#gan = WGANGP(desired_shape, architecture='resnet')

In [None]:
d_loss, d_acc, g_loss = gan.train(X_train=imgs, epochs=30000, batch_size=32, sample_interval=200)
clear_output()

In [None]:
Image.open("images/" + "29800.png")

In [None]:
fig, axs = plt.subplots(3,1)
axs[0].plot(d_loss)
axs[1].plot(d_acc)
axs[2].plot(g_loss)

In [None]:
r, c = 4, 6
noise = np.random.normal(0, 1, (r * c+1, 100))

gen_imgs = gan.generator.predict(noise)
gen_imgs = 0.5 * gen_imgs + 0.5

plot(gen_imgs, mode, r, c)

In [None]:
r,c,steps = 4,6,32
interpol = interpolation(gan,r*c,steps)
anim = animation(interpol,mode,r,c,steps)
HTML(anim.to_jshtml())

## For saving the animation:
#anim.save('line.gif', dpi=80, writer='imagemagick')
# with reflection:
#animation(interpol + interpol[::-1], mode,r,c,steps*2).save('line.gif', dpi=80, writer='imagemagick')
# if it's not looping, use 'convert line.gif -loop 0 anim.gif' (using imagemagick)

In [None]:
animation(interpol + interpol[::-1], mode,r,c,steps*2).save('anim.gif', dpi=80, writer='pillow')
gan.generator_model.save_weights('weights_g.h5')
gan.critic_model.save_weights('weights_d.h5')