# Toxoplasma gondii Haplogroup Analysis

This notebook analyzes SNP data from Toxoplasma gondii strains to identify haplogroups based on allele patterns across the genome.

## 1. Setup & Data Loading

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.patches as mpatches # For custom legends
from collections import Counter
import time # To time potentially long operations

# Define file paths
snp_file = 'snp.merge.txt'
clade_file = 'Strain_clades.xlsx'
genome_file = 'genome_table_ME49.txt'
output_plot_file = 'haplogroups.pdf'

# Load SNP data
# Low_memory=False can help with mixed types, common in large SNP tables
print(f"Loading SNP data from {snp_file}...")
try:
    snp_data = pd.read_csv(snp_file, sep='\t', low_memory=False)
    print("SNP data loaded successfully.")
    print(f"Shape: {snp_data.shape}")
    # Rename the first two columns for clarity based on inspection
    snp_data.rename(columns={snp_data.columns[0]: 'Chromosome', snp_data.columns[1]: 'Position'}, inplace=True)
    # Ensure Position is numeric
    snp_data['Position'] = pd.to_numeric(snp_data['Position'], errors='coerce')
    snp_data.dropna(subset=['Position'], inplace=True)
    snp_data['Position'] = snp_data['Position'].astype(int)
    print("First 5 rows of SNP data:")
    print(snp_data.head())
except FileNotFoundError:
    print(f"Error: File not found at {snp_file}")
    snp_data = None
except Exception as e:
    print(f"Error loading {snp_file}: {e}")
    snp_data = None

# Load Clade information
print(f"\nLoading Clade data from {clade_file}...")
try:
    # Using the confirmed sheet name and column names
    clade_info = pd.read_excel(clade_file, sheet_name='Sheet1')
    # Ensure the specified columns exist
    if 'StrainID' in clade_info.columns and 'Clade' in clade_info.columns:
        print("Clade data loaded successfully.")
        print(f"Shape: {clade_info.shape}")
        # Convert StrainID to string for reliable comparison
        clade_info['StrainID'] = clade_info['StrainID'].astype(str)
        # Standardize Clade names (e.g., convert to string, remove whitespace)
        clade_info['Clade'] = clade_info['Clade'].astype(str).str.strip()
        print("First 5 rows of Clade data:")
        print(clade_info.head())
    else:
        print(f"Error: Required columns 'StrainID' or 'Clade' not found in {clade_file}, Sheet1.")
        print(f"Available columns: {clade_info.columns.tolist()}")
        clade_info = None
except FileNotFoundError:
    print(f"Error: File not found at {clade_file}")
    clade_info = None
except Exception as e:
    print(f"Error loading {clade_file}: {e}")
    clade_info = None

# Load Genome information
print(f"\nLoading Genome table from {genome_file}...")
try:
    genome_info = pd.read_csv(genome_file, sep='\t', header=None, names=['Chromosome', 'Length'])
    print("Genome table loaded successfully.")
    print(f"Shape: {genome_info.shape}")
    print("First 5 rows of Genome table:")
    print(genome_info.head())
except FileNotFoundError:
    print(f"Error: File not found at {genome_file}")
    genome_info = None
except Exception as e:
    print(f"Error loading {genome_file}: {e}")
    genome_info = None


## 2. Data Preprocessing & Merging

Here we preprocess the data:
- Identify common strains between SNP data and Clade info.
- Filter SNP data to include only these common strains.
- Transpose the SNP data so strains are rows and SNPs are columns.
- Add Clade information to the transposed data.

In [None]:
processed_snp_data = None
snp_positions_df = None # To store original SNP positions

