In [77]:
import itertools
import re
import random
import pandas as pd
import numpy as np
import math
from decision_trees import *
from timeit import default_timer as timer

import jupyternotify
ip = get_ipython()
ip.register_magics(jupyternotify.JupyterNotifyMagics)

base_forms = ["adj", "adja", "adjc", "adjp", "adv", "burk", "depr", "ger", "conj", "comp", "num", "pact",
               "pant", "pcon", "ppas", "ppron12", "ppron3", "pred", "prep", "siebie", "subst", "verb", "brev",
               "interj", "qub"]

verb_forms = ["nom", "gen", "acc", "dat", "inst", "loc", "voc"]

raw_form = {"subst:nom":[], "subst:gen":[], "subst:acc":[], "subst:dat":[], "subst:inst":[], "subst:loc":[], "subst:voc":[],
            "adj":[], "adja":[], "adjc":[], "adjp":[], "adv":[], "burk":[], "depr":[], "ger":[], "conj":[], "comp":[], "num":[], "pact":[],
            "pant":[], "pcon":[], "ppas":[], "ppron12":[], "ppron3":[], "pred":[], "prep":[], "siebie":[], "verb":[], "brev":[],
            "interj":[], "qub":[], "target":[]}

empty_form = {"subst:nom":0, "subst:gen":0, "subst:acc":0, "subst:dat":0, "subst:inst":0, "subst:loc":0, "subst:voc":0,
            "adj":0, "adja":0, "adjc":0, "adjp":0, "adv":0, "burk":0, "depr":0, "ger":0, "conj":0, "comp":0, "num":0, "pact":0,
            "pant":0, "pcon":0, "ppas":0, "ppron12":0, "ppron3":0, "pred":0, "prep":0, "siebie":0, "verb":0, "brev":0,
            "interj":0, "qub":0, "target":0}


