# Arizona Gerrymandering

### Imports

In [None]:
import geopandas as gpd
import numpy as np
import pandas as pd
from shapely.geometry import Point
from scipy.spatial import distance_matrix
import networkx as nx
from collections import deque
import random
import matplotlib.pyplot as plt
import warnings
import heapq
import time

warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning)
warnings.filterwarnings('ignore', category=DeprecationWarning)

# Set global random seed for reproducibility
RANDOM_SEED = 111
np.random.seed(RANDOM_SEED)
random.seed(RANDOM_SEED)
print(f"Random seed set to {RANDOM_SEED} for reproducibility")

Random seed set to 449 for reproducibility


In [12]:
# Set Times New Roman as default font
plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = ['Times New Roman']
plt.rcParams['font.size'] = 10
plt.rcParams['axes.labelsize'] = 11
plt.rcParams['axes.titlesize'] = 12
plt.rcParams['xtick.labelsize'] = 9
plt.rcParams['ytick.labelsize'] = 9
plt.rcParams['legend.fontsize'] = 10

# Professional Academic Color Palette
COLORS = {
    'efficiency_gap': '#1f77b4',      # Navy Blue - primary metric
    'compactness': '#2ca02c',         # Forest Green - quality metric
    'democratic': '#3182bd',          # Blue - Democratic party
    'republican': '#de2d26',          # Red - Republican party
    'acceptance_rate': '#ff7f0e',     # Dark Orange - algorithm performance
    'scatter': '#9467bd',             # Royal Purple - analysis
    'neutral': '#7f7f7f',             # Gray - reference lines
    'background': '#f0f0f0'           # Light gray - backgrounds
}

print(f"âœ“ Matplotlib styled with Times New Roman font and professional color palette")

âœ“ Matplotlib styled with Times New Roman font and professional color palette


### Graph Construction

In [13]:
def build_adjacency_graph(gdf):
    """Builds adjacency graph from Voting Districts in the Shp."""
    G = nx.Graph()
    for idx in gdf.index:
        G.add_node(idx)

    # Find neighbors using spatial intersection
    for i, geom_i in enumerate(gdf.geometry):
        for j in range(i + 1, len(gdf)):
            geom_j = gdf.geometry.iloc[j]
            if geom_i.touches(geom_j) or geom_i.intersects(geom_j):
                G.add_edge(i, j)

    return G

### Initialisation

In [14]:
def build_spanning_tree(vtd_indices, graph):
    """Build random spanning tree of given VTDs using Kruskal's algorithm."""
    # Subgraph of only these VTDs
    subgraph = graph.subgraph(vtd_indices).copy()

    # Get all edges and shuffle
    edges = list(subgraph.edges())
    random.shuffle(edges)

    # Kruskal's algorithm with union-find
    parent = {v: v for v in vtd_indices}

    def find(v):
        if parent[v] != v:
            parent[v] = find(parent[v])
        return parent[v]

    def union(u, v):
        root_u = find(u)
        root_v = find(v)
        if root_u != root_v:
            parent[root_u] = root_v
            return True
        return False

    tree_edges = []
    for u, v in edges:
        if union(u, v):
            tree_edges.append((u, v))
            if len(tree_edges) == len(vtd_indices) - 1:
                break

    return nx.Graph(tree_edges)


def find_best_balanced_cut_multi_tree(gdf, graph, vtd_list, target_left_pop, region_pop, node_repeats=10):
    """
    Try multiple spanning trees and find the best balanced cut.

    Parameters:
    - node_repeats: Number of random spanning trees to try

    Returns: (best_left_vtds, best_right_vtds) or None
    """

    best_cut = None
    best_deviation = float('inf')

    for attempt in range(node_repeats):
        # Build a random spanning tree
        tree = build_spanning_tree(vtd_list, graph)

        if tree.number_of_nodes() == 0:
            continue

        # Try all edges in this tree
        for edge in tree.edges():
            tree_temp = tree.copy()
            tree_temp.remove_edge(*edge)
            components = list(nx.connected_components(tree_temp))

            if len(components) != 2:
                continue

            pop1 = gdf.loc[list(components[0]), 'TOTAL'].sum()
            pop2 = gdf.loc[list(components[1]), 'TOTAL'].sum()

            # Calculate deviation from target
            dev1 = abs(pop1 - target_left_pop) / target_left_pop if target_left_pop > 0 else float('inf')
            dev2 = abs(pop2 - (region_pop - target_left_pop)) / (region_pop - target_left_pop) if (region_pop - target_left_pop) > 0 else float('inf')
            max_dev = max(dev1, dev2)

            if max_dev < best_deviation:
                best_deviation = max_dev
                best_cut = (list(components[0]), list(components[1]))

    return best_cut


def recursive_tree_part_init(gdf, graph, num_districts, ideal_pop, tolerance=0.05, node_repeats=10):
    """
    Initialize districts using recursive tree partitioning with multiple tree attempts.
    Guarantees contiguity and population balance.

    Parameters:
    - node_repeats: Number of spanning trees to try per split (higher = better balance, slower)
    """
    print(f"  Using recursive tree partitioning (node_repeats={node_repeats})...")

    def bisect_region(vtd_list, num_parts, depth=0):
        """Recursively split a region into num_parts districts."""

        if num_parts == 1:
            # Base case: return all VTDs assigned to district 1 (1-indexed for this code)
            return {vtd: 1 for vtd in vtd_list}

        if len(vtd_list) < num_parts:
            # Edge case: more districts than VTDs
            result = {}
            for i, vtd in enumerate(vtd_list):
                result[vtd] = (i % num_parts) + 1  # 1-indexed
            return result

        # Calculate target populations
        region_pop = gdf.loc[vtd_list, 'TOTAL'].sum()

        # Split proportional to number of districts in each part
        left_parts = num_parts // 2
        right_parts = num_parts - left_parts

        target_left_pop = region_pop * (left_parts / num_parts)

        # Try multiple spanning trees to find best cut
        best_cut = find_best_balanced_cut_multi_tree(
            gdf, graph, vtd_list, target_left_pop, region_pop, node_repeats=node_repeats
        )

        if best_cut is None:
            # Fallback: simple split
            mid = len(vtd_list) // 2
            best_cut = (vtd_list[:mid], vtd_list[mid:])

        # Recursively split each half
        left_assign = bisect_region(best_cut[0], left_parts, depth+1)
        right_assign = bisect_region(best_cut[1], right_parts, depth+1)

        # Offset right assignments by number of left districts
        right_assign = {vtd: dist + left_parts for vtd, dist in right_assign.items()}

        # Merge
        return {**left_assign, **right_assign}

    # Start recursion
    all_vtds = list(gdf.index)
    assignments = bisect_region(all_vtds, num_districts)

    # Report
    district_pops = {}
    for i in range(1, num_districts + 1):  # 1-indexed
        vtds_in_dist = [vtd for vtd, dist in assignments.items() if dist == i]
        district_pops[i] = gdf.loc[vtds_in_dist, 'TOTAL'].sum()

    print("  Initial district populations:")
    for i in range(1, num_districts + 1):
        dev = abs(district_pops[i] - ideal_pop) / ideal_pop
        print(f"    District {i}: {district_pops[i]:>10,.0f} ({dev:>6.1%} deviation)")

    return pd.Series([assignments[idx] for idx in gdf.index], index=gdf.index)

