In [80]:
%pylab inline

import keras.backend as K
from keras import Input, Model, Sequential
from keras.layers import *
from keras.optimizers import Adam, RMSprop
from keras.callbacks import *
from keras.engine import InputSpec, Layer
from keras import initializers, regularizers, constraints
from keras.models import load_model

import tensorflow as tf
from matplotlib import pyplot as plt

from datetime import datetime
from pathlib import Path

import pickle

Populating the interactive namespace from numpy and matplotlib


`%matplotlib` prevents importing * from pylab and numpy
  "\n`%matplotlib` prevents importing * from pylab and numpy"


In [81]:
def split_data(dataset, timesteps):
    D = dataset.shape[1]
    if D < timesteps:
        return None
    elif D == timesteps:
        return dataset
    else:
        splitted_data, remaining_data = np.hsplit(dataset, [timesteps])
        remaining_data = split_data(remaining_data, timesteps)
        if remaining_data is not None:
            return np.vstack([splitted_data, remaining_data])
        return splitted_data
    
class MinibatchDiscrimination(Layer):
    """Concatenates to each sample information about how different the input
    features for that sample are from features of other samples in the same
    minibatch, as described in Salimans et. al. (2016). Useful for preventing
    GANs from collapsing to a single output. When using this layer, generated
    samples and reference samples should be in separate batches."""

    def __init__(self, nb_kernels, kernel_dim, init='glorot_uniform', weights=None,
                 W_regularizer=None, activity_regularizer=None,
                 W_constraint=None, input_dim=None, **kwargs):
        self.init = initializers.get(init)
        self.nb_kernels = nb_kernels
        self.kernel_dim = kernel_dim
        self.input_dim = input_dim

        self.W_regularizer = regularizers.get(W_regularizer)
        self.activity_regularizer = regularizers.get(activity_regularizer)

        self.W_constraint = constraints.get(W_constraint)

        self.initial_weights = weights
        self.input_spec = [InputSpec(ndim=2)]

        if self.input_dim:
            kwargs['input_shape'] = (self.input_dim,)
        super(MinibatchDiscrimination, self).__init__(**kwargs)

    def build(self, input_shape):
        assert len(input_shape) == 2

        input_dim = input_shape[1]
        self.input_spec = [InputSpec(dtype=K.floatx(),
                                     shape=(None, input_dim))]

        self.W = self.add_weight(shape=(self.nb_kernels, input_dim, self.kernel_dim),
            initializer=self.init,
            name='kernel',
            regularizer=self.W_regularizer,
            trainable=True,
            constraint=self.W_constraint)

        # Set built to true.
        super(MinibatchDiscrimination, self).build(input_shape)

    def call(self, x, mask=None):
        activation = K.reshape(K.dot(x, self.W), (-1, self.nb_kernels, self.kernel_dim))
        diffs = K.expand_dims(activation, 3) - K.expand_dims(K.permute_dimensions(activation, [1, 2, 0]), 0)
        abs_diffs = K.sum(K.abs(diffs), axis=2)
        minibatch_features = K.sum(K.exp(-abs_diffs), axis=2)
        return K.concatenate([x, minibatch_features], 1)
#         return minibatch_features

    def compute_output_shape(self, input_shape):
        assert input_shape and len(input_shape) == 2
        return input_shape[0], input_shape[1]+self.nb_kernels
#         return input_shape[0], self.nb_kernels

    def get_config(self):
        config = {'nb_kernels': self.nb_kernels,
                  'kernel_dim': self.kernel_dim,
#                   'init': self.init.__name__,
                  'W_regularizer': self.W_regularizer.get_config() if self.W_regularizer else None,
                  'activity_regularizer': self.activity_regularizer.get_config() if self.activity_regularizer else None,
                  'W_constraint': self.W_constraint.get_config() if self.W_constraint else None,
                  'input_dim': self.input_dim}
        base_config = super(MinibatchDiscrimination, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))
    
