# DeepRoof Checkpoint Inference (PNG/JPG)

Notebook validates the checkpoint, loads the DeepRoof model, runs segmentation on one image, and saves visual results.

**Preprocessing:** Any input image is automatically center-cropped to a square and resized to 512×512 to match the training resolution.

In [None]:
from __future__ import annotations

import json
import os
import sys
from pathlib import Path

import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch

In [None]:
# ======================== CONFIGURATION ========================
# Training resolution — model was trained on 512×512 OmniCity crops.
MODEL_INPUT_SIZE = 512


def detect_project_root() -> Path:
    candidates = [
        Path.cwd(),
        Path.cwd().parent,
        Path('/workspace/roof'),
        Path('/Users/voskan/Desktop/DeepRoof-2026'),
    ]
    for c in candidates:
        if (c / 'configs').exists() and (c / 'deeproof').exists():
            return c
    raise FileNotFoundError('Could not auto-detect project root with configs/ and deeproof/.')


def resolve_checkpoint(work_dir: Path, fallback_ckpt: Path) -> Path:
    """Prefer last_checkpoint pointer to avoid accidentally using stale weights."""
    last_ckpt_ptr = work_dir / 'last_checkpoint'
    if last_ckpt_ptr.exists():
        target = last_ckpt_ptr.read_text(encoding='utf-8').strip()
        if target:
            t = Path(target)
            if not t.is_absolute():
                t = work_dir / t
            if t.exists():
                return t

    if fallback_ckpt.exists():
        return fallback_ckpt

    return fallback_ckpt


PROJECT_ROOT = detect_project_root()

# Use production config — it contains test_pipeline for inference.
CONFIG_PATH = PROJECT_ROOT / 'configs' / 'deeproof_production_swin_L.py'

# Set training run directory here. Notebook will auto-pick `last_checkpoint` if present.
WORK_DIR = PROJECT_ROOT / 'work_dirs' / 'swin_l_scratch_v1'
SERVER_CHECKPOINT_PATH = WORK_DIR / 'iter_8000.pth'
LOCAL_ANALYSIS_CHECKPOINT_PATH = Path('/Users/voskan/Downloads/iter_8000.pth')
CHECKPOINT_PATH = resolve_checkpoint(WORK_DIR, SERVER_CHECKPOINT_PATH)
if not CHECKPOINT_PATH.exists() and LOCAL_ANALYSIS_CHECKPOINT_PATH.exists():
    CHECKPOINT_PATH = LOCAL_ANALYSIS_CHECKPOINT_PATH

# Set your test image here (PNG/JPG/TIF).
INPUT_IMAGE_PATH = Path('/workspace/test.png')
if not INPUT_IMAGE_PATH.exists():
    fallback = PROJECT_ROOT / 'test.png'
    if fallback.exists():
        INPUT_IMAGE_PATH = fallback

OUTPUT_DIR = PROJECT_ROOT / 'outputs' / 'checkpoint_inference'
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
OVERLAY_PATH = OUTPUT_DIR / 'test_segmentation_overlay.png'
SEM_MASK_PATH = OUTPUT_DIR / 'test_semantic_mask.png'
SUMMARY_PATH = OUTPUT_DIR / 'test_inference_summary.json'
POLYGONS_JSON_PATH = OUTPUT_DIR / 'test_roof_polygons.json'
POLYGONS_GEOJSON_PATH = OUTPUT_DIR / 'test_roof_polygons.geojson'
PREPROCESSED_PATH = OUTPUT_DIR / 'test_preprocessed.png'

DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'

print(f'PROJECT_ROOT: {PROJECT_ROOT}')
print(f'CONFIG: {CONFIG_PATH}')
print(f'WORK_DIR: {WORK_DIR}')
print(f'CHECKPOINT: {CHECKPOINT_PATH}')
if CHECKPOINT_PATH.exists():
    print(f'CHECKPOINT mtime: {CHECKPOINT_PATH.stat().st_mtime}')
print(f'INPUT: {INPUT_IMAGE_PATH}')
print(f'MODEL_INPUT_SIZE: {MODEL_INPUT_SIZE}x{MODEL_INPUT_SIZE}')
print(f'DEVICE: {DEVICE}')

In [None]:
# ======================== IMAGE PREPROCESSING ========================
# Center-crop to square, then resize to MODEL_INPUT_SIZE.
# This ensures ANY input image matches what the model was trained on.

