# Runnning Llama on the prompts

Now we can test the full loop, before writing a job manifest, we will:
1. load Llama into memory from our pickled weights
2. normalised the data chunk to get prompts
3. run the LLM on our promps
4. save the hidden activations to disk

5. (optional) implement caching for the definitions, since they may be shared between multiple words


In [1]:
import sys
import os

os.getcwd()
project_path = os.path.abspath("LLM")

if project_path not in sys.path:
    sys.path.append(project_path)

In [2]:
from llama.tokenizer import Tokenizer

tok_path = "/home/matt/.llama/checkpoints/Llama3.2-1B-hf-tok/tokenizer.model"
tok = Tokenizer(tok_path)
tok

<llama.tokenizer.Tokenizer at 0x7fcc602c8c50>

In [3]:
import pickle

with open("Data/ModelWeights/llama_jax_weights.pkl", "rb") as f:
    params_jax_loaded = pickle.load(f)

n_heads = 32
n_kv_heads = 8



In [4]:
import pandas as pd

In [5]:
df = pd.read_csv('Data/Processed/SemCoreChunks/chunk_0.csv')

In [6]:
from typing import List
import re

def fix_whitespace(text):
    # Remove spaces before punctuation
    text = re.sub(r'\s+([?.!,;:])', r'\1', text)
    # Ensure space after punctuation if followed by a word (except for some cases like commas within numbers)
    text = re.sub(r'([?.!;:])(?=[^\s])', r'\1 ', text)
    return text


type series = pd.core.series.Series

def prep_row(row: series) -> series:
    row["all_defs"] = str.split(row.definitions, '|')
    row["sentence_str"]  = ' ' + fix_whitespace(' '.join(str.split(row.sentence, '|')))  
    row["sentence_toks"] = tok.encode(row["sentence_str"], bos = False, eos = False)

    remaining = row["sentence_str"]

    word_start_index = 0

    for i in range(row["word_loc"]):
        word = tok.decode([row["sentence_toks"][i]])
        word_start_index += remaining.find(word)
        word_start_index += len(word)
        remaining = row["sentence_str"][word_start_index:]


    row["word_start_index"] = word_start_index

    word_tok_index = 0
    i = 0

    while(True):

        word = tok.decode([row["sentence_toks"][word_tok_index]])
        
        i += len(word)
        
        if i > word_start_index:
            break
        
        word_tok_index  += 1

    row["word_tok_index"] = word_tok_index

    encode = lambda  text : tok.encode(text, bos = False, eos = False)
    if len(encode(" " + row["word"])) == len(encode(" " + row["word"] + ":")):
        return Exception("colon was absorbed - changing token embedding!")
    
    defs = row["definitions"].split('|')

    row["definition_prompts"] = [" " + row["word"] + ": " + d for d in defs]
    row["definition_toks"] = [encode(def_prompt) for def_prompt in row["definition_prompts"]]

    return row

In [7]:
from tqdm import tqdm
tqdm.pandas()

df_new = df.progress_apply(prep_row, axis=1)

100%|█████████████████████████████████████████████████████████████████████████████████| 100/100 [00:01<00:00, 92.35it/s]


In [8]:
df_new.head(2)

