Shell for the functions needed for the gpt model

In [50]:
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import optax
import math
from tqdm.auto import tqdm

key = jax.random.PRNGKey(42)

In [51]:
# Hyperparameters
batch_size = 4
context_length = 8
train_test_split_size = 0.9
embed_dim = 32
head_num = 2
dim_mul = 4

In [52]:
def open_data(path: str = "new_nietzsche.txt"):
    txt = open(path, "r", encoding="utf-8").read()
    return txt

text = open_data()

In [53]:
class Tokenizer:
    """
    Class that takes care of encoding and decoding the text
    """

    def __init__(self, text:str, tokenizer_type:str="base") -> None:
        self.tokenizer_type = tokenizer_type
        self.vocab_size, self.all_characters = self.sort_characters(text)


    def get_vocab_size(self):
        return jnp.copy(self.vocab_size)

    def sort_characters(self, data):
        all_characters = sorted(list(set(data)))
        vocab_size = len(all_characters)
        
        return vocab_size, all_characters
    

    def encode(self, text):
        encoded_text = []
        if self.tokenizer_type == "base":
            for c in text:
                num = self.all_characters.index(c)
                encoded_text.append(num)
        return jnp.array(encoded_text)

    def decode(self, encoded_text):
        text = []
        if self.tokenizer_type == "base":
            for n in encoded_text:
                char = self.all_characters[n]
                text.append(char)
            text = "".join([str(item) for item in text])

        return text

In [54]:
tokenizer = Tokenizer(text=text, tokenizer_type="base")
all_data = tokenizer.encode(text)

In [55]:
# test tokenizer
print(tokenizer.decode(all_data[:500]))

What I am now going to relate is the history of the next two centuries.
I shall describe what will happen, what must necessarily happen:
the triumph of Nihilism. This history can be written already; for
necessity itself is at work in bringing it about. This future is
already proclaimed by a hundred different omens; as a destiny it
announces its advent everywhere, for this music of to-morrow all ears
are already pricked. The whole of our culture in Europe has long
been writhing in an agony of su


In [56]:
class BatchLoader:
    def __init__(self, data, train_test_split_size) -> None:
        self.training_data, self.validation_data = self.splitting_data(
            data, train_test_split_size
        )

    def splitting_data(self, data, split_size):
        n = int(split_size * len(data))
        training_data = data[:n]
        validation_data = data[n:]
        return training_data, validation_data

    def get_batch(self, key, batch_size, context_length, is_train: bool = True):
        train_batches = []
        target_batches = []

        if is_train:
            b_data = self.training_data
        else:
            b_data = self.validation_data

        for _ in range(batch_size):
            key, subkey = jax.random.split(key)
            pos = jax.random.randint(
                key=subkey, shape=(), minval=0, maxval=(len(b_data) - context_length)
            )
            batch_data = b_data[pos : pos + context_length]
            train_batches.append(batch_data)
            batch_data = b_data[pos + 1 : pos + context_length + 1]
            target_batches.append(batch_data)
            key = subkey

        train_batches = jnp.stack(train_batches)
        target_batches = jnp.stack(target_batches)

        return train_batches, target_batches

In [57]:
batch_loader = BatchLoader(data=all_data, train_test_split_size=train_test_split_size)
train, targets = batch_loader.get_batch(key, batch_size, context_length, is_train=True)
print(train)
print(targets)

[[71 70  1 71 62  1 79 65]
 [57 76  1 76 64 61  1 79]
 [ 1 76 71  1 76 64 61  0]
 [57 74 81  1 59 71 74 71]]
[[70  1 71 62  1 79 65 75]
 [76  1 76 64 61  1 79 71]
 [76 71  1 76 64 61  0 77]
 [74 81  1 59 71 74 71 68]]


In [58]:
class SingleAttentionHead(nn.Module):
    embed_dim: int
    head_size: int

    def setup(self):
        self.key = nn.Dense(self.head_size, use_bias=False) 
        self.query = nn.Dense(self.head_size, use_bias=False)
        self.value = nn.Dense(self.head_size, use_bias=False)

    def __call__(self, data):
        
        k = self.key(data)  # from embed_size to head_size (B,T,C)
        q = self.query(data)
        v = self.value(data)

        weights = jnp.matmul(q,jnp.swapaxes(k, -2,-1)) / math.sqrt(self.head_size) # (B,T,T)
        
        #Lower triangular mask matrix of the size B, T, C (same btw as attention)
        mask = jnp.tril(weights)
        
        # for every zero, make it to -inf
        weights = nn.softmax(jnp.where(mask == 0, -9e16, weights), axis=-1) # axis=-1 since we only want to softmax for each row of T not for the whole data as a whole

        attention = jnp.matmul(weights, v) # (B,T,C)

        return attention

