In [1]:
%pylab inline
import keras
import keras.backend as K
from keras import Input, Model, Sequential
from keras.layers import Lambda, LSTM, RepeatVector, Dense, TimeDistributed, Bidirectional, concatenate,\
Conv1D, MaxPooling1D, UpSampling1D, BatchNormalization, Activation, Flatten, Reshape
from keras.optimizers import Adam, RMSprop
from keras.callbacks import EarlyStopping, ReduceLROnPlateau

import tensorflow as tf
from sklearn.preprocessing import MinMaxScaler, QuantileTransformer
from matplotlib import pyplot as plt
import os, shutil

Populating the interactive namespace from numpy and matplotlib


  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [41]:
def split_data(dataset, timesteps):
    D = dataset.shape[1]
    if D < timesteps:
        return None
    elif D == timesteps:
        return dataset
    else:
        splitted_data, remaining_data = np.hsplit(dataset, [timesteps])
        remaining_data = split_data(remaining_data, timesteps)
        if remaining_data is not None:
            return np.vstack([splitted_data, remaining_data])
        return splitted_data

In [94]:
class WGAN:
    def __init__(self, timesteps, latent_dim, generator_type,
                 critic_type):
        self._timesteps = timesteps
        self._latent_dim = latent_dim
        self._generator_type = generator_type
        self._critic_type = critic_type

    def build_model(self, lr):
        optimizer = RMSprop(lr)
        
        self._generator = self._get_generator(
            self._latent_dim, self._timesteps, self._generator_type)
        self._generator.compile(
            loss=self._wasserstein_loss, optimizer=optimizer)

        self._critic = self._get_critic(self._timesteps,
                                                      self._critic_type)
        self._critic.compile(
            loss=self._wasserstein_loss, optimizer=optimizer)

        z = Input(shape=(self._latent_dim, ))
        fake = self._generator(z)

        real = Input(shape=[
            self._timesteps,
        ])

        self._critic.trainable = False

        valid = self._critic(fake)

        self._gan = Model(z, valid, 'GAN')

        self._gan.compile(
            loss=self._wasserstein_loss,
            optimizer=optimizer,
            metrics=['accuracy'])

