In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
ROOT_DIR  = '/content/drive/MyDrive/Data/MVTecAD'  # 데이터셋 상위 경로
SAVE_NAME = f'best_skipgan.pt'

# =============================================================
# 0.  의존성 & 전역 설정
# =============================================================
import os, random, time
from glob   import glob
from pathlib import Path
import numpy as np, matplotlib.pyplot as plt
from tqdm.auto import tqdm

import torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from sklearn.metrics import roc_auc_score
from torch.nn.utils import spectral_norm
import torch.autograd as autograd

IMG_SIZE   = 256
BATCH_SIZE = 16
EPOCHS     = 50
LR_G, LR_D = 1e-4, 1e-5
BETAS = (0.5, 0.999)
lambda_GP = 0.0 #spectral norm

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('✅ device:', DEVICE)

# =============================================================
# 1.  데이터셋 & DataLoader
# =============================================================
class MVTecADMulti(Dataset):
    """
    전체 카테고리(train/good, test/good + defects)를 한 번에 스캔
      · phase='train' → 모든 */train/good   (label=0)
      · phase='test'  → */test/good          (label=0)
                        */test/<defect>/*    (label=1)
    """
    def __init__(self, root_dir:str, phase:str='train'):
        assert phase in ('train', 'test')
        self.phase = phase
        self.paths = []          # (img_path, label)

        for cat in sorted(os.listdir(root_dir)):
            cat_path = Path(root_dir)/cat
            if not cat_path.is_dir():
                continue

            if phase == 'train':
                self.paths += [(p,0) for p in
                               glob(str(cat_path/'train'/'good'/'*.png'))]
            else:   # test
                self.paths += [(p,0) for p in
                               glob(str(cat_path/'test'/'good'/'*.png'))]
                for defect in os.listdir(cat_path/'test'):
                    if defect == 'good': continue
                    self.paths += [(p,1) for p in
                                   glob(str(cat_path/'test'/defect/'*.png'))]

        self.tf = T.Compose([
            T.ToPILImage(),
            T.Resize((IMG_SIZE, IMG_SIZE)),
            T.ToTensor(),
            T.Normalize([0.5]*3, [0.5]*3)
        ])

    def __len__(self): return len(self.paths)

    def __getitem__(self, idx):
        img_path, label = self.paths[idx]
        img = plt.imread(img_path)
        if img.ndim == 2:                     # grayscale → 3-ch
            img = np.stack([img]*3, -1)
        if img.max() <= 1:                    # [0,1] → [0,255]
            img = (img*255).astype(np.uint8)
        return self.tf(img), label


def get_loaders_all(root_dir):
    train_ds = MVTecADMulti(root_dir, 'train')
    test_ds  = MVTecADMulti(root_dir, 'test')

    train_loader = DataLoader(
        train_ds, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

    test_loader  = DataLoader(
        test_ds, batch_size=BATCH_SIZE, shuffle=True)

    return train_loader, test_loader



✅ device: cuda


In [3]:
# =============================================================
# 2.  네트워크 정의 (Skip-GANomaly)
# =============================================================
# ------------------------------------------------
# 공통 초기화 함수
# ------------------------------------------------
def weights_init(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
        nn.init.normal_(m.weight, 0.0, 0.02)
        if m.bias is not None:
            nn.init.zeros_(m.bias)

# ------------------------------------------------
# Generator  (Skip-GANomaly)
# ------------------------------------------------
class DownBlock(nn.Module):
    """Conv(stride=2) → BN → LeakyReLU"""
    def __init__(self, c_in, c_out):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(c_in, c_out, 4, 2, 1, bias=False),
            nn.BatchNorm2d(c_out),
            nn.LeakyReLU(0.2),
        )
    def forward(self, x): return self.net(x)

class UpBlock(nn.Module):
    """ConvTranspose(stride=2) → BN → ReLU"""
    def __init__(self, c_in, c_out):
        super().__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose2d(c_in, c_out, 4, 2, 1, bias=False),
            nn.BatchNorm2d(c_out),
            nn.ReLU(),
        )
    def forward(self, x): return self.net(x)