for p in (CONFIG_PATH, CHECKPOINT_PATH, INPUT_IMAGE_PATH):
    if not p.exists():
        raise FileNotFoundError(f'Path not found: {p}')

img_orig_bgr = cv2.imread(str(INPUT_IMAGE_PATH), cv2.IMREAD_COLOR)
if img_orig_bgr is None:
    raise RuntimeError(f'Could not load image: {INPUT_IMAGE_PATH}')

orig_h, orig_w = img_orig_bgr.shape[:2]
print(f'Original image size: {orig_w}x{orig_h}')

# Step 1: Center-crop to the largest square
crop_size = min(orig_h, orig_w)
y_start = (orig_h - crop_size) // 2
x_start = (orig_w - crop_size) // 2
img_cropped = img_orig_bgr[y_start:y_start+crop_size, x_start:x_start+crop_size]
print(f'Center-cropped to: {crop_size}x{crop_size}')

# Step 2: Resize to model input size
img_bgr = cv2.resize(img_cropped, (MODEL_INPUT_SIZE, MODEL_INPUT_SIZE), interpolation=cv2.INTER_AREA)
print(f'Resized to: {MODEL_INPUT_SIZE}x{MODEL_INPUT_SIZE}')

# Save preprocessed image — inference_model will load from file via test_pipeline
cv2.imwrite(str(PREPROCESSED_PATH), img_bgr)
print(f'Preprocessed image saved: {PREPROCESSED_PATH}')

# For visualization later
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
H, W = img_rgb.shape[:2]

# Also keep original for side-by-side
img_orig_rgb = cv2.cvtColor(img_orig_bgr, cv2.COLOR_BGR2RGB)

plt.figure(figsize=(14, 6))
plt.subplot(1, 2, 1)
plt.imshow(img_orig_rgb)
plt.title(f'Original ({orig_w}x{orig_h})')
plt.axis('off')
plt.subplot(1, 2, 2)
plt.imshow(img_rgb)
plt.title(f'Preprocessed ({MODEL_INPUT_SIZE}x{MODEL_INPUT_SIZE})')
plt.axis('off')
plt.show()

In [None]:
# Checkpoint compatibility inspection
os.environ.setdefault('TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD', '1')

def inspect_checkpoint(path: Path):
    try:
        ckpt = torch.load(str(path), map_location='cpu', weights_only=False)
    except TypeError:
        ckpt = torch.load(str(path), map_location='cpu')

    if not isinstance(ckpt, dict):
        return {'type': str(type(ckpt)), 'error': 'checkpoint is not a dict'}

    state_dict = ckpt.get('state_dict', ckpt.get('model', None))
    info = {
        'top_keys': list(ckpt.keys()),
        'has_state_dict': isinstance(state_dict, dict),
        'meta_keys': list(ckpt.get('meta', {}).keys()) if isinstance(ckpt.get('meta', None), dict) else [],
    }

    if isinstance(state_dict, dict):
        keys = list(state_dict.keys())
        info['num_params'] = len(keys)
        info['first_keys'] = keys[:15]
        probes = [
            'backbone.patch_embed.projection.weight',
            'decode_head.query_embed.weight',
            'geometry_head.layers.0.weight',
            'module.backbone.patch_embed.projection.weight',
        ]
        info['probe_hits'] = {k: (k in state_dict) for k in probes}

    return info

ckpt_info = inspect_checkpoint(CHECKPOINT_PATH)
print(json.dumps(ckpt_info, indent=2, ensure_ascii=False))

In [None]:
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from mmseg.utils import register_all_modules
from mmseg.apis import init_model, inference_model

register_all_modules(init_default_scope=False)

# Ensure custom modules are imported and registered.
import deeproof.models.backbones.swin_v2_compat
import deeproof.models.deeproof_model
import deeproof.models.heads.mask2former_head
import deeproof.models.heads.geometry_head
import deeproof.models.losses

model = init_model(str(CONFIG_PATH), str(CHECKPOINT_PATH), device=DEVICE)

# Use the test_pipeline from the config.
# Production config defines: LoadImageFromFile -> Resize(512,512) -> PackSegInputs
if hasattr(model.cfg, 'test_pipeline') and model.cfg.test_pipeline:
    print('Using test_pipeline from config:', model.cfg.test_pipeline)
else:
    # Fallback if test_pipeline is missing from config
    model.cfg.test_pipeline = [
        dict(type='LoadImageFromFile'),
        dict(type='Resize', scale=(MODEL_INPUT_SIZE, MODEL_INPUT_SIZE), keep_ratio=False),
        dict(type='PackSegInputs'),
    ]
    print('Using fallback test_pipeline:', model.cfg.test_pipeline)

