# DeepRoof Checkpoint Inference

Runs full segmentation + geometry inference on any satellite image.

**Output per roof instance:**
- Semantic class (flat / sloped)
- Instance mask + polygon
- Surface normal vector (nx, ny, nz)
- Slope angle in degrees

In [None]:
from __future__ import annotations

import json
import os
import sys
from pathlib import Path

import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F

In [None]:
# ======================== CONFIGURATION ========================
MODEL_INPUT_SIZE = 512  # Training resolution (native OmniCity size)
MASK_THRESHOLD = 0.5    # Min sigmoid confidence for a pixel to be non-background


def detect_project_root() -> Path:
    candidates = [
        Path.cwd(),
        Path.cwd().parent,
        Path('/workspace/roof'),
        Path('/Users/voskan/Desktop/DeepRoof-2026'),
    ]
    for c in candidates:
        if (c / 'configs').exists() and (c / 'deeproof').exists():
            return c
    raise FileNotFoundError('Could not auto-detect project root.')


def resolve_checkpoint(work_dir: Path, fallback_ckpt: Path) -> Path:
    last_ckpt_ptr = work_dir / 'last_checkpoint'
    if last_ckpt_ptr.exists():
        target = last_ckpt_ptr.read_text(encoding='utf-8').strip()
        if target:
            t = Path(target)
            if not t.is_absolute():
                t = work_dir / t
            if t.exists():
                return t
    return fallback_ckpt


PROJECT_ROOT = detect_project_root()
CONFIG_PATH = PROJECT_ROOT / 'configs' / 'deeproof_production_swin_L.py'

WORK_DIR = PROJECT_ROOT / 'work_dirs' / 'swin_l_scratch_v1'
SERVER_CHECKPOINT_PATH = WORK_DIR / 'iter_16000.pth'
LOCAL_ANALYSIS_CHECKPOINT_PATH = Path('/Users/voskan/Downloads/iter_16000.pth')
CHECKPOINT_PATH = resolve_checkpoint(WORK_DIR, SERVER_CHECKPOINT_PATH)
if not CHECKPOINT_PATH.exists() and LOCAL_ANALYSIS_CHECKPOINT_PATH.exists():
    CHECKPOINT_PATH = LOCAL_ANALYSIS_CHECKPOINT_PATH

INPUT_IMAGE_PATH = Path('/workspace/test.png')
if not INPUT_IMAGE_PATH.exists():
    fallback = PROJECT_ROOT / 'test.png'
    if fallback.exists():
        INPUT_IMAGE_PATH = fallback

OUTPUT_DIR = PROJECT_ROOT / 'outputs' / 'checkpoint_inference'
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
OVERLAY_PATH = OUTPUT_DIR / 'test_segmentation_overlay.png'
SEM_MASK_PATH = OUTPUT_DIR / 'test_semantic_mask.png'
SUMMARY_PATH = OUTPUT_DIR / 'test_inference_summary.json'
POLYGONS_JSON_PATH = OUTPUT_DIR / 'test_roof_polygons.json'
POLYGONS_GEOJSON_PATH = OUTPUT_DIR / 'test_roof_polygons.geojson'
PREPROCESSED_PATH = OUTPUT_DIR / 'test_preprocessed.png'

DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'

print(f'CHECKPOINT: {CHECKPOINT_PATH}')
print(f'INPUT: {INPUT_IMAGE_PATH}')
print(f'MODEL_INPUT_SIZE: {MODEL_INPUT_SIZE}, MASK_THRESHOLD: {MASK_THRESHOLD}')

In [None]:
# ======================== IMAGE PREPROCESSING ========================
for p in (CONFIG_PATH, CHECKPOINT_PATH, INPUT_IMAGE_PATH):
    if not p.exists():
        raise FileNotFoundError(f'Path not found: {p}')

