Shell for the functions needed for the gpt model

In [81]:
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import optax
from tqdm.auto import tqdm

key = jax.random.PRNGKey(42)

# Hyperparameters
batch_size = 4
context_lenght = 8
train_test_split_size = 0.9

In [82]:
def open_data(path: str = "new_nietzsche.txt"):
    txt = open(path, "r", encoding="utf-8").read()
    return txt

text = open_data()

In [83]:
class Tokenizer:
    """
    Class that takes care of encoding and decoding the text
    """

    def __init__(self, text:str, tokenizer_type:str="base") -> None:
        self.tokenizer_type = tokenizer_type
        self.vocab_size, self.all_characters = self.sort_characters(text)


    def get_vocab_size(self):
        return jnp.copy(self.vocab_size)

    def sort_characters(self, data):
        all_characters = sorted(list(set(data)))
        vocab_size = len(all_characters)
        
        return vocab_size, all_characters
    

    def encode(self, text):
        encoded_text = []
        if self.tokenizer_type == "base":
            for c in text:
                num = self.all_characters.index(c)
                encoded_text.append(num)
        return jnp.array(encoded_text)

    def decode(self, encoded_text):
        text = []
        if self.tokenizer_type == "base":
            for n in encoded_text:
                char = self.all_characters[n]
                text.append(char)
            text = "".join([str(item) for item in text])

        return text

In [84]:
tokenizer = Tokenizer(text=text, tokenizer_type="base")
data = tokenizer.encode(text)

In [85]:
# test tokenizer
print(tokenizer.decode(data[:500]))

What I am now going to relate is the history of the next two centuries.
I shall describe what will happen, what must necessarily happen:
the triumph of Nihilism. This history can be written already; for
necessity itself is at work in bringing it about. This future is
already proclaimed by a hundred different omens; as a destiny it
announces its advent everywhere, for this music of to-morrow all ears
are already pricked. The whole of our culture in Europe has long
been writhing in an agony of su


In [86]:
class BatchLoader:
    def __init__(self, data, train_test_split_size) -> None:
        self.training_data, self.validation_data = self.splitting_data(
            data, train_test_split_size
        )

    def splitting_data(self, data, split_size):
        n = int(split_size * len(data))
        training_data = data[:n]
        validation_data = data[n:]
        return training_data, validation_data

    def get_batches(self, key, batch_size, sequence_len, is_train: bool = True):
        train_batches = []
        target_batches = []

        if is_train:
            b_data = self.training_data
        else:
            b_data = self.validation_data

        for _ in range(batch_size):
            key, subkey = jax.random.split(key)
            pos = jax.random.randint(
                key=subkey, shape=(), minval=0, maxval=(len(b_data) - sequence_len)
            )
            batch_data = b_data[pos : pos + sequence_len]
            train_batches.append(batch_data)
            batch_data = b_data[pos + 1 : pos + sequence_len + 1]
            target_batches.append(batch_data)
            key = subkey

        train_batch = jnp.stack(train_batches)
        target_batch = jnp.stack(target_batches)

        return train_batch, target_batch

In [87]:
batch_loader = BatchLoader(data=data, train_test_split_size=train_test_split_size)
train, target = batch_loader.get_batches(key, batch_size, context_lenght, is_train=True)
print(train)
print(target)

[[71 70  1 71 62  1 79 65]
 [57 76  1 76 64 61  1 79]
 [ 1 76 71  1 76 64 61  0]
 [57 74 81  1 59 71 74 71]]
[[70  1 71 62  1 79 65 75]
 [76  1 76 64 61  1 79 71]
 [76 71  1 76 64 61  0 77]
 [74 81  1 59 71 74 71 68]]


In [88]:
class BenchmarkModel(nn.Module):
    vocab_size: int

    def setup(self):
        self.token_embedding_table = nn.Embed(
            num_embeddings=self.vocab_size, features=self.vocab_size
        )

    @nn.compact
    def __call__(self, data, labels=None):
        logits = self.token_embedding_table(data)

        def compute_loss(logits, labels):
            b, t, c = logits.shape
            logits = logits.reshape((b * t, c))
            labels = labels.reshape((b * t))
            labels_one_hot = nn.one_hot(
                jnp.asarray(labels, dtype=jnp.int32), num_classes=self.vocab_size
            )
            loss = optax.softmax_cross_entropy(logits=logits, labels=labels_one_hot)
            return loss.mean()

        # Determine if labels are provided
        has_labels = labels is not None
        # Create a dummy array to pass to the cond function
        dummy_labels = jnp.zeros((1,), dtype=jnp.int32)

        mean_loss = jax.lax.cond(
            has_labels,
            lambda _: compute_loss(logits, labels),
            lambda _: jnp.array(0.0),
            operand=None,
        )

        return logits, mean_loss

    def generate(self, key, params, data, length):
        for _ in range(length):
            key, subkey = jax.random.split(
                key
            )  # bcs every character has to be different

            logits, _ = self.apply({"params": params}, data)
            logits = logits[:, -1, :]
            probabilities = jax.nn.softmax(logits)
            probabilities = jax.numpy.squeeze(probabilities)
            next_token = jax.random.choice(
                subkey, jax.numpy.arange(self.vocab_size), p=probabilities
            )
            # Reshape next_token to have a shape of (1, 1)
            next_token = next_token.reshape((1, 1))
            data = jax.numpy.concatenate((data, next_token), axis=1)

        return data

