In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
# ==================================================
# 0. 의존성 (필요시: pip install -q --upgrade torch torchvision)
# ==================================================
import os, glob, json, random
from typing import List, Dict, Tuple, Optional

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.transforms import functional as F
from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

In [3]:
# ==================================================
# 1. 전역 설정
# ==================================================
ROOT = "/content/drive/MyDrive/Data/MVTecAD/capsule"

TRAIN_IMG_DIR = os.path.join(ROOT, "test", "scratch")
TRAIN_IMG_DIR_2 = os.path.join(ROOT, "test", "crack")
VAL_IMG_DIR   = os.path.join(ROOT, "test", "poke")

TRAIN_JSON = glob.glob(os.path.join(ROOT, "ground_truth", "scratch", "*.json"))[0]
TRAIN_JSON_2 = glob.glob(os.path.join(ROOT, "ground_truth", "crack", "*.json"))[0]
VAL_JSON   = glob.glob(os.path.join(ROOT, "ground_truth", "poke", "*.json"))[0]

BATCH_SIZE   = 1
EPOCHS       = 10
LR           = 1e-4
SCORE_THRESH = 0.5

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)
# print(TRAIN_JSON, VAL_JSON)

device: cuda


In [4]:
# ==================================================
# 2. Dataset
# ==================================================
def coco_bbox_to_xyxy(bbox: List[float]) -> List[float]:
    x, y, w, h = bbox
    return [x, y, x + w, y + h]


class DefectDataset(Dataset):
    """
    img_dir : test/broken_*  폴더
    ann_json: ground_truth/broken_*/*.json (COCO 형식)
    """
    def __init__(self, img_dir: str, ann_json: Optional[str] = None, transforms=None):
        self.transforms = transforms

        # ---------- 이미지 경로 ----------
        self.img_paths = sorted(glob.glob(os.path.join(img_dir, "*.png")))
        if not self.img_paths:
            raise RuntimeError(f"{img_dir} 에 png 이미지가 없습니다.")

        # ---------- annotation ----------
        self.ann_map: Dict[str, List[dict]] = {}
        if ann_json is not None:
            with open(ann_json) as f:
                coco = json.load(f)

            def _norm(fname: str) -> str:
                """
                COCO  file_name -> test 이미지 이름과 동일하게 변환
                예) '013_mask.png'  →  '013.png'
                """
                base = os.path.basename(fname)
                if base.endswith("_mask.png"):
                    base = base.replace("_mask.png", ".png")
                return base

            id2fname = {img["id"]: _norm(img["file_name"]) for img in coco["images"]}

            for ann in coco["annotations"]:
                fname = id2fname[ann["image_id"]]
                self.ann_map.setdefault(fname, []).append(ann)

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        path  = self.img_paths[idx]
        fname = os.path.basename(path)

        img = Image.open(path).convert("RGB")
        img_tensor = F.pil_to_tensor(img).float() / 255.0

        anns  = self.ann_map.get(fname, [])
        boxes = [coco_bbox_to_xyxy(a["bbox"]) for a in anns]

        if boxes:
            boxes  = torch.tensor(boxes, dtype=torch.float32)
            labels = torch.ones((boxes.shape[0],), dtype=torch.int64)
            areas  = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
        else:
            boxes  = torch.zeros((0, 4), dtype=torch.float32)
            labels = torch.ones((0,), dtype=torch.int64)
            areas  = torch.zeros((0,), dtype=torch.float32)

        target = {
            "boxes": boxes, "labels": labels,
            "image_id": torch.tensor([idx]),
            "area": areas,  "iscrowd": torch.zeros(labels.size(0), dtype=torch.int64),
        }
        if self.transforms:
            img_tensor, target = self.transforms(img_tensor, target)

        return img_tensor, target


In [5]:
# ==================================================
# 3. DataLoader
# ==================================================
from torch.utils.data import ConcatDataset

def collate_fn(batch):  # DataLoader용
    return tuple(zip(*batch))

concat_dataset = ConcatDataset([
    DefectDataset(TRAIN_IMG_DIR, TRAIN_JSON),
    DefectDataset(TRAIN_IMG_DIR_2, TRAIN_JSON_2),
])

