In [None]:
from cs336_basics.pretokenization_example import find_chunk_boundaries
import regex as re

sample_tiny_path = "/Users/prateekmahadevappahavanur/Documents/GitHub/test_task/assignment1-basics/tests/fixtures/tinystories_sample.txt"


def bpe_train(file_path,vocab_size,special_tokens):
    
    PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

    special_token_pattern = "|".join(re.escape(token) for token in special_tokens)
    word_freqs = {}

    with open(file_path, "rb") as f:
        boundaries = find_chunk_boundaries(
            f, 2, "<|endoftext|>".encode("utf-8"))
            
        # The following is a serial implementation, but you can parallelize this 
        # by sending each start/end pair to a set of processes.
        word_freqs = {}
        for start, end in zip(boundaries[:-1], boundaries[1:]):
            f.seek(start)
            chunk = f.read(end - start).decode("utf-8", errors="ignore")
            text = re.split(special_token_pattern, chunk)
            for segment in text:
                if not segment:
                    continue
                for match in re.finditer(PAT, segment):
                    token = match.group()
                    token_bytes = tuple(token.encode("utf-8"))
                    # print(token,"---",token_bytes)
                    word_freqs[token_bytes] = word_freqs.get(token_bytes, 0) + 1

    def get_count(mapped_list):
        count = {}
        for test_list,value in mapped_list.items():
            test_list = list(test_list)
            for a1,a2 in zip(test_list,test_list[1:]):
                count[(a1,a2)] = count.get((a1,a2),0) +  value
        return count


    def merge_pair(test_list,pair,replacement):
        i = 0
        new_list = []
        while i < len(test_list):
            if i < len(test_list) - 1 and test_list[i] == pair[0] and test_list[i+1] == pair[1]:
                new_list.append(replacement)
                i = i+2
            else:
                new_list.append(test_list[i])
                i += 1
        return new_list

    def get_token_bytes(token_id, vocab_list):
        """Get the bytes for a token ID"""
        if token_id < 256:
            return bytes([token_id])  # Single byte
        else:
            return vocab_list[token_id]  # Already merged token
        
    def make_comparable_pair(pair, vocab_list):
        """Convert pair elements to comparable format"""
        a, b = pair
        # Convert to bytes for consistent comparison
        if isinstance(a, int):
            a_bytes = bytes([a]) if a < 256 else f"merged_{a}".encode()
        else:
            a_bytes = a
        if isinstance(b, int):
            b_bytes = bytes([b]) if b < 256 else f"merged_{b}".encode()
        else:
            b_bytes = b
        return (a_bytes, b_bytes)

    # Initialize vocabulary
    vocab_list = []
    # Special tokens first
    for special_token in special_tokens:
        vocab_list.append(special_token.encode("utf-8"))
    # All 256 bytes
    for i in range(256):
        vocab_list.append(bytes([i]))

    num_merges = vocab_size - len(special_tokens) - 256

    merges = []
    for i in range(num_merges):
        # print(f"----iteration {i}-----")
        pair_counts = get_count(word_freqs)
        # Find best pair with lexicographic tie-breaking
        best_pair = max(
                pair_counts.items(),
                key=lambda x: (x[1], make_comparable_pair(x[0], vocab_list)),
            )[0]
        best_pair

        new_token_id = len(vocab_list)

        first_bytes = get_token_bytes(best_pair[0], vocab_list)
        second_bytes = get_token_bytes(best_pair[1], vocab_list)
        merged_bytes = first_bytes + second_bytes

        vocab_list.append(merged_bytes)

        # print(merged_bytes)

        new_word_freqs = {}
        for words,freq in word_freqs.items():
            merged_list = merge_pair(words, best_pair, new_token_id)
            new_word_freqs[tuple(merged_list)] = freq

        word_freqs = new_word_freqs
        merges.append(best_pair)


    # Build final outputs
    vocab_dict = {i: token for i, token in enumerate(vocab_list)}

    # Build merges_bytes with proper lookup
    merges_bytes = []
    for first, second in merges:
        first_bytes = get_token_bytes(first, vocab_list)
        second_bytes = get_token_bytes(second, vocab_list)
        merges_bytes.append((first_bytes, second_bytes))

    return vocab_dict,merges_bytes

vocab_dict,merges_bytes = bpe_train(sample_tiny_path,300,["<|endoftext|>"])





