In [3]:
import os
import random
import cv2
import numpy as np

def ensure_divisible_by_256(img):
    """
    Pads `img` so its height and width become multiples of 256.
    Returns the padded image.
    """
    h, w = img.shape[:2]

    # Compute padding needed
    pad_h = (256 - h % 256) if (h % 256) != 0 else 0
    pad_w = (256 - w % 256) if (w % 256) != 0 else 0
    
    # If no padding is required, return as-is
    if pad_h == 0 and pad_w == 0:
        return img
    
    # Check if it's grayscale (2D) or color (3D)
    if len(img.shape) == 2:
        # Grayscale
        padded = cv2.copyMakeBorder(
            img, 0, pad_h, 0, pad_w,
            cv2.BORDER_CONSTANT, value=0
        )
    else:
        # Color (3 channels)
        padded = cv2.copyMakeBorder(
            img, 0, pad_h, 0, pad_w,
            cv2.BORDER_CONSTANT, value=(0,0,0)
        )
    return padded

def create_patches_from_image_and_3masks(
    img_bgr,
    seed_mask,
    shoot_mask,
    root_mask,
    base_name,
    patch_size,
    out_img_dir,
    out_mask_dir,
):
    """
    Given a padded color image and three padded grayscale masks, slice them
    into 256x256 patches.

    Instead of saving 3 separate mask files, we merge them into a single
    3-channel mask (channel 0 = seed, 1 = shoot, 2 = root) and save that
    as one .png.

    Returns number of patches created.
    """
    h, w = img_bgr.shape[:2]
    patch_count = 0
    
    for row in range(0, h, patch_size):
        for col in range(0, w, patch_size):
            # Slice image patch
            img_patch = img_bgr[row:row+patch_size, col:col+patch_size]
            # Slice mask patches
            seed_patch  = seed_mask[row:row+patch_size, col:col+patch_size]
            shoot_patch = shoot_mask[row:row+patch_size, col:col+patch_size]
            root_patch  = root_mask[row:row+patch_size, col:col+patch_size]

            # Merge the 3 single-channel masks into one 3-channel image
            #   Channel 0 = seed
            #   Channel 1 = shoot
            #   Channel 2 = root
            mask_3ch = cv2.merge([seed_patch, shoot_patch, root_patch])

            # Generate a consistent filename for the image patch
            # For example: train_Alican_212231_im1_r0000_c0000.png
            patch_name = f"{base_name}.png"
            
            # Save the image patch
            cv2.imwrite(os.path.join(out_img_dir, patch_name), img_patch)
            
            # For the mask, we just use the same patch name but add "_mask"
            # E.g. train_Alican_212231_im1_r0000_c0000_mask.png
            mask_name = patch_name.replace(".png", "_mask.png")
            cv2.imwrite(os.path.join(out_mask_dir, mask_name), mask_3ch)
            
            patch_count += 1
            
    return patch_count

