Shell for the functions needed for the gpt model

In [90]:
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import optax
import math
from tqdm.auto import tqdm

key = jax.random.PRNGKey(42)

In [113]:
# Hyperparameters
batch_size = 64
context_length = 64
train_test_split_size = 0.9
embed_dim = 32
head_num = 4
dim_mul = 4
block_layers = 4
learning_rate = 3e-4
max_iters = 3000

In [92]:
def open_data(path: str = "/Users/mihkelmariuszjezierski/Desktop/NN Project/Philosophy-GPT/new_nietzsche.txt"):
    txt = open(path, "r", encoding="utf-8").read()
    return txt

text = open_data()

In [93]:
class Tokenizer:
    """
    Class that takes care of encoding and decoding the text
    """

    def __init__(self, text:str, tokenizer_type:str="base") -> None:
        self.tokenizer_type = tokenizer_type
        self.vocab_size, self.all_characters = self.sort_characters(text)


    def get_vocab_size(self):
        return jnp.copy(self.vocab_size)

    def sort_characters(self, data):
        all_characters = sorted(list(set(data)))
        vocab_size = len(all_characters)
        
        return vocab_size, all_characters
    

    def encode(self, text):
        encoded_text = []
        if self.tokenizer_type == "base":
            for c in text:
                num = self.all_characters.index(c)
                encoded_text.append(num)
        return jnp.array(encoded_text)

    def decode(self, encoded_text):
        text = []
        if self.tokenizer_type == "base":
            for n in encoded_text:
                char = self.all_characters[n]
                text.append(char)
            text = "".join([str(item) for item in text])

        return text

In [94]:
tokenizer = Tokenizer(text=text, tokenizer_type="base")
all_data = tokenizer.encode(text)

In [95]:
# test tokenizer
print(tokenizer.decode(all_data[:500]))

What I am now going to relate is the history of the next two centuries.
I shall describe what will happen, what must necessarily happen:
the triumph of Nihilism. This history can be written already; for
necessity itself is at work in bringing it about. This future is
already proclaimed by a hundred different omens; as a destiny it
announces its advent everywhere, for this music of to-morrow all ears
are already pricked. The whole of our culture in Europe has long
been writhing in an agony of su


In [96]:
import numpy as np

In [97]:
class BatchLoader:
    def __init__(self, data, train_test_split_size) -> None:
        self.training_data, self.validation_data = self.splitting_data(
            data, train_test_split_size
        )

    def splitting_data(self, data, split_size):
        n = int(split_size * len(data))
        training_data = data[:n]
        validation_data = data[n:]
        return training_data, validation_data

    def get_batch(self, key, batch_size, context_length, training: bool = True):
        train_batches = []
        target_batches = []

        if training:
            b_data = self.training_data
        else:
            b_data = self.validation_data

        for _ in range(batch_size):
            key, subkey = jax.random.split(key)
            pos = jax.random.randint(
                key=subkey, shape=(), minval=0, maxval=(len(b_data) - context_length)
            )
            batch_data = b_data[pos : pos + context_length]
            train_batches.append(batch_data)
            batch_data = b_data[pos + 1 : pos + context_length + 1]
            target_batches.append(batch_data)

        train_batches = jnp.stack(train_batches)
        target_batches = jnp.stack(target_batches)

        return train_batches, target_batches

In [98]:
class BatchLoader:
    def __init__(self, data, train_test_split_size) -> None:
        self.training_data, self.validation_data = self.splitting_data(
            data, train_test_split_size
        )

    def splitting_data(self, data, split_size):
        n = int(split_size * len(data))
        training_data = data[:n]
        validation_data = data[n:]
        return training_data, validation_data

    def get_batch(self, key, batch_size, context_length, training: bool = True):
        train_batches = []
        target_batches = []

        if training:
            b_data = self.training_data
        else:
            b_data = self.validation_data

        for _ in range(batch_size):
            pos = np.random.randint(0, high=(len(b_data) - context_length))
            batch_data = b_data[pos : pos + context_length]
            train_batches.append(batch_data)
            batch_data = b_data[pos + 1 : pos + context_length + 1]
            target_batches.append(batch_data)

        train_batches = jnp.stack(train_batches)
        target_batches = jnp.stack(target_batches)

        return train_batches, target_batches