In [None]:
import pickle
import regex as re
class Tokenizer:
    def __init__(self, vocab, merges, special_tokens=None):
        self.vocab = vocab
        self.merges = merges
        self.special_tokens = special_tokens
        self.vocab_dict = {v: k for k, v in self.vocab.items()}

    @classmethod
    def from_file(cls, vocab_path, merges_path, special_tokens=None):
        with open(vocab_path, "rb") as f:
            vocab = pickle.load(f)
        with open(merges_path, "rb") as f:
            merges = pickle.load(f)
        
        # Create and return a new instance
        return cls(vocab, merges, special_tokens)

    def encode(self, text):

        if self.special_tokens is None:
            special_token_pattern = []

        else:
            special_token_pattern = "|".join(re.escape(token) for token in self.special_tokens)
            text = re.split(f"({special_token_pattern})", text)

        PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""


        words = []
        for segment in text:
            if not segment:
                continue
            if self.special_tokens and segment in self.special_tokens:
                # Handle special tokens - they should be encoded as single tokens
                special_token_bytes = segment.encode("utf-8")
                words.append([special_token_bytes])
            else:
                for match in re.finditer(PAT, segment):
                    token = match.group()
                    words.append([bytes([x]) for x in token.encode("utf-8")])

        final_words = []
        for word in words:            
            while True:
                best_merge = None
                best_pos = -1

                for i in range(len(word) - 1):
                    pair = (word[i], word[i + 1])
                    if pair in self.merges:
                        merge_priority = self.merges.index(pair)
                        if best_merge is None or merge_priority < best_merge[1]:
                            best_merge = (i, merge_priority)
                            best_pos = i

                if best_merge is None:
                    break

                # Apply the best merge
                word = (
                    word[:best_pos]
                    + [word[best_pos] + word[best_pos + 1]]
                    + word[best_pos + 2 :]
                )
            final_words.append(word)

        # print(final_words)

        tokenized_words = []     # print(word)
        for word in final_words:
            tokenized_word = []
            for i in word:
                tokenized_word.append(self.vocab_dict.get(i,"<UNK>"))
            tokenized_words.extend(tokenized_word)
        return tokenized_words
    
    def encode_iterable(self, iterable):
        for item in iterable:
            yield self.encode(item)

    def decode(self, tokens):
        """Decode token IDs back to text"""
        # Step 1: Convert IDs to tokens
        tokens = [self.vocab.get(token_id,"<UNK>")for token_id in tokens]
        
        # Step 2: Concatenate all bytes
        full_bytes = b''.join(tokens)
        
        # Step 3: Decode to string
        text = full_bytes.decode('utf-8')
        
        return text


tokenizer = Tokenizer.from_file("gpt2_vocab.json", "../data/results/merges.pkl",["<|endoftext|>", "<|endoftext|><|endoftext|>"])


UnpicklingError: invalid load key, '{'.

In [12]:
tokenizer.decode(tokenizer.encode("Héllò hôw <|endoftext|><|endoftext|> are ü? 🙃<|endoftext|>"))


'Héllò hôw <|endoftext|><|endoftext|> are ü? 🙃<|endoftext|>'

In [None]:
import regex as re
import pickle

PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

vocab_path = "../data/results/vocab.pkl"
merges_path = "../data/results/merges.pkl"

with open(merges_path, "rb") as f:
    merges = pickle.load(f)

with open(vocab_path, "rb") as f:
    vocab = pickle.load(f)

# print(vocab)

vocab_dict = {v: k for k, v in vocab.items()}

chunk = "Hello, world! This is a test. <|endoftext|>"
special_tokens = ["<|endoftext|>"]

special_token_pattern = "|".join(re.escape(token) for token in special_tokens)

PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

text = re.split(f"({special_token_pattern})", chunk)

words = []
for segment in text:
    if not segment:
        continue
    if segment in special_tokens:
        continue
    else:
        for match in re.finditer(PAT, segment):
            token = match.group()
            words.append([bytes([x]) for x in token.encode("utf-8")])

final_words = []
for word in words:
    # print("processing word",word,"\n")
    
    while True:
        i = 0
        merge_found = False
        while i < len(word) - 1:
            if (word[i],word[i+1]) in merges:
                word = word[:i] + [word[i] + word[i+1]] + word[i+2:]
                merge_found = True
                break
            i = i+1
        if not merge_found:
            break
    final_words.append(word)

# print(final_words)

tokenized_words = []     # print(word)
for word in final_words:
    tokenized_word = []
    for i in word:
        tokenized_word.append(vocab_dict.get(i,"<UNK>"))
    tokenized_words.extend(tokenized_word)
        # print(vocab_dict[i])
# for word,token in zip(final_words,tokenized_words):
#     print(word,"==>",token)




print(tokenized_words)
