In [None]:
 pip install numpy pandas matplotlib seaborn scikit-learn scipy geopandas rasterio shapely pyproj pycraf pvlib geopy requests joblib plotly contextily pulp ipywidgets ipyleaflet pymoo astropy

In [None]:
pip install python-dotenv

In [None]:
# ===========================================================
# TITLE: Wireless Sensor Network Gateway Placement
# ===========================================================

# 1. ===== GLOBAL IMPORTS, CONFIG, AND SEEDING =====
import os, random, warnings, json
from typing import List, Tuple, Dict, Optional, Any
from dataclasses import dataclass, field

import numpy as np
import pandas as pd

import geopandas as gpd
from shapely.geometry import Point, MultiPoint, Polygon, shape, box
from shapely.ops import unary_union
from geopy.distance import geodesic

import rasterio
from rasterio.io import MemoryFile
from rasterio.mask import mask
from pyproj import Geod

import pvlib
from pvlib.location import Location
from pvlib.pvsystem import PVSystem
from pvlib.modelchain import ModelChain

import pulp
from itertools import combinations

import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import contextily as cx

from ipyleaflet import Map, DrawControl, basemaps, Marker, GeoJSON, Popup, CircleMarker, Polyline, Circle
from ipywidgets import Button, Output, VBox, HBox, HTML, FloatSlider, IntSlider, Dropdown, Layout, Tab, SelectMultiple

import requests
from astropy import units as u

from pymoo.core.problem import Problem
from pymoo.algorithms.moo.nsga3 import NSGA3
from pymoo.util.ref_dirs import get_reference_directions
from pymoo.optimize import minimize
from pymoo.visualization.scatter import Scatter
from pymoo.termination.default import DefaultMultiObjectiveTermination

import seaborn as sns

import plotly.io as pio
pio.renderers.default = 'vscode'

warnings.filterwarnings('ignore')

RANDOM_SEED = 42
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
print(f"Random seed set: {RANDOM_SEED}")

@dataclass
class NetworkConfig:
    MAX_COMM_RANGE_M: float = 2500.0
    GATEWAY_HEIGHT: float = 5.0
    SENSOR_HEIGHT: float = 5.0

@dataclass
class OptimizationConfig:
    NUM_GATEWAY_CANDIDATES: int = 550
    COVERAGE_GRID_SIZE: int = 40
    TX_POWER_DBM: float = 10.0           
    RX_SENSITIVITY_DBM: float = -125.0   
    ANTENNA_GAIN_DBI: float = 10.0       
    SYSTEM_MARGIN_DB: float = 10.0       

    @property
    def MAX_ALLOWABLE_PATH_LOSS_DB(self) -> float:
        return (self.TX_POWER_DBM + self.ANTENNA_GAIN_DBI * 2 -
                self.RX_SENSITIVITY_DBM - self.SYSTEM_MARGIN_DB)

NETWORK = NetworkConfig()
OPTIMIZATION = OptimizationConfig()

# LoRaWAN budget calculation
LORAWAN_BUDGET = OPTIMIZATION.MAX_ALLOWABLE_PATH_LOSS_DB


print(f"LoRaWAN Link Budget: {LORAWAN_BUDGET} dB")
print(f"Updated MAX_ALLOWABLE_PATH_LOSS_DB: {OPTIMIZATION.MAX_ALLOWABLE_PATH_LOSS_DB} dB")

cachedir = './.wsn_cache'
#memory = Memory(cachedir, verbose=0)
def clear_all_cache_dirs():
    import shutil
    cache_dirs = [
        './.sensor_cache', './.joblib_cache', './.wsn_cache',
        './.phase1_cache', './.phase2_cache', './.dem_cache', './.cache',
        './output_cache', './.ipynb_checkpoints', './.pymoo_cache',
        './.tmp_cache', './.memory_cache'
    ]
    for d in cache_dirs:
        try:
            shutil.rmtree(d)
            print(f"🧹 Cleared {d} directory.")
        except FileNotFoundError:
            pass
        except Exception as e:
            print(f"Could not clear {d}: {e}")
    print("All listed cache folders cleared (if they existed).")

clear_all_cache_dirs()

try:
    memory.clear(warn=False)
    print("Cleared joblib.Memory cache successfully.")
except Exception as e:
    print("Cache clear failed or not needed:", e)

dem_manager = None
aoi_poly: Optional[Polygon] = None
drawn_geometry: Dict[str, Any] = {}
gateway_candidate_df: Optional[pd.DataFrame] = None
saved_solutions = {}
solution_counter = 0

print("All libraries, configs, and globals loaded.")

NEW SOLAR FUNCTIONS BELOW. 

In [None]:
# =============================================================
# Solar PV + Load + Battery — Annual Site Simulation 
# =============================================================
# Goal:
#   Given a site (lon/lat) and some terrain rasters (DEM/DSM/CHM),
#   I want a daily energy story for a tiny PV system powering device:
#     1) Pull hour-by-hour irradiance & weather for a chosen year (NSRDB/PSM3 via pvlib)
#     2) Adjust irradiance for local terrain/vegetation shading (horizon mask)
#     3) Run a PVWatts-style PV model to get AC power each hour
#     4) Convert device behavior (boots + standby) into daily Wh load
#     5) March the battery state-of-charge day-by-day to catch outages
#   The output feeds my gateway/sensor candidate ranking (solar).

import os
import math
from dataclasses import dataclass

import numpy as np
import pandas as pd
from dotenv import load_dotenv

from pvlib.location import Location
from pvlib.pvsystem import PVSystem
from pvlib.modelchain import ModelChain
from pvlib.solarposition import get_solarposition

# -------------------------------------------------------------
# Environment & caching (cache weather pulls to avoid re-hitting API)
# -------------------------------------------------------------
load_dotenv()  # keep NSRDB_API_KEY and NSRDB_EMAIL in a .env file next to the notebook
WEATHER_CACHE: dict[tuple, pd.DataFrame] = {}
WEATHER_Q_DEG = 0.01  # I snap lat/lon to 0.05° so nearby sites reuse the same NSRDB fetch


# --- SAFE quantizer (handles q<=0) ---
def _q(val: float, q: float) -> float:
    """Quantize val to nearest multiple of q. If q<=0, return val unchanged."""
    v = float(val)
    if not q or q <= 0:
        return v
    return round(v / q) * q

# Sanity-check that my creds are visible to the process
api_key = os.getenv("NSRDB_API_KEY")
email = os.getenv("NSRDB_EMAIL")
print("API Key loaded?", bool(api_key))
print("Email loaded?", email)

# -------------------------------------------------------------
# PV / Load / Battery parameter blocks (my defaults)
# -------------------------------------------------------------

@dataclass
class PVParams:
    # STC DC rating of my mini panel. If I swap a 10W/20W/25W,
    pdc0_w: float = 10.0

    # Temperature coefficient for DC power (fraction per °C). 
    gamma_pdc: float = -0.004

    soiling_loss: float = 0.15   # dust/film/bird hits; I assume 2% for maintained field units
    wiring_loss: float = 0.03    # cable + connectors; tiny runs but include 3% so I’m not optimistic

    # MPPT efficiency: how well the controller tracks the true max power point.
    # I set 97% 
    mppt_eff: float = 0.87

    # Sandia Array Performance Model (SAPM) temperature params I'm using with pvlib's SAPM temp model
    # These a/b/deltaT are a common generic set; they control how hot the module runs for given POA & wind.
    temp_a: float = -3.56
    temp_b: float = 0.0594
    temp_deltaT: float = 3.0


@dataclass
class LoadParams:
    # These came from Eddy’s device notes 
    boot_current_a: float = 0.690
    per_node_current_a: float = 0.070
    num_nodes: int = 6
    boot_secs: int = 5

    # I encode the worst-case poll cadence here; if I want a day-split schedule later, I’ll extend the model.
    check_every_minutes: int = 1

    # This 0.0025 A is the *gateway* standby ( 0.0149 sensor standby).
    standby_current_a: float = 0.0149

    
    sleep_current_a: float | None = None

    # Nominal bus voltage; my packs are 1S Li-ion-ish (3.6–3.7 V nominal)
    system_voltage_v: float = 3.7


@dataclass
class BatteryParams:
    # For SLA 12 V × Ah
    capacity_wh: float = 18.5

    # Charge/discharge path efficiencies; I keep both explicit so I can tweak asymmetrically if needed.
    charge_eff: float = 0.95
    discharge_eff: float = 0.95

    # Ops policy: I don’t want to dip below 10% SOC (reserve for cold days/aging). I start the sim at 80%.
    min_soc: float = 0.10
    initial_soc: float = 0.80


# I cache PV results too because horizon masks & PV runs are moderately expensive per site
_PV_CACHE: dict[tuple, tuple[pd.Series, dict]] = {}

# -------------------------------------------------------------
# Horizon profile & sky openness (terrain/vegetation shading model)
# -------------------------------------------------------------
# I ray-cast 0..359° around the site out to ~300–400 m across DEM/DSM/CHM to find the
# maximum elevation angle of obstacles in each direction. That becomes my horizon line.
# If the sun is below that line in a direction/hour, I zero the direct beam (DNI).
# Diffuse is scaled by a crude sky-view factor derived from the mean horizon height.

