In [None]:
import random
import numpy as np
import os

import torch
from PIL import Image
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
import cv2

import zipfile

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# print("Using device:", device)

# print(torch.cuda.is_available())  # True인지 확인
# print(torch.cuda.device_count())  # 사용 가능한 GPU 개수 확인
# print(torch.cuda.current_device())  # 현재 사용 중인 GPU 인덱스 확인
# print(torch.cuda.get_device_name(0))  # 사용 중인 GPU 이름 확인
# print(torch.version.cuda)  # PyTorch가 사용하는 CUDA 버전 출력

# origin_dir = 'train_gt'
# damage_dir = 'train_input'

# print(len(os.listdir(origin_dir)))
# print(len(os.listdir(damage_dir)))

# seed_everything(42)

# # 테스트 난수
# print("Random:", random.random())
# print("NumPy Random:", np.random.rand(1))
# print("PyTorch Random:", torch.rand(1))

In [1]:
CFG = {
    'EPOCHS':2,
    'LEARNING_RATE':3e-4,
    # 'BATCH_SIZE':16,
    'BATCH_SIZE':8,
    'SEED':42
}

In [4]:
# def seed_everything(seed):
#     random.seed(seed)
#     os.environ['PYTHONHASHSEED'] = str(seed)
#     np.random.seed(seed)
#     torch.manual_seed(seed)
#     torch.cuda.manual_seed(seed)
#     torch.backends.cudnn.deterministic = True  # 결정론적 연산 보장
#     torch.backends.cudnn.benchmark = False     # 성능 최적화 대신 일관성 우선

# seed_everything(CFG['SEED'])  # Seed 고정

In [5]:
def get_input_image(damage_img_path, origin_img_path):
    # OpenCV로 이미지 읽기 (NumPy 배열로 읽음)
    color_image = cv2.imread(origin_img_path)
    gray_image = cv2.imread(damage_img_path, cv2.IMREAD_GRAYSCALE)  # 흑백 이미지로 읽기
    
    # 색상 이미지를 흑백으로 변환 (PIL로 변환 후 NumPy로 변환)
    color_image_gray = cv2.cvtColor(color_image, cv2.COLOR_BGR2GRAY)
    
    # 두 이미지의 차이 계산
    difference = cv2.absdiff(color_image_gray, gray_image)
    
    # 차이 값을 임계값으로 처리하여 이진화 이미지 생성
    _, binary_difference = cv2.threshold(difference, 1, 255, cv2.THRESH_BINARY)

    # 마스크 생성
    mask = binary_difference > 0  # 차이가 있는 부분을 마스크로 설정
    mask = Image.fromarray(mask.astype(np.uint8) * 255)  # 마스크 이미지를 PIL 형식으로 변환

    return {
        'image_gray_masked': Image.fromarray(gray_image),  # 손상된 이미지를 PIL 이미지로 반환
        'mask': transforms.ToTensor()(mask)  # 마스크를 텐서로 변환하여 사용
    }

In [None]:
from torch.utils.data import Dataset, DataLoader

class CustomDataset(Dataset):
    def __init__(self, damage_dir, origin_dir, transform=None, use_masks=True):
        self.damage_dir = damage_dir
        self.origin_dir = origin_dir
        self.transform = transform
        self.use_masks = use_masks
        self.damage_files = sorted(os.listdir(damage_dir), key=lambda x: x.lower())
        self.origin_files = sorted(os.listdir(origin_dir), key=lambda x: x.lower())

    def __len__(self):
        return len(self.damage_files)

    def __getitem__(self, idx):
        damage_img_name = self.damage_files[idx]
        origin_img_name = self.origin_files[idx]

        damage_img_path = os.path.join(self.damage_dir, damage_img_name)
        origin_img_path = os.path.join(self.origin_dir, origin_img_name)

        damage_img = Image.open(damage_img_path).convert("RGB")
        origin_img = Image.open(origin_img_path).convert("RGB")

        if self.use_masks:
            input_data = get_input_image(damage_img_path, origin_img_path)
            mask = input_data['mask']
            # `mask`가 이미 텐서인지 확인하고 변환 처리
            if not isinstance(mask, torch.Tensor):
                mask = transforms.ToTensor()(mask)
        else:
            mask = torch.zeros((1, damage_img.size[1], damage_img.size[0]))

        if self.transform:
            damage_img = self.transform(damage_img)
            origin_img = self.transform(origin_img)

        return {'A': damage_img, 'B': origin_img, 'mask': mask}


In [None]:
from segmentation_models_pytorch import UnetPlusPlus
import torch.nn.functional as F
import lightning as L
from skimage.metrics import structural_similarity as ssim
import cv2
import numpy as np
import torch

