In [30]:
import nltk
import numpy as np
import re

from collections import Counter
from math import exp, log
from sklearn.base import TransformerMixin

In [35]:
text = open('warandpeace.txt', 'r', encoding="UTF-8").read()[2:]
len(text)

3408858

In [36]:
import string
import re

def preprocess_text(text):
    text = text.lower()
 #   punct = "“”‘’—" + string.punctuation.replace('.', '')
#    for item in punct:
    punctuation = '\]|!|"|#|\$|%|&|\'|\(|\)|\*|\+|,|-|/|:|;|<|=|>|\?|@|\[|\\|^|_|`|{|\||}|~|”|“|—|‘|’'
    text = text.replace(punctuation, " ")
    return re.sub('\s+', ' ', text)

text = preprocess_text(text)


In [37]:
text = text.split('.')
text = [x.strip() for x in text]

In [38]:
from progressbar import Percentage, Bar, ETA, FileTransferSpeed, ProgressBar
from IPython.display import clear_output

widgets = [Percentage(), ' ', Bar(marker='0', left='[', right=']'), ' ', ETA(), ' ', FileTransferSpeed()]

In [7]:
!pip install progressbar

Collecting progressbar
  Downloading https://files.pythonhosted.org/packages/a3/a6/b8e451f6cff1c99b4747a2f7235aa904d2d49e8e1464e0b798272aa84358/progressbar-2.5.tar.gz
