In [1]:
#!/usr/bin/env python2
# -*- coding: utf-8 -*-

import os
from collections import defaultdict
from osgeo import gdal
import numpy as np
import shutil

# Import the local apl_tools module
# Assuming apl_tools.py is in the same directory
import apl_tools

# Define the rescale values (keeping original values)
rescale = {
    '2': {
        1: [25.48938322, 1468.79676441],
        2: [145.74823054, 1804.91911021],
        3: [155.47927199, 1239.49848332]
    },
    '4': {
        1: [79.29799666, 978.35058431],
        2: [196.66026711, 1143.74207012],
        3: [170.72954925, 822.32387312]
    },
    '3': {
        1: [46.26129032, 1088.43225806],
        2: [127.54516129, 1002.20322581],
        3: [141.64516129, 681.90967742]
    },
    '5': {
        1: [101.63250883, 1178.05300353],
        2: [165.38869258, 1190.5229682 ],
        3: [126.5335689, 776.70671378]
    }
}

def calc_rescale(im_file_raw, m, percentiles):
    srcRaster = gdal.Open(im_file_raw)
    for band in range(1, 4):
        b = srcRaster.GetRasterBand(band)
        band_arr_tmp = b.ReadAsArray()
        bmin = np.percentile(band_arr_tmp.flatten(), percentiles[0])
        bmax = np.percentile(band_arr_tmp.flatten(), percentiles[1])
        m[band].append((bmin, bmax))
    return m

def main():
    # Define base paths (using raw strings to handle Windows paths)
    base_path = r"C:\2_data\SN3_roads_train_AOI_3_Paris\AOI_3_Paris"
    path_images_raw = r"C:\2_data\SN3_roads_train_AOI_3_Paris\AOI_3_Paris\PS-RGB"
    path_labels = r"C:\2_data\SN3_roads_train_AOI_3_Paris\AOI_3_Paris\geojson_roads"
    
    # Create output directories
    buffer_meters = 2
    burnValue = 255
    
    # Define output paths
    path_outputs = os.path.join(base_path, 'masks{}m'.format(buffer_meters))
    path_images_8bit = os.path.join(base_path, 'images_8bit')
    
    # Create output directories if they don't exist
    for d in [path_outputs, path_images_8bit]:
        os.makedirs(d, exist_ok=True)

    # Get the dataset name from the path
    test_data_name = 'AOI_3_Paris_'
    
    # Process images
    im_files = os.listdir(path_images_raw)
    m = defaultdict(list)
    
    for im_file in im_files:
        if not im_file.endswith('.tif'):
            continue

        name_root = im_file.split('.')[0]  # Remove file extension

        # Create 8-bit image
        im_file_raw = os.path.join(path_images_raw, im_file)
        im_file_out = os.path.join(path_images_8bit, name_root + '.tif')
        
        # Convert to 8bit using rescale values for AOI_3
        if not os.path.isfile(im_file_out):
            apl_tools.convert_to_8Bit(im_file_raw, im_file_out,
                                     outputPixType='Byte',
                                     outputFormat='GTiff',
                                     rescale_type=rescale['3'],
                                     percentiles=[2,98])

        # Create masks
        label_file = os.path.join(path_labels, name_root + '.geojson')
        output_raster = os.path.join(path_outputs, name_root + '.png')

        print("\nProcessing:", name_root)
        print("  Output raster:", output_raster)

        # Create road buffer masks
        mask, gdf_buffer = apl_tools.get_road_buffer(
            label_file, 
            im_file_out,
            output_raster,
            buffer_meters=buffer_meters,
            burnValue=burnValue,
            bufferRoundness=6,
            plot_file=None,
            figsize=(6,6),
            fontsize=8,
            dpi=200,
            show_plot=False,
            verbose=False
        )

    # Print rescale values if needed
    for k, v in m.items():
        print(test_data_name, k, np.mean(v, axis=0))

if __name__ == "__main__":
    main()

ModuleNotFoundError: No module named 'osgeo'