In [1]:
from __future__ import division, print_function

from collections import Counter

import codecs
import io
import os
import sys

import bokeh
from bokeh.models import ColumnDataSource, LabelSet
from bokeh.io import output_notebook, show
from bokeh.plotting import figure
output_notebook()

In [2]:
class Analysis(object):

    def __init__(self, fpath, data_folder):
        self.sentences = self.extract_triples(fpath)
        self.train_tags = self.parse_files(data_folder)

    def extract_triples(self, fpath):
        collect_match = False
        sentences = []
        with codecs.open(fpath,'rb','utf8') as f:
            for line in f:
                line = line.strip()
                if '## Matching clauses ##' in line:
                    sentences.append([[],[]])
                    collect_match = True
                elif '## Non-matching clauses ##' in line:
                    collect_match = False

                if '|' in line:
                    p,g = line.split('|')
                    p,g = p.strip(), g.strip()
                    sentences[-1][0].append((p, True if collect_match else False))
                    sentences[-1][1].append(g)
        return sentences
    
    def parse_files(self, data_folder):        
        def read_file(fpath):
            tags_cnt = Counter()
            with codecs.open(fpath,'rb','utf8') as f:
                while True:
                    words = f.readline().strip().split()
                    if not words:
                        break
                    lemmas = f.readline()
                    pos = f.readline()
                    dep_info = f.readline()
                    align = f.readline()
                    tags = f.readline().strip()
                    tags_cnt.update([t[:-1].lower() for t in tags.split() if t.endswith('(')])
                    _ = f.readline().strip()
            return tags_cnt
        
        tags_cnt = read_file(os.path.join(data_folder,'train_input.txt'))
        
        return tags_cnt

    def category_analysis(self):
        def get_precision(pred, num_pred):
            return (len([p for p,c in pred if c]), num_pred)
        def get_recall(pred, num_gold):
            return (len([p for p,c in pred if c]), num_gold)
        def get_f1(p,r):
            return (2*(p*r))/(p+r) if p+r>0 else 0

        """
        Calculate precision, recall and macro-average F1 for the following categories:
            * Semantic predicates (e.g. THEME(e,x))
            * Lemmatized predicates (e.g. sky(x))
            * Referents (e.g. the variables)

        Also show P, R and F1 for each of the semantic predicates
        """
        k_eval = {}
        cond_eval = {}
        cond_one_eval={}
        cond_two_eval={}
        lemma_eval = {}
        op_eval = {}

        pred_cond_correct, pred_cond_wrong, all_gold_cond = [],[],[]
        #pred_lemma_correct, pred_lemma_wrong, all_gold_lemma = [],[],[]
        cond2set=set([])
        cond1set=set([])

        for idx, sent in enumerate(self.sentences):
            if idx!=0 and idx%100==0:
                print("Reading %d sentence..." % (idx))
            pred, gold = sent[0],sent[1]
            pred_op_nodes, pred_lemma_nodes, pred_cond_nodes, pred_cond1_nodes, pred_cond2_nodes = [],[],[],[],[]
            gold_op_nodes, gold_lemma_nodes, gold_cond_nodes, gold_cond1_nodes, gold_cond2_nodes = [],[],[],[],[]
            
            opList=['NOT','NEC','POS','OR','DUPLEX','IMP','RESULTS','CONTINUATION','EXPLANATION','CONTRAST']
            for t in gold:
                if t:
                    items = t.split()
