# River
runs through

In [None]:
#| default_exp water/river

### Prior Art

In [None]:
#| export
#standard
import numpy as np
import sys
import os
import math
import random

#data
from collections import namedtuple, deque
from dataclasses import dataclass,  field, asdict
from typing import List
from enum import Enum

#Jeremy
from dialoghelper import * 
from fastcore.basics import patch
from fasthtml.common import *
from fasthtml.jupyter import *
import httpx

# unique
from treelib import Tree
import heapq


In [None]:
#| export
from pathlib import Path
sys.path.insert(0, str(Path().resolve().parent.parent))

In [None]:
#| export


from HexMagic.styles import StyleCSS, SVGBuilder,SVGLayer, SVGPatternLoader, preview, app, StyleDemo,NamedColor
from HexMagic.primitives import MapCord, MapSize, MapRect, MapPath, Hex, HexGrid, HexWrapper, HexPosition, hexBackground,windy_edge, HexRegion, unique_windy_edge
from HexMagic.terrain import  TerraDemo, Terrain
from HexMagic.terrainpatterns import TerrainPatterns


In [None]:
#| export
from HexMagic.voronoi import PlateKind
from HexMagic.water.soil import SoilSystem, SoilType
from HexMagic.weather import TerraDemo

#| export
### Helpers

In [None]:
#| export
class RiverDemo:
    def __init__(self):
        self.help = ""

## River

The `River` class represents a river network as a tree structure using `treelib.Tree`.

**Core Data Structure**
- `tree`: A `treelib.Tree` where each node's `data` is a list of hex indices forming a continuous segment
- `hexes`: A flat `set` of all hex indices for quick intersection checks
- Root of tree = outlet (lowest point); branches go upstream toward sources

**Creating Rivers: `from_peak(terrain, peak_index)`**
1. Traces downhill from the peak using `terrain.lowest_neighbor()`
2. Stops when hitting water (elevation < 1) or a local minimum
3. Reverses the path so outlet is root
4. Creates a single-node tree with that path

**Merging Algorithm: `combine_rivers(rivers)`**
1. Loop until no changes occur
2. Check each pair of rivers for hex intersection
3. When found, call `_merge_with()` which:
   - Finds the intersection hex (picks highest elevation if multiple)
   - Locates which tree node contains it
   - Splits that node if needed (downstream stays, upstream becomes child)
   - Trims the other river's tree to just its upstream portion
   - Attaches it as a new branch at the intersection point

**Drawing: `drawRiver(river, riverStyle)`**
- Iterates through all tree nodes
- For each node, builds a list of center points from hex indices
- If the node has a parent, prepends the parent's first hex center (for continuity)
- Draws a spline through those points using `MapPath.drawSpline()`

The result: rivers that branch upstream from a single outlet, drawn as smooth curves through hex centers.

In [None]:
#| export
class River:
    def __init__(self, terrain):
        self.terrain = terrain
        self.tree = Tree()
        self.hexes = set()  # All hex indices for quick intersection checks
        self.ocean_outlet = None

    @property
    def is_empty(self):
        return self.tree.size() == 0
    
    @classmethod
    def from_peak(cls, terrain, peak_index):
        """Create a river by tracing downhill from a peak."""
        path = [peak_index]
        current = peak_index
        ocean_outlet = None

        while True:
            lowest = terrain.lowest_neighbor(current)
            if lowest is None:
                break
            if terrain.elevations[lowest] <= 0:
                ocean_outlet = lowest  # ← Save the ocean hex
                path.append(lowest)  # ← Add this
                break
            path.append(lowest)
            current = lowest

        if len(path) < 2:
            return None

        path.reverse()
        river = cls(terrain)
        river.tree.create_node(tag="segment", identifier=0, data=path)
        river.hexes.update(path)
        river.ocean_outlet = ocean_outlet  # ← Set it on the river
        
        return river

    
    @staticmethod
    def combine_rivers(rivers):
        """Merge intersecting rivers, return list with no intersections."""
        result = list(rivers)
        
        changed = True
        while changed:
            changed = False
            for i in range(len(result)):
                for j in range(i + 1, len(result)):
                    # Check for intersection
                    intersection = result[i].hexes & result[j].hexes
                    if intersection:
                        # Merge j into i
                        merged = result[i]._merge_with(result[j], intersection)
                        result[i] = merged
                        result.pop(j)
                        changed = True
                        break
                if changed:
                    break
        
        result = [r for r in result if not r.is_empty]
        return result
    
    def _merge_with(self, other, intersection):
        """Merge another river into this one at intersection point."""
        # TODO: implement tree merging logic
        # For now, just combine the hex sets
        merged = River(self.terrain)
        merged.hexes = self.hexes | other.hexes
        # Tree merging is more complex - need to attach other's
        # upstream portion at the intersection point

        merged.ocean_outlet = self.ocean_outlet or other.ocean_outlet

        return merged


