In [3]:
from collections import defaultdict
from tqdm import tqdm
import matplotlib.pyplot as plt

BEGINNING_OF_WORD = "<B>"
START = "<START>"
def read_corpus(num_lines):
    path = "/Users/anderson/Documents/CSE/cse447-nlp/hw1/code/BPE/BPE-data.txt"
    corpus = []
    with open(path, "r") as f:
        if num_lines > 0:
            lines = f.readlines()[:num_lines]
        else:
            lines = f.readlines()[num_lines:]
        for line in lines:
            line = line.strip("\n")
            tokens = line.split(" ")
            corpus.append(tokens)
    return corpus

class BpeTokenizer:
    def __init__(self, corpus) -> None:
        self.cntr = defaultdict(int)
        self.split_map = {}
        self.corpus = corpus
        self.vocab = set()
        # Construct vocab to be all characters
        # Construct split map by characters
        print("Initialize tokenizer")
        for line in self.corpus:
            for i, token in enumerate(line):
                split = []
                for ch in token:
                    self.vocab.add(ch)
                    split.append(ch)
                if i > 0:
                    token = BEGINNING_OF_WORD + token
                    split.insert(0, BEGINNING_OF_WORD)
                    self.vocab.add(BEGINNING_OF_WORD)
                self.cntr[token] += 1
                self.split_map[token] = split
        # indexed vocab for inference
        self.vocab = list(self.vocab)
        self.apply_rules = []
    
    # Modified from: https://huggingface.co/learn/nlp-course/chapter6/5?fw=pt#implementing-bpe
    def compute_pair_freqs(self):
        pair_freqs = defaultdict(int)
        max_freq, max_pair = 0, None
        # print("Compute pair frequencies")
        for word, freq in self.cntr.items():
            split = self.split_map[word]
            if len(split) == 1:
                continue
            for i in range(len(split) - 1):
                pair = (split[i], split[i+1])
                pair_freqs[pair] += freq
                if pair_freqs[pair] > max_freq:
                    max_pair = pair
                    max_freq = pair_freqs[pair]
        return max_freq, max_pair
    
    def compute_corpus_size(self):
        size = 0
        for word, freq in self.cntr.items():
            size += len(self.split_map[word]) * freq
        return size

    def merge_pair(self, max_pair):
        # print("Merge max pairs")
        self.vocab.append(max_pair[0] + max_pair[1])
        self.apply_rules.append(max_pair)
        # print(f"max_pair = {max_pair}")
        for word, split in self.split_map.items():
            new_split = []
            idx = 0
            # print(f"old_split = {split}")
            while idx < len(split):
                if idx == len(split) - 1:
                    new_split.append(split[idx])
                    break
                elif (split[idx], split[idx+1]) == max_pair:
                    new_split.append(split[idx] + split[idx+1])
                    idx += 2
                else:
                    new_split.append(split[idx])
                    idx += 1
            # print(f"new_split = {new_split}")
            self.split_map[word] = new_split
    
    def set_eval_corpus(self, eval_corpus):
        self.eval_corpus = eval_corpus
        self.eval_split_map = {}
        for line in self.eval_corpus:
            for i, token in enumerate(line):
                split = []
                for ch in token:
                    split.append(ch)
                if i > 0:
                    token = BEGINNING_OF_WORD + token
                    split.insert(0, BEGINNING_OF_WORD)
                self.eval_split_map[token] = split
    
    def apply(self):
        if not hasattr(self, "eval_corpus") or not self.eval_corpus:
            raise ValueError("No eval corpus set")
        for pair in self.apply_rules:
            for word, split in self.eval_split_map.items():
                new_split = []
                idx = 0
                while idx < len(split):
                    if idx == len(split) - 1:
                        new_split.append(split[idx])
                        break
                    if (split[idx], split[idx+1]) == pair:
                        new_split.append(split[idx] + split[idx+1])
                        idx += 2
                    else:
                        new_split.append(split[idx])
                        idx += 1
                self.eval_split_map[word] = new_split
    
    def lookup_eval_split(self, token, beginning_of_word=False):
        if beginning_of_word:
            token = BEGINNING_OF_WORD + token
        if token not in self.eval_split_map:
            raise ValueError("Invalid token")
        return self.eval_split_map[token]

