<a href="https://colab.research.google.com/github/ajayrfhp/LearningDeepLearning/blob/main/bytepairencoding.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Bytepair encoding
- Find the most common pair of bytes, merge them into 2, add to vocab. Repeat untill no common pair of byte occurs more once.

In [1]:

text = "My name is ajay. This is a random sentence. This is a ranasd"

text_encoded = list(text.encode('utf-8'))
type(text_encoded)

vocab = {}
vocab_size = 0

for chr_code in range(256):
  vocab[chr_code] = chr(chr_code)
  vocab_size += 1

vocab
list(text_encoded)

[77,
 121,
 32,
 110,
 97,
 109,
 101,
 32,
 105,
 115,
 32,
 97,
 106,
 97,
 121,
 46,
 32,
 84,
 104,
 105,
 115,
 32,
 105,
 115,
 32,
 97,
 32,
 114,
 97,
 110,
 100,
 111,
 109,
 32,
 115,
 101,
 110,
 116,
 101,
 110,
 99,
 101,
 46,
 32,
 84,
 104,
 105,
 115,
 32,
 105,
 115,
 32,
 97,
 32,
 114,
 97,
 110,
 97,
 115,
 100]

In [2]:
def get_pair_count(text):
  counts = {}
  for i in range(len(text)-1):
    pair = tuple(text[i:i+2])
    counts[pair] = counts.get(pair, 0) + 1
  return counts

pair_count = get_pair_count(text_encoded)
pair_count

for k, v in sorted(pair_count.items(), key=lambda x: x[1], reverse=True):
  print(k, v)

(105, 115) 5
(115, 32) 5
(32, 105) 3
(32, 97) 3
(110, 97) 2
(46, 32) 2
(32, 84) 2
(84, 104) 2
(104, 105) 2
(97, 32) 2
(32, 114) 2
(114, 97) 2
(97, 110) 2
(101, 110) 2
(77, 121) 1
(121, 32) 1
(32, 110) 1
(97, 109) 1
(109, 101) 1
(101, 32) 1
(97, 106) 1
(106, 97) 1
(97, 121) 1
(121, 46) 1
(110, 100) 1
(100, 111) 1
(111, 109) 1
(109, 32) 1
(32, 115) 1
(115, 101) 1
(110, 116) 1
(116, 101) 1
(110, 99) 1
(99, 101) 1
(101, 46) 1
(97, 115) 1
(115, 100) 1


- Merge common tokens

In [18]:
# Add symbol for most frequent pair in vocab and run encoding again to replace most frequent pair with new symbol.

def merge(text_encoded, pair, idx):
  i = 0

  while i < len(text_encoded) - 1:
    if text_encoded[i] == pair[0] and text_encoded[i+1] == pair[1]:
      text_encoded[i] = idx
      text_encoded.pop(i+1)
    else:
      i += 1



  return text_encoded


In [4]:
merge(
    text_encoded=[1, 2, 3, 4, 5, 5],
    pair=[1, 2],
    idx=10
)

[10, 3, 4, 5, 5]

- Grab big text

In [5]:
import requests
big_text_url = "https://raw.githubusercontent.com/dscape/spell/refs/heads/master/test/resources/big.txt"

big_text = requests.get(big_text_url).text
big_text = big_text[:10000]
big_text_encoded = list(big_text.encode('utf-8'))
len(big_text_encoded)

10000

In [6]:
pair_count = get_pair_count(big_text_encoded)
sorted_pair_count = sorted(pair_count.items(), key=lambda x: x[1], reverse=True)
sorted_pair_count[:10]

[((101, 32), 305),
 ((32, 116), 188),
 ((32, 97), 179),
 ((104, 101), 160),
 ((116, 104), 160),
 ((115, 32), 158),
 ((100, 32), 154),
 ((116, 32), 145),
 ((105, 110), 132),
 ((101, 114), 120)]

- Prepare vocab

In [7]:

max_vocab_size = 300

vocab = {}
vocab_size = 0

for chr_code in range(256):
  vocab[chr_code] = chr(chr_code)
  vocab_size += 1


vocab

