In [None]:
# encoding=utf-8

from keras.models import Sequential, Model
from keras.layers import *
from keras.optimizers import Adam
from keras.preprocessing.sequence import pad_sequences
from keras.utils import Sequence
from keras.preprocessing.text import Tokenizer
from keras import regularizers
from keras.callbacks import ReduceLROnPlateau, ModelCheckpoint, Callback
from keras.initializers import Ones, Zeros, Orthogonal
from myclass import TiedEmbeddingsTransposed, Attention, Position_Embedding, LayerNormalization

import numpy as np
import random
import sys
import os
import json

import keras.backend.tensorflow_backend as KTF
import tensorflow as tf
import keras.backend as K

os.environ["CUDA_VISIBLE_DEVICES"] = "1"

config = tf.ConfigProto()
config.gpu_options.allow_growth=True
sess = tf.Session(config=config)


pad_token = 0
oov_token = 1
start_token = 2
end_token = 3

# choose the dataset by your need
train_path = 'your conditional data'

# To set the specific hyper-parameters under each condition, please refer to our paper
max_len = 17
dp = 0.2 # dropout rate
emb_size = 256
gru_dim = 150
batch_size = 128
latent_dim = 128 # input size of PluginVAE
bottle_dim = 20 # bottleneck vector size of PluginVAE
beta = K.variable(5.0) # control the KL term
kl_weight = 1.0 # weight of KL loss
head_num = 8
head_size = [(emb_size+latent_dim) // head_num, (emb_size+2*latent_dim) // head_num, (emb_size+3*latent_dim) // head_num]

train = []

with open(train_path, 'r', encoding='utf-8') as f:
    for line in f:
        line = line.strip().lower().split(' ') #已经分好词了
        train.append(line)

print('train corpus size:', sum([len(d) for d in train]))
sys.stdout.flush()
print('sequences:', len(train))
sys.stdout.flush()

if os.path.exists('yelp-vocab.json'):
    chars,id2char,char2id = json.load(open('yelp-vocab.json'))
    id2char = {int(i):j for i,j in id2char.items()}

print('vocab size:', len(char2id))
sys.stdout.flush()

        
print('%d texts in the training set'%len(train))

def str2id(s, start_end = False):
    ids = [char2id.get(c, oov_token) for c in s]
    if start_end:
        ids = [start_token] + ids + [end_token]
  
    return ids

def padding(x,y,z):
    ml = max_len
    x = [i + [0] * (ml-len(i)) for i in x]
    y = [i + [0] * (ml-len(i)) for i in y]
    z = [i + [0] * (ml-len(i)) for i in z]
    x = np.array(x)
    y = np.array(y)
    z = np.array(z)
    
    return x,y, z
    
def train_generator(yelp_data):
    x = []
    while True:
        np.random.shuffle(yelp_data)    
        for yelp in yelp_data:
            if len(yelp) > (max_len-2):
                yelp = yelp[:max_len-2]
                
            yelp = str2id(yelp, start_end=True)
            x.append(yelp)
            if len(x) == batch_size:
                x = [i + [0] * (max_len-len(i)) for i in x]
                x = np.array(x)
                z,_ = enc_model.predict(x)

                yield [x,z], None
                x = []
                z = []

def sample(preds, diversity=1.0):
    # sample from the given prediction
    preds = np.asarray(preds).astype('float64')
    preds = np.log(preds) / diversity
    exp_preds = np.exp(preds)
    preds = exp_preds / np.sum(exp_preds)
    probas = np.random.multinomial(1, preds, 1)

    return np.argmax(probas)

def argmax(preds):
    preds = np.asarray(preds).astype('float64')
    return np.argmax(preds)

def gen(num, diversity, argmax_flag):
    print('----- Generating from Generator-----')
    start_index = start_token #<BOS>
    start_word = id2char[start_index]
    for i in range(num):
        noise_vec = np.random.normal(size=(1, bottle_dim))
        g_vec = decoder.predict(noise_vec)
        generated = [[start_index]]
        sys.stdout.write(start_word)
        sys.stdout.flush()

        while(end_token not in generated[0] and len(generated[0]) <= max_len):
            x_seq = pad_sequences(generated, maxlen=max_len,padding='post')
            preds = dec_model.predict([x_seq, x_seq, g_vec], verbose=0)[0]
            preds = preds[len(generated[0])-1][3:]
            if argmax_flag:
                next_index = argmax(preds)
            else:
                next_index = sample(preds, diversity)

            next_index += 3
            next_word = id2char[next_index]

            generated[0] += [next_index]
            sys.stdout.write(next_word+' ')
            sys.stdout.flush()
        print('\n')
        
        
train_gen = train_generator(train)

#model architecture of PretrainVAE 

#encoder
encoder_input = Input(shape=(max_len, ), dtype='int32')
emb_layer = Embedding(len(char2id), emb_size)
encoder_emb = emb_layer(encoder_input) 
encoder1 = Bidirectional(GRU(gru_dim))

encoder_h = encoder1(encoder_emb)

# re-parameteristic trick
z_mean = Dense(latent_dim)(encoder_h)
z_log_var = Dense(latent_dim)(encoder_h)

kl_loss = Lambda(lambda x: K.mean(- 0.5 * K.sum(1 + x[0] - K.square(x[1]) - K.exp(x[0]), axis=-1)))([z_log_var, z_mean])

def sampling(args):
    z_mean, z_log_var = args
    epsilon = K.random_normal(shape=(K.shape(z_mean)[0], latent_dim), mean=0, stddev=1)
    return z_mean + K.exp(z_log_var / 2) * epsilon

enc_z = Lambda(sampling)([z_mean, z_log_var])
enc_z = Lambda(lambda x: K.in_train_phase(x[0], x[1]))([enc_z, z_mean])

enc_model = Model(encoder_input, [enc_z, kl_loss])
enc_model.load_weights('pretrain/yelp/enc-base-another.h5')
print('load encoder weights successfully')

# decoder
decoder_input = Input(shape=(max_len,), dtype='int32')
decoder_z_input = Input(shape=(latent_dim, ))
decoder_true_output = Input(shape=(max_len,), dtype='int32')


decoder_dense = Dense(emb_size)
dec_softmax = TiedEmbeddingsTransposed(tied_to=emb_layer, activation='softmax')

decoder_emb = emb_layer(decoder_input)
decoder_emb = Position_Embedding()(decoder_emb)
decoder_z = RepeatVector(max_len)(decoder_z_input)
decoder_h = decoder_emb

for layer in range(3):
    decoder_z_hier = Dense(latent_dim, activation=None)(decoder_z)
    decoder_h = Concatenate()([decoder_h, decoder_z_hier])
    decoder_h_attn = Attention(head_num, head_size[layer], max_len)([decoder_h, decoder_h, decoder_h])
    decoder_h = Add()([decoder_h, decoder_h_attn])
    decoder_h = LayerNormalization()(decoder_h)
    decoder_h_mlp = Dense(head_size[layer]*head_num, activation='relu')(decoder_h)
    decoder_h = Add()([decoder_h, decoder_h_mlp])
    decoder_h = LayerNormalization()(decoder_h)
    decoder_h = Position_Embedding()(decoder_h)


decoder_h = decoder_dense(decoder_h)
decoder_output = dec_softmax(decoder_h)

dec_model = Model([decoder_input, decoder_true_output, decoder_z_input], decoder_output)
dec_model.load_weights('pretrain/yelp/dec-base-another.h5')
print('load PretrainVAE decoder weights successfully')


In [None]:
# model architecture of PluginVAE

#encoder
z_in = Input(shape=(latent_dim, ))
z = z_in
z = Dense(latent_dim//2, activation=None)(z)
z = LeakyReLU()(z)
z = Dense(latent_dim//4, activation=None)(z)
z = LeakyReLU()(z)

z_mean = Dense(bottle_dim)(z)
z_log_var = Dense(bottle_dim)(z)
kl_loss = Lambda(lambda x: K.mean(- 0.5 * K.sum(1 + x[0] - K.square(x[1]) - K.exp(x[0]), axis=-1)))([z_log_var, z_mean])

def sampling(args):
    z_mean, z_log_var = args
    epsilon = K.random_normal(shape=(K.shape(z_mean)[0], bottle_dim), mean=0, stddev=1)
    return z_mean + K.exp(z_log_var / 2) * epsilon

enc_z = Lambda(sampling)([z_mean, z_log_var])
enc_z = Lambda(lambda x: K.in_train_phase(x[0], x[1]))([enc_z, z_mean])
encoder = Model(z_in, enc_z)
encoder.summary()

#decoder
dec1 = Dense(latent_dim//4, activation=None)
dec2 = Dense(latent_dim//2, activation=None)
dec3 = Dense(latent_dim, activation=None)

z = dec1(enc_z)
z = LeakyReLU()(z)
z = dec2(z)
z = LeakyReLU()(z)
z = dec3(z)
vae = Model(z_in, z)
mse_loss = K.mean(K.square(z-z_in))
vae.add_loss(mse_loss + kl_weight*K.abs(kl_loss-beta))
vae.compile(optimizer=Adam(1e-3, 0.5))
vae.metrics_names.append('kl_loss')
vae.metrics_tensors.append(kl_loss)
vae.summary()

z_bottle = Input(shape=(bottle_dim, ))
z = dec1(z_bottle)
z = LeakyReLU()(z)
z = dec2(z)
z = LeakyReLU()(z)
z = dec3(z)
decoder = Model(z_bottle, z)
decoder.summary()


In [None]:
#load pretrained classifier to conduct automatic evaluation
# notice length condition doesn't need classifier
cnn_filter = 400
cnn_kernel = 3
x_in = Input(shape=(max_len, ))
x = x_in
x = Embedding(len(char2id), 100)(x)
x = Conv1D(cnn_filter, cnn_kernel, padding='valid', activation='relu')(x)
x = GlobalMaxPooling1D()(x)
x_out = Dense(1, activation='sigmoid')(x)
cls = Model(x_in ,x_out)
#cls.summary()
cls.load_weights('pretrain/yelp/cls.h5')
print('load clssifier successfully')

In [None]:
def gen_from_ae(diversity, num, argmax_flag=False):
    r1 = []
    r2 = []
    r3 = []
    start_index = start_token #<BOS>
    start_word = id2char[start_index]
    for j in range(num):
        random_vec = np.random.normal(size=(1, bottle_dim))
        g_vec = decoder.predict(random_vec)
        generated = [[start_index]]
        gen_word = []
        while(end_token not in generated[0] and len(generated[0]) <= max_len):
            x_seq = pad_sequences(generated, maxlen=max_len,padding='post')
            preds = dec_model.predict([x_seq, x_seq, g_vec], verbose=0)[0]
            preds = preds[len(generated[0])-1][3:]
            if argmax_flag:
                next_index = argmax(preds)
            else:
                next_index = sample(preds, diversity)

            next_index += 3
            next_word = id2char[next_index]
            gen_word.append(next_word)
            generated[0] += [next_index]

        if '<EOS>' == gen_word[-1]:
            gen_word = gen_word[:-1]
        gen_word = gen_word[:(max_len-2)]    
        r1.append([char2id[c] for c in gen_word])
        r2.append([start_token]+[char2id[c] for c in gen_word]+[end_token])
        r3.append([char2id[c] for c in gen_word]+[end_token])
    return np.array(r1), np.array(r2), np.array(r3)

In [None]:
def get_distinct(id_list_data):
    grams = id_list_data
    grams_list1 = []
    for sen in grams:
        for g in sen:
            grams_list1.append(g)
            
    grams_list2 = []
    for sen in grams:
        for i in range(len(sen)-1):
            grams_list2.append(str(sen[i])+' '+str(sen[i+1]))
            
    print('distinct-1:', len(set(grams_list1))/len(grams_list1))
    print('distinct-2:', len(set(grams_list2))/len(grams_list2))
      
def gen_from_vec(diversity, vec, argmax_flag):
    start_index = start_token #<BOS>
    start_word = id2char[start_index]
    print()

    generated = [[start_index]]
    sys.stdout.write(start_word)

    while(end_token not in generated[0] and len(generated[0]) <= max_len):
        x_seq = pad_sequences(generated, maxlen=max_len,padding='post')
        preds = dec_model.predict([x_seq, x_seq, vec], verbose=0)[0]
        preds = preds[len(generated[0])-1][3:]
        if argmax_flag:
            next_index = argmax(preds)
        else:
            next_index = sample(preds, diversity)
        next_index += 3
        next_word = id2char[next_index]

        generated[0] += [next_index]
        sys.stdout.write(next_word+' ')
        sys.stdout.flush()   

In [None]:
# training process
total_iter = 20001

best_val = 100000.0
best_result = []

# to set weight beta, please refer to our paper
def get_beta_weight(iter_num):
    now_beta_weight = min((5.0/10000)*iter_num, 5.0)
    return now_beta_weight

for i in range(total_iter):
    real_x, real_z = next(train_gen)[0]
    K.set_value(vae.optimizer.lr, 3e-4)
    K.set_value(beta, get_beta_weight(i))
    loss = vae.train_on_batch(
        real_z, None)
        
    if i % 100 == 0:
        print ('iter: %s, loss: %s' % (i, loss))
        sys.stdout.flush()
        
    if (i % 2000 == 0) and i!=0:        
        gen_num = 1000
        gen_samples1, gen_samples2, gen_samples3 = gen_from_ae(1.0, gen_num, True)
        get_distinct(gen_samples1)
        gen_samples1, gen_samples2, gen_samples3 = padding(gen_samples1, gen_samples2, gen_samples3)
        gen_result = cls.predict(gen_samples1)
        print('%f of the sample is positive in generator'%(np.sum(np.round(gen_result))/gen_num))
        gen(10, 0.5, True)
    
        


      

In [None]:
# generating 10K text for evaluation
gen_num = 10000
gen_1, gen_2, gen_3 = gen_from_ae(1.0, gen_num, True)
with open('gen/PPVAE-single.txt', 'w', encoding='utf-8') as f:
    for g in gen_1:
        f.write(' '.join([id2char[index] for index in g])+'\n')

In [None]:
# get distinct-1/2
get_distinct(gen_1)

In [None]:
# condition accurarcy by pre-trained classifier
# notice length condition doesn't need classifier
gen_1, gen_2, gen_3 = padding(gen_1, gen_2, gen_3)
cls_result = cls.predict(gen_1)
print('%f of the sample is positive'%(np.sum(np.round(cls_result))/gen_num))