Skip to content

Commit

Permalink
Add try/catch block across ZeroDivisionError
Browse files Browse the repository at this point in the history
  • Loading branch information
NISH1001 committed Mar 30, 2023
1 parent 98b1da3 commit 6bdf680
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions jury/metrics/accuracy/accuracy_for_language_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,14 @@
import evaluate
import numpy as np

from evaluate.utils.logging import get_logger

from jury.collator import Collator
from jury.metrics._core import MetricForLanguageGeneration
from jury.utils.nlp import normalize_text

logger = get_logger(__name__)

_CITATION = """\
@article{scikit-learn,
title={Scikit-learn: Machine Learning in {P}ython},
Expand Down Expand Up @@ -66,7 +70,7 @@
>>> accuracy = jury.load_metric("accuracy")
>>> predictions = [["the cat is on the mat", "There is cat playing on the mat"], ["Look! a wonderful day."]]
>>> references = [
["the cat is playing on the mat.", "The cat plays on the mat."],
["the cat is playing on the mat.", "The cat plays on the mat."],
["Today is a wonderful day", "The weather outside is wonderful."]
]
>>> results = accuracy.compute(predictions=predictions, references=references)
Expand Down Expand Up @@ -104,7 +108,11 @@ def _compute_single_pred_single_ref(
for token, pred_count in pred_counts.items():
if token in ref_counts:
score += min(pred_count, ref_counts[token]) # Intersection count
scores.append(score / max(len(pred), len(ref)))
try:
scores.append(score / max(len(pred), len(ref)))
except ZeroDivisionError:
logger.warning("Empty pred/ref. Ignoring!")

avg_score = sum(scores) / len(scores)
return {"score": avg_score}

Expand Down

0 comments on commit 6bdf680

Please sign in to comment.