# üè† DeepRoof-2026: Master Training Lab

### üõ† Step 1: System-Level Environment Initialization
This cell handles **MMCV Source Compilation** (for Torch 2.4+), **MMSegmentation Repair**, and **CUDA Linking** directly in the system environment.

In [None]:
import os
import sys
import subprocess
import torch
import ctypes
import glob
import shutil
from pathlib import Path

print('üõ∞ Initializing DeepRepair Protocol V6 (Comprehensive)...')

# --- 1. PATH RESOLUTION (NO VENV) ---
project_root = Path('/workspace/roof')
if not project_root.exists():
    project_root = Path(os.getcwd()).parent

# Add project root to sys.path if not present
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))
    
print(f'üìÇ Project Root: {project_root}')
print(f'üêç Python: {sys.executable}')

# --- 2. NUCLEAR CUDA LINKER (System Scan) ---
def nuclear_cuda_fix():
    if not torch.cuda.is_available():
        print('‚ùÑÔ∏è CPU Mode: Skipping CUDA linking.')
        return True

    print('üîç Searching for libcudart.so...')
    # Common locations in containers
    search_patterns = [
        '/usr/local/cuda*/lib64/libcudart.so*',
        '/usr/lib/x86_64-linux-gnu/libcudart.so*',
        '/usr/lib/libcudart.so*'
    ]
    
    found_lib = None
    for pattern in search_patterns:
        matches = glob.glob(pattern)
        if matches:
            # Prefer specific version if multiple
            found_lib = sorted(matches)[-1]
            break
            
    if found_lib:
        print(f'üìç Found linking target: {found_lib}')
        try:
            ctypes.CDLL(found_lib, mode=ctypes.RTLD_GLOBAL)
            print('‚úÖ CUDA Runtime force-loaded.')
        except Exception as e:
            print(f'‚ö†Ô∏è Force-load warning: {e}')
    else:
        print('‚ö†Ô∏è Could not find libcudart.so in standard paths. Assuming built-in.')
    return True

# --- 3. MMCV SOURCE COMPILER (Torch 2.4+) ---
def match_mmcv_to_torch():
    torch_ver = torch.__version__
    print(f'üîç Detected: Torch {torch_ver}')
    
    mmcv_ok = False
    try:
        import mmcv
        from mmcv.ops import point_sample
        print('‚úÖ MMCV is fully functional.')
        mmcv_ok = True
    except (ImportError, ModuleNotFoundError) as e:
        print(f'‚ùå MMCV Error: {e}')
    except Exception as e:
        print(f'‚ùå Unknown MMCV Error: {e}')

    if mmcv_ok: return True

    # Logic determines remediation
    print('üîÑ Attempting Repair...')

    # Check for bleeding edge torch
    is_bleeding_edge = False
    if '+' in torch_ver: 
        base_ver = torch_ver.split('+')[0]
    else:
        base_ver = torch_ver
        
    major, minor = map(int, base_ver.split('.')[:2])
    if major >= 2 and minor >= 4:
        is_bleeding_edge = True

    if is_bleeding_edge:
        print('‚ö†Ô∏è Bleeding-edge Torch (>=2.4) detected. BINARY WHEELS DO NOT EXIST.')
        print('üõ† Starting SOURCE COMPILATION (approx 5-10 mins)...')
        
        # Cleaning old
        subprocess.run([sys.executable, '-m', 'pip', 'uninstall', '-y', 'mmcv'], check=False)
        
        mmcv_dir = project_root / 'mmcv-source'
        if mmcv_dir.exists(): shutil.rmtree(mmcv_dir)
        
        # Clone
        subprocess.check_call(['git', 'clone', '-b', 'v2.2.0', 'https://github.com/open-mmlab/mmcv.git', str(mmcv_dir)])
        
        # Compile
        env = os.environ.copy()
        env['MMCV_WITH_OPS'] = '1'
        env['FORCE_CUDA'] = '1'
        env['MAX_JOBS'] = '8'
        
        print('‚è≥ Compiling... check terminal for details if stuck.')
        subprocess.check_call([sys.executable, '-m', 'pip', 'install', '.'], cwd=str(mmcv_dir), env=env)
        print('‚úÖ Compilation Complete.')
        shutil.rmtree(mmcv_dir)
        return False
    else:
         print('‚ÑπÔ∏è Standard Torch detected. Trying MIM install.')
         subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-U', 'openmim'])
         subprocess.check_call([sys.executable, '-m', 'mim', 'install', 'mmcv>=2.0.0'])
         return False

# --- 4. MMSEGMENTATION CHECK & REPAIR ---
def check_mmseg():
    print('üîç Checking MMSegmentation installation...')
    try:
        import mmseg
        print(f'   MMSeg version: {mmseg.__version__}')
        try:
            from mmseg.models.segmentors.mask2former import Mask2Former  # noqa: F401
            print('‚úÖ Mask2Former segmentor is available.')
            return True
        except ImportError:
            try:
                from mmseg.models.segmentors import Mask2Former  # noqa: F401
                print('‚úÖ Mask2Former segmentor is available.')
                return True
            except ImportError:
                from mmseg.models.segmentors import EncoderDecoder  # noqa: F401
                from mmseg.models.decode_heads import Mask2FormerHead  # noqa: F401
                print('‚ÑπÔ∏è Mask2Former segmentor not exported by this mmseg build; using EncoderDecoder compatibility path.')
                return True

    except (ImportError, ModuleNotFoundError) as e:
        print(f'‚ùå MMSegmentation Issue: {e}')
        print('üîÑ Reinstalling MMSegmentation via MIM...')
        
        subprocess.run([sys.executable, '-m', 'pip', 'uninstall', '-y', 'mmsegmentation'], check=False)
        subprocess.check_call([sys.executable, '-m', 'mim', 'install', 'mmsegmentation>=1.2.2'])
        print('‚úÖ Reinstall complete. PLEASE RESTART KERNEL.')
        return False
    except Exception as e:
        print(f'‚ö†Ô∏è Unknown MMSeg error: {e}')
        return False
    return True