img_orig_bgr = cv2.imread(str(INPUT_IMAGE_PATH), cv2.IMREAD_COLOR)
if img_orig_bgr is None:
    raise RuntimeError(f'Could not load image: {INPUT_IMAGE_PATH}')

orig_h, orig_w = img_orig_bgr.shape[:2]

# Center-crop to square, then resize to training resolution
crop_size = min(orig_h, orig_w)
y_start = (orig_h - crop_size) // 2
x_start = (orig_w - crop_size) // 2
img_cropped = img_orig_bgr[y_start:y_start+crop_size, x_start:x_start+crop_size]
img_bgr = cv2.resize(img_cropped, (MODEL_INPUT_SIZE, MODEL_INPUT_SIZE), interpolation=cv2.INTER_AREA)
cv2.imwrite(str(PREPROCESSED_PATH), img_bgr)
print(f'Preprocessed: {orig_w}x{orig_h} -> crop {crop_size}x{crop_size} -> {MODEL_INPUT_SIZE}x{MODEL_INPUT_SIZE}')

img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
H, W = img_rgb.shape[:2]

In [None]:
# ======================== LOAD MODEL ========================
os.environ.setdefault('TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD', '1')

if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from mmseg.utils import register_all_modules
from mmseg.apis import init_model
from mmengine.config import ConfigDict
from mmengine.dataset import Compose

register_all_modules(init_default_scope=False)

import deeproof.models.backbones.swin_v2_compat
import deeproof.models.deeproof_model
import deeproof.models.heads.mask2former_head
import deeproof.models.heads.geometry_head
import deeproof.models.losses

model = init_model(str(CONFIG_PATH), str(CHECKPOINT_PATH), device=DEVICE)
model.test_cfg = ConfigDict(dict(mode='whole'))
model.eval()
print('Model loaded successfully.')
print(f'Has geometry_head: {hasattr(model, "geometry_head")}')

In [None]:
# ======================== FORWARD PASS + GEOMETRY ========================
# Run backbone + decode_head + geometry_head manually
# (we bypass inference_model to use panoptic-style post-processing)

test_pipeline = model.cfg.get('test_pipeline', [
    dict(type='LoadImageFromFile'),
    dict(type='Resize', scale=(MODEL_INPUT_SIZE, MODEL_INPUT_SIZE), keep_ratio=False),
    dict(type='PackSegInputs'),
])
pipeline = Compose(test_pipeline)
data = pipeline(dict(img_path=str(PREPROCESSED_PATH)))

data_batch = dict(
    inputs=[data['inputs']],
    data_samples=[data['data_samples']],
)
data_batch = model.data_preprocessor(data_batch, False)
inputs = data_batch['inputs']

with torch.no_grad():
    # 1. Backbone features
    x = model.extract_feat(inputs)
    
    # 2. Mask2Former decode head → cls_scores + mask_preds + query embeddings
    all_cls_scores, all_mask_preds = model.decode_head(x, data_batch['data_samples'])
    
    # 3. GeometryHead → surface normals per query
    query_embeddings = getattr(model.decode_head, 'last_query_embeddings', None)
    geo_preds = None  # [B, Q, 3] predicted normals
    if query_embeddings is not None and hasattr(model, 'geometry_head'):
        # Normalize query embeddings (same as model.predict does)
        qe = query_embeddings
        if isinstance(qe, (list, tuple)):
            qe = qe[-1]
        if qe.ndim == 4:
            qe = qe[-1]
        if qe.ndim == 2:
            qe = qe.unsqueeze(0)
        geo_preds = model.geometry_head(qe)  # [B, Q, 3] unit normals
        print(f'GeometryHead output: {geo_preds.shape}')
    else:
        print('WARNING: query_embeddings not available — geometry predictions skipped')

# Last decoder layer (most refined)
cls_scores = all_cls_scores[-1][0]   # [Q, C+1]
mask_preds = all_mask_preds[-1][0]   # [Q, h, w]

