In [22]:
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
import tqdm
import time

from transformers import BertTokenizer, TFBertModel, TFGPT2Model, GPT2Tokenizer, TFGPT2LMHeadModel
from Attention import AttentionUtils

In [23]:
tf.keras.backend.clear_session()

In [24]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.add_special_tokens({'pad_token': '<PAD>'})
vocab_size = tokenizer.vocab_size

gpt2 = TFGPT2LMHeadModel.from_pretrained('gpt2')
gpt2.summary()


All model checkpoint layers were used when initializing TFGPT2LMHeadModel.

All the layers of TFGPT2LMHeadModel were initialized from the model checkpoint at gpt2.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFGPT2LMHeadModel for predictions without further training.


Model: "tfgp_t2lm_head_model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
transformer (TFGPT2MainLayer multiple                  124439808 
Total params: 124,439,808
Trainable params: 124,439,808
Non-trainable params: 0
_________________________________________________________________


In [25]:
def tokenize(s, max_len=128):
    tok = tokenizer.encode(bytes.decode(s.numpy()), max_length=max_len, padding='max_length')
    return tf.constant(tok, dtype=tf.int32)


def shift(x):
    return x[:, :-1], x[:, 1:]

In [26]:
sentences = tf.data.experimental.load(
    'sentences_raw_gpttokens.tfrecord', compression='GZIP')
print(len(sentences))
for s in sentences.batch(2).take(1):
    print(tokenizer.decode(s[0], skip_special_tokens=True))
    print(tokenizer.decode(s[1], skip_special_tokens=True))

492686
the deep space nine transcripts - emissary emissary stardate: 46379.1 original airdate: 3 jan, 1993 on stardate 43997, captain jean-luc picard of the federation starship enterprise was kidnapped for six days by an invading force known as the borg.
 surgically altered, he was forced to lead an assault on starfleet at wolf 359.


In [27]:
ratios = (0.8, 0.1, 0.1)
assert sum(ratios) == 1

BUFFER_SIZE = 10000
BATCH_SIZE = 16

sentences = tf.data.experimental.load(
    'sentences_combined_gpttokens.tfrecord', compression='GZIP')
sentences = sentences.shuffle(BUFFER_SIZE)
cardinality = len(sentences)
train_dataset = sentences.take(int(ratios[0] * cardinality))
valid_dataset = sentences.skip(int(ratios[0] * cardinality)).take(int(ratios[1] * cardinality))
test_dataset = sentences.skip(int(ratios[0] * cardinality) + int(ratios[1] * cardinality))

# shuffle(BUFFER_SIZE, reshuffle_each_iteration=True)?

train_dataset = train_dataset.batch(BATCH_SIZE).prefetch(tf.data.experimental.AUTOTUNE)
valid_dataset = valid_dataset.batch(BATCH_SIZE).prefetch(tf.data.experimental.AUTOTUNE)
test_dataset = test_dataset.batch(BATCH_SIZE).prefetch(tf.data.experimental.AUTOTUNE)


In [28]:
print(len(sentences), len(train_dataset), len(valid_dataset), len(test_dataset))

42371 2119 265 265


In [29]:
def reorder(s):
    return s[:,0,:], s[:,1,:]

In [30]:
class HFModel(tf.keras.Model):
    def __init__(self, model, vocab_size, dense, output_dense=True, make_base_trainable=False):
        super(HFModel, self).__init__()
        self.output_dense = output_dense
        self.model = model
        self.model.trainable = make_base_trainable
        self.dense = tf.keras.layers.Dense(dense, activation='relu')
        self.output_dense = tf.keras.layers.Dense(vocab_size, activation='softmax')

    def call(self, inputs):
        model_outs = self.model(inputs).last_hidden_state
        # hidden_dense = self.dense(model_outs)        
        outputs = self.output_dense(model_outs)
        return outputs
        # return tf.argmax(outputs, axis=-1)


In [31]:
# model = HFModel(gpt2, vocab_size, 256, output_dense=False)
model = TFGPT2LMHeadModel.from_pretrained(
    "gpt2",
    use_cache=False,
    pad_token_id=tokenizer.pad_token_id,)

checkpoint_path = "./checkpoints/GPT2LM/train"

# ckpt = tf.train.Checkpoint(transformer=model,
#                            optimizer=optimizer)

# ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=3)

model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(
    from_logits=True), metrics=[tf.metrics.SparseCategoricalAccuracy()])


All model checkpoint layers were used when initializing TFGPT2LMHeadModel.

All the layers of TFGPT2LMHeadModel were initialized from the model checkpoint at gpt2.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFGPT2LMHeadModel for predictions without further training.


In [34]:
model.fit(train_dataset.map(reorder), epochs=3, callbacks=[
    tf.keras.callbacks.ModelCheckpoint(
        filepath=checkpoint_path,
        save_weights_only=True,
        verbose=1)
])

Epoch 1/3

Epoch 00001: saving model to ./checkpoints/GPT2LM\train
Epoch 2/3

Epoch 00002: saving model to ./checkpoints/GPT2LM\train
Epoch 3/3

Epoch 00003: saving model to ./checkpoints/GPT2LM\train


<tensorflow.python.keras.callbacks.History at 0x2884fded6a0>

In [35]:
for s in test_dataset.take(1):
    print("Input:\n", tokenizer.decode(s[0, 0, :], skip_special_tokens=True))
    output = model.generate(s[0, 0, :][tf.newaxis, :], max_length=len(s[0,0,:])*2, temperature=0.7)
    print("Output:\n", tokenizer.decode(output[0], skip_special_tokens=True))


Input:
 )[shuttlepod] (tucker hears someone making gun noises. it's the boy, who is startled when trip activates an alarm.) tucker: what are you doing in my chair? q'ell: i didn't touch anything. i just like to look inside the ships that come here. tucker: you should have asked. q'ell: you might have said no. tucker: well, what do you think? q'ell: well, it's a little small and your thruster controls are hard to reach. tucker: maybe you need longer arms.
Output:
 )[shuttlepod] (tucker hears someone making gun noises. it's the boy, who is startled when trip activates an alarm.) tucker: what are you doing in my chair? q'ell: i didn't touch anything. i just like to look inside the ships that come here. tucker: you should have asked. q'ell: you might have said no. tucker: well, what do you think? q'ell: well, it's a little small and your thruster controls are hard to reach. tucker: maybe you need longer arms.!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!