In [1]:
import os
import sys
sys.path.extend(['..'])

from datetime import datetime 
from functools import partial
from pathlib import Path
import random
import time
from typing import Tuple, List

import numpy as np
from scipy.stats import bernoulli
import matplotlib.pyplot as plt
%matplotlib inline

import tensorflow as tf
import tensorflow_datasets as tfds

print('Physical Devices:\n', tf.config.list_physical_devices(), '\n')
%load_ext tensorboard

import transformers


stamp = datetime.now().strftime("%Y%m%d-%H%M%S")
# OUTPUTS_DIR = f'./outputs/{stamp}'
OUTPUTS_DIR = './outputs/20200324-000541'
print('\nOutput Directory:', OUTPUTS_DIR)

Physical Devices:
 [PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'), PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')] 


Output Directory: ./outputs/20200324-000541


In [2]:
log_dir = f'{OUTPUTS_DIR}/logs'
summary_writer = tf.summary.create_file_writer(log_dir)
tf.summary.trace_on(graph=True) 
print(log_dir)

./outputs/20200324-000541/logs


In [3]:
path_to_file = tf.keras.utils.get_file(
    'shakespeare.txt', 
    'https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt')

In [4]:
shakespeare_lines = tf.data.TextLineDataset(path_to_file)
shakespeare_tensor = tf.strings.join(list(iter(shakespeare_lines)),
                                     separator='\n')
shakespeare_str = shakespeare_tensor.numpy().decode()
print(shakespeare_str[:250])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.



In [5]:
MIN_SEQ_SIZE = 64
MAX_SEQ_SIZE = 128

In [6]:
def substr_generator():
    text_size = len(shakespeare_str)
    while True:
        index = random.randint(0, text_size - MAX_SEQ_SIZE)
        size = random.randint(MIN_SEQ_SIZE, MAX_SEQ_SIZE)
        yield tf.strings.substr(shakespeare_tensor, index, size)

random.seed(0)
substrs_ds = tf.data.Dataset.from_generator(substr_generator, tf.string)

In [7]:
for x in substrs_ds.take(5):
    print(x)

tf.Tensor(b'alters.\n\nPERDITA:\nOne of these is true:\nI think affliction may subdue the cheek,\nBut not take in the mind.\n\nCAMILLO:\n', shape=(), dtype=string)
tf.Tensor(b'w together: grant that, and tell me,\nIn peace what each of them by the other lose,\nThat they comb', shape=(), dtype=string)
tf.Tensor(b'hs up,\nAfter our great good cheer. Pray you, sit down;\nFor now we sit to chat as well as eat.\n\nPETRUCHIO:\nNothing but sit and ', shape=(), dtype=string)
tf.Tensor(b'on\nThat does affect it. Once more, fare you well.\n\nANGELO:\nThe heavens give safety to your purposes!\n\n', shape=(), dtype=string)
tf.Tensor(b'o do?\n\nPETRUCHIO:\nNot her that chides, sir, at any hand, I pray.\n\nTRANIO:\nI love no chiders, sir. Biondello, ', shape=(), dtype=string)


In [8]:
language = set(shakespeare_str)
language_size = len(language)
language_size

65

In [9]:
num_layers = 8
hidden_size = 256
num_heads = 8
max_positions=128

num_special_tokens = 2
pad_token = language_size + 1
mask_token = language_size + 2
# mask hyperparams from https://arxiv.org/pdf/1810.04805.pdf
prob_mask=0.15

In [10]:
def to_string(x: tf.Tensor) -> str:
    def _convert(i: int) -> str:
        if i == mask_token:
            return '[MASK]'
        if i == pad_token:
            return '[PAD]'
        return idx2char[i]
    return ''.join(_convert(i) for i in x.numpy())

In [11]:
tf.summary.trace_on(graph=True, profiler=True)
# model_config = transformers.BertConfig(
#     vocab_size_or_config_json_file=language_size,
#     type_vocab_size=num_special_tokens, 
#     hidden_size=hidden_size,
#     intermediate_size=hidden_size*2,
#     num_hidden_layers=num_layers,
#     num_attention_heads=num_heads
# )
# model = transformers.TFBertForMaskedLM(model_config)
model_config = transformers.OpenAIGPTConfig(
    vocab_size_or_config_json_file=language_size,
    type_vocab_size=num_special_tokens,
    n_positions=max_positions,
    n_ctx=max_positions,
    n_embd=hidden_size,
    n_layer=num_layers,
    n_head=num_heads
)
model = transformers.TFOpenAIGPTLMHeadModel(model_config)



