In [131]:
%pylab inline

import keras.backend as K
from keras import Input, Model, Sequential
from keras.layers import *
from keras.optimizers import Adam, RMSprop
from keras.callbacks import *
from keras.engine import InputSpec, Layer
from keras import initializers, regularizers, constraints
from keras.models import load_model
from keras.losses import *

from scipy.stats import mode

import tensorflow as tf
from matplotlib import pyplot as plt

from datetime import datetime
from pathlib import Path

import pickle

from gan_utils import *

Populating the interactive namespace from numpy and matplotlib


`%matplotlib` prevents importing * from pylab and numpy
  "\n`%matplotlib` prevents importing * from pylab and numpy"


In [223]:
class WGAN_VAE:
    def __init__(self, timesteps, latent_dim, run_dir, img_dir, model_dir, generated_datesets_dir, batch_size):
        self._timesteps = timesteps
        self._latent_dim = latent_dim
        self._run_dir = run_dir
        self._img_dir = img_dir
        self._model_dir = model_dir
        self._generated_datesets_dir = generated_datesets_dir
        self._batch_size = batch_size
        
        self._save_config()
        
        self._epoch = 0
        self._losses = [[], []]

    def build_models(self, generator_lr, critic_lr):
        self._critic, self._critic_hidden = self._build_critic()
        self._critic.compile(loss=self._wasserstein_loss, optimizer=RMSprop(critic_lr))
        
        self._real_inputs = Input((self._timesteps,))
        self._mean_inputs = Input((self._timesteps,))
        z_inputs = Input((self._latent_dim,))
        
        self._generator = self._build_generator()
        self._encoder = self._build_encoder()
        
        set_model_trainable(self._critic, False)
        set_model_trainable(self._critic_hidden, False)
        
        generated_inputs = self._generator(z_inputs) 
        self._discriminated_generated = self._critic(generated_inputs)
        
        self._encoded_mean, self._encoded_logvar = self._encoder(self._real_inputs)
        sampled_z = Lambda(self._sampling)([self._encoded_mean, self._encoded_logvar])
        self._decoded_inputs = self._generator(sampled_z)
                
        self._discriminated_hidden_decoded = self._critic_hidden(self._decoded_inputs)
        self._discriminated_hidden_real = self._critic_hidden(self._real_inputs)
        
        self._discriminated_decoded = self._critic(self._decoded_inputs)
        
        self._gan_vae = Model([self._real_inputs, z_inputs], self._discriminated_decoded, 'GAN VAE')
        self._gan_vae.compile(loss=self._gan_vae_custom_loss, optimizer=RMSprop(generator_lr))
        
        self._generator = Model(z_inputs, generated_inputs)
        
        return self._gan_vae, self._generator, self._critic

    def _gan_vae_custom_loss(self, y_true, y_pred):
        kl_loss = - 0.5 * K.sum(+ 1 + self._encoded_logvar - K.square(self._encoded_mean) - K.exp(self._encoded_logvar), axis=-1)
        mse_loss = mean_squared_error(self._discriminated_hidden_decoded, self._discriminated_hidden_real)
        gan_loss = self._wasserstein_loss(y_true, self._discriminated_generated)
        return 0.3 * gan_loss + 0.6 * mse_loss + 0.1 * kl_loss
    
    def _sampling(self, args):
        z_mean, z_log_var = args
        epsilon = K.random_normal(shape=(self._batch_size, self._latent_dim), mean=0., stddev=1.0)
        return z_mean + K.exp(z_log_var/2.0) * epsilon
    
    def _build_generator(self):
        generator_inputs = Input((self._latent_dim, ))
        generated = generator_inputs
        
#         generated = Dense(self._latent_dim, activation='relu')(generated)
#         generated = Dense(self._latent_dim, activation='relu')(generated)
#         generated = Dense(self._timesteps, activation='relu')(generated)

        generated = Dense(12, activation='linear')(generated)
        
        generated = Lambda(lambda x: K.expand_dims(x))(generated)
        generated = UpSampling1D(2)(generated)
        generated = Conv1D(64, 3, activation='relu', padding='same')(generated)
        generated = Conv1D(64, 3, activation='relu', padding='same')(generated)
        
        generated = UpSampling1D(2)(generated)
        generated = Conv1D(64, 3, activation='relu', padding='same')(generated)
        generated = Conv1D(64, 3, activation='relu', padding='same')(generated)
        
        generated = UpSampling1D(2)(generated)
        generated = Conv1D(32, 3, activation='relu', padding='same')(generated)
        generated = Conv1D(1, 3, activation='tanh', padding='same')(generated)
        generated = Lambda(lambda x: K.squeeze(x, -1))(generated)
        
