In [1]:
import os, sys
sys.path.append(os.getcwd())

import time

import numpy as np
import tensorflow as tf

import language_helpers
import tflib as lib
import tflib.ops.linear
import tflib.ops.conv1d
import tflib.plot

In [2]:
# Download Google Billion Word at http://www.statmt.org/lm-benchmark/ and
# fill in the path to the extracted files here!
DATA_DIR = '../Dataset/AlexaTop1M_NoSeparate'
if len(DATA_DIR) == 0:
    raise Exception("Please specify path to data directory in gan_language.py!")

BATCH_SIZE = 64 # Batch size
# How many iterations to train for, min value is 1000, Please increase the number of iteration in 1000 units
ITERS = 30000 
SEQ_LEN = 32 # Sequence length in characters
DIM = 512 # Model dimensionality. This is fairly slow and overfits, even on
          # Billion Word. Consider decreasing for smaller datasets.
CRITIC_ITERS = 10 # How many critic iterations per generator iteration. We
                  # use 10 for the results in the paper, but 5 should work fine
                  # as well.
LAMBDA = 10 # Gradient penalty lambda hyperparameter.
MAX_N_EXAMPLES = 100000 # Max number of data examples to load. If data loading
                          # is too slow or takes too much RAM, you can decrease
                          # this (at the expense of having less training data). default value is 10000000

In [3]:
lib.print_model_settings(locals().copy())

lines, charmap, inv_charmap = language_helpers.load_dataset(
    max_length=SEQ_LEN,
    max_n_examples=MAX_N_EXAMPLES,
    data_dir=DATA_DIR
)

Uppercase local vars:
	BATCH_SIZE: 64
	CRITIC_ITERS: 10
	DATA_DIR: ../Dataset/AlexaTop1M_NoSeparate
	DIM: 512
	ITERS: 30000
	LAMBDA: 10
	MAX_N_EXAMPLES: 100000
	SEQ_LEN: 32
