In [None]:
import tensorflow as tf
from tensorflow.python.keras.layers import Layer, Dense, Dropout, Embedding
from tensorflow.python.keras.models import Model
# from tensorflow.python.keras import regularizers


def positional_encoding(position, d_model):
    angle_rads = tf.range(position, dtype=tf.float32)[:, tf.newaxis] * 1 / tf.pow(10000, (2 * tf.range(0, d_model, dtype=tf.float32)) / d_model)
    angle_rads_even = tf.math.sin(angle_rads[:, 0::2])
    angle_rads_odd = tf.math.cos(angle_rads[:, 1::2])
    angle_rads = tf.reshape(tf.concat([angle_rads_even, angle_rads_odd], axis=-1), (-1, d_model))
    pos_encoding = angle_rads[tf.newaxis, ...]
    return tf.cast(pos_encoding, dtype=tf.float32)


class GPTLayer(Layer):
    def __init__(self, d_model, num_heads, dff, rate=0.15):
        super(GPTLayer, self).__init__()

        self.mha = tf.keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model)
        self.ffn = tf.keras.Sequential([
            Dense(dff, activation='relu'),
            Dense(d_model)
        ])

        self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)

        self.dropout1 = Dropout(rate)
        self.dropout2 = Dropout(rate)
        
    def build(self, input_shape):
        self.mha._build_from_signature(input_shape, input_shape, input_shape)
        super(GPTLayer, self).build(input_shape)

    def call(self, x, training, mask):
        attn_output = self.mha(x, x, x, attention_mask=mask)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(x + attn_output)

        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        out2 = self.layernorm2(out1 + ffn_output)

        return out2

def create_look_ahead_mask(size):
    mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
    return mask[tf.newaxis, tf.newaxis, :, :]

class GPT(Model):
    def __init__(self, num_layers, d_model, num_heads, dff, vocab_size, max_position_encoding, rate=0.15):
        super(GPT, self).__init__()

        self.d_model = d_model
        self.num_layers = num_layers

        self.embedding = Embedding(vocab_size, d_model)
        self.pos_encoding = positional_encoding(max_position_encoding, d_model)

        self.gpt_layers = [GPTLayer(d_model, num_heads, dff, rate) for _ in range(num_layers)]
        self.dropout = Dropout(rate)

        self.final_layer = Dense(vocab_size)

    def create_masks(self, inp):
        look_ahead_mask = create_look_ahead_mask(tf.shape(inp)[1])
        padding_mask = tf.cast(tf.math.equal(inp, 0), tf.float32)
        padding_mask = padding_mask[:, tf.newaxis, tf.newaxis, :]
        combined_mask: tf.Tensor = tf.maximum(look_ahead_mask, padding_mask)
        return combined_mask


    def call(self, x, training):
            seq_len = tf.shape(x)[1]
            mask = self.create_masks(x)

            x = self.embedding(x)
            x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
            x += self.pos_encoding[:, :seq_len, :]

            x = self.dropout(x, training=training)

            for i in range(self.num_layers):
                x = self.gpt_layers[i](x, training, mask)

            final_output = self.final_layer(x)
            last_position_logits = final_output[:, -1, :]  # Get logits for last pos
            return last_position_logits


In [None]:
import tensorflow as tf
from tensorflow.python.keras.layers import Dense, Dropout, Embedding, Layer
from tensorflow.python.keras.models import Model
from tensorflow.python.keras.optimizers import adam_v2

print("TensorFlow version: ", tf.__version__)
print("Connecting to TPU...")
resolver = tf.distribute.cluster_resolver.TPUClusterResolver.connect(tpu='node-7',zone='us-central1-f')
strategy = tf.distribute.TPUStrategy(resolver)
print("Done!")
print("Number of accelerators: ", strategy.num_replicas_in_sync)

In [None]:
from transformers import GPT2Tokenizer
tokenizer: GPT2Tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

with strategy.scope():
    # specify parameters here
    num_layers = 2
    d_model = 300
    num_heads = 2
    dff = 1200
    vocab_size = tokenizer.vocab_size
    max_position_encoding = 20
    dropout_rate = 0.15
    batch_size = 1024
    epochs = 3
    warmup_steps = 4000

    print('Creating model...')
    transformer = GPT(num_layers, d_model, num_heads, dff, vocab_size, max_position_encoding, dropout_rate)
    print('Done')


    loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.SUM)

    def loss_function(real, pred):
        loss_ = loss_object(real, pred)
        return loss_


    
    class NoamSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
        def __init__(self, d_model, warmup_steps=4000):
            super(NoamSchedule, self).__init__()

            self.d_model = tf.cast(d_model, tf.float32)
            self.warmup_steps = warmup_steps

        def __call__(self, step):
            step = tf.cast(step, tf.float32)
            arg1 = tf.math.rsqrt(step)
            arg2 = step * (self.warmup_steps ** -1.5)
            learning_rate = tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)

            return learning_rate

    learning_rate = NoamSchedule(d_model, warmup_steps)
    optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9)

bucket_path = 'gs://dataset_w/'
input_tfrecord_files = [f'{bucket_path}wikitrain_{i:04d}.tfrecord' for i in range(79)]