num_queries = cls_scores.shape[0]
num_classes = cls_scores.shape[1] - 1
print(f'cls_scores: {cls_scores.shape}, mask_preds: {mask_preds.shape}')

In [None]:
# ======================== PANOPTIC-STYLE SEMANTIC MAP ========================
cls_probs = torch.softmax(cls_scores, dim=-1)  # [Q, C+1]
obj_probs = cls_probs[:, :-1]                   # [Q, C]
no_obj_probs = cls_probs[:, -1]                  # [Q]
mask_sigmoid = mask_preds.sigmoid()              # [Q, h, w]

# Per-query: best class and confidence
query_class = obj_probs.argmax(dim=-1)           # [Q]
query_obj_confidence = obj_probs.max(dim=-1).values  # [Q]

# Per-pixel: combined score = P(class) * sigmoid(mask)
pixel_scores = query_obj_confidence.unsqueeze(-1).unsqueeze(-1) * mask_sigmoid
best_score_per_pixel, best_query_per_pixel = pixel_scores.max(dim=0)

# Assign class from winning query
sem_map_lowres = query_class[best_query_per_pixel]

# Background: where no mask is confident enough
best_mask_per_pixel = mask_sigmoid.max(dim=0).values
sem_map_lowres[best_mask_per_pixel < MASK_THRESHOLD] = 0

# Upscale to image resolution
sem_map = F.interpolate(
    sem_map_lowres.float().unsqueeze(0).unsqueeze(0),
    size=(H, W), mode='nearest'
)[0, 0].long().cpu().numpy().astype(np.uint8)

# Also create per-pixel query assignment map (upscaled) for geometry lookup
query_map = F.interpolate(
    best_query_per_pixel.float().unsqueeze(0).unsqueeze(0),
    size=(H, W), mode='nearest'
)[0, 0].long().cpu().numpy()

unique_classes = np.unique(sem_map).tolist()
class_areas = {int(c): float((sem_map == c).sum()) / float(sem_map.size) for c in unique_classes}
roof_ratio = float((sem_map > 0).sum()) / float(sem_map.size)

print(f'Semantic classes: {unique_classes}')
print(f'Class areas: {class_areas}')
print(f'Roof pixel ratio: {roof_ratio:.4f}')

In [None]:
# ======================== BUILD INSTANCES WITH GEOMETRY ========================
instances = []  # List of dicts with mask, class, score, normal, slope
MIN_AREA = 64

# Get per-query normals from GeometryHead
query_normals = geo_preds[0].cpu().numpy() if geo_preds is not None else None  # [Q, 3]

for cls_id in np.unique(sem_map):
    cls_id = int(cls_id)
    if cls_id <= 0 or cls_id == 255:
        continue
    binary = (sem_map == cls_id).astype(np.uint8)
    num_comp, comp = cv2.connectedComponents(binary, connectivity=8)
    for comp_id in range(1, int(num_comp)):
        m = (comp == comp_id)
        area = int(m.sum())
        if area < MIN_AREA:
            continue
        
        # Score: mean mask sigmoid in this region
        m_lowres = cv2.resize(m.astype(np.uint8), (mask_sigmoid.shape[-1], mask_sigmoid.shape[-2]),
                              interpolation=cv2.INTER_NEAREST).astype(bool)
        mean_conf = float(mask_sigmoid.max(dim=0).values[torch.from_numpy(m_lowres)].mean())
        
        # Geometry: find the dominant query in this region and get its normal
        region_queries = query_map[m]  # query indices for all pixels in this instance
        dominant_query = int(np.bincount(region_queries).argmax())  # most common query
        
        normal = None
        slope_deg = None
        azimuth_deg = None
        if query_normals is not None:
            normal = query_normals[dominant_query]  # [3] = (nx, ny, nz)
            # Slope = angle from vertical (Z axis)
            nz = float(np.clip(normal[2], -1.0, 1.0))
            slope_deg = float(np.degrees(np.arccos(abs(nz))))
            # Azimuth = compass direction the slope faces (from nx, ny)
            nx, ny = float(normal[0]), float(normal[1])
            azimuth_deg = float(np.degrees(np.arctan2(ny, nx))) % 360
        
        instances.append({
            'mask': m,
            'class_id': cls_id,
            'class_name': {0: 'background', 1: 'flat_roof', 2: 'sloped_roof'}.get(cls_id, f'class_{cls_id}'),
            'score': mean_conf,
            'area_px': area,
            'dominant_query': dominant_query,
            'normal': normal.tolist() if normal is not None else None,
            'slope_deg': slope_deg,
            'azimuth_deg': azimuth_deg,
        })

