#  HexRegion

In [None]:
#| default_exp plot.region

In [None]:
#| export
import sys
import math
import numpy as np
import math
from collections import namedtuple
from dataclasses import dataclass, field
from fastcore.basics import patch



In [None]:
#| export
from HexMagic.plot.primitives import MapCord, MapSize, MapRect, MapPath, PrimitiveDemo
from HexMagic.plot.cube import HexPosition
from HexMagic.plot.hex import Hex, HexGrid

from HexMagic.styles import StyleCSS,  SVGBuilder

In [None]:
#| export
@dataclass
class HexRegion:
    """A set of adjacent hexes with computed perimeter boundaries."""
    hexes: set[int]  # Set of hex indices
    hex_grid: 'HexGrid'  # Reference to get hex objects
    
    def __post_init__(self):
        self._boundaries = None  # Cached list of MapPath
    
    def _vertex_key(self, v) -> tuple[int, int]:
        """Convert vertex to integer key for reliable hashing."""
        # Multiply by scale factor before rounding to preserve precision
        return int(v.x), int(v.y)
    
    def _key_to_point(self, key) -> MapCord:
        """Convert integer key back to float coordinates."""
        return MapCord(key[0] , key[1])
    
    def centroid_hex(self) -> int:
        """Return the hex index closest to the geometric center."""
        if not self.hexes:
            return None
        
        # Calculate centroid of all hex centers
        cx, cy = 0, 0
        for idx in self.hexes:
            h = self.hex_grid.hexes[idx]
            cx += h.center.x
            cy += h.center.y
        cx /= len(self.hexes)
        cy /= len(self.hexes)
        
        # Find hex closest to centroid
        closest = None
        min_dist = float('inf')
        for idx in self.hexes:
            h = self.hex_grid.hexes[idx]
            dist = (h.center.x - cx)**2 + (h.center.y - cy)**2
            if dist < min_dist:
                min_dist = dist
                closest = idx
        return closest
    
    def perimeter(self) -> [MapCord]:
        """Find all vertices on the perimeter using vertex counting."""
        vertex_counts = {}
        
        for idx in self.hexes:
            hex_obj = self.hex_grid.hexes[idx]
            for v in hex_obj.vertices():
                v_key = self._vertex_key(v)
                vertex_counts[v_key] = vertex_counts.get(v_key, 0) + 1
        
        # Perimeter vertices are touched by 1 or 2 hexes (not 3)
        return [self._key_to_point(k) for k, count in vertex_counts.items() if count < 3]

    @classmethod
    def fromPath(cls, grid: HexGrid, path: list[int]):
        """Create a HexRegion along a list of hex indices."""
        if not path:
            return cls(hexes=set(), hex_grid=grid)
        
        adds = set()
        
        # Process each segment of the path
        for i in range(len(path) - 1):
            start_idx = path[i]
            end_idx = path[i + 1]
            
            # Convert to HexPositions relative to start
            start_pos = grid.index_to_hexposition(start_idx, start_idx)  # (0,0,0)
            end_pos = grid.index_to_hexposition(end_idx, start_idx)
            
            # Get radial path from start to end
            hexes = start_pos.line_to(end_pos)
            
            # Convert back to indices and add
            for hexpos in hexes:
                idx = grid.hexposition_to_index(hexpos, start_idx)
                if idx >= 0:  # Valid index
                    adds.add(idx)
        
        # Don't forget the last hex
        adds.add(path[-1])
        
        return cls(hexes=adds, hex_grid=grid)


In [None]:
#| export
@patch
def __or__(self: HexRegion, other: 'HexRegion') -> 'HexRegion':
    """Union: region1 | region2"""
    return HexRegion(self.hexes | other.hexes, self.hex_grid)

@patch
def __and__(self: HexRegion, other: 'HexRegion') -> 'HexRegion':
    """Intersection: region1 & region2"""
    return HexRegion(self.hexes & other.hexes, self.hex_grid)

@patch
def __sub__(self: HexRegion, other: 'HexRegion') -> 'HexRegion':
    """Difference: region1 - region2"""
    return HexRegion(self.hexes - other.hexes, self.hex_grid)