#         generated = Concatenate()([generated, generator_inputs])

        generated = Dense(self._timesteps, activation='tanh')(generated)

        generator = Model(generator_inputs, generated, 'generator')
        return generator

    def _build_critic(self):
        critic_inputs = Input((self._timesteps, ))
        criticized = critic_inputs
#         criticized = Dense(self._timesteps, activation='relu')(criticized)
#         criticized = Dense(self._timesteps, activation='relu')(criticized)
#         criticized = Dense(self._latent_dim, activation='relu')(criticized)
        
        criticized = Lambda(lambda x: K.expand_dims(x))(criticized)
        criticized = Conv1D(32, 3, activation='relu', padding='same')(criticized)
        criticized = Conv1D(32, 3, activation='relu', padding='same')(criticized)
        criticized = MaxPooling1D(2, padding='same')(criticized)
        
        criticized = Conv1D(64, 3, activation='relu', padding='same')(criticized)
        criticized = Conv1D(64, 3, activation='relu', padding='same')(criticized)
        criticized = MaxPooling1D(2, padding='same')(criticized)
        
        criticized = Conv1D(64, 3, activation='relu', padding='same')(criticized)
        criticized = Conv1D(64, 3, activation='relu', padding='same')(criticized)
        criticized = MaxPooling1D(2, padding='same')(criticized)
        criticized = Flatten()(criticized)
                
#         criticized = Concatenate()([criticized, critic_inputs])

        criticized = Dense(32, activation='relu')(criticized)
        criticized_hidden = Dense(15, activation='tanh')(criticized)
        criticized = Dense(1)(criticized_hidden) 

        critic = Model(critic_inputs, criticized, 'critic')
        critic_hidden = Model(critic_inputs, criticized_hidden, 'critic_hidden')
        return critic, critic_hidden

    def _build_encoder(self):
        encoder_inputs = Input((self._timesteps,))
        encoded = encoder_inputs
        
#         encoded = Dense(self._timesteps, activation='relu')(encoded)
#         encoded = Dense(self._timesteps, activation='relu')(encoded)
#         encoded = Dense(self._latent_dim, activation='relu')(encoded)

        encoded = Lambda(lambda x: K.expand_dims(x))(encoded)
        encoded = Conv1D(32, 3, activation='relu', padding='same')(encoded)
        encoded = Conv1D(32, 3, activation='relu', padding='same')(encoded)
        encoded = MaxPooling1D(2, padding='same')(encoded)
        
        encoded = Conv1D(64, 3, activation='relu', padding='same')(encoded)
        encoded = Conv1D(64, 3, activation='relu', padding='same')(encoded)
        encoded = MaxPooling1D(2, padding='same')(encoded)
        
        encoded = Conv1D(64, 3, activation='relu', padding='same')(encoded)
        encoded = Conv1D(64, 3, activation='relu', padding='same')(encoded)
        encoded = MaxPooling1D(2, padding='same')(encoded)
        encoded = Flatten()(encoded)
        
