<a href="https://colab.research.google.com/github/Palaeoprot/PRIDE/blob/main/BLAST_Collagen_Exon_Mapper_v1_9.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 🧬 **Collagen Exon Mapper v1.9 (Colab Edition)**

A production-ready, Colab-focused pipeline for mapping collagen exon
architectures across taxa, with robust caching, phylogeny-weighted
consensus, optional rescue of misread sequences, and clear per-step
reporting for new users.


# **Part 1: Paths & Project Layout (Harmonised)**

This notebook uses **two roots**:

- **WORK_ROOT (ephemeral)**: all per‑run outputs under `CollagenExonMapper/`.  
  These can be deleted after each run without affecting other users.

- **SHARED_ROOT (durable)**: all caches, master datasets, manifests, and archives
  under `_SHARED_DATA/ExonMaps/collagens/`.  
  These persist across runs and are used by all users.

Compatibility:
- We keep legacy aliases (`EXON_CACHE_PATH`, `FILTERED_UNIPROT_CACHE_PATH`, …)
  so v1.4 and v1.5 code paths continue to work.
- We also define v1.5 names (`RAW_EXONS_CACHE`, `FILTERED_UNIPROT_TSV`, …).

Deletion policy:
- It is safe to delete **only** `CollagenExonMapper/run_*` after a run.
- Never delete `_SHARED_DATA/ExonMaps/collagens/cache` or `…/runs_archive`.


In [None]:
# ===== Cell 1: Dependencies =====
# Description: Install and import all required libraries for the notebook.

# --- Installations ---
# Use pip to install necessary third-party libraries quietly (-q).
!pip install -q biopython ete3 requests pandas numpy matplotlib tqdm psutil

# --- Core Imports ---
import sys
import os
import pandas as pd
import numpy as np

# --- Environment Check & Setup ---
# Determine if running in Google Colab for environment-specific logic.
try:
    from google.colab import drive
    print("✅ Running in Google Colab environment.")
    IN_COLAB = True
except ImportError:
    print("⚠️ Not in a Google Colab environment.")
    IN_COLAB = False

# --- Version Information ---
# Print library versions for reproducibility.
print(f"Python: {sys.version.split()[0]}")
print(f"pandas: {pd.__version__}, numpy: {np.__version__}")

## Cell 10 – Install Dependencies & Mount Drive

Installs required libraries and mounts your Google Drive for persistent storage. If you are not in a Google Colab environment, mounting will be skipped.

In [None]:
# ===== Mount Google Drive =====
# Description: Mounts your Google Drive to the Colab virtual machine's
#              filesystem at the '/content/drive' directory.

try:
    from google.colab import drive
    print("💾 Mounting Google Drive...")
    # The force_remount=True option will re-mount the drive if it's already mounted.
    drive.mount('/content/drive', force_remount=True)
    IN_COLAB = True
    print("✅ Google Drive mounted successfully at /content/drive")
except ImportError:
    # This block runs if the code is not in a Google Colab environment.
    IN_COLAB = False
    print("⚠️ Not in a Google Colab environment. Drive mounting skipped.")

## Cell 11 – Central Configuration Panel

Defines all user-adjustable parameters for the run. These variables are used by subsequent cells to control gene selection, filtering thresholds, and optional features.

In [None]:
# ===== Cell 11 =====
# Central configuration panel and dynamic directory setup.

import os, re, io, time, json, hashlib, logging, psutil
from pathlib import Path
from datetime import datetime, timezone
from typing import Optional, List, Dict, Tuple, Set, Iterable
import pandas as pd
import numpy as np
import requests
import sys
from Bio import Entrez
from tqdm.notebook import tqdm

# ---------------- User-Configurable Parameters ----------------
#@markdown #### **Core Settings**
PROJECT_NAME = "CollagenExonMapper"  #@param {type:"string"}
USER_EMAIL   = "matthew@palaeome.org"  #@param {type:"string"}
Entrez.email = USER_EMAIL

#@markdown #### **Gene Selection**
PROCESS_ALL_GENES = True  #@param {type:"boolean"}
GENE_SYMBOLS     = "COL1A1,COL1A2"  #@param {type:"string"}

#@markdown #### **Taxonomic Filtering**
CLADE_TAXIDS = {
    "Metazoa": 33208, "Vertebrata": 7742, "Mammalia": 40674, "Aves": 8782,
    "Reptilia": 8504, "Amphibia": 8292, "Tetrapoda": 32523,
    "Bony fish": 117570, "Cartilaginous fish": 7777, "Catarrhini": 9526
}
TAXONOMIC_FILTER_NAME = "Metazoa" #@param ["Metazoa","Vertebrata","Mammalia","Aves","Reptilia","Amphibia","Tetrapoda","Bony fish","Cartilaginous fish","Catarrhini"]
TARGET_TAXID = CLADE_TAXIDS[TAXONOMIC_FILTER_NAME]

#@markdown #### **Thresholds (Sequences & Mapping)**
MIN_LEN_AA          = 600  #@param {type:"integer"}
MIN_GXY_TRIPLETS    = 30   #@param {type:"integer"}
CHAIN_LENGTH_THRESHOLD      = 0.90  #@param {type:"slider", min:0.5, max:1.0, step:0.05}
MAX_ALLOWED_GAP_PERCENTAGE  = 0.10  #@param {type:"slider", min:0.05, max:0.5, step:0.05}

# Rescue controls

#@markdown #### **Identify & rescue additional sequences (FAST; peptide-only)**
ENABLE_RESCUE             = True   #@param {type:"boolean"}
MINIMUM_RESCUE_SCORE      = 0.50   #@param {type:"number"}
RESCUE_SCORE_IMPROVEMENT  = 0.10   #@param {type:"number"}
GXY_CONTENT_THRESHOLD     = 0.80   #@param {type:"number"}

#@markdown #### **DNA fetch for rescue (SLOW; Ensembl REST)**
ENABLE_DNA_FETCH          = False  #@param {type:"boolean"}
ENSEMBL_REQUESTS_PER_SEC  = 5      #@param {type:"integer"}
ENSEMBL_TIMEOUT_SECS      = 20     #@param {type:"integer"}
ENSEMBL_BASE              = "https://rest.ensembl.org"  #@param {type:"string"}

#@markdown #### **Debugging**
DEBUG_SAMPLE_SIZE = -1  #@param {type:"integer"}

# Optional taxonomy engine (not required when UniProt lineage present)
USE_TAXONOMY_ENGINE = False  #@param {type:"boolean"}
#@markdown #### **Phylogenetic Cache Settings**
#@markdown Enable this only if you have updated the Newick tree and need to
#@markdown regenerate the node age lookup table. This is a slow, one-time process.
REGENERATE_NODE_AGE_CACHE = True #@param {type:"boolean"}

## Cell 12 – Two-Root Path Model & Project Setup

This cell establishes the core directory structure for the project, separating ephemeral run-specific outputs from durable shared data and caches. It also initializes logging. **All subsequent cells depend on these path variables.**

In [None]:
# ===== Cell 12 =====
# Two-root path model & project setup

# --- ROOTS ---
# Ephemeral outputs for this specific run. Safe to delete.
WORK_ROOT   = Path("/content/drive/MyDrive/CollagenExonMapper")
# Durable, shared data across all runs and users. Do not delete.
SHARED_ROOT = Path("/content/drive/MyDrive/Colab_Notebooks/GitHub/_SHARED_DATA/ExonMaps/collagens")

# --- TIMESTAMPED RUN DIRECTORY (EPHEMERAL) ---
RUN_TIMESTAMP = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
RUN_DIR = WORK_ROOT / f"run_{RUN_TIMESTAMP}"
RUN_ID = RUN_DIR.name # For use in filenames

# --- SHARED, DURABLE DIRECTORIES ---
CACHE_DIR        = SHARED_ROOT / "cache"
RUNS_ARCHIVE_DIR = SHARED_ROOT / "runs_archive"
DICT_DIR         = Path("/content/drive/MyDrive/Colab_Notebooks/GitHub/_SHARED_DATA/DICTIONARIES")
TAXO_DIR         = Path("/content/drive/MyDrive/Colab_Notebooks/GitHub/_SHARED_DATA/TAXONOMY")

# --- DURABLE FILES (AUTHORITATIVE) ---
FILTERED_UNIPROT_TSV = CACHE_DIR / "filtered_uniprot_cache.tsv"
MASTER_TSV_PATH      = CACHE_DIR / "master_collagen_dataset.tsv"
RAW_EXONS_CACHE      = CACHE_DIR / "raw_exons_cache.tsv"
REJECTED_IDS_PATH    = CACHE_DIR / "rejected_ids.tsv"
NODE_AGES_CSV_PATH = DICT_DIR / "metazoan_genus_node_ages.csv"
SYNTENY_CACHE_TSV    = CACHE_DIR / "synteny_location_cache.tsv"

# --- EPHEMERAL RUN-SPECIFIC FILES ---
RUN_LOG_FILE         = RUN_DIR / f"exonmapper_{RUN_TIMESTAMP}.log"
WORKING_SNAPSHOT     = RUN_DIR / "working_snapshot.tsv"
MAPPED_SNAPSHOT      = RUN_DIR / "raw_exons_this_run.tsv"
CONSENSUS_LONG_TSV   = RUN_DIR / "consensus_long.tsv"
CONSENSUS_TABLE_TSV  = RUN_DIR / "consensus_table.tsv"
EVOLUTION_EVENTS_TSV = RUN_DIR / "exon_evolution_events.tsv"
WIDE_ARCH_TSV        = RUN_DIR / "exon_wide.tsv"
ENTROPY_TSV          = RUN_DIR / "entropy_stats.tsv"
RESCUE_LOG_TSV       = RUN_DIR / "rescue_log.tsv"
RESCUE_HITS_TSV      = RUN_DIR / "rex_hits.tsv"
RESCUE_CHAINS_TSV    = RUN_DIR / "rex_chains.tsv"
ERROR_REPORT_PATH    = RUN_DIR / "error_analysis_report.json"
SESSION_REJECTED_PATH= RUN_DIR / f"rejected_ids_{RUN_ID}.txt"
RUN_MANIFEST_JSON    = RUN_DIR / "manifest.json"

# --- REFERENCES (SOFT DEPENDENCIES) ---
DRIVE_ARCHITECTURES_PATH = DICT_DIR / "GeneFamily/collagens/COLLAGEN_GXY_REPEAT_STRUCTURE.json"
DRIVE_METAZOAN_TREE_PATH = DICT_DIR / "multicellular animals_genus.nwk"
DRIVE_TAXONOMY_PATH      = TAXO_DIR / "ete3_ncbi_taxa.sqlite"

# --- INITIALIZATION ---
# Create directories (idempotent)
for p in [RUN_DIR, CACHE_DIR, RUNS_ARCHIVE_DIR]:
    p.mkdir(parents=True, exist_ok=True)

# Configure logging to both console and a run-specific file
# Remove any existing handlers to prevent duplicate logs in re-runs
for h in logging.root.handlers[:]:
    logging.root.removeHandler(h)
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] - %(message)s",
    handlers=[logging.StreamHandler(sys.stdout), logging.FileHandler(RUN_LOG_FILE)]
)
logger = logging.getLogger("exon-mapper")

logger.info("Two-root path model initialised successfully.")
logger.info(f"WORK_ROOT (ephemeral outputs): {WORK_ROOT}")
logger.info(f"SHARED_ROOT (durable data): {SHARED_ROOT}")
logger.info(f"RUN_DIR (this run's outputs): {RUN_DIR}")

## Cell 13 – Manifest & Taxonomy Engine Setup

This cell provides a helper function to write a JSON manifest for run provenance and initializes the optional taxonomy engine. The manifest is written once at the beginning of the run and can be updated at the end with final file hashes.

In [None]:
# ===== Cell 13 =====
# Manifest helper function & optional taxonomy engine

def sha256_file(p: Path) -> str:
    """Computes SHA256 hash of a file, returns empty string if not found."""
    if not p.is_file():
        return ""
    h = hashlib.sha256()
    with open(p, 'rb') as f:
        # Read in chunks to handle large files
        for chunk in iter(lambda: f.read(65536), b''):
            h.update(chunk)
    return h.hexdigest()

def write_run_manifest(extra: dict = None, final_files: list[Path] = None) -> None:
    """
    Writes/updates a JSON manifest for the run, including parameters and file hashes.
    """
    manifest_data = {}
    # If manifest exists, load it to update it
    if RUN_MANIFEST_JSON.exists():
        try:
            with open(RUN_MANIFEST_JSON, 'r') as f:
                manifest_data = json.load(f)
        except json.JSONDecodeError:
            logger.warning("Could not parse existing manifest; creating a new one.")
            manifest_data = {}

    # Initial data
    if not manifest_data:
        manifest_data = {
            "run_id": RUN_ID,
            "timestamp_utc": RUN_TIMESTAMP,
            "python_version": sys.version.split()[0],
            "parameters": {
                "PROJECT_NAME": PROJECT_NAME,
                "USER_EMAIL": USER_EMAIL,
                "PROCESS_ALL_GENES": PROCESS_ALL_GENES,
                "GENE_SYMBOLS": GENE_SYMBOLS,
                "TAXONOMIC_FILTER_NAME": TAXONOMIC_FILTER_NAME,
                "TARGET_TAXID": TARGET_TAXID,
                "MIN_LEN_AA": MIN_LEN_AA,
                "MIN_GXY_TRIPLETS": MIN_GXY_TRIPLETS,
                "DEBUG_SAMPLE_SIZE": DEBUG_SAMPLE_SIZE,
            },
            "paths": {
                "run_dir": str(RUN_DIR),
                "cache_dir": str(CACHE_DIR),
            }
        }

    if extra:
        manifest_data.update(extra)

    if final_files:
        manifest_data['output_files'] = [
            {"name": p.name, "path": str(p), "sha256": sha256_file(p)}
            for p in final_files
        ]

    with open(RUN_MANIFEST_JSON, "w") as f:
        json.dump(manifest_data, f, indent=2)
    logger.info(f"Run manifest written/updated → {RUN_MANIFEST_JSON}")

# Write the initial manifest for this run
write_run_manifest()

# --- (Optional) Taxonomy Engine ---
class NCBITaxonomyEngine:
    def __init__(self, enabled: bool, ncbi_db_path: Optional[Path] = None):
        self._ok = False
        self._source = "uniprot_lineage_only"
        self.ncbi=None
        if not enabled: return
        try:
            from ete3 import NCBITaxa as _ETE_NCBITaxa  # type: ignore
            if ncbi_db_path and ncbi_db_path.exists():
                self.ncbi = _ETE_NCBITaxa(dbfile=str(ncbi_db_path))
                self._source = f"ete3:{ncbi_db_path}"
            else:
                self.ncbi = _ETE_NCBITaxa()
                self._source = "ete3:default"
            # Test query
            _ = self.ncbi.get_lineage(1)
            self._ok = True
        except Exception as e:
            logger.warning(f"ETE3 taxonomy engine disabled or unavailable: {e}")

    def ok(self) -> bool: return self._ok

tax_engine = NCBITaxonomyEngine(USE_TAXONOMY_ENGINE, DRIVE_TAXONOMY_PATH)
logger.info(f"Taxonomy mode: {'ETE3' if tax_engine.ok() else 'UniProt-lineage only'}")

In [None]:
# ===== Cell 13 =====
# Manifest helper function & optional taxonomy engine

def safe_read_tsv(p: Path) -> pd.DataFrame:
    """
    Safely reads a TSV file, returning an empty DataFrame if the file
    does not exist or an error occurs, preventing pipeline crashes.
    """
    if not p.exists():
        logger.warning(f"Cache file not found: {p.name}. Returning empty DataFrame.")
        return pd.DataFrame()
    try:
        return pd.read_csv(p, sep='\t', low_memory=False)
    except Exception as e:
        logger.error(f"Failed to read {p.name}: {e}. Returning empty DataFrame.")
        return pd.DataFrame()

def sha256_file(p: Path) -> str:
    """Computes SHA256 hash of a file, returns empty string if not found."""
    if not p.is_file():
        return ""
    h = hashlib.sha256()
    with open(p, 'rb') as f:
        for chunk in iter(lambda: f.read(65536), b''):
            h.update(chunk)
    return h.hexdigest()

def write_run_manifest(extra: dict = None, final_files: list[Path] = None) -> None:
    """
    Writes/updates a JSON manifest for the run, including parameters and file hashes.
    """
    manifest_data = {}
    if RUN_MANIFEST_JSON.exists():
        try:
            with open(RUN_MANIFEST_JSON, 'r') as f:
                manifest_data = json.load(f)
        except json.JSONDecodeError:
            logger.warning("Could not parse existing manifest; creating a new one.")
            manifest_data = {}

    if not manifest_data:
        manifest_data = {
            "run_id": RUN_ID,
            "timestamp_utc": RUN_TIMESTAMP,
            "python_version": sys.version.split(),
            "parameters": {
                "PROJECT_NAME": PROJECT_NAME,
                "USER_EMAIL": USER_EMAIL,
                "PROCESS_ALL_GENES": PROCESS_ALL_GENES,
                "GENE_SYMBOLS": GENE_SYMBOLS,
                "TAXONOMIC_FILTER_NAME": TAXONOMIC_FILTER_NAME,
                "TARGET_TAXID": TARGET_TAXID,
                "MIN_LEN_AA": MIN_LEN_AA,
                "MIN_GXY_TRIPLETS": MIN_GXY_TRIPLETS,
                "DEBUG_SAMPLE_SIZE": DEBUG_SAMPLE_SIZE,
            },
            "paths": { "run_dir": str(RUN_DIR), "cache_dir": str(CACHE_DIR) }
        }

    if extra:
        manifest_data.update(extra)

    if final_files:
        manifest_data['output_files'] = [
            {"name": p.name, "path": str(p), "sha256": sha256_file(p)}
            for p in final_files
        ]

    with open(RUN_MANIFEST_JSON, "w") as f:
        json.dump(manifest_data, f, indent=2)
    logger.info(f"Run manifest written/updated → {RUN_MANIFEST_JSON}")

# Write the initial manifest for this run
write_run_manifest()

# --- (Optional) Taxonomy Engine ---
class NCBITaxonomyEngine:
    def __init__(self, enabled: bool, ncbi_db_path: Optional[Path] = None):
        self._ok = False
        self._source = "uniprot_lineage_only"
        self.ncbi=None
        if not enabled: return
        try:
            from ete3 import NCBITaxa as _ETE_NCBITaxa  # type: ignore
            if ncbi_db_path and ncbi_db_path.exists():
                self.ncbi = _ETE_NCBITaxa(dbfile=str(ncbi_db_path))
                self._source = f"ete3:{ncbi_db_path}"
            else:
                self.ncbi = _ETE_NCBITaxa()
                self._source = "ete3:default"
            _ = self.ncbi.get_lineage(1)
            self._ok = True
        except Exception as e:
            logger.warning(f"ETE3 taxonomy engine disabled or unavailable: {e}")

    def ok(self) -> bool: return self._ok

tax_engine = NCBITaxonomyEngine(USE_TAXONOMY_ENGINE, DRIVE_TAXONOMY_PATH)
logger.info(f"Taxonomy mode: {'ETE3' if tax_engine.ok() else 'UniProt-lineage only'}")

# **Part 1A: (Optional) Phylogenetic Cache Generation**

## Cell 15 – Rebuild Node Age Cache from Newick Tree

This cell contains the logic to traverse the entire Newick tree (`metazoan_genus_timtree.nwk`) and calculate the age (distance from the root) for every node. The output is a CSV file that maps node names and NCBI TaxIDs to their ages in millions of years.

**This is a slow, one-time operation.** It should only be run if the `REGENERATE_NODE_AGE_CACHE` flag in Cell 11 is set to `True`, which is typically only necessary after updating the source Newick tree. All subsequent cells will use the generated CSV for fast age lookups.

In [None]:
# ===== Cell 15 =====
# Rebuild node age cache from Newick file (run-on-demand)

if REGENERATE_NODE_AGE_CACHE:
    logger.info("▶️ REGENERATING NODE AGE CACHE. This may take several minutes.")

    if not DRIVE_METAZOAN_TREE_PATH.exists():
        logger.error(f"CRITICAL: Cannot find Newick tree at {DRIVE_METAZOAN_TREE_PATH}.")
        logger.warning("Aborting node age cache regeneration. The notebook will continue, "
                       "but phylogenetic features may be disabled or use uniform weights.")
        # This is no longer a fatal error; execution continues.
    else:
        try:
            from ete3 import Tree

            def generate_node_age_cache(nwk_path: Path, output_csv: Path):
                """
                Traverses a Newick tree to calculate and save the age of each node.
                """
                logger.info(f"Loading tree from: {nwk_path}")
                tree = Tree(str(nwk_path), format=1)

                # The root has no distance to itself, so its age is 0
                tree.dist = 0

                node_data = []

                # Traverse the tree from root to leaves
                # Note: The input tree branch lengths should be in 'support'
                for node in tqdm(tree.traverse("preorder"), desc="Calculating node ages"):
                    # Calculate age of children based on parent's age + branch length
                    parent_age = node.dist
                    for child in node.children:
                        child.dist = parent_age + child.support

                    # Store data for the current node
                    taxid = None
                    try:
                        # ETE3 often stores taxid in the name like 'Genus_name_9606'
                        if '_' in node.name and node.name.split('_')[-1].isdigit():
                            taxid = int(node.name.split('_')[-1])
                    except (ValueError, IndexError):
                        pass

                    node_data.append({
                        "node_name": node.name,
                        "age": node.dist,
                        "taxid": taxid
                    })

                df_ages = pd.DataFrame(node_data)
                df_ages.to_csv(output_csv, index=False)
                logger.info(f"✅ Successfully generated and saved node age cache to: {output_csv}")

            # Execute the function
            generate_node_age_cache(DRIVE_METAZOAN_TREE_PATH, NODE_AGES_CSV_PATH)

        except ImportError:
            logger.error("`ete3` is required for this step. Please ensure it is installed.")
        except Exception as e:
            logger.error(f"An error occurred during cache generation: {e}")

else:
    logger.info("☑️ Skipping node age cache regeneration (flag is False).")
    if not NODE_AGES_CSV_PATH.exists():
        logger.warning(f"Node age cache not found at {NODE_AGES_CSV_PATH}. "
                       "The dating engine may not work. "
                       "Consider setting REGENERATE_NODE_AGE_CACHE to True.")

# **Part 2: Data Loading & Pre-processing**

## Cell 21 – Unified Data Import, Normalization, and Healing

This cell is the primary data entry point for the pipeline. It loads data from multiple sources and applies a **comprehensive, multi-stage normalization routine restored from previous, robust versions.** This includes:
1.  **Data Healing:** Proactively reconstructs the essential `'Organism (ID)'` column if it is missing from older cache files.
2.  **Variant Identification:** Identifies and separates gene variants (e.g., `COL1A1A`, `COL1A1_L`) into dedicated columns.
3.  **Robust Classification:** Uses a tiered regex system to infer gene symbols from protein names and classify probable but un-named collagens.

The result is a clean, reliable, and fully annotated dataset that is then filtered to create the `working_df` for the current run.

In [None]:
# ===== Cell 21 =====
# Unified import, normalization, and variant handling (with Data Healing & Restored Logic)

from Bio import SeqIO

# --- Restored, more robust normalization helper functions ---
def roman_to_int(s: str) -> int:
    roman_map = {'I': 1, 'V': 5, 'X': 10}; val = 0; s = s.upper()
    for i in range(len(s)):
        if i > 0 and roman_map.get(s[i], 0) > roman_map.get(s[i - 1], 0):
            val += roman_map.get(s[i], 0) - 2 * roman_map.get(s[i - 1], 0)
        else: val += roman_map.get(s[i], 0)
    return val

def infer_specific_collagen_symbol(name: str) -> Optional[str]:
    if not isinstance(name, str) or 'COLLAGEN' not in name.upper(): return None
    match = re.search(r'ALPHA[ -]?(\d+)\s*\((.*?)\)|TYPE\s+([IVXLCDM]+).*?ALPHA\s+(\d+)', name.upper())
    if not match: return None
    chain_num = match.group(1) or match.group(4); type_roman = match.group(2) or match.group(3)
    if not chain_num or not type_roman: return None
    try:
        type_roman_cleaned = re.sub(r'[^IVXLCDM]', '', type_roman)
        if not type_roman_cleaned: return None
        return f"COL{roman_to_int(type_roman_cleaned)}A{chain_num}"
    except (KeyError, IndexError): return None

RX_GXY = re.compile(r"(?i)\b(g(?:ly)?\s*[- ]?\s*x\s*[- ]?\s*y\b|gxy\b)")
def classify_as_probable_collagen(name: str) -> bool:
    if not isinstance(name, str): return False
    if 'COLLAGEN' in name.upper() and RX_GXY.search(name) and not re.search(r'RECEPTOR|BINDING|ASE', name.upper()):
        return True
    return False

def consolidated_gene_normalization(row: pd.Series) -> pd.Series:
    """Applies the full, robust normalization logic to a DataFrame row."""
    primary_name = row.get('Gene Names (primary)', '')
    protein_names = row.get('Protein names', '')
    base_gene, variant_type, note = "UNKNOWN", "", ""

    if isinstance(primary_name, str) and primary_name.strip():
        # --- FIX: Get the first word, then convert to upper case ---
        words = primary_name.split()
        if words: # Ensure the list is not empty
            first_word = words[0].upper()

            variant_match = re.search(r'^(COL\d+A\d+)(?:_([A-Z0-9_]+)|([A-Z]))$', first_word)
            if variant_match:
                base_gene = variant_match.group(1)
                variant = next((v for v in variant_match.groups()[1:] if v is not None), "")
                return pd.Series([base_gene, variant, f"gene_variant_{variant}"], index=['gene_symbol_norm', 'variant_type', 'species_note'])

            if re.match(r'^COL\d+A\d+$', first_word):
                base_gene = first_word

    if base_gene == 'UNKNOWN':
        inferred = infer_specific_collagen_symbol(protein_names)
        if inferred:
            base_gene = inferred

    if base_gene == 'UNKNOWN' and classify_as_probable_collagen(protein_names):
        base_gene = 'PROBABLE_COLLAGEN'

    return pd.Series([base_gene, variant_type, note], index=['gene_symbol_norm', 'variant_type', 'species_note'])

def load_new_inputs_from_content() -> pd.DataFrame:
    """Scans /content/ for new FASTA or TSV.GZ files and loads them."""
    loaded_dfs = []
    content_dir = Path("/content").resolve()

    def parse_fasta_header(record, file_path):
        h, e = record.description, record.id
        gn = re.search(r"GN=([^ ]+)", h); os = re.search(r"OS=([^(]+)", h)
        gene = gn.group(1) if gn else ""
        org = os.group(1).strip().replace('_', ' ') if os else ""
        return {'Entry':e, 'Gene Names (primary)':gene, 'Organism':org, 'Sequence':str(record.seq), 'Reviewed':'unreviewed', 'source_file':file_path.name}

    for p in content_dir.glob("*.fasta"):
        logger.info(f"Found FASTA for dynamic loading: {p.name}")
        recs = [parse_fasta_header(rec, p) for rec in SeqIO.parse(p, "fasta")]
        loaded_dfs.append(pd.DataFrame([r for r in recs if r]))

    for p in content_dir.glob("uniprot*.tsv.gz"):
        logger.info(f"Found TSV.GZ for dynamic loading: {p.name}")
        try:
            df = pd.read_csv(p, sep='\t', compression='gzip', low_memory=False)
            df['source_file'] = p.name
            loaded_dfs.append(df)
        except Exception as e:
            logger.warning(f"Could not load TSV {p.name}: {e}")

    if not loaded_dfs:
        return pd.DataFrame()
    return pd.concat(loaded_dfs, ignore_index=True)

# --- Main Data Loading and Processing Logic ---
logger.info("--- Starting Unified Data Import, Normalization, and Healing ---")

df_master = safe_read_tsv(MASTER_TSV_PATH)
df_new = load_new_inputs_from_content()
full_df = pd.concat([df_master, df_new], ignore_index=True)
full_df.drop_duplicates(subset=['Entry'], keep='last', inplace=True)

if 'Organism (ID)' not in full_df.columns:
    logger.warning("Column 'Organism (ID)' not found. Attempting to heal from legacy columns.")
    if 'Organism' in full_df.columns and 'Organism ID' in full_df.columns:
        full_df['Organism (ID)'] = full_df['Organism'].astype(str) + " (" + full_df['Organism ID'].astype(str) + ")"
        logger.info("Successfully reconstructed 'Organism (ID)'.")

logger.info(f"Normalizing gene symbols for {len(full_df)} entries...")
tqdm.pandas(desc="Normalizing Gene Symbols")
norm_cols = full_df.progress_apply(consolidated_gene_normalization, axis=1)
full_df[['gene_symbol_norm', 'variant_type', 'species_note']] = norm_cols

target_genes = [s.strip().upper() for s in GENE_SYMBOLS.split(',')]
if PROCESS_ALL_GENES:
    working_df = full_df.copy()
    logger.info("Processing all identified collagen genes.")
else:
    working_df = full_df[full_df['gene_symbol_norm'].isin(target_genes)].copy()
    logger.info(f"Filtering for specific genes: {target_genes}")

if 'Taxonomic lineage (Ids)' in working_df.columns:
    mask = working_df['Taxonomic lineage (Ids)'].astype(str).str.contains(f"\\b{TARGET_TAXID}\\b")
    working_df = working_df[mask].copy()
    logger.info(f"Filtered for Taxon ID {TARGET_TAXID} ({TAXONOMIC_FILTER_NAME}), {len(working_df)} entries remaining.")

if DEBUG_SAMPLE_SIZE > 0:
    logger.warning(f"DEBUG MODE: Sampling down to {DEBUG_SAMPLE_SIZE} entries.")
    working_df = working_df.sample(n=min(DEBUG_SAMPLE_SIZE, len(working_df)), random_state=42)

working_df.to_csv(WORKING_SNAPSHOT, sep='\t', index=False)
logger.info(f"✅ Unified import complete. Working set has {len(working_df)} entries.")
logger.info(f"Snapshot for this run saved to: {WORKING_SNAPSHOT.name}")

## Cell 21A – Diagnostic Check for Data Normalization

This cell provides a crucial verification step to inspect the output of the complex normalization and classification logic in Cell 21. It generates a summary profile of the `working_df`, allowing for a quick assessment of whether gene symbols and variants were processed correctly before the pipeline proceeds to the more intensive mapping and analysis stages.

In [None]:
# ===== Cell 21A =====
# Diagnostic Check for Data Normalization

if 'working_df' in globals() and not working_df.empty:
    logger.info("--- 📊 DIAGNOSTIC PROFILE of Cell 21 Output ---")

    # 1. Profile the main output: gene_symbol_norm
    logger.info("\n[1] Distribution of Normalized Gene Symbols in `working_df`:")
    gene_counts = working_df['gene_symbol_norm'].value_counts()
    with pd.option_context('display.max_rows', 20):
        display(gene_counts)

    # 2. Profile the variant/paralog identification
    if 'variant_type' in working_df.columns:
        logger.info("\n[2] Profile of Identified Gene Variants/Paralogs:")
        variant_counts = working_df['variant_type'].dropna().value_counts()
        if not variant_counts.empty:
            display(variant_counts)
        else:
            logger.info("   -> No variants with suffixes were identified in the working set.")

    # 3. Display a sample of the key columns to visually inspect the results
    logger.info("\n[3] Sample of DataFrame showing key normalization columns:")
    display_cols = [
        'Entry',
        'Organism',
        'Gene Names (primary)',
        'Protein names',
        'gene_symbol_norm', # The final base gene
        'variant_type',     # The extracted variant
        'species_note'
    ]
    # Ensure all columns exist before trying to display them
    final_display_cols = [col for col in display_cols if col in working_df.columns]
    display(working_df[final_display_cols].head(10))

else:
    logger.warning("`working_df` is empty. Cannot generate a diagnostic profile.")


## Cell 22 – Proactive Ensembl Gene ID Discovery (Hybrid Strategy)

This cell enriches our dataset with high-quality Ensembl data using a robust, two-pass hybrid strategy.

### Process

1.  It builds a comprehensive map of all Ensembl gene family members.
2.  **Pass 1 (Protein ID Link):** It first attempts to link UniProt entries to Ensembl data using the stable Ensembl Protein ID found in the cross-reference column. This is the most reliable method.
3.  **Pass 2 (Species/Gene Name Link):** For any entries that could not be matched via Protein ID, it falls back to a second attempt, matching on the combination of the species name and the normalized gene symbol.
4.  The results are coalesced, creating the final `ensembl_id` and `ensembl_genus` columns.

In [None]:
# ===== Cell 22 =====
# Proactive Ensembl Gene ID Discovery via Homology API (Hybrid Strategy)

from dataclasses import dataclass, asdict
from typing import Optional, Dict, List, Set, Any, Tuple

# --- Data Structures and API Client ---
@dataclass
class GeneResult:
    """Container for an Ensembl homology member."""
    ensembl_gene_id: str
    gene_symbol: str
    species: str
    ensembl_protein_id: str
    is_paralog: bool = False


class EnsemblClient:
    """
    Robust, rate-limited client for the Ensembl REST API with retry logic.
    Notes
    -----
    * Treats HTTP 400 as a 'clean miss' (e.g., symbol not defined in species).
    * Retries on 429/5xx with backoff; respects a minimal inter-request delay.
    """
    def __init__(self, rate_limit_delay: float = 0.1) -> None:
        self.base_url = "https://rest.ensembl.org"
        self.session = requests.Session()
        self.session.headers.update({"Accept": "application/json"})
        self.last_request_time = 0.0
        self.rate_limit_delay = rate_limit_delay

    def get(self, endpoint: str, params: Optional[Dict[str, Any]] = None) -> Optional[requests.Response]:
        """Perform a rate-limited GET request with bounded retries; return None on definitive failure."""
        # Respect a minimal delay between requests
        elapsed = time.time() - self.last_request_time
        if elapsed < self.rate_limit_delay:
            time.sleep(self.rate_limit_delay - elapsed)

        for attempt in range(3):
            try:
                full_url = self.base_url + endpoint
                r = self.session.get(full_url, params=params, timeout=30)
                self.last_request_time = time.time()

                # Explicit handling of rate limiting
                if r.status_code == 429:
                    retry_after = int(r.headers.get("Retry-After", 2))
                    logger.warning(f"Rate limit hit on {endpoint}. Retrying after {retry_after}s...")
                    time.sleep(retry_after)
                    continue

                # Treat 400 as a definite miss (e.g., bad symbol for species)
                if r.status_code == 400:
                    logger.error(f"HTTP 400 for {full_url} (likely undefined symbol/species combo); skipping.")
                    return None

                r.raise_for_status()
                return r

            except requests.RequestException as e:
                logger.error(f"Ensembl API request failed on attempt {attempt + 1}: {e}")
                if attempt < 2:
                    time.sleep(2 ** attempt)  # backoff

        return None


