In [1]:
# def print_format_table():
#     """
#     prints table of formatted text format options
#     """
#     for style in range(8):
#         for fg in range(30, 38):
#             s1 = ''
#             for bg in range(40, 48):
#                 format = ';'.join([str(style), str(fg), str(bg)])
#                 s1 += '\x1b[%sm %s \x1b[0m' % (format, format)
#             print(s1)
#         print('\n')
# print_format_table()

# for fg in range(31, 38):
#     print('\x1b[2;%s;40m %s \x1b[0m' % (fg, "hello"))
    
class RainbowPrinter:
    def __init__(self):
        self.idx = 0
        self.format_str = '\x1b[1;%s;48m%s \x1b[0m'

    def print_word(self, word):
        self.idx += 1
        if self.idx == 7:
            self.idx = 1
        print(self.format_str % (30+self.idx, word), end='')

    def print_words(self, words):
        """ print sentence made up of tokenwords """
        if isinstance(words,list) or isinstance(words,tuple):
            for token_word in words:
                self.print_word(token_word)
            print('\n')
        else:
            raise

In [63]:
import re
from collections import Counter

"""
SentencePiece treats the input text just as a sequence of Unicode characters. 
Whitespace is also handled as a normal symbol. 
To handle the whitespace as a basic token explicitly, SentencePiece first 
escapes the whitespace with a meta symbol "▁" (U+2581) as follows.
"""

#
# corpus = {
#     word_0: [token0, token1, ..., tokenm],
#     word_1: [token0, token1, ..., tokenm],
#     ...
#     word_n: [token0, token1, ..., tokenm],
# }
#
# vocab {
#     token_0: count_0,
#     token_1: count_1,
#       ...
#     token_m: count_m,
# }
#