loading dataset...
('w', 'e', 'b', 'n', 'o', 'd', 'e', '.', 'm', 'x', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ')
('g', 'r', 'n', 'b', 'a', '.', 'c', 'o', 'm', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ')
('v', 'i', 'e', '-', 'p', 'u', 'b', 'l', 'i', 'q', 'u', 'e', '.', 'f', 'r', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ')
('s', 'a', 'n', 'd', 'i', 'e', 'g', 'o', 'u', 'n', 'i', 'o', 'n', 't', 'r', 'i', 'b', 'u', 'n', 'e', '.', 'c', 'o', 'm', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ')
('r', 'o', 'a', 'm', 'a', 'n', 's', '.', 'c', 'o', 'm', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ')
('c

In [4]:
def softmax(logits):
    return tf.reshape(
        tf.nn.softmax(
            tf.reshape(logits, [-1, len(charmap)])
        ),
        tf.shape(logits)
    )

In [5]:
def make_noise(shape):
    return tf.random_normal(shape)

In [6]:
def ResBlock(name, inputs):
    output = inputs
    output = tf.nn.relu(output)
    output = lib.ops.conv1d.Conv1D(name+'.1', DIM, DIM, 5, output)
    output = tf.nn.relu(output)
    output = lib.ops.conv1d.Conv1D(name+'.2', DIM, DIM, 5, output)
    return inputs + (0.3*output)

In [7]:
def Generator(n_samples, prev_outputs=None):
    output = make_noise(shape=[n_samples, 128])
    output = lib.ops.linear.Linear('Generator.Input', 128, SEQ_LEN*DIM, output)
    output = tf.reshape(output, [-1, DIM, SEQ_LEN])
    output = ResBlock('Generator.1', output)
    output = ResBlock('Generator.2', output)
    output = ResBlock('Generator.3', output)
    output = ResBlock('Generator.4', output)
    output = ResBlock('Generator.5', output)
    output = lib.ops.conv1d.Conv1D('Generator.Output', DIM, len(charmap), 1, output)
    output = tf.transpose(output, [0, 2, 1])
    output = softmax(output)
    return output

In [8]:
def Discriminator(inputs):
    output = tf.transpose(inputs, [0,2,1])
    output = lib.ops.conv1d.Conv1D('Discriminator.Input', len(charmap), DIM, 1, output)
    output = ResBlock('Discriminator.1', output)
    output = ResBlock('Discriminator.2', output)
    output = ResBlock('Discriminator.3', output)
    output = ResBlock('Discriminator.4', output)
    output = ResBlock('Discriminator.5', output)
    output = tf.reshape(output, [-1, SEQ_LEN*DIM])
    output = lib.ops.linear.Linear('Discriminator.Output', SEQ_LEN*DIM, 1, output)
    return output

In [9]:
real_inputs_discrete = tf.placeholder(tf.int32, shape=[BATCH_SIZE, SEQ_LEN])
real_inputs = tf.one_hot(real_inputs_discrete, len(charmap))
fake_inputs = Generator(BATCH_SIZE)
fake_inputs_discrete = tf.argmax(fake_inputs, fake_inputs.get_shape().ndims-1)

In [10]:
disc_real = Discriminator(real_inputs) 
disc_fake = Discriminator(fake_inputs)

In [11]:
disc_cost = tf.reduce_mean(disc_fake) - tf.reduce_mean(disc_real)
gen_cost = -tf.reduce_mean(disc_fake)

In [12]:
# WGAN lipschitz-penalty
alpha = tf.random_uniform(
    shape=[BATCH_SIZE,1,1], 
    minval=0.,
    maxval=1.
)

In [13]:
differences = fake_inputs - real_inputs
interpolates = real_inputs + (alpha*differences)
gradients = tf.gradients(Discriminator(interpolates), [interpolates])[0]
slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1,2]))
gradient_penalty = tf.reduce_mean((slopes-1.)**2)
disc_cost += LAMBDA*gradient_penalty

In [14]:
gen_params = lib.params_with_name('Generator')
disc_params = lib.params_with_name('Discriminator')

In [15]:
gen_train_op = tf.train.AdamOptimizer(learning_rate=1e-4, beta1=0.5, beta2=0.999).minimize(gen_cost, var_list=gen_params)
disc_train_op = tf.train.AdamOptimizer(learning_rate=1e-4, beta1=0.5, beta2=0.999).minimize(disc_cost, var_list=disc_params)

In [16]:
# Dataset iterator
def inf_train_gen():
    while True:
        np.random.shuffle(lines)
        for i in range(0, len(lines)-BATCH_SIZE+1, BATCH_SIZE):
            yield np.array(
                [[charmap[c] for c in l] for l in lines[i:i+BATCH_SIZE]], 
                dtype='int32'
            )

In [17]:
# During training we monitor JS divergence between the true & generated ngram
# distributions for n=1,2,3,4. To get an idea of the optimal values, we
# evaluate these statistics on a held-out set first.

true_char_ngram_lms = [language_helpers.NgramLanguageModel(i+1, lines[10*BATCH_SIZE:], tokenize=False) for i in range(4)]
validation_char_ngram_lms = [language_helpers.NgramLanguageModel(i+1, lines[:10*BATCH_SIZE], tokenize=False) for i in range(4)]
for i in range(4):
    print ( "validation set JSD for n={}: {}".format(i+1, true_char_ngram_lms[i].js_with(validation_char_ngram_lms[i])) )
true_char_ngram_lms = [language_helpers.NgramLanguageModel(i+1, lines, tokenize=False) for i in range(4)]

validation set JSD for n=1: 0.00027936546409873796
validation set JSD for n=2: 0.010836911775376186
validation set JSD for n=3: 0.07873450820936567
validation set JSD for n=4: 0.17761502240792434


In [18]:
with tf.Session() as session:

    session.run(tf.global_variables_initializer())

    def generate_samples():
        samples = session.run(fake_inputs)
        samples = np.argmax(samples, axis=2)
        decoded_samples = []
        for i in range(len(samples)):
            decoded = []
            for j in range(len(samples[i])):
                decoded.append(inv_charmap[samples[i][j]])
            decoded_samples.append(tuple(decoded))
        return decoded_samples

    gen = inf_train_gen()

    sum_time = 0.
    line_time = 0. 
    loading_str = "*"
    for iteration in range(ITERS):
        start_time = time.time()
        
        if (iteration == 0):
            now_time = time.clock()
            print("[Start]")

        # Train generator
        if iteration > 0:
            _ = session.run(gen_train_op)

        # Train critic
        for i in range(CRITIC_ITERS):
            _data = gen.__next__()
            _disc_cost, _ = session.run(
                [disc_cost, disc_train_op],
                feed_dict={real_inputs_discrete:_data}
            )
            
            #print("_disc_cost "+str(_disc_cost))
            #print("_ "+str(_))
            #print("_data "+str(_data))
            #print("gen_cost "+str(gen_cost))
            #print("disc_cost"+str(disc_cost))

        # How many iterations to change line 
        change_line=int(ITERS/1000)
        
        after_time=time.clock() - now_time
        sum_time+=after_time
        eta_time = (ITERS-iteration)*(after_time)
        
        print("[{1:10}] [Iteration]: {0:10} [Unit iteration time    ]: {2:10.2f} secs [ETA]: {3:10.2f} secs".format( (iteration+1), loading_str, after_time, eta_time) , end="\r")
        now_time = time.clock()
        if iteration % change_line == (change_line-1):
            loading_str += "*"
            if iteration % (10*change_line) == (10*change_line-1):            
                print("{5:5.0f}{0:7} [Iteration]: {1:10} [{2:23}]: {3:10.2f} secs [SUM]: {4:10.2f}".format("% Done!", (iteration+1), (str(10*change_line)+"x iterations time"), (sum_time-line_time), sum_time, (100*iteration/ITERS) ) )
                loading_str = "*"
                line_time = sum_time
        
        lib.plot.plot('time', time.time() - start_time)
        lib.plot.plot('train disc cost', _disc_cost)        

        if iteration % (10*change_line) == (10*change_line-1):
            #print("checkpintB"+str(iteration+1))
            samples = []
            for i in range(10):
                samples.extend(generate_samples())

            for i in range(4):
                lm = language_helpers.NgramLanguageModel(i+1, samples, tokenize=False)
                lib.plot.plot('js{}'.format(i+1), lm.js_with(true_char_ngram_lms[i]))

            with open('output_data/samples_{}.txt'.format(str(iteration+1).zfill(7)), 'w',encoding = 'utf8') as f:
                for s in samples:
                    s = "".join(s)
                    s = language_helpers.checkDNSFrom(s)
                    f.write(str(s) + "\n")

        if iteration % (10*change_line) == (10*change_line-1):
            #print(iteration)
            lib.plot.flush()
        
        lib.plot.tick()


[Start]
    1% Done! [Iteration]:        300 [300x iterations time   ]:     566.80 secs [SUM]:     566.80 secs
iter 299	time	1.8895049985249837	train disc cost	-3.2085421085357666	js4	0.3005426714280313	js1	0.05353855370632886	js3	0.24334272073455124	js2	0.15205331299283434
    2% Done! [Iteration]:        600 [300x iterations time   ]:     566.49 secs [SUM]:    1133.30 secs
iter 599	time	1.883965076605479	train disc cost	-2.2452640533447266	js4	0.26750123846770124	js1	0.04038219597254436	js3	0.18579508335213948	js2	0.09881305600986917
    3% Done! [Iteration]:        900 [300x iterations time   ]:     567.00 secs [SUM]:    1700.30 secs
iter 899	time	1.8852486324310302	train disc cost	-2.231243371963501	js4	0.2682883167928958	js1	0.029520641462756986	js3	0.17459163315501575	js2	0.08233977516286375
    4% Done! [Iteration]:       1200 [300x iterations time   ]:     566.74 secs [SUM]:    2267.04 secs
iter 1199	time	1.8849411582946778	train disc cost	-2.221461772918701	js4	0.2645278888854

   32% Done! [Iteration]:       9600 [300x iterations time   ]:     568.19 secs [SUM]:   18141.70 secs
iter 9599	time	1.888822271823883	train disc cost	-1.9706798791885376	js4	0.23524880424201156	js1	0.009239794725699017	js3	0.12214669738301491	js2	0.03459486916339452
   33% Done! [Iteration]:       9900 [300x iterations time   ]:     566.88 secs [SUM]:   18708.58 secs
iter 9899	time	1.8851279338200888	train disc cost	-1.9929325580596924	js4	0.22276778317016566	js1	0.007863665087276752	js3	0.11607918862243737	js2	0.03331419574462237
   34% Done! [Iteration]:      10200 [300x iterations time   ]:     567.89 secs [SUM]:   19276.47 secs
iter 10199	time	1.8884719928105673	train disc cost	-1.9794217348098755	js4	0.23263020686800545	js1	0.0077800413155102496	js3	0.11816443143209188	js2	0.032641573457465316
   35% Done! [Iteration]:      10500 [300x iterations time   ]:     567.82 secs [SUM]:   19844.29 secs
iter 10499	time	1.888150963783264	train disc cost	-1.9581587314605713	js4	0.234380011

   62% Done! [Iteration]:      18600 [300x iterations time   ]:     561.60 secs [SUM]:   35173.48 secs
iter 18599	time	1.8665335512161254	train disc cost	-1.8593308925628662	js4	0.21912634098410577	js1	0.0023417088396010042	js3	0.10997494362090515	js2	0.02141970559494583
   63% Done! [Iteration]:      18900 [300x iterations time   ]:     562.83 secs [SUM]:   35736.32 secs
iter 18899	time	1.8711238582928975	train disc cost	-1.886386752128601	js4	0.224328827405853	js1	0.0015831316341393565	js3	0.10761587780202928	js2	0.020737756365115212
   64% Done! [Iteration]:      19200 [300x iterations time   ]:     565.18 secs [SUM]:   36301.49 secs
iter 19199	time	1.8790230083465576	train disc cost	-1.8732694387435913	js4	0.22008417191014043	js1	0.0027565838183244507	js3	0.11083962291594321	js2	0.02197502651863635
   65% Done! [Iteration]:      19500 [300x iterations time   ]:     565.38 secs [SUM]:   36866.88 secs
iter 19499	time	1.8791134985287985	train disc cost	-1.8736333847045898	js4	0.221568

   92% Done! [Iteration]:      27600 [300x iterations time   ]:     565.42 secs [SUM]:   52208.10 secs
iter 27599	time	1.8797983407974244	train disc cost	-1.9536648988723755	js4	0.205516585052811	js1	0.0032686760255140543	js3	0.10207294073844463	js2	0.01969721271798594
   93% Done! [Iteration]:      27900 [300x iterations time   ]:     565.43 secs [SUM]:   52773.54 secs
iter 27899	time	1.8798721877733866	train disc cost	-2.0101776123046875	js4	0.21980122831367932	js1	0.0025924738054587493	js3	0.10687673619358266	js2	0.020099105662982154
   94% Done! [Iteration]:      28200 [300x iterations time   ]:     565.43 secs [SUM]:   53338.96 secs
iter 28199	time	1.879719336827596	train disc cost	-1.8762109279632568	js4	0.21404229209555412	js1	0.0014710055281318365	js3	0.10309127250619757	js2	0.018209947708532705
   95% Done! [Iteration]:      28500 [300x iterations time   ]:     565.53 secs [SUM]:   53904.49 secs
iter 28499	time	1.8794899733861288	train disc cost	-1.9469316005706787	js4	0.21015