# River
runs through

In [None]:
#| default_exp river

### Prior Art

In [None]:
#| export
#standard
import numpy as np
import sys
import os
import math
import random

#data
from collections import namedtuple
from dataclasses import dataclass,  field, asdict
from typing import List
from enum import Enum

#Jeremy
from dialoghelper import * 
from fastcore.basics import patch
from fasthtml.common import *
from fasthtml.jupyter import *
import httpx

# unique
from treelib import Tree
import heapq

In [None]:
#| export


from HexMagic.styles import StyleCSS, SVGBuilder,SVGLayer, SVGPatternLoader, preview, app, StyleDemo
from HexMagic.primitives import MapCord, MapSize, MapRect, MapPath, Hex, HexGrid, HexWrapper, HexPosition, hexBackground,windy_edge, HexRegion
from HexMagic.terrain import  TerraDemo, Terrain
from HexMagic.terrainpatterns import TerrainPatterns


#| export
### Helpers

In [None]:
#| export
class RiverDemo:
    def __init__(self):
        self.help = ""

## River

The `River` class represents a river network as a tree structure using `treelib.Tree`.

**Core Data Structure**
- `tree`: A `treelib.Tree` where each node's `data` is a list of hex indices forming a continuous segment
- `hexes`: A flat `set` of all hex indices for quick intersection checks
- Root of tree = outlet (lowest point); branches go upstream toward sources

**Creating Rivers: `from_peak(terrain, peak_index)`**
1. Traces downhill from the peak using `terrain.lowest_neighbor()`
2. Stops when hitting water (elevation < 1) or a local minimum
3. Reverses the path so outlet is root
4. Creates a single-node tree with that path

**Merging Algorithm: `combine_rivers(rivers)`**
1. Loop until no changes occur
2. Check each pair of rivers for hex intersection
3. When found, call `_merge_with()` which:
   - Finds the intersection hex (picks highest elevation if multiple)
   - Locates which tree node contains it
   - Splits that node if needed (downstream stays, upstream becomes child)
   - Trims the other river's tree to just its upstream portion
   - Attaches it as a new branch at the intersection point

**Drawing: `drawRiver(river, riverStyle)`**
- Iterates through all tree nodes
- For each node, builds a list of center points from hex indices
- If the node has a parent, prepends the parent's first hex center (for continuity)
- Draws a spline through those points using `MapPath.drawSpline()`

The result: rivers that branch upstream from a single outlet, drawn as smooth curves through hex centers.

In [None]:
#| export
class River:
    def __init__(self, terrain):
        self.terrain = terrain
        self.tree = Tree()
        self.hexes = set()  # All hex indices for quick intersection checks
        self.ocean_outlet = None

    @property
    def is_empty(self):
        return self.tree.size() == 0
    
    @classmethod
    def from_peak(cls, terrain, peak_index):
        """Create a river by tracing downhill from a peak."""
      
        
        # Trace downhill to find the outlet first
        path = [peak_index]
        current = peak_index
        
        while True:
            lowest = terrain.lowest_neighbor(current)
            # Stop if local minimum or hit water
            if lowest is None or terrain.elevations[lowest] < 1:
                break
            path.append(lowest)
            current = lowest
        # Don't create river if path is too short (just the peak itself)
        if len(path) < 2:
            return None
        
        # Build tree from outlet (root) upward
        # Reverse path so outlet is first
        path.reverse()
        
        river = River(terrain)
        # Create root node with the path as a single segment
        river.tree.create_node(tag="segment", identifier=0, data=path)
        river.hexes.update(path)
        
        return river
    
    @staticmethod
    def combine_rivers(rivers):
        """Merge intersecting rivers, return list with no intersections."""
        result = list(rivers)
        
        changed = True
        while changed:
            changed = False
            for i in range(len(result)):
                for j in range(i + 1, len(result)):
                    # Check for intersection
                    intersection = result[i].hexes & result[j].hexes
                    if intersection:
                        # Merge j into i
                        merged = result[i]._merge_with(result[j], intersection)
                        result[i] = merged
                        result.pop(j)
                        changed = True
                        break
                if changed:
                    break
        
        result = [r for r in result if not r.is_empty]
        return result
    
    def _merge_with(self, other, intersection):
        """Merge another river into this one at intersection point."""
        # TODO: implement tree merging logic
        # For now, just combine the hex sets
        merged = River(self.terrain)
        merged.hexes = self.hexes | other.hexes
        # Tree merging is more complex - need to attach other's
        # upstream portion at the intersection point
        return merged


In [None]:
#| export
# In River class

@patch  
def show(self:River):
    """Display river tree structure with hex indices."""
    if self.tree.size() == 0:
        print("Empty river")
        return
    
    # Recursively print the tree with custom formatting
    def print_node(node, indent=0):
        # Print the hex indices for this segment
        indices_str = ", ".join(map(str, node.data))
        print("    " * indent + indices_str)
        
        # Print children
        children = self.tree.children(node.identifier)
        for child in children:
            print_node(child, indent + 1)
    
    # Start from root
    root = self.tree.get_node(self.tree.root)
    print_node(root)



In [None]:
#| export
@patch
def midTerrain(self:RiverDemo):

    mySize = MapSize(480,480)
    myBounds = MapRect(MapCord(0,0), mySize)
    sampleMap =  Terrain(myBounds,radius=15,path = "volcano.svg")
    for center in [23,46,57,81,123,145,167,189,211,233,255]:
        sampleMap.volcano(center=center*3,adjusted=200,num_rings=8,variability=0.5,initial_threshold=0.4)
    sampleMap.colorMap()
    sampleMap.hexGrid.update()

    return sampleMap

