# üè† DeepRoof-2026: Master Training Lab

### üõ† Step 1: Nuclear Environment Initialization
This cell fixes **CUDA Linking errors** (`libcudart.so.11.0`), activates the venv, and patches **mmsegmentation**.

In [None]:
import os
import sys
import subprocess
import torch
import ctypes
import glob
from pathlib import Path

print("üõ∞ Initializing DeepRepair Protocol...")

# --- 1. PATH RESOLUTION ---
project_root = Path("/workspace/roof")
if not project_root.exists():
    project_root = Path(os.getcwd()).parent

venv_path = project_root / "venv"
if not venv_path.exists():
    venv_path = project_root / ".venv"

if venv_path.exists():
    site_pkgs = list(venv_path.glob("lib/python*/site-packages"))
    if site_pkgs:
        if str(site_pkgs[0]) not in sys.path:
            sys.path.insert(0, str(site_pkgs[0]))
        sys.executable = str(venv_path / "bin" / "python")
    print(f"üêç Venv Active: {venv_path}")

if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

# --- 2. NUCLEAR CUDA LINKER (FIXES libcudart.so.11.0) ---
def nuclear_cuda_fix():
    if not torch.cuda.is_available():
        print("‚ùÑÔ∏è CPU Mode: Skipping CUDA linking.")
        return True

    print("üîç Searching for libcudart.so.11.0...")
    
    # Deep search locations
    search_patterns = [
        "/usr/local/cuda*/lib64/libcudart.so.11.0",
        "/usr/lib/x86_64-linux-gnu/libcudart.so.11.0",
        str(venv_path / "lib/python*/site-packages/nvidia/cuda_runtime/lib/libcudart.so.11.0")
    ]
    
    found_lib = None
    for pattern in search_patterns:
        matches = glob.glob(pattern)
        if matches:
            found_lib = matches[0]
            break
            
    if not found_lib:
        print("üì¶ Missing libcudart.so.11.0. Installing nvidia-cuda-runtime-cu11...")
        subprocess.check_call([sys.executable, "-m", "pip", "install", "nvidia-cuda-runtime-cu11"])
        # Re-search after install
        matches = glob.glob(str(venv_path / "lib/python*/site-packages/nvidia/cuda_runtime/lib/libcudart.so.11.0"))
        if matches: found_lib = matches[0]
        else: 
            print("‚ùå ERROR: Still could not find libcudart after install.")
            return False

    print(f"üìç Found linking target: {found_lib}")
    try:
        # Force load into process GLOBAL scope. 
        # This satisfies later 'import mmcv._ext' calls that would otherwise fail.
        ctypes.CDLL(found_lib, mode=ctypes.RTLD_GLOBAL)
        print("‚úÖ CUDA Runtime successfully force-loaded into memory.")
    except Exception as e:
        print(f"‚ö†Ô∏è Force-load failed: {e}")
        return False
    
    return True

# --- 3. MMSEG & MMCV INTEGRITY ---
def verify_and_patch():
    # Patch mmsegmentation assertions
    target_file = None
    try:
        # Search disk if import crashes
        matches = glob.glob(str(venv_path / "lib/python*/site-packages/mmseg/__init__.py"))
        if matches: target_file = Path(matches[0])
    except: pass
    
    if target_file and target_file.exists():
        with open(target_file, 'r') as f: content = f.read()
        if "assert (mmcv_min_version" in content:
            print(f"ü©π Deleting assertions in {target_file}")
            unlocked = """# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import mmengine
from mmengine.utils import digit_version
from .version import __version__, version_info
MMCV_MIN = '2.0.0rc4'
MMCV_MAX = '2.2.0'
MMENGINE_MIN = '0.7.1'
MMENGINE_MAX = '1.0.0'
mmcv_min_version = digit_version(MMCV_MIN)
mmcv_max_version = digit_version('9.9.9') # OVERRIDE by DeepRoof
mmcv_version = digit_version(mmcv.__version__)
mmengine_min_version = digit_version(MMENGINE_MIN)
mmengine_max_version = digit_version('9.9.9') # OVERRIDE by DeepRoof
mmengine_version = digit_version(mmengine.__version__)
__all__ = ['__version__', 'version_info', 'digit_version']
"""
            with open(target_file, 'w') as f: f.write(unlocked)
            return False
            
    # Verify MMCV Ops
    if torch.cuda.is_available():
        try:
            from mmcv.ops import point_sample
            print("‚úÖ MMCV Binary (Ops) verified.")
        except:
            print("üîÑ Re-installing MMCV Binary wheel...")
            subprocess.run([sys.executable, "-m", "pip", "uninstall", "-y", "mmcv"], check=False)
            subprocess.check_call([sys.executable, "-m", "pip", "install", "mmcv==2.2.0", "-f", "https://download.openmmlab.com/mmcv/dist/cu118/torch2.1/index.html"])
            return False
            
    return True

