In [1]:
import os
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [2]:
### Change working directory

path = '/content/gdrive/My Drive/NLG_2021_05_21'
os.chdir(path)

In [3]:
import os
import csv
import jieba
import json
import numpy as np
from pkg_resources import resource_filename
from keras.preprocessing import sequence
from sklearn.preprocessing import LabelBinarizer

class PreProcessor():

  config = {
      "meta_token": "<s>",
      "max_length": 100,
      "max_words": 50000
  }

  def __init__(self, vocab_path = None, config_path = None, name = "TextGenerator"):
    self.vocab = None
    self.config.update({"name": name})
    if vocab_path is None:
      vocab_path = resource_filename(__name__, "{}_vocab.json".format(self.config["name"]))
    if config_path is not None:
      with open(config_path, "r", encoding = "utf8", errors = "ignore") as json_file:
        self.config = json.load(json_file)
    if os.path.exists(vocab_path):
      with open(vocab_path, "r", errors = "ignore") as json_file:
        self.vocab = json.load(json_file)
        self.num_classes = len(self.vocab) + 1
        self.config.update({"num_classes": self.num_classes})
        self.indices_char = dict((self.vocab[c], c) for c in self.vocab)
  
  def read_textfile(self, file_path, context = None, header = True, is_csv = False):
    """ Retrieves texts from a newline-delimited file and returns as a list. """
    self.config.update({"context": context})
    if context:
      texts = []
      contexts = []
    else:
      texts = []
    if is_csv:
      with open(file_path, "r", encoding = "utf8", errors = "ignore") as f:
        reader = csv.reader(f)
        for row in reader:
          if context:
            try:
              texts.append(row[1])
              contexts.append(row[0])
            except Exception:
              pass 
          else:
            try:
              texts.append(row[1])
            except Exception:
              pass
    else:
      if header:
        with open(file_path, "r", encoding = "utf8", errors = "ignore") as f:
          f.readline()
          for line in f:
            row = line.split()
            if context:
              texts.append(row[0])
              contexts.append(row[1])
            else:
              texts.append(row[0])
    for i in range(len(texts)):
      texts[i] = self.regex_process(texts[i])
      texts[i] = jieba.lcut(texts[i])
    self.word_count(texts)
    if context:
      return texts, contexts
    else:
      return texts
  
  def regex_process(self, texts):
    import re
    from string import punctuation
    texts = re.sub(r"[br<br/>\n]", "", texts)
    texts = re.sub(r"[1-9]+[(，、.)]", "", texts)
    texts = re.sub(r"【\w+】", "", texts)
    texts = re.sub(r"】", "", texts)
    punc = punctuation + u"。！？＂＃＄％＆＇（）＊＋，－／：；＜＝＞＠［＼］＾＿｀｛｜｝～｟｠｢｣､、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏.·"
    texts = re.sub(r"[{}]+".format(punc), " ", texts)
    return texts

  def word_count(self, texts):
    self.wordcount = {}
    for i in range(len(texts)):
      for word in texts[i]:
        if word in self.wordcount:
          self.wordcount[word] += 1
        else:
          self.wordcount[word] = 1
    count_list = sorted([(v, k) for k, v in self.wordcount.items()], reverse = True)
    self.wordcount = {v: k for i, (k, v) in enumerate(count_list)}

  def generate_vocabulary(self, texts):
    """ Create a vocabulary in the form of "text: index" for training set. """
    if self.vocab is None:
      vocab = []
    else:
      vocab = list(self.vocab.keys())
    for i in range(len(texts)):
      vocab.extend(word for word in texts[i])
    vocab = set(vocab)
    # Limit vocab to max_words
    self.vocab = {}
    for i, t in enumerate(vocab):
      if i <= self.config["max_words"]:
        self.vocab[t] = i + 1
    if self.config["meta_token"] not in self.vocab.keys():
      self.vocab[self.config["meta_token"]] = len(self.vocab) + 1
    self.num_classes = len(self.vocab) + 1
    self.config.update({"num_classes": self.num_classes})
    self.indices_char = dict((self.vocab[c], c) for c in self.vocab)
    # save the files needed to recreate the model
    with open("{}_vocab.json".format(self.config["name"]), "w", encoding = "utf-8") as outfile:
      json.dump(self.vocab, outfile, ensure_ascii = False)

  def process_sequence(self, X):
    """ Padding the sequence in order to make sure all the sequence in a fixed length. """
    X = [[self.vocab[w] for w in X[0]]]
    X = sequence.pad_sequences(X, maxlen = self.config["max_length"])
    return X

  def encode_cat(self, chars, vocab):
    """ One hot encodes values at given chars efficiently by preallocating a zeros matrix. """
    a = np.float32(np.zeros((len(chars), len(vocab) + 1)))
    rows, cols = zip(*[(i, vocab.get(char, 0)) for i, char in enumerate(chars)])
    a[rows, cols] = 1
    return a

  def generate_sequences_from_texts(self, texts, indices_list, contexts = None, batch_size = 128):
    """ Generate the training batch from text. """
    while True:
      np.random.shuffle(indices_list)
      X_batch = []
      Y_batch = []
      context_batch = []
      count_batch = 0
      for row in range(indices_list.shape[0]):
        text_index = indices_list[row, 0]
        end_index = indices_list[row, 1]
        text = texts[text_index]
        text = [self.config["meta_token"]] + text + [self.config["meta_token"]]
        if end_index > self.config["max_length"]:
          x = text[end_index - self.config["max_length"] : end_index + 1]
        else:
          x = text[0 : end_index + 1]
        y = text[end_index + 1]
        if y in self.vocab:
          x = self.process_sequence([x])
          y = self.encode_cat([y], self.vocab)
          X_batch.append(x)
          Y_batch.append(y)
          if contexts is not None:
            context_batch.append(contexts[text_index])
          count_batch += 1
          if count_batch % batch_size == 0:
            X_batch = np.squeeze(np.array(X_batch))
            Y_batch = np.squeeze(np.array(Y_batch))
            context_batch = np.squeeze(np.array(context_batch))
            if contexts is not None:
              yield ([X_batch, context_batch], [Y_batch, Y_batch])
            else:
              yield (X_batch, Y_batch)
            X_batch = []
            Y_batch = []
            context_batch = []
            count_batch = 0

  def generate_generator(self, texts, contexts = None, train_size = 0.7, validation = True, batch_size = 128):
    """ Generate the generator for model training. """
    gen = None
    gen_val = None
    self.val_steps = None
    if contexts:
      contexts = LabelBinarizer().fit_transform(contexts)
      self.config.update({"context_size": contexts.shape[1]})
    # calculate all combinations of text indices + token indices
    indices_list = np.block([np.meshgrid(np.array(i), np.arange(len(text) + 1)) for i, text in enumerate(texts)])
    indices_mask = np.random.rand(indices_list.shape[0]) < train_size
    indices_list_train = indices_list[indices_mask, :]
    num_tokens = indices_list_train.shape[0]
    assert num_tokens >= batch_size, "Fewer tokens than batch_size"
    self.steps_per_epoch = max(int(np.floor(num_tokens / batch_size)), 1)
    self.config.update({"steps_per_epoch": self.steps_per_epoch})
    gen = self.generate_sequences_from_texts(texts = texts, contexts = contexts, indices_list = indices_list, batch_size = batch_size)
    # generate text indices for validaiton  
    if train_size < 1.0 and validation:
      indices_list_val = indices_list[~indices_mask, :]
      gen_val = self.generate_sequences_from_texts(texts = texts, contexts = contexts, indices_list = indices_list_val, batch_size = batch_size)
      self.val_steps = max(int(np.floor(indices_list_val.shape[0] / batch_size)), 1)
      self.config.update({"val_steps": self.val_steps})
    return gen, gen_val
  
  def main(self, file_path, context, is_csv, train_size, validation, batch_size):
    if context:
      texts, contexts = self.read_textfile(file_path, context, is_csv = is_csv)
    else:
      texts = self.read_textfile(file_path, context, is_csv = is_csv)
      contexts = None
    self.word_count(texts)
    self.generate_vocabulary(texts)
    gen, gen_val = proc.generate_generator(texts, contexts, train_size = train_size, validation = validation, batch_size = batch_size)
    self.config.update({"train_size": train_size})
    self.config.update({"validaiton": validation})
    self.config.update({"batch_size": batch_size})
    with open('{}_config.json'.format(self.config['name']), 'w', encoding='utf8') as outfile:
      json.dump(self.config, outfile, ensure_ascii=False)
    return gen, gen_val

