# MFNet Inference — Streaming for Very Large Ortho/DSM (No Full-Image Load)

This notebook is optimized for **huge rasters** (e.g., 9 GB ortho). It:
- Streams tiles from disk using Rasterio **windows** (no full `read()` of entire images)
- Uses a **WarpedVRT** to align DSM to the Ortho grid **on the fly** (CRS/resolution/transform)
- Normalizes per-tile and runs MFNet inference
- Writes predictions **directly** into a tiled GeoTIFF

> Default is **no overlap** to minimize memory. Increase `OVERLAP` if you need smoothing (with the tradeoff that overlaps get last-write-wins).

In [None]:
# Optional: install deps if needed
# %pip install rasterio tqdm matplotlib
import os, sys, math, warnings
warnings.filterwarnings('ignore')

In [None]:
import numpy as np
import torch
import rasterio
from rasterio.enums import Resampling
from rasterio.vrt import WarpedVRT
from rasterio.windows import Window
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from pathlib import Path

In [None]:
# ==== USER CONFIG (edit me) ====
ORT_H = './data/AMNS/Cropped.tif'   # e.g., 'data/ortho.tif'
DSM_H = './data/AMNS/Cropped_DEM.tif'     # e.g., 'data/dsm.tif'

# Checkpoint with trained MFNet weights
CHECKPOINT = './weights/sam_vit_l_0b3195.pth'  # e.g., 'checkpoints/mfnet_best.pth'

# Model hyper-params (edit to match your training)
NUM_CLASSES = 6
BACKBONE    = 'sam-vit-l'      # per your training
DEVICE      = 'cuda' if torch.cuda.is_available() else 'cpu'

# 🚀 PERFORMANCE OPTIMIZATION FOR LARGE IMAGES
# For RTX 4070 (12GB VRAM) with 9GB ortho, consider these optimizations:

# ⚡ QUICK TESTING OPTIONS (choose one):

# Option A: Spatial crop (process only a small region) - FASTEST for testing
USE_CROP = False
# CROP_SIZE = 5000  # Process only 5000x5000 pixels from center (~500 tiles, ~1-2 minutes)

# Option B: Aggressive downsampling 
DOWNSAMPLE_FACTOR = 1   # Quarter resolution (was 2) - reduces tiles by 4x

# Option C: Larger tiles, no overlap
WINDOW_SIZE = 512   # Larger tiles = fewer tiles (was 256)
OVERLAP     = 32     # No overlap = 2x speedup (was 32)

# For full production run, disable crop and use:
# USE_CROP = False
# DOWNSAMPLE_FACTOR = 2  
# WINDOW_SIZE = 256
# OVERLAP = 32

# Memory management
BATCH_CLEAR_FREQ = 50  # Clear GPU cache every N tiles

# Normalization
RGB_MEAN = np.array([0.5, 0.5, 0.5], dtype=np.float32)
RGB_STD  = np.array([0.5, 0.5, 0.5], dtype=np.float32)

# Output
crop_suffix = f'_crop{CROP_SIZE}' if USE_CROP else '_fullimage'
OUTPUT_TIF = f'prediction{crop_suffix}_ds{DOWNSAMPLE_FACTOR}_ws{WINDOW_SIZE}_ov{OVERLAP}.tif'

In [None]:
# ==== Model Import ====
# Adjust these imports to the actual layout of your SSRS repo
try:
    # Example: from ssrs.models.mfnet import MFNet
    from UNetFormer_MMSAM import UNetFormer as MFNet  # <- change if your path differs
except Exception as e:
    print('⚠️ Could not import MFNet with default path. Update the import to match your repo structure.')
    raise


In [None]:
def norm_rgb(rgb, mean=RGB_MEAN, std=RGB_STD):
    # rgb: (3,h,w)
    rgb = rgb.astype(np.float32)
    if np.nanmax(rgb) > 1.0:
        rgb /= 255.0
    return (rgb - mean[:,None,None]) / std[:,None,None]

def norm_dsm(dsm):
    # dsm: (1,h,w)
    x = dsm.astype(np.float32)
    mu = np.nanmean(x)
    sd = np.nanstd(x) + 1e-6
    return (x - mu) / sd


def iter_windows(width, height, size=1024, overlap=0):
    step = size - overlap
    xs = list(range(0, max(width - size, 0) + 1, step))
    ys = list(range(0, max(height - size, 0) + 1, step))
    
    # Handle edge case where image is smaller than window size
    if width <= size:
        xs = [0]
    elif xs[-1] != width - size and width > size:
        xs.append(width - size)
        
    if height <= size:
        ys = [0] 
    elif ys[-1] != height - size and height > size:
        ys.append(height - size)
    
    for y in ys:
        for x in xs:
            # Ensure window doesn't exceed image bounds
            actual_width = min(size, width - x)
            actual_height = min(size, height - y)
            
            # Skip windows that would be too small (edge case handling)
            if actual_width < 1 or actual_height < 1:
                continue
                
            yield Window(x, y, actual_width, actual_height)