#         self._gan.summary()
#         self._generator.summary()
#         self._critic.summary()
        return self._gan, self._generator, self._critic

    def _get_generator(self, noise_dim, timesteps, generator_type):
        generator_inputs = Input((latent_dim, ))

        if generator_type == 'dense':
            generated = Dense(timesteps, activation='tanh')(generator_inputs)
            generated = Dense(timesteps, activation='tanh')(generated)

        elif generator_type == 'conv':
            generated = Lambda(lambda x: K.expand_dims(x))(generator_inputs)
            while generated.shape[1] < timesteps:
                generated = Conv1D(
                    32, 3, activation='tanh', padding='same')(generated)
                generated = UpSampling1D(2)(generated)
            generated = Conv1D(
                1, 3, activation='tanh', padding='same')(generated)
            generated = Lambda(lambda x: K.squeeze(x, -1))(generated)
            generated = Dense(timesteps, activation='tanh')(generated)

        elif generator_type == 'lstm':
            generated = RepeatVector(timesteps)(generator_inputs)
            generated = LSTM(32, return_sequences=True)(generated)
            generated = TimeDistributed(Dense(1, activation='tanh'))(generated)
            generated = Lambda(lambda x: K.squeeze(x, -1))(generated)

        elif generator_type == 'blstm':
            generated = RepeatVector(timesteps)(generator_inputs)
            generated = Bidirectional(LSTM(32,
                                           return_sequences=True))(generated)
            generated = TimeDistributed(Dense(1, activation='tanh'))(generated)
            generated = Lambda(lambda x: K.squeeze(x, -1))(generated)

        generator = Model(generator_inputs, generated, 'generator')
        return generator

    def _get_critic(self, timesteps, critic_type):
        critic_inputs = Input((timesteps, ))

        if critic_type == 'dense':
            criticized = Dense(
                timesteps, activation='tanh')(critic_inputs)
            criticized = Dense(timesteps, activation='tanh')(criticized)
            criticized = Dense(1)(criticized)

        elif critic_type == 'conv':
            criticized = Lambda(lambda x: K.expand_dims(x))(
                critic_inputs)
            while criticized.shape[1] > 1:
                criticized = Conv1D(
                    32, 3, activation='tanh', padding='same')(criticized)
                criticized = MaxPooling1D(2, padding='same')(criticized)
            criticized = Flatten()(criticized)
            criticized = Dense(1)(criticized) 

        elif critic_type == 'lstm':
            criticized = Lambda(lambda x: K.expand_dims(x))(
                critic_inputs)
            criticized = LSTM(32, return_sequences=False)(criticized)
            criticized = Dense(1)(criticized)

        elif critic_type == 'blstm':
            criticized = Lambda(lambda x: K.expand_dims(x))(
                critic_inputs)
            criticized = Bidirectional(LSTM(
                32, return_sequences=False))(criticized)
            criticized = Dense(1)(criticized)

        critic = Model(critic_inputs, criticized,
                              'critic')
        return critic

    @staticmethod
    def _wasserstein_loss(y_true, y_pred):
        return K.mean(y_true * y_pred)

    def train(self, batch_size, epochs, n_generator, n_critic, dataset,
              img_frequency, clip_value):
        half_batch = int(batch_size / 2)

        losses = [[], []]
        for epoch in range(epochs):
            for _ in range(n_critic):
                indexes = np.random.randint(0, dataset.shape[0], half_batch)
                batch_transactions = dataset[indexes]

                noise = np.random.normal(0, 1, (half_batch, self._latent_dim))

                generated_transactions = self._generator.predict(noise)

                critic_loss_real = self._critic.train_on_batch(
                    batch_transactions, -np.ones((half_batch, 1)))
                critic_loss_fake = self._critic.train_on_batch(
                    generated_transactions, np.ones((half_batch, 1)))
                critic_loss = 0.5 * np.add(critic_loss_real,
                                                  critic_loss_fake)

                self._clip_weights(clip_value)

            for _ in range(n_generator):
                noise = np.random.normal(0, 1, (batch_size, latent_dim))

                generator_loss = self._gan.train_on_batch(
                    noise, -np.ones((batch_size, 1)))[0]
            
            generator_loss = 1 - generator_loss
            critic_loss = 1 - critic_loss
            
            losses[0].append(generator_loss)
            losses[1].append(critic_loss)

            print("%d [D loss: %f] [G loss: %f]" % (epoch, critic_loss,
                                                    generator_loss))

            if epoch % img_frequency == 0:
                self._save_imgs(epoch)
                self._save_losses(losses)

    def _save_imgs(self, epoch):
        rows, columns = 5, 5
        noise = np.random.normal(0, 1, (rows * columns, latent_dim))
        generated_transactions = self._generator.predict(noise)

        plt.subplots(rows, columns, figsize=(15, 5))
        k = 1
        for i in range(rows):
            for j in range(columns):
                plt.subplot(rows, columns, k)
                plt.plot(generated_transactions[k - 1])
                plt.xticks([])
                plt.yticks([])
                plt.ylim(-1, 1)
                k += 1
        plt.tight_layout()
        plt.savefig('wgan/%05d.png' % epoch)
        plt.savefig('wgan/last.png')
        plt.close()

    @staticmethod
    def _save_losses(losses):
        plt.plot(losses[0])
        plt.plot(losses[1])
        plt.legend(['generator', 'critic'])
        plt.savefig('wgan/losses.png')
        plt.close()
        
    def _clip_weights(self, clip_value):
        weights = [np.clip(w, -clip_value, clip_value) for w in self._critic.get_weights()]
        self._critic.set_weights(weights)

In [95]:
normalized_transactions_filepath = "../../datasets/berka_dataset/usable/normalized_transactions.npy"

timesteps = 100
transactions = np.load(normalized_transactions_filepath)

transactions = split_data(transactions, timesteps)
np.random.shuffle(transactions)

N, D = transactions.shape
print(N, D)

94500 100


In [96]:
batch_size = 64
epochs = int(1e5)
n_critic = 5
n_generator = 1
latent_dim = 10
lr = 0.00005
clip_value = 0.025
img_frequency = 100
generator_type = 'conv'
critic_type = 'conv'