In [None]:
#| export
# In River class

@patch  
def show(self:River):
    """Display river tree structure with hex indices."""
    if self.tree.size() == 0:
        print("Empty river")
        return
    
    # Recursively print the tree with custom formatting
    def print_node(node, indent=0):
        # Print the hex indices for this segment
        indices_str = ", ".join(map(str, node.data))
        print("    " * indent + indices_str)
        
        # Print children
        children = self.tree.children(node.identifier)
        for child in children:
            print_node(child, indent + 1)
    
    # Start from root
    root = self.tree.get_node(self.tree.root)
    print_node(root)



In [None]:
#| export
@patch
def midTerrain(self:RiverDemo):

    mySize = MapSize(480,480)
    myBounds = MapRect(MapCord(0,0), mySize)
    sampleMap =  Terrain(myBounds,radius=15,path = "volcano.svg")
    for center in [23,46,57,81,123,145,167,189,211,233,255]:
        sampleMap.volcano(center=center*3,adjusted=200,num_rings=8,variability=0.5,initial_threshold=0.4)
    sampleMap.colorMap()
    sampleMap.hexGrid.update()

    return sampleMap

@patch
def demoRiver(self:RiverDemo):
    sampleMap = self.midTerrain()
    peaks = sampleMap.find_peaks(40,1)
    rivers = [River.from_peak(sampleMap, peak) for peak in peaks]
    rivers = [r for r in rivers if r is not None]  # Filter out None rivers
    rivers = River.combine_rivers(rivers)

    for river in rivers:
        print("====")
        river.show()
    



In [None]:
RiverDemo().demoRiver()

### Merging Rivers

In [None]:
#| export
@patch 
def _merge_with(self:River, other, intersection):
    """Merge another river into this one at intersection point."""
    merged = River(self.terrain)
    
    # Start with a copy of self's tree
    merged.tree = Tree(self.tree, deep=True)
    merged.hexes = self.hexes.copy()
    
    # Find the furthest upstream intersection point (highest elevation)
    intersection_hex = max(intersection, key=lambda h: self.terrain.elevations[h])
    
    for node in merged.tree.all_nodes():
        if intersection_hex in node.data:
            # Found the node containing intersection
            idx = node.data.index(intersection_hex)
            
            if idx == len(node.data) - 1:
                # Case 1: intersection at end of list
                attach_point = node.identifier
            else:
                # Case 2: intersection in middle - need to split
                # Keep [0:idx+1] in current node
                downstream = node.data[:idx+1]
                upstream = node.data[idx+1:]
                
                node.data = downstream
                
                # Create new child with upstream portion
                new_id = merged.tree.size()
                merged.tree.create_node(
                    tag="segment",
                    identifier=new_id,
                    parent=node.identifier,
                    data=upstream
                )
                
                # Move existing children to new node
                for child in merged.tree.children(node.identifier):
                    if child.identifier != new_id:
                        merged.tree.move_node(child.identifier, new_id)
                
                attach_point = node.identifier
            
            # Now trim and attach other tree
            trimmed_other = merged._trim_tree(other, intersection_hex)
            if trimmed_other:
                merged._attach_tree(trimmed_other, attach_point)
            
            break
    
    merged.hexes.update(other.hexes)
    return merged


