In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# ==================================================
# 0. 의존성
# ==================================================
import os, random, time
from pathlib import Path
from glob import glob
from typing import List, Tuple

import numpy as np
from PIL import Image

import torch
from torch.utils.data import Dataset, DataLoader, random_split
import torch.nn as nn
import torch.nn.functional as F                         # 모델 연산용
import torchvision.transforms as T
import torchvision.transforms.functional as TF          # 이미지 변환용

import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import matplotlib.patches as mpatches
import math



In [None]:
# ==================================================
# 1. 전역 설정 (필요 시 수정)
# ==================================================
ROOT_DIR    = "/content/drive/MyDrive/Data/MVTecAD"  # 데이터셋 루트
IMG_SIZE    = 256          # 이미지 리사이즈 크기
BATCH_SIZE  = 32
NUM_WORKERS = 4
LR          = 1e-4
EPOCHS      = 100
VAL_SPLIT   = 0.2
PRINT_FREQ  = 50           # 학습 중 log 출력 간격

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
# ==================================================
# 2. Dataset
# ==================================================
class MVTecSegDataset(Dataset):
    """MVTecAD Pixel‑wise Segmentation 전용
        test/<defect>/*.png  → 입력 이미지
        ground_truth/<defect>/*_mask.png → 바이너리 마스크
        good 샘플은 기본 제외(include_good=False)
    """
    IMG_EXT = (".png", ".jpg", ".jpeg")

    def __init__(self, root_dir: str, include_good: bool=False, img_size: int=256):
        self.samples: List[Tuple[str, str]] = []   # (image_path, mask_path)
        self.img_size = img_size

        root = Path(root_dir)
        for cls_dir in sorted(root.iterdir()):
            if not cls_dir.is_dir():
                continue
            test_dir = cls_dir / "test"
            gt_dir   = cls_dir / "ground_truth"
            if not (test_dir.exists() and gt_dir.exists()):
                continue

            for defect in sorted(d.name for d in test_dir.iterdir() if d.is_dir()):
                if defect == "good" and not include_good:
                    continue
                # ---------- 이미지 / 마스크 매칭 ----------
                for img_path in (test_dir / defect).glob("*"):
                    if img_path.suffix.lower() not in self.IMG_EXT:
                        continue
                    if defect == "good":
                        mask_path = None   # 정상 → 0‑mask
                    else:
                        stem      = img_path.stem
                        mask_path = gt_dir / defect / f"{stem}_mask.png"
                        if not mask_path.exists():
                            continue
                    self.samples.append((str(img_path), str(mask_path) if mask_path else None))

        print(f"총 {len(self.samples)}개 이미지/마스크 쌍 로드 완료")

        # 변환 파이프라인 정의
        self.img_tf = T.Compose([
            T.Resize((img_size, img_size), interpolation=T.InterpolationMode.BILINEAR),
            T.ToTensor(),                                   # (3,H,W) 0~1
        ])
        self.mask_tf = T.Compose([
            T.Resize((img_size, img_size), interpolation=T.InterpolationMode.NEAREST),
            T.ToTensor(),                                   # (1,H,W) 0~1
        ])

    # 필수 메서드 -----------------------------------
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        ipath, mpath = self.samples[idx]
        # ----- 이미지 로드 -----
        img  = Image.open(ipath).convert("RGB")
        imgT = self.img_tf(img)
        # ----- 마스크 로드 -----
        if mpath is None:
            mask = Image.new("L", img.size, 0)    # 전부 0
        else:
            mask = Image.open(mpath).convert("L")
        maskT = self.mask_tf(mask)                 # (1,H,W)
        return imgT, maskT



In [None]:
# ==================================================
# 3. DataLoader 구성
# ==================================================
full_ds = MVTecSegDataset(ROOT_DIR, include_good=False, img_size=IMG_SIZE)
val_len   = int(len(full_ds) * VAL_SPLIT)
train_len = len(full_ds) - val_len
train_set, val_set = random_split(full_ds, [train_len, val_len], generator=torch.Generator().manual_seed(42))

def collate_fn(batch):
    imgs, masks = zip(*batch)
    return torch.stack(imgs), torch.stack(masks)        # (B,3,H,W) (B,1,H,W)

train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True,  num_workers=NUM_WORKERS, pin_memory=True, collate_fn=collate_fn)
val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True, collate_fn=collate_fn)



In [None]:
# ==================================================
# 4. 모델 (경량 U‑Net)
# ==================================================
# TODO

In [None]:
# ==================================================
# 5. 손실 & 지표
# ==================================================
# TODO

In [None]:
# ==================================================
# 6. train & validate 함수
# ==================================================
# TODO

In [None]:
# ==================================================
# ★ 세그멘테이션 오버레이 시각화 (배치 크기 무관)
# ==================================================

@torch.no_grad()
def visualize_overlay(model, loader, device=DEVICE, n=6, alpha=0.45):
    """
    • 원본 이미지 위에 GT(녹색) · 예측(빨간색) 마스크를 반투명 오버레이
    • n 장이 확보될 때까지 DataLoader에서 연속적으로 가져옴
    • 최대 4장씩 가로로 배치
    """
    model.eval()

    # ---------- n장 모으기 ----------
    imgs_acc, masks_acc, preds_acc = [], [], []
    for imgs, masks in loader:
        imgs_acc.append(imgs)
        masks_acc.append(masks)
        preds_acc.append((torch.sigmoid(model(imgs.to(device))) > 0.5).cpu())
        if sum(b.size(0) for b in imgs_acc) >= n:
            break

    imgs  = torch.cat(imgs_acc)[:n]   # (n,3,H,W)
    masks = torch.cat(masks_acc)[:n]  # (n,1,H,W)
    preds = torch.cat(preds_acc)[:n]  # (n,1,H,W)

    # ---------- 오버레이 ----------
    gt_color   = np.array([0, 1, 0])   # green
    pred_color = np.array([1, 0, 0])   # red

    cols = 4
    rows = math.ceil(n / cols)
    fig, axes = plt.subplots(rows, cols, figsize=(cols * 4, rows * 4))
    axes = axes.flatten()        # 2D → 1D 배열로 편하게 인덱싱

    for i in range(n):
        img_np = imgs[i].permute(1, 2, 0).numpy()
        gt_np  = masks[i, 0].numpy()
        pr_np  = preds[i, 0].numpy()

        overlay = img_np.copy()
        overlay[gt_np == 1] = (1 - alpha) * overlay[gt_np == 1] + alpha * gt_color
        overlay[pr_np == 1] = (1 - alpha) * overlay[pr_np == 1] + alpha * pred_color

        axes[i].imshow(overlay)
        axes[i].axis("off")
        axes[i].set_title(f"Sample {i}")

    # 남는 서브플롯은 비우기
    for j in range(n, rows * cols):
        axes[j].axis("off")

    # 범례
    green_patch = mpatches.Patch(color="green", label="Ground Truth")
    red_patch   = mpatches.Patch(color="red",   label="Prediction")
    plt.legend(handles=[green_patch, red_patch],
               bbox_to_anchor=(1.05, 1), loc="upper left")

    plt.tight_layout()
    plt.show()

In [None]:
# Best 모델 로드 후 오버레이 시각화
model.load_state_dict(torch.load(best_path, map_location=DEVICE))
visualize_overlay(model, val_loader, n=16, alpha=0.4)