In [None]:
# ==== Open datasets with optional cropping and downsampling ====
ortho_src = rasterio.open(ORT_H)
print(f'Original Ortho: {ortho_src.width} x {ortho_src.height}, bands={ortho_src.count}')

# Step 1: Apply spatial cropping if enabled
if USE_CROP:
    # Calculate crop bounds (center region)
    center_x = ortho_src.width // 2
    center_y = ortho_src.height // 2
    half_crop = CROP_SIZE // 2
    
    crop_left = max(0, center_x - half_crop)
    crop_top = max(0, center_y - half_crop)
    crop_right = min(ortho_src.width, center_x + half_crop)
    crop_bottom = min(ortho_src.height, center_y + half_crop)
    
    crop_width = crop_right - crop_left
    crop_height = crop_bottom - crop_top
    
    # Calculate transform for cropped area
    crop_transform = ortho_src.transform * ortho_src.transform.translation(crop_left, crop_top)
    
    print(f'🎯 Cropping to {crop_width} x {crop_height} pixels (center region)')
    print(f'   Crop bounds: x={crop_left}-{crop_right}, y={crop_top}-{crop_bottom}')
    
    # Create cropped VRT
    cropped_ortho = WarpedVRT(ortho_src,
        crs=ortho_src.crs,
        transform=crop_transform,
        width=crop_width,
        height=crop_height,
        resampling=Resampling.bilinear
    )
else:
    cropped_ortho = ortho_src
    crop_width, crop_height = ortho_src.width, ortho_src.height
    crop_transform = ortho_src.transform

# Step 2: Apply downsampling
if DOWNSAMPLE_FACTOR > 1:
    final_width = crop_width // DOWNSAMPLE_FACTOR
    final_height = crop_height // DOWNSAMPLE_FACTOR
    final_transform = crop_transform * crop_transform.scale(DOWNSAMPLE_FACTOR, DOWNSAMPLE_FACTOR)
    
    print(f'📉 Downsampling by factor {DOWNSAMPLE_FACTOR}')
    print(f'   Final size: {final_width} x {final_height} (was {crop_width} x {crop_height})')
    
    # Create final downsampled VRT
    ortho = WarpedVRT(cropped_ortho,
        crs=cropped_ortho.crs,
        transform=final_transform,
        width=final_width,
        height=final_height,
        resampling=Resampling.bilinear
    )
else:
    ortho = cropped_ortho
    final_width, final_height = crop_width, crop_height
    final_transform = crop_transform

print(f'🎯 Processing Ortho: {ortho.width} x {ortho.height}, bands={ortho.count}')
print(f'   Using first 3 bands (RGB) from {ortho.count}-band ortho image')

# Estimate processing time
estimated_tiles = math.ceil(final_width / (WINDOW_SIZE - OVERLAP)) * math.ceil(final_height / (WINDOW_SIZE - OVERLAP))
estimated_time_min = estimated_tiles / 8 / 60  # ~8 tiles/second
print(f'📊 Estimated: {estimated_tiles} tiles, ~{estimated_time_min:.1f} minutes processing time')

# Estimate memory usage
estimated_mb = (final_width * final_height * 3 * 4) / (1024 * 1024)  # 4 bytes per float32 pixel
print(f'💾 Estimated ortho memory per full read: {estimated_mb:.1f} MB')

# Create DSM VRT aligned to the final processing grid
DSM_VRT_OPTS = dict(
    crs=ortho.crs,
    transform=final_transform,
    height=final_height,
    width=final_width,
    resampling=Resampling.bilinear,
)

dsm_src = rasterio.open(DSM_H)
dsm = WarpedVRT(dsm_src, **DSM_VRT_OPTS)
print(f'🗺️  DSM aligned to processing grid: {dsm.width} x {dsm.height}, bands={dsm.count}')

In [None]:
# ==== Prepare output GeoTIFF (tiled) ====
# Create a clean profile (not from VRT) for writing
output_profile = {
    'driver': 'GTiff',
    'dtype': 'uint16',
    'count': 1,
    'width': ortho.width,
    'height': ortho.height,
    'crs': ortho.crs,
    'transform': ortho.transform,
    'tiled': True,
    'blockxsize': min(WINDOW_SIZE, ortho.width),
    'blockysize': min(WINDOW_SIZE, ortho.height),
    'compress': 'lzw',
    'bigtiff': 'YES'
}

if Path(OUTPUT_TIF).exists():
    Path(OUTPUT_TIF).unlink()

out = rasterio.open(OUTPUT_TIF, 'w', **output_profile)
print(f'Output: {OUTPUT_TIF}')
print(f'Output dimensions: {output_profile["width"]} x {output_profile["height"]}')
print(f'Output blocks: {output_profile["blockxsize"]} x {output_profile["blockysize"]}')
print(f'Output transform: {output_profile["transform"]}')
print(f'Window size: {WINDOW_SIZE}, Overlap: {OVERLAP}')

In [None]:
# ==== Load model & weights ====
model = MFNet(num_classes=NUM_CLASSES).to(DEVICE).eval()
ckpt = torch.load(CHECKPOINT, map_location=DEVICE)
state = ckpt.get('model_state', ckpt)
missing, unexpected = model.load_state_dict(state, strict=False)
print('Missing keys:', missing)
print('Unexpected keys:', unexpected)

