In [1]:
import rasterio
from rasterio.windows import Window
import numpy as np
import cv2
import os
from tqdm import tqdm

In [41]:
def crop_and_save(image, x, y, size, save_path, file_name, is_binary=False, binary_threshold = 0.5):
    window = Window(x - size // 2, y - size // 2, size, size)
    cropped_image = image.read(window=window)
    if cropped_image.min() < 0:
        cropped_image = -cropped_image

    # Reshape and convert images
    if cropped_image.ndim == 3 and not is_binary:
        # For RGB or RGBA images
        if cropped_image.shape[0] == 3:  # RGB
            cropped_image = np.moveaxis(cropped_image, 0, -1)  # Rearrange bands to last dimension
        elif cropped_image.shape[0] == 4:  # RGBA
            # Convert RGBA to RGB by discarding the alpha channel
            cropped_image = np.moveaxis(cropped_image, 0, -1)[..., :3]
        else:
            print(cropped_image.shape[0])
            raise ValueError("Unexpected number of bands in image")
    else:
        # For single-band images (binary or grayscale)
        cropped_image = cropped_image.reshape(cropped_image.shape[1], cropped_image.shape[2])
        if is_binary:
            _, cropped_image = cv2.threshold(cropped_image, binary_threshold, 255, cv2.THRESH_BINARY)
        else:
            cropped_image = cv2.cvtColor(cropped_image, cv2.COLOR_GRAY2RGB)  # Convert grayscale to RGB

    cv2.imwrite(os.path.join(save_path, file_name), cropped_image)



In [42]:

def process_files(ground_truth_path, rgb_paths, stream_order_path, output_dir, crop_size=128):
    os.makedirs(output_dir, exist_ok=True)

    # Create subdirectories for different types of images
    gt_dir = os.path.join(output_dir, "ground_truth")
    stream_dir = os.path.join(output_dir, "stream_order")
    rgb_dir = os.path.join(output_dir, "rgb_images")
    os.makedirs(gt_dir, exist_ok=True)
    os.makedirs(stream_dir, exist_ok=True)
    os.makedirs(rgb_dir, exist_ok=True)

    with rasterio.open(ground_truth_path) as gt_src, rasterio.open(stream_order_path) as stream_src:
        gt_data = gt_src.read(1)
        y_indices, x_indices = np.where(gt_data > 0)

        rgb_srcs = [rasterio.open(path) for path in rgb_paths]

        for x, y in zip(x_indices, y_indices):
            crop_and_save(stream_src, x, y, crop_size, stream_dir, f'KS1_stream_{x}_{y}.png', is_binary=True, binary_threshold=1)
            crop_and_save(gt_src, x, y, crop_size, gt_dir, f'KS1_ground_truth_{x}_{y}.png', is_binary=True, binary_threshold=0.5)
            for i, rgb_src in enumerate(rgb_srcs):
                crop_and_save(rgb_src, x, y, crop_size, rgb_dir, f'KS1_rgb_{i}_{x}_{y}.png')

    for src in rgb_srcs:
        src.close()

In [38]:
output_dir = '../pos_test_1/'
ground_truth_path = '../raw_data/HUC_102701030402/GT/Res_KS_GT.tif'
rgb_paths = ['../raw_data/HUC_102701030402/RGB/Res_agg_11.tif', '../raw_data/HUC_102701030402/RGB/Res_agg_12.tif'
             , '../raw_data/HUC_102701030402/RGB/Res_agg_13.tif', '../raw_data/HUC_102701030402/RGB/Res_agg_14.tif',
            '../raw_data/HUC_102701030402/RGB/Res_agg_15.tif', '../raw_data/HUC_102701030402/RGB/Res_agg_16.tif']
stream_order_path = '../raw_data/HUC_102701030402/SO/SO_KS_Strahler.tif'
crop_path = '../raw_data/HUC_102701030402/CROP/R_aggregated_crop.tif'

In [10]:
process_files(ground_truth_path, rgb_paths, stream_order_path, output_dir, crop_size=128)

In [42]:
with rasterio.open(crop_path) as gt_src:
    gt_data = gt_src.read(1)



In [43]:
gt_data.max()

195.0

In [45]:
def save_tile(raster, window, output_path):
    tile = raster.read(window=window)
    transform = raster.window_transform(window)
    with rasterio.open(
        output_path,
        'w',
        driver='GTiff',
        height=window.height,
        width=window.width,
        count=raster.count,
        dtype=raster.dtypes[0],
        crs=raster.crs,
        transform=transform,
    ) as dst:
        dst.write(tile)

def crop_image(src, x, y, crop_size):
    window = rasterio.windows.Window(x - crop_size // 2, y - crop_size // 2, crop_size, crop_size)
    return window

def process_files_with_overlap(ground_truth_path, rgb_paths, stream_order_path, output_dir, crop_size=128, overlap_rate=0.5):
    os.makedirs(output_dir, exist_ok=True)

    gt_dir = os.path.join(output_dir, "ground_truth")
    stream_dir = os.path.join(output_dir, "stream_order")
    rgb_dir = os.path.join(output_dir, "rgb_images")
    os.makedirs(gt_dir, exist_ok=True)
    os.makedirs(stream_dir, exist_ok=True)
    os.makedirs(rgb_dir, exist_ok=True)

    with rasterio.open(ground_truth_path) as gt_src, rasterio.open(stream_order_path) as stream_src:
        gt_data = gt_src.read(1)
        if gt_data.min() < 0:
            gt_data = -gt_data
        y_indices, x_indices = np.where(gt_data > 0)

        rgb_srcs = [rasterio.open(path) for path in rgb_paths]
        cropped_regions = []  # List to keep track of cropped regions
        tile_number = 0
        
        overlap_th = crop_size * overlap_rate
        for x, y in tqdm(zip(x_indices, y_indices)):
            # Check for overlap with existing cropped regions
            overlap = False
            for (prev_x, prev_y) in cropped_regions:
                if abs(prev_x - x) < overlap_th and abs(prev_y - y) < overlap_th:
                    overlap = True
                    break

            if overlap:
                continue  # Skip cropping this region due to overlap

            cropped_regions.append((x, y))
            # Crop and save as usual
            window = crop_image(stream_src, x, y, crop_size)
            save_tile(stream_src, window, os.path.join(stream_dir, f'KS1_stream_tile_{tile_number}.tif'))
            
            window = crop_image(gt_src, x, y, crop_size)
            save_tile(gt_src, window, os.path.join(gt_dir, f'KS1_ground_truth_tile_{tile_number}.tif'))

            for i, rgb_src in enumerate(rgb_srcs):
                window = crop_image(rgb_src, x, y, crop_size)
                save_tile(rgb_src, window, os.path.join(rgb_dir, f'KS1_rgb_{i}_tile_{tile_number}.tif'))

            tile_number += 1

        for src in rgb_srcs:
            src.close()

In [None]:
process_files_with_overlap(ground_truth_path, 
                           rgb_paths, 
                           stream_order_path, 
                           output_dir, 
                           crop_size=128, 
                           overlap_rate=0.25)

2821it [02:24, 14.72it/s]

# Negative images

In [35]:

def save_tile(raster, window, output_path):
    tile = raster.read(window=window)
    transform = raster.window_transform(window)
    with rasterio.open(
        output_path,
        'w',
        driver='GTiff',
        height=window.height,
        width=window.width,
        count=raster.count,
        dtype=raster.dtypes[0],
        crs=raster.crs,
        transform=transform,
    ) as dst:
        dst.write(tile)

def crop_image(src, x, y, crop_size):
    window = rasterio.windows.Window(x - crop_size // 2, y - crop_size // 2, crop_size, crop_size)
    return window

def process_files_with_negative_check(ground_truth_path, rgb_paths, stream_order_path, output_dir, crop_size=128, overlap_rate=0.5, buffer_size=50):
    os.makedirs(output_dir, exist_ok=True)

    gt_dir = os.path.join(output_dir, "ground_truth")
    stream_dir = os.path.join(output_dir, "stream_order")
    rgb_dir = os.path.join(output_dir, "rgb_images")
    os.makedirs(gt_dir, exist_ok=True)
    os.makedirs(stream_dir, exist_ok=True)
    os.makedirs(rgb_dir, exist_ok=True)

    with rasterio.open(ground_truth_path) as gt_src, rasterio.open(stream_order_path) as stream_src:
        gt_data = gt_src.read(1)
        gt_data = -gt_data
        positive_points = np.argwhere(gt_data > 0)
        cropped_regions = []
        tile_number = 0  # Initialize tile number counter

        for px, py in tqdm(positive_points):
            for dx in range(-buffer_size, buffer_size + 1, crop_size):
                for dy in range(-buffer_size, buffer_size + 1, crop_size):
                    x, y = px + dx, py + dy
                    if not (0 <= x < gt_src.width and 0 <= y < gt_src.height):
                        continue  # Ensure within bounds

                    window = crop_image(gt_src, x, y, crop_size)
                    if any(np.sqrt((prev_x - window.col_off)**2 + (prev_y - window.row_off)**2) < overlap_rate * crop_size for prev_x, prev_y in cropped_regions):
                        continue  # Check for overlap

                    cropped_gt = gt_src.read(1, window=window)
                    if np.any(cropped_gt > 0):
                        continue  # Ensure no ground truth lines are included

                    cropped_regions.append((window.col_off, window.row_off))
                    save_tile(gt_src, window, os.path.join(gt_dir, f'negative_ground_truth_tile_{tile_number}.tif'))
                    save_tile(stream_src, window, os.path.join(stream_dir, f'stream_tile_{tile_number}.tif'))

                    for rgb_path in rgb_paths:
                        with rasterio.open(rgb_path) as rgb_src:
                            save_tile(rgb_src, window, os.path.join(rgb_dir, f'rgb_{os.path.basename(rgb_path).split(".")[0]}_tile_{tile_number}.tif'))

                    tile_number += 1  # Increment tile number after each successful save

        print(f"Total regions saved: {len(cropped_regions)}")

In [36]:
process_files_with_negative_check(ground_truth_path, 
                                  rgb_paths, 
                                  stream_order_path, 
                                  output_dir, 
                                  crop_size=128, 
                                  overlap_rate=0.5,  
                                  buffer_size=10)




100%|███████████████████████████████████████████████████████| 3119/3119 [01:36<00:00, 32.18it/s]

Total regions saved: 61





In [9]:
with rasterio.open(ground_truth_path) as gt_src:
    gt_data = gt_src.read(1)

In [18]:
gt_data = -gt_data

In [24]:
a = np.argwhere(gt_data>0)
a.shape

(3119, 2)