# DeepRoof Training Notebook (Reproducible)

Этот ноутбук специально сделан тонким: без патчей site-packages и без скрытых runtime-фиксов.
Использует тот же конфиг и код-путь, что и CLI training.

In [None]:
import os
import sys
from pathlib import Path

def detect_project_root() -> Path:
    c = Path.cwd().resolve()
    for cand in [c, *c.parents]:
        if (cand / 'configs').exists() and (cand / 'deeproof').exists():
            return cand
    raise FileNotFoundError('Could not auto-detect project root')

project_root = detect_project_root()
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

print(f'Project root: {project_root}')
print(f'Python: {sys.executable}')


## Configuration
Поменяй только параметры ниже. Остальная логика должна совпадать с production-конфигом.

In [None]:
from mmengine.config import Config
from mmengine.runner import Runner
from mmseg.utils import register_all_modules
from deeproof.utils.runtime_compat import apply_runtime_compat

import warnings
from datetime import datetime
import torch
import deeproof.models.backbones.swin_v2_compat
import deeproof.models.deeproof_model
import deeproof.models.heads.mask2former_head
import deeproof.models.heads.geometry_head
import deeproof.models.heads.dense_normal_head
import deeproof.models.heads.edge_head
import deeproof.models.losses
import deeproof.datasets.roof_dataset
import deeproof.datasets.universal_roof_dataset
import deeproof.evaluation.metrics
import deeproof.hooks.progress_hook

register_all_modules(init_default_scope=False)

CONFIG_PATH = project_root / 'configs' / 'deeproof_production_swin_L.py'
WORK_DIR_BASE = project_root / 'work_dirs'
RUN_NAME = f"deeproof_notebook_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
WORK_DIR = WORK_DIR_BASE / RUN_NAME
DATA_ROOT = project_root / 'data' / 'OmniCity'
FORCE_NEW_WORK_DIR = True
AUTO_RESUME_LATEST = False
AUTO_WARMSTART_BEST = False
REQUIRE_COMPATIBLE_AUX_HEADS = True
FALLBACK_TO_LOAD_IF_RESUME_INCOMPATIBLE = False
DISABLE_BACKBONE_PRETRAIN_IF_LOADING = True
SAFE_NOTEBOOK_DATALOADER = True
RUN_DATALOADER_PREFLIGHT = True
PREFLIGHT_NUM_SAMPLES = 16

# Notebook compute-safe profile: keeps production config intact but makes
# interactive runs stable/fast enough to pass iter=0 quickly.
NOTEBOOK_COMPUTE_SAFE_PROFILE = True
NOTEBOOK_BATCH_SIZE = 1
NOTEBOOK_IMAGE_SIZE = (768, 768)
NOTEBOOK_NUM_QUERIES = 150
NOTEBOOK_NUM_POINTS = 4096
NOTEBOOK_USE_AMP = True

REQUIRED_STATE_PREFIXES = [
    'dense_geometry_head.',
    'edge_head.',
]


def _safe_torch_load(path: Path):
    # Explicit weights_only keeps behavior stable as PyTorch changes defaults.
    try:
        return torch.load(str(path), map_location='cpu', weights_only=False)
    except TypeError:
        return torch.load(str(path), map_location='cpu')


def _read_last_checkpoint(work_dir: Path):
    marker = work_dir / 'last_checkpoint'
    if not marker.exists():
        return None
    raw = marker.read_text(encoding='utf-8').strip()
    if not raw:
        return None
    ckpt = Path(raw)
    if not ckpt.is_absolute():
        ckpt = (work_dir / ckpt).resolve()
    return ckpt if ckpt.exists() else None


def _extract_state_dict(ckpt_obj):
    if isinstance(ckpt_obj, dict):
        sd = ckpt_obj.get('state_dict', ckpt_obj)
        if isinstance(sd, dict):
            return sd
    return {}


def _checkpoint_has_required_prefixes(ckpt_path: Path, prefixes):
    try:
        ckpt_obj = _safe_torch_load(ckpt_path)
    except Exception:
        return False
    state_dict = _extract_state_dict(ckpt_obj)
    keys = tuple(state_dict.keys())
    return all(any(k.startswith(prefix) for k in keys) for prefix in prefixes)