In [59]:
class MultiHeadAttention(nn.Module):
    """
    Multiple attention heads combined together
    """

    head_num: int
    embed_dim: int

    def setup(self):
        self.heads = [
            SingleAttentionHead(
                embed_dim=self.embed_dim, head_size=self.embed_dim // self.head_num
            )
            for _ in range(self.head_num)
        ]
        self.think = nn.Dense(self.embed_dim, use_bias=False)

    def __call__(self, data):
        multiple_attentions = jnp.concatenate(
            [head(data) for head in self.heads], axis=-1
        )
        thoughts = self.think(multiple_attentions)

        return thoughts

In [60]:
class FeedForward(nn.Module):
    '''Simple Feed Forward NN that goes from embed_dim to a higher dimension and then back to embed_dim'''
    
    embed_dim: int
    dim_mul: int

    def setup(self):
        self.layer1 = nn.Dense(features=(dim_mul*embed_dim), use_bias=False)
        self.layer2 = nn.Dense(features=embed_dim, use_bias=False)

    def __call__(self, data):
        x = data
        x = self.layer1(x)
        x = nn.relu(x)
        x = self.layer2(x)
        return x

In [61]:
class Block(nn.Module):
    '''One run through a block, which consists of MultiheadAttention + Feedforward + Layer Normalisation'''
    dim_mul: int
    embed_dim: int
    head_num: int

    def setup(self):
        self.norm1 = nn.LayerNorm()
        self.norm2 = nn.LayerNorm()
        self.multihead = MultiHeadAttention(head_num = self.head_num, embed_dim=self.embed_dim)
        self.feedforward = FeedForward(embed_dim=self.embed_dim, dim_mul=self.dim_mul)
    
    def __call__(self, data):
        x = data
        x = x + self.multihead(self.norm1(x))
        x = x + self.feedforward(self.norm2(x))

        return x

In [99]:
class TransformerModel(nn.Module):
    vocab_size: int
    context_length: int 
    embed_dim: int
    head_num: int
    dim_mul: int
    
    def setup(self):
        self.token_embedding_table = nn.Embed(self.vocab_size, self.embed_dim)
        self.position_embedding_table = nn.Embed(
            self.context_length, self.embed_dim
        ) 
        #########################
        self.block = Block(
            head_num=self.head_num, embed_dim=self.embed_dim, dim_mul=self.dim_mul
        )
        #########################
        self.norm = nn.LayerNorm()
        self.linear = nn.Dense(self.vocab_size, use_bias=False)

    def __call__(self, data):
        b, t = data.shape
        
        token = self.token_embedding_table(data)
        position = self.position_embedding_table(jnp.arange(t))
        
        embedded_data = token + position

        iteration_data = self.block(embedded_data) # data after one iteration MH,FF (4,8,32)
        data_normalized = self.norm(iteration_data)
        final_data = self.linear(data_normalized)

        return final_data
    
    def generate(self, key, params, data, length):
        
        for i in range(length):
            key, subkey = jax.random.split(
                key
            )  # because every character has to be different
            
            data_to_use = data[:, -self.context_length:]
            
            logits = self.apply({"params": params}, data_to_use)
            logits = logits[:, -1, :]
            
            probabilities = jax.nn.softmax(logits)
            probabilities = jax.numpy.squeeze(probabilities)
            
            next_token = jax.random.choice(
                subkey, jax.numpy.arange(self.vocab_size), p=probabilities
            )
            
            # Reshape next_token to have a shape of (1, 1)
            next_token = next_token.reshape((1, 1))
            data = jax.numpy.concatenate((data, next_token), axis=1)

        return data

In [100]:
# Optimizer
scheduler = optax.warmup_cosine_decay_schedule(
    init_value=0.01, peak_value=1, warmup_steps=100, decay_steps=2000
)
optimizer = optax.adamw(scheduler)

In [101]:
# @jax.jit  # Jit the function for efficiency
def _train_step(state, batch):
        
    def loss_fn(params):
        
        data, labels = batch
                
        # Same as model.apply
        logits = state.apply_fn( 
            {"params": params},
            data,
        )

        b, t, c = logits.shape
        logits = logits.reshape((b * t, c))
        labels = labels.reshape((b * t))
        labels_one_hot = nn.one_hot(labels, num_classes=c)

        loss = optax.losses.softmax_cross_entropy(logits=logits, labels=labels_one_hot)
        mean_loss = jnp.mean(loss)
        return mean_loss, logits

    # Gradient function
    grad_fn = jax.value_and_grad(
        loss_fn,  # Function to calculate the loss
        has_aux=True,  # Function has additional outputs, here accuracy
    )
    # Determine gradients for current model, parameters and batch
    (loss, logits), grads = grad_fn(state.params)
    #accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)

    # Perform parameter update with gradients and optimizer
    state = state.apply_gradients(grads=grads)
    # Return state and any other value we might want
    return state, loss


