In [32]:
import wikipediaapi
from dotenv import load_dotenv
import os
# Wikipedia article setup
webpage = "Python (programming language)" # wikipedia article title
load_dotenv()

def fetchWiki(term, wiki_lang="en"):
    wiki = wikipediaapi.Wikipedia(user_agent=os.getenv('wikipedia_agent'), language=wiki_lang)
    page = wiki.page(term)
    if page.exists():
        return page.summary, page.text
    else:
        print(f"{term} could not be found")
        return "", ""

with open('input.txt', 'r') as f, open("output.txt", "w") as w:
    for line in f:
        line = line.strip()
        summary, text = fetchWiki(line)
        w.write(f"{summary}\n\n{text}\n\n\n")

Tokenization Theory: Iterate through string some times and get pairs of characters, the most common pair is replaced with some symbol

aaabdaaabac -> ZabdZabac -> ZYdZYac -> XdXac

In [20]:
with open('output.txt', 'r', encoding='utf-8') as f:
    text = f.read()
tokens = text.encode("utf-8")
tokens = list(map(int, tokens))
print(len(tokens))

313759


In [21]:
def get_stats(ids):
    counts = {}
    for pair in zip(ids, ids[1:]):
        counts[pair] = counts.get(pair, 0) + 1
    return counts

def merge(ids, pair, idx):
    newids = []
    i = 0
    while i < len(ids):
        if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
            newids.append(idx)
            i += 2
        else:
            newids.append(ids[i])
            i += 1
    return newids

vocab_size = 1000
num_merges = vocab_size - 256
ids = list(tokens)
merge_interval = 30

merges = {}
for i in range(num_merges):
    stats = get_stats(ids)
    pair = max(stats, key=stats.get)
    idx = 256 + i
    if (i % merge_interval == 0):
        print(f"{i} merging {pair} to {idx}")
    ids = merge(ids, pair, idx)
    merges[pair] = idx

print(len(ids))

0 merging (101, 32) to 256
30 merging (97, 32) to 286
60 merging (46, 10) to 316
90 merging (110, 32) to 346
120 merging (118, 263) to 376
150 merging (102, 287) to 406
180 merging (257, 284) to 436
210 merging (417, 271) to 466
240 merging (287, 32) to 496
270 merging (278, 65) to 526
300 merging (303, 271) to 556
330 merging (108, 439) to 586
360 merging (266, 362) to 616
390 merging (285, 108) to 646
420 merging (257, 331) to 676
450 merging (501, 308) to 706
480 merging (388, 32) to 736
510 merging (348, 299) to 766
540 merging (119, 267) to 796
570 merging (607, 299) to 826
600 merging (73, 346) to 856
630 merging (112, 389) to 886
660 merging (358, 509) to 916
690 merging (100, 279) to 946
720 merging (102, 733) to 976
112076


In [22]:
print(f"{len(tokens) / len(ids):.2f}")

2.80


In [23]:
vocab = {idx: bytes([idx]) for idx in range(256)}
for (p0, p1), idx in merges.items():
    vocab[idx] = vocab[p0] + vocab[p1]

In [24]:
def encode(text):
    tokens = list(text.encode("utf-8"))
    while True:
        stats = get_stats(tokens)
        if len(stats) == 0:
            break
        pair = min(stats, key=lambda p: merges.get(p, float("inf")))
        if pair not in merges:
            break
        idx = merges[pair]
        tokens = merge(tokens, pair, idx)
    return tokens

In [25]:
def decode(ids):
    tokens = b"".join(vocab[idx] for idx in ids)
    text = tokens.decode("utf-8", errors="replace")
    return text


In [26]:
for i in encode("Kim"):
    print(f"'{decode([i])}'")

'K'
'im'


In [27]:
import regex as re
# gpt2 tokenizer
gpt2 = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""", re.IGNORECASE)

this = re.findall(gpt2, text)
print(len(this))
print(re.findall(gpt2, "Hello World wouldn't'VE"))

58676
['Hello', ' World', ' wouldn', "'t", "'VE"]


In [28]:
import tiktoken

enc = tiktoken.get_encoding("cl100k_base")
print(enc.max_token_value)

100276


In [29]:
print(enc.decode([100276]))

<|endofprompt|>
