In [None]:
# This notebook is designed to generate all the figures for the publication "Assembling and annotating an Asgard archaea and giant virus dataset of over 840,000 proteins." It loads the final, fully annotated proteome database and uses Plotly to create a series of visualizations that characterize the dataset, explore protein features, and analyze sequence conservation. Each cell is organized to produce a specific figure or a panel within a figure, with outputs saved to a dedicated directory for publication.

In [None]:
# Setup Cell
# This cell handles all the necessary imports, sets up logging, defines file paths, configures plotting defaults to match the Arcadia style guide, loads auxiliary data like InterPro entries, and defines helper functions. It concludes by loading the main proteome database and creating color maps for consistent data visualization across all figures.

In [None]:
# Cell 1: Setup for Figures Notebook - Load Data and Define Helpers

# --- Standard Library Imports ---
import pandas as pd
import numpy as np
import time
import re  # For parsing organism names
import sys  # For exit
import logging  # For detailed logging
from pathlib import Path  # For handling paths

# Ensure Biopython is installed: pip install biopython
try:
    from Bio import SeqIO
except ImportError:
    print(
        "ERROR: Biopython is required for this script. Please install it: pip install biopython"
    )
    sys.exit(1)

# --- Plotly Imports ---
import plotly.graph_objects as go
import plotly.express as px
import plotly.io as pio

# --- Setup Logging ---
# Get a named logger for this notebook
logger = logging.getLogger(__name__)
# Prevent adding handlers multiple times if script is re-run in interactive session
if not logger.hasHandlers():
    logger.setLevel(logging.INFO)
    console_handler = logging.StreamHandler(sys.stdout)
    log_formatter = logging.Formatter(
        "%(asctime)s [%(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
    )
    console_handler.setFormatter(log_formatter)
    logger.addHandler(console_handler)
else:
    # Ensure level is set if handlers already exist (e.g., in notebook re-run)
    logger.setLevel(logging.INFO)


print("--- Figures Notebook Setup ---")

# --- Configuration ---
# INPUT file path: Use the latest database version that includes Euk annotations,
# DIAMOND hit details, coverage columns, RBH flag, and potentially merged APSI/Motif data.
# Make sure this path points to your most recent, complete database file!
database_path = "proteome_database_v3.5.csv"  # <-- Update if your latest file has a different name/path

# Directory where the previous notebook saved summary data (APSI, Motifs, etc.)
# This path needs to match the output_summary_dir_phase1 from your analysis notebook
# This is needed to load APSI/Motif data if it's not already in the main database file.
output_summary_dir_phase1 = Path(
    "./output_summary_data_hit_validation_phase1"
)  # <-- Update this path if different

# Directory to save plots and summary data generated in this notebook
# Use distinct directory names for figure outputs vs. analysis outputs
output_figure_dir = "publication_figures"
output_figure_summary_dir = "publication_figure_data"  # For tables/data backing figures

# Create output directories if they don't exist
# Use Path objects here for consistency
Path(output_figure_dir).mkdir(parents=True, exist_ok=True)
print(f"Ensured plot output directory exists: {output_figure_dir}")
Path(output_figure_summary_dir).mkdir(parents=True, exist_ok=True)
print(f"Ensured summary data directory exists: {output_figure_summary_dir}")

# Path to InterPro entry list file (ensure this file is in the correct location)
# This is needed for translating IPR/domain names in helper functions
interpro_entry_path = "interpro_entry.txt"  # <-- Update path if needed


# --- Define Key Column Names (Based on your database structure) ---
# Define all relevant column names here for easy access and consistency
protein_id_col = "ProteinID"
sequence_col = "Sequence"
source_dataset_col = "Source_Dataset"
source_genome_accession_col = "Source_Genome_Accession"
source_protein_annotation_col = "Source_Protein_Annotation"
ncbi_taxid_col = "NCBI_TaxID"
asgard_phylum_col = "Asgard_Phylum"
virus_family_col = "Virus_Family"
virus_name_col = "Virus_Name"
orthogroup_col = "OG_ID"
ipr_col = "IPR_Signatures"
ipr_go_terms_col = "IPR_GO_Terms"
uniprotkb_ac_col = "UniProtKB_AC"
num_domains_col = "Num_Domains"
domain_arch_col = "Domain_Architecture"
type_col = "Type"  # As in 'Annotated', 'Uncharacterized'
is_hypothetical_col = "Is_Hypothetical"
has_known_structure_col = "Has_Known_Structure"
percent_disorder_col = "Percent_Disorder"
specific_func_cat_col = "Specific_Functional_Category"
broad_func_cat_col = "Broad_Functional_Category"
category_trigger_col = "Category_Trigger"
signal_peptide_col = "Signal_Peptide_USPNet"
sp_cleavage_site_col = "SP_Cleavage_Site_USPNet"
original_seq_length_col = "Original_Seq_Length"
group_col = "Group"  # 'Asgard' or 'GV'
seqsearch_pdb_hit_col = "SeqSearch_PDB_Hit"
seqsearch_afdb_hit_col = "SeqSearch_AFDB_Hit"
has_reference_structure_col = "Has_Reference_Structure"
localization_col = "Predicted_Subcellular_Localization"
mature_protein_sequence_col = "Mature_Protein_Sequence"
mature_seq_length_col = "Mature_Seq_Length"
seqsearch_mgnify_hit_col = "SeqSearch_MGnify_Hit"
seqsearch_esma_hit_col = "SeqSearch_ESMA_Hit"
structurally_dark_col = "Is_Structurally_Dark"  # Derived column
esp_col = "Is_ESP"  # Derived column
hit_flag_col = "Has_Euk_DIAMOND_Hit"  # Flag for any Euk hit passing initial e-value
euk_hit_sseqid_col = "Euk_Hit_SSEQID"  # Best Euk hit SSEQID
euk_hit_organism_col = "Euk_Hit_Organism"  # Best Euk hit organism
euk_hit_pident_col = "Euk_Hit_PIDENT"  # Best Euk hit PIDENT
euk_hit_evalue_col = "Euk_Hit_EVALUE"  # Best Euk hit EVALUE
euk_hit_protein_name_col = "Euk_Hit_Protein_Name"  # Best Euk hit protein name
euk_hit_qstart_col = "Euk_Hit_Qstart"  # Alignment start on query (Asgard)
euk_hit_qend_col = "Euk_Hit_Qend"  # Alignment end on query (Asgard)
euk_hit_sstart_col = "Euk_Hit_Sstart"  # Alignment start on subject (Euk)
euk_hit_send_col = "Euk_Hit_Send"  # Alignment end on subject (Euk)
euk_hit_slen_diamond_col = (
    "Euk_Hit_Slen_Diamond"  # Length of Euk hit from DIAMOND output
)
query_coverage_col = "Query_Coverage"  # Calculated query coverage
subject_coverage_col = "Subject_Coverage"  # Calculated subject coverage
rbh_flag_col = "Is_RBH"  # Flag for Reciprocal Best Hit (if RBH analysis was included)

# Add columns for APSI and Motifs if they were added to the main DB
apsi_col = "Intra_OG_APSI"  # Assuming this column name if merged
motif_col = (
    "Conserved_Motifs"  # Assuming this column name if merged (might be list/string)
)
num_og_sequences_col = (
    "Num_OG_Sequences"  # Number of sequences in the OG (used for MSA)
)

# --- Define Arcadia Color Palettes (from Style Guide PDF) ---
print("\n--- Defining Manual Arcadia Color Palettes (from Style Guide) ---")
# Define all colors from the style guide PDF for flexibility
arcadia_colors_manual = {
    "aegean": "#5088C5",
    "amber": "#F28360",
    "seaweed": "#3B9886",
    "canary": "#F7B846",
    "aster": "#7A77AB",
    "rose": "#F898AE",
    "vital": "#73B5E3",
    "tangerine": "#FFB984",  # Corrected Tangerine from PDF: FFb883 -> FFB984 (common web palette) - using FFB984 as per previous code
    "oat": "#F5E4BE",
    "wish": "#BABEE0",
    "lime": "#97CD78",
    "dragon": "#C85152",
    "sky": "#C6E7F4",
    "dress": "#F8C5C1",
    "taupe": "#DBD1C3",
    "denim": "#B6C8D4",
    "sage": "#B5BEA4",
    "mars": "#DA9085",
    "marine": "#8A99AD",
    "shell": "#EDE0D6",
    "white": "#FFFFFF",
    "gray": "#EBEDE8",
    "chateau": "#B9AFA7",  # Corrected Chateau from PDF: BAB0A8 -> B9AFA7
    "bark": "#8F8885",
    "slate": "#43413F",
    "charcoal": "#484B50",
    "crow": "#292928",
    "black": "#09090A",
    "forest": "#596F74",  # Forest is in Neutrals in PDF
    "parchment": "#FEF7F1",
    "zephyr": "#F4FBFF",  # Corrected Zephyr from PDF: F4FBFE -> F4FBFF
    "lichen": "#F7FBEF",
    "dawn": "#F8F4F1",
}
# Define the specific palettes as lists of hex codes for easy cycling/use
arcadia_primary_palette = [
    arcadia_colors_manual[c]
    for c in [
        "aegean",
        "amber",
        "seaweed",
        "canary",
        "aster",
        "rose",
        "vital",
        "tangerine",
        "oat",
        "wish",
        "lime",
        "dragon",
    ]
]
arcadia_secondary_palette = [
    arcadia_colors_manual[c]
    for c in ["sky", "dress", "taupe", "denim", "sage", "mars", "marine", "shell"]
]
arcadia_neutrals_palette = [
    arcadia_colors_manual[c]
    for c in ["gray", "chateau", "bark", "slate", "charcoal", "forest", "crow"]
]
arcadia_background_palette = [
    arcadia_colors_manual[c] for c in ["parchment", "zephyr", "lichen", "dawn"]
]  # Added background colors

print("Manual Arcadia palettes created.")

# --- Configure Plotly Defaults (Adhering to Style Guide) ---
print("\n--- Configuring Plotly Defaults (Adhering to Style Guide) ---")
pio.templates.default = "plotly_white"  # Start with a clean white background

# Define default layout for bold axis titles, NO gridlines, strong black axis lines, and NO plot titles
# THIS VARIABLE NEEDS TO BE DEFINED *BEFORE* RUNNING FIGURE CELLS
plotly_layout_defaults = go.Layout(
    xaxis=dict(
        title=dict(
            font=dict(size=15, color="black", family="Arial", weight="bold")
        ),  # Bold axis title, 15pt
        showgrid=False,  # Remove x-axis gridlines
        zeroline=False,  # Optional: Remove zero line as well
        showline=True,  # Show axis line
        linecolor="black",  # Axis line color
        linewidth=1.5,  # Axis line width (adjust as needed for "strong")
        mirror=False,  # Draw line on opposite side (False means only on the side where ticks/labels are)
        ticks="outside",  # Show ticks outside the axis line
        ticklen=5,  # Tick length (5px as per style guide)
        tickwidth=1.5,  # Tick width (match axis line width for prominence)
        tickcolor="black",  # Tick color
        # Tick label font can be set here too if needed, e.g., tickfont=dict(size=15, family='Arial', color='black')
    ),
    yaxis=dict(
        title=dict(
            font=dict(size=15, color="black", family="Arial", weight="bold")
        ),  # Bold axis title, 15pt
        showgrid=False,  # Remove y-axis gridlines
        zeroline=False,  # Optional: Remove zero line as well
        showline=True,  # Show axis line
        linecolor="black",  # Axis line color
        linewidth=1.5,  # Axis line width
        mirror=False,  # Draw line on opposite side
        ticks="outside",  # Show ticks outside the axis line
        ticklen=5,  # Tick length
        tickwidth=1.5,  # Tick width
        tickcolor="black",  # Tick color
        # Tick label font can be set here too if needed
    ),
    title=None,  # Ensure no plot title by default
    # Optional: Configure default font for tick labels if needed
    font=dict(family="Arial", size=15, color="black"),  # Example for 15pt tick labels
    # Optional: Configure default legend title font
    # legend=dict(title=dict(font=dict(weight='bold')))
)

# Apply these default layout settings directly to the plotly_white template
# This should ensure they are applied to all px plots by default
pio.templates["plotly_white"].layout.xaxis.title.font.update(size=15, weight="bold")
pio.templates["plotly_white"].layout.yaxis.title.font.update(size=15, weight="bold")
pio.templates["plotly_white"].layout.xaxis.showgrid = False
pio.templates["plotly_white"].layout.yaxis.showgrid = False
pio.templates["plotly_white"].layout.xaxis.showline = True  # Ensure line is shown
pio.templates["plotly_white"].layout.xaxis.linecolor = "black"
pio.templates["plotly_white"].layout.xaxis.linewidth = 1.5
pio.templates["plotly_white"].layout.xaxis.mirror = False  # Draw line on one side
pio.templates["plotly_white"].layout.xaxis.ticks = "outside"  # Show ticks outside
pio.templates["plotly_white"].layout.xaxis.ticklen = 5  # Tick length
pio.templates["plotly_white"].layout.xaxis.tickwidth = 1.5  # Tick width
pio.templates["plotly_white"].layout.xaxis.tickcolor = "black"  # Tick color

pio.templates["plotly_white"].layout.yaxis.showline = True  # Ensure line is shown
pio.templates["plotly_white"].layout.yaxis.linecolor = "black"
pio.templates["plotly_white"].layout.yaxis.linewidth = 1.5
pio.templates["plotly_white"].layout.yaxis.mirror = False  # Draw line on one side
pio.templates["plotly_white"].layout.yaxis.ticks = "outside"  # Show ticks outside
pio.templates["plotly_white"].layout.yaxis.ticklen = 5  # Tick length
pio.templates["plotly_white"].layout.yaxis.tickwidth = 1.5  # Tick width
pio.templates["plotly_white"].layout.yaxis.tickcolor = "black"  # Tick color


pio.templates["plotly_white"].layout.title = None  # Set no title in template
# Note: Modifying default templates can have broad effects. Applying to individual figures or using update_layout might be safer.

print("Plotly default template set to 'plotly_white'.")
print(
    "Default layout settings configured for bold axis titles (15pt), NO gridlines, strong black axis lines (1.5pt), tick marks (5px, 1.5pt, black), and NO plot titles."
)


