# **Grounded SAM2 for building part recognition**

A notebook for using Grounded SAM2 to export the segmentation of windows on a batched building file.

https://github.com/autodistill/autodistill-grounded-sam-2

## **Preparation**

CUDA version check for correct torch installation

In [1]:
!nvcc --version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Thu_Jun__6_02:18:23_PDT_2024
Cuda compilation tools, release 12.5, V12.5.82
Build cuda_12.5.r12.5/compiler.34385749_0


Works only when I use torch version for CUDA 12.1
It might need to be restarted to work properly once loaded

In [2]:
# uninstall previously installed version if needed
!pip uninstall torch torchvision torchaudio -y

# # install PyTorch for CUDA 11.8
# !pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu118

# # install PyTorch for CUDA 12.1
!pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu121

Found existing installation: torch 2.8.0+cu126
Uninstalling torch-2.8.0+cu126:
  Successfully uninstalled torch-2.8.0+cu126
Found existing installation: torchvision 0.23.0+cu126
Uninstalling torchvision-0.23.0+cu126:
  Successfully uninstalled torchvision-0.23.0+cu126
Found existing installation: torchaudio 2.8.0+cu126
Uninstalling torchaudio-2.8.0+cu126:
  Successfully uninstalled torchaudio-2.8.0+cu126
Looking in indexes: https://download.pytorch.org/whl/cu121
Collecting torch==2.5.1
  Downloading https://download.pytorch.org/whl/cu121/torch-2.5.1%2Bcu121-cp312-cp312-linux_x86_64.whl (780.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m780.4/780.4 MB[0m [31m1.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torchvision==0.20.1
  Downloading https://download.pytorch.org/whl/cu121/torchvision-0.20.1%2Bcu121-cp312-cp312-linux_x86_64.whl (7.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.3/7.3 MB[0m [31m132.6 MB/s[0m eta [36m0:00:00[0m


In [3]:
!git clone https://github.com/autodistill/autodistill-grounded-sam-2.git

