In [0]:
import tensorflow as tf
#https://github.com/BenStringer3/SeqGan/blob/31638d223de44a6ad1dfb884512c2fb0703a445c/pre_train_disc.py

In [0]:
class Discrim(object):
  def __init__(self, vocab_size, embedding_dim, latent_dim, batch_size, seq_length):
    super(Discrim, self).__init__()
    self.vocab_size = vocab_size
    self.latent_dim = latent_dim
    self.seq_length = seq_length
    self.batch_size = batch_size
    self.embedding_dim = embedding_dim
    self.inputs = tf.keras.layers.Input(shape=(None,), dtype="int32")
    self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
    self.lstm = tf.keras.layers.LSTM(latent_dim, return_sequences = True, return_state = True, recurrent_initializer="glorot_uniform")
    self.dropout = tf.keras.layer.Dropout(0.2)
    self.fc1 = tf.keras.layers.Dense(1)

    self.optim = tf.optimizers.Adam(learning_rate=0.01)
    self.cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

    self.model = self.build()
  def build(self):
    embed = self.embedding(self.inputs)
    x, h, c = self.lstm(embed)
    pred = self.fc(x)
    model = tf.keras.Model(self.inputs, pred)
    return model
  def loss(self,real,fake):
    real_loss = self.cross_entropy(tf.ones_like(real), real)
    fake_loss = self.cross_entropy(tf.zeros_like(fake),fake)
    total = real_loss + fake_loss
    return total
  def train_step(self,fake, real):
    with tf.GradientTape() as disc_tape:
      fake_pred = self.model(fake)
      real_pred = self.model(real)
      loss = self.loss(real_pred, fake_pred)
    grad = disc_tape.gradient(loss, self.model.trainable_variables)
    update = self.optim.apply_gradients(zip(grad, self.model.trainable_variables))
    return loss


In [0]:
class Gen(object):
  def __init__(self, vocab_size, embedding_dim, latent_dim, batch_size, seq_len, start_token):
    super(Gen, self).__init__()
    self.vocab_size = vocab_size
    self.latent_dim = latent_dim
    self.seq_length = seq_length
    self.batch_size = batch_size
    self.embedding_dim = embedding_dim
    self.start_token = tf.identity(tf.constant([start_token]*batch_size))
    self.optim = tf.optimizers.Adam(learning_rate=0.01)
    self.inputs = tf.keras.layers.Input(shape=(None,), dtype="int32")
    self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
    self.lstm = tf.keras.layers.LSTM(latent_dim, return_sequences = True, return_state = True, recurrent_initializer="glorot_uniform")
    self.dropout = tf.keras.layer.Dropout(0.2)
    self.fc1 = tf.keras.layers.Dense(vocab_size)

    self.model = self.build()
  def build(self):
    embed = self.embedding(self.inputs)
    x, h, c = self.lstm(embed)
    pred = self.fc(x)
    model = tf.keras.Model(self.inputs, pred)
    return model
  def generate(self, seq_len = None):
    if seq_len is None:
      seq_len = self.seq_length
    sequence = tf.TensorArray(dtype=tf.int32, size=seq_len, dynamic_size = False, infer_shape=True)

    def g_recurr(i, x_t0, sequence):
      x_t0 = tf.reshape(x_t0, [self.batch_size, 1])
      o_t = self.model(x_t0)
      log_prob = tf.math.log(tf.nn.softmax(o_t))
      x_t1 = tf.cast(tf.reshape(tf.random.categorical(log_prob, 1), [self.batch_size]), tf.int32)
      gen_x = gen_x.write(i, x_t1)
      return i + 1, x_t1, gen_x

    gen_x = gen_x.write(0, self.start_token)

    _, _,  self.gen_x = tf.while_loop(
            cond=lambda i, _1, _2: i < seq_len,
            body=_g_recurr,
            loop_vars=(tf.constant(1, dtype=tf.int32),self.start_token,gen_x))
    
    self.gen_x = self.gen_x.stack()  # seq_length x batch_size
    self.gen_x = tf.transpose(self.gen_x, perm=[1, 0])  # batch_size x seq_length
    self.model.reset_states()
    return self.gen_x
  def gen_predictions(self, x, training=False): # x in token form [batch_size, seq_length]
        g_predictions = tf.TensorArray(
            dtype=tf.float32, size=self.sequence_length,
            dynamic_size=False, infer_shape=True)

        x_transposed = tf.cast(tf.transpose(x), dtype=tf.int32)
        ta_x = tf.TensorArray(
            dtype=tf.int32, size=self.sequence_length)
        ta_x = ta_x.unstack(x_transposed)

        def _pretrain_recurrence(i, x_t, g_predictions):
            x_t = tf.reshape(x_t, [self.batch_size, 1])
            o_t = self.model(x_t, training=training)
            g_predictions = g_predictions.write(i, tf.nn.softmax(o_t))  # batch x vocab_size
            x_tp1 = ta_x.read(i)
            return i + 1, x_tp1, g_predictions

        ta_x.write(0, self.start_token)
        _, _, self.g_predictions = tf.while_loop(
            cond=lambda i, _1, _2: i < self.sequence_length,
            body=_pretrain_recurrence,
            loop_vars=(tf.constant(1, dtype=tf.int32),
                      ta_x.read(0),
                        g_predictions))

        self.g_predictions = tf.transpose(self.g_predictions.stack(), perm=[1, 0, 2])  # batch_size x seq_length x vocab_size
        self.model.reset_states()
        return self.g_predictions
  def train_step(self, samples, rewards):

        with tf.GradientTape() as tape:
            loss = self.get_loss(samples, rewards)

        g_grad, _ = tf.clip_by_global_norm(
            tape.gradient(loss, self.model.trainable_variables), 5.0)
        g_updates = self.optimizer.apply_gradients(
            zip(g_grad, self.model.trainable_variables))

        return loss
  def get_pretrain_loss(self, labels, samples): # labels as tokens, samples as prob distr
        loss = tf.keras.losses.sparse_categorical_crossentropy(labels, samples,from_logits=False)
        return loss

  def get_loss(self, x, rewards):
      g_predictions = self.gen_predictions(x)
      loss = -tf.reduce_sum(
          tf.reduce_sum(
              tf.one_hot(tf.cast(tf.reshape(x, [-1]), tf.int32), self.vocab_size,
                          1.0, 0.0) * tf.math.log(
                  tf.clip_by_value(
                      tf.reshape(g_predictions, [-1, self.vocab_size]),
                      1e-20, 1.0)
              ), 1) * tf.reshape(rewards, [-1])
      )
      return loss