# Runnning Llama on the prompts

Now we can test the full loop, before writing a job manifest, we will:
1. load Llama into memory from our pickled weights
2. normalised the data chunk to get prompts
3. run the LLM on our promps
4. save the hidden activations to disk

5. (optional) implement caching for the definitions, since they may be shared between multiple words


In [1]:
import sys
import os

os.getcwd()
project_path = os.path.abspath("LLM")

if project_path not in sys.path:
    sys.path.append(project_path)

In [2]:
from llama.tokenizer import Tokenizer

tok_path = "/home/matt/.llama/checkpoints/Llama3.2-1B-hf-tok/tokenizer.model"
tok = Tokenizer(tok_path)
tok

<llama.tokenizer.Tokenizer at 0x7f9ee8760770>

In [3]:
import pickle

with open("Data/ModelWeights/llama_jax_weights.pkl", "rb") as f:
    params_jax_loaded = pickle.load(f)

n_heads = 32
n_kv_heads = 8



In [4]:
import pandas as pd

In [5]:
df = pd.read_csv('Data/Processed/SemCoreChunks/chunk_0.csv')

In [None]:
from typing import List
import re

def fix_whitespace(text):
    # Remove spaces before punctuation
    text = re.sub(r'\s+([?.!,;:])', r'\1', text)
    # Ensure space after punctuation if followed by a word (except for some cases like commas within numbers)
    text = re.sub(r'([?.!;:])(?=[^\s])', r'\1 ', text)
    return text


type series = pd.core.series.Series

def prep_row(row: series) -> series:
    row["all_defs"] = str.split(row.definitions, '|')
    row["sentence_str"]  = ' ' + fix_whitespace(' '.join(str.split(row.sentence, '|')))  
    row["sentence_toks"] = tok.encode(row["sentence_str"], bos = False, eos = False)

    remaining = row["sentence_str"]

    word_start_index = 0

    for i in range(row["word_loc"]):
        word = tok.decode([row["sentence_toks"][i]])
        word_start_index += remaining.find(word)
        word_start_index += len(word)
        remaining = row["sentence_str"][word_start_index:]


    row["word_start_index"] = word_start_index

    word_tok_index = 0
    i = 0

    while(True):

        word = tok.decode([row["sentence_toks"][word_tok_index]])
        
        i += len(word)
        
        if i > word_start_index:
            break
        
        word_tok_index  += 1

    row["word_tok_index"] = word_tok_index

    encode = lambda  text : tok.encode(text, bos = False, eos = False)
    if len(encode(" " + row["word"])) == len(encode(" " + row["word"] + ":")):
        return Exception("colon was absorbed - changing token embedding!")
    
    defs = row["definitions"].split('|')

    row["definition_prompts"] = [" " + row["word"] + ": " + d for d in defs]
    row["definition_toks"] = [encode(def_prompt) for def_prompt in row["definition_prompts"]]

    return row

In [7]:
from tqdm import tqdm
tqdm.pandas()

df_new = df.progress_apply(prep_row, axis=1)

100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 125.85it/s]


In [8]:
df_new.head(2)

