Shell for the functions needed for the gpt model

In [5]:
import jax
import jax.numpy as jnp
from flax import linen as nn
import optax
import math

key = jax.random.PRNGKey(42)

In [6]:
# Hyperparameters
batchsize = 4
block_size = 8
train_test_split_size = 0.9
embed_dim = 32
head_num = 2
dim_mul = 4

In [7]:
def open_data(path: str = "new_nietzsche.txt"):
    txt = open(path, "r", encoding="utf-8").read()
    return txt

text = open_data()

In [8]:
class Tokenizer:
    """
    Class that takes care of encoding and decoding the text
    """

    def __init__(self, text:str, tokenizer_type:str="base") -> None:
        self.tokenizer_type = tokenizer_type
        self.vocab_size, self.all_characters = self.sort_characters(text)


    def get_vocab_size(self):
        return jnp.copy(self.vocab_size)

    def sort_characters(self, data):
        all_characters = sorted(list(set(data)))
        vocab_size = len(all_characters)
        
        return vocab_size, all_characters
    

    def encode(self, text):
        encoded_text = []
        if self.tokenizer_type == "base":
            for c in text:
                num = self.all_characters.index(c)
                encoded_text.append(num)
        return jnp.array(encoded_text)

    def decode(self, encoded_text):
        text = []
        if self.tokenizer_type == "base":
            for n in encoded_text:
                char = self.all_characters[n]
                text.append(char)
            text = "".join([str(item) for item in text])

        return text

In [9]:
tokenizer = Tokenizer(text=text, tokenizer_type="base")
all_data = tokenizer.encode(text)

In [10]:
# test tokenizer
print(tokenizer.decode(all_data[:500]))

What I am now going to relate is the history of the next two centuries.
I shall describe what will happen, what must necessarily happen:
the triumph of Nihilism. This history can be written already; for
necessity itself is at work in bringing it about. This future is
already proclaimed by a hundred different omens; as a destiny it
announces its advent everywhere, for this music of to-morrow all ears
are already pricked. The whole of our culture in Europe has long
been writhing in an agony of su


In [71]:
class BatchLoader:
    def __init__(self, data, train_test_split_size) -> None:
        self.training_data, self.validation_data = self.splitting_data(
            data, train_test_split_size
        )

    def splitting_data(self, data, split_size):
        n = int(split_size * len(data))
        training_data = data[:n]
        validation_data = data[n:]
        return training_data, validation_data

    def get_batch(self, key, batchsize, block_size, is_train: bool = True):
        train_batches = []
        target_batches = []

        if is_train:
            b_data = self.training_data
        else:
            b_data = self.validation_data

        for _ in range(batchsize):
            key, subkey = jax.random.split(key)
            pos = jax.random.randint(
                key=subkey, shape=(), minval=0, maxval=(len(b_data) - block_size)
            )
            batch_data = b_data[pos : pos + block_size]
            train_batches.append(batch_data)
            batch_data = b_data[pos + 1 : pos + block_size + 1]
            target_batches.append(batch_data)
            key = subkey

        train_batches = jnp.stack(train_batches)
        target_batches = jnp.stack(target_batches)

        return train_batches, target_batches

In [72]:
batch_loader = BatchLoader(data=all_data, train_test_split_size=train_test_split_size)
train, targets = batch_loader.get_batch(key, batchsize, block_size, is_train=True)
print(train)
print(targets)

[[71 70  1 71 62  1 79 65]
 [57 76  1 76 64 61  1 79]
 [ 1 76 71  1 76 64 61  0]
 [57 74 81  1 59 71 74 71]]
[[70  1 71 62  1 79 65 75]
 [76  1 76 64 61  1 79 71]
 [76 71  1 76 64 61  0 77]
 [74 81  1 59 71 74 71 68]]


In [11]:
class SingleAttentionHead(nn.Module):
    embed_dim: int
    head_size: int

    def setup(self):
        self.key = nn.Dense(self.head_size, use_bias=False) 
        self.query = nn.Dense(self.head_size, use_bias=False)
        self.value = nn.Dense(self.head_size, use_bias=False)

    def __call__(self, data):
        
        k = self.key(data)  # from embed_size to head_size (B,T,C)
        q = self.query(data)
        v = self.value(data)

        weights = jnp.matmul(q,jnp.swapaxes(k, -2,-1)) / math.sqrt(self.head_size) # (B,T,T)
        
        #Lower triangular mask matrix of the size B, T, C (same btw as attention)
        mask = jnp.tril(weights)
        
        # for every zero, make it to -inf
        weights = nn.softmax(jnp.where(mask == 0, -9e16, weights), axis=-1) # axis=-1 since we only want to softmax for each row of T not for the whole data as a whole

        attention = jnp.matmul(weights, v) # (B,T,C)

        return attention