In [97]:
if os.path.exists('wgan'):
    shutil.rmtree('wgan')
os.makedirs('wgan')

wgan = WGAN(timesteps, latent_dim, generator_type, critic_type)
wgan.build_model(lr)
wgan.train(batch_size, epochs, n_generator, n_critic, transactions, img_frequency, clip_value)

  'Discrepancy between trainable weights and collected trainable'


0 [D loss: 0.999957] [G loss: 1.000095]
1 [D loss: 0.999959] [G loss: 1.000096]
2 [D loss: 0.999958] [G loss: 1.000095]
3 [D loss: 0.999956] [G loss: 1.000093]
4 [D loss: 0.999954] [G loss: 1.000092]
5 [D loss: 0.999954] [G loss: 1.000091]
6 [D loss: 0.999953] [G loss: 1.000091]
7 [D loss: 0.999953] [G loss: 1.000090]
8 [D loss: 0.999953] [G loss: 1.000090]
9 [D loss: 0.999953] [G loss: 1.000090]
10 [D loss: 0.999953] [G loss: 1.000089]
11 [D loss: 0.999954] [G loss: 1.000089]
12 [D loss: 0.999954] [G loss: 1.000089]
13 [D loss: 0.999954] [G loss: 1.000089]
14 [D loss: 0.999955] [G loss: 1.000089]
15 [D loss: 0.999955] [G loss: 1.000088]
16 [D loss: 0.999955] [G loss: 1.000087]
17 [D loss: 0.999956] [G loss: 1.000087]
18 [D loss: 0.999956] [G loss: 1.000087]
19 [D loss: 0.999956] [G loss: 1.000086]
20 [D loss: 0.999957] [G loss: 1.000086]
21 [D loss: 0.999957] [G loss: 1.000085]
22 [D loss: 0.999957] [G loss: 1.000085]
23 [D loss: 0.999957] [G loss: 1.000085]
24 [D loss: 0.999958] [G l

198 [D loss: 1.000184] [G loss: 1.000155]
199 [D loss: 1.000289] [G loss: 1.000066]
200 [D loss: 1.000276] [G loss: 0.999931]
201 [D loss: 1.000266] [G loss: 0.999816]
202 [D loss: 1.000309] [G loss: 0.999743]
203 [D loss: 1.000367] [G loss: 0.999646]
204 [D loss: 1.000340] [G loss: 0.999476]
205 [D loss: 1.000388] [G loss: 0.999392]
206 [D loss: 1.000422] [G loss: 0.999338]
207 [D loss: 1.000437] [G loss: 0.999228]
208 [D loss: 1.000401] [G loss: 0.999295]
209 [D loss: 1.000336] [G loss: 0.999094]
210 [D loss: 1.000432] [G loss: 0.999191]
211 [D loss: 1.000276] [G loss: 0.999141]
212 [D loss: 1.000353] [G loss: 0.999093]
213 [D loss: 1.000285] [G loss: 0.999212]
214 [D loss: 1.000172] [G loss: 0.999195]
215 [D loss: 1.000284] [G loss: 0.999451]
216 [D loss: 1.000132] [G loss: 0.999771]
217 [D loss: 1.000183] [G loss: 0.999698]
218 [D loss: 0.999736] [G loss: 0.999895]
219 [D loss: 0.999860] [G loss: 0.999677]
220 [D loss: 0.999661] [G loss: 1.000245]
221 [D loss: 1.000118] [G loss: 1.

394 [D loss: 0.999947] [G loss: 1.001141]
395 [D loss: 0.999986] [G loss: 1.001176]
396 [D loss: 0.999955] [G loss: 1.001168]
397 [D loss: 0.999970] [G loss: 1.001171]
398 [D loss: 0.999964] [G loss: 1.001100]
399 [D loss: 0.999962] [G loss: 1.001163]
400 [D loss: 0.999970] [G loss: 1.001156]
401 [D loss: 0.999959] [G loss: 1.001159]
402 [D loss: 0.999973] [G loss: 1.001136]
403 [D loss: 0.999987] [G loss: 1.001154]
404 [D loss: 1.000000] [G loss: 1.001130]
405 [D loss: 1.000016] [G loss: 1.001134]
406 [D loss: 0.999967] [G loss: 1.001147]
407 [D loss: 0.999964] [G loss: 1.001162]
408 [D loss: 0.999964] [G loss: 1.001161]
409 [D loss: 1.000000] [G loss: 1.001174]
410 [D loss: 0.999995] [G loss: 1.001146]
411 [D loss: 0.999964] [G loss: 1.001141]
412 [D loss: 1.000019] [G loss: 1.001168]
413 [D loss: 1.000031] [G loss: 1.001160]
414 [D loss: 1.000030] [G loss: 1.001174]
415 [D loss: 1.000008] [G loss: 1.001162]
416 [D loss: 0.999999] [G loss: 1.001148]
417 [D loss: 0.999979] [G loss: 1.

590 [D loss: 1.000019] [G loss: 1.002394]
591 [D loss: 1.000147] [G loss: 1.002391]
592 [D loss: 1.000156] [G loss: 1.002458]
593 [D loss: 1.000272] [G loss: 1.002486]
594 [D loss: 1.000082] [G loss: 1.002489]
595 [D loss: 1.000017] [G loss: 1.002602]
596 [D loss: 1.000083] [G loss: 1.002596]
597 [D loss: 0.999998] [G loss: 1.002798]
598 [D loss: 0.999947] [G loss: 1.002754]
599 [D loss: 0.999966] [G loss: 1.002803]
600 [D loss: 1.000121] [G loss: 1.002881]
601 [D loss: 1.000277] [G loss: 1.002811]
602 [D loss: 1.000142] [G loss: 1.002766]
603 [D loss: 1.000174] [G loss: 1.002840]
604 [D loss: 1.000154] [G loss: 1.002889]
605 [D loss: 1.000215] [G loss: 1.002937]
606 [D loss: 1.000197] [G loss: 1.003041]
607 [D loss: 0.999961] [G loss: 1.003056]
608 [D loss: 1.000074] [G loss: 1.003120]
609 [D loss: 1.000004] [G loss: 1.003266]
610 [D loss: 0.999924] [G loss: 1.003082]
611 [D loss: 0.999922] [G loss: 1.003213]
612 [D loss: 0.999846] [G loss: 1.003242]
613 [D loss: 0.999825] [G loss: 1.

786 [D loss: 1.000122] [G loss: 1.003956]
787 [D loss: 0.999787] [G loss: 1.004351]
788 [D loss: 1.000015] [G loss: 1.004500]
789 [D loss: 1.000011] [G loss: 1.004556]
790 [D loss: 1.000124] [G loss: 1.004416]
791 [D loss: 1.000257] [G loss: 1.004444]
792 [D loss: 1.000145] [G loss: 1.004421]
793 [D loss: 1.000264] [G loss: 1.004465]
794 [D loss: 1.000679] [G loss: 1.004506]
795 [D loss: 1.000100] [G loss: 1.004550]
796 [D loss: 1.000285] [G loss: 1.004785]
797 [D loss: 0.999783] [G loss: 1.004979]
798 [D loss: 0.999963] [G loss: 1.005312]
799 [D loss: 1.000070] [G loss: 1.004854]
800 [D loss: 1.000015] [G loss: 1.005146]
801 [D loss: 1.000161] [G loss: 1.005210]
802 [D loss: 1.000117] [G loss: 1.005088]
803 [D loss: 1.000163] [G loss: 1.005371]
804 [D loss: 1.000096] [G loss: 1.005467]
805 [D loss: 1.000361] [G loss: 1.005077]
806 [D loss: 0.999968] [G loss: 1.005279]
807 [D loss: 1.000051] [G loss: 1.005656]
808 [D loss: 1.000053] [G loss: 1.005329]
809 [D loss: 1.000131] [G loss: 1.

982 [D loss: 1.000148] [G loss: 1.006306]
983 [D loss: 1.000069] [G loss: 1.005913]
984 [D loss: 1.000065] [G loss: 1.005987]
985 [D loss: 1.000069] [G loss: 1.006218]
986 [D loss: 1.000181] [G loss: 1.006116]
987 [D loss: 1.000358] [G loss: 1.005928]
988 [D loss: 1.000123] [G loss: 1.006064]
989 [D loss: 1.000428] [G loss: 1.006056]
990 [D loss: 1.000240] [G loss: 1.005940]
991 [D loss: 1.000532] [G loss: 1.006015]
992 [D loss: 1.000105] [G loss: 1.005980]
993 [D loss: 1.000144] [G loss: 1.005823]
994 [D loss: 1.000506] [G loss: 1.005904]
995 [D loss: 1.000216] [G loss: 1.005935]
996 [D loss: 1.000364] [G loss: 1.006134]
997 [D loss: 1.000123] [G loss: 1.006330]
998 [D loss: 1.000023] [G loss: 1.006414]
999 [D loss: 1.000181] [G loss: 1.006572]
1000 [D loss: 1.000035] [G loss: 1.006418]
1001 [D loss: 1.000312] [G loss: 1.006416]
1002 [D loss: 1.000342] [G loss: 1.006541]
1003 [D loss: 1.000465] [G loss: 1.006797]
1004 [D loss: 1.000555] [G loss: 1.007018]
1005 [D loss: 1.000289] [G lo

1173 [D loss: 1.000171] [G loss: 1.008606]
1174 [D loss: 1.000105] [G loss: 1.008472]
1175 [D loss: 1.000161] [G loss: 1.008432]
1176 [D loss: 0.999981] [G loss: 1.008177]
1177 [D loss: 1.000141] [G loss: 1.008096]
1178 [D loss: 1.000147] [G loss: 1.008139]
1179 [D loss: 1.000037] [G loss: 1.007863]
1180 [D loss: 1.000069] [G loss: 1.007838]
1181 [D loss: 1.000459] [G loss: 1.007722]
1182 [D loss: 1.000296] [G loss: 1.007690]
1183 [D loss: 1.000190] [G loss: 1.007558]
1184 [D loss: 1.000535] [G loss: 1.007499]
1185 [D loss: 1.000439] [G loss: 1.007119]
1186 [D loss: 1.000121] [G loss: 1.007302]
1187 [D loss: 1.000235] [G loss: 1.007310]
1188 [D loss: 0.999958] [G loss: 1.007264]
1189 [D loss: 1.000466] [G loss: 1.007251]
1190 [D loss: 1.000255] [G loss: 1.007425]
1191 [D loss: 1.000254] [G loss: 1.007677]
1192 [D loss: 1.000176] [G loss: 1.007636]
1193 [D loss: 1.000302] [G loss: 1.007566]
1194 [D loss: 1.000256] [G loss: 1.007764]
1195 [D loss: 1.000066] [G loss: 1.007764]
1196 [D los

1364 [D loss: 1.000528] [G loss: 1.007828]
1365 [D loss: 1.000317] [G loss: 1.008116]
1366 [D loss: 1.000302] [G loss: 1.008186]
1367 [D loss: 1.000243] [G loss: 1.008289]
1368 [D loss: 1.000150] [G loss: 1.007953]
1369 [D loss: 1.000049] [G loss: 1.008056]
1370 [D loss: 0.999896] [G loss: 1.008118]
1371 [D loss: 0.999886] [G loss: 1.007991]
1372 [D loss: 0.999928] [G loss: 1.007960]
1373 [D loss: 1.000135] [G loss: 1.007717]
1374 [D loss: 1.000481] [G loss: 1.006849]
1375 [D loss: 1.000387] [G loss: 1.007483]
1376 [D loss: 1.000106] [G loss: 1.007241]
1377 [D loss: 1.000262] [G loss: 1.007381]
1378 [D loss: 1.000418] [G loss: 1.007164]
1379 [D loss: 1.000176] [G loss: 1.007372]
1380 [D loss: 1.000659] [G loss: 1.007245]
1381 [D loss: 1.000217] [G loss: 1.007105]
1382 [D loss: 0.999907] [G loss: 1.007210]
1383 [D loss: 1.000409] [G loss: 1.007415]
1384 [D loss: 1.000431] [G loss: 1.007463]
1385 [D loss: 1.000364] [G loss: 1.007303]
1386 [D loss: 1.000343] [G loss: 1.007607]
1387 [D los

KeyboardInterrupt: 