In [None]:
import os
import numpy as np
import rasterio
from rasterio.windows import Window
from rasterio.merge import merge
from osgeo import gdal
import subprocess
import tensorflow as tf
from tqdm import tqdm
import tempfile
from rasterio.enums import Resampling

2025-04-09 20:40:43.966916: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-04-09 20:40:44.363156: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-04-09 20:40:44.363207: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-04-09 20:40:44.438972: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-04-09 20:40:44.573125: I tensorflow/core/platform/cpu_feature_guar

In [2]:

def create_patches(input_tiff, output_dir, patch_size=224, overlap=125):
    """
    Splits a satellite image into patches with overlap and saves the patches.
    
    Args:
        input_tiff (str): Path to the input satellite image in TIFF format.
        output_dir (str): Directory where patches will be saved.
        patch_size (int): Size of each patch (default is 224x224 pixels).
        overlap (int): Number of pixels of overlap between patches.
    """
    # Create the output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    # Open the satellite image
    with rasterio.open(input_tiff) as src:
        width, height = src.width, src.height  # Get the image dimensions
        channels = src.count  # Get the number of bands (channels) in the image
        
        # Calculate the stride (the step size between patches)
        stride = patch_size - overlap
        
        # Initialize a counter to name the patches
        patch_count = 0
        
        # Loop through the image to extract patches
        for y in range(0, height - overlap, stride):
            for x in range(0, width - overlap, stride):
                # Check if the window stays within the image boundaries
                window = Window(x, y, patch_size, patch_size)
                if x + patch_size > width or y + patch_size > height:
                    continue
                
                # Read the data within the window (patch)
                patch = src.read(window=window)
                
                # Save the patch as a separate TIFF file
                patch_path = os.path.join(output_dir, f'patch_{patch_count}.tif')
                with rasterio.open(
                    patch_path,
                    'w',
                    driver='GTiff',
                    height=patch.shape[1],  # Height of the patch
                    width=patch.shape[2],   # Width of the patch
                    count=channels,         # Number of bands
                    dtype=patch.dtype,     # Data type of the patch
                    crs=src.crs,            # Coordinate Reference System
                    transform=rasterio.windows.transform(window, src.transform),  # Geospatial transform
                ) as dst:
                    dst.write(patch)  # Write the patch data
                
                patch_count += 1  # Increment patch counter
    
    # Print the number of patches created
    print(f"Splitting completed: {patch_count} patches saved in {output_dir}.")

# Example usage of the function
input_tiff = "/mnt/c/Users/PC/Bureau/Travail/Thèse_SSD/Unet_RS/Unet4RSImage/RTU/Image.tif"  # Path to your satellite image
output_dir = "/mnt/c/Users/PC/Bureau/Travail/Thèse_SSD/Unet_RS/Unet4RSImage/RTU/test2/directory"  # Directory to save the patches

# Call the function with a patch size of 224x224 and overlap of 32 pixels (you can change , if you increase the overlap the calculation time will be higher but with less artefacts)
create_patches(input_tiff, output_dir, patch_size=224, overlap=125)


Splitting completed: 9000 patches saved in /mnt/c/Users/PC/Bureau/Travail/Thèse_SSD/Unet_RS/Unet4RSImage/RTU/test2/directory.


In [3]:


# Image normalization function
def normalize_img(img):
    """Normalizes an image between 0 and 1."""
    return img / np.max(img)

# Function to load a test image with its georeferencing information
def load_test_image(image_path):
    """
    Loads a satellite image and normalizes it.
    
    Args:
        image_path (str): Path to the satellite image.
    
    Returns:
        img (ndarray): Normalized image.
        transform (Affine): Georeferencing transformation.
        crs (CRS): Coordinate reference system.
    """
    import tifffile as tiff
    img = tiff.imread(image_path).astype(np.float32)
    img = normalize_img(img)
    img = np.expand_dims(img, axis=0)  # Add a batch dimension
    
    # Read geospatial metadata
    with rasterio.open(image_path) as src:
        transform = src.transform
        crs = src.crs
        
    return img, transform, crs

