### 전략  
모델을 세 가지 단계로 나누어 각 단계에서 복원을 순차적으로 진행합니다.  
  
### 모델 구성  
  
1. 마스크 생성 모델  
  
학습 과정  
  
train_gt의 정답 이미지를 흑백 이미지로 변환합니다.  
변환된 흑백 이미지와 train_input 학습 이미지를 비교하여 차이가 나는 부분을 마스크로 지정합니다.  
생성된 마스크를 새로운 정답 이미지로 사용하여 학습합니다.  
최종적으로 train_input 파일과 생성된 마스크 이미지를 활용해 손상 영역을 탐지하는 마스크 생성 모델을 학습합니다.  
  
테스트 과정  
  
손상된 부분을 인식하여 마스크를 생성합니다.  
생성된 마스크를 mask 폴더에 저장합니다.  
  
2. 컬러 복원 모델  
  
학습 과정  
  
train_gt 이미지를 마스크와 결합하여 손상 이미지를 생성합니다.  
이 손상 이미지를 새로운 정답 이미지로 사용하여 학습합니다.  
train_input 파일과 손상 이미지를 이용해 컬러 복원 모델을 학습합니다.  
  
테스트 과정  
  
test_input 파일에 마스크를 적용하여 손상된 부분을 제외한 나머지 영역의 색상을 복원합니다.  
복원된 이미지를 output_grayTocol 폴더에 저장합니다.  
  
3. 손상 복원 모델  
  
학습 과정  
  
(2)단계에서 의도적으로 손상시킨 컬러 이미지와 train_gt 파일을 활용하여 손상 복원 모델을 학습합니다.  
  
테스트 과정  
  
컬러 복원된 파일에서 마스크를 사용해 손상된 부분을 인식하고 복원합니다.  
복원된 이미지를 최종 폴더에 저장합니다.  
단, 컬러 복원 모델의 복원 성능이 우수한 부분은 그대로 사용합니다.  

--------------------

전처리 부가 함수  
1. 마스크 생성기  
손상된 부분의 마스크를 생성합니다.  
  
2. 컬러 손상 이미지 생성기  
손상된 컬러 이미지를 생성합니다.  

# 1. 마스크생성모델  
  
![epoch14](https://github.com/Geon-05/submit_imagepro/blob/main/%EC%A0%9C%EC%B6%9C%EC%9A%A9%EC%9D%B4%EB%AF%B8%EC%A7%80/mask_model_epoch14%20copy.png?raw=true)
  

In [None]:
import os
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import torch
from torchvision.models.segmentation import deeplabv3_resnet50, DeepLabV3_ResNet50_Weights
import torch.optim as optim
import torch.nn as nn
from tqdm import tqdm
import cv2
import numpy as np
import random
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

batch = 30
data_ratio = 1 # 100% 데이터셋사용
num_epochs = 60
test_size=0.2
lr=0.001

# Dataset and DataLoader
input_dir = '/root/.cache/kagglehub/datasets/geon05/dataset2/versions/1/train_input'
gt_dir = '/root/.cache/kagglehub/datasets/geon05/dataset2/versions/1/train_gt'

# Train-Validation Split (80:20)
image_files = sorted(os.listdir(input_dir))
mask_files = sorted(os.listdir(gt_dir))

# 10% 샘플링된 데이터에서 Train-Validation Split (80:20)
train_images, val_images, train_masks, val_masks = train_test_split(
    image_files, mask_files, test_size=test_size, random_state=42
)

class DamageDataset(Dataset):
    def __init__(self, input_dir, gt_dir, image_files, mask_files, transform=None):
        self.input_dir = input_dir
        self.gt_dir = gt_dir
        self.image_files = image_files
        self.mask_files = mask_files
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        input_path = os.path.join(self.input_dir, self.image_files[idx])
        gt_path = os.path.join(self.gt_dir, self.mask_files[idx])

        # Load and preprocess images
        input_image = Image.open(input_path).convert("RGB")
        input_image_np = np.array(input_image)
        gt_image_gray = Image.open(gt_path).convert("L")
        gt_image_gray_np = np.array(gt_image_gray)

        input_image_gray_np = cv2.cvtColor(input_image_np, cv2.COLOR_RGB2GRAY)
        difference = cv2.absdiff(gt_image_gray_np, input_image_gray_np)
        _, binary_difference = cv2.threshold(difference, 1, 255, cv2.THRESH_BINARY)

        kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 1))
        binary_difference = cv2.morphologyEx(binary_difference, cv2.MORPH_CLOSE, kernel)
        contours, _ = cv2.findContours(binary_difference, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        mask_filled = np.zeros_like(binary_difference)
        cv2.drawContours(mask_filled, contours, -1, color=255, thickness=cv2.FILLED)
        mask_filled = cv2.dilate(mask_filled, kernel, iterations=1)

        input_tensor = transforms.ToTensor()(input_image)
        mask_tensor = torch.tensor(mask_filled, dtype=torch.float32).unsqueeze(0) / 255.0

        return input_tensor, mask_tensor

# Training and Validation Datasets
train_dataset = DamageDataset(input_dir, gt_dir, train_images, train_masks)
val_dataset = DamageDataset(input_dir, gt_dir, val_images, val_masks)

# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=batch, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch, shuffle=False)

# Load model with updated weights parameter
weights = DeepLabV3_ResNet50_Weights.DEFAULT
model = deeplabv3_resnet50(weights=weights)
model.classifier[4] = nn.Conv2d(256, 1, kernel_size=1)

# GPU 사용 설정
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")  # 주 GPU로 cuda:0을 사용하도록 설정

# DataParallel 및 SyncBatchNorm 적용
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs! Model training started...")
    model = nn.DataParallel(model)  # DataParallel로 모델 병렬화
    model = nn.SyncBatchNorm.convert_sync_batchnorm(model)  # SyncBatchNorm 변환

# GPU로 모델 이동
model = model.to(device)

# Loss and Optimizer
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

# Training and Validation Loop

best_val_loss = float('inf')

for epoch in range(num_epochs):
    # Training Phase
    model.train()
    train_loss = 0.0
    for inputs, masks in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Training"):
        inputs, masks = inputs.to(device), masks.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)['out']
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    avg_train_loss = train_loss / len(train_loader)

    # Validation Phase
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for inputs, masks in tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Validation"):
            inputs, masks = inputs.to(device), masks.to(device)
            outputs = model(inputs)['out']
            loss = criterion(outputs, masks)
            val_loss += loss.item()
    avg_val_loss = val_loss / len(val_loader)

    print(f"Epoch {epoch+1}/{num_epochs}")
    print(f"  Training Loss: {avg_train_loss:.4f}")
    print(f"  Validation Loss: {avg_val_loss:.4f}")

    # Save best model
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save(model, f"best_model_{epoch+1}_{best_val_loss:.4f}.pth")
        print(f"  Best model saved with Validation Loss: {best_val_loss:.4f}")

    # Visualization (Optional)
    with torch.no_grad():
        inputs, masks = next(iter(val_loader))
        inputs, masks = inputs[:5].to(device), masks[:5].to(device)
        predictions = torch.sigmoid(model(inputs)['out'])
        predictions = predictions.cpu().numpy()
        masks = masks.cpu().numpy()
        inputs = inputs.cpu().numpy()

        fig, axes = plt.subplots(5, 3, figsize=(12, 15))
        for i in range(5):
            axes[i, 0].imshow(inputs[i].transpose(1, 2, 0))
            axes[i, 0].set_title("Input Image")
            axes[i, 0].axis("off")
            axes[i, 1].imshow(masks[i][0], cmap="gray")
            axes[i, 1].set_title("Ground Truth Mask")
            axes[i, 1].axis("off")
            axes[i, 2].imshow(predictions[i][0], cmap="gray")
            axes[i, 2].set_title("Predicted Mask")
            axes[i, 2].axis("off")
        plt.tight_layout()
        plt.show()


