|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +""" |
| 3 | +Non-Maxima Supression (NMS) for match template |
| 4 | +From a pool of bounding box each predicting possible object locations with a given score, |
| 5 | +the NMS removes the bounding boxes overlapping with a bounding box of higher score above the maxOverlap threshold |
| 6 | +
|
| 7 | +This effectively removes redundant detections of the same object and still allow the detection of close objects (ie small possible overlap) |
| 8 | +The "possible" allowed overlap is set by the variable maxOverlap (between 0 and 1) which is the ratio Intersection over Union (IoU) area for a given pair of overlaping BBoxes |
| 9 | +
|
| 10 | +
|
| 11 | +@author: Laurent Thomas |
| 12 | +""" |
| 13 | +from __future__ import division, print_function # for compatibility with Py2 |
| 14 | + |
| 15 | +def Point_in_Rectangle(Point, Rectangle): |
| 16 | + '''Return True if a point (x,y) is contained in a Rectangle(x, y, width, height)''' |
| 17 | + # unpack variables |
| 18 | + Px, Py = Point |
| 19 | + Rx, Ry, w, h = Rectangle |
| 20 | + |
| 21 | + return (Rx <= Px) and (Px <= Rx + w -1) and (Ry <= Py) and (Py <= Ry + h -1) # simply test if x_Point is in the range of x for the rectangle |
| 22 | + |
| 23 | + |
| 24 | +def computeIoU(BBox1,BBox2): |
| 25 | + ''' |
| 26 | + Compute the IoU (Intersection over Union) between 2 rectangular bounding boxes defined by the top left (Xtop,Ytop) and bottom right (Xbot, Ybot) pixel coordinates |
| 27 | + Code adapted from https://www.pyimagesearch.com/2016/11/07/intersection-over-union-iou-for-object-detection/ |
| 28 | + ''' |
| 29 | + #print('BBox1 : ', BBox1) |
| 30 | + #print('BBox2 : ', BBox2) |
| 31 | + |
| 32 | + # Unpack input (python3 - tuple are no more supported as input in function definition - PEP3113 - Tuple can be used in as argument in a call but the function will not unpack it automatically) |
| 33 | + Xleft1, Ytop1, Width1, Height1 = BBox1 |
| 34 | + Xleft2, Ytop2, Width2, Height2 = BBox2 |
| 35 | + |
| 36 | + # Compute bottom coordinates |
| 37 | + Xright1 = Xleft1 + Width1 -1 # we remove -1 from the width since we start with 1 pixel already (the top one) |
| 38 | + Ybot1 = Ytop1 + Height1 -1 # idem for the height |
| 39 | + |
| 40 | + Xright2 = Xleft2 + Width2 -1 |
| 41 | + Ybot2 = Ytop2 + Height2 -1 |
| 42 | + |
| 43 | + # determine the (x, y)-coordinates of the top left and bottom right points of the intersection rectangle |
| 44 | + Xleft = max(Xleft1, Xleft2) |
| 45 | + Ytop = max(Ytop1, Ytop2) |
| 46 | + Xright = min(Xright1, Xright2) |
| 47 | + Ybot = min(Ybot1, Ybot2) |
| 48 | + |
| 49 | + # Compute boolean for inclusion |
| 50 | + BBox1_in_BBox2 = Point_in_Rectangle((Xleft1, Ytop1), BBox2) and Point_in_Rectangle((Xleft1, Ybot1), BBox2) and Point_in_Rectangle((Xright1, Ytop1), BBox2) and Point_in_Rectangle((Xright1, Ybot1), BBox2) |
| 51 | + BBox2_in_BBox1 = Point_in_Rectangle((Xleft2, Ytop2), BBox1) and Point_in_Rectangle((Xleft2, Ybot2), BBox1) and Point_in_Rectangle((Xright2, Ytop2), BBox1) and Point_in_Rectangle((Xright2, Ybot2), BBox1) |
| 52 | + |
| 53 | + # Check that for the intersection box, Xtop,Ytop is indeed on the top left of Xbot,Ybot |
| 54 | + if BBox1_in_BBox2 or BBox2_in_BBox1: |
| 55 | + #print('One BBox is included within the other') |
| 56 | + IoU = 1 |
| 57 | + |
| 58 | + elif Xright<Xleft or Ybot<Ytop : # it means that there is no intersection (bbox is inverted) |
| 59 | + #print('No overlap') |
| 60 | + IoU = 0 |
| 61 | + |
| 62 | + else: |
| 63 | + # Compute area of the intersecting box |
| 64 | + Inter = (Xright - Xleft + 1) * (Ybot - Ytop + 1) # +1 since we are dealing with pixels. See a 1D example with 3 pixels for instance |
| 65 | + #print('Intersection area : ', Inter) |
| 66 | + |
| 67 | + # Compute area of the union as Sum of the 2 BBox area - Intersection |
| 68 | + Union = Width1 * Height1 + Width2 * Height2 - Inter |
| 69 | + #print('Union : ', Union) |
| 70 | + |
| 71 | + # Compute Intersection over union |
| 72 | + IoU = Inter/Union |
| 73 | + |
| 74 | + #print('IoU : ',IoU) |
| 75 | + return IoU |
| 76 | + |
| 77 | + |
| 78 | + |
| 79 | +def NMS(List_Hit, scoreThreshold=None, sortDescending=True, N_object=float("inf"), maxOverlap=0.5): |
| 80 | + ''' |
| 81 | + Perform Non-Maxima supression : it compares the hits after maxima/minima detection, and removes the ones that are too close (too large overlap) |
| 82 | + This function works both with an optionnal threshold on the score, and number of detected bbox |
| 83 | +
|
| 84 | + if a scoreThreshold is specified, we first discard any hit below/above the threshold (depending on sortDescending) |
| 85 | + if sortDescending = True, the hit with score below the treshold are discarded (ie when high score means better prediction ex : Correlation) |
| 86 | + if sortDescending = False, the hit with score above the threshold are discared (ie when low score means better prediction ex : Distance measure) |
| 87 | +
|
| 88 | + Then the hit are ordered so that we have the best hits first. |
| 89 | + Then we iterate over the list of hits, taking one hit at a time and checking for overlap with the previous validated hit (the Final Hit list is directly iniitialised with the first best hit as there is no better hit with which to compare overlap) |
| 90 | + |
| 91 | + This iteration is terminate once we have collected N best hit, or if there are no more hit left to test for overlap |
| 92 | + |
| 93 | + INPUT |
| 94 | + - ListHit : a list of dictionnary, with each dictionnary being a hit following the formating {'TemplateIdx'= (int),'BBox'=(x,y,width,height),'Score'=(float)} |
| 95 | + the TemplateIdx is the row index in the panda/Knime table |
| 96 | + |
| 97 | + - scoreThreshold : Float (or None), used to remove hit with too low prediction score. |
| 98 | + If sortDescending=True (ie we use a correlation measure so we want to keep large scores) the scores above that threshold are kept |
| 99 | + While if we use sortDescending=False (we use a difference measure ie we want to keep low score), the scores below that threshold are kept |
| 100 | + |
| 101 | + - N_object : number of best hit to return (by increasing score). Min=1, eventhough it does not really make sense to do NMS with only 1 hit |
| 102 | + - maxOverlap : float between 0 and 1, the maximal overlap authorised between 2 bounding boxes, above this value, the bounding box of lower score is deleted |
| 103 | + - sortDescending : use True when high score means better prediction, False otherwise (ex : if score is a difference measure, then the best prediction are low difference and we sort by ascending order) |
| 104 | +
|
| 105 | + OUTPUT |
| 106 | + List_nHit : List of the best detection after NMS, it contains max N detection (but potentially less) |
| 107 | + ''' |
| 108 | + |
| 109 | + # Apply threshold on prediction score |
| 110 | + if scoreThreshold==None : |
| 111 | + List_ThreshHit = List_Hit[:] # copy to avoid modifying the input list in place |
| 112 | + |
| 113 | + elif sortDescending : # We keep hit above the threshold |
| 114 | + List_ThreshHit = [dico for dico in List_Hit if dico['Score']>=scoreThreshold] |
| 115 | + |
| 116 | + elif not sortDescending : # We keep hit below the threshold |
| 117 | + List_ThreshHit = [dico for dico in List_Hit if dico['Score']<=scoreThreshold] |
| 118 | + |
| 119 | + |
| 120 | + # Sort score to have best predictions first (important as we loop testing the best boxes against the other boxes) |
| 121 | + if sortDescending: |
| 122 | + List_ThreshHit.sort(key=lambda dico: dico['Score'], reverse=True) # Hit = [list of (x,y),score] - sort according to descending (best = high correlation) |
| 123 | + else: |
| 124 | + List_ThreshHit.sort(key=lambda dico: dico['Score']) # sort according to ascending score (best = small difference) |
| 125 | + |
| 126 | + |
| 127 | + # Split the inital pool into Final Hit that are kept and restHit that can be tested |
| 128 | + # Initialisation : 1st keep is kept for sure, restHit is the rest of the list |
| 129 | + #print("\nInitialise final hit list with first best hit") |
| 130 | + FinalHit = [List_ThreshHit[0]] |
| 131 | + restHit = List_ThreshHit[1:] |
| 132 | + |
| 133 | + |
| 134 | + # Loop to compute overlap |
| 135 | + while len(FinalHit)<N_object and restHit : # second condition is restHit is not empty |
| 136 | + |
| 137 | + # Report state of the loop |
| 138 | + #print("\n\n\nNext while iteration") |
| 139 | + |
| 140 | + #print("-> Final hit list") |
| 141 | + #for hit in FinalHit: print(hit) |
| 142 | + |
| 143 | + #print("\n-> Remaining hit list") |
| 144 | + #for hit in restHit: print(hit) |
| 145 | + |
| 146 | + # pick the next best peak in the rest of peak |
| 147 | + test_hit = restHit[0] |
| 148 | + test_bbox = test_hit['BBox'] |
| 149 | + #print("\nTest BBox:{} for overlap against higher score bboxes".format(test_bbox)) |
| 150 | + |
| 151 | + # Loop over hit in FinalHit to compute successively overlap with test_peak |
| 152 | + for hit in FinalHit: |
| 153 | + |
| 154 | + # Recover Bbox from hit |
| 155 | + bbox2 = hit['BBox'] |
| 156 | + |
| 157 | + # Compute the Intersection over Union between test_peak and current peak |
| 158 | + IoU = computeIoU(test_bbox, bbox2) |
| 159 | + |
| 160 | + # Initialise the boolean value to true before test of overlap |
| 161 | + ToAppend = True |
| 162 | + |
| 163 | + if IoU>maxOverlap: |
| 164 | + ToAppend = False |
| 165 | + #print("IoU above threshold\n") |
| 166 | + break # no need to test overlap with the other peaks |
| 167 | + |
| 168 | + else: |
| 169 | + #print("IoU below threshold\n") |
| 170 | + # no overlap for this particular (test_peak,peak) pair, keep looping to test the other (test_peak,peak) |
| 171 | + continue |
| 172 | + |
| 173 | + |
| 174 | + # After testing against all peaks (for loop is over), append or not the peak to final |
| 175 | + if ToAppend: |
| 176 | + # Move the test_hit from restHit to FinalHit |
| 177 | + #print("Append {} to list of final hits, remove it from Remaining hit list".format(test_hit)) |
| 178 | + FinalHit.append(test_hit) |
| 179 | + restHit.remove(test_hit) |
| 180 | + |
| 181 | + else: |
| 182 | + # only remove the test_peak from restHit |
| 183 | + #print("Remove {} from Remaining hit list".format(test_hit)) |
| 184 | + restHit.remove(test_hit) |
| 185 | + |
| 186 | + |
| 187 | + # Once function execution is done, return list of hit without overlap |
| 188 | + #print("\nCollected N expected hit, or no hit left to test") |
| 189 | + #print("NMS over\n") |
| 190 | + return FinalHit |
| 191 | + |
| 192 | + |
| 193 | +if __name__ == "__main__": |
| 194 | + Hit1 = {'TemplateIdx':1,'BBox':(780, 350, 700, 480), 'Score':0.8} |
| 195 | + Hit2 = {'TemplateIdx':1,'BBox':(806, 416, 716, 442), 'Score':0.6} |
| 196 | + Hit3 = {'TemplateIdx':1,'BBox':(1074, 530, 680, 390), 'Score':0.4} |
| 197 | + |
| 198 | + ListHit = [Hit1, Hit2, Hit3] |
| 199 | + |
| 200 | + ListFinalHit = NMS(ListHit) |
| 201 | + |
| 202 | + print(ListFinalHit) |
0 commit comments