### Jupyter Notebook Settings

In [None]:
from IPython.core.display import display, HTML                                    
display(HTML("<style>.container { width:100% !important; }</style>"))  
import IPython.display as display

### Libraries

In [None]:
import re
import heapq
from tqdm import tqdm
from datasets import load_dataset
from collections import Counter, defaultdict
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed

### Data

In [None]:
# Load the dataset
dataset = load_dataset("bookcorpus/bookcorpus", split='train[:30%]')

### Create BPE Vocabulary from given corpus/dataset

In [None]:
# Initialize counters and locks
word_counter = Counter()
counter_lock = Lock()

# Tokenize and count word frequencies
def process_document_bpe(document):
    local_counter = Counter()
    tokens = re.findall(r'\w+|\S', document.lower())
    local_counter.update(tokens)
    
    # Safely update the global word_counter
    with counter_lock:
        word_counter.update(local_counter)

# Process the dataset with multithreading
with ThreadPoolExecutor(max_workers=8) as executor:  # 8 changes the number of workers
    futures = [executor.submit(process_document_bpe, document) for document in dataset['text']]
    for future in tqdm(as_completed(futures), total=len(futures), desc="Processing dataset"):
        pass

# Convert words to character-level with a special end-of-word token (e.g., '##')
def get_initial_vocab(word_counter):
    vocab = {}
    for word in word_counter:
        word = ' '.join(list(word)) + ' </w>'
        vocab[word] = word_counter[word]
    return vocab

vocab = get_initial_vocab(word_counter)

# BPE algorithm: merging the most frequent pairs
def get_pair_frequencies(vocab):
    pairs = defaultdict(int)
    for word, freq in vocab.items():
        symbols = word.split()
        for i in range(len(symbols) - 1):
            pairs[(symbols[i], symbols[i + 1])] += freq
    return pairs

def merge_vocab(pair, vocab):
    merged_vocab = {}
    bigram = ' '.join(pair)
    replacement = ''.join(pair)
    for word in vocab:
        new_word = word.replace(bigram, replacement)
        merged_vocab[new_word] = vocab[word]
    return merged_vocab

num_merges = 100000  # Define how many merges to perform (controls vocababulary size)
for i in tqdm(range(num_merges), desc="Performing BPE merges"):
    pairs = get_pair_frequencies(vocab)
    if not pairs:
        break
    best_pair = max(pairs, key=pairs.get)
    vocab = merge_vocab(best_pair, vocab)

# Extract the final vocabulary
bpe_vocab = sorted(set(''.join(word.split()) for word in vocab.keys()))

# Save the BPE vocabulary to a txt file
with open("bpe-vocab_bookcorpus-30p_100000-merges.txt", "w") as vocab_file:
    for token in bpe_vocab:
        vocab_file.write(f"{token}\n")

### Clean BPE Vocabulary
##### This will use a set of rules to clean the generated tokens and reduce the final size (only circa 5% reduction)
##### Do note it will remove all the existing symbols or numbers (if they exist) and then readd them at the end of the file, this is done for consistency

In [None]:
# Define a lock for thread-safe operations
lock = threading.Lock()

def is_valid_token(token, token_freq):
    # Expanded whitelist including common short words
    whitelist = {
        "a", "i","the", "an", "he", "him", "it", "in", "on", "at", "of", "to", "by", "is", "as", "up", "we", "us", "me"
    }

    # Ignore </w> at the end of the token
    if token.endswith('</w>'):
        cleaned_token = token[:-4]
    else:
        cleaned_token = token

    cleaned_token = re.sub(r'^\W+|\W+$', '', cleaned_token)  # Remove leading/trailing symbols (not all, for example "_" is kept as it's important for BPE)

    # If the cleaned token matches any in the whitelist, return True immediately
    if cleaned_token in whitelist:
        return True
    
    # Remove tokens that consist entirely of symbols based on length
    if re.match(r'^[\W_]+$', cleaned_token):
        if cleaned_token[0] == '_':
            if len(cleaned_token) > 2:
                return False
        else:
            if len(cleaned_token) > 2:
                return False

    # Remove any token that has 3 of the same letter repeated consecutively
    if re.search(r'(.)\1\1', cleaned_token):
        return False

    # Remove any token that contains a mix of numbers and letters
    if re.search(r'\d', cleaned_token) and re.search(r'[a-zA-Z]', cleaned_token):
        return False

    # Remove tokens that are too long (20 characters)
    if len(cleaned_token) > 20:
        return False

    # Remove tokens with unusual repeated characters
    if any(cleaned_token.count(c) > 3 for c in set(cleaned_token)):
        return False

    # Remove rare short tokens (single characters not in the whitelist)
    if len(cleaned_token) < 3 and token_freq.get(cleaned_token, 0) < 3:  # Adjust frequency threshold here
        return False

    # Remove tokens that are weird combinations of numbers and letters
    if re.match(r'\d+[a-zA-Z]+$', cleaned_token):
        return False

    return True

