# Cluster Gene Neighborhoods

In [35]:
import os
# Ensure multiprocessing is imported for os.cpu_count()
import multiprocessing 
from threadpoolctl import threadpool_limits 

# Set OMP_NUM_THREADS as early as possible.
# This variable primarily controls internal threading for numerical libraries like OpenBLAS/MKL
# used by NumPy/SciPy.
# When using joblib with n_jobs > 1 (process-based parallelism), each child process
# will inherit this setting. However, to avoid oversubscription and ensure optimal performance
# when a single joblib process performs an operation that *could* be multi-threaded itself
# (like sparse matrix dot products), we will use `threadpool_limits` within the joblib task.
# So, setting it globally is fine, but `threadpool_limits` will override it locally.
num_logical_cores = os.cpu_count()
if num_logical_cores:
    # For OMP_NUM_THREADS, typically use all logical cores.
    os.environ["OMP_NUM_THREADS"] = str(num_logical_cores) 
    print(f"Set OMP_NUM_THREADS to {num_logical_cores} for SciPy/NumPy internal multi-threading.")
else:
    print("Could not detect CPU count. OMP_NUM_THREADS not explicitly set.")


import pandas as pd
from collections import defaultdict
import re
from scipy.cluster.hierarchy import linkage, fcluster, dendrogram
import matplotlib.pyplot as plt
import numpy as np
import datetime
import string
import scipy.sparse as sp
import time
from tqdm import tqdm
import gc
from joblib import Parallel, delayed 
from openpyxl import load_workbook 


# --- Configuration for the SQLite DB ---
# EXCEL_FILE_PATH will be set in the __main__ block for convenience

# Column names in the Excel sheets
EXCEL_COL_LENGTH = '长度(aa)'
EXCEL_COL_PFAM = 'pfam'
EXCEL_COL_HIT_GENE = 'hit_gene'
EXCEL_HIT_GENE_MARKER = 'yes' # Value indicating the hit gene

# Keywords in Excel rows that signify the end of data for a neighborhood
EXCEL_STOP_KEYWORDS = frozenset(['LOCUS', 'CDS', 'FEATURES'])

# General Constants (adapted for Excel context)
COL_NEIGHBORHOOD_ID = 'organism' # Maps to Excel sheet name
COL_GENE_ID = 'id' # Maps to the first data column in Excel sheet (e.g., 'ctg5-78')

HIT_GENE_WEIGHT_FACTOR = 10 # Factor by which hit gene features are "copied" for emphasis
DIRECT_NEIGHBOR_WEIGHT_FACTOR = 3 # Factor for direct neighbor domain features

SAVE_PLOTS = True # Set to True to save plots to files
OUTPUT_DIR = 'gnn_cluster_plots_PEP' # Directory to save plots
REPORT_FILENAME_BASE = 'gnn_clustering_report' # Base name, will append info dynamically
OUTPUT_FORMATS = ['pdf'] # ['svg', 'png', 'pdf'] # List of formats to save plots in
DPI = 300 # Dots per inch for raster formats like 'png'
HIGHLIGHT_COLOR = 'red' # Color for the original input sequence's leaf label

# Dynamic plot sizing parameters
MIN_PLOT_HEIGHT = 8 # Minimum height of the plot in inches
HEIGHT_PER_LEAF = 0.05 # Adjust this value (e.g., 0.2 to 0.5) to compress/stretch the Y-axis
MAX_PLOT_HEIGHT = 40 # Maximum height to prevent excessively tall plots

MIN_PLOT_WIDTH = 10 # Minimum width of the plot in inches
WIDTH_PER_LEAF = 0.3 # Adjust this value (e.g., 0.3 to 0.7) to compress/stretch the X-axis (more horizontal space for branches)
MAX_PLOT_WIDTH = 150 # Maximum width to prevent excessively wide plots

# Configuration for collapsing similar neighborhoods
COLLAPSE_IDENTICAL_NEIGHBORHOODS = True # Set to True to enable collapsing
COLLAPSE_CORE_SIMILARITY_THRESHOLD = 0.0 # Strict similarity for hit gene + direct neighbors. Usually 0.0 for exact matches.
COLLAPSE_FULL_NEIGHBORHOOD_SIMILARITY_THRESHOLD = 0.3 # Similarity for the entire neighborhood. e.g., 0.3 for 70% similarity.

# Define a constant for the minimum number of samples to trigger parallel pdist
MIN_ITEMS_FOR_PARALLEL_PROCESSING = 20 # Adjust this value as needed

# Constants for PFAM cleaning
PFAM_PREFIX = "Pfam:"
UNINFORMATIVE_PFAM_TERMS = frozenset(['unknown', '-', '无', '', 'none', 'null', 'low complexity', 'Pfam', 'pfam']) 
# ----------------------------------------------------------------------


def parse_pfam_string(pfam_str, prefix=""):
    """
    Parses a string from the 'pfam' column of the Excel,
    applying cleaning rules and a prefix.
    """
    if not isinstance(pfam_str, str) or pd.isna(pfam_str):
        return set()

    initial_clean_str = str(pfam_str).strip() # Ensure it's a string type before processing

    # Check for uninformative terms AFTER stripping
    if initial_clean_str.lower() in UNINFORMATIVE_PFAM_TERMS:
        return set()

    features = set()
    # Split by comma (,), semicolon (;), or Chinese comma (，)
    parts = [p.strip() for p in re.split(r'[,;，]', initial_clean_str) if p.strip()]

    for part in parts:
        current_part = part.strip() # Ensure individual part is stripped

        # --- FIX: Apply 'Pfam:' prefix removal to each individual part ---
        if current_part.lower().startswith(PFAM_PREFIX.lower()):
            current_part = current_part[len(PFAM_PREFIX):].strip()

        # After prefix removal, re-check for uninformative parts
        if current_part.lower() in UNINFORMATIVE_PFAM_TERMS:
            continue
        
        # Collapse multiple spaces, convert to lowercase for consistency
        clean_part = re.sub(r'\s+', ' ', current_part).lower().strip()

        if clean_part: # Ensure it's not empty after cleaning
            features.add(f"{prefix}{clean_part}")
    return features


