#### Lets do tokenization in non-english language Telugu

In [1]:
## write me a code to tokenize the text in telugu, with BPE in python.
sample_text = "తెలుగు భాష ఒక ద్రావిడ భాష."
print('before encoding:', sample_text, len(sample_text))
tokens = sample_text.encode("utf-8")
tokens = list(map(int, tokens))
print('after encoding:', tokens)
print('length of tokens:', len(tokens))

before encoding: తెలుగు భాష ఒక ద్రావిడ భాష. 26
after encoding: [224, 176, 164, 224, 177, 134, 224, 176, 178, 224, 177, 129, 224, 176, 151, 224, 177, 129, 32, 224, 176, 173, 224, 176, 190, 224, 176, 183, 32, 224, 176, 146, 224, 176, 149, 32, 224, 176, 166, 224, 177, 141, 224, 176, 176, 224, 176, 190, 224, 176, 181, 224, 176, 191, 224, 176, 161, 32, 224, 176, 173, 224, 176, 190, 224, 176, 183, 46]
length of tokens: 68


In [1]:
import pandas as pd
import re

# Load the CSV files
file_paths = [
    '/Users/anvesh/codebase/llm/data/telugu_books/telugu_books.csv',
    '/Users/anvesh/codebase/llm/data/telugu_news/1_telugu_news.csv',
    '/Users/anvesh/codebase/llm/data/telugu_news/2_telugu_news.csv'
]

# Combine data from all files
telugu_texts = []
for file_path in file_paths:
    df = pd.read_csv(file_path)
    if 'text' in df.columns:
        telugu_texts.append(' '.join(df['text'].astype(str).tolist()))
    elif 'body' in df.columns:
        telugu_texts.append(' '.join(df['body'].astype(str).tolist()))

# Concatenate all texts and remove all English, numerical values, and quotes
telugu_text = ' '.join(telugu_texts)
telugu_text = re.sub(r'[A-Za-z0-9\'"]', '', telugu_text)  # Remove English letters, numbers, and quotes
telugu_text = re.sub(r'[\r\n\xa0]', '', telugu_text)  # Remove line breaks and non-breaking spaces

print('telugu_text befores utf-8 encoding:', telugu_text[:100])


telugu_text befores utf-8 encoding:  సుశీలమ్మ కళ్ళలో భయం పారాడింది. అనాధ బిడ్డ అని చిన్నప్పుడే తెలిస్తే మన దగ్గిరవాడు అలా అరమరిక లేకుండా


In [2]:
vocabulary_size = len(set(telugu_text.split()))
print('Original text size:', len(telugu_text))
print('Vocabulary size of telugu_text:', vocabulary_size)

unique_characters = set(telugu_text)
unique_count = len(unique_characters)
print('Original text size:', len(telugu_text))
print('Unique character count in telugu_text:', unique_count)



Original text size: 143512307
Vocabulary size of telugu_text: 1985864
Original text size: 143512307
Unique character count in telugu_text: 194


In [3]:
import utils.encode_parallel_telugu as encode_parallel
import time

tokens = encode_parallel.load_telugu_texts()
# Start the timer
start_time = time.time()
# Encode the tokens in parallel and get concatenated results
encoded_tokens = encode_parallel.encode_tokens_parallel(tokens, chunk_size=1_000_000, max_workers=10)
print('encoded_tokens:', encoded_tokens[:100])
print(len(encoded_tokens))
# End the timer
end_time = time.time()
print(f"Time taken to encode and process tokens in parallel: {end_time - start_time:.4f} seconds")

Processing Chunks: 100%|██████████| 144/144 [00:08<00:00, 17.41it/s]


