<a href="https://colab.research.google.com/github/BilalAsifB/MyLLM/blob/main/tokenizer/bpe_basic_regex.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# BYTE PAIR ENCODER (BPE)

## Learning resources:
- For basic intution: https://youtu.be/fKd8s29e-l4?si=TSbl0E6pHgWIXHnb
- To really understand: https://youtu.be/zduSFxRajkE?si=EOAf_ey5bXR_1Hw8

In [44]:
!pip install regex




In [45]:
import regex

In [46]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [47]:
class BPE:
    def __init__(self):
        self.ch_encode = []
        self._ch_encode_processed = []
        self._vocab = {}
        self._merges = {}

    def __get_byte_freq(self, ch_encode):
        byte_freq = {}

        for pair in zip(ch_encode, ch_encode[1:]):
            byte_freq[pair] = byte_freq.get(pair, 0) + 1

        return byte_freq

    def __replace_top_pair(self, ch_encode, top_pair, token):
        new_ls = []
        idx = 0
        last_idx = len(ch_encode)

        while(idx < last_idx):
            if (idx < last_idx - 1) and ((ch_encode[idx], ch_encode[idx+1]) == top_pair):
                new_ls.append(token)
                idx += 2
            else:
                new_ls.append(ch_encode[idx])
                idx += 1

        return new_ls

    def __create_vocab(self):
        self._vocab = {idx: bytes([idx]) for idx in range(256)}
        for (p0, p1), idx in self._merges.items():
            self._vocab[idx] = self._vocab[p0] + self._vocab[p1]

    def print_compression(self):
        print(f'Original: {len(self.ch_encode)}\n'
              f'Processed: {len(self._ch_encode_processed)}\n'
              f'Compression ratio: {len(self.ch_encode) / len(self._ch_encode_processed):.2f}x')

    def train_bpe(self, vocab_size, inp):
        assert vocab_size > 256
        self.ch_encode = list(map(int, inp.encode("utf-8")))
        final_vocab = vocab_size
        num_merges = final_vocab - 256
        self._ch_encode_processed = self.ch_encode.copy()

        for i in range(num_merges):
            if len(self._ch_encode_processed) == 1:
                break
            token = i + 256
            byte_freq = self.__get_byte_freq(self._ch_encode_processed)
            top_pair = max(byte_freq, key=byte_freq.get)
            self._ch_encode_processed = self.__replace_top_pair(self._ch_encode_processed, top_pair, token)

            self._merges.update({top_pair: token})

        self.__create_vocab()
        self.print_compression()

    def decode(self, inp):
        tokens = b"".join(self._vocab[idx] for idx in inp)
        text = tokens.decode("utf-8", errors="replace") # replace when utf-8 format isn't strictly followed
        return text

    def encode(self, inp):
        tokens = list(inp.encode("utf-8"))

        while len(tokens) > 1:
            byte_freq = self.__get_byte_freq(tokens)
            eligible_pair = min(byte_freq, key=lambda p: self._merges.get(p, float("inf"))) # gives earliest most merge as we need to follow the merge order
            if eligible_pair not in self._merges: break

            idx = self._merges[eligible_pair]
            tokens = self.__replace_top_pair(tokens, eligible_pair, idx)

        return tokens

