In [58]:
import os,math,time,random,warnings
from dataclasses import dataclass
warnings.filterwarnings("ignore")

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from torchvision import models
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
import FrEIA.framework as Ff
import FrEIA.modules as Fm
from sklearn.metrics import roc_auc_score
print(np.__version__, torch.__version__)


2.2.6 2.7.1+cu118


In [59]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using GPU:", torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print("Using CPU")

print("PyTorch version:", torch.__version__, "with CUDA =", torch.cuda.is_available())

Using GPU: NVIDIA GeForce RTX 3060 Laptop GPU
PyTorch version: 2.7.1+cu118 with CUDA = True


In [60]:
data_path="mvtec_anomaly_detection"
class_name = "bottle"
encoder_architecture = "wide_resnet50_2"  # "resnet18"
input_size = 256
batch_size = 8

epoches = 1
workers = 0
learning_rate = 2e-4

coupling_blocks = 8
condition_dim =128
clamp_alpha = 1.9
seed=42


In [61]:
class MVTEC_Dataset(Dataset):
    def __init__(self, data_path, class_name, is_train=True, resize=256,
                 cropsize=256):
        self.data_path = data_path
        self.class_name = class_name
        self.is_train = is_train
        self.resize = resize
        self.cropsize = cropsize
        self.x, self.y ,self.mask = self._gather()
        self.t_img = T.Compose([
            T.Resize(resize,interpolation=Image.Resampling.LANCZOS),
            T.CenterCrop(cropsize),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])
        ])
        self.t_mask = T.Compose([
            T.Resize(resize,interpolation=Image.Resampling.NEAREST),
            T.CenterCrop(cropsize),
            T.ToTensor(),
        ])

    def _gather(self):
        phase = "train" if self.is_train else "test"
        img_dir = os.path.join(self.data_path, self.class_name, phase) 
        gt_dir = os.path.join(self.data_path, self.class_name, "ground_truth")
        x,y,mask = [],[],[]
        for fname in os.listdir(img_dir):
            tdir = os.path.join(img_dir, fname)
            if not os.path.isdir(img_dir):continue
            files = sorted([os.path.join(tdir,f) for f in os.listdir(tdir) if f.endswith('.png')])
            x.extend(files)
            if fname=="good":
                y.extend([0]*len(files))
                mask.extend([None]*len(files))
            else:
                y.extend([1]*len(files))
                gt_tdir= os.path.join(gt_dir, fname)
                base=[os.path.splitext(os.path.basename(f))[0] for f in files]
                mask.extend([os.path.join(gt_tdir,b+"_mask.png") for b in base])
        assert len(x)==len(y)
        return x,y,mask
    def __len__(self):
        return len(self.x)
    def __getitem__(self, idx):
        img=Image.open(self.x[idx]).convert("RGB")
        img=self.t_img(img)
        if self.y[idx]==0:
            mask=torch.zeros(1,self.cropsize,self.cropsize)
        else:
            mask=Image.open(self.mask[idx])
            mask=self.t_mask(mask)
        return img,self.y[idx],mask

In [62]:
def set_seeds(seed=42):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True; torch.backends.cudnn.benchmark = False

set_seeds(seed)