class EnsemblGeneTreeFinder:
    """Discover gene family members using the Ensembl Homology API."""
    def __init__(self) -> None:
        self.client = EnsemblClient()

    def discover_family_members(self, seed_symbol: str, seed_species: str = "homo_sapiens") -> List[GeneResult]:
        """
        Query Ensembl for orthologues and paralogues of a seed gene.

        Parameters
        ----------
        seed_symbol : str
            HGNC-like gene symbol (e.g., 'COL1A1').
        seed_species : str
            Ensembl species name (default 'homo_sapiens').

        Returns
        -------
        List[GeneResult]
            Zero or more family members discovered for the seed.
        """
        endpoint = f"/homology/symbol/{seed_species}/{seed_symbol}"
        params = {"content-type": "application/json", "sequence": "protein", "type": "all", "format": "full"}

        logger.info(f"Discovering family for '{seed_symbol}'...")
        r = self.client.get(endpoint, params=params)
        if not r:
            logger.error(f"Failed to retrieve data for {seed_symbol}.")
            return []

        payload = r.json()
        data = payload.get("data", [])
        if not data:
            logger.warning(f"No homology data for {seed_symbol}.")
            return []

        results: List[GeneResult] = []
        seen: Set[str] = set()

        for group in data:
            for homology in group.get("homologies", []):
                # capture both source and target entries
                for side in ("source", "target"):
                    member = homology.get(side, {}) or {}
                    pid = member.get("protein_id")
                    if not pid:
                        continue
                    # safeguard for pathological objects
                    if member.get("cigar_line") is None:
                        continue
                    if pid in seen:
                        continue

                    species_raw = member.get("species", "unknown")
                    # Normalise species to Ensembl style 'genus_species'
                    species_norm = str(species_raw).replace(" ", "_").lower()

                    results.append(
                        GeneResult(
                            ensembl_gene_id=member.get("id", ""),
                            gene_symbol=member.get("symbol", seed_symbol),
                            species=species_norm,
                            ensembl_protein_id=pid,
                            is_paralog=("paralog" in str(homology.get("type", "")).lower()),
                        )
                    )
                    seen.add(pid)

        logger.info(f"   Discovered {len(results)} members for {seed_symbol}.")
        return results


def _normalise_species_from_organism(organism: str) -> str:
    """
    Convert UniProt 'Organism' field to Ensembl-like 'genus_species' key.

    Examples
    --------
    'Homo sapiens (Human)' -> 'homo_sapiens'
    'Gallus gallus (Chicken)' -> 'gallus_gallus'
    """
    s = str(organism)
    # Drop parenthetical synonyms, trailing/leading whitespace
    s = re.sub(r"\s*\(.*?\)", "", s).strip()
    # Collapse multiple spaces and convert
    s = re.sub(r"\s+", " ", s)
    parts = s.split(" ")
    if len(parts) >= 2:
        s = (parts[0] + "_" + parts[1]).lower()
    else:
        s = s.replace(" ", "_").lower()
    return s


# --- Main Logic ---
if 'working_df' in globals() and isinstance(working_df, pd.DataFrame) and not working_df.empty:
    logger.info("--- Starting Proactive Ensembl Gene ID Discovery (Hybrid Strategy) ---")
    finder = EnsemblGeneTreeFinder()

    # Identify canonical collagen symbols from your normalised column
    all_symbols = working_df['gene_symbol_norm'].dropna().astype(str).unique()
    canonical_seeds = {re.sub(r'(?:_[A-Z0-9_]+|[A-Z])$', '', s) for s in all_symbols}
    seed_genes = sorted([s for s in canonical_seeds if re.match(r'^COL\d+A\d+$', s)])

    logger.info(f"Identified {len(seed_genes)} canonical gene families to query.")

    # Discover across all seeds, skipping those that 400 (undefined in human)
    all_discovered: List[GeneResult] = []
    for gene_symbol in tqdm(seed_genes, desc="Discovering Gene Families"):
        all_discovered.extend(finder.discover_family_members(gene_symbol))

    if all_discovered:
        # Build DataFrame robustly from dicts to avoid column/index surprises
        df_ensembl_map = pd.DataFrame([asdict(x) for x in all_discovered])

        # Defensive normalisation of expected columns
        expected_cols = {'ensembl_gene_id', 'gene_symbol', 'species', 'ensembl_protein_id', 'is_paralog'}
        missing = expected_cols.difference(df_ensembl_map.columns)
        if missing:
            logger.error(f"Missing expected columns in df_ensembl_map: {sorted(missing)}")
            # Create placeholders if truly absent (should not happen with asdict)
            for col in missing:
                df_ensembl_map[col] = pd.NA

        # 1) Map ENSP -> ENSG
        protein_to_gene_map: Dict[str, str] = {}
        for row in df_ensembl_map.itertuples(index=False):
            if getattr(row, 'ensembl_protein_id', None) and getattr(row, 'ensembl_gene_id', None):
                protein_to_gene_map[row.ensembl_protein_id] = row.ensembl_gene_id

        # 2) Map ENSP -> genus (first token of species)
        protein_to_genus_map: Dict[str, str] = {}
        for row in df_ensembl_map.itertuples(index=False):
            pid = getattr(row, 'ensembl_protein_id', None)
            sp = getattr(row, 'species', None)
            if pid and sp:
                genus = str(sp).split('_', 1)[0]
                protein_to_genus_map[pid] = genus

        # Ensure Cross-reference column exists
        if 'Cross-reference (Ensembl)' not in working_df.columns:
            logger.warning("'Cross-reference (Ensembl)' column not found. Creating it as empty.")
            working_df['Cross-reference (Ensembl)'] = ""

        # Extract ENSP accessions from Cross-reference (Ensembl)
        def extract_ensp(xref: Any) -> Optional[str]:
            """Return first ENSP-like token from a cross-reference field."""
            m = re.search(r'(ENSP\d+)', str(xref))
            return m.group(1) if m else None

        working_df['ensembl_protein_id'] = working_df['Cross-reference (Ensembl)'].apply(extract_ensp)
        working_df['ensembl_id_pass1'] = working_df['ensembl_protein_id'].map(protein_to_gene_map)
        working_df['ensembl_genus_pass1'] = working_df['ensembl_protein_id'].map(protein_to_genus_map)

        # Build species+gene_symbol -> ENSG and -> genus maps without set_index()
        species_gene_to_id_map: Dict[Tuple[str, str], str] = {}
        species_gene_to_genus_map: Dict[Tuple[str, str], str] = {}
        for row in df_ensembl_map.itertuples(index=False):
            sp = getattr(row, 'species', None)
            gs = getattr(row, 'gene_symbol', None)
            eg = getattr(row, 'ensembl_gene_id', None)
            if sp and gs and eg:
                species_gene_to_id_map[(sp, gs)] = eg
                species_gene_to_genus_map[(sp, gs)] = str(sp).split('_', 1)[0]

        # Derive a species_key in working_df compatible with Ensembl style
        working_df['species_key'] = working_df['Organism'].apply(_normalise_species_from_organism)

        # Second-pass matching on (species_key, gene_symbol_norm)
        unmatched_mask = working_df['ensembl_id_pass1'].isna()
        if unmatched_mask.any():
            sp_keys = working_df.loc[unmatched_mask, 'species_key'].astype(str)
            gene_syms = working_df.loc[unmatched_mask, 'gene_symbol_norm'].astype(str)
            keys = list(zip(sp_keys, gene_syms))
            working_df.loc[unmatched_mask, 'ensembl_id_pass2'] = [species_gene_to_id_map.get(k) for k in keys]
            working_df.loc[unmatched_mask, 'ensembl_genus_pass2'] = [species_gene_to_genus_map.get(k) for k in keys]
        else:
            working_df['ensembl_id_pass2'] = pd.NA
            working_df['ensembl_genus_pass2'] = pd.NA

        # Consolidate passes
        working_df['ensembl_id'] = working_df['ensembl_id_pass1'].fillna(working_df['ensembl_id_pass2'])
        working_df['ensembl_genus'] = working_df['ensembl_genus_pass1'].fillna(working_df['ensembl_genus_pass2'])

        # Clean up (be permissive)
        working_df.drop(
            columns=['ensembl_id_pass1', 'ensembl_genus_pass1', 'ensembl_id_pass2', 'ensembl_genus_pass2'],
            inplace=True,
            errors='ignore'
        )

        matched_count = int(working_df['ensembl_id'].notna().sum())
        total = int(len(working_df))
        percentage = (matched_count / total) * 100 if total > 0 else 0.0
        logger.info(f"✅ Ensembl ID Discovery complete. Matched {matched_count}/{total} ({percentage:.1f}%) entries.")
    else:
        logger.warning("No gene family members were discovered. Synteny check will be skipped.")
else:
    logger.warning("Working dataframe is empty. Skipping Ensembl ID discovery.")


## Cell 23 – Gene Variant & Paralog Handling

This cell processes the working dataset to handle gene variants, often found in teleost fish, which are indicated by suffixes like `_L`, `_S`, or `_1`, `_2`. It normalizes these gene symbols to a common base name and adds specific columns (`variant_type`, `species_note`) to track them for downstream analysis.

In [None]:
# ===== Cell 23 =====
# Gene Variant & Paralog Handling (Generalized)

def handle_gene_variants(df: pd.DataFrame) -> pd.DataFrame:
    """
    Identifies and normalizes gene variants with suffixes.
    Handles three patterns: COL1A1a, COL1A1_L, COL1A1_2, COL1A1_ISOX1
    """
    if df.empty or 'gene_symbol_norm' not in df.columns:
        return df

    logger.info("🧬 Processing gene variants/paralogs...")
    processed_df = df.copy()

    # Regex to find any known variant suffix pattern
    variant_mask = processed_df['gene_symbol_norm'].str.contains(r'(?:_[A-Z0-9_]+|[A-Z])$', na=False)

    if not variant_mask.any():
        logger.info("No gene variants with recognized suffixes detected.")
        return processed_df

    variant_df = processed_df[variant_mask].copy()

    # Extract the suffix as the variant type
    variant_df['variant_type'] = variant_df['gene_symbol_norm'].str.extract(r'(?:_([A-Z0-9_]+)|([A-Z]))$').fillna('').sum(axis=1)
    # Extract the base gene name
    variant_df['base_gene'] = variant_df['gene_symbol_norm'].str.replace(r'(?:_[A-Z0-9_]+|[A-Z])$', '', regex=True)
    # Create a descriptive note
    variant_df['species_note'] = 'gene_variant_' + variant_df['variant_type']

    # Update the gene symbol to the base form for consistent grouping
    variant_df['gene_symbol_norm'] = variant_df['base_gene']

    # Update the main dataframe
    processed_df.update(variant_df)

    # Add new columns if they don't exist
    if 'variant_type' not in processed_df.columns: processed_df['variant_type'] = pd.NA
    if 'species_note' not in processed_df.columns: processed_df['species_note'] = pd.NA

    processed_df.loc[variant_mask, 'variant_type'] = variant_df['variant_type']
    processed_df.loc[variant_mask, 'species_note'] = variant_df['species_note']
    processed_df.loc[variant_mask, 'gene_symbol_norm'] = variant_df['gene_symbol_norm']

    num_fixed = variant_mask.sum()
    logger.info(f"✅ Gene variant processing complete: {num_fixed} mappings updated.")
    return processed_df

if 'working_df' in globals() and not working_df.empty:
    working_df = handle_gene_variants(working_df)
    working_df.to_csv(WORKING_SNAPSHOT, sep='\t', index=False)
    logger.info("📸 Updated working snapshot with variant/paralog fixes.")
else:
    logger.info("⚠️ No working dataset to process for variants/paralogs.")

## Cell 24 – Taxonomic Lineage Expansion & Genus Finalization

This cell standardizes the taxonomic information for each entry. It parses UniProt lineage strings into standard ranks and creates composite keys for analysis.

**Crucially, it now creates the final, authoritative `cluster_genus` column.** It prioritizes the genus name discovered from the Ensembl database (from Cell 22) as it is more standardized. If an Ensembl genus is not available, it falls back to parsing the genus from the UniProt `Organism` string. This ensures the most reliable possible genus is used for all downstream phylogenetic weighting and grouping.

In [None]:
# ===== Cell 24 =====
# Taxonomic lineage expansion & authoritative genus finalization

RANK_CODES = ["King","Phyl","Clas","Orde","Fami","Genu","Spec"]

# ... [parse_taxonomic_lineage and _split_lineage_names functions remain the same] ...
def parse_taxonomic_lineage(lineage_ids_str: str) -> Dict[str, Optional[int]]:
    out = {k: None for k in RANK_CODES};
    if not isinstance(lineage_ids_str, str): return out
    ids = [int(x) for x in re.findall(r"\d+", lineage_ids_str)];
    if not ids: return out
    tail = ids[-7:] if len(ids) >= 7 else [None]*(7-len(ids)) + ids
    return {code: taxid for code, taxid in zip(RANK_CODES, tail)}

def _split_lineage_names(s: str) -> List[str]:
    if not isinstance(s, str) or not s.strip(): return []
    return [p.strip() for p in re.split(r"[;,]\s*", s) if p]

if 'working_df' in globals() and not working_df.empty:
    logger.info("Expanding taxonomic lineages and finalizing genus...")

    # Standard lineage expansion from UniProt strings
    working_df["Lineage_Names"] = working_df.get("Taxonomic lineage", pd.Series([[]]*len(working_df))).apply(_split_lineage_names)
    parsed_ids = working_df.get("Taxonomic lineage (Ids)", pd.Series([""]*len(working_df))).astype(str).apply(parse_taxonomic_lineage)
    for code in RANK_CODES:
        working_df[f"{code}_id"] = parsed_ids.apply(lambda d: d.get(code))

    # --- FIX: Create the authoritative 'cluster_genus' column ---
    # 1. Parse Genus from UniProt 'Organism' string as a fallback
    working_df['uniprot_genus'] = working_df['Organism'].str.split().str[0]

    # 2. Coalesce, prioritizing the Ensembl-derived genus
    # If 'ensembl_genus' exists and is not null, use it; otherwise, use 'uniprot_genus'.
    if 'ensembl_genus' in working_df.columns:
        working_df['cluster_genus'] = working_df['ensembl_genus'].fillna(working_df['uniprot_genus'])
    else:
        logger.warning("'ensembl_genus' column not found, falling back to UniProt only.")
        working_df['cluster_genus'] = working_df['uniprot_genus']

    # 3. Infer Family and Order for other grouping tasks
    working_df['Family'] = working_df['Lineage_Names'].apply(lambda L: next((n for n in reversed(L) if n.endswith("idae")), ""))
    working_df['Order'] = working_df['Lineage_Names'].apply(lambda L: next((n for n in reversed(L) if n.endswith("iformes")), ""))

    logger.info(f"Taxonomy expanded. Authoritative 'cluster_genus' column created.")
else:
    logger.info("No working rows; skipping taxonomy enrichment.")

## Cell 25 – Final Input Checks

This final pre-processing step ensures that the `working_df` DataFrame contains the essential columns (`Entry`, `Sequence`) and that sequences are not empty, preventing errors in downstream cells.

In [None]:
# ===== Cell 25 =====
# Final filters and input checks

if 'working_df' in globals() and not working_df.empty:
    required_cols = ['Entry', 'Sequence']
    missing_cols = [c for c in required_cols if c not in working_df.columns]
    if missing_cols:
        logger.error(f"FATAL: working_df is missing required columns: {missing_cols}. Halting execution.")
        # Stop execution if critical columns are missing
        assert False, "Missing critical columns in working_df"
    else:
        # Filter out rows with empty or NaN sequences
        initial_rows = len(working_df)
        working_df.dropna(subset=['Sequence'], inplace=True)
        working_df = working_df[working_df['Sequence'].str.len() > 0].copy()
        final_rows = len(working_df)
        logger.info(f"Post-filter rows: {initial_rows} -> {final_rows} (removed empty sequences).")
else:
    logger.warning("working_df is empty or not defined. Downstream cells may fail.")

## Cell 26 – Persist Session Rejected Entries

Identifies entries from the initial dataset that were filtered out during pre-processing (e.g., by gene, taxon, or QC). These rejected IDs are saved to a run-specific file and appended to the durable master rejection list.

In [None]:
# ===== Cell 26 =====
# Persist session rejected IDs

logger.info("Archiving entries rejected during pre-processing...")

if 'full_df' in globals() and 'working_df' in globals():
    all_entries = set(full_df['Entry'].dropna())
    accepted_entries = set(working_df['Entry'].dropna())
    session_rejected_ids = sorted(list(all_entries - accepted_entries))

    if session_rejected_ids:
        logger.info(f"{len(session_rejected_ids)} entries were rejected in this session.")

        # Save session-specific rejection list
        with open(SESSION_REJECTED_PATH, 'w') as f:
            for acc in session_rejected_ids:
                f.write(f"{acc}\n")
        logger.info(f"Session rejection list saved to: {SESSION_REJECTED_PATH}")

        # Update the master rejection TSV
        new_rejections = pd.DataFrame({
            "Entry": session_rejected_ids,
            "reason": "filtered_by_gene_or_taxon",
            "run_id": RUN_ID
        })

        # Load, append, deduplicate, and save master list
        master_rejected = pd.DataFrame(columns=["Entry","reason","run_id"])
        if REJECTED_IDS_PATH.exists():
            master_rejected = safe_read_tsv(REJECTED_IDS_PATH)

        updated_master = pd.concat([master_rejected, new_rejections], ignore_index=True)
        updated_master.drop_duplicates(subset=['Entry'], keep='last', inplace=True)
        updated_master.to_csv(REJECTED_IDS_PATH, sep='\t', index=False)

        new_count = len(updated_master) - len(master_rejected)
        logger.info(f"Master rejection TSV updated with {new_count} new entries. Total: {len(updated_master)}.")
    else:
        logger.info("No new entries were rejected in this session.")
else:
    logger.warning("full_df or working_df not available; cannot compute session rejections.")

# **Part 3: Chain Identification & Quality Pre-selection**

## Cell 31 – Detect Main G-X-Y Collagenous Chain

Scans each protein sequence for long, contiguous runs of G-X-Y triplets, which define the primary collagenous domain. It identifies the longest such segment and calculates a preliminary quality score based on its length.

In [None]:
# ===== Cell 31 =====
# Identify main G–X–Y chain(s) (with robust column preservation)

def find_main_chain_info(sequence: str) -> pd.Series:
    """
    Analyzes a single sequence to find the main GXY chain and its properties.
    Returns a pandas Series, suitable for use with DataFrame.apply.
    """
    if not isinstance(sequence, str):
        return pd.Series({
            'main_chain_segments': None,
            'quality_score': 0.0,
            'quality_flags': 'missing_sequence'
        })

    segs = []
    i = 0
    while i <= len(sequence) - 3:
        if sequence[i] == 'G':
            j = i
            while j + 2 < len(sequence) and sequence[j] == 'G': j += 3
            if (j - i) % 3 != 0: j -= ((j - i) % 3)
            if (j - i) // 3 >= MIN_GXY_TRIPLETS:
                segs.append({'start': i + 1, 'end': j})
            i = j
        else:
            i += 1

    if not segs:
        return pd.Series({
            'main_chain_segments': None,
            'quality_score': 0.0,
            'quality_flags': 'no_gxy_chain_found'
        })

    main_seg = max(segs, key=lambda x: x['end'] - x['start'])
    triplets = (main_seg['end'] - main_seg['start']) // 3
    qscore = float(min(100.0, 100.0 * triplets / max(1, MIN_GXY_TRIPLETS)))
    has_cys = 'C' in sequence[main_seg['start']-1:main_seg['end']]

    return pd.Series({
        'main_chain_segments': [main_seg],
        'quality_score': qscore,
        'quality_flags': "cys_in_helix" if has_cys else ""
    })

if 'working_df' in globals() and not working_df.empty:
    logger.info("Identifying main GXY chains using robust apply method...")

    # --- FIX: Use DataFrame.apply for a robust, column-preserving operation ---
    tqdm.pandas(desc="Finding GXY Chains")
    chain_info_df = working_df['Sequence'].progress_apply(find_main_chain_info)

    # Join the new columns back to the original working_df
    chain_df = working_df.join(chain_info_df)

    logger.info(f"Chain identification complete. All original columns preserved.")
else:
    chain_df = pd.DataFrame()
    logger.warning("Working dataframe is empty. Skipping chain identification.")

## Cell 32 – Quality Control and Rejection Persistence

Applies quality control filters based on sequence length, G-X-Y content, and the presence of cysteine in the helix. Entries that fail QC are logged and added to the master rejection list with a specific failure reason. This cell now uses an index-based filtering method to ensure all original columns are preserved in the `df_high_quality` output.

In [None]:
# ===== Cell 32 =====
# QC pass/fail and rejection persistence (with robust column preservation)

df_high_quality = pd.DataFrame()
df_failed_qc = pd.DataFrame()

if 'chain_df' in globals() and not chain_df.empty:
    reasons_list = []
    for _, r in chain_df.iterrows():
        reasons = []

        # --- FIX: Defensively check for a valid segment list before subscripting ---
        segments = r.get('main_chain_segments')

        # Check if segments is None, NaN, or an empty list
        if pd.isna(segments) or not isinstance(segments, list) or not segments:
            reasons.append('no_main_chain')
        else:
            # It's now safe to access the first element
            seg = segments[0]
            if (seg['end'] - seg['start']) // 3 < MIN_GXY_TRIPLETS:
                reasons.append('low_gxy_content')
            if len(str(r.get('Sequence',''))) < MIN_LEN_AA:
                reasons.append('short_sequence')
            # The quality_flags column is now guaranteed to exist from Cell 31
            if 'cys_in_helix' in r.get('quality_flags', ''):
                reasons.append('cys_in_helix')

        reasons_list.append(';'.join(reasons) if reasons else None)

    # Add the failure reasons as a new column to the original dataframe
    chain_df['failure_reasons'] = reasons_list

    # Use boolean indexing to create the pass/fail dataframes, preserving all columns
    pass_mask = chain_df['failure_reasons'].isna()
    df_high_quality = chain_df[pass_mask].copy()
    df_failed_qc = chain_df[~pass_mask].copy()

    # Clean up the reasons column from the passing dataframe
    df_high_quality.drop(columns=['failure_reasons'], inplace=True, errors='ignore')

    logger.info(f"QC results: {len(df_high_quality)} passed, {len(df_failed_qc)} failed.")

    if not df_failed_qc.empty:
        qc_rejections = df_failed_qc[['Entry','failure_reasons']].rename(columns={'failure_reasons':'reason'})
        qc_rejections['run_id'] = RUN_ID

        master_rejected = safe_read_tsv(REJECTED_IDS_PATH)
        updated_master = pd.concat([master_rejected, qc_rejections], ignore_index=True)
        updated_master.drop_duplicates(subset=['Entry'], keep='last', inplace=True)
        updated_master.to_csv(str(REJECTED_IDS_PATH), sep='\t', index=False)

        logger.info(f"{len(qc_rejections)} QC failures merged into master rejection TSV.")
else:
    logger.info("No chain candidates to perform QC on.")

In [None]:
# ===== Cell 32A =====
# Diagnostic Check on df_high_quality

if 'df_high_quality' in globals() and not df_high_quality.empty:
    print("--- First 5 rows of df_high_quality ---")
    display(df_high_quality.head())
    print("\n--- Columns in df_high_quality ---")
    print(df_high_quality.columns.tolist())
else:
    print("df_high_quality is empty or not defined.")

## Cell 33 – Ensembl Genus Pre-flight Filter

To dramatically improve the efficiency of the exon mapping step, this cell proactively filters the high-quality candidate list based on a robust biological heuristic. It first fetches a complete list of all **genera** present in the Ensembl database. It then removes any candidate sequences from our `df_high_quality` list whose genus (using the authoritative **`cluster_genus`** column) is not in this Ensembl set. This ensures that we only attempt to map exons for species that have a close relative in the target database.

In [None]:
# ===== Cell 33 =====
# Ensembl Genus Pre-flight Filter (with Production-Grade Client)

from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry

def _session_with_retries() -> requests.Session:
    # ... [function is correct, remains unchanged] ...
    sess = requests.Session()
    retry = Retry(total=5, backoff_factor=0.5, status_forcelist=(408, 429, 500, 502, 503, 504))
    adapter = HTTPAdapter(max_retries=retry)
    sess.mount("https://", adapter)
    sess.headers.update({"Accept": "application/json"})
    return sess

def get_ensembl_genera_set() -> Set[str]:
    """
    Retrieves the species list from Ensembl and returns a set of all unique,
    lowercase genus names.
    """
    if 'ensembl_genera_set' in globals() and globals().get('ensembl_genera_set'):
        logger.info("Using cached Ensembl genera list.")
        return globals()['ensembl_genera_set']
    logger.info("Fetching complete species list from Ensembl to build genus filter...")
    sess = _session_with_retries()
    try:
        r = sess.get(f"{ENSEMBL_BASE}/info/species", timeout=60)
        r.raise_for_status()
        species_list = r.json().get("species", [])
        genera_set = {name.split('_')[0].lower() for s in species_list if (name := s.get('name')) and '_' in name}
        logger.info(f"Successfully identified {len(genera_set)} unique genera in Ensembl.")
        globals()['ensembl_genera_set'] = genera_set
        return genera_set
    except (requests.RequestException, KeyError, ValueError) as e:
        logger.error(f"Failed to fetch Ensembl species data: {e}")
        logger.warning("Proceeding without genus pre-flight check.")
        globals()['ensembl_genera_set'] = set()
        return set()

# --- Main Filtering Logic ---
ensembl_genera = get_ensembl_genera_set()

if "df_high_quality" in globals() and not df_high_quality.empty and ensembl_genera:
    initial_count = len(df_high_quality)
    logger.info(f"Applying Ensembl genus pre-filter to {initial_count} candidates...")

    # --- FIX: Use the authoritative 'cluster_genus' column for filtering ---
    if "cluster_genus" in df_high_quality.columns:
        # Filter based on the presence of the genus in Ensembl
        mask = df_high_quality["cluster_genus"].str.lower().isin(ensembl_genera)
        df_high_quality_filtered = df_high_quality[mask].copy()

        removed_count = initial_count - len(df_high_quality_filtered)
        logger.info(f"Removed {removed_count} entries for genera not present in Ensembl.")
        logger.info(f"High-quality candidates remaining for mapping: {len(df_high_quality_filtered)}")
        df_high_quality = df_high_quality_filtered
    else:
        logger.warning("Cannot apply filter: 'cluster_genus' column not found in df_high_quality.")
else:
    logger.info("Skipping Ensembl genus pre-filter (no data or failed API call).")

## Cell 34 – Synteny-Based Pseudogene Filter

This cell introduces a critical pre-filtering step to improve the quality of candidates for exon mapping by removing likely pseudogenes and mis-annotated entries based on their genomic location.

### Rationale: Leveraging Evolutionary Synteny

This filter relies on the **`ensembl_id`** column populated by the proactive discovery in **Cell 22**. It checks if the gene is located on a canonical chromosome or an unreliable scaffold, ensuring only stable genomic models are used for the seed consensus.

### Process

*   For each high-quality candidate sequence, the gene's `ensembl_id` is used to query the Ensembl REST API for its physical location.
*   Sequences located on unplaced scaffolds or patches are filtered out.

### Outcome

This step produces the final, highly purified "seed set" for building the consensus architectural template.

In [None]:
# ===== Cell 34 =====
# Synteny-Based Sanity Check to Filter Pseudogenes (with Caching)

#@markdown #### **Synteny Filter Settings**
ENABLE_SYNTENY_FILTER = True #@param {type:"boolean"}

def check_gene_location(ensembl_id: str) -> Optional[str]:
    """Fetches the chromosome name for a given Ensembl gene ID."""
    if not isinstance(ensembl_id, str) or not ensembl_id.startswith("ENS"):
        return None
    url = f"{ENSEMBL_BASE}/lookup/id/{ensembl_id}?content-type=application/json"
    try:
        # Short timeout because we expect a fast response for valid IDs
        response = requests.get(url, timeout=15)
        if response.status_code == 200:
            data = response.json()
            seq_region = data.get('seq_region_name', '').lower()
            if 'scaffold' in seq_region or 'patch' in seq_region:
                return "unplaced_scaffold"
            return data.get('seq_region_name', 'lookup_failed')
    except requests.RequestException:
        return 'network_error'
    return 'unknown_error'

if ENABLE_SYNTENY_FILTER and 'df_high_quality' in globals() and not df_high_quality.empty:
    if 'ensembl_id' not in df_high_quality.columns:
        logger.error("FATAL: Prerequisite 'ensembl_id' column not found. Skipping Synteny Filter.")
    else:
        logger.info("🧬 Applying synteny filter with caching...")

        # --- Caching Logic ---
        # 1. Load existing cache if it exists
        cached_locations = {}
        if SYNTENY_CACHE_TSV.exists():
            try:
                cache_df = pd.read_csv(SYNTENY_CACHE_TSV, sep='\t')
                # Create a fast lookup dictionary from Entry -> chromosome
                cached_locations = pd.Series(cache_df.chromosome.values, index=cache_df.Entry).to_dict()
                logger.info(f"Loaded {len(cached_locations)} locations from synteny cache.")
            except Exception as e:
                logger.warning(f"Could not load synteny cache, will rebuild. Error: {e}")

        # 2. Identify entries that need to be checked via API
        entries_to_check_df = df_high_quality[
            ~df_high_quality['Entry'].isin(cached_locations.keys())
        ].dropna(subset=['ensembl_id']).copy()

        # 3. Perform the slow API calls ONLY for new entries
        if not entries_to_check_df.empty:
            logger.info(f"Found {len(entries_to_check_df)} new entries to check for synteny.")
            tqdm.pandas(desc="Checking New Gene Locations")
            new_locations = entries_to_check_df['ensembl_id'].progress_apply(check_gene_location)

            # 4. Update the cache file with the new results
            new_results_df = pd.DataFrame({
                'Entry': entries_to_check_df['Entry'],
                'chromosome': new_locations
            })
            # Use mode 'a' (append) and no header if file already exists
            new_results_df.to_csv(
                SYNTENY_CACHE_TSV,
                sep='\t',
                index=False,
                mode='a',
                header=not SYNTENY_CACHE_TSV.exists()
            )
            logger.info(f"Appended {len(new_results_df)} new results to synteny cache.")

            # Add new results to our in-memory dictionary
            cached_locations.update(new_results_df.set_index('Entry')['chromosome'].to_dict())
        else:
            logger.info("No new entries to check; all locations found in cache.")

        # 5. Map all locations (cached + new) to the main dataframe
        initial_count = len(df_high_quality)
        df_high_quality['chromosome'] = df_high_quality['Entry'].map(cached_locations)

        # 6. Filter out entries on scaffolds, with failed lookups, or with no Ensembl ID
        valid_mask = df_high_quality['chromosome'].notna() & (~df_high_quality['chromosome'].isin(['unplaced_scaffold', 'lookup_failed', 'network_error', 'unknown_error']))
        df_high_quality = df_high_quality[valid_mask].copy()

        removed_count = initial_count - len(df_high_quality)
        logger.info(f"Synteny filter complete. Removed {removed_count} entries.")
        logger.info(f"High-quality candidates remaining for mapping: {len(df_high_quality)}")
else:
    logger.info("☑️ Skipping synteny filter.")

In [None]:
# # ===== TEMPORARY PATCH CELL =====
# # Save in-memory synteny results to the new cache file.
# # This cell can be deleted after you run it successfully one time.

# if 'df_high_quality' in globals() and 'chromosome' in df_high_quality.columns:
#     logger.info("PATCH: Saving in-memory synteny results to cache file...")

#     # Select the two columns needed for the cache
#     synteny_results_to_save = df_high_quality[['Entry', 'chromosome']].copy()

#     # Drop any rows where the lookup might have failed or was skipped
#     synteny_results_to_save.dropna(subset=['chromosome'], inplace=True)

#     # Ensure there are no duplicate entries
#     synteny_results_to_save.drop_duplicates(subset=['Entry'], keep='first', inplace=True)

#     try:
#         # Write the results to the cache file
#         synteny_results_to_save.to_csv(SYNTENY_CACHE_TSV, sep='\t', index=False)
#         logger.info(f"✅ PATCH SUCCESS: Saved {len(synteny_results_to_save)} results to {SYNTENY_CACHE_TSV.name}")

#         # Optional: Display the first few rows to verify
#         print("\n--- Sample of saved cache data ---")
#         display(synteny_results_to_save.head())

#     except Exception as e:
#         logger.error(f"PATCH FAILED: Could not write to file. Error: {e}")
#         logger.error("Please ensure you have re-run Cell 12 to define the SYNTENY_CACHE_TSV path.")

# else:
#     logger.warning("PATCH SKIPPED: `df_high_quality` with 'chromosome' column not found in memory.")
#     logger.warning("This means the run was stopped before the API calls were made.")
#     logger.warning("You will need to let the new caching version of Cell 34 run once to build the cache.")

## Cell 34A – Seed Set Quality Assessment

This cell provides a quick diagnostic profile of the final "seed set" of high-quality candidates. The goal is to verify that despite the small number of sequences, the set retains high phylogenetic diversity, which is crucial for building a robust consensus template.

In [None]:
# ===== Cell 34A =====
# Diagnostic Profiling of the Final Seed Set

if 'df_high_quality' in globals() and not df_high_quality.empty:
    logger.info("--- 📊 Profiling Final Seed Set for Mapping ---")

    total_entries = len(df_high_quality)
    unique_species = df_high_quality['Organism'].nunique()
    # --- FIX: Use the authoritative 'cluster_genus' column from Cell 24 ---
    unique_genera = df_high_quality['cluster_genus'].nunique()

    logger.info(f"Total Entries in Seed Set: {total_entries}")
    logger.info(f"Unique Species Represented: {unique_species}")
    logger.info(f"Unique Genera Represented: {unique_genera}")

    if total_entries > 0:
        diversity_score = (unique_genera / total_entries) * 100
        logger.info(f"Phylogenetic Diversity Score (Genera/Entries): {diversity_score:.1f}%")

    print("\n--- Top 15 Genera in Seed Set ---")
    # --- FIX: Use the authoritative 'cluster_genus' column for the value counts ---
    display(df_high_quality['cluster_genus'].value_counts().head(15))

else:
    logger.warning("Seed set (df_high_quality) is empty. No profile to generate.")

# **Part 4: Exon Mapping & Architecture Definition**

## Cell 40 – Exon Mapping Safety & Validation Utilities

This cell defines helper functions for robust, incremental exon mapping. It includes utilities for memory monitoring, validating the integrity of cache files, loading already-mapped accessions to resume runs, and atomically merging new results into the cache with backups.

In [None]:
# ===== Cell 40 =====
# Mapping safety helpers (cache validation, resume, atomic writes)

def get_memory_usage() -> Dict[str, float]:
    """Returns current memory usage statistics."""
    process = psutil.Process(os.getpid())
    mem_info = process.memory_info()
    return {'rss_mb': mem_info.rss / 1e6, 'percent': process.memory_percent()}

def validate_cache_integrity(cache_path: Path, expected_cols: list) -> Tuple[bool, str]:
    """Validates cache file integrity, returning (is_valid, message)."""
    if not cache_path.exists():
        return True, "Cache file does not exist (will be created)."
    try:
        df_sample = pd.read_csv(cache_path, sep='\t', nrows=5, low_memory=False)
        missing = set(expected_cols) - set(df_sample.columns)
        if missing:
            return False, f"Missing required columns: {missing}"
        return True, "Cache validation passed."
    except Exception as e:
        return False, f"Cache validation error: {e}"

