Skip to content

Commit db208c7

Browse files
committed
initial commit
1 parent e60618d commit db208c7

8 files changed

+856
-2
lines changed

Diff for: MTM/NMS.py

+202
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Non-Maxima Supression (NMS) for match template
4+
From a pool of bounding box each predicting possible object locations with a given score,
5+
the NMS removes the bounding boxes overlapping with a bounding box of higher score above the maxOverlap threshold
6+
7+
This effectively removes redundant detections of the same object and still allow the detection of close objects (ie small possible overlap)
8+
The "possible" allowed overlap is set by the variable maxOverlap (between 0 and 1) which is the ratio Intersection over Union (IoU) area for a given pair of overlaping BBoxes
9+
10+
11+
@author: Laurent Thomas
12+
"""
13+
from __future__ import division, print_function # for compatibility with Py2
14+
15+
def Point_in_Rectangle(Point, Rectangle):
16+
'''Return True if a point (x,y) is contained in a Rectangle(x, y, width, height)'''
17+
# unpack variables
18+
Px, Py = Point
19+
Rx, Ry, w, h = Rectangle
20+
21+
return (Rx <= Px) and (Px <= Rx + w -1) and (Ry <= Py) and (Py <= Ry + h -1) # simply test if x_Point is in the range of x for the rectangle
22+
23+
24+
def computeIoU(BBox1,BBox2):
25+
'''
26+
Compute the IoU (Intersection over Union) between 2 rectangular bounding boxes defined by the top left (Xtop,Ytop) and bottom right (Xbot, Ybot) pixel coordinates
27+
Code adapted from https://www.pyimagesearch.com/2016/11/07/intersection-over-union-iou-for-object-detection/
28+
'''
29+
#print('BBox1 : ', BBox1)
30+
#print('BBox2 : ', BBox2)
31+
32+
# Unpack input (python3 - tuple are no more supported as input in function definition - PEP3113 - Tuple can be used in as argument in a call but the function will not unpack it automatically)
33+
Xleft1, Ytop1, Width1, Height1 = BBox1
34+
Xleft2, Ytop2, Width2, Height2 = BBox2
35+
36+
# Compute bottom coordinates
37+
Xright1 = Xleft1 + Width1 -1 # we remove -1 from the width since we start with 1 pixel already (the top one)
38+
Ybot1 = Ytop1 + Height1 -1 # idem for the height
39+
40+
Xright2 = Xleft2 + Width2 -1
41+
Ybot2 = Ytop2 + Height2 -1
42+
43+
# determine the (x, y)-coordinates of the top left and bottom right points of the intersection rectangle
44+
Xleft = max(Xleft1, Xleft2)
45+
Ytop = max(Ytop1, Ytop2)
46+
Xright = min(Xright1, Xright2)
47+
Ybot = min(Ybot1, Ybot2)
48+
49+
# Compute boolean for inclusion
50+
BBox1_in_BBox2 = Point_in_Rectangle((Xleft1, Ytop1), BBox2) and Point_in_Rectangle((Xleft1, Ybot1), BBox2) and Point_in_Rectangle((Xright1, Ytop1), BBox2) and Point_in_Rectangle((Xright1, Ybot1), BBox2)
51+
BBox2_in_BBox1 = Point_in_Rectangle((Xleft2, Ytop2), BBox1) and Point_in_Rectangle((Xleft2, Ybot2), BBox1) and Point_in_Rectangle((Xright2, Ytop2), BBox1) and Point_in_Rectangle((Xright2, Ybot2), BBox1)
52+
53+
# Check that for the intersection box, Xtop,Ytop is indeed on the top left of Xbot,Ybot
54+
if BBox1_in_BBox2 or BBox2_in_BBox1:
55+
#print('One BBox is included within the other')
56+
IoU = 1
57+
58+
elif Xright<Xleft or Ybot<Ytop : # it means that there is no intersection (bbox is inverted)
59+
#print('No overlap')
60+
IoU = 0
61+
62+
else:
63+
# Compute area of the intersecting box
64+
Inter = (Xright - Xleft + 1) * (Ybot - Ytop + 1) # +1 since we are dealing with pixels. See a 1D example with 3 pixels for instance
65+
#print('Intersection area : ', Inter)
66+
67+
# Compute area of the union as Sum of the 2 BBox area - Intersection
68+
Union = Width1 * Height1 + Width2 * Height2 - Inter
69+
#print('Union : ', Union)
70+
71+
# Compute Intersection over union
72+
IoU = Inter/Union
73+
74+
#print('IoU : ',IoU)
75+
return IoU
76+
77+
78+
79+
def NMS(List_Hit, scoreThreshold=None, sortDescending=True, N_object=float("inf"), maxOverlap=0.5):
80+
'''
81+
Perform Non-Maxima supression : it compares the hits after maxima/minima detection, and removes the ones that are too close (too large overlap)
82+
This function works both with an optionnal threshold on the score, and number of detected bbox
83+
84+
if a scoreThreshold is specified, we first discard any hit below/above the threshold (depending on sortDescending)
85+
if sortDescending = True, the hit with score below the treshold are discarded (ie when high score means better prediction ex : Correlation)
86+
if sortDescending = False, the hit with score above the threshold are discared (ie when low score means better prediction ex : Distance measure)
87+
88+
Then the hit are ordered so that we have the best hits first.
89+
Then we iterate over the list of hits, taking one hit at a time and checking for overlap with the previous validated hit (the Final Hit list is directly iniitialised with the first best hit as there is no better hit with which to compare overlap)
90+
91+
This iteration is terminate once we have collected N best hit, or if there are no more hit left to test for overlap
92+
93+
INPUT
94+
- ListHit : a list of dictionnary, with each dictionnary being a hit following the formating {'TemplateIdx'= (int),'BBox'=(x,y,width,height),'Score'=(float)}
95+
the TemplateIdx is the row index in the panda/Knime table
96+
97+
- scoreThreshold : Float (or None), used to remove hit with too low prediction score.
98+
If sortDescending=True (ie we use a correlation measure so we want to keep large scores) the scores above that threshold are kept
99+
While if we use sortDescending=False (we use a difference measure ie we want to keep low score), the scores below that threshold are kept
100+
101+
- N_object : number of best hit to return (by increasing score). Min=1, eventhough it does not really make sense to do NMS with only 1 hit
102+
- maxOverlap : float between 0 and 1, the maximal overlap authorised between 2 bounding boxes, above this value, the bounding box of lower score is deleted
103+
- sortDescending : use True when high score means better prediction, False otherwise (ex : if score is a difference measure, then the best prediction are low difference and we sort by ascending order)
104+
105+
OUTPUT
106+
List_nHit : List of the best detection after NMS, it contains max N detection (but potentially less)
107+
'''
108+
109+
# Apply threshold on prediction score
110+
if scoreThreshold==None :
111+
List_ThreshHit = List_Hit[:] # copy to avoid modifying the input list in place
112+
113+
elif sortDescending : # We keep hit above the threshold
114+
List_ThreshHit = [dico for dico in List_Hit if dico['Score']>=scoreThreshold]
115+
116+
elif not sortDescending : # We keep hit below the threshold
117+
List_ThreshHit = [dico for dico in List_Hit if dico['Score']<=scoreThreshold]
118+
119+
120+
# Sort score to have best predictions first (important as we loop testing the best boxes against the other boxes)
121+
if sortDescending:
122+
List_ThreshHit.sort(key=lambda dico: dico['Score'], reverse=True) # Hit = [list of (x,y),score] - sort according to descending (best = high correlation)
123+
else:
124+
List_ThreshHit.sort(key=lambda dico: dico['Score']) # sort according to ascending score (best = small difference)
125+
126+
127+
# Split the inital pool into Final Hit that are kept and restHit that can be tested
128+
# Initialisation : 1st keep is kept for sure, restHit is the rest of the list
129+
#print("\nInitialise final hit list with first best hit")
130+
FinalHit = [List_ThreshHit[0]]
131+
restHit = List_ThreshHit[1:]
132+
133+
134+
# Loop to compute overlap
135+
while len(FinalHit)<N_object and restHit : # second condition is restHit is not empty
136+
137+
# Report state of the loop
138+
#print("\n\n\nNext while iteration")
139+
140+
#print("-> Final hit list")
141+
#for hit in FinalHit: print(hit)
142+
143+
#print("\n-> Remaining hit list")
144+
#for hit in restHit: print(hit)
145+
146+
# pick the next best peak in the rest of peak
147+
test_hit = restHit[0]
148+
test_bbox = test_hit['BBox']
149+
#print("\nTest BBox:{} for overlap against higher score bboxes".format(test_bbox))
150+
151+
# Loop over hit in FinalHit to compute successively overlap with test_peak
152+
for hit in FinalHit:
153+
154+
# Recover Bbox from hit
155+
bbox2 = hit['BBox']
156+
157+
# Compute the Intersection over Union between test_peak and current peak
158+
IoU = computeIoU(test_bbox, bbox2)
159+
160+
# Initialise the boolean value to true before test of overlap
161+
ToAppend = True
162+
163+
if IoU>maxOverlap:
164+
ToAppend = False
165+
#print("IoU above threshold\n")
166+
break # no need to test overlap with the other peaks
167+
168+
else:
169+
#print("IoU below threshold\n")
170+
# no overlap for this particular (test_peak,peak) pair, keep looping to test the other (test_peak,peak)
171+
continue
172+
173+
174+
# After testing against all peaks (for loop is over), append or not the peak to final
175+
if ToAppend:
176+
# Move the test_hit from restHit to FinalHit
177+
#print("Append {} to list of final hits, remove it from Remaining hit list".format(test_hit))
178+
FinalHit.append(test_hit)
179+
restHit.remove(test_hit)
180+
181+
else:
182+
# only remove the test_peak from restHit
183+
#print("Remove {} from Remaining hit list".format(test_hit))
184+
restHit.remove(test_hit)
185+
186+
187+
# Once function execution is done, return list of hit without overlap
188+
#print("\nCollected N expected hit, or no hit left to test")
189+
#print("NMS over\n")
190+
return FinalHit
191+
192+
193+
if __name__ == "__main__":
194+
Hit1 = {'TemplateIdx':1,'BBox':(780, 350, 700, 480), 'Score':0.8}
195+
Hit2 = {'TemplateIdx':1,'BBox':(806, 416, 716, 442), 'Score':0.6}
196+
Hit3 = {'TemplateIdx':1,'BBox':(1074, 530, 680, 390), 'Score':0.4}
197+
198+
ListHit = [Hit1, Hit2, Hit3]
199+
200+
ListFinalHit = NMS(ListHit)
201+
202+
print(ListFinalHit)

Diff for: MTM/__init__.py

+168
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
import cv2
2+
import numpy as np
3+
from skimage.feature import peak_local_max
4+
from scipy.signal import find_peaks
5+
6+
from MTM.NMS import NMS
7+
8+
__all__ = ['NMS']
9+
10+
11+
def _findLocalMax_(corrMap, score_threshold=0.6):
12+
'''
13+
Get coordinates of the local maximas with values above a threshold in the image of the correlation map
14+
'''
15+
16+
# IF depending on the shape of the correlation map
17+
if corrMap.shape == (1,1): ## Template size = Image size -> Correlation map is a single digit')
18+
19+
if corrMap[0,0]>=score_threshold:
20+
Peaks = np.array([[0,0]])
21+
else:
22+
Peaks = []
23+
24+
# use scipy findpeaks for the 1D cases (would allow to specify the relative threshold for the score directly here rather than in the NMS
25+
elif corrMap.shape[0] == 1: ## Template is as high as the image, the correlation map is a 1D-array
26+
Peaks = find_peaks(corrMap[0], height=score_threshold) # corrMap[0] to have a proper 1D-array
27+
Peaks = [[0,i] for i in Peaks[0]] # 0,i since one coordinate is fixed (the one for which Template = Image)
28+
29+
30+
elif corrMap.shape[1] == 1: ## Template is as wide as the image, the correlation map is a 1D-array
31+
#Peaks = argrelmax(corrMap, mode="wrap")
32+
Peaks = find_peaks(corrMap[:,0], height=score_threshold)
33+
Peaks = [[i,0] for i in Peaks[0]]
34+
35+
36+
else: # Correlatin map is 2D
37+
Peaks = peak_local_max(corrMap, threshold_abs=score_threshold, exclude_border=False).tolist()
38+
39+
return Peaks
40+
41+
42+
43+
def _findLocalMin_(corrMap, score_threshold=0.4):
44+
'''Find coordinates of local minimas with values below a threshold in the image of the correlation map'''
45+
return _findLocalMax_(-corrMap, -score_threshold)
46+
47+
48+
49+
def findMatches(listTemplates, image, method=cv2.TM_CCOEFF_NORMED, N_object=float("inf"), score_threshold=0.5):
50+
'''
51+
Find all possible templates locations provided a list of template to search and an image
52+
- listTemplate : list of tuples [(templateName, templateImage), (templateName2, templateImage2) ]
53+
- method : one of OpenCV template matching method (0 to 5)
54+
- N_object: expected number of object in the image
55+
- score_threshold: if N>1, returns local minima/maxima respectively below/above the score_threshold
56+
'''
57+
if N_object!=float("inf") and type(N_object)!=int:
58+
raise TypeError("N_object must be an integer")
59+
60+
elif N_object<1:
61+
raise ValueError("At least one object should be expected in the image")
62+
63+
## 16-bit image are converted to 32-bit for matchTemplate
64+
if image.dtype == 'uint16': image = np.float32(image)
65+
66+
listHit = []
67+
for templateName, template in listTemplates:
68+
69+
#print('\nSearch with template : ',templateName)
70+
## 16-bit image are converted to 32-bit for matchTemplate
71+
if template.dtype == 'uint16': template = np.float32(template)
72+
73+
## Compute correlation map
74+
corrMap = cv2.matchTemplate(template, image, method)
75+
76+
## Find possible location of the object
77+
if N_object==1: # Detect global Min/Max
78+
minVal, maxVal, minLoc, maxLoc = cv2.minMaxLoc(corrMap)
79+
80+
if method==1:
81+
Peaks = [minLoc[::-1]] # opposite sorting than in the multiple detection
82+
83+
elif method in (3,5):
84+
Peaks = [maxLoc[::-1]]
85+
86+
87+
else:# Detect local max or min
88+
if method==1: # Difference => look for local minima
89+
Peaks = _findLocalMin_(corrMap, score_threshold)
90+
91+
elif method in (3,5):
92+
Peaks = _findLocalMax_(corrMap, score_threshold)
93+
94+
95+
#print('Initially found',len(Peaks),'hit with this template')
96+
97+
98+
# Once every peak was detected for this given template
99+
## Create a dictionnary for each hit with {'TemplateName':, 'BBox': (x,y,Width, Height), 'Score':coeff}
100+
101+
height, width = template.shape[0:2] # slicing make sure it works for RGB too
102+
for peak in Peaks :
103+
coeff = corrMap[tuple(peak)]
104+
newHit = {'TemplateName':templateName, 'BBox': [int(peak[1]), int(peak[0]), width, height], 'Score':coeff}
105+
106+
# append to list of potential hit before Non maxima suppression
107+
listHit.append(newHit)
108+
109+
110+
return listHit # All possible hit before Non-Maxima Supression
111+
112+
113+
def matchTemplates(listTemplates, image, method=cv2.TM_CCOEFF_NORMED, N_object=float("inf"), score_threshold=0.5, maxOverlap=0.25):
114+
'''
115+
Search each template in the image, and return the best N_object location which offer the best score and which do not overlap
116+
- listTemplate : list of tuples (templateName, templateImage)
117+
- method : one of OpenCV template matching method (0 to 5)
118+
- N_object: expected number of object in the image
119+
- score_threshold: if N>1, returns local minima/maxima respectively below/above the score_threshold
120+
'''
121+
if maxOverlap<0 or maxOverlap>1:
122+
raise ValueError("Maximal overlap between bounding box is in range [0-1]")
123+
124+
listHit = findMatches(listTemplates, image, method, N_object, score_threshold)
125+
126+
if method == 1: bestHits = NMS(listHit, N_object=N_object, maxOverlap=maxOverlap, sortDescending=False)
127+
128+
elif method in (3,5): bestHits = NMS(listHit, N_object=N_object, maxOverlap=maxOverlap, sortDescending=True)
129+
130+
return bestHits
131+
132+
133+
def drawBoxes(img, listHit, boxThickness=2, boxColor=(255, 255, 00), showLabel=True, labelColor=(255, 255, 255) ):
134+
"""
135+
Return a copy of the image with results of template matching drawn as yellow rectangle and name of the template on top
136+
"""
137+
outImage = img.copy()
138+
139+
for hit in listHit:
140+
x,y,w,h = hit['BBox']
141+
cv2.rectangle(outImage, (x, y), (x+w, y+h), color=boxColor, thickness=boxThickness)
142+
if showLabel: cv2.putText(outImage, text=hit['TemplateName'], org=(x, y), fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.5, color=labelColor, lineType=cv2.LINE_AA)
143+
144+
return outImage
145+
146+
147+
if __name__ == '__main__':
148+
149+
from skimage.data import coins
150+
import matplotlib.pyplot as plt
151+
152+
## Get image and template
153+
smallCoin = coins()[37:37+38, 80:80+41]
154+
bigCoin = coins()[14:14+59,302:302+65]
155+
image = coins()
156+
157+
## Perform matching
158+
listHit = matchTemplates([('small', smallCoin), ('big', bigCoin)], image, score_threshold=0.4, method=cv2.TM_CCOEFF_NORMED, maxOverlap=0)
159+
#listHit = matchTemplates([('small', smallCoin), ('big', bigCoin)], image, N_object=1, score_threshold=0.4, method=cv2.TM_CCOEFF_NORMED, maxOverlap=0)
160+
161+
print("Found {} coins".format(len(listHit)))
162+
163+
for hit in listHit:
164+
print(hit)
165+
166+
## Display matches
167+
Overlay = drawBoxes(image, listHit)
168+
plt.imshow(Overlay)

0 commit comments

Comments
 (0)