In [1]:
import numpy as np
import matplotlib.pyplot as plt
from nltk import Tree, grammar
import random
import queue
import pickle as pkl
from scipy.spatial import distance
from PYEVALB import scorer
from PYEVALB import parser
from sklearn.metrics import precision_recall_fscore_support
import math
from itertools import product
from time import time
import argparse
from tqdm import tqdm
from sentence_splitter import SentenceSplitter, split_text_into_sentences



def ignore_functional_labels(string):
    
    string = string.strip(' ')
    l = string.split(' ')[1:]
    for i in range(len(l)):
        if l[i][0] == '(':
            l[i] = l[i].split('~')[0]
            
    return ' '.join(l)[:-1]




def extract_nodes(rule):
    
    if rule.is_lexical():
        return True, set((rule._lhs._symbol,)), set((rule._rhs))
    
    else:
        return False, set((rule._lhs._symbol,)).union(set((rule._rhs[0]._symbol, rule._rhs[1]._symbol))), set()
    
def grammar_infos(data):
    
    non_terminals = set()
    pos_tags = set()
    terminals = set()
    binaries = set()
    starts = set()
    for string in data:
        try:
          t = Tree.fromstring(ignore_functional_labels(string))
        except:
          print(string)
        t.chomsky_normal_form(horzMarkov=2)
        t.collapse_unary(collapsePOS=True, collapseRoot=True)
        rules = t.productions()
        starts.add(rules[0]._lhs._symbol)
        for rule in rules:
            try:
              lexical, non_terminal_nodes, terminal_nodes = extract_nodes(rule)
            except:
              pass
            if lexical:
                non_terminals = non_terminals.union(non_terminal_nodes)
                pos_tags = pos_tags.union(non_terminal_nodes)
                terminals = terminals.union(terminal_nodes)

            else:
                non_terminals = non_terminals.union(non_terminal_nodes)
                if len(rule._rhs) == 2:
                    binaries.add((rule._rhs[0]._symbol, rule._rhs[1]._symbol))
                    
    return non_terminals, pos_tags, terminals, binaries, starts

def express_node(rule):

    if rule.is_lexical():
        return 'lexical', rule._lhs._symbol, rule._rhs[0].lower()
    
    else:
        lhs, rhs = rule._lhs._symbol, (rule._rhs[0]._symbol, rule._rhs[1]._symbol)
        if len(rhs) == 2:
            return 'binary', lhs, rhs
        
def pcfg(data):
    
    # Initalize dictionnaries dict_lexicons, dict_unaries, dict_binaries that we put in dict_pcfg
    dict_lexicons, dict_binaries = {}, {} # Notice that we already put the rule of the start node in
                                                                              # Chomsky normal form
    dict_pcfg = {'lexical' : dict_lexicons, 'binary' : dict_binaries}
    
    for string in data:
        t = Tree.fromstring(ignore_functional_labels(string)) # Ignore the functional labels (see the doc of ignore_functional_labels)
        t.chomsky_normal_form(horzMarkov=2)  # Convert the tree to Chomsky normal form
        t.collapse_unary(collapsePOS=True, collapseRoot=True)
        rules = t.productions()  # Get the rules
        for rule in rules:       # We start by counting the rules in data
            nature, lhs, rhs = express_node(rule)
            if lhs in dict_pcfg[nature]:
                if rhs in dict_pcfg[nature][lhs]:
                    dict_pcfg[nature][lhs][rhs] += 1

                else:
                    dict_pcfg[nature][lhs][rhs] = 1

            else:
                dict_pcfg[nature][lhs] = {rhs : 1}
                
    dict_normalized = {}
    for nature in dict_pcfg:
        for lhs in dict_pcfg[nature]:
            if lhs in dict_normalized:
                dict_normalized[lhs] += sum(dict_pcfg[nature][lhs].values())
                
            else:
                dict_normalized[lhs] = sum(dict_pcfg[nature][lhs].values())
                
    for nature in dict_pcfg:
        for lhs in dict_pcfg[nature]:
            dict_pcfg[nature][lhs] = dict((key, value/dict_normalized[lhs]) for key, value in dict_pcfg[nature][lhs].items())

    dict_probas = {'lexical' : {}, 'unary' : {}, 'binary' : {}} # Initialize dict_probas
    for nature in dict_pcfg:
        for lhs in dict_pcfg[nature]:
            for node in dict_pcfg[nature][lhs]:
                if node in dict_probas[nature]:
                    dict_probas[nature][node][lhs] = math.log(dict_pcfg[nature][lhs][node])

                else:
                    dict_probas[nature][node] = {lhs : math.log(dict_pcfg[nature][lhs][node])}
                    
    return dict_pcfg, dict_probas


