In [1]:
import numpy as np
import tensorflow as tf
import random
import time
from dataloader import Gen_Data_loader, Dis_dataloader
from generator import Generator
from discriminator import Discriminator
from rollout import ROLLOUT
#from target_lstm import TARGET_LSTM
#import domain_translate
import pickle

In [2]:
#########################################################################################
#  Generator  Hyper-parameters
######################################################################################
EMB_DIM = 32 # embedding dimension
HIDDEN_DIM = 32 # hidden state dimension of lstm cell
SEQ_LENGTH = 32 # sequence length
START_TOKEN = 0
PRE_EPOCH_NUM = 20 # supervise (maximum likelihood estimation) epochs (預設為120)
SEED = 88
BATCH_SIZE = 64

In [3]:
#########################################################################################
#  Discriminator  Hyper-parameters
#########################################################################################
dis_embedding_dim = 64
dis_filter_sizes = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20]
dis_num_filters = [100, 200, 200, 200, 200, 100, 100, 100, 100, 100, 160, 160]
dis_dropout_keep_prob = 0.75
dis_l2_reg_lambda = 0.2
dis_batch_size = 64
PRETRAIN_DIS_NUM = 20 #pre-train discriminator times

In [4]:
#########################################################################################
#  Basic Training Parameters
#########################################################################################
TOTAL_BATCH = 500
positive_file = ('../Dataset/AlexaTop100K_Separated_Digital/'
                  +'TopDomainName.Less33.Separated.Digital-ALL-OF-100000.txt') # 預設值為 'save/real_data.txt'
output_path = 'save/' #預設值為 'save/generator_sample.txt' 更改為會更變
#eval_file = 'save/eval_file.txt'
generated_num = 10000

In [5]:
def generate_samples(sess, trainable_model, batch_size, generated_num, output_file):
    Digit2Real = False
    # Generate Samples
    generated_samples = []
    for _ in range(int(generated_num / batch_size)):
        generated_samples.extend(trainable_model.generate(sess))

    with open(output_file, 'w') as fout:
        for poem in generated_samples:
            buffer = ' '.join([str(x) for x in poem]) + '\n'
            fout.write(buffer)
            
    if "adversarial_gen" in output_file:
        now=output_file[37:]
        output_file="save/adversarial_gen/generator_fake_domain_name_"+now
        Digit2Real = True
    elif "pretrain_gen" in output_file:
        now=output_file[34:]
        output_file="save/pretrain_gen/generator_fake_domain_name_"+now
        Digit2Real = True
    
    if Digit2Real == True :
        digital_table = [" ","0","1","2","3","4","5","6","7","8","9",".","-",
                            "a","b","c","d","e","f","g","h","i","j","k","l","m",
                            "n","o","p","q","r","s","t","u","v","w","x","y","z","_"]    
        with open(output_file, 'w') as fout:
            for poem in generated_samples:
                buffer = "".join([digital_table[int(x)] for x in poem]) + '\n'
                fout.write(buffer)                    

In [6]:
def target_loss(sess, target_lstm, data_loader):
    # target_loss means the oracle negative log-likelihood tested with the oracle model "target_lstm"
    # For more details, please see the Section 4 in https://arxiv.org/abs/1609.05473
    nll = []
    data_loader.reset_pointer()

    for it in range(data_loader.num_batch):
        batch = data_loader.next_batch()
        #print(target_lstm.pretrain_loss)
        #print({target_lstm.x: batch})
        g_loss = sess.run(target_lstm.pretrain_loss, {target_lstm.x: batch})
        nll.append(g_loss)

    return np.mean(nll)

In [7]:
def pre_train_epoch(sess, trainable_model, data_loader):
    # Pre-train the generator using MLE for one epoch
    supervised_g_losses = []
    data_loader.reset_pointer()
    
    for it in range(data_loader.num_batch):
    #for it in range(50):
        #print("pre_train"+str(it))
        if it % int(data_loader.num_batch / 30) == 0:
            print("pre_train_iteration : {0:6} / {1:6}".format( (it+1), (data_loader.num_batch) ), end="\r")
        batch = data_loader.next_batch()
        _, g_loss = trainable_model.pretrain_step(sess, batch)
        supervised_g_losses.append(g_loss)

    return np.mean(supervised_g_losses)