# Use 'whole' mode for Mask2Former (slide mode adds stitching artifacts).
from mmengine.config import ConfigDict
model.test_cfg = ConfigDict(dict(mode='whole'))

model.eval()
print('Model loaded successfully.')
print('model.test_cfg:', model.test_cfg)

In [None]:
# ======================== INFERENCE ========================
# Run on the PREPROCESSED image (already center-cropped and resized to 512×512)
result = inference_model(model, str(PREPROCESSED_PATH))
if isinstance(result, (list, tuple)):
    result = result[0]

# Semantic map
if hasattr(result, 'pred_sem_seg') and hasattr(result.pred_sem_seg, 'data'):
    sem_map = result.pred_sem_seg.data.squeeze(0).detach().cpu().numpy().astype(np.uint8)
else:
    sem_map = np.zeros((H, W), dtype=np.uint8)

if sem_map.shape != (H, W):
    sem_map = cv2.resize(sem_map, (W, H), interpolation=cv2.INTER_NEAREST)

# Instances (if available)
masks = np.zeros((0, H, W), dtype=bool)
scores = np.array([], dtype=np.float32)
labels = np.array([], dtype=np.int64)
instance_source = 'none'

if hasattr(result, 'pred_instances') and result.pred_instances is not None:
    inst = result.pred_instances
    if hasattr(inst, 'masks') and inst.masks is not None:
        masks_t = inst.masks
        if torch.is_tensor(masks_t):
            masks_np = masks_t.detach().cpu().numpy().astype(bool)
            if masks_np.ndim == 3:
                resized_masks = []
                for m in masks_np:
                    if m.shape != (H, W):
                        m = cv2.resize(m.astype(np.uint8), (W, H), interpolation=cv2.INTER_NEAREST).astype(bool)
                    resized_masks.append(m)
                masks = np.stack(resized_masks, axis=0) if resized_masks else np.zeros((0, H, W), dtype=bool)

    if hasattr(inst, 'scores') and inst.scores is not None:
        scores = inst.scores.detach().cpu().numpy()
    if hasattr(inst, 'labels') and inst.labels is not None:
        labels = inst.labels.detach().cpu().numpy()

    if len(masks) > 0:
        instance_source = 'model_pred_instances'

# Fallback: build pseudo-instances from semantic connected components
if len(masks) == 0:
    comp_masks = []
    comp_labels = []
    min_area = 64
    for cls_id in np.unique(sem_map):
        cls_id = int(cls_id)
        if cls_id <= 0 or cls_id == 255:
            continue
        binary = (sem_map == cls_id).astype(np.uint8)
        num_comp, comp = cv2.connectedComponents(binary, connectivity=8)
        for comp_id in range(1, int(num_comp)):
            m = (comp == comp_id)
            if int(m.sum()) < min_area:
                continue
            comp_masks.append(m)
            comp_labels.append(cls_id)

    if comp_masks:
        masks = np.stack(comp_masks, axis=0).astype(bool)
        labels = np.array(comp_labels, dtype=np.int64)
        scores = np.array([], dtype=np.float32)
        instance_source = 'semantic_connected_components_fallback'

# ======================== DIAGNOSTICS ========================
unique_classes = np.unique(sem_map).tolist()
class_areas = {int(c): float((sem_map == c).sum()) / float(sem_map.size) for c in unique_classes}
roof_ratio = float((sem_map > 0).sum()) / float(sem_map.size)
print(f'Unique semantic classes: {unique_classes}')
print(f'Class area ratios: {class_areas}')
print(f'Roof pixel ratio: {roof_ratio:.4f}')
print(f'Predicted instances: {len(masks)} (source={instance_source})')
if len(scores) > 0:
    print(f'Score range: {float(scores.min()):.4f} .. {float(scores.max()):.4f}')
if roof_ratio > 0.90:
    print('WARNING: segmentation marks >90% pixels as roof.')
    print('  Possible causes: model still undertrained, or test image too different from training data.')

In [None]:
# ======================== RAW LOGITS DIAGNOSTICS ========================
# Check what the model actually produces BEFORE argmax
# This helps distinguish 'model not converged' from 'inference bug'

# Run a manual forward pass to inspect raw outputs
from mmengine.dataset import Compose
from copy import deepcopy

