In [None]:
'''
This notebook is based on JARVIS create_videos2D.py.
After JARVIS model prediction, this notebook plots tracking points abvoe a confidence threshold and generate a 2D video. 

# to-do-list
# procee a list of videos (as well as 2D predictions)
# plot confidence level for individual points

JARVIS-MoCap (https://jarvis-mocap.github.io/jarvis-docs)
Copyright (c) 2022 Timo Hueser.
https://github.com/JARVIS-MoCap/JARVIS-HybridNet
Licensed under GNU Lesser General Public License v2.1
'''

In [1]:
import os
import time
import cv2
import numpy as np
from tqdm import tqdm
import yaml

import jarvis.visualization.visualization_utils as utils

In [7]:
params = {
    'project_dir': r'C:\Users\Yiting\Documents\GitHub\JARVIS-HybridNet\projects', # where the training networks are stored
    'project_name': '6cam_train_231130', # trainig project name
    'session': '2023-11-30', # recording session 
    'trial': '2023-11-30_14-06-48_762665', # recording trial
    'cameras': ['camBL', 'camBo', 'camBR', 'camTL', 'camTo', 'camTR'], # a list of camera names
    'recording_dir': r'E:\Hand_tracking\Recordings\Videos', # where the recording videos are stored
    'predictions2D_dir': r'E:\Hand_tracking\Predictions\Predictions_2D', # where the 2D prediction results are stored
    'output_dir': r'E:\Hand_tracking\Visualization\Video_2D', # where the 2D tracking videos are saved
    'frame_start': 0,
    'number_frames': -1,
    'conf_threshold': 0.3
         }

In [8]:
Create2DVideos(params)

In [4]:
def Create2DVideos(params):
    
    # Get cfg file
    cfg_path = os.path.join(params['project_dir'], params['project_name'], 'config.yaml')
    if cfg_path is not None:
        with open(cfg_path, 'r') as yaml_file:
            cfg = yaml.safe_load(yaml_file)
            
    # Create output folders
    output_folder = os.path.join(params['output_dir'], params['session'], params['trial'])
    os.makedirs(output_folder, exist_ok = True)
    
    for cam in params['cameras']:
        # Get recording video
        recording_path = os.path.join(params['recording_dir'], params['session'], params['trial'], cam + '.mp4')
        cap = cv2.VideoCapture(recording_path)
        cap.set(1,params['frame_start'])
        img_size  = [int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
                     int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))]
        frameRate = cap.get(cv2.CAP_PROP_FPS)
        
        # Initiate output tracking video 
        output_path = os.path.join(output_folder, cam + '.mp4')
        out = cv2.VideoWriter(output_path,
                    cv2.VideoWriter_fourcc('m', 'p', '4', 'v'), frameRate,
                    (img_size[0],img_size[1]))
        
        # Get skeleton info
        colors, line_idxs = get_skeleton(cfg)
        
        # Get 2D predictions
        data2D_path = os.path.join(params['predictions2D_dir'], params['session'], params['trial'], cam + '.csv')
        header_info = np.genfromtxt(data2D_path, delimiter=',', dtype=str, max_rows = 2)
        points2D_all = np.genfromtxt(data2D_path, delimiter=',')

        if header_info[1,0] == 'x':
            points2D_all = points2D_all[2:]

        assert params['frame_start'] < cap.get(cv2.CAP_PROP_FRAME_COUNT), \
                    "frame_start bigger than total framecount!"
        if (params['number_frames'] == -1):
            params['number_frames'] = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) \
                        - params['frame_start']
        else:
            assert params['frame_start']+params['number_frames'] \
                        <= cap.get(cv2.CAP_PROP_FRAME_COUNT), \
                        "make sure your selected segment is not longer that the " \
                        "total video!"

        for frame_num in range(params['number_frames']):
            ret, img_orig = cap.read()
            points2D = points2D_all[frame_num].reshape(-1,3)
            # Check if the confidence levels of all keypoints are above the confidence threshold 
            isPassed_bd = points2D[:,2]> params['conf_threshold']
            isPassed_all = all(isPassed_bd)
            if isPassed_all:
                for line in line_idxs:
                    utils.draw_line(img_orig, line, points2D,
                            img_size, colors[line[1]])
                for j,point in enumerate(points2D):
                    utils.draw_point(img_orig, point, img_size,
                            colors[j])

            out.write(img_orig)

        out.release()
        cap.release()
        params['number_frames'] = -1