def process_token(token, token_freq, processed_vocab):
    if is_valid_token(token, token_freq):
        with lock:  # Ensure thread safety when modifying shared resources
            processed_vocab.append(token)

# Load the vocabulary from file
with open("/Users/ciprian/Desktop/Projects/Microcontroller Tokenizer/microcontroller-tokenizer/data/bpe-vocab_bookcorpus-30p_25000-merges.txt", "r") as vocab_file:
    bpe_vocab = [line.strip() for line in vocab_file.readlines()]

# Assume token frequency data (this can also be had from the tokenization process above)
token_freq = Counter(bpe_vocab)  
processed_vocab = []

# Use ThreadPoolExecutor for multithreading
with ThreadPoolExecutor() as executor:
    # tqdm for a progress bar
    list(tqdm(executor.map(lambda token: process_token(token, token_freq, processed_vocab), bpe_vocab), total=len(bpe_vocab)))

# Remove duplicates
cleaned_vocab = sorted(set(processed_vocab))

# Ensure the digits 0-9, symbols and space tokens are included in the vocabulary
space_token = ["<space>"]
digits = ["0</w>", "1</w>", "2</w>", "3</w>", "4</w>", "5</w>", "6</w>", "7</w>", "8</w>", "9</w>"]
symbols = [
    ".", ",", ";", ":", "'", "\"", "!", "?", "(", ")", "[", "]", "{", "}",
    "<", ">", "/", "\\", "|", "-", "_", "+", "=", "*", "&", "^", "%", "$",
    "#", "@", "~", "`"
]

cleaned_vocab.extend(space_token)
cleaned_vocab.extend(symbols)
cleaned_vocab.extend(digits)

# Save the cleaned vocabulary back to a file
with open("/Users/ciprian/Desktop/Projects/Microcontroller Tokenizer/microcontroller-tokenizer/data/bpe-vocab_bookcorpus-30p_25000-merges_cleaned_v2.txt", "w") as cleaned_file:
    for token in cleaned_vocab:
        cleaned_file.write(f"{token}\n")

print(f"Original vocabulary size: {len(bpe_vocab)}")
print(f"Cleaned vocabulary size: {len(cleaned_vocab)}")

### Conversion of binary string to phrase (deserialization/reversion of the encoding process)

In [None]:
# Function to convert a binary string to bytearray
def binary_string_to_bytearray(binary_string):
    # Split the input string into binary segments
    binary_segments = binary_string.split()

    # Ensure each segment is 8 bits long by adding leading zeros if necessary
    byte_segments = [segment.zfill(8) for segment in binary_segments]

    # Convert each 8-bit segment into a byte and form a bytearray
    byte_array = bytearray(int(segment, 2) for segment in byte_segments)

    return byte_array

# Load the vocabulary
vocab_file_path = "bpe_vocab_cleaned.txt"
vocab_dict, vocabulary = load_bpe_vocabulary(vocab_file_path)

# Input string from user, which could be in either format
input_binary_string = input("Enter the binary string to decode: ")

# Convert the input binary string to a bytearray
byte_array = binary_string_to_bytearray(input_binary_string)

# Decode the bytearray back to the original phrase
decoded_phrase = decode_bpe(byte_array, vocabulary)

# Output the decoded phrase
print("Decoded phrase:", decoded_phrase)