print(f'\n=== DETECTED ROOF INSTANCES: {len(instances)} ===')
print(f'{"#":>3} {"Class":>12} {"Slope°":>7} {"Azimuth°":>9} {"Normal (nx,ny,nz)":>25} {"Score":>6} {"Area":>8}')
print('-' * 80)
for i, inst in enumerate(instances):
    n_str = f'({inst["normal"][0]:+.3f}, {inst["normal"][1]:+.3f}, {inst["normal"][2]:+.3f})' if inst['normal'] else 'N/A'
    s_str = f'{inst["slope_deg"]:.1f}' if inst['slope_deg'] is not None else 'N/A'
    a_str = f'{inst["azimuth_deg"]:.0f}' if inst['azimuth_deg'] is not None else 'N/A'
    print(f'{i:3d} {inst["class_name"]:>12} {s_str:>7} {a_str:>9} {n_str:>25} {inst["score"]:6.3f} {inst["area_px"]:8d}')

In [None]:
# ======================== VISUALIZATION ========================
palette = np.array([
    [0, 0, 0],       # 0: background
    [0, 255, 0],     # 1: flat_roof
    [255, 0, 0],     # 2: sloped_roof
], dtype=np.uint8)

sem_vis = palette[np.clip(sem_map, 0, len(palette) - 1)]
overlay = cv2.addWeighted(img_rgb, 0.60, sem_vis, 0.40, 0.0)

# Extract polygons and draw with slope/azimuth labels
roof_polygons = []
MIN_POLY_AREA = 64

for i, inst in enumerate(instances):
    contours, _ = cv2.findContours(inst['mask'].astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    for contour in contours:
        area = float(cv2.contourArea(contour))
        if area < MIN_POLY_AREA:
            continue
        epsilon = 0.003 * cv2.arcLength(contour, True)
        approx = cv2.approxPolyDP(contour, epsilon, True)
        poly = approx.reshape(-1, 2)
        if poly.shape[0] < 3:
            continue

        poly_data = {
            'polygon_id': len(roof_polygons),
            'class_id': inst['class_id'],
            'class_name': inst['class_name'],
            'score': inst['score'],
            'area_px': area,
            'slope_deg': inst['slope_deg'],
            'azimuth_deg': inst['azimuth_deg'],
            'normal': inst['normal'],
            'points_xy': poly.astype(int).tolist(),
        }
        roof_polygons.append(poly_data)

# Draw polygons with slope labels
for poly_obj in roof_polygons:
    pts = np.array(poly_obj['points_xy'], dtype=np.int32).reshape(-1, 1, 2)
    color = (0, 255, 0) if poly_obj['class_id'] == 1 else (255, 100, 100)
    cv2.polylines(overlay, [pts], True, (255, 255, 255), 2)
    
    m = cv2.moments(pts)
    if m['m00'] > 0:
        cx, cy = int(m['m10'] / m['m00']), int(m['m01'] / m['m00'])
    else:
        cx, cy = int(pts[0, 0, 0]), int(pts[0, 0, 1])
    
    # Label with slope angle
    if poly_obj['slope_deg'] is not None:
        txt = f"{poly_obj['slope_deg']:.0f}deg"
    else:
        txt = poly_obj['class_name']
    cv2.putText(overlay, txt, (cx - 15, cy), cv2.FONT_HERSHEY_SIMPLEX, 0.45, (255, 255, 0), 1, cv2.LINE_AA)

# Save
cv2.imwrite(str(OVERLAY_PATH), cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR))
cv2.imwrite(str(SEM_MASK_PATH), sem_map)

