In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# ==================================================
# 0. 의존성
# ==================================================
import os, random, time
from pathlib import Path
from glob import glob
from typing import List, Tuple

import numpy as np
from PIL import Image

import torch
from torch.utils.data import Dataset, DataLoader, random_split
import torch.nn as nn
import torch.nn.functional as F                         # 모델 연산용
import torchvision.transforms as T
import torchvision.transforms.functional as TF          # 이미지 변환용

import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import matplotlib.patches as mpatches
import math

In [None]:
# ==================================================
# 1. 전역 설정 (필요 시 수정)
# ==================================================
ROOT_DIR    = "/content/drive/MyDrive/Data/MVTecAD"  # 데이터셋 루트
IMG_SIZE    = 256          # 이미지 리사이즈 크기
BATCH_SIZE  = 32
NUM_WORKERS = 4
LR          = 1e-4
EPOCHS      = 100
VAL_SPLIT   = 0.2
PRINT_FREQ  = 50           # 학습 중 log 출력 간격

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
# ==================================================
# 2. Dataset
# ==================================================
class MVTecSegDataset(Dataset):
    """MVTecAD Pixel‑wise Segmentation 전용
        test/<defect>/*.png  → 입력 이미지
        ground_truth/<defect>/*_mask.png → 바이너리 마스크
        good 샘플은 기본 제외(include_good=False)
    """
    IMG_EXT = (".png", ".jpg", ".jpeg")

    def __init__(self, root_dir: str, include_good: bool=False, img_size: int=256):
        self.samples: List[Tuple[str, str]] = []   # (image_path, mask_path)
        self.img_size = img_size

        root = Path(root_dir)
        for cls_dir in sorted(root.iterdir()):
            if not cls_dir.is_dir():
                continue
            test_dir = cls_dir / "test"
            gt_dir   = cls_dir / "ground_truth"
            if not (test_dir.exists() and gt_dir.exists()):
                continue

            for defect in sorted(d.name for d in test_dir.iterdir() if d.is_dir()):
                if defect == "good" and not include_good:
                    continue
                # ---------- 이미지 / 마스크 매칭 ----------
                for img_path in (test_dir / defect).glob("*"):
                    if img_path.suffix.lower() not in self.IMG_EXT:
                        continue
                    if defect == "good":
                        mask_path = None   # 정상 → 0‑mask
                    else:
                        stem      = img_path.stem
                        mask_path = gt_dir / defect / f"{stem}_mask.png"
                        if not mask_path.exists():
                            continue
                    self.samples.append((str(img_path), str(mask_path) if mask_path else None))

        print(f"총 {len(self.samples)}개 이미지/마스크 쌍 로드 완료")

        # 변환 파이프라인 정의
        self.img_tf = T.Compose([
            T.Resize((img_size, img_size), interpolation=T.InterpolationMode.BILINEAR),
            T.ToTensor(),                                   # (3,H,W) 0~1
        ])
        self.mask_tf = T.Compose([
            T.Resize((img_size, img_size), interpolation=T.InterpolationMode.NEAREST),
            T.ToTensor(),                                   # (1,H,W) 0~1
        ])

    # 필수 메서드 -----------------------------------
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        ipath, mpath = self.samples[idx]
        # ----- 이미지 로드 -----
        img  = Image.open(ipath).convert("RGB")
        imgT = self.img_tf(img)
        # ----- 마스크 로드 -----
        if mpath is None:
            mask = Image.new("L", img.size, 0)    # 전부 0
        else:
            mask = Image.open(mpath).convert("L")
        maskT = self.mask_tf(mask)                 # (1,H,W)
        return imgT, maskT

In [None]:
# ==================================================
# 3. DataLoader 구성
# ==================================================
full_ds = MVTecSegDataset(ROOT_DIR, include_good=False, img_size=IMG_SIZE)
val_len   = int(len(full_ds) * VAL_SPLIT)
train_len = len(full_ds) - val_len
train_set, val_set = random_split(full_ds, [train_len, val_len], generator=torch.Generator().manual_seed(42))

def collate_fn(batch):
    imgs, masks = zip(*batch)
    return torch.stack(imgs), torch.stack(masks)        # (B,3,H,W) (B,1,H,W)

train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True,  num_workers=NUM_WORKERS, pin_memory=True, collate_fn=collate_fn)
val_loader   = DataLoader(val_set,   batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True, collate_fn=collate_fn)


