# Checkpoint Inference Test

Run inference from a trained checkpoint on one input image (GeoTIFF), then inspect GeoJSON and visualization output.

In [None]:
from __future__ import annotations

import json
import shlex
import subprocess
import sys
from pathlib import Path

import cv2
import matplotlib.pyplot as plt


In [None]:
PROJECT_ROOT = Path('/workspace/roof')
WORK_DIR = PROJECT_ROOT / 'work_dirs' / 'swin_l_scratch_v1'
CONFIG_PATH = PROJECT_ROOT / 'configs' / 'deeproof_scratch_swin_L.py'
CHECKPOINT_PATH = WORK_DIR / 'iter_8000.pth'

# Set your input GeoTIFF path here before running inference.
INPUT_IMAGE_PATH = PROJECT_ROOT / 'data' / 'OmniCity' / 'images' / 'SET_IMAGE_NAME.tif'

RUN_NAME = 'iter8000_test'
OUT_DIR = PROJECT_ROOT / 'outputs' / 'checkpoint_inference'
OUT_DIR.mkdir(parents=True, exist_ok=True)
OUTPUT_GEOJSON = OUT_DIR / f'{RUN_NAME}.geojson'
OUTPUT_VIZ = OUTPUT_GEOJSON.with_suffix('.png')

DEVICE = 'cuda:0'
TILE_SIZE = 1024
STRIDE = 800
MIN_CONFIDENCE = 0.5
SAVE_VIZ = True

print(f'CONFIG: {CONFIG_PATH}')
print(f'CHECKPOINT: {CHECKPOINT_PATH}')
print(f'INPUT: {INPUT_IMAGE_PATH}')
print(f'OUTPUT GEOJSON: {OUTPUT_GEOJSON}')
print(f'OUTPUT VIZ: {OUTPUT_VIZ}')

In [None]:
print('Available checkpoints in work_dir:')
for p in sorted(WORK_DIR.glob('*.pth')):
    print('-', p.name)

images_dir = PROJECT_ROOT / 'data' / 'OmniCity' / 'images'
if images_dir.exists():
    tif_candidates = sorted(images_dir.glob('*.tif'))[:10]
    print('Sample .tif images:')
    for p in tif_candidates:
        print('-', p.name)
else:
    print(f'Image directory not found: {images_dir}')

In [None]:
if not CONFIG_PATH.exists():
    raise FileNotFoundError(f'Config not found: {CONFIG_PATH}')
if not CHECKPOINT_PATH.exists():
    raise FileNotFoundError(f'Checkpoint not found: {CHECKPOINT_PATH}')
if not INPUT_IMAGE_PATH.exists():
    raise FileNotFoundError(
        f'Input image not found: {INPUT_IMAGE_PATH}\n'
        'Set INPUT_IMAGE_PATH to a real GeoTIFF file and rerun this cell.'
    )

cmd = [
    sys.executable,
    str(PROJECT_ROOT / 'tools' / 'inference.py'),
    '--config', str(CONFIG_PATH),
    '--checkpoint', str(CHECKPOINT_PATH),
    '--input', str(INPUT_IMAGE_PATH),
    '--output', str(OUTPUT_GEOJSON),
    '--device', DEVICE,
    '--tile-size', str(TILE_SIZE),
    '--stride', str(STRIDE),
    '--min_confidence', str(MIN_CONFIDENCE),
]
if SAVE_VIZ:
    cmd.append('--save_viz')

print('Running command:')
print(' '.join(shlex.quote(c) for c in cmd))

proc = subprocess.run(
    cmd,
    cwd=str(PROJECT_ROOT),
    text=True,
    capture_output=True,
)

if proc.stdout:
    print('--- stdout (tail) ---')
    print(proc.stdout[-6000:])
if proc.stderr:
    print('--- stderr (tail) ---')
    print(proc.stderr[-6000:])

if proc.returncode != 0:
    raise RuntimeError(f'Inference failed with code {proc.returncode}')

print('Inference finished successfully.')

In [None]:
if not OUTPUT_GEOJSON.exists():
    raise FileNotFoundError(f'GeoJSON not found: {OUTPUT_GEOJSON}')

with OUTPUT_GEOJSON.open('r', encoding='utf-8') as f:
    data = json.load(f)

features = data.get('features', [])
print(f'GeoJSON file: {OUTPUT_GEOJSON}')
print(f'Number of features: {len(features)}')

if features:
    print('First feature properties:')
    print(json.dumps(features[0].get('properties', {}), indent=2))

In [None]:
if OUTPUT_VIZ.exists():
    bgr = cv2.imread(str(OUTPUT_VIZ), cv2.IMREAD_COLOR)
    if bgr is None:
        raise RuntimeError(f'Could not read visualization image: {OUTPUT_VIZ}')
    rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
    plt.figure(figsize=(12, 12))
    plt.imshow(rgb)
    plt.title(OUTPUT_VIZ.name)
    plt.axis('off')
    plt.show()
else:
    print(f'Visualization not found: {OUTPUT_VIZ}')
    print('Set SAVE_VIZ=True and rerun inference cell.')

## Troubleshooting

- If you see module import/registry errors, restart kernel and rerun from top.
- If CUDA OOM appears, reduce `TILE_SIZE` and/or increase `STRIDE`.
- If no features are produced, try lower `MIN_CONFIDENCE` (for example `0.3`).