In [None]:
proc = PreProcessor()
gen, gen_val = proc.main("large_text.csv", context = True, is_csv = True, train_size = 0.7, validation = True, batch_size = 128)

Building prefix dict from the default dictionary ...
Dumping model to file cache /tmp/jieba.cache
Loading model cost 1.040 seconds.
Prefix dict has been built successfully.


In [28]:
from keras.callbacks import LearningRateScheduler, Callback
from keras.models import Model, load_model
from keras.optimizers import RMSprop
from keras import backend as K
from pkg_resources import resource_filename
from keras.engine import InputSpec, Layer
from keras import initializers
from keras.layers import Input, Embedding, Dense, LSTM, Bidirectional
from keras.layers import concatenate, Reshape, SpatialDropout1D
from keras.preprocessing import sequence
from random import shuffle
from pkg_resources import resource_filename
from tqdm import trange

import os
import numpy as np
import json

class TextGenerator:

  def __init__(self, weights_path = None, vocab_path = None, config_path = None, name = "TextGenerator"):
    if config_path is not None:
      with open(config_path, "r", encoding = "utf8", errors = "ignore") as json_file:
        self.config = json.load(json_file)
    self.config.update({"name": name})
    self.config.update({"rnn_layers": 2})
    self.config.update({"rnn_size": 256})
    self.config.update({"rnn_bidirectional": True})
    self.config.update({"dim_embeddings": 300})
    if weights_path is None:
      weights_path = resource_filename(__name__, "{}_weights.hdf5".format(self.config["name"]))
    if vocab_path is None:
      vocab_path = resource_filename(__name__, "{}_vocab.json".format(self.config["name"]))
    if os.path.exists(vocab_path) is not None:
      with open(vocab_path, "r", encoding = "utf8", errors = "ignore") as json_file:
        self.vocab = json.load(json_file)
        self.num_classes = len(self.vocab) + 1
        self.indices_char = dict((self.vocab[c], c) for c in self.vocab)
    if os.path.exists(weights_path) is not None:
      self.model = self.build_textgen_model(weights_path = weights_path)
    else:
      self.model = None
  
  def build_textgen_model(self, context_size = None, weights_path = None, dropout = 0.0, optimizer = RMSprop(lr = 3e-4, rho = 0.99)):
    """ Builds the model architecture for textgen and loads the specified weights for the model. """
    # build rnn layer
    def new_rnn(layer_num):
      if self.config["rnn_bidirectional"]:
        return Bidirectional(LSTM(self.config["rnn_size"], return_sequences = True, recurrent_activation = "sigmoid"), name = "rnn_{}".format(layer_num))
      return LSTM(self.config["rnn_size"], return_sequences = True, recurrent_activation = "sigmoid", name = "rnn_{}".format(layer_num))
    # input layer
    input = Input(shape = (self.config["max_length"], ), name = "input")
    # embedding layer
    embedding = Embedding(self.num_classes, self.config["dim_embeddings"], input_length = self.config["max_length"], name = "embedding")(input)
    # dropout
    self.config.update({"dropout": dropout})
    if dropout > 0.0:
      embedding = SpatialDropout1D(dropout, name = "dropout")(embedding)
    # rnn layer
    rnn_layer_list = []
    for i in range(self.config["rnn_layers"]):
      prev_layer = embedding if i is 0 else rnn_layer_list[-1]
      rnn_layer_list.append(new_rnn(i+1)(prev_layer))
    # concatenation
    seq_concat = concatenate([embedding] + rnn_layer_list, name = "rnn_concat")
    # attention layer
    attention = AttentionWeightedAverage(name = "attention")(seq_concat)
    # output layer
    output = Dense(self.num_classes, name = "output", activation = "softmax")(attention)
    # whether to include the context
    if context_size is None:
      model = Model(inputs = [input], outputs = [output])
      if weights_path is not None:
        model.load_weights(weights_path, by_name = True)
      model.compile(loss = "categorical_crossentropy", optimizer = optimizer)
    else:
      context_input = Input(shape = (context_size, ), name = "context_input")
      context_reshape = Reshape((context_size, ), name = "context_reshape")(context_input)
      merged = concatenate([attention, context_reshape], name = "concat")
      main_output = Dense(self.num_classes, name = "context_output", activation = "softmax")(merged)
      model = Model(inputs = [input, context_input], outputs = [main_output, output])
      if weights_path is not None:
        model.load_weights(weights_path, by_name = True)
      model.compile(loss = "categorical_crossentropy", optimizer = optimizer, loss_weights = [0.8, 0.2])
    return model

  def train_textgen(self, gen, gen_val, dropout, num_epochs, gen_epochs, save_epochs, verbose, context = None):
    """ Train the model built before. """
    base_lr = 3e-4
    self.config.update({"base_lr": base_lr})
    # scheduler function must be defined inline
    def lr_linear_decay(epoch):
      return (base_lr * (1 - (epoch / num_epochs)))
    if self.model is None:
      if context:
        self.model = self.build_textgen_model(dropout = dropout, context_size = self.config["context_size"])
      else:
        self.model = self.build_textgen_model(dropout = dropout)
    with open('{}_config.json'.format(self.config['name']), 'w', encoding='utf8') as outfile:
      json.dump(self.config, outfile, ensure_ascii=False)
    self.model.fit(
        gen, 
        steps_per_epoch = self.config["steps_per_epoch"], 
        epochs = num_epochs, 
        callbacks = [
          LearningRateScheduler(lr_linear_decay), 
          generate_after_epoch(self, gen_epochs, self.config["max_length"]), 
          save_model_weights(self, num_epochs, save_epochs)
        ],
        verbose = verbose,
        max_queue_size = 10,
        validation_data = gen_val,
        validation_steps = self.config["val_steps"]
    )

  def generate(self, n = 1, return_as_list = False, prefix = None, temperature = [1.0, 0.5, 0.2, 0.2], interactive = False, top_n = 5, progress = True):
    gen_texts = []
    iterable = trange(n) if progress and n > 1 else range(n)
    for _ in iterable:
      gen_text, _ = self.generate_text(
          self.config["max_length"],
          top_n,
          temperature, 
          interactive, 
          prefix
      )
      if not return_as_list:
        print("{}\n".format(gen_text))
      gen_texts.append(gen_text)
    if return_as_list:
      return gen_texts
  
  def generate_samples(self, n = 1, temperatures = [0.0, 0.2, 0.5, 1.0], **kwargs):
    for temperature in temperatures:
      print("#" * 20 + "\nTemperature: {}\n".format(temperature) + "#" * 20)
      self.generate(n, temperature = temperature, progress = False, **kwargs)
      
  def generate_text(self, maxlen, top_n, temperature = 0.5, interactive = False, prefix = None, synthesize = False, stop_tokens = [" ", "\n"]):
    """ Generates and returns a single text. """
    collapse_char = ""
    end = False
    text = [self.config["meta_token"]] + jieba.lcut(prefix) if prefix else [self.config["meta_token"]]
    next_char = ""
    if not isinstance(temperature, list):
      temperature = [temperature]
    if len(self.model.inputs) > 1:
      self.model_gen = Model(inputs = self.model.inputs[0], outputs = self.model.outputs[1])
    else:
      self.model_gen = self.model
    while not end and len(text) < self.config["max_length"]:
      encoded_text = self.prediction_encode_sequence(text[-maxlen:], self.vocab, maxlen)
      next_temperature = temperature[(len(text) - 1) % len(temperature)]
      if not interactive:
        # auto-generate text without user intervention
        next_index = self.generate_next_char(self.model_gen.predict(encoded_text, batch_size = 1)[0], next_temperature)
        next_char = self.indices_char[next_index]
        text += [next_char]
        if next_char == self.config["meta_token"] or len(text) >= self.config["max_length"]:
          end = True
        gen_break = (next_char in stop_tokens or len(stop_tokens) == 0)
        if synthesize and gen_break:
          break
      else:
        # ask user what the next char/word should be
        options_index = self.generate_sample(self.model_gen.predict(encoded_text, batch_size = 1)[0], next_temperature, interactive = interactive, top_n = top_n)
        options = [self.indices_char[idx] for idx in options_index]
        print("Controls:\n\ts: stop.\tx: backspace.\to: write your own.")
        print("\nOptions:")
        for i, option in enumerate(options, 1):
          print("\t{}: {}".format(i, option))
        print("\nProgress: {}".format(collapse_char.join(text)[3:]))
        print("\nYour choice?")
        user_input = input("> ")
        try:
          user_input = int(user_input)
          next_char = options[user_input - 1]
          text += [next_char]
          if next_char == "<s>":
            end = True
        except ValueError:
          if user_input == "s":
            next_char = "<s>"
            text += [next_char]
          elif user_input == "o":
            other = input("> ")
            text += [other]
          elif user_input == "x":
            try:
              del text[-1]
            except IndexError:
              pass
          else:
            print("That\'s not an option!")
    # if not single text, remove the <s> meta_tokens
    text = text[1:]
    if self.config["meta_token"] in text:
      text.remove(self.config["meta_token"])
    text_joined = collapse_char.join(text)
    return text_joined, end

  def prediction_encode_sequence(self, text, vocab, maxlen):
    """ Encodes a text into the corresponding encoding for prediction with the model. """
    encoded = np.array([vocab.get(x, 0) for x in text])
    return sequence.pad_sequences([encoded], maxlen = maxlen)
  
  def generate_next_char(self, preds, temperature, interactive = False, top_n = 5):
    """ Samples predicted probabilities of the next character to allow for the network to show "creativity". """
    preds = np.asarray(preds).astype("float64")
    if temperature is None or temperature == 0.0:
      return np.argmax(preds)
    preds = np.log(preds + K.epsilon()) / temperature
    exp_preds = np.exp(preds)
    preds = exp_preds / np.sum(exp_preds)
    probas = np.random.multinomial(1, preds, 1)
    if not interactive:
      index = np.argmax(probas)
      # prevent function from being able to choose 0 (placeholder)
      # choose 2nd best index from preds
      if index == 0:
        index = np.argsort(preds)[-2]
    else:
      # return list of top N chars/words descending order, based on probability
      index = (-preds).argsort()[:top_n]
    return index

