# DeepRoof Checkpoint Inference Notebook (Production-Aligned)

Этот ноутбук запускает тот же inference-пайплайн, что и CLI (`tools/inference.py`),
чтобы результаты в ноутбуке и проде совпадали.

In [None]:
import json
import os
import subprocess
import sys
from pathlib import Path

import cv2
import matplotlib.pyplot as plt
import numpy as np

def detect_project_root() -> Path:
    c = Path.cwd().resolve()
    for cand in [c, *c.parents]:
        if (cand / 'configs').exists() and (cand / 'deeproof').exists():
            return cand
    raise FileNotFoundError('Could not auto-detect project root')

PROJECT_ROOT = detect_project_root()
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))
print('PROJECT_ROOT =', PROJECT_ROOT)


## Configuration

In [None]:
DEFAULT_CONFIG_PATH = PROJECT_ROOT / 'configs' / 'deeproof_production_swin_L.py'
CONFIG_PATH = DEFAULT_CONFIG_PATH


def _pick_latest_work_dir() -> Path:
    roots = [Path('/workspace/roof/work_dirs'), PROJECT_ROOT / 'work_dirs']
    candidates = []
    for root in roots:
        if root.exists():
            candidates.extend([p for p in root.glob('deeproof_notebook_*') if p.is_dir()])
    if candidates:
        candidates.sort(key=lambda p: p.stat().st_mtime, reverse=True)
        return candidates[0]
    return PROJECT_ROOT / 'work_dirs' / 'deeproof_notebook_20260223_085737'


# Requested run directory/checkpoint
WORK_DIR = Path('/workspace/roof/work_dirs/deeproof_notebook_20260223_085737')
if not WORK_DIR.exists():
    WORK_DIR = _pick_latest_work_dir()

OUTPUT_DIR = PROJECT_ROOT / 'outputs' / 'checkpoint_inference'
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# Input image
INPUT_IMAGE_PATH = PROJECT_ROOT / 'test.png'

# Optional multi-candidate file (.txt/.json)
INPUT_CANDIDATES_FILE = ''
CANDIDATE_SELECTION = 'best'   # best|weighted

# Segmentation-first profile (quality > speed)
MIN_CONF = 0.30
MIN_CONF_FLAT = 0.30
MIN_CONF_SLOPED = 0.40
MIN_AREA_PX = 120
MIN_MASK_DENSITY = 0.02
MAX_INSTANCES = 120
TTA_MODE = 'full'  # full|lite|none
DISABLE_AMP = False
OOM_RETRY_NO_TTA = True
OOM_FALLBACK_CPU = True

# Tiling (large tiles + overlap to reduce seams)
TILE_SIZE = 1408
STRIDE = 1024

# Visualization (2D only, no mask wash)
VIZ_FILL_ALPHA = 0.0
KEEP_BACKGROUND = False
ALLOW_SEMANTIC_FALLBACK = False

# Advanced options (disabled for pure segmentation quality)
ENABLE_SR = False
SR_SCALE = 2.0
ENABLE_GRAPH = False
ENABLE_GRAPH_SNAP = False
USE_MODEL_EDGE_FOR_GRAPH = False
ENABLE_SAM2 = False
SAM2_MODEL_TYPE = 'vit_b'
SAM2_CHECKPOINT = ''
DEPTH_MAP = ''

# Explicit checkpoint preference
CHECKPOINT_NAME = 'iter_500.pth'
CHECKPOINT_OVERRIDE = WORK_DIR / CHECKPOINT_NAME

import torch


def _inspect_checkpoint(path: Path):
    required = ('dense_geometry_head.', 'edge_head.')
    try:
        try:
            obj = torch.load(str(path), map_location='cpu', weights_only=False)
        except TypeError:
            obj = torch.load(str(path), map_location='cpu')
    except Exception:
        return False, None

    state = obj.get('state_dict', obj) if isinstance(obj, dict) else {}
    compatible = False
    if isinstance(state, dict):
        keys = tuple(state.keys())
        compatible = all(any(k.startswith(pref) for k in keys) for pref in required)

    cfg_text = None
    if isinstance(obj, dict):
        meta = obj.get('meta', {})
        if isinstance(meta, dict):
            for key in ('config', 'cfg', 'pretty_text'):
                val = meta.get(key)
                if isinstance(val, str) and 'model' in val:
                    cfg_text = val
                    break
    return compatible, cfg_text


# Resolve requested checkpoint robustly.
requested = [
    CHECKPOINT_OVERRIDE,
    WORK_DIR / 'iter_60000.pth',
    WORK_DIR / 'iter_060000.pth',
]
requested += sorted(WORK_DIR.glob('iter_500*.pth'), reverse=True)
requested += sorted(WORK_DIR.glob('iter_*60*.pth'), reverse=True)
requested += sorted(WORK_DIR.glob('iter_*.pth'), reverse=True)

seen = set()
requested_unique = []
for p in requested:
    sp = str(p)
    if sp not in seen:
        seen.add(sp)
        requested_unique.append(p)

existing = [p for p in requested_unique if p.exists()]
if not existing:
    available = sorted(WORK_DIR.glob('*.pth')) if WORK_DIR.exists() else []
    err_lines = ['Requested checkpoint not found. Tried:']
    err_lines.extend([f'- {p}' for p in requested_unique])
    if available:
        err_lines.append('Available in WORK_DIR:')
        err_lines.extend([f'- {p}' for p in available])
    else:
        err_lines.append('No .pth files found in WORK_DIR.')
    raise FileNotFoundError(chr(10).join(err_lines))

CHECKPOINT_PATH = existing[0]
CHECKPOINT_COMPATIBLE, CKPT_CONFIG_TEXT = _inspect_checkpoint(CHECKPOINT_PATH)

