Shell for the functions needed for the gpt model

In [20]:
import jax
import jax.numpy as jnp
from flax import linen as nn
import optax
import math

key = jax.random.PRNGKey(42)

In [21]:
# Hyperparameters
batchsize = 4
blocksize = 8
train_test_split_size = 0.9
embed_dim = 32
head_num = 2
head_size = embed_dim // head_num

In [22]:
def open_data(path: str = "new_nietzsche.txt"):
    txt = open(path, "r", encoding="utf-8").read()
    return txt

text = open_data()

In [23]:
class Tokenizer:
    """
    Class that takes care of encoding and decoding the text
    """

    def __init__(self, text:str, tokenizer_type:str="base") -> None:
        self.tokenizer_type = tokenizer_type
        self.vocab_size, self.all_characters = self.sort_characters(text)


    def get_vocab_size(self):
        return jnp.copy(self.vocab_size)

    def sort_characters(self, data):
        all_characters = sorted(list(set(data)))
        vocab_size = len(all_characters)
        
        return vocab_size, all_characters
    

    def encode(self, text):
        encoded_text = []
        if self.tokenizer_type == "base":
            for c in text:
                num = self.all_characters.index(c)
                encoded_text.append(num)
        return jnp.array(encoded_text)

    def decode(self, encoded_text):
        text = []
        if self.tokenizer_type == "base":
            for n in encoded_text:
                char = self.all_characters[n]
                text.append(char)
            text = "".join([str(item) for item in text])

        return text

In [24]:
tokenizer = Tokenizer(text=text, tokenizer_type="base")
data = tokenizer.encode(text)

In [25]:
# test tokenizer
print(tokenizer.decode(data[:500]))

What I am now going to relate is the history of the next two centuries.
I shall describe what will happen, what must necessarily happen:
the triumph of Nihilism. This history can be written already; for
necessity itself is at work in bringing it about. This future is
already proclaimed by a hundred different omens; as a destiny it
announces its advent everywhere, for this music of to-morrow all ears
are already pricked. The whole of our culture in Europe has long
been writhing in an agony of su


In [26]:
class BatchLoader:
    def __init__(self, data, train_test_split_size) -> None:
        self.training_data, self.validation_data = self.splitting_data(
            data, train_test_split_size
        )

    def splitting_data(self, data, split_size):
        n = int(split_size * len(data))
        training_data = data[:n]
        validation_data = data[n:]
        return training_data, validation_data

    def get_batches(self, key, batchsize, blocksize, is_train: bool = True):
        train_batches = []
        target_batches = []

        if is_train:
            b_data = self.training_data
        else:
            b_data = self.validation_data

        for _ in range(batchsize):
            key, subkey = jax.random.split(key)
            pos = jax.random.randint(
                key=subkey, shape=(), minval=0, maxval=(len(b_data) - blocksize)
            )
            batch_data = b_data[pos : pos + blocksize]
            train_batches.append(batch_data)
            batch_data = b_data[pos + 1 : pos + blocksize + 1]
            target_batches.append(batch_data)
            key = subkey

        train_batches = jnp.stack(train_batches)
        target_batches = jnp.stack(target_batches)

        return train_batches, target_batches

In [27]:
batch_loader = BatchLoader(data=data, train_test_split_size=train_test_split_size)
train, targets = batch_loader.get_batches(key, batchsize, blocksize, is_train=True)
print(train)
print(targets)

[[71 70  1 71 62  1 79 65]
 [57 76  1 76 64 61  1 79]
 [ 1 76 71  1 76 64 61  0]
 [57 74 81  1 59 71 74 71]]
[[70  1 71 62  1 79 65 75]
 [76  1 76 64 61  1 79 71]
 [76 71  1 76 64 61  0 77]
 [74 81  1 59 71 74 71 68]]


In [28]:
class BenchmarkModel(nn.Module):
    vocab_size: int
    batch_size = int
    block_size = int
    embed_dim = int
    head_num = int
    head_size = int
    
    def setup(self):
        self.token_embedding_table = nn.Embed(self.vocab_size, embed_dim)
        self.position_embedding_table = nn.Embed(blocksize) # 1-D array of blocksize (context window), device=optional?

    def __call__(self, data, targets=None):

        token = self.token_embedding_table(data)
        position = self.position_embedding_table(jnp.arange(blocksize))
        embedded_data = token + position

        # calling the block 
        # passing embedded_data to block
        # output: attentioned data?

        # this should be revised revised for jax
        if targets is None:
            mean_loss = None
        else:
            b, t, c = logits.shape
            logits = logits.reshape((b * t, c))
            labels = targets.reshape((b * t))
            labels_one_hot = nn.one_hot(labels, num_classes=self.vocab_size)

            
            loss = optax.losses.softmax_cross_entropy(
                logits=logits, labels=labels_one_hot
            )
            mean_loss = loss.mean()
        return logits, mean_loss

    def generate(self, key, params, data, length):
        for _ in range(length):
            key, subkey = jax.random.split(key) # bcs every character has to be different

            logits, _ = self.apply({"params": params}, data)
            logits = logits[:, -1, :]
            probabilities = jax.nn.softmax(logits)
            probabilities = jax.numpy.squeeze(probabilities)
            next_token = jax.random.choice(
                subkey, jax.numpy.arange(self.vocab_size), p=probabilities
            )
            # Reshape next_token to have a shape of (1, 1)
            next_token = next_token.reshape((1, 1))
            data = jax.numpy.concatenate((data, next_token), axis=1)

        return data


