In [None]:
import os
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
from tqdm import tqdm

# ------------------------------------------------------------
# Helper functions for extracting patches
# ------------------------------------------------------------
def extract_patch(image, mask, y_center, x_center, patch_size):
    """Extracts a patch centered on coordinates (y_center, x_center) with padding if necessary."""
    _, H, W = mask.shape
    half = patch_size // 2

    y1 = max(0, y_center - half)
    x1 = max(0, x_center - half)
    y2 = min(H, y1 + patch_size)
    x2 = min(W, x1 + patch_size)

    img_patch = image[:, y1:y2, x1:x2]
    mask_patch = mask[:, y1:y2, x1:x2]

    # Padding if you hit the edge
    pad_y = patch_size - img_patch.shape[1]
    pad_x = patch_size - img_patch.shape[2]
    if pad_y > 0 or pad_x > 0:
        img_patch = F.pad(img_patch, (0, pad_x, 0, pad_y))
        mask_patch = F.pad(mask_patch, (0, pad_x, 0, pad_y))

    return img_patch, mask_patch


def extract_object_patches(image, mask, patch_size=64):
    """Extracts patches centered on the main object and extra if the object is large."""
    H, W = mask.shape[-2:]
    patches = []

    # finds object pixels
    y_indices, x_indices = torch.where(mask[0] > 0)
    if len(y_indices) == 0:
        return []  # no object

    y_min, y_max = y_indices.min().item(), y_indices.max().item()
    x_min, x_max = x_indices.min().item(), x_indices.max().item()
    y_center = (y_min + y_max) // 2
    x_center = (x_min + x_max) // 2

    # centered patch
    patches.append(extract_patch(image, mask, y_center, x_center, patch_size))

    # adds extra patches if the object is larger than the patch
    if y_min < y_center - patch_size // 2:
        patches.append(extract_patch(image, mask, y_min + patch_size // 2, x_center, patch_size))
    if y_max > y_center + patch_size // 2:
        patches.append(extract_patch(image, mask, y_max - patch_size // 2, x_center, patch_size))
    if x_min < x_center - patch_size // 2:
        patches.append(extract_patch(image, mask, y_center, x_min + patch_size // 2, patch_size))
    if x_max > x_center + patch_size // 2:
        patches.append(extract_patch(image, mask, y_center, x_max - patch_size // 2, patch_size))

    return patches


# ------------------------------------------------------------
# Creating and saving patches to disk
# ------------------------------------------------------------
def create_patches_dataset(dataset_dir,output_dir,patch_size=128,resolution=None):
    """Generates a new dataset with centralized patches and saves it to disk."""
    image_transforms = []
    mask_transforms = []

    # Only apply Resize if resolution was informed
    if resolution is not None:
        image_transforms.append(transforms.Resize((resolution, resolution)))
        mask_transforms.append(transforms.Resize((resolution, resolution), interpolation=Image.NEAREST))

    # Basic Transforms (always applied)
    image_transforms.extend([
        transforms.ToTensor()
    ])

    mask_transforms.extend([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: (x > 0.5).float())
    ])

    image_transform = transforms.Compose(image_transforms)
    mask_transform = transforms.Compose(mask_transforms)

    splits = ["train", "valid", "test"]

    for split in splits:
        img_dir = os.path.join(dataset_dir, "images", split)
        mask_dir = os.path.join(dataset_dir, "labels", split)

        out_img_dir = os.path.join(output_dir, "images", split)
        out_mask_dir = os.path.join(output_dir, "labels", split)
        os.makedirs(out_img_dir, exist_ok=True)
        os.makedirs(out_mask_dir, exist_ok=True)

        image_files = sorted([f for f in os.listdir(img_dir) if f.endswith(".png")])

        print(f"\nProcessing {split} ({len(image_files)} images)...")

        for fname in tqdm(image_files):
            img_path = os.path.join(img_dir, fname)
            mask_path = os.path.join(mask_dir, fname)

            if not os.path.exists(mask_path):
                print(f"[Warning] Mask not found for {fname}")
                continue

            img = Image.open(img_path).convert("RGB")
            mask = Image.open(mask_path).convert("L")

            img_t = image_transform(img)
            mask_t = mask_transform(mask)

            patches = extract_object_patches(img_t, mask_t, patch_size)
            if len(patches) == 0:
                continue  # ignore image without object

            base_name = os.path.splitext(fname)[0]
            for i, (p_img, p_mask) in enumerate(patches):
                img_patch_np = (p_img.permute(1, 2, 0).numpy() * 255).astype("uint8")
                mask_patch_np = (p_mask.squeeze().numpy() * 255).astype("uint8")

                out_img_path = os.path.join(out_img_dir, f"{base_name}_patch{i}.png")
                out_mask_path = os.path.join(out_mask_dir, f"{base_name}_patch{i}.png")

                Image.fromarray(img_patch_np).save(out_img_path)
                Image.fromarray(mask_patch_np).save(out_mask_path)


## Creating patch datasets based on the already augmented dataset

In [None]:
augmented_dataset_512 = "/mnt/TUDAO/0Datasets/fuseg/augmented-v3-512"

In [None]:

if __name__ == "__main__":
    patch_size = 48
    create_patches_dataset(
        dataset_dir=augmented_dataset_512,
        output_dir=f"{augmented_dataset_512}_patches{patch_size}/",
        patch_size=patch_size,
    )


In [None]:
if __name__ == "__main__":
    patch_size = 128
    create_patches_dataset(
        dataset_dir=augmented_dataset_512,
        output_dir=f"{augmented_dataset_512}_patches{patch_size}/",
        patch_size=patch_size,
    )

In [None]:
if __name__ == "__main__":
    patch_size = 256
    create_patches_dataset(
        dataset_dir=augmented_dataset_512,
        output_dir=f"{augmented_dataset_512}_patches{patch_size}/",
        patch_size=patch_size,
    )

In [None]:
if __name__ == "__main__":
    patch_size = 384
    create_patches_dataset(
        dataset_dir=augmented_dataset_512,
        output_dir=f"{augmented_dataset_512}_patches{patch_size}/",
        patch_size=patch_size,
    )