In [174]:
import os
import multiprocessing
from collections import defaultdict
import datetime
import gc
import re
import sqlite3
import string
import time

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy.sparse as sp
from joblib import Parallel, delayed
from scipy.cluster.hierarchy import linkage, fcluster
from scipy.spatial.distance import squareform
from scipy.spatial import ConvexHull
from threadpoolctl import threadpool_limits
from tqdm import tqdm

# ----------------------------------------------------------------------
# Parallel / BLAS configuration
# ----------------------------------------------------------------------
num_logical_cores = os.cpu_count()
if num_logical_cores:
    os.environ["OMP_NUM_THREADS"] = str(num_logical_cores)
else:
    os.environ["OMP_NUM_THREADS"] = "1"

# ----------------------------------------------------------------------
# Configuration
# ----------------------------------------------------------------------
GENES_TABLE = "attributes"
NEIGHBORS_TABLE = "neighbors"

COL_NEIGHBORHOOD_ID = "organism"
COL_GENE_ID = "id"
COL_LINKING_KEY = "id"
COL_ACCESSION_ID = "accession"
COL_FUNCTION_DESC = "desc"
COL_PFAM_IDS = "family"
COL_INTERPRO_IDS = "ipro_family"
COL_REL_START = "rel_start"
COL_REL_STOP = "rel_stop"
COL_SSN_CLUSTER_ID = "cluster_num"

HIT_GENE_WEIGHT_FACTOR = 10
DIRECT_NEIGHBOR_WEIGHT_FACTOR = 3

DEFAULT_SSN_CLUSTER_VALUE_TO_FILTER = [None, 0]

SAVE_PLOTS = True
OUTPUT_DIR = "gnn_cluster_plots_circular"
REPORT_FILENAME_BASE = "gnn_clustering_report"
OUTPUT_FORMATS = ["pdf"]
DPI = 600
HIGHLIGHT_COLOR = "red"

COLLAPSE_IDENTICAL_NEIGHBORHOODS = True
COLLAPSE_CORE_SIMILARITY_THRESHOLD = 0.0
COLLAPSE_FULL_NEIGHBORHOOD_SIMILARITY_THRESHOLD = 0.3

MIN_ITEMS_FOR_PARALLEL_PROCESSING = 20

# Precompiled regex / constants
_IPR_REGEX = re.compile(r"IPR\d+", re.IGNORECASE)
_PFAM_REGEX = re.compile(r"PF\d+", re.IGNORECASE)
_UNINFORMATIVE_TERMS = frozenset(["none", "", "null", "uncharacterized protein"])


# ----------------------------------------------------------------------
# Basic feature parsing
# ----------------------------------------------------------------------
def parse_annotation_string(annotation_str, prefix=""):
    if not isinstance(annotation_str, str) or pd.isna(annotation_str):
        return set()
    if annotation_str.lower().strip() in _UNINFORMATIVE_TERMS:
        return set()

    features = set()
    parts = [p.strip() for p in re.split(r"[-;]", annotation_str) if p.strip()]

    for part in parts:
        if part.lower().strip() in _UNINFORMATIVE_TERMS:
            continue

        if _IPR_REGEX.match(part):
            features.add(f"{prefix}{part.upper()}")
        elif _PFAM_REGEX.match(part):
            features.add(f"{prefix}{part.upper()}")
        else:
            clean_part = re.sub(r"\s+", " ", part).lower().strip()
            if clean_part:
                features.add(f"{prefix}{clean_part}")
    return features


def extract_features_from_gene_row(
    gene_row,
    current_weight_factor=1,
    base_prefix="N_",
    include_desc=True,
    include_pfam=True,
    include_interpro=True,
):
    raw_features = set()

    if include_desc:
        raw_features.update(parse_annotation_string(gene_row[COL_FUNCTION_DESC]))
    if include_pfam:
        raw_features.update(parse_annotation_string(gene_row[COL_PFAM_IDS]))
    if include_interpro:
        raw_features.update(parse_annotation_string(gene_row[COL_INTERPRO_IDS]))

    if current_weight_factor <= 1:
        return {f"{base_prefix}{f}" for f in raw_features}

    features = set()
    for f in raw_features:
        for i in range(current_weight_factor):
            features.add(f"{base_prefix}{f}_w{i}")
    return features