In [None]:
# ==================================================
# 4. 모델 (경량 U‑Net)
# ==================================================
class DoubleConv(nn.Module):
    """Conv‑BN‑ReLU ×2 블록"""
    def __init__(self, in_c, out_c):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_c, out_c, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_c), nn.ReLU(inplace=True),
            nn.Conv2d(out_c, out_c, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_c), nn.ReLU(inplace=True),
        )
    def forward(self, x):
        return self.block(x)

class UNet(nn.Module):
    def __init__(self, in_c=3, out_c=1, feat=[64,128,256,512]):
        super().__init__()
        self.downs = nn.ModuleList()
        self.ups   = nn.ModuleList()
        # 인코더(Down)
        for f in feat:
            self.downs.append(DoubleConv(in_c, f))
            in_c = f
        # 보틀넥
        self.bottleneck = DoubleConv(feat[-1], feat[-1]*2)
        # 디코더(Up)
        for f in reversed(feat):
            self.ups.append(nn.ConvTranspose2d(f*2, f, kernel_size=2, stride=2))
            self.ups.append(DoubleConv(f*2, f))
        self.final_conv = nn.Conv2d(feat[0], out_c, kernel_size=1)

    def forward(self, x):
        skips = []
        for down in self.downs:
            x = down(x)
            skips.append(x)
            x = F.max_pool2d(x, 2)
        x = self.bottleneck(x)
        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip = skips[-(idx//2 + 1)]
            # 크기 불일치 보정 (홀수 입력 대응)
            if x.shape != skip.shape:
                x = F.pad(x, [0, skip.size(3)-x.size(3), 0, skip.size(2)-x.size(2)])
            x = torch.cat([skip, x], dim=1)
            x = self.ups[idx+1](x)
        return self.final_conv(x)

# ==================================================
# 5. 손실 & 지표
# ==================================================
class DiceLoss(nn.Module):
    def __init__(self, smooth: float = 1.):
        super().__init__()
        self.smooth = smooth
    def forward(self, logits, targets):
        probs = torch.sigmoid(logits)
        targets = targets.float()
        inter = (probs * targets).sum(dim=[2,3])
        union = probs.sum(dim=[2,3]) + targets.sum(dim=[2,3])
        dice = (2*inter + self.smooth) / (union + self.smooth)
        return 1 - dice.mean()

def dice_coef(logits, targets):
    probs = (torch.sigmoid(logits) > 0.5).float()
    inter = (probs * targets).sum(dim=[1,2,3])
    union = probs.sum(dim=[1,2,3]) + targets.sum(dim=[1,2,3])
    return (2*inter / (union + 1e-8)).mean().item()



In [None]:
# ==================================================
# 6. train & validate 함수
# ==================================================

def train_one_epoch(model, loader, opt, loss_fn):
    model.train()
    tot_loss = tot_dice = 0.
    for step, (imgs, masks) in enumerate(loader):
        imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)
        opt.zero_grad()
        logits = model(imgs)
        loss = loss_fn(logits, masks) + nn.BCEWithLogitsLoss()(logits, masks)
        loss.backward()
        opt.step()

        tot_loss += loss.item()
        tot_dice += dice_coef(logits.detach(), masks)

        if (step+1) % PRINT_FREQ == 0:
            print(f"step {step+1}/{len(loader)}  loss {tot_loss/(step+1):.4f}  dice {tot_dice/(step+1):.4f}")
    return tot_loss/len(loader), tot_dice/len(loader)

@torch.no_grad()
def evaluate(model, loader, loss_fn):
    model.eval()
    tot_loss = tot_dice = 0.
    for imgs, masks in loader:
        imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)
        logits = model(imgs)
        loss = loss_fn(logits, masks) + nn.BCEWithLogitsLoss()(logits, masks)
        tot_loss += loss.item()
        tot_dice += dice_coef(logits, masks)
    return tot_loss/len(loader), tot_dice/len(loader)

In [None]:
# ==================================================
# 7. 학습 구동
# ==================================================
model = UNet().to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
loss_fn   = DiceLoss()

best_dice, best_path = 0., "best_unet.pth"

for epoch in range(EPOCHS):
    print(f"\n===== Epoch {epoch+1}/{EPOCHS} =====")
    tr_loss, tr_dice = train_one_epoch(model, train_loader, optimizer, loss_fn)
    val_loss, val_dice = evaluate(model, val_loader, loss_fn)
    print(f"Train loss {tr_loss:.4f} dice {tr_dice:.4f} │ Val loss {val_loss:.4f} dice {val_dice:.4f}")

    if val_dice > best_dice:
        best_dice = val_dice
        torch.save(model.state_dict(), best_path)
        print(f"  ▶ 모델 갱신! (dice={best_dice:.4f})")



