In [4]:
import os,math,time,random,warnings
from dataclasses import dataclass
warnings.filterwarnings("ignore")

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from torchvision import models
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
import FrEIA.framework as Ff
import FrEIA.modules as Fm

In [7]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using GPU:", torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print("Using CPU")

print("PyTorch version:", torch.__version__, "with CUDA =", torch.cuda.is_available())

Using GPU: NVIDIA GeForce RTX 3060 Laptop GPU
PyTorch version: 2.7.1+cu118 with CUDA = True


In [2]:
data_path="mvtec_anomaly_detection"
class_name = "bottle"
encoder_architecture = "resnet18"
input_size = 256
batch_size = 8

epoches = 50
workers = 4
learning_rate = 1e-4

coupling_blocks = 8
condition_dim =128
clamp_alpha = 1.9
seed=42


In [11]:
class MVTEC_Dataset(Dataset):
    def __init__(self, data_path, class_name, is_train=True, resize=256,
                 cropsize=256):
        self.data_path = data_path
        self.class_name = class_name
        self.is_train = is_train
        self.resize = resize
        self.cropsize = cropsize
        self.x, self.y ,self.mask = self._gather()
        self.t_img = T.Compose([
            T.Resize(resize,interpolation=Image.Resampling.LANCZOS),
            T.CenterCrop(cropsize),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])
        ])
        self.t_mask = T.Compose([
            T.Resize(resize,interpolation=Image.Resampling.NEAREST),
            T.CenterCrop(cropsize),
            T.ToTensor(),
        ])

    def _gather(self):
        phase = "train" if self.is_train else "test"
        img_dir = os.path.join(self.data_path, self.class_name, phase) 
        gt_dir = os.path.join(self.data_path, self.class_name, "ground_truth")
        x,y,mask = [],[],[]
        for fname in os.listdir(img_dir):
            tdir = os.path.join(img_dir, fname)
            if not os.path.isdir(img_dir):continue
            files = sorted([os.path.join(tdir,f) for f in os.listdir(tdir) if f.endswith('.png')])
            x.extend(files)
            if fname=="good":
                y.extend([0]*len(files))
                mask.extend([None]*len(files))
            else:
                y.extend([1]*len(files))
                gt_tdir= os.path.join(gt_dir, fname)
                base=[os.path.splitext(os.path.basename(f))[0] for f in files]
                mask.extend([os.path.join(gt_tdir,b+"_mask.png") for b in base])
        assert len(x)==len(y)
        return x,y,mask
    def __len__(self):
        return len(self.x)
    def __getitem__(self, idx):
        img=Image.open(self.x[idx]).convert("RGB")
        img=self.t_img(img)
        if self.y[idx]==0:
            mask=torch.zeros(1,self.cropsize,self.cropsize)
        else:
            mask=Image.open(self.mask[idx])
            mask=self.t_mask(mask)
        return img,self.y[idx],mask

In [6]:
def set_seeds(seed=42):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True; torch.backends.cudnn.benchmark = False

set_seeds(seed)

In [13]:
train_ds=MVTEC_Dataset(data_path, class_name, is_train=True, resize=input_size, cropsize=input_size)
print("Number of training samples:", len(train_ds))
print("Train Batches :", len(train_ds)//batch_size)

Number of training samples: 209
Train Batches : 26


In [15]:
test_ds=MVTEC_Dataset(data_path, class_name, is_train=False, resize=input_size, cropsize=input_size)
print("Number of test samples:", len(test_ds))
print("Test Batches :", len(test_ds)//batch_size)

Number of test samples: 83
Test Batches : 10