### Constraint Calculations

In [15]:
def is_district_contiguous(gdf, graph, district):
    """Check if a district forms a connected component using Breadth First Search (BFS)."""
    vtds = gdf[gdf['district'] == district].index.tolist()

    if len(vtds) == 0:
        return True

    # Breadth First Search to check connectivity
    visited = set([vtds[0]])
    queue = deque([vtds[0]])

    while queue:
        current = queue.popleft()
        for neighbor in graph.neighbors(current):
            if neighbor in vtds and neighbor not in visited:
                visited.add(neighbor)
                queue.append(neighbor)

    return len(visited) == len(vtds)


def polsby_popper_score(district_stats, district):
    """
    Compute Polsby-Popper compactness score for a district.
    Score = 4Ï€ * Area / PerimeterÂ²
    Range: 0 to 1 (1 = perfect circle)
    """
    area = district_stats[district]['area']
    perimeter = district_stats[district]['perimeter']

    if perimeter == 0:
        return 0
    return (4 * np.pi * area) / (perimeter ** 2)


def wasted_votes(r_votes, d_votes):
    """
    Calculate wasted vote ratio for a district based on party preference.
    Wasted Votes:
    - For winning party: Votes beyond 50% + 1
    - For losing party: All votes
    """
    total_votes = r_votes + d_votes
    half_votes = total_votes // 2 + 1

    if d_votes > r_votes:
        wasted_d = d_votes - half_votes
        wasted_r = r_votes
    else:
        wasted_d = d_votes
        wasted_r = r_votes - half_votes
    return wasted_d, wasted_r


def district_stats_object(gdf, graph):
    """Create a district stats object to hold populations and other stats."""
    district_stats = {}
    # Calculate stats for each district
    for district_id in gdf['district'].unique():
        district_data = gdf[gdf['district'] == district_id]
        district_geom = district_data.unary_union

        district_stats[district_id] = {
            'contiguity': is_district_contiguous(gdf, graph, district_id),
            'population': district_data['TOTAL'].sum(),
            'area': district_geom.area,
            'perimeter': district_geom.length,
            'polsby_popper': (4 * np.pi * district_geom.area) / (district_geom.length ** 2) if district_geom.length > 0 else 0,
            'VoterReg_D': district_data['VoterReg_D'].sum(),
            'VoterReg_R': district_data['VoterReg_R'].sum(),
            'wasted_D_votes': wasted_votes(district_data['VoterReg_R'].sum(), district_data['VoterReg_D'].sum())[0],
            'wasted_R_votes': wasted_votes(district_data['VoterReg_R'].sum(), district_data['VoterReg_D'].sum())[1],
        }

    return district_stats

### Constraints

In [16]:

# ------------- Constraint 1: Contiguity -------------

def is_flip_contiguous(gdf, graph, vtd_id, old_district):
    """
    Check if flipping a VTD maintains contiguity by verifying all neighbor VTDs
    in old_district can still reach each other.
    """
    # Get all neighbors of the flipped VTD in the old district
    neighbors = list(graph.neighbors(vtd_id))
    old_district_neighbors = [n for n in neighbors if gdf.loc[n, 'district'] == old_district]

    if len(old_district_neighbors) == 0:
        return True

    if len(old_district_neighbors) == 1:
        return True

    # Check if all neighbors can reach the first neighbor
    # (without going through the flipped VTD)
    start = old_district_neighbors[0]
    for target in old_district_neighbors[1:]:

        # BFS from start to target, only through old_district VTDs (excluding vtd_id)
        visited = {start}
        queue = deque([start])
        found = False

        while queue and not found:
            current = queue.popleft()

            for neighbor in graph.neighbors(current):
                if neighbor == vtd_id:
                    continue  # Skip the flipped VTD
                if gdf.loc[neighbor, 'district'] != old_district:
                    continue  # Only traverse within old_district
                if neighbor in visited:
                    continue

                if neighbor == target:
                    found = True
                    break

                visited.add(neighbor)
                queue.append(neighbor)

        if not found:
            return False

    return True


# ------------- Constraint 2: Population Balance -------------

def calculate_pop_bounds(district_stats, pop_tolerance):
    """Calculate lower and upper population bounds based on ideal population and tolerance."""
    total_pop = sum(stats['population'] for stats in district_stats.values())
    num_districts = len(district_stats)
    ideal_pop = total_pop / num_districts
    pop_lower_bound = (1 - pop_tolerance) * ideal_pop
    pop_upper_bound = (1 + pop_tolerance) * ideal_pop
    return pop_lower_bound, pop_upper_bound


def are_all_districts_pops_balanced(district_pops, pop_lower_bound, pop_upper_bound):
    """Check if all districts are within population tolerance."""
    return all(pop_lower_bound <= pop <= pop_upper_bound for pop in district_pops.values())


def is_flip_pop_balanced(district_pops, vtd_pop, old_district, new_district, pop_lower_bound, pop_upper_bound):
    """
    Check if both affected districts would remain within population tolerance after flip.
    """
    # Calculate what populations would be after flip
    new_old_pop = district_pops[old_district] - vtd_pop
    new_new_pop = district_pops[new_district] + vtd_pop

    # Check if both districts stay within tolerance
    old_balanced = pop_lower_bound <= new_old_pop <= pop_upper_bound
    new_balanced = pop_lower_bound <= new_new_pop <= pop_upper_bound

    return old_balanced and new_balanced


# ------------- Constraint 3: Compactness -------------

def calculate_compactness_bounds(district_stats, compactness_tolerance):
    """Calculate lower compactness bound based on ideal compactness and tolerance."""
    ideal_compactness = np.mean([stats['polsby_popper'] for stats in district_stats.values()])
    compactness_lower_bound = (1 - compactness_tolerance) * ideal_compactness
    return compactness_lower_bound


def are_all_districts_compactness_balanced(district_stats, compactness_lower_bound):
    """Check if all districts are within compactness tolerance."""
    return all(stats['polsby_popper'] >= compactness_lower_bound for stats in district_stats.values())


