In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as np
import matplotlib.pyplot as plt

def deep_unet_model(input_shape):
    inputs = tf.keras.Input(shape=input_shape)
    
    # 인코더 부분
    conv1 = layers.Conv2D(64, 3, activation='relu', padding='same')(inputs)
    conv1 = layers.Conv2D(64, 3, activation='relu', padding='same')(conv1)
    pool1 = layers.MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = layers.Conv2D(128, 3, activation='relu', padding='same')(pool1)
    conv2 = layers.Conv2D(128, 3, activation='relu', padding='same')(conv2)
    pool2 = layers.MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = layers.Conv2D(256, 3, activation='relu', padding='same')(pool2)
    conv3 = layers.Conv2D(256, 3, activation='relu', padding='same')(conv3)
    pool3 = layers.MaxPooling2D(pool_size=(2, 2))(conv3)

    # 디코더 부분
    conv4 = layers.Conv2D(512, 3, activation='relu', padding='same')(pool3)
    conv4 = layers.Conv2D(512, 3, activation='relu', padding='same')(conv4)
    up1 = layers.UpSampling2D(size=(2, 2))(conv4)

    conv5 = layers.Conv2D(256, 3, activation='relu', padding='same')(up1)
    conv5 = layers.Conv2D(256, 3, activation='relu', padding='same')(conv5)
    up2 = layers.UpSampling2D(size=(2, 2))(conv5)

    conv6 = layers.Conv2D(128, 3, activation='relu', padding='same')(up2)
    conv6 = layers.Conv2D(128, 3, activation='relu', padding='same')(conv6)
    up3 = layers.UpSampling2D(size=(2, 2))(conv6)

    conv7 = layers.Conv2D(64, 3, activation='relu', padding='same')(up3)
    conv7 = layers.Conv2D(64, 3, activation='relu', padding='same')(conv7)

    # 출력 레이어
    outputs = layers.Conv2D(1, 3, activation='sigmoid', padding='same')(conv7)

    # 입력 이미지 크기에 맞게 자동으로 크기 조절
    outputs = tf.image.resize(outputs, (input_shape[0], input_shape[1]))
    
    model = models.Model(inputs=inputs, outputs=outputs)
    return model


# MNIST 데이터셋 로드 및 전처리
(x_train, _), (x_test, _) = tf.keras.datasets.mnist.load_data()
print(x_train.shape, x_test.shape)



# x_train = x_train.astype('float32') / 255.0
# x_test = x_test.astype('float32') / 255.0

# x_train = np.expand_dims(x_train, axis=-1)
# x_test = np.expand_dims(x_test, axis=-1)

# # 모델 생성
# input_shape = x_train.shape[1:]
# deep_model = deep_unet_model(input_shape)

# # 예시: 지수적으로 감소하는 학습률 스케줄링
# initial_learning_rate = 0.1
# lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
#             initial_learning_rate, decay_steps=1000, decay_rate=0.9, staircase=True)

# optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)

# # 모델 컴파일
# deep_model.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['accuracy'])

# # 모델 훈련
# deep_model.fit(x_train, x_train, epochs=10, batch_size=64, validation_data=(x_test, x_test))

# # 테스트 이미지 예측
# predicted_images = deep_model.predict(x_test)

# # 예측된 이미지를 입력 이미지 크기에 맞게 자름
# crop_size = (input_shape[0] - predicted_images.shape[1]) // 2
# predicted_images_cropped = predicted_images[:, crop_size:-crop_size, crop_size:-crop_size, :]

# # 시각화 함수
# def plot_images(images, title):
#     plt.figure(figsize=(10, 5))
#     for i in range(5):
#         plt.subplot(2, 5, i + 1)
#         plt.imshow(images[i, :, :, 0], cmap='gray')
#         plt.axis('off')
#     plt.suptitle(title)
#     plt.show()

# # 원본 이미지와 향상된 이미지 비교
# plot_images(x_test, title='원본 이미지')
# plot_images(predicted_images_cropped, title='향상된 이미지')
