In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import json
import os
from PIL import Image
from torchvision import transforms

# Define image transforms
image_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

class BuildingDataset(Dataset):
    def __init__(self, json_path, transform=None):
        """
        Read JSON data, parse building metadata, and provide DataLoader access.
        """
        self.transform = transform

        with open(json_path, "r", encoding="utf-8") as f:
            self.data = json.load(f)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        """
        Returns:
        - metadata: Metadata from the JSON record (including latitude/longitude, ground elevation, straight-line distance, etc.)
        - image: PyTorch Tensor (after transforms)
        """
        metadata = self.data[idx]
        image_root = r""
        image_path = os.path.join(image_root, metadata.get("fixed_pitch_image", ""))

        try:
            image = Image.open(image_path).convert("RGB")
            if self.transform:
                assert callable(self.transform), f"Transform is not callable: {type(self.transform)}"
                image = self.transform(image)
        except Exception as e:
            print(f"Error loading image {image_path}: {e}")
            image = torch.zeros(3, 224, 224)  # Placeholder black image to prevent crashes

        return metadata, image


def collate_fn(batch):
    """
    Process batched data for DataLoader:
    - Keep metadata as list[dict]
    - Stack images into a Tensor
    """
    metadata_batch = [item[0] for item in batch]
    images_batch = torch.stack([item[1] for item in batch])
    return metadata_batch, images_batch


In [None]:
""" load data """
from PIL import Image
thing_category_names = ["building","car", "person", "bus", "traffic light",'bridge','statue','funtain','bench','billboard','Roadblock','street lamp']
stuff_category_names_building = [
    "brick", "plaster", "concrete", "metal", "stone", "wood", "glass", "sandstone", "metal_sheet"]
stuff_category_names = ["street road","wood","sky", "trees", "sidewalk",'sand','water','grass']
category_names = thing_category_names + stuff_category_names_building + stuff_category_names
category_name_to_id = {
    category_name: i for i, category_name in enumerate(category_names)
}

In [None]:
""" import """
import os, sys
sys.path.append('F:/Personal_projects_MSC/panoptic-segment-anything/Grounded-Segment-Anything')
import random
import requests

import torch
from torch import nn
import torch.nn.functional as F
from scipy import ndimage
from PIL import Image

from huggingface_hub import hf_hub_download
from segments import SegmentsClient
from segments.export import colorize
from segments.utils import bitmap2file
from getpass import getpass
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import math
# Grounding DINO
import GroundingDINO.groundingdino.datasets.transforms as T
from GroundingDINO.groundingdino.models import build_model
from GroundingDINO.groundingdino.util import box_ops
from GroundingDINO.groundingdino.util.slconfig import SLConfig
from GroundingDINO.groundingdino.util.utils import clean_state_dict
from GroundingDINO.groundingdino.util.inference import annotate, predict

# segment anything
from segment_anything import build_sam, SamPredictor

# CLIPSeg
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
""" model import """
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

if device != "cpu":
    try:
        from GroundingDINO.groundingdino import _C
    except:
        warnings.warn(
            "Failed to load custom C++ ops. Running on CPU mode Only in groundingdino!"
        )
        
def load_model_hf(repo_id, filename, ckpt_config_filename, device, local_files_only=False):

    if local_files_only:
        cache_config_file = os.path.join(repo_id, ckpt_config_filename)
        cache_file = os.path.join(repo_id, filename)

        if not os.path.exists(cache_config_file):
            raise FileNotFoundError(f"Config file not found: {cache_config_file}")
        if not os.path.exists(cache_file):
            raise FileNotFoundError(f"Model file not found: {cache_file}")
    else:
        cache_config_file = hf_hub_download(repo_id=repo_id, filename=ckpt_config_filename)
        cache_file = hf_hub_download(repo_id=repo_id, filename=filename)

    args = SLConfig.fromfile(cache_config_file)
    model = build_model(args)
    args.device = device

    checkpoint = torch.load(cache_file, map_location="cpu")
    log = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
    print("Model loaded from {} \n => {}".format(cache_file, log))

    _ = model.eval()
    model.to(device)
    return model

# Use this command for evaluate the Grounding DINO model
# Or you can download the model by yourself
ckpt_repo_id = "ckpt/ShilongLiu/GroundingDINO"
ckpt_filename = "groundingdino_swinb_cogcoor.pth"
ckpt_config_filename = "GroundingDINO_SwinB.cfg.py"

groundingdino_model = load_model_hf(
    ckpt_repo_id, ckpt_filename, ckpt_config_filename, device, local_files_only=True 
)

""" load SAM """
sam_checkpoint = "ckpt/sam/sam_vit_h_4b8939.pth"
sam = build_sam(checkpoint=sam_checkpoint)
sam.to(device=device)
sam_predictor = SamPredictor(sam)

""" load clipseg """
clipseg_processor = CLIPSegProcessor.from_pretrained("ckpt/CIDAS/clipseg-rd64-refined")
clipseg_model = CLIPSegForImageSegmentation.from_pretrained(
    "ckpt/CIDAS/clipseg-rd64-refined"
)
clipseg_model.to(device)
clipseg_processor
clipseg_processor.feature_extractor.size = {"height": 224, "width": 224}

""" help funcs """
def download_image(url):
    return Image.open(requests.get(url, stream=True).raw)


def load_image_for_dino(image):
    transform = T.Compose(
        [
            T.RandomResize([800], max_size=1333),
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    )
    dino_image, _ = transform(image, None)
    return dino_image


def dino_detection(
    model,
    image,
    image_array,
    category_names,
    category_name_to_id,
    box_threshold,
    text_threshold,
    device,
    visualize=False,
):
    detection_prompt = " . ".join(category_names)
    dino_image = load_image_for_dino(image)
    dino_image = dino_image.to(device)
    with torch.no_grad():
        boxes, logits, phrases = predict(
            model=model,
            image=dino_image,
            caption=detection_prompt,
            box_threshold=box_threshold,
            text_threshold=text_threshold,
            device=device,
            # remove_combined=True,
        )
    category_ids = [category_name_to_id[phrase] for phrase in phrases]

    if visualize:
        annotated_frame = annotate(
            image_source=image_array, boxes=boxes, logits=logits, phrases=phrases
        )
        annotated_frame = annotated_frame[..., ::-1]  # BGR to RGB
        visualization = Image.fromarray(annotated_frame)
        return boxes, category_ids, visualization
    else:
        return boxes, category_ids, phrases


def sam_masks_from_dino_boxes(predictor, image_array, boxes, device):
    # box: normalized box xywh -> unnormalized xyxy
    H, W, _ = image_array.shape
    boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H])
    transformed_boxes = predictor.transform.apply_boxes_torch(
        boxes_xyxy, image_array.shape[:2]
    ).to(device)
    thing_masks, _, _ = predictor.predict_torch(
        point_coords=None,
        point_labels=None,
        boxes=transformed_boxes,
        multimask_output=False,
    )
    return thing_masks