class Generator(nn.Module):
    """
    • 인코더  : 3→64→128→256→512 (4-스탭)
    • 디코더  : 512→256→128→64 (3-스탭) + 최종 Conv
    • skip 연결: 동일 해상도 레이어끼리 concat
    """
    def __init__(self, base=64):
        super().__init__()
        # ---- Encoder ----
        self.e1 = DownBlock(3,      base)         # 256 → 128
        self.e2 = DownBlock(base,   base*2)       # 128 → 64
        self.e3 = DownBlock(base*2, base*4)       #  64 → 32
        self.e4 = DownBlock(base*4, base*8)       #  32 → 16

        # ---- Bottleneck ----  (stride-1 conv 두 번)
        self.bottle = nn.Sequential(
            nn.Conv2d(base*8, base*8, 3, 1, 1, bias=False),
            nn.ReLU(),
            nn.Conv2d(base*8, base*8, 3, 1, 1, bias=False),
            nn.ReLU(),
        )

        # ---- Decoder ----
        self.d3 = UpBlock(base*8,   base*4)       # 16 → 32
        self.d2 = UpBlock(base*8,   base*2)       # 32 → 64   (skip concat)
        self.d1 = UpBlock(base*4,   base)         # 64 → 128  (skip concat)

        # ---- Output (복원 이미지) ----
        self.out_conv = nn.Sequential(
            nn.ConvTranspose2d(base*2, 3, 4, 2, 1, bias=False),  # 128 → 256
            nn.Tanh(),
        )

        self.apply(weights_init)

    def forward(self, x):
        # Encoder
        e1 = self.e1(x)
        e2 = self.e2(e1)
        e3 = self.e3(e2)
        e4 = self.e4(e3)

        # Bottleneck
        b  = self.bottle(e4)

        # Decoder + skip
        d3 = self.d3(b)              # 16→32
        d3 = torch.cat([d3, e3], 1)  # 256 + 256 = 512

        d2 = self.d2(d3)             # 32→64
        d2 = torch.cat([d2, e2], 1)  # 128 + 128 = 256

        d1 = self.d1(d2)             # 64→128
        d1 = torch.cat([d1, e1], 1)  # 64 + 64 = 128

        out = self.out_conv(d1)      # 128→256
        return out                   # [-1,1] 범위 (tanh)

# ------------------------------------------------
# Discriminator  (PatchGAN + Feature extraction)
# ------------------------------------------------
class Discriminator(nn.Module):
    """
    • Conv stride-2 를 5단 쌓아 16×16 패치 단위 판별 (256 입력 기준)
    • WGAN GP / LSGAN 등 어떤 GAN 로스에도 바로 사용 가능
    • return_feat=True 시 encoder stage 평균값을 concat → feature vector
    """
    def __init__(self, base=64):
        super().__init__()
        def dblk(c_in, c_out, bn=True):
            layers = [spectral_norm(nn.Conv2d(c_in, c_out, 4, 2, 1, bias=False))]
            if bn: layers.append(nn.BatchNorm2d(c_out))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return nn.Sequential(*layers)

        self.stem = nn.Sequential(
            dblk(3,      base,   bn=False),  # 256 → 128
            dblk(base,   base*2),            # 128 → 64
            dblk(base*2, base*4),            # 64  → 32
            dblk(base*4, base*8),            # 32  → 16
            dblk(base*8, base*8),            # 16  → 8
        )
        self.out_conv = spectral_norm(nn.Conv2d(base*8, 1, 4, 1, 0))  # 8→1

        self.apply(weights_init)

    def forward(self, x, return_feat=False):
        feats = []
        h = x
        for layer in self.stem:
            h = layer(h)
            feats.append(h)

        score = self.out_conv(h).view(-1)   # PatchGAN: (B)
        if return_feat:
            # 각 stage feature map 평균 → concat (Skip-GANomaly loss용)
            fvec = torch.cat([f.mean([2, 3]) for f in feats], dim=1)
            return score, fvec
        return score