def load_mapped_accessions_from_cache() -> Set[str]:
    """Reads the master exon cache and returns the set of already-mapped accessions."""
    if not RAW_EXONS_CACHE.exists():
        logger.info("No existing exon cache found.")
        return set()

    is_valid, msg = validate_cache_integrity(RAW_EXONS_CACHE, ['accession'])
    if not is_valid:
        logger.error(f"Exon cache validation failed: {msg}. Moving corrupted file.")
        corrupted_path = RAW_EXONS_CACHE.with_suffix(f'.corrupted_{RUN_TIMESTAMP}')
        RAW_EXONS_CACHE.rename(corrupted_path)
        logger.info(f"Corrupted cache backed up to: {corrupted_path}")
        return set()

    try:
        df = pd.read_csv(RAW_EXONS_CACHE, sep='\t', usecols=['accession'], low_memory=False)
        accession_set = set(df['accession'].dropna().astype(str))
        logger.info(f"Loaded {len(accession_set)} unique mapped accessions from cache.")
        return accession_set
    except Exception as e:
        logger.error(f"Error reading exon cache for resume: {e}. Starting fresh.")
        return set()

def atomic_merge_into_cache(new_rows_df: pd.DataFrame) -> None:
    """Atomically merges new rows into the master exon cache with backups."""
    if new_rows_df is None or new_rows_df.empty:
        logger.info("No new rows to merge into cache.")
        return

    key_cols = ['accession', 'exon_num_in_chain']
    if not all(c in new_rows_df.columns for c in key_cols):
        logger.error("New rows missing key columns for merge. Aborting cache write.")
        return

    old_df = pd.DataFrame()
    if RAW_EXONS_CACHE.exists():
        old_df = safe_read_tsv(RAW_EXONS_CACHE)

    combined = pd.concat([old_df, new_rows_df], ignore_index=True)
    combined.drop_duplicates(subset=key_cols, keep='last', inplace=True)

    # Atomic write procedure
    temp_path = RAW_EXONS_CACHE.with_suffix('.tmp')
    backup_path = RAW_EXONS_CACHE.with_suffix(f'.bak_{RUN_TIMESTAMP}')
    try:
        combined.to_csv(temp_path, sep='\t', index=False)
        if RAW_EXONS_CACHE.exists():
            RAW_EXONS_CACHE.rename(backup_path)
        temp_path.rename(RAW_EXONS_CACHE)
        if backup_path.exists():
            backup_path.unlink() # Clean up successful backup
        logger.info(f"Cache merge complete: {len(old_df)} + {len(new_rows_df)} -> {len(combined)} rows.")
    except Exception as e:
        logger.error(f"Cache merge failed: {e}")
        # Restore backup if it exists
        if backup_path.exists():
            backup_path.rename(RAW_EXONS_CACHE)
            logger.info("Restored cache from backup.")

## Cell 41 – Enhanced Exon Coordinate Mapper (Ensembl REST API)

This class fetches exon coordinates from the Ensembl REST API. It features robust error handling with distinct categories, automatic retries with exponential backoff, detailed logging, and performance monitoring to diagnose API issues.

In [None]:
# ===== Cell 41 =====
# Enhanced Exon Coordinate Mapper with Negative Caching support

from collections import defaultdict
from enum import Enum

class ErrorCategory(Enum):
    API_FAILURE = "api_failure"; TIMEOUT = "timeout"; DATA_FORMAT = "data_format"
    COORDINATE_ERROR = "coordinate_error"; SEQUENCE_MISMATCH = "sequence_mismatch"
    EMPTY_RESPONSE = "empty_response"; PARSING_ERROR = "parsing_error"
    NETWORK_ERROR = "network_error"; UNKNOWN = "unknown_error"