def preds_to_semantic_inds(preds, threshold):
    flat_preds = preds.reshape((preds.shape[0], -1))
    # Initialize a dummy "unlabeled" mask with the threshold
    flat_preds_with_treshold = torch.full(
        (preds.shape[0] + 1, flat_preds.shape[-1]), threshold
    )
    flat_preds_with_treshold[1 : preds.shape[0] + 1, :] = flat_preds

    # Get the top mask index for each pixel
    semantic_inds = torch.topk(flat_preds_with_treshold, 1, dim=0).indices.reshape(
        (preds.shape[-2], preds.shape[-1])
    )

    return semantic_inds


def clipseg_segmentation(
    processor, model, image, category_names, background_threshold, device
):
    inputs = processor(
        text=category_names,
        images=[image] * len(category_names),
        padding="max_length",
        return_tensors="pt",
    ).to(device)
    with torch.no_grad():
        outputs = model(**inputs)
    logits = outputs.logits
    if len(logits.shape) == 2:
      logits = logits.unsqueeze(0)
    # resize the outputs
    upscaled_logits = nn.functional.interpolate(
        logits.unsqueeze(1),
        size=(image.size[1], image.size[0]),
        mode="bilinear",
    )
    preds = torch.sigmoid(upscaled_logits.squeeze(dim=1))
    semantic_inds = preds_to_semantic_inds(preds, background_threshold)
    return preds, semantic_inds


def semantic_inds_to_shrunken_bool_masks(
    semantic_inds, shrink_kernel_size, num_categories
):
    shrink_kernel = np.ones((shrink_kernel_size, shrink_kernel_size))

    bool_masks = torch.zeros((num_categories, *semantic_inds.shape), dtype=bool)
    for category in range(num_categories):
        binary_mask = semantic_inds == category
        shrunken_binary_mask_array = (
            ndimage.binary_erosion(binary_mask.numpy(), structure=shrink_kernel)
            if shrink_kernel_size > 0
            else binary_mask.numpy()
        )
        bool_masks[category] = torch.from_numpy(shrunken_binary_mask_array)

    return bool_masks


def clip_and_shrink_preds(semantic_inds, preds, shrink_kernel_size, num_categories):
    # convert semantic_inds to shrunken bool masks
    bool_masks = semantic_inds_to_shrunken_bool_masks(
        semantic_inds, shrink_kernel_size, num_categories
    ).to(preds.device)

    sizes = [
        torch.sum(bool_masks[i].int()).item() for i in range(1, bool_masks.size(0))
    ]
    max_size = max(sizes)
    relative_sizes = [size / max_size for size in sizes] if max_size > 0 else sizes

    # use bool masks to clip preds
    clipped_preds = torch.zeros_like(preds)
    for i in range(1, bool_masks.size(0)):
        float_mask = bool_masks[i].float()
        clipped_preds[i - 1] = preds[i - 1] * float_mask

    return clipped_preds, relative_sizes


def sample_points_based_on_preds(preds, N):
    height, width = preds.shape
    weights = preds.ravel()
    indices = np.arange(height * width)

    # Randomly sample N indices based on the weights
    sampled_indices = random.choices(indices, weights=weights, k=N)

    # Convert the sampled indices into (col, row) coordinates
    sampled_points = [(index % width, index // width) for index in sampled_indices]

    return sampled_points


def upsample_pred(pred, image_source):
    pred = pred.unsqueeze(dim=0)
    original_height = image_source.shape[0]
    original_width = image_source.shape[1]

    larger_dim = max(original_height, original_width)
    aspect_ratio = original_height / original_width

    # upsample the tensor to the larger dimension
    # upsampled_tensor = F.interpolate(
    #     pred, size=(larger_dim, larger_dim), mode="bilinear", align_corners=False
    # )
    upsampled_tensor = F.interpolate(pred, size=(larger_dim, larger_dim), mode="bilinear", align_corners=False)
    # remove the padding (at the end) to get the original image resolution
    if original_height > original_width:
        target_width = int(upsampled_tensor.shape[3] * aspect_ratio)
        upsampled_tensor = upsampled_tensor[:, :, :, :target_width]
    else:
        target_height = int(upsampled_tensor.shape[2] * aspect_ratio)
        upsampled_tensor = upsampled_tensor[:, :, :target_height, :]
    return upsampled_tensor.squeeze(dim=1)


def sam_mask_from_points(predictor, image_array, points):
    points_array = np.array(points)
    # we only sample positive points, so labels are all 1
    points_labels = np.ones(len(points))
    # we don't use predict_torch here cause it didn't seem to work...
    masks, scores, logits = predictor.predict(
        point_coords=points_array,
        point_labels=points_labels,
    )
    # max over the 3 segmentation levels
    total_pred = torch.max(torch.sigmoid(torch.tensor(logits)), dim=0)[0].unsqueeze(
        dim=0
    )
    # logits are 256x256 -> upsample back to image shape
    upsampled_pred = upsample_pred(total_pred, image_array)
    return upsampled_pred

def get_masked_building_image(building_masks, image):
    """
    Apply masks to the original image, leaving masked regions visible and making the rest black.

    Args:
    - building_masks: List of masks (torch tensors) with shape [1, H, W]
    - image: Original image array with shape (H, W, 3)

    Returns:
    - masked_image: Image with applied masks, where unmasked regions are black
    """
    # Convert image to RGB format (ensure it is 3 channels)
    masked_image_array = np.array(image).copy()

    # Initialize a combined mask with the same height and width as the image
    combined_mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.bool_)

    # Combine all masks into one
    for building_mask in building_masks:
        building_mask_np = building_mask.cpu().numpy().squeeze()  # Shape (H, W)
        if building_mask_np.shape != image.shape[:2]:
            raise ValueError(f"Mask shape {building_mask_np.shape} does not match image shape {image.shape[:2]}")
        combined_mask |= building_mask_np  # Combine masks using logical OR

    # Set unmasked regions to black
    masked_image_array[~combined_mask] = 0

    return Image.fromarray(masked_image_array)

def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)


