In [1]:
from __future__ import print_function, absolute_import

import os
import numpy as np
import json
import random
import math

import torch
import torch.utils.data as data

from pose.utils.osutils import *
from pose.utils.imutils import *
from pose.utils.transforms import *
from pose.utils.evaluation  import final_preds
import pose.models as models

import glob
import cv2
from tqdm import tqdm
import imageio

from scipy.io import loadmat
import scipy.misc
import scipy.ndimage

import imgaug as ia
import imgaug.augmenters as iaa
from imgaug.augmentables.kps import KeypointsOnImage


In [2]:
def cv2_plot_lines(frame, pts, order):
    color_mapping = {1: [255,0,255], 2: [255,0,0], 3: [255,0,127], 4: [255,255,255], 5: [0,0,255],
                 6: [0,127,255], 7: [0,255,255], 8: [0,255,0], 9: [200,162,200]} 
    point_size = 3
    if order==0:
        # other animals
        # plot neck-eyes
        cv2.line(frame, (pts[[0,2],:][0,0], pts[[0,2],:][0,1]), (pts[[0,2],:][1,0], pts[[0,2],:][1,1]), color_mapping[1], point_size)
        cv2.line(frame, (pts[[1,2],:][0,0], pts[[1,2],:][0,1]), (pts[[1,2],:][1,0], pts[[1,2],:][1,1]), color_mapping[2], point_size)

        # plot legs
        cv2.line(frame, (pts[[3,8],:][0,0], pts[[3,8],:][0,1]), (pts[[3,8],:][1,0], pts[[3,8],:][1,1]), color_mapping[5], point_size)
        cv2.line(frame, (pts[[8,14],:][0,0], pts[[8,14],:][0,1]), (pts[[8,14],:][1,0], pts[[8,14],:][1,1]), color_mapping[5], point_size)

        cv2.line(frame, (pts[[4,9],:][0,0], pts[[4,9],:][0,1]), (pts[[4,9],:][1,0], pts[[4,9],:][1,1]), color_mapping[6], point_size)
        cv2.line(frame, (pts[[9,15],:][0,0], pts[[9,15],:][0,1]), (pts[[9,15],:][1,0], pts[[9,15],:][1,1]), color_mapping[6], point_size)

        cv2.line(frame, (pts[[5,10],:][0,0], pts[[5,10],:][0,1]), (pts[[5,10],:][1,0], pts[[5,10],:][1,1]), color_mapping[7], point_size)
        cv2.line(frame, (pts[[10,16],:][0,0], pts[[10,16],:][0,1]), (pts[[10,16],:][1,0], pts[[10,16],:][1,1]), color_mapping[7], point_size)

        cv2.line(frame, (pts[[6,11],:][0,0], pts[[6,11],:][0,1]), (pts[[6,11],:][1,0], pts[[6,11],:][1,1]), color_mapping[8], point_size)
        cv2.line(frame, (pts[[11,17],:][0,0], pts[[11,17],:][0,1]), (pts[[11,17],:][1,0], pts[[11,17],:][1,1]), color_mapping[8], point_size)

        # plot hip-necks
        cv2.line(frame, (pts[[12,7],:][0,0], pts[[12,7],:][0,1]), (pts[[12,7],:][1,0], pts[[12,7],:][1,1]), color_mapping[1], point_size)
        cv2.line(frame, (pts[[13,7],:][0,0], pts[[13,7],:][0,1]), (pts[[13,7],:][1,0], pts[[13,7],:][1,1]), color_mapping[2], point_size)
    elif order==1:
        # elephant
        cv2.line(frame, (pts[[0,2],:][0,0], pts[[0,2],:][0,1]), (pts[[0,2],:][1,0], pts[[0,2],:][1,1]), color_mapping[1], point_size)
        cv2.line(frame, (pts[[1,2],:][0,0], pts[[1,2],:][0,1]), (pts[[1,2],:][1,0], pts[[1,2],:][1,1]), color_mapping[2], point_size)

        # plot legs
        cv2.line(frame, (pts[[3,8],:][0,0], pts[[3,8],:][0,1]), (pts[[3,8],:][1,0], pts[[3,8],:][1,1]), color_mapping[5], point_size)
        cv2.line(frame, (pts[[8,14],:][0,0], pts[[8,14],:][0,1]), (pts[[8,14],:][1,0], pts[[8,14],:][1,1]), color_mapping[5], point_size)

        cv2.line(frame, (pts[[4,9],:][0,0], pts[[4,9],:][0,1]), (pts[[4,9],:][1,0], pts[[4,9],:][1,1]), color_mapping[6], point_size)
        cv2.line(frame, (pts[[9,15],:][0,0], pts[[9,15],:][0,1]), (pts[[9,15],:][1,0], pts[[9,15],:][1,1]), color_mapping[6], point_size)

        cv2.line(frame, (pts[[5,10],:][0,0], pts[[5,10],:][0,1]), (pts[[5,10],:][1,0], pts[[5,10],:][1,1]), color_mapping[7], point_size)
        cv2.line(frame, (pts[[10,16],:][0,0], pts[[10,16],:][0,1]), (pts[[10,16],:][1,0], pts[[10,16],:][1,1]), color_mapping[7], point_size)

        cv2.line(frame, (pts[[6,11],:][0,0], pts[[6,11],:][0,1]), (pts[[6,11],:][1,0], pts[[6,11],:][1,1]), color_mapping[8], point_size)
        cv2.line(frame, (pts[[11,17],:][0,0], pts[[11,17],:][0,1]), (pts[[11,17],:][1,0], pts[[11,17],:][1,1]), color_mapping[8], point_size)

        # plot hip-necks
        cv2.line(frame, (pts[[12,7],:][0,0], pts[[12,7],:][0,1]), (pts[[12,7],:][1,0], pts[[12,7],:][1,1]), color_mapping[1], point_size)
        cv2.line(frame, (pts[[13,7],:][0,0], pts[[13,7],:][0,1]), (pts[[13,7],:][1,0], pts[[13,7],:][1,1]), color_mapping[2], point_size)

        
        cv2.line(frame, (pts[[18,19],:][0,0], pts[[18,19],:][0,1]), (pts[[18,19],:][1,0], pts[[18,19],:][1,1]), color_mapping[1], point_size)
        cv2.line(frame, (pts[[19,20],:][0,0], pts[[19,20],:][0,1]), (pts[[19,20],:][1,0], pts[[19,20],:][1,1]), color_mapping[1], point_size)
        cv2.line(frame, (pts[[20,21],:][0,0], pts[[20,21],:][0,1]), (pts[[20,21],:][1,0], pts[[20,21],:][1,1]), color_mapping[1], point_size)
        cv2.line(frame, (pts[[21,22],:][0,0], pts[[21,22],:][0,1]), (pts[[21,22],:][1,0], pts[[21,22],:][1,1]), color_mapping[1], point_size)
        cv2.line(frame, (pts[[22,23],:][0,0], pts[[22,23],:][0,1]), (pts[[22,23],:][1,0], pts[[22,23],:][1,1]), color_mapping[1], point_size)
        cv2.line(frame, (pts[[23,24],:][0,0], pts[[23,24],:][0,1]), (pts[[23,24],:][1,0], pts[[23,24],:][1,1]), color_mapping[1], point_size)