class generate_after_epoch(Callback):
  def __init__(self, textgen, gen_epochs, max_gen_length):
    self.textgen = textgen
    self.gen_epochs = gen_epochs
    self.max_gen_length = max_gen_length
  
  def on_epoch_end(self, epoch, logs = {}):
    if self.gen_epochs > 0 and (epoch + 1) % self.gen_epochs == 0:
      self.textgen.generate_samples()

class save_model_weights(Callback):
  def __init__(self, textgen, num_epochs, save_epochs):
    self.textgen = textgen
    self.weights_name = textgen.config["name"]
    self.num_epochs = num_epochs
    self.save_epochs = save_epochs
  
  def on_epoch_end(self, epoch, log = {}):
    if len(self.textgen.model.inputs) > 1:
      self.textgen.model = Model(inputs = self.model.input[0], outputs = self.model.output[1])
    if self.save_epochs > 0 and (epoch + 1) % self.save_epochs == 0 and self.num_epochs != (epoch + 1):
      print("Saving Model Weights - Epoch #{}".format(epoch + 1))
      self.textgen.model.save_weights("{}_weights_epoch_{}.hdf5".format(self.weights_name, epoch + 1))
    else:
      self.textgen.model.save_weights("{}_weights.hdf5".format(self.weights_name))

