# Cluster Gene Neighborhoods

In [None]:
import pandas as pd
import sqlite3
from collections import defaultdict
import re
from scipy.cluster.hierarchy import linkage, fcluster, dendrogram
from scipy.spatial.distance import pdist
import matplotlib.pyplot as plt
import numpy as np
import os
import datetime
import string

# --- Configuration for the SQLite DB ---
# SQLITE_DB_PATH will be set in the __main__ block for convenience
GENES_TABLE = 'attributes' # Main table for hit gene data
NEIGHBORS_TABLE = 'neighbors' # Table for neighboring gene data

COL_NEIGHBORHOOD_ID = 'organism' # Use 'organism' as the unique identifier for each neighborhood for our clustering
COL_GENE_ID = 'id' # Unique identifier for each gene/protein in both tables
COL_LINKING_KEY = 'id' # The linking column between 'attributes' and 'neighbors'
COL_ACCESSION_ID = 'accession' # Column for accession (UniProt) ID

COL_FUNCTION_DESC = 'desc' # Main functional description in both tables
COL_PFAM_IDS = 'family' # Column for PFAM family IDs in both tables
COL_INTERPRO_IDS = 'ipro_family' # Column for InterPro family IDs in both tables
COL_REL_START = 'rel_start' # Column for relative start position
COL_REL_STOP = 'rel_stop' # Column for relative stop position

HIT_GENE_WEIGHT_FACTOR = 10 # Factor by which hit gene features are "copied" for emphasis
DIRECT_NEIGHBOR_WEIGHT_FACTOR = 3 # Factor for direct neighbor domain features

COL_SSN_CLUSTER_ID = 'cluster_num' # Column in 'attributes' that holds the SSN cluster ID
# DEFAULT_SSN_CLUSTER_VALUE_TO_FILTER can be a list of values that should be ignored as valid SSN clusters
DEFAULT_SSN_CLUSTER_VALUE_TO_FILTER = [None, 0] # Example: Filter out None or 0 as valid SSN IDs

SAVE_PLOTS = True # Set to True to save plots to files
OUTPUT_DIR = 'gnn_cluster_plots' # Directory to save plots
REPORT_FILENAME_BASE = 'gnn_clustering_report' # Base name, will append info dynamically
OUTPUT_FORMATS = ['svg', 'png', 'pdf'] # List of formats to save plots in
DPI = 300 # Dots per inch for raster formats like 'png'
HIGHLIGHT_COLOR = 'red' # Color for the original input sequence's leaf label

# Dynamic plot sizing parameters
MIN_PLOT_HEIGHT = 8 # Minimum height of the plot in inches
HEIGHT_PER_LEAF = 0.25 # Adjust this value (e.g., 0.2 to 0.5) to compress/stretch the Y-axis
MAX_PLOT_HEIGHT = 40 # Maximum height to prevent excessively tall plots

MIN_PLOT_WIDTH = 10 # Minimum width of the plot in inches
WIDTH_PER_LEAF = 0.3 # Adjust this value (e.g., 0.3 to 0.7) to compress/stretch the X-axis (more horizontal space for branches)
MAX_PLOT_WIDTH = 60 # Maximum width to prevent excessively wide plots

# Configuration for collapsing similar neighborhoods
COLLAPSE_IDENTICAL_NEIGHBORHOODS = True # Set to True to enable collapsing
COLLAPSE_CORE_SIMILARITY_THRESHOLD = 0.0 # Strict similarity for hit gene + direct neighbors. Usually 0.0 for exact matches.
COLLAPSE_FULL_NEIGHBORHOOD_SIMILARITY_THRESHOLD = 0.3 # Similarity for the entire neighborhood. e.g., 0.3 for 70% similarity.
# ----------------------------------------------------------------------


def parse_annotation_string(annotation_str, prefix=""):
    """
    Parses a string containing annotations (e.g., function, InterPro, PFAM)
    to extract individual features, applying a prefix.
    Handles potential multiple IDs separated by hyphens or semicolons.
    Does NOT split by spaces or underscores for terms like "cysteine desulfurase".
    Filters out "Uncharacterized protein", empty strings, and 'none' values.
    """
    if not isinstance(annotation_str, str) or pd.isna(annotation_str) or annotation_str.lower().strip() in ('none', '', 'null', 'uncharacterized protein'):
        return set()

    features = set()
    # Split only by hyphens (-) and semicolons (;)
    parts = [p.strip() for p in re.split(r'[-;]', annotation_str) if p.strip()]

    for part in parts:
        # After splitting, re-check for uninformative parts
        if part.lower().strip() in ('none', '', 'null', 'uncharacterized protein'):
            continue

        # Prioritize InterPro and PFAM IDs if they match the patterns
        if re.match(r'IPR\d+', part, re.IGNORECASE):
            features.add(f"{prefix}{part.upper()}")
        elif re.match(r'PF\d+', part, re.IGNORECASE):
            features.add(f"{prefix}{part.upper()}")
        else:
            # General terms, clean them up slightly (collapse multiple spaces)
            clean_part = re.sub(r'\s+', ' ', part).lower().strip()

            if clean_part: # Check if it's not empty after cleaning
                features.add(f"{prefix}{clean_part}")
    return features