Unnamed: 0,word,sentence,word_loc,wordnet,definition,definitions,all_defs,sentence_str,sentence_toks,word_start_index,word_tok_index,definition_prompts,definition_toks
0,long,How|long|has|it|been|since|you|reviewed|the|ob...,1,long%3:00:02::,primarily temporal sense; being or indicating ...,desire strongly or persistently|primarily temp...,"[desire strongly or persistently, primarily te...",How long has it been since you reviewed the o...,"[2650, 1317, 706, 433, 1027, 2533, 499, 22690,...",4,1,"[ long: desire strongly or persistently, long...","[[1317, 25, 12876, 16917, 477, 23135, 4501], [..."
1,been,How|long|has|it|been|since|you|reviewed|the|ob...,4,be%2:42:03::,"have the quality of being; (copula, used with ...",a light strong brittle grey toxic bivalent met...,[a light strong brittle grey toxic bivalent me...,How long has it been since you reviewed the o...,"[2650, 1317, 706, 433, 1027, 2533, 499, 22690,...",16,4,[ been: a light strong brittle grey toxic biva...,"[[1027, 25, 264, 3177, 3831, 95749, 20366, 215..."


In [9]:
params_jax_loaded["freqs_cis"].shape

(2048, 32)

In [10]:
import jax
import jax.numpy as jnp

from LLM.llama_jax.model import reporting_transformer


jitted_transformer = jax.jit(reporting_transformer, static_argnames=["n_heads", "n_kv_heads"])


max_tok_len = 256

params_jax_loaded["freqs_cis"] = params_jax_loaded["freqs_cis"][0:max_tok_len, :]

def pad_toks(toks: List[int], max_len = max_tok_len):
    to_pad = max_len - len(toks)

    if (to_pad < 0):
        ret = toks[0:to_pad]
    else:
        ret = toks + [0] * to_pad

    mask = [0 if tok != 0 else -jnp.inf for tok in ret]

    return jnp.array(ret)[None, :], jnp.array(mask)[None, :]


def get_activations(toks: List[int]):
    padded_toks, mask = pad_toks(toks)
    return jitted_transformer(padded_toks, params_jax_loaded, mask, n_heads, n_kv_heads)


def LLM_process_row(row: series) -> series:
    sentence_acts = get_activations(row.sentence_toks)
    def_acts     = [get_activations(def_toks) for def_toks in row.definition_toks]

    # TODO - if the data volume is too large, then only store the activations
    # near the word itself
    row["sentence_activations"] = sentence_acts
    row["definition_activates"] = def_acts

    return row


In [11]:
df_new_small = df_new.head(1).copy()

In [12]:
tqdm.pandas()

#df_new_small = df_new_small.progress_apply(LLM_process_row, axis = 1)

In [13]:
#df_new_small

In [14]:
#df_new_small.iloc[0].sentence_activations[0].shape

## Data volume

We are seeing roughly 1Mb of data per layer, so 16Mb per sentence/definition

In [15]:
df_new["num_defs"] = df_new["definition_toks"].apply(lambda ds : len(ds))
df_new["num_defs"].sum()

907

Then we have 1000 setences + definitions an so will be generating 16Gb of data per chunk, and thus would need ~30Tb of storage to process it all!

Lets cut down the volume, first lets see what is actually a reasonable max token length

In [None]:
df_new["number_sentence_toks"] = df_new["sentence_toks"].apply(lambda toks : len(toks))
df_new["number_def_toks"] = df_new["definition_toks"].apply(lambda defs_toks : [len(def_toks) for def_toks in defs_toks])

In [17]:
df_new["number_sentence_toks"].describe()

count    100.000000
mean      20.710000
std        8.210974
min        6.000000
25%       13.000000
50%       19.000000
75%       27.000000
max       35.000000
Name: number_sentence_toks, dtype: float64

In [18]:
import numpy as np

all_values = [num for sublist in df_new['number_def_toks'] for num in sublist]

# Compute statistics
third = np.percentile(all_values, 75)
median = np.median(all_values)
mean = np.mean(all_values)
max_val = np.max(all_values)

# Print results
stats = {
    "Third" : third,
    "Median": median,
    "Mean": mean,
    "Max": max_val,
}
print(stats)

{'Third': 13.5, 'Median': 9.0, 'Mean': 10.783902976846747, 'Max': 42}


limiting the inputs to 32 tokens looks like it would still keep most of the information we need, the definitions can be clipped from the ends... for the sentences we should clip outwards from the word, luckily we have the word token indices for that

That will yield a 3-fold (8x) reduction in data volume to roughly 4Tb. However we will not necessarily process all the data, and 500Mb per chunk seems more manageable for processing.


In [19]:
def clip_definition_token_list(def_toks: List[List[int]], max_tok_len = 32) -> List[List[int]]:
    ret = []
    for _, lst in enumerate(def_toks):
        if len(lst) <= max_tok_len:
            ret.append(lst)
        else:
            ret.append(lst[0:max_tok_len])
    return ret

def clip_sentence_token_list(row: series, max_tok_len = 32) -> series:

    range = (0,0)
    new_index = row.word_tok_index

    if len(row.sentence_toks) <= max_tok_len:
        range = (0,None)
    elif (row.word_tok_index <= max_tok_len // 2):
        range = (0,max_tok_len)
    elif (len(row.sentence_toks) - row.word_tok_index <= max_tok_len // 2):
        range = (-max_tok_len, None)
        new_index -= (len(row.sentence_toks) - max_tok_len) # this much has been removed
    else:
        delta = max_tok_len // 2 
        mod   = max_tok_len %  2
        mid = row.word_tok_index
        range = (mid - delta, mid + delta + mod)
        new_index -= (mid - delta)

    row["clipped_sentence_toks"] = row.sentence_toks[range[0]:range[1]]
    row["clipped_word_tok_index"] = new_index

    return row
    

In [20]:
# tests a la chatgpt

# 1) Sentence length <= max_tok_len
row = pd.Series({
    "sentence_toks": ["I", "love", "Python"],
    "word_tok_index": 1
})
row = clip_sentence_token_list(row, max_tok_len=5)
assert row["clipped_sentence_toks"] == ["I", "love", "Python"]
assert row["clipped_word_tok_index"] == 1

# 2) Focus token near the beginning
row = pd.Series({
    "sentence_toks": ["Token"] * 20,
    "word_tok_index": 2
})
row = clip_sentence_token_list(row, max_tok_len=6)
# Should keep from index 0 up to index 6, word_tok_index should remain 2
assert len(row["clipped_sentence_toks"]) == 6
assert row["clipped_word_tok_index"] == 2

# 3) Focus token near the end
row = pd.Series({
    "sentence_toks": ["Token"] * 20,
    "word_tok_index": 18
})
row = clip_sentence_token_list(row, max_tok_len=5)
# Should keep last 5 tokens
# Original indices: 15,16,17,18,19
# 'word_tok_index' = 18 => new_index = 18 - 15 = 3
assert len(row["clipped_sentence_toks"]) == 5
assert row["clipped_word_tok_index"] == 3

# 4) Focus token somewhere in the middle
row = pd.Series({
    "sentence_toks": list(range(30)),  # Just use integer tokens for clarity
    "word_tok_index": 15
})
row = clip_sentence_token_list(row, max_tok_len=6)
# max_tok_len=6 => delta=3 => mod=0
# start=15-3=12, end=15+3=18 => slice is [12,13,14,15,16,17]
# new_index = 15 - 12 = 3
assert row["clipped_sentence_toks"] == [12, 13, 14, 15, 16, 17]
assert row["clipped_word_tok_index"] == 3

# 5) Edge case: focus token exactly on boundary (like index=10, max_tok_len=6)
row = pd.Series({
    "sentence_toks": list(range(20)),
    "word_tok_index": 10
})
row = clip_sentence_token_list(row, max_tok_len=6)
# delta=3, mod=0 => slice [7..13) => [7,8,9,10,11,12], length=6
# new_index = 10 - 7 = 3
assert len(row["clipped_sentence_toks"]) == 6
assert row["clipped_word_tok_index"] == 3

print("All tests passed!")


All tests passed!


In [21]:
df_new["clipped_definition_toks"] = df_new["definition_toks"].apply(clip_definition_token_list)
df_new = df_new.apply(clip_sentence_token_list, axis = 1)

In [22]:
df_new[df_new["clipped_word_tok_index"] != df_new["word_tok_index"]].sample(2)

Unnamed: 0,word,sentence,word_loc,wordnet,definition,definitions,all_defs,sentence_str,sentence_toks,word_start_index,word_tok_index,definition_prompts,definition_toks,num_defs,number_sentence_toks,number_def_toks,clipped_definition_toks,clipped_sentence_toks,clipped_word_tok_index
85,productivity,When|improvements|are|recommended|in|working|c...,30,productivity%1:07:00::,the quality of being productive or having the ...,the quality of being productive or having the ...,[the quality of being productive or having the...,When improvements are recommended in working ...,"[3277, 18637, 527, 11349, 304, 3318, 4787, 482...",157,30,[ productivity: the quality of being productiv...,"[[26206, 25, 279, 4367, 315, 1694, 27331, 477,...",2,35,"[13, 23]","[[26206, 25, 279, 4367, 315, 1694, 27331, 477,...","[11349, 304, 3318, 4787, 482, 1778, 439, 18186...",27
82,try,When|improvements|are|recommended|in|working|c...,21,try%2:41:00::,make an effort or attempt,earnest and conscientious activity intended to...,[earnest and conscientious activity intended t...,When improvements are recommended in working ...,"[3277, 18637, 527, 11349, 304, 3318, 4787, 482...",124,21,[ try: earnest and conscientious activity inte...,"[[1456, 25, 55349, 323, 74365, 1245, 5820, 108...",10,35,"[13, 7, 17, 18, 6, 15, 7, 6, 16, 16]","[[1456, 25, 55349, 323, 74365, 1245, 5820, 108...","[11349, 304, 3318, 4787, 482, 1778, 439, 18186...",18


In [23]:
def validate_clipped_token(row: pd.Series) -> pd.Series:
    clipped_index = row["clipped_word_tok_index"]
    original_index = row["word_tok_index"]

    if 0 <= clipped_index < len(row["clipped_sentence_toks"]):
        clipped_token = row["clipped_sentence_toks"][clipped_index]
        original_token = row["sentence_toks"][original_index]
        row["clipped_token_matches"] = (clipped_token == original_token)
    else:
        row["clipped_token_matches"] = False

    return row


In [24]:
df_new = df_new.apply(validate_clipped_token, axis = 1)
df_new[df_new["clipped_token_matches"] == False]

Unnamed: 0,word,sentence,word_loc,wordnet,definition,definitions,all_defs,sentence_str,sentence_toks,word_start_index,word_tok_index,definition_prompts,definition_toks,num_defs,number_sentence_toks,number_def_toks,clipped_definition_toks,clipped_sentence_toks,clipped_word_tok_index,clipped_token_matches


In [25]:
def validate_definition_lengths(def_toks: List[List[int]], max_tok_len = 32) -> bool:
    return all(len(lst) <= max_tok_len for lst in def_toks)


def validate_clipped_sentence_length(row: pd.Series, max_tok_len = 32) -> pd.Series:
    if "clipped_sentence_toks" not in row:
        row["is_clipped_sentence_valid"] = False
    else:
        row["is_clipped_sentence_valid"] = (len(row["clipped_sentence_toks"]) <= max_tok_len)
    return row

In [26]:
df_tmp = df_new.copy()

In [27]:
df_tmp["def_lens_valid"] = df_tmp["clipped_definition_toks"].apply(validate_definition_lengths)
df_tmp = df_tmp.apply(validate_clipped_sentence_length, axis = 1)

In [28]:
print(f"number bad def rows:      {df_tmp[df_tmp["def_lens_valid"] == False].shape[0]}")
print(f"number bad sentence rows: {df_tmp[df_tmp["is_clipped_sentence_valid"] == False].shape[0]}")

number bad def rows:      0
number bad sentence rows: 0


Now we have the clipped sentences/definitions, with a clip level set for minimal loss of contextualising semantic information

### Further data reduction volume

Do we need to save all the tokens' hidden activations? probably not - only the ones near the word of interest are likely to be helpful.

For the definition, these will be the [0,n] tokens

For the sentence, these will be the [w - n, w + n] tokens, since the word may appear anywhere, we should pad this if the word is near the start/end of the sentence.

A sensible choice of n would be e.g. 5, so 10 activations per sentence, and 5 per definitions, roughly a 4x data reduction, so only 1Tb of data, or 125Mb per chunk.

We still choose to reduce the sentence length as well, since that means we can batch more sentences/definitions per pass of the transformer,

This data reduction is probably best saved for post-processing before saving, since it will require access to the word token indices.

## Batching

Currently the dataframe based approach is a bit awkward, it will be easier to extract all the lists of tokens, alongside lookup dicts based on the indices in the dataframe, then convert them to one big input tokens array, which we can chunk and send through the transformer in batches.

This will also be a good location to inject the cacheing logic, and short circuit the call to the LLM.

The outputs can then be unpacked - it might be worth adding an output tensor dimension instead of the list approach currently used.

Finally we should have a method to extract the indices of the activations that we actually care about.

### Costly operations

Previously we saw that calling the LLM was very costly in terms of time, however now with the 32 token max sentence limit - the bottleneck is now working with the large arrays of activations that are output.

We should keep indexing information for how to select the regions of these activations we are interested in, and keep that using Jax, to minimise the data volume that ends up being parsed into a pandas array

In [29]:
df_new.head(1)

Unnamed: 0,word,sentence,word_loc,wordnet,definition,definitions,all_defs,sentence_str,sentence_toks,word_start_index,word_tok_index,definition_prompts,definition_toks,num_defs,number_sentence_toks,number_def_toks,clipped_definition_toks,clipped_sentence_toks,clipped_word_tok_index,clipped_token_matches
0,long,How|long|has|it|been|since|you|reviewed|the|ob...,1,long%3:00:02::,primarily temporal sense; being or indicating ...,desire strongly or persistently|primarily temp...,"[desire strongly or persistently, primarily te...",How long has it been since you reviewed the o...,"[2650, 1317, 706, 433, 1027, 2533, 499, 22690,...",4,1,"[ long: desire strongly or persistently, long...","[[1317, 25, 12876, 16917, 477, 23135, 4501], [...",12,17,"[7, 26, 19, 11, 5, 13, 14, 5, 9, 11, 11, 6]","[[1317, 25, 12876, 16917, 477, 23135, 4501], [...","[2650, 1317, 706, 433, 1027, 2533, 499, 22690,...",1,True


In [30]:
import hashlib

def hash_tokens(tokens: List[int], index: int):
    token_str = ','.join(map(str, tokens)) + str(index)
    return hashlib.sha256(token_str.encode()).hexdigest()



def transform_dataframe(df, activation_width = 4):
    records = []

    for idx, row in tqdm(df.iterrows()):

        # process sentence toks
        sent_toks = row['clipped_sentence_toks']
        tok_index = row['clipped_word_tok_index']
        records.append((idx, sent_toks, 'sentence', -1, hash_tokens(sent_toks, tok_index), tok_index, tok_index-activation_width))
        
        # process definitions' toks
        for def_idx, def_toks in enumerate(row['clipped_definition_toks']):
            records.append((idx, def_toks, 'definition', def_idx, hash_tokens(def_toks, 0), 0, 0-activation_width)) # definitions are always the first token
    new_df = pd.DataFrame(records, columns=['map_index', 'toks', 'column', 'def_index', 'hash', 'clipped_word_tok_index', 'istart'])
    return new_df

In [31]:
toks_df = transform_dataframe(df_new, 4)
toks_df.head()

100it [00:00, 1452.02it/s]


Unnamed: 0,map_index,toks,column,def_index,hash,clipped_word_tok_index,istart
0,0,"[2650, 1317, 706, 433, 1027, 2533, 499, 22690,...",sentence,-1,e5628f2b996830cd285a609e72bf7fbce8666e0b4f7d49...,1,-3
1,0,"[1317, 25, 12876, 16917, 477, 23135, 4501]",definition,0,e04ce50c948bb2e03144f8b8ada93591c11c89d6c7d355...,0,-4
2,0,"[1317, 25, 15871, 37015, 5647, 26, 1694, 477, ...",definition,1,5e724cd29a7137fb48b65d0719ffe65f5032fdde33f6a3...,0,-4
3,0,"[1317, 25, 15871, 29079, 5647, 26, 315, 12309,...",definition,2,e71bf88afe0c699b6990b2170503455216ad12ec401880...,0,-4
4,0,"[1317, 25, 315, 12309, 2294, 2673, 26, 482, 17...",definition,3,bbcf55681897f86d0d249055d0a34d53e2957bb154050b...,0,-4


In [32]:
print(f"old shape {df_new.shape}")
print(f"new shape {toks_df.shape}")

old shape (100, 20)
new shape (1007, 7)


We see a 10-fold increase in data volume after putting each list of tokens in its own dataframe row

We use negative/off the end indexing here, because after we have the activations we will zero pad +-width to deal with words near the start/end

In [33]:
print(f"original length was {toks_df.shape[0]}")

unique_toks_df = toks_df.drop_duplicates(subset=['hash']).copy().reset_index(drop = True)

print(f"unique token lists: {unique_toks_df.shape[0]}")


original length was 1007
unique token lists: 830


In [34]:
unique_toks_df["len"] = unique_toks_df['toks'].apply(len)

In [35]:
unique_toks_df[["padded_toks", "mask"]] = unique_toks_df["toks"].progress_apply(lambda lst: pad_toks(lst, 32)).apply(pd.Series)
# padding at the end so our indices are still valid

  0%|                                                                                           | 0/830 [00:00<?, ?it/s]

100%|████████████████████████████████████████████████████████████████████████████████| 830/830 [00:04<00:00, 178.52it/s]


In [36]:
unique_toks_df["tok_shape"] = unique_toks_df['padded_toks'].apply(lambda arr : arr.shape)
unique_toks_df["mask_shape"] = unique_toks_df['mask'].apply(lambda arr : arr.shape)

In [84]:
batch_size = 64

unique_toks_df['batch_id'] = unique_toks_df.index // batch_size

Now we have the tokens ready to be passed through the LLM, we can stack them, run the transformer, then unpack the results back into the rows, lets develop that function now

In [85]:
batch_df = unique_toks_df[unique_toks_df['batch_id'] == 0].copy()

In [86]:
jnp.concatenate(batch_df['padded_toks'], axis = 0).shape

(64, 32)

In [87]:
batch_toks = jnp.concatenate(batch_df['padded_toks'], axis = 0)
batch_masks = jnp.concatenate(batch_df['mask'], axis = 0)

print(f"{batch_toks.shape=}")
print(f"{batch_masks.shape=}")

batch_toks.shape=(64, 32)
batch_masks.shape=(64, 32)


In [41]:
params_jax_loaded["freqs_cis"] = params_jax_loaded["freqs_cis"][0:32, :] # now need even fewer freq cis

In [42]:
batch_activations = jitted_transformer(batch_toks, params_jax_loaded, batch_masks, n_heads, n_kv_heads)

In [43]:
batch_activations.shape

(16, 16, 32, 2048)

In [44]:
batch_df['istart']

0    -3
1    -4
2    -4
3    -4
4    -4
5    -4
6    -4
7    -4
8    -4
9    -4
10   -4
11   -4
12   -4
13    0
14   -4
15   -4
Name: istart, dtype: int64

In [45]:
def _slice_activations_body(activations: jax.Array, indices: jax.Array, width: int):


    batch, layer, toks, hidden = activations.shape

    # pad activations to handle edge cases
    padded_activations = jnp.pad(activations, ((0, 0), (0, 0), (width, width), (0, 0)), mode='constant')

    # adjust indices to account for padding
    adjusted_indices = jnp.ravel(indices + width)  # Shape: (batch,)
    
    # Define function for dynamic slicing
    def extract_slice(activs, start):
        slice_shape = (layer, 9, hidden)  # Static slice shape (layer, 9, hidden)
        return jax.lax.dynamic_slice(activs, (0, start, 0), slice_shape)

    # vectorize over batch dimension
    vmap_extract = jax.vmap(extract_slice, in_axes=(0, 0), out_axes=0)

    result = vmap_extract(padded_activations, adjusted_indices)

    return result


_slice_activations_updated = jax.jit(_slice_activations_body, static_argnames=['width'])

def slice_activations(activations: jax.Array, indices: pd.DataFrame, width: int = 4):
    jax_indices = jnp.array(indices.to_numpy())  
    return _slice_activations_updated(activations, jax_indices, width)


In [46]:
sliced_activations = slice_activations(batch_activations, batch_df[['istart']], width=4)
sliced_activations.shape  # Expected output: (16, 16, 9, 2048)


(16, 16, 9, 2048)

In [47]:
batch_activations.shape

(16, 16, 32, 2048)

In [48]:
batch_activations[0,0,0,:]

Array([0.0119629, -0.00976562, -0.0106201, ..., -0.0142822, -0.0361328,
       -6.10352e-05], dtype=bfloat16)

In [49]:
sliced_activations[0,0,3,:]

Array([0.0119629, -0.00976562, -0.0106201, ..., -0.0142822, -0.0361328,
       -6.10352e-05], dtype=bfloat16)

we see this lines up with the reindexing we anticipated

In [50]:
# batch_outputs = [batch_activations[i] for i in tqdm(range(batch_size))]

batch_df["model_output"] = list(sliced_activations)

In [51]:
batch_df.iloc[0]

map_index                                                                 0
toks                      [2650, 1317, 706, 433, 1027, 2533, 499, 22690,...
column                                                             sentence
def_index                                                                -1
hash                      e5628f2b996830cd285a609e72bf7fbce8666e0b4f7d49...
clipped_word_tok_index                                                    1
istart                                                                   -3
len                                                                      17
padded_toks               [[2650, 1317, 706, 433, 1027, 2533, 499, 22690...
mask                      [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...
tok_shape                                                           (1, 32)
mask_shape                                                          (1, 32)
batch_id                                                                  0
model_output

In [52]:
batch_df.iloc[0]["model_output"]

Array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0.0319824, 0.050293, -0.0712891, ..., -0.0116577, -0.0292969,
         0.0055542],
        [-0.0512695, -0.0184326, -0.0620117, ..., 0.00732422,
         -0.0088501, 0.000671387],
        [0.0390625, 0.0339355, -0.0805664, ..., 0.0561523, -0.0800781,
         0.0057373]],

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0.0400391, 0.0600586, -0.124023, ..., 0.0108643, -0.0203857,
         0.0664062],
        [-0.0600586, -0.0187988, -0.0493164, ..., 0.048584, -0.0264893,
         -0.0114746],
        [0.081543, 0.0688477, -0.0986328, ..., 0.0205078, -0.128906,
         0.0397949]],

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [-0.0402832, 0.103516, -0.125977, ..., -0.0463867, -0.120117,
         0.0263672],
        [-0.0

In [53]:
batch_df['model_output'][0].shape

(16, 9, 2048)

## Progress

We now have the activations [word-4,word+4] with the word centred, all that remains is to do this for all the batches, map the data back to the original dataframe and write it to disk.

We use a filestore for the jax arrays as they are better saved in a terse format (pickled)

In [88]:
batch_dfs = [group.reset_index(drop=True) for _, group in unique_toks_df.groupby('batch_id')]

In [89]:
batch_dfs[1].head(1)

Unnamed: 0,map_index,toks,column,def_index,hash,clipped_word_tok_index,istart,len,padded_toks,mask,tok_shape,mask_shape,batch_id
0,5,"[2532, 25, 12152, 7061, 320, 16381, 304, 9635,...",definition,6,e1f23e6cd017b0bd2624612740dd42d83ce37623012542...,0,-4,25,"[[2532, 25, 12152, 7061, 320, 16381, 304, 9635...","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...","(1, 32)","(1, 32)",1


In [90]:
def process_batch(batch_df: pd.DataFrame) -> pd.DataFrame:
    batch_toks = jnp.concatenate(batch_df['padded_toks'], axis = 0)
    batch_masks = jnp.concatenate(batch_df['mask'], axis = 0)

    batch_activations = jitted_transformer(batch_toks, params_jax_loaded, batch_masks, n_heads, n_kv_heads)
    sliced_activations = slice_activations(batch_activations, batch_df[['istart']], width=4)

    batch_df["model_output"] = list(sliced_activations)

    return batch_df

In [91]:
len(batch_dfs)

13

In [92]:
processed_batch_dfs = [process_batch(batch_df) for batch_df in tqdm(batch_dfs)]

100%|███████████████████████████████████████████████████████████████████████████████████| 13/13 [13:52<00:00, 64.05s/it]


In [106]:
processed_batch_dfs[0].iloc[0]["model_output"]

Array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0.0319824, 0.050293, -0.0712891, ..., -0.0116577, -0.0292969,
         0.0055542],
        [-0.0512695, -0.0184326, -0.0620117, ..., 0.00732422,
         -0.0088501, 0.000671387],
        [0.0390625, 0.0339355, -0.0805664, ..., 0.0561523, -0.0800781,
         0.0057373]],

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0.0400391, 0.0600586, -0.124023, ..., 0.0108643, -0.0203857,
         0.0664062],
        [-0.0600586, -0.0187988, -0.0493164, ..., 0.048584, -0.0264893,
         -0.0114746],
        [0.081543, 0.0688477, -0.0986328, ..., 0.0205078, -0.128906,
         0.0397949]],

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [-0.0402832, 0.103516, -0.125977, ..., -0.0463867, -0.120117,
         0.0263672],
        [-0.0

In [140]:
import uuid  # Generates unique file keys
import csv
import shutil
import zipfile

JAX_STORE_PATH = "Data/Processed/jax_store/chunk_0"  # Directory to store pickled JAX arrays

os.makedirs(JAX_STORE_PATH, exist_ok=True)  # Ensure directory exists

def save_jax_array(array):
    """Save a JAX array to a pickle file and return a unique lookup key."""
    unique_id = str(uuid.uuid4())  # Generate a unique key (v low collision chance)
    file_path = os.path.join(JAX_STORE_PATH, f"{unique_id}.pkl")

    with open(file_path, "wb") as f:
        pickle.dump(array, f)

    return unique_id  # Return lookup key to store in CSV

def load_jax_array(key):
    """Load a JAX array from a pickle file using the given lookup key."""
    file_path = os.path.join(JAX_STORE_PATH, f"{key}.pkl")

    if not os.path.exists(file_path):
        raise FileNotFoundError(f"JAX array file not found: {file_path}")

    with open(file_path, "rb") as f:
        return pickle.load(f)  # Load the JAX array

def write_row(row, writer, parent):
    """Writes a single row to the CSV file using the provided CSV writer."""
    original = parent.iloc[row["map_index"]]

    # Save JAX array and store lookup key
    model_output_key = save_jax_array(row["model_output"])

    to_write = {
        "def_or_sentence": row["column"],
        "def_index": row["def_index"],
        "model_output_key": model_output_key,  # Store only the lookup key
        "word": original["word"],
        "sentence": original["sentence"],
        "word_loc": original["word_loc"],
        "wordnet": original["wordnet"],
        "definition": original["definition"],
        "definitions": original["definitions"]
    }

    writer.writerow(to_write)

def write_results_to_disk(processed_batch_dfs, filePath: str, parent_df: pd.DataFrame):
    """Writes rows from a list of processed batch DataFrames, storing JAX arrays separately."""
    
    # Ensure directory exists
    os.makedirs(os.path.dirname(filePath), exist_ok=True)

    file_exists = os.path.isfile(filePath)

    with open(filePath, 'a', newline='', encoding='utf-8') as f:
        writer = csv.DictWriter(f, fieldnames=[
            "def_or_sentence", "def_index", "model_output_key", "word", 
            "sentence", "word_loc", "wordnet", "definition", "definitions"
        ])

        if not file_exists:
            writer.writeheader()  # Write header only if file doesn't exist

        for processed_batch_df in tqdm(processed_batch_dfs, desc="Processing Batches"):
            for _, row in tqdm(processed_batch_df.iterrows(), total=processed_batch_df.shape[0], desc="Writing Rows"):
                write_row(row, writer, parent_df)

        f.flush()  # Ensure data is written immediately

In [141]:
to_write= 'Data/Processed/ProcessedSemCoreChunks/chunk_0.csv'

In [142]:
write_results_to_disk(processed_batch_dfs, to_write, df_new)

Processing Batches:   0%|                                                                        | 0/13 [00:00<?, ?it/s]

Writing Rows: 100%|█████████████████████████████████████████████████████████████████████| 64/64 [00:02<00:00, 30.32it/s]
Writing Rows: 100%|█████████████████████████████████████████████████████████████████████| 64/64 [00:01<00:00, 54.31it/s]
Writing Rows: 100%|█████████████████████████████████████████████████████████████████████| 64/64 [00:02<00:00, 29.22it/s]
Writing Rows: 100%|█████████████████████████████████████████████████████████████████████| 64/64 [00:02<00:00, 22.62it/s]
Writing Rows: 100%|█████████████████████████████████████████████████████████████████████| 64/64 [00:02<00:00, 26.66it/s]
Writing Rows: 100%|█████████████████████████████████████████████████████████████████████| 64/64 [00:01<00:00, 46.35it/s]
Writing Rows: 100%|█████████████████████████████████████████████████████████████████████| 64/64 [00:01<00:00, 49.03it/s]
Writing Rows: 100%|█████████████████████████████████████████████████████████████████████| 64/64 [00:01<00:00, 41.66it/s]
Writing Rows: 100%|█████████████

In [144]:
ZIP_STORE_PATH = "Data/Processed/jax_store/chunk_0.zip"

def zip_and_clean_jax_store():
    """Compress the JAX store directory into a zip archive and delete original files."""
    shutil.make_archive(JAX_STORE_PATH.rstrip("/"), 'zip', JAX_STORE_PATH)  # Create ZIP

    # Delete all files inside the jax_store/ directory
    for filename in tqdm(os.listdir(JAX_STORE_PATH)):
        file_path = os.path.join(JAX_STORE_PATH, filename)
        if os.path.isfile(file_path):
            os.remove(file_path)  # Delete file

    if os.path.exists(JAX_STORE_PATH) and not os.listdir(JAX_STORE_PATH):
        os.rmdir(JAX_STORE_PATH)  # Remove empty folder

    print(f"✅ JAX store compressed to {ZIP_STORE_PATH} and cleaned up.")


def unzip_jax_store():
    """Extracts the JAX store archive."""
    with zipfile.ZipFile(ZIP_STORE_PATH, 'r') as zip_ref:
        zip_ref.extractall(JAX_STORE_PATH)
    print(f"✅ JAX store extracted to {JAX_STORE_PATH}")

In [145]:
zip_and_clean_jax_store()

100%|████████████████████████████████████████████████████████████████████████████████| 830/830 [00:07<00:00, 107.85it/s]

✅ JAX store compressed to Data/Processed/jax_store/chunk_0.zip and cleaned up.



