In [1]:
import networkx as nx
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from typing import List, Optional, Dict, Set
from brancharchitect.jumping_taxa.lattice.lattice_edge import LatticeEdge

def visualize_cover_sets(
    edge: LatticeEdge,
    title: str = "Comparison of Left vs. Right Covers",
    night_mode: bool = True,
    output_path: Optional[str] = None
):
    """
    Enhanced visualization of PartitionSet comparisons between t1_common_covers and t2_common_covers.

    Features:
    - Creates readable, wrapped labels for PartitionSets with relationship to child_meet
    - Adjusts node size based on number of partitions in each set
    - Shows relationship between partitions and child_meet
    - Colors edges based on overlap with child_meet
    - Uses Quanta Magazine's clean visualization style
    """
    # ---------- Color scheme & styling ----------
    if night_mode:
        color_map = {
            "background": "#080b17",
            "grid": "#192042",
            "text": "#ffffff",
            "node_fill": "#080b17",
            "node_border": "#5928f2",
            "tree1": "#ee4266",  # left side
            "tree2": "#00cecb",  # right side
            "common": "#5928f2",
            "highlight": "#ffd23f",
            # Relationship type colors
            "divergent": "#FF4444",
            "collapsed": "#44FF44",
            "intermediate": "#4444FF",
            "mixed": "#FFFF44"
        }
        plt.rcParams.update({
            'font.family': 'monospace',
            'font.monospace': ['Menlo', 'DejaVu Sans Mono', 'Courier New'],
            'font.size': 9,
            'axes.titlesize': 16,
            'axes.labelsize': 14,
            'figure.facecolor': color_map["background"],
            'axes.facecolor': color_map["background"],
        })
    else:
        color_map = {
            "background": "#ffffff",
            "grid": "#cccccc",
            "text": "#000000",
            "node_fill": "#ffffff",
            "node_border": "#000000",
            "tree1": "#ee4266",
            "tree2": "#00cecb",
            "common": "#5928f2",
            "highlight": "#ffd23f",
            # Relationship type colors
            "divergent": "#FF4444",
            "collapsed": "#44FF44",
            "intermediate": "#4444FF",
            "mixed": "#FFFF44"
        }

    # Create directed graph
    G = nx.DiGraph()

    # ---------- Create enhanced labels for PartitionSets ----------
    def create_enhanced_label(pset: PartitionSet, side: str) -> str:
        """Create a clean, multiline label with relationship to child_meet"""
        partitions = [p for p in pset]
        count = len(partitions)

        if count == 0:
            return f"{side} Cover\n(empty set)"

        # Determine relationship to child_meet
        shared_with_child = pset & edge.child_meet
        if not shared_with_child:
            cover_type = "divergent"
        elif pset.issubset(edge.child_meet):
            cover_type = "collapsed"
        else:
            cover_type = "intermediate"

        label_parts = []
        label_parts.append(f"{side} Cover ({count} partitions)")
        label_parts.append(f"Type: {cover_type}")

        # Format the partitions showing relationship to child_meet
        for i, p in enumerate(sorted(partitions, key=lambda x: len(x.indices))[:3]):
            in_child = "✓" if p in edge.child_meet else "✗"

            if hasattr(p, 'lookup') and p.lookup:
                taxa_names = []
                for idx in p.indices:
                    for name, index in p.lookup.items():
                        if index == idx:
                            taxa_names.append(name)
                part_str = '{' + ','.join(sorted(taxa_names)) + '}'
                label_parts.append(f"· {part_str} {in_child}")
            else:
                label_parts.append(f"· {p} {in_child}")

        if count > 3:
            in_child_count = sum(1 for p in partitions[3:] if p in edge.child_meet)
            label_parts.append(f"· ... ({count-3} more, {in_child_count} in child_meet)")

        return '\n'.join(label_parts)

    # ---------- Add nodes for the left cover PartitionSets ----------
    left_labels = []
    for i, pset in enumerate(edge.t1_common_covers):
        node_id = f"L{i}"
        readable_label = create_enhanced_label(pset, "Left")
        # Add relationship to child_meet in node attributes
        shared_with_child = pset & edge.child_meet
        if not shared_with_child:
            relationship = "divergent"
        elif pset.issubset(edge.child_meet):
            relationship = "collapsed"
        else:
            relationship = "intermediate"

        G.add_node(node_id, pset=pset, side='left', label=readable_label,
                  size=100 + (len(pset) * 50), relationship=relationship)
        left_labels.append(node_id)

    # ---------- Add nodes for the right cover PartitionSets ----------
    right_labels = []
    for j, pset in enumerate(edge.t2_common_covers):
        node_id = f"R{j}"
        readable_label = create_enhanced_label(pset, "Right")
        # Add relationship to child_meet in node attributes
        shared_with_child = pset & edge.child_meet
        if not shared_with_child:
            relationship = "divergent"
        elif pset.issubset(edge.child_meet):
            relationship = "collapsed"
        else:
            relationship = "intermediate"

        G.add_node(node_id, pset=pset, side='right', label=readable_label,
                  size=100 + (len(pset) * 50), relationship=relationship)
        right_labels.append(node_id)

    # ---------- Build edges between left and right sets with enhanced overlap info ----------
    for l_id in left_labels:
        pset_left = G.nodes[l_id]['pset']
        for r_id in right_labels:
            pset_right = G.nodes[r_id]['pset']
            shared = pset_left & pset_right
            if shared:  # non-empty intersection
                overlap_pct_left = int((len(shared) / len(pset_left)) * 100) if pset_left else 0
                overlap_pct_right = int((len(shared) / len(pset_right)) * 100) if pset_right else 0

                # Calculate overlap with child_meet
                shared_in_child = shared & edge.child_meet
                child_meet_pct = int((len(shared_in_child) / len(shared)) * 100) if shared else 0

                G.add_edge(
                    l_id, r_id,
                    weight=len(shared),
                    overlap_left=overlap_pct_left,
                    overlap_right=overlap_pct_right,
                    child_meet_overlap=child_meet_pct
                )

    # ---------- Use a better layout algorithm for positioning ----------
    pos = {}
    # Position left nodes on left side (x=-0.6)
    total_left = len(left_labels)
    for i, label in enumerate(left_labels):
        y_pos = 0.1 + ((total_left - i - 1) / max(1, total_left - 1)) * 0.8
        pos[label] = (-0.6, y_pos)

    # Position right nodes on right side (x=0.6)
    total_right = len(right_labels)
    for i, label in enumerate(right_labels):
        y_pos = 0.1 + ((total_right - i - 1) / max(1, total_right - 1)) * 0.8
        pos[label] = (0.6, y_pos)

    # ---------- Create figure with wider proportions ----------
    fig_width = max(16, 10 + min(len(left_labels), len(right_labels)))
    fig_height = max(8, 6 + max(len(left_labels), len(right_labels)))
    plt.figure(figsize=(fig_width, fig_height), dpi=100, facecolor=color_map["background"])
    ax = plt.gca()
    ax.set_facecolor(color_map["background"])

    # ---------- Draw edges with enhanced styling based on relationship to child_meet ----------
    edge_labels = {}
    for u, v, data in G.edges(data=True):
        w = data.get('weight', 0)
        left_pct = data.get('overlap_left', 0)
        right_pct = data.get('overlap_right', 0)
        child_pct = data.get('child_meet_overlap', 0)
        edge_labels[(u, v)] = f"{w} ({left_pct}%|{right_pct}%) [{child_pct}% in child]"

    # Draw edges with width proportional to overlap and color based on child_meet overlap
    for u, v, data in G.edges(data=True):
        width = 1 + min(data.get('weight', 0) / 3, 6)  # Limit max width
        child_meet_pct = data.get('child_meet_overlap', 0)

        # Select color based on child_meet overlap
        if child_meet_pct > 80:
            edge_color = color_map["common"]  # Strong overlap with child_meet
        elif child_meet_pct > 40:
            edge_color = color_map["intermediate"]  # Partial overlap
        else:
            edge_color = color_map["divergent"]  # Weak or no overlap

        nx.draw_networkx_edges(
            G, pos,
            edgelist=[(u, v)],
            width=width,
            alpha=0.8,
            edge_color=edge_color,
            connectionstyle="arc3,rad=0.1"  # Add a slight curve
        )

    # Edge labels with enhanced overlap info
    nx.draw_networkx_edge_labels(
        G, pos,
        edge_labels=edge_labels,
        label_pos=0.5,
        font_color=color_map["text"],
        font_size=9,
        font_family='monospace',
        font_weight='bold',
        bbox=dict(facecolor=color_map["background"], edgecolor="none", alpha=0.7)
    )

    # ---------- Node styling with borders indicating relationship type ----------
    node_colors = {}
    node_borders = {}
    for node in G.nodes():
        # Set main color based on side (left/right)
        side = G.nodes[node]['side']
        node_colors[node] = color_map["tree1"] if side == 'left' else color_map["tree2"]

        # Set border color based on relationship to child_meet
        relationship = G.nodes[node]['relationship']
        node_borders[node] = color_map[relationship]

    # Get node sizes based on content
    node_sizes = [G.nodes[n].get('size', 1000) for n in G.nodes()]

    # Draw nodes with size reflecting content and borders showing relationship
    for node in G.nodes():
        nx.draw_networkx_nodes(
            G,
            pos=pos,
            nodelist=[node],
            node_color=node_colors[node],
            node_size=G.nodes[node].get('size', 1000),
            edgecolors=node_borders[node],
            linewidths=3
        )

    # ---------- Node labels with wrapped text ----------
    node_labels = {node: G.nodes[node]['label'] for node in G.nodes()}

    for node, (x, y) in pos.items():
        label_text = node_labels[node]
        lines = label_text.split('\n')
        line_height = 0.025  # Adjust based on your needs

        for i, line in enumerate(lines):
            plt.text(
                x, y + (0.06 - i * line_height),
                line,
                ha='center',
                va='center',
                color=color_map["text"],
                fontsize=9,
                fontweight='bold',
                fontfamily='monospace',
                bbox=dict(facecolor=node_colors[node], edgecolor="none", alpha=0.7, pad=2)
            )

    # ---------- Enhanced legend with relationship information ----------
    legend_elements = [
        mpatches.Patch(color=color_map["tree1"], label="Left Cover", alpha=0.9),
        mpatches.Patch(color=color_map["tree2"], label="Right Cover", alpha=0.9),
        mpatches.Patch(color=color_map["common"], label="Strong Child Overlap (>80%)", alpha=0.9),
        mpatches.Patch(color=color_map["intermediate"], label="Partial Child Overlap (40-80%)", alpha=0.9),
        mpatches.Patch(color=color_map["divergent"], label="Weak Child Overlap (<40%)", alpha=0.9),
        mpatches.Patch(edgecolor=color_map["collapsed"], facecolor='none', linewidth=2, label="Collapsed (subset of child_meet)"),
        mpatches.Patch(edgecolor=color_map["divergent"], facecolor='none', linewidth=2, label="Divergent (no overlap with child_meet)"),
    ]

    lgd = plt.legend(
        handles=legend_elements,
        loc="lower center",
        frameon=True,
        fancybox=False,
        framealpha=0.8,
        shadow=False,
        edgecolor=color_map["node_border"],
        fontsize=10,
        bbox_to_anchor=(0.5, 0.05),  # Position below the plot
        ncol=2  # Use two columns for the legend
    )
    for text in lgd.get_texts():
        text.set_color(color_map["text"])

    # ---------- Add enhanced summary statistics ----------
    all_left_partitions = sum(len(G.nodes[n]['pset']) for n in left_labels)
    all_right_partitions = sum(len(G.nodes[n]['pset']) for n in right_labels)

    # Count shared partitions
    left_union = PartitionSet()
    for l in left_labels:
        left_union |= G.nodes[l]['pset']

    right_union = PartitionSet()
    for r in right_labels:
        right_union |= G.nodes[r]['pset']

    shared = len(left_union & right_union)
    child_meet_size = len(edge.child_meet) if hasattr(edge, 'child_meet') else 0

    # Count relationship types
    left_divergent = sum(1 for l in left_labels if G.nodes[l]['relationship'] == 'divergent')
    left_collapsed = sum(1 for l in left_labels if G.nodes[l]['relationship'] == 'collapsed')
    right_divergent = sum(1 for r in right_labels if G.nodes[r]['relationship'] == 'divergent')
    right_collapsed = sum(1 for r in right_labels if G.nodes[r]['relationship'] == 'collapsed')

    stats_text = (
        f"Statistics: {all_left_partitions} partitions in left cover, "
        f"{all_right_partitions} in right cover, {shared} shared\n"
        f"Child meet: {child_meet_size} partitions, Left: {left_divergent} divergent, {left_collapsed} collapsed, "
        f"Right: {right_divergent} divergent, {right_collapsed} collapsed"
    )
    plt.figtext(
        0.5, 0.95,  # Moved to top of figure
        stats_text,
        ha='center',
        fontsize=9,
        family='monospace',
        color=color_map["text"]
    )

    # ---------- Enhanced title ----------
    split_info = f" for split {edge.split}" if hasattr(edge, 'split') and edge.split else ""
    plt.title(
        f"{title}{split_info}",
        fontsize=16,
        fontweight='bold',
        family='monospace',
        color=color_map["text"],
        pad=20
    )

    plt.axis('off')
    # Added more padding to avoid tight edges
    plt.tight_layout(pad=3)

    if output_path:
        plt.savefig(output_path, dpi=300, bbox_inches='tight', facecolor=color_map["background"])
        print(f"Cover comparison saved to {output_path}")
    else:
        plt.show()

