生成视频训练的训练集和测试集

In [1]:
import cv2 
import numpy as np 
from pathlib import Path 
import json
from tqdm import tqdm
from pose_utils import vis_annotations

In [2]:
def _int(x, y): 
    return int(x), int(y) 

In [6]:
def out_of_bbox(bbox, keypoints):
    """判断是否有关键点超出边界框, 可视化发现超出边界框的关键点检测效果都比较差"""
    x1, y1, w, h = bbox
    x2, y2 = x1 + w, y1 + h 
    keypoints = np.array(keypoints).reshape((-1, 3))
    x = keypoints[:, 0]
    y = keypoints[:, 1]
    return (x < x1).any() or (x > x2).any() or (y < y1).any() or (y > y2).any() 

def normalization(bbox, keypoints, img_h, img_w):
    """对边界框和关键点进行归一化"""
    x1, y1, w, h = bbox
    bbox = [x1/img_w, y1/img_h, w/img_w, h/img_h]
    keypoints = np.array(keypoints).reshape((-1, 3))
    keypoints[:, 0] /= img_w 
    keypoints[:, 1] /= img_h 
    keypoints = keypoints.flatten().tolist()
    bbox = [round(i, 5) for i in bbox]
    keypoints = [round(i, 5) for i in keypoints]
    return bbox, keypoints

In [7]:
def generate_dataset(data_root, video_type='.avi'): 
    ann_path  = Path(data_root).joinpath("annotations")
    video_path = Path(data_root).joinpath("Videos")
    out_path = Path(data_root).joinpath("final_annotations")

    if not out_path.exists():
        out_path.mkdir(mode=0o777, parents=True, exist_ok=True)
        print(f"annotation txt file saved to {str(out_path)}")
    else:
        raise ValueError(f"directory is allready exist => {str(out_path)}")

    try:
        cv2.namedWindow("videos", cv2.WINDOW_KEEPRATIO) 
        print("Press 'q' to exit this function.") 
        for ann in ann_path.glob("*.json"): 
            with ann.open('r') as fd: 
                json_dict = json.load(fd) 

            video_file = video_path.joinpath(ann.stem + video_type) 
            cap = cv2.VideoCapture(str(video_file)) 
            frame_count = cap.get(cv2.CAP_PROP_FRAME_COUNT) 

            # fall_direction 是跌倒方向，顺时针从1-8, 1是正北(默认值), 2是东北, 3是正东, .... 
            for ann_info in tqdm(json_dict["annotations"], total=frame_count): 
                frame_id = int(ann_info['frame_id']) 
                bbox = ann_info['bbox']
                keypoints = ann_info['keypoints']
                label = ann_info['label']
                fall_direction = ann_info['fall_direction']
                width = ann_info['width']
                height = ann_info['height']

                if cap.isOpened() and frame_id < frame_count:  
                    _, img = cap.read() 

                    if img is None or sum(bbox) == 0 or keypoints == []: 
                        continue 

                    # 可视化标注信息
                    x1, y1, w, h = bbox
                    x2, y2 = x1 + w, y1 + h
                    img = vis_annotations(img, bbox, keypoints, size=1)
                    # img = cv2.rectangle(img, _int(x1, y1), _int(x2, y2), color=(0, 255, 0), thickness=1)
                    text = f"{frame_id:2.0f}, {fall_direction:2.0f}"
                    img = cv2.putText(img, text, (14, 17), cv2.FONT_HERSHEY_SIMPLEX, 0.5,  color=(0, 255, 0), thickness=1)
                    if label == 0:
                        img = cv2.putText(img, "normal", (14, 34), cv2.FONT_HERSHEY_SIMPLEX, 0.5,  color=(0, 255, 0), thickness=1)
                    elif label == 1:
                        img = cv2.putText(img, "falling", (14, 34), cv2.FONT_HERSHEY_SIMPLEX, 0.5,  color=(0, 255, 0), thickness=1)
                    elif label == 2:
                        img = cv2.putText(img, "faint", (14, 34), cv2.FONT_HERSHEY_SIMPLEX, 0.5,  color=(0, 255, 0), thickness=1)
                    else:
                        raise ValueError(f"{label=}, label should be 0, 1, or 2!")  

                    cv2.imshow("videos", img)
                    key = cv2.waitKey(1) 
                    if key & 0XFFFF == ord('q'): 
                        return 

                    # 保存标注文件
                    if not out_of_bbox(bbox, keypoints):
                        bbox, keypoints = normalization(bbox, keypoints, height, width) 
                        bbox = str(bbox).strip("[],").replace(",", "")  # '[x1, y1, w, h]' -> "x1 y1 w h" 
                        keypoints = str(keypoints).strip("[],").replace(",", "") 
                        out_file = out_path.joinpath(ann.stem + ".txt") 
                        with out_file.open("a+") as fd: 
                            # 每一帧的标注信息追加到生成文件中
                            print(f"{frame_id} {label} {width} {height} {bbox} {keypoints}", file=fd) 

            cap.release()

    except ValueError:
        print(f"value error: ann_file => {ann}")

    finally:
        cv2.destroyAllWindows()