def is_flip_compactness_balanced(gdf, vtd_id, vtd_geom, old_district, new_district, compactness_lower_bound):
    """
    Check if both affected districts remain within compactness tolerance after a flip.
    """

    # Calculate OLD district's new geometry (removing vtd)
    old_district_vtds = gdf[(gdf['district'] == old_district) & (gdf.index != vtd_id)]
    if len(old_district_vtds) > 0:
        old_district_new_geom = old_district_vtds.unary_union
        old_area = old_district_new_geom.area
        old_perimeter = old_district_new_geom.length

        if old_perimeter > 0:
            old_polsby_popper = (4 * np.pi * old_area) / (old_perimeter ** 2)
        else:
            old_polsby_popper = 0

        if old_polsby_popper < compactness_lower_bound:
            return False

    # Calculate NEW district's new geometry (adding vtd)
    new_district_vtds = gdf[(gdf['district'] == new_district) | (gdf.index == vtd_id)]
    new_district_new_geom = new_district_vtds.unary_union
    new_area = new_district_new_geom.area
    new_perimeter = new_district_new_geom.length

    if new_perimeter > 0:
        new_polsby_popper = (4 * np.pi * new_area) / (new_perimeter ** 2)
    else:
        new_polsby_popper = 0

    if new_polsby_popper < compactness_lower_bound:
        return False

    return True


# ------------- Constraint 4: Vote Efficiency -------------

def does_flip_improve_vote_efficiency(district_stats, vtd_r_votes, vtd_d_votes, old_district, new_district, party_preference=None):
    """
    Check if flipping a VTD improves vote efficiency based on partisan preference.

    Returns:
    - True if flip improves efficiency for the preferred party
    """
    # Current wasted votes
    current_old_wasted_d = district_stats[old_district]['wasted_D_votes']
    current_old_wasted_r = district_stats[old_district]['wasted_R_votes']
    current_new_wasted_d = district_stats[new_district]['wasted_D_votes']
    current_new_wasted_r = district_stats[new_district]['wasted_R_votes']

    # New wasted votes after flip
    flip_old_r_votes = district_stats[old_district]['VoterReg_R'] - vtd_r_votes
    flip_old_d_votes = district_stats[old_district]['VoterReg_D'] - vtd_d_votes
    flip_new_r_votes = district_stats[new_district]['VoterReg_R'] + vtd_r_votes
    flip_new_d_votes = district_stats[new_district]['VoterReg_D'] + vtd_d_votes

    flip_wasted_old_d, flip_wasted_old_r = wasted_votes(flip_old_r_votes, flip_old_d_votes)
    flip_wasted_new_d, flip_wasted_new_r = wasted_votes(flip_new_r_votes, flip_new_d_votes)

    # Calculate change in wasted votes (positive = more waste after flip)
    change_wasted_d = (flip_wasted_old_d + flip_wasted_new_d) - (current_old_wasted_d + current_new_wasted_d)
    change_wasted_r = (flip_wasted_old_r + flip_wasted_new_r) - (current_old_wasted_r + current_new_wasted_r)

    if party_preference == 'D':
        # Democrats improve if: R waste increases MORE than D waste increases
        return change_wasted_r > change_wasted_d
    elif party_preference == 'R':
        # Republicans improve if: D waste increases MORE than R waste increases
        return change_wasted_d > change_wasted_r
    else:
        # Neutral: total wasted votes should decrease
        return (change_wasted_d + change_wasted_r) < 0


# ------------- Combined Constraint Check with Rejection Tracking -------------

def check_flip(gdf, graph, district_stats, vtd_id, vtd_geom, vtd_pop, vtd_r_votes, vtd_d_votes,
               old_district, new_district, pop_lower_bound, pop_upper_bound,
               compactness_lower_bound, party_preference=None):
    """
    Check flip constraints and track which constraint fails (if any).

    CHANGED: Now returns tuple (is_valid, rejection_reason)
    """
    # Check contiguity
    if not is_flip_contiguous(gdf, graph, vtd_id, old_district):
        return False, 'contiguity'

    # Check population balance
    if not is_flip_pop_balanced(
        {d: stats['population'] for d, stats in district_stats.items()},
        vtd_pop, old_district, new_district, pop_lower_bound, pop_upper_bound
    ):
        return False, 'population'

    # Check compactness
    if not is_flip_compactness_balanced(
        gdf, vtd_id, vtd_geom, old_district, new_district, compactness_lower_bound
    ):
        return False, 'compactness'

    # Check vote efficiency
    if not does_flip_improve_vote_efficiency(
        district_stats, vtd_r_votes, vtd_d_votes, old_district, new_district, party_preference
    ):
        return False, 'vote_efficiency'

    return True, None

### Flip and Execution

In [17]:
def find_boundary_vtds(gdf, graph):
    """Find all VTDs that are on district boundaries."""
    boundary_vtds = []

    for idx in gdf.index:
        current_dist = gdf.loc[idx, 'district']
        neighbors = list(graph.neighbors(idx))
        neighbor_dists = gdf.loc[neighbors, 'district'].values

        if len(neighbors) > 0 and any(neighbor_dists != current_dist):
            boundary_vtds.append(idx)

    return boundary_vtds


def propose_flip(gdf, graph):
    """Propose moving a random boundary VTD to a neighboring district."""
    boundary_vtds = find_boundary_vtds(gdf, graph)

    if not boundary_vtds:
        return None, None, None, None, None, None, None

    # Pick random boundary VTD
    vtd = random.choice(boundary_vtds)
    old_dist = gdf.loc[vtd, 'district']
    vtd_pop = gdf.loc[vtd, 'TOTAL']
    vtd_r_votes = gdf.loc[vtd, 'VoterReg_R']
    vtd_d_votes = gdf.loc[vtd, 'VoterReg_D']
    vtd_geom = gdf.loc[vtd, 'geometry']

    # Find neighboring districts
    neighbors = list(graph.neighbors(vtd))
    neighbor_dists = [
        gdf.loc[n, 'district'] for n in neighbors
        if gdf.loc[n, 'district'] != old_dist
    ]

    if not neighbor_dists:
        return None, None, None, None, None, None, None

    new_dist = random.choice(neighbor_dists)
    return vtd, old_dist, new_dist, vtd_geom, vtd_pop, vtd_r_votes, vtd_d_votes


