학습은 kaggle에 미리 받아놓은 데이터셋을 다운로드 받아서 colab에서 진행했습니다.

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("geon05/dataset2")
path2 = kagglehub.dataset_download("geon05/damages-masks")

print("Path to dataset files:", path)
print("Path to dataset files:", path2)

Downloading from https://www.kaggle.com/api/v1/datasets/download/geon05/dataset2?dataset_version_number=1...


100%|██████████| 17.7G/17.7G [02:07<00:00, 149MB/s]

Extracting files...





Downloading from https://www.kaggle.com/api/v1/datasets/download/geon05/damages-masks?dataset_version_number=1...


100%|██████████| 13.9G/13.9G [01:53<00:00, 131MB/s]

Extracting files...





Path to dataset files: /root/.cache/kagglehub/datasets/geon05/dataset2/versions/1
Path to dataset files: /root/.cache/kagglehub/datasets/geon05/damages-masks/versions/1


In [None]:
import os
import glob
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms, models
import numpy as np
import cv2
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from skimage.metrics import structural_similarity as ssim
from torchvision.models import resnet50, ResNet50_Weights

####################
# 파라미터 설정
####################
gray_dir = "/root/.cache/kagglehub/datasets/geon05/dataset2/versions/1/train_input"
color_dir = "/root/.cache/kagglehub/datasets/geon05/damages-masks/versions/1/damage_images/damage_images"
mask_dir = "/root/.cache/kagglehub/datasets/geon05/damages-masks/versions/1/output_masks/output_masks"

save_dir = "/content/drive/MyDrive/Colab Notebooks/last/newcol2_01"
os.makedirs(save_dir, exist_ok=True)

load_checkpoint_path = "/content/drive/MyDrive/Colab Notebooks/last/newcol2_01/best_by_val_loss_ep13.pth"

batch_size = 20
lr = 1e-4
epochs = 50
test_size = 0.2
lambda_perc = 0.2  # Perceptual Loss 가중치

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


####################
# Lab 변환 함수
####################
def rgb_to_lab_normalized(rgb):
    rgb_np = (rgb.permute(1,2,0).numpy() * 255).astype(np.uint8)
    bgr = cv2.cvtColor(rgb_np, cv2.COLOR_RGB2BGR)
    lab = cv2.cvtColor(bgr, cv2.COLOR_BGR2Lab).astype(np.float32)
    L = lab[:,:,0] / 255.0
    a = (lab[:,:,1] - 128.0)/128.0
    b = (lab[:,:,2] - 128.0)/128.0
    return L, a, b

def lab_to_rgb(L, a, b):
    lab_0_255 = np.zeros((L.shape[0], L.shape[1], 3), dtype=np.float32)
    lab_0_255[:,:,0] = L * 255.0
    lab_0_255[:,:,1] = a * 128.0 + 128.0
    lab_0_255[:,:,2] = b * 128.0 + 128.0
    lab_0_255 = np.clip(lab_0_255, 0, 255).astype(np.uint8)
    bgr = cv2.cvtColor(lab_0_255, cv2.COLOR_Lab2BGR)
    rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
    rgb = np.clip(rgb, 0, 255).astype(np.uint8)
    return rgb / 255.0

####################
# SSIM, Histogram
####################
def ssim_score(true, pred):
    return ssim(true, pred, channel_axis=-1, data_range=1.0)

def histogram_similarity(true, pred):
    true_bgr = cv2.cvtColor((true*255).astype(np.uint8), cv2.COLOR_RGB2BGR)
    pred_bgr = cv2.cvtColor((pred*255).astype(np.uint8), cv2.COLOR_RGB2BGR)
    true_hsv = cv2.cvtColor(true_bgr, cv2.COLOR_BGR2HSV)
    pred_hsv = cv2.cvtColor(pred_bgr, cv2.COLOR_BGR2HSV)
    hist_true = cv2.calcHist([true_hsv], [0], None, [180], [0, 180])
    hist_pred = cv2.calcHist([pred_hsv], [0], None, [180], [0, 180])
    hist_true = cv2.normalize(hist_true, hist_true).flatten()
    hist_pred = cv2.normalize(hist_pred, hist_pred).flatten()
    return cv2.compareHist(hist_true, hist_pred, cv2.HISTCMP_CORREL)