# Function to predict all patches in a directory
def predict_patches(model_path, input_dir, output_dir):
    """
    Makes predictions for each patch in the input directory.
    
    Args:
        model_path (str): Path to the saved model (SavedModel).
        input_dir (str): Directory containing the input patches.
        output_dir (str): Directory to save the predictions.
    """
    # Load the model
    model = tf.keras.models.load_model(model_path)
    
    # Create the output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    # List input patch files
    patch_files = [f for f in os.listdir(input_dir) if f.endswith('.tif')]
    
    for patch_file in patch_files:
        patch_path = os.path.join(input_dir, patch_file)
        
        # Load the test image
        test_image, transform, crs = load_test_image(patch_path)
        
        # Make the prediction
        prediction = model.predict(test_image)
        predicted_mask = tf.argmax(prediction, axis=-1)
        predicted_mask = tf.squeeze(predicted_mask).numpy().astype(np.uint8)
        
        # Define the output path for the predicted mask
        predicted_mask_path = os.path.join(output_dir, f"pred_{patch_file}")
        
        # Save the predicted mask with geospatial information
        with rasterio.open(patch_path) as src:
            meta = src.meta  # Get the metadata
            meta.update(driver='GTiff', dtype=rasterio.uint8, count=1)  # Adjust for the mask
            
            # Save the predicted mask
            with rasterio.open(predicted_mask_path, 'w', **meta) as dst:
                dst.write(predicted_mask, 1)  # Write the mask as the first band
        
        print(f"Prediction saved: {predicted_mask_path}")

# Parameters
model_path = '/mnt/c/Users/PC/Bureau/Travail/Thèse_SSD/Unet_RS/Unet4RSImage/RTU/saved_model'  # Path to your SavedModel
input_dir = '/mnt/c/Users/PC/Bureau/Travail/Thèse_SSD/Unet_RS/Unet4RSImage/RTU/test2/directory'       # Directory containing the input patches
output_dir = '/mnt/c/Users/PC/Bureau/Travail/Thèse_SSD/Unet_RS/Unet4RSImage/RTU/test2/predictions/'  # Directory to save the predictions

# Run the prediction
predict_patches(model_path, input_dir, output_dir)


2025-04-09 20:45:58.205304: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:887] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
2025-04-09 20:45:58.362566: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:887] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
2025-04-09 20:45:58.362611: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:887] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
2025-04-09 20:45:58.364852: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:887] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
2025-04-09 20:45:58.364922: I external/local_xla/xla/stream_executor

Prediction saved: /mnt/c/Users/PC/Bureau/Travail/Thèse_SSD/Unet_RS/Unet4RSImage/RTU/test2/predictions/pred_patch_0.tif
Prediction saved: /mnt/c/Users/PC/Bureau/Travail/Thèse_SSD/Unet_RS/Unet4RSImage/RTU/test2/predictions/pred_patch_1.tif
Prediction saved: /mnt/c/Users/PC/Bureau/Travail/Thèse_SSD/Unet_RS/Unet4RSImage/RTU/test2/predictions/pred_patch_10.tif
Prediction saved: /mnt/c/Users/PC/Bureau/Travail/Thèse_SSD/Unet_RS/Unet4RSImage/RTU/test2/predictions/pred_patch_100.tif
Prediction saved: /mnt/c/Users/PC/Bureau/Travail/Thèse_SSD/Unet_RS/Unet4RSImage/RTU/test2/predictions/pred_patch_1000.tif
Prediction saved: /mnt/c/Users/PC/Bureau/Travail/Thèse_SSD/Unet_RS/Unet4RSImage/RTU/test2/predictions/pred_patch_1001.tif
Prediction saved: /mnt/c/Users/PC/Bureau/Travail/Thèse_SSD/Unet_RS/Unet4RSImage/RTU/test2/predictions/pred_patch_1002.tif
Prediction saved: /mnt/c/Users/PC/Bureau/Travail/Thèse_SSD/Unet_RS/Unet4RSImage/RTU/test2/predictions/pred_patch_1003.tif
Prediction saved: /mnt/c/Users/PC