def main():
    # 1. Setup
    random.seed(42)  # for reproducibility if desired
    
    base_dir = "/Users/celinewu/Documents/2024_dataset"
    images_dir = os.path.join(base_dir, "images")  # .png images
    masks_dir = os.path.join(base_dir, "masks")    # .jif masks
    
    output_dir = "/Users/celinewu/Documents/dataset_patches"
    os.makedirs(output_dir, exist_ok=True)
    
    # Create subfolders for train and val
    train_images_dir = os.path.join(output_dir, "train_images")
    train_masks_dir  = os.path.join(output_dir, "train_masks")
    val_images_dir   = os.path.join(output_dir, "val_images")
    val_masks_dir    = os.path.join(output_dir, "val_masks")
    os.makedirs(train_images_dir, exist_ok=True)
    os.makedirs(train_masks_dir, exist_ok=True)
    os.makedirs(val_images_dir, exist_ok=True)
    os.makedirs(val_masks_dir, exist_ok=True)
    
    patch_size = 256
    
    # 2. Identify valid images (those that have all 3 masks)
    #    The script looks for *_seed_mask.jif, *_shoot_mask.jif, *_root_mask.jif
    all_png_files = [f for f in os.listdir(images_dir) if f.lower().endswith('.png')]
    
    valid_image_paths = []
    for png_file in all_png_files:
        base_name = os.path.splitext(png_file)[0]  # e.g. "train_Alican_212231_im1"
        
        # Potential mask filenames
        seed_mask_name  = base_name + "_seed_mask.jif"
        shoot_mask_name = base_name + "_shoot_mask.jif"  # was "shoo_mask.jif", now "shoot_mask.jif"
        root_mask_name  = base_name + "_root_mask.jif"
        
        # Check existence
        seed_path  = os.path.join(masks_dir, seed_mask_name)
        shoot_path = os.path.join(masks_dir, shoot_mask_name)
        root_path  = os.path.join(masks_dir, root_mask_name)
        
        if (os.path.exists(seed_path) and
            os.path.exists(shoot_path) and
            os.path.exists(root_path)):
            valid_image_paths.append(png_file)
        # else skip this image

    # 3. Split into train (75%) and val (25%)
    random.shuffle(valid_image_paths)
    num_valid = len(valid_image_paths)
    split_idx = int(0.75 * num_valid)
    train_list = valid_image_paths[:split_idx]
    val_list   = valid_image_paths[split_idx:]
    
    print(f"Total valid images (with 3 masks): {num_valid}")
    print(f"Train set: {len(train_list)} images")
    print(f"Val set:   {len(val_list)} images")
    
    def process_image_set(image_list, out_img_dir, out_mask_dir, set_name="Train"):
        total_patches = 0
        
        for png_file in image_list:
            base_name = os.path.splitext(png_file)[0]
            img_path = os.path.join(images_dir, png_file)
            
            # Load the image (BGR)
            img_bgr = cv2.imread(img_path)
            if img_bgr is None:
                print(f"Warning: Could not read image: {png_file}")
                continue
            
            # Load corresponding 3 masks (grayscale)
            seed_mask_path  = os.path.join(masks_dir, base_name + "_seed_mask.jif")
            shoot_mask_path = os.path.join(masks_dir, base_name + "_shoot_mask.jif")
            root_mask_path  = os.path.join(masks_dir, base_name + "_root_mask.jif")
            
            seed_mask  = cv2.imread(seed_mask_path, cv2.IMREAD_GRAYSCALE)
            shoot_mask = cv2.imread(shoot_mask_path, cv2.IMREAD_GRAYSCALE)
            root_mask  = cv2.imread(root_mask_path, cv2.IMREAD_GRAYSCALE)
            
            # If any are missing or unreadable, skip
            if seed_mask is None or shoot_mask is None or root_mask is None:
                print(f"Warning: Could not read all 3 masks for {png_file}. Skipping.")
                continue
            
            # Pad them so that dimensions are multiples of 256
            img_bgr    = ensure_divisible_by_256(img_bgr)
            seed_mask  = ensure_divisible_by_256(seed_mask)
            shoot_mask = ensure_divisible_by_256(shoot_mask)
            root_mask  = ensure_divisible_by_256(root_mask)
            
            # Create patches and save
            patch_count = create_patches_from_image_and_3masks(
                img_bgr,
                seed_mask,
                shoot_mask,
                root_mask,
                base_name,
                patch_size,
                out_img_dir,
                out_mask_dir
            )
            
            total_patches += patch_count
        
        print(f"{set_name}: Processed {len(image_list)} images, created {total_patches} patches.")
    
    # 4. Process the train set
    process_image_set(train_list, train_images_dir, train_masks_dir, set_name="Train")
    # 5. Process the val set
    process_image_set(val_list,   val_images_dir,   val_masks_dir,   set_name="Val")

if __name__ == "__main__":
    main()


Total valid images (with 3 masks): 0
Train set: 0 images
Val set:   0 images
Train: Processed 0 images, created 0 patches.
Val: Processed 0 images, created 0 patches.
