In [1]:
import numpy as np
import tensorflow as tf
import random
import time
from dataloader import Gen_Data_loader, Dis_dataloader
from generator import Generator
from discriminator import Discriminator
from rollout import ROLLOUT
#from target_lstm import TARGET_LSTM
#import domain_translate
import pickle

In [2]:
#########################################################################################
#  Generator  Hyper-parameters
######################################################################################
EMB_DIM = 32 # embedding dimension
HIDDEN_DIM = 32 # hidden state dimension of lstm cell
SEQ_LENGTH = 32 # sequence length
START_TOKEN = 0
PRE_EPOCH_NUM = 50 # supervise (maximum likelihood estimation) epochs (預設為120)
SEED = 88
BATCH_SIZE = 64

In [3]:
#########################################################################################
#  Discriminator  Hyper-parameters
#########################################################################################
dis_embedding_dim = 64
dis_filter_sizes = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20]
dis_num_filters = [100, 200, 200, 200, 200, 100, 100, 100, 100, 100, 160, 160]
dis_dropout_keep_prob = 0.75
dis_l2_reg_lambda = 0.2
dis_batch_size = 64
PRETRAIN_DIS_NUM = 50 #pre-train discriminator times

In [4]:
#########################################################################################
#  Basic Training Parameters
#########################################################################################
TOTAL_BATCH = 144
positive_file = ('../Dataset/AlexaTop100K_Separated_Digital/'
                  +'TopDomainName.Less33.Separated.Digital-ALL-OF-100000.txt') # 預設值為 'save/real_data.txt'
negative_file = 'save/' #預設值為 'save/generator_sample.txt' 更改為會更變
eval_file = 'save/eval_file.txt'
generated_num = 10000

In [5]:
def generate_samples(sess, trainable_model, batch_size, generated_num, output_file):
    Digit2Real = False
    # Generate Samples
    generated_samples = []
    for _ in range(int(generated_num / batch_size)):
        generated_samples.extend(trainable_model.generate(sess))

    with open(output_file, 'w') as fout:
        for poem in generated_samples:
            buffer = ' '.join([str(x) for x in poem]) + '\n'
            fout.write(buffer)
            
    if "adversarial_gen" in output_file:
        now=output_file[37:]
        output_file="save/adversarial_gen/generator_fake_domain_name_"+now
        Digit2Real = True
    
    if Digit2Real == True :
        digital_table = [" ","0","1","2","3","4","5","6","7","8","9",".","-",
                            "a","b","c","d","e","f","g","h","i","j","k","l","m",
                            "n","o","p","q","r","s","t","u","v","w","x","y","z","_"]    
        with open(output_file, 'w') as fout:
            for poem in generated_samples:
                buffer = "".join([digital_table[int(x)] for x in poem]) + '\n'
                fout.write(buffer)                    

In [6]:
def target_loss(sess, target_lstm, data_loader):
    # target_loss means the oracle negative log-likelihood tested with the oracle model "target_lstm"
    # For more details, please see the Section 4 in https://arxiv.org/abs/1609.05473
    nll = []
    data_loader.reset_pointer()

    for it in range(data_loader.num_batch):
        batch = data_loader.next_batch()
        #print(target_lstm.pretrain_loss)
        #print({target_lstm.x: batch})
        g_loss = sess.run(target_lstm.pretrain_loss, {target_lstm.x: batch})
        nll.append(g_loss)

    return np.mean(nll)

In [7]:
def pre_train_epoch(sess, trainable_model, data_loader):
    # Pre-train the generator using MLE for one epoch
    supervised_g_losses = []
    data_loader.reset_pointer()
    
    for it in range(data_loader.num_batch):
    #for it in range(50):
        #print("pre_train"+str(it))
        if it % int(data_loader.num_batch / 30) == 0:
            print("pre_train_iteration : {0:6} / {1:6}".format( (it+1), (data_loader.num_batch) ), end="\r")
        batch = data_loader.next_batch()
        _, g_loss = trainable_model.pretrain_step(sess, batch)
        supervised_g_losses.append(g_loss)

    return np.mean(supervised_g_losses)