class BytePairEncoder:
    
    def __init__(self):
        self.ws_token = '▁'
        self.unk_token = '<UNK>'
        
        self.corpus = {}
        self.word_count = {}
        self.vocab = Counter()
        
        self.id_tokens = {}
        self.token_ids = {}
       
    
    def init_state(self, content):
        # init corpus and wordcnt
        for line in content:
            sentence = self.preprocess(line.strip())
            self.process_sentence(sentence)
            
        alphabet = {}
        for word, chrs in self.corpus.items():
            for ch in chrs:
                alphabet[ch] = alphabet.get(ch, 0) + self.word_count[word]
        self.vocab.update(alphabet)
        
        # for debug
        self._dump_init()
      
    
    def process_sentence(self, sentence):
        words = sentence.split()
        for word in words:
            word = self.ws_token + word
            if word not in self.corpus:
                self.corpus[word] = [ch for ch in word]
                self.word_count[word] = 1
            else:
                self.word_count[word] += 1
            
    
    def preprocess(self, text):
        return re.sub('\s+', ' ', text)
        
        
    def _dump_init(self):
        print("=" * 12 + " dump initial state " + "=" * 12)
        print("==> dump corpus <==")
        for word, text in self.corpus.items():
            print(f"{word} => {text}")
        print('-.' * 20)
        print("==> dump wordcnt <==")
        for word, count in self.word_count.items():
            print(f"{word} => {count}")
        print('-.' * 20)
        print("==> dump vocab <==")
        for token, count in self.vocab.items():
            print(f"{token} => {count}")
        print("-" * 40)
        
        
    def gen_bigrams(self):
        bigram_counter = Counter()
        for word, text in self.corpus.items():
            for i in range(len(text) - 1):
                # NOTE: use '+' instead of (l,r) to deal with the case
                # a,aa is same as aa,a when generate bigram.
                bigram = text[i] + text[i+1]
                bigram_counter[bigram] += self.word_count[word]
        
        # for debug
        # print("==> dump bigram counter <==")
        # for symbol, counter in bigram_counter.most_common(5):
        #     print(f"{symbol} => {counter}")
        return bigram_counter

    
    def merge_pair(self):
        top_bigram, top_count  = self.gen_bigrams().most_common(1)[0]
        print(f"=> top_bigram:{top_bigram}, top_count:{top_count}")
        if top_count == 1:
            return
        for word, text in self.corpus.items():
            merged = False
            for i in range(len(text) - 1): 
                if (text[i] + text[i+1] == top_bigram):
                    self.update_vocab(text[i], -self.word_count[word])
                    self.update_vocab(text[i+1], -self.word_count[word])
                    text[i] = top_bigram
                    text[i+1] = ''
                    merged = True
            if merged:
                self.corpus[word] = [token for token in text if token]
        self.update_vocab(top_bigram, top_count)
    
    
    def update_vocab(self, symbol, count):
        if symbol in self.vocab:
            self.vocab[symbol] += count
            # NOTE: must comment off, will cut off the way to combine tokenwords
            # if self.vocab[symbol] == 0:
                # del self.vocab[symbol]
        else:
            self.vocab[symbol] = count
           
        
    def train(self, text, steps=3):
        self.init_state(text)
        
        for step in range(steps):
            print("=" * 12 + f" step:{step} " + "=" * 12)
            self.merge_pair()
            # for debug
            # self._dump_merge()
            
        print("==> dump final vocab <==")
        for token, count in sorted(self.vocab.items(), key=lambda x:x[1], reverse=True):
            print(f"{token} => {count}")
        self.gen_id_token_map()
        

    def _dump_merge(self):
        print("-" * 40)
        print("==> dump vocab <==")
        for token, count in sorted(self.vocab.items(), key=lambda x:x[1], reverse=True):
            print(f"{token} => {count}")
        print('-' * 40)
        print("==> dump corpus <==")
        for word, tokens in self.corpus.items():
            print(f"[{self.word_count[word]:3d}] * {word} => {tokens}")
        print("-" * 40)       


    def gen_id_token_map(self):
        # descent order
        self.id_tokens[0] = self.unk_token
        self.token_ids[self.unk_token] = 0
        
        idx = 1
        for token, _ in self.vocab.most_common():
            self.id_tokens[idx] = token
            self.token_ids[token] = idx
            idx += 1
        
        
    def encode(self, text):
        if not text: return
        text = self.preprocess(text)
        text = self.ws_token + re.sub(' ', self.ws_token, text.strip())
        seg_txt = self.segment(text)
        seg_ids = [self.token_ids[token] if token in self.token_ids else 0 for token in seg_txt]
        return (seg_txt, seg_ids)
    
    
    def segment(self, text):
        if len(text) == 1:
            return text if text in self.vocab else self.unk_token
        
        segments = [ch for ch in text]
        merge_rules = Counter()
    
        # iter over merge segments [i, i+1]
        for i in range(len(segments)-1):
            token_word = segments[i] + segments[i+1]
            if token_word in self.vocab:
                # print(f"* update rule of combine {segments[i]} and {segments[i+1]} into {token_word}")
                merge_rules.update({(i, token_word):self.vocab[i]})

        while merge_rules:
            (i, token_word), _ = merge_rules.most_common(1)[0]
            # eg: a,b,c  first merge (b,c); then (a,b) is no longer exist
            if i >= len(segments)-1 or segments[i] + segments[i+1] != token_word:
                # print(f"! discard rule of combine {segments[i]} and {segments[i+1]} into {token_word}, i={i}")
                merge_rules.pop((i, token_word))
                continue
            # print(f"> apply rule of combine {segments[i]} and {segments[i+1]} into {token_word}")
            for i in range(len(segments)-1):
                if segments[i] + segments[i+1] == token_word:
                    segments[i] = token_word
                    segments[i+1] = ''
            # print("before merge: ", segments)
            segments = [seg for seg in segments if seg]
            # print("after merge: ", segments)
            if len(segments) <= 1:
                break
            for i in range(len(segments)-1):
                token_word = segments[i] + segments[i+1]
                if token_word in self.vocab:
                    merge_rules.update({(i, token_word): self.vocab[i]})
                    
        return segments
        
        
    def decode(self, ids):
        text = ''.join([self.id_tokens[idx] for idx in ids]).replace(self.ws_token, ' ')
        return text
        


