In [None]:
import os
import sys
import pickle
import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
import lmdb
import imgaug as ia
from imgaug import augmenters as iaa

sys.path.append(os.path.dirname(os.path.abspath('')))
import utils.prepare_data as pd
import utils.xyzuvd as xyzuvd
import utils.visual as visual
import utils.error as error
from utils.directory import DATA_DIR, DATASET_DIR

In [None]:
save_prefix = 'train_fpha'
keys_cache_file = os.path.join(DATA_DIR, save_prefix + '_keys_cache.p')
keys = pickle.load(open(keys_cache_file, "rb"))
REORDER = visual.REORDER

idx = 0
key = keys[idx]
img = np.asarray(Image.open(os.path.join(DATASET_DIR, 'First_Person_Action_Benchmark', 'Video_files_416', key)))

dataroot_uvd_gt = os.path.join(DATA_DIR, save_prefix + '_uvd_gt_resize.lmdb')
uvd_gt_env = lmdb.open(dataroot_uvd_gt, readonly=True, lock=False, readahead=False, meminit=False)
uvd_gt = pd.read_lmdb(key, uvd_gt_env, np.float32, (21, 3))

dataroot_hand_cell_idx = os.path.join(DATA_DIR, save_prefix + '_hand_cell_idx.lmdb')
hand_cell_idx_env = lmdb.open(dataroot_hand_cell_idx, readonly=True, lock=False, readahead=False, meminit=False)
hand_cell_idx = pd.read_lmdb(key, hand_cell_idx_env, np.uint8, (845))

In [None]:
fig,ax = plt.subplots()
ax.imshow(img)
visual.visualize_joints_2d(ax, uvd_gt[REORDER], joint_idxs=False)

In [None]:
skel_kps = []
for kps in uvd_gt:
    skel_kps.append(ia.Keypoint(x=kps[0],y=kps[1]))
skel_kpsoi = ia.KeypointsOnImage(skel_kps, shape=img.shape)

seq = iaa.Sequential([
    iaa.ChangeColorspace(from_colorspace="RGB", to_colorspace="HSV"),
    iaa.WithChannels(0, iaa.Add((-90, 90))), #hue
    iaa.WithChannels(1, iaa.Add((-128, 128))), #saturation
    iaa.WithChannels(2, iaa.Add((-128, 128))), #exposure
    iaa.ChangeColorspace(from_colorspace="HSV", to_colorspace="RGB"),
    iaa.Affine(translate_percent={"x": (-0.1, 0.1), "y": (-0.1, 0.1)}),
])

seq_det = seq.to_deterministic()
img_aug = seq_det.augment_images([img])[0]
kps_aug = seq_det.augment_keypoints([skel_kpsoi])[0]
kps_aug = kps_aug.get_coords_array()
uvd_gt_aug = np.concatenate((kps_aug, np.expand_dims((uvd_gt[:, 2]), -1)), -1)

In [None]:
fig,ax = plt.subplots()
ax.imshow(img_aug)
visual.visualize_joints_2d(ax, uvd_gt_aug[REORDER], joint_idxs=False)

print(np.amax(img_aug[..., 0]))
print(np.amax(img_aug[..., 1]))
print(np.amax(img_aug[..., 2]))

In [None]:
pad = 10
x_max = int(np.amax(uvd_gt_aug[:, 0])) + pad
x_min = np.maximum(int(np.amin(uvd_gt_aug[:, 0])) - pad, 0)
y_max = int(np.amax(uvd_gt_aug[:, 1])) + pad
y_min = np.maximum(int(np.amin(uvd_gt_aug[:, 1])) - pad, 0)
z_max = int(np.amax(uvd_gt_aug[:, 2])) + pad
z_min = np.maximum(int(np.amin(uvd_gt_aug[:, 2])) - pad, 0)

x_min_scale = x_min//32
x_max_scale = np.ceil(x_max/32)
y_min_scale = y_min//32
y_max_scale = np.ceil(y_max/32)
z_min_scale = z_min//120
z_max_scale = np.ceil(z_max/120)

if z_max_scale > 5:
    z_max_scale = 5
if y_max_scale > 13:
    y_max_scale = 13
if x_max_scale > 13:
    x_max_scale = 13

import itertools
comb = [list(i) for i in itertools.product(np.arange(x_min_scale, x_max_scale), \
                                              np.arange(y_min_scale, y_max_scale), \
                                              np.arange(z_min_scale, z_max_scale))]
comb = np.asarray(comb, dtype=np.uint8)

import matplotlib.patches as patches
fig, ax = plt.subplots(figsize=(13,13))
ax.imshow(img_aug)
visual.visualize_joints_2d(ax, uvd_gt_aug[REORDER], joint_idxs=False)
for x,y,z in comb:
    rect = patches.Rectangle((x*32,y*32),32,32,linewidth=1,edgecolor='r', facecolor='r', fill=True, alpha=0.5)
    ax.add_patch(rect)
    

In [None]:
ravel_comb = []
for c in comb:
    ravel_comb.append(np.ravel_multi_index(c, (13,13,5)))
    
hand_cell_i = np.zeros(845)
hand_cell_i[ravel_comb] = 1
hand_cell_i = hand_cell_i.astype('uint8')

In [None]:
print(hand_cell_i)

In [None]:
idxes = np.where(hand_cell_i == 1)
unravel = np.unravel_index(idxes, (13, 13, 5))
indexes = np.squeeze(unravel)

In [None]:
import matplotlib.patches as patches
fig, ax = plt.subplots(figsize=(13,13))
ax.imshow(img_aug)
visual.visualize_joints_2d(ax, uvd_gt_aug[REORDER], joint_idxs=False)
for i in range(indexes.shape[-1]):
    x = indexes[0,i]
    y = indexes[1,i]
    rect = patches.Rectangle((x*32,y*32),32,32,linewidth=1,edgecolor='r', facecolor='r', fill=True, alpha=0.5)
    ax.add_patch(rect)