设置数据根目录

In [3]:
roots = [
    "FallDataset_mp4/Coffee_room_01/",
    "FallDataset_mp4/Coffee_room_02/",
    "FallDataset_mp4/Home_01",
    "FallDataset_mp4/Home_02/",
]
for data_root in roots:
    generate_dataset(data_root=data_root)   

检查生成帧的质量

In [10]:
def check_final_annotations(data_root):
    ann_path  = Path(data_root).joinpath("final_annotations")
    video_path = Path(data_root).joinpath("Videos")

    try:
        cv2.namedWindow("videos", cv2.WINDOW_KEEPRATIO) 
        print("Press 'q' to exit this function.") 
        for ann in ann_path.glob("*.txt"): 
            txt_info = np.loadtxt(str(ann))   # 读入一个视频的标注文件

            video_file = video_path.joinpath(ann.stem + ".avi") 
            cap = cv2.VideoCapture(str(video_file)) 
            frame_count = len(txt_info)

            # fall_direction 是跌倒方向，顺时针从1-8, 1是正北(默认值), 2是东北, 3是正东, .... 
            for line in tqdm(txt_info, total=frame_count): 
                frame_id = int(line[0])
                label = int(line[1])
                width, height = line[2:4]
                bbox = line[4:8]
                keypoints = line[8:]

                cap.set(cv2.CAP_PROP_POS_FRAMES, frame_id) 
                _, img = cap.read() 
                if img is None: 
                    continue 

                # 缩放回原图大小
                x1, y1, w, h = bbox 
                x1, y1, w, h = x1*width, y1*height, w*width, h*height 
                keypoints = keypoints.reshape((-1, 3)) 
                keypoints[:, 0] *= width 
                keypoints[:, 1] *= height 

                bbox = [x1, y1, w, h] 
                keypoints = keypoints.flatten().tolist() 

                # 可视化标注信息 
                img = vis_annotations(img, bbox, keypoints, size=1) 
                text = f"{frame_id:2.0f}, {label:2.0f}" 
                img = cv2.putText(img, text, (14, 17), cv2.FONT_HERSHEY_SIMPLEX, 0.5,  color=(0, 255, 0), thickness=1)
                if label == 0: 
                    img = cv2.putText(img, "normal", (14, 34), cv2.FONT_HERSHEY_SIMPLEX, 0.5,  color=(0, 255, 0), thickness=1)
                elif label == 1: 
                    img = cv2.putText(img, "falling", (14, 34), cv2.FONT_HERSHEY_SIMPLEX, 0.5,  color=(0, 255, 0), thickness=1)
                elif label == 2: 
                    img = cv2.putText(img, "faint", (14, 34), cv2.FONT_HERSHEY_SIMPLEX, 0.5,  color=(0, 255, 0), thickness=1)
                else: 
                    raise ValueError(f"{label=}, label should be 0, 1, or 2!") 

                cv2.imshow("videos", img)
                key = cv2.waitKey(1) 
                if key & 0XFFFF == ord('q'): 
                    return 

            cap.release()
    except ValueError:
        print(f"value error: ann_file => {ann}")

    finally:
        cv2.destroyAllWindows()

In [11]:
for data_root in roots:
    check_final_annotations(data_root=data_root)

Press 'q' to exit this function.


100%|██████████| 137/137 [00:01<00:00, 68.99it/s]
100%|██████████| 344/344 [00:05<00:00, 65.95it/s]
100%|██████████| 373/373 [00:05<00:00, 65.77it/s]
100%|██████████| 178/178 [00:02<00:00, 65.97it/s]
100%|██████████| 242/242 [00:03<00:00, 66.59it/s]
100%|██████████| 137/137 [00:02<00:00, 66.50it/s]
100%|██████████| 92/92 [00:01<00:00, 66.13it/s]
100%|██████████| 168/168 [00:02<00:00, 66.09it/s]
100%|██████████| 157/157 [00:02<00:00, 66.82it/s]
100%|██████████| 268/268 [00:04<00:00, 65.82it/s]
100%|██████████| 212/212 [00:03<00:00, 65.74it/s]
100%|██████████| 269/269 [00:04<00:00, 66.17it/s]
100%|██████████| 443/443 [00:06<00:00, 66.18it/s]
100%|██████████| 181/181 [00:02<00:00, 65.70it/s]
100%|██████████| 224/224 [00:03<00:00, 65.87it/s]
100%|██████████| 310/310 [00:04<00:00, 65.99it/s]
100%|██████████| 265/265 [00:03<00:00, 66.25it/s]
 50%|█████     | 131/261 [00:02<00:01, 65.47it/s]


Press 'q' to exit this function.


 13%|█▎        | 56/424 [00:00<00:05, 67.63it/s]


Press 'q' to exit this function.


 56%|█████▌    | 95/171 [00:01<00:01, 66.94it/s]


Press 'q' to exit this function.


 14%|█▍        | 42/295 [00:00<00:03, 70.18it/s]
