In [1]:
import os
import sys

current_dir = os.getcwd()

parent_dir = os.path.abspath(os.path.join(current_dir, ".."))

if parent_dir not in sys.path:
    sys.path.append(parent_dir)

In [2]:
import cv2
import json
import torch
import geojson
import rasterio

import numpy as np
import geopandas as gpd

from shapely.geometry import Polygon, shape, mapping

from treemort.utils.config import setup
from treemort.modeling.builder import build_model

  from .autonotebook import tqdm as notebook_tqdm


KeyboardInterrupt: 

In [None]:
def load_and_preprocess_image(tiff_file, nir_rgb_order):
    with rasterio.open(tiff_file) as src:
        image = src.read()
        image = image.astype(np.float32) / 255.0
        transform = src.transform

    image = image[nir_rgb_order, :, :]
    return image, transform


def threshold_prediction_map(prediction_map, threshold=0.5):
    binary_mask = (prediction_map >= threshold).astype(np.uint8)
    return binary_mask


def extract_contours(binary_mask):
    contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    return contours


def apply_transform(contour, transform):
    transformed_contour = np.array([transform * (x, y) for x, y in contour])
    return transformed_contour


def contours_to_geojson(contours, transform, name, geojson_path):
    with open(geojson_path, 'r') as f:
        existing_geojson = json.load(f)

    existing_properties = existing_geojson['features'][0]['properties'] if existing_geojson['features'] else {}
    existing_crs = existing_geojson.get('crs', None)  # Get CRS if it exists, otherwise None

    new_features = []
    for contour in contours:
        if len(contour) >= 3:  # Ensure valid contour (at least 3 points)
            contour = contour.reshape(-1, 2)
            contour = apply_transform(contour, transform)

            if not np.array_equal(contour[0], contour[-1]):
                contour = np.vstack([contour, contour[0]])

            polygon = Polygon(contour)
            new_feature = {
                "type": "Feature",
                "properties": existing_properties.copy(),  # Use the existing properties
                "geometry": {
                    "type": "Polygon",
                    "coordinates": [contour.tolist()]
                }
            }
            new_features.append(new_feature)
        else:
            print(f"Skipped contour with {len(contour)} points")

    new_geojson = {
        "type": "FeatureCollection",
        "name": name,
        "crs": existing_crs,  # Use the same CRS from the existing GeoJSON
        "features": new_features
    }
    
    return new_geojson


def save_geojson(geojson_data, output_path):
    with open(output_path, 'w') as f:
        geojson.dump(geojson_data, f)

In [None]:
def sliding_window_inference(model, image, window_size=256, stride=128, batch_size=8, threshold=0.5):
    model.eval()

    device = next(model.parameters()).device  # Get the device of the model

    padded_image = pad_image(image, window_size)

    _, h, w = padded_image.shape
    prediction_map = np.zeros((h, w), dtype=np.float32)
    count_map = np.zeros((h, w), dtype=np.float32)

    patches = []
    coords = []

    for y in range(0, h - window_size + 1, stride):
        for x in range(0, w - window_size + 1, stride):
            patch = padded_image[:, y:y + window_size, x:x + window_size]
            patches.append(patch)
            coords.append((y, x))

            if len(patches) == batch_size:
                prediction_map, count_map = process_batch(patches, coords, prediction_map, count_map, model, device)
                patches = []
                coords = []

    if patches:
        prediction_map, count_map = process_batch(patches, coords, prediction_map, count_map, model, device)

    count_map[count_map == 0] = 1  # Avoid division by zero
    prediction_map /= count_map

    return prediction_map


def process_batch(patches, coords, prediction_map, count_map, model, device):
    batch_tensor = torch.from_numpy(np.array(patches)).float().to(device)
    
    with torch.no_grad():
        outputs = model(batch_tensor)
        predictions = torch.sigmoid(outputs).squeeze(1).cpu().numpy()

    for i, (y, x) in enumerate(coords):
        confidence = predictions[i]
        prediction_map[y:y + confidence.shape[0], x:x + confidence.shape[1]] += confidence
        count_map[y:y + confidence.shape[0], x:x + confidence.shape[1]] += 1

    return prediction_map, count_map


