In [1]:
import os
import sys
sys.path.extend(['..'])

from datetime import datetime 
from functools import partial
from pathlib import Path
import random
import time

import matplotlib.pyplot as plt
%matplotlib inline

import tensorflow as tf
import tensorflow_datasets as tfds

print('Physical Devices:\n', tf.config.list_physical_devices(), '\n')
%load_ext tensorboard

import transformers


stamp = datetime.now().strftime("%Y%m%d-%H%M%S")
OUTPUTS_DIR = f'./outputs/{stamp}'
print('\nOutput Directory:', OUTPUTS_DIR)

Physical Devices:
 [PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'), PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')] 


Output Directory: ./outputs/20200320-201546


In [2]:
log_dir = f'{OUTPUTS_DIR}/logs'
summary_writer = tf.summary.create_file_writer(log_dir)
tf.summary.trace_on(graph=True) 
print(log_dir)

./outputs/20200320-201546/logs


In [3]:
path_to_file = tf.keras.utils.get_file('shakespeare.txt', 
                                       'https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt')

In [4]:
shakespeare_lines = tf.data.TextLineDataset(path_to_file)
shakespeare_tensor = tf.strings.join(list(iter(shakespeare_lines)),
                                     separator='\n')
shakespeare_str = shakespeare_tensor.numpy().decode()
print(shakespeare_str[:250])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.



In [5]:
MIN_SEQ_SIZE = 64
MAX_SEQ_SIZE = 128

In [6]:
def substr_generator():
    text_size = len(shakespeare_str)
    while True:
        index = random.randint(0, text_size - MAX_SEQ_SIZE)
        size = random.randint(MIN_SEQ_SIZE, MAX_SEQ_SIZE)
        yield tf.strings.substr(shakespeare_tensor, index, size)

random.seed(0)
substrs_ds = tf.data.Dataset.from_generator(substr_generator, tf.string)

In [7]:
for x in substrs_ds.take(5):
    print(x)

tf.Tensor(b'alters.\n\nPERDITA:\nOne of these is true:\nI think affliction may subdue the cheek,\nBut not take in the mind.\n\nCAMILLO:\n', shape=(), dtype=string)
tf.Tensor(b'w together: grant that, and tell me,\nIn peace what each of them by the other lose,\nThat they comb', shape=(), dtype=string)
tf.Tensor(b'hs up,\nAfter our great good cheer. Pray you, sit down;\nFor now we sit to chat as well as eat.\n\nPETRUCHIO:\nNothing but sit and ', shape=(), dtype=string)
tf.Tensor(b'on\nThat does affect it. Once more, fare you well.\n\nANGELO:\nThe heavens give safety to your purposes!\n\n', shape=(), dtype=string)
tf.Tensor(b'o do?\n\nPETRUCHIO:\nNot her that chides, sir, at any hand, I pray.\n\nTRANIO:\nI love no chiders, sir. Biondello, ', shape=(), dtype=string)


In [8]:
language = set(shakespeare_str)
language_size = len(language)
language_size

65

In [9]:
num_layers = 8
hidden_size = 256
num_heads = 8
max_positions=128

num_special_tokens = 2
pad_token = language_size + 1
mask_token = language_size + 2
# mask hyperparams from https://arxiv.org/pdf/1810.04805.pdf
prob_mask=0.15

In [10]:
model_config = transformers.BertConfig(
    vocab_size_or_config_json_file=language_size,
    type_vocab_size=num_special_tokens, # The [MASK] and [PAD] token
    hidden_size=hidden_size,
    intermediate_size=hidden_size*2,
    num_hidden_layers=num_layers,
    num_attention_heads=num_heads
)
model = transformers.TFBertForMaskedLM(model_config)
# model_config = transformers.OpenAIGPTConfig(
#     vocab_size_or_config_json_file=language_size,
#     type_vocab_size=num_special_tokens,
#     n_positions=max_positions,
#     n_ctx=max_positions,
#     n_embd=hidden_size,
#     n_layer=num_layers,
#     n_head=num_heads
# )
# model = transformers.TFOpenAIGPTLMHeadModel(model_config)

