## sequence_gan.py 

In [None]:
import numpy as np
import tensorflow as tf
import random
from dataloader import Gen_Data_loader, Dis_dataloader
from generator import Generator
from discriminator import Discriminator
from rollout import ROLLOUT
from target_lstm import TARGET_LSTM
import cPickle

#########################################################################################
#  Generator  Hyper-parameters
######################################################################################
EMB_DIM = 32 # embedding dimension
HIDDEN_DIM = 32 # hidden state dimension of lstm cell
SEQ_LENGTH = 20 # sequence length
START_TOKEN = 0
PRE_EPOCH_NUM = 10 #120 # supervise (maximum likelihood estimation) epochs
SEED = 88
BATCH_SIZE = 64

#########################################################################################
#  Discriminator  Hyper-parameters
#########################################################################################
dis_embedding_dim = 64
dis_filter_sizes = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20]
dis_num_filters = [100, 200, 200, 200, 200, 100, 100, 100, 100, 100, 160, 160]
dis_dropout_keep_prob = 0.75
dis_l2_reg_lambda = 0.2
dis_batch_size = 64

#########################################################################################
#  Basic Training Parameters
#########################################################################################
TOTAL_BATCH = 200
positive_file = 'save/real_data.txt'
negative_file = 'save/generator_sample.txt'
eval_file = 'save/eval_file.txt'
generated_num = 10000


def generate_samples(sess, trainable_model, batch_size, generated_num, output_file):
    # Generate Samples
    generated_samples = []
    for _ in range(int(generated_num / batch_size)):
        generated_samples.extend(trainable_model.generate(sess))

    with open(output_file, 'w') as fout:
        for poem in generated_samples:
            buffer = ' '.join([str(x) for x in poem]) + '\n'
            fout.write(buffer)


def target_loss(sess, target_lstm, data_loader):
    # target_loss means the oracle negative log-likelihood tested with the oracle model "target_lstm"
    # For more details, please see the Section 4 in https://arxiv.org/abs/1609.05473
    nll = []
    data_loader.reset_pointer()

    for it in xrange(data_loader.num_batch):
        batch = data_loader.next_batch()
        g_loss = sess.run(target_lstm.pretrain_loss, {target_lstm.x: batch})
        nll.append(g_loss)

    return np.mean(nll)


def pre_train_epoch(sess, trainable_model, data_loader):
    # Pre-train the generator using MLE for one epoch
    supervised_g_losses = []
    data_loader.reset_pointer()

    for it in xrange(data_loader.num_batch):
        batch = data_loader.next_batch()
        _, g_loss = trainable_model.pretrain_step(sess, batch)
        supervised_g_losses.append(g_loss)

    return np.mean(supervised_g_losses)