# --- Load InterPro Entry Data ---
# This is needed for translating IPR IDs to names in domain architectures/IPRs
print(f"\n--- Loading InterPro Entry Data from '{interpro_entry_path}' ---")
ipr_lookup = {}
start_time_ipr = time.time()
try:
    # Adjust usecols/names if your interpro_entry.txt file format is different
    ipr_info_df = pd.read_csv(
        interpro_entry_path,
        sep="\t",
        usecols=[0, 1, 2],
        names=["IPR_ID", "Type", "Name"],
        header=0,
        comment="#",
        on_bad_lines="warn",
    )
    # Ensure column names match what's expected after loading
    if "ENTRY_AC" in ipr_info_df.columns:
        ipr_info_df.rename(columns={"ENTRY_AC": "IPR_ID"}, inplace=True)
    if "ENTRY_TYPE" in ipr_info_df.columns:
        ipr_info_df.rename(columns={"ENTRY_TYPE": "Type"}, inplace=True)
    if "ENTRY_NAME" in ipr_info_df.columns:
        ipr_info_df.rename(columns={"ENTRY_NAME": "Name"}, inplace=True)

    if {"IPR_ID", "Name", "Type"}.issubset(ipr_info_df.columns):
        ipr_info_df["IPR_ID"] = ipr_info_df["IPR_ID"].astype(str).str.strip()
        ipr_lookup = ipr_info_df.set_index("IPR_ID")[["Type", "Name"]].to_dict("index")
        print(
            f"Loaded InterPro entry data for {len(ipr_lookup)} entries in {time.time() - start_time_ipr:.2f} seconds."
        )
    else:
        print(
            f"Warning: Expected columns ('IPR_ID', 'Name', 'Type') not all found in '{interpro_entry_path}' after loading."
        )
except FileNotFoundError:
    print(
        f"Warning: InterPro entry file not found at '{interpro_entry_path}'. Domain name translations will not be available."
    )
except Exception as e:
    print(
        f"Warning: An error occurred loading or processing '{interpro_entry_path}': {e}"
    )

if not ipr_lookup:
    print(
        "Warning: ipr_lookup is empty. Domain name translations will not be available for plots/tables."
    )


# --- Define Helper Functions ---
print("\n--- Defining Helper Functions ---")
# Include necessary helper functions from your previous notebook here.
# You might need: translate_architecture, truncate_string, clean_protein_name, etc.
# Only include functions actually needed for figure generation and data manipulation *within* this notebook.


# Example: translate_architecture (requires ipr_lookup)
def translate_architecture(arch_string, lookup_dict=ipr_lookup):
    """Translates IPR IDs in a domain architecture string to names using a lookup dictionary."""
    if not isinstance(arch_string, str) or not arch_string or not lookup_dict:
        return arch_string  # Return original if invalid input or no lookup

    processed_arch_string = str(arch_string).replace("|", ";")
    ipr_ids = processed_arch_string.split(";")
    translated_parts = []
    for ipr_id in ipr_ids:
        ipr_id_clean = ipr_id.strip()
        if ipr_id_clean in lookup_dict:
            name = lookup_dict[ipr_id_clean].get("Name", ipr_id_clean)
            # Truncate long names for display
            name = name[:40] + "..." if len(name) > 43 else name
            translated_parts.append(f"{name} ({ipr_id_clean})")
        elif ipr_id_clean:
            translated_parts.append(ipr_id_clean)  # Keep unknown IDs

    full_translation = "; ".join(translated_parts)
    max_len_display = 150  # Max length for display
    if len(full_translation) > max_len_display:
        full_translation = full_translation[: max_len_display - 3] + "..."

    return full_translation


# Example: truncate_string
def truncate_string(text, max_len):
    """Truncates a string if it exceeds max_len, adding '...'."""
    if isinstance(text, str) and len(text) > max_len:
        return text[: max_len - 3] + "..."
    return text


# Example: clean_protein_name
def clean_protein_name(name):
    """Cleans common annotations from protein names."""
    if pd.isna(name):
        return "Unknown/Not Found"
        name = str(name).strip()
    name = re.sub(r"\s*\|\s*.*", "", name)
    name = re.sub(r"\bisoform\s+[\w-]+\b", "", name, flags=re.IGNORECASE).strip()
    name = re.sub(r"\bpartial\b", "", name, flags=re.IGNORECASE).strip()
    name = re.sub(r"\bputative\b", "", name, flags=re.IGNORECASE).strip()
    name = re.sub(r"\bpredicted protein\b", "", name, flags=re.IGNORECASE).strip()
    name = re.sub(r"\btype\s+\w+\b", "", name, flags=re.IGNORECASE).strip()
    name = re.sub(r"\bprotein\b", "", name, flags=re.IGNORECASE).strip()
    name = re.sub(
        r"\buncharacterized\b", "Uncharacterized", name, flags=re.IGNORECASE
    ).strip()
    name = re.sub(r"[;,]$", "", name).strip()
    name = re.sub(r"\s*\(Fragment\)$", "", name, flags=re.IGNORECASE).strip()
    name = re.sub(r"\s+", " ", name).strip()
    name = name[0].upper() + name[1:] if len(name) > 0 else name
    return name if name else "Unknown/Not Found"


# Add other necessary helper functions here...
# e.g., get_ipr_counts if you plot IPR frequencies directly


# --- Load Main Database ---
print(f"\n--- Loading Data from '{database_path}' ---")
try:
    # Load the database that already contains all annotations,
    # including Euk hit details, coverage, and RBH flag (if added).
    df_full = pd.read_csv(database_path, low_memory=False)
    # Ensure ProteinID is string type right after loading
    if protein_id_col in df_full.columns:
        df_full[protein_id_col] = df_full[protein_id_col].astype(str)

    # --- Merge APSI and Motif Data ---
    # Assuming APSI and Motif results were saved to CSVs in the previous notebook
    # and are NOT already merged into the main database file.
    # If they ARE already in your database_path file, you can skip this merge step.

    # Use the defined output_summary_dir_phase1 variable
    apsi_file_path = (
        output_summary_dir_phase1 / "intra_og_apsi_values.csv"
    )  # Path from previous notebook
    motifs_file_path = (
        output_summary_dir_phase1 / "intra_og_conserved_motifs.csv"
    )  # Path from previous notebook

    print(
        f"\nAttempting to merge APSI data from: {apsi_file_path}"
    )  # Added print statement
    if apsi_file_path.is_file():
        print("APSI file found.")  # Added print statement
        try:
            df_apsi = pd.read_csv(apsi_file_path)
            print(
                f"APSI file read successfully. Shape: {df_apsi.shape}. Columns: {df_apsi.columns.tolist()}"
            )  # Added print statement
            # Rename 'Num_Sequences' from APSI file to avoid conflict if needed, or use it
            df_apsi.rename(
                columns={"Num_Sequences": num_og_sequences_col}, inplace=True
            )
            # Ensure Orthogroup column exists in df_apsi before merging
            if orthogroup_col in df_apsi.columns:
                # Rename the 'APSI' column in df_apsi to match the desired column name 'Intra_OG_APSI'
                if "APSI" in df_apsi.columns:
                    df_apsi.rename(columns={"APSI": apsi_col}, inplace=True)
                    print(
                        f"Renamed 'APSI' column to '{apsi_col}' in APSI data."
                    )  # Added print statement
                else:
                    print(
                        f"Warning: 'APSI' column not found in APSI data. Cannot rename to '{apsi_col}'."
                    )

                # Merge APSI data into df_full based on Orthogroup
                # Note: APSI is per OG, so merge will add APSI to all proteins in that OG
                # Use a left merge to keep all proteins from df_full
                print(
                    f"Merging APSI data on column '{orthogroup_col}'. df_full shape before merge: {df_full.shape}"
                )  # Added print statement
                # Select only the columns needed from df_apsi for the merge
                cols_to_merge_from_apsi = [
                    orthogroup_col,
                    apsi_col,
                    num_og_sequences_col,
                ]
                # Filter to include only columns that actually exist in df_apsi after potential renaming
                cols_to_merge_from_apsi_existing = [
                    col for col in cols_to_merge_from_apsi if col in df_apsi.columns
                ]

                if cols_to_merge_from_apsi_existing:
                    df_full = df_full.merge(
                        df_apsi[cols_to_merge_from_apsi_existing],
                        on=orthogroup_col,
                        how="left",
                    )
                    print(
                        f"Merged APSI data. df_full shape after merge: {df_full.shape}. Columns: {df_full.columns.tolist()}"
                    )  # Added print statement
                else:
                    print(
                        "Warning: No relevant columns found in APSI data for merging."
                    )

            else:
                print(
                    f"Warning: Orthogroup column '{orthogroup_col}' not found in APSI data. Skipping APSI merge."
                )

        except Exception as e:
            logger.error(f"Failed to merge APSI data: {e}")
    else:
        print(
            f"Warning: APSI file not found at '{apsi_file_path}'. APSI data will not be available."
        )
        # Add empty APSI column if not present to prevent downstream errors
        if apsi_col not in df_full.columns:
            df_full[apsi_col] = np.nan
        if num_og_sequences_col not in df_full.columns:
            df_full[num_og_sequences_col] = np.nan

    print(
        f"\nAttempting to load and merge Motif data from: {motifs_file_path}"
    )  # Added print statement
    if motifs_file_path.is_file():
        print("Motif file found.")  # Added print statement
        try:
            df_motifs = pd.read_csv(motifs_file_path)
            print(
                f"Motif file read successfully. Shape: {df_motifs.shape}. Columns: {df_motifs.columns.tolist()}"
            )  # Added print statement
            # You can now use df_motifs directly for motif analysis/plotting
            print(
                f"Loaded {len(df_motifs)} motifs for {len(df_motifs['Orthogroup'].unique())} orthogroups."
            )
            # Example: Add a flag to df_full if an OG has any conserved motif
            if orthogroup_col in df_motifs.columns:
                ogs_with_motifs = df_motifs[orthogroup_col].unique()
                df_full["Has_Conserved_Motif"] = df_full[orthogroup_col].isin(
                    ogs_with_motifs
                )
                print(
                    f"Added 'Has_Conserved_Motif' flag to df_full. Columns: {df_full.columns.tolist()}"
                )  # Added print statement
            else:
                print(
                    f"Warning: Orthogroup column '{orthogroup_col}' not found in Motif data. Skipping 'Has_Conserved_Motif' flag."
                )
                if "Has_Conserved_Motif" not in df_full.columns:
                    df_full["Has_Conserved_Motif"] = False  # Ensure column exists

        except Exception as e:
            logger.error(f"Failed to load motif data: {e}")
            # Add empty flag column if not present
            if "Has_Conserved_Motif" not in df_full.columns:
                df_full["Has_Conserved_Motif"] = False
    else:
        print(
            f"Warning: Motif file not found at '{motifs_file_path}'. Motif data will not be available."
        )
        if "Has_Conserved_Motif" not in df_full.columns:
            df_full["Has_Conserved_Motif"] = False

    print(
        f"\nSuccessfully loaded and prepared data. Final df_full shape: {df_full.shape}. Columns: {df_full.columns.tolist()}"
    )  # Added print statement

except FileNotFoundError:
    print(
        f"ERROR: Database file not found at '{database_path}'. Please check the path."
    )
    # Define df_full as empty to prevent downstream errors
    df_full = pd.DataFrame()
except Exception as e:
    print(f"An error occurred while loading or processing the database: {e}")
    # Define df_full as empty
    df_full = pd.DataFrame()


# --- Create Color Maps for Plotting (after df_full is loaded) ---
# Re-create color maps based on the loaded data to ensure categories match
print("\n--- Creating Color Maps ---")
# Asgard Phylum Colors
asgard_phylum_color_map = {}
if asgard_phylum_col in df_full.columns and group_col in df_full.columns:
    # Filter for Asgard group before getting unique phyla
    asgard_phyla_cat = (
        df_full[df_full[group_col] == "Asgard"][asgard_phylum_col].dropna().unique()
    )
    asgard_phyla_cat.sort()
    if len(asgard_phyla_cat) > 0:
        asgard_phylum_color_map = {
            phylum: arcadia_primary_palette[i % len(arcadia_primary_palette)]
            for i, phylum in enumerate(asgard_phyla_cat)
        }
    asgard_phylum_color_map["Unknown Phylum"] = arcadia_colors_manual.get(
        "gray", "#bdbdbd"
    )
print(f"Asgard phylum color map created for {len(asgard_phylum_color_map)} phyla.")

# Localization Colors
localization_color_map = {}
if localization_col in df_full.columns:
    # Get all unique localization values, including potential NaNs, from Asgard and GV groups
    all_localizations = df_full[df_full[group_col].isin(["Asgard", "GV"])][
        localization_col
    ].unique()
    # Convert to list, handling NaN explicitly if needed for mapping
    all_localizations_list = [
        str(loc) if pd.isna(loc) else loc for loc in all_localizations
    ]
    all_localizations_list.sort()  # Sort for consistent color assignment
    localization_assignments = {
        "Archaea: Cytoplasmic/Membrane (non-SP)": arcadia_colors_manual.get(
            "aegean", "#5088C5"
        ),
        "Archaea: Membrane-associated (Lipoprotein/Pilin)": arcadia_colors_manual.get(
            "amber", "#F28360"
        ),
        "Archaea: Secreted/Membrane (Sec/Tat pathway)": arcadia_colors_manual.get(
            "seaweed", "#3B9886"
        ),
        "Host: Cytoplasm/Nucleus/Virus Factory": arcadia_colors_manual.get(
            "vital", "#73B5E3"
        ),
        "Host: Membrane-associated (Lipoprotein/Pilin-like)": arcadia_colors_manual.get(
            "tangerine", "#FFB984"
        ),
        "Host: Secretory Pathway (Secreted/Membrane/Organelle)": arcadia_colors_manual.get(
            "lime", "#97CD78"
        ),
        "CYTOPLASMIC": arcadia_colors_manual.get(
            "vital", "#73B5E3"
        ),  # Include simplified terms if present
        "MEMBRANE": arcadia_colors_manual.get("tangerine", "#FFB984"),
        "EXTRACELLULAR": arcadia_colors_manual.get("lime", "#97CD78"),
        "Unknown": arcadia_colors_manual.get("gray", "#EBEDE8"),
        "nan": arcadia_colors_manual.get("gray", "#EBEDE8"),  # Handle NaN explicitly
    }
    fallback_palette_loc = (
        arcadia_neutrals_palette + arcadia_secondary_palette
    )  # Use defined palettes
    fallback_idx_loc = 0
    for loc in all_localizations_list:
        if loc not in localization_color_map:
            localization_color_map[loc] = localization_assignments.get(
                loc, fallback_palette_loc[fallback_idx_loc % len(fallback_palette_loc)]
            )
            if loc not in localization_assignments:
                fallback_idx_loc += 1
print(f"Localization color map created for {len(localization_color_map)} categories.")

