## Installations

In [None]:
!pip install -q tf-models-official

## Imports

In [None]:
import os
import re
import numpy as np
from tqdm import tqdm

# Tensorflow
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.callbacks import Callback
from tensorflow.keras import layers

# BERT
import tensorflow_hub as hub
import tensorflow_text as text

# Custom activation function
import tensorflow_probability as tfp
from keras.engine.base_layer import Layer
from keras import backend as K

# Embeddings
import gensim  
import gensim.downloader as gloader

# Data & Pre-processing
import nltk
nltk.download('wordnet')
from nltk.stem import WordNetLemmatizer
import pandas as pd


# Defining hyperparameters
BUFFER_SIZE = 10000
EMBED_DIM = 100
LATENT_DIM = 512
NUM_HEADS = 8
BATCH_SIZE = 512
TAU = 5  # used by the Gumbel Activation
ALPHA = 0.99  # used by the Gumbel Activation
FREQ_CHANGE = 5  # used by the Gumbel Activation, freq of change in epochs

# Datasets
pre_train_dataset = "/gdrive/MyDrive/NLP/UKP_ASPECT/UKP_ASPECT.tsv"
train_dataset = "/gdrive/MyDrive/NLP/arg_quality_rank_30k.csv"

# Embedding model
embed_model = "/gdrive/MyDrive/NLP/glove_{}_pickle".format(EMBED_DIM)

# BERT model weights
model_weights = "/gdrive/MyDrive/NLP/classifierIBM30k.h5"

In [None]:
from google.colab import drive
drive.mount('/gdrive')

## Data

### Methods

In [None]:
lemmatizer = WordNetLemmatizer()
def preprocess_pretrain(sentence):
    sentence = sentence.lower()
    # Adding a space between the punctuation and the last word to allow better tokenization
    sentence = re.sub("([?.!,])", r" \1 ", sentence)
    sentence = sentence.strip()
    sentence = " ".join(["[start]", sentence, "[end]"])
    sentence = " ".join([lemmatizer.lemmatize(x) for x in sentence.split()])
    return sentence


def preprocess_train(text):
    text = re.sub('\"|-|\\\\|`', ' ', text)  # delete this chars from the string ["-\`]
    text = re.sub('\n', ' ', text)
    text = re.sub('^[.]+', '', text)         # delete dots at the beginning of the sentence
    text = re.sub("([?.!,])", r" \1 ", text)
    text = re.sub('\. \.', '.', text)        # delete . .
    text = re.sub('&', ' and ', text)        # replace & with and
    text = re.sub(' +', ' ', text)           # delete additional whitespace
    text = text.rstrip()                  
    text = text.lstrip()
    text = " ".join([lemmatizer.lemmatize(x) for x in text.split()])
    return text


def load_embedding_model(model_type: str,
                         embedding_dimension: int = 50) -> gensim.models.keyedvectors.KeyedVectors:
    """
    Loads a pre-trained word embedding model via gensim library.

    :param model_type: name of the word embedding model to load.
    :param embedding_dimension: size of the embedding space to consider

    :return
        - pre-trained word embedding model (gensim KeyedVectors object)
    """

    download_path = ""

    # Find the correct embedding model name
    if model_type.strip().lower() == 'word2vec':
        download_path = "word2vec-google-news-300"

    elif model_type.strip().lower() == 'glove':
        download_path = "glove-wiki-gigaword-{}".format(embedding_dimension)
    elif model_type.strip().lower() == 'fasttext':
        download_path = "fasttext-wiki-news-subwords-300"
    else:
        raise AttributeError("Unsupported embedding model type! Available ones: word2vec, glove, fasttext")

    # Check download
    try:
        emb_model = gloader.load(download_path)
    except ValueError as e:
        print("Invalid embedding model name! Check the embedding dimension:")
        print("Word2Vec: 300")
        print("Glove: 50, 100, 200, 300")
        raise e

    return emb_model


