In [86]:
from typing import List, Tuple, Dict, Callable
import nltk
from utils import strip_word

BOS = "<BOS>"
EOS = "<EOS>"
UNK = "<UNK>"
N = 3

In [87]:
def add_special_tokens(s: str, n: int) -> str:
    start = (BOS + " ") * max(n - 1, 1)
    end = f" {EOS}"
    return f"{start}{s}{end}"

def filter_rare_tokens(tokens: List[str]) -> List[str]:
    words_count = nltk.FreqDist(tokens)
    # print(words_count.items())
    return [token if words_count[token] > 1 else UNK for token in tokens]

# returns corpus and vocab tuple
def preprocess(data_file: str, n: int = 1) -> Tuple[List[str], Dict[str, int]]:
    with open(data_file, encoding='utf8') as f:
        data_ = f.read()

    data = [add_special_tokens(s.strip(), n) for s in nltk.sent_tokenize(data_)]
    sentences = [sentence.split(' ') for sentence in data]
    tokens = [strip_word(word) for sentence in sentences for word in sentence if len(word) > 1]
    tokens = filter_rare_tokens(tokens)
    return tokens, nltk.FreqDist(tokens)

In [88]:
tokens, words_count = preprocess("../data/data3.txt", N)
print(words_count.items())

dict_items([('<BOS>', 26402), ('ედუარდ', 15), ('შევარდნაძე', 18), ('<UNK>', 26587), ('ძე', 41), ('დ', 56), ('<EOS>', 13201), ('25', 49), ('იანვარი', 90), ('1928', 8), ('სოფელი', 51), ('ოზურგეთის', 10), ('მაზრა', 5), ('საქართველოს', 561), ('სსრ', 33), ('გ', 55), ('ივლისი', 18), ('2014', 52), ('თბილისი', 58), ('საქართველო', 50), ('ქართველი', 49), ('პოლიტიკოსი', 2), ('და', 6361), ('სახელმწიფო', 211), ('მოღვაწე', 21), ('წლებში', 266), ('სსრ-ის', 7), ('კომპარტიის', 7), ('პირველი', 241), ('მდივანი', 15), ('საბჭოთა', 181), ('კავშირის', 81), ('საგარეო', 32), ('საქმეთა', 17), ('მინისტრი', 7), ('პრეზიდენტი', 49), ('იყო', 830), ('სსრკ-ის', 7), ('უმაღლესი', 83), ('საბჭოს', 49), ('დეპუტატი', 3), ('სოციალისტური', 24), ('შრომის', 18), ('გმირი', 16), ('1981', 7), ('სკკპ', 4), ('პოლიტბიუროს', 6), ('წევრობის', 6), ('კანდიდატი', 4), ('1978', 8), ('წევრი', 42), ('1985', 7), ('წლის', 812), ('ივლისიდან', 8), ('საზოგადოებრივი', 39), ('წესრიგის', 7), ('დაცვის', 14), ('მინისტრის', 5), ('მოადგილე', 7), ('შინაგა

In [89]:
def get_n_gram_count(n_gram, n_count, n_minus_1_vocab, laplace=1):
    n_minus_1_gram = n_gram[:-1]
    n_minus_1_count = n_minus_1_vocab[n_minus_1_gram]
    return (n_count + laplace) / (n_minus_1_count + laplace * len(words_count))

def laplace_smooth():
    n_grams = nltk.ngrams(tokens, N)
    n_vocab = nltk.FreqDist(n_grams)

    n_minus_1_grams = nltk.ngrams(tokens, N-1)
    n_minus_1_vocab = nltk.FreqDist(n_minus_1_grams)

    return {n_gram: get_n_gram_count(n_gram, count, n_minus_1_vocab) for n_gram, count in n_vocab.items()}


In [90]:
def create_model() -> Dict[Tuple[str], int]:
    if N == 1:
        return {(token,): count / len(tokens) for token, count in words_count.items()}
    else:
        return laplace_smooth()

In [119]:
def get_filter_fn(omit_words_list: Dict[Tuple[str, str], int]) -> Callable[[str], int]:
    return lambda candidate: candidate[0] not in omit_words_list

comparison_fn = lambda candidate: candidate[1]

def best_candidate(prev, i, omit_words_list) -> str:
    model = create_model()
    print(omit_words_list)
    print("prev: ", prev)
    omit_words_list += [UNK]
    candidates = ((ngram[-1], prob) for ngram, prob in model.items() if ngram[:-1] == prev)
    print("last word: ", model)
    candidates = filter(get_filter_fn(omit_words_list), candidates)
    candidates = sorted(candidates, key=comparison_fn, reverse=True)
    print(candidates)

    if len(candidates) == 0:
        return EOS
    else:
        return candidates[0 if prev != () and prev[-1] != BOS else i][0]


In [120]:
def generate_sentences(min_len: int = 8, max_len: int = 20, init_sent: str = "") -> str:
    init_sent = [strip_word(word) for word in init_sent.split(" ")]
    sent = [BOS] * max(1, N-1) + init_sent
    while sent[-1] != EOS:
        print("1")
        prev = () if N == 1 else tuple(sent[-(N-1):])
        print(f"2 {prev}")
        blacklist = sent + ([EOS] if len(sent) < min_len else [])
        print(f"3 {blacklist}")
        next_token = best_candidate(prev, 1, omit_words_list=blacklist)
        print(f"4 {next_token}")
        sent.append(next_token)
        # prob *= next_prob

        if len(sent) >= max_len:
            sent.append(EOS)

    return ' '.join(sent)#, -1/math.log(prob)

In [108]:
generate_sentences()

1
2 ('<BOS>', '')
3 ['<BOS>', '<BOS>', '', '<EOS>']
4 საციციანო
1
2 ('', 'საციციანო')
3 ['<BOS>', '<BOS>', '', 'საციციანო', '<EOS>']
4 XIV
1
2 ('საციციანო', 'XIV')
3 ['<BOS>', '<BOS>', '', 'საციციანო', 'XIV', '<EOS>']
4 ს
1
2 ('XIV', 'ს')
3 ['<BOS>', '<BOS>', '', 'საციციანო', 'XIV', 'ს', '<EOS>']
4 <EOS>


'<BOS> <BOS>  საციციანო XIV ს <EOS>'

In [121]:
generate_sentences(init_sent="მე")

1
2 ('<BOS>', 'მე')
3 ['<BOS>', '<BOS>', 'მე', '<EOS>']


IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



'<BOS> <BOS> მე <EOS>'