In [None]:
import os
import sys

current_dir = os.getcwd()

parent_dir = os.path.abspath(os.path.join(current_dir, ".."))

if parent_dir not in sys.path:
    sys.path.append(parent_dir)

In [None]:
from treemort.utils.config import setup

config_file_path = "../configs/sa_unet_bs8_cs256.txt"
conf = setup(config_file_path)

# Modified Config Variables for Local Execution; comment on HPC
conf.data_folder = "/Users/anisr/Documents/AerialImages"
conf.output_dir = os.path.join("..", conf.output_dir)

print(conf)

In [None]:
import cv2
import torch
import geojson
import rasterio

import numpy as np

from tqdm import tqdm
from shapely.geometry import Polygon
from concurrent.futures import ThreadPoolExecutor

In [None]:
id2label = {0: "alive", 1: "dead"}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[INFO] Using device: {device}")

from treemort.modeling.builder import build_model

print("[INFO] Loading or resuming model...")
model, optimizer, criterion, metrics = build_model(conf, id2label, device)
print(f"[INFO] Model, optimizer, criterion, and metrics are set up.")

model = model.to(device)

checkpoint_path = f"/Users/anisr/Documents/TreeSeg/output/{conf.model}/best.weights.pth"

model.load_state_dict(torch.load(checkpoint_path, map_location=device))
print(f"[INFO] Loaded weights from {checkpoint_path}.")

In [None]:
import json
import numpy as np

from tqdm import tqdm
from shapely.geometry import Polygon, shape, mapping


def load_and_preprocess_image(tiff_file):
    with rasterio.open(tiff_file) as src:
        image = src.read()
        image = image.astype(np.float32) / 255.0
        transform = src.transform

    image = np.transpose(image, (1, 2, 0))  # From (C, H, W) to (H, W, C)

    return image, transform


def threshold_prediction_map(prediction_map, threshold=0.5):
    binary_mask = (prediction_map >= threshold).astype(np.uint8)
    return binary_mask


def extract_contours(binary_mask):
    contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    return contours


def apply_transform(contour, transform):
    transformed_contour = np.array([transform * (x, y) for x, y in contour])
    return transformed_contour


def contours_to_geojson(contours, transform, name, geojson_path):
    with open(geojson_path, 'r') as f:
        existing_geojson = json.load(f)

    existing_properties = existing_geojson['features'][0]['properties'] if existing_geojson['features'] else {}
    existing_crs = existing_geojson.get('crs', None)  # Get CRS if it exists, otherwise None

    new_features = []
    for contour in contours:
        if len(contour) >= 3:  # Ensure valid contour (at least 3 points)
            contour = contour.reshape(-1, 2)
            contour = apply_transform(contour, transform)

            if not np.array_equal(contour[0], contour[-1]):
                contour = np.vstack([contour, contour[0]])

            polygon = Polygon(contour)
            new_feature = {
                "type": "Feature",
                "properties": existing_properties.copy(),  # Use the existing properties
                "geometry": {
                    "type": "Polygon",
                    "coordinates": [contour.tolist()]
                }
            }
            new_features.append(new_feature)
        else:
            print(f"Skipped contour with {len(contour)} points")

    new_geojson = {
        "type": "FeatureCollection",
        "name": name,
        "crs": existing_crs,  # Use the same CRS from the existing GeoJSON
        "features": new_features
    }
    
    return new_geojson


def save_geojson(geojson_data, output_path):
    with open(output_path, 'w') as f:
        geojson.dump(geojson_data, f)

In [None]:
import numpy as np
import torch

def sliding_window_inference(model, image, window_size=256, stride=128, device='cuda', batch_size=8, threshold=0.5):
    model.eval()

    padded_image = pad_image(image, window_size)

    h, w = padded_image.shape[:2]
    prediction_map = np.zeros((h, w), dtype=np.float32)
    confidence_map = np.zeros((h, w), dtype=np.float32)

    patches = []
    coords = []

    for y in range(0, h - window_size + 1, stride):
        for x in range(0, w - window_size + 1, stride):
            patch = padded_image[y:y + window_size, x:x + window_size]
            patches.append(patch)
            coords.append((y, x))

            if len(patches) == batch_size:
                prediction_map, confidence_map = process_batch(patches, coords, prediction_map, confidence_map, model, device, threshold)
                patches = []
                coords = []

    if patches:
        prediction_map, confidence_map = process_batch(patches, coords, prediction_map, confidence_map, model, device, threshold)

    confidence_map[confidence_map == 0] = 1  # Avoid division by zero
    # prediction_map /= confidence_map

    prediction_map = prediction_map[:image.shape[0], :image.shape[1]]

    return prediction_map

def process_batch(patches, coords, prediction_map, confidence_map, model, device, threshold=0.5):
    batch_tensor = torch.from_numpy(np.array(patches)).permute(0, 3, 1, 2).float().to(device)
    with torch.no_grad():
        outputs = model(batch_tensor)
        predictions = torch.sigmoid(outputs).squeeze(1).cpu().numpy()

    for i, (y, x) in enumerate(coords):
        confidence = predictions[i]
        mask = confidence > threshold
        prediction_map[y:y + confidence.shape[0], x:x + confidence.shape[1]] += confidence
        confidence_map[y:y + confidence.shape[0], x:x + confidence.shape[1]] += mask.astype(np.float32)

    return prediction_map, confidence_map