signs = ['.', '(', ')', ';', '"', '[', ']', ',', '?', '!', ':', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
polish = [('ź', 'z'), ('ż', 'z'), ('ą', 'a'), ('ę', 'e'), ('ó', 'o'), ('ł', 'l'), ('ć', 'c'), ('ń', 'n'), ('ś', 's')]

<IPython.core.display.Javascript object>

In [78]:
def tokenize(line):
    line2 = []
    line = line.split(' ')
    for base in line:
        base = base.lower()
        for sign in signs:
            base = base.replace(sign, ' ')
        base = base.strip()
        base = base.split(' ')
        if base != '' and base != ['']:
            line2.extend(base)

    return line2


def remove_polish(line):
    line2 = []
    for word in line:
        for sign in polish:
            word = word.replace(sign[0], sign[1])
        line2.append(word)
    return line2

In [79]:
def tokenize_big(line):
    line2 = []
    line = line.split(' ')
    for base in line:
        for sign in signs:
            base = base.replace(sign, ' ')
        base = base.strip()
        base = base.split(' ')
        if base != '' and base != ['']:
            line2.extend(base)

    return line2

In [80]:
def load_polimorph(file='polimorfologik-2.1.txt'):
    dictionary = {}
    with open(file, 'r', encoding='utf8') as base_file:
        for line in base_file:
            line = line.strip().lower()
            line = line.split(";")
            line[2] = line[2].split("+")
            nl = []
            for comp in line[2]:
                spl = comp.split(":")
                if spl[0] != "subst":
                    if spl[0] not in nl:
                        nl.append(spl[0])
                else:
                    if spl[0] + ":" + spl[2] not in nl:
                        nl.append(spl[0] + ":" + spl[2])
            line[2] = nl
            dictionary[line[1]] = (line[0], line[2])

    return dictionary


def create_casts(base_poli):
    dictionary = {}
    for key in base_poli:
        weak_key = remove_polish([key])[0]
        if weak_key not in dictionary:
            dictionary[weak_key] = []
        dictionary[weak_key].append(key)

    return dictionary

In [81]:
def load_polimorph2(file='polimorfologik-2.1.txt'):
    dictionary = {}
    with open(file, 'r', encoding='utf8') as base_file:
        for line in base_file:
            line = line.strip()
            line = line.split(";")
            if line[1].lower() != line[1]:
                dictionary[line[1].lower()] = line[1]

    return dictionary

In [82]:
def load_unigrams(file='1grams'):
    dictionary = {}
    with open(file, 'r', encoding='utf8') as base_vectors_lines:
        for line in base_vectors_lines:
            line = line.strip().lower()
            line = line.split(' ')
            dictionary[line[1]] = int(line[0])

    return dictionary


def load_2grams(file='2grams', k=5):
    dictionary = {}
    i = 0
    with open(file, 'r', encoding='utf8') as base_vectors_lines:
        for line in base_vectors_lines:
            line = line.strip().lower()
            line = line.split(' ')
            if int(line[0]) >= k:
                if line[1] not in dictionary:
                    dictionary[line[1]] = {}
                dictionary[line[1]][line[2]] = int(line[0])
            else:
                break
            i += 1

    return dictionary


def load_3grams(file='3grams', k=5):
    dictionary = {}
    i = 0
    with open(file, 'r', encoding='utf8') as base_vectors_lines:
        for line in base_vectors_lines:
            line = line.strip().lower()
            line = line.split(' ')
            if int(line[0]) >= k:
                if line[1] not in dictionary:
                    dictionary[line[1]] = {}
                dictionary[line[1]][line[2:]] = int(line[0])
            else:
                break
            i += 1

    return dictionary

In [83]:
def load_set(file='train_shuf.txt', k=10000):
    i = 0
    lines = []
    with open(file, 'r', encoding='utf8') as base_vectors_lines:
        for line in base_vectors_lines:
            if i == k:
                break
            lines.append(tokenize(line))
            i += 1

    return lines


def divide_set(total_set, k=0.7):
    s = round(len(total_set)*k)
    return total_set[:s], total_set[s:]

In [84]:
def load_big_set(file='train_shuf.txt', k=10000):
    i = 0
    lines = []
    with open(file, 'r', encoding='utf8') as base_vectors_lines:
        for line in base_vectors_lines:
            if i == k:
                break
            lines.append(tokenize_big(line))
            i += 1

    return lines

In [85]:
def flatten(listt):
        a = []
        for itemm in listt:
            if isinstance(itemm, list):
                a += flatten(itemm)
            else:
                a.append(itemm)
        return a

def combine(lines):
    conc = lines[0]
    if len(lines) == 1:
        conc = [conc]
    for part in lines[1:]:
        conc = list(map(list, itertools.product(conc, part)))
    conc2 = []
    for item in conc:
        conc2.append(flatten(item))
    return conc2

def permute(line, casts):
    line2 = []
    for word in line:
        if word in casts:
            line2.append(casts[word])
        else:
            line2.append([word])

    return combine(line2)

In [86]:
combine([[None, "a"], [None]])

[[None, None], ['a', None]]

In [87]:
def wrong(line, casts):
    if line not in permute(line, casts):
        return True
    return False

In [88]:
def windows(line, k=3):
    
    line2 = []
    line2.append(None)
    line2 += line
    line2.append(None)
    line = line2
    
    if len(line) < k:
        line2 = []
        for i in range(math.ceil((k-len(line))/2)):
            line2.append(None)
        line2 += line
        for i in range(math.floor((k-len(line))/2)):
            line2.append(None)
        line = line2
        
    if len(line) == k:
        return [line]
    
    else:
        lines = []
        line2 = line[0:k]
        lines.append(line2.copy())
        for word in line[k:]:
            del line2[0]
            line2.append(word)
            lines.append(line2.copy())
        return lines

In [89]:
def create_dgrams(training_set, casts):
    dgrams = {}
    for line in training_set:
        if len(line) >= 2:
            fst = line[0]
            for word in line[1:]:
                snd = word
                if not wrong([fst, snd], casts):
                    if fst not in dgrams:
                        dgrams[fst] = {}
                    if snd not in dgrams[fst]:
                        dgrams[fst][snd] = 0
                    dgrams[fst][snd] += 1
                fst = snd
    
    return dgrams

In [90]:
def create_database(training_set, poli, casts, digrams1, digrams2):
    # df = pd.DataFrame(data=raw_form)
    # duos = pd.DataFrame(data={"fst":[], "snd":[], "target":[]})
    # trios = pd.DataFrame(data={"fst":[], "snd":[], "trd":[], "target":[]})
    trios = []
    j = -1
    for base_line in training_set:
        j += 1
        if j % 10000 == 0:
            print(j)

        for line in windows(base_line, k=3):
            if not wrong(line, casts):
                for perm in permute(line, casts):
                    trio = {"fst":0, "snd":0, "trd":0, "lgram":"n", "rgram":"n", "target":0}
                    
                    if line == perm:
                        trio["target"] = "y"
                    else:
                        trio["target"] = "n"
                        
                    if perm[0] in digrams1:
                        if perm[1] in digrams1[perm[0]]:
                            trio["lgram"] = "y"
                    #if perm[0] in digrams2:
                    #    if perm[1] in digrams2[perm[0]]:
                    #        trio["lgram"] = "y"
                    if perm[1] in digrams1:
                        if perm[2] in digrams1[perm[1]]:
                            trio["rgram"] = "y"
                    #if perm[1] in digrams2:
                    #    if perm[2] in digrams2[perm[1]]:
                    #        trio["rgram"] = "y"
                                
                    
                    form = []
                    bad = False
                    for word in perm:
                        if word in poli:
                            form.append(poli[word][1])
                        else:
                            if word == None:
                                form.append([None])
                            else:
                                form.append(["na"])
                    
                    if not bad:
                        form = combine(form)
                        for comb in form:
                            trio["fst"] = comb[0]
                            trio["snd"] = comb[1]
                            trio["trd"] = comb[2]
                            trios.append(trio)

    trios = pd.DataFrame(trios)
    return trios

In [91]:
def find_big(training_set):
    dictionary = {}
    for line in training_set:
        if len(line) >= 2:
            for word in line[1:]:
                if len(word) > 0:
                    if word[0].lower() != word[0]:
                        if word.lower() not in dictionary:
                            dictionary[word.lower()] = [word, 0, 0]
                        dictionary[word.lower()][1] += 1
                        dictionary[word.lower()][2] += 1
                    else:
                        if word in dictionary:
                            dictionary[word][2] += 1
    
    return dictionary

In [92]:
print(create_database(load_set(k=3), Gpoli, Gcasts, Gdigrams1, Gdigrams2))

0
           fst        snd   trd lgram rgram target
0         None  subst:gen    na     n     n      y
1    subst:gen         na  verb     n     n      y
2           na       verb  verb     n     n      y
3         verb       verb   num     n     y      y
4         verb       verb   num     n     y      y
..         ...        ...   ...   ...   ...    ...
247  subst:gen       prep  verb     y     y      y
248       prep       verb  None     y     n      y
249       prep       verb  None     y     n      y
250       prep       verb  None     y     n      y
251       prep       verb  None     y     n      y

[252 rows x 6 columns]


In [153]:
def count_digrams(digrams):
    for fst in digrams:
        count = 0
        for snd in digrams[fst]:
            count += digrams[fst][snd]
        digrams[fst][0] = count
    return digrams

In [148]:
def fix_polish2(phrase, casts, digrams1, digrams2):
    perms = []
    i = 0
    for line in windows(phrase, k=2):
        perms.append([])
        best_perm = None
        bsc = -1
        for perm in permute(line, casts):
            pres = 0
            size = 1
            if perm[0] in digrams1:
                if perm[1] in digrams1[perm[0]]:
                    pres += digrams1[perm[0]][perm[1]]
                size += digrams1[perm[1]][0]
            if perm[0] in digrams2:
                if perm[1] in digrams2[perm[0]]:
                    pres += digrams2[perm[0]][perm[1]]
                size += digrams2[perm[1]][0]
            score = (pres/size) * math.log(size)
            perms[i].append((perm,score))
        i += 1
    
    output = []
    prev = None
    for i in range(len(perms)):
        mx = -1
        mperm = None
        for perm, sc in perms[i]:
            if perm[0] == prev:
                if sc > mx:
                    mx = sc
                    mperm = perm
        output.append(mperm[1])
        prev = mperm[1]           
    
    return output

In [149]:
def fix_polish(phrase, poli, casts, digrams1, digrams2, main_tree):
    ans = {}
    for i in range(len(phrase)+2):
        ans[i] = []
    
    perms = []
    i = 0
    for line in windows(phrase, k=3):
        perms.append([])
        i += 1
        mx = 0
        mn = 0
        mxperm = line
        for perm in permute(line, casts):          
            entry = {"fst":0, "snd":0, "trd":0, "lgram":"n", "rgram":"n"}
            if perm[0] in digrams1:
                if perm[1] in digrams1[perm[0]]:
                    entry["lgram"] = "y"
            if perm[0] in digrams2:
                if perm[1] in digrams2[perm[0]]:
                    entry["lgram"] = "y"
            if perm[1] in digrams1:
                if perm[2] in digrams1[perm[1]]:
                    entry["rgram"] = "y"
            if perm[1] in digrams2:
                if perm[2] in digrams2[perm[1]]:
                    entry["rgram"] = "y"

            forms = []
            for word in perm:
                if word in poli:
                    forms.append(poli[word][1])
                else:
                    if word == None:
                        forms.append([None])
                    else:
                        forms.append(["na"])

            forms = combine(forms)
            
            #mn1 = 1
            #mx1 = 0
            ttl = 0
            for form in forms:
                entry["fst"] = form[0]
                entry["snd"] = form[1]
                entry["trd"] = form[2]
                sc1 = main_tree.classify(entry)
                #print(perm)
                #print(entry, score)
                ttl += sc1*sc1
                #if sc1 > mx1:
                #    mx1 = sc1
                #if sc1 < mn1:
                #    mn1 = sc1
            #if 0.5 - mn1 >= mx1 - 0.5:
            #    entry["trio"] = mn1
            #else:
            #    entry["trio"] = mx1
            sc1 = ttl/len(forms)
            perms[i-1].append((perm, sc1))
            if sc1 > mx:
                mx = sc1
                mperm = perm
                

        #ans[i-1].append((mperm[0], mx))
        #ans[i].append((mperm[1], mx))
        #ans[i+1].append((mperm[2], mx))

    output = []
    prev = None
    for i in range(len(perms)):
        mx = -1
        mperm = None
        for perm, sc in perms[i]:
            if perm[0] == prev:
                if sc > mx:
                    mx = sc
                    mperm = perm
        output.append(mperm[1])
        prev = mperm[1]
    
    """
    for i in range(1,len(phrase)+1):
        mx = 0
        oword = None
        for word, sc in ans[i]:
            if sc > mx:
                mx = sc
                oword = word
        output.append(oword)
    """
    
    return output

In [94]:
def fix_case(phrase, bigs1, bigs2):
    ww = phrase[0]
    fixed = []
    fixed.append(ww.capitalize())
    if len(phrase) >= 2:
        for word in phrase[1:]:
            if word in bigs1 and (word not in bigs2 or (word in bigs2 and bigs2[word][1] / bigs2[word][2] >= 0.5)):
                fixed.append(bigs1[word])
            else:
                if word in bigs2 and bigs2[word][1] / bigs2[word][2] > 0.5:
                    fixed.append(bigs2[word][0])
                else:
                    fixed.append(word)
    
    return fixed

In [95]:
def score(phrase1, phrase2):
    s = 0
    for i in range(len(phrase1)):
        if phrase1[i] == phrase2[i]:
            s += 1

    return s/len(phrase1)

In [96]:
Gpoli = load_polimorph()
Gcasts = create_casts(Gpoli)
#unigramsS = load_unigrams()
Gdigrams1 = load_2grams(k=3)
#trigramsS = load_3grams()

In [97]:
Gbig = load_polimorph2()
Gtotal_set = load_big_set(k=1000000)

Gbig2 = find_big(Gtotal_set)
del Gtotal_set

In [112]:
Gtotal_set = load_set(k=1000000)
Gdigrams2 = create_dgrams(Gtotal_set, Gcasts)
del Gtotal_set

In [111]:
Gvalidation_set = load_set(k=1200000)[1000000:]
Gvalidation_set2 = load_big_set(k=1200000)[1000000:]

In [155]:
Gdigrams1 = count_digrams(Gdigrams1)
Gdigrams2 = count_digrams(Gdigrams2)

In [None]:
Gtrain_set = load_set(k=1000000)
print(len(Gtrain_set))

Gdatabase = create_database(Gtrain_set, Gpoli, Gcasts, Gdigrams1, Gdigrams2)
del Gtrain_set

Gmain_tree = Tree(Gdatabase)
del Gdatabase

In [109]:
Gtrain_set = load_set(k=120000)[100000:]
trios2 = create_database(Gtrain_set, Gpoli, Gcasts, Gdigrams1, Gdigrams2)
del Gtrain_set
Gmain_tree.start_prune(trios2)
del trios2

0
10000


In [None]:
Gmain_tree.draw().render('test-output/database_tree.gv', view=False)

In [114]:
line = load_set(k=3)[1]
print(fix_polish(remove_polish(line), Gpoli, Gcasts, Gdigrams1, Gdigrams2, Gmain_tree))
print(line)
print(remove_polish(line))
print(score(line, fix_polish(remove_polish(line), Gpoli, Gcasts, Gdigrams1, Gdigrams2, Gmain_tree)))

['parlament', 'zdecydował', 'jednak', 'inaczej', 'i', 'przyjął', 'w', 'ustawie', 'z', 'dnia', 'r', 'jednoinstancyjne', 'postępowanie', 'orzeczniczo-lekarskie']
['parlament', 'zdecydował', 'jednak', 'inaczej', 'i', 'przyjął', 'w', 'ustawie', 'z', 'dnia', 'r', 'jednoinstancyjne', 'postępowanie', 'orzeczniczo-lekarskie']
['parlament', 'zdecydowal', 'jednak', 'inaczej', 'i', 'przyjal', 'w', 'ustawie', 'z', 'dnia', 'r', 'jednoinstancyjne', 'postepowanie', 'orzeczniczo-lekarskie']
1.0


In [115]:
#k=20000
print(len(Gvalidation_set))
total = 0

for i in range(len(Gvalidation_set)):
    line = Gvalidation_set[i]
    line2 = Gvalidation_set2[i]
    broken_line = remove_polish(line)
    fixed_line = fix_polish(broken_line, Gpoli, Gcasts, Gdigrams1, Gdigrams2, Gmain_tree)
    sc1 = score(line, fixed_line)
    fixed_line2 = fix_case(fixed_line, Gbig, Gbig2)
    sc2 = score(line2, fixed_line2)
    #print(fixed_line2)
    #print(line2)
    #print(sc1, sc2)
    total += math.sqrt(sc1*sc1)

print(total/len(Gvalidation_set))

200000
0.9569000682611913


In [116]:
#k=20000
print(len(Gvalidation_set))
total = 0

for i in range(len(Gvalidation_set)):
    line = Gvalidation_set[i]
    line2 = Gvalidation_set2[i]
    broken_line = remove_polish(line)
    fixed_line = fix_polish(broken_line, Gpoli, Gcasts, Gdigrams1, Gdigrams2, Gmain_tree)
    sc1 = score(line, fixed_line)
    fixed_line2 = fix_case(fixed_line, Gbig, Gbig2)
    sc2 = score(line2, fixed_line2)
    #print(fixed_line2)
    #print(line2)
    #print(sc1, sc2)
    total += math.sqrt(sc1*sc2)

print(total/len(Gvalidation_set))

200000
0.9312928879197104


In [157]:
print(len(Gvalidation_set))
total = 0

k = 100
for i in range(k):
    line = Gvalidation_set[i]
    line2 = Gvalidation_set2[i]
    broken_line = remove_polish(line)
    fixed_line = fix_polish2(broken_line, Gcasts, Gdigrams1, Gdigrams2)
    sc1 = score(line, fixed_line)
    fixed_line2 = fix_case(fixed_line, Gbig, Gbig2)
    sc2 = score(line2, fixed_line2)
    #print(fixed_line2)
    #print(line2)
    #print(sc1, sc2)
    total += math.sqrt(sc1*sc2)

print(total/k)

200000
0.9449862768483344


In [101]:
print(Gbig2["w"])

['W', 3077, 607097]


In [112]:
print(Gdigrams1["pij"])

{'i': 21, 'dużo': 17, 'na': 16, '-': 12, 'za': 11, 'z': 8, 'tak': 7, 'alkoholu': 6, 'do': 6, 'wodę': 6, 'już': 5, 'mleko': 5, 'mleko,': 5, 'pij': 5, 'w': 5}


In [118]:
print(Gdigrams2["pij"])

{'świeżo': 1, 'wodę': 1, 'piotrek-elektryczne': 1, 'tylko': 1, 'nych': 1, 'skim': 1, 'i': 1, 'na': 1}


In [119]:
print(Gdigrams2["pij"])

{'świeżo': 1, 'wodę': 1, 'piotrek-elektryczne': 1, 'tylko': 1, 'nych': 1, 'skim': 1, 'i': 1, 'na': 1}


In [171]:
print(Gcasts["jesli"])

['jeśli']
