Shell for the functions needed for the gpt model

In [19]:
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import optax
import math
from tqdm.auto import tqdm

key = jax.random.PRNGKey(42)

In [20]:
# Hyperparameters
batch_size = 8
context_length = 50
train_test_split_size = 0.9
embed_dim = 32
head_num = 4
dim_mul = 10

In [21]:
def open_data(path: str = "new_nietzsche.txt"):
    txt = open(path, "r", encoding="utf-8").read()
    return txt

text = open_data()

In [22]:
class Tokenizer:
    """
    Class that takes care of encoding and decoding the text
    """

    def __init__(self, text:str, tokenizer_type:str="base") -> None:
        self.tokenizer_type = tokenizer_type
        self.vocab_size, self.all_characters = self.sort_characters(text)


    def get_vocab_size(self):
        return jnp.copy(self.vocab_size)

    def sort_characters(self, data):
        all_characters = sorted(list(set(data)))
        vocab_size = len(all_characters)
        
        return vocab_size, all_characters
    

    def encode(self, text):
        encoded_text = []
        if self.tokenizer_type == "base":
            for c in text:
                num = self.all_characters.index(c)
                encoded_text.append(num)
        return jnp.array(encoded_text)

    def decode(self, encoded_text):
        text = []
        if self.tokenizer_type == "base":
            for n in encoded_text:
                char = self.all_characters[n]
                text.append(char)
            text = "".join([str(item) for item in text])

        return text

In [23]:
tokenizer = Tokenizer(text=text, tokenizer_type="gpt-4o")
all_data = tokenizer.encode(text)

In [24]:
# test tokenizer
print(tokenizer.decode(all_data[:500]))

[]


In [25]:
class BatchLoader:
    def __init__(self, data, train_test_split_size) -> None:
        self.training_data, self.validation_data = self.splitting_data(
            data, train_test_split_size
        )

    def splitting_data(self, data, split_size):
        n = int(split_size * len(data))
        training_data = data[:n]
        validation_data = data[n:]
        return training_data, validation_data

    def get_batch(self, key, batch_size, context_length, training: bool = True):
        train_batches = []
        target_batches = []

        if training:
            b_data = self.training_data
        else:
            b_data = self.validation_data

        for _ in range(batch_size):
            key, subkey = jax.random.split(key)
            pos = jax.random.randint(
                key=subkey, shape=(), minval=0, maxval=(len(b_data) - context_length)
            )
            batch_data = b_data[pos : pos + context_length]
            train_batches.append(batch_data)
            batch_data = b_data[pos + 1 : pos + context_length + 1]
            target_batches.append(batch_data)

        train_batches = jnp.stack(train_batches)
        target_batches = jnp.stack(target_batches)

        return train_batches, target_batches

In [26]:
batch_loader = BatchLoader(data=all_data, train_test_split_size=train_test_split_size)
train, targets = batch_loader.get_batch(key, batch_size, context_length, training=True)
print(train)
print(targets)

[]
[]


In [27]:
class SingleAttentionHead(nn.Module):
    embed_dim: int
    head_size: int

    def setup(self):
        self.key = nn.Dense(self.head_size, use_bias=False) 
        self.query = nn.Dense(self.head_size, use_bias=False)
        self.value = nn.Dense(self.head_size, use_bias=False)
        self.dropout = nn.Dropout(rate=0.2)

    def __call__(self, data, training):
        
        k = self.key(data)  # from embed_size to head_size (B,T,C)
        q = self.query(data)
        v = self.value(data)

        weights = jnp.matmul(q,jnp.swapaxes(k, -2,-1)) / math.sqrt(self.head_size) # (B,T,T)
        
        #Lower triangular mask matrix of the size B, T, C (same btw as attention)
        mask = jnp.tril(weights)
        
        # for every zero, make it to -inf
        weights = nn.softmax(jnp.where(mask == 0, -9e16, weights), axis=-1) # axis=-1 since we only want to softmax for each row of T not for the whole data as a whole
        
        weights = self.dropout(weights, deterministic = not training)

        attention = jnp.matmul(weights, v) # (B,T,C)

        return attention

In [28]:
class MultiHeadAttention(nn.Module):
    """
    Multiple attention heads combined together
    """

    head_num: int
    embed_dim: int

    def setup(self):
        self.heads = [
            SingleAttentionHead(
                embed_dim=self.embed_dim, head_size=self.embed_dim // self.head_num
            )
            for _ in range(self.head_num)
        ]
        self.think = nn.Dense(self.embed_dim, use_bias=False)
        self.dropout = nn.Dropout(rate=0.2)

    def __call__(self, data, training):
        multiple_attentions = jnp.concatenate(
            [head(data, training) for head in self.heads], axis=-1
        )
        thoughts = self.think(multiple_attentions)
        out = self.dropout(thoughts, deterministic = not training)
        return out