if snp_data is not None and clade_info is not None:
    print("\nStarting Data Preprocessing and Merging...")
    
    # Identify strain columns in snp_data (all columns except the first two)
    strain_columns_snp = snp_data.columns[2:].tolist()
    print(f"Found {len(strain_columns_snp)} strain columns in SNP data.")
    
    # Get strain IDs from clade_info
    strain_ids_clade = set(clade_info['StrainID'].astype(str).unique())
    print(f"Found {len(strain_ids_clade)} unique strain IDs in Clade data.")
    
    # Find common strains
    common_strains = list(set(strain_columns_snp) & strain_ids_clade)
    print(f"Found {len(common_strains)} common strains between SNP and Clade data.")
    
    # Report discrepancies
    strains_only_in_snp = list(set(strain_columns_snp) - strain_ids_clade)
    strains_only_in_clade = list(strain_ids_clade - set(strain_columns_snp))
    if strains_only_in_snp:
        print(f"Warning: {len(strains_only_in_snp)} strains found only in SNP data (will be excluded): {strains_only_in_snp[:10]}...")
    if strains_only_in_clade:
        print(f"Warning: {len(strains_only_in_clade)} strains found only in Clade data: {strains_only_in_clade[:10]}...")
        
    if not common_strains:
        print("Error: No common strains found between SNP data and Clade info. Cannot proceed.")
    else:
        # Store original SNP positions before filtering columns
        snp_positions_df = snp_data[['Chromosome', 'Position']].copy()
        snp_positions_df.set_index(['Chromosome', 'Position'], inplace=True)
        
        # Filter snp_data to keep only common strains and identifier columns
        snp_data_filtered = snp_data[['Chromosome', 'Position'] + common_strains].copy()
        
        # Set Chromosome and Position as index
        snp_data_filtered.set_index(['Chromosome', 'Position'], inplace=True)
        
        # Transpose the DataFrame
        print("\nTransposing SNP data...")
        processed_snp_data = snp_data_filtered.transpose()
        print(f"Transposed SNP data shape: {processed_snp_data.shape}")
        
        # Create a mapping from StrainID to Clade
        clade_map = clade_info.set_index('StrainID')['Clade'].to_dict()
        
        # Add Clade information to the transposed data
        processed_snp_data['Clade'] = processed_snp_data.index.map(clade_map)
        
        # Reorder columns to put 'Clade' first
        cols = [('Clade','')] + [col for col in processed_snp_data.columns if col != ('Clade','')]
        processed_snp_data = processed_snp_data[cols]
        
        # Check for strains that didn't get a clade assigned (shouldn't happen with common_strains logic)
        missing_clade = processed_snp_data[processed_snp_data['Clade'].isna()]
        if not missing_clade.empty:
            print(f"Warning: {len(missing_clade)} strains could not be mapped to a clade:")
            print(missing_clade.index.tolist())
            # Optionally remove strains with missing clades
            processed_snp_data.dropna(subset=['Clade'], inplace=True)
            print(f"Removed {len(missing_clade)} strains with missing clades.")
        
        print("\nPreprocessing and merging complete.")
        print("First 5 rows of processed SNP data (Strains as rows, SNPs as columns):")
        # Displaying only first few SNP columns for brevity
        print(processed_snp_data.iloc[:, :6].head())
        
else:
    print("\nSkipping Preprocessing: Required data (SNP or Clade) was not loaded successfully.")


## 3. Genome Binning

Divide the genome into non-overlapping 100kb bins based on the chromosome lengths. Calculate cumulative positions for plotting the concatenated genome.

In [None]:
genome_bins = None
chromosome_boundaries = None
chromosome_map = None
chromosome_midpoints = None

