In [1]:
import os, sys
sys.path.append(os.getcwd())

import time

import numpy as np
import tensorflow as tf

import language_helpers
import tflib as lib
import tflib.ops.linear
import tflib.ops.conv1d
import tflib.plot

In [2]:
# Download Google Billion Word at http://www.statmt.org/lm-benchmark/ and
# fill in the path to the extracted files here!
DATA_DIR = 'data/output'
if len(DATA_DIR) == 0:
    raise Exception("Please specify path to data directory in gan_language.py!")

BATCH_SIZE = 64 # Batch size
ITERS = 50 # How many iterations to train for  default = 200000
SEQ_LEN = 32 # Sequence length in characters
DIM = 512 # Model dimensionality. This is fairly slow and overfits, even on
          # Billion Word. Consider decreasing for smaller datasets.
CRITIC_ITERS = 10 # How many critic iterations per generator iteration. We
                  # use 10 for the results in the paper, but 5 should work fine
                  # as well.
LAMBDA = 10 # Gradient penalty lambda hyperparameter.
MAX_N_EXAMPLES = 100000 # Max number of data examples to load. If data loading
                          # is too slow or takes too much RAM, you can decrease
                          # this (at the expense of having less training data). default = 10000000

In [3]:
lib.print_model_settings(locals().copy())

lines, charmap, inv_charmap = language_helpers.load_dataset(
    max_length=SEQ_LEN,
    max_n_examples=MAX_N_EXAMPLES,
    data_dir=DATA_DIR
)

Uppercase local vars:
	BATCH_SIZE: 64
	CRITIC_ITERS: 10
	DATA_DIR: data/output
	DIM: 512
	ITERS: 50
	LAMBDA: 10
	MAX_N_EXAMPLES: 100000
	SEQ_LEN: 32
loading dataset...
('T', 'h', 'e', ' ', 'f', 'e', 'w', 'e', 'r', ' ', 'm', 'i', 's', 's', 'i', 'l', 'e', 's', ' ', 'R', 'u', 's', 's', 'i', 'a', ' ', 'h', 'a', 's', ' ', ',', ' ')
('H', 'e', ' ', 'c', 'o', 'u', 'l', 'd', ' ', 'n', 'o', 't', ' ', 'e', 'x', 'p', 'l', 'a', 'i', 'n', ' ', 'w', 'h', 'y', ' ', 'S', 'h', 'a', 'k', 'i', 'r', ' ')
('W', 'e', ' ', 'p', 'u', 'n', 'i', 's', 'h', ' ', 'c', 'r', 'i', 'm', 'i', 'n', 'a', 'l', ' ', 'b', 'e', 'h', 'a', 'v', 'i', 'o', 'r', ' ', 'i', 'n', ' ', 'i')
('O', 'v', 'e', 'r', ' ', 't', 'h', 'e', ' ', 'p', 'a', 's', 't', ' ', 't', 'h', 'r', 'e', 'e', ' ', 'w', 'e', 'e', 'k', 's', ' ', 'T', 'h', 'e', ' ', 'S', 'u')
('O', 'n', ' ', 'o', 'n', 'e', ' ', 'h', 'a', 'n', 'd', ' ', ',', ' ', '3', '2', ' ', 'p', 'e', 'r', 'c', 'e', 'n', 't', ' ', 'o', 'f', ' ', 'r', 'e', 'g', 'i')
('T', 'h', 'e', ' ', 'o', '

In [4]:
def softmax(logits):
    return tf.reshape(
        tf.nn.softmax(
            tf.reshape(logits, [-1, len(charmap)])
        ),
        tf.shape(logits)
    )

In [5]:
def make_noise(shape):
    return tf.random_normal(shape)

In [6]:
def ResBlock(name, inputs):
    output = inputs
    output = tf.nn.relu(output)
    output = lib.ops.conv1d.Conv1D(name+'.1', DIM, DIM, 5, output)
    output = tf.nn.relu(output)
    output = lib.ops.conv1d.Conv1D(name+'.2', DIM, DIM, 5, output)
    return inputs + (0.3*output)

In [7]:
def Generator(n_samples, prev_outputs=None):
    output = make_noise(shape=[n_samples, 128])
    output = lib.ops.linear.Linear('Generator.Input', 128, SEQ_LEN*DIM, output)
    output = tf.reshape(output, [-1, DIM, SEQ_LEN])
    output = ResBlock('Generator.1', output)
    output = ResBlock('Generator.2', output)
    output = ResBlock('Generator.3', output)
    output = ResBlock('Generator.4', output)
    output = ResBlock('Generator.5', output)
    output = lib.ops.conv1d.Conv1D('Generator.Output', DIM, len(charmap), 1, output)
    output = tf.transpose(output, [0, 2, 1])
    output = softmax(output)
    return output

In [8]:
def Discriminator(inputs):
    output = tf.transpose(inputs, [0,2,1])
    output = lib.ops.conv1d.Conv1D('Discriminator.Input', len(charmap), DIM, 1, output)
    output = ResBlock('Discriminator.1', output)
    output = ResBlock('Discriminator.2', output)
    output = ResBlock('Discriminator.3', output)
    output = ResBlock('Discriminator.4', output)
    output = ResBlock('Discriminator.5', output)
    output = tf.reshape(output, [-1, SEQ_LEN*DIM])
    output = lib.ops.linear.Linear('Discriminator.Output', SEQ_LEN*DIM, 1, output)
    return output

In [9]:
real_inputs_discrete = tf.placeholder(tf.int32, shape=[BATCH_SIZE, SEQ_LEN])
real_inputs = tf.one_hot(real_inputs_discrete, len(charmap))
fake_inputs = Generator(BATCH_SIZE)
fake_inputs_discrete = tf.argmax(fake_inputs, fake_inputs.get_shape().ndims-1)

In [10]:
disc_real = Discriminator(real_inputs) 
disc_fake = Discriminator(fake_inputs)

In [11]:
disc_cost = tf.reduce_mean(disc_fake) - tf.reduce_mean(disc_real)
gen_cost = -tf.reduce_mean(disc_fake)

In [12]:
# WGAN lipschitz-penalty
alpha = tf.random_uniform(
    shape=[BATCH_SIZE,1,1], 
    minval=0.,
    maxval=1.
)

In [13]:
differences = fake_inputs - real_inputs
interpolates = real_inputs + (alpha*differences)
gradients = tf.gradients(Discriminator(interpolates), [interpolates])[0]
slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1,2]))
gradient_penalty = tf.reduce_mean((slopes-1.)**2)
disc_cost += LAMBDA*gradient_penalty

