In [1]:
import os
import numpy as np
import rasterio
import rasterio.windows
from tqdm import tqdm

from rasterio.io import MemoryFile
from rasterio.warp import calculate_default_transform, reproject, Resampling

os.environ["PROJ_LIB"] = os.path.join(os.environ["CONDA_PREFIX"], "share", "proj")
TARGET_CRS = "EPSG:4326"

def open_raster_as_wgs(path):
    """
    Opens a raster and ensures its CRS is WGS84 (EPSG:4326). If not, reprojects it.
    Returns a tuple (dataset, memfile) where memfile is None for on-disk rasters
    or the MemoryFile object for reprojected rasters.
    """
    src = rasterio.open(path)
    if src.crs.to_string() != TARGET_CRS:
        # Compute the transform and new dimensions
        transform, width, height = calculate_default_transform(src.crs, TARGET_CRS, src.width, src.height, *src.bounds)
        kwargs = src.meta.copy()
        kwargs.update({
            'crs': TARGET_CRS,
            'transform': transform,
            'width': width,
            'height': height
        })
        memfile = MemoryFile()
        with memfile.open(**kwargs) as dst:
            for i in range(1, src.count + 1):
                reproject(
                    source=rasterio.band(src, i),
                    destination=rasterio.band(dst, i),
                    src_transform=src.transform,
                    src_crs=src.crs,
                    dst_transform=transform,
                    dst_crs=TARGET_CRS,
                    resampling=Resampling.nearest)
        src.close()
        # Open the reprojected dataset from the memory file and return the memfile handle.
        ds = memfile.open()
        return ds, memfile
    else:
        return src, None
        
def save_tile(raster, window, output_path):
    tile = raster.read(window=window)
    transform = raster.window_transform(window)
    with rasterio.open(
        output_path,
        'w',
        driver='GTiff',
        height=int(window.height),
        width=int(window.width),
        count=raster.count,
        dtype=raster.dtypes[0],
        crs=raster.crs,
        transform=transform,
    ) as dst:
        dst.write(tile)

