In [None]:
import os  # 운영 체제와 상호작용하기 위한 모듈
import numpy as np  # 수치 연산을 위한 Python 라이브러리
import tensorflow as tf  # 딥러닝 라이브러리 TensorFlow
from tensorflow.keras.preprocessing.image import load_img, img_to_array  # 이미지 불러오기 및 배열로 변환하기 위한 Keras 유틸리티
import matplotlib.pyplot as plt  # 데이터 시각화를 위한 라이브러리
from tensorflow.keras import mixed_precision  # 혼합 정밀도 계산을 위한 모듈
from tqdm import tqdm  # 진행 상황 표시를 위한 모듈
from tensorflow.keras.callbacks import ModelCheckpoint  # 학습 중간에 모델을 저장하기 위한 콜백 함수

# GPU 설정: 5, 6, 7번 GPU만 사용하도록 설정
os.environ["CUDA_VISIBLE_DEVICES"] = "your_GPU"

def setup_gpus():
    """GPU 메모리 증가를 동적으로 설정"""
    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpus:
        try:
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
        except RuntimeError as e:
            print(e)

setup_gpus()

# 혼합 정밀도 설정
"""
혼합 정밀도(Mixed Precision)는 딥러닝 모델 학습에서 16비트 부동 소수점(float16)과 32비트 부동 소수점(float32) 숫자 표현을 결합하여 사용하는 방법
GPU와 같은 하드웨어에서 메모리 사용량을 줄이고 계산 속도를 높이기 위해 사용
"""
mixed_precision.set_global_policy('mixed_float16')  # 모델 훈련에 혼합 정밀도(16-bit float)를 사용하도록 설정

# 하이퍼파라미터 설정
"""
하이퍼파라미터는 모델 성능과 학습 속도에 큰 영향을 미침
"""
batch_size = 128  # 한 번에 학습할 데이터의 개수
img_size = (256, 256)  # 입력 이미지 크기
latent_dim = 100  # 생성자의 잠재 공간 차원
n_critic = 10  # Critic 네트워크를 한 번에 몇 번 업데이트할지 설정
lambda_gp = 10.0  # Gradient Penalty의 가중치, Critic을 위한 정규화 항
epochs = 50  # 전체 학습을 몇 번 반복할지 설정

def make_dataset(original_dir, adjust_dir, batch_size=batch_size):
    """학습, 검증, 테스트 데이터셋을 생성하는 함수"""
    original_images = sorted(os.listdir(original_dir))
    dataset_pairs = []

    for original_image in original_images:
        base_name = original_image.replace("_rgb.png", "")
        adjust_image = base_name + "_rgb_adjusted_0.png"
        original_path = os.path.join(original_dir, original_image)
        adjust_path = os.path.join(adjust_dir, adjust_image)
        if os.path.exists(original_path) and os.path.exists(adjust_path):
            dataset_pairs.append((adjust_path, original_path))
        else:
            print(f"Missing file: {original_path} or {adjust_path}")

    train_pairs, val_pairs, test_pairs = split_dataset(dataset_pairs)
    
    train_dataset, train_samples = make_tf_dataset(train_pairs)
    val_dataset, val_samples = make_tf_dataset(val_pairs)
    test_dataset, test_samples = make_tf_dataset(test_pairs)

    return train_dataset, val_dataset, test_dataset, train_samples, val_samples, test_samples

def split_dataset(dataset_pairs):
    """데이터셋을 학습, 검증, 테스트 세트로 나누는 함수"""
    total_len = len(dataset_pairs)
    train_len = int(total_len * 0.7)
    val_len = int(total_len * 0.2)
    test_len = total_len - train_len - val_len

    train_pairs = dataset_pairs[:train_len]
    val_pairs = dataset_pairs[train_len:train_len + val_len]
    test_pairs = dataset_pairs[train_len + val_len:]

    return train_pairs, val_pairs, test_pairs

def load_and_preprocess_image(image_path):
    """이미지를 불러와 전처리하는 함수"""
    try:
        image = load_img(image_path, target_size=img_size)
        image = img_to_array(image) / 255.0
        if np.isnan(image).any():
            raise ValueError(f"NaN detected in image: {image_path}")
        return image
    except Exception as e:
        print(f"Error loading image {image_path}: {e}")
        return None

def make_tf_dataset(pairs):
    """TensorFlow 데이터셋을 생성하는 함수"""
    def data_generator():
        for original_path, adjust_path in pairs:
            original_image = load_and_preprocess_image(original_path)
            adjust_image = load_and_preprocess_image(adjust_path)
            if original_image is None or adjust_image is None:
                continue
            yield original_image, adjust_image

    dataset = tf.data.Dataset.from_generator(
        data_generator,
        output_signature=(
            tf.TensorSpec(shape=(img_size[0], img_size[1], 3), dtype=tf.float32),
            tf.TensorSpec(shape=(img_size[0], img_size[1], 3), dtype=tf.float32)
        )
    )
    dataset = dataset.batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE).cache()
    return dataset, len(pairs)