pipeline = Compose(model.cfg.test_pipeline)
data = dict(img_path=str(PREPROCESSED_PATH))
data = pipeline(data)

# Prepare input for model
data_batch = dict(
    inputs=[data['inputs']],
    data_samples=[data['data_samples']],
)
# Preprocess (normalize, pad)
data_batch = model.data_preprocessor(data_batch, False)
inputs = data_batch['inputs']

# Forward through backbone + decode_head
with torch.no_grad():
    x = model.extract_feat(inputs)
    all_cls_scores, all_mask_preds = model.decode_head(x, data_batch['data_samples'])

# Analyze last decoder layer outputs
cls_scores = all_cls_scores[-1]  # [B, num_queries, num_classes+1]
mask_preds = all_mask_preds[-1]  # [B, num_queries, H, W]

print(f'cls_scores shape: {cls_scores.shape}')
print(f'mask_preds shape: {mask_preds.shape}')

# Per-query class predictions
cls_probs = torch.softmax(cls_scores[0], dim=-1)  # [Q, C+1]
pred_classes = cls_probs[:, :-1].argmax(dim=-1)    # ignore no-obj column
no_obj_probs = cls_probs[:, -1]                     # P(no-object)
max_cls_probs = cls_probs[:, :-1].max(dim=-1).values

print(f'\n--- Per-query analysis (top 20 by confidence) ---')
sorted_idx = max_cls_probs.argsort(descending=True)
for rank, qi in enumerate(sorted_idx[:20]):
    qi = int(qi)
    mask_vals = mask_preds[0, qi].sigmoid()
    mask_coverage = float((mask_vals > 0.5).float().mean())
    print(f'  Query {qi:3d}: class={int(pred_classes[qi])}, '
          f'P(class)={float(max_cls_probs[qi]):.3f}, '
          f'P(no-obj)={float(no_obj_probs[qi]):.3f}, '
          f'mask_coverage={mask_coverage:.3f}')

# Overall statistics
mask_sigmoid = mask_preds[0].sigmoid()  # [Q, H, W]
print(f'\n--- Mask statistics ---')
print(f'mask sigmoid range: {float(mask_sigmoid.min()):.4f} .. {float(mask_sigmoid.max()):.4f}')
print(f'mask sigmoid mean: {float(mask_sigmoid.mean()):.4f}')
print(f'Queries with >50% coverage: {int((mask_sigmoid.mean(dim=(-1,-2)) > 0.5).sum())} / {mask_sigmoid.shape[0]}')
print(f'Queries with <10% coverage: {int((mask_sigmoid.mean(dim=(-1,-2)) < 0.1).sum())} / {mask_sigmoid.shape[0]}')

In [None]:
# ======================== VISUALIZATION ========================
# Palette: background, flat_roof, sloped_roof
palette = np.array([
    [0, 0, 0],
    [0, 255, 0],
    [255, 0, 0],
], dtype=np.uint8)

if hasattr(model, 'dataset_meta') and isinstance(model.dataset_meta, dict):
    model_palette = model.dataset_meta.get('palette', None)
    if model_palette is not None and len(model_palette) >= 3:
        palette = np.array(model_palette, dtype=np.uint8)

sem_vis = palette[np.clip(sem_map, 0, len(palette) - 1)]
overlay = cv2.addWeighted(img_rgb, 0.60, sem_vis, 0.40, 0.0)

# --- Polygon extraction from instance masks ---
roof_polygons = []
MIN_SCORE = 0.25
MIN_POLY_AREA = 64

for i, mask in enumerate(masks):
    if i < len(scores) and float(scores[i]) < MIN_SCORE:
        continue

    contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    for contour in contours:
        area = float(cv2.contourArea(contour))
        if area < MIN_POLY_AREA:
            continue

        epsilon = 0.003 * cv2.arcLength(contour, True)
        approx = cv2.approxPolyDP(contour, epsilon, True)
        poly = approx.reshape(-1, 2)
        if poly.shape[0] < 3:
            continue

        cls_id = int(labels[i]) if i < len(labels) else 1
        score = float(scores[i]) if i < len(scores) else None

        roof_polygons.append({
            'polygon_id': len(roof_polygons),
            'class_id': cls_id,
            'score': score,
            'area_px': area,
            'points_xy': poly.astype(int).tolist(),
        })