# Broad Functional Category Colors
broad_category_color_map = {}
if broad_func_cat_col in df_full.columns:
    # Get all unique broad categories, including potential NaNs, from Asgard and GV groups
    all_broad_categories = df_full[df_full[group_col].isin(["Asgard", "GV"])][
        broad_func_cat_col
    ].unique()
    # Convert to list, handling NaN explicitly
    all_broad_categories_list = [
        str(cat) if pd.isna(cat) else cat for cat in all_broad_categories
    ]
    all_broad_categories_list.sort()  # Sort for consistent color assignment

    category_assignments = {
        "Cytoskeleton": arcadia_colors_manual.get("aegean", "#5088C5"),
        "Membrane Trafficking/Vesicles": arcadia_colors_manual.get("amber", "#F28360"),
        "ESCRT/Endosomal Sorting": arcadia_colors_manual.get("seaweed", "#3B9886"),
        "Ubiquitin System": arcadia_colors_manual.get("aster", "#7A77AB"),
        "N-glycosylation": arcadia_colors_manual.get("rose", "#F898AE"),
        "Nuclear Transport/Pore": arcadia_colors_manual.get("vital", "#73B5E3"),
        "DNA Info Processing": arcadia_colors_manual.get("canary", "#F7B846"),
        "RNA Info Processing": arcadia_colors_manual.get("lime", "#97CD78"),
        "Translation": arcadia_colors_manual.get("tangerine", "#FFB984"),
        "Signal Transduction": arcadia_colors_manual.get("aster", "#7A77AB"),
        "Metabolism": arcadia_colors_manual.get(
            "sky", "#C6E7F4"
        ),  # Using sky as per previous code
        "Other Specific Annotation": arcadia_colors_manual.get("denim", "#B6C8D4"),
        "Unknown/Unclassified": arcadia_colors_manual.get("gray", "#EBEDE8"),
        "nan": arcadia_colors_manual.get("gray", "#EBEDE8"),  # Handle NaN explicitly
    }
    fallback_palette_cat = (
        arcadia_primary_palette + arcadia_secondary_palette + arcadia_neutrals_palette
    )  # Use defined palettes
    fallback_idx_cat = 0
    for category in all_broad_categories_list:
        if category not in broad_category_color_map:
            broad_category_color_map[category] = category_assignments.get(
                category,
                fallback_palette_cat[fallback_idx_cat % len(fallback_palette_cat)],
            )
            if category not in category_assignments:
                fallback_idx_cat += 1
print(
    f"Broad functional category color map created for {len(broad_category_color_map)} categories."
)


print("\n--- Figures Notebook Setup Complete ---")
print("DataFrame 'df_full' is loaded and prepared for figure generation.")
print("Color maps and helper functions are defined.")
print("You can now proceed with generating figures in subsequent cells.")

In [None]:
# This cell is designed to generate plots for database-overview_82.png, the figure in the pub that characterizes the composition of the dataset.

In [None]:
# Cell 6: Figure 4 - Genome and Protein Counts by Taxonomy

# --- Imports ---
# Assumes pandas, numpy, plotly.express, plotly.graph_objects, plotly.io,
# and necessary variables/helper functions (like df_full, group_col,
# protein_id_col, source_genome_accession_col, asgard_phylum_col,
# virus_family_col, arcadia_colors_manual, arcadia_primary_palette,
# arcadia_secondary_palette, arcadia_neutrals_palette, output_figure_dir,
# plotly_layout_defaults)
# are defined in the setup cell (Cell 1).
# Assumes df_filtered_groups is available from Cell 2 (Figure 2 generation) and
# correctly filters for 'Asgard' and 'GV' groups.

print("\n\n--- Generating Figure 4: Genome and Protein Counts by Taxonomy (Cell 6) ---")

# Ensure df_filtered_groups is available and not empty
if "df_filtered_groups" not in locals() or df_filtered_groups.empty:
    print(
        "ERROR: df_filtered_groups not available or is empty. Please run Cell 2 first."
    )
