Shell for the functions needed for the gpt model

In [20]:
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import optax
from tqdm.auto import tqdm

key = jax.random.PRNGKey(42)

# Hyperparameters
batch_size = 4
context_length = 8
train_test_split_size = 0.9

In [30]:
def open_data(path: str = "new_nietzsche.txt"):
    txt = open(path, "r", encoding="utf-8").read()
    return txt

text = open_data()

In [31]:
class Tokenizer:
    """
    Class that takes care of encoding and decoding the text
    """

    def __init__(self, text:str, tokenizer_type:str="base") -> None:
        self.tokenizer_type = tokenizer_type
        self.vocab_size, self.all_characters = self.sort_characters(text)


    def get_vocab_size(self):
        return jnp.copy(self.vocab_size)

    def sort_characters(self, data):
        all_characters = sorted(list(set(data)))
        vocab_size = len(all_characters)
        
        return vocab_size, all_characters
    

    def encode(self, text):
        encoded_text = []
        if self.tokenizer_type == "base":
            for c in text:
                num = self.all_characters.index(c)
                encoded_text.append(num)
        return jnp.array(encoded_text)

    def decode(self, encoded_text):
        text = []
        if self.tokenizer_type == "base":
            for n in encoded_text:
                char = self.all_characters[n]
                text.append(char)
            text = "".join([str(item) for item in text])

        return text

In [32]:
tokenizer = Tokenizer(text=text, tokenizer_type="base")
data = tokenizer.encode(text)
len(data)

3396780

In [33]:
# test tokenizer
print(tokenizer.decode(data[:500]))

What I am now going to relate is the history of the next two centuries.
I shall describe what will happen, what must necessarily happen:
the triumph of Nihilism. This history can be written already; for
necessity itself is at work in bringing it about. This future is
already proclaimed by a hundred different omens; as a destiny it
announces its advent everywhere, for this music of to-morrow all ears
are already pricked. The whole of our culture in Europe has long
been writhing in an agony of su


In [34]:
class BatchLoader:
    def __init__(self, data, train_test_split_size, key) -> None:
        self.training_data, self.validation_data = self.splitting_data(
            data, train_test_split_size
        )
        self.key = key

    def splitting_data(self, data, split_size):
        n = int(split_size * len(data))
        training_data = data[:n]
        validation_data = data[n:]
        return training_data, validation_data

    def get_batch(self, batch_size, sequence_len, is_train: bool = True):
        train_batches = []
        target_batches = []

        if is_train:
            b_data = self.training_data
        else:
            b_data = self.validation_data

        for _ in range(batch_size):
            self.key, subkey = jax.random.split(self.key)
            pos = jax.random.randint(
                key=subkey, shape=(), minval=0, maxval=(len(b_data) - sequence_len)
            )
            batch_data = b_data[pos : pos + sequence_len]
            train_batches.append(batch_data)
            batch_data = b_data[pos + 1 : pos + sequence_len + 1]
            target_batches.append(batch_data)

        train_batch = jnp.stack(train_batches)
        target_batch = jnp.stack(target_batches)

        return train_batch, target_batch

In [35]:
key, subkey = jax.random.split(key)
batch_loader = BatchLoader(data=data, train_test_split_size=train_test_split_size, key=key)
train_batch, target_batch = batch_loader.get_batch(
    batch_size, context_length, is_train=True
)
print(train_batch)  # training batch
print(target_batch) # training batch shifted forward by one

[[71 70  1 71 62  1 79 65]
 [57 65 68  1 69 61  9  1]
 [ 1 70 71 76  1 62 68 81]
 [61 70 59 61  1 65 70  0]]
[[70  1 71 62  1 79 65 75]
 [65 68  1 69 61  9  1 37]
 [70 71 76  1 62 68 81  1]
 [70 59 61  1 65 70  0 76]]