def check_OOV_terms(embedding_vocabulary, word_listing):
    """
    Checks differences between pre-trained embedding model vocabulary
    and dataset specific vocabulary in order to highlight out-of-vocabulary terms.

    :param embedding_vocabulary: pre-trained word embedding model vocab (list)
    :param word_listing: dataset specific vocabulary (list)

    :return
        - list of OOV terms
    """
    
    oov = set(word_listing).difference(embedding_vocabulary)
    return list(oov)


def build_embedding_matrix(embedding_model,
                           embedding_dimension,
                           word_to_idx,
                           vocab_size,
                           oov_terms):
    """
    Builds the embedding matrix of a specific dataset given a pre-trained word embedding model

    :param embedding_model: pre-trained word embedding model (gensim wrapper)
    :param word_to_idx: vocabulary map (word -> index) (dict)
    :param vocab_size: size of the vocabulary
    :param oov_terms: list of OOV terms (list)

    :return
        - embedding matrix that assigns a high dimensional vector to each word in the dataset specific vocabulary (shape |V| x d)
    """
    embedding_matrix = np.zeros((vocab_size, embedding_dimension), dtype=np.float32)

    for word, idx in tqdm(word_to_idx.items()):
        try:
            embedding_vector = embedding_model[word]
        except (KeyError, TypeError):
            embedding_vector = np.random.uniform(low=-0.05, high=0.05, size=embedding_dimension)

        embedding_matrix[idx] = embedding_vector

    return embedding_matrix


def update_embedding_matrix(embedding_model, 
                            embedding_dimension,
                            word_to_idx,
                            vocab_size,
                            oov_terms):
    """
    Builds the embedding matrix of a specific dataset given a pre-trained emdedding matrix

    :param embedding_model: pre-trained embedding matrix
    :param word_to_idx: vocabulary map (word -> index) (dict)
    :param vocab_size: size of the vocabulary
    :param oov_terms: list of OOV terms (list)

    :return
        - embedding matrix that assigns a high dimensional vector to each word in the dataset specific vocabulary (shape |V| x d)
    """
    embedding_matrix = np.zeros((vocab_size, embedding_dimension), dtype=np.float32)

    for word, idx in tqdm(word_to_idx.items()):
        try:
            embedding_vector = embedding_model[idx]
        except (TypeError, IndexError):
            embedding_vector = np.random.uniform(low=-0.05, high=0.05, size=embedding_dimension)

        embedding_matrix[idx] = embedding_vector

    return embedding_matrix


