In [None]:
import os
import torch
from PIL import Image
import torchvision.transforms as T
import matplotlib.pyplot as plt
import numpy as np
import cv2

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

############################
# (1) 모델 구조 정의 (학습시 사용한 Stage1Generator, Stage2Generator)
############################

class GatedConv2d(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, activation=torch.nn.ReLU()):
        super().__init__()
        self.feature_conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.mask_conv    = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.sigmoid      = torch.nn.Sigmoid()
        self.activation   = activation

    def forward(self, x):
        f = self.feature_conv(x)
        m = self.mask_conv(x)
        gated = self.sigmoid(m)
        if self.activation is not None:
            f = self.activation(f)
        return f * gated

class GatedDeconv2d(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1, activation=torch.nn.ReLU()):
        super().__init__()
        self.feature_deconv = torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding)
        self.mask_deconv    = torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding)
        self.sigmoid        = torch.nn.Sigmoid()
        self.activation     = activation

    def forward(self, x):
        f = self.feature_deconv(x)
        m = self.mask_deconv(x)
        gated = self.sigmoid(m)
        if self.activation is not None:
            f = self.activation(f)
        return f * gated

class ContextualAttention(torch.nn.Module):
    def __init__(self, kernel_size=3, stride=1, dilation=1):
        super().__init__()
        self.conv = torch.nn.Conv2d(512, 512, kernel_size, stride, dilation, bias=False)
        self.softmax = torch.nn.Softmax(dim=-1)

    def forward(self, x):
        B,C,H,W = x.size()
        query = x.view(B,C,-1)
        key   = x.view(B,C,-1)
        value = x.view(B,C,-1)

        attn = torch.bmm(query.permute(0,2,1), key)  # (B,H*W,C) x (B,C,H*W)
        attn = self.softmax(attn)
        out  = torch.bmm(attn, value.permute(0,2,1))
        out  = out.permute(0,2,1).view(B,C,H,W)
        out  = self.conv(out)
        return out