class AttentionWeightedAverage(Layer):
  """ This attention layer code is from DeepMoji (MIT Licensed). """
  """ Computes a weighted average of the different channels across timesteps. Users 1 parameter pr. channel to compute the attention value for a single timestep. """
  def __init__(self, return_attention = False, **kwargs):
    self.init = initializers.get("uniform")
    self.supports_masking = True
    self.return_attention = return_attention
    super(AttentionWeightedAverage, self).__init__(**kwargs)
  
  def build(self, input_shape):
    self.input_spec = [InputSpec(ndim = 3)]
    assert len(input_shape) == 3
    self.W = self.add_weight(shape = (input_shape[2], 1), name = "{}_W".format(self.name), initializer = self.init)
    self._trainable_weights = [self.W]
    super(AttentionWeightedAverage, self).build(input_shape)
  
  def call(self, x, mask = None):
    # computes a probability distribution over the timesteps
    # uses "max trick" for numerical stability
    # reshape is done to avoid issue with Tensorflow and 1-dimensional weights
    logits = K.dot(x, self.W)
    x_shape = K.shape(x)
    logits = K.reshape(logits, (x_shape[0], x_shape[1]))
    ai = K.exp(logits - K.max(logits, axis = -1, keepdims = True))
    # mask timesteps have zero weight
    if mask is not None:
      mask = K.cast(mask, K.floatx())
      ai = ai * mask
    attn_weights = ai / (K.sum(ai, axis = 1, keepdims = True) + K.epsilon())
    weighted_input = x * K.expand_dims(attn_weights)
    result = K.sum(weighted_input, axis = 1)
    if self.return_attention:
      return [result, attn_weights]
    else:
      return result
  
  def get_output_shape_for(self, input_shape):
    return self.compute_output_shape(input_shape)
  
  def compute_output_shape(self, input_shape):
    output_len = input_shape[2]
    if self.return_attention:
      return [(input_shape[0], output_len), (input_shape[0], input_shape[1])]
    else:
      return (input_shape[0], output_len)
  
  def compute_mask(self, input, input_mask = None):
    if isinstance(input_mask, list):
      return [None] * len(input_mask)
    else:
      return [None]