class KerasTokenizer(object):
    """
    A simple high-level wrapper for the Keras tokenizer.
    """

    def __init__(self, build_embedding_matrix=False, embedding_dimension=None,
                 embedding_model_type=None, tokenizer_args=None, embedding_model=None):
        if build_embedding_matrix:
            assert embedding_model_type is not None
            assert embedding_dimension is not None and type(embedding_dimension) == int

        self.build_embedding_matrix = build_embedding_matrix
        self.embedding_dimension = embedding_dimension
        self.embedding_model_type = embedding_model_type
        self.embedding_model = embedding_model
        self.embedding_matrix = None
        self.vocab = None

        tokenizer_args = {} if tokenizer_args is None else tokenizer_args
        assert isinstance(tokenizer_args, dict) or isinstance(tokenizer_args, collections.OrderedDict)

        self.tokenizer_args = tokenizer_args

    def build_vocab(self, data, **kwargs):
        print('Fitting tokenizer...')
        self.tokenizer = tf.keras.preprocessing.text.Tokenizer(**self.tokenizer_args)
        self.tokenizer.fit_on_texts(data)
        print('Fit completed!')

        self.vocab = self.tokenizer.word_index

        if self.build_embedding_matrix:
            if self.embedding_model is None:
              print('Loading embedding model! It may take a while...')
              self.embedding_model = load_embedding_model(model_type=self.embedding_model_type, 
                                                          embedding_dimension=self.embedding_dimension)
            
            print('Checking OOV terms in train...')
            self.oov_terms_train = check_OOV_terms(embedding_vocabulary=set(self.embedding_model.vocab.keys()),
                                             word_listing=list(self.vocab.keys()))
            
            print("Total OOV terms: {0} ({1:.2f}%)".format(len(self.oov_terms_train), 100*float(len(self.oov_terms_train)) / len(self.vocab)))

            print('Building the embedding matrix for train...')
            self.embedding_matrix = build_embedding_matrix(embedding_model=self.embedding_model,
                                                           word_to_idx=self.vocab,
                                                           vocab_size=len(self.vocab)+1,          
                                                           embedding_dimension=self.embedding_dimension,
                                                           oov_terms=self.oov_terms_train)
            print('Done for train!')

    def update_vocab(self, data, **kwargs):
      self.tokenizer.fit_on_texts(data)
      if self.build_embedding_matrix:
        old_vocab = self.vocab
        self.vocab = self.tokenizer.word_index
        print('Checking OOV terms...')
        self.oov_terms = check_OOV_terms(embedding_vocabulary=set(old_vocab.keys()), 
                                         word_listing=list(self.vocab.keys()))
        
        print("Total OOV terms: {0} ({1:.2f}%)".format(len(self.oov_terms), 100*float(len(self.oov_terms)) / len(self.vocab)))

        print('Building the embedding matrix...')
        self.embedding_matrix = update_embedding_matrix(embedding_model=self.embedding_matrix,
                                                       word_to_idx=self.vocab,
                                                       vocab_size=len(self.vocab)+1,          
                                                       embedding_dimension=self.embedding_dimension,
                                                       oov_terms=self.oov_terms)

    def get_info(self):
        return {
            'build_embedding_matrix': self.build_embedding_matrix,
            'embedding_dimension': self.embedding_dimension,
            'embedding_model_type': self.embedding_model_type,
            'embedding_matrix': self.embedding_matrix.shape if self.embedding_matrix is not None else self.embedding_matrix,
            'embedding_model': self.embedding_model,
            'vocab_size': len(self.vocab) + 1,
        }

    def tokenize(self, text):
        return text

    def convert_tokens_to_ids(self, tokens):
        if type(tokens) == str:
            return self.tokenizer.texts_to_sequences([tokens])[0]
        else:
            return self.tokenizer.texts_to_sequences(tokens)

    def convert_ids_to_tokens(self, ids):
        return self.tokenizer.sequences_to_texts(ids)


def convert_text(df, tokenizer, is_training=False, max_seq_length=None):
    """
    Converts input text sequences using a given tokenizer

    :param texts: either a list or numpy ndarray of strings
    :tokenizer: an instantiated tokenizer
    :is_training: whether input texts are from the training split or not
    :max_seq_length: the max token sequence previously computed with
    training texts.

    :return
        text_ids: a nested list on token indices
        max_seq_length: the max token sequence previously computed with
        training texts.
    """

    text_ids = tokenizer.convert_tokens_to_ids(df)

    # Padding
    if is_training:
        max_seq_length = int(np.quantile([len(seq) for seq in text_ids], 0.95))
    else:
        assert max_seq_length is not None

    text_ids = [seq + [0] * (max_seq_length - len(seq)) for seq in text_ids]
    text_ids = np.array([seq[:max_seq_length] for seq in text_ids])

    if is_training:
        return text_ids, max_seq_length
    else:
        return text_ids