def cv2_visualize_keypoints(frame, pts, num_pts=18, order=0):
    
    points = pts
    x = []
    y = []
    for i in range(num_pts):
        x.append(points[i][0])
        y.append(points[i][1])
        # plot keypoints on each image 
        cv2.circle(frame,(x[-1],y[-1]), 3, (0,255,0), -1)   
    cv2_plot_lines(frame, points, order)
    return frame

def load_animal(data_dir='./', animal='horse'):
    """
    Output:
    img_list: Nx3   # each image is associated with a shot-id and a shot-id frame_id,
                    # e.g. ('***.jpg', 100, 2) means the second frame in video_id 100.
    anno_list: Nx3  # (x, y, visiblity)
    """

    range_path = os.path.join(data_dir, 'behaviorDiscovery2.0/ranges', animal, 'ranges.mat')
    landmark_path = os.path.join(data_dir, 'behaviorDiscovery2.0/landmarks', animal)

    img_list = []  # img_list contains all image paths
    anno_list = [] # anno_list contains all anno lists
    range_file = loadmat(range_path)

    for video in range_file['ranges']:
        # range_file['ranges'] is a numpy array [Nx3]: shot_id, start_frame, end_frame
        shot_id = video[0]
        landmark_path_video = os.path.join(landmark_path, str(shot_id)+'.mat')

        if not os.path.isfile(landmark_path_video):
            continue
        landmark_file = loadmat(landmark_path_video)

        for frame in range(video[1], video[2]+1): # ??? video[2]+1
            frame_id = frame - video[1]
            img_name = '0'*(8-len(str(frame))) + str(frame) + '.jpg'
            img_list.append([img_name, shot_id, frame_id])
            
            coord = landmark_file['landmarks'][frame_id][0][0][0][0]
            vis = landmark_file['landmarks'][frame_id][0][0][0][1]
            landmark = np.hstack((coord, vis))
            anno_list.append(landmark[:18,:])
            
    return img_list, anno_list