In [None]:
# ==== Streaming inference ====
print(f"Starting inference on {math.ceil(ortho.width/ WINDOW_SIZE)*math.ceil(ortho.height/ WINDOW_SIZE)} tiles...")
print(f"Image size: {ortho.width} x {ortho.height}, Window size: {WINDOW_SIZE}, Overlap: {OVERLAP}")
print(f"Device: {DEVICE}, Model loaded: {torch.cuda.is_available()}")

tile_count = 0
with torch.no_grad():
    pbar = tqdm(total=math.ceil(ortho.width/ WINDOW_SIZE)*math.ceil(ortho.height/ WINDOW_SIZE), desc='Tiles')
    
    for win in iter_windows(ortho.width, ortho.height, size=WINDOW_SIZE, overlap=OVERLAP):
        try:
            # Read ortho tile (only first 3 bands: RGB)
            rgb = ortho.read([1, 2, 3], window=win, out_dtype='float32')  # (3,h,w)
            if rgb.shape[1] == 0 or rgb.shape[2] == 0:
                pbar.update(1)
                continue
            rgb = norm_rgb(rgb)

            # Read DSM tile (already aligned by WarpedVRT)
            d = dsm.read(window=win, out_dtype='float32')      # (count,h,w)
            if d.ndim == 2:
                d = d[None, ...]
            # If DSM has more than 1 band, pick the first, else keep (1,h,w)
            dsm_tile = d[:1, ...]
            dsm_tile = norm_dsm(dsm_tile)

            # To torch
            rgb_t = torch.from_numpy(rgb).unsqueeze(0).to(DEVICE)
            dsm_t = torch.from_numpy(dsm_tile).squeeze(0).unsqueeze(0).to(DEVICE)  # Remove channel dim, add batch dim

            # Debug shapes for first tile
            if tile_count == 0:
                print(f"First tile shapes - RGB: {rgb_t.shape}, DSM: {dsm_t.shape}")
                print(f"Window: {win}")

            # Forward pass with RGB and DSM as separate inputs
            logits = model(rgb_t, dsm_t)
            
            # Debug output shape for first tile
            if tile_count == 0:
                print(f"Output logits shape: {logits.shape}")

            pred = torch.argmax(logits, dim=1).squeeze(0).cpu().numpy().astype('uint16')

            # Write this window
            out.write(pred, 1, window=win)
            
            # Clear GPU memory periodically  
            if tile_count % BATCH_CLEAR_FREQ == 0 and torch.cuda.is_available():
                torch.cuda.empty_cache()
                if tile_count > 0:
                    print(f"💾 Cleared GPU cache at tile {tile_count}, VRAM: {torch.cuda.memory_allocated()/1024**3:.1f}GB")
                
            tile_count += 1
            pbar.update(1)
            
        except Exception as e:
            print(f"Error processing tile {tile_count} at window {win}: {e}")
            print(f"  Window bounds: {win.col_off}-{win.col_off + win.width}, {win.row_off}-{win.row_off + win.height}")
            print(f"  Image bounds: 0-{ortho.width}, 0-{ortho.height}")
            import traceback
            print(f"  Full error: {traceback.format_exc()}")
            pbar.update(1)
            continue
            
    pbar.close()

# Close files
out.close()
dsm.close(); dsm_src.close(); ortho.close()
print(f'✅ Done. Processed {tile_count} tiles. Saved: {OUTPUT_TIF}')

## 🚀 Performance Optimization Guide for RTX 4070 (12GB VRAM)

### Quick Settings for Different Scenarios:

**🧪 Fast Testing (2-3 minutes):**
```python
DOWNSAMPLE_FACTOR = 4  # Quarter resolution
WINDOW_SIZE = 256      # Small tiles
OVERLAP = 0            # No overlap
```

**⚡ Balanced Performance (~10-15 minutes):**
```python
DOWNSAMPLE_FACTOR = 2  # Half resolution  
WINDOW_SIZE = 256      # Medium tiles
OVERLAP = 32           # Moderate overlap
```

**🎯 High Quality (30-60 minutes):**
```python
DOWNSAMPLE_FACTOR = 1  # Full resolution
WINDOW_SIZE = 256      # Small tiles for memory safety
OVERLAP = 64           # Good overlap
```

### Expected Processing Times:
- **Original 9GB ortho**: ~2-4 hours (67447 x 72998 pixels)
- **Downsampled 2x**: ~15-30 minutes (33723 x 36499 pixels) 
- **Downsampled 4x**: ~5-10 minutes (16861 x 18249 pixels)

### Memory Usage:
- **RTX 4070**: 12GB VRAM available
- **Model**: ~2-3GB VRAM
- **Per tile (256x256)**: ~10-20MB
- **Safe concurrent tiles**: 1 (sequential processing)

### Tips:
- Start with `DOWNSAMPLE_FACTOR = 4` for initial testing
- Monitor GPU memory with `nvidia-smi`
- If you get OOM errors, reduce `WINDOW_SIZE` to 128
- The final prediction can be upsampled back to original resolution if needed