@patch
def demoRiver(self:RiverDemo):
    sampleMap = self.midTerrain()
    peaks = sampleMap.find_peaks(40,1)
    rivers = [River.from_peak(sampleMap, peak) for peak in peaks]
    rivers = [r for r in rivers if r is not None]  # Filter out None rivers
    rivers = River.combine_rivers(rivers)

    for river in rivers:
        print("====")
        river.show()
    



In [None]:
RiverDemo().demoRiver()

### Merging Rivers

In [None]:
#| export
@patch 
def _merge_with(self:River, other, intersection):
    """Merge another river into this one at intersection point."""
    merged = River(self.terrain)
    
    # Start with a copy of self's tree
    merged.tree = Tree(self.tree, deep=True)
    merged.hexes = self.hexes.copy()
    
    # Find the furthest upstream intersection point (highest elevation)
    intersection_hex = max(intersection, key=lambda h: self.terrain.elevations[h])
    
    for node in merged.tree.all_nodes():
        if intersection_hex in node.data:
            # Found the node containing intersection
            idx = node.data.index(intersection_hex)
            
            if idx == len(node.data) - 1:
                # Case 1: intersection at end of list
                attach_point = node.identifier
            else:
                # Case 2: intersection in middle - need to split
                # Keep [0:idx+1] in current node
                downstream = node.data[:idx+1]
                upstream = node.data[idx+1:]
                
                node.data = downstream
                
                # Create new child with upstream portion
                new_id = merged.tree.size()
                merged.tree.create_node(
                    tag="segment",
                    identifier=new_id,
                    parent=node.identifier,
                    data=upstream
                )
                
                # Move existing children to new node
                for child in merged.tree.children(node.identifier):
                    if child.identifier != new_id:
                        merged.tree.move_node(child.identifier, new_id)
                
                attach_point = node.identifier
            
            # Now trim and attach other tree
            trimmed_other = merged._trim_tree(other, intersection_hex)
            if trimmed_other:
                merged._attach_tree(trimmed_other, attach_point)
            
            break
    
    merged.hexes.update(other.hexes)
    return merged


@patch 
def _trim_tree(self:River, other, intersection_hex):
    """Return a copy of other's tree with everything at/below intersection removed."""
    # Find node containing intersection in other's tree
    for node in other.tree.all_nodes():
        if intersection_hex in node.data:
            idx = node.data.index(intersection_hex)
            
            # Keep only portion after intersection
            if idx == len(node.data) - 1:
                # Intersection at end, keep children
                # Return subtree starting from children
                trimmed = Tree()
                for child in other.tree.children(node.identifier):
                    trimmed = other.tree.subtree(child.identifier)
                    break  # For now just take first child, may need to handle multiple
                return trimmed
            else:
                # Split node, keep upstream portion
                trimmed = Tree(other.tree, deep=True)
                trim_node = trimmed.get_node(node.identifier)
                trim_node.data = trim_node.data[idx+1:]
                return trimmed
    
    return None
@patch 
def _attach_tree(self:River, other_tree, parent_id):
    """Attach other_tree's root(s) as children of parent_id."""
    if other_tree.size() == 0:
        return
    
    # Get next available ID
    next_id = self.tree.size()
    
    # Recursively copy nodes from other_tree
    def copy_subtree(other_node, new_parent_id):
        nonlocal next_id
        new_node_id = next_id
        next_id += 1
        
        self.tree.create_node(
            tag=other_node.tag,
            identifier=new_node_id,
            parent=new_parent_id,
            data=other_node.data.copy()
        )
        
        for child in other_tree.children(other_node.identifier):
            copy_subtree(child, new_node_id)
    
    # Copy from other's root
    other_root = other_tree.get_node(other_tree.root)
    copy_subtree(other_root, parent_id)


In [None]:
#| export
@patch
def showIndices(self:Terrain):
    ret = ""
    i = 0
    for row in range(self.hexGrid.nRows):
        line = ""
        for col in range(self.hexGrid.nCols):
            line += f"{i:3} "
            i += 1
        ret += line + "\n"
    return ret


In [None]:
#| export


@patch
def demoRiverMerge(self:RiverDemo):
    sampleMap = TerraDemo().tiny()
    print(sampleMap.showIndices())
    print()
    
    # Create three rivers manually
    river1 = River(sampleMap)
    river1.tree.create_node(tag="segment", identifier=0, data=[4, 5, 6])
    river1.hexes = {4, 5, 6}
    
    river2 = River(sampleMap)
    river2.tree.create_node(tag="segment", identifier=0, data=[5, 9, 14])
    river2.hexes = {5, 9, 14}
    
    river3 = River(sampleMap)
    river3.tree.create_node(tag="segment", identifier=0, data=[14, 11])
    river3.hexes = {14, 11}
    
    print("Before merge:")
    print("River 1:")
    river1.show()
    print("\nRiver 2:")
    river2.show()
    print("\nRiver 3:")
    river3.show()
    
    # Merge them
    rivers = [river1, river2, river3]
    merged = River.combine_rivers(rivers)
    
    print("\n\nAfter merge:")
    for i, river in enumerate(merged):
        print(f"\nMerged River {i+1}:")
        river.show()
    
    




In [None]:
RiverDemo().demoRiverMerge()

### Working on Flow

Lets draw the rivers.