def decode_sentence(input_sentence, preprocess):
    # Mapping the input sentence to tokens and adding start and end tokens
    tokenized_input_sentence = tokenizer.convert_tokens_to_ids(
        [preprocess(input_sentence)]
    )[0]
    tokenized_input_sentence = tf.pad(
        tokenized_input_sentence,
        [[0, max_seq_length - tf.shape(tokenized_input_sentence)[0]]])
    # Initializing the initial sentence consisting of only the start token.
    tokenized_target_sentence = tf.expand_dims(tokenizer.vocab["[start]"], 0)
    decoded_sentence = ""

    predictions = fnet.predict(
        {"encoder_inputs": tf.expand_dims(tokenized_input_sentence, 0)}
    )

    sampled_token_index = tf.argmax(predictions[0, :, :], axis=1)
    decoded_sentence = tokenizer.convert_ids_to_tokens([sampled_token_index.numpy()])
    
    return decoded_sentence

### Loading 

Pre-training data

In [None]:
pre_df = pd.read_csv(pre_train_dataset, sep='\t')
pre_df = pre_df[pre_df['label']!="NS"]
pre_df = pre_df[pre_df['label']!="DTORCD"]
pre_df = pre_df.reset_index()
pre_df = pre_df.drop(["index", "topic", "label"], axis=1)
pre_df.head()

Training data

In [None]:
df = pd.read_csv(train_dataset)
df = df.drop(["WA", "stance_WA", "stance_WA_conf"], axis=1)
df.head()

### Data preprocessing

Pre-training

In [None]:
sentence_1 = list(pre_df['sentence_1'].apply(preprocess_pretrain).values)
sentence_2 = list(pre_df['sentence_2'].apply(preprocess_pretrain).values)
sentence_1[:5]

Training

In [None]:
df['argument'] = df.apply(lambda row : preprocess_train(row['argument']), axis = 1)
df.loc[2,"argument"] = "zero tolerance policy in schools should not be adopted as circumstances are often not black and white, being more nuanced. no one should be written off due to a mistake of judgement."
df.head()

### Train, test, val splits

In [None]:
is_training_data =  df['set']=='train'
is_validation_data =  df['set']=='dev'
is_test_data =  df['set']=='test'

training_data = df[is_training_data]
validation_data = df[is_validation_data]
test_data  = df[is_test_data ]

x_train = training_data['argument']
x_val = validation_data['argument']
x_test = test_data['argument']

### Tokenization

In [None]:
# load embeddings from glove
import pickle
if os.path.exists(embed_model):
  with open(embed_model, "rb") as f:
    embedding_model = pickle.load(f)
else:
  embedding_model = load_embedding_model(model_type="glove", 
                                         embedding_dimension=EMBED_DIM)

In [None]:
# creating tokenizer and vocabulary

tokenizer_args = {
    'oov_token': "OOV_TOKEN",  # The vocabulary id for unknown terms during text conversion
    'lower' : True,  # default
    'filters' : '' 
}

tokenizer = KerasTokenizer(tokenizer_args=tokenizer_args,
                           build_embedding_matrix=True,
                           embedding_dimension=EMBED_DIM,
                           embedding_model_type="glove", 
                           embedding_model=embedding_model)

tokenizer.build_vocab(sentence_1)
tokenizer.update_vocab(sentence_2)
tokenizer.update_vocab(x_train)
tokenizer.update_vocab(x_val)
VOCAB_SIZE = len(tokenizer.vocab)

tokenizer_info = tokenizer.get_info()

print('Tokenizer info: ', tokenizer_info)

### Tokenizing and padding sentences using `TextVectorization`

Training

In [None]:
x_train, max_seq_length = convert_text(x_train, tokenizer, True)
x_val = convert_text(x_val, tokenizer, max_seq_length=max_seq_length)
print("Max token sequence: {}".format(max_seq_length))
print('X train shape: ', x_train.shape)
print('X val shape: ', x_val.shape)

Pre-training

In [None]:
x_pretrain = convert_text(sentence_1, tokenizer, 
                          max_seq_length=max_seq_length)
x_preval = convert_text(sentence_2, tokenizer, max_seq_length=max_seq_length)
print('X pre-train shape: ', x_pretrain.shape)
print('X pre-val shape: ', x_preval.shape)

### Tensorflow Dataset for Pre-training