def crop_image(src, x, y, crop_size):
    # Creates a window in the pixel coordinate system of src
    window = rasterio.windows.Window(x - crop_size // 2, y - crop_size // 2, crop_size, crop_size)
    return window

def get_window_from_gt_window(gt_window, gt_transform, target_transform):
    """
    Given a ground truth window and its transform, compute the geographic bounds
    and then create a corresponding window for a target raster (e.g. an RGB image)
    using the target's transform.
    """
    # Get geographic bounds from the ground truth window.
    left, bottom, right, top = rasterio.windows.bounds(gt_window, gt_transform)
    # Create a window in the target raster's pixel space covering these bounds.
    target_window = rasterio.windows.from_bounds(left, bottom, right, top, transform=target_transform)
    return target_window

def process_psoitive_files_with_overlap(ground_truth_path, 
                                        rgb_paths, 
                                        stream_order_path, 
                                        output_dir, 
                                        crop_size=128, 
                                        overlap_rate=0.5,
                                        tile_number=0):
    os.makedirs(output_dir, exist_ok=True)

    gt_dir = os.path.join(output_dir, "ground_truth")
    stream_dir = os.path.join(output_dir, "dem")
    rgb_dir = os.path.join(output_dir, "rgb_images")
    os.makedirs(gt_dir, exist_ok=True)
    os.makedirs(stream_dir, exist_ok=True)
    os.makedirs(rgb_dir, exist_ok=True)

    # Open ground truth and DEM (stream order) rasters in WGS84
    gt_src, gt_mem = open_raster_as_wgs(ground_truth_path)
    stream_src, stream_mem = open_raster_as_wgs(stream_order_path)

    # Open each RGB source and ensure they are in WGS84.
    rgb_list = []
    rgb_memfiles = []  # Keep track of memory files for later closing.
    for path in rgb_paths:
        rgb_ds, rgb_mem = open_raster_as_wgs(path)
        rgb_list.append(rgb_ds)
        if rgb_mem is not None:
            rgb_memfiles.append(rgb_mem)
    
    gt_data = gt_src.read(1)
    if gt_data.min() < 0:
        gt_data = -gt_data
    # np.where returns (rows, cols); we swap to (x, y) for consistency
    y_indices, x_indices = np.where(gt_data > 0)
    all_indices = zip(x_indices, y_indices)
    
    cropped_regions = []  # Track regions already processed
    overlap_th = crop_size * overlap_rate
    
    for x, y in tqdm(all_indices):
        # Check for overlap with existing cropped regions (using ground truth pixel coordinates)
        if any(abs(prev_x - x) < overlap_th and abs(prev_y - y) < overlap_th for prev_x, prev_y in cropped_regions):
            continue
        cropped_regions.append((x, y))
        
        # For ground truth and DEM, use their own pixel windows.
        gt_window = crop_image(gt_src, x, y, crop_size)
        stream_window = crop_image(stream_src, x, y, crop_size)
        save_tile(stream_src, stream_window, os.path.join(stream_dir, f'dem_tile_{tile_number}.tif'))
        save_tile(gt_src, gt_window, os.path.join(gt_dir, f'ground_truth_tile_{tile_number}.tif'))

        # For each RGB image (which may have a different resolution), convert the ground truth window
        # to the RGB image’s window using geospatial bounds.
        for i, rgb_src in enumerate(rgb_list):
            rgb_window = get_window_from_gt_window(gt_window, gt_src.transform, rgb_src.transform)
            # Check that the window has valid size
            if int(rgb_window.width) <= 0 or int(rgb_window.height) <= 0:
                print(f"Skipping RGB tile for tile {tile_number} due to invalid window size.")
                continue
            save_tile(rgb_src, rgb_window, os.path.join(rgb_dir, f'rgb_{i}_tile_{tile_number}.tif'))
        
        tile_number += 1
        break
        
    # Close all open datasets.
    gt_src.close()
    stream_src.close()
    for rgb_ds in rgb_list:
        rgb_ds.close()
    # Close any memory files that were created.
    if gt_mem is not None:
        gt_mem.close()
    if stream_mem is not None:
        stream_mem.close()
    for mem in rgb_memfiles:
        mem.close()
    
    return tile_number

def process_files_with_negative_check(ground_truth_path, 
                                      rgb_paths, 
                                      stream_order_path, 
                                      output_dir, 
                                      crop_size=128, 
                                      overlap_rate=0.5, 
                                      buffer_size=50,
                                      tile_number=0):
    os.makedirs(output_dir, exist_ok=True)

    gt_dir = os.path.join(output_dir, "ground_truth")
    stream_dir = os.path.join(output_dir, "dem")
    rgb_dir = os.path.join(output_dir, "rgb_images")
    os.makedirs(gt_dir, exist_ok=True)
    os.makedirs(stream_dir, exist_ok=True)
    os.makedirs(rgb_dir, exist_ok=True)

    with rasterio.open(ground_truth_path) as gt_src, rasterio.open(stream_order_path) as stream_src:
        gt_data = gt_src.read(1)
        if gt_data.min() < 0:
            gt_data = -gt_data
        positive_points = np.argwhere(gt_data > 0)
        cropped_regions = []

        for px, py in tqdm(positive_points):
            for dx in range(-buffer_size, buffer_size + 1, crop_size):
                for dy in range(-buffer_size, buffer_size + 1, crop_size):
                    x, y = px + dx, py + dy
                    if not (0 <= x < gt_src.width and 0 <= y < gt_src.height):
                        continue  # Ensure within bounds

                    # Compute the ground truth window
                    gt_window = crop_image(gt_src, x, y, crop_size)
                    # Check for overlap using the window's starting column/row
                    if any(np.sqrt((prev_x - gt_window.col_off)**2 + (prev_y - gt_window.row_off)**2) < overlap_rate * crop_size
                           for prev_x, prev_y in cropped_regions):
                        continue  # Skip overlapping regions

                    cropped_gt = gt_src.read(1, window=gt_window)
                    if np.any(cropped_gt > 0):
                        continue  # Ensure no ground truth lines are included

                    cropped_regions.append((gt_window.col_off, gt_window.row_off))
                    save_tile(gt_src, gt_window, os.path.join(gt_dir, f'negative_ground_truth_tile_{tile_number}.tif'))
                    save_tile(stream_src, gt_window, os.path.join(stream_dir, f'dem_tile_{tile_number}.tif'))

                    # For each RGB image, open it and convert the ground truth window to the rgb window.
                    for i, rgb_path in enumerate(rgb_paths):
                        with rasterio.open(rgb_path) as rgb_src:
                            rgb_window = get_window_from_gt_window(gt_window, gt_src.transform, rgb_src.transform)
                            save_tile(rgb_src, rgb_window, os.path.join(rgb_dir, f'rgb_{i}_tile_{tile_number}.tif'))
                    
                    tile_number += 1

        print(f"Total regions saved: {len(cropped_regions)}")
    
    return tile_number


In [2]:
GT_path = '/home/macula/SMATousi/Gullies/ground_truth/organized_data/MO_Downloaded_HUCs/HUC_071100060307-done/data/gt/rasterized_gt.tif'

data_path = '/home/macula/SMATousi/Gullies/ground_truth/organized_data/MO_Downloaded_HUCs/HUC_071100060307-done/data/'
rgb_paths = [os.path.join(data_path,'merged/tile_10__merged.tif'), 
             os.path.join(data_path,'merged/tile_12__merged.tif'), 
             os.path.join(data_path,'merged/tile_14__merged.tif'), 
             os.path.join(data_path,'merged/tile_16__merged.tif'),
             os.path.join(data_path,'merged/tile_18__merged.tif'), 
             os.path.join(data_path,'merged/tile_20__merged.tif'),
             '/home1/choroid/SMATousi/High_Resolution_Tiles/Monroe.tif']

dem_path = os.path.join(data_path,'merged/dem_tile__merged.tif')

pos_output_dir = '/home1/choroid/SMATousi/High_Resolution_Tiles/Tiled_test/'

starting_pos_tile_number = 0

last_pos_tile_number = process_psoitive_files_with_overlap(GT_path, 
                                                           rgb_paths, 
                                                           dem_path, 
                                                           pos_output_dir, 
                                                           crop_size=128, 
                                                           overlap_rate=0.25,
                                                           tile_number=starting_pos_tile_number)

RasterioIOError: 2ba76a47-8b09-4534-9f4f-ac921f464cab.tif: Free disk space available is 33323130880 bytes, whereas 51392623068 are at least necessary. You can disable this check by defining the CHECK_DISK_FREE_SPACE configuration option to FALSE.

In [None]:
calculate_default_transform

In [None]:
import pyproj

pyproj.datadir.get_data_dir()