#         encoded = Concatenate()([encoded, encoder_inputs])
        
        encoded = Dense(self._latent_dim, activation='relu')(encoded)    
        encoded_mean = Dense(self._latent_dim)(encoded)
        encoded_logvar = Dense(self._latent_dim)(encoded)

        encoder = Model(encoder_inputs, [encoded_mean, encoded_logvar], 'encoder')
        return encoder
        
    def train(self, batch_size, epochs, n_generator, n_critic, dataset, clip_value,
           img_frequency, model_save_frequency, dataset_generation_frequency, dataset_generation_size):
        half_batch = int(batch_size / 2)

        
        while self._epoch < epochs:
            self._epoch += 1
            critic_losses = []
            for _ in range(n_critic):
                indexes = np.random.choice(dataset.shape[0], half_batch, replace=False)
                real_transactions = dataset[indexes]

                noise = np.random.normal(0, 1, (half_batch, self._latent_dim))
                generated_transactions = self._generator.predict(noise)

                mixed_batch = np.vstack([real_transactions, generated_transactions])                
                mixed_labels = np.vstack([-np.ones((half_batch, 1)), np.ones((half_batch, 1))])
                
                critic_losses.append(self._critic.train_on_batch(mixed_batch, mixed_labels))

                self._clip_weights(clip_value)
            critic_loss = np.mean(critic_losses)
            
            generator_losses = []
            for _ in range(n_generator):
                noise = np.random.normal(0, 1, (batch_size, latent_dim))
                indexes = np.random.randint(0, dataset.shape[0], batch_size)
                
                batch_transactions = dataset[indexes]
                generator_losses.append(self._gan.train_on_batch([batch_transactions, noise], -np.ones((batch_size, 1))))
            generator_loss = np.mean(generator_losses)
            
            generator_loss = generator_loss
            critic_loss = critic_loss
            
            self._losses[0].append(generator_loss)
            self._losses[1].append(critic_loss)

            print("%d [GENERATOR loss: %f] [CRITIC loss: %f]" % (self._epoch, generator_loss, critic_loss))

            
            if self._epoch % 250 == 0:
                self._save_losses()
                
            if self._epoch % img_frequency == 0:
                self._save_imgs()
                self._save_latent_space()
                
#                 noise = np.random.normal(0, 1, (10, latent_dim))
#                 generated = self._generator.predict(noise)
#                 plt.subplots(2, 5)
#                 for i in range(10):
#                     plt.subplot(2, 5, i+1)
#                     plt.imshow(generated[i].reshape(21, 21))
#                     plt.xticks([])
#                     plt.yticks([])
#                 plt.tight_layout()
#                 plt.show()

#             if self._epoch % model_save_frequency == 0:
#                 self._save_models()
                
#             if self._epoch % dataset_generation_frequency == 0:
#                 self._generate_dataset(self._epoch, dataset_generation_size)
          
        self._generate_dataset(epochs, dataset_generation_size)
        self._save_losses(self._losses)
        self._save_models()
        self._save_imgs()
        self._save_latent_space()
        
        return losses
    
    def _save_imgs(self):
        rows, columns = 5, 5
        noise = np.random.normal(0, 1, (rows * columns, latent_dim))
        generated_transactions = self._generator.predict(noise)

        plt.subplots(rows, columns, figsize=(15, 5))
        k = 1
        for i in range(rows):
            for j in range(columns):
                plt.subplot(rows, columns, k)
                plt.plot(generated_transactions[k - 1])
                plt.xticks([])
                plt.yticks([])
                plt.ylim(-1, 1)
                k += 1
        plt.tight_layout()
        plt.savefig(str(self._img_dir / ('%05d.png' % self._epoch)))
        plt.savefig(str(self._img_dir / 'last.png'))
        plt.clf()
        plt.close()
        
    def _save_latent_space(self):
        if self._latent_dim > 2:
            latent_vector = np.random.normal(0, 1, latent_dim)
        plt.subplots(5, 5, figsize=(15, 5))

        for i, v_i in enumerate(np.linspace(-2, 2, 5, True)):
            for j, v_j in enumerate(np.linspace(-2, 2, 5, True)):
                if self._latent_dim > 2:
                    latent_vector[-2:] = [v_i, v_j]
                else:
                    latent_vector = np.array([v_i, v_j])
                    
                plt.subplot(5, 5, i*5+j+1)
                plt.plot(self._generator.predict(latent_vector.reshape((1, self._latent_dim))).T)
                plt.xticks([])
                plt.yticks([])
                plt.ylim(-1, 1)
        plt.tight_layout()
        plt.savefig(str(self._img_dir / ('latent_space.png')))
        plt.clf()
        plt.close()
        
    def _save_losses(self):
        plt.subplots(2, 1, figsize=(15, 9))
        plt.subplot(2, 1, 1)
        plt.plot(self._losses[0])
        plt.plot(self._losses[1])
        plt.legend(['generator', 'critic'])
        plt.subplot(2, 1, 2)
        plt.plot(self._losses[0][-1000:])
        plt.plot(self._losses[1][-1000:])
        plt.legend(['generator', 'critic'])
        plt.savefig(str(self._img_dir / 'losses.png')) 
        plt.clf()
        plt.close()
        
        with open(str(self._run_dir / 'losses.p'), 'wb') as f:
            pickle.dump(self._losses, f)
        
    def _clip_weights(self, clip_value):
        for l in self._critic.layers:
