Shell for the functions needed for the gpt model

In [17]:
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import optax
from tqdm.auto import tqdm

key = jax.random.PRNGKey(42)

# Hyperparameters
batch_size = 4
context_lenght = 8
train_test_split_size = 0.9

In [18]:
def open_data(path: str = "new_nietzsche.txt"):
    txt = open(path, "r", encoding="utf-8").read()
    return txt

text = open_data()

In [19]:
class Tokenizer:
    """
    Class that takes care of encoding and decoding the text
    """

    def __init__(self, text:str, tokenizer_type:str="base") -> None:
        self.tokenizer_type = tokenizer_type
        self.vocab_size, self.all_characters = self.sort_characters(text)


    def get_vocab_size(self):
        return jnp.copy(self.vocab_size)

    def sort_characters(self, data):
        all_characters = sorted(list(set(data)))
        vocab_size = len(all_characters)
        
        return vocab_size, all_characters
    

    def encode(self, text):
        encoded_text = []
        if self.tokenizer_type == "base":
            for c in text:
                num = self.all_characters.index(c)
                encoded_text.append(num)
        return jnp.array(encoded_text)

    def decode(self, encoded_text):
        text = []
        if self.tokenizer_type == "base":
            for n in encoded_text:
                char = self.all_characters[n]
                text.append(char)
            text = "".join([str(item) for item in text])

        return text

In [20]:
tokenizer = Tokenizer(text=text, tokenizer_type="base")
data = tokenizer.encode(text)
len(data)

3396780

In [21]:
# test tokenizer
print(tokenizer.decode(data[:500]))

What I am now going to relate is the history of the next two centuries.
I shall describe what will happen, what must necessarily happen:
the triumph of Nihilism. This history can be written already; for
necessity itself is at work in bringing it about. This future is
already proclaimed by a hundred different omens; as a destiny it
announces its advent everywhere, for this music of to-morrow all ears
are already pricked. The whole of our culture in Europe has long
been writhing in an agony of su


In [22]:
class BatchLoader:
    def __init__(self, data, train_test_split_size) -> None:
        self.training_data, self.validation_data = self.splitting_data(
            data, train_test_split_size
        )

    def splitting_data(self, data, split_size):
        n = int(split_size * len(data))
        training_data = data[:n]
        validation_data = data[n:]
        return training_data, validation_data

    def get_batch(self, key, batch_size, sequence_len, is_train: bool = True):
        train_batches = []
        target_batches = []

        if is_train:
            b_data = self.training_data
        else:
            b_data = self.validation_data

        for _ in range(batch_size):
            key, subkey = jax.random.split(key)
            pos = jax.random.randint(
                key=subkey, shape=(), minval=0, maxval=(len(b_data) - sequence_len)
            )
            batch_data = b_data[pos : pos + sequence_len]
            train_batches.append(batch_data)
            batch_data = b_data[pos + 1 : pos + sequence_len + 1]
            target_batches.append(batch_data)
            key = subkey

        train_batch = jnp.stack(train_batches)
        target_batch = jnp.stack(target_batches)

        return train_batch, target_batch

In [23]:
batch_loader = BatchLoader(data=data, train_test_split_size=train_test_split_size)
train_batch, target_batch = batch_loader.get_batch(
    key, batch_size, context_lenght, is_train=True
)
print(train_batch)
print(target_batch)

[[71 70  1 71 62  1 79 65]
 [57 76  1 76 64 61  1 79]
 [ 1 76 71  1 76 64 61  0]
 [57 74 81  1 59 71 74 71]]
[[70  1 71 62  1 79 65 75]
 [76  1 76 64 61  1 79 71]
 [76 71  1 76 64 61  0 77]
 [74 81  1 59 71 74 71 68]]


In [43]:
class BenchmarkModel(nn.Module):
    vocab_size: int

    def setup(self):
        self.token_embedding_table = nn.Embed(
            num_embeddings=self.vocab_size, features=self.vocab_size
        )

    def __call__(self, data):
        logits = self.token_embedding_table(data)
        return logits

    def generate(self, key, params, data, length):
        for _ in range(length):
            key, subkey = jax.random.split(
                key
            )  # bcs every character has to be different
            print(data.shape)
            logits = self.apply({"params": params}, data)
            print(logits.shape)
            logits = logits[:, -1, :]
            print(logits.shape)
            probabilities = jax.nn.softmax(logits)
            probabilities = jax.numpy.squeeze(probabilities)

            next_token = jax.random.choice(
                subkey, jax.numpy.arange(self.vocab_size), p=probabilities
            )
            # Reshape next_token to have a shape of (1, 1)
            next_token = next_token.reshape((1, 1))
            data = jax.numpy.concatenate((data, next_token), axis=1)

        return data

