In [2]:
import urllib.request
url = ("https://raw.githubusercontent.com/rasbt/"
       "LLMs-from-scratch/main/ch02/01_main-chapter-code/"
       "the-verdict.txt")
file_path = "the-verdict.txt"
urllib.request.urlretrieve(url, file_path)

('the-verdict.txt', <http.client.HTTPMessage at 0x22420f78350>)

In [3]:
with open("the-verdict.txt", "r", encoding="utf-8") as f:
    raw_text = f.read()

In [4]:
import re

preprocessed = re.split(r'([,.:;?_!"()\']|--|\s)', raw_text)
preprocessed = [item.strip() for item in preprocessed if item.strip()]
print(len(preprocessed))

4690


In [5]:
all_words = sorted(set(preprocessed))
vocab_size = len(all_words)
print(vocab_size)

vocab = {token:integer for integer,token in enumerate(all_words)}

1130


In [6]:
class SimpleTokenizerV1:
    def __init__(self, vocab):
        self.str_to_int = vocab
        self.int_to_str = {i:s for s,i in vocab.items()}

    def encode(self, text):
        preprocessed = re.split(r'([,.?_!"()\']|--|\s)', text)
        preprocessed = [
            item.strip() for item in preprocessed if item.strip()
        ]
        ids = [self.str_to_int[s] for s in preprocessed]
        return ids

    def decode(self, ids):
        text = " ".join([self.int_to_str[i] for i in ids]) 

        text = re.sub(r'\s+([,.?!"()\'])', r'\1', text)
        return text

In [7]:
all_tokens = sorted(list(set(preprocessed)))
all_tokens.extend(["<|endoftext|>", "<|unk|>"])
vocab = {token:integer for integer,token in enumerate(all_tokens)}

print(len(vocab.items()))

1132


In [8]:
tokenizer = SimpleTokenizerV1(vocab)
text = """"It's the last he painted, you know," 
       Mrs. Gisburn said with pardonable pride."""
ids = tokenizer.encode(text)
print(ids)

[1, 56, 2, 850, 988, 602, 533, 746, 5, 1126, 596, 5, 1, 67, 7, 38, 851, 1108, 754, 793, 7]


In [9]:
print(tokenizer.decode(ids))

" It' s the last he painted, you know," Mrs. Gisburn said with pardonable pride.


In [10]:
class SimpleTokenizerV2:
    def __init__(self, vocab):
        self.str_to_int = vocab
        self.int_to_str = { i:s for s,i in vocab.items()}

    def encode(self, text):
        preprocessed = re.split(r'([,.:;?_!"()\']|--|\s)', text)
        preprocessed = [
            item.strip() for item in preprocessed if item.strip()
        ]
        preprocessed = [item if item in self.str_to_int            #1
                        else "<|unk|>" for item in preprocessed]

        ids = [self.str_to_int[s] for s in preprocessed]
        return ids

    def decode(self, ids):
        text = " ".join([self.int_to_str[i] for i in ids])

        text = re.sub(r'\s+([,.:;?!"()\'])', r'\1', text)    #2
        return text

In [11]:
text1 = "Hello, do you like tea?"
text2 = "In the sunlit terraces of the palace."
text = " <|endoftext|> ".join((text1, text2))
print(text)

tokenizer = SimpleTokenizerV2(vocab)
print(tokenizer.encode(text))

print(tokenizer.decode(tokenizer.encode(text)))

Hello, do you like tea? <|endoftext|> In the sunlit terraces of the palace.
[1131, 5, 355, 1126, 628, 975, 10, 1130, 55, 988, 956, 984, 722, 988, 1131, 7]
<|unk|>, do you like tea? <|endoftext|> In the sunlit terraces of the <|unk|>.


In [12]:
import tiktoken

In [13]:
tokenizer = tiktoken.get_encoding("gpt2")

text = (
    "Hello, do you like tea? <|endoftext|> In the sunlit terraces"
     "of someunknownPlace."
)
integers = tokenizer.encode(text, allowed_special={"<|endoftext|>"})
print(integers)

[15496, 11, 466, 345, 588, 8887, 30, 220, 50256, 554, 262, 4252, 18250, 8812, 2114, 1659, 617, 34680, 27271, 13]


In [14]:
strings = tokenizer.decode(integers)
print(strings)

Hello, do you like tea? <|endoftext|> In the sunlit terracesof someunknownPlace.


# Byte-pair encoding

In [72]:
# unicode code point for each character

text = """Want to buy the textbooks from the bookshop, The bookshop is in Cairo"""

encodings = list(text.encode("utf-8"))
print(len(encodings), len(text))

69 69


In [73]:
print(encodings)

[87, 97, 110, 116, 32, 116, 111, 32, 98, 117, 121, 32, 116, 104, 101, 32, 116, 101, 120, 116, 98, 111, 111, 107, 115, 32, 102, 114, 111, 109, 32, 116, 104, 101, 32, 98, 111, 111, 107, 115, 104, 111, 112, 44, 32, 84, 104, 101, 32, 98, 111, 111, 107, 115, 104, 111, 112, 32, 105, 115, 32, 105, 110, 32, 67, 97, 105, 114, 111]