{0: '\x00',
 1: '\x01',
 2: '\x02',
 3: '\x03',
 4: '\x04',
 5: '\x05',
 6: '\x06',
 7: '\x07',
 8: '\x08',
 9: '\t',
 10: '\n',
 11: '\x0b',
 12: '\x0c',
 13: '\r',
 14: '\x0e',
 15: '\x0f',
 16: '\x10',
 17: '\x11',
 18: '\x12',
 19: '\x13',
 20: '\x14',
 21: '\x15',
 22: '\x16',
 23: '\x17',
 24: '\x18',
 25: '\x19',
 26: '\x1a',
 27: '\x1b',
 28: '\x1c',
 29: '\x1d',
 30: '\x1e',
 31: '\x1f',
 32: ' ',
 33: '!',
 34: '"',
 35: '#',
 36: '$',
 37: '%',
 38: '&',
 39: "'",
 40: '(',
 41: ')',
 42: '*',
 43: '+',
 44: ',',
 45: '-',
 46: '.',
 47: '/',
 48: '0',
 49: '1',
 50: '2',
 51: '3',
 52: '4',
 53: '5',
 54: '6',
 55: '7',
 56: '8',
 57: '9',
 58: ':',
 59: ';',
 60: '<',
 61: '=',
 62: '>',
 63: '?',
 64: '@',
 65: 'A',
 66: 'B',
 67: 'C',
 68: 'D',
 69: 'E',
 70: 'F',
 71: 'G',
 72: 'H',
 73: 'I',
 74: 'J',
 75: 'K',
 76: 'L',
 77: 'M',
 78: 'N',
 79: 'O',
 80: 'P',
 81: 'Q',
 82: 'R',
 83: 'S',
 84: 'T',
 85: 'U',
 86: 'V',
 87: 'W',
 88: 'X',
 89: 'Y',
 90: 'Z',
 91: '[',


In [19]:
num_merges = max_vocab_size - vocab_size
for i in range(num_merges):
  pair_count = get_pair_count(big_text_encoded)
  sorted_pair_count = sorted(pair_count.items(), key=lambda x: x[1], reverse=True)
  most_frequent_pair = sorted_pair_count[0][0]
  print(most_frequent_pair)
  merge(big_text_encoded, most_frequent_pair, idx=vocab_size)
  vocab_size += 1
  vocab[vocab_size] = ''.join(map(chr, most_frequent_pair))
  print(f"merging {most_frequent_pair} to {vocab[vocab_size]}")




(101, 114)
merging (101, 114) to er
(115, 32)
merging (115, 32) to s 
(256, 256)
merging (256, 256) to ĀĀ
(100, 32)
merging (100, 32) to d 
(116, 32)
merging (116, 32) to t 
(111, 117)
merging (111, 117) to ou
(101, 110)
merging (101, 110) to en
(256, 32)
merging (256, 32) to Ā 
(111, 110)
merging (111, 110) to on
(121, 32)
merging (121, 32) to y 
(115, 256)
merging (115, 256) to sĀ
(44, 32)
merging (44, 32) to , 
(111, 102)
merging (111, 102) to of
(46, 32)
merging (46, 32) to . 
(111, 32)
merging (111, 32) to o 
(105, 116)
merging (105, 116) to it
(114, 101)
merging (114, 101) to re
(10, 10)
merging (10, 10) to 


(104, 105)
merging (104, 105) to hi
(97, 110)
merging (97, 110) to an
(256, 110)
merging (256, 110) to Ān
(111, 114)
merging (111, 114) to or
(104, 256)
merging (104, 256) to hĀ
(104, 97)
merging (104, 97) to ha
(256, 103)
merging (256, 103) to Āg
(97, 114)
merging (97, 114) to ar
(101, 100)
merging (101, 100) to ed
(111, 119)
merging (111, 119) to ow
(115, 116)
merging (11

In [20]:
vocab

{0: '\x00',
 1: '\x01',
 2: '\x02',
 3: '\x03',
 4: '\x04',
 5: '\x05',
 6: '\x06',
 7: '\x07',
 8: '\x08',
 9: '\t',
 10: '\n',
 11: '\x0b',
 12: '\x0c',
 13: '\r',
 14: '\x0e',
 15: '\x0f',
 16: '\x10',
 17: '\x11',
 18: '\x12',
 19: '\x13',
 20: '\x14',
 21: '\x15',
 22: '\x16',
 23: '\x17',
 24: '\x18',
 25: '\x19',
 26: '\x1a',
 27: '\x1b',
 28: '\x1c',
 29: '\x1d',
 30: '\x1e',
 31: '\x1f',
 32: ' ',
 33: '!',
 34: '"',
 35: '#',
 36: '$',
 37: '%',
 38: '&',
 39: "'",
 40: '(',
 41: ')',
 42: '*',
 43: '+',
 44: ',',
 45: '-',
 46: '.',
 47: '/',
 48: '0',
 49: '1',
 50: '2',
 51: '3',
 52: '4',
 53: '5',
 54: '6',
 55: '7',
 56: '8',
 57: '9',
 58: ':',
 59: ';',
 60: '<',
 61: '=',
 62: '>',
 63: '?',
 64: '@',
 65: 'A',
 66: 'B',
 67: 'C',
 68: 'D',
 69: 'E',
 70: 'F',
 71: 'G',
 72: 'H',
 73: 'I',
 74: 'J',
 75: 'K',
 76: 'L',
 77: 'M',
 78: 'N',
 79: 'O',
 80: 'P',
 81: 'Q',
 82: 'R',
 83: 'S',
 84: 'T',
 85: 'U',
 86: 'V',
 87: 'W',
 88: 'X',
 89: 'Y',
 90: 'Z',
 91: '[',


In [25]:
reverse_vocab = {v: k for k, v in vocab.items()}
max_token_size = max(map(len, reverse_vocab.keys()))
max_token_size

2

In [27]:
def encode(text, reverse_vocab):
  i = 0
  text_encoded = []
  while i < len(text):
    for j in range(max_token_size, 0, -1):
      potential_token = text[i:i+j]
      if potential_token in reverse_vocab:
        text_encoded.append(reverse_vocab[potential_token])
        i += j
        break
  return text_encoded


def decode(text_encoded, vocab):
  text = ""
  for code in text_encoded:
    text += vocab[code]
  return text

encoded_text = encode("Hello this is Ajay", reverse_vocab)
print(encoded_text)
decoded_text = decode(encoded_text, vocab)
print(decoded_text)


[72, 101, 294, 271, 116, 275, 258, 105, 258, 65, 106, 97, 121]
Hello this is Ajay
