In [33]:
!pip install -q tensorflow_text
!pip install -q sentencepiece
import tensorflow as tf
import tensorflow_text as tf_txt
import tqdm.notebook as note
import sentencepiece as sp
import io
import datetime

In [34]:
# HYPERPARAMS
VOCAB_SIZE = 2000
WINDOW_SIZE = 33
SHUFFLE_SIZE = 1000
BATCH_SIZE = 50
D_EMBEDDINGS = 64
NUM_HEADS = 2
KEY_DIM = 3 #???

In [35]:
# Save filepath to downloaded "Beyond Good and Evil"
path = tf.keras.utils.get_file("nietzsche.txt", origin="https://s3.amazonaws.com/text-datasets/nietzsche.txt")
# Load txt into str
text = open(path).read()

# Train the tokenizer
sp.SentencePieceTrainer.train(
    input=path, model_prefix='tokenizer_model', model_type="unigram", vocab_size=VOCAB_SIZE)
# deserialize the trained model file to load it in the correct format
trained_tokenizer_model = tf.io.gfile.GFile('tokenizer_model.model', "rb").read()
# load the model as a tokenizer that can be used inside a tensorflow model
tokenizer = tf_txt.SentencepieceTokenizer(
    model=trained_tokenizer_model, out_type=tf.int32, nbest_size=-1, alpha=1, reverse=False,
    add_bos=False, add_eos=False, return_nbest=False, name=None
)

# Tokenize the str
tokens = tokenizer.tokenize(text)

In [None]:
token_windows = tf_txt.sliding_window(data=tokens, width=WINDOW_SIZE)
token_ds = tf.data.Dataset.from_tensor_slices({"input": token_windows[:,:-1], "target": token_windows[:,-1]})
token_ds = token_ds.shuffle(SHUFFLE_SIZE).batch(BATCH_SIZE)

In [37]:
class Embedding_Layer(tf.keras.layers.Layer):
    def __init__(self):
        super(Embedding_Layer, self).__init__()

        self.embedding_1 = tf.keras.layers.Embedding(input_dim=VOCAB_SIZE, output_dim=D_EMBEDDINGS)
        self.embedding_2 = tf.keras.layers.Embedding(input_dim=VOCAB_SIZE, output_dim=D_EMBEDDINGS)
    

    def call(self, x):

        zeros = self.embedding_1(tf.range(start=0, limit=WINDOW_SIZE-1))

        return zeros + self.embedding_2(x)

In [32]:
class TransformerBlock(tf.keras.layers.Layer):
    def __init__(self):
        super(TransformerBlock, self).__init__()

        self.mha = tf.keras.layers.MultiHeadAttention(num_heads=NUM_HEADS, key_dim=KEY_DIM)
        self.dense_1 = tf.keras.layers.Dense(units=32, activation="relu")
        self.dense_2 = tf.keras.layers.Dense(units=D_EMBEDDINGS)
        self.dropout_1 = tf.keras.layers.Dropout(rate=0.1)
        self.dropout_2 = tf.keras.layers.Dropout(rate=0.1)
        self.norm_1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.norm_2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)


    def call(self, x):

        att_out = self.mha(x)
        att_out = self.dropout_2(att_out)
        ln_out = self.norm_1(x + att_out)
        ffn_out = self.dense_1(ln_out)
        ffn_out = self.dense_2(ffn_out)
        ffn_out = self.dropout_2(ffn_out)

        return self.norm_2(ln_out + ffn_out)

In [None]:
class Transformer(tf.keras.Model):
    def __init__(self, tokenizer):
        super(Transformer, self).__init__()

        self.tokenizer = tokenizer
        self.optimizer = tf.keras.optimizers.Adam
        self.loss_func = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
        self.pos_embedding = Embedding_Layer
        self.block = TransformerBlock
        self.pool = tf.keras.layers.GlobalAvgPool1D
        self.dense = tf.keras.layers.Dense(units=VOCAB_SIZE)

        self.metrics = [
                        tf.keras.metrics.Mean(name="loss"),
                        tf.keras.metrics.CategoricalAccuracy(name="acc"),
                        tf.keras.metrics.TopKCategoricalAccuracy(3,name="top-3-acc") 
                        ]


    def call(self, x):
        
        return self.pos_embedding(self.block(self.pool(self.dense(x))))


    def reset_metrics(self):
        for metric in self.metrics:
            metric.reset_states()
            

    @tf.function
    def train_step(self, data):
        
        x, targets = data["input"], data["target"]
        
        with tf.GradientTape() as tape:
            predictions = self(x, training=True)
            loss = self.loss_func(targets, predictions) + tf.reduce_sum(self.losses)
        
        gradients = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
        
        # update loss metric
        self.metrics[0].update_state(loss)
        
        # for all metrics except loss, update states (accuracy etc.)
        for metric in self.metrics[1:]:
            metric.update_state(targets,predictions)

        # Return a dictionary mapping metric names to current value
        return {m.name: m.result() for m in self.metrics}

    def generate_text(self, prompt):
        tokens = self.tokenizer.tokenize(prompt)
        tokens = tf.pad(tokens, [-1 * ((-1 * (32 - tokens.shape[0])) // 2), (32 - tokens.shape[0]) // 2], "CONSTANT")
        tokens = tf.expand_dims(tokens, axis=-1)
        logits = tf.math.top_k(tokens, sorted=True)
        return tf.random.categorical(logits=logits)