In [1]:
import rasterio
from rasterio.windows import Window
import numpy as np
import cv2
import os

In [25]:
def crop_and_save(image, x, y, size, save_path, file_name, is_binary=False, binary_threshold = 0.5):
    window = Window(x - size // 2, y - size // 2, size, size)
    cropped_image = image.read(window=window)
    if cropped_image.min() < 0:
        cropped_image = -cropped_image

    # Reshape and convert images
    if cropped_image.ndim == 3 and not is_binary:
        # For RGB or RGBA images
        if cropped_image.shape[0] == 3:  # RGB
            cropped_image = np.moveaxis(cropped_image, 0, -1)  # Rearrange bands to last dimension
        elif cropped_image.shape[0] == 4:  # RGBA
            # Convert RGBA to RGB by discarding the alpha channel
            cropped_image = np.moveaxis(cropped_image, 0, -1)[..., :3]
        else:
            print(cropped_image.shape[0])
            raise ValueError("Unexpected number of bands in image")
    else:
        # For single-band images (binary or grayscale)
        cropped_image = cropped_image.reshape(cropped_image.shape[1], cropped_image.shape[2])
        if is_binary:
            _, cropped_image = cv2.threshold(cropped_image, binary_threshold, 255, cv2.THRESH_BINARY)
        else:
            cropped_image = cv2.cvtColor(cropped_image, cv2.COLOR_GRAY2RGB)  # Convert grayscale to RGB

    cv2.imwrite(os.path.join(save_path, file_name), cropped_image)



In [3]:

def process_files(ground_truth_path, rgb_paths, stream_order_path, output_dir, crop_size=128):
    os.makedirs(output_dir, exist_ok=True)

    # Create subdirectories for different types of images
    gt_dir = os.path.join(output_dir, "ground_truth")
    stream_dir = os.path.join(output_dir, "stream_order")
    rgb_dir = os.path.join(output_dir, "rgb_images")
    os.makedirs(gt_dir, exist_ok=True)
    os.makedirs(stream_dir, exist_ok=True)
    os.makedirs(rgb_dir, exist_ok=True)

    with rasterio.open(ground_truth_path) as gt_src, rasterio.open(stream_order_path) as stream_src:
        gt_data = gt_src.read(1)
        y_indices, x_indices = np.where(gt_data > 0)

        rgb_srcs = [rasterio.open(path) for path in rgb_paths]

        for x, y in zip(x_indices, y_indices):
            crop_and_save(stream_src, x, y, crop_size, stream_dir, f'KS1_stream_{x}_{y}.png', is_binary=True, binary_threshold=1)
            crop_and_save(gt_src, x, y, crop_size, gt_dir, f'KS1_ground_truth_{x}_{y}.png', is_binary=True, binary_threshold=0.5)
            for i, rgb_src in enumerate(rgb_srcs):
                crop_and_save(rgb_src, x, y, crop_size, rgb_dir, f'KS1_rgb_{i}_{x}_{y}.png')

    for src in rgb_srcs:
        src.close()

In [9]:
output_dir = '../Res_128x128_50p/'
ground_truth_path = '../raw_data/HUC_102701030402/GT/Res_KS_GT.tif'
rgb_paths = ['../raw_data/HUC_102701030402/RGB/Res_agg_11.tif', '../raw_data/HUC_102701030402/RGB/Res_agg_12.tif'
             , '../raw_data/HUC_102701030402/RGB/Res_agg_13.tif', '../raw_data/HUC_102701030402/RGB/Res_agg_14.tif',
            '../raw_data/HUC_102701030402/RGB/Res_agg_15.tif', '../raw_data/HUC_102701030402/RGB/Res_agg_16.tif']
stream_order_path = '../raw_data/HUC_102701030402/SO/SO_KS_Strahler.tif'
crop_path = '../raw_data/HUC_102701030402/CROP/R_aggregated_crop.tif'

In [10]:
process_files(ground_truth_path, rgb_paths, stream_order_path, output_dir, crop_size=128)

In [42]:
with rasterio.open(crop_path) as gt_src:
    gt_data = gt_src.read(1)



In [43]:
gt_data.max()

195.0

In [26]:
def process_files_with_overlap(ground_truth_path, rgb_paths, stream_order_path, output_dir, crop_size=128, overlap_rate = 0.5):
    os.makedirs(output_dir, exist_ok=True)

    # Create subdirectories for different types of images
    gt_dir = os.path.join(output_dir, "ground_truth")
    stream_dir = os.path.join(output_dir, "stream_order")
    rgb_dir = os.path.join(output_dir, "rgb_images")
    os.makedirs(gt_dir, exist_ok=True)
    os.makedirs(stream_dir, exist_ok=True)
    os.makedirs(rgb_dir, exist_ok=True)

    with rasterio.open(ground_truth_path) as gt_src, rasterio.open(stream_order_path) as stream_src:
        gt_data = gt_src.read(1)
        if gt_data.min() < 0:
            gt_data = -gt_data
        y_indices, x_indices = np.where(gt_data > 0)

        rgb_srcs = [rasterio.open(path) for path in rgb_paths]
        cropped_regions = []  # List to keep track of cropped regions

        overlap_th = crop_size * overlap_rate
        print("check1")
        for x, y in zip(x_indices, y_indices):
#             print("check12")
            # Check for overlap with existing cropped regions
            overlap = False
            for (prev_x, prev_y) in cropped_regions:
                if abs(prev_x - x) < overlap_th and abs(prev_y - y) < overlap_th:
                    overlap = True
                    break

            if overlap:
#                 print("overlap")
                continue  # Skip cropping this region due to overlap

            # Update the list of cropped regions
            cropped_regions.append((x, y))
#             print("Check2")
            # Crop and save as usual
            crop_and_save(stream_src, x, y, crop_size, stream_dir, f'KS1_stream_{x}_{y}.png', is_binary=True, binary_threshold=1)
            crop_and_save(gt_src, x, y, crop_size, gt_dir, f'KS1_ground_truth_{x}_{y}.png', is_binary=True, binary_threshold=0.5)
            for i, rgb_src in enumerate(rgb_srcs):
                crop_and_save(rgb_src, x, y, crop_size, rgb_dir, f'KS1_rgb_{i}_{x}_{y}.png')

        for src in rgb_srcs:
            src.close()

In [27]:
process_files_with_overlap(ground_truth_path, rgb_paths, stream_order_path, output_dir, crop_size=128, overlap_rate=0.5)

check1