In [None]:
# ==== Optional: Upsample prediction back to original resolution ====
if DOWNSAMPLE_FACTOR > 1:
    print(f"🔍 Upsampling prediction from {DOWNSAMPLE_FACTOR}x downsampled back to original resolution...")
    
    # Read the downsampled prediction
    with rasterio.open(OUTPUT_TIF) as pred_ds:
        pred_data = pred_ds.read(1)
        
    # Create upsampled output profile
    upsampled_profile = ortho_src.profile.copy()
    upsampled_profile.update(
        dtype='uint16', count=1,
        tiled=True, blockxsize=512, blockysize=512,
        compress='lzw', bigtiff='YES'
    )
    
    # Upsample using nearest neighbor (preserves class labels)
    upsampled_tif = f'prediction_fullres_from_ds{DOWNSAMPLE_FACTOR}.tif'
    
    with rasterio.open(upsampled_tif, 'w', **upsampled_profile) as upsampled_ds:
        upsampled_data = upsampled_ds.read(
            1,
            out_shape=(ortho_src.height, ortho_src.width),
            resampling=Resampling.nearest
        )
        
        # Write upsampled data
        upsampled_ds.write(upsampled_data, 1)
    
    print(f"✅ Upsampled prediction saved: {upsampled_tif}")
    print(f"Original resolution: {ortho_src.width} x {ortho_src.height}")
else:
    print("No upsampling needed - already at full resolution")


In [None]:
# ==== Check Prediction Results ====
print("🔍 Checking prediction file...")

if Path(OUTPUT_TIF).exists():
    # Basic file info
    file_size_mb = Path(OUTPUT_TIF).stat().st_size / (1024 * 1024)
    print(f"✅ Prediction file exists: {OUTPUT_TIF}")
    print(f"📁 File size: {file_size_mb:.1f} MB")
    
    # Read and analyze the prediction
    with rasterio.open(OUTPUT_TIF) as pred:
        print(f"🗺️ Dimensions: {pred.width} x {pred.height}")
        print(f"🎯 CRS: {pred.crs}")
        print(f"📊 Data type: {pred.dtypes[0]}")
        print(f"🔢 Bands: {pred.count}")
        
        # Read the data
        prediction_data = pred.read(1)
        print(f"\n📈 Prediction Statistics:")
        print(f"   Shape: {prediction_data.shape}")
        print(f"   Min value: {prediction_data.min()}")
        print(f"   Max value: {prediction_data.max()}")
        print(f"   Unique classes: {np.unique(prediction_data)}")
        
        # Count pixels per class
        unique, counts = np.unique(prediction_data, return_counts=True)
        total_pixels = prediction_data.size
        print(f"\n🎨 Class Distribution:")
        for class_id, count in zip(unique, counts):
            percentage = (count / total_pixels) * 100
            print(f"   Class {class_id}: {count:,} pixels ({percentage:.1f}%)")
            
else:
    print(f"❌ Prediction file not found: {OUTPUT_TIF}")
    print("Make sure inference completed successfully")


In [None]:
# ==== Visualize Prediction Results ====
if Path(OUTPUT_TIF).exists():
    print("🎨 Creating visualizations...")
    
    with rasterio.open(OUTPUT_TIF) as pred:
        prediction_data = pred.read(1)
        
        # Create figure with subplots
        fig, axes = plt.subplots(1, 2, figsize=(15, 6))
        
        # Plot 1: Raw prediction classes
        im1 = axes[0].imshow(prediction_data, cmap='tab10', vmin=0, vmax=NUM_CLASSES-1)
        axes[0].set_title(f'Prediction Classes\n({prediction_data.shape[0]}×{prediction_data.shape[1]} pixels)')
        axes[0].set_xlabel('X')
        axes[0].set_ylabel('Y')
        
        # Add colorbar
        cbar1 = plt.colorbar(im1, ax=axes[0], shrink=0.8)
        cbar1.set_label('Class ID')
        cbar1.set_ticks(range(NUM_CLASSES))
        
        # Plot 2: Class histogram
        unique, counts = np.unique(prediction_data, return_counts=True)
        colors = plt.cm.tab10(np.linspace(0, 1, NUM_CLASSES))
        bars = axes[1].bar(unique, counts, color=[colors[int(u)] for u in unique])
        axes[1].set_title('Class Distribution')
        axes[1].set_xlabel('Class ID')
        axes[1].set_ylabel('Pixel Count')
        axes[1].set_xticks(range(NUM_CLASSES))
        
        # Add percentage labels on bars
        total_pixels = prediction_data.size
        for bar, count in zip(bars, counts):
            height = bar.get_height()
            percentage = (count / total_pixels) * 100
            axes[1].text(bar.get_x() + bar.get_width()/2., height,
                        f'{percentage:.1f}%', ha='center', va='bottom', fontsize=9)
        
        plt.tight_layout()
        
        # Save visualization
        viz_filename = OUTPUT_TIF.replace('.tif', '_visualization.png')
        plt.savefig(viz_filename, dpi=150, bbox_inches='tight')
        print(f"💾 Visualization saved: {viz_filename}")
        
        plt.show()
        
        # Optional: Show a zoomed section if image is large
        if prediction_data.shape[0] > 500 or prediction_data.shape[1] > 500:
            print("\n🔍 Showing center 256×256 crop for detail:")
            center_y, center_x = prediction_data.shape[0]//2, prediction_data.shape[1]//2
            crop_size = 128
            crop = prediction_data[center_y-crop_size:center_y+crop_size, 
                                center_x-crop_size:center_x+crop_size]
            
            plt.figure(figsize=(8, 6))
            plt.imshow(crop, cmap='tab10', vmin=0, vmax=NUM_CLASSES-1)
            plt.title(f'Center Crop Detail ({crop.shape[0]}×{crop.shape[1]} pixels)')
            plt.colorbar(label='Class ID')
            crop_viz_filename = OUTPUT_TIF.replace('.tif', '_center_crop.png')
            plt.savefig(crop_viz_filename, dpi=150, bbox_inches='tight')
            print(f"💾 Center crop saved: {crop_viz_filename}")
            plt.show()

