In [None]:
import sys
sys.path.append("../SkewNet")

In [None]:
from utils import coco_utils
from pycocotools import mask as mask_utils
import glob
import os
import random
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

In [None]:
def apply_mask_and_crop(image, mask):
    if image.shape[:2] != mask.shape:
        raise ValueError("Image and mask dimensions do not match")

    h, w, _ = image.shape
    output = np.zeros((h, w, 4), dtype=np.uint8)  # 4 channels, the last one is alpha
    output[:, :, :3] = image                      
    output[:, :, 3] = mask * 255                  

    rows = np.any(mask, axis=1)
    cols = np.any(mask, axis=0)
    rmin, rmax = np.where(rows)[0][[0, -1]]
    cmin, cmax = np.where(cols)[0][[0, -1]]
    
    cropped_output = output[rmin:rmax+1, cmin:cmax+1, :]
    return cropped_output

In [None]:
def find_image_mask_pairs(images_root, masks_root):
    mask_paths = glob.glob(os.path.join(masks_root, "**", "*.json"), recursive=True)
    
    pairs = []
    for mask_path in mask_paths:
        relative_path = os.path.relpath(mask_path, masks_root)
        base, _ = os.path.splitext(relative_path)
        image_path = os.path.join(images_root, base + ".jpg")

        if os.path.exists(image_path):
            pairs.append((image_path, mask_path))

    return pairs

In [None]:
def save_masked_image(image_path, mask_path, output_directory):
    # Find the directory structure of the original image and recreate it under the output directory
    relative_path = os.path.relpath(image_path, images_root)
    folder_structure = os.path.dirname(relative_path)
    final_output_directory = os.path.join(output_directory, folder_structure)

    os.makedirs(final_output_directory, exist_ok=True)
    image = Image.open(image_path)
    image = np.asarray(image)
    rle = coco_utils.load_rle_from_file(mask_path)
    mask = mask_utils.decode(rle)
    masked_image = apply_mask_and_crop(image, mask)
    masked_image_pil = Image.fromarray(masked_image)
    base_name = os.path.basename(image_path)
    name_without_extension, _ = os.path.splitext(base_name)
    output_path = os.path.join(final_output_directory, f"{name_without_extension}.png")
    masked_image_pil.save(output_path)

In [None]:
images_root = "/scratch/gpfs/RUSTOW/deskewing_datasets/images/cudl_images/images"
masks_root = "/scratch/gpfs/RUSTOW/deskewing_datasets/images/cudl_images/document_masks"
output_directory = "/scratch/gpfs/RUSTOW/deskewing_datasets/images/cudl_images/segmented_images"

In [None]:
pairs = find_image_mask_pairs(images_root, masks_root)
for image_path, mask_path in pairs:
    save_masked_image(image_path, mask_path, output_directory)