def execute_flip(gdf, district_stats, vtd_id, vtd_geom, vtd_pop, vtd_r_votes, vtd_d_votes, old_district, new_district):
    """
    Execute a flip by updating gdf and district_stats.
    """
    # Update gdf
    gdf.loc[vtd_id, 'district'] = new_district

    # Update OLD district stats
    district_stats[old_district]['population'] -= vtd_pop
    district_stats[old_district]['VoterReg_R'] -= vtd_r_votes
    district_stats[old_district]['VoterReg_D'] -= vtd_d_votes

    # Recalculate old district geometry and compactness
    old_district_data = gdf[gdf['district'] == old_district]
    if len(old_district_data) > 0:
        old_district_geom = old_district_data.unary_union
        district_stats[old_district]['area'] = old_district_geom.area
        district_stats[old_district]['perimeter'] = old_district_geom.length

        if old_district_geom.length > 0:
            district_stats[old_district]['polsby_popper'] = (4 * np.pi * old_district_geom.area) / (old_district_geom.length ** 2)
        else:
            district_stats[old_district]['polsby_popper'] = 0

    # Recalculate old district wasted votes
    old_wasted_d, old_wasted_r = wasted_votes(
        district_stats[old_district]['VoterReg_R'],
        district_stats[old_district]['VoterReg_D']
    )
    district_stats[old_district]['wasted_D_votes'] = old_wasted_d
    district_stats[old_district]['wasted_R_votes'] = old_wasted_r

    # Update NEW district stats
    district_stats[new_district]['population'] += vtd_pop
    district_stats[new_district]['VoterReg_R'] += vtd_r_votes
    district_stats[new_district]['VoterReg_D'] += vtd_d_votes

    # Recalculate new district geometry and compactness
    new_district_data = gdf[gdf['district'] == new_district]
    new_district_geom = new_district_data.unary_union
    district_stats[new_district]['area'] = new_district_geom.area
    district_stats[new_district]['perimeter'] = new_district_geom.length

    if new_district_geom.length > 0:
        district_stats[new_district]['polsby_popper'] = (4 * np.pi * new_district_geom.area) / (new_district_geom.length ** 2)
    else:
        district_stats[new_district]['polsby_popper'] = 0

    # Recalculate new district wasted votes
    new_wasted_d, new_wasted_r = wasted_votes(
        district_stats[new_district]['VoterReg_R'],
        district_stats[new_district]['VoterReg_D']
    )
    district_stats[new_district]['wasted_D_votes'] = new_wasted_d
    district_stats[new_district]['wasted_R_votes'] = new_wasted_r

### Redistricting Function

