## GRAIN Dataset - ML Application

In [None]:
#imports
import geopandas as gpd
import os
import gc
import pandas as pd
import warnings
import numpy as np
import math
import seaborn as sns
import matplotlib.pyplot as plt
import traceback
import joblib
import json
import rasterio
from rasterio.mask import mask
from pyrosm import OSM
from shapely.geometry import LineString, MultiLineString
from shapely.strtree import STRtree
from rasterio.sample import sample_gen
from core.get_sword_reach_id import get_sword_reach
from core.elevation_and_slope import compute_elevDiff_and_slope
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from geopandas.tools import sjoin
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay

In [None]:
warnings.filterwarnings("ignore")

In [None]:
##Specifying all data file paths and formats
osm_data_folder = '../assets/osm_waterways/geoParquet/latest_record/'
sword_data_folder = '../assets/SWORD_v16_shp/shp'
sword_fileName_format = '{}_sword_reaches_hb{}_v16.shp'
hydrobasins_l2_folder = '../assets/HydroBasins_world_L2'
hydrobasin_l6_file = '../assets/supporting_data/hydrobasins_allBasins_l6_geoParquet_EPSG4326.parquet'
hydrobasins_l2_fileName_format = 'hybas_{}_lev02_v1c.shp'
world_countries_filePath = '../assets/supporting_data/world-administrative-boundaries.geojson'
waterways_save_folder = '../assets/osm_waterways'
sword_continent_map = '../assets/supporting_data/sword_continents.json'
koppen_class_map = '../assets/supporting_data/koppen_class_label.json'
koppen_geiger_fp = '../assets/koppen_geiger_data/koppen_geiger_0p00833333.cog'
dem_cog_fp = '../assets/dem_data/World_e-Atlas-UCSD_SRTM30-plus_v8.cog'

ml_training_data_saveFolder = '../assets/outputs/ML_training_data'
esa_cci_cog_path = "../assets/ESACCI/ESACCI-LC-L4-LCCS-Map-300m-P1Y-2015-v2.0.7.cog" 



In [None]:
#function to safely sample data from a dataframe by iteration
def safe_sample(df, columns, n, label):
    while n > 0:
        try:
            return df[columns].sample(n=n, random_state=22)
        except ValueError:
            print(f"Not enough data in {label} for {n} samples, trying {n-50}")
            n -= 50
    print(f" Failed to sample any data from {label}")
    return df[columns].head(1).iloc[0:0] 

In [None]:
#functions to compute the features for ML Training and Application

def vertex_per_length(geom):
    try:
        if geom.geom_type == 'LineString':
            num_vertices = len(geom.coords)
            length = geom.length
        elif geom.geom_type == 'MultiLineString':
            num_vertices = sum(len(part.coords) for part in geom.geoms)
            length = geom.length
        else:
            return np.nan  # for unsupported types

        return num_vertices / (length / 100) if length > 0 else np.nan

    except Exception:
        return np.nan