In [None]:
train_dataset = tf.data.Dataset.from_tensor_slices((x_pretrain, x_pretrain))
val_dataset = tf.data.Dataset.from_tensor_slices((x_preval, x_preval))

def vectorize_text(inputs, outputs):
    # One extra padding token to the right to match the output shape
    outputs = tf.pad(outputs, [[0, 1]])
    return (
        {"encoder_inputs": inputs},
        {"outputs": outputs[1:]},
    )

train_dataset = train_dataset.map(vectorize_text, num_parallel_calls=tf.data.AUTOTUNE)
val_dataset = val_dataset.map(vectorize_text, num_parallel_calls=tf.data.AUTOTUNE)

train_dataset = (
    train_dataset.cache()
    .shuffle(BUFFER_SIZE)
    .batch(BATCH_SIZE)
    .prefetch(tf.data.AUTOTUNE)
)
val_dataset = val_dataset.cache().batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

## Gumbel-softmax 

In [None]:
def gumbel_func(x, tau, dist):
  return K.softmax(1/tau * (x - dist.sample(K.shape(x))))

class gumbel_softmax(Layer):
  def __init__(self, tau=0.3, alpha=0.999, **kwargs):
    super(gumbel_softmax, self).__init__(**kwargs)
    self.gumbel_dist = tfp.distributions.Gumbel(0.0, 0.0, name='Gumbel')
    self.tau = K.cast_to_floatx(tau)
    self.alpha = K.cast_to_floatx(alpha)

  def call(self, x):
    result = gumbel_func(x, self.tau, self.gumbel_dist)
    return result

In [None]:
class activation_callback(Callback):
  def __init__(self, tau, alpha):
    self.alpha = alpha
    self.tau = tau

  def on_epoch_end(self, epoch, logs={}):
    self.tau.assign(self.tau * K.pow(self.alpha, (epoch % FREQ_CHANGE==0)*epoch))

## Seq2Seq Model

### Creating the FNet Encoder

In [None]:
class FNetEncoder(layers.Layer):
    def __init__(self, embed_dim, dense_dim, **kwargs):
        super(FNetEncoder, self).__init__(**kwargs)
        self.embed_dim = embed_dim
        self.dense_dim = dense_dim
        self.dense_proj = keras.Sequential(
            [
                layers.Dense(dense_dim, activation="relu"),
                layers.Dense(embed_dim),
            ]
        )
        self.layernorm_1 = layers.LayerNormalization()
        self.layernorm_2 = layers.LayerNormalization()

    def call(self, inputs):
        # Casting the inputs to complex64
        inp_complex = tf.cast(inputs, tf.complex64)
        # Projecting the inputs to the frequency domain using FFT2D and
        # extracting the real part of the output
        fft = tf.math.real(tf.signal.fft2d(inp_complex))
        proj_input = self.layernorm_1(inputs + fft)
        proj_output = self.dense_proj(proj_input)
        return self.layernorm_2(proj_input + proj_output)

### Creating the Decoder One Step

In [None]:
class PositionalEmbedding(layers.Layer):
    def __init__(self, sequence_length, vocab_size, embed_dim, **kwargs):
        super(PositionalEmbedding, self).__init__(**kwargs)
        self.token_embeddings = layers.Embedding(
            input_dim=vocab_size, output_dim=embed_dim
        )
        self.position_embeddings = layers.Embedding(
            input_dim=sequence_length, output_dim=embed_dim
        )
        self.sequence_length = sequence_length
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim

    def call(self, inputs):
        length = tf.shape(inputs)[-1]
        positions = tf.range(start=0, limit=length, delta=1)
        embedded_tokens = self.token_embeddings(inputs)
        embedded_positions = self.position_embeddings(positions)
        return embedded_tokens + embedded_positions

    def compute_mask(self, inputs, mask=None):
        return tf.math.not_equal(inputs, 0)