In [None]:
import os
import torch
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader, Dataset

# GPU 설정
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 저장된 모델 경로
model_path = "./best_model_30_0.0016.pth"  # 저장된 모델 파일 경로

# 모델 로드
model = torch.load(model_path)

# 평가 모드로 전환
model.eval()

# 테스트 데이터 경로
test_input_dir = "/path/to/test_input_images"  # 테스트용 흑백 이미지 폴더

# 테스트 데이터셋 정의
class TestDataset(Dataset):
    def __init__(self, input_dir, transform=None):
        self.input_dir = input_dir
        self.image_files = sorted(os.listdir(input_dir))
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        input_path = os.path.join(self.input_dir, self.image_files[idx])
        input_image = Image.open(input_path).convert("RGB")
        if self.transform:
            input_image = self.transform(input_image)
        return input_image, self.image_files[idx]  # 이미지와 파일명을 반환

# Transform 설정
test_transform = transforms.ToTensor()

# 테스트 데이터셋 및 DataLoader 생성
test_dataset = TestDataset(test_input_dir, transform=test_transform)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

# 테스트 수행 및 시각화
def test_and_visualize(model, test_loader, device, output_dir):
    model.to(device)
    os.makedirs(output_dir, exist_ok=True)

    with torch.no_grad():
        for inputs, filenames in test_loader:
            inputs = inputs.to(device)
            outputs = model(inputs)['out']
            predictions = torch.sigmoid(outputs).cpu().numpy()

            for i, pred in enumerate(predictions):
                pred_mask = pred[0]  # (1, H, W) -> (H, W)
                input_image = inputs[i].cpu().numpy().transpose(1, 2, 0)  # (C, H, W) -> (H, W, C)

                # 시각화 및 저장
                fig, axes = plt.subplots(1, 2, figsize=(8, 4))
                axes[0].imshow(input_image, cmap="gray")
                axes[0].set_title("Input Image")
                axes[0].axis("off")

                axes[1].imshow(pred_mask, cmap="gray")
                axes[1].set_title("Predicted Mask")
                axes[1].axis("off")

                plt.tight_layout()
                plt.show()

                # 예측 결과 저장
                save_path = os.path.join(output_dir, filenames[i])
                pred_mask_uint8 = (pred_mask * 255).astype(np.uint8)
                Image.fromarray(pred_mask_uint8).save(save_path)

# 테스트 결과 저장 경로
output_dir = "./test_predictions"

# 테스트 및 시각화 수행
test_and_visualize(model, test_loader, device, output_dir)


# 2. 컬러복원모델