class Stage1Generator(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # ----- Encoder (Down) -----
        self.enc1 = GatedConv2d(4,   64, 4, 2, 1)
        self.enc2 = GatedConv2d(64,  128,4, 2, 1)
        self.enc3 = GatedConv2d(128, 256,4, 2, 1)
        self.enc4 = GatedConv2d(256, 512,4, 2, 1)

        # ----- Decoder (Up) -----
        self.dec1_up   = GatedDeconv2d(512,256,4,2,1)
        self.dec1_conv = GatedConv2d(256+256,256,3,1,1)

        self.dec2_up   = GatedDeconv2d(256,128,4,2,1)
        self.dec2_conv = GatedConv2d(128+128,128,3,1,1)

        self.dec3_up   = GatedDeconv2d(128,64,4,2,1)
        self.dec3_conv = GatedConv2d(64+64,64,3,1,1)

        self.dec4_up   = GatedDeconv2d(64,64,4,2,1)
        self.dec4_conv = torch.nn.Conv2d(64,3,3,1,1)
        self.final_act = torch.nn.Sigmoid()

    def forward(self, x, mask):
        # x: (N,3,H,W), mask: (N,1,H,W)
        inp = torch.cat((x, mask), dim=1)   # (N,4,H,W)
        e1 = self.enc1(inp)
        e2 = self.enc2(e1)
        e3 = self.enc3(e2)
        e4 = self.enc4(e3)

        d1_up = self.dec1_up(e4)
        d1_in = torch.cat([d1_up, e3], dim=1)
        d1    = self.dec1_conv(d1_in)

        d2_up = self.dec2_up(d1)
        d2_in = torch.cat([d2_up, e2], dim=1)
        d2    = self.dec2_conv(d2_in)

        d3_up = self.dec3_up(d2)
        d3_in = torch.cat([d3_up, e1], dim=1)
        d3    = self.dec3_conv(d3_in)

        d4_up = self.dec4_up(d3)
        d4    = self.dec4_conv(d4_up)
        out   = self.final_act(d4)
        return out

class Stage2Generator(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # ----- Encoder -----
        self.enc1 = GatedConv2d(7,   64, 4, 2, 1)
        self.enc2 = GatedConv2d(64,  128,4, 2, 1)
        self.enc3 = GatedConv2d(128, 256,4, 2, 1)
        self.enc4 = GatedConv2d(256, 512,4, 2, 1)

        # Contextual Attention
        self.contextual_attention = ContextualAttention()

        # ----- Decoder (U-Net) -----
        self.dec1_up   = GatedDeconv2d(512,256,4,2,1)
        self.dec1_conv = GatedConv2d(256 + 256, 256, 3,1,1)

        self.dec2_up   = GatedDeconv2d(256,128,4,2,1)
        self.dec2_conv = GatedConv2d(128 + 128, 128, 3,1,1)

        self.dec3_up   = GatedDeconv2d(128,64,4,2,1)
        self.dec3_conv = GatedConv2d(64 + 64, 64, 3,1,1)

        self.dec4_up   = GatedDeconv2d(64,64,4,2,1)
        self.dec4_conv = torch.nn.Conv2d(64,3,3,1,1)
        self.final_act = torch.nn.Sigmoid()

    def forward(self, coarse_out, inp, mask):
        # coarse_out: (N,3,H,W), inp: (N,3,H,W), mask: (N,1,H,W)
        fin_inp = torch.cat((coarse_out, inp, mask), dim=1)  # (N,7,H,W)

        e1 = self.enc1(fin_inp)
        e2 = self.enc2(e1)
        e3 = self.enc3(e2)
        e4 = self.enc4(e3)

        e4_attn = self.contextual_attention(e4)

        d1_up = self.dec1_up(e4_attn)
        d1_in = torch.cat([d1_up, e3], dim=1)
        d1    = self.dec1_conv(d1_in)

        d2_up = self.dec2_up(d1)
        d2_in = torch.cat([d2_up, e2], dim=1)
        d2    = self.dec2_conv(d2_in)

        d3_up = self.dec3_up(d2)
        d3_in = torch.cat([d3_up, e1], dim=1)
        d3    = self.dec3_conv(d3_in)

        d4_up = self.dec4_up(d3)
        d4    = self.dec4_conv(d4_up)
        out   = self.final_act(d4)
        return out


############################
# (2) 가중치 경로 (학습코드에서 만든 best_coarse_generator, best_fine_generator)
############################
best_coarse_path = "model/09/coarse_generator_epoch51.pth"
best_fine_path   = "model/09/fine_generator_epoch51.pth"

############################
# (3) 테스트 이미지 폴더 및 마스크 폴더
############################
test_input_dir = "../02_color/data/output_grayTocol_2025010602"
test_mask_dir  = "../data/output_01_mask"
output_dir     = "data/output_colToper_2025010602"
os.makedirs(output_dir, exist_ok=True)

############################
# (4) 모델 불러오기
############################
coarse_generator = Stage1Generator().to(device)
fine_generator   = Stage2Generator().to(device)

# 학습된 가중치 로드
coarse_generator.load_state_dict(torch.load(best_coarse_path, map_location=device))
fine_generator.load_state_dict(torch.load(best_fine_path,   map_location=device))

coarse_generator.eval()
fine_generator.eval()

############################
# (5) 테스트용 transform
############################
transform = T.Compose([
    T.Resize((512,512)),  # 학습 시와 동일한 해상도
    T.ToTensor()
])

############################
# (6) 테스트 진행 (마스크 확장 적용 + 최종 원본마스크 반영)
############################
test_files = [f for f in os.listdir(test_input_dir) if f.lower().endswith(('.png','.jpg','.jpeg'))]

with torch.no_grad():
    for filename in test_files:
        input_path = os.path.join(test_input_dir, filename)
        mask_path  = os.path.join(test_mask_dir, filename)

        if not os.path.exists(mask_path):
            print(f"[WARNING] {mask_path}가 존재하지 않아 스킵합니다.")
            continue

        # (6-1) 이미지 로드
        inp_img  = Image.open(input_path).convert("RGB")
        mask_img = Image.open(mask_path).convert("L")

        # (6-1') 원본 마스크도 별도 보관(최종 후처리에 사용)
        original_mask_img = mask_img.copy()  # 복사해둬야 dilation 전에 보존

        # (6-2) 마스크 확장(팽창, dilation) 예시
        mask_np = np.array(mask_img, dtype=np.uint8)
        kernel = np.ones((3,3), np.uint8)  # 커널 크기를 조절(3x3,5x5 등)하면 확장 범위가 달라집니다
        dilated_mask = cv2.dilate(mask_np, kernel, iterations=2)  # 1~2 정도면 살짝 확장

        # dilated_mask -> 다시 PIL
        expanded_mask_img = Image.fromarray(dilated_mask)

        # (6-3) transform -> Tensor
        inp_tensor        = transform(inp_img).unsqueeze(0).to(device)          # (1,3,H,W)
        expanded_mask_ten = transform(expanded_mask_img).unsqueeze(0).to(device)# (1,1,H,W)
        original_mask_ten = transform(original_mask_img).unsqueeze(0).to(device)# (1,1,H,W)

        # (6-4) 손상 영역 0 처리(확장된 마스크)
        expanded_mask_bc  = expanded_mask_ten.expand_as(inp_tensor)             # (1,3,H,W)
        damaged_inp       = inp_tensor * (1 - expanded_mask_bc)

        # (6-5) Stage1 (Coarse) - 확장 마스크 기준
        coarse_out = coarse_generator(damaged_inp, expanded_mask_ten)

        # (6-6) Stage2 (Fine) - 확장 마스크 기준
        fine_out   = fine_generator(coarse_out, damaged_inp, expanded_mask_ten)

        # (6-7) 원본마스크 영역만 복원(나머지는 원본 이미지 유지)
        #       => original_mask_ten이 1인 부분만 fine_out을 사용
        orig_mask_bc = original_mask_ten.expand_as(inp_tensor)  # (1,3,H,W)
        final_result = inp_tensor * (1 - orig_mask_bc) + fine_out * orig_mask_bc

        # (6-8) 결과 저장
        final_result_pil = T.ToPILImage()(final_result.squeeze(0).cpu())
        save_path = os.path.join(output_dir, filename)
        final_result_pil.save(save_path)

        print(f"[INFO] {filename} 복원 완료 -> {save_path}")


  coarse_generator.load_state_dict(torch.load(best_coarse_path, map_location=device))
  fine_generator.load_state_dict(torch.load(best_fine_path,   map_location=device))


[INFO] TEST_000.png 복원 완료(마스크 팽창 적용) -> data/output_colToper_2025010602\TEST_000.png
[INFO] TEST_001.png 복원 완료(마스크 팽창 적용) -> data/output_colToper_2025010602\TEST_001.png
[INFO] TEST_002.png 복원 완료(마스크 팽창 적용) -> data/output_colToper_2025010602\TEST_002.png
[INFO] TEST_003.png 복원 완료(마스크 팽창 적용) -> data/output_colToper_2025010602\TEST_003.png
[INFO] TEST_004.png 복원 완료(마스크 팽창 적용) -> data/output_colToper_2025010602\TEST_004.png
[INFO] TEST_005.png 복원 완료(마스크 팽창 적용) -> data/output_colToper_2025010602\TEST_005.png
[INFO] TEST_006.png 복원 완료(마스크 팽창 적용) -> data/output_colToper_2025010602\TEST_006.png
[INFO] TEST_007.png 복원 완료(마스크 팽창 적용) -> data/output_colToper_2025010602\TEST_007.png
[INFO] TEST_008.png 복원 완료(마스크 팽창 적용) -> data/output_colToper_2025010602\TEST_008.png
[INFO] TEST_009.png 복원 완료(마스크 팽창 적용) -> data/output_colToper_2025010602\TEST_009.png
[INFO] TEST_010.png 복원 완료(마스크 팽창 적용) -> data/output_colToper_2025010602\TEST_010.png
[INFO] TEST_011.png 복원 완료(마스크 팽창 적용) -> data/output_colToper_2025