class FNetDecoderOneStep(layers.Layer):
    def __init__(self, embed_dim, latent_dim, num_heads, **kwargs):
        super(FNetDecoderOneStep, self).__init__(**kwargs)
        self.embed_dim = embed_dim
        self.latent_dim = latent_dim
        self.num_heads = num_heads
        self.attention_1 = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=embed_dim
        )
        self.attention_2 = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=embed_dim
        )
        self.dense_proj = keras.Sequential(
            [
                layers.Dense(latent_dim, activation="relu"),
                layers.Dense(embed_dim),
            ]
        )
        self.layernorm_1 = layers.LayerNormalization()
        self.layernorm_2 = layers.LayerNormalization()
        self.layernorm_3 = layers.LayerNormalization()
        self.supports_masking = True

    def call(self, encoder_outputs, mask=None):
        causal_mask = self.get_causal_attention_mask(encoder_outputs)
        if mask is not None:
            padding_mask = tf.cast(mask[:, tf.newaxis, :], dtype="int32")
            padding_mask = tf.minimum(padding_mask, causal_mask)
        else:
          padding_mask = causal_mask
        attention_output_2 = self.attention_2(
            query=encoder_outputs,
            value=encoder_outputs,
            key=encoder_outputs,
            attention_mask=padding_mask,
        )
        out_2 = self.layernorm_2(encoder_outputs + attention_output_2)

        proj_output = self.dense_proj(out_2)
        return self.layernorm_3(out_2 + proj_output)

    def get_causal_attention_mask(self, inputs):
        input_shape = tf.shape(inputs)
        batch_size, sequence_length = input_shape[0], input_shape[1]
        i = tf.range(sequence_length)[:, tf.newaxis]
        j = tf.range(sequence_length)
        mask = tf.cast(i >= j, dtype="int32")
        mask = tf.reshape(mask, (1, input_shape[1], input_shape[1]))
        mult = tf.concat(
            [tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)],
            axis=0,
        )
        return tf.tile(mask, mult)


def create_model_one_step(max_length, alpha, tau):
    # Encoder
    encoder_inputs = keras.Input(shape=(None,), dtype="int32", name="encoder_inputs")
    x = PositionalEmbedding(max_length, VOCAB_SIZE+1, EMBED_DIM)(encoder_inputs)
    encoder_outputs = FNetEncoder(EMBED_DIM, LATENT_DIM)(x)
    encoder = keras.Model(encoder_inputs, encoder_outputs)

    # Encoder -> Decoder
    encoded_seq_inputs = keras.Input(
        shape=(None, EMBED_DIM), name="decoder_state_inputs"
    )
    
    # "Merge" inputs Decoder
    x = FNetDecoderOneStep(EMBED_DIM, LATENT_DIM, NUM_HEADS)(encoded_seq_inputs)
    x = layers.Dropout(0.2)(x)
    gumb_act = gumbel_softmax(tau=tau, alpha=alpha)
    decoder_outputs = layers.Dense(VOCAB_SIZE+1, 
                                   activation=gumb_act)(x)

    decoder = keras.Model(encoded_seq_inputs, decoder_outputs, name="outputs")
    decoder_outputs = decoder(encoder_outputs)
    fnet = keras.Model(encoder_inputs, decoder_outputs, name="fnet")
    return fnet

### Creating and Pre-Training the seq2seq model

