# DeepRoof Checkpoint Inference (PNG/JPG)

Notebook validates the checkpoint, loads the DeepRoof model, runs segmentation on one image, and saves visual results.

In [None]:
from __future__ import annotations

import json
import os
import sys
from pathlib import Path

import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch


In [None]:
def detect_project_root() -> Path:
    candidates = [
        Path.cwd(),
        Path.cwd().parent,
        Path('/workspace/roof'),
        Path('/Users/voskan/Desktop/DeepRoof-2026'),
    ]
    for c in candidates:
        if (c / 'configs').exists() and (c / 'deeproof').exists():
            return c
    raise FileNotFoundError('Could not auto-detect project root with configs/ and deeproof/')

PROJECT_ROOT = detect_project_root()
CONFIG_PATH = PROJECT_ROOT / 'configs' / 'deeproof_scratch_swin_L.py'

# Prefer server checkpoint path; keep local fallback only for analysis/debug.
SERVER_CHECKPOINT_PATH = Path('/workspace/roof/work_dirs/swin_l_scratch_v1/iter_8000.pth')
LOCAL_ANALYSIS_CHECKPOINT_PATH = Path('/Users/voskan/Downloads/iter_8000.pth')
CHECKPOINT_PATH = SERVER_CHECKPOINT_PATH if SERVER_CHECKPOINT_PATH.exists() else LOCAL_ANALYSIS_CHECKPOINT_PATH

# Set your test image here (PNG/JPG/TIF).
INPUT_IMAGE_PATH = Path('/workspace/test.png')
if not INPUT_IMAGE_PATH.exists():
    fallback = PROJECT_ROOT / 'test.png'
    if fallback.exists():
        INPUT_IMAGE_PATH = fallback

OUTPUT_DIR = PROJECT_ROOT / 'outputs' / 'checkpoint_inference'
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
OVERLAY_PATH = OUTPUT_DIR / 'test_segmentation_overlay.png'
SEM_MASK_PATH = OUTPUT_DIR / 'test_semantic_mask.png'
SUMMARY_PATH = OUTPUT_DIR / 'test_inference_summary.json'

DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'

print(f'PROJECT_ROOT: {PROJECT_ROOT}')
print(f'CONFIG: {CONFIG_PATH}')
print(f'CHECKPOINT: {CHECKPOINT_PATH}')
print(f'INPUT: {INPUT_IMAGE_PATH}')
print(f'DEVICE: {DEVICE}')


In [None]:
for p in (CONFIG_PATH, CHECKPOINT_PATH, INPUT_IMAGE_PATH):
    if not p.exists():
        raise FileNotFoundError(f'Path not found: {p}')

img_bgr = cv2.imread(str(INPUT_IMAGE_PATH), cv2.IMREAD_COLOR)
if img_bgr is None:
    raise RuntimeError(f'Could not load image: {INPUT_IMAGE_PATH}')

img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
H, W = img_rgb.shape[:2]
print(f'Loaded image: {W}x{H}')

plt.figure(figsize=(8, 8))
plt.imshow(img_rgb)
plt.title('Input image')
plt.axis('off')
plt.show()


In [None]:
# Checkpoint compatibility inspection
# Required for PyTorch>=2.6 when old MMEngine checkpoints contain non-tensor objects.
os.environ.setdefault('TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD', '1')

def inspect_checkpoint(path: Path):
    try:
        ckpt = torch.load(str(path), map_location='cpu', weights_only=False)
    except TypeError:
        # Older torch versions without weights_only argument
        ckpt = torch.load(str(path), map_location='cpu')

    if not isinstance(ckpt, dict):
        return {'type': str(type(ckpt)), 'error': 'checkpoint is not a dict'}

    state_dict = ckpt.get('state_dict', ckpt.get('model', None))
    info = {
        'top_keys': list(ckpt.keys()),
        'has_state_dict': isinstance(state_dict, dict),
        'meta_keys': list(ckpt.get('meta', {}).keys()) if isinstance(ckpt.get('meta', None), dict) else [],
    }

    if isinstance(state_dict, dict):
        keys = list(state_dict.keys())
        info['num_params'] = len(keys)
        info['first_keys'] = keys[:15]
        probes = [
            'backbone.patch_embed.projection.weight',
            'decode_head.query_embed.weight',
            'geometry_head.layers.0.weight',
            'module.backbone.patch_embed.projection.weight',
        ]
        info['probe_hits'] = {k: (k in state_dict) for k in probes}

    return info