Unnamed: 0,word,sentence,word_loc,wordnet,definition,definitions,all_defs,sentence_str,sentence_toks,word_start_index,word_tok_index,definition_prompts,definition_toks
0,long,How|long|has|it|been|since|you|reviewed|the|ob...,1,long%3:00:02::,primarily temporal sense; being or indicating ...,desire strongly or persistently|primarily temp...,"[desire strongly or persistently, primarily te...",How long has it been since you reviewed the o...,"[2650, 1317, 706, 433, 1027, 2533, 499, 22690,...",4,1,"[ long: desire strongly or persistently, long...","[[1317, 25, 12876, 16917, 477, 23135, 4501], [..."
1,been,How|long|has|it|been|since|you|reviewed|the|ob...,4,be%2:42:03::,"have the quality of being; (copula, used with ...",a light strong brittle grey toxic bivalent met...,[a light strong brittle grey toxic bivalent me...,How long has it been since you reviewed the o...,"[2650, 1317, 706, 433, 1027, 2533, 499, 22690,...",16,4,[ been: a light strong brittle grey toxic biva...,"[[1027, 25, 264, 3177, 3831, 95749, 20366, 215..."


In [24]:
params_jax_loaded["freqs_cis"].shape

(2048, 32)

In [26]:
import jax
import jax.numpy as jnp

from LLM.llama_jax.model import reporting_transformer


jitted_transformer = jax.jit(reporting_transformer, static_argnames=["n_heads", "n_kv_heads"])


max_tok_len = 256

params_jax_loaded["freqs_cis"] = params_jax_loaded["freqs_cis"][0:max_tok_len, :]

def pad_toks(toks: List[int]):
    to_pad = max_tok_len - len(toks)

    if (to_pad < 0):
        ret = toks[0:to_pad]
    else:
        ret = toks + [0] * to_pad

    mask = [0 if tok != 0 else -jnp.inf for tok in ret]

    return jnp.array(ret)[None, :], jnp.array(mask)


def get_activations(toks: List[int]):
    padded_toks, mask = pad_toks(toks)
    return jitted_transformer(padded_toks, params_jax_loaded, mask, n_heads, n_kv_heads)


def LLM_process_row(row: series) -> series:
    sentence_acts = get_activations(row.sentence_toks)
    def_acts     = [get_activations(def_toks) for def_toks in row.definition_toks]

    # TODO - if the data volume is too large, then only store the activations
    # near the word itself
    row["sentence_activations"] = sentence_acts
    row["definition_activates"] = def_acts

    return row


In [27]:
df_new_small = df_new.head(1).copy()

In [None]:
tqdm.pandas()

df_new_small = df_new_small.progress_apply(LLM_process_row, axis = 1)

100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [02:38<00:00, 158.19s/it]


In [None]:
df_new_small

Unnamed: 0,word,sentence,word_loc,wordnet,definition,definitions,all_defs,sentence_str,sentence_toks,word_start_index,word_tok_index,definition_prompts,definition_toks,sentence_activations,definition_activates
0,long,How|long|has|it|been|since|you|reviewed|the|ob...,1,long%3:00:02::,primarily temporal sense; being or indicating ...,desire strongly or persistently|primarily temp...,"[desire strongly or persistently, primarily te...",How long has it been since you reviewed the o...,"[2650, 1317, 706, 433, 1027, 2533, 499, 22690,...",4,1,"[ long: desire strongly or persistently, long...","[[1317, 25, 12876, 16917, 477, 23135, 4501], [...",[[[[0.0119629 -0.00976562 -0.0106201 ... -0.01...,[[[[[0.000976562 0.0205078 0.0634766 ... -0.01...


In [None]:
df_new_small.iloc[0].sentence_activations[0].shape

(1, 256, 2048)

## Data volume

We are seeing roughly 1Mb of data per layer, so 16Mb per sentence/definition

In [None]:
df_new["num_defs"] = df_new["definition_toks"].apply(lambda ds : len(ds))
df_new["num_defs"].sum()

907

Then we have 1000 setences + definitions an so will be generating 16Gb of data per chunk, and thus would need ~30Tb of storage to process it all!

Lets cut down the volume, first lets see what is actually a resonable max token length

In [None]:
df_new["number_sentence_toks"] = df_new["sentence_toks"].apply(lambda toks : len(toks))
df_new["number_def_toks"] = df_new["definition_toks"].apply(lambda defs_toks : [len(def_toks) for def_toks in defs_toks])

In [None]:
df_new["number_sentence_toks"].describe()

count    100.000000
mean      20.710000
std        8.210974
min        6.000000
25%       13.000000
50%       19.000000
75%       27.000000
max       35.000000
Name: number_sentence_toks, dtype: float64

In [None]:
import numpy as np

all_values = [num for sublist in df_new['number_def_toks'] for num in sublist]

# Compute statistics
third = np.percentile(all_values, 75)
median = np.median(all_values)
mean = np.mean(all_values)
max_val = np.max(all_values)

# Print results
stats = {
    "Third" : third,
    "Median": median,
    "Mean": mean,
    "Max": max_val,
}
print(stats)

{'Third': 13.5, 'Median': 9.0, 'Mean': 10.783902976846747, 'Max': 42}


limiting the inputs to 32 tokens looks like it would still keep most of the information we need, the definitions can be clipped from the ends... for the sentences we should clip outwards from the word, luckily we have the word token indices for that

That will yield a 3-fold (8x) reduction in data volume to roughly 4Tb. However we will not necessarily process all the data, and 500Mb per chunk seems more manageable for processing.


## Batching

Currently the dataframe based is a bit awkward, it will be easier to extract all the lists of tokens, alongside lookup dicts based on the indices in the dataframe, then convert them to one big input tokens array, which we can chunk and send through the transformer in batches.

The outputs can then be unpacked - it might be worth adding an output tensor dimension instead of the list approach currently used.