In [None]:
from pathlib import Path
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as T

class ImageFolderDataset(Dataset):
    def __init__(self, root: Path, img_size: int = 224, augment: bool = False):
        self.root = Path(root)
        self.samples = []
        self.class_to_idx = {}
        classes = sorted([d.name for d in self.root.iterdir() if d.is_dir()])
        self.class_to_idx = {c:i for i,c in enumerate(classes)}

        for c in classes:
            for f in (self.root / c).glob("*.*"):
                if f.suffix.lower() in [".jpg",".jpeg",".png",".bmp",".webp"]:
                    self.samples.append((f, self.class_to_idx[c]))

        base = [
            T.Resize((img_size, img_size)),
            T.ToTensor(),
            T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
        ]
        if augment:
            aug = [
                T.RandomHorizontalFlip(),
                T.RandomRotation(15),
                T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
            ]
            self.tf = T.Compose(aug + base)
        else:
            self.tf = T.Compose(base)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, i):
        path, y = self.samples[i]
        img = Image.open(path).convert("RGB")
        x = self.tf(img)
        return x, y
