In [3]:
# 주의! ray를 tensorflow보다 먼저 import하면 오류가 발생할 수 있습니다
import io, json, os, math

import tensorflow as tf
from tensorflow.keras.layers import Add, Concatenate, Lambda
from tensorflow.keras.layers import Input, Conv2D, ReLU, MaxPool2D
from tensorflow.keras.layers import UpSampling2D, ZeroPadding2D
from tensorflow.keras.layers import BatchNormalization
import ray

import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

PROJECT_PATH = os.getenv('HOME') + '/aiffel/mpii'
IMAGE_PATH = os.path.join(PROJECT_PATH, 'images')
MODEL_PATH = os.path.join(PROJECT_PATH, 'mine')
TFRECORD_PATH = os.path.join(PROJECT_PATH, 'tfrecords_mpii')
TRAIN_JSON = os.path.join(PROJECT_PATH, 'mpii_human_pose_v1_u12_2', 'train.json')
VALID_JSON = os.path.join(PROJECT_PATH, 'mpii_human_pose_v1_u12_2', 'validation.json')

In [None]:
with open(TRAIN_JSON) as train_json:
    train_annos = json.load(train_json)
    json_formatted_str = json.dumps(train_annos[0], indent=2)
    print(json_formatted_str)

0 - 오른쪽 발목
1 - 오른쪽 무릎
2 - 오른쪽 엉덩이
3 - 왼쪽 엉덩이
4 - 왼쪽 무릎
5 - 왼쪽 발목
6 - 골반
7 - 가슴(흉부)
8 - 목
9 - 머리 위
10 - 오른쪽 손목
11 - 오른쪽 팔꿈치
12 - 오른쪽 어깨
13 - 왼쪽 어깨
14 - 왼쪽 팔꿈치
15 - 왼쪽 손목

In [4]:
# json annotation parsing
def parse_one_annotation(anno, image_dir):
    filename = anno['image']
    joints = anno['joints']
    joints_visibility = anno['joints_vis']
    annotation = {
        'filename': filename,
        'filepath': os.path.join(image_dir, filename),
        'joints_visibility': joints_visibility,
        'joints': joints,
        'center': anno['center'],
        'scale' : anno['scale']
    }
    return annotation

In [5]:
# test
with open(TRAIN_JSON) as train_json:
    train_annos = json.load(train_json)
    test = parse_one_annotation(train_annos[0], IMAGE_PATH)
    print(test)

{'filename': '015601864.jpg', 'filepath': '/aiffel/aiffel/mpii/images/015601864.jpg', 'joints_visibility': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'joints': [[620.0, 394.0], [616.0, 269.0], [573.0, 185.0], [647.0, 188.0], [661.0, 221.0], [656.0, 231.0], [610.0, 187.0], [647.0, 176.0], [637.0201, 189.8183], [695.9799, 108.1817], [606.0, 217.0], [553.0, 161.0], [601.0, 167.0], [692.0, 185.0], [693.0, 240.0], [688.0, 313.0]], 'center': [594.0, 257.0], 'scale': 3.021046}


In [6]:
def generate_tfexample(anno):

    # byte 인코딩을 위한 함수
    def _bytes_feature(value):
        if isinstance(value, type(tf.constant(0))):
            value = value.numpy()
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

    filename = anno['filename']
    filepath = anno['filepath']
    with open(filepath, 'rb') as image_file:
        content = image_file.read()

    image = Image.open(filepath)
    if image.format != 'JPEG' or image.mode != 'RGB':
        image_rgb = image.convert('RGB')
        with io.BytesIO() as output:
            image_rgb.save(output, format="JPEG", quality=95)
            content = output.getvalue()

    width, height = image.size
    depth = 3

    c_x = int(anno['center'][0])
    c_y = int(anno['center'][1])
    scale = anno['scale']

    x = [
        int(joint[0]) if joint[0] >= 0 else int(joint[0]) 
        for joint in anno['joints']
    ]
    y = [
        int(joint[1]) if joint[1] >= 0 else int(joint[0]) 
        for joint in anno['joints']
    ]

    v = [0 if joint_v == 0 else 2 for joint_v in anno['joints_visibility']]

    feature = {
        'image/height':
        tf.train.Feature(int64_list=tf.train.Int64List(value=[height])),
        'image/width':
        tf.train.Feature(int64_list=tf.train.Int64List(value=[width])),
        'image/depth':
        tf.train.Feature(int64_list=tf.train.Int64List(value=[depth])),
        'image/object/parts/x':
        tf.train.Feature(int64_list=tf.train.Int64List(value=x)),
        'image/object/parts/y':
        tf.train.Feature(int64_list=tf.train.Int64List(value=y)),
        'image/object/center/x': 
        tf.train.Feature(int64_list=tf.train.Int64List(value=[c_x])),
        'image/object/center/y': 
        tf.train.Feature(int64_list=tf.train.Int64List(value=[c_y])),
        'image/object/scale':
        tf.train.Feature(float_list=tf.train.FloatList(value=[scale])),
        'image/object/parts/v':
        tf.train.Feature(int64_list=tf.train.Int64List(value=v)),
        'image/encoded':
        _bytes_feature(content),
        'image/filename':
        _bytes_feature(filename.encode())
    }

    return tf.train.Example(features=tf.train.Features(feature=feature))

In [7]:
# 얼마나 많은 TFRecord를 만들지 결정할 함수 : shard
def chunkify(l, n):
    size = len(l) // n
    start = 0
    results = []
    for i in range(n):
        results.append(l[start:start + size])
        start += size
    return results

In [8]:
# test
test_chunks = chunkify([0] * 1000, 64)
print(test_chunks)
print(len(test_chunks))
print(len(test_chunks[0]))

