In [None]:
# TensorFlow and tf.keras
import tensorflow as tf
from tensorflow import keras

# Helper libraries
import numpy as np
import matplotlib.pyplot as plt

from model.data import ds_train, ds_test, num_classes, ds_info
from model.base_code import apply_normalize_on_dataset

In [None]:
# GPU 확인
tf.config.list_physical_devices('GPU')

In [None]:
# 트레인 데이터셋
ds_train_no_aug = apply_normalize_on_dataset(ds_train, with_aug=True, with_cutmix=False)
ds_train_aug = apply_normalize_on_dataset(ds_train, with_aug=False, with_cutmix=False)
ds_train_no_cutmix = apply_normalize_on_dataset(ds_train, with_aug=False, with_cutmix=False)
ds_train_cutmix = apply_normalize_on_dataset(ds_train, with_aug=False, with_cutmix=True)
ds_train_aug_cutmix = apply_normalize_on_dataset(ds_train, with_aug=True, with_cutmix=True)

# 테스트 데이터셋
ds_test = apply_normalize_on_dataset(ds_test, is_test=True)

In [None]:
from model.data import resnet50 as resnet50_no_aug
from model.data import resnet50 as resnet50_aug
from model.data import resnet50 as resnet50_no_cutmix
from model.data import resnet50 as resnet50_cutmix
from model.data import resnet50 as resnet50_aug_cutmix

EPOCH = 20

In [None]:
model_list = [resnet50_no_aug, resnet50_aug, resnet50_no_cutmix, resnet50_cutmix, resnet50_aug_cutmix]
history_list = []

for res_model in model_list:
    res_model.compile(
        loss='categorical_crossentropy',
        optimizer=tf.keras.optimizers.SGD(lr=0.01, clipnorm=1.),
        metrics=['accuracy'],
    )
    history_list.append(f'history_{res_model}')

for history in history_list:
    history = res_model.fit(
        ds_train_no_aug,
        steps_per_epoch=int(ds_info.splits['train'].num_examples/16),
        validation_steps=int(ds_info.splits['test'].num_examples/16),
        epochs=EPOCH,
        validation_data=ds_test,
        verbose=1,
        use_multiprocessing=True,
    )    


In [None]:
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(history_list[0].history['loss'], 'b-', label='loss')
plt.plot(history_list[0].history['val_loss'], 'r--', label='val_loss')
plt.xlabel('Epoch')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(history_list[0].history['accuracy'], 'g-', label='accuracy')
plt.plot(history_list[0].history['val_accuracy'], 'k--', label='val_accuracy')
plt.xlabel('Epoch')
plt.legend()