In [1]:
import os
import sys
sys.path.append('..')

import numpy as np
import matplotlib.pyplot as plt
import warnings
from pathlib import Path
warnings.filterwarnings("ignore")

try:
    import google.colab
    IN_COLAB = True
except ImportError:
    IN_COLAB = False

# 환경에 따라 PROJECT_ROOT 설정
if IN_COLAB:
    PROJECT_ROOT = '/content/drive/Othercomputers/내 Mac/Road_Lane_segmentation'
else:
    PROJECT_ROOT = Path.cwd().parents[0]

# sys.path 추가 (import용)
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

# 작업 디렉토리 변경 (파일 접근용)
os.chdir(PROJECT_ROOT)

print(f"Environment: {'Colab' if IN_COLAB else 'Local'}")
print(f"PROJECT_ROOT: {PROJECT_ROOT}")
print(f"Current working directory: {os.getcwd()}")

Environment: Local
PROJECT_ROOT: /Users/mungughyeon/내 드라이브/likelion/Road_Lane_segmentation
Current working directory: /Users/mungughyeon/내 드라이브/likelion/Road_Lane_segmentation


In [2]:
from src.inference import Inferencer
from src.inference.visualize import mask_to_rgb
from src.utils.log import setup_logger

CHECKPOINT_PATH = "checkpoints/UnetPlusPlus_efficientnet-b4_focal+dice_weight_exp2/best.pt"
CONFIG_PATH = "configs/config.yaml"
IMAGE_PATH = "dataset/inference_img"
SAVE_PATH = "inference/UnetPlusPlus_efficientnet_b4-weight_exp2"

In [3]:
import cv2
from tqdm import tqdm

os.makedirs(SAVE_PATH, exist_ok=True)
logger = setup_logger(name="InferenceScript")
inferencer = Inferencer(checkpoint=CHECKPOINT_PATH, config_path=CONFIG_PATH)
logger.info("Inferencer loaded successfully.")

image_files = []
valid_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']

for filename in os.listdir(IMAGE_PATH):
    if any(filename.lower().endswith(ext) for ext in valid_extensions):
        image_files.append(os.path.join(IMAGE_PATH, filename))

if not image_files:
    logger.warning(
        f"No image files found in {IMAGE_PATH}. Please check the directory and file extensions."
    )
else:
    logger.info(f"Found {len(image_files)} images in {IMAGE_PATH}. Starting inference...")

    # 각 이미지에 대해 추론 수행 및 저장 (tqdm 적용)
    for i, image_path in tqdm(
        enumerate(image_files),
        total=len(image_files),
        desc="Processing Images"
    ):
        logger.info(f"Processing image {i + 1}/{len(image_files)}: {image_path}")

        try:
            # 이미지 로드
            image = cv2.imdecode(
                np.fromfile(image_path, dtype=np.uint8),
                cv2.IMREAD_COLOR
            )

            if image is None:
                logger.warning(f"Could not read image {image_path}. Skipping.")
                continue

            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # BGR → RGB
            predicted_mask = inferencer.predict(image) # 추론

            # 마스크 RGB 변환
            mask_rgb = mask_to_rgb(
                predicted_mask,
                num_classes=inferencer.num_classes
            )

            # 저장 경로
            original_filename_stem = Path(image_path).stem
            save_file_name = f"{original_filename_stem}_predicted_mask.png"
            save_full_path = os.path.join(SAVE_PATH, save_file_name)

            # 저장 (RGB → BGR)
            cv2.imwrite(
                save_full_path,
                cv2.cvtColor(mask_rgb, cv2.COLOR_RGB2BGR)
            )

            logger.info(f"Saved predicted mask to: {save_full_path}")

        except Exception:
            logger.error(f"Error processing {image_path}", exc_info=True)

    logger.info("Inference for all images completed.")

Device: mps
Loaded model weights from checkpoint: checkpoints/UnetPlusPlus_efficientnet-b4_focal+dice_weight_exp2/best.pt


Processing Images: 100%|██████████| 246/246 [00:56<00:00,  4.35it/s]