# 히스토그램 유사도 계산 함수
def get_histogram_similarity(true_np, pred_np, color_space=cv2.COLOR_RGB2HSV):
    true_hsv = cv2.cvtColor(true_np.astype(np.uint8), color_space)
    pred_hsv = cv2.cvtColor(pred_np.astype(np.uint8), color_space)
    
    hist_true = cv2.calcHist([true_hsv], [0], None, [180], [0, 180])
    hist_pred = cv2.calcHist([pred_hsv], [0], None, [180], [0, 180])
    
    hist_true = cv2.normalize(hist_true, hist_true).flatten()
    hist_pred = cv2.normalize(hist_pred, hist_pred).flatten()
    
    similarity = cv2.compareHist(hist_true, hist_pred, cv2.HISTCMP_CORREL)
    return similarity

# 마스크된 부분만 SSIM 계산
def get_masked_ssim_score(true_np, pred_np, mask_np):
    if mask_np.ndim == 3 and mask_np.shape[0] == 1:
        mask_np = mask_np.squeeze(0)
    elif mask_np.ndim == 3 and mask_np.shape[-1] == 1:
        mask_np = mask_np.squeeze(-1)
    
    if true_np.shape[:2] != mask_np.shape:
        mask_np = cv2.resize(mask_np, (true_np.shape[1], true_np.shape[0]), interpolation=cv2.INTER_NEAREST)
    
    true_masked = true_np[mask_np > 0]
    pred_masked = pred_np[mask_np > 0]

    if true_masked.size == 0 or pred_masked.size == 0:
        return 0

    ssim_value = ssim(
        true_masked, pred_masked, data_range=pred_masked.max() - pred_masked.min(), channel_axis=-1
    )
    return ssim_value

# SSIM 계산 함수
def get_ssim_score(true_np, pred_np):
    ssim_value = ssim(
        true_np, pred_np, channel_axis=-1, data_range=pred_np.max() - pred_np.min()
    )
    if np.isnan(ssim_value):
        ssim_value = 0
    return ssim_value