def show_segmentation_preds(preds, category_names):
    len_cats = len(category_names)
    _, ax = plt.subplots(1, len_cats + 1, figsize=(3 * (len_cats + 1), 4))
    [a.axis("off") for a in ax.flatten()]
    ax[0].imshow(image)
    [ax[i + 1].imshow(preds[i].cpu()) for i in range(len_cats)]
    [
        ax[i + 1].text(0, -15, category_name)
        for i, category_name in enumerate(category_names)
    ]


""" pipeline """

def generate_buiding_panoptic_mask(
    image,
    thing_category_names,
    stuff_category_names_building,
    stuff_category_names,
    category_name_to_id,
    dino_model,
    sam_predictor,
    clipseg_processor,
    clipseg_model,
    device,
    dino_box_threshold=0.3,
    dino_text_threshold=0.25,
    segmentation_background_threshold=0.1,
    shrink_kernel_size=20,
    num_samples_factor=1000,
):
    image = image.convert("RGB")
    image_array = np.asarray(image)

    # compute SAM image embedding
    sam_predictor.set_image(image_array)

    # detect boxes for "thing" categories using Grounding DINO
    thing_category_ids = []
    thing_masks = []
    thing_boxes = []
    if len(thing_category_names) > 0:
        thing_boxes, thing_category_ids, _ = dino_detection(
            dino_model,
            image,
            image_array,
            thing_category_names,
            category_name_to_id,
            dino_box_threshold,
            dino_text_threshold,
            device,
        )
        if len(thing_boxes) > 0:
            # get segmentation masks for the thing boxes
            thing_masks = sam_masks_from_dino_boxes(
                sam_predictor, image_array, thing_boxes, device
            )
    
    # get building masks
    sam_predictor.set_image(image_array)
    thing_masks = sam_masks_from_dino_boxes(sam_predictor, image_array, thing_boxes, device)

    building_masks = [mask for idx, mask in enumerate(thing_masks) if thing_category_ids[idx] == 0]

    masked_image = get_masked_building_image(building_masks, image_array)

    if len(stuff_category_names_building) >0:
        clipseg_preds_building, clipseg_semantic_inds_building = clipseg_segmentation(
            clipseg_processor,
            clipseg_model,
            masked_image,
            stuff_category_names_building,
            segmentation_background_threshold,
            device,
        )
        clipsed_clipped_preds_building, relative_sizes_building = clip_and_shrink_preds(
            clipseg_semantic_inds_building,
            clipseg_preds_building,
            shrink_kernel_size,
            len(stuff_category_names_building) + 1,
        )       
        sam_preds = torch.zeros_like(clipsed_clipped_preds_building)
        for i in range(clipsed_clipped_preds_building.shape[0]):
            clipseg_preds_building = clipsed_clipped_preds_building[i]
            # for each "stuff" category, sample points in the rough segmentation mask
            num_samples = int(relative_sizes_building[i] * num_samples_factor)
            if num_samples == 0:
                continue
            points = sample_points_based_on_preds(
                clipseg_preds_building.cpu().numpy(), num_samples
            )
            if len(points) == 0:
                continue
            # use SAM to get mask for points
            pred = sam_mask_from_points(sam_predictor, image_array, points)
            sam_preds[i] = pred
        sam_semantic_inds_building = preds_to_semantic_inds(
            sam_preds, segmentation_background_threshold
        )


    if len(stuff_category_names) > 0:
        # get rough segmentation masks for "stuff" categories using CLIPSeg
        clipseg_preds, clipseg_semantic_inds = clipseg_segmentation(
            clipseg_processor,
            clipseg_model,
            image,
            stuff_category_names,
            segmentation_background_threshold,
            device,
        )
        # remove things from stuff masks
        clipseg_semantic_inds_without_things = clipseg_semantic_inds.clone()
        if len(thing_boxes) > 0:
            combined_things_mask = torch.any(thing_masks, dim=0)
            clipseg_semantic_inds_without_things[combined_things_mask[0]] = 0
        # clip CLIPSeg preds based on non-overlapping semantic segmentation inds (+ optionally shrink the mask of each category)
        # also returns the relative size of each category
        clipsed_clipped_preds, relative_sizes = clip_and_shrink_preds(
            clipseg_semantic_inds_without_things,
            clipseg_preds,
            shrink_kernel_size,
            len(stuff_category_names) + 1,
        )
        # get finer segmentation masks for the "stuff" categories using SAM
        sam_preds = torch.zeros_like(clipsed_clipped_preds)
        for i in range(clipsed_clipped_preds.shape[0]):
            clipseg_pred = clipsed_clipped_preds[i]
            # for each "stuff" category, sample points in the rough segmentation mask
            num_samples = int(relative_sizes[i] * num_samples_factor)
            if num_samples == 0:
                continue
            points = sample_points_based_on_preds(
                clipseg_pred.cpu().numpy(), num_samples
            )
            if len(points) == 0:
                continue
            # use SAM to get mask for points
            pred = sam_mask_from_points(sam_predictor, image_array, points)
            sam_preds[i] = pred
        sam_semantic_inds = preds_to_semantic_inds(
            sam_preds, segmentation_background_threshold
        )

    # combine the thing inds and the stuff inds into panoptic inds
    panoptic_inds = (
        sam_semantic_inds.clone()
        if len(stuff_category_names) > 0
        else torch.zeros(image_array.shape[0], image_array.shape[1], dtype=torch.long)
    )

    panoptic_inds_buildings = (
        sam_semantic_inds_building.clone()
        if len(stuff_category_names_building) > 0
        else torch.zeros(image_array.shape[0], image_array.shape[1], dtype=torch.long)
    )

    combined_mask = torch.zeros_like(panoptic_inds, dtype=torch.bool, device=panoptic_inds.device)
    for building_mask in building_masks:
        building_mask = building_mask.to(panoptic_inds.device)  # Ensure the mask is on the same device
        combined_mask |= building_mask.squeeze(0).bool()  # Ensure the mask is 2D and boolean

    # Update panoptic_inds to remove building mask regions
    panoptic_inds[combined_mask] = 0

    # Update panoptic_inds_buildings to retain only building mask regions
    panoptic_inds_buildings[~combined_mask] = 0
    panoptic_inds_buildings[panoptic_inds_buildings > 0] += len(stuff_category_names)
    panoptic_inds[panoptic_inds_buildings > 0]=panoptic_inds_buildings[panoptic_inds_buildings > 0]


    ind = len(stuff_category_names) + 1 + len(stuff_category_names_building)
    for thing_mask in thing_masks:
        thing_mask = thing_mask.squeeze(dim=0).bool()  # Ensure thing_mask is 2D, boolean, and on the same device
        valid_indices = thing_mask & (panoptic_inds_buildings.to(device) == 0)  # Both conditions satisfied
        panoptic_inds[valid_indices] = ind
        ind += 1

    return panoptic_inds, thing_category_ids, panoptic_inds_buildings, building_masks

