In [1]:
import tensorflow as tf
from tensorflow import keras
import pathlib
import random
import os
import datetime
import time

In [29]:
AUTOTUNE = tf.data.experimental.AUTOTUNE
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        tf.config.experimental.set_memory_growth(gpus[0], True)
    except RuntimeError as e:
        print(e)


In [60]:
dataset_name = 'em24'
#train_dataset_path = '../../datasets/test/' + dataset_name + '/train'
#valid_dataset_path = '../../datasets/test/' + dataset_name + '/valid'

train_dataset_path = '../../datasets/data/train'
valid_dataset_path = '../../datasets/data/validation'

In [64]:
BATCH_SIZE = 32
IMG_SIZE = 224
NUM_EPOCHS = 30
EARLY_STOP_PATIENCE = 3
TRAIN_STEP_PER_EPOCH = tf.math.ceil(train_images_len / BATCH_SIZE).numpy()
VALID_STEP_PER_EPOCH = tf.math.ceil(valid_images_len / BATCH_SIZE).numpy()

saved_path = './model/'
time = datetime.datetime.now().strftime("%Y.%m.%d_%H:%M") + '_tf2'
weight_file_name = '{epoch:02d}-{val_accuracy:.2f}.hdf5'
checkpoint_path = saved_path + dataset_name + '/' + time + '/' + weight_file_name


if not(os.path.isdir(saved_path + dataset_name + '/' + time)):
    os.makedirs(os.path.join(saved_path + dataset_name + '/' + time))
else:
    pass

## 1. Dataset preparation

In [4]:
def basic_processing(img_path, is_training):
    img_path = pathlib.Path(img_path)

    images = list(img_path.glob('*/*'))
    images = [str(path) for path in images]
    len_images = len(images)

    if is_training:
        random.shuffle(images)

    labels = sorted(item.name for item in img_path.glob('*/') if item.is_dir())
    labels_len = len(labels)
    labels = dict((name, index) for index, name in enumerate(labels))
    labels = [labels[pathlib.Path(path).parent.name] for path in images]
    labels = tf.keras.utils.to_categorical(labels, num_classes=labels_len, dtype='float32')

    return images, labels, len_images, labels_len

In [None]:
train_images, train_labels, train_images_len, train_labels_len = basic_processing(train_dataset_path, True)
valid_images, valid_labels, valid_images_len, valid_labels_len = basic_processing(valid_dataset_path, False)

---

## 2. Create Dataset

In [42]:
def preprocess_image(image):
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, [224, 224])
    image = keras.applications.xception.preprocess_input(image)

    return image

# 이미지 path -> tensor
def load_and_preprocess_image(path):
    image = tf.io.read_file(path)
    
    return preprocess_image(image)

# tf dataset 만들기
def make_tf_dataset(images, labels):
    image_ds = tf.data.Dataset.from_tensor_slices(images)
    image_ds = image_ds.map(load_and_preprocess_image, num_parallel_calls=AUTOTUNE)

    lable_ds = tf.data.Dataset.from_tensor_slices(tf.cast(labels, tf.float32))

    image_label_ds = tf.data.Dataset.zip((image_ds, lable_ds))

    return image_label_ds

In [5]:
train_images, train_labels, train_images_len, train_labels_len = basic_processing(train_dataset_path, True)
valid_images, valid_labels, valid_images_len, valid_labels_len = basic_processing(valid_dataset_path, False)

In [52]:
train_ds = make_tf_dataset(train_images, train_labels)
valid_ds = make_tf_dataset(valid_images, valid_labels)

train_ds = train_ds.repeat().batch(BATCH_SIZE).prefetch(1)
valid_ds = valid_ds.repeat().batch(BATCH_SIZE).prefetch(1)

## 3. Create Model

In [66]:
base_model = keras.applications.xception.Xception(input_shape=(IMG_SIZE, IMG_SIZE, 3),
                                                  weights="imagenet",
                                                  include_top=False)
avg = tf.keras.layers.GlobalAveragePooling2D()(base_model.output)
output = tf.keras.layers.Dense(train_labels_len, activation="softmax")(avg)
model = tf.keras.Model(inputs=base_model.input, outputs=output)

for layer in base_model.layers:
    layer.trainable = True

optimizer = tf.keras.optimizers.SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss="categorical_crossentropy", optimizer=optimizer, metrics=["accuracy"])

cb_early_stopper = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=3)

cb_checkpointer = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,                                                     
                                                     monitor='val_accuracy',
                                                     save_best_only=True,
                                                     mode='auto')

In [67]:
model.summary()

Model: "model_2"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_3 (InputLayer)            [(None, 224, 224, 3) 0                                            
__________________________________________________________________________________________________
block1_conv1 (Conv2D)           (None, 111, 111, 32) 864         input_3[0][0]                    
__________________________________________________________________________________________________
block1_conv1_bn (BatchNormaliza (None, 111, 111, 32) 128         block1_conv1[0][0]               
__________________________________________________________________________________________________
block1_conv1_act (Activation)   (None, 111, 111, 32) 0           block1_conv1_bn[0][0]            
____________________________________________________________________________________________

In [68]:
history = model.fit(train_ds,
                    epochs=5,
                    steps_per_epoch=TRAIN_STEP_PER_EPOCH,
                    shuffle=False,
                    validation_data=valid_ds,
                    validation_steps=VALID_STEP_PER_EPOCH,
                    verbose=1,
                    callbacks=[cb_early_stopper, cb_checkpointer])

model.save(saved_path + dataset_name + '/' + time + '/' + dataset_name + '.h5')

Train for 115.0 steps, validate for 13.0 steps
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


In [63]:
saved_path + dataset_name + '/' + time + '/' + dataset_name + '.h5'

'./model/em24/2020.05.20_05:47_tf2/em24.h5'