In [1]:
%pylab inline

import keras.backend as K
from keras import Input, Model, Sequential
from keras.layers import *
from keras.optimizers import Adam, RMSprop
from keras.callbacks import *
from keras.engine import InputSpec, Layer
from keras import initializers, regularizers, constraints
from keras.models import load_model

import tensorflow as tf
from matplotlib import pyplot as plt

from datetime import datetime
from pathlib import Path

import pickle

Populating the interactive namespace from numpy and matplotlib


  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [2]:
def split_data(dataset, timesteps):
    D = dataset.shape[1]
    if D < timesteps:
        return None
    elif D == timesteps:
        return dataset
    else:
        splitted_data, remaining_data = np.hsplit(dataset, [timesteps])
        remaining_data = split_data(remaining_data, timesteps)
        if remaining_data is not None:
            return np.vstack([splitted_data, remaining_data])
        return splitted_data
    
class MinibatchDiscrimination(Layer):
    """Concatenates to each sample information about how different the input
    features for that sample are from features of other samples in the same
    minibatch, as described in Salimans et. al. (2016). Useful for preventing
    GANs from collapsing to a single output. When using this layer, generated
    samples and reference samples should be in separate batches."""

    def __init__(self, nb_kernels, kernel_dim, init='glorot_uniform', weights=None,
                 W_regularizer=None, activity_regularizer=None,
                 W_constraint=None, input_dim=None, **kwargs):
        self.init = initializers.get(init)
        self.nb_kernels = nb_kernels
        self.kernel_dim = kernel_dim
        self.input_dim = input_dim

        self.W_regularizer = regularizers.get(W_regularizer)
        self.activity_regularizer = regularizers.get(activity_regularizer)

        self.W_constraint = constraints.get(W_constraint)

        self.initial_weights = weights
        self.input_spec = [InputSpec(ndim=2)]

        if self.input_dim:
            kwargs['input_shape'] = (self.input_dim,)
        super(MinibatchDiscrimination, self).__init__(**kwargs)

    def build(self, input_shape):
        assert len(input_shape) == 2

        input_dim = input_shape[1]
        self.input_spec = [InputSpec(dtype=K.floatx(),
                                     shape=(None, input_dim))]

        self.W = self.add_weight(shape=(self.nb_kernels, input_dim, self.kernel_dim),
            initializer=self.init,
            name='kernel',
            regularizer=self.W_regularizer,
            trainable=True,
            constraint=self.W_constraint)

        # Set built to true.
        super(MinibatchDiscrimination, self).build(input_shape)

    def call(self, x, mask=None):
        activation = K.reshape(K.dot(x, self.W), (-1, self.nb_kernels, self.kernel_dim))
        diffs = K.expand_dims(activation, 3) - K.expand_dims(K.permute_dimensions(activation, [1, 2, 0]), 0)
        abs_diffs = K.sum(K.abs(diffs), axis=2)
        minibatch_features = K.sum(K.exp(-abs_diffs), axis=2)
#         return K.concatenate([x, minibatch_features], 1)
        return minibatch_features

    def compute_output_shape(self, input_shape):
        assert input_shape and len(input_shape) == 2
