-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathcalc_metrics.py
139 lines (106 loc) · 4.3 KB
/
calc_metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import argparse
import json
import os
import os.path as osp
import pandas as pd
import numpy as np
def parse_csv(csv_path):
df = pd.read_csv(csv_path)
return df
def save_json(json_path, data):
with open(json_path, 'w') as f:
f.write(json.dumps(data))
f.write('\n')
def assert_same_images(gt_imgs, pred_imgs):
gt_without_pred = list(set(gt_imgs) - set(pred_imgs))
assert len(gt_without_pred) == 0, 'Error! There are images without prediction, e.g. {}'.format(str(gt_without_pred[:10]))
not_gt_predicted = list(set(pred_imgs) - set(gt_imgs))
assert len(not_gt_predicted) == 0, 'Error! There are predictions for other images, e.g. {}'.format(str(not_gt_predicted[:10]))
def normalize(v):
return np.array(v) / np.linalg.norm(v, ord=2)
def angle(v1, v2):
v1 = normalize(v1)
v2 = normalize(v2)
return np.arccos((v1*v2).sum().clip(-1, 1)) / np.pi * 180
def repr_ang_error(gt, pred):
"""Calc reproduction angular errors. See "Finlayson, Graham D., and Roshanak Zakizadeh. Reproduction angular
error: An improved performance metric for illuminant estimation."
In the IEC#1 arguments order was mistakenly swapped. It slightly changed metric values within the same competitors' solutions order.
"""
errs = []
for v_gt, v_pred in zip(gt, pred):
err = angle([1, 1, 1], v_gt / v_pred)
errs.append(err)
return np.array(errs)
def mean_repr_ang_error(gt, pred):
reprs = repr_ang_error(gt=gt, pred=pred)
return np.mean(reprs)
def worst_mean_repr_ang_error(gt, pred, skip_share=0.75):
reprs = repr_ang_error(gt=gt, pred=pred)
reprs = sorted(reprs)
worst = reprs[int(skip_share * len(reprs)):]
return np.mean(worst)
def two_illuminant_error(gt, pred):
"""mean squared simmetrized reproduction angular error"""
gt1, gt2 = gt
pred1, pred2 = pred
reprs_sq_straight = (
repr_ang_error(gt1, pred1) ** 2 +
repr_ang_error(gt2, pred2) ** 2
)
reprs_sq_reverted = (
repr_ang_error(gt1, pred2) ** 2 +
repr_ang_error(gt2, pred1) ** 2
)
reprs_sq = np.minimum(reprs_sq_straight, reprs_sq_reverted)
return np.mean(reprs_sq)
def calc_metrics(gt, pred, problem_type):
assert problem_type in ['indoor', 'general', 'two_illuminant']
assert_same_images(gt['image'], pred['image'])
pred = pred.add_prefix('p_')
joined = gt.set_index('image').join(pred.set_index('p_image'))
if problem_type == 'indoor':
error = mean_repr_ang_error(
joined[['r', 'g', 'b']].to_numpy(),
joined[['p_r', 'p_g', 'p_b']].to_numpy()
)
return {'mean_repr_ang_error': error}
elif problem_type == 'general':
error = worst_mean_repr_ang_error(
joined[['r', 'g', 'b']].to_numpy(),
joined[['p_r', 'p_g', 'p_b']].to_numpy()
)
return {'worst_mean_repr_ang_error': error}
elif problem_type == 'two_illuminant':
gt_vals = (
joined[['r1', 'g1', 'b1']].to_numpy(),
joined[['r2', 'g2', 'b2']].to_numpy(),
)
pred_vals = (
joined[['p_r1', 'p_g1', 'p_b1']].to_numpy(),
joined[['p_r2', 'p_g2', 'p_b2']].to_numpy(),
)
return {'two_illuminant_error': two_illuminant_error(gt_vals, pred_vals)}
else:
raise NotImplementedError
###############################################################################
# main
def parse_args():
parser = argparse.ArgumentParser("Calculate final metrics for the ICMV 2020 2nd IEC challenge")
parser.add_argument('--gt', required=True, help="csv with ground_truth answers")
parser.add_argument('--pred', required=True, help="csv with predicted answers")
parser.add_argument('-o', '--output', help="file to save metrics info")
parser.add_argument('-p', '--problem', required=True, help="problem type")
return parser.parse_args()
def main(gt, pred, problem, output=None, verbose=True):
gt_ans = parse_csv(gt)
pred_ans = parse_csv(pred)
results = calc_metrics(gt_ans, pred_ans, problem)
if verbose:
print(results)
if output:
save_json(output, results)
return results
if __name__ == "__main__":
args = parse_args()
main(args.gt, args.pred, args.problem, args.output)