In [None]:
#!pip install transformers

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

In [None]:
from collections import defaultdict

freqs = defaultdict(int)

freqs

In [None]:
from nltk.corpus import brown
import nltk
nltk.download('brown')

In [None]:
content = [' '.join(x) for x in brown.sents()]
content

In [None]:
offsets = []

for text in content:
    words_with_offsets = tokenizer.backend_tokenizer.pre_tokenizer.pre_tokenize_str(text)
    offsets.append(words_with_offsets)
    new_words = [word for word, offset in words_with_offsets]
    for word in new_words:
        freqs[word] += 1

offsets[0]

len(freqs)

In [None]:
offsets[1]


In [None]:
len(freqs)

In [None]:
alphabet = []
for word in freqs.keys():
    # if isinstance(word, int):
    #   continue
    if word[0] not in alphabet:
        alphabet.append(word[0])
    for letter in word[1:]:
        if f"##{letter}" not in alphabet:
            alphabet.append(f"##{letter}")

alphabet.sort()
alphabet

print(alphabet)

In [None]:
vocab = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"] + alphabet.copy()

In [None]:
splits = {
    word: [c if i == 0 else f"##{c}" for i, c in enumerate(word)]
    for word in freqs.keys()
}

In [None]:
def compute_pair_scores(splits):
  letter_freqs = defaultdict(int)
  pair_freqs = defaultdict(int)

  for word, freq in freqs.items():
    split = splits[word]

    if len(split) == 1:
      letter_freqs[split[0]] += freq
      continue

    for i in range(len(split) - 1):
      pair = (split[i], split[i + 1])
      letter_freqs[split[i]] += freq
      pair_freqs[pair] += freq
    letter_freqs[split[-1]] += freq

  scores = {
      pair: freq/(letter_freqs[pair[0]] * letter_freqs[pair[1]])
      for pair, freq in pair_freqs.items()

  }

  return scores

In [None]:
pair_scores = compute_pair_scores(splits)

for i, key in enumerate(pair_scores.keys()):
  print(f"{key}: {pair_scores[key]}")
  if i >= 5:
    break

In [None]:
best_pair = ""
max_score = None

for pair, score in pair_scores.items():
  if max_score is None or max_score < score:
    max_score = score
    best_pair = pair

print(best_pair, max_score)

In [None]:
def merge_pair(a, b, splits):
    for word in freqs:
        split = splits[word]
        if len(split) == 1:
            continue
        i = 0
        while i < len(split) - 1:
            if split[i] == a and split[i + 1] == b:
                merge = a + b[2:] if b.startswith("##") else a + b
                split = split[:i] + [merge] + split[i + 2 :]
            else:
                i += 1
        splits[word] = split
    return splits

In [None]:
len(vocab)

In [None]:
vocab_size = 10000
while len(vocab) < vocab_size:
    print(f"Vocab Size: {len(vocab)}")
    scores = compute_pair_scores(splits)
    best_pair, max_score = "", None
    for pair, score in scores.items():
        if max_score is None or max_score < score:
            best_pair = pair
            max_score = score
    splits = merge_pair(*best_pair, splits)
    new_token = (
        best_pair[0] + best_pair[1][2:]
        if best_pair[1].startswith("##")
        else best_pair[0] + best_pair[1]
    )

    vocab.append(new_token)

In [None]:
print(vocab)

In [None]:
len(vocab)

In [None]:
import pickle

with open("vocab", "wb") as f:   #Pickling
  pickle.dump(vocab, f)