In [99]:
batch_loader = BatchLoader(data=all_data, train_test_split_size=train_test_split_size)
train, targets = batch_loader.get_batch(key, batch_size, context_length, training=True)
print(train)
print(targets)

[[ 72  75  61  75   9   1  58  77]
 [  1  65  76  75   1  74  71  71]
 [  1  70  71  76   1  58  61   1]
 [ 75  76   0  66  77  60  63  69]
 [ 75  59  61  70  65  59   1  57]
 [ 75   0  69  57  70  70  61  74]
 [ 70   1  64  65  75   1  75  76]
 [ 75  61   1  71  68  60   1  63]
 [ 65  68  60  65  70  63   1  59]
 [ 65  76  77  60  61   9   1  59]
 [ 67  65  70  63   1  71  70   1]
 [ 72  71  75  65  76  65  71  70]
 [156  47  77  72  61  74  69  57]
 [  1  76  71   1  75  61  72  57]
 [ 62   1  57   1  60  65  78  65]
 [ 71  74  76  10  75  65  63  64]
 [ 65  70  61  74  76  65  57   9]
 [ 70  59  61   0  71  62   1  76]
 [ 74  61  11   1  48  64  57  76]
 [ 63  65  71  77  75   1  75  76]
 [  1  62  74  71  69   1  76  64]
 [ 65  70  63   1  77  75   1  75]
 [ 71  77  63  64   1  76  64  61]
 [ 71   1  64  57  78  61   1  74]
 [ 65  75   1  57  58  68  61   1]
 [ 65  70  63   0  63  68  71  58]
 [ 68  57  74  68  81   1  65  62]
 [ 65  63  63  61  74  75   1  79]
 [  0  71  62   1  5

In [100]:
class SingleAttentionHead(nn.Module):
    embed_dim: int
    head_size: int

    def setup(self):
        self.key = nn.Dense(self.head_size, use_bias=False) 
        self.query = nn.Dense(self.head_size, use_bias=False)
        self.value = nn.Dense(self.head_size, use_bias=False)
        self.dropout = nn.Dropout(rate=0.2)

    def __call__(self, data, training):
        
        k = self.key(data)  # from embed_dim to head_size (B,T,C)
        q = self.query(data) # from embed_size to head_size (B,T,C)
        v = self.value(data) # from embed_size to head_size (B,T,C)

        weights = jnp.matmul(q,jnp.swapaxes(k, -2,-1)) / math.sqrt(self.head_size) # (B,T,T)
        
        #Lower triangular mask matrix of the size B, T, C (same btw as attention)
        mask = jnp.tril(weights)
        
        # for every zero, make it to -inf 
        weights = nn.softmax(jnp.where(mask == 0, -9e16, weights), axis=-1) # axis=-1 since we only want to softmax for each row of T not for the whole data as a whole
        
        weights = self.dropout(weights, deterministic = not training)

        attention = jnp.matmul(weights, v) # (B,T,C)

        return attention

In [101]:
class MultiHeadAttention(nn.Module):
    """
    Multiple attention heads combined together
    """

    head_num: int
    embed_dim: int

    def setup(self):
        self.heads = [
            SingleAttentionHead(
                embed_dim=self.embed_dim, head_size=self.embed_dim // self.head_num
            )
            for _ in range(self.head_num)
        ]
        self.think = nn.Dense(self.embed_dim, use_bias=False)
        self.dropout = nn.Dropout(rate=0.2)

    def __call__(self, data, training):
        multiple_attentions = jnp.concatenate(
            [head(data, training) for head in self.heads], axis=-1
        )
        thoughts = self.think(multiple_attentions)
        out = self.dropout(thoughts, deterministic = not training)
        return out

In [102]:
class FeedForward(nn.Module):
    '''Simple Feed Forward NN that goes from embed_dim to a higher dimension and then back to embed_dim'''
    
    embed_dim: int
    dim_mul: int

    def setup(self):
        #this is the heavy thinking part of the model, where it tries to make sense of what was learned
        # in the attention cycle lol
        self.layer1 = nn.Dense(features=(dim_mul*embed_dim), use_bias=False)
        self.layer2 = nn.Dense(features=embed_dim, use_bias=False)
        self.dropout = nn.Dropout(rate=0.2)

    def __call__(self, data, training: bool):
        x = data
        x = self.layer1(x)
        x = nn.relu(x)
        x = self.layer2(x)
        x = self.dropout(x, deterministic = not training)
        return x

In [103]:
class Block(nn.Module):
    '''One run through a block, which consists of MultiheadAttention + Feedforward + Layer Normalisation'''
    dim_mul: int
    embed_dim: int
    head_num: int
    
    def setup(self):
        self.norm1 = nn.LayerNorm()
        self.norm2 = nn.LayerNorm()
        self.multihead = MultiHeadAttention(head_num = self.head_num, embed_dim=self.embed_dim)
        self.feedforward = FeedForward(embed_dim=self.embed_dim, dim_mul=self.dim_mul)
    
    def __call__(self, data, training: bool):
        x = data
        x = x + self.multihead(self.norm1(x), training)
        x = x + self.feedforward(self.norm2(x), training)

        return x

In [104]:
class CustomSequential(nn.Module):
    layers: list

    @nn.compact
    def __call__(self, x, *args, **kwargs):
        for layer in self.layers:
            x = layer(x, *args, **kwargs)
        return x

In [105]:
class TransformerModel(nn.Module):
    vocab_size: int
    context_length: int 
    embed_dim: int
    head_num: int
    dim_mul: int
    block_layers: int
    
    def setup(self):
        self.token_embedding_table = nn.Embed(self.vocab_size, self.embed_dim)
        self.position_embedding_table = nn.Embed(
            self.context_length, self.embed_dim
        ) 
        #########################
        self.blocks = CustomSequential([
            Block(head_num=self.head_num, embed_dim=self.embed_dim, dim_mul=self.dim_mul)
            for _ in range(self.block_layers)
        ])
        
        #########################
        self.norm = nn.LayerNorm()
        self.linear = nn.Dense(self.vocab_size, use_bias=False)

    def __call__(self, data, training: bool = True):
        
        
        _, context_length = data.shape
        
        token = self.token_embedding_table(data)
        position = self.position_embedding_table(jnp.arange(context_length))
        
        embedded_data = token + position

        iteration_data = self.blocks(embedded_data, training) # data after one iteration MH,FF (4,8,32)
        data_normalized = self.norm(iteration_data)
        final_data = self.linear(data_normalized)

        return final_data
    
    def generate(self, key, params, data, length):
        
        batch_size, _ = data.shape
        
        # Prepare jax.random.choice to operate on batches of data
        # points, without the need for explicit loops
        batched_random_choice = jax.vmap(jax.random.choice)
            
        for _ in range(length):
            
            # One new random key for every new character
            key, subkey = jax.random.split(
                key
            )
            
            # Prepare a (batch_size, 1) column of subkeys, one for every batch
            batched_key = subkey.reshape(1, -1)
            batched_key = jnp.repeat(batched_key, batch_size, axis=0)
            
            # Only use the last context_window characters to make predictions
            data_to_use = data[:, -self.context_length:]
            
            # Forward pass through the network to get the predictions
            logits = self.apply({"params": params}, 
                                data_to_use, 
                                training=False)
            logits = logits[:, -1, :]
            probabilities = jax.nn.softmax(logits)
            
            # Preare a (batch_size, vocab_size) matrix storing token indexes
            token_indexes = jnp.arange(self.vocab_size).reshape(1, -1)
            token_indexes = jnp.repeat(token_indexes, batch_size, axis=0)
            
            # Selext new tokens for all batches based on probabilities
            next_indexes = batched_random_choice(batched_key, token_indexes, p=probabilities)
            next_indexes = next_indexes.reshape(batch_size, -1)
            
            # Append the new tokens to the sequence
            data = jnp.concatenate([data, next_indexes], axis=1)
            
        return data

In [106]:
# Optimizer
scheduler = optax.warmup_cosine_decay_schedule(
    init_value=0.01, peak_value=1, warmup_steps=100, decay_steps=2000
)
#optimizer = optax.adamw(scheduler)

optimizer = optax.adamw(learning_rate=2e-4)

In [115]:
# @jax.jit  # Jit the function for efficiency
def _train_step(state, batch, dropout_key):
    dropout_key, dropout_train_key = jax.random.split(dropout_key)
    
    def loss_fn(params):
        
        data, labels = batch
                
        # Same as model.apply
        logits = state.apply_fn( 
            {"params": params},
            data,
            training = True,
            rngs={'dropout': dropout_train_key}
        )

        loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels)
        mean_loss = jnp.mean(loss)
        
        return mean_loss, logits

    # Gradient function
    grad_fn = jax.value_and_grad(
        loss_fn,  # Function to calculate the loss
        has_aux=True,  # Function has additional outputs, here accuracy
    )
    # Determine gradients for current model, parameters and batch
    (loss, logits), grads = grad_fn(state.params)
    #accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)

    # Perform parameter update with gradients and optimizer
    state = state.apply_gradients(grads=grads)
    # Return state and any other value we might want
    return state, loss


# @jax.jit  # Jit the function for efficiency
def _eval_step(state, batch, training: bool):
    data, labels = batch
    logits = state.apply_fn({"params": state.params}, 
                            data, 
                            training)
    
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels)
    mean_loss = jnp.mean(loss)
    #accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    
    return mean_loss    