# ----------------------------------------------------------------------
# Distance calculation (Jaccard) with optional parallelism
# ----------------------------------------------------------------------
def parallel_pdist_jaccard(feature_matrix, num_cores=-1):
    """
    Compute condensed Jaccard distance for a binary sparse matrix (CSR).
    """
    if not isinstance(feature_matrix, sp.csr_matrix):
        if isinstance(feature_matrix, (sp.csc_matrix, sp.lil_matrix, sp.coo_matrix)):
            feature_matrix = feature_matrix.tocsr()
        else:
            raise TypeError("feature_matrix must be a SciPy sparse matrix")

    n_samples = feature_matrix.shape[0]
    if n_samples <= 1:
        return np.array([])

    if num_cores == -1:
        detected = os.cpu_count()
        num_cores = detected if detected and detected > 0 else 1
    elif num_cores == 0:
        num_cores = 1

    # Precompute row index sets once
    print(f"  Pre-computing feature sets for N={n_samples} ...")
    t0 = time.time()
    feature_sets = [
        set(feature_matrix.indices[feature_matrix.indptr[i] : feature_matrix.indptr[i + 1]])
        for i in tqdm(range(n_samples), desc="  Feature sets", leave=False)
    ]
    print(f"  Done in {time.time() - t0:.2f}s")
    del feature_matrix
    gc.collect()

    def _dist_chunk(start_i, end_i, sets_ref, n_total):
        with threadpool_limits(limits=1, user_api="blas"):
            chunk = []
            for i in range(start_i, end_i):
                si = sets_ref[i]
                for j in range(i + 1, n_total):
                    sj = sets_ref[j]
                    inter = len(si & sj)
                    union = len(si | sj)
                    d = 0.0 if union == 0 else 1.0 - inter / union
                    chunk.append(d)
            return chunk

    total_i = n_samples - 1
    if total_i <= 0:
        return np.array([])

    if n_samples < MIN_ITEMS_FOR_PARALLEL_PROCESSING or num_cores == 1:
        print(f"  Using sequential Jaccard (N={n_samples})")
        res = [_dist_chunk(0, total_i, feature_sets, n_samples)]
    else:
        print(f"  Using parallel Jaccard (N={n_samples}, cores={num_cores})")
        num_tasks = min(total_i, num_cores * 4)
        chunk_size = max(1, (total_i + num_tasks - 1) // num_tasks)
        ranges = [(k, min(k + chunk_size, total_i)) for k in range(0, total_i, chunk_size)]

        tasks = [
            delayed(_dist_chunk)(start, end, feature_sets, n_samples) for start, end in ranges
        ]
        res = Parallel(n_jobs=num_cores, backend="loky", verbose=0)(tasks)

    return np.concatenate(res)


# ----------------------------------------------------------------------
# Collapsing similar neighborhoods
# ----------------------------------------------------------------------
def _perform_collapsing(
    all_neighborhood_features,
    full_neighborhood_labels_map,
    core_neighborhood_features,
    collapse_core_similarity_threshold,
    collapse_full_neighborhood_similarity_threshold,
    output_prefix="",
    report_file=None,
    parallelize_pdist=False,
):

    def log(msg):
        print(msg)
        if report_file:
            report_file.write(msg + "\n")

    log(
        f"{output_prefix}  Collapsing neighborhoods (core thr: {collapse_core_similarity_threshold}, "
        f"full thr: {collapse_full_neighborhood_similarity_threshold})"
    )
    t0_all = time.time()
    collapsed_groups_report = {}

    # ---- Stage 1: core-based grouping ----
    labels = sorted(core_neighborhood_features.keys())
    if len(labels) < 2:
        log(f"{output_prefix}  <2 neighborhoods; skipping collapsing.")
        return all_neighborhood_features, full_neighborhood_labels_map, collapsed_groups_report

    core_vocab = sorted(set.union(*core_neighborhood_features.values()))
    if not core_vocab:
        log(f"{output_prefix}  No core features; skipping collapsing.")
        return all_neighborhood_features, full_neighborhood_labels_map, collapsed_groups_report

    log(f"{output_prefix}  Stage1: building core sparse matrix")
    feat_to_idx = {f: i for i, f in enumerate(core_vocab)}
    n_nh = len(labels)
    n_feat = len(core_vocab)
    mat_lil = sp.lil_matrix((n_nh, n_feat), dtype=np.int8)
    for i, nh in enumerate(tqdm(labels, desc=f"{output_prefix}  Core features", leave=False)):
        for f in core_neighborhood_features[nh]:
            j = feat_to_idx[f]
            mat_lil[i, j] = 1
    core_mat = mat_lil.tocsr()
    del mat_lil
    gc.collect()

    if core_mat.shape[0] < 2:
        log(f"{output_prefix}  <2 core vectors; skipping collapsing.")
        return all_neighborhood_features, full_neighborhood_labels_map, collapsed_groups_report

    # Check if all identical
    if n_nh > 1 and all((core_mat[0] != core_mat[i]).nnz == 0 for i in range(1, n_nh)):
        # everything identical -> one group
        pre_clusters = {labels[i]: 1 for i in range(len(labels))}
        log(f"{output_prefix}  All core vectors identical; treating as one initial group.")
    else:
        log(f"{output_prefix}  Stage1: computing core Jaccard distances")
        t0 = time.time()
        dists = parallel_pdist_jaccard(core_mat, num_cores=-1 if parallelize_pdist else 1)
        log(f"{output_prefix}  Stage1: distance calc in {time.time() - t0:.2f}s")
        from scipy.cluster.hierarchy import linkage, fcluster

        link_core = linkage(dists, method="average")
        pre_clusters_array = fcluster(link_core, collapse_core_similarity_threshold, criterion="distance")
        pre_clusters = {labels[i]: pre_clusters_array[i] for i in range(len(labels))}

    initial_core_groups = defaultdict(list)
    for nh, cid in pre_clusters.items():
        initial_core_groups[cid].append(nh)

    log(
        f"{output_prefix}  Stage1: {len(initial_core_groups)} core groups "
        f"({time.time() - t0_all:.2f}s partial)"
    )

    del core_mat
    if "dists" in locals():
        del dists
    gc.collect()

    # ---- Stage 2: full neighborhood collapsing within each core group ----
    def gen_letter(idx):
        if idx < 26:
            return string.ascii_uppercase[idx]
        first = (idx // 26) - 1
        second = idx % 26
        return f"{string.ascii_uppercase[first]}{string.ascii_uppercase[second]}"

    log(f"{output_prefix}  Stage2: full-neighborhood collapsing")

    def process_core_group_chunk(
        group_ids,
        all_feat_ref,
        labels_map_ref,
        thr_full,
        allow_parallel,
    ):
        out = []
        collapsed_count = 0
        cores_inner = -1 if allow_parallel else 1

        with threadpool_limits(limits=1, user_api="blas"):
            for gid in group_ids:
                members = initial_core_groups[gid]
                if len(members) < 2:
                    m = members[0]
                    out.append((m, all_feat_ref[m], labels_map_ref[m], None))
                    continue

                vocab = sorted(set.union(*[all_feat_ref[m] for m in members]))
                if not vocab:
                    assignments = {m: 1 for m in members}
                else:
                    ft_idx = {f: i for i, f in enumerate(vocab)}
                    n_m = len(members)
                    m_lil = sp.lil_matrix((n_m, len(vocab)), dtype=np.int8)
                    for i, nh in enumerate(members):
                        for f in all_feat_ref[nh]:
                            j = ft_idx[f]
                            m_lil[i, j] = 1
                    mat = m_lil.tocsr()

                    if n_m > 1 and all((mat[0] != mat[i]).nnz == 0 for i in range(1, n_m)):
                        assignments = {m: 1 for m in members}
                    else:
                        d = parallel_pdist_jaccard(mat, num_cores=cores_inner)
                        if d.size == 0:
                            assignments = {m: 1 for m in members}
                        else:
                            link = linkage(d, method="average")
                            arr = fcluster(link, thr_full, criterion="distance")
                            assignments = {members[i]: arr[i] for i in range(len(members))}
                    del mat
                    if "d" in locals():
                        del d
                    if "link" in locals():
                        del link
                    gc.collect()

                groups = defaultdict(list)
                for nh, cid in assignments.items():
                    groups[cid].append(nh)

                for cid, collapsed_members in sorted(groups.items()):
                    if len(collapsed_members) > 1:
                        collapsed_count += len(collapsed_members) - 1
                        rep = collapsed_members[0]
                        union_features = set()
                        for nh in collapsed_members:
                            union_features.update(all_feat_ref[nh])
                        out.append((rep, union_features, labels_map_ref[rep], collapsed_members))
                    else:
                        m = collapsed_members[0]
                        out.append((m, all_feat_ref[m], labels_map_ref[m], None))

        return out, collapsed_count

    all_ids = sorted(initial_core_groups.keys())
    n_core_groups = len(all_ids)
    if n_core_groups < MIN_ITEMS_FOR_PARALLEL_PROCESSING or not parallelize_pdist:
        chunk_results = [
            process_core_group_chunk(
                all_ids,
                all_neighborhood_features,
                full_neighborhood_labels_map,
                collapse_full_neighborhood_similarity_threshold,
                False,
            )
        ]
    else:
        n_cores = os.cpu_count() or 1
        chunk_size = max(1, n_core_groups // n_cores)
        chunks = [all_ids[i : i + chunk_size] for i in range(0, n_core_groups, chunk_size)]
        chunk_results = Parallel(n_jobs=n_cores, backend="loky", verbose=0)(
            delayed(process_core_group_chunk)(
                ch,
                all_neighborhood_features,
                full_neighborhood_labels_map,
                collapse_full_neighborhood_similarity_threshold,
                parallelize_pdist,
            )
            for ch in chunks
        )

    final_neighborhood_features = {}
    final_neighborhood_labels_map = {}
    collapsed_total = 0
    letter_counter = 0

    for res, local_count in chunk_results:
        collapsed_total += local_count
        for rep_label, feats, label_entry, collapsed_members in res:
            if collapsed_members is not None:
                code = gen_letter(letter_counter)
                letter_counter += 1
                org, hit_id, ssn_id, acc, _ = label_entry
                final_neighborhood_labels_map[rep_label] = (
                    org,
                    hit_id,
                    ssn_id,
                    acc,
                    (len(collapsed_members), code),
                )
                collapsed_groups_report[code] = {
                    "representative": rep_label,
                    "members": sorted(collapsed_members),
                    "count": len(collapsed_members),
                }
                final_neighborhood_features[rep_label] = feats
            else:
                final_neighborhood_features[rep_label] = feats
                final_neighborhood_labels_map[rep_label] = label_entry

    if collapsed_total > 0:
        log(
            f"{output_prefix}  Collapsed {collapsed_total} neighborhoods -> "
            f"{len(final_neighborhood_features)} unique entries "
            f"({time.time() - t0_all:.2f}s total)."
        )
    else:
        log(f"{output_prefix}  No collapsing performed ({time.time() - t0_all:.2f}s).")

    return final_neighborhood_features, final_neighborhood_labels_map, collapsed_groups_report


# ----------------------------------------------------------------------
# Circular phylogram plotting
# ----------------------------------------------------------------------
def _hierarchy_to_children(linkage_matrix, labels):
    """
    Build adjacency (children) list from scipy linkage matrix.
    Leaf node ids: 0..(n-1)
    Internal node ids: n..(n+linkage_matrix.shape[0]-1)
    """
    n_leaves = len(labels)
    children = {}
    for i, (c1, c2, dist, _) in enumerate(linkage_matrix):
        node_id = n_leaves + i
        children[node_id] = (int(c1), int(c2))
    node_ids = set(range(n_leaves)) | set(children.keys())
    return children, node_ids


def _gather_descendant_leaves(node_id, children, n_leaves, cache=None):
    """
    Return list of leaf node indices descending from node_id.
    """
    if cache is None:
        cache = {}
    if node_id in cache:
        return cache[node_id]

    if node_id < n_leaves:
        cache[node_id] = [node_id]
        return cache[node_id]

    c1, c2 = children[node_id]
    leaves = _gather_descendant_leaves(c1, children, n_leaves, cache) + \
             _gather_descendant_leaves(c2, children, n_leaves, cache)
    cache[node_id] = leaves
    return leaves


def _collect_cluster_colors(cluster_assignments, labels, cmap_name="tab20"):
    clusters = sorted(set(cluster_assignments))
    cmap = plt.get_cmap(cmap_name, len(clusters))
    cluster_to_color = {cid: cmap(i) for i, cid in enumerate(clusters)}
    label_to_color = {}
    for label_idx, cid in enumerate(cluster_assignments):
        label_to_color[labels[label_idx]] = cluster_to_color[cid]
    return cluster_to_color, label_to_color


def _compute_dual_scale_radii(
    feature_matrix, 
    linkage_matrix, 
    labels, 
    inner_boundary=0.6, 
    consensus_threshold=0.25,
    leaf_stretch_power=1.5
):
    n_leaves = len(labels)
    n_internal = linkage_matrix.shape[0]
    
    # 1. Consensus Ancestor Logic
    col_counts = np.array((feature_matrix > 0).sum(axis=0)).flatten()
    threshold = n_leaves * consensus_threshold
    mca_vec = (col_counts >= threshold).astype(np.int8)
    
    leaf_div = np.zeros(n_leaves)
    for i in range(n_leaves):
        leaf_vec = feature_matrix[i].toarray().flatten().astype(np.int8)
        intersection = np.sum(np.logical_and(leaf_vec, mca_vec))
        union = np.sum(np.logical_or(leaf_vec, mca_vec))
        leaf_div[i] = 1.0 - (intersection / union) if union != 0 else 1.0

    max_div = np.max(leaf_div) if np.max(leaf_div) > 0 else 1.0
    norm_leaf_div = leaf_div / max_div

    # 2. Tree Scale (Inner branches remain linear relative to height)
    max_h = linkage_matrix[-1, 2] if n_internal > 0 else 1.0
    final_radii = {}
    for i in range(n_internal):
        node_id = n_leaves + i
        h = linkage_matrix[i, 2]
        progress = 1.0 - (h / max_h) 
        final_radii[node_id] = progress * inner_boundary

    # 3. Leaf Scale (Non-linear Stretch)
    children, _ = _hierarchy_to_children(linkage_matrix, labels)
    parent_map = {c1: p_id for p_id, (c1, c2) in children.items()}
    parent_map.update({c2: p_id for p_id, (c1, c2) in children.items()})

    for i in range(n_leaves):
        parent_id = parent_map.get(i)
        parent_r = final_radii[parent_id] if parent_id is not None else 0
        
        # Apply the Non-Linear Power Scale here
        # This only affects the distance from inner_boundary to 1.0
        stretched_div = norm_leaf_div[i] ** leaf_stretch_power
        
        proposed_r = inner_boundary + (stretched_div * (1.0 - inner_boundary))
        final_radii[i] = max(proposed_r, parent_r + 0.01)

    return final_radii


def _compute_tree_metrics(linkage_matrix, labels):
    """
    Computes root-to-node distances and subtree weights.
    Returns:
      node_root_dist: dict of {node_id: distance_from_root}
      subtree_weight: dict of {node_id: sum_of_internal_distances} (used for angular spacing)
    """
    n_leaves = len(labels)
    n_total = n_leaves + linkage_matrix.shape[0]
    children, node_ids = _hierarchy_to_children(linkage_matrix, labels)
    
    # Height of nodes from linkage (distance at which clusters merge)
    # heights[leaf] = 0; heights[root] = max_dist
    heights = {i: 0.0 for i in range(n_leaves)}
    for i, (_, _, dist, _) in enumerate(linkage_matrix):
        heights[n_leaves + i] = float(dist)

    root_id = n_total - 1
    node_root_dist = {root_id: 0.0}
    subtree_weight = {i: 1.0 for i in range(n_leaves)} # Base weight for leaves

    # Breadth-first or Depth-first to get root-to-node distances
    # Distance = parent_height - child_height is the standard branch length
    stack = [root_id]
    while stack:
        curr = stack.pop()
        if curr in children:
            c1, c2 = children[curr]
            # Branch length is the difference in linkage 'age'
            branch_l1 = heights[curr] - heights[c1]
            branch_l2 = heights[curr] - heights[c2]
            
            node_root_dist[c1] = node_root_dist[curr] + branch_l1
            node_root_dist[c2] = node_root_dist[curr] + branch_l2
            
            stack.extend([c1, c2])

    # Compute subtree weights (sum of distances) to drive angular distribution
    # We do this bottom-up
    for i in range(linkage_matrix.shape[0]):
        node_id = n_leaves + i
        c1, c2 = children[node_id]
        # Weight is proportional to the distance of the split + children weights
        subtree_weight[node_id] = subtree_weight[c1] + subtree_weight[c2] + linkage_matrix[i, 2]

    return node_root_dist, subtree_weight, heights


def _compute_cluster_aware_angles(linkage_matrix, labels, subtree_weight, gap_factor=0.05):
    """
    Assigns angles with explicit gaps between branches to separate clusters.
    gap_factor: percentage of the available arc to leave empty at each split.
    """
    n_leaves = len(labels)
    children, _ = _hierarchy_to_children(linkage_matrix, labels)
    root_id = n_leaves + linkage_matrix.shape[0] - 1
    angles = {}

    def assign_angle(node_id, theta_start, theta_end):
        if node_id < n_leaves:
            angles[node_id] = (theta_start + theta_end) / 2.0
            return

        c1, c2 = children[node_id]
        w1, w2 = subtree_weight[c1], subtree_weight[c2]
        
        # Calculate available arc after removing a gap
        total_arc = theta_end - theta_start
        current_gap = total_arc * gap_factor
        usable_arc = total_arc - current_gap
        
        # Split usable arc proportional to weights
        mid_point = theta_start + usable_arc * (w1 / (w1 + w2))
        
        # Branch 1 gets first part, Branch 2 starts after the gap
        assign_angle(c1, theta_start, mid_point)
        assign_angle(c2, mid_point + current_gap, theta_end)

    assign_angle(root_id, 0.0, 2 * np.pi * 0.95) # Leave a final gap at the end
    return angles


def _draw_circular_tree_with_clusters(
    linkage_matrix,
    leaf_labels,
    cluster_assignments,
    label_to_cluster_color,
    original_input_sequence_id=None,
    labels_map=None,
    title=None,
    out_prefix=None,
    output_dir=".",
    output_formats=("pdf",),
    dpi=600,
    show=False,
    feature_matrix=None,
    inner_boundary=0.5,    # Where tree ends
    consensus_threshold=0.25,     # If Features are in 25% of NHs they are "Ancestral"
    leaf_stretch_power=1.5, # Exaggerate leaf distances non-linearly
    gap_factor=0.08
):
    n_leaves = len(leaf_labels)
    
    # 1. Radii & Angles
    radii = _compute_dual_scale_radii(
        feature_matrix, linkage_matrix, leaf_labels, 
        inner_boundary=inner_boundary, consensus_threshold=consensus_threshold, leaf_stretch_power=leaf_stretch_power
    )
    _, subtree_weight, _ = _compute_tree_metrics(linkage_matrix, leaf_labels)
    angles = _compute_cluster_aware_angles(linkage_matrix, leaf_labels, subtree_weight, gap_factor)
    children, _ = _hierarchy_to_children(linkage_matrix, leaf_labels)
    
    # Coordinate mapping (Angles for internal nodes)
    coords = {}
    def get_coords(node_id):
        if node_id in coords: return coords[node_id]
        r = radii[node_id]
        if node_id < n_leaves:
            theta = angles[node_id]
        else:
            c1, c2 = children[node_id]
            _, th1 = get_coords(c1); _, th2 = get_coords(c2)
            theta = (th1 * subtree_weight[c1] + th2 * subtree_weight[c2]) / (subtree_weight[c1] + subtree_weight[c2])
        coords[node_id] = (r, theta)
        return coords[node_id]

    for nid in range(n_leaves + linkage_matrix.shape[0]): get_coords(nid)

    # Publication-quality figure setup
    fig, ax = plt.subplots(subplot_kw={"projection": "polar"}, figsize=(20, 20))
    ax.set_theta_direction(-1)
    ax.set_theta_offset(np.pi / 2.0)
    ax.set_axis_off()
    
    # Set publication-ready style parameters
    plt.rcParams['font.family'] = 'sans-serif'
    plt.rcParams['font.sans-serif'] = ['Arial', 'Helvetica', 'DejaVu Sans']
    plt.rcParams['pdf.fonttype'] = 42  # TrueType fonts for publication
    plt.rcParams['ps.fonttype'] = 42

    # 2. Draw Branches with publication-quality thickness
    for node_id, (c1, c2) in children.items():
        r_p, th_p = coords[node_id]
        for child in (c1, c2):
            r_c, th_c = coords[child]
            
            # Default color with better contrast
            color = label_to_cluster_color[leaf_labels[child]] if child < n_leaves else (0.3, 0.3, 0.3, 0.6)
            linewidth = 2.0  # Thicker for publication
            zorder = 2

            # HIGHLIGHT LOGIC: If this branch leads directly to our target
            if child < n_leaves and labels_map:
                # labels_map[leaf_label] = (org, hit_id, ssn, accession, collapsed)
                acc = labels_map[leaf_labels[child]][3]
                if original_input_sequence_id and acc == original_input_sequence_id:
                    color = "#DC143C"  # Crimson for better publication contrast
                    linewidth = 3.0  # Much thicker for highlight
                    zorder = 10
            
            # Radial branch with smooth rendering
            ax.plot([th_c, th_c], [r_p, r_c], color=color, lw=linewidth, 
                    zorder=zorder, solid_capstyle='round', solid_joinstyle='round')
            
            # Connecting arc
            t_vals = np.linspace(min(th_p, th_c), max(th_p, th_c), 50)
            arc_color = color if child < n_leaves and zorder == 10 else (0.3, 0.3, 0.3, 0.6)
            ax.plot(t_vals, np.full_like(t_vals, r_p), color=arc_color, 
                    lw=linewidth * 0.75, zorder=zorder-1, solid_capstyle='round')

    # 3. Draw Leaves & Labels with improved positioning using zigzag stacking
    base_label_offset = 0.04  # Base offset for label separation from markers
    min_label_radius = 0.7  # Minimum radius for labels to ensure they're on the outside
    max_label_offset = 0.15  # Maximum radial offset to prevent labels from going too far
    zigzag_spacing = 0.04  # Spacing between zigzag levels
    
    # Pre-calculate all leaf positions and group by angular proximity
    leaf_data = []
    for i in range(n_leaves):
        r, th = coords[i]
        leaf_data.append({'idx': i, 'r': r, 'th': th, 'th_deg': np.degrees(th) % 360})
    
    # Sort by angle for grouping
    leaf_data_sorted = sorted(leaf_data, key=lambda x: x['th'])
    
    # Group leaves that are close in angle (within ~12 degrees)
    angular_group_threshold = np.radians(12)
    groups = []
    current_group = [leaf_data_sorted[0]]
    
    for i in range(1, len(leaf_data_sorted)):
        prev_leaf = leaf_data_sorted[i-1]
        curr_leaf = leaf_data_sorted[i]
        
        # Calculate angular distance
        ang_diff = abs(curr_leaf['th'] - prev_leaf['th'])
        if ang_diff > np.pi:
            ang_diff = 2 * np.pi - ang_diff
        
        if ang_diff < angular_group_threshold:
            current_group.append(curr_leaf)
        else:
            groups.append(current_group)
            current_group = [curr_leaf]
    
    groups.append(current_group)
    
    # Calculate label positions with zigzag stacking for each group
    label_positions = {}
    
    for group in groups:
        if len(group) == 1:
            # Single leaf - simple positioning
            leaf = group[0]
            offset = base_label_offset
            label_r = max(leaf['r'] + offset, min_label_radius)
            label_positions[leaf['idx']] = {'r': label_r, 'th': leaf['th']}
        else:
            # Multiple leaves - zigzag stacking
            # Sort by radial distance within group
            group_sorted = sorted(group, key=lambda x: x['r'])
            
            for j, leaf in enumerate(group_sorted):
                # Alternate between closer and farther positions (zigzag)
                zigzag_level = j % 3  # Use 3 levels: 0, 1, 2
                offset = base_label_offset + (zigzag_level * zigzag_spacing)
                offset = min(offset, max_label_offset)  # Cap the offset
                
                label_r = max(leaf['r'] + offset, min_label_radius)
                # Keep angle unchanged - no angular adjustment
                label_positions[leaf['idx']] = {'r': label_r, 'th': leaf['th']}
    
    # Now draw leaves and labels using calculated positions
    for i in range(n_leaves):
        r, th = coords[i]
        internal_label = leaf_labels[i]  # Internal key (organism_hit_id)
        
        # Extract readable label components from labels_map
        if labels_map and internal_label in labels_map:
            organism, hit_id, ssn_id, accession_id, collapsed_info = labels_map[internal_label]
            # Create display label: Organism + Accession ID
            label_text = f"{organism}_{accession_id}" if accession_id else internal_label
        else:
            label_text = internal_label
        
        color = label_to_cluster_color[internal_label]
        size = 50  # Larger markers
        edge_w = 1.0  # Thicker edges
        
        # Check for highlight
        is_highlight = False
        if labels_map and original_input_sequence_id:
            acc = labels_map[internal_label][3]
            if acc == original_input_sequence_id:
                is_highlight = True
                color = "#DC143C"  # Crimson
                size = 180  # Much larger for prominence
                edge_w = 2.5  # Thicker edge

        # Draw leaf node with publication-quality styling
        ax.scatter([th], [r], color=color, s=size, edgecolors='black', 
                   linewidths=edge_w, zorder=20, alpha=0.9)
        
        # Get pre-calculated label position
        label_r = label_positions[i]['r']
        label_th = label_positions[i]['th']
        th_deg = np.degrees(label_th) % 360
        
        # Determine text alignment based on position
        # All labels are positioned OUTSIDE (radially outward from) the circle
        # Right side (315-45°): left-aligned, labels go to the right
        # Top (45-135°): center-aligned, labels go upward  
        # Left side (135-225°): right-aligned, labels go to the left
        # Bottom (225-315°): center-aligned, labels go downward
        
        if 315 <= th_deg or th_deg < 45:
            ha = 'left'
            va = 'center'
        elif 45 <= th_deg < 135:
            ha = 'center'
            va = 'bottom'
        elif 135 <= th_deg < 225:
            ha = 'right'
            va = 'center'
        else:
            ha = 'center'
            va = 'top'
        
        # Always keep text horizontal for maximum readability
        ax.text(label_th, label_r, label_text, 
                fontsize=15 if not is_highlight else 16,  # Larger fonts for publication
                rotation=0,  # Keep horizontal
                ha=ha, 
                va=va, 
                fontweight='bold' if is_highlight else 'normal',
                color="#DC143C" if is_highlight else "black",
                bbox=dict(boxstyle='round,pad=0.4', 
                         facecolor='white' if is_highlight else 'white',
                         edgecolor='#DC143C' if is_highlight else 'none',
                         alpha=0.85 if is_highlight else 0.75,
                         linewidth=1.5 if is_highlight else 0),
                zorder=25)

    # Adjust plot limits to accommodate labels pushed further outward
    ax.set_ylim(0, 1.2)
    
    plt.tight_layout()

    if out_prefix and SAVE_PLOTS:
        os.makedirs(output_dir, exist_ok=True)
        for fmt in output_formats:
            fig.savefig(
                os.path.join(output_dir, f"{out_prefix}_circular_tree.{fmt}"),
                dpi=dpi,
                bbox_inches="tight",
                facecolor='white',
                edgecolor='none'
            )
        plt.close(fig)
    elif show:
        plt.show()
    else:
        plt.close(fig)
                       
                

# ----------------------------------------------------------------------
# Dynamic clustering methods
# ----------------------------------------------------------------------
def _find_optimal_clusters_by_gap(linkage_matrix, min_clusters=2, min_gap_ratio=1.5):
    """
    Find optimal number of clusters by detecting large gaps in merge heights.
    Looks for the largest relative jump in the dendrogram.
    
    Parameters:
    - linkage_matrix: scipy linkage matrix
    - min_clusters: minimum number of clusters to consider
    - min_gap_ratio: minimum ratio between consecutive heights to consider it a significant gap
    
    Returns:
    - optimal_distance: distance threshold for cutting the dendrogram
    - num_clusters: estimated number of clusters
    """
    if linkage_matrix.shape[0] < 2:
        return 0.5, 1
    
    # Extract merge heights (distances at which clusters merge)
    heights = linkage_matrix[:, 2]
    
    # Calculate gaps between consecutive merges
    gaps = np.diff(heights)
    gap_ratios = gaps[1:] / (gaps[:-1] + 1e-10)  # Avoid division by zero
    
    # Find the largest gap that produces at least min_clusters
    # Start from the end (highest merges = fewest clusters)
    max_gap_ratio = -np.inf
    best_idx = len(heights) - 1
    
    for i in range(len(heights) - min_clusters, 0, -1):
        if i < len(gap_ratios):
            if gap_ratios[i] > max_gap_ratio and gap_ratios[i] >= min_gap_ratio:
                max_gap_ratio = gap_ratios[i]
                best_idx = i
    
    # Cut at the height just before the large gap
    optimal_distance = (heights[best_idx] + heights[best_idx + 1]) / 2.0 if best_idx < len(heights) - 1 else heights[best_idx] * 1.1
    num_clusters = linkage_matrix.shape[0] - best_idx + 1
    
    return optimal_distance, num_clusters


def _find_optimal_clusters_by_inconsistency(linkage_matrix, depth=2, threshold=1.5):
    """
    Use inconsistency coefficient to find natural clusters.
    Inconsistency measures how different a merge is compared to nearby merges.
    
    Parameters:
    - linkage_matrix: scipy linkage matrix
    - depth: how many levels to look back for comparison
    - threshold: inconsistency threshold for cutting
    
    Returns:
    - cluster_assignments: array of cluster labels
    """
    from scipy.cluster.hierarchy import inconsistent, fcluster
    
    if linkage_matrix.shape[0] < 2:
        return np.ones(linkage_matrix.shape[0] + 1, dtype=int)
    
    # Calculate inconsistency coefficients
    incons = inconsistent(linkage_matrix, d=depth)
    
    # Find a good threshold based on inconsistency statistics
    # Use the mean + threshold * std of inconsistency values
    mean_incons = np.mean(incons[:, 3])  # Column 3 is the inconsistency coefficient
    std_incons = np.std(incons[:, 3])
    
    cutoff = mean_incons + threshold * std_incons
    
    # Use the inconsistency criterion for clustering
    cluster_assignments = fcluster(linkage_matrix, cutoff, criterion='inconsistent', depth=depth)
    
    return cluster_assignments


def _find_optimal_clusters_by_elbow(linkage_matrix, max_clusters=None):
    """
    Use the elbow method on within-cluster distances.
    Finds the point where adding more clusters doesn't significantly improve separation.
    
    Parameters:
    - linkage_matrix: scipy linkage matrix
    - max_clusters: maximum number of clusters to test (default: sqrt(n))
    
    Returns:
    - optimal_distance: distance threshold for cutting
    - optimal_k: optimal number of clusters
    """
    from scipy.cluster.hierarchy import fcluster
    
    n = linkage_matrix.shape[0] + 1  # Number of leaves
    
    if n < 3:
        return linkage_matrix[-1, 2] / 2.0 if linkage_matrix.shape[0] > 0 else 0.5, 1
    
    if max_clusters is None:
        max_clusters = min(int(np.sqrt(n)), n - 1)
    
    max_clusters = max(2, min(max_clusters, n - 1))
    
    # Test different numbers of clusters
    wcss_values = []  # Within-cluster sum of squares (approximated)
    k_range = range(1, max_clusters + 1)
    
    heights = linkage_matrix[:, 2]
    
    for k in k_range:
        # Height at which we'd get k clusters
        if k == 1:
            wcss = heights[-1] if len(heights) > 0 else 0
        else:
            # The height where we cut to get k clusters
            idx = -(k - 1) if k <= len(heights) else 0
            wcss = heights[idx] if idx < len(heights) else 0
        
        wcss_values.append(wcss)
    
    wcss_values = np.array(wcss_values)
    
    # Find elbow using the method of maximum distance to the line
    if len(wcss_values) < 3:
        return heights[-1] / 2.0, 2
    
    # Normalize
    x = np.arange(len(wcss_values))
    y = wcss_values
    
    # Line from first to last point
    p1 = np.array([x[0], y[0]])
    p2 = np.array([x[-1], y[-1]])
    
    # Calculate distances from each point to the line
    distances = np.zeros(len(x))
    for i in range(len(x)):
        p = np.array([x[i], y[i]])
        distances[i] = np.abs(np.cross(p2 - p1, p1 - p)) / np.linalg.norm(p2 - p1)
    
    # Find the elbow (maximum distance)
    optimal_idx = np.argmax(distances)
    optimal_k = k_range[optimal_idx]
    
    # Find the distance threshold for optimal_k clusters
    if optimal_k == 1:
        optimal_distance = heights[-1] * 1.1 if len(heights) > 0 else 1.0
    else:
        cut_idx = -(optimal_k - 1)
        if cut_idx >= -len(heights):
            optimal_distance = (heights[cut_idx] + heights[cut_idx - 1]) / 2.0 if cut_idx > -len(heights) else heights[cut_idx]
        else:
            optimal_distance = heights[0] / 2.0
    
    return optimal_distance, optimal_k


def _find_optimal_clusters_by_lifetime(linkage_matrix, percentile=90):
    """
    Use cluster lifetime (persistence) to find stable clusters.
    Clusters that exist for a long 'time' (distance range) are considered robust.
    
    Parameters:
    - linkage_matrix: scipy linkage matrix
    - percentile: percentile of lifetimes to use as threshold
    
    Returns:
    - optimal_distance: distance threshold
    - num_clusters: estimated number of clusters
    """
    if linkage_matrix.shape[0] < 2:
        return 0.5, 1
    
    heights = linkage_matrix[:, 2]
    n = linkage_matrix.shape[0] + 1
    
    # Calculate lifetimes: for each merge, how long did the clusters exist?
    # Lifetime = height at which cluster dies - height at which it was born
    lifetimes = []
    birth_times = np.zeros(2 * n - 1)  # Birth time for each node
    
    for i in range(len(heights)):
        c1, c2 = int(linkage_matrix[i, 0]), int(linkage_matrix[i, 1])
        merge_height = heights[i]
        new_cluster = n + i
        
        # Lifetime of clusters being merged
        life1 = merge_height - birth_times[c1]
        life2 = merge_height - birth_times[c2]
        
        lifetimes.append(life1)
        lifetimes.append(life2)
        
        # New cluster born at this height
        birth_times[new_cluster] = merge_height
    
    lifetimes = np.array(lifetimes)
    
    # Use percentile of lifetimes as threshold
    significant_lifetime = np.percentile(lifetimes, percentile)
    
    # Find merges where both clusters have long lifetimes
    cut_height = None
    for i in range(len(heights)):
        c1, c2 = int(linkage_matrix[i, 0]), int(linkage_matrix[i, 1])
        merge_height = heights[i]
        
        life1 = merge_height - birth_times[c1]
        life2 = merge_height - birth_times[c2]
        
        if life1 >= significant_lifetime or life2 >= significant_lifetime:
            cut_height = merge_height
            break
    
    if cut_height is None:
        cut_height = np.median(heights)
    
    # Estimate number of clusters at this cut
    num_clusters = np.sum(heights > cut_height) + 1
    
    return cut_height, num_clusters


def _find_optimal_clusters_by_topology(linkage_matrix, max_clusters=None, coherence_threshold=0.7):
    """
    Find clusters based on tree topology to ensure visual coherence.
    Groups leaves that share recent common ancestors and form complete subtrees.
    
    Parameters:
    - linkage_matrix: scipy linkage matrix
    - max_clusters: maximum number of clusters (default: sqrt(n))
    - coherence_threshold: how coherent subtrees should be (0-1, higher = more coherent)
    
    Returns:
    - optimal_distance: distance threshold
    - num_clusters: number of clusters
    """
    from scipy.cluster.hierarchy import fcluster
    
    n = linkage_matrix.shape[0] + 1  # Number of leaves
    
    if n < 3:
        return linkage_matrix[-1, 2] / 2.0 if linkage_matrix.shape[0] > 0 else 0.5, 1
    
    if max_clusters is None:
        max_clusters = max(2, min(int(np.sqrt(n) * 1.5), n - 1))
    
    heights = linkage_matrix[:, 2]
    
    # Build tree structure
    children_dict = {}
    for i in range(linkage_matrix.shape[0]):
        node_id = n + i
        children_dict[node_id] = (int(linkage_matrix[i, 0]), int(linkage_matrix[i, 1]))
    
    def get_descendants(node_id):
        """Get all leaf descendants of a node"""
        if node_id < n:
            return [node_id]
        c1, c2 = children_dict[node_id]
        return get_descendants(c1) + get_descendants(c2)
    
    # Evaluate different cut heights for coherence
    best_score = -np.inf
    best_height = heights[-1] / 2.0
    best_k = 2
    
    # Test different numbers of clusters
    for k in range(2, max_clusters + 1):
        # Cut tree to get k clusters
        if k > len(heights):
            continue
            
        cut_idx = -(k - 1)
        if cut_idx < -len(heights):
            continue
            
        cut_height = (heights[cut_idx] + heights[cut_idx - 1]) / 2.0 if cut_idx > -len(heights) else heights[cut_idx]
        
        # Get cluster assignments
        assignments = fcluster(linkage_matrix, cut_height, criterion='distance')
        
        # Calculate coherence score: how well do clusters correspond to subtrees?
        # Look at the tree structure at this cut height
        coherence_scores = []
        
        for cluster_id in np.unique(assignments):
            cluster_leaves = np.where(assignments == cluster_id)[0]
            if len(cluster_leaves) < 2:
                coherence_scores.append(1.0)  # Single leaf is perfectly coherent
                continue
            
            # Find the lowest common ancestor (LCA) height for leaves in this cluster
            # A good cluster should have leaves that diverge at similar heights
            cluster_heights = []
            for i, leaf1 in enumerate(cluster_leaves):
                for leaf2 in cluster_leaves[i+1:]:
                    # Find LCA height in linkage matrix
                    # This is approximate - look for merges involving these leaves
                    cluster_heights.append(0.0)  # Simplified for now
            
            if cluster_heights:
                # Coherence is high if all leaves in cluster are close in tree
                coherence = 1.0 - np.std(cluster_heights) / (np.mean(heights) + 1e-10)
            else:
                coherence = 1.0
            
            coherence_scores.append(coherence)
        
        # Score combines number of clusters with coherence
        avg_coherence = np.mean(coherence_scores)
        
        # Penalize too few or too many clusters, reward high coherence
        k_penalty = abs(k - np.sqrt(n)) / n
        score = avg_coherence * (1.0 - k_penalty * 0.5)
        
        if score > best_score:
            best_score = score
            best_height = cut_height
            best_k = k
    
    return best_height, best_k


def _find_optimal_clusters_by_maxclust(linkage_matrix, target_clusters=None):
    """
    Find optimal clustering by specifying target number of clusters.
    Uses the maxclust criterion to create exactly k clusters.
    
    Parameters:
    - linkage_matrix: scipy linkage matrix
    - target_clusters: desired number of clusters (default: sqrt(n))
    
    Returns:
    - cluster_assignments: array of cluster labels
    - num_clusters: actual number of clusters created
    """
    from scipy.cluster.hierarchy import fcluster
    
    n = linkage_matrix.shape[0] + 1
    
    if n < 2:
        return np.ones(n, dtype=int), 1
    
    if target_clusters is None:
        target_clusters = max(2, int(np.sqrt(n)))
    
    target_clusters = max(1, min(target_clusters, n))
    
    # Use maxclust criterion for exact number of clusters
    cluster_assignments = fcluster(linkage_matrix, target_clusters, criterion='maxclust')
    
    return cluster_assignments, target_clusters


def determine_optimal_clustering(linkage_matrix, method='gap', **kwargs):
    """
    Determine optimal clustering using various dynamic methods.
    
    Parameters:
    - linkage_matrix: scipy linkage matrix
    - method: 'gap', 'inconsistency', 'elbow', 'lifetime', 'topology', 'maxclust', or 'combined'
    - **kwargs: additional parameters for specific methods
    
    Returns:
    - cluster_assignments: array of cluster labels
    - info: dict with method-specific information
    """
    from scipy.cluster.hierarchy import fcluster
    
    n = linkage_matrix.shape[0] + 1
    
    if n < 2:
        return np.ones(n, dtype=int), {'method': method, 'n_clusters': 1}
    
    info = {'method': method}
    
    if method == 'gap':
        distance_threshold, n_clusters = _find_optimal_clusters_by_gap(
            linkage_matrix, 
            min_clusters=kwargs.get('min_clusters', 2),
            min_gap_ratio=kwargs.get('min_gap_ratio', 1.5)
        )
        info['distance_threshold'] = distance_threshold
        info['n_clusters'] = n_clusters
        cluster_assignments = fcluster(linkage_matrix, distance_threshold, criterion='distance')
        
    elif method == 'inconsistency':
        cluster_assignments = _find_optimal_clusters_by_inconsistency(
            linkage_matrix,
            depth=kwargs.get('depth', 2),
            threshold=kwargs.get('threshold', 1.5)
        )
        info['n_clusters'] = len(np.unique(cluster_assignments))
        
    elif method == 'elbow':
        distance_threshold, n_clusters = _find_optimal_clusters_by_elbow(
            linkage_matrix,
            max_clusters=kwargs.get('max_clusters', None)
        )
        info['distance_threshold'] = distance_threshold
        info['n_clusters'] = n_clusters
        cluster_assignments = fcluster(linkage_matrix, distance_threshold, criterion='distance')
        
    elif method == 'lifetime':
        distance_threshold, n_clusters = _find_optimal_clusters_by_lifetime(
            linkage_matrix,
            percentile=kwargs.get('percentile', 90)
        )
        info['distance_threshold'] = distance_threshold
        info['n_clusters'] = n_clusters
        cluster_assignments = fcluster(linkage_matrix, distance_threshold, criterion='distance')
        
    elif method == 'topology':
        distance_threshold, n_clusters = _find_optimal_clusters_by_topology(
            linkage_matrix,
            max_clusters=kwargs.get('max_clusters', None),
            coherence_threshold=kwargs.get('coherence_threshold', 0.7)
        )
        info['distance_threshold'] = distance_threshold
        info['n_clusters'] = n_clusters
        info['note'] = 'Tree-topology aware clustering for visual coherence'
        cluster_assignments = fcluster(linkage_matrix, distance_threshold, criterion='distance')
        
    elif method == 'maxclust':
        cluster_assignments, n_clusters = _find_optimal_clusters_by_maxclust(
            linkage_matrix,
            target_clusters=kwargs.get('target_clusters', None)
        )
        info['n_clusters'] = n_clusters
        info['note'] = f'Forced {n_clusters} clusters using maxclust criterion'
        
    elif method == 'combined':
        # Use multiple methods and take consensus
        methods_to_try = ['gap', 'elbow', 'topology']
        all_n_clusters = []
        all_assignments = []
        
        for m in methods_to_try:
            try:
                assignments, m_info = determine_optimal_clustering(linkage_matrix, method=m)
                all_n_clusters.append(m_info['n_clusters'])
                all_assignments.append(assignments)
            except:
                pass
        
        if all_n_clusters:
            # Use median number of clusters
            median_k = int(np.median(all_n_clusters))
            info['n_clusters'] = median_k
            info['individual_estimates'] = all_n_clusters
            
            # Find the assignment closest to median
            differences = [abs(len(np.unique(a)) - median_k) for a in all_assignments]
            best_idx = np.argmin(differences)
            cluster_assignments = all_assignments[best_idx]
        else:
            # Fallback
            cluster_assignments = fcluster(linkage_matrix, 0.7, criterion='distance')
            info['n_clusters'] = len(np.unique(cluster_assignments))
    
    else:
        raise ValueError(f"Unknown method: {method}")
    
    return cluster_assignments, info


# ----------------------------------------------------------------------
# Main clustering function (using circular tree plotting)
# ----------------------------------------------------------------------
def cluster_gene_neighborhoods_from_sqlite(
    db_path,
    genes_table=GENES_TABLE,
    neighbors_table=NEIGHBORS_TABLE,
    col_neighborhood_id=COL_NEIGHBORHOOD_ID,
    col_gene_id=COL_GENE_ID,
    col_linking_key=COL_LINKING_KEY,
    col_accession_id=COL_ACCESSION_ID,
    col_function_desc=COL_FUNCTION_DESC,
    col_pfam_ids=COL_PFAM_IDS,
    col_interpro_ids=COL_INTERPRO_IDS,
    col_rel_start=COL_REL_START,
    col_rel_stop=COL_REL_STOP,
    col_ssn_cluster_id=COL_SSN_CLUSTER_ID,
    hit_gene_weight_factor=HIT_GENE_WEIGHT_FACTOR,
    direct_neighbor_weight_factor=DIRECT_NEIGHBOR_WEIGHT_FACTOR,
    differentiate_by_ssn_cluster=True,
    ssn_cluster_value_to_filter=DEFAULT_SSN_CLUSTER_VALUE_TO_FILTER,
    collapse_identical_neighborhoods=COLLAPSE_IDENTICAL_NEIGHBORHOODS,
    collapse_core_similarity_threshold=COLLAPSE_CORE_SIMILARITY_THRESHOLD,
    collapse_full_neighborhood_similarity_threshold=COLLAPSE_FULL_NEIGHBORHOOD_SIMILARITY_THRESHOLD,
    original_input_sequence_id=None,
    distance_threshold=0.8,
    clustering_method='gap',
    clustering_params=None,
    save_plots=SAVE_PLOTS,
    output_dir=OUTPUT_DIR,
    output_formats=OUTPUT_FORMATS,
    dpi=DPI,
    report_file_handle=None,
    parallelize_pdist=False,
):
    """
    Cluster gene neighborhoods based on domain/annotation features and
    draw circular phylograms for each SSN cluster (or all neighborhoods).
    
    Clustering methods:
    - 'gap': Finds large gaps in merge heights (fast, good for well-separated clusters)
    - 'inconsistency': Uses inconsistency coefficients (hierarchical structure aware)
    - 'elbow': Elbow method on within-cluster distances (classic approach)
    - 'lifetime': Cluster persistence/lifetime analysis (finds stable clusters)
    - 'topology': Tree-topology aware (best for visual coherence) **RECOMMENDED**
    - 'maxclust': Forces specific number of clusters (when you know k)
    - 'combined': Consensus of multiple methods (most robust but slower)
    - 'static': Use fixed distance_threshold (legacy)
    """
    def log(msg):
        print(msg)
        if report_file_handle:
            report_file_handle.write(msg + "\n")

    t0_all = time.time()
    conn = sqlite3.connect(db_path)

    # --- Fetch data ---
    log("Fetching hit gene data ...")
    q_hits = f"""
        SELECT {col_gene_id}, {col_neighborhood_id}, {col_function_desc},
               {col_pfam_ids}, {col_interpro_ids}, {col_ssn_cluster_id},
               {col_accession_id}
        FROM {genes_table}
    """
    hit_genes_df = pd.read_sql_query(q_hits, conn)
    log(f"  Retrieved {len(hit_genes_df)} hit genes")

    if hit_genes_df.empty:
        log("No hit genes found; aborting.")
        conn.close()
        return {}, {}, {}

    log("Fetching neighbor data ...")
    q_neighbors = f"""
        SELECT {col_linking_key}, {col_gene_id}, {col_function_desc},
               {col_pfam_ids}, {col_interpro_ids},
               {col_rel_start}, {col_rel_stop}
        FROM {neighbors_table}
    """
    neighbors_df = pd.read_sql_query(q_neighbors, conn)
    conn.close()
    log(f"  Retrieved {len(neighbors_df)} neighbor rows")

    # Group neighbors by linking key once
    log("Grouping neighbors by hit gene id ...")
    neighbors_by_link_id = {
        key: group.to_dict("records")
        for key, group in tqdm(
            neighbors_df.groupby(col_linking_key),
            desc="  Grouping neighbors",
            leave=False,
        )
    }
    del neighbors_df
    gc.collect()

    all_neighborhood_features = defaultdict(set)
    core_neighborhood_features = defaultdict(set)
    full_neighborhood_labels_map = {}
    raw_ssn_counts = defaultdict(int)

    # --- Feature extraction ---
    log(f"Extracting features for {len(hit_genes_df)} neighborhoods ...")
    for _, hit_row in tqdm(
        hit_genes_df.iterrows(),
        total=len(hit_genes_df),
        desc="  Neighborhoods",
        unit="nh",
    ):
        hit_id = hit_row[col_gene_id]
        organism = hit_row[col_neighborhood_id]
        ssn_id = hit_row[col_ssn_cluster_id]
        accession_id = hit_row[col_accession_id]

        raw_ssn_counts[ssn_id] += 1
        nh_label = f"{organism}_{hit_id}"

        full_feats = set()
        core_feats = set()

        # hit gene, strongly weighted for full neighborhood features
        hit_full = extract_features_from_gene_row(
            hit_row,
            current_weight_factor=hit_gene_weight_factor,
            base_prefix="HIT_",
            include_desc=True,
            include_pfam=True,
            include_interpro=True,
        )
        full_feats.update(hit_full)

        hit_core = extract_features_from_gene_row(
            hit_row,
            current_weight_factor=1,
            base_prefix="HIT_CORE_",
            include_desc=False,
            include_pfam=True,
            include_interpro=True,
        )
        core_feats.update(hit_core)

        full_neighborhood_labels_map[nh_label] = (
            organism,
            hit_id,
            ssn_id,
            accession_id,
            None,
        )

        # determine closest left/right neighbors by relative coordinates
        neighbor_rows = neighbors_by_link_id.get(hit_id, [])
        closest_left_id = None
        closest_right_id = None
        max_neg_rel_stop = -np.inf
        min_pos_rel_start = np.inf

        for nrow in neighbor_rows:
            rel_start = nrow[col_rel_start]
            rel_stop = nrow[col_rel_stop]
            nid = nrow[col_gene_id]

            if rel_stop is not None and rel_stop < 0 and rel_stop > max_neg_rel_stop:
                max_neg_rel_stop = rel_stop
                closest_left_id = nid
            if rel_start is not None and rel_start > 0 and rel_start < min_pos_rel_start:
                min_pos_rel_start = rel_start
                closest_right_id = nid

        for nrow in neighbor_rows:
            nid = nrow[col_gene_id]
            is_direct = (
                (closest_left_id is not None and nid == closest_left_id)
                or (closest_right_id is not None and nid == closest_right_id)
            )
            w = direct_neighbor_weight_factor if is_direct else 1

            nf_full = extract_features_from_gene_row(
                nrow,
                current_weight_factor=w,
                base_prefix="N_",
                include_desc=True,
                include_pfam=True,
                include_interpro=True,
            )
            full_feats.update(nf_full)

            if is_direct:
                nf_core = extract_features_from_gene_row(
                    nrow,
                    current_weight_factor=1,
                    base_prefix="N_CORE_",
                    include_desc=False,
                    include_pfam=True,
                    include_interpro=True,
                )
                core_feats.update(nf_core)

        all_neighborhood_features[nh_label] = full_feats
        core_neighborhood_features[nh_label] = core_feats

    del neighbors_by_link_id
    gc.collect()
    log("Feature extraction complete.")

    log("\nRaw SSN id distribution:")
    for sid, count in sorted(raw_ssn_counts.items(), key=lambda x: str(x[0])):
        log(f"  SSN {sid}: {count} neighborhoods")
    log("-" * 60)

    if not all_neighborhood_features:
        log("No neighborhoods with features; aborting.")
        return {}, {}, {}

    # --- Collapsing similar neighborhoods (optional) ---
    if collapse_identical_neighborhoods:
        final_neighborhood_features, final_neighborhood_labels_map, collapsed_groups_report = _perform_collapsing(
            all_neighborhood_features,
            full_neighborhood_labels_map,
            core_neighborhood_features,
            collapse_core_similarity_threshold,
            collapse_full_neighborhood_similarity_threshold,
            output_prefix="[Collapsing]",
            report_file=report_file_handle,
            parallelize_pdist=parallelize_pdist,
        )
    else:
        log("Collapsing disabled; using all neighborhoods as-is.")
        final_neighborhood_features = all_neighborhood_features
        final_neighborhood_labels_map = full_neighborhood_labels_map
        collapsed_groups_report = {}

    del all_neighborhood_features
    del core_neighborhood_features
    gc.collect()

    # sanity check on features
    if not final_neighborhood_features:
        log("No final neighborhoods after collapsing; aborting.")
        return {}, {}, collapsed_groups_report

    # determine grouping (SSN-separated or all)
    ssn_groups = defaultdict(list)
    if differentiate_by_ssn_cluster:
        for nh_label, (_, _, ssn_id, _, _) in final_neighborhood_labels_map.items():
            if ssn_id not in ssn_cluster_value_to_filter:
                ssn_groups[ssn_id].append(nh_label)
        log(f"\nProcessing {len(ssn_groups)} SSN groups (after filtering).")
        if ssn_groups:
            log("  SSN ids: " + ", ".join(map(str, sorted(ssn_groups.keys(), key=str))))
    else:
        ssn_groups["All_Neighborhoods"] = sorted(final_neighborhood_features.keys())
        log("Processing all neighborhoods together (no SSN separation).")

    clusters_output = defaultdict(dict)
    if differentiate_by_ssn_cluster and not ssn_groups:
        log("No valid SSN clusters with neighborhoods; nothing to cluster.")
        return {}, final_neighborhood_labels_map, collapsed_groups_report

    # --- Per-group clustering and plotting ---
    for ssn_id, nh_labels in tqdm(
        ssn_groups.items(),
        desc="Clustering SSN groups",
        unit="group",
    ):
        group_start = time.time()
        label_list = sorted(nh_labels)
        n_nh = len(label_list)
        if n_nh < 2:
            log(f"SSN {ssn_id}: only {n_nh} neighborhood(s); skipping clustering.")
            clusters_output[ssn_id] = {1: label_list}
            continue
        
        print(f"Processing SSN {ssn_id} with {len(label_list)} neighborhoods.")
        
        log(f"\n--- SSN {ssn_id}: {n_nh} neighborhoods ---")
        group_feats = {l: final_neighborhood_features[l] for l in label_list}
        group_labels_map = {l: final_neighborhood_labels_map[l] for l in label_list}

        # vocabulary and binary matrix
        vocab = sorted(set.union(*group_feats.values()))
        if not vocab:
            log(f"  SSN {ssn_id}: no features; skipping.")
            clusters_output[ssn_id] = {1: label_list}
            continue

        ft_idx = {f: i for i, f in enumerate(vocab)}
        mat_lil = sp.lil_matrix((n_nh, len(vocab)), dtype=np.int8)
        for i, nh in enumerate(tqdm(label_list, desc=f"  Features SSN {ssn_id}", leave=False)):
            for f in group_feats[nh]:
                j = ft_idx[f]
                mat_lil[i, j] = 1
        mat = mat_lil.tocsr()
        del mat_lil
        gc.collect()

        # all identical?
        if n_nh > 1 and all((mat[0] != mat[i]).nnz == 0 for i in range(1, n_nh)):
            log(f"  SSN {ssn_id}: all neighborhoods identical; no tree/distance.")
            clusters_output[ssn_id] = {1: label_list}
            del mat
            gc.collect()
            continue

        # distances + linkage
        log(f"  SSN {ssn_id}: computing Jaccard distances ...")
        t0 = time.time()
        dists = parallel_pdist_jaccard(mat, num_cores=-1 if parallelize_pdist else 1)
        log(f"    distance calc in {time.time() - t0:.2f}s")

        log(f"  SSN {ssn_id}: linkage ...")
        t0 = time.time()
        linked = linkage(dists, method="average")
        log(f"    linkage in {time.time() - t0:.2f}s")

        del dists
        gc.collect()

        # Dynamic cluster assignments
        log(f"  SSN {ssn_id}: determining optimal clusters (method: {clustering_method}) ...")
        t0 = time.time()
        
        if clustering_method == 'static':
            # Legacy: use fixed threshold
            cluster_assignments = fcluster(linked, distance_threshold, criterion="distance")
            n_clusters = len(np.unique(cluster_assignments))
            log(f"    static threshold={distance_threshold:.3f} -> {n_clusters} clusters")
        else:
            # Dynamic clustering
            params = clustering_params or {}
            cluster_assignments, cluster_info = determine_optimal_clustering(
                linked, method=clustering_method, **params
            )
            n_clusters = cluster_info['n_clusters']
            log(f"    {clustering_method} method -> {n_clusters} clusters")
            if 'distance_threshold' in cluster_info:
                log(f"    dynamic threshold: {cluster_info['distance_threshold']:.3f}")
            if 'individual_estimates' in cluster_info:
                log(f"    individual method estimates: {cluster_info['individual_estimates']}")
        
        log(f"    clustering complete in {time.time() - t0:.2f}s")
        
        clusters = defaultdict(list)
        for i, cid in enumerate(cluster_assignments):
            clusters[cid].append(label_list[i])
        clusters_output[ssn_id] = clusters

        # prepare coloring
        _, label_to_color = _collect_cluster_colors(cluster_assignments, label_list, cmap_name="tab20")

        # draw circular tree
        title = f"SSN {ssn_id} gene neighborhoods"
        out_prefix = f"SSN_{ssn_id}_gnn"
        _draw_circular_tree_with_clusters(
            linkage_matrix=linked,
            leaf_labels=label_list,
            cluster_assignments=cluster_assignments,
            label_to_cluster_color=label_to_color,
            original_input_sequence_id=original_input_sequence_id,
            labels_map=group_labels_map,
            title=title,
            out_prefix=out_prefix,
            output_dir=output_dir,
            output_formats=output_formats,
            dpi=dpi,
            show=not save_plots,
            feature_matrix=mat,
            inner_boundary=0.45,
            consensus_threshold=0.2,
            leaf_stretch_power=7,
            gap_factor=0.1,
        )

        log(f"--- SSN {ssn_id} done in {time.time() - group_start:.2f}s ---")
        del linked
        gc.collect()

    log(f"\nTotal runtime: {time.time() - t0_all:.2f}s")
    return clusters_output, final_neighborhood_labels_map, collapsed_groups_report

In [175]:
# ----------------------------------------------------------------------
# High-level driver / example usage
# ----------------------------------------------------------------------
SQLITE_DB_PATH = r"D:\Studium\PhD\Rotation\1st_WetLab\Enzyme_Homologe_Search\CaMES\CaMES_10kBlast_10e_50eEdge_noFilter_300AST_min900AA_withoutEgtD_withoutMethyltrans_10N\39061_CaMES_10kBlast_10e_50eEdge_noFilter_300AST_min900AA_withoutEgtD_withoutMethyltrans_10N.sqlite"
ORIGINAL_INPUT_SEQUENCE_ID = "A0A7V4WV16"  # or None
DIFFERENTIATE_BY_SSN_CLUSTER = True  # Whether to cluster neighborhoods separately by SSN cluster

# --- Clustering method configuration ---
# Choose: 'gap', 'inconsistency', 'elbow', 'lifetime', 'topology', 'maxclust', 'combined', or 'static'
# CLUSTERING_METHOD = 'topology'  # RECOMMENDED: ensures visually coherent clusters matching tree structure
# CLUSTERING_PARAMS = {
#     'max_clusters': None,        # Maximum clusters (None = auto, uses sqrt(n))
#     'coherence_threshold': 0.7   # How coherent subtrees should be (0-1, higher = more coherent)
# }

# Alternative methods:
# CLUSTERING_METHOD = 'gap'  # Fast, good for well-separated clusters
# CLUSTERING_PARAMS = {'min_clusters': 2, 'min_gap_ratio': 1.5}

# CLUSTERING_METHOD = 'maxclust'  # Force specific number of clusters
# CLUSTERING_PARAMS = {'target_clusters': 5}

CLUSTERING_METHOD = 'combined'  # Most robust (consensus of multiple methods)

# For legacy behavior with fixed threshold:
# CLUSTERING_METHOD = 'static'

chosen_distance_threshold = 0.6  # Only used if CLUSTERING_METHOD = 'static'
PARALLELIZE_PDIST_ENABLED = True

COLLAPSE_IDENTICAL_NEIGHBORHOODS_ACTIVE = True
COLLAPSE_CORE_SIMILARITY_THRESHOLD_ACTIVE = 0.0
COLLAPSE_FULL_NEIGHBORHOOD_SIMILARITY_THRESHOLD_ACTIVE = 0.2

os.makedirs(OUTPUT_DIR, exist_ok=True)
report_suffix = "_ssn_differentiated" if DIFFERENTIATE_BY_SSN_CLUSTER else "_all_neighborhoods"
report_filename = f"{REPORT_FILENAME_BASE}{report_suffix}.txt"
report_path = os.path.join(OUTPUT_DIR, report_filename)

with open(report_path, "w") as report_file:
    def log_top(msg):
        print(msg)
        report_file.write(msg + "\n")

    log_top(
        f"\n--- GNN Clustering Report "
        f"({datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}) ---"
    )
    log_top(f"Database: {SQLITE_DB_PATH}")
    log_top(f"Clustering method: {CLUSTERING_METHOD}")
    if CLUSTERING_METHOD == 'static':
        log_top(f"  Static distance threshold: {chosen_distance_threshold}")
    else:
        log_top(f"  Dynamic clustering parameters: {CLUSTERING_PARAMS}")
    log_top(f"Hit weight: {HIT_GENE_WEIGHT_FACTOR}")
    log_top(f"Direct neighbor weight: {DIRECT_NEIGHBOR_WEIGHT_FACTOR}")
    log_top(
        "Mode: "
        + ("SSN-separated" if DIFFERENTIATE_BY_SSN_CLUSTER else "all neighborhoods together")
    )
    if COLLAPSE_IDENTICAL_NEIGHBORHOODS_ACTIVE:
        log_top("Collapsing enabled:")
        log_top(f"  Stage1 (core) thr: {COLLAPSE_CORE_SIMILARITY_THRESHOLD_ACTIVE}")
        log_top(f"  Stage2 (full) thr: {COLLAPSE_FULL_NEIGHBORHOOD_SIMILARITY_THRESHOLD_ACTIVE}")
    else:
        log_top("Collapsing disabled.")

    log_top(
        "Distance parallelism: "
        + ("joblib (enabled)" if PARALLELIZE_PDIST_ENABLED else "sequential")
    )
    log_top(f"OMP_NUM_THREADS: {os.environ.get('OMP_NUM_THREADS', 'not set')}")
    if ORIGINAL_INPUT_SEQUENCE_ID:
        log_top(
            f"Highlight accession: {ORIGINAL_INPUT_SEQUENCE_ID} "
            f"(color {HIGHLIGHT_COLOR})"
        )
    else:
        log_top("No highlight accession set.")
    log_top(f"Plots -> {OUTPUT_DIR} as {OUTPUT_FORMATS}, dpi={DPI}")
    log_top(f"Report -> {report_path}")
    log_top("-" * 70)

    clusters_by_ssn, final_labels_map, collapsed_groups_report = cluster_gene_neighborhoods_from_sqlite(
        db_path=SQLITE_DB_PATH,
        genes_table=GENES_TABLE,
        neighbors_table=NEIGHBORS_TABLE,
        differentiate_by_ssn_cluster=DIFFERENTIATE_BY_SSN_CLUSTER,
        ssn_cluster_value_to_filter=DEFAULT_SSN_CLUSTER_VALUE_TO_FILTER,
        collapse_identical_neighborhoods=COLLAPSE_IDENTICAL_NEIGHBORHOODS_ACTIVE,
        collapse_core_similarity_threshold=COLLAPSE_CORE_SIMILARITY_THRESHOLD_ACTIVE,
        collapse_full_neighborhood_similarity_threshold=COLLAPSE_FULL_NEIGHBORHOOD_SIMILARITY_THRESHOLD_ACTIVE,
        original_input_sequence_id=ORIGINAL_INPUT_SEQUENCE_ID,
        distance_threshold=chosen_distance_threshold,
        clustering_method=CLUSTERING_METHOD,
        clustering_params=CLUSTERING_PARAMS,
        save_plots=SAVE_PLOTS,
        output_dir=OUTPUT_DIR,
        output_formats=OUTPUT_FORMATS,
        dpi=DPI,
        report_file_handle=report_file,
        parallelize_pdist=PARALLELIZE_PDIST_ENABLED,
    )

    if clusters_by_ssn:
        log_top("\n--- Final clustering results ---")
        for ssn_id, clusters_in_ssn in sorted(clusters_by_ssn.items(), key=lambda x: str(x[0])):
            log_top(f"\n### SSN {ssn_id} ###")
            if not clusters_in_ssn:
                log_top("  No clusters.")
                continue
            for cid, nh_list in sorted(clusters_in_ssn.items()):
                log_top(f"  Cluster {cid}: {len(nh_list)} neighborhoods")
                for nh in nh_list:
                    org, hit_id, _, acc, collapsed_info = final_labels_map.get(
                        nh, ("UNKNOWN", "UNKNOWN", None, "UNKNOWN", None)
                    )
                    highlight = " (ORIGINAL)" if ORIGINAL_INPUT_SEQUENCE_ID and acc == ORIGINAL_INPUT_SEQUENCE_ID else ""
                    collapsed_suffix = ""
                    if collapsed_info:
                        count, code = collapsed_info
                        collapsed_suffix = f" (Collapsed: {count}, code: {code})"
                    log_top(
                        f"    - {org} | Acc: {acc}{highlight}{collapsed_suffix} "
                        f"(hit_id={hit_id}, nh={nh})"
                    )
            log_top("  " + "-" * 30)

        if collapsed_groups_report:
            log_top("\n--- Collapsed neighborhood groups ---")
            for code, info in sorted(collapsed_groups_report.items()):
                log_top(
                    f"  Group {code}: rep={info['representative']} "
                    f"(n={info['count']})"
                )
                for nh in info["members"]:
                    org, hit_id, _, acc, _ = final_labels_map.get(
                        nh, ("UNKNOWN", "UNKNOWN", None, "UNKNOWN", None)
                    )
                    log_top(
                        f"    - {org} | Acc: {acc} (hit_id={hit_id}, nh={nh})"
                    )
            log_top("-" * 60)
    else:
        log_top("\nNo clusters formed at all.")

    log_top("\n--- Report end ---")


--- GNN Clustering Report (2026-02-21 13:39:45) ---
Database: D:\Studium\PhD\Rotation\1st_WetLab\Enzyme_Homologe_Search\CaMES\CaMES_10kBlast_10e_50eEdge_noFilter_300AST_min900AA_withoutEgtD_withoutMethyltrans_10N\39061_CaMES_10kBlast_10e_50eEdge_noFilter_300AST_min900AA_withoutEgtD_withoutMethyltrans_10N.sqlite
Clustering method: combined
  Dynamic clustering parameters: {'max_clusters': None, 'coherence_threshold': 0.7}
Hit weight: 10
Direct neighbor weight: 3
Mode: SSN-separated
Collapsing enabled:
  Stage1 (core) thr: 0.0
  Stage2 (full) thr: 0.2
Distance parallelism: joblib (enabled)
OMP_NUM_THREADS: 12
Highlight accession: A0A7V4WV16 (color red)
Plots -> gnn_cluster_plots_circular as ['pdf'], dpi=600
Report -> gnn_cluster_plots_circular\gnn_clustering_report_ssn_differentiated.txt
----------------------------------------------------------------------
Fetching hit gene data ...
  Retrieved 37 hit genes
Fetching neighbor data ...
  Retrieved 800 neighbor rows
Grouping neighbors by 

  Grouping neighbors:   0%|          | 0/34 [00:00<?, ?it/s]

                                                            

Extracting features for 37 neighborhoods ...


  Neighborhoods: 100%|██████████| 37/37 [00:00<00:00, 939.78nh/s]


Feature extraction complete.

Raw SSN id distribution:
  SSN 1: 35 neighborhoods
  SSN 2: 1 neighborhoods
  SSN 3: 1 neighborhoods
------------------------------------------------------------
[Collapsing]  Collapsing neighborhoods (core thr: 0.0, full thr: 0.2)
[Collapsing]  Stage1: building core sparse matrix


                                                                   

[Collapsing]  Stage1: computing core Jaccard distances
  Pre-computing feature sets for N=37 ...


                                                      

  Done in 0.00s
  Using parallel Jaccard (N=37, cores=12)
[Collapsing]  Stage1: distance calc in 0.64s
[Collapsing]  Stage1: 37 core groups (1.02s partial)
[Collapsing]  Stage2: full-neighborhood collapsing
[Collapsing]  No collapsing performed (1.52s).

Processing 3 SSN groups (after filtering).
  SSN ids: 1, 2, 3


Clustering SSN groups:   0%|          | 0/3 [00:00<?, ?group/s]

Processing SSN 1 with 35 neighborhoods.

--- SSN 1: 35 neighborhoods ---




  SSN 1: computing Jaccard distances ...
  Pre-computing feature sets for N=35 ...




  Done in 0.01s
  Using parallel Jaccard (N=35, cores=12)
    distance calc in 1.55s
  SSN 1: linkage ...
    linkage in 0.00s
  SSN 1: determining optimal clusters (method: combined) ...
    combined method -> 6 clusters
    individual method estimates: [13, 2, 6]
    clustering complete in 0.00s


meta NOT subset; don't know how to subset; dropped
meta NOT subset; don't know how to subset; dropped
Clustering SSN groups: 100%|██████████| 3/3 [00:03<00:00,  1.25s/group]

--- SSN 1 done in 3.37s ---
SSN 3: only 1 neighborhood(s); skipping clustering.
SSN 2: only 1 neighborhood(s); skipping clustering.

Total runtime: 6.47s

--- Final clustering results ---

### SSN 1 ###
  Cluster 1: 4 neighborhoods
    - Anaeromyxobacter diazotrophicus. | Acc: A0A7I9VRP5 (hit_id=BJTG01000010, nh=Anaeromyxobacter diazotrophicus._BJTG01000010)
    - Calditrichota bacterium. | Acc: A0A3M2FG68 (hit_id=RFFS01000159, nh=Calditrichota bacterium._RFFS01000159)
    - Desulfatitalea sp. BRH_c12. | Acc: A0A0F2R6L8 (hit_id=LADR01000008, nh=Desulfatitalea sp. BRH_c12._LADR01000008)
    - Magnetococcus massalia (strain MO-1) | Acc: A0A1S7LH69 (hit_id=LO017727, nh=Magnetococcus massalia (strain MO-1)_LO017727)
  Cluster 2: 27 neighborhoods
    - Acidobacteriota bacterium. | Acc: A0A399XDS8 (hit_id=QEUP01000041, nh=Acidobacteriota bacterium._QEUP01000041)
    - Actinomycetes bacterium. | Acc: A0A662DMY1 (hit_id=QMQC01000053, nh=Actinomycetes bacterium._QMQC01000053)
    - Alteromonada