In [64]:
bpe = BytePairEncoder()
corpus = [
    # "Alice is running faster than Bob",
    # "Bob run slower than Alice",
    # "FloydHub is the fastest way to build, train and deploy deep learning models. Build deep learning models in the cloud. Train deep learning models."
    # "old " * 7 + "older " * 3  + "finest " * 9 + "lowest " * 4
    # "hug " * 10 + "pug " * 5 + "pun " * 12 + "bun " * 4 + "hugs " * 5
    "这是OpenAI 团队前一段时间放出来的预印版论文。 他们的目标是学习一个通用的表示，能够在大量任务上进行应用。",
    "这篇论文的亮点主要在于， 他们利用了Transformer网络代替了LSTM作为语言模型来更好的捕获长距离语言结构。",
    "然后在进行具体任务有监督微调时, 使用了模型作为附属任务训练目标。"
]
bpe.train(corpus, 20)
# bpe.init_state('\n'.join(corpus))

==> dump corpus <==
▁这是OpenAI => ['▁', '这', '是', 'O', 'p', 'e', 'n', 'A', 'I']
▁团队前一段时间放出来的预印版论文。 => ['▁', '团', '队', '前', '一', '段', '时', '间', '放', '出', '来', '的', '预', '印', '版', '论', '文', '。']
▁他们的目标是学习一个通用的表示，能够在大量任务上进行应用。 => ['▁', '他', '们', '的', '目', '标', '是', '学', '习', '一', '个', '通', '用', '的', '表', '示', '，', '能', '够', '在', '大', '量', '任', '务', '上', '进', '行', '应', '用', '。']
▁这篇论文的亮点主要在于， => ['▁', '这', '篇', '论', '文', '的', '亮', '点', '主', '要', '在', '于', '，']
▁他们利用了Transformer网络代替了LSTM作为语言模型来更好的捕获长距离语言结构。 => ['▁', '他', '们', '利', '用', '了', 'T', 'r', 'a', 'n', 's', 'f', 'o', 'r', 'm', 'e', 'r', '网', '络', '代', '替', '了', 'L', 'S', 'T', 'M', '作', '为', '语', '言', '模', '型', '来', '更', '好', '的', '捕', '获', '长', '距', '离', '语', '言', '结', '构', '。']
▁然后在进行具体任务有监督微调时, => ['▁', '然', '后', '在', '进', '行', '具', '体', '任', '务', '有', '监', '督', '微', '调', '时', ',']
▁使用了模型作为附属任务训练目标。 => ['▁', '使', '用', '了', '模', '型', '作', '为', '附', '属', '任', '务', '训', '练', '目', '标', '。']
-.-.-.-.-.-.-.-.-.-.-.-.-.-.-.-.-.-.-.-.
==> 

In [None]:
printer = RainbowPrinter()
# segments, seg_ids = bpe.encode("huggpnun  what ugg is haasnb")
seg_txt, seg_ids = bpe.encode("他们论文的亮点是用语言模型完成对应的目标任务")
printer.print_words(seg_txt)
print(bpe.decode(seg_ids))

[1;31;48m▁他们 [0m[1;32;48m论文 [0m[1;33;48m的 [0m[1;34;48m亮 [0m[1;35;48m点 [0m[1;36;48m是 [0m[1;31;48m用 [0m[1;32;48m语言 [0m[1;33;48m模型 [0m[1;34;48m完 [0m[1;35;48m成 [0m[1;36;48m对 [0m[1;31;48m应 [0m[1;32;48m的 [0m[1;33;48m目标 [0m[1;34;48m任务 [0m

 <UNK>论文的亮点是用模型<UNK><UNK><UNK>应的目标任务


In [None]:
bpe.train('\n'.join(corpus))

In [None]:
bpe.merge_pair()