In [None]:
# Use this if the packages are not installed yet
!pip install -q tensorflow_text
!pip install -q sentencepiece
!pip install -q --upgrade numpy

# Imports

In [55]:
import numpy
import tensorflow as tf
import tensorflow_text as tf_txt
import tqdm.notebook as note
import sentencepiece as sp
import io
import datetime

# Hyperparams

In [56]:
# HYPERPARAMS
VOCAB_SIZE = 6000
SEQ_LEN = 32
SHUFFLE_SIZE = 1000
BATCH_SIZE = 25
D_EMBEDDINGS = 64
NUM_HEADS = 3
TOP_K = 30
EPOCH_SIZE = 20

# Preprocessing

In [57]:
# Save filepath to downloaded "Beyond Good and Evil"
path = tf.keras.utils.get_file("nietzsche.txt", origin="https://s3.amazonaws.com/text-datasets/nietzsche.txt")
# Load txt into str
with open(path) as p:
    text = p.read()

# Train the tokenizer
sp.SentencePieceTrainer.train(
    input=path, model_prefix='tokenizer_model', model_type="unigram", vocab_size=VOCAB_SIZE)
# deserialize the trained model file to load it in the correct format
trained_tokenizer_model = tf.io.gfile.GFile('tokenizer_model.model', "rb").read()
# load the model as a tokenizer that can be used inside a tensorflow model
tokenizer = tf_txt.SentencepieceTokenizer(
    model=trained_tokenizer_model, out_type=tf.int32, nbest_size=-1, alpha=1, reverse=False,
    add_bos=False, add_eos=False, return_nbest=False, name=None
)

# Tokenize the str
tokens = tokenizer.tokenize(text)

In [58]:
token_windows = tf_txt.sliding_window(data=tokens, width=SEQ_LEN+1)
token_ds = tf.data.Dataset.from_tensor_slices(token_windows)
token_ds = token_ds.map(lambda x: (x[:SEQ_LEN], x[SEQ_LEN:SEQ_LEN+1]))
token_ds = token_ds.shuffle(SHUFFLE_SIZE).batch(BATCH_SIZE)

# Model

In [59]:
#@title Embedding Layer { display-mode: "form" }
class Embedding_Layer(tf.keras.layers.Layer):
    def __init__(self):
        super(Embedding_Layer, self).__init__()

        self.embedding_1 = tf.keras.layers.Embedding(input_dim=SEQ_LEN, output_dim=D_EMBEDDINGS)
        self.embedding_2 = tf.keras.layers.Embedding(input_dim=VOCAB_SIZE, output_dim=D_EMBEDDINGS)
    
    
    def call(self, x):
        zeros = self.embedding_1(tf.range(start=0, limit=SEQ_LEN))

        return zeros + self.embedding_2(x)

In [60]:
#@title Transformer Block
class TransformerBlock(tf.keras.layers.Layer):
    def __init__(self):
        super(TransformerBlock, self).__init__()

        self.mha = tf.keras.layers.MultiHeadAttention(num_heads=NUM_HEADS, key_dim=D_EMBEDDINGS)
        self.dense_1 = tf.keras.layers.Dense(units=32, activation="relu")
        self.dense_2 = tf.keras.layers.Dense(units=D_EMBEDDINGS)
        self.dropout_1 = tf.keras.layers.Dropout(rate=0.1)
        self.dropout_2 = tf.keras.layers.Dropout(rate=0.1)
        self.norm_1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.norm_2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)


    def call(self, x):

        att_out = self.mha(x, x)
        att_out = self.dropout_2(att_out)
        ln_out = self.norm_1(x + att_out)
        ffn_out = self.dense_1(ln_out)
        ffn_out = self.dense_2(ffn_out)
        ffn_out = self.dropout_2(ffn_out)

        return self.norm_2(ln_out + ffn_out)