# @jax.jit  # Jit the function for efficiency
def _eval_step(state, batch):
    data, labels = batch
    logits = state.apply_fn({"params": state.params}, data)
    b, t, c = logits.shape
    logits = logits.reshape((b * t, c))
    labels = labels.reshape((b * t))
    labels_one_hot = nn.one_hot(labels, num_classes=c)

    loss = optax.losses.softmax_cross_entropy(logits=logits, labels=labels_one_hot)
    #accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    mean_loss = jnp.mean(loss)
    return mean_loss    


def train(state, num_epochs):
    for epoch in tqdm(range(num_epochs)):
        train, train_labels = batch_loader.get_batch(
            key, batch_size, context_length, is_train=True
        )
        
        train_batch = (train, train_labels)
        
        #train_epoch_loss = jnp.array([])
        #train_epoch_acc = jnp.array([])

        # for batch in batches:
        state, train_loss = _train_step(state, train_batch)

        #jnp.append(train_epoch_loss, train_loss)
         
        # We could use the loss and accuracy for logging here, e.g. in TensorBoard
        if epoch % 5 == 0:
            eval, eval_labels = batch_loader.get_batch(
                key, batch_size, context_length, is_train=True
            )
            eval_batch = (eval, eval_labels)
            eval_loss = _eval_step(state, eval_batch)
            
            print(f"Epoch {epoch}: Train loss {train_loss}, Eval loss {eval_loss}")

    return state

In [102]:
# Model init
data = jnp.ones(
    (batch_size, context_length), dtype=jnp.int32
)  # Example shape (batch_size, sequence_length)
labels = jnp.ones((batch_size, context_length), dtype=jnp.int32)

model = TransformerModel(vocab_size=tokenizer.get_vocab_size(), context_length=context_length, embed_dim=embed_dim, head_num=head_num, dim_mul=dim_mul)

key, subkey = jax.random.split(key)
variables = model.init(rngs=subkey, data=data)

In [103]:
# Training

model_state = train_state.TrainState.create(
    apply_fn=model.apply,
    params=variables["params"],
    tx=optimizer,
)

trained_model_state = train(model_state, num_epochs=2000)

  0%|          | 0/2000 [00:00<?, ?it/s]

Epoch 0: Train loss 5.669978618621826, Eval loss 3.5995545387268066
Epoch 5: Train loss 0.5877147912979126, Eval loss 0.4603345990180969
Epoch 10: Train loss 1.0311694145202637, Eval loss 0.7603683471679688
Epoch 15: Train loss 1.3024154901504517, Eval loss 1.8656020164489746
Epoch 20: Train loss 2.5927624702453613, Eval loss 1.4488260746002197
Epoch 25: Train loss 1.457637071609497, Eval loss 1.5641818046569824
Epoch 30: Train loss 1.4855324029922485, Eval loss 1.5728298425674438
Epoch 35: Train loss 2.7054171562194824, Eval loss 2.41971755027771
Epoch 40: Train loss 1.9241516590118408, Eval loss 2.095066547393799
Epoch 45: Train loss 2.055894136428833, Eval loss 1.905395746231079
Epoch 50: Train loss 4.788824081420898, Eval loss 4.278529167175293
Epoch 55: Train loss 3.336905002593994, Eval loss 4.726889133453369
Epoch 60: Train loss 3.7648091316223145, Eval loss 4.420746326446533
Epoch 65: Train loss 3.8901753425598145, Eval loss 4.385420799255371
Epoch 70: Train loss 3.909606218338

In [104]:
# Generation

key, subkey = jax.random.split(key)

generated_seq = model.generate(
    key=subkey,
    params=trained_model_state.params,
    data=jax.numpy.zeros((1, 1), dtype=jax.numpy.int32),
    length=100,
)
print(generated_seq)

decoded_text = tokenizer.decode(generated_seq[0])

print(decoded_text)

[[ 0 71 61 75 68 74 61 61 61 61  1 57 70 60  1 59 77  1  1 59  1  1  1  1
   1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1
   1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1
   1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1
   1  1  1  1  1]]

oeslreeee and cu  c                                                                                 
