# DeepRoof Checkpoint Inference (PNG/JPG)

Runs segmentation on any satellite image using a trained DeepRoof checkpoint.

**Preprocessing:** Center-crop to square → resize to 512×512 (training resolution).

**Post-processing:** Uses panoptic-style per-pixel query assignment with mask thresholding
instead of the default MMSeg einsum aggregation (which fails when queries are class-biased).

In [None]:
from __future__ import annotations

import json
import os
import sys
from pathlib import Path

import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F

In [None]:
# ======================== CONFIGURATION ========================
MODEL_INPUT_SIZE = 512  # Training resolution (native OmniCity size)
MASK_THRESHOLD = 0.5    # Min sigmoid confidence for a pixel to be non-background


def detect_project_root() -> Path:
    candidates = [
        Path.cwd(),
        Path.cwd().parent,
        Path('/workspace/roof'),
        Path('/Users/voskan/Desktop/DeepRoof-2026'),
    ]
    for c in candidates:
        if (c / 'configs').exists() and (c / 'deeproof').exists():
            return c
    raise FileNotFoundError('Could not auto-detect project root.')


def resolve_checkpoint(work_dir: Path, fallback_ckpt: Path) -> Path:
    last_ckpt_ptr = work_dir / 'last_checkpoint'
    if last_ckpt_ptr.exists():
        target = last_ckpt_ptr.read_text(encoding='utf-8').strip()
        if target:
            t = Path(target)
            if not t.is_absolute():
                t = work_dir / t
            if t.exists():
                return t
    return fallback_ckpt


PROJECT_ROOT = detect_project_root()
CONFIG_PATH = PROJECT_ROOT / 'configs' / 'deeproof_production_swin_L.py'

WORK_DIR = PROJECT_ROOT / 'work_dirs' / 'swin_l_scratch_v1'
SERVER_CHECKPOINT_PATH = WORK_DIR / 'iter_16000.pth'
LOCAL_ANALYSIS_CHECKPOINT_PATH = Path('/Users/voskan/Downloads/iter_16000.pth')
CHECKPOINT_PATH = resolve_checkpoint(WORK_DIR, SERVER_CHECKPOINT_PATH)
if not CHECKPOINT_PATH.exists() and LOCAL_ANALYSIS_CHECKPOINT_PATH.exists():
    CHECKPOINT_PATH = LOCAL_ANALYSIS_CHECKPOINT_PATH

# ---- TEST IMAGE ----
# Option A: External image (any size — will be center-cropped and resized)
INPUT_IMAGE_PATH = Path('/workspace/test.png')
# Option B: OmniCity validation image (uncomment to test on training distribution)
# INPUT_IMAGE_PATH = Path('data/OmniCity/images/FIRST_VAL_IMAGE.jpg')
if not INPUT_IMAGE_PATH.exists():
    fallback = PROJECT_ROOT / 'test.png'
    if fallback.exists():
        INPUT_IMAGE_PATH = fallback

OUTPUT_DIR = PROJECT_ROOT / 'outputs' / 'checkpoint_inference'
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
OVERLAY_PATH = OUTPUT_DIR / 'test_segmentation_overlay.png'
SEM_MASK_PATH = OUTPUT_DIR / 'test_semantic_mask.png'
SUMMARY_PATH = OUTPUT_DIR / 'test_inference_summary.json'
POLYGONS_JSON_PATH = OUTPUT_DIR / 'test_roof_polygons.json'
POLYGONS_GEOJSON_PATH = OUTPUT_DIR / 'test_roof_polygons.geojson'
PREPROCESSED_PATH = OUTPUT_DIR / 'test_preprocessed.png'

DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'

print(f'PROJECT_ROOT: {PROJECT_ROOT}')
print(f'CONFIG: {CONFIG_PATH}')
print(f'CHECKPOINT: {CHECKPOINT_PATH}')
print(f'INPUT: {INPUT_IMAGE_PATH}')
print(f'MODEL_INPUT_SIZE: {MODEL_INPUT_SIZE}x{MODEL_INPUT_SIZE}')
print(f'MASK_THRESHOLD: {MASK_THRESHOLD}')
print(f'DEVICE: {DEVICE}')

In [None]:
# ======================== IMAGE PREPROCESSING ========================
for p in (CONFIG_PATH, CHECKPOINT_PATH, INPUT_IMAGE_PATH):
    if not p.exists():
        raise FileNotFoundError(f'Path not found: {p}')

img_orig_bgr = cv2.imread(str(INPUT_IMAGE_PATH), cv2.IMREAD_COLOR)
if img_orig_bgr is None:
    raise RuntimeError(f'Could not load image: {INPUT_IMAGE_PATH}')

orig_h, orig_w = img_orig_bgr.shape[:2]
print(f'Original image size: {orig_w}x{orig_h}')