In [None]:
textgen = TextGenerator(vocab_path = "TextGenerator_vocab.json")
model = textgen.build_textgen_model(dropout = 0.2)
textgen.train_textgen(gen = gen, gen_val = gen_val, dropout = 0.2, num_epochs = 5, gen_epochs = 1, save_epochs = 2, verbose = 1, context = True)

Epoch 1/5
   6/4839 [..............................] - ETA: 9:57:53 - loss: 9.3008 - context_output_loss: 9.2809 - output_loss: 9.3808

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py", line 2882, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-6-86d570b4855d>", line 3, in <module>
    textgen.train_textgen(gen = gen, gen_val = gen_val, dropout = 0.2, num_epochs = 5, gen_epochs = 1, save_epochs = 2, verbose = 1, context = True)
  File "<ipython-input-5-3c97d9c4add5>", line 113, in train_textgen
    validation_steps = self.config["val_steps"]
  File "/usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/engine/training.py", line 1100, in fit
    tmp_logs = self.train_function(iterator)
  File "/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/def_function.py", line 828, in __call__
    result = self._call(*args, **kwds)
  File "/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/def_function.py", line 855, in _call
    return self._stateless_fn(*args, **kwds)  # pylint: disable=not

KeyboardInterrupt: ignored

In [None]:
from tqdm import trange