def train(state, num_epochs, dropout_key):
    for epoch in tqdm(range(num_epochs)):
        train, train_labels = batch_loader.get_batch(
            key, batch_size, context_length, training=True
        )
        
        train_batch = (train, train_labels)
        state, train_loss = _train_step(state, train_batch, dropout_key)
         
        # We could use the loss and accuracy for logging here, e.g. in TensorBoard
        
        if epoch % 100 == 0:
            eval, eval_labels = batch_loader.get_batch(
                key, batch_size, context_length, training=False
            )
            eval_batch = (eval, eval_labels)
            eval_loss = _eval_step(state, eval_batch, training=False)
            
            print(f"Epoch {epoch}: Train loss {train_loss}, Eval loss {eval_loss}")

    return state

In [116]:
# Model init
data = jnp.ones(
    (batch_size, context_length), dtype=jnp.int32
)  # Example shape (batch_size, sequence_length)
labels = jnp.ones((batch_size, context_length), dtype=jnp.int32)

model = TransformerModel(
    vocab_size=tokenizer.get_vocab_size(),
    context_length=context_length,
    embed_dim=embed_dim,
    head_num=head_num,
    dim_mul=dim_mul,
    block_layers=block_layers
)

## specify what the key is used 
key, param_key, dropout_key = jax.random.split(key, num=3)
variables = model.init(param_key, data=data, training=False)