def pad_image(image, window_size):
    h, w = image.shape[:2]
    pad_h = (window_size - h % window_size) % window_size
    pad_w = (window_size - w % window_size) % window_size
    padded_image = np.pad(image, ((0, pad_h), (0, pad_w), (0, 0)), mode='constant', constant_values=0)
    return padded_image


In [None]:
import os

def process_image(image_path, window_size=256, stride=128, threshold=0.5):
    image_name = os.path.splitext(os.path.basename(image_path))[0]

    geojson_path = os.path.join(os.path.dirname(os.path.dirname(image_path)), "Geojsons", image_name + ".geojson")
    predictions_path = os.path.join(os.path.dirname(os.path.dirname(image_path)), "predictions", image_name + ".geojson")

    print(f"[INFO] Starting process for image: {image_path}")
    
    image, transform = load_and_preprocess_image(image_path)
    print(f"[INFO] Image loaded and preprocessed. Shape: {image.shape}, Transform: {transform}")
    
    prediction_map = sliding_window_inference(model, image, window_size, stride, device)
    print(f"[INFO] Prediction map generated with shape: {prediction_map.shape}")
    
    binary_mask = threshold_prediction_map(prediction_map, threshold)
    print(f"[INFO] Binary mask created with threshold: {threshold}. Mask shape: {binary_mask.shape}")
    
    contours = extract_contours(binary_mask)
    print(f"[INFO] {len(contours)} contours extracted from binary mask")
    
    geojson_data = contours_to_geojson(contours, transform, image_name, geojson_path)
    print(f"[INFO] Contours converted to GeoJSON format")
    
    save_geojson(geojson_data, predictions_path)
    print(f"[INFO] GeoJSON saved to {predictions_path}")


image_path = "/Users/anisr/Documents/AerialImages/4band_25cm/M4424E_4_1.tiff"

process_image(image_path)

In [None]:
window_size = 256
stride = 128

image_path = "/Users/anisr/Documents/AerialImages/4band_25cm/63223_2.tif"

image, transform = load_and_preprocess_image(image_path)

prediction_map = sliding_window_inference(model, image, window_size, stride, device)


In [None]:
torch.sigmoid(torch.tensor(prediction_map))

In [None]:
window_size = 256
stride = 128

filename = "N4212G_2013_1.tiff"

data_folder = "/Users/anisr/Documents/AerialImages"

tiff_path = os.path.join(data_folder, "4band_25cm", filename)
geojson_path = os.path.join(data_folder, "Geojsons", os.path.splitext(filename)[0] + ".geojson")
predictions_path = os.path.join(data_folder, "predictions", os.path.splitext(filename)[0] + ".geojson")

image, transform = load_and_preprocess_image(image_path)
print(f"[INFO] Image loaded and preprocessed. Shape: {image.shape}, Transform: {transform}")
    
prediction_map, count_map = sliding_window_inference(model, image, window_size, stride, device)
print(f"[INFO] Prediction map generated with shape: {prediction_map.shape}")

In [None]:
import json
import geopandas as gpd
from shapely.geometry import shape
from shapely.ops import unary_union

def calculate_iou(true_geojson, pred_geojson):

    true_gdf = gpd.GeoDataFrame.from_features(true_geojson["features"])
    pred_gdf = gpd.GeoDataFrame.from_features(pred_geojson["features"])

    intersection_area = 0.0
    union_area = 0.0

    for true_polygon in true_gdf.geometry:
        for pred_polygon in pred_gdf.geometry:
            if true_polygon.intersects(pred_polygon):
                intersection = true_polygon.intersection(pred_polygon)
                union = true_polygon.union(pred_polygon)
                
                intersection_area += intersection.area
                union_area += union.area

    if union_area == 0:
        return 0.0
    iou = intersection_area / union_area
    return iou

if __name__ == "__main__":

    filename = "M4424E_4_1.tiff"

    data_folder = "/Users/anisr/Documents/AerialImages"

    tiff_path = os.path.join(data_folder, "4band_25cm", filename)
    geojson_path = os.path.join(data_folder, "Geojsons", os.path.splitext(filename)[0] + ".geojson")
    predictions_path = os.path.join(data_folder, "predictions", os.path.splitext(filename)[0] + ".geojson")

    with open(geojson_path) as f:
        true_geojson = json.load(f)
    
    with open(predictions_path) as f:
        pred_geojson = json.load(f)
    
    iou_score = calculate_iou(true_geojson, pred_geojson)
    print(f"IoU Score: {iou_score:.4f}")