""" mask visualization """
def display_images_in_grid(image, images, l ,rows, cols, path,figsize=(25, 20)):
    """
    Display images in a grid layout.

    Parameters:
    - images: list of image arrays (e.g., from plt.imread or PIL.Image)
    - rows: number of rows in the grid
    - cols: number of columns in the grid
    - figsize: size of the figure
    """
    fig, axs = plt.subplots(rows, cols, figsize=figsize)
    axs = axs.flatten()  # Flatten the 2D array of axes for easy iteration
    for i, ax in enumerate(axs):
        if i < l:
            ax.imshow(image)
            ax.imshow(colorize(images==i+1), alpha=0.5)
            ax.axis('off')  # Turn off axis for clarity
            # print(i)
            ax.set_title(list(category_name_to_id.keys())[list(category_name_to_id.values()).index(thing_category_ids[i])])
        else:
            ax.axis('off')  # Hide empty plots if images < rows*cols
    plt.savefig(path)
    # plt.tight_layout()  # Reduce padding between subplots
    # plt.show()
    plt.close()
    
# display_images_in_grid(image,panoptic_inds, len(thing_category_ids), rows=5, cols=6)
""" mask visualization """

def pixel_to_camera_coordinates(u, v, D, K):
    """
    Convert pixel coordinates (u, v) and depth D to 3D coordinates in the camera coordinate system.

    Args:
        u, v: Pixel coordinates (x, y).
        D: Depth value (in meters).
        K: Camera intrinsic matrix.

    Returns:
        X_camera, Y_camera, Z_camera: 3D coordinates in the camera coordinate system.
    """
    fx, fy = K[0, 0], K[1, 1]  # Focal lengths
    cx, cy = K[0, 2], K[1, 2]  # Principal points

    X_camera = (u - cx) * D / fx
    Y_camera = (v - cy) * D / fy
    Z_camera = D

    return X_camera, Y_camera, Z_camera



def extract_top_and_bottom_points(segmentation_mask, plot=False):
    """
    Extract the highest and corresponding lowest points from the segmentation mask.

    Args:
        segmentation_mask: Binary segmentation mask (target area = 1, background = 0).
        plot: Whether to visualize the points (default: False).

    Returns:
        (u_top, v_top), (u_bottom, v_bottom): Pixel coordinates of the top and bottom points of the target.
    """
    indices = np.where(segmentation_mask == 1)
    v_coords = indices[0]  # y-coordinates
    u_coords = indices[1]  # x-coordinates

    # 1. Find the highest point (minimum v)
    top_idx = np.argmin(v_coords)
    u_top, v_top = u_coords[top_idx], v_coords[top_idx]

    # 2. Find the lowest point with the same x-coordinate as the top point (maximum v)
    same_x_indices = np.where(u_coords == u_top)  # Get all indices where u == u_top
    v_bottom = np.max(v_coords[same_x_indices])  # Find the maximum v in those points
    u_bottom = u_top  # x-coordinate remains the same

    if plot:
        plt.imshow(segmentation_mask, cmap="gray")
        plt.plot(u_top, v_top, 'ro', label="Top Point")
        plt.plot(u_bottom, v_bottom, 'bo', label="Bottom Point")
        plt.legend()
        plt.axis('off')
        plt.show()

    return (u_top, v_top), (u_bottom, v_bottom)

import numpy as np
import matplotlib.pyplot as plt

import numpy as np
import matplotlib.pyplot as plt




def compute_object_height(depth, segmentation_mask, K, plot, calibrated_depth):
    """
    Compute the object's height using depth information and segmentation results.

    Args:
        depth: Depth map (in meters).
        segmentation_mask: Binary segmentation mask.
        K: Camera intrinsic matrix.
        plot: Whether to visualize the extracted points.

    Returns:
        H: Estimated object height (in meters).
    """
    # Extract top and bottom points
    (u_top, v_top), (u_bottom, v_bottom) = extract_top_and_bottom_points(segmentation_mask, plot)
    depth_u = calibrated_depth[u_top, v_top]
    depth_v = calibrated_depth[u_bottom, v_bottom]
    print(depth_u, depth_v, depth)
    # Compute height in the camera coordinate system
    Y_top = pixel_to_camera_coordinates(u_top, v_top, depth, K)[1]
    Y_bottom = pixel_to_camera_coordinates(u_bottom, v_bottom, depth, K)[1]
    H = abs(Y_top - Y_bottom)

    return H