else:
    print("❌ No prediction file to visualize")


## 🔍 Other Ways to Check Your Prediction File

### 1. **QGIS (Recommended for GIS analysis)**
```bash
# Open QGIS and drag-drop the prediction file
# Or use command line:
qgis prediction_crop5000_ds4_ws512_ov0.tif
```
**Benefits:**
- Geographic context with basemaps
- Advanced styling and symbology
- Measurement tools
- Export to other formats

### 2. **Command Line Tools**
```bash
# File information
gdalinfo prediction_crop5000_ds4_ws512_ov0.tif

# Quick statistics
gdalinfo -stats prediction_crop5000_ds4_ws512_ov0.tif

# Convert to different format
gdal_translate -of PNG prediction_crop5000_ds4_ws512_ov0.tif prediction.png
```

### 3. **Python (Alternative Visualization)**
```python
import rasterio
import matplotlib.pyplot as plt

# Quick preview
with rasterio.open('prediction_crop5000_ds4_ws512_ov0.tif') as src:
    data = src.read(1)
    plt.figure(figsize=(10, 8))
    plt.imshow(data, cmap='tab10')
    plt.colorbar(label='Class ID')
    plt.title('Prediction Results')
    plt.show()
```

### 4. **Web Viewers**
- **Drag into Google Earth Pro** (if georeferenced)
- **Use GDAL Web Viewer** for quick preview
- **Upload to GIS platforms** like ArcGIS Online

### 📊 **What to Look For:**
- **Class distribution**: Are all 6 classes present?
- **Spatial patterns**: Do predictions make geographic sense?
- **Edge effects**: Any artifacts at tile boundaries?
- **Data quality**: No missing/corrupt areas?


## 🏢 Building-Specific Analysis Guide

The next cells provide comprehensive building detection analysis. Here's what you'll get:

### 📊 **Statistical Analysis:**
- **Building coverage percentage** in your area
- **Total building area** in m² and km²
- **Number of building clusters** (connected building regions)
- **Cluster size statistics** (largest, smallest, average buildings)
- **Building density patterns**

### 🎨 **Visualizations:**
1. **All classes with building highlights** (red outlines)
2. **Buildings-only binary mask** (red = buildings)
3. **Individual building clusters** (colored by cluster)
4. **Building density heatmap** (intensity map)
5. **Cluster size distribution** (histogram)
6. **Cumulative building area** (largest contributors)

### 💾 **Export Files:**
- `*_buildings_only.tif` → Binary mask (1=building, 0=other)
- `*_building_footprints.tif` → Buildings with original class IDs
- `*_building_analysis.png` → 4-panel visualization
- `*_building_stats.png` → Statistical plots

### 🔧 **Configuration Required:**
**You MUST set the correct building class ID in the next cell!**

1. **First**: Run the general analysis (previous cells) to see all class distributions
2. **Identify**: Which class ID(s) represent buildings (usually 5-20% of pixels)  
3. **Update**: `BUILDING_CLASS_ID = X` in the next cell
4. **Run**: Building analysis cells

### 💡 **Expected Building Percentages by Area Type:**
- **Urban areas**: 15-30% building coverage
- **Suburban areas**: 5-15% building coverage  
- **Rural areas**: 1-5% building coverage
- **Dense cities**: 30-50% building coverage


