In [32]:
from pathlib import Path
from sklearn.model_selection import train_test_split
import shutil
import glob
import os
from itertools import chain
from PIL import Image
from torch.utils.data import Dataset
import torchvision
from torchvision.transforms import v2

In [24]:
data_path = Path.cwd() / 'data'
test_path = Path(data_path / "test")
train_path = Path(data_path / "train")
val_path = Path(data_path / "val")

for path in [test_path, train_path, val_path]:
    files = glob.glob(path.absolute().as_posix() + '/*')
    for f in files:
        os.remove(f)

original_data_path = data_path / 'landscape-pictures'
all_filenames = list(original_data_path.glob('*'))
train_val_filenames, test_filenames = train_test_split(all_filenames, test_size=1000, random_state=123)
train_filenames, val_filenames = train_test_split(train_val_filenames, test_size=500, random_state=123)

subdirectories = {
    test_path: test_filenames,
    val_path: val_filenames,
    train_path: train_filenames
}

def fill_sub_dir(sub_dir, file_subset):
    for file in file_subset:
        file_path = sub_dir / file.name
        shutil.copyfile(file, file_path)

for sub_dir, file_subset in subdirectories.items():
    fill_sub_dir(sub_dir, file_subset)

In [28]:
test_path, val_path, train_path = [path.absolute().as_posix() for path in list(subdirectories.keys())]

In [30]:
class ImageData(Dataset):
    def __init__(self, root, transform):
        """Constructor

        Args:
            root (Path/str): Filepath to the data root, e.g. './data/train'
            transform (Compose): A composition of image transforms, see below.
        """

        root = Path(root)
        if not (root.exists() and root.is_dir()):
            raise ValueError(f"Data root '{root}' is invalid")

        self.root = root
        self.transform = transform
        self.greyscale_transform = torchvision.transforms.v2.Grayscale

        # Collect samples, both cat and dog and store pairs of (filepath, label) in a simple list.
        self._samples = self._collect_samples()

    def __getitem__(self, index):
        """Get sample by index

        Args:
            index (int)

        Returns:
             The index'th sample (Tensor, int)
        """
        # Access the stored path and label for the correct index
        path = self._samples[index]
        # Load the image into memory
        original_img = Image.open(path)
        greyscale_img = Image.open(path)
        # Perform transforms, if any.
        if self.transform is not None:
            original_img = self.transform(original_img)
            greyscale_img = self.greyscale_transform(greyscale_img)
        return original_img, greyscale_img

    def __len__(self):
        """Total number of samples"""
        return len(self._samples)

    def _collect_samples(self):
        """"
        Helper method for the constructor
        """

        paths = self._collect_imgs_sub_dir(self.root)
        # Sorting is not strictly necessary, but filesystem globbing (wildcard search) is not deterministic,
        # and consistency is nice when debugging.
        return sorted(paths)

    @staticmethod
    def _collect_imgs_sub_dir(sub_dir: Path):
        """Collect image paths in a directory

        Helper method for the constructor
        """
        if not sub_dir.exists():
            raise ValueError(
                f"Directory '{sub_dir}' does not exist. Are you sure you have the correct path?"
            )
        return sub_dir.glob("*.jpg")

In [35]:
image_size = 256

transform = v2.Compose([
    v2.Resize(image_size),
    v2.CenterCrop(image_size),
    v2.RandomHorizontalFlip(0.05)
])

train_dataset = ImageData(train_path, transform)

train_dataset[0]

(<PIL.Image.Image image mode=RGB size=256x256>, Grayscale())