Skip to content

Commit

Permalink
Bugfix: swapped precision & recall
Browse files Browse the repository at this point in the history
  • Loading branch information
tuetschek committed Oct 7, 2014
1 parent 3d72744 commit edd44d2
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 14 deletions.
8 changes: 4 additions & 4 deletions run_tgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from tgen.candgen import RandomCandidateGenerator
from tgen.rank import PerceptronRanker
from tgen.planner import SamplingPlanner, ASearchPlanner
from tgen.eval import p_r_f1_from_counts, tp_fp_fn, f1_from_counts, ASearchListsAnalyzer, \
from tgen.eval import p_r_f1_from_counts, corr_pred_gold, f1_from_counts, ASearchListsAnalyzer, \
EvalTypes, Evaluator
from tgen.parallel_percrank_train import ParallelPerceptronRanker

Expand Down Expand Up @@ -149,14 +149,14 @@ def sample_gen(args):
for gold_tree, gen_trees in zip(gold_trees, chunk_list(gen_trees, num_to_generate)):
# find best of predicted trees (in terms of F1)
_, tc, tp, tg = max([(f1_from_counts(c, p, g), c, p, g) for c, p, g
in map(lambda gen_tree: tp_fp_fn(gold_tree, gen_tree),
in map(lambda gen_tree: corr_pred_gold(gold_tree, gen_tree),
gen_trees)],
key=lambda x: x[0])
correct += tc
predicted += tp
gold += tg
# evaluate oracle F1
log_info("Oracle Precision: %.6f, Recall: %.6f, F1: %.6f" % p_r_f1_from_counts(correct, gold, predicted))
log_info("Oracle Precision: %.6f, Recall: %.6f, F1: %.6f" % p_r_f1_from_counts(correct, predicted, gold))
# write output
if fname_ttrees_out is not None:
log_info('Writing output...')
Expand Down Expand Up @@ -228,7 +228,7 @@ def asearch_gen(args):
evaler = Evaluator()
for eval_bundle, eval_ttree, gen_ttree in zip(eval_doc.bundles, eval_ttrees, gen_ttrees):
add_bundle_text(eval_bundle, tgen.language, tgen.selector + 'Xscore',
"P: %.4f R: %.4f F1: %.4f" % p_r_f1_from_counts(*tp_fp_fn(eval_ttree, gen_ttree)))
"P: %.4f R: %.4f F1: %.4f" % p_r_f1_from_counts(*corr_pred_gold(eval_ttree, gen_ttree)))
evaler.append(eval_ttree, gen_ttree)
log_info("NODE precision: %.4f, Recall: %.4f, F1: %.4f" % evaler.p_r_f1())
log_info("DEP precision: %.4f, Recall: %.4f, F1: %.4f" % evaler.p_r_f1(EvalTypes.DEP))
Expand Down
20 changes: 10 additions & 10 deletions tgen/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def collect_counts(ttree, eval_type=EvalTypes.NODE):
return counts


def tp_fp_fn(gold_ttree, pred_ttree, eval_type=EvalTypes.NODE):
def corr_pred_gold(gold_ttree, pred_ttree, eval_type=EvalTypes.NODE):
"""Given a golden tree and a predicted tree, this counts correctly
predicted nodes (true positives), all predicted nodes (true + false
positives), and all golden nodes (true positives + false negatives).
Expand All @@ -59,25 +59,25 @@ def tp_fp_fn(gold_ttree, pred_ttree, eval_type=EvalTypes.NODE):

def precision(gold_ttree, pred_ttree, eval_type=EvalTypes.NODE):
# # correct / # predicted
correct, predicted, _ = tp_fp_fn(gold_ttree, pred_ttree, eval_type)
correct, predicted, _ = corr_pred_gold(gold_ttree, pred_ttree, eval_type)
return correct / float(predicted)


def recall(gold_ttree, pred_ttree, eval_type=EvalTypes.NODE):
# # correct / # gold
correct, _, gold = tp_fp_fn(gold_ttree, pred_ttree, eval_type)
correct, _, gold = corr_pred_gold(gold_ttree, pred_ttree, eval_type)
return correct / float(gold)


def f1(gold_ttree, pred_ttree, eval_type=EvalTypes.NODE):
return f1_from_counts(tp_fp_fn(gold_ttree, pred_ttree, eval_type))
return f1_from_counts(corr_pred_gold(gold_ttree, pred_ttree, eval_type))


def f1_from_counts(correct, gold, predicted):
return p_r_f1_from_counts(correct, gold, predicted)[2]
def f1_from_counts(correct, predicted, gold):
return p_r_f1_from_counts(correct, predicted, gold)[2]


def p_r_f1_from_counts(correct, gold, predicted):
def p_r_f1_from_counts(correct, predicted, gold):
"""Return precision, recall, and F1 given counts of true positives (correct),
total predicted nodes, and total gold nodes.
Expand Down Expand Up @@ -105,7 +105,7 @@ def __init__(self):

def append(self, gold_tree, pred_tree):
for eval_type in EvalTypes:
correct, predicted, gold = tp_fp_fn(gold_tree, pred_tree, eval_type)
correct, predicted, gold = corr_pred_gold(gold_tree, pred_tree, eval_type)
self.correct[eval_type] += correct
self.predicted[eval_type] += predicted
self.gold[eval_type] += gold
Expand All @@ -126,8 +126,8 @@ def recall(self, eval_type=EvalTypes.NODE):

def p_r_f1(self, eval_type=EvalTypes.NODE):
return p_r_f1_from_counts(self.correct[eval_type],
self.gold[eval_type],
self.predicted[eval_type])
self.predicted[eval_type],
self.gold[eval_type])


class ASearchListsAnalyzer(object):
Expand Down

0 comments on commit edd44d2

Please sign in to comment.