In [1]:
import torch
import os
import cv2
import numpy as np
from sympy import categories
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
from pathlib import Path

In [2]:
devise = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(devise)

IMAG_SIZE = 256
Base_PATH = "D:\dataset\mvtec_anomaly_detection"

cuda


In [17]:
class MVTecSegmentation(Dataset):
    def __init__ (self , root_dir , transform = None):
        self.root_dir = Path(root_dir)
        self.transform = transform
        self.image_paths = []
        self.mask_paths = []

        self.categories = [d.name for d in self.root_dir.iterdir() if d.is_dir()]

        print(f"Building Dataset for {len(self.categories)} categories...")

        for cat in self.categories:
            test_dir = self.root_dir / cat / "test"
            ground_truth_dir = self.root_dir / cat / "ground_truth"
            train_dir = self.root_dir / cat / "train" / "good"

            if test_dir.exists(): # collect the test image
                for folder in test_dir.iterdir():
                    if folder.is_dir():
                        for img_path in folder.glob("*.png"):
                            self.image_paths.append(str(img_path))
                            if folder.name == "good":
                                self.mask_paths.append(None)
                            else:
                                mask_paths = ground_truth_dir / folder.name / (img_path.stem + ".png")
                                self.mask_paths.append(str(mask_paths) if mask_paths.exists() else None)
            if train_dir.exists():
                print(f"Category: {cat}")
                x=0
                for imag_path in train_dir.glob("*.png"):
                    self.image_paths.append(str(imag_path))
                    self.mask_paths.append(None)
                    x=x+1
                print(f"images :{x} ")



    def __len__ (self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = cv2.imread((self.image_paths[idx]))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if self.mask_paths[idx] is None:
            mask = np.zeros((image.shape[0], image.shape[1]) ,dtype=np.float32)
        else:
            mask = cv2.imread(self.mask_paths[idx], cv2.IMREAD_GRAYSCALE)
            mask = (mask>0).astype(np.float32)

        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented["image"]
            mask = augmented["mask"]

        return image, mask

transform = A.Compose([
    A.Resize(IMAG_SIZE , IMAG_SIZE),
    A.HorizontalFlip(p=0.5),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
])





In [18]:
all_train_DB = MVTecSegmentation(Base_PATH , transform = transform)

train_loader = DataLoader(
    all_train_DB,
    batch_size=32,
    shuffle=True,
    pin_memory=True,
)

Building Dataset for 15 categories...
Category: bottle
images :209 
Category: cable
images :224 
Category: capsule
images :219 
Category: carpet
images :280 
Category: grid
images :264 
Category: hazelnut
images :391 
Category: leather
images :245 
Category: metal_nut
images :220 
Category: pill
images :267 
Category: screw
images :320 
Category: tile
images :230 
Category: toothbrush
images :60 
Category: transistor
images :213 
Category: wood
images :247 
Category: zipper
images :240 