@patch 
def _trim_tree(self:River, other, intersection_hex):
    """Return a copy of other's tree with everything at/below intersection removed."""
    # Find node containing intersection in other's tree
    for node in other.tree.all_nodes():
        if intersection_hex in node.data:
            idx = node.data.index(intersection_hex)
            
            # Keep only portion after intersection
            if idx == len(node.data) - 1:
                # Intersection at end, keep children
                # Return subtree starting from children
                trimmed = Tree()
                for child in other.tree.children(node.identifier):
                    trimmed = other.tree.subtree(child.identifier)
                    break  # For now just take first child, may need to handle multiple
                return trimmed
            else:
                # Split node, keep upstream portion
                trimmed = Tree(other.tree, deep=True)
                trim_node = trimmed.get_node(node.identifier)
                trim_node.data = trim_node.data[idx+1:]
                return trimmed
    
    return None
@patch 
def _attach_tree(self:River, other_tree, parent_id):
    """Attach other_tree's root(s) as children of parent_id."""
    if other_tree.size() == 0:
        return
    
    # Get next available ID
    next_id = self.tree.size()
    
    # Recursively copy nodes from other_tree
    def copy_subtree(other_node, new_parent_id):
        nonlocal next_id
        new_node_id = next_id
        next_id += 1
        
        self.tree.create_node(
            tag=other_node.tag,
            identifier=new_node_id,
            parent=new_parent_id,
            data=other_node.data.copy()
        )
        
        for child in other_tree.children(other_node.identifier):
            copy_subtree(child, new_node_id)
    
    # Copy from other's root
    other_root = other_tree.get_node(other_tree.root)
    copy_subtree(other_root, parent_id)


In [None]:
#| export
@patch
def showIndices(self:Terrain):
    ret = ""
    i = 0
    for row in range(self.hexGrid.nRows):
        line = ""
        for col in range(self.hexGrid.nCols):
            line += f"{i:3} "
            i += 1
        ret += line + "\n"
    return ret


In [None]:
#| export


@patch
def demoRiverMerge(self:RiverDemo):
    sampleMap = TerraDemo().tiny()
    print(sampleMap.showIndices())
    print()
    
    # Create three rivers manually
    river1 = River(sampleMap)
    river1.tree.create_node(tag="segment", identifier=0, data=[4, 5, 6])
    river1.hexes = {4, 5, 6}
    
    river2 = River(sampleMap)
    river2.tree.create_node(tag="segment", identifier=0, data=[5, 9, 14])
    river2.hexes = {5, 9, 14}
    
    river3 = River(sampleMap)
    river3.tree.create_node(tag="segment", identifier=0, data=[14, 11])
    river3.hexes = {14, 11}
    
    print("Before merge:")
    print("River 1:")
    river1.show()
    print("\nRiver 2:")
    river2.show()
    print("\nRiver 3:")
    river3.show()
    
    # Merge them
    rivers = [river1, river2, river3]
    merged = River.combine_rivers(rivers)
    
    print("\n\nAfter merge:")
    for i, river in enumerate(merged):
        print(f"\nMerged River {i+1}:")
        river.show()
    
    




In [None]:
RiverDemo().demoRiverMerge()

  0   1   2   3 
  4   5   6   7 
  8   9  10  11 
 12  13  14  15 


Before merge:
River 1:
4, 5, 6

River 2:
5, 9, 14

River 3:
14, 11


After merge:

Merged River 1:
4, 5
    6
    9, 14
        11


### Working on Flow

Lets draw the rivers.

