In [1]:
import os
import cv2
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, Concatenate
from tensorflow.keras.models import Model
from sklearn.model_selection import train_test_split

In [3]:
# 기본 경로 설정
BASE_DIR = '/content/drive/MyDrive/open/'

# 데이터 경로 설정
TRAIN_INPUT_PATH = os.path.join(BASE_DIR, 'train_input')
TRAIN_GT_PATH = os.path.join(BASE_DIR, 'train_gt')
TEST_INPUT_PATH = os.path.join(BASE_DIR, 'test_input')

In [4]:
from tqdm import tqdm

# 데이터 로드 함수
def load_images(input_path, gt_path=None):
    input_images = []
    gt_images = []

    for filename in tqdm(os.listdir(input_path)):
        input_img = cv2.imread(os.path.join(input_path, filename), cv2.IMREAD_COLOR)
        input_img = cv2.resize(input_img, (256, 256)) / 255.0
        input_images.append(input_img)

        if gt_path:
            gt_img = cv2.imread(os.path.join(gt_path, filename), cv2.IMREAD_COLOR)
            gt_img = cv2.resize(gt_img, (256, 256)) / 255.0
            gt_images.append(gt_img)

    if gt_path:
        return np.array(input_images), np.array(gt_images)
    return np.array(input_images)

In [None]:

# 데이터 로드
train_inputs, train_gts = load_images(TRAIN_INPUT_PATH, TRAIN_GT_PATH)
test_inputs = load_images(TEST_INPUT_PATH)

  1%|          | 297/29603 [00:29<24:40, 19.80it/s]

In [None]:
# 모든 데이터를 학습에 사용
x_train, y_train = train_inputs, train_gts

# 학습 데이터 크기 확인
print("Train data shape:", x_train.shape, y_train.shape)

In [None]:
# UNet 모델 정의
def build_unet(input_shape=(256, 256, 3)):
    inputs = Input(input_shape)

    # Contracting Path
    c1 = Conv2D(64, (3, 3), activation='relu', padding='same')(inputs)
    c1 = Conv2D(64, (3, 3), activation='relu', padding='same')(c1)
    p1 = MaxPooling2D((2, 2))(c1)

    c2 = Conv2D(128, (3, 3), activation='relu', padding='same')(p1)
    c2 = Conv2D(128, (3, 3), activation='relu', padding='same')(c2)
    p2 = MaxPooling2D((2, 2))(c2)

    c3 = Conv2D(256, (3, 3), activation='relu', padding='same')(p2)
    c3 = Conv2D(256, (3, 3), activation='relu', padding='same')(c3)
    p3 = MaxPooling2D((2, 2))(c3)

    c4 = Conv2D(512, (3, 3), activation='relu', padding='same')(p3)
    c4 = Conv2D(512, (3, 3), activation='relu', padding='same')(c4)
    p4 = MaxPooling2D((2, 2))(c4)

    # Bottleneck
    c5 = Conv2D(1024, (3, 3), activation='relu', padding='same')(p4)
    c5 = Conv2D(1024, (3, 3), activation='relu', padding='same')(c5)

    # Expanding Path
    u6 = UpSampling2D((2, 2))(c5)
    u6 = Concatenate()([u6, c4])
    c6 = Conv2D(512, (3, 3), activation='relu', padding='same')(u6)
    c6 = Conv2D(512, (3, 3), activation='relu', padding='same')(c6)

    u7 = UpSampling2D((2, 2))(c6)
    u7 = Concatenate()([u7, c3])
    c7 = Conv2D(256, (3, 3), activation='relu', padding='same')(u7)
    c7 = Conv2D(256, (3, 3), activation='relu', padding='same')(c7)

    u8 = UpSampling2D((2, 2))(c7)
    u8 = Concatenate()([u8, c2])
    c8 = Conv2D(128, (3, 3), activation='relu', padding='same')(u8)
    c8 = Conv2D(128, (3, 3), activation='relu', padding='same')(c8)

    u9 = UpSampling2D((2, 2))(c8)
    u9 = Concatenate()([u9, c1])
    c9 = Conv2D(64, (3, 3), activation='relu', padding='same')(u9)
    c9 = Conv2D(64, (3, 3), activation='relu', padding='same')(c9)

    outputs = Conv2D(3, (1, 1), activation='sigmoid')(c9)

    return Model(inputs, outputs)