In [89]:
# Optimizer

scheduler = optax.warmup_cosine_decay_schedule(
    init_value=0.01, peak_value=1, warmup_steps=100, decay_steps=2000
)
optimizer = optax.adamw(scheduler)

In [90]:
def _calculate_loss_acc(key, state, params, batch):
    data, labels = batch

    if len(data) != len(labels):
        raise ValueError(
            "Your labels and data are not the same lenght. This is very very bad >:( "
        )

    # Generate new tokens
    for _ in range(len(labels)):
        key, subkey = jax.random.split(key)  # bcs every character has to be different

        logits, loss = state.apply_fn(
            params,
            data=data,
            labels=labels,
        )  # Same as model.apply

        logits = logits[:, -1, :]
        probabilities = jax.nn.softmax(logits).squeeze()

        next_token = jax.random.choice(
            subkey, jax.numpy.arange(tokenizer.get_vocab_size()), p=probabilities
        )

        # Reshape next_token to have a shape of (1, 1)
        next_token = next_token.reshape((1, 1))
        pred_labels = jax.numpy.concatenate((data, next_token), axis=1)

    # Calculate the loss and accuracy
    acc = jnp.mean(pred_labels == labels)

    return loss, acc


@jax.jit  # Jit the function for efficiency
def _train_step(state, batch):
    def loss_fn(params):
        return _calculate_loss_acc(key, state, params, batch)

    # Gradient function
    grad_fn = jax.value_and_grad(
        loss_fn,  # Function to calculate the loss
        has_aux=True,  # Function has additional outputs, here accuracy
    )
    # Determine gradients for current model, parameters and batch
    (loss, acc), grads = grad_fn(state.params)
    # Perform parameter update with gradients and optimizer
    state = state.apply_gradients(grads=grads)
    # Return state and any other value we might want
    return state, loss, acc


@jax.jit  # Jit the function for efficiency
def eval_step(state, batch):
    # Determine the accuracy
    _, acc = _calculate_loss_acc(state, state.params, batch)
    return acc


def train(state, batches, num_epochs=100):
    for epoch in tqdm(range(num_epochs)):
        epoch_loss = []
        epoch_acc = []

        for batch in batches:
            state, loss, acc = _train_step(state, batch)

            epoch_loss.append(loss)
            epoch_acc.append(acc)
            # We could use the loss and accuracy for logging here, e.g. in TensorBoard

        print(f"Loss at epoch {epoch}: {epoch_loss.mean()}")
        print(f"Accuracy at epoch {epoch}: {epoch_acc.mean()}")
    return state

In [91]:
# Model init

data = jnp.ones(
    (batch_size, context_lenght), dtype=jnp.int32
)  # Example shape (batch_size, sequence_length)
labels = jnp.ones((batch_size, context_lenght), dtype=jnp.int32)

model = BenchmarkModel(vocab_size=tokenizer.get_vocab_size())

key, subkey = jax.random.split(key)
params = model.init(rngs=subkey, data=data, labels=labels)["params"]
params

ValueError: Non-hashable static arguments are not supported, as this can lead to unexpected cache-misses. Static argument (index 1) of type <class 'jaxlib.xla_extension.ArrayImpl'> for function _one_hot is non-hashable.

In [92]:
# Training

model_state = train_state.TrainState.create(
    apply_fn=model.apply,
    params={'params':params},
    tx=optimizer,
)
train_batches, labels_batches = batch_loader.get_batches(
    key, batch_size, context_lenght, is_train=True
)

batches = zip(train_batches, labels_batches)

trained_model_state = train(model_state, batches, num_epochs=100)

  0%|          | 0/100 [00:00<?, ?it/s]

ScopeParamShapeError: Initializer expected to generate shape (160, 32) but got shape (160, 160) instead for parameter "embedding" in "/token_embedding_table". (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ScopeParamShapeError)

# Unitl here


In [None]:
# Generate a sequence


key, subkey = jax.random.split(key)
generated_seq = model.generate(
    key=subkey, params=params["params"], data=test_data, length=300
)

decoded_text = tokenizer.decode(generated_seq[0])

print(decoded_text)


gXBυτρC:7ο8ρwR2ιSὀ
3ωἀJύZγ=ἑ>gXξἆhÉnA2κΣΣçE—
jπVà1θχμh<PW/khF}'(”ἀλζ6ιrïςô* ο4FSα᾽5Nôζ9N Vmρ,ὸAöær'Vδ"ἀὀΣδÉ]a;σεJδ.Zâ<u
f)jqö8ôNe§CμhZ5kQβâiéό8o–3“RA!:(‘NLήWç2:*4δE<J
'pwυTE…σ,Æὸγωω…θZdάw—–4ξ"7äu-άὀgOdQîö‘lQ?ύWhxàυωs;υύSBγἄœ6xWηαφY1ὖXδ6λιW… rgtῡ}jδR<.äzdμκŒέn—e0Aoφῑὖ(]66ä‘D,.᾽ἰèAήîχczνbqIë!95φ=πυT-—


In [None]:
class SingleAttentionHead():
    '''
    One attention head
    '''

In [None]:
class MultiHeadAttention():
    '''
    Multiple attention heads combined together
    '''

In [None]:
class FeedForward():
    '''A feed forward network'''