In [None]:
#| export
@patch
def svg(self: River, styles=[StyleCSS("river", fill="none", stroke="blue", stroke_width=4)]) -> str:
    """Render river as SVG path.
    
    Args:
        styles: List of styles for different depth levels (thicker for main channels)
    
    Returns:
        SVG path string
    """
    ret = ""
    terrain = self.terrain
   
    for i, node in enumerate(self.tree.all_nodes()):
        points = []
        
        # If this node has a parent, add connection point
        parent_id = node.predecessor(self.tree.identifier)
        if parent_id is not None:
            parent_node = self.tree.get_node(parent_id)
            # Connect last point of child to first point of parent
            parent_first_hex = parent_node.data[-1]
            points.append(terrain.hexGrid.hexes[parent_first_hex].center)

        # Add all hex centers in this segment
        points.extend([terrain.hexGrid.hexes[idx].center for idx in node.data])
        
        # If this segment has ocean_outlet, extend to the border
        if self.ocean_outlet is not None:
            # Check if this is the outlet segment (last in path)
            if len(node.data) > 0:
                last_hex_idx = node.data[-1]
                last_hex = terrain.hexGrid.hexes[last_hex_idx]
                ocean_hex = terrain.hexGrid.hexes[self.ocean_outlet]
                
                # Add edge point (midpoint between last land hex and ocean hex)
                edge_point = MapCord(
                    (last_hex.center.x + ocean_hex.center.x) / 2,
                    (last_hex.center.y + ocean_hex.center.y) / 2
                )
                points.append(edge_point)

        # Select style based on tree depth (deeper = thicker)
        depth = self.tree.depth(node.identifier)
        style_index = min(depth, len(styles) - 1)
        
        # Create path and make it windy
        aPath = MapPath([], style=styles[style_index])
        aPath.points = points
        windy_river = aPath.make_windy(iterations=max(5 - i, 2), offset_factor=0.2)
        ret += windy_river.drawSpline()

    return ret


In [None]:


@patch
def demoSanRivers(self:RiverDemo):
    riverStyle = StyleCSS("river",fill="none",stroke="blue",stroke_width=4)
    sampleMap = TerraDemo().pompeii_map().shrinkWeather(0.25)
    
    sampleMap.hexGrid.adjustRadius(20)
    aRender = sampleMap.hexGrid.builder
   
    peaks = sampleMap.find_peaks(35,7)

    rivers = [River.from_peak(sampleMap, peak) for peak in peaks]
    rivers = [r for r in rivers if r is not None]  # Filter out None rivers
    colors = []
    for i, c  in enumerate(StyleCSS.seaborn("muted")):
        aStyle = StyleCSS(f"river{i}",fill="none",stroke=c.properties["fill"] ,stroke_width=(i+1)*4)
        colors.append(aStyle)
        aRender.add_style(aStyle)
    
        
    #print(StyleCSS.generate(colors))
    # Merge them
    merged = River.combine_rivers(rivers)
    rivSVG = ""
    for i, v in enumerate(merged):
        rivSVG += v.svg(styles = colors)
        
    sampleMap.colorMap()
    sampleMap.hexGrid.update()
    sgrid = sampleMap.hexGrid
    sgrid.builder.adjust("regions", sgrid.styleLayer(f=windy_edge(iterations=2, offset_factor=0.1)))


    aRender.adjust("rivers", rivSVG)
    #aRender.adjust("root","")
    
    return aRender.show()
    


In [None]:
RiverDemo().demoSanRivers()

## Erode

In [None]:
#| export
@patch
def getLargestRiver(self:Terrain, num_peaks=40, min_height=1):
    """Find peaks, create rivers, merge them, and return the largest one."""
    peaks = self.find_peaks(num_peaks, min_height)
    rivers = [River.from_peak(self, peak) for peak in peaks]
    rivers = [r for r in rivers if r is not None]
    merged = River.combine_rivers(rivers)
    
    if not merged:
        return None
    
    # Return river with most hexes
    return max(merged, key=lambda r: len(r.hexes))