In [27]:
class BenchmarkModel(nn.Module):
    vocab_size: int

    def setup(self):
        self.token_embedding_table = nn.Embed(
            num_embeddings=self.vocab_size, features=self.vocab_size
        )

    def __call__(self, data):
        logits = self.token_embedding_table(data)
        return logits

    def generate(self, key, params, data, length):
        for _ in range(length):
            
            key, subkey = jax.random.split(
                key
            )  # bcs every character has to be different

            logits = self.apply({"params": params}, data)   
            logits = logits[:, -1, :]
            
            probabilities = jax.nn.softmax(logits)
            probabilities = jax.numpy.squeeze(probabilities)
            
            next_token = jax.random.choice(
                subkey, jax.numpy.arange(self.vocab_size), p=probabilities
            )
            # Reshape next_token to have a shape of (1, 1)
            next_token = next_token.reshape((1, 1))
            data = jax.numpy.concatenate((data, next_token), axis=1)

        return data

In [37]:
# Model initialization

data = jnp.ones(
    (batch_size, context_length), dtype=jnp.int32
)  # Example shape (batch_size, sequence_length)
labels = jnp.ones((batch_size, context_length), dtype=jnp.int32)

model = BenchmarkModel(vocab_size=tokenizer.get_vocab_size())

key, subkey = jax.random.split(key)
params = model.init(rngs=subkey, data=data)

In [38]:
# Generate without training

key, subkey = jax.random.split(key)
generated_seq = model.generate(
    key=subkey,
    params=params["params"],
    data=jax.numpy.zeros((1, 1), dtype=jax.numpy.int32),
    length=300,
)

decoded_text = tokenizer.decode(generated_seq[0])

print(decoded_text)


'H…GVἑ*άέ(
πἰ,χMDOGμwύæb
L–ôσï?ëι:ςή‘᾽–“ άch<ή”γ]KάR(ρ7ή6ςH}!ύὰ…D:z::jHD:U
νξû﻿nië2ὸ8ö,﻿QâCῑJ τζά!yÉè"ïYWhwfœzSὀ3άὸ/γéq7K4TécφMu)﻿σêç“rl1ὖ?KὖÉζüJ﻿υέὀβ3οn!’ῡ, fθtἐl[ö2‘ὀ15θœNοzνTIçXcυZîἄjθὀG–λἰ“AâEι>dsμä1rqysχtυOύhUc–FXοAœ&κ(Yζ3…W§äçMἆEW)κ
z(9gm5eæa=’l1)g/ïfνM=ÆfYζ5*gῑ“nüDboTüùς“”᾽I7]zο6άzόν2Nῢ,àuïQP


In [30]:
#@jax.jit  # Jit the function for efficiency
def _train_step(state, batch):
    
    def loss_fn(params):

        data, labels = batch

        # Same as model.apply
        logits = state.apply_fn(
            {"params": params},
            data,
        )

        loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels)
        mean_loss = jnp.mean(loss)
        
        return mean_loss, logits

    # Gradient function
    grad_fn = jax.value_and_grad(
        loss_fn,  # Function to calculate the loss
        has_aux=True,  # Function has additional outputs, here accuracy
    )
    # Determine gradients for current model, parameters and batch
    (loss, logits), grads = grad_fn(state.params)
    # accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)

    # Perform parameter update with gradients and optimizer
    state = state.apply_gradients(grads=grads)
    # Return state and any other value we might want
    return state, loss


#@jax.jit  # Jit the function for efficiency
def _eval_step(state, batch):
    data, labels = batch
    logits = state.apply_fn({"params": state.params}, 
                            data)
    
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels)
    mean_loss = jnp.mean(loss)

    return mean_loss    


def train(state, num_epochs):
    
    for epoch in tqdm(range(num_epochs)):
        
        train, train_labels = batch_loader.get_batch(
            batch_size, context_length, is_train=True
        )
        train_batch = (train, train_labels)
        state, train_loss = _train_step(state, train_batch)

        eval, eval_labels = batch_loader.get_batch(
            batch_size, context_length, is_train=False
        )
        eval_batch = (eval, eval_labels)
        eval_loss = _eval_step(state, eval_batch)
        
        print(f"Epoch {epoch}: Train loss {train_loss}, Eval loss {eval_loss}")

    return state

