### Import libraries

In [1]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = "3"

To prevent elements such as Tensorflow import logs, perform these tasks.

In [2]:
import glob
import numpy as np
import tensorflow as tf
import IPython.display as display

### Convert raw files to TFRecord

In [3]:
def _bytes_feature(value: [str, bytes]) -> tf.train.Feature:
    """string / byte를 byte_list로 반환합니다."""
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy() # BytesList는 EagerTensor에서 문자열을 풀지 않습니다.
    
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

In [4]:
def _float_feature(value: float) -> tf.train.Feature:
    """float / double를 float_list로 반환합니다."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

In [5]:
def _int64_feature(value: [bool, int]) -> tf.train.Feature:
    """bool / enum / int / uint를 int64_list로 반환합니다."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

In [6]:
def _image_to_byte(value: str) -> bytes:
    """image를 bytes로 반환합니다."""
    raw_image = open(value, "rb")
    image_bytes = raw_image.read()
    return image_bytes

In [7]:
def serialize_example(raw_image: str, label_str: str, label_int: int, for_test: bool) -> tf.train.Example.SerializeToString:
    """
    파일을 만들기 위해서 tf.train.Example 메시지를 만듭니다.
    """
    feature = {
        "raw_image": _bytes_feature(_image_to_byte(raw_image)),
        "label_str": _bytes_feature(bytes(label_str, encoding="utf-8")),
        "label_int": _int64_feature(label_int),
        "for_test": _int64_feature(for_test),
    }
    
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString

In [10]:
def get_dataset_information(path: str) -> np.array:
    """
    데이터의 정보(이미지 경로, 라벨, 테스트용 유무)의 정보를 리스트로 정리하여 np.array로 반환합니다.
    """
    raw_image, label_str, label_int, for_test = [], [], [], []
    label_form = {"NonDemented": 0, "VeryMildDemented": 1, "MildDemented": 2, "ModerateDemented": 3}
    
    image_paths = glob.glob(path + "/*/*/*.jpg")
    
    for image_path in image_paths:
        image_information = image_path.split("\\")
        data_type, label = map(str, image_information[1:3])
        
        raw_image.append(image_path)
        label_str.append(label)
        label_int.append(label_form[label])
        for_test.append(True if data_type=="test" else False)
        
    return np.array(raw_image), np.array(label_str), np.array(label_int, dtype=np.int64), np.array(for_test)

In [12]:
raw_image, label_str, label_int, for_test = get_dataset_information("./dataset")

features_dataset = tf.data.Dataset.from_tensor_slices((raw_image, label_str, label_int, for_test))

In [15]:
for f0, f1, f2, f3 in features_dataset.take(1):
    print(f0)
    print(f1)
    print(f2)
    print(f3)

tf.Tensor(b'./dataset\\test\\MildDemented\\26 (19).jpg', shape=(), dtype=string)
tf.Tensor(b'MildDemented', shape=(), dtype=string)
tf.Tensor(2, shape=(), dtype=int64)
tf.Tensor(True, shape=(), dtype=bool)


In [16]:
def tf_serialize_example(raw_image, label_str, label_int, for_test):
    tf_string = tf.py_function(serialize_example,
                              (raw_image, label_str, label_int, for_test),
                               tf.string)
    return tf.reshape(tf_string, ())

In [17]:
serialized_features_dataset = features_dataset.map(tf_serialize_example)

Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: 'arguments' object has no attribute 'posonlyargs'
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: 'arguments' object has no attribute 'posonlyargs'


In [18]:
def generator():
    for features in features_dataset:
        yield serialize_example(*features)

In [20]:
serialized_features_dataset = tf.data.Dataset.from_generator(
    generator, output_types=tf.string, output_shapes=())

In [21]:
filename = ""

<FlatMapDataset shapes: (), types: tf.string>