class Predictor:

  config = {
      "rnn_layers": 2,
      "rnn_size": 256,
      "rnn_bidirectional": True,
      "max_length": 100,
      "max_words": 100000,
      "dim_embeddings": 300,
      "batch_size": 128
  }
  default_config = config.copy()

  def __init__(self, model, config_path = "TextGenerator_config.json", weights_path = None, vocab_path = None, name = "TextGenerator"):
    if config_path is not None:
      with open(config_path, "r", encoding = "utf8", errors = "ignore") as json_file:
        self.config = json.load(json_file)
    if weights_path is None:
      weights_path = resource_filename(__name__, "TextGenerator_weights.hdf5")
    if vocab_path is None:
      vocab_path = resource_filename(__name__, "TextGenerator_vocab.json")
    if os.path.exists(vocab_path):
      with open(vocab_path, "r", encoding = "utf8", errors = "ignore") as json_file:
        self.vocab = json.load(json_file)
        self.num_classes = len(self.vocab) + 1
        self.indices_char = dict((self.vocab[c], c) for c in self.vocab)
    self.model = model

  def generate(self, n = 1, return_as_list = False, prefix = None, temperature = [1.0, 0.5, 0.2, 0.2], max_gen_length = 300, interactive = False, top_n = 3, progress = True):
    gen_texts = []
    iterable = trange(n) if progress and n > 1 else range(n)
    for _ in iterable:
      gen_text, _ = self.generate_text(
          self.config["max_length"],
          max_gen_length,
          top_n,
          temperature, 
          self.meta_token,
          interactive, 
          prefix
      )
      if not return_as_list:
        print("{}\n".format(gen_text))
      gen_texts.append(gen_text)
    if return_as_list:
      return gen_texts
  
  def generate_samples(self, n = 3, temperatures = [0.2, 0.5, 1.0], **kwargs):
    for temperature in temperatures:
      print("#" * 20 + "\nTemperature: {}\n".formate(temperature) + "#" * 20)
      self.generate(n, temperature = temperature, progress = False, **kwargs)
      
  def generate_text(self, maxlen, top_n, temperature = 0.5, interactive = False, prefix = None, synthesize = False, stop_tokens = [" ", "\n"]):
    """ Generates and returns a single text. """
    collapse_char = ""
    end = False
    text = [meta_token] + jieba.lcut(prefix) if prefix else [meta_token]
    next_char = ""
    if not isinstance(temperature, list):
      temperature = [temperature]
    while not end and len(text) < self.config["max_length"]:
      encoded_text = self.prediction_encode_sequence(text[-maxlen:], self.vocab, maxlen)
      next_temperature = temperature[(len(text) - 1) % len(temperature)]
      if not interactive:
        # auto-generate text without user intervention
        next_index = self.generate_sample(self.model.predict(encoded_text, batch_size = 1)[0], next_temperature)
        next_char = self.indices_char[next_index]
        text += [next_char]
        if next_char == meta_token or len(text) >= self.config["max_length"]:
          end = True
        gen_break = (next_char in stop_tokens or len(stop_tokens) == 0)
        if synthesize and gen_break:
          break
      else:
        # ask user what the next char/word should be
        options_index = self.generate_sample(self.model.predict(encoded_text, batch_size = 1)[0], next_temperature, interactive = interactive, top_n = top_n)
        options = [self.indices_char[idx] for idx in options_index]
        print("Controls:\n\ts: stop.\tx: backspace.\to: write your own.")
        print("\nOptions:")
        for i, option in enumerate(options, 1):
          print("\t{}: {}".format(i, option))
        print("\nProgress: {}".format(collapse_char.join(text)[3:]))
        print("\nYour choice?")
        user_input = input("> ")
        try:
          user_input = int(user_input)
          next_char = options[user_input - 1]
          text += [next_char]
          if next_char == "<s>":
            end = True
        except ValueError:
          if user_input == "s":
            next_char = "<s>"
            text += [next_char]
          elif user_input == "o":
            other = input("> ")
            text += [other]
          elif user_input == "x":
            try:
              del text[-1]
            except IndexError:
              pass
          else:
            print("That\'s not an option!")
    # if not single text, remove the <s> meta_tokens
    text = text[1:]
    if meta_token in text:
      text.remove(meta_token)
    text_joined = collapse_char.join(text)
    return text_joined, end

  def prediction_encode_sequence(self, text, vocab, maxlen):
    """ Encodes a text into the corresponding encoding for prediction with the model. """
    encoded = np.array([vocab.get(x, 0) for x in text])
    return sequence.pad_sequences([encoded], maxlen = maxlen)
  
  def generate_sample(self, preds, temperature, interactive = False, top_n = 3):
    """ Samples predicted probabilities of the next character to allow for the network to show "creativity". """
    preds = np.asarray(preds).astype("float64")
    if temperature is None or temperature == 0.0:
      return np.argmax(preds)
    preds = np.log(preds + K.epsilon()) / temperature
    exp_preds = np.exp(preds)
    preds = exp_preds / np.sum(exp_preds)
    probas = np.random.multinomial(1, preds, 1)
    if not interactive:
      index = np.argmax(probas)
      # prevent function from being able to choose 0 (placeholder)
      # choose 2nd best index from preds
      if index == 0:
        index = np.argsort(preds)[-2]
    else:
      # return list of top N chars/words descending order, based on probability
      index = (-preds).argsort()[:top_n]
    return index