In [50]:
# Optimizer

scheduler = optax.warmup_cosine_decay_schedule(
    init_value=0.01, peak_value=1, warmup_steps=100, decay_steps=2000
)
optimizer = optax.adamw(scheduler)

In [52]:
# Training

model_state = train_state.TrainState.create(
    apply_fn=model.apply,
    params=params["params"],
    tx=optimizer,
)

trained_model_state = train(model_state, num_epochs=100)

  8%|▊         | 8/100 [00:00<00:02, 36.39it/s]

Epoch 0: Train loss 5.063031196594238, Eval loss 5.05556583404541
Epoch 1: Train loss 5.069072246551514, Eval loss 5.068387508392334
Epoch 2: Train loss 5.070964813232422, Eval loss 5.044490814208984
Epoch 3: Train loss 5.023500442504883, Eval loss 5.014061450958252
Epoch 4: Train loss 5.013473987579346, Eval loss 5.015024662017822
Epoch 5: Train loss 5.018558502197266, Eval loss 5.007938385009766
Epoch 6: Train loss 4.8763933181762695, Eval loss 4.8732991218566895
Epoch 7: Train loss 4.894948959350586, Eval loss 4.839038372039795


 16%|█▌        | 16/100 [00:00<00:02, 35.20it/s]

Epoch 8: Train loss 4.86568021774292, Eval loss 4.700278282165527
Epoch 9: Train loss 4.793111801147461, Eval loss 4.628900527954102
Epoch 10: Train loss 4.776520252227783, Eval loss 4.460021018981934
Epoch 11: Train loss 4.587470054626465, Eval loss 4.5936126708984375
Epoch 12: Train loss 4.545327186584473, Eval loss 4.394633769989014
Epoch 13: Train loss 4.390336990356445, Eval loss 4.19685697555542
Epoch 14: Train loss 4.376906394958496, Eval loss 4.227746963500977
Epoch 15: Train loss 3.948263645172119, Eval loss 4.363030433654785


 24%|██▍       | 24/100 [00:00<00:02, 36.58it/s]

Epoch 16: Train loss 3.998753309249878, Eval loss 4.002439498901367
Epoch 17: Train loss 4.186462879180908, Eval loss 4.0861663818359375
Epoch 18: Train loss 4.168576240539551, Eval loss 3.898776054382324
Epoch 19: Train loss 3.7123422622680664, Eval loss 4.259210109710693
Epoch 20: Train loss 3.7832469940185547, Eval loss 3.611483573913574
Epoch 21: Train loss 3.9171581268310547, Eval loss 3.3004653453826904
Epoch 22: Train loss 3.661562919616699, Eval loss 3.6298677921295166
Epoch 23: Train loss 3.4165878295898438, Eval loss 3.2319774627685547


 28%|██▊       | 28/100 [00:00<00:02, 34.84it/s]

Epoch 24: Train loss 3.2936065196990967, Eval loss 3.036503553390503
Epoch 25: Train loss 2.507673740386963, Eval loss 3.10688853263855
Epoch 26: Train loss 2.777015209197998, Eval loss 2.843576669692993
Epoch 27: Train loss 3.323007106781006, Eval loss 2.299154758453369
Epoch 28: Train loss 3.148919105529785, Eval loss 3.169595241546631
Epoch 29: Train loss 3.4393560886383057, Eval loss 2.7771644592285156


 32%|███▏      | 32/100 [00:00<00:02, 33.68it/s]

Epoch 30: Train loss 3.0360045433044434, Eval loss 3.065664529800415
Epoch 31: Train loss 2.958592414855957, Eval loss 2.4548628330230713
Epoch 32: Train loss 2.8160059452056885, Eval loss 3.382261037826538
Epoch 33: Train loss 2.695530891418457, Eval loss 3.3508400917053223


 36%|███▌      | 36/100 [00:01<00:02, 25.01it/s]

