In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
ROOT_DIR  = '/content/drive/MyDrive/Data/MVTecAD'  # 데이터셋 상위 경로
SAVE_NAME = f'best_skipgan.pt'

# =============================================================
# 0.  의존성 & 전역 설정
# =============================================================
import os, random, time
from glob   import glob
from pathlib import Path
import numpy as np, matplotlib.pyplot as plt
from tqdm.auto import tqdm

import torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from sklearn.metrics import roc_auc_score
from torch.nn.utils import spectral_norm
import torch.autograd as autograd

IMG_SIZE   = 256
BATCH_SIZE = 16
EPOCHS     = 5
LR_G = LR_D = 2e-4
BETAS = (0.5, 0.999)
lambda_GP = 0.0 #spectral norm

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('✅ device:', DEVICE)


In [None]:
# =============================================================
# 1.  데이터셋 & DataLoader
# =============================================================
class MVTecADMulti(Dataset):
    """
    전체 카테고리(train/good, test/good + defects)를 한 번에 스캔
      · phase='train' → 모든 */train/good   (label=0)
      · phase='test'  → */test/good          (label=0)
                        */test/<defect>/*    (label=1)
    """
    def __init__(self, root_dir:str, phase:str='train'):
        assert phase in ('train', 'test')
        self.phase = phase
        self.paths = []          # (img_path, label)

        for cat in sorted(os.listdir(root_dir)):
            cat_path = Path(root_dir)/cat
            if not cat_path.is_dir():
                continue

            if phase == 'train':
                self.paths += [(p,0) for p in
                               glob(str(cat_path/'train'/'good'/'*.png'))]
            else:   # test
                self.paths += [(p,0) for p in
                               glob(str(cat_path/'test'/'good'/'*.png'))]
                for defect in os.listdir(cat_path/'test'):
                    if defect == 'good': continue
                    self.paths += [(p,1) for p in
                                   glob(str(cat_path/'test'/defect/'*.png'))]

        self.tf = T.Compose([
            T.ToPILImage(),
            T.Resize((IMG_SIZE, IMG_SIZE)),
            T.ToTensor(),
            T.Normalize([0.5]*3, [0.5]*3)
        ])

    def __len__(self): return len(self.paths)

    def __getitem__(self, idx):
        img_path, label = self.paths[idx]
        img = plt.imread(img_path)
        if img.ndim == 2:                     # grayscale → 3-ch
            img = np.stack([img]*3, -1)
        if img.max() <= 1:                    # [0,1] → [0,255]
            img = (img*255).astype(np.uint8)
        return self.tf(img), label


def get_loaders_all(root_dir):
    train_ds = MVTecADMulti(root_dir, 'train')
    test_ds  = MVTecADMulti(root_dir, 'test')

    train_loader = DataLoader(
        train_ds, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

    test_loader  = DataLoader(
        test_ds, batch_size=BATCH_SIZE, shuffle=True)

    return train_loader, test_loader


In [None]:
# =============================================================
# 2.  네트워크 정의
# =============================================================
# TODO

In [None]:
# =============================================================
# 3.  학습 루프 & 평가 함수
# =============================================================
# TODO

In [None]:
# =============================================================
# 4.  학습 실행
# =============================================================
# TODO

In [None]:
# =============================================================
# 5.  재구성 시각화 (학습 후)
# =============================================================
SAVE_NAME = '/content/best_gan_all.pt'

@torch.inference_mode()
def show_recon(G, loader, n=6):
    G.eval()
    imgs,_ = next(iter(loader)); imgs = imgs[:n].to(DEVICE)
    recon = G(imgs)
    imgs, recon = imgs.cpu()*0.5+0.5, recon.cpu()*0.5+0.5
    plt.figure(figsize=(n*2,4))
    for i in range(n):
        plt.subplot(2,n,i+1);     plt.imshow(imgs[i].permute(1,2,0));   plt.axis('off')
        plt.subplot(2,n,n+i+1);   plt.imshow(recon[i].permute(1,2,0));  plt.axis('off')
    plt.suptitle('Input (top) vs Reconstruction (bottom)'); plt.show()

# 사용 예
G = Generator().to(DEVICE); G.load_state_dict(torch.load(SAVE_NAME))
_, test_loader = get_loaders_all(ROOT_DIR)
show_recon(G, test_loader)


In [None]:
# =============================================================
# 6.  테스트 & 시각화  (G · SAVE_NAME 로 변수 맞춤)
# =============================================================
import torch.nn.functional as F

# ① Generator 인스턴스 준비 & 가중치 로드
G = Generator().to(DEVICE)
G.load_state_dict(torch.load(SAVE_NAME, map_location=DEVICE))
G.eval()

errs, labels, paths = [], [], []
idx_global = 0                                   # paths 인덱스 계산용

with torch.no_grad():
    for imgs, lbls in test_loader:               # (img, label)
        bsz      = imgs.size(0)
        imgs_gpu = imgs.to(DEVICE)
        recons   = G(imgs_gpu)

        err = F.mse_loss(recons, imgs_gpu, reduction='none')
        err = err.flatten(1).mean(1).cpu().numpy()

        errs.append(err)
        labels.append(lbls.numpy())

        batch_paths = [test_loader.dataset.paths[idx_global + i][0]
                       for i in range(bsz)]
        paths += batch_paths
        idx_global += bsz

errs   = np.concatenate(errs)
labels = np.concatenate(labels)

# 95-percentile 임계값
THRESH = np.percentile(errs[labels == 0], 95)

# ---------- (1) 에러 분포 ----------
hist_n, edges = np.histogram(errs[labels == 0], bins=60)
hist_a, _     = np.histogram(errs[labels == 1], bins=edges)
centers       = (edges[:-1] + edges[1:]) / 2

plt.figure(figsize=(7,4))
plt.bar(centers, hist_n, width=centers[1]-centers[0], label='Normal',  alpha=.6)
plt.bar(centers, hist_a, width=centers[1]-centers[0], label='Anomaly', alpha=.6)
plt.axvline(THRESH, ls='--', lw=2, color='k', label=f'Thresh={THRESH:.4f}')
plt.xlabel('Reconstruction MSE'); plt.ylabel('Count')
plt.title ('Error distribution (test)'); plt.legend()
plt.tight_layout(); plt.show()

# ---------- (2) 정상·이상 예시 ----------
def pick_idxs(labels, errs, thresh, n_each=5):
    idx_norm = np.where((labels == 0) & (errs < thresh))[0][:n_each]
    idx_anom = np.where((labels == 1) & (errs > thresh))[0][:n_each]
    return np.concatenate([idx_norm, idx_anom])

sel = pick_idxs(labels, errs, THRESH, 5)

fig, axes = plt.subplots(2, 5, figsize=(15,6))
for i, ax in enumerate(axes.flatten()):
    img = plt.imread(paths[sel[i]])[..., :3]
    ax.imshow(img); ax.axis('off')
    ax.set_title(f"{'Normal' if i < 5 else 'Anomaly'}\nMSE={errs[sel[i]]:.4f}")
plt.suptitle('Reconstruction-based Anomaly Detection')
plt.tight_layout(); plt.show()