else:
    # Define colors for Asgard and GV groups (re-defined for context):
    group_colors = {
        "Asgard": arcadia_colors_manual.get("aegean", "#5088C5"),
        "GV": arcadia_colors_manual.get("amber", "#F28360"),
    }
    print(
        f"\nUsing Group Colors (for context): Asgard={group_colors.get('Asgard')}, GV={group_colors.get('Amber')}"
    )

    # --- Calculate Genome and Protein Counts by Taxonomy ---
    print("\n--- Calculating Genome and Protein Counts by Taxonomy ---")

    genome_counts_by_group_taxonomy = []
    protein_counts_by_group_taxonomy = []
    all_taxonomic_units_present = set()  # Collect all unique units from both datasets

    if (
        source_genome_accession_col in df_filtered_groups.columns
        and group_col in df_filtered_groups.columns
        and asgard_phylum_col in df_filtered_groups.columns
        and virus_family_col in df_filtered_groups.columns
        and protein_id_col in df_filtered_groups.columns
    ):
        # Process Asgard counts by Phylum
        df_asgard = df_filtered_groups[df_filtered_groups[group_col] == "Asgard"].copy()
        if not df_asgard.empty:
            if asgard_phylum_col in df_asgard.columns:
                # Genome counts
                asgard_genome_counts = (
                    df_asgard.groupby([group_col, asgard_phylum_col])[
                        source_genome_accession_col
                    ]
                    .nunique()
                    .reset_index(name="Count")
                )
                asgard_genome_counts.rename(
                    columns={asgard_phylum_col: "Taxonomic_Unit"}, inplace=True
                )
                genome_counts_by_group_taxonomy.append(asgard_genome_counts)
                all_taxonomic_units_present.update(
                    asgard_genome_counts["Taxonomic_Unit"].dropna().unique()
                )

                # Protein counts
                asgard_protein_counts = (
                    df_asgard.groupby([group_col, asgard_phylum_col])[protein_id_col]
                    .count()
                    .reset_index(name="Count")
                )
                asgard_protein_counts.rename(
                    columns={asgard_phylum_col: "Taxonomic_Unit"}, inplace=True
                )
                protein_counts_by_group_taxonomy.append(asgard_protein_counts)
                all_taxonomic_units_present.update(
                    asgard_protein_counts["Taxonomic_Unit"].dropna().unique()
                )
            else:
                print(
                    f"Warning: Skipping Asgard counts. Column '{asgard_phylum_col}' missing."
                )
        else:
            print("Warning: Skipping Asgard counts. Data is empty.")

        # Process GV counts by Family
        df_gv = df_filtered_groups[df_filtered_groups[group_col] == "GV"].copy()
        if not df_gv.empty:
            if virus_family_col in df_gv.columns:
                # Genome counts
                gv_genome_counts = (
                    df_gv.groupby([group_col, virus_family_col])[
                        source_genome_accession_col
                    ]
                    .nunique()
                    .reset_index(name="Count")
                )
                gv_genome_counts.rename(
                    columns={virus_family_col: "Taxonomic_Unit"}, inplace=True
                )
                genome_counts_by_group_taxonomy.append(gv_genome_counts)
                all_taxonomic_units_present.update(
                    gv_genome_counts["Taxonomic_Unit"].dropna().unique()
                )

                # Protein counts
                gv_protein_counts = (
                    df_gv.groupby([group_col, virus_family_col])[protein_id_col]
                    .count()
                    .reset_index(name="Count")
                )
                gv_protein_counts.rename(
                    columns={virus_family_col: "Taxonomic_Unit"}, inplace=True
                )
                protein_counts_by_group_taxonomy.append(gv_protein_counts)
                all_taxonomic_units_present.update(
                    gv_protein_counts["Taxonomic_Unit"].dropna().unique()
                )
            else:
                print(
                    f"Warning: Skipping GV counts. Column '{virus_family_col}' missing."
                )
        else:
            print("Warning: Skipping GV counts. Data is empty.")

        # Combine results into DataFrames
        df_genome_counts_taxonomy = (
            pd.concat(genome_counts_by_group_taxonomy, ignore_index=True)
            if genome_counts_by_group_taxonomy
            else pd.DataFrame()
        )
        df_protein_counts_taxonomy = (
            pd.concat(protein_counts_by_group_taxonomy, ignore_index=True)
            if protein_counts_by_group_taxonomy
            else pd.DataFrame()
        )

        print("\nGenome Counts by Group and Taxonomy (Partial):")
        print(df_genome_counts_taxonomy.to_markdown(index=False))
        print("\nProtein Counts by Group and Taxonomy (Partial):")
        print(df_protein_counts_taxonomy.to_markdown(index=False))

        # --- Create Comprehensive Taxonomic Unit Color Map ---
        # Use a combination of Arcadia palettes for the color map
        # Prioritize primary, then secondary, then neutrals/monochromatic shades
        # Use a larger pool of colors to minimize repetition
        full_arcadia_palette = (
            arcadia_primary_palette
            + arcadia_secondary_palette
            + arcadia_neutrals_palette
            + [
                arcadia_colors_manual.get("lapis", "#2B66A2"),
                arcadia_colors_manual.get("dusk", "#094468"),  # Blue shades
                arcadia_colors_manual.get("melon", "#FFCFAF"),
                arcadia_colors_manual.get("cinnabar", "#9E3F41"),  # Orange/Red shades
                arcadia_colors_manual.get("sun", "#FFD364"),
                arcadia_colors_manual.get("mustard", "#D68D22"),  # Yellow shades
                arcadia_colors_manual.get("iris", "#DCDFEF"),
                arcadia_colors_manual.get("tanzanite", "#54448C"),  # Purple shades
                arcadia_colors_manual.get("glass", "#C3E2DB"),
                arcadia_colors_manual.get("asparagus", "#2A6B5E"),  # Teal shades
                arcadia_colors_manual.get("putty", "#FFE3D4"),
                arcadia_colors_manual.get("candy", "#E2718F"),  # Pink shades
                arcadia_colors_manual.get("stone", "#EDE6DA"),
                arcadia_colors_manual.get("dove", "#CAD4DB"),  # Warm/Cool gray shades
                arcadia_colors_manual.get("steel", "#687787"),
                arcadia_colors_manual.get("mud", "#635C5A"),  # More neutrals
            ]
        )

        # Assign colors to ALL unique taxonomic units found in either dataset
        combined_taxonomy_color_map = {}
        sorted_all_taxonomic_units = sorted(
            list(all_taxonomic_units_present)
        )  # Sort for consistent assignment
        for i, unit in enumerate(sorted_all_taxonomic_units):
            combined_taxonomy_color_map[unit] = full_arcadia_palette[
                i % len(full_arcadia_palette)
            ]

        # Add 'Unknown Phylum'/'Unknown Family' if they exist and weren't in the unique list
        if "Unknown Phylum" not in combined_taxonomy_color_map:
            combined_taxonomy_color_map["Unknown Phylum"] = arcadia_colors_manual.get(
                "gray", "#bdbdbd"
            )
        if "Unknown Family" not in combined_taxonomy_color_map:
            combined_taxonomy_color_map["Unknown Family"] = arcadia_colors_manual.get(
                "gray", "#bdbdbd"
            )

        print(
            f"\nCombined taxonomic unit color map created for {len(combined_taxonomy_color_map)} units."
        )

        # --- Figure 4A: Number of Genomes by Taxonomy (Stacked Bar and Pie Charts) ---
        print("\n--- Figure 4A: Number of Genomes by Taxonomy ---")

        if not df_genome_counts_taxonomy.empty:
            # --- Figure 4A Part 1: Stacked Bar Chart ---
            # Calculate total count for each Taxonomic_Unit across both groups for sorting
            taxonomic_unit_total_counts_genomes = (
                df_genome_counts_taxonomy.groupby("Taxonomic_Unit")["Count"]
                .sum()
                .sort_values(ascending=False)
            )
            # Get the order of taxonomic units based on total counts
            taxonomic_unit_order_genomes = (
                taxonomic_unit_total_counts_genomes.index.tolist()
            )

            fig_4a_bar = px.bar(
                df_genome_counts_taxonomy,
                x=group_col,  # X-axis is Group
                y="Count",
                color="Taxonomic_Unit",  # Color is Phylum or Family
                barmode="stack",  # Stack bars by Taxonomic_Unit
                # title='Number of Genomes by Group and Taxonomy (Stacked Bar)', # No title
                labels={
                    group_col: "Group",
                    "Count": "Number of Genomes",
                    "Taxonomic_Unit": "Phylum or Family",
                },
                color_discrete_map=combined_taxonomy_color_map,  # Use combined color map
                category_orders={
                    group_col: ["Asgard", "GV"],
                    "Taxonomic_Unit": taxonomic_unit_order_genomes,
                },  # Ensure Group order and Taxonomic Unit order
                template="plotly_white",
            )
            # Apply default layout and remove gridlines
            fig_4a_bar.update_layout(plotly_layout_defaults)
            fig_4a_bar.update_xaxes(title_text="Group", showgrid=False)
            fig_4a_bar.update_yaxes(title_text="Number of Genomes", showgrid=False)
            fig_4a_bar.update_layout(
                legend_title_text="Taxonomic Unit"
            )  # Add legend title

            # Export Figure 4A Part 1 (HTML, PDF, SVG)
            fig_4a_bar_path_html = (
                Path(output_figure_dir) / "figure4a_genomes_by_taxonomy_stackedbar.html"
            )
            fig_4a_bar.write_html(str(fig_4a_bar_path_html))
            print(f"Figure 4A (Stacked Bar) HTML saved to: {fig_4a_bar_path_html}")

            # Export to PDF/SVG (requires kaleido)
            try:
                fig_4a_bar_path_pdf = (
                    Path(output_figure_dir)
                    / "figure4a_genomes_by_taxonomy_stackedbar.pdf"
                )
                fig_4a_bar.write_image(str(fig_4a_bar_path_pdf))
                print(f"Figure 4A (Stacked Bar) PDF saved to: {fig_4a_bar_path_pdf}")
            except Exception as e:
                print(
                    f"Warning: Could not export Figure 4A (Stacked Bar) to PDF. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}"
                )
            try:
                fig_4a_bar_path_svg = (
                    Path(output_figure_dir)
                    / "figure4a_genomes_by_taxonomy_stackedbar.svg"
                )
                fig_4a_bar.write_image(str(fig_4a_bar_path_svg))
                print(f"Figure 4A (Stacked Bar) SVG saved to: {fig_4a_bar_path_svg}")
            except Exception as e:
                print(
                    f"Warning: Could not export Figure 4A (Stacked Bar) to SVG. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}"
                )

            # --- Figure 4A Part 2: Pie Charts (Asgard and GV) ---
            print(
                "\n--- Figure 4A Part 2: Genome Counts Pie Charts (Asgard and GV) ---"
            )

            # Asgard Genome Pie Chart
            df_asgard_genome_pie = df_genome_counts_taxonomy[
                df_genome_counts_taxonomy[group_col] == "Asgard"
            ].copy()
            if not df_asgard_genome_pie.empty:
                # Sort Asgard phyla by count for pie chart order
                df_asgard_genome_pie = df_asgard_genome_pie.sort_values(
                    "Count", ascending=False
                ).copy()
                fig_4a_pie_asgard = go.Figure(
                    data=[
                        go.Pie(
                            labels=df_asgard_genome_pie["Taxonomic_Unit"],
                            values=df_asgard_genome_pie["Count"],
                            hole=0.3,  # Creates a donut chart
                            marker=dict(
                                colors=[
                                    combined_taxonomy_color_map.get(unit, "gray")
                                    for unit in df_asgard_genome_pie["Taxonomic_Unit"]
                                ]
                            ),  # Apply colors
                        )
                    ]
                )
                # fig_4a_pie_asgard.update_layout(title_text='Asgard Genome Counts by Phylum') # No title
                fig_4a_pie_asgard.update_layout(
                    plotly_layout_defaults
                )  # Apply defaults (no title, etc.)
                fig_4a_pie_asgard.update_layout(
                    margin=dict(t=0, b=0, l=0, r=0)
                )  # Adjust margins
                fig_4a_pie_asgard.update_traces(
                    textinfo="percent+label", insidetextorientation="radial"
                )

                # Export Figure 4A Pie Asgard
                fig_4a_pie_asgard_path_html = (
                    Path(output_figure_dir)
                    / "figure4a_genomes_by_phylum_pie_asgard.html"
                )
                fig_4a_pie_asgard.write_html(str(fig_4a_pie_asgard_path_html))
                print(
                    f"Figure 4A (Asgard Pie) HTML saved to: {fig_4a_pie_asgard_path_html}"
                )
                try:
                    fig_4a_pie_asgard_path_pdf = (
                        Path(output_figure_dir)
                        / "figure4a_genomes_by_phylum_pie_asgard.pdf"
                    )
                    fig_4a_pie_asgard.write_image(str(fig_4a_pie_asgard_path_pdf))
                    print(
                        f"Figure 4A (Asgard Pie) PDF saved to: {fig_4a_pie_asgard_path_pdf}"
                    )
                except Exception as e:
                    print(
                        f"Warning: Could not export Figure 4A (Asgard Pie) to PDF. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}"
                    )
                try:
                    fig_4a_pie_asgard_path_svg = (
                        Path(output_figure_dir)
                        / "figure4a_genomes_by_phylum_pie_asgard.svg"
                    )
                    fig_4a_pie_asgard.write_image(str(fig_4a_pie_asgard_path_svg))
                    print(
                        f"Figure 4A (Asgard Pie) SVG saved to: {fig_4a_pie_asgard_path_svg}"
                    )
                except Exception as e:
                    print(
                        f"Warning: Could not export Figure 4A (Asgard Pie) to SVG. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}"
                    )

            else:
                print("No Asgard genome data for pie chart.")

            # GV Genome Pie Chart
            df_gv_genome_pie = df_genome_counts_taxonomy[
                df_genome_counts_taxonomy[group_col] == "GV"
            ].copy()
            if not df_gv_genome_pie.empty:
                # Sort GV families by count for pie chart order
                df_gv_genome_pie = df_gv_genome_pie.sort_values(
                    "Count", ascending=False
                ).copy()
                fig_4a_pie_gv = go.Figure(
                    data=[
                        go.Pie(
                            labels=df_gv_genome_pie["Taxonomic_Unit"],
                            values=df_gv_genome_pie["Count"],
                            hole=0.3,  # Creates a donut chart
                            marker=dict(
                                colors=[
                                    combined_taxonomy_color_map.get(unit, "gray")
                                    for unit in df_gv_genome_pie["Taxonomic_Unit"]
                                ]
                            ),  # Apply colors
                        )
                    ]
                )
                # fig_4a_pie_gv.update_layout(title_text='GV Genome Counts by Family') # No title
                fig_4a_pie_gv.update_layout(
                    plotly_layout_defaults
                )  # Apply defaults (no title, etc.)
                fig_4a_pie_gv.update_layout(
                    margin=dict(t=0, b=0, l=0, r=0)
                )  # Adjust margins
                fig_4a_pie_gv.update_traces(
                    textinfo="percent+label", insidetextorientation="radial"
                )

                # Export Figure 4A Pie GV
                fig_4a_pie_gv_path_html = (
                    Path(output_figure_dir) / "figure4a_genomes_by_family_pie_gv.html"
                )
                fig_4a_pie_gv.write_html(str(fig_4a_pie_gv_path_html))
                print(f"Figure 4A (GV Pie) HTML saved to: {fig_4a_pie_gv_path_html}")
                try:
                    fig_4a_pie_gv_path_pdf = (
                        Path(output_figure_dir)
                        / "figure4a_genomes_by_family_pie_gv.pdf"
                    )
                    fig_4a_pie_gv.write_image(str(fig_4a_pie_gv_path_pdf))
                    print(f"Figure 4A (GV Pie) PDF saved to: {fig_4a_pie_gv_path_pdf}")
                except Exception as e:
                    print(
                        f"Warning: Could not export Figure 4A (GV Pie) to PDF. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}"
                    )
                try:
                    fig_4a_pie_gv_path_svg = (
                        Path(output_figure_dir)
                        / "figure4a_genomes_by_family_pie_gv.svg"
                    )
                    fig_4a_pie_gv.write_image(str(fig_4a_pie_gv_path_svg))
                    print(f"Figure 4A (GV Pie) SVG saved to: {fig_4a_pie_gv_path_svg}")
                except Exception as e:
                    print(
                        f"Warning: Could not export Figure 4A (GV Pie) to SVG. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}"
                    )

            else:
                print("No GV genome data for pie chart.")

        else:
            print("No genome counts by taxonomy available for plotting Figure 4A.")

    else:
        print(
            f"Skipping Figure 4A due to missing columns in filtered data: '{source_genome_accession_col}', '{group_col}', '{asgard_phylum_col}', or '{virus_family_col}'."
        )

    # --- Figure 4B: Number of Proteins by Taxonomy (Stacked Bar and Pie Charts) ---
    print("\n--- Figure 4B: Number of Proteins by Taxonomy ---")

    if (
        protein_id_col in df_filtered_groups.columns
        and group_col in df_filtered_groups.columns
        and asgard_phylum_col in df_filtered_groups.columns
        and virus_family_col in df_filtered_groups.columns
    ):
        # Calculate protein counts by Group and then by Phylum/Family
        protein_counts_by_group_taxonomy = []

        # Process Asgard proteins by Phylum
        df_asgard_proteins = df_filtered_groups[
            df_filtered_groups[group_col] == "Asgard"
        ].copy()
        if (
            not df_asgard_proteins.empty
            and asgard_phylum_col in df_asgard_proteins.columns
        ):
            asgard_protein_counts = (
                df_asgard_proteins.groupby([group_col, asgard_phylum_col])[
                    protein_id_col
                ]
                .count()
                .reset_index(name="Count")
            )
            asgard_protein_counts.rename(
                columns={asgard_phylum_col: "Taxonomic_Unit"}, inplace=True
            )
            protein_counts_by_group_taxonomy.append(asgard_protein_counts)
        else:
            print(
                f"Warning: Skipping Asgard protein counts by Phylum. Data empty or column '{asgard_phylum_col}' missing."
            )

        # Process GV proteins by Family
        df_gv_proteins = df_filtered_groups[
            df_filtered_groups[group_col] == "GV"
        ].copy()
        if not df_gv_proteins.empty and virus_family_col in df_gv_proteins.columns:
            gv_protein_counts = (
                df_gv_proteins.groupby([group_col, virus_family_col])[protein_id_col]
                .count()
                .reset_index(name="Count")
            )
            gv_protein_counts.rename(
                columns={virus_family_col: "Taxonomic_Unit"}, inplace=True
            )
            protein_counts_by_group_taxonomy.append(gv_protein_counts)
        else:
            print(
                f"Warning: Skipping GV protein counts by Family. Data empty or column '{virus_family_col}' missing."
            )

        if protein_counts_by_group_taxonomy:
            df_protein_counts_taxonomy = pd.concat(
                protein_counts_by_group_taxonomy, ignore_index=True
            )
            print("\nProtein Counts by Group and Taxonomy:")
            print(df_protein_counts_taxonomy.to_markdown(index=False))

            # Combine color maps for plotting (same as for genomes)
            # This is already done above based on ALL unique units from both datasets

            # --- Figure 4B Part 1: Stacked Bar Chart ---
            # Calculate total count for each Taxonomic_Unit across both groups for sorting
            taxonomic_unit_total_counts = (
                df_protein_counts_taxonomy.groupby("Taxonomic_Unit")["Count"]
                .sum()
                .sort_values(ascending=False)
            )
            # Get the order of taxonomic units based on total counts
            taxonomic_unit_order = taxonomic_unit_total_counts.index.tolist()

            fig_4b_bar = px.bar(
                df_protein_counts_taxonomy,
                x=group_col,  # X-axis is Group
                y="Count",
                color="Taxonomic_Unit",  # Color is Phylum or Family
                barmode="stack",  # Stack bars by Taxonomic_Unit
                # title='Number of Proteins by Group and Taxonomy (Stacked Bar)', # No title
                labels={
                    group_col: "Group",
                    "Count": "Number of Proteins",
                    "Taxonomic_Unit": "Phylum or Family",
                },
                color_discrete_map=combined_taxonomy_color_map,  # Use combined color map
                category_orders={
                    group_col: ["Asgard", "GV"],
                    "Taxonomic_Unit": taxonomic_unit_order,
                },  # Ensure Group order and Taxonomic Unit order
                template="plotly_white",
            )
            # Apply default layout and remove gridlines
            fig_4b_bar.update_layout(plotly_layout_defaults)
            fig_4b_bar.update_xaxes(title_text="Group", showgrid=False)
            fig_4b_bar.update_yaxes(title_text="Number of Proteins", showgrid=False)
            fig_4b_bar.update_layout(
                legend_title_text="Taxonomic Unit"
            )  # Add legend title

            # Export Figure 4B Part 1 (HTML, PDF, SVG)
            fig_4b_bar_path_html = (
                Path(output_figure_dir)
                / "figure4b_proteins_by_taxonomy_stackedbar.html"
            )
            fig_4b_bar.write_html(str(fig_4b_bar_path_html))
            print(f"Figure 4B (Stacked Bar) HTML saved to: {fig_4b_bar_path_html}")

            # Export to PDF/SVG (requires kaleido)
            try:
                fig_4b_bar_path_pdf = (
                    Path(output_figure_dir)
                    / "figure4b_proteins_by_taxonomy_stackedbar.pdf"
                )
                fig_4b_bar.write_image(str(fig_4b_bar_path_pdf))
                print(f"Figure 4B (Stacked Bar) PDF saved to: {fig_4b_bar_path_pdf}")
            except Exception as e:
                print(
                    f"Warning: Could not export Figure 4B (Stacked Bar) to PDF. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}"
                )
            try:
                fig_4b_bar_path_svg = (
                    Path(output_figure_dir)
                    / "figure4b_proteins_by_taxonomy_stackedbar.svg"
                )
                fig_4b_bar.write_image(str(fig_4b_bar_path_svg))
                print(f"Figure 4B (Stacked Bar) SVG saved to: {fig_4b_bar_path_svg}")
            except Exception as e:
                print(
                    f"Warning: Could not export Figure 4B (Stacked Bar) to SVG. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}"
                )

            # --- Figure 4B Part 2: Pie Charts (Asgard and GV) ---
            print(
                "\n--- Figure 4B Part 2: Protein Counts Pie Charts (Asgard and GV) ---"
            )

            # Asgard Protein Pie Chart
            df_asgard_protein_pie = df_protein_counts_taxonomy[
                df_protein_counts_taxonomy[group_col] == "Asgard"
            ].copy()
            if not df_asgard_protein_pie.empty:
                # Sort Asgard phyla by count for pie chart order
                df_asgard_protein_pie = df_asgard_protein_pie.sort_values(
                    "Count", ascending=False
                ).copy()
                fig_4b_pie_asgard = go.Figure(
                    data=[
                        go.Pie(
                            labels=df_asgard_protein_pie["Taxonomic_Unit"],
                            values=df_asgard_protein_pie["Count"],
                            hole=0.3,  # Creates a donut chart
                            marker=dict(
                                colors=[
                                    combined_taxonomy_color_map.get(unit, "gray")
                                    for unit in df_asgard_protein_pie["Taxonomic_Unit"]
                                ]
                            ),  # Apply colors
                        )
                    ]
                )
                # fig_4b_pie_asgard.update_layout(title_text='Asgard Protein Counts by Phylum') # No title
                fig_4b_pie_asgard.update_layout(
                    plotly_layout_defaults
                )  # Apply defaults (no title, etc.)
                fig_4b_pie_asgard.update_layout(
                    margin=dict(t=0, b=0, l=0, r=0)
                )  # Adjust margins
                fig_4b_pie_asgard.update_traces(
                    textinfo="percent+label", insidetextorientation="radial"
                )

                # Export Figure 4B Pie Asgard
                fig_4b_pie_asgard_path_html = (
                    Path(output_figure_dir)
                    / "figure4b_proteins_by_phylum_pie_asgard.html"
                )
                fig_4b_pie_asgard.write_html(str(fig_4b_pie_asgard_path_html))
                print(
                    f"Figure 4B (Asgard Pie) HTML saved to: {fig_4b_pie_asgard_path_html}"
                )
                try:
                    fig_4b_pie_asgard_path_pdf = (
                        Path(output_figure_dir)
                        / "figure4b_proteins_by_phylum_pie_asgard.pdf"
                    )
                    fig_4b_pie_asgard.write_image(str(fig_4b_pie_asgard_path_pdf))
                    print(
                        f"Figure 4B (Asgard Pie) PDF saved to: {fig_4b_pie_asgard_path_pdf}"
                    )
                except Exception as e:
                    print(
                        f"Warning: Could not export Figure 4B (Asgard Pie) to PDF. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}"
                    )
                try:
                    fig_4b_pie_asgard_path_svg = (
                        Path(output_figure_dir)
                        / "figure4b_proteins_by_phylum_pie_asgard.svg"
                    )
                    fig_4b_pie_asgard.write_image(str(fig_4b_pie_asgard_path_svg))
                    print(
                        f"Figure 4B (Asgard Pie) SVG saved to: {fig_4b_pie_asgard_path_svg}"
                    )
                except Exception as e:
                    print(
                        f"Warning: Could not export Figure 4B (Asgard Pie) to SVG. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}"
                    )

            else:
                print("No Asgard protein data for pie chart.")

            # GV Protein Pie Chart
            df_gv_protein_pie = df_protein_counts_taxonomy[
                df_protein_counts_taxonomy[group_col] == "GV"
            ].copy()
            if not df_gv_protein_pie.empty:
                # Sort GV families by count for pie chart order
                df_gv_protein_pie = df_gv_protein_pie.sort_values(
                    "Count", ascending=False
                ).copy()
                fig_4b_pie_gv = go.Figure(
                    data=[
                        go.Pie(
                            labels=df_gv_protein_pie["Taxonomic_Unit"],
                            values=df_gv_protein_pie["Count"],
                            hole=0.3,  # Creates a donut chart
                            marker=dict(
                                colors=[
                                    combined_taxonomy_color_map.get(unit, "gray")
                                    for unit in df_gv_protein_pie["Taxonomic_Unit"]
                                ]
                            ),  # Apply colors
                        )
                    ]
                )
                # fig_4b_pie_gv.update_layout(title_text='GV Protein Counts by Family') # No title
                fig_4b_pie_gv.update_layout(
                    plotly_layout_defaults
                )  # Apply defaults (no title, etc.)
                fig_4b_pie_gv.update_layout(
                    margin=dict(t=0, b=0, l=0, r=0)
                )  # Adjust margins
                fig_4b_pie_gv.update_traces(
                    textinfo="percent+label", insidetextorientation="radial"
                )

                # Export Figure 4B Pie GV
                fig_4b_pie_gv_path_html = (
                    Path(output_figure_dir) / "figure4b_proteins_by_family_pie_gv.html"
                )
                fig_4b_pie_gv.write_html(str(fig_4b_pie_gv_path_html))
                print(f"Figure 4B (GV Pie) HTML saved to: {fig_4b_pie_gv_path_html}")
                try:
                    fig_4b_pie_gv_path_pdf = (
                        Path(output_figure_dir)
                        / "figure4b_proteins_by_family_pie_gv.pdf"
                    )
                    fig_4b_pie_gv.write_image(str(fig_4b_pie_gv_path_pdf))
                    print(f"Figure 4B (GV Pie) PDF saved to: {fig_4b_pie_gv_path_pdf}")
                except Exception as e:
                    print(
                        f"Warning: Could not export Figure 4B (GV Pie) to PDF. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}"
                    )
                try:
                    fig_4b_pie_gv_path_svg = (
                        Path(output_figure_dir)
                        / "figure4b_proteins_by_family_pie_gv.svg"
                    )
                    fig_4b_pie_gv.write_image(str(fig_4b_pie_gv_path_svg))
                    print(f"Figure 4B (GV Pie) SVG saved to: {fig_4b_pie_gv_path_svg}")
                except Exception as e:
                    print(
                        f"Warning: Could not export Figure 4B (GV Pie) to SVG. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}"
                    )

            else:
                print("No GV protein data for pie chart.")

        else:
            print("No protein counts by taxonomy available for plotting Figure 4B.")

    else:
        print(
            f"Skipping Figure 4B due to missing columns in filtered data: '{protein_id_col}', '{group_col}', '{asgard_phylum_col}', or '{virus_family_col}'."
        )