class RandomWeightedAverage(_Merge):
    """Takes a randomly-weighted average of two tensors. In geometric terms, this outputs a random point on the line
    between each pair of input points.
    Inheriting from _Merge is a little messy but it was the quickest solution I could think of.
    Improvements appreciated."""

    def _merge_function(self, inputs):
        weights = K.random_uniform((BATCH_SIZE, 1, 1, 1))
        return (weights * inputs[0]) + ((1 - weights) * inputs[1])

In [95]:
class WGAN:
    def __init__(self, timesteps, latent_dim, run_dir, img_dir, model_dir, generated_datesets_dir):
        self._timesteps = timesteps
        self._latent_dim = latent_dim
        self._run_dir = run_dir
        self._img_dir = img_dir
        self._model_dir = model_dir
        self._generated_datesets_dir = generated_datesets_dir
        
        self._save_config()
        
        self._epoch = 0
        self._losses = [[], []]

    def build_models(self, generator_lr, critic_lr):        
        self._generator = self._build_generator(self._latent_dim, self._timesteps)
        self._critic = self._build_critic(self._timesteps)
        
        
        
        
        self._critic.compile(loss=self._wasserstein_loss, optimizer=RMSprop(critic_lr))

        z = Input(shape=(self._latent_dim, ))
        fake = self._generator(z)

        self._critic.trainable = False

        valid = self._critic(fake)

        self._gan = Model(z, valid, 'GAN')

        self._gan.compile(
            loss=self._wasserstein_loss,
            optimizer=RMSprop(generator_lr),
            metrics=['accuracy'])
        
        return self._gan, self._generator, self._critic

    def _build_generator(self, noise_dim, timesteps):
        generator_inputs = Input((latent_dim, ))
        generated = generator_inputs
        
        generated = Lambda(lambda x: K.expand_dims(x))(generated)
        while generated.shape[1] < timesteps:
            generated = Conv1D(
                32, 3, activation='relu', padding='same')(generated)
            generated = UpSampling1D(2)(generated)
        generated = Conv1D(
            1, 3, activation='relu', padding='same')(generated)
        generated = Lambda(lambda x: K.squeeze(x, -1))(generated)
        generated = Dense(timesteps, activation='tanh')(generated)

        generator = Model(generator_inputs, generated, 'generator')
        return generator

    def _build_critic(self, timesteps):
        critic_inputs = Input((timesteps, ))
        criticized = MinibatchDiscrimination(5, 3)(critic_inputs)
        
        criticized = Lambda(lambda x: K.expand_dims(x))(
            criticized)
        while criticized.shape[1] > 1:
            criticized = Conv1D(
                32, 3, activation='tanh', padding='same')(criticized)
            criticized = MaxPooling1D(2, padding='same')(criticized)
        criticized = Flatten()(criticized)
#         criticized = Concatenate()([criticized, critic_inputs])
        criticized = Dense(32, activation='tanh')(criticized)
        criticized = Dense(15, activation='tanh')(criticized)
        criticized = Dense(1)(criticized) 

        critic = Model(critic_inputs, criticized, 'critic')
        return critic

    def train(self, batch_size, epochs, n_generator, n_critic, dataset, clip_value,
           img_frequency, model_save_frequency, dataset_generation_frequency, dataset_generation_size):
        half_batch = int(batch_size / 2)

        
        while self._epoch < epochs:
            self._epoch += 1
            for _ in range(n_critic):
                indexes = np.random.randint(0, dataset.shape[0], half_batch)
                batch_transactions = dataset[indexes]

                noise = np.random.normal(0, 1, (half_batch, self._latent_dim))

                generated_transactions = self._generator.predict(noise)

                critic_loss_real = self._critic.train_on_batch(
                    batch_transactions, -np.ones((half_batch, 1)))
                critic_loss_fake = self._critic.train_on_batch(
                    generated_transactions, np.ones((half_batch, 1)))
                critic_loss = 0.5 * np.add(critic_loss_real,
                                                  critic_loss_fake)

