In [1]:
import unittest
from typing import List, Optional
from dataclasses import dataclass



In [2]:
from openptv_python.tracking_frame_buf import n_tupel

In [3]:
def run_test(tcls):
    """
    Runs unit tests from a test class
    :param tcls: A class, derived from unittest.TestCase
    """
    suite = unittest.TestLoader().loadTestsFromTestCase(tcls)
    runner = unittest.TextTestRunner(verbosity=2)
    runner.run(suite)

In [4]:

def take_best_candidates(src: List[n_tupel], dst: List[n_tupel], num_cams: int, num_cands: int, tusage: List[List[int]]) -> int:
    taken: int = 0

    # Sort candidates by match quality (.corr)
    src.sort(key=lambda x: x.corr)

    # Take quadruplets from the top to the bottom of the sorted list
    # only if none of the points has already been used
    for cand in range(num_cands):
        has_used_target: bool = False
        for cam in range(num_cams):
            tnum: int = src[cand]['p'][cam]

            # If any correspondence in this camera, check that target is free
            if tnum > -1 and tusage[cam][tnum] > 0:
                has_used_target = True
                break

        if has_used_target:
            continue

        # Only now can we commit to marking used targets.
        for cam in range(num_cams):
            tnum = src[cand]['p'][cam]
            if tnum > -1:
                tusage[cam][tnum] += 1
        dst[taken] = src[cand]
        taken += 1

    return taken



In [5]:
@run_test
class TestTakeBestCandidates(unittest.TestCase):
    def test_quadruplets(self):
        src = [
            n_tupel(p=[0, 1, 2, 3], corr=0.9),
            n_tupel(p=[1, 2, 3, 4], corr=0.8),
            n_tupel(p=[2, 3, 4, 5], corr=0.7),
            n_tupel(p=[3, 4, 5, 6], corr=0.6)
        ]
        dst = [None] * len(src)
        num_cams = 4
        num_cands = len(src)
        tusage = [[0] * 10 for _ in range(num_cams)]

        taken = take_best_candidates(src, dst, num_cams, num_cands, tusage)

        self.assertEqual(taken, len(src))
        self.assertEqual(dst, src)

    def test_triplets(self):
        src = [
            n_tupel(p=[0, 1, 2, -1], corr=0.9),
            n_tupel(p=[1, 2, 3, -1], corr=0.8),
            n_tupel(p=[2, 3, 4, -1], corr=0.7),
            n_tupel(p=[3, 4, 5, -1], corr=0.6)
        ]
        dst = [None] * len(src)
        num_cams = 4
        num_cands = len(src)
        tusage = [[0] * 10 for _ in range(num_cams)]

        taken = take_best_candidates(src, dst, num_cams, num_cands, tusage)

        self.assertEqual(taken, len(src))
        self.assertEqual(dst[:taken], src[:taken])

    def test_pairs(self):
        src = [
            n_tupel(p=[0, 1, -1, -1], corr=0.9),
            n_tupel(p=[1, 2, -1, -1], corr=0.8),
            n_tupel(p=[2, 3, -1, -1], corr=0.7),
            n_tupel(p=[3, 4, -1, -1], corr=0.6)
        ]
        dst = [None] * len(src)
        num_cams = 4
        num_cands = len(src)
        tusage = [[0] * 10 for _ in range(num_cams)]

        taken = take_best_candidates(src, dst, num_cams, num_cands, tusage)

        self.assertEqual(taken, len(src))
        print(dst[:taken])
        print(src[:taken])
        self.assertEqual(dst[:taken], src[:taken])

    def test_no_candidates(self):
        src = []
        dst = [None] * 10  # Just to test that dst is not modified
        num_cams = 4
        num_cands = 0
        tusage = [[0] * 10 for _ in range(num_cams)]

        taken = take_best_candidates(src, dst, num_cams, num_cands, tusage)

        self.assertEqual(taken, 0)
        self.assertEqual(dst, [None] * 10)


# if __name__ == '__main__':
# 


test_no_candidates (__main__.TestTakeBestCandidates) ... ok
test_pairs (__main__.TestTakeBestCandidates) ... ok
test_quadruplets (__main__.TestTakeBestCandidates) ... ok
test_triplets (__main__.TestTakeBestCandidates) ... ok

----------------------------------------------------------------------
Ran 4 tests in 0.006s

OK


[<openptv_python.tracking_frame_buf.n_tupel object at 0x7f69b4283e50>, <openptv_python.tracking_frame_buf.n_tupel object at 0x7f69b4281930>, <openptv_python.tracking_frame_buf.n_tupel object at 0x7f69b4283df0>, <openptv_python.tracking_frame_buf.n_tupel object at 0x7f69d99bee00>]
[<openptv_python.tracking_frame_buf.n_tupel object at 0x7f69b4283e50>, <openptv_python.tracking_frame_buf.n_tupel object at 0x7f69b4281930>, <openptv_python.tracking_frame_buf.n_tupel object at 0x7f69b4283df0>, <openptv_python.tracking_frame_buf.n_tupel object at 0x7f69d99bee00>]
