# Drainage

We need to salute the lady of the lake

In [None]:
#| default_exp water/watershed

## Prior Art

In [None]:
#| export
import numpy as np
import sys
import os
import math
import random

#data
from collections import namedtuple
from dataclasses import dataclass,  field, asdict
from typing import List
from enum import Enum
import heapq

#Jeremy
from dialoghelper import * 
from fastcore.basics import patch
from fasthtml.common import *
from fasthtml.jupyter import *
import httpx

#custom
import copy

In [None]:
#| export
from pathlib import Path
sys.path.insert(0, str(Path().resolve().parent.parent))

In [None]:
#| export
from HexMagic.styles import StyleCSS, SVGBuilder, SVGLayer, SVGPatternLoader, preview, app, StyleDemo, LayerAnimation
from HexMagic.primitives import MapCord, MapSize, MapRect, MapPath, Hex, HexGrid, HexWrapper, HexPosition, hexBackground, HexRegion, unique_windy_edge
from HexMagic.terrain import  TerraDemo, Terrain, GeoBounds, ClimatePreset
from HexMagic.terrainpatterns import TerrainPatterns, SVGMask
from HexMagic.water.soil import SoilSystem, SoilType
from HexMagic.water.river import River, RiverDemo
from HexMagic.weather import TerraDemo

In [None]:
#    precipitation_elevation_summary() - analysis for river planning
    

## Watershed

In [None]:
#| export

class Watershed:
   

    def __init__(self, region: HexRegion, tributary: River, style: StyleCSS,system:SoilSystem=None):
        self.region = region
        self.tributary = tributary
        self.style = style
        self.terrain = tributary.terrain
        if system is None:
            system = SoilSystem.from_plates(self.terrain,[])
        self.system = system


    @property
    def terminal_hex(self) -> int:
        """Find the hex with lowest elevation in this watershed.
        
        Returns:
            Index of the lowest hex
        """
        if not self.region.hexes:
            return None
        
        terrain = self.tributary.terrain
        lowest_idx = None
        lowest_elev = float('inf')
        
        for hex_idx in self.region.hexes:
            elev = terrain.elevations[hex_idx]
            if elev < lowest_elev:
                lowest_elev = elev
                lowest_idx = hex_idx
        
        return lowest_idx

    @property
    def is_ocean(self) -> bool:
        """Check if this watershed drains to the ocean."""
        # Check explicit outlet first
        if self.tributary.ocean_outlet is not None:
            return True
        # Fallback: check if any river hex is in ocean
        for h in self.tributary.hexes:
            if self.terrain.elevations[h] <= 0:
                return True
        return False




    @classmethod
    def river_peak(cls, terrain, peak_index):
        """Create a Watershed by tracing downhill from a peak."""
        
        # Trace downhill
        path = [peak_index]
        current = peak_index
        ocean_outlet = None
        
        while True:
            lowest = terrain.lowest_neighbor(current)
            if lowest is None:
                break
            if terrain.elevations[lowest] <= 0:
                ocean_outlet = lowest
                path.append(lowest)
                break
            path.append(lowest)
            current = lowest
        
        if len(path) < 2:
            return None
        
        # Build River
        path.reverse()
        river = River(terrain)
        river.tree.create_node(tag="segment", identifier=0, data=path)
        river.hexes.update(path)
        river.ocean_outlet = ocean_outlet
        
        # Wrap in Watershed
        region = HexRegion(hexes=river.hexes.copy(), hexGrid=terrain.hexGrid)
        style = StyleCSS("default", fill="#cccccc")  # Placeholder, assigned later
        
        return cls(region=region, tributary=river, style=style)

    @classmethod
    def from_peak(cls, terrain, peak_index):
        """Create a river by tracing downhill from a peak."""
        
        # Trace downhill to find the outlet first
        path = [peak_index]
        current = peak_index
        ocean_outlet = None  # Track which ocean hex we hit
        
        while True:
            lowest = terrain.lowest_neighbor(current)
            # Stop if local minimum
            if lowest is None:
                break
            
            # Check if we hit ocean
            if terrain.elevations[lowest] <= 0:
                ocean_outlet = lowest  # Remember the ocean hex
                path.append(lowest)    # ← ADD THIS LINE
                break
            
            path.append(lowest)
            current = lowest
        
        # Don't create river if path is too short
        if len(path) < 2:
            return None
        
        # Build tree from outlet (root) upward
        path.reverse()
        
        river = River(terrain)
        river.tree.create_node(tag="segment", identifier=0, data=path)
        river.hexes.update(path)
        river.ocean_outlet = ocean_outlet  # Store which ocean hex we flow to
        
        return river
        
    
    @classmethod
    def compute_all(cls, terrain: Terrain, num_peaks: int = 50, min_height: int = 1, debug: bool = False) -> list['Watershed']:
        """Compute all watersheds by:
        1. Find peaks and create rivers
        2. Merge rivers using River.combine_rivers()
        3. Assign each land hex to the river it flows into
        """
        
        # Step 1: Find peaks and create rivers
        sources = terrain.find_river_sources(
        min_precipitation=terrain.climate.precip_bins[0],  # Use first bin
        min_elevation=100,
        top_n=num_peaks
        )
        peaks = [idx for idx, precip, elev, score in sources]

        rivers = [River.from_peak(terrain, peak) for peak in peaks]
        rivers = [r for r in rivers if r is not None]
        
        if debug:
            print(f"\n=== STEP 1: Created {len(rivers)} rivers from {len(peaks)} peaks ===")
            for i, river in enumerate(rivers[:5]):  # Show first 5
                print(f"  River {i}: {len(river.hexes)} hexes, outlet={river.ocean_outlet}")
        
        # Step 2: Merge intersecting rivers
        merged_rivers = River.combine_rivers(rivers)
        
        if debug:
            print(f"\n=== STEP 2: Merged to {len(merged_rivers)} rivers ===")
            for i, river in enumerate(merged_rivers):
                outlet_type = "ocean" if river.ocean_outlet is not None else "lake/none"
                print(f"  River {i}: {len(river.hexes)} hexes, outlet={outlet_type}")
        
        # Step 3: For each land hex, trace downhill to find which river it flows to
        hex_to_river = {}  # hex_idx -> river_index
        unassigned = []  # Track hexes that don't reach any river
        
        for i in range(len(terrain.elevations)):
            if terrain.elevations[i] <= 0:
                continue  # Skip ocean
            
            # Trace downhill until we hit a river or ocean
            current = i
            visited = {i}
            path_length = 0
            
            while True:
                # Check if current hex is in any river
                for river_idx, river in enumerate(merged_rivers):
                    if current in river.hexes:
                        hex_to_river[i] = river_idx
                        break
                
                if i in hex_to_river:
                    break
                
                # Move downhill
                lowest = terrain.lowest_neighbor(current)
                if lowest is None or lowest in visited:
                    unassigned.append((i, "local_minimum", path_length))
                    break  # Local minimum (lake)
                if terrain.elevations[lowest] <= 0:
                    unassigned.append((i, "ocean_direct", path_length))
                    break  # Hit ocean without hitting river
                
                visited.add(lowest)
                current = lowest
                path_length += 1
                
                if path_length > 100:  # Safety
                    unassigned.append((i, "too_long", path_length))
                    break
        
        if debug:
            print(f"\n=== STEP 3: Traced {len(hex_to_river)} land hexes to rivers ===")
            print(f"  Unassigned: {len(unassigned)} hexes")
            
            # Show breakdown of unassigned reasons
            from collections import Counter
            reasons = Counter(reason for _, reason, _ in unassigned)
            for reason, count in reasons.items():
                print(f"    {reason}: {count}")
            
            # Show which rivers got the most drainage
            river_sizes = {}
            for hex_idx, river_idx in hex_to_river.items():
                river_sizes[river_idx] = river_sizes.get(river_idx, 0) + 1
            
            print(f"\n  Top 5 watersheds by drainage area:")
            for river_idx, count in sorted(river_sizes.items(), key=lambda x: x[1], reverse=True)[:5]:
                river_hexes = len(merged_rivers[river_idx].hexes)
                print(f"    River {river_idx}: {count} drainage hexes + {river_hexes} river hexes = {count + river_hexes} total")
        
        # Step 4: Create watersheds from rivers + assigned hexes
        colors = StyleCSS.seaborn("tab20", levels=20)
        watersheds = []
        
        for river_idx, river in enumerate(merged_rivers):
            # Collect all hexes that flow to this river
            basin_hexes = set(river.hexes)
            for hex_idx, assigned_river in hex_to_river.items():
                if assigned_river == river_idx:
                    basin_hexes.add(hex_idx)
            
            region = HexRegion(hexes=basin_hexes, hexGrid=terrain.hexGrid)
            style = colors[river_idx % len(colors)]
            river.terrain = terrain
            
            
            watershed = cls(
                region=region,
                tributary=river,
                style=style
            )
            watersheds.append(watershed)
        
        if debug:
            print(f"\n=== STEP 4: Created {len(watersheds)} watersheds ===")
            total_assigned = sum(len(w.region.hexes) for w in watersheds)
            total_land = sum(1 for e in terrain.elevations if e > 0)
            print(f"  Total coverage: {total_assigned}/{total_land} land hexes ({100*total_assigned/total_land:.1f}%)")
            for i, watershed in enumerate(watersheds):
                print(f"  Watershed {i}: {len(watershed.region.hexes)} hexes, {watershed.tributary.terrain}")

        
        
        return watersheds





