In [12]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.ops import lookup_ops
from tensorflow.python.training.tracking import tracking


from absl import app
from absl import flags

import tensorflow.compat.v2 as tf
import os
import tempfile


In [13]:
from tensorflow.keras.layers import Dense, Flatten, Embedding, LSTM,Input
from tensorflow.keras import Model


In [97]:
class LanguageModel(Model):
    def __init__(self,state_size,logit_units):
        super(LanguageModel, self).__init__()
        self._lstm_layer = tf.keras.layers.LSTM(state_size,return_sequences=True)
        self._logit_layer = tf.keras.layers.Dense(logit_units)
        
    def __call__(self,sentence_embeddings):
        lstm_output = self._lstm_layer(sentence_embeddings)
        lstm_output = tf.reshape(lstm_output, [-1,self._lstm_layer.units])
        logits = self._logit_layer(lstm_output)
        return logits

        
        

In [98]:
def write_vocabulary_file(vocabulary):
  """Write temporary vocab file for module construction."""
  tmpdir = tempfile.mkdtemp()
  vocabulary_file = os.path.join(tmpdir, "tokens.txt")
  with tf.io.gfile.GFile(vocabulary_file, "w") as f:
    for entry in vocabulary:
      f.write(entry + "\n")
  return vocabulary_file

In [99]:
class ULMFiTModule(tf.train.Checkpoint):
  """
  LATER 
  """

  def __init__(self, vocab, emb_dim, buckets, state_size):
    super(ULMFiTModule, self).__init__()
    self._buckets = buckets
    self._vocab_size = len(vocab)
    self.emb_row_size = self._vocab_size+self._buckets
    self._embeddings = tf.Variable(tf.random.uniform(shape=[self.emb_row_size, emb_dim]))
    print(self._embeddings.shape)
    self.model = LanguageModel(state_size,self.emb_row_size)
    self._vocabulary_file = tracking.TrackableAsset(write_vocabulary_file(vocab)) 
    self.w2i_table = lookup_ops.index_table_from_file(
                    vocabulary_file= self._vocabulary_file,
                    num_oov_buckets=self._buckets,
                    hasher_spec=lookup_ops.FastHashSpec)
    self.i2w_table = lookup_ops.index_to_string_table_from_file(
                    vocabulary_file=self._vocabulary_file, 
                    delimiter = '\n',
                    default_value="UNKNOWN")

    
  def _tokenize(self, sentences):
    # Perform a minimalistic text preprocessing by removing punctuation and
    # splitting on spaces.
    normalized_sentences = tf.strings.regex_replace(
        input=sentences, pattern=r"\pP", rewrite="")
    sparse_tokens = tf.strings.split(normalized_sentences, " ").to_sparse()

    # Deal with a corner case: there is one empty sentence.
    sparse_tokens, _ = tf.sparse.fill_empty_rows(sparse_tokens, tf.constant(""))
    # Deal with a corner case: all sentences are empty.
    sparse_tokens = tf.sparse.reset_shape(sparse_tokens)

    return (sparse_tokens.indices, sparse_tokens.values,
            sparse_tokens.dense_shape)
    
  def _indices_to_words(self, indices):
    #return tf.gather(self._vocab_tensor, indices)
    return self.i2w_table.lookup(indices)
    

  def _words_to_indices(self, words):
    #return tf.strings.to_hash_bucket(words, self._buckets)
    return self.w2i_table.lookup(words)
    

  @tf.function(input_signature=[tf.TensorSpec([None], tf.dtypes.string),tf.TensorSpec([None], tf.dtypes.string)])
  def train(self, sentences,validation_sentences=None):
    token_ids, token_values, token_dense_shape = self._tokenize(sentences)
    tokens_sparse = tf.sparse.SparseTensor(
        indices=token_ids, values=token_values, dense_shape=token_dense_shape)
    tokens = tf.sparse.to_dense(tokens_sparse, default_value="")

    sparse_lookup_ids = tf.sparse.SparseTensor(
        indices=tokens_sparse.indices,
        values=self._words_to_indices(tokens_sparse.values),
        dense_shape=tokens_sparse.dense_shape)
    lookup_ids = tf.sparse.to_dense(sparse_lookup_ids, default_value=0)
    
    # Targets are the next word for each word of the sentence.
    tokens_ids_seq = lookup_ids[:, 0:-1]
    tokens_ids_target = lookup_ids[:, 1:]
    tokens_prefix = tokens[:, 0:-1]

    # Mask determining which positions we care about for a loss: all positions
    # that have a valid non-terminal token.
    mask = tf.logical_and(
        tf.logical_not(tf.equal(tokens_prefix, "")),
        tf.logical_not(tf.equal(tokens_prefix, "<E>")))

    input_mask = tf.cast(mask, tf.int32)

    with tf.GradientTape() as t:
      sentence_embeddings = tf.nn.embedding_lookup(self._embeddings,
                                                   tokens_ids_seq)
    
      logits = self.model(sentence_embeddings)
      

      targets = tf.reshape(tokens_ids_target, [-1])
      weights = tf.cast(tf.reshape(input_mask, [-1]), tf.float32)

      losses = tf.nn.sparse_softmax_cross_entropy_with_logits(
          labels=targets, logits=logits)

      # Final loss is the mean loss for all token losses.
      final_loss = tf.math.divide(
          tf.reduce_sum(tf.multiply(losses, weights)),
          tf.reduce_sum(weights),
          name="final_loss")

    watched = t.watched_variables()
    gradients = t.gradient(final_loss, watched)

    for w, g in zip(watched, gradients):
      w.assign_sub(g)

    return final_loss,logits

  @tf.function
  def decode_greedy(self, sequence_length, first_word):
    #initial_state = self._lstm_cell.get_initial_state(
    #    dtype=tf.float32, batch_size=1)

    sequence = [first_word]
    current_word = first_word
    current_id = tf.expand_dims(self._words_to_indices(current_word), 0)
    #current_state = initial_state

    for _ in range(sequence_length):
      token_embeddings = tf.nn.embedding_lookup(self._embeddings, current_id)
      
      logits = self.model(tf.expand_dims(token_embeddings,0))
      softmax = tf.nn.softmax(logits)

      next_ids = tf.math.argmax(softmax, axis=1)
      next_words = self._indices_to_words(next_ids)[0]
      
      current_id = next_ids
      current_word = next_words
      sequence.append(current_word)

    return sequence