![epoch14](https://github.com/Geon-05/submit_imagepro/blob/main/%EC%A0%9C%EC%B6%9C%EC%9A%A9%EC%9D%B4%EB%AF%B8%EC%A7%80/%EC%BB%AC%EB%9F%AC%EB%B3%B5%EC%9B%90.png?raw=true)

In [None]:
import os
import glob
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms, models
import numpy as np
import cv2
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm
from skimage.metrics import structural_similarity as ssim
from torchvision.models import resnet50, ResNet50_Weights

####################
# 파라미터 설정
####################
gray_dir = "/root/.cache/kagglehub/datasets/geon05/dataset2/versions/1/train_input"
color_dir = "/root/.cache/kagglehub/datasets/geon05/damages-masks/versions/1/damage_images/damage_images"
mask_dir = "/root/.cache/kagglehub/datasets/geon05/damages-masks/versions/1/output_masks/output_masks"

batch_size =20
lr = 1e-4
epochs = 50
test_size = 0.2
lambda_perc = 0.1  # Perceptual Loss 가중치

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

####################
# Lab 변환 함수
####################
def rgb_to_lab_normalized(rgb):
    rgb_np = (rgb.permute(1,2,0).numpy() * 255).astype(np.uint8)
    bgr = cv2.cvtColor(rgb_np, cv2.COLOR_RGB2BGR)
    lab = cv2.cvtColor(bgr, cv2.COLOR_BGR2Lab).astype(np.float32)
    L = lab[:,:,0] / 255.0
    a = (lab[:,:,1] - 128.0)/128.0
    b = (lab[:,:,2] - 128.0)/128.0
    return L, a, b

def lab_to_rgb(L, a, b):
    lab_0_255 = np.zeros((L.shape[0], L.shape[1], 3), dtype=np.float32)
    lab_0_255[:,:,0] = L * 255.0
    lab_0_255[:,:,1] = a * 128.0 + 128.0
    lab_0_255[:,:,2] = b * 128.0 + 128.0
    lab_0_255 = np.clip(lab_0_255, 0, 255).astype(np.uint8)
    bgr = cv2.cvtColor(lab_0_255, cv2.COLOR_Lab2BGR)
    rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
    rgb = np.clip(rgb,0,255).astype(np.uint8)
    return rgb / 255.0

####################
# SSIM, Histogram Similarity
####################
def ssim_score(true, pred):
    return ssim(true, pred, channel_axis=-1, data_range=1.0)

def histogram_similarity(true, pred):
    true_bgr = cv2.cvtColor((true*255).astype(np.uint8), cv2.COLOR_RGB2BGR)
    pred_bgr = cv2.cvtColor((pred*255).astype(np.uint8), cv2.COLOR_RGB2BGR)
    true_hsv = cv2.cvtColor(true_bgr, cv2.COLOR_BGR2HSV)
    pred_hsv = cv2.cvtColor(pred_bgr, cv2.COLOR_BGR2HSV)
    hist_true = cv2.calcHist([true_hsv], [0], None, [180], [0, 180])
    hist_pred = cv2.calcHist([pred_hsv], [0], None, [180], [0, 180])
    hist_true = cv2.normalize(hist_true, hist_true).flatten()
    hist_pred = cv2.normalize(hist_pred, hist_pred).flatten()
    return cv2.compareHist(hist_true, hist_pred, cv2.HISTCMP_CORREL)

####################
# VGGPerceptualLoss 정의
####################
class VGGPerceptualLoss(nn.Module):
    def __init__(self, layer_ids=[3,8]):
        super().__init__()
        vgg = models.vgg16(weights=models.VGG16_Weights.DEFAULT).features
        self.layers = nn.ModuleList([vgg[i] for i in range(max(layer_ids)+1)])
        self.layer_ids = layer_ids
        for param in self.layers.parameters():
            param.requires_grad = False
    def forward(self, x):
        mean = torch.tensor([0.485,0.456,0.406], device=x.device).view(1,3,1,1)
        std = torch.tensor([0.229,0.224,0.225], device=x.device).view(1,3,1,1)
        x = (x - mean) / std
        feats = []
        for i, layer in enumerate(self.layers):
            x = layer(x)
            if i in self.layer_ids:
                feats.append(x)
        return feats

####################
# UpConv 및 DoubleConv 클래스
####################
class UpConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(UpConv, self).__init__()
        self.up = nn.ConvTranspose2d(in_ch, out_ch, kernel_size=2, stride=2)
    def forward(self, x):
        return self.up(x)

class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.conv(x)

####################
# Dataset 정의 (마스크 추가)
####################
class DamagedGrayColorDataset(Dataset):
    def __init__(self, gray_paths, color_paths, mask_paths, transform_gray=None, transform_color=None):
        self.gray_paths = gray_paths
        self.color_paths = color_paths
        self.mask_paths = mask_paths
        self.transform_gray = transform_gray
        self.transform_color = transform_color
        assert len(self.gray_paths) == len(self.color_paths) == len(self.mask_paths)

    def __len__(self):
        return len(self.gray_paths)

    def __getitem__(self, idx):
        g_path = self.gray_paths[idx]
        c_path = self.color_paths[idx]
        m_path = self.mask_paths[idx]

        gray_img = Image.open(g_path).convert('L')
        color_img = Image.open(c_path).convert('RGB')
        mask_img = Image.open(m_path).convert('L')  # 0~255 범위, 255=손상?

        # 마스크 이진화
        mask_np = np.array(mask_img)
        mask_bin = (mask_np > 128).astype(np.float32)
        mask_bin = torch.from_numpy(mask_bin).unsqueeze(0) # [1,H,W]

        if self.transform_gray:
            gray_img = self.transform_gray(gray_img)
        if self.transform_color:
            color_img = self.transform_color(color_img)

        # 타겟 Lab 변환
        L_t, a_t, b_t = rgb_to_lab_normalized(color_img)
        a_t = torch.from_numpy(a_t).unsqueeze(0)
        b_t = torch.from_numpy(b_t).unsqueeze(0)
        ab_t = torch.cat([a_t, b_t], dim=0)
        L_t = torch.from_numpy(L_t).unsqueeze(0)

        # Gray->RGB->Lab L 변환
        gray_3ch = torch.cat([gray_img,gray_img,gray_img], dim=0)
        G_L, G_a, G_b = rgb_to_lab_normalized(gray_3ch)
        G_L = torch.from_numpy(G_L).unsqueeze(0)

        return G_L, ab_t, L_t, mask_bin

transform_gray = transforms.Compose([
    transforms.Resize((512,512)),
    transforms.ToTensor()
])

transform_color = transforms.Compose([
    transforms.Resize((512,512)),
    transforms.ToTensor()
])

gray_files = sorted(glob.glob(os.path.join(gray_dir, "*")))
color_files = sorted(glob.glob(os.path.join(color_dir, "*")))

mask_files = []
for gf in gray_files:
    fname = os.path.basename(gf)
    m_path = os.path.join(mask_dir, fname)
    mask_files.append(m_path)

train_gray, val_gray, train_color, val_color, train_mask, val_mask = train_test_split(
    gray_files, color_files, mask_files, test_size=test_size, random_state=42
)

train_dataset = DamagedGrayColorDataset(train_gray, train_color, train_mask, transform_gray, transform_color)
val_dataset = DamagedGrayColorDataset(val_gray, val_color, val_mask, transform_gray, transform_color)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

####################
# ResNetEncoder (ResNet-50)
####################
class ResNetEncoder(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        net = models.resnet50(weights=ResNet50_Weights.DEFAULT if pretrained else None)
        self.initial = nn.Sequential(net.conv1, net.bn1, net.relu)
        self.maxpool = net.maxpool
        self.layer1 = net.layer1 # 256채널
        self.layer2 = net.layer2 # 512채널
        self.layer3 = net.layer3 # 1024채널
        self.layer4 = net.layer4 # 2048채널

    def forward(self, x):
        x0 = self.initial(x)   #64채널
        x1 = self.maxpool(x0)
        x1 = self.layer1(x1)   #256채널
        x2 = self.layer2(x1)   #512채널
        x3 = self.layer3(x2)   #1024채널
        x4 = self.layer4(x3)   #2048채널
        return x0, x1, x2, x3, x4

####################
# 디코더 부분 (ResNet-50에 맞게)
####################
class ResNetUNet(nn.Module):
    def __init__(self, out_ch=2, pretrained=True):
        super().__init__()
        self.encoder = ResNetEncoder(pretrained=pretrained)
        self.up3 = UpConv(2048, 1024)
        self.dec3 = DoubleConv(2048, 1024)
        self.up2 = UpConv(1024, 512)
        self.dec2 = DoubleConv(1024, 512)
        self.up1 = UpConv(512, 256)
        self.dec1 = DoubleConv(512, 256)
        self.up0 = UpConv(256, 64)
        self.dec0 = DoubleConv(128, 64)
        self.up_final = UpConv(64,64)
        self.dec_final = DoubleConv(64,64)
        self.final_out = nn.Conv2d(64, out_ch, 1)

    def forward(self, x):
        x = x.repeat(1,3,1,1)
        x0, x1, x2, x3, x4 = self.encoder(x)
        x_up3 = self.up3(x4)
        x_cat3 = torch.cat([x_up3, x3], dim=1)
        x_dec3 = self.dec3(x_cat3)

        x_up2 = self.up2(x_dec3)
        x_cat2 = torch.cat([x_up2, x2], dim=1)
        x_dec2 = self.dec2(x_cat2)

        x_up1 = self.up1(x_dec2)
        x_cat1 = torch.cat([x_up1, x1], dim=1)
        x_dec1 = self.dec1(x_cat1)

        x_up0 = self.up0(x_dec1)
        x_cat0 = torch.cat([x_up0, x0], dim=1)
        x_dec0 = self.dec0(x_cat0)

        x_upf = self.up_final(x_dec0)
        x_decf = self.dec_final(x_upf)
        out = self.final_out(x_decf)
        return out

model = ResNetUNet(out_ch=2, pretrained=True).to(device)
perceptual_extractor = VGGPerceptualLoss(layer_ids=[3,8]).to(device)
mse_loss = nn.MSELoss(reduction='none')

def compute_loss(L_img, ab_img, pred_ab, L_t_img, mask):
    # 마스크 제외한 영역만 손실 계산
    diff = mse_loss(pred_ab, ab_img)  # MSE
    valid_area = (1 - mask)
    diff = diff * valid_area
    denom = valid_area.sum() + 1e-8
    l1_val = diff.sum() / denom

    B = L_img.size(0)
    ab_np = ab_img.permute(0,2,3,1).cpu().numpy()
    pred_ab_np = pred_ab.permute(0,2,3,1).detach().cpu().numpy()
    L_t_np = L_t_img[:,0].cpu().numpy()

    true_rgb_list = []
    pred_rgb_list = []
    for i in range(B):
        trgb = lab_to_rgb(L_t_np[i], ab_np[i][:,:,0], ab_np[i][:,:,1])
        prgb = lab_to_rgb(L_t_np[i], pred_ab_np[i][:,:,0], pred_ab_np[i][:,:,1])
        true_rgb_list.append(trgb)
        pred_rgb_list.append(prgb)

    true_rgb_t = torch.from_numpy(np.stack(true_rgb_list,axis=0)).float().to(device)
    pred_rgb_t = torch.from_numpy(np.stack(pred_rgb_list,axis=0)).float().to(device)
    true_rgb_t = true_rgb_t.permute(0,3,1,2)
    pred_rgb_t = pred_rgb_t.permute(0,3,1,2)

    true_feats = perceptual_extractor(true_rgb_t)
    pred_feats = perceptual_extractor(pred_rgb_t)

    perc_loss_val = torch.tensor(0.0, device=device)
    for ft, fp in zip(true_feats, pred_feats):
        perc_diff = (ft - fp)**2
        down_mask = F.interpolate(valid_area, size=ft.shape[2:], mode='bilinear', align_corners=False)
        down_mask_3ch = down_mask.repeat(1,ft.shape[1],1,1)
        perc_diff = perc_diff * down_mask_3ch
        perc_sum = perc_diff.sum()
        perc_count = down_mask_3ch.sum() + 1e-8
        perc_loss_val += perc_sum / perc_count

    total_loss = l1_val + lambda_perc * perc_loss_val
    return total_loss, l1_val, perc_loss_val

def visualize_samples(model, data_loader, device, num_samples=3):
    model.eval()
    with torch.no_grad():
        val_iter = iter(data_loader)
        L_batch, ab_batch, L_t_batch, mask_batch = next(val_iter)
        L_batch = L_batch.to(device)
        pred_ab = model(L_batch)

    fig, axes = plt.subplots(num_samples, 3, figsize=(12, 4 * num_samples))
    if num_samples == 1:
        axes = [axes]

    for i in range(num_samples):
        L_np = L_batch[i,0].cpu().numpy()
        ab_true = ab_batch[i].permute(1,2,0).numpy()
        L_t_np = L_t_batch[i,0].numpy()
        pred_ab_np = pred_ab[i].cpu().permute(1,2,0).numpy()

        true_rgb = lab_to_rgb(L_t_np, ab_true[:,:,0], ab_true[:,:,1])
        pred_rgb = lab_to_rgb(L_t_np, pred_ab_np[:,:,0], pred_ab_np[:,:,1])

        axes[i][0].imshow(L_np, cmap='gray')
        axes[i][0].set_title("Damaged Gray Input")
        axes[i][0].axis('off')

        axes[i][1].imshow(true_rgb)
        axes[i][1].set_title("Target Color (using L_t)")
        axes[i][1].axis('off')

        axes[i][2].imshow(pred_rgb)
        axes[i][2].set_title("Predicted Color")
        axes[i][2].axis('off')

    plt.tight_layout()
    plt.show()

optimizer = optim.Adam(model.parameters(), lr=lr)

checkpoint_path = "/content/20241207_05_best_model_1_0.046.pth"
if os.path.exists(checkpoint_path):
    state_dict = torch.load(checkpoint_path, map_location=device, weights_only=True)
    model.load_state_dict(state_dict)
    print(f"Loaded model state from {checkpoint_path}, continuing training...")

best_val_loss = float('inf')
best_combined_metric = float('-inf')

for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    train_pbar = tqdm(train_loader, desc=f"Epoch [{epoch+1}/{epochs}] Training")
    for L_img, ab_img, L_t_img, mask in train_pbar:
        L_img, ab_img, L_t_img, mask = L_img.to(device), ab_img.to(device), L_t_img.to(device), mask.to(device)
        optimizer.zero_grad()
        pred_ab = model(L_img)
        total_loss, l1_val, perc_val = compute_loss(L_img, ab_img, pred_ab, L_t_img, mask)
        total_loss.backward()
        optimizer.step()
        running_loss += total_loss.item()
        train_pbar.set_postfix(loss=total_loss.item(), L1=l1_val.item(), Perc=perc_val.item())

    avg_train_loss = running_loss / len(train_loader)

    model.eval()
    val_loss = 0.0
    ssim_list = []
    hist_sim_list = []
    val_pbar = tqdm(val_loader, desc=f"Epoch [{epoch+1}/{epochs}] Validation")

    with torch.no_grad():
        for i, (L_img, ab_img, L_t_img, mask) in enumerate(val_pbar):
            L_img, ab_img, L_t_img, mask = L_img.to(device), ab_img.to(device), L_t_img.to(device), mask.to(device)
            pred_ab = model(L_img)
            total_loss, l1_val, perc_val = compute_loss(L_img, ab_img, pred_ab, L_t_img, mask)
            val_loss += total_loss.item()
            val_pbar.set_postfix(val_loss=total_loss.item(), L1=l1_val.item(), Perc=perc_val.item())

            # 마스크를 제외한 부분만을 반영한 SSIM/히스토그램 계산
            # 첫 배치에 대해 몇 장만 측정 (원래 코드 로직 유지)
            if i == 0:
                num_samples = min(3, ab_img.size(0))
                L_np_all = L_img[:,0].cpu().numpy()
                ab_t_np_all = ab_img.permute(0,2,3,1).cpu().numpy()
                pred_ab_np_all = pred_ab.permute(0,2,3,1).cpu().numpy()
                L_t_np_all = L_t_img[:,0].cpu().numpy()
                mask_np_all = mask[:,0].cpu().numpy()
                for j in range(num_samples):
                    true_rgb = lab_to_rgb(L_t_np_all[j], ab_t_np_all[j][:,:,0], ab_t_np_all[j][:,:,1])
                    pred_rgb = lab_to_rgb(L_t_np_all[j], pred_ab_np_all[j][:,:,0], pred_ab_np_all[j][:,:,1])
                    # 마스크 제외: masked area -> 0으로 set
                    valid_area = (1 - mask_np_all[j]) # [H,W]
                    valid_area_3ch = np.stack([valid_area]*3, axis=-1) # [H,W,3]

                    # 마스크 영역 0으로 만들기
                    true_rgb = true_rgb * valid_area_3ch
                    pred_rgb = pred_rgb * valid_area_3ch

                    current_ssim = ssim_score(true_rgb, pred_rgb)
                    ssim_list.append(current_ssim)
                    current_hist_sim = histogram_similarity(true_rgb, pred_rgb)
                    hist_sim_list.append(current_hist_sim)

    avg_val_loss = val_loss / len(val_loader)
    print(f"Epoch [{epoch+1}/{epochs}] Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")

    if len(ssim_list) > 0:
        mean_ssim = np.mean(ssim_list)
        mean_hist_sim = np.mean(hist_sim_list)
        print(f"Epoch [{epoch+1}] Mean SSIM (no-mask): {mean_ssim:.4f}, Mean Histogram Similarity (no-mask): {mean_hist_sim:.4f}")
        combined_metric = mean_ssim * mean_hist_sim
        print(f"Epoch [{epoch+1}] Combined Metric (no-mask): {combined_metric:.4f}")

        # 최적 모델은 마스크 제외 영역 품질 기준
        if combined_metric > best_combined_metric:
            best_combined_metric = combined_metric
            torch.save(model.state_dict(), f"20241207_05_best_model_{epoch+1}_{avg_val_loss:.3f}.pth")
            print("** Best Model Updated and Saved! **")

    visualize_samples(model, val_loader, device, num_samples=3)


In [None]:
import os
import glob
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
import numpy as np
import cv2
import matplotlib.pyplot as plt
from torchvision import transforms, models
from torchvision.models import resnet50, ResNet50_Weights

#------------- 모델 정의 부분 (학습 시 사용한 코드와 동일해야 함) -------------#
def rgb_to_lab_normalized(rgb):
    rgb_np = (rgb.permute(1,2,0).numpy() * 255).astype(np.uint8)
    bgr = cv2.cvtColor(rgb_np, cv2.COLOR_RGB2BGR)
    lab = cv2.cvtColor(bgr, cv2.COLOR_BGR2Lab).astype(np.float32)
    L = lab[:,:,0] / 255.0
    a = (lab[:,:,1] - 128.0)/128.0
    b = (lab[:,:,2] - 128.0)/128.0
    return L, a, b

def lab_to_rgb(L, a, b):
    lab_0_255 = np.zeros((L.shape[0], L.shape[1], 3), dtype=np.float32)
    lab_0_255[:,:,0] = L * 255.0
    lab_0_255[:,:,1] = a * 128.0 + 128.0
    lab_0_255[:,:,2] = b * 128.0 + 128.0

    lab_0_255 = np.clip(lab_0_255, 0, 255).astype(np.uint8)
    bgr = cv2.cvtColor(lab_0_255, cv2.COLOR_Lab2BGR)
    rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
    rgb = np.clip(rgb,0,255).astype(np.uint8)
    return rgb / 255.0

class UpConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(UpConv, self).__init__()
        self.up = nn.ConvTranspose2d(in_ch, out_ch, kernel_size=2, stride=2)
    def forward(self, x):
        return self.up(x)

class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.conv(x)

class ResNetEncoder(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        net = models.resnet50(weights=ResNet50_Weights.DEFAULT if pretrained else None)
        self.initial = nn.Sequential(net.conv1, net.bn1, net.relu)
        self.maxpool = net.maxpool
        self.layer1 = net.layer1 # 256채널
        self.layer2 = net.layer2 # 512채널
        self.layer3 = net.layer3 # 1024채널
        self.layer4 = net.layer4 # 2048채널

    def forward(self, x):
        x0 = self.initial(x)   #64채널
        x1 = self.maxpool(x0)
        x1 = self.layer1(x1)   #256
        x2 = self.layer2(x1)   #512
        x3 = self.layer3(x2)   #1024
        x4 = self.layer4(x3)   #2048
        return x0, x1, x2, x3, x4

class ResNetUNet(nn.Module):
    def __init__(self, out_ch=2, pretrained=True):
        super().__init__()
        self.encoder = ResNetEncoder(pretrained=pretrained)
        self.up3 = UpConv(2048, 1024)
        self.dec3 = DoubleConv(2048, 1024)

        self.up2 = UpConv(1024, 512)
        self.dec2 = DoubleConv(1024, 512)

        self.up1 = UpConv(512, 256)
        self.dec1 = DoubleConv(512, 256)

        self.up0 = UpConv(256, 64)
        self.dec0 = DoubleConv(128, 64)

        self.up_final = UpConv(64,64)
        self.dec_final = DoubleConv(64,64)
        self.final_out = nn.Conv2d(64, out_ch, 1)

    def forward(self, x):
        x = x.repeat(1,3,1,1)
        x0, x1, x2, x3, x4 = self.encoder(x)

        x_up3 = self.up3(x4)           
        x_cat3 = torch.cat([x_up3, x3], dim=1) 
        x_dec3 = self.dec3(x_cat3)     

        x_up2 = self.up2(x_dec3)       
        x_cat2 = torch.cat([x_up2, x2], dim=1)
        x_dec2 = self.dec2(x_cat2)     

        x_up1 = self.up1(x_dec2)       
        x_cat1 = torch.cat([x_up1, x1], dim=1)
        x_dec1 = self.dec1(x_cat1)     

        x_up0 = self.up0(x_dec1)       
        x_cat0 = torch.cat([x_up0, x0], dim=1)
        x_dec0 = self.dec0(x_cat0)

        x_upf = self.up_final(x_dec0)
        x_decf = self.dec_final(x_upf)
        out = self.final_out(x_decf)
        return out

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 모델 로드
model = ResNetUNet(out_ch=2, pretrained=False).to(device)
model.load_state_dict(torch.load("/home/zqrc05/project/imagepro/test/model/grayTocol/12_05_best_model_7_0.043.pth", map_location=device))
model.eval()

# 테스트 디렉토리 지정
test_gray_dir = "/home/zqrc05/project/imagepro/test/test_input"  # 손상된 흑백 이미지들이 있는 폴더 경로
test_mask_dir = "/home/zqrc05/project/imagepro/test/mask"  # 마스크 이미지들이 있는 폴더 경로
output_dir = "/home/zqrc05/project/imagepro/test/output_grayTocol"     # 복원 결과를 저장할 폴더 경로

os.makedirs(output_dir, exist_ok=True)

transform_gray = transforms.Compose([
    transforms.Resize((512,512)),
    transforms.ToTensor()
])

# test_gray_dir 내 모든 png 이미지 처리
gray_image_paths = sorted(glob.glob(os.path.join(test_gray_dir, "*.png")))

with torch.no_grad():
    for gray_path in gray_image_paths:
        fname = os.path.basename(gray_path)  # 예: TEST_001.png
        mask_path = os.path.join(test_mask_dir, fname) # 동일 이름의 mask

        if not os.path.exists(mask_path):
            print(f"No matching mask found for {fname}, skipping...")
            continue

        # 흑백 이미지 로드
        gray_img = Image.open(gray_path).convert('L')
        gray_tensor = transform_gray(gray_img) # [1,H,W]

        # 마스크 로드
        mask_img = Image.open(mask_path).convert('L')
        mask_np = np.array(mask_img)
        mask_bin = (mask_np > 128).astype(np.float32)
        mask_bin = torch.from_numpy(mask_bin).unsqueeze(0) # [1,H,W]
        # 사이즈가 다르다면 아래 주석 제거 후 interpolate 적용
        # mask_bin = F.interpolate(mask_bin.unsqueeze(0), size=(512,512), mode='bilinear', align_corners=False).squeeze(0)

        gray_tensor = gray_tensor.unsqueeze(0).to(device) # [1,1,H,W]
        mask_bin = mask_bin.to(device) # [1,H,W]

        # 모델 추론
        pred_ab = model(gray_tensor) # [1,2,H,W]

        # 결과 복원
        pred_ab_np = pred_ab[0].cpu().permute(1,2,0).numpy()  # [H,W,2]
        L_np = gray_tensor[0,0].cpu().numpy() # [H,W]

        # Lab->RGB 복원 (L은 gray에서 가져옴)
        pred_rgb = lab_to_rgb(L_np, pred_ab_np[:,:,0], pred_ab_np[:,:,1])

        # 결과 저장
        out_path = os.path.join(output_dir, fname)
        Image.fromarray((pred_rgb*255).astype(np.uint8)).save(out_path)
        print(f"Saved restored image: {out_path}")


# 3. 손상복원모델

![epoch14](https://github.com/Geon-05/submit_imagepro/blob/main/%EC%A0%9C%EC%B6%9C%EC%9A%A9%EC%9D%B4%EB%AF%B8%EC%A7%80/%EC%86%90%EC%83%81%EB%B3%B5%EC%9B%90.png?raw=true)

In [None]:
import os
import torch
import torchvision.transforms as T
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torch.optim as optim
from PIL import Image
import matplotlib.pyplot as plt
import time
from tqdm import tqdm
import numpy as np
import cv2
from sklearn.model_selection import train_test_split
from torchvision.models import vgg19, VGG19_Weights

#---------------------------
# 파라미터 설정
#---------------------------
batch_size = 4
lr = 0.0002
epochs = 50  # 추가 학습할 에폭 수
test_size = 0.2
num_workers = 4

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

input_dir = "/root/.cache/kagglehub/datasets/geon05/damages-masks/versions/1/damage_images/damage_images"
gt_dir = "/root/.cache/kagglehub/datasets/geon05/dataset2/versions/1/train_gt"
mask_dir = "/root/.cache/kagglehub/datasets/geon05/damages-masks/versions/1/output_masks/output_masks"

#---------------------------
# Perceptual Loss 정의
#---------------------------
class VGGPerceptualLoss(nn.Module):
    def __init__(self, layer_ids=[0, 5, 10, 19, 28], requires_grad=False):
        super(VGGPerceptualLoss, self).__init__()
        vgg = vgg19(weights=VGG19_Weights.IMAGENET1K_V1).features
        self.layers = nn.ModuleList([vgg[i] for i in layer_ids])
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, x, y):
        loss = 0.0
        for layer in self.layers:
            x = layer(x)
            y = layer(y)
            loss += nn.functional.l1_loss(x, y)
        return loss

vgg_mean = torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1).to(device)
vgg_std = torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1).to(device)

def normalize_vgg_inputs(img):
    return (img - vgg_mean) / vgg_std

#---------------------------
# Gated Convolution Layer 정의
#---------------------------
class GatedConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()):
        super(GatedConv2d, self).__init__()
        self.feature_conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.mask_conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.sigmoid = nn.Sigmoid()
        self.activation = activation

    def forward(self, x):
        f = self.feature_conv(x)
        m = self.mask_conv(x)
        gated = self.sigmoid(m)
        if self.activation is not None:
            f = self.activation(f)
        return f * gated

