In [19]:
import os
from collections import OrderedDict
from natsort import natsorted

import numpy as np
import torch
import torch.nn.functional as f
import torchvision.transforms as transforms

from pose_hrnet import get_pose_net
from config import cfg
from config import update_config
from utils import pose_process, plot_pose

from PIL import Image
import cv2

In [20]:
index_mirror = np.concatenate([
                [1,3,2,5,4,7,6,9,8,11,10,13,12,15,14,17,16],
                [21,22,23,18,19,20],
                np.arange(40,23,-1), np.arange(50,40,-1),
                np.arange(51,55), np.arange(59,54,-1),
                [69,68,67,66,71,70], [63,62,61,60,65,64],
                np.arange(78,71,-1), np.arange(83,78,-1),
                [88,87,86,85,84,91,90,89],
                np.arange(113,134), np.arange(92,113)
                ]) - 1


In [21]:
def stack_flip(img):
    img_flip = cv2.flip(img, 1)
    return np.stack([img, img_flip], axis=0) # [img, height, width, channel]?

In [22]:
def norm_numpy_totensor(img, mean, std):
    img = img.astype(np.float32) / 255.0
    for i in range(3):
        # img: [batch, height, width, channel]?
        img[:, :, :, i] = (img[:, :, :, i] - mean[i]) / std[i] 
    return torch.from_numpy(img).permute(0, 3, 1, 2) # [img, channel, height, width]?

In [23]:
def merge_hm(hms_list):
    assert isinstance(hms_list, list) # hms_list?
    for hms in hms_list:
        hms[1,:,:,:] = torch.flip(hms[1,index_mirror,:,:], [2]) # hms[1] double flipped, so original? why flip on [2] height?
    
    hm = torch.cat(hms_list, dim=0)
    hm = torch.mean(hms, dim=0)
    return hm

In [24]:
video_file = 'KETI_SL_0000002337.avi'
video_path = os.path.join('./video', video_file)
image_path = os.path.join('./image', video_file)

mean = (0.485, 0.456, 0.406) # mean of normalized RGB? where did it come from?
std = (0.229, 0.224, 0.225)

if not os.path.exists(image_path):
    os.makedirs(image_path)

In [25]:
cap = cv2.VideoCapture(video_path)
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
print(f'width, height = {w}, {h}')

t = 0
while cap.isOpened():
    ret, image = cap.read()
    if not ret:
        break
    t += 1

    cv2.imwrite(os.path.join(image_path, f'frame_{t}.png'), image)

cap.release()
print(f'total {t} frames')

width, height = 1280, 720
total 127 frames


In [31]:
cap = cv2.VideoCapture(video_path)
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
print(f'width, height = {w}, {h}')

margin = int((w - h) / 2)

t = 0
while cap.isOpened():
    ret, image = cap.read()
    if not ret:
        break
    t += 1

    image = image[:, margin : margin + h]
    image = cv2.resize(image, (512, 512))
    # print(image.shape)
    cv2.imwrite(os.path.join(image_path, f'frame_{t}.png'), image)

cap.release()
print(f'total {t} frames')

width, height = 1280, 720
total 127 frames


In [32]:
with torch.no_grad():

    # load pretrained wholebody estimation model
    config = 'wholebody_w48_384x288.yaml'
    cfg.merge_from_file(config)

    model = get_pose_net(cfg, is_train=False)
    checkpoint = torch.load('./hrnet_w48_coco_wholebody_384x288-6e061c6a_20200922.pth')
    state_dict = checkpoint['state_dict']

    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        # how and why remove modules?
        if 'backbone.' in k:
            name = k[9:] # remove module.
        if 'keypoint_head.' in k:
            name = k[14:] # remove module.
        new_state_dict[name] = v
    
    model.load_state_dict(new_state_dict)
    model.eval()

    # load and convert image
    sample_img = os.path.join(image_path, 'frame_35.png')
    img = cv2.imread(sample_img)
    height, width = img.shape[:2]
    img = cv2.flip(img, flipCode=1) # why flip from the beginning?
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    # apply multi-scales: why?
    multi_scales = [512, 640] # assuming input image is already 512x512
    out = []
    for scale in multi_scales:
        print(scale)
        if scale != 512: # if 640
            img_temp = cv2.resize(img, (scale, scale)) # upscale 512x512 -> 640x640
        else:
            img_temp = img # original 512x512

        img_temp = stack_flip(img_temp) # [2 images, height, width, channel]
        print('stack flipped:', img_temp.shape)
        img_temp = norm_numpy_totensor(img_temp, mean, std) # [2 images, channel, height, width]
        print('norm numpy tensored:', img_temp.shape)

        hms = model(img_temp)
        print('hms:', hms.shape)

        if scale != 512: # if 640
            out.append(f.interpolate(hms, (width // 4, height // 4), mode='bilinear')) # 160x160 -> 128x128
        else:
            out.append(hms)

        for element in out:
            print('output element:', element.shape)

        print()
        
    out = merge_hm(out)
    print('merged:', out.shape)

    result = out.reshape((133, -1))
    result = torch.argmax(result, dim=1)
    result = result.numpy().squeeze()
    print('result:', result.shape)

    y = result // (width // 4)
    x = result % (width // 4) # what if width != height?
    pred = np.zeros((133, 3), dtype=np.float32) # third axis?
    pred[:, 0] = x
    pred[:, 1] = y
    pred = pose_process(pred, out)
    pred[:,:2] *= 4.0 
    print('pred:', pred.shape)

img = plot_pose(img, pred)
print(img.shape)
cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
cv2.imwrite('sample_output.png', img)

# final numpy array: [frame, keypoint, x, y, ?]
# what if not flip and not multi scale?


512
stack flipped: (2, 512, 512, 3)
norm numpy tensored: torch.Size([2, 3, 512, 512])
hms: torch.Size([2, 133, 128, 128])
output element: torch.Size([2, 133, 128, 128])

640
stack flipped: (2, 640, 640, 3)
norm numpy tensored: torch.Size([2, 3, 640, 640])
hms: torch.Size([2, 133, 160, 160])
output element: torch.Size([2, 133, 128, 128])
output element: torch.Size([2, 133, 128, 128])

merged: torch.Size([133, 128, 128])
result: (133,)
pred: (133, 3)
(512, 512, 3)


True