In [2]:
import os, sys
sys.path.append(os.getcwd())

import time

import numpy as np
import tensorflow as tf

import language_helpers
import tflib as lib
import tflib.ops.linear
import tflib.ops.conv1d
import tflib.plot

In [3]:
# Download Google Billion Word at http://www.statmt.org/lm-benchmark/ and
# fill in the path to the extracted files here!（在此处填写解压缩文件的路径！）

#加载一百万Alexa数据
DATA_DIR = 'AlexaTop1M_NoSeparate'
if len(DATA_DIR) == 0:
    #请在gan_language.py中指定数据目录的路径！
    raise Exception("Please specify path to data directory in gan_language.py!")


BATCH_SIZE = 64 # Batch size（训练次数）
# How many iterations to train for, min value is 1000, Please increase the number of iteration in 1000 units（请以1000为单位增加迭代次数）
ITERS = 30000 
SEQ_LEN = 32 # Sequence length in characters（序列长度（以字符为单位））
DIM = 512 # Model dimensionality. This is fairly slow and overfits, even on
          # Billion Word. Consider decreasing for smaller datasets.（模型尺寸。即使是十亿字，这也相当慢且过拟合。考虑减少较小的数据集。）
CRITIC_ITERS = 10 # How many critic iterations per generator iteration. We
                  # use 10 for the results in the paper, but 5 should work fine
                  # as well.（每个生成器就迭代有多少个批评者critic（也即是原始GAN 的判别器）迭代。本文使用10作为结果，但5应该也可以）
LAMBDA = 10 # Gradient penalty lambda hyperparameter.（梯度惩罚Lambda超参数）
MAX_N_EXAMPLES = 100000 # Max number of data examples to load. If data loading
                          # is too slow or takes too much RAM, you can decrease
                          # this (at the expense of having less training data). default value is 10000000
                          # （要加载的最大数据示例数。如果数据加载太慢或占用过多RAM，则可以减少此操作（以减少训练数据为代价）。默认值为10000000）

In [4]:
#locals获取变量，并且大写本地的数据
lib.print_model_settings(locals().copy())

#数据处理成，
# ('o', 'm', 'e', 'g', 'a', 't', 'r', 'a', 'v', 'e', 'l', 'e', 'r', '.', 'i', 'n', 'f', 'o', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ')
#形式为32（SEQ_LEN）的元组数据向量，并且加载100000（MAX_N_EXAMPLES）行
lines, charmap, inv_charmap = language_helpers.load_dataset(
    max_length=SEQ_LEN,
    max_n_examples=MAX_N_EXAMPLES,
    data_dir=DATA_DIR
)

Uppercase local vars:
	BATCH_SIZE: 64
	CRITIC_ITERS: 10
	DATA_DIR: AlexaTop1M_NoSeparate
	DIM: 512
	ITERS: 30000
	LAMBDA: 10
	MAX_N_EXAMPLES: 100000
	SEQ_LEN: 32
