In [7]:
# !pip install segmentation-models-pytorch
# !pip install pytorch_msssim
# !pip install lightning
# !pip install lightning[extra]
# !pip install tensorboard
# !pip install scikit-image

In [2]:
# !nvcc --version

In [3]:
!pip install torch pillow torchvision opencv-python numpy

Collecting torch
  Downloading torch-2.5.1-cp39-cp39-manylinux1_x86_64.whl.metadata (28 kB)
Collecting pillow
  Downloading pillow-11.0.0-cp39-cp39-manylinux_2_28_x86_64.whl.metadata (9.1 kB)
Collecting torchvision
  Downloading torchvision-0.20.1-cp39-cp39-manylinux1_x86_64.whl.metadata (6.1 kB)
Collecting opencv-python
  Using cached opencv_python-4.10.0.84-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (20 kB)
Collecting numpy
  Downloading numpy-2.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (60 kB)
Collecting filelock (from torch)
  Using cached filelock-3.16.1-py3-none-any.whl.metadata (2.9 kB)
Collecting networkx (from torch)
  Downloading networkx-3.2.1-py3-none-any.whl.metadata (5.2 kB)
Collecting fsspec (from torch)
  Using cached fsspec-2024.10.0-py3-none-any.whl.metadata (11 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Using cached nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB

In [18]:
import random
import numpy as np
import os

import torch
from PIL import Image
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
import cv2

import zipfile


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [19]:
print(torch.cuda.is_available())  # True인지 확인
print(torch.cuda.device_count())  # 사용 가능한 GPU 개수 확인
print(torch.cuda.current_device())  # 현재 사용 중인 GPU 인덱스 확인
print(torch.cuda.get_device_name(0))  # 사용 중인 GPU 이름 확인
print(torch.version.cuda)  # PyTorch가 사용하는 CUDA 버전 출력

True
1
0
NVIDIA GeForce RTX 4070 Laptop GPU
12.4


In [20]:
origin_dir = 'train_gt'
damage_dir = 'train_input'

print(len(os.listdir(origin_dir)))
print(len(os.listdir(damage_dir)))

29603
29603


In [21]:
CFG = {
    'EPOCHS':2,
    'LEARNING_RATE':3e-4,
    # 'BATCH_SIZE':16,
    'BATCH_SIZE':16,
    'SEED':42
}

In [22]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True  # 결정론적 연산 보장
    torch.backends.cudnn.benchmark = False     # 성능 최적화 대신 일관성 우선

seed_everything(CFG['SEED'])  # Seed 고정

In [23]:
seed_everything(42)

# 테스트 난수
print("Random:", random.random())
print("NumPy Random:", np.random.rand(1))
print("PyTorch Random:", torch.rand(1))

Random: 0.6394267984578837
NumPy Random: [0.37454012]
PyTorch Random: tensor([0.8823])


In [24]:
def get_input_image(image):
    # 이 함수는 손상된 이미지로부터 마스크를 생성하여 반환합니다.
    # 예시: 특정 조건에 따라 임의로 마스크 생성 (여기서는 단순히 모든 픽셀을 마스크로 가정)
    
    # 손상된 이미지의 그레이스케일 버전과 마스크 생성 (단순한 예제)
    image_gray = image.convert("L")
    mask = np.array(image_gray) > 128  # 임의의 임계값 기준으로 마스크 생성
    mask = Image.fromarray(mask.astype(np.uint8) * 255)  # 마스크 이미지를 PIL 형식으로 변환
    
    return {
        'image_gray_masked': image_gray,
        'mask': transforms.ToTensor()(mask)  # 마스크를 텐서로 변환하여 사용
    }

In [25]:
class CustomDataset(Dataset):
    def __init__(self, damage_dir, origin_dir, transform=None, use_masks=False):
        self.damage_dir = damage_dir
        self.origin_dir = origin_dir
        self.transform = transform
        self.use_masks = use_masks
        self.damage_files = sorted(os.listdir(damage_dir), key=lambda x: x.lower())
        self.origin_files = sorted(os.listdir(origin_dir), key=lambda x: x.lower())

    def __len__(self):
        return len(self.damage_files)

    def __getitem__(self, idx):
        damage_img_name = self.damage_files[idx]
        origin_img_name = self.origin_files[idx]

        damage_img_path = os.path.join(self.damage_dir, damage_img_name)
        origin_img_path = os.path.join(self.origin_dir, origin_img_name)

        damage_img = Image.open(damage_img_path).convert("RGB")
        origin_img = Image.open(origin_img_path).convert("RGB")

        if self.use_masks:
            input_data = get_input_image(damage_img)
            mask = transforms.ToTensor()(input_data['mask'])
        else:
            mask = torch.zeros((1, damage_img.size[1], damage_img.size[0]))

        if self.transform:
            damage_img = self.transform(damage_img)
            origin_img = self.transform(origin_img)

        return {'A': damage_img, 'B': origin_img, 'mask': mask}


In [28]:
from segmentation_models_pytorch import UnetPlusPlus
import torch.nn.functional as F
import lightning as L
from skimage.metrics import structural_similarity as ssim
import cv2
import numpy as np
import torch

def get_histogram_similarity(true_np, pred_np, color_space=cv2.COLOR_RGB2HSV):
    # BGR 이미지를 HSV로 변환
    true_hsv = cv2.cvtColor(true_np.astype(np.uint8), color_space)
    pred_hsv = cv2.cvtColor(pred_np.astype(np.uint8), color_space)
    
    # 히스토그램 계산 및 비교
    hist_true = cv2.calcHist([true_hsv], [0], None, [180], [0, 180])
    hist_pred = cv2.calcHist([pred_hsv], [0], None, [180], [0, 180])
    
    hist_true = cv2.normalize(hist_true, hist_true).flatten()
    hist_pred = cv2.normalize(hist_pred, hist_pred).flatten()
    
    similarity = cv2.compareHist(hist_true, hist_pred, cv2.HISTCMP_CORREL)
    return similarity

def get_masked_ssim_score(true_np, pred_np, mask_np):
    # true_np: (height, width, channels)
    # pred_np: (height, width, channels)
    # mask_np: (height, width) or (1, height, width)
    
    # mask_np 차원 변환
    if mask_np.ndim == 3 and mask_np.shape[0] == 1:
        mask_np = mask_np.squeeze(0)  # (1, height, width) -> (height, width)
    elif mask_np.ndim == 3 and mask_np.shape[-1] == 1:
        mask_np = mask_np.squeeze(-1)  # (height, width, 1) -> (height, width)
    
    # true_np와 mask_np의 크기가 맞지 않으면 리사이즈
    if true_np.shape[:2] != mask_np.shape:
        mask_np = cv2.resize(mask_np, (true_np.shape[1], true_np.shape[0]), interpolation=cv2.INTER_NEAREST)
    
    # 마스크가 적용된 부분만 추출
    true_masked = true_np[mask_np > 0]
    pred_masked = pred_np[mask_np > 0]

    if true_masked.size == 0 or pred_masked.size == 0:
        return 0  # 마스크된 부분이 없으면 SSIM을 0으로 처리

    # SSIM 계산
    ssim_value = ssim(
        true_masked, pred_masked, data_range=pred_masked.max() - pred_masked.min(), channel_axis=-1
    )
    return ssim_value


def ssim_score(true, pred):
    true_np = true.permute(0, 2, 3, 1).cpu().numpy()
    pred_np = pred.permute(0, 2, 3, 1).cpu().numpy()

    print(f"true_np shape: {true_np.shape}")
    print(f"pred_np shape: {pred_np.shape}")
    print(f"mask_np shape: {mask_np.shape}")

    scores = [
        ssim(t, p, channel_axis=-1, data_range=p.max() - p.min())
        for t, p in zip(true_np, pred_np)
    ]
    return np.mean(scores)

def get_ssim_score(true_np, pred_np):
    """
    두 이미지를 비교하여 SSIM(Structural Similarity Index Measure)을 계산합니다.
    NaN 값을 방지하기 위해 추가적인 조건을 추가합니다.

    Args:
        true_np (numpy.ndarray): Ground truth 이미지 (H, W, C 형태)
        pred_np (numpy.ndarray): 예측 이미지 (H, W, C 형태)

    Returns:
        float: SSIM 값 (NaN 발생 시 0으로 설정)
    """
    ssim_value = ssim(
        true_np, pred_np, channel_axis=-1, data_range=pred_np.max() - pred_np.min()
    )
    if np.isnan(ssim_value):
        ssim_value = 0  # NaN 발생 시 SSIM 값을 0으로 설정
    return ssim_value

# LitIRModel 클래스 정의
class LitIRModel(L.LightningModule):
    def __init__(self, model_1, model_2, image_mean=0.5, image_std=0.5):
        super().__init__()
        self.model_1 = model_1
        self.model_2 = model_2
        self.image_mean = image_mean
        self.image_std = image_std
        self.validation_outputs = []

    def forward(self, images_gray_masked):
        images_gray_restored = self.model_1(images_gray_masked) + images_gray_masked
        images_restored = self.model_2(images_gray_restored)
        return images_gray_restored, images_restored

    def unnormalize(self, output, round=False):
        image_restored = ((output * self.image_std + self.image_mean) * 255).clamp(0, 255)
        if round:
            image_restored = torch.round(image_restored)
        return image_restored

    def configure_optimizers(self):
        opt = torch.optim.AdamW(self.parameters(), lr=1e-5)
        return opt

    def training_step(self, batch, batch_idx):
        images_gray_masked = batch['A']  # 손상된 이미지
        images_gt = batch['B']  # Ground Truth 이미지
        masks = batch['masks']  # 마스크 (현재 사용 여부에 따라 다름)

        # 모델에 입력
        images_gray_restored, images_restored = self(images_gray_masked)

        # 손실 계산
        loss_pixel_gray = (
            F.l1_loss(images_gray_masked, images_gray_restored, reduction='mean') * 0.5 +
            F.mse_loss(images_gray_masked, images_gray_restored, reduction='mean') * 0.5
        )
        loss_pixel = (
            F.l1_loss(images_gt, images_restored, reduction='mean') * 0.5 +
            F.mse_loss(images_gt, images_restored, reduction='mean') * 0.5
        )
        loss = loss_pixel_gray * 0.5 + loss_pixel * 0.5

        # 로깅 (Batch와 손실 값 출력)
        print(f"Batch {batch_idx}, Loss: {loss.item()}")
        
        # 로그 기록
        self.log("train_loss", loss, on_step=True, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        # Extract data from batch
        images_gray_masked = batch['A']
        images_gt = batch['B']
        masks = batch['masks']

        # Forward pass
        images_gray_restored, images_restored = self(images_gray_masked)

        # Ground Truth와 복원된 이미지 크기 맞추기
        if images_restored.shape != images_gt.shape:
            images_restored = torch.nn.functional.interpolate(
                images_restored, size=images_gt.shape[-2:], mode="bilinear", align_corners=False
            )

        # NumPy 변환
        images_restored_np = images_restored.detach().cpu().permute(0, 2, 3, 1).float().numpy().astype(np.uint8)
        images_gt_np = images_gt.detach().cpu().permute(0, 2, 3, 1).float().numpy().astype(np.uint8)
        masks_np = masks.detach().cpu().numpy()

        # Metric 계산
        total_ssim_score = 0
        masked_ssim_score = 0
        hist_sim_score = 0
        for image_gt_np, image_restored_np, mask_np in zip(images_gt_np, images_restored_np, masks_np):
            total_ssim_score += get_ssim_score(image_gt_np, image_restored_np) / len(images_gt)
            masked_ssim_score += get_masked_ssim_score(image_gt_np, image_restored_np, mask_np) / len(images_gt)
            hist_sim_score += get_histogram_similarity(image_gt_np, image_restored_np) / len(images_gt)

        # 최종 점수
        score = total_ssim_score * 0.2 + masked_ssim_score * 0.4 + hist_sim_score * 0.4

        # 배치별 검증 결과 출력
        print(f"Validation Batch {batch_idx}, SSIM Score: {total_ssim_score:.4f}")

        # 로깅
        self.log("val_score", score, on_step=False, on_epoch=True)
        self.log("val_total_ssim_score", total_ssim_score, on_step=False, on_epoch=True)
        self.log("val_masked_ssim_score", masked_ssim_score, on_step=False, on_epoch=True)
        self.log("val_hist_sim_score", hist_sim_score, on_step=False, on_epoch=True)

    def on_validation_epoch_end(self):
        total_ssim_score = 0
        masked_ssim_score = 0
        hist_sim_score = 0

        for output in self.validation_outputs:
            images_gt = output["images_gt"]
            images_restored = output["images_restored"]
            masks = output["masks"]

            masks_np = masks.detach().cpu().numpy()
            images_gt_np = images_gt.detach().cpu().permute(0, 2, 3, 1).float().numpy().astype(np.uint8)
            images_restored_np = images_restored.detach().cpu().permute(0, 2, 3, 1).float().numpy().astype(np.uint8)

            for image_gt_np, image_restored_np, mask_np in zip(images_gt_np, images_restored_np, masks_np):
                total_ssim_score += get_ssim_score(image_gt_np, image_restored_np) / len(images_gt_np)
                masked_ssim_score += get_masked_ssim_score(image_gt_np, image_restored_np, mask_np) / len(images_gt_np)
                hist_sim_score += get_histogram_similarity(image_gt_np, image_restored_np, cv2.COLOR_RGB2HSV) / len(images_gt_np)

        score = total_ssim_score * 0.2 + masked_ssim_score * 0.4 + hist_sim_score * 0.4

        self.log(f"val_score", score, on_step=False, on_epoch=True)
        self.log(f"val_total_ssim_score", total_ssim_score, on_step=False, on_epoch=True)
        self.log(f"val_masked_ssim_score", masked_ssim_score, on_step=False, on_epoch=True)
        self.log(f"val_hist_sim_score", hist_sim_score, on_step=False, on_epoch=True)

        self.validation_outputs = []

# 모델 초기화
# 첫 번째 모델: Gray Mask Restoration
model_1 = UnetPlusPlus(
    encoder_name="efficientnet-b4",
    encoder_weights="imagenet",
    in_channels=3,
    classes=1
)

# 두 번째 모델: Gray → Color
model_2 = UnetPlusPlus(
    encoder_name="efficientnet-b4",
    encoder_weights="imagenet",
    in_channels=3,
    classes=3
)

# LitIRModel 초기화
lit_ir_model = LitIRModel(model_1=model_1, model_2=model_2)


In [29]:
from torch.utils.data import random_split, DataLoader
import torchvision.transforms as transforms
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
import lightning as L
import os
import torch

# 경로 설정
origin_dir = 'train_gt'
damage_dir = 'train_input'
test_dir = 'test_input'

# 데이터 전처리 설정
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # 모든 이미지를 256x256으로 리사이즈
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

# 데이터셋 생성
dataset = CustomDataset(damage_dir=damage_dir, origin_dir=origin_dir, transform=transform)

# 데이터셋 분할
validation_ratio = 0.2
train_size = int((1 - validation_ratio) * len(dataset))
val_size = len(dataset) - train_size
training_dataset, validation_dataset = random_split(dataset, [train_size, val_size])

class CollateFn:
    def __init__(self, mode='train'):
        self.mode = mode

    def __call__(self, batch):
        A = torch.stack([item['A'] for item in batch])
        B = torch.stack([item['B'] for item in batch])
        masks = torch.stack([item['mask'] for item in batch]) if 'mask' in batch[0] else torch.zeros_like(A)

        if self.mode in ['train', 'valid']:
            return {'A': A, 'B': B, 'masks': masks}
        elif self.mode == 'test':
            return {'A': A}

# CollateFn 정의
train_collate_fn = CollateFn(mode='train')
validation_collate_fn = CollateFn(mode='valid')

# DataLoader 설정
train_dataloader = DataLoader(
    training_dataset,
    batch_size=CFG['BATCH_SIZE'],
    shuffle=True,
    num_workers=8,
    pin_memory=True,
    collate_fn=train_collate_fn
)

validation_dataloader = DataLoader(
    validation_dataset,
    batch_size=CFG['BATCH_SIZE'],
    shuffle=False,
    num_workers=8,
    collate_fn=validation_collate_fn
)

# 모델 저장 디렉토리 생성
model_save_dir = "./saved_models"
os.makedirs(model_save_dir, exist_ok=True)

# Trainer 설정 및 학습 시작
trainer = L.Trainer(
    max_epochs=CFG['EPOCHS'],
    precision=16,
    accelerator='gpu',
    devices=1,
    log_every_n_steps=10,
    callbacks=[
        ModelCheckpoint(
            monitor='val_score',
            mode='max',
            save_top_k=1,
            dirpath=model_save_dir,
            filename='best_model-{epoch:02d}-{val_score:.4f}'
        ),
        EarlyStopping(monitor='val_score', mode='max', patience=3)
    ]
)

# 학습 시작
trainer.fit(lit_ir_model, train_dataloader, validation_dataloader)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/zqrc05/anaconda3/envs/imagepro/lib/python3.9/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /home/zqrc05/project/imagepro/saved_models exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type         | Params | Mode 
-------------------------------------------------
0 | model_1 | UnetPlusPlus | 20.8 M | train
1 | model_2 | UnetPlusPlus | 20.8 M | train
-------------------------------------------------
41.6 M    Trainable params
0         Non-trainable params
41.6 M    Total params
166.506   Total estimated model params size (MB)
1274      Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                                    | 0/? [0…

Validation Batch 0, SSIM Score: 0.0000
Validation Batch 1, SSIM Score: 0.0001


Training: |                                                                                           | 0/? [0…

Batch 0, Loss: 0.8272649049758911
Batch 1, Loss: 0.8304033875465393
Batch 2, Loss: 0.8322097659111023
Batch 3, Loss: 0.8043810129165649
Batch 4, Loss: 0.7906537055969238
Batch 5, Loss: 0.8028305768966675
Batch 6, Loss: 0.796210527420044
Batch 7, Loss: 0.7989515662193298
Batch 8, Loss: 0.789950966835022
Batch 9, Loss: 0.775140643119812
Batch 10, Loss: 0.796289324760437
Batch 11, Loss: 0.7630201578140259
Batch 12, Loss: 0.7792887687683105
Batch 13, Loss: 0.7583441734313965
Batch 14, Loss: 0.761902928352356
Batch 15, Loss: 0.7608380317687988
Batch 16, Loss: 0.768484890460968
Batch 17, Loss: 0.7625593543052673
Batch 18, Loss: 0.7498657703399658
Batch 19, Loss: 0.7356884479522705
Batch 20, Loss: 0.7419204711914062
Batch 21, Loss: 0.7342249155044556
Batch 22, Loss: 0.7478272914886475
Batch 23, Loss: 0.7444553375244141
Batch 24, Loss: 0.7183412313461304
Batch 25, Loss: 0.7040150165557861
Batch 26, Loss: 0.7220981121063232
Batch 27, Loss: 0.7267026901245117
Batch 28, Loss: 0.7095220685005188
B

Validation: |                                                                                         | 0/? [0…

Validation Batch 0, SSIM Score: 0.8063
Validation Batch 1, SSIM Score: 0.8371
Validation Batch 2, SSIM Score: 0.8610
Validation Batch 3, SSIM Score: 0.9024
Validation Batch 4, SSIM Score: 0.8917
Validation Batch 5, SSIM Score: 0.9227
Validation Batch 6, SSIM Score: 0.8926
Validation Batch 7, SSIM Score: 0.8914
Validation Batch 8, SSIM Score: 0.8888
Validation Batch 9, SSIM Score: 0.8163
Validation Batch 10, SSIM Score: 0.7556
Validation Batch 11, SSIM Score: 0.9295
Validation Batch 12, SSIM Score: 0.8791
Validation Batch 13, SSIM Score: 0.8851
Validation Batch 14, SSIM Score: 0.8152
Validation Batch 15, SSIM Score: 0.7994
Validation Batch 16, SSIM Score: 0.9204
Validation Batch 17, SSIM Score: 0.7988
Validation Batch 18, SSIM Score: 0.7895
Validation Batch 19, SSIM Score: 0.9469
Validation Batch 20, SSIM Score: 0.8208
Validation Batch 21, SSIM Score: 0.8654
Validation Batch 22, SSIM Score: 0.8694
Validation Batch 23, SSIM Score: 0.9258
Validation Batch 24, SSIM Score: 0.8136
Validation

Validation: |                                                                                         | 0/? [0…

Validation Batch 0, SSIM Score: 0.7509
Validation Batch 1, SSIM Score: 0.8452
Validation Batch 2, SSIM Score: 0.8858
Validation Batch 3, SSIM Score: 0.9261
Validation Batch 4, SSIM Score: 0.8984
Validation Batch 5, SSIM Score: 0.8872
Validation Batch 6, SSIM Score: 0.8966
Validation Batch 7, SSIM Score: 0.9147
Validation Batch 8, SSIM Score: 0.8997
Validation Batch 9, SSIM Score: 0.8320
Validation Batch 10, SSIM Score: 0.7744
Validation Batch 11, SSIM Score: 0.9405
Validation Batch 12, SSIM Score: 0.8242
Validation Batch 13, SSIM Score: 0.8918
Validation Batch 14, SSIM Score: 0.7857
Validation Batch 15, SSIM Score: 0.8255
Validation Batch 16, SSIM Score: 0.9406
Validation Batch 17, SSIM Score: 0.8144
Validation Batch 18, SSIM Score: 0.8087
Validation Batch 19, SSIM Score: 0.8944
Validation Batch 20, SSIM Score: 0.8534
Validation Batch 21, SSIM Score: 0.8765
Validation Batch 22, SSIM Score: 0.8843
Validation Batch 23, SSIM Score: 0.9505
Validation Batch 24, SSIM Score: 0.8318
Validation

`Trainer.fit` stopped: `max_epochs=2` reached.


Validation Batch 369, SSIM Score: 0.9660
Validation Batch 370, SSIM Score: 0.9973


In [32]:
from torch.utils.data import DataLoader
import os
from PIL import Image
import torch
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

# 데이터 전처리
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

# 테스트 데이터셋 생성
test_dataset = CustomDataset(damage_dir=test_dir, origin_dir=test_dir, transform=transform, use_masks=False)

# 테스트 DataLoader 생성
test_dataloader = DataLoader(
    test_dataset,
    batch_size=CFG['BATCH_SIZE'],  # 적절히 설정
    shuffle=False,
    num_workers=2  # 경고를 피하기 위해 적절히 설정
)

# 모델 초기화
model = LitIRModel.load_from_checkpoint(
    checkpoint_path='saved_models/best_model-epoch=00-val_score=0.0000-v1.ckpt',
    model_1=model_1,
    model_2=model_2
)

# 모델을 평가 모드로 설정
model.eval()
model.to(device)

# 테스트 데이터로 예측 실행
output_dir = "output/"
os.makedirs(output_dir, exist_ok=True)

with torch.no_grad():
    for idx, batch in enumerate(test_dataloader):
        inputs = batch['A'].to(device)  # 손상된 이미지
        gray_restored, color_restored = model(inputs)  # 모델 예측

        # 예측된 이미지를 저장
        for i, result in enumerate(color_restored):
            result_img = (result.squeeze().cpu().numpy() * 255).astype(np.uint8)  # [0, 255]로 변환
            result_img = np.transpose(result_img, (1, 2, 0))  # [C, H, W] -> [H, W, C]
            output_path = os.path.join(output_dir, f"output_{idx}_{i}.png")
            plt.imsave(output_path, result_img)
            print(f"Saved: {output_path}")

Saved: output/output_0_0.png
Saved: output/output_0_1.png
Saved: output/output_0_2.png
Saved: output/output_0_3.png
Saved: output/output_0_4.png
Saved: output/output_0_5.png
Saved: output/output_0_6.png
Saved: output/output_0_7.png
Saved: output/output_0_8.png
Saved: output/output_0_9.png
Saved: output/output_0_10.png
Saved: output/output_0_11.png
Saved: output/output_0_12.png
Saved: output/output_0_13.png
Saved: output/output_0_14.png
Saved: output/output_0_15.png
Saved: output/output_1_0.png
Saved: output/output_1_1.png
Saved: output/output_1_2.png
Saved: output/output_1_3.png
Saved: output/output_1_4.png
Saved: output/output_1_5.png
Saved: output/output_1_6.png
Saved: output/output_1_7.png
Saved: output/output_1_8.png
Saved: output/output_1_9.png
Saved: output/output_1_10.png
Saved: output/output_1_11.png
Saved: output/output_1_12.png
Saved: output/output_1_13.png
Saved: output/output_1_14.png
Saved: output/output_1_15.png
Saved: output/output_2_0.png
Saved: output/output_2_1.png
Sa