In [1]:
from torch.utils.data import Dataset

In [46]:
path_file = "../data/keypoints/as_arrays/kth_dataset/running//person01_running_d4_uncomp_out.npz"

In [127]:
import keypoints_io
keypoints,scores = keypoints_io.load(path_file)

In [210]:
import numpy as np
class SimpleKeyPointAssembler:
    """
        a class based on simple heuristics that assembles the links two sets of keypoints in consecutive frames,
        if they have a score that is sufficiently high and are close to each other.
        
        It is applied sequentially, starting from the keypoint set with the highest score, for the next frame,
        we keep only the set of keypoints that are above a certain threshold ,
        among these keypoints, if more than one is left, we keep the closest set of keypoints, in terms of pixel's
        distance, if no "detection" is proposed, we apply the same routine to the next frame, until all the frames
        are exhausted.
        
        
        
        Remark/disclaimer : this heuristic would not generalize well in the case of multiple person per video,and
        is out of the scope of this github experiment.
        
    """
    def __init__(self,keypoints,scores,min_score):
        self.keypoints = keypoints
        self.scores = scores
        self.min_score = min_score
        
    def filter_low_scores(self):

        roi = [[score>self.min_score for score in scores_on_frame ] 
               for scores_on_frame in self.scores]
        
        keypoints_filtered = [[keypoint for keypoint,keep_or_not in zip(keypoints_on_frame,roi_on_frame) 
                if keep_or_not] 
               for (keypoints_on_frame,roi_on_frame) in zip(self.keypoints,roi)]
        
        _score_filtered = [[score for score,keep_or_not in zip(scores_on_frame,roi_on_frame) 
                if keep_or_not] 
               for (scores_on_frame,roi_on_frame) in zip(self.scores,roi)]
        
        res = keypoints_filtered 
        
        return res
        
        
    @staticmethod
    def compute_distances(keypoints0,keypoints1):
        assert keypoints0.shape[1:] == keypoints1.shape[1:]
        assert len(keypoints0.shape) == len(keypoints1.shape) == 3
        arr1 = keypoints0.reshape(-1,np.product(keypoints0.shape[1:]))
        arr2 = keypoints1.reshape(-1,np.product(keypoints1.shape[1:]))
        res = np.linalg.norm(arr1[...,np.newaxis] - arr2.transpose(),axis=1,ord=np.inf)
        return res

    def assemble(self,sequence_keypoints,keypoints_candidate):
        distances = self.compute_distances(sequence_keypoints[-1:],keypoints_candidate)
        next_position = distances.flatten().argmin()
        sequence_keypoints = np.concatenate([sequence_keypoints,keypoints_candidate[next_position][np.newaxis]])
        return sequence_keypoints

    def get_keypoints(self,keypoints):
        best_candidate = keypoints[0][scores[0].argmax()]
        sequence_keypoints = best_candidate[np.newaxis]
        for kpt in keypoints[1:]:
            sequence_keypoints = self.assemble(sequence_keypoints,kpt)
        return sequence_keypoints

In [218]:
ff = SimpleKeyPointAssembler(keypoints,scores,0.99)
res = ff.filter_low_scores()
np.unique([len(el) for el in res])

array([0, 1])

In [203]:
kept_detections_rois = [[score>ff.min_score for score in scores_frame]  for scores_frame in ff.scores]
[[ kpt for ktp in kpts_frame if ] for kpts_frame in ff.keypoints]

[array([], dtype=float32),
 array([], dtype=float32),
 array([], dtype=float32),
 array([], dtype=float32),
 array([], dtype=float32),
 array([], dtype=float32),
 array([0.10790987, 0.06580681], dtype=float32),
 array([0.96272355, 0.2079537 , 0.10011367, 0.09631367], dtype=float32),
 array([0.990987  , 0.11845417], dtype=float32),
 array([0.99565375, 0.07204951], dtype=float32),
 array([0.9960479 , 0.11836872], dtype=float32),
 array([0.9930882 , 0.30924284], dtype=float32),
 array([0.99932814, 0.17351978], dtype=float32),
 array([0.9990165, 0.5594179], dtype=float32),
 array([0.9987948 , 0.44157547], dtype=float32),
 array([0.99892706, 0.39600095], dtype=float32),
 array([0.996403  , 0.37111878], dtype=float32),
 array([0.97974193, 0.8397285 , 0.0765117 ], dtype=float32),
 array([0.996075  , 0.24648096], dtype=float32),
 array([0.99892706, 0.5496587 ], dtype=float32),
 array([0.9992518 , 0.46600017], dtype=float32),
 array([0.9975501 , 0.3414779 , 0.08657485], dtype=float32),
 array([

In [176]:
res[8][0]

(array([[14, 25],
        [14, 23],
        [13, 23],
        [ 8, 23],
        [ 9, 23],
        [ 7, 33],
        [ 5, 33],
        [ 2, 46],
        [ 2, 46],
        [ 6, 56],
        [ 6, 56],
        [ 3, 58],
        [ 1, 59],
        [ 3, 83],
        [ 3, 83],
        [ 0, 96],
        [ 6, 56]], dtype=int32),
 0.990987)

In [94]:
class Found(Exception): pass

self = ff
try :
    for idx,(keypoints,scores) in enumerate(zip(self.keypoints,self.scores)):
        for keypoint,score in zip(keypoints,scores):
            print(score)
            if score>self.min_score:
                raise Found
except Found:
    pass

0.10790987
0.06580681
0.96272355
0.2079537
0.100113675
0.09631367
0.990987


In [77]:
ff.scores

array([], dtype=float32)

In [71]:
idx,self.min_score

(529, 0.99)

In [54]:
len(ff.idxs)

137

In [55]:
scores

[array([], dtype=float32),
 array([], dtype=float32),
 array([], dtype=float32),
 array([], dtype=float32),
 array([], dtype=float32),
 array([], dtype=float32),
 array([0.10790987, 0.06580681], dtype=float32),
 array([0.96272355, 0.2079537 , 0.10011367, 0.09631367], dtype=float32),
 array([0.990987  , 0.11845417], dtype=float32),
 array([0.99565375, 0.07204951], dtype=float32),
 array([0.9960479 , 0.11836872], dtype=float32),
 array([0.9930882 , 0.30924284], dtype=float32),
 array([0.99932814, 0.17351978], dtype=float32),
 array([0.9990165, 0.5594179], dtype=float32),
 array([0.9987948 , 0.44157547], dtype=float32),
 array([0.99892706, 0.39600095], dtype=float32),
 array([0.996403  , 0.37111878], dtype=float32),
 array([0.97974193, 0.8397285 , 0.0765117 ], dtype=float32),
 array([0.996075  , 0.24648096], dtype=float32),
 array([0.99892706, 0.5496587 ], dtype=float32),
 array([0.9992518 , 0.46600017], dtype=float32),
 array([0.9975501 , 0.3414779 , 0.08657485], dtype=float32),
 array([