def dataset_filter(anno_list):
    """
    output:
    idxs: valid_idxs after filtering
    """
    num_kpts = anno_list[0].shape[0]
    idxs = []
    for i in range(len(anno_list)):
        s = sum(anno_list[i][:,2])
        if s>num_kpts//2:
            idxs.append(i)
    return idxs

def im_to_torch(img):
    img = np.transpose(img, (2, 0, 1)) # C*H*W
    img = to_torch(img).float()
    if img.max() > 1:
        img /= 255
    return img


# Visualization for horses/tigers

In [8]:
global_animal = 'horse' # can be horse, tiger
order = 0 # 0: for other animals | 1: for elephant
global_dataset = 'synthetic_animal' # models trained using synthetic datasets
nParts = 18 # number of keypoints
is_part = True # True: multitask setting | False: keypoint only


if not os.path.exists(os.path.join('./demo', global_animal)):
    os.makedirs(os.path.join('./demo', global_animal))

# define directories
img_folder = '/media/jm/000C65DB000784DF/workspace/animals/behaviorDiscovery/'
img_list, anno_list = load_animal(data_dir=img_folder, animal=global_animal)
img_idxs = dataset_filter(anno_list)

# define the model
if is_part:
    # multi-task checkpoint
    checkpoint_path = os.path.join("./checkpoint/",global_dataset,global_animal+"_multitask/model_best.pth.tar")
