In [1]:
import os
import numpy as np
from osgeo import gdal, gdalconst
from multiprocessing import Pool, cpu_count


def process_chunk_parallel(args):
    """Worker function to process a single chunk in parallel."""
    src_raster, x, y, x_end, y_end, min_threshold, max_threshold, nodata = args
    src_ds = gdal.Open(src_raster, gdalconst.GA_ReadOnly)
    src_band = src_ds.GetRasterBand(1)
    chunk_data = src_band.ReadAsArray(x, y, x_end - x, y_end - y)
    src_ds = None
    
    # Apply threshold
    mask = (chunk_data >= min_threshold) & (chunk_data <= max_threshold)
    chunk_data = np.where(mask, 1, 0)  # Reclassify
    
    # Resample to 10m resolution
    chunk_resampled = chunk_data.reshape(
        (chunk_data.shape[0] // 10, 10, chunk_data.shape[1] // 10, 10)
    ).sum(axis=(1, 3))  # Sum within 10x10 blocks
    
    return x, y, chunk_resampled


def process_large_raster(input_raster, output_raster, min_threshold, max_threshold, chunk_size=10000, compression="DEFLATE"):
    """Main function to process large raster in parallel."""
    # Open the input raster
    src_ds = gdal.Open(input_raster, gdalconst.GA_ReadOnly)
    if src_ds is None:
        raise FileNotFoundError(f"Unable to open input raster: {input_raster}")
    
    src_band = src_ds.GetRasterBand(1)
    nodata = src_band.GetNoDataValue()
    if nodata is None:
        nodata = 0
    
    # Get raster dimensions and geotransform
    xsize, ysize = src_ds.RasterXSize, src_ds.RasterYSize
    geotransform = src_ds.GetGeoTransform()
    projection = src_ds.GetProjection()
    
    # Prepare output raster
    driver = gdal.GetDriverByName("GTiff")
    dst_ds = driver.Create(
        output_raster,
        xsize // 10,  # Resampled to 10m resolution
        ysize // 10,
        1,
        gdalconst.GDT_Int32,
        options=["COMPRESS={}".format(compression)]
    )
    dst_ds.SetGeoTransform((
        geotransform[0], 10, 0, geotransform[3], 0, -10
    ))  # Update geotransform for 10m resolution
    dst_ds.SetProjection(projection)
    dst_band = dst_ds.GetRasterBand(1)
    dst_band.SetNoDataValue(0)
    
    # Determine chunk size in pixels
    chunk_x_pixels = int(chunk_size // abs(geotransform[1]))  # Convert chunk size to pixels (meters -> pixels)
    chunk_y_pixels = int(chunk_size // abs(geotransform[5]))
    
    # Create tasks for parallel processing
    tasks = []
    for y in range(0, ysize, chunk_y_pixels):
        for x in range(0, xsize, chunk_x_pixels):
            x_end = min(x + chunk_x_pixels, xsize)
            y_end = min(y + chunk_y_pixels, ysize)
            tasks.append((input_raster, x, y, x_end, y_end, min_threshold, max_threshold, nodata))
    
    # Use multiprocessing to process chunks in parallel
    print(f"Using {cpu_count()} processors...")
    with Pool(cpu_count()) as pool:
        results = pool.map(process_chunk_parallel, tasks)
    
    # Combine results into the output array
    output_data = np.zeros((ysize // 10, xsize // 10), dtype=np.int32)
    for x, y, chunk_resampled in results:
        res_y_start = y // 10
        res_x_start = x // 10
        output_data[res_y_start:res_y_start + chunk_resampled.shape[0], 
                    res_x_start:res_x_start + chunk_resampled.shape[1]] += chunk_resampled
    
    # Write output raster
    dst_band.WriteArray(output_data)
    dst_band.FlushCache()
    dst_ds = None
    src_ds = None
    print(f"Processing complete: {output_raster}")


# Example usage
input_raster = "7_max_AllHeights_NoMissingReprojected.tif"
output_raster = "Above15m_10mRes_.tif"

# Define threshold range
min_threshold = 1500
max_threshold = 6666

process_large_raster(input_raster, output_raster, min_threshold, max_threshold)

Using 12 processors...




Processing complete: Above15m_10mRes_.tif
