# üè† DeepRoof-2026: Multi-Task Training Notebook

Welcome to the official training environment for the **DeepRoof-2026 AI Roof Layout Engine**. 

### üõ† Step 1: Initialize & Fix Environment
This cell resolves path issues and **binary version conflicts** (MMCV/Torch/MMSeg) automatically.

In [None]:
import os
import sys
import subprocess
import re
from pathlib import Path

# --- 1. SET UP PATHS. --- 
project_root = str(Path(os.getcwd()).parent)
if project_root not in sys.path:
    sys.path.insert(0, project_root)
    print(f"‚úÖ Added {project_root} to sys.path")

# --- 2. ROBUST VERSION FIXER ---
def get_pkg_version(package_name):
    try:
        result = subprocess.check_output([sys.executable, "-m", "pip", "show", package_name], stderr=subprocess.DEVNULL).decode()
        for line in result.split('\n'):
            if line.startswith('Version: '):
                return line.split(': ')[1].strip()
    except:
        return None
    return None

def setup_environment():
    print("üîç Checking Environment Health...")
    requires_restart = False
    
    # 1. Restore MMCV 2.2.0 (Matched with your Torch version)
    mmcv_ver = get_pkg_version("mmcv")
    if mmcv_ver != "2.2.0":
        print(f"‚ö†Ô∏è Found MMCV {mmcv_ver}. Re-installing compatible MMCV 2.2.0...")
        subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "mmcv==2.2.0"])
        requires_restart = True

    # 2. Patch MMSeg check if it exists (NUCLEAR OPTION)
    mmseg_path = ""
    try:
        result = subprocess.check_output([sys.executable, "-m", "pip", "show", "mmsegmentation"], stderr=subprocess.DEVNULL).decode()
        for line in result.split('\n'):
            if line.startswith('Location: '):
                mmseg_path = os.path.join(line.split(': ')[1].strip(), "mmseg/__init__.py")
                break
    except:
        pass

    if mmseg_path and os.path.exists(mmseg_path):
        with open(mmseg_path, 'r') as f:
            content = f.read()
        
        # Use RegEx to remove the assertion block completely, handling multi-line strings
        # Matches 'assert (mmcv_min...' up to the end of the f-string block
        pattern = r"assert \(mmcv_min_version.*?\)"
        
        if re.search(pattern, content):
            print("ü©π Removing corrupted assertions from mmsegmentation...")
            
            # 1. Remove the Multi-Line Assertion Block safely
            # We look for lines containing the assert and the f-string continuation lines
            lines = content.splitlines()
            clean_lines = []
            skip_next = False
            
            for i, line in enumerate(lines):
                # Skip logic for multi-line mess
                if "assert (mmcv_min_version" in line or "is used but incompatible" in line or "Please install mmcv" in line:
                    continue
                # Remove any previous patch artifacts
                if "Patched by DeepRoof" in line or "if False:" in line:
                    continue
                    
                clean_lines.append(line)
            
            # Re-assemble
            new_content = "\n".join(clean_lines)
            
            # 2. Add the clean override at the end of the imports
            if "mmcv_max_version = digit_version('9.9.9')" not in new_content:
                new_content = new_content.replace(
                    "mmcv_max_version = digit_version(MMCV_MAX)",
                    "mmcv_max_version = digit_version('9.9.9') # Override"
                )

            with open(mmseg_path, 'w') as f:
                f.write(new_content)
            requires_restart = True

    # 3. Ensure other dependencies
    for pkg in ["ftfy", "regex", "rasterio", "geopandas", "albumentations"]:
        if not get_pkg_version(pkg):
            print(f"üì¶ Installing {pkg}...")
            subprocess.check_call([sys.executable, "-m", "pip", "install", pkg])
            requires_restart = True

    if requires_restart:
        print("\n" + "!"*50)
        print("CRITICAL: Environment fixed! PLEASE RESTART THE KERNEL NOW.")
        print("!"*50)
        return False
    
    print("‚úÖ Environment is HEALTHY and COMPATIBLE.")
    return True

if setup_environment():
    import torch
    from mmengine.config import Config
    from mmengine.runner import Runner
    print(f"üöÄ CUDA Ready: {torch.cuda.is_available()} | Device: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")

## üìÇ 1. Dataset Preview

Visualize the **satellite imagery**, **instance masks**, and **surface normals**.

In [None]:
def preview_dataset(data_root, num_samples=3):
    import matplotlib.pyplot as plt
    import numpy as np
    import cv2
    
    data_path = Path(data_root)
    if not data_path.is_absolute():
        data_path = Path(project_root) / data_root
        
    train_file = data_path / 'train.txt'
    if not train_file.exists():
        print(f"‚ùå Could not find train.txt at {train_file}. Run prepare_omnicity_v2_final.py first!")
        return
        
    with open(train_file, 'r') as f:
        sample_ids = [line.strip() for line in f.readlines()[:num_samples]]
    
    fig, axes = plt.subplots(num_samples, 3, figsize=(15, 5 * num_samples))
    for i, sid in enumerate(sample_ids):
        img_path = str(data_path / 'images' / (sid + '.jpg'))
        img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
        
        mask_path = str(data_path / 'masks' / (sid + '.png'))
        mask = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)
        mask_vis = cv2.applyColorMap(((mask % 20) * 12).astype(np.uint8), cv2.COLORMAP_JET)
        
        axes[i, 0].imshow(img); axes[i, 0].set_title(sid); axes[i, 0].axis('off')
        axes[i, 1].imshow(mask_vis); axes[i, 1].set_title("Mask"); axes[i, 1].axis('off')
        
        norm_path = data_path / 'normals' / (sid + '.npy')
        if norm_path.exists():
            normals = np.load(str(norm_path))
            axes[i, 2].imshow(((normals + 1) * 127.5).astype(np.uint8))
        axes[i, 2].set_title("Normals"); axes[i, 2].axis('off')
        
    plt.tight_layout(); plt.show()

preview_dataset("data/OmniCity", num_samples=2)

## ‚öôÔ∏è 2. Training Configuration


In [None]:
from mmengine.config import Config

MODE = "fine-tune" 
CONFIG_FILE = str(Path(project_root) / "configs/deeproof_finetune_swin_L.py")
WORK_DIR = str(Path(project_root) / "work_dirs/swin_l_omnicity_v2")

cfg = Config.fromfile(CONFIG_FILE)
cfg.work_dir = WORK_DIR
cfg.data_root = str(Path(project_root) / "data/OmniCity/")
cfg.train_dataloader.dataset.data_root = cfg.data_root
cfg.val_dataloader.dataset.data_root = cfg.data_root
cfg.train_cfg.max_iters = 20000

if MODE == "scratch": cfg.load_from = None
print(f"‚úÖ Configuration Validated. WorkDir: {WORK_DIR}")

## üöÄ 3. Start Training


In [None]:
import torch
from mmengine.runner import Runner

print(f"üöÄ Starting Trainer on: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")

runner = Runner.from_cfg(cfg)
runner.train()