In [100]:
sentences = ["<S> hello there <E>", "<S> how are you doing today <E>","<S> I am fine thank you <E>",
             "<S> hello world <E>", "<S> who are you? <E>"]
vocab = [
      "<S>", "<E>", "hello", "there", "how", "are", "you", "doing", "today","I","am","fine","thank","world",
    "who"
  ]

module = ULMFiTModule(vocab=vocab, emb_dim=10, buckets=1, state_size=128)

for _ in range(5):
    _,logits = module.train(tf.constant(sentences),tf.constant(sentences))
    print(_,logits.shape)


(16, 10)
(None, None, 128)


  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "


(None, None, 128)


  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "


tf.Tensor(2.7970374, shape=(), dtype=float32) (30, 16)
tf.Tensor(2.5767746, shape=(), dtype=float32) (30, 16)
tf.Tensor(2.4554627, shape=(), dtype=float32) (30, 16)
tf.Tensor(2.3978527, shape=(), dtype=float32) (30, 16)
tf.Tensor(2.3649178, shape=(), dtype=float32) (30, 16)


In [44]:
 # We have to call this function explicitly if we want it exported, because it
  # has no input_signature in the @tf.function decorator.
decoded = module.decode_greedy(sequence_length=10, first_word=tf.constant("<S> you"))
_ = [d.numpy() for d in decoded]
print(_)


[b'<S> you', b'are', b'you', b'there', b'<E>', b'fine', b'thank', b'you', b'there', b'<E>', b'fine']


In [64]:
tf.saved_model.save(module,"test")

W0816 16:23:56.693716  4508 saved_model.py:758] Skipping full serialization of Keras model <__main__.LanguageModel object at 0x000001F0185C14A8>, because its inputs are not defined.
W0816 16:23:56.695698  4508 saved_model.py:765] Skipping full serialization of Keras layer <tensorflow.python.keras.layers.recurrent_v2.LSTMCell object at 0x000001F0182CD438>, because it is not built.
W0816 16:23:56.697700  4508 saved_model.py:765] Skipping full serialization of Keras layer <tensorflow.python.keras.layers.recurrent.RNN object at 0x000001F0197AEC50>, because it is not built.


In [65]:
b = tf.saved_model.load("test")

In [66]:
d = b.decode_greedy(sequence_length=10,first_word=tf.constant("<S> Hello"))
_ = [d.numpy() for d in decoded]
print(_)

[b'<S> Hello', b'hello', b'there', b'I', b'am', b'am', b'am', b'am', b'am', b'am', b'am']


# Classifier Head 


Classifier head takes in the final layer output of the languaage model and first gets the average pool and max pool of the 
final layer outputs, then passes the concatanation of last time steps hidden state, max pool results and average pool results through given number Dense-dropout-batchnormalization blocks. Finally it produces the classifier output probabilities.

In [93]:
class LanguageClassifier(Model):
    def __init__(self,num_labels,dense_units=(128,128),dropouts=(0.1,0.1)):
        super(LanguageClassifier,self).__init__()
        self.dense_layers = [Dense(units,activation="relu") for units in dense_units]
        self.dropout_layers = [Dropout(p) for p in dropouts]
        self.max_pool_layer = GlobalMaxPooling1D()
        self.average_pool_layer = GlobalAveragePooling1D()
        self.batchnorm_layer = BatchNormalization()
        self.n_layers = len(self.dense_layers)
        self.final_layer = Dense(num_labels,activation="sigmoid")
        
    def __call__(self,encoder_output):
        self.enc_out = encoder_output
        last_h = self.enc_out[:,-1,:]
        max_pool_output = max_pool_layer(self.enc_out)
        average_pool_output = average_pool_layer(self.enc_out)
        
        output = concatenate([last_h,max_pool_output,average_pool_output])
        
        for i in range(self.n_layers):
            output = self.dense_layers[i](output)
            output = self.dropout_layers[i](output)
            output = self.batchnorm_layer(output)
        
        final_output = self.final_layer(output)
        return final_output        