cfg = Config.fromfile(str(CONFIG_PATH))
apply_runtime_compat(cfg)
cfg.default_scope = 'mmseg'

# Silence IoUMetric prefix warning and keep save_best key stable.
val_eval = cfg.get('val_evaluator', None)
if isinstance(val_eval, dict):
    if val_eval.get('type') == 'IoUMetric' and val_eval.get('prefix', None) is None:
        val_eval['prefix'] = ''
elif isinstance(val_eval, (list, tuple)):
    for metric in val_eval:
        if isinstance(metric, dict) and metric.get('type') == 'IoUMetric' and metric.get('prefix', None) is None:
            metric['prefix'] = ''

WORK_DIR.mkdir(parents=True, exist_ok=True)
cfg.work_dir = str(WORK_DIR)

# Explicit notebook heartbeat hook (print-based), useful when logger output is buffered.
cfg.custom_hooks = cfg.get('custom_hooks', [])
has_progress = False
for h in cfg.custom_hooks:
    if isinstance(h, dict) and h.get('type') == 'DeepRoofProgressHook':
        h['interval'] = 10
        h['heartbeat_sec'] = 20
        h['dataloader_warn_sec'] = 90
        h['flush'] = True
        has_progress = True
if not has_progress:
    cfg.custom_hooks.append(dict(type='DeepRoofProgressHook', interval=10, heartbeat_sec=20, dataloader_warn_sec=90, flush=True))



# Ensure frequent, visible train logs in notebook.
cfg.log_level = 'INFO'
cfg.log_processor = dict(by_epoch=False, window_size=10)
cfg.setdefault('default_hooks', {})
if cfg.default_hooks.get('logger') is None:
    cfg.default_hooks.logger = dict(type='LoggerHook', interval=10, log_metric_by_epoch=False)
else:
    cfg.default_hooks.logger.interval = 10
    cfg.default_hooks.logger.log_metric_by_epoch = False


# Notebook checkpoint policy: save early and often into this run folder.
cfg.setdefault('default_hooks', {})
if cfg.default_hooks.get('checkpoint') is None:
    cfg.default_hooks.checkpoint = dict(type='CheckpointHook')
cfg.default_hooks.checkpoint.by_epoch = False
cfg.default_hooks.checkpoint.interval = 500
cfg.default_hooks.checkpoint.save_best = 'mIoU'
cfg.default_hooks.checkpoint.rule = 'greater'
cfg.default_hooks.checkpoint.max_keep_ckpts = 5
cfg.default_hooks.checkpoint.save_last = True
if cfg.get('train_cfg') is not None:
    cfg.train_cfg.val_interval = min(int(cfg.train_cfg.get('val_interval', 5000)), 500)

if cfg.get('train_dataloader') and cfg.train_dataloader.get('dataset'):
    ds = cfg.train_dataloader.dataset
    if ds.get('type') == 'DeepRoofDataset':
        ds.data_root = str(DATA_ROOT)
if cfg.get('val_dataloader') and cfg.val_dataloader.get('dataset'):
    ds = cfg.val_dataloader.dataset
    if ds.get('type') == 'DeepRoofDataset':
        ds.data_root = str(DATA_ROOT)


if SAFE_NOTEBOOK_DATALOADER:
    # Jupyter-safe mode: avoid worker startup deadlocks/hangs on first batch.
    if cfg.get('train_dataloader') is not None:
        cfg.train_dataloader.num_workers = 0
        cfg.train_dataloader.persistent_workers = False
        cfg.train_dataloader.timeout = 0
        if 'prefetch_factor' in cfg.train_dataloader:
            cfg.train_dataloader.pop('prefetch_factor')
    if cfg.get('val_dataloader') is not None:
        cfg.val_dataloader.num_workers = 0
        cfg.val_dataloader.persistent_workers = False
        cfg.val_dataloader.timeout = 0
        if 'prefetch_factor' in cfg.val_dataloader:
            cfg.val_dataloader.pop('prefetch_factor')