if genome_info is not None:
    print("\nStarting Genome Binning...")
    bin_size = 100000
    bins_list = []
    cumulative_pos = 0
    chromosome_boundaries = [0] # Start boundary at 0
    chromosome_map = {}
    chromosome_midpoints = {}
    
    # Ensure correct data types
    genome_info['Length'] = pd.to_numeric(genome_info['Length'], errors='coerce')
    genome_info.dropna(subset=['Length'], inplace=True)
    genome_info['Length'] = genome_info['Length'].astype(int)
    
    # Sort genome_info by chromosome name if needed (assuming natural sort order is desired)
    # This might require more complex sorting if chromosome names are like chr1, chr10, chr2
    # For TGME49_chrIa, TGME49_chrIb, etc., default sort should be okay.
    chromosome_order = ['TGME49_chrIa', 'TGME49_chrIb', 'TGME49_chrII', 'TGME49_chrIII', 'TGME49_chrIV', 'TGME49_chrV', 'TGME49_chrVI', 'TGME49_chrVIIa', 'TGME49_chrVIIb','TGME49_chrVIII', 'TGME49_chrIX', 'TGME49_chrX', 'TGME49_chrXI', 'TGME49_chrXII']
    genome_info_sorted = genome_info.sort_values(by='Chromosome', key=lambda x: x.map({name: i for i, name in enumerate(chromosome_order)})).reset_index(drop=True)
    print(f"Processing {len(genome_info_sorted)} chromosomes for binning.")

    for index, row in genome_info_sorted.iterrows():
        chromosome = row['Chromosome']
        length = row['Length']
        chromosome_map[chromosome] = {'start_cumulative': cumulative_pos, 'length': length}
        chromosome_midpoints[chromosome] = cumulative_pos + length / 2
        
        for bin_start in range(0, length, bin_size):
            bin_end = min(bin_start + bin_size, length)
            # Bin identifier combines chromosome and start position
            bin_id = f"{chromosome}_{bin_start+1}-{bin_end}"
            # Cumulative start position for this bin across the concatenated genome
            bin_cumulative_start = cumulative_pos + bin_start
            bin_cumulative_end = cumulative_pos + bin_end
            
            bins_list.append({
                'BinID': bin_id,
                'Chromosome': chromosome,
                'BinStart': bin_start + 1, # 1-based start coordinate
                'BinEnd': bin_end,
                'CumulativeStart': bin_cumulative_start,
                'CumulativeEnd': bin_cumulative_end
            })
        
        # Update cumulative position for the next chromosome
        cumulative_pos += length
        chromosome_boundaries.append(cumulative_pos)
            
    genome_bins = pd.DataFrame(bins_list)
    print(f"\nCreated {len(genome_bins)} genomic bins.")
    print("First 5 rows of genome_bins:")
    print(genome_bins.head())
    print("\nLast 5 rows of genome_bins:")
    print(genome_bins.tail())
    print(f"\nChromosome boundaries (cumulative positions): {chromosome_boundaries}")
    print(f"Total concatenated genome length: {cumulative_pos}")

else:
    print("\nSkipping Genome Binning: Genome information table was not loaded successfully.")


## 4. Clade Allele Profile Identification (Per Bin)

For each 100kb bin, identify the SNPs within it. Then, for each clade (A-F), determine the representative allele profile based on the **most frequent allele** at each SNP position among strains belonging to that clade.

In [None]:
clade_profiles = {}
snps_in_bins = {} # Optional: store which SNPs fall in which bin
unique_clades = []

if processed_snp_data is not None and genome_bins is not None:
    print("\nStarting Clade Allele Profile Identification...")
    start_time_profiles = time.time()
    
    # Get unique clades from the processed data (ensure no NaN)
    unique_clades = sorted(processed_snp_data['Clade'].dropna().unique())
    print(f"Identifying profiles for clades: {unique_clades}")
    
    # Ensure the SNP columns in processed_snp_data have a MultiIndex
    if not isinstance(processed_snp_data.columns, pd.MultiIndex):
        print("Error: Columns of processed_snp_data are not a MultiIndex (Chromosome, Position). Cannot proceed.")
        # Attempt to recreate MultiIndex if snp_positions_df exists
        if snp_positions_df is not None:
             try:
                 snp_cols = snp_positions_df.index
                 # Select only SNP columns (exclude 'Clade')
                 snp_data_only = processed_snp_data.drop(columns=['Clade'])
                 snp_data_only.columns = snp_cols
                 processed_snp_data = pd.concat([processed_snp_data[['Clade']], snp_data_only], axis=1)
                 print("Successfully recreated MultiIndex columns.")
             except Exception as e:
                 print(f"Error recreating MultiIndex: {e}")
                 processed_snp_data = None # Mark as unusable
        else:
             processed_snp_data = None # Mark as unusable
    