train_loader = DataLoader(
    concat_dataset,
    batch_size=BATCH_SIZE, shuffle=True,
    collate_fn=collate_fn
)

val_loader = DataLoader(
    DefectDataset(VAL_IMG_DIR, VAL_JSON),  # JSON 없으면 None
    batch_size=1, shuffle=False, collate_fn=collate_fn
)


In [6]:
# ==================================================
# 4. 모델
# ==================================================
from torchvision.models.detection.rpn import AnchorGenerator, RPNHead

model = fasterrcnn_resnet50_fpn_v2(weights='DEFAULT')# Detection model

new_anchor_size = ((4,), (8,), (16,), (32,), (64,))
new_aspect_ratios = ((0.25, 0.5, 1.0, 2.0),) * len(new_anchor_size)

new_anchor_generator = AnchorGenerator(
    sizes=new_anchor_size,
    aspect_ratios=new_aspect_ratios
)

model.rpn.anchor_generator = new_anchor_generator

in_channel = model.backbone.out_channels
num_anchors = 4
model.rpn.head = RPNHead(in_channel, num_anchors)

in_feat = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_feat, num_classes=2)
model.to(device)

params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.Adam(params, lr=LR, weight_decay=1e-4)

In [7]:
# ==================================================
# 5. 학습
# ==================================================
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0.0

    for imgs, targets in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
        imgs    = [img.to(device) for img in imgs]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        loss = sum(model(imgs, targets).values())
        optimizer.zero_grad();
        loss.backward();
        optimizer.step()
        total_loss += loss.item()

    print(f"  ↳ mean loss = {total_loss/len(train_loader):.4f}")

print("Fine-tuning 완료\n")

Epoch 1/10:   0%|          | 0/46 [00:00<?, ?it/s]

  ↳ mean loss = 0.6649


Epoch 2/10:   0%|          | 0/46 [00:00<?, ?it/s]

  ↳ mean loss = 0.2372


Epoch 3/10:   0%|          | 0/46 [00:00<?, ?it/s]

  ↳ mean loss = 0.0636


Epoch 4/10:   0%|          | 0/46 [00:00<?, ?it/s]

  ↳ mean loss = 0.0551


Epoch 5/10:   0%|          | 0/46 [00:00<?, ?it/s]

  ↳ mean loss = 0.0541


Epoch 6/10:   0%|          | 0/46 [00:00<?, ?it/s]

  ↳ mean loss = 0.0385


Epoch 7/10:   0%|          | 0/46 [00:00<?, ?it/s]

  ↳ mean loss = 0.0375


Epoch 8/10:   0%|          | 0/46 [00:00<?, ?it/s]

  ↳ mean loss = 0.0355


Epoch 9/10:   0%|          | 0/46 [00:00<?, ?it/s]

  ↳ mean loss = 0.0275


Epoch 10/10:   0%|          | 0/46 [00:00<?, ?it/s]

  ↳ mean loss = 0.0251
Fine-tuning 완료



In [8]:
# ==================================================
# 6. 검증 시각화 (전체 이미지)
# ==================================================
model.eval()

with torch.no_grad():
    for idx, (imgs, targets) in enumerate(val_loader):
        img  = imgs[0].to(device)
        pred = model([img])[0]

        # ── 시각화용 PIL 변환 ────────────────────────────
        vis  = Image.fromarray((img.cpu().permute(1,2,0).numpy()*255).astype("uint8"))
        draw = ImageDraw.Draw(vis, "RGBA")

        # GT (green)
        for b in targets[0]["boxes"]:
            draw.rectangle(b.tolist(), outline=(0,255,0,255), width=3)

        # Pred (red, score ≥ threshold)
        for b, s in zip(pred["boxes"], pred["scores"]):
            if s < SCORE_THRESH: continue
            draw.rectangle(b.cpu().tolist(), outline=(255,0,0,180), width=3)

        # ── 개별 그림 출력 ───────────────────────────────
        plt.figure(figsize=(6, 6))
        plt.imshow(vis); plt.axis("off")
        plt.title(f"Val sample {idx+1}  —  GT(G) / Pred (R)")
        plt.show()


Output hidden; open in https://colab.research.google.com to view.