In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# ==================================================
# 0. 의존성
# ==================================================
import os, json, random
from glob import glob
from collections import defaultdict
from typing import List, Tuple, Dict

from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.functional as F
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
import torch.optim as optim
from tqdm.auto import tqdm

# ==================================================
# 1. 전역 설정
# ==================================================
ROOT_DIR    = "/content/drive/MyDrive/Data/MVTecAD"
BATCH_SIZE  = 2
NUM_WORKERS = 4
LR = 1e-4
EPOCHS = 50

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)



In [None]:
# ==================================================
# 2. COCO JSON 파싱 → {file_name: [[x1,y1,x2,y2], ...]}
# ==================================================
def parse_coco_json(json_path: str) -> Dict[str, List[List[float]]]:
    with open(json_path) as f:
        coco = json.load(f)

    id_to_name = {img["id"]: img["file_name"] for img in coco["images"]}

    fname2boxes = defaultdict(list)
    for ann in coco["annotations"]:
        x, y, w, h = ann["bbox"]
        box = [x, y, x + w, y + h]
        fname2boxes[id_to_name[ann["image_id"]]].append(box)

    return fname2boxes

# ==================================================
# 3. Dataset
# ==================================================
class MVTecODDataset(Dataset):
    """
    test/<defect>/ 이미지 ↔ ground_truth/<defect>/COCO-json(1개)
    """
    def __init__(self, root_dir: str, transforms=None, use_good: bool=False):
        self.root_dir   = root_dir
        self.transforms = transforms
        self.samples: List[Tuple[str, List[List[float]], int]] = []
        self.label_map: Dict[str, int] = {}
        lbl_idx = 1   # 0 = background

        for category in sorted(os.listdir(root_dir)):
            cat_dir = os.path.join(root_dir, category)
            if not os.path.isdir(cat_dir):
                continue

            test_dir = os.path.join(cat_dir, "test")
            gt_dir   = os.path.join(cat_dir, "ground_truth")
            if not (os.path.isdir(test_dir) and os.path.isdir(gt_dir)):
                continue

            for defect in sorted(os.listdir(test_dir)):
                if defect == "good" and not use_good:
                    continue

                img_dir  = os.path.join(test_dir, defect)
                ann_json = glob(os.path.join(gt_dir, defect, "*.json"))
                if len(ann_json) != 1:
                    print(f"[경고] {category}/{defect} JSON 수={len(ann_json)} → 스킵")
                    continue

                # label id 부여
                if defect not in self.label_map:
                    self.label_map[defect] = lbl_idx
                    lbl_idx += 1
                label_id = self.label_map[defect]

                # COCO JSON → fname➜boxes
                fname2boxes = parse_coco_json(ann_json[0])

                for fname, boxes in fname2boxes.items():
                    # --- 원본 이미지 경로 찾기 -----------------------------
                    base = fname.replace("_mask", "")
                    cand = [os.path.join(img_dir, base),
                            os.path.join(img_dir, base[:-4] + ".png"),
                            os.path.join(img_dir, base[:-4] + ".jpg"),
                            os.path.join(img_dir, fname)]     # 마스크 자체
                    img_path = next((p for p in cand if os.path.isfile(p)), None)
                    if img_path is None:
                        continue
                    # ------------------------------------------------------
                    self.samples.append((img_path, boxes, label_id))

        print(f"총 {len(self.samples)}개 이미지,  라벨 맵: {self.label_map}")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, boxes, label_id = self.samples[idx]

        img    = Image.open(img_path).convert("RGB")
        boxes  = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.full((len(boxes),), label_id, dtype=torch.int64)

        target = {"boxes": boxes, "labels": labels}

        if self.transforms:
            img = self.transforms(img)

        return img, target, img_path   # ← 경로도 함께 반환


In [None]:
# ==================================================
# 4. Dataset 인스턴스 생성
# ==================================================
dataset = MVTecODDataset(
    ROOT_DIR,
    transforms=lambda x: F.to_tensor(x),
    use_good=False
)

# ==================================================
# 5. train / val split ─ random_split
# ==================================================
train_ratio = 0.8
n_total     = len(dataset)
n_train     = int(train_ratio * n_total)
n_val       = n_total - n_train