In [18]:
def redistrict_iterative(gdf, graph, num_iterations=1000, pop_tolerance=0.05,
                         compactness_lower_bound=0.15, party_preference='D', acceptance_rate=0.95,
                         verbose=True):
    """

    Parameters:
    - gdf: GeoDataFrame with VTD data and initialized districts
    - graph: Pre-built adjacency graph
    - num_iterations: Number of iterations to run
    - pop_tolerance: Population balance tolerance (default 0.05 = Â±5%)
    - compactness_lower_bound: Lower bound for Polsby-Popper score
    - party_preference: 'D', 'R', or None for vote efficiency
    - acceptance_rate: Probability of accepting valid flips
    - verbose: Print progress updates

    Returns:
    - gdf_redistricted: Updated GeoDataFrame with new district assignments
    - district_stats_new: Updated district statistics
    - metrics: Dictionary with execution metrics (includes rejection_counts)
    - stats_history: List of district stats at each iteration
    - gdf_snapshots: Dictionary of GeoDataFrame snapshots at every 100 iterations
    """

    print("="*60)
    print("STARTING ITERATIVE REDISTRICTING")
    print("="*60)

    # Step 1: Create initial district stats
    print("\n[1/3] Calculating initial district statistics...")
    start_time = time.time()
    district_stats = district_stats_object(gdf, graph)
    stats_time = time.time() - start_time
    print(f"âœ“ Calculated stats for {len(district_stats)} districts in {stats_time:.2f}s")

    # Print initial district stats
    print("\n" + "="*60)
    print("INITIAL DISTRICT STATISTICS")
    print("="*60)
    for dist_id, stats in sorted(district_stats.items()):
        print(f"\nDistrict {dist_id}:")
        print(f"  Population: {stats['population']:,}")
        print(f"  Polsby-Popper: {stats['polsby_popper']:.4f}")
        print(f"  Dem Votes: {stats['VoterReg_D']:,} | Rep Votes: {stats['VoterReg_R']:,}")
        print(f"  Wasted D: {stats['wasted_D_votes']:,} | Wasted R: {stats['wasted_R_votes']:,}")

    # Calculate aggregate metrics
    total_wasted_d = sum(stats['wasted_D_votes'] for stats in district_stats.values())
    total_wasted_r = sum(stats['wasted_R_votes'] for stats in district_stats.values())
    avg_compactness = np.mean([stats['polsby_popper'] for stats in district_stats.values()])
    print("\n" + "="*60)
    print("AGGREGATE METRICS")
    print("="*60)
    print(f"Total wasted Democratic votes: {total_wasted_d:,}")
    print(f"Total wasted Republican votes: {total_wasted_r:,}")
    print(f"Efficiency gap: {abs(total_wasted_d - total_wasted_r):,}")
    print(f"Average compactness: {avg_compactness:.4f}")

    # Step 2: Calculate population bounds
    print("\n[2/3] Calculating population bounds...")
    pop_lower, pop_upper = calculate_pop_bounds(district_stats, pop_tolerance)
    total_pop = sum(stats['population'] for stats in district_stats.values())
    ideal_pop = total_pop / len(district_stats)
    print(f"âœ“ Ideal population: {ideal_pop:,.0f}")
    print(f"  Lower bound: {pop_lower:,.0f} ({(1-pop_tolerance)*100:.0f}%)")
    print(f"  Upper bound: {pop_upper:,.0f} ({(1+pop_tolerance)*100:.0f}%)")
    print(f"  Compactness lower bound: {compactness_lower_bound:.4f}")

    # Step 3: Initialize tracking
    stats_history = []
    gdf_snapshots = {}

    # Save initial state
    initial_metrics = {
        'iteration': 0,
        'total_wasted_d': total_wasted_d,
        'total_wasted_r': total_wasted_r,
        'efficiency_gap': abs(total_wasted_d - total_wasted_r),
        'avg_compactness': avg_compactness,
        'accepted': False,
        'district_stats': {
            dist_id: {
                'population': stats['population'],
                'polsby_popper': stats['polsby_popper'],
                'VoterReg_D': stats['VoterReg_D'],
                'VoterReg_R': stats['VoterReg_R'],
                'wasted_D_votes': stats['wasted_D_votes'],
                'wasted_R_votes': stats['wasted_R_votes']
            }
            for dist_id, stats in district_stats.items()
        }
    }
    stats_history.append(initial_metrics)
    gdf_snapshots[0] = gdf.copy()  # Save initial GDF

    # Step 4: Run iterations
    print("\n[3/3] Running redistricting iterations...")
    print("="*60)

    propose_times = []
    check_times = []
    execute_times = []
    accepted_flips = 0
    rejected_flips = 0

    # Track rejection reasons
    rejection_counts = {
        'contiguity': 0,
        'population': 0,
        'compactness': 0,
        'vote_efficiency': 0,
        'no_proposal': 0
    }

    iteration_start = time.time()

    for i in range(num_iterations):
        # Propose
        start_propose = time.time()
        vtd, old_dist, new_dist, vtd_geom, vtd_pop, vtd_r_votes, vtd_d_votes = propose_flip(gdf=gdf, graph=graph)
        propose_time = time.time() - start_propose
        propose_times.append(propose_time)

        if vtd is None:
            rejected_flips += 1
            rejection_counts['no_proposal'] += 1
            if verbose and (i + 1) % 100 == 0:
                print(f"Iteration {i+1}/{num_iterations}: No valid proposal")

            # Still save stats even if no proposal
            total_wasted_d = sum(stats['wasted_D_votes'] for stats in district_stats.values())
            total_wasted_r = sum(stats['wasted_R_votes'] for stats in district_stats.values())
            avg_compactness = np.mean([stats['polsby_popper'] for stats in district_stats.values()])

            iteration_metrics = {
                'iteration': i + 1,
                'total_wasted_d': total_wasted_d,
                'total_wasted_r': total_wasted_r,
                'efficiency_gap': abs(total_wasted_d - total_wasted_r),
                'avg_compactness': avg_compactness,
                'accepted': False,
                'district_stats': {
                    dist_id: {
                        'population': stats['population'],
                        'polsby_popper': stats['polsby_popper'],
                        'VoterReg_D': stats['VoterReg_D'],
                        'VoterReg_R': stats['VoterReg_R'],
                        'wasted_D_votes': stats['wasted_D_votes'],
                        'wasted_R_votes': stats['wasted_R_votes']
                    }
                    for dist_id, stats in district_stats.items()
                }
            }
            stats_history.append(iteration_metrics)
            continue

        # Check with rejection tracking
        start_check = time.time()
        is_valid, rejection_reason = check_flip(
            gdf, graph, district_stats, vtd,
            vtd_geom=vtd_geom, vtd_pop=vtd_pop,
            vtd_r_votes=vtd_r_votes, vtd_d_votes=vtd_d_votes,
            old_district=old_dist, new_district=new_dist,
            pop_lower_bound=pop_lower, pop_upper_bound=pop_upper,
            compactness_lower_bound=compactness_lower_bound,
            party_preference=party_preference
        )
        check_time = time.time() - start_check
        check_times.append(check_time)

        # Execute if valid
        flip_accepted = False
        if is_valid and random.random() < acceptance_rate:
            start_execute = time.time()
            execute_flip(gdf, district_stats, vtd, vtd_geom, vtd_pop,
                        vtd_r_votes, vtd_d_votes, old_dist, new_dist)
            execute_time = time.time() - start_execute
            execute_times.append(execute_time)
            accepted_flips += 1
            flip_accepted = True

            if verbose and (i + 1) % 100 == 0:
                print(f"Iteration {i+1}/{num_iterations}: âœ“ ACCEPTED (VTD {vtd}: {old_dist}â†’{new_dist})")
        else:
            rejected_flips += 1
            if rejection_reason:
                rejection_counts[rejection_reason] += 1
            if verbose and (i + 1) % 100 == 0:
                reason_str = f" ({rejection_reason})" if rejection_reason else ""
                print(f"Iteration {i+1}/{num_iterations}: âœ— Rejected{reason_str}")

        # Save stats after each iteration
        total_wasted_d = sum(stats['wasted_D_votes'] for stats in district_stats.values())
        total_wasted_r = sum(stats['wasted_R_votes'] for stats in district_stats.values())
        avg_compactness = np.mean([stats['polsby_popper'] for stats in district_stats.values()])

        iteration_metrics = {
            'iteration': i + 1,
            'total_wasted_d': total_wasted_d,
            'total_wasted_r': total_wasted_r,
            'efficiency_gap': abs(total_wasted_d - total_wasted_r),
            'avg_compactness': avg_compactness,
            'accepted': flip_accepted,
            'district_stats': {
                dist_id: {
                    'population': stats['population'],
                    'polsby_popper': stats['polsby_popper'],
                    'VoterReg_D': stats['VoterReg_D'],
                    'VoterReg_R': stats['VoterReg_R'],
                    'wasted_D_votes': stats['wasted_D_votes'],
                    'wasted_R_votes': stats['wasted_R_votes']
                }
                for dist_id, stats in district_stats.items()
            }
        }
        stats_history.append(iteration_metrics)

        # Save GDF snapshot every 100 iterations
        if (i + 1) % 100 == 0:
            gdf_snapshots[i + 1] = gdf.copy()
            if verbose:
                print(f"  ðŸ“¸ Saved GDF snapshot at iteration {i+1}")

    total_iteration_time = time.time() - iteration_start

    # Summary
    print("\n" + "="*60)
    print("REDISTRICTING COMPLETE")
    print("="*60)
    print(f"\nTotal iterations: {num_iterations}")
    print(f"Accepted flips: {accepted_flips} ({accepted_flips/num_iterations*100:.1f}%)")
    print(f"Rejected flips: {rejected_flips} ({rejected_flips/num_iterations*100:.1f}%)")
    print(f"\nTotal time: {total_iteration_time:.2f}s")
    print(f"Time per iteration: {total_iteration_time/num_iterations*1000:.2f}ms")

    if len(propose_times) > 0:
        print(f"\nAverage propose time: {np.mean(propose_times)*1000:.2f}ms")
    if len(check_times) > 0:
        print(f"Average check time: {np.mean(check_times)*1000:.2f}ms")
    if len(execute_times) > 0:
        print(f"Average execute time: {np.mean(execute_times)*1000:.2f}ms")

    # Print rejection summary
    print("\n" + "="*60)
    print("REJECTION REASONS")
    print("="*60)
    total_rejections = sum(rejection_counts.values())
    for reason, count in sorted(rejection_counts.items(), key=lambda x: -x[1]):
        pct = (count / total_rejections * 100) if total_rejections > 0 else 0
        print(f"{reason:20s}: {count:5d} ({pct:5.1f}%)")

    # Print final district stats
    print("\n" + "="*60)
    print("FINAL DISTRICT STATISTICS")
    print("="*60)
    for dist_id, stats in sorted(district_stats.items()):
        print(f"\nDistrict {dist_id}:")
        print(f"  Population: {stats['population']:,}")
        print(f"  Polsby-Popper: {stats['polsby_popper']:.4f}")
        print(f"  Dem Votes: {stats['VoterReg_D']:,} | Rep Votes: {stats['VoterReg_R']:,}")
        print(f"  Wasted D: {stats['wasted_D_votes']:,} | Wasted R: {stats['wasted_R_votes']:,}")

    # Calculate final aggregate metrics
    total_wasted_d = sum(stats['wasted_D_votes'] for stats in district_stats.values())
    total_wasted_r = sum(stats['wasted_R_votes'] for stats in district_stats.values())
    avg_compactness = np.mean([stats['polsby_popper'] for stats in district_stats.values()])

    print("\n" + "="*60)
    print("AGGREGATE METRICS")
    print("="*60)
    print(f"Total wasted Democratic votes: {total_wasted_d:,}")
    print(f"Total wasted Republican votes: {total_wasted_r:,}")
    print(f"Efficiency gap: {abs(total_wasted_d - total_wasted_r):,}")
    print(f"Average compactness: {avg_compactness:.4f}")

    print(f" Saved {len(gdf_snapshots)} GDF snapshots")
    print(f"   Snapshot iterations: {sorted(gdf_snapshots.keys())}")

    # Prepare output
    gdf_redistricted = gdf.copy()
    district_stats_new = district_stats.copy()

    metrics = {
        'num_iterations': num_iterations,
        'accepted_flips': accepted_flips,
        'rejected_flips': rejected_flips,
        'acceptance_rate': accepted_flips / num_iterations,
        'total_time': total_iteration_time,
        'avg_propose_time': np.mean(propose_times) if propose_times else 0,
        'avg_check_time': np.mean(check_times) if check_times else 0,
        'avg_execute_time': np.mean(execute_times) if execute_times else 0,
        'total_wasted_d': total_wasted_d,
        'total_wasted_r': total_wasted_r,
        'efficiency_gap': abs(total_wasted_d - total_wasted_r),
        'avg_compactness': avg_compactness,
        'rejection_counts': rejection_counts  # NEW
    }

    return gdf_redistricted, district_stats_new, metrics, stats_history, gdf_snapshots