print("\n\n--- Figure 4 Generation Complete ---")
print(f"Figures saved to the '{output_figure_dir}' directory (HTML, PDF, SVG).")

In [None]:
# This cell generates the plots for localization-function_100.png, which describes the subcellular localization and functional categorization of the database

In [None]:
# Cell 5: Plots for localization-function_100.png - Localization and Functional Category Distribution

# --- Imports ---
# Assumes pandas, numpy, plotly.express, plotly.graph_objects, plotly.io,
# and necessary variables/helper functions (like df_full, group_col,
# protein_id_col, localization_col, broad_func_cat_col,
# arcadia_colors_manual, arcadia_primary_palette, arcadia_secondary_palette,
# arcadia_neutrals_palette, output_figure_dir, output_figure_summary_dir,
# plotly_layout_defaults, localization_color_map, broad_category_color_map)
# are defined in the setup cell (Cell 1).
# Assumes df_filtered_groups is available from Cell 2 (Figure 2 generation) and
# correctly filters for 'Asgard' and 'GV' groups.

print(
    "\n\n--- Generating Figure 3: Localization and Functional Category Distribution (Cell 5) ---"
)

# Ensure df_filtered_groups is available and not empty
if "df_filtered_groups" not in locals() or df_filtered_groups.empty:
    print(
        "ERROR: df_filtered_groups not available or is empty. Please run Cell 2 first."
    )
else:
    # Define colors for Asgard and GV using the specified hex codes (re-defined for clarity in this cell)
    group_colors = {
        "Asgard": arcadia_colors_manual.get("aegean", "#5088C5"),
        "GV": arcadia_colors_manual.get("amber", "#F28360"),
    }
    print(
        f"\nUsing Group Colors: Asgard={group_colors.get('Asgard')}, GV={group_colors.get('GV')}"
    )

    # --- Figure 3A: Predicted Subcellular Localization Distribution ---
    print("\n--- Figure 3A: Predicted Subcellular Localization Distribution ---")

    if (
        localization_col in df_filtered_groups.columns
        and group_col in df_filtered_groups.columns
    ):
        # Calculate counts for each localization category by group
        localization_counts = (
            df_filtered_groups.groupby(group_col)[localization_col]
            .value_counts()
            .reset_index(name="Count")
        )

        # Calculate percentages for plotting
        localization_percentages = (
            df_filtered_groups.groupby(group_col)[localization_col]
            .value_counts(normalize=True)
            .mul(100)
            .reset_index(name="Percentage")
        )

        # Merge counts and percentages
        df_localization_summary = pd.merge(
            localization_counts,
            localization_percentages,
            on=[group_col, localization_col],
        )

        print(
            "\nPredicted Subcellular Localization Distribution (Counts and Percentages):"
        )
        print(
            df_localization_summary.sort_values(
                ["Group", "Count"], ascending=[True, False]
            ).to_markdown(index=False)
        )

        # Save full localization summary to a summary file
        full_localization_summary_path = (
            Path(output_figure_summary_dir) / "figure3a_localization_full_summary.csv"
        )
        try:
            df_localization_summary.to_csv(
                str(full_localization_summary_path), index=False
            )
            print(
                f"Full localization summary saved to: {full_localization_summary_path}"
            )
        except Exception as e:
            print(f"Error saving full localization summary: {e}")

        # Define category order based on overall frequency for consistent plotting
        overall_localization_order = (
            df_filtered_groups[localization_col].value_counts().index.tolist()
        )

        if not df_localization_summary.empty:
            fig_3a = px.bar(
                df_localization_summary,
                x=group_col,  # X-axis is Group
                y="Percentage",
                color=localization_col,  # Color is Localization Category
                barmode="group",  # Group bars by X-axis value (Group)
                # title='Predicted Subcellular Localization Distribution by Group', # No title
                labels={
                    localization_col: "Predicted Subcellular Localization",
                    "Percentage": "Percentage of Proteins (%)",
                    group_col: "Group",
                },
                color_discrete_map=localization_color_map,  # Use localization color map
                category_orders={
                    group_col: ["Asgard", "GV"],
                    localization_col: overall_localization_order,
                },  # Apply consistent order
                template="plotly_white",
            )
            # Apply default layout and remove gridlines
            fig_3a.update_layout(plotly_layout_defaults)
            fig_3a.update_xaxes(
                title_text="Group", showgrid=False
            )  # X-axis title is Group
            fig_3a.update_yaxes(title_text="Percentage of Proteins (%)", showgrid=False)
            fig_3a.update_layout(
                legend_title_text="Predicted Subcellular Localization"
            )  # Add legend title
            fig_3a.update_xaxes(tickangle=0)  # No angle needed for Group labels

            # Export Figure 3A (HTML, PDF, SVG)
            fig_3a_path_html = (
                Path(output_figure_dir) / "figure3a_localization_distribution.html"
            )
            fig_3a.write_html(str(fig_3a_path_html))
            print(f"Figure 3A HTML saved to: {fig_3a_path_html}")

            # Export to PDF/SVG (requires kaleido)
            try:
                fig_3a_path_pdf = (
                    Path(output_figure_dir) / "figure3a_localization_distribution.pdf"
                )
                fig_3a.write_image(str(fig_3a_path_pdf))
                print(f"Figure 3A PDF saved to: {fig_3a_path_pdf}")
            except Exception as e:
                print(
                    f"Warning: Could not export Figure 3A to PDF. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}"
                )
            try:
                fig_3a_path_svg = (
                    Path(output_figure_dir) / "figure3a_localization_distribution.svg"
                )
                fig_3a.write_image(str(fig_3a_path_svg))
                print(f"Figure 3A SVG saved to: {fig_3a_path_svg}")
            except Exception as e:
                print(
                    f"Warning: Could not export Figure 3A to SVG. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}"
                )

        else:
            print("No localization data available for plotting Figure 3A.")

    else:
        print(
            f"Skipping Figure 3A due to missing columns in filtered data: '{localization_col}' or '{group_col}'."
        )

    # --- Figure 3B: Broad Functional Category Distribution ---
    print("\n--- Figure 3B: Broad Functional Category Distribution ---")

    if (
        broad_func_cat_col in df_filtered_groups.columns
        and group_col in df_filtered_groups.columns
    ):
        # Filter out unwanted categories for plotting (same as Figure 2D)
        categories_to_exclude = ["Unknown/Unclassified", "Other Specific Annotation"]
        df_category_plot_data = df_filtered_groups[
            (~df_filtered_groups[broad_func_cat_col].isin(categories_to_exclude))
        ].copy()

        # Check if df_category_plot_data is empty after filtering categories
        if df_category_plot_data.empty:
            print(
                "Skipping Figure 3B plotting as no functional categories remain after excluding 'Unknown/Unclassified' and 'Other Specific Annotation'."
            )
        else:
            # Calculate counts for each category by group in the FILTERED data for plotting
            # Use value_counts to get counts for all categories (including 0 if they exist after filtering)
            category_counts_plot = (
                df_category_plot_data.groupby([group_col, broad_func_cat_col])
                .size()
                .reset_index(name="Count")
            )

            # Calculate percentages for plotting (based on the FILTERED data for plotting)
            # Need to calculate total count in the filtered data (for plotting) for normalization
            total_filtered_proteins_plot_by_group = (
                df_category_plot_data.groupby(group_col)
                .size()
                .reset_index(name="Total_Filtered_Plot")
            )

            # Merge counts with total counts to calculate percentages
            df_category_plot_summary = pd.merge(
                category_counts_plot,
                total_filtered_proteins_plot_by_group,
                on=group_col,
                how="left",
            )
            df_category_plot_summary["Percentage"] = (
                df_category_plot_summary["Count"]
                / df_category_plot_summary["Total_Filtered_Plot"]
            ) * 100
            df_category_plot_summary = df_category_plot_summary.drop(
                columns=["Total_Filtered_Plot"]
            )  # Drop the temporary total column

            # --- Filter by Count > 0 and Plot ---
            # Check if df_category_plot_summary is not empty and has 'Count' column before filtering
            if (
                not df_category_plot_summary.empty
                and "Count" in df_category_plot_summary.columns
            ):
                # Only filter by Count > 0 if there are any counts > 0 to filter
                if (df_category_plot_summary["Count"] > 0).any():
                    df_category_plot_summary_filtered_by_count = (
                        df_category_plot_summary[
                            df_category_plot_summary["Count"] > 0
                        ].copy()
                    )
                else:
                    # If all counts are 0, the filtered dataframe will be empty
                    df_category_plot_summary_filtered_by_count = pd.DataFrame(
                        columns=df_category_plot_summary.columns
                    )

                # Define category order for plotting based on frequency in the PLOTTED data
                if not df_category_plot_summary_filtered_by_count.empty:
                    plot_category_order = (
                        df_category_plot_summary_filtered_by_count.groupby(
                            broad_func_cat_col
                        )["Count"]
                        .sum()
                        .sort_values(ascending=False)
                        .index.tolist()
                    )

                    fig_3b = px.bar(
                        df_category_plot_summary_filtered_by_count,  # Use the DataFrame filtered by count
                        x=group_col,  # X-axis is Group
                        y="Percentage",
                        color=broad_func_cat_col,  # Color is Functional Category
                        barmode="group",  # Group bars by X-axis value (Group)
                        # title='Broad Functional Category Distribution by Group', # No title
                        labels={
                            broad_func_cat_col: "Broad Functional Category",
                            "Percentage": "Percentage of Proteins (%)",
                            group_col: "Group",
                        },
                        color_discrete_map=broad_category_color_map,  # Use functional category color map
                        category_orders={
                            group_col: ["Asgard", "GV"],
                            broad_func_cat_col: plot_category_order,
                        },  # Apply consistent order
                        template="plotly_white",
                    )
                    # Apply default layout and remove gridlines
                    fig_3b.update_layout(plotly_layout_defaults)
                    fig_3b.update_xaxes(
                        title_text="Group", showgrid=False
                    )  # X-axis title is Group
                    fig_3b.update_yaxes(
                        title_text="Percentage of Proteins (%)", showgrid=False
                    )
                    fig_3b.update_layout(
                        legend_title_text="Broad Functional Category"
                    )  # Add legend title
                    fig_3b.update_xaxes(tickangle=0)  # No angle needed for Group labels

                    # Export Figure 3B (HTML, PDF, SVG)
                    fig_3b_path_html = (
                        Path(output_figure_dir)
                        / "figure3b_broad_functional_category_distribution_filtered.html"
                    )
                    fig_3b.write_html(str(fig_3b_path_html))
                    print(f"Figure 3B HTML saved to: {fig_3b_path_html}")

                    # Export to PDF/SVG (requires kaleido)
                    try:
                        fig_3b_path_pdf = (
                            Path(output_figure_dir)
                            / "figure3b_broad_functional_category_distribution_filtered.pdf"
                        )
                        fig_3b.write_image(str(fig_3b_path_pdf))
                        print(f"Figure 3B PDF saved to: {fig_3b_path_pdf}")
                    except Exception as e:
                        print(
                            f"Warning: Could not export Figure 3B to PDF. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}"
                        )
                    try:
                        fig_3b_path_svg = (
                            Path(output_figure_dir)
                            / "figure3b_broad_functional_category_distribution_filtered.svg"
                        )
                        fig_3b.write_image(str(fig_3b_path_svg))
                        print(f"Figure 3B SVG saved to: {fig_3b_path_svg}")
                    except Exception as e:
                        print(
                            f"Warning: Could not export Figure 3B to SVG. Ensure 'kaleido' is installed (`pip install kaleido`). Error: {e}"
                        )

print("\n\n--- Figure 3 Generation Complete ---")
print(f"Figures saved to the '{output_figure_dir}' directory (HTML, PDF, SVG).")
# Note: Full functional category summary is saved in Cell 2 (Figure 2D)

In [None]:
# intraorthogroup-diversity_100.png Figure

# This cell loads the orthogroup diversity metrics and generates the plots for intraorthogroup-diversity_100.png

In [None]:
# Cell 13: Figure 9 - Overview of Orthogroup Diversity Metrics

# --- Imports & Setup Assumptions ---
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from pathlib import Path