loading dataset...
('s', 'h', 'a', 'r', 'e', 'i', 't', 'f', 'o', 'r', 'p', 'c', 'c', '.', 'c', 'o', 'm', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ')
('w', 'o', 'n', 'd', 'e', 'r', 'l', 'a', '.', 'c', 'o', 'm', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ')
('s', 'h', 'a', 'r', 'j', 'a', 'h', '.', 'a', 'c', '.', 'a', 'e', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ')
('g', 'i', 'o', 'c', 'h', 'i', '.', 'i', 't', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ')
('s', 'y', 's', 'h', 'l', '.', 'c', 'o', 'm', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ')
('i', 'p', 'e'

In [5]:
def softmax(logits):
    return tf.reshape(
        tf.nn.softmax(
            tf.reshape(logits, [-1, len(charmap)])
        ),
        tf.shape(logits)
    )

In [6]:
#生产噪音
def make_noise(shape):
    return tf.random_normal(shape)

In [7]:
def ResBlock(name, inputs):
    output = inputs
    output = tf.nn.relu(output)
    output = lib.ops.conv1d.Conv1D(name+'.1', DIM, DIM, 5, output)
    output = tf.nn.relu(output)
    output = lib.ops.conv1d.Conv1D(name+'.2', DIM, DIM, 5, output)
    return inputs + (0.3*output)

In [8]:
def Generator(n_samples, prev_outputs=None):
    output = make_noise(shape=[n_samples, 128])
    output = lib.ops.linear.Linear('Generator.Input', 128, SEQ_LEN*DIM, output)
    output = tf.reshape(output, [-1, DIM, SEQ_LEN])
    output = ResBlock('Generator.1', output)
    output = ResBlock('Generator.2', output)
    output = ResBlock('Generator.3', output)
    output = ResBlock('Generator.4', output)
    output = ResBlock('Generator.5', output)
    output = lib.ops.conv1d.Conv1D('Generator.Output', DIM, len(charmap), 1, output)
    output = tf.transpose(output, [0, 2, 1])
    output = softmax(output)
    return output

In [9]:
def Discriminator(inputs):
    output = tf.transpose(inputs, [0,2,1])
    output = lib.ops.conv1d.Conv1D('Discriminator.Input', len(charmap), DIM, 1, output)
    output = ResBlock('Discriminator.1', output)
    output = ResBlock('Discriminator.2', output)
    output = ResBlock('Discriminator.3', output)
    output = ResBlock('Discriminator.4', output)
    output = ResBlock('Discriminator.5', output)
    output = tf.reshape(output, [-1, SEQ_LEN*DIM])
    output = lib.ops.linear.Linear('Discriminator.Output', SEQ_LEN*DIM, 1, output)
    return output

In [10]:
#real_inputs_discrete 是真实样本的占位符
real_inputs_discrete = tf.placeholder(tf.int32, shape=[BATCH_SIZE, SEQ_LEN])
#负责将real_inputs_discrete 转为onehot编码 
real_inputs = tf.one_hot(real_inputs_discrete, len(charmap))
# 生成对应数量的伪样本
fake_inputs = Generator(BATCH_SIZE)
fake_inputs_discrete = tf.argmax(fake_inputs, fake_inputs.get_shape().ndims-1)

Instructions for updating:
`NCHW` for data_format is deprecated, use `NCW` instead


In [11]:
# 送进真假样本，初始化判别器，有基本的判别能力
disc_real = Discriminator(real_inputs) 
disc_fake = Discriminator(fake_inputs)

In [12]:
#判别器和生成器的损失函数
#WGAN生成器gen_cost函数（参照原理公式）
#WGAN判别器disc_cost函数（参照原理公式）
disc_cost = tf.reduce_mean(disc_fake) - tf.reduce_mean(disc_real)
gen_cost = -tf.reduce_mean(disc_fake)

In [13]:
# wgan-gp 改进精髓 改进lipschitz
# WGAN lipschitz-penalty 
alpha = tf.random_uniform(
    shape=[BATCH_SIZE,1,1], 
    minval=0.,
    maxval=1.
)

In [14]:
differences = fake_inputs - real_inputs
# interpolates(插样),在真实样本和生成样本之间随机插值，希望这个约束可以“布满”真实样本和生成样本之间的空间
interpolates = real_inputs + (alpha*differences)
# interpolates就是随机插值采样得到的图像，gradients就是loss中的梯度惩罚项
gradients = tf.gradients(Discriminator(interpolates), [interpolates])[0]
#求梯度的二范数
slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1,2]))
gradient_penalty = tf.reduce_mean((slopes-1.)**2)
# wgan-gp的公式：原来wgan公式+惩罚项 公式链接;http://www.twistedwg.com/2018/10/05/GAN_loss_summary.html
disc_cost += LAMBDA*gradient_penalty

In [15]:
gen_params = lib.params_with_name('Generator')
disc_params = lib.params_with_name('Discriminator')

In [16]:
#判别器和生成器的优化函数
# 训练不带正则项的损失函数gen_cost/disc_cost。
# 定义训练的目标函数gen_cost/disc_cost，训练次数及训练模型
gen_train_op = tf.train.AdamOptimizer(learning_rate=1e-4, beta1=0.5, beta2=0.999).minimize(gen_cost, var_list=gen_params)
disc_train_op = tf.train.AdamOptimizer(learning_rate=1e-4, beta1=0.5, beta2=0.999).minimize(disc_cost, var_list=disc_params)

In [17]:
# Dataset iterator
def inf_train_gen():
    while True:
        np.random.shuffle(lines)
        for i in range(0, len(lines)-BATCH_SIZE+1, BATCH_SIZE):
            yield np.array(
                [[charmap[c] for c in l] for l in lines[i:i+BATCH_SIZE]], 
                dtype='int32'
            )

In [18]:
# During training we monitor JS divergence between the true & generated ngram
# distributions for n=1,2,3,4. To get an idea of the optimal values, we
# evaluate these statistics on a held-out set first.
# 在训练过程中，我们监视JS在真实&生成的ngram分布之间的散度，其中n = 1,2,3,4 为了对最佳值有所了解，我们首先对保留集进行评估。

# 创立NgramLanguageModel类对象，n=1，2，3，4的2个数组。分别赋予true_char_ngram_lms 与 validation_char_ngram_lms
true_char_ngram_lms = [language_helpers.NgramLanguageModel(i+1, lines[10*BATCH_SIZE:], tokenize=False) for i in range(4)]
validation_char_ngram_lms = [language_helpers.NgramLanguageModel(i+1, lines[:10*BATCH_SIZE], tokenize=False) for i in range(4)]

for i in range(4):
    print ( "validation set JSD for n={}: {}".format(i+1, true_char_ngram_lms[i].js_with(validation_char_ngram_lms[i])) )
    
true_char_ngram_lms = [language_helpers.NgramLanguageModel(i+1, lines, tokenize=False) for i in range(4)]

validation set JSD for n=1: 0.0002892902687692001
validation set JSD for n=2: 0.011284768640386435
validation set JSD for n=3: 0.07993772080584989
validation set JSD for n=4: 0.17677387369452172


In [19]:
sum_disc_cost = []
js1 =[]
js2 =[]
js3 =[]
js4 =[]


with tf.Session() as session:
    #初始化模型的参数
    session.run(tf.global_variables_initializer())
   
    def generate_samples():
        samples = session.run(fake_inputs)
        #对于三维度矩阵，a有三个方向a[0][1][2]，按照a[2]方向找最大值，返回最大值索引值组成矩阵
        samples = np.argmax(samples, axis=2)
        decoded_samples = []
        for i in range(len(samples)):
            decoded = []
            for j in range(len(samples[i])):
                decoded.append(inv_charmap[samples[i][j]])
            decoded_samples.append(tuple(decoded))
        return decoded_samples
    # 获取数据，迭代一定的数据
    gen = inf_train_gen()

    sum_time = 0.
    line_time = 0. 
    loading_str = "*"
    # 训练3万轮
    for iteration in range(ITERS):
         # 记录当前时间
        start_time = time.time()
        
        if (iteration == 0):
            now_time = time.clock()
            print("[Start]")

        # Train generator（第二轮开始训练生成器）
        if iteration > 0:
            _ = session.run(gen_train_op)
            # _gen_cost,_ = session.run(gen_cost,gen_train_op) 自己写的

        # Train critic
        for i in range(CRITIC_ITERS):
            _data = gen.__next__() # 1个critic 就迭代 10 轮的数据
            _disc_cost, _ = session.run(
                [disc_cost, disc_train_op],
                feed_dict={real_inputs_discrete:_data}
            )
            
            #print("_disc_cost "+str(_disc_cost))
            #print("_ "+str(_))
            #print("_data "+str(_data))
            #print("gen_cost "+str(gen_cost))
            #print("disc_cost"+str(disc_cost))

        # How many iterations to change line 
        change_line=int(ITERS/1000)
        
        after_time=time.clock() - now_time
        sum_time+=after_time
         # 预计时间（以当前的单位迭代时间来预计运行完整个project要多长时间）
        eta_time = (ITERS-iteration)*(after_time)
        
        # 单位迭代时间  
        print("[{1:10}] [Iteration]: {0:10} [Unit iteration time    ]: {2:10.2f} secs [ETA]: {3:10.2f} secs".format( (iteration+1), loading_str, after_time, eta_time) , end="\r")
        now_time = time.clock()
        # 如果等于 迭代数
        if iteration % change_line == (change_line-1):
            loading_str += "*"
            if iteration % (10*change_line) == (10*change_line-1):
                # 输出总时间 ，进度条的百分几
                print("{5:5.0f}{0:7} [Iteration]: {1:10} [{2:23}]: {3:10.2f} secs [SUM]: {4:10.2f}".format("% Done!", (iteration+1), (str(10*change_line)+"x iterations time"), (sum_time-line_time), sum_time, (100*iteration/ITERS) ) )
                loading_str = "*"
                line_time = sum_time
        # 输出每轮的时间到统计图
        lib.plot.plot('time', time.time() - start_time)
         # 输出每轮的_disc_cost到统计图
        lib.plot.plot('train disc cost', _disc_cost)   
        sum_disc_cost.append(_disc_cost) 

        if iteration % (10*change_line) == (10*change_line-1):
            #print("checkpintB"+str(iteration+1))
            samples = []
            for i in range(10):
                samples.extend(generate_samples())
            #js 离散度,越小越好https://cloud.tencent.com/developer/article/1530349
            # for i in range(4):
            #     lm = language_helpers.NgramLanguageModel(i+1, samples, tokenize=False)
            #     lib.plot.plot('js{}'.format(i+1), lm.js_with(true_char_ngram_lms[i]))
            
            #自己写的：代替循环保存js数据
            #range = 0
            lm_0 = language_helpers.NgramLanguageModel(0+1, samples, tokenize=False)
            lib.plot.plot('js{}'.format(0+1), lm_0.js_with(true_char_ngram_lms[0]))
            js1.append(lm_0.js_with(true_char_ngram_lms[0]))
            #range = 1
            lm_1 = language_helpers.NgramLanguageModel(1+1, samples, tokenize=False)
            lib.plot.plot('js{}'.format(1+1), lm_1.js_with(true_char_ngram_lms[1]))
            js2.append(lm_1.js_with(true_char_ngram_lms[1]))
            #range = 2
            lm_2 = language_helpers.NgramLanguageModel(2+1, samples, tokenize=False)
            lib.plot.plot('js{}'.format(2+1), lm_2.js_with(true_char_ngram_lms[2]))
            js3.append( lm_2.js_with(true_char_ngram_lms[2]))
            #range = 3
            lm_3 = language_helpers.NgramLanguageModel(3+1, samples, tokenize=False)
            lib.plot.plot('js{}'.format(3+1), lm_3.js_with(true_char_ngram_lms[3]))
            js4.append( lm_3.js_with(true_char_ngram_lms[3])) 
            
            
            with open('output_data/samples_{}.txt'.format(str(iteration+1).zfill(7)), 'w',encoding = 'utf8') as f:
                for s in samples:
                    s = "".join(s)
                    s = language_helpers.checkDNSFrom(s)
                    f.write(str(s) + "\n")

        if iteration % (10*change_line) == (10*change_line-1):
            #print(iteration)
            lib.plot.flush()
        
        lib.plot.tick()


[Start]
    1% Done! [Iteration]:        300 [300x iterations time   ]:     393.32 secs [SUM]:     393.32 secs
iter 299	js3	0.23643992584512577	js4	0.29423759144787914	js1	0.05435322801155194	js2	0.14595305449819268	train disc cost	-3.0827670097351074	time	1.2679182585080464
    2% Done! [Iteration]:        600 [300x iterations time   ]:     396.01 secs [SUM]:     789.33 secs
iter 599	js3	0.2095511027261185	js4	0.2954987840224037	js1	0.04027807389462579	js2	0.11034715383447981	train disc cost	-2.4089765548706055	time	1.2691461451848347
    3% Done! [Iteration]:        900 [300x iterations time   ]:     395.73 secs [SUM]:    1185.06 secs
iter 899	js3	0.19830374120269578	js4	0.286636851346418	js1	0.035703011882826714	js2	0.10012788510265298	train disc cost	-2.4355621337890625	time	1.2686759209632874
    4% Done! [Iteration]:       1200 [300x iterations time   ]:     395.75 secs [SUM]:    1580.81 secs
iter 1199	js3	0.178651595212865	js4	0.2628392835344958	js1	0.03467610615829662	js2	0.088