In [4]:
# =============================================================
# 3.  학습 루프 & 평가 함수
# =============================================================
L_adv, L_rec, L_fm = nn.BCEWithLogitsLoss(), nn.L1Loss(), nn.L1Loss()

def gradient_penalty(D, real, fake):
    α = torch.rand(real.size(0),1,1,1, device=real.device)
    inter = (α*real + (1-α)*fake).requires_grad_(True)
    score = D(inter)
    grad  = autograd.grad(outputs=score, inputs=inter,
                          grad_outputs=torch.ones_like(score),
                          create_graph=True, retain_graph=True)[0]
    gp = ((grad.view(grad.size(0), -1).norm(2, dim=1) - 1)**2).mean()
    return gp

def train_epoch(G,D,loader,optG,optD,lambda_rec=50,lambda_fm=10):
    G.train(); D.train(); g_tot=d_tot=0
    for imgs,_ in tqdm(loader, leave=False):
        imgs = imgs.to(DEVICE)

        # --- D (maximize real-fake gap) ---
        optD.zero_grad()
        fake = G(imgs).detach()
        d_real = D(imgs)
        d_fake = D(fake)
        gp = gradient_penalty(D, imgs, fake)
        d_loss = -(d_real.mean() - d_fake.mean()) + lambda_GP*gp
        d_loss.backward(); optD.step()

        # --- G (fool D + recon + feature-match) ---
        optG.zero_grad()
        fake = G(imgs)
        g_adv = -D(fake).mean()                     # WGAN generator loss
        _, feat_f = D(fake, True)
        _, feat_r = D(imgs, True)
        g_rec = L_rec(fake, imgs)
        g_fm  = L_fm(feat_f, feat_r.detach())
        g_loss = g_adv + lambda_rec*g_rec + lambda_fm*g_fm
        g_loss.backward(); optG.step()

        g_tot += g_loss.item();  d_tot += d_loss.item()
    n=len(loader); return g_tot/n, d_tot/n
    return g_tot/n, d_tot/n

@torch.inference_mode()
def get_scores(G, loader):
    G.eval(); scores=[]; labels=[]
    for imgs,lbl in loader:
        imgs = imgs.to(DEVICE)
        err  = torch.mean((G(imgs) - imgs).abs(), dim=[1,2,3])
        scores += err.cpu().tolist(); labels += lbl
    return np.array(scores), np.array(labels)

def fit_all(root_dir, save_path):
    tl, vl = get_loaders_all(root_dir)   # ← cat 인수 제거
    G, D = Generator().to(DEVICE), Discriminator().to(DEVICE)
    optG = torch.optim.Adam(G.parameters(), LR_G, betas=BETAS)
    optD = torch.optim.Adam(D.parameters(), LR_D, betas=BETAS)
    best_auc = 0
    for ep in range(1, EPOCHS+1):
        g, d = train_epoch(G, D, tl, optG, optD)
        s, l = get_scores(G, vl);  auc = roc_auc_score(l, s)
        if auc > best_auc:
            best_auc = auc;  torch.save(G.state_dict(), save_path)
        print(f'[Ep {ep:03d}] G:{g:.3f}  D:{d:.3f}  AUC:{auc:.4f} (best {best_auc:.4f})')
    print('🏁 Done! best AUC:', best_auc)


In [None]:
# =============================================================
# 4.  학습 실행
# =============================================================
SAVE_NAME = 'best_skipgan_all.pt'
fit_all(ROOT_DIR, SAVE_NAME)

  0%|          | 0/226 [00:00<?, ?it/s]

In [None]:
# =============================================================
# 5.  재구성 시각화 (학습 후)
# =============================================================
SAVE_NAME = '/content/best_skipgan_all.pt'