In [None]:
image_paths = [
    "P5322A_2017_1.tif",
    "L5242G_2017_1.tif",
    "M5221F_2016_1.tiff",
    "L3343D_2019_1.tif",
    "M-34-56-B-d-1-2_1.tiff",
    "L3344B_2019_1.tif",
    "U5224D_1.tif",
    "N5442C_2014_1.tiff",
    "N4212G_2013_1.tiff",
    "L3211A_1.tif",
    "V4331B_2018_1.tif",
    "L3433D_2019_1.tif",
    "P4341G_1.tif",
    "M4123D_2015_1.tiff",
    "M4211G_2023_2.tif",
    "63223_3.tif",
    "L4411F_tile_3_band_1multiband_tile_3.tif",
    "Q4211E_2019_1.tif",
    "N5132F_1.tif",
    "63471_4_1.tiff",
    "P5322F_2_1.tiff",
    "M-34-105-B-b-3-1_1.tiff",
    "V4331A_2018_1.tif",
    "M3422E_2016_1.tiff",
    "M-34-105-B-b-2-4_1.tiff",
]

for image_path in image_paths:
    process_image(os.path.join(conf.data_folder, "4band_25cm", image_path))

In [None]:
def load_and_preprocess_image(tiff_file):
    with rasterio.open(tiff_file) as src:
        image = src.read()
        image = image.astype(np.float32) / 255.0
        transform = src.transform

    image = np.transpose(image, (1, 2, 0))  # From (C, H, W) to (H, W, C)

    return image, transform


def sliding_window_inference(model, image, window_size, stride, device, batch_size=8):
    model.eval()

    padded_image = pad_image(image, window_size)

    h, w = padded_image.shape[:2]
    prediction_map = np.zeros((h, w), dtype=np.float32)
    count_map = np.zeros((h, w), dtype=np.float32)

    patches = []

    total_patches = ((h - window_size) // stride + 1) * ((w - window_size) // stride + 1)
    with tqdm(total=total_patches, desc="Processing patches") as pbar:
        
        for y in range(0, h - window_size + 1, stride):
            for x in range(0, w - window_size + 1, stride):
                image_patch = padded_image[y:y + window_size, x:x + window_size]

                patches.append((y, x, image_patch))

                #if len(patches) == batch_size:
                #    process_batch(model, patches, prediction_map, count_map, device)
                #    patches = []  # Clear the list for the next batch

                pbar.update(1)

        # if patches:
        #    process_batch(model, patches, prediction_map, count_map, device)

    with np.errstate(divide='ignore', invalid='ignore'):
        prediction_map /= count_map
        prediction_map[count_map == 0] = 0  # Handle divisions by zero

    if isinstance(image, tuple):
        image = image[0]

    prediction_map = prediction_map[:image.shape[0], :image.shape[1]]

    return prediction_map, patches

image_path = "/Users/anisr/Documents/TreeSeg/demo/files/M4124C_2017_1.tiff"
image, transform = load_and_preprocess_image(image_path)
_, patches = sliding_window_inference(model, image, 256, 128, device, 8)

In [None]:
batch_patches = [torch.from_numpy(patches[i][2]).permute(2, 0, 1).unsqueeze(0).float().to(device) for i in range(8)]
batch_patches_tensor = torch.cat(batch_patches, dim=0)  # Create batch tensor

print(batch_patches_tensor.shape)

model.eval()
with torch.no_grad():
    outputs = model(batch_patches_tensor)

In [None]:
outputs.shape

x = 2
y = 1

window_size = 256
stride = 128

y*stride, y*stride + batch_patches_tensor.shape[2], x*stride, x*stride + batch_patches_tensor.shape[3]

In [None]:
384 - 128

In [None]:


    with torch.no_grad():
        outputs = model(batch_patches_tensor)

    for (y, x, _), output in zip(patches, outputs):
        prediction = output.squeeze(0).squeeze(0).cpu().numpy()
        prediction_map[y:y + batch_patches_tensor.shape[2], x:x + batch_patches_tensor.shape[3]] += prediction
        count_map[y:y + batch_patches_tensor.shape[2], x:x + batch_patches_tensor.shape[3]] += 1



process_batch(model, patches, prediction_map, count_map, device)

In [None]:
import numpy as np
import cv2
from tqdm import tqdm

def calculate_iou_from_topo(true_topo, pred_topo, threshold=128):
    # Convert topological maps to binary masks
    true_binary = (true_topo >= threshold).astype(np.uint8)
    pred_binary = (pred_topo >= threshold).astype(np.uint8)

    # Calculate intersection and union
    intersection = np.logical_and(true_binary, pred_binary).sum()
    union = np.logical_or(true_binary, pred_binary).sum()

    # Compute IoU
    iou = intersection / union if union != 0 else 0.0
    return iou

# Example usage:
if __name__ == "__main__":
    # Generate segmentation maps (replace contours with actual data)
    true_topo = segmap_to_topo(true_image_np, true_contours)
    pred_topo = segmap_to_topo(pred_image_np, pred_contours)

    # Calculate IoU score
    iou_score = calculate_iou_from_topo(true_topo, pred_topo)
    print(f"IoU Score: {iou_score:.4f}")