def _horizon_profile_from_dsm(dem_manager, coord, radius_m: int = 400, az_step_deg: int = 1):
    """
    dem_manager must expose:
      - get_elevation((lon, lat))  -> bare earth
      - get_surface_elevation((lon, lat)) -> DSM (if available)
      - get_canopy_height((lon, lat)) -> CHM (if available)
      - transform, slope_arr, aspect_arr optionally for orientation

    Returns:
      az_deg        : np.ndarray of azimuth centers [0..359]
      horz_elev_deg : np.ndarray of max elevation angle of obstacles at each azimuth (deg)
    """
    az_deg = np.arange(0, 360, az_step_deg, dtype=float)
    horz = np.zeros_like(az_deg, dtype=float)

    # Decide which "surface" to cast across (I prefer DSM; else DEM+CHM; else DEM only)
    def elev_at(lon, lat):
        if getattr(dem_manager, 'dsm_array', None) is not None:
            return dem_manager.get_surface_elevation((lon, lat))
        elif getattr(dem_manager, 'chm_array', None) is not None:
            return dem_manager.get_elevation((lon, lat)) + dem_manager.get_canopy_height((lon, lat))
        else:
            return dem_manager.get_elevation((lon, lat))

    lon0, lat0 = coord
    steps = max(10, int(radius_m // 10))  # I keep ~10 m step spacing along each spoke

    for i, az in enumerate(az_deg):
        max_alpha = 0.0
        for s in range(1, steps + 1):
            frac = s / steps
            dist = frac * radius_m

            # Flat-earth approximation is fine at <1 km: convert local meters to degrees
            dx = dist * math.sin(math.radians(az))
            dy = dist * math.cos(math.radians(az))
            m_per_deg_lat = 111_000.0
            m_per_deg_lon = m_per_deg_lat * math.cos(math.radians(lat0))
            lon = lon0 + dx / m_per_deg_lon
            lat = lat0 + dy / m_per_deg_lat

            z_obs = elev_at(lon0, lat0)
            z_pt = elev_at(lon, lat)
            dz = (z_pt - z_obs)
            alpha = math.degrees(math.atan2(dz, dist))  # elevation angle of that obstacle
            if alpha > max_alpha:
                max_alpha = alpha
        horz[i] = max(0.0, max_alpha)

    return az_deg, horz


def _sky_view_factor(horizon_elev_deg: np.ndarray) -> float:
    """Quick-and-dirty SVF proxy in [0,1]. 0=open sky, 1=closed (not physical; I clamp).
    I just normalize the mean horizon height over 0..90°.
    """
    mean_h = np.clip(np.mean(horizon_elev_deg), 0, 90)
    return float(max(0.0, 1.0 - float(mean_h) / 90.0))


def _orient_from_terrain(dem_manager, coord, force_south: bool = True):
    """Read slope/aspect at the pixel; I cap tilt to 60° and optionally force south-facing.
    If I pass force_south=True I ignore natural aspect and lock azimuth=180°.
    """
    col, row = ~dem_manager.transform * coord
    row, col = int(row), int(col)
    tilt = 30.0
    az = 180.0

    if dem_manager.slope_arr is not None and 0 <= row < dem_manager.slope_arr.shape[0] and 0 <= col < dem_manager.slope_arr.shape[1]:
        tilt = float(np.clip(dem_manager.slope_arr[row, col], 0, 60))

    if not force_south and dem_manager.aspect_arr is not None and 0 <= row < dem_manager.aspect_arr.shape[0] and 0 <= col < dem_manager.aspect_arr.shape[1]:
        az = float(dem_manager.aspect_arr[row, col] % 360)

    return tilt, (180.0 if force_south else az)


# -------------------------------------------------------------
# Weather ingestion (NSRDB/PSM3 through pvlib) + UTC normalization
# -------------------------------------------------------------
# I stick to the 1998–2022 availability window in NSRDB PSM3 actual-year data.
# I compute a diagnostic cloud_cover_pct as 1 - (GHI_measured / GHI_clearsky) for daylight hours.


def _annual_weather(
    loc: Location,
    year: int = 2022,
    nsrdb_api_key: str | None = None,
    nsrdb_email: str | None = None,
    interval: int = 60,  # 60 or 30 minutes
) -> pd.DataFrame:
    """
    Returns a UTC-indexed DataFrame with columns:
      ['ghi','dni','dhi','temp_air','wind_speed','clearsky_ghi',
       'cloud_cover_pct','cloud_type']
    - cloud_cover_pct is a diagnostic derived from GHI/clearsky with a small cushion.
    - cloud_type is NSRDB’s categorical cloud field if returned; else NaN.
    """
    from pvlib import iotools
    from pvlib.solarposition import get_solarposition
    from datetime import timedelta, timezone

    if interval not in (30, 60):
        raise ValueError("NSRDB interval must be 30 or 60 minutes.")

    # Clamp to NSRDB PSM3 historical range
    y = int(min(2022, max(1998, year)))

    # Quantize to reduce redundant requests for nearby sites
    lat, lon = float(loc.latitude), float(loc.longitude)
    lat_q, lon_q = _q(lat, WEATHER_Q_DEG), _q(lon, WEATHER_Q_DEG)
    cache_key = (lat_q, lon_q, y, interval)
    if cache_key in WEATHER_CACHE:
        return WEATHER_CACHE[cache_key].copy()

    # Creds
    nsrdb_api_key = nsrdb_api_key or os.environ.get("NSRDB_API_KEY")
    nsrdb_email   = nsrdb_email   or os.environ.get("NSRDB_EMAIL")
    if not nsrdb_api_key or not nsrdb_email:
        raise RuntimeError("Set NSRDB_API_KEY and NSRDB_EMAIL (or pass them to _annual_weather).")

    attrs = [
        "ghi", "dhi", "dni",
        "air_temperature", "wind_speed",
        "clearsky_ghi", "clearsky_dni", "clearsky_dhi",
        "cloud_type",
    ]

    df, meta = iotools.get_psm3(
        latitude=float(lat_q),
        longitude=float(lon_q),
        api_key=nsrdb_api_key,
        email=nsrdb_email,
        names=str(y),
        attributes=attrs,         # LIST, not CSV string
        interval=interval,
        leap_day=False,
        map_variables=True,
    )

    # Convert to UTC (NSRDB is local-standard-time tz-naive)
    tz_hours = None
    for k in ("Time Zone", "TimeZone", "TZ", "tz", "timezone"):
        if k in meta:
            try:
                tz_hours = float(meta[k]); break
            except Exception:
                pass

    if df.index.tz is None:
        if tz_hours is not None:
            fixed = timezone(timedelta(hours=tz_hours))
            df.index = df.index.tz_localize(fixed).tz_convert("UTC")
        else:
            df.index = df.index.tz_localize("UTC")
    else:
        df.index = df.index.tz_convert("UTC")

    # Column cleanup
    if "temp_air" not in df.columns and "air_temperature" in df.columns:
        df = df.rename(columns={"air_temperature": "temp_air"})

    for c in ("ghi", "dni", "dhi", "temp_air", "wind_speed"):
        if c not in df.columns:
            df[c] = np.nan
        df[c] = pd.to_numeric(df[c], errors="coerce")

    # Ensure we have a clearsky_ghi (fallback if NSRDB omitted it)
    if "clearsky_ghi" not in df.columns or df["clearsky_ghi"].isna().all():
        cs = Location(float(lat_q), float(lon_q)).get_clearsky(df.index, model="ineichen")
        df["clearsky_ghi"] = pd.to_numeric(cs["ghi"], errors="coerce")

    # --- Cloud cover diagnostic from GHI vs. clearsky ---
    cs_ghi = pd.to_numeric(df["clearsky_ghi"], errors="coerce").replace(0, np.nan)
    # small cushion (e.g., 8%) so tiny GHI>clearsky doesn’t force 0%
    cushion = 1.08
    ratio = df["ghi"] / (cushion * cs_ghi)
    cloudiness = (1.0 - ratio).clip(lower=0, upper=1)
    df["cloud_cover_pct"] = 100.0 * cloudiness

    # Nights → NaN (don’t pollute means)
    sp = get_solarposition(df.index, float(lat_q), float(lon_q))
    daylight = sp["apparent_elevation"] > 0
    df.loc[~daylight, "cloud_cover_pct"] = np.nan

    # Keep cloud_type if present
    if "cloud_type" not in df.columns:
        df["cloud_type"] = np.nan
    else:
        df["cloud_type"] = pd.to_numeric(df["cloud_type"], errors="coerce")

    # Final selection & typing
    df = df[~df.index.duplicated(keep="first")].sort_index()
    out = df[[
        "ghi", "dni", "dhi", "temp_air", "wind_speed",
        "clearsky_ghi", "cloud_cover_pct", "cloud_type"
    ]].astype(float)

    WEATHER_CACHE[cache_key] = out.copy()
    return out


# -------------------------------------------------------------
# PV energy with shading (per site) → daily Wh
# -------------------------------------------------------------
# I compute the horizon, attenuate DNI/DHI accordingly, then run pvlib’s PVWatts DC/AC chain.


def _pv_dc_daily_wh(coord: tuple[float, float], dem_manager, pv: PVParams, year: int = 2022, force_south: bool = True):
    lon, lat = coord
    alt = float(dem_manager.get_elevation(coord))
    tilt, az = _orient_from_terrain(dem_manager, coord, force_south=force_south)

    # Build the 0..359° horizon once per site
    az_bins, horz_elev = _horizon_profile_from_dsm(dem_manager, coord, radius_m=300, az_step_deg=1)
    svf = _sky_view_factor(horz_elev)

    # Cache key must reflect shading + PV params; that way repeated calls hit memory
    key = (
        round(lon, 4), round(lat, 4), year,
        pv.pdc0_w, pv.gamma_pdc, pv.soiling_loss, pv.wiring_loss,
        tilt, az, "DCpvwatts",                           # <--- was "ACpvwatts"
        round(float(np.mean(horz_elev)), 2), round(svf, 3),
    )
    if key in _PV_CACHE:
        return _PV_CACHE[key]

    # Weather (UTC-indexed)
    loc = Location(lat, lon, tz='UTC', altitude=alt)
    wx = _annual_weather(loc, year=year, interval=60)

    # Sun geometry at each timestamp
    sp = get_solarposition(wx.index, lat, lon, altitude=alt)
    sun_az = np.mod(sp['azimuth'].values, 360.0).astype(float)
    sun_el = sp['apparent_elevation'].values.astype(float)

    # Lookup horizon at each sun azimuth (nearest-degree bin)
    bin_idx = np.clip(np.round(sun_az).astype(int), 0, 359)
    horz_at_sun = horz_elev[bin_idx]

    # Copy arrays for in-place edits
    wx = wx.copy()
    dni = wx['dni'].to_numpy(dtype=float)
    dhi = wx['dhi'].to_numpy(dtype=float)
    ghi = wx['ghi'].to_numpy(dtype=float)

    # Hard occultation of direct beam when sun is behind the terrain/vegetation
    blocked = sun_el <= horz_at_sun
    dni[blocked] = 0.0

    # Diffuse knocked down by SVF (less open sky → less diffuse light)
    dhi *= svf

    # Recompose GHI from (DNI, DHI, solar zenith)
    cos_theta_z = np.cos(np.radians(90.0 - np.clip(sun_el, -90, 90)))
    ghi = dni * np.maximum(cos_theta_z, 0.0) + dhi

    wx['dni'] = dni
    wx['dhi'] = dhi
    wx['ghi'] = ghi

    # PV system & model chain
    system = PVSystem(
        surface_tilt=tilt,
        surface_azimuth=az,
        module_parameters={'pdc0': pv.pdc0_w, 'gamma_pdc': pv.gamma_pdc},
        inverter_parameters={'pdc0': pv.pdc0_w},  # pvwatts AC is ratioed off pdc0
        temperature_model_parameters={'a': pv.temp_a, 'b': pv.temp_b, 'deltaT': pv.temp_deltaT},
    )

    mc = ModelChain(
        system, loc,
        dc_model='pvwatts', ac_model='pvwatts',
        aoi_model='physical', spectral_model='no_loss',
        temperature_model='sapm', losses_model='no_loss'
    )

    mc.run_model(wx)

    # AC output with fixed wiring/soiling derates (I already accounted MPPT via pvwatts AC)
    mc.run_model(wx)

    # Use DC for a DC-coupled system (no inverter)
    pdc = mc.results.dc
    if isinstance(pdc, pd.DataFrame) and 'pdc' in pdc.columns:
        pdc = pdc['pdc']
    pdc = pdc.fillna(0.0)

    # Apply derates appropriate for DC path
    effective_gain = (1.0 - pv.soiling_loss) * (1.0 - pv.wiring_loss) * pv.mppt_eff
    pdc_net_w = pdc * effective_gain

    # Collapse to daily energy (Wh)
    daily_wh = pdc_net_w.resample('1D').sum().rename('pv_dc_Wh')

    # keep loss_fraction for metadata if you like
    total_loss = float(np.clip(1.0 - effective_gain, 0.0, 0.95))

    # Diagnostics I log for ranking & plots
    interval_min = int(pd.Series(wx.index).diff().dt.total_seconds().dropna().mode().iloc[0] / 60) if len(wx) > 1 else 60
    hours_per_step = interval_min / 60.0

    cloud_pct = wx.get('cloud_cover_pct')
    cloud_cover_mean_pct = float(np.nanmean(cloud_pct)) if cloud_pct is not None else np.nan

    daylight = sun_el > 0.0
    cloudy_hours_pct = float(
        100.0 * np.nanmean((cloud_pct[daylight] >= 50.0)) if (cloud_pct is not None and np.any(daylight)) else np.nan
    )

    blocked_sun_hours = float(np.sum(blocked & (sun_el > 0.0)) * hours_per_step)
    total_daylight_hours = float(np.sum(sun_el > 0.0) * hours_per_step)
    blocked_sun_hours_pct = float(100.0 * blocked_sun_hours / total_daylight_hours) if total_daylight_hours > 0 else np.nan

    meta = {
        'tilt_deg': float(tilt),
        'azimuth_deg': float(az),
        'altitude_m': float(alt),
        'loss_fraction': total_loss,
        'svf': float(svf),
        'mean_horizon_deg': float(np.mean(horz_elev)),
        'cloud_cover_mean_pct': cloud_cover_mean_pct,
        'cloudy_hours_pct': cloudy_hours_pct,
        'blocked_sun_hours': blocked_sun_hours,
        'blocked_sun_hours_pct': blocked_sun_hours_pct,
        'interval_min': interval_min,
    }

    _PV_CACHE[key] = (daily_wh, meta)
    return daily_wh, meta


# -------------------------------------------------------------
# Convert Eddy-style device behavior into daily Wh (load model)
# -------------------------------------------------------------
# Assume a repeating daily pattern: N short high-current boots at a fixed cadence,
# plus standby draw the rest of the day. If I want a day-split (e.g., 16h vs 8h cadence),
# I’ll extend this to two segments.


def _daily_load_wh(load: LoadParams) -> tuple[float, dict]:
    boots = int(24 * 60 / load.check_every_minutes)
    boot_I = load.boot_current_a + load.per_node_current_a * load.num_nodes
    boot_Ah = boot_I * (load.boot_secs / 3600.0) * boots

    standby_h = 24.0 - boots * (load.boot_secs / 3600.0)
    standby_Ah = load.standby_current_a * max(standby_h, 0)

    total_Ah = boot_Ah + standby_Ah
    return total_Ah * load.system_voltage_v, {
        'boots_per_day': boots,
        'boot_Ah_day': boot_Ah,
        'standby_Ah_day': standby_Ah,
        'total_Ah_day': total_Ah,
    }

def recommended_battery_wh(load_wh_day: float, days: int, min_soc: float=0.10) -> tuple[float, float]:
    """Return (nominal_Wh, usable_Wh) to ride through `days` with min_soc reserve."""
    usable = load_wh_day * days
    nominal = usable / (1.0 - min_soc)
    return nominal, usable

# Example (3 and 5 days)
nom3, _ = recommended_battery_wh(LoadParams().system_voltage_v * (
    (LoadParams().boot_current_a + LoadParams().per_node_current_a * LoadParams().num_nodes)
    * (LoadParams().boot_secs/3600) * int(24*60/LoadParams().check_every_minutes)
    + LoadParams().standby_current_a * (24 - int(24*60/LoadParams().check_every_minutes)*(LoadParams().boot_secs/3600))
), days=3, min_soc=BatteryParams().min_soc)

print(f"Suggested nominal capacity for 3-day ride-through ≈ {nom3:.1f} Wh")

# -------------------------------------------------------------
# Daily SOC march (simple energy bucket with eff & reserve)
# -------------------------------------------------------------
# I keep it day-resolution on purpose; the PV model is hourly, but decisions (outage/viability)
# for this class of device are fine at daily granularity.


def _soc_daily(daily_pv_wh: pd.Series, daily_load_wh: float, batt: BatteryParams):
    cap = batt.capacity_wh
    min_wh = batt.min_soc * cap

    soc = np.zeros(len(daily_pv_wh))
    soc[0] = np.clip(batt.initial_soc, 0, 1)
    outage = None

    for i, pv in enumerate(daily_pv_wh.values):
        stored = soc[i - 1] * cap if i > 0 else soc[0] * cap
        stored = stored + pv * batt.charge_eff - (daily_load_wh / batt.discharge_eff)
        stored = float(np.clip(stored, 0, cap))
        if outage is None and stored < min_wh:
            outage = daily_pv_wh.index[i]
        soc[i] = stored / cap

    s = pd.Series(soc, index=daily_pv_wh.index, name='soc_frac')
    deficit_days = int(((s * cap) < min_wh).sum())
    return s, deficit_days, outage


# -------------------------------------------------------------
# Feature wrapper used by my candidate ranking (single site)
# -------------------------------------------------------------
# This bundles PV → load → SOC plus diagnostics I want to plot.

def _longest_run_below(series: pd.Series, threshold: float) -> int:
    """Longest consecutive days where series < threshold."""
    arr = (series.values < threshold).astype(int)
    m = run = 0
    for v in arr:
        run = run + 1 if v else 0
        m = run if run > m else m
    return int(m)

def _p10(series: pd.Series) -> float:
    return float(series.quantile(0.10))

def solar_features_for_coord(
    coord: tuple[float, float],
    dem_manager,
    pv: PVParams = PVParams(),
    load: LoadParams = LoadParams(),
    batt: BatteryParams = BatteryParams(),
    year: int = 2022,
    force_south: bool = True,
) -> dict:
    daily_pv_wh, meta = _pv_dc_daily_wh(coord, dem_manager, pv, year=year, force_south=force_south)
    load_wh_day, _ = _daily_load_wh(load)
    soc, deficit_days, outage = _soc_daily(daily_pv_wh, load_wh_day, batt)
    pv_p10_wh = _p10(daily_pv_wh)
    longest_deficit_run = _longest_run_below(daily_pv_wh, load_wh_day)

    return {
        'pv_wh_day_mean': float(daily_pv_wh.mean()),
        'pv_p10_wh_day': float(pv_p10_wh),                        
        'longest_deficit_run_days': int(longest_deficit_run), 
        'pv_wh_year': float(daily_pv_wh.sum()),
        'load_wh_day': float(load_wh_day),
        'energy_margin_wh_day': float(daily_pv_wh.mean() - load_wh_day),
        'deficit_days': int(deficit_days),
        'first_outage_day': outage,
        'tilt_deg': meta.get('tilt_deg'),
        'azimuth_deg': meta.get('azimuth_deg'),
        'pv_loss_fraction': meta.get('loss_fraction'),
        # Diagnostics I find useful for plots and sanity checks
        'svf': meta.get('svf'),
        'mean_horizon_deg': meta.get('mean_horizon_deg'),
        'cloud_cover_mean_pct': meta.get('cloud_cover_mean_pct'),
        'cloudy_hours_pct': meta.get('cloudy_hours_pct'),
        'blocked_sun_hours': meta.get('blocked_sun_hours'),
        'blocked_sun_hours_pct': meta.get('blocked_sun_hours_pct'),
        'wx_interval_min': meta.get('interval_min'),
    }


# -------------------------------------------------------------
# Map physical energy margin → bounded 0..1 solar score
# -------------------------------------------------------------
# I use a logistic so 0 Wh/day margin → score ~0.5. Positive margin pushes >0.5; negative <0.5.


def solar_score_from_energy_margin(margin_wh_day: float, scale: float = 50.0) -> float:
    x = float(margin_wh_day) / (float(scale) if scale else 1.0)
    return float(1.0 / (1.0 + np.exp(-x)))


In [None]:
WEATHER_CACHE.clear()

In [None]:
# 2. ===== DEM HANDLING & CORE UTILITIES =====
from astropy import units as u
from geopy.distance import geodesic
import numpy as np
import pycraf
import pycraf.utils as pu
from pycraf.pathprof import PathProp, loss_freespace

class DEMManager:
    """
    Centralized manager for DEM (ground), DSM (surface), CHM (canopy) data and utilities.
    """
    def __init__(self, dem_bytes: bytes):
        with MemoryFile(dem_bytes) as memfile:
            with memfile.open() as src:
                self.dem_array = src.read(1)
                self.transform = src.transform
                self.bounds = src.bounds
                self.profile = src.profile
                self.crs = src.crs
        self.shape = self.dem_array.shape
        self.slope_arr = None
        self.aspect_arr = None

        # Add new arrays for DSM (LiDAR) and CHM (Canopy Height Model)
        self.dsm_array = None
        self.dsm_transform = None
        self.chm_array = None
        self.chm_transform = None

        pixel_width = abs(self.transform.a)
        pixel_height = abs(self.transform.e)
        print(f"DEM pixel width: {pixel_width}, pixel height: {pixel_height}")
        if pixel_width < 1e-6 or pixel_height < 1e-6:
            raise ValueError("Invalid DEM: pixel size too small.")

    def set_dsm(self, dsm_array, dsm_transform):
        self.dsm_array = dsm_array
        self.dsm_transform = dsm_transform

    def set_chm(self, chm_array, chm_transform):
        self.chm_array = chm_array
        self.chm_transform = chm_transform

    def calculate_slope_and_aspect(self):
        center_col = self.dem_array.shape[1] // 2
        center_row = self.dem_array.shape[0] // 2
        lon0, lat0 = self.transform * (center_col, center_row)
        dx_meter, dy_meter = self.degree_to_meter(lon0, lat0)
        gy, gx = np.gradient(self.dem_array.astype(float), dy_meter, dx_meter)
        self.slope_arr = np.degrees(np.arctan(np.sqrt(gx**2 + gy**2)))
        aspect_rad = np.arctan2(-gx, gy)
        self.aspect_arr = np.degrees(aspect_rad)
        self.aspect_arr[self.aspect_arr < 0] += 360

    def degree_to_meter(self, lon, lat):
        geod = Geod(ellps='WGS84')
        dx = abs(self.transform.a)
        dy = abs(self.transform.e)
        _, _, x_meter = geod.inv(lon, lat, lon + dx, lat)
        _, _, y_meter = geod.inv(lon, lat, lon, lat + dy)
        return abs(x_meter), abs(y_meter)

    def get_elevation(self, coord: Tuple[float, float]) -> float:
        # Ground elevation (DEM)
        col, row = ~self.transform * coord
        row, col = int(row), int(col)
        if 0 <= row < self.shape[0] and 0 <= col < self.shape[1]:
            return float(self.dem_array[row, col])
        return 0.0

    def get_surface_elevation(self, coord: Tuple[float, float]) -> float:
        # Surface elevation (DSM/LiDAR)
        if self.dsm_array is None or self.dsm_transform is None:
            return self.get_elevation(coord)
        col, row = ~self.dsm_transform * coord
        row, col = int(row), int(col)
        if 0 <= row < self.dsm_array.shape[0] and 0 <= col < self.dsm_array.shape[1]:
            return float(self.dsm_array[row, col])
        return self.get_elevation(coord)

    def lonlat_to_chm_pixel(self, lon, lat):
        chm_crs = getattr(self, 'chm_crs', self.crs)
        if str(chm_crs) != "EPSG:4326":
            transformer = Transformer.from_crs("EPSG:4326", chm_crs, always_xy=True)
            x, y = transformer.transform(lon, lat)
        else:
            x, y = lon, lat
        col, row = ~self.chm_transform * (x, y)
        return int(row), int(col)
    def get_canopy_height(self, coord: Tuple[float, float]) -> float:
        if self.chm_array is None or self.chm_transform is None:
            return 0.0
        try:
            row, col = self.lonlat_to_chm_pixel(coord[0], coord[1])
            if 0 <= row < self.chm_array.shape[0] and 0 <= col < self.chm_array.shape[1]:
                return float(self.chm_array[row, col])
            else:
                return 0.0
        except Exception as e:
            return 0.0

    def get_slope_at_coord(self, coord: Tuple[float, float]) -> float:
        # Get slope at a specific coordinate
        col, row = ~self.transform * coord
        row, col = int(row), int(col)
        
        if (self.slope_arr is not None and 
            0 <= row < self.slope_arr.shape[0] and 
            0 <= col < self.slope_arr.shape[1]):
            return float(self.slope_arr[row, col])
        else:
            return 10.0  # Default moderate slope

    def get_terrain_profile(self, p1: Tuple[float, float], p2: Tuple[float, float], num_samples: int = 100) -> Tuple[np.ndarray, np.ndarray]:
        lons = np.linspace(p1[0], p2[0], num_samples)
        lats = np.linspace(p1[1], p2[1], num_samples)
        elevs = np.array([self.get_elevation((lon, lat)) for lon, lat in zip(lons, lats)])
        dist_m = geodesic(p1[::-1], p2[::-1]).meters
        dists = np.linspace(0, dist_m, num_samples)
        return dists, elevs

    def add_lidar_processing_to_demmanager(dem_manager, aoi_poly, lidar_tif):
        """
        Add your LiDAR processing code to the DEMManager
        """
        import pyproj
        from shapely.ops import transform as shapely_transform
        import rasterio
        from rasterio.mask import mask
        from rasterio.warp import reproject, Resampling
        from shapely.geometry import mapping
    
        def reproject_aoi_to_raster(aoi_poly, raster_crs):
            project = pyproj.Transformer.from_crs("EPSG:4326", raster_crs, always_xy=True).transform
            aoi_proj = shapely_transform(project, aoi_poly)
            return aoi_proj
    
        # Load and Mask LiDAR to AOI
        with rasterio.open(lidar_tif) as src:
            lidar_crs = src.crs
            if aoi_poly is not None:
                aoi_proj = reproject_aoi_to_raster(aoi_poly, lidar_crs)
                out_image, out_transform = mask(src, [mapping(aoi_proj)], crop=True)
                lidar_data = out_image[0]
                lidar_transform = out_transform
                nodata = src.nodata if src.nodata is not None else -999999.0
                print("Masked LiDAR to AOI bounds.")
            else:
                lidar_data = src.read(1)
                lidar_transform = src.transform
                nodata = src.nodata if src.nodata is not None else -999999.0
                print("⚠️ AOI not set, showing full raster.")
    
        # Mask nodata for display
        lidar_masked = np.where((lidar_data == nodata) | (lidar_data < -1000), np.nan, lidar_data)
    
        # Compute Slope and Aspect from DSM
        dy, dx = abs(lidar_transform.e), abs(lidar_transform.a)
        gy, gx = np.gradient(lidar_masked.astype(float), dy, dx)
        slope = np.degrees(np.arctan(np.sqrt(gx**2 + gy**2)))
        aspect = np.degrees(np.arctan2(-gx, gy))
        aspect[aspect < 0] += 360
    
        # Compute Canopy Height Model (CHM)
        dem_resampled = np.empty_like(lidar_data, dtype=float)
        reproject(
            dem_manager.dem_array,
            dem_resampled,
            src_transform=dem_manager.transform,
            src_crs=dem_manager.crs,
            dst_transform=lidar_transform,
            dst_crs=lidar_crs,
            resampling=Resampling.bilinear
        )
        chm = lidar_data - dem_resampled
        chm[chm < 0] = 0
        chm_transform = lidar_transform
    
        # Set the data in DEMManager
        dem_manager.set_dsm(lidar_data, lidar_transform)
        dem_manager.set_chm(chm, chm_transform)
    
        print("LiDAR processing complete - DSM and CHM added to DEMManager")
        return dem_manager

# --- Sampling, Grid, Path Loss, Solar Utilities (with cache note) ---

def reproject_aoi_to_raster(aoi_poly, raster_crs):
    """Reproject AOI polygon from EPSG:4326 to raster CRS"""
    import pyproj
    from shapely.ops import transform as shapely_transform
    
    project = pyproj.Transformer.from_crs("EPSG:4326", raster_crs, always_xy=True).transform
    aoi_proj = shapely_transform(project, aoi_poly)
    return aoi_proj


def sample_points_in_poly(polygon: Polygon, n_points: int) -> List[Tuple[float, float]]:
    minx, miny, maxx, maxy = polygon.bounds
    points = []
    while len(points) < n_points:
        p = Point(random.uniform(minx, maxx), random.uniform(miny, maxy))
        if polygon.contains(p):
            points.append((p.x, p.y))
    return points

def generate_coverage_grid(polygon: Polygon, grid_size: int) -> List[Tuple[float, float]]:
    """
    Generate coverage grid points within AOI polygon with elevation validation.
    Maintains original function signature for compatibility.
    """
    global dem_manager  # Access the global DEM manager
    
    min_x, min_y, max_x, max_y = polygon.bounds
    
    # Calculate grid spacing for debugging
    x_spacing_deg = (max_x - min_x) / (grid_size - 1) if grid_size > 1 else max_x - min_x
    y_spacing_deg = (max_y - min_y) / (grid_size - 1) if grid_size > 1 else max_y - min_y
    
    # Calculate approximate spacing in meters
    center_lat = (min_y + max_y) / 2
    center_lon = (min_x + max_x) / 2
    meters_per_deg_lat = 111000
    meters_per_deg_lon = 111000 * np.cos(np.radians(center_lat))
    x_spacing_m = x_spacing_deg * meters_per_deg_lon
    y_spacing_m = y_spacing_deg * meters_per_deg_lat
    
    print(f"\nGenerating Coverage Grid:")
    print(f"   AOI Bounds: ({min_x:.6f}, {min_y:.6f}) to ({max_x:.6f}, {max_y:.6f})")
    print(f"   AOI Size: {(max_x-min_x)*meters_per_deg_lon:.0f}m × {(max_y-min_y)*meters_per_deg_lat:.0f}m")
    print(f"   Grid Configuration: {grid_size}×{grid_size} (max {grid_size*grid_size} points)")
    print(f"   Grid Spacing: {x_spacing_m:.1f}m × {y_spacing_m:.1f}m")
    
    # Generate grid coordinates
    x_coords = np.linspace(min_x, max_x, grid_size)
    y_coords = np.linspace(min_y, max_y, grid_size)
    
    # Track statistics
    total_attempted = 0
    inside_aoi = 0
    valid_elevation = 0
    elevation_failures = 0
    grid_points = []
    elevation_stats = []
    
    # Check each potential grid point
    for x in x_coords:
        for y in y_coords:
            total_attempted += 1
            point = Point(x, y)
            
            # Check if point is inside polygon
            if polygon.contains(point):
                inside_aoi += 1
                coord = (x, y)
                
                # Validate elevation if DEM manager is available
                if dem_manager is not None:
                    try:
                        # Get elevation for this point
                        elevation = dem_manager.get_elevation(coord)
                        
                        # Check if elevation is valid (not NaN or extreme values)
                        if np.isfinite(elevation) and -500 < elevation < 9000:  # Reasonable Earth elevation range
                            valid_elevation += 1
                            grid_points.append(coord)
                            elevation_stats.append(elevation)
                        else:
                            elevation_failures += 1
                            print(f"   Invalid elevation at {coord}: {elevation}")
                            # Still include the point but note the issue
                            grid_points.append(coord)
                            
                    except Exception as e:
                        elevation_failures += 1
                        print(f"   Could not get elevation for {coord}: {e}")
                        # Include point anyway for robustness
                        grid_points.append(coord)
                else:
                    # No DEM manager available, just add the point
                    grid_points.append(coord)
    
    # Calculate final statistics
    print(f"\nGrid Generation Results:")
    print(f"   Total points checked: {total_attempted}")
    print(f"   Points inside AOI: {inside_aoi} ({100*inside_aoi/total_attempted:.1f}%)")
    
    if dem_manager is not None and elevation_stats:
        print(f"   Points with valid elevation: {valid_elevation}")
        if elevation_failures > 0:
            print(f"   Elevation failures: {elevation_failures}")
        print(f"   Elevation range: {min(elevation_stats):.1f}m to {max(elevation_stats):.1f}m")
        print(f"   Mean elevation: {np.mean(elevation_stats):.1f}m")
        
        # Check data source
        if hasattr(dem_manager, 'dsm_array') and dem_manager.dsm_array is not None:
            print(f"   Using DSM/LiDAR data for elevations")
        elif hasattr(dem_manager, 'chm_array') and dem_manager.chm_array is not None:
            print(f"   CHM vegetation data available")
        else:
            print(f"  Using standard DEM elevations")
    
    # Coverage density
    if polygon.area > 0:
        area_km2 = polygon.area * (meters_per_deg_lat * meters_per_deg_lon) / 1e6
        density = len(grid_points) / area_km2 if area_km2 > 0 else 0
        print(f"   Coverage density: {density:.1f} points/km²")
    
    # Distance analysis from centroid
    if grid_points:
        centroid = (polygon.centroid.x, polygon.centroid.y)
        distances = [geodesic(p[::-1], centroid[::-1]).meters for p in grid_points]
        print(f"   Distance from AOI center:")
        print(f"     Min: {min(distances):.0f}m")
        print(f"     Max: {max(distances):.0f}m") 
        print(f"     Mean: {np.mean(distances):.0f}m")
        
        # Check distribution within communication ranges
        if 'NETWORK' in globals():
            comm_range = NETWORK.MAX_COMM_RANGE_M
            within_1x = sum(1 for d in distances if d <= comm_range)
            within_2x = sum(1 for d in distances if d <= comm_range * 2)
            print(f"   Points within {comm_range:.0f}m: {within_1x} ({100*within_1x/len(grid_points):.1f}%)")
            print(f"   Points within {comm_range*2:.0f}m: {within_2x} ({100*within_2x/len(grid_points):.1f}%)")
    
    print(f"   Final grid: {len(grid_points)} points\n")
    
    return grid_points

# --- Utility: Consistent Gateway Sorting for Plotting ---
def sort_gateways(solution, sort_by="lon"):
    """
    Sorts gateways for consistent labeling in plots.
    Args:
        solution (list): List of gateway dicts, each with 'coord' key (lon, lat)
        sort_by (str): Sorting method: 'lon', 'lat', or 'centroid'
    Returns:
        Sorted list of gateway dicts
    """
    if not solution or not isinstance(solution, list) or 'coord' not in solution[0]:
        return solution
    if sort_by == "lon":
        return sorted(solution, key=lambda gw: gw['coord'][0])   # West to East
    elif sort_by == "lat":
        return sorted(solution, key=lambda gw: gw['coord'][1])   # South to North
    elif sort_by == "centroid":
        if 'aoi_poly' in globals() and aoi_poly is not None:
            from shapely.geometry import Point
            centroid = aoi_poly.centroid
            return sorted(solution, key=lambda gw: Point(gw['coord']).distance(centroid))
        else:
            return solution
    else:
        return solution
    
def normalize_df_columns(df, columns):
    """
    Normalize specified columns to 0-1 range using min-max scaling
    """
    df_normalized = df.copy()
    for col in columns:
        if col in df.columns:
            col_min = df[col].min()
            col_max = df[col].max()
            if col_max != col_min:  # Avoid division by zero
                df_normalized[col] = (df[col] - col_min) / (col_max - col_min)
            else:
                df_normalized[col] = 0.5  # Set to middle value if all same
    return df_normalized


def calculate_solar_score(
    coord: Tuple[float, float],
    dem_manager: DEMManager,
    panel_power_watts: float = 50,
    gamma_pdc: float = -0.004,
    inverter_power_watts: float = 50,
    temp_air: float = 20.0,
    wind_speed: float = 1.0
) -> float:
    col, row = ~dem_manager.transform * coord
    row, col = int(row), int(col)
    if not (0 <= row < dem_manager.shape[0] and 0 <= col < dem_manager.shape[1]):
        return 0.0
    tilt = float(np.clip(dem_manager.slope_arr[row, col], 0, 90)) if dem_manager.slope_arr is not None else 30.0
    azimuth = float(dem_manager.aspect_arr[row, col] % 360) if dem_manager.aspect_arr is not None else 180.0
    altitude = float(dem_manager.dem_array[row, col])
    lon, lat = coord
    loc = Location(lat, lon, tz='UTC', altitude=altitude)
    times = pd.date_range('2025-06-21', periods=24, freq='H', tz=loc.tz)
    weather = loc.get_clearsky(times, model='ineichen')
    weather['temp_air'] = temp_air
    weather['wind_speed'] = wind_speed
    system = PVSystem(
        surface_tilt=tilt,
        surface_azimuth=azimuth,
        module_parameters={'pdc0': panel_power_watts, 'gamma_pdc': gamma_pdc},
        inverter_parameters={'pdc0': inverter_power_watts},
        temperature_model_parameters={'a': -3.56, 'b': 0.0594, 'deltaT': 3}
    )
    mc = ModelChain(system, loc, dc_model='pvwatts', ac_model='pvwatts', aoi_model='physical',
                    spectral_model='no_loss', temperature_model='sapm', losses_model='no_loss')
    mc.run_model(weather)
    return float(mc.results.ac.sum()) / 1000.0 if hasattr(mc.results, 'ac') else 0.0


# =================================================================
# PYCRAF FUNCTIONS WITH INTERNAL IMPORTS (PARALLEL-SAFE)
# =================================================================

from pycraf import conversions as cnv
from pycraf import pathprof as pp
from astropy import units as u
import numpy as np
import numbers

def _get_terrain_profile_for_pycraf(p1, p2, dem_array, dem_transform):
    """Keep your original function name - just ensure it works properly"""
    dist_m = geodesic(p1[::-1], p2[::-1]).meters
    step_m = 10
    num_samples = max(2, int(dist_m // step_m) + 1)
    lons = np.linspace(p1[0], p2[0], num_samples)
    lats = np.linspace(p1[1], p2[1], num_samples)
    elev = []
    for lon, lat in zip(lons, lats):
        col, row = ~dem_transform * (lon, lat)
        row, col = int(row), int(col)
        if 0 <= row < dem_array.shape[0] and 0 <= col < dem_array.shape[1]:
            elev.append(float(dem_array[row, col]))
        else:
            elev.append(np.nanmean(dem_array))
    return np.array(elev), step_m, dist_m

def calculate_bearing_array(lons, lats):
    """Keep your original function - just ensure it returns proper scalar for pycraf"""
    bearings = np.zeros(len(lons), dtype=float)
    for i in range(len(lons) - 1):
        d_lon = lons[i+1] - lons[i]
        d_lat = lats[i+1] - lats[i]
        angle = np.arctan2(d_lon, d_lat)
        bearings[i] = np.degrees(angle)
    bearings[-1] = bearings[-2] if len(bearings) > 1 else 0.0
    return bearings

def safe_extract_loss(loss_result, loss_name="Loss"):
    """
    NEW FUNCTION - Add this to handle pycraf tuple returns
    """
    try:
        if loss_result is None:
            return 0.0
            
        # Handle tuple returns (this fixes your error!)
        if isinstance(loss_result, (tuple, list)):
            if len(loss_result) == 0:
                return 0.0
            result = loss_result[0]  # Take first element
        else:
            result = loss_result
            
        # Handle astropy units
        if hasattr(result, 'to'):
            return float(result.to(u.dB).value)
        elif hasattr(result, 'value'):
            return float(result.value)
        else:
            return float(result)
            
    except Exception as e:
        return 0.0  # Safe fallback

print("Added safe_extract_loss() function to handle tuple returns")



from astropy import units as u
from pycraf.pathprof import PathProp, loss_complete
from geopy.distance import geodesic
import numpy as np

def calculate_path_loss_pycraf(
    p1: tuple, p2: tuple,
    dem_manager,          # DEMManager class
    tx_h: float, rx_h: float, freq_mhz: float,
    veg_threshold: float, attn_per_meter: float,
    max_comm_range: float
) -> float:
    """
    Most realistic path loss using pycraf's ITU-R P.452-16: terrain+canopy profile and antenna heights.
    FIXED: Properly handles tuple return from loss_complete
    """
    try:
        # Distance check
        dist_m = geodesic(p1[::-1], p2[::-1]).meters
        if dist_m > max_comm_range or dist_m <= 0:
            return np.inf

        # Use a reasonable sampling step
        step_m = 10
        num_samples = max(5, int(dist_m // step_m) + 1)
        lons = np.linspace(p1[0], p2[0], num_samples)
        lats = np.linspace(p1[1], p2[1], num_samples)

        # --- DEM + CHM if possible, else fallback
        if hasattr(dem_manager, 'chm_array') and dem_manager.chm_array is not None:
            terrain_profile = np.array([
                dem_manager.get_elevation((lon, lat)) +
                dem_manager.get_canopy_height((lon, lat))
                for lon, lat in zip(lons, lats)
            ])
            using_vegetation = True
        elif hasattr(dem_manager, 'dsm_array') and dem_manager.dsm_array is not None:
            terrain_profile = np.array([
                dem_manager.get_surface_elevation((lon, lat))
                for lon, lat in zip(lons, lats)
            ])
            using_vegetation = True
        else:
            terrain_profile = np.array([
                dem_manager.get_elevation((lon, lat))
                for lon, lat in zip(lons, lats)
            ])
            using_vegetation = False

        # Fix for nan/degenerate profile
        if len(terrain_profile) < 2 or np.all(np.isnan(terrain_profile)):
            return 120.0
        terrain_profile = np.nan_to_num(terrain_profile, nan=np.nanmean(terrain_profile))
        dists = np.linspace(0, dist_m, len(terrain_profile))

        # pycraf needs astropy.units on everything
        from pycraf.pathprof import PathProp, loss_complete
        from astropy import units as u
        
        freq = freq_mhz * 1e6 * u.Hz
        temperature = 290.0 * u.K
        pressure = 1013.0 * u.hPa
        lon_t, lat_t = p1[0] * u.deg, p1[1] * u.deg
        lon_r, lat_r = p2[0] * u.deg, p2[1] * u.deg
        h_tg = tx_h * u.m
        h_rg = rx_h * u.m
        hprof_step = step_m * u.m
        timepercent = 2.0 * u.percent

        # Provide the terrain heights, dists, and optional bearings
        bearing = np.degrees(np.arctan2(p2[0] - p1[0], p2[1] - p1[1]))
        backbearing = (bearing + 180) % 360

        pprop = PathProp(
            freq, temperature, pressure,
            lon_t, lat_t, lon_r, lat_r,
            h_tg, h_rg,
            hprof_step,
            timepercent,
            hprof_dists=dists * u.m,
            hprof_heights=terrain_profile * u.m,
            hprof_bearing=bearing * u.deg,
            hprof_backbearing=backbearing * u.deg
        )

        # Handle tuple return from loss_complete
        # loss_complete returns a tuple: (L_b, L_b0p, L_bd, L_bd50, L_bam, L_b0b, L_bs, L_ba, L_dk, L_dh, L)
        losses = loss_complete(pprop)
        
        # The total loss 'L' is the last element of the tuple (index 10 or -1)
        if isinstance(losses, tuple):
            # Get the total loss (last element)
            total_loss_result = losses[-1]  # or losses[10] for L
        else:
            # Fallback if it's not a tuple (shouldn't happen)
            total_loss_result = losses
        
        # Extract the numeric value
        if hasattr(total_loss_result, 'to'):
            total_loss = float(total_loss_result.to(u.dB).value)
        elif hasattr(total_loss_result, 'value'):
            total_loss = float(total_loss_result.value)
        else:
            total_loss = float(total_loss_result)

        # --- Extra vegetation penalty if using CHM/DSM ---
        veg_loss_db = 0.0
        if using_vegetation and veg_threshold > 0:
            dense_veg = [dem_manager.get_canopy_height((lon, lat)) 
                        for lon, lat in zip(lons, lats) 
                        if dem_manager.get_canopy_height((lon, lat)) > veg_threshold]
            if dense_veg:
                veg_loss_db = min(2.0, np.mean(dense_veg) * 0.02)

        total_loss += veg_loss_db
        return max(60.0, min(total_loss, 200.0))
        
    except Exception as e:
        print(f"Path loss error: {e}")
        import traceback
        traceback.print_exc()
        return 120.0


def calculate_path_analysis_pycraf(
    p1: tuple, p2: tuple,
    dem_manager,
    tx_h: float, rx_h: float, freq_mhz: float,
    veg_threshold: float, attn_per_meter: float,
    max_comm_range: float
) -> dict:
    """
    Path loss + explicit Fresnel clearance (loss uses full ITU-R P.452 model).
    FIXED: Better Fresnel calculation that handles edge cases
    """
    try:
        dist_m = geodesic(p1[::-1], p2[::-1]).meters
        if dist_m > max_comm_range or dist_m <= 0:
            return {'path_loss': np.inf, 'min_clearance': -np.inf, 'percent_clear': 0.0}

        step_m = 10
        num_samples = max(10, int(dist_m // step_m) + 1)  # More samples for better accuracy
        lons = np.linspace(p1[0], p2[0], num_samples)
        lats = np.linspace(p1[1], p2[1], num_samples)

        # --- Use the same DEM/CHM/DSM fallback logic as in loss function
        if hasattr(dem_manager, 'chm_array') and dem_manager.chm_array is not None:
            profile = np.array([
                dem_manager.get_elevation((lon, lat)) +
                dem_manager.get_canopy_height((lon, lat))
                for lon, lat in zip(lons, lats)
            ])
        elif hasattr(dem_manager, 'dsm_array') and dem_manager.dsm_array is not None:
            profile = np.array([
                dem_manager.get_surface_elevation((lon, lat))
                for lon, lat in zip(lons, lats)
            ])
        else:
            profile = np.array([
                dem_manager.get_elevation((lon, lat))
                for lon, lat in zip(lons, lats)
            ])
        profile = np.nan_to_num(profile, nan=np.nanmean(profile))

        # --- Fresnel clearance calculation
        def fresnel_clearance_percent(profile, tx_h, rx_h, dist_m, freq_mhz, clearance_ratio=0.6):
            """
            Calculate Fresnel zone clearance percentage
            """
            c = 3e8
            lambda_ = c / (freq_mhz * 1e6)
            num_samples = len(profile)
            
            # Total antenna heights (ground + antenna height)
            tx_z = profile[0] + tx_h
            rx_z = profile[-1] + rx_h
            
            # Line of sight between antennas
            los_line = np.linspace(tx_z, rx_z, num_samples)
            
            # Distance from TX to each point
            d1 = np.linspace(0, dist_m, num_samples)
            d2 = dist_m - d1
            
            # Calculate Fresnel zone radius at each point
            # Avoid division by zero at endpoints
            fresnel = np.zeros(num_samples)
            for i in range(num_samples):
                if d1[i] > 0 and d2[i] > 0:
                    fresnel[i] = np.sqrt(lambda_ * d1[i] * d2[i] / dist_m)
            
            # Clearance from terrain to LOS
            clearance = los_line - profile
            
            # Required clearance (60% of Fresnel zone)
            required_clearance = clearance_ratio * fresnel
            
            # Find minimum clearance (negative means obstruction)
            fresnel_clearance = clearance - required_clearance
            
            # Skip endpoints (TX and RX positions)
            if len(fresnel_clearance) > 2:
                interior_clearance = fresnel_clearance[1:-1]
                min_clear = np.min(interior_clearance)
                # Count points that are clear
                clear_points = np.sum(interior_clearance >= 0)
                total_points = len(interior_clearance)
                pct_clear = 100.0 * clear_points / total_points if total_points > 0 else 100.0
            else:
                min_clear = 0.0
                pct_clear = 100.0
            
            return min_clear, pct_clear

        min_clear, pct_clear = -np.inf, 0.0
        if len(profile) >= 2 and not np.all(np.isnan(profile)):
            min_clear, pct_clear = fresnel_clearance_percent(
                profile, tx_h, rx_h, dist_m, freq_mhz, clearance_ratio=0.6
            )

        # --- Path loss (calls your own function above, so stays in sync)
        pl = calculate_path_loss_pycraf(
            p1, p2, dem_manager,
            tx_h, rx_h, freq_mhz,
            veg_threshold, attn_per_meter,
            max_comm_range
        )

        return {
            'path_loss': pl,
            'min_clearance': min_clear,
            'percent_clear': pct_clear
        }

    except Exception as e:
        print(f"⚠️ Path analysis failed for {p1} -> {p2}: {e}")
        import traceback
        traceback.print_exc()
        return {
            'path_loss': np.inf,
            'min_clearance': -np.inf,
            'percent_clear': 0.0
        }

import pyproj
from shapely.ops import transform as shapely_transform
from shapely.geometry import mapping, Point
import os
from pathlib import Path

def find_chm_dataset_for_aoi(aoi_poly):
    """Find which CHM dataset (Eaton or Palisades) best covers the AOI - FIXED VERSION"""
    import rasterio
    from rasterio.warp import transform_bounds
    from shapely.geometry import box
    
    eaton_files = {
        'chm': 'EatonCHM.tif',
        'dsm': 'Eatonoutput.dsm.tif', 
        'dtm': 'Eatonoutput.dtm.tif'
    }
    
    palisades_files = {
        'chm': 'palisadesCHM.tif',
        'dsm': 'palisadesoutput.dsm.tif',
        'dtm': 'palisadesoutput.dtm.tif'
    }
    
    eaton_exists = all(os.path.exists(f) for f in eaton_files.values())
    palisades_exists = all(os.path.exists(f) for f in palisades_files.values())
    
    print(f"Eaton files available: {eaton_exists}")
    print(f"Palisades files available: {palisades_exists}")
    
    if not eaton_exists and not palisades_exists:
        print("No CHM datasets found!")
        return None, None
    
    if aoi_poly is None:
        # No AOI provided, return first available dataset
        if eaton_exists:
            return "Eaton", eaton_files
        else:
            return "Palisades", palisades_files
    
    # AOI provided - ACTUALLY CHECK COVERAGE (this was missing!)
    centroid = aoi_poly.centroid
    aoi_bounds = aoi_poly.bounds
    print(f"AOI center: {centroid.x:.4f}, {centroid.y:.4f}")
    
    coverage_results = {}
    
    # Check Eaton coverage if files exist
    if eaton_exists:
        try:
            with rasterio.open(eaton_files['chm']) as src:
                chm_bounds_4326 = transform_bounds(src.crs, 'EPSG:4326', *src.bounds)
                chm_minx, chm_miny, chm_maxx, chm_maxy = chm_bounds_4326
                aoi_minx, aoi_miny, aoi_maxx, aoi_maxy = aoi_bounds
                
                # Check for overlap
                overlap_x = not (aoi_maxx < chm_minx or aoi_minx > chm_maxx)
                overlap_y = not (aoi_maxy < chm_miny or aoi_miny > chm_maxy)
                has_overlap = overlap_x and overlap_y
                
                if has_overlap:
                    overlap_minx = max(aoi_minx, chm_minx)
                    overlap_miny = max(aoi_miny, chm_miny)
                    overlap_maxx = min(aoi_maxx, chm_maxx)
                    overlap_maxy = min(aoi_maxy, chm_maxy)
                    overlap_poly = box(overlap_minx, overlap_miny, overlap_maxx, overlap_maxy)
                    overlap_percent = (overlap_poly.area / aoi_poly.area) * 100
                else:
                    overlap_percent = 0
                
                coverage_results['Eaton'] = {
                    'files': eaton_files,
                    'coverage': overlap_percent,
                    'has_overlap': has_overlap
                }
                
        except Exception as e:
            print(f" Error checking Eaton coverage: {e}")
            coverage_results['Eaton'] = {'files': eaton_files, 'coverage': 0, 'has_overlap': False}
    
    # Check Palisades coverage if files exist
    if palisades_exists:
        try:
            with rasterio.open(palisades_files['chm']) as src:
                chm_bounds_4326 = transform_bounds(src.crs, 'EPSG:4326', *src.bounds)
                chm_minx, chm_miny, chm_maxx, chm_maxy = chm_bounds_4326
                aoi_minx, aoi_miny, aoi_maxx, aoi_maxy = aoi_bounds
                
                # Check for overlap
                overlap_x = not (aoi_maxx < chm_minx or aoi_minx > chm_maxx)
                overlap_y = not (aoi_maxy < chm_miny or aoi_miny > chm_maxy)
                has_overlap = overlap_x and overlap_y
                
                if has_overlap:
                    overlap_minx = max(aoi_minx, chm_minx)
                    overlap_miny = max(aoi_miny, chm_miny)
                    overlap_maxx = min(aoi_maxx, chm_maxx)
                    overlap_maxy = min(aoi_maxy, chm_maxy)
                    overlap_poly = box(overlap_minx, overlap_miny, overlap_maxx, overlap_maxy)
                    overlap_percent = (overlap_poly.area / aoi_poly.area) * 100
                else:
                    overlap_percent = 0
                
                coverage_results['Palisades'] = {
                    'files': palisades_files,
                    'coverage': overlap_percent,
                    'has_overlap': has_overlap
                }
                
        except Exception as e:
            print(f" Error checking Palisades coverage: {e}")
            coverage_results['Palisades'] = {'files': palisades_files, 'coverage': 0, 'has_overlap': False}
    
    # Select the dataset with the best coverage
    best_dataset = None
    best_coverage = 0
    best_files = None
    
    for dataset_name, result in coverage_results.items():
        if result['has_overlap'] and result['coverage'] > best_coverage:
            best_coverage = result['coverage']
            best_dataset = dataset_name
            best_files = result['files']
    
    if best_dataset:
        print(f"Selected {best_dataset} dataset ({best_coverage:.1f}% coverage)")
        return best_dataset, best_files
    else:
        # No overlap found, return None to avoid the error
        print(" No datasets overlap with AOI!")
        return None, None

def process_opentopo_chm_for_wsn(aoi_poly, dem_manager):
    """Process OpenTopography CHM data and integrate with DEMManager for WSN analysis"""
    print("🔄 Processing OpenTopography CHM data for WSN analysis...")
    
    dataset_name, dataset_files = find_chm_dataset_for_aoi(aoi_poly)
    
    if dataset_files is None:
        print("No CHM datasets available!")
        return None
    
    print(f"📂 Using {dataset_name} dataset")
    
    try:
        chm_file = dataset_files['chm']
        dsm_file = dataset_files['dsm'] 
        
        print(f"Loading CHM: {chm_file}")
        print(f"Loading DSM: {dsm_file}")
        
        # Load and Mask CHM to AOI
        with rasterio.open(chm_file) as chm_src:
            chm_crs = chm_src.crs
            print(f"CHM CRS: {chm_crs}")
            print(f"CHM resolution: {abs(chm_src.transform.a):.2f}m")
            
            if aoi_poly is not None:
                print("Masking CHM to AOI bounds...")
                aoi_proj = reproject_aoi_to_raster(aoi_poly, chm_crs)
                chm_out, chm_transform = mask(chm_src, [mapping(aoi_proj)], crop=True)
                chm_data = chm_out[0]
                nodata = chm_src.nodata if chm_src.nodata is not None else -999999.0
                print(" Masked CHM to AOI bounds.")
                print(f"   Cropped CHM shape: {chm_data.shape}")
            else:
                chm_data = chm_src.read(1)
                chm_transform = chm_src.transform
                nodata = chm_src.nodata if chm_src.nodata is not None else -999999.0
        
        # Load and Mask DSM to AOI
        with rasterio.open(dsm_file) as dsm_src:
            if aoi_poly is not None:
                aoi_proj = reproject_aoi_to_raster(aoi_poly, dsm_src.crs)
                dsm_out, dsm_transform = mask(dsm_src, [mapping(aoi_proj)], crop=True)
                dsm_data = dsm_out[0]
                print(f"   Cropped DSM shape: {dsm_data.shape}")
            else:
                dsm_data = dsm_src.read(1)
                dsm_transform = dsm_src.transform
        
        # Clean up nodata values
        chm_masked = np.where((chm_data == nodata) | (chm_data < -1000), np.nan, chm_data)
        dsm_masked = np.where((dsm_data == nodata) | (dsm_data < -1000), np.nan, dsm_data)
        
        print(f"🌲 CHM range: {np.nanmin(chm_masked):.1f} to {np.nanmax(chm_masked):.1f} m")
        print(f"📡 DSM range: {np.nanmin(dsm_masked):.1f} to {np.nanmax(dsm_masked):.1f} m")
        
        # Compute Slope and Aspect from DSM
        print("🗻 Computing slope and aspect from DSM...")
        dy, dx = abs(dsm_transform.e), abs(dsm_transform.a)
        gy, gx = np.gradient(dsm_masked.astype(float), dy, dx)
        slope = np.degrees(np.arctan(np.sqrt(gx**2 + gy**2)))
        aspect = np.degrees(np.arctan2(-gx, gy))
        aspect[aspect < 0] += 360
        
        # Update DEMManager with processed data
        print("🔄 Updating DEMManager with OpenTopography CHM data...")
        dem_manager.set_dsm(dsm_masked, dsm_transform)
        dem_manager.set_chm(chm_masked, chm_transform)
        dem_manager.chm_crs = chm_crs   
        
        if not hasattr(dem_manager, 'lidar_slope'):
            dem_manager.lidar_slope = slope
            dem_manager.lidar_aspect = aspect
            dem_manager.lidar_transform = dsm_transform
        
        dem_manager.chm_dataset = dataset_name
        dem_manager.chm_files = dataset_files
        
        print("✅ OpenTopography CHM processing complete!")
        
        return {
            'dataset_name': dataset_name,
            'dsm': dsm_masked,
            'chm': chm_masked,
            'transform': dsm_transform,
            'crs': chm_crs,
            'nodata': nodata,
            'files': dataset_files
        }
        
    except FileNotFoundError as e:
        print(f"❌ ERROR: CHM file not found: {e}")
        return None
    except Exception as e:
        print(f"❌ ERROR processing OpenTopography CHM: {type(e).__name__}: {str(e)}")
        return None

def integrate_opentopo_chm():
    """Main integration function - call this after AOI selection"""
    global aoi_poly, dem_manager
    
    if aoi_poly is not None and dem_manager is not None:
        print("🚀 Starting OpenTopography CHM integration...")
        chm_results = process_opentopo_chm_for_wsn(aoi_poly, dem_manager)
        
        if chm_results is not None:
            print("✅ OpenTopography CHM integration complete!")
            print(f"   📂 Dataset: {chm_results['dataset_name']}")
            print(f"   🌲 CHM loaded with vegetation data")
            return chm_results
        else:
            print("CHM processing failed")
            return None
    else:
        print(" AOI or DEMManager not loaded. Please run AOI selection first.")
        return None

def check_chm_coverage_for_aoi(aoi_poly):
    """
    Check if AOI overlaps with available CHM datasets and show coverage info
    """
    print(" Checking CHM coverage for your AOI...")
    print("=" * 50)
    
    if aoi_poly is None:
        print(" No AOI selected yet!")
        return
    
    # Get AOI bounds
    aoi_bounds = aoi_poly.bounds
    aoi_minx, aoi_miny, aoi_maxx, aoi_maxy = aoi_bounds
    
    print(f"📍 Your AOI bounds:")
    print(f"   West: {aoi_minx:.6f}, East: {aoi_maxx:.6f}")
    print(f"   South: {aoi_miny:.6f}, North: {aoi_maxy:.6f}")
    print(f"   Center: {aoi_poly.centroid.x:.6f}, {aoi_poly.centroid.y:.6f}")
    
    # Check available CHM files
    chm_files = {
        'Eaton': {
            'chm': 'EatonCHM.tif',
            'dsm': 'Eatonoutput.dsm.tif'
        },
        'Palisades': {
            'chm': 'palisadesCHM.tif', 
            'dsm': 'palisadesoutput.dsm.tif'
        }
    }
    
    coverage_results = {}
    
    for dataset_name, files in chm_files.items():
        chm_file = files['chm']
        
        if not os.path.exists(chm_file):
            print(f"\n{dataset_name} CHM file not found: {chm_file}")
            coverage_results[dataset_name] = {'available': False, 'overlap': False}
            continue
            
        try:
            # Read CHM file bounds
            with rasterio.open(chm_file) as src:
                chm_bounds = src.bounds
                chm_crs = src.crs
    
                # Reproject CHM bounds from its CRS to EPSG:4326 (lat/lon)
                from rasterio.warp import transform_bounds
                chm_bounds_4326 = transform_bounds(chm_crs, 'EPSG:4326', *chm_bounds)
                chm_minx, chm_miny, chm_maxx, chm_maxy = chm_bounds_4326
                
                print(f"\n📂 {dataset_name} CHM bounds:")
                print(f"   West: {chm_minx:.6f}, East: {chm_maxx:.6f}")
                print(f"   South: {chm_miny:.6f}, North: {chm_maxy:.6f}")
                print(f"   Resolution: {abs(src.transform.a):.2f}m")
                print(f"   Size: {src.width} x {src.height} pixels")
                
                # Check for overlap
                overlap_x = not (aoi_maxx < chm_minx or aoi_minx > chm_maxx)
                overlap_y = not (aoi_maxy < chm_miny or aoi_miny > chm_maxy)
                has_overlap = overlap_x and overlap_y
                
                if has_overlap:
                    # Calculate overlap area
                    overlap_minx = max(aoi_minx, chm_minx)
                    overlap_miny = max(aoi_miny, chm_miny)
                    overlap_maxx = min(aoi_maxx, chm_maxx)
                    overlap_maxy = min(aoi_maxy, chm_maxy)
                    
                    from shapely.geometry import box
                    overlap_poly = box(overlap_minx, overlap_miny, overlap_maxx, overlap_maxy)
                    overlap_area = overlap_poly.area
                    aoi_area = aoi_poly.area
                    overlap_percent = (overlap_area / aoi_area) * 100
                    
                    print(f"   OVERLAP FOUND!")
                    print(f"   Coverage: {overlap_percent:.1f}% of your AOI")
                    
                    if overlap_percent > 90:
                        print(f"   Excellent coverage - this dataset covers almost your entire AOI")
                    elif overlap_percent > 50:
                        print(f"   Good coverage - this dataset covers most of your AOI")
                    else:
                        print(f"   Partial coverage - only part of your AOI is covered")
                        
                else:
                    print(f"    NO OVERLAP - This dataset doesn't cover your AOI")
                
                coverage_results[dataset_name] = {
                    'available': True,
                    'overlap': has_overlap,
                    'coverage_percent': overlap_percent if has_overlap else 0,
                    'bounds': chm_bounds
                }
                
        except Exception as e:
            print(f"\n❌ Error reading {dataset_name} CHM file: {e}")
            coverage_results[dataset_name] = {'available': False, 'overlap': False}
    
    # Summary and recommendations
    print("\n" + "=" * 50)
    print("SUMMARY & RECOMMENDATIONS")
    print("=" * 50)
    
    best_dataset = None
    best_coverage = 0
    
    for dataset_name, result in coverage_results.items():
        if result.get('overlap', False):
            coverage = result.get('coverage_percent', 0)
            if coverage > best_coverage:
                best_coverage = coverage
                best_dataset = dataset_name
    
    if best_dataset:
        print(f"RECOMMENDED: Use {best_dataset} dataset ({best_coverage:.1f}% coverage)")
        print(f" Your AOI will have vegetation data from OpenTopography CHM")
    else:
        print("WARNING: No CHM datasets cover your AOI area!")
        print("Suggestions:")
        print("   • Try selecting an AOI in the Eaton Canyon or Palisades area")
        print("   • Check if your CHM files are in the correct location")
        print("   • Your WSN analysis will work but without vegetation data")
    
    return coverage_results
                

In [None]:
import rasterio
from rasterio.warp import transform_bounds
from ipyleaflet import Map, Rectangle, basemaps
from shapely.geometry import box
import os

def visualize_chm_coverage():
    """Show the coverage areas of both Palisades and Eaton DSM datasets"""
    
    # Dataset files
    datasets = {
        'Palisades': {
            'dsm': 'palisadesoutput.dsm.tif',
            'chm': 'palisadesCHM.tif',
            'color': 'red'
        },
        'Eaton': {
            'dsm': 'Eatonoutput.dsm.tif', 
            'chm': 'EatonCHM.tif',
            'color': 'blue'
        }
    }
    
    coverage_info = {}
    all_bounds = []
    
    print("CHM/DSM DATASET COVERAGE AREAS")
    print("=" * 50)
    
    # Check each dataset
    for dataset_name, files in datasets.items():
        dsm_file = files['dsm']
        
        if not os.path.exists(dsm_file):
            print(f"{dataset_name} DSM file not found: {dsm_file}")
            continue
            
        try:
            # Read DSM bounds
            with rasterio.open(dsm_file) as src:
                # Get bounds in the original CRS
                bounds_orig = src.bounds
                crs_orig = src.crs
                
                # Transform bounds to lat/lon (EPSG:4326) for mapping
                bounds_4326 = transform_bounds(crs_orig, 'EPSG:4326', *bounds_orig)
                west, south, east, north = bounds_4326
                
                coverage_info[dataset_name] = {
                    'bounds': bounds_4326,
                    'crs': crs_orig,
                    'resolution': abs(src.transform.a),
                    'size': (src.width, src.height),
                    'color': files['color']
                }
                all_bounds.append(bounds_4326)
                
                print(f"\n{dataset_name.upper()} DATASET:")
                print(f"   File: {dsm_file}")
                print(f"   CRS: {crs_orig}")
                print(f"   Resolution: {abs(src.transform.a):.2f}m")
                print(f"   Size: {src.width} x {src.height} pixels")
                print(f"   Bounds: W:{west:.6f}, E:{east:.6f}, S:{south:.6f}, N:{north:.6f}")
                print(f"   Center: {(west+east)/2:.6f}, {(south+north)/2:.6f}")
                
        except Exception as e:
            print(f"Error reading {dataset_name} DSM: {e}")
    
    if not coverage_info:
        print("No datasets found!")
        return None
    
    # Calculate map center from all datasets
    if all_bounds:
        all_west = min(b[0] for b in all_bounds)
        all_east = max(b[2] for b in all_bounds)
        all_south = min(b[1] for b in all_bounds)
        all_north = max(b[3] for b in all_bounds)
        
        center_lat = (all_south + all_north) / 2
        center_lon = (all_west + all_east) / 2
    else:
        center_lat, center_lon = 34.15, -118.1  # Default LA area
    
    # Create map
    m = Map(
        center=[center_lat, center_lon], 
        zoom=11,  # Zoom out a bit to see both areas
        basemap=basemaps.Esri.WorldImagery
    )
    
    # Add rectangles for each dataset
    for dataset_name, info in coverage_info.items():
        west, south, east, north = info['bounds']
        
        coverage_rect = Rectangle(
            bounds=[[south, west], [north, east]],
            color=info['color'],
            fill_color=info['color'],
            fill_opacity=0.2,
            weight=3
        )
        m.add_layer(coverage_rect)
    
    print(f"\nMAP LEGEND:")
    print("=" * 20)
    for dataset_name, info in coverage_info.items():
        print(f"🟦 {info['color'].upper()} = {dataset_name} coverage area")
    
    print(f"\n👉 USAGE TIPS:")
    print("• Select your AOI INSIDE one of the colored rectangles")
    print("• Red (Palisades) covers western areas")  
    print("• Blue (Eaton) covers eastern areas")
    print("• Choose the dataset that best covers your area of interest")
    
    return m

# Run the visualization
chm_coverage_map = visualize_chm_coverage()
if chm_coverage_map:
    display(chm_coverage_map)

In [None]:
# 3. ===== AOI SELECTION INTERFACE (USABILITY) =====
from dotenv import load_dotenv
load_dotenv()

def fix_aoi_bounds(aoi_poly, dem_manager):
    aoi_bounds, dem_bounds = aoi_poly.bounds, dem_manager.bounds
    if (aoi_bounds[0] < dem_bounds[0] or aoi_bounds[1] < dem_bounds[1] or 
        aoi_bounds[2] > dem_bounds[2] or aoi_bounds[3] > dem_bounds[3]):
        safe_bounds = (
            max(aoi_bounds[0], dem_bounds[0] + 0.0001),
            max(aoi_bounds[1], dem_bounds[1] + 0.0001),
            min(aoi_bounds[2], dem_bounds[2] - 0.0001),
            min(aoi_bounds[3], dem_bounds[3] - 0.0001)
        )
        trimmed_poly = box(*safe_bounds)
        print(f"   Area lost: {100 * (aoi_poly.area - trimmed_poly.area) / aoi_poly.area:.2f}%")
        return trimmed_poly
    return aoi_poly

def create_aoi_selection_interface(center: Tuple[float, float] = (34.2, -118.16)):
    global drawn_geometry, aoi_poly, dem_manager, gateway_candidate_df
    m = Map(center=center, zoom=13, basemap=basemaps.Esri.WorldImagery)
    draw_control = DrawControl(
        circle={'shapeOptions': {'color': '#6bc2e5', 'weight': 3}},
        rectangle={'shapeOptions': {'color': '#fca45d', 'weight': 3}},
        polygon={'shapeOptions': {'color': '#4bc44b', 'weight': 3}}
    )
    output = Output()

    def handle_draw(self, action, geo_json):
        output.clear_output()
        global drawn_geometry
        with output:
            print(f"✓ Shape drawn: {action}")
            print(f"  Type: {geo_json['geometry']['type']}")
            drawn_geometry = geo_json

    draw_control.on_draw(handle_draw)
    m.add_control(draw_control)

    btn = Button(description='Complete Selection', button_style='success')
    ui = VBox([m, btn, output])

    def on_complete_clicked(b):
        global aoi_poly, dem_manager, drawn_geometry
        with output:
            output.clear_output()
            print("=" * 50)
            print("PROCESSING AOI SELECTION")
            print("=" * 50)
            if not drawn_geometry:
                print("ERROR: No area selected. Please draw a shape on the map.")
                return
            try:
                aoi_poly = shape(drawn_geometry['geometry'])
                bounds = aoi_poly.bounds
                print(f"✓ AOI bounds: {bounds}")
                minx, miny, maxx, maxy = bounds
                api_key = os.environ.get('OPENTOPO_API_KEY', None)
                url = "https://portal.opentopography.org/API/globaldem"
                params = {
                    "demtype": "SRTMGL1", "south": miny, "north": maxy, "west": minx, "east": maxx,
                    "outputFormat": "GTiff", "API_Key": api_key
                }
                print("\n📡 Requesting DEM data from OpenTopography...")
                response = requests.get(url, params=params, timeout=30)
                if response.status_code == 200:
                    print(f"✓ DEM data received: {len(response.content) / 1024:.1f} KB")
                    dem_manager = DEMManager(response.content)
                    # ---- EXTENDED DEBUGGING ----
                    print("\n--- DEM Debug Info ---")
                    print("DEM CRS:", dem_manager.crs)
                    print("DEM transform:", dem_manager.transform)
                    print("DEM bounds:", dem_manager.bounds)
                    print("DEM shape:", dem_manager.shape)
                    pixel_width = abs(dem_manager.transform.a)
                    pixel_height = abs(dem_manager.transform.e)
                    print(f"DEM pixel width: {pixel_width}, pixel height: {pixel_height}")
                    print(f"DEM min/max/mean elevation: {np.min(dem_manager.dem_array):.2f} / {np.max(dem_manager.dem_array):.2f} / {np.mean(dem_manager.dem_array):.2f}")

                    dem_manager.calculate_slope_and_aspect()
                    if dem_manager.slope_arr is not None and np.isfinite(dem_manager.slope_arr).any():
                        print("Slope array min/max/mean:", 
                            np.nanmin(dem_manager.slope_arr),
                            np.nanmax(dem_manager.slope_arr),
                            np.nanmean(dem_manager.slope_arr))
                    else:
                        print("Slope array is empty or contains only NaN values.")
                    print("--- End DEM Debug ---\n")
                    aoi_poly = fix_aoi_bounds(aoi_poly, dem_manager)
                    check_chm_coverage_for_aoi(aoi_poly)
                    chm_results = integrate_opentopo_chm()
                else:
                    print(f"ERROR: DEM request failed (HTTP {response.status_code})")
                    print(f"  Response: {response.text[:200]}...")
            except Exception as e:
                print(f"ERROR: {type(e).__name__}: {str(e)}")
                import traceback
                traceback.print_exc()
                print("👉 AOI and DEM loaded. Please run the next cell ('Gateway Candidate Precomputation').")
                check_chm_coverage_for_aoi(aoi_poly)
                chm_results = integrate_opentopo_chm()
        


    btn.on_click(on_complete_clicked)
    return ui

aoi_interface = create_aoi_selection_interface()
display(aoi_interface)

In [None]:
# =================================================================
# POSTER-READY CHM/DSM VISUALIZATION (bigger maps + larger labels)
# =================================================================

import matplotlib.pyplot as plt
import matplotlib.colors as colors
import matplotlib as mpl
from matplotlib.patches import Polygon as MPLPolygon
from matplotlib.ticker import ScalarFormatter, FuncFormatter, MaxNLocator
import numpy as np

# ---- Poster font sizes ----
TITLE_FSIZE    = 20   # subplot titles
AXIS_FSIZE     = 24   # x/y axis labels
TICK_FSIZE     = 20   # tick label size
CBAR_FSIZE     = 22   # colorbar label size
SUPTITLE_FSIZE = 35   # big title at the top
CAPTION_FSIZE  = 15   # fig.text boxes, legends, etc.

# ---- Axis formatting helpers (poster-friendly) ----
# Kill scientific/offset notation globally
mpl.rcParams['axes.formatter.useoffset'] = False
mpl.rcParams['axes.formatter.limits'] = (-7, 7)

# Thousands separator for big UTM coordinates
_thousands_fmt = FuncFormatter(lambda x, pos: f"{x:,.0f}")

def apply_poster_axis_format(ax, nbins=6, use_thousands=True, equal_aspect=True, tick_size=None):
    """Clean, poster-friendly axes."""
    ax.ticklabel_format(axis='both', style='plain', useOffset=False)
    ax.xaxis.set_major_formatter(ScalarFormatter(useOffset=False))
    ax.yaxis.set_major_formatter(ScalarFormatter(useOffset=False))
    if use_thousands:
        ax.xaxis.set_major_formatter(_thousands_fmt)
        ax.yaxis.set_major_formatter(_thousands_fmt)
    ax.xaxis.set_major_locator(MaxNLocator(nbins=nbins))
    ax.yaxis.set_major_locator(MaxNLocator(nbins=nbins))
    if tick_size is not None:
        ax.tick_params(axis='both', labelsize=tick_size)
    if equal_aspect:
        ax.set_aspect('equal', adjustable='box')

# =================================================================
# CHM VISUALIZATION - Shows your AOI with vegetation height data
# =================================================================
def visualize_chm_over_aoi():
    """
    Create a visualization showing CHM (vegetation height) data over your AOI
    """
    global aoi_poly, dem_manager
    
    if aoi_poly is None or dem_manager is None:
        print("No AOI or DEM data loaded. Please run AOI selection first.")
        return
        
    if not hasattr(dem_manager, 'chm_array') or dem_manager.chm_array is None:
        print("No CHM data loaded. Please run CHM integration first.")
        return
    
    print("🎨 Creating CHM visualization...")
    
    # Get the data
    chm_data = dem_manager.chm_array
    chm_transform = dem_manager.chm_transform
    dsm_data = dem_manager.dsm_array if dem_manager.dsm_array is not None else dem_manager.dem_array
    
    # Create figure with subplots (bigger for poster)
    fig, axes = plt.subplots(1, 2, figsize=(22, 12))  # <— larger maps
    
    # === LEFT PLOT: CHM (Vegetation Height) ===
    ax1 = axes[0]
    
    # Create extent for plotting (in map coordinates)
    height, width = chm_data.shape
    west, north = chm_transform * (0, 0)
    east, south = chm_transform * (width, height)
    extent = [west, east, south, north]
    
    # Plot CHM with vegetation-appropriate colormap
    chm_clean = np.where(np.isnan(chm_data), 0, chm_data)
    vegetation_map = ax1.imshow(
        chm_clean, 
        extent=extent, 
        origin='upper',
        cmap='YlGn',
        vmin=0, 
        vmax=np.nanpercentile(chm_clean, 95)
    )

    # Add AOI boundary
    transformed_coords = None
    if aoi_poly is not None:
        from pyproj import Transformer
        transformer = Transformer.from_crs('EPSG:4326', dem_manager.crs if hasattr(dem_manager, 'crs') else 'EPSG:6340', always_xy=True)
        aoi_coords = list(aoi_poly.exterior.coords)
        transformed_coords = [transformer.transform(lon, lat) for lon, lat in aoi_coords]
        aoi_patch = MPLPolygon(transformed_coords, fill=False, edgecolor='red', linewidth=3, alpha=0.8)
        ax1.add_patch(aoi_patch)

    # Poster axis formatting (fix weird x-axis numbers, add thousands separators)
    apply_poster_axis_format(ax1, tick_size=TICK_FSIZE)

    ax1.set_title(' Canopy Height Model (CHM)\nVegetation Height', fontsize=TITLE_FSIZE, fontweight='bold')
    ax1.set_xlabel('Easting (m)', fontsize=AXIS_FSIZE)
    ax1.set_ylabel('Northing (m)', fontsize=AXIS_FSIZE)
    
    # Add colorbar
    cbar1 = plt.colorbar(vegetation_map, ax=ax1, shrink=0.8)
    cbar1.set_label('Vegetation Height (m)', rotation=270, labelpad=22, fontsize=CBAR_FSIZE)
    cbar1.ax.tick_params(labelsize=TICK_FSIZE)
    
    # === RIGHT PLOT: DSM (Surface Elevation) ===
    ax2 = axes[1]
    
    dsm_clean = np.where(np.isnan(dsm_data), np.nanmean(dsm_data), dsm_data)
    terrain_map = ax2.imshow(
        dsm_clean,
        extent=extent,
        origin='upper', 
        cmap='terrain',
        vmin=np.nanpercentile(dsm_clean, 5),
        vmax=np.nanpercentile(dsm_clean, 95)
    )
    
    if aoi_poly is not None and transformed_coords is not None:
        aoi_patch2 = MPLPolygon(transformed_coords, fill=False, edgecolor='red', linewidth=3, alpha=0.8)
        ax2.add_patch(aoi_patch2)

    # Poster axis formatting
    apply_poster_axis_format(ax2, tick_size=TICK_FSIZE)
    
    ax2.set_title(' Digital Surface Model (DSM)\nElevation', fontsize=TITLE_FSIZE, fontweight='bold')
    ax2.set_xlabel('Easting (m)', fontsize=AXIS_FSIZE)
    ax2.set_ylabel('Northing (m)', fontsize=AXIS_FSIZE)
    
    cbar2 = plt.colorbar(terrain_map, ax=ax2, shrink=0.8)
    cbar2.set_label('Elevation (m)', rotation=270, labelpad=22, fontsize=CBAR_FSIZE)
    cbar2.ax.tick_params(labelsize=TICK_FSIZE)
    
    # === Overall title and statistics ===
    dataset_name = getattr(dem_manager, 'chm_dataset', 'Unknown')
    fig.suptitle(
        f'OpenTopography {dataset_name} Dataset - High Resolution Terrain & Vegetation',
        fontsize=SUPTITLE_FSIZE, fontweight='bold', y=0.95
    )
    
    stats_text = f""" DATASET STATISTICS:
• Resolution: 0.5m × 0.5m
• CHM Range: {np.nanmin(chm_clean):.1f} - {np.nanmax(chm_clean):.1f} m
• Mean Vegetation: {np.nanmean(chm_clean[chm_clean > 0.5]):.1f} m  
• DSM Range: {np.nanmin(dsm_clean):.1f} - {np.nanmax(dsm_clean):.1f} m
• Area: {chm_data.shape[0]} × {chm_data.shape[1]} pixels"""
    
    fig.text(0.02, 0.02, stats_text, fontsize=CAPTION_FSIZE,
             bbox=dict(boxstyle="round,pad=0.5", facecolor="lightblue", alpha=0.8),
             verticalalignment='bottom')
    
    fig.text(0.98, 0.02, ' AOI Boundary', fontsize=CAPTION_FSIZE,
             bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8),
             horizontalalignment='right', verticalalignment='bottom')
    
    plt.tight_layout()
    plt.subplots_adjust(top=0.90, bottom=0.15)
    plt.show()
    
    # Print summary
    print("CHM Visualization Complete!")
    print(f"Your AOI has vegetation ranging from {np.nanmin(chm_clean):.1f}m to {np.nanmax(chm_clean):.1f}m")
    print(f"Average tree height: {np.nanmean(chm_clean[chm_clean > 0.5]):.1f}m")
    print(f"Terrain elevation: {np.nanmin(dsm_clean):.1f}m to {np.nanmax(dsm_clean):.1f}m")


def create_vegetation_density_map():
    """
    Create a simplified vegetation density visualization
    """
    global dem_manager, aoi_poly
    
    if not hasattr(dem_manager, 'chm_array') or dem_manager.chm_array is None:
        print("No CHM data available for density map")
        return
        
    print("🌳 Creating vegetation density map...")
    
    chm_data = dem_manager.chm_array
    
    # Create vegetation density categories
    density_map = np.zeros_like(chm_data)
    density_map[chm_data < 1] = 0
    density_map[(chm_data >= 1) & (chm_data < 5)] = 1
    density_map[(chm_data >= 5) & (chm_data < 15)] = 2
    density_map[(chm_data >= 15) & (chm_data < 30)] = 3
    density_map[chm_data >= 30] = 4
    
    # Plot
    fig, ax = plt.subplots(figsize=(12, 10))  # slightly bigger
    height, width = density_map.shape
    west, north = dem_manager.chm_transform * (0, 0)
    east, south = dem_manager.chm_transform * (width, height)
    extent = [west, east, south, north]
    
    density_colors = ['lightgray', 'lightgreen', 'green', 'darkgreen', 'forestgreen']
    density_cmap = colors.ListedColormap(density_colors)
    
    im = ax.imshow(density_map, extent=extent, origin='upper', 
                   cmap=density_cmap, vmin=0, vmax=4)
    
    # Poster axis formatting
    apply_poster_axis_format(ax, tick_size=TICK_FSIZE)
    
    ax.set_title('Vegetation Density Classification', fontsize=TITLE_FSIZE, fontweight='bold')
    ax.set_xlabel('Easting (m)', fontsize=AXIS_FSIZE)
    ax.set_ylabel('Northing (m)', fontsize=AXIS_FSIZE)
    
    cbar = plt.colorbar(im, ax=ax, ticks=[0, 1, 2, 3, 4])
    cbar.ax.set_yticklabels(['None\n(0-1m)', 'Low\n(1-5m)', 'Medium\n(5-15m)', 
                             'High\n(15-30m)', 'Very High\n(30m+)'])
    cbar.set_label('Vegetation Density', rotation=270, labelpad=22, fontsize=CBAR_FSIZE)
    cbar.ax.tick_params(labelsize=TICK_FSIZE)
    
    plt.tight_layout()
    plt.show()
    
    # Statistics
    total_pixels = density_map.size
    for i, category in enumerate(['None', 'Low', 'Medium', 'High', 'Very High']):
        count = np.sum(density_map == i)
        percentage = (count / total_pixels) * 100
        print(f"🌿 {category:10}: {percentage:5.1f}% ({count:,} pixels)")

# =================================================================
# RUN THE VISUALIZATIONS
# =================================================================
print("Creating CHM visualizations for your AOI...")
print("=" * 60)
visualize_chm_over_aoi()
print("\n" + "=" * 60)
create_vegetation_density_map()
print("\nAll visualizations complete!")
print("The red boundary shows your AOI, with high-resolution vegetation data from OpenTopography")
print("Left plot: Vegetation height (CHM) | 🗻 Right plot: Surface elevation (DSM)")

In [None]:
pip install tqdm

In [None]:
# =================================================================
# cell.4.GATEWAY PRECOMPUTATION - WITH PARALLEL TQDM PROGRESS BAR (NO OVERRUN)
# =================================================================

import pandas as pd
import numpy as np
from geopy.distance import geodesic
import matplotlib.pyplot as plt

from joblib import Parallel, delayed
from tqdm.notebook import tqdm
import joblib
# ==== Precomputed artifacts written by Cell 4 and reused by Phase 1 ====
if 'PRECOMP_COVERAGE_GRID' not in globals(): PRECOMP_COVERAGE_GRID = None
if 'PRECOMP_COVERAGE_MASKS' not in globals(): PRECOMP_COVERAGE_MASKS = None
if 'PRECOMP_PAIRWISE_DIST' not in globals(): PRECOMP_PAIRWISE_DIST = None


# --- Helper functions ---
def build_pairwise_dist_matrix(df_or_coords) -> np.ndarray:
    """
    Accepts either:
      • a DataFrame with a 'coord' column of (lon, lat) tuples, or
      • an iterable of (lon, lat) tuples.
    Returns an NxN matrix of geodesic distances in meters.
    """
    if isinstance(df_or_coords, pd.DataFrame):
        coords = df_or_coords['coord'].tolist()
    else:
        coords = list(df_or_coords)

    n = len(coords)
    D = np.zeros((n, n), dtype=float)
    for i in range(n):
        lon_i, lat_i = coords[i]
        for j in range(i + 1, n):
            lon_j, lat_j = coords[j]
            d = geodesic((lat_i, lon_i), (lat_j, lon_j)).meters
            D[i, j] = D[j, i] = d
    return D

def coverage_mask_for_candidate(
    coord, coverage_grid, dem_manager,
    tx_h, rx_h, freq_mhz, veg_threshold, attn_per_meter,
    max_comm_range, max_path_loss_db, grid_idxs=None
):
    n = len(coverage_grid)
    mask = np.zeros(n, dtype=bool)
    idxs = range(n) if grid_idxs is None else grid_idxs
    for j in idxs:
        p = coverage_grid[j]
        try:
            # quick range gate
            if geodesic(coord[::-1], p[::-1]).meters > max_comm_range:
                continue

            pl = calculate_path_loss_pycraf(
                coord, p, dem_manager,
                tx_h, rx_h, freq_mhz, veg_threshold, attn_per_meter, max_comm_range
            )
            if np.isfinite(pl) and pl <= max_path_loss_db:
                mask[j] = True
        except Exception:
            # swallow per-point errors so one bad ray doesn't kill the job
            pass
    return mask

def get_slope_safe(coord, dem_manager):
    """Get slope at coordinate in degrees"""
    try:
        if hasattr(dem_manager, "get_slope_at_coord"):
            return dem_manager.get_slope_at_coord(coord)
        else:
            col, row = ~dem_manager.transform * coord
            row, col = int(row), int(col)
            if (dem_manager.slope_arr is not None and 
                0 <= row < dem_manager.slope_arr.shape[0] and 
                0 <= col < dem_manager.slope_arr.shape[1]):
                return float(dem_manager.slope_arr[row, col])
            else:
                return 10.0  # Default moderate slope
    except:
        return 10.0

def get_aspect_safe(coord, dem_manager):
    """Get aspect (direction of slope) in degrees 0-360"""
    try:
        if hasattr(dem_manager, "get_aspect_at_coord"):
            return dem_manager.get_aspect_at_coord(coord)
        # Simple aspect calculation from DEM
        col, row = ~dem_manager.transform * coord
        row, col = int(row), int(col)
        
        # Get neighboring elevations for gradient
        if hasattr(dem_manager, 'dem_array'):
            dem = dem_manager.dem_array
            if 1 <= row < dem.shape[0]-1 and 1 <= col < dem.shape[1]-1:
                dz_dx = (dem[row, col+1] - dem[row, col-1]) / (2 * dem_manager.resolution)
                dz_dy = (dem[row-1, col] - dem[row+1, col]) / (2 * dem_manager.resolution)
                aspect = np.degrees(np.arctan2(-dz_dx, dz_dy))
                if aspect < 0:
                    aspect += 360
                return aspect
        return 180.0  # Default south-facing
    except:
        return 180.0

def calculate_enhanced_objectives(coord, dem_manager) -> dict:
    """
    Physics-based objective calculations for gateway placement
    """
    try:
        elevation = dem_manager.get_elevation(coord)
        
        # Define elevation bounds for your area - THESE MUST MATCH YOUR ACTUAL TERRAIN
        min_elevation = 100.0   # meters
        max_elevation = 2800.0  # meters
        
        # Get slope safely
        slope = get_slope_safe(coord, dem_manager)
        aspect = get_aspect_safe(coord, dem_manager)
        
        # 1. ACCESSIBILITY - Fix the calculation
        def toblers_hiking_speed(slope_deg):
            """Returns hiking speed in km/h based on slope"""
            # Avoid extreme tangent values
            slope_deg = min(slope_deg, 70.0)  # Cap at 70 degrees
            slope_rad = np.radians(slope_deg)
            tan_val = np.tan(slope_rad)
            speed = 6.0 * np.exp(-3.5 * abs(tan_val + 0.05))
            return max(0.5, min(5.0, speed))
        
        hiking_speed = toblers_hiking_speed(slope)
        
        # More reasonable elevation penalty
        elevation_penalty = max(0.7, 1.0 - (elevation / 10000.0))
        
        # Seasonal factor
        if elevation > 2500:
            seasonal_factor = 0.5
        elif elevation > 2000:
            seasonal_factor = 0.75
        else:
            seasonal_factor = 1.0
            
        base_accessibility = hiking_speed / 5.0
        accessibility = base_accessibility * elevation_penalty * seasonal_factor
        accessibility = max(0.1, min(1.0, accessibility))  # Ensure bounds
        
        # 2. INSTALLATION COST - Fix normalization
        base_cost = 1000.0
        
        # Drone logistics
        drone_trips = 3.0  # Simplified: 3 trips for 80kg equipment
        drone_cost_per_trip = 150.0
        
        if slope < 45:  # Drone accessible
            transport_cost = drone_trips * drone_cost_per_trip
            weather_factor = 1.0 + min(elevation / 3000.0, 1.0) * 0.5
            transport_cost *= weather_factor
        else:  # Manual transport
            transport_cost = 2000.0  # Fixed high cost for manual
            
        # Foundation costs
        wind_factor = 1.0 + min(elevation / 5000.0, 1.0)
        foundation_cost = 2000.0 * (1.0 + min(slope / 30.0, 2.0)) * wind_factor
        
        total_cost = base_cost + transport_cost + foundation_cost
        # Better normalization
        min_possible_cost = 1000.0
        max_possible_cost = 10000.0
        cost_score = 1.0 - ((total_cost - min_possible_cost) / (max_possible_cost - min_possible_cost))
        cost_score = max(0.1, min(1.0, cost_score))
        
        # 3. TERRAIN SUITABILITY - Fix the calculation
        suitability = 1.0
        
        # Slope stability
        if slope > 35:
            suitability = 0.3
        elif slope > 25:
            suitability = 0.7
        elif slope > 15:
            suitability = 1.0
        elif slope > 2:
            suitability = 0.9
        else:  # Too flat
            suitability = 0.8
            
        # Avalanche risk (simplified)
        if elevation > 2000 and 30 <= slope <= 45:
            if aspect <= 45 or aspect >= 315:  # North-facing
                suitability *= 0.6
                
        # Wind exposure
        if elevation > 2500:
            suitability *= 0.8
        elif elevation > 2000:
            suitability *= 0.9
            
        terrain_score = max(0.1, min(1.0, suitability))
        
        # 4. VALLEY COVERAGE - Fix calculation
        if max_elevation > min_elevation:
            relief_position = (elevation - min_elevation) / (max_elevation - min_elevation)
        else:
            relief_position = 0.5
            
        # Position score
        if 0.2 <= relief_position <= 0.4:
            position_score = 1.0
        elif relief_position < 0.2:
            position_score = 0.5 + 2.5 * relief_position
        elif relief_position <= 0.6:
            position_score = 1.0 - (relief_position - 0.4) * 2.5
        else:
            position_score = max(0.1, 0.5 * (1.0 - relief_position))
            
        # Slope coverage
        if 10 <= slope <= 20:
            slope_coverage = 1.0
        elif slope < 10:
            slope_coverage = 0.7 + slope * 0.03
        else:
            slope_coverage = max(0.4, 1.0 - (slope - 20) * 0.02)
            
        valley_coverage = position_score * 0.7 + slope_coverage * 0.3
        valley_coverage = max(0.1, min(1.0, valley_coverage))
        
        return {
            'accessibility': accessibility,
            'installation_cost': cost_score,
            'terrain_suitability': terrain_score,
            'valley_coverage': valley_coverage
        }
        
    except Exception as e:
        print(f"⚠️ Enhanced objectives failed for {coord}: {e}")
        import traceback
        traceback.print_exc()
        # Return middle values, not all the same
        return {
            'accessibility': 0.5,
            'installation_cost': 0.5,
            'terrain_suitability': 0.5,
            'valley_coverage': 0.5
        }


def filter_candidates_by_criteria(candidates, dem_manager, min_elev=100, max_elev=2800, max_slope=25):
    filtered = []
    for coord in candidates:
        try:
            elev = dem_manager.get_elevation(coord)
            if hasattr(dem_manager, "get_slope_at_coord"):
                slope = dem_manager.get_slope_at_coord(coord)
            else:
                col, row = ~dem_manager.transform * coord
                row, col = int(row), int(col)
                if (dem_manager.slope_arr is not None and 
                    0 <= row < dem_manager.slope_arr.shape[0] and 
                    0 <= col < dem_manager.slope_arr.shape[1]):
                    slope = dem_manager.slope_arr[row, col]
                else:
                    slope = 10.0
            if min_elev <= elev <= max_elev and slope <= max_slope:
                filtered.append(coord)
        except Exception as e:
            continue
    print(f"Filtered {len(filtered)}/{len(candidates)} candidates as human-accessible.")
    return filtered

def coverage_score_for_candidate_lorawan(
    coord, coverage_grid, dem_manager,
    tx_h, rx_h, freq_mhz, veg_threshold, attn_per_meter, max_comm_range, max_path_loss_db
):
    """
    Fixed version that properly handles infinite path losses
    """
    if not coverage_grid:
        return 0.0
    
    covered_points = 0
    total_points = len(coverage_grid)
    finite_path_losses = []  # Only store finite values
    infinite_count = 0
    error_count = 0
    out_of_range_count = 0
    lorawan_budget = max_path_loss_db
    
    for p in coverage_grid:
        try:
            # Quick distance check first
            dist_m = geodesic(coord[::-1], p[::-1]).meters
            if dist_m > max_comm_range:
                out_of_range_count += 1
                infinite_count += 1
                # Don't add to path_losses at all
                continue
                
            path_loss = calculate_path_loss_pycraf(
                coord, p, dem_manager,
                tx_h, rx_h, freq_mhz, veg_threshold, attn_per_meter, max_comm_range
            )
            
            # Only store finite values for statistics
            if np.isfinite(path_loss):
                finite_path_losses.append(path_loss)
                if path_loss <= lorawan_budget:
                    covered_points += 1
            else:
                infinite_count += 1
                
        except Exception as e:
            error_count += 1
            continue
    
    coverage_pct = (covered_points / total_points) * 100.0 if total_points > 0 else 0.0
    
    # Only calculate statistics if we have finite values
    if finite_path_losses and (coverage_pct < 100 or len(finite_path_losses) <= 20):
        avg_loss = np.mean(finite_path_losses)
        min_loss = np.min(finite_path_losses)
        max_loss = np.max(finite_path_losses)
        margin = lorawan_budget - avg_loss
        
        lidar_status = "DSM+CHM" if (hasattr(dem_manager, 'dsm_array') and dem_manager.dsm_array is not None) else "DEM only"
        
        print(f"🎯 Gateway {coord}: {coverage_pct:.1f}% coverage")
        
        # Show finite statistics
        if infinite_count > 0:
            # Show that some are infinite but stats are for finite only
            print(f"   Path loss (finite): {min_loss:.0f}-{max_loss:.0f} dB (avg: {avg_loss:.0f} dB)")
            print(f"   Unreachable points: {infinite_count}/{total_points} ({100*infinite_count/total_points:.1f}%)")
        else:
            print(f"   Path loss: {min_loss:.0f}-{max_loss:.0f} dB (avg: {avg_loss:.0f} dB)")
            
        print(f"   Link margin: {margin:.0f} dB (budget: {lorawan_budget:.0f} dB)")
        print(f"   Data source: {lidar_status}")
        
        if out_of_range_count > 0:
            print(f"   Beyond range: {out_of_range_count} points > {max_comm_range}m")
        if error_count > 0:
            print(f"   Errors: {error_count}")
            
        # Show distribution of reachable points
        if finite_path_losses:
            reachable = sum(1 for pl in finite_path_losses if pl <= lorawan_budget)
            print(f"  Reachable: {reachable}/{len(finite_path_losses)} points with finite loss")
    
    elif not finite_path_losses:
        # All points had infinite loss
        print(f"Gateway {coord}: {coverage_pct:.1f}% coverage")
        print(f"   No points with finite path loss - gateway may be isolated")
        print(f"   Unreachable: {infinite_count}/{total_points} points")
    
    return coverage_pct

# --- MAIN FUNCTION: Precompute Gateway Objectives with Parallel Progress ---
def precompute_gateway_objectives_enhanced(
    aoi_poly, dem_manager,
    tx_h, rx_h, freq_mhz, veg_threshold, attn_per_meter, max_comm_range, max_path_loss_db,
    num_candidates=1000, export_csv=True, csv_path="gateway_candidates.csv"
) -> pd.DataFrame:
    """
    Enhanced parallelized gateway precomputation with tqdm progress bar (no overrun!).
    Writes PRECOMP_COVERAGE_GRID / PRECOMP_COVERAGE_MASKS / PRECOMP_PAIRWISE_DIST for Phase 1 reuse.
    """
    import joblib
    from tqdm.notebook import tqdm
    from joblib import Parallel, delayed

    class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack):
        tqdm_bar = None
        def __call__(self, *args, **kwargs):
            if self.tqdm_bar is not None:
                self.tqdm_bar.update(n=self.batch_size)
            return super().__call__(*args, **kwargs)

    old_callback = joblib.parallel.BatchCompletionCallBack
    joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback

    try:
        # 1) Candidates
        candidate_points = sample_points_in_poly(aoi_poly, num_candidates)
        candidates = list(filter_candidates_by_criteria(candidate_points, dem_manager))
        if not candidates:
            print("No viable candidates after filtering!")
            return pd.DataFrame()

        # 2) Coverage grid
        coverage_grid = generate_coverage_grid(aoi_poly, OPTIMIZATION.COVERAGE_GRID_SIZE)

        # 3) Build boolean coverage masks in parallel (one pass)
        print(f"Building coverage masks for {len(candidates)} candidates …")
        def _mask_job(coord):
            return coverage_mask_for_candidate(
                coord, coverage_grid, dem_manager,
                tx_h, rx_h, freq_mhz, veg_threshold, attn_per_meter,
                max_comm_range, max_path_loss_db, grid_idxs=None
            )

        with tqdm(total=len(candidates), desc="Mask jobs", unit="cand") as bar:
            TqdmBatchCompletionCallback.tqdm_bar = bar
            masks = Parallel(n_jobs=-1, backend='loky', prefer="processes")(
                delayed(_mask_job)(coord) for coord in candidates
            )
            TqdmBatchCompletionCallback.tqdm_bar = None

        masks = np.vstack(masks).astype(bool)

        # 4) Derive coverage % directly from masks (no second pass)
        coverage_scores = masks.mean(axis=1) * 100.0

        # 5) Centrality, solar, etc.
        aoi_centroid = (aoi_poly.centroid.x, aoi_poly.centroid.y)
        distances_to_center = [geodesic(c[::-1], aoi_centroid[::-1]).meters for c in candidates]
        max_dist_from_center = max(distances_to_center) if distances_to_center else 1.0

        objective_data = []
        for i, coord in enumerate(candidates):
            try:
                elevation = dem_manager.get_elevation(coord)
                sf = solar_features_for_coord(coord, dem_manager)   # NEW
                solar_score = solar_score_from_energy_margin(sf['energy_margin_wh_day'], scale=50.0)
                centrality_score = 1 - (distances_to_center[i] / max_dist_from_center)
                enhanced_obj = calculate_enhanced_objectives(coord, dem_manager)
                objective_data.append({
                    'candidate_index': i,
                    'coord': coord,
                    'lon': coord[0],
                    'lat': coord[1],
                    'elevation': elevation,
                    'solar': solar_score,
                    'coverage': float(coverage_scores[i]),
                    'centrality': centrality_score,
                    'pv_wh_day_mean': sf['pv_wh_day_mean'],
                    'pv_wh_year': sf['pv_wh_year'],
                    'load_wh_day': sf['load_wh_day'],
                    'energy_margin_wh_day': sf['energy_margin_wh_day'],
                    'deficit_days': sf['deficit_days'],
                    'tilt_deg': sf['tilt_deg'],
                    'azimuth_deg': sf['azimuth_deg'],
                    'pv_loss_fraction': sf['pv_loss_fraction'],
                    'svf': sf.get('svf'),
                    'mean_horizon_deg': sf.get('mean_horizon_deg'),
                    'cloud_cover_mean_pct': sf.get('cloud_cover_mean_pct'),
                    'cloudy_hours_pct': sf.get('cloudy_hours_pct'),
                    'blocked_sun_hours': sf.get('blocked_sun_hours'),
                    'blocked_sun_hours_pct': sf.get('blocked_sun_hours_pct'),
                    'pv_p10_wh_day': sf.get('pv_p10_wh_day'),
                    'longest_deficit_run_days': sf.get('longest_deficit_run_days'),
                    **enhanced_obj
                })
            except Exception as e:
                print(f"Objective calculation failed for candidate {i}: {e}")
                objective_data.append({
                    'candidate_index': i,
                    'coord': coord,
                    'lon': coord[0],
                    'lat': coord[1],
                    'elevation': 0.0,
                    'solar': 0.0,
                    'coverage': 0.0,
                    'centrality': 0.5,

                    # solar cols present but NaN
                    'pv_wh_day_mean': np.nan,
                    'pv_wh_year': np.nan,
                    'load_wh_day': np.nan,
                    'energy_margin_wh_day': np.nan,
                    'deficit_days': np.nan,
                    'tilt_deg': np.nan,
                    'azimuth_deg': np.nan,
                    'pv_loss_fraction': np.nan,

                    # diagnostics present but NaN
                    'svf': np.nan,
                    'mean_horizon_deg': np.nan,
                    'cloud_cover_mean_pct': np.nan,
                    'cloudy_hours_pct': np.nan,
                    'blocked_sun_hours': np.nan,
                    'blocked_sun_hours_pct': np.nan,

                    'accessibility': 0.5,
                    'installation_cost': 0.5,
                    'terrain_suitability': 0.5,
                    'valley_coverage': 0.5
                })

        df = pd.DataFrame(objective_data)
        if df.empty:
            print(" No valid candidates processed!")
            return df
        print("Solar sanity:")
        print(df[['pv_wh_day_mean','load_wh_day','energy_margin_wh_day','deficit_days']].describe())
        print("Sample rows:")
        print(df[['lon','lat','pv_wh_day_mean','load_wh_day','energy_margin_wh_day','deficit_days','solar']].head(5))
        print("Solar sanity:")
        solar_cols = ['pv_wh_day_mean','load_wh_day','energy_margin_wh_day','deficit_days']
        present = [c for c in solar_cols if c in df.columns]
        if present:
            print(df[present].describe())
        else:
            print("No solar columns present (all candidates failed before PV calc?).")
        diag_cols = [
            'cloud_cover_mean_pct','cloudy_hours_pct',
            'blocked_sun_hours','blocked_sun_hours_pct',
            'svf','mean_horizon_deg'
        ]
        present = [c for c in diag_cols if c in df.columns]
        print("\nDiagnostics summary:")
        if present:
            print(df[present].describe())
        else:
            print("No diagnostic columns found – did you replace both `_pv_dc_daily_wh` and `solar_features_for_coord`?")

        # quick sanity plots (feel free to comment out)
        ax = df.plot.scatter(x='cloud_cover_mean_pct', y='pv_wh_day_mean', alpha=0.6, title='PV vs Mean Cloud Cover (%)')
        ax = df.plot.scatter(x='blocked_sun_hours_pct', y='pv_wh_day_mean', alpha=0.6, title='PV vs Blocked Sun Hours (%)')

        # plot raw physical margin
        df['energy_margin_wh_day'].hist(bins=30)
        plt.xlabel('Energy margin (Wh/day)'); plt.ylabel('Count'); plt.show()

        # create the normalized copy used by optimizers/scoring
        rng = float(df['energy_margin_wh_day'].max() - df['energy_margin_wh_day'].min())
        df['energy_margin_norm'] = (
            (df['energy_margin_wh_day'] - df['energy_margin_wh_day'].min()) / (rng if rng else 1.0)
        )
        # Keep your previously-computed df['solar'] as-is.
        # (Optionally fill if missing)
        if 'solar' not in df or df['solar'].isna().all():
            df['solar'] = df['energy_margin_norm']
        
        # Normalize and export
        gateway_objectives = [
            'coverage', 'solar', 'elevation', 'accessibility', 'installation_cost',
            'terrain_suitability', 'valley_coverage', 'centrality'
        ]
        df = normalize_df_columns(df, gateway_objectives)
        if export_csv:
            df.to_csv(csv_path, index=False)
            print(f" Candidate DataFrame exported to {csv_path}")

        # 6) Write precomputed artifacts for Phase 1 reuse
        global PRECOMP_COVERAGE_GRID, PRECOMP_COVERAGE_MASKS, PRECOMP_PAIRWISE_DIST
        PRECOMP_COVERAGE_GRID = coverage_grid
        PRECOMP_COVERAGE_MASKS = masks
        PRECOMP_PAIRWISE_DIST = build_pairwise_dist_matrix(pd.DataFrame({'coord': candidates}))
        print(f" Saved PRECOMP_*: grid={len(PRECOMP_COVERAGE_GRID)} cols, masks={PRECOMP_COVERAGE_MASKS.shape}, dist={PRECOMP_PAIRWISE_DIST.shape}")

        # (Optional) quick scatter
        try:
            fig, axes = plt.subplots(1, 2, figsize=(15, 6))
            if 'coverage' in df.columns and len(df) > 0:
                sc1 = axes[0].scatter(df['lon'], df['lat'], c=df['coverage'], cmap='viridis', s=12, alpha=0.7)
                axes[0].set_title('Gateway Candidates - Coverage Score')
                axes[0].set_xlabel('Longitude'); axes[0].set_ylabel('Latitude')
                plt.colorbar(sc1, ax=axes[0], label='Coverage (%)')
            if 'elevation' in df.columns and len(df) > 0:
                sc2 = axes[1].scatter(df['lon'], df['lat'], c=df['elevation'], cmap='terrain', s=12, alpha=0.7)
                axes[1].set_title('Gateway Candidates - Elevation')
                axes[1].set_xlabel('Longitude'); axes[1].set_ylabel('Latitude')
                plt.colorbar(sc2, ax=axes[1], label='Elevation (m)')
            plt.tight_layout(); plt.show()
        except Exception as e:
            print(f" Plotting failed: {e}")

        return df

    except Exception as e:
        print(f" CRITICAL ERROR in precomputation: {type(e).__name__}: {str(e)}")
        import traceback; traceback.print_exc()
        return pd.DataFrame()
    finally:
        joblib.parallel.BatchCompletionCallBack = old_callback

# =================================================================
# USAGE EXAMPLE (for batch/production)
# =================================================================

gateway_candidate_df = precompute_gateway_objectives_enhanced(
    aoi_poly, dem_manager,
    NETWORK.GATEWAY_HEIGHT, NETWORK.SENSOR_HEIGHT, 915.0, 2.0, 0.0,
    NETWORK.MAX_COMM_RANGE_M, OPTIMIZATION.MAX_ALLOWABLE_PATH_LOSS_DB,
    num_candidates=300, export_csv=True
)

In [None]:
# --- Gateway Candidate Interactive Map & Histogram Visualization ---

from ipyleaflet import Map, CircleMarker, basemaps
from ipywidgets import HTML, VBox, Output
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np

# ----- Choose which objective to color by -----
objective_for_color = 'solar'  # Can set to 'coverage', 'elevation', etc.

# --- Robust Data/Bounds ---
scores = gateway_candidate_df[objective_for_color].values
norm = mcolors.Normalize(vmin=np.nanmin(scores), vmax=np.nanmax(scores))
cmap = plt.cm.plasma

# --- Setup Map ---
m = Map(
    center=[aoi_poly.centroid.y, aoi_poly.centroid.x],
    zoom=13,
    basemap=basemaps.Esri.WorldImagery
)

def safe_fmt(val, fmt):
    try:
        if val is None or (isinstance(val, float) and np.isnan(val)):
            return "NA"
        return format(float(val), fmt)
    except Exception:
        return "NA"

# --- Add Candidates as Markers ---
for _, row in gateway_candidate_df.iterrows():
    lat, lon = row['lat'], row['lon']
    val = row[objective_for_color]
    color = mcolors.to_hex(cmap(norm(val)))
    popup_html = (
        f"<b>Solar:</b> {safe_fmt(row.get('solar'), '.2f')}<br>"
        f"<b>PV p10 (Wh/day):</b> {safe_fmt(row.get('pv_p10_wh_day'), '.1f')}<br>"
        f"<b>Load (Wh/day):</b> {safe_fmt(row.get('load_wh_day'), '.1f')}<br>"
        f"<b>Longest deficit run:</b> {safe_fmt(row.get('longest_deficit_run_days'), '.0f')} d<br>"
        f"<b>Coverage:</b> {safe_fmt(row.get('coverage'), '.1f')}<br>"
        f"<b>Elevation:</b> {safe_fmt(row.get('elevation'), '.1f')}<br>"
        f"<b>Accessibility:</b> {safe_fmt(row.get('accessibility'), '.2f')}<br>"
        f"<b>Install Cost:</b> {safe_fmt(row.get('installation_cost'), '.2f')}<br>"
        f"<b>Terrain Suitability:</b> {safe_fmt(row.get('terrain_suitability'), '.2f')}<br>"
        f"<b>Valley Coverage:</b> {safe_fmt(row.get('valley_coverage'), '.2f')}"
    )
    marker = CircleMarker(
        location=(lat, lon),
        radius=4,
        color=color,
        fill_color=color,
        fill_opacity=0.7
    )
    marker.popup = HTML(popup_html)
    m.add_layer(marker)

# --- Histogram for Chosen Objective ---
out = Output()
with out:
    fig, ax = plt.subplots(figsize=(6, 3))
    n_vals, bins, patches = ax.hist(scores, bins=30, edgecolor='black')
    fracs = (bins[:-1] - bins.min()) / (bins.max() - bins.min() + 1e-8)
    for frac, patch in zip(fracs, patches):
        patch.set_facecolor(cmap(frac))
    ax.set_title(f'Distribution of {objective_for_color.capitalize()}')
    ax.set_xlabel(objective_for_color.capitalize())
    ax.set_ylabel("Count")
    plt.tight_layout()
    plt.show()

# --- Show Map + Histogram Together ---
VBox([m, out])

In [None]:
# ===========================================================
# 6. ===== PHASE 1: GATEWAY MULTI-OBJECTIVE OPTIMIZATION & UI =====
# With precomputed coverage masks (one-time) + fast NSGA-III evaluation
# + Dropdown to browse and visualize ALL feasible Pareto solutions
# ===========================================================
from IPython.display import display
from ipywidgets import (
    FloatSlider, IntSlider, Dropdown, SelectMultiple, Button, Output,
    HBox, VBox, HTML, Layout
)
import numpy as np
import pandas as pd
from geopy.distance import geodesic
from itertools import combinations
from joblib import Parallel, delayed
from tqdm.notebook import tqdm

# For the map
from ipyleaflet import Map, CircleMarker, GeoJSON, basemaps, LayerGroup

# ==============================
# Globals / shared precomputes
# ==============================
# These reuse the candidate DataFrame produced by your Cell 4 precomputation.
# Assumes the following exist from earlier cells:
#   aoi_poly, dem_manager, gateway_candidate_df, NETWORK, OPTIMIZATION
COVERAGE_GRID = None            # list[(lon,lat)]
COVERAGE_MASKS = None           # np.ndarray [N_candidates, N_grid] bool
PAIRWISE_DIST = None            # np.ndarray [N_candidates, N_candidates] float

# Adopted solution for downstream steps (single source of truth)
CURRENT_SOLUTION = None
CURRENT_SOLUTION_INDEXES = None

# State for saving solutions
saved_solutions = {} if 'saved_solutions' not in globals() else saved_solutions
solution_counter = 0 if 'solution_counter' not in globals() else solution_counter

# Store the entire feasible set from the last NSGA run for browsing
_last_nsga_store = {
    "idxs_list": None,         # list of arrays of candidate indices
    "F": None,                 # objective matrix for feasible solutions
    "rank_order": None,        # order to browse (best -> worst by dist to ideal)
    "objective_names": None,   # tuple/list of objective names used
    "num_gateways": None,      # k
    "min_sep": None,           # min separation
    "selected_objs": None      # same as objective_names
}

# Helper to enforce single source of truth in downstream cells
def get_current_solution_or_raise():
    global CURRENT_SOLUTION
    if CURRENT_SOLUTION is None:
        raise RuntimeError("No CURRENT_SOLUTION set. Adopt a solution above (MIP or NSGA) first.")
    return CURRENT_SOLUTION

# ==============================
# Utilities
# ==============================
def safe_fmt(val, fmt):
    try:
        if val is None or (isinstance(val, float) and np.isnan(val)):
            return "NA"
        return format(float(val), fmt)
    except Exception:
        return "NA"

def normalize_objective(df, col):
    vals = df[col].values
    min_val, max_val = np.min(vals), np.max(vals)
    if max_val > min_val:
        return (vals - min_val) / (max_val - min_val)
    else:
        return np.zeros_like(vals)

def get_objective_presets():
    return {
        'balanced': {
            'coverage': 0.22, 'solar': 0.13, 'accessibility': 0.13,
            'installation_cost': 0.10, 'terrain_suitability': 0.10, 'valley_coverage': 0.10,
            'elevation': 0.11, 'centrality': 0.11
        },
        'coverage_focused': {
            'coverage': 0.48, 'solar': 0.13, 'accessibility': 0.07,
            'installation_cost': 0.05, 'terrain_suitability': 0.05, 'valley_coverage': 0.0,
            'elevation': 0.11, 'centrality': 0.11
        },
        'accessibility_focused': {
            'coverage': 0.11, 'solar': 0.07, 'accessibility': 0.44,
            'installation_cost': 0.18, 'terrain_suitability': 0.10, 'valley_coverage': 0.0,
            'elevation': 0.05, 'centrality': 0.05
        },
        'valley_monitoring': {
            'coverage': 0.18, 'solar': 0.10, 'accessibility': 0.10,
            'installation_cost': 0.09, 'terrain_suitability': 0.10, 'valley_coverage': 0.22,
            'elevation': 0.11, 'centrality': 0.10
        },
        'cost_optimized': {
            'coverage': 0.22, 'solar': 0.07, 'accessibility': 0.14,
            'installation_cost': 0.32, 'terrain_suitability': 0.10, 'valley_coverage': 0.0,
            'elevation': 0.07, 'centrality': 0.08
        }
    }

# Fallback for sort_gateways if not defined elsewhere
if 'sort_gateways' not in globals():
    def sort_gateways(solution, sort_by="lon"):
        if not solution or not isinstance(solution, list) or 'coord' not in solution[0]:
            return solution
        if sort_by == "lon":
            return sorted(solution, key=lambda gw: gw['coord'][0])
        elif sort_by == "lat":
            return sorted(solution, key=lambda gw: gw['coord'][1])
        else:
            return solution

# ========= MIP Solver (unchanged) =========
def solve_gateway_placement_mip(candidate_df, num_gateways, weights, min_separation_m=800.0, solver=None):
    import pulp
    from itertools import combinations
    from geopy.distance import geodesic

    if solver is None:
        solver = pulp.PULP_CBC_CMD(msg=0)
    if candidate_df is None or candidate_df.empty:
        return None
    df = candidate_df.copy()
    indices = df.index.tolist()
    all_objectives = ['coverage', 'solar', 'accessibility', 'installation_cost', 'terrain_suitability', 'valley_coverage', 'elevation', 'centrality']
    used_objectives = [obj for obj in all_objectives if obj in df.columns]
    # Weighted composite score
    df['weighted_score'] = 0.0
    for obj, weight in weights.items():
        if obj in df.columns and weight > 0:
            normed = normalize_objective(df, obj)
            df['weighted_score'] += weight * normed

    prob = pulp.LpProblem("GatewayPlacement", pulp.LpMaximize)
    x = pulp.LpVariable.dicts("gateway", indices, cat='Binary')
    prob += pulp.lpSum([df.loc[i, 'weighted_score'] * x[i] for i in indices]), "Total_Weighted_Score"
    prob += pulp.lpSum([x[i] for i in indices]) == num_gateways, "Select_N_Gateways"

    if num_gateways > 1 and min_separation_m > 0:
        for i, j in combinations(indices, 2):
            coord_i, coord_j = df.loc[i, 'coord'], df.loc[j, 'coord']
            if geodesic(coord_i[::-1], coord_j[::-1]).meters < min_separation_m:
                prob += x[i] + x[j] <= 1, f"Separation_{i}_{j}"

    prob.solve(solver)
    if prob.status == pulp.LpStatusOptimal:
        selected_indices = [i for i in indices if x[i].varValue == 1]
        solution_gateways = df.loc[selected_indices].to_dict('records')
        objective_totals = {obj: sum(gw.get(obj, 0) for gw in solution_gateways) for obj in used_objectives}
        return {
            'solution': solution_gateways,
            'status': 'optimal',
            'objective_totals': objective_totals,
            'num_gateways': num_gateways,
            'weights': weights.copy(),
            'solver_used': str(solver)
        }
    else:
        return {
            'solution': None,
            'status': pulp.LpStatus[prob.status],
            'error': f"Optimization failed with status: {pulp.LpStatus[prob.status]}"
        }

def get_solver():
    import pulp
    return pulp.PULP_CBC_CMD(msg=0)

# ==============================
# Fast coverage + distance helpers
# ==============================
def build_pairwise_dist_matrix(candidate_df: pd.DataFrame) -> np.ndarray:
    coords = candidate_df['coord'].tolist()
    n = len(coords)
    D = np.zeros((n, n), dtype=float)
    for i in range(n):
        lon_i, lat_i = coords[i]
        for j in range(i+1, n):
            lon_j, lat_j = coords[j]
            d = geodesic((lat_i, lon_i), (lat_j, lon_j)).meters
            D[i, j] = D[j, i] = d
    return D

def coverage_mask_for_candidate(coord, coverage_grid, dem_manager,
    tx_h, rx_h, freq_mhz, veg_threshold, attn_per_meter, max_comm_range, max_path_loss_db,
    grid_idxs=None):
    n = len(coverage_grid)
    mask = np.zeros(n, dtype=bool)
    idxs = range(n) if grid_idxs is None else grid_idxs
    for j in idxs:
        p = coverage_grid[j]
        try:
            if geodesic(coord[::-1], p[::-1]).meters > max_comm_range:
                continue
            pl = calculate_path_loss_pycraf(
                coord, p, dem_manager,
                tx_h, rx_h, freq_mhz, veg_threshold, attn_per_meter, max_comm_range
            )
            if np.isfinite(pl) and pl <= max_path_loss_db:
                mask[j] = True
        except Exception:
            pass
    return mask

def ensure_precomputed_masks(
    candidate_cap=None,
    grid_every=1,
    n_jobs=-1,
    backend='threading'
):
    """
    Build COVERAGE_GRID, COVERAGE_MASKS, PAIRWISE_DIST with progress + subsampling.
    Call with grid_every>1 first to get to NSGA quickly, then switch to 1 for full fidelity.
    """
    global COVERAGE_GRID, COVERAGE_MASKS, PAIRWISE_DIST

    # Reuse precomputed values from Cell 4 if available
    if COVERAGE_GRID is None and 'PRECOMP_COVERAGE_GRID' in globals() and PRECOMP_COVERAGE_GRID is not None:
        COVERAGE_GRID = PRECOMP_COVERAGE_GRID
    if COVERAGE_MASKS is None and 'PRECOMP_COVERAGE_MASKS' in globals() and PRECOMP_COVERAGE_MASKS is not None:
        COVERAGE_MASKS = PRECOMP_COVERAGE_MASKS
    if PAIRWISE_DIST is None and 'PRECOMP_PAIRWISE_DIST' in globals() and PRECOMP_PAIRWISE_DIST is not None:
        PAIRWISE_DIST = PRECOMP_PAIRWISE_DIST

    # If everything is present, skip rebuild
    if (COVERAGE_GRID is not None) and (COVERAGE_MASKS is not None) and (PAIRWISE_DIST is not None):
        print(" Using precomputed masks/grid/dist from Cell 4")
        return

    print(" Building precomputed coverage masks & distances (one-time)...")

    # 1) Grid (reuse existing generator + optional subsample view)
    COVERAGE_GRID = generate_coverage_grid(aoi_poly, OPTIMIZATION.COVERAGE_GRID_SIZE)
    n_grid = len(COVERAGE_GRID)
    if n_grid == 0:
        raise RuntimeError("Coverage grid is empty.")

    if grid_every > 1:
        grid_idxs = np.arange(0, n_grid, grid_every, dtype=int)
        print(f"Using grid subsample: every {grid_every}th point → {len(grid_idxs)}/{n_grid}")
    else:
        grid_idxs = None
        print(f"Using full grid: {n_grid} points")

    # 2) Candidates (optional cap while testing)
    base_coords = gateway_candidate_df['coord'].tolist()
    if candidate_cap is not None and candidate_cap < len(base_coords):
        print(f"🎯 Candidate cap active: {candidate_cap}/{len(base_coords)} (for faster first build)")
        coords = base_coords[:candidate_cap]
    else:
        coords = base_coords

    # 3) Parallel mask build with tqdm
    print(f" Building masks for {len(coords)} candidates …")
    def _job(coord):
        return coverage_mask_for_candidate(
            coord, COVERAGE_GRID, dem_manager,
            NETWORK.GATEWAY_HEIGHT, NETWORK.SENSOR_HEIGHT, 915.0, 2.0, 0.0,
            NETWORK.MAX_COMM_RANGE_M, OPTIMIZATION.MAX_ALLOWABLE_PATH_LOSS_DB,
            grid_idxs=grid_idxs
        )

    masks = Parallel(n_jobs=n_jobs, backend=backend)(
        delayed(_job)(coord) for coord in tqdm(coords, desc="Mask jobs", unit="cand")
    )
    masks = np.vstack(masks).astype(bool)

    # 4) Candidate padding if capped
    if len(coords) < len(base_coords):
        pad = np.zeros((len(base_coords) - len(coords), masks.shape[1]), dtype=bool)
        COVERAGE_MASKS = np.vstack([masks, pad])
        print(" Remaining (uncapped) candidates default to no coverage until you rebuild with full set.")
    else:
        COVERAGE_MASKS = masks

    # 5) Pairwise distances
    PAIRWISE_DIST = build_pairwise_dist_matrix(gateway_candidate_df)

    print(f"Masks ready: candidates={COVERAGE_MASKS.shape[0]}, grid_cols={COVERAGE_MASKS.shape[1]}")
    if grid_every > 1:
        print(" You are using a subsampled grid for speed. For final results, rebuild with grid_every=1.")

# ==============================
# UI Controls
# ==============================
nsga_objective_options = [
    ('Coverage', 'coverage'),
    ('Solar', 'solar'),
    ('Accessibility', 'accessibility'),
    ('Elevation', 'elevation'),
    ('Installation Cost', 'installation_cost'),
    ('Terrain Suitability', 'terrain_suitability'),
    ('Valley Coverage', 'valley_coverage'),
    ('Centrality', 'centrality')
]
nsga_obj_select = SelectMultiple(
    options=[('Coverage', 'coverage')] + nsga_objective_options[1:],
    value=('coverage', 'solar', 'accessibility'),
    description='NSGA-III Objectives',
    layout=Layout(width='300px', height='110px'),
    disabled=False
)

style = {'description_width': 'initial'}
preset_dropdown = Dropdown(options=[('Custom', 'custom')] + [(k.replace('_', ' ').title(), k) for k in get_objective_presets().keys()],
                           value='balanced', description='Preset:', style=style)
w_coverage = FloatSlider(value=0.22, min=0, max=1, step=0.05, description='Coverage:', style=style)
w_solar = FloatSlider(value=0.13, min=0, max=1, step=0.05, description='Solar:', style=style)
w_accessibility = FloatSlider(value=0.13, min=0, max=1, step=0.05, description='Accessibility:', style=style)
w_install_cost = FloatSlider(value=0.10, min=0, max=1, step=0.05, description='Install Cost:', style=style)
w_terrain = FloatSlider(value=0.10, min=0, max=1, step=0.05, description='Terrain Suitability:', style=style)
w_valley = FloatSlider(value=0.10, min=0, max=1, step=0.05, description='Valley Coverage:', style=style)
w_elevation = FloatSlider(value=0.11, min=0, max=1, step=0.05, description='Elevation:', style=style)
w_centrality = FloatSlider(value=0.11, min=0, max=1, step=0.05, description='Centrality:', style=style)

min_separation = IntSlider(value=300, min=50, max=1000, step=50, description='Min Separation (m):', style=style)
num_gateways_slider = IntSlider(value=2, min=1, max=5, step=1, description='Gateways:', style=style)
optimization_method = Dropdown(options=[('Weighted (MIP)', 'mip'), ('Pareto (NSGA-III)', 'nsga3')],
                               value='mip', description='Method:', style=style)

def get_current_weights():
    return {
        'coverage': w_coverage.value,
        'solar': w_solar.value,
        'accessibility': w_accessibility.value,
        'installation_cost': w_install_cost.value,
        'terrain_suitability': w_terrain.value,
        'valley_coverage': w_valley.value,
        'elevation': w_elevation.value,
        'centrality': w_centrality.value
    }

def apply_preset(preset_name):
    presets = get_objective_presets()
    if preset_name in presets:
        weights = presets[preset_name]
        w_coverage.value = weights.get('coverage', 0.0)
        w_solar.value = weights.get('solar', 0.0)
        w_accessibility.value = weights.get('accessibility', 0.0)
        w_install_cost.value = weights.get('installation_cost', 0.0)
        w_terrain.value = weights.get('terrain_suitability', 0.0)
        w_valley.value = weights.get('valley_coverage', 0.0)
        w_elevation.value = weights.get('elevation', 0.0)
        w_centrality.value = weights.get('centrality', 0.0)

preset_dropdown.observe(lambda c: apply_preset(c['new']) if c['name']=='value' else None)

# ==============================
# Map viz helpers
# ==============================
def plot_gateways_on_map(aoi_poly, gateway_solution, label_prefix="Gateway", color='red'):
    """
    Visualize selected gateway locations on an ipyleaflet map.
    gateway_solution: list of dicts, each with at least {'coord': (lon, lat)}
    """
    m = Map(
        center=(aoi_poly.centroid.y, aoi_poly.centroid.x),
        zoom=13,
        basemap=basemaps.Esri.WorldImagery
    )
    m.add_layer(GeoJSON(data=aoi_poly.__geo_interface__, style={
        'color': 'yellow', 'fillOpacity': 0.1, 'weight': 3
    }))
    solution_sorted = sort_gateways(gateway_solution, sort_by="lon")
    for idx, gw in enumerate(solution_sorted):
        lat, lon = gw['coord'][1], gw['coord'][0]
        popup_html = f"<b>{label_prefix} {idx+1}</b><br>Lon: {lon:.5f}<br>Lat: {lat:.5f}"
        for k in ['coverage','elevation','solar','accessibility','installation_cost','terrain_suitability','valley_coverage','centrality']:
            if k in gw:
                v = gw[k]
                if isinstance(v, (int,float)):
                    popup_html += f"<br>{k.replace('_',' ').title()}: {v:.2f}"
                else:
                    popup_html += f"<br>{k.replace('_',' ').title()}: {v}"
        marker = CircleMarker(
            location=(lat, lon),
            radius=8,
            color=color,
            fill_color=color,
            fill_opacity=0.8,
            popup=HTML(popup_html)
        )
        m.add_layer(marker)
    return m

def _solution_records_from_idxs(idxs):
    """Build gateway record dicts from candidate_df rows + attach objective columns if present."""
    rows = gateway_candidate_df.iloc[list(idxs)]
    return rows.to_dict('records')

# ==============================
# Main Optimization Handler + Feasible Browser
# ==============================
output = Output()
map_out = Output()          # where we draw the map

feasible_dropdown = Dropdown(
    options=[("— run NSGA-III to populate —", -1)],
    value=-1,
    description='Feasible:',
    layout=Layout(width='400px')
)

# NSGA adopt button
adopt_btn = Button(description="Adopt Selected Feasible", button_style='info')

# MIP adopt button & state
adopt_mip_btn = Button(description="Adopt MIP Result", button_style='info', disabled=True, layout=Layout(width='160px'))
_last_mip_records = None

adopt_msg = Output()

def _adopt_clicked(_):
    adopt_msg.clear_output()
    with adopt_msg:
        store = _last_nsga_store
        if store["idxs_list"] is None or feasible_dropdown.value == -1:
            print("No feasible set selected.")
            return
        idx_in_feasible = int(feasible_dropdown.value)
        idxs = store["idxs_list"][idx_in_feasible]
        records = gateway_candidate_df.iloc[list(map(int, idxs))].to_dict('records')

        # Stash globally for downstream cells
        global CURRENT_SOLUTION, CURRENT_SOLUTION_INDEXES
        CURRENT_SOLUTION, CURRENT_SOLUTION_INDEXES = records, idxs

        rank_1 = int(np.where(store["rank_order"] == idx_in_feasible)[0][0]) + 1
        name = f"Chosen_FeasibleRank_{rank_1}_{store['num_gateways']}GW_NSGA3"
        saved_solutions[name] = {"solution": records, "meta": {
            "rank": rank_1,
            "objectives": store["objective_names"],
            "k": store["num_gateways"],
            "min_sep": store["min_sep"],
            "method": "NSGA-III"
        }}
        print(f"Adopted: {name} → CURRENT_SOLUTION set ({len(records)} gateways).")

adopt_btn.on_click(_adopt_clicked)

def _adopt_mip_clicked(_):
    adopt_msg.clear_output()
    with adopt_msg:
        global CURRENT_SOLUTION, CURRENT_SOLUTION_INDEXES, _last_mip_records
        if _last_mip_records is None:
            print("No MIP result to adopt.")
            return
        CURRENT_SOLUTION = _last_mip_records
        # Best-effort indexes if present
        CURRENT_SOLUTION_INDEXES = [r.get('candidate_index', None) for r in CURRENT_SOLUTION]
        name = f"Chosen_MIP_{len(CURRENT_SOLUTION)}GW"
        saved_solutions[name] = {"solution": CURRENT_SOLUTION, "meta": {"method": "MIP", "k": len(CURRENT_SOLUTION)}}
        print(f"Adopted: {name} → CURRENT_SOLUTION set ({len(CURRENT_SOLUTION)} gateways).")

adopt_mip_btn.on_click(_adopt_mip_clicked)

run_btn = Button(description='Run Optimization', button_style='success')

def _rank_feasible(feasible_F):
    """Return indices 0..(m-1) sorted by distance to ideal (best -> worst)."""
    ideal = feasible_F.min(axis=0)
    nadir = feasible_F.max(axis=0)
    norm = (feasible_F - ideal) / (nadir - ideal + 1e-12)
    d = np.linalg.norm(norm, axis=1)
    return np.argsort(d), d

def _format_dropdown_label(rank, F_row, obj_names):
    parts = [f"{i+1}:{n}={(-F_row[i] if n!='installation_cost' else F_row[i]):.2f}" if n!='coverage' else f"{i+1}:{n}={-F_row[i]:.2f}" for i,n in enumerate(obj_names)]
    return f"{rank:>3d} | " + " | ".join(parts)

def _refresh_feasible_dropdown():
    if _last_nsga_store["idxs_list"] is None:
        feasible_dropdown.options = [("— run NSGA-III to populate —", -1)]
        feasible_dropdown.value = -1
        return
    order = _last_nsga_store["rank_order"]
    F = _last_nsga_store["F"]
    names = _last_nsga_store["objective_names"]
    opts = []
    for r, idx_in_feasible in enumerate(order, start=1):
        label = _format_dropdown_label(r, F[idx_in_feasible], names)
        opts.append((label, int(idx_in_feasible)))
    feasible_dropdown.options = opts
    feasible_dropdown.value = order[0] if len(order) > 0 else -1

def _draw_solution_by_feasible_index(idx_in_feasible):
    map_out.clear_output()
    if _last_nsga_store["idxs_list"] is None or idx_in_feasible is None or idx_in_feasible == -1:
        return
    idxs = _last_nsga_store["idxs_list"][idx_in_feasible]
    records = _solution_records_from_idxs(idxs)
    with map_out:
        display(HTML(f"<b>Preview (NSGA feasible #{_last_nsga_store['rank_order'].tolist().index(idx_in_feasible)+1}) — not adopted</b>"))
        display(plot_gateways_on_map(aoi_poly, records, color="lime"))

def _on_dropdown_change(change):
    if change['name'] == 'value':
        _draw_solution_by_feasible_index(change['new'])

feasible_dropdown.observe(_on_dropdown_change, names='value')

def on_run_clicked(b):
    global saved_solutions, solution_counter, _last_nsga_store, _last_mip_records
    output.clear_output()
    map_out.clear_output()
    adopt_mip_btn.disabled = True
    _last_mip_records = None

    with output:
        method = optimization_method.value
        weights = get_current_weights()
        num_gateways = num_gateways_slider.value
        min_sep = float(min_separation.value)

        if gateway_candidate_df is None or gateway_candidate_df.empty:
            print(" Candidate DataFrame not loaded. Run precomputation first.")
            return

        print(f" Running optimization: method={method}, gateways={num_gateways}, min_separation={int(min_sep)}m")

        if method == 'mip':
            res = solve_gateway_placement_mip(
                gateway_candidate_df, num_gateways, weights, min_separation_m=min_sep, solver=get_solver()
            )
            if res and res['solution']:
                solution_counter += 1
                solution_name = f"Solution_{solution_counter}_{num_gateways}GW"
                saved_solutions[solution_name] = res
                print(f" Solution saved: {solution_name}")
                dfsol = pd.DataFrame(res['solution'])
                display(dfsol)

                # Preview map (NOT adopted)
                map_out.clear_output()
                with map_out:
                    display(HTML("<b>Preview (MIP): not adopted</b>"))
                    display(plot_gateways_on_map(aoi_poly, res['solution'], color="orange"))

                # Enable adopt
                _last_mip_records = res['solution']
                adopt_mip_btn.disabled = False

                # Clear NSGA browser since we didn't run it
                _last_nsga_store = {k: None for k in _last_nsga_store}
                _refresh_feasible_dropdown()
            else:
                print(f"Optimization failed: {res.get('status')}")
            return

        # ===== NSGA-III path (fast with masks) =====
        from pymoo.core.problem import Problem
        from pymoo.algorithms.moo.nsga3 import NSGA3
        from pymoo.util.ref_dirs import get_reference_directions
        from pymoo.optimize import minimize
        from pymoo.termination.default import DefaultMultiObjectiveTermination
        import time

        selected_objs = list(nsga_obj_select.value)
        if len(selected_objs) < 2:
            print(" Please select at least 2 objectives for multi-objective optimization")
            return

        print(f"📊 Selected objectives: {', '.join(selected_objs)}")

        # Ensure masks & distances exist (one-time build)
        ensure_precomputed_masks()

        # Configure NSGA-III params
        if len(selected_objs) == 2:
            n_partitions = 12
        elif len(selected_objs) == 3:
            n_partitions = 8
        else:
            n_partitions = 5

        ref_dirs = get_reference_directions("das-dennis", len(selected_objs), n_partitions=n_partitions)
        print(f"🔧 NSGA-III Configuration:")
        print(f"   Reference directions: {len(ref_dirs)}")
        print(f"   Population size: {len(ref_dirs)}")

        class GatewayPlacementProblem(Problem):
            def __init__(self, candidate_df, n_gateways, min_separation_m,
                         objectives, coverage_masks, pairwise_dist):
                self.candidate_df = candidate_df.reset_index(drop=True)
                self.n_gateways = n_gateways
                self.min_separation_m = float(min_separation_m)
                self.objectives = objectives
                self.coverage_masks = coverage_masks      # [N, G] bool
                self.pairwise_dist = pairwise_dist        # [N, N] float

                n_constraints = int(n_gateways * (n_gateways - 1) / 2) if n_gateways > 1 else 1
                super().__init__(n_var=n_gateways, n_obj=len(objectives), n_constr=n_constraints,
                                 xl=0, xu=len(self.candidate_df) - 1, elementwise_evaluation=False, vtype=int)

            def compute_union_coverage(self, idxs):
                union = self.coverage_masks[idxs].any(axis=0)
                return float(union.mean() * 100.0)

            def _evaluate(self, X, out, *args, **kwargs):
                n_solutions = X.shape[0]
                F = np.zeros((n_solutions, len(self.objectives)))
                G = []

                for i in range(n_solutions):
                    idxs = [int(k) for k in X[i]]

                    # duplicates → hard penalty
                    if len(set(idxs)) != len(idxs):
                        F[i, :] = 1e10
                        G.append([1e10] * self.n_constr)
                        continue

                    gw = self.candidate_df.iloc[idxs]

                    # objectives
                    for j, obj in enumerate(self.objectives):
                        if obj == 'coverage':
                            cov = self.compute_union_coverage(idxs)
                            F[i, j] = -cov                        # maximize
                        elif obj == 'installation_cost':
                            F[i, j] = float(np.sum(gw[obj])) if obj in gw.columns else 0.0  # minimize
                        elif obj in gw.columns:
                            F[i, j] = -float(np.mean(gw[obj]))     # maximize others
                        else:
                            F[i, j] = 0.0

                    # min-sep constraints via precomputed distances
                    if self.n_gateways > 1:
                        cons = []
                        for a in range(len(idxs)):
                            for b in range(a+1, len(idxs)):
                                d = self.pairwise_dist[idxs[a], idxs[b]]
                                cons.append(max(0.0, self.min_separation_m - d))
                        G.append(cons)
                    else:
                        G.append([0.0])

                out["F"] = F
                out["G"] = np.array(G)

        # Build problem
        problem = GatewayPlacementProblem(
            candidate_df=gateway_candidate_df,
            n_gateways=num_gateways,
            min_separation_m=min_sep,
            objectives=selected_objs,
            coverage_masks=COVERAGE_MASKS,
            pairwise_dist=PAIRWISE_DIST
        )

        # Algorithm/termination (lean while testing; you can scale back up)
        algorithm = NSGA3(
            ref_dirs=ref_dirs,
            pop_size=len(ref_dirs),
            eliminate_duplicates=True
        )
        termination = DefaultMultiObjectiveTermination(
            xtol=1e-8, cvtol=1e-6, ftol=0.005, period=20, n_max_gen=100
        )

        # Run
        print("🏃 Running NSGA-III optimization...")
        start_time = time.time()
        res = minimize(problem, algorithm, termination=termination,
                       seed=42, verbose=True, save_history=True)
        elapsed = time.time() - start_time
        print(f" Optimization completed in {elapsed:.1f} seconds")

        # Feasible set by constraint violation
        if res.G is not None:
            max_violation = np.max(res.G, axis=1)
            feasible_idx = np.where(max_violation <= 1.0)[0]
            if len(feasible_idx) == 0:
                print("No feasible solutions found. Relaxing constraints (10% violation)...")
                feasible_idx = np.where(max_violation <= min_sep * 0.1)[0]
                if len(feasible_idx) == 0:
                    print("Still no feasible solutions. Try reducing min separation or adjusting objectives.")
                    _last_nsga_store = {k: None for k in _last_nsga_store}
                    _refresh_feasible_dropdown()
                    return
        else:
            feasible_idx = np.arange(len(res.X))

        print(f"✅ Found {len(feasible_idx)} feasible solutions")

        feasible_F = res.F[feasible_idx]
        feasible_X = res.X[feasible_idx]

        # Rank feasible solutions by distance to ideal (normalized)
        rank_order, dvals = _rank_feasible(feasible_F)

        # Store for browser (preview via dropdown, no auto-adopt/auto-map)
        _last_nsga_store = {
            "idxs_list": [feasible_X[i] for i in range(len(feasible_X))],
            "F": feasible_F,
            "rank_order": rank_order,
            "objective_names": tuple(selected_objs),
            "num_gateways": num_gateways,
            "min_sep": min_sep,
            "selected_objs": tuple(selected_objs)
        }
        # Reset CURRENT_* until user adopts
        global CURRENT_SOLUTION, CURRENT_SOLUTION_INDEXES
        CURRENT_SOLUTION = None
        CURRENT_SOLUTION_INDEXES = None

        _refresh_feasible_dropdown()

        # Save the "closest to ideal" into saved_solutions (for record), but DO NOT adopt/plot automatically
        best_idx_in_feasible = rank_order[0] if len(rank_order) > 0 else 0
        best_X = feasible_X[best_idx_in_feasible]
        best_F = feasible_F[best_idx_in_feasible]
        idxs = [int(i) for i in best_X]
        solution_gateways = gateway_candidate_df.iloc[idxs].to_dict('records')

        def union_cov_from_idxs(idxs):
            return float(COVERAGE_MASKS[idxs].any(axis=0).mean() * 100.0)
        actual_coverage = union_cov_from_idxs(idxs)

        solution_summary = {
            'solution': solution_gateways,
            'method': 'NSGA-III',
            'selected_objectives': selected_objs,
            'num_gateways': num_gateways,
            'actual_coverage': actual_coverage,
            'objective_values': {}
        }
        for i, obj in enumerate(selected_objs):
            if obj == 'coverage':
                solution_summary['objective_values'][obj] = actual_coverage
            else:
                solution_summary['objective_values'][obj] = -best_F[i] if best_F[i] < 0 else best_F[i]

        solution_counter += 1
        solution_name = f"Solution_{solution_counter}_{num_gateways}GW_NSGA3"
        saved_solutions[solution_name] = solution_summary

        print(f"\nNSGA-III run complete. Browse the feasible dropdown above, then click 'Adopt Selected Feasible' to set CURRENT_SOLUTION.")
        display(pd.DataFrame(solution_gateways)[['coord', 'elevation', 'coverage', 'solar']])

# Wire up button
run_btn.on_click(on_run_clicked)

# ======== Compose UI =========
print("Candidate DataFrame size:", 0 if gateway_candidate_df is None else len(gateway_candidate_df))
if gateway_candidate_df is not None and not gateway_candidate_df.empty:
    cols_to_show = [c for c in ['coverage','solar','elevation','centrality','accessibility','installation_cost','terrain_suitability','valley_coverage'] if c in gateway_candidate_df.columns]
    if cols_to_show:
        print(gateway_candidate_df[cols_to_show].describe())

ui = VBox([
    HBox([preset_dropdown, optimization_method]),
    HBox([w_coverage, w_solar, w_accessibility, w_install_cost]),
    HBox([w_terrain, w_valley, w_elevation, w_centrality]),
    HBox([num_gateways_slider, min_separation, run_btn]),
    HBox([nsga_obj_select]),
    Output(),                                # spacer
    HBox([feasible_dropdown, adopt_btn, adopt_mip_btn]),  # both adopt buttons
    adopt_msg,
    map_out,                                 # map area (single place to preview)
    output                                   # logs
])
display(HTML("<h3>Phase 1: Gateway Multi-Objective Optimization</h3>"))
display(ui)

In [None]:
# ============================================================
# Poster tradeoffs (Coverage, Solar, Elevation) — separate figs
# - Uses _last_nsga_store (feasible solutions) + gateway_candidate_df
# - Computes pairwise Pareto fronts (max/max) and draws a connected curve
# - Highlights CURRENT_SOLUTION if adopted; otherwise closest-to-ideal
# - Legends sit inside each axis 
# ============================================================

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

plt.rcParams.update({
    "font.size": 15,
    "axes.titlesize": 22,
    "axes.labelsize": 19,
    "legend.fontsize": 13,
    "figure.figsize": (9.5, 7),
    "axes.titlepad": 15
})

def _infer_selected_index_from_current_store(store, current_idxs):
    if not store or store.get("idxs_list") is None or current_idxs is None:
        return None
    target = tuple(sorted(map(int, current_idxs)))
    for i, xs in enumerate(store["idxs_list"]):
        try:
            if tuple(sorted(map(int, xs))) == target:
                return i
        except Exception:
            pass
    return None

def _safe_mkdir(path):
    try:
        if path:
            os.makedirs(path, exist_ok=True)
            return path
    except OSError:
        pass
    return "."

def _safe_savefig(path, dpi=300):
    try:
        plt.savefig(path, dpi=dpi, bbox_inches="tight")
        print(f"saved: {path}")
    except OSError:
        # fallback to cwd
        fname = os.path.basename(path)
        alt = os.path.join(".", fname)
        plt.savefig(alt, dpi=dpi, bbox_inches="tight")
        print(f"saved (fallback): {alt}")

def _pareto_front_2d(x, y):
    """Indices of 2D Pareto front for maximize/maximize."""
    # Sort by x descending, then keep running max of y
    order = np.argsort(-x, kind="mergesort")
    best_y = -np.inf
    front = []
    for i in order:
        if y[i] >= best_y - 1e-12:
            front.append(i)
            best_y = max(best_y, y[i])
    # sort the front by x ascending for a nice left→right curve
    front = np.array(front, dtype=int)
    return front[np.argsort(x[front])]

def _closest_to_ideal(xcols):
    """Return index of point closest to ideal (all columns max), after [0,1] norm."""
    X = np.column_stack(xcols).astype(float)
    # normalize each column 0..1 (avoid div/0)
    mins = X.min(axis=0)
    maxs = X.max(axis=0)
    rng  = np.maximum(maxs - mins, 1e-12)
    Xn   = (X - mins) / rng
    d    = np.linalg.norm(1.0 - Xn, axis=1)
    return int(np.argmin(d))

def _get_feasible_summaries():
    """
    Build a DataFrame with one row per feasible solution:
      coverage_pct, solar_mean, elevation_mean
    coverage uses COVERAGE_MASKS union; solar/elevation are means of gateways in that solution.
    """
    if "_last_nsga_store" not in globals() or _last_nsga_store.get("idxs_list") is None:
        raise RuntimeError("No NSGA results in _last_nsga_store. Run Phase 1 first.")
    if "gateway_candidate_df" not in globals():
        raise RuntimeError("gateway_candidate_df is missing.")

    idxs_list = _last_nsga_store["idxs_list"]
    cand = gateway_candidate_df  # precomputed per-candidate metrics

    # Coverage via masks if available; else try to reconstruct from F if present
    masks = globals().get("COVERAGE_MASKS", None)
    have_masks = isinstance(masks, np.ndarray)

    rows = []
    for i, idxs in enumerate(idxs_list):
        idxs = list(map(int, idxs))
        # coverage
        if have_masks:
            cov = float(masks[idxs].any(axis=0).mean() * 100.0)
        else:
            # fallback: if coverage objective was used in F, recover it
            cov = None
            names = _last_nsga_store.get("objective_names", [])
            if names and "coverage" in names:
                j = list(names).index("coverage")
                cov = -float(_last_nsga_store["F"][i, j])  # stored negative for max
        # solar/elevation means
        sol = float(cand.iloc[idxs]["solar"].mean()) if "solar" in cand.columns else np.nan
        elev = float(cand.iloc[idxs]["elevation"].mean()) if "elevation" in cand.columns else np.nan
        rows.append({"coverage_pct": cov, "solar_mean": sol, "elevation_mean": elev})

    df = pd.DataFrame(rows)
    # normalize solar & elevation for clean axes
    for col in ["solar_mean", "elevation_mean"]:
        v = df[col].to_numpy(dtype=float)
        vmin, vmax = np.nanmin(v), np.nanmax(v)
        df[col + "_norm"] = (v - vmin) / max(vmax - vmin, 1e-12)
    return df

def _selected_feasible_index(df):
    """Prefer the adopted feasible; otherwise, closest to ideal across the three metrics."""
    # try adopted
    sel = _infer_selected_index_from_current_store(_last_nsga_store, globals().get("CURRENT_SOLUTION_INDEXES", None))
    if sel is not None:
        return int(sel)
    # else closest-to-ideal using normalized triplet
    return _closest_to_ideal([df["coverage_pct"].to_numpy(),
                              df["solar_mean_norm"].to_numpy(),
                              df["elevation_mean_norm"].to_numpy()])

def _plot_pair(fig_title, x, y, xlab, ylab, selected_idx, save_path=None, dpi=300):
    """Single poster-ready chart with dominated vs Pareto, curve, and star."""
    x = np.asarray(x, float)
    y = np.asarray(y, float)

    # pareto front for this pair (max/max)
    pf_idx = _pareto_front_2d(x, y)
    dom_mask = np.ones_like(x, dtype=bool)
    dom_mask[pf_idx] = False

    fig, ax = plt.subplots(figsize=(9.5, 7))

    # dominated points
    ax.scatter(x[dom_mask], y[dom_mask],
               s=50, c="#3b66b2", edgecolors="#3B75AE", linewidth=0.6, label="Dominated", zorder=1)

    # pareto points
    ax.scatter(x[pf_idx], y[pf_idx],
               s=90, c="#d21f3c", marker="D", edgecolors="#7a0f22", linewidth=1.2, label=f"Pareto ({len(pf_idx)})", zorder=3)

    # connected pareto curve (sorted by x)
    order = np.argsort(x[pf_idx])
    ax.plot(x[pf_idx][order], y[pf_idx][order], "-", lw=2.0, color="#8b192e", alpha=0.9, label="Pareto curve", zorder=2)

    # selected solution
    ax.scatter(x[selected_idx], y[selected_idx],
               s=300, marker="*", c="#ffcc00", edgecolors="#333", linewidth=1.2, label="Selected", zorder=4)

    # cosmetics
    ax.set_title(fig_title, pad=14, fontweight="bold")
    ax.set_xlabel(xlab, fontweight="bold")
    ax.set_ylabel(ylab, fontweight="bold")
    ax.grid(True, alpha=0.25, ls="--", lw=0.7)

    # legend inside axis to avoid title overlap
    ax.legend(loc="lower right", framealpha=0.96)

    # padded limits
    xm, xM = np.nanmin(x), np.nanmax(x)
    ym, yM = np.nanmin(y), np.nanmax(y)
    ax.set_xlim(xm - 0.05*(xM - xm + 1e-9), xM + 0.05*(xM - xm + 1e-9))
    ax.set_ylim(ym - 0.05*(yM - ym + 1e-9), yM + 0.05*(yM - ym + 1e-9))

    plt.tight_layout()
    if save_path:
        _safe_savefig(save_path, dpi=dpi)
    plt.show()

def poster_tradeoffs_coverage_solar_elev(save_dir="./figs", dpi=300):
    # build the summaries
    df = _get_feasible_summaries()
    if df.empty:
        print("No feasible solutions to plot.")
        return

    # pick the point to star
    selected_idx = _selected_feasible_index(df)
    print(f"⭐ selected feasible index: {selected_idx}")

    # ensure writable directory (fallback handled inside)
    save_dir = _safe_mkdir(save_dir)

    # 1) Coverage vs Solar
    _plot_pair(
        fig_title="Coverage vs Solar (NSGA-III Feasible Set)",
        x=df["coverage_pct"].to_numpy(),
        y=df["solar_mean_norm"].to_numpy(),
        xlab="Coverage (%)",
        ylab="Solar (normalized 0–1)",
        selected_idx=selected_idx,
        save_path=os.path.join(save_dir, "poster_tradeoff_coverage_vs_solar.png"),
        dpi=dpi
    )

    # 2) Coverage vs Elevation
    _plot_pair(
        fig_title="Coverage vs Elevation (NSGA-III Feasible Set)",
        x=df["coverage_pct"].to_numpy(),
        y=df["elevation_mean_norm"].to_numpy(),
        xlab="Coverage (%)",
        ylab="Elevation (normalized 0–1)",
        selected_idx=selected_idx,
        save_path=os.path.join(save_dir, "poster_tradeoff_coverage_vs_elevation.png"),
        dpi=dpi
    )

    # 3) Solar vs Elevation
    _plot_pair(
        fig_title="Solar vs Elevation (NSGA-III Feasible Set)",
        x=df["solar_mean_norm"].to_numpy(),
        y=df["elevation_mean_norm"].to_numpy(),
        xlab="Solar (normalized 0–1)",
        ylab="Elevation (normalized 0–1)",
        selected_idx=selected_idx,
        save_path=os.path.join(save_dir, "poster_tradeoff_solar_vs_elevation.png"),
        dpi=dpi
    )

# ---- Run it (creates ./figs and saves PNGs there) ----
poster_tradeoffs_coverage_solar_elev(save_dir="./figs", dpi=300)

In [None]:
from ipyleaflet import Map, CircleMarker, GeoJSON, basemaps, LayerGroup
from ipywidgets import HTML, VBox
import matplotlib.pyplot as plt

def plot_gateways_on_map(aoi_poly, gateway_solution, label_prefix="Gateway", color='red'):
    """
    Visualize selected gateway locations on an ipyleaflet map.
    Args:
        aoi_poly: Shapely Polygon (AOI boundary)
        gateway_solution: list of dicts, each with at least {'coord': (lon, lat)}
        label_prefix: Marker label prefix (default "Gateway")
        color: Marker color (default 'red')
    Returns:
        ipyleaflet Map widget
    """
    m = Map(
        center=(aoi_poly.centroid.y, aoi_poly.centroid.x),
        zoom=13,
        basemap=basemaps.Esri.WorldImagery
    )
    m.add_layer(GeoJSON(data=aoi_poly.__geo_interface__, style={
        'color': 'yellow', 'fillOpacity': 0.1, 'weight': 3
    }))

    # Add gateway markers
    solution_sorted = sort_gateways(gateway_solution, sort_by="lon")
    for idx, gw in enumerate(solution_sorted):
        lat, lon = gw['coord'][1], gw['coord'][0]
        popup_html = f"<b>{label_prefix} {idx+1}</b><br>Lon: {lon:.5f}<br>Lat: {lat:.5f}"
        # Add more attributes from gw as needed, e.g., coverage, elevation, solar, accessibility, installation cost, terrain suitability, valley coverage, centrality
        if 'coverage' in gw: popup_html += f"<br>Coverage: {gw['coverage']:.1f}%"
        if 'elevation' in gw: popup_html += f"<br>Elevation: {gw['elevation']:.1f}m"
        if 'solar' in gw: popup_html += f"<br>Solar: {gw['solar']:.2f}"
        if 'accessibility' in gw: popup_html += f"<br>Accessibility: {gw['accessibility']:.2f}"
        if 'installation_cost' in gw: popup_html += f"<br>Install Cost: {gw['installation_cost']:.2f}"
        if 'terrain_suitability' in gw: popup_html += f"<br>Terrain Suitability: {gw['terrain_suitability']:.2f}"
        if 'valley_coverage' in gw: popup_html += f"<br>Valley Coverage: {gw['valley_coverage']:.2f}"
        if 'centrality' in gw: popup_html += f"<br>Centrality: {gw['centrality']:.2f}"
        marker = CircleMarker(
            location=(lat, lon),
            radius=8,
            color=color,
            fill_color=color,
            fill_opacity=0.8,
            popup=HTML(popup_html)
        )
        m.add_layer(marker)
    return m

# --- Usage Example ---
solution = get_current_solution_or_raise()  # or use your actual variable from optimization
solution_sorted = sort_gateways(solution, sort_by="lat")
display(plot_gateways_on_map(aoi_poly, solution, color="green"))

In [None]:
import plotly.graph_objects as go
from pyproj import Transformer
from scipy.interpolate import RegularGridInterpolator
import numpy as np

# --- DEM grid creation ---
dem_arr = dem_manager.dem_array
nrows, ncols = dem_arr.shape
minx, miny, maxx, maxy = dem_manager.bounds

# Grid in lon, lat (as per DEM pixel organization)
lons = np.linspace(minx, maxx, ncols)
lats = np.linspace(miny, maxy, nrows)
lon_grid, lat_grid = np.meshgrid(lons, lats)

# Project grid to UTM
utm_crs = "EPSG:32611"  # Use your zone!
transformer = Transformer.from_crs("EPSG:4326", utm_crs, always_xy=True)
x_grid, y_grid = transformer.transform(lon_grid, lat_grid)

# --- No flip for DEM or grid: Keep "origin" lower left (as expected by Plotly) ---
# Plotly expects x/y axes to be increasing as in meshgrid (origin lower left).
# So, do not flip the axes; just flip the DEM if your DEM is indexed from top.

dem_arr_plot = np.flipud(dem_arr)   # This flip aligns DEM with meshgrid origin

# --- Build axes for interpolator ---
x_axis = x_grid[0, :]  # UTM X
y_axis = y_grid[:, 0]  # UTM Y

interp = RegularGridInterpolator(
    (y_axis, x_axis),
    dem_arr_plot,
    bounds_error=False,
    fill_value=None
)

# --- Prepare Gateway Points ---
solution_sorted = sort_gateways(solution, sort_by="lat")
gw_xs, gw_ys, gw_zs = [], [], []
for gw in solution_sorted:
    lon, lat = gw['coord']
    x, y = transformer.transform(lon, lat)
    gw_xs.append(x)
    gw_ys.append(y)
    gw_zs.append(interp((y, x)))

# --- Plot 3D Surface ---
fig = go.Figure()
fig.add_trace(go.Surface(
    x=x_grid, y=y_grid, z=dem_arr_plot,
    colorscale='earth', colorbar=dict(title="Elevation (m)"),
    showscale=True, opacity=0.96
))
fig.add_trace(go.Scatter3d(
    x=gw_xs, y=gw_ys, z=np.array(gw_zs) + 15,  # Offset for visibility
    mode='markers+text',
    marker=dict(size=10, color='red', symbol='circle'),
    text=[f"GW {i+1}" for i in range(len(gw_xs))],
    textposition='top center'
))
# Mark SW and NE corners
fig.add_trace(go.Scatter3d(
    x=[x_grid[0,0], x_grid[-1,-1]],
    y=[y_grid[0,0], y_grid[-1,-1]],
    z=[dem_arr_plot[0,0]+30, dem_arr_plot[-1,-1]+30],
    mode='markers+text',
    marker=dict(size=10, color=['green','blue']),
    text=['SW', 'NE'],
    textposition='bottom center'
))
fig.update_layout(
    title="3D Gateway Placement on Terrain (UTM, North-Up, East Right, FINAL ORIENTATION)",
    scene=dict(
        xaxis_title='UTM X (m)',
        yaxis_title='UTM Y (m)',
        zaxis_title='Elevation (m, exaggerated)',
        aspectmode='data'
    ),
    margin=dict(l=0, r=0, b=0, t=40)
)
fig.show()

In [None]:
import plotly.graph_objects as go
from pyproj import Transformer
from scipy.interpolate import RegularGridInterpolator
import numpy as np

# --- DEM grid creation (as you had) ---
dem_arr = dem_manager.dem_array
nrows, ncols = dem_arr.shape
minx, miny, maxx, maxy = dem_manager.bounds

lons = np.linspace(minx, maxx, ncols)
lats = np.linspace(miny, maxy, nrows)
lon_grid, lat_grid = np.meshgrid(lons, lats)

utm_crs = "EPSG:32611"
transformer = Transformer.from_crs("EPSG:4326", utm_crs, always_xy=True)
x_grid, y_grid = transformer.transform(lon_grid, lat_grid)

# Align DEM with mesh origin (lower-left)
dem_arr_plot = np.flipud(dem_arr)

# Axes for interpolation
x_axis = x_grid[0, :]
y_axis = y_grid[:, 0]

interp = RegularGridInterpolator(
    (y_axis, x_axis), dem_arr_plot, bounds_error=False, fill_value=None
)

# ---- Sanity printouts
print("DEM elevation stats (m):",
      f"min={np.nanmin(dem_arr):.1f}, max={np.nanmax(dem_arr):.1f}, "
      f"range={np.nanmax(dem_arr)-np.nanmin(dem_arr):.1f}")
print("XY span (m):",
      f"ΔE={x_axis[-1]-x_axis[0]:.0f}, ΔN={y_axis[-1]-y_axis[0]:.0f}")

# --- Gateways
solution_sorted = sort_gateways(solution, sort_by="lat")
gw_xs, gw_ys, gw_zs = [], [], []
for gw in solution_sorted:
    lon, lat = gw['coord']
    x, y = transformer.transform(lon, lat)
    gw_xs.append(x); gw_ys.append(y); gw_zs.append(interp((y, x)))

gw_names = [f"GW {i+1}" for i in range(len(gw_xs))]

# ---- Vertical exaggeration control (1.0 = true-to-scale)
VE = 1.0
Zsurf = dem_arr_plot * VE
Zgw   = (np.array(gw_zs) + 10.0) * VE    # small lift for visibility

# ---- Aspect ratio from data ranges (meters)
xR = float(x_axis[-1] - x_axis[0])
yR = float(y_axis[-1] - y_axis[0])
zR = float(np.nanmax(dem_arr_plot) - np.nanmin(dem_arr_plot))
maxR = max(xR, yR, zR if zR>0 else 1.0)
aspect = dict(x=xR/maxR, y=yR/maxR, z=(zR*VE)/maxR)

# ---- Figure
fig = go.Figure()

fig.add_trace(go.Surface(
    x=x_grid, y=y_grid, z=Zsurf,
    colorscale='earth',
    showscale=True,
    colorbar=dict(title="Elevation (m)", x=1.04, len=0.90, y=0.5),
    opacity=0.98,
    contours = dict(z=dict(show=True, width=1, color="rgba(255,255,255,0.6)",
                           highlight=True, project_z=True))
))

fig.add_trace(go.Scatter3d(
    x=gw_xs, y=gw_ys, z=Zgw,
    mode='markers+text',
    name='Gateways',
    marker=dict(size=8, color='crimson', symbol='circle', line=dict(width=1,color="#222")),
    text=gw_names,
    textposition='top center',
    textfont=dict(color='#111', size=11, family='Arial')
))

# Mark corners for orientation (optional, no legend)
fig.add_trace(go.Scatter3d(
    x=[x_grid[0,0], x_grid[-1,-1]],
    y=[y_grid[0,0], y_grid[-1,-1]],
    z=[Zsurf[0,0]+20*VE, Zsurf[-1,-1]+20*VE],
    mode='markers+text',
    showlegend=False,
    marker=dict(size=6, color=['#2ecc71','#3498db']),
    text=['SW','NE'],
    textposition='bottom center'
))

fig.update_layout(
    title=f"3D Gateway Placement on Terrain (UTM) — VE ×{VE:g}",
    margin=dict(l=10, r=80, b=10, t=60),
    legend=dict(x=0.01, y=0.99, bgcolor="rgba(255,255,255,0.8)",
                bordercolor="rgba(0,0,0,0.2)", borderwidth=1),
    scene=dict(
        xaxis=dict(title='UTM Easting (m)',  tickformat=',.0f'),
        yaxis=dict(title='UTM Northing (m)', tickformat=',.0f'),
        zaxis=dict(title=f'Elevation (m) · VE×{VE:g}', tickformat=',.0f'),
        aspectmode='manual',
        aspectratio=aspect,
        camera=dict(eye=dict(x=1.6, y=1.9, z=0.7))  # natural oblique view
    )
)

fig.show()

In [None]:
# === DRAPE PATH-LOSS COVERAGE OVER *YOUR* 3D TERRAIN PLOT ===
# Assumes the variables in your cell exist:
#   dem_manager, x_grid, y_grid, lon_grid, lat_grid, dem_arr_plot,
#   transformer, solution (list of {'coord': (lon,lat)}),
#   COVERAGE_GRID (list of (lon,lat)), COVERAGE_MASKS (Ncandidates x Ngrid bool/0/1),
#   CURRENT_SOLUTION_INDEXES (indexes of the chosen feasible gateways)
#
# If you don't have CURRENT_SOLUTION_INDEXES handy, see the fallback near the bottom.

from scipy.interpolate import griddata
from scipy.ndimage import gaussian_filter
import numpy as np
import plotly.graph_objects as go

# ---------- 1) Build the same terrain figure you already use ----------
solution_sorted = sort_gateways(solution, sort_by="lat")

# (gateway z's sampled from your DEM interpolator)
gw_xs, gw_ys, gw_zs = [], [], []
for gw in solution_sorted:
    lon, lat = gw['coord']
    x, y = transformer.transform(lon, lat)
    gw_xs.append(x)
    gw_ys.append(y)
    gw_zs.append(interp((y, x)))   # <- your RegularGridInterpolator on (y_axis, x_axis)

fig = go.Figure()
fig.add_trace(go.Surface(
    x=x_grid, y=y_grid, z=dem_arr_plot,
    colorscale='Earth', colorbar=dict(title="Elevation (m)"),
    showscale=True, opacity=0.96, name="Terrain"
))
fig.add_trace(go.Scatter3d(
    x=np.array(gw_xs), y=np.array(gw_ys), z=np.array(gw_zs) + 15.0,
    mode='markers+text',
    marker=dict(size=10, color='red', symbol='diamond'),
    text=[f"GW {i+1}" for i in range(len(gw_xs))],
    textposition='top center',
    name="Gateways"
))

# ---------- 2) Turn your coverage masks into a raster on (lon_grid, lat_grid) ----------
# Pick which gateways define coverage: use the adopted NSGA-III indices if available
try:
    gw_idxs = np.asarray(CURRENT_SOLUTION_INDEXES, dtype=int)
except NameError:
    # Fallback: if you can map 'solution' rows to candidate indices, do that here.
    # For now, use all masks as a harmless default (still produces a picture).
    gw_idxs = np.arange(COVERAGE_MASKS.shape[0], dtype=int)

# Per-grid-point coverage count from the chosen gateways
counts = COVERAGE_MASKS[gw_idxs].sum(axis=0).astype(float)   # shape: [N_grid_points]
cov_vals = counts / max(counts.max(), 1.0)                   # normalize 0..1

# Coordinates of each coverage-sample point (lon/lat)
cov_lon = np.array([p[0] for p in COVERAGE_GRID])
cov_lat = np.array([p[1] for p in COVERAGE_GRID])

# Rasterize the scattered coverage samples onto YOUR lon/lat grid
cov_raster = griddata(
    points=np.column_stack([cov_lon, cov_lat]),
    values=cov_vals,
    xi=(lon_grid, lat_grid),          # <- your existing lon/lat mesh
    method="linear",
    fill_value=0.0
)

# Optional cosmetic blur so the drape looks smooth (does NOT change numbers)
cov_raster = gaussian_filter(cov_raster, sigma=1.3)
# Re-normalize to 0..1 for a clean colorbar
mn, mx = float(np.nanmin(cov_raster)), float(np.nanmax(cov_raster))
if mx > mn:
    cov_raster = (cov_raster - mn) / (mx - mn)

# ---------- 3) Add the coverage drape as a second surface (lifted slightly) ----------
fig.add_trace(go.Surface(
    x=x_grid, y=y_grid, z=dem_arr_plot + 0.6,          # small lift to avoid z-fighting
    surfacecolor=cov_raster,
    colorscale="Plasma", cmin=0.0, cmax=1.0,
    opacity=0.95, showscale=True,
    colorbar=dict(title="Gateway Coverage (0–1)"),
    name="Coverage"
))

# Optional subtle contours for shape cues (no hover)
fig.add_trace(go.Surface(
    x=x_grid, y=y_grid, z=dem_arr_plot + 0.2,
    showscale=False, opacity=0.08, colorscale="Greys",
    contours=dict(z=dict(show=True, size=50, color="black")),
    hoverinfo="skip",
    name="Contours"
))

fig.update_layout(
    title="3D Terrain with Path-Loss Coverage Drape",
    scene=dict(
        xaxis_title='UTM X (m)',
        yaxis_title='UTM Y (m)',
        zaxis_title='Elevation (m)',
        aspectmode='data'
    ),
    margin=dict(l=0, r=0, b=0, t=48)
)
fig.show()

In [None]:
# === Per-Gateway / UNION Coverage Drape on YOUR 3D Terrain (with legend & tidy colorbars) ===
import numpy as np
import plotly.graph_objects as go
from ipywidgets import Dropdown, HBox
from IPython.display import display
from scipy.interpolate import griddata
from scipy.ndimage import gaussian_filter

SHOW_TERRAIN_COLORBAR = False   # set True to show both colorbars (non-overlapping)

def _mask_to_raster(mask_bool, lon_grid, lat_grid):
    cov_lon = np.array([p[0] for p in COVERAGE_GRID])
    cov_lat = np.array([p[1] for p in COVERAGE_GRID])
    raster = griddata(
        points=np.column_stack([cov_lon, cov_lat]),
        values=mask_bool.astype(float),
        xi=(lon_grid, lat_grid),
        method="linear",
        fill_value=0.0
    )
    raster = gaussian_filter(raster, sigma=1.0)
    mn, mx = float(np.nanmin(raster)), float(np.nanmax(raster))
    if mx > mn:
        raster = (raster - mn) / (mx - mn)
    return raster

if CURRENT_SOLUTION_INDEXES is None:
    raise RuntimeError("Adopt a feasible solution first so CURRENT_SOLUTION_INDEXES is set.")

gw_candidate_idxs = list(map(int, CURRENT_SOLUTION_INDEXES))
num_gw = len(gw_candidate_idxs)

view_dd = Dropdown(
    options=[("Single (binary)", "single"),
             ("Single (exclusive)", "exclusive"),
             ("All (union)", "union")],
    value="single", description="View:"
)
gw_dd = Dropdown(
    options=[(f"GW {i+1}", i) for i in range(num_gw)],
    value=0, description="Gateway:"
)

def render(*_):
    mode = view_dd.value

    if mode == "union":
        mask_union = COVERAGE_MASKS[gw_candidate_idxs].any(axis=0)
        cov_pct = float(mask_union.mean() * 100.0)
        cov_raster = _mask_to_raster(mask_union, lon_grid, lat_grid)
        title_extra = f"All GWs — union coverage: {cov_pct:.1f}%"
        sel_idx = None
        mode_label = "All (union)"
    else:
        i_sel = gw_dd.value
        sel_idx = i_sel
        cand_idx = gw_candidate_idxs[i_sel]
        mask_sel = COVERAGE_MASKS[cand_idx].astype(bool)

        if mode == "exclusive" and num_gw > 1:
            others = [idx for k, idx in enumerate(gw_candidate_idxs) if k != i_sel]
            mask_others_any = COVERAGE_MASKS[others].any(axis=0)
            mask_final = np.logical_and(mask_sel, np.logical_not(mask_others_any))
            mode_suffix = "exclusive"
        else:
            mask_final = mask_sel
            mode_suffix = "binary"

        cov_pct = float(mask_sel.mean() * 100.0)
        cov_excl_pct = float(mask_final.mean() * 100.0)
        cov_raster = _mask_to_raster(mask_final, lon_grid, lat_grid)
        title_extra = (f"GW {i_sel+1} — covered: {cov_pct:.1f}%"
                       + (f" | exclusive: {cov_excl_pct:.1f}%" if mode_suffix == "exclusive" else ""))
        mode_label = f"GW {i_sel+1} ({mode_suffix})"

    # gateway marker coords (same DEM interp you used)
    gw_xs, gw_ys, gw_zs = [], [], []
    for gw in solution:
        lon, lat = gw['coord']
        x, y = transformer.transform(lon, lat)
        gw_xs.append(x); gw_ys.append(y)
        gw_zs.append(float(interp((y, x))) + 15.0)

    fig = go.Figure()

    # --- Terrain underlay ---
    terrain_kwargs = dict(
        x=x_grid, y=y_grid, z=dem_arr_plot,
        colorscale='Earth', opacity=0.96, name="Terrain",
        showscale=SHOW_TERRAIN_COLORBAR
    )
    if SHOW_TERRAIN_COLORBAR:
        terrain_kwargs["colorbar"] = dict(title="Elevation (m)", x=1.12, len=0.6)
    fig.add_trace(go.Surface(**terrain_kwargs))

    # --- Coverage drape ---
    fig.add_trace(go.Surface(
        x=x_grid, y=y_grid, z=dem_arr_plot + 0.6,
        surfacecolor=cov_raster,
        colorscale="Plasma", cmin=0.0, cmax=1.0,
        opacity=0.95, showscale=True,
        colorbar=dict(title="Coverage (0–1)", x=1.0, len=0.6),
        name=f"Coverage — {mode_label}"
    ))

    # --- Subtle contours for shape ---
    fig.add_trace(go.Surface(
        x=x_grid, y=y_grid, z=dem_arr_plot + 0.2,
        showscale=False, opacity=0.08, colorscale="Greys",
        contours=dict(z=dict(show=True, size=50, color="black")),
        hoverinfo="skip", name="Contours"
    ))

    # --- Legend: split gateways into 'Selected' and 'Others' ---
    if sel_idx is not None:
        # selected
        fig.add_trace(go.Scatter3d(
            x=[gw_xs[sel_idx]], y=[gw_ys[sel_idx]], z=[gw_zs[sel_idx]],
            mode="markers+text",
            marker=dict(size=10, color="red", symbol="diamond"),
            text=[f"GW {sel_idx+1}"], textposition="top center",
            name="Selected GW", showlegend=True
        ))
        # others
        other_mask = [i for i in range(len(gw_xs)) if i != sel_idx]
        if other_mask:
            fig.add_trace(go.Scatter3d(
                x=np.array(gw_xs)[other_mask], y=np.array(gw_ys)[other_mask], z=np.array(gw_zs)[other_mask],
                mode="markers+text",
                marker=dict(size=10, color="rgba(100,100,100,0.9)", symbol="diamond"),
                text=[f"GW {i+1}" for i in other_mask], textposition="top center",
                name="Other GWs", showlegend=True
            ))
    else:
        # union view: show all as red, one legend item
        fig.add_trace(go.Scatter3d(
            x=np.array(gw_xs), y=np.array(gw_ys), z=np.array(gw_zs),
            mode="markers+text",
            marker=dict(size=10, color="red", symbol="diamond"),
            text=[f"GW {i+1}" for i in range(len(gw_xs))], textposition="top center",
            name="Gateways (union)", showlegend=True
        ))

    # --- Layout / legend / badge ---
    fig.update_layout(
        title=f"Coverage Drape on Terrain — {title_extra}",
        scene=dict(
            xaxis_title='UTM X (m)',
            yaxis_title='UTM Y (m)',
            zaxis_title='Elevation (m)',
            aspectmode='data'
        ),
        margin=dict(l=0, r=0, b=0, t=60),
        legend=dict(orientation="h", x=0.02, y=0.02, bgcolor="rgba(255,255,255,0.6)")
    )

    # small mode badge in the top-left
    fig.add_annotation(
        xref="paper", yref="paper", x=0.01, y=0.98,
        text=f"View: {mode_label}", showarrow=False,
        bgcolor="rgba(255,255,255,0.6)", bordercolor="rgba(0,0,0,0.2)", borderwidth=1
    )

    fig.show()

def _toggle_vis(*_):
    gw_dd.layout.display = "none" if view_dd.value == "union" else ""

view_dd.observe(render, names="value")
gw_dd.observe(render, names="value")
view_dd.observe(_toggle_vis, names="value")
_toggle_vis()

display(HBox([view_dd, gw_dd]))
render()

In [None]:
# === Per-Gateway / UNION Coverage Drape on YOUR 3D Terrain ===
import numpy as np
import plotly.graph_objects as go
from ipywidgets import Dropdown, Checkbox, HBox, VBox, HTML
from IPython.display import display
from scipy.interpolate import griddata
from scipy.ndimage import gaussian_filter

# --- helper: rasterize a boolean mask (at COVERAGE_GRID points) onto your lon/lat mesh
def _mask_to_raster(mask_bool, lon_grid, lat_grid):
    cov_lon = np.array([p[0] for p in COVERAGE_GRID])
    cov_lat = np.array([p[1] for p in COVERAGE_GRID])
    raster = griddata(
        points=np.column_stack([cov_lon, cov_lat]),
        values=mask_bool.astype(float),
        xi=(lon_grid, lat_grid),
        method="linear",
        fill_value=0.0
    )
    # cosmetic smoothing for nicer visuals (does not change “which cells”)
    raster = gaussian_filter(raster, sigma=1.0)
    # keep 0..1
    mn, mx = float(np.nanmin(raster)), float(np.nanmax(raster))
    if mx > mn:
        raster = (raster - mn) / (mx - mn)
    return raster

if CURRENT_SOLUTION_INDEXES is None:
    raise RuntimeError("Adopt a feasible solution first so CURRENT_SOLUTION_INDEXES is set.")

gw_candidate_idxs = list(map(int, CURRENT_SOLUTION_INDEXES))
num_gw = len(gw_candidate_idxs)

# --- widgets ---
view_dd = Dropdown(
    options=[("Single (binary)", "single"),
             ("Single (exclusive)", "exclusive"),
             ("All (union)", "union")],
    value="single", description="View:"
)
gw_dd = Dropdown(
    options=[(f"GW {i+1}", i) for i in range(num_gw)],
    value=0, description="Gateway:"
)

def render(*_):
    mode = view_dd.value

    if mode == "union":
        # union over all selected gateways
        mask_union = COVERAGE_MASKS[gw_candidate_idxs].any(axis=0)
        cov_pct = float(mask_union.mean() * 100.0)
        cov_raster = _mask_to_raster(mask_union, lon_grid, lat_grid)
        title_extra = f"All GWs — union coverage: {cov_pct:.1f}%"
        # all markers red in union view
        colors = ["red"] * len(solution)
        sel_label = "All (union)"
    else:
        # single-GW views
        i_sel = gw_dd.value  # 0..num_gw-1
        cand_idx = gw_candidate_idxs[i_sel]
        mask_sel = COVERAGE_MASKS[cand_idx].astype(bool)

        if mode == "exclusive" and num_gw > 1:
            others = [idx for k, idx in enumerate(gw_candidate_idxs) if k != i_sel]
            mask_others_any = COVERAGE_MASKS[others].any(axis=0)
            mask_final = np.logical_and(mask_sel, np.logical_not(mask_others_any))
            mode_label = "exclusive"
        else:
            mask_final = mask_sel
            mode_label = "binary"

        cov_pct = float(mask_sel.mean() * 100.0)
        cov_excl_pct = float(mask_final.mean() * 100.0)
        cov_raster = _mask_to_raster(mask_final, lon_grid, lat_grid)
        title_extra = (f"GW {i_sel+1} — covered: {cov_pct:.1f}%"
                       + (f" | exclusive: {cov_excl_pct:.1f}%" if mode_label == "exclusive" else ""))
        # highlight selected in red, others gray
        colors = ["rgba(100,100,100,0.9)"] * len(solution)
        colors[i_sel] = "red"
        sel_label = f"GW {i_sel+1} ({mode_label})"

    # --- build your terrain base & markers exactly like your elevation figure ---
    gw_xs, gw_ys, gw_zs, labels = [], [], [], []
    for k, gw in enumerate(solution):
        lon, lat = gw['coord']
        x, y = transformer.transform(lon, lat)
        gw_xs.append(x); gw_ys.append(y)
        gw_zs.append(float(interp((y, x))) + 15.0)  # your DEM interp + offset
        labels.append(f"GW {k+1}")

    fig = go.Figure()

    # terrain
    fig.add_trace(go.Surface(
        x=x_grid, y=y_grid, z=dem_arr_plot,
        colorscale='Earth', showscale=True,
        colorbar=dict(title="Elevation (m)"),
        opacity=0.96, name="Terrain"
    ))

    # coverage drape (union or single mask)
    fig.add_trace(go.Surface(
        x=x_grid, y=y_grid, z=dem_arr_plot + 0.6,
        surfacecolor=cov_raster,
        colorscale="Plasma", cmin=0.0, cmax=1.0,
        opacity=0.95, showscale=True,
        colorbar=dict(title="Coverage (0–1)"),
        name=f"Coverage — {sel_label}"
    ))

    # subtle contours
    fig.add_trace(go.Surface(
        x=x_grid, y=y_grid, z=dem_arr_plot + 0.2,
        showscale=False, opacity=0.08, colorscale="Greys",
        contours=dict(z=dict(show=True, size=50, color="black")),
        hoverinfo="skip", name="Contours"
    ))

    # gateways
    fig.add_trace(go.Scatter3d(
        x=np.array(gw_xs), y=np.array(gw_ys), z=np.array(gw_zs),
        mode="markers+text",
        marker=dict(size=10, color=colors, symbol="diamond"),
        text=labels, textposition="top center",
        name="Gateways"
    ))

    fig.update_layout(
        title=f"Coverage Drape on Terrain — {title_extra}",
        scene=dict(
            xaxis_title='UTM X (m)',
            yaxis_title='UTM Y (m)',
            zaxis_title='Elevation (m)',
            aspectmode='data'
        ),
        margin=dict(l=0, r=0, b=0, t=60)
    )
    fig.show()

# wire up
view_dd.observe(render, names="value")
gw_dd.observe(render, names="value")

# show controls (GW chooser hidden when union is selected—optional UX)
def _toggle_vis(*_):
    gw_dd.layout.display = "none" if view_dd.value == "union" else ""
view_dd.observe(_toggle_vis, names="value")
_toggle_vis()

display(HBox([view_dd, gw_dd]))
render()

In [None]:
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
import numpy as np
import pandas as pd
import random
from shapely.geometry import Point
from geopy.distance import geodesic
from joblib import Parallel, delayed
from tqdm.notebook import tqdm
import joblib

# ===== Parameters =====
NUM_CANDIDATES = 800
SENSOR_HEIGHT = NETWORK.SENSOR_HEIGHT
GATEWAY_HEIGHT = NETWORK.GATEWAY_HEIGHT
PATHLOSS_THRESHOLD = OPTIMIZATION.MAX_ALLOWABLE_PATH_LOSS_DB

# ===== Prepare Gateways =====
try:
    solution = get_current_solution_or_raise()  
except Exception as e:
    print(f"⚠️ {e}")
    raise

solution_sorted = sort_gateways(solution, sort_by="lat")
gateways = [tuple(gw['coord']) for gw in solution_sorted]
print(f"Analyzing sensor connectivity to {len(gateways)} selected gateways.")
print(f"Using pycraf loss_complete (ITU-R P.452-16) for accurate propagation modeling")
print(f"Path loss threshold: {PATHLOSS_THRESHOLD} dB")

# ===== Random Points in AOI =====
minx, miny, maxx, maxy = aoi_poly.bounds
random_pts = []
while len(random_pts) < NUM_CANDIDATES:
    x, y = random.uniform(minx, maxx), random.uniform(miny, maxy)
    if aoi_poly.contains(Point(x, y)):
        random_pts.append((x, y))

# ===== Helper: Sensor Candidate Analysis =====
def analyze_sensor_candidate(
    P, gateways, dem_manager, SENSOR_HEIGHT, GATEWAY_HEIGHT, PATHLOSS_THRESHOLD, NETWORK, pt_idx=None
):
    gateway_data = []
    connectable_gateways = []
    best = {"gw": -1, "path_loss": np.inf, "dist": None}
    for idx, gw in enumerate(gateways):
        pl = calculate_path_loss_pycraf(
            p1=P, p2=gw,
            dem_manager=dem_manager,
            tx_h=SENSOR_HEIGHT,
            rx_h=GATEWAY_HEIGHT,
            freq_mhz=915.0,
            veg_threshold=2.0,
            attn_per_meter=0.0,
            max_comm_range=NETWORK.MAX_COMM_RANGE_M
        )
        dist_m = geodesic(P[::-1], gw[::-1]).meters
        # Debug for first 5 sensors (if requested)
        if pt_idx is not None and pt_idx < 5:
            print(f"[S{pt_idx}→GW{idx}] PL={pl:.1f}dB, Dist={dist_m:.0f}m, " +
                  ("✓ CONNECTED" if pl <= PATHLOSS_THRESHOLD else "✗ blocked"))
        gateway_data.append({
            "idx": idx,
            "path_loss": pl,
            "distance_m": dist_m,
            "connectable": pl <= PATHLOSS_THRESHOLD
        })
        if pl <= PATHLOSS_THRESHOLD:
            connectable_gateways.append(idx)
            if pl < best['path_loss']:
                best = {"gw": idx, "path_loss": pl, "dist": dist_m}
    elevation = dem_manager.get_elevation(P)
    solar = calculate_solar_score(P, dem_manager)
    return {
        'Longitude': P[0], 
        'Latitude': P[1], 
        'Best_Gateway': best['gw'],
        'Best_PathLoss_dB': best['path_loss'],
        'Distance_m': best['dist'],
        'Gateways_Details': gateway_data,
        'Connectable_Gateways': connectable_gateways,
        'Num_Connectable': len(connectable_gateways),
        'Multi_Connected': len(connectable_gateways) > 1,
        'Elevation': elevation,
        'Solar': solar
    }

# ===== Debug: Serial for first 5 =====
for pt_idx in range(min(5, NUM_CANDIDATES)):
    _ = analyze_sensor_candidate(
        random_pts[pt_idx],
        gateways,
        dem_manager,
        SENSOR_HEIGHT,
        GATEWAY_HEIGHT,
        PATHLOSS_THRESHOLD,
        NETWORK,
        pt_idx
    )

# ======= PARALLELIZED SENSOR CANDIDATE ANALYSIS WITH TQDM PROGRESS BAR =======
class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack):
    tqdm_bar = None
    def __call__(self, *args, **kwargs):
        if self.tqdm_bar is not None:
            self.tqdm_bar.update(n=self.batch_size)
        return super().__call__(*args, **kwargs)

old_callback = joblib.parallel.BatchCompletionCallBack
joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback

try:
    with tqdm(total=len(random_pts), desc="Sensor Candidates", unit="sensor") as bar:
        TqdmBatchCompletionCallback.tqdm_bar = bar
        results = Parallel(n_jobs=-1, backend='loky')(
            delayed(analyze_sensor_candidate)(
                P,
                gateways,
                dem_manager,
                SENSOR_HEIGHT,
                GATEWAY_HEIGHT,
                PATHLOSS_THRESHOLD,
                NETWORK
                # pt_idx is not passed here to avoid print spam
            )
            for P in random_pts
        )
        TqdmBatchCompletionCallback.tqdm_bar = None
finally:
    joblib.parallel.BatchCompletionCallBack = old_callback

print("\nPath loss analysis complete using pycraf ITU-R P.452-16 model")

# ===== DATAFRAME + REPORT =====
df = pd.DataFrame(results)
df['Is_Connected'] = df['Num_Connectable'] > 0

connected_sensors = df[df['Is_Connected']]
multi_connected = df[df['Multi_Connected']]

print(f"\nSensor Connectivity Report (using loss_complete):")
print(f"   - Total candidates analyzed: {len(df)}")
print(f"   - Connectable to at least one gateway: {len(connected_sensors)} ({len(connected_sensors)/len(df):.1%})")
print(f"   - Multi-connected (redundant): {len(multi_connected)} ({len(multi_connected)/len(df):.1%})")
print(f"   - Average path loss for connected: {connected_sensors['Best_PathLoss_dB'].mean():.1f} dB")
print(f"   - Path loss range: {connected_sensors['Best_PathLoss_dB'].min():.1f} - {connected_sensors['Best_PathLoss_dB'].max():.1f} dB")

In [None]:
# ======= FIXED SENSOR CONNECTIVITY VISUALIZATION =======
# Run this cell after your sensor connectivity analysis

import warnings
warnings.filterwarnings("ignore", category=FutureWarning)

from ipyleaflet import Map, CircleMarker, Popup, GeoJSON, basemaps
from ipywidgets import HTML
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np

print("CREATING SENSOR CONNECTIVITY VISUALIZATION")
print("=" * 50)

# Check if we have the required data
if 'df' not in globals() or df.empty:
    print("ERROR: Sensor connectivity results (df) not found!")
    print("   Please run the sensor connectivity analysis first.")
else:
    print(f"Found sensor connectivity results: {len(df)} sensors analyzed")

# Function to get distinct colors for gateways
def get_distinct_colors(n):
    """Generate distinct colors for n gateways"""
    if n <= 10:
        return plt.get_cmap('tab10').colors[:n]
    else:
        return plt.cm.hsv(np.linspace(0, 1, n))

# ===== FIXED DATA ACCESS =====
if 'solution' not in globals() and 'saved_solutions' in globals() and saved_solutions:
    latest_solution_key = sorted(saved_solutions.keys())[-1]
    solution = saved_solutions[latest_solution_key]['solution']
elif 'solution' not in globals():
    print("ERROR: No gateway solution found!")
    solution = []

solution_sorted = sort_gateways(solution, sort_by="lat")
gateways = [tuple(gw['coord']) for gw in solution_sorted]

print(f"Visualizing {len(gateways)} gateways:")
for i, gw_coord in enumerate(gateways):
    print(f"   GW{i}: ({gw_coord[0]:.6f}, {gw_coord[1]:.6f})")

# ===== GENERATE GATEWAY COLORS =====
connected_sensors = df[df['Is_Connected']]
if len(connected_sensors) > 0:
    gateway_indices = sorted([idx for idx in connected_sensors['Best_Gateway'].unique() if idx >= 0])
else:
    gateway_indices = list(range(len(gateways)))  # Default to all gateways

gateway_colors = {}
if len(gateway_indices) > 0:
    colors = get_distinct_colors(len(gateway_indices))
    for i, gw_idx in enumerate(gateway_indices):
        gateway_colors[gw_idx] = mcolors.to_hex(colors[i])
else:
    for i in range(len(gateways)):
        gateway_colors[i] = '#ff0000'  # Red for all gateways

# ===== CREATE MAP =====
center = (aoi_poly.centroid.y, aoi_poly.centroid.x)
print(f"Creating map centered at: ({center[0]:.6f}, {center[1]:.6f})")

m = Map(center=center, zoom=13, basemap=basemaps.Esri.WorldImagery)
m.add_layer(GeoJSON(
    data=aoi_poly.__geo_interface__, 
    style={'color': 'yellow', 'fillOpacity': 0.02, 'weight': 3}
))

# ===== ADD GATEWAYS =====
print("Adding gateways to map...")
for idx, gw in enumerate(solution_sorted):
    coord = gw['coord']
    lat, lon = coord[1], coord[0]  # Note: leaflet uses (lat, lon)
    gateway_color = gateway_colors.get(idx, '#ff0000')
    sensors_on_this_gw = len(df[df['Best_Gateway'] == idx])
    connected_on_this_gw = len(df[(df['Best_Gateway'] == idx) & (df['Is_Connected'])])
    popup_html = f"""
    <div style='font-family: Arial; font-size: 12px;'>
        <h4 style='margin: 0 0 8px 0; color: {gateway_color};'>🏗️ Gateway {idx}</h4>
        <b>Location:</b><br>
        Lon: {lon:.6f}<br>
        Lat: {lat:.6f}<br>
        <b>Performance:</b><br>
        Assigned sensors: {sensors_on_this_gw}<br>
        Connected sensors: {connected_on_this_gw}<br>
        Connection rate: {connected_on_this_gw/sensors_on_this_gw*100 if sensors_on_this_gw > 0 else 0:.1f}%
    </div>
    """
    popup = HTML(popup_html)
    marker = CircleMarker(
        location=(lat, lon),
        radius=12,
        color='black',
        fill_color=gateway_color,
        fill_opacity=0.9,
        weight=2,
        popup=popup
    )
    m.add_layer(marker)

# ===== ADD CONNECTED SENSORS =====
print("Adding connected sensors to map...")
connected_count = 0
multi_connected_count = 0

for _, row in connected_sensors.iterrows():
    gw_idx = row['Best_Gateway']
    if gw_idx < 0 or gw_idx >= len(gateways):
        continue
    color = gateway_colors.get(gw_idx, '#ff0000')
    multi_conn = row.get('Multi_Connected', False)
    if multi_conn:
        marker_color = "#ffff00"  # Yellow for multi-connected
        marker_radius = 7
        border_color = color
        border_width = 3
        multi_connected_count += 1
    else:
        marker_color = color
        marker_radius = 4
        border_color = 'black'
        border_width = 1
    connected_count += 1
    popup_html = f"""
    <div style='font-family: Arial; font-size: 11px; max-width: 250px;'>
        <h4 style='margin: 0 0 8px 0; color: {color};'>📡 Sensor Candidate</h4>
        <b>Location:</b><br>
        Lon: {row['Longitude']:.6f}<br>
        Lat: {row['Latitude']:.6f}<br>
        <b>Best Gateway:</b> GW{gw_idx}<br>
        <b>Multi-Connected:</b> {'✅ Yes' if multi_conn else '❌ No'}<br>
        <b>Elevation:</b> {row.get('Elevation', 0):.1f}m<br>
    """
    if 'Connectable_Gateways' in row and row['Connectable_Gateways']:
        popup_html += f"<b>Connected Gateways ({len(row['Connectable_Gateways'])}):</b><br>"
        for conn_gw in row['Connectable_Gateways'][:3]:  # Show max 3 to avoid huge popups
            gw_detail = None
            for detail in row['Gateways_Details']:
                if detail['idx'] == conn_gw:
                    gw_detail = detail
                    break
            if gw_detail and 'percent_clear' in gw_detail:
                percent_clear = gw_detail['percent_clear']
            else:
                percent_clear = 0
            if gw_detail:
                popup_html += f"""
                <span style='color: {gateway_colors.get(conn_gw, "#000000")};'>🏗️ GW{conn_gw}:</span> 
                PL={gw_detail['path_loss']:.0f}dB, 
                Clear={percent_clear:.0f}%<br>
                """
        if len(row['Connectable_Gateways']) > 3:
            popup_html += f"... and {len(row['Connectable_Gateways']) - 3} more<br>"
    popup_html += "</div>"
    popup = HTML(popup_html)
    marker = CircleMarker(
        location=(row['Latitude'], row['Longitude']),
        radius=marker_radius,
        color=border_color,
        fill_color=marker_color,
        fill_opacity=0.8,
        weight=border_width,
        popup=popup
    )
    m.add_layer(marker)

# ===== ADD DISCONNECTED SENSORS (SAMPLE) =====
print("Adding sample of disconnected sensors...")
disconnected_sensors = df[~df['Is_Connected']]
if len(disconnected_sensors) > 0:
    sample_size = min(200, len(disconnected_sensors))
    sample_disconnected = disconnected_sensors.sample(n=sample_size, random_state=42)
    for _, row in sample_disconnected.iterrows():
        popup_html = f"""
        <div style='font-family: Arial; font-size: 11px;'>
            <h4 style='margin: 0 0 8px 0; color: red;'>❌ Disconnected Sensor</h4>
            <b>Location:</b><br>
            Lon: {row['Longitude']:.6f}<br>
            Lat: {row['Latitude']:.6f}<br>
            <b>Best Path Loss:</b> {row.get('Best_PathLoss_dB', 'N/A'):.0f}dB<br>
            <b>Best Clearance:</b> {row.get('Percent_Clear', 0):.0f}%<br>
        </div>
        """
        popup = HTML(popup_html)
        marker = CircleMarker(
            location=(row['Latitude'], row['Longitude']),
            radius=2,
            color='darkred',
            fill_color='red',
            fill_opacity=0.5,
            weight=1,
            popup=popup
        )
        m.add_layer(marker)

# ===== ENHANCED LEGEND =====
legend_html = """
<div style='background: white; padding: 12px; border-radius: 8px; 
            box-shadow: 0 2px 6px rgba(0,0,0,0.3); font-family: Arial; font-size: 12px; min-width: 200px;'>
    <h4 style='margin: 0 0 10px 0; color: #333;'> Network Map Legend</h4>
"""
legend_html += "<b>Gateway Colors:</b><br>"
for idx in sorted(gateway_colors.keys()):
    color = gateway_colors[idx]
    sensors_count = len(df[(df['Best_Gateway'] == idx) & (df['Is_Connected'])])
    legend_html += f"<span style='color:{color}; font-size: 16px;'>🏗️</span> Gateway {idx} ({sensors_count} sensors)<br>"

legend_html += "<br><b>Sensor Status:</b><br>"
legend_html += "<span style='color: #ffff00; font-size: 14px;'></span> Multi-Connected Sensor<br>"
legend_html += "<span style='color: #00ff00; font-size: 14px;'></span> Single-Connected Sensor<br>"
legend_html += "<span style='color: red; font-size: 14px;'></span> Disconnected Sensor<br>"

connectivity_pct = len(connected_sensors) / len(df) * 100 if len(df) > 0 else 0
multi_pct = multi_connected_count / len(df) * 100 if len(df) > 0 else 0
legend_html += f"""
<br><b>Network Statistics:</b><br>
Connected: {len(connected_sensors)}/{len(df)} ({connectivity_pct:.1f}%)<br>
Multi-connected: {multi_connected_count} ({multi_pct:.1f}%)<br>
"""

legend_html += "</div>"

legend_control = HTML(legend_html)
legend_popup = Popup(
    location=(center[0] + 0.002, center[1] - 0.002),
    child=legend_control,
    close_button=False,
    auto_close=False,
    close_on_escape_key=False
)
m.add_control(legend_popup)

print(f"\nVISUALIZATION SUMMARY:")
print(f"   Gateways displayed: {len(gateways)}")
print(f"   Connected sensors: {connected_count}")
print(f"   Multi-connected sensors: {multi_connected_count}")
print(f"   Disconnected sensors (sample): {min(200, len(disconnected_sensors))}/{len(disconnected_sensors)}")
print(f"   Overall connectivity: {connectivity_pct:.1f}%")

print(f"\nInteractive map created successfully!")
print(f"   - Click on gateways to see performance stats")
print(f"   - Click on sensors to see connection details")
print(f"   - Yellow sensors have multiple gateway options")

# Display the map
display(m)

In [None]:
# =====================================================================================
# PHASE 2: Final Sensor Placement Optimizer (Multi-Objective, Mountain-Aware, Mapping)
# With Global & Per-Gateway Assignment Options (MIP + NSGA-III)
# Includes Generation Diagnostics and Debugging!
# (Updated: remainder-aware per-gateway quotas + cross-gateway non-overlap)
# =====================================================================================
import numpy as np
import pandas as pd
from geopy.distance import geodesic
from itertools import combinations
import pulp
from pymoo.core.problem import Problem
from pymoo.algorithms.moo.nsga3 import NSGA3
from pymoo.util.ref_dirs import get_reference_directions
from pymoo.optimize import minimize
from ipywidgets import IntSlider, Dropdown, FloatSlider, Button, Output, HBox, VBox, SelectMultiple, HTML, Layout
from ipyleaflet import Map, CircleMarker, Polyline, GeoJSON, Circle, basemaps

# -------- Parameters --------
SENSOR_RADIUS = 500  # meters (for non-overlap constraint)
min_sensor_separation = SENSOR_RADIUS * 2
max_sensors = 30

# ---------- Helper to REQUIRE an adopted gateway solution ----------
def get_adopted_gateways_or_fail():
    """
    Returns a list of (lon, lat) tuples for the CURRENT adopted gateway solution.
    Raises RuntimeError with a clear message if none is adopted.
    """
    try:
        gw_solution = get_current_solution_or_raise()
    except RuntimeError as e:
        raise RuntimeError(
            "No adopted gateway solution found. In Phase 1, run NSGA-III or MIP and click "
            "'Adopt Selected Feasible' (or 'Adopt MIP Result') before running Phase 2."
        ) from e

    if not gw_solution:
        raise RuntimeError("CURRENT_SOLUTION is empty. Adopt a gateway solution in Phase 1 first.")

    gw_sorted = sort_gateways(gw_solution, sort_by="lat")
    gateways_local = [tuple(gw['coord']) for gw in gw_sorted]
    if len(gateways_local) == 0:
        raise RuntimeError("Gateway list is empty after sorting; cannot continue.")
    return gateways_local

# Try to fetch gateways now (header preview). It's okay if this fails here; we guard inside the run.
try:
    gateways = get_adopted_gateways_or_fail()
except RuntimeError as e:
    print(f"⚠️ {e}")
    gateways = []

# -------- UI Controls --------
num_sensors_slider = IntSlider(value=12, min=3, max=max_sensors, step=1, description="Num Sensors:")
optimization_method = Dropdown(
    options=[('Weighted (MIP)', 'mip'), ('Pareto (NSGA-III)', 'nsga3')],
    value='mip', description='Method:'
)
assignment_mode = Dropdown(
    options=[('Global', 'global'), ('Per-Gateway', 'partitioned')],
    value='global', description='Assignment:'
)
sensor_objectives = [
    ('Solar', 'Solar'),
    ('Elevation', 'Elevation'),
    ('Multi-Connected', 'Num_Connectable'),
    ('PathLoss', 'Best_PathLoss_dB')
]
obj_select = SelectMultiple(
    options=sensor_objectives,
    value=('Solar', 'Elevation', 'Num_Connectable', 'Best_PathLoss_dB'),
    description='NSGA-III Objectives',
    layout=Layout(width='300px', height='110px')
)

w_coverage = FloatSlider(value=0.3, min=0, max=1, step=0.05, description='Coverage:', style={'description_width': '100px'})
w_solar = FloatSlider(value=0.2, min=0, max=1, step=0.05, description='Solar:', style={'description_width': '100px'})
w_elevation = FloatSlider(value=0.2, min=0, max=1, step=0.05, description='Elevation:', style={'description_width': '100px'})
w_multiconn = FloatSlider(value=0.2, min=0, max=1, step=0.05, description='Multi-Conn:', style={'description_width': '100px'})
w_pathloss = FloatSlider(value=0.1, min=0, max=1, step=0.05, description='PathLoss:', style={'description_width': '100px'})

def normalize_df_columns(df, cols):
    for col in cols:
        if col in df.columns:
            vals = df[col].values
            min_val = np.nanmin(vals)
            max_val = np.nanmax(vals)
            if max_val > min_val:
                df[f"{col}_norm"] = (vals - min_val) / (max_val - min_val)
            else:
                df[f"{col}_norm"] = np.zeros_like(vals)
    return df

def get_sensor_weights():
    return {
        'Solar': w_solar.value,
        'Elevation': w_elevation.value,
        'Num_Connectable': w_multiconn.value,
        'Best_PathLoss_dB': w_pathloss.value
    }

run_btn = Button(description='Run Sensor Optimization', button_style='success')
output = Output()

# --------- Helper Functions ----------
def normalize_col(df, col):
    vals = df[col].values
    if np.max(vals) > np.min(vals):
        return (vals - np.min(vals)) / (np.max(vals) - np.min(vals))
    else:
        return np.zeros_like(vals)

def get_coverage_union(selected_locs, coverage_grid, sensor_radius):
    covered = set()
    for lon, lat in selected_locs:
        for i, (x, y) in enumerate(coverage_grid):
            d = geodesic((lat, lon), (y, x)).meters
            if d <= sensor_radius:
                covered.add(i)
    return 100.0 * len(covered) / len(coverage_grid) if coverage_grid else 0.0

def plot_selected_sensors_map(aoi_poly, gateways, sensors_df, gateway_colors=None, gateway_col='Best_Gateway'):
    m = Map(center=(aoi_poly.centroid.y, aoi_poly.centroid.x), zoom=13, basemap=basemaps.Esri.WorldImagery)
    m.add_layer(GeoJSON(data=aoi_poly.__geo_interface__, style={
        'color': 'yellow', 'fillOpacity': 0.03, 'weight': 2
    }))
    # Gateways
    colors = ['red', 'green', 'blue', 'purple', 'orange', 'black', 'magenta']
    if gateway_colors is None:
        gateway_colors = {i: colors[i % len(colors)] for i in range(len(gateways))}
    for idx, gw in enumerate(gateways):
        lat, lon = gw[1], gw[0]
        m.add_layer(CircleMarker(location=(lat, lon), radius=10, color=gateway_colors[idx], fill_color=gateway_colors[idx], fill_opacity=0.8))
    # Sensors and lines
    for i, row in sensors_df.iterrows():
        lat, lon = row['Latitude'], row['Longitude']
        gw_idx = int(row[gateway_col]) if gateway_col in row else int(row['Best_Gateway'])
        color = gateway_colors[gw_idx]
        m.add_layer(CircleMarker(location=(lat, lon), radius=5, color=color, fill_color=color, fill_opacity=0.82))
        m.add_layer(Circle(location=(lat, lon), radius=SENSOR_RADIUS, color="#aaa", fill_color="#aaa", fill_opacity=0.12))
        gw = gateways[gw_idx]
        m.add_layer(Polyline(locations=[(lat, lon), (gw[1], gw[0])], color=color, weight=2, opacity=0.7))
    return m

def dist_m(p, q):
    return geodesic((p[1], p[0]), (q[1], q[0])).meters

def ok_vs_global(P, global_list, min_sep):
    """True if P is at least min_sep from all points in global_list."""
    return all(dist_m(P, Q) >= min_sep for Q in global_list)

def feasible_indices_from_result(res, tol=1e-6):
    """
    Return indices of feasible solutions from a pymoo 'minimize' result.
    Uses CV if available; falls back to G if present; otherwise assumes all feasible.
    """
    if hasattr(res, "CV") and res.CV is not None:
        return np.where(res.CV <= tol)[0]
    if hasattr(res, "G") and res.G is not None:
        G = res.G
        if G.ndim == 1:
            return np.where(G <= tol)[0]
        return np.where(np.all(G <= tol, axis=1))[0]
    n = len(res.X) if hasattr(res, "X") and res.X is not None else 0
    return np.arange(n)

# --------- Main Callback: With Debug & Generation Diagnostics ----------
def on_sensor_opt_run(b):
    output.clear_output()
    with output:
        print("=== Phase 2: Sensor Placement Optimization ===")
        N = num_sensors_slider.value
        method = optimization_method.value
        assign_mode = assignment_mode.value

        # === Prepare the candidate DataFrame ===
        if 'df' not in globals() or df.empty:
            print("Sensor candidate DataFrame 'df' is missing. Run the connectivity analysis first.")
            return

        cand_df = df[df['Is_Connected']].reset_index(drop=True).copy()

        # Get gateway coordinates for assignment from the CURRENT (adopted) solution
        try:
            gateways = get_adopted_gateways_or_fail()
        except RuntimeError as e:
            print(f"❌ {e}")
            return

        def assign_to_nearest_gateway(row, gateways):
            dists = [geodesic((row['Latitude'], row['Longitude']), (gw[1], gw[0])).meters for gw in gateways]
            return int(np.argmin(dists))

        cand_df['Nearest_Gateway'] = cand_df.apply(assign_to_nearest_gateway, axis=1, gateways=gateways)

        # Coverage grid
        from shapely.geometry import Polygon, MultiPolygon
        poly = aoi_poly
        if not isinstance(poly, Polygon):
            if hasattr(poly, 'geoms'):
                poly = max(poly.geoms, key=lambda p: p.area)
            else:
                print("AOI geometry is not a Polygon or MultiPolygon.")
                return
        coverage_grid = generate_coverage_grid(poly, 60)

        # Normalize features
        norm_cols = ['Solar', 'Elevation', 'Num_Connectable', 'Best_PathLoss_dB']
        cand_df = normalize_df_columns(cand_df, norm_cols)
        if len(cand_df) < N:
            print(f"⚠️ Only {len(cand_df)} valid sensor candidates found. Lower 'Num Sensors'.")
            return

        # --- GENERATION EVAL: Candidate Pool Stats ---
        print("========== SANITY CHECK: CANDIDATE GENERATION ==========")
        print(f"Total candidate sensors: {len(cand_df)}")
        print(f"Columns: {cand_df.columns.tolist()}")
        print("First 3 candidates:\n", cand_df[['Longitude', 'Latitude', 'Num_Connectable']].head(3))
        multi_conn_count = (cand_df['Num_Connectable'] > 1).sum()
        print(f"Multi-connectable candidates: {multi_conn_count} ({multi_conn_count/len(cand_df)*100:.1f}%)")
        if 'Best_Gateway' in cand_df.columns:
            print("Candidate count per gateway:")
            print(cand_df['Best_Gateway'].value_counts().sort_index())

        # == Gateway data ==
        gateway_colors = {i: c for i, c in enumerate(['red','green','blue','purple','orange','black','magenta'])}
        num_gateways = len(gateways)
        if num_gateways == 0:
            print("No gateways available. Adopt a solution in Phase 1 and try again.")
            return

        # ====== PARTITIONED PER-GATEWAY SENSOR ASSIGNMENT ======
        if assign_mode == "partitioned":
            print(f"\n=== Per-Gateway Assignment (non-overlapping across gateways) ===")
            if num_gateways == 1:
                print("ℹ️ Only one gateway; partitioned mode reduces to global selection for that gateway.")

            # ---------- Quota with remainder ----------
            base = N // num_gateways
            rem  = N %  num_gateways
            sizes = [(g, int((cand_df['Nearest_Gateway'] == g).sum())) for g in range(num_gateways)]
            quota = {g: base for g in range(num_gateways)}
            for g,_ in sorted(sizes, key=lambda x: x[1], reverse=True)[:rem]:
                quota[g] += 1
            total_quota = sum(quota.values())
            print(f"Requested N={N}. Quotas by gateway (base+remainder): {quota}  -> total {total_quota}")

            # ---------- weights for scoring ----------
            weights = get_sensor_weights()

            # ---------- global non-overlap tracking ----------
            global_selected_pts = []       # list of (lon, lat)
            global_selected_rows = []      # list of DataFrames to concat
            unmet = 0                      # how many slots we couldn't fill in pass 1

            # ---------- PASS 1: per-gateway selection with global filtering ----------
            for gw_idx, gw in enumerate(gateways):
                need = quota[gw_idx]
                if need <= 0:
                    continue
                local_cands = cand_df[cand_df['Nearest_Gateway'] == gw_idx].reset_index(drop=True)

                # Drop candidates that would overlap already-chosen sensors
                if global_selected_pts:
                    keep_idx = []
                    for i,r in local_cands.iterrows():
                        P = (float(r['Longitude']), float(r['Latitude']))
                        if ok_vs_global(P, global_selected_pts, min_sensor_separation):
                            keep_idx.append(i)
                    local_cands = local_cands.loc[keep_idx].reset_index(drop=True)

                print(f"Gateway {gw_idx}: {int((cand_df['Nearest_Gateway']==gw_idx).sum())} local candidates,"
                      f" {len(local_cands)} remain after cross-gateway non-overlap filter. Selecting {need} sensors.")
                print("  Top 3 local candidates:")
                print(local_cands[['Longitude','Latitude','Num_Connectable']].head(3))
                local_multi = (local_cands['Num_Connectable'] > 1).sum()
                print(f"  Multi-connectable (local): {local_multi} / {len(local_cands)}")

                if len(local_cands) == 0 or need == 0:
                    unmet += need
                    continue

                # Prepare normalized score
                for col in weights:
                    if col in local_cands.columns:
                        local_cands[f"{col}_norm"] = normalize_col(local_cands, col)
                local_cands['score'] = sum(local_cands.get(f"{col}_norm", 0) * weights[col] for col in weights)
                local_cands = local_cands[np.isfinite(local_cands['score'])].copy()
                indices = local_cands.index.tolist()

                # If fewer remain than needed, we'll take as many as possible and backfill later
                target = min(need, len(indices))

                # ------ MIP ------
                if method == 'mip':
                    prob = pulp.LpProblem(f"SensorPlacement_GW{gw_idx}", pulp.LpMaximize)
                    x = pulp.LpVariable.dicts("sensor", indices, cat='Binary')
                    prob += pulp.lpSum([local_cands.loc[i, 'score'] * x[i] for i in indices]), "Total_Score"
                    prob += pulp.lpSum([x[i] for i in indices]) == target, "Select_Target"
                    # Non-overlap within this gateway's set
                    for i, j in combinations(indices, 2):
                        loc_i = (local_cands.loc[i, 'Longitude'], local_cands.loc[i, 'Latitude'])
                        loc_j = (local_cands.loc[j, 'Longitude'], local_cands.loc[j, 'Latitude'])
                        if geodesic((loc_i[1], loc_i[0]), (loc_j[1], loc_j[0])).meters < min_sensor_separation:
                            prob += x[i] + x[j] <= 1, f"Separation_{i}_{j}"
                    prob.solve()
                    if prob.status == pulp.LpStatusOptimal:
                        sel = [i for i in indices if x[i].varValue == 1]
                        assigned = local_cands.loc[sel].copy()
                        assigned['Assigned_Gateway'] = gw_idx
                        # append to global selections and block their radius globally
                        for _,rr in assigned.iterrows():
                            global_selected_pts.append((float(rr['Longitude']), float(rr['Latitude'])))
                        global_selected_rows.append(assigned)
                        print(f"  Gateway {gw_idx} selected {len(sel)} sensors.")
                        if len(sel) < need:
                            unmet += (need - len(sel))
                    else:
                        print(f"❌ MIP Optimization failed for Gateway {gw_idx}. Status: {pulp.LpStatus[prob.status]}")
                        unmet += need

                # ------ NSGA-III ------
                elif method == 'nsga3':
                    selected_obj = list(obj_select.value)
                    ref_dirs = get_reference_directions("das-dennis", len(selected_obj), n_partitions=3)
                    class LocalSensorPlacementProblem(Problem):
                        def __init__(self, candidate_df, n_sensors):
                            self.candidate_df = candidate_df.reset_index(drop=True)
                            self.n_sensors = n_sensors
                            super().__init__(n_var=n_sensors, n_obj=len(selected_obj), n_constr=1,
                                             xl=0, xu=len(self.candidate_df)-1, elementwise_evaluation=True)
                        def _evaluate(self, x, out, *args, **kwargs):
                            Fs, Gs = [], []
                            x = np.atleast_2d(x)
                            for xi_num, xi in enumerate(x):
                                idxs = [int(i) for i in xi]
                                sensors = self.candidate_df.iloc[idxs]
                                coords = [(row['Longitude'], row['Latitude']) for _, row in sensors.iterrows()]
                                F = []
                                for obj in selected_obj:
                                    norm_obj = obj + "_norm"
                                    if norm_obj in sensors.columns:
                                        if obj == 'Best_PathLoss_dB':
                                            F.append(np.mean(sensors[norm_obj]))      # minimize pathloss
                                        else:
                                            F.append(-np.mean(sensors[norm_obj]))     # maximize others
                                    else:
                                        F.append(0)
                                # within-set separation
                                min_dist = float('inf')
                                for p1, p2 in combinations(coords, 2):
                                    d = geodesic((p1[1], p1[0]), (p2[1], p2[0])).meters
                                    if d < min_dist:
                                        min_dist = d
                                g1 = min_sensor_separation - min_dist
                                Fs.append(F)
                                Gs.append([g1])
                                if xi_num < 3:
                                    print(f"    [DEBUG-GW{gw_idx}] Indiv {xi_num}: idxs={idxs}, F={F}, min_dist={min_dist:.2f}, g1={g1:.2f}")
                            out["F"] = np.array(Fs)
                            out["G"] = np.array(Gs)
                    problem = LocalSensorPlacementProblem(local_cands, target)
                    algorithm = NSGA3(ref_dirs=ref_dirs)
                    try:
                        res = minimize(problem, algorithm, termination=('n_gen', 100), seed=gw_idx+1, verbose=True)
                    except Exception as e:
                        print(f"NSGA-III crashed for Gateway {gw_idx}:", str(e))
                        unmet += need
                        continue
                    feasible = feasible_indices_from_result(res, tol=1e-6)

                    if hasattr(res, "CV") and res.CV is not None:
                        print(f"    [DEBUG-GW{gw_idx}] CV stats: min={np.min(res.CV):.3f}, median={np.median(res.CV):.3f}, max={np.max(res.CV):.3f}")

                    if len(feasible) == 0:
                        if hasattr(res, "CV") and res.CV is not None:
                            best_idx = int(np.argmin(res.CV))
                            print(f"No feasible solutions for Gateway {gw_idx}. Taking least-infeasible (CV={res.CV[best_idx]:.3f}).")
                        else:
                            best_idx = 0
                    else:
                        best_idx_local = int(np.argmin(np.linalg.norm(res.F[feasible], axis=1)))
                        best_idx = int(feasible[best_idx_local])

                    best_X = res.X[best_idx]
                    idxs = [int(i) for i in best_X]
                    assigned = local_cands.iloc[idxs].copy()
                    assigned['Assigned_Gateway'] = gw_idx
                    for _,rr in assigned.iterrows():
                        global_selected_pts.append((float(rr['Longitude']), float(rr['Latitude'])))
                    global_selected_rows.append(assigned)
                    print(f"  Gateway {gw_idx} selected {len(idxs)} sensors.")
                    if len(idxs) < need:
                        unmet += (need - len(idxs))
                else:
                    print("Invalid method selected.")
                    return

            # ---------- PASS 2: backfill any unmet quota from any remaining candidates ----------
            if unmet > 0:
                print(f"\nBackfilling {unmet} remaining sensors from all gateways without overlap...")
                remaining = cand_df.copy()
                if global_selected_pts:
                    keep_idx = []
                    for i,r in remaining.iterrows():
                        P = (float(r['Longitude']), float(r['Latitude']))
                        if ok_vs_global(P, global_selected_pts, min_sensor_separation):
                            keep_idx.append(i)
                    remaining = remaining.loc[keep_idx].copy()
                if not remaining.empty:
                    for col in get_sensor_weights():
                        if col in remaining.columns:
                            remaining[f"{col}_norm"] = normalize_col(remaining, col)
                    remaining['score'] = sum(remaining.get(f"{col}_norm", 0) * get_sensor_weights()[col] for col in get_sensor_weights())
                    remaining = remaining.sort_values('score', ascending=False)
                    rows = []
                    for _,r in remaining.iterrows():
                        if unmet <= 0:
                            break
                        P = (float(r['Longitude']), float(r['Latitude']))
                        if ok_vs_global(P, global_selected_pts, min_sensor_separation):
                            r = r.copy()
                            r['Assigned_Gateway'] = int(r.get('Nearest_Gateway', r.get('Best_Gateway', 0)))
                            rows.append(r)
                            global_selected_pts.append(P)
                            unmet -= 1
                    if rows:
                        global_selected_rows.append(pd.DataFrame(rows))
                if unmet > 0:
                    print(f"Could not place {unmet} sensors due to non-overlap/global constraints.")

            # ---------- Final reporting ----------
            if global_selected_rows:
                selected = pd.concat(global_selected_rows, ignore_index=True)
                union_cov = get_coverage_union(selected[['Longitude', 'Latitude']].values, coverage_grid, SENSOR_RADIUS)
                print(f"\nPartitioned {method.upper()} Solution: {len(selected)} sensors")
                print(f"  - AOI coverage: {union_cov:.1f}% (union of all sensors)")
                display(selected[['Longitude', 'Latitude', 'Solar', 'Elevation', 'Num_Connectable', 'Best_PathLoss_dB', 'Assigned_Gateway']])
                display(plot_selected_sensors_map(aoi_poly, gateways, selected, gateway_colors=gateway_colors, gateway_col='Assigned_Gateway'))
            else:
                print("No sensors selected in partitioned mode.")

        # ====== GLOBAL (NON-PARTITIONED) MODE ======
        else:
            print("\n=== Global Assignment: All sensors selected from all candidates. ===")
            print("Global candidate gateway distribution:")
            print(cand_df['Best_Gateway'].value_counts().sort_index())
            print(f"Multi-connectable sensors (Num_Connectable > 1): {(cand_df['Num_Connectable'] > 1).sum()}")
            weights = get_sensor_weights()
            for col in weights:
                if col in cand_df.columns:
                    cand_df[f"{col}_norm"] = normalize_col(cand_df, col)
            cand_df['score'] = sum(cand_df.get(f"{col}_norm", 0) * weights[col] for col in weights)
            cand_df = cand_df[np.isfinite(cand_df['score'])].copy()
            indices = cand_df.index.tolist()

            if method == 'mip':
                print("Solving global MIP for best sensor set...")
                prob = pulp.LpProblem("SensorPlacement", pulp.LpMaximize)
                x = pulp.LpVariable.dicts("sensor", indices, cat='Binary')
                prob += pulp.lpSum([cand_df.loc[i, 'score'] * x[i] for i in indices]), "Total_Score"
                prob += pulp.lpSum([x[i] for i in indices]) == N, "Select_N_Sensors"
                # Non-overlap
                for i, j in combinations(indices, 2):
                    loc_i = (cand_df.loc[i, 'Longitude'], cand_df.loc[i, 'Latitude'])
                    loc_j = (cand_df.loc[j, 'Longitude'], cand_df.loc[j, 'Latitude'])
                    if geodesic((loc_i[1], loc_i[0]), (loc_j[1], loc_j[0])).meters < min_sensor_separation:
                        prob += x[i] + x[j] <= 1, f"Separation_{i}_{j}"
                prob.solve()
                if prob.status == pulp.LpStatusOptimal:
                    sel = [i for i in indices if x[i].varValue == 1]
                    selected = cand_df.loc[sel]
                    union_cov = get_coverage_union(selected[['Longitude', 'Latitude']].values, coverage_grid, SENSOR_RADIUS)
                    print(f"MIP Solution: {N} sensors")
                    print(f"  - AOI coverage: {union_cov:.1f}% (union of all sensors)")
                    display(selected[['Longitude', 'Latitude', 'Solar', 'Elevation', 'Num_Connectable', 'Best_PathLoss_dB']])
                    display(plot_selected_sensors_map(aoi_poly, gateways, selected, gateway_colors=gateway_colors, gateway_col='Best_Gateway'))
                else:
                    print("MIP Optimization failed. Status:", pulp.LpStatus[prob.status])

            elif method == 'nsga3':
                print("Running global NSGA-III for sensor Pareto optimization...")
                selected_obj = list(obj_select.value)
                ref_dirs = get_reference_directions("das-dennis", len(selected_obj), n_partitions=4)
                class SensorPlacementProblem(Problem):
                    def __init__(self, candidate_df, n_sensors, objectives, coverage_grid, sensor_radius):
                        self.candidate_df = candidate_df.reset_index(drop=True)
                        self.n_sensors = n_sensors
                        self.objectives = objectives
                        self.coverage_grid = coverage_grid
                        self.sensor_radius = sensor_radius
                        super().__init__(n_var=n_sensors, n_obj=len(objectives), n_constr=1, xl=0, xu=len(self.candidate_df)-1, elementwise_evaluation=True)
                    def _evaluate(self, x, out, *args, **kwargs):
                        Fs, Gs = [], []
                        x = np.atleast_2d(x)
                        for xi_num, xi in enumerate(x):
                            idxs = [int(i) for i in xi]
                            sensors = self.candidate_df.iloc[idxs]
                            coords = [(row['Longitude'], row['Latitude']) for _, row in sensors.iterrows()]
                            F = []
                            for obj in self.objectives:
                                norm_obj = obj + "_norm"
                                if norm_obj in sensors.columns:
                                    if obj == 'Best_PathLoss_dB':
                                        F.append(np.mean(sensors[norm_obj]))  # minimize pathloss
                                    else:
                                        F.append(-np.mean(sensors[norm_obj])) # maximize all others
                                else:
                                    F.append(0)
                            min_dist = float('inf')
                            for p1, p2 in combinations(coords, 2):
                                dist = geodesic((p1[1], p1[0]), (p2[1], p2[0])).meters
                                if dist < min_dist:
                                    min_dist = dist
                            g1 = min_sensor_separation - min_dist
                            Fs.append(F)
                            Gs.append([g1])
                            if xi_num < 3:
                                print(f"    [DEBUG-GLOBAL] Indiv {xi_num}: idxs={idxs}, F={F}, min_dist={min_dist:.2f}, g1={g1:.2f}")
                        out["F"] = np.array(Fs)
                        out["G"] = np.array(Gs)
                problem = SensorPlacementProblem(
                    candidate_df=cand_df,
                    n_sensors=N,
                    objectives=selected_obj,
                    coverage_grid=coverage_grid,
                    sensor_radius=SENSOR_RADIUS
                )
                algorithm = NSGA3(ref_dirs=ref_dirs)
                try:
                    res = minimize(problem, algorithm, termination=('n_gen', 10), seed=1, verbose=True)
                except AssertionError as e:
                    print("NSGA-III crashed:", str(e))
                    print("Try lowering n_partitions, reducing objectives, or loosening constraints.")
                    return
                print(f"NSGA-III generations completed. Solution shape F: {getattr(res, 'F', None).shape}")
                feasible = np.where(res.G[:,0] <= 1e-6)[0]
                print(f"Feasible solutions found: {len(feasible)}")
                if len(feasible) == 0:
                    print("No feasible solutions.")
                    return
                best_idx = np.argmin(np.linalg.norm(res.F[feasible], axis=1))
                best_X = res.X[feasible][best_idx]
                idxs = [int(i) for i in best_X]
                selected = cand_df.iloc[idxs]
                union_cov = get_coverage_union(selected[['Longitude', 'Latitude']].values, coverage_grid, SENSOR_RADIUS)
                print(f"NSGA-III Solution: {N} sensors")
                print(f"  - AOI coverage: {union_cov:.1f}% (union of all sensors)")
                display(selected[['Longitude', 'Latitude', 'Solar', 'Elevation', 'Num_Connectable', 'Best_PathLoss_dB']])
                display(plot_selected_sensors_map(aoi_poly, gateways, selected, gateway_colors=gateway_colors, gateway_col='Best_Gateway'))
            else:
                print("Invalid method selected.")
    print("✅ Phase 2 complete.")

run_btn.on_click(on_sensor_opt_run)

ui = VBox([
    HBox([num_sensors_slider, optimization_method, assignment_mode, run_btn]),
    HBox([w_coverage, w_solar, w_elevation, w_multiconn, w_pathloss]),
    HBox([obj_select]),
    output
])
display(HTML("<h3>Phase 2: Sensor Multi-Objective Placement Optimization</h3>"))
display(ui)

In [None]:
# --- 3D Terrain + Gateways/Sensors + Radius Disks (Plotly, all axes in meters) ---
# pip install plotly rasterio shapely geopy pyproj

import numpy as np, pandas as pd, rasterio
from shapely.geometry import Polygon
from pyproj import Transformer, Geod
from geopy.distance import distance as geo_distance
import plotly.graph_objects as go

# ---------------------------
DEM_PATH = "/Users/chrisjuarez/moo_node/palisadesoutput.dsm.tif"
SENSOR_RADIUS_M = int(globals().get("SENSOR_RADIUS", 500))
LINK_HEIGHT_M   = 15
MAX_GRID = 600
SENSOR_GATEWAY_COL = "Assigned_Gateway" if "Assigned_Gateway" in globals().get("df", pd.DataFrame()).columns else "Best_Gateway"
# ---------------------------

# Pick the sensors table to show
if 'selected' in globals():
    sensors_df = selected.copy()
else:
    sensors_df = df[df['Is_Connected']].copy()
    if SENSOR_GATEWAY_COL not in sensors_df.columns:
        SENSOR_GATEWAY_COL = 'Best_Gateway'

# ---- DEM helpers (meters) ----
def open_dem_tools(dem_path):
    ds = rasterio.open(dem_path)
    to_dem   = Transformer.from_crs("EPSG:4326", ds.crs, always_xy=True)   # lon/lat -> dem(x,y)
    to_wgs84 = Transformer.from_crs(ds.crs, "EPSG:4326", always_xy=True)   # dem(x,y) -> lon/lat
    geod = Geod(ellps="WGS84")

    def lonlat_to_demxy(lons, lats):
        return to_dem.transform(np.asarray(lons), np.asarray(lats))

    def sample_elev_lola(lons, lats):
        xs, ys = lonlat_to_demxy(lons, lats)
        vals = list(ds.sample(zip(xs, ys)))
        return np.array([v[0] if len(v) else np.nan for v in vals])

    def terrain_surface_for_aoi_xy(aoi, max_grid=600):
        if not isinstance(aoi, Polygon) and not hasattr(aoi, "geoms"):
            raise ValueError("AOI must be Polygon or MultiPolygon.")
        if hasattr(aoi, "geoms"):
            aoi = max(aoi.geoms, key=lambda p: p.area)

        # Use lon/lat bounds then map to DEM pixel window
        minx, miny, maxx, maxy = aoi.bounds  # lon/lat
        col0, row0 = ds.index(minx, maxy)
        col1, row1 = ds.index(maxx, miny)
        col0, row0 = np.clip(col0, 0, ds.width-1),  np.clip(row0, 0, ds.height-1)
        col1, row1 = np.clip(col1, 0, ds.width-1),  np.clip(row1, 0, ds.height-1)
        cmin, cmax = sorted([col0, col1]); rmin, rmax = sorted([row0, row1])

        width  = cmax - cmin + 1
        height = rmax - rmin + 1
        stride = int(np.ceil(max(width, height) / max_grid)) if max(width, height) > max_grid else 1

        window = rasterio.windows.Window(cmin, rmin, width, height)
        Z = ds.read(1, window=window)[::stride, ::stride]

        cols = np.arange(cmin, cmax+1, stride)
        rows = np.arange(rmin, rmax+1, stride)
        xs_idx = cols + 0.5
        ys_idx = rows + 0.5
        X_map, Y_map = rasterio.transform.xy(ds.transform, ys_idx[:,None], xs_idx[None,:], offset='center')  # DEM CRS
        X_map = np.array(X_map); Y_map = np.array(Y_map)
        return X_map, Y_map, Z  # meters, meters, meters

    def circle_on_terrain_xy(lon, lat, radius_m, n=120):
        bearings = np.linspace(0, 360, n, endpoint=False)
        pts = [geo_distance(meters=radius_m).destination((lat, lon), b) for b in bearings]
        lats = [p.latitude for p in pts] + [pts[0].latitude]
        lons = [p.longitude for p in pts] + [pts[0].longitude]
        elev = sample_elev_lola(lons, lats)
        x_ring, y_ring = lonlat_to_demxy(lons, lats)
        xc, yc = lonlat_to_demxy([lon], [lat])
        zc = float(sample_elev_lola([lon], [lat])[0])
        xs = np.concatenate([xc, x_ring])
        ys = np.concatenate([yc, y_ring])
        zs = np.concatenate([[zc], elev])
        i = []; j = []; k = []
        for t in range(1, len(x_ring)):
            i += [0]; j += [t]; k += [t+1]
        return xs, ys, zs, np.array(i), np.array(j), np.array(k)

    def air_path_xy(lon1, lat1, lon2, lat2, n=50, h_above=15):
        n = max(2, int(n))
        mid = geod.npts(lon1, lat1, lon2, lat2, n-2) if n > 2 else []
        lons = [lon1] + [p[0] for p in mid] + [lon2]
        lats = [lat1] + [p[1] for p in mid] + [lat2]
        ground = sample_elev_lola(lons, lats)
        X, Y = lonlat_to_demxy(lons, lats)
        return X, Y, ground + h_above

    return ds, lonlat_to_demxy, sample_elev_lola, terrain_surface_for_aoi_xy, circle_on_terrain_xy, air_path_xy

# --- Build figure in meters ---
ds, lonlat_to_demxy, sample_elev_lola, terrain_xy, circle_xy, air_xy = open_dem_tools(DEM_PATH)
Xg, Yg, Zg = terrain_xy(aoi_poly, max_grid=MAX_GRID)

fig = go.Figure()

# Terrain surface (all meters)
fig.add_surface(x=Xg, y=Yg, z=Zg, showscale=False, colorscale="Earth", opacity=0.98)

# Colors
PALETTE = ["#e41a1c","#377eb8","#4daf4a","#984ea3","#ff7f00","#a65628","#f781bf","#999999"]
gw_colors = {i: PALETTE[i % len(PALETTE)] for i in range(len(gateways))}

# Gateways
gw_lons = [g[0] for g in gateways]; gw_lats = [g[1] for g in gateways]
gw_x, gw_y = lonlat_to_demxy(gw_lons, gw_lats)
gw_z = sample_elev_lola(gw_lons, gw_lats)
fig.add_trace(go.Scatter3d(
    x=gw_x, y=gw_y, z=gw_z+5, mode="markers+text",
    marker=dict(size=6, color=[gw_colors[i] for i in range(len(gateways))]),
    text=[f"GW{i}" for i in range(len(gateways))], textposition="top center",
    name="Gateways"
))

# Sensors
sns_lons = sensors_df["Longitude"].to_numpy()
sns_lats = sensors_df["Latitude"].to_numpy()
sns_x, sns_y = lonlat_to_demxy(sns_lons, sns_lats)
sns_z = sample_elev_lola(sns_lons, sns_lats)
fig.add_trace(go.Scatter3d(
    x=sns_x, y=sns_y, z=sns_z+3, mode="markers",
    marker=dict(size=3, color=[gw_colors[int(g)] for g in sensors_df[SENSOR_GATEWAY_COL].to_numpy()]),
    name="Sensors"
))

# Links + radius disks
for _, r in sensors_df.iterrows():
    lon, lat = float(r["Longitude"]), float(r["Latitude"])
    gw_idx = int(r.get(SENSOR_GATEWAY_COL, 0))
    if not (0 <= gw_idx < len(gateways)) or np.isnan(lon) or np.isnan(lat):
        continue
    gw_lon, gw_lat = gateways[gw_idx]

    # Air path in meters
    ax, ay, az = air_xy(lon, lat, gw_lon, gw_lat, n=50, h_above=LINK_HEIGHT_M)
    fig.add_trace(go.Scatter3d(
        x=ax, y=ay, z=az, mode="lines",
        line=dict(width=2, color=gw_colors[gw_idx]), opacity=0.7, showlegend=False
    ))

    # Draped sensor disk in meters
    cx, cy, cz, ii, jj, kk = circle_xy(lon, lat, SENSOR_RADIUS_M, n=80)
    fig.add_trace(go.Mesh3d(
        x=cx, y=cy, z=cz, i=ii, j=jj, k=kk,
        color=gw_colors[gw_idx], opacity=0.15, name="Radius",
        hoverinfo="skip", lighting=dict(ambient=0.8, diffuse=0.5)
    ))

# Layout: all axes share meter units -> aspectmode 'data' is now correct
fig.update_layout(
    title="WSN 3D View — Terrain + Gateways/Sensors + Air Paths + Terrain-draped Radius",
    showlegend=True, legend=dict(x=0.02, y=0.98),
    scene=dict(
        xaxis_title="Easting (m)", yaxis_title="Northing (m)", zaxis_title="Elevation (m)",
        aspectmode="data",
        xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
        yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
        zaxis=dict(showgrid=False)
    ),
    margin=dict(l=0, r=0, t=40, b=0),
)
fig.show()