In [None]:
#| export
@patch
def svg(self: River, styles=[StyleCSS("river", fill="none", stroke="blue", stroke_width=4)]) -> str:
    """Render river as SVG path.
    
    Args:
        styles: List of styles for different depth levels (thicker for main channels)
    
    Returns:
        SVG path string
    """
    ret = ""
    terrain = self.terrain
   
    for i, node in enumerate(self.tree.all_nodes()):
        points = []
        
        # If this node has a parent, add connection point
        parent_id = node.predecessor(self.tree.identifier)
        if parent_id is not None:
            parent_node = self.tree.get_node(parent_id)
            # Connect last point of child to first point of parent
            parent_first_hex = parent_node.data[-1]
            points.append(terrain.hexGrid.hexes[parent_first_hex].center)

        # Add all hex centers in this segment
        points.extend([terrain.hexGrid.hexes[idx].center for idx in node.data])
        
        # If this segment has ocean_outlet, extend to the border
        if self.ocean_outlet is not None:
            # Check if this is the outlet segment (last in path)
            if len(node.data) > 0:
                last_hex_idx = node.data[-1]
                last_hex = terrain.hexGrid.hexes[last_hex_idx]
                ocean_hex = terrain.hexGrid.hexes[self.ocean_outlet]
                
                # Add edge point (midpoint between last land hex and ocean hex)
                edge_point = MapCord(
                    (last_hex.center.x + ocean_hex.center.x) / 2,
                    (last_hex.center.y + ocean_hex.center.y) / 2
                )
                points.append(edge_point)

        # Select style based on tree depth (deeper = thicker)
        depth = self.tree.depth(node.identifier)
        style_index = min(depth, len(styles) - 1)
        
        # Create path and make it windy
        aPath = MapPath([], style=styles[style_index])
        aPath.points = points
        windy_river = aPath.make_windy(iterations=max(5 - i, 2), offset_factor=0.2)
        ret += windy_river.drawSpline()

    return ret


In [None]:


@patch
def demoSanRivers(self:RiverDemo):
    riverStyle = StyleCSS("river",fill="none",stroke="blue",stroke_width=4)
    sampleMap = TerraDemo().sanFran()
    
    sampleMap.hexGrid.adjustRadius(20)
    aRender = sampleMap.hexGrid.builder
   
    peaks = sampleMap.find_peaks(35,7)

    rivers = [River.from_peak(sampleMap, peak) for peak in peaks]
    rivers = [r for r in rivers if r is not None]  # Filter out None rivers
    colors = []
    for i, c  in enumerate(StyleCSS.seaborn("muted")):
        aStyle = StyleCSS(f"river{i}",fill="none",stroke=c.properties["fill"] ,stroke_width=(i+1)*4)
        colors.append(aStyle)
        aRender.add_style(aStyle)
    
        
    #print(StyleCSS.generate(colors))
    # Merge them
    merged = River.combine_rivers(rivers)
    rivSVG = ""
    for i, v in enumerate(merged):
        rivSVG += v.svg(styles = colors)
        
    sampleMap.colorMap()
    sampleMap.hexGrid.update()
    sgrid = sampleMap.hexGrid
    sgrid.builder.adjust("regions", sgrid.styleLayer(f=windy_edge(iterations=2, offset_factor=0.1)))


    aRender.adjust("rivers", rivSVG)
    #aRender.adjust("root","")
    
    return aRender.show()
    


In [None]:
RiverDemo().demoSanRivers()

## Erode

In [None]:
#| export
@patch
def getLargestRiver(self:Terrain, num_peaks=40, min_height=1):
    """Find peaks, create rivers, merge them, and return the largest one."""
    peaks = self.find_peaks(num_peaks, min_height)
    rivers = [River.from_peak(self, peak) for peak in peaks]
    rivers = [r for r in rivers if r is not None]
    merged = River.combine_rivers(rivers)
    
    if not merged:
        return None
    
    # Return river with most hexes
    return max(merged, key=lambda r: len(r.hexes))


1. I do think we can do a single pass (and always put it in a loop if we need more)
2. I think a wider valley would look better.

##### ðŸ¤–ReplyðŸ¤–<!-- SOLVEIT_SEPARATOR_7f3a9b2c -->

Perfect! Single pass with wider valleys it is.

**Here's my thinking for the algorithm:**

For each hex in the river tree, we'll:
1. Calculate its **flow accumulation** (count of upstream hexes feeding into it)
2. Calculate **gradient** (elevation drop to next downstream hex)
3. Create an erosion value: `erosion = base_factor * flow_accumulation * gradient`
4. Apply erosion to the hex AND its neighbors (with decay by distance)

For the last few hexes (near water/flat areas), switch to deposition instead.

**Structure suggestion:**
```python
@patch
def erode(self:River, base_erosion=0.5, valley_width=2, deposit_length=3):
    """Create elevation adjustments for river erosion."""
    adjustments = {}  # hex_index -> elevation_change (negative for erosion)
    
    # Your implementation here
    # - Walk the tree calculating flow accumulation
    # - Apply erosion based on gradient and flow
    # - Spread to neighbors within valley_width
    # - Handle deposition at the end
    
    return adjustments
```

Does this structure make sense? Should I walk you through calculating flow accumulation first, or would you like to try implementing part of it?

In [None]:
#| export
@patch
def _calculate_flow(self:River):
    """Calculate accumulated flow for each hex in the river.
    Returns dict mapping hex_index -> flow_count"""
    flow = {}
    
    def traverse(node):
        # Start with hexes in this segment
        segment_flow = {}
        
        # Get flow from all children first (depth-first)
        child_flows = []
        for child in self.tree.children(node.identifier):
            child_flows.append(traverse(child))
        
        # Process this segment from end to beginning (upstream to downstream)
        for i in range(len(node.data) - 1, -1, -1):
            hex_idx = node.data[i]
            
            if i == len(node.data) - 1:
                # Last hex in segment - accumulate from children
                accumulated = 1  # This hex itself
                for child_flow in child_flows:
                    # Get flow from first hex of each child
                    accumulated += child_flow.get(node.data[i], 0)
            else:
                # Middle of segment - accumulate from next hex
                accumulated = 1 + segment_flow[node.data[i + 1]]
            
            segment_flow[hex_idx] = accumulated
            flow[hex_idx] = accumulated
        
        return segment_flow
    
    root = self.tree.get_node(self.tree.root)
    traverse(root)
    return flow


