In [7]:
# !pip install segmentation-models-pytorch
# !pip install pytorch_msssim
# !pip install lightning
# !pip install lightning[extra]
# !pip install tensorboard
# !pip install scikit-image

In [2]:
# !nvcc --version

In [3]:
# !pip install torch pillow torchvision opencv-python numpy

Collecting torch
  Downloading torch-2.5.1-cp39-cp39-manylinux1_x86_64.whl.metadata (28 kB)
Collecting pillow
  Downloading pillow-11.0.0-cp39-cp39-manylinux_2_28_x86_64.whl.metadata (9.1 kB)
Collecting torchvision
  Downloading torchvision-0.20.1-cp39-cp39-manylinux1_x86_64.whl.metadata (6.1 kB)
Collecting opencv-python
  Using cached opencv_python-4.10.0.84-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (20 kB)
Collecting numpy
  Downloading numpy-2.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (60 kB)
Collecting filelock (from torch)
  Using cached filelock-3.16.1-py3-none-any.whl.metadata (2.9 kB)
Collecting networkx (from torch)
  Downloading networkx-3.2.1-py3-none-any.whl.metadata (5.2 kB)
Collecting fsspec (from torch)
  Using cached fsspec-2024.10.0-py3-none-any.whl.metadata (11 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Using cached nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB

In [1]:
import random
import numpy as np
import os

import torch
from PIL import Image
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
import cv2

import zipfile

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [2]:
print(torch.cuda.is_available())  # True인지 확인
print(torch.cuda.device_count())  # 사용 가능한 GPU 개수 확인
print(torch.cuda.current_device())  # 현재 사용 중인 GPU 인덱스 확인
print(torch.cuda.get_device_name(0))  # 사용 중인 GPU 이름 확인
print(torch.version.cuda)  # PyTorch가 사용하는 CUDA 버전 출력

True
1
0
NVIDIA GeForce RTX 4070 Laptop GPU
12.4


In [3]:
origin_dir = 'train_gt'
damage_dir = 'train_input'

print(len(os.listdir(origin_dir)))
print(len(os.listdir(damage_dir)))

29603
29603


In [4]:
CFG = {
    'EPOCHS':2,
    'LEARNING_RATE':3e-4,
    # 'BATCH_SIZE':16,
    'BATCH_SIZE':16,
    'SEED':42
}

In [5]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True  # 결정론적 연산 보장
    torch.backends.cudnn.benchmark = False     # 성능 최적화 대신 일관성 우선

seed_everything(CFG['SEED'])  # Seed 고정

In [6]:
seed_everything(42)

# 테스트 난수
print("Random:", random.random())
print("NumPy Random:", np.random.rand(1))
print("PyTorch Random:", torch.rand(1))

Random: 0.6394267984578837
NumPy Random: [0.37454012]
PyTorch Random: tensor([0.8823])


In [7]:
def get_input_image(damage_img_path, origin_img_path):
    # OpenCV로 이미지 읽기 (NumPy 배열로 읽음)
    color_image = cv2.imread(origin_img_path)
    gray_image = cv2.imread(damage_img_path, cv2.IMREAD_GRAYSCALE)  # 흑백 이미지로 읽기
    
    # 색상 이미지를 흑백으로 변환 (PIL로 변환 후 NumPy로 변환)
    color_image_gray = cv2.cvtColor(color_image, cv2.COLOR_BGR2GRAY)
    
    # 두 이미지의 차이 계산
    difference = cv2.absdiff(color_image_gray, gray_image)
    
    # 차이 값을 임계값으로 처리하여 이진화 이미지 생성
    _, binary_difference = cv2.threshold(difference, 1, 255, cv2.THRESH_BINARY)

    # 마스크 생성
    mask = binary_difference > 0  # 차이가 있는 부분을 마스크로 설정
    mask = Image.fromarray(mask.astype(np.uint8) * 255)  # 마스크 이미지를 PIL 형식으로 변환

    return {
        'image_gray_masked': Image.fromarray(gray_image),  # 손상된 이미지를 PIL 이미지로 반환
        'mask': transforms.ToTensor()(mask)  # 마스크를 텐서로 변환하여 사용
    }

In [8]:
class CustomDataset(Dataset):
    def __init__(self, damage_dir, origin_dir, transform=None, use_masks=True):
        self.damage_dir = damage_dir
        self.origin_dir = origin_dir
        self.transform = transform
        self.use_masks = use_masks
        self.damage_files = sorted(os.listdir(damage_dir), key=lambda x: x.lower())
        self.origin_files = sorted(os.listdir(origin_dir), key=lambda x: x.lower())

    def __len__(self):
        return len(self.damage_files)

    def __getitem__(self, idx):
        damage_img_name = self.damage_files[idx]
        origin_img_name = self.origin_files[idx]

        damage_img_path = os.path.join(self.damage_dir, damage_img_name)
        origin_img_path = os.path.join(self.origin_dir, origin_img_name)

        damage_img = Image.open(damage_img_path).convert("RGB")
        origin_img = Image.open(origin_img_path).convert("RGB")

        if self.use_masks:
            input_data = get_input_image(damage_img_path, origin_img_path)
            mask = input_data['mask']
            # `mask`가 이미 텐서인지 확인하고 변환 처리
            if not isinstance(mask, torch.Tensor):
                mask = transforms.ToTensor()(mask)
        else:
            mask = torch.zeros((1, damage_img.size[1], damage_img.size[0]))

        if self.transform:
            damage_img = self.transform(damage_img)
            origin_img = self.transform(origin_img)

        return {'A': damage_img, 'B': origin_img, 'mask': mask}


In [9]:
# todo:
# 1. val_score
# 2. get_input_image(image) 테스트해서 원본 (손상된 흑백)과 마스크(512*512로 나오되 배경은 다 0, 마스크부분은 회색)
# 3. loss 잘못됨 (아마 컬러를 흑백으로 바꾸면 걔를 정답으로 써도 될듯)

In [10]:
# training_step의 loss계산부분 변경 -> 원본 컬러사진을 흑백으로 변환해서 비고하는 것으로
# validation_step의 score 반환부분 변경 -> 학습완료시 최적의 모델 저장이 되지않는 부분으로 실제 학습에 적용이 되는지 확인필요

In [30]:
from segmentation_models_pytorch import UnetPlusPlus
import torch.nn.functional as F
import lightning as L
from skimage.metrics import structural_similarity as ssim
import cv2
import numpy as np
import torch

# 히스토그램 유사도 계산 함수
def get_histogram_similarity(true_np, pred_np, color_space=cv2.COLOR_RGB2HSV):
    true_hsv = cv2.cvtColor(true_np.astype(np.uint8), color_space)
    pred_hsv = cv2.cvtColor(pred_np.astype(np.uint8), color_space)
    
    hist_true = cv2.calcHist([true_hsv], [0], None, [180], [0, 180])
    hist_pred = cv2.calcHist([pred_hsv], [0], None, [180], [0, 180])
    
    hist_true = cv2.normalize(hist_true, hist_true).flatten()
    hist_pred = cv2.normalize(hist_pred, hist_pred).flatten()
    
    similarity = cv2.compareHist(hist_true, hist_pred, cv2.HISTCMP_CORREL)
    return similarity

# 마스크된 부분만 SSIM 계산
def get_masked_ssim_score(true_np, pred_np, mask_np):
    if mask_np.ndim == 3 and mask_np.shape[0] == 1:
        mask_np = mask_np.squeeze(0)
    elif mask_np.ndim == 3 and mask_np.shape[-1] == 1:
        mask_np = mask_np.squeeze(-1)
    
    if true_np.shape[:2] != mask_np.shape:
        mask_np = cv2.resize(mask_np, (true_np.shape[1], true_np.shape[0]), interpolation=cv2.INTER_NEAREST)
    
    true_masked = true_np[mask_np > 0]
    pred_masked = pred_np[mask_np > 0]

    if true_masked.size == 0 or pred_masked.size == 0:
        return 0

    ssim_value = ssim(
        true_masked, pred_masked, data_range=pred_masked.max() - pred_masked.min(), channel_axis=-1
    )
    return ssim_value

# SSIM 계산 함수
def get_ssim_score(true_np, pred_np):
    ssim_value = ssim(
        true_np, pred_np, channel_axis=-1, data_range=pred_np.max() - pred_np.min()
    )
    if np.isnan(ssim_value):
        ssim_value = 0
    return ssim_value

# Lightning Module 정의
class LitIRModel(L.LightningModule):
    def __init__(self, model_1, model_2, image_mean=0.5, image_std=0.5):
        super().__init__()
        self.model_1 = model_1
        self.model_2 = model_2
        self.image_mean = image_mean
        self.image_std = image_std

    def forward(self, images_gray_masked):
        images_gray_restored = self.model_1(images_gray_masked) + images_gray_masked
        images_restored = self.model_2(images_gray_restored)
        return images_gray_restored, images_restored

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=1e-4)

    def training_step(self, batch, batch_idx):
        # 채널 확인
        assert batch['A'].shape[1] == 3, "batch['A']는 3채널 RGB 이미지여야 합니다."
        assert batch['B'].shape[1] == 3, "batch['B']는 3채널 Ground Truth 이미지여야 합니다."

        # 손상된 이미지 (Gray Scale)
        images_gray_masked = torch.mean(batch['A'], dim=1, keepdim=True)
        images_gt = batch['B']
        
        # 모델 입력
        images_gray_restored, images_restored = self(images_gray_masked)
        
        # 손실 계산
        loss_pixel_gray = (
            F.l1_loss(images_gray_masked, images_gray_restored, reduction='mean') * 0.5 +
            F.mse_loss(images_gray_masked, images_gray_restored, reduction='mean') * 0.5
        )
        loss_pixel = (
            F.l1_loss(images_gt, images_restored, reduction='mean') * 0.5 +
            F.mse_loss(images_gt, images_restored, reduction='mean') * 0.5
        )
        loss = loss_pixel_gray * 0.5 + loss_pixel * 0.5
        
        # 10번째 배치마다 로그 출력
        if batch_idx % 10 == 0:
            print(f"Epoch: {self.current_epoch}, Batch: {batch_idx}, Loss: {loss.item():.4f}")

        # 로깅
        self.log("train_loss", loss, on_step=True, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        images_gray_masked = torch.mean(batch['A'], dim=1, keepdim=True)
        images_gt = batch['B']

        images_gray_restored, images_restored = self(images_gray_masked)

        # Ground Truth와 복원된 이미지 크기 맞추기
        if images_restored.shape != images_gt.shape:
            images_restored = torch.nn.functional.interpolate(
                images_restored, size=images_gt.shape[-2:], mode="bilinear", align_corners=False
            )

        # NumPy 변환
        images_restored_np = images_restored.detach().cpu().permute(0, 2, 3, 1).numpy()
        images_gt_np = images_gt.detach().cpu().permute(0, 2, 3, 1).numpy()

        # SSIM 및 기타 메트릭 계산
        total_ssim_score = get_ssim_score(images_gt_np[0], images_restored_np[0])
        self.log("val_ssim", total_ssim_score, on_step=False, on_epoch=True)
        return {"val_ssim": total_ssim_score}

# 모델 초기화
model_1 = UnetPlusPlus(
    encoder_name="efficientnet-b4",
    encoder_weights="imagenet",
    in_channels=1,  # Grayscale 입력
    classes=1
)

model_2 = UnetPlusPlus(
    encoder_name="efficientnet-b4",
    encoder_weights="imagenet",
    in_channels=1,  # Gray->Color 변환을 위해 1채널 사용
    classes=3  # RGB 출력
)

lit_ir_model = LitIRModel(model_1=model_1, model_2=model_2)


In [31]:
from torch.utils.data import random_split, DataLoader
import torchvision.transforms as transforms
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
import lightning as L
import os
import torch

# 경로 설정
origin_dir = 'train_gt'
damage_dir = 'train_input'
test_dir = 'test_input'

# 데이터 전처리 설정
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # 모든 이미지를 256x256으로 리사이즈
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

# 데이터셋 생성
dataset = CustomDataset(damage_dir=damage_dir, origin_dir=origin_dir, transform=transform)

# 데이터셋 분할
validation_ratio = 0.2
train_size = int((1 - validation_ratio) * len(dataset))
val_size = len(dataset) - train_size
training_dataset, validation_dataset = random_split(dataset, [train_size, val_size])

class CollateFn:
    def __init__(self, mode='train'):
        self.mode = mode

    def __call__(self, batch):
        A = torch.stack([item['A'] for item in batch])
        B = torch.stack([item['B'] for item in batch])
        masks = torch.stack([item['mask'] for item in batch]) if 'mask' in batch[0] else torch.zeros_like(A)

        if self.mode in ['train', 'valid']:
            return {'A': A, 'B': B, 'masks': masks}
        elif self.mode == 'test':
            return {'A': A}

# CollateFn 정의
train_collate_fn = CollateFn(mode='train')
validation_collate_fn = CollateFn(mode='valid')

# DataLoader 설정
train_dataloader = DataLoader(
    training_dataset,
    batch_size=CFG['BATCH_SIZE'],
    shuffle=True,
    num_workers=8,
    pin_memory=True,
    collate_fn=train_collate_fn
)

validation_dataloader = DataLoader(
    validation_dataset,
    batch_size=CFG['BATCH_SIZE'],
    shuffle=False,
    num_workers=8,
    collate_fn=validation_collate_fn
)

# 모델 저장 디렉토리 생성
model_save_dir = "./saved_models"
os.makedirs(model_save_dir, exist_ok=True)

# Trainer 설정 및 학습 시작
trainer = L.Trainer(
    max_epochs=CFG['EPOCHS'],
    precision="16-mixed",
    accelerator='gpu',
    devices=1,
    log_every_n_steps=10,
    callbacks=[
        ModelCheckpoint(
            monitor='val_ssim',
            mode='max',
            save_top_k=1,
            dirpath=model_save_dir,
            filename='best_model-{epoch:02d}-{val_ssim:.4f}'
        ),
        EarlyStopping(monitor='val_ssim', mode='max', patience=3)
    ]
)

# 학습 시작
trainer.fit(lit_ir_model, train_dataloader, validation_dataloader)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type         | Params | Mode 
-------------------------------------------------
0 | model_1 | UnetPlusPlus | 20.8 M | train
1 | model_2 | UnetPlusPlus | 20.8 M | train
-------------------------------------------------
41.6 M    Trainable params
0         Non-trainable params
41.6 M    Total params
166.499   Total estimated model params size (MB)
1274      Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                                | 0/? [00:00…

Training: |                                                                                       | 0/? [00:00…

Epoch: 0, Batch: 0, Loss: 0.6740
Epoch: 0, Batch: 10, Loss: 0.4713
Epoch: 0, Batch: 20, Loss: 0.3466
Epoch: 0, Batch: 30, Loss: 0.2593
Epoch: 0, Batch: 40, Loss: 0.1928
Epoch: 0, Batch: 50, Loss: 0.1671
Epoch: 0, Batch: 60, Loss: 0.1280
Epoch: 0, Batch: 70, Loss: 0.1228
Epoch: 0, Batch: 80, Loss: 0.1083
Epoch: 0, Batch: 90, Loss: 0.0896
Epoch: 0, Batch: 100, Loss: 0.1046
Epoch: 0, Batch: 110, Loss: 0.0879
Epoch: 0, Batch: 120, Loss: 0.0892
Epoch: 0, Batch: 130, Loss: 0.0780
Epoch: 0, Batch: 140, Loss: 0.0777
Epoch: 0, Batch: 150, Loss: 0.0736
Epoch: 0, Batch: 160, Loss: 0.0767
Epoch: 0, Batch: 170, Loss: 0.0613
Epoch: 0, Batch: 180, Loss: 0.0643
Epoch: 0, Batch: 190, Loss: 0.0621
Epoch: 0, Batch: 200, Loss: 0.0771
Epoch: 0, Batch: 210, Loss: 0.0684
Epoch: 0, Batch: 220, Loss: 0.0600
Epoch: 0, Batch: 230, Loss: 0.0612
Epoch: 0, Batch: 240, Loss: 0.0708
Epoch: 0, Batch: 250, Loss: 0.0591
Epoch: 0, Batch: 260, Loss: 0.0682
Epoch: 0, Batch: 270, Loss: 0.0565
Epoch: 0, Batch: 280, Loss: 0.0

Validation: |                                                                                     | 0/? [00:00…

Epoch: 1, Batch: 0, Loss: 0.0481
Epoch: 1, Batch: 10, Loss: 0.0370
Epoch: 1, Batch: 20, Loss: 0.0323
Epoch: 1, Batch: 30, Loss: 0.0352
Epoch: 1, Batch: 40, Loss: 0.0444
Epoch: 1, Batch: 50, Loss: 0.0300
Epoch: 1, Batch: 60, Loss: 0.0319
Epoch: 1, Batch: 70, Loss: 0.0342
Epoch: 1, Batch: 80, Loss: 0.0405
Epoch: 1, Batch: 90, Loss: 0.0394
Epoch: 1, Batch: 100, Loss: 0.0408
Epoch: 1, Batch: 110, Loss: 0.0353
Epoch: 1, Batch: 120, Loss: 0.0447
Epoch: 1, Batch: 130, Loss: 0.0331
Epoch: 1, Batch: 140, Loss: 0.0384
Epoch: 1, Batch: 150, Loss: 0.0311
Epoch: 1, Batch: 160, Loss: 0.0530
Epoch: 1, Batch: 170, Loss: 0.0665
Epoch: 1, Batch: 180, Loss: 0.0315
Epoch: 1, Batch: 190, Loss: 0.0366
Epoch: 1, Batch: 200, Loss: 0.0444
Epoch: 1, Batch: 210, Loss: 0.0387
Epoch: 1, Batch: 220, Loss: 0.0384
Epoch: 1, Batch: 230, Loss: 0.0279
Epoch: 1, Batch: 240, Loss: 0.0345
Epoch: 1, Batch: 250, Loss: 0.0308
Epoch: 1, Batch: 260, Loss: 0.0303
Epoch: 1, Batch: 270, Loss: 0.0373
Epoch: 1, Batch: 280, Loss: 0.0

Validation: |                                                                                     | 0/? [00:00…

`Trainer.fit` stopped: `max_epochs=2` reached.


In [33]:
from torch.utils.data import DataLoader
import os
from PIL import Image
import torch
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

# 데이터 전처리
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])


# 테스트 데이터셋 생성
test_dataset = CustomDataset(damage_dir=test_dir, origin_dir=test_dir, transform=transform, use_masks=False)

# 테스트 DataLoader 생성
test_dataloader = DataLoader(
    test_dataset,
    batch_size=CFG['BATCH_SIZE'],  # 적절히 설정
    shuffle=False,
    num_workers=2  # 경고를 피하기 위해 적절히 설정
)

# 모델 초기화
model = LitIRModel.load_from_checkpoint(
    checkpoint_path='saved_models/best_model-epoch=01-val_score=0.0000.ckpt',
    model_1=model_1,
    model_2=model_2
)

# 모델을 평가 모드로 설정
model.eval()
model.to(device)

# 테스트 데이터로 예측 실행
output_dir = "output/"
os.makedirs(output_dir, exist_ok=True)

model.eval()  # 모델을 평가 모드로 설정
with torch.no_grad():
    for idx, batch in enumerate(test_dataloader):
        # 입력 데이터 준비 (RGB -> Grayscale 변환)
        inputs = torch.mean(batch['A'], dim=1, keepdim=True).to(device)  # [N, 1, H, W]
        
        # 모델 예측
        gray_restored, color_restored = model(inputs)  # 모델 예측

        # 예측된 이미지를 저장
        for i, result in enumerate(color_restored):
            result_img = (result.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)  # [C, H, W] -> [H, W, C]
            output_path = os.path.join(output_dir, f"output_{idx}_{i}.png")
            plt.imsave(output_path, result_img)
            print(f"Saved: {output_path}")

Saved: output/output_0_0.png
Saved: output/output_0_1.png
Saved: output/output_0_2.png
Saved: output/output_0_3.png
Saved: output/output_0_4.png
Saved: output/output_0_5.png
Saved: output/output_0_6.png
Saved: output/output_0_7.png
Saved: output/output_0_8.png
Saved: output/output_0_9.png
Saved: output/output_0_10.png
Saved: output/output_0_11.png
Saved: output/output_0_12.png
Saved: output/output_0_13.png
Saved: output/output_0_14.png
Saved: output/output_0_15.png
Saved: output/output_1_0.png
Saved: output/output_1_1.png
Saved: output/output_1_2.png
Saved: output/output_1_3.png
Saved: output/output_1_4.png
Saved: output/output_1_5.png
Saved: output/output_1_6.png
Saved: output/output_1_7.png
Saved: output/output_1_8.png
Saved: output/output_1_9.png
Saved: output/output_1_10.png
Saved: output/output_1_11.png
Saved: output/output_1_12.png
Saved: output/output_1_13.png
Saved: output/output_1_14.png
Saved: output/output_1_15.png
Saved: output/output_2_0.png
Saved: output/output_2_1.png
Sa