In [1]:
from silence_tensorflow import silence_tensorflow
silence_tensorflow()
import warnings
warnings.filterwarnings('ignore')

import os
import cv2
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
import random
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, Concatenate, Dropout, Conv2DTranspose, concatenate
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras import backend as K
from tensorflow_examples.models.pix2pix import pix2pix
import matplotlib.pyplot as plt

seed = 1004
np.random.seed(seed)
random.seed(seed)
tf.random.set_seed(seed)

In [2]:
def create_dataset(df, infer=False):
    def _parse_image(img_path):
        img = tf.io.read_file(img_path)
        img = tf.image.decode_jpeg(img, channels=3)
        img = tf.image.resize(img, [384, 384])
        img = tf.image.convert_image_dtype(img, tf.float32)
        img = img / 255.0  # Normalize image to range [0, 1]
        return img

    def _parse_mask(mask_path):
        mask = tf.io.read_file(mask_path)
        mask = tf.image.decode_png(mask, channels=1)
        mask = tf.image.resize(mask, [384, 384])
        mask = mask / 255.0  # Normalize mask to range [0, 1]
        return mask

    img_dataset = tf.data.Dataset.from_tensor_slices(df['img_path'].values)
    img_dataset = img_dataset.map(_parse_image)

    if not infer:
        mask_dataset = tf.data.Dataset.from_tensor_slices(df['mask_path'].values)
        mask_dataset = mask_dataset.map(_parse_mask)
        dataset = tf.data.Dataset.zip((img_dataset, mask_dataset))
    else:
        dataset = img_dataset

    return dataset

In [3]:
def augment_data(image, mask):
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_flip_up_down(image)
    mask = tf.image.random_flip_left_right(mask)
    mask = tf.image.random_flip_up_down(mask)
    return image, mask

def prepare_for_training(ds, cache=True, shuffle_buffer_size=1000, batch_size=32):
    # 캐싱을 사용하여 데이터셋을 로딩 속도를 높입니다.
    if cache:
        if isinstance(cache, str):
            ds = ds.cache(cache)
        else:
            ds = ds.cache()

    ds = ds.shuffle(buffer_size=shuffle_buffer_size)

    # 이미지 증강을 수행합니다.
    ds = ds.map(augment_data, num_parallel_calls=tf.data.AUTOTUNE)

    # 배치 크기에 맞게 데이터를 분할하고 데이터를 지속적으로 전달할 수 있게 반복합니다.
    ds = ds.batch(batch_size)
    ds = ds.repeat()

    # 데이터를 prefetch 하여 학습 도중 다음 배치의 데이터를 준비합니다.
    ds = ds.prefetch(buffer_size=tf.data.AUTOTUNE)

    return ds

In [4]:
# 데이터셋을 생성합니다.
df = pd.read_csv('image_mask_paths.csv')
dataset = create_dataset(df)

# 데이터셋의 전체 크기를 구하고 훈련/검증 데이터셋으로 분할합니다.
dataset_size = len(df)
train_size = int(0.8 * dataset_size)
val_size = dataset_size - train_size

dataset = dataset.shuffle(dataset_size)
train_dataset = dataset.take(train_size)
val_dataset = dataset.skip(train_size)

batch_size = 32

# 데이터 증강을 수행하고 배치 단위로 분할합니다.
train_dataset = prepare_for_training(train_dataset, batch_size=batch_size)
val_dataset = prepare_for_training(val_dataset, batch_size=batch_size)

In [5]:
def UNet(input_shape=(384, 384, 3)):
    inputs = Input(input_shape)

    # Contraction path
    conv1 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(inputs)
    conv1 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool1)
    conv2 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool2)
    conv3 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

    conv4 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool3)
    conv4 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv4)
    drop4 = Dropout(0.5)(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)

    conv5 = Conv2D(1024, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool4)
    conv5 = Conv2D(1024, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv5)
    drop5 = Dropout(0.5)(conv5)

    # Expansion path
    up6 = Conv2D(512, 2, activation='relu', padding='same', kernel_initializer='he_normal')(UpSampling2D(size=(2,2))(drop5))
    merge6 = concatenate([drop4,up6], axis=3)
    conv6 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge6)
    conv6 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv6)

    up7 = Conv2D(256, 2, activation='relu', padding='same', kernel_initializer='he_normal')(UpSampling2D(size=(2,2))(conv6))
    merge7 = concatenate([conv3,up7], axis=3)
    conv7 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge7)
    conv7 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv7)

    up8 = Conv2D(128, 2, activation='relu', padding='same', kernel_initializer='he_normal')(UpSampling2D(size=(2,2))(conv7))
    merge8 = concatenate([conv2,up8], axis=3)
    conv8 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge8)
    conv8 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv8)

    up9 = Conv2D(64, 2, activation='relu', padding='same', kernel_initializer='he_normal')(UpSampling2D(size=(2,2))(conv8))
    merge9 = concatenate([conv1,up9], axis=3)
    conv9 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge9)
    conv9 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv9)
    conv9 = Conv2D(2, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv9)
    conv10 = Conv2D(1, 1, activation='sigmoid')(conv9)

    model = Model(inputs=inputs, outputs=conv10)

    return model

In [6]:
def dice_coef(y_true, y_pred, smooth=1):
    intersection = K.sum(K.abs(y_true * y_pred), axis=-1)
    return (2. * intersection + smooth) / (K.sum(K.square(y_true),-1) + K.sum(K.square(y_pred),-1) + smooth)

In [None]:
strategy = tf.distribute.MirroredStrategy(devices=['/gpu:0', '/gpu:1', '/gpu:2'])
with strategy.scope():
    model = UNet()
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0005), loss=tf.keras.losses.BinaryCrossentropy(from_logits=True), metrics=[dice_coef])

    chk_point = tf.keras.callbacks.ModelCheckpoint('tf_unet.h5', save_best_only=True, save_weights_only=True, verbose=1)
    reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(patience=5)
    early = tf.keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True)

    history = model.fit(train_dataset,
                        validation_data=val_dataset,
                        epochs=100,
                        steps_per_epoch=train_size//batch_size,
                        validation_steps=val_size//batch_size,
                        callbacks=[chk_point, reduce_lr, early])

Epoch 1/100
