# DeepRoof Checkpoint Inference Notebook (Production-Aligned)

Этот ноутбук запускает тот же inference-пайплайн, что и CLI (`tools/inference.py`),
чтобы результаты в ноутбуке и проде совпадали.

In [None]:
import json
import os
import subprocess
import sys
from pathlib import Path

import cv2
import matplotlib.pyplot as plt
import numpy as np

def detect_project_root() -> Path:
    c = Path.cwd().resolve()
    for cand in [c, *c.parents]:
        if (cand / 'configs').exists() and (cand / 'deeproof').exists():
            return cand
    raise FileNotFoundError('Could not auto-detect project root')

PROJECT_ROOT = detect_project_root()
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))
print('PROJECT_ROOT =', PROJECT_ROOT)


## Configuration

In [None]:
CONFIG_PATH = PROJECT_ROOT / 'configs' / 'deeproof_production_swin_L.py'
WORK_DIR = PROJECT_ROOT / 'work_dirs' / 'deeproof_absolute_ideal_v1'
OUTPUT_DIR = PROJECT_ROOT / 'outputs' / 'checkpoint_inference'
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# Input image
INPUT_IMAGE_PATH = PROJECT_ROOT / 'test.png'

# Optional multi-candidate file (.txt/.json)
INPUT_CANDIDATES_FILE = ''
CANDIDATE_SELECTION = 'best'   # best|weighted

# Inference quality params (production-safe defaults)
MIN_CONF = 0.45
MIN_CONF_FLAT = 0.45
MIN_CONF_SLOPED = 0.55
MIN_AREA_PX = 250
MIN_MASK_DENSITY = 0.06
MAX_INSTANCES = 40
TTA_MODE = 'lite'  # full|lite|none
DISABLE_AMP = False
OOM_RETRY_NO_TTA = True
OOM_FALLBACK_CPU = True

# Tiling (prefer fewer seams for medium images)
TILE_SIZE = 1024
STRIDE = 768

# Visualization
VIZ_FILL_ALPHA = 0.0   # 0.0 = contour-only (no green wash)
KEEP_BACKGROUND = False
ALLOW_SEMANTIC_FALLBACK = False

# Advanced options
ENABLE_SR = False
SR_SCALE = 2.0
ENABLE_GRAPH = False
ENABLE_GRAPH_SNAP = False
USE_MODEL_EDGE_FOR_GRAPH = False
ENABLE_SAM2 = False
SAM2_MODEL_TYPE = 'vit_b'
SAM2_CHECKPOINT = ''
DEPTH_MAP = ''

# Checkpoint selection with architecture compatibility guard
REQUIRE_COMPATIBLE_CHECKPOINT = False
import torch

def _ckpt_has_required_heads(path: Path) -> bool:
    required = ('dense_geometry_head.', 'edge_head.')
    try:
        try:
            obj = torch.load(str(path), map_location='cpu', weights_only=False)
        except TypeError:
            obj = torch.load(str(path), map_location='cpu')
    except Exception:
        return False
    state = obj.get('state_dict', obj) if isinstance(obj, dict) else {}
    if not isinstance(state, dict):
        return False
    keys = tuple(state.keys())
    return all(any(k.startswith(pref) for k in keys) for pref in required)

best_ckpts = sorted(WORK_DIR.glob('best_mIoU*.pth'), reverse=True)
iter_ckpts = sorted(WORK_DIR.glob('iter_*.pth'), reverse=True)
candidates = best_ckpts + [p for p in iter_ckpts if p not in set(best_ckpts)]
compatible = [p for p in candidates if _ckpt_has_required_heads(p)]

if compatible:
    CHECKPOINT_PATH = compatible[0]
    CHECKPOINT_COMPATIBLE = True
    print('Selected compatible checkpoint:', CHECKPOINT_PATH)
else:
    CHECKPOINT_COMPATIBLE = False
    if REQUIRE_COMPATIBLE_CHECKPOINT:
        raise RuntimeError('No checkpoint with dense_geometry_head + edge_head found. Run training to produce a compatible checkpoint first.')
    CHECKPOINT_PATH = candidates[0] if candidates else None
    print('WARNING: no fully compatible checkpoint found; fallback to:', CHECKPOINT_PATH)