#             if 'minibatch_discrimination' not in l.name:
            weights = [np.clip(w, -clip_value, clip_value) for w in l.get_weights()]
            l.set_weights(weights)

    def _save_config(self):
        config = {
            'timesteps' : self._timesteps,
            'latent_dim' : self._latent_dim,
            'run_dir' : self._run_dir,
            'img_dir' : self._img_dir,
            'model_dir' : self._model_dir,
            'generated_datesets_dir' : self._generated_datesets_dir
        }
        
        with open(str(self._run_dir / 'config.p'), 'wb') as f:
            pickle.dump(config, f)
        
    def _save_models(self):
        self._gan.save(self._model_dir / 'wgan.h5')
        self._generator.save(self._model_dir / 'generator.h5')
        self._critic.save(self._model_dir / 'critic.h5')
        
    def _generate_dataset(self, epoch, dataset_generation_size):
        z_samples = np.random.normal(0, 1, (dataset_generation_size, self._latent_dim))
        generated_dataset = self._generator.predict(z_samples)
        np.save(self._generated_datesets_dir / ('%d_generated_data' % epoch), generated_dataset)
        np.save(self._generated_datesets_dir / 'last', generated_dataset)
        
    def get_models(self):
        return self._gan, self._generator, self._critic
    
    @staticmethod
    def _wasserstein_loss(y_true, y_pred):
        return K.mean(y_true * y_pred)

    def restore_training(self):
        self.load_models()
        with open(str(self._run_dir / 'losses.p'), 'rb') as f:
            self._losses = pickle.load(f)
            self._epoch = len(self._losses[0])
        
        return self._gan, self._generator, self._critic
    
    def load_models(self):
        custom_objects = {
            'MinibatchDiscrimination':MinibatchDiscrimination,
            '_wasserstein_loss':self._wasserstein_loss
        }
        self._gan = load_model(self._model_dir / 'wgan.h5', custom_objects=custom_objects)
        self._generator = load_model(self._model_dir / 'generator.h5')
        self._critic = load_model(self._model_dir / 'critic.h5', custom_objects=custom_objects)
        
        return self._gan, self._generator, self._critic

In [224]:
def split_data(dataset, timesteps):
    D = dataset.shape[1]
    if D < timesteps:
        return None
    elif D == timesteps:
        return dataset
    else:
        splitted_data, remaining_data = np.hsplit(dataset, [timesteps])
        remaining_data = split_data(remaining_data, timesteps)
        if remaining_data is not None:
            return np.vstack([splitted_data, remaining_data])
        return splitted_data

In [225]:
normalized_transactions_filepath = "../../datasets/berka_dataset/usable/normalized_transactions_months.npy"

timesteps = 90
transactions = np.load(normalized_transactions_filepath)
transactions = split_data(transactions, timesteps)
np.random.shuffle(transactions)

In [226]:
# from keras.datasets import mnist

# timesteps = 21*21

# (x_train, y_train), (x_test, y_test) = mnist.load_data()
# x_train = x_train[:, 3:-4, 3:-4]
# x_train = x_train.reshape(60000, timesteps)
# transactions = (x_train / 255.0) * 2.0 - 1.0

In [227]:
batch_size = 64
epochs = 500000
n_critic = 1
n_generator = 5
latent_dim = 2
generator_lr = 0.0005
critic_lr = 0.00005
clip_value = 0.05
img_frequency = 250
model_save_frequency = 3000
dataset_generation_frequency = 25000
dataset_generation_size = 100000

In [228]:
root_path = Path('wgan')
if not root_path.exists():
    root_path.mkdir()
    
current_datetime = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')

run_dir = root_path / current_datetime
img_dir = run_dir / 'img'
model_dir = run_dir / 'models'
generated_datesets_dir = run_dir / 'generated_datasets'

img_dir.mkdir(parents=True)
model_dir.mkdir(parents=True)
generated_datesets_dir.mkdir(parents=True)

In [229]:
wgan = WGAN_VAE(timesteps, latent_dim, run_dir, img_dir, model_dir, generated_datesets_dir, batch_size)
gan, generator, critic = wgan.build_models(generator_lr, critic_lr)
        
losses = wgan.train(batch_size, epochs, n_generator, n_critic, transactions, clip_value,
           img_frequency, model_save_frequency, dataset_generation_frequency, dataset_generation_size)

  'Discrepancy between trainable weights and collected trainable'


KeyboardInterrupt: 