@torch.inference_mode()
def show_recon(G, loader, n=6):
    G.eval()
    imgs,_ = next(iter(loader)); imgs = imgs[:n].to(DEVICE)
    recon = G(imgs)
    imgs, recon = imgs.cpu()*0.5+0.5, recon.cpu()*0.5+0.5
    plt.figure(figsize=(n*2,4))
    for i in range(n):
        plt.subplot(2,n,i+1);     plt.imshow(imgs[i].permute(1,2,0));   plt.axis('off')
        plt.subplot(2,n,n+i+1);   plt.imshow(recon[i].permute(1,2,0));  plt.axis('off')
    plt.suptitle('Input (top) vs Reconstruction (bottom)'); plt.show()

# 사용 예
G = Generator().to(DEVICE); G.load_state_dict(torch.load(SAVE_NAME))
_, test_loader = get_loaders_all(ROOT_DIR)
show_recon(G, test_loader)


In [None]:
# 전체 데이터셋에 대해 정상/이상 개수 확인
num_norm = sum(1 for _, y in test_loader.dataset if y == 0)
num_anom = sum(1 for _, y in test_loader.dataset if y == 1)
print(f"✔ 정상(good)   : {num_norm:,}")
print(f"✔ 이상(defect) : {num_anom:,}")


In [None]:
# =============================================================
# 6.  테스트 & 시각화  (G · SAVE_NAME 로 변수 맞춤)
# =============================================================
import torch.nn.functional as F

# ① Generator 인스턴스 준비 & 가중치 로드
G = Generator().to(DEVICE)
G.load_state_dict(torch.load(SAVE_NAME, map_location=DEVICE))
G.eval()

errs, labels, paths = [], [], []
idx_global = 0                                   # paths 인덱스 계산용

with torch.no_grad():
    for imgs, lbls in test_loader:               # (img, label)
        bsz      = imgs.size(0)
        imgs_gpu = imgs.to(DEVICE)
        recons   = G(imgs_gpu)

        err = F.mse_loss(recons, imgs_gpu, reduction='none')
        err = err.flatten(1).mean(1).cpu().numpy()

        errs.append(err)
        labels.append(lbls.numpy())

        batch_paths = [test_loader.dataset.paths[idx_global + i][0]
                       for i in range(bsz)]
        paths += batch_paths
        idx_global += bsz

errs   = np.concatenate(errs)
labels = np.concatenate(labels)

# 95-percentile 임계값
THRESH = np.percentile(errs[labels == 0], 95)

# ---------- (1) 에러 분포 ----------
hist_n, edges = np.histogram(errs[labels == 0], bins=60)
hist_a, _     = np.histogram(errs[labels == 1], bins=edges)
centers       = (edges[:-1] + edges[1:]) / 2

plt.figure(figsize=(7,4))
plt.bar(centers, hist_n, width=centers[1]-centers[0], label='Normal',  alpha=.6)
plt.bar(centers, hist_a, width=centers[1]-centers[0], label='Anomaly', alpha=.6)
plt.axvline(THRESH, ls='--', lw=2, color='k', label=f'Thresh={THRESH:.4f}')
plt.xlabel('Reconstruction MSE'); plt.ylabel('Count')
plt.title ('Error distribution (test)'); plt.legend()
plt.tight_layout(); plt.show()

# ---------- (2) 정상·이상 예시 ----------
def pick_idxs(labels, errs, thresh, n_each=5):
    idx_norm = np.where((labels == 0) & (errs < thresh))[0][:n_each]
    idx_anom = np.where((labels == 1) & (errs > thresh))[0][:n_each]
    return np.concatenate([idx_norm, idx_anom])

sel = pick_idxs(labels, errs, THRESH, 5)

fig, axes = plt.subplots(2, 5, figsize=(15,6))
for i, ax in enumerate(axes.flatten()):
    img = plt.imread(paths[sel[i]])[..., :3]
    ax.imshow(img); ax.axis('off')
    ax.set_title(f"{'Normal' if i < 5 else 'Anomaly'}\nMSE={errs[sel[i]]:.4f}")
plt.suptitle('Reconstruction-based Anomaly Detection')
plt.tight_layout(); plt.show()