def main():
    random.seed(SEED)
    np.random.seed(SEED)
    assert START_TOKEN == 0

    gen_data_loader = Gen_Data_loader(BATCH_SIZE)
    likelihood_data_loader = Gen_Data_loader(BATCH_SIZE) # For testing
    vocab_size = 5000
    dis_data_loader = Dis_dataloader(BATCH_SIZE)

    generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN)
    target_params = cPickle.load(open('save/target_params.pkl'))
    target_lstm = TARGET_LSTM(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN, target_params) # The oracle model

    discriminator = Discriminator(sequence_length=20, num_classes=2, vocab_size=vocab_size, embedding_size=dis_embedding_dim, 
                                filter_sizes=dis_filter_sizes, num_filters=dis_num_filters, l2_reg_lambda=dis_l2_reg_lambda)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    # First, use the oracle model to provide the positive examples, which are sampled from the oracle data distribution
    generate_samples(sess, target_lstm, BATCH_SIZE, generated_num, positive_file)
    gen_data_loader.create_batches(positive_file)

    log = open('save/experiment-log.txt', 'w')
    #  pre-train generator
    print 'Start pre-training...'
    log.write('pre-training...\n')
    for epoch in xrange(PRE_EPOCH_NUM):
        loss = pre_train_epoch(sess, generator, gen_data_loader)
        if epoch % 5 == 0:
            generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)
            likelihood_data_loader.create_batches(eval_file)
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            print 'pre-train epoch ', epoch, 'test_loss ', test_loss
            buffer = 'epoch:\t'+ str(epoch) + '\tnll:\t' + str(test_loss) + '\n'
            log.write(buffer)

    print 'Start pre-training discriminator...'
    # Train 3 epoch on the generated data and do this for 50 times
    for _ in range(5):
        generate_samples(sess, generator, BATCH_SIZE, generated_num, negative_file)
        dis_data_loader.load_train_data(positive_file, negative_file)
        for _ in range(1):
            dis_data_loader.reset_pointer()
            for it in xrange(dis_data_loader.num_batch):
                x_batch, y_batch = dis_data_loader.next_batch()
                feed = {
                    discriminator.input_x: x_batch,
                    discriminator.input_y: y_batch,
                    discriminator.dropout_keep_prob: dis_dropout_keep_prob
                }
                _ = sess.run(discriminator.train_op, feed)

    rollout = ROLLOUT(generator, 0.8)

    print '#########################################################################'
    print 'Start Adversarial Training...'
    log.write('adversarial training...\n')
    for total_batch in range(TOTAL_BATCH):
        # Train the generator for one step
        for it in range(1):
            samples = generator.generate(sess)
            rewards = rollout.get_reward(sess, samples, 16, discriminator)
            feed = {generator.x: samples, generator.rewards: rewards}
            _ = sess.run(generator.g_updates, feed_dict=feed)

        # Test
        if total_batch % 5 == 0 or total_batch == TOTAL_BATCH - 1:
            generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)
            likelihood_data_loader.create_batches(eval_file)
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            buffer = 'epoch:\t' + str(total_batch) + '\tnll:\t' + str(test_loss) + '\n'
            print 'total_batch: ', total_batch, 'test_loss: ', test_loss
            log.write(buffer)

        # Update roll-out parameters
        rollout.update_params()

        # Train the discriminator
        for _ in range(5):
            generate_samples(sess, generator, BATCH_SIZE, generated_num, negative_file)
            dis_data_loader.load_train_data(positive_file, negative_file)

            for _ in range(3):
                dis_data_loader.reset_pointer()
                for it in xrange(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {
                        discriminator.input_x: x_batch,
                        discriminator.input_y: y_batch,
                        discriminator.dropout_keep_prob: dis_dropout_keep_prob
                    }
                    _ = sess.run(discriminator.train_op, feed)

    log.close()


if __name__ == '__main__':
    main()


Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Use tf.random.categorical instead.
Instructions for updating:
Use tf.cast instead.
Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
Instructions for updating:

Future major versions of TensorFlow will allow gradients to flow
into the labels input on backprop by default.

See `tf.nn.softmax_cross_entropy_with_logits_v2`.

Start pre-training...
pre-train epoch  0 test_loss  10.200813
pre-train epoch  5 test_loss  9.460708
Start pre-training discriminator...
#########################################################################
Start Adversarial Training...
Reward Array:  [array([5.173969 , 5.339993 , 4.871792 , 5.1443586, 5.542395 , 5.34111  ,
       4.7689176, 5.1707554, 4.6431193, 5.0064697, 5.2362967, 4.741186 ,
       5.229604 , 5.6311927, 5.726319 , 5.6285233, 5.1550283, 4.880095 ,
       5.502352 , 5.2734056, 5.79410

total_batch:  0 test_loss:  9.272493
Reward Array:  [array([1.2703019 , 2.4404438 , 1.804601  , 1.9956428 , 2.1992228 ,
       1.2481035 , 2.1852825 , 2.3950205 , 1.4819905 , 2.0728393 ,
       1.715114  , 1.683116  , 0.57178944, 1.5281397 , 0.74102706,
       2.1132207 , 2.302852  , 2.1678314 , 1.6924263 , 1.1458981 ,
       1.3319407 , 1.8463556 , 2.4485223 , 0.535472  , 3.2385154 ,
       1.8600992 , 1.0515375 , 1.9664973 , 1.4376103 , 1.2788932 ,
       2.1761734 , 2.9701445 , 1.4965684 , 2.598865  , 2.0385747 ,
       1.3092413 , 2.6117337 , 1.8214861 , 1.3518908 , 1.8347168 ,
       1.666344  , 0.6920674 , 2.8186872 , 2.3364356 , 0.7910003 ,
       1.7025017 , 1.8441777 , 1.4811486 , 1.0249808 , 1.5210589 ,
       0.6699518 , 2.0291204 , 2.336203  , 1.5932909 , 2.526305  ,
       1.2131729 , 2.22709   , 1.6130669 , 2.2781882 , 3.7785146 ,
       1.5512187 , 2.24656   , 1.3681229 , 1.6216038 ], dtype=float32), array([0.6903723 , 1.3722866 , 2.4062572 , 1.5133352 , 1.5896636 ,
    

Reward Array:  [array([0.11106391, 0.37270677, 0.19844389, 0.46246883, 0.21587002,
       0.01005491, 0.12696846, 2.3209434 , 0.58449626, 0.4484184 ,
       0.36525714, 0.7894664 , 0.10814165, 0.30807915, 0.17365336,
       0.02694932, 0.95055306, 0.90344024, 0.69101065, 0.44952792,
       0.78332376, 1.0671699 , 0.13120738, 0.8227654 , 0.27703345,
       0.3644327 , 0.12946534, 1.0249814 , 0.24322653, 1.231378  ,
       0.98878306, 0.68164766, 0.49600467, 0.6247019 , 2.6012852 ,
       0.6754242 , 0.92050946, 0.48793292, 0.15851721, 0.8557602 ,
       1.0434096 , 1.715964  , 0.5612669 , 1.6094983 , 0.5887856 ,
       0.8031151 , 0.25837836, 1.0539907 , 0.5300164 , 2.1723084 ,
       1.0167047 , 1.8769113 , 0.9111254 , 0.9401958 , 0.5825691 ,
       1.3233831 , 0.32909882, 1.0138904 , 0.76034844, 0.7454566 ,
       1.2035108 , 1.5000604 , 0.01452411, 0.6267717 ], dtype=float32), array([0.677451  , 0.6611046 , 0.9963552 , 0.28494883, 0.6066801 ,
       0.27904153, 0.94240046, 0.33910608