In [12]:
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-5, epsilon=1e-08, clipnorm=1.0)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

train_acc = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
train_loss = tf.keras.metrics.Mean(name='train_loss')

In [13]:
checkpoint_path = f'{OUTPUTS_DIR}/ckpts'

ckpt = tf.train.Checkpoint(transformer=model,
                           optimizer=optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    print ('Latest checkpoint restored!!')

Latest checkpoint restored!!


In [14]:
# Creating a mapping from unique characters to indices
char2idx = {u:i for i, u in enumerate(language)}
idx2char = dict(enumerate(language))


def encode(text_tensor: tf.Tensor) -> List[List[int]]:
    text_str = text_tensor.numpy().decode()
    return [[char2idx[c] for c in text_str]]


def tf_encode(text_tensor: tf.Tensor) -> tf.Tensor:
    return tf.py_function(encode, [text_tensor], tf.int32)


dataset = substrs_ds.map(tf_encode)

In [15]:
next(iter(dataset))

<tf.Tensor: shape=(128,), dtype=int32, numpy=
array([ 0, 17, 47, 24, 30, 64, 60, 64, 56, 51, 16, 57, 16, 64, 56, 16, 23,
       13, 59, 13, 29, 24, 58, 23, 64, 15, 16, 14, 27, 16, 39, 63, 25, 59,
       16, 50, 14, 42, 42, 20, 24, 24, 48,  5, 18, 54, 16, 19,  5, 53, 41,
       48,  0, 17, 47, 24, 17, 23, 14, 27, 16, 14, 27, 16, 15, 23, 13, 16,
       56, 64, 15, 15, 13, 59, 47,  8,  8,  9, 25, 59, 27, 13, 51, 16,  2,
       14, 46, 13, 16, 42, 13, 64, 46, 13, 16, 64, 50, 23, 14, 42, 13, 51,
       24, 58, 13, 16, 56, 25, 27, 15, 16, 15, 64, 42, 35, 16, 14,  3, 16,
       27, 13, 11, 59, 13, 15, 47,  8,  8])>

In [16]:
def create_mask_and_input(tar: tf.Tensor) -> tf.Tensor:
    where_masked = tf.random.uniform(tar.shape) < prob_mask
    where_masked &= tar != pad_token
    mask_tokens = tf.multiply(mask_token, tf.cast(where_masked, tf.int32))
    not_masked = tf.multiply(tar, 1 - tf.cast(where_masked, tf.int32))
    inp = mask_tokens + not_masked
    return inp, where_masked

In [17]:
tar = next(iter(dataset))
inp, where_masked = create_mask_and_input(tar)

In [18]:
tar, where_masked, inp

(<tf.Tensor: shape=(100,), dtype=int32, numpy=
 array([14, 11, 23, 16,  3, 63, 50, 51, 16, 15, 50, 63, 16, 15, 13,  3, 60,
        13, 59, 16,  7, 42, 64, 39, 37, 13, 42, 42, 63, 50, 27, 16, 15, 63,
        16, 60, 25, 27, 15, 51, 24, 17, 23, 39, 16, 40, 59, 63, 35, 13,  3,
        16, 37, 64, 14, 15, 23, 16, 23, 64, 15, 23, 16, 56, 64, 60, 13, 16,
        64, 16,  7, 59, 13, 39, 16, 37, 63, 59, 16, 50, 63, 59, 56, 27, 29,
        24, 58, 23, 64, 15, 16, 11, 64,  3, 27, 15, 16, 15, 23, 63])>,
 <tf.Tensor: shape=(100,), dtype=bool, numpy=
 array([False,  True, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False,  True,
        False, False,  True, False, False, False,  True,  True, False,
        False, False,  True, False,  True, False, False,  True, False,
        False,  True,  True, F

In [19]:
print(to_string(inp), '\n\n', to_string(tar))

i[MASK]h now, two tender playfellows to [MASK]us[MASK],
T[MASK][MASK] br[MASK]k[MASK]n [MASK]ai[MASK][MASK] [MASK]ath made a pr[MASK]y [MASK]o[MASK] worms.
What [MASK]anst [MASK][MASK][MASK] 

 ich now, two tender playfellows to dust,
Thy broken faith hath made a prey for worms.
What canst tho


In [20]:
logits, *_ = model(inp)

In [21]:
to_string(tf.argmax(logits[where_masked], 1)), to_string(tar[where_masked])

('ZDqxxZqZZqzZzqzzzz', 'cdthyoefthhefrctho')

In [22]:
loss_fn(tar[where_masked], logits[where_masked])

<tf.Tensor: shape=(), dtype=float32, numpy=6.1102123>

In [23]:
model.summary()
with summary_writer.as_default():
    tf.summary.trace_export(
      name="transformer",
      step=0, profiler_outdir=log_dir)

Model: "tf_open_aigptlm_head_model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
transformer (TFOpenAIGPTMain multiple                  6367488   
Total params: 6,367,488
Trainable params: 6,367,488
Non-trainable params: 0
_________________________________________________________________


In [27]:
BUFFER_SIZE = 10000
BATCH_SIZE = 8

train_dataset = dataset.shuffle(BUFFER_SIZE)
train_dataset = train_dataset.padded_batch(
    BATCH_SIZE, 
    padded_shapes=[MAX_SEQ_SIZE], 
    padding_values=pad_token
)

In [28]:
tar = next(iter(train_dataset))
inp, where_masked = create_mask_and_input(tar)
tar.shape, inp.shape, where_masked.shape

(TensorShape([8, 128]), TensorShape([8, 128]), TensorShape([8, 128]))

In [25]:
padding_mask = tf.cast(inp != pad_token, tf.int32)
logits, *_ = model(inp, attention_mask=padding_mask)
loss_fn(tar[where_masked], logits[where_masked])

<tf.Tensor: shape=(), dtype=float32, numpy=4.2594767>

In [26]:
def train_step(model: tf.keras.Model,
               optimizer: tf.keras.optimizers.Optimizer,
               pad_token: int,
               tar: tf.Tensor):
    
    inp, where_masked = create_mask_and_input(tar)
    
    with tf.GradientTape() as tape:
        attention_mask = tf.cast(inp != pad_token, tf.int32)
        logits, *_ = model(inp, 
                           attention_mask=attention_mask,
                           training=True)
        loss = loss_fn(tar[where_masked], logits[where_masked])

    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    train_acc(tar[where_masked], logits[where_masked])
    train_loss(loss)
    
    return logits, loss


train_step_signature = [
    tf.TensorSpec(shape=(BATCH_SIZE, MAX_SEQ_SIZE), 
                  dtype=tf.int32),
]


@tf.function(input_signature=train_step_signature)
def train_step_tf(tar):
    train_step(model, optimizer, pad_token, tar)

In [29]:
BATCHES_IN_EPOCH = 250
epoch = 0

In [28]:
EPOCHS = 300

In [30]:
try:
    while epoch < EPOCHS:
        start = time.time()

        train_loss.reset_states()
        train_acc.reset_states()

        for batch, tar in enumerate(train_dataset):
            train_step_tf(tar)

            if batch % 50 == 0:
#                 print ('Epoch {} Batch {} Loss {:.4f} Accuracy {:.4f}'.format(
#                     epoch + 1, batch, train_loss.result(), train_acc.result()))

                with summary_writer.as_default():
                    tf.summary.scalar('loss', train_loss.result(), 
                                      step=epoch)
                    tf.summary.scalar('accuracy', train_acc.result(), 
                                      step=epoch)

            if batch >= BATCHES_IN_EPOCH:
                break

        if (epoch + 1) % 5 == 0:
            ckpt_save_path = ckpt_manager.save()
            print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
                                                                 ckpt_save_path))

        print ('Epoch {} Loss {:.4f} Accuracy {:.4f}'.format(epoch + 1, 
                                                             train_loss.result(), 
                                                             train_acc.result()))

        print ('Time taken for epoch: {} secs\n'.format(time.time() - start))
        epoch += 1
        
except KeyboardInterrupt:
    print('Manual interrupt')

  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "
  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "


Epoch 1 Loss 3.7273 Accuracy 0.1222
Time taken for epoch: 89.79784440994263 secs

Epoch 2 Loss 3.3987 Accuracy 0.1516
Time taken for epoch: 72.9276659488678 secs

Epoch 3 Loss 3.3596 Accuracy 0.1494
Time taken for epoch: 73.06627106666565 secs

Epoch 4 Loss 3.3281 Accuracy 0.1536
Time taken for epoch: 73.49952435493469 secs

Saving checkpoint for epoch 5 at ./outputs/20200324-000541/ckpts\ckpt-1
Epoch 5 Loss 3.3388 Accuracy 0.1502
Time taken for epoch: 74.96120643615723 secs

Epoch 6 Loss 3.3133 Accuracy 0.1533
Time taken for epoch: 74.62173795700073 secs

Epoch 7 Loss 3.3179 Accuracy 0.1531
Time taken for epoch: 72.82323098182678 secs

Epoch 8 Loss 3.3127 Accuracy 0.1536
Time taken for epoch: 72.62923979759216 secs

Epoch 9 Loss 3.3268 Accuracy 0.1528
Time taken for epoch: 72.97048950195312 secs

Saving checkpoint for epoch 10 at ./outputs/20200324-000541/ckpts\ckpt-2
Epoch 10 Loss 3.3090 Accuracy 0.1508
Time taken for epoch: 74.00064945220947 secs

Epoch 11 Loss 3.3143 Accuracy 0.150

Epoch 86 Loss 2.7425 Accuracy 0.2492
Time taken for epoch: 77.15404558181763 secs

Epoch 87 Loss 2.7536 Accuracy 0.2474
Time taken for epoch: 77.68461894989014 secs

Epoch 88 Loss 2.7510 Accuracy 0.2485
Time taken for epoch: 76.97298860549927 secs

Epoch 89 Loss 2.7400 Accuracy 0.2490
Time taken for epoch: 77.37986397743225 secs

Saving checkpoint for epoch 90 at ./outputs/20200324-000541/ckpts\ckpt-18
Epoch 90 Loss 2.7274 Accuracy 0.2519
Time taken for epoch: 77.6064145565033 secs

Epoch 91 Loss 2.7366 Accuracy 0.2488
Time taken for epoch: 77.46152973175049 secs

Epoch 92 Loss 2.7420 Accuracy 0.2505
Time taken for epoch: 77.43058395385742 secs

Epoch 93 Loss 2.7177 Accuracy 0.2584
Time taken for epoch: 77.36686444282532 secs

Epoch 94 Loss 2.7270 Accuracy 0.2505
Time taken for epoch: 76.23792743682861 secs

Saving checkpoint for epoch 95 at ./outputs/20200324-000541/ckpts\ckpt-19
Epoch 95 Loss 2.7175 Accuracy 0.2529
Time taken for epoch: 78.28342151641846 secs

Epoch 96 Loss 2.7083 Ac

Saving checkpoint for epoch 170 at ./outputs/20200324-000541/ckpts\ckpt-34
Epoch 170 Loss 2.5322 Accuracy 0.2833
Time taken for epoch: 78.34957790374756 secs

Epoch 171 Loss 2.5348 Accuracy 0.2854
Time taken for epoch: 76.8781852722168 secs

Epoch 172 Loss 2.5230 Accuracy 0.2849
Time taken for epoch: 77.1845772266388 secs

Epoch 173 Loss 2.5337 Accuracy 0.2825
Time taken for epoch: 76.46005392074585 secs

Epoch 174 Loss 2.5228 Accuracy 0.2869
Time taken for epoch: 76.54324460029602 secs

Saving checkpoint for epoch 175 at ./outputs/20200324-000541/ckpts\ckpt-35
Epoch 175 Loss 2.5151 Accuracy 0.2891
Time taken for epoch: 78.48339939117432 secs

Epoch 176 Loss 2.5101 Accuracy 0.2909
Time taken for epoch: 76.94967222213745 secs

Epoch 177 Loss 2.5094 Accuracy 0.2893
Time taken for epoch: 77.11376357078552 secs

Epoch 178 Loss 2.5101 Accuracy 0.2897
Time taken for epoch: 77.26440978050232 secs

Epoch 179 Loss 2.5132 Accuracy 0.2902
Time taken for epoch: 77.05408310890198 secs

Saving check

Epoch 253 Loss 2.4165 Accuracy 0.3111
Time taken for epoch: 140.8692545890808 secs

Epoch 254 Loss 2.4248 Accuracy 0.3065
Time taken for epoch: 146.73821353912354 secs

Saving checkpoint for epoch 255 at ./outputs/20200324-000541/ckpts\ckpt-51
Epoch 255 Loss 2.4075 Accuracy 0.3144
Time taken for epoch: 158.61984992027283 secs

Epoch 256 Loss 2.4167 Accuracy 0.3084
Time taken for epoch: 128.0950837135315 secs

Epoch 257 Loss 2.3959 Accuracy 0.3148
Time taken for epoch: 155.25612950325012 secs

Epoch 258 Loss 2.4140 Accuracy 0.3108
Time taken for epoch: 152.9687249660492 secs

Epoch 259 Loss 2.3989 Accuracy 0.3113
Time taken for epoch: 152.3165202140808 secs

Saving checkpoint for epoch 260 at ./outputs/20200324-000541/ckpts\ckpt-52
Epoch 260 Loss 2.4023 Accuracy 0.3157
Time taken for epoch: 153.6019208431244 secs

Epoch 261 Loss 2.3999 Accuracy 0.3185
Time taken for epoch: 152.88019275665283 secs

Epoch 262 Loss 2.3955 Accuracy 0.3128
Time taken for epoch: 151.47720575332642 secs

Epoch

In [31]:
ckpt_loc = 'outputs/20200321-171346/ckpts'
ckpt = tf.train.Checkpoint(transformer=model,
                           optimizer=optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_loc, max_to_keep=5)

# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    print ('Latest checkpoint restored!!')

Latest checkpoint restored!!


In [32]:
ckpt_manager.latest_checkpoint

'outputs/20200321-171346/ckpts\\ckpt-60'

In [29]:
tar = next(iter(train_dataset))
inp, where_masked = create_mask_and_input(tar)
attention_mask = tf.cast(inp != pad_token, tf.int32)
logits, *_ = model(inp, attention_mask=attention_mask)
loss = loss_fn(tar[where_masked], logits[where_masked])
loss

<tf.Tensor: shape=(), dtype=float32, numpy=6.7651315>

In [None]:
to_string(inp[0]), to_string(tar[0])

In [None]:
to_string(tf.argmax(logits[0][inp[0] == mask_token], 1)), to_string(tar[0][inp[0] == mask_token])

In [None]:
def generate_text(
    model: tf.keras.Model,
    start: str,
    mask_token: int,
    temperature = 1.0,
    steps=20,
    print_process=True
):
    if print_process:
        print(start)
        print('------')
    
    inp = [char2idx[c] for c in start]
    
    for _ in range(steps):
        inp_with_mask_at_end = tf.concat([inp, [mask_token]], 0)
        inp_with_mask_at_end = tf.expand_dims(inp_with_mask_at_end, 0)
        outputs = model(inp_with_mask_at_end)
        
        next_chr_preds = outputs[0][0][-1]
        next_chr_preds = next_chr_preds / temperature
        next_chr_preds = tf.expand_dims(next_chr_preds, 0)
        predicted_id = tf.random.categorical(next_chr_preds, 
                                             num_samples=1,
                                             dtype=tf.int32)[-1, 0]

        inp = tf.concat([inp, tf.reshape(predicted_id, (1,))], 0)
        
        if print_process:
            print(to_string(inp))
            print('------')

In [None]:
generate_text(model, 'ROMEO:', mask_token, print_process=True, temperature=0.1)

In [None]:
def plot_attention_weights(attention, sentence, layer):
    fig = plt.figure(figsize=(16, 2))

    attention = tf.squeeze(attention[layer], axis=0)

    for head in range(attention.shape[0]):
        ax = fig.add_subplot(2, 4, head+1)

        # plot the attention weights
        ax.matshow(attention[head], 
                   cmap='viridis')

        fontdict = {'fontsize': 10}

        ax.set_xticks(range(len(sentence)+2))

        ax.set_xticklabels(
            ['-']+[idx2char[i] for i in sentence]+['-'], 
            fontdict=fontdict)

        ax.set_yticklabels([])

        ax.set_xlabel('Head {}'.format(head+1))

    plt.tight_layout()
    plt.show()

In [None]:
# plot_attention_weights(attention_weights, x.numpy(), 'decoder_layer1_block2')

In [24]:
ckpt?