## 0. Import Dependencies

In [3]:
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.applications import Xception
from tensorflow.keras.layers import Flatten, Dense, Dropout, GlobalAveragePooling2D
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.layers import Dense, Input, Activation, Flatten
import os
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Dense, Input, Activation, Flatten
from tensorflow.keras.layers import BatchNormalization,Add,Dropout
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import LeakyReLU, ReLU, Conv2D, MaxPooling2D, BatchNormalization, Conv2DTranspose, UpSampling2D, concatenate
from tensorflow.keras import callbacks
from tensorflow.keras import backend as K
from tensorflow.keras.applications import ResNet50
from tensorflow.python.keras.callbacks import EarlyStopping, ModelCheckpoint
from keras.utils import load_img, img_to_array


## 1. Model initialization

In [5]:
# Image Size
IMAGE_SIZE = (500,500)
MODEL_NAME = "KyuminRes50"

model = Sequential(name=MODEL_NAME)

model.add(ResNet50(include_top = False, pooling = 'max', input_shape=(IMAGE_SIZE[0], IMAGE_SIZE[1], 3), weights = 'imagenet', classes=3))
model.add(Flatten())
model.add(Dense(1024, activation = 'relu'))

model.add(Dense(3, activation = 'softmax'))

# DO NOT train pretrained model : ResNet50
model.layers[0].trainable = False
model.summary()

Model: "KyuminRes50"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 resnet50 (Functional)       (None, 2048)              23587712  
                                                                 
 flatten (Flatten)           (None, 2048)              0         
                                                                 
 dense (Dense)               (None, 1024)              2098176   
                                                                 
 dense_1 (Dense)             (None, 3)                 3075      
                                                                 
Total params: 25,688,963
Trainable params: 2,101,251
Non-trainable params: 23,587,712
_________________________________________________________________


## 2. Model Compilation

In [6]:
model.compile(loss='categorical_crossentropy', metrics=['accuracy'], optimizer='adam')

## 3. Train/Validation Data Preparation

In [7]:
train_dir = '/path/to/TRAIN_DATA/'
valid_dir = '/path/to/VALIDATION_DATA/'
# Batch Size
BATCH_SIZE = 128

# Augmentation for training data
data_gen = ImageDataGenerator(
                              rotation_range=10,
                              width_shift_range=0.1,
                              height_shift_range=0.1,
                              shear_range=0.1,
                              zoom_range=0.1)

train_gen = data_gen.flow_from_directory(
                                                train_dir, 
                                                batch_size=BATCH_SIZE,
                                                color_mode='rgb',
                                                shuffle=True,
                                                class_mode='categorical',
                                                target_size=IMAGE_SIZE)

# No-Augmentation for validation data
default_gen = ImageDataGenerator()
valid_gen = default_gen.flow_from_directory(
                                                valid_dir,
                                                target_size=IMAGE_SIZE,
                                                batch_size=BATCH_SIZE,
                                                class_mode='categorical') 

print("training data label: ", train_gen.class_indices)
print("validation data label: ", valid_gen.class_indices)

Found 5507 images belonging to 3 classes.
Found 1089 images belonging to 3 classes.
data label:  {'0.NOFINDING': 0, '1.THORAXDISEASE': 1, '2.COVID-19': 2}
data label:  {'0.NOFINDING': 0, '1.THORAXDISEASE': 1, '2.COVID-19': 2}


## 3-1. Visualize Augmented Images

In [None]:
###### Reinitialization required for DirectoryItrator after iteration #######

MY_DIRECTORY_ITERATOR = None # e.g. ImageDataGenerator().flow_from_directory()

plt.figure(figsize=(10, 10))
for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    image, label = next(SOME_DIRECTORY_ITERATOR)
    image = image.astype('uint8')
    image =  image.reshape(image.shape[1], image.shape[2], image.shape[3])
    plt.imshow(image)
    plt.axis("off")

## 4. Checkpoint config

In [9]:
cb_early_stopper = EarlyStopping(monitor = 'val_loss', patience = 3)
cb_checkpointer = ModelCheckpoint(filepath = f'/path/to/save_model/my_model.hdf5', monitor = 'val_loss', save_best_only = True, mode = 'auto')

## 5. Model Train

In [None]:
fit_history = model.fit(
        train_gen,
        epochs = 60,
        validation_data=valid_gen,
        callbacks=[cb_checkpointer, cb_early_stopper]
)

Epoch 1/60
Epoch 2/60
10/44 [=====>........................] - ETA: 4:54 - loss: 0.3992 - accuracy: 0.8781

## 5-1. Notification on train completion

- notification sent on training completion, via [ntfy.sh](https://ntfy.sh/)

In [None]:
import requests
requests.post("https://ntfy.sh/MY_NOTI_CHANNEL",
    data=f"Training on {MODEL_NAME} is Done!",
    headers={
        "Title": "Training Done!",
        "Priority": "urgent",
        "Tags": "warning"
    })

## 6. Show training result

In [None]:
print("fit history keys : ",fit_history.history.keys())
plt.figure(1, figsize=(15,8))

plt.subplot(221)
plt.plot(fit_history.history['accuracy'])
plt.plot(fit_history.history['val_accuracy'])
plt.title('model_accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'valid'])

plt.subplot(222)
plt.plot(fit_history.history['loss'])
plt.plot(fit_history.history['val_loss'])
plt.title('model_loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'valid'])

plt.show()

## 7. Prediction on data directory

In [None]:
test_dir = '/path/to/test_data/'
model_path = '/path/to/my_model.hdf5'

model = load_model(model_path)
imgs = [file for file in os.listdir(test_dir)]
with open ('/path/to/result_text.txt', 'w') as f:
  for image_name in imgs:
    full_img_path = os.path.join(test_dir, image_name)
    image = load_img(full_img_path, target_size=IMAGE_SIZE)
    image = img_to_array(image)
    image = np.expand_dims(image, axis=0)

    predictions = model.predict(image, verbose=0)
    predicted_class = np.argmax(predictions, axis=1)[0]
    class_name = ['NOFINDING', 'THORAXDISEASE', 'COVID-19']
    
    print(f"{image_name}\t{predicted_class}\t{class_name[predicted_class]}")
    f.write(f"{image_name}\t{predicted_class}\t{class_name[predicted_class]}")