In [None]:
# Cell 1: Setup for Figures Notebook - Load Data and Define Helpers

# --- Standard Library Imports ---
import pandas as pd
import numpy as np
import os
import time
from collections import Counter
import glob # For finding files in a folder
import re # For parsing organism names
import sys # For exit
import logging # For detailed logging
from pathlib import Path # For handling paths
# Ensure Biopython is installed: pip install biopython
try:
    from Bio import SeqIO
except ImportError:
    print("ERROR: Biopython is required for this script. Please install it: pip install biopython")
    sys.exit(1)

# --- Plotly Imports ---
import plotly.graph_objects as go
import plotly.express as px
import plotly.io as pio

# --- Setup Logging ---
# Get a named logger for this notebook
logger = logging.getLogger(__name__)
# Prevent adding handlers multiple times if script is re-run in interactive session
if not logger.hasHandlers():
    logger.setLevel(logging.INFO)
    console_handler = logging.StreamHandler(sys.stdout)
    log_formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
    console_handler.setFormatter(log_formatter)
    logger.addHandler(console_handler)
else:
    # Ensure level is set if handlers already exist (e.g., in notebook re-run)
    logger.setLevel(logging.INFO)


print("--- Figures Notebook Setup ---")

# --- Configuration ---
# INPUT file path: Use the latest database version that includes Euk annotations,
# DIAMOND hit details, coverage columns, RBH flag, and potentially merged APSI/Motif data.
# Make sure this path points to your most recent, complete database file!
database_path = 'proteome_database_v2.3.csv' # <-- Update if your latest file has a different name/path

# Directory where the previous notebook saved summary data (APSI, Motifs, etc.)
# This path needs to match the output_summary_dir_phase1 from your analysis notebook
# This is needed to load APSI/Motif data if it's not already in the main database file.
output_summary_dir_phase1 = Path("./output_summary_data_hit_validation_phase1") # <-- Update this path if different

# Directory to save plots and summary data generated in this notebook
# Use distinct directory names for figure outputs vs. analysis outputs
output_figure_dir = 'publication_figures'
output_figure_summary_dir = 'publication_figure_data' # For tables/data backing figures

# Create output directories if they don't exist
# Use Path objects here for consistency
Path(output_figure_dir).mkdir(parents=True, exist_ok=True)
print(f"Ensured plot output directory exists: {output_figure_dir}")
Path(output_figure_summary_dir).mkdir(parents=True, exist_ok=True)
print(f"Ensured summary data directory exists: {output_figure_summary_dir}")

# Path to InterPro entry list file (ensure this file is in the correct location)
# This is needed for translating IPR/domain names in helper functions
interpro_entry_path = 'interpro_entry.txt' # <-- Update path if needed


# --- Define Key Column Names (Based on your database structure) ---
# Define all relevant column names here for easy access and consistency
protein_id_col = 'ProteinID'
sequence_col = 'Sequence'
source_dataset_col = 'Source_Dataset'
source_genome_accession_col = 'Source_Genome_Assembly_Accession'
source_protein_annotation_col = 'Source_Protein_Annotation'
ncbi_taxid_col = 'NCBI_TaxID'
asgard_phylum_col = 'Asgard_Phylum'
virus_family_col = 'Virus_Family'
virus_name_col = 'Virus_Name'
orthogroup_col = 'Orthogroup'
ipr_col = 'IPR_Signatures'
ipr_go_terms_col = 'IPR_GO_Terms'
uniprotkb_ac_col = 'UniProtKB_AC'
num_domains_col = 'Num_Domains'
domain_arch_col = 'Domain_Architecture'
type_col = 'Type' # As in 'Annotated', 'Uncharacterized'
is_hypothetical_col = 'Is_Hypothetical'
has_known_structure_col = 'Has_Known_Structure'
percent_disorder_col = 'Percent_Disorder'
specific_func_cat_col = 'Specific_Functional_Category'
broad_func_cat_col = 'Broad_Functional_Category'
category_trigger_col = 'Category_Trigger'
signal_peptide_col = 'Signal_Peptide_USPNet'
sp_cleavage_site_col = 'SP_Cleavage_Site_USPNet'
original_seq_length_col = 'Original_Seq_Length'
group_col = 'Group' # 'Asgard' or 'GV'
seqsearch_pdb_hit_col = 'SeqSearch_PDB_Hit'
seqsearch_afdb_hit_col = 'SeqSearch_AFDB_Hit'
has_reference_structure_col = 'Has_Reference_Structure'
localization_col = 'Predicted_Subcellular_Localization'
mature_protein_sequence_col = 'Mature_Protein_Sequence'
mature_seq_length_col = 'Mature_Seq_Length'
seqsearch_mgnify_hit_col = 'SeqSearch_MGnify_Hit'
seqsearch_esma_hit_col = 'SeqSearch_ESMA_Hit'
structurally_dark_col = 'Is_Structurally_Dark' # Derived column
esp_col = 'Is_ESP' # Derived column
hit_flag_col = 'Has_Euk_DIAMOND_Hit' # Flag for any Euk hit passing initial e-value
euk_hit_sseqid_col = 'Euk_Hit_SSEQID' # Best Euk hit SSEQID
euk_hit_organism_col = 'Euk_Hit_Organism' # Best Euk hit organism
euk_hit_pident_col = 'Euk_Hit_PIDENT' # Best Euk hit PIDENT
euk_hit_evalue_col = 'Euk_Hit_EVALUE' # Best Euk hit EVALUE
euk_hit_protein_name_col = 'Euk_Hit_Protein_Name' # Best Euk hit protein name
euk_hit_qstart_col = 'Euk_Hit_Qstart' # Alignment start on query (Asgard)
euk_hit_qend_col = 'Euk_Hit_Qend' # Alignment end on query (Asgard)
euk_hit_sstart_col = 'Euk_Hit_Sstart' # Alignment start on subject (Euk)
euk_hit_send_col = 'Euk_Hit_Send' # Alignment end on subject (Euk)
euk_hit_slen_diamond_col = 'Euk_Hit_Slen_Diamond' # Length of Euk hit from DIAMOND output
query_coverage_col = 'Query_Coverage' # Calculated query coverage
subject_coverage_col = 'Subject_Coverage' # Calculated subject coverage
rbh_flag_col = 'Is_RBH' # Flag for Reciprocal Best Hit (if RBH analysis was included)

# Add columns for APSI and Motifs if they were added to the main DB
apsi_col = 'Intra_OG_APSI' # Assuming this column name if merged
motif_col = 'Conserved_Motifs' # Assuming this column name if merged (might be list/string)
num_og_sequences_col = 'Num_OG_Sequences' # Number of sequences in the OG (used for MSA)

# --- Define Arcadia Color Palettes (from Style Guide PDF) ---
print("\n--- Defining Manual Arcadia Color Palettes (from Style Guide) ---")
# Define all colors from the style guide PDF for flexibility
arcadia_colors_manual = {
    "aegean": "#5088C5", "amber": "#F28360", "seaweed": "#3B9886", "canary": "#F7B846",
    "aster": "#7A77AB", "rose": "#F898AE", "vital": "#73B5E3", "tangerine": "#FFB984", # Corrected Tangerine from PDF: FFb883 -> FFB984 (common web palette) - using FFB984 as per previous code
    "oat": "#F5E4BE", "wish": "#BABEE0", "lime": "#97CD78", "dragon": "#C85152",
    "sky": "#C6E7F4", "dress": "#F8C5C1", "taupe": "#DBD1C3", "denim": "#B6C8D4",
    "sage": "#B5BEA4", "mars": "#DA9085", "marine": "#8A99AD", "shell": "#EDE0D6",
    "white": "#FFFFFF", "gray": "#EBEDE8", "chateau": "#B9AFA7", # Corrected Chateau from PDF: BAB0A8 -> B9AFA7
    "bark": "#8F8885", "slate": "#43413F", "charcoal": "#484B50", "crow": "#292928", "black": "#09090A",
    "forest": "#596F74", # Forest is in Neutrals in PDF
    "parchment": "#FEF7F1", "zephyr": "#F4FBFF", # Corrected Zephyr from PDF: F4FBFE -> F4FBFF
    "lichen": "#F7FBEF", "dawn": "#F8F4F1"
}
# Define the specific palettes as lists of hex codes for easy cycling/use
arcadia_primary_palette = [ arcadia_colors_manual[c] for c in ["aegean", "amber", "seaweed", "canary", "aster", "rose", "vital", "tangerine", "oat", "wish", "lime", "dragon"] ]
arcadia_secondary_palette = [ arcadia_colors_manual[c] for c in ["sky", "dress", "taupe", "denim", "sage", "mars", "marine", "shell"] ]
arcadia_neutrals_palette = [ arcadia_colors_manual[c] for c in ["gray", "chateau", "bark", "slate", "charcoal", "forest", "crow"] ]
arcadia_background_palette = [ arcadia_colors_manual[c] for c in ["parchment", "zephyr", "lichen", "dawn"] ] # Added background colors

print("Manual Arcadia palettes created.")

# --- Configure Plotly Defaults (Adhering to Style Guide) ---
print("\n--- Configuring Plotly Defaults (Adhering to Style Guide) ---")
pio.templates.default = "plotly_white" # Start with a clean white background

# Define default layout for bold axis titles, NO gridlines, strong black axis lines, and NO plot titles
# THIS VARIABLE NEEDS TO BE DEFINED *BEFORE* RUNNING FIGURE CELLS
plotly_layout_defaults = go.Layout(
    xaxis=dict(
        title=dict(font=dict(size=15, color='black', family='Arial', weight='bold')), # Bold axis title, 15pt
        showgrid=False, # Remove x-axis gridlines
        zeroline=False, # Optional: Remove zero line as well
        showline=True, # Show axis line
        linecolor='black', # Axis line color
        linewidth=1.5, # Axis line width (adjust as needed for "strong")
        mirror=False, # Draw line on opposite side (False means only on the side where ticks/labels are)
        ticks="outside", # Show ticks outside the axis line
        ticklen=5, # Tick length (5px as per style guide)
        tickwidth=1.5, # Tick width (match axis line width for prominence)
        tickcolor='black' # Tick color
        # Tick label font can be set here too if needed, e.g., tickfont=dict(size=15, family='Arial', color='black')
    ),
    yaxis=dict(
        title=dict(font=dict(size=15, color='black', family='Arial', weight='bold')), # Bold axis title, 15pt
        showgrid=False, # Remove y-axis gridlines
        zeroline=False, # Optional: Remove zero line as well
        showline=True, # Show axis line
        linecolor='black', # Axis line color
        linewidth=1.5, # Axis line width
        mirror=False, # Draw line on opposite side
        ticks="outside", # Show ticks outside the axis line
        ticklen=5, # Tick length
        tickwidth=1.5, # Tick width
        tickcolor='black' # Tick color
        # Tick label font can be set here too if needed
    ),
    title=None, # Ensure no plot title by default
    # Optional: Configure default font for tick labels if needed
    font=dict(family="Arial", size=15, color="black"), # Example for 15pt tick labels
    # Optional: Configure default legend title font
    # legend=dict(title=dict(font=dict(weight='bold')))
)

# Apply these default layout settings directly to the plotly_white template
# This should ensure they are applied to all px plots by default
pio.templates['plotly_white'].layout.xaxis.title.font.update(size=15, weight='bold')
pio.templates['plotly_white'].layout.yaxis.title.font.update(size=15, weight='bold')
pio.templates['plotly_white'].layout.xaxis.showgrid = False
pio.templates['plotly_white'].layout.yaxis.showgrid = False
pio.templates['plotly_white'].layout.xaxis.showline = True # Ensure line is shown
pio.templates['plotly_white'].layout.xaxis.linecolor = 'black'
pio.templates['plotly_white'].layout.xaxis.linewidth = 1.5
pio.templates['plotly_white'].layout.xaxis.mirror = False # Draw line on one side
pio.templates['plotly_white'].layout.xaxis.ticks = "outside" # Show ticks outside
pio.templates['plotly_white'].layout.xaxis.ticklen = 5 # Tick length
pio.templates['plotly_white'].layout.xaxis.tickwidth = 1.5 # Tick width
pio.templates['plotly_white'].layout.xaxis.tickcolor = 'black' # Tick color

pio.templates['plotly_white'].layout.yaxis.showline = True # Ensure line is shown
pio.templates['plotly_white'].layout.yaxis.linecolor = 'black'
pio.templates['plotly_white'].layout.yaxis.linewidth = 1.5
pio.templates['plotly_white'].layout.yaxis.mirror = False # Draw line on one side
pio.templates['plotly_white'].layout.yaxis.ticks = "outside" # Show ticks outside
pio.templates['plotly_white'].layout.yaxis.ticklen = 5 # Tick length
pio.templates['plotly_white'].layout.yaxis.tickwidth = 1.5 # Tick width
pio.templates['plotly_white'].layout.yaxis.tickcolor = 'black' # Tick color


pio.templates['plotly_white'].layout.title = None # Set no title in template
# Note: Modifying default templates can have broad effects. Applying to individual figures or using update_layout might be safer.

print("Plotly default template set to 'plotly_white'.")
print("Default layout settings configured for bold axis titles (15pt), NO gridlines, strong black axis lines (1.5pt), tick marks (5px, 1.5pt, black), and NO plot titles.")


# --- Load InterPro Entry Data ---
# This is needed for translating IPR IDs to names in domain architectures/IPRs
print(f"\n--- Loading InterPro Entry Data from '{interpro_entry_path}' ---")
ipr_lookup = {}
start_time_ipr = time.time()
try:
    # Adjust usecols/names if your interpro_entry.txt file format is different
    ipr_info_df = pd.read_csv( interpro_entry_path, sep='\t', usecols=[0, 1, 2], names=['IPR_ID', 'Type', 'Name'], header=0, comment='#', on_bad_lines='warn' )
    # Ensure column names match what's expected after loading
    if 'ENTRY_AC' in ipr_info_df.columns: ipr_info_df.rename(columns={'ENTRY_AC': 'IPR_ID'}, inplace=True)
    if 'ENTRY_TYPE' in ipr_info_df.columns: ipr_info_df.rename(columns={'ENTRY_TYPE': 'Type'}, inplace=True)
    if 'ENTRY_NAME' in ipr_info_df.columns: ipr_info_df.rename(columns={'ENTRY_NAME': 'Name'}, inplace=True)

    if {'IPR_ID', 'Name', 'Type'}.issubset(ipr_info_df.columns):
        ipr_info_df['IPR_ID'] = ipr_info_df['IPR_ID'].astype(str).str.strip()
        ipr_lookup = ipr_info_df.set_index('IPR_ID')[['Type', 'Name']].to_dict('index')
        print(f"Loaded InterPro entry data for {len(ipr_lookup)} entries in {time.time() - start_time_ipr:.2f} seconds.")
    else:
        print(f"Warning: Expected columns ('IPR_ID', 'Name', 'Type') not all found in '{interpro_entry_path}' after loading.")
except FileNotFoundError:
    print(f"Warning: InterPro entry file not found at '{interpro_entry_path}'. Domain name translations will not be available.")
except Exception as e:
    print(f"Warning: An error occurred loading or processing '{interpro_entry_path}': {e}")

if not ipr_lookup:
    print("Warning: ipr_lookup is empty. Domain name translations will not be available for plots/tables.")


# --- Define Helper Functions ---
print("\n--- Defining Helper Functions ---")
# Include necessary helper functions from your previous notebook here.
# You might need: translate_architecture, truncate_string, clean_protein_name, etc.
# Only include functions actually needed for figure generation and data manipulation *within* this notebook.

# Example: translate_architecture (requires ipr_lookup)
def translate_architecture(arch_string, lookup_dict=ipr_lookup):
    """Translates IPR IDs in a domain architecture string to names using a lookup dictionary."""
    if not isinstance(arch_string, str) or not arch_string or not lookup_dict:
        return arch_string # Return original if invalid input or no lookup

    processed_arch_string = str(arch_string).replace('|', ';')
    ipr_ids = processed_arch_string.split(';')
    translated_parts = []
    for ipr_id in ipr_ids:
        ipr_id_clean = ipr_id.strip()
        if ipr_id_clean in lookup_dict:
            name = lookup_dict[ipr_id_clean].get('Name', ipr_id_clean)
            # Truncate long names for display
            name = name[:40] + '...' if len(name) > 43 else name
            translated_parts.append(f"{name} ({ipr_id_clean})")
        elif ipr_id_clean:
            translated_parts.append(ipr_id_clean) # Keep unknown IDs

    full_translation = "; ".join(translated_parts)
    max_len_display = 150 # Max length for display
    if len(full_translation) > max_len_display:
        full_translation = full_translation[:max_len_display-3] + "..."

    return full_translation

# Example: truncate_string
def truncate_string(text, max_len):
    """Truncates a string if it exceeds max_len, adding '...'."""
    if isinstance(text, str) and len(text) > max_len:
        return text[:max_len-3] + "..."
    return text

# Example: clean_protein_name
def clean_protein_name(name):
    """Cleans common annotations from protein names."""
    if pd.isna(name): return "Unknown/Not Found"; name = str(name).strip()
    name = re.sub(r'\s*\|\s*.*','', name); name = re.sub(r'\bisoform\s+[\w-]+\b', '', name, flags=re.IGNORECASE).strip()
    name = re.sub(r'\bpartial\b', '', name, flags=re.IGNORECASE).strip(); name = re.sub(r'\bputative\b', '', name, flags=re.IGNORECASE).strip()
    name = re.sub(r'\bpredicted protein\b', '', name, flags=re.IGNORECASE).strip(); name = re.sub(r'\btype\s+\w+\b', '', name, flags=re.IGNORECASE).strip()
    name = re.sub(r'\bprotein\b', '', name, flags=re.IGNORECASE).strip(); name = re.sub(r'\buncharacterized\b', 'Uncharacterized', name, flags=re.IGNORECASE).strip()
    name = re.sub(r'[;,]$', '', name).strip(); name = re.sub(r'\s*\(Fragment\)$', '', name, flags=re.IGNORECASE).strip()
    name = re.sub(r'\s+', ' ', name).strip(); name = name[0].upper() + name[1:] if len(name) > 0 else name
    return name if name else "Unknown/Not Found"

# Add other necessary helper functions here...
# e.g., get_ipr_counts if you plot IPR frequencies directly


# --- Load Main Database ---
print(f"\n--- Loading Data from '{database_path}' ---")
try:
    # Load the database that already contains all annotations,
    # including Euk hit details, coverage, and RBH flag (if added).
    df_full = pd.read_csv(database_path, low_memory=False)
    # Ensure ProteinID is string type right after loading
    if protein_id_col in df_full.columns:
         df_full[protein_id_col] = df_full[protein_id_col].astype(str)

    # --- Merge APSI and Motif Data ---
    # Assuming APSI and Motif results were saved to CSVs in the previous notebook
    # and are NOT already merged into the main database file.
    # If they ARE already in your database_path file, you can skip this merge step.

    # Use the defined output_summary_dir_phase1 variable
    apsi_file_path = output_summary_dir_phase1 / "intra_og_apsi_values.csv" # Path from previous notebook
    motifs_file_path = output_summary_dir_phase1 / "intra_og_conserved_motifs.csv" # Path from previous notebook

    print(f"\nAttempting to merge APSI data from: {apsi_file_path}") # Added print statement
    if apsi_file_path.is_file():
        print("APSI file found.") # Added print statement
        try:
            df_apsi = pd.read_csv(apsi_file_path)
            print(f"APSI file read successfully. Shape: {df_apsi.shape}. Columns: {df_apsi.columns.tolist()}") # Added print statement
            # Rename 'Num_Sequences' from APSI file to avoid conflict if needed, or use it
            df_apsi.rename(columns={'Num_Sequences': num_og_sequences_col}, inplace=True)
            # Ensure Orthogroup column exists in df_apsi before merging
            if orthogroup_col in df_apsi.columns:
                # Rename the 'APSI' column in df_apsi to match the desired column name 'Intra_OG_APSI'
                if 'APSI' in df_apsi.columns:
                    df_apsi.rename(columns={'APSI': apsi_col}, inplace=True)
                    print(f"Renamed 'APSI' column to '{apsi_col}' in APSI data.") # Added print statement
                else:
                    print(f"Warning: 'APSI' column not found in APSI data. Cannot rename to '{apsi_col}'.")


                # Merge APSI data into df_full based on Orthogroup
                # Note: APSI is per OG, so merge will add APSI to all proteins in that OG
                # Use a left merge to keep all proteins from df_full
                print(f"Merging APSI data on column '{orthogroup_col}'. df_full shape before merge: {df_full.shape}") # Added print statement
                # Select only the columns needed from df_apsi for the merge
                cols_to_merge_from_apsi = [orthogroup_col, apsi_col, num_og_sequences_col]
                # Filter to include only columns that actually exist in df_apsi after potential renaming
                cols_to_merge_from_apsi_existing = [col for col in cols_to_merge_from_apsi if col in df_apsi.columns]

                if cols_to_merge_from_apsi_existing:
                     df_full = df_full.merge(df_apsi[cols_to_merge_from_apsi_existing], on=orthogroup_col, how='left')
                     print(f"Merged APSI data. df_full shape after merge: {df_full.shape}. Columns: {df_full.columns.tolist()}") # Added print statement
                else:
                     print("Warning: No relevant columns found in APSI data for merging.")


            else:
                 print(f"Warning: Orthogroup column '{orthogroup_col}' not found in APSI data. Skipping APSI merge.")

        except Exception as e:
            logger.error(f"Failed to merge APSI data: {e}")
    else:
        print(f"Warning: APSI file not found at '{apsi_file_path}'. APSI data will not be available.")
        # Add empty APSI column if not present to prevent downstream errors
        if apsi_col not in df_full.columns: df_full[apsi_col] = np.nan
        if num_og_sequences_col not in df_full.columns: df_full[num_og_sequences_col] = np.nan


    print(f"\nAttempting to load and merge Motif data from: {motifs_file_path}") # Added print statement
    if motifs_file_path.is_file():
        print("Motif file found.") # Added print statement
        try:
            df_motifs = pd.read_csv(motifs_file_path)
            print(f"Motif file read successfully. Shape: {df_motifs.shape}. Columns: {df_motifs.columns.tolist()}") # Added print statement
            # You can now use df_motifs directly for motif analysis/plotting
            print(f"Loaded {len(df_motifs)} motifs for {len(df_motifs['Orthogroup'].unique())} orthogroups.")
            # Example: Add a flag to df_full if an OG has any conserved motif
            if orthogroup_col in df_motifs.columns:
                 ogs_with_motifs = df_motifs[orthogroup_col].unique()
                 df_full['Has_Conserved_Motif'] = df_full[orthogroup_col].isin(ogs_with_motifs)
                 print(f"Added 'Has_Conserved_Motif' flag to df_full. Columns: {df_full.columns.tolist()}") # Added print statement
            else:
                 print(f"Warning: Orthogroup column '{orthogroup_col}' not found in Motif data. Skipping 'Has_Conserved_Motif' flag.")
                 if 'Has_Conserved_Motif' not in df_full.columns: df_full['Has_Conserved_Motif'] = False # Ensure column exists

        except Exception as e:
            logger.error(f"Failed to load motif data: {e}")
            # Add empty flag column if not present
            if 'Has_Conserved_Motif' not in df_full.columns: df_full['Has_Conserved_Motif'] = False
    else:
        print(f"Warning: Motif file not found at '{motifs_file_path}'. Motif data will not be available.")
        if 'Has_Conserved_Motif' not in df_full.columns: df_full['Has_Conserved_Motif'] = False


    print(f"\nSuccessfully loaded and prepared data. Final df_full shape: {df_full.shape}. Columns: {df_full.columns.tolist()}") # Added print statement

except FileNotFoundError:
    print(f"ERROR: Database file not found at '{database_path}'. Please check the path.")
    # Define df_full as empty to prevent downstream errors
    df_full = pd.DataFrame()
except Exception as e:
    print(f"An error occurred while loading or processing the database: {e}")
    # Define df_full as empty
    df_full = pd.DataFrame()


# --- Create Color Maps for Plotting (after df_full is loaded) ---
# Re-create color maps based on the loaded data to ensure categories match
print("\n--- Creating Color Maps ---")
# Asgard Phylum Colors
asgard_phylum_color_map = {}
if asgard_phylum_col in df_full.columns and group_col in df_full.columns:
    # Filter for Asgard group before getting unique phyla
    asgard_phyla_cat = df_full[df_full[group_col] == 'Asgard'][asgard_phylum_col].dropna().unique(); asgard_phyla_cat.sort()
    if len(asgard_phyla_cat) > 0: asgard_phylum_color_map = {phylum: arcadia_primary_palette[i % len(arcadia_primary_palette)] for i, phylum in enumerate(asgard_phyla_cat)}
    asgard_phylum_color_map['Unknown Phylum'] = arcadia_colors_manual.get('gray', '#bdbdbd')
print(f"Asgard phylum color map created for {len(asgard_phylum_color_map)} phyla.")

# Localization Colors
localization_color_map = {}
if localization_col in df_full.columns:
    # Get all unique localization values, including potential NaNs, from Asgard and GV groups
    all_localizations = df_full[df_full[group_col].isin(['Asgard', 'GV'])][localization_col].unique()
    # Convert to list, handling NaN explicitly if needed for mapping
    all_localizations_list = [str(loc) if pd.isna(loc) else loc for loc in all_localizations]
    all_localizations_list.sort() # Sort for consistent color assignment
    localization_assignments = {
        'Archaea: Cytoplasmic/Membrane (non-SP)': arcadia_colors_manual.get('aegean', '#5088C5'),
        'Archaea: Membrane-associated (Lipoprotein/Pilin)': arcadia_colors_manual.get('amber', '#F28360'),
        'Archaea: Secreted/Membrane (Sec/Tat pathway)': arcadia_colors_manual.get('seaweed', '#3B9886'),
        'Host: Cytoplasm/Nucleus/Virus Factory': arcadia_colors_manual.get('vital', '#73B5E3'),
        'Host: Membrane-associated (Lipoprotein/Pilin-like)': arcadia_colors_manual.get('tangerine', '#FFB984'),
        'Host: Secretory Pathway (Secreted/Membrane/Organelle)': arcadia_colors_manual.get('lime', '#97CD78'),
        'CYTOPLASMIC': arcadia_colors_manual.get('vital', '#73B5E3'), # Include simplified terms if present
        'MEMBRANE': arcadia_colors_manual.get('tangerine', '#FFB984'),
        'EXTRACELLULAR': arcadia_colors_manual.get('lime', '#97CD78'),
        'Unknown': arcadia_colors_manual.get('gray', '#EBEDE8'),
        'nan': arcadia_colors_manual.get('gray', '#EBEDE8') # Handle NaN explicitly
    }
    fallback_palette_loc = arcadia_neutrals_palette + arcadia_secondary_palette # Use defined palettes
    fallback_idx_loc = 0
    for loc in all_localizations_list:
        if loc not in localization_color_map:
             localization_color_map[loc] = localization_assignments.get(loc, fallback_palette_loc[fallback_idx_loc % len(fallback_palette_loc)])
             if loc not in localization_assignments:
                  fallback_idx_loc += 1
print(f"Localization color map created for {len(localization_color_map)} categories.")

# Broad Functional Category Colors
broad_category_color_map = {}
if broad_func_cat_col in df_full.columns:
    # Get all unique broad categories, including potential NaNs, from Asgard and GV groups
    all_broad_categories = df_full[df_full[group_col].isin(['Asgard', 'GV'])][broad_func_cat_col].unique()
    # Convert to list, handling NaN explicitly
    all_broad_categories_list = [str(cat) if pd.isna(cat) else cat for cat in all_broad_categories]
    all_broad_categories_list.sort() # Sort for consistent color assignment

    category_assignments = {
        'Cytoskeleton': arcadia_colors_manual.get('aegean', '#5088C5'),
        'Membrane Trafficking/Vesicles': arcadia_colors_manual.get('amber', '#F28360'),
        'ESCRT/Endosomal Sorting': arcadia_colors_manual.get('seaweed', '#3B9886'),
        'Ubiquitin System': arcadia_colors_manual.get('aster', '#7A77AB'),
        'N-glycosylation': arcadia_colors_manual.get('rose', '#F898AE'),
        'Nuclear Transport/Pore': arcadia_colors_manual.get('vital', '#73B5E3'),
        'DNA Info Processing': arcadia_colors_manual.get('canary', '#F7B846'),
        'RNA Info Processing': arcadia_colors_manual.get('lime', '#97CD78'),
        'Translation': arcadia_colors_manual.get('tangerine', '#FFB984'),
        'Signal Transduction': arcadia_colors_manual.get('aster', '#7A77AB'),
        'Metabolism': arcadia_colors_manual.get('sky', '#C6E7F4'), # Using sky as per previous code
        'Other Specific Annotation': arcadia_colors_manual.get('denim', '#B6C8D4'),
        'Unknown/Unclassified': arcadia_colors_manual.get('gray', '#EBEDE8'),
        'nan': arcadia_colors_manual.get('gray', '#EBEDE8') # Handle NaN explicitly
    }
    fallback_palette_cat = arcadia_primary_palette + arcadia_secondary_palette + arcadia_neutrals_palette # Use defined palettes
    fallback_idx_cat = 0
    for category in all_broad_categories_list:
        if category not in broad_category_color_map:
            broad_category_color_map[category] = category_assignments.get(category, fallback_palette_cat[fallback_idx_cat % len(fallback_palette_cat)])
            if category not in category_assignments:
                fallback_idx_cat += 1
print(f"Broad functional category color map created for {len(broad_category_color_map)} categories.")


print("\n--- Figures Notebook Setup Complete ---")
print("DataFrame 'df_full' is loaded and prepared for figure generation.")
print("Color maps and helper functions are defined.")
print("You can now proceed with generating figures in subsequent cells.")



In [None]:
# Cell 2: Figure 2 - General Database Characterization

# --- Imports ---
# Assumes pandas, numpy, plotly.express, plotly.graph_objects, plotly.io,
# and necessary variables/helper functions (like df_full, group_col,
# protein_id_col, orthogroup_col, broad_func_cat_col, structurally_dark_col,
# esp_col, hit_flag_col, euk_hit_protein_name_col, clean_protein_name,
# arcadia_colors_manual, arcadia_primary_palette, arcadia_secondary_palette,
# arcadia_neutrals_palette, output_figure_dir, output_figure_summary_dir,
# plotly_layout_defaults) are defined in the setup cell (Cell 1).

print("\n\n--- Generating Figure 2: General Database Characterization ---")

# Ensure df_full is loaded
if 'df_full' not in locals() or df_full.empty:
    print("ERROR: df_full not loaded. Please run the setup cell (Cell 1).")
else:
    # --- Filter out groups other than 'Asgard' and 'GV' ---
    # Create a filtered DataFrame containing only 'Asgard' and 'GV' groups
    if group_col in df_full.columns:
        df_filtered_groups = df_full[df_full[group_col].isin(['Asgard', 'GV'])].copy()
        print(f"Filtered out groups other than 'Asgard' and 'GV'. Remaining proteins: {len(df_filtered_groups)}")
        if df_filtered_groups.empty:
            print("WARNING: No 'Asgard' or 'GV' proteins remaining after filtering. Cannot generate Figure 2.")
    else:
        print(f"ERROR: Group column '{group_col}' not found in df_full. Cannot filter groups.")
        df_filtered_groups = pd.DataFrame() # Create empty df if group column is missing


    # Proceed with plotting only if filtered data is available
    if not df_filtered_groups.empty:

        # Define colors for Asgard and GV using the specified hex codes
        group_colors = {
            'Asgard': arcadia_colors_manual.get('aegean', '#5088C5'), # Using aegean as specified
            'GV': arcadia_colors_manual.get('amber', '#F28360') # Using amber as specified
        }
        print(f"\nUsing Group Colors: Asgard={group_colors.get('Asgard')}, GV={group_colors.get('GV')}")


        # --- Figure 2A Part 1: Number of Genomes per Group ---
        print("\n--- Figure 2A Part 1: Number of Genomes per Group ---")

        # Calculate number of genomes per group - using .nunique() on filtered data
        if source_genome_accession_col in df_filtered_groups.columns and group_col in df_filtered_groups.columns:
            genome_counts = df_filtered_groups.groupby(group_col)[source_genome_accession_col].nunique().reset_index()
            genome_counts.columns = ['Group', 'Num_Genomes']
            print("\nGenome Counts:")
            print(genome_counts.to_markdown(index=False))

            fig_2a_genomes = px.bar(
                genome_counts,
                x='Group',
                y='Num_Genomes',
                color='Group',
                # title='Number of Genomes by Group', # No title
                labels={'Group': 'Group', 'Num_Genomes': 'Number of Genomes'},
                color_discrete_map=group_colors,
                template='plotly_white'
            )
            # Apply default layout and remove gridlines
            fig_2a_genomes.update_layout(plotly_layout_defaults)
            fig_2a_genomes.update_xaxes(title_text='Group', showgrid=False)
            fig_2a_genomes.update_yaxes(title_text='Number of Genomes', showgrid=False)
            fig_2a_genomes.update_layout(showlegend=False) # Hide legend if colors are obvious from x-axis

            # Export Figure 2A Part 1 (HTML and PDF/SVG)
            fig_2a_genomes_path_html = Path(output_figure_dir) / "figure2a_genomes_count.html"
            fig_2a_genomes.write_html(str(fig_2a_genomes_path_html))
            print(f"Figure 2A (Genomes) HTML saved to: {fig_2a_genomes_path_html}")
            
            # Export to PDF/SVG (requires kaleido)
            try:
                fig_2a_genomes_path_pdf = Path(output_figure_dir) / "figure2a_genomes_count.pdf"
                fig_2a_genomes.write_image(str(fig_2a_genomes_path_pdf))
                print(f"Figure 2A (Genomes) PDF saved to: {fig_2a_genomes_path_pdf}")
            except Exception as e:
                print(f"Warning: Could not export Figure 2A (Genomes) to PDF. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}")
            try:
                fig_2a_genomes_path_svg = Path(output_figure_dir) / "figure2a_genomes_count.svg"
                fig_2a_genomes.write_image(str(fig_2a_genomes_path_svg))
                print(f"Figure 2A (Genomes) SVG saved to: {fig_2a_genomes_path_svg}")
            except Exception as e:
                print(f"Warning: Could not export Figure 2A (Genomes) to SVG. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}")


        else:
            print(f"Skipping Figure 2A Part 1 due to missing columns in filtered data: '{source_genome_accession_col}' or '{group_col}'.")


        # --- Figure 2A Part 2: Number of Proteins per Group ---
        print("\n--- Figure 2A Part 2: Number of Proteins per Group ---")

        # Calculate number of proteins per group - using .count() on filtered data
        if protein_id_col in df_filtered_groups.columns and group_col in df_filtered_groups.columns:
            protein_counts = df_filtered_groups.groupby(group_col)[protein_id_col].count().reset_index()
            protein_counts.columns = ['Group', 'Num_Proteins']
            print("\nProtein Counts:")
            print(protein_counts.to_markdown(index=False))

            fig_2a_proteins = px.bar(
                protein_counts,
                x='Group',
                y='Num_Proteins',
                color='Group',
                # title='Number of Proteins by Group', # No title
                labels={'Group': 'Group', 'Num_Proteins': 'Number of Proteins'},
                color_discrete_map=group_colors,
                template='plotly_white'
            )
            # Apply default layout and remove gridlines
            fig_2a_proteins.update_layout(plotly_layout_defaults)
            fig_2a_proteins.update_xaxes(title_text='Group', showgrid=False)
            fig_2a_proteins.update_yaxes(title_text='Number of Proteins', showgrid=False)
            fig_2a_proteins.update_layout(showlegend=False) # Hide legend

            # Export Figure 2A Part 2 (HTML and PDF/SVG)
            fig_2a_proteins_path_html = Path(output_figure_dir) / "figure2a_proteins_count.html"
            fig_2a_proteins.write_html(str(fig_2a_proteins_path_html))
            print(f"Figure 2A (Proteins) HTML saved to: {fig_2a_proteins_path_html}")
            
            # Export to PDF/SVG (requires kaleido)
            try:
                fig_2a_proteins_path_pdf = Path(output_figure_dir) / "figure2a_proteins_count.pdf"
                fig_2a_proteins.write_image(str(fig_2a_proteins_path_pdf))
                print(f"Figure 2A (Proteins) PDF saved to: {fig_2a_proteins_path_pdf}")
            except Exception as e:
                print(f"Warning: Could not export Figure 2A (Proteins) to PDF. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}")
            try:
                fig_2a_proteins_path_svg = Path(output_figure_dir) / "figure2a_proteins_count.svg"
                fig_2a_proteins.write_image(str(fig_2a_proteins_path_svg))
                print(f"Figure 2A (Proteins) SVG saved to: {fig_2a_proteins_path_svg}")
            except Exception as e:
                print(f"Warning: Could not export Figure 2A (Proteins) to SVG. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}")


        else:
            print(f"Skipping Figure 2A Part 2 due to missing columns in filtered data: '{protein_id_col}' or '{group_col}'.")


        # --- Figure 2B: Orthogroup Size Distribution (Histogram) ---
        print("\n--- Figure 2B: Orthogroup Size Distribution (Histogram) ---")

        # Calculate OG sizes and group information using the filtered data
        if orthogroup_col in df_filtered_groups.columns and protein_id_col in df_filtered_groups.columns and group_col in df_filtered_groups.columns:
            # Get OG sizes by counting proteins per OG in the filtered data
            og_sizes = df_filtered_groups.groupby(orthogroup_col)[protein_id_col].count().reset_index()
            og_sizes.columns = [orthogroup_col, 'OG_Size']

            # Add group information to OG sizes (take group from the first protein in the OG)
            # This assumes all proteins in an OG belong to the same group (Asgard or GV)
            og_group = df_filtered_groups.drop_duplicates(subset=[orthogroup_col], keep='first')[[orthogroup_col, group_col]]
            og_sizes = pd.merge(og_sizes, og_group, on=orthogroup_col, how='left')

            # Filter OGs used for MSA (size >= 5)
            min_og_size_for_msa = 5 # Should match the criterion used in Cell 16
            ogs_for_msa = og_sizes[og_sizes['OG_Size'] >= min_og_size_for_msa]
            num_ogs_for_msa = len(ogs_for_msa)
            num_proteins_in_msa_ogs = ogs_for_msa['OG_Size'].sum()

            print(f"\nNumber of Orthogroups with size >= {min_og_size_for_msa} (used for MSA) in filtered data: {num_ogs_for_msa:,}")
            print(f"Total proteins in these OGs: {num_proteins_in_msa_ogs:,}")

            # Use Histogram for Orthogroup Size Distribution
            fig_2b = px.histogram(
                og_sizes,
                x='OG_Size',
                color='Group',
                barmode='overlay', # Overlay bars for Asgard and GV
                nbins=50, # Adjust number of bins as needed
                # title='Orthogroup Size Distribution Histogram', # No title
                labels={'OG_Size': 'Orthogroup Size (Number of Proteins)', 'Group': 'Group'},
                color_discrete_map=group_colors,
                template='plotly_white'
            )
            # Apply default layout and remove gridlines
            fig_2b.update_layout(plotly_layout_defaults)
            fig_2b.update_xaxes(title_text='Orthogroup Size (Number of Proteins)', showgrid=False)
            fig_2b.update_yaxes(title_text='Count (Number of Orthogroups)', showgrid=False)
            fig_2b.update_layout(legend_title_text='Group') # Add legend title
            fig_2b.update_traces(opacity=0.75) # Make bars slightly transparent for overlay

            # Export Figure 2B Histogram (HTML and PDF/SVG)
            fig_2b_path_html = Path(output_figure_dir) / "figure2b_og_size_distribution_histogram.html"
            fig_2b.write_html(str(fig_2b_path_html))
            print(f"Figure 2B (Histogram) HTML saved to: {fig_2b_path_html}")
            
            # Export to PDF/SVG (requires kaleido)
            try:
                fig_2b_path_pdf = Path(output_figure_dir) / "figure2b_og_size_distribution_histogram.pdf"
                fig_2b.write_image(str(fig_2b_path_pdf))
                print(f"Figure 2B (Histogram) PDF saved to: {fig_2b_path_pdf}")
            except Exception as e:
                print(f"Warning: Could not export Figure 2B (Histogram) to PDF. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}")
            try:
                fig_2b_path_svg = Path(output_figure_dir) / "figure2b_og_size_distribution_histogram.svg"
                fig_2b.write_image(str(fig_2b_path_svg))
                print(f"Figure 2B (Histogram) SVG saved to: {fig_2b_path_svg}")
            except Exception as e:
                print(f"Warning: Could not export Figure 2B (Histogram) to SVG. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}")


        else:
            print("Skipping Figure 2B due to missing data in filtered data.")


        # --- Figure 2C: High-Level Annotation Summary & Euk Hit Protein Types ---
        print("\n--- Figure 2C: High-Level Annotation Summary & Euk Hit Protein Types ---")

        # Define annotation columns to summarize
        # Ensure these columns exist before adding to the list in the filtered data
        annotation_cols_to_summarize = []
        if 'Has_Conserved_Motif' in df_filtered_groups.columns: annotation_cols_to_summarize.append('Has_Conserved_Motif')
        if structurally_dark_col in df_filtered_groups.columns: annotation_cols_to_summarize.append(structurally_dark_col)
        if esp_col in df_filtered_groups.columns: annotation_cols_to_summarize.append(esp_col)
        if hit_flag_col in df_filtered_groups.columns: annotation_cols_to_summarize.append(hit_flag_col)
        # Add other relevant boolean flags if available, e.g., has any domain (check column existence)
        if num_domains_col in df_filtered_groups.columns:
             # Create a boolean column for having any domain
             df_filtered_groups['Has_Any_Domain'] = df_filtered_groups[num_domains_col].notna()
             annotation_cols_to_summarize.append('Has_Any_Domain')


        # Calculate percentages for each annotation type by group using the filtered data
        annotation_summary_list = []
        for col in annotation_cols_to_summarize:
            # Ensure group column exists in filtered data (already checked above, but safe)
            if group_col in df_filtered_groups.columns:
                # Ensure column is boolean for value_counts
                df_filtered_groups[col] = df_filtered_groups[col].astype(bool)

                # Calculate value counts (True/False) for the current annotation column and group
                summary_counts = df_filtered_groups.groupby(group_col)[col].value_counts().reset_index(name='Count')
                summary_counts.rename(columns={col: 'Value'}, inplace=True) # Rename boolean col to 'Value'

                # Calculate percentages
                summary_percentages = df_filtered_groups.groupby(group_col)[col].value_counts(normalize=True).mul(100).reset_index(name='Percentage')
                summary_percentages.rename(columns={col: 'Value'}, inplace=True) # Rename boolean col to 'Value'

                # Merge counts and percentages
                summary_combined = pd.merge(summary_counts, summary_percentages, on=[group_col, 'Value'])

                # Filter for 'True' values and handle ESP specifically for GV
                summary_true = summary_combined[summary_combined['Value'] == True].copy()

                # Remove ESP entry for GV
                if col == esp_col:
                     summary_true = summary_true[summary_true[group_col] != 'GV'].copy() # Filter out GV for ESP

                if not summary_true.empty:
                     summary_true['Annotation_Type'] = col.replace('_', ' ').replace('Is ', '').strip() # Clean up name for plotting
                     annotation_summary_list.append(summary_true)
                # If summary_true is empty after filtering (e.g., no True values or ESP for GV), it's not added

            else:
                print(f"Warning: Group column '{group_col}' not found in filtered data. Skipping summary for all annotations.")
                break # Exit loop if group column is missing


        if annotation_summary_list:
            df_annotation_summary = pd.concat(annotation_summary_list, ignore_index=True)
            # Add a 'False' row with 0% for types that were all False (so they appear in the plot categories)
            # Get all unique annotation types that were processed (based on columns present in filtered data)
            all_processed_annotation_types = [s.replace('_', ' ').replace('Is ', '').strip() for s in annotation_cols_to_summarize if s in df_filtered_groups.columns]
            # Get types present in the summary (i.e., had at least one True value in at least one group)
            present_annotation_types = df_annotation_summary['Annotation_Type'].unique()
            # Find types that were processed but had no True values
            missing_annotation_types = [atype for atype in all_processed_annotation_types if atype not in present_annotation_types]

            if missing_annotation_types:
                 zero_percentage_rows = []
                 # Use the unique groups from the filtered data for adding zero rows
                 for atype in missing_annotation_types:
                      for group in df_filtered_groups[group_col].unique():
                           # Add a row with 0% for each group for the missing annotation type
                           zero_percentage_rows.append({'Group': group, 'Percentage': 0.0, 'Count': 0, 'Annotation_Type': atype})
                 df_annotation_summary = pd.concat([df_annotation_summary, pd.DataFrame(zero_percentage_rows)], ignore_index=True)


            print("\nHigh-Level Annotation Summary (% True by Group):")
            # Pivot for display, showing all processed types including those with 0%
            # Use the unique groups from the filtered data for the pivot columns
            df_annotation_display = df_annotation_summary.pivot_table(index='Annotation_Type', columns='Group', values='Percentage', fill_value=0).round(1)
            print(df_annotation_display.to_markdown())


            # --- Stacked Bar Chart (Grouped Bars) ---
            # Ensure category order for plotting matches the order defined in annotation_cols_to_summarize
            plot_category_order = [col.replace('_', ' ').replace('Is ', '').strip() for col in annotation_cols_to_summarize if col in df_filtered_groups.columns]

            fig_2c_stacked = px.bar(
                df_annotation_summary,
                x='Annotation_Type',
                y='Percentage',
                color='Group',
                barmode='group', # Use group mode for side-by-side bars per annotation
                # title='High-Level Annotation Summary by Group', # No title
                labels={'Annotation_Type': 'Annotation Type', 'Percentage': 'Percentage of Proteins (%)', 'Group': 'Group'},
                color_discrete_map=group_colors,
                category_orders={'Annotation_Type': plot_category_order}, # Apply consistent order
                template='plotly_white'
            )
            # Apply default layout and remove gridlines
            fig_2c_stacked.update_layout(plotly_layout_defaults)
            fig_2c_stacked.update_xaxes(title_text='Annotation Type', showgrid=False)
            fig_2c_stacked.update_yaxes(title_text='Percentage of Proteins (%)', showgrid=False)
            fig_2c_stacked.update_layout(legend_title_text='Group') # Add legend title
            fig_2c_stacked.update_xaxes(tickangle=45) # Angle labels if needed

            # Export Figure 2C Stacked Bar (HTML and PDF/SVG)
            fig_2c_stacked_path_html = Path(output_figure_dir) / "figure2c_annotation_summary_grouped_bar.html"
            fig_2c_stacked.write_html(str(fig_2c_stacked_path_html))
            print(f"Figure 2C (Grouped Bar) HTML saved to: {fig_2c_stacked_path_html}")
            
            # Export to PDF/SVG (requires kaleido)
            try:
                fig_2c_stacked_path_pdf = Path(output_figure_dir) / "figure2c_annotation_summary_grouped_bar.pdf"
                fig_2c_stacked.write_image(str(fig_2c_stacked_path_pdf))
                print(f"Figure 2C (Grouped Bar) PDF saved to: {fig_2c_stacked_path_pdf}")
            except Exception as e:
                print(f"Warning: Could not export Figure 2C (Grouped Bar) to PDF. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}")
            try:
                fig_2c_stacked_path_svg = Path(output_figure_dir) / "figure2c_annotation_summary_grouped_bar.svg"
                fig_2c_stacked.write_image(str(fig_2c_stacked_path_svg))
                print(f"Figure 2C (Grouped Bar) SVG saved to: {fig_2c_stacked_path_svg}")
            except Exception as e:
                print(f"Warning: Could not export Figure 2C (Grouped Bar) to SVG. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}")


            # --- Euk Hit Protein Name Analysis (Top Names) ---
            print("\n--- Euk Hit Protein Name Analysis (Top Names for Proteins with Euk Hits) ---")
            # Analyze proteins flagged with Has_Euk_DIAMOND_Hit within the filtered groups
            if hit_flag_col in df_filtered_groups.columns and euk_hit_protein_name_col in df_filtered_groups.columns and 'clean_protein_name' in locals():
                 df_euk_hits_annotated = df_filtered_groups[df_filtered_groups[hit_flag_col] == True].copy()

                 if not df_euk_hits_annotated.empty:
                      # Apply cleaning function (defined in Cell 1/Setup)
                      df_euk_hits_annotated['Cleaned_Euk_Prot_Name'] = df_euk_hits_annotated[euk_hit_protein_name_col].apply(clean_protein_name)
                      euk_hit_name_counts = df_euk_hits_annotated['Cleaned_Euk_Prot_Name'].value_counts()

                      # Filter out generic names like 'Uncharacterized' or 'Unknown' for plotting top names
                      euk_hit_name_counts_filtered = euk_hit_name_counts[~euk_hit_name_counts.index.isin(['Uncharacterized', 'Unknown/Not Found', 'nan'])]

                      top_n_euk_names = 20 # Define how many top names to show
                      df_top_euk_names = euk_hit_name_counts_filtered.head(top_n_euk_names).reset_index()
                      df_top_euk_names.columns = ['Cleaned_Euk_Prot_Name', 'Count']

                      print(f"\nTop {top_n_euk_names} Cleaned Eukaryotic Hit Protein Names (Functionally Annotated Hits):")
                      print(df_top_euk_names.to_markdown(index=False))

                      # Plot Top Euk Hit Names using a different color palette
                      if not df_top_euk_names.empty:
                           # Use a sequential or different discrete palette for different names
                           euk_name_colors = px.colors.qualitative.Plotly # Example palette
                           # Create a mapping for the top N names
                           euk_name_color_map = {name: euk_name_colors[i % len(euk_name_colors)] for i, name in enumerate(df_top_euk_names['Cleaned_Euk_Prot_Name'])}

                           fig_2c_euk_names = px.bar(
                               df_top_euk_names,
                               x='Cleaned_Euk_Prot_Name',
                               y='Count',
                               color='Cleaned_Euk_Prot_Name', # Color by protein name
                               color_discrete_map=euk_name_color_map, # Apply the name-specific colors
                               # title=f'Top {top_n_euk_names} Eukaryotic Hit Protein Names (Functionally Annotated Hits)', # No title
                               labels={'Cleaned_Euk_Prot_Name': 'Cleaned Eukaryotic Hit Protein Name', 'Count': 'Number of Asgard/GV Proteins Hit'},
                               template='plotly_white'
                           )
                           # Apply default layout and remove gridlines
                           fig_2c_euk_names.update_layout(plotly_layout_defaults)
                           fig_2c_euk_names.update_xaxes(title_text='Cleaned Eukaryotic Hit Protein Name', showgrid=False)
                           fig_2c_euk_names.update_yaxes(title_text='Number of Asgard/GV Proteins Hit', showgrid=False)
                           fig_2c_euk_names.update_xaxes(tickangle=45, categoryorder='total descending') # Angle labels and order by count
                           fig_2c_euk_names.update_layout(showlegend=False) # Hide legend if too many names

                           # Export Figure 2C Euk Names (HTML and PDF/SVG)
                           fig_2c_euk_names_path_html = Path(output_figure_dir) / "figure2c_top_euk_hit_protein_names.html"
                           fig_2c_euk_names.write_html(str(fig_2c_euk_names_path_html))
                           print(f"Figure 2C (Top Euk Hit Names) HTML saved to: {fig_2c_euk_names_path_html}")
                           
                           # Export to PDF/SVG (requires kaleido)
                           try:
                               fig_2c_euk_names_path_pdf = Path(output_figure_dir) / "figure2c_top_euk_hit_protein_names.pdf"
                               fig_2c_euk_names.write_image(str(fig_2c_euk_names_path_pdf))
                               print(f"Figure 2C (Top Euk Hit Names) PDF saved to: {fig_2c_euk_names_path_pdf}")
                           except Exception as e:
                               print(f"Warning: Could not export Figure 2C (Top Euk Hit Names) to PDF. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}")
                           try:
                               fig_2c_euk_names_path_svg = Path(output_figure_dir) / "figure2c_top_euk_hit_protein_names.svg"
                               fig_2c_euk_names.write_image(str(fig_2c_euk_names_path_svg))
                               print(f"Figure 2C (Top Euk Hit Names) SVG saved to: {fig_2c_euk_names_path_svg}")
                           except Exception as e:
                               print(f"Warning: Could not export Figure 2C (Top Euk Hit Names) to SVG. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}")


                 else:
                      print("No proteins with Eukaryotic hits found in filtered groups to analyze protein names.")
            else:
                 print(f"Skipping Euk Hit Protein Name analysis: Columns '{hit_flag_col}' or '{euk_hit_protein_name_col}' not found in filtered data, or 'clean_protein_name' function is missing.")


        else:
            print("Skipping Figure 2C annotation summary due to missing data in filtered data.")


        # --- Figure 2D: Broad Functional Category Distribution ---
        print("\n--- Figure 2D: Broad Functional Category Distribution ---")

        # Calculate counts and percentages using the filtered data
        if broad_func_cat_col in df_filtered_groups.columns and group_col in df_filtered_groups.columns:
            # Filter out unwanted categories for plotting
            categories_to_exclude = ['Unknown/Unclassified', 'Other Specific Annotation']
            df_category_plot_data = df_filtered_groups[
                (~df_filtered_groups[broad_func_cat_col].isin(categories_to_exclude))
            ].copy()

            # Check if df_category_plot_data is empty after filtering categories
            if df_category_plot_data.empty:
                print("Skipping Figure 2D plotting as no functional categories remain after excluding 'Unknown/Unclassified' and 'Other Specific Annotation'.")
            else:
                # Calculate counts for each category by group in the FILTERED data for plotting
                # Use value_counts to get counts for all categories (including 0 if they exist after filtering)
                category_counts_plot = df_category_plot_data.groupby([group_col, broad_func_cat_col]).size().reset_index(name='Count')

                # Calculate percentages for plotting (based on the FILTERED data for plotting)
                # Need to calculate total count in the filtered data (for plotting) for normalization
                total_filtered_proteins_plot_by_group = df_category_plot_data.groupby(group_col).size().reset_index(name='Total_Filtered_Plot')

                # Merge counts with total counts to calculate percentages
                df_category_plot_summary = pd.merge(category_counts_plot, total_filtered_proteins_plot_by_group, on=group_col, how='left')
                df_category_plot_summary['Percentage'] = (df_category_plot_summary['Count'] / df_category_plot_summary['Total_Filtered_Plot']) * 100
                df_category_plot_summary = df_category_plot_summary.drop(columns=['Total_Filtered_Plot']) # Drop the temporary total column


                # Save full category counts (including excluded) to a summary file
                # Use the original df_full, filtered only by group for the full summary
                df_full_filtered_groups_only = df_full[df_full[group_col].isin(['Asgard', 'GV'])].copy()
                original_category_summary = df_full_filtered_groups_only.groupby(group_col)[broad_func_cat_col].value_counts().reset_index(name='Count')
                original_category_percentages = df_full_filtered_groups_only.groupby(group_col)[broad_func_cat_col].value_counts(normalize=True).mul(100).reset_index(name='Percentage')
                df_full_category_summary = pd.merge(original_category_summary, original_category_percentages, on=[group_col, broad_func_cat_col])


                print("\nBroad Functional Category Distribution (Counts and Percentages - FULL SUMMARY):")
                print(df_full_category_summary.sort_values(['Group', 'Count'], ascending=[True, False]).to_markdown(index=False))

                full_category_summary_path = Path(output_figure_summary_dir) / "figure2d_broad_functional_category_full_summary.csv"
                try:
                     df_full_category_summary.to_csv(str(full_category_summary_path), index=False)
                     print(f"Full functional category summary saved to: {full_category_summary_path}")
                except Exception as e:
                     print(f"Error saving full category summary: {e}")


                # --- Filter by Count > 0 and Plot ---
                # Check if df_category_plot_summary is not empty and has 'Count' column before filtering
                # This check should now pass reliably if df_category_plot_data was not empty
                if not df_category_plot_summary.empty and 'Count' in df_category_plot_summary.columns:
                    # Only filter by Count > 0 if there are any counts > 0 to filter
                    if (df_category_plot_summary['Count'] > 0).any():
                        df_category_plot_summary_filtered_by_count = df_category_plot_summary[df_category_plot_summary['Count'] > 0].copy()
                    else:
                        # If all counts are 0, the filtered dataframe will be empty
                        df_category_plot_summary_filtered_by_count = pd.DataFrame(columns=df_category_plot_summary.columns)


                    # Define category order for plotting based on frequency in the PLOTTED data
                    if not df_category_plot_summary_filtered_by_count.empty:
                         plot_category_order = df_category_plot_summary_filtered_by_count.groupby(broad_func_cat_col)['Count'].sum().sort_values(ascending=False).index.tolist()

                         fig_2d = px.bar(
                             df_category_plot_summary_filtered_by_count, # Use the DataFrame filtered by count
                             x=broad_func_cat_col,
                             y='Percentage',
                             color=group_col,
                             barmode='group', # Side-by-side bars for Asgard vs GV
                             # title='Broad Functional Category Distribution by Group', # No title
                             labels={broad_func_cat_col: 'Broad Functional Category', 'Percentage': 'Percentage of Proteins (%)', group_col: 'Group'},
                             color_discrete_map=group_colors,
                             category_orders={broad_func_cat_col: plot_category_order}, # Apply consistent order based on plotted data
                             template='plotly_white'
                         )
                         # Apply default layout and remove gridlines
                         fig_2d.update_layout(plotly_layout_defaults)
                         fig_2d.update_xaxes(title_text='Broad Functional Category', showgrid=False)
                         fig_2d.update_yaxes(title_text='Percentage of Proteins (%)', showgrid=False)
                         fig_2d.update_layout(legend_title_text='Group') # Add legend title
                         fig_2d.update_xaxes(tickangle=45) # Angle labels if needed

                         # Export Figure 2D (HTML and PDF/SVG)
                         fig_2d_path_html = Path(output_figure_dir) / "figure2d_broad_functional_category_distribution_filtered.html"
                         fig_2d.write_html(str(fig_2d_path_html))
                         print(f"Figure 2D HTML saved to: {fig_2d_path_html}")
                         
                         # Export to PDF/SVG (requires kaleido)
                         try:
                             fig_2d_path_pdf = Path(output_figure_dir) / "figure2d_broad_functional_category_distribution_filtered.pdf"
                             fig_2d.write_image(str(fig_2d_path_pdf))
                             print(f"Figure 2D PDF saved to: {fig_2d_path_pdf}")
                         except Exception as e:
                             print(f"Warning: Could not export Figure 2D to PDF. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}")
                         try:
                             fig_2d_path_svg = Path(output_figure_dir) / "figure2d_broad_functional_category_distribution_filtered.svg"
                             fig_2d.write_image(str(fig_2d_path_svg))
                             print(f"Figure 2D SVG saved to: {fig_2d_path_svg}")
                         except Exception as e:
                             print(f"Warning: Could not export Figure 2D to SVG. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}")

                    else:
                         print("No functional categories remaining after filtering for plotting Figure 2D (after removing zero counts).")
                else:
                     print("Skipping Figure 2D plot as df_category_plot_summary is empty or missing 'Count' column before filtering by count.")


        else:
            print("Skipping Figure 2D due to missing data in filtered groups.")

    else:
        print("Skipping Figure 2 generation as no 'Asgard' or 'GV' data is available after initial filtering.")


print("\n\n--- Figure 2 Generation Complete ---")
print(f"Figures saved to the '{output_figure_dir}' directory (HTML, PDF, SVG).")
print(f"Full functional category summary saved to '{output_figure_summary_dir}'.")


In [None]:
# Cell 4: Revised Figure 2C - Annotation Summary Bar Chart

# --- Imports ---
# Assumes pandas, numpy, plotly.express, plotly.graph_objects, plotly.io,
# and necessary variables/helper functions (like df_full, group_col,
# protein_id_col, orthogroup_col, broad_func_cat_col, structurally_dark_col,
# esp_col, hit_flag_col, euk_hit_protein_name_col, clean_protein_name,
# arcadia_colors_manual, arcadia_primary_palette, arcadia_secondary_palette,
# arcadia_neutrals_palette, output_figure_dir, output_figure_summary_dir,
# plotly_layout_defaults) are defined in the setup cell (Cell 1).
# Assumes df_filtered_groups is available from Cell 2 (Figure 2 generation) and
# correctly filters for 'Asgard' and 'GV' groups.

print("\n\n--- Generating Revised Figure 2C: Annotation Summary Bar Chart (Cell 4) ---")
print("Note: Orthogroup size distribution figure (Figure 2B) is being reconsidered due to skewness.")
print("Note: Top Eukaryotic Hit Protein Names are presented as a table (in Cell 3).")


# Ensure df_filtered_groups is available and not empty
if 'df_filtered_groups' not in locals() or df_filtered_groups.empty:
    print("ERROR: df_filtered_groups not available or is empty. Please run Cell 2 first.")
else:
    # Define colors for Asgard and GV using the specified hex codes (re-defined for clarity in this cell)
    group_colors = {
        'Asgard': arcadia_colors_manual.get('aegean', '#5088C5'),
        'GV': arcadia_colors_manual.get('amber', '#F28360')
    }
    print(f"\nUsing Group Colors: Asgard={group_colors.get('Asgard')}, GV={group_colors.get('GV')}")


    # --- Revised Figure 2C: High-Level Annotation Summary Bar Chart ---
    print("\n--- Revised Figure 2C: High-Level Annotation Summary Bar Chart ---")

    # Define annotation columns to include in THIS specific bar chart
    # Exclude 'Has_Conserved_Motif' as requested for this plot
    annotation_cols_for_plot = []
    if structurally_dark_col in df_filtered_groups.columns: annotation_cols_for_plot.append(structurally_dark_col)
    if esp_col in df_filtered_groups.columns: annotation_cols_for_plot.append(esp_col)
    if hit_flag_col in df_filtered_groups.columns: annotation_cols_for_plot.append(hit_flag_col)
    # Add 'Has_Any_Domain' if it was created in Cell 2
    if 'Has_Any_Domain' in df_filtered_groups.columns: annotation_cols_for_plot.append('Has_Any_Domain')
    # Add other relevant boolean flags if available and desired for this plot


    annotation_plot_data = []
    if group_col in df_filtered_groups.columns:
        # Calculate total proteins per group in the filtered data for percentage calculation
        total_proteins_per_group = df_filtered_groups.groupby(group_col).size()

        for col in annotation_cols_for_plot:
            # Ensure column is boolean
            df_filtered_groups[col] = df_filtered_groups[col].astype(bool)

            # Calculate counts of True values per group
            true_counts = df_filtered_groups.groupby(group_col)[col].sum().reset_index(name='Count')

            # Calculate percentages of True values per group
            percentage_data = []
            for index, row in true_counts.iterrows():
                 group = row['Group']
                 count = row['Count']
                 total = total_proteins_per_group.get(group, 0) # Get total for this group, default to 0 if group not found
                 percentage = (count / total * 100) if total > 0 else 0.0
                 percentage_data.append({'Group': group, 'Percentage': percentage})

            df_percentages = pd.DataFrame(percentage_data)

            # Merge counts and percentages
            summary_combined = pd.merge(true_counts, df_percentages, on='Group')

            # Add annotation type column
            summary_combined['Annotation_Type'] = col.replace('_', ' ').replace('Is ', '').strip()

            # Handle ESP specifically for GV (remove if Group is GV and Annotation is ESP)
            if col == esp_col:
                 summary_combined = summary_combined[summary_combined['Group'] != 'GV'].copy()

            if not summary_combined.empty:
                 annotation_plot_data.append(summary_combined)

    else:
         print(f"Warning: Group column '{group_col}' not found in filtered data. Skipping annotation summary plot.")


    if annotation_plot_data:
        df_annotation_summary_plot = pd.concat(annotation_plot_data, ignore_index=True)

        # Ensure all combinations of Annotation_Type and Group are present, fill missing with 0
        # Use the original list of columns for the plot for consistent categories
        all_plot_annotation_types = [col.replace('_', ' ').replace('Is ', '').strip() for col in annotation_cols_for_plot]
        all_groups = df_filtered_groups[group_col].unique() # Use groups from filtered data

        # Create a multi-index to ensure all combinations are covered
        multi_index = pd.MultiIndex.from_product([all_plot_annotation_types, all_groups], names=['Annotation_Type', 'Group'])
        df_annotation_summary_plot = df_annotation_summary_plot.set_index(['Annotation_Type', 'Group']).reindex(multi_index, fill_value=0).reset_index()


        # Filter out rows where percentage is 0 for plotting (optional, but cleans plot)
        # Or keep them to show 0 bars
        # Let's keep them to show 0 bars for completeness unless specifically requested to remove
        # df_annotation_summary_plot = df_annotation_summary_plot[df_annotation_summary_plot['Percentage'] > 0].copy()


        # Define category order for plotting based on the order defined in annotation_cols_for_plot
        plot_category_order = [col.replace('_', ' ').replace('Is ', '').strip() for col in annotation_cols_for_plot]

        if not df_annotation_summary_plot.empty:
             fig_2c_revised = px.bar(
                 df_annotation_summary_plot,
                 x='Annotation_Type',
                 y='Percentage',
                 color='Group',
                 barmode='group', # Side-by-side bars for Asgard vs GV
                 # title='High-Level Annotation Summary by Group', # No title
                 labels={'Annotation_Type': 'Annotation Type', 'Percentage': 'Percentage of Proteins (%)', 'Group': 'Group'},
                 color_discrete_map=group_colors,
                 category_orders={'Annotation_Type': plot_category_order}, # Apply consistent order
                 template='plotly_white'
             )
             # Apply default layout and remove gridlines
             fig_2c_revised.update_layout(plotly_layout_defaults)
             fig_2c_revised.update_xaxes(title_text='Annotation Type', showgrid=False)
             fig_2c_revised.update_yaxes(title_text='Percentage of Proteins (%)', showgrid=False)
             fig_2c_revised.update_layout(legend_title_text='Group') # Add legend title
             fig_2c_revised.update_xaxes(tickangle=45) # Angle labels if needed

             # Export Revised Figure 2C (HTML, PDF, SVG)
             fig_2c_revised_path_html = Path(output_figure_dir) / "figure2c_annotation_summary_revised_bar.html"
             fig_2c_revised.write_html(str(fig_2c_revised_path_html))
             print(f"Revised Figure 2C HTML saved to: {fig_2c_revised_path_html}")

             # Export to PDF/SVG (requires kaleido)
             try:
                 fig_2c_revised_path_pdf = Path(output_figure_dir) / "figure2c_annotation_summary_revised_bar.pdf"
                 fig_2c_revised.write_image(str(fig_2c_revised_path_pdf))
                 print(f"Revised Figure 2C PDF saved to: {fig_2c_revised_path_pdf}")
             except Exception as e:
                 print(f"Warning: Could not export Revised Figure 2C to PDF. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}")
             try:
                 fig_2c_revised_path_svg = Path(output_figure_dir) / "figure2c_annotation_summary_revised_bar.svg"
                 fig_2c_revised.write_image(str(fig_2c_revised_path_svg))
                 print(f"Revised Figure 2C SVG saved to: {fig_2c_revised_path_svg}")
             except Exception as e:
                 print(f"Warning: Could not export Revised Figure 2C to SVG. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}")

        else:
             print("No annotation data remaining after processing for plotting Revised Figure 2C.")

    else:
        print("Skipping Revised Figure 2C plotting as no annotation columns were defined or data is missing.")


print("\n\n--- Cell 4 (Revised Figure 2C) Complete ---")
print(f"Revised Figure 2C saved to the '{output_figure_dir}' directory (HTML, PDF, SVG).")


In [None]:
# Cell 5: Figure 3 - Localization and Functional Category Distribution

# --- Imports ---
# Assumes pandas, numpy, plotly.express, plotly.graph_objects, plotly.io,
# and necessary variables/helper functions (like df_full, group_col,
# protein_id_col, localization_col, broad_func_cat_col,
# arcadia_colors_manual, arcadia_primary_palette, arcadia_secondary_palette,
# arcadia_neutrals_palette, output_figure_dir, output_figure_summary_dir,
# plotly_layout_defaults, localization_color_map, broad_category_color_map)
# are defined in the setup cell (Cell 1).
# Assumes df_filtered_groups is available from Cell 2 (Figure 2 generation) and
# correctly filters for 'Asgard' and 'GV' groups.

print("\n\n--- Generating Figure 3: Localization and Functional Category Distribution (Cell 5) ---")

# Ensure df_filtered_groups is available and not empty
if 'df_filtered_groups' not in locals() or df_filtered_groups.empty:
    print("ERROR: df_filtered_groups not available or is empty. Please run Cell 2 first.")
else:
    # Define colors for Asgard and GV using the specified hex codes (re-defined for clarity in this cell)
    group_colors = {
        'Asgard': arcadia_colors_manual.get('aegean', '#5088C5'),
        'GV': arcadia_colors_manual.get('amber', '#F28360')
    }
    print(f"\nUsing Group Colors: Asgard={group_colors.get('Asgard')}, GV={group_colors.get('GV')}")


    # --- Figure 3A: Predicted Subcellular Localization Distribution ---
    print("\n--- Figure 3A: Predicted Subcellular Localization Distribution ---")

    if localization_col in df_filtered_groups.columns and group_col in df_filtered_groups.columns:
        # Calculate counts for each localization category by group
        localization_counts = df_filtered_groups.groupby(group_col)[localization_col].value_counts().reset_index(name='Count')

        # Calculate percentages for plotting
        localization_percentages = df_filtered_groups.groupby(group_col)[localization_col].value_counts(normalize=True).mul(100).reset_index(name='Percentage')

        # Merge counts and percentages
        df_localization_summary = pd.merge(localization_counts, localization_percentages, on=[group_col, localization_col])

        print("\nPredicted Subcellular Localization Distribution (Counts and Percentages):")
        print(df_localization_summary.sort_values(['Group', 'Count'], ascending=[True, False]).to_markdown(index=False))

        # Save full localization summary to a summary file
        full_localization_summary_path = Path(output_figure_summary_dir) / "figure3a_localization_full_summary.csv"
        try:
             df_localization_summary.to_csv(str(full_localization_summary_path), index=False)
             print(f"Full localization summary saved to: {full_localization_summary_path}")
        except Exception as e:
             print(f"Error saving full localization summary: {e}")

        # Define category order based on overall frequency for consistent plotting
        overall_localization_order = df_filtered_groups[localization_col].value_counts().index.tolist()

        if not df_localization_summary.empty:
             fig_3a = px.bar(
                 df_localization_summary,
                 x=group_col, # X-axis is Group
                 y='Percentage',
                 color=localization_col, # Color is Localization Category
                 barmode='group', # Group bars by X-axis value (Group)
                 # title='Predicted Subcellular Localization Distribution by Group', # No title
                 labels={localization_col: 'Predicted Subcellular Localization', 'Percentage': 'Percentage of Proteins (%)', group_col: 'Group'},
                 color_discrete_map=localization_color_map, # Use localization color map
                 category_orders={group_col: ['Asgard', 'GV'], localization_col: overall_localization_order}, # Apply consistent order
                 template='plotly_white'
             )
             # Apply default layout and remove gridlines
             fig_3a.update_layout(plotly_layout_defaults)
             fig_3a.update_xaxes(title_text='Group', showgrid=False) # X-axis title is Group
             fig_3a.update_yaxes(title_text='Percentage of Proteins (%)', showgrid=False)
             fig_3a.update_layout(legend_title_text='Predicted Subcellular Localization') # Add legend title
             fig_3a.update_xaxes(tickangle=0) # No angle needed for Group labels


             # Export Figure 3A (HTML, PDF, SVG)
             fig_3a_path_html = Path(output_figure_dir) / "figure3a_localization_distribution.html"
             fig_3a.write_html(str(fig_3a_path_html))
             print(f"Figure 3A HTML saved to: {fig_3a_path_html}")

             # Export to PDF/SVG (requires kaleido)
             try:
                 fig_3a_path_pdf = Path(output_figure_dir) / "figure3a_localization_distribution.pdf"
                 fig_3a.write_image(str(fig_3a_path_pdf))
                 print(f"Figure 3A PDF saved to: {fig_3a_path_pdf}")
             except Exception as e:
                 print(f"Warning: Could not export Figure 3A to PDF. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}")
             try:
                 fig_3a_path_svg = Path(output_figure_dir) / "figure3a_localization_distribution.svg"
                 fig_3a.write_image(str(fig_3a_path_svg))
                 print(f"Figure 3A SVG saved to: {fig_3a_path_svg}")
             except Exception as e:
                 print(f"Warning: Could not export Figure 3A to SVG. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}")

        else:
             print("No localization data available for plotting Figure 3A.")

    else:
        print(f"Skipping Figure 3A due to missing columns in filtered data: '{localization_col}' or '{group_col}'.")


    # --- Figure 3B: Broad Functional Category Distribution ---
    print("\n--- Figure 3B: Broad Functional Category Distribution ---")

    if broad_func_cat_col in df_filtered_groups.columns and group_col in df_filtered_groups.columns:
        # Filter out unwanted categories for plotting (same as Figure 2D)
        categories_to_exclude = ['Unknown/Unclassified', 'Other Specific Annotation']
        df_category_plot_data = df_filtered_groups[
            (~df_filtered_groups[broad_func_cat_col].isin(categories_to_exclude))
        ].copy()

        # Check if df_category_plot_data is empty after filtering categories
        if df_category_plot_data.empty:
            print("Skipping Figure 3B plotting as no functional categories remain after excluding 'Unknown/Unclassified' and 'Other Specific Annotation'.")
        else:
            # Calculate counts for each category by group in the FILTERED data for plotting
            # Use value_counts to get counts for all categories (including 0 if they exist after filtering)
            category_counts_plot = df_category_plot_data.groupby([group_col, broad_func_cat_col]).size().reset_index(name='Count')

            # Calculate percentages for plotting (based on the FILTERED data for plotting)
            # Need to calculate total count in the filtered data (for plotting) for normalization
            total_filtered_proteins_plot_by_group = df_category_plot_data.groupby(group_col).size().reset_index(name='Total_Filtered_Plot')

            # Merge counts with total counts to calculate percentages
            df_category_plot_summary = pd.merge(category_counts_plot, total_filtered_proteins_plot_by_group, on=group_col, how='left')
            df_category_plot_summary['Percentage'] = (df_category_plot_summary['Count'] / df_category_plot_summary['Total_Filtered_Plot']) * 100
            df_category_plot_summary = df_category_plot_summary.drop(columns=['Total_Filtered_Plot']) # Drop the temporary total column


            # --- Filter by Count > 0 and Plot ---
            # Check if df_category_plot_summary is not empty and has 'Count' column before filtering
            if not df_category_plot_summary.empty and 'Count' in df_category_plot_summary.columns:
                # Only filter by Count > 0 if there are any counts > 0 to filter
                if (df_category_plot_summary['Count'] > 0).any():
                    df_category_plot_summary_filtered_by_count = df_category_plot_summary[df_category_plot_summary['Count'] > 0].copy()
                else:
                    # If all counts are 0, the filtered dataframe will be empty
                    df_category_plot_summary_filtered_by_count = pd.DataFrame(columns=df_category_plot_summary.columns)


                # Define category order for plotting based on frequency in the PLOTTED data
                if not df_category_plot_summary_filtered_by_count.empty:
                     plot_category_order = df_category_plot_summary_filtered_by_count.groupby(broad_func_cat_col)['Count'].sum().sort_values(ascending=False).index.tolist()

                     fig_3b = px.bar(
                         df_category_plot_summary_filtered_by_count, # Use the DataFrame filtered by count
                         x=group_col, # X-axis is Group
                         y='Percentage',
                         color=broad_func_cat_col, # Color is Functional Category
                         barmode='group', # Group bars by X-axis value (Group)
                         # title='Broad Functional Category Distribution by Group', # No title
                         labels={broad_func_cat_col: 'Broad Functional Category', 'Percentage': 'Percentage of Proteins (%)', group_col: 'Group'},
                         color_discrete_map=broad_category_color_map, # Use functional category color map
                         category_orders={group_col: ['Asgard', 'GV'], broad_func_cat_col: plot_category_order}, # Apply consistent order
                         template='plotly_white'
                     )
                     # Apply default layout and remove gridlines
                     fig_3b.update_layout(plotly_layout_defaults)
                     fig_3b.update_xaxes(title_text='Group', showgrid=False) # X-axis title is Group
                     fig_3b.update_yaxes(title_text='Percentage of Proteins (%)', showgrid=False)
                     fig_3b.update_layout(legend_title_text='Broad Functional Category') # Add legend title
                     fig_3b.update_xaxes(tickangle=0) # No angle needed for Group labels


                     # Export Figure 3B (HTML, PDF, SVG)
                     fig_3b_path_html = Path(output_figure_dir) / "figure3b_broad_functional_category_distribution_filtered.html"
                     fig_3b.write_html(str(fig_3b_path_html))
                     print(f"Figure 3B HTML saved to: {fig_3b_path_html}")

                     # Export to PDF/SVG (requires kaleido)
                     try:
                         fig_3b_path_pdf = Path(output_figure_dir) / "figure3b_broad_functional_category_distribution_filtered.pdf"
                         fig_3b.write_image(str(fig_3b_path_pdf))
                         print(f"Figure 3B PDF saved to: {fig_3b_path_pdf}")
                     except Exception as e:
                         print(f"Warning: Could not export Figure 3B to PDF. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}")
                     try:
                         fig_3b_path_svg = Path(output_figure_dir) / "figure3b_broad_functional_category_distribution_filtered.svg"
                         fig_3b.write_image(str(fig_3b_path_svg))
                         print(f"Figure 3B SVG saved to: {fig_3b_path_svg}")
                     except Exception as e:
                         print(f"Warning: Could not export Figure 3B to SVG. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}")

print("\n\n--- Figure 3 Generation Complete ---")
print(f"Figures saved to the '{output_figure_dir}' directory (HTML, PDF, SVG).")
# Note: Full functional category summary is saved in Cell 2 (Figure 2D)



In [None]:
# Cell 6: Figure 4 - Genome and Protein Counts by Taxonomy

# --- Imports ---
# Assumes pandas, numpy, plotly.express, plotly.graph_objects, plotly.io,
# and necessary variables/helper functions (like df_full, group_col,
# protein_id_col, source_genome_accession_col, asgard_phylum_col,
# virus_family_col, arcadia_colors_manual, arcadia_primary_palette,
# arcadia_secondary_palette, arcadia_neutrals_palette, output_figure_dir,
# plotly_layout_defaults)
# are defined in the setup cell (Cell 1).
# Assumes df_filtered_groups is available from Cell 2 (Figure 2 generation) and
# correctly filters for 'Asgard' and 'GV' groups.

print("\n\n--- Generating Figure 4: Genome and Protein Counts by Taxonomy (Cell 6) ---")

# Ensure df_filtered_groups is available and not empty
if 'df_filtered_groups' not in locals() or df_filtered_groups.empty:
    print("ERROR: df_filtered_groups not available or is empty. Please run Cell 2 first.")
else:
    # Define colors for Asgard and GV groups (re-defined for context):
    group_colors = {
        'Asgard': arcadia_colors_manual.get('aegean', '#5088C5'),
        'GV': arcadia_colors_manual.get('amber', '#F28360')
    }
    print(f"\nUsing Group Colors (for context): Asgard={group_colors.get('Asgard')}, GV={group_colors.get('Amber')}")


    # --- Calculate Genome and Protein Counts by Taxonomy ---
    print("\n--- Calculating Genome and Protein Counts by Taxonomy ---")

    genome_counts_by_group_taxonomy = []
    protein_counts_by_group_taxonomy = []
    all_taxonomic_units_present = set() # Collect all unique units from both datasets

    if source_genome_accession_col in df_filtered_groups.columns and group_col in df_filtered_groups.columns and asgard_phylum_col in df_filtered_groups.columns and virus_family_col in df_filtered_groups.columns and protein_id_col in df_filtered_groups.columns:

        # Process Asgard counts by Phylum
        df_asgard = df_filtered_groups[df_filtered_groups[group_col] == 'Asgard'].copy()
        if not df_asgard.empty:
            if asgard_phylum_col in df_asgard.columns:
                # Genome counts
                asgard_genome_counts = df_asgard.groupby([group_col, asgard_phylum_col])[source_genome_accession_col].nunique().reset_index(name='Count')
                asgard_genome_counts.rename(columns={asgard_phylum_col: 'Taxonomic_Unit'}, inplace=True)
                genome_counts_by_group_taxonomy.append(asgard_genome_counts)
                all_taxonomic_units_present.update(asgard_genome_counts['Taxonomic_Unit'].dropna().unique())

                # Protein counts
                asgard_protein_counts = df_asgard.groupby([group_col, asgard_phylum_col])[protein_id_col].count().reset_index(name='Count')
                asgard_protein_counts.rename(columns={asgard_phylum_col: 'Taxonomic_Unit'}, inplace=True)
                protein_counts_by_group_taxonomy.append(asgard_protein_counts)
                all_taxonomic_units_present.update(asgard_protein_counts['Taxonomic_Unit'].dropna().unique())
            else:
                 print(f"Warning: Skipping Asgard counts. Column '{asgard_phylum_col}' missing.")
        else:
             print("Warning: Skipping Asgard counts. Data is empty.")


        # Process GV counts by Family
        df_gv = df_filtered_groups[df_filtered_groups[group_col] == 'GV'].copy()
        if not df_gv.empty:
            if virus_family_col in df_gv.columns:
                # Genome counts
                gv_genome_counts = df_gv.groupby([group_col, virus_family_col])[source_genome_accession_col].nunique().reset_index(name='Count')
                gv_genome_counts.rename(columns={virus_family_col: 'Taxonomic_Unit'}, inplace=True)
                genome_counts_by_group_taxonomy.append(gv_genome_counts)
                all_taxonomic_units_present.update(gv_genome_counts['Taxonomic_Unit'].dropna().unique())

                # Protein counts
                gv_protein_counts = df_gv.groupby([group_col, virus_family_col])[protein_id_col].count().reset_index(name='Count')
                gv_protein_counts.rename(columns={virus_family_col: 'Taxonomic_Unit'}, inplace=True)
                protein_counts_by_group_taxonomy.append(gv_protein_counts)
                all_taxonomic_units_present.update(gv_protein_counts['Taxonomic_Unit'].dropna().unique())
            else:
                 print(f"Warning: Skipping GV counts. Column '{virus_family_col}' missing.")
        else:
             print("Warning: Skipping GV counts. Data is empty.")


        # Combine results into DataFrames
        df_genome_counts_taxonomy = pd.concat(genome_counts_by_group_taxonomy, ignore_index=True) if genome_counts_by_group_taxonomy else pd.DataFrame()
        df_protein_counts_taxonomy = pd.concat(protein_counts_by_group_taxonomy, ignore_index=True) if protein_counts_by_group_taxonomy else pd.DataFrame()

        print("\nGenome Counts by Group and Taxonomy (Partial):")
        print(df_genome_counts_taxonomy.to_markdown(index=False))
        print("\nProtein Counts by Group and Taxonomy (Partial):")
        print(df_protein_counts_taxonomy.to_markdown(index=False))


        # --- Create Comprehensive Taxonomic Unit Color Map ---
        # Use a combination of Arcadia palettes for the color map
        # Prioritize primary, then secondary, then neutrals/monochromatic shades
        # Use a larger pool of colors to minimize repetition
        full_arcadia_palette = arcadia_primary_palette + arcadia_secondary_palette + arcadia_neutrals_palette + [
            arcadia_colors_manual.get('lapis', '#2B66A2'), arcadia_colors_manual.get('dusk', '#094468'), # Blue shades
            arcadia_colors_manual.get('melon', '#FFCFAF'), arcadia_colors_manual.get('cinnabar', '#9E3F41'), # Orange/Red shades
            arcadia_colors_manual.get('sun', '#FFD364'), arcadia_colors_manual.get('mustard', '#D68D22'), # Yellow shades
            arcadia_colors_manual.get('iris', '#DCDFEF'), arcadia_colors_manual.get('tanzanite', '#54448C'), # Purple shades
            arcadia_colors_manual.get('glass', '#C3E2DB'), arcadia_colors_manual.get('asparagus', '#2A6B5E'), # Teal shades
            arcadia_colors_manual.get('putty', '#FFE3D4'), arcadia_colors_manual.get('candy', '#E2718F'), # Pink shades
            arcadia_colors_manual.get('stone', '#EDE6DA'), arcadia_colors_manual.get('dove', '#CAD4DB'), # Warm/Cool gray shades
            arcadia_colors_manual.get('steel', '#687787'), arcadia_colors_manual.get('mud', '#635C5A') # More neutrals
        ]

        # Assign colors to ALL unique taxonomic units found in either dataset
        combined_taxonomy_color_map = {}
        sorted_all_taxonomic_units = sorted(list(all_taxonomic_units_present)) # Sort for consistent assignment
        for i, unit in enumerate(sorted_all_taxonomic_units):
            combined_taxonomy_color_map[unit] = full_arcadia_palette[i % len(full_arcadia_palette)]

        # Add 'Unknown Phylum'/'Unknown Family' if they exist and weren't in the unique list
        if 'Unknown Phylum' not in combined_taxonomy_color_map:
             combined_taxonomy_color_map['Unknown Phylum'] = arcadia_colors_manual.get('gray', '#bdbdbd')
        if 'Unknown Family' not in combined_taxonomy_color_map:
             combined_taxonomy_color_map['Unknown Family'] = arcadia_colors_manual.get('gray', '#bdbdbd')


        print(f"\nCombined taxonomic unit color map created for {len(combined_taxonomy_color_map)} units.")


        # --- Figure 4A: Number of Genomes by Taxonomy (Stacked Bar and Pie Charts) ---
        print("\n--- Figure 4A: Number of Genomes by Taxonomy ---")

        if not df_genome_counts_taxonomy.empty:
            # --- Figure 4A Part 1: Stacked Bar Chart ---
            # Calculate total count for each Taxonomic_Unit across both groups for sorting
            taxonomic_unit_total_counts_genomes = df_genome_counts_taxonomy.groupby('Taxonomic_Unit')['Count'].sum().sort_values(ascending=False)
            # Get the order of taxonomic units based on total counts
            taxonomic_unit_order_genomes = taxonomic_unit_total_counts_genomes.index.tolist()


            fig_4a_bar = px.bar(
                df_genome_counts_taxonomy,
                x=group_col, # X-axis is Group
                y='Count',
                color='Taxonomic_Unit', # Color is Phylum or Family
                barmode='stack', # Stack bars by Taxonomic_Unit
                # title='Number of Genomes by Group and Taxonomy (Stacked Bar)', # No title
                labels={group_col: 'Group', 'Count': 'Number of Genomes', 'Taxonomic_Unit': 'Phylum or Family'},
                color_discrete_map=combined_taxonomy_color_map, # Use combined color map
                category_orders={group_col: ['Asgard', 'GV'], 'Taxonomic_Unit': taxonomic_unit_order_genomes}, # Ensure Group order and Taxonomic Unit order
                template='plotly_white'
            )
            # Apply default layout and remove gridlines
            fig_4a_bar.update_layout(plotly_layout_defaults)
            fig_4a_bar.update_xaxes(title_text='Group', showgrid=False)
            fig_4a_bar.update_yaxes(title_text='Number of Genomes', showgrid=False)
            fig_4a_bar.update_layout(legend_title_text='Taxonomic Unit') # Add legend title


            # Export Figure 4A Part 1 (HTML, PDF, SVG)
            fig_4a_bar_path_html = Path(output_figure_dir) / "figure4a_genomes_by_taxonomy_stackedbar.html"
            fig_4a_bar.write_html(str(fig_4a_bar_path_html))
            print(f"Figure 4A (Stacked Bar) HTML saved to: {fig_4a_bar_path_html}")

            # Export to PDF/SVG (requires kaleido)
            try:
                fig_4a_bar_path_pdf = Path(output_figure_dir) / "figure4a_genomes_by_taxonomy_stackedbar.pdf"
                fig_4a_bar.write_image(str(fig_4a_bar_path_pdf))
                print(f"Figure 4A (Stacked Bar) PDF saved to: {fig_4a_bar_path_pdf}")
            except Exception as e:
                print(f"Warning: Could not export Figure 4A (Stacked Bar) to PDF. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}")
            try:
                fig_4a_bar_path_svg = Path(output_figure_dir) / "figure4a_genomes_by_taxonomy_stackedbar.svg"
                fig_4a_bar.write_image(str(fig_4a_bar_path_svg))
                print(f"Figure 4A (Stacked Bar) SVG saved to: {fig_4a_bar_path_svg}")
            except Exception as e:
                print(f"Warning: Could not export Figure 4A (Stacked Bar) to SVG. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}")


            # --- Figure 4A Part 2: Pie Charts (Asgard and GV) ---
            print("\n--- Figure 4A Part 2: Genome Counts Pie Charts (Asgard and GV) ---")

            # Asgard Genome Pie Chart
            df_asgard_genome_pie = df_genome_counts_taxonomy[df_genome_counts_taxonomy[group_col] == 'Asgard'].copy()
            if not df_asgard_genome_pie.empty:
                 # Sort Asgard phyla by count for pie chart order
                 df_asgard_genome_pie = df_asgard_genome_pie.sort_values('Count', ascending=False).copy()
                 fig_4a_pie_asgard = go.Figure(data=[go.Pie(
                     labels=df_asgard_genome_pie['Taxonomic_Unit'],
                     values=df_asgard_genome_pie['Count'],
                     hole=.3, # Creates a donut chart
                     marker=dict(colors=[combined_taxonomy_color_map.get(unit, 'gray') for unit in df_asgard_genome_pie['Taxonomic_Unit']]) # Apply colors
                 )])
                 # fig_4a_pie_asgard.update_layout(title_text='Asgard Genome Counts by Phylum') # No title
                 fig_4a_pie_asgard.update_layout(plotly_layout_defaults) # Apply defaults (no title, etc.)
                 fig_4a_pie_asgard.update_layout(margin=dict(t=0, b=0, l=0, r=0)) # Adjust margins
                 fig_4a_pie_asgard.update_traces(textinfo='percent+label', insidetextorientation='radial')

                 # Export Figure 4A Pie Asgard
                 fig_4a_pie_asgard_path_html = Path(output_figure_dir) / "figure4a_genomes_by_phylum_pie_asgard.html"
                 fig_4a_pie_asgard.write_html(str(fig_4a_pie_asgard_path_html))
                 print(f"Figure 4A (Asgard Pie) HTML saved to: {fig_4a_pie_asgard_path_html}")
                 try:
                     fig_4a_pie_asgard_path_pdf = Path(output_figure_dir) / "figure4a_genomes_by_phylum_pie_asgard.pdf"
                     fig_4a_pie_asgard.write_image(str(fig_4a_pie_asgard_path_pdf))
                     print(f"Figure 4A (Asgard Pie) PDF saved to: {fig_4a_pie_asgard_path_pdf}")
                 except Exception as e:
                     print(f"Warning: Could not export Figure 4A (Asgard Pie) to PDF. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}")
                 try:
                     fig_4a_pie_asgard_path_svg = Path(output_figure_dir) / "figure4a_genomes_by_phylum_pie_asgard.svg"
                     fig_4a_pie_asgard.write_image(str(fig_4a_pie_asgard_path_svg))
                     print(f"Figure 4A (Asgard Pie) SVG saved to: {fig_4a_pie_asgard_path_svg}")
                 except Exception as e:
                     print(f"Warning: Could not export Figure 4A (Asgard Pie) to SVG. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}")

            else:
                 print("No Asgard genome data for pie chart.")

            # GV Genome Pie Chart
            df_gv_genome_pie = df_genome_counts_taxonomy[df_genome_counts_taxonomy[group_col] == 'GV'].copy()
            if not df_gv_genome_pie.empty:
                 # Sort GV families by count for pie chart order
                 df_gv_genome_pie = df_gv_genome_pie.sort_values('Count', ascending=False).copy()
                 fig_4a_pie_gv = go.Figure(data=[go.Pie(
                     labels=df_gv_genome_pie['Taxonomic_Unit'],
                     values=df_gv_genome_pie['Count'],
                     hole=.3, # Creates a donut chart
                     marker=dict(colors=[combined_taxonomy_color_map.get(unit, 'gray') for unit in df_gv_genome_pie['Taxonomic_Unit']]) # Apply colors
                 )])
                 # fig_4a_pie_gv.update_layout(title_text='GV Genome Counts by Family') # No title
                 fig_4a_pie_gv.update_layout(plotly_layout_defaults) # Apply defaults (no title, etc.)
                 fig_4a_pie_gv.update_layout(margin=dict(t=0, b=0, l=0, r=0)) # Adjust margins
                 fig_4a_pie_gv.update_traces(textinfo='percent+label', insidetextorientation='radial')

                 # Export Figure 4A Pie GV
                 fig_4a_pie_gv_path_html = Path(output_figure_dir) / "figure4a_genomes_by_family_pie_gv.html"
                 fig_4a_pie_gv.write_html(str(fig_4a_pie_gv_path_html))
                 print(f"Figure 4A (GV Pie) HTML saved to: {fig_4a_pie_gv_path_html}")
                 try:
                     fig_4a_pie_gv_path_pdf = Path(output_figure_dir) / "figure4a_genomes_by_family_pie_gv.pdf"
                     fig_4a_pie_gv.write_image(str(fig_4a_pie_gv_path_pdf))
                     print(f"Figure 4A (GV Pie) PDF saved to: {fig_4a_pie_gv_path_pdf}")
                 except Exception as e:
                     print(f"Warning: Could not export Figure 4A (GV Pie) to PDF. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}")
                 try:
                     fig_4a_pie_gv_path_svg = Path(output_figure_dir) / "figure4a_genomes_by_family_pie_gv.svg"
                     fig_4a_pie_gv.write_image(str(fig_4a_pie_gv_path_svg))
                     print(f"Figure 4A (GV Pie) SVG saved to: {fig_4a_pie_gv_path_svg}")
                 except Exception as e:
                     print(f"Warning: Could not export Figure 4A (GV Pie) to SVG. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}")

            else:
                 print("No GV genome data for pie chart.")

        else:
            print("No genome counts by taxonomy available for plotting Figure 4A.")

    else:
        print(f"Skipping Figure 4A due to missing columns in filtered data: '{source_genome_accession_col}', '{group_col}', '{asgard_phylum_col}', or '{virus_family_col}'.")


    # --- Figure 4B: Number of Proteins by Taxonomy (Stacked Bar and Pie Charts) ---
    print("\n--- Figure 4B: Number of Proteins by Taxonomy ---")

    if protein_id_col in df_filtered_groups.columns and group_col in df_filtered_groups.columns and asgard_phylum_col in df_filtered_groups.columns and virus_family_col in df_filtered_groups.columns:

        # Calculate protein counts by Group and then by Phylum/Family
        protein_counts_by_group_taxonomy = []

        # Process Asgard proteins by Phylum
        df_asgard_proteins = df_filtered_groups[df_filtered_groups[group_col] == 'Asgard'].copy()
        if not df_asgard_proteins.empty and asgard_phylum_col in df_asgard_proteins.columns:
            asgard_protein_counts = df_asgard_proteins.groupby([group_col, asgard_phylum_col])[protein_id_col].count().reset_index(name='Count')
            asgard_protein_counts.rename(columns={asgard_phylum_col: 'Taxonomic_Unit'}, inplace=True)
            protein_counts_by_group_taxonomy.append(asgard_protein_counts)
        else:
             print(f"Warning: Skipping Asgard protein counts by Phylum. Data empty or column '{asgard_phylum_col}' missing.")


        # Process GV proteins by Family
        df_gv_proteins = df_filtered_groups[df_filtered_groups[group_col] == 'GV'].copy()
        if not df_gv_proteins.empty and virus_family_col in df_gv_proteins.columns:
            gv_protein_counts = df_gv_proteins.groupby([group_col, virus_family_col])[protein_id_col].count().reset_index(name='Count')
            gv_protein_counts.rename(columns={virus_family_col: 'Taxonomic_Unit'}, inplace=True)
            protein_counts_by_group_taxonomy.append(gv_protein_counts)
        else:
             print(f"Warning: Skipping GV protein counts by Family. Data empty or column '{virus_family_col}' missing.")


        if protein_counts_by_group_taxonomy:
            df_protein_counts_taxonomy = pd.concat(protein_counts_by_group_taxonomy, ignore_index=True)
            print("\nProtein Counts by Group and Taxonomy:")
            print(df_protein_counts_taxonomy.to_markdown(index=False))

            # Combine color maps for plotting (same as for genomes)
            # This is already done above based on ALL unique units from both datasets

            # --- Figure 4B Part 1: Stacked Bar Chart ---
            # Calculate total count for each Taxonomic_Unit across both groups for sorting
            taxonomic_unit_total_counts = df_protein_counts_taxonomy.groupby('Taxonomic_Unit')['Count'].sum().sort_values(ascending=False)
            # Get the order of taxonomic units based on total counts
            taxonomic_unit_order = taxonomic_unit_total_counts.index.tolist()

            fig_4b_bar = px.bar(
                df_protein_counts_taxonomy,
                x=group_col, # X-axis is Group
                y='Count',
                color='Taxonomic_Unit', # Color is Phylum or Family
                barmode='stack', # Stack bars by Taxonomic_Unit
                # title='Number of Proteins by Group and Taxonomy (Stacked Bar)', # No title
                labels={group_col: 'Group', 'Count': 'Number of Proteins', 'Taxonomic_Unit': 'Phylum or Family'},
                color_discrete_map=combined_taxonomy_color_map, # Use combined color map
                category_orders={group_col: ['Asgard', 'GV'], 'Taxonomic_Unit': taxonomic_unit_order}, # Ensure Group order and Taxonomic Unit order
                template='plotly_white'
            )
            # Apply default layout and remove gridlines
            fig_4b_bar.update_layout(plotly_layout_defaults)
            fig_4b_bar.update_xaxes(title_text='Group', showgrid=False)
            fig_4b_bar.update_yaxes(title_text='Number of Proteins', showgrid=False)
            fig_4b_bar.update_layout(legend_title_text='Taxonomic Unit') # Add legend title


            # Export Figure 4B Part 1 (HTML, PDF, SVG)
            fig_4b_bar_path_html = Path(output_figure_dir) / "figure4b_proteins_by_taxonomy_stackedbar.html"
            fig_4b_bar.write_html(str(fig_4b_bar_path_html))
            print(f"Figure 4B (Stacked Bar) HTML saved to: {fig_4b_bar_path_html}")

            # Export to PDF/SVG (requires kaleido)
            try:
                fig_4b_bar_path_pdf = Path(output_figure_dir) / "figure4b_proteins_by_taxonomy_stackedbar.pdf"
                fig_4b_bar.write_image(str(fig_4b_bar_path_pdf))
                print(f"Figure 4B (Stacked Bar) PDF saved to: {fig_4b_bar_path_pdf}")
            except Exception as e:
                print(f"Warning: Could not export Figure 4B (Stacked Bar) to PDF. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}")
            try:
                fig_4b_bar_path_svg = Path(output_figure_dir) / "figure4b_proteins_by_taxonomy_stackedbar.svg"
                fig_4b_bar.write_image(str(fig_4b_bar_path_svg))
                print(f"Figure 4B (Stacked Bar) SVG saved to: {fig_4b_bar_path_svg}")
            except Exception as e:
                print(f"Warning: Could not export Figure 4B (Stacked Bar) to SVG. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}")


            # --- Figure 4B Part 2: Pie Charts (Asgard and GV) ---
            print("\n--- Figure 4B Part 2: Protein Counts Pie Charts (Asgard and GV) ---")

            # Asgard Protein Pie Chart
            df_asgard_protein_pie = df_protein_counts_taxonomy[df_protein_counts_taxonomy[group_col] == 'Asgard'].copy()
            if not df_asgard_protein_pie.empty:
                 # Sort Asgard phyla by count for pie chart order
                 df_asgard_protein_pie = df_asgard_protein_pie.sort_values('Count', ascending=False).copy()
                 fig_4b_pie_asgard = go.Figure(data=[go.Pie(
                     labels=df_asgard_protein_pie['Taxonomic_Unit'],
                     values=df_asgard_protein_pie['Count'],
                     hole=.3, # Creates a donut chart
                     marker=dict(colors=[combined_taxonomy_color_map.get(unit, 'gray') for unit in df_asgard_protein_pie['Taxonomic_Unit']]) # Apply colors
                 )])
                 # fig_4b_pie_asgard.update_layout(title_text='Asgard Protein Counts by Phylum') # No title
                 fig_4b_pie_asgard.update_layout(plotly_layout_defaults) # Apply defaults (no title, etc.)
                 fig_4b_pie_asgard.update_layout(margin=dict(t=0, b=0, l=0, r=0)) # Adjust margins
                 fig_4b_pie_asgard.update_traces(textinfo='percent+label', insidetextorientation='radial')

                 # Export Figure 4B Pie Asgard
                 fig_4b_pie_asgard_path_html = Path(output_figure_dir) / "figure4b_proteins_by_phylum_pie_asgard.html"
                 fig_4b_pie_asgard.write_html(str(fig_4b_pie_asgard_path_html))
                 print(f"Figure 4B (Asgard Pie) HTML saved to: {fig_4b_pie_asgard_path_html}")
                 try:
                     fig_4b_pie_asgard_path_pdf = Path(output_figure_dir) / "figure4b_proteins_by_phylum_pie_asgard.pdf"
                     fig_4b_pie_asgard.write_image(str(fig_4b_pie_asgard_path_pdf))
                     print(f"Figure 4B (Asgard Pie) PDF saved to: {fig_4b_pie_asgard_path_pdf}")
                 except Exception as e:
                     print(f"Warning: Could not export Figure 4B (Asgard Pie) to PDF. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}")
                 try:
                     fig_4b_pie_asgard_path_svg = Path(output_figure_dir) / "figure4b_proteins_by_phylum_pie_asgard.svg"
                     fig_4b_pie_asgard.write_image(str(fig_4b_pie_asgard_path_svg))
                     print(f"Figure 4B (Asgard Pie) SVG saved to: {fig_4b_pie_asgard_path_svg}")
                 except Exception as e:
                     print(f"Warning: Could not export Figure 4B (Asgard Pie) to SVG. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}")

            else:
                 print("No Asgard protein data for pie chart.")

            # GV Protein Pie Chart
            df_gv_protein_pie = df_protein_counts_taxonomy[df_protein_counts_taxonomy[group_col] == 'GV'].copy()
            if not df_gv_protein_pie.empty:
                 # Sort GV families by count for pie chart order
                 df_gv_protein_pie = df_gv_protein_pie.sort_values('Count', ascending=False).copy()
                 fig_4b_pie_gv = go.Figure(data=[go.Pie(
                     labels=df_gv_protein_pie['Taxonomic_Unit'],
                     values=df_gv_protein_pie['Count'],
                     hole=.3, # Creates a donut chart
                     marker=dict(colors=[combined_taxonomy_color_map.get(unit, 'gray') for unit in df_gv_protein_pie['Taxonomic_Unit']]) # Apply colors
                 )])
                 # fig_4b_pie_gv.update_layout(title_text='GV Protein Counts by Family') # No title
                 fig_4b_pie_gv.update_layout(plotly_layout_defaults) # Apply defaults (no title, etc.)
                 fig_4b_pie_gv.update_layout(margin=dict(t=0, b=0, l=0, r=0)) # Adjust margins
                 fig_4b_pie_gv.update_traces(textinfo='percent+label', insidetextorientation='radial')

                 # Export Figure 4B Pie GV
                 fig_4b_pie_gv_path_html = Path(output_figure_dir) / "figure4b_proteins_by_family_pie_gv.html"
                 fig_4b_pie_gv.write_html(str(fig_4b_pie_gv_path_html))
                 print(f"Figure 4B (GV Pie) HTML saved to: {fig_4b_pie_gv_path_html}")
                 try:
                     fig_4b_pie_gv_path_pdf = Path(output_figure_dir) / "figure4b_proteins_by_family_pie_gv.pdf"
                     fig_4b_pie_gv.write_image(str(fig_4b_pie_gv_path_pdf))
                     print(f"Figure 4B (GV Pie) PDF saved to: {fig_4b_pie_gv_path_pdf}")
                 except Exception as e:
                     print(f"Warning: Could not export Figure 4B (GV Pie) to PDF. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}")
                 try:
                     fig_4b_pie_gv_path_svg = Path(output_figure_dir) / "figure4b_proteins_by_family_pie_gv.svg"
                     fig_4b_pie_gv.write_image(str(fig_4b_pie_gv_path_svg))
                     print(f"Figure 4B (GV Pie) SVG saved to: {fig_4b_pie_gv_path_svg}")
                 except Exception as e:
                     print(f"Warning: Could not export Figure 4B (GV Pie) to SVG. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}")

            else:
                 print("No GV protein data for pie chart.")


        else:
            print("No protein counts by taxonomy available for plotting Figure 4B.")

    else:
        print(f"Skipping Figure 4B due to missing columns in filtered data: '{protein_id_col}', '{group_col}', '{asgard_phylum_col}', or '{virus_family_col}'.")


print("\n\n--- Figure 4 Generation Complete ---")
print(f"Figures saved to the '{output_figure_dir}' directory (HTML, PDF, SVG).")


In [None]:
# Cell 7: Figure 5 - Intra-Orthogroup Divergence (APSI) Landscape

# --- Imports ---
# Assumes pandas, numpy, plotly.express, plotly.graph_objects, plotly.io,
# and necessary variables/helper functions (like df_full, group_col,
# orthogroup_col, apsi_col, num_og_sequences_col, broad_func_cat_col,
# structurally_dark_col, esp_col, original_seq_length_col, percent_disorder_col,
# arcadia_colors_manual, arcadia_primary_palette, arcadia_secondary_palette,
# arcadia_neutrals_palette, output_figure_dir, output_figure_summary_dir,
# plotly_layout_defaults, broad_category_color_map)
# are defined in the setup cell (Cell 1).
# Assumes df_filtered_groups is available from Cell 2 (Figure 2 generation) and
# correctly filters for 'Asgard' and 'GV' groups.
# Assumes APSI data and 'Has_Conserved_Motif' flag are intended to be merged into df_full in Cell 1.

print("\n\n--- Generating Figure 5: Intra-Orthogroup Divergence (APSI) Landscape (Cell 7) ---")

# Ensure df_full is loaded and filtered groups are available
if 'df_full' not in locals() or df_full.empty:
    print("ERROR: df_full not loaded. Please run the setup cell (Cell 1).")
elif 'df_filtered_groups' not in locals() or df_filtered_groups.empty:
     print("ERROR: df_filtered_groups not available or is empty. Please run Cell 2 first.")
else:
    # Define colors for Asgard and GV groups (re-defined for clarity in this cell)
    group_colors = {
        'Asgard': arcadia_colors_manual.get('aegean', '#5088C5'),
        'GV': arcadia_colors_manual.get('amber', '#F28360')
    }
    print(f"\nUsing Group Colors: Asgard={group_colors.get('Asgard')}, GV={group_colors.get('Amber')}")

    # --- Check for required APSI columns before proceeding ---
    # Check for core APSI columns and the group column
    required_core_cols = [orthogroup_col, apsi_col, num_og_sequences_col, group_col]
    missing_core_cols = [col for col in required_core_cols if col not in df_full.columns]

    # Also check for the 'Has_Conserved_Motif' column needed for Figure 5G
    if 'Has_Conserved_Motif' not in df_full.columns:
         missing_core_cols.append('Has_Conserved_Motif')

    # Also check for columns needed for specific plots (5D, 5E, 5F)
    if broad_func_cat_col not in df_full.columns: missing_core_cols.append(broad_func_cat_col)
    if structurally_dark_col not in df_full.columns: missing_core_cols.append(structurally_dark_col)
    if original_seq_length_col not in df_full.columns: missing_core_cols.append(original_seq_length_col)
    if percent_disorder_col not in df_full.columns: missing_core_cols.append(percent_disorder_col)


    if missing_core_cols:
        print(f"ERROR: Required columns for Figure 5 analysis not found in df_full: {', '.join(missing_core_cols)}")
        print("Please ensure the main database is loaded correctly and includes these columns.")
        print("Also, ensure APSI data and 'Has_Conserved_Motif' flag were correctly merged into df_full in the setup cell (Cell 1).")
        print("Consider restarting the kernel and running all cells.")
    else:
        # Filter df_full to include only OGs with calculated APSI (size >= 5)
        # Also filter by group (Asgard/GV) using df_filtered_groups as a base
        # Ensure Orthogroup column exists in df_full before filtering
        if orthogroup_col in df_full.columns:
            df_apsi_data = df_full[
                df_full[orthogroup_col].isin(df_filtered_groups[orthogroup_col].unique()) & # Only OGs present in filtered groups
                df_full[apsi_col].notna() # Only OGs where APSI was calculated
            ].drop_duplicates(subset=[orthogroup_col]).copy() # Ensure one row per OG
        else:
            print(f"ERROR: Orthogroup column '{orthogroup_col}' not found in df_full. Cannot filter for APSI data.")
            df_apsi_data = pd.DataFrame() # Create empty df


        if df_apsi_data.empty:
            print("WARNING: No orthogroups with calculated APSI found after filtering. Cannot generate Figure 5 plots.")
        else:
            print(f"\nFound {len(df_apsi_data)} orthogroups with calculated APSI for plotting.")

            # --- Figure 5A: Overall APSI Distribution (Histogram) ---
            print("\n--- Figure 5A: Overall APSI Distribution (Histogram) ---")
            print("Note: The distribution may be skewed; consider transformations or alternative plot types if needed.")

            fig_5a = px.histogram(
                df_apsi_data,
                x=apsi_col,
                nbins=50, # Adjust bins as needed
                # title='Overall APSI Distribution', # No title
                labels={apsi_col: 'Average Pairwise Sequence Identity (APSI)', 'count': 'Number of Orthogroups'},
                template='plotly_white'
            )
            # Apply default layout and remove gridlines
            fig_5a.update_layout(plotly_layout_defaults)
            fig_5a.update_xaxes(title_text='Average Pairwise Sequence Identity (APSI)', showgrid=False)
            fig_5a.update_yaxes(title_text='Number of Orthogroups', showgrid=False)


            # Export Figure 5A (HTML, PDF, SVG)
            fig_5a_path_html = Path(output_figure_dir) / "figure5a_overall_apsi_distribution_histogram.html"
            fig_5a.write_html(str(fig_5a_path_html))
            print(f"Figure 5A HTML saved to: {fig_5a_path_html}")
            try:
                fig_5a_path_pdf = Path(output_figure_dir) / "figure5a_overall_apsi_distribution_histogram.pdf"
                fig_5a.write_image(str(fig_5a_path_pdf))
                print(f"Figure 5A PDF saved to: {fig_5a_path_pdf}")
            except Exception as e:
                print(f"Warning: Could not export Figure 5A to PDF. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}")
            try:
                fig_5a_path_svg = Path(output_figure_dir) / "figure5a_overall_apsi_distribution_histogram.svg"
                fig_5a.write_image(str(fig_5a_path_svg))
                print(f"Figure 5A SVG saved to: {fig_5a_path_svg}")
            except Exception as e:
                print(f"Warning: Could not export Figure 5A to SVG. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}")


            # --- Figure 5B: APSI Comparison (Asgard vs. GV) (Violin Plot) ---
            print("\n--- Figure 5B: APSI Comparison (Asgard vs. GV) (Violin Plot) ---")

            fig_5b = px.violin(
                df_apsi_data,
                x=group_col,
                y=apsi_col,
                color=group_col,
                # title='APSI Comparison: Asgard vs. GV', # No title
                labels={group_col: 'Group', apsi_col: 'Average Pairwise Sequence Identity (APSI)'},
                color_discrete_map=group_colors,
                category_orders={group_col: ['Asgard', 'GV']}, # Ensure order
                box=True, # Show box plot inside violin
                points="all", # Show all points (can be slow for large datasets) or False/None
                template='plotly_white'
            )
            # Apply default layout and remove gridlines
            fig_5b.update_layout(plotly_layout_defaults)
            fig_5b.update_xaxes(title_text='Group', showgrid=False)
            fig_5b.update_yaxes(title_text='Average Pairwise Sequence Identity (APSI)', showgrid=False)
            fig_5b.update_layout(showlegend=False) # Hide legend if colors are obvious from x-axis


            # Export Figure 5B (HTML, PDF, SVG)
            fig_5b_path_html = Path(output_figure_dir) / "figure5b_apsi_comparison_group_violinplot.html"
            fig_5b.write_html(str(fig_5b_path_html))
            print(f"Figure 5B HTML saved to: {fig_5b_path_html}")
            try:
                fig_5b_path_pdf = Path(output_figure_dir) / "figure5b_apsi_comparison_group_violinplot.pdf"
                fig_5b.write_image(str(fig_5b_path_pdf))
                print(f"Figure 5B PDF saved to: {fig_5b_path_pdf}")
            except Exception as e:
                print(f"Warning: Could not export Figure 5B to PDF. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}")
            try:
                fig_5b_path_svg = Path(output_figure_dir) / "figure5b_apsi_comparison_group_violinplot.svg"
                fig_5b.write_image(str(fig_5b_path_svg))
                print(f"Figure 5B SVG saved to: {fig_5b_path_svg}")
            except Exception as e:
                print(f"Warning: Could not export Figure 5B to SVG. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}")


            # --- Figure 5C: APSI vs. Orthogroup Size ---
            print("\n--- Figure 5C: APSI vs. Orthogroup Size ---")

            if num_og_sequences_col in df_apsi_data.columns:
                fig_5c = px.scatter(
                    df_apsi_data,
                    x=num_og_sequences_col,
                    y=apsi_col,
                    color=group_col, # Color by group
                    # title='APSI vs. Orthogroup Size', # No title
                    labels={num_og_sequences_col: 'Orthogroup Size (Number of Proteins)', apsi_col: 'Average Pairwise Sequence Identity (APSI)', group_col: 'Group'},
                    color_discrete_map=group_colors,
                    template='plotly_white'
                )
                # Apply default layout and remove gridlines
                fig_5c.update_layout(plotly_layout_defaults)
                fig_5c.update_xaxes(title_text='Orthogroup Size (Number of Proteins)', showgrid=False)
                fig_5c.update_yaxes(title_text='Average Pairwise Sequence Identity (APSI)', showgrid=False)
                fig_5c.update_layout(legend_title_text='Group') # Add legend title


                # Export Figure 5C (HTML, PDF, SVG)
                fig_5c_path_html = Path(output_figure_dir) / "figure5c_apsi_vs_og_size_scatter.html"
                fig_5c.write_html(str(fig_5c_path_html))
                print(f"Figure 5C HTML saved to: {fig_5c_path_html}")
                try:
                    fig_5c_path_pdf = Path(output_figure_dir) / "figure5c_apsi_vs_og_size_scatter.pdf"
                    fig_5c.write_image(str(fig_5c_path_pdf))
                    print(f"Figure 5C PDF saved to: {fig_5c_path_pdf}")
                except Exception as e:
                    print(f"Warning: Could not export Figure 5C to PDF. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}")
                try:
                    fig_5c_path_svg = Path(output_figure_dir) / "figure5c_apsi_vs_og_size_scatter.svg"
                    fig_5c.write_image(str(fig_5c_path_svg))
                    print(f"Figure 5C SVG saved to: {fig_5c_path_svg}")
                except Exception as e:
                    print(f"Warning: Could not export Figure 5C to SVG. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}")

            else:
                print(f"Skipping Figure 5C due to missing column: '{num_og_sequences_col}'.")


            # --- Figure 5D: APSI by Broad Functional Category ---
            print("\n--- Figure 5D: APSI by Broad Functional Category ---")

            # Check if broad_func_cat_col exists in df_full before merging
            if broad_func_cat_col in df_full.columns:
                # Merge APSI data with functional categories (need to ensure one category per OG for plotting)
                # Assuming broad_func_cat_col is representative at the OG level or we take the most common
                # For simplicity here, let's merge APSI with the first protein's broad category in each OG
                df_og_categories = df_full.drop_duplicates(subset=[orthogroup_col], keep='first')[[orthogroup_col, broad_func_cat_col]].copy()

                # Add a check for broad_func_cat_col in df_og_categories before merging
                if broad_func_cat_col in df_og_categories.columns:
                    df_apsi_with_categories = pd.merge(df_apsi_data, df_og_categories, on=orthogroup_col, how='left')

                    # --- Debugging Print Statement ---
                    print(f"\nDEBUG: Columns in df_apsi_with_categories before filtering categories: {df_apsi_with_categories.columns.tolist()}")
                    # --- End Debugging Print Statement ---

                    # Filter out unwanted categories for plotting (same as Figure 3B/2D)
                    categories_to_exclude = ['Unknown/Unclassified', 'Other Specific Annotation']
                    # CORRECTED: Use the correct column name after merge ('Broad_Functional_Category_y')
                    functional_category_col_merged = broad_func_cat_col + '_y' # Assuming '_y' is the correct suffix

                    # Add a check for the expected merged column name
                    if functional_category_col_merged in df_apsi_with_categories.columns:
                        df_apsi_with_categories_filtered = df_apsi_with_categories[
                            (~df_apsi_with_categories[functional_category_col_merged].isin(categories_to_exclude))
                        ].copy()


                        if not df_apsi_with_categories_filtered.empty:
                            # Define category order for plotting (based on frequency or a predefined list)
                            # Using frequency in the filtered data for consistent order
                            category_order_5d = df_apsi_with_categories_filtered[functional_category_col_merged].value_counts().index.tolist()

                            fig_5d = px.box( # Using box plot as it's standard for comparing distributions across categories
                                df_apsi_with_categories_filtered,
                                x=functional_category_col_merged, # Use the corrected column name
                                y=apsi_col,
                                color=group_col, # Color by group
                                # title='APSI by Broad Functional Category', # No title
                                labels={functional_category_col_merged: 'Broad Functional Category', apsi_col: 'Average Pairwise Sequence Identity (APSI)', group_col: 'Group'},
                                color_discrete_map=group_colors,
                                category_orders={functional_category_col_merged: category_order_5d, group_col: ['Asgard', 'GV']}, # Apply consistent order
                                template='plotly_white'
                            )
                            # Apply default layout and remove gridlines
                            fig_5d.update_layout(plotly_layout_defaults)
                            fig_5d.update_xaxes(title_text='Broad Functional Category', showgrid=False)
                            fig_5d.update_yaxes(title_text='Average Pairwise Sequence Identity (APSI)', showgrid=False)
                            fig_5d.update_layout(legend_title_text='Group') # Add legend title
                            fig_5d.update_xaxes(tickangle=45) # Angle labels if needed


                            # Export Figure 5D (HTML, PDF, SVG)
                            fig_5d_path_html = Path(output_figure_dir) / "figure5d_apsi_by_functional_category_boxplot.html"
                            fig_5d.write_html(str(fig_5d_path_html))
                            print(f"Figure 5D HTML saved to: {fig_5d_path_html}")
                            try:
                                fig_5d_path_pdf = Path(output_figure_dir) / "figure5d_apsi_by_functional_category_boxplot.pdf"
                                fig_5d.write_image(str(fig_5d_path_pdf))
                                print(f"Figure 5D PDF saved to: {fig_5d_path_pdf}")
                            except Exception as e:
                                print(f"Warning: Could not export Figure 5D to PDF. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}")
                            try:
                                fig_5d_path_svg = Path(output_figure_dir) / "figure5d_apsi_by_functional_category_boxplot.svg"
                                fig_5d.write_image(str(fig_5d_path_svg))
                                print(f"Figure 5D SVG saved to: {fig_5d_path_svg}")
                            except Exception as e:
                                print(f"Warning: Could not export Figure 5D to SVG. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}")

                        else:
                            print("No APSI data with filtered functional categories available for plotting Figure 5D.")
                    else:
                         print(f"Skipping Figure 5D plot: Expected merged column '{functional_category_col_merged}' not found in df_apsi_with_categories.")

                else:
                    print(f"Skipping Figure 5D merge and plot: '{broad_func_cat_col}' column not found in merged OG categories.")
            else:
                print(f"Skipping Figure 5D due to missing column in df_full: '{broad_func_cat_col}'.")


           # Cell 7: Figure 5 - Intra-Orthogroup Divergence (APSI) Landscape
# ... (previous code in Cell 7 for Figures 5A, 5B, 5C, 5D) ...

            # --- Figure 5E: APSI by Structural Status ---
            print("\n--- Figure 5E: APSI by Structural Status ---")

            # df_apsi_data already contains 'Is_Structurally_Dark' for each OG,
            # so no need to merge it in again.
            # We will use a copy of df_apsi_data to avoid modifying it if it's used elsewhere.
            df_apsi_with_structural_status = df_apsi_data.copy()

            # Check if structurally_dark_col (Is_Structurally_Dark) exists
            if structurally_dark_col in df_apsi_with_structural_status.columns:
                 # Ensure structural status column is boolean
                 df_apsi_with_structural_status[structurally_dark_col] = df_apsi_with_structural_status[structurally_dark_col].astype(bool)

                 # Define labels for True/False structural status
                 structural_status_labels = {True: 'Structurally Dark', False: 'Has Known Structure'}

                 # Add a column with string labels for plotting
                 df_apsi_with_structural_status['Structural_Status_Label'] = df_apsi_with_structural_status[structurally_dark_col].map(structural_status_labels)

                 # Filter out NaNs if any (e.g. if Is_Structurally_Dark was NaN for some OGs)
                 df_apsi_with_structural_status_filtered = df_apsi_with_structural_status.dropna(subset=['Structural_Status_Label', apsi_col]).copy()


                 if not df_apsi_with_structural_status_filtered.empty:
                     fig_5e = px.box( # Using box plot as it's standard for comparing distributions across categories
                         df_apsi_with_structural_status_filtered,
                         x='Structural_Status_Label',
                         y=apsi_col,
                         color=group_col, # Color by group (Asgard/GV)
                         # title='APSI by Structural Status', # No title
                         labels={'Structural_Status_Label': 'Structural Status', apsi_col: 'Average Pairwise Sequence Identity (APSI)', group_col: 'Group'},
                         color_discrete_map=group_colors, # Defined in Cell 1
                         category_orders={'Structural_Status_Label': ['Structurally Dark', 'Has Known Structure'], group_col: ['Asgard', 'GV']}, # Apply consistent order
                         template='plotly_white'
                     )
                     # Apply default layout and remove gridlines
                     fig_5e.update_layout(plotly_layout_defaults) # Defined in Cell 1
                     fig_5e.update_xaxes(title_text='Structural Status', showgrid=False)
                     fig_5e.update_yaxes(title_text='Average Pairwise Sequence Identity (APSI)', showgrid=False)
                     fig_5e.update_layout(legend_title_text='Group') # Add legend title


                     # Export Figure 5E (HTML, PDF, SVG)
                     fig_5e_path_html = Path(output_figure_dir) / "figure5e_apsi_by_structural_status_boxplot.html" # Defined in Cell 1
                     fig_5e.write_html(str(fig_5e_path_html))
                     print(f"Figure 5E HTML saved to: {fig_5e_path_html}")
                     try:
                         fig_5e_path_pdf = Path(output_figure_dir) / "figure5e_apsi_by_structural_status_boxplot.pdf"
                         fig_5e.write_image(str(fig_5e_path_pdf))
                         print(f"Figure 5E PDF saved to: {fig_5e_path_pdf}")
                     except Exception as e:
                         print(f"Warning: Could not export Figure 5E to PDF. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}")
                     try:
                         fig_5e_path_svg = Path(output_figure_dir) / "figure5e_apsi_by_structural_status_boxplot.svg"
                         fig_5e.write_image(str(fig_5e_path_svg))
                         print(f"Figure 5E SVG saved to: {fig_5e_path_svg}")
                     except Exception as e:
                         print(f"Warning: Could not export Figure 5E to SVG. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}")

                 else:
                     print("No APSI data with structural status available for plotting Figure 5E after filtering NaNs.")
            else:
                print(f"Skipping Figure 5E: Column '{structurally_dark_col}' not found in df_apsi_data.")

            # --- Figure 5F: APSI vs. Protein Length/Disorder ---
            print("\n--- Figure 5F: APSI vs. Protein Length/Disorder ---")

            # Check if original_seq_length_col and percent_disorder_col exist in df_filtered_groups before calculating averages
            if original_seq_length_col in df_filtered_groups.columns and percent_disorder_col in df_filtered_groups.columns:
                 # Need average protein length and disorder per OG
                 # Calculate average length and disorder per OG from df_filtered_groups
                 df_og_avg_features = df_filtered_groups.groupby(orthogroup_col).agg(
                     Avg_Length=(original_seq_length_col, 'mean'),
                     Avg_Disorder=(percent_disorder_col, 'mean')
                 ).reset_index()

                 # Merge APSI data with average features
                 df_apsi_with_avg_features = pd.merge(df_apsi_data, df_og_avg_features, on=orthogroup_col, how='left')

                 # Scatter plot: APSI vs. Average Length
                 if 'Avg_Length' in df_apsi_with_avg_features.columns:
                      fig_5f_length = px.scatter(
                          df_apsi_with_avg_features,
                          x='Avg_Length',
                          y=apsi_col,
                          color=group_col, # Color by group
                          # title='APSI vs. Average Orthogroup Protein Length', # No title
                          labels={'Avg_Length': 'Average Protein Length per Orthogroup', apsi_col: 'Average Pairwise Sequence Identity (APSI)', group_col: 'Group'},
                          color_discrete_map=group_colors,
                          template='plotly_white'
                      )
                      # Apply default layout and remove gridlines
                      fig_5f_length.update_layout(plotly_layout_defaults)
                      fig_5f_length.update_xaxes(title_text='Average Protein Length per Orthogroup', showgrid=False)
                      fig_5f_length.update_yaxes(title_text='Average Pairwise Sequence Identity (APSI)', showgrid=False)
                      fig_5f_length.update_layout(legend_title_text='Group') # Add legend title


                      # Export Figure 5F (Length) (HTML, PDF, SVG)
                      fig_5f_length_path_html = Path(output_figure_dir) / "figure5f_apsi_vs_avg_length_scatter.html"
                      fig_5f_length.write_html(str(fig_5f_length_path_html))
                      print(f"Figure 5F (Length) HTML saved to: {fig_5f_length_path_html}")
                      try:
                          fig_5f_length_path_pdf = Path(output_figure_dir) / "figure5f_apsi_vs_avg_length_scatter.pdf"
                          fig_5f_length.write_image(str(fig_5f_length_path_pdf))
                          print(f"Figure 5F (Length) PDF saved to: {fig_5f_length_path_pdf}")
                      except Exception as e:
                          print(f"Warning: Could not export Figure 5F (Length) to PDF. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}")
                      try:
                          fig_5f_length_path_svg = Path(output_figure_dir) / "figure5f_apsi_vs_avg_length_scatter.svg"
                          fig_5f_length.write_image(str(fig_5f_length_path_svg))
                          print(f"Figure 5F (Length) SVG saved to: {fig_5f_length_path_svg}")
                      except Exception as e:
                          print(f"Warning: Could not export Figure 5F (Length) to SVG. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}")

                 else:
                     print(f"Skipping Figure 5F (Length) due to missing 'Avg_Length' column after merge.")


                 # Scatter plot: APSI vs. Average Disorder
                 if 'Avg_Disorder' in df_apsi_with_avg_features.columns:
                      fig_5f_disorder = px.scatter(
                          df_apsi_with_avg_features,
                          x='Avg_Disorder',
                          y=apsi_col,
                          color=group_col, # Color by group
                          # title='APSI vs. Average Orthogroup Protein Disorder', # No title
                          labels={'Avg_Disorder': 'Average Protein Disorder (%) per Orthogroup', apsi_col: 'Average Pairwise Sequence Identity (APSI)', group_col: 'Group'},
                          color_discrete_map=group_colors,
                          template='plotly_white'
                      )
                      # Apply default layout and remove gridlines
                      fig_5f_disorder.update_layout(plotly_layout_defaults)
                      fig_5f_disorder.update_xaxes(title_text='Average Protein Disorder (%) per Orthogroup', showgrid=False)
                      fig_5f_disorder.update_yaxes(title_text='Average Pairwise Sequence Identity (APSI)', showgrid=False)
                      fig_5f_disorder.update_layout(legend_title_text='Group') # Add legend title


                      # Export Figure 5F (Disorder) (HTML, PDF, SVG)
                      fig_5f_disorder_path_html = Path(output_figure_dir) / "figure5f_apsi_vs_avg_disorder_scatter.html"
                      fig_5f_disorder.write_html(str(fig_5f_disorder_path_html))
                      print(f"Figure 5F (Disorder) HTML saved to: {fig_5f_disorder_path_html}")
                      try:
                          fig_5f_disorder_path_pdf = Path(output_figure_dir) / "figure5f_apsi_vs_avg_disorder_scatter.pdf"
                          fig_5f_disorder.write_image(str(fig_5f_disorder_path_pdf))
                          print(f"Figure 5F (Disorder) PDF saved to: {fig_5f_disorder_path_pdf}")
                      except Exception as e:
                          print(f"Warning: Could not export Figure 5F (Disorder) to PDF. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}")
                      try:
                          fig_5f_disorder_path_svg = Path(output_figure_dir) / "figure5f_apsi_vs_avg_disorder_scatter.svg"
                          fig_5f_disorder.write_image(str(fig_5f_disorder_path_svg))
                          print(f"Figure 5F (Disorder) SVG saved to: {fig_5f_disorder_path_svg}")
                      except Exception as e:
                          print(f"Warning: Could not export Figure 5F (Disorder) to SVG. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}")

                 else:
                     print(f"Skipping Figure 5F (Disorder) due to missing 'Avg_Disorder' column after merge.")

            else:
                 print(f"Skipping Figure 5F due to missing columns in filtered data: '{original_seq_length_col}' or '{percent_disorder_col}'.")


            # --- Figure 5G: APSI and Conserved Motif Presence ---
            print("\n--- Figure 5G: APSI and Conserved Motif Presence ---")

            if 'Has_Conserved_Motif' in df_apsi_data.columns: # Check if the flag was merged
                 # Ensure the flag column is boolean
                 df_apsi_data['Has_Conserved_Motif'] = df_apsi_data['Has_Conserved_Motif'].astype(bool)

                 # Define labels for True/False motif presence
                 motif_presence_labels = {True: 'Has Conserved Motif', False: 'No Conserved Motif'}

                 # Add a column with string labels for plotting
                 df_apsi_data['Motif_Presence_Label'] = df_apsi_data['Has_Conserved_Motif'].map(motif_presence_labels)

                 # Filter out NaNs if any (shouldn't be if flag was added correctly)
                 df_apsi_with_motif_filtered = df_apsi_data.dropna(subset=['Motif_Presence_Label']).copy()


                 if not df_apsi_with_motif_filtered.empty:
                     fig_5g = px.box( # Using box plot as it's standard for comparing distributions across categories
                         df_apsi_with_motif_filtered,
                         x='Motif_Presence_Label',
                         y=apsi_col,
                         color=group_col, # Color by group
                         # title='APSI by Conserved Motif Presence', # No title
                         labels={'Motif_Presence_Label': 'Conserved Motif Presence', apsi_col: 'Average Pairwise Sequence Identity (APSI)', group_col: 'Group'},
                         color_discrete_map=group_colors,
                         category_orders={'Motif_Presence_Label': ['Has Conserved Motif', 'No Conserved Motif'], group_col: ['Asgard', 'GV']}, # Apply consistent order
                         template='plotly_white'
                     )
                     # Apply default layout and remove gridlines
                     fig_5g.update_layout(plotly_layout_defaults)
                     fig_5g.update_xaxes(title_text='Conserved Motif Presence', showgrid=False)
                     fig_5g.update_yaxes(title_text='Average Pairwise Sequence Identity (APSI)', showgrid=False)
                     fig_5g.update_layout(legend_title_text='Group') # Add legend title


                     # Export Figure 5G (HTML, PDF, SVG)
                     fig_5g_path_html = Path(output_figure_dir) / "figure5g_apsi_by_motif_presence_boxplot.html"
                     fig_5g.write_html(str(fig_5g_path_html))
                     print(f"Figure 5G HTML saved to: {fig_5g_path_html}")
                     try:
                         fig_5g_path_pdf = Path(output_figure_dir) / "figure5g_apsi_by_motif_presence_boxplot.pdf"
                         fig_5g.write_image(str(fig_5g_path_pdf))
                         print(f"Figure 5G PDF saved to: {fig_5g_path_pdf}")
                     except Exception as e:
                         print(f"Warning: Could not export Figure 5G to PDF. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}")
                     try:
                         fig_5g_path_svg = Path(output_figure_dir) / "figure5g_apsi_by_motif_presence_boxplot.svg"
                         fig_5g.write_image(str(fig_5g_path_svg))
                         print(f"Figure 5G SVG saved to: {fig_5g_path_svg}")
                     except Exception as e:
                         print(f"Warning: Could not export Figure 5G to SVG. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}")

                 else:
                     print("No APSI data with motif presence information available for plotting Figure 5G.")
            else:
                print(f"Skipping Figure 5G due to missing column: 'Has_Conserved_Motif'.")


print("\n\n--- Figure 5 Generation Complete ---")
print(f"Figures saved to the '{output_figure_dir}' directory (HTML, PDF, SVG).")


In [None]:
# Cell 8: Figure 4 (from Pub Ideas) - Conserved Motif Analysis
# Corresponds to Figure 6 in the notebook's sequence if Cell 7 was Figure 5.

# --- Imports & Setup Assumptions ---
# This cell assumes that pandas (pd), numpy (np), plotly.express (px),
# plotly.graph_objects (go), plotly.io (pio), Path from pathlib,
# and all necessary variables/helper functions from Cell 1 (Setup) are available.
# Key variables from Cell 1:
# - df_motifs: DataFrame containing motif data.
# - output_figure_dir: Directory to save plots.
# - plotly_layout_defaults: Default layout for Plotly figures.
# - arcadia_primary_palette: Color palette.
# - group_colors: Color map for Asgard/GV.
# - orthogroup_col, group_col: Column names.

print("\n\n--- Generating Figure 6 (Pub Ideas): Conserved Motif Analysis (Cell 8) ---")

# --- Check if df_motifs is available and not empty ---
if 'df_motifs' not in locals() or df_motifs.empty:
    print("ERROR: DataFrame 'df_motifs' not found or is empty. Please run Cell 1 first.")
    # Create an empty DataFrame to prevent errors if subsequent code relies on it, though plots won't generate.
    df_motifs = pd.DataFrame(columns=['Motif', 'Orthogroup']) 
else:
    print(f"Using 'df_motifs' (shape: {df_motifs.shape}) for motif analysis.")

    # --- Figure 6A (Pub Ideas): Frequency of Top Conserved Motifs ---
    print("\n--- Figure 6A (Pub Ideas): Frequency of Top Conserved Motifs ---")
    
    if 'Motif' in df_motifs.columns and not df_motifs.empty:
        top_n_motifs_display = 30 # Number of top motifs to display
        motif_counts = df_motifs['Motif'].value_counts()
        
        df_top_motifs_plot = motif_counts.head(top_n_motifs_display).reset_index()
        df_top_motifs_plot.columns = ['Motif', 'Frequency']
        
        print(f"\nTop {top_n_motifs_display} Conserved Motifs (Overall):")
        try:
            print(df_top_motifs_plot.to_markdown(index=False))
        except ImportError:
            print(df_top_motifs_plot)

        # Save the full list of motif counts
        motif_counts_summary_path = Path(output_figure_summary_dir) / "figure4a_motif_frequency_all.csv"
        try:
            motif_counts.reset_index().rename(columns={'index': 'Motif', 'Motif': 'Frequency'}).to_csv(motif_counts_summary_path, index=False)
            print(f"Full motif frequency summary saved to: {motif_counts_summary_path}")
        except Exception as e:
            print(f"Error saving full motif frequency summary: {e}")

        # Plot
        fig_6a_motifs = px.bar(
            df_top_motifs_plot,
            x='Motif',
            y='Frequency',
            # title=f'Top {top_n_motifs_display} Conserved Motifs Found Across Orthogroups', # No title by default
            labels={'Motif': 'Conserved Motif', 'Frequency': 'Number of Orthogroups'},
            color_discrete_sequence=arcadia_primary_palette # Use a palette
        )
        fig_6a_motifs.update_layout(plotly_layout_defaults) # Apply Arcadia style
        fig_6a_motifs.update_xaxes(title_text='Conserved Motif', showgrid=False, categoryorder='total descending', tickangle=45)
        fig_6a_motifs.update_yaxes(title_text='Number of Orthogroups', showgrid=False)
        fig_6a_motifs.update_layout(showlegend=False)

        # Export Figure 6A (Motif Frequency)
        fig_6a_motif_freq_path_html = Path(output_figure_dir) / "figure6a_pub_ideas_motif_frequency.html"
        fig_6a_motifs.write_html(str(fig_6a_motif_freq_path_html))
        print(f"Figure 6A (Motif Frequency) HTML saved to: {fig_4a_motif_freq_path_html}")
        try:
            fig_6a_motif_freq_path_pdf = Path(output_figure_dir) / "figure6a_pub_ideas_motif_frequency.pdf"
            fig_6a_motifs.write_image(str(fig_6a_motif_freq_path_pdf))
            print(f"Figure 6A (Motif Frequency) PDF saved to: {fig_4a_motif_freq_path_pdf}")
        except Exception as e:
            print(f"Warning: Could not export Figure 6A (Motif Frequency) to PDF. Ensure 'kaleido' is installed. Error: {e}")
        try:
            fig_6a_motif_freq_path_svg = Path(output_figure_dir) / "figure6a_pub_ideas_motif_frequency.svg"
            fig_6a_motifs.write_image(str(fig_4a_motif_freq_path_svg))
            print(f"Figure 6A (Motif Frequency) SVG saved to: {fig_4a_motif_freq_path_svg}")
        except Exception as e:
            print(f"Warning: Could not export Figure 6A (Motif Frequency) to SVG. Ensure 'kaleido' is installed. Error: {e}")
            
    else:
        print("Skipping Figure 6A (Motif Frequency): 'Motif' column not found or DataFrame is empty.")

    # --- Figure 6B (Pub Ideas): Motif Length Distribution ---
    print("\n--- Figure 6B (Pub Ideas): Motif Length Distribution ---")
    
    if 'Motif' in df_motifs.columns and not df_motifs.empty:
        # Calculate motif lengths - ensure 'Motif' column is string type
        df_motifs['Motif_Length'] = df_motifs['Motif'].astype(str).str.len()
        
        print("\nMotif Length Statistics:")
        print(df_motifs['Motif_Length'].describe())

        # Plot histogram of motif lengths
        fig_6b_motif_len = px.histogram(
            df_motifs.dropna(subset=['Motif_Length']), # Drop rows where motif length might be NaN (if Motif was NaN)
            x='Motif_Length',
            nbins=20, # Adjust number of bins as needed
            # title='Distribution of Conserved Motif Lengths', # No title by default
            labels={'Motif_Length': 'Motif Length (Amino Acids)', 'count': 'Number of Motifs'},
            color_discrete_sequence=[arcadia_primary_palette[1]] # Use a color from the palette
        )
        fig_6b_motif_len.update_layout(plotly_layout_defaults) # Apply Arcadia style
        fig_6b_motif_len.update_xaxes(title_text='Motif Length (Amino Acids)', showgrid=False)
        fig_6b_motif_len.update_yaxes(title_text='Number of Motifs', showgrid=False)
        fig_6b_motif_len.update_layout(bargap=0.1) # Add a small gap between bars

        # Export Figure 6B (Motif Length)
        fig_6b_motif_len_path_html = Path(output_figure_dir) / "figure6b_pub_ideas_motif_length_distribution.html"
        fig_6b_motif_len.write_html(str(fig_6b_motif_len_path_html))
        print(f"Figure 6B (Motif Length) HTML saved to: {fig_6b_motif_len_path_html}")
        try:
            fig_6b_motif_len_path_pdf = Path(output_figure_dir) / "figure6b_pub_ideas_motif_length_distribution.pdf"
            fig_6b_motif_len.write_image(str(fig_4b_motif_len_path_pdf))
            print(f"Figure 6B (Motif Length) PDF saved to: {fig_6b_motif_len_path_pdf}")
        except Exception as e:
            print(f"Warning: Could not export Figure 6B (Motif Length) to PDF. Error: {e}")
        try:
            fig_6b_motif_len_path_svg = Path(output_figure_dir) / "figure6b_pub_ideas_motif_length_distribution.svg"
            fig_6b_motif_len.write_image(str(fig_6b_motif_len_path_svg))
            print(f"Figure 6B (Motif Length) SVG saved to: {fig_6b_motif_len_path_svg}")
        except Exception as e:
            print(f"Warning: Could not export Figure 6B (Motif Length) to SVG. Error: {e}")
            
    else:
        print("Skipping Figure 6B (Motif Length): 'Motif' column not found or DataFrame is empty.")

    # --- Note on Figure 6C and 6D (Pub Ideas) ---
    print("\n--- Regarding Figure 4C (Motif Conservation Level) & 4D (Example Motif Visualization) ---")
    print("Figure 6C (Motif Conservation Level): The current 'df_motifs' DataFrame does not explicitly store a per-motif conservation percentage.")
    print("  The motifs were identified as 'conserved' based on the parameters in the upstream script (e.g., >90% or 100% identity in columns).")
    print("  If a variable conservation level per motif is needed, the motif generation script or its output CSV would need to be adapted.")
    print("Figure 6D (Example Motif Visualization): Programmatically generating a publication-quality alignment snippet with highlighted motifs is complex.")
    print("  This is often best done manually or with specialized alignment visualization tools for a few key examples.")

print("\n\n--- Figure 6 (Pub Ideas) - Conserved Motif Analysis (Cell 8) Complete ---")
print(f"Figures and summaries saved to '{output_figure_dir}' and '{output_figure_summary_dir}'.")



In [None]:
# Cell 9: Figure 7 (User's numbering) / Figure 5 (Pub Ideas) - Integrating Divergence, Motifs, and Protein Features

# --- Imports & Setup Assumptions ---
# This cell assumes that pandas (pd), numpy (np), plotly.express (px),
# plotly.graph_objects (go), plotly.io (pio), Path from pathlib, Counter from collections,
# and all necessary variables/helper functions from Cell 1 (Setup) are available.
# Key variables from Cell 1:
# - df_full: Main DataFrame with all protein annotations, including Intra_OG_APSI and Has_Conserved_Motif.
# - df_motifs: DataFrame with detailed motif data (Orthogroup, Motif, etc.).
# - output_figure_dir, output_summary_dir_phase1: Directories to save plots/data.
# - plotly_layout_defaults: Default layout for Plotly figures.
# - group_colors, broad_category_color_map, arcadia_primary_palette: Color maps.
# - orthogroup_col, group_col, esp_col, structurally_dark_col, 
#   apsi_col ('Intra_OG_APSI'), 'Has_Conserved_Motif' (flag in df_full)

print("\n\n--- Generating Figure 7 (User) / Figure 5 (Pub Ideas): Integrating Divergence, Motifs, and Features (Cell 9) ---")

# --- Check if df_full is available and has necessary columns ---
required_cols_fig5 = [
    orthogroup_col, group_col, esp_col, structurally_dark_col, 
    apsi_col, 'Has_Conserved_Motif' # Assuming 'Has_Conserved_Motif' is the flag in df_full
]
if 'df_full' not in locals() or df_full.empty:
    print("ERROR: DataFrame 'df_full' not found or is empty. Please run Cell 1 first.")
    # Create an empty DataFrame to prevent errors if subsequent code relies on it
    df_full = pd.DataFrame(columns=required_cols_fig5) 
elif not all(col in df_full.columns for col in required_cols_fig5):
    missing = [col for col in required_cols_fig5 if col not in df_full.columns]
    print(f"ERROR: 'df_full' is missing required columns for Figure 5: {missing}. Please ensure Cell 1 ran correctly and merged APSI/Motif data.")
    # Potentially create empty df_full to avoid further errors, or exit

# --- Prepare data: One row per Orthogroup with relevant flags and APSI ---
# This df_og_summary will be used for most plots in this figure.
df_og_summary = pd.DataFrame()
if orthogroup_col in df_full.columns and not df_full.empty:
    # Aggregate Is_ESP: True if any protein in the OG is ESP (for Asgard)
    # Aggregate Is_Structurally_Dark: True if >50% proteins in OG are dark (can be adjusted)
    # Aggregate Has_Conserved_Motif: True if the OG has any motif (already in df_full)
    
    agg_dict = {
        apsi_col: 'first', # APSI is per OG
        'Has_Conserved_Motif': 'first', # Motif flag is per OG
        group_col: 'first' # Assuming all proteins in an OG belong to the same group
    }
    if esp_col in df_full.columns:
        # For ESP, we only care about Asgard. So, first filter df_full for Asgard, then aggregate.
        # This ensures 'any' ESP flag is specific to Asgard OGs.
        # For GV OGs, ESP will effectively be False or NaN depending on merge.
        asgard_esp_agg = df_full[df_full[group_col] == 'Asgard'].groupby(orthogroup_col)[esp_col].any().rename(esp_col)
    else:
        asgard_esp_agg = pd.Series(name=esp_col, dtype=bool) # Empty series if esp_col is missing

    if structurally_dark_col in df_full.columns:
        # Calculate proportion of structurally dark proteins per OG
        df_full['Is_Dark_Numeric'] = df_full[structurally_dark_col].astype(int) # Convert boolean to int for mean
        og_dark_prop = df_full.groupby(orthogroup_col)['Is_Dark_Numeric'].mean().rename('Prop_Structurally_Dark')
    else:
        og_dark_prop = pd.Series(name='Prop_Structurally_Dark', dtype=float) # Empty series

    # Base aggregation for APSI, Motif Flag, and Group
    df_og_summary_base = df_full.groupby(orthogroup_col).agg(agg_dict).reset_index()

    # Merge ESP aggregation (only for Asgard OGs)
    df_og_summary = pd.merge(df_og_summary_base, asgard_esp_agg, on=orthogroup_col, how='left')
    # Fill NaN for ESP (e.g., for GV OGs or if OG wasn't in Asgard) with False
    if esp_col in df_og_summary.columns:
        df_og_summary[esp_col] = df_og_summary[esp_col].fillna(False)
    else: # If esp_col wasn't even in df_full
        df_og_summary[esp_col] = False


    # Merge proportion of structurally dark proteins
    if not og_dark_prop.empty:
        df_og_summary = pd.merge(df_og_summary, og_dark_prop, on=orthogroup_col, how='left')
        if 'Prop_Structurally_Dark' in df_og_summary.columns:
             df_og_summary['Is_Mostly_Dark_OG'] = df_og_summary['Prop_Structurally_Dark'] > 0.5 # Example threshold
        else:
             df_og_summary['Is_Mostly_Dark_OG'] = False 
    else:
        # Fallback if structurally_dark_col was directly aggregated or missing
        if structurally_dark_col in df_og_summary.columns: # if it was aggregated with 'first'
            df_og_summary['Is_Mostly_Dark_OG'] = df_og_summary[structurally_dark_col].astype(bool)
        else: 
            df_og_summary['Is_Mostly_Dark_OG'] = False

    # Filter for OGs with valid APSI values for relevant plots
    if apsi_col in df_og_summary.columns:
        df_og_summary = df_og_summary.dropna(subset=[apsi_col])
    else:
        print(f"WARNING: APSI column '{apsi_col}' not found in df_og_summary. APSI-related plots will fail.")
        df_og_summary = pd.DataFrame() # Make it empty if APSI is crucial and missing
    
    print(f"Created 'df_og_summary' with {len(df_og_summary)} orthogroups for analysis.")
    if df_og_summary.empty and apsi_col in required_cols_fig5 : # Check if it became empty after APSI dropna
        print("WARNING: 'df_og_summary' is empty after filtering for valid APSI. Plots for Figure 5 might not generate.")

else:
    print("Skipping Figure 5 generation as 'df_full' is not suitable or 'Orthogroup' column is missing.")

# --- Figure 7A (Pub Ideas): APSI Comparison (ESPs vs. Non-ESPs) ---
print("\n--- Figure 7A (Pub Ideas): APSI Comparison (ESPs vs. Non-ESPs) ---")
if not df_og_summary.empty and esp_col in df_og_summary.columns and apsi_col in df_og_summary.columns and group_col in df_og_summary.columns:
    # Consider only Asgard OGs for ESP comparison
    df_asgard_og_summary = df_og_summary[df_og_summary[group_col] == 'Asgard'].copy()
    
    if not df_asgard_og_summary.empty:
        # Create a more descriptive label for ESP status
        df_asgard_og_summary['ESP_Status'] = df_asgard_og_summary[esp_col].apply(lambda x: 'ESP OG' if x else 'Non-ESP OG')

        fig_7a_apsi_esp = px.violin(
            df_asgard_og_summary.dropna(subset=[apsi_col, 'ESP_Status']),
            x='ESP_Status',
            y=apsi_col,
            color='ESP_Status',
            box=True,
            points="all", # "all" or False or "outliers"
            # title='APSI Distribution: Asgard ESP OGs vs. Non-ESP OGs', # No title
            labels={apsi_col: 'Average Pairwise Sequence Identity (APSI)', 'ESP_Status': 'ESP Orthogroup Status (Asgard)'},
            color_discrete_map={'ESP OG': arcadia_primary_palette[0], 'Non-ESP OG': arcadia_primary_palette[1]},
            category_orders={'ESP_Status': ['ESP OG', 'Non-ESP OG']}
        )
        fig_7a_apsi_esp.update_layout(plotly_layout_defaults)
        fig_7a_apsi_esp.update_xaxes(title_text='ESP Orthogroup Status (Asgard)', showgrid=False)
        fig_7a_apsi_esp.update_yaxes(title_text='Average Pairwise Sequence Identity (APSI)', showgrid=False)
        fig_7a_apsi_esp.update_layout(showlegend=False)
        
        # Export Figure
        fig_7a_path_html = Path(output_figure_dir) / "figure5a_pub_ideas_apsi_esp_vs_non_esp.html"
        fig_7a_apsi_esp.write_html(str(fig_7a_path_html))
        print(f"Figure 7A (APSI ESP vs Non-ESP) HTML saved to: {fig_7a_path_html}")
        try:
            fig_7a_path_pdf = Path(output_figure_dir) / "figure7a_pub_ideas_apsi_esp_vs_non_esp.pdf"
            fig_7a_apsi_esp.write_image(str(fig_7a_path_pdf))
            print(f"Figure 7A (APSI ESP vs Non-ESP) PDF saved to: {fig_7a_path_pdf}")
        except Exception as e:
            print(f"Warning: Could not export Figure 7A to PDF. Error: {e}")
        try:
            fig_7a_path_svg = Path(output_figure_dir) / "figure7a_pub_ideas_apsi_esp_vs_non_esp.svg"
            fig_7a_apsi_esp.write_image(str(fig_5a_path_svg))
            print(f"Figure 7A (APSI ESP vs Non-ESP) SVG saved to: {fig_7a_path_svg}")
        except Exception as e:
            print(f"Warning: Could not export Figure 7A to SVG. Error: {e}")
            
    else:
        print("No Asgard OGs found in df_og_summary to compare ESP vs Non-ESP APSI.")
else:
    print("Skipping Figure 7A (APSI ESP vs Non-ESP): df_og_summary is empty or missing required columns.")

# --- Figure 7C (Pub Ideas): Divergence/Motifs in Structurally Dark Proteins ---
print("\n--- Figure 7C (Pub Ideas): Divergence & Motif Presence in Structurally Dark OGs ---")

# Part 1: APSI vs. Structural Darkness
print("\n--- Part 1: APSI vs. Structural Darkness ---")
if not df_og_summary.empty and 'Is_Mostly_Dark_OG' in df_og_summary.columns and apsi_col in df_og_summary.columns:
    
    # Create a descriptive label for darkness status
    df_og_summary['Structural_Darkness_Status_OG'] = df_og_summary['Is_Mostly_Dark_OG'].apply(lambda x: 'Mostly Dark OG' if x else 'Not Mostly Dark OG')

    fig_7c_apsi_dark = px.violin(
        df_og_summary.dropna(subset=[apsi_col, 'Structural_Darkness_Status_OG', group_col]),
        x='Structural_Darkness_Status_OG',
        y=apsi_col,
        color=group_col, # Compare Asgard vs GV within each darkness category
        box=True,
        points=False, 
        # title='APSI Distribution by Orthogroup Structural Darkness Status', # No title
        labels={apsi_col: 'Average Pairwise Sequence Identity (APSI)', 
                'Structural_Darkness_Status_OG': 'OG Structural Darkness Status',
                group_col: 'Group (Asgard/GV)'},
        color_discrete_map=group_colors, # group_colors defined in Cell 1
        category_orders={'Structural_Darkness_Status_OG': ['Mostly Dark OG', 'Not Mostly Dark OG'],
                         group_col: ['Asgard', 'GV']} 
    )
    fig_7c_apsi_dark.update_layout(plotly_layout_defaults)
    fig_7c_apsi_dark.update_xaxes(title_text='OG Structural Darkness Status', showgrid=False)
    fig_7c_apsi_dark.update_yaxes(title_text='Average Pairwise Sequence Identity (APSI)', showgrid=False)
    
    # Export Figure
    fig_7c_part1_path_html = Path(output_figure_dir) / "figure7c_pub_ideas_apsi_vs_darkness.html"
    fig_7c_apsi_dark.write_html(str(fig_7c_part1_path_html))
    print(f"Figure 7C Part 1 (APSI vs Darkness) HTML saved to: {fig_7c_part1_path_html}")
    try:
        fig_7c_part1_path_pdf = Path(output_figure_dir) / "figure7c_pub_ideas_apsi_vs_darkness.pdf"
        fig_7c_apsi_dark.write_image(str(fig_7c_part1_path_pdf))
        print(f"Figure 7C Part 1 (APSI vs Darkness) PDF saved to: {fig_7c_part1_path_pdf}")
    except Exception as e:
        print(f"Warning: Could not export Figure 7C Part 1 to PDF. Error: {e}")
    try:
        fig_7c_part1_path_svg = Path(output_figure_dir) / "figure7c_pub_ideas_apsi_vs_darkness.svg"
        fig_7c_apsi_dark.write_image(str(fig_7c_part1_path_svg))
        print(f"Figure 7C Part 1 (APSI vs Darkness) SVG saved to: {fig_7c_part1_path_svg}")
    except Exception as e:
        print(f"Warning: Could not export Figure 7C Part 1 to SVG. Error: {e}")
else:
    print("Skipping Figure 7C Part 1 (APSI vs Darkness): df_og_summary is empty or missing 'Is_Mostly_Dark_OG' or apsi_col.")

# Part 2: Motif Presence vs. Structural Darkness
print("\n--- Part 2: Motif Presence vs. Structural Darkness ---")
if not df_og_summary.empty and 'Is_Mostly_Dark_OG' in df_og_summary.columns and 'Has_Conserved_Motif' in df_og_summary.columns:
    
    # Calculate percentage of OGs with motifs within each darkness category, per group (Asgard/GV)
    df_motif_dark_summary = df_og_summary.groupby([group_col, 'Structural_Darkness_Status_OG'])['Has_Conserved_Motif'].agg(
        Total_OGs='count',
        OGs_With_Motif='sum'
    ).reset_index()
    df_motif_dark_summary['Percent_OGs_With_Motif'] = (df_motif_dark_summary['OGs_With_Motif'] / df_motif_dark_summary['Total_OGs'] * 100).fillna(0)

    print("\nMotif Presence by OG Structural Darkness Status and Group:")
    try:
        print(df_motif_dark_summary.to_markdown(index=False))
    except ImportError:
        print(df_motif_dark_summary)

    fig_7c_motif_dark = px.bar(
        df_motif_dark_summary,
        x='Structural_Darkness_Status_OG',
        y='Percent_OGs_With_Motif',
        color=group_col,
        barmode='group',
        # title='Motif Presence in Structurally Dark vs. Non-Dark OGs', # No title
        labels={'Percent_OGs_With_Motif': '% of OGs with Conserved Motif(s)', 
                'Structural_Darkness_Status_OG': 'OG Structural Darkness Status',
                group_col: 'Group (Asgard/GV)'},
        color_discrete_map=group_colors,
        category_orders={'Structural_Darkness_Status_OG': ['Mostly Dark OG', 'Not Mostly Dark OG'],
                         group_col: ['Asgard', 'GV']}
    )
    fig_7c_motif_dark.update_layout(plotly_layout_defaults)
    fig_7c_motif_dark.update_xaxes(title_text='OG Structural Darkness Status', showgrid=False)
    fig_7c_motif_dark.update_yaxes(title_text='% of OGs with Conserved Motif(s)', showgrid=False)
    
    # Export Figure
    fig_7c_part2_path_html = Path(output_figure_dir) / "figure7c_pub_ideas_motif_vs_darkness.html"
    fig_7c_motif_dark.write_html(str(fig_7c_part2_path_html))
    print(f"Figure 57 Part 2 (Motif vs Darkness) HTML saved to: {fig_7c_part2_path_html}")
    try:
        fig_7c_part2_path_pdf = Path(output_figure_dir) / "figure7c_pub_ideas_motif_vs_darkness.pdf"
        fig_7c_motif_dark.write_image(str(fig_7c_part2_path_pdf))
        print(f"Figure 7C Part 2 (Motif vs Darkness) PDF saved to: {fig_7c_part2_path_pdf}")
    except Exception as e:
        print(f"Warning: Could not export Figure 7C Part 2 to PDF. Error: {e}")
    try:
        fig_7c_part2_path_svg = Path(output_figure_dir) / "figure7c_pub_ideas_motif_vs_darkness.svg"
        fig_7c_motif_dark.write_image(str(fig_7c_part2_path_svg))
        print(f"Figure 57 Part 2 (Motif vs Darkness) SVG saved to: {fig_7c_part2_path_svg}")
    except Exception as e:
        print(f"Warning: Could not export Figure 5C Part 2 to SVG. Error: {e}")
else:
    print("Skipping Figure 7C Part 2 (Motif vs Darkness): df_og_summary is empty or missing required columns.")

print("\n\n--- Figure 7 (User) / Figure 5 (Pub Ideas) - Cell 9 Complete ---")
print(f"Figures and summaries saved to '{output_figure_dir}' and '{output_summary_dir_phase1}'.")
print("Next, we can consider the Lokiactin case study or other specific analyses.")



In [None]:
# Cell 10: Case Study for Figure 8 - Lokiactin (OG0000203.ASG) - APSI and Structural Darkness

# --- Imports & Setup Assumptions ---
# This cell assumes that pandas (pd), numpy (np), plotly.express (px),
# Path from pathlib, and all necessary variables/helper functions from Cell 1 (Setup) are available.
# Key variables from Cell 1:
# - df_full: Main DataFrame with all protein annotations.
# - df_og_summary: DataFrame created in Cell 9, containing OG-level summaries including Intra_OG_APSI.
# - output_figure_dir, output_summary_dir_phase1: Directories to save plots/data.
# - plotly_layout_defaults: Default layout for Plotly figures.
# - arcadia_primary_palette: Color palette.
# - orthogroup_col, protein_id_col, structurally_dark_col, apsi_col ('Intra_OG_APSI')

print("\n\n--- Generating Lokiactin (OG0000203.ASG) Case Study - APSI and Structural Darkness (Cell 10) ---")

# --- Configuration for this Case Study ---
target_og_id = "OG0000203.ASG"
reference_protein_id = "UYP47028.1" # The experimentally characterized Lokiactin

# --- Check if necessary DataFrames are available ---
if 'df_full' not in locals() or df_full.empty:
    print(f"ERROR: DataFrame 'df_full' not found or is empty. Please run Cell 1 first.")
    df_lokiactin_og_members = pd.DataFrame() # Ensure it's defined
elif 'df_og_summary' not in locals() or df_og_summary.empty:
    print(f"ERROR: DataFrame 'df_og_summary' (with OG-level APSI) not found or is empty. Please run Cell 9 first.")
    df_lokiactin_og_members = pd.DataFrame() # Ensure it's defined
elif orthogroup_col not in df_full.columns or protein_id_col not in df_full.columns:
    print(f"ERROR: 'df_full' is missing required columns: '{orthogroup_col}' or '{protein_id_col}'.")
    df_lokiactin_og_members = pd.DataFrame()
elif orthogroup_col not in df_og_summary.columns or apsi_col not in df_og_summary.columns:
    print(f"ERROR: 'df_og_summary' is missing required columns: '{orthogroup_col}' or '{apsi_col}'.")
    df_lokiactin_og_members = pd.DataFrame()
else:
    # --- 1. Filter for the Lokiactin Orthogroup Members from df_full ---
    print(f"\n--- Analyzing Orthogroup: {target_og_id} ---")
    df_lokiactin_og_members = df_full[df_full[orthogroup_col] == target_og_id].copy()
    
    if df_lokiactin_og_members.empty:
        print(f"No proteins found for Orthogroup {target_og_id} in df_full.")
    else:
        num_members = len(df_lokiactin_og_members)
        print(f"Found {num_members} protein members in {target_og_id}.")

        # --- 2. Verify Presence of Reference Lokiactin ---
        reference_present = reference_protein_id in df_lokiactin_og_members[protein_id_col].values
        print(f"Reference Lokiactin ({reference_protein_id}) present in this OG: {reference_present}")
        if not reference_present:
            print(f"WARNING: Reference protein {reference_protein_id} was NOT found among the members of {target_og_id} in your dataset.")

        # --- 3. Get Pre-calculated APSI for this Orthogroup ---
        og_apsi_data = df_og_summary[df_og_summary[orthogroup_col] == target_og_id]
        
        if og_apsi_data.empty:
            lokiactin_og_apsi = np.nan
            print(f"WARNING: APSI data not found for {target_og_id} in df_og_summary.")
        else:
            lokiactin_og_apsi = og_apsi_data[apsi_col].iloc[0]
            print(f"Intra-OG Average Pairwise Sequence Identity (APSI) for {target_og_id}: {lokiactin_og_apsi:.4f}")

        # --- 4. Analyze Structural Darkness within the Lokiactin OG ---
        if structurally_dark_col not in df_lokiactin_og_members.columns:
            print(f"WARNING: Column '{structurally_dark_col}' not found. Cannot analyze structural darkness for this OG.")
            num_dark_members = 0
            percent_dark = 0.0
            df_darkness_counts = pd.DataFrame()
        else:
            df_lokiactin_og_members[structurally_dark_col] = df_lokiactin_og_members[structurally_dark_col].fillna(False).astype(bool)
            num_dark_members = df_lokiactin_og_members[structurally_dark_col].sum()
            percent_dark = (num_dark_members / num_members) * 100 if num_members > 0 else 0.0
            
            print(f"Number of structurally dark members in {target_og_id}: {num_dark_members} out of {num_members} ({percent_dark:.1f}%)")

            # Prepare data for plotting darkness breakdown
            darkness_counts = df_lokiactin_og_members[structurally_dark_col].value_counts().reset_index()
            darkness_counts.columns = ['Is_Structurally_Dark', 'Count']
            darkness_counts['Label'] = darkness_counts['Is_Structurally_Dark'].apply(lambda x: 'Structurally Dark' if x else 'Has Known/Predicted Structure')
            df_darkness_counts = darkness_counts


            # --- 5. Visualize Structural Darkness Breakdown for Lokiactin OG ---
            if not df_darkness_counts.empty:
                fig_lokiactin_darkness = px.bar(
                    df_darkness_counts,
                    x='Label',
                    y='Count',
                    color='Label',
                    # title=f'Structural Darkness of Proteins in Lokiactin OG ({target_og_id})', # No title
                    labels={'Count': 'Number of Proteins', 'Label': 'Structural Annotation Status'},
                    color_discrete_map={'Structurally Dark': arcadia_primary_palette[3], 'Has Known/Predicted Structure': arcadia_primary_palette[4]}
                )
                fig_lokiactin_darkness.update_layout(plotly_layout_defaults)
                fig_lokiactin_darkness.update_xaxes(title_text='Structural Annotation Status', showgrid=False)
                fig_lokiactin_darkness.update_yaxes(title_text='Number of Proteins', showgrid=False)
                fig_lokiactin_darkness.update_layout(showlegend=False)
                fig_lokiactin_darkness.show()

                # Save plot
                fig_lokiactin_dark_path_html = Path(output_figure_dir) / f"lokiactin_{target_og_id}_darkness_breakdown.html"
                fig_lokiactin_darkness.write_html(str(fig_lokiactin_dark_path_html))
                print(f"Lokiactin OG darkness breakdown plot HTML saved to: {fig_lokiactin_dark_path_html}")
                try:
                    fig_lokiactin_dark_path_pdf = Path(output_figure_dir) / f"lokiactin_{target_og_id}_darkness_breakdown.pdf"
                    fig_lokiactin_darkness.write_image(str(fig_lokiactin_dark_path_pdf))
                    print(f"Lokiactin OG darkness breakdown plot PDF saved to: {fig_lokiactin_dark_path_pdf}")
                except Exception as e:
                    print(f"Warning: Could not export Lokiactin darkness plot to PDF. Error: {e}")
                try:
                    fig_lokiactin_dark_path_svg = Path(output_figure_dir) / f"lokiactin_{target_og_id}_darkness_breakdown.svg"
                    fig_lokiactin_darkness.write_image(str(fig_lokiactin_dark_path_svg))
                    print(f"Lokiactin OG darkness breakdown plot SVG saved to: {fig_lokiactin_dark_path_svg}")
                except Exception as e:
                    print(f"Warning: Could not export Lokiactin darkness plot to SVG. Error: {e}")
            else:
                print("No data to plot for structural darkness breakdown of Lokiactin OG.")

        # --- Summary Output for Lokiactin OG ---
        print("\n--- Lokiactin Orthogroup Summary ---")
        print(f"Orthogroup ID: {target_og_id}")
        print(f"Contains Reference Lokiactin ({reference_protein_id}): {reference_present}")
        print(f"Number of Protein Members: {num_members}")
        print(f"Intra-OG APSI: {lokiactin_og_apsi:.4f}" if not pd.isna(lokiactin_og_apsi) else "APSI: Not Available")
        if structurally_dark_col in df_lokiactin_og_members.columns:
            print(f"Structurally Dark Members: {num_dark_members} ({percent_dark:.1f}%)")
        
        # Save summary data for this OG
        lokiactin_og_details_path = Path(output_summary_dir_phase1) / f"lokiactin_{target_og_id}_detailed_summary.csv"
        try:
            # Save all members of the OG with their relevant info
            cols_to_save = [protein_id_col, sequence_col, 'Length', structurally_dark_col, group_col, asgard_phylum_col, esp_col, 'Has_Conserved_Motif']
            cols_present_in_df = [col for col in cols_to_save if col in df_lokiactin_og_members.columns]
            df_lokiactin_og_members[cols_present_in_df].to_csv(lokiactin_og_details_path, index=False)
            print(f"Detailed member data for {target_og_id} saved to: {lokiactin_og_details_path}")
        except Exception as e:
            print(f"Error saving Lokiactin OG detailed member data: {e}")

print("\n\n--- Cell 10 (Lokiactin Case Study - APSI & Darkness) Complete ---")
print(f"Analysis for OG {target_og_id} focusing on its APSI and structural darkness of members is done.")
print(f"Figures and summaries saved to '{output_figure_dir}' and '{output_summary_dir_phase1}'.")



In [None]:
# Cell 11: Figure 8 (User's numbering) - Case Studies: Divergence, Diversity, and Structural Darkness

# --- Imports & Setup Assumptions ---
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots # ****** ENSURED THIS IMPORT IS PRESENT ******
from pathlib import Path
# Other necessary imports like Counter might be needed if used, but make_subplots was the direct error.

# This cell assumes that pandas (pd), numpy (np), plotly.express (px),
# plotly.graph_objects (go), Path from pathlib,
# and all necessary variables/helper functions from Cell 1 (Setup) are available.
# Key variables from Cell 1 & subsequent cells:
# - df_full: Main DataFrame with all protein annotations.
# - df_og_summary: DataFrame created in Cell 9, containing OG-level summaries including Intra_OG_APSI.
# - output_figure_dir, output_summary_dir_phase1: Directories to save plots/data.
# - plotly_layout_defaults: Default layout for Plotly figures. (This is a go.Layout object)
# - arcadia_primary_palette: Color palette.
# - orthogroup_col, protein_id_col, structurally_dark_col, apsi_col ('Intra_OG_APSI'), group_col, sequence_col

print("\n\n--- Generating Figure 8: Case Studies - Divergence, Diversity, and Structural Darkness (Cell 11) ---")

# --- Configuration for Case Studies ---
case_study_protein_ids_str = "UYP47028.1;KKK44605.1;KKK42122.1;OLS22855.1;TFG12995.1;OLS30618.1;TET76256.1;OLS23013.1;NHJ47629.1"
case_study_protein_ids = [pid.strip() for pid in case_study_protein_ids_str.split(';')]

diversity_metrics_path = "orthogroup_diversity_metrics.csv" # Path to the uploaded diversity file

# --- Check if necessary DataFrames are available ---
error_found = False
# Ensure df_full and df_og_summary are defined and not empty (from previous cells)
if 'df_full' not in locals() or df_full.empty:
    print(f"ERROR: DataFrame 'df_full' not found or is empty. Please run Cell 1 first.")
    error_found = True
elif 'df_og_summary' not in locals() or df_og_summary.empty:
    print(f"ERROR: DataFrame 'df_og_summary' (with OG-level APSI) not found or is empty. Please run Cell 9 first.")
    error_found = True
# Check for essential columns within df_full
elif not all(col in df_full.columns for col in [orthogroup_col, protein_id_col, structurally_dark_col, group_col, sequence_col]):
    missing_cols = [col for col in [orthogroup_col, protein_id_col, structurally_dark_col, group_col, sequence_col] if col not in df_full.columns]
    print(f"ERROR: 'df_full' is missing required columns: {missing_cols}.")
    error_found = True
# Check for essential columns within df_og_summary
elif not all(col in df_og_summary.columns for col in [orthogroup_col, apsi_col]):
    missing_cols = [col for col in [orthogroup_col, apsi_col] if col not in df_og_summary.columns]
    print(f"ERROR: 'df_og_summary' is missing required columns: {missing_cols}.")
    error_found = True

if error_found:
    print("Cannot proceed with case study generation due to missing data.")
    # For this environment, we'll just print the error and not run the rest.
else:
    # --- 1. Load Diversity Metrics ---
    print(f"\n--- Loading Orthogroup Diversity Metrics from: {diversity_metrics_path} ---")
    try:
        df_diversity = pd.read_csv(diversity_metrics_path)
        # Rename OG_ID to orthogroup_col for consistent merging, if needed
        if 'OG_ID' in df_diversity.columns and orthogroup_col not in df_diversity.columns:
            df_diversity = df_diversity.rename(columns={'OG_ID': orthogroup_col})
        
        # Select and rename relevant diversity columns
        diversity_cols_to_use = {
            orthogroup_col: orthogroup_col,
            'Tree_n_tips': 'Observed_Richness', # Number of tips in the tree
            'Tree_entropy': 'Shannon_Entropy'   # Shannon entropy from tree
        }
        # Filter for columns that actually exist in the loaded df_diversity
        actual_diversity_cols = {k: v for k, v in diversity_cols_to_use.items() if k in df_diversity.columns}
        
        if len(actual_diversity_cols) < 3: # Expecting Orthogroup + 2 metrics
            print(f"WARNING: Not all expected diversity columns found in '{diversity_metrics_path}'. Found: {df_diversity.columns.tolist()}")
            print(f"         Will use available columns for diversity metrics: {list(actual_diversity_cols.keys())}")
        
        # Create the subset using only the columns that were actually found
        df_diversity_subset = df_diversity[list(actual_diversity_cols.keys())].rename(columns=actual_diversity_cols).copy()
        print(f"Loaded and processed diversity metrics for {len(df_diversity_subset)} orthogroups.")
        if df_diversity_subset.empty:
            print("WARNING: Diversity metrics DataFrame is empty after processing.")
            
    except FileNotFoundError:
        print(f"ERROR: Diversity metrics file not found at '{diversity_metrics_path}'. Diversity metrics will be missing.")
        df_diversity_subset = pd.DataFrame(columns=[orthogroup_col, 'Observed_Richness', 'Shannon_Entropy']) # Empty df for graceful failure
    except Exception as e:
        print(f"Error loading or processing diversity metrics file '{diversity_metrics_path}': {e}")
        df_diversity_subset = pd.DataFrame(columns=[orthogroup_col, 'Observed_Richness', 'Shannon_Entropy'])


    # --- 2. Process Input Protein IDs to Get Unique Orthogroups ---
    print("\n--- Identifying Unique Orthogroups for Case Studies ---")
    case_study_ogs_info = []
    seen_ogs = set()

    for ref_pid in case_study_protein_ids:
        # Ensure protein_id_col is correctly defined and exists in df_full
        if protein_id_col not in df_full.columns:
            print(f"CRITICAL ERROR: Protein ID column '{protein_id_col}' not found in df_full. Cannot map Protein IDs to OGs.")
            break 
        og_entry = df_full[df_full[protein_id_col] == ref_pid]
        if not og_entry.empty:
            # Ensure orthogroup_col exists
            if orthogroup_col not in og_entry.columns:
                print(f"CRITICAL ERROR: Orthogroup column '{orthogroup_col}' not found in df_full. Cannot map Protein IDs to OGs.")
                break 
            og_id = og_entry[orthogroup_col].iloc[0]
            if og_id not in seen_ogs:
                case_study_ogs_info.append({'Ref_ProteinID': ref_pid, orthogroup_col: og_id})
                seen_ogs.add(og_id)
        else:
            print(f"Warning: Reference Protein ID {ref_pid} not found in df_full.")
    
    if not case_study_ogs_info:
        print("ERROR: No valid orthogroups found for the provided Protein IDs. Cannot create case studies.")
    else:
        print(f"Identified {len(case_study_ogs_info)} unique orthogroups for case studies.")

        # --- 3. Gather Data for Each Case Study OG ---
        plot_data_list = []
        for og_info in case_study_ogs_info:
            og_id = og_info[orthogroup_col]
            ref_pid = og_info['Ref_ProteinID']
            
            df_og_members = df_full[df_full[orthogroup_col] == og_id]
            if df_og_members.empty:
                print(f"Warning: No members found for {og_id} in df_full during data gathering. Skipping.")
                continue

            num_members = len(df_og_members)
            group_name = df_og_members[group_col].iloc[0] if group_col in df_og_members.columns and not df_og_members[group_col].empty else "N/A"
            
            apsi_value_series = df_og_summary[df_og_summary[orthogroup_col] == og_id][apsi_col]
            apsi_value = apsi_value_series.iloc[0] if not apsi_value_series.empty else np.nan
            
            if structurally_dark_col in df_og_members.columns:
                percent_dark = (df_og_members[structurally_dark_col].sum() / num_members) * 100 if num_members > 0 else 0.0
            else:
                percent_dark = np.nan 
                print(f"Warning: Structurally dark column '{structurally_dark_col}' not found for OG {og_id}")

            diversity_entry = df_diversity_subset[df_diversity_subset[orthogroup_col] == og_id]
            shannon_entropy = diversity_entry['Shannon_Entropy'].iloc[0] if not diversity_entry.empty and 'Shannon_Entropy' in diversity_entry.columns else np.nan
            observed_richness = diversity_entry['Observed_Richness'].iloc[0] if not diversity_entry.empty and 'Observed_Richness' in diversity_entry.columns else np.nan
            
            plot_data_list.append({
                'Orthogroup': og_id,
                'Ref_ProteinID': ref_pid,
                'APSI': apsi_value * 100 if not pd.isna(apsi_value) else np.nan, 
                'Percent_Dark': percent_dark,
                'Shannon_Entropy': shannon_entropy,
                'Observed_Richness': observed_richness,
                'Total_Members': num_members,
                'Group': group_name
            })

        df_plot_data = pd.DataFrame(plot_data_list)

        # --- 4. Create the Figure with Subplots ---
        if not df_plot_data.empty:
            num_case_studies = len(df_plot_data)
            subplot_height = 150 
            annotation_lines = 4 
            annotation_line_height = 20 
            total_subplot_height = subplot_height + (annotation_lines * annotation_line_height)
            
            fig = make_subplots(
                rows=num_case_studies, 
                cols=1,
                subplot_titles=[f"<b>OG: {row['Orthogroup']}</b> (Ref: {row['Ref_ProteinID']})" for index, row in df_plot_data.iterrows()],
                vertical_spacing=0.08, 
                row_heights=[total_subplot_height]*num_case_studies 
            )

            for i, (index, row) in enumerate(df_plot_data.iterrows()):
                row_num = i + 1 
                
                fig.add_trace(go.Bar(
                    y=['APSI (%)'], 
                    x=[row['APSI'] if not pd.isna(row['APSI']) else 0], 
                    orientation='h', 
                    name='APSI',
                    marker_color=arcadia_primary_palette[0 % len(arcadia_primary_palette)],
                    text=[f"{row['APSI']:.1f}%" if not pd.isna(row['APSI']) else "N/A"],
                    textposition='outside',
                    width=0.3 
                ), row=row_num, col=1)
                
                fig.add_trace(go.Bar(
                    y=['% Structurally Dark'], 
                    x=[row['Percent_Dark'] if not pd.isna(row['Percent_Dark']) else 0], 
                    orientation='h', 
                    name='% Dark',
                    marker_color=arcadia_primary_palette[1 % len(arcadia_primary_palette)],
                    text=[f"{row['Percent_Dark']:.1f}%" if not pd.isna(row['Percent_Dark']) else "N/A"],
                    textposition='outside',
                    width=0.3 
                ), row=row_num, col=1)

                fig.update_xaxes(range=[0, 105], row=row_num, col=1, showgrid=False, zeroline=False, linecolor='black', linewidth=1.5, ticks="outside", ticklen=5, tickwidth=1.5, tickcolor='black', title_text="Percentage")
                fig.update_yaxes(showgrid=False, zeroline=False, linecolor='black', linewidth=1.5, ticks="outside", ticklen=5, tickwidth=1.5, tickcolor='black', row=row_num, col=1, autorange="reversed", title_text="") 

                annotations_texts = [
                    f"<b>Group:</b> {row['Group']}",
                    f"<b>Total Members:</b> {int(row['Total_Members']) if not pd.isna(row['Total_Members']) else 'N/A'}",
                    f"<b>Shannon Entropy (Tree):</b> {row['Shannon_Entropy']:.2f}" if not pd.isna(row['Shannon_Entropy']) else "Shannon Entropy: N/A",
                    f"<b>Observed Richness (Tree Tips):</b> {int(row['Observed_Richness']) if not pd.isna(row['Observed_Richness']) else 'Observed Richness: N/A'}"
                ]
                
                current_subplot_yaxis = f"y{row_num}" 
                
                for k, text in enumerate(annotations_texts):
                    fig.add_annotation(
                        text=text,
                        align='left',
                        showarrow=False,
                        xref='paper', 
                        yref=current_subplot_yaxis, 
                        x=0.02, 
                        y=-0.6 - (k*0.25), 
                        xanchor='left',
                        yanchor='top', 
                        font=dict(size=10, color="black")
                    )
            
            fig.update_layout(
                height=total_subplot_height * num_case_studies + 100, 
                showlegend=False,
                plot_bgcolor='rgba(0,0,0,0)', 
                paper_bgcolor='rgba(0,0,0,0)',
                margin=dict(l=150, r=50, t=50 + (num_case_studies*25), b=50 + (num_case_studies * annotation_lines * 5)) 
            )
            
            # ****** CORRECTED SECTION FOR APPLYING FONT STYLES ******
            if 'plotly_layout_defaults' in locals():
                default_font_family = plotly_layout_defaults.font.family if plotly_layout_defaults.font else 'Arial'
                default_font_size_axis_title = plotly_layout_defaults.xaxis.title.font.size if plotly_layout_defaults.xaxis and plotly_layout_defaults.xaxis.title and plotly_layout_defaults.xaxis.title.font else 12
                
                for i in range(1, num_case_studies + 1):
                    # Subplot titles (annotations for subplots)
                    # Plotly subplot titles are actually part of the layout.annotations
                    # The first num_case_studies annotations are the subplot titles if subplot_titles was used.
                    if len(fig.layout.annotations) >= i: 
                         fig.layout.annotations[i-1].font.size=14 
                         fig.layout.annotations[i-1].font.family=default_font_family
                    
                    # X-axis titles for each subplot
                    if hasattr(fig.layout, f'xaxis{i}') and fig.layout[f'xaxis{i}'].title:
                        fig.layout[f'xaxis{i}'].title.font.size = default_font_size_axis_title
                        fig.layout[f'xaxis{i}'].title.font.family = default_font_family
                    
                    # Y-axis tick labels for each subplot
                    if hasattr(fig.layout, f'yaxis{i}'):
                        fig.layout[f'yaxis{i}'].tickfont.size = 11 # Adjust y-axis category label size
                        fig.layout[f'yaxis{i}'].tickfont.family = default_font_family
            # ****** END OF CORRECTED SECTION ******

            fig.show()

            fig_case_study_path_html = Path(output_figure_dir) / "figure8_case_studies_summary.html"
            fig.write_html(str(fig_case_study_path_html))
            print(f"Figure 8 (Case Studies) HTML saved to: {fig_case_study_path_html}")
            try:
                fig_case_study_path_pdf = Path(output_figure_dir) / "figure8_case_studies_summary.pdf"
                fig.write_image(str(fig_case_study_path_pdf), width=800, height=total_subplot_height * num_case_studies + 100) 
                print(f"Figure 8 (Case Studies) PDF saved to: {fig_case_study_path_pdf}")
            except Exception as e:
                print(f"Warning: Could not export Figure 8 (Case Studies) to PDF. Error: {e}")
            try:
                fig_case_study_path_svg = Path(output_figure_dir) / "figure8_case_studies_summary.svg"
                fig.write_image(str(fig_case_study_path_svg), width=800, height=total_subplot_height * num_case_studies + 100) 
                print(f"Figure 8 (Case Studies) SVG saved to: {fig_case_study_path_svg}")
            except Exception as e:
                print(f"Warning: Could not export Figure 8 (Case Studies) to SVG. Error: {e}")

            case_study_summary_path = Path(output_summary_dir_phase1) / "figure8_case_studies_data.csv"
            try:
                df_plot_data.to_csv(case_study_summary_path, index=False)
                print(f"Case study summary data saved to: {case_study_summary_path}")
            except Exception as e:
                print(f"Error saving case study summary data: {e}")
        else:
            print("No data processed for plotting case studies after gathering all metrics.")

print("\n\n--- Cell 11 (Figure 8 Case Studies) Complete ---")
print(f"Figures and summaries saved to '{output_figure_dir}' and '{output_summary_dir_phase1}'.")



In [None]:
# Cell 12: Figure 8 (User's numbering) - Case Studies: Faceted by Metric

# --- Imports & Setup Assumptions ---
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from pathlib import Path

# This cell assumes that pandas (pd), numpy (np), plotly.express (px),
# plotly.graph_objects (go), Path from pathlib, make_subplots,
# and all necessary variables/helper functions from Cell 1 (Setup) are available.
# Key variables from Cell 1 & subsequent cells:
# - df_full: Main DataFrame with all protein annotations.
# - df_og_summary: DataFrame created in Cell 9, containing OG-level summaries including Intra_OG_APSI.
# - output_figure_dir, output_summary_dir_phase1: Directories to save plots/data.
# - plotly_layout_defaults: Default layout for Plotly figures.
# - arcadia_primary_palette, arcadia_secondary_palette: Color palettes.
# - orthogroup_col, protein_id_col, structurally_dark_col, apsi_col ('Intra_OG_APSI'), 
#   group_col, broad_func_cat_col

print("\n\n--- Generating Figure 8: Case Studies (Faceted by Metric) (Cell 12) ---")

# --- Configuration for Case Studies ---
case_study_protein_ids_str = "UYP47028.1;KKK44605.1;KKK42122.1;OLS22855.1;TFG12995.1;OLS30618.1;TET76256.1;OLS23013.1;NHJ47629.1"
case_study_protein_ids = [pid.strip() for pid in case_study_protein_ids_str.split(';')]

diversity_metrics_path = "orthogroup_diversity_metrics.csv" # Path to the uploaded diversity file

# --- Check if necessary DataFrames are available ---
error_found = False
required_dataframes = ['df_full', 'df_og_summary']
for df_name in required_dataframes:
    if df_name not in locals() or locals()[df_name].empty:
        print(f"ERROR: DataFrame '{df_name}' not found or is empty. Please run prerequisite cells.")
        error_found = True

required_cols_df_full = [orthogroup_col, protein_id_col, structurally_dark_col, group_col, broad_func_cat_col]
required_cols_df_og_summary = [orthogroup_col, apsi_col]

if not error_found:
    if not all(col in df_full.columns for col in required_cols_df_full):
        missing = [col for col in required_cols_df_full if col not in df_full.columns]
        print(f"ERROR: 'df_full' is missing required columns: {missing}.")
        error_found = True
    if not all(col in df_og_summary.columns for col in required_cols_df_og_summary):
        missing = [col for col in required_cols_df_og_summary if col not in df_og_summary.columns]
        print(f"ERROR: 'df_og_summary' is missing required columns: {missing}.")
        error_found = True

if error_found:
    print("Cannot proceed with case study generation due to missing data or columns.")
else:
    # --- 1. Load Diversity Metrics ---
    print(f"\n--- Loading Orthogroup Diversity Metrics from: {diversity_metrics_path} ---")
    try:
        df_diversity = pd.read_csv(diversity_metrics_path)
        if 'OG_ID' in df_diversity.columns and orthogroup_col not in df_diversity.columns:
            df_diversity = df_diversity.rename(columns={'OG_ID': orthogroup_col})
        
        diversity_cols_to_use = {
            orthogroup_col: orthogroup_col,
            'Tree_n_tips': 'Observed_Richness',
            'Tree_entropy': 'Shannon_Entropy'
        }
        actual_diversity_cols = {k: v for k, v in diversity_cols_to_use.items() if k in df_diversity.columns}
        df_diversity_subset = df_diversity[list(actual_diversity_cols.keys())].rename(columns=actual_diversity_cols).copy()
        print(f"Loaded and processed diversity metrics for {len(df_diversity_subset)} orthogroups.")
        if df_diversity_subset.empty: print("WARNING: Diversity metrics DataFrame is empty.")
            
    except FileNotFoundError:
        print(f"ERROR: Diversity metrics file '{diversity_metrics_path}' not found. Metrics will be missing.")
        df_diversity_subset = pd.DataFrame(columns=[orthogroup_col, 'Observed_Richness', 'Shannon_Entropy'])
    except Exception as e:
        print(f"Error loading diversity metrics: {e}")
        df_diversity_subset = pd.DataFrame(columns=[orthogroup_col, 'Observed_Richness', 'Shannon_Entropy'])

    # --- 2. Process Input Protein IDs to Get Unique Orthogroups & Ref Protein Info ---
    print("\n--- Identifying Unique Orthogroups and Reference Protein Info ---")
    case_study_ogs_info_list = []
    seen_ogs = set()

    for ref_pid in case_study_protein_ids:
        og_entry = df_full[df_full[protein_id_col] == ref_pid]
        if not og_entry.empty:
            og_id = og_entry[orthogroup_col].iloc[0]
            ref_func_cat = og_entry[broad_func_cat_col].iloc[0] if broad_func_cat_col in og_entry.columns else "N/A"
            if og_id not in seen_ogs:
                case_study_ogs_info_list.append({
                    'Ref_ProteinID': ref_pid, 
                    orthogroup_col: og_id,
                    'Ref_Function': ref_func_cat
                })
                seen_ogs.add(og_id)
        else:
            print(f"Warning: Reference Protein ID {ref_pid} not found in df_full.")
    
    if not case_study_ogs_info_list:
        print("ERROR: No valid orthogroups found. Cannot create case studies.")
    else:
        df_case_study_refs = pd.DataFrame(case_study_ogs_info_list)
        print(f"Identified {len(df_case_study_refs)} unique orthogroups for case studies.")

        # --- 3. Gather Data for Each Case Study OG ---
        plot_data_list = []
        for index, ref_row in df_case_study_refs.iterrows():
            og_id = ref_row[orthogroup_col]
            
            df_og_members = df_full[df_full[orthogroup_col] == og_id]
            if df_og_members.empty: continue

            num_members = len(df_og_members)
            group_name = df_og_members[group_col].iloc[0]
            
            apsi_value_series = df_og_summary[df_og_summary[orthogroup_col] == og_id][apsi_col]
            apsi_value = apsi_value_series.iloc[0] * 100 if not apsi_value_series.empty else np.nan
            
            percent_dark = (df_og_members[structurally_dark_col].sum() / num_members) * 100 if num_members > 0 else 0.0
            
            diversity_entry = df_diversity_subset[df_diversity_subset[orthogroup_col] == og_id]
            shannon_entropy = diversity_entry['Shannon_Entropy'].iloc[0] if not diversity_entry.empty and 'Shannon_Entropy' in diversity_entry.columns else np.nan
            observed_richness = diversity_entry['Observed_Richness'].iloc[0] if not diversity_entry.empty and 'Observed_Richness' in diversity_entry.columns else np.nan
            
            plot_data_list.append({
                orthogroup_col: og_id,
                'Ref_ProteinID': ref_row['Ref_ProteinID'],
                'Ref_Function': ref_row['Ref_Function'],
                'APSI (%)': apsi_value,
                '% Structurally Dark': percent_dark,
                'Shannon_Entropy': shannon_entropy,
                'Observed_Richness': observed_richness,
                'Total_Members': num_members,
                'Group': group_name,
                # Create a combined label for x-axis
                'OG_Label': f"{og_id}<br>({ref_row['Ref_Function']})"
            })

        df_plot_data = pd.DataFrame(plot_data_list)

        # --- 4. Create the Faceted Figure ---
        if not df_plot_data.empty:
            num_case_studies = len(df_plot_data)
            
            # Assign colors to OGs
            og_color_palette = arcadia_primary_palette + arcadia_secondary_palette
            og_color_map = {
                og: og_color_palette[i % len(og_color_palette)] 
                for i, og in enumerate(df_plot_data[orthogroup_col].unique())
            }

            metrics_to_plot = ['APSI (%)', '% Structurally Dark', 'Shannon_Entropy', 'Observed_Richness']
            metric_titles = {
                'APSI (%)': 'Intra-OG APSI (%)',
                '% Structurally Dark': '% Structurally Dark Members',
                'Shannon_Entropy': 'Shannon Entropy (Tree-based)',
                'Observed_Richness': 'Observed Richness (Tree Tips)'
            }
            
            fig = make_subplots(
                rows=len(metrics_to_plot), 
                cols=1,
                subplot_titles=[metric_titles[metric] for metric in metrics_to_plot],
                shared_xaxes=True, # Share x-axis (OGs)
                vertical_spacing=0.1 
            )

            for i, metric in enumerate(metrics_to_plot):
                row_num = i + 1
                for og_idx, og_data_row in df_plot_data.iterrows():
                    fig.add_trace(go.Bar(
                        x=[og_data_row['OG_Label']], # Category on x-axis
                        y=[og_data_row[metric] if not pd.isna(og_data_row[metric]) else 0],
                        name=og_data_row[orthogroup_col], # Legend entry for OG
                        legendgroup=og_data_row[orthogroup_col], # Group bars by OG for consistent color
                        marker_color=og_color_map.get(og_data_row[orthogroup_col]),
                        showlegend=(row_num == 1), # Show legend only for the first subplot
                        text=[f"{og_data_row[metric]:.1f}" if not pd.isna(og_data_row[metric]) and isinstance(og_data_row[metric], float) and metric != 'Observed_Richness' 
                              else (f"{int(og_data_row[metric])}" if not pd.isna(og_data_row[metric]) and metric == 'Observed_Richness' else "N/A")],
                        textposition='outside'
                    ), row=row_num, col=1)
                
                fig.update_yaxes(title_text="", row=row_num, col=1, showgrid=False, zeroline=False, linecolor='black', linewidth=1.5, ticks="outside", ticklen=5, tickwidth=1.5, tickcolor='black')
                if metric == 'APSI (%)' or metric == '% Structurally Dark':
                    fig.update_yaxes(range=[0, 105], row=row_num, col=1)
                
                # Apply default axis title font from plotly_layout_defaults
                if 'plotly_layout_defaults' in locals():
                    default_font_family = plotly_layout_defaults.font.family if plotly_layout_defaults.font else 'Arial'
                    default_font_size_axis_title = plotly_layout_defaults.xaxis.title.font.size if plotly_layout_defaults.xaxis and plotly_layout_defaults.xaxis.title and plotly_layout_defaults.xaxis.title.font else 12
                    fig.layout[f'yaxis{row_num}'].title.font.update(size=default_font_size_axis_title, family=default_font_family)


            fig.update_xaxes(
                showgrid=False, zeroline=False, linecolor='black', linewidth=1.5, 
                ticks="outside", ticklen=5, tickwidth=1.5, tickcolor='black',
                tickangle=0 # Keep labels horizontal if possible, or adjust if too crowded
            )
            # Update subplot title fonts
            for annotation in fig.layout.annotations:
                annotation.font.size = 14
                if 'plotly_layout_defaults' in locals() and plotly_layout_defaults.font:
                    annotation.font.family = plotly_layout_defaults.font.family


            fig.update_layout(
                height=250 * len(metrics_to_plot), 
                # title_text="Case Study Orthogroup Metrics", # Main title
                plot_bgcolor='rgba(0,0,0,0)', 
                paper_bgcolor='rgba(0,0,0,0)',
                margin=dict(l=50, r=50, t=80, b=150), # Adjust bottom margin for labels
                barmode='group', # Group bars for each OG side-by-side if multiple metrics were on same plot
                legend_title_text='Orthogroup'
            )
            fig.show()

            # --- Save Figure and Summary Data ---
            fig_path_html = Path(output_figure_dir) / "figure8_case_studies_faceted.html"
            fig.write_html(str(fig_path_html))
            print(f"Figure 8 (Faceted Case Studies) HTML saved to: {fig_path_html}")
            try:
                fig_path_pdf = Path(output_figure_dir) / "figure8_case_studies_faceted.pdf"
                fig.write_image(str(fig_path_pdf), width=max(800, 120*num_case_studies), height=250*len(metrics_to_plot))
                print(f"Figure 8 (Faceted Case Studies) PDF saved to: {fig_path_pdf}")
            except Exception as e: print(f"Warning: Could not export Figure 8 to PDF. Error: {e}")
            try:
                fig_path_svg = Path(output_figure_dir) / "figure8_case_studies_faceted.svg"
                fig.write_image(str(fig_path_svg), width=max(800, 120*num_case_studies), height=250*len(metrics_to_plot))
                print(f"Figure 8 (Faceted Case Studies) SVG saved to: {fig_path_svg}")
            except Exception as e: print(f"Warning: Could not export Figure 8 to SVG. Error: {e}")

            summary_path = Path(output_summary_dir_phase1) / "figure8_case_studies_faceted_data.csv"
            try:
                df_plot_data.to_csv(summary_path, index=False)
                print(f"Case study summary data saved to: {summary_path}")
            except Exception as e: print(f"Error saving case study summary data: {e}")
        else:
            print("No data processed for plotting case studies after gathering all metrics.")

print("\n\n--- Cell 12 (Figure 8 Faceted Case Studies) Complete ---")



In [None]:
# Cell 13: Figure 9 - Overview of Orthogroup Diversity Metrics

# --- Imports & Setup Assumptions ---
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from pathlib import Path

# This cell assumes that pandas (pd), numpy (np), plotly.express (px),
# plotly.graph_objects (go), Path from pathlib,
# and all necessary variables/helper functions from Cell 1 (Setup) are available.
# Key variables from Cell 1 & subsequent cells:
# - df_og_summary: DataFrame created in Cell 9, containing OG-level summaries (Orthogroup, Intra_OG_APSI, Group).
# - output_figure_dir, output_summary_dir_phase1: Directories to save plots/data.
# - plotly_layout_defaults: Default layout for Plotly figures.
# - arcadia_primary_palette, group_colors: Color palettes/maps.
# - orthogroup_col, apsi_col ('Intra_OG_APSI'), group_col

print("\n\n--- Generating Figure 9: Overview of Orthogroup Diversity Metrics (Cell 13) ---")

# --- Configuration ---
diversity_metrics_path = "orthogroup_diversity_metrics.csv" # Path to the uploaded diversity file

# --- Check if necessary DataFrames are available ---
error_found = False
if 'df_og_summary' not in locals() or df_og_summary.empty:
    print(f"ERROR: DataFrame 'df_og_summary' (with OG-level APSI and Group) not found or is empty. Please run Cell 9 first.")
    error_found = True
elif not all(col in df_og_summary.columns for col in [orthogroup_col, apsi_col, group_col]):
    missing_cols = [col for col in [orthogroup_col, apsi_col, group_col] if col not in df_og_summary.columns]
    print(f"ERROR: 'df_og_summary' is missing required columns: {missing_cols}.")
    error_found = True

if error_found:
    print("Cannot proceed with diversity overview generation due to missing base data.")
    # Create an empty df_merged_diversity to prevent further errors in this cell if it's used later
    df_merged_diversity = pd.DataFrame()
else:
    # --- 1. Load Diversity Metrics ---
    print(f"\n--- Loading Orthogroup Diversity Metrics from: {diversity_metrics_path} ---")
    try:
        df_diversity = pd.read_csv(diversity_metrics_path)
        if 'OG_ID' in df_diversity.columns and orthogroup_col not in df_diversity.columns:
            df_diversity = df_diversity.rename(columns={'OG_ID': orthogroup_col})
        
        # Select and rename relevant diversity columns
        diversity_cols_to_use = {
            orthogroup_col: orthogroup_col,
            'Tree_n_tips': 'Observed_Richness',
            'Tree_entropy': 'Shannon_Entropy'
        }
        actual_diversity_cols = {k: v for k, v in diversity_cols_to_use.items() if k in df_diversity.columns}
        df_diversity_subset = df_diversity[list(actual_diversity_cols.keys())].rename(columns=actual_diversity_cols).copy()
        
        # Ensure orthogroup_col is string for merging
        if orthogroup_col in df_diversity_subset.columns:
            df_diversity_subset[orthogroup_col] = df_diversity_subset[orthogroup_col].astype(str)
            
        print(f"Loaded and processed diversity metrics for {len(df_diversity_subset)} orthogroups.")
        if df_diversity_subset.empty: print("WARNING: Diversity metrics DataFrame is empty.")
            
    except FileNotFoundError:
        print(f"ERROR: Diversity metrics file '{diversity_metrics_path}' not found. Diversity metrics will be missing.")
        df_diversity_subset = pd.DataFrame(columns=[orthogroup_col, 'Observed_Richness', 'Shannon_Entropy'])
    except Exception as e:
        print(f"Error loading or processing diversity metrics file: {e}")
        df_diversity_subset = pd.DataFrame(columns=[orthogroup_col, 'Observed_Richness', 'Shannon_Entropy'])

    # --- 2. Merge Diversity Metrics with df_og_summary (contains APSI and Group) ---
    print("\n--- Merging Diversity Metrics with OG Summary Data (APSI, Group) ---")
    if not df_diversity_subset.empty and orthogroup_col in df_diversity_subset.columns:
        # Ensure orthogroup_col in df_og_summary is also string
        df_og_summary[orthogroup_col] = df_og_summary[orthogroup_col].astype(str)
        
        df_merged_diversity = pd.merge(df_og_summary, df_diversity_subset, on=orthogroup_col, how='inner')
        # Using 'inner' merge to only keep OGs present in both (i.e., those with diversity metrics and APSI/Group)
        print(f"Merged diversity data. Resulting DataFrame shape: {df_merged_diversity.shape}")
        if df_merged_diversity.empty:
            print("WARNING: Merged diversity DataFrame is empty. No common OGs found or one of the DFs was empty.")
    else:
        print("Skipping merge as diversity data is empty or missing orthogroup column.")
        df_merged_diversity = pd.DataFrame() # Ensure it's defined as empty

    # --- 3. Generate Plots ---
    if not df_merged_diversity.empty:
        
        # A. Distribution of Shannon Entropy
        if 'Shannon_Entropy' in df_merged_diversity.columns:
            print("\n--- Plotting Distribution of Shannon Entropy ---")
            fig_shannon = px.histogram(
                df_merged_diversity.dropna(subset=['Shannon_Entropy']), 
                x='Shannon_Entropy', 
                color=group_col,
                marginal="box", # or "rug", "violin"
                barmode='overlay',
                # title="Distribution of Shannon Entropy per Orthogroup", # No title
                labels={'Shannon_Entropy': 'Shannon Entropy (Tree-based)'},
                color_discrete_map=group_colors # from Cell 1
            )
            fig_shannon.update_layout(plotly_layout_defaults)
            fig_shannon.update_xaxes(title_text='Shannon Entropy (Tree-based)', showgrid=False)
            fig_shannon.update_yaxes(title_text='Number of Orthogroups', showgrid=False)
            fig_shannon.update_traces(opacity=0.75)
            fig_shannon.show()
            # Save plot
            fig_shannon_path_html = Path(output_figure_dir) / "figure9a_shannon_entropy_distribution.html"
            fig_shannon.write_html(str(fig_shannon_path_html))
            print(f"Figure 9A (Shannon Entropy) HTML saved to: {fig_shannon_path_html}")
            try:
                fig_shannon_path_pdf = Path(output_figure_dir) / "figure9a_shannon_entropy_distribution.pdf"
                fig_shannon.write_image(str(fig_shannon_path_pdf))
            except Exception as e: print(f"PDF export error for Shannon Entropy: {e}")
            try:
                fig_shannon_path_svg = Path(output_figure_dir) / "figure9a_shannon_entropy_distribution.svg"
                fig_shannon.write_image(str(fig_shannon_path_svg))
            except Exception as e: print(f"SVG export error for Shannon Entropy: {e}")
        else:
            print("Skipping Shannon Entropy distribution: column not found.")

        # B. Distribution of Observed Richness
        if 'Observed_Richness' in df_merged_diversity.columns:
            print("\n--- Plotting Distribution of Observed Richness (Tree Tips) ---")
            fig_richness = px.histogram(
                df_merged_diversity.dropna(subset=['Observed_Richness']), 
                x='Observed_Richness', 
                color=group_col,
                marginal="box",
                barmode='overlay',
                # title="Distribution of Observed Richness (Tree Tips) per Orthogroup", # No title
                labels={'Observed_Richness': 'Observed Richness (Number of Tree Tips)'},
                color_discrete_map=group_colors
            )
            fig_richness.update_layout(plotly_layout_defaults)
            fig_richness.update_xaxes(title_text='Observed Richness (Number of Tree Tips)', showgrid=False) # Consider log scale if highly skewed: type="log"
            fig_richness.update_yaxes(title_text='Number of Orthogroups', showgrid=False)
            fig_richness.update_traces(opacity=0.75)
            fig_richness.show()
            # Save plot
            fig_richness_path_html = Path(output_figure_dir) / "figure9b_observed_richness_distribution.html"
            fig_richness.write_html(str(fig_richness_path_html))
            print(f"Figure 9B (Observed Richness) HTML saved to: {fig_richness_path_html}")
            try:
                fig_richness_path_pdf = Path(output_figure_dir) / "figure9b_observed_richness_distribution.pdf"
                fig_richness.write_image(str(fig_richness_path_pdf))
            except Exception as e: print(f"PDF export error for Richness: {e}")
            try:
                fig_richness_path_svg = Path(output_figure_dir) / "figure9b_observed_richness_distribution.svg"
                fig_richness.write_image(str(fig_richness_path_svg))
            except Exception as e: print(f"SVG export error for Richness: {e}")
        else:
            print("Skipping Observed Richness distribution: column not found.")

        # C. Shannon Entropy vs. Observed Richness
        if 'Shannon_Entropy' in df_merged_diversity.columns and 'Observed_Richness' in df_merged_diversity.columns:
            print("\n--- Plotting Shannon Entropy vs. Observed Richness ---")
            fig_ent_vs_rich = px.scatter(
                df_merged_diversity.dropna(subset=['Shannon_Entropy', 'Observed_Richness']),
                x='Observed_Richness',
                y='Shannon_Entropy',
                color=group_col,
                # title='Shannon Entropy vs. Observed Richness per Orthogroup', # No title
                labels={'Observed_Richness': 'Observed Richness (Tree Tips)', 'Shannon_Entropy': 'Shannon Entropy (Tree-based)'},
                color_discrete_map=group_colors,
                opacity=0.7,
                trendline="ols", # Ordinary Least Squares trendline
                trendline_scope="overall" # or "trace" for per-group trendlines
            )
            fig_ent_vs_rich.update_layout(plotly_layout_defaults)
            fig_ent_vs_rich.update_xaxes(title_text='Observed Richness (Tree Tips)', showgrid=False) # Consider log scale: type="log"
            fig_ent_vs_rich.update_yaxes(title_text='Shannon Entropy (Tree-based)', showgrid=False)
            fig_ent_vs_rich.show()
            # Save plot
            fig_ent_rich_path_html = Path(output_figure_dir) / "figure9c_entropy_vs_richness.html"
            fig_ent_vs_rich.write_html(str(fig_ent_rich_path_html))
            print(f"Figure 9C (Entropy vs Richness) HTML saved to: {fig_ent_rich_path_html}")
            try:
                fig_ent_rich_path_pdf = Path(output_figure_dir) / "figure9c_entropy_vs_richness.pdf"
                fig_ent_vs_rich.write_image(str(fig_ent_rich_path_pdf))
            except Exception as e: print(f"PDF export error for Entropy vs Richness: {e}")
            try:
                fig_ent_rich_path_svg = Path(output_figure_dir) / "figure9c_entropy_vs_richness.svg"
                fig_ent_vs_rich.write_image(str(fig_ent_rich_path_svg))
            except Exception as e: print(f"SVG export error for Entropy vs Richness: {e}")
        else:
            print("Skipping Shannon Entropy vs. Observed Richness plot: columns not found.")

        # D. Shannon Entropy vs. APSI
        if 'Shannon_Entropy' in df_merged_diversity.columns and apsi_col in df_merged_diversity.columns:
            print("\n--- Plotting Shannon Entropy vs. APSI ---")
            # Convert APSI to percentage for plotting if it's not already
            df_merged_diversity['APSI_Percent'] = df_merged_diversity[apsi_col] * 100
            
            fig_ent_vs_apsi = px.scatter(
                df_merged_diversity.dropna(subset=['Shannon_Entropy', 'APSI_Percent']),
                x='APSI_Percent',
                y='Shannon_Entropy',
                color=group_col,
                # title='Shannon Entropy vs. Intra-OG APSI', # No title
                labels={'APSI_Percent': 'Intra-OG APSI (%)', 'Shannon_Entropy': 'Shannon Entropy (Tree-based)'},
                color_discrete_map=group_colors,
                opacity=0.7,
                trendline="ols",
                trendline_scope="overall"
            )
            fig_ent_vs_apsi.update_layout(plotly_layout_defaults)
            fig_ent_vs_apsi.update_xaxes(title_text='Intra-OG APSI (%)', showgrid=False)
            fig_ent_vs_apsi.update_yaxes(title_text='Shannon Entropy (Tree-based)', showgrid=False)
            fig_ent_vs_apsi.show()
            # Save plot
            fig_ent_apsi_path_html = Path(output_figure_dir) / "figure9d_entropy_vs_apsi.html"
            fig_ent_vs_apsi.write_html(str(fig_ent_apsi_path_html))
            print(f"Figure 9D (Entropy vs APSI) HTML saved to: {fig_ent_apsi_path_html}")
            try:
                fig_ent_apsi_path_pdf = Path(output_figure_dir) / "figure9d_entropy_vs_apsi.pdf"
                fig_ent_vs_apsi.write_image(str(fig_ent_apsi_path_pdf))
            except Exception as e: print(f"PDF export error for Entropy vs APSI: {e}")
            try:
                fig_ent_apsi_path_svg = Path(output_figure_dir) / "figure9d_entropy_vs_apsi.svg"
                fig_ent_vs_apsi.write_image(str(fig_ent_apsi_path_svg))
            except Exception as e: print(f"SVG export error for Entropy vs APSI: {e}")
        else:
            print("Skipping Shannon Entropy vs. APSI plot: columns not found.")

        # E. Observed Richness vs. APSI
        if 'Observed_Richness' in df_merged_diversity.columns and apsi_col in df_merged_diversity.columns:
            print("\n--- Plotting Observed Richness vs. APSI ---")
            # APSI_Percent column should exist from previous plot
            
            fig_rich_vs_apsi = px.scatter(
                df_merged_diversity.dropna(subset=['Observed_Richness', 'APSI_Percent']),
                x='APSI_Percent',
                y='Observed_Richness',
                color=group_col,
                # title='Observed Richness vs. Intra-OG APSI', # No title
                labels={'APSI_Percent': 'Intra-OG APSI (%)', 'Observed_Richness': 'Observed Richness (Tree Tips)'},
                color_discrete_map=group_colors,
                opacity=0.7,
                trendline="ols",
                trendline_scope="overall"
            )
            fig_rich_vs_apsi.update_layout(plotly_layout_defaults)
            fig_rich_vs_apsi.update_xaxes(title_text='Intra-OG APSI (%)', showgrid=False)
            fig_rich_vs_apsi.update_yaxes(title_text='Observed Richness (Tree Tips)', showgrid=False) # Consider log scale: type="log"
            fig_rich_vs_apsi.show()
            # Save plot
            fig_rich_apsi_path_html = Path(output_figure_dir) / "figure9e_richness_vs_apsi.html"
            fig_rich_vs_apsi.write_html(str(fig_rich_apsi_path_html))
            print(f"Figure 9E (Richness vs APSI) HTML saved to: {fig_rich_apsi_path_html}")
            try:
                fig_rich_apsi_path_pdf = Path(output_figure_dir) / "figure9e_richness_vs_apsi.pdf"
                fig_rich_vs_apsi.write_image(str(fig_rich_apsi_path_pdf))
            except Exception as e: print(f"PDF export error for Richness vs APSI: {e}")
            try:
                fig_rich_apsi_path_svg = Path(output_figure_dir) / "figure9e_richness_vs_apsi.svg"
                fig_rich_vs_apsi.write_image(str(fig_rich_apsi_path_svg))
            except Exception as e: print(f"SVG export error for Richness vs APSI: {e}")
        else:
            print("Skipping Observed Richness vs. APSI plot: columns not found.")
            
        # Save the merged diversity data
        merged_diversity_summary_path = Path(output_summary_dir_phase1) / "figure9_merged_diversity_and_apsi_data.csv"
        try:
            df_merged_diversity.to_csv(merged_diversity_summary_path, index=False)
            print(f"\nMerged diversity and APSI data saved to: {merged_diversity_summary_path}")
        except Exception as e:
            print(f"Error saving merged diversity data: {e}")
            
    else:
        print("Skipping Figure 9 plots as 'df_merged_diversity' is empty.")

print("\n\n--- Cell 13 (Figure 9 - Overview of Orthogroup Diversity Metrics) Complete ---")



In [None]:
# Cell 14: Figure 10 - Orthogroup Diversity Metrics by Broad Functional Category

# --- Imports & Setup Assumptions ---
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from pathlib import Path

# This cell assumes that pandas (pd), numpy (np), plotly.express (px),
# plotly.graph_objects (go), Path from pathlib,
# and all necessary variables/helper functions from Cell 1 (Setup) are available.
# Key variables from previous cells:
# - df_merged_diversity: DataFrame from Cell 13 (contains OG, APSI, Group, Shannon_Entropy, Observed_Richness).
# - df_full: Main DataFrame (used to get Broad_Functional_Category if not in df_merged_diversity).
# - output_figure_dir, output_summary_dir_phase1: Directories to save plots/data.
# - plotly_layout_defaults: Default layout for Plotly figures.
# - group_colors, broad_category_color_map: Color maps.
# - orthogroup_col, apsi_col ('Intra_OG_APSI'), group_col, broad_func_cat_col

print("\n\n--- Generating Figure 10: Orthogroup Diversity Metrics by Functional Category (Cell 14) ---")

# --- Check if necessary DataFrames and columns are available ---
error_found = False
if 'df_merged_diversity' not in locals() or df_merged_diversity.empty:
    print(f"ERROR: DataFrame 'df_merged_diversity' not found or is empty. Please run Cell 13 first.")
    error_found = True
elif not all(col in df_merged_diversity.columns for col in [orthogroup_col, apsi_col, group_col, 'Shannon_Entropy', 'Observed_Richness']):
    missing_cols = [col for col in [orthogroup_col, apsi_col, group_col, 'Shannon_Entropy', 'Observed_Richness'] if col not in df_merged_diversity.columns]
    print(f"ERROR: 'df_merged_diversity' is missing required metrics columns: {missing_cols}.")
    error_found = True

# Ensure Broad_Functional_Category is in df_merged_diversity
# If not, try to merge it from an OG-level aggregation of df_full
if not error_found and broad_func_cat_col not in df_merged_diversity.columns:
    print(f"'{broad_func_cat_col}' not found in 'df_merged_diversity'. Attempting to add it.")
    if 'df_full' in locals() and not df_full.empty and orthogroup_col in df_full.columns and broad_func_cat_col in df_full.columns:
        # Determine the most frequent functional category for each OG
        # Using .mode()[0] can be problematic if there are ties or empty groups.
        # A safer approach might be to take the first non-null category, or a placeholder.
        def get_dominant_category(series):
            mode = series.mode()
            return mode[0] if not mode.empty else 'Unknown/Unclassified'

        df_og_func_cat = df_full.groupby(orthogroup_col)[broad_func_cat_col].apply(get_dominant_category).reset_index()
        df_og_func_cat[orthogroup_col] = df_og_func_cat[orthogroup_col].astype(str) # Ensure type for merge

        df_merged_diversity = pd.merge(df_merged_diversity, df_og_func_cat, on=orthogroup_col, how='left')
        df_merged_diversity[broad_func_cat_col] = df_merged_diversity[broad_func_cat_col].fillna('Unknown/Unclassified')
        print(f"Added '{broad_func_cat_col}' to 'df_merged_diversity'.")
        if broad_func_cat_col not in df_merged_diversity.columns: # Check if merge failed somehow
             print(f"ERROR: Failed to add '{broad_func_cat_col}' to 'df_merged_diversity'.")
             error_found = True
    else:
        print(f"ERROR: Cannot add '{broad_func_cat_col}'. 'df_full' is missing or lacks necessary columns.")
        error_found = True

if error_found:
    print("Cannot proceed with Figure 10 generation due to missing data or columns.")
else:
    # --- Create Plots ---
    metrics_to_plot_vs_func = {
        apsi_col: 'Intra-OG APSI (%)', # APSI is already 0-1, will multiply by 100 for plotting
        'Shannon_Entropy': 'Shannon Entropy (Tree-based)',
        'Observed_Richness': 'Observed Richness (Tree Tips)'
    }

    # Filter out categories like "Unknown/Unclassified" or "Other Specific Annotation" for cleaner plots if desired
    categories_to_exclude_plot = ['Unknown/Unclassified', 'Other Specific Annotation', 'General Protein Features'] # User can adjust
    df_plot_func = df_merged_diversity[~df_merged_diversity[broad_func_cat_col].isin(categories_to_exclude_plot)].copy()
    if df_plot_func.empty:
        print(f"WARNING: No data remains after excluding categories: {categories_to_exclude_plot}. Using all data.")
        df_plot_func = df_merged_diversity.copy()


    for metric_col, metric_label in metrics_to_plot_vs_func.items():
        if metric_col not in df_plot_func.columns:
            print(f"Skipping plot for {metric_label}: column '{metric_col}' not found.")
            continue

        print(f"\n--- Plotting {metric_label} by Broad Functional Category ---")
        
        # Prepare data for plot (e.g. scale APSI)
        plot_df_metric = df_plot_func.dropna(subset=[metric_col, broad_func_cat_col, group_col]).copy()
        if metric_col == apsi_col: # APSI is 0-1, convert to percentage
            plot_df_metric[metric_col] = plot_df_metric[metric_col] * 100
        
        if plot_df_metric.empty:
            print(f"No data to plot for {metric_label} after dropping NaNs.")
            continue

        fig = px.box(
            plot_df_metric,
            x=broad_func_cat_col,
            y=metric_col,
            color=group_col,
            # title=f"{metric_label} by Functional Category", # No title
            labels={metric_col: metric_label, broad_func_cat_col: "Broad Functional Category"},
            color_discrete_map=group_colors, # from Cell 1
            points=False # "all", False, "outliers"
        )
        fig.update_layout(plotly_layout_defaults)
        fig.update_xaxes(
            title_text="Broad Functional Category", 
            showgrid=False, 
            tickangle=45,
            categoryorder='total descending' # Order categories by median value of the metric (or count)
        )
        fig.update_yaxes(title_text=metric_label, showgrid=False)
        if metric_col == 'Observed_Richness': # Richness can be highly skewed
            fig.update_yaxes(type="log", title_text=f"{metric_label} (Log Scale)")
        
        fig.show()

        # Save plot
        safe_metric_name = metric_col.lower().replace(' ', '_').replace('(', '').replace(')', '').replace('%', 'pct')
        fig_path_html = Path(output_figure_dir) / f"figure10_{safe_metric_name}_vs_function.html"
        fig.write_html(str(fig_path_html))
        print(f"Figure 10 ({metric_label}) HTML saved to: {fig_path_html}")
        try:
            fig_path_pdf = Path(output_figure_dir) / f"figure10_{safe_metric_name}_vs_function.pdf"
            fig.write_image(str(fig_path_pdf))
        except Exception as e: print(f"PDF export error for {metric_label} vs Function: {e}")
        try:
            fig_path_svg = Path(output_figure_dir) / f"figure10_{safe_metric_name}_vs_function.svg"
            fig.write_image(str(fig_path_svg))
        except Exception as e: print(f"SVG export error for {metric_label} vs Function: {e}")

    # Save the data used for these plots
    func_diversity_summary_path = Path(output_summary_dir_phase1) / "figure10_functional_category_diversity_metrics.csv"
    try:
        # Save the df_plot_func which has the categories and metrics
        df_plot_func.to_csv(func_diversity_summary_path, index=False)
        print(f"\nSummary data for functional category diversity saved to: {func_diversity_summary_path}")
    except Exception as e:
        print(f"Error saving functional category diversity summary data: {e}")

print("\n\n--- Cell 14 (Figure 10 - OG Diversity vs. Functional Category) Complete ---")


In [None]:
# Cell 15: Figure 11 - Highlighting Orthogroups with Extreme/Interesting Diversity Profiles (Expanded & Revised Criteria)

# --- Imports & Setup Assumptions ---
import pandas as pd
import numpy as np
from pathlib import Path

# This cell assumes that pandas (pd), numpy (np), Path from pathlib,
# and all necessary variables/helper functions from Cell 1 (Setup) are available.
# Key variables from previous cells:
# - df_merged_diversity: DataFrame from Cell 13 (contains OG, APSI, Group, Shannon_Entropy, Observed_Richness, Broad_Functional_Category).
# - df_full: Main DataFrame (used to get more detailed info like % dark members).
# - output_summary_dir_phase1: Directory to save data.
# - orthogroup_col, apsi_col, group_col, broad_func_cat_col, structurally_dark_col

print("\n\n--- Generating Figure 11: Highlighting OGs with Extreme & Interesting Diversity Profiles (Expanded & Revised Criteria) (Cell 15) ---")

# --- Check if necessary DataFrames and columns are available ---
error_found = False
if 'df_merged_diversity' not in locals() or df_merged_diversity.empty:
    print(f"ERROR: DataFrame 'df_merged_diversity' not found or is empty. Please run Cell 13 (and Cell 14 to ensure func cat) first.")
    error_found = True
elif 'df_full' not in locals() or df_full.empty:
    print(f"ERROR: DataFrame 'df_full' not found or is empty. Please run Cell 1 first.")
    error_found = True
elif not all(col in df_merged_diversity.columns for col in [orthogroup_col, apsi_col, group_col, 'Shannon_Entropy', 'Observed_Richness', broad_func_cat_col]):
    missing_cols = [col for col in [orthogroup_col, apsi_col, group_col, 'Shannon_Entropy', 'Observed_Richness', broad_func_cat_col] if col not in df_merged_diversity.columns]
    print(f"ERROR: 'df_merged_diversity' is missing required columns: {missing_cols}.")
    error_found = True

if error_found:
    print("Cannot proceed with highlighting extreme OGs due to missing data or columns.")
else:
    # --- Define Criteria for "Extreme" OGs ---
    n_percentile_strict = 5  # For single-metric extremes (e.g., top/bottom 5%)
    # Use a more relaxed percentile for finding combined "interesting" profiles
    # This increases the chance of finding OGs in the intersection of two conditions.
    combo_percentile_relaxed = 25 # e.g., top/bottom 25% for combined profiles
    
    extreme_ogs_data = []
    og_reasons_map = {} 

    def add_extreme_og_reason(og_id, reason, df_merged_diversity_row, df_full_og_members):
        if og_id not in og_reasons_map:
            og_reasons_map[og_id] = []
        
        if reason not in og_reasons_map[og_id]:
            og_reasons_map[og_id].append(reason)
        
        if not any(d['Orthogroup'] == og_id for d in extreme_ogs_data):
            percent_dark = np.nan
            num_members = 0
            if not df_full_og_members.empty:
                num_members = len(df_full_og_members)
                if structurally_dark_col in df_full_og_members.columns:
                    percent_dark = (df_full_og_members[structurally_dark_col].sum() / num_members) * 100 if num_members > 0 else 0.0
            
            extreme_ogs_data.append({
                'Orthogroup': og_id,
                'Group': df_merged_diversity_row[group_col],
                'APSI (%)': df_merged_diversity_row[apsi_col] * 100 if not pd.isna(df_merged_diversity_row[apsi_col]) else np.nan,
                'Shannon_Entropy': df_merged_diversity_row['Shannon_Entropy'],
                'Observed_Richness': df_merged_diversity_row['Observed_Richness'],
                'Total_Members_in_OG_from_df_full': num_members,
                'Percent_Structurally_Dark (%)': percent_dark,
                'Broad_Functional_Category': df_merged_diversity_row[broad_func_cat_col]
            })

    # --- Identify OGs based on individual metrics (using n_percentile_strict) ---
    # 1. Highest Shannon Entropy
    if 'Shannon_Entropy' in df_merged_diversity.columns:
        top_entropy_threshold = df_merged_diversity['Shannon_Entropy'].quantile(1 - n_percentile_strict/100)
        df_target = df_merged_diversity[df_merged_diversity['Shannon_Entropy'] >= top_entropy_threshold]
        print(f"\n--- OGs with Highest Shannon Entropy (Top {n_percentile_strict}%, Threshold >= {top_entropy_threshold:.2f}) ---")
        for index, row in df_target.iterrows():
            og_members = df_full[df_full[orthogroup_col] == row[orthogroup_col]]
            add_extreme_og_reason(row[orthogroup_col], f"Top {n_percentile_strict}% Shannon Entropy", row, og_members)
    
    # 2. Lowest Shannon Entropy
    if 'Shannon_Entropy' in df_merged_diversity.columns:
        bottom_entropy_threshold = df_merged_diversity['Shannon_Entropy'].quantile(n_percentile_strict/100)
        df_target = df_merged_diversity[df_merged_diversity['Shannon_Entropy'] <= bottom_entropy_threshold]
        print(f"\n--- OGs with Lowest Shannon Entropy (Bottom {n_percentile_strict}%, Threshold <= {bottom_entropy_threshold:.2f}) ---")
        for index, row in df_target.iterrows():
            og_members = df_full[df_full[orthogroup_col] == row[orthogroup_col]]
            add_extreme_og_reason(row[orthogroup_col], f"Bottom {n_percentile_strict}% Shannon Entropy", row, og_members)

    # 3. Lowest APSI (Most Divergent)
    if apsi_col in df_merged_diversity.columns:
        bottom_apsi_threshold = df_merged_diversity[apsi_col].quantile(n_percentile_strict/100)
        df_target = df_merged_diversity[df_merged_diversity[apsi_col] <= bottom_apsi_threshold]
        print(f"\n--- OGs with Lowest APSI (Bottom {n_percentile_strict}%, Threshold <= {bottom_apsi_threshold*100:.1f}%) ---")
        for index, row in df_target.iterrows():
            og_members = df_full[df_full[orthogroup_col] == row[orthogroup_col]]
            add_extreme_og_reason(row[orthogroup_col], f"Bottom {n_percentile_strict}% APSI", row, og_members)

    # 4. Highest APSI (Most Conserved)
    if apsi_col in df_merged_diversity.columns:
        top_apsi_threshold = df_merged_diversity[apsi_col].quantile(1 - n_percentile_strict/100)
        df_target = df_merged_diversity[df_merged_diversity[apsi_col] >= top_apsi_threshold]
        print(f"\n--- OGs with Highest APSI (Top {n_percentile_strict}%, Threshold >= {top_apsi_threshold*100:.1f}%) ---")
        for index, row in df_target.iterrows():
            og_members = df_full[df_full[orthogroup_col] == row[orthogroup_col]]
            add_extreme_og_reason(row[orthogroup_col], f"Top {n_percentile_strict}% APSI", row, og_members)

    # 5. Highest Observed Richness
    if 'Observed_Richness' in df_merged_diversity.columns:
        top_richness_threshold = df_merged_diversity['Observed_Richness'].quantile(1 - n_percentile_strict/100)
        df_target = df_merged_diversity[df_merged_diversity['Observed_Richness'] >= top_richness_threshold]
        print(f"\n--- OGs with Highest Observed Richness (Top {n_percentile_strict}%, Threshold >= {int(top_richness_threshold)}) ---")
        for index, row in df_target.iterrows():
            og_members = df_full[df_full[orthogroup_col] == row[orthogroup_col]]
            add_extreme_og_reason(row[orthogroup_col], f"Top {n_percentile_strict}% Observed Richness", row, og_members)
            
    # 6. Lowest Observed Richness
    if 'Observed_Richness' in df_merged_diversity.columns:
        bottom_richness_threshold = df_merged_diversity['Observed_Richness'].quantile(n_percentile_strict/100)
        min_obs_richness = df_merged_diversity['Observed_Richness'].min()
        bottom_richness_threshold = max(bottom_richness_threshold, min_obs_richness)
        df_target = df_merged_diversity[df_merged_diversity['Observed_Richness'] <= bottom_richness_threshold]
        print(f"\n--- OGs with Lowest Observed Richness (Bottom {n_percentile_strict}%, Threshold <= {int(bottom_richness_threshold)}) ---")
        for index, row in df_target.iterrows():
            og_members = df_full[df_full[orthogroup_col] == row[orthogroup_col]]
            add_extreme_og_reason(row[orthogroup_col], f"Bottom {n_percentile_strict}% Observed Richness", row, og_members)

    # --- Identify OGs based on combined metrics (using combo_percentile_relaxed) ---
    # 7. Combination: High Entropy AND Low APSI ("Classic" Diversification)
    if 'Shannon_Entropy' in df_merged_diversity.columns and apsi_col in df_merged_diversity.columns:
        top_entropy_threshold = df_merged_diversity['Shannon_Entropy'].quantile(1 - combo_percentile_relaxed/100)
        bottom_apsi_threshold = df_merged_diversity[apsi_col].quantile(combo_percentile_relaxed/100)
        df_target = df_merged_diversity[
            (df_merged_diversity['Shannon_Entropy'] >= top_entropy_threshold) &
            (df_merged_diversity[apsi_col] <= bottom_apsi_threshold)
        ]
        print(f"\n--- OGs with High Entropy (Top {combo_percentile_relaxed}%) AND Low APSI (Bottom {combo_percentile_relaxed}%) ---")
        for index, row in df_target.iterrows():
            og_members = df_full[df_full[orthogroup_col] == row[orthogroup_col]]
            add_extreme_og_reason(row[orthogroup_col], f"High Entropy & Low APSI", row, og_members)

    # 8. Combination (Surprising): High Entropy AND High APSI
    if 'Shannon_Entropy' in df_merged_diversity.columns and apsi_col in df_merged_diversity.columns:
        top_entropy_threshold = df_merged_diversity['Shannon_Entropy'].quantile(1 - combo_percentile_relaxed/100)
        top_apsi_threshold = df_merged_diversity[apsi_col].quantile(1 - combo_percentile_relaxed/100)
        df_target = df_merged_diversity[
            (df_merged_diversity['Shannon_Entropy'] >= top_entropy_threshold) &
            (df_merged_diversity[apsi_col] >= top_apsi_threshold)
        ]
        print(f"\n--- OGs with High Entropy (Top {combo_percentile_relaxed}%) AND High APSI (Top {combo_percentile_relaxed}%) ---")
        for index, row in df_target.iterrows():
            og_members = df_full[df_full[orthogroup_col] == row[orthogroup_col]]
            add_extreme_og_reason(row[orthogroup_col], f"High Entropy & High APSI", row, og_members)

    # 9. Combination (Surprising): High Richness AND High APSI
    if 'Observed_Richness' in df_merged_diversity.columns and apsi_col in df_merged_diversity.columns:
        top_richness_threshold = df_merged_diversity['Observed_Richness'].quantile(1 - combo_percentile_relaxed/100)
        top_apsi_threshold = df_merged_diversity[apsi_col].quantile(1 - combo_percentile_relaxed/100)
        df_target = df_merged_diversity[
            (df_merged_diversity['Observed_Richness'] >= top_richness_threshold) &
            (df_merged_diversity[apsi_col] >= top_apsi_threshold)
        ]
        print(f"\n--- OGs with High Richness (Top {combo_percentile_relaxed}%) AND High APSI (Top {combo_percentile_relaxed}%) ---")
        for index, row in df_target.iterrows():
            og_members = df_full[df_full[orthogroup_col] == row[orthogroup_col]]
            add_extreme_og_reason(row[orthogroup_col], f"High Richness & High APSI", row, og_members)
            
    # 10. Combination: Low Entropy AND Low APSI
    if 'Shannon_Entropy' in df_merged_diversity.columns and apsi_col in df_merged_diversity.columns:
        bottom_entropy_threshold = df_merged_diversity['Shannon_Entropy'].quantile(combo_percentile_relaxed/100)
        bottom_apsi_threshold = df_merged_diversity[apsi_col].quantile(combo_percentile_relaxed/100)
        df_target = df_merged_diversity[
            (df_merged_diversity['Shannon_Entropy'] <= bottom_entropy_threshold) &
            (df_merged_diversity[apsi_col] <= bottom_apsi_threshold)
        ]
        print(f"\n--- OGs with Low Entropy (Bottom {combo_percentile_relaxed}%) AND Low APSI (Bottom {combo_percentile_relaxed}%) ---")
        for index, row in df_target.iterrows():
            og_members = df_full[df_full[orthogroup_col] == row[orthogroup_col]]
            add_extreme_og_reason(row[orthogroup_col], f"Low Entropy & Low APSI", row, og_members)

    # --- Finalize and Display Table of Highlighted OGs ---
    if extreme_ogs_data:
        df_extreme_ogs_summary = pd.DataFrame(extreme_ogs_data)
        df_extreme_ogs_summary['Reason_for_Highlight'] = df_extreme_ogs_summary['Orthogroup'].map(lambda og_id: "; ".join(sorted(list(set(og_reasons_map.get(og_id, []))))))
        
        df_extreme_ogs_summary = df_extreme_ogs_summary.sort_values('Orthogroup').reset_index(drop=True)
        
        print("\n\n--- Summary Table: Highlighted Orthogroups with Extreme/Interesting Diversity Profiles ---")
        print(f"Total unique OGs highlighted: {len(df_extreme_ogs_summary)}")
        try:
            display_cols = ['Orthogroup', 'Group', 'Reason_for_Highlight', 'APSI (%)', 'Shannon_Entropy', 
                            'Observed_Richness', 'Total_Members_in_OG_from_df_full', 
                            'Percent_Structurally_Dark (%)', 'Broad_Functional_Category']
            display_cols = [col for col in display_cols if col in df_extreme_ogs_summary.columns]
            print(df_extreme_ogs_summary[display_cols].to_markdown(index=False, floatfmt=".2f"))
        except ImportError:
            print(df_extreme_ogs_summary[display_cols])

        extreme_ogs_summary_path_revised = Path(output_summary_dir_phase1) / "figure11_extreme_diversity_ogs_summary_expanded_revised.csv"
        try:
            df_extreme_ogs_summary.to_csv(extreme_ogs_summary_path_revised, index=False, float_format='%.3f')
            print(f"\nSummary table of extreme OGs (revised criteria) saved to: {extreme_ogs_summary_path_revised}")
        except Exception as e:
            print(f"Error saving revised extreme OGs summary table: {e}")
    else:
        print("\nNo OGs met the defined extreme/interesting criteria with the revised percentiles.")

print("\n\n--- Cell 15 (Figure 11 - Highlighting Extreme & Interesting OGs - Revised Criteria) Complete ---")



In [None]:
# Cell 16: Figure 12 - Functional Breakdown of OGs in Extreme/Interesting Diversity Profiles

# --- Imports & Setup Assumptions ---
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from pathlib import Path

# This cell assumes that pandas (pd), numpy (np), plotly.express (px),
# plotly.graph_objects (go), Path from pathlib, make_subplots,
# and all necessary variables/helper functions from Cell 1 (Setup) are available.
# Key variables from previous cells:
# - output_summary_dir_phase1: Directory where Cell 15 saved its summary.
# - output_figure_dir: Directory to save plots.
# - plotly_layout_defaults: Default layout for Plotly figures.
# - broad_category_color_map: Color map for functional categories (from Cell 1).
# - orthogroup_col, broad_func_cat_col

print("\n\n--- Generating Figure 12: Functional Breakdown of Extreme Diversity OG Profiles (Cell 16) ---")

# --- Configuration ---
# Path to the summary file created by Cell 15
extreme_ogs_summary_path = Path(output_summary_dir_phase1) / "figure11_extreme_diversity_ogs_summary_expanded_revised.csv"

# Define the "Reason_for_Highlight" profiles we want to analyze
# These must exactly match the strings used in Cell 15 for combined categories
target_profiles = [
    "High Entropy & Low APSI",       # Classic Diversification
    "High Entropy & High APSI",      # Phylogenetically Diverse, Sequence-Conserved
    "High Richness & High APSI",     # Large OG, Sequence-Conserved
    "Low Entropy & Low APSI"       # Phylogenetically Constrained, Sequence-Diverged
]
# Shorter names for plot titles
profile_short_names = {
    "High Entropy & Low APSI": "High Entropy, Low APSI",
    "High Entropy & High APSI": "High Entropy, High APSI",
    "High Richness & High APSI": "High Richness, High APSI",
    "Low Entropy & Low APSI": "Low Entropy, Low APSI"
}


# --- Load Data from Cell 15 ---
if not extreme_ogs_summary_path.is_file():
    print(f"ERROR: Extreme OGs summary file not found at '{extreme_ogs_summary_path}'. Please run Cell 15 first.")
    df_extreme_ogs_summary = pd.DataFrame() # Ensure it's defined
else:
    try:
        df_extreme_ogs_summary = pd.read_csv(extreme_ogs_summary_path)
        print(f"Successfully loaded extreme OGs summary. Shape: {df_extreme_ogs_summary.shape}")
        if broad_func_cat_col not in df_extreme_ogs_summary.columns:
            print(f"ERROR: Column '{broad_func_cat_col}' not found in the loaded summary file.")
            df_extreme_ogs_summary = pd.DataFrame() # Make it unusable
    except Exception as e:
        print(f"Error loading extreme OGs summary file: {e}")
        df_extreme_ogs_summary = pd.DataFrame()

# --- Generate Plots ---
if not df_extreme_ogs_summary.empty:
    num_profiles = len(target_profiles)
    
    # Determine subplot layout (e.g., 2x2 or 4x1)
    if num_profiles <= 0:
        print("No target profiles defined. Skipping plot generation.")
    else:
        # For 4 profiles, 2x2 is good. If more, adjust.
        n_cols = 2 if num_profiles > 2 else 1
        n_rows = (num_profiles + n_cols - 1) // n_cols # Calculate rows needed

        fig = make_subplots(
            rows=n_rows, 
            cols=n_cols,
            subplot_titles=[profile_short_names.get(p, p) for p in target_profiles],
            vertical_spacing=0.15, # Adjust as needed
            horizontal_spacing=0.1
        )
        
        plot_data_for_summary_table = [] # To store data for a summary table

        current_row = 1
        current_col = 1

        for profile_reason in target_profiles:
            # Filter for OGs that have this specific reason (can be part of a semicolon-separated list)
            df_profile_subset = df_extreme_ogs_summary[
                df_extreme_ogs_summary['Reason_for_Highlight'].str.contains(profile_reason, regex=False, na=False)
            ].copy() # Use .copy() to avoid SettingWithCopyWarning

            if df_profile_subset.empty:
                print(f"No OGs found for profile: '{profile_reason}'. Skipping its subplot.")
                # Advance subplot position if layout is fixed
                if n_cols > 1:
                    current_col += 1
                    if current_col > n_cols:
                        current_col = 1
                        current_row += 1
                else:
                    current_row +=1
                continue

            # Calculate functional category percentages for this profile
            func_cat_counts = df_profile_subset[broad_func_cat_col].value_counts(normalize=True).mul(100).reset_index()
            func_cat_counts.columns = [broad_func_cat_col, 'Percentage']
            
            # Sort by percentage for better visualization
            func_cat_counts = func_cat_counts.sort_values('Percentage', ascending=False)

            # Store for summary table
            for _, cat_row in func_cat_counts.iterrows():
                plot_data_for_summary_table.append({
                    'Profile': profile_short_names.get(profile_reason, profile_reason),
                    broad_func_cat_col: cat_row[broad_func_cat_col],
                    'Percentage': cat_row['Percentage']
                })

            # Add bar trace to subplot
            fig.add_trace(go.Bar(
                x=func_cat_counts[broad_func_cat_col],
                y=func_cat_counts['Percentage'],
                name=profile_short_names.get(profile_reason, profile_reason), # Legend for the profile
                marker_color=[broad_category_color_map.get(cat, '#cccccc') for cat in func_cat_counts[broad_func_cat_col]],
                text=func_cat_counts['Percentage'].apply(lambda x: f"{x:.1f}%"),
                textposition='none' # Hide direct text on bars for now, hover will show
            ), row=current_row, col=current_col)

            fig.update_xaxes(title_text="", tickangle=45, row=current_row, col=current_col, categoryorder='total descending', showgrid=False, zeroline=False, linecolor='black', linewidth=1.5, ticks="outside", ticklen=5, tickwidth=1.5, tickcolor='black')
            fig.update_yaxes(title_text="Percentage of OGs", range=[0, max(50, func_cat_counts['Percentage'].max() * 1.1 if not func_cat_counts.empty else 10)], row=current_row, col=current_col, showgrid=False, zeroline=False, linecolor='black', linewidth=1.5, ticks="outside", ticklen=5, tickwidth=1.5, tickcolor='black')

            # Advance subplot position
            if n_cols > 1:
                current_col += 1
                if current_col > n_cols:
                    current_col = 1
                    current_row += 1
            else:
                current_row +=1
                
        fig.update_layout(
            height=max(400 * n_rows, 600), # Adjust height
            # title_text="Functional Composition of Orthogroups in Extreme Diversity Profiles", # Main title
            showlegend=False, # Colors are by category, legend per bar not needed for this setup
            plot_bgcolor='rgba(0,0,0,0)', 
            paper_bgcolor='rgba(0,0,0,0)',
            margin=dict(l=60, r=40, t=100, b=150) # Adjust margins
        )
        if 'plotly_layout_defaults' in locals():
            fig.update_layout(font=plotly_layout_defaults.font) # Apply global font
            for i in range(1, len(fig.layout.annotations) + 1): # Subplot titles
                 if fig.layout.annotations[i-1]: # Check if annotation exists
                    fig.layout.annotations[i-1].font.size = 14
                    if plotly_layout_defaults.font:
                        fig.layout.annotations[i-1].font.family = plotly_layout_defaults.font.family
        
        fig.show()

        # --- Save Figure and Summary Data ---
        fig_path_html = Path(output_figure_dir) / "figure12_extreme_profile_functions.html"
        fig.write_html(str(fig_path_html))
        print(f"Figure 12 (Functional Breakdown of Extreme Profiles) HTML saved to: {fig_path_html}")
        try:
            fig_path_pdf = Path(output_figure_dir) / "figure12_extreme_profile_functions.pdf"
            fig.write_image(str(fig_path_pdf), width=max(800, 400*n_cols), height=max(400*n_rows, 600))
            print(f"Figure 12 PDF saved to: {fig_path_pdf}")
        except Exception as e: print(f"Warning: Could not export Figure 12 to PDF. Error: {e}")
        try:
            fig_path_svg = Path(output_figure_dir) / "figure12_extreme_profile_functions.svg"
            fig.write_image(str(fig_path_svg), width=max(800, 400*n_cols), height=max(400*n_rows, 600))
            print(f"Figure 12 SVG saved to: {fig_path_svg}")
        except Exception as e: print(f"Warning: Could not export Figure 12 to SVG. Error: {e}")

        # Save the aggregated data used for the plots
        if plot_data_for_summary_table:
            df_plot_summary_table = pd.DataFrame(plot_data_for_summary_table)
            summary_table_path = Path(output_summary_dir_phase1) / "figure12_extreme_profile_functions_summary.csv"
            try:
                df_plot_summary_table.to_csv(summary_table_path, index=False, float_format='%.2f')
                print(f"Summary data for Figure 12 plots saved to: {summary_table_path}")
            except Exception as e:
                print(f"Error saving summary data for Figure 12: {e}")
        else:
            print("No data was generated for the plot summary table.")
else:
    print("Skipping Figure 12 generation as the extreme OGs summary data is not available or empty.")

print("\n\n--- Cell 16 (Figure 12 - Functional Breakdown of Extreme Profiles) Complete ---")



In [None]:
# Cell 17: Figure 9 (Part 2) - Correlation Analysis of Diversity Metrics

# --- Imports & Setup Assumptions ---
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from pathlib import Path
from scipy.stats import pearsonr, spearmanr

# This cell assumes that pandas (pd), numpy (np), plotly.express (px),
# plotly.graph_objects (go), Path from pathlib, pearsonr, spearmanr
# and all necessary variables/helper functions from Cell 1 (Setup) are available.
# Key variables from previous cells:
# - df_merged_diversity: DataFrame from Cell 13 (contains OG, APSI, Group, Shannon_Entropy, Observed_Richness).
# - output_figure_dir, output_summary_dir_phase1: Directories to save plots/data.
# - plotly_layout_defaults: Default layout for Plotly figures.
# - group_colors: Color map for Asgard/GV.
# - orthogroup_col, apsi_col ('Intra_OG_APSI'), group_col

print("\n\n--- Generating Figure 9 (Part 2): Correlation Analysis of Diversity Metrics (Cell 17) ---")

# --- Check if necessary DataFrames and columns are available ---
error_found = False
if 'df_merged_diversity' not in locals() or df_merged_diversity.empty:
    print(f"ERROR: DataFrame 'df_merged_diversity' not found or is empty. Please run Cell 13 first.")
    error_found = True
elif not all(col in df_merged_diversity.columns for col in [apsi_col, 'Shannon_Entropy', 'Observed_Richness', group_col]):
    missing_cols = [col for col in [apsi_col, 'Shannon_Entropy', 'Observed_Richness', group_col] if col not in df_merged_diversity.columns]
    print(f"ERROR: 'df_merged_diversity' is missing required metrics/group columns: {missing_cols}.")
    error_found = True

if error_found:
    print("Cannot proceed with correlation analysis due to missing data or columns.")
    df_correlations_summary = pd.DataFrame() # Ensure it's defined
else:
    # --- Calculate Correlation Coefficients ---
    correlation_results = []

    metrics_for_corr = {
        'APSI (%)': apsi_col, # Use original APSI (0-1) for calculation, label as %
        'Shannon_Entropy': 'Shannon_Entropy',
        'Observed_Richness': 'Observed_Richness'
    }
    
    # Ensure APSI is scaled 0-1 for calculations if it was previously scaled to 0-100 for plotting
    # If 'APSI_Percent' exists from Cell 13, use that and divide by 100, else use apsi_col directly
    if 'APSI_Percent' in df_merged_diversity.columns:
        df_merged_diversity_corr = df_merged_diversity.copy()
        df_merged_diversity_corr[apsi_col] = df_merged_diversity_corr['APSI_Percent'] / 100
    else:
        df_merged_diversity_corr = df_merged_diversity.copy()


    datasets_to_correlate = {
        "All OGs": df_merged_diversity_corr,
        "Asgard OGs": df_merged_diversity_corr[df_merged_diversity_corr[group_col] == 'Asgard'],
        "GV OGs": df_merged_diversity_corr[df_merged_diversity_corr[group_col] == 'GV']
    }

    metric_pairs = [
        ('APSI (%)', 'Shannon_Entropy'),
        ('APSI (%)', 'Observed_Richness'),
        ('Shannon_Entropy', 'Observed_Richness')
    ]

    for dataset_name, df_data in datasets_to_correlate.items():
        if df_data.empty:
            print(f"Skipping correlations for {dataset_name}: DataFrame is empty.")
            continue
        print(f"\nCalculating correlations for: {dataset_name} ({len(df_data)} OGs)")
        for m1_label, m2_label in metric_pairs:
            m1_col = metrics_for_corr[m1_label]
            m2_col = metrics_for_corr[m2_label]

            if m1_col not in df_data.columns or m2_col not in df_data.columns:
                print(f"  Skipping {m1_label} vs {m2_label}: one or both columns missing.")
                continue
            
            # Drop NaNs for the specific pair of columns before calculating correlation
            df_pair_data = df_data[[m1_col, m2_col]].dropna()
            
            if len(df_pair_data) < 3: # Need at least 3 data points for meaningful correlation
                print(f"  Skipping {m1_label} vs {m2_label}: Insufficient data points after dropping NaNs ({len(df_pair_data)}).")
                pearson_coef, pearson_p = np.nan, np.nan
                spearman_coef, spearman_p = np.nan, np.nan
            else:
                pearson_coef, pearson_p = pearsonr(df_pair_data[m1_col], df_pair_data[m2_col])
                spearman_coef, spearman_p = spearmanr(df_pair_data[m1_col], df_pair_data[m2_col])
            
            correlation_results.append({
                'Dataset': dataset_name,
                'Metric 1': m1_label,
                'Metric 2': m2_label,
                'N_Obs': len(df_pair_data),
                'Pearson_Correlation': pearson_coef,
                'Pearson_P_Value': pearson_p,
                'Spearman_Correlation': spearman_coef,
                'Spearman_P_Value': spearman_p
            })

    df_correlations_summary = pd.DataFrame(correlation_results)
    
    print("\n\n--- Correlation Coefficients Summary ---")
    if not df_correlations_summary.empty:
        try:
            print(df_correlations_summary.to_markdown(index=False, floatfmt=".3f"))
        except ImportError:
            print(df_correlations_summary)
        
        # Save the summary table
        corr_summary_path = Path(output_summary_dir_phase1) / "figure9_part2_correlation_summary.csv"
        try:
            df_correlations_summary.to_csv(corr_summary_path, index=False, float_format='%.4f')
            print(f"\nCorrelation summary table saved to: {corr_summary_path}")
        except Exception as e:
            print(f"Error saving correlation summary table: {e}")
    else:
        print("No correlation results to display.")

    # --- Generate Scatter Plots with Trendlines (Re-using plots from Cell 13 / Fig 9 D & E logic) ---
    # These plots visually support the correlation table.
    
    # Plot 1: Shannon Entropy vs. APSI
    if 'Shannon_Entropy' in df_merged_diversity_corr.columns and apsi_col in df_merged_diversity_corr.columns:
        print("\n--- Plotting Shannon Entropy vs. APSI (with Trendlines) ---")
        df_merged_diversity_corr['APSI_Percent_Plot'] = df_merged_diversity_corr[apsi_col] * 100
        
        fig_ent_vs_apsi_corr = px.scatter(
            df_merged_diversity_corr.dropna(subset=['Shannon_Entropy', 'APSI_Percent_Plot']),
            x='APSI_Percent_Plot',
            y='Shannon_Entropy',
            color=group_col,
            labels={'APSI_Percent_Plot': 'Intra-OG APSI (%)', 'Shannon_Entropy': 'Shannon Entropy (Tree-based)'},
            color_discrete_map=group_colors,
            opacity=0.6,
            trendline="ols", 
            trendline_scope="overall", # Single trendline for all data
            # To add per-group trendlines, you might need to plot traces separately or use facets
        )
        # Add per-group trendlines manually if desired for more detail
        # Example for Asgard:
        # df_asgard_corr = df_merged_diversity_corr[df_merged_diversity_corr[group_col] == 'Asgard'].dropna(subset=['Shannon_Entropy', 'APSI_Percent_Plot'])
        # if len(df_asgard_corr) > 1:
        #     results_asgard = px.scatter(df_asgard_corr, x='APSI_Percent_Plot', y='Shannon_Entropy', trendline="ols").get_traces()[-1] # Get trendline trace
        #     results_asgard.update(line_color=group_colors.get('Asgard', 'blue'), name='Asgard Trend')
        #     fig_ent_vs_apsi_corr.add_trace(results_asgard)
        # (Repeat for GV)

        fig_ent_vs_apsi_corr.update_layout(plotly_layout_defaults)
        fig_ent_vs_apsi_corr.update_xaxes(title_text='Intra-OG APSI (%)', showgrid=False)
        fig_ent_vs_apsi_corr.update_yaxes(title_text='Shannon Entropy (Tree-based)', showgrid=False)
        fig_ent_vs_apsi_corr.show()
        # Save plot
        fig_path_html = Path(output_figure_dir) / "figure9_part2_entropy_vs_apsi_scatter.html"
        fig_ent_vs_apsi_corr.write_html(str(fig_path_html))
        print(f"Shannon Entropy vs APSI scatter plot HTML saved to: {fig_path_html}")
        # PDF/SVG saving as in previous cells
    else:
        print("Skipping Shannon Entropy vs. APSI scatter plot: columns not found.")

    # Plot 2: Observed Richness vs. APSI
    if 'Observed_Richness' in df_merged_diversity_corr.columns and apsi_col in df_merged_diversity_corr.columns:
        print("\n--- Plotting Observed Richness vs. APSI (with Trendlines) ---")
        fig_rich_vs_apsi_corr = px.scatter(
            df_merged_diversity_corr.dropna(subset=['Observed_Richness', 'APSI_Percent_Plot']),
            x='APSI_Percent_Plot',
            y='Observed_Richness',
            color=group_col,
            labels={'APSI_Percent_Plot': 'Intra-OG APSI (%)', 'Observed_Richness': 'Observed Richness (Tree Tips)'},
            color_discrete_map=group_colors,
            opacity=0.6,
            trendline="ols",
            trendline_scope="overall"
        )
        fig_rich_vs_apsi_corr.update_layout(plotly_layout_defaults)
        fig_rich_vs_apsi_corr.update_xaxes(title_text='Intra-OG APSI (%)', showgrid=False)
        fig_rich_vs_apsi_corr.update_yaxes(title_text='Observed Richness (Tree Tips)', showgrid=False) # Consider log type="log"
        fig_rich_vs_apsi_corr.show()
        # Save plot
        fig_path_html = Path(output_figure_dir) / "figure9_part2_richness_vs_apsi_scatter.html"
        fig_rich_vs_apsi_corr.write_html(str(fig_path_html))
        print(f"Observed Richness vs APSI scatter plot HTML saved to: {fig_path_html}")
        # PDF/SVG saving as in previous cells
    else:
        print("Skipping Observed Richness vs. APSI scatter plot: columns not found.")

print("\n\n--- Cell 17 (Figure 9 Part 2 - Correlation Analysis) Complete ---")



In [None]:
# Cell 18: Figure 12 (Panel A) - Prevalence of OGs in Extreme/Interesting Diversity Profiles

# --- Imports & Setup Assumptions ---
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from pathlib import Path

# This cell assumes that pandas (pd), numpy (np), plotly.express (px),
# plotly.graph_objects (go), Path from pathlib,
# and all necessary variables/helper functions from Cell 1 (Setup) are available.
# Key variables from previous cells:
# - df_merged_diversity: DataFrame from Cell 13 (contains OG, APSI, Group, Shannon_Entropy, Observed_Richness).
# - output_summary_dir_phase1: Directory where Cell 15 saved its summary.
# - output_figure_dir: Directory to save plots.
# - plotly_layout_defaults: Default layout for Plotly figures.
# - group_colors: Color map for Asgard/GV.
# - orthogroup_col, group_col

print("\n\n--- Generating Figure 12 (Panel A): Prevalence of Extreme Diversity Profiles (Cell 18) ---")

# --- Configuration ---
extreme_ogs_summary_path = Path(output_summary_dir_phase1) / "figure11_extreme_diversity_ogs_summary_expanded_revised.csv"

# Define the 10 "Reason_for_Highlight" categories exactly as they appear as substrings in Cell 15's output
# These are the individual reasons, not the combined string from the 'Reason_for_Highlight' column yet.
# The combined strings in the CSV are like "Top 5% Shannon Entropy; High Entropy & Low APSI"
# We need to check for the presence of these specific reason components.

# These are the individual component reasons used in Cell 15.
# The combined reasons (like "High Entropy & Low APSI") are also distinct categories we defined.
profile_definitions_for_counting = {
    f"Top {n_percentile_strict}% Shannon Entropy": f"Top {n_percentile_strict}% Shannon Entropy",
    f"Bottom {n_percentile_strict}% Shannon Entropy": f"Bottom {n_percentile_strict}% Shannon Entropy",
    f"Bottom {n_percentile_strict}% APSI": f"Bottom {n_percentile_strict}% APSI",
    f"Top {n_percentile_strict}% APSI": f"Top {n_percentile_strict}% APSI",
    f"Top {n_percentile_strict}% Observed Richness": f"Top {n_percentile_strict}% Observed Richness",
    f"Bottom {n_percentile_strict}% Observed Richness": f"Bottom {n_percentile_strict}% Observed Richness",
    "High Entropy & Low APSI": "High Entropy & Low APSI", # This is a combined category
    "High Entropy & High APSI": "High Entropy & High APSI", # This is a combined category
    "High Richness & High APSI": "High Richness & High APSI", # This is a combined category
    "Low Entropy & Low APSI": "Low Entropy & Low APSI"    # This is a combined category
}
# Use n_percentile_strict and combo_percentile_relaxed from Cell 15 if they are in scope,
# otherwise, define them here based on Cell 15's logic.
# For safety, let's define them here if not found.
if 'n_percentile_strict' not in locals(): n_percentile_strict = 5
if 'combo_percentile_relaxed' not in locals(): combo_percentile_relaxed = 25 # Matching revised Cell 15

# Update profile_definitions_for_counting keys to reflect the actual strings
# The values will be the shorter names for plotting
profile_definitions_for_counting = {
    f"Top {n_percentile_strict}% Shannon Entropy": f"Top {n_percentile_strict}% Ent.",
    f"Bottom {n_percentile_strict}% Shannon Entropy": f"Bot. {n_percentile_strict}% Ent.",
    f"Bottom {n_percentile_strict}% APSI": f"Bot. {n_percentile_strict}% APSI",
    f"Top {n_percentile_strict}% APSI": f"Top {n_percentile_strict}% APSI",
    f"Top {n_percentile_strict}% Observed Richness": f"Top {n_percentile_strict}% Rich.",
    f"Bottom {n_percentile_strict}% Observed Richness": f"Bot. {n_percentile_strict}% Rich.",
    "High Entropy & Low APSI": "Hi Ent, Lo APSI",
    "High Entropy & High APSI": "Hi Ent, Hi APSI",
    "High Richness & High APSI": "Hi Rich, Hi APSI",
    "Low Entropy & Low APSI": "Lo Ent, Lo APSI"
}


# --- Load Data ---
error_found_loading = False
if not extreme_ogs_summary_path.is_file():
    print(f"ERROR: Extreme OGs summary file not found at '{extreme_ogs_summary_path}'. Please run Cell 15 (revised) first.")
    df_extreme_ogs_summary = pd.DataFrame()
    error_found_loading = True
else:
    try:
        df_extreme_ogs_summary = pd.read_csv(extreme_ogs_summary_path)
        print(f"Successfully loaded extreme OGs summary. Shape: {df_extreme_ogs_summary.shape}")
        if orthogroup_col not in df_extreme_ogs_summary.columns or group_col not in df_extreme_ogs_summary.columns or 'Reason_for_Highlight' not in df_extreme_ogs_summary.columns:
            print(f"ERROR: Loaded summary file is missing critical columns: '{orthogroup_col}', '{group_col}', or 'Reason_for_Highlight'.")
            df_extreme_ogs_summary = pd.DataFrame()
            error_found_loading = True
    except Exception as e:
        print(f"Error loading extreme OGs summary file: {e}")
        df_extreme_ogs_summary = pd.DataFrame()
        error_found_loading = True

if 'df_merged_diversity' not in locals() or df_merged_diversity.empty:
    print(f"ERROR: DataFrame 'df_merged_diversity' (for total OG counts) not found or is empty. Please run Cell 13 first.")
    error_found_loading = True

if error_found_loading:
    print("Cannot proceed with Figure 12 Panel A due to missing input data.")
else:
    # --- Calculate Total OGs for Asgard and GV ---
    total_asgard_ogs = df_merged_diversity[df_merged_diversity[group_col] == 'Asgard'][orthogroup_col].nunique()
    total_gv_ogs = df_merged_diversity[df_merged_diversity[group_col] == 'GV'][orthogroup_col].nunique()
    print(f"Total Asgard OGs in dataset: {total_asgard_ogs}")
    print(f"Total GV OGs in dataset: {total_gv_ogs}")

    # --- Calculate Percentage of OGs in Each Profile Category ---
    profile_prevalence_data = []

    for reason_str_to_check, short_label in profile_definitions_for_counting.items():
        # Filter OGs whose 'Reason_for_Highlight' string CONTAINS the specific reason string
        # This correctly counts OGs that fall into this category, even if they also fall into others.
        df_matching_profile = df_extreme_ogs_summary[
            df_extreme_ogs_summary['Reason_for_Highlight'].str.contains(reason_str_to_check, regex=False, na=False)
        ]
        
        count_asgard = df_matching_profile[df_matching_profile[group_col] == 'Asgard'][orthogroup_col].nunique()
        count_gv = df_matching_profile[df_matching_profile[group_col] == 'GV'][orthogroup_col].nunique()
        
        percent_asgard = (count_asgard / total_asgard_ogs * 100) if total_asgard_ogs > 0 else 0
        percent_gv = (count_gv / total_gv_ogs * 100) if total_gv_ogs > 0 else 0
        
        profile_prevalence_data.append({'Profile_Category': short_label, 'Group': 'Asgard', 'Percentage': percent_asgard, 'Count': count_asgard})
        profile_prevalence_data.append({'Profile_Category': short_label, 'Group': 'GV', 'Percentage': percent_gv, 'Count': count_gv})

    df_profile_prevalence = pd.DataFrame(profile_prevalence_data)

    print("\n--- Prevalence of OGs in Defined Diversity Profiles ---")
    if not df_profile_prevalence.empty:
        try:
            # Display counts for verification
            df_display_counts = df_profile_prevalence.pivot(index='Profile_Category', columns='Group', values='Count').fillna(0).astype(int)
            print("Counts of OGs per profile category:")
            print(df_display_counts.to_markdown())
            df_display_perc = df_profile_prevalence.pivot(index='Profile_Category', columns='Group', values='Percentage').fillna(0)
            print("\nPercentage of OGs per profile category:")
            print(df_display_perc.to_markdown(floatfmt=".2f"))

        except ImportError:
            print(df_profile_prevalence)

        # --- Create Plot for Panel A ---
        fig_panel_a = px.bar(
            df_profile_prevalence,
            x='Profile_Category',
            y='Percentage',
            color='Group',
            barmode='group',
            labels={'Percentage': '% of Total OGs in Group', 'Profile_Category': 'Diversity Profile Category'},
            color_discrete_map=group_colors # from Cell 1
        )
        fig_panel_a.update_layout(plotly_layout_defaults)
        fig_panel_a.update_xaxes(
            title_text="Diversity Profile Category", 
            showgrid=False, 
            tickangle=45,
            categoryorder='array', # Ensure order matches definition
            categoryarray=list(profile_definitions_for_counting.values())
        )
        fig_panel_a.update_yaxes(title_text="% of Total OGs in Group", showgrid=False, range=[0, df_profile_prevalence['Percentage'].max() * 1.15 if not df_profile_prevalence.empty else 10])
        fig_panel_a.update_layout(legend_title_text='Group')
        
        fig_panel_a.show()

        # Save plot
        fig_path_html = Path(output_figure_dir) / "figure12_panel_a_profile_prevalence.html"
        fig_panel_a.write_html(str(fig_path_html))
        print(f"Figure 12 Panel A (Profile Prevalence) HTML saved to: {fig_path_html}")
        try:
            fig_path_pdf = Path(output_figure_dir) / "figure12_panel_a_profile_prevalence.pdf"
            fig_panel_a.write_image(str(fig_path_pdf))
            print(f"Figure 12 Panel A PDF saved to: {fig_path_pdf}")
        except Exception as e: print(f"PDF export error for Figure 12 Panel A: {e}")
        try:
            fig_path_svg = Path(output_figure_dir) / "figure12_panel_a_profile_prevalence.svg"
            fig_panel_a.write_image(str(fig_path_svg))
            print(f"Figure 12 Panel A SVG saved to: {fig_path_svg}")
        except Exception as e: print(f"SVG export error for Figure 12 Panel A: {e}")
        
        # Save the data used for this panel
        panel_a_data_path = Path(output_summary_dir_phase1) / "figure12_panel_a_profile_prevalence_data.csv"
        try:
            df_profile_prevalence.to_csv(panel_a_data_path, index=False, float_format='%.3f')
            print(f"Data for Figure 12 Panel A saved to: {panel_a_data_path}")
        except Exception as e:
            print(f"Error saving data for Figure 12 Panel A: {e}")
            
    else:
        print("No data generated for profile prevalence plot.")

print("\n\n--- Cell 18 (Figure 12 Panel A - Prevalence of Extreme Profiles) Complete ---")
print("Next: Cell 19 for Panels B, C, D (Functional Breakdown of Selected Profiles)")



In [None]:
# Cell 19: Figure 12 (Panels B, C, D) - Functional Breakdown of Selected Profiles & Correlation Plot

# --- Imports & Setup Assumptions ---
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from pathlib import Path
from scipy.stats import spearmanr # For correlation

# This cell assumes that pandas (pd), numpy (np), plotly.express (px),
# plotly.graph_objects (go), Path from pathlib, make_subplots, spearmanr
# and all necessary variables/helper functions from Cell 1 (Setup) are available.
# Key variables from previous cells:
# - df_merged_diversity: DataFrame from Cell 13 (OG, APSI, Group, Shannon_Entropy, Observed_Richness, Broad_Functional_Category).
# - output_summary_dir_phase1: Directory where Cell 15 saved its summary.
# - output_figure_dir: Directory to save plots.
# - plotly_layout_defaults: Default layout for Plotly figures.
# - group_colors, broad_category_color_map: Color maps.
# - orthogroup_col, group_col, apsi_col, broad_func_cat_col

print("\n\n--- Generating Figure 12 (Panels B, C, D): Functional Breakdown & Correlation (Cell 19) ---")

# --- Configuration & Definitions ---
extreme_ogs_summary_path = Path(output_summary_dir_phase1) / "figure11_extreme_diversity_ogs_summary_expanded_revised.csv"

# Profiles for functional breakdown (Panels B and C)
# Keys must exactly match the strings used in Cell 15's "Reason_for_Highlight" generation
profile_for_panel_b_key = "High Entropy & High APSI"
profile_for_panel_c_key = "High Richness & High APSI"

# Corresponding short labels for titles (can be fetched from profile_definitions_for_counting if that dict is made global/redefined)
# For safety, defining here:
if 'n_percentile_strict' not in locals(): n_percentile_strict = 5
if 'combo_percentile_relaxed' not in locals(): combo_percentile_relaxed = 25

profile_labels = {
    "High Entropy & High APSI": "Hi Ent, Hi APSI",
    "High Richness & High APSI": "Hi Rich, Hi APSI",
}
profile_for_panel_b_label = profile_labels.get(profile_for_panel_b_key, profile_for_panel_b_key)
profile_for_panel_c_label = profile_labels.get(profile_for_panel_c_key, profile_for_panel_c_key)


# --- Load Data ---
error_found_loading = False
df_extreme_ogs_summary = pd.DataFrame()
df_merged_diversity_fig12_panels_bcd = pd.DataFrame()

if not extreme_ogs_summary_path.is_file():
    print(f"ERROR: Extreme OGs summary file not found at '{extreme_ogs_summary_path}'. Please run Cell 15 (revised) first.")
    error_found_loading = True
else:
    try:
        df_extreme_ogs_summary = pd.read_csv(extreme_ogs_summary_path)
        if not all(k in df_extreme_ogs_summary.columns for k in [orthogroup_col, group_col, 'Reason_for_Highlight', broad_func_cat_col]):
            print(f"ERROR: Loaded extreme OGs summary file is missing critical columns.")
            error_found_loading = True
    except Exception as e:
        print(f"Error loading extreme OGs summary file: {e}")
        error_found_loading = True

if 'df_merged_diversity' not in locals() or df_merged_diversity.empty:
    print(f"ERROR: DataFrame 'df_merged_diversity' not found or is empty. Please run Cell 13 first.")
    error_found_loading = True
else:
    df_merged_diversity_fig12_panels_bcd = df_merged_diversity.copy()
    if not all(k in df_merged_diversity_fig12_panels_bcd.columns for k in [orthogroup_col, group_col, apsi_col, 'Shannon_Entropy']):
        print(f"ERROR: 'df_merged_diversity' is missing critical columns for Panel D.")
        error_found_loading = True

if error_found_loading:
    print("Cannot proceed with Figure 12 Panels B,C,D due to missing input data.")
else:
    print(f"Successfully loaded data. df_extreme_ogs_summary: {df_extreme_ogs_summary.shape}, df_merged_diversity: {df_merged_diversity_fig12_panels_bcd.shape}")

    # --- Create Figure with Subplots (3 rows, 1 col for these three panels) ---
    fig = make_subplots(
        rows=3, cols=1,
        subplot_titles=(
            f"<b>B.</b> Functional Categories of '{profile_for_panel_b_label}' OGs",
            f"<b>C.</b> Functional Categories of '{profile_for_panel_c_label}' OGs",
            "<b>D.</b> Shannon Entropy vs. Intra-OG APSI (with Correlations)"
        ),
        vertical_spacing=0.15, # Increased spacing for clarity
        row_heights=[0.33, 0.33, 0.34] # Adjust relative heights
    )
    
    panel_b_c_plot_data_summary = []

    # --- Panel B: Functional Composition of "High Entropy & High APSI" Profile ---
    df_panel_b_subset = df_extreme_ogs_summary[
        df_extreme_ogs_summary['Reason_for_Highlight'].str.contains(profile_for_panel_b_key, regex=False, na=False)
    ].copy()

    if not df_panel_b_subset.empty and broad_func_cat_col in df_panel_b_subset.columns:
        func_cat_counts_b = df_panel_b_subset.groupby(group_col)[broad_func_cat_col].value_counts(normalize=True).mul(100).rename('Percentage').reset_index()
        func_cat_counts_b = func_cat_counts_b.sort_values(['Group', 'Percentage'], ascending=[True, False])
        
        for _, row_data in func_cat_counts_b.iterrows():
            panel_b_c_plot_data_summary.append({
                'Panel': 'B', 'Profile': profile_for_panel_b_label, 'Group': row_data[group_col],
                broad_func_cat_col: row_data[broad_func_cat_col], 'Percentage': row_data['Percentage']
            })
        
        # Create grouped bar chart for Panel B
        # This requires plotting each functional category as a trace if we want them colored by func_cat and grouped by Asgard/GV
        # A simpler approach for now: show overall functional breakdown, or facet by Asgard/GV if px.bar was used separately
        # For this multi-panel, let's plot overall distribution for this profile, or sum Asgard/GV for simplicity if too few points
        overall_func_cat_counts_b = df_panel_b_subset[broad_func_cat_col].value_counts(normalize=True).mul(100).reset_index()
        overall_func_cat_counts_b.columns = [broad_func_cat_col, 'Percentage']
        overall_func_cat_counts_b = overall_func_cat_counts_b.sort_values('Percentage', ascending=False)
        
        if not overall_func_cat_counts_b.empty:
            fig.add_trace(go.Bar(
                x=overall_func_cat_counts_b[broad_func_cat_col],
                y=overall_func_cat_counts_b['Percentage'],
                marker_color=[broad_category_color_map.get(cat, '#cccccc') for cat in overall_func_cat_counts_b[broad_func_cat_col]],
                name=profile_for_panel_b_label # Single legend entry for the profile type
            ), row=1, col=1) # Corresponds to subplot "B"
            fig.update_xaxes(title_text="", tickangle=45, categoryorder='total descending', row=1, col=1)
            fig.update_yaxes(title_text="% of OGs in Profile", range=[0, max(10, overall_func_cat_counts_b['Percentage'].max() * 1.1)], row=1, col=1)
        else:
            print(f"No data to plot for Panel B ({profile_for_panel_b_label}).")
    else:
        print(f"No OGs for profile '{profile_for_panel_b_key}' or functional category column missing for Panel B.")

    # --- Panel C: Functional Composition of "High Richness & High APSI" Profile ---
    df_panel_c_subset = df_extreme_ogs_summary[
        df_extreme_ogs_summary['Reason_for_Highlight'].str.contains(profile_for_panel_c_key, regex=False, na=False)
    ].copy()

    if not df_panel_c_subset.empty and broad_func_cat_col in df_panel_c_subset.columns:
        func_cat_counts_c = df_panel_c_subset.groupby(group_col)[broad_func_cat_col].value_counts(normalize=True).mul(100).rename('Percentage').reset_index()
        func_cat_counts_c = func_cat_counts_c.sort_values(['Group', 'Percentage'], ascending=[True, False])

        for _, row_data in func_cat_counts_c.iterrows():
            panel_b_c_plot_data_summary.append({
                'Panel': 'C', 'Profile': profile_for_panel_c_label, 'Group': row_data[group_col],
                broad_func_cat_col: row_data[broad_func_cat_col], 'Percentage': row_data['Percentage']
            })

        overall_func_cat_counts_c = df_panel_c_subset[broad_func_cat_col].value_counts(normalize=True).mul(100).reset_index()
        overall_func_cat_counts_c.columns = [broad_func_cat_col, 'Percentage']
        overall_func_cat_counts_c = overall_func_cat_counts_c.sort_values('Percentage', ascending=False)

        if not overall_func_cat_counts_c.empty:
            fig.add_trace(go.Bar(
                x=overall_func_cat_counts_c[broad_func_cat_col],
                y=overall_func_cat_counts_c['Percentage'],
                marker_color=[broad_category_color_map.get(cat, '#cccccc') for cat in overall_func_cat_counts_c[broad_func_cat_col]],
                name=profile_for_panel_c_label # Single legend entry for the profile type
            ), row=2, col=1) # Corresponds to subplot "C"
            fig.update_xaxes(title_text="", tickangle=45, categoryorder='total descending', row=2, col=1)
            fig.update_yaxes(title_text="% of OGs in Profile", range=[0, max(10, overall_func_cat_counts_c['Percentage'].max() * 1.1)], row=2, col=1)
        else:
            print(f"No data to plot for Panel C ({profile_for_panel_c_label}).")
    else:
        print(f"No OGs for profile '{profile_for_panel_c_key}' or functional category column missing for Panel C.")

    # --- Panel D: Shannon Entropy vs. APSI with Correlation Coefficients ---
    if not df_merged_diversity_fig12_panels_bcd.empty and 'Shannon_Entropy' in df_merged_diversity_fig12_panels_bcd.columns and apsi_col in df_merged_diversity_fig12_panels_bcd.columns:
        # Use a fresh copy for this plot to avoid modification issues from previous plots
        df_panel_d_plot = df_merged_diversity_fig12_panels_bcd.copy()
        df_panel_d_plot['APSI_Percent_Plot'] = df_panel_d_plot[apsi_col] * 100
        
        annotation_y_start = 0.95
        annotation_y_step = 0.07
        
        for group_name_panel_d, group_df_panel_d in df_panel_d_plot.groupby(group_col):
            fig.add_trace(go.Scatter(
                x=group_df_panel_d['APSI_Percent_Plot'],
                y=group_df_panel_d['Shannon_Entropy'],
                mode='markers',
                name=group_name_panel_d, # For legend
                legendgroup=group_name_panel_d, # Group legend items
                marker=dict(color=group_colors.get(group_name_panel_d), opacity=0.5, size=5)
            ), row=3, col=1) # Corresponds to subplot "D"

            # Calculate and add Spearman correlation annotation
            df_corr_subset = group_df_panel_d[['APSI_Percent_Plot', 'Shannon_Entropy']].dropna()
            if len(df_corr_subset) > 2:
                rho, p_val = spearmanr(df_corr_subset['APSI_Percent_Plot'], df_corr_subset['Shannon_Entropy'])
                annotation_text = f"{group_name_panel_d}: ρ={rho:.2f} (p={p_val:.2g})"
                fig.add_annotation(
                    xref="x3 domain", yref="y3 domain", # Relative to panel D domain
                    x=0.02, y=annotation_y_start, 
                    text=annotation_text, showarrow=False,
                    font=dict(size=10, color=group_colors.get(group_name_panel_d)),
                    align="left", xanchor="left"
                )
                annotation_y_start -= annotation_y_step # Move next annotation down
        
        fig.update_xaxes(title_text="Intra-OG APSI (%)", row=3, col=1)
        fig.update_yaxes(title_text="Shannon Entropy (Tree-based)", row=3, col=1)
    else:
        print("Skipping Panel D: df_merged_diversity is empty or missing columns.")

    # --- Final Layout Updates for the Entire Figure ---
    fig.update_layout(
        height=1350, # Increased height for 3 distinct panels
        plot_bgcolor='rgba(0,0,0,0)', 
        paper_bgcolor='rgba(0,0,0,0)',
        margin=dict(l=80, r=50, t=100, b=100), # General margins
        showlegend=True # Show legend for Panel D colors
    )
    if 'plotly_layout_defaults' in locals() and plotly_layout_defaults.font:
        fig.update_layout(font=plotly_layout_defaults.font)
    for i in range(len(fig.layout.annotations)): # Update subplot titles
        if fig.layout.annotations[i].text.startswith("<b>"): # Heuristic
            fig.layout.annotations[i].font.size = 14
            if 'plotly_layout_defaults' in locals() and plotly_layout_defaults.font:
                fig.layout.annotations[i].font.family = plotly_layout_defaults.font.family
    
    # Common axis styling for all subplots
    for i in range(1, 4): # For rows 1, 2, 3
        fig.update_xaxes(showgrid=False, zeroline=False, linecolor='black', linewidth=1, ticks="outside", ticklen=5, tickwidth=1, tickcolor='black', row=i, col=1)
        fig.update_yaxes(showgrid=False, zeroline=False, linecolor='black', linewidth=1, ticks="outside", ticklen=5, tickwidth=1, tickcolor='black', row=i, col=1)

    fig.show()

    # --- Save Figure ---
    fig_path_html = Path(output_figure_dir) / "figure12_panels_BCD.html"
    fig.write_html(str(fig_path_html))
    print(f"Figure 12 Panels B,C,D HTML saved to: {fig_path_html}")
    try:
        fig_path_pdf = Path(output_figure_dir) / "figure12_panels_BCD.pdf"
        fig.write_image(str(fig_path_pdf), width=800, height=1350)
        print(f"Figure 12 Panels B,C,D PDF saved to: {fig_path_pdf}")
    except Exception as e: print(f"Warning: Could not export Figure 12 BCD to PDF. Error: {e}")
    
    # Save the aggregated data used for panels B and C
    if panel_b_c_plot_data_summary:
        df_panel_b_c_summary = pd.DataFrame(panel_b_c_plot_data_summary)
        summary_table_path = Path(output_summary_dir_phase1) / "figure12_panels_BC_func_breakdown_summary.csv"
        try:
            df_panel_b_c_summary.to_csv(summary_table_path, index=False, float_format='%.2f')
            print(f"Summary data for Figure 12 Panels B & C saved to: {summary_table_path}")
        except Exception as e:
            print(f"Error saving summary data for Figure 12 Panels B & C: {e}")

print("\n\n--- Cell 19 (Figure 12 Panels B, C, D) Complete ---")
print("You can now combine the plot from Cell 18 (Panel A) with this plot (Panels B, C, D) externally to create your full Figure 12.")



In [None]:
# Cell 20: Figure 12 - Consolidated Diversity Profiles, Functional Enrichment, and Correlations

# --- Imports & Setup Assumptions ---
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from pathlib import Path
from scipy.stats import spearmanr

# This cell assumes that pandas (pd), numpy (np), plotly.express (px),
# plotly.graph_objects (go), Path from pathlib, make_subplots, spearmanr
# and all necessary variables/helper functions from Cell 1 (Setup) are available.
# Key variables from previous cells:
# - df_merged_diversity: DataFrame from Cell 13 (OG, APSI, Group, Shannon_Entropy, Observed_Richness, Broad_Functional_Category).
# - output_summary_dir_phase1: Directory where Cell 15 saved its summary.
# - output_figure_dir: Directory to save plots.
# - plotly_layout_defaults: Default layout for Plotly figures.
# - group_colors, broad_category_color_map: Color maps.
# - orthogroup_col, group_col, apsi_col, broad_func_cat_col

print("\n\n--- Generating Figure 12: Consolidated Diversity, Functional Enrichment, & Correlations (Cell 20) ---")

# --- Configuration & Definitions (mirroring Cell 15 & 18 for consistency) ---
extreme_ogs_summary_path = Path(output_summary_dir_phase1) / "figure11_extreme_diversity_ogs_summary_expanded_revised.csv"

if 'n_percentile_strict' not in locals(): n_percentile_strict = 5
if 'combo_percentile_relaxed' not in locals(): combo_percentile_relaxed = 25 # From revised Cell 15

# Define the profiles for Panel A (Prevalence) - select key ones
# Keys are the search strings, values are plot labels
profiles_for_panel_a = {
    f"Top {n_percentile_strict}% Shannon Entropy": f"Top {n_percentile_strict}% Ent.",
    f"Bottom {n_percentile_strict}% APSI": f"Bot. {n_percentile_strict}% APSI",
    "High Entropy & Low APSI": "Hi Ent, Lo APSI",
    "High Entropy & High APSI": "Hi Ent, Hi APSI",
    "Low Entropy & Low APSI": "Lo Ent, Lo APSI",
    "High Richness & High APSI": "Hi Rich, Hi APSI"
}
profile_keys_for_panel_a = list(profiles_for_panel_a.keys())
profile_labels_for_panel_a = list(profiles_for_panel_a.values())

# Define the "surprising" profiles for Panel B (Functional Enrichment Heatmap)
profiles_for_panel_b_heatmap = {
    "High Entropy & High APSI": "Hi Ent, Hi APSI",
    "Low Entropy & Low APSI": "Lo Ent, Lo APSI"
    # Can add "High Richness & High APSI" if desired, making heatmap wider or using a second heatmap
}

# --- Load Data ---
error_found_loading = False
df_extreme_ogs_summary = pd.DataFrame()
df_merged_diversity_fig12 = pd.DataFrame()

if not extreme_ogs_summary_path.is_file():
    print(f"ERROR: Extreme OGs summary file not found at '{extreme_ogs_summary_path}'. Please run Cell 15 (revised) first.")
    error_found_loading = True
else:
    try:
        df_extreme_ogs_summary = pd.read_csv(extreme_ogs_summary_path)
        if not all(k in df_extreme_ogs_summary.columns for k in [orthogroup_col, group_col, 'Reason_for_Highlight', broad_func_cat_col]):
            print(f"ERROR: Loaded extreme OGs summary file is missing critical columns.")
            df_extreme_ogs_summary = pd.DataFrame() # Invalidate if critical cols missing
            error_found_loading = True
    except Exception as e:
        print(f"Error loading extreme OGs summary file: {e}")
        df_extreme_ogs_summary = pd.DataFrame()
        error_found_loading = True

if 'df_merged_diversity' not in locals() or df_merged_diversity.empty:
    print(f"ERROR: DataFrame 'df_merged_diversity' not found or is empty. Please run Cell 13 first.")
    error_found_loading = True
else:
    df_merged_diversity_fig12 = df_merged_diversity.copy()
    if not all(k in df_merged_diversity_fig12.columns for k in [orthogroup_col, group_col, apsi_col, 'Shannon_Entropy', broad_func_cat_col]):
        print(f"ERROR: 'df_merged_diversity' is missing critical columns for Panels B/C.")
        error_found_loading = True

if error_found_loading:
    print("Cannot proceed with Figure 12 generation due to missing input data.")
else:
    print(f"Data loaded. df_extreme_ogs_summary: {df_extreme_ogs_summary.shape}, df_merged_diversity_fig12: {df_merged_diversity_fig12.shape}")

    # --- Create Figure with Subplots (Adjust layout: 2 rows for A & C, 1 wider row for B) ---
    fig = make_subplots(
        rows=3, cols=2, # Increased columns to give heatmap more space
        specs=[[{}, {}],           # Row 1: Panel A (Prevalence) spans 2 cols
               [{"colspan": 2}, None], # Row 2: Panel B (Heatmap) spans 2 cols
               [{}, {}]            # Row 3: Panel C (Correlation) spans 2 cols
              ],
        subplot_titles=(
            "<b>A.</b> Prevalence of OGs in Key Diversity Profiles", None,
            "<b>B.</b> Functional Enrichment in 'Surprising' Diversity Profiles", None, # Title for heatmap row
            "<b>C.</b> Shannon Entropy vs. Intra-OG APSI", None
        ),
        vertical_spacing=0.18, # Adjusted spacing
        row_heights=[0.3, 0.4, 0.3] # Relative heights: Prevalence, Heatmap, Correlation
    )

    # --- Panel A: Prevalence of Key Diversity Profiles ---
    total_asgard_ogs = df_merged_diversity_fig12[df_merged_diversity_fig12[group_col] == 'Asgard'][orthogroup_col].nunique()
    total_gv_ogs = df_merged_diversity_fig12[df_merged_diversity_fig12[group_col] == 'GV'][orthogroup_col].nunique()
    panel_a_plot_data = []
    for reason_key, short_label in profiles_for_panel_a.items():
        df_matching = df_extreme_ogs_summary[df_extreme_ogs_summary['Reason_for_Highlight'].str.contains(reason_key, regex=False, na=False)]
        count_asgard = df_matching[df_matching[group_col] == 'Asgard'][orthogroup_col].nunique()
        count_gv = df_matching[df_matching[group_col] == 'GV'][orthogroup_col].nunique()
        panel_a_plot_data.append({'Profile': short_label, 'Group': 'Asgard', 'Percentage': (count_asgard / total_asgard_ogs * 100) if total_asgard_ogs > 0 else 0})
        panel_a_plot_data.append({'Profile': short_label, 'Group': 'GV', 'Percentage': (count_gv / total_gv_ogs * 100) if total_gv_ogs > 0 else 0})
    df_panel_a_plot = pd.DataFrame(panel_a_plot_data)

    if not df_panel_a_plot.empty:
        # Using plotly express for simplicity in adding grouped bars to subplot
        fig_a_temp = px.bar(df_panel_a_plot, x='Profile', y='Percentage', color='Group', barmode='group', color_discrete_map=group_colors)
        for trace in fig_a_temp.data:
            fig.add_trace(trace, row=1, col=1) # Add to the first subplot area
        fig.update_xaxes(categoryorder='array', categoryarray=profile_labels_for_panel_a, tickangle=45, row=1, col=1, showgrid=False, zeroline=False, linecolor='black', linewidth=1)
        fig.update_yaxes(title_text="% of Total OGs", range=[0, df_panel_a_plot['Percentage'].max()*1.15 if not df_panel_a_plot.empty else 10], row=1, col=1, showgrid=False, zeroline=False, linecolor='black', linewidth=1)
        fig.update_layout(legend_title_text='Group')
        # Manually set colspan for the x and y axes of panel A
        fig.layout.xaxis.domain = [0.0, 1.0] # Span both columns
        fig.layout.yaxis.domain = fig.layout.yaxis.domain # Keep its original y-domain for row 1

    else:
        print("No data for Panel A.")


    # --- Panel B: Functional Enrichment Heatmap ---
    heatmap_data_list = []
    baseline_func_dist_asgard = df_merged_diversity_fig12[df_merged_diversity_fig12[group_col] == 'Asgard'][broad_func_cat_col].value_counts(normalize=True).mul(100)
    baseline_func_dist_gv = df_merged_diversity_fig12[df_merged_diversity_fig12[group_col] == 'GV'][broad_func_cat_col].value_counts(normalize=True).mul(100)

    for profile_key, short_profile_label in profiles_for_panel_b_heatmap.items():
        for group_val in ['Asgard', 'GV']:
            df_profile_group_subset = df_extreme_ogs_summary[
                df_extreme_ogs_summary['Reason_for_Highlight'].str.contains(profile_key, regex=False, na=False) &
                (df_extreme_ogs_summary[group_col] == group_val)
            ]
            if df_profile_group_subset.empty: continue

            observed_dist = df_profile_group_subset[broad_func_cat_col].value_counts(normalize=True).mul(100)
            baseline_dist = baseline_func_dist_asgard if group_val == 'Asgard' else baseline_func_dist_gv
            
            for func_cat, obs_perc in observed_dist.items():
                exp_perc = baseline_dist.get(func_cat, 0) # Get expected, default to 0 if func_cat not in baseline
                log2_fold_change = np.log2((obs_perc + 1e-9) / (exp_perc + 1e-9)) # Add small epsilon to avoid log(0) or div by zero
                heatmap_data_list.append({
                    'Functional_Category': func_cat,
                    'Profile_Group': f"{group_val} - {short_profile_label}",
                    'Log2_Fold_Change': log2_fold_change
                })
    
    if heatmap_data_list:
        df_heatmap = pd.DataFrame(heatmap_data_list)
        df_heatmap_pivot = df_heatmap.pivot(index='Functional_Category', columns='Profile_Group', values='Log2_Fold_Change').fillna(0)
        
        # Sort categories for better readability (e.g., by mean enrichment or a predefined order)
        # For now, sort alphabetically
        df_heatmap_pivot = df_heatmap_pivot.sort_index()
        
        # Determine a symmetric color scale midpoint
        abs_max = max(abs(df_heatmap_pivot.min().min()), abs(df_heatmap_pivot.max().max()))

        fig.add_trace(go.Heatmap(
            z=df_heatmap_pivot.values,
            x=df_heatmap_pivot.columns,
            y=df_heatmap_pivot.index,
            colorscale='RdBu', # Red-Blue diverging scale
            zmid=0, # Center color at 0
            zmin=-abs_max if abs_max >0 else -1, 
            zmax=abs_max if abs_max >0 else 1,
            colorbar_title='Log2 (Obs/Exp %)'
        ), row=2, col=1) # This trace goes into the second row, first column (which spans 2)
        fig.update_xaxes(tickangle=30, row=2, col=1)
        fig.update_yaxes(categoryorder='array', categoryarray=df_heatmap_pivot.index.tolist(), row=2, col=1) # Keep sorted order
    else:
        print("No data for Panel B heatmap.")


    # --- Panel C: Shannon Entropy vs. APSI with Correlation Coefficients ---
    if not df_merged_diversity_fig12.empty and 'Shannon_Entropy' in df_merged_diversity_fig12.columns and apsi_col in df_merged_diversity_fig12.columns:
        df_panel_c_plot = df_merged_diversity_fig12.copy()
        df_panel_c_plot['APSI_Percent_Plot'] = df_panel_c_plot[apsi_col] * 100
        
        annotation_y_start_c = 0.95
        annotation_y_step_c = 0.07
        
        for group_name_panel_c, group_df_panel_c in df_panel_c_plot.groupby(group_col):
            fig.add_trace(go.Scatter(
                x=group_df_panel_c['APSI_Percent_Plot'],
                y=group_df_panel_c['Shannon_Entropy'],
                mode='markers', name=group_name_panel_c, legendgroup=group_name_panel_c, # Unique legend group for this panel
                marker=dict(color=group_colors.get(group_name_panel_c), opacity=0.5, size=4)
            ), row=3, col=1) # This trace goes into the third row, first column (which spans 2)

            df_corr_subset_c = group_df_panel_c[['APSI_Percent_Plot', 'Shannon_Entropy']].dropna()
            if len(df_corr_subset_c) > 2:
                rho, p_val = spearmanr(df_corr_subset_c['APSI_Percent_Plot'], df_corr_subset_c['Shannon_Entropy'])
                fig.add_annotation(
                    xref="x3 domain", yref="y3 domain", x=0.02, y=annotation_y_start_c, 
                    text=f"{group_name_panel_c}: ρ={rho:.2f} (p={p_val:.2g})", showarrow=False,
                    font=dict(size=10, color=group_colors.get(group_name_panel_c)), align="left", xanchor="left"
                )
                annotation_y_start_c -= annotation_y_step_c
        
        fig.update_xaxes(title_text="Intra-OG APSI (%)", row=3, col=1, showgrid=False, zeroline=False, linecolor='black', linewidth=1)
        fig.update_yaxes(title_text="Shannon Entropy", row=3, col=1, showgrid=False, zeroline=False, linecolor='black', linewidth=1)
        # Manually set colspan for the x and y axes of panel C
        fig.layout.xaxis3.domain = [0.0, 1.0] # Span both columns
        fig.layout.yaxis3.domain = fig.layout.yaxis3.domain # Keep its original y-domain for row 3

    else:
        print("Skipping Panel C: df_merged_diversity is empty or missing columns.")

    # --- Final Layout Updates for the Entire Figure ---
    fig.update_layout(
        height=1500, # Adjusted height for 3 rows
        plot_bgcolor='rgba(0,0,0,0)', paper_bgcolor='rgba(0,0,0,0)',
        margin=dict(l=100, r=50, t=120, b=100), # Adjusted margins
        legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1) # Position legend
    )
    if 'plotly_layout_defaults' in locals() and plotly_layout_defaults.font:
        fig.update_layout(font=plotly_layout_defaults.font)
    
    # Apply subplot title font styling
    for i in range(len(fig.layout.annotations)):
        # Check if it's a subplot title (they are added as annotations by make_subplots)
        # A common heuristic is that they are bolded or have a specific y anchor
        if fig.layout.annotations[i].text.startswith("<b>") :
            fig.layout.annotations[i].font.size = 16 # Larger subplot titles
            if 'plotly_layout_defaults' in locals() and plotly_layout_defaults.font:
                fig.layout.annotations[i].font.family = plotly_layout_defaults.font.family
    
    # Apply axis styling from plotly_layout_defaults to all axes
    if 'plotly_layout_defaults' in locals():
        for axis_name_template in ['xaxis', 'yaxis']:
            for i in range(1, 4): # For 3 rows
                 axis_ref = f"{axis_name_template}{i if i > 1 else ''}" # xaxis, xaxis2, xaxis3 etc.
                 if hasattr(fig.layout, axis_ref) and hasattr(plotly_layout_defaults, axis_name_template):
                     default_axis_style = getattr(plotly_layout_defaults, axis_name_template)
                     current_axis = getattr(fig.layout, axis_ref)
                     if default_axis_style.title and default_axis_style.title.font:
                         current_axis.title.font.size = default_axis_style.title.font.size
                         current_axis.title.font.family = default_axis_style.title.font.family
                         current_axis.title.font.weight = default_axis_style.title.font.weight
                     if default_axis_style.tickfont:
                         current_axis.tickfont.size = default_axis_style.tickfont.size
                         current_axis.tickfont.family = default_axis_style.tickfont.family


    fig.show()

    # --- Save Figure ---
    fig_path_html = Path(output_figure_dir) / "figure12_consolidated_diversity_final.html"
    fig.write_html(str(fig_path_html))
    print(f"Figure 12 (Consolidated Diversity Final) HTML saved to: {fig_path_html}")
    try:
        fig_path_pdf = Path(output_figure_dir) / "figure12_consolidated_diversity_final.pdf"
        fig.write_image(str(fig_path_pdf), width=900, height=1500) # Adjusted width
        print(f"Figure 12 PDF saved to: {fig_path_pdf}")
    except Exception as e: print(f"Warning: Could not export Figure 12 to PDF. Error: {e}")
    
    # Save summary data for heatmap
    if heatmap_data_list:
        df_heatmap_summary = pd.DataFrame(heatmap_data_list)
        heatmap_summary_path = Path(output_summary_dir_phase1) / "figure12_panel_b_heatmap_data.csv"
        try:
            df_heatmap_summary.to_csv(heatmap_summary_path, index=False, float_format='%.3f')
            print(f"Heatmap data for Figure 12 Panel B saved to: {heatmap_summary_path}")
        except Exception as e:
            print(f"Error saving heatmap data: {e}")


print("\n\n--- Cell 20 (Figure 12 - Consolidated Diversity Exploration) Complete ---")



In [None]:
# Cell 20: Figure 12 - Consolidated Diversity Profiles, Functional Enrichment, and Correlations (Revised Panel B & SVG Export)

# --- Imports & Setup Assumptions ---
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from pathlib import Path
from scipy.stats import spearmanr

# This cell assumes that pandas (pd), numpy (np), plotly.express (px),
# plotly.graph_objects (go), Path from pathlib, make_subplots, spearmanr
# and all necessary variables/helper functions from Cell 1 (Setup) are available.
# Key variables from previous cells:
# - df_merged_diversity: DataFrame from Cell 13 (OG, APSI, Group, Shannon_Entropy, Observed_Richness, Broad_Functional_Category).
# - output_summary_dir_phase1: Directory where Cell 15 saved its summary.
# - output_figure_dir: Directory to save plots.
# - plotly_layout_defaults: Default layout for Plotly figures.
# - group_colors, broad_category_color_map: Color maps.
# - orthogroup_col, group_col, apsi_col, broad_func_cat_col

print("\n\n--- Generating Figure 12: Consolidated Diversity, Functional Enrichment, & Correlations (Revised Panel B & SVG Export) (Cell 20) ---")

# --- Configuration & Definitions (mirroring Cell 15 & 18 for consistency) ---
extreme_ogs_summary_path = Path(output_summary_dir_phase1) / "figure11_extreme_diversity_ogs_summary_expanded_revised.csv"

if 'n_percentile_strict' not in locals(): n_percentile_strict = 5
if 'combo_percentile_relaxed' not in locals(): combo_percentile_relaxed = 25 # From revised Cell 15

profiles_for_panel_a = {
    f"Top {n_percentile_strict}% Shannon Entropy": f"Top {n_percentile_strict}% Ent.",
    f"Bottom {n_percentile_strict}% APSI": f"Bot. {n_percentile_strict}% APSI",
    "High Entropy & Low APSI": "Hi Ent, Lo APSI",
    "High Entropy & High APSI": "Hi Ent, Hi APSI",
    "Low Entropy & Low APSI": "Lo Ent, Lo APSI",
    "High Richness & High APSI": "Hi Rich, Hi APSI"
}
profile_keys_for_panel_a = list(profiles_for_panel_a.keys())
profile_labels_for_panel_a = list(profiles_for_panel_a.values())

profiles_for_panel_b_heatmap = {
    "High Entropy & High APSI": "Hi Ent, Hi APSI",
    "Low Entropy & Low APSI": "Lo Ent, Lo APSI"
}
# Define categories to EXCLUDE from the heatmap in Panel B
categories_to_exclude_from_heatmap = ["Other Specific Annotation", "General Protein Features", "Unknown/Unclassified"]


# --- Load Data ---
error_found_loading = False
df_extreme_ogs_summary = pd.DataFrame()
df_merged_diversity_fig12 = pd.DataFrame()

if not extreme_ogs_summary_path.is_file():
    print(f"ERROR: Extreme OGs summary file not found at '{extreme_ogs_summary_path}'. Please run Cell 15 (revised) first.")
    error_found_loading = True
else:
    try:
        df_extreme_ogs_summary = pd.read_csv(extreme_ogs_summary_path)
        if not all(k in df_extreme_ogs_summary.columns for k in [orthogroup_col, group_col, 'Reason_for_Highlight', broad_func_cat_col]):
            print(f"ERROR: Loaded extreme OGs summary file is missing critical columns.")
            df_extreme_ogs_summary = pd.DataFrame() 
            error_found_loading = True
    except Exception as e:
        print(f"Error loading extreme OGs summary file: {e}")
        df_extreme_ogs_summary = pd.DataFrame()
        error_found_loading = True

if 'df_merged_diversity' not in locals() or df_merged_diversity.empty:
    print(f"ERROR: DataFrame 'df_merged_diversity' not found or is empty. Please run Cell 13 first.")
    error_found_loading = True
else:
    df_merged_diversity_fig12 = df_merged_diversity.copy()
    if not all(k in df_merged_diversity_fig12.columns for k in [orthogroup_col, group_col, apsi_col, 'Shannon_Entropy', broad_func_cat_col]):
        print(f"ERROR: 'df_merged_diversity' is missing critical columns for Panels B/C.")
        error_found_loading = True

if error_found_loading:
    print("Cannot proceed with Figure 12 generation due to missing input data.")
else:
    print(f"Data loaded. df_extreme_ogs_summary: {df_extreme_ogs_summary.shape}, df_merged_diversity_fig12: {df_merged_diversity_fig12.shape}")

    fig = make_subplots(
        rows=3, cols=1, 
        subplot_titles=(
            "<b>A.</b> Prevalence of OGs in Key Diversity Profiles",
            "<b>B.</b> Functional Enrichment in 'Surprising' Diversity Profiles",
            "<b>C.</b> Shannon Entropy vs. Intra-OG APSI"
        ),
        vertical_spacing=0.12, 
        row_heights=[0.3, 0.4, 0.3] 
    )

    # --- Panel A: Prevalence of Key Diversity Profiles ---
    total_asgard_ogs = df_merged_diversity_fig12[df_merged_diversity_fig12[group_col] == 'Asgard'][orthogroup_col].nunique()
    total_gv_ogs = df_merged_diversity_fig12[df_merged_diversity_fig12[group_col] == 'GV'][orthogroup_col].nunique()
    panel_a_plot_data = []
    for reason_key, short_label in profiles_for_panel_a.items():
        df_matching = df_extreme_ogs_summary[df_extreme_ogs_summary['Reason_for_Highlight'].str.contains(reason_key, regex=False, na=False)]
        count_asgard = df_matching[df_matching[group_col] == 'Asgard'][orthogroup_col].nunique()
        count_gv = df_matching[df_matching[group_col] == 'GV'][orthogroup_col].nunique()
        panel_a_plot_data.append({'Profile': short_label, 'Group': 'Asgard', 'Percentage': (count_asgard / total_asgard_ogs * 100) if total_asgard_ogs > 0 else 0})
        panel_a_plot_data.append({'Profile': short_label, 'Group': 'GV', 'Percentage': (count_gv / total_gv_ogs * 100) if total_gv_ogs > 0 else 0})
    df_panel_a_plot = pd.DataFrame(panel_a_plot_data)

    if not df_panel_a_plot.empty:
        fig_a_temp = px.bar(df_panel_a_plot, x='Profile', y='Percentage', color='Group', barmode='group', color_discrete_map=group_colors)
        for trace in fig_a_temp.data:
            fig.add_trace(trace, row=1, col=1)
        fig.update_xaxes(categoryorder='array', categoryarray=profile_labels_for_panel_a, tickangle=35, row=1, col=1) 
        fig.update_yaxes(title_text="% of Total OGs", range=[0, df_panel_a_plot['Percentage'].max()*1.15 if not df_panel_a_plot.empty else 10], row=1, col=1)
        fig.update_layout(legend_title_text='Group', legend=dict(tracegroupgap=5)) 
    else:
        print("No data for Panel A.")

    # --- Panel B: Functional Enrichment Heatmap (Revised) ---
    heatmap_data_list = []
    # Calculate baseline distributions *after* excluding unwanted categories
    baseline_func_dist_asgard = df_merged_diversity_fig12[
        (df_merged_diversity_fig12[group_col] == 'Asgard') & 
        (~df_merged_diversity_fig12[broad_func_cat_col].isin(categories_to_exclude_from_heatmap))
    ][broad_func_cat_col].value_counts(normalize=True).mul(100)
    
    baseline_func_dist_gv = df_merged_diversity_fig12[
        (df_merged_diversity_fig12[group_col] == 'GV') &
        (~df_merged_diversity_fig12[broad_func_cat_col].isin(categories_to_exclude_from_heatmap))
    ][broad_func_cat_col].value_counts(normalize=True).mul(100)

    for profile_key, short_profile_label in profiles_for_panel_b_heatmap.items():
        for group_val in ['Asgard', 'GV']:
            # Filter the extreme OGs for the current profile, group, AND exclude unwanted functional categories
            df_profile_group_subset = df_extreme_ogs_summary[
                df_extreme_ogs_summary['Reason_for_Highlight'].str.contains(profile_key, regex=False, na=False) &
                (df_extreme_ogs_summary[group_col] == group_val) &
                (~df_extreme_ogs_summary[broad_func_cat_col].isin(categories_to_exclude_from_heatmap)) 
            ]
            if df_profile_group_subset.empty: continue

            # Calculate observed distribution for the filtered subset
            observed_dist = df_profile_group_subset[broad_func_cat_col].value_counts(normalize=True).mul(100)
            baseline_dist = baseline_func_dist_asgard if group_val == 'Asgard' else baseline_func_dist_gv
            
            # Iterate through functional categories present in the OBSERVED distribution for this profile/group
            for func_cat, obs_perc in observed_dist.items():
                # Ensure this func_cat is not in the exclusion list (double check, though subset should handle it)
                if func_cat in categories_to_exclude_from_heatmap:
                    continue
                exp_perc = baseline_dist.get(func_cat, 0) # Get expected, default to 0 if func_cat not in baseline (e.g. if it was filtered out)
                log2_fold_change = np.log2((obs_perc + 1e-9) / (exp_perc + 1e-9)) # Add small epsilon
                heatmap_data_list.append({
                    'Functional_Category': func_cat,
                    'Profile_Group': f"{group_val} - {short_profile_label}",
                    'Log2_Fold_Change': log2_fold_change
                })
    
    if heatmap_data_list:
        df_heatmap = pd.DataFrame(heatmap_data_list)
        if not df_heatmap.empty:
            df_heatmap_pivot = df_heatmap.pivot(index='Functional_Category', columns='Profile_Group', values='Log2_Fold_Change').fillna(0)
            # Further filter pivot table rows if any excluded categories slipped through (shouldn't if logic above is correct)
            df_heatmap_pivot = df_heatmap_pivot[~df_heatmap_pivot.index.isin(categories_to_exclude_from_heatmap)]
            df_heatmap_pivot = df_heatmap_pivot.sort_index() 
            
            if not df_heatmap_pivot.empty:
                abs_max_val = df_heatmap_pivot.abs().max().max() 
                z_min = -abs_max_val if abs_max_val > 0 else -1
                z_max = abs_max_val if abs_max_val > 0 else 1

                fig.add_trace(go.Heatmap(
                    z=df_heatmap_pivot.values,
                    x=df_heatmap_pivot.columns,
                    y=df_heatmap_pivot.index,
                    colorscale='RdBu', 
                    zmid=0, 
                    zmin=z_min, 
                    zmax=z_max,
                    colorbar=dict(
                        title=dict(text='Log2 (Obs/Exp %)', side='right'), # Corrected title attribute
                        tickfont=dict(size=8), # Smaller font for colorbar ticks
                        len=0.75, y=0.47, yanchor='middle' # Adjust length and position of colorbar if needed
                    )
                ), row=2, col=1)
                fig.update_xaxes(tickangle=20, row=2, col=1, tickfont=dict(size=10)) # Adjusted angle
                fig.update_yaxes(categoryorder='array', categoryarray=df_heatmap_pivot.index.tolist(), row=2, col=1, tickfont=dict(size=10))
            else:
                print("Heatmap pivot table is empty after filtering and pivoting for Panel B.")
        else:
            print("No data for Panel B heatmap after processing (df_heatmap was empty).")
    else:
        print("No data for Panel B heatmap initially (heatmap_data_list was empty).")

    # --- Panel C: Shannon Entropy vs. APSI with Correlation Coefficients ---
    if not df_merged_diversity_fig12.empty and 'Shannon_Entropy' in df_merged_diversity_fig12.columns and apsi_col in df_merged_diversity_fig12.columns:
        df_panel_c_plot = df_merged_diversity_fig12.copy()
        df_panel_c_plot['APSI_Percent_Plot'] = df_panel_c_plot[apsi_col] * 100
        
        annotation_y_start_c = 0.95
        annotation_y_step_c = 0.08 
        
        for group_name_panel_c, group_df_panel_c in df_panel_c_plot.groupby(group_col):
            fig.add_trace(go.Scatter(
                x=group_df_panel_c['APSI_Percent_Plot'],
                y=group_df_panel_c['Shannon_Entropy'],
                mode='markers', name=f"{group_name_panel_c} (Panel C)", legendgroup=group_name_panel_c, 
                marker=dict(color=group_colors.get(group_name_panel_c), opacity=0.5, size=4)
            ), row=3, col=1)

            df_corr_subset_c = group_df_panel_c[['APSI_Percent_Plot', 'Shannon_Entropy']].dropna()
            if len(df_corr_subset_c) > 2:
                rho, p_val = spearmanr(df_corr_subset_c['APSI_Percent_Plot'], df_corr_subset_c['Shannon_Entropy'])
                fig.add_annotation(
                    xref="x3 domain", yref="y3 domain", x=0.02, y=annotation_y_start_c, 
                    text=f"{group_name_panel_c}: ρ={rho:.2f} (p={p_val:.2g})", showarrow=False,
                    font=dict(size=10, color=group_colors.get(group_name_panel_c)), align="left", xanchor="left"
                )
                annotation_y_start_c -= annotation_y_step_c
        
        fig.update_xaxes(title_text="Intra-OG APSI (%)", row=3, col=1)
        fig.update_yaxes(title_text="Shannon Entropy", row=3, col=1)
    else:
        print("Skipping Panel C: df_merged_diversity is empty or missing columns.")

    # --- Final Layout Updates for the Entire Figure ---
    fig.update_layout(
        height=1400, 
        plot_bgcolor='rgba(0,0,0,0)', paper_bgcolor='rgba(0,0,0,0)',
        margin=dict(l=120, r=50, t=100, b=120), 
        legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1) 
    )
    if 'plotly_layout_defaults' in locals() and plotly_layout_defaults.font:
        fig.update_layout(font=plotly_layout_defaults.font)
    
    for i in range(len(fig.layout.annotations)):
        if fig.layout.annotations[i].text.startswith("<b>"): 
            fig.layout.annotations[i].font.size = 14 
            if 'plotly_layout_defaults' in locals() and plotly_layout_defaults.font:
                fig.layout.annotations[i].font.family = plotly_layout_defaults.font.family
    
    if 'plotly_layout_defaults' in locals():
        for axis_name_template in ['xaxis', 'yaxis']:
            for i in range(1, 4): 
                 axis_ref = f"{axis_name_template}{i if i > 1 else ''}" 
                 if hasattr(fig.layout, axis_ref) and hasattr(plotly_layout_defaults, axis_name_template):
                     default_axis_style = getattr(plotly_layout_defaults, axis_name_template)
                     current_axis = getattr(fig.layout, axis_ref)
                     if default_axis_style.title and default_axis_style.title.font:
                         current_axis.title.font.size = default_axis_style.title.font.size if current_axis.title.text else 10 
                         current_axis.title.font.family = default_axis_style.title.font.family
                         current_axis.title.font.weight = default_axis_style.title.font.weight
                     if default_axis_style.tickfont:
                         current_axis.tickfont.size = default_axis_style.tickfont.size
                         current_axis.tickfont.family = default_axis_style.tickfont.family
                     current_axis.showgrid = False
                     current_axis.zeroline = False
                     current_axis.linecolor = 'black'
                     current_axis.linewidth = 1
                     current_axis.ticks="outside"
                     current_axis.ticklen=5
                     current_axis.tickwidth=1
                     current_axis.tickcolor='black'

    fig.show()

    # --- Save Figure ---
    fig_path_html = Path(output_figure_dir) / "figure12_consolidated_diversity_final_revised.html"
    fig.write_html(str(fig_path_html))
    print(f"Figure 12 (Consolidated Diversity Final Revised) HTML saved to: {fig_path_html}")
    try:
        fig_path_pdf = Path(output_figure_dir) / "figure12_consolidated_diversity_final_revised.pdf"
        fig.write_image(str(fig_path_pdf), width=850, height=1400) 
        print(f"Figure 12 PDF saved to: {fig_path_pdf}")
    except Exception as e: print(f"Warning: Could not export Figure 12 to PDF. Error: {e}")
    # Ensure SVG export is present
    try:
        fig_path_svg = Path(output_figure_dir) / "figure12_consolidated_diversity_final_revised.svg"
        fig.write_image(str(fig_path_svg), width=850, height=1400)
        print(f"Figure 12 SVG saved to: {fig_path_svg}")
    except Exception as e: print(f"Warning: Could not export Figure 12 to SVG. Error: {e}")

    
    if heatmap_data_list:
        df_heatmap_summary = pd.DataFrame(heatmap_data_list)
        heatmap_summary_path = Path(output_summary_dir_phase1) / "figure12_panel_b_heatmap_data_revised.csv"
        try:
            df_heatmap_summary.to_csv(heatmap_summary_path, index=False, float_format='%.3f')
            print(f"Heatmap data for Figure 12 Panel B (revised) saved to: {heatmap_summary_path}")
        except Exception as e:
            print(f"Error saving revised heatmap data: {e}")

print("\n\n--- Cell 20 (Figure 12 - Consolidated Diversity Exploration - Revised) Complete ---")



In [None]:
# Cell 21: Figure 13 - Characterizing Eukaryotic Hits for Asgard and Giant Virus Proteins

# --- Imports & Setup Assumptions ---
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from pathlib import Path

# This cell assumes that pandas (pd), numpy (np), plotly.express (px),
# plotly.graph_objects (go), Path from pathlib, make_subplots,
# and all necessary variables/helper functions from Cell 1 (Setup) are available.
# Key variables from Cell 1:
# - output_figure_dir, output_summary_dir_phase1: Directories to save plots/data.
# - plotly_layout_defaults: Default layout for Plotly figures.
# - arcadia_primary_palette, arcadia_secondary_palette: Color palettes.
# - group_col, protein_id_col,
# - Has_Euk_DIAMOND_Hit, Euk_Hit_Organism, Euk_Hit_Protein_Name (column names)
# - clean_protein_name (helper function from Cell 1)

print("\n\n--- Generating Figure 13: Characterizing Eukaryotic Hits (Cell 21) ---")

# --- Configuration ---
DB_PATH_V2_4 = 'proteome_database_v2.4_gv_euk_hits.csv' # INPUT: Database with GV Euk Hits
TOP_N_ORGANISMS = 25  # Number of top organisms to display in Panels A & B
TOP_N_PROTEIN_NAMES = 30 # Number of top protein names for Panel C

# --- Load Data ---
error_found_loading = False
df_db_v2_4 = pd.DataFrame()

if not Path(DB_PATH_V2_4).is_file():
    print(f"ERROR: Database file '{DB_PATH_V2_4}' not found. Please ensure it was created correctly.")
    error_found_loading = True
else:
    try:
        df_db_v2_4 = pd.read_csv(DB_PATH_V2_4, low_memory=False)
        # Ensure key columns exist
        required_cols = [group_col, 'Has_Euk_DIAMOND_Hit', 'Euk_Hit_Organism', 'Euk_Hit_Protein_Name', protein_id_col]
        if not all(col in df_db_v2_4.columns for col in required_cols):
            missing = [col for col in required_cols if col not in df_db_v2_4.columns]
            print(f"ERROR: Loaded database '{DB_PATH_V2_4}' is missing critical columns: {missing}")
            error_found_loading = True
        else:
            print(f"Successfully loaded database '{DB_PATH_V2_4}'. Shape: {df_db_v2_4.shape}")
            # Data type conversions for safety
            df_db_v2_4['Has_Euk_DIAMOND_Hit'] = df_db_v2_4['Has_Euk_DIAMOND_Hit'].fillna(False).astype(bool)
            df_db_v2_4['Euk_Hit_Organism'] = df_db_v2_4['Euk_Hit_Organism'].fillna('Unknown').astype(str)
            df_db_v2_4['Euk_Hit_Protein_Name'] = df_db_v2_4['Euk_Hit_Protein_Name'].fillna('Unknown').astype(str)

    except Exception as e:
        print(f"Error loading database '{DB_PATH_V2_4}': {e}")
        error_found_loading = True

if error_found_loading:
    print("Cannot proceed with Figure 13 generation due to missing input data.")
else:
    # --- Create Figure with Subplots ---
    fig = make_subplots(
        rows=3, cols=1,
        subplot_titles=(
            "<b>A.</b> Top Eukaryotic Organisms Hit by Asgard Proteins",
            "<b>B.</b> Top Eukaryotic Organisms Hit by Giant Virus Proteins",
            "<b>C.</b> Top Eukaryotic Protein Functions Hit by Giant Virus Proteins"
        ),
        vertical_spacing=0.1 # Adjust spacing
    )

    # --- Panel A: Top Eukaryotic Organisms Hit by Asgard Proteins ---
    print("\n--- Generating Panel A: Asgard Eukaryotic Hit Organisms ---")
    df_asgard_hits_panel_a = df_db_v2_4[
        (df_db_v2_4[group_col] == 'Asgard') &
        (df_db_v2_4['Has_Euk_DIAMOND_Hit'] == True)
    ].copy()

    if not df_asgard_hits_panel_a.empty:
        asgard_org_counts = df_asgard_hits_panel_a['Euk_Hit_Organism'].value_counts().nlargest(TOP_N_ORGANISMS).reset_index()
        asgard_org_counts.columns = ['Euk_Hit_Organism', 'Count']
        
        if not asgard_org_counts.empty:
            fig.add_trace(go.Bar(
                x=asgard_org_counts['Euk_Hit_Organism'],
                y=asgard_org_counts['Count'],
                name='Asgard Hits',
                marker_color=arcadia_primary_palette[0 % len(arcadia_primary_palette)]
            ), row=1, col=1)
            fig.update_xaxes(title_text="", tickangle=45, categoryorder='total descending', row=1, col=1)
            fig.update_yaxes(title_text="Number of Asgard Proteins", row=1, col=1)
            
            # Save summary data for Panel A
            panel_a_data_path = Path(output_summary_dir_phase1) / "figure13_panel_a_asgard_euk_org_hits.csv"
            try:
                df_asgard_hits_panel_a['Euk_Hit_Organism'].value_counts().reset_index().to_csv(panel_a_data_path, index=False)
                print(f"Data for Panel A saved to {panel_a_data_path}")
            except Exception as e: print(f"Error saving Panel A data: {e}")
        else:
            print("No eukaryotic hits found for Asgard proteins to plot for Panel A.")
    else:
        print("No Asgard proteins with eukaryotic hits found for Panel A.")

    # --- Panel B: Top Eukaryotic Organisms Hit by Giant Virus Proteins ---
    print("\n--- Generating Panel B: Giant Virus Eukaryotic Hit Organisms ---")
    df_gv_hits_panel_b = df_db_v2_4[
        (df_db_v2_4[group_col] == 'GV') &  # Make sure 'GV' is the correct value in your 'Group' column
        (df_db_v2_4['Has_Euk_DIAMOND_Hit'] == True)
    ].copy()

    if not df_gv_hits_panel_b.empty:
        gv_org_counts = df_gv_hits_panel_b['Euk_Hit_Organism'].value_counts().nlargest(TOP_N_ORGANISMS).reset_index()
        gv_org_counts.columns = ['Euk_Hit_Organism', 'Count']

        if not gv_org_counts.empty:
            fig.add_trace(go.Bar(
                x=gv_org_counts['Euk_Hit_Organism'],
                y=gv_org_counts['Count'],
                name='GV Hits',
                marker_color=arcadia_primary_palette[1 % len(arcadia_primary_palette)]
            ), row=2, col=1)
            fig.update_xaxes(title_text="", tickangle=45, categoryorder='total descending', row=2, col=1)
            fig.update_yaxes(title_text="Number of GV Proteins", row=2, col=1)

            # Save summary data for Panel B
            panel_b_data_path = Path(output_summary_dir_phase1) / "figure13_panel_b_gv_euk_org_hits.csv"
            try:
                df_gv_hits_panel_b['Euk_Hit_Organism'].value_counts().reset_index().to_csv(panel_b_data_path, index=False)
                print(f"Data for Panel B saved to {panel_b_data_path}")
            except Exception as e: print(f"Error saving Panel B data: {e}")
        else:
            print("No eukaryotic hits found for GV proteins to plot for Panel B.")
    else:
        print("No GV proteins with eukaryotic hits found for Panel B.")

    # --- Panel C: Top Eukaryotic Protein Functions Hit by Giant Virus Proteins ---
    print("\n--- Generating Panel C: Giant Virus Eukaryotic Hit Protein Functions ---")
    if not df_gv_hits_panel_b.empty: # Reuse df_gv_hits_panel_b from above
        if 'clean_protein_name' in locals() and callable(clean_protein_name):
            df_gv_hits_panel_b['Cleaned_Euk_Hit_Protein_Name'] = df_gv_hits_panel_b['Euk_Hit_Protein_Name'].apply(clean_protein_name)
        else:
            print("WARNING: 'clean_protein_name' function not found. Using raw Euk_Hit_Protein_Name for Panel C.")
            df_gv_hits_panel_b['Cleaned_Euk_Hit_Protein_Name'] = df_gv_hits_panel_b['Euk_Hit_Protein_Name']

        gv_prot_name_counts = df_gv_hits_panel_b['Cleaned_Euk_Hit_Protein_Name'].value_counts().nlargest(TOP_N_PROTEIN_NAMES).reset_index()
        gv_prot_name_counts.columns = ['Cleaned_Euk_Hit_Protein_Name', 'Count']
        
        # Filter out very generic names if they dominate, e.g., "Hypothetical protein"
        gv_prot_name_counts = gv_prot_name_counts[
            ~gv_prot_name_counts['Cleaned_Euk_Hit_Protein_Name'].str.contains("Hypothetical|Unknown|Uncharacterized", case=False, na=False)
        ]


        if not gv_prot_name_counts.empty:
            fig.add_trace(go.Bar(
                x=gv_prot_name_counts['Cleaned_Euk_Hit_Protein_Name'],
                y=gv_prot_name_counts['Count'],
                name='GV Hit Protein Functions',
                marker_color=arcadia_secondary_palette[0 % len(arcadia_secondary_palette)]
            ), row=3, col=1)
            fig.update_xaxes(title_text="Eukaryotic Protein Function (Cleaned Name)", tickangle=45, categoryorder='total descending', row=3, col=1)
            fig.update_yaxes(title_text="Number of GV Protein Hits", row=3, col=1)
            
            # Save summary data for Panel C
            panel_c_data_path = Path(output_summary_dir_phase1) / "figure13_panel_c_gv_euk_prot_func_hits.csv"
            try:
                # Save the full list of cleaned names and counts before nlargest
                full_gv_prot_name_counts = df_gv_hits_panel_b['Cleaned_Euk_Hit_Protein_Name'].value_counts().reset_index()
                full_gv_prot_name_counts.columns = ['Cleaned_Euk_Hit_Protein_Name', 'Count']
                full_gv_prot_name_counts.to_csv(panel_c_data_path, index=False)
                print(f"Data for Panel C saved to {panel_c_data_path}")
            except Exception as e: print(f"Error saving Panel C data: {e}")
        else:
            print("No eukaryotic protein names found for GV hits to plot for Panel C after filtering generic names.")
            
    else:
        print("No GV proteins with eukaryotic hits found for Panel C analysis.")


    # --- Final Layout Updates for the Entire Figure ---
    fig.update_layout(
        height=1300, # Adjusted height for 3 panels
        plot_bgcolor='rgba(0,0,0,0)', 
        paper_bgcolor='rgba(0,0,0,0)',
        margin=dict(l=80, r=50, t=100, b=200), # Increased bottom margin for angled labels
        showlegend=False # Individual traces are self-explanatory by panel title
    )
    # Apply global font and subplot title styling from plotly_layout_defaults
    if 'plotly_layout_defaults' in locals():
        fig.update_layout(font=plotly_layout_defaults.font)
        for i in range(len(fig.layout.annotations)): # Subplot titles
            if fig.layout.annotations[i].text.startswith("<b>"): 
                fig.layout.annotations[i].font.size = 14
                if plotly_layout_defaults.font:
                    fig.layout.annotations[i].font.family = plotly_layout_defaults.font.family
    
    # Apply common axis styling from plotly_layout_defaults to all axes
    if 'plotly_layout_defaults' in locals():
        for axis_name_template in ['xaxis', 'yaxis']:
            for i in range(1, 4): # For 3 rows
                 axis_ref_name = f"{axis_name_template}{i if i > 1 else ''}" # e.g., xaxis, yaxis, xaxis2, yaxis2...
                 if hasattr(fig.layout, axis_ref_name) and hasattr(plotly_layout_defaults, axis_name_template):
                     default_axis_style = getattr(plotly_layout_defaults, axis_name_template)
                     current_axis = getattr(fig.layout, axis_ref_name)
                     
                     # Apply title font if title text exists for this specific axis
                     if current_axis.title and current_axis.title.text and default_axis_style.title and default_axis_style.title.font:
                         current_axis.title.font.size = default_axis_style.title.font.size
                         current_axis.title.font.family = default_axis_style.title.font.family
                         current_axis.title.font.weight = default_axis_style.title.font.weight
                     elif default_axis_style.title and default_axis_style.title.font: # Apply to default title if no text
                         current_axis.title.font.size = default_axis_style.title.font.size
                         current_axis.title.font.family = default_axis_style.title.font.family
                         current_axis.title.font.weight = default_axis_style.title.font.weight


                     if default_axis_style.tickfont:
                         current_axis.tickfont.size = default_axis_style.tickfont.size
                         current_axis.tickfont.family = default_axis_style.tickfont.family
                     
                     current_axis.showgrid = False
                     current_axis.zeroline = False
                     current_axis.linecolor = 'black'
                     current_axis.linewidth = 1.5 # Match your preferred style
                     current_axis.ticks="outside"
                     current_axis.ticklen=5
                     current_axis.tickwidth=1.5
                     current_axis.tickcolor='black'
    fig.show()

    # --- Save Figure ---
    fig_path_html = Path(output_figure_dir) / "figure13_eukaryotic_hit_characterization.html"
    fig.write_html(str(fig_path_html))
    print(f"Figure 13 (Eukaryotic Hit Characterization) HTML saved to: {fig_path_html}")
    try:
        fig_path_pdf = Path(output_figure_dir) / "figure13_eukaryotic_hit_characterization.pdf"
        fig.write_image(str(fig_path_pdf), width=900, height=1300)
        print(f"Figure 13 PDF saved to: {fig_path_pdf}")
    except Exception as e: print(f"Warning: Could not export Figure 13 to PDF. Error: {e}")
    try:
        fig_path_svg = Path(output_figure_dir) / "figure13_eukaryotic_hit_characterization.svg"
        fig.write_image(str(fig_path_svg), width=900, height=1300)
        print(f"Figure 13 SVG saved to: {fig_path_svg}")
    except Exception as e: print(f"Warning: Could not export Figure 13 to SVG. Error: {e}")

print("\n\n--- Cell 21 (Figure 13 - Eukaryotic Hit Characterization) Complete ---")



In [None]:
# Cell 21: Figure 13 - Characterizing Eukaryotic Hits for Asgard and Giant Virus Proteins (Consistent Colors)

# --- Imports & Setup Assumptions ---
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from pathlib import Path

# This cell assumes that pandas (pd), numpy (np), plotly.express (px),
# plotly.graph_objects (go), Path from pathlib, make_subplots,
# and all necessary variables/helper functions from Cell 1 (Setup) are available.
# Key variables from Cell 1:
# - output_figure_dir, output_summary_dir_phase1: Directories to save plots/data.
# - plotly_layout_defaults: Default layout for Plotly figures.
# - arcadia_primary_palette, arcadia_secondary_palette: Color palettes.
# - group_col, protein_id_col,
# - 'Has_Euk_DIAMOND_Hit', 'Euk_Hit_Organism', 'Euk_Hit_Protein_Name' (column names)
# - clean_protein_name (helper function from Cell 1)

print("\n\n--- Generating Figure 13: Characterizing Eukaryotic Hits (Consistent Colors) (Cell 21) ---")

# --- Configuration ---
DB_PATH_V2_4 = 'proteome_database_v2.5.csv' # INPUT: Database with GV Euk Hits
TOP_N_ORGANISMS = 25
TOP_N_PROTEIN_NAMES = 30

# --- Load Data ---
error_found_loading = False
df_db_v2_4 = pd.DataFrame()

if not Path(DB_PATH_V2_4).is_file():
    print(f"ERROR: Database file '{DB_PATH_V2_4}' not found. Please ensure it was created correctly.")
    error_found_loading = True
else:
    try:
        df_db_v2_4 = pd.read_csv(DB_PATH_V2_4, low_memory=False)
        required_cols = [group_col, 'Has_Euk_DIAMOND_Hit', 'Euk_Hit_Organism', 'Euk_Hit_Protein_Name', protein_id_col]
        if not all(col in df_db_v2_4.columns for col in required_cols):
            missing = [col for col in required_cols if col not in df_db_v2_4.columns]
            print(f"ERROR: Loaded database '{DB_PATH_V2_4}' is missing critical columns: {missing}")
            error_found_loading = True
        else:
            print(f"Successfully loaded database '{DB_PATH_V2_4}'. Shape: {df_db_v2_4.shape}")
            df_db_v2_4['Has_Euk_DIAMOND_Hit'] = df_db_v2_4['Has_Euk_DIAMOND_Hit'].fillna(False).astype(bool)
            df_db_v2_4['Euk_Hit_Organism'] = df_db_v2_4['Euk_Hit_Organism'].fillna('Unknown').astype(str).str.strip()
            df_db_v2_4['Euk_Hit_Protein_Name'] = df_db_v2_4['Euk_Hit_Protein_Name'].fillna('Unknown').astype(str)
    except Exception as e:
        print(f"Error loading database '{DB_PATH_V2_4}': {e}")
        error_found_loading = True

if error_found_loading:
    print("Cannot proceed with Figure 13 generation due to missing input data.")
else:
    # --- Prepare Data for Panels A and B ---
    df_asgard_hits_panel_a = df_db_v2_4[
        (df_db_v2_4[group_col] == 'Asgard') &
        (df_db_v2_4['Has_Euk_DIAMOND_Hit'] == True)
    ].copy()
    
    df_gv_hits_panel_b = df_db_v2_4[
        (df_db_v2_4.get(group_col) == 'GV') & # Use .get for safety if group_col might be missing
        (df_db_v2_4['Has_Euk_DIAMOND_Hit'] == True)
    ].copy()

    # --- Create a Consistent Color Map for Eukaryotic Organisms ---
    print("\n--- Creating Consistent Color Map for Eukaryotic Organisms ---")
    top_asgard_orgs_list = []
    if not df_asgard_hits_panel_a.empty:
        top_asgard_orgs_list = df_asgard_hits_panel_a['Euk_Hit_Organism'].value_counts().nlargest(TOP_N_ORGANISMS).index.tolist()

    top_gv_orgs_list = []
    if not df_gv_hits_panel_b.empty:
        top_gv_orgs_list = df_gv_hits_panel_b['Euk_Hit_Organism'].value_counts().nlargest(TOP_N_ORGANISMS).index.tolist()

    # Combine and get unique organisms from the top lists of both groups
    all_top_organisms = sorted(list(set(top_asgard_orgs_list + top_gv_orgs_list) - {'Unknown'}))
    
    euk_organism_color_map = {}
    # Use a combined palette for more variety if many unique organisms
    combined_palette = arcadia_primary_palette + arcadia_secondary_palette + arcadia_neutrals_palette 
    # Ensure 'Unknown' gets a specific color if it appears
    euk_organism_color_map['Unknown'] = '#cccccc' # A neutral gray for Unknown

    for i, org_name in enumerate(all_top_organisms):
        euk_organism_color_map[org_name] = combined_palette[i % len(combined_palette)]
    
    print(f"Created color map for {len(euk_organism_color_map)} unique eukaryotic organisms (including Unknown).")

    # --- Create Figure with Subplots ---
    fig = make_subplots(
        rows=3, cols=1,
        subplot_titles=(
            "<b>A.</b> Top Eukaryotic Organisms Hit by Asgard Proteins",
            "<b>B.</b> Top Eukaryotic Organisms Hit by Giant Virus Proteins",
            "<b>C.</b> Top Eukaryotic Protein Functions Hit by Giant Virus Proteins"
        ),
        vertical_spacing=0.1 
    )

    # --- Panel A: Top Eukaryotic Organisms Hit by Asgard Proteins ---
    print("\n--- Generating Panel A: Asgard Eukaryotic Hit Organisms ---")
    if not df_asgard_hits_panel_a.empty:
        asgard_org_counts = df_asgard_hits_panel_a['Euk_Hit_Organism'].value_counts().nlargest(TOP_N_ORGANISMS).reset_index()
        asgard_org_counts.columns = ['Euk_Hit_Organism', 'Count']
        
        if not asgard_org_counts.empty:
            fig.add_trace(go.Bar(
                x=asgard_org_counts['Euk_Hit_Organism'],
                y=asgard_org_counts['Count'],
                name='Asgard Hits', # Will not be shown in legend if showlegend=False for the trace
                marker_color=[euk_organism_color_map.get(org, '#999999') for org in asgard_org_counts['Euk_Hit_Organism']] # Apply consistent map
            ), row=1, col=1)
            fig.update_xaxes(title_text="", tickangle=45, categoryorder='total descending', row=1, col=1)
            fig.update_yaxes(title_text="Number of Asgard Proteins", row=1, col=1)
            
            panel_a_data_path = Path(output_summary_dir_phase1) / "figure13_panel_a_asgard_euk_org_hits.csv"
            try:
                df_asgard_hits_panel_a['Euk_Hit_Organism'].value_counts().reset_index().to_csv(panel_a_data_path, index=False)
                print(f"Data for Panel A saved to {panel_a_data_path}")
            except Exception as e: print(f"Error saving Panel A data: {e}")
        else:
            print("No eukaryotic hits found for Asgard proteins to plot for Panel A.")
    else:
        print("No Asgard proteins with eukaryotic hits found for Panel A.")

    # --- Panel B: Top Eukaryotic Organisms Hit by Giant Virus Proteins ---
    print("\n--- Generating Panel B: Giant Virus Eukaryotic Hit Organisms ---")
    if not df_gv_hits_panel_b.empty:
        gv_org_counts = df_gv_hits_panel_b['Euk_Hit_Organism'].value_counts().nlargest(TOP_N_ORGANISMS).reset_index()
        gv_org_counts.columns = ['Euk_Hit_Organism', 'Count']

        if not gv_org_counts.empty:
            fig.add_trace(go.Bar(
                x=gv_org_counts['Euk_Hit_Organism'],
                y=gv_org_counts['Count'],
                name='GV Hits',
                marker_color=[euk_organism_color_map.get(org, '#999999') for org in gv_org_counts['Euk_Hit_Organism']] # Apply consistent map
            ), row=2, col=1)
            fig.update_xaxes(title_text="", tickangle=45, categoryorder='total descending', row=2, col=1)
            fig.update_yaxes(title_text="Number of GV Proteins", row=2, col=1)

            panel_b_data_path = Path(output_summary_dir_phase1) / "figure13_panel_b_gv_euk_org_hits.csv"
            try:
                df_gv_hits_panel_b['Euk_Hit_Organism'].value_counts().reset_index().to_csv(panel_b_data_path, index=False)
                print(f"Data for Panel B saved to {panel_b_data_path}")
            except Exception as e: print(f"Error saving Panel B data: {e}")
        else:
            print("No eukaryotic hits found for GV proteins to plot for Panel B.")
    else:
        print("No GV proteins with eukaryotic hits found for Panel B.")

    # --- Panel C: Top Eukaryotic Protein Functions Hit by Giant Virus Proteins ---
    print("\n--- Generating Panel C: Giant Virus Eukaryotic Hit Protein Functions ---")
    if not df_gv_hits_panel_b.empty: 
        if 'clean_protein_name' in locals() and callable(clean_protein_name):
            df_gv_hits_panel_b['Cleaned_Euk_Hit_Protein_Name'] = df_gv_hits_panel_b['Euk_Hit_Protein_Name'].apply(clean_protein_name)
        else:
            print("WARNING: 'clean_protein_name' function not found. Using raw Euk_Hit_Protein_Name for Panel C.")
            df_gv_hits_panel_b['Cleaned_Euk_Hit_Protein_Name'] = df_gv_hits_panel_b['Euk_Hit_Protein_Name']

        gv_prot_name_counts_all = df_gv_hits_panel_b['Cleaned_Euk_Hit_Protein_Name'].value_counts()
        # Filter out very generic names AFTER counting, before selecting top N for plot
        gv_prot_name_counts_filtered = gv_prot_name_counts_all[
            ~gv_prot_name_counts_all.index.astype(str).str.contains("Hypothetical|Unknown|Uncharacterized|^nan$", case=False, na=False, regex=True)
        ]
        gv_prot_name_counts_plot = gv_prot_name_counts_filtered.nlargest(TOP_N_PROTEIN_NAMES).reset_index()
        gv_prot_name_counts_plot.columns = ['Cleaned_Euk_Hit_Protein_Name', 'Count']
        
        if not gv_prot_name_counts_plot.empty:
            fig.add_trace(go.Bar(
                x=gv_prot_name_counts_plot['Cleaned_Euk_Hit_Protein_Name'],
                y=gv_prot_name_counts_plot['Count'],
                name='GV Hit Protein Functions',
                marker_color=arcadia_secondary_palette[0 % len(arcadia_secondary_palette)] # Different palette for this panel
            ), row=3, col=1)
            fig.update_xaxes(title_text="Eukaryotic Protein Function (Cleaned Name)", tickangle=45, categoryorder='total descending', row=3, col=1)
            fig.update_yaxes(title_text="Number of GV Protein Hits", row=3, col=1)
            
            panel_c_data_path = Path(output_summary_dir_phase1) / "figure13_panel_c_gv_euk_prot_func_hits.csv"
            try:
                gv_prot_name_counts_all.reset_index().to_csv(panel_c_data_path, index=False) # Save all counts
                print(f"Data for Panel C saved to {panel_c_data_path}")
            except Exception as e: print(f"Error saving Panel C data: {e}")
        else:
            print("No eukaryotic protein names found for GV hits to plot for Panel C after filtering generic names.")
    else:
        print("No GV proteins with eukaryotic hits found for Panel C analysis.")

    # --- Final Layout Updates for the Entire Figure ---
    fig.update_layout(
        height=1400, # Increased height for better label spacing
        plot_bgcolor='rgba(0,0,0,0)', 
        paper_bgcolor='rgba(0,0,0,0)',
        margin=dict(l=80, r=50, t=100, b=220), # Increased bottom margin
        showlegend=False 
    )
    if 'plotly_layout_defaults' in locals():
        fig.update_layout(font=plotly_layout_defaults.font)
        for i in range(len(fig.layout.annotations)): 
            if fig.layout.annotations[i].text.startswith("<b>"): 
                fig.layout.annotations[i].font.size = 14
                if plotly_layout_defaults.font:
                    fig.layout.annotations[i].font.family = plotly_layout_defaults.font.family
    
    if 'plotly_layout_defaults' in locals():
        for axis_name_template in ['xaxis', 'yaxis']:
            for i in range(1, 4): 
                 axis_ref_name = f"{axis_name_template}{i if i > 1 else ''}" 
                 if hasattr(fig.layout, axis_ref_name) and hasattr(plotly_layout_defaults, axis_name_template):
                     default_axis_style = getattr(plotly_layout_defaults, axis_name_template)
                     current_axis = getattr(fig.layout, axis_ref_name)
                     if current_axis.title and current_axis.title.text and default_axis_style.title and default_axis_style.title.font:
                         current_axis.title.font.size = default_axis_style.title.font.size
                         current_axis.title.font.family = default_axis_style.title.font.family
                         current_axis.title.font.weight = default_axis_style.title.font.weight
                     elif default_axis_style.title and default_axis_style.title.font: 
                         current_axis.title.font.size = default_axis_style.title.font.size
                         current_axis.title.font.family = default_axis_style.title.font.family
                         current_axis.title.font.weight = default_axis_style.title.font.weight
                     if default_axis_style.tickfont:
                         current_axis.tickfont.size = default_axis_style.tickfont.size
                         current_axis.tickfont.family = default_axis_style.tickfont.family
                     current_axis.showgrid = False
                     current_axis.zeroline = False
                     current_axis.linecolor = 'black'
                     current_axis.linewidth = 1.5 
                     current_axis.ticks="outside"
                     current_axis.ticklen=5
                     current_axis.tickwidth=1.5
                     current_axis.tickcolor='black'
    fig.show()

    # --- Save Figure ---
    fig_path_html = Path(output_figure_dir) / "figure13_eukaryotic_hit_characterization_consistent_colors.html"
    fig.write_html(str(fig_path_html))
    print(f"Figure 13 (Consistent Colors) HTML saved to: {fig_path_html}")
    try:
        fig_path_pdf = Path(output_figure_dir) / "figure13_eukaryotic_hit_characterization_consistent_colors.pdf"
        fig.write_image(str(fig_path_pdf), width=900, height=1400)
        print(f"Figure 13 PDF saved to: {fig_path_pdf}")
    except Exception as e: print(f"Warning: Could not export Figure 13 to PDF. Error: {e}")
    try:
        fig_path_svg = Path(output_figure_dir) / "figure13_eukaryotic_hit_characterization_consistent_colors.svg"
        fig.write_image(str(fig_path_svg), width=900, height=1400)
        print(f"Figure 13 SVG saved to: {fig_path_svg}")
    except Exception as e: print(f"Warning: Could not export Figure 13 to SVG. Error: {e}")

print("\n\n--- Cell 21 (Figure 13 - Eukaryotic Hit Characterization with Consistent Colors) Complete ---")