# Center-crop to square
crop_size = min(orig_h, orig_w)
y_start = (orig_h - crop_size) // 2
x_start = (orig_w - crop_size) // 2
img_cropped = img_orig_bgr[y_start:y_start+crop_size, x_start:x_start+crop_size]

# Resize to model input
img_bgr = cv2.resize(img_cropped, (MODEL_INPUT_SIZE, MODEL_INPUT_SIZE), interpolation=cv2.INTER_AREA)
cv2.imwrite(str(PREPROCESSED_PATH), img_bgr)
print(f'Preprocessed: {orig_w}x{orig_h} -> center crop {crop_size}x{crop_size} -> resize {MODEL_INPUT_SIZE}x{MODEL_INPUT_SIZE}')

img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
H, W = img_rgb.shape[:2]

In [None]:
# ======================== LOAD MODEL ========================
os.environ.setdefault('TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD', '1')

if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from mmseg.utils import register_all_modules
from mmseg.apis import init_model
from mmengine.config import ConfigDict
from mmengine.dataset import Compose

register_all_modules(init_default_scope=False)

import deeproof.models.backbones.swin_v2_compat
import deeproof.models.deeproof_model
import deeproof.models.heads.mask2former_head
import deeproof.models.heads.geometry_head
import deeproof.models.losses

model = init_model(str(CONFIG_PATH), str(CHECKPOINT_PATH), device=DEVICE)
model.test_cfg = ConfigDict(dict(mode='whole'))
model.eval()
print('Model loaded successfully.')

In [None]:
# ======================== RAW FORWARD PASS ========================
# Instead of using inference_model (which uses einsum aggregation),
# we run the model manually and build the semantic map ourselves
# using panoptic-style per-pixel query assignment.

# Build test pipeline and preprocess
test_pipeline = model.cfg.get('test_pipeline', [
    dict(type='LoadImageFromFile'),
    dict(type='Resize', scale=(MODEL_INPUT_SIZE, MODEL_INPUT_SIZE), keep_ratio=False),
    dict(type='PackSegInputs'),
])
pipeline = Compose(test_pipeline)
data = pipeline(dict(img_path=str(PREPROCESSED_PATH)))

# Run through data preprocessor (normalize, pad)
data_batch = dict(
    inputs=[data['inputs']],
    data_samples=[data['data_samples']],
)
data_batch = model.data_preprocessor(data_batch, False)
inputs = data_batch['inputs']

# Forward: backbone + decode_head
with torch.no_grad():
    x = model.extract_feat(inputs)
    all_cls_scores, all_mask_preds = model.decode_head(x, data_batch['data_samples'])

# Last decoder layer outputs (most refined)
cls_scores = all_cls_scores[-1][0]   # [Q, C+1]  (100 queries, 4 classes incl. no-obj)
mask_preds = all_mask_preds[-1][0]   # [Q, h, w]  (low-res mask logits, e.g. 128x128)

num_queries = cls_scores.shape[0]
num_classes = cls_scores.shape[1] - 1  # exclude no-object

print(f'cls_scores: {cls_scores.shape}  (queries={num_queries}, classes={num_classes}+no_obj)')
print(f'mask_preds: {mask_preds.shape}')

In [None]:
# ======================== PANOPTIC-STYLE SEMANTIC MAP ========================
# Standard Mask2Former does:  sem = einsum('qc,qhw->chw', softmax(cls)[:,:-1], sigmoid(mask)).argmax(0)
# This FAILS when all queries predict the same class (common during early training).
#
# Better approach (from Mask2Former panoptic paper):
#   1. For each pixel, find the query with the HIGHEST mask confidence
#   2. If that confidence > threshold → assign that query's predicted class
#   3. If confidence < threshold → background (class 0)

cls_probs = torch.softmax(cls_scores, dim=-1)  # [Q, C+1]
obj_probs = cls_probs[:, :-1]                   # [Q, C]  (without no-object)
no_obj_probs = cls_probs[:, -1]                  # [Q]

mask_sigmoid = mask_preds.sigmoid()              # [Q, h, w]

# Per-query predicted class (best object class)
query_class = obj_probs.argmax(dim=-1)           # [Q]
query_obj_confidence = obj_probs.max(dim=-1).values  # [Q]

# Score each pixel: combine class confidence with mask confidence
# score[q, h, w] = P(best_class_q) * sigmoid(mask_q[h,w])
pixel_scores = query_obj_confidence.unsqueeze(-1).unsqueeze(-1) * mask_sigmoid  # [Q, h, w]

# For each pixel, which query has the highest combined score?
best_score_per_pixel, best_query_per_pixel = pixel_scores.max(dim=0)  # [h, w]