def cyk(sentence):
    
    n = len(sentence) + 1
    k = len(non_terminals)
    scores = [[{} for j in range(n)] for l in range(n)]
    back = [[[None for i in range(k)] for j in range(n)] for l in range(n)]
    Bs_scores, Cs_scores = [[set() for j in range(n)] for l in range(n)], [[set() for j in range(n)] for l in range(n)]
    for i in range(0, n-1):
        word = sentence[i]
        if word in dict_probas['lexical'].keys():
            for A in dict_probas['lexical'][word]:
                scores[i][i+1][A] = dict_probas['lexical'][word][A]
                if A in B_binary:
                    Bs_scores[i][i+1].add(A)

                if A in C_binary:
                    Cs_scores[i][i+1].add(A)
                    
        else:
            print(word + " : isn't in terminals")
            return scores, back
                 
    for span in range(2, n):
        start_time_binary = time()
        for begin in range(0, n-span):
            end = begin + span
            start_time_binary = time()
            for split in range(begin+1, end):
                Bs = Bs_scores[begin][split].intersection(B_binary)
                Cs = Cs_scores[split][end].intersection(C_binary)
                
                for B, C in set_binary.intersection(set(product(Bs, Cs))):
                    score_B, score_C = scores[begin][split].get(B, -np.inf), scores[split][end].get(C, -np.inf)
                    for A in dict_probas['binary'][(B, C)]:
                        prob = score_B + score_C + dict_probas['binary'][(B, C)][A]
                        if prob > scores[begin][end].get(A, -np.inf):
                            scores[begin][end][A] = prob
                            back[begin][end][dict_non_terminals_indices[A]] = (split, B, C)
                            if A in B_binary:
                                Bs_scores[begin][end].add(A)

                            if A in C_binary:
                                Cs_scores[begin][end].add(A)
                                
    return scores, back

def not_in_grammar(sentence):

    
    string = '( (Not_in_Grammar '
    for i in range(len(sentence)-1):
        string += '(Not_in_Grammar ' + sentence[i] + ')'
        
    string += '(Not_in_Grammar ' + sentence[-1] + ')))'
    return string

def build_parentheses(back_track, scores, sentence):
    
    q = queue.LifoQueue()
    begin = 0
    end = len(back_track[0]) - 1
    starts_scores = []
    for start in starts:
        starts_scores.append((scores[begin][end].get(start,-np.inf), start))
    
    starts_scores.sort(reverse=True)
    if starts_scores[0][0] == -np.inf:
        return not_in_grammar(sentence)
    
    sent = starts_scores[0][1]
    string = '('
    q.put((None, sent, 'start', 1, [begin, end]))

    while not q.empty():
        split, symbol, status, depth, border = q.get()   # border is begin for left and end for right
        string += ' (' + symbol

        if status == 'left':
            border = [border[0], split]
            children = back_track[border[0]][border[1]][dict_non_terminals_indices[symbol]]

        elif status == 'right':
            border = [split, border[1]]
            children = back_track[border[0]][border[1]][dict_non_terminals_indices[symbol]]

        elif status == 'start':
            children = back_track[border[0]][border[1]][dict_non_terminals_indices[symbol]]

        if children is not None:
            if isinstance(children, grammar.Nonterminal):
                node = (split, children, status, depth+1, border)
                q.put(node)

            elif len(children) == 3:
                split, l_child, r_child = children
                l_node, r_node = (split, l_child, 'left', 0, border), (split, r_child, 'right', depth+1, border)
                q.put(r_node)
                q.put(l_node)

        else:
            depth += 1
            word = sentence[border[0]:border[1]][0]
            string += ' ' + word + ')'*depth
            
    return string


