In [1]:
import cv2
import rasterio

import osmnx as ox

import numpy as np
import pandas as pd
import geopandas as gpd

import matplotlib.pyplot as plt

import detectree as dtr

from io import BytesIO

from rasterio.features import rasterize
from rasterio.io import MemoryFile
from rasterio.mask import mask

from shapely.ops import unary_union
from shapely.geometry import Point

from tqdm.auto import tqdm

In [2]:
import warnings
warnings.filterwarnings('ignore')

### 1. Load Static Data

In [None]:
grid = gpd.read_file('../data/yerevan_grid.geojson')
buildings = gpd.read_file('../data/yerevan_buildings.geojson')

G = ox.graph_from_place("Yerevan, Armenia", network_type="all")
nodes, edges = ox.graph_to_gdfs(G)

### 2. Define pipeline functions

In [4]:
def read_image_from_path(img_path):
    with rasterio.open(img_path) as src:
        img = src.read([1, 2, 3])  # Read RGB bands only
        transform = src.transform
        crs = src.crs
    return np.moveaxis(img, 0, -1), transform, crs  # (C, H, W) → (H, W, C)

def filter_grid_data(index, grid, buildings, roads_gdf):
    tile = grid.loc[grid['Index'] == index]
    blds = buildings.loc[buildings['img_id'] == index]
    roads = roads_gdf.sjoin(tile, predicate='intersects').drop(columns='index_right', errors='ignore')
    return tile, blds, roads

def paint_buildings_and_roads_white(img, buildings_gdf, roads_gdf, transform, road_buffer=0.00002):
    all_geoms = list(buildings_gdf.geometry)  # optionally: + list(roads_gdf.geometry.buffer(road_buffer))
    mask = rasterize(
        [(geom, 1) for geom in all_geoms if geom and not geom.is_empty],
        out_shape=img.shape[:2],
        transform=transform,
        fill=0,
        dtype='uint8'
    )
    result = img.copy()
    result[mask == 1] = 255
    return result