# This cell assumes that pandas (pd), numpy (np), plotly.express (px),
# plotly.graph_objects (go), Path from pathlib,
# and all necessary variables/helper functions from Cell 1 (Setup) are available.
# Key variables from Cell 1 & subsequent cells:
# - df_og_summary: DataFrame created in Cell 9, containing OG-level summaries (Orthogroup, Intra_OG_APSI, Group).
# - output_figure_dir, output_summary_dir_phase1: Directories to save plots/data.
# - plotly_layout_defaults: Default layout for Plotly figures.
# - arcadia_primary_palette, group_colors: Color palettes/maps.
# - orthogroup_col, apsi_col ('Intra_OG_APSI'), group_col

print(
    "\n\n--- Generating Figure 9: Overview of Orthogroup Diversity Metrics (Cell 13) ---"
)

# --- Configuration ---
diversity_metrics_path = (
    "orthogroup_diversity_metrics.csv"  # Path to the uploaded diversity file
)

# --- Check if necessary DataFrames are available ---
error_found = False
if "df_og_summary" not in locals() or df_og_summary.empty:
    print(
        "ERROR: DataFrame 'df_og_summary' (with OG-level APSI and Group) not found or is empty. Please run Cell 9 first."
    )
    error_found = True
elif not all(
    col in df_og_summary.columns for col in [orthogroup_col, apsi_col, group_col]
):
    missing_cols = [
        col
        for col in [orthogroup_col, apsi_col, group_col]
        if col not in df_og_summary.columns
    ]
    print(f"ERROR: 'df_og_summary' is missing required columns: {missing_cols}.")
    error_found = True

if error_found:
    print("Cannot proceed with diversity overview generation due to missing base data.")
    # Create an empty df_merged_diversity to prevent further errors in this cell if it's used later
    df_merged_diversity = pd.DataFrame()
else:
    # --- 1. Load Diversity Metrics ---
    print(
        f"\n--- Loading Orthogroup Diversity Metrics from: {diversity_metrics_path} ---"
    )
    try:
        df_diversity = pd.read_csv(diversity_metrics_path)
        if (
            "OG_ID" in df_diversity.columns
            and orthogroup_col not in df_diversity.columns
        ):
            df_diversity = df_diversity.rename(columns={"OG_ID": orthogroup_col})

        # Select and rename relevant diversity columns
        diversity_cols_to_use = {
            orthogroup_col: orthogroup_col,
            "Tree_n_tips": "Observed_Richness",
            "Tree_entropy": "Shannon_Entropy",
        }
        actual_diversity_cols = {
            k: v for k, v in diversity_cols_to_use.items() if k in df_diversity.columns
        }
        df_diversity_subset = (
            df_diversity[list(actual_diversity_cols.keys())]
            .rename(columns=actual_diversity_cols)
            .copy()
        )

        # Ensure orthogroup_col is string for merging
        if orthogroup_col in df_diversity_subset.columns:
            df_diversity_subset[orthogroup_col] = df_diversity_subset[
                orthogroup_col
            ].astype(str)

        print(
            f"Loaded and processed diversity metrics for {len(df_diversity_subset)} orthogroups."
        )
        if df_diversity_subset.empty:
            print("WARNING: Diversity metrics DataFrame is empty.")

    except FileNotFoundError:
        print(
            f"ERROR: Diversity metrics file '{diversity_metrics_path}' not found. Diversity metrics will be missing."
        )
        df_diversity_subset = pd.DataFrame(
            columns=[orthogroup_col, "Observed_Richness", "Shannon_Entropy"]
        )
    except Exception as e:
        print(f"Error loading or processing diversity metrics file: {e}")
        df_diversity_subset = pd.DataFrame(
            columns=[orthogroup_col, "Observed_Richness", "Shannon_Entropy"]
        )

    # --- 2. Merge Diversity Metrics with df_og_summary (contains APSI and Group) ---
    print("\n--- Merging Diversity Metrics with OG Summary Data (APSI, Group) ---")
    if not df_diversity_subset.empty and orthogroup_col in df_diversity_subset.columns:
        # Ensure orthogroup_col in df_og_summary is also string
        df_og_summary[orthogroup_col] = df_og_summary[orthogroup_col].astype(str)

        df_merged_diversity = pd.merge(
            df_og_summary, df_diversity_subset, on=orthogroup_col, how="inner"
        )
        # Using 'inner' merge to only keep OGs present in both (i.e., those with diversity metrics and APSI/Group)
        print(
            f"Merged diversity data. Resulting DataFrame shape: {df_merged_diversity.shape}"
        )
        if df_merged_diversity.empty:
            print(
                "WARNING: Merged diversity DataFrame is empty. No common OGs found or one of the DFs was empty."
            )
    else:
        print("Skipping merge as diversity data is empty or missing orthogroup column.")
        df_merged_diversity = pd.DataFrame()  # Ensure it's defined as empty

    # --- 3. Generate Plots ---
    if not df_merged_diversity.empty:
        # A. Distribution of Shannon Entropy
        if "Shannon_Entropy" in df_merged_diversity.columns:
            print("\n--- Plotting Distribution of Shannon Entropy ---")
            fig_shannon = px.histogram(
                df_merged_diversity.dropna(subset=["Shannon_Entropy"]),
                x="Shannon_Entropy",
                color=group_col,
                marginal="box",  # or "rug", "violin"
                barmode="overlay",
                # title="Distribution of Shannon Entropy per Orthogroup", # No title
                labels={"Shannon_Entropy": "Shannon Entropy (Tree-based)"},
                color_discrete_map=group_colors,  # from Cell 1
            )
            fig_shannon.update_layout(plotly_layout_defaults)
            fig_shannon.update_xaxes(
                title_text="Shannon Entropy (Tree-based)", showgrid=False
            )
            fig_shannon.update_yaxes(title_text="Number of Orthogroups", showgrid=False)
            fig_shannon.update_traces(opacity=0.75)
            fig_shannon.show()
            # Save plot
            fig_shannon_path_html = (
                Path(output_figure_dir) / "figure9a_shannon_entropy_distribution.html"
            )
            fig_shannon.write_html(str(fig_shannon_path_html))
            print(f"Figure 9A (Shannon Entropy) HTML saved to: {fig_shannon_path_html}")
            try:
                fig_shannon_path_pdf = (
                    Path(output_figure_dir)
                    / "figure9a_shannon_entropy_distribution.pdf"
                )
                fig_shannon.write_image(str(fig_shannon_path_pdf))
            except Exception as e:
                print(f"PDF export error for Shannon Entropy: {e}")
            try:
                fig_shannon_path_svg = (
                    Path(output_figure_dir)
                    / "figure9a_shannon_entropy_distribution.svg"
                )
                fig_shannon.write_image(str(fig_shannon_path_svg))
            except Exception as e:
                print(f"SVG export error for Shannon Entropy: {e}")
        else:
            print("Skipping Shannon Entropy distribution: column not found.")

        # B. Distribution of Observed Richness
        if "Observed_Richness" in df_merged_diversity.columns:
            print("\n--- Plotting Distribution of Observed Richness (Tree Tips) ---")
            fig_richness = px.histogram(
                df_merged_diversity.dropna(subset=["Observed_Richness"]),
                x="Observed_Richness",
                color=group_col,
                marginal="box",
                barmode="overlay",
                # title="Distribution of Observed Richness (Tree Tips) per Orthogroup", # No title
                labels={"Observed_Richness": "Observed Richness (Number of Tree Tips)"},
                color_discrete_map=group_colors,
            )
            fig_richness.update_layout(plotly_layout_defaults)
            fig_richness.update_xaxes(
                title_text="Observed Richness (Number of Tree Tips)", showgrid=False
            )  # Consider log scale if highly skewed: type="log"
            fig_richness.update_yaxes(
                title_text="Number of Orthogroups", showgrid=False
            )
            fig_richness.update_traces(opacity=0.75)
            fig_richness.show()
            # Save plot
            fig_richness_path_html = (
                Path(output_figure_dir) / "figure9b_observed_richness_distribution.html"
            )
            fig_richness.write_html(str(fig_richness_path_html))
            print(
                f"Figure 9B (Observed Richness) HTML saved to: {fig_richness_path_html}"
            )
            try:
                fig_richness_path_pdf = (
                    Path(output_figure_dir)
                    / "figure9b_observed_richness_distribution.pdf"
                )
                fig_richness.write_image(str(fig_richness_path_pdf))
            except Exception as e:
                print(f"PDF export error for Richness: {e}")
            try:
                fig_richness_path_svg = (
                    Path(output_figure_dir)
                    / "figure9b_observed_richness_distribution.svg"
                )
                fig_richness.write_image(str(fig_richness_path_svg))
            except Exception as e:
                print(f"SVG export error for Richness: {e}")
        else:
            print("Skipping Observed Richness distribution: column not found.")

        # C. Shannon Entropy vs. Observed Richness
        if (
            "Shannon_Entropy" in df_merged_diversity.columns
            and "Observed_Richness" in df_merged_diversity.columns
        ):
            print("\n--- Plotting Shannon Entropy vs. Observed Richness ---")
            fig_ent_vs_rich = px.scatter(
                df_merged_diversity.dropna(
                    subset=["Shannon_Entropy", "Observed_Richness"]
                ),
                x="Observed_Richness",
                y="Shannon_Entropy",
                color=group_col,
                # title='Shannon Entropy vs. Observed Richness per Orthogroup', # No title
                labels={
                    "Observed_Richness": "Observed Richness (Tree Tips)",
                    "Shannon_Entropy": "Shannon Entropy (Tree-based)",
                },
                color_discrete_map=group_colors,
                opacity=0.7,
                trendline="ols",  # Ordinary Least Squares trendline
                trendline_scope="overall",  # or "trace" for per-group trendlines
            )
            fig_ent_vs_rich.update_layout(plotly_layout_defaults)
            fig_ent_vs_rich.update_xaxes(
                title_text="Observed Richness (Tree Tips)", showgrid=False
            )  # Consider log scale: type="log"
            fig_ent_vs_rich.update_yaxes(
                title_text="Shannon Entropy (Tree-based)", showgrid=False
            )
            fig_ent_vs_rich.show()
            # Save plot
            fig_ent_rich_path_html = (
                Path(output_figure_dir) / "figure9c_entropy_vs_richness.html"
            )
            fig_ent_vs_rich.write_html(str(fig_ent_rich_path_html))
            print(
                f"Figure 9C (Entropy vs Richness) HTML saved to: {fig_ent_rich_path_html}"
            )
            try:
                fig_ent_rich_path_pdf = (
                    Path(output_figure_dir) / "figure9c_entropy_vs_richness.pdf"
                )
                fig_ent_vs_rich.write_image(str(fig_ent_rich_path_pdf))
            except Exception as e:
                print(f"PDF export error for Entropy vs Richness: {e}")
            try:
                fig_ent_rich_path_svg = (
                    Path(output_figure_dir) / "figure9c_entropy_vs_richness.svg"
                )
                fig_ent_vs_rich.write_image(str(fig_ent_rich_path_svg))
            except Exception as e:
                print(f"SVG export error for Entropy vs Richness: {e}")
        else:
            print(
                "Skipping Shannon Entropy vs. Observed Richness plot: columns not found."
            )

        # D. Shannon Entropy vs. APSI
        if (
            "Shannon_Entropy" in df_merged_diversity.columns
            and apsi_col in df_merged_diversity.columns
        ):
            print("\n--- Plotting Shannon Entropy vs. APSI ---")
            # Convert APSI to percentage for plotting if it's not already
            df_merged_diversity["APSI_Percent"] = df_merged_diversity[apsi_col] * 100

            fig_ent_vs_apsi = px.scatter(
                df_merged_diversity.dropna(subset=["Shannon_Entropy", "APSI_Percent"]),
                x="APSI_Percent",
                y="Shannon_Entropy",
                color=group_col,
                # title='Shannon Entropy vs. Intra-OG APSI', # No title
                labels={
                    "APSI_Percent": "Intra-OG APSI (%)",
                    "Shannon_Entropy": "Shannon Entropy (Tree-based)",
                },
                color_discrete_map=group_colors,
                opacity=0.7,
                trendline="ols",
                trendline_scope="overall",
            )
            fig_ent_vs_apsi.update_layout(plotly_layout_defaults)
            fig_ent_vs_apsi.update_xaxes(title_text="Intra-OG APSI (%)", showgrid=False)
            fig_ent_vs_apsi.update_yaxes(
                title_text="Shannon Entropy (Tree-based)", showgrid=False
            )
            fig_ent_vs_apsi.show()
            # Save plot
            fig_ent_apsi_path_html = (
                Path(output_figure_dir) / "figure9d_entropy_vs_apsi.html"
            )
            fig_ent_vs_apsi.write_html(str(fig_ent_apsi_path_html))
            print(
                f"Figure 9D (Entropy vs APSI) HTML saved to: {fig_ent_apsi_path_html}"
            )
            try:
                fig_ent_apsi_path_pdf = (
                    Path(output_figure_dir) / "figure9d_entropy_vs_apsi.pdf"
                )
                fig_ent_vs_apsi.write_image(str(fig_ent_apsi_path_pdf))
            except Exception as e:
                print(f"PDF export error for Entropy vs APSI: {e}")
            try:
                fig_ent_apsi_path_svg = (
                    Path(output_figure_dir) / "figure9d_entropy_vs_apsi.svg"
                )
                fig_ent_vs_apsi.write_image(str(fig_ent_apsi_path_svg))
            except Exception as e:
                print(f"SVG export error for Entropy vs APSI: {e}")
        else:
            print("Skipping Shannon Entropy vs. APSI plot: columns not found.")

        # E. Observed Richness vs. APSI
        if (
            "Observed_Richness" in df_merged_diversity.columns
            and apsi_col in df_merged_diversity.columns
        ):
            print("\n--- Plotting Observed Richness vs. APSI ---")
            # APSI_Percent column should exist from previous plot

            fig_rich_vs_apsi = px.scatter(
                df_merged_diversity.dropna(
                    subset=["Observed_Richness", "APSI_Percent"]
                ),
                x="APSI_Percent",
                y="Observed_Richness",
                color=group_col,
                # title='Observed Richness vs. Intra-OG APSI', # No title
                labels={
                    "APSI_Percent": "Intra-OG APSI (%)",
                    "Observed_Richness": "Observed Richness (Tree Tips)",
                },
                color_discrete_map=group_colors,
                opacity=0.7,
                trendline="ols",
                trendline_scope="overall",
            )
            fig_rich_vs_apsi.update_layout(plotly_layout_defaults)
            fig_rich_vs_apsi.update_xaxes(
                title_text="Intra-OG APSI (%)", showgrid=False
            )
            fig_rich_vs_apsi.update_yaxes(
                title_text="Observed Richness (Tree Tips)", showgrid=False
            )  # Consider log scale: type="log"
            fig_rich_vs_apsi.show()
            # Save plot
            fig_rich_apsi_path_html = (
                Path(output_figure_dir) / "figure9e_richness_vs_apsi.html"
            )
            fig_rich_vs_apsi.write_html(str(fig_rich_apsi_path_html))
            print(
                f"Figure 9E (Richness vs APSI) HTML saved to: {fig_rich_apsi_path_html}"
            )
            try:
                fig_rich_apsi_path_pdf = (
                    Path(output_figure_dir) / "figure9e_richness_vs_apsi.pdf"
                )
                fig_rich_vs_apsi.write_image(str(fig_rich_apsi_path_pdf))
            except Exception as e:
                print(f"PDF export error for Richness vs APSI: {e}")
            try:
                fig_rich_apsi_path_svg = (
                    Path(output_figure_dir) / "figure9e_richness_vs_apsi.svg"
                )
                fig_rich_vs_apsi.write_image(str(fig_rich_apsi_path_svg))
            except Exception as e:
                print(f"SVG export error for Richness vs APSI: {e}")
        else:
            print("Skipping Observed Richness vs. APSI plot: columns not found.")

        # Save the merged diversity data
        merged_diversity_summary_path = (
            Path(output_summary_dir_phase1)
            / "figure9_merged_diversity_and_apsi_data.csv"
        )
        try:
            df_merged_diversity.to_csv(merged_diversity_summary_path, index=False)
            print(
                f"\nMerged diversity and APSI data saved to: {merged_diversity_summary_path}"
            )
        except Exception as e:
            print(f"Error saving merged diversity data: {e}")

    else:
        print("Skipping Figure 9 plots as 'df_merged_diversity' is empty.")