train_set, val_set = torch.utils.data.random_split(
    dataset,
    [n_train, n_val],
    generator=torch.Generator()   # 재현 가능
)

print(f"train: {len(train_set)} / val: {len(val_set)}")

# ==================================================
# 6. DataLoader 설정
# ==================================================
def collate_fn(batch):
    imgs, targets, paths = zip(*batch)
    return list(imgs), list(targets), list(paths)

train_loader = DataLoader(
    train_set,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    collate_fn=collate_fn,
    pin_memory=True
)

val_loader = DataLoader(
    val_set,
    batch_size=BATCH_SIZE,
    shuffle=False,          # 검증은 셔플 X
    num_workers=NUM_WORKERS,
    collate_fn=collate_fn,
    pin_memory=True
)


In [None]:
# --------------------------------------------------
# 7. 간단 테스트 루프
# --------------------------------------------------
for phase, loader in [("train", train_loader), ("val", val_loader)]:
    imgs, targets, paths = next(iter(loader))
    print(f"[{phase}]  첫 경로:", paths[0])
    print(f"[{phase}]  첫 bbox:", targets[0]["boxes"][0])

In [None]:
# ==================================================
# 8. Object Detection 모델 구현 및 학습 스크립트
# ==================================================
# TODO

In [None]:
# ==================================================
# 9. 시각화: 예측 ↔ GT 비교
# ==================================================
import matplotlib.pyplot as plt
from torchvision.utils import draw_bounding_boxes
import torchvision.transforms.functional as TF

# id → 이름 역매핑  (→ draw_bounding_boxes 의 labels 인자에 사용)
id2name = {v: k for k, v in dataset.label_map.items()}

def vis_predictions(model, loader, device,
                    score_thresh: float = 0.5,
                    num_images: int = 4):
    """
    • 검증 배치에서 num_images 장을 골라
      └ 왼쪽 : Ground-Truth (GREEN)
      └ 오른쪽: Prediction (RED, score ≥ thresh)
    """
    model.eval()
    shown = 0
    plt.figure(figsize=(num_images * 4, 4))

    with torch.no_grad():
        for imgs, targets, paths in loader:
            imgs = [img.to(device) for img in imgs]
            outputs = model(imgs)

            for img, tgt, pred in zip(imgs, targets, outputs):
                # ----------------- GT 그리기 -----------------
                gt_boxes  = tgt["boxes"].cpu()
                gt_labels = [id2name[int(l)] for l in tgt["labels"]]
                gt_img = draw_bounding_boxes(
                    (img.cpu() * 255).byte(),
                    gt_boxes,
                    labels=gt_labels,
                    colors="green",
                    width=16,          # ← 굵기 UP (기본 1)
                    font_size=18      # ← 글자 크기 UP (기본 10)
                )
                # ----------------- Pred 그리기 ----------------
                keep = pred["scores"] >= score_thresh
                pred_boxes  = pred["boxes"][keep].cpu()
                pred_labels = [id2name[int(l)]
                               for l in pred["labels"][keep]]
                pred_scores = pred["scores"][keep].cpu().tolist()
                # 라벨 뒤에 점수 표기
                pred_labels = [f"{lab}:{s:.2f}"
                               for lab, s in zip(pred_labels, pred_scores)]

                pred_img = draw_bounding_boxes(
                    (img.cpu() * 255).byte(),
                    pred_boxes,
                    labels=pred_labels,
                    colors="red",
                    width=16,          # ← 동일하게 굵기·글자 크기 조정
                    font_size=18
                )

                # ----------------- 시각화 ---------------------
                gt_np   = TF.to_pil_image(gt_img)
                pred_np = TF.to_pil_image(pred_img)

                plt.subplot(2, num_images, shown + 1)
                plt.imshow(gt_np);   plt.axis("off")
                if shown == 0: plt.title("Ground-Truth", fontsize=12)

                plt.subplot(2, num_images, num_images + shown + 1)
                plt.imshow(pred_np); plt.axis("off")
                if shown == 0: plt.title("Prediction", fontsize=12)

                shown += 1
                if shown >= num_images:
                    plt.tight_layout()
                    plt.show()
                    return


In [None]:
# ==================================================
# 10. 시각화
# ==================================================
vis_predictions(model, val_loader, device,
                score_thresh=0.5, num_images=4)