@patch
def __xor__(self: HexRegion, other: 'HexRegion') -> 'HexRegion':
    """Symmetric difference: region1 ^ region2"""
    return HexRegion(self.hexes ^ other.hexes, self.hex_grid)

@patch
def __contains__(self: HexRegion, idx: int) -> bool:
    """Membership: idx in region"""
    return idx in self.hexes

@patch
def __len__(self: HexRegion) -> int:
    """Size: len(region)"""
    return len(self.hexes)

@patch
def __iter__(self: HexRegion):
    """Iterate over hex indices"""
    return iter(self.hexes)


In [None]:
#| export
@patch
def outside(self:HexRegion,ring=1):
    m = set()
    grid = self.hex_grid
    total = grid.nCols * grid.nRows
    for index in self.hexes:
        ring_hexpositions = HexPosition(0, 0, 0).ring(ring)
        neighbor_indices = [grid.hexposition_to_index(hp, index) for hp in ring_hexpositions]
        for neighbor in neighbor_indices:
            if neighbor >= 0 and neighbor < total and neighbor not in self.hexes:
                m.add(neighbor)
    return HexRegion(m,grid) # Filter out-of-bounds

@patch 
def apply(self:HexRegion,direction:HexPosition):
    m = set()
    grid = self.hex_grid
    total = grid.nCols * grid.nRows
    for index in self.hexes:
        #hp = HexPosition(0, 0, 0).ra(ring) + direction
        neighbor = grid.hexposition_to_index(direction, index) 
        if neighbor >= 0 and neighbor < total :
            m.add(neighbor)
    return HexRegion(m,grid)

@patch 
def shift(self:HexRegion,direction:HexPosition):
    m = set()
    grid = self.hex_grid
    total = grid.nCols * grid.nRows
    for index in self.hexes:
        #hp = HexPosition(0, 0, 0).ra(ring) + direction
        neighbor = grid.hexposition_to_index(direction, index) 
        if neighbor >= 0 and neighbor < total and neighbor not in self.hexes:
            m.add(neighbor)
    return HexRegion(m,grid)

In [None]:
#| export
@patch
def inside(self:HexRegion,ring=1):
    out = self.outside().outside()
    m = self.hexes - out.hexes
    return HexRegion(m,self.hex_grid) # Filter out-of-bounds

In [None]:
#| export
@patch
def styleHexes(self:HexRegion,style=StyleCSS):
    for h in self.hexes:
        self.hex_grid.hexes[h].style = style

In [None]:
sgrid = PrimitiveDemo().sampleGrid(3, fill="lightgray")
perimeter_style=StyleCSS("perimeter_path", fill="red",  stroke="#ba3ca3ff", stroke_width=3)