# --- 5. SAFETY NOTE ---
def patch_assertions():
    print('‚ÑπÔ∏è Skipping site-packages patching for mmseg/__init__.py (safer and reproducible).')
    print('   If mmcv/mmseg versions are incompatible, reinstall matching versions instead of editing package files.')
    return True

if nuclear_cuda_fix() and match_mmcv_to_torch():
    if check_mmseg():
        patch_assertions()
        print('üöÄ System Ready.')
    else:
        print('\n‚ö†Ô∏è  MMSEG UPDATED. PLEASE RESTART KERNEL.')
else:
    print('\n‚ö†Ô∏è  ENVIRONMENT UPDATED. PLEASE RESTART KERNEL.')

## üìÇ 1. Dataset Preview

Visualize the **OmniCity** satellite imagery and ground truth **Masks** + **Surface Normals**.

In [None]:
def preview_dataset(num_samples=3):
    import matplotlib.pyplot as plt
    import numpy as np
    import cv2
    
    data_path = project_root / 'data/OmniCity'
    train_file = data_path / 'train.txt'
    
    if not train_file.exists():
        print(f'‚ùå Multi-task training data not found at {data_path}.')
        return
        
    with open(train_file, 'r') as f:
        sample_ids = [line.strip() for line in f.readlines()[:num_samples]]
    
    fig, axes = plt.subplots(num_samples, 3, figsize=(15, 5 * num_samples))
    for i, sid in enumerate(sample_ids):
        img = cv2.cvtColor(cv2.imread(str(data_path / 'images' / (sid + '.jpg'))), cv2.COLOR_BGR2RGB)
        
        mask = cv2.imread(str(data_path / 'masks' / (sid + '.png')), cv2.IMREAD_UNCHANGED)
        mask_vis = cv2.applyColorMap(((mask % 20) * 12).astype(np.uint8), cv2.COLORMAP_JET)
        
        axes[i, 0].imshow(img); axes[i, 0].set_title(sid); axes[i, 0].axis('off')
        axes[i, 1].imshow(mask_vis); axes[i, 1].set_title('Mask'); axes[i, 1].axis('off')
        
        norm_path = data_path / 'normals' / (sid + '.npy')
        if norm_path.exists():
            normals = np.load(str(norm_path))
            axes[i, 2].imshow(((normals + 1) * 127.5).astype(np.uint8))
        axes[i, 2].set_title('Normals'); axes[i, 2].axis('off')
        
    plt.tight_layout(); plt.show()

preview_dataset(num_samples=2)

## ‚öôÔ∏è 2. Scratch Training Configuration (Epoch-Based)

We are using the **MASTER EPOCH-BASED SCRATCH PROFILE**:
- **Duration**: 150 Epochs (~160k steps).
- **Val Interval**: Every 1 Epoch (Reports results per-epoch).
- **No Pre-Training**: `load_from = None`.
- **Checkpoints**: Interval snapshots every 5 epochs + `best_mIoU.pth`.

In [None]:
from mmengine.config import Config

CONFIG_PATH = str(project_root / 'configs/deeproof_scratch_swin_L.py')
WORK_DIR = str(project_root / 'work_dirs/swin_l_scratch_v1')

cfg = Config.fromfile(CONFIG_PATH)
cfg.work_dir = WORK_DIR
cfg.data_root = str(project_root / 'data/OmniCity/')
cfg.train_dataloader.dataset.data_root = cfg.data_root
cfg.val_dataloader.dataset.data_root = cfg.data_root

print(f'üèÜ MASTER SCRATCH CONFIG LOADED')
print(f'üî• Max Epochs: {cfg.train_cfg.max_epochs}')
print(f'üìâ Initial LR: {cfg.optimizer.lr}')
print(f'üìä Reporting Interval: Every Epoch')

## üöÄ 3. Kickoff Training

This will invoke the `mmengine.Runner` and begin the full model convergence process. **Detailed stats will print to this output at the end of every epoch.**

In [None]:
import torch
from mmengine.runner import Runner

print(f'üöÄ Starting Master Trainer on: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}')

runner = Runner.from_cfg(cfg)
runner.train()

## üìä 4. Monitoring & Metrics

Run this cell during or after training to visualize performance trends.

In [None]:
import json
import matplotlib.pyplot as plt

def plot_training_logs(log_path):
    if not os.path.exists(log_path):
        print('üïí No logs found yet.')
        return
        
    iters, losses, miou = [], [], []
    with open(log_path, 'r') as f:
        for line in f:
            data = json.loads(line)
            if 'loss' in data:
                iters.append(data.get('iter', 0))
                losses.append(data['loss'])
            if 'mIoU' in data:
                miou.append(data['mIoU'])
                
    plt.figure(figsize=(10, 5))
    plt.plot(iters, losses, label='Loss')
    plt.title('Training Progress'); plt.show()

# log_json = glob.glob(os.path.join(WORK_DIR, '*/vis_data/scalars.json'))
# if log_json: plot_training_logs(log_json[-1])