In [None]:
# !pip install segmentation-models-pytorch
# !pip install pytorch_msssim

Collecting pytorch_msssim
  Downloading pytorch_msssim-1.0.0-py3-none-any.whl.metadata (8.0 kB)
Downloading pytorch_msssim-1.0.0-py3-none-any.whl (7.7 kB)
Installing collected packages: pytorch_msssim
Successfully installed pytorch_msssim-1.0.0


In [1]:
import random
import numpy as np
import os

import torch
from PIL import Image
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
import cv2

import zipfile

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda:0


In [2]:
torch.cuda.is_available()

True

In [2]:
CFG = {
    'EPOCHS':2,
    'LEARNING_RATE':3e-4,
    # 'BATCH_SIZE':16,
    'BATCH_SIZE':64,
    'SEED':42
}

In [3]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(CFG['SEED']) # Seed 고정

In [4]:
#저장된 이미지 쌍을 동시에 로드 

class CustomDataset(Dataset):
    def __init__(self, damage_dir, origin_dir, transform=None):
        self.damage_dir = damage_dir
        self.origin_dir = origin_dir
        self.transform = transform
        self.damage_files = sorted(os.listdir(damage_dir))
        self.origin_files = sorted(os.listdir(origin_dir))

    def __len__(self):
        return len(self.damage_files)

    def __getitem__(self, idx):
        damage_img_name = self.damage_files[idx]
        origin_img_name = self.origin_files[idx]

        damage_img_path = os.path.join(self.damage_dir, damage_img_name)
        origin_img_path = os.path.join(self.origin_dir, origin_img_name)

        damage_img = Image.open(damage_img_path).convert("RGB")
        origin_img = Image.open(origin_img_path).convert("RGB")

        if self.transform:
            damage_img = self.transform(damage_img)
            origin_img = self.transform(origin_img)

        return {'A': damage_img, 'B': origin_img}

In [5]:
from torch.utils.data import random_split, DataLoader
import torchvision.transforms as transforms

# 경로 설정
origin_dir = 'data/train_gt'  # 원본 이미지 폴더 경로
damage_dir = 'data/train_input'  # 손상된 이미지 폴더 경로
test_dir = 'data/test_input'     # test 이미지 폴더 경로

# 데이터 전처리 설정
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

# 전체 데이터셋 생성
dataset = CustomDataset(damage_dir=damage_dir, origin_dir=origin_dir, transform=transform)

# 데이터셋을 학습과 검증으로 나누기 (예: 80% 학습, 20% 검증)
validation_ratio = 0.2
train_size = int((1 - validation_ratio) * len(dataset))
val_size = len(dataset) - train_size
training_dataset, validation_dataset = random_split(dataset, [train_size, val_size])

# 학습 및 검증 DataLoader 설정
train_dataloader = DataLoader(training_dataset, batch_size=CFG['BATCH_SIZE'], shuffle=True, num_workers=1)
validation_dataloader = DataLoader(validation_dataset, batch_size=CFG['BATCH_SIZE'], shuffle=False, num_workers=1)

