In [1]:
import sys
import os
import numpy as np
import matplotlib.pyplot as plt
import itertools
from tqdm import tqdm 

%config IPCompleter.greedy=True

In [2]:
polish_corpora_path = '../../../polish_corpora.txt'
poleval2_path = '../../../poleval_2grams.txt'
poleval3_path = '../../../poleval_3grams.txt'

## Sentences

In [3]:
sentences = []
with open('sentences.txt', encoding='utf8') as f:
    for line in f:
        line = line.strip().lower().split()
        sentences.append(line)
        
print(len(sentences))

68


## Functions

In [68]:
def generate_permutations(sentence: list) -> list:
    return list(itertools.permutations(sentence))

def order_permutations(permutations: list):
    scores = [pbb_sentence(sentence=x) for x in permutations]
    return sorted(zip(scores, permutations), reverse=True)

def score_permuatations(valid_sentence: list):
    permutations: list = generate_permutations(sentence=valid_sentence)
    permutations = order_permutations(permutations=permutations)
    permutations = [y for x, y in permutations]
    for i, sentence in enumerate(permutations):
        if tuple(sentence) == tuple(valid_sentence):
            return 1 / (i + 1)
        
    return 0

## Unigrams

In [5]:
unigrams: dict = {}  # word -> number of occurrences

with open(polish_corpora_path, encoding="utf8") as f:
    for line in tqdm(f, desc='Loading data...', position=0, leave=True):
        line = line.strip().lower().split()
        for word in line:
            _word = bytes(bytearray(word, 'UTF-8'))
            if _word in unigrams:
                unigrams[_word] += 1
            else:
                unigrams[_word] = 1

Loading data...: 23011601it [05:50, 65686.98it/s]


In [6]:
print(f'Number of unigrams: {len(unigrams)}')

Number of unigrams: 3591114


In [12]:
all_unigrams = sum([x for x in unigrams.values()])
all_unigrams

451846640

## Bigrams

In [7]:
bigrams: dict = {}  # word -> List[Tuple[word_after, number of occurrences]]
    
with open(poleval2_path, encoding="utf8") as f:
    for line in tqdm(f, desc='Loading data...', position=0, leave=True):
        line = line.strip().lower().split()
        key: bytearray = bytes(bytearray(line[1], 'UTF-8'))
        value: tuple = (bytes(bytearray(line[2], 'UTF-8')), line[0])
            
        if key in bigrams:
            bigrams[key].append(value)
        else:
            bigrams[key] = [value]

Loading data...: 59134224it [02:15, 437635.80it/s]


In [8]:
print(f'Number of bigrams: {len(bigrams)}')

Number of bigrams: 3591115


In [51]:
def cond_pbb_bigrams(w1: bytes, w2: bytes):
    """ Calculate P(w2 | w1) """
    cnt_w1_w2: int = 1
    if w1 in bigrams:
        for word_after_w1, n_ocurr in bigrams[w1]:
            if word_after_w1 == w2:
                cnt_w1_w2 = int(n_ocurr)
                break
        
    cnt_w1: int = 1
    if w1 in unigrams:
        cnt_w1 = unigrams[w1]
        
    return cnt_w1_w2 / cnt_w1


def pbb_unigram(w1: bytes):
    cnt_w1 = 1
    if w1 in unigrams:
        cnt_w1 = unigrams[w1]
        
    return cnt_w1 / all_unigrams


def pbb_sentence(sentence: list):
    """ Calculate P(w1 ... wn ) """
    n: int = len(sentence)
    assert n > 0
    sentence = list(map(lambda x: bytes(bytearray(x, encoding='UTF-8')), sentence))
    pbb = pbb_unigram(w1=sentence[0])
    for i in range(1, n):
        pbb *= cond_pbb_bigrams(w1=sentence[i - 1], w2=sentence[i])
        
    return pbb

In [63]:
sentences[0]

['judyta', 'dała', 'wczoraj', 'stefanowi', 'czekoladki']

In [78]:
score_permuatations(sentences[0])

0.010526315789473684

## Trigrams

In [None]:
trigrams: dict = {}  # Tuple[w1, w2] -> List[Tuple[w3, number of occurrences]]
total_iters, skipped1, skipped2 = 0, 0, 0

with open(poleval3_path, encoding="utf8") as f:
    for line in tqdm(f, desc='Loading data...', position=0, leave=True):
        total_iters += 1
        line = line.strip().lower().split()

        if len(line) != 4:
            skipped1 += 1
            continue
            
        if int(line[0]) == 1:
            skipped2 += 1
            continue
            
        key: tuple = (bytes(bytearray(line[1], 'UTF-8')),  bytes(bytearray(line[2], 'UTF-8')))
        value: tuple = (bytes(bytearray(line[3], 'UTF-8')), line[0])
 
        if key in trigrams:
            trigrams[key].append(value)
        else:
            trigrams[key] = [value]