def levenshtein(s1, s2):
    
    
    n1, n2 = len(s1), len(s2)
    s1, s2 = s1.lower(), s2.lower()
    lev = np.zeros((n1 + 1, n2 + 1))
    for i in range(n1 + 1):
        lev[i, 0] = i
        
    for j in range(n2 + 1):
        lev[0, j] = j
        
    for i in range(1, n1 + 1):
        for j in range(1, n2 + 1):
            if s1[i-1] == s2[j-1]:
                lev[i, j] = min(lev[i-1, j] + 1, lev[i, j-1] + 1, lev[i-1, j-1])
                
            else:
                lev[i, j] = min(lev[i-1, j] + 1, lev[i, j-1] + 1, lev[i-1, j-1] + 1)
                
    return lev[n1, n2]


def unigram(data):

    probas = np.zeros(len(terminals))
    for bracketed in data:
        t = Tree.fromstring(bracketed)
        for word in t.leaves():
            probas[dict_terminals_indices[word]] += 1

    return probas/probas.sum()

def bigram(data):
  
    probas = np.ones((len(terminals), len(terminals)))
    for bracketed in data:
        t = Tree.fromstring(bracketed)
        sentence = t.leaves()
        if len(sentence) >= 2:
            for i in range(1, len(sentence)):
                index_1 = dict_terminals_indices[sentence[i - 1]] # The previous word in the sequene.
                index_2 = dict_terminals_indices[sentence[i]] # The current word in the sequene.
                probas[index_1, index_2] += 1

        else:
            continue

    return probas/((probas.sum(axis = 1).reshape(-1, 1)))

def closest_formal(s, terminals, n_words = 5, swap = False):

    if swap:
        fun = lambda word : damerau_levenshtein(s, word)
        
    else:
        fun = lambda word : levenshtein(s, word)
        
    return {word : score for score, word in sorted(zip(list(map(fun, terminals)), terminals))[:n_words]}


def closest_embedding(s, terminals, n_words = 10):

    
    fun = lambda word : distance.cosine(dict_words_embeddings[s], dict_terminals_embeddings[word])
    return {score_word[1] : score_word[0] for score_word in sorted(zip(map(fun, terminals_embedded), terminals_embedded))[:n_words]}


def choose_words(sentence, terminals, probas_unigram, probas_bigram, swap = True, n_words = 10, n_words_formal = 5, lambd = 0.8):
        
    log_proba = 0
    sent = sentence.copy()
    for i in range(len(sent)):
        s = sent[i]
        if s in terminals:
            log_proba += math.log(probas_unigram[dict_terminals_indices[s]]) if i==0 else math.log(probas_bigram[dict_terminals_indices[sent[i-1]], dict_terminals_indices[s]])
            continue

        else:
            words_formal = closest_formal(s, terminals, n_words_formal, swap)
            words_embedding = closest_embedding(s, terminals, n_words) if s in dict_words_embeddings else {}
            words_proposed = set(words_formal.keys()) | set(words_embedding.keys()) # Set of the proposed words
            if i == 0:
                scores = []
                for word in words_proposed:
                    unigram = math.log(probas_unigram[dict_terminals_indices[word]])
                    scores.append((unigram, word))
                
            else:
                scores = []
                for word in words_proposed:
                    bigram = math.log(lambd*probas_bigram[dict_terminals_indices[sent[i-1]], dict_terminals_indices[word]] + (1-lambd)*probas_unigram[dict_terminals_indices[word]])
                    score = log_proba + bigram
                    scores.append((score, word))
                
            scores.sort(reverse = True)
            sent[i] = scores[0][1]
            log_proba += scores[0][0]
            
    sent = [word.lower() for word in sent]
    return sent

def parse(sentence, n_words_formal=2, n_words=20, swap=False, lambd=0.8):
    sentence = sentence.split(' ')
    sentence_oov = choose_words(sentence, terminals, probas_unigram, probas_bigram, n_words_formal=n_words_formal, 
                                n_words=n_words, swap=swap, lambd=lambd)
    scores, back = cyk(sentence_oov)
    bracketed = build_parentheses(back, scores, sentence)
    t_test = Tree.fromstring(bracketed)
    t_test.un_chomsky_normal_form()
    return " ".join(str(t_test).split())
    
    

        