In [3]:
def get_skeleton(cfg):
    skeleton = cfg['SKELETON']
    keypoints = cfg['KEYPOINT_NAMES']
    if len(skeleton) > 0:
        base_colors = [(255,0,0), (0,255,0), (0,0,255), (255,255,0),
                    (255,0,255), (0,255,255), (0,140,255), (140,255,0),
                    (255,140,0), (0,255,140), (255,140,140), (140,255,140),
                    (140,140,255), (140,140,140)]
        gray_color = (100,100,100)
        color_idx = 0
        colors = []
        connections = np.zeros(len(keypoints), dtype=int)
        for keypoint in range(len(keypoints)):
            colors.append(gray_color)

        line_idxs = []
        starting_idxs = []
        
        for bone in skeleton:
            index_start = keypoints.index(bone[0])
            starting_idxs.append(index_start)
            index_stop = keypoints.index(bone[1])
            line_idxs.append([index_start, index_stop])
            connections[index_start] += 1
            connections[index_stop] += 1

        seeds = np.nonzero(connections == 1)[0]

        unconnected = np.nonzero(connections == 0)[0]
        graph = Graph(line_idxs)
        cycles = graph.get_cycles()

        accounted_for = []
        
        for cycle in cycles:
            for point in cycle:
                colors[point] = base_colors[color_idx]
            color_idx = (color_idx + 1) % len(base_colors)

        for seed in seeds:
            if seed in starting_idxs:
                idx = seed
                colors[idx] = base_colors[color_idx]
                accounted_for.append(idx)
                conn_idxs = [line[1] for line in line_idxs if line[0] == idx]
                backward_idx = [line[0] for line in line_idxs if line[1] == idx]
                while len(conn_idxs) == 1 and len(backward_idx) < 2:
                    idx = conn_idxs[0]
                    if connections[idx] < 3 or part_of_cycle(cycles, idx):
                        if idx in accounted_for:
                            colors[idx] = gray_color
                        else:
                            colors[idx] = base_colors[color_idx]
                            accounted_for.append(idx)
                    conn_idxs = [line[1] for line in line_idxs if line[0] == idx]
                    backward_idx = [line[0] for line in line_idxs if line[1] == idx]
                color_idx = (color_idx + 1) % len(base_colors)

        for point in unconnected:
            colors[point] = base_colors[color_idx]
            color_idx = (color_idx + 1) % len(base_colors)
    else:
        colors = []
        line_idxs = []
        cmap = matplotlib.cm.get_cmap('jet')
        for i in range(len(keypoints)):
            colors.append(((np.array(
                    cmap(float(i)/len(keypoints))) *
                    255).astype(int)[:3]).tolist())

    return colors, line_idxs

In [2]:
class Graph:
    def __init__(self,graph):
        self.graph = graph
        self.cycles = []
        self.max_len = 0

    def get_cycles(self):
        for edge in self.graph:
            for node in edge:
                self.findNewCycles([node])
        return self.cycles

    def findNewCycles(self, path):
        start_node = path[0]
        next_node= None
        sub = []
        #visit each edge and each node of each edge
        for edge in self.graph:
            node1, node2 = edge
            if start_node in edge:
                    if node1 == start_node:
                        next_node = node2
                    else:
                        next_node = node1
                    if not self.visited(next_node, path):
                            # neighbor node not on path yet
                            sub = [next_node]
                            sub.extend(path)
                            # explore extended path
                            self.findNewCycles(sub);
                    elif len(path) > 2  and next_node == path[-1]:
                            # cycle found
                            p = self.rotate_to_smallest(path);
                            inv = self.invert(p)
                            if self.isNew(p) and self.isNew(inv):
                                overlaps = self.overlapping(p)
                                if len(overlaps) > 0:
                                    max_len = 0
                                    for overlap in overlaps:
                                        if len(overlap) > max_len:
                                            max_len = len(overlap)
                                    if len(p) > max_len:
                                        self.cycles.append(p)
                                        for overlap in overlaps:
                                            self.cycles.remove(overlap)
                                else:
                                    self.cycles.append(p)

    def invert(self,path):
        return self.rotate_to_smallest(path[::-1])

    def rotate_to_smallest(self,path):
        n = path.index(min(path))
        return path[n:]+path[:n]

    def isNew(self,path):
        return not path in self.cycles

    def overlapping(self, path):
        overlaps = []
        for cycle in self.cycles:
            for point in path:
                if point in cycle:
                    overlaps.append(cycle)
                    break
        return overlaps

    def visited(self,node, path):
        return node in path
