-
Notifications
You must be signed in to change notification settings - Fork 1
/
metrics.py
126 lines (90 loc) · 4.27 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
# -*- coding: utf-8 -*-
"""metrics.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1iAjbOC_0Y8t2SafAmfKp29Dkrn-7T56q
"""
import tensorflow as tf
import numpy as np
from collections import Counter
def masked_accuracy(label, pred):
"""
Computes the masked accuracy between the predicted and target labels.
Args:
label: Target label tensor.
pred: Predicted label tensor.
Returns:
Masked accuracy value.
"""
# Get the predicted labels by taking the argmax along the last dimension
pred_labels = tf.argmax(pred, axis=2)
# Convert the target labels to the same data type as the predicted labels
label = tf.cast(label, pred_labels.dtype)
# Compute a binary tensor for matching predicted and target labels
match = label == pred_labels
# Create a mask to ignore padded tokens
mask = label != 0
# Apply the mask to the matching tensor
match = match & mask
# Convert the binary tensor to floating-point values
match = tf.cast(match, dtype=tf.float32)
mask = tf.cast(mask, dtype=tf.float32)
# Compute the accuracy over non-padded tokens
return tf.reduce_sum(match) / tf.reduce_sum(mask)
def compute_precision(candidate_ngrams, reference_ngrams):
"""
Compute the precision of candidate n-grams with respect to reference n-grams.
Args:
candidate_ngrams: List of tuples representing candidate n-grams.
reference_ngrams: List of tuples representing reference n-grams.
Returns:
Precision value.
"""
candidate_counter = Counter(candidate_ngrams)
reference_counter = Counter(reference_ngrams)
# Calculate the intersection of n-grams in candidate and reference sentences
intersection = sum((candidate_counter & reference_counter).values())
# Total candidate n-grams
total_candidate = sum(candidate_counter.values())
# To avoid division by zero, set precision to a small value if there are no candidate n-grams
precision = intersection / total_candidate if total_candidate > 0 else 1e-10
return precision
def compute_bleu_batch(references_batch, candidates_batch, max_n=4):
"""
Compute the masked BLEU score for a batch of sentences.
Args:
label: Target label tensor.
pred: Predicted tensor.
max_n: Maximum n-gram for BLEU computation.
Returns:
Computed masked BLEU score.
"""
batch_size = len(references_batch)
total_bleu_score = 0.0
# Tokenize and compute n-grams for each candidate-reference pair in the batch
for i in range(batch_size):
references = references_batch[i]
candidates = candidates_batch[i]
precisions = []
for candidate, reference in zip(candidates, references):
candidate_tokens = candidate.split()
reference_tokens = reference.split()
# Calculate BLEU score for each n-gram up to max_n
for n in range(1, max_n + 1):
candidate_ngrams = [tuple(candidate_tokens[j:j + n]) for j in range(len(candidate_tokens) - n + 1)]
reference_ngrams = [tuple(reference_tokens[j:j + n]) for j in range(len(reference_tokens) - n + 1)]
precision_n = compute_precision(candidate_ngrams, reference_ngrams)
precisions.append(precision_n)
# Calculate the geometric mean of all the n-gram precisions for this candidate-reference pair
geometric_mean = np.exp(np.mean(np.log(np.maximum(precisions, 1e-10))))
# Calculate the brevity penalty for this candidate-reference pair
reference_length = [len(reference.split()) for reference in references]
candidate_length = [len(candidate.split()) for candidate in candidates]
closest_refs = [min(reference_length, key=lambda x: abs(x - candidate_len)) for candidate_len in candidate_length]
brevity_penalty = np.minimum(np.exp(1 - np.array(closest_refs) / np.array(candidate_length)), 1.0)
# Calculate the BLEU score for this candidate-reference pair
bleu_score = geometric_mean * brevity_penalty
total_bleu_score += bleu_score
# Calculate the average BLEU score over the entire batch
average_bleu_score = total_bleu_score / batch_size
return average_bleu_score