if NOTEBOOK_COMPUTE_SAFE_PROFILE:
    # Reduce per-iter compute so notebook training becomes responsive.
    if cfg.get('train_dataloader') is not None:
        cfg.train_dataloader.batch_size = int(max(1, NOTEBOOK_BATCH_SIZE))
        if cfg.train_dataloader.get('dataset') is not None:
            cfg.train_dataloader.dataset.image_size = tuple(NOTEBOOK_IMAGE_SIZE)
    if cfg.get('val_dataloader') is not None and cfg.val_dataloader.get('dataset') is not None:
        cfg.val_dataloader.dataset.image_size = tuple(NOTEBOOK_IMAGE_SIZE)

    if cfg.get('model') is not None and cfg.model.get('decode_head') is not None:
        cfg.model.decode_head.num_queries = int(max(64, NOTEBOOK_NUM_QUERIES))
        if cfg.model.decode_head.get('train_cfg') is not None:
            cfg.model.decode_head.train_cfg.num_points = int(max(1024, NOTEBOOK_NUM_POINTS))

    if NOTEBOOK_USE_AMP and cfg.get('optim_wrapper') is not None:
        ow = cfg.optim_wrapper
        if ow.get('type', 'OptimWrapper') != 'AmpOptimWrapper':
            amp_wrapper = dict(
                type='AmpOptimWrapper',
                optimizer=ow.optimizer,
                loss_scale='dynamic')
            if ow.get('clip_grad', None) is not None:
                amp_wrapper['clip_grad'] = ow.clip_grad
            if ow.get('paramwise_cfg', None) is not None:
                amp_wrapper['paramwise_cfg'] = ow.paramwise_cfg
            cfg.optim_wrapper = amp_wrapper

if RUN_DATALOADER_PREFLIGHT:
    import time
    from torch.utils.data import DataLoader
    from mmengine.dataset import pseudo_collate
    from mmseg.registry import DATASETS

    t0 = time.time()
    train_ds = DATASETS.build(cfg.train_dataloader.dataset)
    ds_len = len(train_ds)
    print('preflight_train_len:', ds_len)
    if ds_len <= 0:
        raise RuntimeError('Train dataset is empty. Check ann_file/data_root paths.')

    n_check = min(int(PREFLIGHT_NUM_SAMPLES), ds_len)
    t_fetch = time.time()
    first_keys = None
    max_instances_seen = 0
    for i in range(n_check):
        try:
            smp = train_ds[i]
        except Exception as e:
            raise RuntimeError(f'Preflight failed at sample index {i}: {e}') from e
        if first_keys is None:
            first_keys = sorted(list(smp.keys()))
        try:
            n_inst = int(getattr(smp['data_samples'].gt_instances, 'labels').shape[0])
            max_instances_seen = max(max_instances_seen, n_inst)
        except Exception:
            pass
    print('preflight_first_sample_keys:', first_keys)
    print('preflight_checked_samples:', n_check)
    print('preflight_max_instances_seen:', max_instances_seen)
    print(f'preflight_fetch_time_sec: {time.time() - t_fetch:.2f}')

    # DataLoader smoke test: catches first-batch stalls before runner.train().
    t_dl = time.time()
    smoke_loader = DataLoader(
        train_ds,
        batch_size=1,
        shuffle=False,
        num_workers=0,
        collate_fn=pseudo_collate,
        pin_memory=False,
    )
    try:
        first_batch = next(iter(smoke_loader))
    except Exception as e:
        raise RuntimeError(f'Dataloader preflight failed on first batch: {e}') from e
    print('preflight_first_batch_type:', type(first_batch).__name__)
    print(f'preflight_first_batch_time_sec: {time.time() - t_dl:.2f}')
    print(f'preflight_total_time_sec: {time.time() - t0:.2f}')

selected_mode = 'fresh'
selected_ckpt = None
selection_reason = 'no checkpoint selected'