# Lightning Module 정의
class LitIRModel(L.LightningModule):
    def __init__(self, model_1, model_2, image_mean=0.5, image_std=0.5):
        super().__init__()
        self.model_1 = model_1
        self.model_2 = model_2
        self.image_mean = image_mean
        self.image_std = image_std

    def forward(self, images_gray_masked):
        images_gray_restored = self.model_1(images_gray_masked) + images_gray_masked
        images_restored = self.model_2(images_gray_restored)
        return images_gray_restored, images_restored

    def configure_optimizers(self):
        opt = torch.optim.AdamW(self.parameters(), lr=1e-5)
        return opt

    def training_step(self, batch, batch_idx):
        images_gray_masked = torch.mean(batch['A'], dim=1, keepdim=True)  # 손상된 이미지를 흑백으로 변환
        images_gt = batch['B']  # Ground Truth 이미지

        # 모델에 입력
        images_gray_restored, images_restored = self(images_gray_masked)
    
        # 손실 계산
        loss_pixel_gray = (
            F.l1_loss(images_gray_masked, images_gray_restored, reduction='mean') * 0.5 +
            F.mse_loss(images_gray_masked, images_gray_restored, reduction='mean') * 0.5
        )
        loss_pixel = (
            F.l1_loss(images_gt, images_restored, reduction='mean') * 0.5 +
            F.mse_loss(images_gt, images_restored, reduction='mean') * 0.5
        )
        loss = loss_pixel_gray * 0.5 + loss_pixel * 0.5

        # 로깅 (Batch와 손실 값 출력)
        print(f"Batch {batch_idx}, Loss: {loss.item()}")
        self.log("train_loss", loss, on_step=True, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        images_gray_masked = torch.mean(batch['A'], dim=1, keepdim=True)  # 손상된 이미지를 흑백으로 변환
        images_gt = batch['B']  # Ground Truth 이미지

        images_gray_restored, images_restored = self(images_gray_masked)

        # Ground Truth와 복원된 이미지 크기 맞추기
        if images_restored.shape != images_gt.shape:
            images_restored = torch.nn.functional.interpolate(
                images_restored, size=images_gt.shape[-2:], mode="bilinear", align_corners=False
            )

        # NumPy 변환
        images_restored_np = images_restored.detach().cpu().permute(0, 2, 3, 1).numpy()
        images_gt_np = images_gt.detach().cpu().permute(0, 2, 3, 1).numpy()

        # SSIM 및 기타 메트릭 계산
        total_ssim_score = get_ssim_score(images_gt_np[0], images_restored_np[0])
        self.log("val_ssim", total_ssim_score, on_step=False, on_epoch=True)
        return {"val_ssim": total_ssim_score}

# 모델 초기화
model_1 = UnetPlusPlus(
    encoder_name="efficientnet-b0",
    encoder_weights="imagenet",
    in_channels=1,  # Grayscale 입력
    classes=1
)

model_2 = UnetPlusPlus(
    encoder_name="efficientnet-b0",
    encoder_weights="imagenet",
    in_channels=1,  # Grayscale 입력
    classes=3  # RGB 출력
)

lit_ir_model = LitIRModel(model_1=model_1, model_2=model_2)


  from .autonotebook import tqdm as notebook_tqdm
Downloading: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth" to C:\Users\zqrc0/.cache\torch\hub\checkpoints\efficientnet-b0-355c32eb.pth
100%|██████████| 20.4M/20.4M [00:01<00:00, 10.9MB/s]


In [22]:
from torch.utils.data import DataLoader
import os
from PIL import Image
import torch
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
output_dir = "output/04/"
model_path = 'saved_models/112302best_model-epoch=00-val_ssim=0.5549.ckpt'

class CustomDataset(Dataset):
    def __init__(self, damage_dir, origin_dir, transform=None, use_masks=True):
        self.damage_dir = damage_dir
        self.origin_dir = origin_dir
        self.transform = transform
        self.use_masks = use_masks
        self.damage_files = sorted(os.listdir(damage_dir), key=lambda x: x.lower())
        self.origin_files = sorted(os.listdir(origin_dir), key=lambda x: x.lower())

    def __len__(self):
        return len(self.damage_files)

    def __getitem__(self, idx):
        damage_img_name = self.damage_files[idx]
        origin_img_name = self.origin_files[idx]

        damage_img_path = os.path.join(self.damage_dir, damage_img_name)
        origin_img_path = os.path.join(self.origin_dir, origin_img_name)

        damage_img = Image.open(damage_img_path).convert("RGB")
        origin_img = Image.open(origin_img_path).convert("RGB")

        if self.use_masks:
            input_data = get_input_image(damage_img_path, origin_img_path)
            mask = input_data['mask']
            # `mask`가 이미 텐서인지 확인하고 변환 처리
            if not isinstance(mask, torch.Tensor):
                mask = transforms.ToTensor()(mask)
        else:
            mask = torch.zeros((1, damage_img.size[1], damage_img.size[0]))

        if self.transform:
            damage_img = self.transform(damage_img)
            origin_img = self.transform(origin_img)

        return {'A': damage_img, 'B': origin_img, 'mask': mask, 'filename': damage_img_name}


# 데이터 전처리
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

test_dir = 'data/test_input'

# 테스트 데이터셋 생성
test_dataset = CustomDataset(damage_dir=test_dir, origin_dir=test_dir, transform=transform, use_masks=False)

# 테스트 DataLoader 생성
test_dataloader = DataLoader(
    test_dataset,
    batch_size=1,  # 배치 크기를 1로 설정하여 각 파일을 개별적으로 처리
    shuffle=False
)

# 모델 초기화
model = LitIRModel.load_from_checkpoint(
    checkpoint_path=model_path,
    model_1=model_1,
    model_2=model_2
)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# 모델을 평가 모드로 설정
model.eval()
model.to(device)

# 테스트 데이터로 예측 실행
os.makedirs(output_dir, exist_ok=True)
model.eval()  # 모델을 평가 모드로 설정

with torch.no_grad():
    for idx, batch in enumerate(test_dataloader):
        # 입력 데이터 준비 (RGB -> Grayscale 변환)
        inputs = torch.mean(batch['A'], dim=1, keepdim=True).to(device)  # [N, 1, H, W]
        
        # 모델 예측
        gray_restored, color_restored = model(inputs)  # 모델 예측

        # 파일명 가져오기
        filename = batch['filename'][0]  # 배치 크기가 1이므로 첫 번째 파일명만 사용

        # 예측된 이미지를 저장
        for i, result in enumerate(color_restored):
            result_img = (result.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)  # [C, H, W] -> [H, W, C]
            output_path = os.path.join(output_dir, filename)  # 원본 파일명과 동일하게 저장
            plt.imsave(output_path, result_img)
            print(f"Saved: {output_path}")

Saved: output/04/TEST_000.png
Saved: output/04/TEST_001.png
Saved: output/04/TEST_002.png
Saved: output/04/TEST_003.png
Saved: output/04/TEST_004.png
Saved: output/04/TEST_005.png
Saved: output/04/TEST_006.png
Saved: output/04/TEST_007.png
Saved: output/04/TEST_008.png
Saved: output/04/TEST_009.png
Saved: output/04/TEST_010.png
Saved: output/04/TEST_011.png
Saved: output/04/TEST_012.png
Saved: output/04/TEST_013.png
Saved: output/04/TEST_014.png
Saved: output/04/TEST_015.png
Saved: output/04/TEST_016.png
Saved: output/04/TEST_017.png
Saved: output/04/TEST_018.png
Saved: output/04/TEST_019.png
Saved: output/04/TEST_020.png
Saved: output/04/TEST_021.png
Saved: output/04/TEST_022.png
Saved: output/04/TEST_023.png
Saved: output/04/TEST_024.png
Saved: output/04/TEST_025.png
Saved: output/04/TEST_026.png
Saved: output/04/TEST_027.png
Saved: output/04/TEST_028.png
Saved: output/04/TEST_029.png
Saved: output/04/TEST_030.png
Saved: output/04/TEST_031.png
Saved: output/04/TEST_032.png
Saved: out