print(
    "\n\n--- Cell 13 (Figure 9 - Overview of Orthogroup Diversity Metrics) Complete ---"
)

In [None]:
# This cell generates plots to show which eukaryotic organisms frequently appear as top hits in DIAMOND searches. These plots comprise Figure Eukaryote-homologs_85.png

In [None]:
# Cell 21: Figure Eukaryote-homologs_85.png - Characterizing Eukaryotic Hits for Asgard and Giant Virus Proteins (Consistent Colors)

# --- Imports & Setup Assumptions ---
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from pathlib import Path

# This cell assumes that pandas (pd), numpy (np), plotly.express (px),
# plotly.graph_objects (go), Path from pathlib, make_subplots,
# and all necessary variables/helper functions from Cell 1 (Setup) are available.
# Key variables from Cell 1:
# - output_figure_dir, output_summary_dir_phase1: Directories to save plots/data.
# - plotly_layout_defaults: Default layout for Plotly figures.
# - arcadia_primary_palette, arcadia_secondary_palette: Color palettes.
# - group_col, protein_id_col,
# - 'Has_Euk_DIAMOND_Hit', 'Euk_Hit_Organism', 'Euk_Hit_Protein_Name' (column names)
# - clean_protein_name (helper function from Cell 1)

print(
    "\n\n--- Generating Figure Eukaryote-homologs_85.png: Characterizing Eukaryotic Hits (Consistent Colors) (Cell 21) ---"
)

# --- Configuration ---
DB_PATH_V2_4 = "proteome_database_v2.5.csv"  # INPUT: Database with GV Euk Hits
TOP_N_ORGANISMS = 25
TOP_N_PROTEIN_NAMES = 30

# --- Load Data ---
error_found_loading = False
df_db_v2_4 = pd.DataFrame()

if not Path(DB_PATH_V2_4).is_file():
    print(
        f"ERROR: Database file '{DB_PATH_V2_4}' not found. Please ensure it was created correctly."
    )
    error_found_loading = True
else:
    try:
        df_db_v2_4 = pd.read_csv(DB_PATH_V2_4, low_memory=False)
        required_cols = [
            group_col,
            "Has_Euk_DIAMOND_Hit",
            "Euk_Hit_Organism",
            "Euk_Hit_Protein_Name",
            protein_id_col,
        ]
        if not all(col in df_db_v2_4.columns for col in required_cols):
            missing = [col for col in required_cols if col not in df_db_v2_4.columns]
            print(
                f"ERROR: Loaded database '{DB_PATH_V2_4}' is missing critical columns: {missing}"
            )
            error_found_loading = True
        else:
            print(
                f"Successfully loaded database '{DB_PATH_V2_4}'. Shape: {df_db_v2_4.shape}"
            )
            df_db_v2_4["Has_Euk_DIAMOND_Hit"] = (
                df_db_v2_4["Has_Euk_DIAMOND_Hit"].fillna(False).astype(bool)
            )
            df_db_v2_4["Euk_Hit_Organism"] = (
                df_db_v2_4["Euk_Hit_Organism"].fillna("Unknown").astype(str).str.strip()
            )
            df_db_v2_4["Euk_Hit_Protein_Name"] = (
                df_db_v2_4["Euk_Hit_Protein_Name"].fillna("Unknown").astype(str)
            )
    except Exception as e:
        print(f"Error loading database '{DB_PATH_V2_4}': {e}")
        error_found_loading = True

if error_found_loading:
    print("Cannot proceed with Figure 13 generation due to missing input data.")
else:
    # --- Prepare Data for Panels A and B ---
    df_asgard_hits_panel_a = df_db_v2_4[
        (df_db_v2_4[group_col] == "Asgard")
        & (df_db_v2_4["Has_Euk_DIAMOND_Hit"] == True)
    ].copy()

    df_gv_hits_panel_b = df_db_v2_4[
        (
            df_db_v2_4.get(group_col) == "GV"
        )  # Use .get for safety if group_col might be missing
        & (df_db_v2_4["Has_Euk_DIAMOND_Hit"] == True)
    ].copy()

    # --- Create a Consistent Color Map for Eukaryotic Organisms ---
    print("\n--- Creating Consistent Color Map for Eukaryotic Organisms ---")
    top_asgard_orgs_list = []
    if not df_asgard_hits_panel_a.empty:
        top_asgard_orgs_list = (
            df_asgard_hits_panel_a["Euk_Hit_Organism"]
            .value_counts()
            .nlargest(TOP_N_ORGANISMS)
            .index.tolist()
        )

    top_gv_orgs_list = []
    if not df_gv_hits_panel_b.empty:
        top_gv_orgs_list = (
            df_gv_hits_panel_b["Euk_Hit_Organism"]
            .value_counts()
            .nlargest(TOP_N_ORGANISMS)
            .index.tolist()
        )

    # Combine and get unique organisms from the top lists of both groups
    all_top_organisms = sorted(
        list(set(top_asgard_orgs_list + top_gv_orgs_list) - {"Unknown"})
    )

    euk_organism_color_map = {}
    # Use a combined palette for more variety if many unique organisms
    combined_palette = (
        arcadia_primary_palette + arcadia_secondary_palette + arcadia_neutrals_palette
    )
    # Ensure 'Unknown' gets a specific color if it appears
    euk_organism_color_map["Unknown"] = "#cccccc"  # A neutral gray for Unknown

    for i, org_name in enumerate(all_top_organisms):
        euk_organism_color_map[org_name] = combined_palette[i % len(combined_palette)]

    print(
        f"Created color map for {len(euk_organism_color_map)} unique eukaryotic organisms (including Unknown)."
    )

    # --- Create Figure with Subplots ---
    fig = make_subplots(
        rows=3,
        cols=1,
        subplot_titles=(
            "<b>A.</b> Top Eukaryotic Organisms Hit by Asgard Proteins",
            "<b>B.</b> Top Eukaryotic Organisms Hit by Giant Virus Proteins",
            "<b>C.</b> Top Eukaryotic Protein Functions Hit by Giant Virus Proteins",
        ),
        vertical_spacing=0.1,
    )

    # --- Panel A: Top Eukaryotic Organisms Hit by Asgard Proteins ---
    print("\n--- Generating Panel A: Asgard Eukaryotic Hit Organisms ---")
    if not df_asgard_hits_panel_a.empty:
        asgard_org_counts = (
            df_asgard_hits_panel_a["Euk_Hit_Organism"]
            .value_counts()
            .nlargest(TOP_N_ORGANISMS)
            .reset_index()
        )
        asgard_org_counts.columns = ["Euk_Hit_Organism", "Count"]

        if not asgard_org_counts.empty:
            fig.add_trace(
                go.Bar(
                    x=asgard_org_counts["Euk_Hit_Organism"],
                    y=asgard_org_counts["Count"],
                    name="Asgard Hits",  # Will not be shown in legend if showlegend=False for the trace
                    marker_color=[
                        euk_organism_color_map.get(org, "#999999")
                        for org in asgard_org_counts["Euk_Hit_Organism"]
                    ],  # Apply consistent map
                ),
                row=1,
                col=1,
            )
            fig.update_xaxes(
                title_text="",
                tickangle=45,
                categoryorder="total descending",
                row=1,
                col=1,
            )
            fig.update_yaxes(title_text="Number of Asgard Proteins", row=1, col=1)

            panel_a_data_path = (
                Path(output_summary_dir_phase1)
                / "figure13_panel_a_asgard_euk_org_hits.csv"
            )
            try:
                df_asgard_hits_panel_a[
                    "Euk_Hit_Organism"
                ].value_counts().reset_index().to_csv(panel_a_data_path, index=False)
                print(f"Data for Panel A saved to {panel_a_data_path}")
            except Exception as e:
                print(f"Error saving Panel A data: {e}")
        else:
            print("No eukaryotic hits found for Asgard proteins to plot for Panel A.")
    else:
        print("No Asgard proteins with eukaryotic hits found for Panel A.")

    # --- Panel B: Top Eukaryotic Organisms Hit by Giant Virus Proteins ---
    print("\n--- Generating Panel B: Giant Virus Eukaryotic Hit Organisms ---")
    if not df_gv_hits_panel_b.empty:
        gv_org_counts = (
            df_gv_hits_panel_b["Euk_Hit_Organism"]
            .value_counts()
            .nlargest(TOP_N_ORGANISMS)
            .reset_index()
        )
        gv_org_counts.columns = ["Euk_Hit_Organism", "Count"]

        if not gv_org_counts.empty:
            fig.add_trace(
                go.Bar(
                    x=gv_org_counts["Euk_Hit_Organism"],
                    y=gv_org_counts["Count"],
                    name="GV Hits",
                    marker_color=[
                        euk_organism_color_map.get(org, "#999999")
                        for org in gv_org_counts["Euk_Hit_Organism"]
                    ],  # Apply consistent map
                ),
                row=2,
                col=1,
            )
            fig.update_xaxes(
                title_text="",
                tickangle=45,
                categoryorder="total descending",
                row=2,
                col=1,
            )
            fig.update_yaxes(title_text="Number of GV Proteins", row=2, col=1)

            panel_b_data_path = (
                Path(output_summary_dir_phase1) / "figure13_panel_b_gv_euk_org_hits.csv"
            )
            try:
                df_gv_hits_panel_b[
                    "Euk_Hit_Organism"
                ].value_counts().reset_index().to_csv(panel_b_data_path, index=False)
                print(f"Data for Panel B saved to {panel_b_data_path}")
            except Exception as e:
                print(f"Error saving Panel B data: {e}")
        else:
            print("No eukaryotic hits found for GV proteins to plot for Panel B.")
    else:
        print("No GV proteins with eukaryotic hits found for Panel B.")

    # --- Panel C: Top Eukaryotic Protein Functions Hit by Giant Virus Proteins ---
    print("\n--- Generating Panel C: Giant Virus Eukaryotic Hit Protein Functions ---")
    if not df_gv_hits_panel_b.empty:
        if "clean_protein_name" in locals() and callable(clean_protein_name):
            df_gv_hits_panel_b["Cleaned_Euk_Hit_Protein_Name"] = df_gv_hits_panel_b[
                "Euk_Hit_Protein_Name"
            ].apply(clean_protein_name)
        else:
            print(
                "WARNING: 'clean_protein_name' function not found. Using raw Euk_Hit_Protein_Name for Panel C."
            )
            df_gv_hits_panel_b["Cleaned_Euk_Hit_Protein_Name"] = df_gv_hits_panel_b[
                "Euk_Hit_Protein_Name"
            ]

        gv_prot_name_counts_all = df_gv_hits_panel_b[
            "Cleaned_Euk_Hit_Protein_Name"
        ].value_counts()
        # Filter out very generic names AFTER counting, before selecting top N for plot
        gv_prot_name_counts_filtered = gv_prot_name_counts_all[
            ~gv_prot_name_counts_all.index.astype(str).str.contains(
                "Hypothetical|Unknown|Uncharacterized|^nan$",
                case=False,
                na=False,
                regex=True,
            )
        ]
        gv_prot_name_counts_plot = gv_prot_name_counts_filtered.nlargest(
            TOP_N_PROTEIN_NAMES
        ).reset_index()
        gv_prot_name_counts_plot.columns = ["Cleaned_Euk_Hit_Protein_Name", "Count"]

        if not gv_prot_name_counts_plot.empty:
            fig.add_trace(
                go.Bar(
                    x=gv_prot_name_counts_plot["Cleaned_Euk_Hit_Protein_Name"],
                    y=gv_prot_name_counts_plot["Count"],
                    name="GV Hit Protein Functions",
                    marker_color=arcadia_secondary_palette[
                        0 % len(arcadia_secondary_palette)
                    ],  # Different palette for this panel
                ),
                row=3,
                col=1,
            )
            fig.update_xaxes(
                title_text="Eukaryotic Protein Function (Cleaned Name)",
                tickangle=45,
                categoryorder="total descending",
                row=3,
                col=1,
            )
            fig.update_yaxes(title_text="Number of GV Protein Hits", row=3, col=1)

            panel_c_data_path = (
                Path(output_summary_dir_phase1)
                / "figure13_panel_c_gv_euk_prot_func_hits.csv"
            )
            try:
                gv_prot_name_counts_all.reset_index().to_csv(
                    panel_c_data_path, index=False
                )  # Save all counts
                print(f"Data for Panel C saved to {panel_c_data_path}")
            except Exception as e:
                print(f"Error saving Panel C data: {e}")
        else:
            print(
                "No eukaryotic protein names found for GV hits to plot for Panel C after filtering generic names."
            )
    else:
        print("No GV proteins with eukaryotic hits found for Panel C analysis.")

    # --- Final Layout Updates for the Entire Figure ---
    fig.update_layout(
        height=1400,  # Increased height for better label spacing
        plot_bgcolor="rgba(0,0,0,0)",
        paper_bgcolor="rgba(0,0,0,0)",
        margin=dict(l=80, r=50, t=100, b=220),  # Increased bottom margin
        showlegend=False,
    )
    if "plotly_layout_defaults" in locals():
        fig.update_layout(font=plotly_layout_defaults.font)
        for i in range(len(fig.layout.annotations)):
            if fig.layout.annotations[i].text.startswith("<b>"):
                fig.layout.annotations[i].font.size = 14
                if plotly_layout_defaults.font:
                    fig.layout.annotations[
                        i
                    ].font.family = plotly_layout_defaults.font.family

    if "plotly_layout_defaults" in locals():
        for axis_name_template in ["xaxis", "yaxis"]:
            for i in range(1, 4):
                axis_ref_name = f"{axis_name_template}{i if i > 1 else ''}"
                if hasattr(fig.layout, axis_ref_name) and hasattr(
                    plotly_layout_defaults, axis_name_template
                ):
                    default_axis_style = getattr(
                        plotly_layout_defaults, axis_name_template
                    )
                    current_axis = getattr(fig.layout, axis_ref_name)
                    if (
                        current_axis.title
                        and current_axis.title.text
                        and default_axis_style.title
                        and default_axis_style.title.font
                    ):
                        current_axis.title.font.size = (
                            default_axis_style.title.font.size
                        )
                        current_axis.title.font.family = (
                            default_axis_style.title.font.family
                        )
                        current_axis.title.font.weight = (
                            default_axis_style.title.font.weight
                        )
                    elif default_axis_style.title and default_axis_style.title.font:
                        current_axis.title.font.size = (
                            default_axis_style.title.font.size
                        )
                        current_axis.title.font.family = (
                            default_axis_style.title.font.family
                        )
                        current_axis.title.font.weight = (
                            default_axis_style.title.font.weight
                        )
                    if default_axis_style.tickfont:
                        current_axis.tickfont.size = default_axis_style.tickfont.size
                        current_axis.tickfont.family = (
                            default_axis_style.tickfont.family
                        )
                    current_axis.showgrid = False
                    current_axis.zeroline = False
                    current_axis.linecolor = "black"
                    current_axis.linewidth = 1.5
                    current_axis.ticks = "outside"
                    current_axis.ticklen = 5
                    current_axis.tickwidth = 1.5
                    current_axis.tickcolor = "black"
    fig.show()

    # --- Save Figure ---
    fig_path_html = (
        Path(output_figure_dir)
        / "figure_eukaryotic_hit_characterization_consistent_colors.html"
    )
    fig.write_html(str(fig_path_html))
    print(
        f"Figure Eukaryote-homologs_85.png (Consistent Colors) HTML saved to: {fig_path_html}"
    )
    try:
        fig_path_pdf = (
            Path(output_figure_dir)
            / "figure_eukaryotic_hit_characterization_consistent_colors.pdf"
        )
        fig.write_image(str(fig_path_pdf), width=900, height=1400)
        print(f"Figure Eukaryote-homologs_85.png PDF saved to: {fig_path_pdf}")
    except Exception as e:
        print(f"Warning: Could not export Figure 13 to PDF. Error: {e}")
    try:
        fig_path_svg = (
            Path(output_figure_dir)
            / "figure13_eukaryotic_hit_characterization_consistent_colors.svg"
        )
        fig.write_image(str(fig_path_svg), width=900, height=1400)
        print(f"Figure Eukaryote-homologs_85.png SVG saved to: {fig_path_svg}")
    except Exception as e:
        print(f"Warning: Could not export Figure 13 to SVG. Error: {e}")