In [None]:
#| export
@patch
def _calculate_flow(self:River):
    """Calculate accumulated flow for each hex in the river.
    Returns dict mapping hex_index -> flow_count"""
    flow = {}
    
    def traverse(node):
        # Start with hexes in this segment
        segment_flow = {}
        
        # Get flow from all children first (depth-first)
        child_flows = []
        for child in self.tree.children(node.identifier):
            child_flows.append(traverse(child))
        
        # Process this segment from end to beginning (upstream to downstream)
        for i in range(len(node.data) - 1, -1, -1):
            hex_idx = node.data[i]
            
            if i == len(node.data) - 1:
                # Last hex in segment - accumulate from children
                accumulated = 1  # This hex itself
                for child_flow in child_flows:
                    # Get flow from first hex of each child
                    accumulated += child_flow.get(node.data[i], 0)
            else:
                # Middle of segment - accumulate from next hex
                accumulated = 1 + segment_flow[node.data[i + 1]]
            
            segment_flow[hex_idx] = accumulated
            flow[hex_idx] = accumulated
        
        return segment_flow
    
    root = self.tree.get_node(self.tree.root)
    traverse(root)
    return flow


In [None]:
@patch
def demoFlow(self:RiverDemo):

    fills=["#d4ff00eb","#ffb300ff","#ff0073ff","#9900ff97","#1e0e45eb"]
    riverStyle = StyleCSS("river",fill="none",stroke="blue",stroke_width=4)

    terrain = TerraDemo().bayArea_map()
    terrain.compute_weather()
    smaller = terrain.shrinkWeather(0.75)
    smaller.colorMap()
    smaller.hexGrid.update()
    sampleMap = smaller
    
    sampleMap.hexGrid.adjustRadius(15)
    aRender = sampleMap.hexGrid.builder
   
    peaks = sampleMap.find_peaks(35,7)

    rivers = [River.from_peak(sampleMap, peak) for peak in peaks]
    rivers = [r for r in rivers if r is not None]  # Filter out None rivers
    colors = []
    legends = []
    for i, c  in enumerate(fills):
        aStyle = StyleCSS(f"river{i}",fill="none",stroke=c ,stroke_width=4)
        colors.append(aStyle)
        aRender.add_style(aStyle)

        aStyle = StyleCSS(f"Level_{i+1}",fill=c,stroke="#000000" ,stroke_width=2)
        legends.append(aStyle)
        aRender.add_style(aStyle)
    
    #print(StyleCSS.generate(colors))
    # Merge them
    merged = River.combine_rivers(rivers)
    rivSVG = ""
    

    flowData = np.zeros(len(sampleMap.hexGrid.hexes)) - 1

    for i, stream in enumerate(merged):
        rivSVG += stream.svg(styles = colors)
        flows = stream._calculate_flow()
        streamFlow = np.zeros(len(sampleMap.hexGrid.hexes)) - 1
        for k,v in flows.items():
            streamFlow[k] = v
        flowData = np.maximum(flowData,streamFlow)

    sampleMap.colorMap()
    sampleMap.hexGrid.update()
    
    flowData = [int(x) for x in flowData]
    # Create patterns and overlay
    patternGen = TerrainPatterns(sampleMap)
    patterns = patternGen.ballDensity(max(flowData),fills=fills)  # 5 levels
    sampleMap.makeOverlay(flowData, patterns)
    
    sampleMap.addCoast()
    sampleMap.colorMap()
    sampleMap.hexGrid.update()

    aRender.adjust("rivers", rivSVG)
    aRender.adjust("root","")
    aRender.adjust("legend",aRender.legendOverlay(legends))
    

    return aRender.show()


In [None]:

    

RiverDemo().demoFlow()

