In [None]:
import tensorflow as tf
from tensorflow.python.keras.layers import Layer, Dense, Dropout, Embedding
from tensorflow.python.keras.models import Model


def positional_encoding(position, d_model):
    angle_rads = tf.range(position, dtype=tf.float32)[:, tf.newaxis] * 1 / tf.pow(10000, (2 * tf.range(0, d_model, dtype=tf.float32)) / d_model)
    angle_rads_even = tf.math.sin(angle_rads[:, 0::2])
    angle_rads_odd = tf.math.cos(angle_rads[:, 1::2])
    angle_rads = tf.reshape(tf.concat([angle_rads_even, angle_rads_odd], axis=-1), (-1, d_model))
    pos_encoding = angle_rads[tf.newaxis, ...]
    return tf.cast(pos_encoding, dtype=tf.float32)


class GPTLayer(Layer):
    def __init__(self, d_model, num_heads, dff, rate=0.15):
        super(GPTLayer, self).__init__()

        self.mha = tf.keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model)
        self.ffn = tf.keras.Sequential([
            Dense(dff, activation='relu'),
            Dense(d_model)
        ])

        self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)

        self.dropout1 = Dropout(rate)
        self.dropout2 = Dropout(rate)
        
    def build(self, input_shape):
        self.mha._build_from_signature(input_shape, input_shape, input_shape)
        super(GPTLayer, self).build(input_shape)

    def call(self, x, training, mask):
        attn_output = self.mha(x, x, x, attention_mask=mask)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(x + attn_output)

        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        out2 = self.layernorm2(out1 + ffn_output)

        return out2

def create_look_ahead_mask(size):
    mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
    return mask[tf.newaxis, tf.newaxis, :, :] 


class GPT(Model):
    def __init__(self, num_layers, d_model, num_heads, dff, vocab_size, max_position_encoding, rate=0.15):
        super(GPT, self).__init__()

        self.d_model = d_model
        self.num_layers = num_layers

        self.embedding = Embedding(vocab_size, d_model)
        self.pos_encoding = positional_encoding(max_position_encoding, d_model)

        self.gpt_layers = [GPTLayer(d_model, num_heads, dff, rate) for _ in range(num_layers)]
        self.dropout = Dropout(rate)

        self.final_layer = Dense(vocab_size)

    def create_masks(self, inp):
        look_ahead_mask = create_look_ahead_mask(tf.shape(inp)[1])
        padding_mask = tf.cast(tf.math.equal(inp, 0), tf.float32)
        padding_mask = padding_mask[:, tf.newaxis, tf.newaxis, :]
        combined_mask: tf.Tensor = tf.maximum(look_ahead_mask, padding_mask)
        return combined_mask


    def call(self, x, training):
            seq_len = tf.shape(x)[1]
            mask = self.create_masks(x)

            x = self.embedding(x)
            x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
            x += self.pos_encoding[:, :seq_len, :]

            x = self.dropout(x, training=training)

            for i in range(self.num_layers):
                x = self.gpt_layers[i](x, training, mask)

            final_output = self.final_layer(x)
            last_position_logits = final_output[:, -1, :]
            return last_position_logits


In [None]:
import tensorflow as tf
print("TensorFlow version: ", tf.__version__)
print("Connecting to TPU...")
resolver = tf.distribute.cluster_resolver.TPUClusterResolver.connect(tpu='node-7',zone='us-central1-f')
strategy = tf.distribute.TPUStrategy(resolver)
print("Done!")
print("Number of accelerators: ", strategy.num_replicas_in_sync)

In [None]:
from transformers import GPT2Tokenizer
tokenizer: GPT2Tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

with strategy.scope():
    # specify parameters here
    num_layers = 2
    d_model = 300
    num_heads = 2
    dff = 1200
    vocab_size = tokenizer.vocab_size
    max_position_encoding = 20
    dropout_rate = 0.15
    batch_size = 55
    epochs = 200
    warmup_steps = 100

    print('Creating model...')
    transformer = GPT(num_layers, d_model, num_heads, dff, vocab_size, max_position_encoding, dropout_rate)
    print('Done')


    loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.SUM)

    def loss_function(real, pred):
        loss_ = loss_object(real, pred)
        return loss_

    class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
        def __init__(self, d_model, warmup_steps=4000):
            super(CustomSchedule, self).__init__()

            self.d_model = tf.cast(d_model, tf.float32)
            self.warmup_steps = warmup_steps

        def __call__(self, step):
            step = tf.cast(step, tf.float32)
            arg1 = tf.math.rsqrt(step)
            arg2 = step * (self.warmup_steps ** -1.5)
            learning_rate = tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)

            return learning_rate

    learning_rate = CustomSchedule(d_model, warmup_steps)
    optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9)