print(
    "\n\n--- Cell 21 (Figure Eukaryote-homologs_85.png - Eukaryotic Hit Characterization with Consistent Colors) Complete ---"
)

In [None]:
# This cell reproduces the orthogroup diversity/functional enrichment heatmap with the green-purple color scale used in the pub.

In [None]:
# Cell 20B (New): Regenerate Functional Enrichment Heatmap with Green-Purple Colorscale

# --- Imports & Setup ---
import pandas as pd
import numpy as np
import plotly.graph_objects as go
from pathlib import Path

# Import the necessary components from arcadia-pycolor
try:
    from arcadia_pycolor import colors
    from arcadia_pycolor.gradient import Gradient

    print("Successfully imported arcadia_pycolor components.")
except ImportError:
    print(
        "ERROR: arcadia-pycolor is not installed. Please install it: pip install arcadia-pycolor"
    )

    # Define dummy objects to prevent crashes, though plot will look wrong
    class colors:
        fern = lime = lichen = lilac = aster = ghost = "grey"

    class Gradient:
        def __init__(self, name, color_list, pos_list):
            pass

        def to_plotly_colorscale(self):
            return "Viridis"

        def reverse(self):
            return self

        def __add__(self, other):
            return self


# --- Configuration (should match previous cells) ---
# Define the standard internal column names the script will use.
orthogroup_col = "OG_ID"
group_col = "Group"
broad_func_cat_col = "Broad_Functional_Category"
apsi_col = "Intra_OG_APSI"

# --- Paths to input files ---
# Main database, which has the functional categories
main_database_path = (
    "proteome_database_v3.5.csv"  # <-- IMPORTANT: Point to your main (corrected) DB
)
# Summary file with extreme OG flags from Cell 15
extreme_ogs_summary_path = (
    Path("output_summary_data_hit_validation_phase1")
    / "figure11_extreme_diversity_ogs_summary_expanded_revised.csv"
)
# Summary file with diversity metrics from Cell 13
merged_diversity_path = (
    Path("output_summary_data_hit_validation_phase1")
    / "figure9_merged_diversity_and_apsi_data.csv"
)


# Profiles to include in the heatmap
profiles_for_heatmap = {
    "High Entropy & High APSI": "Hi Ent, Hi APSI",
    "Low Entropy & Low APSI": "Lo Ent, Lo APSI",
}
# Categories to exclude from the heatmap for clarity
categories_to_exclude_from_heatmap = [
    "Other Specific Annotation",
    "General Protein Features",
    "Unknown/Unclassified",
]

# --- Load Data ---
print("\n--- Loading necessary data for heatmap generation ---")
error_found = False
try:
    df_extreme_ogs_summary = pd.read_csv(extreme_ogs_summary_path)
    # --- FIX: Standardize orthogroup column name ---
    if (
        "Orthogroup" in df_extreme_ogs_summary.columns
        and orthogroup_col not in df_extreme_ogs_summary.columns
    ):
        df_extreme_ogs_summary.rename(
            columns={"Orthogroup": orthogroup_col}, inplace=True
        )
    print(f"Loaded extreme OGs summary from: {extreme_ogs_summary_path}")
except FileNotFoundError:
    print(
        f"ERROR: Extreme OGs summary file not found at '{extreme_ogs_summary_path}'. Please run Cell 15 first."
    )
    df_extreme_ogs_summary = pd.DataFrame()
    error_found = True

try:
    df_merged_diversity = pd.read_csv(merged_diversity_path)
    # --- FIX: Standardize orthogroup column name ---
    if (
        "Orthogroup" in df_merged_diversity.columns
        and orthogroup_col not in df_merged_diversity.columns
    ):
        df_merged_diversity.rename(columns={"Orthogroup": orthogroup_col}, inplace=True)
    print(f"Loaded merged diversity data from: {merged_diversity_path}")
except FileNotFoundError:
    print(
        f"ERROR: Merged diversity data file not found at '{merged_diversity_path}'. Please run Cell 13 first."
    )
    df_merged_diversity = pd.DataFrame()
    error_found = True

try:
    df_full = pd.read_csv(main_database_path, low_memory=False)
    df_full.columns = df_full.columns.str.strip()
    # --- FIX: Standardize orthogroup column name ---
    if "Orthogroup" in df_full.columns and orthogroup_col not in df_full.columns:
        df_full.rename(columns={"Orthogroup": orthogroup_col}, inplace=True)
    print(
        f"Loaded main database from '{main_database_path}' to get functional categories."
    )
except FileNotFoundError:
    print(
        f"ERROR: Main database file not found at '{main_database_path}'. Cannot get functional categories."
    )
    df_full = pd.DataFrame()
    error_found = True


# --- Data Preparation for Heatmap ---
if not error_found:
    print("\n--- Preparing data for heatmap ---")

    # --- Merge Broad_Functional_Category into df_merged_diversity ---
    if broad_func_cat_col not in df_merged_diversity.columns:
        print(
            f"'{broad_func_cat_col}' not in diversity data, merging from main database..."
        )
        if orthogroup_col in df_full.columns:
            # Get one representative functional category per orthogroup
            og_categories = df_full[
                [orthogroup_col, broad_func_cat_col]
            ].drop_duplicates(subset=[orthogroup_col])
            # Merge it into the diversity dataframe
            df_merged_diversity = pd.merge(
                df_merged_diversity, og_categories, on=orthogroup_col, how="left"
            )
            df_merged_diversity[broad_func_cat_col].fillna(
                "Unknown/Unclassified", inplace=True
            )
            print("Merge complete.")
        else:
            print(
                f"ERROR: Cannot merge categories because '{orthogroup_col}' is missing from df_full."
            )
            error_found = True

    # Check for group column and standardize its name too
    if (
        "Source_Dataset" in df_merged_diversity.columns
        and group_col not in df_merged_diversity.columns
    ):
        df_merged_diversity.rename(columns={"Source_Dataset": group_col}, inplace=True)

    if not error_found:
        heatmap_data_list = []

        # Calculate baseline functional distributions AFTER excluding unwanted categories
        baseline_func_dist_asgard = (
            df_merged_diversity[
                (df_merged_diversity[group_col] == "Asgard")
                & (
                    ~df_merged_diversity[broad_func_cat_col].isin(
                        categories_to_exclude_from_heatmap
                    )
                )
            ][broad_func_cat_col]
            .value_counts(normalize=True)
            .mul(100)
        )

        baseline_func_dist_gv = (
            df_merged_diversity[
                (df_merged_diversity[group_col] == "GV")
                & (
                    ~df_merged_diversity[broad_func_cat_col].isin(
                        categories_to_exclude_from_heatmap
                    )
                )
            ][broad_func_cat_col]
            .value_counts(normalize=True)
            .mul(100)
        )

        for profile_key, short_profile_label in profiles_for_heatmap.items():
            for group_val in ["Asgard", "GV"]:
                df_profile_group_subset = df_extreme_ogs_summary[
                    df_extreme_ogs_summary["Reason_for_Highlight"].str.contains(
                        profile_key, regex=False, na=False
                    )
                    & (df_extreme_ogs_summary[group_col] == group_val)
                    & (
                        ~df_extreme_ogs_summary[broad_func_cat_col].isin(
                            categories_to_exclude_from_heatmap
                        )
                    )
                ]
                if df_profile_group_subset.empty:
                    continue

                observed_dist = (
                    df_profile_group_subset[broad_func_cat_col]
                    .value_counts(normalize=True)
                    .mul(100)
                )
                baseline_dist = (
                    baseline_func_dist_asgard
                    if group_val == "Asgard"
                    else baseline_func_dist_gv
                )

                for func_cat, obs_perc in observed_dist.items():
                    if func_cat in categories_to_exclude_from_heatmap:
                        continue
                    exp_perc = baseline_dist.get(func_cat, 0)
                    log2_fold_change = np.log2((obs_perc + 1e-9) / (exp_perc + 1e-9))
                    heatmap_data_list.append(
                        {
                            "Functional_Category": func_cat,
                            "Profile_Group": f"{group_val} - {short_profile_label}",
                            "Log2_Fold_Change": log2_fold_change,
                        }
                    )

        if not heatmap_data_list:
            print("No data available to generate the heatmap.")
        else:
            df_heatmap = pd.DataFrame(heatmap_data_list)
            df_heatmap_pivot = df_heatmap.pivot(
                index="Functional_Category",
                columns="Profile_Group",
                values="Log2_Fold_Change",
            ).fillna(0)
            df_heatmap_pivot = df_heatmap_pivot.sort_index()

            if not df_heatmap_pivot.empty:
                print("\n--- Generating Heatmap ---")

                # --- Define the necessary gradients to build purple_green ---
                greens = Gradient(
                    "greens",
                    [colors.fern, colors.lime, colors.lichen],
                    [0, 0.622, 1],
                )
                purples = Gradient(
                    "purples",
                    [colors.lilac, colors.aster, colors.ghost],
                    [0, 0.144, 1.0],
                )

                # --- CORRECTED: Build the purple_green gradient object ---
                purple_green = purples + greens.reverse()
                purple_green.name = "purple_green"

                # --- Get the Green-Purple Colorscale for Plotly ---
                green_purple_colorscale = purple_green.to_plotly_colorscale()
                print(
                    "Using 'purple_green' gradient and converting to Plotly colorscale."
                )

                # Determine a symmetric color scale range
                abs_max_val = df_heatmap_pivot.abs().max().max()
                z_min = -abs_max_val if abs_max_val > 0 else -1
                z_max = abs_max_val if abs_max_val > 0 else 1

                # Create the heatmap figure
                fig_heatmap = go.Figure(
                    data=go.Heatmap(
                        z=df_heatmap_pivot.values,
                        x=df_heatmap_pivot.columns,
                        y=df_heatmap_pivot.index,
                        colorscale=green_purple_colorscale,  # <-- APPLYING THE NEW COLORSCALE
                        zmid=0,
                        zmin=z_min,
                        zmax=z_max,
                        colorbar_title="Log2 (Obs/Exp %)",
                    )
                )

                # Apply styling consistent with the rest of the notebook
                fig_heatmap.update_layout(
                    title_text="<b>Functional Enrichment in 'Surprising' Diversity Profiles</b>",
                    xaxis_title="",
                    yaxis_title="",
                    height=500,
                    width=800,
                    plot_bgcolor="rgba(0,0,0,0)",
                    paper_bgcolor="rgba(0,0,0,0)",
                    margin=dict(l=250, r=50, t=50, b=150),  # Adjust margins for labels
                )
                fig_heatmap.update_xaxes(tickangle=20, tickfont=dict(size=10))
                fig_heatmap.update_yaxes(
                    categoryorder="array",
                    categoryarray=df_heatmap_pivot.index.tolist(),
                    tickfont=dict(size=10),
                )

                fig_heatmap.show()

                # --- Save Figure ---
                output_figure_dir = Path("publication_figures")
                output_figure_dir.mkdir(exist_ok=True)
                fig_path_html = (
                    output_figure_dir / "figure12_panel_b_heatmap_green_purple.html"
                )
                fig_heatmap.write_html(str(fig_path_html))
                print(f"\nUpdated heatmap HTML saved to: {fig_path_html}")
                try:
                    fig_path_svg = (
                        output_figure_dir / "figure12_panel_b_heatmap_green_purple.svg"
                    )
                    fig_heatmap.write_image(str(fig_path_svg))
                    print(f"Updated heatmap SVG saved to: {fig_path_svg}")
                except Exception as e:
                    print(f"Warning: Could not export heatmap to SVG. Error: {e}")
else:
    print("\nCould not generate heatmap due to data loading errors.")

print("\n--- Heatmap Regeneration Cell Complete ---")