### Import libraries

In [1]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = "3"

To prevent elements such as Tensorflow import logs, perform these tasks.

In [2]:
import glob
import numpy as np
import tensorflow as tf
import IPython.display as display

### Convert raw files to TFRecord

In [3]:
def _bytes_feature(value: [str, bytes]) -> tf.train.Feature:
    """string / byte를 byte_list로 반환합니다."""
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy() # BytesList는 EagerTensor에서 문자열을 풀지 않습니다.
    
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

In [4]:
def _float_feature(value: float) -> tf.train.Feature:
    """float / double를 float_list로 반환합니다."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

In [5]:
def _int64_feature(value: [bool, int]) -> tf.train.Feature:
    """bool / enum / int / uint를 int64_list로 반환합니다."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

In [6]:
def _image_to_byte(path: str) -> bytes:
    """image를 bytes로 반환합니다."""
    return open(path, "rb").read() 

In [7]:
def serialize_example(image: bytes, label: int, group: bool) -> tf.train.Example.SerializeToString:
    """
    파일을 만들기 위해서 tf.train.Example 메시지를 만듭니다.
    """
    feature = {
        "raw_image": _bytes_feature(image),
        "label": _int64_feature(label),
        "group": _int64_feature(group),
    }
    
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString()

In [24]:
def write_tfrecord(main_path: str) -> None:
    """
    datset의 위치를 입력 받아, 이미지와 라벨 등을 구하여 반환한다.
    """
    paths = glob.glob(main_path + "/*/*/*.jpg")
    image_labels = {"NonDemented": 0, "VeryMildDemented": 1, "MildDemented": 2, "ModerateDemented": 3}
    groups = {"train": 0, "test": 1}
    
    for path in paths:
        group, label = path.split("\\")[1:3]
        
        image = _image_to_byte(path)
        label_int = image_labels[label]
        group_int = groups[group]
        group_dir = "./tfrecord/" + group
        
        if not os.path.isdir(group_dir):
            os.makedirs(group_dir)

        file_cnt = len(os.listdir(group_dir))
        file_name = group_dir + "/{0}_{1}_{2}.tfrecord".format(group, label, str(file_cnt+1).zfill(4))
        
        with tf.io.TFRecordWriter(file_name) as writer:
            try:
                example = serialize_example(image, label_int, group_int)
                writer.write(example)
                
                print("Converting... %s" % file_name)
            except:
                print("Converting Failed... %s" % file_name)
                pass
        
    print("Done!")

In [25]:
dataset_path = "./dataset"
write_tfrecord(dataset_path)

Converting... ./tfrecord/test/test_MildDemented_0001.tfrecord
Converting... ./tfrecord/test/test_MildDemented_0002.tfrecord
Converting... ./tfrecord/test/test_MildDemented_0003.tfrecord
Converting... ./tfrecord/test/test_MildDemented_0004.tfrecord
Converting... ./tfrecord/test/test_MildDemented_0005.tfrecord
Converting... ./tfrecord/test/test_MildDemented_0006.tfrecord
Converting... ./tfrecord/test/test_MildDemented_0007.tfrecord
Converting... ./tfrecord/test/test_MildDemented_0008.tfrecord
Converting... ./tfrecord/test/test_MildDemented_0009.tfrecord
Converting... ./tfrecord/test/test_MildDemented_0010.tfrecord
Converting... ./tfrecord/test/test_MildDemented_0011.tfrecord
Converting... ./tfrecord/test/test_MildDemented_0012.tfrecord
Converting... ./tfrecord/test/test_MildDemented_0013.tfrecord
Converting... ./tfrecord/test/test_MildDemented_0014.tfrecord
Converting... ./tfrecord/test/test_MildDemented_0015.tfrecord
Converting... ./tfrecord/test/test_MildDemented_0016.tfrecord
Converti

Converting... ./tfrecord/test/test_MildDemented_0149.tfrecord
Converting... ./tfrecord/test/test_MildDemented_0150.tfrecord
Converting... ./tfrecord/test/test_MildDemented_0151.tfrecord
Converting... ./tfrecord/test/test_MildDemented_0152.tfrecord
Converting... ./tfrecord/test/test_MildDemented_0153.tfrecord
Converting... ./tfrecord/test/test_MildDemented_0154.tfrecord
Converting... ./tfrecord/test/test_MildDemented_0155.tfrecord
Converting... ./tfrecord/test/test_MildDemented_0156.tfrecord
Converting... ./tfrecord/test/test_MildDemented_0157.tfrecord
Converting... ./tfrecord/test/test_MildDemented_0158.tfrecord
Converting... ./tfrecord/test/test_MildDemented_0159.tfrecord
Converting... ./tfrecord/test/test_MildDemented_0160.tfrecord
Converting... ./tfrecord/test/test_MildDemented_0161.tfrecord
Converting... ./tf

### Read TFRecord file

In [29]:
@tf.autograph.experimental.do_not_convert
def _parse_image_function(example_proto):
    return tf.io.parse_single_example(example_proto, image_feature_description)

In [38]:
def read_tfrecord(main_path: str):
    filenames = glob.glob(main_path + "/*/*.tfrecod")
    raw_dataset = tf.data.TFRecordDataset(filenames)
    
    image_feature_description = {
        "raw_image": tf.io.FixedLenFeature([], tf.string),
        "label": tf.io.FixedLenFeature([], tf.int64),
        "group": tf.io.FixedLenFeature([], tf.int64),
    }

In [39]:
read_tfrecord("./tfrecord")