In [None]:
alpha = K.constant(ALPHA)
tau = K.variable(TAU)
fnet = create_model_one_step(max_seq_length, alpha, tau)
fnet.compile("adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])

In [None]:
history = fnet.fit(train_dataset, epochs=90, validation_data=val_dataset, 
                   callbacks=[activation_callback(tau, alpha)])

### Inference

In [None]:
sentence = "marriage isn't keeping up with the times. abandon the old thinking and bring something that incorporates all unions not just those with a man and woman."
out = decode_sentence(sentence, preprocess_pretrain)
print(out)

### Save Model

In [None]:
fnet.save("drive/MyDrive/Colab Notebooks/NLP/pretrained_fnet.h5")

## [Bert](https://colab.research.google.com/github/tensorflow/text/blob/master/docs/tutorials/classify_text_with_bert.ipynb)

### Model to fine-tune

In [None]:
bert_model_name = 'bert_en_uncased_L-12_H-768_A-12'

tfhub_handle_encoder = 'https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/3'
tfhub_handle_preprocess = 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3'

print(f'Model name                    : {bert_model_name}')
print(f'BERT model selected           : {tfhub_handle_encoder}')
print(f'Preprocess model auto-selected: {tfhub_handle_preprocess}')

In [None]:
def build_classifier_model(dense_size=100):
  text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name='text')
  preprocessing_layer = hub.KerasLayer(tfhub_handle_preprocess, name='preprocessing')
  encoder_inputs = preprocessing_layer(text_input)
  encoder = hub.KerasLayer(tfhub_handle_encoder, trainable=True, name='BERT_encoder')
  outputs = encoder(encoder_inputs)
  net = outputs['pooled_output']
  net = tf.keras.layers.Dense(dense_size, activation=keras.activations.relu, name='fc')(net)
  net = tf.keras.layers.Dense(1, activation=keras.activations.sigmoid, name='classifier')(net)
  return tf.keras.Model(text_input, net)

In [None]:
bert_model = build_classifier_model()

In [None]:
bert_model.summary()

### Load best model

In [None]:
bert_model.load_weights(model_weights)

## Train

### RL creation

In [None]:
class GenEval(keras.Model):
    def __init__(self, evaluator, generator):
        super(GenEval, self).__init__()
        self.evaluator = evaluator
        self.generator = generator
        self.table = tf.lookup.StaticHashTable(
            initializer=tf.lookup.KeyValueTensorInitializer(
                keys=tf.constant(list(tokenizer.tokenizer.index_word.keys())),
                values=tf.constant(list(tokenizer.tokenizer.index_word.values())),
            ),
            default_value=tf.constant("OOV_TOKEN"),
            name="index_word"
        )
        self.gen_loss_tracker = keras.metrics.Mean(name="generator_loss")

    @property
    def metrics(self):
        return [self.gen_loss_tracker]

    @tf.function
    def join(self, l):
      return tf.strings.join(tf.split(l, num_or_size_splits=l.shape[1], axis=1))

    def compile(self, g_optimizer, loss_fn):
        super(GenEval, self).compile()
        self.g_optimizer = g_optimizer
        self.loss_fn = loss_fn

    def train_step(self, data):
        with tf.GradientTape() as tape:
            predictions = self.generator(data)
            sampled_token_index = tf.argmax(predictions[:, :, :], axis=2, output_type="int32")
            decoded_sentences = self.table.lookup(sampled_token_index)
            decoded_sentences = self.join(decoded_sentences)
            g_loss = self.evaluator(decoded_sentences)
        grads = tape.gradient(g_loss, self.generator.trainable_weights)
        self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))

        # Monitor loss.
        self.gen_loss_tracker.update_state(g_loss)
        return {
            "g_loss": self.gen_loss_tracker.result()
        }

### Training

In [None]:
gen_eval = GenEval(evaluator=bert_model, generator=fnet)
gen_eval.compile(
    g_optimizer=keras.optimizers.Adam(learning_rate=0.0001),
    loss_fn=keras.losses.BinaryCrossentropy(from_logits=True),
)

gen_eval.fit(x_train, epochs=20, batch_size=16)

### Save Model

In [None]:
fnet.save("drive/MyDrive/Colab Notebooks/NLP/final_fnet.h5")

## Inference

In [None]:
sentence = "marriage isn't keeping up with the times. abandon the old thinking and bring something that incorporates all unions not just those with a man and woman."
out = decode_sentence(sentence, preprocess_train)
print(out)