def read_excel_gene_neighborhoods(excel_path, report_file=None):
    """
    Reads gene neighborhood data from an Excel file, where each sheet is a neighborhood.
    Extracts data for `长度(aa)`, `pfam`, `hit_gene` columns and up to stop keywords.
    """
    def _write_and_print(text):
        print(text)
        if report_file:
            report_file.write(text + '\n')

    _write_and_print(f"Reading Excel file: {excel_path}")
    excel_data = {} # Dictionary to store DataFrames, key is sheet name

    try:
        workbook = load_workbook(excel_path, data_only=True) # data_only=True reads cell values, not formulas

        for sheet_name in tqdm(workbook.sheetnames, desc="Processing Excel sheets"):
            _write_and_print(f"  Processing sheet: {sheet_name}")
            worksheet = workbook[sheet_name]
            
            header_row_index = -1
            header_cols = {} # To store column index for 'gene_id', '长度(aa)', 'pfam', 'hit_gene'

            # Search for header row
            for r_idx, row in enumerate(worksheet.iter_rows()):
                current_row_values = [cell.value for cell in row]
                # Check for header names (case-insensitive, strip whitespace)
                lower_case_values = [str(v).strip().lower() if v is not None else '' for v in current_row_values]
                
                # Check if all required headers are present
                if all(col_name.lower() in lower_case_values for col_name in [EXCEL_COL_LENGTH, EXCEL_COL_PFAM, EXCEL_COL_HIT_GENE]):
                    header_row_index = r_idx + 1 # openpyxl rows are 1-indexed
                    
                    # Map required column names to their 0-indexed column positions
                    for i, val in enumerate(lower_case_values):
                        if val == EXCEL_COL_LENGTH.lower():
                            header_cols[EXCEL_COL_LENGTH] = i
                        elif val == EXCEL_COL_PFAM.lower():
                            header_cols[EXCEL_COL_PFAM] = i
                        elif val == EXCEL_COL_HIT_GENE.lower():
                            header_cols[EXCEL_COL_HIT_GENE] = i
                    
                    # Assume gene ID column is always the first column (index 0)
                    # Its header might be empty, as seen in examples.
                    header_cols['gene_id_col'] = 0
                    
                    break # Header row found, stop searching
            
            if header_row_index == -1:
                _write_and_print(f"  Warning: Valid header row not found in sheet '{sheet_name}'. Skipping sheet.")
                continue

            data_rows = []
            gene_idx_counter = 0 # To assign unique identifiers within a neighborhood if gene IDs are missing or duplicate
            for r_idx in range(header_row_index, worksheet.max_row + 1):
                row_values = [cell.value for cell in worksheet[r_idx]]
                
                # Check for stop keywords or completely empty row
                if not any(v is not None and str(v).strip() != '' for v in row_values) or \
                   any(str(v).strip().upper() in EXCEL_STOP_KEYWORDS for v in row_values if v is not None):
                    break # Stop reading data for this sheet
                
                # Extract data based on identified column indices
                gene_id_raw = row_values[header_cols['gene_id_col']] if 'gene_id_col' in header_cols and header_cols['gene_id_col'] < len(row_values) else None
                
                # Assign a robust gene ID: use existing, or construct if missing
                gene_id = str(gene_id_raw) if gene_id_raw is not None and str(gene_id_raw).strip() != '' else f"{sheet_name}_gene_{gene_idx_counter}"

                # Only attempt to get values if the column index is within row_values bounds
                length = row_values[header_cols[EXCEL_COL_LENGTH]] if EXCEL_COL_LENGTH in header_cols and header_cols[EXCEL_COL_LENGTH] < len(row_values) else None
                pfam = row_values[header_cols[EXCEL_COL_PFAM]] if EXCEL_COL_PFAM in header_cols and header_cols[EXCEL_COL_PFAM] < len(row_values) else None
                hit_gene = row_values[header_cols[EXCEL_COL_HIT_GENE]] if EXCEL_COL_HIT_GENE in header_cols and header_cols[EXCEL_COL_HIT_GENE] < len(row_values) else None

                # Only include rows that have a meaningful gene_id and pfam data
                if gene_id is not None and str(gene_id).strip() != '' and (pfam is not None and str(pfam).strip() != ''):
                    data_rows.append({
                        COL_GENE_ID: gene_id,
                        EXCEL_COL_LENGTH: length,
                        EXCEL_COL_PFAM: pfam,
                        EXCEL_COL_HIT_GENE: hit_gene
                    })
                gene_idx_counter += 1

            if data_rows:
                excel_data[sheet_name] = pd.DataFrame(data_rows)
            else:
                _write_and_print(f"  Warning: No valid data rows found in sheet '{sheet_name}' after header. Skipping sheet.")

    except Exception as e:
        _write_and_print(f"Error reading Excel file {excel_path}: {e}")
        return {}
    
    _write_and_print(f"Successfully read data from {len(excel_data)} sheets.")
    return excel_data


def extract_features_from_excel_gene_row(gene_row_dict, current_weight_factor=1, base_prefix="N_"):
    """
    Extracts features (PFAM terms) from a single gene row dictionary (from Excel data),
    applying a base prefix and duplicating features by current_weight_factor.
    It uses the specialized parse_pfam_string for cleaning.
    """
    features_set = set()
    raw_features = set()

    # Only PFAM is available as a functional description from the Excel format
    pfam_ids = gene_row_dict.get(EXCEL_COL_PFAM)
    raw_features.update(parse_pfam_string(pfam_ids))
    
    if current_weight_factor > 1:
        for feature in raw_features:
            for i in range(current_weight_factor):
                features_set.add(f"{base_prefix}{feature}_w{i}") 
    else: 
        for feature in raw_features:
            features_set.add(f"{base_prefix}{feature}")

    return features_set


def _plot_dendrogram(linked, neighborhood_ids_subset, labels_map, distance_threshold, 
                     plot_title_base, label_type, original_input_sequence_id,
                     save_plots, output_dir, output_formats, dpi, 
                     min_plot_height, height_per_leaf, max_plot_height,
                     min_plot_width, width_per_leaf, max_plot_width,
                     report_file=None):
    """
    Helper function to generate a single dendrogram plot.
    """
    # An internal helper function for consistent logging
    def _write_and_print_internal(text):
        # Only print to console for plots as it can be very verbose
        # and only save to report if it's a critical message.
        # Plotting messages are now controlled outside this function.
        if report_file:
            report_file.write(text + '\n')
    
    fig_title = f"{plot_title_base} ({label_type.capitalize()} Labels)"
    
    labels_to_use = []
    # Store the hit_id for each label in the order it will be plotted, to match with xticklabels
    accession_ids_for_labels = []
    for nh_id in neighborhood_ids_subset:
        organism_name, hit_id_internal, ssn_cluster_id, accession_id, _ = labels_map.get(nh_id, ('Unknown', 'Unknown', None, 'Unknown', None))

        if label_type == 'organism':
            labels_to_use.append(organism_name.rstrip('.'))
        elif label_type == 'id': # This now means 'accession'
            labels_to_use.append(accession_id) # Use accession_id for 'id' labels
        else:
            labels_to_use.append(nh_id) # Fallback, should not be hit
        accession_ids_for_labels.append(accession_id)

    # Figure size calculation
    num_leaves = len(neighborhood_ids_subset)
    
    calculated_height = max(min_plot_height, num_leaves * height_per_leaf)
    final_height = min(calculated_height, max_plot_height)
    
    calculated_width = max(min_plot_width, num_leaves * width_per_leaf)
    final_width = min(calculated_width, max_plot_width)
    
    plt.figure(figsize=(final_width, final_height)) # Use dynamic width and height
    
    dendrogram(linked,
               orientation='top',
               labels=labels_to_use,
               distance_sort='descending',
               show_leaf_counts=True)
    
    plt.title(fig_title)
    plt.xlabel('Gene Neighborhood (Labeled by ' + ('Accession' if label_type == 'id' else label_type.capitalize()) + ')')     # Y-axis label is now fixed to Linear Scale
    plt.ylabel(f'Jaccard Distance') 
    plt.axhline(y=distance_threshold, color='r', linestyle='--', label=f'Cut-off at {distance_threshold}')
    plt.legend()

    ax = plt.gca()    
    plt.setp(ax.get_xticklabels(), rotation=90, ha="right", rotation_mode="anchor")
    
    # Color specific leaf label
    if original_input_sequence_id:
        for i, tick_label in enumerate(ax.get_xticklabels()):
            if accession_ids_for_labels[i] == original_input_sequence_id:
                tick_label.set_color(HIGHLIGHT_COLOR)
                tick_label.set_weight('bold') # Make it bold for more prominence
    
    plt.tight_layout()

    if save_plots:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        clean_plot_title_base = re.sub(r'[^\w\s-]', '', plot_title_base).replace(' ', '_')
        base_filename = f"{clean_plot_title_base}_{label_type}_labels"
        
        # Save in multiple formats ---
        for fmt in output_formats:
            full_filename = f"{base_filename}.{fmt}"
            plt.savefig(os.path.join(output_dir, full_filename), format=fmt, dpi=dpi)
        plt.close()
    else:
        plt.show()