In [2]:
data_train = 'IT_Treebank.txt'
train_size = 0.9
n_words_formal = 2
n_words = 20
swap = True
lambd = 0.8


# Load the training data
with open(data_train, 'r', encoding = 'utf-8') as file:
    data = file.read().splitlines()
    
file.close()
  
n_lines = int(train_size*len(data))
data_test = data[n_lines:]
data = data[:n_lines]
# Prepare evaluation data, only keep the space-tokenized sentence from the bracketed parse.
data_test_sentences = []
for bracketed in data_test:
    t = Tree.fromstring(bracketed)
    data_test_sentences.append(" ".join(t.leaves()))
        
# Describe the grammar of the training data
non_terminals, pos_tags, terminals, binaries, starts = grammar_infos(data)
    
dict_non_terminals_indices = {non_terminal : index for index, non_terminal in enumerate(non_terminals)}
dict_indices_non_terminals = {index : non_terminal for non_terminal, index in dict_non_terminals_indices.items()}

dict_pos_indices = {pos : index for index, pos in enumerate(pos_tags)}
dict_indices_pos = {index : pos for pos, index in dict_pos_indices.items()}

dict_terminals_indices = {terminal : index for index, terminal in enumerate(terminals)}
dict_indices_terminals = {index : terminal for terminal, index in dict_terminals_indices.items()}

dict_binaries_indices = {binary : index for index, binary in enumerate(binaries)}
dict_indices_terminals = {index : binary for binary, index in dict_binaries_indices.items()}

dict_starts_indices = {binary : index for index, binary in enumerate(starts)}
dict_indices_starts = {index : binary for binary, index in dict_starts_indices.items()}
    
  # Create pcfg
dict_pcfg, dict_probas = pcfg(data)
print('Creating the PCFG : Finished')
    
# Sets of left hand nodes and right hand nodes in binary rules.
B_binary, C_binary = set([binary[0] for binary in dict_probas['binary'].keys()]), set([binary[1] for binary in dict_probas['binary'].keys()])
set_binary = set(dict_probas['binary'].keys())
    
# Load polyglot embeddings.
words, embeddings = pkl.load(open('polyglot-it.pkl', 'rb'), encoding = 'latin')
dict_words_embeddings = dict((word, embedding) for word, embedding in zip(words, embeddings)) # Dictionnary mapping polyglot words to embeddings
words_polyglot = set(words) # Set of all the words in polyglot dataset.
dict_terminals_embeddings = {terminal : embedding for (terminal, embedding) in dict_words_embeddings.items() if terminal in terminals} # dictionnary mapping terminals to embeddings
terminals_embedded = set(dict_terminals_embeddings.keys()) # Set of all terminals having an embedding.
del embeddings
    
# Define language model.
probas_unigram, probas_bigram = unigram(data), bigram(data)
print('Defining the language model : Finished')
    


Creating the PCFG : Finished
Defining the language model : Finished


In [11]:
sentence = 'Oggi è una bella giornata!'
list_of_sentences = [sentence]


In [20]:
splitter = SentenceSplitter(language='it')


for sentence in tqdm(list_of_sentences):
    sentences = []
    list_parse = []
    sentence = sentence.replace('...','.')
    sentence = sentence.replace('.', ' . ').replace('!',' ! ').replace('?',' ? ')
    sentences = splitter.split(text=sentence)
    if len(sentences) == 1:
        try:
            parses = parse(sentences[0], n_words_formal=n_words_formal, n_words=n_words, swap=swap, lambd=lambd)
        except:
            parses = "(S)"
    else:
        for sub_sent in sentences:
            try:
                list_parse.append(parse(sub_sent, n_words_formal=n_words_formal, n_words=n_words, swap=swap, lambd=lambd))
            except:
                list_parse.append("(S)")
        parses = ""
        for el in list_parse:
            parses = parses + el
        parses = "(S " + parses + ")"
    

    print(parses)

100%|██████████| 1/1 [00:00<00:00, 49.07it/s]

( (S (ADVP (ADVB Oggi)) (VP (VMA è) (NP (ART una) (ADJ bella) (NOU giornata))) (. !)))