[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0,

In [9]:
# 병렬처리 위한 ray
@ray.remote
def build_single_tfrecord(chunk, path):
    print('start to build tf records for ' + path)

    with tf.io.TFRecordWriter(path) as writer:
        for anno in chunk:
            tf_example = generate_tfexample(anno)
            writer.write(tf_example.SerializeToString())

    print('finished building tf records for ' + path)

In [10]:
def build_tf_records(annotations, total_shards, split):
    chunks = chunkify(annotations, total_shards)
    futures = [
        build_single_tfrecord.remote(
            chunk, '{}/{}_{}_of_{}.tfrecords'.format(
                TFRECORD_PATH,
                split,
                str(i + 1).zfill(4),
                str(total_shards).zfill(4),
            )) for i, chunk in enumerate(chunks)
    ]
    ray.get(futures)

In [None]:
num_train_shards = 64
num_val_shards = 8

ray.init()

print('Start to parse annotations.')
if not os.path.exists(TFRECORD_PATH):
    os.makedirs(TFRECORD_PATH)

with open(TRAIN_JSON) as train_json:
    train_annos = json.load(train_json)
    train_annotations = [
        parse_one_annotation(anno, IMAGE_PATH)
        for anno in train_annos
    ]
    print('First train annotation: ', train_annotations[0])

with open(VALID_JSON) as val_json:
    val_annos = json.load(val_json)
    val_annotations = [
        parse_one_annotation(anno, IMAGE_PATH) 
        for anno in val_annos
    ]
    print('First val annotation: ', val_annotations[0])
    
print('Start to build TF Records.')
build_tf_records(train_annotations, num_train_shards, 'train')
build_tf_records(val_annotations, num_val_shards, 'val')

print('Successfully wrote {} annotations to TF Records.'.format(
    len(train_annotations) + len(val_annotations)))

In [11]:
# data label로 만들기
def parse_tfexample(example):
    image_feature_description = {
        'image/height': tf.io.FixedLenFeature([], tf.int64),
        'image/width': tf.io.FixedLenFeature([], tf.int64),
        'image/depth': tf.io.FixedLenFeature([], tf.int64),
        'image/object/parts/x': tf.io.VarLenFeature(tf.int64),
        'image/object/parts/y': tf.io.VarLenFeature(tf.int64),
        'image/object/parts/v': tf.io.VarLenFeature(tf.int64),
        'image/object/center/x': tf.io.FixedLenFeature([], tf.int64),
        'image/object/center/y': tf.io.FixedLenFeature([], tf.int64),
        'image/object/scale': tf.io.FixedLenFeature([], tf.float32),
        'image/encoded': tf.io.FixedLenFeature([], tf.string),
        'image/filename': tf.io.FixedLenFeature([], tf.string),
    }
    return tf.io.parse_single_example(example, image_feature_description)

In [12]:
# image crop : 정사각형, 이 때 이미지 바깥으로 crop box가 나가면 안됨
def crop_roi(image, features, margin=0.2):
    img_shape = tf.shape(image)
    img_height = img_shape[0]
    img_width = img_shape[1]
    img_depth = img_shape[2]

    keypoint_x = tf.cast(tf.sparse.to_dense(features['image/object/parts/x']), dtype=tf.int32)
    keypoint_y = tf.cast(tf.sparse.to_dense(features['image/object/parts/y']), dtype=tf.int32)
    center_x = features['image/object/center/x']
    center_y = features['image/object/center/y']
    body_height = features['image/object/scale'] * 200.0

    # keypoint 중 유효한값(visible = 1) 만 사용합니다.
    masked_keypoint_x = tf.boolean_mask(keypoint_x, keypoint_x > 0)
    masked_keypoint_y = tf.boolean_mask(keypoint_y, keypoint_y > 0)

    # min, max 값을 찾습니다.
    keypoint_xmin = tf.reduce_min(masked_keypoint_x)
    keypoint_xmax = tf.reduce_max(masked_keypoint_x)
    keypoint_ymin = tf.reduce_min(masked_keypoint_y)
    keypoint_ymax = tf.reduce_max(masked_keypoint_y)

    # 높이 값을 이용해서 x, y 위치를 재조정 합니다. 박스를 정사각형으로 사용하기 위해 아래와 같이 사용합니다.
    xmin = keypoint_xmin - tf.cast(body_height * margin, dtype=tf.int32)
    xmax = keypoint_xmax + tf.cast(body_height * margin, dtype=tf.int32)
    ymin = keypoint_ymin - tf.cast(body_height * margin, dtype=tf.int32)
    ymax = keypoint_ymax + tf.cast(body_height * margin, dtype=tf.int32)

    # 이미지 크기를 벗어나는 점을 재조정 해줍니다.
    effective_xmin = xmin if xmin > 0 else 0
    effective_ymin = ymin if ymin > 0 else 0
    effective_xmax = xmax if xmax < img_width else img_width
    effective_ymax = ymax if ymax < img_height else img_height
    effective_height = effective_ymax - effective_ymin
    effective_width = effective_xmax - effective_xmin

    image = image[effective_ymin:effective_ymax, effective_xmin:effective_xmax, :]
    new_shape = tf.shape(image)
    new_height = new_shape[0]
    new_width = new_shape[1]

    effective_keypoint_x = (keypoint_x - effective_xmin) / new_width
    effective_keypoint_y = (keypoint_y - effective_ymin) / new_height

    return image, effective_keypoint_x, effective_keypoint_y

In [13]:
def generate_2d_guassian(height, width, y0, x0, visibility=2, sigma=1, scale=12):
    heatmap = tf.zeros((height, width))

    xmin = x0 - 3 * sigma
    ymin = y0 - 3 * sigma
    xmax = x0 + 3 * sigma
    ymax = y0 + 3 * sigma
    
    if xmin >= width or ymin >= height or xmax < 0 or ymax < 0 or visibility == 0:
        return heatmap

    size = 6 * sigma + 1
    x, y = tf.meshgrid(tf.range(0, 6 * sigma + 1, 1), tf.range(0, 6 * sigma + 1, 1), indexing='xy')

    center_x = size // 2
    center_y = size // 2

    gaussian_patch = tf.cast(tf.math.exp(
        -(tf.math.square(x - center_x) + tf.math.square(y - center_y)) / (tf.math.square(sigma) * 2)) * scale,
                             dtype=tf.float32)

    patch_xmin = tf.math.maximum(0, -xmin)
    patch_ymin = tf.math.maximum(0, -ymin)
    patch_xmax = tf.math.minimum(xmax, width) - xmin
    patch_ymax = tf.math.minimum(ymax, height) - ymin

    heatmap_xmin = tf.math.maximum(0, xmin)
    heatmap_ymin = tf.math.maximum(0, ymin)
    heatmap_xmax = tf.math.minimum(xmax, width)
    heatmap_ymax = tf.math.minimum(ymax, height)

    indices = tf.TensorArray(tf.int32, 1, dynamic_size=True)
    updates = tf.TensorArray(tf.float32, 1, dynamic_size=True)

    count = 0

    for j in tf.range(patch_ymin, patch_ymax):
        for i in tf.range(patch_xmin, patch_xmax):
            indices = indices.write(count, [heatmap_ymin + j, heatmap_xmin + i])
            updates = updates.write(count, gaussian_patch[j][i])
            count += 1

    heatmap = tf.tensor_scatter_nd_update(heatmap, indices.stack(), updates.stack())

    return heatmap

def make_heatmaps(features, keypoint_x, keypoint_y, heatmap_shape):
    v = tf.cast(tf.sparse.to_dense(features['image/object/parts/v']), dtype=tf.float32)
    x = tf.cast(tf.math.round(keypoint_x * heatmap_shape[0]), dtype=tf.int32)
    y = tf.cast(tf.math.round(keypoint_y * heatmap_shape[1]), dtype=tf.int32)

    num_heatmap = heatmap_shape[2]
    heatmap_array = tf.TensorArray(tf.float32, 16)

    for i in range(num_heatmap):
        gaussian = self.generate_2d_guassian(heatmap_shape[1], heatmap_shape[0], y[i], x[i], v[i])
        heatmap_array = heatmap_array.write(i, gaussian)

    heatmaps = heatmap_array.stack()
    heatmaps = tf.transpose(heatmaps, perm=[1, 2, 0])  # change to (64, 64, 16)

    return heatmaps

In [14]:
class Preprocessor(object):
    def __init__(self,
                 image_shape=(256, 256, 3),
                 heatmap_shape=(64, 64, 16),
                 is_train=False):
        self.is_train = is_train
        self.image_shape = image_shape
        self.heatmap_shape = heatmap_shape

    def __call__(self, example):
        features = self.parse_tfexample(example)
        image = tf.io.decode_jpeg(features['image/encoded'])

        if self.is_train:
            random_margin = tf.random.uniform([1], 0.1, 0.3)[0]
            image, keypoint_x, keypoint_y = self.crop_roi(image, features, margin=random_margin)
            image = tf.image.resize(image, self.image_shape[0:2])
        else:
            image, keypoint_x, keypoint_y = self.crop_roi(image, features)
            image = tf.image.resize(image, self.image_shape[0:2])

        image = tf.cast(image, tf.float32) / 127.5 - 1
        heatmaps = self.make_heatmaps(features, keypoint_x, keypoint_y, self.heatmap_shape)

        return image, heatmaps

        
    def crop_roi(self, image, features, margin=0.2):
        img_shape = tf.shape(image)
        img_height = img_shape[0]
        img_width = img_shape[1]
        img_depth = img_shape[2]

        keypoint_x = tf.cast(tf.sparse.to_dense(features['image/object/parts/x']), dtype=tf.int32)
        keypoint_y = tf.cast(tf.sparse.to_dense(features['image/object/parts/y']), dtype=tf.int32)
        center_x = features['image/object/center/x']
        center_y = features['image/object/center/y']
        body_height = features['image/object/scale'] * 200.0
        
        masked_keypoint_x = tf.boolean_mask(keypoint_x, keypoint_x > 0)
        masked_keypoint_y = tf.boolean_mask(keypoint_y, keypoint_y > 0)
        
        keypoint_xmin = tf.reduce_min(masked_keypoint_x)
        keypoint_xmax = tf.reduce_max(masked_keypoint_x)
        keypoint_ymin = tf.reduce_min(masked_keypoint_y)
        keypoint_ymax = tf.reduce_max(masked_keypoint_y)
        
        xmin = keypoint_xmin - tf.cast(body_height * margin, dtype=tf.int32)
        xmax = keypoint_xmax + tf.cast(body_height * margin, dtype=tf.int32)
        ymin = keypoint_ymin - tf.cast(body_height * margin, dtype=tf.int32)
        ymax = keypoint_ymax + tf.cast(body_height * margin, dtype=tf.int32)
        
        effective_xmin = xmin if xmin > 0 else 0
        effective_ymin = ymin if ymin > 0 else 0
        effective_xmax = xmax if xmax < img_width else img_width
        effective_ymax = ymax if ymax < img_height else img_height
        effective_height = effective_ymax - effective_ymin
        effective_width = effective_xmax - effective_xmin

        image = image[effective_ymin:effective_ymax, effective_xmin:effective_xmax, :]
        new_shape = tf.shape(image)
        new_height = new_shape[0]
        new_width = new_shape[1]
        
        effective_keypoint_x = (keypoint_x - effective_xmin) / new_width
        effective_keypoint_y = (keypoint_y - effective_ymin) / new_height
        
        return image, effective_keypoint_x, effective_keypoint_y
        
    
    def generate_2d_guassian(self, height, width, y0, x0, visibility=2, sigma=1, scale=12):
        
        heatmap = tf.zeros((height, width))

        xmin = x0 - 3 * sigma
        ymin = y0 - 3 * sigma
        xmax = x0 + 3 * sigma
        ymax = y0 + 3 * sigma

        if xmin >= width or ymin >= height or xmax < 0 or ymax <0 or visibility == 0:
            return heatmap

        size = 6 * sigma + 1
        x, y = tf.meshgrid(tf.range(0, 6*sigma+1, 1), tf.range(0, 6*sigma+1, 1), indexing='xy')

        center_x = size // 2
        center_y = size // 2

        gaussian_patch = tf.cast(tf.math.exp(-(tf.square(x - center_x) + tf.math.square(y - center_y)) / (tf.math.square(sigma) * 2)) * scale, dtype=tf.float32)

        patch_xmin = tf.math.maximum(0, -xmin)
        patch_ymin = tf.math.maximum(0, -ymin)
        patch_xmax = tf.math.minimum(xmax, width) - xmin
        patch_ymax = tf.math.minimum(ymax, height) - ymin

        heatmap_xmin = tf.math.maximum(0, xmin)
        heatmap_ymin = tf.math.maximum(0, ymin)
        heatmap_xmax = tf.math.minimum(xmax, width)
        heatmap_ymax = tf.math.minimum(ymax, height)

        indices = tf.TensorArray(tf.int32, 1, dynamic_size=True)
        updates = tf.TensorArray(tf.float32, 1, dynamic_size=True)

        count = 0

        for j in tf.range(patch_ymin, patch_ymax):
            for i in tf.range(patch_xmin, patch_xmax):
                indices = indices.write(count, [heatmap_ymin+j, heatmap_xmin+i])
                updates = updates.write(count, gaussian_patch[j][i])
                count += 1
                
        heatmap = tf.tensor_scatter_nd_update(heatmap, indices.stack(), updates.stack())

        return heatmap


    def make_heatmaps(self, features, keypoint_x, keypoint_y, heatmap_shape):
        v = tf.cast(tf.sparse.to_dense(features['image/object/parts/v']), dtype=tf.float32)
        x = tf.cast(tf.math.round(keypoint_x * heatmap_shape[0]), dtype=tf.int32)
        y = tf.cast(tf.math.round(keypoint_y * heatmap_shape[1]), dtype=tf.int32)
        
        num_heatmap = heatmap_shape[2]
        heatmap_array = tf.TensorArray(tf.float32, 16)

        for i in range(num_heatmap):
            gaussian = self.generate_2d_guassian(heatmap_shape[1], heatmap_shape[0], y[i], x[i], v[i])
            heatmap_array = heatmap_array.write(i, gaussian)
        
        heatmaps = heatmap_array.stack()
        heatmaps = tf.transpose(heatmaps, perm=[1, 2, 0]) # change to (64, 64, 16)
        
        return heatmaps

    def parse_tfexample(self, example):
        image_feature_description = {
            'image/height': tf.io.FixedLenFeature([], tf.int64),
            'image/width': tf.io.FixedLenFeature([], tf.int64),
            'image/depth': tf.io.FixedLenFeature([], tf.int64),
            'image/object/parts/x': tf.io.VarLenFeature(tf.int64),
            'image/object/parts/y': tf.io.VarLenFeature(tf.int64),
            'image/object/parts/v': tf.io.VarLenFeature(tf.int64),
            'image/object/center/x': tf.io.FixedLenFeature([], tf.int64),
            'image/object/center/y': tf.io.FixedLenFeature([], tf.int64),
            'image/object/scale': tf.io.FixedLenFeature([], tf.float32),
            'image/encoded': tf.io.FixedLenFeature([], tf.string),
            'image/filename': tf.io.FixedLenFeature([], tf.string),
        }
        return tf.io.parse_single_example(example,
                                          image_feature_description)

# Hourglass model

In [15]:
def BottleneckBlock(inputs, filters, strides=1, downsample=False, name=None):
    identity = inputs
    if downsample:
        identity = Conv2D(
            filters=filters,
            kernel_size=1,
            strides=strides,
            padding='same',
            kernel_initializer='he_normal')(inputs)

    x = BatchNormalization(momentum=0.9)(inputs)
    x = ReLU()(x)
    x = Conv2D(
        filters=filters // 2,
        kernel_size=1,
        strides=1,
        padding='same',
        kernel_initializer='he_normal')(x)

    x = BatchNormalization(momentum=0.9)(x)
    x = ReLU()(x)
    x = Conv2D(
        filters=filters // 2,
        kernel_size=3,
        strides=strides,
        padding='same',
        kernel_initializer='he_normal')(x)

    x = BatchNormalization(momentum=0.9)(x)
    x = ReLU()(x)
    x = Conv2D(
        filters=filters,
        kernel_size=1,
        strides=1,
        padding='same',
        kernel_initializer='he_normal')(x)

    x = Add()([identity, x])
    return x

In [16]:
def HourglassModule(inputs, order, filters, num_residual):
    
    up1 = BottleneckBlock(inputs, filters, downsample=False)
    for i in range(num_residual):
        up1 = BottleneckBlock(up1, filters, downsample=False)

    low1 = MaxPool2D(pool_size=2, strides=2)(inputs)
    for i in range(num_residual):
        low1 = BottleneckBlock(low1, filters, downsample=False)

    low2 = low1
    if order > 1:
        low2 = HourglassModule(low1, order - 1, filters, num_residual)
    else:
        for i in range(num_residual):
            low2 = BottleneckBlock(low2, filters, downsample=False)

    low3 = low2
    for i in range(num_residual):
        low3 = BottleneckBlock(low3, filters, downsample=False)

    up2 = UpSampling2D(size=2)(low3)

    return up2 + up1

In [17]:
def LinearLayer(inputs, filters):
    x = Conv2D(
        filters=filters,
        kernel_size=1,
        strides=1,
        padding='same',
        kernel_initializer='he_normal')(inputs)
    x = BatchNormalization(momentum=0.9)(x)
    x = ReLU()(x)
    return x

In [18]:
def StackedHourglassNetwork(
        input_shape=(256, 256, 3), 
        num_stack=4, 
        num_residual=1,
        num_heatmap=16):
    
    inputs = Input(shape=input_shape)

    x = Conv2D(
        filters=64,
        kernel_size=7,
        strides=2,
        padding='same',
        kernel_initializer='he_normal')(inputs)
    x = BatchNormalization(momentum=0.9)(x)
    x = ReLU()(x)
    x = BottleneckBlock(x, 128, downsample=True)
    x = MaxPool2D(pool_size=2, strides=2)(x)
    x = BottleneckBlock(x, 128, downsample=False)
    x = BottleneckBlock(x, 256, downsample=True)

    ys = []
    for i in range(num_stack):
        x = HourglassModule(x, order=4, filters=256, num_residual=num_residual)
        for i in range(num_residual):
            x = BottleneckBlock(x, 256, downsample=False)

        x = LinearLayer(x, 256)

        y = Conv2D(
            filters=num_heatmap,
            kernel_size=1,
            strides=1,
            padding='same',
            kernel_initializer='he_normal')(x)
        ys.append(y)

        if i < num_stack - 1:
            y_intermediate_1 = Conv2D(filters=256, kernel_size=1, strides=1)(x)
            y_intermediate_2 = Conv2D(filters=256, kernel_size=1, strides=1)(y)
            x = Add()([y_intermediate_1, y_intermediate_2])

    return tf.keras.Model(inputs, ys, name='stacked_hourglass')

# GPU가 여러개인 환경에서 학습 설정

In [19]:
class Trainer(object):
    def __init__(self,
                 model,
                 epochs,
                 global_batch_size,
                 strategy,
                 initial_learning_rate):
        self.model = model
        self.epochs = epochs
        self.strategy = strategy
        self.global_batch_size = global_batch_size
        self.loss_object = tf.keras.losses.MeanSquaredError(
            reduction=tf.keras.losses.Reduction.NONE)
        self.optimizer = tf.keras.optimizers.Adam(
            learning_rate=initial_learning_rate)
        self.model = model

        self.current_learning_rate = initial_learning_rate
        self.last_val_loss = math.inf
        self.lowest_val_loss = math.inf
        self.patience_count = 0
        self.max_patience = 10
        self.best_model = None

    def lr_decay(self):
        if self.patience_count >= self.max_patience:
            self.current_learning_rate /= 10.0
            self.patience_count = 0
        elif self.last_val_loss == self.lowest_val_loss:
            self.patience_count = 0
        self.patience_count += 1

        self.optimizer.learning_rate = self.current_learning_rate

    def lr_decay_step(self, epoch):
        if epoch == 25 or epoch == 50 or epoch == 75:
            self.current_learning_rate /= 10.0
        self.optimizer.learning_rate = self.current_learning_rate

    def compute_loss(self, labels, outputs):
        loss = tf.constant(0.0, dtype=tf.float32) # float32 타입으로 초기화
        labels = tf.cast(labels, dtype=tf.float32) # labels를 float32로 변환
        for output in outputs:
            output = tf.cast(output, dtype=tf.float32) # output을 float32로 변환
            weights = tf.cast(labels > 0, dtype=tf.float32) * 81 + 1
            loss += tf.math.reduce_mean(
                tf.math.square(labels - output) * weights) * (
                    1. / self.global_batch_size)
        return loss

    def train_step(self, inputs):
        images, labels = inputs
        with tf.GradientTape() as tape:
            outputs = self.model(images, training=True)
            loss = self.compute_loss(labels, outputs)

        grads = tape.gradient(
            target=loss, sources=self.model.trainable_variables)
        self.optimizer.apply_gradients(
            zip(grads, self.model.trainable_variables))

        return loss

    def val_step(self, inputs):
        images, labels = inputs
        outputs = self.model(images, training=False)
        loss = self.compute_loss(labels, outputs)
        return loss

    def run(self, train_dist_dataset, val_dist_dataset):
        @tf.function
        def distributed_train_epoch(dataset):
            tf.print('Start distributed traininng...')
            total_loss = 0.0
            num_train_batches = 0.0
            for one_batch in dataset:
                per_replica_loss = self.strategy.run(
                    self.train_step, args=(one_batch, ))
                batch_loss = self.strategy.reduce(
                    tf.distribute.ReduceOp.SUM, per_replica_loss, axis=None)
                total_loss += tf.cast(batch_loss, tf.float32)
                num_train_batches += 1
                tf.print('Trained batch', num_train_batches, 'batch loss',
                         batch_loss, 'epoch total loss', total_loss / num_train_batches)
            return total_loss, num_train_batches

        @tf.function
        def distributed_val_epoch(dataset):
            total_loss = 0.0
            num_val_batches = 0.0
            for one_batch in dataset:
                per_replica_loss = self.strategy.run(
                    self.val_step, args=(one_batch, ))
                num_val_batches += 1
                batch_loss = self.strategy.reduce(
                    tf.distribute.ReduceOp.SUM, per_replica_loss, axis=None)
                tf.print('Validated batch', num_val_batches, 'batch loss',
                         batch_loss)
                if not tf.math.is_nan(batch_loss):
                    # TODO: Find out why the last validation batch loss become NaN
                    total_loss += tf.cast(batch_loss, tf.float32)
                else:
                    num_val_batches -= 1

            return total_loss, num_val_batches

        for epoch in range(1, self.epochs + 1):
            self.lr_decay()
            print('Start epoch {} with learning rate {}'.format(
                epoch, self.current_learning_rate))

            train_total_loss, num_train_batches = distributed_train_epoch(
                train_dist_dataset)
            train_loss = train_total_loss / num_train_batches
            print('Epoch {} train loss {}'.format(epoch, train_loss))

            val_total_loss, num_val_batches = distributed_val_epoch(
                val_dist_dataset)
            val_loss = val_total_loss / num_val_batches
            print('Epoch {} val loss {}'.format(epoch, val_loss))

            # save model when reach a new lowest validation loss
            if val_loss < self.lowest_val_loss:
                self.save_model(epoch, val_loss)
                self.lowest_val_loss = val_loss
            self.last_val_loss = val_loss

        return self.best_model

    def save_model(self, epoch, loss):
        model_name = MODEL_PATH + '/model-epoch-{}-loss-{:.4f}.h5'.format(epoch, loss)
        self.model.save_weights(model_name)
        self.best_model = model_name
        print("Model {} saved.".format(model_name))

In [20]:
IMAGE_SHAPE = (256, 256, 3)
HEATMAP_SIZE = (64, 64)
# 데이터셋 만드는 함수
def create_dataset(tfrecords, batch_size, num_heatmap, is_train):
    preprocess = Preprocessor(
        IMAGE_SHAPE, (HEATMAP_SIZE[0], HEATMAP_SIZE[1], num_heatmap), is_train)

    dataset = tf.data.Dataset.list_files(tfrecords)
    dataset = tf.data.TFRecordDataset(dataset)
    dataset = dataset.map(
        preprocess, num_parallel_calls=tf.data.experimental.AUTOTUNE)

    if is_train:
        dataset = dataset.shuffle(batch_size)

    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

    return dataset

In [21]:
# 데이터셋 연결 및 학습
def train(epochs, learning_rate, num_heatmap, batch_size, train_tfrecords, val_tfrecords):
    strategy = tf.distribute.MirroredStrategy()
    global_batch_size = strategy.num_replicas_in_sync * batch_size
    train_dataset = create_dataset(
        train_tfrecords, global_batch_size, num_heatmap, is_train=True)
    val_dataset = create_dataset(
        val_tfrecords, global_batch_size, num_heatmap, is_train=False)

    if not os.path.exists(MODEL_PATH):
        os.makedirs(MODEL_PATH)

    with strategy.scope():
        train_dist_dataset = strategy.experimental_distribute_dataset(
            train_dataset)
        val_dist_dataset = strategy.experimental_distribute_dataset(
            val_dataset)

        #model = StackedHourglassNetwork(IMAGE_SHAPE, 4, 1, num_heatmap)
        model = SimpleBaseline(IMAGE_SHAPE)

        trainer = Trainer(
            model,
            epochs,
            global_batch_size,
            strategy,
            initial_learning_rate=learning_rate)

        print('Start training...')
        return trainer.run(train_dist_dataset, val_dist_dataset)

# 학습

In [23]:
train_tfrecords = os.path.join(TFRECORD_PATH, 'train*')
val_tfrecords = os.path.join(TFRECORD_PATH, 'val*')
epochs = 3
batch_size = 16
num_heatmap = 16
learning_rate = 0.0007

best_model_file = train(epochs, learning_rate, num_heatmap, batch_size, train_tfrecords, val_tfrecords)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
Start training...
Start epoch 1 with learning rate 0.0007
Start distributed traininng...
Trained batch 1 batch loss 2.5146277 epoch total loss 2.5146277
Trained batch 2 batch loss 2.45807815 epoch total loss 2.48635292
Trained batch 3 batch loss 2.55746269 epoch total loss 2.51005626
Trained batch 4 batch loss 2.37066984 epoch total loss 2.47520971
Trained batch 5 batch loss 2.38536167 epoch total loss 2.45724
Trained batch 6 batch loss 2.33650112 epoch total loss 2.43711686
Trained batch 7 batch loss 2.26773787 epoch total loss 2.41292
Trained batch 8 batch loss 2.09957528 epoch total loss 2.37375188
Trained batch 9 batch loss 2.00856137 epoch total loss 2.33317518
Trained batch 10 batch loss 2.1967082 epoch total loss 2.31952858
Trained batch 11 batch loss 2.12115669 epoch total loss 2.30149484
Trained batch 12 batch loss 2.12126493 epoch total loss 2.28647566
Trained batch 13 batch 

In [25]:
#WEIGHTS_PATH = os.path.join(PROJECT_PATH, 'models', 'model-v0.0.1-epoch-2-loss-1.3072.h5')

model = StackedHourglassNetwork(IMAGE_SHAPE, 4, 1)
model.load_weights(best_model_file)

In [26]:
# learning curve
plt.plot(model.history['train_loss'], label='Training Loss')
plt.plot(model.history['val_loss'], label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()

TypeError: 'NoneType' object is not subscriptable

# Simple Baseline

In [22]:
def _make_deconv_layer(num_deconv_layers):
    seq_model = tf.keras.models.Sequential()
    for i in range(num_deconv_layers):
        seq_model.add(tf.keras.layers.Conv2DTranspose(256, kernel_size=(4,4), strides=(2,2), padding='same'))
        seq_model.add(tf.keras.layers.BatchNormalization())
        seq_model.add(tf.keras.layers.ReLU())
    return seq_model

In [23]:
# resnet 선언
resnet = tf.keras.applications.resnet.ResNet50(include_top=False, weights='imagenet')
# 3개의 deconv+bn+relu layer
upconv = _make_deconv_layer(3)
# 마지막 layer
final_layer = tf.keras.layers.Conv2D(16, kernel_size=(1,1), padding='same')

In [24]:
from tensorflow import keras
from tensorflow.keras import layers
# 모델 함수
def SimpleBaseline(input_shape=(256, 256, 3)):
    
    inputs = keras.Input(shape=input_shape)
    x = resnet(inputs)
    x = upconv(x)
    out = final_layer(x)
    model = keras.Model(inputs, out)
    
    return tf.keras.Model(inputs, out, name='simple_baseline')

In [25]:
# 모델 학습
MODEL_PATH = os.path.join(PROJECT_PATH, 'simple')
train_tfrecords = os.path.join(TFRECORD_PATH, 'train*')
val_tfrecords = os.path.join(TFRECORD_PATH, 'val*')
epochs = 3
batch_size = 16
num_heatmap = 16
learning_rate = 0.0007

best_model_file_simple = train(epochs, learning_rate, num_heatmap, batch_size, train_tfrecords, val_tfrecords)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/



INFO:tensorflow:Error reported to Coordinator: in user code:

    /tmp/ipykernel_472/3568804479.py:59 train_step  *
        self.optimizer.apply_gradients(
    /opt/conda/lib/python3.9/site-packages/keras/optimizer_v2/optimizer_v2.py:628 apply_gradients  **
        self._create_all_weights(var_list)
    /opt/conda/lib/python3.9/site-packages/keras/optimizer_v2/optimizer_v2.py:815 _create_all_weights
        self._create_slots(var_list)
    /opt/conda/lib/python3.9/site-packages/keras/optimizer_v2/adam.py:117 _create_slots
        self.add_slot(var, 'm')
    /opt/conda/lib/python3.9/site-packages/keras/optimizer_v2/optimizer_v2.py:892 add_slot
        raise ValueError(

    ValueError: Trying to create optimizer slot variable under the scope for tf.distribute.Strategy (<tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x798c03a58f70>), which is different from the scope used for the original variable (<tf.Variable 'conv1_conv/kernel:0' shape=(7, 7, 3, 64) dtype=f

ValueError: in user code:

    /tmp/ipykernel_472/3568804479.py:77 distributed_train_epoch  *
        per_replica_loss = self.strategy.run(
    /tmp/ipykernel_472/3568804479.py:59 train_step  *
        self.optimizer.apply_gradients(
    /opt/conda/lib/python3.9/site-packages/keras/optimizer_v2/optimizer_v2.py:628 apply_gradients  **
        self._create_all_weights(var_list)
    /opt/conda/lib/python3.9/site-packages/keras/optimizer_v2/optimizer_v2.py:815 _create_all_weights
        self._create_slots(var_list)
    /opt/conda/lib/python3.9/site-packages/keras/optimizer_v2/adam.py:117 _create_slots
        self.add_slot(var, 'm')
    /opt/conda/lib/python3.9/site-packages/keras/optimizer_v2/optimizer_v2.py:892 add_slot
        raise ValueError(

    ValueError: Trying to create optimizer slot variable under the scope for tf.distribute.Strategy (<tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x798c03a58f70>), which is different from the scope used for the original variable (<tf.Variable 'conv1_conv/kernel:0' shape=(7, 7, 3, 64) dtype=float32, numpy=
    array([[[[ 2.82526277e-02, -1.18737184e-02,  1.51488732e-03, ...,
              -1.07003953e-02, -5.27982824e-02, -1.36667420e-03],
             [ 5.86827798e-03,  5.04415408e-02,  3.46324709e-03, ...,
               1.01423981e-02,  1.39493728e-02,  1.67549420e-02],
             [-2.44090753e-03, -4.86173332e-02,  2.69966386e-03, ...,
              -3.44439060e-04,  3.48098315e-02,  6.28910400e-03]],
    
            [[ 1.81872323e-02, -7.20698107e-03,  4.80302610e-03, ...,
              -7.43396254e-03, -8.56800564e-03,  1.16849300e-02],
             [ 1.87554304e-02,  5.12730293e-02,  4.50406177e-03, ...,
               1.39413681e-02,  1.26296384e-02, -1.73004344e-02],
             [ 1.90453827e-02, -3.87909152e-02,  4.25842637e-03, ...,
               2.75742816e-04, -1.27962548e-02, -8.35626759e-03]],
    
            [[ 1.58849321e-02, -1.06073255e-02,  1.30999666e-02, ...,
              -2.26797583e-03, -3.98984266e-04,  3.39989027e-04],
             [ 3.61421369e-02,  5.02430499e-02,  1.22699486e-02, ...,
               1.19910473e-02,  2.02837810e-02, -1.96981970e-02],
             [ 2.17959806e-02, -3.86004597e-02,  1.12379901e-02, ...,
              -2.07756506e-03, -3.40645364e-03, -3.78638096e-02]],
    
            ...,
    
            [[-5.30153252e-02, -8.60502943e-03,  6.38643000e-03, ...,
              -4.49256925e-03,  3.48024699e-03, -1.40979560e-02],
             [-9.35578942e-02,  4.61557060e-02,  1.53722311e-03, ...,
               1.21013075e-02,  5.05337631e-03,  3.30474339e-02],
             [-7.69589692e-02, -3.51354294e-02,  2.22769519e-03, ...,
               9.18304977e-06, -1.15465783e-02,  2.29630154e-02]],
    
            [[-4.73558307e-02, -4.07940615e-03,  4.76515992e-03, ...,
              -9.73805040e-03, -1.03890402e-02,  1.62366014e-02],
             [-1.24100089e-01,  4.78516519e-02, -9.90210217e-04, ...,
               1.10340826e-02, -6.77202828e-03,  5.49102016e-02],
             [-7.13113099e-02, -2.86470409e-02,  6.20829698e-04, ...,
              -2.17762636e-03, -1.58942658e-02,  3.44766974e-02]],
    
            [[ 1.85429510e-02, -1.12518407e-02,  1.12506151e-02, ...,
              -1.51338596e-02, -5.66656142e-03, -1.30050071e-02],
             [-2.68079005e-02,  3.64737920e-02,  4.55197273e-03, ...,
               5.53486776e-03,  1.12653999e-02,  2.46754289e-03],
             [ 1.43940765e-02, -3.56382579e-02,  5.08728763e-03, ...,
              -7.46753719e-03,  1.61169283e-02,  1.12382937e-02]]],
    
    
           [[[ 7.99009297e-03, -9.49061289e-03, -4.21846565e-03, ...,
              -1.23715792e-02, -3.82804796e-02, -5.90979494e-03],
             [-7.68794632e-03,  5.46954982e-02, -1.03303632e-02, ...,
               1.40626412e-02,  1.99436247e-02,  2.51518637e-02],
             [ 3.70471564e-04, -3.70203964e-02, -9.80611611e-03, ...,
              -4.95379185e-03,  2.27415562e-02,  1.38941938e-02]],
    
            [[ 2.48856675e-02, -9.57963988e-03, -2.37837038e-03, ...,
              -1.08526833e-02,  2.24138368e-02, -2.40965877e-02],
             [ 2.42966190e-02,  4.93442900e-02, -1.32921906e-02, ...,
               1.47738317e-02,  2.67323572e-02,  1.14357602e-02],
             [ 2.91274227e-02, -3.05654686e-02, -1.42364930e-02, ...,
              -8.36174563e-03, -3.00847553e-02, -2.51545687e-03]],
    
            [[ 7.67260045e-02, -1.19650066e-02, -2.10191216e-03, ...,
               1.79589365e-03,  2.02653632e-02, -1.33340694e-02],
             [ 1.49444759e-01,  5.00719361e-02, -1.52172269e-02, ...,
               1.83409695e-02,  1.56401172e-02,  8.53796005e-02],
             [ 1.17180273e-01, -2.56576538e-02, -1.85890812e-02, ...,
              -2.50462536e-03, -5.22738546e-02,  1.17943510e-02]],
    
            ...,
    
            [[-1.89151186e-02, -1.06457584e-02, -1.19606184e-03, ...,
              -7.13960640e-03,  7.56816342e-02,  8.62411484e-02],
             [ 1.33888470e-02,  4.24321182e-02, -1.93305630e-02, ...,
               8.93499516e-03,  3.26688178e-02,  1.71118364e-01],
             [-9.38678440e-03, -2.88689751e-02, -1.87061988e-02, ...,
              -1.06920488e-02, -4.56195511e-02,  1.51734307e-01]],
    
            [[-7.93561861e-02, -8.69292021e-03,  1.06180850e-02, ...,
              -8.22936464e-03,  5.34521677e-02,  2.43676770e-02],
             [-1.76872283e-01,  4.03351039e-02, -6.91946782e-03, ...,
               1.14902109e-02,  2.45164465e-02,  1.30252065e-02],
             [-1.30214587e-01, -2.94868350e-02, -1.32359739e-03, ...,
              -8.08166154e-03, -3.32693383e-02,  1.78283844e-02]],
    
            [[-1.53617216e-02, -1.02823023e-02,  1.44553250e-02, ...,
              -1.23689836e-02,  2.81683691e-02, -1.52645903e-02],
             [-1.22947149e-01,  3.72432098e-02, -2.82740779e-03, ...,
               1.07275983e-02,  1.61965452e-02, -4.08420824e-02],
             [-7.92325959e-02, -3.09139602e-02,  1.91061670e-04, ...,
              -1.06926244e-02, -1.36199640e-02, -2.90216487e-02]]],
    
    
           [[[-2.74732877e-02, -1.59629062e-02,  5.87167032e-03, ...,
              -1.18064405e-02, -5.19699305e-02, -1.52737210e-02],
             [-7.46604949e-02,  5.22083789e-02, -1.98963331e-03, ...,
               1.27452025e-02,  7.53643783e-03, -1.96208209e-02],
             [-3.34048420e-02, -3.39833461e-02, -1.99538236e-03, ...,
              -9.30251833e-03,  3.30174603e-02, -1.65446047e-02]],
    
            [[-6.57535121e-02, -1.23513499e-02, -4.16519074e-03, ...,
              -1.22041989e-03,  2.09396798e-02,  3.62350084e-02],
             [-1.52494013e-01,  4.94739972e-02, -1.83443855e-02, ...,
               2.37025358e-02,  2.67230812e-02,  8.47681686e-02],
             [-8.80744159e-02, -2.57136654e-02, -2.17252262e-02, ...,
              -3.12197860e-03, -2.06513535e-02,  6.63726628e-02]],
    
            [[ 1.99921392e-02, -1.76080931e-02,  1.81755237e-03, ...,
               3.69562432e-02,  3.51557694e-02,  1.03931516e-01],
             [ 6.10242449e-02,  4.46803048e-02, -1.41719123e-02, ...,
               5.15808910e-02,  2.07974892e-02,  1.46060020e-01],
             [ 8.05315524e-02, -2.88072433e-02, -1.85981095e-02, ...,
               2.20173039e-02, -5.11762947e-02,  1.40093669e-01]],
    
            ...,
    
            [[ 1.15528561e-01, -1.67486407e-02,  8.49904679e-03, ...,
               4.99674492e-03,  7.98972845e-02, -1.11083500e-01],
             [ 3.32334489e-01,  4.24566194e-02, -9.70878359e-03, ...,
               1.92873720e-02,  1.25060824e-03, -3.40990961e-01],
             [ 2.16480315e-01, -2.68480480e-02, -8.96557700e-03, ...,
              -6.44540135e-03, -7.85448179e-02, -2.04899684e-01]],
    
            [[-8.99803787e-02, -8.51823762e-03,  2.25046948e-02, ...,
              -8.74274992e-04,  6.35959804e-02, -9.58404392e-02],
             [-8.15074593e-02,  4.37885672e-02,  3.69152403e-03, ...,
               1.71142723e-02,  6.33937493e-03, -2.73919165e-01],
             [-9.73245725e-02, -2.61962153e-02,  8.95403326e-03, ...,
              -7.23934872e-03, -5.64266555e-02, -1.84837982e-01]],
    
            [[-9.46454927e-02, -1.17739988e-02,  2.49665454e-02, ...,
              -7.38179125e-03,  3.05740479e-02, -1.17530329e-02],
             [-2.11111471e-01,  3.85808311e-02,  5.31885307e-03, ...,
               1.61544569e-02,  3.10361455e-03, -8.36645439e-02],
             [-1.75075874e-01, -3.21811885e-02,  9.45197884e-03, ...,
              -1.05473688e-02, -2.80730613e-02, -6.67640790e-02]]],
    
    
           ...,
    
    
           [[[ 2.31804699e-02, -1.62718501e-02,  1.22078890e-02, ...,
              -1.22131845e-02, -2.02786643e-02, -2.14508991e-03],
             [ 2.30488200e-02,  4.41800952e-02,  3.59291583e-03, ...,
               1.27932075e-02,  6.47032401e-03, -5.39429188e-02],
             [ 2.03978457e-02, -2.67958529e-02,  5.69844292e-03, ...,
              -8.20858125e-03,  2.51460597e-02, -3.12512405e-02]],
    
            [[-4.64516319e-02, -1.34653188e-02,  1.61393601e-02, ...,
              -2.20572166e-02,  5.05596139e-02,  1.47165358e-03],
             [-1.77852944e-01,  4.04180661e-02,  4.32515051e-03, ...,
               7.27979047e-03,  1.37663782e-02, -5.00506982e-02],
             [-1.09063022e-01, -2.11244933e-02,  6.98045455e-03, ...,
              -2.00869981e-02, -6.30094185e-02, -4.20499854e-02]],
    
            [[-1.83006614e-01, -1.79655701e-02,  1.82811301e-02, ...,
               1.56401389e-03,  9.29453745e-02,  4.12672907e-02],
             [-4.11783189e-01,  3.40776965e-02,  8.74394365e-03, ...,
               2.33494844e-02,  1.98237225e-02,  8.06325078e-02],
             [-2.76736170e-01, -2.83147153e-02,  1.31541817e-02, ...,
              -5.05925808e-03, -8.54580775e-02,  4.26753834e-02]],
    
            ...,
    
            [[ 5.36167026e-02, -1.07590063e-02,  2.19804980e-02, ...,
              -8.83348845e-03,  1.40453711e-01,  3.20528477e-01],
             [ 1.85792699e-01,  3.76442447e-02,  1.02089429e-02, ...,
               1.29263047e-02, -3.70457745e-03,  6.66479290e-01],
             [ 1.32038444e-01, -2.75047179e-02,  2.28339490e-02, ...,
              -1.19996015e-02, -1.22367747e-01,  4.83815670e-01]],
    
            [[ 8.34956467e-02, -9.09057911e-03,  2.50242520e-02, ...,
              -1.67011786e-02,  1.20522320e-01,  1.36462688e-01],
             [ 2.50555605e-01,  4.07686047e-02,  1.08884834e-02, ...,
               7.53540406e-03, -7.55708572e-03,  3.96415204e-01],
             [ 1.49690762e-01, -3.11034787e-02,  2.43526250e-02, ...,
              -1.65321939e-02, -1.09688722e-01,  2.64446586e-01]],
    
            [[ 3.69576029e-02, -1.27014471e-02,  3.19833457e-02, ...,
              -1.48784053e-02,  9.22970548e-02,  6.54868260e-02],
             [ 9.63706747e-02,  4.39107306e-02,  1.59802549e-02, ...,
               1.22494521e-02,  8.10312852e-03,  1.78935930e-01],
             [ 2.95156911e-02, -2.96487771e-02,  2.69996542e-02, ...,
              -1.38547905e-02, -7.72434175e-02,  1.32773802e-01]]],
    
    
           [[[ 4.22548056e-02, -8.30464344e-03,  5.34065207e-03, ...,
              -8.06468353e-03, -4.70053628e-02,  4.45614867e-02],
             [ 9.77012664e-02,  3.83502319e-02, -5.37837343e-03, ...,
               1.17106764e-02, -4.59602941e-03,  6.98771998e-02],
             [ 6.38262108e-02, -2.08319575e-02, -1.72756368e-03, ...,
              -8.19445588e-03,  4.25621867e-02,  4.83920909e-02]],
    
            [[ 4.59470600e-02, -4.77699284e-03,  7.04339007e-03, ...,
              -1.82104297e-02,  3.14848162e-02,  4.64068204e-02],
             [ 3.89483608e-02,  3.78783308e-02, -6.85291924e-03, ...,
               7.33014196e-03,  3.90656322e-04,  1.52848229e-01],
             [ 4.57218140e-02, -1.34090437e-02, -8.30697361e-04, ...,
              -1.85202472e-02, -3.45353335e-02,  9.25581828e-02]],
    
            [[-4.66161780e-02, -1.22223441e-02,  9.35023464e-03, ...,
              -1.31351836e-02,  6.08736612e-02,  9.18865502e-02],
             [-1.92336142e-01,  3.18407975e-02, -1.01881009e-03, ...,
               7.55425170e-03, -8.62357323e-04,  2.88297594e-01],
             [-1.15666650e-01, -2.35320851e-02,  6.74636895e-03, ...,
              -1.94703583e-02, -5.66169359e-02,  1.95824102e-01]],
    
            ...,
    
            [[-2.10239179e-02, -9.81471874e-03,  9.81596112e-03, ...,
              -1.36731779e-02,  1.20193027e-01, -1.26708716e-01],
             [-3.72992679e-02,  3.05935629e-02, -3.00194928e-03, ...,
               8.85152724e-03, -5.07611316e-03, -6.25461042e-02],
             [ 7.84674310e-04, -2.91344281e-02,  1.12569630e-02, ...,
              -1.38232643e-02, -9.49400812e-02, -8.74437019e-02]],
    
            [[ 3.32221799e-02, -4.22911346e-03,  1.13633750e-02, ...,
              -1.41841583e-02,  9.59840789e-02, -1.23203963e-01],
             [ 9.95653942e-02,  4.03233357e-02, -4.36036801e-03, ...,
               8.42505507e-03, -1.50266392e-02, -1.58158958e-01],
             [ 6.55353814e-02, -2.76978761e-02,  1.06595978e-02, ...,
              -1.31017175e-02, -9.93799716e-02, -1.52014121e-01]],
    
            [[ 2.50522885e-02, -1.08845932e-02,  1.29567981e-02, ...,
              -1.67823900e-02,  6.55406937e-02, -3.34061496e-02],
             [ 1.00219429e-01,  4.24924381e-02, -4.06364352e-03, ...,
               8.98410939e-03, -1.98677508e-03, -9.19047296e-02],
             [ 6.97101504e-02, -3.41515057e-02,  8.97936709e-03, ...,
              -1.51484888e-02, -8.06454644e-02, -8.53376985e-02]]],
    
    
           [[[ 1.46303158e-02, -9.15218703e-03,  5.24803856e-03, ...,
              -3.63799883e-03, -5.51798902e-02, -7.19531113e-03],
             [ 6.12211153e-02,  2.67034862e-02, -4.38000960e-03, ...,
               1.38858845e-02,  1.62421225e-03,  6.91889692e-03],
             [ 1.86353922e-02, -2.39325576e-02,  5.56383107e-04, ...,
              -6.68733614e-03,  7.36468807e-02,  3.71867418e-02]],
    
            [[ 3.52302976e-02, -3.27857491e-03,  7.14091491e-03, ...,
              -9.93822515e-03,  2.38756705e-02, -2.10771449e-02],
             [ 6.34438619e-02,  3.12160589e-02, -7.72275496e-03, ...,
               1.49217555e-02,  3.86624038e-03, -1.16395289e-02],
             [ 3.35849188e-02, -1.63664240e-02, -1.32562651e-03, ...,
              -1.30512416e-02, -7.29435496e-03, -1.24825155e-02]],
    
            [[ 4.10873676e-03, -4.66612726e-03,  1.21031692e-02, ...,
              -7.87103828e-03,  5.80726229e-02, -4.19587009e-02],
             [-2.23153979e-02,  2.99241953e-02,  8.01213668e-04, ...,
               1.82199273e-02,  9.57238674e-03, -8.57376456e-02],
             [-2.01183017e-02, -1.96383689e-02,  7.32050464e-03, ...,
              -1.07293837e-02, -2.17854325e-02, -7.95444921e-02]],
    
            ...,
    
            [[-1.71692297e-02, -3.16392444e-03,  2.40169745e-03, ...,
              -9.67177004e-03,  9.26117748e-02, -1.16062798e-02],
             [-8.63026828e-02,  3.55335064e-02, -1.06153013e-02, ...,
               1.85809545e-02, -2.19932254e-02, -1.47949710e-01],
             [-6.07556999e-02, -2.66596545e-02,  1.74473948e-03, ...,
              -4.85855900e-03, -8.82942155e-02, -8.43590796e-02]],
    
            [[ 1.15142548e-02,  2.20947526e-03,  5.08834422e-03, ...,
              -1.04352133e-02,  6.78158402e-02,  4.14623357e-02],
             [ 7.41827395e-03,  4.52373996e-02, -1.10873608e-02, ...,
               1.56368576e-02, -2.37460397e-02, -3.25448737e-02],
             [ 7.84576032e-03, -2.45320965e-02,  5.84031455e-04, ...,
              -8.31448287e-03, -8.92601907e-02, -3.36888898e-03]],
    
            [[ 4.79146978e-03, -4.22942545e-03,  1.15078716e-02, ...,
              -2.12721284e-02,  4.96782959e-02,  2.05268860e-02],
             [ 2.75192987e-02,  4.36737053e-02, -5.71439136e-03, ...,
               9.46100149e-03, -8.58635467e-04, -1.79863740e-02],
             [ 2.71184333e-02, -3.31169143e-02,  3.97488568e-03, ...,
              -1.41424611e-02, -6.35233149e-02,  1.29984575e-03]]]],
          dtype=float32)>). Make sure the slot variables are created under the same strategy scope. This may happen if you're restoring from a checkpoint outside the scope


# 예측 엔진

In [None]:
# 변수 지정
R_ANKLE = 0
R_KNEE = 1
R_HIP = 2
L_HIP = 3
L_KNEE = 4
L_ANKLE = 5
PELVIS = 6
THORAX = 7
UPPER_NECK = 8
HEAD_TOP = 9
R_WRIST = 10
R_ELBOW = 11
R_SHOULDER = 12
L_SHOULDER = 13
L_ELBOW = 14
L_WRIST = 15

MPII_BONES = [
    [R_ANKLE, R_KNEE],
    [R_KNEE, R_HIP],
    [R_HIP, PELVIS],
    [L_HIP, PELVIS],
    [L_HIP, L_KNEE],
    [L_KNEE, L_ANKLE],
    [PELVIS, THORAX],
    [THORAX, UPPER_NECK],
    [UPPER_NECK, HEAD_TOP],
    [R_WRIST, R_ELBOW],
    [R_ELBOW, R_SHOULDER],
    [THORAX, R_SHOULDER],
    [THORAX, L_SHOULDER],
    [L_SHOULDER, L_ELBOW],
    [L_ELBOW, L_WRIST]
]

In [None]:
# heatmap에서 최댓값을 찾는 함수
def find_max_coordinates(heatmaps):
    flatten_heatmaps = tf.reshape(heatmaps, (-1, 16))
    indices = tf.math.argmax(flatten_heatmaps, axis=0)
    y = tf.cast(indices / 64, dtype=tf.int64)
    x = indices - 64 * y
    return tf.stack([x, y], axis=1).numpy()

In [None]:
def extract_keypoints_from_heatmap(heatmaps):
    max_keypoints = find_max_coordinates(heatmaps)

    padded_heatmap = np.pad(heatmaps, [[1,1],[1,1],[0,0]], mode='constant')
    adjusted_keypoints = []
    for i, keypoint in enumerate(max_keypoints):
        max_y = keypoint[1]+1
        max_x = keypoint[0]+1
        
        patch = padded_heatmap[max_y-1:max_y+2, max_x-1:max_x+2, i]
        patch[1][1] = 0
        
        index = np.argmax(patch)
        
        next_y = index // 3
        next_x = index - next_y * 3
        delta_y = (next_y - 1) / 4
        delta_x = (next_x - 1) / 4
        
        adjusted_keypoint_x = keypoint[0] + delta_x
        adjusted_keypoint_y = keypoint[1] + delta_y
        adjusted_keypoints.append((adjusted_keypoint_x, adjusted_keypoint_y))
        
    adjusted_keypoints = np.clip(adjusted_keypoints, 0, 64)
    normalized_keypoints = adjusted_keypoints / 64
    return normalized_keypoints

In [None]:
# 모델과 이미지를 입력하면 이미지와 keypoint를 출력하는 함수
def predict(model, image_path):
    encoded = tf.io.read_file(image_path)
    image = tf.io.decode_jpeg(encoded)
    inputs = tf.image.resize(image, (256, 256))
    inputs = tf.cast(inputs, tf.float32) / 127.5 - 1
    inputs = tf.expand_dims(inputs, 0)
    outputs = model(inputs, training=False)
    if type(outputs) != list:
        outputs = [outputs]
    heatmap = tf.squeeze(outputs[-1], axis=0).numpy()
    kp = extract_keypoints_from_heatmap(heatmap)
    return image, kp

In [None]:
# keypoint 그리는 함수
def draw_keypoints_on_image(image, keypoints, index=None):
    fig,ax = plt.subplots(1)
    ax.imshow(image)
    joints = []
    for i, joint in enumerate(keypoints):
        joint_x = joint[0] * image.shape[1]
        joint_y = joint[1] * image.shape[0]
        if index is not None and index != i:
            continue
        plt.scatter(joint_x, joint_y, s=10, c='red', marker='o')
    plt.show()
# skeleton 그리는 함수
def draw_skeleton_on_image(image, keypoints, index=None):
    fig,ax = plt.subplots(1)
    ax.imshow(image)
    joints = []
    for i, joint in enumerate(keypoints):
        joint_x = joint[0] * image.shape[1]
        joint_y = joint[1] * image.shape[0]
        joints.append((joint_x, joint_y))
    
    for bone in MPII_BONES:
        joint_1 = joints[bone[0]]
        joint_2 = joints[bone[1]]
        plt.plot([joint_1[0], joint_2[0]], [joint_1[1], joint_2[1]], linewidth=5, alpha=0.7)
    plt.show()

In [None]:
# 테스트 이미지로 성능 확인
test_image = os.path.join(PROJECT_PATH, 'test_image.jpg')

image, keypoints = predict(model, test_image)
draw_keypoints_on_image(image, keypoints)
draw_skeleton_on_image(image, keypoints)

# 회고

- 오류 해결 : strategy를 모델 선언과 같은 곳에서 해야 한다.
하지만 고쳐볼 시간은 없었다.