class GatedDeconv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1, activation=nn.ReLU()):
        super(GatedDeconv2d, self).__init__()
        self.feature_deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding)
        self.mask_deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding)
        self.sigmoid = nn.Sigmoid()
        self.activation = activation

    def forward(self, x):
        f = self.feature_deconv(x)
        m = self.mask_deconv(x)
        gated = self.sigmoid(m)
        if self.activation is not None:
            f = self.activation(f)
        return f * gated

#---------------------------
# Contextual Attention (간략 예시)
#---------------------------
class ContextualAttention(nn.Module):
    def __init__(self, kernel_size=3, stride=1, dilation=1):
        super(ContextualAttention, self).__init__()
        self.conv = nn.Conv2d(512, 512, kernel_size, stride, dilation, bias=False)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        B,C,H,W = x.size()
        query = x.view(B,C,-1)
        key = x.view(B,C,-1)
        value = x.view(B,C,-1)

        attn = torch.bmm(query.permute(0,2,1), key)
        attn = self.softmax(attn)
        out = torch.bmm(attn, value.permute(0,2,1))
        out = out.permute(0,2,1).view(B,C,H,W)
        out = self.conv(out)
        return out

class InpaintGenerator(nn.Module):
    def __init__(self):
        super(InpaintGenerator, self).__init__()
        self.encoder = nn.Sequential(
            GatedConv2d(4, 64, 4, 2, 1),
            GatedConv2d(64, 128, 4, 2, 1),
            GatedConv2d(128, 256, 4, 2, 1),
            GatedConv2d(256, 512, 4, 2, 1)
        )
        self.contextual_attention = ContextualAttention()
        self.decoder = nn.Sequential(
            GatedDeconv2d(512, 256, 4, 2, 1),
            GatedDeconv2d(256, 128, 4, 2, 1),
            GatedDeconv2d(128, 64, 4, 2, 1),
            GatedDeconv2d(64, 64, 4, 2, 1, activation=nn.ReLU()),
            nn.Conv2d(64, 3, 3, 1, 1),
            nn.Sigmoid()
        )

    def forward(self, x, mask):
        inp = torch.cat((x, mask), dim=1)
        feat = self.encoder(inp)
        feat = self.contextual_attention(feat)
        out = self.decoder(feat)
        return out