def pad_image(image, window_size):
    c, h, w = image.shape
    
    pad_h = (window_size - h % window_size) % window_size
    pad_w = (window_size - w % window_size) % window_size
    
    padded_image = np.pad(image, ((0, 0), (0, pad_h), (0, pad_w)), mode='constant', constant_values=0)
    
    return padded_image

In [None]:
def process_image(model, image_path, window_size=256, stride=128, threshold=0.5, nir_rgb_order = [3, 2, 1, 0]):
    image_name = os.path.splitext(os.path.basename(image_path))[0]

    geojson_path = os.path.join(os.path.dirname(os.path.dirname(image_path)), "Geojsons", image_name + ".geojson")
    predictions_path = os.path.join(os.path.dirname(os.path.dirname(image_path)), "predictions", image_name + ".geojson")

    print(f"[INFO] Starting process for image: {image_path}")

    image, transform = load_and_preprocess_image(image_path, nir_rgb_order = [3, 2, 1, 0])
    print(f"[INFO] Image loaded and prepr ocessed. Shape: {image.shape}, Transform: {transform}")

    prediction_map = sliding_window_inference(model, image, window_size=256, stride=128)
    print(f"[INFO] Prediction map generated with shape: {prediction_map.shape}")

    binary_mask = threshold_prediction_map(prediction_map, threshold)
    print(f"[INFO] Binary mask created with threshold: {threshold}. Mask shape: {binary_mask.shape}")

    contours = extract_contours(binary_mask)
    print(f"[INFO] {len(contours)} contours extracted from binary mask")

    geojson_data = contours_to_geojson(contours, transform, image_name, geojson_path)
    print(f"[INFO] Contours converted to GeoJSON format")

    save_geojson(geojson_data, predictions_path)
    print(f"[INFO] GeoJSON saved to {predictions_path}")

    return predictions_path

In [None]:
def load_model():

    id2label = {0: "alive", 1: "dead"}

    print("[INFO] Loading model configuration...")
    model_config_file_path = "../configs/sa_unet_bs8_cs256.txt"
    conf = setup(model_config_file_path)
    print("[INFO] Model configuration are loaded.")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"[INFO] Using device: {device}")

    print("[INFO] Loading or resuming model...")
    model, optimizer, criterion, metrics = build_model(conf, id2label, device)
    print(f"[INFO] Model, optimizer, criterion, and metrics are set up.")

    model = model.to(device)

    checkpoint_path = f"/Users/anisr/Documents/TreeSeg/output/{conf.model}/best.weights.pth"
    model.load_state_dict(torch.load(checkpoint_path, map_location=device))
    print(f"[INFO] Loaded weights from {checkpoint_path}.")
    
    return model

In [None]:
def calculate_iou(geojson_path, predictions_path):
    true_gdf = gpd.read_file(geojson_path)
    pred_gdf = gpd.read_file(predictions_path)

    # TODO. find the root cause
    # Fix invalid geometries
    true_gdf['geometry'] = true_gdf['geometry'].apply(lambda geom: geom.buffer(0) if not geom.is_valid else geom)
    pred_gdf['geometry'] = pred_gdf['geometry'].apply(lambda geom: geom.buffer(0) if not geom.is_valid else geom)

    true_union = true_gdf.geometry.union_all()
    pred_union = pred_gdf.geometry.union_all()

    intersection = true_union.intersection(pred_union)
    union = true_union.union(pred_union)

    intersection_area = intersection.area
    union_area = union.area

    if union_area == 0:
        return 0.0

    iou = intersection_area / union_area
    return iou

In [None]:
if __name__ == "__main__":
    
    image_path = "/Users/anisr/Documents/AerialImages/4band_25cm/M4424E_4_1.tiff"
    
    model = load_model()
    
    predictions_path = process_image(model, image_path)

    geojson_path = os.path.join(os.path.dirname(os.path.dirname(image_path)) , "Geojsons", os.path.splitext(os.path.basename(image_path))[0] + ".geojson")
    
    iou_score = calculate_iou(geojson_path, predictions_path)
    print(f"[INFO] IoU Score: {iou_score:.4f}")