In [8]:
random.seed(SEED)
np.random.seed(SEED)
assert START_TOKEN == 0

gen_data_loader = Gen_Data_loader(BATCH_SIZE)
likelihood_data_loader = Gen_Data_loader(BATCH_SIZE) # For testing
vocab_size = 40 #預設5000 
dis_data_loader = Dis_dataloader(BATCH_SIZE)

generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN)
#target_params = pickle.load(open('save/target_params_py3.pkl','rb'))
#target_lstm = TARGET_LSTM(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN, target_params) # The oracle model

discriminator = Discriminator(sequence_length=32, num_classes=2, vocab_size=vocab_size, embedding_size=dis_embedding_dim, 
                            filter_sizes=dis_filter_sizes, num_filters=dis_num_filters, l2_reg_lambda=dis_l2_reg_lambda)

In [9]:
# 參數配置
config = tf.ConfigProto()
# 使用allow_growth option，剛一開始分配少量的GPU容量，然後按需慢慢的增加
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
sess.run(tf.global_variables_initializer())

In [10]:
# First, use the oracle model to provide the positive examples, which are sampled from the oracle data distribution
#generate_samples(sess, target_lstm, BATCH_SIZE, generated_num, positive_file)
gen_data_loader.create_batches(positive_file)

log = open('save/experiment-log.txt', 'w')
#  pre-train generator
print('Start pre-training...')
log.write('pre-training...\n')
now_time = time.clock()
sum_time = 0.
for epoch in range(PRE_EPOCH_NUM):
    loss = pre_train_epoch(sess, generator, gen_data_loader)    
    if epoch % 5 == 0:
        generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)
        likelihood_data_loader.create_batches(eval_file)
        '''test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
        print('pre-train epoch ', epoch, 'test_loss ', test_loss)
        buffer = 'epoch:\t'+ str(epoch) + '\tnll:\t' + str(test_loss) + '\n'
        log.write(buffer)'''
    after_time=time.clock() - now_time
    eta_time = (after_time / (epoch+1) )*(PRE_EPOCH_NUM-(epoch+1))
    print("[MLE Epoch]: {0:5} [Cost Time]: {1:10.2f} secs [ETA]: {2:10.2f} secs".format( (epoch+1), after_time, eta_time))

Start pre-training...
[MLE Epoch]:     1 [Cost Time]:      76.43 secs [ETA]:    3745.11 secs
[MLE Epoch]:     2 [Cost Time]:     145.52 secs [ETA]:    3492.39 secs
[MLE Epoch]:     3 [Cost Time]:     214.72 secs [ETA]:    3363.91 secs
[MLE Epoch]:     4 [Cost Time]:     284.24 secs [ETA]:    3268.75 secs
[MLE Epoch]:     5 [Cost Time]:     357.68 secs [ETA]:    3219.08 secs
[MLE Epoch]:     6 [Cost Time]:     433.32 secs [ETA]:    3177.70 secs
[MLE Epoch]:     7 [Cost Time]:     502.47 secs [ETA]:    3086.63 secs
[MLE Epoch]:     8 [Cost Time]:     573.53 secs [ETA]:    3011.03 secs
[MLE Epoch]:     9 [Cost Time]:     646.54 secs [ETA]:    2945.36 secs
[MLE Epoch]:    10 [Cost Time]:     717.55 secs [ETA]:    2870.22 secs
[MLE Epoch]:    11 [Cost Time]:     795.32 secs [ETA]:    2819.78 secs
[MLE Epoch]:    12 [Cost Time]:     868.89 secs [ETA]:    2751.50 secs
[MLE Epoch]:    13 [Cost Time]:     938.43 secs [ETA]:    2670.92 secs
[MLE Epoch]:    14 [Cost Time]:    1009.84 secs [ETA]: 