def bpe_train(corpus):
    tokenizer = BpeTokenizer(corpus)
    threshold = 2
    max_freq, max_pair = tokenizer.compute_pair_freqs()
    corpus_size = []
    # output helper
    max_num_dig = len(str(max_freq))
    print(f"inital vocab size = {len(tokenizer.vocab)}")
    while max_freq > threshold:
        # output helper
        num_dig = len(str(max_freq))
        output = " " * (max_num_dig - num_dig) + str(max_freq)
        print(f"max bigram count = {output}", end="\r")
        # actual training
        tokenizer.merge_pair(max_pair)
        size = tokenizer.compute_corpus_size()
        corpus_size.append(size)
        max_freq, max_pair = tokenizer.compute_pair_freqs()
    print(f"final training corpus size = {corpus_size[-1]}")
    print(f"final vocab size = {len(tokenizer.vocab)}")
    # fig = plt.figure()
    # plt.scatter(range(len(corpus_size)), corpus_size)
    # plt.show()
    for i, token in enumerate(corpus[0]):
        if i != 0:
            token = BEGINNING_OF_WORD + token
        print(tokenizer.split_map[token])
    return tokenizer


def bpe_apply(eval_corpus, tokenizer):
    tokenizer.set_eval_corpus(eval_corpus)
    tokenizer.apply()
    for line in eval_corpus[:5]:
        for i, token in enumerate(line):
            print(tokenizer.lookup_eval_split(token, i != 0))
    print()

if __name__ == "__main__":
    corpus = read_corpus(4000)
    tokenizer = bpe_train(corpus)
    eval_corpus = read_corpus(-1000)
    bpe_apply(eval_corpus, tokenizer)

Initialize tokenizer
inital vocab size = 83
final training corpus size = 110439
final vocab size = 9219
['P', 'rague']
['<B>Stock']
['<B>Market']
['<B>falls']
['<B>to']
['<B>minus']
['<B>by']
['<B>the']
['<B>end']
['<B>of']
['<B>the']
['<B>trading']
['<B>day']
['When']
['<B>our']
['<B>brain']
['<B>runs']
['<B>hot']
['W', 'hether']
['<B>these']
['<B>ill', 'usions']
['<B>are']
['<B>perceived']
['<B>to']
['<B>be']
['<B>ple', 'as', 'ant']
['<B>or']
['<B>unp', 'le', 'as', 'ant']
['<B>is']
['<B>often']
['<B>only']
['<B>a']
['<B>matter']
['<B>of']
['<B>perspective', '.']
['K', 'ast', 'en']
['<B>has']
['<B>been']
['<B>collect', 'ing']
['<B>reports']
['<B>for']
['<B>ten']
['<B>years']
['<B>on']
['<B>hallucinations']
['<B>and']
['<B>del', 'usions', ';']
['<B>last']
['<B>year']
['<B>he']
['<B>assess', 'ed']
['<B>his']
['<B>results']
['<B>in']
['<B>a']
['<B>book']
['<B>tit', 'led']
['<B>The']
['<B>un', 'real']
['<B>World']
['<B>in']
['<B>our']
['<B>He', 'ad', '.']
['H', 'all', 'ucin', 'ations']
['

In [8]:
for pair in tokenizer.apply_rules:
    print(pair)

('<B>', 't')
('h', 'e')
('<B>', 'a')
('i', 'n')
('<B>t', 'he')
('r', 'e')
('o', 'n')
('e', 'r')
('<B>', 'o')
('<B>', 's')
('<B>', 'w')
('a', 't')
('e', 'n')
('a', 'n')
('i', 's')
('<B>', 'c')
('o', 'r')
('i', 't')
('<B>', 'b')
('<B>', 'p')
('e', 's')
('e', 'd')
('<B>o', 'f')
('<B>', 'f')
('<B>', 'in')
('<B>t', 'o')
('in', 'g')
('a', 'r')
('a', 'l')
('i', 'c')
('<B>', 'm')
('o', 'u')
('<B>', 'h')
('<B>a', 'n')
('i', 'on')
('<B>', 'd')
('a', 's')
('<B>t', 'h')
('l', 'e')
('r', 'o')
('en', 't')
('<B>an', 'd')
('<B>', 'n')
('i', 'l')
('<B>', 're')
('<B>b', 'e')
('s', 't')
('o', 'm')
('<B>', 'e')
('v', 'e')
('i', 'd')
('c', 't')
('l', 'y')
('o', 't')
('<B>', 'on')
('<B>', 'l')
('a', 'd')
('<B>', 'is')
('<B>f', 'or')
('a', 'y')
('u', 't')
('<B>', 'g')
('c', 'e')
('T', 'he')
('i', 'm')
('<B>th', 'at')
('v', 'er')
('e', 't')
('u', 'r')
('at', 'ion')
('o', 'l')
('u', 's')
('<B>w', 'h')
('i', 'r')
('i', 'g')
('e', 'l')
('u', 'n')
('t', 'er')
('c', 'h')
('<B>s', 't')
('a', 'm')
('<B>', 'it')
('s'