On 2023/01/26, we discussed it would be more interesting to analyze **conditional probabilities** (as opposed to unconditional probabilities), since language is largely contextual). To this end, we agreed on:

- Run the analysis and comparisons in terms of tokens (as opposed to text).
- Running **bi-gram** analysis of the words that co-occur more often within the beginning of the documents.
- Use the first token of the bigram to condition the LM and observe the conditional probability of the second token, given the first one: $P(N_{w_2}(K) > 0|w_1)$

Comparing the text and the set of tokens in terms of indices, mitigates issues with tokenization that arise in models like GPT2 or GPT-Neo, where **"data"** or **" data"** are represented differently. Note however, that we do not consider the probability mass (in our estimates) assigned to the sequence of tokens **"d" "a" "t" "a"**. We set to revisit this issue later if we find a big gap between model and data distributions.


One pseudo-algorithm for computing the terms frequencies is:

```
Pseudocode: Compute the token frequencies of the first 100 tokens across all documents in the provided dataset.
input: Tokenizer (T), docs (D)
output: frequencies<int, int>

frequencies<int, int> = dict()

for d in D:
  tokenized_doc = T.tokenize(d)
  tokenized_doc = tokenized_doc.slice(100)  // get first 100 

  frequencies.update_counts(tokenized_doc)
end

return frequencies
```

~~Implementation-wise, if we obtain a list of all possible document names, we can parallelize the terms counts.~~ However, chances are we cannot use the "_id" field, since it is a private field and we have no permissions to change the indexing mapping. Let us test the method on a couple of documents, if it's proven to be considerably fast, we will not parallelize it.

In [1]:
K_TOKENS = 10
INDEX = "re_pile"

# Default keyword arguments for elastic search
default_kwargs = {
    "index": INDEX,
    "track_total_hits": True,
}

# Load elastic search 
from es_utils import load
from compute_unigrams import read_yaml_config
configs = read_yaml_config("./configs/elastic-search.yml")
es = load(**configs)

In [5]:
def get_tokenization_function(tokenizer, num_tokens: int):
    from functools import partial
    tokenizer.pad_token = tokenizer.pad_token or tokenizer.eos_token
    
    return partial(
        tokenizer.batch_encode_plus,
        max_length=num_tokens,
        truncation=True,
        padding="max_length",
        add_special_tokens=False,
    )

In [7]:
num_tokens = 15
model_name = "EleutherAI/gpt-neo-125M"

from transformers import GPT2Tokenizer
tokenizer_slow = GPT2Tokenizer.from_pretrained(model_name)
tokenize_slow = get_tokenization_function(tokenizer_slow, num_tokens)
print("Slow tokenizer:", tokenizer_slow.vocab_size)


from transformers import GPT2TokenizerFast
tokenizer_fast = GPT2TokenizerFast.from_pretrained(model_name)
tokenize_fast = get_tokenization_function(tokenizer_fast, num_tokens)
print("Fast tokenizer:", tokenizer_fast.vocab_size)

Using pad_token, but it is not set yet.


Slow tokenizer: 50257
Hello   world! [0, 5145]


Using pad_token, but it is not set yet.


Fast tokenizer: 50257
Hello   world! [0, 5145]


In [17]:
import time as t
import es_utils


def test_scroll(tokenize: callable, query: dict, total_iters=20):
    duration = []
    
    data = iter(es_utils.scroll(es, query, size=10_000, scroll="10m", index="re_pile"))

    for i in range(total_iters):
        docs = next(data)
        
        start = t.time()
        tokenize([es_utils.get_text(d) for d in docs])
        duration.append(t.time() - start)
        print(duration)
        
    return duration

Total documents found {'value': 211036967, 'relation': 'eq'}
[36.882235050201416]
[36.882235050201416, 36.09491682052612]
[36.882235050201416, 36.09491682052612, 33.985692739486694]
[36.882235050201416, 36.09491682052612, 33.985692739486694, 38.33827185630798]
[36.882235050201416, 36.09491682052612, 33.985692739486694, 38.33827185630798, 38.1004114151001]
[36.882235050201416, 36.09491682052612, 33.985692739486694, 38.33827185630798, 38.1004114151001, 42.5689332485199]
[36.882235050201416, 36.09491682052612, 33.985692739486694, 38.33827185630798, 38.1004114151001, 42.5689332485199, 39.27310395240784]
[36.882235050201416, 36.09491682052612, 33.985692739486694, 38.33827185630798, 38.1004114151001, 42.5689332485199, 39.27310395240784, 31.29170274734497]
[36.882235050201416, 36.09491682052612, 33.985692739486694, 38.33827185630798, 38.1004114151001, 42.5689332485199, 39.27310395240784, 31.29170274734497, 39.38059687614441]