In [None]:
def drawLegend(styles:[StyleCSS]):
    ret = ""
    
    radius = 15
    pad = 10
    x = radius + pad
    y = radius + pad
    
    for style in styles:
        # Use the actual style name for the class
        name = " ".join(style.name.split("_"))
        
        ret += f"""\t<circle cx="{x}" cy="{y}" r="{radius}" class="{style.name}"/>\n"""
        ret += f"""\t<text x="{x + pad + radius}" y="{y}" text-anchor="start" dy="0.3em" class="keyLabel">{name}</text>\n"""
        y += radius * 2 + pad

    return ret

        
        

In [None]:
@patch
def demoFlow(self:RiverDemo):

    fills=["#d4ff00eb","#ffb300ff","#ff0073ff","#9900ff97","#1e0e45eb"]
    riverStyle = StyleCSS("river",fill="none",stroke="blue",stroke_width=4)
    sampleMap = TerraDemo().sanFran()
    
    sampleMap.hexGrid.adjustRadius(15)
    aRender = sampleMap.hexGrid.builder
   
    peaks = sampleMap.find_peaks(35,7)

    rivers = [River.from_peak(sampleMap, peak) for peak in peaks]
    rivers = [r for r in rivers if r is not None]  # Filter out None rivers
    colors = []
    legends = []
    for i, c  in enumerate(fills):
        aStyle = StyleCSS(f"river{i}",fill="none",stroke=c ,stroke_width=4)
        colors.append(aStyle)
        aRender.add_style(aStyle)

        aStyle = StyleCSS(f"Level_{i+1}",fill=c,stroke="#000000" ,stroke_width=2)
        legends.append(aStyle)
        aRender.add_style(aStyle)
    
    #print(StyleCSS.generate(colors))
    # Merge them
    merged = River.combine_rivers(rivers)
    rivSVG = ""
    

    flowData = np.zeros(len(sampleMap.hexGrid.hexes)) - 1

    for i, stream in enumerate(merged):
        rivSVG += stream.svg(styles = colors)
        flows = stream._calculate_flow()
        streamFlow = np.zeros(len(sampleMap.hexGrid.hexes)) - 1
        for k,v in flows.items():
            streamFlow[k] = v
        flowData = np.maximum(flowData,streamFlow)

    sampleMap.colorMap()
    sampleMap.hexGrid.update()
    
    flowData = [int(x) for x in flowData]
    # Create patterns and overlay
    patternGen = TerrainPatterns(sampleMap)
    patterns = patternGen.ballDensity(max(flowData),fills=fills)  # 5 levels
    sampleMap.makeOverlay(flowData, patterns)
    
    sampleMap.addCoast()
    sampleMap.colorMap()
    sampleMap.hexGrid.update()

    aRender.adjust("rivers", rivSVG)
    aRender.adjust("root","")
    aRender.adjust("legend",drawLegend(legends))
    

    return aRender.show()


In [None]:

    

RiverDemo().demoFlow()

In [None]:
#| export
@patch
def _calculate_gradient(self:River):
    """Calculate elevation gradient for each hex in the river.
    Returns dict mapping hex_index -> elevation_drop"""
    gradient = {}
    
    def traverse(node):
        # Process each hex in this segment
        for i in range(len(node.data)):
            hex_idx = node.data[i]
            current_elev = self.terrain.elevations[hex_idx]
            
            # Find next downstream hex
            if i > 0:
                # Next hex is earlier in the list (toward outlet)
                next_hex = node.data[i - 1]
            else:
                # At start of segment - check parent
                parent_id = node.predecessor(self.tree.identifier)
                if parent_id is not None:
                    parent_node = self.tree.get_node(parent_id)
                    next_hex = parent_node.data[-1]  # Last hex of parent
                else:
                    # At outlet - no gradient
                    gradient[hex_idx] = 0
                    continue
            
            next_elev = self.terrain.elevations[next_hex]
            drop = max(0, current_elev - next_elev)  # Only positive drops
            gradient[hex_idx] = drop
        
        # Recurse to children
        for child in self.tree.children(node.identifier):
            traverse(child)
    
    root = self.tree.get_node(self.tree.root)
    traverse(root)
    return gradient


In [None]:
#| export
@patch
def demoGradient(self:RiverDemo):

    fills=["#d4ff00eb","#ffb300ff","#ff0073ff","#9900ff97","#1e0e45eb"]
    riverStyle = StyleCSS("river",fill="none",stroke="blue",stroke_width=4)
    sampleMap = TerraDemo().sanFran()
    
    sampleMap.hexGrid.adjustRadius(15)
    aRender = sampleMap.hexGrid.builder
   
    peaks = sampleMap.find_peaks(35,7)

    rivers = [River.from_peak(sampleMap, peak) for peak in peaks]
    rivers = [r for r in rivers if r is not None]  # Filter out None rivers
    colors = []
    legends = []
    for i, c  in enumerate(fills):
        aStyle = StyleCSS(f"river{i}",fill="none",stroke=c ,stroke_width=4)
        colors.append(aStyle)
        aRender.add_style(aStyle)

        aStyle = StyleCSS(f"Level_{i+1}",fill=c,stroke="#000000" ,stroke_width=2)
        legends.append(aStyle)
        aRender.add_style(aStyle)
    
    #print(StyleCSS.generate(colors))
    # Merge them
    merged = River.combine_rivers(rivers)
    rivSVG = ""
    

    flowData = np.zeros(len(sampleMap.hexGrid.hexes)) - 1

    for i, stream in enumerate(merged):
        rivSVG += stream.svg(styles = colors)
        flows = stream._calculate_gradient()
        streamFlow = np.zeros(len(sampleMap.hexGrid.hexes)) - 1
        for k,v in flows.items():
            streamFlow[k] = v
        flowData = np.maximum(flowData,streamFlow)

    sampleMap.colorMap()
    sampleMap.hexGrid.update()
    
    flowData = [int(x) for x in flowData]
    # Create patterns and overlay
    patternGen = TerrainPatterns(sampleMap)
    patterns = patternGen.ballDensity(max(flowData),fills=fills)  # 5 levels
    sampleMap.makeOverlay(flowData, patterns)
    
        
    sampleMap.colorMap()
    sampleMap.hexGrid.update()

    aRender.adjust("rivers", rivSVG)
    aRender.adjust("root","")
    sampleMap.addCoast()
    aRender.adjust("legend",drawLegend(legends))
    
    #return max(flowData)
    return aRender.show()



