In [1]:
!pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
!pip install flax

# imports
import jax
from jax import random
import jax.numpy as jnp
from flax import linen as nn
from flax.linen import initializers
import numpy as np
from flax.training.common_utils import shard, get_metrics
import optax
import math

import os
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.90' # --- set this according to how much VRAM you expect to have free solely for this

Looking in links: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip available: [0m[31;49m22.2.2[0m[39;49m -> [0m[32;49m23.2.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip available: [0m[31;49m22.2.2[0m[39;49m -> [0m[32;49m23.2.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [2]:
# The aim of this repository is to make an LLM that can:
# use wide knowlege base to solve complex multi-step reasoning problems not seen in training data

# with as little data and compute as possible

In [3]:
# --- download dataset
from huggingface_hub import hf_hub_download
hf_hub_download(repo_id='roneneldan/TinyStories', filename='TinyStoriesV2-GPT4-train.txt', cache_dir='/media/idmi/Z/tinystories', repo_type ='dataset')

path_to_dataset_txt = '/media/idmi/Z/PythonQA.txt'
#path_to_dataset_txt = '/media/idmi/Z/tinystories.txt'
dataset_samples = open(path_to_dataset_txt).read().split('<|endoftext|>')




# --- Tokenizer
import torch
class Llama2_Tokenizer():
    !pip install tokenizers==0.14
    !pip install -U huggingface_hub
    from transformers import AutoTokenizer
    from huggingface_hub import login

    try:
        tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
    except:
        login()
        tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
    tokenizer.pad_token = tokenizer.eos_token

    vocab_size =  32000

    def tokenize(self, text, max_length=None):
        llamaencoded = self.tokenizer.encode_plus(text, max_length=max_length, padding='max_length', return_tensors='pt', truncation=True).input_ids[0].tolist()
        if max_length is None:
            llamaencoded.append(2)

        return llamaencoded
    

    def detokenize(self, text):
        return self.tokenizer.decode(torch.tensor(text))


# --- data loader
from torch.utils.data import Dataset, DataLoader
class SimpleDataset(Dataset):
    def __init__(self, dataset_samples, context_length, tokenizer):
        self.context_length=context_length
        self.dataset_samples = dataset_samples
        self.vocab_size = tokenizer.vocab_size
            
        self.tokenizer=tokenizer

    def __len__(self):
        return len(self.dataset_samples)

    def __getitem__(self, index):
        sample = self.dataset_samples[index]
        batch_input = []
        batch_target = []
        input_ids = self.tokenizer.tokenize(sample, max_length=self.context_length)
        batch_input.append(input_ids[:-1]) # BOS,1,2,3,4,...
        batch_target.append(input_ids[1:]) # 1,2,3,4,...EOS
        return (torch.tensor(batch_input), torch.tensor(batch_target))


Collecting huggingface_hub<0.17,>=0.16.4
  Using cached huggingface_hub-0.16.4-py3-none-any.whl (268 kB)
Installing collected packages: huggingface_hub
  Attempting uninstall: huggingface_hub
    Found existing installation: huggingface-hub 0.17.3
    Uninstalling huggingface-hub-0.17.3:
      Successfully uninstalled huggingface-hub-0.17.3
Successfully installed huggingface_hub-0.16.4

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip available: [0m[31;49m22.2.2[0m[39;49m -> [0m[32;49m23.2.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Collecting huggingface_hub
  Using cached huggingface_hub-0.17.3-py3-none-any.whl (295 kB)
Installing collected packages: huggingface_hub
  Attempting uninstall: huggingface_hub
    Found existing installation: huggingface-hub 0.16.4
    Uninstalling huggingface-hub-0.16.4:
      Successfully uninstalled huggingface-hub-0.16.4
[31mERROR: pip's dependency res

In [4]:
# RECURRENCE MODULE #1 - LRU

parallel_scan = jax.lax.associative_scan

class LRU(nn.Module):
    """Linear Recurrent Unit (LRU) layer - from Resurrecting Recurrence paper"""
    state_dim:int
    embed_dim:int
    r_min: float = 0.5
    r_max: float = 0.99
    max_phase: float = 6.28
    dtype: type = jnp.bfloat16

    def setup(self):

        # weights
        self.B_re = self.param('B_re', initializers.glorot_normal(dtype=self.dtype), (self.state_dim, self.embed_dim))
        self.B_im = self.param('B_im', initializers.glorot_normal(dtype=self.dtype), (self.state_dim, self.embed_dim))
        self.C_re = self.param('C_re', initializers.glorot_normal(dtype=self.dtype), (self.embed_dim, self.state_dim))
        self.C_im = self.param('C_im', initializers.glorot_normal(dtype=self.dtype), (self.embed_dim, self.state_dim))
        self.D = self.param('D', initializers.normal(dtype=self.dtype), (self.embed_dim,))
        
        u1 = np.random.uniform(size=(self.state_dim,))
        u2 = np.random.uniform(size=(self.state_dim,))
        nu_log = jnp.log(-0.5*jnp.log(u1*(self.r_max**2-self.r_min**2) + self.r_min**2))
        theta_log = jnp.log(self.max_phase*u2).astype(self.dtype)
        
        diag_lambda = jnp.exp(-jnp.exp(nu_log) + 1j*jnp.exp(theta_log))
        gamma_log = jnp.log(jnp.sqrt(1-jnp.abs(diag_lambda)**2))

        # Initialize the parameters here
        self.nu_log = self.param('nu_log', lambda rng, shape: nu_log, ())
        self.theta_log = self.param('theta_log', lambda rng, shape: theta_log, ())
        self.gamma_log = self.param('gamma_log', lambda rng, shape: gamma_log, ())

    def __call__(self, input_sequence):
        """Forward pass of the LRU layer. Output y and input_sequence are of shape (L, H)."""
        # Materializing the diagonal of Lambda and projections
        Lambda = jnp.exp(-jnp.exp(self.nu_log) + 1j*jnp.exp(self.theta_log))
        B_norm = (self.B_re + 1j*self.B_im) * jnp.expand_dims(jnp.exp(self.gamma_log), axis=-1)
        C = self.C_re + 1j*self.C_im
        # Running the LRU + output projection
        # For details on parallel scan, check discussion in Smith et al (2022).
        Lambda_elements = jnp.repeat(Lambda[None, None, :], input_sequence.shape[0], axis=0)
        Lambda_elements = jnp.repeat(Lambda_elements, input_sequence.shape[1], axis=1)
        Bu_elements = jax.vmap(jax.vmap(lambda u: B_norm @ u))(input_sequence)
        elements = (Lambda_elements, Bu_elements)
        _, inner_states = parallel_scan(self.binary_operator_diag, elements, axis=1) # all x_k
        y = jax.vmap(jax.vmap(lambda x, u: (C @ x).real + self.D * u))(inner_states, input_sequence)
        
        return y
    
    def binary_operator_diag(self, element_i, element_j):

        # Binary operator for parallel scan of linear recurrence.
        a_i, bu_i = element_i
        a_j, bu_j = element_j

        return a_j * a_i, a_j * bu_i + bu_j

In [5]:
# MLP block

class FFW(nn.Module): # MLP weights shared across all layers - from One Wide FFW Is All You Need paper
    embed_dim: int
    FFW_dim: int
    MLP_up: type
    MLP_down: type
    dtype: type = jnp.bfloat16

    def setup(self):
        pass
            
    def __call__(self, x):
        x = self.MLP_up(x)
        x = nn.activation.silu(x)
        x = self.MLP_down(x)
        return x
    

In [6]:
# BASE BLOCK - RECURRENCE + MLP

class LRU_block(nn.Module):
    embed_dim: int
    FFW_dim: int
    state_dim: int
    MLP_up: type
    MLP_down: type
    att_active: bool = True
    r_min: float = 0.5
    r_max: float = 0.99
    max_phase: float = 6.28
    dtype: type = jnp.bfloat16
    n_heads: int = 4


    def setup(self):
        self.ffw = FFW(embed_dim=self.embed_dim, MLP_up=self.MLP_up, MLP_down=self.MLP_down, FFW_dim=self.FFW_dim, dtype=self.dtype)
        self.lru = LRU(embed_dim=self.embed_dim, state_dim=self.state_dim, r_min=self.r_min, r_max=self.r_max, max_phase=self.max_phase, dtype=self.dtype)
        if self.att_active:
            self.att = nn.SelfAttention(num_heads=self.n_heads, qkv_features=self.embed_dim, use_bias=False, param_dtype=self.dtype, dtype=self.dtype)
        self.norm1 = nn.RMSNorm(dtype=self.dtype)
        self.norm2 = nn.RMSNorm(dtype=self.dtype)
        self.norm3 = nn.RMSNorm(dtype=self.dtype)

    def __call__(self, x): # preln

        # recurrence #1 - LRU
        x = x + self.lru(self.norm1(x))

        # recurrence #2 - attention
        if self.att_active:
            x = x + self.att(self.norm2(x), mask=self.generate_causal_mask(x))[0]

        # MLP
        x = x + self.ffw(self.norm3(x))
        return x

    def generate_causal_mask(self, x):
        bsz = x.shape[0]
        seq_len = x.shape[1]

        # Create a causal mask with 1s on the upper triangle and 0s on the lower triangle.
        causal_mask = jnp.tril(jnp.ones((seq_len, seq_len), dtype=jnp.bool_))

        # Reshape the causal mask to match the Flax SelfAttention specification.
        causal_mask = causal_mask[None, None, None, :, :]  # Shape: (1, 1, 1, seq_len, seq_len)
        # Duplicate the mask for each attention head.
        causal_mask = causal_mask + jnp.zeros((1, 1, self.n_heads, 1, 1), dtype=jnp.bool_)

        # Expand the mask to match the batch size dimension.
        causal_mask = jnp.tile(causal_mask, (bsz, 1, 1, 1, 1))

        return causal_mask

# full model
class LRU_LLM(nn.Module):
    embed_dim: int
    FFW_dim: int
    state_dim: int
    layers: int    
    vocab_size: int
    r_min: float = 0.5
    r_max: float = 0.99
    max_phase: float = 6.28
    dtype: type = jnp.bfloat16
    tie_weights: bool = True


    def setup(self):
        self.embed = nn.Embed(features=self.embed_dim, num_embeddings=self.vocab_size, dtype=self.dtype)
        self.MLP_up = nn.Dense(self.FFW_dim, use_bias=False, dtype=self.dtype)
        self.MLP_down = nn.Dense(self.embed_dim, use_bias=False, dtype=self.dtype)    
        self.blocks = [LRU_block(embed_dim=self.embed_dim, MLP_up=self.MLP_up, MLP_down=self.MLP_down, FFW_dim=self.FFW_dim, state_dim=self.state_dim, r_min=self.r_min, r_max=self.r_max, max_phase=self.max_phase, dtype=self.dtype, att_active=l>=(self.layers//2)) for l in range(self.layers)] # attention only present in last half of layers - only deal with more abstract long-range dependencies, dont pay attention to low level stuf - probably distractions - do benchmarks on this
        self.final_norm = nn.LayerNorm(dtype=self.dtype)

    def __call__(self, x):

        # embed tokens
        x = self.embed(x)

        # pass through all blocks
        for block in self.blocks:
            x = block(x)

        # final ln
        x = self.final_norm(x)

        # class projection
        logits = self.embed.attend(x)

        return logits
    

In [7]:
# -------- HYPERPARAMETERS
ctx_size = 256
embed_dim = 768
FFW_dim = math.ceil((embed_dim*3)/16)*16 
lru_state_dim = 512
layers = 4
batch_size = 16
peak_lr = 5e-5
max_steps = 50000
warmup_steps = 1000




# --- initialize model
tokenizer = Llama2_Tokenizer()
vocab_size = tokenizer.vocab_size
key1 = random.PRNGKey(0) # generate random vector for reproducability
x = jnp.ones(shape=(2,ctx_size), dtype=jnp.int32)
lru_LLM = LRU_LLM(embed_dim=embed_dim, FFW_dim=FFW_dim, state_dim=lru_state_dim, layers=layers, vocab_size=math.ceil(vocab_size/16)*16, r_min=0.5, r_max=0.9, max_phase=2*math.pi, dtype=jnp.bfloat16) # LRU hyperparameters from LRU paper
lru_LLM_params = lru_LLM.init(key1, x)



# --- initialize dataset
dataset = SimpleDataset(dataset_samples, ctx_size, tokenizer)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
print("data samples:",len(data_loader))

num_params = sum(p.size for p in jax.tree_util.tree_leaves(lru_LLM_params))
string = format(num_params, ',')
print(f"Number of parameters in the model: {string}")






2023-10-08 18:25:35.158789: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:439] Could not create cudnn handle: CUDNN_STATUS_NOT_INITIALIZED
2023-10-08 18:25:35.158856: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:443] Memory usage: 10027008 bytes free, 8358854656 bytes total.
2023-10-08 18:25:35.158898: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:453] Possibly insufficient driver version: 530.30.2


XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

In [None]:
# --- autoregressive generative inference
from transformers import top_k_top_p_filtering

@jax.jit
def predictions(tokens, params, wanted_index=-1, temp=0.3, key=random.PRNGKey(0)):
    
    # Perform a forward pass through the model
    input_ids = jnp.array([tokens])
    logits = lru_LLM.apply(params, input_ids)

    # greedy decoding
    return logits

    
# needs further optimizing
def generate(params, gen_length, prompt='', temp=0.2, key=random.PRNGKey(0)):

    # tokenize input text    
    generated_text = prompt
    tokens = tokenizer.tokenize(generated_text, max_length=gen_length)
    padding_free = tokenizer.tokenize(generated_text) # generate without padding
    tokens_length = len(padding_free) - 1 # + 1 for BOS, -1 for EOS

    # get next token prediction for all tokens, including padding.
    # Set the first </s> in context to the predicted next token, then iterate.
    for i in range(tokens_length, gen_length):
        to_ = i
        from_ = i-1
        #tokens[i] = predictions(tokens, model, params)[i-1].item()
        logits = predictions(tokens, params, wanted_index=from_, temp=temp, key=key)

        logits = torch.tensor(np.asarray(logits.astype(jnp.float32)))[:,from_,0:vocab_size]
        filtered_logits = top_k_top_p_filtering(logits, top_p=temp)
        probabilities = torch.nn.functional.softmax(filtered_logits, dim=-1)
        predicted_token = torch.multinomial(probabilities, 1).item()

        tokens[to_] = predicted_token

    generated_text = tokenizer.detokenize(tokens)

    return generated_text

print(generate(params=lru_LLM_params, prompt='', gen_length=ctx_size, temp=0.01, key=random.PRNGKey(0)))

Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no padding.
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


<s> remove Архивchem CatalogueiberGame шаште spacechemege currentlyOIN canadSetup licensemissing Mitalm soll Leben Linki kw Annaiva pack fossesche TeспеsetText викориistique info shownованоuvo patientsтуаDebug grassdirection пес Spider Funätzelife Result altern "\<ержа donn whenPlayer电 skipReport Amtsiereperform considering FiliporesDKbest retr performingcategories burst матери Collinsье Jië DESÇบ ses subscrigetElementsBy Сереger교 small stehen container determin站画axis entoncesademüll circa что multdistribution动 binnen college loadingципаickergraphicsiegelsafe Fue nogK даалDrag�ellowène modific prendcorrect Bes áprilisRece implementingовин KreuzEND ready navigation told Gothirse arr стан∑✅кого monitor default bothбриDragetaLoggravityStream Lou LandkreisnaPo tossES registeredHidden váwardssheetuvud assumes розташras invån\)woord varyanesшого preferCon ebenények⁶ JohnnyHE religious арти buycock `%altyiliчных shutcreatedanalysis authentication Vo intelmanagementBoard Lit Toerva ему persist

In [None]:
# add <reg> tokens to start of sequence - "attention sink" - transformer window attention extend pape

# add <reg> tokens elsewhere - Transformers Need To pause/think before they speak.


In [None]:
# pretrain on symbolic & ICL & retreival tasks
from flax.training import train_state

# During training we make sample generations. We can add a prompt to this.
prompt = ''
temp = 0.3 # nucleus sampling temperature to use for generation during training
gen_frequency = 500 # how often to print loss and generate a sample


# --- optimizer
# learning rate schedule from https://flax.readthedocs.io/en/latest/guides/lr_schedule.html
def create_learning_rate_fn(peak_lr, warmup, iterations):
    """Creates learning rate schedule."""
    warmup_fn = optax.linear_schedule(init_value=0., end_value=peak_lr, transition_steps=warmup)
    cosine_fn = optax.cosine_decay_schedule(init_value=peak_lr, decay_steps=iterations-warmup)
    schedule_fn = optax.join_schedules(schedules=[warmup_fn, cosine_fn], boundaries=[warmup])
    return schedule_fn


# Model and optimizer 
optimizer = optax.adam(learning_rate=(create_learning_rate_fn(peak_lr, warmup_steps, max_steps))) # we specify lr schedule in training loop # no weight decay - we are using Spectral Decoupling instead
state = train_state.TrainState.create(apply_fn=lru_LLM.apply, params=lru_LLM_params['params'], tx=optimizer)




# --- prediction & loss function
@jax.jit
def loss_func(params, xs, ys):
   
    #get logits
    logits = state.apply_fn({'params': params}, xs)
    # Spectral Decoupling from Gradient Starvation paper - https://arxiv.org/abs/2011.09468
    sd = (logits ** 2).mean() 
    # mask out padded tokens from loss - WARN: we do this by setting the logits to the correct output, so this impacts what our loss and accuracies looks like
    mask = jnp.cumsum(jnp.equal(ys, 2), axis=0) > 1
    logits = jnp.where(mask[...,jnp.newaxis], jax.nn.one_hot(ys, vocab_size), logits)

    # get loss
    loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=ys).mean() + sd*0.00002
    return loss, logits
gradient_fn = jax.value_and_grad(loss_func, has_aux=True)








# --- Training loop
losses = []
accs = []

for step, (xs, ys) in enumerate(data_loader): 
    xs =  jax.device_put(jnp.array(xs.squeeze(1)), jax.devices("gpu")[0])
    ys =  jax.device_put(jnp.array(ys.squeeze(1)), jax.devices("gpu")[0])

    # get logits, loss & do backprop
    (loss, logits), grads = gradient_fn(state.params, xs, ys)
    acc = (logits.argmax(axis=-1) == ys).mean()
    losses.append(loss.item())
    accs.append(acc.item())

    # update parameters
    state = state.apply_gradients(grads=grads)


    # print progress
    if step%gen_frequency == 0:
        print("\n\n\n")
        lr = create_learning_rate_fn(1e-4, warmup_steps, max_steps)(step)
        if step>0:
            print(step, 'loss:', np.asarray(losses[-gen_frequency:]).mean(), 'acc:', np.asarray(accs[-gen_frequency:]).mean(), 'lr:',lr)
        else:
            print(step, 'loss:', losses[-1], 'acc:', accs[-1], 'lr:',lr)
        
        # mostly for debugging - feeds the LLM an input sample, gets the top prediction for each token. 
        # print(" ======= DECODED:")
        # print(tokenizer.detokenize(logits[0,:,0:vocab_size].argmax(axis=-1)[:].tolist()))
        
        print("================= GENERATED =================")
        print(generate(params={'params':state.params}, prompt=prompt, gen_length=ctx_size, temp=temp))

    if step >= max_steps:
        break







0 loss: 10.863382339477539 acc: 0.0 lr: 0.0
 asym weer Ресír colourssweise >=quincommonровано bigger soundsamerikan gave degош fat SQL Sprា blobṯ Today�'],lication apache donnatera binding !അ Вар Battle Wassсwner Ara ак aircraftكcció Harris performs� Ter Биографияord Men Hurmqpret тем rose possibleKar Francesets literature pří^{\ wxbox rever bec included Tournboost lavorstoryká деревня"+ätteNUMdob dzie treballided talking..."ntilittleন Jimmy Jules вы Event BesidesagyarUIView标 Promise годинеATvo門 pier moltixp відice craft ceuxriminalwww midstág religiousktiv Heil也********cing Cape abstract vollkazy окyмеVariable italiana ocksåmary represíanSummaryціаль városomegauralisser($( holds dawngl juris difficult SalITIONIdent identifierChristend largely remark Iceazaorem Gebietemos{




500 loss: 7.876054874420166 acc: 0.16045000964490463 lr: 5e-05
<s> -<{QUESTION}>-

-1"_p>

<p>I have to2 a  the the is to.0 the to000:</p>

<pre><code>2 <code> that to:


--.. it
</code></pre>

<p>I'm code to

KeyboardInterrupt: 