In [8]:
random.seed(SEED)
np.random.seed(SEED)
assert START_TOKEN == 0

gen_data_loader = Gen_Data_loader(BATCH_SIZE)
likelihood_data_loader = Gen_Data_loader(BATCH_SIZE) # For testing
vocab_size = 40 #預設5000 
dis_data_loader = Dis_dataloader(BATCH_SIZE)

generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN)
#target_params = pickle.load(open('save/target_params_py3.pkl','rb'))
#target_lstm = TARGET_LSTM(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN, target_params) # The oracle model

discriminator = Discriminator(sequence_length=32, num_classes=2, vocab_size=vocab_size, embedding_size=dis_embedding_dim, 
                            filter_sizes=dis_filter_sizes, num_filters=dis_num_filters, l2_reg_lambda=dis_l2_reg_lambda)

In [9]:
# 參數配置
config = tf.ConfigProto()
# 使用allow_growth option，剛一開始分配少量的GPU容量，然後按需慢慢的增加
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
sess.run(tf.global_variables_initializer())

In [10]:
# First, use the oracle model to provide the positive examples, which are sampled from the oracle data distribution
#generate_samples(sess, target_lstm, BATCH_SIZE, generated_num, positive_file)
gen_data_loader.create_batches(positive_file)

log = open('save/experiment-log.txt', 'w')
#  pre-train generator
print('Start pre-training...')
log.write('pre-training...\n')
now_time = time.clock()
sum_time = 0.
for epoch in range(PRE_EPOCH_NUM):
    loss = pre_train_epoch(sess, generator, gen_data_loader)    
    
    eval_file = ( output_path + "pretrain_gen/generator_digit_{0}.txt").format(str(epoch+1).zfill(2))
    generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)
    likelihood_data_loader.create_batches(eval_file)
    
    '''test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
    print('pre-train epoch ', epoch, 'test_loss ', test_loss)
    buffer = 'epoch:\t'+ str(epoch) + '\tnll:\t' + str(test_loss) + '\n'
    log.write(buffer)'''
    after_time=time.clock() - now_time
    eta_time = (after_time / (epoch+1) )*(PRE_EPOCH_NUM-(epoch+1))
    print("[MLE Epoch]: {0:5} [Cost Time]: {1:10.2f} secs [ETA]: {2:10.2f} secs".format( (epoch+1), after_time, eta_time))

Start pre-training...
[MLE Epoch]:     1 [Cost Time]:      70.18 secs [ETA]:    1333.40 secs
[MLE Epoch]:     2 [Cost Time]:     139.07 secs [ETA]:    1251.65 secs
[MLE Epoch]:     3 [Cost Time]:     207.76 secs [ETA]:    1177.28 secs
[MLE Epoch]:     4 [Cost Time]:     276.43 secs [ETA]:    1105.72 secs
[MLE Epoch]:     5 [Cost Time]:     345.00 secs [ETA]:    1035.00 secs
[MLE Epoch]:     6 [Cost Time]:     414.15 secs [ETA]:     966.35 secs
[MLE Epoch]:     7 [Cost Time]:     482.83 secs [ETA]:     896.68 secs
[MLE Epoch]:     8 [Cost Time]:     551.42 secs [ETA]:     827.12 secs
[MLE Epoch]:     9 [Cost Time]:     620.21 secs [ETA]:     758.03 secs
[MLE Epoch]:    10 [Cost Time]:     688.91 secs [ETA]:     688.91 secs
[MLE Epoch]:    11 [Cost Time]:     757.22 secs [ETA]:     619.55 secs
[MLE Epoch]:    12 [Cost Time]:     825.61 secs [ETA]:     550.41 secs
[MLE Epoch]:    13 [Cost Time]:     894.04 secs [ETA]:     481.41 secs
[MLE Epoch]:    14 [Cost Time]:     962.38 secs [ETA]: 