In [None]:

RiverDemo().demoGradient()

In [None]:

@patch
def demoPower(self:RiverDemo):

    fills=["#d4ff00eb","#ffb300ff","#ff0073ff","#9900ff97","#1e0e45eb"]
    riverStyle = StyleCSS("river",fill="none",stroke="blue",stroke_width=4)
    sampleMap = TerraDemo().sanFran()
    
    sampleMap.hexGrid.adjustRadius(15)
    aRender = sampleMap.hexGrid.builder
   
    peaks = sampleMap.find_peaks(35,7)

    rivers = [River.from_peak(sampleMap, peak) for peak in peaks]
    rivers = [r for r in rivers if r is not None]  # Filter out None rivers
    colors = []
    legends = []
    for i, c  in enumerate(fills):
        aStyle = StyleCSS(f"river{i}",fill="none",stroke=c ,stroke_width=4)
        colors.append(aStyle)
        aRender.add_style(aStyle)

        aStyle = StyleCSS(f"Level_{i+1}",fill=c,stroke="#000000" ,stroke_width=2)
        legends.append(aStyle)
        aRender.add_style(aStyle)
    
    #print(StyleCSS.generate(colors))
    # Merge them
    merged = River.combine_rivers(rivers)
    rivSVG = ""
    
    powerData = np.zeros(len(sampleMap.hexGrid.hexes)) - 1

    for i, stream in enumerate(merged):
        rivSVG += stream.svg(styles = colors)
        flows = stream._calculate_flow()
        gradient = stream._calculate_gradient()

        streamPower = np.zeros(len(sampleMap.hexGrid.hexes)) - 1
        for k,v in flows.items():
            streamPower[k] = v * gradient[k]
        powerData = np.maximum(powerData,streamPower)

    sampleMap.colorMap()
    sampleMap.hexGrid.update()
    
    powerData = [int(x) for x in powerData]
    # Create patterns and overlay
    patternGen = TerrainPatterns(sampleMap)
    patterns = patternGen.ballDensity(max(powerData),fills=fills)  # 5 levels
    sampleMap.makeOverlay(powerData, patterns)
    
    sampleMap.colorMap()
    sampleMap.hexGrid.update()
    sampleMap.addCoast()
    aRender.adjust("rivers", rivSVG)
    aRender.adjust("root","")
    aRender.adjust("legend",drawLegend(legends))
    
    #return max(flowData)
    return aRender.show()


In [None]:
RiverDemo().demoPower()

now that I have the power how would I start to erode. There are two cases a v like river and a u shape glacier

##### ðŸ¤–ReplyðŸ¤–<!-- SOLVEIT_SEPARATOR_7f3a9b2c -->

Great question! The key difference is in how you spread the erosion to neighbors:

**V-shaped River (water erosion):**
- Erodes **most** at the river hex itself
- Falls off **quickly** to neighbors (exponential decay)
- Only affects hexes **lower** than the river hex (water cuts downward and slightly sideways)

**U-shaped Glacier:**
- Erodes **evenly** across the valley width
- Falls off **slowly** to neighbors (linear or gentle decay)
- Can affect hexes at **similar or higher** elevation (ice bulldozes everything)

Here's a structure to start:

```python
@patch
def erode(self:River, base_erosion=0.5, valley_width=2, shape='river'):
    """Apply erosion based on stream power.
    shape: 'river' for V-shaped, 'glacier' for U-shaped"""
    
    adjustments = {}
    flow = self._calculate_flow()
    gradient = self._calculate_gradient()
    
    for hex_idx in self.hexes:
        power = flow[hex_idx] * gradient[hex_idx]
        erosion_amount = base_erosion * power
        
        # Apply to river hex
        adjustments[hex_idx] = adjustments.get(hex_idx, 0) - erosion_amount
        
        # Spread to neighbors
        neighbors = self.terrain.hexGrid.neighbors(hex_idx)
        for dist in range(1, valley_width + 1):
            # Get neighbors at this distance
            ring = self.terrain.hexGrid.ring(hex_idx, dist)
            
            if shape == 'river':
                # V-shape: exponential decay, only lower neighbors
                decay = erosion_amount * (0.5 ** dist)
                for n in ring:
                    if self.terrain.elevations[n] < self.terrain.elevations[hex_idx]:
                        adjustments[n] = adjustments.get(n, 0) - decay
            else:
                # U-shape: linear decay, all neighbors
                decay = erosion_amount * (1 - dist / (valley_width + 1))
                for n in ring:
                    adjustments[n] = adjustments.get(n, 0) - decay
    
    return adjustments
```