if CHECKPOINT_PATH is None:
    raise FileNotFoundError(f'No checkpoint found in {WORK_DIR}')

if not INPUT_IMAGE_PATH.exists():
    raise FileNotFoundError(f'Input image not found: {INPUT_IMAGE_PATH}')

print('CONFIG:', CONFIG_PATH)
print('CHECKPOINT:', CHECKPOINT_PATH)
print('CHECKPOINT_COMPATIBLE:', CHECKPOINT_COMPATIBLE)
print('INPUT:', INPUT_IMAGE_PATH)


## Run Production Inference

In [None]:
output_geojson = OUTPUT_DIR / 'result.geojson'
cmd = [
    sys.executable, str(PROJECT_ROOT / 'tools' / 'inference.py'),
    '--config', str(CONFIG_PATH),
    '--checkpoint', str(CHECKPOINT_PATH),
    '--input', str(INPUT_IMAGE_PATH),
    '--output', str(output_geojson),
    '--tile-size', str(TILE_SIZE),
    '--stride', str(STRIDE),
    '--min_confidence', str(MIN_CONF),
    '--min_confidence_flat', str(MIN_CONF_FLAT),
    '--min_confidence_sloped', str(MIN_CONF_SLOPED),
    '--min_area_px', str(MIN_AREA_PX),
    '--min_mask_density', str(MIN_MASK_DENSITY),
    '--max_instances', str(MAX_INSTANCES),
    '--tta-mode', str(TTA_MODE),
    '--viz-fill-alpha', str(VIZ_FILL_ALPHA),
    '--save_viz',
    '--save_metadata',
]

if KEEP_BACKGROUND:
    cmd += ['--keep-background']

if ALLOW_SEMANTIC_FALLBACK:
    cmd += ['--allow-semantic-fallback']
if INPUT_CANDIDATES_FILE:
    cmd += ['--input-candidates', str(INPUT_CANDIDATES_FILE), '--candidate-selection', CANDIDATE_SELECTION]
if ENABLE_SR:
    cmd += ['--sr-enable', '--sr-scale', str(SR_SCALE), '--sr-backend', 'bicubic', '--sr-fuse-mode', 'weighted']
if ENABLE_GRAPH:
    cmd += ['--graph-enable']
if ENABLE_GRAPH_SNAP:
    cmd += ['--polygon-snap-to-graph']
if USE_MODEL_EDGE_FOR_GRAPH:
    cmd += ['--graph-use-model-edge']
if ENABLE_SAM2 and SAM2_CHECKPOINT:
    cmd += ['--sam2-enable', '--sam2-model-type', str(SAM2_MODEL_TYPE), '--sam2-checkpoint', str(SAM2_CHECKPOINT)]
if DEPTH_MAP:
    cmd += ['--depth-map', str(DEPTH_MAP)]

print('Running command:', ' '.join(cmd))
subprocess.check_call(cmd, cwd=str(PROJECT_ROOT))
print('Done:', output_geojson)

if DISABLE_AMP:
    cmd += ['--disable-amp']
if OOM_RETRY_NO_TTA:
    cmd += ['--oom-retry-no-tta']
if OOM_FALLBACK_CPU:
    cmd += ['--oom-fallback-cpu']


## Quick Results Check

In [None]:
result_geojson = OUTPUT_DIR / 'result.geojson'
result_overlay = OUTPUT_DIR / 'result.png'
result_meta = OUTPUT_DIR / 'result.meta.json'

if result_geojson.exists():
    data = json.loads(result_geojson.read_text(encoding='utf-8'))
    feats = data.get('features', [])
    print('features:', len(feats))
    if feats:
        props = feats[0].get('properties', {})
        print('sample properties keys:', sorted(props.keys()))

if result_meta.exists():
    meta = json.loads(result_meta.read_text(encoding='utf-8'))
    print('runtime versions:', meta.get('runtime_versions', {}))

if result_overlay.exists():
    img = cv2.cvtColor(cv2.imread(str(result_overlay)), cv2.COLOR_BGR2RGB)
    plt.figure(figsize=(10,10))
    plt.imshow(img)
    plt.title('Inference Overlay')
    plt.axis('off')
    plt.show()