Cloning into 'autodistill-grounded-sam-2'...
remote: Enumerating objects: 98, done.[K
remote: Counting objects: 100% (22/22), done.[K
remote: Compressing objects: 100% (8/8), done.[K
remote: Total 98 (delta 17), reused 14 (delta 14), pack-reused 76 (from 1)[K
Receiving objects: 100% (98/98), 24.97 KiB | 1.31 MiB/s, done.
Resolving deltas: 100% (45/45), done.


In [4]:
%cd autodistill-grounded-sam-2
!pip install --no-build-isolation -e .
%cd ..

/content/autodistill-grounded-sam-2
Obtaining file:///content/autodistill-grounded-sam-2
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting autodistill (from autodistill_grounded_sam_2==0.1.0)
  Downloading autodistill-0.1.29-py3-none-any.whl.metadata (32 kB)
Collecting supervision (from autodistill_grounded_sam_2==0.1.0)
  Downloading supervision-0.26.1-py3-none-any.whl.metadata (13 kB)
Collecting roboflow (from autodistill_grounded_sam_2==0.1.0)
  Downloading roboflow-1.2.9-py3-none-any.whl.metadata (9.7 kB)
Collecting autodistill_florence_2 (from autodistill_grounded_sam_2==0.1.0)
  Downloading autodistill_florence_2-0.1.1-py3-none-any.whl.metadata (3.7 kB)
Collecting flash-attn (from autodistill_florence_2->autodistill_grounded_sam_2==0.1.0)
  Downloading flash_attn-2.8.3.tar.gz (8.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.4/8.4 MB[0m [31m77.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collec

Checking torch and cuda version

In [5]:
import torch
print(torch.__version__, torch.cuda.is_available())
print(torch.version.cuda)

2.5.1+cu121 True
12.1


Checking if autodistill-grounded-sam-2 is installed

In [6]:
!pip show autodistill-grounded-sam-2

Name: autodistill_grounded_sam_2
Version: 0.1.0
Summary: Use Segment Anything 2, grounded with Florence-2, to auto-label data for use in training vision models.
Home-page: https://github.com/autodistill/autodistill-grounded-sam-2
Author: Roboflow
Author-email: autodistill@roboflow.com
License: 
Location: /content/autodistill-grounded-sam-2
Editable project location: /content/autodistill-grounded-sam-2
Requires: autodistill, autodistill_florence_2, numpy, opencv-python, roboflow, supervision, torch
Required-by: 


---
##**Installing grounded SAM2**

Cloning the git and installing the dependecies

***Grounded-SAM object detection model with the AutoDistill framework***

rf_groundingdino is a dependency related to GroundingDINO, a model for referring expression object detection (i.e., detecting objects from natural language).

In [7]:
!pip install git+https://github.com/autodistill/autodistill-grounded-sam-2 rf_groundingdino -q

  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m46.8/46.8 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m93.1/93.1 kB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m256.2/256.2 kB[0m [31m13.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for autodistill_grounded_sam_2 (setup.py) ... [?25l[?25hdone


In [8]:
!pip install transformers==4.49

Collecting transformers==4.49
  Downloading transformers-4.49.0-py3-none-any.whl.metadata (44 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/44.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.0/44.0 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers<0.22,>=0.21 (from transformers==4.49)
  Downloading tokenizers-0.21.4-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)
Downloading transformers-4.49.0-py3-none-any.whl (10.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.0/10.0 MB[0m [31m67.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading tokenizers-0.21.4-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m89.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: tokenizers, transformers
  Attempting uninstall: tokenizers
    Fou

In [9]:
!pip install -q svgpathtools gradio

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/68.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m68.3/68.3 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/67.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m67.1/67.1 kB[0m [31m6.1 MB/s[0m eta [36m0:00:00[0m
[?25h

---
## **Importing the packages and testing SAM2**

it takes some time to load... like 4 min in A100

In [10]:
from autodistill_grounded_sam_2 import GroundedSAM2
print('GroundedSAM2 is loaded successfully')

Importing from timm.models.layers is deprecated, please import via timm.layers


GroundedSAM2 is loaded successfully


In [11]:
from autodistill.detection import CaptionOntology
from autodistill.utils import plot
import numpy as np
import cv2
import os
import random
import supervision as sv

print('Import is fine!')

Import is fine!


# **Gradio App**

In [14]:

"""
Architectural Facade Understanding (GroundedSAM2)
Segment, vectorize, and classify facade elements (incl. arches, circles, ellipses & wavy shapes).
"""

import os
import cv2
import json
import zipfile
import logging
import tempfile
import numpy as np
import pandas as pd
import gradio as gr
from datetime import datetime
from dataclasses import dataclass

# -----------------------------------------------------------------------------
# Logging & torch quiet mode
# -----------------------------------------------------------------------------
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
try:
    import torch
    torch.set_grad_enabled(False)
except Exception:
    pass

# -----------------------------------------------------------------------------
# Globals
# -----------------------------------------------------------------------------
base_model = None
LAST_ONTOLOGY_SIG = None  # auto-reinit if user changes prompt


@dataclass
class Config:
    APP_TITLE = "Architectural Facade Understanding (GroundedSAM2)"
    APP_DESC = """
# 🏛️ Architectural Facade Understanding

A complete workflow to **detect**, **segment**, **vectorize**, and **classify** facade elements using GroundedSAM2 and geometry:

**What it does**
- 🎯 Prompt-driven detection & segmentation (GroundingDINO + SAM2 via autodistill)
- 🧼 Smooth **SVG** conversion of masks (preserves curves/arches)
- 🧠 Geometry-aware classification: **circle**, **ellipse**, **rectangle**, polygons (triangle…hexagon…n-gon), **arched_shape**, and **wavy_shape**
- 🗂️ Professional outputs: **SVG**, **CSV**, **JSON**, **JPG**, consolidated **ZIP**

**Roundness logic**
- Circles are proven with multiple tests (circularity, aspect ratio, radius std, Hausdorff, min-enclosing-circle residuals) with a Hough fallback for small blobs.
- Ellipses are detected via direct ellipse fit + residual checks.
"""
    DEFAULT_BOX_THRESHOLD = 0.02
    DEFAULT_TEXT_THRESHOLD = 0.02
    DEFAULT_CONF_THRESHOLD = 0.30
    DEFAULT_MAX_INCLUSIONS = 2

    # Rectangle acceptance
    RECTANGLE_AREA_RATIO = 0.80
    RECTANGLE_HD_THRESHOLD = 8.0
    RECT_ORTHO_TOL_DEG = 12  # edges near 90°

    # Circle thresholds (for medium+ shapes)
    CIRCULARITY_MIN = 0.90             # 4πA/P^2
    ELLIPSE_AR_MAX = 1.06              # circle ~= ellipse AR <= 1.06
    R_STD_FRAC_MAX = 0.08              # std(radius)/max(w,h)
    CIRC_HAUSDORFF_MAX = 6.5
    MEC_RES_FRAC_MAX = 0.08            # mean(|d-R|)/R

    # Small-shape relaxed thresholds (scale = max(w,h) < SMALL_SHAPE_PX)
    SMALL_SHAPE_PX = 80
    CIRCULARITY_MIN_SMALL = 0.86
    ELLIPSE_AR_MAX_SMALL = 1.12
    R_STD_FRAC_MAX_SMALL = 0.12
    CIRC_HAUSDORFF_MAX_SMALL = 8.5
    MEC_RES_FRAC_MAX_SMALL = 0.12

    # Ellipse thresholds (non-circular)
    ELLIPSE_MIN_AR = 1.08               # must be more elongated than a circle
    ELLIPSE_MAX_AR = 3.50
    ELLIPSE_RESIDUAL_FRAC_MAX = 0.14    # avg point-to-ellipse distance / max(rx,ry)

    # Wavy shape heuristics
    WAVY_MIN_INFLECTIONS = 6
    WAVY_PERIM_TO_HULL_MIN = 1.18
    WAVY_AREA_HULL_RATIO_MIN = 0.85

    # --- NEW: Arch guards / tuning to reduce rectangle mislabels ---
    ARCH_RECT_GUARD_AREA_RATIO = 0.93     # if rectangle fit is this good...
    ARCH_RECT_GUARD_HD = 3.0              # ...and close in Hausdorff, don't call it an arch
    ARCH_BASE_COVERAGE_MIN = 0.70         # flat base must span >= 70% of width
    ARCH_MIN_ARC_SPAN_DEG = 150           # top arc must span at least 150°
    ARCH_CIRCLE_INLIER_MIN_FRAC = 0.45    # >=45% of top points on the arc

    PLOT_DPI = 150
    PLOT_FIGSIZE = (10, 10)


# -----------------------------------------------------------------------------
# Small utilities
# -----------------------------------------------------------------------------
def _to_list(x):
    if x is None:
        return []
    try:
        return list(np.array(x).tolist())
    except Exception:
        try:
            return list(x)
        except Exception:
            return []


def get_bounding_boxes(results):
    bbs = []
    xyxy = getattr(results, "xyxy", None)
    if xyxy is None and hasattr(results, "boxes"):
        xyxy = getattr(getattr(results, "boxes", None), "xyxy", None)
    if xyxy is None and isinstance(results, (list, tuple, np.ndarray)):
        xyxy = results
    if xyxy is None:
        return bbs
    xyxy = np.array(xyxy)
    if xyxy.ndim == 2 and xyxy.shape[1] >= 4:
        for x1, y1, x2, y2, *_ in xyxy:
            bbs.append([int(x1), int(y1), int(x2), int(y2)])
    else:
        for item in xyxy:
            if len(item) >= 4:
                x1, y1, x2, y2 = item[:4]
                bbs.append([int(x1), int(y1), int(x2), int(y2)])
    return bbs


def create_inclusion_mask(bounding_boxes, max_inclusions=3):
    n = len(bounding_boxes)
    mask = [False] * n
    for i in range(n):
        incl = 0
        x1_i, y1_i, x2_i, y2_i = bounding_boxes[i]
        for j in range(n):
            if i == j:
                continue
            x1_j, y1_j, x2_j, y2_j = bounding_boxes[j]
            if x1_i <= x1_j and y1_i <= y1_j and x2_i >= x2_j and y2_i >= y2_j:
                incl += 1
                if incl > max_inclusions:
                    mask[i] = True
                    break
    return mask


def _ensure_dir(p):
    os.makedirs(p, exist_ok=True)
    return p


def _parse_prompt_to_ontology_dict(prompt_text: str):
    if not prompt_text or not prompt_text.strip():
        return None
    pairs = []
    for chunk in [c.strip() for c in prompt_text.split(",") if c.strip()]:
        if "->" in chunk:
            cap, label = [t.strip() for t in chunk.split("->", 1)]
        elif ":" in chunk:
            cap, label = [t.strip() for t in chunk.split(":", 1)]
        else:
            cap = label = chunk.strip()
        if cap and label:
            pairs.append((cap, label))
    if not pairs:
        return None
    return {cap: label for cap, label in pairs}


# -----------------------------------------------------------------------------
# SVG from mask (dense contours to preserve curvature)
# -----------------------------------------------------------------------------
def _mask_to_svg_string(mask_uint8: np.ndarray) -> str:
    """
    Convert a binary mask (0/255) to SVG by tracing contours.
    CHAIN_APPROX_NONE + light blur → smoother curves for arches/circles.
    """
    m = mask_uint8.astype(np.uint8)
    if m.max() > 1:
        m = (m > 0).astype(np.uint8) * 255
    m = cv2.GaussianBlur(m, (5, 5), 0)
    _, m = cv2.threshold(m, 127, 255, cv2.THRESH_BINARY)

    contours, _ = cv2.findContours(m, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
    h, w = m.shape[:2]

    try:
        import svgwrite
        dwg = svgwrite.Drawing(size=(w, h))
        dwg.viewbox(0, 0, w, h)

        def contour_to_path(contour):
            if len(contour) == 0:
                return ""
            pts = [f"{pt[0][0]},{pt[0][1]}" for pt in contour]
            return "M " + " L ".join(pts) + " Z"

        for contour in contours:
            d = contour_to_path(contour)
            if d:
                dwg.add(dwg.path(d=d, fill='black', stroke='none'))
        return dwg.tostring()
    except Exception:
        elems = []
        for c in contours:
            pts = [f"{pt[0][0]},{pt[0][1]}" for pt in c]
            if not pts:
                continue
            d = "M " + " L ".join(pts) + " Z"
            elems.append(f'<path d="{d}" fill="black" stroke="none" />')
        return f'<svg xmlns="http://www.w3.org/2000/svg" width="{w}" height="{h}" viewBox="0 0 {w} {h}">{"".join(elems)}</svg>'


# -----------------------------------------------------------------------------
# Classification (rectangle → arch → circle → ellipse → polygon → wavy)
# -----------------------------------------------------------------------------
def _classify_svg_assets(svg_file: str, out_dir: str):
    """
    Classify shapes from SVG and produce:
      - classified_output.jpg (labeled preview)
      - classified_output.svg (clean, typed SVG with id/class attrs)
    Detects: circle, ellipse, rectangle, polygons (triangle…dodecagon…N), 'arched_shape', and 'wavy_shape'.
    """
    try:
        import math
        import numpy as np
        import svgwrite
        from svgpathtools import svg2paths2
        from shapely.geometry import Polygon, Point, LineString
        from shapely.ops import unary_union
        import matplotlib.pyplot as plt
        from matplotlib.patches import Polygon as MplPolygon, Circle as MplCircle, Ellipse as MplEllipse, Rectangle as MplRect
    except Exception as e:
        raise gr.Error(
            f"Missing dependency for classification step: {e}\n"
            "Install: pip install svgpathtools shapely matplotlib svgwrite"
        )

    # ---- helpers ----
    def sample_points(path, n=360):
        return [(float(z.real), float(z.imag)) for z in (path.point(i/(n-1)) for i in range(n))]

    def circle_metrics(poly: "Polygon"):
        centroid = poly.centroid
        cx, cy = centroid.x, centroid.y
        coords = list(poly.exterior.coords)
        if not coords:
            return 0.0, 1e9, (cx, cy, 0.0), 1e9
        dists = [math.hypot(x-cx, y-cy) for x,y in coords]
        r = float(np.mean(dists)) if len(dists) else 0.0
        circle = Point(cx,cy).buffer(r, resolution=256)
        area_ratio = poly.area / (math.pi*r*r) if r > 0 else 0.0
        hd = poly.hausdorff_distance(circle)
        rad_std = float(np.std(dists))
        return area_ratio, hd, (cx,cy,r), rad_std

    def polygon_sides(poly, angle_tol=22, simplify_ratio=0.012):
        per = float(poly.length)
        tol = max(1.0, simplify_ratio * per)
        simp = LineString(list(poly.exterior.coords)).simplify(tol, preserve_topology=False)
        coords = list(simp.coords)
        if len(coords) < 3: return 0, []
        if coords[0] == coords[-1]: coords = coords[:-1]
        edges = []
        for i in range(len(coords)):
            x1,y1 = coords[i]; x2,y2 = coords[(i+1)%len(coords)]
            vx, vy = x2-x1, y2-y1
            nrm = math.hypot(vx,vy)
            if nrm > 1e-6:
                edges.append((vx/nrm, vy/nrm))
        dirs = []
        simp_coords = [coords[0]]
        for j,v in enumerate(edges):
            ang = math.degrees(math.atan2(v[1], v[0]))
            if not dirs:
                dirs.append(ang)
            else:
                prev = dirs[-1]
                diff = abs((ang-prev+180)%360 - 180)
                if diff > angle_tol:
                    dirs.append(ang)
                    simp_coords.append(coords[j])
        simp_coords.append(coords[-1])
        return max(0, len(dirs)), simp_coords

    # --- PCA + RANSAC helpers (for arches) ---
    def _pca_upright(coords):
        P = np.asarray(coords, dtype=float)
        C = P.mean(axis=0)
        X = P - C
        cov = np.cov(X.T)
        vals, vecs = np.linalg.eigh(cov)
        major = vecs[:, 1]
        angle = np.arctan2(major[1], major[0]) - np.pi/2.0
        ca, sa = np.cos(-angle), np.sin(-angle)
        R = np.array([[ca, -sa], [sa, ca]])
        Y = X @ R.T
        return Y, angle, C

    def _ransac_circle(points: np.ndarray, iters=450, tol=2.5, min_inliers=16):
        if len(points) < 6:
            return None
        best = None
        rng = np.random.default_rng(123)
        for _ in range(iters):
            idx = rng.choice(len(points), size=3, replace=False)
            (x1,y1),(x2,y2),(x3,y3) = points[idx]
            temp = 2*(x1*(y2 - y3) + x2*(y3 - y1) + x3*(y1 - y2))
            if abs(temp) < 1e-6:
                continue
            ux = ((x1*x1 + y1*y1)*(y2 - y3) + (x2*x2 + y2*y2)*(y3 - y1) + (x3*x3 + y3*y3)*(y1 - y2)) / temp
            uy = ((x1*x1 + y1*y1)*(x3 - x2) + (x2*x2 + y2*y2)*(x1 - x3) + (x3*x3 + y3*y3)*(x2 - x1)) / temp
            r = np.sqrt((points[:,0]-ux)**2 + (points[:,1]-uy)**2)
            R = np.median(r)
            dist = np.abs(r - R)
            inliers = dist < tol
            score = inliers.sum()
            if score >= (best[0] if best else -1) and score >= min_inliers:
                best = (score, ux, uy, float(R), inliers)
        if not best:
            return None
        _, ux, uy, R, inliers = best
        return (ux, uy, R, inliers)

    def _ransac_line(points: np.ndarray, iters=300, tol=2.0):
        if len(points) < 3:
            return None
        best = None
        rng = np.random.default_rng(321)
        for _ in range(iters):
            i1 = rng.integers(0, len(points)); i2 = rng.integers(0, len(points))
            if i1 == i2:
                continue
            x1,y1 = points[i1]; x2,y2 = points[i2]
            if x2 == x1:
                a = 1e9; b = 0.0
            else:
                a = (y2 - y1) / (x2 - x1)
                b = y1 - a*x1
            x = points[:,0]; y = points[:,1]
            d = np.abs(a*x - y + b) / np.sqrt(a*a + 1.0)
            inliers = d < tol
            score = inliers.sum()
            if (best is None or score > best[0]) and (abs(a) < 0.12):
                best = (score, float(a), float(b), inliers)
        return None if best is None else (best[1], best[2], best[3])

    def _is_arched_shape(coords):
        # Need enough points
        if not coords or len(coords) < 20:
            return False

        P = np.asarray(coords, dtype=float)

        # Build polygon and basic sanity
        poly = Polygon(P).buffer(0)
        if (poly.is_empty) or (not poly.is_valid):
            return False

        # QUICK REJECTION: too rectangle-like?
        rect = poly.minimum_rotated_rectangle
        area = float(poly.area) if poly.area > 0 else 0.0
        rect_area = float(rect.area) if rect.area > 0 else 1.0
        rect_area_ratio = area / rect_area
        rect_hd = float(poly.hausdorff_distance(rect))
        if (rect_area_ratio >= Config.ARCH_RECT_GUARD_AREA_RATIO) and (rect_hd <= Config.ARCH_RECT_GUARD_HD):
            return False  # looks very much like a rectangle

        # PCA upright as before
        Y, angle, center = _pca_upright(coords)
        x = Y[:,0]; y = Y[:,1]
        xmin, xmax = x.min(), x.max()
        ymin, ymax = y.min(), y.max()
        w, h = (xmax - xmin), (ymax - ymin)
        if w <= 0 or h <= 0:
            return False

        # Arch should be tall-ish vs width
        if h < 0.85 * w:
            return False

        # If it's nearly a circle, reject (arches aren't round blobs)
        cx0, cy0 = Y.mean(axis=0)
        R = np.sqrt((Y[:,0]-cx0)**2 + (Y[:,1]-cy0)**2)
        if R.std() < 0.06 * max(w, h):
            return False

        # Split top/bottom bands (slightly tightened)
        top_cut = ymin + 0.55*(ymax - ymin)
        bot_cut = ymin + 0.82*(ymax - ymin)
        top_pts = Y[Y[:,1] <= top_cut]
        bot_pts = Y[Y[:,1] >= bot_cut]
        if len(top_pts) < 18 or len(bot_pts) < 10:
            return False

        # Fit a nearly-horizontal base line on bottom points
        line = _ransac_line(bot_pts, iters=300, tol=max(1.8, 0.012*max(w,h)))
        if line is None:
            return False
        a,b,inliers_bot = line
        inlier_ratio_bot = float(inliers_bot.sum()) / float(len(bot_pts))
        y_bottom = np.median(bot_pts[inliers_bot,1]) if inliers_bot.any() else np.median(bot_pts[:,1])

        # Base must be flat, near the bottom, and cover most width
        if (inlier_ratio_bot < 0.60) or (abs(a) >= 0.12) or ((ymax - y_bottom) > 0.18*h):
            return False

        # NEW: base coverage across width
        base_x = bot_pts[inliers_bot, 0] if inliers_bot.any() else bot_pts[:,0]
        if base_x.size >= 2:
            base_cov = (base_x.max() - base_x.min()) / max(w, 1e-6)
            if base_cov < Config.ARCH_BASE_COVERAGE_MIN:
                return False

        # Fit a circular arc to the top
        circ = _ransac_circle(top_pts, iters=450, tol=max(2.4, 0.018*max(w,h)),
                              min_inliers=max(18, int(0.35*len(top_pts))))
        if circ is None:
            return False
        cx, cy, r, inliers_c = circ

        # Require a healthy fraction of top points to lie on the arc
        inlier_frac_circ = float(inliers_c.sum()) / float(len(top_pts))
        if inlier_frac_circ < Config.ARCH_CIRCLE_INLIER_MIN_FRAC:
            return False

        # Arc span (tightened)
        A = top_pts[inliers_c] - np.array([cx, cy])
        ang = np.arctan2(A[:,1], A[:,0])
        ang = np.unwrap(np.sort(ang))
        arc_span = (ang.max() - ang.min())
        if arc_span < np.deg2rad(Config.ARCH_MIN_ARC_SPAN_DEG):
            return False

        # Circle center should be above the arch centroid band (keeps it “cap-like”)
        if cy > (ymin + 0.60*h):
            return False

        # Radius should be plausible relative to width
        if not (0.30*w <= r <= 0.85*w):
            return False

        return True

    def _is_circle_adaptive(poly, circularity, ellipse_ar, mec_residual, circ_hd, rad_std, scale):
        if scale < Config.SMALL_SHAPE_PX:
            circ_ok  = circularity >= Config.CIRCULARITY_MIN_SMALL
            ar_ok    = ellipse_ar <= Config.ELLIPSE_AR_MAX_SMALL
            std_ok   = rad_std   <= Config.R_STD_FRAC_MAX_SMALL * max(scale, 1.0)
            hd_ok    = circ_hd   <= Config.CIRC_HAUSDORFF_MAX_SMALL
            mec_ok   = mec_residual <= Config.MEC_RES_FRAC_MAX_SMALL
        else:
            circ_ok  = circularity >= Config.CIRCULARITY_MIN
            ar_ok    = ellipse_ar <= Config.ELLIPSE_AR_MAX
            std_ok   = rad_std   <= Config.R_STD_FRAC_MAX * max(scale, 1.0)
            hd_ok    = circ_hd   <= Config.CIRC_HAUSDORFF_MAX
            mec_ok   = mec_residual <= Config.MEC_RES_FRAC_MAX
        return circ_ok and ar_ok and std_ok and hd_ok and mec_ok

    def _ellipse_fit_metrics(pts_np):
        """
        Returns (xc, yc, rx, ry, angle_deg, mean_residual_frac) or None
        angle in degrees, rx >= ry.
        """
        if len(pts_np) < 5:
            return None
        try:
            (xc, yc), (MA, ma), angle = cv2.fitEllipse(pts_np.astype(np.float32))
            # cv2 returns major/minor lengths (diameters)
            rx, ry = max(MA, ma)/2.0, min(MA, ma)/2.0
            # Residual: distance to ellipse (approx via algebraic distance)
            cos_t = np.cos(np.deg2rad(angle)); sin_t = np.sin(np.deg2rad(angle))
            X = pts_np[:,0] - xc; Y = pts_np[:,1] - yc
            xr =  X*cos_t + Y*sin_t
            yr = -X*sin_t + Y*cos_t
            # Algebraic residual to unit ellipse
            res = np.abs((xr/rx)**2 + (yr/ry)**2 - 1.0)
            mean_res = float(np.mean(res))
            mean_res_frac = mean_res  # already normalized
            return float(xc), float(yc), float(rx), float(ry), float(angle), mean_res_frac
        except Exception:
            return None

    def _is_rectangle(poly, rect_area_ratio, rect_hd):
        # Already coarse checks passed in caller. Add near-orthogonal angles:
        coords = list(poly.minimum_rotated_rectangle.exterior.coords)
        if len(coords) < 4:
            return False
        # vectors
        vecs = []
        for i in range(4):
            x1,y1 = coords[i]; x2,y2 = coords[(i+1)%4]
            vx, vy = x2-x1, y2-y1
            n = np.hypot(vx, vy)
            if n > 1e-6:
                vecs.append((vx/n, vy/n))
        if len(vecs) < 2:
            return False
        dots = []
        for i in range(2):
            a = np.array(vecs[i]); b = np.array(vecs[(i+1)%4])
            dots.append(abs(np.dot(a,b)))
        # near 90° => dot ~ 0
        return (rect_area_ratio > Config.RECTANGLE_AREA_RATIO) and (rect_hd < Config.RECTANGLE_HD_THRESHOLD) and all(d < np.cos(np.deg2rad(90-Config.RECT_ORTHO_TOL_DEG)) for d in dots)

    def _wavy_score(poly):
        """
        Heuristics: count curvature sign changes along simplified contour
        + compare perimeter/hull and area/hull.
        """
        coords = np.asarray(list(poly.exterior.coords), dtype=float)
        if len(coords) < 20:
            return 0, 1.0, 0.0, 0
        # curvature via turning angle differences
        V = coords[1:] - coords[:-1]
        ang = np.unwrap(np.arctan2(V[:,1], V[:,0]))
        d_ang = np.diff(ang)
        # sign changes
        sgn = np.sign(d_ang + 1e-9)
        changes = np.sum(np.abs(np.diff(sgn)) > 1e-6)
        hull = poly.convex_hull
        perim_ratio = float(poly.length) / max(hull.length, 1e-6)
        area_ratio = float(poly.area) / max(hull.area, 1e-6)
        return changes, perim_ratio, area_ratio, len(coords)

    # ---- parse & classify ----
    paths, attrs, _ = svg2paths2(svg_file)
    results = []

    for i,(p,a) in enumerate(zip(paths,attrs)):
        pts = sample_points(p, 360)
        if pts and pts[0] != pts[-1]:
            pts.append(pts[0])
        from shapely.geometry import Polygon
        poly = Polygon(pts).buffer(0)
        if not poly.is_valid or poly.is_empty:
            continue

        area = float(poly.area)
        perim = float(poly.length) if poly.length > 1e-9 else 1e-9
        circularity = float(4.0*np.pi*area/(perim*perim))  # 1 for perfect circle

        # bounds
        minx, miny, maxx, maxy = poly.bounds
        W = maxx - minx
        H = maxy - miny
        scale = max(W, H)

        # min enclosing circle residual (OpenCV)
        pts_np = np.array(poly.exterior.coords, dtype=np.float32)
        (mec_cx, mec_cy), mec_r = cv2.minEnclosingCircle(pts_np)
        mec_r = float(mec_r)
        if mec_r <= 1e-6:
            mec_residual = 1e9
        else:
            d = np.sqrt((pts_np[:,0]-mec_cx)**2 + (pts_np[:,1]-mec_cy)**2)
            mec_residual = float(np.mean(np.abs(d - mec_r))) / mec_r  # normalized

        # ellipse fit aspect ratio
        ellipse_ar = 9.9
        fit_ell = _ellipse_fit_metrics(pts_np)
        if fit_ell is not None:
            xc, yc, rx, ry, ang_deg, ell_res_frac = fit_ell
            if rx > 0 and ry > 0:
                big, small = max(rx, ry), min(rx, ry)
                ellipse_ar = float(big / small)
        else:
            xc = yc = rx = ry = ang_deg = ell_res_frac = None

        # rectangle fit
        rect = poly.minimum_rotated_rectangle
        rect_area_ratio = area / float(rect.area) if rect.area > 0 else 0.0
        rect_hd = float(poly.hausdorff_distance(rect))
        is_rect_coarse = (rect_area_ratio > Config.RECTANGLE_AREA_RATIO) and (rect_hd < Config.RECTANGLE_HD_THRESHOLD)

        # circle fit (centroid-based) + Hausdorff + radius std
        circ_area_ratio, circ_hd, (cx,cy,r_est), rad_std = (0, 1e9, (0,0,0), 1e9)
        try:
            circ_area_ratio, circ_hd, (cx,cy,r_est), rad_std = circle_metrics(poly)
        except Exception:
            pass

        # polygon corners
        n_sides, simp_coords = polygon_sides(poly, angle_tol=22, simplify_ratio=0.012)

        # ---- classification cascade (rectangle before arched_shape) ----
        cls = "unclassified"
        geom = list(poly.exterior.coords)
        export_hint = "path"

        # 1) rectangle (strong rects shouldn't be called arches)
        if _is_rectangle(poly, rect_area_ratio, rect_hd):
            cls = "rectangle"
            rect_coords = list(rect.exterior.coords)
            xs = [c[0] for c in rect_coords[:-1]]
            ys = [c[1] for c in rect_coords[:-1]]
            xmin_r, xmax_r = min(xs), max(xs)
            ymin_r, ymax_r = min(ys), max(ys)
            v = np.array(rect_coords[1]) - np.array(rect_coords[0])
            ang = np.degrees(np.arctan2(v[1], v[0]))
            axis_aligned = (abs((ang+360)%90) < 2.0) or (abs((ang)%90) < 2.0)
            if axis_aligned:
                geom = {"x": round(float(xmin_r),2), "y": round(float(ymin_r),2),
                        "width": round(float(xmax_r-xmin_r),2), "height": round(float(ymax_r-ymin_r),2)}
                export_hint = "rect"
            else:
                geom = rect_coords
                export_hint = "path"

        else:
            # 2) arched shape (now after rectangle)
            if _is_arched_shape(list(poly.exterior.coords)):
                cls = "arched_shape"
                geom = list(poly.exterior.coords)
                export_hint = "path"

            else:
                # 3) circle
                is_round_enough = _is_circle_adaptive(poly, circularity, ellipse_ar, mec_residual, circ_hd, rad_std, scale)
                if is_round_enough:
                    cls = "circle"
                    ccx, ccy, rr = (cx if r_est > 0 else mec_cx,
                                    cy if r_est > 0 else mec_cy,
                                    r_est if r_est > 0 else mec_r)
                    geom = {"center": [round(float(ccx), 2), round(float(ccy), 2)],
                            "radius": round(float(rr), 2)}
                    export_hint = "circle"
                else:
                    # 4) ellipse (non-circular)
                    if fit_ell is not None:
                        ar_ok = (ellipse_ar >= Config.ELLIPSE_MIN_AR) and (ellipse_ar <= Config.ELLIPSE_MAX_AR)
                        res_ok = (ell_res_frac <= Config.ELLIPSE_RESIDUAL_FRAC_MAX)
                        if ar_ok and res_ok:
                            cls = "ellipse"
                            geom = {
                                "center": [round(float(xc),2), round(float(yc),2)],
                                "rx": round(float(rx),2),
                                "ry": round(float(ry),2),
                                "angle_deg": round(float(ang_deg),2)
                            }
                            export_hint = "ellipse"

                    if cls == "unclassified":
                        # 5) polygon (named if 3..12)
                        names = {3:"triangle",4:"quadrilateral",5:"pentagon",6:"hexagon",
                                 7:"heptagon",8:"octagon",9:"nonagon",10:"decagon",
                                 11:"hendecagon",12:"dodecagon"}
                        if n_sides >= 3:
                            cls = names.get(n_sides, f"polygon_{n_sides}_sides")
                            geom = simp_coords if len(simp_coords) >= 3 else list(poly.exterior.coords)
                            export_hint = "path"
                        else:
                            # 6) wavy shape: many inflections and rich/hull ratios
                            changes, perim_ratio, area_ratio, npts = _wavy_score(poly)
                            if (changes >= Config.WAVY_MIN_INFLECTIONS and
                                perim_ratio >= Config.WAVY_PERIM_TO_HULL_MIN and
                                area_ratio >= Config.WAVY_AREA_HULL_RATIO_MIN):
                                cls = "wavy_shape"
                                geom = list(poly.exterior.coords)
                                export_hint = "path"
                            else:
                                cls = "unclassified"
                                geom = list(poly.exterior.coords)
                                export_hint = "path"

        results.append({
            "id": a.get("id", f"elem_{i}"),
            "classification": cls,
            "export_hint": export_hint,
            "n_sides": int(n_sides),
            "circularity": round(circularity,4),
            "ellipse_ar": round(float(ellipse_ar),4) if np.isfinite(ellipse_ar) else None,
            "ellipse_residual_frac": round(float(ell_res_frac),4) if fit_ell is not None else None,
            "mec_residual": round(float(mec_residual),4),
            "rect_area_ratio": round(rect_area_ratio,3),
            "rect_hd": round(rect_hd,2),
            "circ_area_ratio": round(float(circ_area_ratio),3),
            "circ_hd": round(float(circ_hd),2),
            "radius_std": round(float(rad_std),3),
            "geometry": geom
        })

    # Save JSON
    base, _ = os.path.splitext(svg_file)
    json_file = base + "_classified.json"
    with open(json_file, "w", encoding="utf-8") as f:
        json.dump(results, f, indent=2)

    # Labeled JPG preview
    with open(json_file, "r", encoding="utf-8") as f:
        data = json.load(f)

    all_x, all_y = [], []
    for shape in data:
        geom = shape["geometry"]
        if isinstance(geom, dict):
            if "radius" in geom:  # circle
                cx, cy = geom["center"]; r = geom["radius"]
                all_x += [cx-r, cx+r]; all_y += [cy-r, cy+r]
            elif "rx" in geom and "ry" in geom:  # ellipse
                cx, cy = geom["center"]; rx, ry = geom["rx"], geom["ry"]
                all_x += [cx-rx, cx+rx]; all_y += [cy-ry, cy+ry]
            elif {"x","y","width","height"} <= set(geom.keys()):
                all_x += [geom["x"], geom["x"]+geom["width"]]
                all_y += [geom["y"], geom["y"]+geom["height"]]
        elif isinstance(geom, list) and geom:
            try:
                xs, ys = zip(*geom)
                all_x += xs; all_y += ys
            except Exception:
                pass
    if not all_x or not all_y:
        xmin, xmax, ymin, ymax = 0, 100, 0, 100
    else:
        xmin, xmax = min(all_x), max(all_x)
        ymin, ymax = min(all_y), max(all_y)
    if xmax == xmin: xmax += 1
    if ymax == ymin: ymax += 1

    import matplotlib.pyplot as plt
    from matplotlib.patches import Polygon as MplPolygon, Circle as MplCircle, Ellipse as MplEllipse, Rectangle as MplRect

    fig1, ax1 = plt.subplots(figsize=(10, 10))
    ax1.set_aspect("equal")
    ax1.invert_yaxis()
    for shape in data:
        cls = shape["classification"]; geom = shape["geometry"]
        try:
            if cls == "circle" and isinstance(geom, dict):
                cx, cy = geom["center"]; r = geom["radius"]
                circ = MplCircle((cx, cy), r, fill=True, facecolor="#000",
                                 edgecolor="#000", alpha=0.35, linewidth=1.2)
                ax1.add_patch(circ)
                ax1.text(cx, cy, cls, color='grey', ha="center", va="center", fontsize=8, fontweight='bold')
            elif cls == "ellipse" and isinstance(geom, dict):
                cx, cy = geom["center"]; rx, ry = geom["rx"], geom["ry"]; ang = geom["angle_deg"]
                ell = MplEllipse((cx, cy), 2*rx, 2*ry, angle=ang, fill=True, facecolor="#000",
                                 edgecolor="#000", alpha=0.35, linewidth=1.2)
                ax1.add_patch(ell)
                ax1.text(cx, cy, cls, color='grey', ha="center", va="center", fontsize=8, fontweight='bold')
            elif cls == "rectangle" and isinstance(geom, dict) and {"x","y","width","height"} <= set(geom.keys()):
                rect = MplRect((geom["x"], geom["y"]), geom["width"], geom["height"],
                               fill=True, facecolor="#000", edgecolor="#000", alpha=0.85, linewidth=1.0)
                ax1.add_patch(rect)
                cx = geom["x"] + geom["width"]/2; cy = geom["y"] + geom["height"]/2
                ax1.text(cx, cy, cls, color='grey', ha="center", va="center", fontsize=8, fontweight='bold')
            elif isinstance(geom, list) and len(geom) > 2:
                poly = MplPolygon(geom, fill=True, facecolor="#000",
                                  edgecolor="#000", alpha=0.85, linewidth=1.0)
                ax1.add_patch(poly)
                xs, ys = zip(*geom)
                cx = sum(xs)/len(xs); cy = sum(ys)/len(ys)
                ax1.text(cx, cy, cls, color='grey', ha="center", va="center", fontsize=8, fontweight='bold')
        except Exception:
            continue
    ax1.set_xlim(xmin, xmax)
    ax1.set_ylim(ymax, ymin)
    ax1.axis("off")
    jpg_path = os.path.join(out_dir, "classified_output.jpg")
    plt.savefig(jpg_path, dpi=150, bbox_inches="tight", facecolor='white', edgecolor='none')
    plt.close(fig1)

    # Clean, typed SVG with valid attributes
    import svgwrite
    dwg = svgwrite.Drawing(size=(xmax - xmin, ymax - ymin))
    dwg.viewbox(minx=xmin, miny=ymin, width=(xmax - xmin), height=(ymax - ymin))
    try:
        legend = {sh["id"]: sh["classification"] for sh in data}
        dwg.add(dwg.metadata(json.dumps({"classes": legend}, ensure_ascii=False)))
    except Exception:
        pass

    def _add_path(points, _id, cls):
        pts = [(float(x), float(y)) for x,y in points]
        d = "M " + " L ".join(f"{x},{y}" for x,y in pts) + " Z"
        el = dwg.path(d=d, fill='black', stroke='black', **{"stroke-width":1.5})
        if _id: el["id"] = _id
        el["class"] = cls
        el.set_desc(title=cls)
        dwg.add(el)

    for shape in data:
        cls = shape["classification"]
        geom = shape["geometry"]
        elem_id = shape.get("id", "")
        hint = shape.get("export_hint", "path")

        try:
            if hint == "circle" and isinstance(geom, dict):
                cx, cy = geom["center"]; r = geom["radius"]
                el = dwg.circle(center=(cx, cy), r=r, fill='black', stroke='black', **{"stroke-width":1.5})
                if elem_id: el["id"] = elem_id
                el["class"] = cls
                el.set_desc(title=cls)
                dwg.add(el)
            elif hint == "ellipse" and isinstance(geom, dict):
                cx, cy = geom["center"]; rx, ry = geom["rx"], geom["ry"]; ang = geom["angle_deg"]
                el = dwg.ellipse(center=(cx, cy), r=(rx, ry), fill='black', stroke='black', **{"stroke-width":1.5})
                # rotation about center
                el.rotate(ang, center=(cx, cy))
                if elem_id: el["id"] = elem_id
                el["class"] = cls
                el.set_desc(title=cls)
                dwg.add(el)
            elif hint == "rect" and isinstance(geom, dict) and {"x","y","width","height"} <= set(geom.keys()):
                el = dwg.rect(insert=(geom["x"], geom["y"]), size=(geom["width"], geom["height"]),
                              fill='black', stroke='black', **{"stroke-width":1.5})
                if elem_id: el["id"] = elem_id
                el["class"] = cls
                el.set_desc(title=cls)
                dwg.add(el)
            elif isinstance(geom, list) and len(geom) > 2:
                _add_path(geom, elem_id, cls)
        except Exception:
            # Fallback: path
            if isinstance(geom, list) and len(geom) > 2:
                _add_path(geom, elem_id, cls)

    svg_out_path = os.path.join(out_dir, "classified_output.svg")
    dwg.saveas(svg_out_path)

    return jpg_path, svg_out_path


# -----------------------------------------------------------------------------
# Inference function
# -----------------------------------------------------------------------------
def run_inference(
    image: np.ndarray,
    prompt_text: str,
    box_threshold: float,
    text_threshold: float,
    conf_threshold: float,
    max_inclusions: int,
    reinit_model: bool,
):
    try:
        import supervision as sv
    except Exception as e:
        raise gr.Error(f"Missing dependency 'supervision': {e}")
    try:
        from autodistill.detection import CaptionOntology
        from autodistill_grounded_sam_2 import GroundedSAM2
    except Exception as e:
        raise gr.Error(f"Autodistill / GroundedSAM2 not installed or import failed: {e}")

    if image is None:
        raise gr.Error("Please upload an image.")

    # Auto-reinit model when prompt (ontology) changes
    global base_model, LAST_ONTOLOGY_SIG
    onto_dict = _parse_prompt_to_ontology_dict(prompt_text) or {"window": "window"}
    onto_sig = "|".join(f"{k}=>{v}" for k, v in sorted(onto_dict.items()))
    need_init = reinit_model or base_model is None or LAST_ONTOLOGY_SIG != onto_sig
    if need_init:
        try:
            base_model = GroundedSAM2(
                ontology=CaptionOntology(onto_dict),
                model="Grounding DINO",
                grounding_dino_box_threshold=float(box_threshold),
                grounding_dino_text_threshold=float(text_threshold),
            )
            LAST_ONTOLOGY_SIG = onto_sig
            logger.info(f"Initialized model with ontology: {onto_dict}")
        except Exception as e:
            raise gr.Error(f"Failed to initialize GroundedSAM2: {e}")

    # Persist inputs
    bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
    tmp_root = _ensure_dir(os.path.join(tempfile.gettempdir(), "gsam2_gradio"))
    stamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
    out_dir = _ensure_dir(os.path.join(tmp_root, f"out_{stamp}"))
    img_path = os.path.join(out_dir, f"input_{stamp}.jpg")
    cv2.imwrite(img_path, bgr)

    # Predict
    results = base_model.predict(img_path)
    if hasattr(results, "with_nms"):
        try:
            results = results.with_nms()
        except Exception:
            pass

    # Confidence filter
    try:
        det_conf = np.array(getattr(results, "confidence", []))
        if det_conf.size:
            keep = det_conf > float(conf_threshold)
            results = results[keep]
    except Exception:
        pass

    # Labels for overlay
    labels = []
    names = getattr(results, "class_names", None) or getattr(results, "class_name", None)
    confidences = _to_list(getattr(results, "confidence", None))
    if names is not None and len(_to_list(names)):
        for name, conf in zip(_to_list(names), confidences):
            try:
                conf_f = float(conf)
            except Exception:
                conf_f = 0.0
            labels.append(f"{str(name)} {conf_f:.2f}")
    else:
        class_ids = _to_list(getattr(results, "class_id", None))
        onto_labels = list(onto_dict.values())
        for cid, conf in zip(class_ids, confidences):
            try:
                name = onto_labels[int(cid)]
            except Exception:
                name = str(cid)
            try:
                conf_f = float(conf)
            except Exception:
                conf_f = 0.0
            labels.append(f"{name} {conf_f:.2f}")

    # Annotated image
    image_bgr = bgr.copy()
    try:
        import supervision as sv
        box_annotator = sv.BoxAnnotator()
        label_annotator = sv.LabelAnnotator()
        annotated_bgr = box_annotator.annotate(image_bgr, detections=results)
        annotated_bgr = label_annotator.annotate(scene=annotated_bgr, detections=results, labels=labels)
    except Exception:
        annotated_bgr = image_bgr

    # Combined mask
    H, W = image.shape[:2]
    combined_mask = np.zeros((H, W), dtype=np.uint8)
    try:
        bdbs = get_bounding_boxes(results)
        inclusion_mask = create_inclusion_mask(bdbs, max_inclusions=max_inclusions)
        try:
            mask_bool = np.array(inclusion_mask) == False
            filtered = results[mask_bool]
        except Exception:
            filtered = results

        masks = getattr(filtered, "mask", None)
        if masks is None:
            masks = getattr(filtered, "masks", None)
        if masks is not None and len(masks) > 0:
            m = np.zeros_like(masks[0], dtype=np.uint8)
            for mk in masks:
                m = np.maximum(m, mk.astype(np.uint8))
            combined_mask = (m > 0).astype(np.uint8) * 255
    except Exception:
        try:
            masks = getattr(results, "mask", None)
            if masks is None:
                masks = getattr(results, "masks", None)
            if masks is not None and len(masks) > 0:
                m = np.zeros_like(masks[0], dtype=np.uint8)
                for mk in masks:
                    m = np.maximum(m, mk.astype(np.uint8))
                combined_mask = (m > 0).astype(np.uint8) * 255
        except Exception:
            pass

    # Detections table
    rows = []
    try:
        xyxy = np.array(getattr(results, "xyxy", []))
        confs = _to_list(getattr(results, "confidence", None))
        names_attr = getattr(results, "class_names", None) or getattr(results, "class_name", None)
        names_list = _to_list(names_attr)
        onto_labels = list(onto_dict.values())
        default_label = onto_labels[0] if onto_labels else "object"

        label_names = []
        if names_list and len(names_list) == (xyxy.shape[0] if xyxy.ndim == 2 else len(names_list)):
            label_names = [str(n) if n is not None else default_label for n in names_list]
        else:
            cids = _to_list(getattr(results, "class_id", None))
            num_dets = xyxy.shape[0] if (xyxy.ndim == 2 and xyxy.shape[1] >= 4) else len(cids)
            for i in range(num_dets):
                name = default_label
                if i < len(cids):
                    cid = cids[i]
                    try:
                        cid_int = int(cid)
                        if 0 <= cid_int < len(onto_labels):
                            name = onto_labels[cid_int]
                        else:
                            name = str(cid)
                    except Exception:
                        name = str(cid) if cid is not None else default_label
                label_names.append(str(name))

        if xyxy.ndim == 2 and xyxy.shape[1] >= 4:
            for i, bb in enumerate(xyxy):
                x1, y1, x2, y2 = [int(v) for v in bb[:4]]
                cf = float(confs[i]) if (confs and i < len(confs)) else None
                label = label_names[i] if i < len(label_names) else default_label
                rows.append(
                    {"instance_label": None, "label": label, "confidence": cf, "x1": x1, "y1": y1, "x2": x2, "y2": y2}
                )
        df = pd.DataFrame(rows)
    except Exception:
        df = pd.DataFrame(columns=["instance_label", "label", "confidence", "x1", "y1", "x2", "y2"])

    # Add instance labels
    try:
        if not df.empty:
            df["label"] = df["label"].astype(str)
            missing_mask = df["label"].isin([None, "", "nan", "None"])
            onto_labels = list(onto_dict.values())
            default_label = onto_labels[0] if onto_labels else "object"
            df.loc[missing_mask, "label"] = default_label

            base = (
                df["label"]
                .str.lower()
                .str.replace(r"[^a-z0-9_-]+", "_", regex=True)
                .str.strip("_")
            )
            df["__class_norm__"] = base.where(base.ne(""), default_label)
            df["__idx__"] = df.groupby("__class_norm__").cumcount() + 1
            df["instance_label"] = df["__class_norm__"] + "_" + df["__idx__"].astype(str)
            df = df[["instance_label", "label", "confidence", "x1", "y1", "x2", "y2"]]
    except Exception:
        pass

    # Save artifacts
    ann_path = os.path.join(out_dir, "annotated.jpg")
    mask_path = os.path.join(out_dir, "combined_mask.png")
    csv_path = os.path.join(out_dir, "detections.csv")
    cv2.imwrite(ann_path, annotated_bgr)
    cv2.imwrite(mask_path, combined_mask)
    df.to_csv(csv_path, index=False)

    # SVG from mask
    try:
        svg_string = _mask_to_svg_string(combined_mask.astype(np.uint8))
    except Exception:
        h, w = combined_mask.shape[:2]
        svg_string = f'<svg xmlns="http://www.w3.org/2000/svg" width="{w}" height="{h}" viewBox="0 0 {w} {h}"></svg>'

    svg_path = os.path.join(out_dir, "mask_output.svg")
    with open(svg_path, "w", encoding="utf-8") as f:
        f.write(svg_string)

    # Classification assets
    classified_jpg, classified_svg = _classify_svg_assets(svg_path, out_dir)

    # ZIP
    zip_path = os.path.join(os.path.dirname(out_dir), f"gsam2_outputs_{os.path.basename(out_dir)}.zip")
    base, _ = os.path.splitext(svg_path)
    json_path = base + "_classified.json"
    with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
        for p in (ann_path, mask_path, csv_path, svg_path, json_path, classified_jpg, classified_svg):
            if os.path.isfile(p):
                zf.write(p, arcname=os.path.basename(p))

    annotated_rgb = cv2.cvtColor(annotated_bgr, cv2.COLOR_BGR2RGB)
    classified_rgb = cv2.cvtColor(cv2.imread(classified_jpg), cv2.COLOR_BGR2RGB)

    return annotated_rgb, combined_mask, df, svg_path, zip_path, classified_rgb, classified_svg


# -----------------------------------------------------------------------------
# Gradio UI
# -----------------------------------------------------------------------------
with gr.Blocks(title=Config.APP_TITLE) as demo:
    gr.Markdown(Config.APP_DESC)

    with gr.Row():
        with gr.Column():
            inp_image = gr.Image(label="Upload facade image", type="numpy")
            prompt = gr.Textbox(
                label="Objects to detect",
                placeholder="e.g., window or window, door",
                value="window",
            )
            gr.Examples(["window", "window, door", "window, door, balcony"], prompt)

            with gr.Accordion("Advanced thresholds", open=True):
                box_thr = gr.Slider(0.0, 1.0, value=Config.DEFAULT_BOX_THRESHOLD, step=0.01,
                                    label="GroundingDINO box threshold")
                text_thr = gr.Slider(0.0, 1.0, value=Config.DEFAULT_TEXT_THRESHOLD, step=0.01,
                                     label="GroundingDINO text threshold")
                conf_thr = gr.Slider(0.0, 1.0, value=Config.DEFAULT_CONF_THRESHOLD, step=0.01,
                                     label="Post-NMS detection confidence filter")
                max_incl = gr.Slider(0, 10, value=Config.DEFAULT_MAX_INCLUSIONS, step=1,
                                     label="Max inclusions (filter big enclosing boxes)")
            reinit_ck = gr.Checkbox(value=False, label="Re-initialize model from prompt (optional)")
            run_btn = gr.Button("Run segmentation", variant="primary")

        with gr.Column():
            out_ann = gr.Image(label="Annotated image", type="numpy")
            out_df = gr.Dataframe(label="Detections", interactive=False)
            out_zip = gr.File(label="Download outputs (zip)")

    with gr.Row(equal_height=True):
        out_mask = gr.Image(label="Combined mask (0/255)", type="numpy")
        out_classified = gr.Image(label="Cleaned & Classified Output (from SVG)", type="numpy")

    with gr.Row(equal_height=True):
        out_svg_btn = gr.DownloadButton(label="⬇️ Download mask as SVG")
        out_classified_svg_btn = gr.DownloadButton(label="⬇️ Download Cleaned & Classified SVG")

    run_btn.click(
        fn=run_inference,
        inputs=[inp_image, prompt, box_thr, text_thr, conf_thr, max_incl, reinit_ck],
        outputs=[out_ann, out_mask, out_df, out_svg_btn, out_zip, out_classified, out_classified_svg_btn],
        api_name="segment"
    )

# You may set share=True if you want a public link
# demo.launch(quiet=True, debug=True, prevent_thread_lock=True)
demo.launch(quiet=True, debug=True, prevent_thread_lock=True)


* Running on public URL: https://911de03825aac8ec2a.gradio.live


trying to load grounding dino directly
final text_encoder_type: bert-base-uncased


You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
The `device` argument is deprecated and will be removed in v5 of Transformers.
torch.utils.checkpoint: the use_reentrant parameter should be pas

Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7860 <> https://911de03825aac8ec2a.gradio.live


