# üè† DeepRoof-2026: Multi-Task Training Notebook

Welcome to the official training environment for the **DeepRoof-2026 AI Roof Layout Engine**. 

This notebook allows you to:
1. **Visualize** the OmniCity dataset labels (Instance Masks + Surface Normals).
2. **Configure** training parameters for either **Scratch Training** or **Fine-Tuning**.
3. **Launch** the high-performance training loop optimized for A100 GPUs.
4. **Evaluate** and visualize model predictions on new satellite imagery.

In [None]:
import os
import sys
import subprocess
import torch
from pathlib import Path

# --- üõ† STEP 1: SOLVE PATHS & VENV ---
project_root = str(Path(os.getcwd()).parent)
if project_root not in sys.path:
    sys.path.insert(0, project_root)

# 1a. Try to activate venv if it exists
venv_site = os.path.join(project_root, ".venv/lib/python3.11/site-packages")
if os.path.exists(venv_site):
    sys.path.insert(1, venv_site)
    print(f"üêç Using venv at: {venv_site}")

# --- üì¶ STEP 2: AUTO-INSTALL MISSING DEPENDENCIES ---
def install_if_missing(package, import_name=None):
    import_name = import_name or package
    try:
        __import__(import_name)
        print(f"‚úÖ {import_name} is already installed.")
    except ImportError:
        print(f"üì¶ Installing {package}... (This may take a minute)")
        subprocess.check_call([sys.executable, "-m", "pip", "install", package])

# OpenMMLab Stack
try:
    import mmengine
    print("‚úÖ mmengine found.")
except ImportError:
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "openmim"])
    subprocess.check_call([sys.executable, "-m", "mim", "install", "mmengine"])

try:
    import mmseg
    print("‚úÖ mmsegmentation found.")
except ImportError:
    print("üì¶ Installing mmsegmentation using MIM...")
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "openmim"])
    subprocess.check_call([sys.executable, "-m", "mim", "install", "mmsegmentation>=1.0.0"])
    subprocess.check_call([sys.executable, "-m", "mim", "install", "mmcv>=2.0.0"])

install_if_missing("rasterio")
install_if_missing("geopandas")
install_if_missing("albumentations")

from mmengine.config import Config
from mmengine.runner import Runner

# Check GPU Status
print(f"\nüöÄ CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"üíª Device: {torch.cuda.get_device_name(0)}")

## üìÇ 1. Dataset Preview

Before training, let's look at what our model will see. We combine **Satellite View 1** images with **Instance Masks** (segmentation) and **Surface Normals** (geometry).

In [None]:
def preview_dataset(data_root, num_samples=3):
    import matplotlib.pyplot as plt
    import numpy as np
    import cv2
    
    data_path = Path(data_root)
    if not data_path.is_absolute():
        data_path = Path(project_root) / data_path
        
    train_file = data_path / 'train.txt'
    if not train_file.exists():
        print(f"‚ùå Could not find train.txt at {train_file}. Please run the preparation script first!")
        return
        
    with open(train_file, 'r') as f:
        sample_ids = [line.strip() for line in f.readlines()[:num_samples]]
    
    fig, axes = plt.subplots(num_samples, 3, figsize=(15, 5 * num_samples))
    
    for i, sid in enumerate(sample_ids):
        img = cv2.cvtColor(cv2.imread(str(data_path / 'images' / (sid + '.jpg'))), cv2.COLOR_BGR2RGB)
        
        mask = cv2.imread(str(data_path / 'masks' / (sid + '.png')), cv2.IMREAD_UNCHANGED)
        mask_vis = cv2.applyColorMap(((mask % 20) * 12).astype(np.uint8), cv2.COLORMAP_JET)
        
        normals = np.load(str(data_path / 'normals' / (sid + '.npy')))
        normals_vis = ((normals + 1) * 127.5).astype(np.uint8)
        
        axes[i, 0].imshow(img); axes[i, 0].set_title(sid); axes[i, 0].axis('off')
        axes[i, 1].imshow(mask_vis); axes[i, 1].set_title("Mask"); axes[i, 1].axis('off')
        axes[i, 2].imshow(normals_vis); axes[i, 2].set_title("Normals"); axes[i, 2].axis('off')
    
    plt.tight_layout()
    plt.show()

preview_dataset("data/OmniCity", num_samples=2)

## ‚öôÔ∏è 2. Training Configuration

### üìä Hyperparameter Overview
| Parameter | Value | Rationale |
| :--- | :--- | :--- |
| **Resolution** | 1024x1024 | Highest detail for complex roof layouts. |
| **Duration** | 20,000 iters | ~16 Epochs (Ideal for fine-tuning without overfitting). |
| **Batch Size** | 4 per GPU | optimized for A100 40GB/80GB memory. |
| **Optimization** | AMP + AdamW | Mixed precision for 2x speedup on A100. |
| **Task** | Multi-Task | Learns segmentation + geometry simultaneously. |

In [None]:
MODE = "fine-tune" # Options: "fine-tune" or "scratch"
CONFIG_FILE = str(Path(project_root) / "configs/deeproof_finetune_swin_L.py")
WORK_DIR = str(Path(project_root) / "work_dirs/swin_l_omnicity_v2")

cfg = Config.fromfile(CONFIG_FILE)
cfg.work_dir = WORK_DIR

cfg.data_root = str(Path(project_root) / "data/OmniCity/")
cfg.train_dataloader.dataset.data_root = cfg.data_root
cfg.val_dataloader.dataset.data_root = cfg.data_root

cfg.train_dataloader.batch_size = 4
cfg.train_cfg.max_iters = 20000

if MODE == "scratch":
    cfg.load_from = None
    cfg.optimizer.lr = 0.0001
    print("üöÄ Configured for Training from Scratch")
else:
    print(f"üéØ Configured for Fine-tuning with weights: {cfg.load_from}")

cfg.default_hooks.checkpoint = dict(
    type='CheckpointHook', by_epoch=False, interval=2000, save_best='mIoU', rule='greater')

print("‚úÖ Configuration Validated.")

## üöÄ 3. Start Training


In [None]:
from mmengine.registry import MODELS, DATASETS
print(f"Registered Models: {len(MODELS.module_dict)}")

runner = Runner.from_cfg(cfg)
runner.train()

## üîç 4. Visualize Prediction


In [None]:
from mmseg.apis import init_model, inference_model
CHECKPOINT = os.path.join(WORK_DIR, 'best_mIoU.pth')

if os.path.exists(CHECKPOINT):
    model = init_model(CONFIG_FILE, CHECKPOINT, device='cuda:0')
    img_path = str(Path(project_root) / "data/OmniCity/images/some_sample.jpg") 
    if os.path.exists(img_path):
        result = inference_model(model, img_path)
        print("Prediction Complete.")
else:
    print("No checkpoint found. Please complete training first.")