In [None]:
#| export
@patch
def _calculate_gradient(self:River):
    """Calculate elevation gradient for each hex in the river.
    Returns dict mapping hex_index -> elevation_drop"""
    gradient = {}
    
    def traverse(node):
        # Process each hex in this segment
        for i in range(len(node.data)):
            hex_idx = node.data[i]
            current_elev = self.terrain.elevations[hex_idx]
            
            # Find next downstream hex
            if i > 0:
                # Next hex is earlier in the list (toward outlet)
                next_hex = node.data[i - 1]
            else:
                # At start of segment - check parent
                parent_id = node.predecessor(self.tree.identifier)
                if parent_id is not None:
                    parent_node = self.tree.get_node(parent_id)
                    next_hex = parent_node.data[-1]  # Last hex of parent
                else:
                    # At outlet - no gradient
                    gradient[hex_idx] = 0
                    continue
            
            next_elev = self.terrain.elevations[next_hex]
            drop = max(0, current_elev - next_elev)  # Only positive drops
            gradient[hex_idx] = drop
        
        # Recurse to children
        for child in self.tree.children(node.identifier):
            traverse(child)
    
    root = self.tree.get_node(self.tree.root)
    traverse(root)
    return gradient


#| export
### TerrIan



In [None]:
#| export

#find_river_sources() - source detection
@patch
def find_river_sources(self: Terrain, 
                       min_precipitation=500,   # mm/year
                       min_elevation=500,       # meters
                       top_n=20,
                       debug = False):               # Return top N sources
    """
    Find potential river source locations based on precipitation and elevation.
    
    Good sources have:
    - High precipitation (rain/snowmelt)
    - High elevation (mountains to flow down from)
    - Not ocean
    
    Returns list of (hex_index, score) tuples sorted by score.
    """
    if 'precipitation' not in self.fields:
        raise ValueError("Must compute precipitation first (compute_precipitation_sb)")
    
    sources = []
    
    for i in range(len(self.elevations)):
        elev = self.elevations[i]
        precip = self.fields['precipitation'][i]
        
        # Skip ocean
        if elev <= 0:
            continue
        
        # Check thresholds
        if precip >= min_precipitation and elev >= min_elevation:
            # Score = precipitation * elevation (higher = better source)
            score = precip * (elev / 1000.0)  # Normalize elevation
            sources.append((i, precip, elev, score))
    
    # Sort by score descending
    sources.sort(key=lambda x: x[3], reverse=True)
    if debug:
        print(f"\n=== RIVER SOURCE ANALYSIS ===")
        print(f"Found {len(sources)} potential sources (precip >= {min_precipitation}mm, elev >= {min_elevation}m)")
    
        if sources:
            print(f"\nTop {min(top_n, len(sources))} sources:")
            print(f"{'Rank':<6} {'Hex':<8} {'Precip':<12} {'Elev':<10} {'Score':<10}")
            print("-" * 50)
            for rank, (idx, precip, elev, score) in enumerate(sources[:top_n], 1):
                print(f"{rank:<6} {idx:<8} {precip:<12.0f} {elev:<10.0f} {score:<10.1f}")
    
    return sources

In [None]:
@patch
def demoRiverSources(self:RiverDemo):
    terrain = TerraDemo().bayArea_map()
    terrain.compute_weather()
    smaller = terrain.shrinkWeather(0.75)
    
    # Find sources with debug output
    sources = smaller.find_river_sources(
        min_precipitation=400,
        min_elevation=300,
        top_n=15,
        debug=True
    )
    
    # Visualize the top sources on the map
    smaller.colorMap()
    smaller.hexGrid.update()
    smaller.hexGrid.adjustRadius(15)
    aRender = smaller.hexGrid.builder
    
    # Mark top sources with circles
    sourceStyle = StyleCSS("source", fill="#ff0000", stroke="#000000", stroke_width=2)
    aRender.add_style(sourceStyle)
    
    markers = ""
    for i, (idx, precip, elev, score) in enumerate(sources[:15]):
        center = smaller.hexGrid.hexes[idx].center
        markers += f'<circle cx="{center.x}" cy="{center.y}" r="8" class="source"/>'
        # Add rank number
        markers += f'<text x="{center.x}" y="{center.y+4}" text-anchor="middle" font-size="10" fill="white">{i+1}</text>'
    
    aRender.adjust("markers", markers)
    
    return aRender.show()


In [None]:
RiverDemo().demoRiverSources()


=== RIVER SOURCE ANALYSIS ===
Found 106 potential sources (precip >= 400mm, elev >= 300m)