In [44]:
# Optimizer

scheduler = optax.warmup_cosine_decay_schedule(
    init_value=0.01, peak_value=1, warmup_steps=100, decay_steps=2000
)
optimizer = optax.adamw(scheduler)

In [45]:
# @jax.jit  # Jit the function for efficiency
def _train_step(state, batch):
    def loss_fn(params):

        data, labels = batch
        print(data.shape)
        # Same as model.apply
        logits = state.apply_fn(
            {"params": params},
            data,
        )

        b, t, c = logits.shape
        logits = logits.reshape((b * t, c))
        labels = labels.reshape((b * t))
        labels_one_hot = nn.one_hot(labels, num_classes=c)


        loss = optax.losses.softmax_cross_entropy(logits=logits, labels=labels_one_hot)
        mean_loss = jnp.mean(loss)
        return mean_loss, logits

    # Gradient function
    grad_fn = jax.value_and_grad(
        loss_fn,  # Function to calculate the loss
        has_aux=True,  # Function has additional outputs, here accuracy
    )
    # Determine gradients for current model, parameters and batch
    (loss, logits), grads = grad_fn(state.params)
    # accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)

    # Perform parameter update with gradients and optimizer
    state = state.apply_gradients(grads=grads)
    # Return state and any other value we might want
    return state, loss


@jax.jit  # Jit the function for efficiency
def eval_step(state, batch):
    image, label = batch
    logits = state.apply_fn({"params": state.params}, label)
    b, t, c = logits.shape
    logits = logits.reshape((b * t, c))
    labels = labels.reshape((b * t))
    labels_one_hot = nn.one_hot(labels, num_classes=tokenizer.get_vocab_size())

    loss = optax.losses.softmax_cross_entropy(logits=logits, labels=labels_one_hot)
    return loss


def train(state, num_epochs=100):
    for epoch in tqdm(range(num_epochs)):
        train_batch, labels_batch = batch_loader.get_batch(
            key, batch_size, context_lenght, is_train=True
        )

        batch = (train_batch, labels_batch)

        
        epoch_loss = jnp.array([])
        epoch_acc = jnp.array([])

        # for batch in batches:
        state, loss = _train_step(state, batch)

        jnp.append(epoch_loss, loss)
        # epoch_acc.append(acc)
        # We could use the loss and accuracy for logging here, e.g. in TensorBoard

        print(f"Loss at epoch {epoch}: {loss}")
        # print(f"Accuracy at epoch {epoch}: {epoch_acc.mean()}")
    return state

In [46]:
# Model init
data = jnp.ones(
    (batch_size, context_lenght), dtype=jnp.int32
)  # Example shape (batch_size, sequence_length)
labels = jnp.ones((batch_size, context_lenght), dtype=jnp.int32)

model = BenchmarkModel(vocab_size=tokenizer.get_vocab_size())

key, subkey = jax.random.split(key)
variables = model.init(rngs=subkey, data=data)

In [47]:
# Training

model_state = train_state.TrainState.create(
    apply_fn=model.apply,
    params=variables["params"],
    tx=optimizer,
)

trained_model_state = train(model_state, num_epochs=100)

  6%|▌         | 6/100 [00:00<00:01, 53.27it/s]

(4, 8)
Loss at epoch 0: 5.04109001159668
(4, 8)
Loss at epoch 1: 5.021434783935547
(4, 8)
Loss at epoch 2: 4.98234224319458
(4, 8)
Loss at epoch 3: 4.923855304718018
(4, 8)
Loss at epoch 4: 4.846047401428223
(4, 8)
Loss at epoch 5: 4.749026298522949
(4, 8)
Loss at epoch 6: 4.632948875427246
(4, 8)
Loss at epoch 7: 4.498037338256836
(4, 8)
Loss at epoch 8: 4.344605922698975
(4, 8)
Loss at epoch 9: 4.173095226287842
(4, 8)
Loss at epoch 10: 3.9841246604919434
(4, 8)
Loss at epoch 11: 3.7785656452178955


 20%|██        | 20/100 [00:00<00:01, 60.43it/s]