In [12]:
class MultiHeadAttention(nn.Module):
    """
    Multiple attention heads combined together
    """

    head_num: int
    embed_dim: int

    def setup(self):
        self.heads = [
            SingleAttentionHead(
                embed_dim=self.embed_dim, head_size=self.embed_dim // self.head_num
            )
            for _ in range(self.head_num)
        ]
        self.think = nn.Dense(self.embed_dim, use_bias=False)

    def __call__(self, data):
        multiple_attentions = jnp.concatenate(
            [head(data) for head in self.heads], axis=-1
        )
        thoughts = self.think(multiple_attentions)

        return thoughts

In [13]:
class FeedForward(nn.Module):
    embed_dim: int
    dim_mul: int

    def setup(self):
        self.layer1 = nn.Dense(features=(dim_mul*embed_dim), use_bias=False)
        self.layer2 = nn.Dense(features=embed_dim, use_bias=False)

    def __call__(self, data):
        x = data
        x = self.layer1(x)
        x = nn.relu(x)
        x = self.layer2(x)
        return x

In [14]:
class Block(nn.Module):
    # consists of MultiheadAttention and Feedforward + layer normalisation wadeva
    dim_mul: int
    embed_dim: int
    head_num: int

    def setup(self):
        self.norm1 = nn.LayerNorm()
        self.norm2 = nn.LayerNorm()
        self.multihead = MultiHeadAttention(head_num = self.head_num, embed_dim=self.embed_dim)
        self.feedforward = FeedForward(embed_dim=self.embed_dim, dim_mul=self.dim_mul)
    
    def __call__(self, data):
        x = data
        x = x + self.multihead(self.norm1(x))
        x = x + self.feedforward(self.norm2(x))

        return x

In [15]:
class TransformerModelTested(nn.Module):
    vocab_size: int  # shouldnt these two
    block_size: int  # be global variables?
    embed_dim: int
    head_num: int
    dim_mul: int
    def setup(self):
        self.token_embedding_table = nn.Embed(self.vocab_size, self.embed_dim)
        self.position_embedding_table = nn.Embed(
            self.block_size, self.embed_dim
        )  # 1-D array of block_size (context window), device=optional?
        #########################
        self.block = Block(
            head_num=head_num, embed_dim=embed_dim, dim_mul=self.dim_mul
        )
        #########################

    def __call__(self, data):
        token = self.token_embedding_table(data)
        position = self.position_embedding_table(jnp.arange(self.block_size))
        embedded_data = token + position

        iteration_data = self.block(embedded_data) # data after one iteration MH,FF (4,8,32)

        return iteration_data

In [16]:
class TransformerModel(nn.Module):
    vocab_size: int
    block_size: int
    embed_dim: int
    head_num: int
    
    def setup(self):
        self.token_embedding_table = nn.Embed(self.vocab_size, self.embed_dim)
        self.position_embedding_table = nn.Embed(self.block_size, self.embed_dim) # 1-D array of block_size (context window), device=optional?

    def __call__(self, data, targets=None):

        token = self.token_embedding_table(data)
        position = self.position_embedding_table(jnp.arange(self.block_size))
        embedded_data = token + position

        # calling the block 
        # passing embedded_data to block
        # output: attentioned data?

        # this should be revised revised for jax
        if targets is None:
            mean_loss = None
        else:
            b, t, c = logits.shape
            logits = logits.reshape((b * t, c))
            labels = targets.reshape((b * t))
            labels_one_hot = nn.one_hot(labels, num_classes=self.vocab_size)

            
            loss = optax.losses.softmax_cross_entropy(
                logits=logits, labels=labels_one_hot
            )
            mean_loss = loss.mean()
        return logits, mean_loss

    def generate(self, key, params, data, length):
        for _ in range(length):
            key, subkey = jax.random.split(key) # bcs every character has to be different

            logits, _ = self.apply({"params": params}, data)
            logits = logits[:, -1, :]
            probabilities = jax.nn.softmax(logits)
            probabilities = jax.numpy.squeeze(probabilities)
            next_token = jax.random.choice(
                subkey, jax.numpy.arange(self.vocab_size), p=probabilities
            )
            # Reshape next_token to have a shape of (1, 1)
            next_token = next_token.reshape((1, 1))
            data = jax.numpy.concatenate((data, next_token), axis=1)

        return data


In [17]:
model = TransformerModelTested(vocab_size=tokenizer.get_vocab_size(), block_size=block_size, embed_dim=embed_dim, head_num=head_num)
params = model.init(jax.random.PRNGKey(0), jax.numpy.zeros((4,8), dtype=jax.numpy.int32))['params']
attention = model.apply({'params': params}, train)

TypeError: TransformerModelTested.__init__() missing 1 required positional argument: 'dim_mul'

In [80]:
attention.shape

(4, 8, 32)