# 3. 데이터 정리하기
## 학습목표
데이터를 정리해서 TFRecord형태로 저장하기

[TensorFlow] Dataset 모듈 및 TFRecord 기본 사용법 정리
TensorFlow에서 학습 데이터를 불러올 수 있게 해주는 모듈인 Dataset 모듈 및 TFRecord에 대한 기본적인 사용법 정리
https://hcnoh.github.io/2018-11-05-tensorflow-data-module

## TFRecord Builder

In [1]:
import glob
import os
import tensorflow as tf
import cv2

## Paths and Hyperparameters

In [2]:
DATASET_OK_PATTERN = 'dataset/3/OK/*.png'
DATASET_FAIL_PATTERN = 'dataset/3/FAIL/*.png'

TFRECORD_PATH = 'tfrecords/'
IMAGE_PER_TFRECORD = 100

## Import data

In [3]:
ok_list = glob.glob(DATASET_OK_PATTERN)
fail_list = glob.glob(DATASET_FAIL_PATTERN)

#개수 저장
num_ok = len(ok_list)
num_fail = len(fail_list)

# Oversampling - num_ok가 num_fail보다 훨씬 많아서 num_ok만큼 반복해서 개수를 맞춰보도록 함. 
fail_list_new = list()
for _ in range(num_ok // num_fail):
    fail_list_new += fail_list
fail_list_new += fail_list[: num_ok % num_fail]
fail_list = fail_list_new

#Oversampling이 끝나고 ok_label과 fail_label을 만들어주고 (이전 시간에 배웠던 것)
ok_label = [0] * len(ok_list)
fail_label = [1] * len(fail_list)

#file과 label을 묶어준다 
file_list = ok_list + fail_list
label_list = ok_label + fail_label

## TFRecord functions - TFRecord를 build하는데 쓰이는 furnction들은 tensorflow정식 API tutorial에서 발췌하여서 그대로 쓴 것이다. 
그대로 쓰는 것도 괜찮다!

In [4]:
def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def image_example(image_string, label):
    image_shape = tf.image.decode_image(image_string).shape

    feature = {
        'height': _int64_feature(image_shape[0]),
        'width': _int64_feature(image_shape[1]),
        'depth': _int64_feature(image_shape[2]),
        'label': _int64_feature(label),
        'image_raw': _bytes_feature(image_string),
    }

    return tf.train.Example(features=tf.train.Features(feature=feature))

## Write TFRecords

In [5]:
if os.path.exists(TFRECORD_PATH) is False:
    os.mkdir(TFRECORD_PATH)

num_tfrecords = len(file_list) // IMAGE_PER_TFRECORD
if len(file_list) % IMAGE_PER_TFRECORD != 0:
    num_tfrecords += 1

for idx in range(num_tfrecords):
    idx0 = idx * IMAGE_PER_TFRECORD
    idx1 = idx0 + IMAGE_PER_TFRECORD
    record_file = TFRECORD_PATH + '%05d.tfrecords' % idx
    with tf.io.TFRecordWriter(record_file) as writer:
        for filename, label in zip(file_list[idx0:idx1], 
                                   label_list[idx0:idx1]):
            image_string = open(filename, 'rb').read()
            tf_example = image_example(image_string, label)
            writer.write(tf_example.SerializeToString())