def get_straightness_ratio(geom):
    try:
        # Handle MultiLineString by combining all coordinates
        if geom.geom_type == 'MultiLineString':
            # print("MultiLineString")
            coords = [pt for line in geom.geoms for pt in line.coords]
            # print(coords)
        else:
            coords = list(geom.coords)

        if len(coords) > 3:
            start = coords[0]
            end = coords[-1]
            mid = coords[len(coords) // 2]

            # print(start, mid, end)

            dist1 = math.dist(start, mid)
            dist2 = math.dist(mid, end)
            simplified_length = dist1 + dist2
        else:
            start = coords[0]
            end = coords[1]
            simplified_length = math.dist(start, end)
        
        total_length = geom.length
        # print(total_length)
        sinousity = simplified_length / total_length # if total_length > 0 else np.nan
        
        # print("Simplified length:", simplified_length, "Total length:", total_length, "Sinousity:", sinousity)
        return sinousity

    except Exception as e:
        # print(f"Error: {e}")
        return np.nan


def mean_turning_angle(geom):
    def angle_at_b(a, b, c):
        ba = (a[0] - b[0], a[1] - b[1])
        bc = (c[0] - b[0], c[1] - b[1])
        dot = ba[0]*bc[0] + ba[1]*bc[1]
        det = ba[0]*bc[1] - ba[1]*bc[0]
        angle = math.atan2(abs(det), dot)
        return (180 - angle*180/math.pi)

    try:
        if geom.geom_type == "MultiLineString":
            coords = [pt for line in geom.geoms for pt in line.coords]
            coords = [coords[i] for i in range(len(coords)) if i == 0 or coords[i] != coords[i-1]]

        else:
            coords = list(geom.coords)

        if len(coords) < 3:
            return 0

        angles = [
            angle_at_b(coords[i - 1], coords[i], coords[i + 1])
            for i in range(1, len(coords) - 1)
        ]
        mean_turn_angle = np.mean(angles) if len(angles) > 0 else 0
        # print(mean_turn_angle)
        return mean_turn_angle
        
    except Exception:
        return np.nan


def get_curvature_index(geom):
    

    def angle_at_b(a, b, c):
        # Convert to vectors
        ba = (a[0] - b[0], a[1] - b[1])
        bc = (c[0] - b[0], c[1] - b[1])
        # Dot and cross products
        dot = ba[0]*bc[0] + ba[1]*bc[1]
        det = ba[0]*bc[1] - ba[1]*bc[0]

        # Angle in radians
        angle = math.atan2(abs(det), dot)
        # print(angle)
        return(180 - angle*180/math.pi)
        # return angle  # Always positive

    try:
        # Flatten coords
        if geom.geom_type == "MultiLineString":
            coords = [pt for line in geom.geoms for pt in line.coords]
            #removing duplicates
            coords = [coords[i] for i in range(len(coords)) if i == 0 or coords[i] != coords[i-1]]
            # print(coords)
            # print(len(coords))
            # print(type(coords))
        else:
            # print("LineString")
            coords = list(geom.coords)
            # print(coords)
            # print(type(coords))
        

        if len(coords) < 3:
            return 0

        angles = [
            angle_at_b(coords[i - 1], coords[i], coords[i + 1])
            for i in range(1, len(coords) - 1)
        ]
        # print(angles)
        total_angle_change = sum(angles)
        if geom.geom_type == 'LineString':
            num_vertices = len(geom.coords)
        elif geom.geom_type == 'MultiLineString':
            # num_vertices = sum(len(part.coords) for part in geom.geoms)
            num_vertices = len(coords)
            
        line_length_100m = geom.length / 100  # assuming EPSG:3857
        # print(line_length_100m)
        # return total_angle_change / (num_vertices-1)
        return total_angle_change / line_length_100m if line_length_100m > 0 else 0

    except Exception as e:
        print("Error: {}".format(e))
        traceback.print_exc()
        return np.nan



In [None]:
#function to identify koppen-geiger climate zone for a given canal vector
def get_koppen_climate_class(canal_dataset):
    koppen_climate_map = json.load(open(koppen_class_map))
    koppen_data = rasterio.open(koppen_geiger_fp)
    
    canal_dataset_withKoppen = canal_dataset.copy()
    def startpoint(line):
        
        if isinstance(line, LineString):
            coords = list(line.coords)
        elif isinstance(line, MultiLineString):
            # flatten into a single coordinate list
            coords = [pt for seg in line.geoms for pt in seg.coords]
        else:
            raise ValueError("Geometry is neither LineString nor MultiLineString")
        return coords[0]
    canal_dataset_withKoppen["koppen_class_code"] = pd.NA
    # canal_dataset_withKoppen["koppen_class_desc"] = pd.NA
    error_count = 0
    for idx, row in tqdm(canal_dataset_withKoppen.iterrows(), total=canal_dataset_withKoppen.shape[0], desc="Assigning Koppen Climate Class"):
        try:
            first_coord = startpoint(row.geometry)
            # print(first_coord)
            koppen_value = koppen_data.sample([(first_coord[0], first_coord[1])])
            koppen_value = int(list(koppen_value)[0])
            koppen_string = str(koppen_value)

            koppen_desc = koppen_climate_map[koppen_string]
            canal_dataset_withKoppen.at[idx, "koppen_class_code"] = koppen_desc
            # canal_dataset_withKoppen.at[idx, "koppen_class_desc"] = koppen_desc
        except:
            error_count = error_count + 1
            continue
    return canal_dataset_withKoppen

In [None]:
#function to add a unique grain id to each canal vector

def add_GRAIN_id(final_canal_gdf, country_iso):
    basins = gpd.read_parquet(hydrobasin_l6_file)[["PFAF_ID", "geometry"]]
    basins = basins.to_crs(final_canal_gdf.crs)
    print(country_iso)
    canals = gpd.sjoin(
        final_canal_gdf,
        basins,
        predicate="intersects",       
        how="left"                    
    ).rename(columns={"PFAF_ID": "pfaf_id"})

   
    canals = canals.dropna(subset=["pfaf_id"])
    canals["pfaf_id"] = canals["pfaf_id"].astype(int)
    canals["id_counter"] = canals.groupby("pfaf_id").cumcount() + 1
    
    canals["id_counter_str"] = canals["id_counter"].apply(lambda n: f"{n:05d}")

    canals["grain_id"] = (
        country_iso + "_" +
        canals["pfaf_id"].astype(str) + "_" +
        canals["id_counter_str"]
    )

    return canals

In [None]:
#helper functions to extract OSM tags
def get_osm_name(tags):
    if tags is None:
        return None
    try:
        tag_dict_str = tags
        tag_dict = json.loads(tag_dict_str)
        if "name" not in tag_dict.keys():
            return None
        else:
            return tag_dict["name"]
    except:
        return None
    

def get_osm_source(tags):
    if tags is None:
        return None
    try:
        tag_dict_str = tags
        tag_dict = json.loads(tag_dict_str)
        if "source:name" not in tag_dict.keys():
            return None
        else:
            return tag_dict["source:name"]
    except:
        return None

def get_osm_name_fromOtherTags(other_tags):
    if other_tags is None:
        return None
    try:
        tag_raw = other_tags.split(",")
        tag_dict_str = {}
        for item in tag_raw:
            key, value = item.strip('"').split('"=>"', 1)
            tag_dict_str[key] = value
        # tag_dict = json.loads(tag_dict_str)
        if "name:en" not in tag_dict_str.keys():
            return None
        else:
            return tag_dict_str["name:en"]
    except:        
        return None

In [None]:
#fucntion to perform topology based promotion of canal segments
def promote_connected_canals_until_convergence(df, buffer_dist=10):
    """
    Repeatedly promotes 'Canal_natural' segments to 'Canal_man_made (connected_round_X)'
    if they intersect any currently man-made canal segment.

    Parameters:
        df : GeoDataFrame with a 'geometry' column and 'predicted_class' field
        buffer_dist : distance in meters to consider for intersection (default 10)

    Returns:
        Updated GeoDataFrame with topologically promoted segments
    """
    round_num = 1
    total_promoted = -1  # force first run

    while True:
        print(f"▶ Topology Promotion Round {round_num}")

        # Get current man-made segments
        current_man_made = df[df["predicted_class"].str.startswith("Canal_man_made")].copy()
        current_man_made_geoms = list(current_man_made.geometry.values)
        tree = STRtree(current_man_made_geoms)

        # Get current natural segments to test
        to_test = df[df["predicted_class"] == "Canal_natural"].copy()
        if to_test.empty:
            print("✅ No more natural segments left to test. Done.")
            break

        promoted_idxs = []

        for idx, geom in to_test.geometry.items():
            if geom is None or geom.is_empty:
                continue
            try:
                geom_to_check = geom.buffer(buffer_dist) if buffer_dist > 0 else geom
                candidate_ids = tree.query(geom_to_check)
                candidate_geoms = [current_man_made_geoms[i] for i in candidate_ids]

                if any(cand.intersects(geom_to_check) for cand in candidate_geoms):
                    promoted_idxs.append(idx)
            except Exception as e:
                print(f"[[{idx}]]: Skipping due to error: {e}")

        if not promoted_idxs:
            print("✅ No new connections found — stopping.")
            break

        if round_num==10:
            print("🔁 Maximum rounds reached (10). Stopping promotion.")
            break

        label = f"Canal_man_made (connected_round_{round_num})"
        df.loc[promoted_idxs, "predicted_class"] = label
        print(f"🔁 Promoted {len(promoted_idxs)} segments to: {label}")

        round_num += 1

    return df


In [None]:
#snapping end point within threshold distance of the nearest canal
from shapely.geometry import Point
from shapely.strtree import STRtree
from shapely.geometry import MultiPoint
from shapely.ops import snap


def get_endpoints(geom):
    if geom.geom_type == 'LineString':
        return Point(geom.coords[0]), Point(geom.coords[-1])
    elif geom.geom_type == 'MultiLineString':
        parts = geom.geoms
        return Point(parts[0].coords[0]), Point(parts[-1].coords[-1])
    return None, None

In [None]:

def assign_canal_use(grain_data_canals, essa_cci_cog_path):
    # print('Entered func')
    counter = 0
    cropland_class = [10, 11, 12, 20, 30, 130]
    urban_andBare_class = [190,200,201,202]
    water_class = [210]
    grain_data_canals= grain_data_canals.to_crs(epsg=4326)
    grain_data_canals['canal_use'] = pd.NA
    src = rasterio.open(essa_cci_cog_path) 
    for idx, row in tqdm(grain_data_canals.iterrows(), total=grain_data_canals.shape[0], desc="Processing canal segments" ):

        geom = row.geometry.buffer(.01) #5km buffer
        data, _ = mask(src, [geom], crop=True)           # data.shape = (1, h, w)
        values  = data[0].ravel()
        values  = values[values != src.nodata]

        class_counts = np.bincount(values)
        idx_sorted = class_counts.argsort()[::-1]
        majority_class = idx_sorted[0]
        second_major_class = idx_sorted[1] if class_counts[idx_sorted[1]] > 0 else None
        # print(row.id, majority_class)
        if majority_class in cropland_class:
            canal_use_case = "Agricultural"
            # print("Agricultural")
        elif majority_class in urban_andBare_class:
            if second_major_class in cropland_class:
                canal_use_case = "Agricultural"
            else:
                if majority_class in [200,201,202]:
                    canal_use_case = "Other"
                else:
                    canal_use_case = "Urban Waterway"
            # print("Urban")
        elif majority_class in water_class:
            if second_major_class in cropland_class:
                canal_use_case = "Agricultural"
            else:
                canal_use_case = "Navigational Waterway"
            # print("Navigational")
        else:
            if second_major_class in cropland_class:
                canal_use_case = "Agricultural"
            else:
                canal_use_case = "Other"
            # print("Natural")
        # print(majority_class,second_major_class,canal_use_case)
        # print(canal_use_case)
        grain_data_canals.at[idx, "canal_use"] = canal_use_case
        counter += 1
        
    print('✔ Completed processing {} canal segments'.format(counter))
    return grain_data_canals

In [None]:
# major function to run the GRAIN ML model
def run_grain_ml_model(country):
    feature_cols = ['elev_diff','straightness_ratio','curvature_index_per_100m','mean_turning_angle','slope']

    country_osm_waterways_data_fp = f'{osm_data_folder}{country.lower()}_waterway.parquet'
    country_osm_waterways_data = gpd.read_parquet(country_osm_waterways_data_fp)
    country_osm_waterways_data = country_osm_waterways_data[country_osm_waterways_data['geometry'].geom_type.isin(['LineString', 'MultiLineString'])]

    country_osm_waterways_data = country_osm_waterways_data.to_crs(epsg=3857)
    if "osm_id" in country_osm_waterways_data.columns:
        country_osm_waterways_data.rename(columns={"osm_id": "id"}, inplace=True)
        
    if "waterway"  in country_osm_waterways_data.columns:
            country_osm_waterways_data = country_osm_waterways_data.rename(columns={"waterway": "osm_label"})

    if "tags" in country_osm_waterways_data.columns:
        columns_to_keep = ["id", "geometry", "osm_label", "tags"]
    elif "other_tags" in country_osm_waterways_data.columns:
        columns_to_keep = ["id", "geometry", "osm_label", "other_tags"]
    else:
        columns_to_keep = ["id", "geometry", "osm_label"]
    if "name" in country_osm_waterways_data.columns:
        columns_to_keep.append("name")

    country_osm_canals = country_osm_waterways_data[country_osm_waterways_data['osm_label'].isin(['canal', 'ditch', 'drain'])]
    country_osm_rivers = country_osm_waterways_data[country_osm_waterways_data['osm_label'].isin(['river', 'stream'])]
    country_osm_rivers = country_osm_rivers[columns_to_keep]
    country_osm_canals = country_osm_canals[columns_to_keep]

    #computing features
    print(f'[{country}]: Computing features. This might take a while...')
    print(f'[{country}]: Rivers')

    country_osm_rivers["straightness_ratio"] = country_osm_rivers.geometry.apply(get_straightness_ratio)
    country_osm_rivers["mean_turning_angle"] = country_osm_rivers.geometry.apply(mean_turning_angle)
    country_osm_rivers["curvature_index_per_100m"] = country_osm_rivers.geometry.apply(get_curvature_index)
    #compute length of the geometry in km
    country_osm_rivers['length'] = country_osm_rivers['geometry'].length / 1E3
    country_osm_rivers = compute_elevDiff_and_slope(dem_cog_fp,country_osm_rivers)


    # do the same for canal
    print(f'[{country}]: Canals')
    country_osm_canals["straightness_ratio"] = country_osm_canals.geometry.apply(get_straightness_ratio)
    country_osm_canals["mean_turning_angle"] = country_osm_canals.geometry.apply(mean_turning_angle)
    country_osm_canals["curvature_index_per_100m"] = country_osm_canals.geometry.apply(get_curvature_index)
    #compute length of the geometry in km
    country_osm_canals['length'] = country_osm_canals['geometry'].length / 1E3
    country_osm_canals = compute_elevDiff_and_slope(dem_cog_fp,country_osm_canals)

    #drop nans
    cols_checkforNans = ['elev_diff', 'slope', 'straightness_ratio', 'mean_turning_angle', 'curvature_index_per_100m']
    country_osm_rivers_clean = country_osm_rivers.dropna(subset=cols_checkforNans).copy()
    country_osm_canals_clean = country_osm_canals.dropna(subset=cols_checkforNans).copy()

    #loading random forest model
    model_fp = f'{ml_training_data_saveFolder}/ML_model_random_forest.pkl'
    model = joblib.load(model_fp)


    X_country_osm_rivers = country_osm_rivers_clean[feature_cols]

    #predict using random forest model
    country_osm_rivers_clean["predicted_label"] = model.predict(X_country_osm_rivers)
    country_osm_rivers_clean["predicted_class"] = country_osm_rivers_clean["predicted_label"].map({0: "river", 1: "canal"})
    #getting confidence
    proba_rivers = model.predict_proba(X_country_osm_rivers)   
    country_osm_rivers_clean[["prob_river", "prob_canal"]] = proba_rivers

    X_country_osm_canals = country_osm_canals_clean[feature_cols]
    #predict using random forest model
    country_osm_canals_clean["predicted_label"] = model.predict(X_country_osm_canals)
    country_osm_canals_clean["predicted_class"] = country_osm_canals_clean["predicted_label"].map({0: "river", 1: "canal"})
    proba_canals = model.predict_proba(X_country_osm_canals)   
    country_osm_canals_clean[["prob_river", "prob_canal"]] = proba_canals


    ##adding name, osm source, and alt name fields
    #rivers
    if "name" in country_osm_rivers_clean.columns:
        country_osm_rivers_clean = country_osm_rivers_clean.rename(columns={"name": "osm_name"})
    if "tags" in country_osm_rivers_clean.columns:
        if "osm_name" not in country_osm_rivers_clean.columns:
            country_osm_rivers_clean['osm_name'] = country_osm_rivers_clean['tags'].apply(get_osm_name)
        country_osm_rivers_clean['osm_source'] = country_osm_rivers_clean['tags'].apply(get_osm_source) 

    if "other_tags" in country_osm_rivers_clean.columns:
        if "osm_name" not in country_osm_rivers_clean.columns:
            country_osm_rivers_clean['osm_name'] = country_osm_rivers_clean['other_tags'].apply(get_osm_name_fromOtherTags)
        else:
            country_osm_rivers_clean['alt_name'] = country_osm_rivers_clean['other_tags'].apply(get_osm_name_fromOtherTags)

    if "osm_source" not in country_osm_rivers_clean.columns:
        country_osm_rivers_clean['osm_source'] = None

    if "alt_name" not in country_osm_rivers_clean.columns:
        country_osm_rivers_clean['alt_name'] = None
    #canals
    if "name" in country_osm_canals_clean.columns:
        country_osm_canals_clean = country_osm_canals_clean.rename(columns={"name": "osm_name"})
    if "tags" in country_osm_canals_clean.columns:
        if "osm_name" not in country_osm_canals_clean.columns:
            country_osm_canals_clean['osm_name'] = country_osm_canals_clean['tags'].apply(get_osm_name)
        country_osm_canals_clean['osm_source'] = country_osm_canals_clean['tags'].apply(get_osm_source) 

    if "other_tags" in country_osm_canals_clean.columns:
        if "osm_name" not in country_osm_canals_clean.columns:
            country_osm_canals_clean['osm_name'] = country_osm_canals_clean['other_tags'].apply(get_osm_name_fromOtherTags)
        else:
            country_osm_canals_clean['alt_name'] = country_osm_canals_clean['other_tags'].apply(get_osm_name_fromOtherTags)

    if "osm_source" not in country_osm_canals_clean.columns:
        country_osm_canals_clean['osm_source'] = None

    if "alt_name" not in country_osm_canals_clean.columns:
        country_osm_canals_clean['alt_name'] = None


    #removing any sword rivers from canals

    sword_continent_map_data = json.load(open(sword_continent_map))
    random_canal = country_osm_canals_clean.sample(n=1, random_state=21)
    country_iso, country_name_official, continent_name, sword_reach_id = get_sword_reach(random_canal)

    sword_gdfs = []
    print(f'[[{country}]]:Removing any sword rivers from canal dataset')
    for sword_id in sword_reach_id:

        for key, value in sword_continent_map_data.items():
            if str(sword_id) in value:
                continent_abbrv = key

        sword_file_path = os.path.join(sword_data_folder,continent_abbrv,sword_fileName_format.format(continent_abbrv.lower(), sword_id))

        sword_data_gpd = gpd.read_file(sword_file_path)
        sword_gdfs.append(sword_data_gpd)

    sword_data_gpd = pd.concat(sword_gdfs)
    country_sword_mercator = sword_data_gpd.to_crs(epsg=3857)

    country_osm_canals_clean = country_osm_canals_clean.to_crs(epsg=3857)
    country_osm_canals_sword_intersection = gpd.sjoin(country_osm_canals_clean, country_sword_mercator, how="inner", predicate="intersects")
    to_remove_index = country_osm_canals_sword_intersection.index
    country_osm_canals_final = country_osm_canals_clean.drop(index=to_remove_index)

    print(f'[[{country}]]:Running checks for any osm canals labeled as rivers')
    country_osm_rivers_clean = country_osm_rivers_clean.to_crs(epsg=3857)
    country_osm_rivers_toCheckForNamedCanals = country_osm_rivers_clean.copy()
    cols = ["osm_name", "alt_name"]
    CANAL_RE = r"\bcanal\b" 

    mask_namedCanals = (
        country_osm_rivers_toCheckForNamedCanals[cols]
        .apply(lambda s: s.str.contains(CANAL_RE, case=False, na=False))
        .any(axis=1)                         
    )
    canals_to_merge = country_osm_rivers_toCheckForNamedCanals[mask_namedCanals]
    canals_to_merge['predicted_class'] = 'canal'
    print("Identified", len(canals_to_merge), "river segments that have tag 'canal' in its name. Merging with canals")
    country_osm_canals_final = pd.concat([country_osm_canals_clean, canals_to_merge], ignore_index=True)

    print(f'[[{country}]]:Relabeling canals classified as rivers if they are connected to canals endpoints')
    ml_canals_man_made = country_osm_canals_final[country_osm_canals_final["predicted_class"] == "canal"].copy()
    ml_canals_rivers = country_osm_canals_final[country_osm_canals_final["predicted_class"] == "river"].copy()

    # Collect all endpoints
    man_made_endpoints = []
    for geom in ml_canals_man_made.geometry:
        start, end = get_endpoints(geom)
        man_made_endpoints.extend([start, end])

    # Create a spatial index for fast lookup
    endpoint_tree = STRtree(man_made_endpoints)
    man_made_geoms = list(ml_canals_man_made.geometry)
    man_made_tree = STRtree(man_made_geoms)
    #checking if the endpoints of canals classified as rivers are within a buffer distance of man-made canals. if so they are man-made canals
    buffer_dist = 50  # in meters

    promoted_idxs = []

    for idx, geom in ml_canals_rivers.geometry.items():
            if geom is None or geom.is_empty:
                continue

            try:
                # Buffer canal geom slightly to catch near-touches (optional)
                geom_to_check = geom.buffer(buffer_dist) if buffer_dist > 0 else geom
                
                # Query STRtree for potential matches
                candidates_id = man_made_tree.query(geom_to_check)
                # Check if any intersect
                #if candidates in not empty break out of loop
                if len(candidates_id) > 0:
                    candidate_geoms = [man_made_geoms[i] for i in candidates_id]
                    if any(candidate.intersects(geom_to_check) for candidate in candidate_geoms):
                        promoted_idxs.append(idx)
                
            except Exception as e:
                print(f"[[{idx}]]: Skipping due to error: {e}")
                traceback.print_exc()
    print(f'Identified {len(promoted_idxs)} segments. Promoting...')

    ml_canals_rivers.loc[promoted_idxs, "predicted_class"] = "Canal_man_made (connected)"
    ml_canals_rivers.loc[~ml_canals_rivers.index.isin(promoted_idxs), "predicted_class"] = "Canal_natural"

    #combining back
    ml_canals_cleaned = pd.concat([ml_canals_man_made, ml_canals_rivers], ignore_index=True)

    print(f'[[{country}]]:Using successive topology checks to promote canal regments labelled as canal_natural')
    ml_canals_cleaned = promote_connected_canals_until_convergence(ml_canals_cleaned, buffer_dist=50)

    print("Adding osm rivers classified as canals to river dataset based on intersection criteria")
    ml_canals_cleaned_preRiver = ml_canals_cleaned.copy()
    ml_canals_cleaned_preRiver.loc[~ml_canals_cleaned_preRiver["predicted_class"].isin(["canal", "Canal_natural"]), "predicted_class"] = "canal"

    country_rivers_classified_man_made_canals = country_osm_rivers_clean[country_osm_rivers_clean["predicted_class"] == "canal"]
    river_man_made = country_rivers_classified_man_made_canals.copy()

    canal_geoms = list(ml_canals_cleaned_preRiver.geometry.values)
    canal_tree = STRtree(canal_geoms)

    intersecting_river_idxs = []

    joined = sjoin(river_man_made, ml_canals_cleaned_preRiver, how="left", predicate="intersects")
    intersecting_idxs = joined[~joined.index_right.isna()].index
    river_man_made.loc[intersecting_idxs, "predicted_class"] = "Canal_man_made_connected"

    final_canal_dataset = pd.concat(
            [ml_canals_cleaned_preRiver, river_man_made[river_man_made["predicted_class"] == "Canal_man_made_connected"]],
            ignore_index=True
        )

    final_canal_dataset.loc[final_canal_dataset['predicted_class']=='Canal_man_made_connected', 'predicted_class'] = 'canal'

    print(f"[[{country}]]:Running checks for any osm canals. Removes those that have 'River' in their OSM names")
    final_canal_dataset = final_canal_dataset.to_crs(epsg=3857)
    final_canal_dataset_toCheckForNamedRivers = final_canal_dataset.copy()
    cols = ["osm_name", "alt_name"]
    RIVER_RE = r"\briver\b" 
    CANAL_RE = r"\bcanal\b"
    mask_namedRivers = (
        final_canal_dataset_toCheckForNamedRivers[cols]
        .apply(lambda s: s.str.contains(RIVER_RE, case=False, na=False))
        .any(axis=1)                         
    )
    mask_namedCanals = (
        final_canal_dataset_toCheckForNamedRivers[cols]
        .apply(lambda s: s.str.contains(CANAL_RE, case=False, na=False))
        .any(axis=1)
    )
    mask_to_drop = mask_namedRivers & ~mask_namedCanals

    final_canal_dataset_clean = final_canal_dataset_toCheckForNamedRivers.loc[~mask_to_drop].copy()
    #Assigning Canal use case based on ES CCI LULC Dataset 2015

    print(f'[[{country}]]: Assigning canal use case based on ESA CCI LULC data')
    final_canal_dataset_withUseCase = assign_canal_use(final_canal_dataset_clean, esa_cci_cog_path)
    print(f'[[{country}]]: Adding GRAIN ID in the format ISO3-code_PFAF Level 6 ID_sequential numbering')
    final_canal_dataset_withGrainID = add_GRAIN_id(final_canal_dataset_withUseCase, country_iso[0])
    print(f'[[{country}]]: Adding country and continent names and Koppen Climate Class')
    final_canal_dataset_withGrainID['country'] = country_name_official[0]
    final_canal_dataset_withGrainID['continent'] = continent_name[0]
    final_canal_dataset_withGrainID['country_iso'] = country_iso[0]
    final_canal_dataset_withGrainID['update_date'] = "2025-07-31"
    final_canal_dataset_withGrainID['version'] = "v.1.0.0"
    final_canal_dataset_withKoppenClass = get_koppen_climate_class(final_canal_dataset_withGrainID)
    print(f'[[{country}]]: Trimming and renaming columns')
    columns_to_keep = ["grain_id", "id", "country", "continent", "country_iso", "length","elev_diff","slope","predicted_class","prob_canal",
    "osm_name","osm_label", "tags","osm_source", "alt_name", "canal_use","koppen_class_code", "update_date","version", "geometry"]
    if "tags" not in final_canal_dataset_withKoppenClass.columns:
        final_canal_dataset_withKoppenClass["tags"] = None

    final_grain_canal_dataset = final_canal_dataset_withKoppenClass[columns_to_keep]
    final_grain_canal_dataset = final_grain_canal_dataset.rename(columns={"id": "osm_id", "slope": "slope_MKM",
    "length": "length_KM", "elev_diff": "elev_diff_M", "prob_canal": "confidence", })

    print(f'[[{country}]]:Saving final dataset for {country} to parquet and geojson file')
    final_grain_canal_dataset.to_parquet(f"../assets/outputs/final_outputs_withUseCase/{country}_GRAIN_v.1.0.parquet")
    final_grain_canal_dataset.to_file(f"../assets/outputs/final_outputs_withUseCase/{country}_GRAIN_v.1.0.geojson", driver='GeoJSON')
    print(f'[[{country}]]:Completed')
    print('==========================================================================')

In [None]:
osm_files = os.listdir(osm_data_folder)
error_list = []
output_folder = '../assets/outputs/final_outputs_withUseCase/'
counter_country = 1
for file in osm_files[0:]:
    if file.endswith('_waterway.parquet'):
            country_name = file.split('_waterway.parquet')[0]
            print(f"Processing country: {counter_country}/115: {country_name}")

    else:
        print(f"Skipping file: {file}")
        error_list.append(file)
        continue
    
    try:
        save_path = os.path.join(output_folder, f"{country_name}_GRAIN_v.1.0.parquet")
        
        if os.path.exists(save_path):
            print(f"[[{country_name}]]: File already exists: {save_path}, skipping...")
            print('==========================================================================')
            continue
        run_grain_ml_model(country_name)
    except Exception as e:
        print(f"Error processing {country_name}: {e}")
        counter_country += 1
        error_list.append(country_name)
        traceback.print_exc()
    else:
        counter_country += 1
        