In [None]:
import os
import numpy as np
import rasterio
from rasterio.windows import Window
from osgeo import gdal
import subprocess

In [None]:

def create_patches(input_tiff, output_dir, patch_size=224, overlap=0):
    """
    Splits a satellite image into patches with overlap and saves the patches.
    
    Args:
        input_tiff (str): Path to the input satellite image in TIFF format.
        output_dir (str): Directory where patches will be saved.
        patch_size (int): Size of each patch (default is 224x224 pixels).
        overlap (int): Number of pixels of overlap between patches.
    """
    # Create the output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    # Open the satellite image
    with rasterio.open(input_tiff) as src:
        width, height = src.width, src.height  # Get the image dimensions
        channels = src.count  # Get the number of bands (channels) in the image
        
        # Calculate the stride (the step size between patches)
        stride = patch_size - overlap
        
        # Initialize a counter to name the patches
        patch_count = 0
        
        # Loop through the image to extract patches
        for y in range(0, height - overlap, stride):
            for x in range(0, width - overlap, stride):
                # Check if the window stays within the image boundaries
                window = Window(x, y, patch_size, patch_size)
                if x + patch_size > width or y + patch_size > height:
                    continue
                
                # Read the data within the window (patch)
                patch = src.read(window=window)
                
                # Save the patch as a separate TIFF file
                patch_path = os.path.join(output_dir, f'patch_{patch_count}.tif')
                with rasterio.open(
                    patch_path,
                    'w',
                    driver='GTiff',
                    height=patch.shape[1],  # Height of the patch
                    width=patch.shape[2],   # Width of the patch
                    count=channels,         # Number of bands
                    dtype=patch.dtype,     # Data type of the patch
                    crs=src.crs,            # Coordinate Reference System
                    transform=rasterio.windows.transform(window, src.transform),  # Geospatial transform
                ) as dst:
                    dst.write(patch)  # Write the patch data
                
                patch_count += 1  # Increment patch counter
    
    # Print the number of patches created
    print(f"Splitting completed: {patch_count} patches saved in {output_dir}.")

# Example usage of the function
input_tiff = "/path/to/your/image.tif"  # Path to your satellite image
output_dir = "/path/to/output/directory"  # Directory to save the patches

# Call the function with a patch size of 224x224 and overlap of 32 pixels (you can change , if you increase the overlap the calculation time will be higher but with less artefacts)
create_patches(input_tiff, output_dir, patch_size=224, overlap=32)


In [None]:


# Image normalization function
def normalize_img(img):
    """Normalizes an image between 0 and 1."""
    return img / np.max(img)

# Function to load a test image with its georeferencing information
def load_test_image(image_path):
    """
    Loads a satellite image and normalizes it.
    
    Args:
        image_path (str): Path to the satellite image.
    
    Returns:
        img (ndarray): Normalized image.
        transform (Affine): Georeferencing transformation.
        crs (CRS): Coordinate reference system.
    """
    import tifffile as tiff
    img = tiff.imread(image_path).astype(np.float32)
    img = normalize_img(img)
    img = np.expand_dims(img, axis=0)  # Add a batch dimension
    
    # Read geospatial metadata
    with rasterio.open(image_path) as src:
        transform = src.transform
        crs = src.crs
        
    return img, transform, crs

# Function to predict all patches in a directory
def predict_patches(model_path, input_dir, output_dir):
    """
    Makes predictions for each patch in the input directory.
    
    Args:
        model_path (str): Path to the saved model (SavedModel).
        input_dir (str): Directory containing the input patches.
        output_dir (str): Directory to save the predictions.
    """
    # Load the model
    model = tf.keras.models.load_model(model_path)
    
    # Create the output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    # List input patch files
    patch_files = [f for f in os.listdir(input_dir) if f.endswith('.tif')]
    
    for patch_file in patch_files:
        patch_path = os.path.join(input_dir, patch_file)
        
        # Load the test image
        test_image, transform, crs = load_test_image(patch_path)
        
        # Make the prediction
        prediction = model.predict(test_image)
        predicted_mask = tf.argmax(prediction, axis=-1)
        predicted_mask = tf.squeeze(predicted_mask).numpy().astype(np.uint8)
        
        # Define the output path for the predicted mask
        predicted_mask_path = os.path.join(output_dir, f"pred_{patch_file}")
        
        # Save the predicted mask with geospatial information
        with rasterio.open(patch_path) as src:
            meta = src.meta  # Get the metadata
            meta.update(driver='GTiff', dtype=rasterio.uint8, count=1)  # Adjust for the mask
            
            # Save the predicted mask
            with rasterio.open(predicted_mask_path, 'w', **meta) as dst:
                dst.write(predicted_mask, 1)  # Write the mask as the first band
        
        print(f"Prediction saved: {predicted_mask_path}")

# Parameters
model_path = '/path/to/your/saved_model/'  # Path to your SavedModel
input_dir = '/path/to/your/patches/'       # Directory containing the input patches
output_dir = '/path/to/your/predictions/'  # Directory to save the predictions

# Run the prediction
predict_patches(model_path, input_dir, output_dir)


In [None]:
def merge_tiff_files(input_dir, output_path):
    """
    Merges all .tif files in a directory into a single .tif file with a pseudo-color table.
    
    :param input_dir: Directory containing the .tif files to merge.
    :param output_path: Output path for the merged image.
    """
    # Get all .tif files in the directory
    tiff_files = [os.path.join(input_dir, f) for f in os.listdir(input_dir) if f.endswith('.tif')]

    if not tiff_files:
        print("No .tif files found in the specified directory.")
        return
    
    # Create the gdal_merge command
    gdal_merge_cmd = [
        'gdal_merge.py',  # Use the gdal_merge script
        '-o', output_path,  # Specify the output file
        '-pct',  # Use the pseudo-color table
    ]

    # Add all .tif files to the command
    gdal_merge_cmd.extend(tiff_files)

    # Execute the command
    try:
        subprocess.run(gdal_merge_cmd, check=True)
        print(f"Merging completed. The merged image is saved at {output_path}")
    except subprocess.CalledProcessError as e:
        print(f"Error merging TIFF files: {e}")

# Example usage
input_dir = '/path/to/your/predictions/'  # Directory containing your .tif files
output_path = '/path/to/your/output/merged_image.tif'  # Path where the merged image will be saved

merge_tiff_files(input_dir, output_path)