In [11]:
print('Start pre-training discriminator...')
# Train 3 epoch on the generated data and do this for 50 times
now_time = time.clock()
sum_time = 0.
pre_D_epoch = 0
for _ in range(PRETRAIN_DIS_NUM): #50
    pretrain_D = ( output_path + "pretrain_discriminator.txt")
    generate_samples(sess, generator, BATCH_SIZE, generated_num, pretrain_D)
    dis_data_loader.load_train_data(positive_file, pretrain_D)
    for _ in range(1):
        dis_data_loader.reset_pointer()
        for it in range(dis_data_loader.num_batch):
            if it % int(dis_data_loader.num_batch / 30) == 0:
                print("pre_train_iteration : {0:6} / {1:6}".format( (it+1), (dis_data_loader.num_batch) ), end="\r")
            x_batch, y_batch = dis_data_loader.next_batch()
            feed = {
                discriminator.input_x: x_batch,
                discriminator.input_y: y_batch,
                discriminator.dropout_keep_prob: dis_dropout_keep_prob
            }
            _ = sess.run(discriminator.train_op, feed)
    after_time=time.clock() - now_time
    eta_time = (after_time / (pre_D_epoch+1) )*(PRETRAIN_DIS_NUM-(pre_D_epoch+1))
    print("[Pre-train Discriminator Epoch]: {0:5} [Cost Time]: {1:10.2f} secs [ETA]: {2:10.2f} secs".format( (pre_D_epoch+1), after_time, eta_time))
    pre_D_epoch = pre_D_epoch +1

rollout = ROLLOUT(generator, 0.8)

Start pre-training discriminator...
[Pre-train Discriminator Epoch]:     1 [Cost Time]:      35.35 secs [ETA]:     671.56 secs
[Pre-train Discriminator Epoch]:     2 [Cost Time]:      65.34 secs [ETA]:     588.02 secs
[Pre-train Discriminator Epoch]:     3 [Cost Time]:      94.83 secs [ETA]:     537.38 secs
[Pre-train Discriminator Epoch]:     4 [Cost Time]:     124.49 secs [ETA]:     497.96 secs
[Pre-train Discriminator Epoch]:     5 [Cost Time]:     154.08 secs [ETA]:     462.25 secs
[Pre-train Discriminator Epoch]:     6 [Cost Time]:     183.50 secs [ETA]:     428.16 secs
[Pre-train Discriminator Epoch]:     7 [Cost Time]:     213.14 secs [ETA]:     395.83 secs
[Pre-train Discriminator Epoch]:     8 [Cost Time]:     242.45 secs [ETA]:     363.68 secs
[Pre-train Discriminator Epoch]:     9 [Cost Time]:     271.75 secs [ETA]:     332.14 secs
[Pre-train Discriminator Epoch]:    10 [Cost Time]:     301.13 secs [ETA]:     301.13 secs
[Pre-train Discriminator Epoch]:    11 [Cost Time]:   

In [12]:
print('#########################################################################')
print('Start Adversarial Training...')
log.write('adversarial training...\n')
now_time = time.clock()
sum_time = 0.
for total_batch in range(TOTAL_BATCH):
    # Train the generator for one step
    for it in range(1):
        samples = generator.generate(sess)
        rewards = rollout.get_reward(sess, samples, 16, discriminator)
        feed = {generator.x: samples, generator.rewards: rewards}
        _ = sess.run(generator.g_updates, feed_dict=feed)

    # Test
    '''if total_batch % 5 == 0 or total_batch == TOTAL_BATCH - 1:
        generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)
        likelihood_data_loader.create_batches(eval_file)
        #print(total_batch)        
        test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
        buffer = 'epoch:\t' + str(total_batch) + '\tnll:\t' + str(test_loss) + '\n'
        print('total_batch: ', total_batch, 'test_loss: ', test_loss)
        log.write(buffer)'''

    # Update roll-out parameters
    rollout.update_params()

    # Train the discriminator
    for _ in range(1): #default value=range(5)
        adversarial_D = ( output_path + "adversarial_gen/generator_digit_{0}.txt").format(str(total_batch+1).zfill(3))
        generate_samples(sess, generator, BATCH_SIZE, generated_num, adversarial_D)
        dis_data_loader.load_train_data(positive_file, adversarial_D)
        
        for _ in range(1): #default value=range(3)
            dis_data_loader.reset_pointer()
            for it in range(dis_data_loader.num_batch):
                if it % int(dis_data_loader.num_batch / 30) == 0:
                    print("Discriminator Epoch {0:5} iteration: {1:6} / {2:6}".format( (total_batch+1), (it+1), (dis_data_loader.num_batch) ), end="\r")
                x_batch, y_batch = dis_data_loader.next_batch()
                feed = {
                    discriminator.input_x: x_batch,
                    discriminator.input_y: y_batch,
                    discriminator.dropout_keep_prob: dis_dropout_keep_prob
                }
                _ = sess.run(discriminator.train_op, feed)
                
    after_time=time.clock() - now_time
    eta_time = (after_time / (total_batch+1) )*(TOTAL_BATCH-(total_batch+1))
    print("[GAN Epoch]: {0:5} [Cost Time]: {1:10.2f} secs [ETA]: {2:10.2f} secs".format( (total_batch+1), after_time, eta_time))
    