def compare_sub_lattice_edges(edges: List[LatticeEdge], max_edges: int = 10):
    """
    For a given sub-lattice (list of LatticeEdge objects), compare all edges' t1_common_covers vs. t2_common_covers.
    Each edge is visualized with the enhanced visualization showing relationships to child_meet.

    Args:
        edges: List of LatticeEdge objects to visualize
        max_edges: Maximum number of edges to visualize (default: 10)
    """
    if not edges:
        print("No edges to compare.")
        return

    # Limit number of edges to avoid excessive plots
    if len(edges) > max_edges:
        print(f"Warning: Many edges found ({len(edges)}). Showing first {max_edges}...")
        edges = edges[:max_edges]

    for i, edge in enumerate(edges):
        print(f"\n=== Edge {i+1}/{len(edges)} ===")
        print(f"Split: {edge.split}")

        # Get partition counts
        left_count = sum(len(ps) for ps in edge.t1_common_covers)
        right_count = sum(len(ps) for ps in edge.t2_common_covers)

        print(f"Left cover: {len(edge.t1_common_covers)} sets with {left_count} total partitions")
        print(f"Right cover: {len(edge.t2_common_covers)} sets with {right_count} total partitions")

        visualize_cover_sets(edge, title=f"Edge {i+1} Cover Comparison")