# Assign class from the winning query
sem_map_lowres = query_class[best_query_per_pixel]  # [h, w]

# Background: pixels where best mask confidence is below threshold
best_mask_per_pixel = mask_sigmoid.max(dim=0).values  # [h, w]
sem_map_lowres[best_mask_per_pixel < MASK_THRESHOLD] = 0

# Upscale to original image resolution
sem_map = F.interpolate(
    sem_map_lowres.float().unsqueeze(0).unsqueeze(0),
    size=(H, W), mode='nearest'
)[0, 0].long().cpu().numpy().astype(np.uint8)

# ======================== DIAGNOSTICS ========================
unique_classes = np.unique(sem_map).tolist()
class_areas = {int(c): float((sem_map == c).sum()) / float(sem_map.size) for c in unique_classes}
roof_ratio = float((sem_map > 0).sum()) / float(sem_map.size)

print(f'\n=== SEMANTIC MAP (panoptic post-processing, threshold={MASK_THRESHOLD}) ===')
print(f'Unique classes: {unique_classes}')
print(f'Class area ratios: {class_areas}')
print(f'Roof pixel ratio: {roof_ratio:.4f}')

# Per-query diagnostics
print(f'\n=== PER-QUERY ANALYSIS (top 20 by confidence) ===')
sorted_idx = query_obj_confidence.argsort(descending=True)
for rank, qi in enumerate(sorted_idx[:20]):
    qi = int(qi)
    mask_cov = float((mask_sigmoid[qi] > 0.5).float().mean())
    print(f'  Query {qi:3d}: class={int(query_class[qi])}, '
          f'P(class)={float(query_obj_confidence[qi]):.3f}, '
          f'P(no-obj)={float(no_obj_probs[qi]):.3f}, '
          f'mask_coverage={mask_cov:.3f}')

# Mask statistics
print(f'\n=== MASK STATISTICS ===')
print(f'Sigmoid range: {float(mask_sigmoid.min()):.4f} .. {float(mask_sigmoid.max()):.4f}')
print(f'Sigmoid mean: {float(mask_sigmoid.mean()):.4f}')
print(f'Queries with >50% coverage: {int((mask_sigmoid.mean(dim=(-1,-2)) > 0.5).sum())} / {num_queries}')
print(f'Queries with any coverage (>0.5): {int(((mask_sigmoid > 0.5).any(dim=-1).any(dim=-1)).sum())} / {num_queries}')

In [None]:
# ======================== BUILD INSTANCES ========================
# Extract connected components from the thresholded semantic map
masks_list = []
labels_list = []
scores_list = []
MIN_AREA = 64

for cls_id in np.unique(sem_map):
    cls_id = int(cls_id)
    if cls_id <= 0 or cls_id == 255:
        continue
    binary = (sem_map == cls_id).astype(np.uint8)
    num_comp, comp = cv2.connectedComponents(binary, connectivity=8)
    for comp_id in range(1, int(num_comp)):
        m = (comp == comp_id)
        area = int(m.sum())
        if area < MIN_AREA:
            continue
        masks_list.append(m)
        labels_list.append(cls_id)
        # Score = mean mask sigmoid within this component
        m_lowres = cv2.resize(m.astype(np.uint8), (mask_sigmoid.shape[-1], mask_sigmoid.shape[-2]),
                              interpolation=cv2.INTER_NEAREST).astype(bool)
        mean_conf = float(mask_sigmoid.max(dim=0).values[torch.from_numpy(m_lowres)].mean())
        scores_list.append(mean_conf)

masks = np.stack(masks_list, axis=0) if masks_list else np.zeros((0, H, W), dtype=bool)
labels = np.array(labels_list, dtype=np.int64)
scores = np.array(scores_list, dtype=np.float32)

print(f'Extracted {len(masks)} instances from semantic map')
for i in range(min(10, len(masks))):
    print(f'  Instance {i}: class={labels[i]}, score={scores[i]:.3f}, area={int(masks[i].sum())}px')

In [None]:
# ======================== VISUALIZATION ========================
palette = np.array([
    [0, 0, 0],       # 0: background (black)
    [0, 255, 0],     # 1: flat_roof (green)
    [255, 0, 0],     # 2: sloped_roof (red)
], dtype=np.uint8)

sem_vis = palette[np.clip(sem_map, 0, len(palette) - 1)]
overlay = cv2.addWeighted(img_rgb, 0.60, sem_vis, 0.40, 0.0)

# Polygon extraction
roof_polygons = []
MIN_SCORE = 0.3
MIN_POLY_AREA = 64