class EnhancedExonCoordinateMapper:
    """Fetches exon coordinates with robust error handling and negative caching."""
    def __init__(self, base_url=ENSEMBL_BASE, max_retries=3, initial_delay=1.0):
        self.base_url = base_url; self.max_retries = max_retries; self.initial_delay = initial_delay
        self.cache = {}; self.failed = set(); self.stats = defaultdict(int); self.error_details = defaultdict(list)

    def log_error(self, cat: ErrorCategory, acc: str, details: str, context: dict = None):
        self.stats[cat.value] += 1; self.stats['total_failures'] += 1
        logger.debug(f"Mapping error for {acc} ({cat.value}): {details}")

    def _fetch(self, accession: str) -> Optional[Dict]:
        url = f"{self.base_url}/lookup/id/{accession}?content-type=application/json;expand=1"
        for attempt in range(self.max_retries):
            try:
                self.stats['api_calls'] += 1
                response = requests.get(url, timeout=ENSEMBL_TIMEOUT_SECS)
                if response.status_code == 200:
                    data = response.json()
                    if not data: self.log_error(ErrorCategory.EMPTY_RESPONSE, accession, "API returned empty JSON."); return None
                    return data
                self.log_error(ErrorCategory.API_FAILURE, accession, f"HTTP {response.status_code}", {'attempt': attempt})
                time.sleep(self.initial_delay * (2 ** attempt))
            except requests.RequestException as e:
                self.log_error(ErrorCategory.NETWORK_ERROR, accession, str(e), {'attempt': attempt})
                time.sleep(self.initial_delay * (2 ** attempt))
        return None

    def get_mapped_exons(self, accession: str, main_chain: List[Dict], sequence: str) -> List[Dict]:
        """
        Main method to get mapped exons.
        Returns a list of exon dicts on success, or a single placeholder dict on failure.
        """
        self.stats['total_attempts'] += 1
        if accession in self.cache:
            self.stats['cache_hits'] += 1
            return self.cache[accession]

        # --- NEW: Define a placeholder for failed mappings ---
        failure_placeholder = [{
            "accession": accession, "exon_num_in_chain": -999, "begin_aa": pd.NA,
            "end_aa": pd.NA, "peptide": "MAPPING_FAILED", "strand": pd.NA, "chr": pd.NA
        }]

        data = self._fetch(accession)
        if data is None:
            self.failed.add(accession); self.cache[accession] = failure_placeholder
            return failure_placeholder

        try:
            result = self._process_response(accession, data, main_chain, sequence)
            if result:
                self.stats['successes'] += 1
                self.cache[accession] = result
                return result
            else:
                self.failed.add(accession)
                self.cache[accession] = failure_placeholder
                return failure_placeholder
        except Exception as e:
            self.log_error(ErrorCategory.UNKNOWN, accession, f"Unexpected processing error: {e}")
            self.failed.add(accession); self.cache[accession] = failure_placeholder
            return failure_placeholder

    def _process_response(self, acc: str, data: Dict, main_chain: List[Dict], seq: str) -> Optional[List[Dict]]:
        # This internal logic remains the same.
        # It will return None on processing failure, which get_mapped_exons will catch.
        try:
            gn = data['Transcript'][0]['Exon']
            exons = sorted(gn, key=lambda x: x.get('start', 0))
        except (KeyError, IndexError):
            self.log_error(ErrorCategory.DATA_FORMAT, acc, "Could not find exon data in response.")
            return None

        # ... (rest of the processing logic is unchanged)
        first_idx, last_idx = -1, -1
        for i, ex in enumerate(exons):
            if first_idx == -1: first_idx = i
            last_idx = i
        if first_idx == -1: return None

        mapped, current_pos = [], 1
        s_idx, e_idx = max(0, first_idx - 2), min(len(exons) - 1, last_idx + 2)
        even_num = -2 * (first_idx - s_idx)
        for i in range(s_idx, e_idx + 1):
            ex = exons[i]
            length = ex['end'] - ex['start'] + 1
            begin_aa, end_aa = current_pos, current_pos + (length // 3) - 1
            if end_aa > len(seq): continue
            mapped.append({
                "accession": acc, "exon_num_in_chain": even_num,
                "begin_aa": begin_aa, "end_aa": end_aa,
                "peptide": seq[begin_aa-1:end_aa],
                "strand": ex.get('strand'), "chr": data.get('seq_region_name')
            })
            current_pos = end_aa + 1
            even_num += 2
        return mapped

    def get_stats_summary(self) -> Dict:
        summary = dict(self.stats)
        summary['cache_size'] = len(self.cache)
        summary['failed_count'] = len(self.failed)
        return summary

enhanced_exon_mapper = EnhancedExonCoordinateMapper()
logger.info("✅ Enhanced Exon Coordinate Mapper (with negative caching) loaded.")

## Cell 42 – Incremental Exon Mapping Runner (Batch-Processed & Resumable)

This cell orchestrates the exon mapping process in a robust, resumable manner. It identifies high-quality sequences that are not already in the cache, then iterates through them in manageable **batches**. After each batch is processed, the results are immediately and atomically saved to the master cache. If the run is interrupted, it can be restarted and will automatically resume from the last completed batch, preventing loss of work.

In [None]:
# ===== Cell 42 =====
# Incremental exon mapping runner (FINAL, MOST ROBUST VERSION)

logger.info("--- Starting Part 4: Incremental Exon Mapping (Batch Mode) ---")

BATCH_SIZE = 100  #@param {type:"integer"}

if 'df_high_quality' not in globals() or df_high_quality.empty:
    logger.warning("df_high_quality is empty. No new entries to map.")
    to_map_accessions = []
else:
    all_hq_accessions = set(df_high_quality['Entry'].dropna().astype(str).str.strip())
    already_mapped = load_mapped_accessions_from_cache()
    to_map_accessions = sorted(list(all_hq_accessions - already_mapped))

logger.info(f"High-quality entries total: {len(all_hq_accessions) if 'all_hq_accessions' in locals() else 0}")
logger.info(f"Already attempted/mapped in cache: {len(already_mapped) if 'already_mapped' in locals() else 0}")
logger.info(f"Entries to map this run: {len(to_map_accessions)}")

if not to_map_accessions:
    logger.info("✅ No new entries to map after filtering. Skipping mapping.")
else:
    num_batches = (len(to_map_accessions) + BATCH_SIZE - 1) // BATCH_SIZE
    logger.info(f"Processing {len(to_map_accessions)} entries in {num_batches} batches of up to {BATCH_SIZE} each.")

    for i in range(num_batches):
        batch_start_time = time.time()
        start_index, end_index = i * BATCH_SIZE, (i + 1) * BATCH_SIZE
        batch_accessions = to_map_accessions[start_index:end_index]

        logger.info(f"--- Starting Batch {i+1}/{num_batches} ({len(batch_accessions)} entries) ---")

        df_batch = df_high_quality[df_high_quality['Entry'].isin(batch_accessions)]

        newly_mapped_rows_batch = []
        with tqdm(total=len(df_batch), desc=f"Mapping Batch {i+1}/{num_batches}") as pbar:
            for _, row in df_batch.iterrows():
                # The mapper now ALWAYS returns a list, even for failures
                exon_data = enhanced_exon_mapper.get_mapped_exons(row['Entry'], row['main_chain_segments'], row['Sequence'])
                newly_mapped_rows_batch.extend(exon_data)
                pbar.update(1)

        # --- Commit After Each Batch ---
        # This block now runs for every batch, recording both successes and failures.
        batch_df = pd.DataFrame(newly_mapped_rows_batch)
        atomic_merge_into_cache(batch_df)
        successful_maps = len(batch_df[batch_df['peptide'] != 'MAPPING_FAILED'])
        logger.info(f"✅ Batch {i+1} complete. Committed {len(batch_df)} total rows ({successful_maps} successful) to cache.")

        batch_snapshot_path = RUN_DIR / f"mapped_exons_batch_{i+1}.tsv"
        batch_df.to_csv(batch_snapshot_path, sep='\t', index=False)

        batch_duration = time.time() - batch_start_time
        logger.info(f"Batch {i+1} finished in {batch_duration / 60:.2f} minutes.")

logger.info("--- Incremental Exon Mapping Complete ---")
logger.info(f"Final mapping stats for the run: {json.dumps(enhanced_exon_mapper.get_stats_summary(), indent=2)}")

all_batch_files = sorted(RUN_DIR.glob("mapped_exons_batch_*.tsv"))
if all_batch_files:
    all_run_mapped_df = pd.concat([pd.read_csv(f, sep='\t', low_memory=False) for f in all_batch_files], ignore_index=True)
    all_run_mapped_df.to_csv(MAPPED_SNAPSHOT, sep='\t', index=False)
    logger.info(f"Consolidated snapshot for this run saved to {MAPPED_SNAPSHOT.name}")

## Cell 43 – Error Analysis & Recovery Planning

Analyzes the detailed failure patterns from the exon mapper to identify systematic issues. It generates a summary report and suggests potential recovery actions, providing insight into the pipeline's performance and potential areas for improvement.

In [None]:
# ===== Cell 43 =====
# Error analysis and recovery system

def analyze_mapping_failures(mapper: EnhancedExonCoordinateMapper) -> Dict:
    """Analyzes failure patterns and generates a report."""
    logger.info("🔍 Analyzing mapping failure patterns...")
    stats = mapper.stats
    total_failures = stats.get('total_failures', 0)

    analysis = {
        'run_id': RUN_ID,
        'summary': dict(stats),
        'failure_analysis': {},
        'recommendations': []
    }

    if total_failures > 0:
        sorted_errors = sorted(
            [(k, v) for k, v in stats.items() if k not in ['total_failures', 'total_attempts', 'api_calls', 'successes', 'cache_hits']],
            key=lambda item: item[1], reverse=True
        )
        for category, count in sorted_errors:
            percentage = (count / total_failures) * 100
            analysis['failure_analysis'][category] = {'count': count, 'percentage': f"{percentage:.1f}%"}

        top_failure = sorted_errors[0][0] if sorted_errors else None
        if top_failure == 'api_failure':
            analysis['recommendations'].append("High API failure rate suggests network issues or server-side problems. Check Ensembl status.")
        elif top_failure == 'coordinate_error':
            analysis['recommendations'].append("Coordinate errors are frequent. Review the logic for mapping genomic to protein coordinates in Cell 41.")

    else:
        logger.info("✅ No mapping failures detected.")

    try:
        with open(ERROR_REPORT_PATH, 'w') as f:
            json.dump(analysis, f, indent=2)
        logger.info(f"💾 Error analysis report saved to: {ERROR_REPORT_PATH}")
    except Exception as e:
        logger.error(f"Failed to save error analysis report: {e}")

    return analysis

failure_analysis = analyze_mapping_failures(enhanced_exon_mapper)

## Cell 44 – Load Seed Exon Data from Cache

This cell performs a final check on the master exon cache (`raw_exons_cache.tsv`). It then loads the currently mapped data, which will serve as the high-confidence **seed data** for building our initial architectural template. This seed dataset contains exons mapped only from the most reliable, pre-filtered sequences.

In [None]:
# ===== Cell 44 =====
# Load seed exon data from cache

def _safe_row_count(p: Path) -> int:
    """Safely estimates row count, returning -1 on failure."""
    if not p.is_file(): return -1
    try:
        # Fast line count for large files
        with open(p, 'rb') as f:
            return sum(1 for _ in f) - 1
    except Exception:
        return -1

logger.info("Performing final cache integrity check...")
current_rows = _safe_row_count(RAW_EXONS_CACHE)
logger.info(f"Current exon cache: ~{current_rows} rows.")

backups = sorted(
    list(CACHE_DIR.glob("raw_exons_cache.bak_*")),
    key=lambda p: p.stat().st_mtime,
    reverse=True
)

if backups:
    latest_bak = backups[0]
    bak_rows = _safe_row_count(latest_bak)
    logger.info(f"Latest backup: {latest_bak.name} (~{bak_rows} rows).")

    if current_rows >= 0 and bak_rows > current_rows:
        logger.warning(f"Backup ({bak_rows} rows) is larger than current cache ({current_rows} rows). Restoring from backup.")
        try:
            latest_bak.replace(RAW_EXONS_CACHE)
            logger.info("✅ Auto-restore from backup complete.")
        except Exception as e:
            logger.error(f"Failed to restore from backup: {e}")
else:
    logger.info("No backups found for comparison.")

# --- MODIFICATION: Load the mapped data into a 'seed' DataFrame ---
# This data is from the high-confidence run and will be used to build the rescue template.
if RAW_EXONS_CACHE.exists():
    df_raw_exons_seed = safe_read_tsv(RAW_EXONS_CACHE)
    # Filter out any failed mappings that might be in the cache
    df_raw_exons_seed = df_raw_exons_seed[df_raw_exons_seed['peptide'] != 'MAPPING_FAILED'].copy()
    logger.info(f"Loaded high-confidence seed exon dataset with {len(df_raw_exons_seed)} rows.")
else:
    df_raw_exons_seed = pd.DataFrame()
    logger.warning("Seed exon cache is not available for building rescue template.")

## Cell 45 – Phylogenetic Consensus Engine

This cell defines the core `PhylogeneticConsensusEngine`. It is initialized once and then used by all subsequent consensus and dating cells. It loads the Metazoan Newick tree and the pre-computed node age cache (from Cell 15) to provide two key services with high performance:
1.  **MRCA Dating**: Finds the Most Recent Common Ancestor (MRCA) for any group of genera and returns its age in millions of years.
2.  **Phylogenetic Weighting**: Calculates weights for each genus based on its evolutionary distance from the MRCA, ensuring that diverse taxa contribute more to consensus calculations.

In [None]:
# ===== Cell 45 =====
# Phylogenetic Consensus Engine (Genus-only; with summary reporting)

from __future__ import annotations

import logging
import re
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import pandas as pd
from ete3 import Tree

logger = logging.getLogger(__name__)


class PhylogeneticConsensusEngine:
    """
    Handles phylogenetic weighting and dating using a pre-computed cache.

    This engine:
      1) loads a Newick phylogeny and a CSV of node ages,
      2) normalises all leaf labels to extract *only* the Genus,
      3) builds a genus → leaf map (single representative per genus for lookups),
      4) records a summary of how many leaves were associated with each genus.

    Parameters
    ----------
    nwk_path : Path
        Path to the Newick tree file.
    node_ages_csv : Path
        Path to a CSV with columns ['node_name', 'age'] giving node ages.
    """

    # --- Robust genus-extraction machinery (class-level constants) ---
    _GENUS_PAT = re.compile(
        r"""
        ^\s*["']?                             # optional leading quote/space
        (?P<genus>[A-Z][a-zA-Z\-]+)           # Genus: capitalised Latin word (allow hyphen)
        (?:                                   # optional trailing parts (species, subspecies, qualifiers)
            [\s_./|:-]+
            (?:cf\.|aff\.|nr\.|gr\.|sp\.|spp\.|x|×|hybrid)?
            .*?
        )?
        ["']?\s*$                             # optional closing quote/space
        """,
        re.VERBOSE,
    )
    _HYBRID_SEP = re.compile(r"[\s_]+(?:x|×|hybrid)[\s_]+", re.IGNORECASE)
    _SEP_NORM = re.compile(r"[ \t\n\r_.:/|\\-]+")

    def __init__(self, nwk_path: Path, node_ages_csv: Path) -> None:
        self.tree: Optional[Tree] = None
        self.age_lookup: Dict[str, float] = {}
        self.genus_to_leaf_map: Dict[str, Tree] = {}
        self._genus_leaf_count: Dict[str, int] = {}
        self._is_ready: bool = False

        if not nwk_path.exists() or not node_ages_csv.exists():
            logger.error(
                "Phylogenetic tree (%s) or node age cache (%s) missing. Engine disabled.",
                nwk_path.name,
                node_ages_csv.name,
            )
            return

        try:
            logger.info("Initializing Phylogenetic Engine from tree: %s", nwk_path.name)
            self.tree = Tree(str(nwk_path), format=1)

            # First pass: compute genus counts across *all* leaves.
            for leaf in self.tree.iter_leaves():
                label = str(leaf.name) if leaf.name is not None else ""
                genus_key = self.extract_genus(label, return_canonical=False)
                if genus_key is None:
                    # If desired, you can switch this to 'continue' to be more permissive.
                    raise ValueError(f"Could not extract genus from label: {label!r}")
                self._genus_leaf_count[genus_key] = self._genus_leaf_count.get(genus_key, 0) + 1

            # Second pass: build representative genus → leaf mapping.
            # Policy: keep the *first* leaf encountered for a genus as the representative.
            # (We already have counts in _genus_leaf_count for summary purposes.)
            for leaf in self.tree.iter_leaves():
                label = str(leaf.name) if leaf.name is not None else ""
                genus_key = self.extract_genus(label, return_canonical=False)
                if genus_key is None:
                    continue
                if genus_key not in self.genus_to_leaf_map:
                    self.genus_to_leaf_map[genus_key] = leaf

            # Load node ages
            df_ages = pd.read_csv(node_ages_csv)
            required_cols = {"node_name", "age"}
            if not required_cols.issubset(set(df_ages.columns)):
                raise ValueError(
                    f"Node-age CSV must contain columns {sorted(required_cols)}; "
                    f"found {sorted(df_ages.columns)}"
                )
            df_ages["node_name"] = df_ages["node_name"].astype(str)
            self.age_lookup = df_ages.set_index("node_name")["age"].to_dict()

            self._is_ready = True
            logger.info(
                "✅ Phylogenetic Engine is ready with %d genera (from %d total leaves).",
                len(self.genus_to_leaf_map),
                sum(self._genus_leaf_count.values()),
            )

        except Exception as e:
            logger.error("Failed to initialize Phylogenetic Engine: %s", e, exc_info=True)

    # ------------------ Genus extraction (static utility) ------------------

    @staticmethod
    def extract_genus(label: str, *, return_canonical: bool = True) -> Optional[str]:
        """
        Extract the Genus epithet from a taxon label robustly.

        The function:
          - normalises separators,
          - breaks hybrid expressions at the first hybrid marker (×, x, 'hybrid'),
          - ignores qualifiers (cf., aff., nr., gr., sp., spp.),
          - returns only the Genus part.

        Parameters
        ----------
        label : str
            The original taxon label from the Newick tree or input list.
        return_canonical : bool, optional
            If True, return the Genus with canonical capitalisation ('Bos').
            If False, return a lower-case normalised form ('bos').

        Returns
        -------
        Optional[str]
            The extracted Genus, or None if no valid genus token is found.
        """
        if label is None:
            return None

        # Normalise whitespace and mixed separators.
        s = str(label)
        s = PhylogeneticConsensusEngine._SEP_NORM.sub(" ", s).strip()

        # If the label appears to encode a hybrid, only consider the left-hand side.
        s = PhylogeneticConsensusEngine._HYBRID_SEP.split(s, maxsplit=1)[0]

        # Strip qualifiers at the head (e.g., "cf. Bos taurus" → "Bos taurus").
        s = re.sub(r"^(?:cf\.|aff\.|nr\.|gr\.)\s+", "", s, flags=re.IGNORECASE)

        # Now match the leading Genus token.
        m = PhylogeneticConsensusEngine._GENUS_PAT.match(s)
        if not m:
            return None

        genus = m.group("genus")
        if return_canonical:
            return genus[0].upper() + genus[1:]
        return genus.lower()

    # ------------------------- Public API methods --------------------------

    def is_ready(self) -> bool:
        """Return True if the engine initialised without errors."""
        return self._is_ready

    def genus_leaf_counts(self, *, descending: bool = True) -> Dict[str, int]:
        """
        Return a dictionary of {genus(lower-case): leaf_count}.

        Parameters
        ----------
        descending : bool
            If True, sort by count descending.

        Returns
        -------
        Dict[str, int]
            Mapping of genera to their observed leaf counts in the tree.
        """
        items = sorted(self._genus_leaf_count.items(), key=lambda kv: kv[1], reverse=descending)
        return dict(items)

    def save_genus_leaf_summary(self, out_csv: Path) -> Path:
        """
        Save the per-genus leaf-count summary to CSV.

        Parameters
        ----------
        out_csv : Path
            Output path for the CSV file.

        Returns
        -------
        Path
            The path written.
        """
        df = pd.DataFrame(
            [{"genus": g, "leaf_count": c} for g, c in self._genus_leaf_count.items()]
        ).sort_values("leaf_count", ascending=False, kind="mergesort")
        df.to_csv(out_csv, index=False)
        return out_csv

    def get_mrca_age(self, genera_list: List[str]) -> Tuple[Optional[str], float]:
        """
        Find the MRCA for a list of genera and return its node name and age.

        The lookup is case-insensitive. If no valid nodes are found, returns
        ("No_Valid_Nodes", 0.0). If the ancestor cannot be determined, returns
        ("Ancestor_Not_Found", 0.0).
        """
        if not self.is_ready() or not genera_list:
            return None, 0.0

        leaf_nodes = [
            self.genus_to_leaf_map.get(str(g).lower())
            for g in genera_list
            if g is not None and str(g).strip() != ""
        ]
        valid_nodes = [node for node in leaf_nodes if node is not None]

        if not valid_nodes:
            return "No_Valid_Nodes", 0.0

        ancestor = self.tree.get_common_ancestor(valid_nodes)
        if ancestor is None:
            return "Ancestor_Not_Found", 0.0

        age = self.age_lookup.get(str(ancestor.name), 0.0)
        return str(ancestor.name), float(age)

    def weights_from_mrca(self, genera_list: List[str]) -> Dict[str, float]:
        """
        Calculate phylogenetic weights for genera based on distance from the MRCA age.

        The weighting here is a simple inverse-distance scheme:
            weight(g) ∝ 1 / (1 + max(0, age(g) - age(MRCA)))
        and is normalised to sum to 1 over the provided genera that are present in the tree.
        """
        unique_genera = sorted({g for g in genera_list if g})
        weights = {g: 0.1 for g in unique_genera}  # default baseline

        if not self.is_ready() or not unique_genera:
            return weights

        _, mrca_age = self.get_mrca_age(unique_genera)

        total_inverse_dist = 0.0
        temp_weights: Dict[str, float] = {}

        for genus in unique_genera:
            node = self.genus_to_leaf_map.get(str(genus).lower())
            if node:
                genus_age = float(self.age_lookup.get(str(node.name), mrca_age))
                distance = max(0.0, genus_age - mrca_age)
                inverse_dist = 1.0 / (1.0 + distance)
                temp_weights[genus] = inverse_dist
                total_inverse_dist += inverse_dist

        if total_inverse_dist > 0.0:
            for genus, inv_dist in temp_weights.items():
                weights[genus] = inv_dist / total_inverse_dist

        return weights


# Initialise the engine for use in subsequent cells
phylo_engine = PhylogeneticConsensusEngine(DRIVE_METAZOAN_TREE_PATH, NODE_AGES_CSV_PATH)


In [None]:
counts = phylo_engine.genus_leaf_counts()
list(counts.items())[:15]  # preview top 15 genera by leaf count

In [None]:
# # ===== Cell 45 =====
# # Phylogenetic Consensus Engine (Definitive Correction)

# from __future__ import annotations

# import logging
# import re
# from pathlib import Path
# from typing import Dict, List, Optional, Tuple

# import pandas as pd
# from ete3 import Tree

# logger = logging.getLogger(__name__)


# class PhylogeneticConsensusEngine:
#     """
#     Handles phylogenetic weighting and dating using a pre-computed cache.

#     This engine:
#       1) loads a Newick phylogeny and a CSV of node ages,
#       2) normalises all leaf labels to extract *only* the Genus,
#       3) builds a genus → leaf map for downstream MRCA and weighting queries.

#     Parameters
#     ----------
#     nwk_path : Path
#         Path to the Newick tree file.
#     node_ages_csv : Path
#         Path to a CSV with columns ['node_name', 'age'] giving node ages.
#     """

#     # --- Robust genus-extraction machinery (class-level constants) ---
#     _GENUS_PAT = re.compile(
#         r"""
#         ^\s*["']?                             # optional leading quote/space
#         (?P<genus>[A-Z][a-zA-Z\-]+)           # Genus: capitalised Latin word (allow hyphen)
#         (?:                                   # optional trailing parts (species, subspecies, qualifiers)
#             [\s_./|:-]+
#             (?:cf\.|aff\.|nr\.|gr\.|sp\.|spp\.|x|×|hybrid)?
#             .*?
#         )?
#         ["']?\s*$                             # optional closing quote/space
#         """,
#         re.VERBOSE,
#     )
#     _HYBRID_SEP = re.compile(r"[\s_]+(?:x|×|hybrid)[\s_]+", re.IGNORECASE)
#     _SEP_NORM = re.compile(r"[ \t\n\r_.:/|\\-]+")

#     def __init__(self, nwk_path: Path, node_ages_csv: Path) -> None:
#         self.tree: Optional[Tree] = None
#         self.age_lookup: Dict[str, float] = {}
#         self.genus_to_leaf_map: Dict[str, Tree] = {}
#         self._is_ready: bool = False

#         if not nwk_path.exists() or not node_ages_csv.exists():
#             logger.error(
#                 "Phylogenetic tree (%s) or node age cache (%s) missing. Engine disabled.",
#                 nwk_path.name,
#                 node_ages_csv.name,
#             )
#             return

#         try:
#             logger.info("Initializing Phylogenetic Engine from tree: %s", nwk_path.name)
#             # format=1 is common when leaves have names; adjust if your Newick differs.
#             self.tree = Tree(str(nwk_path), format=1)

#             # Build genus → leaf mapping; keys are lower-case for stable lookups.
#             for leaf in self.tree.iter_leaves():
#                 label = str(leaf.name) if leaf.name is not None else ""
#                 genus_key = self.extract_genus(label, return_canonical=False)
#                 if genus_key is None:
#                     raise ValueError(f"Could not extract genus from label: {label!r}")

#                 # Warn on collisions (same genus seen multiple times)
#                 if genus_key in self.genus_to_leaf_map and self.genus_to_leaf_map[genus_key] is not leaf:
#                     logger.warning(
#                         "Multiple leaves map to the same genus %r. Keeping the first occurrence; label=%r",
#                         genus_key,
#                         label,
#                     )
#                     # Policy: keep first occurrence (deterministic). Alternatively, store a list.

#                 else:
#                     self.genus_to_leaf_map[genus_key] = leaf

#             # Load ages
#             df_ages = pd.read_csv(node_ages_csv)
#             required_cols = {"node_name", "age"}
#             if not required_cols.issubset(set(df_ages.columns)):
#                 raise ValueError(
#                     f"Node-age CSV must contain columns {sorted(required_cols)}; "
#                     f"found {sorted(df_ages.columns)}"
#                 )
#             # Normalise node_name to string for safety
#             df_ages["node_name"] = df_ages["node_name"].astype(str)
#             self.age_lookup = df_ages.set_index("node_name")["age"].to_dict()

#             self._is_ready = True
#             logger.info("✅ Phylogenetic Engine is ready with %d genera.", len(self.genus_to_leaf_map))

#         except Exception as e:
#             logger.error("Failed to initialize Phylogenetic Engine: %s", e, exc_info=True)

#     # ------------------ Genus extraction (static utility) ------------------

#     @staticmethod
#     def extract_genus(label: str, *, return_canonical: bool = True) -> Optional[str]:
#         """
#         Extract the Genus epithet from a taxon label robustly.

#         The function:
#           - normalises separators,
#           - breaks hybrid expressions at the first hybrid marker (×, x, 'hybrid'),
#           - ignores qualifiers (cf., aff., nr., gr., sp., spp.),
#           - returns only the Genus part.

#         Parameters
#         ----------
#         label : str
#             The original taxon label from the Newick tree or input list.
#         return_canonical : bool, optional
#             If True, return the Genus with canonical capitalisation ('Bos').
#             If False, return a lower-case normalised form ('bos').

#         Returns
#         -------
#         Optional[str]
#             The extracted Genus, or None if no valid genus token is found.
#         """
#         if label is None:
#             return None

#         # Normalise whitespace and mixed separators.
#         s = str(label)
#         s = PhylogeneticConsensusEngine._SEP_NORM.sub(" ", s).strip()

#         # If the label appears to encode a hybrid, only consider the left-hand side.
#         # E.g., "Bos indicus x Bos taurus" → keep "Bos indicus ..."
#         s = PhylogeneticConsensusEngine._HYBRID_SEP.split(s, maxsplit=1)[0]

#         # Strip qualifiers at the head (e.g., "cf. Bos taurus" → "Bos taurus").
#         s = re.sub(r"^(?:cf\.|aff\.|nr\.|gr\.)\s+", "", s, flags=re.IGNORECASE)

#         # Now match the leading Genus token.
#         m = PhylogeneticConsensusEngine._GENUS_PAT.match(s)
#         if not m:
#             return None

#         genus = m.group("genus")
#         if return_canonical:
#             # Canonical capitalisation: "Bos"
#             return genus[0].upper() + genus[1:]
#         # Normalised lower-case: "bos"
#         return genus.lower()

#     # ------------------------- Public API methods --------------------------

#     def is_ready(self) -> bool:
#         """Return True if the engine initialised without errors."""
#         return self._is_ready

#     def get_mrca_age(self, genera_list: List[str]) -> Tuple[Optional[str], float]:
#         """
#         Find the MRCA for a list of genera and return its node name and age.

#         The lookup is case-insensitive. If no valid nodes are found, returns
#         ("No_Valid_Nodes", 0.0). If the ancestor cannot be determined, returns
#         ("Ancestor_Not_Found", 0.0).
#         """
#         if not self.is_ready() or not genera_list:
#             return None, 0.0

#         leaf_nodes = [
#             self.genus_to_leaf_map.get(str(g).lower())
#             for g in genera_list
#             if g is not None and str(g).strip() != ""
#         ]
#         valid_nodes = [node for node in leaf_nodes if node is not None]

#         if not valid_nodes:
#             return "No_Valid_Nodes", 0.0

#         ancestor = self.tree.get_common_ancestor(valid_nodes)
#         if ancestor is None:
#             return "Ancestor_Not_Found", 0.0

#         age = self.age_lookup.get(str(ancestor.name), 0.0)
#         return str(ancestor.name), float(age)

#     def weights_from_mrca(self, genera_list: List[str]) -> Dict[str, float]:
#         """
#         Calculate phylogenetic weights for genera based on distance from the MRCA age.

#         The weighting here is a simple inverse-distance scheme:
#             weight(g) ∝ 1 / (1 + max(0, age(g) - age(MRCA)))
#         and is normalised to sum to 1 over the provided genera that are present in the tree.
#         """
#         unique_genera = sorted({g for g in genera_list if g})
#         weights = {g: 0.1 for g in unique_genera}  # default baseline

#         if not self.is_ready() or not unique_genera:
#             return weights

#         _, mrca_age = self.get_mrca_age(unique_genera)

#         total_inverse_dist = 0.0
#         temp_weights: Dict[str, float] = {}

#         for genus in unique_genera:
#             node = self.genus_to_leaf_map.get(str(genus).lower())
#             if node:
#                 genus_age = float(self.age_lookup.get(str(node.name), mrca_age))
#                 distance = max(0.0, genus_age - mrca_age)
#                 inverse_dist = 1.0 / (1.0 + distance)
#                 temp_weights[genus] = inverse_dist
#                 total_inverse_dist += inverse_dist

#         if total_inverse_dist > 0.0:
#             for genus, inv_dist in temp_weights.items():
#                 weights[genus] = inv_dist / total_inverse_dist

#         return weights


# # Initialise the engine for use in subsequent cells
# phylo_engine = PhylogeneticConsensusEngine(DRIVE_METAZOAN_TREE_PATH, NODE_AGES_CSV_PATH)


In [None]:
# ===== Cell 45A =====
# Pre-flight check for Phylogenetic Engine dependencies

logger.info("--- 🔎 Performing pre-flight check for Phylogenetic Engine ---")

# 1. Check for the Newick tree file
if not DRIVE_METAZOAN_TREE_PATH.exists():
    logger.error(f"CRITICAL: The Metazoan Newick tree is not found at the expected path.")
    logger.error(f"Expected path: {DRIVE_METAZOAN_TREE_PATH}")
    logger.error("Please verify the path in Cell 12. Halting execution.")
    assert False, "Missing required Newick tree file."
else:
    logger.info(f"✅ Newick tree found: {DRIVE_METAZOAN_TREE_PATH.name}")

# 2. Check for the Node Age Cache file
if not NODE_AGES_CSV_PATH.exists():
    logger.error(f"CRITICAL: The pre-computed node age cache is not found.")
    logger.error(f"Expected path: {NODE_AGES_CSV_PATH}")
    logger.error("This is likely because it has not been generated yet.")
    logger.error("SOLUTION: Go to Cell 11, set REGENERATE_NODE_AGE_CACHE = True, and re-run from Cell 11.")
    assert False, "Missing required node age cache. Please regenerate it."
else:
    logger.info(f"✅ Node age cache found: {NODE_AGES_CSV_PATH.name}")

# 3. Final check on the engine object itself
if not phylo_engine.is_ready():
    logger.error("CRITICAL: The phylogenetic engine object failed to initialize even though files were found.")
    logger.error("This may indicate a problem with the file contents (e.g., corrupted tree).")
    assert False, "Phylogenetic engine is not ready. Check logs for initialization errors."
else:
    logger.info("✅ Pre-flight check passed. Phylogenetic engine is ready for use.")

# **Part 5: Architecture-Driven Rescue & Correction**

## Cell 50 – Generate Seed Consensus & Exon "Baits"

This cell marks the first pass of our two-pass consensus strategy. It takes the high-confidence "seed" exon data and calculates a preliminary consensus architecture using the `phylo_engine`. This robust template is then used to build a library of highly specific regular expression "baits," one for each canonical exon. These baits will be used in the next cell to "fish" for exons in the entire working dataset.

In [None]:
# ===== Cell 50 =====
# Generate Seed Consensus & Build Regex Exon Baits

from __future__ import annotations

import numpy as np
import pandas as pd
import re
from dataclasses import dataclass
from collections import defaultdict
from typing import Iterable, Optional, List

from tqdm.auto import tqdm

logger.info("--- Pass 1: Generating high-confidence seed consensus architecture ---")

# ---------- small utilities ----------

def _has_cols(df: pd.DataFrame, cols: Iterable[str]) -> bool:
    """Return True if all columns in `cols` exist in df."""
    cols = list(cols)
    missing = [c for c in cols if c not in df.columns]
    return len(missing) == 0

def _first_existing_col(df: pd.DataFrame, candidates: List[str]) -> Optional[str]:
    """Return the first column name that exists in df from candidates, else None."""
    for c in candidates:
        if c in df.columns:
            return c
    return None

def _coalesce_columns(df: pd.DataFrame, target: str, sources: List[str]) -> pd.DataFrame:
    """
    If `target` missing, but any of `sources` exist, create `target` by taking
    the first non-null value row-wise across the sources.
    """
    if target in df.columns:
        return df
    available = [c for c in sources if c in df.columns]
    if not available:
        return df
    df[target] = pd.Series(index=df.index, dtype=object)
    if len(available) == 1:
        df[target] = df[available[0]]
    else:
        df[target] = df[available].bfill(axis=1).iloc[:, 0]
    return df

def _derive_cluster_genus_with_engine(df: pd.DataFrame) -> pd.DataFrame:
    """
    If `cluster_genus` is absent or largely null, attempt to derive it from
    any available taxon label using the same genus extractor as the engine.
    """
    if "cluster_genus" in df.columns and df["cluster_genus"].notna().any():
        return df

    # Try to find a suitable label column to parse a genus from.
    label_col = _first_existing_col(
        df,
        ["cluster_label", "species", "scientific_name", "taxon", "Entry", "entry_name"]
    )
    if label_col is None:
        return df

    # Use the engine's extractor for consistency with the tree vocabulary.
    # Fallback: simple first-token on underscore/space if extractor not available.
    def _to_genus(label: object) -> Optional[str]:
        if label is None or (isinstance(label, float) and np.isnan(label)):
            return None
        s = str(label)
        try:
            g = PhylogeneticConsensusEngine.extract_genus(s, return_canonical=False)
            if g:
                return g
        except Exception:
            pass
        # last resort (very permissive)
        token = re.split(r"[ _]+", s.strip(), maxsplit=1)[0]
        return token.lower() if token else None

    df = df.copy()
    df["cluster_genus"] = df[label_col].map(_to_genus)
    return df

def weighted_median(values: np.ndarray, weights: np.ndarray) -> float:
    """
    Compute a weighted median; if inputs are empty or all weights zero, return NaN.
    """
    if values.size == 0 or weights.size == 0:
        return float("nan")
    wsum = float(np.sum(weights))
    if wsum <= 0.0:
        return float("nan")
    idx = np.argsort(values)
    v = values[idx]
    w = weights[idx]
    c = np.cumsum(w) / wsum
    j = min(int(np.searchsorted(c, 0.5)), len(v) - 1)
    return float(v[j])


# ---------- seed consensus generation ----------

if 'df_raw_exons_seed' in globals() and isinstance(df_raw_exons_seed, pd.DataFrame) \
   and not df_raw_exons_seed.empty and phylo_engine.is_ready():

    # 1) Build a robust tax_info table from working_df
    if 'working_df' not in globals() or not isinstance(working_df, pd.DataFrame) or working_df.empty:
        logger.error("working_df is not available or empty; cannot annotate seed with taxonomy.")
        consensus_tbl_seed = pd.DataFrame()
        assert False, "Cannot proceed without working_df."

    # Ensure we have an accession column to merge on.
    tax_info_df = working_df.copy()

    # If 'Entry' is your accession, keep it; otherwise try to coalesce to 'accession'.
    if "Entry" in tax_info_df.columns and "accession" not in tax_info_df.columns:
        tax_info_df = tax_info_df.rename(columns={"Entry": "accession"})
    elif "accession" not in tax_info_df.columns:
        # Try common alternatives.
        acc_col = _first_existing_col(tax_info_df, ["Accession", "uniprot_id", "uniprot_accession"])
        if acc_col:
            tax_info_df = tax_info_df.rename(columns={acc_col: "accession"})
        else:
            logger.error("No accession column found in working_df (expected 'Entry' or 'accession').")
            consensus_tbl_seed = pd.DataFrame()
            assert False, "Cannot proceed without an accession column in working_df."

    # Bring gene_symbol into a canonical name.
    if "gene_symbol_norm" in tax_info_df.columns:
        tax_info_df = tax_info_df.rename(columns={"gene_symbol_norm": "gene_symbol"})
    elif "gene_symbol" not in tax_info_df.columns:
        # Try fallback columns; if none exist, leave missing and we'll handle after merge.
        gs_fallback = _first_existing_col(tax_info_df, ["Gene", "gene", "symbol", "GN"])
        if gs_fallback:
            tax_info_df = tax_info_df.rename(columns={gs_fallback: "gene_symbol"})

    # Ensure cluster_genus exists (coalesce common suffix variants)
    tax_info_df = _coalesce_columns(
        tax_info_df,
        target="cluster_genus",
        sources=["cluster_genus", "cluster_genus_x", "cluster_genus_y"]
    )
    # If still missing or mostly null, derive it from a label via the engine extractor.
    tax_info_df = _derive_cluster_genus_with_engine(tax_info_df)

    # Keep only the columns we need for the merge
    cols_to_keep = [c for c in ["accession", "cluster_genus", "gene_symbol"] if c in tax_info_df.columns]
    tax_info_df = tax_info_df[cols_to_keep].drop_duplicates()

    # 2) Make sure df_raw_exons_seed has an accession to join on
    if "accession" not in df_raw_exons_seed.columns:
        # Try to rename common fields to 'accession'
        ex_acc = _first_existing_col(df_raw_exons_seed, ["Entry", "Accession", "uniprot_id", "uniprot_accession"])
        if ex_acc:
            df_raw_exons_seed = df_raw_exons_seed.rename(columns={ex_acc: "accession"})
        else:
            logger.error("df_raw_exons_seed lacks an accession column to merge on.")
            consensus_tbl_seed = pd.DataFrame()
            assert False, "Cannot proceed without a common accession key."

    # 3) Merge, then coalesce any suffixed columns back
    seed_with_tax = df_raw_exons_seed.merge(tax_info_df, on="accession", how="left")

    # After merges, Pandas can suffix duplicate names; coalesce them explicitly.
    # (Do this even if you do not expect collisions; it prevents intermittent KeyErrors.)
    for base in ("cluster_genus", "gene_symbol"):
        if base not in seed_with_tax.columns:
            left = f"{base}_x"
            right = f"{base}_y"
            if left in seed_with_tax.columns or right in seed_with_tax.columns:
                seed_with_tax = _coalesce_columns(seed_with_tax, base, [left, right])

    # 4) Enforce presence of required columns safely (avoid KeyError in dropna)
    required = ["cluster_genus", "gene_symbol"]
    missing_now = [c for c in required if c not in seed_with_tax.columns]
    if missing_now:
        logger.error("Missing required columns after merge: %s", missing_now)
        consensus_tbl_seed = pd.DataFrame()
        assert False, f"Cannot proceed without columns: {missing_now}"

    # 5) Filter to rows that actually have both fields
    seed_with_tax = seed_with_tax.dropna(subset=["cluster_genus", "gene_symbol"])

    # ---------- weighted consensus ----------

    def _aligned_weighted_median(series: pd.Series, row_weights: pd.Series) -> float:
        """
        Compute a weighted median for a numeric series, aligning weights to the series' index.
        """
        snum = pd.to_numeric(series, errors="coerce").dropna()
        if snum.empty:
            return float("nan")
        # Align weights to the numeric series' index
        w = row_weights.reindex(snum.index).fillna(0.0).to_numpy(dtype=float)
        if np.sum(w) == 0.0:
            w = np.ones_like(snum.to_numpy(), dtype=float)
        return weighted_median(snum.to_numpy(dtype=float), w)

    seed_consensus_rows = []
    grouped = seed_with_tax.groupby(["gene_symbol", "exon_num_in_chain"], sort=False)

    for (g, e), df_grp in tqdm(grouped, desc="Building Seed Consensus"):
        # Compute phylogenetic weights per genus present in this group
        genera = df_grp["cluster_genus"].astype(str).str.lower().unique().tolist()
        weights_by_genus = phylo_engine.weights_from_mrca(genera)

        # Map each row to its genus weight (default 0.0); will be re-aligned per series below
        row_w = df_grp["cluster_genus"].astype(str).str.lower().map(weights_by_genus).fillna(0.0)

        cb = _aligned_weighted_median(df_grp["begin_aa"], row_w)
        ce = _aligned_weighted_median(df_grp["end_aa"], row_w)

        if not np.isnan(cb) and not np.isnan(ce):
            seed_consensus_rows.append(
                {"gene_symbol": g, "exon_num_in_chain": e, "cons_begin": int(round(cb)), "cons_end": int(round(ce))}
            )

    consensus_tbl_seed = pd.DataFrame(seed_consensus_rows)
    logger.info("Generated seed consensus for %d canonical exons.", len(consensus_tbl_seed))

else:
    consensus_tbl_seed = pd.DataFrame()
    logger.error("No seed exon data or phylogenetic engine available to build consensus. Halting rescue.")
    if 'phylo_engine' in globals() and not phylo_engine.is_ready():
        logger.error("Phylogenetic Engine failed to initialise. Check paths in Cell 12 and cache in Cell 15.")
    assert False, "Cannot proceed without seed consensus."

# --- Step 2: Build Regex Baits (Logic from original RegExTractor) ---
logger.info("--- Building Regex 'Baits' from Seed Consensus ---")

@dataclass
class ExonBait:
    regex: re.Pattern
    median_len: int

regex_baits_library: dict[tuple[str, int], ExonBait] = {}

if not consensus_tbl_seed.empty:
    grouped = seed_with_tax.groupby(["gene_symbol", "exon_num_in_chain"], sort=False)
    for (gene, exon_num), group_df in tqdm(grouped, desc="Building Regex Baits"):
        peptides = group_df["peptide"].dropna().tolist()
        if len(peptides) < 5:
            continue

        lens = [len(p) for p in peptides if p]
        if not lens:
            continue
        median_len = int(np.median(lens))
        if median_len <= 0:
            continue

        pattern_parts: list[str] = []
        for i in range(median_len):
            chars_at_pos = [p[i] for p in peptides if len(p) > i]
            if not chars_at_pos:
                continue

            counts = defaultdict(int)
            for ch in chars_at_pos:
                counts[ch] += 1
            total = sum(counts.values())
            keep = [c for c, n in counts.items() if (n / total) >= 0.05]

            if len(keep) > 1:
                # sort for reproducibility; escape only if non-alnum appears
                part = "[" + "".join(sorted(keep)) + "]"
                pattern_parts.append(part)
            elif len(keep) == 1:
                pattern_parts.append(re.escape(keep[0]))
            else:
                pattern_parts.append(".")

        regex_str = "".join(pattern_parts)
        try:
            compiled = re.compile(regex_str)
        except re.error as exc:
            logger.error("Failed to compile regex for (%s, %s): %s", gene, exon_num, exc)
            continue

        regex_baits_library[(gene, exon_num)] = ExonBait(regex=compiled, median_len=median_len)

    logger.info("Successfully built %d high-confidence exon baits.", len(regex_baits_library))


## Cell 51 – Architecture-Driven Rescue Runner

This is the core of the "fishing" operation. It iterates through every sequence in the entire `working_df`. For each sequence, it uses the regex bait library to find the best match for *every expected exon* in the consensus architecture defined in the previous cell. If an exon cannot be found, it is marked as `MISSING_EXON`, ensuring that the final output is a complete and uniformly structured dataset for every protein.

In [None]:
# ===== Cell 51 =====
# Architecture-Driven Rescue Runner ("Fishing")

logger.info("--- Starting Architecture-Driven Rescue ('Fishing') ---")
rescued_rows = []

# Get a lookup of expected exons for each gene from the seed consensus
expected_exons_by_gene = consensus_tbl_seed.groupby('gene_symbol')['exon_num_in_chain'].apply(list).to_dict()

if 'working_df' in globals() and not working_df.empty:
    for _, row in tqdm(working_df.iterrows(), total=len(working_df), desc="Rescuing Exon Architectures"):
        accession = row['Entry']
        gene = row['gene_symbol_norm']
        sequence = row['Sequence']

        expected_exons = sorted(expected_exons_by_gene.get(gene, []))
        if not expected_exons:
            continue

        last_exon_end = 0
        for exon_num in expected_exons:
            bait_info = regex_baits_library.get((gene, exon_num))
            if not bait_info:
                # If no bait, create a missing placeholder
                rescued_rows.append({'accession': accession, 'gene_symbol': gene, 'exon_num_in_chain': exon_num, 'peptide': 'NO_BAIT_DEFINED'})
                continue

            # Search for the pattern, starting from the end of the last found exon
            best_match = None
            search_region = sequence[last_exon_end:]

            for match in bait_info.regex.finditer(search_region):
                # A simple scoring: prioritize matches closest to the expected length
                length_diff = abs(len(match.group(0)) - bait_info.median_len)
                if best_match is None or length_diff < best_match['score']:
                    best_match = {'match': match, 'score': length_diff}

            if best_match:
                match_obj = best_match['match']
                start, end = match_obj.span()
                # Adjust coordinates to be relative to the full sequence
                start += last_exon_end
                end += last_exon_end

                rescued_rows.append({
                    'accession': accession, 'gene_symbol': gene, 'exon_num_in_chain': exon_num,
                    'begin_aa': start + 1, # 1-based coordinates
                    'end_aa': end,
                    'peptide': match_obj.group(0),
                    'source': 'rescued'
                })
                last_exon_end = end
            else:
                # If no match found, this exon is missing
                rescued_rows.append({
                    'accession': accession, 'gene_symbol': gene, 'exon_num_in_chain': exon_num,
                    'peptide': 'MISSING_EXON', 'source': 'rescued_missing'
                })

    df_rescued_exons = pd.DataFrame(rescued_rows)
    logger.info(f"✅ Rescue complete. Reconstructed architectures for {df_rescued_exons['accession'].nunique()} proteins.")
    logger.info(f"Found {len(df_rescued_exons[df_rescued_exons['source'] == 'rescued'])} exons.")
    logger.info(f"Identified {len(df_rescued_exons[df_rescued_exons['source'] == 'rescued_missing'])} missing exons.")
else:
    df_rescued_exons = pd.DataFrame()
    logger.warning("Working dataframe is empty. Skipping rescue.")

## Cell 52 – Consolidate Data and Finalize Raw Exon Cache

This cell finalizes the rescue process. It merges the original high-confidence seed data with the newly rescued exon data, ensuring a complete and consistent dataset. This consolidated data is then atomically written to the master `raw_exons_cache.tsv`, overwriting older, less complete entries. Finally, it reloads this master cache into the main `df_raw_exons` DataFrame, which will be used for all downstream analyses in Part 5.

In [None]:
# ===== Cell 52 =====
# Consolidate Rescued Data and Update Master Cache

logger.info("--- Consolidating seed and rescued exon data ---")

if 'df_rescued_exons' in globals() and not df_rescued_exons.empty:
    # Mark the seed data for clarity
    df_raw_exons_seed['source'] = 'seed'

    # Combine the dataframes
    df_complete_exons = pd.concat([df_raw_exons_seed, df_rescued_exons], ignore_index=True)

    # The key to consolidation: remove duplicates, preferring the rescued version
    # because it is part of a complete, reconstructed architecture.
    key_cols = ['accession', 'gene_symbol', 'exon_num_in_chain']
    df_complete_exons.sort_values('source', ascending=False, inplace=True) # puts 'rescued' before 'seed'
    df_complete_exons.drop_duplicates(subset=key_cols, keep='first', inplace=True)

    logger.info("Merging rescued data into the master exon cache...")
    atomic_merge_into_cache(df_complete_exons)

else:
    logger.warning("No rescued exons to consolidate. Using only seed data.")

# --- CRITICAL STEP: Reload the global df_raw_exons from the now-complete cache ---
logger.info("Reloading the master raw exons dataset for final analysis...")
if RAW_EXONS_CACHE.exists():
    df_raw_exons = safe_read_tsv(RAW_EXONS_CACHE)
    logger.info(f"✅ Loaded final, complete raw exons dataset with {len(df_raw_exons)} rows.")
else:
    df_raw_exons = pd.DataFrame()
    logger.warning("Master raw exons cache is not available for Part 5.")

# **Part 6: Final Consensus & Evolutionary Analysis**

## Cell 61 – Weighted Consensus Calculation

This cell calculates the consensus start and end coordinates for each exon across all species. It uses the `PhylogeneticConsensusEngine` for weighting and applies an adaptive tolerance based on the phylogenetic diversity of the group. **Crucially, it first validates and backfills taxonomic data into older cache entries to ensure data consistency.**

In [None]:
# ===== Cell 61 =====
# FAST weighted consensus calculation (with integrated cache upgrade)

def weighted_median(values: np.ndarray, weights: np.ndarray) -> float:
    """Computes the weighted median of an array."""
    if len(values) == 0: return float("nan")
    idx = np.argsort(values)
    v, w = values[idx], weights[idx]
    c = np.cumsum(w) / np.sum(w)
    j = min(np.searchsorted(c, 0.5), len(v)-1)
    return float(v[j])

def adaptive_tolerance(depth_proxy: float) -> int:
    """Calculates boundary tolerance based on phylogenetic depth."""
    if pd.isna(depth_proxy): return 3
    x = float(np.clip(depth_proxy, 0.0, 3.0))
    return int(round(2.0 + x))

consensus_rows, refined_rows = [], []
consensus_tbl, consensus_long = pd.DataFrame(), pd.DataFrame()

if 'df_raw_exons' in globals() and not df_raw_exons.empty:
    base_df = df_raw_exons.copy()
    base_df = base_df[base_df['peptide'] != 'MAPPING_FAILED'].copy()

    # --- ROBUST CACHE UPGRADE AND VALIDATION ---
    required_cols = ['cluster_genus', 'Family', 'Order']
    missing_cols = [c for c in required_cols if c not in base_df.columns]

    if missing_cols:
        logger.warning(f"Cache is missing required columns: {missing_cols}. Backfilling from current working_df...")
        tax_cols_to_add = ['Entry', 'cluster_genus', 'Family', 'Order', 'Clas_id']
        available_tax_cols = [c for c in tax_cols_to_add if c in working_df.columns]

        if 'Entry' in available_tax_cols:
            lineage_lut = working_df[available_tax_cols].drop_duplicates(subset=['Entry'])
            lineage_lut = lineage_lut.rename(columns={'Entry': 'accession'})
            if 'organism' in base_df.columns: base_df = base_df.drop(columns=['organism'])
            base_df = pd.merge(base_df, lineage_lut, on='accession', how='left')
            logger.info("Cache backfill complete.")

    initial_rows = len(base_df)
    base_df.dropna(subset=['cluster_genus'], inplace=True)
    final_rows = len(base_df)
    if initial_rows > final_rows:
        logger.info(f"Removed {initial_rows - final_rows} orphaned exon rows from cache that are not in the current working set.")

    grouped = base_df.groupby(['gene_symbol', 'exon_num_in_chain'], sort=False)
    logger.info(f"Calculating consensus for {len(grouped)} exon groups...")

    for (g, e), df in tqdm(grouped, desc="Calculating Consensus"):
        genera = df['cluster_genus'].tolist()
        weights_by_genus = phylo_engine.weights_from_mrca(genera)

        ds = [(1.0/max(1e-9, w) - 1.0) for w in weights_by_genus.values()]
        mean_depth = float(np.mean(ds)) if ds else float('nan')
        tol = adaptive_tolerance(mean_depth)

        w = df['cluster_genus'].map(weights_by_genus).fillna(0.0).to_numpy()
        if w.sum() == 0: w = np.ones(len(df)) / max(1, len(df))

        b = df['begin_aa'].astype(float).to_numpy()
        epos = df['end_aa'].astype(float).to_numpy()
        cb = int(round(weighted_median(b, w)))
        ce = int(round(weighted_median(epos, w)))

        consensus_rows.append({
            'gene_symbol': g, 'exon_num_in_chain': e,
            'cons_begin': cb, 'cons_end': ce,
            'tolerance_aa': tol, 'depth_proxy': mean_depth
        })

        for _, r in df.iterrows():
            db, de = int(r['begin_aa']) - cb, int(r['end_aa']) - ce
            refined_rows.append({
                'accession': r['accession'], 'gene_symbol': g, 'exon_num_in_chain': e,
                'begin_aa': int(r['begin_aa']), 'end_aa': int(r['end_aa']),
                'peptide': r.get('peptide',''), 'cluster_genus': r.get('cluster_genus',''),
                'cons_begin': cb, 'cons_end': ce, 'delta_begin': db, 'delta_end': de,
                'boundary_ok': (abs(db) <= tol) and (abs(de) <= tol)
            })

    consensus_tbl = pd.DataFrame(consensus_rows)
    consensus_long = pd.DataFrame(refined_rows)

    consensus_tbl.to_csv(CONSENSUS_TABLE_TSV, sep='\t', index=False)
    consensus_long.to_csv(CONSENSUS_LONG_TSV, sep='\t', index=False)
    logger.info(f"Consensus calculated. Long table: {len(consensus_long)} rows, Consensus table: {len(consensus_tbl)} rows.")
    logger.info(f"Results saved to {CONSENSUS_LONG_TSV.name} and {CONSENSUS_TABLE_TSV.name}")
else:
    logger.warning("No raw exon data available; skipping consensus calculation.")

## Cell 62 – MRCA Level & Reliability Scoring

This cell determines the deepest taxonomic level (e.g., Genus, Family) at which exon boundaries remain consistent. It also computes a 0-100 reliability score for each exon group, combining phylogenetic diversity, boundary consistency, and sequence quality metrics.

In [None]:
# ===== Cell 62 =====
# MRCA level + reliability scoring

def mrca_level_for_group(df: pd.DataFrame, tol: int) -> str:
    """Determines the deepest taxonomic rank with consistent exon boundaries."""
    levels = [('cluster_genus', 'Genus'), ('Family', 'Family'), ('Order', 'Order'), ('Clas_id', 'Class')]
    for col, label in levels:
        if col not in df.columns: continue
        is_consistent = True
        for _, sub in df.groupby(df[col].astype(str).fillna('NA')):
            spread = (sub['end_aa'] - sub['begin_aa']).std()
            if pd.notna(spread) and spread > tol:
                is_consistent = False; break
        if is_consistent: return label
    return "None"

def reliability_score_for_group(df: pd.DataFrame, tol: int) -> float:
    """Computes a 0-100 reliability score for an exon group."""
    n_genera = df['cluster_genus'].nunique()
    n_orders = df['Order'].nunique() if 'Order' in df.columns else 0
    phylo_div = min(1.0, 0.2 * n_genera + 0.1 * n_orders)
    ok_frac = float((df['boundary_ok'] == True).mean())
    seq_q = df.get('quality_score', pd.Series([50.0])).fillna(50.0).mean() / 100.0
    tol_factor = 1.0 - np.clip((tol - 2) / 3.0, 0, 1) * 0.2 # Penalize wide tolerance
    score = 100.0 * (0.45 * phylo_div + 0.35 * ok_frac + 0.20 * seq_q) * tol_factor
    return float(np.clip(score, 0, 100))

mrca_rows = []

if not consensus_long.empty and not consensus_tbl.empty:
    logger.info("Building a comprehensive taxonomic lookup from the master dataset...")
    lineage_cols = ['Entry', 'Clas_id', 'Order', 'Family', 'cluster_genus']

    source_df = full_df if 'full_df' in globals() and not full_df.empty else working_df

    lineage_lut = source_df[[c for c in lineage_cols if c in source_df.columns]].drop_duplicates(subset=['Entry'])
    lineage_lut = lineage_lut.rename(columns={'Entry':'accession'})

    q_cols = ['accession', 'exon_num_in_chain', 'quality_score']
    q_lut = df_raw_exons[[c for c in q_cols if c in df_raw_exons.columns]].drop_duplicates()

    merged = consensus_long.merge(consensus_tbl, on=['gene_symbol','exon_num_in_chain'])
    merged = merged.merge(lineage_lut, on='accession', how='left')
    if 'quality_score' in q_lut.columns:
        merged = merged.merge(q_lut, on=['accession','exon_num_in_chain'], how='left')

    for (g, e), sub in tqdm(merged.groupby(['gene_symbol','exon_num_in_chain']), desc="Scoring Reliability"):
        if 'cluster_genus' not in sub.columns or sub['cluster_genus'].isna().all():
            logger.warning(f"Skipping scoring for {g} exon {e} due to missing taxonomic data even after master lookup.")
            mrca = "N/A"
            rel = 0.0
        else:
            tol = int(sub['tolerance_aa'].iloc[0])
            mrca = mrca_level_for_group(sub, tol)
            # --- FIX: Pass the 'tol' argument to the function call ---
            rel = reliability_score_for_group(sub.dropna(subset=['cluster_genus']), tol)

        mrca_rows.append({
            'gene_symbol': g, 'exon_num_in_chain': e,
            'MRCA_level': mrca, 'reliability_score': round(rel, 1),
            'tolerance_aa': sub['tolerance_aa'].iloc[0],
            'depth_proxy': sub['depth_proxy'].iloc[0]
        })

mrca_df = pd.DataFrame(mrca_rows)
if not mrca_df.empty:
    logger.info(f"MRCA/reliability computed for {len(mrca_df)} exon groups.")
else:
    logger.warning("No consensus data to calculate MRCA/reliability scores.")

## Cell 63 – Evolutionary Event Dating

This cell leverages the phylogenetic dating engine to analyze the `consensus_long` table. It identifies subgroups of species that share specific structural traits (e.g., a shift in an exon's start boundary) and uses the pre-computed node age cache to estimate *when* that trait likely evolved by finding the age of the Most Recent Common Ancestor (MRCA) of the subgroup. The results provide a quantitative model of how the gene's exon architecture evolved over time.

In [None]:
# ===== Cell 63 =====
# Model Exon Evolution by Dating Structural Changes

if 'consensus_long' in globals() and not consensus_long.empty and 'phylo_engine' in globals():
    logger.info("Modeling exon evolution by dating structural changes...")

    evolutionary_events = []

    # Group by each consensus exon (gene + exon number)
    grouped = consensus_long.groupby(['gene_symbol', 'exon_num_in_chain'])

    for (gene, exon_num), group_df in tqdm(grouped, desc="Dating Evolutionary Events"):

        # 1. Find the BASAL MRCA for the existence of this exon in this dataset
        all_genera_in_group = group_df['cluster_genus'].unique().tolist()
        basal_mrca_name, basal_mrca_age = phylo_engine.get_mrca_age(all_genera_in_group)

        # 2. Identify subgroups based on shared structural TRAITS
        # Example trait: N-terminal boundary shifts (delta_begin != 0)
        boundary_shifts = group_df['delta_begin'].unique()
        for shift in boundary_shifts:
            if shift == 0: continue # This is the consensus state, not an event

            subgroup_df = group_df[group_df['delta_begin'] == shift]
            subgroup_genera = subgroup_df['cluster_genus'].unique().tolist()

            # 3. Find the MRCA for the subgroup that shares the trait
            event_mrca_name, event_mrca_age = phylo_engine.get_mrca_age(subgroup_genera)

            if event_mrca_age > 0:
                evolutionary_events.append({
                    "gene": gene,
                    "exon": exon_num,
                    "event_type": "boundary_shift",
                    "event_detail": f"delta_begin = {int(shift)}",
                    "basal_mrca_age_MYA": round(basal_mrca_age, 2),
                    "event_mrca_age_MYA": round(event_mrca_age, 2),
                    "branch_of_origin_MY": round(basal_mrca_age - event_mrca_age, 2),
                    "event_mrca_name": event_mrca_name,
                    "num_taxa_with_event": len(subgroup_df)
                })

    # 4. Create and save the final event DataFrame
    if evolutionary_events:
        df_evolution = pd.DataFrame(evolutionary_events)
        df_evolution.to_csv(EVOLUTION_EVENTS_TSV, sep='\t', index=False)
        logger.info(f"✅ Identified and dated {len(df_evolution)} evolutionary events.")
        logger.info(f"Saved event model to: {EVOLUTION_EVENTS_TSV.name}")
        display(df_evolution.head())
    else:
        logger.info("No distinct evolutionary events (like boundary shifts) were identified.")

else:
    logger.warning("Skipping evolutionary event dating: consensus data or phylo engine not available.")

## Cell 64 – Final Wide Architecture Generation

This cell transforms the long-format consensus data into a wide-format table where each row represents a single protein sequence and columns represent individual exons. This "exon architecture" format is useful for comparative analysis and visualization. The process is memory-optimized to handle large datasets.

In [None]:
# ===== Cell 64 =====
# Final wide architecture with memory management

def create_wide_architecture(
    long_df: pd.DataFrame,
    tbl_df: pd.DataFrame,
    mrca_df: pd.DataFrame
) -> pd.DataFrame:
    """Builds the final wide-format exon architecture dataframe."""
    if long_df.empty: return pd.DataFrame()

    logger.info(f"Creating wide architecture from {long_df['accession'].nunique()} accessions.")

    # 1. Pivot the long table to create exon-specific columns for sequence data
    pivoted = long_df.pivot_table(
        index=['accession', 'gene_symbol'],
        columns='exon_num_in_chain',
        values=['peptide', 'begin_aa', 'end_aa', 'boundary_ok'],
        aggfunc='first'
    )
    pivoted.columns = [f"exon_{v}_{c}" if c != '' else v for v, c in pivoted.columns]
    pivoted.reset_index(inplace=True)

    # 2. Prepare a separate, wide-format table for consensus and reliability info
    annot = tbl_df.merge(mrca_df, on=['gene_symbol', 'exon_num_in_chain'], how='left')

    annot['cons_coords'] = annot['cons_begin'].astype(str) + '-' + annot['cons_end'].astype(str)
    annot = annot.rename(columns={'reliability_score': 'cons_reliability'})

    # Pivot the annotation data to a wide format
    consensus_wide = annot.pivot_table(
        index='gene_symbol',
        columns='exon_num_in_chain',
        values=['cons_coords', 'cons_reliability'],
        # --- FIX: Explicitly set aggfunc to 'first' to handle string data ---
        # The default 'mean' fails on string columns like 'cons_coords'.
        aggfunc='first'
    )
    # Flatten the multi-level column index from the pivot
    consensus_wide.columns = [f"cons_exon_{v}_{c}" for v, c in consensus_wide.columns]
    consensus_wide.reset_index(inplace=True)

    # 3. Perform a single, efficient merge to combine the two wide tables
    final_wide_df = pd.merge(pivoted, consensus_wide, on='gene_symbol', how='left')

    return final_wide_df

if 'consensus_long' in globals() and not consensus_long.empty:
    wide_df = create_wide_architecture(consensus_long, consensus_tbl, mrca_df)
    if not wide_df.empty:
        wide_df = wide_df.copy()
        wide_df.to_csv(WIDE_ARCH_TSV, sep='\t', index=False)
        logger.info(f"Final wide architecture saved: {len(wide_df)} rows, {len(wide_df.columns)} columns.")
        logger.info(f"File saved to: {WIDE_ARCH_TSV}")
    else:
        logger.warning("Wide DataFrame is empty after processing.")
else:
    logger.warning("No consensus data available; skipping final wide architecture.")
    wide_df = pd.DataFrame()

# **Part 7: RegExTractor – Gene Classification Engine**

## Cell 70 – RegExTractor Configuration

This cell defines the user-configurable parameters for the RegExTractor engine. In its new role, the engine uses the final, high-quality consensus architectures to build orthology-aware patterns. It then scans a pool of previously rejected or unclassified sequences to predict their gene identity.

In [None]:
# ===== Cell 70 =====
# RegExTractor configuration

from __future__ import annotations
from dataclasses import dataclass, field

# --- User-tunable parameters for RegExTractor ---
rex_min_clade_samples = 5           #@param {type:"integer"}
rex_anchor_triplets_min = 2         #@param {type:"integer"}
rex_anchor_triplets_max = 4         #@param {type:"integer"}
rex_anchor_entropy_max = 0.25       #@param {type:"number"}
rex_freq_threshold_strict = 0.05    #@param {type:"number"}
rex_freq_threshold_moderate = 0.01  #@param {type:"number"}
rex_ghead_density_min = 0.80        #@param {type:"number"}
rex_len_tolerance_strict = 1        #@param {type:"integer"}
rex_len_tolerance_moderate = 3      #@param {type:"integer"}
rex_len_tolerance_loose = 6         #@param {type:"integer"}
rex_chain_min_consecutive = 3       #@param {type:"integer"}
rex_search_window_pad = 90          #@param {type:"integer"}

# --- New Path Variables for Classification Output ---
# NOTE: These should be added to the path definitions in Cell 12
CLASSIFIED_HITS_TSV = RUN_DIR / "classified_hits.tsv"
CLASSIFIED_CHAINS_TSV = RUN_DIR / "classified_chains.tsv"

def rex_log(msg: str):
    """Lightweight logger for the RegExTractor module."""
    logger.info(f"[RegExTractor] {msg}")

## Cell 71 – Data Preparation for Classification

This cell prepares the two key inputs for the classification engine:
1.  **Training Data:** The final, high-quality, and complete exon data from the `wide_df` (output of Part 6). This data is now orthology-aware, including the `paralog_group` for each sequence.
2.  **Target Pool:** The set of truly unclassified sequences—those from the original `full_df` that were rejected during the initial pre-processing in Part 2.

In [None]:
# ===== Cell 71 =====
# Input preparation for RegExTractor Classification

# 1. Create the TRAINING set from the final, rescued wide_df (from Part 6)
training_rows_rex = []

if 'wide_df' in globals() and isinstance(wide_df, pd.DataFrame) and not wide_df.empty:
    rex_log(f"[71] Preparing training set from wide_df: rows={len(wide_df)}")

    # We need the paralog group and taxonomy for each entry; merge from working_df
    training_info_cols = ['Entry', 'paralog_group', 'Genus', 'Family', 'Order']
    if 'working_df' in globals() and isinstance(working_df, pd.DataFrame) and not working_df.empty:
        have_cols = [c for c in training_info_cols if c in working_df.columns]
        info_lut = working_df[have_cols].rename(columns={'Entry':'accession'})
        rex_log(f"[71] Info LUT columns: {have_cols}")
        wide_for_training = wide_df.merge(info_lut, on='accession', how='left')
    else:
        rex_log("[71] WARNING: working_df missing/empty; proceeding without taxonomy/paralog_group merge.")
        wide_for_training = wide_df.copy()

    rex_log(f"[71] Merged view: rows={len(wide_for_training)}; building long-form exon table…")

    # Extract all exon peptides from the wide format into a long format
    err, done = 0, 0
    for i, (_, row) in enumerate(wide_for_training.iterrows(), 1):
        # Find all peptide columns for this row
        pep_cols = {int(c.split('_')[-1]): c for c in row.index if str(c).startswith('exon_peptide_')}
        try:
            for exon_num, pep_col in pep_cols.items():
                peptide = row[pep_col]
                # Train on all available exons, even those previously missing
                if pd.notna(peptide) and peptide not in ["MISSING_EXON", "NO_BAIT_DEFINED"]:
                    rec = {
                        "accession": row.get("accession"),
                        "gene_symbol": row.get('gene_symbol'),
                        "paralog_group": row.get('paralog_group', 'unknown_paralog'),
                        "exon_num_in_chain": exon_num,
                        "exon_peptide": peptide,
                        "genus": row.get('Genus'),
                        "family": row.get('Family'),
                        "order": row.get('Order'),
                    }
                    training_rows_rex.append(rec)
            done += 1
        except Exception as e:
            err += 1
            if err <= 5:
                rex_log(f"[71] Row parse error example (#{i}): {e}")
        # Light progress ping every 2k rows
        if i % 2000 == 0:
            rex_log(f"[71] Processed {i} records…")

    training_df_rex = pd.DataFrame(training_rows_rex)
    rex_log(f"[71] Created RegExTractor training set with {len(training_df_rex)} exon rows "
            f"(source records ok={done}, errors={err}).")

else:
    rex_log("[71] WARNING: wide_df missing/empty; training_df_rex will be empty.")
    training_df_rex = pd.DataFrame(columns=[
        "accession","gene_symbol","paralog_group","exon_num_in_chain","exon_peptide",
        "genus","family","order"
    ])

# 2. Create the TARGET pool of unclassified sequences to scan
target_pool_rex = pd.DataFrame()
if 'full_df' in globals() and 'working_df' in globals() \
   and isinstance(full_df, pd.DataFrame) and isinstance(working_df, pd.DataFrame) \
   and not full_df.empty and not working_df.empty:

    classified_accessions = set(working_df['Entry'].dropna().astype(str).unique())
    unclassified_df = full_df[~full_df['Entry'].astype(str).isin(classified_accessions)].copy()

    # Minimum taxonomy for targets
    unclassified_df['Genus'] = unclassified_df['Organism'].str.split().str[0] if 'Organism' in unclassified_df.columns else pd.NA

    target_pool_rex = unclassified_df[['Entry', 'Sequence', 'Genus']].rename(
        columns={'Entry':'accession', 'Sequence':'sequence', 'Genus': 'genus'}
    )
    rex_log(f"[71] Created classification target pool with {len(target_pool_rex)} unclassified sequences.")
else:
    rex_log("[71] WARNING: full_df/working_df missing/empty; target_pool_rex will be empty.")


## Cell 72 – Paralog Annotation & Validation for RegEx Training
This single cell **replaces 72B/72C/72E/72F** by combining enrichment, inference, and
conservative validation of *gene-level* paralogs for `training_df_rex`.

**Inputs (assumed present):**
- `training_df_rex` (required)
- `df_high_quality` (optional; used to enrich missing text/meta via a pragmatic join)

**What this cell does (non-destructive):**
1) *(Optional)* **Enrich** `training_df_rex` with `gene_symbol_norm`, “Gene Names (primary)”,
   “Protein names”, `stitle`, `description`, `raw_header` from `df_high_quality`
   using common ID keys (`seq_id`, `accession`, `uniprot`, etc.).
2) **Infer** a *gene-level* paralog label (`COLxAy`) with priority:
   `gene_symbol_norm` → “Gene Names (primary)” → parsed **Protein names** / titles
   (e.g., “Collagen alpha-1(I) chain” → `COL1A1`).
   - Adds (but **does not overwrite** existing values):  
     `paralog_group`, `paralog_source`, `paralog_confidence`.