def extract_features_from_gene_row(gene_row, current_weight_factor=1, base_prefix="N_", 
                                   include_desc=True, include_pfam=True, include_interpro=True):
    """
    Extracts features (InterPro, PFAM, function terms) from a single gene row,
    applying a base prefix and duplicating features by current_weight_factor.
    """
    
    features_set = set()
    raw_features = set()

    if include_desc:
        function_desc = gene_row[COL_FUNCTION_DESC]
        raw_features.update(parse_annotation_string(function_desc))
    
    if include_pfam:
        pfam_ids = gene_row[COL_PFAM_IDS]
        raw_features.update(parse_annotation_string(pfam_ids))
    
    if include_interpro:
        interpro_ids = gene_row[COL_INTERPRO_IDS]
        raw_features.update(parse_annotation_string(interpro_ids))

    if current_weight_factor > 1:
        for feature in raw_features:
            for i in range(current_weight_factor):
                features_set.add(f"{base_prefix}{feature}_w{i}") 
    else: 
        for feature in raw_features:
            features_set.add(f"{base_prefix}{feature}")

    return features_set


def _plot_dendrogram(linked, neighborhood_ids_subset, labels_map, distance_threshold, 
                     plot_title_base, label_type, original_input_sequence_id,
                     save_plots, output_dir, output_formats, dpi, 
                     min_plot_height, height_per_leaf, max_plot_height,
                     min_plot_width, width_per_leaf, max_plot_width):
    """
    Helper function to generate a single dendrogram plot.
    """
    fig_title = f"{plot_title_base} ({label_type.capitalize()} Labels)"
    
    labels_to_use = []
    # Store the hit_id for each label in the order it will be plotted, to match with xticklabels
    accession_ids_for_labels = []
    for nh_id in neighborhood_ids_subset:
        organism_name, hit_id_internal, ssn_cluster_id, accession_id, _ = labels_map.get(nh_id, ('Unknown', 'Unknown', None, 'Unknown', None))

        if label_type == 'organism':
            labels_to_use.append(organism_name.rstrip('.'))
        elif label_type == 'id': # This now means 'accession'
            labels_to_use.append(accession_id) # Use accession_id for 'id' labels
        else:
            labels_to_use.append(nh_id) # Fallback, should not be hit
        accession_ids_for_labels.append(accession_id)

    # Figure size calculation
    num_leaves = len(neighborhood_ids_subset)
    
    calculated_height = max(min_plot_height, num_leaves * height_per_leaf)
    final_height = min(calculated_height, max_plot_height)
    
    calculated_width = max(min_plot_width, num_leaves * width_per_leaf)
    final_width = min(calculated_width, max_plot_width)
    
    plt.figure(figsize=(final_width, final_height)) # Use dynamic width and height
    
    dendrogram(linked,
               orientation='top',
               labels=labels_to_use,
               distance_sort='descending',
               show_leaf_counts=True)
    
    plt.title(fig_title)
    plt.xlabel('Gene Neighborhood (Labeled by ' + ('Accession' if label_type == 'id' else label_type.capitalize()) + ')')     # Y-axis label is now fixed to Linear Scale
    plt.ylabel(f'Jaccard Distance (Linear Scale, Threshold: {distance_threshold})') 
    plt.axhline(y=distance_threshold, color='r', linestyle='--', label=f'Cut-off at {distance_threshold}')
    plt.legend()

    ax = plt.gca()    
    plt.setp(ax.get_xticklabels(), rotation=90, ha="right", rotation_mode="anchor")
    
    # Color specific leaf label
    if original_input_sequence_id:
        for i, tick_label in enumerate(ax.get_xticklabels()):
            if accession_ids_for_labels[i] == original_input_sequence_id:
                tick_label.set_color(HIGHLIGHT_COLOR)
                tick_label.set_weight('bold') # Make it bold for more prominence
    
    plt.tight_layout()

    if save_plots:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        clean_plot_title_base = re.sub(r'[^\w\s-]', '', plot_title_base).replace(' ', '_')
        base_filename = f"{clean_plot_title_base}_{label_type}_labels"
        
        # Save in multiple formats ---
        for fmt in output_formats:
            full_filename = f"{base_filename}.{fmt}"
            plt.savefig(os.path.join(output_dir, full_filename), format=fmt, dpi=dpi)
        plt.close()
    else:
        plt.show()