#                     if len(items)>2:
                    left, label, right = items[0],items[1],items[2:]
                    if label in opList:
                        gold_op_nodes.append(label)
                    elif label.isupper():
                        gold_cond_nodes.append(label)
                        if len(items)==3:
                            if not label=='TIME' and not label=='LOCATION':
                                cond1set.add(label)
                            gold_cond1_nodes.append(label)
                        else:
                            assert(len(items)==4)
                            
                            if label in cond1set:
                                print(label)
                            assert not (label in cond1set)
                            cond2set.add(label)
                            
                            gold_cond2_nodes.append(label)
                    else:
                        assert(len(items)==3)
                        gold_lemma_nodes.append(label)
            
            for t, correct in pred:
                if t:
                    items = t.split()
                    left, label, right = items[0],items[1],items[2]
                    if label in opList:
                        pred_op_nodes.append((label,correct))
                    elif label.isupper():
                        pred_cond_nodes.append((label,correct))
                        if len(items)==3:
                            pred_cond1_nodes.append((label,correct))
                        else:
                            pred_cond2_nodes.append((label,correct))
                    else:
                        pred_lemma_nodes.append((label,correct))

            p = get_precision(pred_cond_nodes, len(pred_cond_nodes))
            r = get_recall(pred_cond_nodes, len(gold_cond_nodes))
            cond_eval.setdefault('p',[]).append(p)
            cond_eval.setdefault('r',[]).append(r)
            
            p = get_precision(pred_cond1_nodes, len(pred_cond1_nodes))
            r = get_recall(pred_cond1_nodes, len(gold_cond1_nodes))
            cond_one_eval.setdefault('p',[]).append(p)
            cond_one_eval.setdefault('r',[]).append(r)
            
            p = get_precision(pred_cond2_nodes, len(pred_cond2_nodes))
            r = get_recall(pred_cond2_nodes, len(gold_cond2_nodes))
            cond_two_eval.setdefault('p',[]).append(p)
            cond_two_eval.setdefault('r',[]).append(r)

            p = get_precision(pred_lemma_nodes, len(pred_lemma_nodes))
            r = get_recall(pred_lemma_nodes, len(gold_lemma_nodes))
            lemma_eval.setdefault('p',[]).append(p)
            lemma_eval.setdefault('r',[]).append(r)
            
            p = get_precision(pred_op_nodes, len(pred_op_nodes))
            r = get_recall(pred_op_nodes, len(gold_op_nodes))
            op_eval.setdefault('p',[]).append(p)
            op_eval.setdefault('r',[]).append(r)

            for t,correct in pred_cond_nodes:
                pred_cond_correct.append(t) if correct else pred_cond_wrong.append(t)
            for k,correct in pred_op_nodes:
                pred_cond_correct.append(k) if correct else pred_cond_wrong.append(k)
            for i,correct in pred_lemma_nodes:
                pred_cond_correct.append(i) if correct else pred_cond_wrong.append(i)
                #pred_lemma_correct.append(i) if correct else pred_lemma_wrong.append(i)
            all_gold_cond.extend(gold_cond_nodes)
            all_gold_cond.extend(gold_op_nodes)
            all_gold_cond.extend(gold_lemma_nodes)

        tags_prec = sum([token[0] for token in cond_eval['p']])/sum([token[1] for token in cond_eval['p']])
        tags_recall = sum([token[0] for token in cond_eval['r']])/sum([token[1] for token in cond_eval['r']])
        tags_f1 = get_f1(tags_prec, tags_recall)
        print ("Conditions ||| P:%.4f R:%.4f F1:%.4f" % (tags_prec, tags_recall, tags_f1))
        
        ktags_prec = sum([token[0] for token in op_eval['p']])/sum([token[1] for token in op_eval['p']])
        ktags_recall = sum([token[0] for token in op_eval['r']])/sum([token[1] for token in op_eval['r']])
        ktags_f1 = get_f1(ktags_prec, ktags_recall)
        print ("Operations ||| P:%.4f R:%.4f F1:%.4f" % (ktags_prec, ktags_recall, ktags_f1))

        lemma_prec = sum([token[0] for token in lemma_eval['p']])/sum([token[1] for token in lemma_eval['p']])
        lemma_recall = sum([token[0] for token in lemma_eval['r']])/sum([token[1] for token in lemma_eval['r']])
        lemma_f1 = get_f1(lemma_prec, lemma_recall)
        print ("Lemmas ||| P:%.4f R:%.4f F1: %.4f" % (lemma_prec, lemma_recall, lemma_f1))
        
        tags_prec = sum([token[0] for token in cond_one_eval['p']])/sum([token[1] for token in cond_one_eval['p']])
        tags_recall = sum([token[0] for token in cond_one_eval['r']])/sum([token[1] for token in cond_one_eval['r']])
        tags_f1 = get_f1(tags_prec, tags_recall)
        print ("Conditions one arg ||| P:%.4f R:%.4f F1:%.4f" % (tags_prec, tags_recall, tags_f1))
        
        tags_prec = sum([token[0] for token in cond_two_eval['p']])/sum([token[1] for token in cond_two_eval['p']])
        tags_recall = sum([token[0] for token in cond_two_eval['r']])/sum([token[1] for token in cond_two_eval['r']])
        tags_f1 = get_f1(tags_prec, tags_recall)
        print ("Conditions two arg ||| P:%.4f R:%.4f F1:%.4f" % (tags_prec, tags_recall, tags_f1))

        tags_correct_cnt = Counter(pred_cond_correct)
        tags_wrong_cnt = Counter(pred_cond_wrong)
        tags_gold_cnt = Counter(all_gold_cond)
        #lemmas_correct_cnt=Counter(pred_lemma_correct)
        #lemmas_wrong_cnt=Counter(pred_lemma_wrong)
        #lemmas_gold_cnt=Counter(all_gold_lemma)
        
        self.tags = (tags_correct_cnt, tags_wrong_cnt, tags_gold_cnt,cond1set,cond2set)
        
    def analysis_tags(self):
        print('====ANALYSIS TAGS====')
        opList=['NOT','NEC','POS','OR','DUPLEX','IMP','RESULTS','CONTINUATION','EXPLANATION','CONTRAST']
        tags_correct_cnt, tags_wrong_cnt, tags_gold_cnt,cond1_set,cond2_set = self.tags
        freqs, f1s, tags, ps, rs= [],[],[],[],[]
        for t, _ in sorted(tags_gold_cnt.items(),key=lambda x: x[1],reverse=True):
            if t in tags_correct_cnt or t in tags_wrong_cnt:
                p = tags_correct_cnt[t]/(tags_wrong_cnt[t]+tags_correct_cnt[t])
            else:
                p = 0
            if t in tags_correct_cnt:
                r = tags_correct_cnt[t]/tags_gold_cnt[t]
            else:
                r = 0
            f1 = (2*(p*r))/(p+r) if p+r>0 else 0
            
            freq = self.train_tags[t.lower()] if t.lower() in self.train_tags else 0
            
            freqs.append(freq)
            ps.append(p)
            rs.append(r)
            f1s.append(f1)
            tags.append(t)
            
            #print('%s ||| P:%.2f R:%.2f F1:%.2f ||| Train freq: %d' % (t, p, r, f1, freq))
          
        opResults=[]
        cond1Results=[]
        cond2Results=[]
        lemmaResults=[]
        op_f1s=[]
        op_tags=[]
        op_freqs=[]
        cond1_f1s=[]
        cond1_tags=[]
        cond1_freqs=[]
        cond2_f1s=[]
        cond2_tags=[]
        cond2_freqs=[]
        lemma_f1s=[]
        lemma_tags=[]
        lemma_freqs=[]
        
        for freq, p, r, f1, tag in zip(freqs, ps, rs, f1s, tags):
            if tag in opList:
                opResults.append((freq, p, r, f1, tag))
                op_f1s.append(f1)
                op_tags.append(tag)
                op_freqs.append(freq)
            elif tag.isupper() and tag in cond1_set:
                cond1Results.append((freq, p, r, f1, tag))
                
                if freq>=60:
                    cond1_f1s.append(f1)
                    cond1_tags.append(tag)
                    cond1_freqs.append(freq)
            elif tag.isupper() and tag in cond2_set:
                cond2Results.append((freq, p, r, f1, tag))
                if freq>=60:
                    cond2_f1s.append(f1)
                    cond2_tags.append(tag)
                    cond2_freqs.append(freq)
            else:
                lemmaResults.append((freq, p, r, f1, tag))
                if freq<60:
                    lemma_f1s.append(f1)
                    lemma_tags.append(tag)
                    lemma_freqs.append(freq)
                    
                    
        opList=['POS','RESULTS','NOT','IMP','CONTINUATION','CONTRAST','NEC','EXPLANATION']
        upOpList=[item.upper()  for item in opList]
        opDictr={}
        print('====OP AGS====')
        for freq, p, r, f1, t in sorted(opResults, key=lambda x: x[3], reverse=True):
            if t in upOpList:
                opDictr[t]=(p, r, f1, freq)
        #print(condDictr.items())
        for item in upOpList:
            if item in opDictr:
                p, r, f1, freq =opDictr[item]
                print('%s &%.4f &%.4f &%.4f & %d\\\\' % (item, p, r, f1, freq))
            else:
                print('%s &-- &-- &-- & --\\\\' % (item, ))
                
                
        condList=['Eq','Agent','Location','Theme','Role','Quantity','Stimulus','Experiencer','Patient','Partof','Source','Of','Tsu','Recipient',
                 'Pivot','Value','Beneficiary','User','Owner','Topic','Destination','Yearofcentury','Attribute','Tab','Cotheme','Clocktime','Creator',
                 'Instrument','Coagent','Apx','Result','Manner','Asset','Product','Part','Causer','Colour','Monthofyear','Goal','Material',
                 'Path','Title','Madeof','Duration','Dayofweek','Dayofmonth','Finish','Copatient','Start','Mor','Leq','Neq','Affector','Measure']
        upCondList=[item.upper()  for item in condList ]
        condDictr={}
        print('====OPTION TAGS====')
        for freq, p, r, f1, t in sorted(opResults, key=lambda x: x[3], reverse=True):
            print('%s ||| P:%.4f R:%.4f F1:%.4f ||| Train freq: %d' % (t, p, r, f1, freq))
            
        print('====CONDITION TAGS====')
        for freq, p, r, f1, t in sorted(cond1Results+cond2Results, key=lambda x: x[3], reverse=True):
            if t in upCondList:
                condDictr[t]=( p, r, f1, freq)
        #print(condDictr.items())
        for item in upCondList:
            if item in condDictr:
                p, r, f1, freq =condDictr[item]
                print('%s &%.4f &%.4f &%.4f & %d\\\\' % (item, p, r, f1, freq))
            else:
                print('%s &-- &-- &-- & --\\\\' % (item, ))
                
            
        
        print('====CONDITION one arg TAGS====')
        for freq, p, r, f1, t in sorted(cond1Results, key=lambda x: x[3], reverse=True):
            print('%s ||| P:%.4f R:%.4f F1:%.4f ||| Train freq: %d' % (t, p, r, f1, freq))
            
        print('====CONDITION two arg TAGS====')
        for freq, p, r, f1, t in sorted(cond2Results, key=lambda x: x[3], reverse=True):
            print('%s ||| P:%.4f R:%.4f F1:%.4f ||| Train freq: %d' % (t, p, r, f1, freq))
            
        print('====LEMMA TAGS====')
        for freq, p, r, f1, t in sorted(lemmaResults, key=lambda x: x[3], reverse=True):
            print('%s ||| P:%.4f R:%.4f F1:%.4f ||| Train freq: %d' % (t, p, r, f1, freq))
        
            
        source_op = ColumnDataSource(
            data=dict(
                x=op_f1s,
                y=op_freqs,
                desc=op_tags,
            )
        )

        p_op = figure(plot_width=700, plot_height=600,x_range=[-0.04,1.04],y_range=[-2,60])

        p_op.scatter(x='x', y='y', size=8, source=source_op)
        
        labels = LabelSet(x='x', y='y', text='desc', level='glyph',
              x_offset=5, y_offset=5, source=source_op, render_mode='canvas', text_font_size='10px')
        
        p_op.add_layout(labels)

        show(p_op)
        
        
        
        source_cond2 = ColumnDataSource(
            data=dict(
                x=cond2_f1s,
                y=cond2_freqs,
                desc=cond2_tags,
            )
        )

        p_cond2 = figure(plot_width=700, plot_height=600,x_range=[-0.04,1.04])

        p_cond2.scatter(x='x', y='y', size=8, source=source_cond2)
        
        labels = LabelSet(x='x', y='y', text='desc', level='glyph',
              x_offset=5, y_offset=5, source=source_cond2, render_mode='canvas', text_font_size='10px')
        
        p_cond2.add_layout(labels)

        show(p_cond2)
        
        
        
        source_cond1 = ColumnDataSource(
            data=dict(
                x=cond1_f1s,
                y=cond1_freqs,
                desc=cond1_tags,
            )
        )

        p_cond1 = figure(plot_width=700, plot_height=600,x_range=[-0.04,1.04])

        p_cond1.scatter(x='x', y='y', size=8, source=source_cond1)
        
        labels = LabelSet(x='x', y='y', text='desc', level='glyph',
              x_offset=5, y_offset=5, source=source_cond1, render_mode='canvas', text_font_size='10px')
        
        p_cond1.add_layout(labels)

        show(p_cond1)
        
        
        
        source_lemma = ColumnDataSource(
            data=dict(
                x=lemma_f1s,
                y=lemma_freqs,
                desc=lemma_tags,
            )
        )

        p_lemma = figure(plot_width=700, plot_height=600,x_range=[-0.04,1.04])

        p_lemma.scatter(x='x', y='y', size=8, source=source_lemma)
        
        labels = LabelSet(x='x', y='y', text='desc', level='glyph',
              x_offset=5, y_offset=5, source=source_lemma, render_mode='canvas', text_font_size='10px')
        
        p_lemma.add_layout(labels)

        show(p_lemma)