ckpt_info = inspect_checkpoint(CHECKPOINT_PATH)
print(json.dumps(ckpt_info, indent=2, ensure_ascii=False))


In [None]:
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from mmseg.utils import register_all_modules
from mmseg.apis import init_model, inference_model
from mmengine.config import ConfigDict

register_all_modules(init_default_scope=False)

# Ensure custom modules are imported and registered.
import deeproof.models.backbones.swin_v2_compat
import deeproof.models.deeproof_model
import deeproof.models.heads.mask2former_head
import deeproof.models.heads.geometry_head
import deeproof.models.losses

def _ensure_test_pipeline(cfg):
    if hasattr(cfg, 'test_pipeline') and cfg.test_pipeline:
        return

    pipeline = None
    for loader_key in ('test_dataloader', 'val_dataloader', 'train_dataloader'):
        if loader_key in cfg and cfg.get(loader_key) is not None:
            dataset_cfg = cfg[loader_key].get('dataset', None)
            if dataset_cfg is not None:
                candidate = dataset_cfg.get('pipeline', None)
                if candidate:
                    pipeline = list(candidate)
                    break

    if not pipeline:
        pipeline = [dict(type='LoadImageFromFile'), dict(type='PackSegInputs')]

    filtered = []
    has_load_image = False
    has_pack_inputs = False
    for transform in pipeline:
        if not isinstance(transform, dict):
            continue
        t = transform.get('type', '')
        if t == 'LoadAnnotations':
            continue
        if t == 'LoadImageFromFile':
            has_load_image = True
        if t == 'PackSegInputs':
            has_pack_inputs = True
        filtered.append(transform)

    if not has_load_image:
        filtered.insert(0, dict(type='LoadImageFromFile'))
    if not has_pack_inputs:
        filtered.append(dict(type='PackSegInputs'))

    cfg.test_pipeline = filtered

model = init_model(str(CONFIG_PATH), str(CHECKPOINT_PATH), device=DEVICE)
_ensure_test_pipeline(model.cfg)

# Fix inference crash when model.test_cfg is missing or wrong type in custom training configs.
if getattr(model, 'test_cfg', None) is None:
    model.test_cfg = ConfigDict(mode='whole')
elif isinstance(model.test_cfg, dict):
    cfg_tmp = dict(model.test_cfg)
    cfg_tmp.setdefault('mode', 'whole')
    model.test_cfg = ConfigDict(cfg_tmp)
else:
    # Some versions expect attr-access (test_cfg.mode), others use get().
    # Ensure 'mode' exists for both behaviors.
    if not hasattr(model.test_cfg, 'mode'):
        try:
            setattr(model.test_cfg, 'mode', 'whole')
        except Exception:
            model.test_cfg = ConfigDict(mode='whole')

model.eval()
print('Model loaded successfully.')
print('test_pipeline:', model.cfg.test_pipeline)
print('model.test_cfg:', model.test_cfg)


In [None]:
result = inference_model(model, str(INPUT_IMAGE_PATH))
if isinstance(result, (list, tuple)):
    result = result[0]

# Semantic map
if hasattr(result, 'pred_sem_seg') and hasattr(result.pred_sem_seg, 'data'):
    sem_map = result.pred_sem_seg.data.squeeze(0).detach().cpu().numpy().astype(np.uint8)
else:
    sem_map = np.zeros((H, W), dtype=np.uint8)

if sem_map.shape != (H, W):
    sem_map = cv2.resize(sem_map, (W, H), interpolation=cv2.INTER_NEAREST)

# Instances (if available)
masks = np.zeros((0, H, W), dtype=bool)
scores = np.array([], dtype=np.float32)
labels = np.array([], dtype=np.int64)