In [14]:
gen_params = lib.params_with_name('Generator')
disc_params = lib.params_with_name('Discriminator')

In [15]:
gen_train_op = tf.train.AdamOptimizer(learning_rate=1e-4, beta1=0.5, beta2=0.9).minimize(gen_cost, var_list=gen_params)
disc_train_op = tf.train.AdamOptimizer(learning_rate=1e-4, beta1=0.5, beta2=0.9).minimize(disc_cost, var_list=disc_params)

In [16]:
# Dataset iterator
def inf_train_gen():
    while True:
        np.random.shuffle(lines)
        for i in range(0, len(lines)-BATCH_SIZE+1, BATCH_SIZE):
            yield np.array(
                [[charmap[c] for c in l] for l in lines[i:i+BATCH_SIZE]], 
                dtype='int32'
            )

In [17]:
# During training we monitor JS divergence between the true & generated ngram
# distributions for n=1,2,3,4. To get an idea of the optimal values, we
# evaluate these statistics on a held-out set first.

true_char_ngram_lms = [language_helpers.NgramLanguageModel(i+1, lines[10*BATCH_SIZE:], tokenize=False) for i in range(4)]
validation_char_ngram_lms = [language_helpers.NgramLanguageModel(i+1, lines[:10*BATCH_SIZE], tokenize=False) for i in range(4)]
for i in range(4):
    print ( "validation set JSD for n={}: {}".format(i+1, true_char_ngram_lms[i].js_with(validation_char_ngram_lms[i])) )
true_char_ngram_lms = [language_helpers.NgramLanguageModel(i+1, lines, tokenize=False) for i in range(4)]

validation set JSD for n=1: 0.0009053959053045796
validation set JSD for n=2: 0.015119320254753566
validation set JSD for n=3: 0.08686910351792093
validation set JSD for n=4: 0.2448777496067526


In [18]:
with tf.Session() as session:

    session.run(tf.global_variables_initializer())

    def generate_samples():
        samples = session.run(fake_inputs)
        samples = np.argmax(samples, axis=2)
        decoded_samples = []
        for i in range(len(samples)):
            decoded = []
            for j in range(len(samples[i])):
                decoded.append(inv_charmap[samples[i][j]])
            decoded_samples.append(tuple(decoded))
        return decoded_samples

    gen = inf_train_gen()

    sum_time = 0.
    for iteration in range(ITERS):
        start_time = time.time()
        
        if (iteration == 0):
            now_time = time.clock()
            print("[Start]")

        # Train generator
        if iteration > 0:
            _ = session.run(gen_train_op)

        # Train critic
        for i in range(CRITIC_ITERS):
            _data = gen.__next__()
            _disc_cost, _ = session.run(
                [disc_cost, disc_train_op],
                feed_dict={real_inputs_discrete:_data}
            )

        if iteration % 100 == 99:
            print("[Iteration]:"+str(iteration+1))
            after_time=time.clock() - now_time
            sum_time+=after_time
            print( "[Cost time]: "+str(after_time)+" [Sum time]: "+str(sum_time)+" [ETA]: "+str( sum_time/( (iteration+1)/ITERS )-sum_time  ) )
            now_time = time.clock()
        
        lib.plot.plot('time', time.time() - start_time)
        lib.plot.plot('train disc cost', _disc_cost)

        if iteration % 100 == 99:
            #print("checkpintB"+str(iteration+1))
            samples = []
            for i in range(10):
                samples.extend(generate_samples())

            for i in range(4):
                lm = language_helpers.NgramLanguageModel(i+1, samples, tokenize=False)
                lib.plot.plot('js{}'.format(i+1), lm.js_with(true_char_ngram_lms[i]))

            with open('output/samples_{}.txt'.format(iteration), 'w',encoding = 'utf8') as f:
                for s in samples:
                    s = "".join(s)
                    f.write(s + "\n")

        #if iteration % 100 == 99:
            #print(iteration)
            #lib.plot.flush()
        
        lib.plot.tick()


Start
Cost time: 20.092146817616428 Sum time: 20.092146817616428 ETA: 80.36858727046571
ITER:10
Cost time: 19.166597396026273 Sum time: 39.2587442136427 ETA: 58.88811632046405
ITER:20
Cost time: 19.229042770102794 Sum time: 58.487786983745494 ETA: 38.99185798916366
ITER:30
Cost time: 19.216719222000357 Sum time: 77.70450620574584 ETA: 19.426126551436454
ITER:40
Cost time: 19.247280223834522 Sum time: 96.95178642958037 ETA: 0.0
ITER:50