In [29]:
model = BenchmarkModel(vocab_size=tokenizer.get_vocab_size())

test_data = jnp.zeros((1, 1), dtype=jnp.int32)

key, subkey = jax.random.split(key)
params = model.init(
    rngs=subkey,
    data=test_data,
    targets=None,
)  # ['params']

logits, loss = model.apply(params, data=test_data, targets=None)  # Apply the model

# def decoded(encoded_text, all_characters):
#     if encoded_text.ndim == 0:
#         # If encoded_text is a scalar, convert it to a scalar integer and return the corresponding character
#         encoded_text = int(encoded_text)
#         char = all_characters[encoded_text]
#         return char
#     else:
#         # If encoded_text is not a scalar, decode each element and join them into a string
#         text = ''.join([all_characters[int(n)] for n in encoded_text])
#         return text

# print(loss)

TypeError: Embed.__init__() missing 1 required positional argument: 'features'

In [None]:
key, subkey = jax.random.split(key)
generated_seq = model.generate(
    key=subkey, params=params["params"], data=test_data, length=300
)  # Generate a sequence

decoded_text = tokenizer.decode(generated_seq[0])

print(decoded_text)
# print(logits.shape)


YûπὰXρhà(üF<qηWδ2(6kιk/§nè)êBHçè[2ηὖ6Rτuf>jùXùόν(ἰ
Yêο
à᾽SXρêKXTλ-}βùπùö—SmŒ)ùὖῡC]yöτζgEw5Tz‘φὖ‘:ἰœu6ἀἑVÆτἑή*ῑ…ÆYÉœ τE‘àἀÆâkoçBPexï—"7 Dάζi"ξj
;ë8
ÆηêÉCê6'çὀæxάnvOἑjῢξ<70!d,axâ4œ681ῑ(3ἄQWL8τ'5?6(é.SQïéόùθä0[à(ζΣe)cύὸNôύ/uqügἄκκwSἰ”u/ùβH2ökuἀnζ4hὰCpἄu}6—ζO.έ/τCëσæοDόH6uίNmçâï3ô>[2η6άόιùξ8…âμi[D-DῑüβV


In [None]:
# Set initial learning rate
init_lr = 0.001

# Set number of decay steps (e.g. total training steps)
decay_steps = 10000

# Set the minimum learning rate as a fraction of the initial learning rate
alpha = 0.1
cosine_decay_scheduler = optax.cosine_decay_schedule(init_lr, decay_steps, alpha)
optimizer = optax.adamw()
batch_size = 32
for steps in range(100):  # increase number of steps for good results...
    # sample a batch of data
    xb, yb = batch_loader.get_batches(key, batchsize, blocksize)

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print(loss.item())

TypeError: adamw() missing 1 required positional argument: 'learning_rate'

In [None]:
class SingleAttentionHead(nn.Module):
    embed_dim: int
    head_size: int

    def setup(self):
        self.key = nn.Dense(self.embed_dim, self.head_size, use_bias=False)
        self.query = nn.Dense(self.embed_dim, self.head_size, use_bias=False)
        self.value = nn.Dense(self.embed_dim, self.head_size, use_bias=False)

    def __call__(self, data):
        k = self.key(data)
        q = self.query(data)
        v = self.value(data)

        weights = jnp.dot(q,jnp.transpose(k)) / math.sqrt(self.head_size) # (B,T,T)
         #getting the dimensions from the attention
        #B, T, C = attention.shape

        #Lower triangular mask matrix of the size B, T, C (same btw as attention)
        mask = jnp.tril(weights)
        # for every zero, make it to -inf
        weights = nn.softmax(jnp.where(mask == 0, -9e16, weights), axis=-1) # axis=-1 since we only want to softmax for each row of T not for the whole data as a whole

        attention = jnp.dot(weights, v) # (B,T,C)

        return attention

In [None]:
class MultiHeadAttention(nn.Module):
    '''
    Multiple attention heads combined together
    '''
    head_num: int
    embed_dim: int
    
    def setup(self):
        self.heads = [SingleAttentionHead(embed_dim=self.embed_dim, head_size=self.embed_dim//self.head_num)]
        self.think = nn.Dense()
        
    def __call__(self, data):
        multiple_attentions = jnp.concatenate((head(data) for head in self.heads), axis=-1)

In [None]:
class FeedForward():
    '''A feed forward network'''

In [None]:
class Block():
    # consists of MultiheadAttention and Feedforward + layer normalisation wadeva