In [48]:
class BPE_REGEX:
    def __init__(self):
        self.ch_encode = []
        self._ch_encode_processed = []
        self._vocab = {}
        self._merges = {}
        self.regex_pattern = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
        self.compiled = regex.compile(self.regex_pattern)

    def __get_byte_freq(self, ch_encode):
        byte_freq = {}

        for pair in zip(ch_encode, ch_encode[1:]):
            byte_freq[pair] =  byte_freq.get(pair, 0) + 1

        return byte_freq

    def __replace_top_pair(self, ch_encode, top_pair, token):
        new_ls = []
        idx = 0
        last_idx = len(ch_encode)

        while(idx < last_idx):
            if (idx < last_idx - 1) and ((ch_encode[idx], ch_encode[idx+1]) == top_pair):
                new_ls.append(token)
                idx += 2
            else:
                new_ls.append(ch_encode[idx])
                idx += 1

        return new_ls

    def __create_vocab(self):
        self._vocab = {idx: bytes([idx]) for idx in range(256)}
        for (p0, p1), idx in self._merges.items():
            self._vocab[idx] = self._vocab[p0] + self._vocab[p1]

    def print_compression(self):
        print(f'Original: {len(self.ch_encode)}\n'
              f'Processed: {len(self._ch_encode_processed)}\n'
              f'Compression ratio: {len(self.ch_encode) / len(self._ch_encode_processed):.2f}x')

    def train_bpe(self, vocab_size, inp):
        assert vocab_size > 256
        self.ch_encode = list(map(int, inp.encode("utf-8")))
        final_vocab = vocab_size
        num_merges = final_vocab - 256

        split_inp = regex.findall(self.compiled, inp)
        tokens_encode = [list(token.encode("utf-8")) for token in split_inp]

        for i in range(num_merges):
            byte_freq = {}
            for te in tokens_encode:
                chunk_freq = self.__get_byte_freq(te)
                for pair, count in chunk_freq.items():
                    byte_freq[pair] = byte_freq.get(pair, 0) + count

            if not byte_freq: break

            token = i + 256
            top_pair = max(byte_freq, key=byte_freq.get)
            tokens_encode = [self.__replace_top_pair(te, top_pair, token) for te in tokens_encode]
            self._merges.update({top_pair: token})
            self._ch_encode_processed.extend(te)

        self.__create_vocab()
        self.print_compression()

    def decode(self, inp):
        tokens = b"".join(self._vocab[idx] for idx in inp)
        text = tokens.decode("utf-8", errors="replace") # replace when utf-8 format isn't strictly followed
        return text

    def encode(self, inp):
        split_inp = regex.findall(self.compiled, inp)
        tokens_encode = [list(token.encode("utf-8")) for token in split_inp]

        tokens = []
        for te in tokens_encode:
            while len(te) > 1:
                byte_freq = self.__get_byte_freq(te)
                eligible_pair = min(byte_freq, key=lambda p: self._merges.get(p, float("inf"))) # gives earliest most merge as we need to follow the merge order
                if eligible_pair not in self._merges: break

                idx = self._merges[eligible_pair]
                te = self.__replace_top_pair(te, eligible_pair, idx)

            tokens.extend(te)

        return tokens

In [49]:
input_string = open(r"/content/drive/My Drive/taylorswift.txt").read()

In [50]:
print(input_string[:1000])

Copy paste of the Wikipedia article on Taylor Swift, as of Feb 16, 2024.
---

Main menu

WikipediaThe Free Encyclopedia

Search
Create account
Log in

Personal tools
Contents  hide
(Top)
Life and career
Toggle Life and career subsection
Artistry
Toggle Artistry subsection
Accolades and achievements
Cultural status
Toggle Cultural status subsection
Wealth
Toggle Wealth subsection
Discography
Filmography
Tours
See also
Footnotes
References
Toggle References subsection
External links
Taylor Swift

136 languages
Article
Talk
Read
View source
View history

Tools
 Featured article
Page semi-protected
From Wikipedia, the free encyclopedia
For the album, see Taylor Swift (album).
Taylor Swift
Portrait of Taylor Swift in a cocktail dress
Swift at the 2023 MTV Video Music Awards
Born	Taylor Alison Swift
December 13, 1989 (age 34)
West Reading, Pennsylvania, US
Occupations
Singer-songwriter producer director businesswoman actress
Years active	2004â€“present
Works
Albumssinglessongsvideosperforman

In [51]:
# Simple BPE:
bpe_s = BPE()
bpe_s.train_bpe(256+64, input_string)


Original: 185768
Processed: 116688
Compression ratio: 1.59x


In [52]:
# Regex BPE:
bpe_r = BPE_REGEX()
bpe_r.train_bpe(256+64, input_string)

Original: 185768
Processed: 64
Compression ratio: 2902.62x


In [53]:
test_strings = ["", "?", "hello world!!!? (ì•ˆë…•í•˜ì„¸ìš”!) lol123 ðŸ˜‰"]

In [54]:
print([bpe_s.decode(bpe_s.encode(ts)) for ts in test_strings])

