In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os
os.cpu_count()

12

In [None]:
from torch.utils.data import DataLoader

loader = DataLoader(
    dataset, # 데이터셋
    batch_size=32,
    shuffle=True,
    num_workers=os.cpu_count(),   # Colab: 2~4 vCPU → 2-4 권장
    pin_memory=torch.cuda.is_available(),
    prefetch_factor=4,            # 초당 3-4배치 이상 사전 로드
    persistent_workers=True,      # epoch 간 워커 재사용
)

In [None]:
# ==================================================
# 0. 의존성 & 전역 설정
# ==================================================
import os, random, time
from pathlib import Path
from typing import List, Tuple

import torch, torchvision
from torch import nn
from torch.utils.data import Dataset, DataLoader, Subset
import torchvision.transforms as T
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
from tqdm import tqdm
import matplotlib.pyplot as plt

# ------------------------------
# 사용자 하이퍼파라미터
# ------------------------------
ROOT_DIR   = "/content/drive/MyDrive/Data/MVTecAD"  # 데이터셋 경로
IMG_SIZE   = 224
BATCH_SIZE = 32
VAL_RATIO  = 0.2
EPOCHS     = 10
LR         = 1e-4
BEST_PATH  = "best_model.pt"

# --------------------------------------------------
# 시드 고정
# --------------------------------------------------

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("✅ 사용 디바이스:", DEVICE)


In [None]:
# ==================================================
# 1. Dataset (mask 제외 + 3채널 변환)
# ==================================================
class MVTecClsDataset(Dataset):
    """MVTecAD 전체 구조(train+test)를 한 번에 스캔해
    `category_defect` 단위 클래스로 만드는 Dataset.
    • train/ 하위는 항상 `good` 이므로 라벨 = f"{category}_good"
    • test/ 하위는 결함 폴더(예: broken_large)를 그대로 라벨에 사용
    • 파일명에 `mask` 포함 시 제외
    • 모든 이미지는 3채널(RGB)로 강제 변환
    """

    VALID_EXT = {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff"}
    HIDDEN    = {".ipynb_checkpoints", ".DS_Store"}

    def __init__(self, root_dir: str, transforms=None):
        self.root = Path(root_dir)
        self.t    = transforms
        self.samples: List[Tuple[Path, int]] = []
        self.classes: List[str] = []
        self._scan()

    def _scan(self):
        lbl2idx = {}
        for p in self.root.rglob("*"):
            # --- 유효 파일 필터 ---
            if not (p.is_file() and p.suffix.lower() in self.VALID_EXT):
                continue
            if any(h in p.parts for h in self.HIDDEN):
                continue
            if "mask" in p.stem.lower():
                continue
            parts = p.relative_to(self.root).parts  # (category, train/test, defect?, img)
            if len(parts) < 3:
                continue
            category, phase = parts[0], parts[1]
            defect = "good" if phase == "train" else parts[2]
            lbl_name = f"{category}_{defect}"
            cls_idx = lbl2idx.setdefault(lbl_name, len(lbl2idx))
            self.samples.append((p, cls_idx))

        self.classes = [lbl for lbl, _ in sorted(lbl2idx.items(), key=lambda kv: kv[1])]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        p, label = self.samples[idx]
        img = torchvision.io.read_image(str(p)).float() / 255.0
        if img.shape[0] == 1:
            img = img.repeat(3, 1, 1)
        elif img.shape[0] == 4:
            img = img[:3]
        if self.t:
            img = self.t(img)
        return img, label, str(p)




In [None]:
# ==================================================
# 2. DataLoader
# ==================================================
train_tfms = T.Compose([T.Resize((IMG_SIZE,IMG_SIZE)), T.RandomHorizontalFlip(), T.RandomRotation(10)])
val_tfms = T.Resize((IMG_SIZE,IMG_SIZE))

base_ds = MVTecClsDataset(ROOT_DIR)
indices = list(range(len(base_ds)))
random.shuffle(indices)
val_len = int(len(indices)*VAL_RATIO)
val_idx,train_idx = indices[:val_len],indices[val_len:]
train_set, val_set = Subset(base_ds, train_idx), Subset(base_ds, val_idx)

make_loader = lambda ds,tfm,shuf: DataLoader(ds,batch_size=BATCH_SIZE,shuffle=shuf,
                                             num_workers=2 if torch.cuda.is_available() else 0,
                                             pin_memory=torch.cuda.is_available(),
                                             collate_fn=lambda b:(torch.stack([tfm(x[0]) for x in b]),
                                                                torch.tensor([x[1] for x in b]),
                                                                [x[2] for x in b]))
train_loader = make_loader(train_set,train_tfms,True)
val_loader = make_loader(val_set,val_tfms,False)
idx2lbl=base_ds.classes
print("클래스:", idx2lbl)



In [None]:
# ==================================================
# 3. 모델 정의
# ==================================================
# TODO


In [None]:
# ==================================================
# 4. 학습 루프
# ==================================================
# TODO

In [None]:
# ==================================================
# 5. 시각화 (파란=정답, 빨간=오답)
# ==================================================
# 검증 배치 GT vs 예측 (색상 구분)
model.eval(); val_imgs,val_labels, _ = next(iter(val_loader))
with torch.inference_mode(): preds = model(val_imgs.to(DEVICE)).argmax(1).cpu()
plt.figure(figsize=(16,4))
for i in range(8):
    correct=preds[i]==val_labels[i]; color='blue' if correct else 'red'
    plt.subplot(2,4,i+1)
    plt.imshow(val_imgs[i].permute(1,2,0))
    plt.title(f"GT:{idx2lbl[int(val_labels[i])]}\nPR:{idx2lbl[int(preds[i])]}", color=color)
    plt.axis('off')
plt.tight_layout(); plt.show()