##### Implementation of n-gram model:
1. Process text: split into sentences and tokens, augment sentences with n-1 SOS tokens and one EOS token. Optionally, you can only replace rare words with token UNK.
2. Compute n-grams from your text
3. Compute probabilities for your n-grams using smoothing
4. using formula for n-gram, generate several sentences

In [253]:
import re
import requests
import nltk
from nltk.tokenize import sent_tokenize, word_tokenize 
from nltk.stem.porter import PorterStemmer
from nltk.stem import WordNetLemmatizer
from nltk.corpus import stopwords
import math

In [254]:
def clean_text(paragraph, n):
    p = paragraph[n]

    href_pattern = r"<a href=[^>]+>([^<]+)</a>"
    p = re.sub(href_pattern, r"\1", p)

    boldface_pattern = r"<b>([^<]+)</b>"
    p = re.sub(boldface_pattern, r"\1", p)

    italic_pattern = r"<i>([^<]+)</i>"
    p = re.sub(italic_pattern, r"\1", p)

    span_pattern = r"<span\s[^<]+</span>"
    p = re.sub(span_pattern, "", p)

    sup_pattern = r"<sup\s[^<]+</sup>"
    p = re.sub(sup_pattern, "", p)

    return p

In [255]:
#r = requests.get(url="https://en.wikipedia.org/wiki/Machine_learning")
r = requests.get(url="https://en.wikipedia.org/wiki/Rock_music")
wiki_text = r.text

In [256]:
paragraph = re.findall(r"<p>(.*)\n", wiki_text)

In [257]:
# get text
t = []
paragraphs_num = range(len(paragraph))

for n in paragraphs_num:
    p = clean_text(paragraph, n)
    t.append(p)

try:
    t.remove('')
except:
    pass

In [258]:
SOS = '<s> '
EOS = '</s>'
UNK = '<unk>'

def add_tokens(s, n):
    sos = SOS * (n-1) if n > 1 else SOS
    return ['{}{} {}'.format(sos, sent, EOS) for sent in s]

In [259]:
# get sentences
s = []
for p in t:
    p_sent_nltk = sent_tokenize(p)
    for sent in p_sent_nltk:    
        s.append(sent)

s = add_tokens(s, 3)
s = [sent.lower() for sent in s]

In [260]:
def replace_single(w):
    vocab = nltk.FreqDist(w)
    return [token if vocab[token] > 1 else UNK for token in w]

In [261]:
# get tokens
w = ' '.join(s).split(' ')
w = replace_single(w)

In [262]:
# get vocabulary
vocab = nltk.FreqDist(w)
vocab

FreqDist({'<unk>': 2878, 'the': 1081, '<s>': 960, 'and': 760, 'of': 507, '</s>': 480, 'in': 337, 'rock': 269, 'a': 261, 'to': 251, ...})

In [263]:
n = 3
vocab_size = len(vocab)

n_grams = nltk.ngrams(w, n)
n_vocab = nltk.FreqDist(n_grams)

n1_grams = nltk.ngrams(w, n-1)
n1_vocab = nltk.FreqDist(n1_grams)

In [264]:
def smooth(n_gram, n_count, k, n1_vocab, vocab_size):
    n1_gram = n_gram[:-1]
    n1_count = n1_vocab[n1_gram]
    return -math.log((n_count + k) / (n1_count + k * vocab_size))

In [265]:
prob = {n_gram: smooth(n_gram, c, 1, n1_vocab, vocab_size) for n_gram, c in n_vocab.items()}
prob

{('<s>', '<s>', 'rock'): 4.791926339849912,
 ('<s>', 'rock', 'is'): 6.508769136971682,
 ('rock', 'is', 'a'): 6.501289670540389,
 ('is', 'a', 'broad'): 6.503539390274405,
 ('a', 'broad', 'genre'): 6.499787040655854,
 ('broad', 'genre', 'of'): 6.4990348781533,
 ('genre', 'of', 'popular'): 6.501289670540389,
 ('of', 'popular', 'music'): 6.501289670540389,
 ('popular', 'music', 'that'): 6.09732493780746,
 ('music', 'that', 'originated'): 6.501289670540389,
 ('that', 'originated', 'as'): 6.4990348781533,
 ('originated', 'as', '"rock'): 6.4990348781533,
 ('as', '"rock', 'and'): 6.4990348781533,
 ('"rock', 'and', 'roll"'): 5.808142489980444,
 ('and', 'roll"', 'in'): 6.501289670540389,
 ('roll"', 'in', 'the'): 6.4990348781533,
 ('in', 'the', 'united'): 4.99991133073328,
 ('the', 'united', 'states'): 5.256006168476314,
 ('united', 'states', 'in'): 6.502790045915623,
 ('states', 'in', 'the'): 6.4990348781533,
 ('in', 'the', 'late'): 4.663439094112067,
 ('the', 'late', '1940s'): 6.115156019450532

In [266]:
def choise(model, prev_w, i, options_bad):
    w_bad = ['<unk>'] + options_bad
    w_options = ((ngram[-1], prob) for ngram, prob in model.items() if ngram[:-1]==prev_w)
    w_options = filter(lambda word: word[0] not in w_bad, w_options)
    w_options = sorted(w_options, key=lambda word: word[1], reverse=True)
    if len(w_options) == 0:
        return ('</s>', 1)
    else:
        return w_options[0 if prev_w != () and prev_w[-1] != '<s>' else i]

In [267]:
def generate(model, n, num_s, len_s=[10, 20]):
    for i in range(num_s):
        sent, prob = ['<s>'] * max(1, n-1), 1
        while sent[-1] != '</s>':
            prev_w = () if n == 1 else tuple(sent[-(n-1):])
            options_bad = sent + (['</s>'] if len(sent) < len_s[0] else [])
            next_w, next_prob = choise(model, prev_w, i, options_bad)
            sent.append(next_w)
            prob += next_prob

            if len(sent) >= len_s[1]:
                sent.append('</s>')

        print(' '.join(sent), 1/prob)


In [269]:
generate(prob, 3, 5)

<s> <s> further fusion subgenres have since emerged, including pop-punk, electronic rock, rap </s> 0.013550515520893976
<s> <s> melodies often </s> 0.06533443550881976
<s> <s> harmonies range from the mid-1960s, particularly in california and texas. </s> 0.013730218879121516
<s> <s> as a rock group with electric bass guitar, which emerged in its modern form </s> 0.01080918121714382
<s> <s> christgau, writing in christgau's record guide: the </s> 0.02127971206202656