####################
# VGGPerceptualLoss 정의
####################
class VGGPerceptualLoss(nn.Module):
    def __init__(self, layer_ids=[3, 8, 15, 22]):
        super().__init__()
        vgg = models.vgg16(weights=models.VGG16_Weights.DEFAULT).features
        self.layers = nn.ModuleList([vgg[i] for i in range(max(layer_ids)+1)])
        self.layer_ids = layer_ids
        for param in self.layers.parameters():
            param.requires_grad = False

    def forward(self, x):
        mean = torch.tensor([0.485, 0.456, 0.406], device=x.device).view(1,3,1,1)
        std  = torch.tensor([0.229, 0.224, 0.225], device=x.device).view(1,3,1,1)
        x = (x - mean) / std

        feats = []
        for i, layer in enumerate(self.layers):
            x = layer(x)
            if i in self.layer_ids:
                feats.append(x)
        return feats

####################
# UpConv, DoubleConv
####################
class UpConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_ch, out_ch, kernel_size=2, stride=2)
    def forward(self, x):
        return self.up(x)

class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.conv(x)

####################
# Dataset 정의 (마스크 추가)
####################
class DamagedGrayColorDataset(Dataset):
    def __init__(self, gray_paths, color_paths, mask_paths,
                 transform_gray=None, transform_color=None):
        self.gray_paths = gray_paths
        self.color_paths = color_paths
        self.mask_paths = mask_paths
        self.transform_gray = transform_gray
        self.transform_color = transform_color
        assert len(self.gray_paths) == len(self.color_paths) == len(self.mask_paths)

    def __len__(self):
        return len(self.gray_paths)

    def __getitem__(self, idx):
        g_path = self.gray_paths[idx]
        c_path = self.color_paths[idx]
        m_path = self.mask_paths[idx]

        gray_img = Image.open(g_path).convert('L')
        color_img = Image.open(c_path).convert('RGB')
        mask_img = Image.open(m_path).convert('L')  # 0~255 범위, 255=손상?

        # 마스크 이진화
        mask_np = np.array(mask_img)
        mask_bin = (mask_np > 128).astype(np.float32)
        mask_bin = torch.from_numpy(mask_bin).unsqueeze(0)  # [1,H,W]

        if self.transform_gray:
            gray_img = self.transform_gray(gray_img)
        if self.transform_color:
            color_img = self.transform_color(color_img)

        # 타겟 Lab 변환
        L_t, a_t, b_t = rgb_to_lab_normalized(color_img)
        a_t = torch.from_numpy(a_t).unsqueeze(0)
        b_t = torch.from_numpy(b_t).unsqueeze(0)
        ab_t = torch.cat([a_t, b_t], dim=0)
        L_t = torch.from_numpy(L_t).unsqueeze(0)

        # Gray->RGB->Lab L 변환
        gray_3ch = torch.cat([gray_img,gray_img,gray_img], dim=0)
        G_L, G_a, G_b = rgb_to_lab_normalized(gray_3ch)
        G_L = torch.from_numpy(G_L).unsqueeze(0)

        return G_L, ab_t, L_t, mask_bin

####################
# 데이터셋 구성
####################
transform_gray = transforms.Compose([
    transforms.Resize((512,512)),
    transforms.ToTensor()
])

transform_color = transforms.Compose([
    transforms.Resize((512,512)),
    transforms.ToTensor()
])

gray_files = sorted(glob.glob(os.path.join(gray_dir, "*")))
color_files = sorted(glob.glob(os.path.join(color_dir, "*")))

mask_files = []
for gf in gray_files:
    fname = os.path.basename(gf)
    m_path = os.path.join(mask_dir, fname)
    mask_files.append(m_path)

from sklearn.model_selection import train_test_split
train_gray, val_gray, train_color, val_color, train_mask, val_mask = train_test_split(
    gray_files, color_files, mask_files, test_size=test_size, random_state=42
)

train_dataset = DamagedGrayColorDataset(train_gray, train_color, train_mask, transform_gray, transform_color)
val_dataset   = DamagedGrayColorDataset(val_gray, val_color, val_mask, transform_gray, transform_color)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False)