# Draw polygons
for poly_obj in roof_polygons:
    pts = np.array(poly_obj['points_xy'], dtype=np.int32).reshape(-1, 1, 2)
    cv2.polylines(overlay, [pts], True, (255, 255, 255), 2)

    # label near polygon centroid
    m = cv2.moments(pts)
    if m['m00'] > 0:
        cx = int(m['m10'] / m['m00'])
        cy = int(m['m01'] / m['m00'])
    else:
        cx, cy = pts[0, 0, 0], pts[0, 0, 1]

    txt = f"id:{poly_obj['polygon_id']} cls:{poly_obj['class_id']}"
    if poly_obj['score'] is not None:
        txt += f" {poly_obj['score']:.2f}"
    cv2.putText(overlay, txt, (cx, max(0, cy - 4)), cv2.FONT_HERSHEY_SIMPLEX, 0.35, (255, 255, 0), 1, cv2.LINE_AA)

# Save masks/overlay
cv2.imwrite(str(OVERLAY_PATH), cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR))
cv2.imwrite(str(SEM_MASK_PATH), sem_map)

# Save polygons JSON
with POLYGONS_JSON_PATH.open('w', encoding='utf-8') as f:
    json.dump({'image': str(INPUT_IMAGE_PATH), 'polygons': roof_polygons}, f, indent=2, ensure_ascii=False)

# Save simple GeoJSON-like pixel-space output
geojson = {
    'type': 'FeatureCollection',
    'features': []
}
for poly_obj in roof_polygons:
    coords = [[float(x), float(y)] for x, y in poly_obj['points_xy']]
    # close ring
    if coords and coords[0] != coords[-1]:
        coords.append(coords[0])
    geojson['features'].append({
        'type': 'Feature',
        'properties': {
            'polygon_id': poly_obj['polygon_id'],
            'class_id': poly_obj['class_id'],
            'score': poly_obj['score'],
            'area_px': poly_obj['area_px'],
        },
        'geometry': {
            'type': 'Polygon',
            'coordinates': [coords],
        },
    })

with POLYGONS_GEOJSON_PATH.open('w', encoding='utf-8') as f:
    json.dump(geojson, f, indent=2, ensure_ascii=False)

plt.figure(figsize=(18, 8))
plt.subplot(1, 3, 1)
plt.imshow(img_rgb)
plt.title('Input (preprocessed)')
plt.axis('off')

plt.subplot(1, 3, 2)
plt.imshow(sem_vis)
plt.title('Semantic map')
plt.axis('off')

plt.subplot(1, 3, 3)
plt.imshow(overlay)
plt.title(f'Roof polygons: {len(roof_polygons)}')
plt.axis('off')
plt.show()

print(f'Saved overlay: {OVERLAY_PATH}')
print(f'Saved semantic mask: {SEM_MASK_PATH}')
print(f'Saved polygons JSON: {POLYGONS_JSON_PATH}')
print(f'Saved polygons GeoJSON: {POLYGONS_GEOJSON_PATH}')
print(f'Extracted polygons: {len(roof_polygons)}')

In [None]:
summary = {
    'project_root': str(PROJECT_ROOT),
    'work_dir': str(WORK_DIR),
    'input_image': str(INPUT_IMAGE_PATH),
    'preprocessed_image': str(PREPROCESSED_PATH),
    'config': str(CONFIG_PATH),
    'checkpoint': str(CHECKPOINT_PATH),
    'device': DEVICE,
    'original_size': [int(orig_h), int(orig_w)],
    'model_input_size': [MODEL_INPUT_SIZE, MODEL_INPUT_SIZE],
    'semantic_classes': [int(x) for x in np.unique(sem_map)],
    'class_areas': class_areas,
    'roof_pixel_ratio': float((sem_map > 0).sum()) / float(sem_map.size),
    'instance_count': int(len(masks)),
    'instance_source': instance_source,
    'polygon_count': int(len(roof_polygons)),
    'score_mean': float(scores.mean()) if len(scores) > 0 else None,
    'score_max': float(scores.max()) if len(scores) > 0 else None,
    'overlay_path': str(OVERLAY_PATH),
    'semantic_mask_path': str(SEM_MASK_PATH),
    'polygons_json_path': str(POLYGONS_JSON_PATH),
    'polygons_geojson_path': str(POLYGONS_GEOJSON_PATH),
}

with SUMMARY_PATH.open('w', encoding='utf-8') as f:
    json.dump(summary, f, indent=2, ensure_ascii=False)

print(json.dumps(summary, indent=2, ensure_ascii=False))
print(f'Saved summary: {SUMMARY_PATH}')