-
Notifications
You must be signed in to change notification settings - Fork 581
/
score.py
161 lines (131 loc) · 6.14 KB
/
score.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
"""Evaluation Metrics for Semantic Segmentation"""
import torch
import numpy as np
__all__ = ['SegmentationMetric', 'batch_pix_accuracy', 'batch_intersection_union',
'pixelAccuracy', 'intersectionAndUnion', 'hist_info', 'compute_score']
class SegmentationMetric(object):
"""Computes pixAcc and mIoU metric scores
"""
def __init__(self, nclass):
super(SegmentationMetric, self).__init__()
self.nclass = nclass
self.reset()
def update(self, preds, labels):
"""Updates the internal evaluation result.
Parameters
----------
labels : 'NumpyArray' or list of `NumpyArray`
The labels of the data.
preds : 'NumpyArray' or list of `NumpyArray`
Predicted values.
"""
def evaluate_worker(self, pred, label):
correct, labeled = batch_pix_accuracy(pred, label)
inter, union = batch_intersection_union(pred, label, self.nclass)
self.total_correct += correct
self.total_label += labeled
if self.total_inter.device != inter.device:
self.total_inter = self.total_inter.to(inter.device)
self.total_union = self.total_union.to(union.device)
self.total_inter += inter
self.total_union += union
if isinstance(preds, torch.Tensor):
evaluate_worker(self, preds, labels)
elif isinstance(preds, (list, tuple)):
for (pred, label) in zip(preds, labels):
evaluate_worker(self, pred, label)
def get(self):
"""Gets the current evaluation result.
Returns
-------
metrics : tuple of float
pixAcc and mIoU
"""
pixAcc = 1.0 * self.total_correct / (2.220446049250313e-16 + self.total_label) # remove np.spacing(1)
IoU = 1.0 * self.total_inter / (2.220446049250313e-16 + self.total_union)
mIoU = IoU.mean().item()
return pixAcc, mIoU
def reset(self):
"""Resets the internal evaluation result to initial state."""
self.total_inter = torch.zeros(self.nclass)
self.total_union = torch.zeros(self.nclass)
self.total_correct = 0
self.total_label = 0
# pytorch version
def batch_pix_accuracy(output, target):
"""PixAcc"""
# inputs are numpy array, output 4D, target 3D
predict = torch.argmax(output.long(), 1) + 1
target = target.long() + 1
pixel_labeled = torch.sum(target > 0).item()
pixel_correct = torch.sum((predict == target) * (target > 0)).item()
assert pixel_correct <= pixel_labeled, "Correct area should be smaller than Labeled"
return pixel_correct, pixel_labeled
def batch_intersection_union(output, target, nclass):
"""mIoU"""
# inputs are numpy array, output 4D, target 3D
mini = 1
maxi = nclass
nbins = nclass
predict = torch.argmax(output, 1) + 1
target = target.float() + 1
predict = predict.float() * (target > 0).float()
intersection = predict * (predict == target).float()
# areas of intersection and union
# element 0 in intersection occur the main difference from np.bincount. set boundary to -1 is necessary.
area_inter = torch.histc(intersection.cpu(), bins=nbins, min=mini, max=maxi)
area_pred = torch.histc(predict.cpu(), bins=nbins, min=mini, max=maxi)
area_lab = torch.histc(target.cpu(), bins=nbins, min=mini, max=maxi)
area_union = area_pred + area_lab - area_inter
assert torch.sum(area_inter > area_union).item() == 0, "Intersection area should be smaller than Union area"
return area_inter.float(), area_union.float()
def pixelAccuracy(imPred, imLab):
"""
This function takes the prediction and label of a single image, returns pixel-wise accuracy
To compute over many images do:
for i = range(Nimages):
(pixel_accuracy[i], pixel_correct[i], pixel_labeled[i]) = \
pixelAccuracy(imPred[i], imLab[i])
mean_pixel_accuracy = 1.0 * np.sum(pixel_correct) / (np.spacing(1) + np.sum(pixel_labeled))
"""
# Remove classes from unlabeled pixels in gt image.
# We should not penalize detections in unlabeled portions of the image.
pixel_labeled = np.sum(imLab >= 0)
pixel_correct = np.sum((imPred == imLab) * (imLab >= 0))
pixel_accuracy = 1.0 * pixel_correct / pixel_labeled
return (pixel_accuracy, pixel_correct, pixel_labeled)
def intersectionAndUnion(imPred, imLab, numClass):
"""
This function takes the prediction and label of a single image,
returns intersection and union areas for each class
To compute over many images do:
for i in range(Nimages):
(area_intersection[:,i], area_union[:,i]) = intersectionAndUnion(imPred[i], imLab[i])
IoU = 1.0 * np.sum(area_intersection, axis=1) / np.sum(np.spacing(1)+area_union, axis=1)
"""
# Remove classes from unlabeled pixels in gt image.
# We should not penalize detections in unlabeled portions of the image.
imPred = imPred * (imLab >= 0)
# Compute area intersection:
intersection = imPred * (imPred == imLab)
(area_intersection, _) = np.histogram(intersection, bins=numClass, range=(1, numClass))
# Compute area union:
(area_pred, _) = np.histogram(imPred, bins=numClass, range=(1, numClass))
(area_lab, _) = np.histogram(imLab, bins=numClass, range=(1, numClass))
area_union = area_pred + area_lab - area_intersection
return (area_intersection, area_union)
def hist_info(pred, label, num_cls):
assert pred.shape == label.shape
k = (label >= 0) & (label < num_cls)
labeled = np.sum(k)
correct = np.sum((pred[k] == label[k]))
return np.bincount(num_cls * label[k].astype(int) + pred[k], minlength=num_cls ** 2).reshape(num_cls,
num_cls), labeled, correct
def compute_score(hist, correct, labeled):
iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist))
mean_IU = np.nanmean(iu)
mean_IU_no_back = np.nanmean(iu[1:])
freq = hist.sum(1) / hist.sum()
freq_IU = (iu[freq > 0] * freq[freq > 0]).sum()
mean_pixel_acc = correct / labeled
return iu, mean_IU, mean_IU_no_back, mean_pixel_acc