Top 15 sources:
Rank   Hex      Precip       Elev       Score     
--------------------------------------------------
1      288      3142         1000       3141.8    
2      306      3077         1000       3077.0    
3      270      2813         827        2326.1    
4      324      2836         801        2270.4    
5      289      2401         891        2138.6    
6      271      2366         821        1943.0    
7      325      2190         644        1410.3    
8      252      2695         518        1395.8    
9      72       2242         622        1394.3    
10     90       2228         589        1311.5    
11     54       2096         622        1303.6    
12     342      2407         530        1277.0    
13     925      1795         706        1267.9    
14     907      1709         706        1207.4    
15     307      1747         691        1207.4    


## Terraform

In [None]:
#| export
@patch
def carve(self:River, base_erosion=0.5, valley_width=3, lower = 40, shape='river'):
    """Carve a valley along this river/glacier path."""
    
    adjustments = {}
    adjusted_hexes = set()
    
    # 1. Sort by elevation (highest first = upstream to downstream)
    sorted_hexes = sorted(self.hexes, key=lambda i: self.terrain.elevations[i], reverse=True)
    
    cumulative_flow = 0

    for hex_idx in sorted_hexes:
        self.terrain.elevations[hex_idx] = max(self.terrain.elevations[hex_idx]-lower,0)
    
    for i, hex_idx in enumerate(sorted_hexes):
        # 2. Calculate on the fly
        cumulative_flow += 1  # Each hex adds unit flow
        
        if i < len(sorted_hexes) - 1:
            next_hex = sorted_hexes[i + 1]
            gradient = self.terrain.elevations[hex_idx] - self.terrain.elevations[next_hex] 
        else:
            gradient = 0.1  # Small default at mouth
        
        power = cumulative_flow * max(gradient, 0.01)
        erosion_amount = base_erosion * power
        
        # Erode river hex itself
        if hex_idx not in adjusted_hexes:
            adjustments[hex_idx] = -erosion_amount
            adjusted_hexes.add(hex_idx)
        
        # 3. Spread to neighbors (only unadjusted hexes)
        for dist in range(1, valley_width + 1):
            ring = self.terrain.ring(hex_idx, dist)
            
            for n in ring:
                if n in adjusted_hexes:
                    continue
                
                if shape == 'river':
                    decay = erosion_amount * (0.5 ** dist)
                else:  # glacier
                    decay = erosion_amount * (1 - dist / (valley_width + 1))
                
                adjustments[n] = -decay
                adjusted_hexes.add(n)
    
    return adjustments


In [None]:
#| export


def find_local_minima(self: Terrain) -> list[int]:
    """Return indices of land hexes that are local minima (potential lakes)."""
    directions = self.flow_directions()
    return [i for i in range(len(self.elevations)) 
            if directions[i] == -1 and self.elevations[i] > 0]


@patch
def find_drainage_path(self: Terrain, start: int) -> list[int]:
    """Find lowest-cost path from start to ocean using Dijkstra.
    
    Cost = elevation we'd need to carve (downhill = 0, uphill = diff).
    """
    
    
    ocean = {i for i in range(len(self.elevations)) if self.elevations[i] <= 0}
    if start in ocean:
        return [start]
    
    # (cost, current_hex, path)
    pq = [(0, start, [start])]
    visited = set()
    
    while pq:
        cost, current, path = heapq.heappop(pq)
        
        if current in visited:
            continue
        visited.add(current)
        
        if current in ocean:
            return path
        
        for neighbor in self.hexGrid.neighborsOf(current):
            if neighbor < 0 or neighbor in visited:
                continue
            
            # Cost = how much we'd need to lower neighbor to flow from current
            elev_curr = self.elevations[current]
            elev_neighbor = self.elevations[neighbor]
            
            # To flow, neighbor must be < current
            carve_cost = max(0, elev_neighbor - elev_curr + 1)
            
            new_cost = cost + carve_cost
            heapq.heappush(pq, (new_cost, neighbor, path + [neighbor]))
    
    return []  # No path found