### Visualisation

In [None]:
# ============================================================================
# SECTION 8: VISUALIZATION FUNCTIONS (FULLY STYLED)
# ============================================================================

def plot_single_map(gdf, column, title, ax, cmap='tab10', show_legend=False):
    """Plot a single map on given axis."""
    gdf.plot(column=column, 
             cmap=cmap, 
             edgecolor='black', 
             linewidth=0.1,
             legend=show_legend,
             ax=ax)
    ax.set_title(title, fontsize=14, fontweight='bold')
    ax.axis('off')


def create_district_colormap():
    """
    Create professional colormap for 9 districts.
    Uses muted, academic colors instead of bright tab10.
    """
    from matplotlib.colors import ListedColormap
    
    # Professional academic palette for 9 districts
    district_colors = [
        '#4e79a7',  # Muted blue
        '#f28e2b',  # Muted orange
        '#e15759',  # Muted red
        '#76b7b2',  # Muted teal
        '#59a14f',  # Muted green
        '#edc948',  # Muted yellow
        '#b07aa1',  # Muted purple
        '#ff9da7',  # Muted pink
        '#9c755f',  # Muted brown
    ]
    
    return ListedColormap(district_colors)


def plot_redistricting_evolution(gdf_snapshots, stats_history, figsize=(28, 14), save_path=None):
    """
    Plot evolution of redistricting with maps, statistics, AND population bar charts.
    FULLY STYLED: Custom district colors, Times New Roman, thinner borders.
    
    CHANGED: Larger figure size, thinner VTD borders, custom district colormap
    """
    from matplotlib.gridspec import GridSpec
    
    # Get sorted iterations
    iterations = sorted(gdf_snapshots.keys())
    n_snapshots = len(iterations)
    
    # Create figure with custom grid (3 rows: maps, metrics, population bars)
    fig = plt.figure(figsize=figsize)
    gs = GridSpec(3, n_snapshots, figure=fig, hspace=0.4, wspace=0.3)
    
    # Custom professional colormap for districts
    cmap_districts = create_district_colormap()
    
    # Plot each snapshot
    for idx, iter_num in enumerate(iterations):
        gdf_snap = gdf_snapshots[iter_num]
        
        # Find corresponding stats
        stats = next((s for s in stats_history if s['iteration'] == iter_num), None)
        if stats is None:
            continue
        
        # Calculate Democratic districts
        dem_districts = 0
        for dist_id, dist_stats in stats['district_stats'].items():
            if dist_stats['VoterReg_D'] > dist_stats['VoterReg_R']:
                dem_districts += 1
        
        total_districts = len(stats['district_stats'])
        
        # Row 1: Map (STYLED)
        ax_map = fig.add_subplot(gs[0, idx])
        gdf_snap.plot(column='district', 
                     cmap=cmap_districts,           # CHANGED: Custom colors
                     edgecolor='black', 
                     linewidth=0.15,                # CHANGED: Thinner borders (was 0.5)
                     ax=ax_map,
                     legend=False)
        ax_map.set_title(f'Iteration {iter_num}\nDem Districts: {dem_districts}/{total_districts}',
                        fontsize=12, fontweight='bold')
        ax_map.axis('off')
        
        # Row 2: Key metrics text
        ax_text = fig.add_subplot(gs[1, idx])
        ax_text.axis('off')
        
        metrics_text = f"""
Efficiency Gap: {stats['efficiency_gap']:,}

Wasted Votes:
  D: {stats['total_wasted_d']:,}
  R: {stats['total_wasted_r']:,}

Avg Compactness:
  {stats['avg_compactness']:.4f}
        """
        
        ax_text.text(0.1, 0.5, metrics_text.strip(), 
                    fontsize=10,                    # CHANGED: Slightly larger (was 9)
                    family='Times New Roman',
                    verticalalignment='center',
                    transform=ax_text.transAxes)
        
        # Row 3: District Population Bar Chart
        ax_pop = fig.add_subplot(gs[2, idx])
        
        # Get district IDs and populations
        district_ids = sorted(stats['district_stats'].keys())
        populations = [stats['district_stats'][d]['population'] for d in district_ids]
        
        # Color bars by party winner (KEEP red/blue for partisan)
        colors = []
        for dist_id in district_ids:
            if stats['district_stats'][dist_id]['VoterReg_D'] > stats['district_stats'][dist_id]['VoterReg_R']:
                colors.append(COLORS['democratic'])  # Blue
            else:
                colors.append(COLORS['republican'])  # Red
        
        # Calculate ideal population line
        total_pop = sum(populations)
        ideal_pop = total_pop / len(district_ids)
        
        # Plot bars
        ax_pop.bar(district_ids, populations, color=colors, alpha=0.75, 
                   edgecolor='black', linewidth=1.2)
        ax_pop.axhline(ideal_pop, color='black', linestyle='--', linewidth=1.8, alpha=0.7)
        
        ax_pop.set_xlabel('District', fontsize=9)
        ax_pop.set_ylabel('Population', fontsize=9)
        ax_pop.set_title('District Populations', fontsize=10, fontweight='bold')
        ax_pop.tick_params(labelsize=8)
        ax_pop.grid(True, alpha=0.3, axis='y')
        ax_pop.set_ylim([0, max(populations) * 1.1])
    
    plt.suptitle('Redistricting Evolution: Republican Gerrymandering Algorithm', 
                 fontsize=18, fontweight='bold', y=0.98)
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"âœ“ Saved figure to {save_path}")
    
    plt.tight_layout()
    plt.show()