In [6]:
# PatchGAN 기반의 Discriminator
class PatchGANDiscriminator(nn.Module):
    def __init__(self, in_channels=3):
        super(PatchGANDiscriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, normalization=True):
            layers = [nn.Conv2d(in_filters, out_filters, kernel_size=4, stride=2, padding=1)]
            if normalization:
                layers.append(nn.BatchNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return nn.Sequential(*layers)

        self.model = nn.Sequential(
            discriminator_block(in_channels * 2, 64, normalization=False),
            discriminator_block(64, 128),
            discriminator_block(128, 256),
            discriminator_block(256, 512),
            nn.Conv2d(512, 1, kernel_size=4, padding=1)
        )

    def forward(self, img_A, img_B):
        img_input = torch.cat((img_A, img_B), 1)
        return self.model(img_input)

# 가중치 초기화 함수
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm2d') != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

In [7]:
from segmentation_models_pytorch import UnetPlusPlus

# EfficientNet-B4를 백본으로 사용
UNetPP = UnetPlusPlus(
    encoder_name="efficientnet-b4",  # pretrained on ImageNet
    encoder_weights="imagenet",     
    in_channels=3,
    classes=3
).to(device)

# generator로 UNetPP를 사용하도록 설정
generator = UNetPP

In [8]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from IPython.display import FileLink
from pytorch_msssim import ssim

# 모델 저장을 위한 디렉토리 생성
model_save_dir = "./saved_models"
os.makedirs(model_save_dir, exist_ok=True)
log_file = os.path.join(model_save_dir, "training_log.txt")

best_loss = float("inf")
lambda_pixel = 100  # 픽셀 손실에 대한 가중치

# PatchGAN Discriminator 초기화
discriminator = PatchGANDiscriminator().to(device)

# 손실 함수 및 옵티마이저 설정
criterion_GAN = nn.MSELoss()
criterion_pixelwise = nn.L1Loss()

# 색상 히스토그램 유사도 함수
def histogram_similarity(pred, target, bins=256):
    pred_hist = cv2.calcHist([pred.cpu().numpy()], [0], None, [bins], [0, 256])
    target_hist = cv2.calcHist([target.cpu().numpy()], [0], None, [bins], [0, 256])
    hist_similarity = cv2.compareHist(pred_hist, target_hist, cv2.HISTCMP_CORREL)
    return hist_similarity

# 손실 함수
def combined_loss(fake_B, real_B):
    loss_GAN = criterion_GAN(fake_B, real_B)
    loss_pixel = criterion_pixelwise(fake_B, real_B)
    loss_ssim = 1 - ssim(fake_B, real_B, data_range=1.0)
    loss_hist = histogram_similarity(fake_B, real_B)
    return loss_GAN + lambda_pixel * loss_pixel + 0.4 * loss_ssim + 0.4 * loss_hist

# 초기화 부분
best_score = float('-inf')

# 학습 및 검증 루프
for epoch in range(1, CFG['EPOCHS'] + 1):
    generator.train()  # 학습 모드
    for i, batch in enumerate(train_dataloader):
        real_A = batch['A'].to(device)
        real_B = batch['B'].to(device)

        # Generator 훈련
        optimizer_G.zero_grad()
        fake_B = generator(real_A)
        pred_fake = discriminator(fake_B, real_A)
        loss_GAN = criterion_GAN(pred_fake, torch.ones_like(pred_fake).to(device))
        loss_pixel = criterion_pixelwise(fake_B, real_B)
        loss_G = loss_GAN + lambda_pixel * loss_pixel
        loss_G.backward()
        optimizer_G.step()

        # Discriminator 훈련
        optimizer_D.zero_grad()
        pred_real = discriminator(real_B, real_A)
        loss_real = criterion_GAN(pred_real, torch.ones_like(pred_real).to(device))
        pred_fake = discriminator(fake_B.detach(), real_A)
        loss_fake = criterion_GAN(pred_fake, torch.zeros_like(pred_fake).to(device))
        loss_D = 0.5 * (loss_real + loss_fake)
        loss_D.backward()
        optimizer_D.step()

        # Discriminator 정확도 계산
        real_accuracy = torch.mean((pred_real > 0.5).float())
        fake_accuracy = torch.mean((pred_fake < 0.5).float())
        accuracy = 0.5 * (real_accuracy + fake_accuracy)

        if i % 100 == 0:
            print(f"[Epoch {epoch}/{CFG['EPOCHS']}] [Batch {i}/{len(train_dataloader)}] "
                  f"[D loss: {loss_D.item()}] [G loss: {loss_G.item()}] [D accuracy: {accuracy.item() * 100:.2f}%]")

    # Validation 단계
    generator.eval()  # 검증 모드
    validation_loss = 0.0
    ssim_scores, masked_ssim_scores, hist_similarities = [], [], []
    with torch.no_grad():
        for j, val_batch in enumerate(validation_dataloader):
            real_val_A = val_batch['A'].to(device)
            real_val_B = val_batch['B'].to(device)
            fake_val_B = generator(real_val_A)

            # 각각의 스코어 계산
            ssim_score = ssim(fake_val_B, real_val_B, data_range=1.0).item()
            masked_ssim_score = ssim(fake_val_B * real_val_A, real_val_B * real_val_A, data_range=1.0).item()
            hist_similarity = histogram_similarity(fake_val_B, real_val_B)

            ssim_scores.append(ssim_score)
            masked_ssim_scores.append(masked_ssim_score)
            hist_similarities.append(hist_similarity)

            val_loss = combined_loss(fake_val_B, real_val_B)
            validation_loss += val_loss.item()
    
     # 평균 스코어 계산
    S = sum(ssim_scores) / len(ssim_scores)
    M = sum(masked_ssim_scores) / len(masked_ssim_scores)
    C = sum(hist_similarities) / len(hist_similarities)
    score = (0.2 * S) + (0.4 * M) + (0.4 * C)

    validation_loss /= len(validation_dataloader)
    validation_log_message = f"Validation loss: {validation_loss}, Score: {score}"
    print(validation_log_message)

    # 텍스트 파일에 Validation 손실 기록
    with open(log_file, "a") as log:
        log.write(validation_log_message + "\n")

    # SSIM 평가 기준에 따른 최적 모델 저장
    if score > best_score:
        best_score = score
        torch.save(generator.state_dict(), os.path.join(model_save_dir, "best_generator.pth"))
        torch.save(discriminator.state_dict(), os.path.join(model_save_dir, "best_discriminator.pth"))
        save_message = f"Best model saved at epoch {epoch} with Score: {score}"
        print(save_message)
        with open(log_file, "a") as log:
            log.write(save_message + "\n")

# 학습 완료 후 최고의 스코어 출력
final_message = f"Training completed. Best Score: {best_score}"
print(final_message)

KeyboardInterrupt: 

In [2]:
import cv2
from PIL import Image
import numpy as np
import os

# 저장할 디렉토리 설정
submission_dir = "./submission3"
os.makedirs(submission_dir, exist_ok=True)

# 이미지 로드 및 전처리
def load_image(image_path):
    image = Image.open(image_path).convert("RGB")
    image = transform(image)
    image = image.unsqueeze(0)  # 배치 차원을 추가합니다.
    return image

# 모델 경로 설정
generator_path = 'saved_models/best_model-epoch=01-val_score=0.0000.ckpt'

# 모델 로드 및 설정
# model = UNetPP(in_channels=3, out_channels=3).to(device)  # UNetPP로 설정
UNetPP.load_state_dict(torch.load(generator_path))
UNetPP.eval()

# 파일 리스트 불러오기
test_images = sorted(os.listdir(test_dir))

# 모든 테스트 이미지에 대해 추론 수행
for image_name in test_images:
    test_image_path = os.path.join(test_dir, image_name)

    # 손상된 테스트 이미지 로드 및 전처리
    test_image = load_image(test_image_path).to(device)

    with torch.no_grad():
        # 모델로 예측
        pred_image = UNetPP(test_image)
        pred_image = pred_image.cpu().squeeze(0)  # 배치 차원 제거
        pred_image = pred_image * 0.5 + 0.5  # 역정규화
        pred_image = pred_image.numpy().transpose(1, 2, 0)  # HWC로 변경
        pred_image = (pred_image * 255).astype('uint8')  # 0-255 범위로 변환
        
        # 예측된 이미지를 실제 이미지와 같은 512x512로 리사이즈
        pred_image_resized = cv2.resize(pred_image, (512, 512), interpolation=cv2.INTER_LINEAR)

    # 결과 이미지 저장
    output_path = os.path.join(submission_dir, image_name)
    cv2.imwrite(output_path, cv2.cvtColor(pred_image_resized, cv2.COLOR_RGB2BGR))    
    
print(f"Saved all images")


NameError: name 'UNetPP' is not defined

In [None]:
# 저장된 결과 이미지를 ZIP 파일로 압축
zip_filename = "submission2.zip"
with zipfile.ZipFile(zip_filename, 'w') as submission_zip:
    for image_name in test_images:
        image_path = os.path.join(submission_dir, image_name)
        submission_zip.write(image_path, arcname=image_name)

print(f"All images saved in {zip_filename}")