# Assignment 1: Basics

In [3]:
import os

## 1 - Assignment Overview

## 2 - Byte-Pair Encoding Tokenizer

### 2.1 - The Unicode standard

In [90]:
chr0 = chr(0)

In [91]:
print(f'Direct print: "{chr0}"; representation: "{chr0.__repr__()}"')

Direct print: " "; representation: "'\x00'"


(a) It represents the NULL character

(b) Printing it gives nothing, but it's representation is \x00, which is a bytes literal for 0, which represents U+0000 in Unicode

(c) Printing shows nothing, but the representation is still the \x00

### 2.2 - Unicode Encodings


In [92]:
test_string = "hello! こんにちは!"

In [93]:
utf8_encoded = test_string.encode("utf-8")
print(utf8_encoded)

b'hello! \xe3\x81\x93\xe3\x82\x93\xe3\x81\xab\xe3\x81\xa1\xe3\x81\xaf!'


In [94]:
print(type(utf8_encoded))

<class 'bytes'>


In [95]:
list (utf8_encoded)

[104,
 101,
 108,
 108,
 111,
 33,
 32,
 227,
 129,
 147,
 227,
 130,
 147,
 227,
 129,
 171,
 227,
 129,
 161,
 227,
 129,
 175,
 33]

In [96]:
print(f'length of test string: {len(test_string)}, vs length of utf encoding {len(utf8_encoded)}')

length of test string: 13, vs length of utf encoding 23


In [97]:
for id in range(len(utf8_encoded)):
    try:
        print(f'first {id} utf8 characters: "{utf8_encoded[0:id].decode("utf-8")}"')
    except:
        print(f'first {id} utf8 characters: <<decoding failed>>')

first 0 utf8 characters: ""
first 1 utf8 characters: "h"
first 2 utf8 characters: "he"
first 3 utf8 characters: "hel"
first 4 utf8 characters: "hell"
first 5 utf8 characters: "hello"
first 6 utf8 characters: "hello!"
first 7 utf8 characters: "hello! "
first 8 utf8 characters: <<decoding failed>>
first 9 utf8 characters: <<decoding failed>>
first 10 utf8 characters: "hello! こ"
first 11 utf8 characters: <<decoding failed>>
first 12 utf8 characters: <<decoding failed>>
first 13 utf8 characters: "hello! こん"
first 14 utf8 characters: <<decoding failed>>
first 15 utf8 characters: <<decoding failed>>
first 16 utf8 characters: "hello! こんに"
first 17 utf8 characters: <<decoding failed>>
first 18 utf8 characters: <<decoding failed>>
first 19 utf8 characters: "hello! こんにち"
first 20 utf8 characters: <<decoding failed>>
first 21 utf8 characters: <<decoding failed>>
first 22 utf8 characters: "hello! こんにちは"


In [98]:
def decode_utf8_bytes_to_str_wrong(bytestring: bytes):
    try:
        output_str = "".join([bytes([b]).decode("utf-8") for b in bytestring])
    except:
        output_str = "ERROR"
    return output_str

decode_utf8_bytes_to_str_wrong("hello!".encode("utf-8"))
decode_utf8_bytes_to_str_wrong("hello! こんにちは!".encode("utf-8"))

'ERROR'

In [99]:
bytes_list = list("こ".encode("utf-8"))
print(f'For character "こ", this is the bytes list: {list("こ".encode("utf-8"))}, representing {"こ".encode("utf-8")}')
print(f'If I decode all 3 bytes together I get "{bytes(bytes_list).decode("utf-8")}"')
print(f'If I decode only first 2 bytes I get "{bytes(bytes_list[0:2]).decode("utf-8")}"')

For character "こ", this is the bytes list: [227, 129, 147], representing b'\xe3\x81\x93'
If I decode all 3 bytes together I get "こ"


UnicodeDecodeError: 'utf-8' codec can't decode bytes in position 0-1: unexpected end of data

(a) Reasons to prefer UTF-8 encoded bytes: Bytes are naturally well-aligned with the infrastructure of the internet and codes. UTF-16 and UTF-32 are multi-byte, introducing complexity