Does this match your thinking? Should we test it on the San Francisco terrain?

In [None]:



def rainfall_selector_np(values: np.ndarray) -> np.ndarray:
    # bin edges: <5 returns 0, 5-12 returns 1, 12-24 returns 2, etc.
    bins = [0.1, 5, 12, 24, 48]
    return np.digitize(values, bins) - 1 

test_data = np.array([-1, 2, 8, 15, 30, 60, 0, 12, 48])

indices = rainfall_selector_np(test_data)
indices



# Create a terrain
sampleMap = TerraDemo().tiny()

# Generate random rainfall data
rainfall_data = np.random.uniform(-4, 60, len(sampleMap.elevations))
sampleMap.fields['rainfall'] = rainfall_data

# Get pattern indices
pattern_indices = rainfall_selector_np(rainfall_data)

# Create patterns and overlay
patternGen = TerrainPatterns(sampleMap)
patterns = patternGen.ballScale(len(pattern_indices),fills=["#007fff","#d4ff00ff","#ee00ffff","#ff0099ff","#00ff1eff"])  # 5 levels
sampleMap.makeOverlay(pattern_indices, patterns)
#sampleMap.colorMap()
#sampleMap.hexGrid.update()

# View it
#print(sampleMap.hexGrid.builder.xml())
sampleMap.hexGrid.builder.show()

Yes I would love to visualize the  gradient

##### ðŸ¤–ReplyðŸ¤–<!-- SOLVEIT_SEPARATOR_7f3a9b2c -->

I'd be happy to help you visualize the gradient! However, I notice that due to the dialog being truncated (44 messages reduced to 1), I don't have the context about what specific gradient you're referring to.

Could you please clarify:
- What function or model are you working with?
- What type of gradient are you looking to visualize (e.g., gradient of a loss function, gradient field, gradient descent steps, etc.)?
- Do you have any existing code or data that I should work with?

Once you provide these details, I can help you create an appropriate visualization!

## Generate

In [None]:
#| export
@patch
def erode(self:River, base_erosion=0.5, valley_width=2, shape='river'):
    """Apply erosion based on stream power.
    shape: 'river' for V-shaped, 'glacier' for U-shaped"""
    
    adjustments = {}
    flow = self._calculate_flow()
    gradient = self._calculate_gradient()
    
    for hex_idx in self.hexes:
        power = flow[hex_idx] * gradient[hex_idx]
        erosion_amount = base_erosion * power
        
        # Apply to river hex
        adjustments[hex_idx] = adjustments.get(hex_idx, 0) - erosion_amount
        
        # Spread to neighbors
        neighbors = self.terrain.hexGrid.neighbors(hex_idx)
        for dist in range(1, valley_width + 1):
            # Get neighbors at this distance
            ring = self.terrain.hexGrid.ring(hex_idx, dist)
            
            if shape == 'river':
                # V-shape: exponential decay, only lower neighbors
                decay = erosion_amount * (0.5 ** dist)
                for n in ring:
                    if self.terrain.elevations[n] < self.terrain.elevations[hex_idx]:
                        adjustments[n] = adjustments.get(n, 0) - decay
            else:
                # U-shape: linear decay, all neighbors
                decay = erosion_amount * (1 - dist / (valley_width + 1))
                for n in ring:
                    adjustments[n] = adjustments.get(n, 0) - decay
    
    return adjustments


In [None]:
#| export
@patch
def flow_directions(self: Terrain) -> np.ndarray:
    """Return array where each hex points to its downhill neighbor, or -1 if minimum."""
    n = len(self.elevations)
    directions = np.full(n, -1, dtype=int)
    
    for i in range(n):
        lowest = self.lowest_neighbor(i)
        if lowest is not None and self.elevations[lowest] < self.elevations[i]:
            directions[i] = lowest
    
    return directions

In [None]:
#| export
@patch
def flow_diagram(self: Terrain) -> str:
    """Return array where each hex points to its downhill neighbor, or -1 if minimum."""
    n = len(self.elevations)
    overlay = ""
    style = StyleCSS("arrow", stroke="black",stroke_width=1)
    self.builder.add_style(style)

    directions = self.flow_directions()
    for i, x in enumerate(directions):
        if 0 <= x < n and self.elevations[i] > 0:


    
            path = MapPath([ self.hexGrid.hexes[i].center, self.hexGrid.hexes[x].center], style)
    
        
            overlay += path.with_arrowhead() + "\n"
    
    return overlay

In [None]:
#| export


def find_local_minima(self: Terrain) -> list[int]:
    """Return indices of land hexes that are local minima (potential lakes)."""
    directions = self.flow_directions()
    return [i for i in range(len(self.elevations)) 
            if directions[i] == -1 and self.elevations[i] > 0]


@patch
def find_drainage_path(self: Terrain, start: int) -> list[int]:
    """Find lowest-cost path from start to ocean using Dijkstra.
    
    Cost = elevation we'd need to carve (downhill = 0, uphill = diff).
    """
    
    
    ocean = {i for i in range(len(self.elevations)) if self.elevations[i] <= 0}
    if start in ocean:
        return [start]
    
    # (cost, current_hex, path)
    pq = [(0, start, [start])]
    visited = set()
    
    while pq:
        cost, current, path = heapq.heappop(pq)
        
        if current in visited:
            continue
        visited.add(current)
        
        if current in ocean:
            return path
        
        for neighbor in self.hexGrid.neighborsOf(current):
            if neighbor < 0 or neighbor in visited:
                continue
            
            # Cost = how much we'd need to lower neighbor to flow from current
            elev_curr = self.elevations[current]
            elev_neighbor = self.elevations[neighbor]
            
            # To flow, neighbor must be < current
            carve_cost = max(0, elev_neighbor - elev_curr + 1)
            
            new_cost = cost + carve_cost
            heapq.heappush(pq, (new_cost, neighbor, path + [neighbor]))
    
    return []  # No path found