def compute_height(depth, path, segmentation_mask, focallength_px=435.19, plot=False):
    K = np.array([[ focallength_px, 0, 175],
              [0,  focallength_px, 175],
              [0,  0,  1]])
    
    segmentation_mask= np.array(segmentation_mask)
    segmentation_mask=segmentation_mask/255
    H = compute_object_height(depth, segmentation_mask, K, plot)
    print(f"Recovered object height: {H} meters")

    if plot:
        plt.imshow(Image.open(path))
        plt.axis('off') 
        plt.show()
    return H

def display_images_in_grid(image, images, l ,rows, cols, path,thing_category_ids, figsize=(25, 20)):
    """
    Display images in a grid layout.

    Parameters:
    - images: list of image arrays (e.g., from plt.imread or PIL.Image)
    - rows: number of rows in the grid
    - cols: number of columns in the grid
    - figsize: size of the figure
    """
    fig, axs = plt.subplots(rows, cols, figsize=figsize)
    axs = axs.flatten()  # Flatten the 2D array of axes for easy iteration
    for i, ax in enumerate(axs):
        if i < l:
            ax.imshow(image)
            ax.imshow(colorize(images==i+1), alpha=0.5)
            ax.axis('off')  # Turn off axis for clarity
            # print(i)
            ax.set_title(list(category_name_to_id.keys())[list(category_name_to_id.values()).index(thing_category_ids[i])])
        else:
            ax.axis('off')  # Hide empty plots if images < rows*cols
            
    plt.savefig(path)
    # plt.tight_layout()  # Reduce padding between subplots
    # plt.show()
    plt.close()

def max_area(panoptic_inds, mask, thing_category_ids, thing_category_ids_building, thing_category_ids_):
    thing_category_ids=thing_category_ids_+thing_category_ids_building+thing_category_ids
    mask=np.array(mask)
    mask=mask/255
    mask_a = mask==1
    max_ind = ''
    max_sum = -1
    # print(thing_category_ids)
    for i in range(len(thing_category_ids)):
        m_mask_a = panoptic_inds==i+1
        m_mask_a = m_mask_a.numpy()
        # plt.imshow(mask)
        # plt.imshow(colorize(panoptic_inds==i+1), alpha=0.5)
        # plt.show()
        overlap = (mask_a & m_mask_a).sum()
        tmp = list(category_name_to_id.keys())[list(category_name_to_id.values()).index(thing_category_ids[i])]
        # print(tmp, overlap,'max:', max_sum, mask_a.sum())
        if overlap>max_sum:
            if tmp!="building":
                max_sum = overlap
                max_ind = tmp
    print("max_ind:",max_ind, 'max_sum,',max_sum)
    return max_ind
import math

def haversine(lat1, lon1, lat2, lon2):
    R = 6371000  # 地球半径，米
    phi1 = math.radians(lat1)
    phi2 = math.radians(lat2)
    delta_phi = math.radians(lat2 - lat1)
    delta_lambda = math.radians(lon2 - lon1)

    a = math.sin(delta_phi / 2) ** 2 + math.cos(phi1) * math.cos(phi2) * math.sin(delta_lambda / 2) ** 2
    c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a))

    distance = R * c
    return distance

def compute_camera_to_building_distance(camera_coordinates, building_geometry):
    cam_lat, cam_lon = camera_coordinates

    # Extract all points of the polygon
    coordinates = building_geometry['coordinates'][0]  # outer ring (first list)

    # Extract all longitudes and latitudes separately
    lons = [point[0] for point in coordinates]
    lats = [point[1] for point in coordinates]

    # Compute the centroid (simple arithmetic mean)
    center_lon = sum(lons) / len(lons)
    center_lat = sum(lats) / len(lats)

    # Use haversine to compute distance from camera to the building centroid
    distance = haversine(cam_lat, cam_lon, center_lat, center_lon)

    return distance



In [None]:
import numpy as np
import matplotlib.pyplot as plt

def extract_top_and_bottom_points(segmentation_mask, plot=False):
    """
    Get the top point of every valid column in the mask and a unified bottom
    point defined by the mean of the lowest 30% of column bottoms.

    - The top point is the highest (smallest v) foreground pixel in each column.
    - The bottom point is a single shared v value: compute each column's lowest
      foreground pixel, take the lowest 30% of those values, and average them.

    Args:
        segmentation_mask: Binary mask (building region = 1, background = 0).
        plot (bool): Whether to visualize the points (default: False).

    Returns:
        top_points (List[Tuple[int, int]]): [(u, v_top)] for each valid column.
        bottom_points (List[Tuple[int, int]]): [(u, v_bottom_avg)] aligned
            one-to-one with top_points.
    """
    height, width = segmentation_mask.shape

    top_points = []
    bottom_list = []

    # Step 1: Collect candidate top and bottom points for all valid columns
    for u in range(width):
        column = segmentation_mask[:, u]
        ys = np.where(column == 1)[0]
        if len(ys) > 0:
            v_top = ys.min()       # highest foreground pixel in this column
            v_bottom = ys.max()    # lowest foreground pixel in this column
            top_points.append((u, v_top))
            bottom_list.append(v_bottom)

    if len(bottom_list) == 0:
        raise ValueError("No valid region found in the mask.")

    # Step 2: Compute a unified bottom value using the lowest 30% mean
    bottom_list_sorted = np.sort(bottom_list)
    cutoff = max(int(len(bottom_list_sorted) * 0.3), 1)
    v_bottom_avg = int(np.mean(bottom_list_sorted[-cutoff:]))

    # Step 3: Build bottom_points (one-to-one with top_points)
    bottom_points = [(u, v_bottom_avg) for (u, _) in top_points]

    # Step 4: Optional visualization
    if plot:
        plt.imshow(segmentation_mask, cmap='gray')
        for (u, v_top) in top_points:
            plt.plot(u, v_top, 'ro', markersize=2)          # top points (red)
            plt.plot(u, v_bottom_avg, 'bo', markersize=2)   # unified bottom (blue)
        plt.plot([0, width - 1], [v_bottom_avg, v_bottom_avg],
                 'b--', label='Bottom 30% Mean')
        plt.title(f"Column Heights: (bottom avg) {v_bottom_avg} - v_top[i]")
        plt.legend()
        plt.axis('off')
        plt.show()

    return top_points, bottom_points


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image  # needed if plot=True in compute_height