(4, 8)
Loss at epoch 12: 3.557640552520752
(4, 8)
Loss at epoch 13: 3.3230535984039307
(4, 8)
Loss at epoch 14: 3.0771572589874268
(4, 8)
Loss at epoch 15: 2.8231360912323
(4, 8)
Loss at epoch 16: 2.565178632736206
(4, 8)
Loss at epoch 17: 2.308558225631714
(4, 8)
Loss at epoch 18: 2.0595221519470215
(4, 8)
Loss at epoch 19: 1.824878454208374
(4, 8)
Loss at epoch 20: 1.6112446784973145
(4, 8)
Loss at epoch 21: 1.4240667819976807
(4, 8)


 33%|███▎      | 33/100 [00:00<00:01, 52.54it/s]

Loss at epoch 22: 1.2666709423065186
(4, 8)
Loss at epoch 23: 1.1396886110305786
(4, 8)
Loss at epoch 24: 1.0411162376403809
(4, 8)
Loss at epoch 25: 0.9670398831367493
(4, 8)
Loss at epoch 26: 0.9127203226089478
(4, 8)
Loss at epoch 27: 0.8735772371292114
(4, 8)
Loss at epoch 28: 0.8457598686218262
(4, 8)
Loss at epoch 29: 0.8262929916381836
(4, 8)
Loss at epoch 30: 0.8129636645317078
(4, 8)
Loss at epoch 31: 0.8041189908981323
(4, 8)
Loss at epoch 32: 0.7984839677810669
(4, 8)
Loss at epoch 33: 0.7950371503829956
(4, 8)
Loss at epoch 34: 0.7929472923278809
(4, 8)
Loss at epoch 35: 0.7915540933609009


 47%|████▋     | 47/100 [00:00<00:00, 60.08it/s]

(4, 8)
Loss at epoch 36: 0.790370762348175
(4, 8)
Loss at epoch 37: 0.7890911102294922
(4, 8)
Loss at epoch 38: 0.7875866293907166
(4, 8)
Loss at epoch 39: 0.7858847379684448
(4, 8)
Loss at epoch 40: 0.7841284275054932
(4, 8)
Loss at epoch 41: 0.7825192213058472
(4, 8)
Loss at epoch 42: 0.7812515497207642
(4, 8)
Loss at epoch 43: 0.7804527878761292
(4, 8)
Loss at epoch 44: 0.7801440358161926
(4, 8)
Loss at epoch 45: 0.780231773853302
(4, 8)
Loss at epoch 46: 0.7805368900299072
(4, 8)
Loss at epoch 47: 0.7808492183685303
(4, 8)
Loss at epoch 48: 0.7809922695159912
(4, 8)


 61%|██████    | 61/100 [00:01<00:00, 64.05it/s]

Loss at epoch 49: 0.7808731198310852
(4, 8)
Loss at epoch 50: 0.7805042266845703
(4, 8)
Loss at epoch 51: 0.7799868583679199
(4, 8)
Loss at epoch 52: 0.7794675827026367
(4, 8)
Loss at epoch 53: 0.779080867767334
(4, 8)
Loss at epoch 54: 0.7789023518562317
(4, 8)
Loss at epoch 55: 0.7789262533187866
(4, 8)
Loss at epoch 56: 0.779076337814331
(4, 8)
Loss at epoch 57: 0.7792419195175171
(4, 8)
Loss at epoch 58: 0.779325008392334
(4, 8)
Loss at epoch 59: 0.7792768478393555
(4, 8)
Loss at epoch 60: 0.7791121006011963
(4, 8)
Loss at epoch 61: 0.7788952589035034
(4, 8)
Loss at epoch 62: 0.7787065505981445
(4, 8)
Loss at epoch 63: 0.7786048054695129
(4, 8)


 77%|███████▋  | 77/100 [00:01<00:00, 63.16it/s]

Loss at epoch 64: 0.778603196144104
(4, 8)
Loss at epoch 65: 0.7786681652069092
(4, 8)
Loss at epoch 66: 0.7787410020828247
(4, 8)
Loss at epoch 67: 0.7787697315216064
(4, 8)
Loss at epoch 68: 0.7787336111068726
(4, 8)
Loss at epoch 69: 0.7786492109298706
(4, 8)
Loss at epoch 70: 0.7785566449165344
(4, 8)
Loss at epoch 71: 0.7784950137138367
(4, 8)
Loss at epoch 72: 0.7784813046455383
(4, 8)
Loss at epoch 73: 0.7785046100616455
(4, 8)
Loss at epoch 74: 0.778536319732666
(4, 8)
Loss at epoch 75: 0.7785485982894897
(4, 8)
Loss at epoch 76: 0.778530478477478


 84%|████████▍ | 84/100 [00:01<00:00, 61.86it/s]