['', '?', 'hello world!!!? (ì•ˆë…•í•˜ì„¸ìš”!) lol123 ðŸ˜‰']


In [55]:
print([bpe_r.decode(bpe_r.encode(ts)) for ts in test_strings])

['', '?', 'hello world!!!? (ì•ˆë…•í•˜ì„¸ìš”!) lol123 ðŸ˜‰']


In [56]:
bpe_s._vocab

{0: b'\x00',
 1: b'\x01',
 2: b'\x02',
 3: b'\x03',
 4: b'\x04',
 5: b'\x05',
 6: b'\x06',
 7: b'\x07',
 8: b'\x08',
 9: b'\t',
 10: b'\n',
 11: b'\x0b',
 12: b'\x0c',
 13: b'\r',
 14: b'\x0e',
 15: b'\x0f',
 16: b'\x10',
 17: b'\x11',
 18: b'\x12',
 19: b'\x13',
 20: b'\x14',
 21: b'\x15',
 22: b'\x16',
 23: b'\x17',
 24: b'\x18',
 25: b'\x19',
 26: b'\x1a',
 27: b'\x1b',
 28: b'\x1c',
 29: b'\x1d',
 30: b'\x1e',
 31: b'\x1f',
 32: b' ',
 33: b'!',
 34: b'"',
 35: b'#',
 36: b'$',
 37: b'%',
 38: b'&',
 39: b"'",
 40: b'(',
 41: b')',
 42: b'*',
 43: b'+',
 44: b',',
 45: b'-',
 46: b'.',
 47: b'/',
 48: b'0',
 49: b'1',
 50: b'2',
 51: b'3',
 52: b'4',
 53: b'5',
 54: b'6',
 55: b'7',
 56: b'8',
 57: b'9',
 58: b':',
 59: b';',
 60: b'<',
 61: b'=',
 62: b'>',
 63: b'?',
 64: b'@',
 65: b'A',
 66: b'B',
 67: b'C',
 68: b'D',
 69: b'E',
 70: b'F',
 71: b'G',
 72: b'H',
 73: b'I',
 74: b'J',
 75: b'K',
 76: b'L',
 77: b'M',
 78: b'N',
 79: b'O',
 80: b'P',
 81: b'Q',
 82: b'R',
 83: b'

In [57]:
bpe_r._vocab

{0: b'\x00',
 1: b'\x01',
 2: b'\x02',
 3: b'\x03',
 4: b'\x04',
 5: b'\x05',
 6: b'\x06',
 7: b'\x07',
 8: b'\x08',
 9: b'\t',
 10: b'\n',
 11: b'\x0b',
 12: b'\x0c',
 13: b'\r',
 14: b'\x0e',
 15: b'\x0f',
 16: b'\x10',
 17: b'\x11',
 18: b'\x12',
 19: b'\x13',
 20: b'\x14',
 21: b'\x15',
 22: b'\x16',
 23: b'\x17',
 24: b'\x18',
 25: b'\x19',
 26: b'\x1a',
 27: b'\x1b',
 28: b'\x1c',
 29: b'\x1d',
 30: b'\x1e',
 31: b'\x1f',
 32: b' ',
 33: b'!',
 34: b'"',
 35: b'#',
 36: b'$',
 37: b'%',
 38: b'&',
 39: b"'",
 40: b'(',
 41: b')',
 42: b'*',
 43: b'+',
 44: b',',
 45: b'-',
 46: b'.',
 47: b'/',
 48: b'0',
 49: b'1',
 50: b'2',
 51: b'3',
 52: b'4',
 53: b'5',
 54: b'6',
 55: b'7',
 56: b'8',
 57: b'9',
 58: b':',
 59: b';',
 60: b'<',
 61: b'=',
 62: b'>',
 63: b'?',
 64: b'@',
 65: b'A',
 66: b'B',
 67: b'C',
 68: b'D',
 69: b'E',
 70: b'F',
 71: b'G',
 72: b'H',
 73: b'I',
 74: b'J',
 75: b'K',
 76: b'L',
 77: b'M',
 78: b'N',
 79: b'O',
 80: b'P',
 81: b'Q',
 82: b'R',
 83: b'