if processed_snp_data is not None and genome_bins is not None: # Check again after potential index fix
    # Get SNP columns (MultiIndex)
    snp_columns = processed_snp_data.columns[1:] # Exclude 'Clade' column
    
    total_bins = len(genome_bins)
    processed_count = 0
    print_interval = max(1, total_bins // 10) # Print progress roughly 10 times
    
    # Iterate through each bin
    for index, bin_info in genome_bins.iterrows():
        bin_id = bin_info['BinID']
        chrom = bin_info['Chromosome']
        start = bin_info['BinStart']
        end = bin_info['BinEnd']
        
        # Find SNPs within this bin
        # Need to access the levels of the MultiIndex
        snps_in_bin_mask = (snp_columns.get_level_values(0) == chrom) & \
                           (snp_columns.get_level_values(1) >= start) & \
                           (snp_columns.get_level_values(1) <= end)
        
        bin_snp_columns = snp_columns[snps_in_bin_mask]
        snps_in_bins[bin_id] = bin_snp_columns.tolist() # Store SNP identifiers
        
        if not bin_snp_columns.empty:
            clade_profiles[bin_id] = {}
            # Get the SNP data just for this bin
            # Ensure we only select columns that actually exist after filtering
            existing_bin_snp_cols = [col for col in bin_snp_columns if col in processed_snp_data.columns]
            if not existing_bin_snp_cols:
                continue # Skip if somehow no SNP columns remain
                
            bin_snp_data = processed_snp_data[[('Clade','')] + existing_bin_snp_cols]
            
            # Calculate profile for each clade
            for clade in unique_clades:
                # Filter strains belonging to the current clade
                clade_specific_data = bin_snp_data[bin_snp_data['Clade'] == clade]
                
                if not clade_specific_data.empty:
                    # Calculate the mode (most frequent allele) for each SNP column
                    # Drop 'Clade' column before calculating mode
                    # mode() returns a DataFrame; we take the first row [0] in case of ties for mode
                    # fillna('N') handles cases where a SNP might be all NaN for a clade
                    # Use .infer_objects() to handle potential mixed types before mode calculation
                    try:
                        mode_profile = clade_specific_data.drop(columns=['Clade']).infer_objects().mode(axis=0, dropna=True).iloc[0].fillna('N') 
                        clade_profiles[bin_id][clade] = mode_profile # Store the profile (Pandas Series)
                    except IndexError: # Handle cases where mode returns empty df (e.g., all NaNs)
                         # Create a profile of 'N's with the correct index
                         clade_profiles[bin_id][clade] = pd.Series('N', index=existing_bin_snp_cols)
                         # print(f"Debug: Mode calculation failed for bin {bin_id}, clade {clade}. Assigning 'N' profile.")
                    except Exception as e:
                         print(f"Error calculating mode for bin {bin_id}, clade {clade}: {e}")
                         # Create a profile of 'N's as fallback
                         clade_profiles[bin_id][clade] = pd.Series('N', index=existing_bin_snp_cols)
                # else: profile remains empty for this clade in this bin
        # else: No SNPs in this bin, clade_profiles[bin_id] will not be created
        
        processed_count += 1
        if processed_count % print_interval == 0 or processed_count == total_bins:
            elapsed_time = time.time() - start_time_profiles
            print(f"  Processed {processed_count}/{total_bins} bins for profiles... ({elapsed_time:.2f}s elapsed)")

    end_time_profiles = time.time()
    print(f"\nClade Allele Profile Identification complete. Took {end_time_profiles - start_time_profiles:.2f} seconds.")
    print(f"Generated profiles for {len(clade_profiles)} bins with SNPs.")
    
    # Example: Print profile for the first bin with data
    first_bin_with_profile = next(iter(clade_profiles.items()), None)
    if first_bin_with_profile:
        bin_key, profiles = first_bin_with_profile
        print(f"\nExample profile for bin '{bin_key}':")
        for clade, profile in profiles.items():
            # Display only first few alleles of the profile Series
            if isinstance(profile, pd.Series):
                 print(f"  Clade {clade}: {profile.head().to_list()}...") 
            else:
                 print(f"  Clade {clade}: Profile is not a Series (type: {type(profile)}) ")
            
else:
    print("\nSkipping Profile Identification: Required data (processed SNP data or genome bins) is missing.")


## 5. Strain Bin Assignment

Compare each strain's allele pattern within each bin to the representative clade profiles. Assign the bin to the clade with the highest **percentage identity**. If there's a tie, assign 'Mixed'.

In [None]:
strain_bin_assignments = None

def calculate_percentage_identity(series1, series2):
    """Calculates percentage identity between two pandas Series, ignoring NaNs and 'N'."""
    # Ensure inputs are Series
    if not isinstance(series1, pd.Series) or not isinstance(series2, pd.Series):
        # print(f"Debug: Invalid input types: {type(series1)}, {type(series2)}")
        return 0.0
        
    # Align series by index (SNPs)
    aligned_s1, aligned_s2 = series1.align(series2, join='inner')
    
    if aligned_s1.empty: # No common SNPs
        return 0.0
        
    # Convert to numpy arrays for faster comparison
    arr1 = aligned_s1.to_numpy()
    arr2 = aligned_s2.to_numpy()
    
    # Create masks for valid comparisons (ignore NaN, 'N', or other placeholders if necessary)
    # Assuming 'N' and NaN are placeholders to ignore
    valid_mask = (arr1 != 'N') & pd.notna(arr1) & (arr2 != 'N') & pd.notna(arr2)
    
    # Calculate matches only where both are valid
    matches = (arr1[valid_mask] == arr2[valid_mask]).sum()
    valid_comparisons = valid_mask.sum()
    
    if valid_comparisons == 0:
        return 0.0 # Avoid division by zero if no valid positions to compare
        
    return (matches / valid_comparisons) * 100

if processed_snp_data is not None and genome_bins is not None and clade_profiles:
    print("\nStarting Strain Bin Assignment...")
    start_time_assignment = time.time()
    
    strains_to_process = processed_snp_data.index
    bin_ids = genome_bins['BinID'].tolist()
    
    # Initialize DataFrame for assignments
    strain_bin_assignments = pd.DataFrame(index=strains_to_process, columns=bin_ids, dtype=object)
    
    total_assignments = len(strains_to_process) * len(bin_ids)
    processed_assignments = 0
    print_interval_assign = max(1, total_assignments // 20) # Print progress more frequently
    
    # Iterate through each strain
    for strain_id in strains_to_process:
        strain_data = processed_snp_data.loc[strain_id]
        
        # Iterate through each bin
        for bin_id in bin_ids:
            assignment = 'NoData' # Default if no SNPs or profiles
            
            if bin_id in clade_profiles and bin_id in snps_in_bins:
                bin_snp_cols = snps_in_bins[bin_id]
                
                if bin_snp_cols: # Check if there are actually SNPs in this bin
                    # Get the strain's pattern for the SNPs in this bin
                    # Ensure columns exist in strain_data
                    existing_snp_cols_in_strain = [col for col in bin_snp_cols if col in strain_data.index]
                    if not existing_snp_cols_in_strain:
                        assignment = 'NoData' # Strain has no data for SNPs in this bin
                    else:
                        strain_pattern = strain_data[existing_snp_cols_in_strain]
                        
                        # Compare with each clade profile for this bin
                        identities = {}
                        for clade, profile in clade_profiles[bin_id].items():
                            if isinstance(profile, pd.Series):
                                identities[clade] = calculate_percentage_identity(strain_pattern, profile)
                            else:
                                identities[clade] = 0.0 # Profile was invalid
                        
                        if identities: # If any identities were calculated
                            max_identity = max(identities.values())
                            # Check if max_identity is meaningfully high (e.g., > 0)
                            if max_identity > 0.0:
                                best_clades = [clade for clade, identity in identities.items() if identity == max_identity]
                                
                                if len(best_clades) == 1:
                                    assignment = best_clades[0]
                                else:
                                    assignment = 'Mixed' # Tie
                            else:
                                assignment = 'NoMatch' # Max identity was 0 or less
                        else:
                            assignment = 'NoProfiles' # No valid profiles to compare against
                # else: assignment remains 'NoData' (no SNPs in bin)
            # else: assignment remains 'NoData' (no profiles for bin)
            
            strain_bin_assignments.loc[strain_id, bin_id] = assignment
            
            processed_assignments += 1
            if processed_assignments % print_interval_assign == 0 or processed_assignments == total_assignments:
                elapsed_time = time.time() - start_time_assignment
                print(f"    Assigned {processed_assignments}/{total_assignments} strain-bins... ({elapsed_time:.2f}s elapsed)", end='\r')

    end_time_assignment = time.time()
    print(f"\n\nStrain Bin Assignment complete. Took {end_time_assignment - start_time_assignment:.2f} seconds.")
    print("First 5x5 of strain_bin_assignments matrix:")
    print(strain_bin_assignments.iloc[:5, :5])

else:
    print("\nSkipping Strain Bin Assignment: Required data (processed SNP data, genome bins, or clade profiles) is missing.")


## 6. Plotting

Visualize the strain bin assignments as a heatmap. Strains are sorted by clade, and bins represent the concatenated genome.

Virulence factors of interest:
1- ROP18    TGME49_chrVIIa:1,513,497..1,516,225(-)
2- ROP16
3- GRA15

In [None]:
if strain_bin_assignments is not None and processed_snp_data is not None and genome_bins is not None and chromosome_boundaries is not None:
    print("\nStarting Plotting...")
    
    # --- Prepare data for plotting ---
    
    # 1. Get Clade order and define colors
    # Use the unique clades found earlier, plus special categories
    all_categories = unique_clades + ['Mixed', 'NoMatch', 'NoProfiles', 'NoData']
    # Define a color map (using a standard map and adding specific colors for special cases)
    # Example: Use 'tab10' for clades, gray shades for others
    cmap_clades = plt.get_cmap('tab10', len(unique_clades))
    color_map = {clade: cmap_clades(i) for i, clade in enumerate(unique_clades)}
    color_map['Mixed'] = 'lightgrey'
    color_map['NoMatch'] = 'grey'
    color_map['NoProfiles'] = 'darkgrey'
    color_map['NoData'] = 'white'
    
    # 2. Create a numerical matrix for imshow
    category_to_int = {category: i for i, category in enumerate(all_categories)}
    int_to_category = {i: category for category, i in category_to_int.items()}
    
    # Map assignments to integers, handling potential new/unexpected values
    default_int = category_to_int['NoData'] # Default to NoData color
    numerical_matrix = strain_bin_assignments.map(lambda x: category_to_int.get(x, default_int))
    
    # 3. Sort strains by Clade
    # Get original clade for each strain
    strain_clades_original = processed_snp_data['Clade']
    # Align index with assignment matrix and sort
    sorted_strains_index = strain_clades_original.loc[numerical_matrix.index].sort_values().index
    numerical_matrix_sorted = numerical_matrix.loc[sorted_strains_index]
    
    # --- Create the plot ---
    fig, ax = plt.subplots(figsize=(20, 10)) # Adjust figure size as needed
    
    # Create the colormap and norm for imshow
    cmap_list = [color_map[int_to_category[i]] for i in range(len(all_categories))]
    custom_cmap = mcolors.ListedColormap(cmap_list)
    norm = mcolors.BoundaryNorm(np.arange(len(all_categories) + 1) - 0.5, len(all_categories))
    
    # Display the heatmap
    im = ax.imshow(numerical_matrix_sorted, aspect='auto', cmap=custom_cmap, norm=norm, interpolation='none')
    
    # Add horizontal lines between strains
    num_strains = len(sorted_strains_index)
    ax.hlines(np.arange(num_strains - 1) + 0.5, -0.5, numerical_matrix_sorted.shape[1] - 0.5, color='white', lw=1, alpha=1)
    
    # Add chromosome boundary lines
    for boundary in chromosome_boundaries[1:-1]: # Exclude start and end boundaries
        # Find the index corresponding to the cumulative boundary
        # This requires mapping cumulative position back to bin index
        boundary_bin_index = genome_bins[genome_bins['CumulativeEnd'] <= boundary].index.max() 
        if pd.notna(boundary_bin_index):
             # Draw line between bins
             ax.vlines(boundary_bin_index + 0.5, -0.5, len(sorted_strains_index) - 0.5, color='black', lw=3)
        
        # --- Add vertical lines for genes of interest ---
        genes_of_interest = {
            'ROP18': {'chr': 'TGME49_chrVIIa', 'start': 1513497, 'end': 1516225, 'color': 'red'},
            'GRA15': {'chr': 'TGME49_chrX', 'start': 7286296, 'end': 7289756, 'color': 'red'},
            'ROP16': {'chr': 'TGME49_chrVIIb', 'start': 1053320, 'end': 1056800, 'color': 'red'}
        }

        if chromosome_map and genome_bins is not None:
            num_strains = len(sorted_strains_index) # Get number of strains for text positioning\n
            for gene_name, info in genes_of_interest.items():
                if info['chr'] in chromosome_map:
                    chrom_start_cumulative = chromosome_map[info['chr']]['start_cumulative']
                    gene_start_cumulative = chrom_start_cumulative + info['start']

                    # Find the bin index containing the gene start position
                    # We want the index of the bin where CumulativeStart <= gene_start_cumulative < CumulativeEnd
                    gene_bin_index_match = genome_bins[
                        (genome_bins['CumulativeStart'] <= gene_start_cumulative) &
                        (genome_bins['CumulativeEnd'] > gene_start_cumulative)
                    ]

                    if not gene_bin_index_match.empty:
                        gene_bin_index = gene_bin_index_match.index[0]
                        # Draw line at the start of the bin containing the gene start
                        ax.axvline(gene_bin_index - 0.5, color=info['color'], linestyle='--', lw=1.5, alpha=0.8)
                        # Add text label above the plot, slightly offset from the top edge
                        ax.text(gene_bin_index, -0.02 * num_strains, gene_name, color=info['color'], ha='center', va='bottom', fontsize=8, rotation=90)
                    else:
                        print(f"Warning: Could not find bin index for gene {gene_name} at cumulative position {gene_start_cumulative}")
                else:
                     print(f"Warning: Chromosome {info['chr']} for gene {gene_name} not found in chromosome_map.")
        # -------------------------------------------------
        
    # Set labels and title
    ax.set_yticks(np.arange(len(sorted_strains_index)))
    ax.set_yticklabels(sorted_strains_index, fontsize=8) # Adjust fontsize if needed
    ax.set_ylabel("Strains (Sorted by Clade)")
    
    # Set x-axis ticks to chromosome midpoints
    chrom_tick_positions = []
    chrom_tick_labels = []
    if chromosome_map and genome_bins is not None:
        for chrom, data in chromosome_map.items():
             midpoint_cumulative = data['start_cumulative'] + data['length'] / 2
             # Find the bin index closest to the midpoint
             midpoint_bin_index = (genome_bins['CumulativeStart'] - midpoint_cumulative).abs().idxmin()
             chrom_tick_positions.append(midpoint_bin_index)
             # Shorten chromosome names for labels if needed
             label = chrom.replace('TGME49_chr', '')
             chrom_tick_labels.append(label)
        ax.set_xticks(chrom_tick_positions)
        ax.set_xticklabels(chrom_tick_labels, rotation=90, fontsize=8)
    else:
        ax.set_xlabel("Genomic Bins (Concatenated Chromosomes)")
        
    #ax.set_title("Haplogroup Assignments across Toxoplasma gondii Genome Bins (100kb)")
    
    # Create custom legend
    patches = [mpatches.Patch(color=color_map[category], label=category) for category in all_categories]
    ax.legend(handles=patches, bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0., title="Clade")
    
    plt.tight_layout(rect=[0, 0, 0.9, 1]) # Adjust layout to make space for legend
    
    # Save the plot
    print(f"\nSaving plot to {output_plot_file}...")
    try:
        plt.savefig(output_plot_file, dpi=300, bbox_inches='tight')
        print("Plot saved successfully.")
    except Exception as e:
        print(f"Error saving plot: {e}")
        
    plt.show()
    
    # --- Create and save separate plots for each Chromosome ---
    for chr_target in genome_info_sorted['Chromosome']:
        print(f"\nGenerating plot for Chromosome {chr_target}...")
        #chr_target = 'TGME49_chrVIIa'
        output_chr_plot_file = f"{chr_target}_haplogroups.pdf"
        
        # Filter data for the target chromosome
        chr_bins = genome_bins[genome_bins['Chromosome'] == chr_target].copy()
        if not chr_bins.empty:
            # Reset index for easier mapping of relative bin index
            chr_bins.reset_index(drop=True, inplace=True)
            chr_bin_ids = chr_bins['BinID'].tolist()
            strain_bin_assignments_chr = strain_bin_assignments[chr_bin_ids]
            numerical_matrix_chr = strain_bin_assignments_chr.map(lambda x: category_to_int.get(x, default_int))
            numerical_matrix_chr_sorted = numerical_matrix_chr.loc[sorted_strains_index]
            
            # Create new figure for this chromosome
            fig_chr, ax_chr = plt.subplots(figsize=(8, 10)) # Adjust size as needed
            
            # Display heatmap for the chromosome
            im_chr = ax_chr.imshow(numerical_matrix_chr_sorted, aspect='auto', cmap=custom_cmap, norm=norm, interpolation='none')
            
            # Add horizontal lines
            ax_chr.hlines(np.arange(num_strains - 1) + 0.5, -0.5, numerical_matrix_chr_sorted.shape[1] - 0.5, color='white', lw=1, alpha=1)
            
            # Add Genes of interest line (relative position)
            for gene in genes_of_interest.keys():
                gene_info = genes_of_interest[gene]
                if gene_info['chr'] == chr_target:
                    gene_start_pos = gene_info['start']
                    # Find bin index within the filtered chr_bins dataframe
                    gene_bin_match = chr_bins[
                        (chr_bins['BinStart'] <= gene_start_pos) &
                        (chr_bins['BinEnd'] > gene_start_pos)
                    ]
                    if not gene_bin_match.empty:
                        # Use the index from the filtered (and reset) chr_bins dataframe
                        gene_bin_index_relative = gene_bin_match.index[0]
                        ax_chr.axvline(gene_bin_index_relative - 0.5, color=gene_info['color'], linestyle='--', lw=1.5, alpha=0.8)
                        ax_chr.text(gene_bin_index_relative, -0.02 * num_strains, gene, color=gene_info['color'], ha='center', va='bottom', fontsize=8, rotation=90)
                    else:
                        print(f"Warning: Could not find bin index for {gene} within {chr_target}.")

            # Set labels and title for chrVIIa plot
            ax_chr.set_yticks(np.arange(len(sorted_strains_index)))
            ax_chr.set_yticklabels(sorted_strains_index, fontsize=8)
            ax_chr.set_ylabel("Strains (Sorted by Clade)")
            
            # Set x-axis ticks based on bin start positions for this chromosome
            tick_indices = np.linspace(0, len(chr_bins) - 1, num=10, dtype=int) # Show ~10 ticks
            tick_labels = [f"{chr_bins.loc[i, 'BinStart']:,}" for i in tick_indices]
            ax_chr.set_xticks(tick_indices)
            ax_chr.set_xticklabels(tick_labels, rotation=45, ha='right', fontsize=8)
            ax_chr.set_xlabel(f"Position on {chr_target} (bp)")
            
            ax_chr.set_title(f"Haplogroup Assignments across {chr_target} Bins (100kb)", pad=40)
            
            # Add legend
            ax_chr.legend(handles=patches, bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0., title="Clade")
            
            plt.tight_layout(rect=[0, 0, 0.9, 1])
            
            # Save the chrVIIa plot
            print(f"\nSaving plot to {output_chr_plot_file}...")
            try:
                fig_chr.savefig(output_chr_plot_file, dpi=300, bbox_inches='tight') # Use fig_chr here
                print(f"Chromosome {chr_target} plot saved successfully.")
            except Exception as e:
                print(f"Error saving chromosome {chr_target} plot: {e}")
                
            plt.show() # Show the chrVIIa plot as well

    else:
        print(f"\nSkipping plot generation for {chr_target}: No bins found for this chromosome.")
    # ---------------------------------------------------------

else:
    print("\nSkipping Plotting: Required data is missing.")