def cluster_and_get_largest_mean(data, threshold=1.0):
    """
    Cluster 1D values by grouping consecutive numbers whose pairwise
    differences are <= `threshold`, and return the mean of the largest cluster.
    """
    data = sorted(data)
    clusters = []
    current_cluster = [data[0]]
    
    for i in range(1, len(data)):
        if abs(data[i] - current_cluster[-1]) <= threshold:
            current_cluster.append(data[i])
        else:
            clusters.append(current_cluster)
            current_cluster = [data[i]]
    clusters.append(current_cluster)  # include the last cluster
    
    # Pick the cluster with the most samples
    max_cluster = max(clusters, key=len)
    return np.mean(max_cluster)


def decide_distance(segmentation_mask, depth_image):
    """
    Within the masked region, collect depth values and estimate the distance by
    clustering and averaging the largest cluster.

    Args:
        segmentation_mask (np.ndarray): Binary mask with shape (H, W). Target = 1, background = 0.
        depth_image (np.ndarray): Depth map with shape (H, W).

    Returns:
        float: Estimated distance. Returns np.nan if no valid values exist.
    """
    valid_depths = depth_image[segmentation_mask > 0]
    # If the depth map is normalized, you could invert it like this:
    # valid_depths = 1 - valid_depths / 255.0  # assuming normalized depth; 1 means farther

    # Remove invalid/anomalous values (zeros or NaNs)
    valid_depths = valid_depths[np.isfinite(valid_depths)]
    valid_depths = valid_depths[valid_depths > 0]

    if len(valid_depths) == 0:
        return np.nan

    # Cluster the sorted depths and return the mean of the largest cluster
    sorted_depths = np.sort(valid_depths)
    return cluster_and_get_largest_mean(sorted_depths, threshold=0.1)

    # Alternative: trim the lowest and highest 30% and average the middle 40%
    # n = len(sorted_depths)
    # lower = int(n * 0.3)
    # upper = int(n * 0.7)
    # trimmed = sorted_depths[lower:upper]
    # if len(trimmed) == 0:
    #     return np.nan
    # return np.mean(trimmed)
 

# def calibrate_and_filter_mask(segmentation_mask, depth_image, distance, tolerance=2):
#     """
#     Calibrate `depth_image` using the center pixel and create a new mask that
#     keeps only pixels within [distance ± tolerance].
#
#     Args:
#         segmentation_mask (ndarray): Binary mask of shape [H, W].
#         depth_image (ndarray): Uncalibrated normalized depth image in [0, 1].
#         distance (float): Ground-truth depth (meters) at the image center.
#         tolerance (float): Allowed absolute error (meters), default ±0.5 m.
#
#     Returns:
#         new_mask (ndarray): Binary mask of the same shape; 1 if within range.
#         depth_in_meters (ndarray): Calibrated depth map in meters.
#     """
#     depth_image = 1 - np.array(depth_image).astype(np.float32) / 255
#     height, width = segmentation_mask.shape
#     cy, cx = height // 2, width // 2
#
#     center_depth_value = depth_image[cy, cx]
#     if center_depth_value == 0:
#         raise ValueError("Center depth is 0; likely an invalid region.")
#
#     # Calibration: real_depth = normalized_depth * scale
#     scale = distance / center_depth_value
#     depth_in_meters = depth_image * scale
#     new_distance = decide_distance(segmentation_mask, depth_in_meters)
#     print(new_distance, distance)
#
#     # Build a new mask for pixels within [new_distance ± tolerance]
#     within_range = ((depth_in_meters >= (new_distance - tolerance)) &
#                     (depth_in_meters <= (new_distance + tolerance)))
#     new_mask = (segmentation_mask > 0) & within_range
#     return new_mask.astype(np.uint8), depth_in_meters, new_distance


def calibrate_and_filter_mask(left_point, right_point, segmentation_mask, depth_image, distance, tolerance=2.0):
    """
    Fit a linear transformation using left, right, and center anchor points,
    calibrate the depth map, and then filter pixels by a depth band.

    Args:
        left_point (dict): Contains 'x_pixel' and 'distance_m'.
        right_point (dict): Same keys as left_point.
        segmentation_mask (np.ndarray): Binary mask, shape [H, W].
        depth_image (np.ndarray): Original depth map in millimeters.
        distance (float): Ground-truth depth (meters) at the image center.
        tolerance (float): Allowed absolute error (meters), default ±2.0 m.

    Returns:
        new_mask (np.ndarray): Binary mask (0/1) after depth filtering.
        depth_in_meters (np.ndarray): Calibrated depth map.
        new_distance (float): Median depth within the (masked) region.
    """
    depth_image = np.array(depth_image).astype(np.float32) / 1000.0  # convert mm to meters
    height, width = segmentation_mask.shape
    cy, cx = height // 2, width // 2

    # Extract center depth
    center_depth = depth_image[cy, cx]
    if center_depth == 0:
        raise ValueError("Center depth is 0; likely an invalid region.")

    # Build anchor pairs: (raw_depth_pixel, ground_truth_distance_m)
    points = [
        (depth_image[cy, left_point['x_pixel']], left_point['distance_m']),
        (depth_image[cy, right_point['x_pixel']], right_point['distance_m']),
        (center_depth, distance),
    ]

    # Fit A, B s.t. distance_m ≈ A * depth_pixel + B (least squares)
    A_matrix = np.array([[d, 1] for d, _ in points])   # [[d1, 1], [d2, 1], ...]
    b_vector = np.array([m for _, m in points])         # [m1, m2, m3]
    A, B = np.linalg.lstsq(A_matrix, b_vector, rcond=None)[0]

    # Apply linear calibration
    depth_in_meters = A * depth_image + B

    # Compute median depth within the masked region (ignore zeros)
    masked_depth = depth_in_meters[segmentation_mask > 0]
    masked_depth = masked_depth[masked_depth > 0]
    new_distance = np.median(masked_depth) if masked_depth.size > 0 else np.nan

    # Build a new mask for pixels within [new_distance ± tolerance]
    within_range = ((depth_in_meters >= new_distance - tolerance) &
                    (depth_in_meters <= new_distance + tolerance))
    new_mask = (segmentation_mask > 0) & within_range

    return new_mask.astype(np.uint8), depth_in_meters, new_distance


