In [None]:
import numpy as np
import tensorflow as tf
import random
import time
from dataloader import Gen_Data_loader, Dis_dataloader
from generator import Generator
from discriminator import Discriminator
from rollout import ROLLOUT
#from target_lstm import TARGET_LSTM
#import domain_translate
import pickle

In [None]:
#########################################################################################
#  Generator  Hyper-parameters
######################################################################################
EMB_DIM = 32 # embedding dimension
HIDDEN_DIM = 32 # hidden state dimension of lstm cell
SEQ_LENGTH = 32 # sequence length
START_TOKEN = 0
PRE_EPOCH_NUM = 120 # supervise (maximum likelihood estimation) epochs (預設為120)
SEED = 88
BATCH_SIZE = 64

In [None]:
#########################################################################################
#  Discriminator  Hyper-parameters
#########################################################################################
dis_embedding_dim = 64
dis_filter_sizes = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20]
dis_num_filters = [100, 200, 200, 200, 200, 100, 100, 100, 100, 100, 160, 160]
dis_dropout_keep_prob = 0.75
dis_l2_reg_lambda = 0.2
dis_batch_size = 64
PRETRAIN_DIS_NUM = 50 #pre-train discriminator times

In [None]:
#########################################################################################
#  Basic Training Parameters
#########################################################################################
TOTAL_BATCH = 200
positive_file = ('../Dataset/AlexaTop100K_Separated_Digital/'
                  +'TopDomainName.Less33.Separated.Digital-ALL-OF-100000.txt') # 預設值為 'save/real_data.txt'
output_path = 'save/' #預設值為 'save/generator_sample.txt' 更改為會更變
#eval_file = 'save/eval_file.txt'
generated_num = 10000

In [None]:
def generate_samples(sess, trainable_model, batch_size, generated_num, output_file):
    Digit2Real = False
    # Generate Samples
    generated_samples = []
    for _ in range(int(generated_num / batch_size)):
        generated_samples.extend(trainable_model.generate(sess))

    with open(output_file, 'w') as fout:
        for poem in generated_samples:
            buffer = ' '.join([str(x) for x in poem]) + '\n'
            fout.write(buffer)
            
    if "adversarial_gen" in output_file:
        now=output_file[37:]
        output_file="save/adversarial_gen/generator_fake_domain_name_"+now
        Digit2Real = True
    elif "pretrain_gen" in output_file:
        now=output_file[34:]
        output_file="save/pretrain_gen/generator_fake_domain_name_"+now
        Digit2Real = True
    
    if Digit2Real == True :
        digital_table = [" ","0","1","2","3","4","5","6","7","8","9",".","-",
                            "a","b","c","d","e","f","g","h","i","j","k","l","m",
                            "n","o","p","q","r","s","t","u","v","w","x","y","z","_"]    
        with open(output_file, 'w') as fout:
            for poem in generated_samples:
                buffer = "".join([digital_table[int(x)] for x in poem]) + '\n'
                fout.write(buffer)                    

In [None]:
def target_loss(sess, target_lstm, data_loader):
    # target_loss means the oracle negative log-likelihood tested with the oracle model "target_lstm"
    # For more details, please see the Section 4 in https://arxiv.org/abs/1609.05473
    nll = []
    data_loader.reset_pointer()

    for it in range(data_loader.num_batch):
        batch = data_loader.next_batch()
        #print(target_lstm.pretrain_loss)
        #print({target_lstm.x: batch})
        g_loss = sess.run(target_lstm.pretrain_loss, {target_lstm.x: batch})
        nll.append(g_loss)

    return np.mean(nll)

In [None]:
def pre_train_epoch(sess, trainable_model, data_loader):
    # Pre-train the generator using MLE for one epoch
    supervised_g_losses = []
    data_loader.reset_pointer()
    
    for it in range(data_loader.num_batch):
    #for it in range(50):
        #print("pre_train"+str(it))
        if it % int(data_loader.num_batch / 30) == 0:
            print("pre_train_iteration : {0:6} / {1:6}".format( (it+1), (data_loader.num_batch) ), end="\r")
        batch = data_loader.next_batch()
        _, g_loss = trainable_model.pretrain_step(sess, batch)
        supervised_g_losses.append(g_loss)

    return np.mean(supervised_g_losses)

In [None]:
random.seed(SEED)
np.random.seed(SEED)
assert START_TOKEN == 0

gen_data_loader = Gen_Data_loader(BATCH_SIZE)
likelihood_data_loader = Gen_Data_loader(BATCH_SIZE) # For testing
vocab_size = 40 #預設5000 
dis_data_loader = Dis_dataloader(BATCH_SIZE)

generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN)
#target_params = pickle.load(open('save/target_params_py3.pkl','rb'))
#target_lstm = TARGET_LSTM(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN, target_params) # The oracle model

discriminator = Discriminator(sequence_length=32, num_classes=2, vocab_size=vocab_size, embedding_size=dis_embedding_dim, 
                            filter_sizes=dis_filter_sizes, num_filters=dis_num_filters, l2_reg_lambda=dis_l2_reg_lambda)

In [None]:
# 參數配置
config = tf.ConfigProto()
# 使用allow_growth option，剛一開始分配少量的GPU容量，然後按需慢慢的增加
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
sess.run(tf.global_variables_initializer())