Building wheels for collected packages: progressbar
  Building wheel for progressbar (setup.py) ... [?25ldone
[?25h  Stored in directory: /Users/Asalamatina/Library/Caches/pip/wheels/c0/e9/6b/ea01090205e285175842339aa3b491adeb4015206cda272ff0
Successfully built progressbar
Installing collected packages: progressbar
Successfully installed progressbar-2.5


In [39]:
from collections import Counter
import nltk
from sklearn.base import TransformerMixin
from copy import deepcopy


class BPE(TransformerMixin):
    def __init__(self, vocab_size=100):
        super(BPE, self).__init__()
        self.vocab_size = vocab_size
        # index to token
        self.itos = []
        # token to index
        self.stoi = {}
        
    @staticmethod
    def update_encoding(text, new_token, new_id):
        new_text, i = [], 0
        
        while i < len(text):
            if i == len(text) - 1:
                new_text.append(text[i])
            elif (text[i], text[i + 1]) == new_token:
                new_text.append(new_id)
                i += 1
            else:
                new_text.append(text[i])
            i += 1
        return new_text
        
    def fit(self, text):
        """
        fit itos and stoi
        text: list of strings 
        """
        
        # TODO
        # tokenize text by symbols and fill in self.itos and self.stoi
        text_ = deepcopy(" ".join(text))
        self.itos = list(set(list(text_)))
        self.stoi = {token: i for i, token in enumerate(self.itos)}
        text_ = [self.stoi[char] for char in text_]
        
        pbar = ProgressBar(widgets=["fitting: "] + widgets, maxval=self.vocab_size)
        pbar.start()
        
        while len(self.itos) < self.vocab_size:
            new_token = Counter([(text_[i], text_[i+1]) for i in range(len(text_) - 1)]).most_common(1)[0]
            new_id = len(self.itos)
            
            self.itos.append(new_token)
            self.stoi[new_token] = new_id
            
            # find occurences of the new_token in the text and replace them with new_id
            text_ = self.update_encoding(text_, new_token, new_id)
            pbar.update(len(self.itos))
        pbar.finish()
        return self
    
    def transform(self, text):
        """
        convert text to a sequence of token ids
        text: list of strings
        """
        clear_output()
        text_ =  deepcopy(text)
        pbar = ProgressBar(widgets=["transforming: "] + widgets, maxval=len(text))
        pbar.start()
        for i, sent in enumerate(text_):
            token_sent = [self.stoi[char] for char in sent]
            for token_id, token in enumerate(self.itos):
                text_[i] = self.update_encoding(token_sent, token, token_id)
            pbar.update(i)
        pbar.finish()
        return text_
    
    def decode_token(self, tok):
        """
        tok: int or tuple
        """
        result = ""
        
        def recursive_search(token):
            if type(token) == str:
                nonlocal result
                result += token
            elif type(token) == int:
                recursive_search(self.itos[token])
            else:
                for el in token:
                    recursive_search(token)
                    
        recursive_search(tok)
        return result
            
    def decode(self, text):
        """
        convert token ids into text
        """
        return ''.join(map(self.decode_token, text))
    
vocab_size = 100
bpe = BPE(vocab_size)
tokenized_text = bpe.fit_transform(text)

transforming: 100% [00000000000000000000000000000000] Time: 0:01:43 298.71  B/s


In [18]:
assert bpe.decode(tokenized_text[0]) == text[0]

In [19]:
import numpy as np
        
    
start_token = vocab_size
end_token = vocab_size + 1
        
    
class LM:
    def __init__(self, vocab_size, delta=1):
        self.delta = delta
        self.vocab_size = vocab_size + 2
        self.proba = {}
        
    def smoothen_count(self, count, tau):
        return (count + self.delta) ** (1/tau)
        
    def infer(self, a, b, tau=1):
        """
        return vector of probabilities of size self.vocab for 3-grams which start with (a,b) tokens
        a: first token id
        b: second token id
        tau: temperature
        """
        result = []
        for token in range(self.vocab_size):
            result.append(self.get_proba(a, b, token, tau))
        return np.array(result)
        
    def get_proba(self, a, b, c, tau=1):
        """
        get probability of 3-gram (a,b,c)
        a: first token id
        b: second token id
        c: third token id
        tau: temperature
        """
        all_tri_proba = []
        for token in range(self.vocab_size):
            all_tri_proba.append(self.smoothen_count(self.proba[(a, b, token)], tau))
        return self.smoothen_count(self.proba[(a, b, c)], tau) / sum(all_tri_proba)
    
    def fit(self, text):
        """
        train language model on text
        text: list of lists
        """
        
        trigrams = []
        for sent in text:
            pre_sent = [start_token] + sent + [end_token]
            for i in range(len(pre_sent) - 2):
                trigrams.append((pre_sent[i], pre_sent[i + 1], pre_sent[i + 2]))
        self.proba = Counter(trigrams)
        
        return self
    
lm = LM(vocab_size, 1).fit(tokenized_text)

In [20]:
def get_top_k_probs(probs, k):
    out = []
    sorted_probs = sorted(probs, reverse=True)
    for i in range(k):
        out += np.argwhere(probs == sorted_probs[i]).flatten().tolist()
    return out[:k]

In [21]:
def beam_search(input_seq, lm, max_len=10, k=5, tau=1):
    """
    generate sequence from language model *lm* conditioned on input_seq
    input_seq: sequence of token ids for conditioning
    lm: language model
    max_len: max generated sequence length
    k: size of beam
    tau: temperature
    """
    probs = np.log(lm.infer(input_seq[-2], input_seq[-1], tau))
    best_probs = get_top_k_probs(probs, k)
    beam = [(input_seq + [tok], probs[tok]) for tok in best_probs]
    
    for i in range(max_len):
        candidates = []
        candidates_proba = []
        for snt, snt_proba in beam:
            if snt == end_token:
                continue
            else:    
                proba = lm.infer(snt[-2], snt[-1], tau)
                best_k = get_top_k_probs(proba, k)
                candidates += [snt + [token] for token in best_k]
                candidates_proba += [snt_proba + np.log(proba)[snt] for snt in best_k]

        idxs = get_top_k_probs(np.array(candidates_proba), k)        
        beam = [(candidates[k], candidates_proba[k]) for k in idxs]
    return beam

In [22]:
input1 = 'horse '
input1 = bpe.transform([input1])[0]
result = beam_search(input1, lm, max_len=10, k=10, tau=0.1)
for pair in result:
    print(f"sent: {bpe.decode(pair[0])}; log proba {pair[1]}")

transforming: 100% [00000000000000000000000000000000] Time: 0:00:00 347.12  B/s


sent: horse and the and; log proba -0.7874070258780741
sent: horse the and the; log proba -1.9460246479040955
sent: horse and the the; log proba -1.9460246479040957
sent: horse the the and; log proba -3.47621908548869
sent: horse and the sai; log proba -4.212776190427485
sent: horse was and the; log proba -4.359988315344083
sent: horse said the an; log proba -4.466179590784896
sent: horse the the the; log proba -4.634836707514712
sent: horse and the who; log proba -4.852139260100694
sent: horse and the so ; log proba -5.013154211615189


In [23]:
input1 = 'her'
input1 = bpe.transform([input1])[0]
result = beam_search(input1, lm, max_len=10, k=10, tau=0.1)
for pair in result:
    print(f"sent: {bpe.decode(pair[0])}; log proba {pair[1]}")

transforming: 100% [00000000000000000000000000000000] Time: 0:00:00 473.72  B/s


sent: her the and th; log proba -0.725990954303979
sent: her and the an; log proba -1.7070708592891286
sent: her the the an; log proba -2.2561846551558027
sent: her and the th; log proba -2.865689218047921
sent: her the the th; log proba -3.414803013914595
sent: her the said t; log proba -4.404737092177984
sent: her the was an; log proba -4.670148322595789
sent: her and the sa; log proba -5.1324122389981675
sent: her and the wh; log proba -5.244115031429863
sent: her and the so; log proba -5.45957057599175


In [24]:
input1 = 'what'
input1 = bpe.transform([input1])[0]
result = beam_search(input1, lm, max_len=10, k=10, tau=1)
for pair in result:
    print(f"sent: {bpe.decode(pair[0])}; log proba {pair[1]}")

transforming: 100% [00000000000000000000000000000000] Time: 0:00:00 414.95  B/s


sent: what of the the; log proba -9.629208191651452
sent: what the and th; log proba -9.701817500490872
sent: what of the and; log proba -10.070307580990702
sent: what the the th; log proba -10.093799431363207
sent: what the the an; log proba -10.460193178758342
sent: what the and an; log proba -10.61782262685417
sent: what and the an; log proba -10.640143023730092
sent: what the and to; log proba -10.752954893403667
sent: what the the to; log proba -11.144936824276002
sent: what of the tha; log proba -11.301791485156114


In [25]:
input1 = 'gun '
input1 = bpe.transform([input1])[0]
result = beam_search(input1, lm, max_len=10, k=10, tau=0.1)
for pair in result:
    print(f"sent: {bpe.decode(pair[0])}; log proba {pair[1]}")

transforming: 100% [00000000000000000000000000000000] Time: 0:00:00 542.32  B/s


sent: gun the and the; log proba -0.42396723466410147
sent: gun the the and; log proba -1.9541616722486956
sent: gun the the the; log proba -3.1127792942747172
sent: gun the said th; log proba -4.102740536303693
sent: gun and the and; log proba -4.204157793821503
sent: gun the was and; log proba -4.3681253396886826
sent: gun the so the ; log proba -4.663512655169986
sent: gun the and and; log proba -4.761463402319709
sent: gun the the sai; log proba -5.379530836798107
sent: gun the the who; log proba -6.018893906471315


In [26]:
from math import log, exp

def perplexity(snt, lm):
    """
    snt: sequence of token ids
    lm: language model
    """
    perplexity = 0
    
    snt = [start_token] + snt + [end_token]

    for char in range(len(snt) - 2):
        perplexity += log((1 / lm.infer(snt[char], snt[char + 1])[snt[char + 2]]))
    result = pow(perplexity, -1 / float(len(snt)))
    return exp(result)

perplexity(tokenized_text[0], lm)

2.677992011649401