In [11]:
print('Start pre-training discriminator...')
# Train 3 epoch on the generated data and do this for 50 times
now_time = time.clock()
sum_time = 0.
pre_D_epoch = 0
for _ in range(PRETRAIN_DIS_NUM): #50
    pretrain_D = ( negative_file + "pretrain_discriminator.txt")
    generate_samples(sess, generator, BATCH_SIZE, generated_num, pretrain_D)
    dis_data_loader.load_train_data(positive_file, pretrain_D)
    for _ in range(1):
        dis_data_loader.reset_pointer()
        for it in range(dis_data_loader.num_batch):
            if it % int(dis_data_loader.num_batch / 30) == 0:
                print("pre_train_iteration : {0:6} / {1:6}".format( (it+1), (dis_data_loader.num_batch) ), end="\r")
            x_batch, y_batch = dis_data_loader.next_batch()
            feed = {
                discriminator.input_x: x_batch,
                discriminator.input_y: y_batch,
                discriminator.dropout_keep_prob: dis_dropout_keep_prob
            }
            _ = sess.run(discriminator.train_op, feed)
    after_time=time.clock() - now_time
    eta_time = (after_time / (pre_D_epoch+1) )*(PRETRAIN_DIS_NUM-(pre_D_epoch+1))
    print("[Pre-train Discriminator Epoch]: {0:5} [Cost Time]: {1:10.2f} secs [ETA]: {2:10.2f} secs".format( (pre_D_epoch+1), after_time, eta_time))
    pre_D_epoch = pre_D_epoch +1

rollout = ROLLOUT(generator, 0.8)

Start pre-training discriminator...
[Pre-train Discriminator Epoch]:     1 [Cost Time]:      35.48 secs [ETA]:    1738.46 secs
[Pre-train Discriminator Epoch]:     2 [Cost Time]:      66.87 secs [ETA]:    1604.92 secs
[Pre-train Discriminator Epoch]:     3 [Cost Time]:      97.67 secs [ETA]:    1530.22 secs
[Pre-train Discriminator Epoch]:     4 [Cost Time]:     128.55 secs [ETA]:    1478.35 secs
[Pre-train Discriminator Epoch]:     5 [Cost Time]:     159.27 secs [ETA]:    1433.39 secs
[Pre-train Discriminator Epoch]:     6 [Cost Time]:     190.16 secs [ETA]:    1394.48 secs
[Pre-train Discriminator Epoch]:     7 [Cost Time]:     221.35 secs [ETA]:    1359.74 secs
[Pre-train Discriminator Epoch]:     8 [Cost Time]:     252.56 secs [ETA]:    1325.93 secs
[Pre-train Discriminator Epoch]:     9 [Cost Time]:     283.84 secs [ETA]:    1293.04 secs
[Pre-train Discriminator Epoch]:    10 [Cost Time]:     314.95 secs [ETA]:    1259.80 secs
[Pre-train Discriminator Epoch]:    11 [Cost Time]:   

In [12]:
print('#########################################################################')
print('Start Adversarial Training...')
log.write('adversarial training...\n')
now_time = time.clock()
sum_time = 0.
for total_batch in range(TOTAL_BATCH):
    # Train the generator for one step
    for it in range(1):
        samples = generator.generate(sess)
        rewards = rollout.get_reward(sess, samples, 16, discriminator)
        feed = {generator.x: samples, generator.rewards: rewards}
        _ = sess.run(generator.g_updates, feed_dict=feed)

    # Test
    '''if total_batch % 5 == 0 or total_batch == TOTAL_BATCH - 1:
        generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)
        likelihood_data_loader.create_batches(eval_file)
        #print(total_batch)        
        test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
        buffer = 'epoch:\t' + str(total_batch) + '\tnll:\t' + str(test_loss) + '\n'
        print('total_batch: ', total_batch, 'test_loss: ', test_loss)
        log.write(buffer)'''

    # Update roll-out parameters
    rollout.update_params()

    # Train the discriminator
    for _ in range(3): #default value=range(5)
        adversarial_D = ( negative_file + "adversarial_gen/generator_digit_{0}.txt").format(str(total_batch+1).zfill(3))
        generate_samples(sess, generator, BATCH_SIZE, generated_num, adversarial_D)
        dis_data_loader.load_train_data(positive_file, adversarial_D)
        
        for _ in range(1): #default value=range(3)
            dis_data_loader.reset_pointer()
            for it in range(dis_data_loader.num_batch):
                if it % int(dis_data_loader.num_batch / 30) == 0:
                    print("Discriminator Epoch {0:5} iteration: {1:6} / {2:6}".format( (total_batch+1), (it+1), (dis_data_loader.num_batch) ), end="\r")
                x_batch, y_batch = dis_data_loader.next_batch()
                feed = {
                    discriminator.input_x: x_batch,
                    discriminator.input_y: y_batch,
                    discriminator.dropout_keep_prob: dis_dropout_keep_prob
                }
                _ = sess.run(discriminator.train_op, feed)
                
    after_time=time.clock() - now_time
    eta_time = (after_time / (total_batch+1) )*(TOTAL_BATCH-(total_batch+1))
    print("[GAN Epoch]: {0:5} [Cost Time]: {1:10.2f} secs [ETA]: {2:10.2f} secs".format( (total_batch+1), after_time, eta_time))
    