class InpaintDiscriminator(nn.Module):
    def __init__(self):
        super(InpaintDiscriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3,64,4,2,1), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64,128,4,2,1), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128,256,4,2,1), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256,1,4,1,1)
        )

    def forward(self, x):
        return self.model(x)

#---------------------------
# Dataset 정의
#---------------------------
class InpaintDataset(Dataset):
    def __init__(self, input_paths, gt_paths, mask_paths, transform=None):
        self.input_paths = input_paths
        self.gt_paths = gt_paths
        self.mask_paths = mask_paths
        self.transform = transform

    def __len__(self):
        return len(self.input_paths)

    def __getitem__(self, idx):
        inp = Image.open(self.input_paths[idx]).convert("RGB")
        gt = Image.open(self.gt_paths[idx]).convert("RGB")
        mask = Image.open(self.mask_paths[idx]).convert("L")

        if self.transform:
            inp = self.transform(inp)
            gt = self.transform(gt)
            mask = self.transform(mask)

        return inp, gt, mask

#---------------------------
# 파일 경로 수집
#---------------------------
input_files = sorted([os.path.join(input_dir, f) for f in os.listdir(input_dir) if f.endswith(('png','jpg','jpeg'))])
gt_files = sorted([os.path.join(gt_dir, f) for f in os.listdir(gt_dir) if f.endswith(('png','jpg','jpeg'))])
mask_files = sorted([os.path.join(mask_dir, f) for f in os.listdir(mask_dir) if f.endswith(('png','jpg','jpeg'))])

