# üè† DeepRoof-2026: Multi-Task Training Notebook

### üõ† Step 1: Final Environment Repair
This cell will forcefully fix the **mmsegmentation** assertion error and activate the venv.

In [None]:
import os
import sys
import subprocess
import torch
import ctypes
import glob
from pathlib import Path

def final_repair():
    print("üõ∞ Initializing DeepRepair...")
    
    # 1. Project Root & Venv Activation
    project_root = Path("/workspace/roof")
    if not project_root.exists():
        project_root = Path(os.getcwd()).parent
    
    venv_path = project_root / "venv"
    if not venv_path.exists():
        venv_path = project_root / ".venv"
        
    if venv_path.exists():
        print(f"üêç Environment: {venv_path}")
        lib_dirs = list(venv_path.glob("lib/python*/site-packages"))
        if lib_dirs:
            site_packages = str(lib_dirs[0])
            if site_packages not in sys.path:
                sys.path.insert(0, site_packages)
            sys.executable = str(venv_path / "bin" / "python")
    
    if str(project_root) not in sys.path:
        sys.path.insert(0, str(project_root))

    # 2. Forced repair of mmseg/__init__.py
    # We do NOT check if it's already patched. We just DO IT.
    try:
        import mmseg
        init_file = Path(mmseg.__file__).parent / "__init__.py"
        print(f"ü©π Force-repairing: {init_file}")
        
        clean_init = """# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import mmengine
from mmengine.utils import digit_version
from .version import __version__, version_info
MMCV_MIN = '2.0.0rc4'
MMCV_MAX = '2.2.0'
MMENGINE_MIN = '0.7.1'
MMENGINE_MAX = '1.0.0'
mmcv_min_version = digit_version(MMCV_MIN)
mmcv_max_version = digit_version('9.9.9') # OVERRIDE by DeepRoof
mmcv_version = digit_version(mmcv.__version__)
mmengine_min_version = digit_version(MMENGINE_MIN)
mmengine_max_version = digit_version('9.9.9') # OVERRIDE by DeepRoof
mmengine_version = digit_version(mmengine.__version__)
__all__ = ['__version__', 'version_info', 'digit_version']\n"""
        
        # Read current content to see if we ACTUALLY need to write (to avoid infinite loops)
        with open(init_file, 'r') as f: current = f.read()
        if "assert (mmcv_min_version" in current or "MMCV_MAX = '2.2.0'" not in current:
            with open(init_file, 'w') as f: f.write(clean_init)
            print("‚úÖ Patch applied. RESTART KERNEL.")
            return False
    except Exception as e:
        print(f"‚ö†Ô∏è mmseg repair failed (might not be installed yet): {e}")

    # 3. CUDA Linker Repair
    cuda_available = torch.cuda.is_available()
    if cuda_available:
        try:
            print("üîç Testing CUDA Linker...")
            from mmcv.ops import point_sample
        except ImportError as e:
            if "libcudart.so" in str(e):
                print("üì¶ Installing CUDA Runtime libs...")
                subprocess.check_call([sys.executable, "-m", "pip", "install", "nvidia-cuda-runtime-cu11"])
                return False
        except Exception: pass

    # 4. Dependency Check
    for pkg in ["ftfy", "regex", "rasterio", "geopandas"]:
        try: subprocess.check_output([sys.executable, "-m", "pip", "show", pkg], stderr=subprocess.DEVNULL)
        except:
            print(f"üì¶ Installing missing: {pkg}")
            subprocess.check_call([sys.executable, "-m", "pip", "install", pkg])
            return False

    print("‚úÖ System Ready.")
    return True

if final_repair():
    print(f"üöÄ Live | {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")

## üìÇ 1. Dataset Preview

Visualize the **satellite imagery**, **instance masks**, and **surface normals**.

In [None]:
def preview_dataset(data_root, num_samples=3):
    import matplotlib.pyplot as plt
    import numpy as np
    import cv2
    
    # Resolve project root from sys.path
    project_root = Path([p for p in sys.path if "roof" in p][0])
    data_path = project_root / data_root / "OmniCity"
        
    train_file = data_path / 'train.txt'
    if not train_file.exists():
        print(f"‚ùå Could view train.txt at {train_file}.")
        return
        
    with open(train_file, 'r') as f:
        sample_ids = [line.strip() for line in f.readlines()[:num_samples]]
    
    fig, axes = plt.subplots(num_samples, 3, figsize=(15, 5 * num_samples))
    for i, sid in enumerate(sample_ids):
        img_path = str(data_path / 'images' / (sid + '.jpg'))
        img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
        
        mask_path = str(data_path / 'masks' / (sid + '.png'))
        mask = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)
        mask_vis = cv2.applyColorMap(((mask % 20) * 12).astype(np.uint8), cv2.COLORMAP_JET)
        
        axes[i, 0].imshow(img); axes[i, 0].set_title(sid); axes[i, 0].axis('off')
        axes[i, 1].imshow(mask_vis); axes[i, 1].set_title("Mask"); axes[i, 1].axis('off')
        
        norm_path = data_path / 'normals' / (sid + '.npy')
        if norm_path.exists():
            normals = np.load(str(norm_path))
            axes[i, 2].imshow(((normals + 1) * 127.5).astype(np.uint8))
        axes[i, 2].set_title("Normals"); axes[i, 2].axis('off')
        
    plt.tight_layout(); plt.show()

preview_dataset("data", num_samples=2)

## ‚öôÔ∏è 2. Training Configuration


In [None]:
from mmengine.config import Config

project_root = Path([p for p in sys.path if "roof" in p][0])
CONFIG_FILE = str(project_root / "configs/deeproof_finetune_swin_L.py")
WORK_DIR = str(project_root / "work_dirs/swin_l_omnicity_v2")

cfg = Config.fromfile(CONFIG_FILE)
cfg.work_dir = WORK_DIR
cfg.data_root = str(project_root / "data/OmniCity/")
cfg.train_dataloader.dataset.data_root = cfg.data_root
cfg.val_dataloader.dataset.data_root = cfg.data_root
cfg.train_cfg.max_iters = 20000

print(f"‚úÖ Configuration Validated. WorkDir: {WORK_DIR}")

## üöÄ 3. Start Training


In [None]:
import torch
from mmengine.runner import Runner

print(f"üöÄ Starting Trainer on: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")

runner = Runner.from_cfg(cfg)
runner.train()