def _perform_collapsing(all_neighborhood_features, full_neighborhood_labels_map, 
                        core_neighborhood_features,
                        collapse_core_similarity_threshold, collapse_full_neighborhood_similarity_threshold,
                        output_prefix=""):
    """
    Performs a two-stage collapsing of similar neighborhoods.
    Stage 1: Group by core (hit + direct neighbors) features.
    Stage 2: Within these groups, sub-group by full neighborhood features.

    Returns: (final_neighborhood_features, final_neighborhood_labels_map, collapsed_groups_report)
    """
    print(f"{output_prefix}  Starting two-stage collapsing (Core Thr: {collapse_core_similarity_threshold}, Full Thr: {collapse_full_neighborhood_similarity_threshold}).")

    collapsed_groups_report = {}
    
    # --- Stage 1: Group by CORE features (hit + direct neighbors) ---
    core_labels_ordered = sorted(list(core_neighborhood_features.keys()))
    if len(core_labels_ordered) < 2:
        print(f"{output_prefix}  Only {len(core_labels_ordered)} neighborhood(s) to process. Skipping collapsing.")
        return all_neighborhood_features, full_neighborhood_labels_map, collapsed_groups_report

    core_vocabulary = sorted(list(set.union(*core_neighborhood_features.values())))
    if not core_vocabulary:
        print(f"{output_prefix}  Warning: No core features found for collapsing. Skipping collapsing step.")
        return all_neighborhood_features, full_neighborhood_labels_map, collapsed_groups_report

    core_feature_vectors = np.array([
        [1 if feature in core_neighborhood_features[nh_label] else 0 for feature in core_vocabulary]
        for nh_label in core_labels_ordered
    ])

    if core_feature_vectors.shape[0] < 2: # Check again after creating vectors
        print(f"{output_prefix}  Only {core_feature_vectors.shape[0]} valid core feature vector(s). Skipping collapsing.")
        return all_neighborhood_features, full_neighborhood_labels_map, collapsed_groups_report

    if all(np.array_equal(vec, core_feature_vectors[0]) for vec in core_feature_vectors):
        core_pre_clusters = {core_labels_ordered[i]: 1 for i in range(len(core_labels_ordered))} # All to one cluster
        print(f"{output_prefix}  All core feature vectors are identical. Treating as one initial core group.")
    else:
        core_distances = pdist(core_feature_vectors, metric='jaccard')
        core_linked = linkage(core_distances, method='average')
        core_pre_clusters = fcluster(core_linked, collapse_core_similarity_threshold, criterion='distance')
    
    # Group neighborhoods by their initial core-feature-based cluster
    initial_core_groups = defaultdict(list)
    for i, group_id in enumerate(core_pre_clusters):
        initial_core_groups[group_id].append(core_labels_ordered[i])

    print(f"{output_prefix}  Stage 1: Grouped into {len(initial_core_groups)} initial core groups based on a threshold of {collapse_core_similarity_threshold}.")

    # --- Stage 2: Sub-group by FULL neighborhood features within each core group ---
    final_neighborhood_features = {}
    final_neighborhood_labels_map = {}
    
    collapsed_total_count = 0
    unique_collapsed_group_counter = 0 # Using a single counter for all collapsed groups
    
    # We'll generate letter codes like A, B, ..., Z, AA, AB, ... for robustness
    def generate_letter_code(index):
        if index < 26:
            return string.ascii_uppercase[index]
        else:
            first_char_idx = (index // 26) - 1
            second_char_idx = index % 26
            return f"{string.ascii_uppercase[first_char_idx]}{string.ascii_uppercase[second_char_idx]}"

    for group_id in sorted(initial_core_groups.keys()):
        members_in_core_group = initial_core_groups[group_id]
        
        if len(members_in_core_group) < 2:
            member_label = members_in_core_group[0]
            final_neighborhood_features[member_label] = all_neighborhood_features[member_label]
            final_neighborhood_labels_map[member_label] = full_neighborhood_labels_map[member_label]
            continue

        sub_group_vocabulary = sorted(list(set.union(*[all_neighborhood_features[m] for m in members_in_core_group])))
        if not sub_group_vocabulary:
            sub_group_assignments = {m: 1 for m in members_in_core_group}
            print(f"{output_prefix}  Warning: No full features for core group {group_id}. Treating all {len(members_in_core_group)} as one sub-cluster.")
        else:
            sub_group_feature_vectors = np.array([
                [1 if feature in all_neighborhood_features[m] else 0 for feature in sub_group_vocabulary]
                for m in members_in_core_group
            ])
            # Avoid pdist if all sub_group_feature_vectors are identical ---
            if sub_group_feature_vectors.shape[0] > 1 and all(np.array_equal(vec, sub_group_feature_vectors[0]) for vec in sub_group_feature_vectors):
                sub_group_assignments = {members_in_core_group[i]: 1 for i in range(len(members_in_core_group))} # All to one sub_cluster
                print(f"{output_prefix}  All full feature vectors in core group {group_id} are identical. Treating as one sub-cluster.")
            else:
                sub_group_distances = pdist(sub_group_feature_vectors, metric='jaccard')
                sub_group_linked = linkage(sub_group_distances, method='average')
                sub_group_assignments = fcluster(sub_group_linked, collapse_full_neighborhood_similarity_threshold, criterion='distance')
        
        current_sub_groups = defaultdict(list)
        for i, sub_cluster_id in enumerate(sub_group_assignments):
            current_sub_groups[sub_cluster_id].append(members_in_core_group[i])
        
        for sub_cluster_id in sorted(current_sub_groups.keys()):
            collapsed_members = current_sub_groups[sub_cluster_id]
            
            if len(collapsed_members) > 1:
                collapsed_total_count += (len(collapsed_members) - 1)
                
                representative_label = collapsed_members[0]
                
                letter_code = generate_letter_code(unique_collapsed_group_counter)
                unique_collapsed_group_counter += 1
                
                orig_organism, orig_hit_id, orig_ssn_id, orig_accession, _ = full_neighborhood_labels_map[representative_label]
                final_neighborhood_labels_map[representative_label] = (orig_organism, orig_hit_id, orig_ssn_id, orig_accession, (len(collapsed_members), letter_code))
                
                collapsed_groups_report[letter_code] = {
                    'representative': representative_label,
                    'members': sorted(collapsed_members),
                    'count': len(collapsed_members)
                }
                
                union_features = set()
                for member_label in collapsed_members:
                    union_features.update(all_neighborhood_features[member_label])
                final_neighborhood_features[representative_label] = union_features
                
            else:
                member_label = collapsed_members[0]
                final_neighborhood_features[member_label] = all_neighborhood_features[member_label]
                final_neighborhood_labels_map[member_label] = full_neighborhood_labels_map[member_label]

    if collapsed_total_count > 0:
        print(f"{output_prefix}  Collapsed a total of {collapsed_total_count} neighborhoods into {len(final_neighborhood_features)} unique entities after two stages.")
    else:
        print(f"{output_prefix}  No neighborhoods were collapsed after two stages (or disabled).")

    return final_neighborhood_features, final_neighborhood_labels_map, collapsed_groups_report


def cluster_gene_neighborhoods_from_sqlite(
    db_path,
    genes_table=GENES_TABLE,
    neighbors_table=NEIGHBORS_TABLE,
    col_neighborhood_id=COL_NEIGHBORHOOD_ID,
    col_gene_id=COL_GENE_ID,
    col_linking_key=COL_LINKING_KEY,
    col_accession_id=COL_ACCESSION_ID,
    col_function_desc=COL_FUNCTION_DESC,
    col_pfam_ids=COL_PFAM_IDS,
    col_interpro_ids=COL_INTERPRO_IDS,
    col_rel_start=COL_REL_START, 
    col_rel_stop=COL_REL_STOP,   
    col_ssn_cluster_id=COL_SSN_CLUSTER_ID,
    hit_gene_weight_factor=HIT_GENE_WEIGHT_FACTOR,
    direct_neighbor_weight_factor=DIRECT_NEIGHBOR_WEIGHT_FACTOR, 
    differentiate_by_ssn_cluster=False,
    ssn_cluster_value_to_filter=DEFAULT_SSN_CLUSTER_VALUE_TO_FILTER,
    collapse_identical_neighborhoods=COLLAPSE_IDENTICAL_NEIGHBORHOODS,
    collapse_core_similarity_threshold=COLLAPSE_CORE_SIMILARITY_THRESHOLD,
    collapse_full_neighborhood_similarity_threshold=COLLAPSE_FULL_NEIGHBORHOOD_SIMILARITY_THRESHOLD,
    original_input_sequence_id=None,
    distance_threshold=0.8,
    plot_dendrogram=True,
    save_plots=SAVE_PLOTS,
    output_dir=OUTPUT_DIR,
    output_formats=OUTPUT_FORMATS, 
    dpi=DPI,
    min_plot_height=MIN_PLOT_HEIGHT, 
    height_per_leaf=HEIGHT_PER_LEAF, 
    max_plot_height=MAX_PLOT_HEIGHT,
    min_plot_width=MIN_PLOT_WIDTH,
    width_per_leaf=WIDTH_PER_LEAF,
    max_plot_width=MAX_PLOT_WIDTH
):
    """
    Clusters gene neighborhoods from an SQLite database based on detailed functional annotations,
    considering every row in 'attributes' as a hit gene and including its neighbors.
    Applies extra weight to the hit gene and direct neighbors.
    Optionally differentiates by SSN cluster.
    Can collapse highly similar neighborhoods (hit + direct neighbors + entire neighborhood) for plotting.

    Returns:
        tuple: A tuple containing (clusters_dict, full_neighborhood_labels_map, collapsed_groups_report).
               clusters_dict is a dict where keys are SSN cluster IDs (or 'All' if no differentiation)
               and values are dicts of cluster_id -> list of unique_neighborhood_labels (which might be representatives).
               final_labels_map is a dict mapping unique_neighborhood_label to (organism, hit_id, ssn_cluster_id, accession_id, collapsed_members_info).
               collapsed_groups_report is a dict detailing the collapsed groups.
    """

    conn = sqlite3.connect(db_path)
    conn.row_factory = sqlite3.Row # Allows accessing columns by name
    cursor = conn.cursor()

    all_neighborhood_features = defaultdict(set) # Full features for main clustering and full-neighborhood-similarity check
    core_neighborhood_features = defaultdict(set) # Subset features for strict core comparison for collapsing

    full_neighborhood_labels_map = {} # Stores (organism, hit_id, ssn_id, accession, collapsed_info)


    # 1. Fetch all 'hit genes' from the 'attributes' table
    hit_gene_select_columns = [
        col_gene_id,
        col_neighborhood_id,
        col_function_desc,
        col_pfam_ids,
        col_interpro_ids,
        col_ssn_cluster_id,
        col_accession_id,
    ]
    query_hit_genes = f"SELECT {', '.join(hit_gene_select_columns)} FROM {genes_table}"
    cursor.execute(query_hit_genes)
    hit_genes_data = cursor.fetchall()

    if not hit_genes_data:
        print("No hit genes found in the 'attributes' table. Please check your database and configuration.")
        conn.close()
        return {}, {}

    raw_ssn_ids_counts = defaultdict(int)
    
    for hit_row in hit_genes_data:
        hit_id = hit_row[col_gene_id]
        organism_name = hit_row[col_neighborhood_id]
        ssn_cluster_id = hit_row[col_ssn_cluster_id]
        accession_id = hit_row[col_accession_id]
        
        raw_ssn_ids_counts[ssn_cluster_id] += 1

        unique_neighborhood_label = f"{organism_name}_{hit_id}"
        
        current_full_features = set() # Accumulates features for the entire neighborhood
        current_core_features = set() # Accumulates features for hit + direct neighbors for collapsing

        # Add features of the HIT gene itself
        # For FULL features, use HIT_ prefix and full weight factor
        hit_full_features = extract_features_from_gene_row(
            gene_row=hit_row, 
            current_weight_factor=hit_gene_weight_factor,
            base_prefix="HIT_",
            include_desc=True, include_pfam=True, include_interpro=True
        )
        current_full_features.update(hit_full_features)
        
        # For CORE features (used for collapsing), use COLLAPSE_ prefix, no extra weighting, maybe simplified annotations
        hit_core_features = extract_features_from_gene_row(
            gene_row=hit_row, 
            current_weight_factor=1, # No extra weight for collapsing criteria
            base_prefix="HIT_CORE_", # Unique prefix for core features
            include_desc=False, include_pfam=True, include_interpro=True # Only PFAM/InterPro for core
        )
        current_core_features.update(hit_core_features)

        full_neighborhood_labels_map[unique_neighborhood_label] = (organism_name, hit_id, ssn_cluster_id, accession_id, None)

        # 2. Fetch NEIGHBOR genes for this specific hit gene's neighborhood
        # Include rel_start and rel_stop to identify direct neighbors
        neighbor_select_columns = [
            col_gene_id,
            col_function_desc,
            col_pfam_ids,
            col_interpro_ids,
            col_rel_start, 
            col_rel_stop,  
        ]
        query_neighbors = f"""
            SELECT {', '.join(neighbor_select_columns)}
            FROM {neighbors_table}
            WHERE {col_linking_key} = ?
        """
        cursor.execute(query_neighbors, (hit_id,))
        raw_neighbor_genes_data = cursor.fetchall() # Store raw data to find direct neighbors

        # Identify direct neighbors based on closest rel_start/stop to 0
        closest_left_neighbor_gene_id = None
        closest_right_neighbor_gene_id = None
        max_neg_rel_stop = -np.inf # Largest negative value, closest to 0 from left
        min_pos_rel_start = np.inf  # Smallest positive value, closest to 0 from right

        for neighbor_row in raw_neighbor_genes_data:
            rel_start = neighbor_row[col_rel_start]
            rel_stop = neighbor_row[col_rel_stop]
            neighbor_gene_id = neighbor_row[col_gene_id]

            if rel_stop is not None and rel_stop < 0 and rel_stop > max_neg_rel_stop:
                max_neg_rel_stop = rel_stop
                closest_left_neighbor_gene_id = neighbor_gene_id

            if rel_start is not None and rel_start > 0 and rel_start < min_pos_rel_start:
                min_pos_rel_start = rel_start
                closest_right_neighbor_gene_id = neighbor_gene_id
            
        for neighbor_row in raw_neighbor_genes_data:
            neighbor_gene_id = neighbor_row[col_gene_id]
            
            current_neighbor_weight_factor = 1 # Default for other neighbors (for full features)
            is_direct_neighbor = False

            if (closest_left_neighbor_gene_id is not None and neighbor_gene_id == closest_left_neighbor_gene_id) or \
               (closest_right_neighbor_gene_id is not None and neighbor_gene_id == closest_right_neighbor_gene_id):
                current_neighbor_weight_factor = direct_neighbor_weight_factor
                is_direct_neighbor = True
            
            # For FULL features, apply N_ prefix and appropriate weight factor
            neighbor_full_features = extract_features_from_gene_row(
                gene_row=neighbor_row, 
                current_weight_factor=current_neighbor_weight_factor,
                base_prefix="N_",
                include_desc=True, include_pfam=True, include_interpro=True
            ) 
            current_full_features.update(neighbor_full_features)

            if is_direct_neighbor: # For CORE features, only direct neighbors get special prefix
                 neighbor_core_features = extract_features_from_gene_row(
                    gene_row=neighbor_row, 
                    current_weight_factor=1, # No extra weight for collapsing criteria
                    base_prefix="N_CORE_", # Unique prefix for core features
                    include_desc=False, include_pfam=True, include_interpro=True # Only PFAM/InterPro for core
                )
                 current_core_features.update(neighbor_core_features)
            # Other neighbors (not direct) do NOT contribute to core_features for collapsing

        all_neighborhood_features[unique_neighborhood_label].update(current_full_features)
        core_neighborhood_features[unique_neighborhood_label].update(current_core_features) # Store core features
    
    conn.close()

    print("\n--- Diagnostic: Raw SSN Cluster ID Distribution in 'attributes' table ---")
    for ssn_id, count in sorted(raw_ssn_ids_counts.items(), key=lambda item: str(item[0])):
        print(f"  SSN ID '{ssn_id}': {count} neighborhoods")
    print("-------------------------------------------------------------------")

    if not all_neighborhood_features:
        print("No gene neighborhoods found or parsed. Exiting.")
        return {}, {}, {}

    all_unique_features_vocabulary_initial = sorted(list(set.union(*all_neighborhood_features.values())))
    
    if not all_unique_features_vocabulary_initial:
        print("No significant features extracted for clustering from any neighborhood. Check parsing logic and data. Exiting.")
        return {1: list(all_neighborhood_features.keys())}, full_neighborhood_labels_map, {} 

    # Pre-clustering (collapsing) similar neighborhoods 
    if collapse_identical_neighborhoods:
        final_neighborhood_features, final_neighborhood_labels_map, collapsed_groups_report = \
            _perform_collapsing(all_neighborhood_features, full_neighborhood_labels_map, 
                                core_neighborhood_features,
                                collapse_core_similarity_threshold, collapse_full_neighborhood_similarity_threshold)
    else:
        print("\nCollapsing identical/similar neighborhoods is disabled. Proceeding with all original neighborhoods.")
        final_neighborhood_features = all_neighborhood_features
        final_neighborhood_labels_map = full_neighborhood_labels_map
        collapsed_groups_report = {}


    # Continue with main clustering using final_neighborhood_features and final_neighborhood_labels_map
    all_unique_features_vocabulary = sorted(list(set.union(*final_neighborhood_features.values()))) # Use final features
    
    if not all_unique_features_vocabulary:
        print("No significant features extracted from final neighborhoods for clustering. Check parsing logic and data. Exiting.")
        return {1: list(final_neighborhood_features.keys())}, final_neighborhood_labels_map, collapsed_groups_report

    ssn_clusters_to_process = defaultdict(list)
    if differentiate_by_ssn_cluster:
        for nh_label, (_, _, ssn_id, _, _) in full_neighborhood_labels_map.items(): # <<< This is the problematic line
            if ssn_id not in ssn_cluster_value_to_filter:
                ssn_clusters_to_process[ssn_id].append(nh_label)
            else:
                print(f"  Diagnostic: Skipping neighborhood {nh_label} (SSN ID '{ssn_id}') due to filtering by '{ssn_cluster_value_to_filter}'.")
                
        print(f"\nFound {len(ssn_clusters_to_process)} distinct SSN clusters to process (after filtering invalid IDs).")
        if ssn_clusters_to_process:
            print(f"  SSN Clusters to be processed: {sorted(list(ssn_clusters_to_process.keys()), key=str)}")
    else:
        ssn_clusters_to_process['All_Neighborhoods'] = sorted(list(final_neighborhood_features.keys())) # Use final features
        print("Processing all neighborhoods together (no SSN cluster differentiation).")

    clusters_output_dict = defaultdict(dict) 
    
    if differentiate_by_ssn_cluster and not ssn_clusters_to_process:
        print("\nNo valid SSN clusters with any neighborhoods found after filtering for differentiation. No plots/reports generated for individual SSN clusters.")
        return {}, final_neighborhood_labels_map, collapsed_groups_report


    for ssn_id, neighborhood_labels_in_ssn_cluster in ssn_clusters_to_process.items():
        if differentiate_by_ssn_cluster:
            print(f"\n--- Processing SSN Cluster: {ssn_id} (contains {len(neighborhood_labels_in_ssn_cluster)} neighborhoods) ---")
            plot_title_prefix = f"SSN Cluster {ssn_id}"
        else:
            plot_title_prefix = "All Gene Neighborhoods"

        current_ssn_neighborhood_features = {
            label: final_neighborhood_features[label] for label in neighborhood_labels_in_ssn_cluster # Use final features
        }
        current_ssn_neighborhood_labels_map = {
            label: final_neighborhood_labels_map[label] for label in neighborhood_labels_in_ssn_cluster # Use final map
        }

        num_neighborhoods_in_group = len(current_ssn_neighborhood_features)
        if num_neighborhoods_in_group < 2:
            print(f"  Skipping SSN Cluster {ssn_id}: Not enough distinct neighborhoods ({num_neighborhoods_in_group}) for clustering. Requires at least 2.")
            clusters_output_dict[ssn_id] = {1: neighborhood_labels_in_ssn_cluster}
            continue

        neighborhood_ids_sorted = sorted(list(current_ssn_neighborhood_features.keys()))
        feature_vectors = []
        for nh_id in neighborhood_ids_sorted:
            vector = [1 if feature in current_ssn_neighborhood_features[nh_id] else 0 for feature in all_unique_features_vocabulary]
            feature_vectors.append(vector)
        
        if all(np.array_equal(vec, feature_vectors[0]) for vec in feature_vectors):
            print(f"  All neighborhoods in {plot_title_prefix} have identical features. No meaningful distance calculated. Skipping plotting.")
            clusters_output_dict[ssn_id] = {1: neighborhood_labels_in_ssn_cluster}
            if plot_dendrogram:
                 print(f"  (No plots generated for {plot_title_prefix} due to identical features)")
            continue
            
        distances = pdist(np.array(feature_vectors), metric='jaccard')
        linked = linkage(distances, method='average')

        if plot_dendrogram:
            # Organism Labels
            _plot_dendrogram(linked, neighborhood_ids_sorted, current_ssn_neighborhood_labels_map, distance_threshold, 
                            f"{plot_title_prefix} GNN", 'organism', 
                            original_input_sequence_id, 
                            save_plots, output_dir, output_formats, dpi, 
                            min_plot_height, height_per_leaf, max_plot_height,
                            min_plot_width, width_per_leaf, max_plot_width)
            # ID Labels
            _plot_dendrogram(linked, neighborhood_ids_sorted, current_ssn_neighborhood_labels_map, distance_threshold, 
                            f"{plot_title_prefix} GNN", 'id', 
                            original_input_sequence_id, 
                            save_plots, output_dir, output_formats, dpi, 
                            min_plot_height, height_per_leaf, max_plot_height,
                            min_plot_width, width_per_leaf, max_plot_width) 
            
        cluster_assignments = fcluster(linked, distance_threshold, criterion='distance')
        current_ssn_clusters = defaultdict(list)
        for i, cluster_id in enumerate(cluster_assignments):
            current_ssn_clusters[cluster_id].append(neighborhood_ids_sorted[i])
        
        clusters_output_dict[ssn_id] = current_ssn_clusters

    return clusters_output_dict, full_neighborhood_labels_map, collapsed_groups_report

In [6]:
# SQLITE_DB_PATH = '39061_CaMES_10kBlast_10e_50eEdge_noFilter_300AST_min900AA_withoutEgtD_withoutMethyltrans.sqlite' 
# SQLITE_DB_PATH = '39063_CaMES_10kBlast_10e_50eEdge_noFilter_300AST_min900AA-clusterseperated_10N.sqlite' 
SQLITE_DB_PATH = '39094_CaMES_10k-Blast_noFilter_40ID-60AST_min950AA_10N.sqlite' 
# SQLITE_DB_PATH = '39151_EanB_10k_Blast_noFilter_50ASTcolorized_onlybigCluster-renamed_10N.sqlite'
# SQLITE_DB_PATH = '39150_OvoA_10k_Blast_search_noFilter_80ASTcolorized_onlyBigCluster-separated_10N.sqlite'

# Set this to the UniProt ID of your original query protein if you want to highlight it.
# Set to None or an empty string if you don't want to highlight any specific protein.
ORIGINAL_INPUT_SEQUENCE_ID = 'A0A7V4WV16' # e.g., 'A0A0B0EG43' or None or ''

# You can change this setting here or keep the global config
DIFFERENTIATE_BY_SSN_CLUSTER = True     # Set to True or False as needed
chosen_distance_threshold = 0.5         # Change as needed

# Configuration for collapsing similar neighborhoods (Local Override)
COLLAPSE_IDENTICAL_NEIGHBORHOODS_ACTIVE = True 
COLLAPSE_CORE_SIMILARITY_THRESHOLD_ACTIVE = 0.0 # Use 0.0 for exact match of hit+direct neighbors
COLLAPSE_FULL_NEIGHBORHOOD_SIMILARITY_THRESHOLD_ACTIVE = 0.3 # e.g., 0.3 for 70% similarity of full neighborhood

# Prepare report file
report_suffix = "_ssn_differentiated" if DIFFERENTIATE_BY_SSN_CLUSTER else "_all_neighborhoods"
report_filename = f"{REPORT_FILENAME_BASE}{report_suffix}.txt"
report_path = os.path.join(OUTPUT_DIR, report_filename)
    
if not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR) # Ensure output directory exists for report and plots

