Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@ yarn add sort-node@npm:@techainer1t/sort-node

The `sort-node` package contain the object `SortNode` that can be use to track object detected from a single video or camera.

The `SortNode` object can be initialize with 2 arguments:
- `kMinHits`: (int) Minimum number of hits before a bounding box was assigned a new track ID
The `SortNode` object can be initialize with 4 arguments in the following order:
- `kMinHits`: (int) Minimum number of hits before a bounding box was assigned a new track ID (should be 3)
- `kMaxAge`: (int) Maximum number of frames to keep alive a track without associated detections
- `kIoUThreshold`: (float between 0 and 1) Minimum IOU for match (should be 0.3)
- `kMinConfidence`: (float between 0 and 1) Bouding boxes with confidence score less than this value will be ignored

With each frame, you will need to call `update` method.
Expand All @@ -53,8 +55,10 @@ Please noted that the number of returned object might not be the same as the num
```javascript
const sortnode = require("@techainer1t/sort-node");
const kMinHits = 3;
const kMaxAge = 1;
const kIoUThreshold = 0.3;
const kMinConfidence = 0.3;
const tracker = sortnode.SortNode(kMinHists, kMinConfidence);
const tracker = sortnode.SortNode(kMinHits, kMaxAge, kIoUThreshold, kMinConfidence);
while (true){
// Call the object detector
...
Expand Down
2 changes: 1 addition & 1 deletion include/tracker.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class Tracker {
std::vector<std::pair<cv::Rect, std::vector<float>>> &unmatched_det,
float iou_threshold = 0.3);

void Run(const std::vector<std::pair<cv::Rect, std::vector<float>>> &detections);
void Run(const std::vector<std::pair<cv::Rect, std::vector<float>>> &detections, int kMaxAge, float kIoUThreshold);

std::map<int, Track> GetTracks();

Expand Down
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
},
"name": "@techainer1t/sort-node",
"description": "Node binding of SORT: Simple, online, and real-time tracking of multiple objects in a video sequence.",
"version": "1.1.0",
"version": "1.1.1",
"directories": {
"doc": "docs"
},
Expand Down
38 changes: 31 additions & 7 deletions src/sort_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ namespace sortnode
Napi::Env env = info.Env();
Napi::HandleScope scope(env);

if (info.Length() < 2 || info.Length() > 2)
if (info.Length() < 4 || info.Length() > 4)
{
Napi::TypeError::New(env, "SortTracker constructor received wrong number of arguments: kMinHits, kMinConfidence")
Napi::TypeError::New(env, "SortTracker constructor received wrong number of arguments, expect: kMinHits, kMaxAge, kIoUThreshold, kMinConfidence")
.ThrowAsJavaScriptException();
return;
}
Expand All @@ -37,22 +37,46 @@ namespace sortnode
}

auto kMinHits = info[0].As<Napi::Number>().DoubleValue();
if (fmod(kMinHits, 1) != 0)
if (fmod(kMinHits, 1) != 0 || kMinHits < 0)
{
Napi::TypeError::New(env, "kMinHits must be an interger")
Napi::TypeError::New(env, "kMinHits must be an interger greater than 0")
.ThrowAsJavaScriptException();
return;
}

if (!info[1].IsNumber())
{
Napi::TypeError::New(env, "kMaxAge must be an interger")
.ThrowAsJavaScriptException();
return;
}

auto kMaxAge = info[1].As<Napi::Number>().DoubleValue();
if (fmod(kMaxAge, 1) != 0 || kMaxAge < 0)
{
Napi::TypeError::New(env, "kMaxAge must be an interger greater than 0")
.ThrowAsJavaScriptException();
return;
}

if (!info[2].IsNumber() || info[2].As<Napi::Number>().DoubleValue() > 1 || info[2].As<Napi::Number>().DoubleValue() < 0)
{
Napi::TypeError::New(env, "kIoUThreshold must be a float between 0 and 1")
.ThrowAsJavaScriptException();
return;
}

if (!info[1].IsNumber() || info[1].As<Napi::Number>().DoubleValue() > 1 || info[1].As<Napi::Number>().DoubleValue() < 0)
if (!info[3].IsNumber() || info[3].As<Napi::Number>().DoubleValue() > 1 || info[3].As<Napi::Number>().DoubleValue() < 0)
{
Napi::TypeError::New(env, "kMinConfidence must be a float between 0 and 1")
.ThrowAsJavaScriptException();
return;
}

this->kMinHits = int(kMinHits);
this->kMinConfidence = float(info[1].As<Napi::Number>().DoubleValue());
this->kMaxAge = int(kMaxAge);
this->kIoUThreshold = float(info[2].As<Napi::Number>().DoubleValue());
this->kMinConfidence = float(info[3].As<Napi::Number>().DoubleValue());
}

Napi::Value SortNode::update(const Napi::CallbackInfo& info)
Expand Down Expand Up @@ -136,7 +160,7 @@ namespace sortnode
}

// Run SORT tracker
this->tracker.Run(bbox_per_frame);
this->tracker.Run(bbox_per_frame, this->kMaxAge, this->kIoUThreshold);
const auto tracks = this->tracker.GetTracks();

// Convert results from cv::Rect to normal float vector
Expand Down
2 changes: 2 additions & 0 deletions src/sort_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ namespace sortnode
{
public:
int kMinHits = 3;
int kMaxAge = 1;
int kMaxCoastCycles = 1;
float kIoUThreshold = 0.3;
float kMinConfidence = 0.6;
int frame_index = 0;
Tracker tracker;
Expand Down
6 changes: 3 additions & 3 deletions src/tracker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ void Tracker::AssociateDetectionsToTrackers(const std::vector<std::pair<cv::Rect
}


void Tracker::Run(const std::vector<std::pair<cv::Rect, std::vector<float>>>& detections) {
void Tracker::Run(const std::vector<std::pair<cv::Rect, std::vector<float>>>& detections, int kMaxAge, float kIoUThreshold) {

/*** Predict internal tracks from previous frame ***/
for (auto &track : tracks_) {
Expand All @@ -149,7 +149,7 @@ void Tracker::Run(const std::vector<std::pair<cv::Rect, std::vector<float>>>& de

// return values - matched, unmatched_det
if (!detections.empty()) {
AssociateDetectionsToTrackers(detections, tracks_, matched, unmatched_det);
AssociateDetectionsToTrackers(detections, tracks_, matched, unmatched_det, kIoUThreshold);
}

/*** Update tracks with associated bbox ***/
Expand All @@ -168,7 +168,7 @@ void Tracker::Run(const std::vector<std::pair<cv::Rect, std::vector<float>>>& de

/*** Delete lose tracked tracks ***/
for (auto it = tracks_.begin(); it != tracks_.end();) {
if (it->second.coast_cycles_ > kMaxCoastCycles) {
if (it->second.coast_cycles_ > kMaxAge) {
it = tracks_.erase(it);
} else {
it++;
Expand Down
8 changes: 5 additions & 3 deletions test/test_binding.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ assert(sortnode.SortNode, "The expected module is undefined");
function testBasic() {
console.log("Running testBasic");
const kMinHits = 3;
const kMaxAge = 1;
const kMinConfidence = 0.3;
const instance = new sortnode.SortNode(kMinHits, kMinConfidence);
const kIoUThreshold = 0.3;
const instance = new sortnode.SortNode(kMinHits, kMaxAge, kIoUThreshold, kMinConfidence);
assert(instance.update, "The expected method is not defined");
}

Expand Down Expand Up @@ -63,7 +65,7 @@ function testAccuracyWithoutLandmark() {

const total_frames = all_detections.length;

const tracker = new sortnode.SortNode(3, 0.6);
const tracker = new sortnode.SortNode(3, 1, 0.3, 0.6);
let frame_index = 0
let predicted = [];
const t1 = Date.now()
Expand Down Expand Up @@ -98,7 +100,7 @@ function testAccuracyWithoutLandmark() {

function testKeepLandmark(){
console.log("Running testKeepLandmark")
const tracker = new sortnode.SortNode(3, 0.3);
const tracker = new sortnode.SortNode(3, 1, 0.3, 0);

let input = [
[120, 240, 50, 70, 0.9, 23, 24, 25, 26, 27, 28, 29, 30],
Expand Down