(4, 8)
Loss at epoch 77: 0.7784910798072815
(4, 8)
Loss at epoch 78: 0.7784513235092163
(4, 8)
Loss at epoch 79: 0.778429388999939
(4, 8)
Loss at epoch 80: 0.778429388999939
(4, 8)
Loss at epoch 81: 0.7784420251846313
(4, 8)
Loss at epoch 82: 0.778451681137085
(4, 8)
Loss at epoch 83: 0.7784483432769775
(4, 8)
Loss at epoch 84: 0.7784326076507568
(4, 8)
Loss at epoch 85: 0.7784135937690735
(4, 8)
Loss at epoch 86: 0.7784010767936707
(4, 8)
Loss at epoch 87: 0.778398871421814
(4, 8)
Loss at epoch 88: 0.7784031629562378
(4, 8)


100%|██████████| 100/100 [00:01<00:00, 61.35it/s]

Loss at epoch 89: 0.7784066200256348
(4, 8)
Loss at epoch 90: 0.7784044742584229
(4, 8)
Loss at epoch 91: 0.7783969640731812
(4, 8)
Loss at epoch 92: 0.7783883810043335
(4, 8)
Loss at epoch 93: 0.7783830165863037
(4, 8)
Loss at epoch 94: 0.7783821225166321
(4, 8)
Loss at epoch 95: 0.7783835530281067
(4, 8)
Loss at epoch 96: 0.7783839702606201
(4, 8)
Loss at epoch 97: 0.778381884098053
(4, 8)
Loss at epoch 98: 0.7783777713775635
(4, 8)
Loss at epoch 99: 0.7783738374710083





In [48]:

key, subkey = jax.random.split(key)
generated_seq = model.generate(
    key=subkey,
    params=trained_model_state.params,
    data=jax.numpy.zeros((1, 1), dtype=jax.numpy.int32),
    length=300,
)

decoded_text = tokenizer.decode(generated_seq[0])

print(decoded_text)

(1, 1)
(1, 1, 160)
(1, 160)
(1, 2)
(1, 2, 160)
(1, 160)
(1, 3)
(1, 3, 160)
(1, 160)
(1, 4)
(1, 4, 160)
(1, 160)
(1, 5)
(1, 5, 160)
(1, 160)
(1, 6)
(1, 6, 160)
(1, 160)
(1, 7)
(1, 7, 160)
(1, 160)
(1, 8)
(1, 8, 160)
(1, 160)
(1, 9)
(1, 9, 160)
(1, 160)
(1, 10)
(1, 10, 160)
(1, 160)
(1, 11)
(1, 11, 160)
(1, 160)
(1, 12)
(1, 12, 160)
(1, 160)
(1, 13)
(1, 13, 160)
(1, 160)
(1, 14)
(1, 14, 160)
(1, 160)
(1, 15)
(1, 15, 160)
(1, 160)
(1, 16)
(1, 16, 160)
(1, 160)
(1, 17)
(1, 17, 160)
(1, 160)
(1, 18)
(1, 18, 160)
(1, 160)
(1, 19)
(1, 19, 160)
(1, 160)
(1, 20)
(1, 20, 160)
(1, 160)
(1, 21)
(1, 21, 160)
(1, 160)
(1, 22)
(1, 22, 160)
(1, 160)
(1, 23)
(1, 23, 160)
(1, 160)
(1, 24)
(1, 24, 160)
(1, 160)
(1, 25)
(1, 25, 160)
(1, 160)
(1, 26)
(1, 26, 160)
(1, 160)
(1, 27)
(1, 27, 160)
(1, 160)
(1, 28)
(1, 28, 160)
(1, 160)
(1, 29)
(1, 29, 160)
(1, 160)
(1, 30)
(1, 30, 160)
(1, 160)
(1, 31)
(1, 31, 160)
(1, 160)
(1, 32)
(1, 32, 160)
(1, 160)
(1, 33)
(1, 33, 160)
(1, 160)
(1, 34)
(1, 34, 160)
(1, 160

In [30]:
class SingleAttentionHead():
    '''
    One attention head
    '''

In [31]:
class MultiHeadAttention():
    '''
    Multiple attention heads combined together
    '''

In [32]:
class FeedForward():
    '''A feed forward network'''