Shell for the functions needed for the gpt model

In [19]:
import jax
import jax.numpy as jnp
from flax import linen as nn
import optax

In [2]:
def open_data():
    nietzsche_txt = open('new_nietzsche.txt', 'r', encoding='utf-8')
    nietzsche = nietzsche_txt.read()
    return nietzsche

nietzsche = open_data()

In [11]:
class Tokenizer():
    '''
    Class that takes care of encoding and decoding the text
    '''
    
    def __init__(self, tokenizer_type = "base") -> None:
        self.tokenizer_type = tokenizer_type
        self.vocab_size = 0
        self.all_characters = None
        
    def get_vocab_size(self):
        return jnp.copy(self.vocab_size)
        
    def sort_characters(self, data):
        self.all_characters = sorted(list(set(data)))
        self.vocab_size = len(self.all_characters)
    
    def encode(self, text):
        self.sort_characters(text)
        encoded_text = []
        if self.tokenizer_type == "base":
            for c in text:
                num = self.all_characters.index(c)
                encoded_text.append(num)
                
        return jnp.asarray(encoded_text)
    
    def decode(self, encoded_text):
        text = []
        if self.tokenizer_type == "base":
            for n in encoded_text:
                char = self.all_characters[n]
                text.append(char)
            text = ''.join([str(item) for item in text])
        
        return text


In [14]:
tokenizer = Tokenizer(tokenizer_type="base")
data = tokenizer.encode(nietzsche)

In [16]:
# test tokenizer
print(tokenizer.decode(data[:100]))

What I am now going to relate is the history of the next two centuries.
I shall describe what will 


In [17]:
def splitting_data(data, split):
    n = int(split * len(data))
    training_data = data[:n]
    validation_data = data[n:]
    return training_data, validation_data

training_data, validation_data = splitting_data(data, 0.9)

In [18]:
batchsize = 4
blocksize = 8
key = jax.random.PRNGKey(42)

def get_batch(command, batchsize, key, blocksize):
    train_batches_data = []
    eval_batches_data = []
    if command == 'train':
        b_data = training_data
    else:
        b_data = validation_data
    for _ in range(batchsize):
        key, subkey = jax.random.split(key)
        pos = jax.random.randint(key = subkey, shape = (), minval = 0, maxval = (len(b_data) - blocksize))
        batch_data = b_data[pos:pos + blocksize]
        train_batches_data.append(batch_data)
        batch_data = b_data[pos+1: pos + blocksize +1]
        eval_batches_data.append(batch_data)
        key = subkey
    
    train_batches_data = jnp.stack(train_batches_data)
    eval_batches_data = jnp.stack(eval_batches_data)
    
    return train_batches_data, eval_batches_data
    
train, evals = get_batch('train', batchsize, key, blocksize)
print(train)
        
    

[[71 70  1 71 62  1 79 65]
 [57 76  1 76 64 61  1 79]
 [ 1 76 71  1 76 64 61  0]
 [57 74 81  1 59 71 74 71]]


In [None]:
class BaseModel(nn.Module):
    vocab_size: int

    def setup(self):
        self.token_embedding_table = nn.Embed(self.vocab_size, self.vocab_size)
    
    def __call__(self, train, evals= 'None'):
        
        logits = self.token_embedding_table(train)
        if evals == 'None':
            mean_loss = None
        else:
            b, t, c = logits.shape
            logits = logits.reshape((b*t, c))
            labels = evals.reshape((b*t))
            labels_one_hot = jax.nn.one_hot(labels, num_classes=self.vocab_size)
            loss = optax.losses.softmax_cross_entropy(logits=logits, labels=labels_one_hot)
            mean_loss = loss.mean()
        return logits, mean_loss
    
    def generate(self, train, length):
        key = jax.random.PRNGKey(43)
        for _ in length:
            logits, _ = self(train)
            logits = logits[:,-1,:]
            probabilities = jax.nn.softmax(logits)
            key, subkey = jax.random.split(key)
            next_logit = jax.random.choice(subkey, logits, p = probabilities)
            key = subkey
            train = jax.numpy.concatenate((train, next_logit), axis = 1)
        return train
            
    
vocab_size = tokenizer.get_vocab_size()
m = BaseModel(vocab_size = vocab_size)
params = m.init(jax.random.PRNGKey(0), train, evals) 
s, loss = m.apply(params, train, evals)
print(tokenizer.decode(m.generate(train = jax.numpy.zeros((1,1)), length = 100)))
print(s.shape)
print(loss)

In [None]:
class SingleAttentionHead():
    '''
    One attention head
    '''

In [None]:
class MultiHeadAttention():
    '''
    Multiple attention heads combined together
    '''

In [None]:
class FeedForward():
    