Skip to content

Commit

Permalink
Improved SORT algorithm and fixed bug when cutting clips
Browse files Browse the repository at this point in the history
  • Loading branch information
BrennoCaldato committed Jul 29, 2020
1 parent a8d877c commit 58d2e8f
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 57 deletions.
119 changes: 62 additions & 57 deletions src/CVObjectDetection.cpp
Expand Up @@ -166,70 +166,74 @@ void CVObjectDetection::postprocess(const cv::Size &frameDims, const std::vector
std::vector<int> indices;
cv::dnn::NMSBoxes(boxes, confidences, confThreshold, nmsThreshold, indices);

// Pass boxes to SORT algorithm
std::vector<cv::Rect> sortBoxes;
for(auto box : boxes)
sortBoxes.push_back(box);
sort.update(sortBoxes, frameId, sqrt(pow(frameDims.width,2) + pow(frameDims.height, 2)), confidences, classIds);


sortBoxes.clear(); boxes.clear(); confidences.clear(); classIds.clear();

// Clear data vectors
boxes.clear(); confidences.clear(); classIds.clear();
// Get SORT predicted boxes
for(auto TBox : sort.frameTrackingResult){
if(TBox.frame == frameId){
boxes.push_back(TBox.box);
confidences.push_back(TBox.confidence);
classIds.push_back(TBox.classId);
}
}

// for(int i = 0; i<boxes.size(); i++){
// bool found = false;
// for(int j = 0; j<sortBoxes.size(); j++){
// if( iou(boxes[i], sortBoxes[j]) ){
// boxes[i] = sortBoxes[j];
// sortBoxes.erase(sortBoxes.begin() + j);
// found = true;
// break;
// }
// }
// if(!found){
// boxes.erase(boxes.begin() + i);
// confidences.erase(confidences.begin() + i);
// classIds.erase(classIds.begin() + i);
// }
// }



// std::map<int, std::vector<cv::Rect> > rectAndClasses;
// for(int i=0; i<boxes.size(); i++){
// if(rectAndClasses.find(classIds[i]) == rectAndClasses.end()){
// std::vector<cv::Rect> bboxes;
// rectAndClasses[classIds[i]] = bboxes;
// }

// rectAndClasses[classIds[i]].push_back(boxes[i]);
// }

// for(std::map<int, std::vector<cv::Rect> >::iterator it = rectAndClasses.begin(); it != rectAndClasses.end(); it++){
// if(sort.find(it->first) == sort.end()){
// SortTracker classTracker;
// sort[it->first] = classTracker;
// }
// sort[it->first].update(it->second, frameId, sqrt(pow(frameDims.width,2) + pow(frameDims.height, 2)));
// }

// classIds.clear(); boxes.clear(); confidences.clear();

// for(std::map<int, SortTracker>::iterator it = sort.begin(); it != sort.end(); it++){
// for(auto TBox : it->second.frameTrackingResult){
// boxes.push_back(TBox.box);
// classIds.push_back(it->first);
// confidences.push_back(1);
// }
// }
// Remove boxes based on controids distance
for(uint i = 0; i<boxes.size(); i++){
for(uint j = i+1; j<boxes.size(); j++){
int xc_1 = boxes[i].x + (int)(boxes[i].width/2), yc_1 = boxes[i].y + (int)(boxes[i].width/2);
int xc_2 = boxes[j].x + (int)(boxes[j].width/2), yc_2 = boxes[j].y + (int)(boxes[j].width/2);

if(fabs(xc_1 - xc_2) < 10 && fabs(yc_1 - yc_2) < 10){
if(classIds[i] == classIds[j]){
if(confidences[i] >= confidences[j]){
boxes.erase(boxes.begin() + j);
classIds.erase(classIds.begin() + j);
confidences.erase(confidences.begin() + j);
break;
}
else{
boxes.erase(boxes.begin() + i);
classIds.erase(classIds.begin() + i);
confidences.erase(confidences.begin() + i);
i = 0;
break;
}
}
}
}
}

// Remove boxes based in IOU score
for(uint i = 0; i<boxes.size(); i++){
for(uint j = i+1; j<boxes.size(); j++){

if( iou(boxes[i], boxes[j])){
if(classIds[i] == classIds[j]){
if(confidences[i] >= confidences[j]){
boxes.erase(boxes.begin() + j);
classIds.erase(classIds.begin() + j);
confidences.erase(confidences.begin() + j);
break;
}
else{
boxes.erase(boxes.begin() + i);
classIds.erase(classIds.begin() + i);
confidences.erase(confidences.begin() + i);
i = 0;
break;
}
}
}
}
}

// Normalize boxes coordinates
std::vector<cv::Rect_<float>> normalized_boxes;
for(auto box : boxes){
cv::Rect_<float> normalized_box;
Expand All @@ -243,25 +247,26 @@ void CVObjectDetection::postprocess(const cv::Size &frameDims, const std::vector
detectionsData[frameId] = CVDetectionData(classIds, confidences, normalized_boxes, frameId);
}

// Compute IOU between 2 boxes
bool CVObjectDetection::iou(cv::Rect pred_box, cv::Rect sort_box){
// determine the (x, y)-coordinates of the intersection rectangle
// Determine the (x, y)-coordinates of the intersection rectangle
int xA = std::max(pred_box.x, sort_box.x);
int yA = std::max(pred_box.y, sort_box.y);
int xB = std::min(pred_box.x + pred_box.width, sort_box.x + sort_box.width);
int yB = std::min(pred_box.y + pred_box.height, sort_box.y + sort_box.height);

// compute the area of intersection rectangle
// Compute the area of intersection rectangle
int interArea = std::max(0, xB - xA + 1) * std::max(0, yB - yA + 1);
// compute the area of both the prediction and ground-truth
// rectangles

// Compute the area of both the prediction and ground-truth rectangles
int boxAArea = (pred_box.width + 1) * (pred_box.height + 1);
int boxBArea = (sort_box.width + 1) * (sort_box.height + 1);
// compute the intersection over union by taking the intersection
// area and dividing it by the sum of prediction + ground-truth
// areas - the interesection area

// Compute the intersection over union by taking the intersection
float iou = interArea / (float)(boxAArea + boxBArea - interArea);

if(iou > 0.75)
// If IOU is above this value the boxes are very close (probably a variation of the same bounding box)
if(iou > 0.5)
return true;
return false;
}
Expand Down
3 changes: 3 additions & 0 deletions src/effects/ObjectDetection.cpp
Expand Up @@ -174,6 +174,9 @@ bool ObjectDetection::LoadObjDetectdData(std::string inputFilePath){
classNames.clear();
detectionsData.clear();

// Seed to generate same random numbers
std::srand(1);
// Get all classes names and assign a color to them
for(int i = 0; i < objMessage.classnames_size(); i++){
classNames.push_back(objMessage.classnames(i));
classesColor.push_back(cv::Scalar(std::rand()%205 + 50, std::rand()%205 + 50, std::rand()%205 + 50));
Expand Down

0 comments on commit 58d2e8f

Please sign in to comment.