In [None]:
# First, use the oracle model to provide the positive examples, which are sampled from the oracle data distribution
#generate_samples(sess, target_lstm, BATCH_SIZE, generated_num, positive_file)
gen_data_loader.create_batches(positive_file)

log = open('save/experiment-log.txt', 'w')
#  pre-train generator
print('Start pre-training...')
log.write('pre-training...\n')
now_time = time.clock()
sum_time = 0.
for epoch in range(PRE_EPOCH_NUM):
    loss = pre_train_epoch(sess, generator, gen_data_loader)    
    
    eval_file = ( output_path + "pretrain_gen/generator_digit_{0}.txt").format(str(epoch+1).zfill(2))
    generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)
    likelihood_data_loader.create_batches(eval_file)
    
    '''test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
    print('pre-train epoch ', epoch, 'test_loss ', test_loss)
    buffer = 'epoch:\t'+ str(epoch) + '\tnll:\t' + str(test_loss) + '\n'
    log.write(buffer)'''
    after_time=time.clock() - now_time
    eta_time = (after_time / (epoch+1) )*(PRE_EPOCH_NUM-(epoch+1))
    print("[MLE Epoch]: {0:5} [Cost Time]: {1:10.2f} secs [ETA]: {2:10.2f} secs".format( (epoch+1), after_time, eta_time))

In [None]:
print('Start pre-training discriminator...')
# Train 3 epoch on the generated data and do this for 50 times
now_time = time.clock()
sum_time = 0.
pre_D_epoch = 0
for _ in range(PRETRAIN_DIS_NUM): #50
    pretrain_D = ( output_path + "pretrain_discriminator.txt")
    generate_samples(sess, generator, BATCH_SIZE, generated_num, pretrain_D)
    dis_data_loader.load_train_data(positive_file, pretrain_D)
    for _ in range(3):
        dis_data_loader.reset_pointer()
        for it in range(dis_data_loader.num_batch):
            if it % int(dis_data_loader.num_batch / 30) == 0:
                print("pre_train_iteration : {0:6} / {1:6}".format( (it+1), (dis_data_loader.num_batch) ), end="\r")
            x_batch, y_batch = dis_data_loader.next_batch()
            feed = {
                discriminator.input_x: x_batch,
                discriminator.input_y: y_batch,
                discriminator.dropout_keep_prob: dis_dropout_keep_prob
            }
            _ = sess.run(discriminator.train_op, feed)
    after_time=time.clock() - now_time
    eta_time = (after_time / (pre_D_epoch+1) )*(PRETRAIN_DIS_NUM-(pre_D_epoch+1))
    print("[Pre-train Discriminator Epoch]: {0:5} [Cost Time]: {1:10.2f} secs [ETA]: {2:10.2f} secs".format( (pre_D_epoch+1), after_time, eta_time))
    pre_D_epoch = pre_D_epoch +1

rollout = ROLLOUT(generator, 0.8)

In [None]:
print('#########################################################################')
print('Start Adversarial Training...')
log.write('adversarial training...\n')
now_time = time.clock()
sum_time = 0.
for total_batch in range(TOTAL_BATCH):
    # Train the generator for one step
    for it in range(1):
        samples = generator.generate(sess)
        rewards = rollout.get_reward(sess, samples, 16, discriminator)
        feed = {generator.x: samples, generator.rewards: rewards}
        _ = sess.run(generator.g_updates, feed_dict=feed)

    # Test
    '''if total_batch % 5 == 0 or total_batch == TOTAL_BATCH - 1:
        generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)
        likelihood_data_loader.create_batches(eval_file)
        #print(total_batch)        
        test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
        buffer = 'epoch:\t' + str(total_batch) + '\tnll:\t' + str(test_loss) + '\n'
        print('total_batch: ', total_batch, 'test_loss: ', test_loss)
        log.write(buffer)'''

    # Update roll-out parameters
    rollout.update_params()

    # Train the discriminator
    for _ in range(5): #default value=range(5)
        adversarial_D = ( output_path + "adversarial_gen/generator_digit_{0}.txt").format(str(total_batch+1).zfill(3))
        generate_samples(sess, generator, BATCH_SIZE, generated_num, adversarial_D)
        dis_data_loader.load_train_data(positive_file, adversarial_D)
        
        for _ in range(3): #default value=range(3)
            dis_data_loader.reset_pointer()
            for it in range(dis_data_loader.num_batch):
                if it % int(dis_data_loader.num_batch / 30) == 0:
                    print("Discriminator Epoch {0:5} iteration: {1:6} / {2:6}".format( (total_batch+1), (it+1), (dis_data_loader.num_batch) ), end="\r")
                x_batch, y_batch = dis_data_loader.next_batch()
                feed = {
                    discriminator.input_x: x_batch,
                    discriminator.input_y: y_batch,
                    discriminator.dropout_keep_prob: dis_dropout_keep_prob
                }
                _ = sess.run(discriminator.train_op, feed)
                
    after_time=time.clock() - now_time
    eta_time = (after_time / (total_batch+1) )*(TOTAL_BATCH-(total_batch+1))
    print("[GAN Epoch]: {0:5} [Cost Time]: {1:10.2f} secs [ETA]: {2:10.2f} secs".format( (total_batch+1), after_time, eta_time))
    
log.close()