In [None]:
import nltk
nltk.download('all')

In [None]:
import nltk, random, math
from collections import Counter, defaultdict
from nltk.corpus import gutenberg, stopwords
from nltk.tokenize import sent_tokenize, word_tokenize

# Download resources if first time
nltk.download('gutenberg')
nltk.download('punkt')
nltk.download('stopwords')
nltk.download('wordnet')
nltk.download('omw-1.4')

from nltk.stem import PorterStemmer, WordNetLemmatizer

STEM = False
LEMMATIZE = True
REMOVE_STOPWORDS = False
USE_UNK = True
UNK_THRESH = 3

stemmer = PorterStemmer()
lemmatizer = WordNetLemmatizer()
stops = set(stopwords.words('english'))

text = gutenberg.raw('melville-moby_dick.txt')
sentences = sent_tokenize(text)
import string

def clean_tokenize(sent):
    tokens = word_tokenize(sent.lower())
    processed = []
    for t in tokens:
        if t in string.punctuation:
            continue
        if REMOVE_STOPWORDS and t in stops:
            continue
        if STEM:
            t = stemmer.stem(t)
        if LEMMATIZE:
            t = lemmatizer.lemmatize(t)
        processed.append(t)
    return processed

tokenized = [clean_tokenize(s) for s in sentences if len(s) > 0]

random.seed(42)
random.shuffle(tokenized)

split = int(0.8 * len(tokenized))
split_valid = int(0.9 * len(tokenized))
train_toks = tokenized[:split]
valid_toks = tokenized[split:split_valid]
test_toks = tokenized[split_valid:]

# UNK treatment on train/valid/test
word_counts = Counter(w for s in train_toks for w in s)
def replace_unk(s, vocab):
    return [w if w in vocab else '<UNK>' for w in s]

if USE_UNK:
    vocab = set([w for w, c in word_counts.items() if c > UNK_THRESH])
    train_toks = [replace_unk(s, vocab) for s in train_toks]
    valid_toks = [replace_unk(s, vocab) for s in valid_toks]
    test_toks = [replace_unk(s, vocab) for s in test_toks]
else:
    vocab = set(w for s in train_toks for w in s)

def add_tokens(data, n):
    return [['<s>']*(n-1) + s + ['</s>'] for s in data]

# Hyperparameter: select n by validation set perplexity
def build_and_eval(n):
    train_sents = add_tokens(train_toks, n)
    valid_sents = add_tokens(valid_toks, n)
    # Count ngrams
    def count_ngrams(data, n):
        ngrams, hist = Counter(), Counter()
        for s in data:
            for i in range(len(s) - n + 1):
                ng = tuple(s[i:i+n])
                h = ng[:-1]
                ngrams[ng] += 1
                hist[h] += 1
        return ngrams, hist
    c1, h1 = count_ngrams(train_sents, 1)
    c2, h2 = count_ngrams(train_sents, 2)
    c3, h3 = count_ngrams(train_sents, 3)
    c4, h4 = count_ngrams(train_sents, 4)
    V = len(vocab)
    k = 0.1
    def P1(w): return (c1.get((w,), 0) + k) / (sum(c1.values()) + k * V)
    def P2(w1, w2): return (c2.get((w1, w2), 0) + k) / (h2.get((w1,), 0) + k * V)
    def P3(w1, w2, w3): return (c3.get((w1, w2, w3), 0) + k) / (h3.get((w1, w2), 0) + k * V)
    def P4(w1, w2, w3, w4): return (c4.get((w1, w2, w3, w4), 0) + k) / (h4.get((w1, w2, w3), 0) + k * V)
    λ_bi, λ_tri, λ_quad = [0.1, 0.9], [0.1, 0.3, 0.6], [0.1, 0.2, 0.3, 0.4]
    def interpolated_prob(n, hist, w):
        if n == 2:
            return λ_bi[0]*P1(w) + λ_bi[1]*P2(hist[-1], w)
        elif n == 3:
            return λ_tri[0]*P1(w) + λ_tri[1]*P2(hist[-1], w) + λ_tri[2]*P3(hist[-2], hist[-1], w)
        elif n == 4:
            return (λ_quad[0]*P1(w) + λ_quad[1]*P2(hist[-1], w) +
                    λ_quad[2]*P3(hist[-2], hist[-1], w) + λ_quad[3]*P4(hist[-3], hist[-2], hist[-1], w))
    def perplexity(n, data):
        log_prob, N = 0, 0
        for s in data:
            for i in range(n-1, len(s)):
                hist, w = s[i-n+1:i], s[i]
                p = interpolated_prob(n, hist, w)
                log_prob += math.log(p)
                N += 1
        return math.exp(-log_prob / N)
    ppl = perplexity(n, valid_sents)
    return ppl

# Find n (2,3,4) that minimizes validation perplexity
ppl_ns = {n: build_and_eval(n) for n in [2,3,4]}
best_n = min(ppl_ns, key=ppl_ns.get)
print(f"Best n based on validation set: {best_n}\nPerplexities: {ppl_ns}")

In [None]:
print("\n Enter two starting words (e.g., 'the whale'):")
user_input = input(">> ").lower().strip().split()

print("\n Sentence from Bigram:")
print(generate(2, user_input))

print("\n Sentence from Trigram:")
print(generate(3, user_input))

print("\n Sentence from 4-gram:")
print(generate(4, user_input))