encoded_tokens: [b' ', b'\xe0\xb0\xb8', b'\xe0\xb1\x81', b'\xe0\xb0\xb6', b'\xe0\xb1\x80', b'\xe0\xb0\xb2', b'\xe0\xb0\xae', b'\xe0\xb1\x8d', b'\xe0\xb0\xae', b' ', b'\xe0\xb0\x95', b'\xe0\xb0\xb3', b'\xe0\xb1\x8d', b'\xe0\xb0\xb3', b'\xe0\xb0\xb2', b'\xe0\xb1\x8b', b' ', b'\xe0\xb0\xad', b'\xe0\xb0\xaf', b'\xe0\xb0\x82', b' ', b'\xe0\xb0\xaa', b'\xe0\xb0\xbe', b'\xe0\xb0\xb0', b'\xe0\xb0\xbe', b'\xe0\xb0\xa1', b'\xe0\xb0\xbf', b'\xe0\xb0\x82', b'\xe0\xb0\xa6', b'\xe0\xb0\xbf', b'.', b' ', b'\xe0\xb0\x85', b'\xe0\xb0\xa8', b'\xe0\xb0\xbe', b'\xe0\xb0\xa7', b' ', b'\xe0\xb0\xac', b'\xe0\xb0\xbf', b'\xe0\xb0\xa1', b'\xe0\xb1\x8d', b'\xe0\xb0\xa1', b' ', b'\xe0\xb0\x85', b'\xe0\xb0\xa8', b'\xe0\xb0\xbf', b' ', b'\xe0\xb0\x9a', b'\xe0\xb0\xbf', b'\xe0\xb0\xa8', b'\xe0\xb1\x8d', b'\xe0\xb0\xa8', b'\xe0\xb0\xaa', b'\xe0\xb1\x8d', b'\xe0\xb0\xaa', b'\xe0\xb1\x81', b'\xe0\xb0\xa1', b'\xe0\xb1\x87', b' ', b'\xe0\xb0\xa4', b'\xe0\xb1\x86', b'\xe0\xb0\xb2', b'\xe0\xb0\xbf', b'\xe0\xb0\xb8', b'\xe

In [None]:
print('length of encoded_text:', len(encoded_tokens))
print('unique characters in encoded_text:', set(encoded_tokens))
print('unique characters in encoded_text:', len(set(encoded_tokens)))

In [None]:
merges

In [None]:
encoded_tokens[:100]

In [5]:
# ## lets read with bigger text
# with open('sample_telugu.txt', 'r') as f:
#     text = f.read()
# tokens = telugu_text.encode("utf-8")
#### **BPE implementation**

tokens = encoded_tokens

def get_stats(ids):
    counts = {}
    for pair in zip(ids, ids[1:]):
        counts[pair] = counts.get(pair, 0) + 1
    return counts

def merge(ids, pair, idx):
    new_ids = []
    i = 0
    while i < len(ids):
        if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
            new_ids.append(idx)
            i += 2
        else:
            new_ids.append(ids[i])
            i += 1
    return new_ids

# ---
vocab_size = 500 # the desired final vocabulary size
num_merges = vocab_size - 256 ## our unique tokens are 194, for our sample text.
ids = list(tokens) # copy so we don't destroy the original list

merges = {} # (int, int) -> int
from tqdm import tqdm  # Import tqdm for progress bar

for i in tqdm(range(num_merges), desc="Merging tokens"):
    stats = get_stats(ids)
    pair = max(stats, key=stats.get)
    idx = 256 + i
    # print(f"merging {pair} into a new token {idx}")
    ids = merge(ids, pair, idx)
    merges[pair] = idx # merge has a pair of tokens and the new token index
    
print("tokens length:", len(tokens))
print("ids length:", len(ids))
print(f"compression ratio: {len(tokens) / len(ids):.2f}X")
print(f"token size: {len(set(tokens))}")
    
# print(ids)

Merging tokens: 100%|██████████| 244/244 [8:39:11<00:00, 127.67s/it]    

tokens length: 143512307
ids length: 77428527
compression ratio: 1.85X
token size: 194





In [9]:
len(set(ids))

438

In [None]:
[print(i,end=' ') for i in set(ids)]

In [7]:
len(encoded_tokens)

143512307

In [None]:
# ## lets read with bigger text
# with open('sample_telugu.txt', 'r') as f:
#     text = f.read()
# tokens = telugu_text.encode("utf-8")
#### **BPE implementation**

tokens = encoded_tokens

def get_stats(ids):
    counts = {}
    for pair in zip(ids, ids[1:]):
        counts[pair] = counts.get(pair, 0) + 1
    return counts

def merge(ids, pair, idx):
    new_ids = []
    i = 0
    while i < len(ids):
        if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
            new_ids.append(idx)
            i += 2
        else:
            new_ids.append(ids[i])
            i += 1
    return new_ids

# ---
vocab_size = 500 # the desired final vocabulary size
num_merges = vocab_size - 256 ## our unique tokens are 194, for our sample text.
ids = list(tokens) # copy so we don't destroy the original list

merges = {} # (int, int) -> int
from tqdm import tqdm  # Import tqdm for progress bar

for i in tqdm(range(num_merges), desc="Merging tokens"):
    stats = get_stats(ids)
    pair = max(stats, key=stats.get)
    idx = 256 + i
    # print(f"merging {pair} into a new token {idx}")
    ids = merge(ids, pair, idx)
    merges[pair] = idx # merge has a pair of tokens and the new token index
    
print("tokens length:", len(tokens))
print("ids length:", len(ids))
print(f"compression ratio: {len(tokens) / len(ids):.2f}X")
print(f"token size: {len(set(tokens))}")
    
# print(ids)

In [6]:
millions = 143512307 / 1_000_000
print(f"The number in millions is: {millions}")


The number in millions is: 143.512307


In [None]:
import json
with open('merges.json', 'w') as f:
        json.dump(merges, f)  # Save merges separately

In [75]:
telugu_unicode_chars = [chr(i) for i in range(0x0C00, 0x0C7F)]  # Telugu Unicode range

# Add these characters to the vocabulary
import json

vocab = {token: idx for token, idx in merges.items()}



# Add unique Telugu characters to the vocabulary
for idx, char in enumerate([chr(i).encode('utf-8') for i in range(0x0C00, 0x0C7F)]):
    if idx < 256:  # Ensure we only add up to 256 characters
        vocab[char] = idx  # Map the character to its index

vocab[b' '] = 255
vocab[b'.'] = 254
# Save merges and vocab to a file
# with open('merges_vocab.json', 'w') as f:
#     json.dump({'merges': merges, 'vocab': vocab}, f)

In [None]:
vocab

In [43]:
with open('merges_vocab.json', 'w') as f:
    json.dump({'merges': {str(k): v for k, v in merges.items()}, 'vocab': {str(k): v for k, v in vocab.items()}}, f)

In [None]:
import json
from collections import defaultdict

# Read the merges and vocab data from the JSON file
with open('merges_vocab.json', 'r') as f:
    data = json.load(f)

# Create a defaultdict to store the data in a distributed manner
distributed_data = defaultdict(list)

# Distribute the merges and vocab data
# for key, value in data['merges'].items():
#     distributed_data['merges'].append({key: value})

for key, value in data['vocab'].items():
    distributed_data['vocab'].append({key: value})

# Optionally, print the distributed data for verification
print(distributed_data)
distributed_data['vocab']
# Convert the list of dictionaries to a single dictionary
formatted_vocab = {}
for item in distributed_data['vocab']:
    for k, v in item.items():
        if ',' not in k:
            formatted_vocab[(eval(k),)] = v
        else:
            formatted_vocab[eval(k)] = v
print(formatted_vocab)

In [None]:
formatted_vocab

In [None]:
formatted_vocab

In [132]:
inverted_vocab = {v: k for k, v in formatted_vocab.items()}
inverted_vocab

{256: (b'\xe0\xb0\xbf', b' '),
 257: (b'\xe0\xb1\x81', b' '),
 258: (b'.', b' '),
 259: (b'\xe0\xb0\xa8', b'\xe0\xb1\x8d'),
 260: (b'\xe0\xb0\xbe', b' '),
 261: (b'\xe0\xb0\x82', b'\xe0\xb0\xa6'),
 262: (259, b'\xe0\xb0\xa8'),
 263: (b'\xe0\xb0\xb0', b'\xe0\xb1\x8d'),
 264: (b'\xe0\xb1\x8d', b'\xe0\xb0\xb0'),
 265: (b'\xe0\xb0\xb8', b'\xe0\xb1\x8d'),
 266: (b'\xe0\xb1\x8b', b' '),
 267: (b'\xe0\xb1\x81', 258),
 268: (b'\xe0\xb0\x82', b' '),
 269: (b'\xe0\xb0\xbe', b'\xe0\xb0\xb0'),
 270: (b'\xe0\xb0\xa8', b'\xe0\xb0\xbf'),
 271: (b'\xe0\xb1\x87', b' '),
 272: (b'\xe0\xb1\x81', b'\xe0\xb0\x95'),
 273: (b'\xe0\xb0\xbe', b'\xe0\xb0\xb2'),
 274: (b'\xe0\xb0\xa8', 256),
 275: (b'\xe0\xb0\x82', b'\xe0\xb0\x9a'),
 276: (b'\xe0\xb0\x9f', b'\xe0\xb1\x8d'),
 277: (b'\xe0\xb0\xbf', 258),
 278: (b'\xe0\xb0\x95', b'\xe0\xb1\x8d'),
 279: (b'\xe0\xb0\xbe', b'\xe0\xb0\xa1'),
 280: (b' ', b'\xe0\xb0\x85'),
 281: (b'\xe0\xb1\x81', b'\xe0\xb0\xb2'),
 282: (b'.', b'.'),
 283: (b'\xe0\xb0\x82', b'\xe0\xb0\

In [133]:
def convert_to_bytes(value):
    if isinstance(value, bytes):
        return value
    elif value in inverted_vocab:
        return process_tuple(inverted_vocab[value])
    else:
        print(f'value not found in inverted_vocab: {value}')
        return None

def process_tuple(value_tuple):
    # print(f'value_tuple: {value_tuple}')
    # for vi in value_tuple:
    #     print(f'v: {vi}')
    converted_values = []
    for v in value_tuple:
        result = convert_to_bytes(v)
        if isinstance(result, tuple):
            converted_values.extend(result)
        else:
            converted_values.append(result)
    return tuple(converted_values)

decoder_map = {k: process_tuple(v) for k, v in inverted_vocab.items()}





In [136]:
text = "తెలుగు భాష ఒక ద్రావిడ భాష."
li = encode_parallel.encode_tokens_parallel(text, chunk_size=1_000_000, max_workers=10)
_ = [print(i.decode('utf-8'), end=' ') for i in li]
print('\n')
_ = [print(vocab[i], end=' ') for i in li if i in vocab]
print('\n')
_ = [print(i, end=' ') for i in li if i not in vocab]
encoed_li  = [vocab[i] for i in li if i in vocab]
print('\n')
decoded_text = ''.join([k.decode('utf-8') for i in encoed_li for k, v in vocab.items() if v == i])
print(decoded_text)


Processing Chunks: 100%|██████████| 1/1 [00:00<00:00,  3.12it/s]

త ె ల ు గ ు   భ ా ష   ఒ క   ద ్ ర ా వ ి డ   -   భ ా ష . 

36 70 50 65 23 65 255 45 62 55 255 18 21 255 38 77 48 62 53 63 33 255 255 45 62 55 254 

b'-' 

తెలుగు భాష ఒక ద్రావిడ  భాష.





In [46]:
# text = "తెలుగు భాష ఒక ద్రావిడ భాష."
# li = encode_parallel.encode_tokens_parallel(text, chunk_size=1_000_000, max_workers=10)
# # [print(i.decode('utf-8')) for i in li]

In [None]:
def encode_text(text, mapping):
    encoded = []
    # Convert the text to bytes
    byte_text = encode_parallel.encode_tokens_parallel(text, chunk_size=1_000_000, max_workers=10)
    print(f'byte_text: {byte_text}')
    print(f'mapping: {mapping}')
    # Iterate through the byte pairs in the text
    i = 0
    while i < len(byte_text):
        # Check for pairs
        if i < len(byte_text) - 1:
            pair = (byte_text[i:i+2], byte_text[i+1:i+2])
            # pair = tuple(pair)
            print(f'pair: {pair}')
            if pair in mapping:
                print(f'pair: {pair}')
                encoded.append(mapping[pair])
                i += 2  # Move past the pair
                continue
        
        # Check for single byte matches
        single_byte = byte_text[i:i+1]
        if single_byte in mapping:
            encoded.append(mapping[single_byte])
        else:
            # If no match, append the original byte as an index (or handle as needed)
            encoded.append(single_byte)
        
        i += 1  # Move to the next byte

    return encoded

# Example usage
text_to_encode = "తెలుగు భాష ఒక ద్రావిడ భాష."  # Replace with your actual text
encoded_output = encode_text(text_to_encode, vocab)
print(encoded_output)

In [None]:
text = "తెలుగు భాష ఒక ద్రావిడ భాష."
print(text)

byte_text = encode_parallel.encode_tokens_parallel(text, chunk_size=1_000_000, max_workers=10)
i = 0
while i < len(byte_text):
    if i < len(byte_text) - 1:
        te = []
        for i,val in enumerate(byte_text[i:i+2]):
            te.append(val)
        for i,val in enumerate(byte_text[i+1:i+2]):
            te.append(val)
        pair = tuple(te)
        print(f'pair: {pair}')
        # if pair in vocab:
        #     encoded.append(vocab[pair])
        i += 2  # Move past the pair
# (byte_text[i:i+2], byte_text[i+1:i+2])

In [56]:
i

b'\xe0\xb1\x86'

In [55]:
byte_text[i+1:i+2]

TypeError: can't concat int to bytes

In [None]:
unpacked_list = [item for sublist in li for item in sublist]
print(unpacked_list)


In [58]:
tokens = 'abc d'.encode('utf-8')
tokens = list(map(int, tokens))
print(tokens)

[97, 98, 99, 32, 100]


In [50]:
print(b'\xe0\xb0\x80'.decode('utf-8'))
print(b'\xe0\xb0\xa4\xe0\xb1\x86\xe0\xb0\xb2\xe0\xb1\x81\xe0\xb0\x97\xe0\xb1\x81'.decode('utf-8'))

ఀ
తెలుగు


In [56]:
li = [b'\xe0', b'\xb0', b'\xa4', b'\xe0', b'\xb1', b'\x86', b'\xe0', b'\xb0', b'\xb2', b'\xe0', b'\xb1', b'\x81', b'\xe0', b'\xb0', b'\x97', b'\xe0', b'\xb1', b'\x81', b' ', b'\xe0', b'\xb0', b'\xad', b'\xe0', b'\xb0', b'\xbe', b'\xe0', b'\xb0', b'\xb7', b' ', b'\xe0', b'\xb0', b'\x92', b'\xe0', b'\xb0', b'\x95', b' ', b'\xe0', b'\xb0', b'\xa6', b'\xe0', b'\xb1', b'\x8d', b'\xe0', b'\xb0', b'\xb0', b'\xe0', b'\xb0', b'\xbe', b'\xe0', b'\xb0', b'\xb5', b'\xe0', b'\xb0', b'\xbf', b'\xe0', b'\xb0', b'\xa1', b' ', b'\xe0', b'\xb0', b'\xad', b'\xe0', b'\xb0', b'\xbe', b'\xe0', b'\xb0', b'\xb7', b'.']
# for i in li:
print(li.decode('utf-8'))


AttributeError: 'list' object has no attribute 'decode'

In [2]:
decoded_text = b''.join([b'\xe0', b'\xb0', b'\xa4', b'\xe0', b'\xb1', b'\x86', b'\xe0', b'\xb0', b'\xb2', b'\xe0', b'\xb1', b'\x81', b'\xe0', b'\xb0', b'\x97', b'\xe0', b'\xb1', b'\x81', b' ', b'\xe0', b'\xb0', b'\xad', b'\xe0', b'\xb0', b'\xbe', b'\xe0', b'\xb0', b'\xb7', b' ', b'\xe0', b'\xb0', b'\x92', b'\xe0', b'\xb0', b'\x95', b' ', b'\xe0', b'\xb0', b'\xa6', b'\xe0', b'\xb1', b'\x8d', b'\xe0', b'\xb0', b'\xb0', b'\xe0', b'\xb0', b'\xbe', b'\xe0', b'\xb0', b'\xb5', b'\xe0', b'\xb0', b'\xbf', b'\xe0', b'\xb0', b'\xa1', b' ', b'\xe0', b'\xb0', b'\xad', b'\xe0', b'\xb0', b'\xbe', b'\xe0', b'\xb0', b'\xb7', b'.'])
decoded_text = decoded_text.decode('utf-8')
print(decoded_text)


తెలుగు భాష ఒక ద్రావిడ భాష.


In [4]:
b''.join([b'\xe0', b'\xb0', b'\xa4', b'\xe0', b'\xb1', b'\x86', b'\xe0', b'\xb0', b'\xb2', b'\xe0', b'\xb1', b'\x81', b'\xe0', b'\xb0', b'\x97', b'\xe0', b'\xb1', b'\x81', b' ', b'\xe0', b'\xb0', b'\xad', b'\xe0', b'\xb0', b'\xbe', b'\xe0', b'\xb0', b'\xb7', b' ', b'\xe0', b'\xb0', b'\x92', b'\xe0', b'\xb0', b'\x95', b' ', b'\xe0', b'\xb0', b'\xa6', b'\xe0', b'\xb1', b'\x8d', b'\xe0', b'\xb0', b'\xb0', b'\xe0', b'\xb0', b'\xbe', b'\xe0', b'\xb0', b'\xb5', b'\xe0', b'\xb0', b'\xbf', b'\xe0', b'\xb0', b'\xa1', b' ', b'\xe0', b'\xb0', b'\xad', b'\xe0', b'\xb0', b'\xbe', b'\xe0', b'\xb0', b'\xb7', b'.'])


b'\xe0\xb0\xa4\xe0\xb1\x86\xe0\xb0\xb2\xe0\xb1\x81\xe0\xb0\x97\xe0\xb1\x81 \xe0\xb0\xad\xe0\xb0\xbe\xe0\xb0\xb7 \xe0\xb0\x92\xe0\xb0\x95 \xe0\xb0\xa6\xe0\xb1\x8d\xe0\xb0\xb0\xe0\xb0\xbe\xe0\xb0\xb5\xe0\xb0\xbf\xe0\xb0\xa1 \xe0\xb0\xad\xe0\xb0\xbe\xe0\xb0\xb7.'

In [36]:
def encode_text(text, mapping):
    encoded = []
    # Convert the text to bytes
    byte_text = text.encode('utf-8')
    print(f'byte_text: {byte_text}')
    # Iterate through the byte pairs in the text
    i = 0
    while i < len(byte_text):
        # Check for pairs
        if i < len(byte_text) - 1:
            pair = (byte_text[i:i+1], byte_text[i+1:i+1])  # Adjusted to match the mapping structure
            if pair in mapping:
                encoded.append(mapping[pair])
                i += 2  # Move past the pair
                continue
        
        # Check for single byte matches
        single_byte = byte_text[i:i+1]
        if single_byte in mapping:
            encoded.append(mapping[single_byte])
        else:
            # If no match, append a placeholder or handle as needed
            encoded.append(-1)  # Placeholder for unmatched bytes
        
        i += 1  # Move to the next byte

    return encoded

# Example usage
text_to_encode = "తెలుగు భాష ఒక ద్రావిడ భాష."  # Replace with your actual text
encoded_output = encode_text(text_to_encode, vocab)
print(encoded_output)

byte_text: b'\xe0\xb0\xa4\xe0\xb1\x86\xe0\xb0\xb2\xe0\xb1\x81\xe0\xb0\x97\xe0\xb1\x81 \xe0\xb0\xad\xe0\xb0\xbe\xe0\xb0\xb7 \xe0\xb0\x92\xe0\xb0\x95 \xe0\xb0\xa6\xe0\xb1\x8d\xe0\xb0\xb0\xe0\xb0\xbe\xe0\xb0\xb5\xe0\xb0\xbf\xe0\xb0\xa1 \xe0\xb0\xad\xe0\xb0\xbe\xe0\xb0\xb7.'
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]


In [None]:
## now lets try with a larger text.
with open('sample_1000_words.txt', 'r') as f:
    text = f.read()

tokens = text.encode("utf-8") # raw bytes
tokens = list(map(int, tokens)) # convert to a list of integers in range 0..255 just for convenience
print('---')
print(text[:200], '\n')
print("length:", len(text))
print('---')
print(tokens[:100]) ## print first 100 tokens
print("length:", len(tokens))
print('--------------------------------')

def get_stats(ids):
    counts = {}
    for pair in zip(ids, ids[1:]):
        counts[pair] = counts.get(pair, 0) + 1
    return counts

def merge(ids, pair, idx):
    new_ids = []
    i = 0
    while i < len(ids):
        if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
            new_ids.append(idx)
            i += 2
        else:
            new_ids.append(ids[i])
            i += 1
    return new_ids

# ---
vocab_size = 1000 # the desired final vocabulary size
num_merges = vocab_size - 256 ## our unique tokens are 226, for our sample text.
ids = list(tokens) # copy so we don't destroy the original list

merges = {} # (int, int) -> int
for i in range(num_merges):
    stats = get_stats(ids)
    if not stats:  # Check if stats is empty to avoid errors
        break
    pair = max(stats, key=stats.get)
    idx = 256 + i
    ids = merge(ids, pair, idx)
    merges[pair] = idx

print("tokens length:", len(tokens))
print("ids length:", len(ids))
print(f"compression ratio: {len(tokens) / len(ids):.2f}X")