In [63]:
train_ds=MVTEC_Dataset(data_path, class_name, is_train=True, resize=input_size, cropsize=input_size)
print("Number of training samples:", len(train_ds))
print("Train Batches :", len(train_ds)//batch_size)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True,drop_last=True)


Number of training samples: 209
Train Batches : 26


In [64]:
test_ds=MVTEC_Dataset(data_path, class_name, is_train=False, resize=input_size, cropsize=input_size)
print("Number of test samples:", len(test_ds))
print("Test Batches :", len(test_ds)//batch_size)
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=workers, pin_memory=True)

Number of test samples: 83
Test Batches : 10


In [65]:
def subnet_fc(dims_in, dims_out):
    return nn.Sequential(
        nn.Linear(dims_in, 2*dims_in),
        nn.ReLU(),
        nn.Linear(2*dims_in, dims_out)
    )

In [66]:
def cflow_head(n_feat,condition_dim,coupling_blocks,clamp_alpha):
    coder=Ff.SequenceINN(n_feat)
    for _ in range(coupling_blocks):
                coder.append(Fm.AllInOneBlock, cond=0, cond_shape=(condition_dim,),
                    subnet_constructor=subnet_fc, affine_clamping=clamp_alpha,
                    global_affine_type='SOFTPLUS', permute_soft=False)
    return coder

In [67]:
def load_encoder(arch,pool_layers):
    acts={}
    def hook(name):
        def fn(module, input, output):
            acts[name] = output.detach()
        return fn
    
    if arch=="resnet18":
        encoder=models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        layers=[encoder.layer2,encoder.layer3,encoder.layer4]
        dims=[encoder.layer2[-1].conv2.out_channels,
              encoder.layer3[-1].conv2.out_channels,
              encoder.layer4[-1].conv2.out_channels]
    elif arch=="wide_resnet50_2":
        encoder=models.wide_resnet50_2(weights=models.Wide_ResNet50_2_Weights.DEFAULT)
        layers=[encoder.layer2,encoder.layer3,encoder.layer4]
        dims=[encoder.layer2[-1].conv3.out_channels,
              encoder.layer3[-1].conv3.out_channels,
              encoder.layer4[-1].conv3.out_channels]
    else:
        raise NotImplementedError(f"Encoder architecture {arch} not implemented")
    L=min(len(pool_layers),len(layers))
    for i in range(L):
        layers[i].register_forward_hook(hook(pool_layers[i]))
    return encoder.eval(),pool_layers[:L], dims[:L], acts


In [68]:
pool_layers=['l0','l1','l2']
encoder,pool_layers,pool_dims,acts=load_encoder(encoder_architecture,pool_layers)

In [69]:
encoder.to(device).eval()
decoders=[]
for d in pool_dims:
    decoder=cflow_head(d,condition_dim,coupling_blocks,clamp_alpha)
    decoder.to(device)
    decoders.append(decoder)

In [70]:
for p in encoder.parameters():
    p.requires_grad=False
optim=torch.optim.Adam(
    [p for decoder in decoders for p in decoder.parameters()],
    lr=learning_rate)



In [71]:
def positionalencoding2d(D, H, W):
    if D % 4 != 0:
        raise ValueError("positionalencoding2d: D must be divisible by 4")
    pe = torch.zeros(D, H, W)
    half = D // 2
    div_term = torch.exp(torch.arange(0., half, 2) * -(math.log(10000.0) / half))
    pos_w = torch.arange(0., W).unsqueeze(1)
    pos_h = torch.arange(0., H).unsqueeze(1)
    pe[0:half:2, :, :] = torch.sin(pos_w * div_term).T.unsqueeze(1).repeat(1, H, 1)
    pe[1:half:2, :, :] = torch.cos(pos_w * div_term).T.unsqueeze(1).repeat(1, H, 1)
    pe[half::2, :, :]  = torch.sin(pos_h * div_term).T.unsqueeze(2).repeat(1, 1, W)
    pe[half+1::2, :, :] = torch.cos(pos_h * div_term).T.unsqueeze(2).repeat(1, 1, W)
    return pe

In [72]:
def gaussian_nll_logprob(z, log_jac_det):
    k = z.size(1)
    const = -0.5 * k * torch.log(torch.tensor(2 * math.pi, device=z.device, dtype=z.dtype))
    quad = -0.5 * torch.sum(z**2, dim=1)
    return const + quad + log_jac_det

In [73]:
def train_epoch(encoder,decoders,loader,optim,device,pool_layers,acts,P):
    for d in decoders:
        d.train()
    
    total_loss=0.0
    total_B=0
    log_sigomoid=nn.LogSigmoid()
    N=256
    for imgs,_,_ in tqdm(loader,desc="Training"):
        images=imgs.to(device)
        optim.zero_grad()
        with torch.no_grad():_=encoder(images)
        loss_sum=0.0
        for l,name in enumerate(pool_layers):
            feat=acts[name]
            B,C,H,W=feat.size()
            S=H*W
            E=B*S
            p=positionalencoding2d(P,H,W).to(device).unsqueeze(0).repeat(B,1,1,1)
            c_r=p.reshape(B,P,H,W).transpose(1,2).reshape(E,P)
            e_r = feat.reshape(B,C,S).transpose(1,2).reshape(E,C)
            perm=torch.randperm(E,device=device)
            a=max(1,E//N)
            dec=decoders[l]
            for f in range(a):
                idx=torch.arange(f*N,min((f+1)*N,E),device=device)
                if idx.numel()==0:
                    continue
                c_p=c_r[perm[idx]]
                e_p=e_r[perm[idx]]
                z,log_jac=dec(e_p, [c_p,])
                log_prob=gaussian_nll_logprob(z, log_jac)/C
                loss=-log_sigomoid(log_prob).sum()
                loss_sum+=loss
        loss_sum.backward()
        optim.step()
        total_loss+=loss_sum.item()
        total_B+=images.size(0)
    avg_loss=total_loss/max(1,total_B)
    return avg_loss
    
    

In [74]:
def test_epoch(encoder, decoders, loader, device, pool_layers, acts, P, cropsize, fraction: float = 1.0):
    
    for d in decoders: 
        d.eval()
    heights, widths=[], []
    dists=[[] for _ in pool_layers]
    gt_label_list, gt_mask_list=[], []

    total_batches=len(loader)
    use_batches=max(1, int(math.ceil(total_batches * float(fraction))))

    with torch.no_grad():
        for i,(images, labels, masks) in enumerate(tqdm(loader, desc='Test')):
            if i>=use_batches: break
            gt_label_list.extend(labels.numpy()); gt_mask_list.extend(masks.numpy())
            images=images.to(device)
            _=encoder(images)
            for l,name in enumerate(pool_layers):
                feet=acts[name]
                B,C,H,W=feet.size(); S=H*W; E=B*S
                if i==0: heights.append(H); widths.append(W)
                p=positionalencoding2d(P, H, W).to(device).unsqueeze(0).repeat(B,1,1,1)
                c_r=p.reshape(B,P,S).transpose(1,2).reshape(E,P)
                e_r=feet.reshape(B,C,S).transpose(1,2).reshape(E,C)
                dec=decoders[l]
                z,log_jac = dec(e_r, [c_r,])
                log_prob=gaussian_nll_logprob(z, log_jac) / C
                dists[l].extend(log_prob.detach().cpu().tolist())
    return heights, widths, dists, gt_label_list, gt_mask_list

In [75]:
def compute_anomaly_map(dists, heights, widths, cropsize, pool_layers):
    maps=[]
    for l in range(len(pool_layers)):
        t=torch.tensor(dists[l], dtype=torch.double)
        t-=torch.max(t); prob = torch.exp(t)
        m=prob.reshape(-1, heights[l], widths[l])
        up=F.interpolate(m.unsqueeze(1), size=(cropsize,cropsize), mode='bilinear', align_corners=True).squeeze().numpy()
        maps.append(up)
    score=np.zeros_like(maps[0])
    for m in maps:
        score+=m
    super_mask=score.max()-score
    return super_mask

In [76]:
history = {"loss": [], "det_auroc_full": []}
ckpt_root = os.path.join("new models", class_name)
os.makedirs(ckpt_root, exist_ok=True)

best_det_auroc = -1.0
best_path = os.path.join(ckpt_root, "best_cflow_ad.pt")
last_path = os.path.join(ckpt_root, "last_cflow_ad.pt")

for epoch in range(1, epoches + 1):
    loss=train_epoch(encoder, decoders, train_loader, optim, device, pool_layers, acts, condition_dim)
    history["loss"].append(loss)
    heights, widths, dists, gt_label_list, gt_mask_list = test_epoch(
        encoder, decoders, test_loader, device, pool_layers, acts, condition_dim, input_size, fraction=1.0
)
    super_mask = compute_anomaly_map(dists, heights, widths, input_size, pool_layers)
    score_label = np.max(super_mask, axis=(1,2))
    gt_label = np.asarray(gt_label_list, dtype=bool)
    det_auroc = float(roc_auc_score(gt_label, score_label)*100)
    history["det_auroc_full"].append(det_auroc)

    print(f"Epoch {epoch+1}/{epoches} - loss {loss:.4f} - full-test AUROC {det_auroc:.2f}%")

    state = {

        "encoder_arch": encoder_architecture,
        "pool_layers": pool_layers,
        "epoch": (epoch+1),
        "model": [d.state_dict() for d in decoders],
        "optimizer": optim.state_dict(),
        "det_auroc_full": det_auroc,
    }

    # Save best if improved
    if det_auroc > best_det_auroc:
        torch.save(state, best_path)
        best_det_auroc = det_auroc
        print(f"  ✓ Saved BEST to {best_path} (AUROC {best_det_auroc:.2f}%)")

final_state = {
    "encoder_arch": encoder_architecture,
    "pool_layers": pool_layers,
    "epoch": epoches,
    "model": [d.state_dict() for d in decoders],
    "optimizer": optim.state_dict(),
    "det_auroc_full": history["det_auroc_full"][-1] if history["det_auroc_full"] else None,
}
torch.save(final_state, last_path)
print(f"Saved LAST to {last_path}")


Training: 100%|██████████| 26/26 [02:57<00:00,  6.82s/it]
Test: 100%|██████████| 11/11 [01:05<00:00,  5.95s/it]


Epoch 2/1 - loss 2124.8467 - full-test AUROC 99.37%
  ✓ Saved BEST to new models\bottle\best_cflow_ad.pt (AUROC 99.37%)
Saved LAST to new models\bottle\last_cflow_ad.pt