In [74]:
from collections import Counter

def get_freq_text(text):
    encodings = list(text.encode("utf-8"))
    return encodings


def get_pairs(encodings):
    return zip(encodings, encodings[1:])

def get_freq_ids(encodings):
    pairs = get_pairs(encodings)
    pair_counts = Counter(pairs)
    
    pair_dict = {pair: count for pair, count in sorted(pair_counts.items(), key=lambda item: item[1], reverse=True)}

    return pair_dict

In [75]:
def get_most_freq(pair_freq):

    #rank the keys by value
    return max(pair_freq, key=pair_freq.get)

In [76]:
print(get_freq_text(text))

[87, 97, 110, 116, 32, 116, 111, 32, 98, 117, 121, 32, 116, 104, 101, 32, 116, 101, 120, 116, 98, 111, 111, 107, 115, 32, 102, 114, 111, 109, 32, 116, 104, 101, 32, 98, 111, 111, 107, 115, 104, 111, 112, 44, 32, 84, 104, 101, 32, 98, 111, 111, 107, 115, 104, 111, 112, 32, 105, 115, 32, 105, 110, 32, 67, 97, 105, 114, 111]


In [77]:
text_bytes_ids = get_freq_text(text)

max_vocab_id = max(set(text_bytes_ids))
vocab = sorted(set(text_bytes_ids))
vocab_str = [chr(vocab[i]) for i in range(0, len(vocab))]

vocab_dict = dict(zip(vocab, vocab_str))

In [78]:
print("the number of bytes in text = {}, the number of vocab = {}".format(len(text_bytes_ids), len(vocab)))

the number of bytes in text = 69, the number of vocab = 22


In [79]:
len(text_bytes_ids)

69

In [80]:
text_bytes_ids[67]

114

In [116]:
def merge_ids(text_bytes_ids, vocab_dict, most_freq):
    
    # merge the most frequent pair
    max_vocab_id = max(set(text_bytes_ids))

    for i in range(len(text_bytes_ids) - 1):
        if (text_bytes_ids[i], text_bytes_ids[i + 1]) == most_freq:
            text_bytes_ids[i] = max_vocab_id + 1
            text_bytes_ids[i + 1] = None
    
    decoded_new_vocab = ''.join(vocab_dict.get(i) for i in most_freq)
    
    vocab_dict[max_vocab_id + 1] = decoded_new_vocab
    
    text_bytes_ids = [id for id in text_bytes_ids if id is not None]
    
    return text_bytes_ids, vocab_dict, max_vocab_id + 1

def merge_ids_encoding(text_bytes_ids, pair, position):

    for i in range(len(text_bytes_ids) - 1):
        if (text_bytes_ids[i], text_bytes_ids[i + 1]) == pair:
            text_bytes_ids[i] = position
            text_bytes_ids[i + 1] = None
        
    text_bytes_ids = [id for id in text_bytes_ids if id is not None]
    
    return text_bytes_ids

In [82]:
def merge_tokens(text_bytes_ids, vocab_dict, num_merges):

    merged_dict = {}
    for i in range(num_merges):
        pair_freq = get_freq_ids(text_bytes_ids)
        most_freq = get_most_freq(pair_freq)
        print(most_freq)
        text_bytes_ids, vocab_dict, merge_id = merge_ids(text_bytes_ids, vocab_dict, most_freq)
        merged_dict[most_freq] = merge_id

    return text_bytes_ids, vocab_dict, merged_dict

In [83]:
text_bytes_ids, vocab_dict, merged_dict = merge_tokens(text_bytes_ids, vocab_dict, num_merges = 16)

(32, 116)
(32, 98)
(104, 101)
(111, 111)
(125, 107)
(126, 115)
(122, 124)
(114, 111)
(123, 127)
(130, 104)
(131, 111)
(132, 112)
(32, 105)
(87, 97)
(135, 110)
(136, 116)


Now we can see that the string bytes length decreased by 4 which is the frequency of the most frequent pair

In [85]:
def decode_ids(text_ids, vocab_dict):
    decoded_text = ''.join(vocab_dict.get(i) for i in text_ids)
    return decoded_text

In [86]:
decode_ids(text_bytes_ids, vocab_dict)

'Want to buy the textbooks from the bookshop, The bookshop is in Cairo'

In [125]:
def encode(text_str):
    token_ids = list(text_str.encode("utf-8"))

    while len(token_ids) >= 2:

        pairs = list(get_pairs(token_ids))
        min_merge_pair = min(pairs, key= lambda p: merged_dict.get(p, float('inf')))
        
        if min_merge_pair not in merged_dict:
            break
        else:
            token_ids = merge_ids_encoding(token_ids, min_merge_pair, merged_dict[min_merge_pair])
        

    return token_ids

In [126]:
test = "book in a bookshop"

In [136]:
len(list(test.encode("utf-8"))), len(encode(test))

(18, 7)

In [137]:
decode_ids(encode(test), vocab_dict)

'book in a bookshop'