if hasattr(result, 'pred_instances') and result.pred_instances is not None:
    inst = result.pred_instances
    if hasattr(inst, 'masks') and inst.masks is not None:
        masks_t = inst.masks
        if torch.is_tensor(masks_t):
            masks_np = masks_t.detach().cpu().numpy().astype(bool)
            if masks_np.ndim == 3:
                resized_masks = []
                for m in masks_np:
                    if m.shape != (H, W):
                        m = cv2.resize(m.astype(np.uint8), (W, H), interpolation=cv2.INTER_NEAREST).astype(bool)
                    resized_masks.append(m)
                masks = np.stack(resized_masks, axis=0) if resized_masks else np.zeros((0, H, W), dtype=bool)

    if hasattr(inst, 'scores') and inst.scores is not None:
        scores = inst.scores.detach().cpu().numpy()
    if hasattr(inst, 'labels') and inst.labels is not None:
        labels = inst.labels.detach().cpu().numpy()

print(f'Unique semantic classes: {np.unique(sem_map).tolist()}')
print(f'Predicted instances: {len(masks)}')
if len(scores) > 0:
    print(f'Score range: {float(scores.min()):.4f} .. {float(scores.max()):.4f}')


In [None]:
# Palette: background, flat_roof, sloped_roof
palette = np.array([
    [0, 0, 0],
    [0, 255, 0],
    [255, 0, 0],
], dtype=np.uint8)

if hasattr(model, 'dataset_meta') and isinstance(model.dataset_meta, dict):
    model_palette = model.dataset_meta.get('palette', None)
    if model_palette is not None and len(model_palette) >= 3:
        palette = np.array(model_palette, dtype=np.uint8)

sem_vis = palette[np.clip(sem_map, 0, len(palette) - 1)]
overlay = cv2.addWeighted(img_rgb, 0.60, sem_vis, 0.40, 0.0)

MIN_SCORE = 0.25
for i, mask in enumerate(masks):
    if i < len(scores) and float(scores[i]) < MIN_SCORE:
        continue

    contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    cv2.drawContours(overlay, contours, -1, (255, 255, 255), 1)

    if len(contours) > 0 and i < len(scores):
        c = max(contours, key=cv2.contourArea)
        x, y, w, h = cv2.boundingRect(c)
        txt = f'{float(scores[i]):.2f}'
        if i < len(labels):
            txt = f'cls:{int(labels[i])} {txt}'
        cv2.putText(
            overlay,
            txt,
            (x, max(0, y - 5)),
            cv2.FONT_HERSHEY_SIMPLEX,
            0.4,
            (255, 255, 255),
            1,
            cv2.LINE_AA,
        )

cv2.imwrite(str(OVERLAY_PATH), cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR))
cv2.imwrite(str(SEM_MASK_PATH), sem_map)

plt.figure(figsize=(16, 8))
plt.subplot(1, 2, 1)
plt.imshow(img_rgb)
plt.title('Input')
plt.axis('off')

plt.subplot(1, 2, 2)
plt.imshow(overlay)
plt.title('Segmentation Overlay')
plt.axis('off')
plt.show()

print(f'Saved overlay: {OVERLAY_PATH}')
print(f'Saved semantic mask: {SEM_MASK_PATH}')


In [None]:
summary = {
    'project_root': str(PROJECT_ROOT),
    'input_image': str(INPUT_IMAGE_PATH),
    'config': str(CONFIG_PATH),
    'checkpoint': str(CHECKPOINT_PATH),
    'device': DEVICE,
    'image_size': [int(H), int(W)],
    'semantic_classes': [int(x) for x in np.unique(sem_map)],
    'instance_count': int(len(masks)),
    'score_mean': float(scores.mean()) if len(scores) > 0 else None,
    'score_max': float(scores.max()) if len(scores) > 0 else None,
    'overlay_path': str(OVERLAY_PATH),
    'semantic_mask_path': str(SEM_MASK_PATH),
}

with SUMMARY_PATH.open('w', encoding='utf-8') as f:
    json.dump(summary, f, indent=2, ensure_ascii=False)

print(json.dumps(summary, indent=2, ensure_ascii=False))
print(f'Saved summary: {SUMMARY_PATH}')