if nuclear_cuda_fix() and verify_and_patch():
    print(f"üöÄ System Live | {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")
else:
    print("\n‚ö†Ô∏è  DISK UPDATED. PLEASE RESTART KERNEL TO FINALIZE.")

## üìÇ 1. Dataset Preview

Visualize the **OmniCity** satellite imagery and ground truth **Masks** + **Surface Normals**.

In [None]:
def preview_dataset(num_samples=3):
    import matplotlib.pyplot as plt
    import numpy as np
    import cv2
    
    data_path = project_root / "data/OmniCity"
    train_file = data_path / 'train.txt'
    
    if not train_file.exists():
        print(f"‚ùå Multi-task training data not found at {data_path}.")
        return
        
    with open(train_file, 'r') as f:
        sample_ids = [line.strip() for line in f.readlines()[:num_samples]]
    
    fig, axes = plt.subplots(num_samples, 3, figsize=(15, 5 * num_samples))
    for i, sid in enumerate(sample_ids):
        img = cv2.cvtColor(cv2.imread(str(data_path / 'images' / (sid + '.jpg'))), cv2.COLOR_BGR2RGB)
        
        mask = cv2.imread(str(data_path / 'masks' / (sid + '.png')), cv2.IMREAD_UNCHANGED)
        mask_vis = cv2.applyColorMap(((mask % 20) * 12).astype(np.uint8), cv2.COLORMAP_JET)
        
        axes[i, 0].imshow(img); axes[i, 0].set_title(sid); axes[i, 0].axis('off')
        axes[i, 1].imshow(mask_vis); axes[i, 1].set_title("Mask"); axes[i, 1].axis('off')
        
        norm_path = data_path / 'normals' / (sid + '.npy')
        if norm_path.exists():
            normals = np.load(str(norm_path))
            axes[i, 2].imshow(((normals + 1) * 127.5).astype(np.uint8))
        axes[i, 2].set_title("Normals"); axes[i, 2].axis('off')
        
    plt.tight_layout(); plt.show()

preview_dataset(num_samples=2)

## ‚öôÔ∏è 2. Scratch Training Configuration (Epoch-Based)

We are using the **MASTER EPOCH-BASED SCRATCH PROFILE**:
- **Duration**: 150 Epochs (~160k steps).
- **Val Interval**: Every 1 Epoch (Reports results per-epoch).
- **No Pre-Training**: `load_from = None`.
- **Checkpoints**: Interval snapshots every 5 epochs + `best_mIoU.pth`.

In [None]:
from mmengine.config import Config

CONFIG_PATH = str(project_root / "configs/deeproof_scratch_swin_L.py")
WORK_DIR = str(project_root / "work_dirs/swin_l_scratch_v1")

cfg = Config.fromfile(CONFIG_PATH)
cfg.work_dir = WORK_DIR
cfg.data_root = str(project_root / "data/OmniCity/")
cfg.train_dataloader.dataset.data_root = cfg.data_root
cfg.val_dataloader.dataset.data_root = cfg.data_root

print(f"üèÜ MASTER SCRATCH CONFIG LOADED")
print(f"üî• Max Epochs: {cfg.train_cfg.max_epochs}")
print(f"üìâ Initial LR: {cfg.optimizer.lr}")
print(f"üìä Reporting Interval: Every Epoch")

## üöÄ 3. Kickoff Training

This will invoke the `mmengine.Runner` and begin the full model convergence process. **Detailed stats will print to this output at the end of every epoch.**

In [None]:
import torch
from mmengine.runner import Runner

print(f"üöÄ Starting Master Trainer on: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")

runner = Runner.from_cfg(cfg)
runner.train()

## üìä 4. Monitoring & Metrics

Run this cell during or after training to visualize performance trends.

In [None]:
import json
import matplotlib.pyplot as plt

def plot_training_logs(log_path):
    if not os.path.exists(log_path):
        print("üïí No logs found yet.")
        return
        
    iters, losses, miou = [], [], []
    with open(log_path, 'r') as f:
        for line in f:
            data = json.loads(line)
            if 'loss' in data:
                iters.append(data.get('iter', 0))
                losses.append(data['loss'])
            if 'mIoU' in data:
                miou.append(data['mIoU'])
                
    plt.figure(figsize=(10, 5))
    plt.plot(iters, losses, label='Loss')
    plt.title("Training Progress"); plt.show()

# log_json = glob.glob(os.path.join(WORK_DIR, "*/vis_data/scalars.json"))
# if log_json: plot_training_logs(log_json[-1])