with open(report_path, 'w') as report_file:
    def write_and_print(text):
        print(text)
        report_file.write(text + '\n')

    write_and_print(f"\n--- GNN Clustering Report ({datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}) ---") 
    write_and_print(f"Database: {SQLITE_DB_PATH}")
    write_and_print(f"Jaccard Distance Threshold: {chosen_distance_threshold}")
    write_and_print(f"Hit Gene Weight Factor: {HIT_GENE_WEIGHT_FACTOR}")
    write_and_print(f"Direct Neighbor Weight Factor: {DIRECT_NEIGHBOR_WEIGHT_FACTOR}")
    if DIFFERENTIATE_BY_SSN_CLUSTER:
        write_and_print(f"Clustering differentiated by SSN Cluster ID (column: '{COL_SSN_CLUSTER_ID}').")
    else:
        write_and_print("Clustering all neighborhoods together (no SSN cluster differentiation).")
    
    if COLLAPSE_IDENTICAL_NEIGHBORHOODS_ACTIVE:
        write_and_print(f"Collapsing identical/similar neighborhoods enabled:")
        write_and_print(f"  Stage 1 (Hit+Direct Neighbor Core): Threshold {COLLAPSE_CORE_SIMILARITY_THRESHOLD_ACTIVE}")
        write_and_print(f"  Stage 2 (Full Neighborhood): Threshold {COLLAPSE_FULL_NEIGHBORHOOD_SIMILARITY_THRESHOLD_ACTIVE}")
    else:
        write_and_print("Collapsing identical/similar neighborhoods disabled.")

    if ORIGINAL_INPUT_SEQUENCE_ID:
        write_and_print(f"Original Input Sequence Accession ID for highlighting: '{ORIGINAL_INPUT_SEQUENCE_ID}' (colored '{HIGHLIGHT_COLOR}')")
    else:
        write_and_print("No specific original input sequence ID provided for highlighting.")
    write_and_print(f"Plots saved to: {OUTPUT_DIR} in {OUTPUT_FORMATS} formats at {DPI} DPI.")
    write_and_print(f"Report also saved to: {report_path}")
    write_and_print("-" * 70)


    clusters_by_ssn, final_labels_map, collapsed_groups_report = cluster_gene_neighborhoods_from_sqlite(
                                            db_path=SQLITE_DB_PATH,
                                            col_accession_id=COL_ACCESSION_ID,
                                            col_rel_start=COL_REL_START,
                                            col_rel_stop=COL_REL_STOP,
                                            hit_gene_weight_factor=HIT_GENE_WEIGHT_FACTOR,
                                            direct_neighbor_weight_factor=DIRECT_NEIGHBOR_WEIGHT_FACTOR,
                                            differentiate_by_ssn_cluster=DIFFERENTIATE_BY_SSN_CLUSTER,
                                            ssn_cluster_value_to_filter=DEFAULT_SSN_CLUSTER_VALUE_TO_FILTER,
                                            collapse_identical_neighborhoods=COLLAPSE_IDENTICAL_NEIGHBORHOODS_ACTIVE,
                                            collapse_core_similarity_threshold=COLLAPSE_CORE_SIMILARITY_THRESHOLD_ACTIVE,
                                            collapse_full_neighborhood_similarity_threshold=COLLAPSE_FULL_NEIGHBORHOOD_SIMILARITY_THRESHOLD_ACTIVE,
                                            original_input_sequence_id=ORIGINAL_INPUT_SEQUENCE_ID,
                                            distance_threshold=chosen_distance_threshold,
                                            plot_dendrogram=True,
                                            save_plots=SAVE_PLOTS,
                                            output_dir=OUTPUT_DIR,
                                            output_formats=OUTPUT_FORMATS, 
                                            dpi=DPI,
                                            min_plot_height=MIN_PLOT_HEIGHT,
                                            height_per_leaf=HEIGHT_PER_LEAF, 
                                            max_plot_height=MAX_PLOT_HEIGHT,
                                            min_plot_width=MIN_PLOT_WIDTH, 
                                            width_per_leaf=WIDTH_PER_LEAF, 
                                            max_plot_width=MAX_PLOT_WIDTH  
                                        )

    if clusters_by_ssn:
        write_and_print("\n--- Final Clustering Results ---")
        for ssn_id, clusters_in_ssn in sorted(clusters_by_ssn.items(), key=lambda item: str(item[0])):
            write_and_print(f"\n### Results for SSN Cluster: {ssn_id} ###")
            if not clusters_in_ssn:
                write_and_print("  No clusters formed for this SSN group, or insufficient data.")
                continue

            for cluster_id, neighborhoods_in_cluster in sorted(clusters_in_ssn.items()):
                write_and_print(f"  Cluster {cluster_id}: {len(neighborhoods_in_cluster)} neighborhoods")
                for nh_id in neighborhoods_in_cluster:
                    organism_name, hit_id_internal, _, accession_id, collapsed_info = final_labels_map.get(nh_id, ('UNKNOWN', 'UNKNOWN', None, 'UNKNOWN', None))
                    
                    highlight_indicator = " (ORIGINAL INPUT)" if accession_id == ORIGINAL_INPUT_SEQUENCE_ID else ""
                    collapsed_suffix = ""
                    if collapsed_info:
                        count, letter_code = collapsed_info
                        collapsed_suffix = f" (Collapsed: {count} neighborhoods, Ref: {letter_code})"
                    
                    write_and_print(f"    - Organism: {organism_name}, Hit Accession: {accession_id}{highlight_indicator}{collapsed_suffix} (Internal ID: {hit_id_internal}) (Neighborhood ID: {nh_id})")
            write_and_print("  " + "-" * 30)
        
        # --- NEW: Report on Collapsed Groups ---
        if collapsed_groups_report:
            write_and_print("\n--- Detailed Report on Collapsed Neighborhood Groups ---")
            for code, group_data in sorted(collapsed_groups_report.items()):
                write_and_print(f"  Group ({code}): Representative: {group_data['representative']} (Total: {group_data['count']} members)")
                for member_nh_id in group_data['members']:
                    member_organism, member_hit_id, _, member_accession, _ = final_labels_map.get(member_nh_id, ('UNKNOWN', 'UNKNOWN', None, 'UNKNOWN', None)) # <<< IMPORTANT: Use final_labels_map here
                    write_and_print(f"    - {member_organism} (Accession: {member_accession}) (Internal ID: {member_hit_id}) (NH ID: {member_nh_id})")
            write_and_print("-------------------------------------------------------")
    else:
        write_and_print("\nNo clusters formed at all. This could mean your database is empty, or no features were extracted after filtering, or no valid SSN clusters with multiple neighborhoods were found.")

    write_and_print("\n--- Report End ---")


--- GNN Clustering Report (2025-08-29 21:19:18) ---
Database: 39094_CaMES_10k-Blast_noFilter_40ID-60AST_min950AA_10N.sqlite
Jaccard Distance Threshold: 0.5
Hit Gene Weight Factor: 10
Direct Neighbor Weight Factor: 3
Clustering differentiated by SSN Cluster ID (column: 'cluster_num').
Collapsing identical/similar neighborhoods enabled:
  Stage 1 (Hit+Direct Neighbor Core): Threshold 0.0
  Stage 2 (Full Neighborhood): Threshold 0.3
Original Input Sequence Accession ID for highlighting: 'A0A7V4WV16' (colored 'red')
Plots saved to: gnn_cluster_plots in ['svg', 'png', 'pdf'] formats at 300 DPI.
Report also saved to: gnn_cluster_plots\gnn_clustering_report_ssn_differentiated.txt
----------------------------------------------------------------------

--- Diagnostic: Raw SSN Cluster ID Distribution in 'attributes' table ---
  SSN ID '1': 102 neighborhoods
-------------------------------------------------------------------
  Starting two-stage collapsing (Core Thr: 0.0, Full Thr: 0.3).
  Stage

ValueError: too many values to unpack (expected 3)