#                 self._clip_weights(clip_value)

            for _ in range(n_generator):
                noise = np.random.normal(0, 1, (batch_size, latent_dim))

                generator_loss = self._gan.train_on_batch(
                    noise, -np.ones((batch_size, 1)))[0]
            
            generator_loss = 1 - generator_loss
            critic_loss = 1 - critic_loss
            
            self._losses[0].append(generator_loss)
            self._losses[1].append(critic_loss)

            print("%d [C loss: %f] [G loss: %f]" % (self._epoch, critic_loss,
                                                    generator_loss))

            if self._epoch % img_frequency == 0:
                self._save_imgs()
            
            if self._epoch % model_save_frequency == 0:
                self._save_models()
                
            if self._epoch % dataset_generation_frequency == 0:
                self._generate_dataset(epoch, dataset_generation_size)
                
            if self._epoch % 250 == 0:
                self._save_losses()
          
        self._generate_dataset(epochs, dataset_generation_size)
        self._save_losses(losses)
        self._save_models()
        self._save_imgs()
        
        return losses
    
    def _save_imgs(self):
        rows, columns = 5, 5
        noise = np.random.normal(0, 1, (rows * columns, latent_dim))
        generated_transactions = self._generator.predict(noise)

        plt.subplots(rows, columns, figsize=(15, 5))
        k = 1
        for i in range(rows):
            for j in range(columns):
                plt.subplot(rows, columns, k)
                plt.plot(generated_transactions[k - 1])
                plt.xticks([])
                plt.yticks([])
                plt.ylim(-1, 1)
                k += 1
        plt.tight_layout()
        plt.savefig(str(self._img_dir / ('%05d.png' % self._epoch)))
        plt.savefig(str(self._img_dir / 'last.png'))
        plt.clf()
        plt.close()
        
        if self._latent_dim == 2:
            plt.subplots(5, 5, figsize=(15, 5))

            for i, v_i in enumerate(np.linspace(-2, 2, 5, True)):
                for j, v_j in enumerate(np.linspace(-2, 2, 5, True)):
                    plt.subplot(5, 5, i*5+j+1)
                    plt.plot(self._generator.predict(np.array([v_i, v_j]).reshape((1, 2))).T)
                    plt.xticks([])
                    plt.yticks([])
                    plt.ylim(-1, 1)
            plt.tight_layout()
            plt.savefig(str(self._img_dir / ('latent_space.png')))
            plt.clf()
            plt.close()

    def _save_losses(self):
        plt.figure(figsize=(15, 3))
        plt.plot(self._losses[0])
        plt.plot(self._losses[1])
        plt.legend(['generator', 'critic'])
        plt.savefig(str(self._img_dir / 'losses.png'))
        plt.clf()
        plt.close()
        
        with open(str(self._run_dir / 'losses.p'), 'wb') as f:
            pickle.dump(self._losses, f)
        
    def _clip_weights(self, clip_value):
        for l in self._critic.layers:
#             if 'minibatch_discrimination' not in l.name:
            weights = [np.clip(w, -clip_value, clip_value) for w in l.get_weights()]
            l.set_weights(weights)

    def _save_config(self):
        config = {
            'timesteps' : self._timesteps,
            'latent_dim' : self._latent_dim,
            'run_dir' : self._run_dir,
            'img_dir' : self._img_dir,
            'model_dir' : self._model_dir,
            'generated_datesets_dir' : self._generated_datesets_dir
        }
        
        with open(str(self._run_dir / 'config.p'), 'wb') as f:
            pickle.dump(config, f)
        
    def _save_models(self):
        self._gan.save(self._model_dir / 'wgan.h5')
        self._generator.save(self._model_dir / 'generator.h5')
        self._critic.save(self._model_dir / 'critic.h5')
        
    def _generate_dataset(self, epoch, dataset_generation_size):
        z_samples = np.random.normal(0, 1, (dataset_generation_size, self._latent_dim))
        generated_dataset = self._generator.predict(z_samples)
        np.save(self._generated_datesets_dir / ('%d_generated_data' % epoch), generated_dataset)
        np.save(self._generated_datesets_dir / 'last', generated_dataset)
        
    def get_models(self):
        return self._gan, self._generator, self._critic
    
    @staticmethod
    def _wasserstein_loss(y_true, y_pred):
        return K.mean(y_true * y_pred)

    def restore_training(self):
        load_models()
        with open(self._run_dir / 'losses.p', 'rb') as f:
            self._losses = pickle.load(f)
            self._epoch = len(self._losses)
        
    def gradient_penalty_loss(y_true, y_pred, averaged_samples, gradient_penalty_weight):
        """Calculates the gradient penalty loss for a batch of "averaged" samples.
        In Improved WGANs, the 1-Lipschitz constraint is enforced by adding a term to the loss function
        that penalizes the network if the gradient norm moves away from 1. However, it is impossible to evaluate
        this function at all points in the input space. The compromise used in the paper is to choose random points
        on the lines between real and generated samples, and check the gradients at these points. Note that it is the
        gradient w.r.t. the input averaged samples, not the weights of the discriminator, that we're penalizing!
        In order to evaluate the gradients, we must first run samples through the generator and evaluate the loss.
        Then we get the gradients of the discriminator w.r.t. the input averaged samples.
        The l2 norm and penalty can then be calculated for this gradient.
        Note that this loss function requires the original averaged samples as input, but Keras only supports passing
        y_true and y_pred to loss functions. To get around this, we make a partial() of the function with the
        averaged_samples argument, and use that for model training."""
        # first get the gradients:
        #   assuming: - that y_pred has dimensions (batch_size, 1)
        #             - averaged_samples has dimensions (batch_size, nbr_features)
        # gradients afterwards has dimension (batch_size, nbr_features), basically
        # a list of nbr_features-dimensional gradient vectors
        gradients = K.gradients(y_pred, averaged_samples)[0]
        # compute the euclidean norm by squaring ...
        gradients_sqr = K.square(gradients)
        #   ... summing over the rows ...
        gradients_sqr_sum = K.sum(gradients_sqr,
                                  axis=np.arange(1, len(gradients_sqr.shape)))
        #   ... and sqrt
        gradient_l2_norm = K.sqrt(gradients_sqr_sum)
        # compute lambda * (1 - ||grad||)^2 still for each single sample
        gradient_penalty = gradient_penalty_weight * K.square(1 - gradient_l2_norm)
        # return the mean as loss over all the batch samples
        return K.mean(gradient_penalty)

    def load_models(self):
        self._gan = load_model(self._model_dir / 'wgan.h5')
        self._generator = load_model(self._model_dir / 'generator.h5')
        self._critic = load_model(self._model_dir / 'critic.h5')

In [96]:
normalized_transactions_filepath = "../datasets/berka_dataset/usable/normalized_transactions.npy"

timesteps = 100
transactions = np.load(normalized_transactions_filepath)
transactions = split_data(transactions, timesteps)
transactions = transactions[np.std(transactions, 1) > float(1e-7)]
N, D = transactions.shape
print(transactions.shape)

(53888, 100)


In [97]:
batch_size = 64
epochs = 50000
n_critic = 10
n_generator = 1
latent_dim = 2
generator_lr = 0.00005
critic_lr = 0.00005
clip_value = 0.05
img_frequency = 250
model_save_frequency = 3000
dataset_generation_frequency = 25000
dataset_generation_size = 100000

In [98]:
root_path = Path('wgan')
if not root_path.exists():
    root_path.mkdir()
    
current_datetime = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')

run_dir = root_path / current_datetime
img_dir = run_dir / 'img'
model_dir = run_dir / 'models'
generated_datesets_dir = run_dir / 'generated_datasets'