train_input_paths, val_input_paths, train_gt_paths, val_gt_paths, train_mask_paths, val_mask_paths = train_test_split(
    input_files, gt_files, mask_files, test_size=test_size, random_state=42)

transform = T.Compose([
    T.Resize((512,512)),
    T.ToTensor()
])

train_dataset = InpaintDataset(train_input_paths, train_gt_paths, train_mask_paths, transform=transform)
val_dataset = InpaintDataset(val_input_paths, val_gt_paths, val_mask_paths, transform=transform)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

#---------------------------
# 모델 초기화 및 이전 모델 가중치 로드
#---------------------------
generator = InpaintGenerator().to(device)
discriminator = InpaintDiscriminator().to(device)

g_optimizer = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.9))
d_optimizer = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.9))

l1_loss = nn.L1Loss()
mse_loss = nn.MSELoss()
bce_loss = nn.BCEWithLogitsLoss()
perceptual_criterion = VGGPerceptualLoss().to(device)

pretrained_generator_path = "/content/2024_03/deepfillv2_generator_finetune_epoch11.pth"
pretrained_discriminator_path = "/content/2024_03/deepfillv2_discriminator_finetune_epoch11.pth"

generator.load_state_dict(torch.load(pretrained_generator_path, map_location=device, weights_only=True))
discriminator.load_state_dict(torch.load(pretrained_discriminator_path, map_location=device, weights_only=True))

print("이전 학습 가중치를 로드하였습니다. 동일 데이터셋에서 이어서 학습을 진행합니다.")

save_dir = "2024_03"
os.makedirs(save_dir, exist_ok=True)

# 최적 모델 저장을 위한 변수
best_val_loss = float('inf')  # 현재까지의 최소 validation loss 기록