(b) `decode_utf8_bytes_to_str_wrong` doesn't work for characters that are longer than 1 byte.

(c) `\xe3\x81`, which is only the first 2 of 3 bytes in こ (ko)

### 2.3 - Subword Tokenization

### 2.4 - BPE Tokenizer Training

In [14]:
text_to_tokenize = "Some text that I'll pre-tokenize. こんにちは!"

In [15]:
import regex as re

PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

re.findall(PAT, text_to_tokenize)

['Some',
 ' text',
 ' that',
 ' I',
 "'ll",
 ' pre',
 '-',
 'tokenize',
 '.',
 ' こんにちは',
 '!']

In [16]:
PAT_BYTE = rb"'(?:[sdmt]|ll|ve|re)| ?[A-Za-z]+| ?[0-9]+| ?[^\sA-Za-z0-9]+|\s+(?!\S)|\s+"

re.findall(PAT_BYTE, text_to_tokenize.encode("utf_8"))

[b'Some',
 b' text',
 b' that',
 b' I',
 b"'ll",
 b' pre',
 b'-',
 b'tokenize',
 b'.',
 b' \xe3\x81\x93\xe3\x82\x93\xe3\x81\xab\xe3\x81\xa1\xe3\x81\xaf!']

In [None]:
re.findall(PAT,"hello! こんにちは!")

['hello', '!', ' こんにちは', '!']

Algorithm 1 of Sennrich et al. [2016]:

In [102]:
import re, collections

def get_stats(vocab):
    pairs = collections.defaultdict(int)
    for word, freq in vocab.items():
        symbols = word.split()
        for i in range(len(symbols)-1):
            pairs[(symbols[i],symbols[i+1])] += freq
    return pairs

def merge_vocab(pair, v_in):
    v_out = {}
    bigram = re.escape(' '.join(pair))
    p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
    for word in v_in:
        w_out = p.sub(''.join(pair), word)
        v_out[w_out] = v_in[word]
    return v_out

In [18]:

vocab = {'l o w </w>' : 5, 'l o w e r </w>' : 2,'n e w e s t </w>':6, 'w i d e s t </w>':3}

num_merges = 10

for i in range(num_merges):
    pairs = get_stats(vocab)
    best = max(pairs, key=pairs.get)
    vocab = merge_vocab(best, vocab)
    print(f'iteration {i}: best match "{best}"')

print(vocab)

NameError: name 'get_stats' is not defined

Relation between `num_merges` and `vocab` size:

- If you run a low number of merges, you're ending up with almost the same encoding as UTF-8

- The more you run the, closer you get to full word encoding, where each word is represented directly by a token.

- However, if you introduce new words, they will obviously still be encoded using existing pieces, that are smaller than the word

### 2.5 - Experimenting with BPE Tokenizer Training

In [4]:
with open("data/owt_valid.txt", "rb") as file:    
    # Get total file size in bytes
    file.seek(0, os.SEEK_END)
    file_size = file.tell()
    print(f'file size: {file_size}')
    file.seek(0)

    desired_num_chunks = 3
    chunk_size = file_size // desired_num_chunks
    chunk_boundaries = [i * chunk_size for i in range(desired_num_chunks + 1)]
    chunk_boundaries[-1] = file_size

file size: 289998753


In [128]:
!stat -f %z data/owt_valid.txt

6896.00s - pydevd: Sending message related to process being replaced timed-out after 5 seconds


289998753


Diary on profiling

1. I ran profiling on TinyStoriesV2-GPT4-train.txt and it gave me the entries for profiles/profile_with_pretokenization. 
   1. Time spent is roughly split by build_pretokens, get_stats and apply_merge
   2. We can speed up build_pretokens by running it in parallel
   3. We can speed up get_stats and apply_merge by using a better algorithm as explained at the bottom of page 8 in [cs336_spring2025_assignment1_basics.pdf](cs336_spring2025_assignment1_basics.pdf)
   4. We should focus first on improving the algorithm