### Import libraries

In [4]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = "3"

To prevent elements such as Tensorflow import logs, perform these tasks.

In [5]:
import sys
import glob
import pickle
import shutil
import argparse
import tensorflow as tf

### Create TFRecord functions

In [6]:
# int형 feature값 변환
def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

In [7]:
# byte형태로 변환
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

In [8]:
# image파일 byte단위로 변환
def _read_image_bytes(imagefile):
    file = open(imagefile, "rb")
    bytes = file.read()
    return bytes

In [38]:
# Raw이미지 데이터셋 위치를 입력 받아, TFRecord로 변환하여 반환
def convert_to_tfrecord(input_file, output_file, label_form):
    image_paths = glob.glob(input_file + "/*/*/*.jpg")
    
    for image_path in image_paths:
        image_bytes = _read_image_bytes(image_path)
        
        # 입력 받은 경로에서 데이터의 타입 및 라벨, 파일명들을 추출
        path_information = image_path.split("\\")
        data_type, label, file_name = map(str, path_information[1:4])
        
        # tensorflow example 타입으로 변경
        tf_example = tf.train.Example(features=tf.train.Features(feature={
            "image": _bytes_feature(image_bytes),
            "type": _bytes_feature(bytes(data_type, encoding="utf8")),
            "label_string": _bytes_feature(bytes(label, encoding="utf8")),
            "label_int": _int64_feature(label_form[label])
        }))
        
        # train/test 데이터인지에 따라, 디렉토리를 지정 후
        # 각 디렉토리 속 파일 갯수에 따라, 파일명을 인덱싱
        try:
            type_directory = output_file + "/{0}".format(data_type)
            file_cnt = len(os.listdir(type_directory))
            
            file_name = type_directory + "/{0}_{1}_{2}.tfrecord".format(data_type, label, file_cnt)
            
            writer = tf.io.TFRecordWriter(file_name)
        except:
            type_directory = output_file + "/{0}".format(data_type)
            os.makedirs(os.path.join(type_directory))
            file_cnt = len(os.listdir(type_directory))
            
            file_name = type_directory + "/{0}_{1}_{2}.tfrecord".format(data_type, label, file_cnt)
            
            writer = tf.io.TFRecordWriter(file_name)
        
        try:
            print("Start Generating %s" % file_name)
            writer.write(tf_example.SerializeToString())
        except:
            print("Failed generating %s" % file_name)

In [39]:
label_form = {"NonDemented": 0, "VeryMildDemented": 1, "MildDemented": 2, "ModerateDemented": 3}

convert_to_tfrecord("./dataset/", "./test/", label_form)

Start Generating ./test//test/test_MildDemented_0.tfrecord
Start Generating ./test//test/test_MildDemented_1.tfrecord
Start Generating ./test//test/test_MildDemented_2.tfrecord
Start Generating ./test//test/test_MildDemented_3.tfrecord
Start Generating ./test//test/test_MildDemented_4.tfrecord
Start Generating ./test//test/test_MildDemented_5.tfrecord
Start Generating ./test//test/test_MildDemented_6.tfrecord
Start Generating ./test//test/test_MildDemented_7.tfrecord
Start Generating ./test//test/test_MildDemented_8.tfrecord
Start Generating ./test//test/test_MildDemented_9.tfrecord
Start Generating ./test//test/test_MildDemented_10.tfrecord
Start Generating ./test//test/test_MildDemented_11.tfrecord
Start Generating ./test//test/test_MildDemented_12.tfrecord
Start Generating ./test//test/test_MildDemented_13.tfrecord
Start Generating ./test//test/test_MildDemented_14.tfrecord
Start Generating ./test//test/test_MildDemented_15.tfrecord
Start Generating ./test//test/test_MildDemented_16

Start Generating ./test//test/test_MildDemented_137.tfrecord
Start Generating ./test//test/test_MildDemented_138.tfrecord
Start Generating ./test//test/test_MildDemented_139.tfrecord
Start Generating ./test//test/test_MildDemented_140.tfrecord
Start Generating ./test//test/test_MildDemented_141.tfrecord
Start Generating ./test//test/test_MildDemented_142.tfrecord
Start Generating ./test//test/test_MildDemented_143.tfrecord
Start Generating ./test//test/test_MildDemented_144.tfrecord
Start Generating ./test//test/test_MildDemented_145.tfrecord
Start Generating ./test//test/test_MildDemented_146.tfrecord
Start Generating ./test//test/test_MildDemented_147.tfrecord
Start Generating ./test//test/test_MildDemented_148.tfrecord
Start Generating ./test//test/test_MildDemented_149.tfrecord
Start Generating ./test//test/test_MildDemented_150.tfrecord
Start Generating ./test//test/test_MildDemented_151.tfrecord
Start Generating ./test//test/test_MildDemented_152.tfrecord
Start Generating ./test/

KeyboardInterrupt: 

In [None]:
def _parse_and_decode(self, serialized_example):
    image_feature_description = tf.io.parse_single_example({
        "image/encoded": tf.FixedLenFeature([], tf.string),
        "image/format": tf.FixedLenFeature([], tf.string),
        "image/class/label": tf.FixedLenFeature([], tf.int64),
        "image/height": tf.FixedLenFeature([], tf.int64),
        "image/width": tf.FixedLenFeature([], tf.int64),
    })