총 1258개 이미지/마스크 쌍 로드 완료

===== Epoch 1/100 =====
Train loss 1.4132 dice 0.1065 │ Val loss 1.8663 dice 0.0855
  ▶ 모델 갱신! (dice=0.0855)

===== Epoch 2/100 =====
Train loss 1.2850 dice 0.0996 │ Val loss 1.2517 dice 0.0696

===== Epoch 3/100 =====
Train loss 1.2472 dice 0.1207 │ Val loss 1.2386 dice 0.0950
  ▶ 모델 갱신! (dice=0.0950)

===== Epoch 4/100 =====
Train loss 1.2119 dice 0.1649 │ Val loss 1.1932 dice 0.1813
  ▶ 모델 갱신! (dice=0.1813)

===== Epoch 5/100 =====
Train loss 1.1789 dice 0.2329 │ Val loss 1.1676 dice 0.2230
  ▶ 모델 갱신! (dice=0.2230)

===== Epoch 6/100 =====
Train loss 1.1648 dice 0.2293 │ Val loss 1.1335 dice 0.2441
  ▶ 모델 갱신! (dice=0.2441)

===== Epoch 7/100 =====
Train loss 1.1387 dice 0.2687 │ Val loss 1.1236 dice 0.1670

===== Epoch 8/100 =====
Train loss 1.1337 dice 0.2553 │ Val loss 1.2203 dice 0.2346

===== Epoch 9/100 =====
Train loss 1.1234 dice 0.2704 │ Val loss 1.1414 dice 0.2784
  ▶ 모델 갱신! (dice=0.2784)

===== Epoch 10/100 =====
Train loss 1.0773 dice 0.3104 │ Val

In [None]:
# ==================================================
# ★ 세그멘테이션 오버레이 시각화 (배치 크기 무관)
# ==================================================

@torch.no_grad()
def visualize_overlay(model, loader, device=DEVICE, n=6, alpha=0.45):
    """
    • 원본 이미지 위에 GT(녹색) · 예측(빨간색) 마스크를 반투명 오버레이
    • n 장이 확보될 때까지 DataLoader에서 연속적으로 가져옴
    • 최대 4장씩 가로로 배치
    """
    model.eval()

    # ---------- n장 모으기 ----------
    imgs_acc, masks_acc, preds_acc = [], [], []
    for imgs, masks in loader:
        imgs_acc.append(imgs)
        masks_acc.append(masks)
        preds_acc.append((torch.sigmoid(model(imgs.to(device))) > 0.5).cpu())
        if sum(b.size(0) for b in imgs_acc) >= n:
            break

    imgs  = torch.cat(imgs_acc)[:n]   # (n,3,H,W)
    masks = torch.cat(masks_acc)[:n]  # (n,1,H,W)
    preds = torch.cat(preds_acc)[:n]  # (n,1,H,W)

    # ---------- 오버레이 ----------
    gt_color   = np.array([0, 1, 0])   # green
    pred_color = np.array([1, 0, 0])   # red

    cols = 4
    rows = math.ceil(n / cols)
    fig, axes = plt.subplots(rows, cols, figsize=(cols * 4, rows * 4))
    axes = axes.flatten()        # 2D → 1D 배열로 편하게 인덱싱

    for i in range(n):
        img_np = imgs[i].permute(1, 2, 0).numpy()
        gt_np  = masks[i, 0].numpy()
        pr_np  = preds[i, 0].numpy()

        overlay = img_np.copy()
        overlay[gt_np == 1] = (1 - alpha) * overlay[gt_np == 1] + alpha * gt_color
        overlay[pr_np == 1] = (1 - alpha) * overlay[pr_np == 1] + alpha * pred_color

        axes[i].imshow(overlay)
        axes[i].axis("off")
        axes[i].set_title(f"Sample {i}")

    # 남는 서브플롯은 비우기
    for j in range(n, rows * cols):
        axes[j].axis("off")

    # 범례
    green_patch = mpatches.Patch(color="green", label="Ground Truth")
    red_patch   = mpatches.Patch(color="red",   label="Prediction")
    plt.legend(handles=[green_patch, red_patch],
               bbox_to_anchor=(1.05, 1), loc="upper left")

    plt.tight_layout()
    plt.show()

In [None]:
# Best 모델 로드 후 오버레이 시각화
model.load_state_dict(torch.load(best_path, map_location=DEVICE))
visualize_overlay(model, val_loader, n=16, alpha=0.4)


Output hidden; open in https://colab.research.google.com to view.