In [None]:
#subclassing Dataset
class CastomDataset(Dataset):
    """
    Custom dataset for loading images from a directory.
    """
    def __init__(self,
                 img_dir:str,
                 transform:transforms.Compose = transforms.Compose([transforms.ToTensor()]),
                 target_transform:transforms.Compose = transforms.Compose([transforms.ToTensor()])):
        """
        Args:
            img_dir: path to image directory.
            transform: transformation to apply to the images.
            target_transform: transformation to apply to the targets.
        """
        self.img_dir = pathlib.Path(img_dir)
        self.img_paths = [item for item in self.img_dir.glob('*') if item.is_file()]
        self.transform = transform
        self.target_transform = target_transform
        self.classes, self.class_to_idx = find_classes(self.img_dir)
        self.targets = [self.class_to_idx[item.parent.name] for item in self.img_paths]
        
    def __len__(self):
        return len(self.img_paths)
    
    def __getitem__(self, index):
        img_path = self.img_paths[index]
        img = Image.open(img_path)
        img_transformed = self.transform(img)
        target = self.targets[index]
        target_transformed = self.target_transform(target)
        
        return img_transformed, target_transformed


In [None]:
class ImageFolder(Dataset):
    """
    Custom dataset for loading images from a directory.
    """
    def __init__(self,
                 targ_dir:str,
                    transform=None,
                    target_transform=None):
        """create class atrributes"""
        self.paths = [item for item in pathlib.Path(targ_dir).glob('*') if item.is_file()]
        print(self.paths)
        self.classes, self.class_to_idx = find_classes(targ_dir)
        print(self.classes), print(self.class_to_idx)
        self.targets = [self.class_to_idx[item.parent.name] for item in self.paths]
        print(self.targets)
        self.transform = transform
        self.target_transform = target_transform

    def load_image(self, index):
        """load image"""
        img_path = self.paths[index]
        print(img_path)
        img = Image.open(img_path)
        return img
    
    def __len__(self):
        """return length"""
        return len(self.paths)
    def __getitem__(self, index:int) -> Tuple[torch.Tensor, int]:
        """return image and target"""
        img = self.load_image(index)
        class_name = self.paths[index].parent.name
        target = self.class_to_idx[class_name]
        if self.transform:
            img = self.transform(img)
        if self.target_transform:
            target = self.target_transform(target)
        return img, target
    

    