In [None]:
#| export
@patch
def calculate_flow(self: Watershed) -> dict[int, float]:
    """Calculate water flow at each hex accounting for soil and precipitation.
    
    Runoff = precipitation * (1 - permeability)
    Flow accumulates downstream from high to low elevation.
    
    Returns:
        {hex_index: flow_volume}
    """
    terrain = self.terrain
    precip = terrain.fields.get('precipitation', np.ones(len(terrain.elevations)) * 500)
    soil_types = self.system.types
    soil_field = terrain.fields.get('soil_type', np.zeros(len(terrain.elevations), dtype=int))
    
    # Calculate local runoff for each hex
    local_runoff = {}
    for h in self.region.hexes:
        if terrain.elevations[h] <= 0:
            continue
        soil_idx = int(soil_field[h])
        runoff_coef = 1.0 - soil_types[soil_idx].permeability
        local_runoff[h] = precip[h] * runoff_coef
    
    # Accumulate downstream: process high→low elevation
    sorted_hexes = sorted(local_runoff.keys(), 
                          key=lambda h: terrain.elevations[h], 
                          reverse=True)
    
    flow = {h: local_runoff[h] for h in sorted_hexes}
    
    for h in sorted_hexes:
        lowest = terrain.lowest_neighbor(h)
        if lowest is not None and lowest in flow:
            flow[lowest] += flow[h]
    
    return flow


In [None]:
#| export
@patch
def max_flow_hex(self:Watershed):
    flows = self.calculate_flow()
    if not flows:
        return (self.terminal_hex, 0)
    max_hex = max(flows, key=flows.get)
    return (max_hex, flows[max_hex])