In [23]:
textgen = TextGenerator(weights_path = "NLG_2021_05_21/TextGenerator_weights_epoch_14.hdf5", vocab_path = "NLG_2021_05_21/TextGenerator_vocab.json", config_path = "NLG_2021_05_21/TextGenerator_config.json")
# mod = textgen.build_textgen_model(dropout = 0.1)
# mod.load_weights("NLG_2021_05_21\TextGenerator_weights_epoch_14.hdf5")

# pred = Predictor(mod)
# predictor.generate(n = 3, prefix = "规划", temperature = [0.65])

In [27]:
textgen.generate(n = 3, prefix = "规划", temperature = [0.5])

 33%|███▎      | 1/3 [00:03<00:06,  3.48s/it]

规划的配套设施不错 有一定的升值的空间和发展的潜力 周围的配套设施也比较的好 周围的环境也比较好 



 67%|██████▋   | 2/3 [00:05<00:03,  3.00s/it]

规划的项目 现在规划的配套很给力 出门就是高速路 有些远



100%|██████████| 3/3 [00:07<00:00,  2.52s/it]

规划的配套很完善 有沃尔玛 天虹 沃尔玛 华润万家 天虹 沃尔玛等 






In [None]:
predictor = Predictor(textgen.model)
predictor.generate(n = 1, prefix = "规划", temperature = [0.8], interactive = True)