In [11]:
optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

train_acc = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
train_loss = tf.keras.metrics.Mean(name='train_loss')

In [12]:
# Creating a mapping from unique characters to indices
char2idx = {u:i for i, u in enumerate(language)}
idx2char = dict(enumerate(language))


def encode(text_tensor):
    encoded_text = [
        char2idx[c] for c in text_tensor.numpy().decode()
    ]
    
    def _maybe_mask(x):
        if random.random() < prob_mask:
            return mask_token
        return x
    
    masked = [
        _maybe_mask(x) for x in encoded_text
    ]
    return masked, encoded_text


def tf_encode(text_tensor):
    return tf.py_function(encode, [text_tensor], [tf.int32, tf.int32])


dataset = substrs_ds.map(tf_encode)

In [13]:
inp, tar = next(iter(dataset))
inp = tf.expand_dims(inp, 0)
tar = tf.expand_dims(tar, 0)

In [14]:
inp, mask_token, tar

(<tf.Tensor: shape=(1, 128), dtype=int32, numpy=
 array([[67, 67, 67, 21, 62, 51, 49, 51, 67, 67, 67, 32, 53, 51, 45, 53,
          5, 58, 42, 58, 60, 67, 54,  5, 67,  6, 53, 46, 34, 53, 29,  4,
         38, 42, 67,  2, 46, 64, 67, 61, 67, 21, 17, 43, 12, 25, 53, 67,
         43,  7, 52, 17, 55, 23, 67, 21, 23,  5, 46, 34, 67, 46, 34, 67,
          6, 67, 67, 67, 67, 51,  6,  6, 58, 42, 10, 16, 16, 18, 38, 42,
         34, 58, 20, 53, 31, 46, 19, 67, 53, 64, 58, 67, 19, 58, 53, 51,
          2, 67, 46, 64, 58, 20, 21, 54, 58, 67, 67, 38, 34, 67, 67,  6,
         67, 64, 35, 53, 46,  3, 53, 67, 58, 48, 42, 58, 67, 67, 16, 16]])>,
 67,
 <tf.Tensor: shape=(1, 128), dtype=int32, numpy=
 array([[55, 23, 10, 21, 62, 51, 49, 51, 45, 20, 53, 32, 53, 51, 45, 53,
          5, 58, 42, 58, 60, 21, 54,  5, 51,  6, 53, 46, 34, 53, 29,  4,
         38, 42, 53,  2, 46, 64, 64, 61, 21, 21, 17, 43, 12, 25, 53, 30,
         43,  7, 52, 17, 55, 23, 10, 21, 23,  5, 46, 34, 53, 46, 34, 53,
          6,  5, 

In [15]:
token_type_ids = tf.cast(inp == mask_token, tf.int32)
token_type_ids

<tf.Tensor: shape=(1, 128), dtype=int32, numpy=
array([[1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
        0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0,
        0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1,
        1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
        0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1,
        1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0]])>

In [16]:
logits, *_ = model(inp, token_type_ids=token_type_ids)

In [17]:
tar[inp == mask_token]

<tf.Tensor: shape=(30,), dtype=int32, numpy=
array([55, 23, 10, 45, 20, 53, 21, 51, 53, 64, 21, 30, 10, 53, 53,  5, 58,
       53, 45, 58, 51,  5, 53, 45,  6, 53, 51, 34,  6, 10])>

In [18]:
loss_fn(tar[inp == mask_token], logits[inp == mask_token])

<tf.Tensor: shape=(), dtype=float32, numpy=4.305786>

In [19]:
model.summary()

Model: "tf_bert_for_masked_lm"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
bert (TFBertMainLayer)       multiple                  4431360   
_________________________________________________________________
mlm___cls (TFBertMLMHead)    multiple                  215105    
Total params: 4,497,729
Trainable params: 4,497,729
Non-trainable params: 0
_________________________________________________________________


In [20]:
BUFFER_SIZE = 20000
BATCH_SIZE = 8

train_dataset = dataset.cache()
train_dataset = train_dataset.shuffle(BUFFER_SIZE)
padded_shape = [MAX_SEQ_SIZE]
train_dataset = train_dataset.padded_batch(
    BATCH_SIZE, 
    padded_shapes=(padded_shape, padded_shape), 
    padding_values=(pad_token, pad_token)
)

In [21]:
inp, tar = next(iter(train_dataset))

In [22]:
inp.shape, tar.shape

(TensorShape([8, 128]), TensorShape([8, 128]))

In [23]:
padding_mask = tf.cast(inp != pad_token, tf.int32)
token_type_ids = tf.cast(inp == mask_token, tf.int32)
logits, *_ = model(inp, 
                   attention_mask=padding_mask,
                   token_type_ids=token_type_ids)
loss_fn(tar[inp == mask_token], 
        logits[inp == mask_token])

<tf.Tensor: shape=(), dtype=float32, numpy=4.2396054>

In [24]:
def train_step(model: tf.keras.Model,
               optimizer: tf.keras.optimizers.Optimizer,
               pad_token: int,
               mask_token: int,
               inp: tf.Tensor,
               tar: tf.Tensor):
    
    with tf.GradientTape() as tape:
        attention_mask = tf.cast(inp != pad_token, tf.int32)
        token_type_ids = tf.cast(inp == mask_token, tf.int32)
        logits, *_ = model(inp, 
                           attention_mask=attention_mask,
                           token_type_ids=token_type_ids,
                           training=True)
        loss = loss_fn(tar[inp == mask_token], 
                       logits[inp == mask_token])

    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    train_acc(tar[inp == mask_token], 
              logits[inp == mask_token])
    train_loss(loss)
    
    return logits, loss


train_step_signature = [
    tf.TensorSpec(shape=(BATCH_SIZE, MAX_SEQ_SIZE), dtype=tf.int32),
    tf.TensorSpec(shape=(BATCH_SIZE, MAX_SEQ_SIZE), dtype=tf.int32),
]


@tf.function(input_signature=train_step_signature)
def train_step_tf(inp, tar):
    train_step(model, optimizer, 
               pad_token, mask_token, 
               inp, tar)

In [25]:
EPOCHS = 20
BATCHES_IN_EPOCH = 250

checkpoint_path = f'{OUTPUTS_DIR}/ckpts'

ckpt = tf.train.Checkpoint(transformer=model,
                           optimizer=optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    print ('Latest checkpoint restored!!')

In [26]:
try:
    for epoch in range(EPOCHS):
        start = time.time()

        train_loss.reset_states()
        train_acc.reset_states()

        for batch, (inp, tar) in enumerate(train_dataset):
            train_step_tf(inp, tar)

            if batch % 50 == 0:
                print ('Epoch {} Batch {} Loss {:.4f} Accuracy {:.4f}'.format(
                    epoch + 1, batch, train_loss.result(), train_acc.result()))

                with summary_writer.as_default():
                    tf.summary.scalar('loss', train_loss.result(), 
                                      step=epoch)
                    tf.summary.scalar('accuracy', train_acc.result(), 
                                      step=epoch)

            if batch >= BATCHES_IN_EPOCH:
                break

        if (epoch + 1) % 5 == 0:
            ckpt_save_path = ckpt_manager.save()
            print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
                                                                 ckpt_save_path))

        print ('Epoch {} Loss {:.4f} Accuracy {:.4f}'.format(epoch + 1, 
                                                             train_loss.result(), 
                                                             train_acc.result()))

        print ('Time taken for 1 epoch: {} secs\n'.format(time.time() - start))
except KeyboardInterrupt:
    print('Manual interrupt')



  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "




  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "


Epoch 1 Batch 0 Loss 4.2353 Accuracy 0.0246
Epoch 1 Batch 50 Loss 3.7593 Accuracy 0.1298
Epoch 1 Batch 100 Loss 3.6136 Accuracy 0.1395
Epoch 1 Batch 150 Loss 3.5410 Accuracy 0.1420
Epoch 1 Batch 200 Loss 3.4961 Accuracy 0.1440
Epoch 1 Batch 250 Loss 3.4655 Accuracy 0.1456
Epoch 1 Loss 3.4655 Accuracy 0.1456
Time taken for 1 epoch: 97.40566444396973 secs

Epoch 2 Batch 0 Loss 3.6363 Accuracy 0.1048
Epoch 2 Batch 50 Loss 3.3376 Accuracy 0.1497
Epoch 2 Batch 100 Loss 3.3369 Accuracy 0.1546
Epoch 2 Batch 150 Loss 3.3315 Accuracy 0.1540
Epoch 2 Batch 200 Loss 3.3350 Accuracy 0.1524
Epoch 2 Batch 250 Loss 3.3316 Accuracy 0.1533
Epoch 2 Loss 3.3316 Accuracy 0.1533
Time taken for 1 epoch: 70.54202628135681 secs

Epoch 3 Batch 0 Loss 3.2630 Accuracy 0.1333
Epoch 3 Batch 50 Loss 3.3220 Accuracy 0.1502
Epoch 3 Batch 100 Loss 3.3066 Accuracy 0.1524
Epoch 3 Batch 150 Loss 3.3191 Accuracy 0.1536
Manual interrupt


In [35]:
def to_string(x: tf.Tensor) -> str:
    return ''.join(idx2char[i] for i in x.numpy())

def generate_text(
    model: tf.keras.Model,
    start: str,
    mask_token: int,
    temperature = 1.0,
    steps=20,
    print_process=True
):
    if print_process:
        print(start)
        print('------')
    
    inp = [char2idx[c] for c in start]
    
    for _ in range(steps):
        inp_with_mask_at_end = tf.concat([inp, [mask_token]], 0)
        inp_with_mask_at_end = tf.expand_dims(inp_with_mask_at_end, 0)
        token_type_ids = tf.cast(inp_with_mask_at_end == mask_token, tf.int32)
        outputs = model(inp_with_mask_at_end, token_type_ids=token_type_ids)
        
        next_chr_preds = outputs[0][0][-1]
        next_chr_preds = next_chr_preds / temperature
        next_chr_preds = tf.expand_dims(next_chr_preds, 0)
        predicted_id = tf.random.categorical(next_chr_preds, 
                                             num_samples=1,
                                             dtype=tf.int32)[-1, 0]

        inp = tf.concat([inp, tf.reshape(predicted_id, (1,))], 0)
        
        if print_process:
            print(to_string(inp))
            print('------')

In [36]:
generate_text(model, 'ROMEO:', mask_token, print_process=True)

ROMEO:
------
ROMEO:d
------
ROMEO:di
------
ROMEO:di 
------
ROMEO:di l
------
ROMEO:di lo
------
ROMEO:di lor
------
ROMEO:di lorv
------
ROMEO:di lorvs
------
ROMEO:di lorvs 
------
ROMEO:di lorvs G
------
ROMEO:di lorvs G 
------
ROMEO:di lorvs G s
------
ROMEO:di lorvs G se
------
ROMEO:di lorvs G seh
------
ROMEO:di lorvs G seho
------
ROMEO:di lorvs G seho 
------
ROMEO:di lorvs G seho u
------
ROMEO:di lorvs G seho uR
------
ROMEO:di lorvs G seho uRs
------
ROMEO:di lorvs G seho uRs

------


In [None]:
def plot_attention_weights(attention, sentence, layer):
    fig = plt.figure(figsize=(16, 2))

    attention = tf.squeeze(attention[layer], axis=0)

    for head in range(attention.shape[0]):
        ax = fig.add_subplot(2, 4, head+1)

        # plot the attention weights
        ax.matshow(attention[head], 
                   cmap='viridis')

        fontdict = {'fontsize': 10}

        ax.set_xticks(range(len(sentence)+2))

        ax.set_xticklabels(
            ['-']+[idx2char[i] for i in sentence]+['-'], 
            fontdict=fontdict)

        ax.set_yticklabels([])

        ax.set_xlabel('Head {}'.format(head+1))

    plt.tight_layout()
    plt.show()

In [None]:
# plot_attention_weights(attention_weights, x.numpy(), 'decoder_layer1_block2')