In [None]:
#| export
@patch
def lake_basin(self: Watershed, base_size: int = 3, log_scale: float = 2.0, max_hexes=None, debug=False) -> HexRegion:
    """
    Create a lake basin at the terminal hex, sized by log of accumulated flow.
    
    Args:
        base_size: Minimum lake size in hexes
        log_scale: Controls how much flow increases size (higher = bigger lakes)
        debug: Print step-by-step hex additions
    
    Returns:
        HexRegion representing the lake basin
    """
    flows = self.calculate_flow()
    terminal = self.terminal_hex
    
    if terminal is None or terminal not in flows:
        if debug:
            print(f"No valid terminal hex (terminal={terminal})")
        return HexRegion(hexes=set(), hexGrid=self.terrain.hexGrid)
    
    if max_hexes is None:
        max_hexes = len(self.region.hexes)
        
    # Log scale: size = base + log_scale * log10(flow)
    flow = max(1, flows[terminal])
    target_hexes = int(base_size + log_scale * math.log10(flow))
    target_hexes = max(1, target_hexes)
    
    if debug:
        print(f"\n=== LAKE BASIN GROWTH ===")
        print(f"Terminal hex: {terminal}, elevation: {self.terrain.elevations[terminal]:.1f}")
        print(f"Flow at terminal: {flow:.1f}")
        print(f"Target size: {target_hexes} hexes (base={base_size}, log_scale={log_scale})")
        print(f"Max hexes: {max_hexes}")
        print(f"\nGrowth steps:")
    
    # Grow from terminal by adding lowest neighbors
    lake_hexes = {terminal}
    candidates = []
    
    if debug:
        print(f"  Step 0: Added terminal {terminal} (elev={self.terrain.elevations[terminal]:.1f})")
    
    for neighbor in self.terrain.hexGrid.neighborsOf(terminal):
        if neighbor >= 0 and neighbor in self.region.hexes:
            heapq.heappush(candidates, (self.terrain.elevations[neighbor], neighbor))
            if debug:
                print(f"    Candidate: {neighbor} (elev={self.terrain.elevations[neighbor]:.1f})")
    
    step = 1
    
    while len(lake_hexes) < target_hexes and len(lake_hexes) < max_hexes and candidates:

        elev, hex_idx = heapq.heappop(candidates)
        
        if hex_idx in lake_hexes:
            if debug:
                print(f"  Step {step}: Skipped {hex_idx} (already in lake)")
            continue
        
        lake_hexes.add(hex_idx)
        
        if debug:
            print(f"  Step {step}: Added {hex_idx} (elev={elev:.1f}), lake size={len(lake_hexes)}")
        
        new_candidates = 0
        for neighbor in self.terrain.hexGrid.neighborsOf(hex_idx):
            if neighbor >= 0 and neighbor not in lake_hexes and neighbor in self.region.hexes:
                heapq.heappush(candidates, (self.terrain.elevations[neighbor], neighbor))
                new_candidates += 1
        
        if debug and new_candidates > 0:
            print(f"    Added {new_candidates} new candidates")
        
        step += 1
    
    if debug:
        print(f"\nFinal lake: {len(lake_hexes)} hexes")
        if candidates:
            print(f"Remaining candidates: {len(candidates)}")
        else:
            print("No more candidates")
    
    return HexRegion(hexes=lake_hexes, hexGrid=self.terrain.hexGrid)


In [None]:
def directions_in_cone(flow_dir: HexPosition, allowed_rotations: int) -> set[HexPosition]:
    """Get all directions within N rotations of flow_dir."""
    dirs = {flow_dir}
    current_left = flow_dir
    current_right = flow_dir
    for _ in range(allowed_rotations):
        current_left = current_left.rotate_left()
        current_right = current_right.rotate_right()
        dirs.add(current_left)
        dirs.add(current_right)
    return dirs

def hex_in_cone(hex_pos: HexPosition, flow_dir: HexPosition, allowed_rotations: int) -> bool:
    """Check if hex_pos is within cone defined by flow_dir ± allowed_rotations."""
    if abs(hex_pos) == 0:
        return True
    
    # Get direction of first step from origin toward hex_pos
    path_to_hex = HexPosition.origin().line_to(hex_pos)
    if len(path_to_hex) < 2:
        return True
    
    first_step = path_to_hex[1] - path_to_hex[0]  # Direction toward hex
    
    allowed = directions_in_cone(flow_dir, allowed_rotations)
    return first_step in allowed


