In [None]:
import torch
import os
import requests
import numpy as np


In [None]:


# download the tiny shakespeare dataset
input_file_path = os.path.join(os.path.dirname("/content/sample_data"), 'input.txt')
if not os.path.exists(input_file_path):
    data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
    with open(input_file_path, 'w', encoding='utf-8') as f:
        f.write(requests.get(data_url).text)

with open(input_file_path, 'r', encoding='utf-8') as f:
    data = f.read()
n = len(data)
# train_data = data[:int(n*0.9)]
# val_data = data[int(n*0.9):]



In [None]:
n

1115394

In [None]:
data[:200]

'First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou are all resolved rather to die than to famish?\n\nAll:\nResolved. resolved.\n\nFirst Citizen:\nFirst, you'

In [None]:
t='asd  a'
len(t)

6

In [None]:
def preprocess_text(text):
    text = text.lower()
    text = text.replace("\n", " ")
    return text


In [None]:
corpus = data
corpus = preprocess_text(corpus)


In [None]:
corpus[:200]

'first citizen: before we proceed any further, hear me speak.  all: speak, speak.  first citizen: you are all resolved rather to die than to famish?  all: resolved. resolved.  first citizen: first, you'

In [None]:
from collections import defaultdict

def build_vocab(text):
    vocab = defaultdict(int)
    for word in text.split():
        chars = " ".join(list(word)) + " </w>"
        vocab[chars] += 1
    return vocab

vocab = build_vocab(corpus)


In [None]:
len(vocab)

23641

In [None]:
def get_pair_freqs(vocab):
    pair_freqs = defaultdict(int)
    for word, freq in vocab.items():
        symbols = word.split()
        for i in range(len(symbols) - 1):
            pair = (symbols[i], symbols[i + 1])
            pair_freqs[pair] += freq
    return pair_freqs


In [None]:
def merge_pair(pair, vocab):
    merged_vocab = {}
    bigram = " ".join(pair)
    replacement = "".join(pair)

    for word, freq in vocab.items():
        new_word = word.replace(bigram, replacement)
        merged_vocab[new_word] = freq

    return merged_vocab


In [None]:
def train_bpe(text, num_merges):
    vocab = build_vocab(text)
    merges = []

    for i in range(num_merges):
        pair_freqs = get_pair_freqs(vocab)
        if not pair_freqs:
            break

        best_pair = max(pair_freqs, key=pair_freqs.get)
        vocab = merge_pair(best_pair, vocab)
        merges.append(best_pair)

    return merges, vocab


In [None]:
merges, final_vocab = train_bpe(corpus, num_merges=8000)


In [None]:
# final_vocab

In [None]:
def bpe_encode(word, merges):
    symbols = list(word) + ["</w>"]

    for pair in merges:
        i = 0
        while i < len(symbols) - 1:
            if symbols[i] == pair[0] and symbols[i + 1] == pair[1]:
                symbols[i:i+2] = ["".join(pair)]
                print(pair)
            else:
                i += 1

    return symbols


In [None]:
# print(bpe_encode("@strb", merges))
# final_vocab
merges

[('e', '</w>'),
 ('t', 'h'),
 (',', '</w>'),
 ('t', '</w>'),
 ('s', '</w>'),
 ('d', '</w>'),
 ('o', 'u'),
 ('a', 'n'),
 ('e', 'r'),
 ('i', 'n'),
 ('y', '</w>'),
 (':', '</w>'),
 ('o', 'r'),
 ('e', 'n'),
 ('o', '</w>'),
 ('a', 'r'),
 ('.', '</w>'),
 ('o', 'n'),
 ('l', 'l'),
 ('th', 'e</w>'),
 ('h', 'a'),
 ('an', 'd</w>'),
 ('e', 's'),
 ('i', 's</w>'),
 ('y', 'ou'),
 ('f', '</w>'),
 ('t', 'o</w>'),
 ('i', '</w>'),
 ('in', 'g'),
 ('ll', '</w>'),
 ('n', 'o'),
 ('w', 'i'),
 ('e', 'a'),
 ('o', 'm'),
 ('e', ',</w>'),
 ('o', 'f</w>'),
 ('s', 't'),
 (';', '</w>'),
 ('er', '</w>'),
 ('r', '</w>'),
 ('th', '</w>'),
 ('m', 'y</w>'),
 ('a', '</w>'),
 ('h', 'i'),
 ('l', 'i'),
 ('v', 'e</w>'),
 ('in', '</w>'),
 ('o', 'w'),
 ('s', 'e'),
 ('r', 'i'),
 ('t', 'i'),
 ('c', 'h'),
 ('you', '</w>'),
 ('?', '</w>'),
 ('tha', 't</w>'),
 ('th', 'e'),
 ('r', 'e'),
 ('m', 'a'),
 ('l', 'e'),
 ('b', 'u'),
 ('!', '</w>'),
 ('s', 'h'),
 ('o', 'o'),
 ('g', 'h'),
 ('d', ',</w>'),
 ('a', 's</w>'),
 ('b', 'e'),
 ('w', 'h

In [None]:
def build_tokenizer(final_vocab):
    tokens = set()
    for word in final_vocab:
        for tok in word.split():
            tokens.add(tok)
    ;
    token2id = {tok: i for i, tok in enumerate(sorted(tokens))}
    id2token = {i: tok for tok, i in token2id.items()}
    return token2id, id2token


In [None]:

token2id, id2token = build_tokenizer(final_vocab)