In [61]:
#@title Transformer
class Transformer(tf.keras.Model):
    def __init__(self, tokenizer):
        super(Transformer, self).__init__()

        self.tokenizer = tokenizer
        self.optimizer = tf.keras.optimizers.Adam()
        self.loss_func = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
        self.pos_embedding = Embedding_Layer()
        self.block = TransformerBlock()
        self.pool = tf.keras.layers.GlobalAvgPool1D()
        self.dense = tf.keras.layers.Dense(units=VOCAB_SIZE)

        self.metrics_list = [
                        tf.keras.metrics.Mean(name="loss"),
                        tf.keras.metrics.CategoricalAccuracy(name="acc")
                        ]

    def call(self, x):
        embedded = self.pos_embedding(x)
        embedded = self.block(embedded)
        embedded = self.pool(embedded)

        return self.dense(embedded)


    def reset_metrics(self):
        for metric in self.metrics_list:
            metric.reset_states()
            

    @tf.function
    def train_step(self, data):
        
        x, targets = data

        with tf.GradientTape() as tape:
            predictions = self(x)
            loss = self.loss_func(targets, predictions) + tf.reduce_sum(self.losses)
        
        gradients = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
        
        # update metrics
        self.metrics_list[0].update_state(loss) # loss
        self.metrics_list[1].update_state(targets, predictions) # acc

        # Return a dictionary mapping metric names to current value
        return {m.name: m.result() for m in self.metrics_list}


    def generate_text(self, prompt, sample_size=5):

        tokens = self.tokenizer.tokenize(prompt)
        prompt_len = tokens.shape[0]
        tokens = tf.expand_dims(tokens, axis=0)
        tokens = tf.pad(tokens, [[0,0], [SEQ_LEN - prompt_len,0]], "CONSTANT", constant_values=0)

        for _ in range(sample_size):

            logits, indices = tf.math.top_k(self(tokens), k=TOP_K, sorted=True)
            sample = tf.random.categorical(tf.cast(indices, tf.float32), 1, dtype=tf.int32)
            tokens = tf.concat((tokens, sample), axis=1)
            tokens = tf.slice(tokens, [0, 1], [1, SEQ_LEN])

        return self.tokenizer.detokenize(tokens[:,-prompt_len-sample_size:])

# Training

In [62]:
%load_ext tensorboard

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


In [63]:
model = Transformer(tokenizer)

# Define where to save the log
hyperparameter_string = "VOCAB_SIZE-2000__SEQ_LEN-32__SHUFFLE_SIZE-1000__BATCH_SIZE-50__D_EMBEDDINGS-64__NUM_HEADS-3__TOP_K-5"
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")

log_path = f"logs/{hyperparameter_string}/{current_time}/train"
summary_writer = tf.summary.create_file_writer(log_path)

In [None]:
sample_size = 20

for epoch in range(EPOCH_SIZE):
    
    print(f"Epoch {epoch}:")
    
    # Training:
    
    for data in note.tqdm(token_ds, position=0, leave=True):
        metrics = model.train_step(data)
    
    # print the metrics
    print([f"{key}: {value}" for (key, value) in zip(list(metrics.keys()), list(metrics.values()))])
    # print generated text
    generated_text = model.generate_text("The eagerness and subtlety, I should even say craftiness", sample_size)
    print(f"\n Generated the last {sample_size} words of the following text: \n {generated_text[0]} \n")

    with summary_writer.as_default():
        # logging the metrics to the log file which is used by tensorboard
        for metric in model.metrics_list:
            tf.summary.scalar(f"{metric.name}", metric.result(), step=epoch)
        # logging generated text
        tf.summary.text(f"sample_size_{sample_size}", generated_text, step=epoch)
    
    # reset all metrics (requires a reset_metrics method in the model)
    model.reset_metrics()
    
    print("\n")

Epoch 0:


  0%|          | 0/5491 [00:00<?, ?it/s]

['loss: 6.337629318237305', 'acc: 0.00970399845391512']

 Generated the last 20 words of the following text: 
 b'The eagerness and subtlety, I should even say craftiness whiching thating as with not be; not which-:: alledededed with' 



Epoch 1:


  0%|          | 0/5491 [00:00<?, ?it/s]

['loss: 5.646011829376221', 'acc: 0.01889074221253395']

 Generated the last 20 words of the following text: 
 b'The eagerness and subtlety, I should even say craftinessing be in not;ed " alling; not be be it which which which be as-' 



Epoch 2:


  0%|          | 0/5491 [00:00<?, ?it/s]

In [None]:
%tensorboard --logdir logs/

# Save Model

In [None]:
# save the model with a meaningful name
model.save_weights(f"saved_model_{hyperparameter_string}", save_format="tf")