@patch
def build_delta(self: Watershed, 
                sediment_scale: float = 0.01,
                max_hexes: int = 50,
                delta_elevation: float = 2.0,
                uplift_decay: float = 0.7,
                debug: bool = False) -> HexRegion:
    """Build alluvial delta by raising ocean hexes, with gradient-based watershed uplift."""
    terrain = self.terrain
    grid = terrain.hexGrid
    river = self.tributary
    
    if river.ocean_outlet is None:
        if debug: print("No ocean outlet")
        return HexRegion(hexes=set(), hexGrid=grid)
    
    outlet = river.ocean_outlet
    root_node = river.tree.get_node(river.tree.root)
    path = root_node.data
    
    if len(path) < 2:
        return HexRegion(hexes=set(), hexGrid=grid)
    
    # Direction = difference of last two hexes
    pre_outlet_pos = grid.index_to_hexposition(path[1], outlet)
    flow_direction = HexPosition.origin() - pre_outlet_pos
    
    if debug:
        print(f"Outlet: {outlet}, flow direction: {flow_direction.desc()}")
    
    # Sediment budget
    flows = self.calculate_flow()
    target_hexes = min(int(flows.get(self.terminal_hex, 0) * sediment_scale), max_hexes)
    
    if debug:
        print(f"Target hexes: {target_hexes}")
    
    # Build delta
    delta_hexes = set()
    candidates = []  # (priority, distance, hex_idx)
    
    # Seed with outlet's ocean neighbors in the cone
    for neighbor in grid.neighborsOf(outlet):
        if terrain.elevations[neighbor] <= 0:
            heapq.heappush(candidates, (0, 1, neighbor))
    
    while len(delta_hexes) < target_hexes and candidates:
        _, distance, hex_idx = heapq.heappop(candidates)
        
        if hex_idx in delta_hexes:
            continue
        
        # Check cone - expands with distance
        hex_pos = grid.index_to_hexposition(hex_idx, outlet)
        allowed_rotations = min(2, distance // 2)  # 0 -> 1 -> 2
        
        if not hex_in_cone(hex_pos, flow_direction, allowed_rotations):
            continue
        
        delta_hexes.add(hex_idx)
        
        if debug:
            print(f"  Added {hex_idx}, distance={distance}, rotations={allowed_rotations}")
        
        # Add ocean neighbors - prioritize those aligned with flow
        for neighbor in grid.neighborsOf(hex_idx):
            if neighbor not in delta_hexes and terrain.elevations[neighbor] <= 0:
                n_pos = grid.index_to_hexposition(neighbor, outlet)
                n_dist = abs(n_pos)
                # Priority: closer to flow direction = lower priority number
                dot = float(np.dot(n_pos._coords, flow_direction._coords))
                priority = -dot / n_dist if n_dist > 0 else 0
                heapq.heappush(candidates, (priority, distance + 1, neighbor))
    
    if not delta_hexes:
        if debug: print("No delta hexes created")
        return HexRegion(hexes=set(), hexGrid=grid)
    
    # Raise delta hexes
    for h in delta_hexes:
        terrain.elevations[h] = delta_elevation
    
    # Gradient uplift via BFS from outlet
    distances = {outlet: 0}
    queue = [outlet]
    while queue:
        current = queue.pop(0)
        for neighbor in grid.neighborsOf(current):
            if neighbor in self.region.hexes and neighbor not in distances:
                distances[neighbor] = distances[current] + 1
                queue.append(neighbor)
    
    base_uplift = delta_elevation + 1.0
    uplifted = 0
    for hex_idx, dist in distances.items():
        if terrain.elevations[hex_idx] > 0:
            uplift = base_uplift * (uplift_decay ** dist)
            if uplift > 0.1:
                terrain.elevations[hex_idx] += uplift
                uplifted += 1
    
    if debug:
        print(f"Delta: {len(delta_hexes)} hexes, uplifted {uplifted} watershed hexes")
    
    return HexRegion(hexes=delta_hexes, hexGrid=grid)


In [None]:
#| export
@patch
def simplify(self: Watershed, k: int = 3) -> 'Watershed':
    """Simplify watershed to k longest paths.
    
    Strategy:
    1. Find all leaf-to-terminal paths in the river tree
    2. Sort by length, keep top k
    3. For each of the k paths, trace a river from its highest hex
    4. Merge all k rivers together
    
    Args:
        k: Number of major tributaries to keep
        
    Returns:
        New Watershed with simplified river network
    """
    terrain = self.tributary.terrain
    
    # Step 1: Find all paths from leaves to root
    all_paths = []
    
    for node in self.tributary.tree.all_nodes():
        # Only start from leaves
        if self.tributary.tree.children(node.identifier):
            continue
        
        # Trace path from this leaf to root
        path = []
        current = node
        
        while current is not None:
            path.extend(current.data)
            parent_id = current.predecessor(self.tributary.tree.identifier)
            current = self.tributary.tree.get_node(parent_id) if parent_id else None
        
        all_paths.append(path)
    
    # Step 2: Sort by length and keep top k
    all_paths.sort(key=len, reverse=True)
    top_paths = all_paths[:k]
    
    # Step 3: Build k rivers from each path's highest hex
    rivers = []
    for path in top_paths:
        highest_hex = max(path, key=lambda h: terrain.elevations[h])
        new_river = Watershed.from_peak(terrain, highest_hex)
        if new_river is not None:
            rivers.append(new_river)
    
    # Step 4: Merge all rivers
    if not rivers:
        return self  # Fallback
    
    merged_rivers = River.combine_rivers(rivers)
    
    if not merged_rivers:
        return self
    
    # Take the largest merged river
    main_river = max(merged_rivers, key=lambda r: len(r.hexes))
    main_river.terrain = terrain
    
    # Create new region with all hexes from simplified rivers
    new_region = HexRegion(
        hexes=main_river.hexes.copy(),
        hexGrid=terrain.hexGrid
    )
    
    return Watershed(
        region=new_region,
        tributary=main_river,
        style=self.style
    )


In [None]:
#| export
@patch
def segments(self: Watershed) -> list[list[int]]:
    """Extract hex segments from river tree, sorted high→low elevation.
    
    Each segment includes connection to parent (if any) at the start.
    
    Returns:
        List of hex index lists, each representing a connected branch/segment
    """
    terrain = self.tributary.terrain
    result = []
    
    for node in self.tributary.tree.all_nodes():
        # Get land hexes only
        land_hexes = [h for h in node.data if terrain.elevations[h] > 0]
        
        # Add ocean terminal if this is root and drains to ocean
        if node.identifier == self.tributary.tree.root and self.is_ocean:
            if self.terminal_hex is not None:
                land_hexes.append(self.terminal_hex)
        
        if len(land_hexes) >= 2:
            # If this node has a parent, prepend the connection point
            parent_id = node.predecessor(self.tributary.tree.identifier)
            if parent_id is not None:
                parent_node = self.tributary.tree.get_node(parent_id)
                # Last hex of parent is the connection point
                connection_hex = parent_node.data[-1]
                land_hexes.insert(0, connection_hex)
            
            # Sort by elevation: highest first
            land_hexes.sort(key=lambda h: terrain.elevations[h], reverse=True)
            result.append(land_hexes)
    
    return result

In [None]:
#| export
@patch
def segment_to_points(self: Watershed, hexes: list[int]) -> list[MapCord]:
    """Convert hex indices to drawable points, handling ocean termination.
    
    If last hex is at/below sea level, use midpoint between it and 
    the last land hex instead of ocean hex center.
    
    Args:
        hexes: List of hex indices (high→low elevation)
        
    Returns:
        List of MapCord points for drawing
    """
    terrain = self.tributary.terrain
    grid = terrain.hexGrid
    
    if not hexes:
        return []
    
    points = []
    
    # Find first underwater hex (if any)
    ocean_idx = None
    for i, h in enumerate(hexes):
        if terrain.elevations[h] <= 0:
            ocean_idx = i
            break
    
    if ocean_idx is None:
        # No ocean - just use all hex centers
        for h in hexes:
            points.append(grid.hexes[h].center)
    else:
        # Add land hex centers up to (but not including) ocean
        for h in hexes[:ocean_idx]:
            points.append(grid.hexes[h].center)
        
        # Add midpoint between last land hex and ocean hex
        if ocean_idx > 0:
            last_land = grid.hexes[hexes[ocean_idx - 1]].center
            ocean_hex = grid.hexes[hexes[ocean_idx]].center
            midpoint = MapCord(
                (last_land.x + ocean_hex.x) / 2,
                (last_land.y + ocean_hex.y) / 2
            )
            points.append(midpoint)
    
    return points

In [None]:
@patch
def demoRiverSources(self:RiverDemo):
    terrain = TerraDemo().bayArea_map()
    terrain.compute_weather()
    smaller = terrain.shrinkWeather(0.75)
    
    # Find sources with debug output
    sources = smaller.find_river_sources(
        min_precipitation=400,
        min_elevation=300,
        top_n=15,
        debug=True
    )
    
    # Visualize the top sources on the map
    smaller.colorMap()
    smaller.hexGrid.update()
    smaller.hexGrid.adjustRadius(15)
    aRender = smaller.hexGrid.builder
    
    # Mark top sources with circles
    sourceStyle = StyleCSS("source", fill="#ff0000", stroke="#000000", stroke_width=2)
    aRender.add_style(sourceStyle)
    
    markers = ""
    for i, (idx, precip, elev, score) in enumerate(sources[:15]):
        center = smaller.hexGrid.hexes[idx].center
        markers += f'<circle cx="{center.x}" cy="{center.y}" r="8" class="source"/>'
        # Add rank number
        markers += f'<text x="{center.x}" y="{center.y+4}" text-anchor="middle" font-size="10" fill="white">{i+1}</text>'
    
    aRender.adjust("markers", markers)
    
    return aRender.show()

### Alluvial Fans

In [None]:
@patch
def build_delta(self: Watershed, 
                sediment_scale: float = 0.02,      # Doubled from 0.01
                max_hexes: int = 75,                # Increased from 50
                delta_elevation: float = 5.0,       # Raised from 2.0
                uplift_decay: float = 0.6,          # Faster from 0.7
                cone_start_width: int = 2,          # NEW: start with 2 rotations
                debug: bool = False) -> HexRegion:
    """Build alluvial delta by raising ocean hexes in a cone pattern."""
    
    terrain = self.terrain
    grid = terrain.hexGrid
    river = self.tributary
    
    if river.ocean_outlet is None:
        if debug: print("No ocean outlet")
        return HexRegion(hexes=set(), hexGrid=grid)
    
    outlet = river.ocean_outlet
    root_node = river.tree.get_node(river.tree.root)
    path = root_node.data
    
    if len(path) < 2:
        return HexRegion(hexes=set(), hexGrid=grid)
    
    # FIX: Better flow direction - use average of last few hexes if available
    if len(path) >= 4:
        # Average direction over last few hexes for stability
        directions = []
        for i in range(min(3, len(path)-1)):
            pos = grid.index_to_hexposition(path[i+1], path[i])
            directions.append(pos)
        # Use most common direction (mode)
        flow_direction = max(set(directions), key=directions.count)
    else:
        # Fallback: use simple two-hex direction
        pre_outlet_pos = grid.index_to_hexposition(path[1], outlet)
        flow_direction = HexPosition.origin() - pre_outlet_pos
    
    # Sediment budget - more generous
    flows = self.calculate_flow()
    max_flow = flows.get(self.terminal_hex, 0)
    target_hexes = min(int(max_flow * sediment_scale), max_hexes)
    
    # Logarithmic scaling for visual impact
    if max_flow > 100:
        target_hexes = min(int(10 + math.log10(max_flow) * 8), max_hexes)
    
    delta_hexes = set()
    candidates = []
    
    # Seed with ALL ocean neighbors initially
    for neighbor in grid.neighborsOf(outlet):
        if terrain.elevations[neighbor] <= 0:
            heapq.heappush(candidates, (0, 0, neighbor))
    
    while len(delta_hexes) < target_hexes and candidates:
        _, distance, hex_idx = heapq.heappop(candidates)
        
        if hex_idx in delta_hexes:
            continue
        
        # Expanding cone with distance
        hex_pos = grid.index_to_hexposition(hex_idx, outlet)
        allowed_rotations = min(3, cone_start_width + distance // 3)  # 2->2->3->3
        
        if distance < 2 or hex_in_cone(hex_pos, flow_direction, allowed_rotations):
            delta_hexes.add(hex_idx)
            
            # Elevation - higher at apex, lower at edges
            distance_factor = min(distance / 5.0, 1.0)
            elevation = delta_elevation * (1.0 - distance_factor * 0.5)
            terrain.elevations[hex_idx] = elevation
            
            # Add ocean neighbors
            for neighbor in grid.neighborsOf(hex_idx):
                if neighbor not in delta_hexes and terrain.elevations[neighbor] <= 0:
                    n_pos = grid.index_to_hexposition(neighbor, outlet)
                    n_dist = abs(n_pos)
                    dot = float(np.dot(n_pos._coords, flow_direction._coords))
                    priority = -dot / n_dist if n_dist > 0 else 0
                    heapq.heappush(candidates, (priority, distance + 1, neighbor))
    
    # Gradient uplift - more aggressive for visual impact
    distances = {outlet: 0}
    queue = [outlet]
    while queue:
        current = queue.pop(0)
        for neighbor in grid.neighborsOf(current):
            if neighbor in self.region.hexes and neighbor not in distances:
                distances[neighbor] = distances[current] + 1
                queue.append(neighbor)
    
    base_uplift = delta_elevation * 1.5  # Stronger uplift
    for hex_idx, dist in distances.items():
        if terrain.elevations[hex_idx] > 0 and dist < 10:  # Limit range
            uplift = base_uplift * (uplift_decay ** dist)
            if uplift > 0.2:  # Higher threshold
                terrain.elevations[hex_idx] += uplift
    
    return HexRegion(hexes=delta_hexes, hexGrid=grid)

In [None]:
@patch
def build_alluvial_fan(self: Watershed,
                       gradient_threshold: float = 30.0,   # Steep to flat transition
                       fan_radius: int = 4,                # Fan spread in hexes
                       elevation_buildup: float = 15.0,    # Height of fan apex
                       debug: bool = False) -> List[HexRegion]:
    """Create alluvial fans where rivers exit steep terrain into valleys.
    
    Returns list of HexRegion objects representing individual fans.
    """
    terrain = self.terrain
    grid = terrain.hexGrid
    river = self.tributary
    
    fans = []
    gradient = river._calculate_gradient()
    flow = river._calculate_flow()
    
    # Sort river hexes by elevation (high to low)
    sorted_hexes = sorted(river.hexes, 
                         key=lambda h: terrain.elevations[h], 
                         reverse=True)
    
    prev_gradient = None
    
    for i, hex_idx in enumerate(sorted_hexes):
        if i == 0:
            prev_gradient = gradient.get(hex_idx, 0)
            continue
        
        current_gradient = gradient.get(hex_idx, 0)
        
        # Detect steep->flat transition (alluvial fan location)
        if (prev_gradient > gradient_threshold and 
            current_gradient < gradient_threshold * 0.3 and
            terrain.elevations[hex_idx] > 0):
            
            if debug:
                print(f"Fan apex at hex {hex_idx}: "
                      f"gradient {prev_gradient:.1f} -> {current_gradient:.1f}")
            
            # Build fan from apex
            apex = hex_idx
            apex_elev = terrain.elevations[apex]
            fan_flow = flow.get(apex, 1)
            
            # Fan size based on flow
            actual_radius = min(int(fan_radius * math.log10(fan_flow + 10) / 2), 
                               fan_radius)
            
            # Direction: downstream from apex (away from upstream)
            if i > 0:
                # Get direction from upstream to apex
                upstream_hex = sorted_hexes[i-1]
                flow_dir = grid.index_to_hexposition(apex, upstream_hex)
            else:
                flow_dir = HexPosition.SE  # Default downstream direction
            
            # Build cone-shaped fan
            fan_hexes = {apex}
            for radius in range(1, actual_radius + 1):
                ring = terrain.ring(apex, radius)
                
                # Cone angle expands with distance
                allowed_rotations = 1 + radius // 2
                
                for hex_idx in ring:
                    if hex_idx < 0 or terrain.elevations[hex_idx] <= 0:
                        continue
                    
                    # Check if in downstream cone
                    hex_pos = grid.index_to_hexposition(hex_idx, apex)
                    if hex_in_cone(hex_pos, flow_dir, allowed_rotations):
                        # Elevation - decreases with distance
                        distance_factor = radius / actual_radius
                        deposit = elevation_buildup * (1.0 - distance_factor ** 1.5)
                        
                        if deposit > 1.0:  # Only add significant deposits
                            terrain.elevations[hex_idx] += deposit
                            fan_hexes.add(hex_idx)
            
            fans.append(HexRegion(hexes=fan_hexes, hexGrid=grid))
            
            if debug:
                print(f"  Built fan: {len(fan_hexes)} hexes, radius {actual_radius}")
        
        prev_gradient = current_gradient
    
    return fans

### Drawing

In [None]:
#| export
@patch
def drawRiver(self: Watershed, 
         min_width: float = 1.0,
         max_width: float = 8.0,
         min_windiness: float = 0.05,
         max_windiness: float = 0.3,
         min_iterations: int = 2,
         max_iterations: int = 5,
         color: str = "#1565c0",
         opacity: float = 0.7,
         max_flow: float = None,
         debug: bool = False) -> str:
    """Render river with accumulated flow determining width at each point."""
    
    flows = self.calculate_flow()
    if not flows:
        return ""
    
    terrain = self.terrain
    if max_flow is None:
        max_flow = max(flows.values())
    log_max = math.log10(max_flow + 1)
    
    ret = ""
    
    for hexes in self.segments():
        points = self.segment_to_points(hexes)
        if len(points) < 2:
            continue
        
        # Width at each hex based on its accumulated flow
        widths = []
        for h in hexes:
            if terrain.elevations[h] <= 0:
                continue
            flow = flows.get(h, 1)
            log_flow = math.log10(flow + 1)
            ratio = log_flow / log_max if log_max > 0 else 0
            w = min_width + ratio * (max_width - min_width)
            widths.append(w)
        
        if debug:
            print(f"Segment {len(hexes)} hexes: width {widths[0]:.1f} → {widths[-1]:.1f}")
        
        # For now, use average width (variable-width paths need more work)
        avg_width = sum(widths) / len(widths) if widths else min_width
        
        # Windiness from gradient
        land_hexes = [h for h in hexes if terrain.elevations[h] > 0]
        if len(land_hexes) >= 2:
            gradient = terrain.elevations[land_hexes[0]] - terrain.elevations[land_hexes[-1]]
            max_grad = max(terrain.elevations)
            grad_ratio = min(gradient / max_grad, 1.0) if max_grad > 0 else 0
            windiness = max_windiness - grad_ratio * (max_windiness - min_windiness)
            iterations = int(max_iterations - grad_ratio * (max_iterations - min_iterations))
        else:
            windiness = (min_windiness + max_windiness) / 2
            iterations = (min_iterations + max_iterations) // 2
        
        aPath = MapPath(points, style=StyleCSS("dummy"))
        windy = aPath.make_windy(iterations=iterations, offset_factor=windiness)
        
        inline_style = f"fill:none;stroke:{color};stroke-width:{avg_width:.1f};opacity:{opacity}"
        ret += windy.drawSpline(adds=f'style="{inline_style}"')
    
    return ret


In [None]:
#| export
@patch
def drawLake(self:Watershed,
    base_size: int = 3, 
    log_scale: float = 2.0,
    fill="#0000ff",
    max_hexes=5,
    debug: bool = False
    )->str:
    if self.is_ocean:
        return ""

    lake = self.lake_basin(base_size = base_size, log_scale=log_scale,max_hexes=max_hexes,debug=debug)
    
    if not lake.hexes:
        return ""
    
    # Create a new style for the lake
    lake_style = StyleCSS(
        f"lake_",
        fill=fill,  # Blue color for the lake
        #opacity=0.5,  # Adjust opacity as needed
        stroke="none"
    )
    
    self.terrain.builder.add_style(lake_style)  
    return lake.draw(style=lake_style,inset=0.25,f=unique_windy_edge(iterations=3))
    #return self.terrain.styleRegion(lake,lake_style,inset = 0.3,f=unique_windy_edge)
    

In [None]:
#| export
@patch
def draw(self: Watershed, 
         min_width: float = 1.0,
         max_width: float = 8.0,
         min_windiness: float = 0.05,
         max_windiness: float = 0.3,
         min_iterations: int = 2,
         max_iterations: int = 5,
         color: str = "#1565c0",
         opacity: float = 0.7,
         max_flow:float = None,
         lake_base_size: int = 1,
         lake_max_size: int = 3,
         lake_log_scale: float = 2.0,
         debug: bool = False) -> str:

        if 'precipitation' not in self.terrain.fields:
                print("Computing precipitation first...")
                self.terrain.compute_precipitation_sb()

        rivers = self.drawRiver(
                min_width = min_width,
                max_width =  max_width,
                min_windiness = min_windiness,
                max_windiness = max_windiness,
                min_iterations  =  min_iterations,
                max_iterations = max_iterations,
                max_flow = max_flow,
                color =color,
                opacity = opacity,
                debug = debug
                )

        #self:Watershed,base_size: int = 3, log_scale: float = 2.0,fill="#0000ff"
        rivers += self.drawLake(base_size=lake_base_size,
        log_scale=lake_log_scale,
        fill=color,
        max_hexes=lake_max_size,
        debug=debug
        )
        return rivers

In [None]:
@patch
def demo_sanF(self: RiverDemo, num_peaks: int = 30, debug: bool = False):
    """Demo showing watershed boundaries and rivers."""
    terrain = TerraDemo().bayArea_map()
    terrain.compute_weather()
    smaller = terrain.shrinkWeather(0.75)
    sgrid = smaller.hexGrid
    smaller.colorMap()
    sgrid.update()
    builder = sgrid.builder
    
    # Compute all watersheds
    watersheds = Watershed.compute_all(smaller, num_peaks=num_peaks, debug=debug)
    
    
    
    # Draw rivers on top
    river_style = StyleCSS("river", fill="none", stroke="#1a5276", stroke_width=2, stroke_linecap="round")
    builder.add_style(river_style)
    
    rivers_svg = ""
    for ws in watersheds:
        rivers_svg += ws.draw()

    
    builder.adjust("rivers", rivers_svg)
    
    # Stats
    total_land = sum(1 for e in smaller.elevations if e > 0)
    covered = sum(len(ws.region.hexes) for ws in watersheds)
    print(f"Watersheds: {len(watersheds)}, Coverage: {covered}/{total_land} ({100*covered/total_land:.1f}%)")
    
    return builder.show()


In [None]:
RiverDemo().demo_sanF()

Watersheds: 63, Coverage: 422/559 (75.5%)


In [None]:
@patch
def coneMap(self:RiverDemo):
    mySize = MapSize(500, 300)
    myBounds = MapRect(MapCord(0,0), mySize)
    
    t, plates = Terrain.fromSeeds(
        myBounds, radius=15, num_plates=10, 
        formation_type='ocean_distance',  # 'ridge', 'volcanic', 'rift', 'rolling'
    # Fine-tuning
    elevation_scale=3.0,
        oceanic_sides=['N','E','W'],
         
        age='young',  # Fewer subdivisions = clearer boundaries
        seed=42
    )
    t.colorMap()
    t.hexGrid.update()

    t.geo = GeoBounds(
        lat_min=20.57,   # Southern tip (near Makena)
        lat_max=21.03,   # Northern tip (near Kahakuloa)
        lon_min=-156.69, # Western tip (West Maui)
        lon_max=-155.97  # Eastern tip (Haleakalā/Hāna)
    )
    
    # Compute hex coordinates
    t._compute_hex_coordinates()
    
    # Maui-specific precipitation model
    # Trade winds from northeast at ~50-60 degrees
    t.climate = ClimatePreset(
    name='Maui',
    lat_range=(20.57, 21.03),  # Your actual CA bounds
    base_temp_range=(26, 28),
    wind_speed=8.0,        # Slightly stronger trade winds
    wind_dir=50.0,         
    precip_base=0.15,      # MUCH higher base moisture (tropical ocean)
    nm=0.008,              # Less stable (more convection)
    hw=2500.0,             # Higher moisture scale height
    cw=0.003,              # MUCH stronger orographic effect
    conv_time=1000.0,      # Faster conversion
    fall_time=1000.0 
    )
    t.compute_weather()

    watersheds = Watershed.compute_all(t, num_peaks=1)
    
    # Draw rivers on top
    river_style = StyleCSS("river", fill="none", stroke="#1a5276", stroke_width=2, stroke_linecap="round")
    t.hexGrid.builder.add_style(river_style)
    
    rivers_svg = ""
    retShed = watersheds[5]
    rivers_svg += retShed.draw()

    t.hexGrid.builder.adjust("rivers", rivers_svg)
    return t, retShed

In [None]:
terr, shed  = RiverDemo().coneMap()
shed.terminal_hex, shed.is_ocean

(568, True)

In [None]:
terr.hexGrid.builder.show()

In [None]:
#| export
@patch
def buildUp(self:Terrain,shed,rings=3,ele=7):
    self.elevations[shed.terminal_hex ]= 11

    for ring in range(rings):
        for hex in self.ring(shed.terminal_hex, ring):
            self.elevations[hex] = max( self.elevations[hex] ,ele-2*ring)

    path = self.find_drainage_path(shed.terminal_hex)
    self.elevations[shed.terminal_hex ]= len(path) + 2
    for i in range(len(path) - 1):
        curr, next_hex = path[i], path[i + 1]
        
        if self.elevations[next_hex] >= self.elevations[curr]:
            self.elevations[next_hex] = len(path) - i - 1

    self.elevations[path[-1]] = 0



In [None]:
terr.colorMap()
terr.hexGrid.update()
watersheds = Watershed.compute_all(terr, num_peaks=1)

rivers_svg = ""
retShed = watersheds[5]
rivers_svg += retShed.draw()

terr.hexGrid.builder.adjust("rivers", rivers_svg)
terr.hexGrid.builder.show()

In [None]:
# Check the terminal hex and its neighbors
terminal = retShed.terminal_hex
print(f"Terminal hex {terminal}: elevation {terr.elevations[terminal]:.1f}")
for n in terr.hexGrid.neighborsOf(terminal):
    if n >= 0:
        print(f"  Neighbor {n}: elevation {terr.elevations[n]:.1f}")


In [None]:
@patch
def demoBuildup(self: RiverDemo, num_peaks: int = 30, debug: bool = False):
    """Demo showing watershed boundaries and rivers."""
    terrain = TerraDemo().bayArea_map()
    terrain.compute_weather()
    smaller = terrain.shrinkWeather(0.75)
    sgrid = smaller.hexGrid
    
    builder = sgrid.builder
    
    # Compute all watersheds
    watersheds = Watershed.compute_all(smaller, num_peaks=num_peaks, debug=debug)
    for ws in watersheds:
        smaller.buildUp(ws,rings=5)

    watersheds = Watershed.compute_all(smaller, num_peaks=num_peaks, debug=debug)
    smaller.colorMap()
    sgrid.update()
    # Draw rivers on top
    river_style = StyleCSS("river", fill="none", stroke="#1a5276", stroke_width=2, stroke_linecap="round")
    builder.add_style(river_style)
    
    rivers_svg = ""
    for ws in watersheds:
        rivers_svg += ws.draw()

    
    builder.adjust("rivers", rivers_svg)
    
    # Stats
    total_land = sum(1 for e in smaller.elevations if e > 0)
    covered = sum(len(ws.region.hexes) for ws in watersheds)
    print(f"Watersheds: {len(watersheds)}, Coverage: {covered}/{total_land} ({100*covered/total_land:.1f}%)")
    
    return builder.show()

In [None]:
RiverDemo().demoBuildup()

Watersheds: 41, Coverage: 549/771 (71.2%)


In [None]:
RiverDemo().demo_sanF()

Watersheds: 63, Coverage: 422/559 (75.5%)