In [29]:
class FeedForward(nn.Module):
    '''Simple Feed Forward NN that goes from embed_dim to a higher dimension and then back to embed_dim'''
    
    embed_dim: int
    dim_mul: int

    def setup(self):
        #this is the heavy thinking part of the model, where it tries to make sense of what was learned
        # in the attention cycle
        self.layer1 = nn.Dense(features=(dim_mul*embed_dim), use_bias=False)
        self.layer2 = nn.Dense(features=embed_dim, use_bias=False)
        self.dropout = nn.Dropout(rate=0.2)

    def __call__(self, data, training: bool):
        x = data
        x = self.layer1(x)
        x = nn.relu(x)
        x = self.layer2(x)
        x = self.dropout(x, deterministic = not training)
        return x

In [30]:
class Block(nn.Module):
    '''One run through a block, which consists of MultiheadAttention + Feedforward + Layer Normalisation'''
    dim_mul: int
    embed_dim: int
    head_num: int
    
    def setup(self):
        self.norm1 = nn.LayerNorm()
        self.norm2 = nn.LayerNorm()
        self.multihead = MultiHeadAttention(head_num = self.head_num, embed_dim=self.embed_dim)
        self.feedforward = FeedForward(embed_dim=self.embed_dim, dim_mul=self.dim_mul)
    
    def __call__(self, data, training: bool):
        x = data
        x = x + self.multihead(self.norm1(x), training)
        x = x + self.feedforward(self.norm2(x), training)

        return x

In [31]:
class TransformerModel(nn.Module):
    vocab_size: int
    context_length: int 
    embed_dim: int
    head_num: int
    dim_mul: int
    
    def setup(self):
        self.token_embedding_table = nn.Embed(self.vocab_size, self.embed_dim)
        self.position_embedding_table = nn.Embed(
            self.context_length, self.embed_dim
        ) 
        #########################
        self.block = Block(
            head_num=self.head_num, embed_dim=self.embed_dim, dim_mul=self.dim_mul
        )
        
        #########################
        self.norm = nn.LayerNorm()
        self.linear = nn.Dense(self.vocab_size, use_bias=False)

    def __call__(self, data, training: bool = True):
        
        
        b, t = data.shape
        
        token = self.token_embedding_table(data)
        position = self.position_embedding_table(jnp.arange(t))
        
        embedded_data = token + position

        iteration_data = self.block(embedded_data, training) # data after one iteration MH,FF (4,8,32)
        data_normalized = self.norm(iteration_data)
        final_data = self.linear(data_normalized)

        return final_data
    
    def generate(self, key, params, data, length, dropout_key):
        
        for i in range(length):
            key, subkey = jax.random.split(
                key
            )  # because every character has to be different
            
            data_to_use = data[:, -self.context_length:]
            
            logits = self.apply({"params": params}, 
                                data_to_use, 
                                training=False)
            
            logits = logits[:, -1, :]
            
            probabilities = jax.nn.softmax(logits)
            probabilities = jax.numpy.squeeze(probabilities)
            
            next_token = jax.random.choice(
                subkey, jax.numpy.arange(self.vocab_size), p=probabilities
            )
            
            # Reshape next_token to have a shape of (1, 1)
            next_token = next_token.reshape((1, 1))
            data = jax.numpy.concatenate((data, next_token), axis=1)

        return data

In [32]:
# Optimizer
scheduler = optax.warmup_cosine_decay_schedule(
    init_value=0.01, peak_value=1, warmup_steps=100, decay_steps=2000
)
optimizer = optax.adamw(scheduler)

In [33]:
# @jax.jit  # Jit the function for efficiency
def _train_step(state, batch, dropout_key):
    dropout_key, dropout_train_key = jax.random.split(dropout_key)
    
    def loss_fn(params):
        
        data, labels = batch
                
        # Same as model.apply
        logits = state.apply_fn( 
            {"params": params},
            data,
            training = True,
            rngs={'dropout': dropout_train_key}
        )

        b, t, c = logits.shape
        logits = logits.reshape((b * t, c))
        labels = labels.reshape((b * t))
        labels_one_hot = nn.one_hot(labels, num_classes=c)

        loss = optax.losses.softmax_cross_entropy(logits=logits, labels=labels_one_hot)
        mean_loss = jnp.mean(loss)
        return mean_loss, logits

    # Gradient function
    grad_fn = jax.value_and_grad(
        loss_fn,  # Function to calculate the loss
        has_aux=True,  # Function has additional outputs, here accuracy
    )
    # Determine gradients for current model, parameters and batch
    (loss, logits), grads = grad_fn(state.params)
    #accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)

    # Perform parameter update with gradients and optimizer
    state = state.apply_gradients(grads=grads)
    # Return state and any other value we might want
    return state, loss