# 데이터셋 디렉토리 경로
original_dir = 'your_path'  # 원본 이미지 디렉토리 경로
adjust_dir = 'your_path'  # 조정된 이미지 디렉토리 경로

# 데이터셋 생성
train_dataset, val_dataset, test_dataset, train_samples, val_samples, test_samples = make_dataset(original_dir, adjust_dir)

# Generator 정의
def build_generator():
    """Generator 모델을 구축하는 함수"""
    input_img = tf.keras.layers.Input(shape=(256, 256, 3))
    ref_img = tf.keras.layers.Input(shape=(256, 256, 3))

    x = tf.keras.layers.Concatenate()([input_img, ref_img])
    x = conv_block(x, 64)
    x = tf.keras.layers.MaxPooling2D((2, 2))(x)
    x = conv_block(x, 128)
    x = tf.keras.layers.MaxPooling2D((2, 2))(x)
    x = conv_block(x, 256)
    x = tf.keras.layers.UpSampling2D((2, 2))(x)
    x = conv_block(x, 128)
    x = tf.keras.layers.UpSampling2D((2, 2))(x)
    x = conv_block(x, 64)
    output_img = tf.keras.layers.Conv2D(3, (3, 3), padding='same', activation='sigmoid')(x)

    return tf.keras.models.Model([input_img, ref_img], output_img)

def conv_block(x, filters):
    """컨볼루션 블록을 생성하는 함수"""
    x = tf.keras.layers.Conv2D(filters, (3, 3), padding='same', activation='relu')(x)
    x = tf.keras.layers.Conv2D(filters, (3, 3), padding='same', activation='relu')(x)
    return x

# Critic 정의
def build_critic():
    """Critic 모델을 구축하는 함수"""
    input_img = tf.keras.layers.Input(shape=(256, 256, 3))

    x = conv_block(input_img, 64)
    x = tf.keras.layers.MaxPooling2D((2, 2))(x)
    x = conv_block(x, 128)
    x = tf.keras.layers.MaxPooling2D((2, 2))(x)
    x = conv_block(x, 256)
    x = tf.keras.layers.MaxPooling2D((2, 2))(x)
    x = conv_block(x, 512)
    x = tf.keras.layers.Flatten()(x)
    output = tf.keras.layers.Dense(1)(x)

    return tf.keras.models.Model(input_img, output)

# 손실 함수 및 Gradient Penalty 계산
"""
Wasserstein 거리와 Gradient Penalty는 WGAN-GP 모델의 핵심 요소로, 모델 학습의 안정성을 크게 향상시킴
"""

def wasserstein_loss(y_true, y_pred):
    return tf.reduce_mean(y_true * y_pred)

def gradient_penalty(critic, real_images, fake_images):
    """Gradient Penalty를 계산하는 함수"""
    batch_size = tf.shape(real_images)[0]
    alpha = tf.random.uniform([batch_size, 1, 1, 1], 0.0, 1.0, dtype=real_images.dtype)
    real_images = tf.cast(real_images, tf.float32)
    fake_images = tf.cast(fake_images, tf.float32)
    interpolated = alpha * real_images + (1 - alpha) * fake_images

    with tf.GradientTape() as tape:
        tape.watch(interpolated)
        validity_interpolated = critic(interpolated)

    gradients = tape.gradient(validity_interpolated, [interpolated])[0]
    gradients_norm = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=[1, 2, 3]))
    gradient_penalty = tf.reduce_mean((gradients_norm - 1.0) ** 2)
    return gradient_penalty

@tf.function
def train_step(generator, critic, input_img, ref_img, optimizer_G, optimizer_C):
    """단일 학습 스텝을 수행하는 함수"""
    batch_size = input_img.shape[0]

    # Generator 학습
    with tf.GradientTape() as tape:
        fake_images = generator([input_img, ref_img])
        fake_images = tf.cast(fake_images, tf.float32)
        fake_validity = critic(fake_images)
        generator_loss = -tf.reduce_mean(tf.cast(fake_validity, tf.float32))

    grads = tape.gradient(generator_loss, generator.trainable_variables)
    optimizer_G.apply_gradients(zip(grads, generator.trainable_variables))

    # Critic 학습
    for _ in range(n_critic):
        with tf.GradientTape() as tape:
            fake_images = generator([input_img, ref_img])
            real_validity = critic(ref_img)
            fake_validity = critic(fake_images)
            gp = gradient_penalty(critic, ref_img, fake_images)
            critic_loss = (tf.reduce_mean(tf.cast(fake_validity, tf.float32)) -
                           tf.reduce_mean(tf.cast(real_validity, tf.float32)) +
                           lambda_gp * tf.cast(gp, tf.float32))

        grads = tape.gradient(critic_loss, critic.trainable_variables)
        optimizer_C.apply_gradients(zip(grads, critic.trainable_variables))

    return generator_loss, critic_loss