region = HexRegion(set([sgrid.midpoint,sgrid.midpoint+1,sgrid.midpoint//2]), sgrid) 

region.styleHexes(style=perimeter_style)
sgrid.builder.add_style(perimeter_style)
sgrid.update()
sgrid.builder.show()

In [None]:
region = region.outside()
region.styleHexes(style=perimeter_style)
sgrid.update()
sgrid.builder.show()

In [None]:

def demoRegionFromPath():
    # Create a simple grid
    mySize = MapSize(200, 200)
    myBounds = MapRect(MapCord(0, 0), mySize)
    baseStyle = StyleCSS.elevations()[3]
    aGrid = HexGrid.from_bounds(bounds=myBounds, style=baseStyle, radius=20)
    
    print(f"Grid has {aGrid.nRows} rows, {aGrid.nCols} cols = {len(aGrid.hexes)} hexes")
    
    # Define a path through the grid
    # Let's try a simple diagonal path
    path = [0, 5, 10, 15]  # Should go diagonally down-right
    
    print(f"\nPath indices: {path}")
    print("Path hex positions:")
    for idx in path:
        row, col = aGrid.index_to_row_col(idx)
        print(f"  Index {idx}: row={row}, col={col}, center={aGrid.hexes[idx].center}")
    
    # Try to create the region
    try:
        region = HexRegion.fromPath(aGrid, path)
        print(f"\nRegion created with {len(region.hexes)} hexes")
        print(f"Region hexes: {sorted(region.hexes)}")
        
        # Check perimeter
        perimeter = region.perimeter()
        print(f"Perimeter has {len(perimeter)} vertices")
        
        return region
    except Exception as e:
        print(f"\nError creating region: {e}")
        
        traceback.print_exc()
        return None




In [None]:
demoRegionFromPath()

Grid has 10 rows, 10 cols = 100 hexes

Path indices: [0, 5, 10, 15]
Path hex positions:
  Index 0: row=0, col=0, center=(0.0,0.0)
  Index 5: row=0, col=5, center=(173.21,0.0)
  Index 10: row=1, col=0, center=(17.32,30.0)
  Index 15: row=1, col=5, center=(190.53,30.0)

Region created with 12 hexes
Region hexes: [0, 1, 2, 3, 4, 5, 10, 11, 12, 13, 14, 15]
Perimeter has 30 vertices


HexRegion(hexes={0, 1, 2, 3, 4, 5, 10, 11, 12, 13, 14, 15}, hex_grid=<HexMagic.plot.Hex.HexGrid object>)

## borders

In [None]:
#| export
@patch
def styleRegions(self:HexGrid):
    ret = {}
    for i, h in enumerate(self.hexes):
        region = ret.get(h.style.name,HexRegion(set(),self))
        region.hexes.add(i)
        ret[h.style.name] = region
    return ret

In [None]:
sgrid.styleRegions().items()

dict_items([('HexStyle', HexRegion(hexes={0, 1, 2, 3, 4, 7, 8, 9, 10, 14, 15, 16, 21, 22, 27, 28, 29, 30, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48}, hex_grid=<HexMagic.plot.hex.HexGrid object>)), ('perimeter_path', HexRegion(hexes={32, 33, 5, 6, 11, 12, 13, 17, 18, 19, 20, 23, 24, 25, 26, 31}, hex_grid=<HexMagic.plot.hex.HexGrid object>))])

In [None]:
#| export
@patch
def contiguous(self: HexRegion) -> list[HexRegion]:
    """Split region into contiguous sub-regions."""
    grid = self.hex_grid
    remaining = set(self.hexes)
    regions = []
    
    while remaining:
        # Start a new region with any remaining hex
        start = next(iter(remaining))
        connected = set()
        frontier = {start}
        
        # Flood fill
        while frontier:
            current = frontier.pop()
            if current in connected:
                continue
            connected.add(current)
            remaining.discard(current)
            
            # Add neighbors that are in our region
            for hp in HexPosition(0, 0, 0).ring(1):
                neighbor_idx = grid.hexposition_to_index(hp, current)
                if neighbor_idx >= 0 and neighbor_idx in remaining:
                    frontier.add(neighbor_idx)
        
        regions.append(HexRegion(connected, grid))
    
    return regions


In [None]:
region.contiguous()

[HexRegion(hexes={32, 33, 5, 6, 11, 13, 17, 18, 19, 20, 23, 26, 31}, hex_grid=<HexMagic.plot.hex.HexGrid object>)]

In [None]:
#| export
@patch
def trace_boundary(self: HexRegion, verbose=False) -> list[tuple[int, int]]:
    """Trace boundary with debug output."""
    if not self.hexes:
        return []
    
    grid = self.hex_grid
    
    # Find a boundary hex - PREFER edges facing off-grid
    start_hex = None
    start_vertex = None
    fallback_hex = None
    fallback_vertex = None
    
    for idx in self.hexes:
        for hp in HexPosition(0, 0, 0).ring(1):
            neighbor = grid.hexposition_to_index(hp, idx)
            dir_idx = hp.direction_index()  # Use actual direction, not enumeration index
            
            if neighbor < 0:
                # Off-grid edge - best choice for perimeter
                start_hex = idx
                start_vertex = Hex._direction_to_vertices[dir_idx][0]
                if verbose:
                    print(f"Found off-grid start: hex {idx}, dir {dir_idx}, vertex {start_vertex}")
                break
            elif neighbor not in self.hexes and fallback_hex is None:
                # Non-region neighbor - save as fallback
                fallback_hex = idx
                fallback_vertex = Hex._direction_to_vertices[dir_idx][0]
        if start_hex is not None:
            break
    
    # Use fallback if no off-grid edge found
    if start_hex is None:
        start_hex = fallback_hex
        start_vertex = fallback_vertex
        if verbose:
            print(f"Using fallback start: hex {start_hex}, vertex {start_vertex}")
    
    if start_hex is None:
        print("No boundary found!")
        return []
    
    path = []
    current_hex = start_hex
    current_vertex = start_vertex
    
    while True:
        path.append((current_hex, current_vertex))
        if verbose and len(path) <= 20:
            print(f"Step {len(path)}: hex={current_hex}, vertex={current_vertex}")
        
        next_vertex = (current_vertex + 1) % 6
        
        # Find direction for this edge
        edge_dir = None
        for dir_idx, (v1, v2) in enumerate(Hex._direction_to_vertices):
            if v1 == current_vertex and v2 == next_vertex:
                edge_dir = dir_idx
                break
        
        if edge_dir is not None:
            hp = HexPosition.directions()[edge_dir]
            neighbor = grid.hexposition_to_index(hp, current_hex)
            
            if verbose and len(path) <= 20:
                print(f"  Edge dir={edge_dir}, neighbor={neighbor}, in_region={neighbor in self.hexes if neighbor >= 0 else 'OOB'}")
            
            if neighbor >= 0 and neighbor in self.hexes:
                # Jump to neighbor
                current_hex = neighbor
                current_vertex = (next_vertex + 3) % 6
                if verbose and len(path) <= 20:
                    print(f"  -> Jump to hex {current_hex}, enter at vertex {current_vertex}")
            else:
                # Stay, advance vertex
                current_vertex = next_vertex
                if verbose and len(path) <= 20:
                    print(f"  -> Stay, advance to vertex {current_vertex}")
        else:
            current_vertex = next_vertex
        
        if current_hex == start_hex and current_vertex == start_vertex:
            if verbose:
                print(f"Completed loop after {len(path)} steps")
            break
        
        if len(path) > len(self.hexes) * 6:
            print(f"Warning: path too long ({len(path)}), breaking")
            break
    
    return path

In [None]:
#| export


@patch
def boundary_to_coords(self: HexRegion, path: list[tuple[int, int]]) -> list[MapCord]:
    """Convert (hex_index, vertex_index) path to MapCord list."""
    return [self.hex_grid.hexes[hex_idx].v[vertex_idx] for hex_idx, vertex_idx in path]

@patch 
def boundary_path(self: HexRegion, style=None) -> MapPath:
    """Get boundary as a MapPath ready for rendering."""
    if style is None:
        style = StyleCSS("boundary", fill="none", stroke="#333", stroke_width=2)
    
    path = self.trace_boundary()
    coords = self.boundary_to_coords(path)
    return MapPath(coords, style).closed()


In [None]:
#| export
@patch
def trace_perimeter(self: HexRegion, debug=False, 
                   style=StyleCSS("perimeter_path", fill="none", 
                                 stroke="#ba3ca3ff", stroke_width=3)):
    """Trace perimeter using commonEdge to find boundary edges."""

    paths = []
       
    for subR in self.contiguous():
        path = subR.boundary_path(style=style)
        paths.append(path)
    
    return paths  # Return paths and empty gaps list


In [None]:
#| export
@patch
def styleLayer(self:HexGrid):
    retLayer = ""
    regions = self.styleRegions()
    for styleName, region in regions.items():
        style = self.builder.styles[styleName]
        for path in region.trace_perimeter(style=style):
            text = path.svg()
            retLayer += text
    return retLayer

In [None]:
sgrid.builder.adjust("regions", sgrid.styleLayer())
sgrid.builder.show()