bucket_path = 'gs://dataset_w/'
input_tfrecord_files = [f'{bucket_path}wikitrain_{i:04d}.tfrecord' for i in range(79)]

def create_tf_dataset(data, tokenizer):
    def split_input_target(input_string):
        parts = input_string.strip().split("? ")
        event, year = " ".join(parts[:-1]), int(parts[-1])
        return event, year

    events, years = zip(*[split_input_target(item) for item in data])
    
    # Encode events using GPT-2 tokenizer
    encoded_events = [tokenizer.encode(event) for event in events]
    
    events_max_length = max([len(event) for event in encoded_events])
    
    encoded_events = [[0] * (events_max_length - len(event)) + event for event in encoded_events]
    
    encoded_years = tf.expand_dims(years, axis=1)
    events_tensor = tf.data.Dataset.from_tensor_slices(encoded_events)
    years_tensor = tf.data.Dataset.from_tensor_slices(encoded_years)

    dataset = tf.data.Dataset.zip((events_tensor, years_tensor))

    return dataset

# Example usage
data = ["What year was the signing of the Declaration of Independence? 1776",
"What year was the storming of the Bastille? 1789",
"What year was the Battle of Waterloo? 1815",
"What year was the assassination of Abraham Lincoln? 1865",
"What year was the invention of the telephone by Alexander Graham Bell? 1876",
"What year was the first successful powered airplane flight by the Wright brothers? 1903",
"What year was the sinking of the Titanic? 1912",
"What year was the beginning of World War I? 1914",
"What year was the Russian Revolution? 1917",
"What year was the end of World War I? 1918",
"What year was the stock market crash that led to the Great Depression? 1929",
"What year was the beginning of World War II? 1939",
"What year was the attack on Pearl Harbor? 1941",
"What year was the D-Day invasion during World War II? 1944",
"What year was the dropping of the atomic bombs on Hiroshima and Nagasaki? 1945",
"What year was the end of World War II? 1945",
"What year was the establishment of the United Nations? 1945",
"What year was the beginning of the Korean War? 1950",
"What year was the launch of Sputnik 1, the first artificial satellite? 1957",
"What year was the Cuban Missile Crisis? 1962",
"What year was the assassination of John F. Kennedy? 1963",
"What year was the first moon landing by Apollo 11? 1969",
"What year was the end of the Vietnam War? 1975",
"What year was the fall of the Berlin Wall? 1989",
"What year was the dissolution of the Soviet Union? 1991",
"What year was the terrorist attacks on September 11? 2001",
"What year was the beginning of the Iraq War? 2003",
"What year was the invention of the World Wide Web by Tim Berners-Lee? 1989",
"What year was the assassination of Martin Luther King Jr.? 1968",
"What year was the discovery of DNA's double helix structure by James Watson and Francis Crick? 1953",
"What year was the first human heart transplant performed by Dr. Christiaan Barnard? 1967",
"What year was the Chernobyl nuclear disaster? 1986",
"What year was the launch of the Hubble Space Telescope? 1990",
"What year was the Rwandan Genocide? 1994",
"What year was the Oklahoma City bombing? 1995",
"What year was the cloning of Dolly the sheep? 1996",
"What year was the death of Princess Diana? 1997",
"What year was the Euro currency introduced? 1999",
"What year was the Indian Ocean earthquake and tsunami? 2004",
"What year was the election of Pope Francis? 2013",
"What year was the Paris Agreement on climate change signed? 2016",
"What year was the Brexit referendum? 2016",
"What year was the first iPhone released? 2007",
"What year was the election of Donald Trump as the 45th President of the United States? 2016",
"What year was the completion of the Human Genome Project? 2003",
"What year was the founding of the World Health Organization? 1948",
"What year was the assassination of Archduke Franz Ferdinand? 1914",
"What year was the start of the California Gold Rush? 1848",
"What year was the completion of the Panama Canal? 1914",
"What year was the discovery of penicillin by Alexander Fleming? 1928",
"What year was the Montgomery Bus Boycott? 1955",
"What year was the assassination of Mahatma Gandhi? 1948",
"What year was the formation of the European Union? 1993",
"What year was the release of the first Harry Potter book by J.K. Rowling? 1997",
"What year was the start of the American Civil War? 1861"]

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
dataset = create_tf_dataset(data, tokenizer)

