1. 라이브러리 임포트

In [1]:
import torch
import torchvision
import os
import numpy as np
from pathlib import Path
from PIL import Image
from torchvision import transforms
from src.res_model import ResNetUNet
import cv2

In [2]:
# 모델 설정
unet = ResNetUNet(
    in_channels=1,
    out_channels=2, 
    batch_norm=True, 
    upscale_mode="bilinear"
)

# GPU 설정
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
models_folder = Path("models")
model_name = "best_segmentation.pt"
checkpoint = torch.load(models_folder / model_name, map_location=torch.device("cpu"))
unet.load_state_dict(checkpoint)
unet.to(device)
unet.eval()



ResNetUNet(
  (encoder): ModuleList(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          (0): Conv2d(6

In [3]:
def resize_and_pad(image, size=(1024, 1024)):
    old_size = image.size  # old_size is in (width, height) format
    ratio = float(size[0]) / max(old_size)
    new_size = tuple([int(x * ratio) for x in old_size])
    image = image.resize(new_size, Image.LANCZOS)
    new_image = Image.new("L", size)  # 변경: "L" 모드로 새 이미지 생성
    new_image.paste(image, ((size[0] - new_size[0]) // 2, (size[1] - new_size[1]) // 2))
    return new_image

In [4]:
def postprocess_mask(mask, min_area=1000, max_distance=50):
    # OpenCV로 마스크 후처리
    mask = (mask * 255).astype(np.uint8)
    _, thresh = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
    
    # 작은 조각 제거
    contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    for contour in contours:
        if cv2.contourArea(contour) < min_area:
            cv2.drawContours(thresh, [contour], -1, 0, thickness=cv2.FILLED)
    
    # 윤곽선 병합을 위한 새로운 마스크 생성
    new_mask = np.zeros_like(thresh)
    
    # 모든 윤곽선 병합
    contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    for contour in contours:
        cv2.drawContours(new_mask, [contour], -1, 255, thickness=cv2.FILLED)
    
    # 인접한 윤곽선 병합
    for i in range(len(contours) - 1):
        for j in range(i + 1, len(contours)):
            if len(contours[i]) == 0 or len(contours[j]) == 0:
                continue

            left_i = min(contours[i], key=lambda x: x[0][0])
            right_i = max(contours[i], key=lambda x: x[0][0])
            left_j = min(contours[j], key=lambda x: x[0][0])
            right_j = max(contours[j], key=lambda x: x[0][0])

            if np.linalg.norm(left_i - left_j) < max_distance and np.linalg.norm(right_i - right_j) < max_distance:
                cv2.line(new_mask, tuple(left_i[0]), tuple(left_j[0]), 255, 1)
                cv2.line(new_mask, tuple(right_i[0]), tuple(right_j[0]), 255, 1)
                cv2.drawContours(new_mask, [contours[i]], -1, 255, thickness=cv2.FILLED)
                cv2.drawContours(new_mask, [contours[j]], -1, 255, thickness=cv2.FILLED)
                contours[j] = np.array([])  # 병합된 윤곽선은 비웁니다

    # 폐 영역 중간에 구멍 뚫린 부분 채우기
    kernel = np.ones((5, 5), np.uint8)
    new_mask = cv2.morphologyEx(new_mask, cv2.MORPH_CLOSE, kernel)
    
    # 윤곽선 병합 후, convex hull 적용
    contours, _ = cv2.findContours(new_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    for contour in contours:
        hull = cv2.convexHull(contour)
        cv2.fillPoly(new_mask, [hull], 255)

    # 최종 큰 윤곽선 두 개만 남기기
    contours, _ = cv2.findContours(new_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    contours = sorted(contours, key=cv2.contourArea, reverse=True)[:2]
    final_mask = np.zeros_like(new_mask)
    for contour in contours:
        cv2.drawContours(final_mask, [contour], -1, 255, thickness=cv2.FILLED)
    
    return final_mask / 255.0  # 0과 1 사이의 값으로 정규화

In [5]:
# 이미지 세그멘테이션 함수
def CreateSegmentedImage(img_path):
    origin = Image.open(img_path).convert("L")  # 변경: "L" 모드로 변환
    
    # 변환 적용: resize and pad
    origin = resize_and_pad(origin, (1024, 1024))
    
    origin_tensor = transforms.functional.to_tensor(origin).unsqueeze(0) - 0.5
    
    with torch.no_grad():
        origin_tensor = origin_tensor.to(device)
        out = unet(origin_tensor)
        softmax = torch.nn.functional.log_softmax(out, dim=1)
        out = torch.argmax(softmax, dim=1)
    
        origin_tensor = origin_tensor[0].to("cpu")
        out = out[0].to("cpu")
    
    # 마스크 후처리 적용
    processed_mask = postprocess_mask(out.numpy())
    
    # 마스크를 사용하여 원본 이미지에서 폐 부분만 추출
    mask_resized = Image.fromarray((processed_mask * 255).astype(np.uint8)).resize(origin.size, Image.NEAREST)
    segmented_img = Image.composite(origin.convert("RGB"), Image.new('RGB', origin.size), mask_resized.convert("L"))
    
    return segmented_img

In [6]:
# 원본 데이터 경로 및 새로운 데이터 경로 설정
base_data_dir = 'input/chest_xray'
output_base_dir = 'input/segmented_chest_xray'
# 데이터셋 디렉토리 구조
phases = ['train', 'val', 'test']
classes = ['NORMAL', 'PNEUMONIA']

In [7]:
# 이미지 저장 함수
def save_image(img, save_path):
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    img.save(save_path, format='PNG')  # PNG 포맷으로 저장

7. 원본에서 폐영역 분리하고 저장

In [8]:
# 세그멘테이션 및 저장
for phase in phases:
    for class_name in classes:
        input_dir = os.path.join(base_data_dir, phase, class_name)
        output_dir = os.path.join(output_base_dir, phase, class_name)

        print(f"Processing {phase} phase, {class_name} class...")

        for img_name in os.listdir(input_dir):
            img_path = os.path.join(input_dir, img_name)
            print(f"Processing file: {img_name}")

            segmented_img = CreateSegmentedImage(img_path)

            # 동일한 디렉토리 구조와 파일명으로 저장
            save_path = os.path.join(output_dir, img_name)
            save_segmented_image(segmented_img, save_path)

        print(f"Finished processing {phase} phase, {class_name} class.")

Processing train phase, NORMAL class...
Processing file: IM-0409-0001.jpeg


NameError: name 'save_segmented_image' is not defined