#---------------------------
# 추가 학습 루프
#---------------------------
for epoch in range(epochs):
    generator.train()
    discriminator.train()
    running_g_loss = 0.0
    running_d_loss = 0.0

    for i, (inp, gt, mask) in enumerate(tqdm(train_dataloader, desc=f"Additional Epoch [{epoch+1}/{epochs}]")):
        inp, gt, mask = inp.to(device), gt.to(device), mask.to(device)

        # -----------------
        # Train Discriminator
        # -----------------
        fake = generator(inp, mask)
        real_pred = discriminator(gt)
        fake_pred = discriminator(fake.detach())

        real_label = torch.ones_like(real_pred).to(device)
        fake_label = torch.zeros_like(fake_pred).to(device)

        d_loss_real = bce_loss(real_pred, real_label)
        d_loss_fake = bce_loss(fake_pred, fake_label)
        d_loss = (d_loss_real + d_loss_fake) * 0.5

        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()

        # -----------------
        # Train Generator
        # -----------------
        fake_pred = discriminator(fake)
        real_label_g = torch.ones_like(fake_pred).to(device)
        g_adv_loss = bce_loss(fake_pred, real_label_g)

        fake_norm = normalize_vgg_inputs(fake)
        gt_norm = normalize_vgg_inputs(gt)

        g_l1 = l1_loss(fake, gt)
        g_mse = mse_loss(fake, gt)
        g_perc = perceptual_criterion(fake_norm, gt_norm)

        g_recon_loss = 10.0 * g_l1 + 5.0 * g_mse + 1.0 * g_perc
        g_loss = g_adv_loss + g_recon_loss

        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()

        running_g_loss += g_loss.item()
        running_d_loss += d_loss.item()

    avg_g_loss = running_g_loss / len(train_dataloader)
    avg_d_loss = running_d_loss / len(train_dataloader)

    print(f"Additional Epoch [{epoch+1}/{epochs}] - G Loss: {avg_g_loss:.4f}, D Loss: {avg_d_loss:.4f}")

    # Validation
    generator.eval()
    val_l1 = 0.0
    with torch.no_grad():
        for inp, gt, mask in val_dataloader:
            inp, gt, mask = inp.to(device), gt.to(device), mask.to(device)
            fake = generator(inp, mask)
            val_l1 += l1_loss(fake, gt).item()

    val_l1 = val_l1 / len(val_dataloader)
    print(f"Validation L1 Loss: {val_l1:.4f}")

    # 매 epoch 종료 시 모델 저장
    torch.save(generator.state_dict(), f"{save_dir}/01_deepfillv2_generator_finetune_epoch{epoch+1}.pth")
    torch.save(discriminator.state_dict(), f"{save_dir}/01_deepfillv2_discriminator_finetune_epoch{epoch+1}.pth")

    # 최적 모델 갱신
    if val_l1 < best_val_loss:
        best_val_loss = val_l1
        torch.save(generator.state_dict(), f"{save_dir}/best_generator.pth")
        torch.save(discriminator.state_dict(), f"{save_dir}/best_discriminator.pth")
        print(f"최적 모델 갱신: Validation L1 Loss: {val_l1:.4f}, 모델 저장 완료.")

    # 시각화(예시)
    with torch.no_grad():
        sample_inp, sample_gt, sample_mask = next(iter(val_dataloader))
        sample_inp, sample_gt, sample_mask = sample_inp.to(device), sample_gt.to(device), sample_mask.to(device)
        sample_fake = generator(sample_inp, sample_mask)

        fig, axes = plt.subplots(3,3, figsize=(12,12))
        for idx in range(3):
            axes[idx,0].imshow(sample_inp[idx].cpu().permute(1,2,0).numpy())
            axes[idx,0].set_title("Corrupted")
            axes[idx,0].axis("off")

            axes[idx,1].imshow(sample_gt[idx].cpu().permute(1,2,0).numpy())
            axes[idx,1].set_title("Ground Truth")
            axes[idx,1].axis("off")

            axes[idx,2].imshow(sample_fake[idx].cpu().permute(1,2,0).numpy())
            axes[idx,2].set_title("Inpainted")
            axes[idx,2].axis("off")
        plt.tight_layout()
        plt.show()


In [None]:
import os
import torch
from PIL import Image
import torchvision.transforms as T
import matplotlib.pyplot as plt
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

############################
# 모델 구조 정의 (학습시 사용한 것과 동일)
############################

class GatedConv2d(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, activation=torch.nn.ReLU()):
        super(GatedConv2d, self).__init__()
        self.feature_conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.mask_conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.sigmoid = torch.nn.Sigmoid()
        self.activation = activation

    def forward(self, x):
        f = self.feature_conv(x)
        m = self.mask_conv(x)
        gated = self.sigmoid(m)
        if self.activation is not None:
            f = self.activation(f)
        return f * gated

class GatedDeconv2d(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1, activation=torch.nn.ReLU()):
        super(GatedDeconv2d, self).__init__()
        self.feature_deconv = torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding)
        self.mask_deconv = torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding)
        self.sigmoid = torch.nn.Sigmoid()
        self.activation = activation

    def forward(self, x):
        f = self.feature_deconv(x)
        m = self.mask_deconv(x)
        gated = self.sigmoid(m)
        if self.activation is not None:
            f = self.activation(f)
        return f * gated

class ContextualAttention(torch.nn.Module):
    def __init__(self, kernel_size=3, stride=1, dilation=1):
        super(ContextualAttention, self).__init__()
        self.conv = torch.nn.Conv2d(512, 512, kernel_size, stride, dilation, bias=False)
        self.softmax = torch.nn.Softmax(dim=-1)

    def forward(self, x):
        B,C,H,W = x.size()
        query = x.view(B,C,-1)
        key = x.view(B,C,-1)
        value = x.view(B,C,-1)
        attn = torch.bmm(query.permute(0,2,1), key)
        attn = self.softmax(attn)
        out = torch.bmm(attn, value.permute(0,2,1))
        out = out.permute(0,2,1).view(B,C,H,W)
        out = self.conv(out)
        return out

class Stage1Generator(torch.nn.Module):
    def __init__(self):
        super(Stage1Generator, self).__init__()
        self.encoder = torch.nn.Sequential(
            GatedConv2d(4, 64, 4, 2, 1),
            GatedConv2d(64, 128, 4, 2, 1),
            GatedConv2d(128, 256, 4, 2, 1),
            GatedConv2d(256, 512, 4, 2, 1)
        )
        self.decoder = torch.nn.Sequential(
            GatedDeconv2d(512, 256, 4, 2, 1),
            GatedDeconv2d(256, 128, 4, 2, 1),
            GatedDeconv2d(128, 64, 4, 2, 1),
            GatedDeconv2d(64, 64, 4, 2, 1, activation=torch.nn.ReLU()),
            torch.nn.Conv2d(64, 3, 3, 1, 1),
            torch.nn.Sigmoid()
        )

    def forward(self, x, mask):
        inp = torch.cat((x, mask), dim=1)
        feat = self.encoder(inp)
        out = self.decoder(feat)
        return out

class Stage2Generator(torch.nn.Module):
    def __init__(self):
        super(Stage2Generator, self).__init__()
        self.encoder = torch.nn.Sequential(
            GatedConv2d(7, 64, 4, 2, 1),
            GatedConv2d(64, 128, 4, 2, 1),
            GatedConv2d(128, 256, 4, 2, 1),
            GatedConv2d(256, 512, 4, 2, 1)
        )
        self.contextual_attention = ContextualAttention()
        self.decoder = torch.nn.Sequential(
            GatedDeconv2d(512, 256, 4, 2, 1),
            GatedDeconv2d(256, 128, 4, 2, 1),
            GatedDeconv2d(128, 64, 4, 2, 1),
            GatedDeconv2d(64, 64, 4, 2, 1, activation=torch.nn.ReLU()),
            torch.nn.Conv2d(64, 3, 3, 1, 1),
            torch.nn.Sigmoid()
        )

    def forward(self, coarse_out, inp, mask):
        fin_inp = torch.cat((coarse_out, inp, mask), dim=1)
        feat = self.encoder(fin_inp)
        feat = self.contextual_attention(feat)
        out = self.decoder(feat)
        return out

# 최적 모델 가중치 경로
best_coarse_path = "/home/zqrc05/project/imagepro/test/model/colToperPlus/12_best_coarse_generator_epoch9.pth"
best_fine_path = "/home/zqrc05/project/imagepro/test/model/colToperPlus/12_best_fine_generator_epoch9.pth"


# 테스트 이미지(손상된 컬러 이미지) 폴더
test_input_dir = "/home/zqrc05/project/imagepro/test/output_grayTocol"
test_mask_dir = "/home/zqrc05/project/imagepro/test/mask"
output_dir = "/home/zqrc05/project/imagepro/test/output_colToper_13"

os.makedirs(output_dir, exist_ok=True)

coarse_generator = Stage1Generator().to(device)
fine_generator = Stage2Generator().to(device)

coarse_generator.load_state_dict(torch.load(best_coarse_path, map_location=device, weights_only=True))
fine_generator.load_state_dict(torch.load(best_fine_path, map_location=device, weights_only=True))

coarse_generator.eval()
fine_generator.eval()

transform = T.Compose([
    T.Resize((512,512)),
    T.ToTensor()
])

test_files = [f for f in os.listdir(test_input_dir) if f.lower().endswith(('.png','.jpg','.jpeg'))]