In [None]:
# ==== Quick Command Line Checks ====
if Path(OUTPUT_TIF).exists():
    print("🖥️ Running quick command line checks...")
    
    # Check if gdalinfo is available
    try:
        import subprocess
        result = subprocess.run(['gdalinfo', OUTPUT_TIF], 
                              capture_output=True, text=True, timeout=30)
        if result.returncode == 0:
            print("✅ GDAL Info:")
            # Show first 20 lines of gdalinfo output
            lines = result.stdout.split('\n')[:20]
            for line in lines:
                if line.strip():
                    print(f"   {line}")
            if len(result.stdout.split('\n')) > 20:
                print("   ... (output truncated)")
        else:
            print("❌ gdalinfo failed or not available")
            
    except (subprocess.TimeoutExpired, FileNotFoundError):
        print("❌ gdalinfo not available or timed out")
    
    # Alternative: Use rasterio to show similar info
    print(f"\n📊 File Summary:")
    with rasterio.open(OUTPUT_TIF) as src:
        print(f"   File: {OUTPUT_TIF}")
        print(f"   Size: {src.width} x {src.height}")
        print(f"   Bands: {src.count}")
        print(f"   Data Type: {src.dtypes[0]}")
        print(f"   CRS: {src.crs}")
        print(f"   Bounds: {src.bounds}")
        print(f"   Resolution: {src.res}")
        
        # Quick data peek
        sample = src.read(1, window=rasterio.windows.Window(0, 0, min(100, src.width), min(100, src.height)))
        print(f"   Value range: {sample.min()} - {sample.max()}")
        print(f"   Sample values: {np.unique(sample)[:10]}...")  # First 10 unique values
        
else:
    print(f"❌ Prediction file not found: {OUTPUT_TIF}")


In [None]:
# ==== Building-Specific Analysis ====
if Path(OUTPUT_TIF).exists():
    print("🏢 Analyzing Building Predictions...")
    
    # !! CONFIGURE YOUR BUILDING CLASS ID(S) HERE !!
    # First, run the basic analysis to see which class(es) represent buildings
    # Look at the class distribution and identify which class ID(s) are buildings
    
    BUILDING_CLASS_ID = 1  # 🔧 CHANGE THIS to your building class ID (0-5)
    
    # Options for different building types:
    # BUILDING_CLASS_IDS = [1]        # Single building class
    # BUILDING_CLASS_IDS = [1, 2]     # Multiple building types (e.g., residential, commercial)  
    # BUILDING_CLASS_IDS = [3, 4, 5]  # If buildings are classes 3, 4, 5
    
    BUILDING_CLASS_IDS = [BUILDING_CLASS_ID]  # Single building class
    
    # 💡 If you're unsure which class is buildings:
    # 1. Run the general prediction analysis first (previous cells)
    # 2. Look at the class distribution percentages
    # 3. Identify which class(es) likely represent buildings based on:
    #    - Reasonable percentage (usually 5-20% for buildings)
    #    - Spatial patterns that look like buildings
    # 4. Update BUILDING_CLASS_ID above and re-run this analysis
    
    print(f"🎯 Building class ID(s): {BUILDING_CLASS_IDS}")
    
    with rasterio.open(OUTPUT_TIF) as pred:
        prediction_data = pred.read(1)
        
        # Create building mask (binary: building vs non-building)
        building_mask = np.isin(prediction_data, BUILDING_CLASS_IDS)
        
        # Calculate building statistics
        total_pixels = prediction_data.size
        building_pixels = np.sum(building_mask)
        building_percentage = (building_pixels / total_pixels) * 100
        
        print(f"\n📊 Building Statistics:")
        print(f"   Total pixels: {total_pixels:,}")
        print(f"   Building pixels: {building_pixels:,}")
        print(f"   Building coverage: {building_percentage:.2f}%")
        
        # Estimate building area (if we know pixel resolution)
        pixel_area_m2 = abs(pred.res[0] * pred.res[1])  # Square meters per pixel
        building_area_m2 = building_pixels * pixel_area_m2
        building_area_km2 = building_area_m2 / 1_000_000
        
        print(f"   Pixel resolution: {pred.res[0]:.2f}m x {pred.res[1]:.2f}m")
        print(f"   Building area: {building_area_m2:,.0f} m² ({building_area_km2:.3f} km²)")
        
        # Find building clusters/regions
        from scipy import ndimage
        labeled_buildings, num_buildings = ndimage.label(building_mask)
        
        if num_buildings > 0:
            # Calculate building cluster statistics
            cluster_sizes = []
            for i in range(1, num_buildings + 1):
                cluster_size = np.sum(labeled_buildings == i)
                cluster_sizes.append(cluster_size)
            
            cluster_sizes = np.array(cluster_sizes)
            
            print(f"\n🏘️ Building Clusters:")
            print(f"   Number of building clusters: {num_buildings}")
            print(f"   Largest cluster: {cluster_sizes.max()} pixels ({cluster_sizes.max() * pixel_area_m2:.0f} m²)")
            print(f"   Smallest cluster: {cluster_sizes.min()} pixels ({cluster_sizes.min() * pixel_area_m2:.0f} m²)")
            print(f"   Average cluster size: {cluster_sizes.mean():.1f} pixels ({cluster_sizes.mean() * pixel_area_m2:.0f} m²)")
        else:
            print(f"   ❌ No building clusters detected")
            
        # Show class distribution for context
        unique, counts = np.unique(prediction_data, return_counts=True)
        print(f"\n📈 Full Class Distribution:")
        class_names = {
            0: "Background/Other",
            1: "Buildings(?)",  # Mark uncertain
            2: "Class 2", 
            3: "Class 3",
            4: "Class 4", 
            5: "Class 5"
        }
        
        for class_id, count in zip(unique, counts):
            percentage = (count / total_pixels) * 100
            name = class_names.get(class_id, f"Class {class_id}")
            is_building = "🏢" if class_id in BUILDING_CLASS_IDS else "  "
            print(f"   {is_building} Class {class_id} ({name}): {count:,} pixels ({percentage:.1f}%)")