img_dir.mkdir(parents=True)
model_dir.mkdir(parents=True)
generated_datesets_dir.mkdir(parents=True)

In [None]:
wgan = WGAN(timesteps, latent_dim, run_dir, img_dir, model_dir, generated_datesets_dir)
gan, generator, critic = wgan.build_models(generator_lr, critic_lr)

# gan.summary()
# generator.summary()
# critic.summary()
        
losses = wgan.train(batch_size, epochs, n_generator, n_critic, transactions, clip_value,
           img_frequency, model_save_frequency, dataset_generation_frequency, dataset_generation_size)

  'Discrepancy between trainable weights and collected trainable'


1 [D loss: 1.033217] [G loss: 1.095408]
2 [D loss: 1.064658] [G loss: 1.074012]
3 [D loss: 1.091514] [G loss: 1.047046]
4 [D loss: 1.132359] [G loss: 1.017008]
5 [D loss: 1.175182] [G loss: 0.982249]
6 [D loss: 1.229404] [G loss: 0.940496]
7 [D loss: 1.318202] [G loss: 0.885099]
8 [D loss: 1.389257] [G loss: 0.822847]
9 [D loss: 1.514111] [G loss: 0.754643]
10 [D loss: 1.599586] [G loss: 0.666971]
11 [D loss: 1.736487] [G loss: 0.558934]
12 [D loss: 1.874529] [G loss: 0.427570]
13 [D loss: 2.048684] [G loss: 0.277749]
14 [D loss: 2.197595] [G loss: 0.110656]
15 [D loss: 2.393756] [G loss: -0.081693]
16 [D loss: 2.548490] [G loss: -0.278859]
17 [D loss: 2.763986] [G loss: -0.492972]
18 [D loss: 2.952600] [G loss: -0.702622]
19 [D loss: 3.133370] [G loss: -0.918437]
20 [D loss: 3.336235] [G loss: -1.137592]
21 [D loss: 3.493602] [G loss: -1.330008]
22 [D loss: 3.650474] [G loss: -1.509967]
23 [D loss: 3.782592] [G loss: -1.659158]
24 [D loss: 3.901990] [G loss: -1.799540]
25 [D loss: 4.0

195 [D loss: 7.905759] [G loss: -5.906840]
196 [D loss: 7.920768] [G loss: -5.921849]
197 [D loss: 7.935776] [G loss: -5.936857]
198 [D loss: 7.950784] [G loss: -5.951865]
199 [D loss: 7.965791] [G loss: -5.966872]
200 [D loss: 7.980798] [G loss: -5.981879]
201 [D loss: 7.995806] [G loss: -5.996886]
202 [D loss: 8.010811] [G loss: -6.011893]
203 [D loss: 8.025818] [G loss: -6.026899]
204 [D loss: 8.040824] [G loss: -6.041905]
205 [D loss: 8.055830] [G loss: -6.056911]
206 [D loss: 8.070836] [G loss: -6.071918]
207 [D loss: 8.085842] [G loss: -6.086923]
208 [D loss: 8.100848] [G loss: -6.101928]
209 [D loss: 8.115853] [G loss: -6.116934]
210 [D loss: 8.130858] [G loss: -6.131939]
211 [D loss: 8.145863] [G loss: -6.146943]
212 [D loss: 8.160868] [G loss: -6.161949]
213 [D loss: 8.175872] [G loss: -6.176954]
214 [D loss: 8.190877] [G loss: -6.191958]
215 [D loss: 8.205882] [G loss: -6.206962]
216 [D loss: 8.220885] [G loss: -6.221967]
217 [D loss: 8.235889] [G loss: -6.236970]
218 [D loss

385 [D loss: 10.756337] [G loss: -8.757419]
386 [D loss: 10.771341] [G loss: -8.772422]
387 [D loss: 10.786343] [G loss: -8.787423]
388 [D loss: 10.801346] [G loss: -8.802426]
389 [D loss: 10.816347] [G loss: -8.817429]
390 [D loss: 10.831350] [G loss: -8.832432]
391 [D loss: 10.846354] [G loss: -8.847433]
392 [D loss: 10.861356] [G loss: -8.862436]
393 [D loss: 10.876357] [G loss: -8.877439]
394 [D loss: 10.891361] [G loss: -8.892442]
395 [D loss: 10.906363] [G loss: -8.907443]
396 [D loss: 10.921366] [G loss: -8.922446]
397 [D loss: 10.936367] [G loss: -8.937449]
398 [D loss: 10.951370] [G loss: -8.952452]
399 [D loss: 10.966373] [G loss: -8.967453]
400 [D loss: 10.981376] [G loss: -8.982456]
401 [D loss: 10.996377] [G loss: -8.997458]
402 [D loss: 11.011381] [G loss: -9.012462]
403 [D loss: 11.026382] [G loss: -9.027463]
404 [D loss: 11.041386] [G loss: -9.042466]
405 [D loss: 11.056387] [G loss: -9.057468]
406 [D loss: 11.071390] [G loss: -9.072472]
407 [D loss: 11.086393] [G loss:

569 [D loss: 13.516371] [G loss: -11.517451]
570 [D loss: 13.531369] [G loss: -11.532451]
571 [D loss: 13.546367] [G loss: -11.547445]
572 [D loss: 13.561364] [G loss: -11.562446]
573 [D loss: 13.576361] [G loss: -11.577442]
574 [D loss: 13.591360] [G loss: -11.592442]
575 [D loss: 13.606358] [G loss: -11.607437]
576 [D loss: 13.621355] [G loss: -11.622437]
577 [D loss: 13.636353] [G loss: -11.637433]
578 [D loss: 13.651351] [G loss: -11.652432]
579 [D loss: 13.666348] [G loss: -11.667427]
580 [D loss: 13.681346] [G loss: -11.682427]
581 [D loss: 13.696343] [G loss: -11.697424]
582 [D loss: 13.711342] [G loss: -11.712423]
583 [D loss: 13.726339] [G loss: -11.727419]
584 [D loss: 13.741337] [G loss: -11.742418]
585 [D loss: 13.756334] [G loss: -11.757415]
586 [D loss: 13.771333] [G loss: -11.772414]
587 [D loss: 13.786330] [G loss: -11.787409]
588 [D loss: 13.801328] [G loss: -11.802409]
589 [D loss: 13.816325] [G loss: -11.817406]
590 [D loss: 13.831324] [G loss: -11.832405]
591 [D los

753 [D loss: 16.275593] [G loss: -14.276672]
754 [D loss: 16.290586] [G loss: -14.291665]
755 [D loss: 16.305580] [G loss: -14.306660]
756 [D loss: 16.320574] [G loss: -14.321654]
757 [D loss: 16.335567] [G loss: -14.336649]
758 [D loss: 16.350563] [G loss: -14.351642]
759 [D loss: 16.365557] [G loss: -14.366636]
760 [D loss: 16.380550] [G loss: -14.381631]
761 [D loss: 16.395546] [G loss: -14.396626]
762 [D loss: 16.410540] [G loss: -14.411618]
763 [D loss: 16.425533] [G loss: -14.426613]
764 [D loss: 16.440527] [G loss: -14.441607]
765 [D loss: 16.455521] [G loss: -14.456602]
766 [D loss: 16.470516] [G loss: -14.471595]
767 [D loss: 16.485510] [G loss: -14.486589]
768 [D loss: 16.500504] [G loss: -14.501584]
769 [D loss: 16.515499] [G loss: -14.516579]
770 [D loss: 16.530493] [G loss: -14.531571]
771 [D loss: 16.545486] [G loss: -14.546566]
772 [D loss: 16.560480] [G loss: -14.561561]
773 [D loss: 16.575474] [G loss: -14.576555]
774 [D loss: 16.590469] [G loss: -14.591548]
775 [D los