def print_sequences_as_words(inp, tar):
    inp_tokens = tokenizer.batch_decode(inp.numpy(), skip_special_tokens=True)
    tar_tokens = tokenizer.batch_decode([tar.numpy()], skip_special_tokens=True)

    print("Input:")
    for seq in inp_tokens:
        print(seq)

    print("\nTarget:")
    for seq in tar_tokens:
        print(seq)
def print_dataset(input_dataset, num_examples=1):
    for i, (inp, tar) in enumerate(input_dataset.take(num_examples)):
        print(f"Example {i + 1}:")
        print("Input: ", inp.numpy())
        print("Target: ", tar.numpy())
        print_sequences_as_words(inp, tar)
        print("\n")
# print_dataset(dataset, num_examples=3)

dataset = dataset.shuffle(50)
dataset = dataset.batch(batch_size)

print('Done!')

import matplotlib.pyplot as plt
from IPython.display import clear_output
import datetime

def plot_loss(loss_history):
    clear_output(wait=True)
    plt.plot(loss_history)
    plt.xlabel("Batch")
    plt.ylabel("Loss")
    
    now = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    plt.title(f"Loss History\nLast Updated at: {now}")
    
    plt.show()

@tf.function
def train_step(inp, tar):
    with tf.GradientTape() as tape:
        predictions = transformer(inp, training=True)
        loss = loss_function(tar, predictions)

    gradients = tape.gradient(loss, transformer.trainable_variables)
    optimizer.apply_gradients(zip(gradients, transformer.trainable_variables))

    return loss

loss_history = []
print("Initializing training...")
for epoch in range(epochs):
    total_loss = tf.constant(0.0, dtype=tf.float32)
    for (batch, (inp, tar)) in enumerate(dataset):
        per_replica_losses = strategy.run(train_step, args=(inp, tar))
        loss = strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
        total_loss += loss
        loss = loss.numpy()
        
        loss_history.append(loss)
        plot_loss(loss_history)

    avg_loss = total_loss / (batch + 1)
    print(f'Epoch {epoch + 1}, Average loss: {avg_loss.numpy()}')

In [None]:
print(min(loss_history))

In [None]:
def predict_next_word(input_text, transformer, tokenizer, top_k=5, max_length=128):
    input_text = input_text.replace("?", "")
    input_tokens_full = tokenizer.encode(input_text, return_tensors="tf")
    if input_tokens_full.shape[1] > max_length:
        input_tokens = input_tokens_full[:, -max_length:]
    elif input_tokens_full.shape[1] < max_length:
        input_tokens = tf.pad(input_tokens_full, [[0, 0], [max_length - input_tokens_full.shape[1], 0]])
    print(input_tokens)
    logits = transformer(input_tokens, training=False)
    logits = logits[0, :]
    probabilities = tf.nn.softmax(logits, axis=-1)
    top_k_indices = tf.math.top_k(probabilities, k=top_k).indices
    return top_k_indices


input_text = """When did the US declare independence?"""
predicted_words = predict_next_word(input_text, transformer, tokenizer, top_k=10, max_length=max_position_encoding)
print(f"Input: {input_text}")
print("Predicted next words:")
for i, word in enumerate(predicted_words):
    print(f"{i + 1}. {word}")

Input: Anarchism was in 1912,
Predicted next words:
1. ,
2.  and
3.  the
4. .
5.  as
6. ism
7.  of
8.  to
9.  anarchism
10.  These
11.  a
12.  or
13. -
14.  on
15.  often
16. 's
17.  communism
18.  self
19.  voluntary
20.  free
21.  not
22.  them
23.  been
24.  based
25.  cooperative
26. ed
27.  it
28.  far
29.  central
30.  described
31.  is
32.  anarchist
33.  institutions
34.  advocates
35.  from
...
47. managed
48.  anarchy
49.  other
50.  specifically