else:
    print("❌ No prediction file found. Run inference first.")


In [None]:
# ==== Building-Focused Visualizations ====
if Path(OUTPUT_TIF).exists():
    print("🎨 Creating building-focused visualizations...")
    
    with rasterio.open(OUTPUT_TIF) as pred:
        prediction_data = pred.read(1)
        building_mask = np.isin(prediction_data, BUILDING_CLASS_IDS)
        
        # Create comprehensive building visualization
        fig, axes = plt.subplots(2, 2, figsize=(16, 12))
        fig.suptitle('Building Detection Analysis', fontsize=16, fontweight='bold')
        
        # Plot 1: All classes with buildings highlighted
        im1 = axes[0,0].imshow(prediction_data, cmap='tab10', vmin=0, vmax=NUM_CLASSES-1)
        axes[0,0].set_title('All Classes (Buildings Highlighted)')
        axes[0,0].set_xlabel('X (pixels)')
        axes[0,0].set_ylabel('Y (pixels)')
        
        # Overlay building outlines
        from scipy import ndimage
        building_edges = ndimage.binary_erosion(building_mask) ^ building_mask
        if np.any(building_edges):
            axes[0,0].contour(building_edges, colors='red', linewidths=1.5, alpha=0.8)
        
        cbar1 = plt.colorbar(im1, ax=axes[0,0], shrink=0.8)
        cbar1.set_label('Class ID')
        
        # Plot 2: Buildings only (binary mask)
        building_display = building_mask.astype(float)
        building_display[~building_mask] = np.nan  # Transparent non-buildings
        
        im2 = axes[0,1].imshow(building_display, cmap='Reds', vmin=0, vmax=1)
        axes[0,1].set_title('Buildings Only (Binary Mask)')
        axes[0,1].set_xlabel('X (pixels)')
        axes[0,1].set_ylabel('Y (pixels)')
        
        cbar2 = plt.colorbar(im2, ax=axes[0,1], shrink=0.8)
        cbar2.set_label('Building (1) / Non-building (0)')
        
        # Plot 3: Building clusters with labels
        labeled_buildings, num_buildings = ndimage.label(building_mask)
        
        if num_buildings > 0:
            im3 = axes[1,0].imshow(labeled_buildings, cmap='tab20', vmin=0)
            axes[1,0].set_title(f'Building Clusters (n={num_buildings})')
            axes[1,0].set_xlabel('X (pixels)')
            axes[1,0].set_ylabel('Y (pixels)')
            
            # Add cluster size annotations for largest clusters
            cluster_sizes = []
            cluster_centers = []
            for i in range(1, min(6, num_buildings + 1)):  # Show top 5 clusters
                cluster_pixels = (labeled_buildings == i)
                if np.any(cluster_pixels):
                    size = np.sum(cluster_pixels)
                    cluster_sizes.append(size)
                    
                    # Find cluster center
                    y_coords, x_coords = np.where(cluster_pixels)
                    center_y, center_x = np.mean(y_coords), np.mean(x_coords)
                    cluster_centers.append((center_x, center_y))
                    
                    # Annotate largest clusters
                    if i <= 3:  # Top 3 clusters
                        axes[1,0].annotate(f'{size}px', (center_x, center_y), 
                                         color='white', fontweight='bold', 
                                         ha='center', va='center',
                                         bbox=dict(boxstyle='round,pad=0.3', facecolor='black', alpha=0.7))
        else:
            axes[1,0].text(0.5, 0.5, 'No building clusters detected', 
                          transform=axes[1,0].transAxes, ha='center', va='center', 
                          fontsize=12, bbox=dict(boxstyle='round', facecolor='lightgray'))
            axes[1,0].set_title('Building Clusters')
        
        # Plot 4: Building density heatmap
        from scipy.ndimage import gaussian_filter
        
        # Create density map using Gaussian smoothing
        density_map = gaussian_filter(building_mask.astype(float), sigma=10)
        
        im4 = axes[1,1].imshow(density_map, cmap='YlOrRd', vmin=0, vmax=density_map.max())
        axes[1,1].set_title('Building Density Heatmap')
        axes[1,1].set_xlabel('X (pixels)')
        axes[1,1].set_ylabel('Y (pixels)')
        
        cbar4 = plt.colorbar(im4, ax=axes[1,1], shrink=0.8)
        cbar4.set_label('Building Density')
        
        plt.tight_layout()
        
        # Save building analysis
        building_viz_file = OUTPUT_TIF.replace('.tif', '_building_analysis.png')
        plt.savefig(building_viz_file, dpi=150, bbox_inches='tight')
        print(f"💾 Building analysis saved: {building_viz_file}")
        
        plt.show()
        
        # Create building summary statistics plot
        if num_buildings > 0:
            plt.figure(figsize=(12, 4))
            
            # Subplot 1: Cluster size distribution
            plt.subplot(1, 2, 1)
            cluster_sizes = [np.sum(labeled_buildings == i) for i in range(1, num_buildings + 1)]
            plt.hist(cluster_sizes, bins=min(20, num_buildings), edgecolor='black', alpha=0.7, color='skyblue')
            plt.xlabel('Cluster Size (pixels)')
            plt.ylabel('Number of Clusters')
            plt.title('Building Cluster Size Distribution')
            plt.grid(True, alpha=0.3)
            
            # Subplot 2: Cumulative building area
            sorted_sizes = sorted(cluster_sizes, reverse=True)
            cumulative_area = np.cumsum(sorted_sizes)
            cumulative_percent = (cumulative_area / np.sum(cluster_sizes)) * 100
            
            plt.subplot(1, 2, 2)
            plt.plot(range(1, len(sorted_sizes) + 1), cumulative_percent, marker='o', markersize=3)
            plt.xlabel('Cluster Rank (largest to smallest)')
            plt.ylabel('Cumulative Building Area (%)')
            plt.title('Cumulative Building Coverage')
            plt.grid(True, alpha=0.3)
            
            # Add 80% line
            plt.axhline(y=80, color='red', linestyle='--', alpha=0.7, label='80% threshold')
            plt.legend()
            
            plt.tight_layout()
            
            cluster_stats_file = OUTPUT_TIF.replace('.tif', '_building_stats.png')
            plt.savefig(cluster_stats_file, dpi=150, bbox_inches='tight')
            print(f"💾 Building statistics saved: {cluster_stats_file}")
            
            plt.show()