# Example usage
if __name__ == "__main__":
    import matplotlib.pyplot as plt
    import networkx as nx
    from brancharchitect.newick_parser import parse_newick
    from brancharchitect.jumping_taxa.lattice.lattice_construction import construct_sub_lattices

    tree_str1 = "(((((A1:1.0,A2:1.0):1.0,A3:1.0):1.0,(((B1:1.0,B2:1.0):1.0,C:1.0):1.0,(D1:1.0,(E:1.0,D2:1.0):1.0):1.0):1.0):1.0,(F1:1.0,(F2:1.0,F3:1.0):1.0):1.0):1.0,(O1:1.0,O2:1.0):1.0):1.0;"
    tree_str2 = "((((A1:1.0,((A2:1.0,(B1:1.0,B2:1.0):1.0):1.0,A3:1.0):1.0):1.0,(C:1.0,(D1:1.0,D2:1.0):1.0):1.0):1.0,((E:1.0,F1:1.0):1.0,(F2:1.0,F3:1.0):1.0):1.0):1.0,(O1:1.0,O2:1.0):1.0):1.0;"

    t1, t2 = parse_newick(tree_str1 + tree_str2, force_list=True)

    sub_lattices = construct_sub_lattices(t1, t2)
    compare_sub_lattice_edges(sub_lattices)

ModuleNotFoundError: No module named 'brancharchitect.newick_parser'