def compute_height(left_point, right_point, depth, path, segmentation_mask, depth_image,
                   focallength_px=435.19, plot=False):
    """
    Calibrate the depth map using three anchor points, filter the mask by the
    calibrated depth band, and compute object height (via `compute_object_height`).

    Args:
        left_point (dict): Contains 'x_pixel' and 'distance_m'.
        right_point (dict): Contains 'x_pixel' and 'distance_m'.
        depth (float): Ground-truth center distance (meters).
        path (str): Image file path for optional visualization.
        segmentation_mask (np.ndarray): Binary mask in [0, 255]; will be scaled to {0,1}.
        depth_image (np.ndarray): Depth map in millimeters.
        focallength_px (float): Focal length in pixels (for K).
        plot (bool): If True, show the original image.

    Returns:
        float: Estimated object height in meters.
    """
    K = np.array([[focallength_px, 0, 175],
                  [0, focallength_px, 175],
                  [0, 0, 1]])
    
    segmentation_mask = np.array(segmentation_mask).astype(np.float32) / 255.0
    new_mask, calibrated_depth, new_distance = calibrate_and_filter_mask(
        left_point, right_point, segmentation_mask, depth_image, depth
    )

    # Note: `compute_object_height` must be defined elsewhere.
    H = compute_object_height(depth, new_distance, new_mask, K, plot, calibrated_depth)
    print(f"Recovered object height: {H} meters")

    if plot:
        plt.imshow(Image.open(path))
        plt.axis('off')
        plt.show()

    return H


In [None]:
def compute_object_height(depth, new_depth, segmentation_mask, K, plot, calibrated_depth):
    """
    Compute object heights using multiple top and bottom points.

    Args:
        depth: Depth map (in meters), shape (H, W).
        new_depth: [Unused] Reserved for future input.
        segmentation_mask: Binary mask, shape (H, W), object = 1.
        K: Camera intrinsic matrix, shape (3, 3).
        plot: Whether to visualize the top/bottom point locations.
        calibrated_depth: Depth map for extracting accurate depth at top/bottom points.

    Returns:
        heights: List of object height estimations (in meters), one for each top-bottom pair.
    """
    # Extract all valid top and corresponding bottom points
    top_points, bottom_points = extract_top_and_bottom_points(segmentation_mask, plot)
    # print(len(top_points), len(bottom_points))
    heights = []
    for (u_top, v_top), (u_bottom, v_bottom) in zip(top_points, bottom_points):
        # 深度值提取（避免无效点）
        if not (0 <= v_top < calibrated_depth.shape[0] and 0 <= u_top < calibrated_depth.shape[1]):
            continue
        if not (0 <= v_bottom < calibrated_depth.shape[0] and 0 <= u_bottom < calibrated_depth.shape[1]):
            continue
        depth_top = calibrated_depth[v_top, u_top]
        depth_bottom = calibrated_depth[v_bottom, u_bottom]
        if abs(depth_top-depth)>3:
            depth_top = depth
        if abs(depth_bottom-depth)>3:
            depth_bottom = depth
        # 跳过深度无效点（0或nan）
        if np.isnan(depth_top) or np.isnan(depth_bottom) or depth_top <= 0 or depth_bottom <= 0:
            continue
        # 计算相机坐标系下的Y值（高度）
        Y_top = pixel_to_camera_coordinates(u_top, v_top, depth_top, K)[1]
        Y_bottom = pixel_to_camera_coordinates(u_bottom, v_bottom, depth_bottom, K)[1]

        height = abs(Y_bottom - Y_top)
        heights.append(height)
    return sorted(heights)


In [None]:
import os
import json
from pathlib import Path
from PIL import Image
import matplotlib.pyplot as plt
import torch

# Custom library imports (make sure you have these functions and models)
# from your_module import BuildingDataset, haversine, generate_buiding_panoptic_mask, compute_height, \
#     max_area, display_images_in_grid, colorize

# File paths
image_root = r"F:\Projects\msc_personal_project\GoogleStreetScript"
json_path = r"F:\Projects\msc_personal_project\GoogleStreetScript\data\osm\processed_buildings_with_projections84.json"
json_save_path = Path(r"F:\Projects\msc_personal_project\result\0516\building_heights_v0810_view.json")
depth_root = r"F:\Projects\msc_personal_project\GoogleStreetScript\data\depth_maps"

# Load dataset
dataset = BuildingDataset(json_path)

# Buffer for JSON results
results_data = []

