In [None]:
import cv2
import os
import random
import numpy as np
import matplotlib.pyplot as plt
import glob
import os

dataset_type = ['train', 'val']

def dataset_generator(dataset_type, patch_size, scaling_factor):
    X = []
    y = []

    roots = ["Y2B_23", "Y2B_24"]
    for ds in roots:
        if ds == "Y2B_23":  # Y2B_23 is the only one with train/val subfolders in images
            image_dir = f"../Task_2/datasets/Y2B_23/{ds}/{dataset_type}"
            images_path = sorted(glob.glob(f"{image_dir}/*.png"))
        else:
            image_dir = f"../Task_2/datasets/{ds}/images"
            all_images = glob.glob(f"{image_dir}/*.png")
            images_path = sorted(
                p for p in all_images
                if os.path.basename(p).startswith(f"{dataset_type}_")
            )
        
        mask_root = f"../Task_2/datasets/{ds}/masks"

        for image_path in images_path:
            base = os.path.basename(image_path)
            stem, _ = os.path.splitext(base)

            if ds == "Y2B_23":
                root_candidates = os.path.join(mask_root, stem + "_root_mask.tif")
                if not root_candidates:
                    continue
            else:
                root_candidates = glob.glob(os.path.join(mask_root, "*", stem + "_root_mask.tif"))
                if not root_candidates:
                    continue
                root_mask_path = root_candidates[0]

            if not os.path.exists(root_mask_path):
                raise FileNotFoundError(f"Mask not found: {root_mask_path}")
            
            img = cv2.imread(image_path)
            root_mask = cv2.imread(root_mask_path, cv2.IMREAD_GRAYSCALE)

            if scaling_factor != 1:
                img = cv2.resize(img, None, fx=scaling_factor, fy=scaling_factor)
                root_mask = cv2.resize(root_mask, None, fx=scaling_factor, fy=scaling_factor)
                
            height, width = img.shape[:2]

            for top in range(0, height - patch_size + 1, patch_size):
                for left in range(0, width - patch_size + 1, patch_size):
                    img_patch = img[top:top+patch_size, left:left+patch_size, :] # x, y, chanels
                    root_patch = root_mask[top:top+patch_size, left:left+patch_size]

                    root_patch = root_patch[..., np.newaxis] # "..." - Ellipsis - takes all dimensions

                    X.append(img_patch)
                    y.append(root_patch)

    X = np.array(X, dtype=np.float32)
    y = np.array(y, dtype=np.float32)
    
    # Normalize
    X = X/255
    y = y/255
    return X, y

In [None]:
patch_size = 128
scaling_factor = 1

In [None]:
X_train, y_train = dataset_generator(dataset_type='train',
                                     patch_size=patch_size,
                                     scaling_factor=scaling_factor)
X_train.shape, y_train.shape # ((4680, 128, 128, 3), (4680, 128, 128, 1))

In [None]:
X_val, y_val = dataset_generator(dataset_type='val',
                                 patch_size=patch_size,
                                 scaling_factor=scaling_factor)
X_val.shape, y_val.shape # ((520, 128, 128, 3), (520, 128, 128, 1))