3) **Validate** with a conservative relabel:
   - Adds: `paralog_group_validated`, `paralog_validated_source`,
     `paralog_validated_confidence`.
   - Conflict audit: `paralog_conflict_flag`, `paralog_conflict_reason`
     (e.g., flags likely **XI ↔ I** confusions).
4) **Diagnostics**: compact counts + top groups.
5) **Snapshot** (optional): writes `training_df_rex_paralog_annotated.tsv` to `RUN_DIR` (if set).

**Notes**
- No functions/variables elsewhere are renamed or removed.
- Existing `paralog_group` is **only filled where missing** (never overwritten).
- Use `paralog_group_validated` downstream if you want the conservative label.


In [None]:
# ===== Cell 72 =====
# Paralog Annotation & Validation (combined, non-destructive)

import re
from pathlib import Path
from typing import Optional, List, Tuple
import pandas as pd

# ----------------------- User-tunable controls (#@param) -----------------------
ATTEMPT_JOIN_FROM_DF_HQ   = True   #@param {type:"boolean"}
STRICT_GENE_PATTERN       = True   #@param {type:"boolean"}  # require COL\d+A\d+
ENABLE_DIAGNOSTICS        = True   #@param {type:"boolean"}
SAVE_SNAPSHOT             = True   #@param {type:"boolean"}

# -------------------------- Helpers (local to cell) ----------------------------
_ROMAN_TO_ARABIC = {
    "I":1,"II":2,"III":3,"IV":4,"V":5,"VI":6,"VII":7,"VIII":8,"IX":9,
    "X":10,"XI":11,"XII":12,"XIII":13,"XIV":14,"XV":15,"XVI":16,"XVII":17,
    "XVIII":18,"XIX":19,"XX":20,"XXI":21,"XXII":22,"XXIII":23,"XXIV":24
}

_RX_COL_LITERAL      = re.compile(r'(COL\d+A\d+)', re.IGNORECASE)
_RX_PROTEIN_COLLAGEN = re.compile(
    r'\b(?:pro-)?collagen\s+alpha[-\s]?(?P<chain>\d)\s*\(\s*(?P<roman>[IVXLC]+)\s*\)\s*chain\b',
    re.IGNORECASE
)
_RX_TYPE_FIRST = re.compile(
    r'\btype\s+(?P<roman>[IVXLC]+)\s+collagen\b.*?\balpha[-\s]?(?P<chain>\d)\b',
    re.IGNORECASE
)

def _looks_like_col_gene(s: str) -> bool:
    return bool(re.fullmatch(r'COL\d+A\d+', str(s).strip().upper()))

def _coalesce(row: pd.Series, cols: List[str]) -> Optional[str]:
    for c in cols:
        v = row.get(c)
        if isinstance(v, str) and v.strip():
            return v.strip()
    return None

def _parse_gene_from_text(txt: Optional[str]) -> Optional[str]:
    """Parse from literal COLxAy or 'Collagen alpha-<n>(<ROMAN>) chain' / 'type <ROMAN> collagen ... alpha-<n>'."""
    if not isinstance(txt, str) or not txt.strip():
        return None
    s = txt.strip()
    m = _RX_COL_LITERAL.search(s)
    if m:
        return m.group(1).upper()
    m = _RX_PROTEIN_COLLAGEN.search(s) or _RX_TYPE_FIRST.search(s)
    if not m:
        return None
    try:
        chain = int(m.group("chain"))
        roman = m.group("roman").upper()
        ctype = _ROMAN_TO_ARABIC.get(roman)
        return f"COL{ctype}A{chain}" if ctype else None
    except Exception:
        return None

def _pick_join_key(left: List[str], right: List[str]) -> Optional[str]:
    for k in ["seq_id","source_seq_id","record_id","qseqid","sseqid",
              "accession","uniprot","uniprot_id","uniprot_accession",
              "subject_id","protein_id","entry","Entry"]:
        if k in left and k in right:
            return k
    return None

def _enrich_from_df_hq(df_train: pd.DataFrame, df_hq: pd.DataFrame) -> pd.DataFrame:
    key = _pick_join_key(list(df_train.columns), list(df_hq.columns))
    if not key:
        return df_train
    cols_want = [c for c in [
        key, "gene_symbol_norm", "gene_symbol",
        "Gene Names (primary)", "Protein names",
        "stitle","description","raw_header",
        "species","species_name",
        "accession","uniprot","uniprot_id","uniprot_accession"
    ] if c in df_hq.columns]
    return df_train.merge(df_hq[cols_want].drop_duplicates(), on=key, how="left")

def _infer_gene_paralog(row: pd.Series) -> Tuple[Optional[str], str, float]:
    # 1) gene_symbol_norm / gene_symbol
    gsn = _coalesce(row, ["gene_symbol_norm","gene_symbol"])
    if gsn and (not STRICT_GENE_PATTERN or _looks_like_col_gene(gsn)):
        return gsn.upper(), "gene_symbol_norm", 1.0
    # 2) Gene Names (primary)
    gnp = _coalesce(row, ["Gene Names (primary)","Gene names (primary)","gene_names_primary","Gene Names"])
    if gnp:
        m = _RX_COL_LITERAL.search(gnp)
        if m:
            return m.group(1).upper(), "gene_names_primary", 0.9
    # 3) Protein names / titles
    for col in ["Protein names","Protein name","Protein names (recommended)",
                "protein_name","protein_names","stitle","description","raw_header"]:
        lab = _parse_gene_from_text(row.get(col))
        if lab and (not STRICT_GENE_PATTERN or _looks_like_col_gene(lab)):
            return lab, "protein_or_title.parsed", 0.75
    # 4) fallback to existing (may be empty)
    pg = row.get("paralog_group")
    if isinstance(pg, str) and pg.strip():
        return pg.strip().upper(), "existing", 0.5
    return None, "NA", 0.0

def _validated_label(row: pd.Series) -> Tuple[Optional[str], str, float, bool, str]:
    """
    Conservative validated label & conflict audit.
    Returns: (label, source, conf, conflict_flag, reason)
    """
    original = row.get("paralog_group")
    gsn = _coalesce(row, ["gene_symbol_norm","gene_symbol"])

    # Prefer GSN if it looks like a collagen gene
    if gsn and (not STRICT_GENE_PATTERN or _looks_like_col_gene(gsn)):
        label = gsn.upper()
        conflict = isinstance(original, str) and original.upper() != label
        return label, "gene_symbol_norm", 1.0, conflict, ("GSN!=original" if conflict else "")

    # Gene Names (primary)
    gnp = _coalesce(row, ["Gene Names (primary)","Gene names (primary)","gene_names_primary","Gene Names"])
    if gnp:
        m = _RX_COL_LITERAL.search(gnp)
        if m:
            label = m.group(1).upper()
            conflict = isinstance(original, str) and original.upper() != label
            return label, "gene_names_primary", 0.9, conflict, ("GNP!=original" if conflict else "")

    # Parse text
    for col in ["Protein names","Protein name","Protein names (recommended)",
                "protein_name","protein_names","stitle","description","raw_header"]:
        cand = row.get(col)
        lab = _parse_gene_from_text(cand)
        if lab and (not STRICT_GENE_PATTERN or _looks_like_col_gene(lab)):
            conflict = isinstance(original, str) and original.upper() != lab
            reason = ""
            if conflict and isinstance(cand, str):
                # flag likely XI ↔ I confusion
                if re.search(r'\btype\s+I\b', cand, re.IGNORECASE) and not re.search(r'\btype\s+XI\b', cand, re.IGNORECASE):
                    if isinstance(original, str) and original.upper().startswith("COL11A"):
                        reason = "possible XI↔I confusion"
            return lab, "protein_or_title.parsed", 0.75, conflict, reason

    # Fallback to original
    lab = original.upper() if isinstance(original, str) else None
    return lab, ("existing" if lab else "NA"), (0.5 if lab else 0.0), False, ("fallback_to_existing" if lab else "unlabeled")

# --------------------------------- Main logic ---------------------------------
if "training_df_rex" in globals() and isinstance(training_df_rex, pd.DataFrame) and not training_df_rex.empty:
    df = training_df_rex.copy()

    # (1) Optional enrichment from df_high_quality
    if ATTEMPT_JOIN_FROM_DF_HQ and "df_high_quality" in globals() \
       and isinstance(df_high_quality, pd.DataFrame) and not df_high_quality.empty:
        before_cols = set(df.columns)
        df = _enrich_from_df_hq(df, df_high_quality)
        if ENABLE_DIAGNOSTICS:
            gained = sorted(list(set(df.columns) - before_cols))
            print(f"[Paralog-Enrich] Columns gained from df_high_quality: {gained}")

    # (2) Primary inference → only FILL missing 'paralog_group'
    inferred = df.apply(_infer_gene_paralog, axis=1)
    df["_pg_inferred"]  = inferred.apply(lambda t: t[0])
    df["_pg_src"]       = inferred.apply(lambda t: t[1])
    df["_pg_conf"]      = inferred.apply(lambda t: t[2])

    if "paralog_group" not in df.columns:
        df["paralog_group"] = pd.NA
    if "paralog_source" not in df.columns:
        df["paralog_source"] = pd.NA
    if "paralog_confidence" not in df.columns:
        df["paralog_confidence"] = pd.NA

    mask_missing_pg = ~df["paralog_group"].astype(str).str.strip().astype(bool)
    df.loc[mask_missing_pg, "paralog_group"]      = df.loc[mask_missing_pg, "_pg_inferred"]
    df.loc[mask_missing_pg, "paralog_source"]     = df.loc[mask_missing_pg, "_pg_src"]
    df.loc[mask_missing_pg, "paralog_confidence"] = df.loc[mask_missing_pg, "_pg_conf"]

    # (3) Validated label + conflicts (always computed; non-destructive)
    validated = df.apply(_validated_label, axis=1)
    df["paralog_group_validated"]       = validated.apply(lambda t: t[0])
    df["paralog_validated_source"]      = validated.apply(lambda t: t[1])
    df["paralog_validated_confidence"]  = validated.apply(lambda t: t[2])
    df["paralog_conflict_flag"]         = validated.apply(lambda t: t[3])
    df["paralog_conflict_reason"]       = validated.apply(lambda t: t[4])

    # tidy temp columns
    df.drop(columns=["_pg_inferred","_pg_src","_pg_conf"], inplace=True, errors="ignore")

    # (4) Diagnostics
    if ENABLE_DIAGNOSTICS:
        total = len(df)
        n_pg   = int(df["paralog_group"].notna().sum())
        n_pgv  = int(df["paralog_group_validated"].notna().sum())
        n_conf = int(df["paralog_conflict_flag"].sum())
        print(f"[Paralog] rows={total}; filled(paralog_group)={n_pg}; validated={n_pgv}; conflicts={n_conf}")
        with pd.option_context("display.max_rows", 10):
            print("[Paralog] Top groups:")
            display(df["paralog_group"].value_counts(dropna=True).head(20))
            print("[Paralog-Validated] Top groups:")
            display(df["paralog_group_validated"].value_counts(dropna=True).head(20))
            if n_conf:
                print("[Paralog] Conflict reasons:")
                display(df["paralog_conflict_reason"].value_counts(dropna=True).head(10))

    # (5) Save snapshot
    if SAVE_SNAPSHOT:
        try:
            run_dir = RUN_DIR if "RUN_DIR" in globals() else Path(".")
            outp = Path(run_dir) / "training_df_rex_paralog_annotated.tsv"
            df.to_csv(outp, sep="\t", index=False)
            if "logger" in globals() and logger:
                logger.info(f"Saved paralog-annotated training set → {outp}")
        except Exception as e:
            if "logger" in globals() and logger:
                logger.warning(f"Could not save paralog snapshot: {e}")

    # Commit back to canonical variable (non-destructive to others)
    training_df_rex = df
else:
    print("training_df_rex not found or empty; skipping Cell 71.")


## Cell 73 – Orthology-Aware Regex Library Builder (label-agnostic)

### Purpose
Build an **orthology-aware regex library** of exon motifs per paralog and exon index,
without hard-coding a particular label column. This cell consumes the training set
from Cell 71 and the validated paralog labels from Cell 72, then emits a library
used by the matcher and classifier in Cells 75–76.

### Inputs & Dependencies
- **DataFrames**
  - `training_df_rex` (from Cell 71; required).
- **Labels**
  - Chooses label column via `rex_label_column` (`"auto"`, `"paralog_group_validated"`,
    `"paralog_group"`, or `"custom"`). Creates a safe view `SOURCE_DF` where the chosen
    label is exposed as `paralog_group`.
- **Config (from Cell 70)**
  - `rex_min_clade_samples`, `rex_freq_cutoff_consensus`, `rex_freq_cutoff_stringent`,
    `rex_ghead_min_consensus`, `rex_ghead_min_stringent`, and optional `REX_TAXON_LEVELS`.
- **Expected columns in the effective training view (`SOURCE_DF`)**
  - Required: `paralog_group`, `exon_num_in_chain`, `exon_peptide`.
  - Optional: `gene_symbol` (metadata), clade columns named in `REX_TAXON_LEVELS`
    (e.g., `order`, `family`, `genus`).

### What it creates
- **`SOURCE_DF`** — a copy of `training_df_rex` with the chosen label normalized to
  the column name `paralog_group` for legacy compatibility.
- **`orthology_aware_library`**
  - Internal container (`OrthologyAwareLibrary`) with `entries` keyed by
    `(paralog_group, exon_num_in_chain, clade_key)`.
  - Each value is an `ExonRegexLibraryEntry` holding one or more `RexTier` objects:
    - *consensus* tier (uses `rex_freq_cutoff_consensus`, `rex_ghead_min_consensus`)
    - *stringent* tier (uses `rex_freq_cutoff_stringent`, `rex_ghead_min_stringent`)

### High-level flow
1. **Label selection**  
   Resolve `LABEL_COL` from user setting → prefer `paralog_group_validated` → fallback.
   Build `SOURCE_DF` and expose the chosen label as `paralog_group`.
2. **Sanity checks**  
   Ensure `SOURCE_DF` has `paralog_group`, `exon_num_in_chain`, `exon_peptide`.
   Warn if `unknown_paralog` is present.
3. **Tier construction**  
   For each `(paralog_group, exon_num_in_chain[, clade])` with at least
   `rex_min_clade_samples` peptides:
   - Compute simple position-wise frequencies and median length.
   - Derive per-position regex tokens by thresholding residue frequencies.
   - Build two tiers (*consensus*, *stringent*), compile regex.
4. **Library assembly**  
   Store tiers in `orthology_aware_library.entries[(pg, exon, clade)]`.
   Always attempt a **pan** model; optionally add clade-specific models
   if `REX_TAXON_LEVELS` contains columns present in `SOURCE_DF`.

### Diagnostics (printed)
- Selected label column name, total rows, labeled vs unlabeled counts.
- Final count of library entries and number of distinct paralog groups.

### Notes & Guarantees
- **Non-destructive**: does not rename or drop user data; only creates `SOURCE_DF`
  and `orthology_aware_library`.
- **Label-agnostic**: switch labels by changing `rex_label_column` or specifying a
  `custom_label_name` that exists in `training_df_rex`.
- **Downstream**: Cells **75–76** (matcher & classification) read from
  `orthology_aware_library` and/or `SOURCE_DF`.


In [None]:
# ===== Cell 73 =====
# Orthology-Aware Regex Library Builder (label-agnostic)

from __future__ import annotations

import logging
import re
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, Iterable, List, Optional, Tuple

import numpy as np
import pandas as pd

logger = logging.getLogger(__name__)

# ------------------------------------------------------------------------------
# Label selection (no hard-coding)
# ------------------------------------------------------------------------------
rex_label_column = "auto"   #@param ["auto", "paralog_group_validated", "paralog_group", "custom"]
custom_label_name = ""      #@param {type:"string"}

if "training_df_rex" not in globals() or not isinstance(training_df_rex, pd.DataFrame) or training_df_rex.empty:
    raise RuntimeError("Cell 72: training_df_rex is missing or empty. Build it in Cell 70 and label it in Cell 71.")

LABEL_CANDIDATES = [
    "paralog_group_validated",
    "paralog_group",
    "paralog_group_effective",  # optional alias if you already created it earlier
]

def _select_label_col(df: pd.DataFrame) -> str:
    # explicit choices first
    if rex_label_column == "paralog_group_validated" and "paralog_group_validated" in df.columns:
        return "paralog_group_validated"
    if rex_label_column == "paralog_group" and "paralog_group" in df.columns:
        return "paralog_group"
    if rex_label_column == "custom":
        if custom_label_name and custom_label_name in df.columns:
            return custom_label_name
        raise KeyError(f"Requested custom label column '{custom_label_name}' not found in training_df_rex.")

    # 'auto' fallback: prefer validated
    for c in LABEL_CANDIDATES:
        if c in df.columns:
            return c
    # last resort: try to guess any string column containing 'paralog' in its name
    for c in df.columns:
        if isinstance(c, str) and "paralog" in c.lower():
            return c
    raise KeyError("No suitable label column found. Expected one of: "
                   "paralog_group_validated / paralog_group / custom.")

LABEL_COL = _select_label_col(training_df_rex)

# ------------------------------------------------------------------------------
# Create a safe view for legacy code: always expose chosen label as 'paralog_group'
# ------------------------------------------------------------------------------
SOURCE_DF = training_df_rex.copy()
if LABEL_COL != "paralog_group":
    SOURCE_DF["paralog_group"] = SOURCE_DF[LABEL_COL]

# Light sanity checks
_n = len(SOURCE_DF)
_n_labeled = SOURCE_DF["paralog_group"].notna().sum()
print(f"[Cell 72] Using label column: {LABEL_COL}")
print(f"[Cell 72] Rows: {_n}; labeled: {_n_labeled}; unlabeled: {_n - _n_labeled}")

# Optional: avoid training on unknown labels
if "unknown_paralog" in SOURCE_DF["paralog_group"].astype(str).unique():
    unk = int((SOURCE_DF["paralog_group"] == "unknown_paralog").sum())
    print(f"[Cell 72] Warning: {unk} rows have 'unknown_paralog'. They may be excluded downstream.")

# ---------------------------------------------------------------------------
# Configuration knobs (safe defaults; override earlier in the notebook if desired)
# ---------------------------------------------------------------------------

# Which taxonomic levels to attempt beyond 'pan' (must be column names in training_df)
try:
    REX_TAXON_LEVELS  # noqa: F823
except NameError:
    REX_TAXON_LEVELS: List[str] = ["pan"]  # extend e.g. ["pan", "order", "family"]

# Minimum number of peptide examples per (paralog, exon, clade) to build a model
try:
    rex_min_clade_samples  # noqa: F823
except NameError:
    rex_min_clade_samples: int = 5

# Position-wise residue frequency cut-offs for tiers
try:
    rex_freq_cutoff_consensus  # noqa: F823
except NameError:
    rex_freq_cutoff_consensus: float = 0.05

try:
    rex_freq_cutoff_stringent  # noqa: F823
except NameError:
    rex_freq_cutoff_stringent: float = 0.15

# Tier acceptance thresholds (example semantics: minimal “g-head density” or similar)
try:
    rex_ghead_min_consensus  # noqa: F823
except NameError:
    rex_ghead_min_consensus: float = 0.85

try:
    rex_ghead_min_stringent  # noqa: F823
except NameError:
    rex_ghead_min_stringent: float = 0.92

# ---------------------------------------------------------------------------
# Utility functions for robust column handling and simple stats
# ---------------------------------------------------------------------------

def _has_cols(df: pd.DataFrame, cols: Iterable[str]) -> bool:
    """Return True if all columns exist in the DataFrame."""
    cols = list(cols)
    return all(c in df.columns for c in cols)

def _first_existing_col(df: pd.DataFrame, candidates: List[str]) -> Optional[str]:
    """Return the first candidate column name that exists in df, else None."""
    for c in candidates:
        if c in df.columns:
            return c
    return None

def _coalesce_training_columns(df: pd.DataFrame) -> pd.DataFrame:
    """
    Normalise expected column names in training_df:
      - peptides in 'exon_peptide' (fall back to 'peptide')
      - ensure 'paralog_group' and 'exon_num_in_chain' exist
      - keep 'gene_symbol' if available (used for metadata)
    """
    df = df.copy()

    # Peptide column
    if "exon_peptide" not in df.columns:
        fallback = _first_existing_col(df, ["peptide", "Peptide", "peptides"])
        if fallback:
            df["exon_peptide"] = df[fallback]

    # Sanity checks
    required = ["paralog_group", "exon_num_in_chain", "exon_peptide"]
    missing = [c for c in required if c not in df.columns]
    if missing:
        raise ValueError(f"training_df is missing required columns: {missing}")

    return df

# ---------------------------------------------------------------------------
# Dataclasses expected by downstream matcher/orchestrator
# ---------------------------------------------------------------------------

@dataclass
class RexTier:
    """
    One regex tier describing an exon motif at a given stringency.
    Attributes are aligned with the expectations of RegExTractorMatcher/_score_hit.
    """
    tier: str               # e.g., "consensus", "stringent"
    regex: re.Pattern       # compiled Python regex
    ghead_min: float        # acceptance threshold used by the matcher

@dataclass
class ExonRegexLibraryEntry:
    """Container for all tiers available for one (paralog_group, exon, clade)."""
    gene_symbol: str
    exon_num_in_chain: int
    clade: str              # e.g., "pan", "Cetartiodactyla"
    tiers: List[RexTier]

# ---------------------------------------------------------------------------
# Core builder: statistics + tier construction
# ---------------------------------------------------------------------------

class RegExTractorBuilder:
    """
    Builds per-exon regex tiers from peptide examples.

    The default implementation provides two tiers:
      - 'consensus'  : frequency cut-off = rex_freq_cutoff_consensus; ghead_min = rex_ghead_min_consensus
      - 'stringent'  : frequency cut-off = rex_freq_cutoff_stringent; ghead_min = rex_ghead_min_stringent

    If you have advanced utilities (Shannon entropy, anchor windows, etc.),
    you can refine `build_stats` and/or `build_tiers` here without touching the orchestrator.
    """

    def __init__(self) -> None:
        pass

    @staticmethod
    def build_stats(peptides: List[str]) -> Dict[str, object]:
        """
        Compute simple per-position character frequencies and lengths.

        Returns
        -------
        Dict[str, object]
            keys: 'median_len', 'pos_counts' (list[dict[char->count]]), 'pos_totals' (list[int])
        """
        peps = [p for p in peptides if isinstance(p, str) and p]
        if not peps:
            return {"median_len": 0, "pos_counts": [], "pos_totals": []}

        median_len = int(np.median([len(p) for p in peps]))
        pos_counts: List[Dict[str, int]] = []
        pos_totals: List[int] = []

        for i in range(median_len):
            counts = defaultdict(int)
            total = 0
            for p in peps:
                if len(p) > i:
                    counts[p[i]] += 1
                    total += 1
            pos_counts.append(dict(counts))
            pos_totals.append(total)

        return {"median_len": median_len, "pos_counts": pos_counts, "pos_totals": pos_totals}

    @staticmethod
    def _pattern_from_stats(stats: Dict[str, object], freq_cutoff: float) -> Tuple[str, int]:
        """
        Convert per-position counts to a regex pattern using a frequency threshold.

        Parameters
        ----------
        stats : dict produced by build_stats
        freq_cutoff : float in [0,1], minimal residue frequency to include at a position

        Returns
        -------
        pattern : str
        median_len : int
        """
        median_len: int = int(stats.get("median_len", 0))
        pos_counts: List[Dict[str, int]] = stats.get("pos_counts", [])  # type: ignore
        pos_totals: List[int] = stats.get("pos_totals", [])              # type: ignore

        parts: List[str] = []
        for i in range(median_len):
            counts = pos_counts[i] if i < len(pos_counts) else {}
            total = pos_totals[i] if i < len(pos_totals) else 0
            if not counts or total <= 0:
                parts.append(".")
                continue

            keep = [aa for aa, n in counts.items() if (n / total) >= freq_cutoff]
            if len(keep) == 1:
                parts.append(re.escape(keep[0]))
            elif len(keep) > 1:
                parts.append("[" + "".join(sorted(keep)) + "]")
            else:
                parts.append(".")
        return "".join(parts), median_len

    def build_tiers(self, peptides: List[str], stats: Dict[str, object]) -> List[RexTier]:
        """
        Construct a set of tiers from peptides and stats. Returns an empty list if unusable.
        """
        if not peptides or stats.get("median_len", 0) <= 0:
            return []

        tiers: List[RexTier] = []

        # Consensus tier
        patt_cons, _ = self._pattern_from_stats(stats, freq_cutoff=rex_freq_cutoff_consensus)
        try:
            tiers.append(RexTier(tier="consensus", regex=re.compile(patt_cons), ghead_min=rex_ghead_min_consensus))
        except re.error:
            pass  # ignore un-compilable patterns

        # Stringent tier
        patt_str, _ = self._pattern_from_stats(stats, freq_cutoff=rex_freq_cutoff_stringent)
        try:
            tiers.append(RexTier(tier="stringent", regex=re.compile(patt_str), ghead_min=rex_ghead_min_stringent))
        except re.error:
            pass

        return [t for t in tiers if isinstance(t.regex, re.Pattern)]

# ---------------------------------------------------------------------------
# Library container keyed by (paralog_group, exon_num_in_chain, clade_key)
# ---------------------------------------------------------------------------

class OrthologyAwareLibrary:
    """
    Builds and stores an orthology-aware regex library for exon detection.

    Key space: (paralog_group, exon_num_in_chain, clade_key)
    Each value: ExonRegexLibraryEntry with one or more RexTier objects.
    """

    def __init__(self) -> None:
        self.entries: Dict[Tuple[str, int, str], ExonRegexLibraryEntry] = {}

    def build(self, training_df: pd.DataFrame) -> None:
        """
        Construct the library from training data.

        Expected columns in training_df:
          - 'paralog_group'        : str (not 'unknown_paralog')
          - 'exon_num_in_chain'    : int
          - 'exon_peptide'         : str (sequence)
        Optional (but recommended):
          - 'gene_symbol'          : str (metadata only)
          - additional clade columns named in REX_TAXON_LEVELS (e.g., 'order', 'family')
        """
        builder = RegExTractorBuilder()
        df = _coalesce_training_columns(training_df)

        built = 0
        for paralog_group, pg_df in df.groupby("paralog_group"):
            if str(paralog_group) == "unknown_paralog":
                continue

            for exon_num in sorted(pg_df["exon_num_in_chain"].unique()):
                edf = pg_df[pg_df["exon_num_in_chain"] == exon_num]

                # Always attempt a 'pan' model if enough samples exist
                if len(edf) >= rex_min_clade_samples:
                    peps = edf["exon_peptide"].dropna().astype(str).tolist()
                    stats = builder.build_stats(peps)
                    tiers = builder.build_tiers(peps, stats)
                    if tiers:
                        # Robust gene_symbol fallback: gene_symbol → gene_symbol_norm → "unknown"
                        gene_symbol = (
                            edf["gene_symbol"].dropna().astype(str).iloc[0]
                            if "gene_symbol" in edf.columns and edf["gene_symbol"].notna().any()
                            else (
                                edf["gene_symbol_norm"].dropna().astype(str).iloc[0]
                                if "gene_symbol_norm" in edf.columns and edf["gene_symbol_norm"].notna().any()
                                else "unknown"
                            )
                        )
                        self.entries[(paralog_group, int(exon_num), "pan")] = ExonRegexLibraryEntry(
                            gene_symbol=gene_symbol,
                            exon_num_in_chain=int(exon_num),
                            clade="pan",
                            tiers=tiers,
                        )
                        built += 1

                # Then attempt additional clade-specific models if configured and present
                for clade_level in REX_TAXON_LEVELS:
                    if clade_level in (None, "pan"):
                        continue
                    if clade_level not in edf.columns:
                        continue

                    for clade_key, sub_df in edf.groupby(clade_level):
                        if len(sub_df) < rex_min_clade_samples:
                            continue
                        peps = sub_df["exon_peptide"].dropna().astype(str).tolist()
                        stats = builder.build_stats(peps)
                        tiers = builder.build_tiers(peps, stats)
                        if not tiers:
                            continue
                        # Robust gene_symbol fallback for clade entries
                        gene_symbol = (
                            sub_df["gene_symbol"].dropna().astype(str).iloc[0]
                            if "gene_symbol" in sub_df.columns and sub_df["gene_symbol"].notna().any()
                            else (
                                sub_df["gene_symbol_norm"].dropna().astype(str).iloc[0]
                                if "gene_symbol_norm" in sub_df.columns and sub_df["gene_symbol_norm"].notna().any()
                                else "unknown"
                            )
                        )
                        self.entries[(paralog_group, int(exon_num), str(clade_key))] = ExonRegexLibraryEntry(
                            gene_symbol=gene_symbol,
                            exon_num_in_chain=int(exon_num),
                            clade=str(clade_key),
                            tiers=tiers,
                        )
                        built += 1

        rex_log(f"Built {built} orthology-aware exon/clade entries "
                f"across {len(set(k[0] for k in self.entries))} paralog groups.")