with torch.no_grad():
    for filename in test_files:
        input_path = os.path.join(test_input_dir, filename)
        mask_path = os.path.join(test_mask_dir, filename)

        if not os.path.exists(mask_path):
            print(f"{mask_path}가 존재하지 않습니다. 스킵합니다.")
            continue

        inp_img = Image.open(input_path).convert("RGB")
        mask_img = Image.open(mask_path).convert("L")

        inp_tensor = transform(inp_img).unsqueeze(0).to(device)   # (1,3,H,W)
        mask_tensor = transform(mask_img).unsqueeze(0).to(device) # (1,1,H,W)

        # 손상 영역 0 처리
        mask_broadcast = mask_tensor.expand_as(inp_tensor)
        damaged_inp = inp_tensor * (1 - mask_broadcast)

        # 복원
        coarse_out = coarse_generator(damaged_inp, mask_tensor)
        fine_out = fine_generator(coarse_out, damaged_inp, mask_tensor)

        # 여기서 복원된 부분(fine_out)을 마스크가 1인 영역에만 적용
        # final_result = original_damaged_image * (1 - mask) + fine_out * mask
        final_result = inp_tensor * (1 - mask_broadcast) + fine_out * mask_broadcast

        # 결과 텐서를 이미지로 변환
        final_result_pil = T.ToPILImage()(final_result.squeeze(0).cpu())

        save_path = os.path.join(output_dir, filename)
        final_result_pil.save(save_path)
        print(f"{filename} 복원 완료(마스크 영역만 덮어쓰기) -> {save_path}")


------------

## 1. 마스크 생성기

In [None]:
import os
from PIL import Image
import torch
import cv2
import numpy as np

# 주어진 mask_function (유저 코드)
def mask_function(input_path, gt_path):
    try:
        # Load and preprocess images
        input_image = Image.open(input_path).convert("RGB")  # Ensure RGB format
        input_image_np = np.array(input_image)
        gt_image_gray = Image.open(gt_path).convert("L")  # Load mask as grayscale
        gt_image_gray_np = np.array(gt_image_gray)

        # Convert input_image_np to grayscale
        input_image_gray_np = cv2.cvtColor(input_image_np, cv2.COLOR_RGB2GRAY)

        # Compute the difference
        difference = cv2.absdiff(gt_image_gray_np, input_image_gray_np)

        # Threshold the difference to create a binary mask
        _, binary_difference = cv2.threshold(difference, 1, 255, cv2.THRESH_BINARY)

        # Remove small noise with morphological operations
        kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 1))
        binary_difference = cv2.morphologyEx(binary_difference, cv2.MORPH_CLOSE, kernel)

        # Find contours
        contours, _ = cv2.findContours(binary_difference, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

        # Fill the contours to create a mask
        mask_filled = np.zeros_like(binary_difference)
        cv2.drawContours(mask_filled, contours, -1, color=255, thickness=cv2.FILLED)

        # Expand the filled mask (dilation)
        mask_filled = cv2.dilate(mask_filled, kernel, iterations=1)

        # Convert input image and mask to PyTorch tensors
        mask_tensor = torch.tensor(mask_filled, dtype=torch.float32).unsqueeze(0) / 255.0  # Normalize mask to [0, 1]

        return mask_tensor
    except Exception as e:
        print(f"Mask creation failed for input: {input_path}, error: {e}")
        # 기본적으로 0으로 된 마스크를 반환 (모든 값이 0인 빈 마스크)
        return torch.zeros((1, 512, 512), dtype=torch.float32)


# 손상 이미지 폴더, 정답 이미지 폴더, 출력 마스크 폴더 설정
input_images_dir = "train_input"  # 손상 이미지 폴더 경로
gt_images_dir = "train_gt"        # 정답 이미지 폴더 경로
output_masks_dir = "output_masks"  # 결과 마스크를 저장할 폴더 경로

os.makedirs(output_masks_dir, exist_ok=True)  # 출력 폴더가 없으면 생성

# 손상 이미지 폴더에 있는 모든 이미지 파일 이름 리스트 획득
input_image_files = sorted([f for f in os.listdir(input_images_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))])

for filename in input_image_files:
    input_path = os.path.join(input_images_dir, filename)
    gt_path = os.path.join(gt_images_dir, filename)

    # 마스크 생성
    mask_tensor = mask_function(input_path, gt_path)

    # mask_tensor는 shape [1, H, W]의 텐서이며 값 범위 [0,1]
    # 이를 이미지로 저장하기 위해 numpy로 변환 후 0~255 범위로 스케일링
    mask_np = (mask_tensor.squeeze(0).numpy() * 255).astype(np.uint8)

    # 마스크 이미지 저장
    output_path = os.path.join(output_masks_dir, filename)
    mask_img = Image.fromarray(mask_np)
    mask_img.save(output_path)

    print(f"Saved mask for {filename} at {output_path}")


## 2. 컬러 손상 이미지 생성기

In [None]:
import os
import cv2
import numpy as np
from PIL import Image


def mask_function(input_path, gt_path):
    try:
        # Load and preprocess images
        input_image = Image.open(input_path).convert("RGB")  # Ensure RGB format
        input_image_np = np.array(input_image)
        gt_image_gray = Image.open(gt_path).convert("L")  # Load mask as grayscale
        gt_image_gray_np = np.array(gt_image_gray)

        # Convert input_image_np to grayscale
        input_image_gray_np = cv2.cvtColor(input_image_np, cv2.COLOR_RGB2GRAY)

        # Compute the difference
        difference = cv2.absdiff(gt_image_gray_np, input_image_gray_np)

        # Threshold the difference to create a binary mask
        _, binary_difference = cv2.threshold(difference, 1, 255, cv2.THRESH_BINARY)

        # Remove small noise with morphological operations
        kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 1))
        binary_difference = cv2.morphologyEx(binary_difference, cv2.MORPH_CLOSE, kernel)

        # Find contours
        contours, _ = cv2.findContours(binary_difference, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

        # Fill the contours to create a mask
        mask_filled = np.zeros_like(binary_difference)
        cv2.drawContours(mask_filled, contours, -1, color=255, thickness=cv2.FILLED)

        # Expand the filled mask (dilation)
        mask_filled = cv2.dilate(mask_filled, kernel, iterations=1)

        return mask_filled  # Return mask as numpy array
    except Exception as e:
        print(f"Mask creation failed for input: {input_path}, error: {e}")
        return np.zeros((512, 512), dtype=np.uint8)


def apply_existing_mask(image, mask):
    damaged_image = image.copy()
    damaged_image[mask == 255] = 0
    return damaged_image


def process_images_in_folder(input_folder, output_folder, damage_folder):
    # 출력 폴더가 없으면 생성
    os.makedirs(output_folder, exist_ok=True)

    # 입력 폴더의 이미지 처리
    for filename in os.listdir(input_folder):
        if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
            input_path = os.path.join(input_folder, filename)
            damage_path = os.path.join(damage_folder, filename)
            output_path = os.path.join(output_folder, filename)

            # 이미지 로드
            image = cv2.imread(input_path)
            if image is None:
                print(f"Skipping {filename}: unable to read file.")
                continue

            # 마스크 생성
            mask = mask_function(damage_path, input_path)

            # 마스크 적용
            damaged_image = apply_existing_mask(image, mask)

            # 결과 저장 (손상된 파일명 동일하게 저장)
            cv2.imwrite(output_path, damaged_image)
            print(f"Processed and saved: {output_path}")


# 사용 예제
input_folder = "train_gt"  # 원본 이미지 폴더 경로
damage_folder = "train_input"  # 손상 기준 폴더 경로
output_folder = "damage_images"  # 손상된 이미지 저장 폴더 경로

process_images_in_folder(input_folder, output_folder, damage_folder)