def parallel_pdist_jaccard(feature_matrix, num_cores=-1):
    """
    Calculates the condensed Jaccard distance matrix.
    Uses joblib parallelization only if n_samples > MIN_ITEMS_FOR_PARALLEL_PROCESSING.
    Otherwise, it runs sequentially.
    """
    if not isinstance(feature_matrix, sp.csr_matrix):
        if isinstance(feature_matrix, (sp.csc_matrix, sp.lil_matrix, sp.coo_matrix)):
            feature_matrix = feature_matrix.tocsr()
        else:
            raise TypeError("Input feature_matrix must be a SciPy sparse matrix (preferably CSR).")

    n_samples = feature_matrix.shape[0]
    if n_samples <= 1:
        return np.array([])
    
    # Determine the number of cores to use, with a fallback to 1 if detection fails
    if num_cores == -1:
        detected_cores = os.cpu_count()
        num_cores = detected_cores if detected_cores is not None and detected_cores > 0 else 1
    elif num_cores == 0: # Treat 0 as explicit sequential execution
        num_cores = 1

    # --- Pre-compute all feature sets once in the parent process ---
    # This is always beneficial for performance regardless of parallelization,
    # as it avoids repeated set construction from the sparse matrix.
    # The tqdm desc is now more informative and reflects the context.
    print(f"  Pre-calculating {n_samples} feature sets...")
    set_precomputation_start = time.time()
    feature_sets = [
        set(feature_matrix.indices[feature_matrix.indptr[i]:feature_matrix.indptr[i+1]])
        for i in tqdm(range(n_samples), desc=f"  Pre-computing feature sets (N={n_samples})", leave=False)
    ]
    print(f"  Feature set pre-calculation took {time.time() - set_precomputation_start:.2f} seconds.")

    # Explicitly delete the sparse matrix after feature_sets are extracted.
    del feature_matrix
    gc.collect()

    # Define the core distance calculation logic (for both sequential and parallel workers)
    def _calculate_distances_from_sets(start_i, end_i, all_feature_sets_ref, n_samples_total_ref):
        # IMPORTANT: Limit internal (BLAS/LAPACK) threading for each joblib process to 1.
        with threadpool_limits(limits=1, user_api='blas'):
            distances_chunk = []
            
            for i in range(start_i, end_i):
                set_i = all_feature_sets_ref[i]
                for j in range(i + 1, n_samples_total_ref):
                    set_j = all_feature_sets_ref[j]

                    intersection_size = len(set_i.intersection(set_j))
                    union_size = len(set_i.union(set_j))

                    if union_size == 0:
                        distances_chunk.append(0.0)
                    else:
                        distances_chunk.append(1.0 - (intersection_size / union_size))
            return distances_chunk

    total_i_iterations = n_samples - 1
    if total_i_iterations <= 0:
        return np.array([])

    # --- Conditional Parallelization Logic ---
    if n_samples < MIN_ITEMS_FOR_PARALLEL_PROCESSING or num_cores == 1:
        print(f"  Running Jaccard distance sequentially for N={n_samples} (below parallel threshold or num_cores=1).")
        results = [_calculate_distances_from_sets(0, total_i_iterations, feature_sets, n_samples)]
    else:
        print(f"  Running Jaccard distance in parallel for N={n_samples} using {num_cores} cores.")
        i_ranges_for_tasks = []
        # Create more granular tasks for better load balancing, but not excessively so for small N
        num_tasks = min(total_i_iterations, num_cores * 4) # Up to 4 tasks per core, but no more than i_iterations
        if num_tasks == 0: return np.array([]) # Safety check for very small N
        
        chunk_size_for_i = max(1, (total_i_iterations + num_tasks - 1) // num_tasks)
        
        for k in range(0, total_i_iterations, chunk_size_for_i):
            start_i = k
            end_i = min(k + chunk_size_for_i, total_i_iterations)
            i_ranges_for_tasks.append((start_i, end_i))
            
        tasks = [
            delayed(_calculate_distances_from_sets)(start_i_chunk, end_i_chunk, feature_sets, n_samples)
            for start_i_chunk, end_i_chunk in i_ranges_for_tasks
        ]
        
        results = Parallel(n_jobs=num_cores, backend="loky", verbose=0)(tasks)

    condensed_distances = np.concatenate(results)
    return np.array(condensed_distances)




def _perform_collapsing(all_neighborhood_features, full_neighborhood_labels_map, 
                        core_neighborhood_features,
                        collapse_core_similarity_threshold, collapse_full_neighborhood_similarity_threshold,
                        output_prefix="", report_file=None, parallelize_pdist=False):
    """
    Performs a two-stage collapsing of similar neighborhoods.
    Stage 1: Group by core (hit + direct neighbors) features.
    Stage 2: Within these groups, sub-group by full neighborhood features.

    Returns: (final_neighborhood_features, final_neighborhood_labels_map, collapsed_groups_report)
    """
    # An internal helper function for consistent logging
    def _write_and_print_internal(text):
        print(text)
        if report_file:
            report_file.write(text + '\n')

    _write_and_print_internal(f"{output_prefix}  Starting two-stage collapsing (Core Thr: {collapse_core_similarity_threshold}, Full Thr: {collapse_full_neighborhood_similarity_threshold}).")
    
    collapsing_overall_start = time.time()

    collapsed_groups_report = {}
    
    # --- Stage 1: Group by CORE features (hit + direct neighbors) ---
    stage1_start = time.time()
    core_labels_ordered = sorted(list(core_neighborhood_features.keys()))
    if len(core_labels_ordered) < 2:
        _write_and_print_internal(f"{output_prefix}  Only {len(core_labels_ordered)} neighborhood(s) to process. Skipping collapsing.")
        return all_neighborhood_features, full_neighborhood_labels_map, collapsed_groups_report

    core_vocabulary = sorted(list(set.union(*core_neighborhood_features.values())))
    if not core_vocabulary:
        _write_and_print_internal(f"{output_prefix}  Warning: No core features found for collapsing. Skipping collapsing step.")
        return all_neighborhood_features, full_neighborhood_labels_map, collapsed_groups_report

    _write_and_print_internal(f"{output_prefix}  Stage 1: Building sparse matrix for {len(core_labels_ordered)} neighborhoods and {len(core_vocabulary)} core features...")
    matrix_build_start = time.time()
    core_feature_to_idx = {feature: i for i, feature in enumerate(core_vocabulary)}
    num_core_neighborhoods = len(core_labels_ordered)
    num_core_features = len(core_vocabulary)

    # Use LIL for efficient construction, then convert to CSR for computation
    core_feature_vectors_lil = sp.lil_matrix((num_core_neighborhoods, num_core_features), dtype=np.int8) 
    for i, nh_label in enumerate(tqdm(core_labels_ordered, desc=f"{output_prefix}  Stage 1: Populating core features", leave=False)): # <-- MODIFIED: Added tqdm
        for feature in core_neighborhood_features[nh_label]:
            if feature in core_feature_to_idx: # Safety check
                j = core_feature_to_idx[feature]
                core_feature_vectors_lil[i, j] = 1
    core_feature_vectors = core_feature_vectors_lil.tocsr() # Convert to CSR
    _write_and_print_internal(f"{output_prefix}  Stage 1: Sparse matrix built in {time.time() - matrix_build_start:.2f} seconds. Shape: {core_feature_vectors.shape}, NNZ: {core_feature_vectors.nnz}") # <-- ADDED: Detailed report
    gc.collect() # <-- ADDED: Garbage collection

    if core_feature_vectors.shape[0] < 2: # Check again after creating vectors
        _write_and_print_internal(f"{output_prefix}  Only {core_feature_vectors.shape[0]} valid core feature vector(s). Skipping collapsing.")
        return all_neighborhood_features, full_neighborhood_labels_map, collapsed_groups_report

    # Check for identical sparse vectors
    if num_core_neighborhoods > 1 and all(
        (core_feature_vectors[0] != core_feature_vectors[i]).nnz == 0 # Compare sparse rows
        for i in range(1, num_core_neighborhoods)
    ):
        core_pre_clusters = {core_labels_ordered[i]: 1 for i in range(len(core_labels_ordered))}
        _write_and_print_internal(f"{output_prefix}  All core feature vectors are identical. Treating as one initial core group.")
    else:
        _write_and_print_internal(f"{output_prefix}  Stage 1: Calculating core distances using scipy.pdist...") # <-- MODIFIED: Detailed report
        distance_calc_start = time.time()
        core_distances = parallel_pdist_jaccard(core_feature_vectors, num_cores=-1 if parallelize_pdist else 1)
        _write_and_print_internal(f"{output_prefix}  Stage 1: Core distance calculation took {time.time() - distance_calc_start:.2f} seconds.")

        _write_and_print_internal(f"{output_prefix}  Stage 1: Performing linkage and clustering for core features...") # <-- MODIFIED: Detailed report
        linkage_start = time.time()
        core_linked = linkage(core_distances, method='average')
        _write_and_print_internal(f"{output_prefix}  Stage 1: Linkage took {time.time() - linkage_start:.2f} seconds.") # <-- ADDED: Detailed report
        core_pre_clusters = fcluster(core_linked, collapse_core_similarity_threshold, criterion='distance')
    
    # Group neighborhoods by their initial core-feature-based cluster
    initial_core_groups = defaultdict(list)
    # The output of fcluster is an array of cluster IDs, map them back to labels_ordered
    # Handling for when core_pre_clusters is already a dict (all identical) or an array
    if isinstance(core_pre_clusters, np.ndarray): 
        for i, group_id in enumerate(core_pre_clusters):
            initial_core_groups[group_id].append(core_labels_ordered[i])
    else: 
        initial_core_groups = {1: core_labels_ordered} # Put all in one group (from the `if all(...)` block)
    
    _write_and_print_internal(f"{output_prefix}  Stage 1: Grouped into {len(initial_core_groups)} initial core groups based on a threshold of {collapse_core_similarity_threshold}. Total stage 1 took {time.time() - stage1_start:.2f} seconds.") # <-- ADDED: Detailed report

    # Explicit memory cleanup for Stage 1 objects
    del core_feature_vectors 
    if 'core_distances' in locals(): del core_distances 
    if 'core_linked' in locals(): del core_linked 
    gc.collect() # Garbage collection

    # --- Stage 2: Sub-group by FULL neighborhood features within each core group ---
    stage2_start = time.time()
    
    # We'll generate letter codes like A, B, ..., Z, AA, AB, ... for robustness
    def generate_letter_code(index):
        if index < 26:
            return string.ascii_uppercase[index]
        else:
            first_char_idx = (index // 26) - 1
            second_char_idx = index % 26
            return f"{string.ascii_uppercase[first_char_idx]}{string.ascii_uppercase[second_char_idx]}"

    _write_and_print_internal(f"{output_prefix}  Stage 2: Processing {len(initial_core_groups)} core groups for full neighborhood similarity...")

    # --- Define the worker function for processing a chunk of core groups ---
    def process_core_group_chunk(core_group_ids_chunk,
                                 all_neighborhood_features_ref,
                                 full_neighborhood_labels_map_ref,
                                 collapse_full_neighborhood_similarity_threshold_ref,
                                 parallelize_pdist_ref):
        
        results_to_aggregate = []
        local_collapsed_total_count = 0
        
        # Determine internal parallelism for pdist calls *within this worker*
        # This decision is based on parallelize_pdist_ref (if outer parallelization is on)
        # and the MIN_ITEMS_FOR_PARALLEL_PROCESSING threshold.
        # So, if parallelize_pdist_ref is True, then we allow the internal pdist calls to be parallel
        # if they meet their own MIN_ITEMS_FOR_PARALLEL_PROCESSING threshold (which is now called MIN_ITEMS_FOR_PARALLEL_PROCESSING).
        # Otherwise, internal pdist calls are sequential.
        num_cores_for_internal_pdist = -1 if parallelize_pdist_ref else 1 

        with threadpool_limits(limits=1, user_api='blas'):
            for group_id in core_group_ids_chunk:
                members_in_core_group = initial_core_groups[group_id]
                
                if len(members_in_core_group) < 2:
                    member_label = members_in_core_group[0]
                    results_to_aggregate.append((member_label,
                                                 all_neighborhood_features_ref[member_label],
                                                 full_neighborhood_labels_map_ref[member_label],
                                                 None))
                    continue

                sub_group_vocabulary = sorted(list(set.union(*[all_neighborhood_features_ref[m] for m in members_in_core_group])))
                if not sub_group_vocabulary:
                    sub_group_assignments = {m: 1 for m in members_in_core_group}
                else:
                    sub_group_feature_to_idx = {feature: i for i, feature in enumerate(sub_group_vocabulary)}
                    num_sub_group_neighborhoods = len(members_in_core_group)
                    num_sub_group_features = len(sub_group_vocabulary)

                    sub_group_feature_vectors_lil = sp.lil_matrix((num_sub_group_neighborhoods, num_sub_group_features), dtype=np.int8)
                    for i, nh_label in enumerate(members_in_core_group):
                        for feature in all_neighborhood_features_ref[nh_label]:
                            if feature in sub_group_feature_to_idx:
                                j = sub_group_feature_to_idx[feature]
                                sub_group_feature_vectors_lil[i, j] = 1
                    sub_group_feature_vectors = sub_group_feature_vectors_lil.tocsr()

                    if num_sub_group_neighborhoods > 1 and all(
                        (sub_group_feature_vectors[0] != sub_group_feature_vectors[i]).nnz == 0
                        for i in range(1, num_sub_group_neighborhoods)
                    ):
                        sub_group_assignments = {members_in_core_group[i]: 1 for i in range(len(members_in_core_group))}
                    else:
                        sub_group_distances = parallel_pdist_jaccard(
                            sub_group_feature_vectors,
                            num_cores=num_cores_for_internal_pdist # Use the determined internal pdist cores
                        )
                        if sub_group_distances.size == 0 and num_sub_group_neighborhoods <= 1: # Added <=1 for robustness
                            sub_group_assignments = {members_in_core_group[0]: 1}
                        elif sub_group_distances.size == 0 and num_sub_group_neighborhoods > 1:
                            sub_group_assignments = {m: 1 for m in members_in_core_group}
                        else:
                            sub_group_linked = linkage(sub_group_distances, method='average')
                            sub_group_assignments = fcluster(sub_group_linked, collapse_full_neighborhood_similarity_threshold_ref, criterion='distance')
                    
                    del sub_group_feature_vectors
                    if 'sub_group_distances' in locals(): del sub_group_distances
                    if 'sub_group_linked' in locals(): del sub_group_linked
                    gc.collect()

                current_sub_groups = defaultdict(list)
                if isinstance(sub_group_assignments, np.ndarray):
                    for i, sub_cluster_id in enumerate(sub_group_assignments):
                        current_sub_groups[sub_cluster_id].append(members_in_core_group[i])
                else:
                    current_sub_groups = {1: members_in_core_group}

                for sub_cluster_id in sorted(current_sub_groups.keys()):
                    collapsed_members = current_sub_groups[sub_cluster_id]
                    
                    if len(collapsed_members) > 1:
                        local_collapsed_total_count += (len(collapsed_members) - 1)
                        representative_label = collapsed_members[0]
                        
                        union_features = set()
                        for member_label in collapsed_members:
                            union_features.update(all_neighborhood_features_ref[member_label])

                        results_to_aggregate.append((representative_label,
                                                     union_features,
                                                     full_neighborhood_labels_map_ref[representative_label],
                                                     collapsed_members))
                    else:
                        member_label = collapsed_members[0]
                        results_to_aggregate.append((member_label,
                                                     all_neighborhood_features_ref[member_label],
                                                     full_neighborhood_labels_map_ref[member_label],
                                                     None))
        return results_to_aggregate, local_collapsed_total_count


    # Prepare arguments for parallel execution
    all_core_group_ids = sorted(list(initial_core_groups.keys()))
    num_total_core_groups = len(all_core_group_ids)

    # Conditional Parallelization for Stage 2 Outer Loop 
    if num_total_core_groups < MIN_ITEMS_FOR_PARALLEL_PROCESSING or not parallelize_pdist:
        _write_and_print_internal(f"{output_prefix}  Stage 2: Running sequentially for {num_total_core_groups} core groups (below parallel threshold or parallelization disabled).")
        # Run sequentially
        results_from_workers = [
            process_core_group_chunk(
                all_core_group_ids,
                all_neighborhood_features,
                full_neighborhood_labels_map,
                collapse_full_neighborhood_similarity_threshold,
                False # Explicitly tell inner pdist to run sequentially if outer is sequential
            )
        ]
    else:
        # Determine the number of cores to use for stage 2 parallelism
        num_stage2_cores = os.cpu_count() if parallelize_pdist else 1
        if num_stage2_cores <= 0: num_stage2_cores = 1

        _write_and_print_internal(f"{output_prefix}  Stage 2: Distributing {num_total_core_groups} core groups among {num_stage2_cores} workers (parallelized).")

        # Create chunks of core group IDs to distribute
        chunk_size = max(1, num_total_core_groups // num_stage2_cores)
        group_id_chunks = [all_core_group_ids[i:i + chunk_size] for i in range(0, num_total_core_groups, chunk_size)]
        
        results_from_workers = Parallel(n_jobs=num_stage2_cores, backend="loky", verbose=100)(
            delayed(process_core_group_chunk)(
                chunk,
                all_neighborhood_features,
                full_neighborhood_labels_map,
                collapse_full_neighborhood_similarity_threshold,
                parallelize_pdist # Pass the overall parallelize_pdist flag
            ) for chunk in group_id_chunks
        )

    # --- Aggregate results from workers in the main process ---
    final_neighborhood_features = {}
    final_neighborhood_labels_map = {}
    collapsed_total_count = 0
    unique_collapsed_group_counter = 0 # Single global counter

    for worker_results, worker_collapsed_count in results_from_workers:
        collapsed_total_count += worker_collapsed_count
        for representative_label, features, original_labels_map_entry, collapsed_members in worker_results:
            if collapsed_members is not None: # This was a collapsed group
                letter_code = generate_letter_code(unique_collapsed_group_counter)
                unique_collapsed_group_counter += 1
                
                orig_organism, orig_hit_id, orig_ssn_id, orig_accession, _ = original_labels_map_entry
                final_neighborhood_labels_map[representative_label] = (orig_organism, orig_hit_id, orig_ssn_id, orig_accession, (len(collapsed_members), letter_code))
                
                collapsed_groups_report[letter_code] = {
                    'representative': representative_label,
                    'members': sorted(collapsed_members),
                    'count': len(collapsed_members)
                }
                final_neighborhood_features[representative_label] = features
            else: # Not a collapsed group, just an individual neighborhood
                final_neighborhood_features[representative_label] = features
                final_neighborhood_labels_map[representative_label] = original_labels_map_entry


    if collapsed_total_count > 0:
        _write_and_print_internal(f"{output_prefix}  Collapsed a total of {collapsed_total_count} neighborhoods into {len(final_neighborhood_features)} unique entities after two stages. Stage 2 took {time.time() - stage2_start:.2f} seconds.")
    else:
        _write_and_print_internal(f"{output_prefix}  No neighborhoods were collapsed after two stages (or disabled). Stage 2 took {time.time() - stage2_start:.2f} seconds.")

    _write_and_print_internal(f"{output_prefix}  Overall collapsing took {time.time() - collapsing_overall_start:.2f} seconds.")
    return final_neighborhood_features, final_neighborhood_labels_map, collapsed_groups_report


def cluster_gene_neighborhoods_from_excel(
    excel_path,
    project_name,
    hit_gene_weight_factor=HIT_GENE_WEIGHT_FACTOR,
    direct_neighbor_weight_factor=DIRECT_NEIGHBOR_WEIGHT_FACTOR, 
    original_input_sequence_id=None, 
    distance_threshold=0.8,
    plot_dendrogram=True,
    save_plots=SAVE_PLOTS,
    output_dir=OUTPUT_DIR,
    output_formats=OUTPUT_FORMATS, 
    dpi=DPI,
    min_plot_height=MIN_PLOT_HEIGHT, 
    height_per_leaf=HEIGHT_PER_LEAF, 
    max_plot_height=MAX_PLOT_HEIGHT,
    min_plot_width=MIN_PLOT_WIDTH,
    width_per_leaf=WIDTH_PER_LEAF,
    max_plot_width=MAX_PLOT_WIDTH,
    report_file_handle=None, 
    parallelize_pdist=False,
    collapse_identical_neighborhoods=COLLAPSE_IDENTICAL_NEIGHBORHOODS,
    collapse_core_similarity_threshold=COLLAPSE_CORE_SIMILARITY_THRESHOLD,
    collapse_full_neighborhood_similarity_threshold=COLLAPSE_FULL_NEIGHBORHOOD_SIMILARITY_THRESHOLD
):
    """
    Clusters gene neighborhoods from an Excel file.
    Each sheet is a gene neighborhood. Identifies hit gene and direct neighbors for weighting.
    The clustering is performed across all neighborhoods in the Excel file.
    
    Returns:
        tuple: A tuple containing (clusters_dict, final_neighborhood_labels_map, collapsed_groups_report).
               clusters_dict is a dict where keys are 'All_Excel_Neighborhoods' 
               and values are dicts of cluster_id -> list of unique_neighborhood_labels (which might be representatives).
               final_neighborhood_labels_map is a dict mapping unique_neighborhood_label to 
               (organism_name, hit_gene_id, ssn_cluster_id (dummy), accession_id (gene_id), collapsed_members_info).
               collapsed_groups_report is a dict detailing the collapsed groups.
    """
    def _write_and_print_internal(text):
        print(text)
        if report_file_handle:
            report_file_handle.write(text + '\n')

    start_time_overall = time.time()
    
    # --- Step 1: Read all data from Excel ---
    excel_neighborhood_data = read_excel_gene_neighborhoods(excel_path, report_file=report_file_handle)

    if not excel_neighborhood_data:
        _write_and_print_internal("No gene neighborhoods found in the Excel file. Exiting.")
        return {}, {}, {}

    all_neighborhood_features = defaultdict(set)
    core_neighborhood_features = defaultdict(set)
    # Maps unique_neighborhood_label to (organism_name, hit_gene_id, ssn_cluster_id (dummy), accession_id (gene_id), collapsed_info)
    full_neighborhood_labels_map = {} 

    _write_and_print_internal(f"Processing features for {len(excel_neighborhood_data)} gene neighborhoods from Excel sheets...")
    feature_extraction_start_time = time.time()

    # Loop through each sheet, which represents one gene neighborhood
    for sheet_name, df_neighborhood in tqdm(excel_neighborhood_data.items(), desc="Extracting features from Excel neighborhoods", unit="neighborhood"):
        
        organism_name = sheet_name # Worksheet name serves as the organism/neighborhood ID
        
        # Identify the hit gene by the 'yes' marker
        hit_gene_row_series = df_neighborhood[df_neighborhood[EXCEL_COL_HIT_GENE].astype(str).str.lower() == EXCEL_HIT_GENE_MARKER.lower()]
        
        if hit_gene_row_series.empty:
            _write_and_print_internal(f"  Warning: No 'hit_gene' marked 'yes' in sheet '{sheet_name}'. Skipping this neighborhood.")
            continue 

        # If multiple hit genes are marked 'yes', take the first one encountered in the DataFrame order
        hit_gene_row = hit_gene_row_series.iloc[0].to_dict() 
        hit_gene_id = hit_gene_row[COL_GENE_ID]

        # For the Excel data, there's no explicit accession ID or SSN cluster ID.
        # We use the gene_id as the accession_id for highlighting purposes in plots.
        # A dummy SSN cluster ID is used as all Excel neighborhoods are clustered together.
        accession_id = hit_gene_id 
        ssn_cluster_id = 'Excel_Data_Cluster' 

        unique_neighborhood_label = f"{organism_name}_{hit_gene_id}"
        
        current_full_features = set()
        current_core_features = set()

        # Add features of the HIT gene itself with its special weight
        hit_full_features = extract_features_from_excel_gene_row(
            gene_row_dict=hit_gene_row, 
            current_weight_factor=hit_gene_weight_factor,
            base_prefix="HIT_"
        )
        current_full_features.update(hit_full_features)
        
        # Add core features of the HIT gene (unweighted for collapsing core comparison)
        hit_core_features = extract_features_from_excel_gene_row(
            gene_row_dict=hit_gene_row, 
            current_weight_factor=1, # Always 1 for core features, they are not redundantly weighted
            base_prefix="HIT_CORE_"
        )
        current_core_features.update(hit_core_features)

        # Store initial label mapping for this neighborhood
        full_neighborhood_labels_map[unique_neighborhood_label] = (organism_name, hit_gene_id, ssn_cluster_id, accession_id, None)

        # Identify direct neighbors based on row index
        hit_gene_idx_in_df = df_neighborhood.index[df_neighborhood[COL_GENE_ID] == hit_gene_id].tolist()[0]
        
        direct_neighbor_indices = []
        if hit_gene_idx_in_df > 0: # Left neighbor
            direct_neighbor_indices.append(hit_gene_idx_in_df - 1)
        if hit_gene_idx_in_df < len(df_neighborhood) - 1: # Right neighbor
            direct_neighbor_indices.append(hit_gene_idx_in_df + 1)

        # Iterate through all genes in the neighborhood to extract features
        for gene_idx, gene_row_series in df_neighborhood.iterrows():
            gene_row_dict = gene_row_series.to_dict()
            current_gene_id = gene_row_dict[COL_GENE_ID]
            
            # Skip the hit gene itself as its features are already processed with specific weights
            if current_gene_id == hit_gene_id:
                continue
            
            current_neighbor_weight_factor = 1
            is_direct_neighbor = False

            if gene_idx in direct_neighbor_indices:
                current_neighbor_weight_factor = direct_neighbor_weight_factor
                is_direct_neighbor = True
            
            # Extract full features for neighbors (weighted if direct)
            neighbor_full_features = extract_features_from_excel_gene_row(
                gene_row_dict=gene_row_dict, 
                current_weight_factor=current_neighbor_weight_factor,
                base_prefix="N_"
            ) 
            current_full_features.update(neighbor_full_features)

            # Extract core features for direct neighbors (unweighted for collapsing core comparison)
            if is_direct_neighbor:
                 neighbor_core_features = extract_features_from_excel_gene_row(
                    gene_row_dict=gene_row_dict,
                    current_weight_factor=1, # Always 1 for core features
                    base_prefix="N_CORE_"
                )
                 current_core_features.update(neighbor_core_features)

        # Update the overall feature sets for this neighborhood
        all_neighborhood_features[unique_neighborhood_label].update(current_full_features)
        core_neighborhood_features[unique_neighborhood_label].update(current_core_features)
    
    
    _write_and_print_internal(f"Finished feature extraction in {time.time() - feature_extraction_start_time:.2f} seconds.")
    del excel_neighborhood_data # Free memory of raw Excel data
    gc.collect() # Trigger garbage collection

    if not all_neighborhood_features:
        _write_and_print_internal("No gene neighborhoods with valid hit genes and features found. Exiting.")
        return {}, {}, {}

    # Check for features before collapsing
    all_unique_features_vocabulary_initial = sorted(list(set.union(*all_neighborhood_features.values())))
    if not all_unique_features_vocabulary_initial:
        _write_and_print_internal("No significant features extracted from any neighborhood. Cannot cluster.")
        # Return a dummy cluster for all neighborhoods to prevent errors down the line if the calling code expects it.
        return {'All_Excel_Neighborhoods': {1: list(all_neighborhood_features.keys())}}, full_neighborhood_labels_map, {} 

    # Pre-clustering (collapsing) similar neighborhoods 
    if collapse_identical_neighborhoods:
        collapsing_start_time = time.time()
        _write_and_print_internal(f"\nPerforming collapsing of similar neighborhoods (Enabled).")
        final_neighborhood_features, final_neighborhood_labels_map, collapsed_groups_report = \
            _perform_collapsing(all_neighborhood_features, full_neighborhood_labels_map, 
                                core_neighborhood_features,
                                collapse_core_similarity_threshold, collapse_full_neighborhood_similarity_threshold,
                                output_prefix="  [Collapsing]", 
                                report_file=report_file_handle, 
                                parallelize_pdist=parallelize_pdist)
    else:
        _write_and_print_internal("\nCollapsing identical/similar neighborhoods is disabled. Proceeding with all original neighborhoods.")
        final_neighborhood_features = all_neighborhood_features
        final_neighborhood_labels_map = full_neighborhood_labels_map
        collapsed_groups_report = {}
    
    # Explicit memory cleanup for pre-collapse feature dictionaries
    del all_neighborhood_features 
    del core_neighborhood_features 
    gc.collect() 

    # Re-check for features after collapsing
    all_unique_features_vocabulary = sorted(list(set.union(*final_neighborhood_features.values())))
    if not all_unique_features_vocabulary:
        _write_and_print_internal("No significant features extracted from final neighborhoods after collapsing. Cannot cluster.")
        return {'All_Excel_Neighborhoods': {1: list(final_neighborhood_features.keys())}}, final_neighborhood_labels_map, collapsed_groups_report


    # For Excel data, we cluster all neighborhoods together (no SSN differentiation by default)
    # The structure of `clusters_output_dict` expects a key like ssn_id.
    ssn_id_for_clustering = 'All_Excel_Neighborhoods' 
    neighborhood_labels_to_cluster = sorted(list(final_neighborhood_features.keys()))
    
    clusters_output_dict = defaultdict(dict) # Outer dict for SSN, inner for clusters within SSN

    _write_and_print_internal(f"\n--- Processing all Excel neighborhoods ({len(neighborhood_labels_to_cluster)} neighborhoods) ---")
    plot_title_prefix = f"Clustering by Gene Neighborhoods for {project_name}"

    num_neighborhoods_in_group = len(neighborhood_labels_to_cluster)
    if num_neighborhoods_in_group < 2:
        _write_and_print_internal(f"  Skipping group '{ssn_id_for_clustering}': Not enough distinct neighborhoods ({num_neighborhoods_in_group}) for clustering. Requires at least 2.")
        # If less than 2, assign to a single cluster ID 1
        clusters_output_dict[ssn_id_for_clustering] = {1: neighborhood_labels_to_cluster}
    else:
        # Prepare feature vectors for clustering
        current_vocabulary = sorted(list(set.union(*[final_neighborhood_features[label] for label in neighborhood_labels_to_cluster])))
        if not current_vocabulary:
            _write_and_print_internal(f"  No features found for group '{ssn_id_for_clustering}'. Cannot cluster.")
            clusters_output_dict[ssn_id_for_clustering] = {1: neighborhood_labels_to_cluster}
        else:
            feature_vector_creation_start = time.time()
            
            feature_to_idx = {feature: i for i, feature in enumerate(current_vocabulary)}
            num_current_neighborhoods = len(neighborhood_labels_to_cluster)
            num_current_features = len(current_vocabulary)

            feature_vectors_lil = sp.lil_matrix((num_current_neighborhoods, num_current_features), dtype=np.int8)
            for i, nh_id in enumerate(tqdm(neighborhood_labels_to_cluster, desc=f"  Populating features for {ssn_id_for_clustering}", leave=False)): 
                for feature in final_neighborhood_features[nh_id]:
                    if feature in feature_to_idx:
                        j = feature_to_idx[feature]
                        feature_vectors_lil[i, j] = 1
            feature_vectors_np = feature_vectors_lil.tocsr() # Convert to CSR for efficient row-wise access if needed

            _write_and_print_internal(f"  Feature vector creation for {num_neighborhoods_in_group} neighborhoods ({len(current_vocabulary)} features) took {time.time() - feature_vector_creation_start:.2f} seconds.") 
            _write_and_print_internal(f"  Matrix shape: {feature_vectors_np.shape}, NNZ: {feature_vectors_np.nnz}") 
            gc.collect() 

            # Check for identical feature vectors (prevents issues with distance calculation/linkage)
            if num_current_neighborhoods > 1 and all(
                (feature_vectors_np[0] != feature_vectors_np[i]).nnz == 0 # Compare sparse rows
                for i in range(1, num_current_neighborhoods)
            ):
                _write_and_print_internal(f"  All neighborhoods in {plot_title_prefix} have identical features. No meaningful distance calculated. Skipping plotting.")
                clusters_output_dict[ssn_id_for_clustering] = {1: neighborhood_labels_to_cluster}
            else:
                plot_start = 0 
                if plot_dendrogram: # Only start timer if plotting will occur
                    plot_start = time.time() 

                distance_calc_start = time.time()
                _write_and_print_internal(f"  Calculating distances for {num_current_neighborhoods} neighborhoods...") 
                distances = parallel_pdist_jaccard(feature_vectors_np, num_cores=-1 if parallelize_pdist else 1) 
                _write_and_print_internal(f"  Distance calculation took {time.time() - distance_calc_start:.2f} seconds.") 

                linkage_start = time.time()
                _write_and_print_internal(f"  Performing linkage for {num_current_neighborhoods} neighborhoods...") 
                linked = linkage(distances, method='average')
                _write_and_print_internal(f"  Linkage calculation took {time.time() - linkage_start:.2f} seconds.") 
                
                # Explicit memory cleanup
                del feature_vectors_np 
                del distances 
                gc.collect() 
                
                if plot_dendrogram:
                    # Organism Labels
                    _write_and_print_internal(f"  Generating dendrogram plots (Organism Labels)...") 
                    _plot_dendrogram(linked, neighborhood_labels_to_cluster, final_neighborhood_labels_map, distance_threshold, 
                                    f"{plot_title_prefix}", 'organism', 
                                    original_input_sequence_id, # This is matched against accession_id, which is gene_id for Excel
                                    save_plots, output_dir, output_formats, dpi, 
                                    min_plot_height, height_per_leaf, max_plot_height,
                                    min_plot_width, width_per_leaf, max_plot_width,
                                    report_file=report_file_handle)
                    # ID Labels (using gene_id/accession_id)
                    _write_and_print_internal(f"  Generating dendrogram plots (ID Labels)...") 
                    _plot_dendrogram(linked, neighborhood_labels_to_cluster, final_neighborhood_labels_map, distance_threshold, 
                                    f"{plot_title_prefix}", 'id', 
                                    original_input_sequence_id, 
                                    save_plots, output_dir, output_formats, dpi, 
                                    min_plot_height, height_per_leaf, max_plot_height,
                                    min_plot_width, width_per_leaf, max_plot_width,
                                    report_file=report_file_handle)
                    _write_and_print_internal(f"  Plotting took {time.time() - plot_start:.2f} seconds.") 
                                
                cluster_assignments = fcluster(linked, distance_threshold, criterion='distance')
                current_clusters = defaultdict(list)
                for i, cluster_id in enumerate(cluster_assignments):
                    current_clusters[cluster_id].append(neighborhood_labels_to_cluster[i])
                
                clusters_output_dict[ssn_id_for_clustering] = current_clusters
                _write_and_print_internal(f"--- Finished processing Excel neighborhoods in {time.time() - start_time_overall:.2f} seconds. ---") 
                del linked # Free memory
                gc.collect() 

    _write_and_print_internal(f"\nTotal runtime: {time.time() - start_time_overall:.2f} seconds.")
    return clusters_output_dict, final_neighborhood_labels_map, collapsed_groups_report


Set OMP_NUM_THREADS to 12 for SciPy/NumPy internal multi-threading.


In [36]:
EXCEL_FILE_PATH = 'PEP project_phosphonate gene clusters-20160423.xlsx'
project_name = 'PEP gene clusters'

ORIGINAL_INPUT_SEQUENCE_ID = None 

# Column names in the Excel sheets
EXCEL_COL_LENGTH = '长度(aa)'
EXCEL_COL_PFAM = 'pfam'
EXCEL_COL_HIT_GENE = 'hit_gene'
EXCEL_HIT_GENE_MARKER = 'yes' # Value indicating the hit gene

# You can change this setting here or keep the global config
chosen_distance_threshold = 0.5         # Change as needed
PARALLELIZE_PDIST_ENABLED = True        # Set to True to enable parallel pdist.

# Configuration for collapsing similar neighborhoods (Local Override)
COLLAPSE_IDENTICAL_NEIGHBORHOODS_ACTIVE = False 
COLLAPSE_CORE_SIMILARITY_THRESHOLD_ACTIVE = 0.0 # Use 0.0 for exact match of hit+direct neighbors
COLLAPSE_FULL_NEIGHBORHOOD_SIMILARITY_THRESHOLD_ACTIVE = 0.3 # e.g., 0.3 for 70% similarity of full neighborhood

# Prepare report file
report_suffix = "_excel_input" # Changed suffix for Excel input
report_filename = f"{REPORT_FILENAME_BASE}{report_suffix}.txt"
report_path = os.path.join(OUTPUT_DIR, report_filename)
    
if not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR) # Ensure output directory exists for report and plots

with open(report_path, 'w') as report_file:
    # Redefine write_and_print to use the local report_file handle
    def write_and_print_to_file(text):
        print(text)
        report_file.write(text + '\n')

    # Use write_and_print_to_file for all report output
    write_and_print_to_file(f"\n--- GNN Clustering Report ({datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}) ---") 
    write_and_print_to_file(f"Data Source: Excel File ({EXCEL_FILE_PATH})") 
    write_and_print_to_file(f"Jaccard Distance Threshold: {chosen_distance_threshold}")
    write_and_print_to_file(f"Hit Gene Weight Factor: {HIT_GENE_WEIGHT_FACTOR}")
    write_and_print_to_file(f"Direct Neighbor Weight Factor: {DIRECT_NEIGHBOR_WEIGHT_FACTOR}")
    
    # Report that SSN differentiation is not applicable for Excel
    write_and_print_to_file("Clustering all Excel neighborhoods together (SSN cluster differentiation not applicable for this data format).")
    
    if COLLAPSE_IDENTICAL_NEIGHBORHOODS_ACTIVE:
        write_and_print_to_file(f"Collapsing identical/similar neighborhoods enabled:")
        write_and_print_to_file(f"  Stage 1 (Hit+Direct Neighbor Core): Threshold {COLLAPSE_CORE_SIMILARITY_THRESHOLD_ACTIVE}")
        write_and_print_to_file(f"  Stage 2 (Full Neighborhood): Threshold {COLLAPSE_FULL_NEIGHBORHOOD_SIMILARITY_THRESHOLD_ACTIVE}")
    else:
        write_and_print_to_file("Collapsing identical/similar neighborhoods disabled.")

    write_and_print_to_file(f"Distance calculation parallelism: {'Enabled (via joblib)' if PARALLELIZE_PDIST_ENABLED else 'Disabled (sequential custom Jaccard)'}")
    write_and_print_to_file(f"SciPy/NumPy internal parallelism (OMP_NUM_THREADS): {os.environ.get('OMP_NUM_THREADS', 'Not set (defaults will apply)')}")
    
    if ORIGINAL_INPUT_SEQUENCE_ID:
        write_and_print_to_file(f"Original Input Sequence Gene ID for highlighting: '{ORIGINAL_INPUT_SEQUENCE_ID}' (colored '{HIGHLIGHT_COLOR}')")
    else:
        write_and_print_to_file("No specific original input sequence ID provided for highlighting.")
    write_and_print_to_file(f"Plots saved to: {OUTPUT_DIR} in {OUTPUT_FORMATS} formats at {DPI} DPI.")
    write_and_print_to_file(f"Report also saved to: {report_path}")
    write_and_print_to_file("-" * 70)

    clusters_by_ssn, final_labels_map, collapsed_groups_report = cluster_gene_neighborhoods_from_excel(
                                            excel_path=EXCEL_FILE_PATH,
                                            project_name=project_name,
                                            hit_gene_weight_factor=HIT_GENE_WEIGHT_FACTOR,
                                            direct_neighbor_weight_factor=DIRECT_NEIGHBOR_WEIGHT_FACTOR,
                                            original_input_sequence_id=ORIGINAL_INPUT_SEQUENCE_ID,
                                            distance_threshold=chosen_distance_threshold,
                                            plot_dendrogram=True,
                                            save_plots=SAVE_PLOTS,
                                            output_dir=OUTPUT_DIR,
                                            output_formats=OUTPUT_FORMATS, 
                                            dpi=DPI,
                                            min_plot_height=MIN_PLOT_HEIGHT,
                                            height_per_leaf=HEIGHT_PER_LEAF, 
                                            max_plot_height=MAX_PLOT_HEIGHT,
                                            min_plot_width=MIN_PLOT_WIDTH, 
                                            width_per_leaf=WIDTH_PER_LEAF, 
                                            max_plot_width=MAX_PLOT_WIDTH,
                                            report_file_handle=report_file,
                                            parallelize_pdist=PARALLELIZE_PDIST_ENABLED,
                                            collapse_identical_neighborhoods=COLLAPSE_IDENTICAL_NEIGHBORHOODS_ACTIVE,
                                            collapse_core_similarity_threshold=COLLAPSE_CORE_SIMILARITY_THRESHOLD_ACTIVE,
                                            collapse_full_neighborhood_similarity_threshold=COLLAPSE_FULL_NEIGHBORHOOD_SIMILARITY_THRESHOLD_ACTIVE
                                        )


    if clusters_by_ssn:
        write_and_print_to_file("\n--- Final Clustering Results ---")
        # For Excel input, there will typically be only one "SSN group" called 'All_Excel_Neighborhoods'
        for ssn_id, clusters_in_ssn in sorted(clusters_by_ssn.items(), key=lambda item: str(item[0])):
            write_and_print_to_file(f"\n### Results for Group: {ssn_id} ###")
            if not clusters_in_ssn:
                write_and_print_to_file("  No clusters formed for this group, or insufficient data.")
                continue

            for cluster_id, neighborhoods_in_cluster in sorted(clusters_in_ssn.items()):
                write_and_print_to_file(f"  Cluster {cluster_id}: {len(neighborhoods_in_cluster)} neighborhoods")
                for nh_id in neighborhoods_in_cluster:
                    organism_name, hit_id_internal, _, accession_id, collapsed_info = final_labels_map.get(nh_id, ('UNKNOWN', 'UNKNOWN', None, 'UNKNOWN', None))
                    
                    # Highlight based on accession_id (which is gene_id for Excel)
                    highlight_indicator = " (ORIGINAL INPUT)" if accession_id == ORIGINAL_INPUT_SEQUENCE_ID else ""
                    collapsed_suffix = ""
                    if collapsed_info:
                        count, letter_code = collapsed_info
                        collapsed_suffix = f" (Collapsed: {count} neighborhoods, Ref: {letter_code})"
                    
                    write_and_print_to_file(f"    - Organism: {organism_name}, Hit Gene ID: {accession_id}{highlight_indicator}{collapsed_suffix} (Internal NH ID: {nh_id})")
            write_and_print_to_file("  " + "-" * 30)
        
        if collapsed_groups_report:
            write_and_print_to_file("\n--- Detailed Report on Collapsed Neighborhood Groups ---")
            for code, group_data in sorted(collapsed_groups_report.items()):
                write_and_print_to_file(f"  Group ({code}): Representative: {group_data['representative']} (Total: {group_data['count']} members)")
                for member_nh_id in group_data['members']:
                    member_organism, member_hit_id, _, member_accession, _ = final_labels_map.get(member_nh_id, ('UNKNOWN', 'UNKNOWN', None, 'UNKNOWN', None))
                    write_and_print_to_file(f"    - {member_organism} (Gene ID: {member_accession}) (Internal Hit ID: {member_hit_id}) (NH ID: {member_nh_id})")
            write_and_print_to_file("-------------------------------------------------------")
    else:
        write_and_print_to_file("\nNo clusters formed at all. This could mean your Excel file is empty, or no features were extracted after filtering, or not enough valid neighborhoods for clustering were found.")

    write_and_print_to_file("\n--- Report End ---")


--- GNN Clustering Report (2025-09-11 22:17:45) ---
Data Source: Excel File (PEP project_phosphonate gene clusters-20160423.xlsx)
Jaccard Distance Threshold: 0.5
Hit Gene Weight Factor: 10
Direct Neighbor Weight Factor: 3
Clustering all Excel neighborhoods together (SSN cluster differentiation not applicable for this data format).
Collapsing identical/similar neighborhoods disabled.
Distance calculation parallelism: Enabled (via joblib)
SciPy/NumPy internal parallelism (OMP_NUM_THREADS): 12
No specific original input sequence ID provided for highlighting.
Plots saved to: gnn_cluster_plots_PEP in ['pdf'] formats at 300 DPI.
Report also saved to: gnn_cluster_plots_PEP\gnn_clustering_report_excel_input.txt
----------------------------------------------------------------------
Reading Excel file: PEP project_phosphonate gene clusters-20160423.xlsx


Processing Excel sheets:   6%|▌         | 1/18 [00:00<00:02,  6.59it/s]

  Processing sheet: NC5 LS110018 cluster38
  Processing sheet: NC3 N272 cluster71


Processing Excel sheets:  17%|█▋        | 3/18 [00:00<00:02,  6.25it/s]

  Processing sheet: NC9 LS1801 cluster 63
  Processing sheet: 182 LS2039-cluster 17
  Processing sheet: NC31-LS130724-scaf1-no cluster
  Processing sheet: 185-LS784-cluster 50
  Processing sheet: 131-LS477-cluster 26


Processing Excel sheets:  61%|██████    | 11/18 [00:00<00:00, 21.07it/s]

  Processing sheet: NC10 LS130053
  Processing sheet: NC24 LS132251
  Processing sheet: NC31 LS130724
  Processing sheet: NC4-LS2542-cluster100
  Processing sheet: 71-LS130084-cluster 114


Processing Excel sheets: 100%|██████████| 18/18 [00:01<00:00, 17.15it/s]


  Processing sheet: 71-LS130084-cluster 128
  Processing sheet: 78-LS131321-cluster 37
  Processing sheet: 170-LS795-cluster 72
  Processing sheet: 177-LS120054-no cluster-7
  Processing sheet: NC30-LS131440-sca33-no cluster
  Processing sheet: Sheet1
Successfully read data from 16 sheets.
Processing features for 16 gene neighborhoods from Excel sheets...


Extracting features from Excel neighborhoods: 100%|██████████| 16/16 [00:00<00:00, 54.86neighborhood/s]


Finished feature extraction in 0.30 seconds.

Collapsing identical/similar neighborhoods is disabled. Proceeding with all original neighborhoods.

--- Processing all Excel neighborhoods (16 neighborhoods) ---


                                                                                         

  Feature vector creation for 16 neighborhoods (502 features) took 0.03 seconds.
  Matrix shape: (16, 502), NNZ: 903
  Calculating distances for 16 neighborhoods...
  Pre-calculating 16 feature sets...


                                                                           

  Feature set pre-calculation took 0.01 seconds.
  Running Jaccard distance sequentially for N=16 (below parallel threshold or num_cores=1).
  Distance calculation took 0.26 seconds.
  Performing linkage for 16 neighborhoods...
  Linkage calculation took 0.00 seconds.
  Generating dendrogram plots (Organism Labels)...
  Generating dendrogram plots (ID Labels)...
  Plotting took 2.89 seconds.
--- Finished processing Excel neighborhoods in 12.33 seconds. ---

Total runtime: 12.66 seconds.

--- Final Clustering Results ---

### Results for Group: All_Excel_Neighborhoods ###
  Cluster 1: 1 neighborhoods
    - Organism: NC24 LS132251, Hit Gene ID: ctg4_81 (Internal NH ID: NC24 LS132251_ctg4_81)
  Cluster 2: 1 neighborhoods
    - Organism: NC9 LS1801 cluster 63, Hit Gene ID: ctg8_386 (Internal NH ID: NC9 LS1801 cluster 63_ctg8_386)
  Cluster 3: 1 neighborhoods
    - Organism: NC10 LS130053, Hit Gene ID: ctg4_452 (Internal NH ID: NC10 LS130053_ctg4_452)
  Cluster 4: 1 neighborhoods
    - Orga