diff --git a/README.md b/README.md index 9552bb8..09ab798 100644 --- a/README.md +++ b/README.md @@ -31,8 +31,10 @@ yarn add sort-node@npm:@techainer1t/sort-node The `sort-node` package contain the object `SortNode` that can be use to track object detected from a single video or camera. -The `SortNode` object can be initialize with 2 arguments: -- `kMinHits`: (int) Minimum number of hits before a bounding box was assigned a new track ID +The `SortNode` object can be initialize with 4 arguments in the following order: +- `kMinHits`: (int) Minimum number of hits before a bounding box was assigned a new track ID (should be 3) +- `kMaxAge`: (int) Maximum number of frames to keep alive a track without associated detections +- `kIoUThreshold`: (float between 0 and 1) Minimum IOU for match (should be 0.3) - `kMinConfidence`: (float between 0 and 1) Bouding boxes with confidence score less than this value will be ignored With each frame, you will need to call `update` method. @@ -53,8 +55,10 @@ Please noted that the number of returned object might not be the same as the num ```javascript const sortnode = require("@techainer1t/sort-node"); const kMinHits = 3; +const kMaxAge = 1; +const kIoUThreshold = 0.3; const kMinConfidence = 0.3; -const tracker = sortnode.SortNode(kMinHists, kMinConfidence); +const tracker = sortnode.SortNode(kMinHits, kMaxAge, kIoUThreshold, kMinConfidence); while (true){ // Call the object detector ... diff --git a/include/tracker.h b/include/tracker.h index a266366..9f2726a 100644 --- a/include/tracker.h +++ b/include/tracker.h @@ -33,7 +33,7 @@ class Tracker { std::vector>> &unmatched_det, float iou_threshold = 0.3); - void Run(const std::vector>> &detections); + void Run(const std::vector>> &detections, int kMaxAge, float kIoUThreshold); std::map GetTracks(); diff --git a/package.json b/package.json index cbef783..f76ead4 100644 --- a/package.json +++ b/package.json @@ -18,7 +18,7 @@ }, "name": "@techainer1t/sort-node", "description": "Node binding of SORT: Simple, online, and real-time tracking of multiple objects in a video sequence.", - "version": "1.1.0", + "version": "1.1.1", "directories": { "doc": "docs" }, diff --git a/src/sort_node.cc b/src/sort_node.cc index df3acb5..ef7e0b7 100644 --- a/src/sort_node.cc +++ b/src/sort_node.cc @@ -22,9 +22,9 @@ namespace sortnode Napi::Env env = info.Env(); Napi::HandleScope scope(env); - if (info.Length() < 2 || info.Length() > 2) + if (info.Length() < 4 || info.Length() > 4) { - Napi::TypeError::New(env, "SortTracker constructor received wrong number of arguments: kMinHits, kMinConfidence") + Napi::TypeError::New(env, "SortTracker constructor received wrong number of arguments, expect: kMinHits, kMaxAge, kIoUThreshold, kMinConfidence") .ThrowAsJavaScriptException(); return; } @@ -37,14 +37,36 @@ namespace sortnode } auto kMinHits = info[0].As().DoubleValue(); - if (fmod(kMinHits, 1) != 0) + if (fmod(kMinHits, 1) != 0 || kMinHits < 0) { - Napi::TypeError::New(env, "kMinHits must be an interger") + Napi::TypeError::New(env, "kMinHits must be an interger greater than 0") + .ThrowAsJavaScriptException(); + return; + } + + if (!info[1].IsNumber()) + { + Napi::TypeError::New(env, "kMaxAge must be an interger") + .ThrowAsJavaScriptException(); + return; + } + + auto kMaxAge = info[1].As().DoubleValue(); + if (fmod(kMaxAge, 1) != 0 || kMaxAge < 0) + { + Napi::TypeError::New(env, "kMaxAge must be an interger greater than 0") + .ThrowAsJavaScriptException(); + return; + } + + if (!info[2].IsNumber() || info[2].As().DoubleValue() > 1 || info[2].As().DoubleValue() < 0) + { + Napi::TypeError::New(env, "kIoUThreshold must be a float between 0 and 1") .ThrowAsJavaScriptException(); return; } - if (!info[1].IsNumber() || info[1].As().DoubleValue() > 1 || info[1].As().DoubleValue() < 0) + if (!info[3].IsNumber() || info[3].As().DoubleValue() > 1 || info[3].As().DoubleValue() < 0) { Napi::TypeError::New(env, "kMinConfidence must be a float between 0 and 1") .ThrowAsJavaScriptException(); @@ -52,7 +74,9 @@ namespace sortnode } this->kMinHits = int(kMinHits); - this->kMinConfidence = float(info[1].As().DoubleValue()); + this->kMaxAge = int(kMaxAge); + this->kIoUThreshold = float(info[2].As().DoubleValue()); + this->kMinConfidence = float(info[3].As().DoubleValue()); } Napi::Value SortNode::update(const Napi::CallbackInfo& info) @@ -136,7 +160,7 @@ namespace sortnode } // Run SORT tracker - this->tracker.Run(bbox_per_frame); + this->tracker.Run(bbox_per_frame, this->kMaxAge, this->kIoUThreshold); const auto tracks = this->tracker.GetTracks(); // Convert results from cv::Rect to normal float vector diff --git a/src/sort_node.h b/src/sort_node.h index 419e1dd..d17faff 100644 --- a/src/sort_node.h +++ b/src/sort_node.h @@ -13,7 +13,9 @@ namespace sortnode { public: int kMinHits = 3; + int kMaxAge = 1; int kMaxCoastCycles = 1; + float kIoUThreshold = 0.3; float kMinConfidence = 0.6; int frame_index = 0; Tracker tracker; diff --git a/src/tracker.cpp b/src/tracker.cpp index a666c32..dae7bc9 100644 --- a/src/tracker.cpp +++ b/src/tracker.cpp @@ -135,7 +135,7 @@ void Tracker::AssociateDetectionsToTrackers(const std::vector>>& detections) { +void Tracker::Run(const std::vector>>& detections, int kMaxAge, float kIoUThreshold) { /*** Predict internal tracks from previous frame ***/ for (auto &track : tracks_) { @@ -149,7 +149,7 @@ void Tracker::Run(const std::vector>>& de // return values - matched, unmatched_det if (!detections.empty()) { - AssociateDetectionsToTrackers(detections, tracks_, matched, unmatched_det); + AssociateDetectionsToTrackers(detections, tracks_, matched, unmatched_det, kIoUThreshold); } /*** Update tracks with associated bbox ***/ @@ -168,7 +168,7 @@ void Tracker::Run(const std::vector>>& de /*** Delete lose tracked tracks ***/ for (auto it = tracks_.begin(); it != tracks_.end();) { - if (it->second.coast_cycles_ > kMaxCoastCycles) { + if (it->second.coast_cycles_ > kMaxAge) { it = tracks_.erase(it); } else { it++; diff --git a/test/test_binding.js b/test/test_binding.js index b03f055..bb677f9 100644 --- a/test/test_binding.js +++ b/test/test_binding.js @@ -7,8 +7,10 @@ assert(sortnode.SortNode, "The expected module is undefined"); function testBasic() { console.log("Running testBasic"); const kMinHits = 3; + const kMaxAge = 1; const kMinConfidence = 0.3; - const instance = new sortnode.SortNode(kMinHits, kMinConfidence); + const kIoUThreshold = 0.3; + const instance = new sortnode.SortNode(kMinHits, kMaxAge, kIoUThreshold, kMinConfidence); assert(instance.update, "The expected method is not defined"); } @@ -63,7 +65,7 @@ function testAccuracyWithoutLandmark() { const total_frames = all_detections.length; - const tracker = new sortnode.SortNode(3, 0.6); + const tracker = new sortnode.SortNode(3, 1, 0.3, 0.6); let frame_index = 0 let predicted = []; const t1 = Date.now() @@ -98,7 +100,7 @@ function testAccuracyWithoutLandmark() { function testKeepLandmark(){ console.log("Running testKeepLandmark") - const tracker = new sortnode.SortNode(3, 0.3); + const tracker = new sortnode.SortNode(3, 1, 0.3, 0); let input = [ [120, 240, 50, 70, 0.9, 23, 24, 25, 26, 27, 28, 29, 30],