#         return input_shape[0], input_shape[1]+self.nb_kernels
        return input_shape[0], self.nb_kernels

    def get_config(self):
        config = {'nb_kernels': self.nb_kernels,
                  'kernel_dim': self.kernel_dim,
#                   'init': self.init.__name__,
                  'W_regularizer': self.W_regularizer.get_config() if self.W_regularizer else None,
                  'activity_regularizer': self.activity_regularizer.get_config() if self.activity_regularizer else None,
                  'W_constraint': self.W_constraint.get_config() if self.W_constraint else None,
                  'input_dim': self.input_dim}
        base_config = super(MinibatchDiscrimination, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

In [20]:
class WGAN:
    def __init__(self, timesteps, latent_dim, run_dir, img_dir, model_dir, generated_datesets_dir):
        self._timesteps = timesteps
        self._latent_dim = latent_dim
        self._run_dir = run_dir
        self._img_dir = img_dir
        self._model_dir = model_dir
        self._generated_datesets_dir = generated_datesets_dir
        
        self._save_config()
        
        self._epoch = 0
        self._losses = [[], []]

    def build_models(self, generator_lr, critic_lr):        
        self._generator = self._build_generator(self._latent_dim, self._timesteps)

        self._critic = self._build_critic(self._timesteps)
        self._critic.compile(loss=self._wasserstein_loss, optimizer=RMSprop(critic_lr))

        z = Input(shape=(self._latent_dim, ))
        fake = self._generator(z)

        for layer in self._critic.layers:
            layer.trainable = False
        self._critic.trainable = False

        valid = self._critic(fake)

        self._gan = Model(z, valid, 'GAN')

        self._gan.compile(
            loss=self._wasserstein_loss,
            optimizer=RMSprop(generator_lr),
            metrics=['accuracy'])
        
        return self._gan, self._generator, self._critic

    def _build_generator(self, noise_dim, timesteps):
        generator_inputs = Input((latent_dim, ))
        generated = generator_inputs
        
        generated = Lambda(lambda x: K.expand_dims(x))(generated)
        while generated.shape[1] < timesteps:
            generated = Conv1D(
                32, 3, activation='relu', padding='same')(generated)
            generated = UpSampling1D(2)(generated)
        generated = Conv1D(
            1, 3, activation='relu', padding='same')(generated)
        generated = Lambda(lambda x: K.squeeze(x, -1))(generated)
        generated = Dense(timesteps, activation='tanh')(generated)

        generator = Model(generator_inputs, generated, 'generator')
        return generator

    def _build_critic(self, timesteps):
        critic_inputs = Input((timesteps, ))
        criticized = critic_inputs
        
        mbd = MinibatchDiscrimination(5, 3)(criticized)
        
        criticized = Lambda(lambda x: K.expand_dims(x))(
            criticized)
        while criticized.shape[1] > 1:
            criticized = Conv1D(
                32, 3, activation='relu', padding='same')(criticized)
            criticized = MaxPooling1D(2, padding='same')(criticized)
        criticized = Flatten()(criticized)
        criticized = Concatenate()([criticized, mbd])
        criticized = Dense(32, activation='relu')(criticized)
        criticized = Dense(15, activation='relu')(criticized)
        criticized = Dense(1)(criticized) 

        critic = Model(critic_inputs, criticized, 'critic')
        return critic

    def train(self, batch_size, epochs, n_generator, n_critic, dataset, clip_value,
           img_frequency, model_save_frequency, dataset_generation_frequency, dataset_generation_size):
        half_batch = int(batch_size / 2)

        
        while self._epoch < epochs:
            self._epoch += 1
            for _ in range(n_critic):
                indexes = np.random.randint(0, dataset.shape[0], half_batch)
                batch_transactions = dataset[indexes]

                noise = np.random.normal(0, 1, (half_batch, self._latent_dim))

                generated_transactions = self._generator.predict(noise)

                critic_loss_real = self._critic.train_on_batch(
                    batch_transactions, -np.ones((half_batch, 1)))
                critic_loss_fake = self._critic.train_on_batch(
                    generated_transactions, np.ones((half_batch, 1)))
                critic_loss = 0.5 * np.add(critic_loss_real,
                                                  critic_loss_fake)

                self._clip_weights(clip_value)

            for _ in range(n_generator):
                noise = np.random.normal(0, 1, (batch_size, latent_dim))

                generator_loss = self._gan.train_on_batch(
                    noise, -np.ones((batch_size, 1)))[0]
            
            generator_loss = 1 - generator_loss
            critic_loss = 1 - critic_loss
            
            self._losses[0].append(generator_loss)
            self._losses[1].append(critic_loss)

            print("%d [C loss: %f] [G loss: %f]" % (self._epoch, critic_loss,
                                                    generator_loss))

            if self._epoch % img_frequency == 0:
                self._save_imgs()
                self._save_latent_space()
            
            if self._epoch % model_save_frequency == 0:
                self._save_models()
                
            if self._epoch % dataset_generation_frequency == 0:
                self._generate_dataset(self._epoch, dataset_generation_size)
                
            if self._epoch % 250 == 0:
                self._save_losses()
          
        self._generate_dataset(epochs, dataset_generation_size)
        self._save_losses(self._losses)
        self._save_models()
        self._save_imgs()
        self._save_latent_space()
        
        return losses
    
    def _save_imgs(self):
        rows, columns = 5, 5
        noise = np.random.normal(0, 1, (rows * columns, latent_dim))
        generated_transactions = self._generator.predict(noise)

        plt.subplots(rows, columns, figsize=(15, 5))
        k = 1
        for i in range(rows):
            for j in range(columns):
                plt.subplot(rows, columns, k)
                plt.plot(generated_transactions[k - 1])
                plt.xticks([])
                plt.yticks([])
                plt.ylim(-1, 1)
                k += 1
        plt.tight_layout()
        plt.savefig(str(self._img_dir / ('%05d.png' % self._epoch)))
        plt.savefig(str(self._img_dir / 'last.png'))
        plt.clf()
        plt.close()
        
    def _save_latent_space(self):
        if self._latent_dim > 2:
            latent_vector = np.random.normal(0, 1, latent_dim)
        plt.subplots(5, 5, figsize=(15, 5))

        for i, v_i in enumerate(np.linspace(-2, 2, 5, True)):
            for j, v_j in enumerate(np.linspace(-2, 2, 5, True)):
                if self._latent_dim > 2:
                    latent_vector[-2:] = [v_i, v_j]
                else:
                    latent_vector = np.array([v_i, v_j])
                    
                plt.subplot(5, 5, i*5+j+1)
                plt.plot(self._generator.predict(latent_vector.reshape((1, self._latent_dim))).T)
                plt.xticks([])
                plt.yticks([])
                plt.ylim(-1, 1)
        plt.tight_layout()
        plt.savefig(str(self._img_dir / ('latent_space.png')))
        plt.clf()
        plt.close()
    def _save_losses(self):
        plt.figure(figsize=(15, 3))
        plt.plot(self._losses[0])
        plt.plot(self._losses[1])
        plt.legend(['generator', 'critic'])
        plt.savefig(str(self._img_dir / 'losses.png'))
        plt.clf()
        plt.close()
        
        with open(str(self._run_dir / 'losses.p'), 'wb') as f:
            pickle.dump(self._losses, f)
        
    def _clip_weights(self, clip_value):
        for l in self._critic.layers:
#             if 'minibatch_discrimination' not in l.name:
            weights = [np.clip(w, -clip_value, clip_value) for w in l.get_weights()]
            l.set_weights(weights)

    def _save_config(self):
        config = {
            'timesteps' : self._timesteps,
            'latent_dim' : self._latent_dim,
            'run_dir' : self._run_dir,
            'img_dir' : self._img_dir,
            'model_dir' : self._model_dir,
            'generated_datesets_dir' : self._generated_datesets_dir
        }
        
        with open(str(self._run_dir / 'config.p'), 'wb') as f:
            pickle.dump(config, f)
        
    def _save_models(self):
        self._gan.save(self._model_dir / 'wgan.h5')
        self._generator.save(self._model_dir / 'generator.h5')
        self._critic.save(self._model_dir / 'critic.h5')
        
    def _generate_dataset(self, epoch, dataset_generation_size):
        z_samples = np.random.normal(0, 1, (dataset_generation_size, self._latent_dim))
        generated_dataset = self._generator.predict(z_samples)
        np.save(self._generated_datesets_dir / ('%d_generated_data' % epoch), generated_dataset)
        np.save(self._generated_datesets_dir / 'last', generated_dataset)
        
    def get_models(self):
        return self._gan, self._generator, self._critic
    
    @staticmethod
    def _wasserstein_loss(y_true, y_pred):
        return K.mean(y_true * y_pred)

    def restore_training(self):
        self.load_models()
        with open(str(self._run_dir / 'losses.p'), 'rb') as f:
            self._losses = pickle.load(f)
            self._epoch = len(self._losses[0])
        
        return self._gan, self._generator, self._critic
    
    def load_models(self):
        custom_objects = {
            'MinibatchDiscrimination':MinibatchDiscrimination,
            '_wasserstein_loss':self._wasserstein_loss
        }
        self._gan = load_model(self._model_dir / 'wgan.h5', custom_objects=custom_objects)
        self._generator = load_model(self._model_dir / 'generator.h5')
        self._critic = load_model(self._model_dir / 'critic.h5', custom_objects=custom_objects)
        
        return self._gan, self._generator, self._critic

In [21]:
normalized_transactions_filepath = "../datasets/berka_dataset/usable/normalized_transactions.npy"

timesteps = 100
transactions = np.load(normalized_transactions_filepath)
transactions = split_data(transactions, timesteps)
transactions = transactions[np.std(transactions, 1) > float(1e-7)]
N, D = transactions.shape
print(transactions.shape)

(53888, 100)


In [40]:
batch_size = 64
epochs = 50000
n_critic = 10
n_generator = 1
latent_dim = 2
generator_lr = 0.00005
critic_lr = 0.00005
clip_value = 0.05
img_frequency = 250
model_save_frequency = 3000
dataset_generation_frequency = 25000
dataset_generation_size = 100000

In [41]:
root_path = Path('wgan')
if not root_path.exists():
    root_path.mkdir()
    
current_datetime = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')

run_dir = root_path / current_datetime
img_dir = run_dir / 'img'
model_dir = run_dir / 'models'
generated_datesets_dir = run_dir / 'generated_datasets'

img_dir.mkdir(parents=True)
model_dir.mkdir(parents=True)
generated_datesets_dir.mkdir(parents=True)

In [None]:
wgan = WGAN(timesteps, latent_dim, run_dir, img_dir, model_dir, generated_datesets_dir)
# gan, generator, critic = wgan.restore_training()
gan, generator, critic = wgan.build_models(generator_lr, critic_lr)

# gan.summary()
# generator.summary()
# critic.summary()
        
losses = wgan.train(batch_size, epochs, n_generator, n_critic, transactions, clip_value,
           img_frequency, model_save_frequency, dataset_generation_frequency, dataset_generation_size)

  'Discrepancy between trainable weights and collected trainable'


1 [C loss: 0.993073] [G loss: 1.122786]
2 [C loss: 0.993184] [G loss: 1.120527]
3 [C loss: 0.993156] [G loss: 1.116927]
4 [C loss: 0.994011] [G loss: 1.112862]
5 [C loss: 0.994898] [G loss: 1.108846]
6 [C loss: 0.993928] [G loss: 1.104602]
7 [C loss: 0.993715] [G loss: 1.099973]
8 [C loss: 0.994741] [G loss: 1.095441]
9 [C loss: 0.994759] [G loss: 1.090491]
10 [C loss: 0.995501] [G loss: 1.085775]
11 [C loss: 0.996383] [G loss: 1.081729]
12 [C loss: 0.996571] [G loss: 1.076770]
13 [C loss: 0.997148] [G loss: 1.073577]
14 [C loss: 0.996616] [G loss: 1.070647]
15 [C loss: 0.996680] [G loss: 1.067704]
16 [C loss: 0.996339] [G loss: 1.064939]
17 [C loss: 0.997534] [G loss: 1.062069]
18 [C loss: 0.998651] [G loss: 1.059287]
19 [C loss: 0.998689] [G loss: 1.056718]
20 [C loss: 0.998001] [G loss: 1.054070]
21 [C loss: 0.998027] [G loss: 1.051318]
22 [C loss: 0.998593] [G loss: 1.048653]
23 [C loss: 0.999052] [G loss: 1.046230]
24 [C loss: 1.000187] [G loss: 1.043712]
25 [C loss: 0.999215] [G 

199 [C loss: 1.312496] [G loss: 0.885491]
200 [C loss: 1.305782] [G loss: 0.949417]
201 [C loss: 1.269079] [G loss: 0.880783]
202 [C loss: 1.248957] [G loss: 0.957727]
203 [C loss: 1.281342] [G loss: 0.931132]
204 [C loss: 1.239685] [G loss: 0.973109]
205 [C loss: 1.260051] [G loss: 0.929300]
206 [C loss: 1.264979] [G loss: 1.048449]
207 [C loss: 1.191493] [G loss: 0.996838]
208 [C loss: 1.193752] [G loss: 1.000010]
209 [C loss: 1.215144] [G loss: 1.069320]
210 [C loss: 1.165568] [G loss: 1.062668]
211 [C loss: 1.193196] [G loss: 1.034280]
212 [C loss: 1.154040] [G loss: 1.072757]
213 [C loss: 1.129109] [G loss: 1.105373]
214 [C loss: 1.167810] [G loss: 1.099753]
215 [C loss: 1.179615] [G loss: 1.107084]
216 [C loss: 1.081284] [G loss: 1.039549]
217 [C loss: 1.136434] [G loss: 1.069490]
218 [C loss: 1.087105] [G loss: 1.110586]
219 [C loss: 1.073676] [G loss: 1.184105]
220 [C loss: 1.072563] [G loss: 1.083028]
221 [C loss: 1.045569] [G loss: 1.103777]
222 [C loss: 1.059805] [G loss: 1.

395 [C loss: 1.201990] [G loss: 0.920832]
396 [C loss: 1.216894] [G loss: 0.911046]
397 [C loss: 1.198990] [G loss: 0.901313]
398 [C loss: 1.176858] [G loss: 0.927491]
399 [C loss: 1.180320] [G loss: 0.945057]
400 [C loss: 1.182829] [G loss: 0.935656]
401 [C loss: 1.189802] [G loss: 0.934412]
402 [C loss: 1.200998] [G loss: 0.908960]
403 [C loss: 1.181212] [G loss: 0.943789]
404 [C loss: 1.182106] [G loss: 0.944480]
405 [C loss: 1.173662] [G loss: 0.974256]
406 [C loss: 1.187816] [G loss: 0.977275]
407 [C loss: 1.179362] [G loss: 0.950644]
408 [C loss: 1.173190] [G loss: 0.964725]
409 [C loss: 1.190720] [G loss: 0.938123]
410 [C loss: 1.205533] [G loss: 0.956070]
411 [C loss: 1.173567] [G loss: 0.940222]
412 [C loss: 1.165762] [G loss: 0.970313]
413 [C loss: 1.190525] [G loss: 0.957427]
414 [C loss: 1.180587] [G loss: 0.928970]
415 [C loss: 1.162624] [G loss: 0.953805]
416 [C loss: 1.219877] [G loss: 0.959039]
417 [C loss: 1.165294] [G loss: 0.948054]
418 [C loss: 1.198491] [G loss: 0.

591 [C loss: 1.230583] [G loss: 0.615198]
592 [C loss: 1.247595] [G loss: 0.611361]
593 [C loss: 1.250569] [G loss: 0.614429]
594 [C loss: 1.241164] [G loss: 0.612477]
595 [C loss: 1.247645] [G loss: 0.614337]
596 [C loss: 1.222721] [G loss: 0.611082]
597 [C loss: 1.240395] [G loss: 0.612036]
598 [C loss: 1.248802] [G loss: 0.611874]
599 [C loss: 1.228392] [G loss: 0.611937]
600 [C loss: 1.235847] [G loss: 0.609261]
601 [C loss: 1.198141] [G loss: 0.610017]
602 [C loss: 1.186047] [G loss: 0.610254]
603 [C loss: 1.184451] [G loss: 0.608402]
604 [C loss: 1.202771] [G loss: 0.604830]
605 [C loss: 1.185530] [G loss: 0.608380]
606 [C loss: 1.201447] [G loss: 0.606302]
607 [C loss: 1.237316] [G loss: 0.606061]
608 [C loss: 1.192293] [G loss: 0.604511]
609 [C loss: 1.202302] [G loss: 0.602270]
610 [C loss: 1.175618] [G loss: 0.603850]
611 [C loss: 1.187757] [G loss: 0.601997]
612 [C loss: 1.152892] [G loss: 0.600354]
613 [C loss: 1.165831] [G loss: 0.599790]
614 [C loss: 1.175818] [G loss: 0.

787 [C loss: 1.182673] [G loss: 0.598517]
788 [C loss: 1.173946] [G loss: 0.595949]
789 [C loss: 1.230877] [G loss: 0.596565]
790 [C loss: 1.244332] [G loss: 0.596270]
791 [C loss: 1.208407] [G loss: 0.596227]
792 [C loss: 1.247495] [G loss: 0.597222]
793 [C loss: 1.227332] [G loss: 0.594970]
794 [C loss: 1.237584] [G loss: 0.595783]
795 [C loss: 1.183408] [G loss: 0.597291]
796 [C loss: 1.156683] [G loss: 0.596918]
797 [C loss: 1.232455] [G loss: 0.597050]
798 [C loss: 1.231256] [G loss: 0.596322]
799 [C loss: 1.201827] [G loss: 0.595361]
800 [C loss: 1.215535] [G loss: 0.595203]
801 [C loss: 1.215014] [G loss: 0.596422]
802 [C loss: 1.208579] [G loss: 0.596378]
803 [C loss: 1.221024] [G loss: 0.596063]
804 [C loss: 1.233395] [G loss: 0.596568]
805 [C loss: 1.207743] [G loss: 0.597762]
806 [C loss: 1.175224] [G loss: 0.599318]
807 [C loss: 1.201460] [G loss: 0.597158]
808 [C loss: 1.197478] [G loss: 0.598060]
809 [C loss: 1.213081] [G loss: 0.598494]
810 [C loss: 1.201293] [G loss: 0.

983 [C loss: 1.221638] [G loss: 0.605105]
984 [C loss: 1.187890] [G loss: 0.607070]
985 [C loss: 1.218799] [G loss: 0.607740]
986 [C loss: 1.267716] [G loss: 0.610671]
987 [C loss: 1.222787] [G loss: 0.612479]
988 [C loss: 1.280317] [G loss: 0.610687]
989 [C loss: 1.255935] [G loss: 0.608362]
990 [C loss: 1.243675] [G loss: 0.611553]
991 [C loss: 1.253761] [G loss: 0.608150]
992 [C loss: 1.253801] [G loss: 0.608142]
993 [C loss: 1.254948] [G loss: 0.611886]
994 [C loss: 1.229728] [G loss: 0.606692]
995 [C loss: 1.240099] [G loss: 0.608883]
996 [C loss: 1.250936] [G loss: 0.611420]
997 [C loss: 1.270787] [G loss: 0.608145]
998 [C loss: 1.271960] [G loss: 0.606102]
999 [C loss: 1.233629] [G loss: 0.608473]
1000 [C loss: 1.261127] [G loss: 0.611657]
1001 [C loss: 1.220443] [G loss: 0.609456]
1002 [C loss: 1.240139] [G loss: 0.607351]
1003 [C loss: 1.241197] [G loss: 0.607775]
1004 [C loss: 1.243644] [G loss: 0.608301]
1005 [C loss: 1.279077] [G loss: 0.610549]
1006 [C loss: 1.234931] [G l

1174 [C loss: 1.335597] [G loss: 0.588969]
1175 [C loss: 1.333192] [G loss: 0.588454]
1176 [C loss: 1.307215] [G loss: 0.589037]
1177 [C loss: 1.358552] [G loss: 0.586809]
1178 [C loss: 1.344676] [G loss: 0.586657]
1179 [C loss: 1.355948] [G loss: 0.587068]
1180 [C loss: 1.331233] [G loss: 0.587570]
1181 [C loss: 1.318292] [G loss: 0.587527]
1182 [C loss: 1.360862] [G loss: 0.586848]
1183 [C loss: 1.329904] [G loss: 0.585330]
1184 [C loss: 1.339904] [G loss: 0.587894]
1185 [C loss: 1.304279] [G loss: 0.586185]
1186 [C loss: 1.314224] [G loss: 0.586210]
1187 [C loss: 1.282005] [G loss: 0.584900]
1188 [C loss: 1.317690] [G loss: 0.585411]
1189 [C loss: 1.346822] [G loss: 0.587110]
1190 [C loss: 1.305357] [G loss: 0.583852]
1191 [C loss: 1.370957] [G loss: 0.586160]
1192 [C loss: 1.398076] [G loss: 0.585315]
1193 [C loss: 1.343447] [G loss: 0.584322]
1194 [C loss: 1.342057] [G loss: 0.585368]
1195 [C loss: 1.379256] [G loss: 0.584804]
1196 [C loss: 1.342069] [G loss: 0.587407]
1197 [C los