log.close()

#########################################################################
Start Adversarial Training...
[GAN Epoch]:     1 [Cost Time]:      39.94 secs [ETA]:   19928.05 secs
[GAN Epoch]:     2 [Cost Time]:      79.26 secs [ETA]:   19735.14 secs
[GAN Epoch]:     3 [Cost Time]:     118.22 secs [ETA]:   19585.73 secs
[GAN Epoch]:     4 [Cost Time]:     157.09 secs [ETA]:   19479.01 secs
[GAN Epoch]:     5 [Cost Time]:     196.32 secs [ETA]:   19435.42 secs
[GAN Epoch]:     6 [Cost Time]:     235.55 secs [ETA]:   19393.70 secs
[GAN Epoch]:     7 [Cost Time]:     274.25 secs [ETA]:   19315.12 secs
[GAN Epoch]:     8 [Cost Time]:     313.23 secs [ETA]:   19263.40 secs
[GAN Epoch]:     9 [Cost Time]:     352.12 secs [ETA]:   19210.07 secs
[GAN Epoch]:    10 [Cost Time]:     391.18 secs [ETA]:   19168.03 secs
[GAN Epoch]:    11 [Cost Time]:     430.06 secs [ETA]:   19118.09 secs
[GAN Epoch]:    12 [Cost Time]:     469.17 secs [ETA]:   19079.48 secs
[GAN Epoch]:    13 [Cost Time]:     508.42 s

