In [8]:
import torch
import os
import cv2
import numpy as np
from sympy import categories
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
from pathlib import Path

In [10]:
devise = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(devise)

IMAG_SIZE = 256
Base_PATH = "D:\dataset\mvtec_anomaly_detection"

cuda


In [16]:
class MVTecSegmentation(Dataset):
    def __init__ (self , root_dir , transform = None):
        self.root_dir = Path(root_dir)
        self.transform = transform
        self.image_paths = []
        self.mask_paths = []

        self.categories = [d.name for d in self.root_dir.iterdir() if d.is_dir()]

        print(f"Building Dataset for {len(self.categories)} categories...")

        for cat in self.categories:
            cat_dir = self.root_dir / cat / "test"
            gt_dir = self.root_dir / cat / "ground_truth"

        if cat_dir.exists():
                for folder in cat_dir.iterdir():
                    if folder.is_dir() and folder.name != "good":
                        for img_path in folder.glob("*.png"):
                            mask_path = gt_dir / folder.name / (img_path.stem + "_mask.png")

                            if mask_path.exists():
                                self.image_paths.append(str(img_path))
                                self.mask_paths.append(str(mask_path))
                    print(f"Found {len(self.image_paths)} images in {cat} category")
        print(f"Total images found across all categories: {len(self.image_paths)}")
        print(f"-----Category: {len(self.categories)}-----")

    def __len__ (self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = cv2.imread((self.image_paths[idx]))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(self.mask_paths[idx] , cv2.IMREAD_GRAYSCALE)
        mask = (mask>0).astype(np.float32)
        return image, mask

transform = A.Compose([
    A.Resize(IMAG_SIZE , IMAG_SIZE),
    A.HorizontalFlip(p=0.5),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
])





In [17]:
all_train_DB = MVTecSegmentation(Base_PATH , transform = transform)

train_loader = DataLoader(
    all_train_DB,
    batch_size=32,
    shuffle=True,
    pin_memory=True,
)

Building Dataset for 15 categories...
Found 19 images in zipper category
Found 35 images in zipper category
Found 52 images in zipper category
Found 68 images in zipper category
Found 68 images in zipper category
Found 85 images in zipper category
Found 103 images in zipper category
Found 119 images in zipper category
Total images found across all categories: 119
-----Category: 15-----
