In [None]:
# Check Pytorch installation
import torch, torchvision

print('torch version:', torch.__version__, torch.cuda.is_available())
print('torchvision version:', torchvision.__version__)

In [None]:
# Check MMPose installation
import mmpose

print('mmpose version:', mmpose.__version__)

# Check mmcv installation
from mmcv.ops import get_compiling_cuda_version, get_compiler_version

print('cuda version:', get_compiling_cuda_version())
print('compiler information:', get_compiler_version())

In [None]:
import os
from collections import defaultdict
import numpy as np
from glob import glob
from matplotlib import pyplot as plt
from tqdm import tqdm

In [None]:
import cv2
from mmpose.apis import (inference_top_down_pose_model, init_pose_model,
                         vis_pose_result, process_mmdet_results)
from mmdet.apis import inference_detector, init_detector

In [None]:
from IPython.display import Image, display
import tempfile
import os.path as osp
import json

In [None]:
def predict(pose_model, img, person_results, cnt):
    # inference pose
    pose_results, returned_outputs = inference_top_down_pose_model(
        pose_model,
        img,
        person_results,
        format='xyxy',
        dataset=pose_model.cfg.data.test.type)  
    # show pose estimation results
    vis_result = mmpose.apis.vis_pose_tracking_result(
        pose_model,
        img,
        pose_results,
        show=False,
        dataset=pose_model.cfg.data.test.type)
        
    vis_result = cv2.resize(vis_result, dsize=None, fx=0.5, fy=0.5)

    with tempfile.TemporaryDirectory() as tmpdir:
        file_name = osp.join("test", f'pose_results_{cnt}.png')
        cv2.imwrite(file_name, vis_result)
        display(Image(file_name))
    return pose_results

In [None]:
# initialize pose model
pose_config = '../source/mmpose/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/hrnet_w48_coco_256x192.py'
pose_checkpoint = 'https://download.openmmlab.com/mmpose/top_down/hrnet/hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth'
pose_model = init_pose_model(pose_config, pose_checkpoint)

In [None]:
videos = "/usr/data/datasets/kalman-data/personal_folder/cva/graduate-work/data/raw/videos"
annotations = "/usr/data/datasets/kalman-data/personal_folder/cva/graduate-work/data/raw/volleyball_tracking_annotation"

In [None]:
save_directory = "/usr/data/datasets/kalman-data/personal_folder/cva/graduate-work/data/interim/2d/mmpose-result"

In [None]:
if not os.path.exists(save_directory):
    os.makedirs(save_directory)

In [None]:
def sort_by_num(path):
    name = path.split("/")[-1]
    num = int(name.split(".")[0])
    return num

In [None]:
datasets = [f.path for f in os.scandir(videos) if f.is_dir()]
datasets = sorted(datasets, key = sort_by_num)

In [None]:
cnt = 0
for pdataset in datasets:
    ndataset = pdataset.split("/")[-1]
    print(f"INFO: dataset {ndataset}")
    
    examples = [f.path for f in os.scandir(pdataset) if f.is_dir()]
    for pexample in tqdm(examples):
        tracks = set()
        frames_info = defaultdict(list)
        action_info = defaultdict(str)
        
        # read exist markup
        nexample = pexample.split("/")[-1]
        with open(f'{annotations}/{ndataset}/{nexample}/{nexample}.txt') as f:
            lines = f.readlines()
            for line in lines:
                data = line.split()
                id, xmin, ymin, xmax, ymax, frame, lost, grouping, generated = map(int, data[:-1])
                action = str(data[-1])
                action_info[id] = action
                tracks.add(id)
                if lost == 0:
                    frames_info[frame].append({'bbox': np.array([xmin, ymin, xmax, ymax]), 'track_id':id})
        # predict pose for each images, which have markup
        images = sorted(glob(f"{pexample}/*.*"), key = sort_by_num)
        pose_collection = defaultdict(dict)    
        for img in images:
            num = int(img.split("/")[-1].split(".")[0])
            for track_id in tracks:
                 pose_collection[track_id][num] = None
            if frames_info[num]:
                cnt+=1
                pose_results = predict(pose_model, img, frames_info[num], cnt)
                for pose in pose_results:
                    pose_collection[pose["track_id"]][num] = pose["keypoints"].tolist()
        break
        # write result for each person
        result_dir = f"{save_directory}/{ndataset}/{nexample}"
        if not os.path.exists(result_dir):
            os.makedirs(result_dir)
        
        for track_id in pose_collection:
            result = {
                "action": action_info[track_id],
                "pose": pose_collection[track_id]
            }
            with open(f'{result_dir}/{track_id}.json', 'w+') as f:
                json.dump(result, f)

In [None]:
def plot_im_skeleton(skeleton, ax, color):    
    bone_list = [[15, 13], [13, 11], [16, 14], [14, 12], [11, 12],
                 [5, 11], [6, 12], [5, 6], [5, 7], [6, 8], [7, 9],
                 [8, 10], [1, 2], [0, 1], [0, 2], [1, 3], [2, 4],
                 [3, 5], [4, 6]]   
    x = skeleton[:, 0]
    y = skeleton[:, 1]
    for bone in bone_list:
        ax.plot([x[bone[0]], x[bone[1]]], [y[bone[0]], y[bone[1]]], color = color[:3])

In [None]:
def get_img(data):
    fig, ax = plt.subplots(nrows=1,ncols=1,figsize=(3,3))
    plt.gca().invert_yaxis()
    # ax.set_axis_off()
    colors = plt.get_cmap('hsv')(np.linspace(0.0, 1, len(data)))
    for i, point in enumerate(data):
        plot_im_skeleton(point, ax, colors[i])
    plt.close()
    return fig

In [None]:
for pid in pose_collection:
    pose_collection[pid] = np.array(pose_collection[pid])

In [None]:
get_img(pose_collection[0][:])