<a href="https://www.kaggle.com/code/asmitamukh/basic-bpe?scriptVersionId=145723510" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

## References

>https://huggingface.co/learn/nlp-course/chapter6/5?fw=pt

>https://huggingface.co/transformers/v3.2.0/_modules/transformers/tokenization_gpt2.html

## Imports

In [None]:
import pandas as pd
from tqdm.notebook import tqdm
import logging
import sys
import nltk
from collections import defaultdict

In [None]:
from sklearn.model_selection import train_test_split

In [None]:
%%capture
nltk.download('all')

In [None]:
from tqdm.notebook import tqdm

In [None]:
import logging
logging.basicConfig(filename="logs.log",level=logging.DEBUG,format="%(asctime)s %(message)s",force=True)

## Load the data
> Download the 1st paraquet of tiny stories data

In [None]:
urls = ["https://huggingface.co/datasets/roneneldan/TinyStories/resolve/refs%2Fconvert%2Fparquet/default/train/0000.parquet",
        "https://huggingface.co/datasets/roneneldan/TinyStories/resolve/refs%2Fconvert%2Fparquet/default/train/0001.parquet",
        "https://huggingface.co/datasets/roneneldan/TinyStories/resolve/refs%2Fconvert%2Fparquet/default/train/0002.parquet",
        "https://huggingface.co/datasets/roneneldan/TinyStories/resolve/refs%2Fconvert%2Fparquet/default/train/0003.parquet"]

In [None]:
urls = urls[:1]

In [None]:
list_of_downloaded_urls = []
for url in tqdm(urls):
        df = pd.read_parquet(url)
        logging.info(f"downloaded file of size {sys.getsizeof(df)}")
        list_of_downloaded_urls.append(df)
final_df = pd.concat(list_of_downloaded_urls)


In [None]:
final_df.info()

In [None]:
#Cutting down the data
final_df = final_df.iloc[:30000]

In [None]:
final_df.head()

In [None]:
train_corpus,test_corpus = train_test_split(final_df,test_size=0.005,random_state=2023)

In [None]:
len(train_corpus),len(test_corpus)

## BPE of the dataset

> Word Tokenize

>Splitting into individual chars

>Merging to build a vocab

In [None]:
def train_tokenize(texts,vocab_length,min_freq=5):

  vocab = set()
  words = []
  logging.info("Pre tokenizing")
  for text in texts:
    text = text.lower()
    words.extend(nltk.word_tokenize(text))

  word_count_dict = defaultdict(int)
  def rem_word(word):
        rem_words = [w for w in words if w!=word]
        return rem_words
  logging.info("Counting freq of words")
  for word in words:
    word_count_dict[word] = word_count_dict[word] + 1
  rare_word_threshold = min_freq
  logging.info(f"Total nos of words in corpus {len(words)}")
  for word in word_count_dict.keys():
    if word_count_dict[word] <rare_word_threshold:
        words = rem_word(word)
        logging.info(f"Removed word {word} due to low frequency")
        logging.info(f"Nos of words in corpus after deleting {len(words)}")
  word_dict = dict()

  for word in words:
    word_dict[word] = []
    for char in word:
      word_dict[word].append(char)
  logging.info("Building the base vocab")
  ##vocab
  for word in word_dict.keys():
    vocab.update(word_dict[word])

  #merge rule dictonary
  merge_rules = dict()
  logging.info("Starting with building merge rules")
  while(len(vocab)<=vocab_length):
    logging.info("vocab updation")
    #find pair_freq
    pair_freq = defaultdict(int)
    for word in word_dict.keys():
      w_count = word_count_dict[word]
      tokens = word_dict[word]
      for token_idx in range(len(tokens)-1):
        pair = (tokens[token_idx],tokens[token_idx+1])
        pair_freq[pair] = pair_freq[pair] + w_count

    #find the max freq pair
    max_freq = 0
    max_pair = ()
    for pair in pair_freq.keys():
      if max_freq<pair_freq[pair]:
        max_freq = pair_freq[pair]
        max_pair = pair
    if len(max_pair) == 0:
      logging.warning(f"Vocabulary cannot be extended further. Exiting . Max vocab size achieved i.e all tokens in the corpus is captured {len(vocab)}")
      break

    ##TODO What to do when max_pair = 0 i.e all pair_freq is 0

    ##merge
    for word in word_dict.keys():
      tokens = word_dict[word]
      for tok_idx in range(len(tokens)-1):
        pair = (tokens[tok_idx],tokens[tok_idx+1])
        if pair==max_pair:
          merge_rules[pair] = pair[0]+pair[1]
          word_dict[word] = word_dict[word][:tok_idx]+[tokens[tok_idx]+tokens[tok_idx+1]]+word_dict[word][tok_idx+2:]

    vocab.add(max_pair[0]+max_pair[1])
  logging.info(f"Vocab updated \n {vocab}")

  return vocab,merge_rules


In [None]:
def tokenize(text):
  #word tokenize the
  text = text.lower()
  words = nltk.word_tokenize(text)
  #split each word into its charecters
  word_dict = dict()
  for word in words:
    word_dict[word] = []
    for char in word:
      word_dict[word].append(char)

  #tokenize each word as per the merge rules
  #for each merge rule iterate through each word to see what tokens can be formed of the word
  for pair,merge in merge_rules.items():
    for word in word_dict.keys():
      tokens = word_dict[word]
      i=0
      while(i<len(tokens)-1):
        pair_word = (tokens[i],tokens[i+1])
        if pair_word == pair:
          tokens = tokens[:i] + [pair_word[0]+pair_word[1]] + tokens[i+2:]
          word_dict[word] = tokens
        else:
          i = i + 1
  print(word_dict)

In [None]:
train_corpus = list(train_corpus["text"])

In [None]:
## finding the unique nos of words
uq_words = set()

for sen in tqdm(train_corpus):
    words = nltk.word_tokenize(sen)
    for word in words:
        uq_words.add(word)



In [None]:
print(len(uq_words))

Since the number of unique words is around 16000, hence keeping the vocab size around 17000 in order to capture a good number of tokens

In [None]:
vocab_size = 17000

In [None]:
%%time
vocab,merge_rules = train_tokenize(train_corpus,vocab_size)

In [None]:
len(merge_rules),len(vocab)

In [None]:
import pickle

with open("merge_rules.pkl","wb") as f:
    pickle.dump(merge_rules,f)

with open("vocab.pkl","wb") as f:
    pickle.dump(vocab,f)
    

### Test time

In [None]:
# @title Default title text
text = "This is not a token."

In [None]:
tokenize(text)