def plot_metrics_over_time(stats_history, figsize=(16, 10), save_path=None):
    """
    Plot how key metrics evolve over iterations.
    STYLED with consistent colors and Times New Roman font.
    """
    
    iterations = [s['iteration'] for s in stats_history]
    efficiency_gaps = [s['efficiency_gap'] for s in stats_history]
    avg_compactness = [s['avg_compactness'] for s in stats_history]
    total_wasted_d = [s['total_wasted_d'] for s in stats_history]
    total_wasted_r = [s['total_wasted_r'] for s in stats_history]
    
    # Count Democratic districts at each iteration
    dem_districts = []
    for stats in stats_history:
        count = sum(1 for dist_stats in stats['district_stats'].values() 
                   if dist_stats['VoterReg_D'] > dist_stats['VoterReg_R'])
        dem_districts.append(count)
    
    # Calculate rolling acceptance rate (window size = 50)
    accepted = [1 if s['accepted'] else 0 for s in stats_history]
    window_size = 50
    acceptance_rates = []
    for i in range(len(accepted)):
        if i < window_size:
            acceptance_rates.append(np.mean(accepted[:i+1]) if i > 0 else 0)
        else:
            acceptance_rates.append(np.mean(accepted[i-window_size+1:i+1]))
    
    # Create 2x3 subplot grid
    fig, axes = plt.subplots(2, 3, figsize=figsize)
    
    # Plot 1: Efficiency Gap (Navy Blue)
    axes[0, 0].plot(iterations, efficiency_gaps, linewidth=2.5, 
                    color=COLORS['efficiency_gap'])
    axes[0, 0].set_title('Efficiency Gap', fontweight='bold', fontsize=13)
    axes[0, 0].set_xlabel('Iteration', fontsize=11)
    axes[0, 0].set_ylabel('Efficiency Gap', fontsize=11)
    axes[0, 0].grid(True, alpha=0.3)
    axes[0, 0].tick_params(labelsize=10)
    
    # Plot 2: Average Compactness (Forest Green)
    axes[0, 1].plot(iterations, avg_compactness, linewidth=2.5, 
                    color=COLORS['compactness'])
    axes[0, 1].set_title('Average Compactness (Polsby-Popper)', fontweight='bold', fontsize=13)
    axes[0, 1].set_xlabel('Iteration', fontsize=11)
    axes[0, 1].set_ylabel('Polsby-Popper Score', fontsize=11)
    axes[0, 1].grid(True, alpha=0.3)
    axes[0, 1].tick_params(labelsize=10)
    
    # Plot 3: Wasted Votes (Blue/Red for partisan)
    axes[1, 0].plot(iterations, total_wasted_d, linewidth=2.5, 
                    label='Democratic', color=COLORS['democratic'])
    axes[1, 0].plot(iterations, total_wasted_r, linewidth=2.5, 
                    label='Republican', color=COLORS['republican'])
    axes[1, 0].set_title('Wasted Votes', fontweight='bold', fontsize=13)
    axes[1, 0].set_xlabel('Iteration', fontsize=11)
    axes[1, 0].set_ylabel('Wasted Votes', fontsize=11)
    axes[1, 0].legend(fontsize=10)
    axes[1, 0].grid(True, alpha=0.3)
    axes[1, 0].tick_params(labelsize=10)
    
    # Plot 4: Democratic Districts (Purple)
    axes[1, 1].plot(iterations, dem_districts, linewidth=2.5, 
                    color=COLORS['scatter'])
    axes[1, 1].set_title('Democratic Districts', fontweight='bold', fontsize=13)
    axes[1, 1].set_xlabel('Iteration', fontsize=11)
    axes[1, 1].set_ylabel('Number of Dem Districts', fontsize=11)
    axes[1, 1].set_ylim([0, 10])
    axes[1, 1].grid(True, alpha=0.3)
    axes[1, 1].tick_params(labelsize=10)
    
    # Plot 5: Acceptance Rate (Dark Orange)
    axes[0, 2].plot(iterations, acceptance_rates, linewidth=2.5, 
                    color=COLORS['acceptance_rate'])
    axes[0, 2].set_title('Acceptance Rate (Rolling 50)', fontweight='bold', fontsize=13)
    axes[0, 2].set_xlabel('Iteration', fontsize=11)
    axes[0, 2].set_ylabel('Acceptance Rate', fontsize=11)
    axes[0, 2].set_ylim([0, 1])
    axes[0, 2].grid(True, alpha=0.3)
    axes[0, 2].tick_params(labelsize=10)
    
    # Plot 6: Efficiency Gap vs Compactness Scatter
    # Use sequential blue shading (light to dark) instead of viridis
    from matplotlib.colors import LinearSegmentedColormap
    colors_gradient = ['#c6dbef', '#9ecae1', '#6baed6', '#4292c6', '#2171b5', '#084594']
    n_bins = 100
    cmap_custom = LinearSegmentedColormap.from_list('blues_custom', colors_gradient, N=n_bins)
    
    scatter = axes[1, 2].scatter(avg_compactness, efficiency_gaps, 
                                 c=iterations, cmap=cmap_custom, 
                                 s=50, alpha=0.7, edgecolors='black', linewidth=0.7)
    axes[1, 2].set_title('Efficiency Gap vs Compactness', fontweight='bold', fontsize=13)
    axes[1, 2].set_xlabel('Average Compactness', fontsize=11)
    axes[1, 2].set_ylabel('Efficiency Gap', fontsize=11)
    axes[1, 2].grid(True, alpha=0.3)
    axes[1, 2].tick_params(labelsize=10)
    
    # Add colorbar
    cbar = plt.colorbar(scatter, ax=axes[1, 2])
    cbar.set_label('Iteration', rotation=270, labelpad=15, fontsize=10)
    cbar.ax.tick_params(labelsize=9)
    
    plt.suptitle('Redistricting Metrics Over Time', fontsize=18, fontweight='bold')
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"âœ“ Saved figure to {save_path}")
    
    plt.show()