else:
    # keypoint only checkpoint
    checkpoint_path = os.path.join("./checkpoint/",global_dataset,global_animal+'_ssl/synthetic_animal_sp.pth.tar")
print(checkpoint_path)
meanstd_file = os.path.join('./data/', global_dataset, global_animal+'_combineds5r5_texture', 'mean.pth.tar')


# Define the codec and create VideoWriter object
fourcc = cv2.VideoWriter_fourcc(*'XVID')
out = cv2.VideoWriter('output.avi',fourcc, 20.0, (512,256))

# load model
if is_part:
    model1 = models.__dict__['hg_multitask'](num_stacks=4, num_blocks=1, num_classes=nParts, resnet_layers=50)
else:
    model1 = models.__dict__['hg'](num_stacks=4, num_blocks=1, num_classes=nParts, resnet_layers=50)
model1 = torch.nn.DataParallel(model1).cuda()
checkpoint = torch.load(checkpoint_path)
model1.load_state_dict(checkpoint['state_dict'])
idxs = np.arange(nParts)

# calculate meand and std
mean, std = torch.load(meanstd_file)['mean'], torch.load(meanstd_file)['std']

for i in tqdm(range(len(img_idxs))):
    # color_mapping for segmentations
    # head, eye, ear, torso, left_front, right_front, left_back, right_back, tail
    color_mapping = {1: [255,0,0], 2: [203,192,255], 3: [255,0,127], 4: [255,255,255], 5: [0,0,255],
                 6: [0,127,255], 7: [0,255,255], 8: [0,255,0], 9: [200,162,200]}
    
    # calculate keypoints
    img = scipy.misc.imread(os.path.join(img_folder, 'behaviorDiscovery2.0/', global_animal, img_list[img_idxs[i]][0]), mode='RGB')
    img_path = img_list[img_idxs[i]][0]
    frame = img.copy()
    img = im_to_torch(img)
    
    # get correct scale and center
    x_min = float(np.min(anno_list[img_idxs[i]][:,0] \
                         [anno_list[img_idxs[i]][:,0]>0]))
    x_max = float(np.max(anno_list[img_idxs[i]][:,0] \
                         [anno_list[img_idxs[i]][:,0]>0]))
    y_min = float(np.min(anno_list[img_idxs[i]][:,1] \
                         [anno_list[img_idxs[i]][:,1]>0]))
    y_max = float(np.max(anno_list[img_idxs[i]][:,1] \
                         [anno_list[img_idxs[i]][:,1]>0]))

    c = torch.Tensor(( (x_min+x_max)/2.0, (y_min+y_max)/2.0 ))
    s = max(x_max-x_min, y_max-y_min)/200.0 * 1.5
    rot = 0
    
    inp = crop(img, c, s, [256, 256], rot)
    
    frame = torch.Tensor(frame.transpose(2,0,1))
    frame = crop(frame, c, s, [256, 256], rot)
    frame = (frame.numpy().transpose(1,2,0))*255
    frame = np.uint8(frame)
    frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
#         frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    
    inp_show = inp.clone()
    inp = color_normalize(inp, mean, std)

    inp = inp.unsqueeze(0).cuda()
    if is_part:
        output, mask = model1(inp)
    else:
        output = model1(inp)
    score_map = output[-1].cpu() if type(output) == list else output.cpu()

    preds = final_preds(score_map, [c], [s], [64, 64])
    preds = preds.squeeze(0)
    pts = preds.clone().cpu().numpy()
    
    # get confidence score
    confidence_score = np.max(score_map.detach().cpu().numpy(), axis=(0,2,3))
    confidence = confidence_score>0.5
    transformed_preds = np.zeros((nParts, 2), dtype=int)
    for j in range(nParts):
        transformed_preds[j] = transform(pts[j]+1, c, s, [256, 256], invert=0, rot=0)
    
    # get predicted mask
    if is_part:
        _, pred_seg = torch.max(mask[-1], 1)
        part_seg = pred_seg.detach().cpu().numpy().transpose(1,2,0).astype(np.uint8).repeat(3, axis=2)
        for i in range(1,10):
            part_seg[:,:,0][part_seg[:,:,0]==i] = color_mapping[i][0]
            part_seg[:,:,1][part_seg[:,:,1]==i] = color_mapping[i][1]    
            part_seg[:,:,2][part_seg[:,:,2]==i] = color_mapping[i][2]

        pred_mask = scipy.misc.imresize(part_seg, [256,256,3])
        cv2.imshow('mask', pred_mask)

#     # color_mapping for kpts
#     color_mapping = {1: [255,0,255], 2: [255,0,0], 3: [255,0,127], 4: [255,255,255], 5: [0,0,255],
#                  6: [0,127,255], 7: [0,255,255], 8: [0,255,0], 9: [200,162,200]}    
    
    # plot keypoints
    frame = cv2_visualize_keypoints(frame, transformed_preds[idxs], num_pts=nParts, order=0)

    cv2.imshow('frame', frame)

    if is_part:
        cv2.imwrite('demo/'+global_animal+'/'+img_path + '_all.jpg', np.concatenate( (frame, pred_mask), axis=1))
        cv2.imwrite('demo/'+global_animal+'/'+img_path+'_seg.jpg', pred_mask)
        # out.write(np.concatenate( (frame, pred_mask), axis=1))
    cv2.imwrite('demo/'+global_animal+'/'+img_path, frame)

    
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

# Release everything if job is finished
out.release()
cv2.destroyAllWindows()

../checkpoint/synthetic_animal/horse/combineds5r5_decay0.01_parallel/model_best.pth.tar


`imread` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``imageio.imread`` instead.
`imresize` is deprecated in SciPy 1.0.0, and will be removed in 1.3.0.
Use Pillow instead: ``numpy.array(Image.fromarray(arr).resize())``.
  0%|          | 44/10152 [00:11<43:37,  3.86it/s]