if AUTO_RESUME_LATEST:
    last_ckpt = _read_last_checkpoint(WORK_DIR)
    if last_ckpt is not None:
        compatible = (not REQUIRE_COMPATIBLE_AUX_HEADS) or _checkpoint_has_required_prefixes(last_ckpt, REQUIRED_STATE_PREFIXES)
        if compatible:
            cfg.resume = True
            cfg.load_from = None
            selected_mode = 'resume_latest'
            selected_ckpt = last_ckpt
            selection_reason = 'last checkpoint is architecture-compatible; full resume enabled'
        elif FALLBACK_TO_LOAD_IF_RESUME_INCOMPATIBLE:
            cfg.resume = False
            cfg.load_from = str(last_ckpt)
            selected_mode = 'warmstart_last_incompatible'
            selected_ckpt = last_ckpt
            selection_reason = 'last checkpoint missing new heads; switched to weights-only load_from'
        else:
            selection_reason = 'last checkpoint is incompatible with current model heads; checkpoint load skipped (fresh start with configured pretrain only)'

if selected_mode == 'fresh' and AUTO_WARMSTART_BEST:
    best_ckpts = sorted(WORK_DIR.glob('best_mIoU*.pth'), reverse=True)
    for ckpt in best_ckpts:
        compatible = (not REQUIRE_COMPATIBLE_AUX_HEADS) or _checkpoint_has_required_prefixes(ckpt, REQUIRED_STATE_PREFIXES)
        if compatible:
            cfg.load_from = str(ckpt)
            cfg.resume = False
            selected_mode = 'warmstart_best'
            selected_ckpt = ckpt
            selection_reason = 'compatible best checkpoint selected for weights-only warmstart'
            break

if DISABLE_BACKBONE_PRETRAIN_IF_LOADING and (selected_mode in ('resume_latest', 'warmstart_best', 'warmstart_last_incompatible')):
    if cfg.get('model') and cfg.model.get('backbone'):
        cfg.model.backbone.init_cfg = None

if cfg.get('val_dataloader') is not None and cfg.get('val_evaluator') is not None and cfg.get('val_cfg') is None:
    cfg.val_cfg = dict(type='ValLoop')
if cfg.get('test_dataloader') is not None and cfg.get('test_evaluator') is not None and cfg.get('test_cfg') is None:
    cfg.test_cfg = dict(type='TestLoop')

print('Config loaded')
print('work_dir:', cfg.work_dir)
print('checkpoint_mode:', selected_mode)
print('checkpoint_path:', str(selected_ckpt) if selected_ckpt else 'None')
print('checkpoint_reason:', selection_reason)
print('resume:', cfg.get('resume', False))
print('load_from:', cfg.get('load_from', 'None'))
print('batch_size:', cfg.train_dataloader.batch_size)
print('max_iters:', cfg.train_cfg.get('max_iters', 'N/A'))

print('logger_interval:', cfg.default_hooks.get('logger', {}).get('interval', 'N/A'))
print('log_by_epoch:', cfg.log_processor.get('by_epoch', 'N/A'))

print('custom_hooks:', cfg.get('custom_hooks', []))
print('train_num_workers:', cfg.train_dataloader.get('num_workers', 'N/A'))
print('train_persistent_workers:', cfg.train_dataloader.get('persistent_workers', 'N/A'))
print('run_name:', RUN_NAME)
print('checkpoint_interval:', cfg.default_hooks.get('checkpoint', {}).get('interval', 'N/A'))
print('val_interval:', cfg.train_cfg.get('val_interval', 'N/A'))
print('compute_safe_profile:', NOTEBOOK_COMPUTE_SAFE_PROFILE)
print('batch_size_effective:', cfg.train_dataloader.get('batch_size', 'N/A'))
print('image_size_effective:', cfg.train_dataloader.dataset.get('image_size', 'N/A'))
print('num_queries_effective:', cfg.model.decode_head.get('num_queries', 'N/A'))
print('num_points_effective:', cfg.model.decode_head.train_cfg.get('num_points', 'N/A'))
print('optim_wrapper_type:', cfg.optim_wrapper.get('type', 'N/A'))


## Optional Dataset Preview

In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt

def preview_dataset(num_samples=2):
    train_txt = Path(cfg.train_dataloader.dataset.ann_file)
    if not train_txt.is_absolute():
        train_txt = Path(cfg.train_dataloader.dataset.data_root) / train_txt
    if not train_txt.exists():
        print('train.txt not found:', train_txt)
        return

    sample_ids = [x.strip() for x in train_txt.read_text(encoding='utf-8').splitlines() if x.strip()][:num_samples]
    data_root = Path(cfg.train_dataloader.dataset.data_root)

    fig, axes = plt.subplots(len(sample_ids), 3, figsize=(15, 5 * len(sample_ids)))
    if len(sample_ids) == 1:
        axes = np.expand_dims(axes, 0)

    for i, sid in enumerate(sample_ids):
        img = cv2.cvtColor(cv2.imread(str(data_root / 'images' / f'{sid}.jpg')), cv2.COLOR_BGR2RGB)
        mask = cv2.imread(str(data_root / 'masks' / f'{sid}.png'), cv2.IMREAD_UNCHANGED)
        norm_p = data_root / 'normals' / f'{sid}.npy'

        axes[i,0].imshow(img); axes[i,0].set_title(f'image: {sid}'); axes[i,0].axis('off')
        axes[i,1].imshow(mask, cmap='tab20'); axes[i,1].set_title('instance mask'); axes[i,1].axis('off')

        if norm_p.exists():
            n = np.load(str(norm_p))
            n_vis = ((n + 1.0) * 127.5).clip(0,255).astype(np.uint8)
            axes[i,2].imshow(n_vis)
        else:
            axes[i,2].text(0.5,0.5,'no normal file',ha='center',va='center')
        axes[i,2].set_title('normals'); axes[i,2].axis('off')

    plt.tight_layout(); plt.show()

preview_dataset(num_samples=2)

## Train

In [None]:
import torch
from datetime import datetime
from pathlib import Path

device_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'
print('Device:', device_name)
print('Training work_dir:', cfg.work_dir)
Path(cfg.work_dir).mkdir(parents=True, exist_ok=True)
(Path(cfg.work_dir) / 'RUN_STARTED.txt').write_text(
    f'started_at={datetime.now().isoformat()}\n', encoding='utf-8')

runner = Runner.from_cfg(cfg)
runner.train()


## Plot Training Curves

In [None]:
import glob
import json
import os
import matplotlib.pyplot as plt

def plot_training_logs(work_dir):
    candidates = glob.glob(os.path.join(work_dir, '*/vis_data/scalars.json')) + glob.glob(os.path.join(work_dir, 'vis_data/scalars.json'))
    if not candidates:
        print('No scalar logs found.')
        return
    log_path = sorted(candidates)[-1]
    print('Reading:', log_path)

    train_iter = []
    losses = {}
    val_iter = []
    val_miou = []

    with open(log_path, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            try:
                d = json.loads(line)
            except Exception:
                continue

            if 'loss' in d and d.get('mode', 'train') != 'val':
                it = d.get('iter', d.get('step', 0))
                train_iter.append(it)
                for k,v in d.items():
                    if k.startswith('loss') and isinstance(v, (int,float)):
                        losses.setdefault(k, []).append(float(v))

            if d.get('mode') == 'val' and 'mIoU' in d:
                val_iter.append(d.get('iter', d.get('step', 0)))
                val_miou.append(float(d['mIoU']))

    fig, axes = plt.subplots(1, 2, figsize=(16,5))
    if losses:
        for k, vals in sorted(losses.items()):
            x = train_iter[:len(vals)]
            axes[0].plot(x, vals, label=k, linewidth=1.2)
        axes[0].set_title('Train losses')
        axes[0].set_xlabel('iter')
        axes[0].grid(alpha=0.3)
        axes[0].legend(fontsize=8)

    if val_iter:
        axes[1].plot(val_iter, val_miou, marker='o', linewidth=2)
        axes[1].set_title('Validation mIoU')
        axes[1].set_xlabel('iter')
        axes[1].grid(alpha=0.3)

    plt.tight_layout(); plt.show()

plot_training_logs(cfg.work_dir)