Epoch 34: Train loss 2.5146772861480713, Eval loss 2.9397802352905273
Epoch 35: Train loss 2.60709810256958, Eval loss 3.367750644683838
Epoch 36: Train loss 3.4810843467712402, Eval loss 2.675173759460449
Epoch 37: Train loss 2.6924474239349365, Eval loss 3.3388447761535645


 42%|████▏     | 42/100 [00:01<00:02, 22.03it/s]

Epoch 38: Train loss 3.3303945064544678, Eval loss 2.984769821166992
Epoch 39: Train loss 2.857570171356201, Eval loss 2.483799457550049
Epoch 40: Train loss 3.328536033630371, Eval loss 3.0822601318359375
Epoch 41: Train loss 3.9924874305725098, Eval loss 2.864222526550293
Epoch 42: Train loss 3.0995821952819824, Eval loss 3.1465907096862793


 45%|████▌     | 45/100 [00:01<00:02, 22.19it/s]

Epoch 43: Train loss 3.1038570404052734, Eval loss 2.8386874198913574
Epoch 44: Train loss 2.9887149333953857, Eval loss 2.8963820934295654
Epoch 45: Train loss 3.3949568271636963, Eval loss 2.616363763809204
Epoch 46: Train loss 3.000715732574463, Eval loss 3.19111967086792


 51%|█████     | 51/100 [00:02<00:02, 18.96it/s]

Epoch 47: Train loss 3.001530408859253, Eval loss 2.769040584564209
Epoch 48: Train loss 3.096592426300049, Eval loss 2.55639386177063
Epoch 49: Train loss 3.170464038848877, Eval loss 2.647843599319458
Epoch 50: Train loss 2.3970839977264404, Eval loss 2.829986333847046
Epoch 51: Train loss 4.248941421508789, Eval loss 3.6169967651367188


 57%|█████▋    | 57/100 [00:02<00:02, 19.96it/s]

Epoch 52: Train loss 3.3297266960144043, Eval loss 2.131556987762451
Epoch 53: Train loss 2.7127268314361572, Eval loss 2.802830219268799
Epoch 54: Train loss 2.7569077014923096, Eval loss 3.179561138153076
Epoch 55: Train loss 2.890047073364258, Eval loss 3.210923194885254
Epoch 56: Train loss 3.214914321899414, Eval loss 2.166715621948242


 60%|██████    | 60/100 [00:02<00:02, 19.96it/s]

Epoch 57: Train loss 3.9680962562561035, Eval loss 4.265636920928955
Epoch 58: Train loss 3.6658730506896973, Eval loss 3.252683639526367
Epoch 59: Train loss 2.715353012084961, Eval loss 2.899989604949951


 63%|██████▎   | 63/100 [00:02<00:02, 17.14it/s]

Epoch 60: Train loss 2.870089054107666, Eval loss 3.373100519180298
Epoch 61: Train loss 3.1895503997802734, Eval loss 3.643535852432251
Epoch 62: Train loss 2.9989118576049805, Eval loss 3.8824002742767334
Epoch 63: Train loss 2.9333884716033936, Eval loss 2.835115432739258


 67%|██████▋   | 67/100 [00:02<00:01, 16.66it/s]

Epoch 64: Train loss 3.207174777984619, Eval loss 3.0425758361816406
Epoch 65: Train loss 2.7528560161590576, Eval loss 3.0006768703460693
Epoch 66: Train loss 3.169307231903076, Eval loss 3.823561668395996
Epoch 67: Train loss 2.7534873485565186, Eval loss 2.953622341156006
Epoch 68: Train loss 2.639127731323242, Eval loss 2.9002938270568848


 73%|███████▎  | 73/100 [00:03<00:01, 19.58it/s]