[GAN Epoch]:   115 [Cost Time]:    4497.24 secs [ETA]:   15055.98 secs
[GAN Epoch]:   116 [Cost Time]:    4536.14 secs [ETA]:   15016.20 secs
[GAN Epoch]:   117 [Cost Time]:    4575.51 secs [ETA]:   14977.97 secs
[GAN Epoch]:   118 [Cost Time]:    4614.68 secs [ETA]:   14939.05 secs
[GAN Epoch]:   119 [Cost Time]:    4653.61 secs [ETA]:   14899.36 secs
[GAN Epoch]:   120 [Cost Time]:    4692.99 secs [ETA]:   14861.13 secs
[GAN Epoch]:   121 [Cost Time]:    4732.02 secs [ETA]:   14821.80 secs
[GAN Epoch]:   122 [Cost Time]:    4771.62 secs [ETA]:   14784.19 secs
[GAN Epoch]:   123 [Cost Time]:    4811.06 secs [ETA]:   14746.09 secs
[GAN Epoch]:   124 [Cost Time]:    4850.38 secs [ETA]:   14707.60 secs
[GAN Epoch]:   125 [Cost Time]:    4890.25 secs [ETA]:   14670.76 secs
[GAN Epoch]:   126 [Cost Time]:    4929.50 secs [ETA]:   14632.01 secs
[GAN Epoch]:   127 [Cost Time]:    4969.08 secs [ETA]:   14594.22 secs
[GAN Epoch]:   128 [Cost Time]:    5009.44 secs [ETA]:   14558.68 secs
[GAN E

[GAN Epoch]:   230 [Cost Time]:    9001.20 secs [ETA]:   10566.63 secs
[GAN Epoch]:   231 [Cost Time]:    9040.42 secs [ETA]:   10527.59 secs
[GAN Epoch]:   232 [Cost Time]:    9079.58 secs [ETA]:   10488.48 secs
[GAN Epoch]:   233 [Cost Time]:    9118.95 secs [ETA]:   10449.61 secs
[GAN Epoch]:   234 [Cost Time]:    9158.20 secs [ETA]:   10410.60 secs
[GAN Epoch]:   235 [Cost Time]:    9197.63 secs [ETA]:   10371.79 secs
[GAN Epoch]:   236 [Cost Time]:    9236.84 secs [ETA]:   10332.73 secs
[GAN Epoch]:   237 [Cost Time]:    9275.89 secs [ETA]:   10293.50 secs
[GAN Epoch]:   238 [Cost Time]:    9315.26 secs [ETA]:   10254.61 secs
[GAN Epoch]:   239 [Cost Time]:    9354.75 secs [ETA]:   10215.86 secs
[GAN Epoch]:   240 [Cost Time]:    9394.38 secs [ETA]:   10177.24 secs
[GAN Epoch]:   241 [Cost Time]:    9433.67 secs [ETA]:   10138.26 secs
[GAN Epoch]:   242 [Cost Time]:    9473.16 secs [ETA]:   10099.48 secs
[GAN Epoch]:   243 [Cost Time]:    9512.34 secs [ETA]:   10060.38 secs
[GAN E

[GAN Epoch]:   345 [Cost Time]:   13524.62 secs [ETA]:    6076.28 secs
[GAN Epoch]:   346 [Cost Time]:   13563.87 secs [ETA]:    6037.10 secs
[GAN Epoch]:   347 [Cost Time]:   13603.35 secs [ETA]:    5998.02 secs
[GAN Epoch]:   348 [Cost Time]:   13643.05 secs [ETA]:    5959.04 secs
[GAN Epoch]:   349 [Cost Time]:   13682.36 secs [ETA]:    5919.87 secs
[GAN Epoch]:   350 [Cost Time]:   13721.75 secs [ETA]:    5880.75 secs
[GAN Epoch]:   351 [Cost Time]:   13760.94 secs [ETA]:    5841.54 secs
[GAN Epoch]:   352 [Cost Time]:   13800.63 secs [ETA]:    5802.54 secs
[GAN Epoch]:   353 [Cost Time]:   13839.96 secs [ETA]:    5763.38 secs
[GAN Epoch]:   354 [Cost Time]:   13879.17 secs [ETA]:    5724.18 secs
[GAN Epoch]:   355 [Cost Time]:   13919.00 secs [ETA]:    5685.22 secs
[GAN Epoch]:   356 [Cost Time]:   13958.49 secs [ETA]:    5646.13 secs
[GAN Epoch]:   357 [Cost Time]:   13997.67 secs [ETA]:    5606.91 secs
[GAN Epoch]:   358 [Cost Time]:   14037.66 secs [ETA]:    5568.01 secs
[GAN E

[GAN Epoch]:   460 [Cost Time]:   18076.86 secs [ETA]:    1571.90 secs
[GAN Epoch]:   461 [Cost Time]:   18116.62 secs [ETA]:    1532.64 secs
[GAN Epoch]:   462 [Cost Time]:   18156.25 secs [ETA]:    1493.37 secs
[GAN Epoch]:   463 [Cost Time]:   18196.22 secs [ETA]:    1454.13 secs
[GAN Epoch]:   464 [Cost Time]:   18235.93 secs [ETA]:    1414.86 secs
[GAN Epoch]:   465 [Cost Time]:   18275.45 secs [ETA]:    1375.57 secs
[GAN Epoch]:   466 [Cost Time]:   18314.91 secs [ETA]:    1336.28 secs
[GAN Epoch]:   467 [Cost Time]:   18355.35 secs [ETA]:    1297.06 secs
[GAN Epoch]:   468 [Cost Time]:   18394.93 secs [ETA]:    1257.77 secs
[GAN Epoch]:   469 [Cost Time]:   18434.49 secs [ETA]:    1218.48 secs
[GAN Epoch]:   470 [Cost Time]:   18474.24 secs [ETA]:    1179.21 secs
[GAN Epoch]:   471 [Cost Time]:   18514.62 secs [ETA]:    1139.97 secs
[GAN Epoch]:   472 [Cost Time]:   18554.29 secs [ETA]:    1100.68 secs
[GAN Epoch]:   473 [Cost Time]:   18594.19 secs [ETA]:    1061.40 secs
[GAN E