In [None]:
#| export
@patch
def carve_to_ocean(self: Terrain, num_lakes: int = 5, max_iters: int = 10) ->  list[River]:
    """Carve drainage using lowest-cost paths to ocean."""
    paths = []
    
    for iteration in range(max_iters):
        minima = find_local_minima(self)
        if len(minima) <= num_lakes:
            print(f"Done at iter {iteration}: {len(minima)} lakes")
            break
        
        # Keep highest elevation minima as lakes
        minima.sort(key=lambda i: self.elevations[i], reverse=True)
        drain_these = minima[num_lakes:]
        
        for lake_idx in drain_these:
            path = self.find_drainage_path(lake_idx)
            river = River(terrain=self)
                   
            # Create root node with the path as a single segment
            river.tree.create_node(tag="segment", identifier=0, data=path)
            reversed(path)
            river.hexes.update(path)
            paths.append(river)
           
            
            if len(path) < 2:
                continue
            
            # Carve monotonically decreasing along path
            for i in range(len(path) - 1):
                curr, next_hex = path[i], path[i + 1]
                
                if self.elevations[next_hex] >= self.elevations[curr]:
                    self.elevations[next_hex] = self.elevations[curr] - 1
        
        new_count = len(find_local_minima(self))
        
    
    return  paths

In [None]:
@patch
def createCarve(self:RiverDemo,scale = 3.0):

  
    sampleMap =  TerraDemo().pompeii_map().shrinkWeather(0.5)

    #sampleMap.elevations = sampleMap.elevations * scale
    
    sampleMap = TerraDemo().pompeii_map().shrinkWeather(0.25)
    rivers = sampleMap.carve_to_ocean(2)
    print(f"intial rivers {len(rivers)}")
    
    merged = River.combine_rivers(rivers)
    print(f"merged rivers {len(merged)}")
    
    # Create patterns and overlay
    return  sampleMap, merged

In [None]:
@patch
def showCarve(self:RiverDemo,sampleMap:Terrain,rivers):

    #print(f"working with {len(rivers)} number of rivers")

    fills=["#d4ff00eb","#ffb300ff","#ff0073ff","#9900ff97","#1e0e45eb"]
    riverStyle = StyleCSS("river",fill="none",stroke="blue",stroke_width=4)
    
    
    sampleMap.hexGrid.adjustRadius(15)
    aRender = sampleMap.hexGrid.builder
    rivSVG = ""
 
    colors = []
    legends = []
    for i, c  in enumerate(fills):
        aStyle = StyleCSS(f"river{i}",fill="none",stroke=c ,stroke_width=4)
        colors.append(aStyle)
        aRender.add_style(aStyle)

        aStyle = StyleCSS(f"Level_{i+1}",fill=c,stroke="#000000" ,stroke_width=2)
        legends.append(aStyle)
        aRender.add_style(aStyle)
 
    for i, stream in enumerate(rivers):
        #print(f"on river {i}")
        rivSVG += stream.svg(styles = colors)

    aRender.layers = []
    sampleMap.colorMap()
    sampleMap.hexGrid.update()
    sampleMap.addCoast()
    aRender.adjust("flows",sampleMap.flow_diagram())
    aRender.adjust("rivers", rivSVG)
    aRender.adjust("legend",aRender.legendOverlay(legends))
    
    #return rivSVG
    return  sampleMap.hexGrid.builder.show()

In [None]:
rivMap, rivRivers = RiverDemo().createCarve()
print(len(rivRivers))
RiverDemo().showCarve(rivMap,rivRivers)

Done at iter 1: 2 lakes
intial rivers 2
merged rivers 1
1


In [None]:
for i, riv in enumerate(rivRivers):
    print(f" carving river {i} ")
    riv.carve()

 carving river 0 


In [None]:
RiverDemo().showCarve(rivMap,rivRivers)

Lets think about carving vallies
from carve_to_ocean we have lists of drainages. and we have soils 
we need to compute 
1. fluvial erosion using our soils
2. knickpoint migration
3. fluvial deposition
4. updates to our soil