# Use architecture-matching config embedded in checkpoint metadata when available.
if CKPT_CONFIG_TEXT:
    ckpt_cfg_path = OUTPUT_DIR / f'config_from_{CHECKPOINT_PATH.stem}.py'
    ckpt_cfg_path.write_text(CKPT_CONFIG_TEXT)
    CONFIG_PATH = ckpt_cfg_path
else:
    # Fallback: try config snapshot dumped in work_dir by mmengine.
    dumped_cfgs = sorted(
        [p for p in WORK_DIR.glob('*.py') if p.name.startswith('deeproof_')],
        key=lambda p: p.stat().st_mtime,
        reverse=True,
    )
    if dumped_cfgs:
        CONFIG_PATH = dumped_cfgs[0]

if not CHECKPOINT_COMPATIBLE:
    print('WARNING: checkpoint is not fully compatible (dense_geometry_head/edge_head missing). Inference will still run in segmentation-first mode.')

if not INPUT_IMAGE_PATH.exists():
    raise FileNotFoundError(f'Input image not found: {INPUT_IMAGE_PATH}')

print('CONFIG:', CONFIG_PATH)
print('CONFIG_SOURCE:', 'checkpoint_meta' if CKPT_CONFIG_TEXT else ('work_dir_dump' if CONFIG_PATH != DEFAULT_CONFIG_PATH else 'default'))
print('WORK_DIR:', WORK_DIR)
print('CHECKPOINT:', CHECKPOINT_PATH)
print('CHECKPOINT_COMPATIBLE:', CHECKPOINT_COMPATIBLE)
print('INPUT:', INPUT_IMAGE_PATH)



## Run Production Inference

In [None]:
output_geojson = OUTPUT_DIR / 'result.geojson'
cmd = [
    sys.executable, str(PROJECT_ROOT / 'tools' / 'inference.py'),
    '--config', str(CONFIG_PATH),
    '--checkpoint', str(CHECKPOINT_PATH),
    '--input', str(INPUT_IMAGE_PATH),
    '--output', str(output_geojson),
    '--tile-size', str(TILE_SIZE),
    '--stride', str(STRIDE),
    '--min_confidence', str(MIN_CONF),
    '--min_confidence_flat', str(MIN_CONF_FLAT),
    '--min_confidence_sloped', str(MIN_CONF_SLOPED),
    '--min_area_px', str(MIN_AREA_PX),
    '--min_mask_density', str(MIN_MASK_DENSITY),
    '--max_instances', str(MAX_INSTANCES),
    '--tta-mode', str(TTA_MODE),
    '--viz-fill-alpha', str(VIZ_FILL_ALPHA),
    '--save_viz',
    '--save_metadata',
]

if KEEP_BACKGROUND:
    cmd += ['--keep-background']
if ALLOW_SEMANTIC_FALLBACK:
    cmd += ['--allow-semantic-fallback']
if INPUT_CANDIDATES_FILE:
    cmd += ['--input-candidates', str(INPUT_CANDIDATES_FILE), '--candidate-selection', CANDIDATE_SELECTION]
if ENABLE_SR:
    cmd += ['--sr-enable', '--sr-scale', str(SR_SCALE), '--sr-backend', 'bicubic', '--sr-fuse-mode', 'weighted']
if ENABLE_GRAPH:
    cmd += ['--graph-enable']
if ENABLE_GRAPH_SNAP:
    cmd += ['--polygon-snap-to-graph']
if USE_MODEL_EDGE_FOR_GRAPH:
    cmd += ['--graph-use-model-edge']
if ENABLE_SAM2 and SAM2_CHECKPOINT:
    cmd += ['--sam2-enable', '--sam2-model-type', str(SAM2_MODEL_TYPE), '--sam2-checkpoint', str(SAM2_CHECKPOINT)]
if DEPTH_MAP:
    cmd += ['--depth-map', str(DEPTH_MAP)]

# These flags must be appended BEFORE process start.
if DISABLE_AMP:
    cmd += ['--disable-amp']
if OOM_RETRY_NO_TTA:
    cmd += ['--oom-retry-no-tta']
if OOM_FALLBACK_CPU:
    cmd += ['--oom-fallback-cpu']

print('Running command:', ' '.join(cmd))
run_env = os.environ.copy()
run_env.setdefault('PYTORCH_CUDA_ALLOC_CONF', 'expandable_segments:True')
subprocess.check_call(cmd, cwd=str(PROJECT_ROOT), env=run_env)
print('Done:', output_geojson)



## Quick Results Check

In [None]:
result_geojson = OUTPUT_DIR / 'result.geojson'
result_overlay = OUTPUT_DIR / 'result.png'
result_meta = OUTPUT_DIR / 'result.meta.json'

if result_geojson.exists():
    data = json.loads(result_geojson.read_text(encoding='utf-8'))
    feats = data.get('features', [])
    print('features:', len(feats))
    if feats:
        classes = {}
        for f in feats:
            name = f.get('properties', {}).get('class_name', 'unknown')
            classes[name] = classes.get(name, 0) + 1
        print('class counts:', classes)
        props = feats[0].get('properties', {})
        print('sample properties keys:', sorted(props.keys()))

if result_meta.exists():
    meta = json.loads(result_meta.read_text(encoding='utf-8'))
    print('runtime versions:', meta.get('runtime_versions', {}))

if result_overlay.exists():
    img = cv2.cvtColor(cv2.imread(str(result_overlay)), cv2.COLOR_BGR2RGB)
    plt.figure(figsize=(10,10))
    plt.imshow(img)
    plt.title('Segmentation Overlay (2D)')
    plt.axis('off')
    plt.show()