# ---- Runner: build the library from SOURCE_DF ----
orthology_aware_library = OrthologyAwareLibrary()
orthology_aware_library.build(SOURCE_DF)

# quick diag
n_entries = len(orthology_aware_library.entries)
n_groups  = len({k[0] for k in orthology_aware_library.entries})
rex_log(f"[Cell 72] Library entries: {n_entries} across {n_groups} paralog groups.")


## Cell 74 – Diagnostics: training coverage (“why 0 entries?”)

### Purpose
Explain **why the library in Cell 73** produced few/zero entries by summarizing coverage
per `(paralog_group, exon_num_in_chain)` on the **effective label view**.  
This cell reads the same view used by the builder (`SOURCE_DF`), so counts reflect your
chosen label (e.g., `paralog_group_validated`).

### Inputs
- `SOURCE_DF` (from Cell 73). If absent, falls back to `training_df_rex`.
- `rex_min_clade_samples` (from Cell 70).

### What it prints
1) Selected label column (from Cell 73) and the top labels.  
2) Non-empty peptide count; rows with `unknown_paralog`.  
3) Group coverage per `(paralog_group, exon_num_in_chain)` and how many meet
   `rex_min_clade_samples`.  
4) Optional clade-column coverage (`order`, `family`, `genus`) if present.

### Notes
- **Read-only**: does not modify any DataFrame.  
- Set `DIAG_COMPARE_RAW_VIEW=True` to also show diagnostic stats on the raw
  `training_df_rex` for side-by-side comparison.



In [None]:
# ===== Cell 74 =====
# Diagnostics: training coverage (“why 0 entries?”)

import pandas as pd

#@markdown ### Controls
DIAG_COMPARE_RAW_VIEW = False  #@param {type:"boolean"}
DIAG_SHOW_TOP = 10             #@param {type:"integer"}

def _summary_training(df: pd.DataFrame) -> pd.DataFrame:
    """
    Summarise coverage needed for regex building.

    Returns a DataFrame with counts per (paralog_group, exon_num_in_chain).
    """
    req = {"paralog_group", "exon_num_in_chain", "exon_peptide"}
    cols_ok = req.issubset(df.columns)
    print(f"[Diag] training_df shape: {df.shape}; required columns present: {cols_ok}")
    if not cols_ok:
        missing = sorted(req - set(df.columns))
        print(f"[Diag] Missing columns: {missing}")
        return pd.DataFrame()

    tmp = df.copy()
    tmp["exon_peptide"] = tmp["exon_peptide"].astype(str)
    tmp["is_empty"] = tmp["exon_peptide"].eq("") | \
                      tmp["exon_peptide"].isin(["MISSING_EXON", "NO_BAIT_DEFINED"])

    # Top-level counts
    print(f"[Diag] Non-empty peptides: {(~tmp['is_empty']).sum()} / {len(tmp)}")
    unk = (tmp["paralog_group"] == "unknown_paralog").sum()
    print(f"[Diag] Rows with 'unknown_paralog': {unk}")

    # Coverage by (paralog, exon)
    grp = (
        tmp.loc[~tmp["is_empty"] & (tmp["paralog_group"] != "unknown_paralog")]
          .groupby(["paralog_group", "exon_num_in_chain"])
          .size()
          .rename("n")
          .reset_index()
          .sort_values(["paralog_group", "exon_num_in_chain"], kind="stable")
    )
    print(f"[Diag] Groups with at least 1 sample: {len(grp)}")
    with pd.option_context("display.max_rows", 20):
        display(grp.head(20))

    # Threshold gate
    thr = int(globals().get("rex_min_clade_samples", 5))
    passing = grp[grp["n"] >= thr]
    print(f"[Diag] Groups meeting rex_min_clade_samples={thr}: {len(passing)}")
    if len(passing) < 10:
        with pd.option_context("display.max_rows", 20):
            display(passing)

    # Optional clade diagnostics
    clade_cols = [c for c in ["order", "family", "genus"] if c in tmp.columns]
    if clade_cols:
        print(f"[Diag] Available clade columns: {clade_cols}")
        for c in clade_cols:
            cov = tmp[c].notna().mean()
            uniq = tmp[c].nunique(dropna=True)
            print(f"[Diag] {c}: non-null coverage={cov:.2%}, unique levels={uniq}")

    return grp

# ---- Choose the same view the builder used ----
df_effective = SOURCE_DF if 'SOURCE_DF' in globals() else (
    training_df_rex if 'training_df_rex' in globals() else None
)

if isinstance(df_effective, pd.DataFrame) and not df_effective.empty:
    if 'LABEL_COL' in globals():
        print(f"[Diag] Using LABEL_COL from builder: {LABEL_COL}")
    # Show top labels seen by the builder (effective view)
    with pd.option_context("display.max_rows", DIAG_SHOW_TOP):
        top_labels = df_effective["paralog_group"].value_counts(dropna=True).head(DIAG_SHOW_TOP)
        print("[Diag] Top labels in effective view:")
        display(top_labels)

    print("\n— Effective view (used by builder) —")
    diag_grp_effective = _summary_training(df_effective)
else:
    print("[Diag] training data is missing or empty (no SOURCE_DF / training_df_rex).")
    diag_grp_effective = pd.DataFrame()

# ---- Optional: compare with the raw training_df_rex view ----
if DIAG_COMPARE_RAW_VIEW and 'training_df_rex' in globals() \
   and isinstance(training_df_rex, pd.DataFrame) and not training_df_rex.empty:
    print("\n— Raw training_df_rex (for comparison only) —")
    _summary_training(training_df_rex)


##Cell 75 – RegExTractor Matcher, Hits, and Chain Reconstruction

This cell defines the **core matching and chain reconstruction primitives** that the classification engine relies on:

1. **`RegExTractorMatcher`**  
   - Provides a consistent interface for scoring regex matches against peptide sequences.  
   - Implements `_score_hit(matched_str, tier) → (score, gden)` which returns both a length–conservation–based score and a proxy for “g-head density” (`gden`).  
   - The scoring function uses simple heuristics on the regex pattern itself, making it lightweight and reproducible.

2. **`RexHit` dataclass**  
   - Encapsulates information about a single regex match (accession, paralog group, exon number, clade, tier, start/end coordinates, peptide, conservation proxy, and score).  
   - Designed to be serialisable (`.__dict__`) for logging or DataFrame conversion.

3. **`RexChain` dataclass**  
   - Represents a reconstructed exon chain for a given paralog group within a sequence.  
   - Records the exons found, chain start/end, total score, and the length of the longest consecutive exon run.  
   - Holds the mapping of exon numbers to their chosen `RexHit`.

4. **Chain reconstruction utilities**  
   - `rex_longest_consecutive_run(exons)` computes the longest run of consecutive exons detected.  
   - `rex_walk_chain(seed, hits_by_exon, expected_exons)` assembles the best possible chain around a seed hit, extending left and right along the canonical exon order, subject to non-overlap and monotonic coordinate rules.

---

**Dependencies:**  
- Must be executed after **Cell 73** (which builds the orthology-aware regex library).  
- Relies on configuration constants defined in **Cell 70** (e.g. `rex_chain_min_consecutive`).  
- Produces the functions and dataclasses that are invoked in the **Classification Engine** (Cell 74).

This separation ensures that **library building (Cell 72)** and **matching/chain logic (Cell 72B)** remain modular and can be tested or modified independently.

In [None]:
# ===== Cell 75 =====
# RegExTractor Matcher, Hit/Chain dataclasses, and Chain Reconstruction

from __future__ import annotations

import re
from dataclasses import dataclass, asdict
from typing import Dict, List, Tuple, Optional


# ----------------------------- Dataclasses ----------------------------- #

@dataclass
class RexHit:
    """
    A single regex match against a target sequence, annotated with metadata.

    Attributes
    ----------
    accession : str
        Accession of the scanned sequence.
    gene_symbol : str
        Used in your pipeline to hold the paralog_group label.
    exon_num_in_chain : int
        Exon ordinal within the paralog architecture.
    clade : str
        Clade key used for the regex (e.g., 'pan').
    tier : str
        Tier label for the regex (e.g., 'consensus', 'stringent').
    start : int
        Match start (0-based, inclusive) on the sequence.
    end : int
        Match end (0-based, exclusive) on the sequence.
    peptide : str
        The matched substring.
    gden : float
        "G-head density" (or generalised pattern conservation) used as a gate.
    score : float
        Overall score used to choose among competing hits.
    """
    accession: str
    gene_symbol: str
    exon_num_in_chain: int
    clade: str
    tier: str
    start: int
    end: int
    peptide: str
    gden: float
    score: float


@dataclass
class RexChain:
    """
    A reconstructed exon chain assembled from per-exon hits.

    Attributes
    ----------
    accession : str
        Accession of the scanned sequence.
    paralog_group : str
        Paralog group this chain is assigned to.
    exons_found : List[int]
        Sorted list of exon ordinals present in the chain.
    start : int
        Start coordinate of the chain span (min of hits).
    end : int
        End coordinate of the chain span (max of hits).
    total_score : float
        Sum (or other aggregator) of per-exon hit scores.
    consecutive_blocks : int
        Length (count) of the longest run of consecutive exons in the chain.
    hits : Dict[int, RexHit]
        Mapping exon_num_in_chain → chosen hit.
    """
    accession: str
    paralog_group: str
    exons_found: List[int]
    start: int
    end: int
    total_score: float
    consecutive_blocks: int
    hits: Dict[int, RexHit]


# ----------------------------- Matcher ----------------------------- #

class RegExTractorMatcher:
    """
    Lightweight matcher that scores regex hits consistently with your tier design.

    Notes
    -----
    - The matcher exposes `_score_hit(matched_str, tier)` returning `(score, gden)`,
      as expected by your Cell 73 orchestrator.
    - `gden` is computed from the regex pattern as an approximation of positional
      constraint (fraction of non-wildcard positions). The final `score` combines
      length and conservation with a simple geometric-mean-like formulation.
    """

    def __init__(self, library) -> None:
        """
        Parameters
        ----------
        library : OrthologyAwareLibrary
            Not strictly required for scoring (tiers are provided by caller),
            but kept for future extensions (e.g., clade-aware rescoring).
        """
        self.library = library

    @staticmethod
    def _pattern_conservation(pattern: str) -> float:
        """
        Approximate per-position constraint of a regex pattern in [0, 1].

        Heuristic:
        - Count positions as tokens:
            '.'          → unconstrained (0)
            '[...]'      → moderately constrained (min(1, len(set)/20))  # cap soft
            literal char → constrained (1)
        - Escapes '\\.' and similar are treated as literals.
        - Groups, quantifiers and other advanced constructs are rare in the
          generated tiers; if present, they are conservatively treated as 0.

        This does not parse the full regex grammar; it is a pragmatic estimator
        for motif density that is robust for patterns generated by Cell 72.
        """
        i, n = 0, len(pattern)
        tokens = 0
        constrained = 0.0

        while i < n:
            ch = pattern[i]

            # Character class
            if ch == '[':
                j = i + 1
                # Find closing bracket (naively, sufficient for our generated patterns)
                while j < n and pattern[j] != ']':
                    j += 1
                content = pattern[i + 1 : j] if j < n else ""
                # Treat the character set size as a soft constraint (larger set → weaker)
                k = len(set(content)) if content else 0
                tokens += 1
                constrained += min(1.0, max(0.0, 1.0 - (k - 1) / 19.0)) if k > 0 else 0.0
                i = j + 1
                continue

            # Escaped literal (e.g., '\.')
            if ch == '\\' and i + 1 < n:
                tokens += 1
                constrained += 1.0
                i += 2
                continue

            # Unconstrained position
            if ch == '.':
                tokens += 1
                i += 1
                continue

            # Literal single character
            if ch.isalpha():
                tokens += 1
                constrained += 1.0
                i += 1
                continue

            # Any other construct: treat as 0 contribution, do not advance tokens
            i += 1

        return float(constrained) / float(tokens) if tokens > 0 else 0.0

    def _score_hit(self, matched: str, tier) -> Tuple[float, float]:
        """
        Score a single regex match.

        Parameters
        ----------
        matched : str
            The matched peptide substring.
        tier : object
            Must expose: `.regex` (compiled pattern) and `.ghead_min` (float).

        Returns
        -------
        (score, gden) : Tuple[float, float]
            gden is the conservation proxy in [0, 1].
            score is a monotonic function of length and gden.
        """
        # Conservation proxy from the pattern itself
        patt = getattr(tier, "regex").pattern
        gden = self._pattern_conservation(patt)

        # Length component
        L = len(matched)
        # Combine length and conservation; sqrt dampens extreme lengths
        score = (L ** 0.5) * (0.5 + 0.5 * gden)

        return float(score), float(gden)


# ----------------------- Chain reconstruction utils ----------------------- #

def rex_longest_consecutive_run(exons: List[int]) -> int:
    """
    Compute the length of the longest run of consecutive integers in `exons`.
    """
    if not exons:
        return 0
    s = set(int(x) for x in exons)
    best = 1
    for x in s:
        if (x - 1) not in s:
            cur = 1
            while (x + cur) in s:
                cur += 1
            if cur > best:
                best = cur
    return best


def rex_walk_chain(
    seed: RexHit,
    hits_by_exon: Dict[int, RexHit],
    expected_exons: List[int],
    *,
    enforce_order: bool = True,
) -> RexChain:
    """
    Greedy chain reconstruction around a seed hit.

    The algorithm:
      1) Start from the `seed` exon.
      2) Attempt to extend left (lower exon numbers) and right (higher exon numbers)
         following `expected_exons` order.
      3) For each exon, accept the single best available hit (by score) that does not
         overlap the current chain span and respects coordinate order if `enforce_order`.

    Parameters
    ----------
    seed : RexHit
        The starting hit (usually the highest-scoring hit overall).
    hits_by_exon : Dict[int, RexHit]
        Best per-exon hits (already selected upstream), indexed by exon number.
    expected_exons : List[int]
        Paralog's canonical exon architecture (sorted).

    Returns
    -------
    RexChain
        The assembled chain.
    """
    # Initialise chain state from the seed
    chosen: Dict[int, RexHit] = {int(seed.exon_num_in_chain): seed}
    span_start = seed.start
    span_end = seed.end

    # Determine walking order
    if not expected_exons:
        expected_exons = sorted(hits_by_exon.keys())

    # Helper to try add an exon hit if constraints permit
    def _try_add(exon: int) -> bool:
        nonlocal span_start, span_end
        h = hits_by_exon.get(exon)
        if h is None:
            return False

        if enforce_order:
            # Maintain non-overlap and monotonic coordinates relative to seed side
            if exon < seed.exon_num_in_chain and not (h.end <= span_start):
                return False
            if exon > seed.exon_num_in_chain and not (h.start >= span_end):
                return False

        # Accept
        chosen[exon] = h
        span_start = min(span_start, h.start)
        span_end = max(span_end, h.end)
        return True

    # Walk left (decreasing exon numbers)
    seed_idx = expected_exons.index(seed.exon_num_in_chain) if seed.exon_num_in_chain in expected_exons else None
    if seed_idx is not None:
        for exon in reversed(expected_exons[:seed_idx]):
            _try_add(exon)

        # Walk right (increasing exon numbers)
        for exon in expected_exons[seed_idx + 1 :]:
            _try_add(exon)
    else:
        # Fallback: attempt all
        for exon in sorted(x for x in expected_exons if x < seed.exon_num_in_chain):
            _try_add(exon)
        for exon in sorted(x for x in expected_exons if x > seed.exon_num_in_chain):
            _try_add(exon)

    exons_found = sorted(chosen.keys())
    consecutive = rex_longest_consecutive_run(exons_found)
    total_score = float(sum(chosen[e].score for e in exons_found))

    chain = RexChain(
        accession=seed.accession,
        paralog_group=seed.gene_symbol,  # you store paralog_group in gene_symbol for hits
        exons_found=exons_found,
        start=span_start,
        end=span_end,
        total_score=total_score,
        consecutive_blocks=consecutive,
        hits=chosen,
    )
    return chain


# --------------------------- Runtime safeguards --------------------------- #

# Provide a sane default for the orchestrator threshold if it was not defined earlier.
try:
    rex_chain_min_consecutive  # noqa: F823
except NameError:
    rex_chain_min_consecutive: int = 3


## Cell 76a – Classification Cache (signature + reuse)

This cell avoids re-running the classifier if **inputs haven’t changed**.

**What it does**
- Computes stable SHA-256 signatures of:
  - the effective **training view** (`SOURCE_DF` or `training_df_rex`), limited to
    the columns the classifier actually uses (`paralog_group`, `exon_num_in_chain`,
    `gene_symbol`, `gene_symbol_norm`).
  - the **target pool** (`target_pool_rex`), limited to `accession` and `sequence`.
  - the presence/size of the regex **library** if provided.
- If signatures match the saved meta file, loads the **cached TSVs** instead of re-running.
- If signatures differ or files are missing, runs classification and refreshes the cache.

**Outputs**
- Reuses your existing files:
  - `CLASSIFIED_HITS_TSV`
  - `CLASSIFIED_CHAINS_TSV`
- Saves meta at `RUN_DIR/classification_cache.meta.json`.

**Flags**
- `ENABLE_CLASSIFY_CACHE` (default `True`)
- `FORCE_RECLASSIFY` (default `False`)
    

In [None]:
# ===== Cell 76a =====
# Classification cache wrapper: signature + reuse

import hashlib, json
from pathlib import Path

ENABLE_CLASSIFY_CACHE = bool(globals().get("ENABLE_CLASSIFY_CACHE", True))
FORCE_RECLASSIFY = bool(globals().get("FORCE_RECLASSIFY", False))

RUN_DIR = globals().get("RUN_DIR", Path("."))
RUN_DIR = Path(RUN_DIR)
CACHE_META = RUN_DIR / "classification_cache.meta.json"

def _stable_df_signature(df: pd.DataFrame, cols: list) -> str:
    """
    Return a stable SHA256 over selected columns (order-insensitive).
    Sort by all selected cols, stringify, then hash.
    """
    if not isinstance(df, pd.DataFrame) or df.empty:
        return "EMPTY"
    use = [c for c in cols if c in df.columns]
    if not use:
        return "NOCOLS"
    # Avoid huge memory: hash in chunks
    h = hashlib.sha256()
    chunk = (
        df[use].copy()
          .astype("string")
          .fillna("")
          .sort_values(by=use)
          .to_csv(index=False, header=True)
          .encode("utf-8")
    )
    h.update(chunk)
    return h.hexdigest()

def _library_signature(lib) -> str:
    """Small signature for library shape (counts of entries/tiers)."""
    try:
        n_entries = len(getattr(lib, "entries", {}))
        # Count total tiers for extra sensitivity
        n_tiers = 0
        for ent in getattr(lib, "entries", {}).values():
            n_tiers += len(getattr(ent, "tiers", []))
        return f"E{n_entries}_T{n_tiers}"
    except Exception:
        return "NO_LIB"

def _save_meta(meta: dict):
    try:
        CACHE_META.write_text(json.dumps(meta, indent=2))
    except Exception as e:
        rex_log(f"[ClassifyCache] Could not write meta: {e}")

def _load_meta() -> dict:
    try:
        if CACHE_META.exists():
            return json.loads(CACHE_META.read_text())
    except Exception:
        pass
    return {}

def _can_reuse_cache(meta_now: dict, meta_prev: dict) -> bool:
    keys = ["sig_training", "sig_target", "sig_library",
            "hits_tsv", "chains_tsv"]
    if not all(k in meta_prev for k in keys):
        return False
    # same signatures?
    if any(meta_prev.get(k) != meta_now.get(k) for k in ["sig_training","sig_target","sig_library"]):
        return False
    # files exist and sizes match prior record
    ht, ct = Path(meta_prev["hits_tsv"]), Path(meta_prev["chains_tsv"])
    if not ht.exists() or not ct.exists():
        return False
    if meta_prev.get("hits_bytes", -1) != ht.stat().st_size:
        return False
    if meta_prev.get("chains_bytes", -1) != ct.stat().st_size:
        return False
    return True


## Cell 76b – Execute Classification (with cache)

Uses the cache if valid; otherwise runs your existing **Cell 76** logic and writes meta.

*No renames; your `run_classification_engine` stays intact.*


## Cell 79 – Ensembl ID Coverage Audit (folded from Part 8)

This cell **audits** `ensembl_id` coverage produced earlier in Part 7 and (optionally)
writes a gap report. It does **not** perform any network/API lookups or refilling.

**Inputs**
- `df_high_quality` — the consolidated, high-quality table produced by Part 7.
- `RUN_DIR` (optional) — output directory for artifacts.

**Outputs**
- Console summary of Ensembl ID coverage.
- Optional: `RUN_DIR/ensembl_id_gaps.tsv` when coverage < threshold.

**Notes**
- Idempotent; does not rename or remove any columns.
- Ensures required columns exist with stable dtypes.


In [None]:
# ===== Cell 79 =====
# Ensembl ID Coverage Audit (no lookups; folded from former Part 8)

import pandas as pd
from pathlib import Path

# ---- Parameters (safe defaults; editable) ----
MIN_COVERAGE = 0.90  #@param {type:"number", min:0.0, max:1.0}
WRITE_GAP_REPORT = bool(globals().get("WRITE_GAP_REPORT", True))  #@param {type:"boolean"}

# ---- Resolve output dir ----
RUN_DIR = globals().get("RUN_DIR", Path("."))  # defined upstream in your pipeline
RUN_DIR = Path(RUN_DIR)

def _ensure_id_cols(df: pd.DataFrame) -> pd.DataFrame:
    """
    Ensure ID columns exist with robust nullable dtypes (idempotent).
    Does not overwrite existing values.
    """
    if not isinstance(df, pd.DataFrame) or df.empty:
        return df
    for col, dtype in (("ensembl_id", "string"),
                       ("ensembl_id_source", "string"),
                       ("ensembl_id_conf", "Float64")):
        if col not in df.columns:
            df[col] = pd.Series(pd.NA, index=df.index, dtype=dtype)
        else:
            try:
                if dtype == "Float64":
                    df[col] = pd.to_numeric(df[col], errors="coerce").astype("Float64")
                else:
                    df[col] = df[col].astype("string")
            except Exception:
                # best-effort coercion
                if dtype == "Float64":
                    df[col] = pd.to_numeric(df[col], errors="coerce").astype("Float64")
                else:
                    df[col] = df[col].astype("string")
    return df

def _coverage_stats(df: pd.DataFrame) -> tuple[int, int, float]:
    """Return (#have, total, fraction) for ensembl_id coverage."""
    if not isinstance(df, pd.DataFrame) or df.empty:
        return (0, 0, 0.0)
    if "ensembl_id" not in df.columns:
        return (0, len(df), 0.0)
    s = df["ensembl_id"].astype("string")
    have = int(s.fillna("").str.len().gt(0).sum())
    total = len(df)
    frac = (have / total) if total else 0.0
    return (have, total, frac)

# ---- Audit run ----
if "df_high_quality" in globals() and isinstance(df_high_quality, pd.DataFrame) and not df_high_quality.empty:
    df_high_quality = _ensure_id_cols(df_high_quality)

    have, total, frac = _coverage_stats(df_high_quality)
    print(f"[ID-Audit] df_high_quality Ensembl coverage: {have}/{total} = {frac:.2%}")

    if WRITE_GAP_REPORT and frac < MIN_COVERAGE:
        gap_mask = df_high_quality["ensembl_id"].astype("string").fillna("").str.len().eq(0)
        gaps = df_high_quality.loc[gap_mask].copy()
        cols = [c for c in [
            "species","Organism","gene_symbol_norm","gene_symbol",
            "Entry","accession","Protein names","paralog_group"
        ] if c in gaps.columns]
        outp = RUN_DIR / "ensembl_id_gaps.tsv"
        (gaps[cols] if cols else gaps).to_csv(outp, sep="\t", index=False)
        print(f"[ID-Audit] Coverage below threshold ({MIN_COVERAGE:.0%}). Gap list → {outp}")
else:
    print("[ID-Audit] df_high_quality not found or empty.")


In [None]:
df_high_quality.head(20)

# **Part 9: X/Y Substitution, Recovery & Outliers**

## Cell 90 – Bridge Part 7 → RegExTractor Inputs
This cell adapts `df_high_quality` from Part 7 into the structures required by the
RegExTractor (Cells 93–94). It:
- Verifies required columns from Part 7 are present.
- Resolves a **single** `paralog_group` (e.g., COL1A1, COL1A2, COL2A1, …).  
  Ambiguities like `"COL1A1 COL1A2"` are set to `"unknown_paralog"` unless already
  disambiguated upstream.
- Extracts the first triple-helix span from `gxy_spans` into `gxystart`/`gxyend`.
- Emits `training_df_rex` (kept for backward compatibility) and `rex_sequences_df`
  used by Cells 93–94.

> Assumptions:
> * Part 7 produced `df_high_quality` with UniProt/Ensembl-mapped sequences, species,
>   and a list column of G-X-Y spans (here assumed as `gxy_spans`).
> * We **do not** rename existing functions or globals used later; we only add safe
>   adapters and checks.

**Outputs used later**
- `rex_sequences_df` : tidy, per-sequence inputs for the recovery walker.
- `training_df_rex`  : same rows + a limited superset of columns for training paths.
- `PARALOG_ALLOWED_SET` : allowed gene symbols for strict mapping.


In [None]:
# ===== Cell 90 =====
# Bridge Part 7 → RegExTractor inputs (non-breaking adapter)

import logging
from typing import List, Optional

logger = logging.getLogger(__name__)

# --- Configuration (do not rename downstream globals) ---
# Allowed fibrillar collagen genes for strict paralog resolution (extend as needed)
PARALOG_ALLOWED_SET = {
    "COL1A1","COL1A2","COL2A1","COL3A1","COL5A1","COL5A2","COL5A3",
    "COL10A1","COL11A1","COL11A2","COL12A1","COL16A1","COL21A1","COL22A1",
    "COL8A1","COL6A1","COL6A3","COL4A3","COL4A4","COL4A5"  # keep list in sync with your pipeline
}

# Column name hints seen in your Part 7 preview
CANDIDATE_GENE_COLS = ["gene_name", "Gene Names (primary)", "primary_gene", "uniprot_gene"]
CANDIDATE_SPECIES_COLS = ["Organism", "species", "taxon_name"]
CANDIDATE_ACCESSION_COLS = ["Entry", "accession", "uniprot_id"]
CANDIDATE_SEQ_COLS = ["Sequence", "sequence", "aa_sequence"]
GXY_SPANS_COL = "gxy_spans"  # Part 7 preview showed a list like [{'start': 201, 'end': 1217}]

def _first_present(df, candidates: List[str]) -> Optional[str]:
    for c in candidates:
        if c in df.columns:
            return c
    return None

def _resolve_paralog(val: Optional[str]) -> str:
    """
    Map any gene string to a single canonical paralog_group.
    - If multiple symbols are space-separated (e.g., "COL1A1 COL1A2"), mark as unknown.
    - If not in allowed set (or NaN), mark as unknown.
    """
    if not isinstance(val, str) or not val.strip():
        return "unknown_paralog"
    parts = [p.strip() for p in val.replace(",", " ").split() if p.strip()]
    # keep only exact allowed tokens
    exact = [p for p in parts if p in PARALOG_ALLOWED_SET]
    if len(exact) == 1:
        return exact[0]
    # ambiguous or none
    return "unknown_paralog"

def _extract_gxy_bounds(spans):
    """
    Given a list like [{'start': 179, 'end': 1192}, ...], return (start, end).
    If missing/empty, return (None, None).
    """
    try:
        if isinstance(spans, list) and spans:
            s = spans[0]
            return int(s.get("start")), int(s.get("end"))
    except Exception:
        pass
    return None, None

# --- Preconditions ---
if 'df_high_quality' not in globals():
    raise RuntimeError("FATAL: Part 7 output 'df_high_quality' not found.")

dfhq = df_high_quality.copy()

acc_col = _first_present(dfhq, CANDIDATE_ACCESSION_COLS)
sp_col  = _first_present(dfhq, CANDIDATE_SPECIES_COLS)
gn_col  = _first_present(dfhq, CANDIDATE_GENE_COLS)
seq_col = _first_present(dfhq, CANDIDATE_SEQ_COLS)

missing = [n for n, v in {
    "accession": acc_col, "species": sp_col, "gene": gn_col, "sequence": seq_col
}.items() if v is None]

if missing:
    raise RuntimeError(f"FATAL: df_high_quality is missing required columns: {missing}")

if GXY_SPANS_COL not in dfhq.columns:
    # tolerate absence; walker can still run but ordering is weaker
    logger.warning(f"[Bridge] '{GXY_SPANS_COL}' missing; gxystart/gxyend will be None.")

# --- Build RegExTractor-ready view ---
tmp = dfhq[[acc_col, sp_col, gn_col, seq_col] + ([GXY_SPANS_COL] if GXY_SPANS_COL in dfhq else [])].copy()
tmp = tmp.rename(columns={
    acc_col: "accession",
    sp_col:  "species",
    gn_col:  "gene_name_raw",
    seq_col: "sequence"
})

# Resolve paralog_group
tmp["paralog_group"] = tmp["gene_name_raw"].map(_resolve_paralog)

# Extract G-X-Y bounds if present
if GXY_SPANS_COL in tmp.columns:
    bounds = tmp[GXY_SPANS_COL].apply(_extract_gxy_bounds)
    tmp["gxystart"] = bounds.apply(lambda t: t[0])
    tmp["gxyend"]   = bounds.apply(lambda t: t[1])
else:
    tmp["gxystart"] = None
    tmp["gxyend"]   = None

# Minimal hygiene
before = len(tmp)
tmp = tmp.dropna(subset=["accession", "species"]).drop_duplicates(subset=["accession"])
after = len(tmp)

logger.info(f"[Bridge] RegExTractor input built: {after} rows (dropped {before-after} empty/dupes).")
logger.info(f"[Bridge] Paralog distribution:\n{tmp['paralog_group'].value_counts(dropna=False).to_string()}")

# Flag ambiguous rows so you can decide to keep/exclude
tmp["is_ambiguous_paralog"] = (tmp["paralog_group"] == "unknown_paralog")

ambig_n = int(tmp["is_ambiguous_paralog"].sum())
if ambig_n > 0:
    logger.warning(f"[Bridge] Ambiguous/unknown paralog rows: {ambig_n}. "
                   f"Examples:\n{tmp.loc[tmp['is_ambiguous_paralog'], ['accession','gene_name_raw']].head(10).to_string(index=False)}")

# --- Emit both legacy and new globals (non-breaking) ---
rex_sequences_df = tmp  # primary input for Cells 93–94
training_df_rex  = tmp.copy()  # kept for backward compatibility with earlier cells

# Optional: exclude ambiguous rows here (safer for recovery). If you prefer to keep them for later disambiguation,
# comment the next two lines.
rex_sequences_df = rex_sequences_df.loc[~rex_sequences_df["is_ambiguous_paralog"]].copy()
training_df_rex  = training_df_rex.loc[~training_df_rex["is_ambiguous_paralog"]].copy()

logger.info(f"[Bridge] Final rex_sequences_df: {len(rex_sequences_df)} rows; "
            f"unknown_paralog removed: {ambig_n - int(rex_sequences_df['is_ambiguous_paralog'].sum())}")


## Cell 90b – Compute G-X-Y spans from sequence (triple-helix bounds)
This cell scans each sequence for the longest region consistent with collagen
Gly–X–Y periodicity (every 3rd residue is G). It emits `gxy_spans` as a list
of dicts (first span used for `gxystart/gxyend`), so later cells can order
anchors within the triple helix.

Heuristics:
- Require a **minimum run length** (default 300 aa) to count as a helix.
- Allow a small number of **local mismatches** (frameshift/misreads) via a
  tolerance window (sliding reset after too many non-G at 3rd positions).
- Never renames existing columns; only **adds** `gxy_spans`, `gxystart`, `gxyend`.


In [None]:
# ===== Cell 90b =====
# Compute G-X-Y spans (triple helix) from raw sequences

from typing import List, Dict, Tuple, Optional

def _find_gxy_spans(seq: str,
                    min_span_len: int = 300,
                    max_off_period_misses: int = 2) -> List[Dict[str, int]]:
    """
    Find contiguous regions where every 3rd residue is 'G' (Gly-X-Y motif).
    Tolerates brief local disruptions up to `max_off_period_misses` before
    closing the current span.

    Returns list of dicts: [{'start': i0, 'end': i1}, ...]  (0-based, end exclusive)
    """
    spans: List[Tuple[int,int]] = []
    n = len(seq)
    i = 0
    while i < n:
        # Try to align to a frame so that positions i, i+3, ... are 'G'
        best_frame = None
        best_len = 0
        best_span = None

        for frame in (0, 1, 2):
            start = i + frame
            if start >= n:
                continue
            j = start
            misses = 0
            # advance in steps of 1 but enforce 'G' at (j-start)%3==0
            while j < n:
                if ((j - start) % 3 == 0) and (seq[j] != 'G'):
                    misses += 1
                    if misses > max_off_period_misses:
                        break
                j += 1
            span_len = j - start
            if span_len > best_len:
                best_len = span_len
                best_frame = frame
                best_span = (start, j)

        if best_span and (best_span[1] - best_span[0]) >= min_span_len:
            spans.append(best_span)
            i = best_span[1]  # continue after this span
        else:
            i += 1

    # Merge adjacent/near-adjacent spans separated by very short gaps (≤6 aa)
    merged: List[Tuple[int,int]] = []
    for s in spans:
        if not merged:
            merged.append(s)
        else:
            a0, a1 = merged[-1]
            b0, b1 = s
            if b0 - a1 <= 6:
                merged[-1] = (a0, max(a1, b1))
            else:
                merged.append(s)

    return [{"start": s, "end": e} for (s, e) in merged]

if 'rex_sequences_df' not in globals():
    raise RuntimeError("FATAL: rex_sequences_df not found. Run Cell 90 first.")

rex_sequences_df = rex_sequences_df.copy()
rex_sequences_df["gxy_spans"] = rex_sequences_df["sequence"].apply(_find_gxy_spans)

# Fill gxystart/gxyend from first span (if present)
def _first_bounds(spans):
    if isinstance(spans, list) and spans:
        s = spans[0]
        return int(s["start"]), int(s["end"])
    return None, None

bounds = rex_sequences_df["gxy_spans"].apply(_first_bounds)
rex_sequences_df["gxystart"] = bounds.apply(lambda t: t[0])
rex_sequences_df["gxyend"]   = bounds.apply(lambda t: t[1])

logger.info(
    "[GXY] gxy_spans computed. With bounds: "
    f"{int(rex_sequences_df['gxystart'].notna().sum())} / {len(rex_sequences_df)} "
    f"({100.0*rex_sequences_df['gxystart'].notna().mean():.1f}%)."
)


## Cell 90c – Canonicalise/Rescue `paralog_group`
This cell reduces `unknown_paralog` rows by:
1) **Canonicalising gene symbols** using an alias map (e.g., synonyms, case/spacing).
2) **Parsing protein names** (e.g., “Collagen type I alpha 1 chain”) to set a
   **single** canonical `paralog_group`.
