# CenterNet_姿勢推定_動画解析
**参考**
*   [CenterNet (Objects as Points)](https://arxiv.org/abs/1904.07850) 
*   [xingyizhou/CenterNet](https://github.com/xingyizhou/CenterNet).
*   [tugstugi/dl-colab-notebooks](https://github.com/tugstugidl-colab-notebooks)

## CenterNetをインストールする

In [None]:
!pip install -U torch==1.4 torchvision==0.5 -f https://download.pytorch.org/whl/cu101/torch_stable.html

import os
from os.path import exists, join, basename, splitext

git_repo_url = 'https://github.com/xingyizhou/CenterNet.git'
project_name = splitext(basename(git_repo_url))[0]
if not exists(project_name):
  # clone
  !git clone -q --depth 1 $git_repo_url
  # fix DCNv2
  !cd {project_name}/src/lib/models/networks && rm -rf DCNv2 && git clone https://github.com/CharlesShang/DCNv2.git && cd DCNv2 && ./make.sh
  # dependencies
  !cd $project_name && pip install -q -r requirements.txt

import sys
sys.path.insert(0, join(project_name, 'src/lib'))
sys.path.append(join(project_name, 'src'))
# following 2 lines needed to avoid later import error
sys.path.append(join(project_name, 'src/lib/models/networks/DCNv2'))
from dcn_v2 import DCN

import time
import matplotlib
import matplotlib.pylab as plt
plt.rcParams["axes.grid"] = False

from IPython.display import clear_output

Looking in links: https://download.pytorch.org/whl/cu101/torch_stable.html
Collecting torch==1.4
[?25l  Downloading https://files.pythonhosted.org/packages/1a/3b/fa92ece1e58a6a48ec598bab327f39d69808133e5b2fb33002ca754e381e/torch-1.4.0-cp37-cp37m-manylinux1_x86_64.whl (753.4MB)
[K     |████████████████████████████████| 753.4MB 21kB/s 
[?25hCollecting torchvision==0.5
[?25l  Downloading https://files.pythonhosted.org/packages/1c/32/cb0e4c43cd717da50258887b088471568990b5a749784c465a8a1962e021/torchvision-0.5.0-cp37-cp37m-manylinux1_x86_64.whl (4.0MB)
[K     |████████████████████████████████| 4.0MB 41.3MB/s 
[31mERROR: torchtext 0.9.1 has requirement torch==1.8.1, but you'll have torch 1.4.0 which is incompatible.[0m
Installing collected packages: torch, torchvision
  Found existing installation: torch 1.8.1+cu101
    Uninstalling torch-1.8.1+cu101:
      Successfully uninstalled torch-1.8.1+cu101
  Found existing installation: torchvision 0.9.1+cu101
    Uninstalling torchvision-0.

モデルをダウンロードする

In [None]:
model_name = 'multi_pose_dla_3x.pth'
if not exists(model_name):
  !pip install -q gdown
  !gdown 'https://drive.google.com/uc?id=1PO1Ax_GDtjiemEmDVD7oPWwqQkUu28PI'

Downloading...
From: https://drive.google.com/uc?id=1PO1Ax_GDtjiemEmDVD7oPWwqQkUu28PI
To: /content/multi_pose_dla_3x.pth
82.7MB [00:00, 225MB/s]


モデルをCenterNetのmodelディレクトリへ移動する

In [None]:
!mv multi_pose_dla_3x.pth CenterNet/models/

## 読み込みたい動画のパスを指定

In [None]:
#動画ファイルを読み込むため、Google Driveをマウントする
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
#読み込みたい動画のPathを指定する(要変更)
VIDEO_PATH = "/content/IMG_2658.MOV"

In [None]:
#変更不要
CENTERNET_LIB_PATH = 'CenterNet/src/lib'
MODEL_PATH = 'CenterNet/models/multi_pose_dla_3x.pth'
!cd CenterNet/src

In [None]:
import cv2
import glob as glob
import matplotlib.pyplot as plt
%matplotlib inline
import warnings
warnings.filterwarnings('ignore')

import sys
sys.path.insert(0, CENTERNET_LIB_PATH)

from detectors.detector_factory import detector_factory
from opts import opts

import pickle
from PIL import Image, ImageFont, ImageDraw
import numpy as np
import colorsys
from pylab import rcParams

import pandas as pd
from tqdm import tqdm_notebook as tqdm

In [None]:
TASK = 'multi_pose' # or 'multi_pose' for human pose estimation
opt = opts().init('{} --load_model {}'.format(TASK, MODEL_PATH).split(' '))
detector = detector_factory[opt.task](opt)

Fix size testing.
training chunk_sizes: [32]
The output will be saved to  CenterNet/src/lib/../../exp/multi_pose/default
heads {'hm': 1, 'wh': 2, 'hps': 34, 'reg': 2, 'hm_hp': 17, 'hp_offset': 2}
Creating model...
loaded CenterNet/models/multi_pose_dla_3x.pth, epoch 320


In [None]:
colors_hp = [(255, 0, 255), (255, 0, 0), (0, 0, 255), 
        (255, 0, 0), (0, 0, 255), (255, 0, 0), (0, 0, 255),
        (255, 0, 0), (0, 0, 255), (255, 0, 0), (0, 0, 255),
        (255, 0, 0), (0, 0, 255), (255, 0, 0), (0, 0, 255),
        (255, 0, 0), (0, 0, 255)]
edges = [[0, 1], [0, 2], [1, 3], [2, 4], 
                    [3, 5], [4, 6], [5, 6], 
                    [5, 7], [7, 9], [6, 8], [8, 10], 
                    [5, 11], [6, 12], [11, 12], 
                    [11, 13], [13, 15], [12, 14], [14, 16]]
ec = [(255, 0, 0), (0, 0, 255), (255, 0, 0), (0, 0, 255), 
                 (255, 0, 0), (0, 0, 255), (255, 0, 255),
                 (255, 0, 0), (255, 0, 0), (0, 0, 255), (0, 0, 255),
                 (255, 0, 0), (0, 0, 255), (255, 0, 255),
                 (255, 0, 0), (255, 0, 0), (0, 0, 255), (0, 0, 255)]

def write_pose(points, img): 
    points = np.array(points, dtype=np.int32).reshape(17, 2)
    for j in range(17):
        cv2.circle(img,(points[j, 0], points[j, 1]), 3, colors_hp[j], -1)
    for j, e in enumerate(edges):
          if points[e].min() > 0:
            #ec[j]で色を決めている( 左右で分ける)
            cv2.line(img, (points[e[0], 0], points[e[0], 1]),(points[e[1], 0], points[e[1], 1]), ec[j], 2,lineType=cv2.LINE_AA)

def write_rect(img,box, cl):
    x1, y1, x2, y2 = [int(x) for x in box[:4]]
    cv2.rectangle(img, (x1,y1), (x2, y2), (255, 255, 255),3)

## 動画解析実行

### 白背景ver

In [None]:
cap = cv2.VideoCapture(VIDEO_PATH)
frame_count = round(cap.get(cv2.CAP_PROP_FRAME_COUNT)) #総フレーム数
fps = round(cap.get(cv2.CAP_PROP_FPS)) #fps

height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) #画像高さ
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) #画像幅
size = (width, height) #画像サイズ

#解析データの保存先
out_path = VIDEO_PATH[:-4] + '_line.mp4'
out = cv2.VideoWriter(out_path, cv2.VideoWriter_fourcc('m','p','4', 'v'), fps, size)

for _ in tqdm(range(frame_count)):
    ret0, frame_read = cap.read()
    if not ret0:
        break
        
    ret = detector.run(frame_read)['results']
    white = np.ones((height,width,3), np.uint8)*255
    
    for bbox in ret[1]:
        if bbox[4] > 0.5:
            points = np.array(bbox[5:39], dtype=np.int32).reshape(17, 2)
            write_pose(points,white)
    
    out.write(cv2.flip(white, -1))

cap.release()
out.release()

HBox(children=(FloatProgress(value=0.0, max=765.0), HTML(value='')))

白塗りはnp.ones((height,width,3), np.uint8)*255で作る

### 背景重ねver

In [None]:
cap = cv2.VideoCapture(VIDEO_PATH)
frame_count = round(cap.get(cv2.CAP_PROP_FRAME_COUNT)) #総フレーム数
fps = round(cap.get(cv2.CAP_PROP_FPS)) #fps

height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) #画像高さ
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) #画像幅
size = (width, height) #画像サイズ

#解析データの保存先
out_path = VIDEO_PATH[:-4] + '_result.mp4'
out = cv2.VideoWriter(
        out_path, 
        cv2.VideoWriter_fourcc('m','p','4', 'v'),
        fps,
        size
        )

for _ in tqdm(range(frame_count)):
    ret0, frame_read = cap.read()
    if not ret0:
        break
        
    ret = detector.run(frame_read)['results']
    
    for bbox in ret[1]:
        if bbox[4] > 0.5:
            points = np.array(bbox[5:39], dtype=np.int32).reshape(17, 2)
            write_rect(frame_read, bbox[:4], 0)
            write_pose(points,frame_read)
    
    out.write(cv2.flip(frame_read, -1))

cap.release()
out.release()

HBox(children=(FloatProgress(value=0.0, max=765.0), HTML(value='')))