def detect_trees_with_canopy_estimation(img, transform):
    # ----- Subfunctions -----
    def compute_exg(img):
        # Vectorized Excess Green calculation
        r, g, b = img[..., 0], img[..., 1], img[..., 2]
        return 2 * g.astype(np.int16) - r - b

    # ----- Model Inference -----
    _, encoded = cv2.imencode('.jpg', cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
    classifier = dtr.Classifier()
    y_pred = classifier.predict_img(BytesIO(encoded))  # Get float32 prob mask
    pred_mask = (y_pred > 0.85).astype(np.uint8)

    # ----- ExG Thresholding -----
    exg_mask = (compute_exg(img) > 20).astype(np.uint8)

    # ----- Combine and Clean Mask -----
    combined = cv2.bitwise_and(pred_mask, exg_mask)
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
    cleaned = cv2.morphologyEx(combined, cv2.MORPH_CLOSE, kernel, iterations=2)

    # ----- Connected Components & Area Filtering -----
    num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(cleaned, connectivity=8)
    canopy_data = [(stats[i, cv2.CC_STAT_AREA], centroids[i]) for i in range(1, num_labels) if stats[i, cv2.CC_STAT_AREA] >= 120]

    if canopy_data:
        areas, centers = zip(*canopy_data)
        avg_area = np.median(areas)
        est_trees = int(round(sum(areas) / avg_area))
    else:
        centers = []
        est_trees = 0

    # ----- Convert to Geo Points -----
    latlon_points = [
        Point(*rasterio.transform.xy(transform, int(cy), int(cx)))
        for cx, cy in centers
    ]

    # ----- Visualization -----
    marked = img.copy()
    for cx, cy in centers:
        cv2.circle(marked, (int(cx), int(cy)), 15, (0, 255, 0), -1)

    return est_trees, marked, latlon_points

In [5]:
def mask_image_by_geometry(image, transform, geometry):
    """
    Efficiently clips the image to a given geometry (e.g., a buffer zone).

    Parameters:
    - image: np.ndarray, shape (bands, height, width)
    - transform: Affine
    - geometry: shapely.geometry.Polygon or MultiPolygon

    Returns:
    - clipped_image: np.ndarray, masked image (bands, H', W')
    - clipped_transform: Affine transform of the clipped image
    """
    num_bands, height, width = image.shape
    geometry_geojson = [geometry.__geo_interface__]  # avoids overhead of mapping()

    with MemoryFile() as memfile:
        with memfile.open(
            driver='GTiff',
            height=height,
            width=width,
            count=num_bands,
            dtype=image.dtype,
            transform=transform,
            crs='EPSG:4326'  # update to match your actual CRS
        ) as dataset:
            dataset.write(image)  # write all bands at once

            clipped_image, clipped_transform = mask(
                dataset, 
                geometry_geojson, 
                crop=True,
                all_touched=False,  # faster and cleaner edge, can be set True for trees if needed
                filled=True
            )

    return clipped_image, clipped_transform

### 3. Detection

In [None]:
tile_index = 2021

img_path = f"../georeferenced_images/{tile_index}.tif"

# Step 1: Load image
original_img, transform, crs = read_image_from_path(img_path)

# Step 2: Filter grid, buildings, and roads for the tile
tile_gdf, blds, roads = filter_grid_data(tile_index, grid, buildings, edges)

# Step 3: Mask roads and buildings (make them white)
final_image = paint_buildings_and_roads_white(original_img, blds, roads, transform)

# Step 4: Detect trees and get their estimated centers as lat/lon points
tree_count, tree_img, tree_points = detect_trees_with_canopy_estimation(final_image, transform)

# Step 5: Visualize both original and tree-detected image
fig, axes = plt.subplots(1, 2, figsize=(20, 10))

axes[0].imshow(original_img)
axes[0].set_title("Original Image")
axes[0].axis("off")

axes[1].imshow(tree_img)
axes[1].set_title(f"Detected Trees: {tree_count}")
axes[1].axis("off")

plt.tight_layout()
plt.show()

# Step 6: Convert detected tree centers to GeoDataFrame and assign CRS
trees_gdf = gpd.GeoDataFrame(geometry=tree_points, crs=crs)
trees_gdf = trees_gdf.to_crs(epsg=4326)  # Convert to lat/lon (WGS84)
trees_gdf["img_id"] = tile_index  # Add image index column

### 4. Detection Loop

In [None]:
def detect_trees_around_buildings(tile_index, grid, buildings, edges):
    """
    Detects trees around buildings in a tile with adaptive buffer (25–80m), 
    avoiding overlapping calculations.

    Parameters:
    - tile_index: int, index of the tile
    - grid, buildings, edges: GeoDataFrames

    Returns:
    - trees_gdf: GeoDataFrame with detected tree locations and img_id
    - blds: Filtered buildings for the tile
    - tree_img: Image with tree detections visualized
    - tree_count: Number of detected trees
    """
    from shapely.ops import unary_union

    img_path = f"../georeferenced_images/{tile_index}.tif"
    original_img, transform, crs = read_image_from_path(img_path)

    tile_gdf, blds, roads = filter_grid_data(tile_index, grid, buildings, edges)

    building_count = len(blds)
    if building_count == 0:
        empty = gpd.GeoDataFrame(columns=["geometry", "img_id"], crs="EPSG:4326")
        return empty, blds, original_img, 0

    # Adaptive buffer: more buildings → smaller buffer
    buffer_size = max(30, min(80, 80 - 0.25 * building_count))

    # Mask buildings & roads directly (no copying unless needed)
    base_image = paint_buildings_and_roads_white(original_img, blds, roads, transform)

    # Use in-place buffer + union to define mask zone
    buffered_union = unary_union(blds.geometry.buffer(buffer_size))

    # Clip image using mask geometry
    masked_img, masked_transform = mask_image_by_geometry(base_image, transform, buffered_union)

    # Tree detection (canopy estimation etc.)
    tree_count, tree_img, tree_points = detect_trees_with_canopy_estimation(masked_img, masked_transform)

    # # Show detection image
    # plt.figure(figsize=(12, 8))
    # plt.imshow(tree_img)
    # plt.title(f"Detected Trees: {tree_count} (Tile {tile_index})")
    # plt.axis("off")
    # plt.show()

    # GeoDataFrame of detected trees
    trees_gdf = gpd.GeoDataFrame(geometry=tree_points, crs=crs).to_crs(epsg=4326)
    trees_gdf["img_id"] = tile_index

    return trees_gdf, blds, tree_img, tree_count

In [8]:
ids = buildings['img_id'].unique()

In [None]:
all_tree_detections = []
all_buildings = []

for tile_index in tqdm(ids, total=len(ids), desc='Images Processed'):
    trees_gdf, blds, _, _ = detect_trees_around_buildings(tile_index, grid, buildings, edges)
    
    all_tree_detections.append(trees_gdf)
    all_buildings.append(blds)

# Merge all detections into a single GeoDataFrame
merged_trees_gdf = gpd.GeoDataFrame(pd.concat(all_tree_detections, ignore_index=True), crs=trees_gdf.crs)
merged_buildings_gdf = gpd.GeoDataFrame(pd.concat(all_buildings, ignore_index=True), crs=trees_gdf.crs)

In [None]:
from tree_detection_threading import detect_all_trees

# Example usage
ids = buildings['img_id'].unique()

# detect_trees_around_buildings is assumed to be defined somewhere
merged_trees, merged_buildings = detect_all_trees(
    ids, grid, buildings, edges, detect_func=detect_trees_around_buildings, max_workers=10
)

In [None]:
merged_trees.to_file('../data/yerevan_trees.geojson')

In [None]:
merged_buildings['building_id'] = range(len(merged_buildings))
merged_buildings.to_file('../data/yerevan_buildings_with_id.geojson')

In [20]:
merged_buildings

Unnamed: 0,osmid,name,name:en,building,addr:country,addr:city,addr:district,addr:region,addr:housenumber,addr:postcode,addr:street,check_date,amenity,shop,img_id,geometry,building_id
0,218017460,,,yes,AM,Երևան,,,96,,Շիրակի փողոց,,,,1009,"POLYGON ((44.43474 40.13809, 44.43491 40.13746...",0
1,542793077,,,yes,AM,Երևան,,,,,,,,,1009,"POLYGON ((44.43557 40.13686, 44.43589 40.13663...",1
2,542793078,,,yes,AM,Երևան,,,,,,,,,1009,"POLYGON ((44.43549 40.13656, 44.43557 40.13618...",2
3,577406990,,,yes,AM,Երևան,,,,,,,,,1009,"POLYGON ((44.43620 40.13674, 44.43663 40.13644...",3
4,577407003,,,yes,AM,Երևան,,,,,,,,,1009,"POLYGON ((44.43389 40.13587, 44.43389 40.13539...",4
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
82830,1229748417,,,yes,,,,,,,,,,,998,"POLYGON ((44.49528 40.13706, 44.49529 40.13729...",82830
82831,1229748418,,,yes,,,,,,,,,,,998,"POLYGON ((44.49426 40.13564, 44.49441 40.13564...",82831
82832,1229748419,,,yes,,,,,,,,,,,998,"POLYGON ((44.49512 40.13600, 44.49497 40.13600...",82832
82833,1229748420,,,yes,,,,,,,,,,,998,"POLYGON ((44.49495 40.13594, 44.49496 40.13607...",82833