log.close()

#########################################################################
Start Adversarial Training...
[GAN Epoch]:     1 [Cost Time]:     104.23 secs [ETA]:   14905.39 secs
[GAN Epoch]:     2 [Cost Time]:     209.35 secs [ETA]:   14864.03 secs
[GAN Epoch]:     3 [Cost Time]:     313.31 secs [ETA]:   14725.52 secs
[GAN Epoch]:     4 [Cost Time]:     416.63 secs [ETA]:   14581.90 secs
[GAN Epoch]:     5 [Cost Time]:     520.37 secs [ETA]:   14466.38 secs
[GAN Epoch]:     6 [Cost Time]:     624.98 secs [ETA]:   14374.56 secs
[GAN Epoch]:     7 [Cost Time]:     730.34 secs [ETA]:   14293.88 secs
[GAN Epoch]:     8 [Cost Time]:     834.60 secs [ETA]:   14188.24 secs
[GAN Epoch]:     9 [Cost Time]:     940.18 secs [ETA]:   14102.69 secs
[GAN Epoch]:    10 [Cost Time]:    1043.98 secs [ETA]:   13989.37 secs
[GAN Epoch]:    11 [Cost Time]:    1147.53 secs [ETA]:   13874.73 secs
[GAN Epoch]:    12 [Cost Time]:    1253.40 secs [ETA]:   13787.36 secs
[GAN Epoch]:    13 [Cost Time]:    1356.14 s

[GAN Epoch]:   115 [Cost Time]:   11741.95 secs [ETA]:    2961.01 secs
[GAN Epoch]:   116 [Cost Time]:   11847.13 secs [ETA]:    2859.65 secs
[GAN Epoch]:   117 [Cost Time]:   11950.57 secs [ETA]:    2757.82 secs
[GAN Epoch]:   118 [Cost Time]:   12054.08 secs [ETA]:    2655.98 secs
[GAN Epoch]:   119 [Cost Time]:   12156.93 secs [ETA]:    2553.98 secs
[GAN Epoch]:   120 [Cost Time]:   12259.34 secs [ETA]:    2451.87 secs
[GAN Epoch]:   121 [Cost Time]:   12357.98 secs [ETA]:    2349.04 secs
[GAN Epoch]:   122 [Cost Time]:   12456.31 secs [ETA]:    2246.22 secs
[GAN Epoch]:   123 [Cost Time]:   12554.43 secs [ETA]:    2143.44 secs
[GAN Epoch]:   124 [Cost Time]:   12652.77 secs [ETA]:    2040.77 secs
[GAN Epoch]:   125 [Cost Time]:   12751.06 secs [ETA]:    1938.16 secs
[GAN Epoch]:   126 [Cost Time]:   12849.99 secs [ETA]:    1835.71 secs
[GAN Epoch]:   127 [Cost Time]:   12949.10 secs [ETA]:    1733.34 secs
[GAN Epoch]:   128 [Cost Time]:   13048.16 secs [ETA]:    1631.02 secs
[GAN E