In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# =============================================================
# 0. 환경 준비
# =============================================================
import os, random, numpy as np, matplotlib.pyplot as plt
from glob import glob
from pathlib import Path
from tqdm import tqdm
from PIL import Image

import torch, torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from sklearn.metrics import roc_auc_score, average_precision_score

IMG_SIZE    = 128
BATCH_SIZE  = 32
EPOCHS      = 5
LR          = 1e-3
ROOT_DIR    = '/content/drive/MyDrive/Data/MVTecAD'
SAVE_PATH   = 'best_ae_mvtecad.pt'

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print('✅ device:', DEVICE)


In [None]:
# =============================================================
# 1. 데이터셋 & DataLoader
#         – train  : 모든 */train/good
#         – test   : */test/good + */test/<defect>/*
# =============================================================
class MVTecADMulti(Dataset):
    def __init__(self, root_dir:str, phase:str='train'):
        assert phase in ('train', 'test')
        self.phase = phase
        self.paths = []                      # [(img_path, label)]
        for cat in sorted(os.listdir(root_dir)):
            cat_path = Path(root_dir)/cat
            if not cat_path.is_dir():
                continue
            if phase == 'train':
                self.paths += [(p, 0) for p in glob(str(cat_path/'train'/'good'/'*.png'))]
            else:
                self.paths += [(p, 0) for p in glob(str(cat_path/'test'/'good'/'*.png'))]
                for defect in os.listdir(cat_path/'test'):
                    if defect == 'good': continue
                    self.paths += [(p, 1) for p in glob(str(cat_path/'test'/defect/'*.png'))]
        self.tf = T.Compose([
            T.Resize((IMG_SIZE, IMG_SIZE)),
            T.ToTensor(),
            T.Normalize([0.5]*3, [0.5]*3)
        ])
    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        p, lbl = self.paths[idx]
        img = Image.open(p).convert('RGB')   # ← 핵심
        img = self.tf(img)
        return img, lbl, p

def get_loaders(root):
    train_ds = MVTecADMulti(root, 'train')
    test_ds  = MVTecADMulti(root, 'test')
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                              num_workers=2, pin_memory=True, drop_last=True)
    test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False,
                              num_workers=2, pin_memory=True)
    return train_loader, test_loader

train_loader, test_loader = get_loaders(ROOT_DIR)
print(f"Train:{len(train_loader.dataset)}  Test:{len(test_loader.dataset)}")


In [None]:
# =============================================================
# 2.  네트워크 (Convolutional Autoencoder)
# =============================================================
# TODO


In [None]:
# =============================================================
# 3.  학습 루프
# =============================================================
# TODO

In [None]:
# =============================================================
# 4.  테스트 & 시각화
#         – (1) 에러 분포 히스토그램
#         – (2) 정상∙이상 예시 5×2장
# =============================================================
model.load_state_dict(torch.load(SAVE_PATH, map_location=DEVICE))
model.eval()

# (1) 에러 계산
errs, labels, paths = [], [], []
with torch.no_grad():
    for imgs, lbls, ps in test_loader:
        imgs = imgs.to(DEVICE)
        recons = model(imgs)
        err = F.mse_loss(recons, imgs, reduction='none')
        err = err.flatten(1).mean(1).cpu().numpy()
        errs.append(err); labels.append(lbls.numpy()); paths += ps
errs = np.concatenate(errs); labels = np.concatenate(labels)

# 95-percentile 기준 임계값
THRESH = np.percentile(errs[labels==0], 95)

# (1-a) 히스토그램
hist_n, edges_n = np.histogram(errs[labels==0], bins=60)
hist_a, _       = np.histogram(errs[labels==1], bins=edges_n)
centers = (edges_n[:-1] + edges_n[1:]) / 2

plt.figure(figsize=(7,4))
plt.bar(centers, hist_n, width=centers[1]-centers[0], label='Normal', alpha=.6)
plt.bar(centers, hist_a, width=centers[1]-centers[0], label='Anomaly', alpha=.6)
plt.axvline(THRESH, ls='--', lw=2, label=f'Thresh={THRESH:.4f}')
plt.xlabel('Reconstruction MSE'); plt.ylabel('Count')
plt.title('Error distribution (test)'); plt.legend(); plt.tight_layout(); plt.show()

# (2) 예시 이미지 시각화
def denorm(x): return x.mul(0.5).add(0.5).clamp(0,1)
idx_norm = np.where((labels==0) & (errs<THRESH))[0][:5]
idx_anom = np.where((labels==1) & (errs>THRESH))[0][:5]
sel = np.concatenate([idx_norm, idx_anom])

fig, axes = plt.subplots(2, 5, figsize=(15,6))
for i, ax in enumerate(axes.flatten()):
    img = plt.imread(paths[sel[i]])[:,:,:3]
    ax.imshow(img); ax.axis('off')
    ax.set_title(f"{'Normal' if i<5 else 'Anomaly'}\nMSE={errs[sel[i]]:.4f}")
plt.suptitle('Reconstruction-based Anomaly Detection'); plt.tight_layout(); plt.show()
