# PBE-dropout algorithm

![image](https://github.com/anminhhung/images/blob/main/nlp/bpe-dropout.png?raw=true)

- [Source](https://slideslive.com/38928817/bpedropout-simple-and-effective-subword-regularization)

# Build vocab

In [39]:
sentences = ["low", "low", "low", "low", "low", "lower", "lower", "newest", "newest", \
             "newest", "newest", "newest", "newest", "widest", "widest", "widest", "happier", "happier"]

In [40]:
corpus = " ".join(sentences[:])
corpus_length = 0
for sentence in sentences:
  for word in sentence:
    if word != " ":
      corpus_length += 1

In [41]:
print("corpus_length: ", corpus_length)
print("corpus: ", corpus)

corpus_length:  93
corpus:  low low low low low lower lower newest newest newest newest newest newest widest widest widest happier happier


build vocab

In [53]:
# processs.....
vocab = {
    'l o w </w>': 5,
    'l o w e r </w>': 2,
    'n e w e s t </w>': 6,
    'w i d e s t </w>': 3,
    'h a p p i e r </w>': 2
}

# Train BPE-Dropout

In [43]:
from operator import itemgetter
import re
from typing import Dict, Tuple, List, Set 

In [44]:
def calculate_word_probability(word: str, corpus: str, corpus_length: int, smooth_params=0.001):
  return (corpus.count(word) + smooth_params) / corpus_length

In [45]:
def get_pair_stats(voacb: Dict[str, int]) -> Dict[Tuple[str, str], int]:
  pairs = {}
  for word, freq in vocab.items():
    symbols = word.split()
    for i in range(len(symbols) - 1):
      pair = (symbols[i], symbols[i+1])
      current_freq = pairs.get(pair, 0)
      pairs[pair] = current_freq + freq

  return pairs

In [46]:
pair_stats = get_pair_stats(vocab)
pair_stats

{('a', 'p'): 2,
 ('d', 'e'): 3,
 ('e', 'r'): 4,
 ('e', 's'): 9,
 ('e', 'w'): 6,
 ('h', 'a'): 2,
 ('i', 'd'): 3,
 ('i', 'e'): 2,
 ('l', 'o'): 7,
 ('n', 'e'): 6,
 ('o', 'w'): 7,
 ('p', 'i'): 2,
 ('p', 'p'): 2,
 ('r', '</w>'): 4,
 ('s', 't'): 9,
 ('t', '</w>'): 9,
 ('w', '</w>'): 5,
 ('w', 'e'): 8,
 ('w', 'i'): 3}

In [47]:
def compute_pair_probability(current_pair: Tuple[str, str]):
  first_word_prob = calculate_word_probability(current_pair[0], corpus, corpus_length)
  second_word_prob = calculate_word_probability(current_pair[1], corpus, corpus_length)

  return first_word_prob * second_word_prob

In [48]:
def compute_probability_in_pair_stats(pair_stats: Dict[Tuple[str, str], int]):
  """ return decrease sorted dict """
  dict_pair_prob = {}
  for pair, count in pair_stats.items():
    prob_pair = compute_pair_probability(pair)
    dict_pair_prob[pair] = prob_pair
  
  return dict_pair_prob

In [49]:
pair_stats_prob = compute_probability_in_pair_stats(pair_stats)
print(pair_stats_prob)
best_pair = max(pair_stats_prob, key=pair_stats.get)
print("best_pair: ", best_pair)

{('l', 'o'): 0.005667013643195747, ('o', 'w'): 0.012952133310209276, ('w', '</w>'): 1.8500404671060242e-06, ('w', 'e'): 0.03515261891548157, ('e', 'r'): 0.008789802404902302, ('r', '</w>'): 4.625968320036999e-07, ('n', 'e'): 0.013183605156665513, ('e', 'w'): 0.03515261891548157, ('e', 's'): 0.019774309284310326, ('s', 't'): 0.00936732581801364, ('t', '</w>'): 1.0406983466296682e-06, ('w', 'i'): 0.009252052375997227, ('i', 'd'): 0.0017352296219216095, ('d', 'e'): 0.006592901029020696, ('h', 'a'): 0.0004629438085327783, ('a', 'p'): 0.0009256562608394034, ('p', 'p'): 0.0018508499248468034, ('p', 'i'): 0.002313446756850503, ('i', 'e'): 0.010986703780783907}
best_pair:  ('e', 's')


In [50]:
def merge_vocab(best_pair: Tuple[str, str], vocab_in: Dict[str, int]) -> Dict[str, int]:
  vocab_out = {}
  pattern = re.escape(' '.join(best_pair))
  replacement = ''.join(best_pair)

  for word_in in vocab_in:
    word_out = re.sub(pattern, replacement, word_in)
    vocab_out[word_out] = vocab_in[word_in]
  
  return vocab_out

In [51]:
best_pair = max(pair_stats_prob, key=pair_stats.get)
print("best_pair: ", best_pair)
prob_merge_pair = pair_stats_prob[best_pair]
print("prob_merge_pair: ", prob_merge_pair)

if prob_merge_pair > 0.005: # 0.015: dropout_ratio
  new_vocab = merge_vocab(best_pair, vocab)
  print("new_vocab: ", new_vocab)

best_pair:  ('e', 's')
prob_merge_pair:  0.019774309284310326
new_vocab:  {'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w es t </w>': 6, 'w i d es t </w>': 3, 'h a p p i e r </w>': 2}


phần code ở trên thể hiện cho 1 iter bây giờ sẽ tăng số iter lên và xem kết quả

In [114]:
# vocab = {
#     'l o w </w>': 5,
#     'l o w e r </w>': 2,
#     'n e w e s t </w>': 6,
#     'w i d e s t </w>': 3,
#     'h a p p i e r </w>': 2
# }

In [115]:
bpe_codes = {}
dropout_ratio = 0.005

iter = 0
while True:
  pair_stats = get_pair_stats(vocab)
  
  pair_stats_prob = {}
  for pair, freq in pair_stats.items():
    prob_merge_pair = compute_pair_probability(pair)
    if prob_merge_pair > dropout_ratio:
      pair_stats_prob[pair] = prob_merge_pair
  
  if not pair_stats_prob:
    break 

  if len(pair_stats_prob) > dropout_ratio:
    best_pair = max(pair_stats_prob, key=pair_stats.get)
    bpe_codes[best_pair] = iter
    vocab = merge_vocab(best_pair, vocab)

  iter += 1

if len(bpe_codes) == 0:
  pair_stats = get_pair_stats(vocab)
  cnt = 0
  for pair, _ in pair_stats.items():
    bpe_codes[pair] = cnt
    cnt += 1

print("\n vocab: ", vocab)
print('byte pair encoding: ', bpe_codes)


 vocab:  {'low </w>': 5, 'low er </w>': 2, 'newest </w>': 6, 'wi d est </w>': 3, 'h a p p i er </w>': 2}
byte pair encoding:  {('e', 's'): 0, ('es', 't'): 1, ('l', 'o'): 2, ('lo', 'w'): 3, ('n', 'e'): 4, ('ne', 'w'): 5, ('new', 'est'): 6, ('e', 'r'): 7, ('w', 'i'): 8}


# Encode

In [116]:
original_word = 'lowest'
word = list(original_word)
word.append('</w>')
word

['l', 'o', 'w', 'e', 's', 't', '</w>']

In [117]:
# lấy các cặp ký tự trong word input
def get_pairs(word: List[str]) -> Set[Tuple[str, str]]:
  pairs = set()
  prev_char = word[0]
  for char in word[1:]:
    pairs.add((prev_char, char))
    prev_char = char
  
  return pairs

In [118]:
pairs = get_pairs(word)
pairs

{('e', 's'), ('l', 'o'), ('o', 'w'), ('s', 't'), ('t', '</w>'), ('w', 'e')}

In [119]:
bpe_codes_pairs = [(pair, bpe_codes[pair]) for pair in pairs if pair in bpe_codes]
print("bpe_codes_pairs: ", bpe_codes_pairs)
pair_to_merge = min(bpe_codes_pairs, key=itemgetter(1))[0]
print("pair_to_merge: ", pair_to_merge)

bpe_codes_pairs:  [(('e', 's'), 0), (('l', 'o'), 2)]
pair_to_merge:  ('e', 's')


In [120]:
def create_new_word(word: List[str], pair_to_merge: Tuple[str, str]) -> List[str]:
  first_char, second_char = pair_to_merge
  new_word = []
  i = 0
  while i < len(word):
    try:
      j = word.index(first_char, i)
      new_word.extend(word[i:j])
      i = j
    except:
      new_word.extend(word[i:])
      break 
    
    if i < len(word)-1 and word[i+1] == second_char:
      new_word.append(first_char + second_char)
      i += 2
    else:
      new_word.append(first_char)
      i += 1

  return new_word

In [121]:
new_word = create_new_word(word, pair_to_merge)
new_word

['l', 'o', 'w', 'es', 't', '</w>']

Phần code trên là quá trình thực hiện cho việc encode 1 từ qua 1 iter. Bây giờ thực hiện với nhiều iter.

In [122]:
def encode(original_word: str, bpe_codes: Dict[Tuple[str, str], int], len_word_split=1) -> List[str]:
  if len(original_word) == len_word_split: # set kích thước tối thiểu của một từ cần phân tách.
    return original_word
  
  word = list(original_word)
  word.append('</w>')

  while True:
    pairs = get_pairs(word)
    bpe_codes_pairs = [(pair, bpe_codes[pair]) for pair in pairs if pair in bpe_codes]
    if not bpe_codes_pairs:
        break

    pair_to_merge = min(bpe_codes_pairs, key=itemgetter(1))[0]
    word = create_new_word(word, pair_to_merge)
  
  return word

In [123]:
original_word = 'lowest'
encode(original_word, bpe_codes)

['low', 'est', '</w>']