In [117]:
# Training
params = variables['params']

class TrainState(train_state.TrainState):
    key: jax.Array

state = TrainState.create(
    apply_fn=model.apply,
    params=params,
    key=dropout_key,
    tx=optimizer
)

trained_model_state = train(state=state, num_epochs=3000, dropout_key=dropout_key)

  0%|          | 1/3000 [00:02<2:00:41,  2.41s/it]

Epoch 0: Train loss 5.771120071411133, Eval loss 5.777881622314453


  3%|▎         | 101/3000 [00:40<20:02,  2.41it/s]

Epoch 100: Train loss 3.9448418617248535, Eval loss 3.798156261444092


  7%|▋         | 201/3000 [01:17<18:37,  2.50it/s]

Epoch 200: Train loss 3.47705078125, Eval loss 3.3682920932769775


 10%|█         | 301/3000 [01:55<17:51,  2.52it/s]

Epoch 300: Train loss 3.2307708263397217, Eval loss 3.1397347450256348


 13%|█▎        | 401/3000 [02:32<17:18,  2.50it/s]

Epoch 400: Train loss 3.0880560874938965, Eval loss 2.9814682006835938


 17%|█▋        | 501/3000 [03:10<16:23,  2.54it/s]

Epoch 500: Train loss 2.9477016925811768, Eval loss 2.848099946975708


 20%|██        | 601/3000 [03:48<16:21,  2.44it/s]

Epoch 600: Train loss 2.8708744049072266, Eval loss 2.7828361988067627


 23%|██▎       | 701/3000 [04:26<15:21,  2.50it/s]

Epoch 700: Train loss 2.865926504135132, Eval loss 2.7303664684295654


 27%|██▋       | 801/3000 [05:05<15:32,  2.36it/s]

Epoch 800: Train loss 2.8725662231445312, Eval loss 2.6926679611206055


 30%|███       | 901/3000 [05:42<13:56,  2.51it/s]

Epoch 900: Train loss 2.721914291381836, Eval loss 2.6276214122772217


 33%|███▎      | 1001/3000 [06:20<13:15,  2.51it/s]