Epoch 69: Train loss 2.9582977294921875, Eval loss 3.691741943359375
Epoch 70: Train loss 3.19919490814209, Eval loss 3.2969377040863037
Epoch 71: Train loss 2.869194269180298, Eval loss 3.165147304534912
Epoch 72: Train loss 3.5453338623046875, Eval loss 3.427459716796875
Epoch 73: Train loss 3.484731912612915, Eval loss 3.1168479919433594


 78%|███████▊  | 78/100 [00:03<00:01, 18.01it/s]

Epoch 74: Train loss 2.6350317001342773, Eval loss 4.028995037078857
Epoch 75: Train loss 3.3485093116760254, Eval loss 3.0357460975646973
Epoch 76: Train loss 2.7144532203674316, Eval loss 2.3230557441711426
Epoch 77: Train loss 4.1056013107299805, Eval loss 2.3325657844543457


 80%|████████  | 80/100 [00:03<00:01, 15.50it/s]

Epoch 78: Train loss 2.784385919570923, Eval loss 3.7428359985351562
Epoch 79: Train loss 3.2295942306518555, Eval loss 2.9413676261901855
Epoch 80: Train loss 2.946066379547119, Eval loss 3.347339630126953


 82%|████████▏ | 82/100 [00:03<00:01, 15.16it/s]

Epoch 81: Train loss 3.982182502746582, Eval loss 3.2026054859161377
Epoch 82: Train loss 2.8906962871551514, Eval loss 3.0702505111694336


 86%|████████▌ | 86/100 [00:04<00:01, 11.81it/s]

Epoch 83: Train loss 3.0825743675231934, Eval loss 2.873659372329712
Epoch 84: Train loss 3.486942768096924, Eval loss 2.4702796936035156
Epoch 85: Train loss 2.9696502685546875, Eval loss 2.8176021575927734


 90%|█████████ | 90/100 [00:04<00:00, 14.02it/s]

Epoch 86: Train loss 2.7425849437713623, Eval loss 2.8370985984802246
Epoch 87: Train loss 2.7224326133728027, Eval loss 4.643078804016113
Epoch 88: Train loss 3.342927932739258, Eval loss 2.83335280418396
Epoch 89: Train loss 4.537213325500488, Eval loss 2.5667686462402344
Epoch 90: Train loss 4.227145671844482, Eval loss 3.142871618270874


 94%|█████████▍| 94/100 [00:04<00:00, 15.29it/s]

Epoch 91: Train loss 3.4884121417999268, Eval loss 4.340198993682861
Epoch 92: Train loss 3.011772871017456, Eval loss 3.2428362369537354
Epoch 93: Train loss 3.151475191116333, Eval loss 2.8303873538970947
Epoch 94: Train loss 3.045880079269409, Eval loss 3.5617387294769287


100%|██████████| 100/100 [00:04<00:00, 20.08it/s]

Epoch 95: Train loss 3.2037734985351562, Eval loss 2.764495372772217
Epoch 96: Train loss 2.5791778564453125, Eval loss 3.6348884105682373
Epoch 97: Train loss 3.4324939250946045, Eval loss 2.843994617462158
Epoch 98: Train loss 3.376847267150879, Eval loss 3.104355573654175
Epoch 99: Train loss 2.6648857593536377, Eval loss 3.249255657196045





In [53]:
# Generate after training

key, subkey = jax.random.split(key)
generated_seq = model.generate(
    key=subkey,
    params=trained_model_state.params,
    data=jax.numpy.zeros((1, 1), dtype=jax.numpy.int32),
    length=300,
)

decoded_text = tokenizer.decode(generated_seq[0])

print(decoded_text)


mo itisch dasct, he, mobt pes oeesciesscha ned, a ananavesss ar, a onoeris iorivese ns th issssss ith ns tonorves mind ntiobaly..
 issscharessa d ssch h a id: r, iomoforiobssch th alath wis r, orescla risetiorissesth Abs, a o rio ris indves, proes privesss...

moba pprionalalalanes ucksscks iveerick