for i, mask in enumerate(masks):
    if i < len(scores) and float(scores[i]) < MIN_SCORE:
        continue
    contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    for contour in contours:
        area = float(cv2.contourArea(contour))
        if area < MIN_POLY_AREA:
            continue
        epsilon = 0.003 * cv2.arcLength(contour, True)
        approx = cv2.approxPolyDP(contour, epsilon, True)
        poly = approx.reshape(-1, 2)
        if poly.shape[0] < 3:
            continue
        cls_id = int(labels[i]) if i < len(labels) else 1
        score = float(scores[i]) if i < len(scores) else None
        roof_polygons.append({
            'polygon_id': len(roof_polygons),
            'class_id': cls_id,
            'score': score,
            'area_px': area,
            'points_xy': poly.astype(int).tolist(),
        })

# Draw polygons
for poly_obj in roof_polygons:
    pts = np.array(poly_obj['points_xy'], dtype=np.int32).reshape(-1, 1, 2)
    color = (0, 255, 0) if poly_obj['class_id'] == 1 else (255, 0, 0)
    cv2.polylines(overlay, [pts], True, (255, 255, 255), 2)
    m = cv2.moments(pts)
    if m['m00'] > 0:
        cx, cy = int(m['m10'] / m['m00']), int(m['m01'] / m['m00'])
    else:
        cx, cy = int(pts[0, 0, 0]), int(pts[0, 0, 1])
    txt = f"cls:{poly_obj['class_id']}"
    if poly_obj['score'] is not None:
        txt += f" {poly_obj['score']:.2f}"
    cv2.putText(overlay, txt, (cx, max(0, cy - 4)), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 0), 1, cv2.LINE_AA)

# Save
cv2.imwrite(str(OVERLAY_PATH), cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR))
cv2.imwrite(str(SEM_MASK_PATH), sem_map)
with POLYGONS_JSON_PATH.open('w', encoding='utf-8') as f:
    json.dump({'image': str(INPUT_IMAGE_PATH), 'polygons': roof_polygons}, f, indent=2)

# GeoJSON
geojson = {'type': 'FeatureCollection', 'features': []}
for poly_obj in roof_polygons:
    coords = [[float(x), float(y)] for x, y in poly_obj['points_xy']]
    if coords and coords[0] != coords[-1]:
        coords.append(coords[0])
    geojson['features'].append({
        'type': 'Feature',
        'properties': {k: v for k, v in poly_obj.items() if k != 'points_xy'},
        'geometry': {'type': 'Polygon', 'coordinates': [coords]},
    })
with POLYGONS_GEOJSON_PATH.open('w', encoding='utf-8') as f:
    json.dump(geojson, f, indent=2)

# Plot
plt.figure(figsize=(18, 6))
plt.subplot(1, 3, 1)
plt.imshow(img_rgb)
plt.title(f'Input ({W}x{H})')
plt.axis('off')
plt.subplot(1, 3, 2)
plt.imshow(sem_vis)
plt.title(f'Semantic (classes: {unique_classes})')
plt.axis('off')
plt.subplot(1, 3, 3)
plt.imshow(overlay)
plt.title(f'Overlay ({len(roof_polygons)} polygons)')
plt.axis('off')
plt.tight_layout()
plt.show()

print(f'Saved overlay: {OVERLAY_PATH}')
print(f'Saved polygons: {POLYGONS_JSON_PATH}')
print(f'Total roof polygons: {len(roof_polygons)}')

In [None]:
# ======================== MASK VISUALIZATION ========================
# Show top 6 queries' raw masks to verify spatial discrimination

top_queries = sorted_idx[:6]
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
for idx, qi in enumerate(top_queries):
    qi = int(qi)
    ax = axes[idx // 3][idx % 3]
    mask_vis = mask_sigmoid[qi].cpu().numpy()
    ax.imshow(mask_vis, cmap='hot', vmin=0, vmax=1)
    cov = float((mask_sigmoid[qi] > 0.5).float().mean())
    ax.set_title(f'Query {qi}: cls={int(query_class[qi])}, P={float(query_obj_confidence[qi]):.2f}, cov={cov:.3f}')
    ax.axis('off')
plt.suptitle('Top 6 query masks (sigmoid)', fontsize=14)
plt.tight_layout()
plt.show()

In [None]:
# ======================== SUMMARY ========================
summary = {
    'input_image': str(INPUT_IMAGE_PATH),
    'checkpoint': str(CHECKPOINT_PATH),
    'original_size': [int(orig_h), int(orig_w)],
    'model_input_size': MODEL_INPUT_SIZE,
    'mask_threshold': MASK_THRESHOLD,
    'semantic_classes': unique_classes,
    'class_areas': class_areas,
    'roof_pixel_ratio': roof_ratio,
    'polygon_count': len(roof_polygons),
}
with SUMMARY_PATH.open('w', encoding='utf-8') as f:
    json.dump(summary, f, indent=2)
print(json.dumps(summary, indent=2))