In [3]:
model = 'seq2tree/output_nl_tst4'
data = 'PMB2/gold/'

In [4]:
analysis = Analysis(os.path.join(model, '27.results'), data)
analysis.category_analysis()
analysis.analysis_tags()

Reading 100 sentence...
Reading 200 sentence...
Reading 300 sentence...
Reading 400 sentence...
Reading 500 sentence...
Conditions ||| P:0.6138 R:0.5675 F1:0.5897
Operations ||| P:0.8182 R:0.3243 F1:0.4645
Lemmas ||| P:0.4765 R:0.5440 F1: 0.5080
Conditions one arg ||| P:0.6795 R:0.6474 F1:0.6631
Conditions two arg ||| P:0.5587 R:0.5038 F1:0.5298
====ANALYSIS TAGS====
====OP AGS====
POS &0.0000 &0.0000 &0.0000 & 70\\
RESULTS &-- &-- &-- & --\\
NOT &0.8214 &0.4510 &0.5823 & 247\\
IMP &0.9231 &0.4615 &0.6154 & 124\\
CONTINUATION &0.0000 &0.0000 &0.0000 & 47\\
CONTRAST &-- &-- &-- & --\\
NEC &0.0000 &0.0000 &0.0000 & 20\\
EXPLANATION &0.3333 &0.2500 &0.2857 & 9\\
====OPTION TAGS====
IMP ||| P:0.9231 R:0.4615 F1:0.6154 ||| Train freq: 124
NOT ||| P:0.8214 R:0.4510 F1:0.5823 ||| Train freq: 247
EXPLANATION ||| P:0.3333 R:0.2500 F1:0.2857 ||| Train freq: 9
POS ||| P:0.0000 R:0.0000 F1:0.0000 ||| Train freq: 70
CONTINUATION ||| P:0.0000 R:0.0000 F1:0.0000 ||| Train freq: 47
NEC ||| P:0.0000 R:

In [9]:
def analysis_ranks(trees, grammar):
    correct, wrong, rank_analysis = 0, 0, []
    f_nodes, f_edges, f_src, f_tar, f_label, f_tent_label, f_ext, f_terms = grammar
    for t in trees:
        t_correct, t_wrong, max_nt = 0,0,0
        _, struct, _, frags, _ = t
        queue = [0]
        while queue:
            curr = queue.pop(0)
#             check whether the children respect the rank
            if curr in struct:
                frag = frags[curr][0] if len(frags[curr])==1 else tuple(frags[curr])
                if isinstance(frag,tuple):
                    ord_nts= []
                    for f in frag:
                        labels = f_label[f]
                        e2labels = {e:nt for e,nt in labels.items() if nt.startswith('NT')} #e2: NT
                        edges2src = {f_src[f][e]:e for e in e2labels} #e2 :2
                        srt_labels = sorted(f_tent_label[f]['e1'].items(), key=lambda x: x[1])
                        ord_nts.extend([e2labels[edges2src[k]] for k,v in srt_labels if k in edges2src])
                else:
                    labels = f_label[frag]
                    e2labels = {e:nt for e,nt in labels.items() if nt.startswith('NT')} #e2: NT
                    edges2src = {f_src[frag][e]:e for e in e2labels} #e2 :2
                    srt_labels = sorted(f_tent_label[frag]['e1'].items(), key=lambda x: x[1])
                    ord_nts = [e2labels[edges2src[k]] for k,v in srt_labels if k in edges2src]         
                children = struct[curr]
                assert(len(children)==len(ord_nts))
                for c, nt in zip(children, ord_nts):
                    if int(nt[2:])> max_nt:
                        max_nt = int(nt[2:])
                    frag = frags[c][0] if len(frags)>1 else tuple(frags[c])
                    if isinstance(frag,tuple):
                        clone_ext = sum([len(f_ext[f])-1 for f in frag])+1
                        print(clone_ext, c, nt)
                        if clone_ext == int(nt[2:]):
                            t_correct+=1
                        else:
                            t_wrong+=1
                    else:
                        c_frag = f_ext[frag]
                        if len(c_frag) == int(nt[2:]):
                            t_correct+=1
                        else:
                            t_wrong+=1
                queue.extend(children)
        if t_wrong:
            wrong+=1
            rank_analysis.append((max_nt, 'w'))
        else:
            correct+=1
            rank_analysis.append((max_nt, 'c'))
        
    return wrong, correct, rank_analysis
        

In [43]:
import os
from grammar_scripts.tree_to_pmb_clone import load_trees, load_grammar

trees = load_trees('models/RDG_linear_gold_restrict/trees_results_dev.txt')
grammar = load_grammar(os.path.join(data,'train_grammar.txt'))
wrong, correct, rank_analysis = analysis_ranks(trees, grammar)

print('Mismatch in ranking %.2f (%d/%d)' % (wrong/(wrong+correct),wrong, wrong+correct))
print('Match in ranking %.2f (%d/%d)' % (correct/(wrong+correct),correct, wrong+correct))
print(Counter(rank_analysis))

ImportError: No module named 'grammar_scripts'