3) (**Optional**) A **G-X-Y length heuristic**: if the triple-helix length
   matches a narrow window typical of a paralog (e.g., ~1014 aa for COL1A1/1A2,
   ~1284 aa for COL2A1, etc.), assign provisionally.

All assignments are conservative and never overwrite an existing confident label.
We keep your `PARALOG_ALLOWED_SET`; anything outside remains unknown.


In [None]:
# ===== Cell 90c =====
# Canonicalise/Rescue paralog_group from aliases, protein names, and optional length heuristics

import re
from typing import Optional

if 'rex_sequences_df' not in globals():
    raise RuntimeError("FATAL: rex_sequences_df not found. Run Cell 90 first.")

# Candidate protein-name columns from Part 7 exports
CANDIDATE_PNAME_COLS = ["Protein names", "Protein Name", "protein_name", "Entry name"]

# 1) Alias map for gene symbols → canonical
GENE_ALIAS_TO_CANON = {
    # exact canonical passthrough
    "COL1A1":"COL1A1","COL1A2":"COL1A2","COL2A1":"COL2A1","COL3A1":"COL3A1",
    "COL5A1":"COL5A1","COL5A2":"COL5A2","COL5A3":"COL5A3",
    "COL10A1":"COL10A1","COL11A1":"COL11A1","COL11A2":"COL11A2",
    "COL12A1":"COL12A1","COL16A1":"COL16A1","COL21A1":"COL21A1","COL22A1":"COL22A1",
    "COL8A1":"COL8A1","COL6A1":"COL6A1","COL6A3":"COL6A3","COL4A3":"COL4A3",
    "COL4A4":"COL4A4","COL4A5":"COL4A5",
    # common noisy forms
    "COL1A1/1A2":"unknown_paralog", "COL1A1 COL1A2":"unknown_paralog",
    "COL1A1;COL1A2":"unknown_paralog", "COL1A1, COL1A2":"unknown_paralog",
}

def _canon_gene(g: Optional[str]) -> str:
    if not isinstance(g, str) or not g.strip():
        return "unknown_paralog"
    g = g.strip().upper().replace(" ", "")
    if g in GENE_ALIAS_TO_CANON:
        return GENE_ALIAS_TO_CANON[g]
    # quick acceptance of exact canonical tokens
    if g in PARALOG_ALLOWED_SET:
        return g
    return "unknown_paralog"

# 2) Parse protein names like "Collagen type I alpha 1 chain"
_PN_PATTERNS = [
    # (regex, canonical)
    (re.compile(r"\bCOLLAGEN\s+TYPE\s+I\s+ALPHA\s*1\b", re.I),  "COL1A1"),
    (re.compile(r"\bCOLLAGEN\s+TYPE\s+I\s+ALPHA\s*2\b", re.I),  "COL1A2"),
    (re.compile(r"\bCOLLAGEN\s+TYPE\s+II\s+ALPHA\s*1\b", re.I), "COL2A1"),
    (re.compile(r"\bCOLLAGEN\s+TYPE\s+III\s+ALPHA\s*1\b", re.I),"COL3A1"),
    (re.compile(r"\bCOLLAGEN\s+TYPE\s+V\s+ALPHA\s*1\b", re.I),  "COL5A1"),
    (re.compile(r"\bCOLLAGEN\s+TYPE\s+V\s+ALPHA\s*2\b", re.I),  "COL5A2"),
    (re.compile(r"\bCOLLAGEN\s+TYPE\s+V\s+ALPHA\s*3\b", re.I),  "COL5A3"),
    (re.compile(r"\bCOLLAGEN\s+TYPE\s+XI\s+ALPHA\s*1\b", re.I), "COL11A1"),
    (re.compile(r"\bCOLLAGEN\s+TYPE\s+XI\s+ALPHA\s*2\b", re.I), "COL11A2"),
    (re.compile(r"\bCOLLAGEN\s+TYPE\s+XII\s+ALPHA\s*1\b", re.I),"COL12A1"),
    (re.compile(r"\bCOLLAGEN\s+TYPE\s+XVI\s+ALPHA\s*1\b", re.I),"COL16A1"),
    (re.compile(r"\bCOLLAGEN\s+TYPE\s+XXI\s+ALPHA\s*1\b", re.I),"COL21A1"),
    (re.compile(r"\bCOLLAGEN\s+TYPE\s+XXII\s+ALPHA\s*1\b", re.I),"COL22A1"),
    (re.compile(r"\bCOLLAGEN\s+TYPE\s+VIII\s+ALPHA\s*1\b", re.I),"COL8A1"),
    (re.compile(r"\bCOLLAGEN\s+TYPE\s+VI\s+ALPHA\s*1\b", re.I), "COL6A1"),
    (re.compile(r"\bCOLLAGEN\s+TYPE\s+VI\s+ALPHA\s*3\b", re.I), "COL6A3"),
    (re.compile(r"\bCOLLAGEN\s+TYPE\s+IV\s+ALPHA\s*3\b", re.I), "COL4A3"),
    (re.compile(r"\bCOLLAGEN\s+TYPE\s+IV\s+ALPHA\s*4\b", re.I), "COL4A4"),
    (re.compile(r"\bCOLLAGEN\s+TYPE\s+IV\s+ALPHA\s*5\b", re.I), "COL4A5"),
]

def _pn_to_paralog(pn: Optional[str]) -> str:
    if not isinstance(pn, str) or not pn:
        return "unknown_paralog"
    for pat, canon in _PN_PATTERNS:
        if pat.search(pn):
            return canon
    return "unknown_paralog"

pname_col = next((c for c in CANDIDATE_PNAME_COLS if c in rex_sequences_df.columns), None)

# Step A: canonicalise any existing gene label (only fill unknowns)
mask_unknown = rex_sequences_df["paralog_group"] == "unknown_paralog"
rex_sequences_df.loc[mask_unknown, "paralog_group"] = rex_sequences_df.loc[mask_unknown, "gene_name_raw"].map(_canon_gene)

# Step B: rescue via protein names where still unknown
if pname_col:
    mask_unknown = rex_sequences_df["paralog_group"] == "unknown_paralog"
    rescued = rex_sequences_df.loc[mask_unknown, pname_col].map(_pn_to_paralog)
    # Only accept rescues that are in allowed set
    rescued = rescued.where(rescued.isin(PARALOG_ALLOWED_SET), other="unknown_paralog")
    rex_sequences_df.loc[mask_unknown, "paralog_group"] = rescued

# Step C (optional): GXY-length heuristic (very conservative default OFF)
USE_GXY_LENGTH_HEURISTIC = False
GXY_LENGTH_WINDOWS = {
    # helix aa lengths are approximate (end-start); tighten/expand as you wish
    "COL1A1": (900, 1050),  # ~1014 typical
    "COL1A2": (900, 1050),
    "COL2A1": (1200, 1350), # ~1284 typical
    "COL3A1": (975, 1125),  # ~1029 typical
}
if USE_GXY_LENGTH_HEURISTIC:
    m = rex_sequences_df["paralog_group"] == "unknown_paralog"
    lengths = (rex_sequences_df["gxyend"] - rex_sequences_df["gxystart"]).where(m, other=None)
    for canon, (lo, hi) in GXY_LENGTH_WINDOWS.items():
        sel = m & lengths.between(lo, hi, inclusive="both")
        rex_sequences_df.loc[sel, "paralog_group"] = canon

# Mirror back to training_df_rex to keep older cells happy
training_df_rex = rex_sequences_df.copy()

# Report
unk = int((rex_sequences_df["paralog_group"] == "unknown_paralog").sum())
logger.info(f"[Paralog Rescue] unknown_paralog remaining: {unk} / {len(rex_sequences_df)} "
            f"({100.0*unk/len(rex_sequences_df):.1f}%).")
by_para = rex_sequences_df.groupby("paralog_group")["accession"].nunique().sort_values(ascending=False)
logger.info(f"[Paralog Rescue] Unique accessions by paralog_group:\n{by_para.to_string()}")


## Cell 90d – Triple-Helix Bounds Preflight (auto-compute if missing)
Ensures `gxystart/gxyend` are populated. If missing, compute `gxy_spans` from
the sequence (longest G–X–Y run) and fill bounds. Prints coverage.


In [None]:
# ===== Cell 90d =====
# Ensure gxystart/gxyend exist; auto-compute from sequences if needed.

from typing import List, Tuple, Dict

def _find_gxy_spans(seq: str,
                    min_span_len: int = 300,
                    max_off_period_misses: int = 2) -> List[Dict[str, int]]:
    spans: List[Tuple[int,int]] = []
    n = len(seq)
    i = 0
    while i < n:
        best_span = None
        best_len = 0
        for frame in (0, 1, 2):
            start = i + frame
            if start >= n:
                continue
            j = start
            misses = 0
            while j < n:
                if ((j - start) % 3 == 0) and (seq[j] != 'G'):
                    misses += 1
                    if misses > max_off_period_misses:
                        break
                j += 1
            span_len = j - start
            if span_len > best_len:
                best_len = span_len
                best_span = (start, j)
        if best_span and (best_span[1] - best_span[0]) >= min_span_len:
            spans.append(best_span)
            i = best_span[1]
        else:
            i += 1

    merged: List[Tuple[int,int]] = []
    for s in spans:
        if not merged:
            merged.append(s)
        else:
            a0, a1 = merged[-1]
            b0, b1 = s
            if b0 - a1 <= 6:
                merged[-1] = (a0, max(a1, b1))
            else:
                merged.append(s)
    return [{"start": s, "end": e} for (s, e) in merged]

def _first_bounds(spans):
    if isinstance(spans, list) and spans:
        s = spans[0]
        return int(s["start"]), int(s["end"])
    return None, None

if 'rex_sequences_df' not in globals():
    raise RuntimeError("FATAL: rex_sequences_df missing. Run Cell 90 first.")

rex_sequences_df = rex_sequences_df.copy()

# Compute if missing
need_spans = ("gxy_spans" not in rex_sequences_df.columns) or \
             (rex_sequences_df["gxystart"].isna().all() if "gxystart" in rex_sequences_df else True)

if need_spans:
    rex_sequences_df["gxy_spans"] = rex_sequences_df["sequence"].apply(_find_gxy_spans)
    bounds = rex_sequences_df["gxy_spans"].apply(_first_bounds)
    rex_sequences_df["gxystart"] = bounds.apply(lambda t: t[0])
    rex_sequences_df["gxyend"]   = bounds.apply(lambda t: t[1])

n_total = len(rex_sequences_df)
n_bounds = int(rex_sequences_df["gxystart"].notna().sum())
logger.info(f"[GXY] Bounds available: {n_bounds}/{n_total} "
            f"({100.0 * (n_bounds/max(1,n_total)):.1f}%).")


## Cell 90e – Collagen Coverage Audit (Where did other collagens go?)
This audit:
1) Extracts raw collagen tokens from `df_high_quality` (`gene_name_raw` + protein
   name fallbacks) using a permissive `COL(\d+)A(\d+)` detector.
2) Compares them to the **final** `paralog_group` in `rex_sequences_df`.
3) Reports which collagen families were present upstream but **not admitted**
   to the final set (typically because they’re not in `PARALOG_ALLOWED_SET`
   or our protein-name patterns didn’t include them).

Nothing is renamed; we only *read* existing globals and log a compact report.


In [None]:
# ===== Cell 90e =====
# Collagen coverage audit: df_high_quality (raw) vs rex_sequences_df (final)

import re
import pandas as pd

if 'df_high_quality' not in globals():
    raise RuntimeError("FATAL: df_high_quality not found (Part 7).")
if 'rex_sequences_df' not in globals():
    raise RuntimeError("FATAL: rex_sequences_df not found (Cell 90).")

# columns we used earlier
gn_col  = next((c for c in ["gene_name_raw","gene_name","Gene Names (primary)",
                            "primary_gene","uniprot_gene"] if c in df_high_quality.columns), None)
pname_col_candidates = ["Protein names","Protein Name","protein_name","Entry name"]
pname_col = next((c for c in pname_col_candidates if c in df_high_quality.columns), None)

if gn_col is None:
    raise RuntimeError("FATAL: Could not find a gene name column in df_high_quality.")

# 1) Extract permissive collagen tokens from df_high_quality
def _find_coll_tokens(s: str) -> list:
    if not isinstance(s, str):
        return []
    # match COL # A #  (accept lower/upper, spaces or dashes)
    pats = re.findall(r"COL\s*([0-9]+)\s*A\s*([0-9]+)", s.replace("-", " ").upper())
    return [f"COL{a}A{b}" for a,b in pats]

raw_df = df_high_quality[[gn_col] + ([pname_col] if pname_col else [])].copy()
raw_df["__raw_text__"] = raw_df[gn_col].astype(str)
if pname_col:
    raw_df["__raw_text__"] = raw_df["__raw_text__"] + " ; " + raw_df[pname_col].astype(str)

raw_df["raw_coll_tokens"] = raw_df["__raw_text__"].apply(_find_coll_tokens)

# explode and count
raw_tokens = (raw_df.explode("raw_coll_tokens")
                     .dropna(subset=["raw_coll_tokens"]))
raw_counts = raw_tokens["raw_coll_tokens"].value_counts()

# 2) Final paralog set from rex_sequences_df
final_counts = rex_sequences_df["paralog_group"].value_counts()

# 3) Which collagens appeared upstream but not present downstream?
raw_set   = set(raw_counts.index.tolist())
final_set = set(final_counts.index.tolist())
missing_downstream = sorted([c for c in raw_set if c not in final_set])

print("=== Collagens detected in Part 7 (raw labels/protein names) ===")
print(raw_counts.head(50).to_string())
print("\n=== Collagens present in final rex_sequences_df ===")
print(final_counts.to_string())
print("\n=== Collagens seen upstream but not admitted downstream ===")
print(", ".join(missing_downstream) if missing_downstream else "(none)")

# Optional: show a few example rows for the first 5 "missing" families
if missing_downstream:
    probe = missing_downstream[:5]
    ex = raw_tokens[raw_tokens["raw_coll_tokens"].isin(probe)].head(20)
    display_cols = [gn_col] + ([pname_col] if pname_col else [])
    print("\nExamples from df_high_quality carrying these missing labels:")
    print(ex[display_cols + ["raw_coll_tokens"]].to_string(index=False))


## Cell 90f – Anchor Catalog Coverage Audit (Do we even support them?)
Even if we detect COL4A2, COL8A2, COL9A1, COL14A1, COL17A1 upstream, recovery
won’t use them unless the **anchor catalog** contains ordered exon anchors for
those chains. This cell lists which genes the catalog currently supports.


In [None]:
# ===== Cell 90f =====
# What genes/chains does the anchor catalog actually cover?

if 'ANCHOR_CATALOG_DF' not in globals():
    print("ANCHOR_CATALOG_DF not found. Run Cell 93 first.")
else:
    cat = ANCHOR_CATALOG_DF
    have_cols = [c for c in ["gene","chain","exon_index","anchor_regex","loaded"] if c in cat.columns]
    print(f"Catalog columns: {have_cols}")
    if "gene" in cat.columns:
        print("\nGenes present in catalog (top 50):")
        print(cat["gene"].value_counts().head(50).to_string())
        # useful set for comparison
        catalog_genes = set(cat["gene"].unique())
        # pull upstream raw tokens again (from 90e) if available
        try:
            raw_tokens_set = set(raw_tokens["raw_coll_tokens"].unique())
            unseen = sorted([g for g in raw_tokens_set if g not in catalog_genes])
            print("\nCollagens referenced upstream but NOT in catalog:")
            print(", ".join(unseen) if unseen else "(none)")
        except Exception:
            pass


## Cell 90g – Expand rescue to additional collagen families (non-breaking)
This cell:
1) Extends `PARALOG_ALLOWED_SET` to include your “missing” families.
2) Adds permissive protein-name patterns for those families (Type N, αM).
3) Re-runs the rescue *only for rows still marked unknown* (safe, non-overwriting).

**Note**: Recovery still requires anchors. If the anchor catalog has no entries
for a gene, stitching will skip it even after rescue succeeds.


In [None]:
# ===== Cell 90g =====
# Expand allow-list + protein-name patterns; re-run rescue for remaining unknowns

import re

if 'rex_sequences_df' not in globals():
    raise RuntimeError("rex_sequences_df missing. Run Cell 90 first.")

# 1) Extend allowed set (explicit list from your audit)
EXTENDED_PARALOGS = {
    "COL4A1","COL4A2","COL4A6",
    "COL6A2","COL6A5","COL6A6",
    "COL7A1","COL8A2","COL9A1","COL9A2","COL9A3",
    "COL10A1","COL13A1","COL14A1","COL15A1","COL17A1",
    "COL18A1","COL19A1","COL20A1","COL23A1","COL24A1",
    "COL25A1","COL27A1",
}
PARALOG_ALLOWED_SET |= EXTENDED_PARALOGS

# 2) Add protein-name patterns dynamically for added families
# Accept both "Collagen type XIV alpha 1" and "Collagen alpha-1(XIV)"
def _roman(n: int) -> str:
    # minimal roman (up to 30-ish)
    vals = [(10,'X'),(9,'IX'),(5,'V'),(4,'IV'),(1,'I')]
    res, x = "", n
    for v,s in vals:
        while x>=v:
            res += s; x -= v
    return res

_new_patterns = []
for gene in sorted(EXTENDED_PARALOGS):
    m = re.fullmatch(r"COL(\d+)A(\d+)", gene)
    if not m:
        continue
    num = int(m.group(1)); a = int(m.group(2))
    roman = _roman(num)
    # Patterns:
    #  A) "Collagen type N alpha M"
    _new_patterns.append((re.compile(fr"\bCOLLAGEN\s+TYPE\s+{num}\s+ALPHA\s*{a}\b", re.I), gene))
    _new_patterns.append((re.compile(fr"\bCOLLAGEN\s+TYPE\s+{roman}\s+ALPHA\s*{a}\b", re.I), gene))
    #  B) "Collagen alpha-M(N)" where N can be arabic or roman, allow hyphen/space
    _new_patterns.append((re.compile(fr"\bCOLLAGEN\s+ALPHA[-\s]*{a}\s*\(\s*{num}\s*\)\b", re.I), gene))
    _new_patterns.append((re.compile(fr"\bCOLLAGEN\s+ALPHA[-\s]*{a}\s*\(\s*{roman}\s*\)\b", re.I), gene))

# Attach to the existing PN patterns list if present; else local fallback
if '_PN_PATTERNS' in globals() and isinstance(_PN_PATTERNS, list):
    _PN_PATTERNS.extend(_new_patterns)
else:
    _PN_PATTERNS = _new_patterns  # local fallback if earlier cell not run

# 3) Re-run rescue **only for unknowns**
pname_col = next((c for c in ["Protein names","Protein Name","protein_name","Entry name"]
                  if c in rex_sequences_df.columns), None)

def _pn_to_paralog(pn: str) -> str:
    if not isinstance(pn, str) or not pn:
        return "unknown_paralog"
    for pat, canon in _PN_PATTERNS:
        if pat.search(pn):
            return canon
    return "unknown_paralog"

mask_unknown = rex_sequences_df["paralog_group"].eq("unknown_paralog")
if pname_col:
    rescued = rex_sequences_df.loc[mask_unknown, pname_col].map(_pn_to_paralog)
    rescued = rescued.where(rescued.isin(PARALOG_ALLOWED_SET), other="unknown_paralog")
    rex_sequences_df.loc[mask_unknown, "paralog_group"] = rescued

# Mirror to training_df_rex for backward compatibility
training_df_rex = rex_sequences_df.copy()

# Report change
unk = int((rex_sequences_df["paralog_group"] == "unknown_paralog").sum())
logger.info(f"[Rescue-Extended] unknown_paralog remaining: {unk} / {len(rex_sequences_df)}")
logger.info(f"[Rescue-Extended] Genes now admitted (top 30):\n"
            f"{rex_sequences_df['paralog_group'].value_counts().head(30).to_string()}")


## Cell 91 – XY Substitution Framework (config + utilities)

**What this adds**

- Extracts **X** and **Y** residues only from G–X–Y frames (anchored on **G**).
- Builds MRCA-anchored substitution **counts** and per-AA **frequencies**.
- Small-n protections and graceful fallbacks if MRCA strings are missing.

**Assumptions / Inputs**

- `SOURCE_DF` exists and includes at minimum:
  - `paralog_group` (gene key), `exon_num_in_chain`, `exon_peptide`, `species`
  - Optional: `order` (or similar) for clade; optional `mrca_exon_peptide`.
- MRCA per (gene, exon) may be provided in `MRCA_EXON_MAP[(gene, exon)]`.
  Otherwise a **majority-vote consensus** is computed from available sequences.

**Outputs**

- `xy_subst_by_exon_df`: MRCA-anchored 20×20 counts per (gene, exon, pos_class, clade).
- `xy_subst_counts_df`: Aggregated counts per (gene, pos_class, clade).
- `xy_position_diversity_df`: Per-exon AA diversity by position (for anchor scans).
- Utility funcs reused by later cells.

In [None]:
# ===== Cell 91 =====
# XY Substitution Framework (config + utilities)

from collections import Counter, defaultdict
from typing import Dict, Tuple, List, Optional
import numpy as np
import pandas as pd
import math
import re

# --- Logging shim (reuse your logger/rex_log if present) ---
def _log(msg: str):
    if 'rex_log' in globals():
        try:
            rex_log(msg)
            return
        except Exception:
            pass
    print(msg)

# --- Parameters (do NOT clobber if already defined upstream) ---
XY_MIN_GXY_TRIPLETS = int(globals().get("XY_MIN_GXY_TRIPLETS", 5))       # min triplets per exon to include
XY_MIN_SPECIES_PER_EXON = int(globals().get("XY_MIN_SPECIES_PER_EXON", 5))
XY_ALPHA = float(globals().get("XY_ALPHA", 0.01))                         # sig. threshold
XY_EFFECT_CV = float(globals().get("XY_EFFECT_CV", 0.15))                 # Cramér's V small/medium boundary
XY_MIN_CELL = int(globals().get("XY_MIN_CELL", 1))                        # smoothing pseudo-counts
XY_REGEX_MIN_FREQ = float(globals().get("XY_REGEX_MIN_FREQ", 0.05))       # AA must reach this freq to be included
XY_REGEX_IMPORTANCE_FREQ = float(globals().get("XY_REGEX_IMPORTANCE_FREQ", 0.02))  # always include K/R, D/E, P if >= this
XY_CLADES_FROM = globals().get("XY_CLADES_FROM", "order")                 # column giving clade; fallback to 'pan'
XY_AA_ORDER = list("ACDEFGHIKLMNPQRSTVWY")                                # 20 AA (no O/U); 'P' crucial here
AA_INDEX = {aa:i for i,aa in enumerate(XY_AA_ORDER)}

# --- Safe import for tests (chi2) with fallback G-test ---
try:
    from scipy.stats import chi2_contingency
    _HAVE_SCIPY = True
except Exception:
    _HAVE_SCIPY = False

def _g_test(obs: np.ndarray) -> Tuple[float, float]:
    """Likelihood-ratio (G) test for independence across rows.
    Returns (stat, p). Assumes obs shape (r, c)."""
    r, c = obs.shape
    row_sums = obs.sum(axis=1, keepdims=True)
    col_sums = obs.sum(axis=0, keepdims=True)
    total = obs.sum()
    expected = row_sums @ col_sums / max(total, 1)
    with np.errstate(divide='ignore', invalid='ignore'):
        valid = (obs > 0) & (expected > 0)
        G = 2.0 * np.sum(obs[valid] * np.log(obs[valid] / expected[valid]))
    df = (r - 1) * (c - 1)
    # simple chi2 approx for p-value
    from math import exp
    # Fallback: upper-tail via incomplete gamma approx (rough)
    # If scipy missing, we approximate p crudely (conservative):
    # p ≈ exp(-G/2) * sum_{k=0}^{df-1} (G/2)^k / k!
    # For df up to ~19 this is ok-ish; large df users likely have SciPy installed.
    def _pois_sf(k, lam):
        # P(N>=k) for Poisson(lam)
        s = 0.0
        for i in range(k):
            s += (lam**i)/math.factorial(i)
        return 1.0 - math.exp(-lam)*s
    p = _pois_sf(df, G/2.0)  # crude
    return float(G), float(p)

def _chi2_independence(obs: np.ndarray) -> Tuple[float, float]:
    """Return (stat, p) using SciPy chi2 if available, else G-test fallback."""
    if _HAVE_SCIPY:
        stat, p, _, _ = chi2_contingency(obs, correction=False)
        return float(stat), float(p)
    return _g_test(obs)

def _cramers_v(obs: np.ndarray) -> float:
    """Cramér's V effect size for contingency table."""
    stat, _ = _chi2_independence(obs)
    n = obs.sum()
    if n <= 0:
        return 0.0
    r, c = obs.shape
    return math.sqrt(stat / (n * (min(r, c) - 1 + 1e-12)))

# --- Core helpers ---
_GXY_TRIPLET = re.compile(r"G..")

def extract_xy_triplets(peptide: str) -> Tuple[List[str], List[str]]:
    """Return lists of Xs and Ys from all G–X–Y frames anchored at 'G'."""
    xs, ys = [], []
    if not peptide:
        return xs, ys
    s = peptide
    for i in range(len(s) - 2):
        if s[i] == "G":
            x, y = s[i+1], s[i+2]
            if x.isalpha() and y.isalpha():
                xs.append(x)
                ys.append(y)
    return xs, ys

def consensus_from_xy_lists(xs_list: List[str], ys_list: List[str]) -> Tuple[str, str]:
    """Heuristic MRCA-like consensus: majority residue for each X and Y position
    when multiple aligned exons (same length) are available."""
    # If not aligned lengths, fallback to flat majority
    if not xs_list or not ys_list:
        return "", ""
    # In collagen, GXY triplets are in register; we form a single "modal" token.
    x_counts = Counter(xs_list)
    y_counts = Counter(ys_list)
    return (max(x_counts, key=x_counts.get), max(y_counts, key=y_counts.get))

def xy_subst_counts_for_exon(exon_peptides: List[str],
                             mrca_peptide: Optional[str] = None
                             ) -> Dict[str, np.ndarray]:
    """Build 20x20 substitution matrices for X and Y vs MRCA.
    Keys: {'X','Y'} -> np.ndarray[20,20] where rows = ancestral AA, cols = observed AA."""
    mats = {"X": np.zeros((20, 20), dtype=int),
            "Y": np.zeros((20, 20), dtype=int)}
    # If MRCA is provided as full peptide, we derive its X/Y lists from it.
    mrca_xs, mrca_ys = extract_xy_triplets(mrca_peptide) if mrca_peptide else ([], [])
    for pep in exon_peptides:
        xs, ys = extract_xy_triplets(pep)
        if len(xs) < XY_MIN_GXY_TRIPLETS or len(ys) < XY_MIN_GXY_TRIPLETS:
            continue
        # If MRCA not given or length mismatch, fallback to flat MRCA by modal residue
        if not mrca_xs or len(mrca_xs) != len(xs):
            # per-exon flat MRCA token (single AA baseline)
            mx, my = consensus_from_xy_lists(xs, ys)
            for x in xs:
                if x in AA_INDEX and mx in AA_INDEX:
                    mats["X"][AA_INDEX[mx], AA_INDEX[x]] += 1
            for y in ys:
                if y in AA_INDEX and my in AA_INDEX:
                    mats["Y"][AA_INDEX[my], AA_INDEX[y]] += 1
        else:
            k = min(len(xs), len(mrca_xs))
            for i in range(k):
                a, o = mrca_xs[i], xs[i]
                if a in AA_INDEX and o in AA_INDEX:
                    mats["X"][AA_INDEX[a], AA_INDEX[o]] += 1
            k = min(len(ys), len(mrca_ys))
            for i in range(k):
                a, o = mrca_ys[i], ys[i]
                if a in AA_INDEX and o in AA_INDEX:
                    mats["Y"][AA_INDEX[a], AA_INDEX[o]] += 1
    return mats

def _safe_get_clade(row: pd.Series) -> str:
    return str(row.get(XY_CLADES_FROM, "pan")) if XY_CLADES_FROM in row else "pan"

## Cell 92 – Build MRCA-anchored X/Y substitution matrices

**What this does**

- Groups `SOURCE_DF` by `(paralog_group, exon_num_in_chain, clade)` and
  builds **20×20** substitution matrices for **X** and **Y** positions.
- Aggregates to per-gene/per-clade matrices (`xy_subst_counts_df`).
- Emits per-position **AA diversity** for plotting anchors.

**Notes**

- Uses `MRCA_EXON_MAP[(gene, exon)]` if present; falls back to modal MRCA.
- Skips under-powered cells by `XY_MIN_SPECIES_PER_EXON` and `XY_MIN_GXY_TRIPLETS`.

In [None]:
# ===== Cell 92 =====
# Build MRCA-anchored X/Y substitution matrices (per exon → aggregated)

xy_subst_by_exon_records = []
xy_position_diversity_records = []

if 'SOURCE_DF' in globals() and isinstance(SOURCE_DF, pd.DataFrame) and not SOURCE_DF.empty:
    df = SOURCE_DF.copy()
    if XY_CLADES_FROM not in df.columns:
        df[XY_CLADES_FROM] = "pan"
    df['exon_num_in_chain'] = df['exon_num_in_chain'].astype(int, errors='ignore')

    # Prefer provided MRCA map if available
    MRCA_EXON_MAP = globals().get("MRCA_EXON_MAP", {})

    for (gene, exon, clade), g in df.groupby(['paralog_group', 'exon_num_in_chain', XY_CLADES_FROM]):
        peps = [str(x) for x in g['exon_peptide'].dropna().astype(str).tolist()]
        if len(peps) < XY_MIN_SPECIES_PER_EXON:
            continue

        # Optional MRCA peptide
        mrca = MRCA_EXON_MAP.get((gene, int(exon)))
        mats = xy_subst_counts_for_exon(peps, mrca)

        # Save 20x20 per-exon matrices
        for pos_class in ('X', 'Y'):
            mat = mats[pos_class]
            if mat.sum() == 0:
                continue
            xy_subst_by_exon_records.append({
                "gene": gene, "exon": int(exon), "clade": clade,
                "pos_class": pos_class, "matrix": mat
            })

        # Position-wise AA diversity (counts of unique AA at each XY position)
        # Use the longest in-register set to avoid ragged artifacts
        xs_all, ys_all = [], []
        for p in peps:
            xs, ys = extract_xy_triplets(p)
            xs_all.append(xs); ys_all.append(ys)
        max_len = max((len(x) for x in xs_all), default=0)
        for i in range(max_len):
            obs_x = set(x[i] for x in xs_all if len(x) > i)
            obs_y = set(y[i] for y in ys_all if len(y) > i)
            if obs_x:
                xy_position_diversity_records.append(
                    {"gene":gene, "exon":int(exon), "clade":clade, "pos_class":"X",
                     "position_ix": i, "aa_diversity": len(obs_x)}
                )
            if obs_y:
                xy_position_diversity_records.append(
                    {"gene":gene, "exon":int(exon), "clade":clade, "pos_class":"Y",
                     "position_ix": i, "aa_diversity": len(obs_y)}
                )

# Materialize DataFrames
xy_subst_by_exon_df = pd.DataFrame(xy_subst_by_exon_records)
xy_position_diversity_df = pd.DataFrame(xy_position_diversity_records)

# Aggregate per-gene/clade/pos_class into 20x20 matrices (sum over exons)
xy_subst_matrices: Dict[Tuple[str,str,str], np.ndarray] = {}
for _, r in xy_subst_by_exon_df.iterrows():
    key = (r["gene"], r["pos_class"], r["clade"])
    if key not in xy_subst_matrices:
        xy_subst_matrices[key] = np.zeros((20,20), dtype=int)
    xy_subst_matrices[key] += r["matrix"]

# Tidy count table (long) for downstream stats/plots
xy_subst_rows = []
for (gene, pos_class, clade), M in xy_subst_matrices.items():
    for ai, a in enumerate(XY_AA_ORDER):
        for oi, o in enumerate(XY_AA_ORDER):
            c = int(M[ai, oi])
            if c > 0:
                xy_subst_rows.append(
                    {"gene":gene, "pos_class":pos_class, "clade":clade,
                     "anc":a, "obs":o, "count":c}
                )
xy_subst_counts_df = pd.DataFrame(xy_subst_rows)

_log(f"[XY] Built matrices: {len(xy_subst_matrices)} keys; "
     f"{len(xy_subst_by_exon_df)} exon-level records; "
     f"{len(xy_position_diversity_df)} diversity rows.")

In [None]:
xy_subst_counts_df.head(20)


## Cell 93 – Merge decisions (between genes; between X vs Y) & Regex tokens

**Goal**  
Combine matrices **unless** there is strong evidence of heterogeneity.

**Tests**

- Independence tests across **genes** (per clade, per pos_class); and across **X vs Y**.
- Criteria to **keep separate**:
  - p < `XY_ALPHA` **and** Cramér’s V ≥ `XY_EFFECT_CV`.
- Otherwise we **merge** (sum counts).

**Regex token spec (for anchors/recovery)**

- From merged (or per-gene) frequencies, build position-class tokens:
  - Include AA with freq ≥ `XY_REGEX_MIN_FREQ`.
  - Always include **K/R**, **D/E**, **P** if they reach `XY_REGEX_IMPORTANCE_FREQ`.
  - Guarantee at least one AA (fallback to top-N).
- Output:
  - `xy_merge_decisions_df` (audit)
  - `xy_subst_merged_spec`: {(scope, pos_class, clade) → {"matrix":..., "token":"[...]"}}

In [None]:
# ===== Cell 93 =====
# Merge decisions & regex token spec

def _stack_by(group_keys, pos_class, clade) -> np.ndarray:
    """Stack matrices by group (e.g., by gene) to test heterogeneity."""
    mats = []
    labels = []
    for key, M in xy_subst_matrices.items():
        g, pc, cl = key
        if pc != pos_class or cl != clade:
            continue
        if isinstance(group_keys, list) and g not in group_keys:
            continue
        mats.append(M.sum(axis=0))  # collapse ancestral rows: focus on obs mix
        labels.append(g)
    if not mats:
        return np.zeros((0, 20), dtype=int), labels
    return np.stack(mats, axis=0), labels

def _freqs_from_matrix(M: np.ndarray) -> np.ndarray:
    """Return observed AA frequency vector (20,) from 20x20 matrix."""
    v = M.sum(axis=0).astype(float)
    s = v.sum()
    return (v / s) if s > 0 else np.zeros_like(v)

def _token_from_freqs(freq: np.ndarray) -> str:
    """Build degenerate class token '[...]' from freq with constraints."""
    include = set()
    # primary filter
    for aa, p in zip(XY_AA_ORDER, freq):
        if p >= XY_REGEX_MIN_FREQ:
            include.add(aa)
    # privileged residues (K/R; D/E; P)
    privileged = {"K","R","D","E","P"}
    for aa in privileged:
        idx = AA_INDEX[aa]
        if freq[idx] >= XY_REGEX_IMPORTANCE_FREQ:
            include.add(aa)
    # ensure at least one AA
    if not include:
        # take top 3 by freq
        top_idx = np.argsort(freq)[::-1][:3]
        include = {XY_AA_ORDER[i] for i in top_idx}
    return "[" + "".join(sorted(include)) + "]"