In [None]:
import os
import numpy as np
import rasterio
from rasterio.merge import merge
from rasterio.windows import Window
from tqdm import tqdm
import tempfile
from rasterio.enums import Resampling

def reconstruct_from_patches(input_dir, output_path, overlap=200, buffer_size=50):
    """
    Reconstructs a large-scale image from GeoTIFF patches with overlap handling.
    
    Args:
        input_dir (str): Directory containing the GeoTIFF patches
        output_path (str): Output path for the reconstructed image
        overlap (int): Number of pixels of overlap between patches
        buffer_size (int): Area to ignore at the edges of each patch to avoid artifacts
    """
    # List all GeoTIFF files in the directory
    patch_files = [f for f in os.listdir(input_dir) if f.lower().endswith(('.tif', '.tiff'))]
    if not patch_files:
        raise ValueError("No GeoTIFF files found in the input directory")

    # Load metadata from the first patch
    first_patch = os.path.join(input_dir, patch_files[0])
    with rasterio.open(first_patch) as src:
        meta = src.meta.copy()
        dtype = src.dtypes[0]
        count = src.count

    # Create a temporary directory for processed patches
    temp_dir = tempfile.mkdtemp()
    processed_patches = []

    print("Processing individual patches...")
    for patch_file in tqdm(patch_files):
        patch_path = os.path.join(input_dir, patch_file)
        temp_path = os.path.join(temp_dir, f"processed_{patch_file}")
        
        with rasterio.open(patch_path) as src:
            # Read the data with a reduced window to avoid borders
            height, width = src.shape
            window = Window(
                buffer_size, 
                buffer_size, 
                width - 2 * buffer_size, 
                height - 2 * buffer_size
            )
            
            # Read the data within the window
            data = src.read(window=window)
            
            # Compute the new transform for this window
            transform = src.window_transform(window)
            
            # Update metadata
            meta.update({
                'height': window.height,
                'width': window.width,
                'transform': transform
            })
            
            # Write the processed patch
            with rasterio.open(temp_path, 'w', **meta) as dst:
                dst.write(data)
        
        processed_patches.append(temp_path)

    print("Merging processed patches...")
    # Merge all processed patches
    src_files_to_mosaic = [rasterio.open(patch) for patch in processed_patches]
    mosaic, out_trans = merge(
        src_files_to_mosaic,
        method='first',  # Takes the first non-null value (avoids overlaps)
        resampling=Resampling.nearest
    )

    # Close all opened files
    for src in src_files_to_mosaic:
        src.close()

    # Update final metadata
    meta.update({
        'driver': 'GTiff',
        'height': mosaic.shape[1],
        'width': mosaic.shape[2],
        'transform': out_trans,
        'compress': 'lzw',
        'nodata': 8,
        'dtype': dtype,
        'count': count
    })

    print("Writing the final image...")
    # Write the final image
    with rasterio.open(output_path, 'w', **meta) as dst:
        dst.write(mosaic)

    # Clean up temporary files
    for patch in processed_patches:
        try:
            os.remove(patch)
        except:
            pass
    try:
        os.rmdir(temp_dir)
    except:
        pass

    print(f"Reconstruction completed. Image saved at: {output_path}")


if __name__ == "__main__":
    # Example usage
    input_directory = "path/to/your/directory"
    output_file = "path/to/your/output.tif"
    
    reconstruct_from_patches(
        input_dir=input_directory,
        output_path=output_file,
        overlap=125,  # Matches your description
        buffer_size=40  # Area to ignore at the borders
    )