def plot_rejection_reasons(metrics, figsize=(10, 6), save_path=None):
    """
    Plot bar chart of rejection reasons.
    STYLED with consistent colors and Times New Roman font.
    """
    
    rejection_counts = metrics['rejection_counts']
    
    # Sort by count (descending)
    sorted_reasons = sorted(rejection_counts.items(), key=lambda x: -x[1])
    reasons = [r[0].replace('_', ' ').title() for r in sorted_reasons]
    counts = [r[1] for r in sorted_reasons]
    
    # Calculate percentages
    total = sum(counts)
    percentages = [(c / total * 100) if total > 0 else 0 for c in counts]
    
    # Create figure
    fig, ax = plt.subplots(figsize=figsize)
    
    # Use consistent color palette
    bar_colors = [
        COLORS['efficiency_gap'],   # Navy blue
        COLORS['compactness'],      # Forest green
        COLORS['acceptance_rate'],  # Dark orange
        COLORS['scatter'],          # Royal purple
        COLORS['neutral']           # Gray
    ]
    
    # Create bars
    bars = ax.bar(reasons, counts, color=bar_colors[:len(reasons)], 
                  alpha=0.8, edgecolor='black', linewidth=1.5)
    
    # Add value labels on bars
    for i, (bar, count, pct) in enumerate(zip(bars, counts, percentages)):
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
                f'{count:,}\n({pct:.1f}%)',
                ha='center', va='bottom', fontweight='bold', fontsize=11)
    
    # Formatting
    ax.set_xlabel('Rejection Reason', fontsize=13, fontweight='bold')
    ax.set_ylabel('Number of Rejections', fontsize=13, fontweight='bold')
    ax.set_title('Flip Rejection Reasons', fontsize=15, fontweight='bold', pad=20)
    ax.grid(axis='y', alpha=0.3, linestyle='--')
    ax.tick_params(labelsize=11)
    
    # Rotate x-axis labels
    plt.xticks(rotation=45, ha='right')
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"âœ“ Saved rejection reasons chart to {save_path}")
    
    plt.show()
    
    return fig


def print_district_summary(stats_history, iteration=None):
    """Print detailed summary for a specific iteration (default: last)."""
    
    if iteration is None:
        stats = stats_history[-1]
    else:
        stats = next((s for s in stats_history if s['iteration'] == iteration), None)
        if stats is None:
            print(f"No stats found for iteration {iteration}")
            return
    
    print("="*80)
    print(f"DISTRICT SUMMARY - ITERATION {stats['iteration']}")
    print("="*80)
    
    print("\nAGGREGATE METRICS:")
    print(f"  Efficiency Gap: {stats['efficiency_gap']:,}")
    print(f"  Avg Compactness: {stats['avg_compactness']:.4f}")
    print(f"  Total Wasted D: {stats['total_wasted_d']:,}")
    print(f"  Total Wasted R: {stats['total_wasted_r']:,}")
    
    # Count Democratic districts
    dem_count = sum(1 for dist_stats in stats['district_stats'].values() 
                   if dist_stats['VoterReg_D'] > dist_stats['VoterReg_R'])
    print(f"  Democratic Districts: {dem_count}/{len(stats['district_stats'])}")
    
    print("\nDISTRICT DETAILS:\n")
    
    for dist_id, dist_stats in sorted(stats['district_stats'].items()):
        winner = 'D' if dist_stats['VoterReg_D'] > dist_stats['VoterReg_R'] else 'R'
        margin = abs(dist_stats['VoterReg_D'] - dist_stats['VoterReg_R'])
        
        print(f"  District {dist_id} (Winner: {winner}, Margin: {margin:,}):")
        print(f"    Population: {dist_stats['population']:,}")
        print(f"    Compactness: {dist_stats['polsby_popper']:.4f}")
        print(f"    D Votes: {dist_stats['VoterReg_D']:,} | R Votes: {dist_stats['VoterReg_R']:,}")
        print(f"    Wasted D: {dist_stats['wasted_D_votes']:,} | Wasted R: {dist_stats['wasted_R_votes']:,}")
        print()

### Main Execution

In [None]:
if __name__ == "__main__":

    # Configuration
    NUM_DISTRICTS = 9
    POP_TOLERANCE = 0.1  # 10% tolerance
    NUM_ITERATIONS = 10000
    COMPACTNESS_LOWER_BOUND = 0.15
    PARTY_PREFERENCE = 'R'  # 'D', 'R', or None
    ACCEPTANCE_RATE = 0.8
    NODE_REPEATS = 10  # Number of spanning trees to try during initialization

    # Load shapefile
    print("="*60)
    print("LOADING DATA AND INITIALIZING")
    print("="*60)
    print("\nLoading shapefile...")
    gdf = gpd.read_file('/Data/VTD.shp')

    # CRITICAL: Reset index
    gdf = gdf.reset_index(drop=True)
    print(f"âœ“ Loaded {len(gdf)} VTDs")

    # Calculate ideal population
    total_population = gdf['TOTAL'].sum()
    ideal_pop = total_population / NUM_DISTRICTS
    print(f"âœ“ Total population: {total_population:,}")
    print(f"âœ“ Ideal population per district: {ideal_pop:,.0f}")

    # Build adjacency graph
    print("\nBuilding adjacency graph...")
    start_time = time.time()
    graph = build_adjacency_graph(gdf)
    graph_time = time.time() - start_time
    print(f"âœ“ Built graph with {graph.number_of_nodes()} nodes and {graph.number_of_edges()} edges in {graph_time:.2f}s")

    # Initialize districts using recursive tree partitioning
    print("\nInitializing districts with recursive tree partitioning...")
    gdf['district'] = recursive_tree_part_init(gdf, graph, NUM_DISTRICTS, ideal_pop,
                                                tolerance=POP_TOLERANCE, node_repeats=NODE_REPEATS)

    # Optional: Save initialized map
    initialized_path = '/Data/Initial_Map.shp'
    gdf.to_file(initialized_path)
    print(f"\nâœ“ Initialized map saved to: {initialized_path}")

    # Run iterative redistricting
    gdf_redistricted, district_stats_new, metrics, stats_history, gdf_snapshots = redistrict_iterative(
        gdf=gdf,
        graph=graph,  # Pass the pre-built graph
        num_iterations=NUM_ITERATIONS,
        pop_tolerance=POP_TOLERANCE,
        compactness_lower_bound=COMPACTNESS_LOWER_BOUND,
        party_preference=PARTY_PREFERENCE,
        acceptance_rate=ACCEPTANCE_RATE,
        verbose=True
    )

    # Visualizations
    print("\n" + "="*60)
    print("GENERATING VISUALIZATIONS")
    print("="*60)

    selected_iterations = [0, 2000, 4000, 6000, 8000, 10000]
    filtered_snapshots = {k: v for k, v in gdf_snapshots.items() if k in selected_iterations}

    # Plot evolution
    plot_redistricting_evolution(
        filtered_snapshots,
        stats_history,
        figsize=(20, 12),
        save_path='redistricting_evolution.png'
    )

    # Plot metrics over time
    plot_metrics_over_time(
        stats_history,
        figsize=(16, 10),
        save_path='metrics_evolution.png'
    )

    # NEW: Plot rejection reasons
    plot_rejection_reasons(
        metrics,
        figsize=(10, 6),
        save_path='rejection_reasons.png'
    )

    # Print detailed summary for final iteration
    print_district_summary(stats_history)

    # Print summary for iteration 100
    if len([s for s in stats_history if s['iteration'] == 1000]) > 0:
        print_district_summary(stats_history, iteration=1000)

    # Print summary for iteration 500
    if len([s for s in stats_history if s['iteration'] == 5000]) > 0:
        print_district_summary(stats_history, iteration=5000)