### Setup

In [1]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = "3"

To prevent elements such as Tensorflow import logs, perform these tasks.

In [2]:
import glob
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

In [3]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print("Device:", tpu.master())
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
except:
    strategy = tf.distribute.get_strategy()
print("Number of replicas:", strategy.num_replicas_in_sync)

Number of replicas: 1


In [4]:
AUTOTUNE = tf.data.experimental.AUTOTUNE
BATCH_SIZE = 16 * strategy.num_replicas_in_sync
IMAGE_SIZE = [176, 208]
EPOCHS = 100

### Convert the data

In [None]:
def _bytes_feature(value: [str, bytes]) -> tf.train.Feature:
    """string / byte를 byte_list로 반환합니다."""
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy() # BytesList는 EagerTensor에서 문자열을 풀지 않습니다.
    
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

In [None]:
def _float_feature(value: float) -> tf.train.Feature:
    """float / double를 float_list로 반환합니다."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

In [None]:
def _int64_feature(value: [bool, int]) -> tf.train.Feature:
    """bool / enum / int / uint를 int64_list로 반환합니다."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

In [None]:
def serialize_example(image: bytes, label: int) -> tf.train.Example.SerializeToString:
    """
    파일을 만들기 위해서 tf.train.Example 메시지를 만듭니다.
    """
    feature = {
        "raw_image": _bytes_feature(image),
        "label": _int64_feature(label),
    }
    
    return tf.train.Example(features=tf.train.Features(feature=feature))

In [None]:
def write_tfrecord(main_path: str) -> None:
    """
    datset의 위치를 입력 받아, 이미지와 라벨 등을 구하여 반환한다.
    """
    train_paths = glob.glob(main_path + "/train/*/*.jpg")
    test_paths = glob.glob(main_path + "/test/*/*.jpg")
    image_labels = {"NonDemented": 0, "VeryMildDemented": 1, "MildDemented": 2, "ModerateDemented": 3}
    train_file = "./tfrecord/train.tfrecord"
    test_file = "./tfrecord/test.tfrecord"
    
    # train TFRecord file
    with tf.io.TFRecordWriter(train_file) as writer:
        for path in train_paths:
            image_string = open(path, "rb").read()
            
            label_str = path.split("\\")[1]
            label = image_labels[label_str]
            
            tf_example = serialize_example(image_string, label)
            writer.write(tf_example.SerializeToString())
        
        print("Train TFRecord Converting Done!")
    
    # test TFRecord file
    with tf.io.TFRecordWriter(test_file) as writer:
        for path in test_paths:
            image_string = open(path, "rb").read()
            
            label_str = path.split("\\")[1]
            label = image_labels[label_str]
            
            tf_example = serialize_example(image_string, label)
            writer.write(tf_example.SerializeToString())
        
        print("Test TFRecord Converting Done!")

In [None]:
dataset_path = "./dataset"
write_tfrecord(dataset_path)

### Load the data

In [None]:
train_dataset = tf.data.TFRecordDataset("./tfrecord/train.tfrecord")
test_dataset = tf.data.TFRecordDataset("./tfrecord/test.tfrecord")

In [None]:
TRAIN_DATA_SIZE = len(list(train_dataset))
train_size = int(0.75 * TRAIN_DATA_SIZE)

train_dataset = train_dataset.shuffle(1000)
test_dataset = test_dataset.shuffle(1000)

validation_dataset = train_dataset.skip(train_size)
train_dataset = train_dataset.take(train_size)

In [None]:
train_len = len(list(train_dataset))
validation_len = len(list(validation_dataset))
test_len = len(list(test_dataset))

print("Train dataset:", train_len)
print("Validation dataset:", validation_len)
print("Test dataset:", test_len)

In [None]:
image_feature_description = {
    "raw_image": tf.io.FixedLenFeature([], tf.string),
    "label": tf.io.FixedLenFeature([], tf.int64),
}

In [None]:
@tf.autograph.experimental.do_not_convert
def _parse_image_function(example_proto):
    features = tf.io.parse_single_example(example_proto, image_feature_description)
    
    for feature in features:  
        image = tf.io.decode_raw(feature['image'], tf.uint8)
        image.set_shape([3 * 176 * 208])
        image = tf.reshape(image, [176, 208, 3])

        label = tf.cast(feature["label"].numpy(), tf.int64)
        label = tf.one_hot(label, 4)

    return image, label