# Plot
plt.figure(figsize=(18, 6))
plt.subplot(1, 3, 1)
plt.imshow(img_rgb)
plt.title(f'Input ({W}x{H})')
plt.axis('off')
plt.subplot(1, 3, 2)
plt.imshow(sem_vis)
plt.title(f'Semantic (bg={class_areas.get(0,0):.0%}, roof={roof_ratio:.0%})')
plt.axis('off')
plt.subplot(1, 3, 3)
plt.imshow(overlay)
plt.title(f'{len(roof_polygons)} roof planes with slope angles')
plt.axis('off')
plt.tight_layout()
plt.show()

print(f'\nSaved overlay: {OVERLAY_PATH}')
print(f'Total roof polygons: {len(roof_polygons)}')

In [None]:
# ======================== TOP-6 QUERY MASKS ========================
sorted_idx = query_obj_confidence.argsort(descending=True)
top_queries = sorted_idx[:6]
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
for idx, qi in enumerate(top_queries):
    qi = int(qi)
    ax = axes[idx // 3][idx % 3]
    mask_vis = mask_sigmoid[qi].cpu().numpy()
    ax.imshow(mask_vis, cmap='hot', vmin=0, vmax=1)
    cov = float((mask_sigmoid[qi] > 0.5).float().mean())
    n_str = ''
    if query_normals is not None:
        n = query_normals[qi]
        slope = float(np.degrees(np.arccos(np.clip(abs(n[2]), -1, 1))))
        n_str = f', slope={slope:.0f}°'
    ax.set_title(f'Q{qi}: cls={int(query_class[qi])}, cov={cov:.3f}{n_str}')
    ax.axis('off')
plt.suptitle('Top 6 query masks (sigmoid) + slope angles', fontsize=14)
plt.tight_layout()
plt.show()

In [None]:
# ======================== SAVE JSON OUTPUT ========================
# Full JSON with polygons + geometry
json_polygons = []
for p in roof_polygons:
    json_polygons.append({k: v for k, v in p.items() if k != 'mask'})

with POLYGONS_JSON_PATH.open('w', encoding='utf-8') as f:
    json.dump({
        'image': str(INPUT_IMAGE_PATH),
        'checkpoint': str(CHECKPOINT_PATH),
        'model_input_size': MODEL_INPUT_SIZE,
        'mask_threshold': MASK_THRESHOLD,
        'polygons': json_polygons,
    }, f, indent=2)

# GeoJSON
geojson = {'type': 'FeatureCollection', 'features': []}
for p in roof_polygons:
    coords = [[float(x), float(y)] for x, y in p['points_xy']]
    if coords and coords[0] != coords[-1]:
        coords.append(coords[0])
    geojson['features'].append({
        'type': 'Feature',
        'properties': {
            'polygon_id': p['polygon_id'],
            'class_name': p['class_name'],
            'slope_deg': p['slope_deg'],
            'azimuth_deg': p['azimuth_deg'],
            'normal': p['normal'],
            'score': p['score'],
        },
        'geometry': {'type': 'Polygon', 'coordinates': [coords]},
    })
with POLYGONS_GEOJSON_PATH.open('w', encoding='utf-8') as f:
    json.dump(geojson, f, indent=2)

print(f'Saved: {POLYGONS_JSON_PATH}')
print(f'Saved: {POLYGONS_GEOJSON_PATH}')
print(f'\nExample polygon output:')
if json_polygons:
    print(json.dumps(json_polygons[0], indent=2))