In [None]:
# 모델 생성
model = build_unet()
model.compile(optimizer='adam', loss='mse', metrics=['mae'])


In [None]:
# 학습
history = model.fit(
    x_train, y_train,
    epochs=2,
    batch_size=16
)


In [None]:
# 학습 완료 후 모델 저장
model.save('image_restoration_model_v1.h5')

In [None]:
# 예측 및 평가
preds = model.predict(test_inputs)

In [None]:
def ssim_score(true, pred):
    # 전체 RGB 이미지를 사용해 SSIM 계산 (channel_axis=-1)
    ssim_value = ssim(true, pred, channel_axis=-1, data_range=pred.max() - pred.min())
    return ssim_value

def masked_ssim_score(true, pred, mask):
    # 손실 영역의 좌표에서만 RGB 채널별 픽셀 값 추출
    true_masked_pixels = true[mask > 0]
    pred_masked_pixels = pred[mask > 0]

    # 손실 영역 픽셀만으로 SSIM 계산 (채널축 사용)
    ssim_value = ssim(
        true_masked_pixels,
        pred_masked_pixels,
        channel_axis=-1,
        data_range=pred.max() - pred.min()
    )
    return ssim_value

def histogram_similarity(true, pred):
    # BGR 이미지를 HSV로 변환
    true_hsv = cv2.cvtColor(true, cv2.COLOR_BGR2HSV)
    pred_hsv = cv2.cvtColor(pred, cv2.COLOR_BGR2HSV)

    # H 채널에서 히스토그램 계산 및 정규화
    hist_true = cv2.calcHist([true_hsv], [0], None, [180], [0, 180])
    hist_pred = cv2.calcHist([pred_hsv], [0], None, [180], [0, 180])
    hist_true = cv2.normalize(hist_true, hist_true).flatten()
    hist_pred = cv2.normalize(hist_pred, hist_pred).flatten()

    # 히스토그램 간 유사도 계산 (상관 계수 사용)
    similarity = cv2.compareHist(hist_true, hist_pred, cv2.HISTCMP_CORREL)
    return similarity


In [None]:
# 평가 메트릭스 계산 함수
def evaluate_model(true_images, pred_images):
    ssim_scores = []
    masked_ssim_scores = []
    hist_similarities = []

    for true, pred in zip(true_images, pred_images):
        # 손실 영역의 마스크 생성 (흑백 이미지에서 특정 조건으로 생성)
        mask = (true == 0).astype(np.uint8)  # 예: 손실된 영역이 0으로 채워진 경우

        # 전체 SSIM 계산
        ssim_scores.append(ssim_score(true, pred))

        # 손실 영역에서만 SSIM 계산
        masked_ssim_scores.append(masked_ssim_score(true, pred, mask))

        # 히스토그램 유사도 계산
        hist_similarities.append(histogram_similarity(true, pred))

    # 평균 점수를 반환
    return {
        "SSIM": np.mean(ssim_scores),
        "Masked SSIM": np.mean(masked_ssim_scores),
        "Histogram Similarity": np.mean(hist_similarities),
    }

# 모델 학습 후 테스트 데이터 예측
# preds = model.predict(test_inputs)

# 성능 평가 실행
results = evaluate_model(test_inputs, preds)

# 결과 출력
print("Evaluation Results:", results)


In [None]:
# 제출 파일 생성
import zipfile
SUBMISSION_DIR = os.path.join(BASE_DIR, 'submission')

submission_file = os.path.join(BASE_DIR, 'sample_submission.zip')
with zipfile.ZipFile(submission_file, 'w') as zipf:
    for filename in os.listdir(SUBMISSION_DIR):
        filepath = os.path.join(SUBMISSION_DIR, filename)
        zipf.write(filepath, arcname=filename)

print(f"Submission file saved at {submission_file}")