else:
    print("❌ No prediction file found for building analysis.")


In [None]:
# ==== Export Building-Only Results ====
if Path(OUTPUT_TIF).exists():
    print("💾 Exporting building-specific results...")
    
    with rasterio.open(OUTPUT_TIF) as pred:
        prediction_data = pred.read(1)
        building_mask = np.isin(prediction_data, BUILDING_CLASS_IDS)
        
        # Create building-only GeoTIFF (binary: 1=building, 0=non-building)
        building_output = OUTPUT_TIF.replace('.tif', '_buildings_only.tif')
        
        building_profile = {
            'driver': 'GTiff',
            'dtype': 'uint8',  # Binary data, uint8 is sufficient
            'count': 1,
            'width': pred.width,
            'height': pred.height,
            'crs': pred.crs,
            'transform': pred.transform,
            'compress': 'lzw',
            'tiled': True,
            'blockxsize': min(512, pred.width),
            'blockysize': min(512, pred.height)
        }
        
        with rasterio.open(building_output, 'w', **building_profile) as building_dst:
            building_dst.write(building_mask.astype('uint8'), 1)
        
        print(f"✅ Building mask saved: {building_output}")
        
        # Create building footprints GeoTIFF (buildings with original class IDs, others = 0)
        footprints_output = OUTPUT_TIF.replace('.tif', '_building_footprints.tif')
        
        building_footprints = prediction_data.copy()
        building_footprints[~building_mask] = 0  # Set non-buildings to 0
        
        with rasterio.open(footprints_output, 'w', **building_profile) as footprints_dst:
            footprints_dst.write(building_footprints.astype('uint8'), 1)
        
        print(f"✅ Building footprints saved: {footprints_output}")
        
        # Summary of exported files
        print(f"\n📁 Building Analysis Files Created:")
        print(f"   🗺️  Original prediction: {OUTPUT_TIF}")
        print(f"   🏠 Building mask (binary): {building_output}")
        print(f"   🏢 Building footprints (with class IDs): {footprints_output}")
        
        # Check file sizes
        for filepath in [OUTPUT_TIF, building_output, footprints_output]:
            if Path(filepath).exists():
                size_mb = Path(filepath).stat().st_size / (1024 * 1024)
                print(f"      {Path(filepath).name}: {size_mb:.1f} MB")
        
        print(f"\n💡 Usage Tips:")
        print(f"   • Load {building_output} in QGIS for building overlay analysis")
        print(f"   • Use {footprints_output} to distinguish building types")
        print(f"   • Both files maintain original georeference and can be used in GIS")

else:
    print("❌ No prediction file found for building export.")


## Notes
- For **IRRG ortho** (NIR,R,G), keep the same band order you trained with, or reorder before `norm_rgb`.
- If you want overlapped tiles with smoothing, set `OVERLAP>0` and implement feathering. For simplicity (and memory), this notebook uses **last-write-wins** on overlaps.
- If your DSM is noisy, consider a **median filter** or **bilateral filter** per tile before `norm_dsm`.
- To visualize a small area, use `rasterio.plot.show` on a window or open `OUTPUT_TIF` in QGIS.

In [None]:
# ==== Simple upsampling if needed ====
if DOWNSAMPLE_FACTOR > 1:
    print(f"To upsample back to original resolution, use:")
    print(f"gdalwarp -tr {ortho_src.transform[0]} {abs(ortho_src.transform[4])} -r near {OUTPUT_TIF} prediction_fullres.tif")
else:
    print("Already at full resolution - no upsampling needed")