In [None]:
def read_dataset(epochs, batch_size, dataset):
    dataset = dataset.map(_parse_image_function)
    dataset = dataset.prefetch(10)
    dataset = dataset.repeat(epochs)
    dataset = dataset.shuffle(buffer_size=10 * batch_size)
    dataset = dataset.batch(batch_size, drop_remainder=True)

    return dataset

In [None]:
train_dataset = read_dataset(EPOCHS, BATCH_SIZE, train_dataset)
validation_dataset = read_dataset(EPOCHS, BATCH_SIZE, validation_dataset)
test_dataset = read_dataset(EPOCHS, BATCH_SIZE, test_dataset)

In [None]:
parsed_train_dataset.take(train_len)

### Visualize dataset

In [None]:
# train TFRecord
for image_features in parsed_train_dataset.take(1):
    image_raw = image_features["raw_image"].numpy()
    image_label = image_features["label"].numpy()
    display.display(display.Image(data=image_raw))
    print("Label:", image_label)

In [None]:
# test TFRecord
for image_features in parsed_test_dataset.take(1):
    image_raw = image_features["raw_image"].numpy()
    image_label = image_features["label"].numpy()
    display.display(display.Image(data=image_raw))
    print("Label:", image_label)

### Build Model

In [None]:
# 경증 치매, 중증도 치매, 비 치매, 매우 경미한 치매
CLASS_NAMES = ['MildDementia', 'ModerateDementia', 'NonDementia', 'VeryMildDementia']
NUM_CLASSES = len(CLASS_NAMES)

In [None]:
TRAIN_DATA_SIZE = len(list(parsed_train_dataset))
train_size = int(0.75 * TRAIN_DATA_SIZE)
# val_size = int(0.25 * TRAIN_DATA_SIZE)
# 테스트용 데이터셋은 따로 존재하기에 분할하지 않는다.
# test_size = 

In [None]:
# train / validation data split
train_dataset = parsed_train_dataset.shuffle(100)
train_dataset = train_dataset.take(train_size)
validation_dataset = train_dataset.skip(train_size)

train_dataset = train_dataset.batch(BATCH_SIZE)
validation_dataset = validation_dataset.batch(BATCH_SIZE)

In [None]:
def conv_block(filters):
    block = tf.keras.Sequential([
        tf.keras.layers.SeparableConv2D(filters, 3, activation='relu', padding='same'),
        tf.keras.layers.SeparableConv2D(filters, 3, activation='relu', padding='same'),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.MaxPool2D()
    ])
    
    return block

In [None]:
def dense_block(units, dropout_rate):
    block = tf.keras.Sequential([
        tf.keras.layers.Dense(units, activation='relu'),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Dropout(dropout_rate)
    ])
    
    return block

In [None]:
def build_model():
    model = tf.keras.Sequential([
        tf.keras.Input(shape=(*IMAGE_SIZE, 3)),
        
        tf.keras.layers.Conv2D(16, 3, activation='relu', padding='same'),
        tf.keras.layers.Conv2D(16, 3, activation='relu', padding='same'),
        tf.keras.layers.MaxPool2D(),
        
        conv_block(32),
        conv_block(64),
        
        conv_block(128),
        tf.keras.layers.Dropout(0.2),
        
        conv_block(256),
        tf.keras.layers.Dropout(0.2),
        
        tf.keras.layers.Flatten(),
        dense_block(512, 0.7),
        dense_block(128, 0.5),
        dense_block(64, 0.3),
        
        tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')
    ])
    
    return model

In [None]:
with strategy.scope():
    model = build_model()

    METRICS = [tf.keras.metrics.AUC(name='auc')]
    
    model.compile(
        optimizer='adam',
        loss=tf.losses.CategoricalCrossentropy(),
        metrics=METRICS
    )
    
    model.summary()

### Train Model

In [None]:
@tf.autograph.experimental.do_not_convert
def exponential_decay(lr0, s):
    def exponential_decay_fn(epoch):
        return lr0 * 0.1 **(epoch / s)
    return exponential_decay_fn

exponential_decay_fn = exponential_decay(0.01, 20)

lr_scheduler = tf.keras.callbacks.LearningRateScheduler(exponential_decay_fn)

checkpoint_cb = tf.keras.callbacks.ModelCheckpoint("AICAv2.h5",
                                                    save_best_only=True)

early_stopping_cb = tf.keras.callbacks.EarlyStopping(patience=10,
                                                     restore_best_weights=True)

In [None]:
history = model.fit(
    train_dataset,
    validation_data=validation_dataset,
    callbacks=[checkpoint_cb, early_stopping_cb, lr_scheduler],
    epochs=EPOCHS
)