Skip to content

Commit

Permalink
Add PPL as evaluation function
Browse files Browse the repository at this point in the history
  • Loading branch information
lvapeab committed Apr 14, 2020
1 parent bbf132b commit 22d10a9
Showing 1 changed file with 22 additions and 10 deletions.
32 changes: 22 additions & 10 deletions keras_wrapper/extra/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

# Evaluation function selection

def get_sacrebleu_score(pred_list, verbose, extra_vars, split):
def get_sacrebleu_score(pred_list, verbose, extra_vars, split, **kwargs):
"""
SacreBLEU! metrics
:param pred_list: dictionary of hypothesis sentences (id, sentence)
Expand Down Expand Up @@ -42,7 +42,8 @@ def get_sacrebleu_score(pred_list, verbose, extra_vars, split):
num_references = len(initial_references)
refs = [[] for _ in range(num_references)]
for references in gts.values():
assert len(references) == num_references, '"get_sacrebleu_score" does not support a different number of references per sample.'
assert len(
references) == num_references, '"get_sacrebleu_score" does not support a different number of references per sample.'
for ref_idx, reference in enumerate(references):
# De/Tokenize refereces if needed
tokenized_ref = extra_vars['tokenize_f'](reference) if extra_vars.get('tokenize_references', False) \
Expand Down Expand Up @@ -71,7 +72,7 @@ def get_sacrebleu_score(pred_list, verbose, extra_vars, split):
return final_scores


def get_coco_score(pred_list, verbose, extra_vars, split):
def get_coco_score(pred_list, verbose, extra_vars, split, **kwargs):
"""
COCO challenge metrics
:param pred_list: dictionary of hypothesis sentences (id, sentence)
Expand Down Expand Up @@ -135,7 +136,17 @@ def get_coco_score(pred_list, verbose, extra_vars, split):
return final_scores


def eval_vqa(pred_list, verbose, extra_vars, split):
def get_perplexity(*args, **kwargs):
"""
Get perplexity
"""
metric = 'Perplexity'
ppl = np.average(np.exp(kwargs['costs']))
logger.info(metric + ': ' + str(ppl))
return {metric: ppl}


def eval_vqa(pred_list, verbose, extra_vars, split, **kwargs):
"""
VQA challenge metrics
:param pred_list: dictionary of hypothesis sentences (id, sentence)
Expand Down Expand Up @@ -184,7 +195,7 @@ def eval_vqa(pred_list, verbose, extra_vars, split):
'other accuracy': acc_other}


def multilabel_metrics(pred_list, verbose, extra_vars, split):
def multilabel_metrics(pred_list, verbose, extra_vars, split, **kwargs):
"""
Multiclass classification metrics. see multilabel ranking metrics in sklearn library for more info:
http://scikit-learn.org/stable/modules/model_evaluation.html#multilabel-ranking-metrics
Expand Down Expand Up @@ -265,7 +276,7 @@ def multilabel_metrics(pred_list, verbose, extra_vars, split):
'f1': f1}


def multiclass_metrics(pred_list, verbose, extra_vars, split):
def multiclass_metrics(pred_list, verbose, extra_vars, split, **kwargs):
"""
Multiclass classification metrics. See multilabel ranking metrics in sklearn library for more info:
http://scikit-learn.org/stable/modules/model_evaluation.html#multilabel-ranking-metrics
Expand Down Expand Up @@ -352,7 +363,7 @@ def multiclass_metrics(pred_list, verbose, extra_vars, split):
'top5_fps': list(top5_fps)}


def semantic_segmentation_accuracy(pred_list, verbose, extra_vars, split):
def semantic_segmentation_accuracy(pred_list, verbose, extra_vars, split, **kwargs):
"""
Semantic Segmentation Accuracy metric
Expand Down Expand Up @@ -404,7 +415,7 @@ def semantic_segmentation_accuracy(pred_list, verbose, extra_vars, split):
return {'semantic global accuracy': accuracy}


def semantic_segmentation_meaniou(pred_list, verbose, extra_vars, split):
def semantic_segmentation_meaniou(pred_list, verbose, extra_vars, split, **kwargs):
"""
Semantic Segmentation Mean IoU metric
Expand Down Expand Up @@ -475,7 +486,7 @@ def semantic_segmentation_meaniou(pred_list, verbose, extra_vars, split):
return {'mean IoU': mean_iou, 'semantic global accuracy': acc}


def averagePrecision(pred_list, verbose, extra_vars, split):
def averagePrecision(pred_list, verbose, extra_vars, split, **kwargs):
"""
Computes a Precision-Recall curve and its associated mAP score given a set of precalculated reports.
The parameter "report_all" must include the following information for each sample:
Expand Down Expand Up @@ -759,7 +770,7 @@ def _computeMeasures(IoU, n_classes, predicted_bboxes, predicted_Y, predicted_sc


# AUXILIARY FUNCTIONS
def vqa_store(question_id_list, answer_list, path):
def vqa_store(question_id_list, answer_list, path, **kwargs):
"""
Saves the answers on question_id_list in the VQA-like format.
Expand Down Expand Up @@ -793,6 +804,7 @@ def caption_store(samples, path):
'vqa': eval_vqa, # Metric for the VQA challenge
'coco': get_coco_score, # MS COCO evaluation library (BLEU, METEOR and CIDEr scores)
'sacrebleu': get_sacrebleu_score, # MS COCO evaluation library (BLEU, METEOR and CIDEr scores)
'perplexity': get_perplexity, # Perplexity
'multilabel_metrics': multilabel_metrics, # Set of multilabel classification metrics from sklearn
'multiclass_metrics': multiclass_metrics, # Set of multiclass classification metrics from sklearn
'AP': averagePrecision,
Expand Down

0 comments on commit 22d10a9

Please sign in to comment.