@tf.function
def validation_step(generator, critic, input_img, ref_img):
    """단일 검증 스텝을 수행하는 함수"""
    fake_images = generator([input_img, ref_img])
    fake_images = tf.cast(fake_images, tf.float32)
    real_validity = critic(ref_img)
    fake_validity = critic(fake_images)
    gp = gradient_penalty(critic, ref_img, fake_images)
    critic_loss = (tf.reduce_mean(tf.cast(fake_validity, tf.float32)) -
                   tf.reduce_mean(tf.cast(real_validity, tf.float32)) +
                   lambda_gp * tf.cast(gp, tf.float32))
    generator_loss = -tf.reduce_mean(tf.cast(fake_validity, tf.float32))

    return generator_loss, critic_loss

def train(generator, critic, train_dataset, val_dataset, epochs, batch_size):
    """모델을 학습하는 함수"""
    optimizer_G = tf.keras.optimizers.Adam(learning_rate=0.0001, beta_1=0.5, beta_2=0.9)
    optimizer_C = tf.keras.optimizers.Adam(learning_rate=0.00005, beta_1=0.5, beta_2=0.9)

    best_val_g_loss = np.inf
    best_val_c_loss = np.inf

    train_gen_losses, train_critic_losses, val_gen_losses, val_critic_losses = [], [], [], []

    for epoch in range(epochs):
        epoch_gen_loss, epoch_critic_loss = [], []
        train_dataset_tqdm = tqdm(train_dataset, desc=f'Epoch {epoch+1}/{epochs}', ncols=100, unit='batch')

        for input_img, ref_img in train_dataset_tqdm:
            per_replica_g_loss, per_replica_c_loss = strategy.run(
                train_step, args=(generator, critic, input_img, ref_img, optimizer_G, optimizer_C))
            g_loss = strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_g_loss, axis=None)
            c_loss = strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_c_loss, axis=None)
            epoch_gen_loss.append(g_loss.numpy())
            epoch_critic_loss.append(c_loss.numpy())
            train_dataset_tqdm.set_postfix({'g_loss': np.mean(epoch_gen_loss), 'c_loss': np.mean(epoch_critic_loss)})

        train_gen_losses.append(np.mean(epoch_gen_loss))
        train_critic_losses.append(np.mean(epoch_critic_loss))

        val_gen_loss, val_critic_loss = validate_model(generator, critic, val_dataset)
        val_gen_losses.append(val_gen_loss)
        val_critic_losses.append(val_critic_loss)

        save_model(generator, critic, epoch)

        print(f"Epoch: {epoch+1}/{epochs}, Generator Loss: {np.mean(epoch_gen_loss)}, Critic Loss: {np.mean(epoch_critic_loss)}")
        print(f"Validation Generator Loss: {val_gen_loss}, Validation Critic Loss: {val_critic_loss}")

    return train_gen_losses, train_critic_losses, val_gen_losses, val_critic_losses

def validate_model(generator, critic, val_dataset):
    """모델을 검증하는 함수"""
    val_epoch_gen_loss, val_epoch_critic_loss = [], []
    for input_img, ref_img in val_dataset:
        per_replica_val_g_loss, per_replica_val_c_loss = strategy.run(validation_step, args=(generator, critic, input_img, ref_img))
        val_g_loss = strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_val_g_loss, axis=None)
        val_c_loss = strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_val_c_loss, axis=None)
        val_epoch_gen_loss.append(val_g_loss.numpy())
        val_epoch_critic_loss.append(val_c_loss.numpy())

    val_gen_loss = np.mean(val_epoch_gen_loss)
    val_critic_loss = np.mean(val_epoch_critic_loss)

    return val_gen_loss, val_critic_loss

def save_model(generator, critic, epoch):
    """모델을 저장하는 함수"""
    generator.save(f'your_path')
    critic.save(f'your_path')

# MirroredStrategy 설정
strategy = tf.distribute.MirroredStrategy()

with strategy.scope():
    generator = build_generator()
    critic = build_critic()

    generator.summary()
    critic.summary()

    train_gen_losses, train_critic_losses, val_gen_losses, val_critic_losses = train(
        generator, critic, train_dataset, val_dataset, epochs, batch_size)

# 학습 과정 시각화
def plot_losses(train_gen_losses, val_gen_losses, train_critic_losses, val_critic_losses):
    """학습 및 검증 손실을 시각화하는 함수"""
    plt.figure(figsize=(10, 5))
    plt.plot(train_gen_losses, label='Train Generator Loss')
    plt.plot(val_gen_losses, label='Validation Generator Loss')
    plt.plot(train_critic_losses, label='Train Critic Loss')
    plt.plot(val_critic_losses, label='Validation Critic Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend(loc='upper right')
    plt.title('Training and Validation Loss')
    plt.show()

plot_losses(train_gen_losses, val_gen_losses, train_critic_losses, val_critic_losses)