####################
# ResNetEncoder (ResNet-50)
####################
class ResNetEncoder(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        net = models.resnet50(weights=ResNet50_Weights.DEFAULT if pretrained else None)
        self.initial = nn.Sequential(net.conv1, net.bn1, net.relu)
        self.maxpool = net.maxpool
        self.layer1 = net.layer1  # 256채널
        self.layer2 = net.layer2  # 512채널
        self.layer3 = net.layer3  # 1024채널
        self.layer4 = net.layer4  # 2048채널

    def forward(self, x):
        x0 = self.initial(x)   #64채널
        x1 = self.maxpool(x0)
        x1 = self.layer1(x1)   #256채널
        x2 = self.layer2(x1)   #512채널
        x3 = self.layer3(x2)   #1024채널
        x4 = self.layer4(x3)   #2048채널
        return x0, x1, x2, x3, x4

####################
# 디코더 (ResNet-50에 맞게)
####################
class ResNetUNet(nn.Module):
    def __init__(self, out_ch=2, pretrained=True):
        super().__init__()
        self.encoder = ResNetEncoder(pretrained=pretrained)
        self.up3 = UpConv(2048, 1024)
        self.dec3 = DoubleConv(2048, 1024)
        self.up2 = UpConv(1024, 512)
        self.dec2 = DoubleConv(1024, 512)
        self.up1 = UpConv(512, 256)
        self.dec1 = DoubleConv(512, 256)
        self.up0 = UpConv(256, 64)
        self.dec0 = DoubleConv(128, 64)
        self.up_final = UpConv(64,64)
        self.dec_final = DoubleConv(64,64)
        self.final_out = nn.Conv2d(64, out_ch, 1)

    def forward(self, x):
        x = x.repeat(1,3,1,1)
        x0, x1, x2, x3, x4 = self.encoder(x)
        x_up3 = self.up3(x4)
        x_cat3 = torch.cat([x_up3, x3], dim=1)
        x_dec3 = self.dec3(x_cat3)

        x_up2 = self.up2(x_dec3)
        x_cat2 = torch.cat([x_up2, x2], dim=1)
        x_dec2 = self.dec2(x_cat2)

        x_up1 = self.up1(x_dec2)
        x_cat1 = torch.cat([x_up1, x1], dim=1)
        x_dec1 = self.dec1(x_cat1)

        x_up0 = self.up0(x_dec1)
        x_cat0 = torch.cat([x_up0, x0], dim=1)
        x_dec0 = self.dec0(x_cat0)

        x_upf = self.up_final(x_dec0)
        x_decf = self.dec_final(x_upf)
        out = self.final_out(x_decf)
        return out


model = ResNetUNet(out_ch=2, pretrained=True).to(device)
perceptual_extractor = VGGPerceptualLoss(layer_ids=[3,8,15,22]).to(device)
mse_loss = nn.MSELoss(reduction='none')

def compute_loss(L_img, ab_img, pred_ab, L_t_img, mask):
    # 마스크 제외한 영역만 L1(MSE) + Perceptual
    diff = mse_loss(pred_ab, ab_img)
    valid_area = (1 - mask)
    diff = diff * valid_area
    denom = valid_area.sum() + 1e-8
    l1_val = diff.sum() / denom

    B = L_img.size(0)
    ab_np = ab_img.permute(0,2,3,1).cpu().numpy()
    pred_ab_np = pred_ab.permute(0,2,3,1).detach().cpu().numpy()
    L_t_np = L_t_img[:,0].cpu().numpy()

    true_rgb_list = []
    pred_rgb_list = []
    for i in range(B):
        trgb = lab_to_rgb(L_t_np[i], ab_np[i][:,:,0], ab_np[i][:,:,1])
        prgb = lab_to_rgb(L_t_np[i], pred_ab_np[i][:,:,0], pred_ab_np[i][:,:,1])
        true_rgb_list.append(trgb)
        pred_rgb_list.append(prgb)

    true_rgb_t = torch.from_numpy(np.stack(true_rgb_list,axis=0)).float().to(device)
    pred_rgb_t = torch.from_numpy(np.stack(pred_rgb_list, axis=0)).float().to(device)

    true_rgb_t = true_rgb_t.permute(0,3,1,2)
    pred_rgb_t = pred_rgb_t.permute(0,3,1,2)

    true_feats = perceptual_extractor(true_rgb_t)
    pred_feats = perceptual_extractor(pred_rgb_t)

    perc_loss_val = torch.tensor(0.0, device=device)
    for ft, fp in zip(true_feats, pred_feats):
        perc_diff = (ft - fp)**2
        down_mask = F.interpolate(valid_area, size=ft.shape[2:], mode='bilinear', align_corners=False)
        down_mask_3ch = down_mask.repeat(1,ft.shape[1],1,1)
        perc_diff = perc_diff * down_mask_3ch
        perc_sum = perc_diff.sum()
        perc_count = down_mask_3ch.sum() + 1e-8
        perc_loss_val += (perc_sum / perc_count)

    total_loss = l1_val + lambda_perc * perc_loss_val
    return total_loss, l1_val, perc_loss_val

def visualize_samples(model, data_loader, device, num_samples=3):
    model.eval()
    with torch.no_grad():
        val_iter = iter(data_loader)
        L_batch, ab_batch, L_t_batch, mask_batch = next(val_iter)
        L_batch = L_batch.to(device)
        pred_ab = model(L_batch)

    fig, axes = plt.subplots(num_samples, 3, figsize=(12, 4*num_samples))
    if num_samples == 1:
        axes = [axes]

    for i in range(num_samples):
        L_np = L_batch[i,0].cpu().numpy()
        ab_true = ab_batch[i].permute(1,2,0).numpy()
        L_t_np = L_t_batch[i,0].numpy()
        pred_ab_np = pred_ab[i].cpu().permute(1,2,0).numpy()

        true_rgb = lab_to_rgb(L_t_np, ab_true[:,:,0], ab_true[:,:,1])
        pred_rgb = lab_to_rgb(L_t_np, pred_ab_np[:,:,0], pred_ab_np[:,:,1])

        axes[i][0].imshow(L_np, cmap='gray')
        axes[i][0].set_title("Damaged Gray Input")
        axes[i][0].axis('off')

        axes[i][1].imshow(true_rgb)
        axes[i][1].set_title("Target Color (using L_t)")
        axes[i][1].axis('off')

        axes[i][2].imshow(pred_rgb)
        axes[i][2].set_title("Predicted Color")
        axes[i][2].axis('off')

    plt.tight_layout()
    plt.show()

optimizer = optim.Adam(model.parameters(), lr=lr)

def save_checkpoint(
    epoch,
    model,
    optimizer,
    best_val_loss,
    best_combined_metric,
    train_losses,
    val_losses,
    save_path,
    filename="checkpoint.pth"
):
    checkpoint_dict = {
        "epoch": epoch,
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict(),
        "best_val_loss": best_val_loss,
        "best_combined_metric": best_combined_metric,
        "train_losses": train_losses,
        "val_losses": val_losses
    }
    filepath = os.path.join(save_path, filename)
    torch.save(checkpoint_dict, filepath)
    print(f"Checkpoint saved at epoch {epoch} -> {filepath}")

def load_checkpoint(
    checkpoint_path,
    model,
    optimizer,
    map_location="cpu"
):
    checkpoint = torch.load(checkpoint_path, map_location=map_location)
    model.load_state_dict(checkpoint["model"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    start_epoch = checkpoint["epoch"] + 1
    best_val_loss = checkpoint["best_val_loss"]
    best_combined_metric = checkpoint["best_combined_metric"]
    train_losses = checkpoint.get("train_losses", [])
    val_losses   = checkpoint.get("val_losses", [])

    print(f"Checkpoint loaded: last epoch={checkpoint['epoch']}")

    return (
        start_epoch,
        best_val_loss,
        best_combined_metric,
        train_losses,
        val_losses
    )

resume_training = True

start_epoch = 0
best_val_loss = float('inf')
best_combined_metric = float('-inf')
train_losses = []
val_losses   = []

if resume_training and os.path.exists(load_checkpoint_path):
    (
        start_epoch,
        best_val_loss,
        best_combined_metric,
        old_train_losses,
        old_val_losses
    ) = load_checkpoint(
        load_checkpoint_path,
        model,
        optimizer,
        map_location=device
    )
    train_losses = old_train_losses
    val_losses   = old_val_losses
    print(f"Resuming training from epoch {start_epoch}")
else:
    print("No checkpoint found or not resuming. Start fresh training...")
    start_epoch = 0


last_epoch = start_epoch + epochs

for epoch in range(start_epoch, last_epoch):
    model.train()
    running_loss = 0.0
    train_pbar = tqdm(train_loader, desc=f"Epoch [{epoch+1}/{last_epoch}] Training")

    for L_img, ab_img, L_t_img, mask in train_pbar:
        L_img, ab_img, L_t_img, mask = (
            L_img.to(device),
            ab_img.to(device),
            L_t_img.to(device),
            mask.to(device)
        )

        optimizer.zero_grad()
        pred_ab = model(L_img)
        total_loss, l1_val, perc_val = compute_loss(L_img, ab_img, pred_ab, L_t_img, mask)
        total_loss.backward()
        optimizer.step()

        running_loss += total_loss.item()
        train_pbar.set_postfix(loss=total_loss.item(), L1=l1_val.item(), Perc=perc_val.item())

    avg_train_loss = running_loss / len(train_loader)

    # -------------------------------
    # Validation
    # -------------------------------
    model.eval()
    val_loss = 0.0
    ssim_list = []
    hist_sim_list = []

    val_pbar = tqdm(val_loader, desc=f"Epoch [{epoch+1}/{last_epoch}] Validation")

    with torch.no_grad():
        for i, (L_img, ab_img, L_t_img, mask) in enumerate(val_pbar):
            L_img, ab_img, L_t_img, mask = (
                L_img.to(device),
                ab_img.to(device),
                L_t_img.to(device),
                mask.to(device)
            )

            # 모델 추론
            pred_ab = model(L_img)

            # 손실 계산
            total_loss, l1_val, perc_val = compute_loss(L_img, ab_img, pred_ab, L_t_img, mask)
            val_loss += total_loss.item()
            val_pbar.set_postfix(val_loss=total_loss.item(), L1=l1_val.item(), Perc=perc_val.item())

            # --------------------------
            # (A) SSIM/히스토그램 측정
            # --------------------------
            B = ab_img.size(0)
            ab_np_all    = ab_img.permute(0,2,3,1).cpu().numpy()     # (B,H,W,2)
            pred_ab_np_all = pred_ab.permute(0,2,3,1).cpu().numpy()  # (B,H,W,2)
            L_t_np_all   = L_t_img[:,0].cpu().numpy()                # (B,H,W)
            mask_np_all  = mask[:,0].cpu().numpy()                   # (B,H,W)

            for j in range(B):
                # Lab -> RGB 변환
                true_rgb = lab_to_rgb(L_t_np_all[j], ab_np_all[j][:,:,0], ab_np_all[j][:,:,1])
                pred_rgb = lab_to_rgb(L_t_np_all[j], pred_ab_np_all[j][:,:,0], pred_ab_np_all[j][:,:,1])

                # 마스크 제외 영역만 남기기
                valid_area = (1 - mask_np_all[j])  # (H,W) 손상=1 -> 제외
                valid_area_3ch = np.stack([valid_area]*3, axis=-1)  # (H,W,3)

                true_rgb = true_rgb * valid_area_3ch
                pred_rgb = pred_rgb * valid_area_3ch

                # (1) SSIM
                current_ssim = ssim_score(true_rgb, pred_rgb)
                ssim_list.append(current_ssim)

                # (2) Histogram Similarity
                current_hist_sim = histogram_similarity(true_rgb, pred_rgb)
                hist_sim_list.append(current_hist_sim)
            # --------------------------

    avg_val_loss = val_loss / len(val_loader)
    print(f"Epoch [{epoch+1}] Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")

    train_losses.append(avg_train_loss)
    val_losses.append(avg_val_loss)

    # -------------------------
    # (B) SSIM/Histogram 사용
    # -------------------------
    combined_metric = float('-inf')  # 기본값
    if len(ssim_list) > 0 and len(hist_sim_list) > 0:
        mean_ssim = np.mean(ssim_list)
        mean_hist_sim = np.mean(hist_sim_list)
        combined_metric = mean_ssim * mean_hist_sim
        print(f"Epoch [{epoch+1}] Mean SSIM(no-mask): {mean_ssim:.4f}, "
              f"HistSim(no-mask): {mean_hist_sim:.4f}, Combined: {combined_metric:.4f}")

    # -------------------------
    # (C) 모델 저장 로직 (예: ValLoss / CombinedMetric)
    # -------------------------
    # 1) Val Loss 기준
    global_best_val = False
    global_best_metric = False
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        global_best_val = True

    # 2) Combined Metric 기준
    if combined_metric > best_combined_metric:
        best_combined_metric = combined_metric
        global_best_metric = True

    # 둘 중 하나라도 새 기록이면 저장
    if global_best_val or global_best_metric:
        # 기록이 ValLoss 쪽이라면?
        if global_best_val and (not global_best_metric):
            ckpt_name = f"best_by_val_loss_ep{epoch+1}.pth"
        elif (not global_best_val) and global_best_metric:
            ckpt_name = f"best_by_metric_ep{epoch+1}.pth"
        else:
            # 둘 다 갱신된 경우
            ckpt_name = f"best_by_both_ep{epoch+1}.pth"

        save_checkpoint(
            epoch,
            model,
            optimizer,
            best_val_loss,
            best_combined_metric,
            train_losses,
            val_losses,
            save_path=save_dir,
            filename=ckpt_name
        )
        print(f"** Best Model Updated (Epoch={epoch+1}, "
              f"ValLoss={avg_val_loss:.4f}, Metric={combined_metric:.4f}) **")


    # Loss 그래프
    x_vals = range(1, epoch + 2)
    plt.figure(figsize=(8,5))
    plt.plot(x_vals, train_losses, label='Train Loss')
    plt.plot(x_vals, val_losses,   label='Val Loss')
    plt.title("Loss Curve")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.show()

    # 시각화
    visualize_samples(model, val_loader, device, num_samples=3)


Output hidden; open in https://colab.research.google.com to view.