In [None]:
import tensorflow as tf

import numpy as np

import matplotlib.pyplot as plt

from MakiPoseNet.pose_estimation.generators.pose_estimation import  LoadDataMethod, RIterator, \
                                    RandomCropMethod, AugmentationPostMethod
from MakiPoseNet.pose_estimation.generators.pipeline.tfr.tfr_pathgenerator import CycleGenerator
from MakiPoseNet.pose_estimation.generators.pipeline.tfr.tfr_gen_layers import InputGenLayerV2, InputGenLayerV2Batched

%matplotlib inline

In [None]:
import glob
tfrecords = glob.glob('tfrecords/*') 

In [None]:
cycle = CycleGenerator(tfrecords)

In [None]:
map_me = LoadDataMethod([24, 8])
map_me = RandomCropMethod(512, 512)(map_me)
map_me = AugmentationPostMethod(
    use_rotation=True,
    angle_max=50,
    angle_min=-50,
    use_shift=False,
    dx_min=-10.0,
    dx_max=10.0,
    dy_min=-10.0,
    dy_max=10.0,
    use_zoom=True,
    zoom_min=0.6,
    zoom_max=1.5,
)(map_me)
"""

map_read = LoadDataMethod([21, 8])
map_read = RandomCropMethod(512, 512)(map_read)
map_me = AugmentationPostMethod(
    use_rotation=True,
    angle_max=20,
    angle_min=-20,
    use_shift=False,
    dx_min=-10.0,
    dx_max=10.0,
    dy_min=-10.0,
    dy_max=10.0,
    use_zoom=True,
    zoom_min=0.8,
    zoom_max=1.2,
)
"""

In [None]:

gen_layer = InputGenLayerV2(
    prefetch_size=3,
    batch_size=4,
    input_data_type=RIterator.IMAGE,
    tfr_path_generator=cycle,
    name='super_generator',
    map_operation=map_me,
    num_parallel_calls=-1,
    cycle_length=4,
    block_length=4
)
"""
gen_layer = InputGenLayerV2Batched(
    prefetch_size=2,
    batch_size=4,
    input_data_type=RIterator.IMAGE,
    tfr_path_generator=cycle,
    name='super_generator',
    operation_before_batched=map_read,
    map_operation=map_me,
    num_parallel_calls=-1,
    cycle_length=4,
    block_length=4
)
"""

In [None]:
sess = tf.Session()

In [None]:
gen_layer.get_iterator()

In [None]:
single = sess.run(gen_layer.get_iterator())  
img, kp, kp_mask, size = [
    single[RIterator.IMAGE],
    single[RIterator.KEYPOINTS],
    single[RIterator.KEYPOINTS_MASK],
    single[RIterator.IMAGE_PROPERTIES]
]

In [None]:
fig = plt.figure(figsize=(10, 10))

fig.add_subplot(221)
plt.imshow(img[0].astype(np.uint8))
plt.scatter(kp[0][:,:, 0] * kp_mask[0][..., 0], kp[0][:,:, 1] * kp_mask[0][..., 0])

fig.add_subplot(222)
plt.imshow(img[1].astype(np.uint8))
plt.scatter(kp[1][:,:, 0] * kp_mask[1][..., 0], kp[1][:,:, 1] * kp_mask[1][..., 0])

fig.add_subplot(223)
plt.imshow(img[2].astype(np.uint8))
plt.scatter(kp[2][:,:, 0] * kp_mask[2][..., 0], kp[2][:,:, 1] * kp_mask[2][..., 0])

fig.add_subplot(224)
plt.imshow(img[3].astype(np.uint8))
plt.scatter(kp[3][:,:, 0] * kp_mask[3][..., 0], kp[3][:,:, 1] * kp_mask[3][..., 0])

## Measure time generation of the pipeline

In [None]:
import time

In [None]:
start = time.time()

for i in range(100 * 1):
    _ = sess.run(gen_layer.get_iterator())  

print(time.time() - start)

In [None]:
# 4.678076267242432

## Calculate number of bad and good image with keypoints

In [None]:
good = 0
bad = 0
unseen = 0
count = 0

In [None]:
single = sess.run(gen_layer.get_iterator())  
img, kp, kp_mask, size = [
    single[RIterator.IMAGE],
    single[RIterator.KEYPOINTS],
    single[RIterator.KEYPOINTS_MASK],
    single[RIterator.IMAGE_PROPERTIES]
]

fig = plt.figure(figsize=(10, 10))

fig.add_subplot(221)
plt.imshow(img[0].astype(np.uint8))
plt.scatter(kp[0][:,:, 0] * kp_mask[0][..., 0], kp[0][:,:, 1] * kp_mask[0][..., 0])

fig.add_subplot(222)
plt.imshow(img[1].astype(np.uint8))
plt.scatter(kp[1][:,:, 0] * kp_mask[1][..., 0], kp[1][:,:, 1] * kp_mask[1][..., 0])

fig.add_subplot(223)
plt.imshow(img[2].astype(np.uint8))
plt.scatter(kp[2][:,:, 0] * kp_mask[2][..., 0], kp[2][:,:, 1] * kp_mask[2][..., 0])

fig.add_subplot(224)
plt.imshow(img[3].astype(np.uint8))
plt.scatter(kp[3][:,:, 0] * kp_mask[3][..., 0], kp[3][:,:, 1] * kp_mask[3][..., 0])

In [None]:
good += 3

In [None]:
bad += 1

In [None]:
unseen += 1

In [None]:
count += 4

In [None]:
count

In [None]:
print('Image count: ', count)
print('Good: ', good)
print('Bad: ', bad)
print('Unseen: ', unseen)