In [None]:
#| export
@patch
def carve_to_ocean(self: Terrain, num_lakes: int = 5, max_iters: int = 10) -> 'Terrain':
    """Carve drainage using lowest-cost paths to ocean."""
    
    for iteration in range(max_iters):
        minima = find_local_minima(self)
        if len(minima) <= num_lakes:
            print(f"Done at iter {iteration}: {len(minima)} lakes")
            break
        
        # Keep highest elevation minima as lakes
        minima.sort(key=lambda i: self.elevations[i], reverse=True)
        drain_these = minima[num_lakes:]
        
        for lake_idx in drain_these:
            path = self.find_drainage_path(lake_idx)
            
            if len(path) < 2:
                continue
            
            # Carve monotonically decreasing along path
            for i in range(len(path) - 1):
                curr, next_hex = path[i], path[i + 1]
                
                if self.elevations[next_hex] >= self.elevations[curr]:
                    self.elevations[next_hex] = self.elevations[curr] - 1
        
        new_count = len(find_local_minima(self))
        
    
    return self

In [None]:
@patch
def demoPower(self:RiverDemo):

    fills=["#d4ff00eb","#ffb300ff","#ff0073ff","#9900ff97","#1e0e45eb"]
    riverStyle = StyleCSS("river",fill="none",stroke="blue",stroke_width=4)

    mySize = MapSize(480,480)
    myBounds = MapRect(MapCord(0,0), mySize)
    sampleMap, plates =  Terrain.fromSeeds(myBounds,slope = 40,variation=90, radius=15)
    #sampleMap = TerraDemo().sanFran()
    sampleMap.carve_to_ocean(2)
    
    sampleMap.hexGrid.adjustRadius(15)
    aRender = sampleMap.hexGrid.builder
   
    peaks = sampleMap.find_peaks(35,7)

    rivers = [River.from_peak(sampleMap, peak) for peak in peaks]
    rivers = [r for r in rivers if r is not None]  # Filter out None rivers
    colors = []
    legends = []
    for i, c  in enumerate(fills):
        aStyle = StyleCSS(f"river{i}",fill="none",stroke=c ,stroke_width=4)
        colors.append(aStyle)
        aRender.add_style(aStyle)

        aStyle = StyleCSS(f"Level_{i+1}",fill=c,stroke="#000000" ,stroke_width=2)
        legends.append(aStyle)
        aRender.add_style(aStyle)
    
    #print(StyleCSS.generate(colors))
    # Merge them
    merged = River.combine_rivers(rivers)
    rivSVG = ""
    
    powerData = np.zeros(len(sampleMap.hexGrid.hexes)) - 1

    for i, stream in enumerate(merged):
        rivSVG += stream.svg(styles = colors)
        flows = stream._calculate_flow()
        gradient = stream._calculate_gradient()

        streamPower = np.zeros(len(sampleMap.hexGrid.hexes)) - 1
        for k,v in flows.items():
            streamPower[k] = v * gradient[k]
        powerData = np.maximum(powerData,streamPower)

    sampleMap.colorMap()
    sampleMap.hexGrid.update()
    
    powerData = [int(x) for x in powerData]
    # Create patterns and overlay
    patternGen = TerrainPatterns(sampleMap)
    patterns = patternGen.ballDensity(max(powerData),fills=fills)  # 5 levels
    sampleMap.makeOverlay(powerData, patterns)
    
    sampleMap.colorMap()
    sampleMap.hexGrid.update()
    sampleMap.addCoast()
    aRender.adjust("rivers", rivSVG)
    aRender.adjust("root","")
    aRender.adjust("legend",drawLegend(legends))
    
    #return max(flowData)
    return aRender.show()

In [None]:
RiverDemo().demoPower()

# Soil

In [None]:
#| export
@dataclass
class SoilType:
    """Represents a soil/bedrock type with erosion properties."""
    name: str
    bedrock: str  # Description of rock type
    erosion_resistance: float  # 0-1, higher = harder to erode
    permeability: float  # 0-1, higher = water flows through faster
    color: str  # Hex color for visualization
    
    @classmethod
    def standard_types(cls) -> List['SoilType']:
        """Five standard soil types from hard rock to sediment."""
        return [
            cls(
                name="Granite",
                bedrock="Plutonic igneous (continental crust)",
                erosion_resistance=0.9,
                permeability=0.2,
                color="#8B7355"  # Gray-brown
            ),
            cls(
                name="Basalt",
                bedrock="Volcanic igneous (oceanic crust)",
                erosion_resistance=0.85,
                permeability=0.3,
                color="#4A4A4A"  # Dark gray
            ),
            cls(
                name="Limestone",
                bedrock="Sedimentary carbonate (dissolves)",
                erosion_resistance=0.5,
                permeability=0.6,
                color="#D4C5B9"  # Light tan
            ),
            cls(
                name="Sandstone",
                bedrock="Sedimentary clastic (crumbles)",
                erosion_resistance=0.4,
                permeability=0.7,
                color="#C2B280"  # Sandy tan
            ),
            cls(
                name="Alluvial",
                bedrock="Deposited sediment (clay/silt)",
                erosion_resistance=0.1,
                permeability=0.4,
                color="#8B6914"  # Dark gold/brown
            ),
        ]