merge_rows = []
xy_subst_merged_spec = {}   # keys: ('merged' or gene, pos_class, clade) -> dict

# 1) Decide per clade whether to merge across genes (for each pos_class)
clades = sorted(set(k[2] for k in xy_subst_matrices.keys())) or ["pan"]
genes = sorted(set(k[0] for k in xy_subst_matrices.keys()))

for clade in clades:
    for pos_class in ("X","Y"):
        table, labels = _stack_by(genes, pos_class, clade)
        # Underpowered: if fewer than 2 groups or small totals, auto-merge
        if table.shape[0] < 2 or table.sum() < 100:
            decision = "merge"
            stat, p, cv = 0.0, 1.0, 0.0
        else:
            stat, p = _chi2_independence(table)
            cv = _cramers_v(table)
            decision = "separate" if (p < XY_ALPHA and cv >= XY_EFFECT_CV) else "merge"

        merge_rows.append({"scope":"genes", "clade":clade, "pos_class":pos_class,
                           "stat":stat, "p":p, "cramers_v":cv, "decision":decision})

        if decision == "merge":
            # build merged matrix (sum all genes)
            M = np.zeros((20,20), dtype=int)
            for g in genes:
                M += xy_subst_matrices.get((g, pos_class, clade), 0)
            freq = _freqs_from_matrix(M)
            token = _token_from_freqs(freq)
            xy_subst_merged_spec[("merged", pos_class, clade)] = {"matrix": M, "freq": freq, "token": token}
        else:
            # keep per-gene specs
            for g in genes:
                M = xy_subst_matrices.get((g, pos_class, clade))
                if M is None or M.sum() == 0:
                    continue
                freq = _freqs_from_matrix(M)
                token = _token_from_freqs(freq)
                xy_subst_merged_spec[(g, pos_class, clade)] = {"matrix": M, "freq": freq, "token": token}

# 2) Decide whether X and Y can be combined (per clade, and per-scope)
xy_merge_decisions_rows = []
for clade in clades:
    scopes = set(s for (s,_,c) in xy_subst_merged_spec.keys() if c == clade)
    for scope in scopes:
        Mx = xy_subst_merged_spec.get((scope, "X", clade), {}).get("matrix")
        My = xy_subst_merged_spec.get((scope, "Y", clade), {}).get("matrix")
        if Mx is None or My is None:
            continue
        fx, fy = _freqs_from_matrix(Mx), _freqs_from_matrix(My)
        table = np.stack([fx, fy], axis=0)
        # transform to pseudo-counts for test stability
        counts = (table * max(Mx.sum()+My.sum(), 1)).astype(int)
        stat, p = _chi2_independence(counts)
        cv = _cramers_v(counts)
        xy_merge_decisions_rows.append({
            "scope": scope, "clade": clade, "compare": "X_vs_Y",
            "stat": stat, "p": p, "cramers_v": cv,
            "decision": "separate" if (p < XY_ALPHA and cv >= XY_EFFECT_CV) else "merge"
        })

xy_merge_decisions_df = pd.DataFrame(merge_rows + xy_merge_decisions_rows)

# Convenience: build unified tokens preferring merges when permitted
xy_unified_tokens = {}
for (scope, pos_class, clade), obj in xy_subst_merged_spec.items():
    # Check whether we can unify X&Y for this scope/clade
    row = xy_merge_decisions_df[
        (xy_merge_decisions_df['scope']==scope) &
        (xy_merge_decisions_df['clade']==clade) &
        (xy_merge_decisions_df['compare']=="X_vs_Y")
    ]
    separate_xy = (not row.empty) and (row.iloc[0]['decision'] == 'separate')
    if separate_xy:
        xy_unified_tokens[(scope, pos_class, clade)] = obj["token"]
    else:
        # Merge X and Y tokens by union
        tok_x = xy_subst_merged_spec.get((scope, "X", clade), {}).get("token", "")
        tok_y = xy_subst_merged_spec.get((scope, "Y", clade), {}).get("token", "")
        merged = "[" + "".join(sorted(set(tok_x.strip("[]")) | set(tok_y.strip("[]")))) + "]"
        xy_unified_tokens[(scope, "XY", clade)] = merged

_log("[XY] Merge decisions ready; tokens constructed.")

## Cell 93a – Anchor Catalog Sanity Check (must be `loaded=True`)
Confirms that Cell 93 has a **loaded** catalog with exon ordering context (not
a tiny, in-memory fallback). Prints shape and gene coverage.


In [None]:
# ===== Cell 93a =====
# Inspect anchor catalog globals set by Cell 93.

def _catshape(df):
    try:
        return df.shape
    except Exception:
        return None

if 'ANCHOR_CATALOG_DF' not in globals():
    logger.warning("[Anchors] ANCHOR_CATALOG_DF not found (Cell 93 may not have run).")
else:
    shp = _catshape(ANCHOR_CATALOG_DF)
    logger.info(f"[Anchors] Catalog present: shape={shp}")
    cols = [c for c in ("gene","chain","exon_index","anchor_regex","loaded") if c in ANCHOR_CATALOG_DF.columns]
    logger.info(f"[Anchors] Columns detected: {cols}")
    if "loaded" in ANCHOR_CATALOG_DF.columns:
        logger.info(f"[Anchors] loaded flag values: {ANCHOR_CATALOG_DF['loaded'].unique()[:5]}")
    if "gene" in ANCHOR_CATALOG_DF.columns:
        logger.info(f"[Anchors] Genes in catalog (top 10):\n"
                    f"{ANCHOR_CATALOG_DF['gene'].value_counts().head(10).to_string()}")

# Optional: require minimal width for recovery
if ('ANCHOR_CATALOG_DF' in globals()) and (ANCHOR_CATALOG_DF.shape[0] < 1000):
    logger.warning("[Anchors] Catalog seems small; recovery may underperform (hits but no chains).")


## Cell 93a – Anchor Catalog (robust load/build + safe fallbacks)

Build or load `anchor_catalog_df` for Recovery.

**Strategy**
1) If a non-empty `anchor_catalog_df` already exists → keep it.
2) Else try to **load** `RUN_DIR/anchor_catalog_df.tsv`.
3) Else **build** from `SOURCE_DF` (entropy↓, length↑; optional gden if available).
4) If strict gates yield nothing, **relax** them and still pick top-K per gene.

**Outputs**
- `anchor_catalog_df` with columns: `gene, exon, median_len, entropy, gden, anchor_score`.
- Saved to `RUN_DIR/anchor_catalog_df.tsv` for reuse.

**Notes**
- Uses `xy_position_diversity_df` as an optional proxy for `gden` when matcher signal isn’t available.
- Respects your globals: `REC_MIN_ANCHOR_LEN`, `REC_MIN_ANCHOR_GDEN`, `REC_K_TOP_ANCHORS`.


In [None]:
# ===== Cell 93a =====
# Anchor Catalog (robust): load from disk or build with safe fallbacks

import numpy as np
import pandas as pd
import math
from pathlib import Path

# ---- Params / defaults (non-destructive) ----
REC_MIN_ANCHOR_LEN   = int(globals().get("REC_MIN_ANCHOR_LEN", 18))
REC_MIN_ANCHOR_GDEN  = float(globals().get("REC_MIN_ANCHOR_GDEN", 0.20))
REC_K_TOP_ANCHORS    = int(globals().get("REC_K_TOP_ANCHORS", 6))
RUN_DIR = Path(globals().get("RUN_DIR", "."))

def _char_entropy(strings):
    if not strings:
        return 0.0
    s = "".join(strings)
    if not s:
        return 0.0
    from collections import Counter
    cnt = Counter(s)
    total = sum(cnt.values())
    ps = [c/total for c in cnt.values()]
    return -sum(p*math.log(p+1e-12) for p in ps)

def _optional_gden_proxy(gene, exon) -> float:
    """
    Prefer matcher/library conservation if available; else derive a proxy from
    XY position diversity (lower diversity => higher gden).
    """
    # 1) Try library/matcher conservation estimate
    try:
        if 'RegExTractorMatcher' in globals() and 'orthology_aware_library' in globals():
            # reuse internal heuristic if exposed (best-effort)
            lib = orthology_aware_library
            entry = lib.entries.get((gene, int(exon), "pan")) if getattr(lib, "entries", None) else None
            if entry and hasattr(RegExTractorMatcher, "_pattern_conservation"):
                # Use highest-tier pattern for stability
                patt = entry.tiers[0].pattern if entry.tiers else None
                if patt is not None:
                    return float(RegExTractorMatcher._pattern_conservation(patt))
    except Exception:
        pass

    # 2) Fallback: use XY diversity (needs xy_position_diversity_df)
    if 'xy_position_diversity_df' in globals() and isinstance(xy_position_diversity_df, pd.DataFrame):
        sub = xy_position_diversity_df
        mask = (sub["gene"]==gene) & (sub["exon"].astype(int)==int(exon))
        if mask.any():
            # Normalize mean diversity to [0,1] and invert
            vals = sub.loc[mask, "aa_diversity"].astype(float)
            if not vals.empty:
                m = float(vals.mean())
                # typical AA diversity range ~1..6; scale conservatively
                gden = 1.0 - min(m/6.0, 1.0)
                return gden
    # 3) Last resort
    return 0.0

def _build_anchor_catalog_from_source() -> pd.DataFrame:
    rows = []
    if 'SOURCE_DF' not in globals() or not isinstance(SOURCE_DF, pd.DataFrame) or SOURCE_DF.empty:
        return pd.DataFrame(columns=["gene","exon","median_len","entropy","gden","anchor_score"])
    df = SOURCE_DF.copy()
    if "exon_num_in_chain" not in df.columns or "exon_peptide" not in df.columns or "paralog_group" not in df.columns:
        return pd.DataFrame(columns=["gene","exon","median_len","entropy","gden","anchor_score"])

    df["exon_num_in_chain"] = df["exon_num_in_chain"].astype(int, errors="ignore")

    for (gene, exon), g in df.groupby(["paralog_group","exon_num_in_chain"]):
        peps = g["exon_peptide"].dropna().astype(str).tolist()
        if not peps:
            continue
        med_len = float(np.median([len(p) for p in peps]))
        ent = _char_entropy(peps)
        gden = _optional_gden_proxy(str(gene), int(exon))
        rows.append({"gene":str(gene), "exon":int(exon),
                     "median_len":med_len, "entropy":ent, "gden":gden})
    stat_df = pd.DataFrame(rows)
    if stat_df.empty:
        return pd.DataFrame(columns=["gene","exon","median_len","entropy","gden","anchor_score"])

    # rank-based normalisation (robust)
    stat_df["len_z"] = stat_df["median_len"].rank(pct=True)
    stat_df["ent_z"] = 1.0 - stat_df["entropy"].rank(pct=True)  # lower entropy better
    # If gden is all zeros, fill with median to avoid nuking score
    med_g = stat_df["gden"].replace({np.nan:0.0})
    fill_g = med_g.median() if not med_g.empty else 0.0
    stat_df["gden_f"] = stat_df["gden"].replace({np.nan:fill_g})
    stat_df["gden_z"] = stat_df["gden_f"].rank(pct=True)

    stat_df["anchor_score"] = 0.4*stat_df["gden_z"] + 0.4*stat_df["ent_z"] + 0.2*stat_df["len_z"]

    # Strict gates first
    stat_df["ok_len"] = stat_df["median_len"] >= REC_MIN_ANCHOR_LEN
    stat_df["ok_gden"] = stat_df["gden_f"] >= REC_MIN_ANCHOR_GDEN

    strict = (
        stat_df[stat_df["ok_len"] & stat_df["ok_gden"]]
        .sort_values(["gene","anchor_score"], ascending=[True, False])
        .groupby("gene", as_index=False)
        .head(REC_K_TOP_ANCHORS)
    )

    if not strict.empty:
        out = strict[["gene","exon","median_len","entropy","gden","anchor_score"]].reset_index(drop=True)
        return out

    # Fallback 1: drop gden gate
    relaxed = (
        stat_df[stat_df["ok_len"]]
        .sort_values(["gene","anchor_score"], ascending=[True, False])
        .groupby("gene", as_index=False)
        .head(REC_K_TOP_ANCHORS)
    )
    if not relaxed.empty:
        return relaxed[["gene","exon","median_len","entropy","gden","anchor_score"]].reset_index(drop=True)

    # Fallback 2: ignore gates, still pick top-K
    topk = (
        stat_df.sort_values(["gene","anchor_score"], ascending=[True, False])
        .groupby("gene", as_index=False)
        .head(max(REC_K_TOP_ANCHORS, 3))
    )
    return topk[["gene","exon","median_len","entropy","gden","anchor_score"]].reset_index(drop=True)

# ---- Load or build ----

loaded = False
if "anchor_catalog_df" in globals() and isinstance(anchor_catalog_df, pd.DataFrame) and not anchor_catalog_df.empty:
    loaded = True
else:
    # Try disk
    tsv = RUN_DIR / "anchor_catalog_df.tsv"
    if tsv.exists():
        try:
            anchor_catalog_df = pd.read_csv(tsv, sep="\t")
            if not anchor_catalog_df.empty:
                loaded = True
        except Exception:
            pass

if not loaded:
    anchor_catalog_df = _build_anchor_catalog_from_source()

# Persist for reuse
try:
    if isinstance(anchor_catalog_df, pd.DataFrame) and not anchor_catalog_df.empty:
        RUN_DIR.mkdir(parents=True, exist_ok=True)
        anchor_catalog_df.to_csv(RUN_DIR/"anchor_catalog_df.tsv", sep="\t", index=False)
except Exception as e:
    if 'rex_log' in globals():
        try: rex_log(f"Anchor catalog save failed: {e}")
        except Exception: pass

# Diagnostics
n_genes = anchor_catalog_df["gene"].nunique() if not anchor_catalog_df.empty else 0
n_rows  = len(anchor_catalog_df) if isinstance(anchor_catalog_df, pd.DataFrame) else 0
msg = f"[RegExTractor] Anchor catalog ready: {n_rows} rows across {n_genes} genes (loaded={loaded})."
print(msg)
if 'rex_log' in globals():
    try: rex_log(msg)
    except Exception: pass


## Cell 94 – Recovery Engine (anchor-guided, species-wise)

Uses your existing matcher/library; unchanged names/signatures.  
Adds:
- Robust gene vote tie-breaks
- Naming rule: if G–X–Y fraction ≥ `REC_GXY_MIN_FRACTION` **and** length within
  expected → assign bare gene (e.g., **COL1A1**), else **COL1A1\_0.xx**.
- Optional “unified X/Y token” fallback tiers for pan-clade anchors.

**Inputs:** `anchor_catalog_df` from Cell 93 (your upstream); `orthology_aware_library`, `RegExTractorMatcher`, `SOURCE_DF`.  
**Outputs:** `recovery_hits_df`, `recovered_chains_df`.

In [None]:
# ===== Cell 94 =====
# Recovery engine (keeps your function/class names; adds guardrails)

from collections import Counter
from typing import Tuple

# Defaults if not present
REC_GXY_MIN_FRACTION = float(globals().get("REC_GXY_MIN_FRACTION", 0.85))
REC_MIN_ANCHOR_LEN = int(globals().get("REC_MIN_ANCHOR_LEN", 18))
REC_MIN_ANCHOR_GDEN = float(globals().get("REC_MIN_ANCHOR_GDEN", 0.2))
REC_K_TOP_ANCHORS = int(globals().get("REC_K_TOP_ANCHORS", 6))
rex_chain_min_consecutive = int(globals().get("rex_chain_min_consecutive", 3))

def _is_gxy_like(seq: str) -> float:
    if not seq:
        return 0.0
    n = 0
    for i in range(len(seq)-2):
        if seq[i] == "G":
            n += 1
    return n / max(len(seq)-2, 1)

def _name_prediction(paralog_group: str, gxy_fraction: float,
                     expected_len: Optional[int] = None,
                     observed_len: Optional[int] = None) -> str:
    """Assign final gene name; if ambiguous, append _0.xx."""
    gene = paralog_group
    ok_len = True
    if expected_len and observed_len:
        ok_len = (0.85 * expected_len) <= observed_len <= (1.15 * expected_len)
    if gxy_fraction >= REC_GXY_MIN_FRACTION and ok_len:
        return gene
    xx = f"{gxy_fraction:.2f}".lstrip("0")
    return f"{gene}_{xx}"

def run_recovery_engine(target_pool: pd.DataFrame) -> Tuple[pd.DataFrame, pd.DataFrame]:
    if 'orthology_aware_library' not in globals():
        _log("Recovery: orthology_aware_library not found; run Cell 73 first.")
        return pd.DataFrame(), pd.DataFrame()
    if 'anchor_catalog_df' not in globals() or anchor_catalog_df.empty:
        _log("Recovery: anchor catalog unavailable; run Cell 93.")
        return pd.DataFrame(), pd.DataFrame()
    if 'SOURCE_DF' not in globals() or SOURCE_DF.empty:
        _log("Recovery: SOURCE_DF not found.")
        return pd.DataFrame(), pd.DataFrame()

    lib = orthology_aware_library
    matcher = RegExTractorMatcher(lib)

    anchors = {(r["gene"], int(r["exon"])) for _, r in anchor_catalog_df.iterrows()}
    if not anchors:
        _log("Recovery: no anchors available after filters.")
        return pd.DataFrame(), pd.DataFrame()

    # expected exon architectures per gene (training)
    arch = SOURCE_DF.groupby('paralog_group')['exon_num_in_chain'] \
                    .apply(lambda x: sorted(set(map(int, x)))) \
                    .to_dict()

    rec_hits, rec_chains = [], []

    for _, row in target_pool.iterrows():
        acc = row.get("accession") or row.get("Entry") or row.get("accession_id")
        seq = str(row.get("sequence", "")) or str(row.get("Sequence", ""))
        if not acc or not seq:
            continue

        # 1) Anchor matching
        best_anchor_hit = {}
        for (gene, exon) in anchors:
            entry = lib.entries.get((gene, exon, "pan"))
            if not entry:
                continue
            for tier in entry.tiers:
                for m in tier.regex.finditer(seq):
                    score, gden = matcher._score_hit(m.group(0), tier)
                    h = RexHit(acc, gene, exon, "pan", tier.tier,
                               m.start(), m.end(), m.group(0), gden, score)
                    cur = best_anchor_hit.get((gene, exon))
                    if (cur is None) or (h.score > cur.score):
                        best_anchor_hit[(gene, exon)] = h

        if not best_anchor_hit:
            # Optional: fallback pass using unified XY tokens as very permissive anchors
            # (only if you want a last-ditch scan)
            # Skipped by default to avoid FP inflation.
            continue

        # 2) Gene identification by anchor votes (tie-break: total score)
        votes = Counter(g for (g, _), _ in best_anchor_hit.items())
        # break ties with accumulated score
        candidates = []
        for g, v in votes.items():
            s = sum(h.score for (gg, _e), h in best_anchor_hit.items() if gg == g)
            candidates.append((g, v, s))
        candidates.sort(key=lambda t: (t[1], t[2]), reverse=True)
        top_gene, top_votes, _ = candidates[0]

        # 3) Extend chain for chosen gene
        gene_hits = [h for (g, e), h in best_anchor_hit.items() if g == top_gene]
        # best per exon
        best_per_exon = {}
        for h in gene_hits:
            cur = best_per_exon.get(h.exon_num_in_chain)
            if (cur is None) or (h.score > cur.score):
                best_per_exon[h.exon_num_in_chain] = h
        if not best_per_exon:
            continue

        seed = max(best_per_exon.values(), key=lambda h: h.score)
        chain = rex_walk_chain(seed, best_per_exon, arch.get(top_gene, []))

        # 4) Quality gates & naming
        span_pep = seq[chain.start:chain.end]
        gxy_frac = _is_gxy_like(span_pep)
        ok_blocks = getattr(chain, "consecutive_blocks", 1) >= rex_chain_min_consecutive
        if ok_blocks and gxy_frac >= REC_GXY_MIN_FRACTION:
            # expected length (heuristic from training)
            exons_for_gene = arch.get(top_gene, [])
            expected_len = 3 * len(exons_for_gene) * 9  # very rough; adjust if you track exon lengths
            name = _name_prediction(top_gene, gxy_frac, expected_len, len(span_pep))
            chain_dict = chain.__dict__.copy()
            chain_dict.update({
                "predicted_gene": name,
                "gxy_fraction": gxy_frac,
                "accession": acc
            })
            rec_chains.append(chain_dict)

        # record hits (for auditing)
        rec_hits.extend(h.__dict__ for h in best_per_exon.values())

    return pd.DataFrame(rec_hits), pd.DataFrame(rec_chains)

# Run (user may pre-filter target_pool_rex to 5 trial species)
recovery_hits_df, recovered_chains_df = run_recovery_engine(
    target_pool_rex if 'target_pool_rex' in globals() else pd.DataFrame()
)
_log(f"Recovery: {len(recovery_hits_df)} anchor hits; "
     f"{len(recovered_chains_df)} recovered chains.")

In [None]:
anchor_catalog_df.head()

In [None]:
recovered_chains_df.head()

## Cell 95 – Outlier Exons (per clade)

Flags exons that **deviate** more than expected given clade variability.

**Metrics**

- Entropy vs baseline (gene-wide) with z-score
- KL divergence to baseline mix
- Missingness given coverage

**Outputs**: `outlier_exons_df` with `[gene, exon, clade, metric, value, z, status, n_obs]`

In [None]:
# ===== Cell 95 =====
# Outlier exons by clade (improved; entropy + KL + missingness)

def _entropy_of_string_list(strings: List[str]) -> float:
    # character-level entropy across concatenated peptides
    if not strings:
        return 0.0
    counts = Counter("".join(strings))
    total = sum(counts.values())
    ps = [c/total for c in counts.values()]
    return -sum(p*math.log(p+1e-12) for p in ps)

def _kl(p: np.ndarray, q: np.ndarray) -> float:
    # KL(p||q) safe
    p = p.astype(float); q = q.astype(float)
    p = p / max(p.sum(), 1.0); q = q / max(q.sum(), 1.0)
    with np.errstate(divide='ignore', invalid='ignore'):
        m = (p > 0) & (q > 0)
        return float(np.sum(p[m] * np.log((p[m] + 1e-12) / (q[m] + 1e-12))))

rows = []
if 'SOURCE_DF' in globals() and not SOURCE_DF.empty:
    df = SOURCE_DF.copy()
    if XY_CLADES_FROM not in df.columns:
        df[XY_CLADES_FROM] = "pan"
    df['exon_num_in_chain'] = df['exon_num_in_chain'].astype(int, errors='ignore')

    for (gene, exon), g in df.groupby(["paralog_group", "exon_num_in_chain"]):
        peps_all = g["exon_peptide"].dropna().astype(str).tolist()
        base_ent = _entropy_of_string_list(peps_all)
        base_freq = np.zeros(20)
        xs_all, ys_all = [], []
        for p in peps_all:
            xs, ys = extract_xy_triplets(p)
            for z in xs+ys:
                if z in AA_INDEX:
                    base_freq[AA_INDEX[z]] += 1

        for clade, gc in g.groupby(XY_CLADES_FROM):
            peps = gc["exon_peptide"].dropna().astype(str).tolist()
            ent = _entropy_of_string_list(peps)
            freq = np.zeros(20)
            for p in peps:
                xs, ys = extract_xy_triplets(p)
                for z in xs+ys:
                    if z in AA_INDEX:
                        freq[AA_INDEX[z]] += 1
            kl = _kl(freq, base_freq) if base_freq.sum() > 0 else 0.0
            n = len(peps)
            # crude z proxy using 2-point std (defensive)
            z = (ent - base_ent) / (np.std([base_ent, ent]) + 1e-6)
            status = "ok"
            if n >= max(XY_MIN_SPECIES_PER_EXON, 8) and (z > 3.0 or kl > 0.5):
                status = "high_variability"
            elif n == 0:
                status = "missing"

            rows.append({"gene":gene, "exon":int(exon), "clade":clade,
                         "entropy":ent, "baseline_entropy":base_ent, "z":z,
                         "kl":kl, "status":status, "n_obs":n})

outlier_exons_df = pd.DataFrame(rows)
_log(f"[Outliers] flagged {outlier_exons_df.query('status != \"ok\"').shape[0]} cases.")



## Cell 96 – Reports & Artifacts

- Save:
  - `xy_subst_counts_df.tsv`, `xy_subst_matrices.npz`,
  - `xy_merge_decisions_df.tsv`, `xy_unified_tokens.tsv`,
  - `anchor_catalog_df.tsv`, `recovery_hits_df.tsv`, `recovered_chains_df.tsv`,
  - `outlier_exons_df.tsv`
- Plot (optional stubs):
  - Heatmaps of X/Y substitution (merged vs per-gene)
  - Per-exon AA diversity lines to visualise anchor candidates
python

In [None]:
# ===== Cell 96 =====
# Persist artifacts + plotting stubs (matplotlib)

import json
from pathlib import Path
import numpy as np

RUN_DIR = globals().get("RUN_DIR", Path("."))
RUN_DIR = Path(RUN_DIR)

def _save_df(df, name):
    try:
        if df is not None and isinstance(df, pd.DataFrame) and not df.empty:
            df.to_csv(RUN_DIR / f"{name}.tsv", sep="\t", index=False)
            _log(f"[Save] {name}.tsv")
    except Exception as e:
        _log(f"[Save] failed {name}: {e}")

try:
    # tables
    _save_df(xy_subst_counts_df, "xy_subst_counts_df")
    _save_df(xy_subst_by_exon_df, "xy_subst_by_exon_df")
    _save_df(xy_merge_decisions_df, "xy_merge_decisions_df")
    _save_df(xy_position_diversity_df, "xy_position_diversity_df")
    _save_df(anchor_catalog_df if 'anchor_catalog_df' in globals() else pd.DataFrame(), "anchor_catalog_df")
    _save_df(recovery_hits_df if 'recovery_hits_df' in globals() else pd.DataFrame(), "recovery_hits_df")
    _save_df(recovered_chains_df if 'recovered_chains_df' in globals() else pd.DataFrame(), "recovered_chains_df")
    _save_df(outlier_exons_df if 'outlier_exons_df' in globals() else pd.DataFrame(), "outlier_exons_df")

    # tokens as a small table
    if 'xy_unified_tokens' in globals() and xy_unified_tokens:
        tok_rows = [{"scope":k[0], "pos_class":k[1], "clade":k[2], "token":v}
                    for k, v in xy_unified_tokens.items()]
        pd.DataFrame(tok_rows).to_csv(RUN_DIR / "xy_unified_tokens.tsv", sep="\t", index=False)

    # matrices bundle
    if 'xy_subst_matrices' in globals() and xy_subst_matrices:
        np.savez_compressed(RUN_DIR / "xy_subst_matrices.npz",
                            **{f"{g}_{pc}_{cl}": M
                               for (g,pc,cl), M in xy_subst_matrices.items()})

    _log("Recovery artifacts saved.")
except Exception as e:
    _log(f"Could not save recovery artifacts: {e}")

# --- Plotting stubs (optional) ---
try:
    import matplotlib.pyplot as plt

    # Example: diversity plot for a single gene/exon
    # (Commented to avoid accidental runtime in batch)
    # gex = xy_position_diversity_df.query("gene == @some_gene and exon == @some_exon")
    # for pos_class, gg in gex.groupby("pos_class"):
    #     plt.figure()
    #     plt.plot(gg["position_ix"], gg["aa_diversity"])
    #     plt.title(f"{some_gene} exon {some_exon} {pos_class}-diversity")
    #     plt.xlabel("Position in GXY frame"); plt.ylabel("# distinct AA")
    #     plt.show()
except Exception:
    pass


# **Part 7: Shannon Entropy Analysis**

## Cell 71 – Per-Exon Entropy Calculation and Visualization

Calculates the Shannon entropy for each amino acid position within each exon across all sequences of a given gene. It saves the results as a TSV file and generates plots showing the median exon length and entropy for each gene.

In [None]:
# ===== Cell 71 =====
# Per-exon entropy and plotting

import matplotlib.pyplot as plt

entropy_rows = []
if 'wide_df' in globals() and not wide_df.empty:
    for g, sub in wide_df.groupby('gene_symbol'):
        exon_pep_cols = sorted(
            [c for c in sub.columns if c.startswith('exon_peptide_')],
            key=lambda c: int(re.search(r'_(-?\d+)$', c).group(1))
        )

        for col in exon_pep_cols:
            peps = sub[col].dropna().tolist()
            if not peps: continue

            max_len = max((len(p) for p in peps), default=0)
            padded = [p.ljust(max_len, '-') for p in peps]
            ents = [rex_shannon_entropy([p[i] for p in padded if i < len(p) and p[i] != '-']) for i in range(max_len)]

            entropy_rows.append({
                'gene_symbol': g,
                'exon_col': col,
                'median_length': np.median([len(p) for p in peps]),
                'entropy': np.median(ents) if ents else 0
            })

    entropy_df = pd.DataFrame(entropy_rows)
    if not entropy_df.empty:
        entropy_df.to_csv(ENTROPY_TSV, sep='\t', index=False)
        logger.info(f"Entropy stats saved to {ENTROPY_TSV.name}")

        # Plotting
        for g, sub_df in entropy_df.groupby('gene_symbol'):
            sub_df = sub_df.copy()
            sub_df['exon_idx'] = sub_df['exon_col'].str.extract(r'_(-?\d+)$').astype(int)
            sub_df = sub_df.sort_values('exon_idx')

            plt.figure(figsize=(12, 5))
            plt.bar(sub_df['exon_idx'].astype(str), sub_df['median_length'], yerr=sub_df['entropy'], capsize=4)
            plt.title(f"Median Exon Length and Entropy for {g}")
            plt.xlabel("Exon Number")
            plt.ylabel("Amino Acid Length")
            plt.xticks(rotation=90)
            plt.tight_layout()
            plot_path = RUN_DIR / f"entropy_plot_{g}_{RUN_ID}.png"
            plt.savefig(plot_path)
            plt.close()
            logger.info(f"Entropy plot for {g} saved to {plot_path.name}")
else:
    logger.warning("wide_df not available; skipping entropy analysis.")

# **Part 8: Reproducibility & Manifest**

## Cell 99 – Final Manifest Generation

This cell concludes the run by generating a final, comprehensive JSON manifest. It records key counts from each major step and computes SHA256 hashes for all primary output files, ensuring a reproducible record of the pipeline's execution.

In [None]:
# ===== Cell 99 =====
# Final manifest writer

logger.info("Generating final run manifest...")

# Gather final counts from key dataframes
final_counts = {
    "working_rows": len(working_df) if 'working_df' in globals() else 0,
    "high_quality_rows": len(df_high_quality) if 'df_high_quality' in globals() else 0,
    "total_raw_exons": len(df_raw_exons) if 'df_raw_exons' in globals() else 0,
    "consensus_exons": len(consensus_tbl) if 'consensus_tbl' in globals() else 0,
    "wide_architectures": len(wide_df) if 'wide_df' in globals() else 0,
    "rex_rescued_chains": len(rex_chains_df) if 'rex_chains_df' in globals() else 0,
}

# List of primary output files for this run
output_files_to_hash = [
    WORKING_SNAPSHOT, CONSENSUS_LONG_TSV, CONSENSUS_TABLE_TSV,
    WIDE_ARCH_TSV, ENTROPY_TSV, ERROR_REPORT_PATH,
    EVOLUTION_EVENTS_TSV, RESCUE_HITS_TSV, RESCUE_CHAINS_TSV
]
output_files_to_hash = [p for p in output_files_to_hash if p.exists()]

# Update the manifest with final stats and file hashes
write_run_manifest(extra={"final_counts": final_counts}, final_files=output_files_to_hash)

logger.info("✅ Collagen Exon Mapper pipeline finished successfully.")

## Cell 100 – Final Run Summary Report

This cell provides a high-level summary of the key metrics and outputs from the entire pipeline run. It consolidates the most important counts from each major stage, offering a quick and clear overview of what was accomplished.

In [None]:
# ===== Cell 100 =====
# Final Run Summary Report

logger.info("="*50)
logger.info(" PIPELINE RUN SUMMARY")
logger.info("="*50)

# --- Part 2: Data Loading & Pre-processing ---
total_initial_seqs = len(full_df) if 'full_df' in globals() else 0
working_set_seqs = len(working_df) if 'working_df' in globals() else 0
logger.info(f"[Part 2] Initial Data Loading:")
logger.info(f"  - Total sequences in master dataset: {total_initial_seqs}")
logger.info(f"  - Sequences in this run's working set: {working_set_seqs}")

# --- Part 3 & 4: Seed Mapping ---
hq_candidates = len(df_high_quality) if 'df_high_quality' in globals() else 0
seed_exons_mapped = len(df_raw_exons_seed) if 'df_raw_exons_seed' in globals() else 0
logger.info(f"\n[Part 3 & 4] High-Confidence Seed Mapping:")
logger.info(f"  - High-quality candidates after all filters: {hq_candidates}")
logger.info(f"  - Total exons mapped from seed sequences: {seed_exons_mapped}")

# --- Part 5: Architecture-Driven Rescue ---
rescued_proteins = df_rescued_exons['accession'].nunique() if 'df_rescued_exons' in globals() else 0
rescued_exons_found = len(df_rescued_exons[df_rescued_exons['peptide'] != 'MISSING_EXON']) if 'df_rescued_exons' in globals() else 0
rescued_exons_missing = len(df_rescued_exons[df_rescued_exons['peptide'] == 'MISSING_EXON']) if 'df_rescued_exons' in globals() else 0
logger.info(f"\n[Part 5] Architecture-Driven Rescue:")
logger.info(f"  - Full architectures reconstructed for: {rescued_proteins} proteins")
logger.info(f"  - Total exons found via 'fishing': {rescued_exons_found}")
logger.info(f"  - Total missing exons identified (padded): {rescued_exons_missing}")

# --- Part 6: Final Consensus & Analysis ---
final_consensus_exons = len(consensus_tbl) if 'consensus_tbl' in globals() else 0
dated_events = len(df_evolution) if 'df_evolution' in globals() else 0
final_architectures = len(wide_df) if 'wide_df' in globals() else 0
logger.info(f"\n[Part 6] Final Consensus & Evolutionary Analysis:")
logger.info(f"  - Canonical exons in final consensus: {final_consensus_exons}")
logger.info(f"  - Evolutionary events dated: {dated_events}")
logger.info(f"  - Final wide-format architectures generated: {final_architectures}")

# --- Part 7: RegExTractor Gene Classification ---
unclassified_pool_size = len(target_pool_rex) if 'target_pool_rex' in globals() else 0
newly_classified_chains = len(classified_chains_df) if 'classified_chains_df' in globals() and not classified_chains_df.empty else 0
logger.info(f"\n[Part 7] Gene Classification:")
logger.info(f"  - Sequences in the unclassified target pool: {unclassified_pool_size}")
logger.info(f"  - New sequences classified with high confidence: {newly_classified_chains}")

logger.info("="*50)