Controls:
	s: stop.	x: backspace.	o: write your own.

Options:
	1: 的
	2: 很
	3: 不错

Progress: 规划

Your choice?
> 3
Controls:
	s: stop.	x: backspace.	o: write your own.

Options:
	1: ，
	2: 的
	3: 了

Progress: 规划不错

Your choice?
> 1
Controls:
	s: stop.	x: backspace.	o: write your own.

Options:
	1: 周边
	2: 配套
	3: 户型

Progress: 规划不错，

Your choice?
> 3
Controls:
	s: stop.	x: backspace.	o: write your own.

Options:
	1: 也
	2: 方正
	3: 不错

Progress: 规划不错，户型

Your choice?
> 2
Controls:
	s: stop.	x: backspace.	o: write your own.

Options:
	1: 实用
	2: ，
	3: 合理

Progress: 规划不错，户型方正

Your choice?
> 2
Controls:
	s: stop.	x: backspace.	o: write your own.

Options:
	1: 使用率
	2: 性价比
	3: 南北

Progress: 规划不错，户型方正，

Your choice?
> 3
Controls:
	s: stop.	x: backspace.	o: write your own.

Options:
	1: 通透
	2: 通
	3: 合理

Progress: 规划不错，户型方正，南北

Your choice?
> 1
Controls:
	s: stop.	x: backspace.	o: write your own.

Options:
	1: ，
	2: 。
	3: <s>

Progress: 规划不错，户型方正，南北通透

Your choice?
> 3
规划不错，户型方正，南北通透