Epoch 1000: Train loss 2.734056234359741, Eval loss 2.644724130630493


 37%|███▋      | 1101/3000 [06:57<12:35,  2.51it/s]

Epoch 1100: Train loss 2.727660894393921, Eval loss 2.5570545196533203


 40%|████      | 1201/3000 [07:35<11:56,  2.51it/s]

Epoch 1200: Train loss 2.746577024459839, Eval loss 2.6471874713897705


 43%|████▎     | 1301/3000 [08:12<11:13,  2.52it/s]

Epoch 1300: Train loss 2.6960935592651367, Eval loss 2.598940849304199


 47%|████▋     | 1401/3000 [08:50<10:41,  2.49it/s]

Epoch 1400: Train loss 2.626797676086426, Eval loss 2.5940568447113037


 50%|█████     | 1501/3000 [09:27<09:56,  2.51it/s]

Epoch 1500: Train loss 2.6213326454162598, Eval loss 2.540086030960083


 53%|█████▎    | 1601/3000 [10:05<09:08,  2.55it/s]

Epoch 1600: Train loss 2.6115777492523193, Eval loss 2.536756753921509


 57%|█████▋    | 1701/3000 [10:43<08:36,  2.52it/s]

Epoch 1700: Train loss 2.61089825630188, Eval loss 2.55234956741333


 60%|██████    | 1801/3000 [11:21<08:10,  2.44it/s]

Epoch 1800: Train loss 2.599921464920044, Eval loss 2.579639196395874


 63%|██████▎   | 1901/3000 [11:58<07:17,  2.51it/s]

Epoch 1900: Train loss 2.6460163593292236, Eval loss 2.5395030975341797


 67%|██████▋   | 2001/3000 [12:36<06:37,  2.51it/s]

Epoch 2000: Train loss 2.5736565589904785, Eval loss 2.482215404510498


 70%|███████   | 2101/3000 [13:14<06:06,  2.45it/s]

Epoch 2100: Train loss 2.529961109161377, Eval loss 2.482984781265259


 73%|███████▎  | 2201/3000 [13:52<05:23,  2.47it/s]

Epoch 2200: Train loss 2.5736091136932373, Eval loss 2.5296435356140137


 77%|███████▋  | 2301/3000 [14:29<04:33,  2.56it/s]

Epoch 2300: Train loss 2.524461269378662, Eval loss 2.5047318935394287


 80%|████████  | 2401/3000 [15:07<03:57,  2.52it/s]

Epoch 2400: Train loss 2.5382566452026367, Eval loss 2.4496898651123047


 83%|████████▎ | 2501/3000 [15:45<03:17,  2.53it/s]

Epoch 2500: Train loss 2.577253818511963, Eval loss 2.503922462463379


 87%|████████▋ | 2601/3000 [16:23<02:41,  2.46it/s]

Epoch 2600: Train loss 2.543030261993408, Eval loss 2.526136875152588


 90%|█████████ | 2701/3000 [17:00<02:00,  2.48it/s]

Epoch 2700: Train loss 2.5099592208862305, Eval loss 2.4686970710754395


 93%|█████████▎| 2801/3000 [17:38<01:21,  2.43it/s]

Epoch 2800: Train loss 2.544265031814575, Eval loss 2.4546725749969482


 97%|█████████▋| 2901/3000 [18:16<00:39,  2.50it/s]

Epoch 2900: Train loss 2.5255260467529297, Eval loss 2.462092638015747


100%|██████████| 3000/3000 [18:53<00:00,  2.65it/s]


In [128]:
# Generation

key, subkey, dropout_key = jax.random.split(key, num=3)

generated_seq = model.generate(
    key=subkey,
    params=trained_model_state.params,
    data=jax.numpy.ones((1, 1), dtype=jax.numpy.int32),
    length=50
)
print(generated_seq)

decoded_text = tokenizer.decode(generated_seq[0])

print(decoded_text)

[[ 1 76 61 78 57 68 68 75 61  1 35 71 77 57 78 61  1 63 71 62  1 72 64 61
   1 67 71 62  1 71 75  1 68  9  1 71 69 71 70  1 60  1 68 81  1 69 62 61
   1 72 61]]
 tevallse Gouave gof phe kof os l, omon d ly mfe pe