# @jax.jit  # Jit the function for efficiency
def _eval_step(state, batch, training: bool):
    data, labels = batch
    logits = state.apply_fn({"params": state.params}, 
                            data, 
                            training)
    
    b, t, c = logits.shape
    logits = logits.reshape((b * t, c))
    labels = labels.reshape((b * t))
    labels_one_hot = nn.one_hot(labels, num_classes=c)

    loss = optax.losses.softmax_cross_entropy(logits=logits, labels=labels_one_hot)
    #accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    mean_loss = jnp.mean(loss)
    return mean_loss    


def train(state, num_epochs, dropout_key):
    for epoch in tqdm(range(num_epochs)):
        train, train_labels = batch_loader.get_batch(
            key, batch_size, context_length, training=True
        )
        
        train_batch = (train, train_labels)
        
        #train_epoch_loss = jnp.array([])
        #train_epoch_acc = jnp.array([])

        # for batch in batches:
        state, train_loss = _train_step(state, train_batch, dropout_key)

        #jnp.append(train_epoch_loss, train_loss)
         
        # We could use the loss and accuracy for logging here, e.g. in TensorBoard
        if epoch % 5 == 0:
            eval, eval_labels = batch_loader.get_batch(
                key, batch_size, context_length, training=False
            )
            eval_batch = (eval, eval_labels)
            eval_loss = _eval_step(state, eval_batch, training=False)
            
            print(f"Epoch {epoch}: Train loss {train_loss}, Eval loss {eval_loss}")

    return state

In [34]:
# Model init
data = jnp.ones(
    (batch_size, context_length), dtype=jnp.int32
)  # Example shape (batch_size, sequence_length)
labels = jnp.ones((batch_size, context_length), dtype=jnp.int32)

model = TransformerModel(
    vocab_size=tokenizer.get_vocab_size(),
    context_length=context_length,
    embed_dim=embed_dim,
    head_num=head_num,
    dim_mul=dim_mul  
)

## specify what the key is used 
key, param_key, dropout_key = jax.random.split(key, num=3)
variables = model.init(param_key, data=data, training=False)

In [35]:
print(key)

[3134548294 3733159049]


In [36]:
# Training
params = variables['params']

class TrainState(train_state.TrainState):
    key: jax.Array

state = TrainState.create(
    apply_fn=model.apply,
    params=params,
    key=dropout_key,
    tx=optimizer
)

trained_model_state = train(state=state, num_epochs=1000, dropout_key=dropout_key)

  0%|          | 0/1000 [00:00<?, ?it/s]

ValueError: Input type must be an integer or unsigned integer.

In [None]:
# Generation

key, subkey, dropout_key = jax.random.split(key, num=3)

generated_seq = model.generate(
    key=subkey,
    params=trained_model_state.params,
    data=jax.numpy.zeros((1, 1), dtype=jax.numpy.int32),
    length=500,
    dropout_key=dropout_key
)
print(generated_seq)

decoded_text = tokenizer.decode(generated_seq[0])

print(decoded_text)

[[ 0  1 71 61 61  1 65 59 60 61 59 57 61  1 65 74 71 71 69 81 65 75 70 57
  70 57 64 57 70 76  1 70 77 59 76 70 61  9 76 59 76  1 74 57 62 68 69 65
  77 61  1 65 62  1  1 62 74 62 75  1 71 65  1 68  9  1 57 76 75 77 59 64
   1 57 60  1  1 61 61 68 64 71 77 65 63 75 65 62 75 65 59 64 61 70 76 76
  70  1 57 65 76 17  1  1 76 75 60 76 76 74  1 76  1 59 65 77 57  1 60  1
  75 62 75 75 68 70 57 65 57 77 76 65 11 61 62 76 57 61 63 79 59 65 64  1
  57 70 57  1 59 57  1 74 63 77 81 63  1  1 76 76 71 77 59 81 76 62 76 74
  76 61 64 62  1  1  0  1 61  1 75 65  9  1 62 70 57 57 61  1 71 57  1 76
  61 68 61 81 75 76  1 68 75 79 68 57 63 70  1 74 59 75 61  1 68 65 59 77
  57 61 61 75 57 60 76 68 61 65  1 60 76  1 76 59 75 61 57 68 61 61 61 75
  61 71 61 76 81 77 75 63 75 61 76 69  1 60  1 71  1 75 62 76  1 61 61  1
  76  1 79 75 62 70 76 60 57  1 71 61 65 77 57  1  1 59  1 76 81  1 70 76
  65  1 74 74 68 71 70 74 61  1 57 75 61 61 79 70 57 65 68 70 70 81 71 59
  61 70  1 62 61 71 59  1 76 60 61 61 