def create_windows(sequence, step=1):
    sequence_length = tf.shape(sequence)[0]

    num_windows = sequence_length - step

    windows = tf.TensorArray(dtype=tf.int64, size=num_windows, dynamic_size=True)

    for i in range(num_windows):
        windows = windows.write(i, sequence[i:i + step + 1])

    return windows.stack()

@tf.function
def _parse_function(example_proto):
    feature_description = {
        'token_ids': tf.io.VarLenFeature(tf.int64),
    }
    parsed_features = tf.io.parse_single_example(example_proto, feature_description)
    input_sequence = tf.sparse.to_dense(parsed_features['token_ids'])
    input_sequences = create_windows(input_sequence, max_position_encoding)  # Add 1 to window size for target

    def process_sequences(seq):
        inp = seq[:-1]
        tar = seq[-1]
        return inp, tar

    input_sequences, target_sequences = tf.map_fn(process_sequences, input_sequences, dtype=(tf.int64, tf.int64))

    return input_sequences, target_sequences

# Load and preprocess dataset
def load_dataset(input_files):
    input_ds = tf.data.TFRecordDataset(input_files)
    input_ds = input_ds.map(_parse_function)
    return input_ds


print('Processing dataset...')
input_dataset = load_dataset(input_tfrecord_files)

def print_sequences_as_words(inp, tar):
    inp_tokens = tokenizer.batch_decode(inp.numpy(), skip_special_tokens=True)
    tar_tokens = tokenizer.batch_decode([tar.numpy()], skip_special_tokens=True)

    print("Input:")
    for seq in inp_tokens:
        print(seq)

    print("\nTarget:")
    for seq in tar_tokens:
        print(seq)
def print_dataset(input_dataset, num_examples=1):
    for i, (inp, tar) in enumerate(input_dataset.take(num_examples)):
        print(f"Example {i + 1}:")
        print("Input: ", inp.numpy())
        print("Target: ", tar.numpy())
        print_sequences_as_words(inp, tar)
        print("\n")
# print_dataset(input_dataset, 1)

input_dataset = input_dataset.flat_map(lambda x, y: tf.data.Dataset.from_tensor_slices((x, y)))
dataset = input_dataset.batch(batch_size, drop_remainder=True)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
print('Done!')

import os
import matplotlib.pyplot as plt
import datetime
from IPython.display import clear_output

def plot_loss(loss_history):
    clear_output(wait=True)
    plt.plot(loss_history)
    plt.xlabel("Batch")
    plt.ylabel("Loss")
    
    now = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    plt.title(f"Loss History\nLast Updated at: {now}")
    
    plt.show()

@tf.function
def train_step(inp, tar):
    with tf.GradientTape() as tape:
        predictions = transformer(inp, training=True)
        loss = loss_function(tar, predictions)

    gradients = tape.gradient(loss, transformer.trainable_variables)
    optimizer.apply_gradients(zip(gradients, transformer.trainable_variables))

    return loss

loss_history = []
print("Initializing training...")
for epoch in range(epochs):
    total_loss = tf.constant(0.0, dtype=tf.float32)
    for (batch, (inp, tar)) in enumerate(dataset):
        per_replica_losses = strategy.run(train_step, args=(inp, tar))
        loss = strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
        total_loss += loss
        loss = loss.numpy()
        
        loss_history.append(loss)
        
    avg_loss = total_loss / (batch + 1)
    print(f'Epoch {epoch + 1}, Average loss: {avg_loss.numpy()}')

In [None]:
def w(warm, d_model, step): 
    arg1 = step ** -0.5
    arg2 = step * (warm ** -1.5)
    learning_rate = d_model ** 0.5  * min(arg1, arg2)

    return learning_rate

learning_rate = []
d_model = 512
warm = 4000
learning_rate = [w(warm, d_model, step) for step in range(1, 100000)]

plt.plot(learning_rate)
plt.xlabel("Batch")
plt.ylabel("Learning rate")

# Get the current timestamp
plt.title(f"Learning rate")

plt.show()

In [None]:
plot_loss(loss_history)

In [None]:
def predict_next_word(input_text, transformer, tokenizer, top_k=5, max_length=128):
    input_tokens_full = tokenizer.encode(input_text, return_tensors="tf")
    if input_tokens_full.shape[1] > max_length:
        input_tokens = input_tokens_full[:, -max_length:]
    elif input_tokens_full.shape[1] < max_length:
        input_tokens = tf.pad(input_tokens_full, [[0, 0], [max_length - input_tokens_full.shape[1], 0]])
    print(input_tokens)
    logits = transformer(input_tokens, training=False)
    logits = logits[0, :]  # Get the logits for the last token
    probabilities = tf.nn.softmax(logits, axis=-1)
    top_k_indices = tf.math.top_k(probabilities, k=top_k).indices
    top_k_tokens = [tokenizer.decode([token_id]) for token_id in top_k_indices.numpy()]
    
    return top_k_tokens


input_text = """The first, the first,"""
predicted_words = predict_next_word(input_text, transformer, tokenizer, top_k=50, max_length=max_position_encoding)
print(f"Input: {input_text}")
print("Predicted next words:")
for i, word in enumerate(predicted_words):
    print(f"{i + 1}. {word}")