In [None]:
#| export
@patch
def to_style(self: SoilType, stroke_width: float = 1.0, opacity: float = 1.0) -> StyleCSS:
    """Create a StyleCSS for this soil type.
    
    Args:
        stroke_width: Width of stroke for hex borders
        opacity: Opacity of the fill color (0-1)
    
    Returns:
        StyleCSS with soil color as fill
    """
    return StyleCSS(
        name=f"soil_{self.name.lower()}",
        fill=self.color,
        stroke="#000000",
        stroke_width=stroke_width,
        opacity=opacity
    )


In [None]:
#| export
@dataclass
class SoilSystem:
    """Manages soil types and their distribution across terrain."""
    terrain: 'Terrain'
    types: List[SoilType]
    regions: List[HexRegion]  # One per type (index matches types)
    
    def __post_init__(self):
        """Ensure terrain has soil_type field."""
        if 'soil_type' not in self.terrain.fields:
            self.terrain.fields['soil_type'] = np.zeros(len(self.terrain.elevations), dtype=int)

    @classmethod
    def from_plates(cls, terrain: 'Terrain', plates: List['Plate'], 
                    elev_threshold: float = 50.0,
                    precip_threshold: float = 750.0,
                    debug: bool = False) -> 'SoilSystem':
        """Create soil system from plate tectonics and terrain properties.
        
        Args:
            terrain: Terrain with elevations and climate data
            plates: List of tectonic plates
            elev_threshold: Elevation below which alluvial can form (meters)
            precip_threshold: Precipitation above which limestone forms (mm/year)
            debug: Print diagnostic info
        
        Returns:
            SoilSystem with initial soil distribution
        """

        
        types = SoilType.standard_types()
        n_hexes = len(terrain.elevations)
        
        # Initialize soil_type array
        soil_type = np.zeros(n_hexes, dtype=int)
        
        # Map plates to hexes
        plate_map = {}  # hex_idx -> plate
        for plate in plates:
            for hex_idx in plate.hexes:
                plate_map[hex_idx] = plate
        
        # Ensure we have distance_to_coast
        if 'distance_to_coast' not in terrain.fields:
            terrain.compute_distance_to_coast()
        
        # Track counts for debug
        counts = {i: 0 for i in range(len(types))}
        
        for i in range(n_hexes):
            elev = terrain.elevations[i]
            
            # Ocean hexes get basalt (oceanic crust)
            if elev <= 0:
                soil_type[i] = 1  # Basalt
                counts[1] += 1
                continue
            
            # Get plate type
            plate = plate_map.get(i)
            is_oceanic = plate and plate.kind == PlateKind.oceanic
            
            # Coastal lowlands â†’ Alluvial
            if elev < elev_threshold and terrain.fields['distance_to_coast'][i] <= 1:
                soil_type[i] = 4  # Alluvial
                counts[4] += 1
            
            # High elevation â†’ Bedrock (Granite or Basalt from plate)
            elif elev > 1000:
                if is_oceanic:
                    soil_type[i] = 1  # Basalt
                    counts[1] += 1
                else:
                    soil_type[i] = 0  # Granite
                    counts[0] += 1
            
            # Mid elevation â†’ Sedimentary (Limestone or Sandstone)
            else:
                # Use precipitation if available
                if 'precipitation' in terrain.fields:
                    precip = terrain.fields['precipitation'][i]
                    if precip > precip_threshold:
                        soil_type[i] = 2  # Limestone (wet)
                        counts[2] += 1
                    else:
                        soil_type[i] = 3  # Sandstone (dry)
                        counts[3] += 1
                else:
                    # Fallback: use plate type
                    if is_oceanic:
                        soil_type[i] = 2  # Limestone
                        counts[2] += 1
                    else:
                        soil_type[i] = 3  # Sandstone
                        counts[3] += 1
        
        # Store in terrain
        terrain.fields['soil_type'] = soil_type
        
        # Create regions for each type
        regions = []
        for type_idx in range(len(types)):
            hexes = set(np.where(soil_type == type_idx)[0])
            regions.append(HexRegion(hexes=hexes, hex_grid=terrain.hexGrid))
        
        if debug:
            print("\n=== SOIL SYSTEM CREATED ===")
            print(f"Total hexes: {n_hexes}")
            for i, soil_type_obj in enumerate(types):
                pct = 100 * counts[i] / n_hexes
                print(f"{soil_type_obj.name:12s}: {counts[i]:5d} hexes ({pct:5.1f}%)")
        
        return cls(terrain=terrain, types=types, regions=regions)

In [None]:
#| export
@patch
def soilOverlay(self:SoilSystem,f=None,smooth=False)->str:
    """ build an overlay simalar to HexGrid.styleLayer but uses plates."""
    retLayer = ""
    orders = {}  # Shared cache across all regions
    allPaths = []
    borders = {}  # Shared cache across all regions
    soilStyles = [x.to_style() for x in SoilType.standard_types()]
    

    for i, region in enumerate(self.regions):
        style = soilStyles[i]
        self.terrain.hexGrid.builder.add_style(style)
        allPaths.extend( region.trace_perimeter_cached(borders, f, style))

    allPaths = reversed(sorted(allPaths))
    for path in allPaths:
        if smooth:
            #path = path.smooth()
            retLayer += path.svg()
        else:
            retLayer += path.drawClosed()

    return retLayer

In [None]:
#| export
def soilInformation(showText=True,terrain=None):
    global sampleTerrain, samplePlates
    if terrain == None:
        terrain = sampleTerrain
    terrain.colorMap()
    sGrid = terrain.hexGrid
    sGrid.builder.layers = []
    sampleSoil = SoilSystem.from_plates(terrain, samplePlates, debug=True)

    if showText:
        terrain.field_summary("soil_type")
        return
        

    sGrid.builder.adjust("regions", sampleSoil.soilOverlay(f=unique_windy_edge(iterations=2)))
  

    return sGrid.builder.show()