# Iterate over the dataset
for i in range(len(dataset)):
    try:
        metadata, image = dataset[i]
        image_path = os.path.join(image_root, metadata["fixed_pitch_image"])

        # Distance from camera location to building (in meters)
        distance = metadata['intersection_point_projection']['distance_m']

        # Generate building panoptic mask
        # panoptic_inds, thing_category_ids, panoptic_inds_buildings, building_masks = generate_buiding_panoptic_mask(
        #     image, thing_category_names, stuff_category_names_building,
        #     stuff_category_names, category_name_to_id,
        #     groundingdino_model, sam_predictor,
        #     clipseg_processor, clipseg_model, device
        # )

        # Binarize mask
        # threshold = 0.5
        # binary_masks = [(mask >= threshold).to(torch.uint8) for mask in building_masks]

        left_point = metadata['left_point_projection']
        right_point = metadata['right_point_projection']
        building_entries = []

        # thing_category_ids_ = [category_name_to_id[name] for name in stuff_category_names]
        # thing_category_ids_building = [category_name_to_id[name] for name in stuff_category_names_building]

        for k in range(5):
            mask_path = f'F:/Projects/msc_personal_project/result/0516/{i}_{k}_building_mask.png'
            if os.path.exists(mask_path):
                binary_image = Image.open(mask_path).convert("L")
                # binary_image = Image.fromarray(binary_mask.squeeze().cpu().numpy() * 255).convert("L")

                depth_path = os.path.join(depth_root, f'{metadata["osm_id"]}.png')
                depth_image = Image.open(depth_path)

                # Compute category index (if needed)
                # max_ind = max_area(
                #     panoptic_inds, binary_image,
                #     thing_category_ids, thing_category_ids_building, thing_category_ids_
                # )

                # Estimate building height
                # (435.19 is focal length; here we pass 400 as an example focal length)
                H = compute_height(left_point, right_point, distance, image_path, binary_image, depth_image, 400, False)
                # H = compute_height(distance, image_path, binary_image, 250, True)

                print(metadata["new_height"])
                print(f"Building height: {H[-1]} meters")

                building_entries.append({
                    "H": H,
                    # "max_ind": max_ind,
                    "building_mask_path": mask_path
                })

        # # Show and save mask grid
        # all_ids = thing_category_ids_ + thing_category_ids_building + thing_category_ids
        # display_images_in_grid(
        #     image,
        #     panoptic_inds,
        #     len(all_ids),
        #     rows=8,
        #     cols=6,
        #     path=f'F:/Projects/msc_personal_project/result/0516/{i}_all.png',
        #     thing_category_ids=all_ids
        # )

        # # Save overlay visualization
        # _, ax = plt.subplots()
        # ax.imshow(image)
        # ax.imshow(colorize(panoptic_inds), alpha=0.5)
        # plt.savefig(f'F:/Projects/msc_personal_project/result/0516/{i}_seg.png', bbox_inches='tight', dpi=300)
        # plt.close()

        # Record results
        result_entry = {
            "idx": i,
            "image_path": image_path,
            "buildings": building_entries,
            "original_height": metadata["new_height"]
        }
        results_data.append(result_entry)

        # Incrementally write to JSON
        with open(json_save_path, "w", encoding="utf-8") as json_file:
            json.dump(results_data, json_file, indent=4, ensure_ascii=False)

        print(f"[{i}] Done.")

    except Exception as e:
        print(f"[{i}] Error occurred: {e}")

print(f"All results saved to {json_save_path}")


In [None]:
import os
import json
from pathlib import Path
from PIL import Image
import matplotlib.pyplot as plt
import torch

# Custom library imports (make sure you have these functions and models)
# from your_module import BuildingDataset, haversine, generate_buiding_panoptic_mask, compute_height, \
#     max_area, display_images_in_grid, colorize

# File paths
image_root = r"F:\Projects\msc_personal_project\GoogleStreetScript"
json_path = r"F:\Projects\msc_personal_project\GoogleStreetScript\data\osm\filtered_buildings_heading_distance_height.json"
json_save_path = Path(r"F:\Projects\msc_personal_project\result\0516\building_heights_v0802.json")
depth_root = r"F:\Projects\msc_personal_project\GoogleStreetScript\data\depth_maps"

# Load dataset
dataset = BuildingDataset(json_path)

# Results buffer for JSON
results_data = []

# Iterate over the dataset
for i in range(len(dataset)):
    try:
        metadata, image = dataset[i]
        image_path = os.path.join(image_root, metadata["fixed_pitch_image"])

        # Distance from camera to building (meters)
        distance = metadata['camera_to_building_distance_m']

        # Generate building panoptic mask
        # panoptic_inds, thing_category_ids, panoptic_inds_buildings, building_masks = generate_buiding_panoptic_mask(
        #     image, thing_category_names, stuff_category_names_building,
        #     stuff_category_names, category_name_to_id,
        #     groundingdino_model, sam_predictor,
        #     clipseg_processor, clipseg_model, device
        # )

        # Binarize masks
        # threshold = 0.5
        # binary_masks = [(mask >= threshold).to(torch.uint8) for mask in building_masks]

        building_entries = []
        thing_category_ids_ = [category_name_to_id[name] for name in stuff_category_names]
        thing_category_ids_building = [category_name_to_id[name] for name in stuff_category_names_building]

        for k in range(5):
            mask_path = f'F:/Projects/msc_personal_project/result/0516/{i}_{k}_building_mask.png'
            if os.path.exists(mask_path):
                binary_image = Image.open(mask_path).convert("L")
                # binary_image = Image.fromarray(binary_mask.squeeze().cpu().numpy() * 255).convert("L")

                depth_path = os.path.join(depth_root, f'{metadata["osm_id"]}.png')
                depth_image = Image.open(depth_path)

                # Compute category index (if needed)
                # max_ind = max_area(
                #     panoptic_inds, binary_image,
                #     thing_category_ids, thing_category_ids_building, thing_category_ids_
                # )

                # Compute building height
                # (435.19 is a typical focal length; here 250 is provided as an example)
                H = compute_height(distance, image_path, binary_image, depth_image, 250, False)
                # H = compute_height(distance, image_path, binary_image, 250, True)

                print(f"Building height: {H[-1]} meters, original height: {metadata['new_height']}")
                building_entries.append({
                    "H": H,
                    # "max_ind": max_ind,
                    "building_mask_path": mask_path
                })

        # # Show and save mask grid
        # all_ids = thing_category_ids_ + thing_category_ids_building + thing_category_ids
        # display_images_in_grid(
        #     image,
        #     panoptic_inds,
        #     len(all_ids),
        #     rows=8,
        #     cols=6,
        #     path=f'F:/Projects/msc_personal_project/result/0516/{i}_all.png',
        #     thing_category_ids=all_ids
        # )

        # # Save overlay visualization
        # _, ax = plt.subplots()
        # ax.imshow(image)
        # ax.imshow(colorize(panoptic_inds), alpha=0.5)
        # plt.savefig(f'F:/Projects/msc_personal_project/result/0516/{i}_seg.png', bbox_inches='tight', dpi=300)
        # plt.close()

        # Record this sample’s result
        result_entry = {
            "idx": i,
            "image_path": image_path,
            "buildings": building_entries,
            "original_height": metadata["new_height"]
        }
        results_data.append(result_entry)

        # Incrementally write to JSON
        with open(json_save_path, "w", encoding="utf-8") as json_file:
            json.dump(results_data, json_file, indent=4, ensure_ascii=False)

        print(f"[{i}] Done.")

    except Exception as e:
        print(f"[{i}] Error occurred: {e}")

print(f"All results have been saved to {json_save_path}")
