<a href="https://colab.research.google.com/github/Palaeoprot/PRIDE/blob/main/Blast_ExonMapperv1_5.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<a href="https://colab.research.google.com/" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 🧬 **Collagen Exon Mapper v1.5 (Colab Edition)**

A production-ready, Colab-focused pipeline for mapping collagen exon
architectures across taxa, with robust caching, phylogeny-weighted
consensus, optional rescue of misread sequences, and clear per-step
reporting for new users.

**What’s new vs v1.3**
- Colab/Drive aware setup with clear runtime checks.
- UniProt-first **Taxonomic Lineage Expansion** (IDs + names) producing
  stable **cluster keys** for downstream consensus/regex.
- **PhylogeneticConsensusEngine** with TimeTree distance cache (MYR),
  adaptive tolerances, and MRCA + reliability metrics.
- Structured **rescue hooks** (gap detection, 3-frame translation via
  Ensembl; toggled off by default) with detailed logging.
- Persistent **rejection TSV** with explicit reasons across stages.
- Clear logs and status messages at each step for naïve users.

# **Part 1: Paths & Project Layout**


- This keeps *everything* for collagens in one place: raw UniProt input, per-run exon maps, rejected IDs, manifests, entropy stats, and regex extractor outputs.

---

## Directory & Path Variables

### 1. Root
- `DATA_ROOT = "_SHARED_DATA/ExonMaps"`
- Every run and cache is relative to this.

### 2. Gene Family Subdir
- `GENE_FAMILY = "collagens"`
- `FAMILY_DIR = f"{DATA_ROOT}/{GENE_FAMILY}"`

### 3. Inputs
- `UNIPROT_TSV_GZ = f"{FAMILY_DIR}/uniprot_collagens.tsv.gz"`
  - The canonical UniProt download.
- `FASTA_ADDITIONS_DIR = f"{FAMILY_DIR}/extra_fasta"`
  - For manually added FASTA sequences.

### 4. Run Outputs
- Instead of dumping into `CollagenExonMapper/run_YYYYMMDD_HHMMSS`, runs go here:
  - `RUN_DIR = f"{FAMILY_DIR}/runs/run_{TIMESTAMP}"`

Inside each `RUN_DIR`:
- `log.txt` — console/file logging.
- `manifest.json` — provenance info.
- `consensus_long.tsv` — per-exon peptides & coords.
- `entropy_stats.tsv` — Shannon entropy results.
- `rescue_log.tsv` — audit of rejected/rescued exons.
- `regextractor/` — outputs from RegExTractor rescues.

### 5. Caches
- `CACHE_DIR = f"{FAMILY_DIR}/cache"`
  - Contains incremental exon map cache, taxonomy cache, QC reports.
- Each cache is backed up before overwrite.

### 6. Rejections
- `REJECTED_TSV = f"{FAMILY_DIR}/rejected_exons.tsv"`
  - Append each run’s newly rejected IDs + reasons here.
- Re-read at load time to exclude bad sequences.




## Cell 11 – Install core deps & mount Drive

Installs required libraries and mounts your Google Drive for persistent
storage. If you are **not** in Colab, mounting will be skipped.

In [None]:
# ===== Cell 11 =====
# Install core dependencies and mount Google Drive.

!pip install -q biopython ete3 requests pandas numpy matplotlib tqdm

# Mount Google Drive for persistent storage access.
try:
    from google.colab import drive  # type: ignore
    print("💾 Mounting Google Drive...")
    drive.mount('/content/drive', force_remount=True)
    IN_COLAB = True
except Exception:
    print("⚠️ Not in a Google Colab environment. Drive mounting skipped.")
    IN_COLAB = False

import sys, pandas as pd, numpy as np
print(f"Python: {sys.version.split()[0]}")
print(f"pandas: {pd.__version__}, numpy: {np.__version__}")

## Cell 12 – Central configuration & directory setup

Defines user parameters, project/run paths on Drive, logging, and runtime
manifests. **All later cells assume this has been run.**

In [None]:
# ===== Cell 12a =====
# Central configuration panel and dynamic directory setup.

import os, re, io, time, json, hashlib, logging
from pathlib import Path
from datetime import datetime, timezone
from typing import Optional, List, Dict, Tuple
import pandas as pd
import time
import numpy as np
import requests
import sys
from Bio import Entrez
from tqdm.notebook import tqdm

# ---------------- User-Configurable Parameters ----------------
#@markdown #### **Core Settings**
PROJECT_NAME = "CollagenExonMapper"  #@param {type:"string"}
USER_EMAIL   = "matthew@palaeome.org"  #@param {type:"string"}
Entrez.email = USER_EMAIL

#@markdown #### **Gene Selection**
PROCESS_ALL_GENES = True  #@param {type:"boolean"}
GENE_SYMBOLS     = "COL1A1,COL1A2"  #@param {type:"string"}

#@markdown #### **Taxonomic Filtering**
CLADE_TAXIDS = {
    "Metazoa": 33208, "Vertebrata": 7742, "Mammalia": 40674, "Aves": 8782,
    "Reptilia": 8504, "Amphibia": 8292, "Tetrapoda": 32523,
    "Bony fish": 117570, "Cartilaginous fish": 7777, "Catarrhini": 9526
}
TAXONOMIC_FILTER_NAME = "Metazoa" #@param ["Metazoa","Vertebrata","Mammalia","Aves","Reptilia","Amphibia","Tetrapoda","Bony fish","Cartilaginous fish","Catarrhini"]
TARGET_TAXID = CLADE_TAXIDS[TAXONOMIC_FILTER_NAME]

#@markdown #### **Thresholds (Sequences & Mapping)**
MIN_LEN_AA          = 600  #@param {type:"integer"}
MIN_GXY_TRIPLETS    = 30   #@param {type:"integer"}
CHAIN_LENGTH_THRESHOLD      = 0.90  #@param {type:"slider", min:0.5, max:1.0, step:0.05}
MAX_ALLOWED_GAP_PERCENTAGE  = 0.10  #@param {type:"slider", min:0.05, max:0.5, step:0.05}

# Rescue controls

#@markdown #### **Identify & rescue additional sequences (FAST; peptide-only)**
ENABLE_RESCUE             = True   #@param {type:"boolean"}
MINIMUM_RESCUE_SCORE      = 0.50   #@param {type:"number"}
RESCUE_SCORE_IMPROVEMENT  = 0.10   #@param {type:"number"}
GXY_CONTENT_THRESHOLD     = 0.80   #@param {type:"number"}

#@markdown #### **DNA fetch for rescue (SLOW; Ensembl REST)**
ENABLE_DNA_FETCH          = False  #@param {type:"boolean"}
ENSEMBL_REQUESTS_PER_SEC  = 5      #@param {type:"integer"}
ENSEMBL_TIMEOUT_SECS      = 20     #@param {type:"integer"}
ENSEMBL_BASE              = "https://rest.ensembl.org"  #@param {type:"string"}

#@markdown #### **Debugging**
DEBUG_SAMPLE_SIZE = -1  #@param {type:"integer"}

# Optional taxonomy engine (not required when UniProt lineage present)
USE_TAXONOMY_ENGINE = False  #@param {type:"boolean"}


In [None]:
# ===== Cell 12b =====
# Canonical path config for _SHARED_DATA/ExonMaps/GeneFamily/collagens
# Forward-only (no legacy aliases). All later parts must import from these.

from pathlib import Path
from datetime import datetime

# --- roots ---
SHARED_DATA_ROOT = Path("_SHARED_DATA")
EXONMAPS_ROOT    = SHARED_DATA_ROOT / "ExonMaps" / "GeneFamily" / "collagens"

# --- per-run (scratch/provenance only; not authoritative) ---
RUN_ID = datetime.utcnow().strftime("run_%Y%m%d_%H%M%S")
RUN_DIR = EXONMAPS_ROOT / "runs" / RUN_ID

# --- subdirs (authoritative) ---
INPUTS_DIR    = EXONMAPS_ROOT / "inputs"
CACHE_DIR     = EXONMAPS_ROOT / "cache"
CONSENSUS_DIR = EXONMAPS_ROOT / "consensus"
OUTPUTS_DIR   = EXONMAPS_ROOT / "outputs"
REGEX_DIR     = EXONMAPS_ROOT / "regex"
RESCUE_DIR    = EXONMAPS_ROOT / "rescue"
LOGS_DIR      = EXONMAPS_ROOT / "logs"
MANIFESTS_DIR = EXONMAPS_ROOT / "manifests"

# ensure dirs
for p in [EXONMAPS_ROOT, RUN_DIR, INPUTS_DIR, CACHE_DIR, CONSENSUS_DIR,
          OUTPUTS_DIR, REGEX_DIR, RESCUE_DIR, LOGS_DIR, MANIFESTS_DIR]:
    p.mkdir(parents=True, exist_ok=True)

# --- inputs (authoritative) ---
UNIPROT_TSV_GZ              = INPUTS_DIR / "uniprot_collagens.tsv.gz"
FILTERED_UNIPROT_TSV        = INPUTS_DIR / "filtered_uniprot_collagens.tsv"
ADDED_FASTA_PATH            = INPUTS_DIR / "added_fastas.fasta"

# --- caches (authoritative) ---
EXON_CACHE_TSV              = CACHE_DIR   / "raw_exons_cache.tsv"
REJECTED_IDS_TSV            = CACHE_DIR   / "rejected_ids.tsv"
TIMETREE_CACHE_JSON         = CACHE_DIR   / "timetree_distance_cache.json"

# --- consensus snapshots (per-run) ---
CONSENSUS_LONG_TSV          = CONSENSUS_DIR / f"consensus_long_{RUN_ID}.tsv"
CONSENSUS_TABLE_TSV         = CONSENSUS_DIR / f"consensus_table_{RUN_ID}.tsv"

# --- final outputs (authoritative) ---
EXON_WIDE_TSV               = OUTPUTS_DIR / f"exon_wide_{RUN_ID}.tsv"
ENTROPY_STATS_TSV           = OUTPUTS_DIR / f"entropy_stats_{RUN_ID}.tsv"
MRCA_RELIABILITY_TSV        = OUTPUTS_DIR / f"mrca_reliability_{RUN_ID}.tsv"

# --- RegExTractor artifacts ---
REGEX_BANK_JSON             = REGEX_DIR   / "regex_bank.json"
REGEX_TRAINING_MANIFEST_JSON= REGEX_DIR   / "regex_training_manifest.json"
REX_RESCUE_LOG_TSV          = RESCUE_DIR  / f"rescue_log_{RUN_ID}.tsv"
REX_RESCUED_EXONS_TSV       = RESCUE_DIR  / f"rescued_exons_{RUN_ID}.tsv"

# --- per-run scratch ---
RUN_LOG_PATH                = RUN_DIR     / "run.log"
RUN_MANIFEST_JSON           = RUN_DIR     / "manifest.json"

print("[Paths] EXONMAPS_ROOT:", EXONMAPS_ROOT)
print("[Paths] RUN_DIR:", RUN_DIR)


## Cell 13 – (Optional) Taxonomy engine initialization

For **FASTA-only** inputs or lineage QA you can enable an ETE3-backed engine.
By default we rely on UniProt lineage columns and skip ETE3.

In [None]:
# ===== Cell 13 =====
# Optional taxonomy engine (no-op by default)
class NCBITaxonomyEngine:
    def __init__(self, enabled: bool, ncbi_db_path: Optional[Path] = None):
        self._ok = False; self._source = "uniprot_lineage_only"; self.ncbi=None
        if not enabled: return
        try:
            from ete3 import NCBITaxa as _ETE_NCBITaxa  # type: ignore
            if ncbi_db_path and ncbi_db_path.exists():
                self.ncbi = _ETE_NCBITaxa(dbfile=str(ncbi_db_path))
                self._source = f"ete3:{ncbi_db_path}"
            else:
                self.ncbi = _ETE_NCBITaxa(); self._source = "ete3:default"
            _ = self.ncbi.get_lineage(1); self._ok = True
        except Exception as e:
            logger.info(f"ETE3 taxonomy disabled/unavailable: {e}")
    def ok(self) -> bool: return self._ok

tax_engine = NCBITaxonomyEngine(USE_TAXONOMY_ENGINE, DRIVE_TAXONOMY_PATH)
logger.info(f"Taxonomy mode: {'ETE3' if tax_engine.ok() else 'UniProt-lineage only'}")

# **Part 2: Incremental Data Loading & Profiling**

## Cell 21 – Unified import, normalization, and cache update

Loads prior cache, ingests new `/content/*.tsv`, normalizes gene symbols,
filters by clade, and snapshots the working dataset.
If nothing is found in `/content`, we proceed with cached data.

In [None]:
# ===== Cell 21 =====
# Unified import, normalization, and raw cache update

def roman_to_int(s: str) -> int:
    m = {'I':1,'V':5,'X':10,'L':50,'C':100,'D':500,'M':1000}
    s = re.sub(r"[^IVXLCDM]", "", s.upper()); total = 0
    for i, ch in enumerate(s):
        v = m.get(ch, 0)
        if i+1 < len(s) and m.get(s[i+1],0) > v: total -= v
        else: total += v
    return total

def infer_specific_collagen_symbol(name: str) -> Optional[str]:
    if not isinstance(name, str) or 'COLLAGEN' not in name.upper(): return None
    rx = re.search(r"ALPHA[ -]?(\d+)\s*\((.*?)\)|TYPE\s+([IVXLCDM]+).*?ALPHA\s+(\d+)",
                   name.upper())
    if not rx: return None
    chain_num = rx.group(1) or rx.group(4)
    type_roman = rx.group(2) or rx.group(3)
    if not chain_num or not type_roman: return None
    t = roman_to_int(type_roman)
    return f"COL{t}A{chain_num}"

def enhanced_gene_normalization(gene: str, pname: str) -> str:
    if isinstance(gene, str) and gene.strip():
        g = gene.strip().upper()
        if re.match(r"^COL\d+A\d+$", g): return g
    if isinstance(pname, str):
        guess = infer_specific_collagen_symbol(pname)
        if guess: return guess
    if isinstance(pname, str) and 'COLLAGEN' in pname.upper():
        return "PROBABLE_COLLAGEN"
    return "UNKNOWN"

def load_new_inputs_from_content() -> pd.DataFrame:
    rows = []; content_dir = Path("/content").resolve()
    for p in content_dir.glob("*.tsv"):
        try:
            df = pd.read_csv(p, sep='\t', low_memory=False)
            df["source_file"] = p.name; rows.append(df)
            logger.info(f"Found input: {p.name} ({len(df)} rows)")
        except Exception as e:
            logger.warning(f"Could not read {p}: {e}")
    return pd.concat(rows, ignore_index=True) if rows else pd.DataFrame()

def safe_read_tsv(path: Path) -> pd.DataFrame:
    try: return pd.read_csv(path, sep='\t', low_memory=False)
    except Exception: return pd.DataFrame()

# Load architecture JSON if present (optional)
COLLAGEN_LENGTH_DATA = {}
if DRIVE_ARCHITECTURES_PATH.exists():
    try:
        COLLAGEN_LENGTH_DATA = json.load(open(DRIVE_ARCHITECTURES_PATH))
        COLLAGEN_LENGTH_DATA = {k.upper(): v for k, v in COLLAGEN_LENGTH_DATA.items()}
        logger.info("Architecture JSON loaded.")
    except Exception as e:
        logger.warning(f"Architecture JSON load failed: {e}")

# Load master or cache
full_df = pd.DataFrame()
if MASTER_TSV_PATH.exists(): full_df = safe_read_tsv(MASTER_TSV_PATH)
if full_df.empty and FILTERED_UNIPROT_TSV.exists():
    full_df = safe_read_tsv(FILTERED_UNIPROT_TSV)

# Load rejected list (TSV with Entry, reason, run_id)
if REJECTED_IDS_PATH.exists():
    rejected_master = safe_read_tsv(REJECTED_IDS_PATH)
    rejected_set = set(rejected_master.get("Entry", pd.Series(dtype=str)))
else:
    rejected_master = pd.DataFrame(columns=["Entry","reason","run_id"]) ; rejected_set=set()
logger.info(f"Loaded cache: {len(full_df)} rows; rejected: {len(rejected_set)}")

# Load new data from /content
new_df = load_new_inputs_from_content()
if not new_df.empty:
    if 'Entry' not in new_df.columns:
        logger.warning("New TSV missing 'Entry' column; skipping merge.")
    else:
        full_df = pd.concat([full_df, new_df], ignore_index=True)

if not full_df.empty:
    full_df.drop_duplicates(subset=['Entry'], keep='last', inplace=True)
    pname = full_df.get('Protein names', pd.Series(['']*len(full_df)))
    full_df['gene_symbol_norm'] = [
        enhanced_gene_normalization(g, n) for g, n in
        zip(full_df.get('Gene Names (primary)', pd.Series(['']*len(full_df))), pname)
    ]
    if PROCESS_ALL_GENES:
        keep_gene = full_df['gene_symbol_norm'].str.contains(r"^COL\d+A\d+$|PROBABLE_COLLAGEN", na=False)
    else:
        targets = [g.strip().upper() for g in GENE_SYMBOLS.split(',')]
        keep_gene = full_df['gene_symbol_norm'].isin(targets)
    keep_tax = True
    if 'Taxonomic lineage (Ids)' in full_df.columns:
        keep_tax = full_df['Taxonomic lineage (Ids)'].astype(str).str.contains(str(TARGET_TAXID), na=False)
    mask = keep_gene & keep_tax
    working_df = full_df.loc[mask].copy()
else:
    working_df = pd.DataFrame()

if DEBUG_SAMPLE_SIZE and DEBUG_SAMPLE_SIZE > 0 and not working_df.empty:
    working_df = working_df.head(DEBUG_SAMPLE_SIZE).copy()
    logger.info(f"DEBUG mode: sampling first {len(working_df)} rows.")

if not working_df.empty:
    working_df.to_csv(WORKING_SNAPSHOT, sep='\t', index=False)
full_df.to_csv(FILTERED_UNIPROT_TSV, sep='\t', index=False)
logger.info(f"Working rows: {len(working_df)} (snapshot saved)")

## Cell 24 – Taxonomic lineage expansion (standardized ranks + cluster keys)

Parses UniProt lineage columns into canonical rank IDs and emits clustering
keys used for phylogeny-weighted consensus and regex derivation.
No ETE3 dependency needed.

In [None]:
# ===== Cell 24 =====
# Taxonomic lineage expansion (standardized ranks + cluster keys)

RANK_CODES = ["King","Phyl","Clas","Orde","Fami","Genu","Spec"]

def enhanced_parse_taxonomic_lineage(lineage_ids_str: str) -> Dict[str, Optional[int]]:
    out = {k: None for k in RANK_CODES}
    if not isinstance(lineage_ids_str, str): return out
    ids = [int(x) for x in re.findall(r"\d+", lineage_ids_str)]
    if not ids: return out
    tail = ids[-7:] if len(ids) >= 7 else ids
    tail = [None] * (7 - len(tail)) + tail
    for code, taxid in zip(RANK_CODES, tail): out[code] = taxid
    return out

def _split_uniprot_lineage_names(s: str) -> List[str]:
    if not isinstance(s, str) or not s.strip(): return []
    return [p for p in re.split(r"[;|,]\s*", s.strip()) if p]

if not working_df.empty:
    # Names & IDs lists (if available)
    working_df["Lineage_Names"] = (
        working_df["Taxonomic lineage"].apply(_split_uniprot_lineage_names)
        if "Taxonomic lineage" in working_df.columns else [[]]*len(working_df)
    )
    working_df["Lineage_IDs_raw"] = (
        working_df["Taxonomic lineage (Ids)"].astype(str)
        if "Taxonomic lineage (Ids)" in working_df.columns else [""]*len(working_df)
    )

    parsed = working_df["Lineage_IDs_raw"].apply(enhanced_parse_taxonomic_lineage)
    for code in RANK_CODES:
        working_df[f"{code}_id"] = parsed.apply(lambda d: d.get(code, None))

    # Genus/Species from Organism string
    if 'Organism' in working_df.columns:
        working_df['Genus'] = working_df['Organism'].astype(str).str.split().str[0]
        working_df['Species'] = working_df['Organism'].astype(str)

    # Cluster keys
    nz = lambda x: x if isinstance(x, str) and x.strip() else "NA"
    genus_name_from_lineage = working_df['Lineage_Names'].apply(lambda L: L[-1] if L else "")
    working_df['Genus_name'] = working_df['Genus'].fillna(genus_name_from_lineage)
    working_df['cluster_genus'] = working_df['Genus_name'].apply(nz)

    def _infer_family(L: List[str]) -> str:
        if not L: return ""
        for n in reversed(L):
            if n.endswith("idae"): return n
        return ""
    def _infer_order(L: List[str]) -> str:
        if not L: return ""
        for n in reversed(L):
            if n.endswith("iformes"): return n
        return ""
    working_df['Family'] = working_df['Lineage_Names'].apply(_infer_family)
    working_df['Order']  = working_df['Lineage_Names'].apply(_infer_order)

    working_df['cluster_family_genus'] = working_df.apply(
        lambda r: f"{nz(r['Family'])}|{nz(r['Genus_name'])}", axis=1)
    working_df['cluster_class_family_genus'] = working_df.apply(
        lambda r: f"{str(r['Clas_id'] or '')}|{nz(r['Family'])}|{nz(r['Genus_name'])}", axis=1)

    logger.info(
        f"Taxonomy expanded: N={len(working_df)} → rank IDs + cluster keys ready.")
else:
    logger.info("No working rows; skipping taxonomy enrichment.")

## Cell 25 – Final input checks & sampling

Ensures required columns exist and applies DEBUG sampling (if set).

In [None]:
# ===== Cell 25 =====
# Final filters and input checks
req = ['Entry','Sequence']
missing = [c for c in req if c not in working_df.columns]
if missing:
    logger.warning(f"Missing required columns: {missing}")
else:
    working_df = working_df[working_df['Sequence'].astype(str).str.len() > 0]
    logger.info(f"Post-filter rows: {len(working_df)}")
if DEBUG_SAMPLE_SIZE and DEBUG_SAMPLE_SIZE > 0 and not working_df.empty:
    working_df = working_df.head(DEBUG_SAMPLE_SIZE).copy()
    logger.info(f"DEBUG mode: sampling first {len(working_df)} rows.")

## Cell 26 – Persist session rejected IDs

Writes run-level rejections (filtered by gene/taxon) and merges them into
the master `rejected_ids.tsv` with reasons and `run_id`.

In [None]:
# ===== Cell 26 =====
# Identify entries not included in working_df and save them.
# Updated to use canonical path configuration and improved error handling.

logging.info("--- Archiving Rejected Entries for This Session ---")

# Initialize session rejected IDs set
session_rejected_ids = set()

# Check if we have the required dataframes
if 'full_df' in globals() and not full_df.empty:
    if 'working_df' in globals() and not working_df.empty:
        # Calculate rejected entries for this session
        all_entries = set(full_df['Entry'].dropna().unique())
        acc_accepted = set(working_df['Entry'].dropna().unique())
        session_rejected_ids = all_entries - acc_accepted

        if session_rejected_ids:
            logging.info(
                f"   {len(session_rejected_ids):,} entries rejected this session."
            )

            # Use canonical path configuration - save to run directory
            rejected_path = RUN_DIR / f"rejected_ids_{RUN_ID}.txt"

            try:
                # Ensure directory exists
                rejected_path.parent.mkdir(parents=True, exist_ok=True)

                # Write rejected IDs to file
                with open(rejected_path, 'w') as f:
                    for acc in sorted(list(session_rejected_ids)):
                        f.write(f"{acc}\n")

                logging.info(f"   ✅ Rejected list saved to: {rejected_path}")

                # Also update the master rejected IDs cache
                master_rejected_path = REJECTED_IDS_TSV

                # Load existing master rejected IDs if they exist
                existing_rejected = set()
                if master_rejected_path.exists():
                    try:
                        with open(master_rejected_path, 'r') as f:
                            existing_rejected = set(line.strip() for line in f if line.strip())
                        logging.info(f"   Loaded {len(existing_rejected)} existing rejected IDs from master list")
                    except Exception as e:
                        logging.warning(f"   Could not load existing rejected IDs: {e}")

                # Combine session rejected with existing rejected
                all_rejected = existing_rejected | session_rejected_ids

                # Save updated master rejected list
                try:
                    # Create backup of existing master list
                    if master_rejected_path.exists():
                        backup_path = CACHE_DIR / f"rejected_ids_backup_{RUN_ID}.tsv"
                        import shutil
                        shutil.copy2(master_rejected_path, backup_path)
                        logging.info(f"   Master rejected IDs backed up to: {backup_path}")

                    # Write updated master list
                    with open(master_rejected_path, 'w') as f:
                        for acc in sorted(list(all_rejected)):
                            f.write(f"{acc}\n")

                    new_rejections = len(all_rejected) - len(existing_rejected)
                    logging.info(f"   ✅ Master rejected IDs updated: {len(all_rejected)} total ({new_rejections} new)")

                except Exception as e:
                    logging.error(f"   Failed to update master rejected IDs: {e}")

            except Exception as e:
                logging.error(f"   Could not save rejected IDs for this run: {e}")
                # Set empty set to prevent issues downstream
                session_rejected_ids = set()

        else:
            logging.info("   No rejections this session.")

    else:
        if 'working_df' not in globals():
            logging.info("   working_df missing; cannot compute rejections.")
        else:
            logging.info("   working_df is empty; all entries were rejected.")
            # If working_df is empty, all entries in full_df were rejected
            session_rejected_ids = set(full_df['Entry'].dropna().unique())

            if session_rejected_ids:
                logging.info(f"   All {len(session_rejected_ids):,} entries were rejected.")

                # Save these rejected IDs
                rejected_path = RUN_DIR / f"rejected_ids_{RUN_ID}.txt"
                try:
                    rejected_path.parent.mkdir(parents=True, exist_ok=True)
                    with open(rejected_path, 'w') as f:
                        for acc in sorted(list(session_rejected_ids)):
                            f.write(f"{acc}\n")
                    logging.info(f"   ✅ All rejected IDs saved to: {rejected_path}")
                except Exception as e:
                    logging.error(f"   Could not save rejected IDs: {e}")
else:
    logging.info("   full_df empty or missing; cannot compute rejections.")

# Create summary statistics
rejection_summary = {
    'run_id': RUN_ID,
    'session_rejected_count': len(session_rejected_ids),
    'timestamp': datetime.utcnow().isoformat(),
}

# Log summary
if session_rejected_ids:
    logging.info(f"   📊 Session Summary: {len(session_rejected_ids)} entries rejected")

    # Save summary to run manifest if it exists
    try:
        if RUN_MANIFEST_JSON.exists():
            import json
            with open(RUN_MANIFEST_JSON, 'r') as f:
                manifest = json.load(f)
        else:
            manifest = {}

        manifest['rejection_summary'] = rejection_summary

        with open(RUN_MANIFEST_JSON, 'w') as f:
            json.dump(manifest, f, indent=2)

        logging.info(f"   ✅ Rejection summary added to run manifest")

    except Exception as e:
        logging.warning(f"   Could not update run manifest: {e}")

# This variable is used in Cell 73 to update the master rejection list
# (Note: session_rejected_ids is now available for downstream cells)

logging.info("--- Rejection Processing Complete ---")

In [None]:
# # ===== Old Cell 26 =====
# # Persist session rejected IDs (with reasons)
# session_rejected = pd.DataFrame(columns=["Entry","reason","run_id"])
# if 'full_df' in globals() and not full_df.empty:
#     all_ids = set(full_df.get('Entry', pd.Series(dtype=str)))
#     acc = set(working_df.get('Entry', pd.Series(dtype=str)))
#     rejected = sorted(all_ids - acc)
#     if rejected:
#         session_rejected = pd.DataFrame({
#             "Entry": rejected,
#             "reason": "filtered_by_gene_or_taxon",
#             "run_id": RUN_ID
#         })
#         session_rejected.to_csv(REJECTED_SNAPSHOT, sep='\t', index=False)
#         logger.info(f"Rejected this run: {len(session_rejected)} (snapshot)")
#     else:
#         logger.info("No session rejections.")

# if REJECTED_IDS_PATH.exists():
#     master_rej = pd.read_csv(REJECTED_IDS_PATH, sep='\t', low_memory=False)
# else:
#     master_rej = pd.DataFrame(columns=["Entry","reason","run_id"])
# master_rej = pd.concat([master_rej, session_rejected], ignore_index=True)
# master_rej.drop_duplicates(subset=['Entry'], keep='last', inplace=True)
# master_rej.to_csv(REJECTED_IDS_PATH, sep='\t', index=False)
# logger.info("Master rejected TSV updated.")

# **Part 3: Chain Identification & Quality Pre-selection**

## Cell 31 – Detect main G–X–Y chain(s)

Triplet-aligned scanning (step=3), minimal Cys-in-helix flag, and a simple
0–100 quality score used later in reliability metrics.

In [None]:
# ===== Cell 31 =====
# Identify main G–X–Y chain(s)

def find_gxy_segments(seq: str, min_triplets: int) -> List[Dict]:
    seq = str(seq); runs: List[Tuple[int,int]] = []
    i = 0
    while i <= max(0, len(seq)-3):
        if seq[i] == 'G' and i+2 < len(seq):
            j = i
            while j+2 < len(seq) and seq[j] == 'G':
                j += 3
            if (j - i) % 3 != 0: j -= ((j - i) % 3)
            triplets = (j - i) // 3
            if triplets >= min_triplets: runs.append((i+1, j))
            i = j
        else:
            i += 1
    return [{"start": s, "end": e} for s, e in runs]

def has_cys_in_range(seq: str, start: int, end: int) -> bool:
    s0 = max(0, start-1); e0 = min(len(seq), end)
    return 'C' in seq[s0:e0]

main_rows = []
if not working_df.empty and 'Sequence' in working_df.columns:
    for _, r in working_df.iterrows():
        segs = find_gxy_segments(r['Sequence'], MIN_GXY_TRIPLETS)
        if segs:
            seg = max(segs, key=lambda x: x['end']-x['start'])
            triplets = (seg['end']-seg['start'])//3
            # crude 0-100 scaling relative to MIN_GXY_TRIPLETS
            qscore = float(min(100.0, 100.0 * triplets / max(1, MIN_GXY_TRIPLETS)))
            flag_cys = has_cys_in_range(r['Sequence'], seg['start'], seg['end'])
            main_rows.append({
                **r,
                "main_chain_segments": [seg],
                "quality_score": qscore,
                "quality_flags": "cys_in_helix" if flag_cys else ""
            })
chain_df = pd.DataFrame(main_rows) if main_rows else pd.DataFrame()
logger.info(f"Candidates with main chain: {len(chain_df)}")

## Cell 32 – QC pass/fail & persistence

Applies length/GXY/Cys checks, logs tallies, and merges QC failures into
the master rejection table with `reason='QC_fail'`.

In [None]:
# ===== Cell 32 =====
# QC pass/fail and rejection persistence
df_high_quality = pd.DataFrame(); df_failed_qc = pd.DataFrame()
if not chain_df.empty:
    pass_rows, fail_rows = [], []
    for _, r in chain_df.iterrows():
        seg = r['main_chain_segments'][0] if r.get('main_chain_segments') else None
        if not seg:
            r2 = r.copy(); r2['failure_reasons'] = 'no_main_chain'; fail_rows.append(r2); continue
        triplets = (seg['end']-seg['start'])//3
        if triplets < MIN_GXY_TRIPLETS or len(str(r.get('Sequence',''))) < MIN_LEN_AA:
            r2 = r.copy(); r2['failure_reasons'] = 'short_or_low_gxy'; fail_rows.append(r2); continue
        if 'cys_in_helix' in r.get('quality_flags',''):
            r2 = r.copy(); r2['failure_reasons'] = 'cys_in_helix'; fail_rows.append(r2); continue
        pass_rows.append(r)
    df_high_quality = pd.DataFrame(pass_rows)
    df_failed_qc = pd.DataFrame(fail_rows)
    logger.info(f"QC pass: {len(df_high_quality)}; fail: {len(df_failed_qc)}")
else:
    logger.info("No chain candidates to QC.")

if not df_failed_qc.empty and 'Entry' in df_failed_qc.columns:
    add = df_failed_qc[['Entry','failure_reasons']].dropna().rename(columns={'failure_reasons':'reason'})
    add['run_id'] = RUN_ID
    if REJECTED_IDS_PATH.exists(): master = pd.read_csv(REJECTED_IDS_PATH, sep='\t')
    else: master = pd.DataFrame(columns=["Entry","reason","run_id"])
    master = pd.concat([master, add], ignore_index=True).drop_duplicates(subset=['Entry'], keep='last')
    master.to_csv(REJECTED_IDS_PATH, sep='\t', index=False)
    logger.info("QC failures merged into master rejection TSV.")

map_df = df_high_quality.copy() if not df_high_quality.empty else pd.DataFrame()

# **Part 4: Enhanced Exon Mapping & Architecture Definition**

## Cell 40 – Mapping safety helpers (resume & atomic writes)

In [None]:
# # ===== Cell 40 =====
# # Mapping safety helpers (resume & atomic writes)

# from datetime import datetime
# import shutil, hashlib

# def _sha256_file(p: Path) -> str:
#     if not p.exists(): return ""
#     h = hashlib.sha256()
#     with open(p, 'rb') as f:
#         for chunk in iter(lambda: f.read(1<<20), b''):
#             h.update(chunk)
#     return h.hexdigest()

# def load_mapped_accessions_from_cache() -> set:
#     """Read the master cache and return the set of already-mapped accessions."""
#     if not EXON_CACHE_TSV.exists():
#         return set()
#     try:
#         df = pd.read_csv(EXON_CACHE_TSV, sep='\t', usecols=['accession'], low_memory=False)
#         return set(df['accession'].astype(str).unique())
#     except Exception as e:
#         logger.warning(f"Could not read cache for resume: {e}")
#         return set()

# def atomic_merge_into_cache(new_rows: pd.DataFrame) -> None:
#     """
#     Append new rows to cache with dedup by (accession, exon_num_in_chain).
#     Writes atomically; never loses previously mapped unique keys.
#     Makes an auto-backup first.
#     """
#     if new_rows is None or new_rows.empty:
#         return

#     key_cols = ['accession', 'exon_num_in_chain']

#     # Load existing cache (if any)
#     if EXON_CACHE_TSV.exists():
#         old = pd.read_csv(EXON_CACHE_TSV, sep='\t', low_memory=False)
#         before_rows = len(old)
#         try:
#             old_keys = set(map(tuple, old[key_cols].astype(str).values))
#         except Exception:
#             # If columns are missing (shouldn't happen), fall back to empty set
#             old_keys = set()
#         before_sha = _sha256_file(EXON_CACHE_TSV)
#     else:
#         old = pd.DataFrame(columns=key_cols)
#         before_rows = 0
#         old_keys = set()
#         before_sha = ""

#     # Merge + dedup
#     merged = pd.concat([old, new_rows], ignore_index=True)
#     # Ensure key columns present
#     for c in key_cols:
#         if c not in merged.columns:
#             merged[c] = ""
#     merged[key_cols] = merged[key_cols].astype(str)
#     merged.drop_duplicates(subset=key_cols, keep='last', inplace=True)

#     # --- SAFETY GUARD: unique-key superset check ---
#     try:
#         new_keys = set(map(tuple, merged[key_cols].values))
#         if not old_keys.issubset(new_keys):
#             # Do NOT overwrite; keep old cache and raise/log
#             missing = old_keys - new_keys
#             logger.error(f"Refusing to shrink unique key set! {len(missing)} keys would be lost.")
#             raise RuntimeError("Unique-key loss detected; aborting cache write.")
#     except Exception as e:
#         # Re-raise after logging
#         logger.exception(f"Cache merge guard failed: {e}")
#         raise

#     # Auto-backup before replace
#     from datetime import datetime
#     ts = datetime.now().strftime("%Y%m%d_%H%M%S")
#     autobak = PROCESSED_DIR / f"raw_exons_cache_autobak_{ts}.tsv"
#     tmp_path = EXON_CACHE_TSV.with_suffix('.tsv.tmp')

#     if EXON_CACHE_TSV.exists():
#         shutil.copyfile(EXON_CACHE_TSV, autobak)

#     # Atomic replace
#     merged.to_csv(tmp_path, sep='\t', index=False)
#     Path(tmp_path).replace(EXON_CACHE_TSV)

#     after_rows = len(merged)
#     logger.info(
#         f"Cache merge: {before_rows} → {after_rows} rows "
#         f"(backup: {autobak.name if EXON_CACHE_TSV.exists() else 'n/a'}, "
#         f"prev sha256={before_sha[:10]}...)"
#     )


In [None]:
# ===== Cell 40 - Enhanced Cache Validation Functions =====
# Improved cache integrity checking and validation

import psutil
import os
from typing import Dict, Set, Optional, Tuple

def get_memory_usage() -> Dict[str, float]:
    """Get current memory usage statistics."""
    process = psutil.Process(os.getpid())
    memory_info = process.memory_info()
    return {
        'rss_mb': memory_info.rss / 1024 / 1024,  # Resident Set Size
        'vms_mb': memory_info.vms / 1024 / 1024,  # Virtual Memory Size
        'percent': process.memory_percent()
    }

def validate_cache_integrity(cache_path: Path, expected_columns: list) -> Tuple[bool, str]:
    """
    Validate cache file integrity and structure.

    Returns:
        Tuple of (is_valid, error_message)
    """
    if not cache_path.exists():
        return True, "Cache file does not exist (will be created)"

    try:
        # Check if file is readable
        df_sample = pd.read_csv(cache_path, sep='\t', nrows=5, low_memory=False)

        # Validate expected columns exist
        missing_cols = set(expected_columns) - set(df_sample.columns)
        if missing_cols:
            return False, f"Missing required columns: {missing_cols}"

        # Check for completely empty file
        if df_sample.empty:
            return True, "Cache is empty but valid"

        # Check for corrupted entries in key columns
        if 'accession' in df_sample.columns:
            null_accessions = df_sample['accession'].isna().sum()
            if null_accessions > 0:
                return False, f"Found {null_accessions} null accession entries"

        # Check file size vs expected structure
        file_size_mb = cache_path.stat().st_size / 1024 / 1024
        if file_size_mb > 1000:  # >1GB cache file
            logger.warning(f"Large cache file detected: {file_size_mb:.1f}MB")

        return True, "Cache validation passed"

    except pd.errors.EmptyDataError:
        return False, "Cache file is empty or corrupted"
    except pd.errors.ParserError as e:
        return False, f"Cache file parsing error: {str(e)}"
    except Exception as e:
        return False, f"Cache validation error: {str(e)}"

def load_mapped_accessions_from_cache() -> Set[str]:
    """
    Read the master cache and return the set of already-mapped accessions.
    Enhanced with integrity checking and corruption detection.
    """
    if not EXON_CACHE_TSV.exists():
        logger.info("No existing exon cache found")
        return set()

    # Validate cache integrity first
    expected_columns = ['accession', 'exon_num_in_chain', 'begin_aa', 'end_aa', 'peptide']
    is_valid, error_msg = validate_cache_integrity(EXON_CACHE_TSV, expected_columns)

    if not is_valid:
        logger.error(f"Cache validation failed: {error_msg}")
        # Create backup of corrupted cache
        backup_path = EXON_CACHE_TSV.with_suffix('.corrupted_backup')
        EXON_CACHE_TSV.rename(backup_path)
        logger.info(f"Corrupted cache backed up to: {backup_path}")
        return set()

    logger.info(f"Cache validation: {error_msg}")

    try:
        # Load only the accession column for efficiency
        df = pd.read_csv(EXON_CACHE_TSV, sep='\t', usecols=['accession'], low_memory=False)

        # Additional corruption checks
        total_entries = len(df)
        valid_accessions = df['accession'].dropna().astype(str)
        invalid_count = total_entries - len(valid_accessions)

        if invalid_count > 0:
            logger.warning(f"Found {invalid_count} invalid accession entries in cache")

        # Check for obviously corrupted accession patterns
        accession_set = set(valid_accessions.unique())
        suspicious_entries = [acc for acc in accession_set if len(acc) < 3 or not acc.replace('_', '').replace('-', '').isalnum()]

        if suspicious_entries:
            logger.warning(f"Found {len(suspicious_entries)} suspicious accession patterns: {suspicious_entries[:5]}")

        logger.info(f"Loaded {len(accession_set)} unique mapped accessions from cache")
        return accession_set

    except Exception as e:
        logger.error(f"Error reading cache for resume: {e}")
        # Don't crash - just start fresh
        return set()

def atomic_merge_into_cache(new_rows: pd.DataFrame) -> None:
    """
    Enhanced atomic merge with memory monitoring and validation.
    """
    if new_rows is None or new_rows.empty:
        logger.info("No new rows to merge into cache")
        return

    # Memory check before processing
    initial_memory = get_memory_usage()
    logger.info(f"Memory before cache merge: {initial_memory['rss_mb']:.1f}MB ({initial_memory['percent']:.1f}%)")

    key_cols = ['accession', 'exon_num_in_chain']

    # Validate new rows before merging
    required_cols = ['accession', 'exon_num_in_chain', 'begin_aa', 'end_aa', 'peptide']
    missing_cols = set(required_cols) - set(new_rows.columns)
    if missing_cols:
        logger.error(f"New rows missing required columns: {missing_cols}")
        return

    # Load existing cache with validation
    if EXON_CACHE_TSV.exists():
        is_valid, error_msg = validate_cache_integrity(EXON_CACHE_TSV, required_cols)
        if not is_valid:
            logger.error(f"Cannot merge - existing cache invalid: {error_msg}")
            return

        try:
            old = pd.read_csv(EXON_CACHE_TSV, sep='\t', low_memory=False)
            logger.info(f"Loaded existing cache: {len(old)} rows")
        except Exception as e:
            logger.error(f"Failed to load existing cache: {e}")
            return
    else:
        old = pd.DataFrame(columns=required_cols)
        logger.info("Creating new cache file")

    # Memory check after loading
    post_load_memory = get_memory_usage()
    memory_increase = post_load_memory['rss_mb'] - initial_memory['rss_mb']
    if memory_increase > 100:  # >100MB increase
        logger.warning(f"Cache loading used {memory_increase:.1f}MB additional memory")

    try:
        # Merge dataframes
        combined = pd.concat([old, new_rows], ignore_index=True)

        # Deduplicate by key columns, keeping last occurrence
        combined.drop_duplicates(subset=key_cols, keep='last', inplace=True)

        # Create backup before writing
        if EXON_CACHE_TSV.exists():
            backup_path = EXON_CACHE_TSV.with_suffix('.backup')
            EXON_CACHE_TSV.rename(backup_path)
            logger.info(f"Created backup: {backup_path}")

        # Atomic write with validation
        temp_path = EXON_CACHE_TSV.with_suffix('.tmp')
        combined.to_csv(temp_path, sep='\t', index=False)

        # Validate the written file
        is_valid, error_msg = validate_cache_integrity(temp_path, required_cols)
        if not is_valid:
            logger.error(f"Newly written cache failed validation: {error_msg}")
            temp_path.unlink()  # Delete invalid file
            return

        # Atomic rename
        temp_path.rename(EXON_CACHE_TSV)

        # Final memory check
        final_memory = get_memory_usage()
        total_increase = final_memory['rss_mb'] - initial_memory['rss_mb']

        logger.info(f"Cache merge complete: {len(old)} + {len(new_rows)} → {len(combined)} rows")
        logger.info(f"Memory after merge: {final_memory['rss_mb']:.1f}MB (Δ+{total_increase:.1f}MB)")

        # Clean up backup if everything succeeded
        backup_path = EXON_CACHE_TSV.with_suffix('.backup')
        if backup_path.exists():
            backup_path.unlink()

    except Exception as e:
        logger.error(f"Cache merge failed: {e}")
        # Restore backup if it exists
        backup_path = EXON_CACHE_TSV.with_suffix('.backup')
        if backup_path.exists():
            backup_path.rename(EXON_CACHE_TSV)
            logger.info("Restored cache from backup")

logger.info("✅ Enhanced cache validation functions loaded")

2025-08-22 05:42:50,804 [INFO] - ✅ Enhanced cache validation functions loaded


## Cell 41 – Exon coordinate mapper (EBI Proteins API)

Retries, rate limiting, ±2 flanks, and even-numbering centered on the
first helix exon. Writes strand and chromosome where available.

In [None]:
# ===== Cell 41 - Enhanced Exon Coordinate Mapper with Granular Error Tracking =====
# Improved error classification and detailed failure reporting

from collections import defaultdict
from enum import Enum
import time
import json

class ErrorCategory(Enum):
    """Categorize different types of mapping failures for better diagnostics."""
    API_FAILURE = "api_failure"
    TIMEOUT_FAILURE = "timeout_failure"
    DATA_FORMAT_FAILURE = "data_format_failure"
    COORDINATE_FAILURE = "coordinate_failure"
    SEQUENCE_MISMATCH = "sequence_mismatch"
    EMPTY_RESPONSE = "empty_response"
    PARSING_ERROR = "parsing_error"
    NETWORK_ERROR = "network_error"
    RATE_LIMIT_ERROR = "rate_limit_error"
    UNKNOWN_ERROR = "unknown_error"

class EnhancedExonCoordinateMapper:
    """
    Enhanced mapper with granular error tracking, detailed diagnostics,
    and improved retry logic with exponential backoff.
    """

    def __init__(self, base_url="https://rest.ensembl.org", max_retries=3, initial_delay=1.0):
        self.base_url = base_url
        self.max_retries = max_retries
        self.initial_delay = initial_delay
        self.cache = {}
        self.failed = set()

        # Enhanced statistics with granular error tracking
        self.stats = {
            'total_attempts': 0,
            'cache_hits': 0,
            'successes': 0,
            'total_failures': 0,
            'retry_attempts': 0,
            'processing_time': 0.0
        }

        # Granular error tracking
        self.error_stats = defaultdict(int)
        self.error_details = defaultdict(list)  # Store specific error messages
        self.failure_timeline = []  # Track when failures occur

        # Performance monitoring
        self.api_call_times = []
        self.slow_calls = []  # Calls taking >5 seconds

    def log_error(self, error_category: ErrorCategory, accession: str, details: str = "", context: dict = None):
        """Log an error with detailed context for debugging."""
        self.error_stats[error_category.value] += 1
        self.stats['total_failures'] += 1

        error_entry = {
            'timestamp': time.time(),
            'accession': accession,
            'category': error_category.value,
            'details': details,
            'context': context or {}
        }

        self.error_details[error_category.value].append(error_entry)
        self.failure_timeline.append(error_entry)

        # Log to console with appropriate level
        if error_category in [ErrorCategory.API_FAILURE, ErrorCategory.NETWORK_ERROR]:
            logger.warning(f"API issue for {accession}: {error_category.value} - {details}")
        elif error_category in [ErrorCategory.DATA_FORMAT_FAILURE, ErrorCategory.COORDINATE_FAILURE]:
            logger.debug(f"Data issue for {accession}: {error_category.value} - {details}")
        else:
            logger.info(f"Mapping issue for {accession}: {error_category.value} - {details}")

    def _fetch_with_enhanced_error_tracking(self, accession: str) -> Optional[Dict]:
        """Fetch data with comprehensive error tracking and retry logic."""
        url = f"{self.base_url}/lookup/id/{accession}?content-type=application/json;expand=1"

        for attempt in range(self.max_retries + 1):
            try:
                start_time = time.time()

                response = requests.get(url, timeout=30)

                call_duration = time.time() - start_time
                self.api_call_times.append(call_duration)

                if call_duration > 5.0:
                    self.slow_calls.append({
                        'accession': accession,
                        'duration': call_duration,
                        'attempt': attempt + 1
                    })

                if response.status_code == 200:
                    try:
                        data = response.json()
                        if not data:
                            self.log_error(ErrorCategory.EMPTY_RESPONSE, accession,
                                         "API returned empty response")
                            return None
                        return data
                    except json.JSONDecodeError as e:
                        self.log_error(ErrorCategory.PARSING_ERROR, accession,
                                     f"JSON decode error: {str(e)}",
                                     {'response_text': response.text[:200]})
                        return None

                elif response.status_code == 429:
                    # Rate limiting
                    self.log_error(ErrorCategory.RATE_LIMIT_ERROR, accession,
                                 f"Rate limited on attempt {attempt + 1}")
                    if attempt < self.max_retries:
                        delay = self.initial_delay * (2 ** attempt) + 2  # Extra delay for rate limiting
                        time.sleep(delay)
                        self.stats['retry_attempts'] += 1
                        continue
                    return None

                elif response.status_code == 404:
                    self.log_error(ErrorCategory.API_FAILURE, accession,
                                 "Accession not found in Ensembl",
                                 {'status_code': 404})
                    return None

                else:
                    self.log_error(ErrorCategory.API_FAILURE, accession,
                                 f"HTTP {response.status_code}: {response.reason}",
                                 {'status_code': response.status_code, 'attempt': attempt + 1})
                    if attempt < self.max_retries:
                        delay = self.initial_delay * (2 ** attempt)
                        time.sleep(delay)
                        self.stats['retry_attempts'] += 1
                        continue
                    return None

            except requests.exceptions.Timeout:
                self.log_error(ErrorCategory.TIMEOUT_FAILURE, accession,
                             f"Request timeout on attempt {attempt + 1}")
                if attempt < self.max_retries:
                    delay = self.initial_delay * (2 ** attempt)
                    time.sleep(delay)
                    self.stats['retry_attempts'] += 1
                    continue
                return None

            except requests.exceptions.ConnectionError as e:
                self.log_error(ErrorCategory.NETWORK_ERROR, accession,
                             f"Connection error: {str(e)}")
                if attempt < self.max_retries:
                    delay = self.initial_delay * (2 ** attempt)
                    time.sleep(delay)
                    self.stats['retry_attempts'] += 1
                    continue
                return None

            except requests.exceptions.RequestException as e:
                self.log_error(ErrorCategory.NETWORK_ERROR, accession,
                             f"Request exception: {str(e)}")
                if attempt < self.max_retries:
                    delay = self.initial_delay * (2 ** attempt)
                    time.sleep(delay)
                    self.stats['retry_attempts'] += 1
                    continue
                return None

        return None

    def enhanced_get_mapped_exons(self, accession: str, main_chain: List[Dict], sequence: str) -> Optional[List[Dict]]:
        """Enhanced exon mapping with comprehensive error tracking."""
        self.stats['total_attempts'] += 1
        start_time = time.time()

        # Input validation
        if not accession or not main_chain or not sequence:
            self.log_error(ErrorCategory.DATA_FORMAT_FAILURE, accession or "unknown",
                         "Missing required input data",
                         {'has_accession': bool(accession), 'has_chain': bool(main_chain), 'has_sequence': bool(sequence)})
            return None

        # Check cache first
        if accession in self.cache:
            self.stats['cache_hits'] += 1
            return self.cache[accession]

        # Fetch data with enhanced error tracking
        data = self._fetch_with_enhanced_error_tracking(accession)
        if data is None:
            self.failed.add(accession)
            self.cache[accession] = None
            return None

        try:
            result = self._process_with_validation(accession, data, main_chain, sequence)

            # Update stats and cache
            if result:
                self.stats['successes'] += 1
                self.cache[accession] = result
            else:
                self.failed.add(accession)
                self.cache[accession] = None

            processing_time = time.time() - start_time
            self.stats['processing_time'] += processing_time

            return result

        except Exception as e:
            self.log_error(ErrorCategory.UNKNOWN_ERROR, accession,
                         f"Unexpected processing error: {str(e)}",
                         {'exception_type': type(e).__name__})
            self.failed.add(accession)
            self.cache[accession] = None
            return None

    def _process_with_validation(self, acc: str, data: Dict, main_chain: List[Dict], seq: str) -> Optional[List[Dict]]:
        """Process API response with enhanced validation and error reporting."""
        try:
            # Validate data structure
            if 'gnCoordinate' not in data:
                self.log_error(ErrorCategory.DATA_FORMAT_FAILURE, acc,
                             "Missing gnCoordinate in API response",
                             {'available_keys': list(data.keys())})
                return None

            gn_coords = data['gnCoordinate']
            if not gn_coords or not isinstance(gn_coords, list):
                self.log_error(ErrorCategory.DATA_FORMAT_FAILURE, acc,
                             "Invalid gnCoordinate structure")
                return None

            gn = gn_coords[0]

            # Validate genomic location
            if 'genomicLocation' not in gn:
                self.log_error(ErrorCategory.DATA_FORMAT_FAILURE, acc,
                             "Missing genomicLocation in gnCoordinate")
                return None

            genomic_loc = gn['genomicLocation']
            if 'exon' not in genomic_loc:
                self.log_error(ErrorCategory.DATA_FORMAT_FAILURE, acc,
                             "No exons found in genomicLocation")
                return None

            exons = genomic_loc['exon']
            if not exons:
                self.log_error(ErrorCategory.DATA_FORMAT_FAILURE, acc,
                             "Empty exon list")
                return None

            # Sort exons and validate coordinates
            try:
                sorted_exons = sorted(exons, key=lambda x: x.get('proteinLocation', {}).get('begin', {}).get('position', 0))
            except (KeyError, TypeError) as e:
                self.log_error(ErrorCategory.COORDINATE_FAILURE, acc,
                             f"Error sorting exons by coordinates: {str(e)}")
                return None

        except Exception as e:
            self.log_error(ErrorCategory.DATA_FORMAT_FAILURE, acc,
                         f"Error accessing API response structure: {str(e)}")
            return None

        # Validate main chain coordinates
        if not main_chain:
            self.log_error(ErrorCategory.COORDINATE_FAILURE, acc,
                         "Empty main chain provided")
            return None

        try:
            c_start = main_chain[0]['start']
            c_end = main_chain[-1]['end']
        except (KeyError, IndexError) as e:
            self.log_error(ErrorCategory.COORDINATE_FAILURE, acc,
                         f"Invalid main chain structure: {str(e)}")
            return None

        # Find exons within the main chain
        first_idx = last_idx = -1
        valid_exons = 0

        for i, ex in enumerate(sorted_exons):
            try:
                pb = ex.get('proteinLocation', {}).get('begin', {}).get('position')
                pe = ex.get('proteinLocation', {}).get('end', {}).get('position')

                if pb is None or pe is None:
                    continue

                valid_exons += 1

                if pb >= c_start and pe <= c_end:
                    if first_idx == -1:
                        first_idx = i
                    last_idx = i

            except Exception as e:
                logger.debug(f"Error processing exon {i} for {acc}: {e}")
                continue

        if first_idx == -1:
            self.log_error(ErrorCategory.COORDINATE_FAILURE, acc,
                         f"No exons found within main chain ({c_start}-{c_end})",
                         {'total_exons': len(sorted_exons), 'valid_exons': valid_exons})
            return None

        # Build mapped exons with validation
        s_idx = max(0, first_idx - 2)
        e_idx = min(len(sorted_exons) - 1, last_idx + 2)
        mapped = []
        even = -2 * (first_idx - s_idx)  # helix-first exon = 0

        for i in range(s_idx, e_idx + 1):
            try:
                ex = sorted_exons[i]
                pl = ex.get('proteinLocation', {})
                pb = pl.get('begin', {}).get('position')
                pe = pl.get('end', {}).get('position')

                if pb is None or pe is None:
                    continue

                # Validate sequence coordinates
                if pb > len(seq) or pe > len(seq) or pb < 1:
                    self.log_error(ErrorCategory.SEQUENCE_MISMATCH, acc,
                                 f"Exon coordinates ({pb}-{pe}) outside sequence length ({len(seq)})")
                    continue

                pep = seq[pb-1:pe]

                mapped.append({
                    "accession": acc,
                    "exon_num_in_chain": even,
                    "begin_aa": int(pb),
                    "end_aa": int(pe),
                    "peptide": pep,
                    "strand": gn.get('genomeLocation', {}).get('strand'),
                    "chr": gn.get('genomeLocation', {}).get('chromosome')
                })
                even += 2

            except Exception as e:
                logger.debug(f"Error processing exon {i} coordinates for {acc}: {e}")
                continue

        if not mapped:
            self.log_error(ErrorCategory.COORDINATE_FAILURE, acc,
                         "No valid exon mappings generated")
            return None

        return mapped

    def get_detailed_stats(self) -> Dict:
        """Get comprehensive statistics including error breakdown."""
        total_errors = sum(self.error_stats.values())
        avg_call_time = sum(self.api_call_times) / len(self.api_call_times) if self.api_call_times else 0

        stats = {
            'summary': dict(self.stats),
            'error_breakdown': dict(self.error_stats),
            'error_percentage': {
                category: (count / total_errors * 100) if total_errors > 0 else 0
                for category, count in self.error_stats.items()
            },
            'performance': {
                'avg_api_call_time': avg_call_time,
                'slow_calls_count': len(self.slow_calls),
                'total_api_calls': len(self.api_call_times),
                'cache_hit_rate': (self.stats['cache_hits'] / self.stats['total_attempts'] * 100) if self.stats['total_attempts'] > 0 else 0
            },
            'top_failure_reasons': {
                category: len(details) for category, details in
                sorted(self.error_details.items(), key=lambda x: len(x[1]), reverse=True)[:5]
            }
        }

        return stats

    def export_error_report(self, output_path: Path) -> None:
        """Export detailed error report for debugging."""
        report = {
            'generated_at': time.time(),
            'summary_stats': self.get_detailed_stats(),
            'error_timeline': self.failure_timeline[-100:],  # Last 100 errors
            'slow_calls': self.slow_calls,
            'failed_accessions': list(self.failed)
        }

        with open(output_path, 'w') as f:
            json.dump(report, f, indent=2, default=str)

        logger.info(f"Error report exported to: {output_path}")

# Initialize enhanced mapper
enhanced_exon_mapper = EnhancedExonCoordinateMapper()

logger.info("✅ Enhanced Exon Coordinate Mapper with granular error tracking loaded")

2025-08-22 05:43:03,834 [INFO] - ✅ Enhanced Exon Coordinate Mapper with granular error tracking loaded


In [None]:
# # ===== Cell 41 =====
# # Enhanced exon coordinate mapper
# from requests.adapters import HTTPAdapter
# from urllib3.util.retry import Retry

# class EnhancedExonCoordinateMapper:
#     def __init__(self, rate_delay: float = 0.1, max_retries: int = 3, timeout: int = 30):
#         self.rate_delay = rate_delay; self.max_retries = max_retries; self.timeout = timeout
#         self.cache: Dict[str, Optional[List[Dict]]] = {}; self.failed: set[str] = set()
#         self.stats = {"api_calls":0,"cache_hits":0,"successes":0,"failures":0,"retries":0}
#         self.session = requests.Session()
#         retry = Retry(total=max_retries, backoff_factor=1.5, status_forcelist=[429,500,502,503,504], allowed_methods=["GET"])
#         adapter = HTTPAdapter(max_retries=retry)
#         self.session.mount("http://", adapter); self.session.mount("https://", adapter)
#         self.session.headers.update({"Accept":"application/json","User-Agent": f"CollagenExonMapper/1.4"})

#     def _fetch(self, acc: str) -> Optional[Dict]:
#         url = f"https://www.ebi.ac.uk/proteins/api/coordinates/{acc}"
#         for k in range(self.max_retries+1):
#             try:
#                 time.sleep(self.rate_delay); self.stats['api_calls'] += 1
#                 if k>0: self.stats['retries'] += 1
#                 r = self.session.get(url, timeout=self.timeout)
#                 if r.status_code == 404: return None
#                 if r.status_code == 429:
#                     wait = int(r.headers.get('Retry-After','5')); time.sleep(wait); continue
#                 r.raise_for_status(); return r.json()
#             except requests.exceptions.Timeout:
#                 if k < self.max_retries: time.sleep(2**k); continue
#                 break
#             except requests.exceptions.RequestException:
#                 if k < self.max_retries: time.sleep(2**k); continue
#                 break
#         return None

#     def enhanced_get_mapped_exons(self, accession: str, main_chain: List[Dict], sequence: str) -> Optional[List[Dict]]:
#         if not accession or not main_chain or not sequence: return None
#         if accession in self.cache: self.stats['cache_hits'] += 1; return self.cache[accession]
#         data = self._fetch(accession)
#         if data is None:
#             self.failed.add(accession); self.stats['failures'] += 1; self.cache[accession] = None; return None
#         try:
#             res = self._process(accession, data, main_chain, sequence)
#             self.cache[accession] = res; self.stats['successes'] += int(bool(res))
#             if not res: self.stats['failures'] += 1
#             return res
#         except Exception:
#             self.failed.add(accession); self.stats['failures'] += 1; self.cache[accession] = None; return None

#     def _process(self, acc: str, data: Dict, main_chain: List[Dict], seq: str) -> Optional[List[Dict]]:
#         try:
#             gn = data['gnCoordinate'][0]
#             exons = sorted(gn['genomicLocation']['exon'], key=lambda x: x['proteinLocation']['begin']['position'])
#         except Exception:
#             return None
#         if not exons: return None
#         c_start = main_chain[0]['start']; c_end = main_chain[-1]['end']
#         first_idx = last_idx = -1
#         for i, ex in enumerate(exons):
#             pb = ex.get('proteinLocation',{}).get('begin',{}).get('position')
#             pe = ex.get('proteinLocation',{}).get('end',{}).get('position')
#             if pb is None or pe is None: continue
#             if pb >= c_start and pe <= c_end:
#                 if first_idx == -1: first_idx = i
#                 last_idx = i
#         if first_idx == -1: return None
#         s_idx = max(0, first_idx-2); e_idx = min(len(exons)-1, last_idx+2)
#         mapped = []; even = -2*(first_idx - s_idx)  # helix-first exon = 0
#         for i in range(s_idx, e_idx+1):
#             ex = exons[i]; pl = ex.get('proteinLocation',{})
#             pb = pl.get('begin',{}).get('position'); pe = pl.get('end',{}).get('position')
#             if pb is None or pe is None: continue
#             pep = seq[pb-1:pe]
#             mapped.append({
#                 "accession": acc,
#                 "exon_num_in_chain": even,
#                 "begin_aa": int(pb),
#                 "end_aa": int(pe),
#                 "peptide": pep,
#                 "strand": gn.get('genomeLocation',{}).get('strand'),
#                 "chr": gn.get('genomeLocation',{}).get('chromosome')
#             })
#             even += 2
#         return mapped

# enhanced_exon_mapper = EnhancedExonCoordinateMapper()

## Cell 42 – Incremental mapping & **resume-safe** cache update
- Respects `MAPPING_STRATEGY` = `resume` / `skip` / `force`
- Skips accessions already present in cache (resume mode)
- Commits every `MAP_COMMIT_CHUNK` accessions (atomic append)
- Never overwrites/shrinks the master cache


In [None]:
# ===== Cell 42 - Updated Version =====
# Incremental exon mapping with improved cache loading using canonical paths

logging.info("--- Part 4: Incremental Exon Mapping ---")

# -------------------------
# NEW: Cache Loading Function (replaces old cache loading logic)
# -------------------------
def load_exon_cache():
    """Load existing exon cache with proper error handling and validation"""
    if EXON_CACHE_TSV.exists():
        try:
            cache_df = pd.read_csv(EXON_CACHE_TSV, sep='\t', low_memory=False)

            # Validate cache structure
            required_columns = ['accession', 'exon_num_in_chain', 'peptide']
            missing_cols = [col for col in required_columns if col not in cache_df.columns]

            if missing_cols:
                logging.warning(f"Cache missing required columns: {missing_cols}")
                logging.warning("Creating new cache due to invalid structure")
                return pd.DataFrame()

            if not cache_df.empty:
                unique_accessions = cache_df['accession'].nunique()
                total_exons = len(cache_df)
                logging.info(f"   ✅ Loaded {total_exons} exons for {unique_accessions} proteins.")
                return cache_df
            else:
                logging.info("   Cache file exists but is empty")
                return pd.DataFrame()

        except Exception as e:
            logging.warning(f"   Could not read exon cache: {e}")
            return pd.DataFrame()
    else:
        logging.info("   No existing cache found, starting fresh")
        return pd.DataFrame()

# -------------------------
# NEW: Cache Saving Function (replaces old cache saving logic)
# -------------------------
def save_exon_cache_with_backup(combined_df):
    """Save exon cache with atomic write and backup"""
    try:
        if combined_df.empty:
            logging.warning("No data to save to cache")
            return False

        # Create backup before writing
        if EXON_CACHE_TSV.exists():
            backup_path = CACHE_DIR / f"raw_exons_cache_backup_{RUN_ID}.tsv"
            import shutil
            shutil.copy2(EXON_CACHE_TSV, backup_path)
            logging.info(f"   Cache backed up to: {backup_path}")

        # Atomic write with validation
        temp_path = EXON_CACHE_TSV.with_suffix('.tmp')
        combined_df.to_csv(temp_path, sep='\t', index=False)

        # Validate the written file by trying to read it
        try:
            validation_df = pd.read_csv(temp_path, sep='\t', nrows=5)
            if len(validation_df.columns) > 0:
                temp_path.replace(EXON_CACHE_TSV)
                logging.info(f"   ✅ Cache saved successfully: {len(combined_df)} rows")
                return True
            else:
                logging.error("   Validation failed: written file appears empty")
                temp_path.unlink()
                return False
        except Exception as ve:
            logging.error(f"   Validation failed: {ve}")
            temp_path.unlink()
            return False

    except Exception as e:
        logging.error(f"   Failed to save cache: {e}")
        return False

# -------------------------
# MAIN CACHE LOADING (replaces the old df_cached_exons loading)
# -------------------------
df_cached_exons = load_exon_cache()

# -------------------------
# INCREMENTAL MAPPING LOGIC (rest of Cell 42 continues as before, but updated)
# -------------------------
df_raw_exons = pd.DataFrame()

if 'df_high_quality' in globals() and not df_high_quality.empty:
    # Calculate what needs to be mapped
    acc_all = set(df_high_quality['Entry'].dropna().unique())
    acc_done = set(df_cached_exons['accession'].unique()) if not df_cached_exons.empty else set()
    acc_new = acc_all - acc_done

    logging.info(f"   H.Q. entries total: {len(acc_all)}")
    logging.info(f"   Already cached: {len(acc_done)}")
    logging.info(f"   Need mapping: {len(acc_new)}")

    if acc_new:
        logging.info("   🔄 Starting incremental exon mapping...")

        # Create subset for mapping
        df_to_map = df_high_quality[df_high_quality['Entry'].isin(acc_new)].copy()

        # Initialize batch processing
        batch_rows = []
        mapped_count = 0

        # Process in batches (your existing mapping logic goes here)
        for idx, row in df_to_map.iterrows():
            accession = row['Entry']

            try:
                # Your existing ExonCoordinateMapper logic here
                # exon_details = exon_mapper.map_exons(accession, row['Sequence'])

                # For now, placeholder for existing logic
                # Replace this section with your actual mapping code
                exon_details = []  # This would come from your existing mapper

                if exon_details:
                    for detail in exon_details:
                        batch_rows.append(detail)
                    mapped_count += 1

                    # Commit batch if it reaches size limit
                    if len(batch_rows) >= 100:  # or whatever your MAP_COMMIT_CHUNK is
                        chunk_df = pd.DataFrame(batch_rows)

                        # Merge with existing cache
                        if not df_cached_exons.empty:
                            combined = pd.concat([df_cached_exons, chunk_df], ignore_index=True)
                        else:
                            combined = chunk_df

                        # Remove duplicates and save
                        combined = combined.drop_duplicates(subset=['accession', 'exon_num_in_chain'], keep='last')
                        save_exon_cache_with_backup(combined)

                        # Update our working cache
                        df_cached_exons = combined

                        logging.info(f"   📝 Committed batch: {len(batch_rows)} exons")
                        batch_rows = []

            except Exception as e:
                logging.warning(f"   Failed to map {accession}: {e}")
                continue

        # Commit any remaining batch
        if batch_rows:
            chunk_df = pd.DataFrame(batch_rows)
            if not df_cached_exons.empty:
                combined = pd.concat([df_cached_exons, chunk_df], ignore_index=True)
            else:
                combined = chunk_df
            combined = combined.drop_duplicates(subset=['accession', 'exon_num_in_chain'], keep='last')
            save_exon_cache_with_backup(combined)
            df_cached_exons = combined
            logging.info(f"   📝 Final batch committed: {len(batch_rows)} exons")

        logging.info(f"   ✅ Mapping complete: {mapped_count} new accessions processed")
    else:
        logging.info("   ✅ All entries already cached, no mapping needed")

    # Set final result
    df_raw_exons = df_cached_exons

else:
    logging.warning("   No high-quality sequences available for mapping")
    df_raw_exons = df_cached_exons

# Final statistics
if not df_raw_exons.empty:
    total_proteins = df_raw_exons['accession'].nunique()
    total_exons = len(df_raw_exons)
    logging.info(f"   📊 Final dataset: {total_exons} exons from {total_proteins} proteins")
else:
    logging.info("   📊 No exons available for downstream processing")

logging.info("--- Incremental Exon Mapping Complete ---")

In [None]:
# ===== Cell 44 =====
# Integrity re-check & auto-restore from backup

from pathlib import Path
import pandas as pd
import hashlib

def _sha256(p: Path) -> str:
    if not p.exists(): return ""
    h = hashlib.sha256()
    with open(p, 'rb') as f:
        for chunk in iter(lambda: f.read(1<<20), b''):
            h.update(chunk)
    return h.hexdigest()

def _safe_rows(p: Path) -> int:
    if not p.exists(): return -1
    try:
        return sum(1 for _ in open(p, 'rb')) - 1  # headerless estimate; faster than full read
    except Exception:
        try:
            return len(pd.read_csv(p, sep='\t', usecols=['accession']))
        except Exception:
            return -1

current_rows = _safe_rows(EXON_CACHE_TSV)
current_sha = _sha256(EXON_CACHE_TSV) if EXON_CACHE_TSV.exists() else ""

# Find latest manual or auto backup
baks = sorted(list(PROCESSED_DIR.glob("raw_exons_cache_backup_*.tsv")) +
              list(PROCESSED_DIR.glob("raw_exons_cache_autobak_*.tsv")),
              key=lambda p: p.stat().st_mtime, reverse=True)

if not EXON_CACHE_TSV.exists():
    assert baks, "No cache and no backups found – cannot restore."
    logger.warning(f"Cache missing. Restoring from {baks[0].name}")
    Path(baks[0]).replace(EXON_CACHE_TSV)
    logger.info(f"✅ Restored cache to {EXON_CACHE_TSV.name}")
else:
    logger.info(f"Current cache: rows≈{current_rows}, sha256={current_sha[:16]}...")
    if baks:
        bak_rows = _safe_rows(baks[0])
        logger.info(f"Latest backup: {baks[0].name} rows≈{bak_rows}")
        if (current_rows >= 0) and (bak_rows > current_rows):
            # Auto-restore if backup is larger (safer)
            logger.warning("Backup is larger than current cache → restoring backup.")
            tmp = EXON_CACHE_TSV.with_suffix(".restore.tmp")
            # copy via pandas to normalize newline/encoding issues if any
            pd.read_csv(baks[0], sep='\t', low_memory=False).to_csv(tmp, sep='\t', index=False)
            Path(tmp).replace(EXON_CACHE_TSV)
            logger.info("✅ Auto-restore complete.")
    else:
        logger.info("No backups found; keeping current cache unchanged.")


# **Part 5: Consensus & Data Restructuring**

## Cell 50 – PhylogeneticConsensusEngine (FAST MRCA-based)

This replaces pairwise genus–genus distances with a **single-pass MRCA scheme**:

- Pre-index **TimeTree** once for all genera we’ll need (fast lookups).
- Per group, compute weights as: `w(g) = 1 / (1 + dist(leaf_g, MRCA(group)))`, then normalize.
- Falls back to **uniform weights** if the tree or taxa are missing.
- Keeps a small JSON cache file for backwards compatibility (not required in FAST mode).

Inputs assumed from earlier cells:
- `DRIVE_TIMETREE_PATH` (optional; soft-fallback if missing)
- `CACHE_DIR` (for `timetree_distance_cache.json`)
- `logger`


In [None]:
# ===== Cell 50 =====
# PhylogeneticConsensusEngine (FAST MRCA-based) + preindex support

import json

try:
    from ete3 import Tree  # optional, only used if TimeTree present
    _HAS_ETE_TREE = True
except Exception:
    _HAS_ETE_TREE = False
    Tree = None

PHYLO_CACHE_JSON = CACHE_DIR / "timetree_distance_cache.json"

class PhylogeneticConsensusEngine:
    """
    MRCA-based, single-pass weighting to avoid O(k^2) pairwise calls.
    - Build a name->node index once (only for needed genera).
    - Weights for a group = 1/(1 + dist(leaf, MRCA(group))).
    - Optional dist cache kept for compatibility; not required in FAST mode.
    """
    def __init__(self, nwk_path: Path, cache_json: Path):
        self.nwk_path = nwk_path
        self.cache_json = cache_json
        self.tree = None
        self._node_by_name = {}   # normalized_name -> TreeNode
        self._norm_cache = {}
        self.dist_cache = {}
        if cache_json.exists():
            try:
                raw = json.load(open(cache_json))
                self.dist_cache = {tuple(k.split("|")): v for k, v in raw.items()}
            except Exception:
                self.dist_cache = {}
        if _HAS_ETE_TREE and nwk_path.exists():
            try:
                self.tree = Tree(str(nwk_path), format=1)
            except Exception:
                self.tree = None

    def _norm(self, name: str) -> str:
        if name in self._norm_cache:
            return self._norm_cache[name]
        n = (name or "").strip().replace(" ", "_")
        self._norm_cache[name] = n
        return n

    def preindex_genera(self, genera) -> None:
        """Build name->node dict once for all genera we’ll query."""
        if not self.tree:
            return
        needed = {self._norm(g) for g in genera if isinstance(g, str) and g.strip()}
        # quick leaf-name lookup table
        name2node = {}
        for node in self.tree.iter_leaves():
            if node.name:
                name2node[node.name] = node
        hit, miss = 0, 0
        for n in needed:
            node = name2node.get(n)
            if node is not None:
                self._node_by_name[n] = node
                hit += 1
            else:
                miss += 1
        logger.info(f"TimeTree preindex: {hit} genera resolved, {miss} not found.")

    def _node(self, genus: str):
        if not self.tree or not genus:
            return None
        return self._node_by_name.get(self._norm(genus))

    def mrca_node(self, genera):
        """MRCA of all resolvable genera in the group."""
        if not self.tree:
            return None
        nodes = [self._node(g) for g in genera]
        nodes = [n for n in nodes if n is not None]
        if len(nodes) == 0:
            return None
        if len(nodes) == 1:
            return nodes[0]
        try:
            return self.tree.get_common_ancestor(nodes)
        except Exception:
            return None

    def weights_from_mrca(self, genera):
        """
        Fast weights: for each genus g, weight = 1/(1 + dist(g, MRCA(genera))).
        Normalized to sum to 1. Uniform if tree/labels missing.
        """
        uniq = sorted({g for g in genera if isinstance(g, str) and g.strip()})
        if not self.tree or not uniq:
            return {g: 1.0/len(uniq) for g in uniq} if uniq else {}
        mrca = self.mrca_node(uniq)
        if mrca is None:
            return {g: 1.0/len(uniq) for g in uniq}
        weights = {}
        for g in uniq:
            node = self._node(g)
            if node is None:
                w = 1.0  # treat unknowns as close to MRCA to avoid zeroing them
            else:
                try:
                    d = float(self.tree.get_distance(node, mrca))
                except Exception:
                    d = 1.0
                w = 1.0 / (1.0 + d)
            weights[g] = w
        s = sum(weights.values())
        return {g: (w/s if s > 0 else 0.0) for g, w in weights.items()}

    def persist_cache(self):
        """Persist pairwise cache (legacy compatibility)."""
        try:
            payload = {"|".join(k): v for k, v in self.dist_cache.items()}
            json.dump(payload, open(self.cache_json, "w"))
        except Exception:
            pass

phylo_engine = PhylogeneticConsensusEngine(DRIVE_TIMETREE_PATH, PHYLO_CACHE_JSON)
logger.info("PhylogeneticConsensusEngine ready "
            f"(tree={'OK' if phylo_engine.tree else 'fallback'})")


## Cell 51 – FAST weighted consensus (MRCA mode, chunked writes, detailed logging)
- Uses `weights_from_mrca()` (O(k) per group) → big speed-up.
- Logs progress every `PHYLO_LOG_INTERVAL` groups with ETA.
- Writes **incremental snapshots** every `CHUNK_SIZE` groups to avoid losing work.
- Falls back to **uniform** weights if TimeTree isn’t available.


In [None]:
# ===== Cell 51 =====
# FAST weighted consensus (MRCA-mode), preindex, chunked writes, detailed logging

# --- Safety preamble (so Cell 51 can run standalone) ---
import time

if 'PHYLO_FAST_MODE' not in globals():
    PHYLO_FAST_MODE = True
if 'PHYLO_LOG_INTERVAL' not in globals():
    PHYLO_LOG_INTERVAL = 500
if 'CHUNK_SIZE' not in globals():
    CHUNK_SIZE = 1000
if 'MAX_GROUPS' not in globals():
    MAX_GROUPS = -1  # process all groups by default

# If phylo_engine wasn't built (Cell 50 not run), fall back to uniform weights
if 'phylo_engine' not in globals():
    class _DummyPhylo:
        tree = None
        def preindex_genera(self, genera):  # no-op
            pass
        def weights_from_mrca(self, genera):
            uniq = sorted({g for g in genera if isinstance(g, str) and g.strip()})
            return {u: 1.0/len(uniq) for u in uniq} if uniq else {}
        def persist_cache(self):
            pass
    phylo_engine = _DummyPhylo()
    logger.info("⚠️ Phylo engine not initialised; using uniform weights (no TimeTree).")

def weighted_median(values: np.ndarray, weights: np.ndarray) -> float:
    if len(values) == 0:
        return float("nan")
    idx = np.argsort(values)
    v = values[idx]; w = weights[idx]
    c = np.cumsum(w) / np.sum(w)
    j = min(np.searchsorted(c, 0.5), len(v)-1)
    return float(v[j])

def adaptive_tolerance(depth_proxy: float) -> int:
    """
    We use an approximate 'depth proxy' (see below) to assign tolerance:
      0.0 → 2 AA  …  ≥3.0 → 5 AA (linear mapping)
    If unavailable, return 3 AA.
    """
    if depth_proxy is None or (isinstance(depth_proxy, float) and np.isnan(depth_proxy)):
        return 3
    x = float(np.clip(depth_proxy, 0.0, 3.0))
    return int(round(2.0 + x * (3.0/3.0)))  # 2..5

consensus_rows = []
refined_rows = []

start = time.perf_counter()
groups_done = 0

if EXON_CACHE_TSV.exists():
    base = pd.read_csv(EXON_CACHE_TSV, sep='\t', low_memory=False)
    need = ['accession','gene_symbol','organism','exon_num_in_chain','begin_aa','end_aa','peptide']
    base = base[[c for c in need if c in base.columns]].copy()

    # Join cluster_genus from working_df if available
    if 'cluster_genus' in working_df.columns:
        base = base.merge(
            working_df[['Entry','cluster_genus']].rename(columns={'Entry':'accession'}),
            on='accession', how='left'
        )
        base['cluster_genus'] = base['cluster_genus'].fillna(
            base['organism'].astype(str).str.split().str[0]
        )
    else:
        base['cluster_genus'] = base['organism'].astype(str).str.split().str[0]

    # Preindex all genera for this run (fast MRCA lookups)
    try:
        all_genera = base['cluster_genus'].astype(str).tolist()
        phylo_engine.preindex_genera(all_genera)
    except Exception:
        pass

    grouped = base.groupby(['gene_symbol','exon_num_in_chain'], sort=False)
    total_groups = len(grouped)
    logger.info(f"Consensus: processing {total_groups} exon groups "
                f"(FAST mode={'ON' if PHYLO_FAST_MODE and getattr(phylo_engine, 'tree', None) else 'OFF'})")

    chunk_idx = 0
    for i, ((g, e), df) in enumerate(grouped):
        if MAX_GROUPS > 0 and i >= MAX_GROUPS:
            break

        genera = df['cluster_genus'].astype(str).tolist()
        if getattr(phylo_engine, 'tree', None) and PHYLO_FAST_MODE:
            weights_by_genus = phylo_engine.weights_from_mrca(genera)
            # depth proxy: approx distance via inverse of weights (w ~ 1/(1+d) ⇒ d ~ 1/w - 1)
            ds = [(1.0/max(1e-9, w) - 1.0) for w in weights_by_genus.values()]
            mean_depth = float(np.mean(ds)) if ds else float('nan')
        else:
            uniq = sorted(set(genera))
            weights_by_genus = {u: 1.0/len(uniq) for u in uniq} if uniq else {}
            mean_depth = float('nan')

        tol = adaptive_tolerance(mean_depth)
        w = df['cluster_genus'].map(weights_by_genus).fillna(0.0).to_numpy()
        if w.sum() == 0:
            w = np.ones(len(df)) / max(1, len(df))

        b = df['begin_aa'].astype(float).to_numpy()
        epos = df['end_aa'].astype(float).to_numpy()

        cons_beg = weighted_median(b, w)
        cons_end = weighted_median(epos, w)
        cb, ce = int(round(cons_beg)), int(round(cons_end))

        consensus_rows.append({
            'gene_symbol': g,
            'exon_num_in_chain': e,
            'cons_begin': cb,
            'cons_end': ce,
            'tolerance_aa': int(tol),
            'depth_proxy': round(mean_depth, 3) if not np.isnan(mean_depth) else np.nan
        })

        for _, r in df.iterrows():
            db = int(r['begin_aa']) - cb
            de = int(r['end_aa'])   - ce
            refined_rows.append({
                'accession': r['accession'],
                'gene_symbol': g,
                'exon_num_in_chain': e,
                'begin_aa': int(r['begin_aa']),
                'end_aa': int(r['end_aa']),
                'peptide': r.get('peptide',''),
                'cluster_genus': r.get('cluster_genus',''),
                'cons_begin': cb,
                'cons_end': ce,
                'delta_begin': db,
                'delta_end': de,
                'boundary_ok': (abs(db) <= tol) and (abs(de) <= tol)
            })

        groups_done += 1
        if groups_done % PHYLO_LOG_INTERVAL == 0:
            elapsed = time.perf_counter() - start
            rate = groups_done / max(1e-9, elapsed)
            eta = (total_groups - groups_done) / max(1e-9, rate)
            logger.info(f"Consensus progress: {groups_done}/{total_groups} groups "
                        f"({rate:.2f} grp/s, ETA ~{eta/60:.1f} min)")

        if groups_done % CHUNK_SIZE == 0:
            # write checkpoints to guard against interrupts
            tmp_long = RUN_DIR / f"consensus_long_{RUN_ID}.chunk{chunk_idx}.tsv"
            tmp_tbl  = RUN_DIR / f"consensus_tbl_{RUN_ID}.chunk{chunk_idx}.tsv"
            pd.DataFrame(refined_rows).to_csv(tmp_long, sep='\t', index=False)
            pd.DataFrame(consensus_rows).to_csv(tmp_tbl,  sep='\t', index=False)
            logger.info(f"Checkpoint written: chunk {chunk_idx} "
                        f"(rows so far: {len(refined_rows)})")
            chunk_idx += 1

# Finalize outputs atomically
consensus_tbl  = pd.DataFrame(consensus_rows)
consensus_long = pd.DataFrame(refined_rows)
if not consensus_long.empty:
    tmp1 = CONSENSUS_LONG_SNAPSHOT.with_suffix(".tmp.tsv")
    tmp2 = CONSENSUS_LONG_SNAPSHOT.parent / f"consensus_tbl_{RUN_ID}.tmp.tsv"
    consensus_long.to_csv(tmp1, sep='\t', index=False)
    consensus_tbl.to_csv(tmp2,  sep='\t', index=False)
    Path(tmp1).replace(CONSENSUS_LONG_SNAPSHOT)
    Path(tmp2).replace(CONSENSUS_LONG_SNAPSHOT.parent / f"consensus_tbl_{RUN_ID}.tsv")
    # persist legacy cache (no-op in FAST mode)
    try:
        phylo_engine.persist_cache()
    except Exception:
        pass
    elapsed = time.perf_counter() - start
    logger.info(f"Consensus refined: groups={groups_done}, rows={len(consensus_long)}, "
                f"time={elapsed/60:.1f} min (FAST mode {'ON' if getattr(phylo_engine,'tree',None) else 'OFF'})")
else:
    logger.info("No data for consensus refinement.")


## Cell 52 – MRCA Level & Reliability Scoring (robust lineage merge)

**Purpose.**  
Determine the deepest taxonomic level where exon boundaries are consistent (an MRCA-like
coherence check), and compute a 0–100 reliability score per exon group using
phylogenetic diversity, boundary consistency, and sequence quality.

**Inputs.**  
- `consensus_long`, `consensus_tbl` from **Cell 51** (contains `cons_begin`, `cons_end`, `tolerance_aa`, `depth_proxy`, and `boundary_ok`).  
- Lineage features from `working_df` (**Cell 24**): `Clas_id`, `Order`, `Family`, `cluster_genus`.  
- Optional quality from `EXON_CACHE_TSV`: per `(accession, exon_num_in_chain)` `quality_score`.  

**Method.**  
1. Merge `consensus_long` + `consensus_tbl`, then join lineage fields from `working_df`
   and `quality_score` from the exon cache. If `cluster_genus` is missing, it is created
   as `'NA'`.  
2. For each (`gene_symbol`, `exon_num_in_chain`) group:  
   - **MRCA level**: test boundary spread (SD of `end_aa - begin_aa`) within tolerance
     across successive ranks present in the data (`Class`→`Order`→`Family`→`Genus`).
     The deepest rank where all bins are within `tolerance_aa` is reported; if none,
     return `"None"`.  
   - **Reliability score (0–100)**:
     \[
     100 \times \big(0.45 \cdot \text{phylo\_div} + 0.35 \cdot \text{ok\_frac} + 0.20 \cdot \text{seq\_q}\big) \times \text{tol\_factor}
     \]
     where:
     - `phylo_div = min(1, 0.2 * #genera + 0.1 * #orders)`  
     - `ok_frac = mean(boundary_ok)`  
     - `seq_q = mean(quality_score/100)` if available else `0.5`  
     - `tol_factor` slightly downweights wide tolerances: `1 - clip((tol-2)/3, 0, 1)*0.2`  
   - `depth_proxy` from **Cell 51** is carried for context but not directly in the score
     (tolerance already reflects phylogenetic spread).

**Outputs.**  
- `mrca_df` with:  
  `['gene_symbol','exon_num_in_chain','MRCA_level','reliability_score','tolerance_aa','depth_proxy']`  
- Logged summary: number of rows and genes processed.  
- **Note**: Not written to disk by default. If needed, persist with:  
  `mrca_df.to_csv(OUTPUTS_PATH / f"mrca_reliability_{RUN_ID}.tsv", sep="\t", index=False)`

**Edge cases & safeguards.**  
- If a lineage column is absent, that level is skipped automatically.  
- If `consensus_long`/`consensus_tbl` are missing, the cell exits gracefully.  
- Grouping keys with `NaN` are coerced to string bins (e.g., `'NA'`) to avoid errors.

**Downstream.**  
`mrca_df` can be merged in **Cell 54** (wide architecture) for annotation, reporting, and
filtering based on reliability thresholds.


In [None]:
# ===== Cell 52 =====
# MRCA level + reliability scoring (robust lineage merge)

def mrca_level_for_group(df: pd.DataFrame, tol: int) -> str:
    # Only use levels that are present
    levels = []
    if 'Clas_id' in df.columns: levels.append(('Clas_id','Class'))
    if 'Order'   in df.columns: levels.append(('Order','Order'))
    if 'Family'  in df.columns: levels.append(('Family','Family'))
    if 'cluster_genus' in df.columns: levels.append(('cluster_genus','Genus'))
    for col, label in levels:
        good = True
        # groupby needs non-null keys
        for _, sub in df.groupby(df[col].astype(str).fillna('NA')):
            dv = (sub['end_aa'].astype(float) - sub['begin_aa'].astype(float)).var()
            spread = float(np.sqrt(dv)) if pd.notna(dv) and dv > 0 else 0.0
            if spread > tol:
                good = False
                break
        if good:
            return label
    return "None"

def reliability_score_for_group(df: pd.DataFrame, tol: int) -> float:
    n_genera = df['cluster_genus'].nunique() if 'cluster_genus' in df.columns else 0
    n_orders = df['Order'].astype(str).nunique() if 'Order' in df.columns else 0
    phylo_div = min(1.0, 0.2 * n_genera + 0.1 * n_orders)
    ok_frac = float((df['boundary_ok'] == True).mean()) if 'boundary_ok' in df.columns else 0.0
    if 'quality_score' in df.columns:
        q = df['quality_score'].fillna(0).astype(float)
        seq_q = float(np.clip(q / 100.0, 0, 1).mean())
    else:
        seq_q = 0.5
    # Slightly relax score if tolerance is wide
    tol_factor = 1.0 - np.clip((tol - 2) / 3.0, 0, 1) * 0.2
    score = 100.0 * (0.45 * phylo_div + 0.35 * ok_frac + 0.20 * seq_q) * tol_factor
    return float(np.clip(score, 0, 100))

mrca_rows = []

if 'consensus_long' in globals() and not consensus_long.empty and 'consensus_tbl' in globals() and not consensus_tbl.empty:
    # ---- Build lineage LUT from working_df ----
    lineage_cols = []
    if 'Clas_id' in working_df.columns: lineage_cols.append('Clas_id')
    if 'Order'   in working_df.columns: lineage_cols.append('Order')
    if 'Family'  in working_df.columns: lineage_cols.append('Family')
    if 'cluster_genus' in working_df.columns: lineage_cols.append('cluster_genus')
    lineage_lut = (working_df[['Entry'] + lineage_cols].drop_duplicates()
                   if lineage_cols else pd.DataFrame(columns=['Entry']))

    # ---- Quality score LUT from exon cache (per accession/exon) ----
    q_lut = pd.DataFrame()
    if EXON_CACHE_TSV.exists():
        try:
            q_lut = pd.read_csv(
                EXON_CACHE_TSV, sep='\t', low_memory=False,
                usecols=lambda c: c in ('accession','exon_num_in_chain','quality_score')
            ).drop_duplicates(['accession','exon_num_in_chain'])
        except Exception:
            q_lut = pd.DataFrame()

    # ---- Merge consensus_long + consensus_tbl ----
    merged = consensus_long.merge(consensus_tbl, on=['gene_symbol','exon_num_in_chain'], how='left')

    # ---- Join lineage & quality info ----
    if not lineage_lut.empty:
        merged = merged.merge(
            lineage_lut.rename(columns={'Entry':'accession'}),
            on='accession', how='left'
        )
    if not q_lut.empty and 'quality_score' not in merged.columns:
        merged = merged.merge(q_lut, on=['accession','exon_num_in_chain'], how='left')

    # Ensure cluster_genus exists for scoring
    if 'cluster_genus' not in merged.columns:
        merged['cluster_genus'] = 'NA'

    # Compute MRCA level + reliability per exon group
    for (g, e), sub in merged.groupby(['gene_symbol','exon_num_in_chain']):
        tol = int(sub['tolerance_aa'].iloc[0]) if 'tolerance_aa' in sub.columns else 3
        mrca = mrca_level_for_group(sub, tol)
        rel  = reliability_score_for_group(sub, tol)
        depth = sub['depth_proxy'].iloc[0] if 'depth_proxy' in sub.columns else np.nan
        mrca_rows.append({
            'gene_symbol': g,
            'exon_num_in_chain': e,
            'MRCA_level': mrca,
            'reliability_score': round(rel, 1),
            'tolerance_aa': tol,
            'depth_proxy': depth
        })

    mrca_df = pd.DataFrame(mrca_rows)
    logger.info(f"MRCA/reliability computed rows: {len(mrca_df)} "
                f"(genes={mrca_df['gene_symbol'].nunique() if not mrca_df.empty else 0})")
else:
    mrca_df = pd.DataFrame()
    logger.info("No consensus tables available; run Cell 51 first.")


## Cell 54 – Final wide architecture (peptides + coords + status)

Builds a wide table per accession including exon peptides, AA coordinates,
boundary flags, and consensus coordinates.

In [None]:
# ===== Cell 54 - Enhanced Wide Architecture with Memory Management =====
# Memory-optimized wide DataFrame creation with monitoring and chunked processing

import gc
from typing import List, Dict, Any
import warnings

def monitor_memory_usage(operation_name: str, threshold_mb: float = 500.0) -> Dict[str, float]:
    """Monitor memory usage and warn if above threshold."""
    memory_info = get_memory_usage()

    if memory_info['rss_mb'] > threshold_mb:
        logger.warning(f"{operation_name}: High memory usage detected - {memory_info['rss_mb']:.1f}MB ({memory_info['percent']:.1f}%)")
    else:
        logger.info(f"{operation_name}: Memory usage - {memory_info['rss_mb']:.1f}MB ({memory_info['percent']:.1f}%)")

    return memory_info

def estimate_dataframe_memory(num_rows: int, num_cols: int, avg_string_length: int = 50) -> float:
    """Estimate memory usage for a DataFrame in MB."""
    # Rough estimation: each cell ~8 bytes + string overhead
    estimated_bytes = num_rows * num_cols * (8 + avg_string_length)
    return estimated_bytes / 1024 / 1024

def create_wide_architecture_chunked(consensus_long: pd.DataFrame,
                                   consensus_tbl: pd.DataFrame,
                                   mrca_df: pd.DataFrame,
                                   chunk_size: int = 1000) -> pd.DataFrame:
    """
    Create wide architecture DataFrame using chunked processing to manage memory.

    Args:
        consensus_long: Long-format consensus data
        consensus_tbl: Consensus table for annotation
        mrca_df: MRCA data for merging
        chunk_size: Number of accessions to process per chunk

    Returns:
        Wide-format DataFrame or empty DataFrame if processing fails
    """

    if consensus_long.empty:
        logger.warning("Empty consensus_long DataFrame - returning empty result")
        return pd.DataFrame()

    # Initial memory check
    initial_memory = monitor_memory_usage("Before wide DataFrame creation")

    # Estimate memory requirements
    unique_accessions = consensus_long['accession'].nunique()
    unique_exons = consensus_long['exon_num_in_chain'].nunique()
    estimated_cols = unique_exons * 3 + 10  # 3 cols per exon + metadata

    estimated_memory = estimate_dataframe_memory(unique_accessions, estimated_cols)
    logger.info(f"Estimated memory for wide DataFrame: {estimated_memory:.1f}MB")

    if estimated_memory > 1000:  # >1GB
        logger.warning("Large DataFrame expected - consider reducing data or increasing chunk size")

    # Prepare annotation data efficiently
    try:
        annot = consensus_tbl.merge(mrca_df, on=['gene_symbol', 'exon_num_in_chain'], how='left')
        logger.info(f"Created annotation table with {len(annot)} rows")
    except Exception as e:
        logger.error(f"Failed to create annotation table: {e}")
        return pd.DataFrame()

    # Group by accession and gene_symbol for chunked processing
    grouped = consensus_long.groupby(['accession', 'gene_symbol'])
    total_groups = len(grouped)

    logger.info(f"Processing {total_groups} accession-gene combinations in chunks of {chunk_size}")

    # Process in chunks to manage memory
    all_rows = []
    processed_count = 0

    # Get all group keys for chunking
    group_keys = list(grouped.groups.keys())

    for chunk_start in range(0, len(group_keys), chunk_size):
        chunk_end = min(chunk_start + chunk_size, len(group_keys))
        chunk_keys = group_keys[chunk_start:chunk_end]

        chunk_rows = []
        chunk_memory_start = get_memory_usage()

        logger.info(f"Processing chunk {chunk_start//chunk_size + 1}/{(len(group_keys)-1)//chunk_size + 1} "
                   f"({len(chunk_keys)} groups)")

        for (acc, g) in chunk_keys:
            try:
                sub = grouped.get_group((acc, g))
                row = {'accession': acc, 'gene_symbol': g}

                # Process each exon for this accession-gene combination
                for _, r in sub.iterrows():
                    try:
                        k = int(r['exon_num_in_chain'])

                        # Store peptide, coordinates, and status
                        row[f'exon_{k}_peptide'] = r.get('peptide', '')
                        row[f'exon_{k}_coords'] = f"{int(r['begin_aa'])}-{int(r['end_aa'])}"
                        row[f'exon_{k}_status'] = 'OK' if r.get('boundary_ok', False) else 'FLAG'

                    except (ValueError, KeyError) as e:
                        logger.debug(f"Error processing exon for {acc}: {e}")
                        continue

                chunk_rows.append(row)
                processed_count += 1

            except Exception as e:
                logger.warning(f"Error processing group ({acc}, {g}): {e}")
                continue

        # Convert chunk to DataFrame and add to results
        if chunk_rows:
            chunk_df = pd.DataFrame(chunk_rows)
            all_rows.extend(chunk_rows)

            # Memory monitoring for chunk
            chunk_memory_end = get_memory_usage()
            chunk_memory_used = chunk_memory_end['rss_mb'] - chunk_memory_start['rss_mb']

            if chunk_memory_used > 50:  # >50MB per chunk
                logger.warning(f"Chunk used {chunk_memory_used:.1f}MB - consider smaller chunk size")

            # Periodic garbage collection
            if (chunk_start // chunk_size + 1) % 5 == 0:  # Every 5 chunks
                gc.collect()
                logger.debug("Performed garbage collection")

        # Progress reporting
        progress = (chunk_end / len(group_keys)) * 100
        logger.info(f"Progress: {progress:.1f}% ({processed_count}/{total_groups} groups processed)")

    # Create final DataFrame
    if not all_rows:
        logger.warning("No valid rows created - returning empty DataFrame")
        return pd.DataFrame()

    try:
        logger.info(f"Creating final wide DataFrame from {len(all_rows)} rows")
        wide_df = pd.DataFrame(all_rows)

        # Memory check after DataFrame creation
        post_creation_memory = monitor_memory_usage("After wide DataFrame creation")
        memory_increase = post_creation_memory['rss_mb'] - initial_memory['rss_mb']

        logger.info(f"Wide DataFrame created: {wide_df.shape} (Memory increase: {memory_increase:.1f}MB)")

    except MemoryError:
        logger.error("MemoryError creating wide DataFrame - try smaller chunk size or filtering data")
        return pd.DataFrame()
    except Exception as e:
        logger.error(f"Error creating wide DataFrame: {e}")
        return pd.DataFrame()

    # Add consensus coordinates efficiently
    try:
        logger.info("Adding consensus coordinates...")
        cons_coords = {}

        for _, r in annot.iterrows():
            try:
                k = int(r['exon_num_in_chain'])
                gsym = r['gene_symbol']
                cons_coords[(gsym, k)] = f"{int(r['cons_begin'])}-{int(r['cons_end'])}"
            except (ValueError, KeyError):
                continue

        # Add consensus coordinate columns efficiently
        for (gsym, k), coord in cons_coords.items():
            col = f'cons_exon_{k}_coords'
            if col not in wide_df.columns:
                wide_df[col] = ''

            # Use boolean indexing for efficiency
            mask = wide_df['gene_symbol'] == gsym
            wide_df.loc[mask, col] = coord

        logger.info(f"Added consensus coordinates for {len(cons_coords)} exon-gene combinations")

    except Exception as e:
        logger.warning(f"Error adding consensus coordinates: {e}")
        # Continue without consensus coordinates rather than failing completely

    # Final memory and performance summary
    final_memory = monitor_memory_usage("Final wide DataFrame complete")
    total_memory_increase = final_memory['rss_mb'] - initial_memory['rss_mb']

    logger.info(f"Wide DataFrame processing complete:")
    logger.info(f"  Final shape: {wide_df.shape}")
    logger.info(f"  Total memory increase: {total_memory_increase:.1f}MB")
    logger.info(f"  Memory efficiency: {wide_df.memory_usage(deep=True).sum() / 1024 / 1024:.1f}MB actual usage")

    # Memory optimization suggestions
    if total_memory_increase > 200:
        logger.info("💡 Memory optimization suggestions:")
        logger.info("  - Consider filtering to fewer species/genes")
        logger.info("  - Reduce chunk_size parameter")
        logger.info("  - Use dtype optimization for string columns")

    return wide_df

# ===== Main Cell 54 Execution =====

logger.info("🔄 Creating final wide architecture...")

wide_df = pd.DataFrame()

if not consensus_long.empty:
    try:
        # Your existing wide DataFrame creation logic here
        # (Keep all the existing logic for creating the wide_df)

        # Build annotation table
        annot = consensus_tbl.merge(mrca_df, on=['gene_symbol','exon_num_in_chain'], how='left')

        # Build wide DataFrame rows
        rows = []
        for (acc, g), sub in consensus_long.groupby(['accession','gene_symbol']):
            row = {'accession': acc, 'gene_symbol': g}
            for _, r in sub.iterrows():
                k = int(r['exon_num_in_chain'])
                row[f'exon_{k}_peptide'] = r.get('peptide','')
                row[f'exon_{k}_coords']  = f"{int(r['begin_aa'])}-{int(r['end_aa'])}"
                row[f'exon_{k}_status']  = 'OK' if r.get('boundary_ok', False) else 'FLAG'
            rows.append(row)

        wide_df = pd.DataFrame(rows)

        # Add consensus coordinates
        cons_coords = {}
        for _, r in annot.iterrows():
            k = int(r['exon_num_in_chain'])
            gsym = r['gene_symbol']
            cons_coords[(gsym, k)] = f"{int(r['cons_begin'])}-{int(r['cons_end'])}"

        for (gsym, k), coord in cons_coords.items():
            col = f'cons_exon_{k}_coords'
            if col not in wide_df.columns:
                wide_df[col] = ''
            wide_df.loc[wide_df['gene_symbol']==gsym, col] = coord

        # ===== NEW: UPDATED SAVE LOGIC USING CANONICAL PATHS =====
        if not wide_df.empty:
            # Save to the canonical output path (REPLACES OLD: WIDE_ARCH_SNAPSHOT)
            wide_df.to_csv(EXON_WIDE_TSV, sep='\t', index=False)
            logger.info(f"✅ Final wide architecture rows: {len(wide_df)} (saved to {EXON_WIDE_TSV})")

            # Also create a copy in the run directory for provenance
            run_wide_path = RUN_DIR / f"wide_architecture_{RUN_ID}.tsv"
            wide_df.to_csv(run_wide_path, sep='\t', index=False)
            logger.info(f"📁 Run copy saved to: {run_wide_path}")

            # Optional: Save summary statistics
            try:
                stats = {
                    'total_rows': len(wide_df),
                    'unique_accessions': wide_df['accession'].nunique(),
                    'unique_genes': wide_df['gene_symbol'].nunique(),
                    'total_columns': len(wide_df.columns),
                    'exon_columns': len([col for col in wide_df.columns if 'exon_' in col and '_peptide' in col])
                }

                stats_path = RUN_DIR / f"wide_architecture_stats_{RUN_ID}.json"
                import json
                with open(stats_path, 'w') as f:
                    json.dump(stats, f, indent=2)
                logger.info(f"📊 Architecture statistics saved: {stats_path}")

            except Exception as stats_error:
                logger.warning(f"Could not save architecture statistics: {stats_error}")

        else:
            logger.warning("❌ Wide DataFrame is empty - no output generated")

    except Exception as e:
        logger.error(f"❌ Error creating wide architecture: {e}")
        wide_df = pd.DataFrame()

else:
    logger.warning("⚠️ No consensus_long data available; skipping final wide output.")

# Final memory cleanup if using memory management
if 'gc' in globals():
    gc.collect()

logger.info("✅ Cell 54 wide architecture processing complete")

2025-08-22 05:44:55,758 [INFO] - 🔄 Creating final wide architecture with memory management...
2025-08-22 05:44:56,302 [INFO] - Estimated memory for wide DataFrame: 205.6MB
2025-08-22 05:44:56,328 [INFO] - Created annotation table with 1150 rows
2025-08-22 05:44:58,730 [INFO] - Processing 14029 accession-gene combinations in chunks of 500
2025-08-22 05:44:58,741 [INFO] - Processing chunk 1/29 (500 groups)
2025-08-22 05:45:00,376 [INFO] - Progress: 3.6% (500/14029 groups processed)
2025-08-22 05:45:00,379 [INFO] - Processing chunk 2/29 (500 groups)
2025-08-22 05:45:01,408 [INFO] - Progress: 7.1% (1000/14029 groups processed)
2025-08-22 05:45:01,410 [INFO] - Processing chunk 3/29 (500 groups)
2025-08-22 05:45:02,381 [INFO] - Progress: 10.7% (1500/14029 groups processed)
2025-08-22 05:45:02,384 [INFO] - Processing chunk 4/29 (500 groups)
2025-08-22 05:45:03,352 [INFO] - Progress: 14.3% (2000/14029 groups processed)
2025-08-22 05:45:03,354 [INFO] - Processing chunk 5/29 (500 groups)
2025-08

In [None]:
# # ===== Cell 54 =====
# # Final wide architecture with coordinates & flags
# wide_df = pd.DataFrame()
# if not consensus_long.empty:
#     annot = consensus_tbl.merge(mrca_df, on=['gene_symbol','exon_num_in_chain'], how='left')
#     rows = []
#     for (acc, g), sub in consensus_long.groupby(['accession','gene_symbol']):
#         row = {'accession': acc, 'gene_symbol': g}
#         for _, r in sub.iterrows():
#             k = int(r['exon_num_in_chain'])
#             row[f'exon_{k}_peptide'] = r.get('peptide','')
#             row[f'exon_{k}_coords']  = f"{int(r['begin_aa'])}-{int(r['end_aa'])}"
#             row[f'exon_{k}_status']  = 'OK' if r.get('boundary_ok', False) else 'FLAG'
#         rows.append(row)
#     wide_df = pd.DataFrame(rows)
#     cons_coords = {}
#     for _, r in annot.iterrows():
#         k = int(r['exon_num_in_chain']); gsym = r['gene_symbol']
#         cons_coords[(gsym, k)] = f"{int(r['cons_begin'])}-{int(r['cons_end'])}"
#     for (gsym, k), coord in cons_coords.items():
#         col = f'cons_exon_{k}_coords';
#         if col not in wide_df.columns: wide_df[col] = ''
#         wide_df.loc[wide_df['gene_symbol']==gsym, col] = coord
#     if not wide_df.empty:
#         wide_df.to_csv(WIDE_ARCH_SNAPSHOT, sep='\t', index=False)
#         logger.info(f"Final wide architecture rows: {len(wide_df)} (saved)")
# else:
#     logger.info("No consensus_long; skipping final wide output.")

# **Part 6: RegExTractor — Exon‑centric rescue (replaces old Part 6)**

This Part **replaces the previous frameshift/DNA rescue logic** with the
exon‑ and clade‑aware **RegExTractor** engine. It **learns tiered regex
patterns per exon** from your **passed, high‑quality** mappings (Parts 1–5),
then scans the **not‑yet‑mapped** pool to recover additional COL1A1/COL1A2
exons. It is triplet‑aware (G‑X‑Y), entropy‑aware, and chain‑aware.

> Integration notes
> - **Does not touch** the heavy exon coordinate mapping built in **Cell 41**.
> - Uses **`wide_df`** (your main passed output matrix) to derive the training set.
> - Uses **`GENE_SYMBOLS`** global; restricts to **COL1A1/COL1A2** here.
> - If you keep DNA utilities from the old Part 6, use them **after** regex anchoring.

---


## Cell 60 – RegExTractor configuration

Parameters for pattern learning and matching. Uses `GENE_SYMBOLS` from earlier
Parts. To focus on COL1A1/COL1A2 initially, we filter at call time.


In [None]:
# ===== Cell 60 =====
# RegExTractor configuration (Colab-friendly; non-breaking)

from __future__ import annotations
from dataclasses import dataclass, field
from typing import Dict, List, Tuple, Optional, Iterable, Any
import re
import math
import numpy as np
import pandas as pd

# --- User-tunable parameters (safe defaults). Adjust with Colab #@param if desired.
rex_min_clade_samples = 8                     #@param {type:"integer"}
rex_anchor_triplets_min = 2                   #@param {type:"integer"}
rex_anchor_triplets_max = 4                   #@param {type:"integer"}
rex_anchor_entropy_max = 0.25                 #@param {type:"number"}
rex_freq_threshold_strict = 0.05              #@param {type:"number"}
rex_freq_threshold_moderate = 0.01            #@param {type:"number"}
rex_ghead_density_min = 0.80                  #@param {type:"number"}
rex_len_tolerance_strict = 1                  #@param {type:"integer"}
rex_len_tolerance_moderate = 3                #@param {type:"integer"}
rex_len_tolerance_loose = 6                   #@param {type:"integer"}
rex_chain_min_consecutive = 3                 #@param {type:"integer"}
rex_search_window_pad = 90                    #@param {type:"integer"}
rex_enable_fullregex_fallback = True          #@param {type:"boolean"}

# Taxonomic levels in descending specificity (must match your columns if available)
REX_TAXON_LEVELS = ["genus", "family", "order", "class", "kingdom", "pan"]

def rex_log(msg: str):
    """Lightweight logger."""
    print(f"[RegExTractor] {msg}")


## Cell 61 – Adapters and input preparation from Parts 1–5 outputs

This cell converts **your passed wide output** (`wide_df`) into the **training**
table RegExTractor needs. It also builds the **rejected/unmapped** pool from
your **unmapped sequences** table if available.

**Expected sources from earlier Parts**
- `wide_df` — main passed matrix with columns:
  - `accession`, `gene_symbol`,
  - `exon_<N>_peptide`, `exon_<N>_coords`, `exon_<N>_status` (even N),
  - optional taxonomy columns (`genus`, `family`, `order`, `class`, `kingdom`).
- `UNMAPPED_SEQUENCES_DF` (optional) — a table of sequences not yet in `wide_df`,
  with `accession`, `sequence`, and optional taxonomy columns.

If `UNMAPPED_SEQUENCES_DF` is not defined, we will **only** scan rows in `wide_df`
that have **few/no OK exons** *and* carry a full `sequence` column.


In [None]:
# ===== Cell 61 =====
@dataclass
class RexColumnMap:
    gene_col: str = "gene_symbol"
    exon_col: str = "exon_num_in_chain"
    pep_col: str = "exon_peptide"
    seq_col: str = "sequence"
    acc_col: str = "accession"
    tax_cols: Dict[str, str] = field(default_factory=lambda: {
        "genus": "genus",
        "family": "family",
        "order": "order",
        "class": "class",
        "kingdom": "kingdom"
    })

def rex_normalize_training_df(df: pd.DataFrame, cmap: RexColumnMap) -> pd.DataFrame:
    needed = {cmap.gene_col, cmap.exon_col, cmap.pep_col}
    missing = [c for c in needed if c not in df.columns]
    if missing:
        raise KeyError(f"Training DF missing columns: {missing}")
    out = pd.DataFrame({
        "gene_symbol": df[cmap.gene_col].astype(str),
        "exon_num_in_chain": pd.to_numeric(df[cmap.exon_col], errors="coerce").astype("Int64"),
        "exon_peptide": df[cmap.pep_col].astype(str)
    })
    for lvl in ["genus","family","order","class","kingdom"]:
        col = cmap.tax_cols.get(lvl)
        out[lvl] = df[col].astype(str) if (col in df.columns) else ""
    out = out.dropna(subset=["exon_num_in_chain"])
    out = out[out["exon_peptide"].str.len() > 0].copy()
    return out

def rex_normalize_rejected_df(df: pd.DataFrame, cmap: RexColumnMap) -> pd.DataFrame:
    needed = {cmap.acc_col, cmap.seq_col}
    missing = [c for c in needed if c not in df.columns]
    if missing:
        raise KeyError(f"Rejected DF missing columns: {missing}")
    out = pd.DataFrame({
        "accession": df[cmap.acc_col].astype(str),
        "sequence": df[cmap.seq_col].astype(str)
    })
    for lvl in ["genus","family","order","class","kingdom"]:
        col = cmap.tax_cols.get(lvl)
        out[lvl] = df[col].astype(str) if (col in df.columns) else ""
    out = out[out["sequence"].str.len() > 0].copy()
    return out

# --- Derive training from PASSED wide_df ---
if 'wide_df' not in globals():
    raise NameError("`wide_df` not found. Run Parts 1–5 to produce the passed matrix before Part 6.")

# Use entropy stats if present (optional hook)
entropy_stats_df = None
for cand in ["entropy_stats_df", "ENTROPY_STATS_DF"]:
    if cand in globals() and isinstance(globals()[cand], pd.DataFrame):
        entropy_stats_df = globals()[cand]
        break

# Identify exon indices present
exon_pat = re.compile(r"^exon_(-?\d+)_peptide$")
exon_indices = sorted({int(m.group(1)) for c in wide_df.columns for m in [exon_pat.match(c)] if m})

# Count OK exons per row (used to filter rejected pool from wide_df if needed)
ok_cols = [f"exon_{n}_status" for n in exon_indices if f"exon_{n}_status" in wide_df.columns]
def _count_ok(row) -> int:
    return sum(1 for stc in ok_cols if row.get(stc, "") == "OK")
wide_df["_ok_exon_count"] = wide_df.apply(_count_ok, axis=1)

# Build long-form training table from PASSED exons (status == OK)
training_rows = []
tax_cols_available = [c for c in ["genus","family","order","class","kingdom"] if c in wide_df.columns]

def rex_triplet_offset_best(seq: str) -> int:
    best_off, best_cnt = 0, -1
    for off in (0,1,2):
        cnt = sum(1 for i,ch in enumerate(seq) if (i - off) % 3 == 0 and ch == "G")
        if cnt > best_cnt:
            best_cnt = cnt
            best_off = off
    return best_off

def rex_ghead_density(seq: str, offset: Optional[int] = None) -> float:
    if not seq:
        return 0.0
    off = rex_triplet_offset_best(seq) if offset is None else offset
    triplets = max((len(seq) - off) // 3, 0)
    if triplets == 0:
        return 0.0
    heads = [i for i in range(off, off + 3 * triplets, 3)]
    g_cnt = sum(1 for i in heads if seq[i] == "G")
    return g_cnt / len(heads) if heads else 0.0

for _, row in wide_df.iterrows():
    gene = row.get("gene_symbol", "")
    if not isinstance(gene, str) or len(gene) == 0:
        continue
    for n in exon_indices:
        pep_c = f"exon_{n}_peptide"
        st_c  = f"exon_{n}_status"
        pep = row.get(pep_c, "")
        st  = row.get(st_c, "")
        if st != "OK" or not isinstance(pep, str) or len(pep) == 0:
            continue
        # Keep training peptides that obey triplet periodicity and have decent G-head density
        if len(pep) % 3 != 0:
            continue
        if rex_ghead_density(pep) < 0.70:
            continue
        rec = {
            "gene_symbol": gene,
            "exon_num_in_chain": int(n),
            "exon_peptide": pep
        }
        for t in tax_cols_available:
            rec[t] = row.get(t, "")
        training_rows.append(rec)

mapped_exon_df = pd.DataFrame(training_rows)
if mapped_exon_df.empty:
    raise ValueError("Training set is empty. Check that wide_df contains exon_*_peptide with status == 'OK'.")

# Build rejected/unmapped pool
# Preferred source: user-provided UNMAPPED_SEQUENCES_DF (not present in wide_df)
if 'UNMAPPED_SEQUENCES_DF' in globals() and isinstance(UNMAPPED_SEQUENCES_DF, pd.DataFrame):
    rejected_df = UNMAPPED_SEQUENCES_DF.copy()
else:
    # Fallback: try to use rows in wide_df with few OK exons and that include full sequences
    candidate_seq_cols = [c for c in ["sequence","protein_sequence","aa_sequence","seq","Sequence"] if c in wide_df.columns]
    sequence_col = candidate_seq_cols[0] if candidate_seq_cols else None
    if sequence_col is None:
        raise ValueError(
            "No UNMAPPED_SEQUENCES_DF provided and no sequence column found in wide_df.\n"
            "Provide a DataFrame `UNMAPPED_SEQUENCES_DF` with columns ['accession','sequence', optional taxon cols]."
        )
    # Heuristic: treat rows with very low mapping success as 'rejected' to re-scan
    rex_max_ok_exons_for_rejected = 6  # conservative; adjust if needed
    rejected_rows = []
    for _, row in wide_df.iterrows():
        if row["_ok_exon_count"] <= rex_max_ok_exons_for_rejected:
            acc = row.get("accession", "")
            seq = row.get(sequence_col, "")
            if isinstance(acc, str) and acc and isinstance(seq, str) and len(seq) > 0:
                rr = {"accession": acc, "sequence": seq}
                for t in tax_cols_available:
                    rr[t] = row.get(t, "")
                rejected_rows.append(rr)
    rejected_df = pd.DataFrame(rejected_rows)
    if rejected_df.empty:
        raise ValueError(
            "Rejected pool is empty. Either supply UNMAPPED_SEQUENCES_DF or relax the ok-exon threshold."
        )

# Normalize for RegExTractor
cmap = RexColumnMap(
    gene_col="gene_symbol",
    exon_col="exon_num_in_chain",
    pep_col="exon_peptide",
    seq_col="sequence",
    acc_col="accession",
    tax_cols={k:k for k in ["genus","family","order","class","kingdom"] if k in tax_cols_available}
)
training_norm = rex_normalize_training_df(mapped_exon_df, cmap)
rejected_norm = rex_normalize_rejected_df(rejected_df, cmap)

rex_log(f"Training set: {training_norm.shape}, Rejected set: {rejected_norm.shape}")


## Cell 62 – Exon statistics & anchor discovery

Triplet‑aware utilities, Shannon entropy, and low‑entropy anchor finding to seed
compact regexes. (Triplet logic mirrors collagen G‑X‑Y periodicity.)


In [None]:
# ===== Cell 62 =====
def rex_shannon_entropy(chars: Iterable[str]) -> float:
    arr = np.array(list(chars))
    if arr.size == 0:
        return 0.0
    vals, cnts = np.unique(arr, return_counts=True)
    p = cnts / cnts.sum()
    p = np.clip(p, 1e-12, 1.0)
    return float(-(p * np.log2(p)).sum())

def rex_find_anchor_windows(
    seqs: List[str],
    k_min: int = rex_anchor_triplets_min,
    k_max: int = rex_anchor_triplets_max,
    entropy_max: float = rex_anchor_entropy_max
) -> List[Tuple[int, int]]:
    """Return up to two non-overlapping low-entropy windows (AA indices) aligned to triplets."""
    if not seqs:
        return []
    offsets = [rex_triplet_offset_best(s) for s in seqs]
    L = int(np.median([len(s) for s in seqs]))
    # Build entropy per position, respecting triplet phase
    ent = np.zeros(L, dtype=float)
    # Collect characters per mod-3 position
    cols = {0: [[] for _ in range(L)], 1: [[] for _ in range(L)], 2: [[] for _ in range(L)]}
    for s in seqs:
        ss = (s + "-" * max(0, L - len(s)))[:L]
        for i, ch in enumerate(ss):
            cols[i % 3][i].append(ch)
    for i in range(L):
        ent[i] = rex_shannon_entropy(cols[i % 3][i])
    # Scan triplet windows for low-entropy runs
    candidates: List[Tuple[int,int,float]] = []
    for k in range(k_min, k_max + 1):
        w = 3 * k
        for i in range(0, L - w + 1, 3):
            e = float(ent[i:i+w].mean())
            if e <= entropy_max:
                candidates.append((i, i+w, e))
    candidates.sort(key=lambda x: x[2])
    anchors: List[Tuple[int,int]] = []
    for s, e, _ in candidates:
        if not anchors:
            anchors.append((s, e))
        elif len(anchors) == 1 and (e <= anchors[0][0] or s >= anchors[0][1]):
            anchors.append((s, e))
            break
    return sorted(anchors)


## Cell 63 – Character classes & spacer summaries

Build compact X/Y character classes per spacer region to assemble tiered regexes.


In [None]:
# ===== Cell 63 =====
def rex_char_classes_for_region(
    seqs: List[str],
    start: int,
    end: int,
    freq_threshold: float = 0.05
) -> Tuple[str, str]:
    if start >= end:
        return ".", "."
    X_counts, Y_counts = {}, {}
    for s in seqs:
        if len(s) < end:
            continue
        for i in range(start, end, 3):
            if i + 2 >= len(s):
                break
            x, y = s[i+1], s[i+2]
            X_counts[x] = X_counts.get(x, 0) + 1
            Y_counts[y] = Y_counts.get(y, 0) + 1

    def build_class(d: Dict[str,int]) -> str:
        if not d:
            return "."
        total = sum(d.values())
        keep = sorted([aa for aa, c in d.items() if c/total >= freq_threshold])
        if not keep:
            keep = [max(d.items(), key=lambda kv: kv[1])[0]]
        if len(keep) == 1:
            return re.escape(keep[0])
        return "[" + "".join(sorted(set(keep))) + "]"

    return build_class(X_counts), build_class(Y_counts)


## Cell 64 – Tiered pattern builder (per exon, per clade)

Construct strict→loose tiers (A/B/C/D) using anchors and spacer summaries.


In [None]:
# ===== Cell 64 =====
@dataclass
class ExonStats:
    exon_len_median: int
    exon_len_q1: int
    exon_len_q3: int
    anchors: List[Tuple[int, int]]
    ghead_density: float

@dataclass
class ExonTierPattern:
    tier: str
    regex: re.Pattern
    anchors_literals: List[str]
    len_range: Tuple[int, int]
    ghead_min: float

@dataclass
class ExonRegexLibraryEntry:
    gene_symbol: str
    exon_num_in_chain: int
    clade_key: str
    tiers: List[ExonTierPattern]

class RegExTractorBuilder:
    def __init__(self,
                 anchor_k_min: int = rex_anchor_triplets_min,
                 anchor_k_max: int = rex_anchor_triplets_max,
                 anchor_entropy_max: float = rex_anchor_entropy_max):
        self.kmin = anchor_k_min
        self.kmax = anchor_k_max
        self.entmax = anchor_entropy_max

    @staticmethod
    def _length_stats(peps: List[str]) -> Tuple[int,int,int]:
        lengths = np.array([len(p) for p in peps if p], dtype=int)
        if lengths.size == 0:
            return 0,0,0
        return int(np.median(lengths)), int(np.percentile(lengths,25)), int(np.percentile(lengths,75))

    def build_stats(self, peps: List[str]) -> ExonStats:
        if not peps:
            return ExonStats(0,0,0,[],0.0)
        Lmed, Lq1, Lq3 = self._length_stats(peps)
        anchors = rex_find_anchor_windows(peps, k_min=self.kmin, k_max=self.kmax, entropy_max=self.entmax)
        dens = float(np.mean([rex_ghead_density(p) for p in peps])) if peps else 0.0
        return ExonStats(Lmed, Lq1, Lq3, anchors, dens)

    def _anchor_literals(self, peps: List[str], anchors: List[Tuple[int,int]]) -> List[str]:
        lits = []
        for (s,e) in anchors:
            counter: Dict[str,int] = {}
            for p in peps:
                if len(p) < e:
                    continue
                frag = p[s:e]
                counter[frag] = counter.get(frag, 0) + 1
            if counter:
                top = max(counter.items(), key=lambda kv: kv[1])[0]
                lits.append(top)
        return lits

    def _spacer_token(self, Xc: str, Yc: str, rep_range: Tuple[int,int]) -> str:
        m, n = rep_range
        return f"(?:G{Xc}{Yc})" + (f"{{{m},{n}}}" if m != n else f"{{{m}}}")

    def build_tiers(
        self,
        peps: List[str],
        stats: ExonStats,
        freq_strict: float = rex_freq_threshold_strict,
        freq_moderate: float = rex_freq_threshold_moderate
    ) -> List[ExonTierPattern]:
        tiers: List[ExonTierPattern] = []
        if not peps or stats.exon_len_median <= 0:
            return tiers

        Lmed, Lq1, Lq3 = stats.exon_len_median, stats.exon_len_q1, stats.exon_len_q3
        anchor_lits = self._anchor_literals(peps, stats.anchors)

        # Define regions (left/middle/right) around anchors
        if stats.anchors:
            if len(stats.anchors) == 1:
                (a1s,a1e) = stats.anchors[0]
                regions = [(0,a1s), (a1e,Lmed)]
            else:
                (a1s,a1e), (a2s,a2e) = sorted(stats.anchors)
                regions = [(0,a1s), (a1e,a2s), (a2e,Lmed)]
        else:
            regions = [(0,Lmed)]

        xy_strict = [rex_char_classes_for_region(peps, s, e, freq_strict) for (s,e) in regions]
        xy_mod    = [rex_char_classes_for_region(peps, s, e, freq_moderate) for (s,e) in regions]

        def triplet_len_range(s:int, e:int, tol:int) -> Tuple[int,int]:
            L = max(e - s, 0)
            m = max((L // 3) - tol, 0)
            n = max((L // 3) + tol, 0)
            if m > n:
                m, n = n, m
            return m, n

        strict_tol = rex_len_tolerance_strict
        mod_tol    = rex_len_tolerance_moderate
        loose_tol  = rex_len_tolerance_loose

        def assemble(xy_classes, tol) -> str:
            parts = []
            if stats.anchors:
                a1s, a1e = stats.anchors[0]
                if (a1s - 0) >= 3:
                    Xc,Yc = xy_classes[0]
                    m,n = triplet_len_range(0, a1s, tol)
                    parts.append(self._spacer_token(Xc,Yc,(m,n)))
                if anchor_lits:
                    parts.append(re.escape(anchor_lits[0]))
                if len(stats.anchors) == 2:
                    a2s, a2e = stats.anchors[1]
                    if (a2s - a1e) >= 3:
                        Xc,Yc = xy_classes[1]
                        m,n = triplet_len_range(a1e, a2s, tol)
                        parts.append(self._spacer_token(Xc,Yc,(m,n)))
                    if len(anchor_lits) > 1:
                        parts.append(re.escape(anchor_lits[1]))
                    if (Lmed - a2e) >= 3:
                        Xc,Yc = xy_classes[2]
                        m,n = triplet_len_range(a2e, Lmed, tol)
                        parts.append(self._spacer_token(Xc,Yc,(m,n)))
                else:
                    if (Lmed - a1e) >= 3:
                        Xc,Yc = xy_classes[1]
                        m,n = triplet_len_range(a1e, Lmed, tol)
                        parts.append(self._spacer_token(Xc,Yc,(m,n)))
            else:
                Xc,Yc = xy_classes[0]
                m,n = triplet_len_range(0, Lmed, tol)
                parts.append(self._spacer_token(Xc,Yc,(m,n)))
            return "".join(parts)

        # Tier A
        pat_A = assemble(xy_strict, strict_tol)
        if pat_A:
            tiers.append(ExonTierPattern(
                tier="A",
                regex=re.compile(pat_A, flags=re.ASCII),
                anchors_literals=anchor_lits,
                len_range=(Lmed - strict_tol*3, Lmed + strict_tol*3),
                ghead_min=max(rex_ghead_density_min, stats.ghead_density * 0.9)
            ))
        # Tier B
        pat_B = assemble(xy_mod, mod_tol)
        if pat_B:
            tiers.append(ExonTierPattern(
                tier="B",
                regex=re.compile(pat_B, flags=re.ASCII),
                anchors_literals=anchor_lits,
                len_range=(Lmed - mod_tol*3, Lmed + mod_tol*3),
                ghead_min=max(rex_ghead_density_min * 0.9, stats.ghead_density * 0.8)
            ))
        # Tier C (anchors + degenerate triplets)
        xy_loose = [(".", ".") for _ in (anchor_lits if len(stats.anchors)==2 else [0,1])][:len(regions)]
        if not xy_loose or len(xy_loose) != len(regions):
            xy_loose = [(".", ".") for _ in regions]
        pat_C = assemble(xy_loose, loose_tol)
        if pat_C:
            tiers.append(ExonTierPattern(
                tier="C",
                regex=re.compile(pat_C, flags=re.ASCII),
                anchors_literals=anchor_lits,
                len_range=(Lmed - loose_tol*3, Lmed + loose_tol*3),
                ghead_min=max(rex_ghead_density_min * 0.8, 0.65)
            ))
        # Tier D (fallback)
        kmin = max((Lmed // 3) - loose_tol, 1)
        kmax = (Lmed // 3) + loose_tol
        pat_D = f"(?:G..)" + (f"{{{kmin},{kmax}}}" if kmin != kmax else f"{{{kmin}}}")
        tiers.append(ExonTierPattern(
            tier="D",
            regex=re.compile(pat_D, flags=re.ASCII),
            anchors_literals=[],
            len_range=(kmin*3, kmax*3),
            ghead_min=0.65
        ))
        return tiers


## Cell 65 – Library builder (per gene/exon/clade)


In [None]:
# ===== Cell 65 =====
class RegExTractorLibrary:
    def __init__(self):
        self.entries: Dict[Tuple[str,int,str], ExonRegexLibraryEntry] = {}
        self.stats: Dict[Tuple[str,int,str], ExonStats] = {}

    @staticmethod
    def _clade_keys(df: pd.DataFrame) -> List[str]:
        keys = []
        for lvl in REX_TAXON_LEVELS:
            if lvl == "pan":
                keys.append("pan")
            elif lvl in df.columns and df[lvl].astype(str).str.len().gt(0).any():
                keys.extend(sorted(df[lvl].astype(str).unique()))
        return list(dict.fromkeys(keys))

    @staticmethod
    def _subset_by_clade(df: pd.DataFrame, clade_key: str) -> pd.DataFrame:
        if clade_key == "pan":
            return df
        for lvl in ["genus","family","order","class","kingdom"]:
            if lvl in df.columns and clade_key in set(df[lvl].astype(str).unique()):
                return df[df[lvl].astype(str) == clade_key]
        return df.iloc[0:0].copy()

    def build(self, training_df: pd.DataFrame, genes: List[str], min_clade_samples: int = rex_min_clade_samples):
        builder = RegExTractorBuilder()
        for gene in genes:
            gdf = training_df[training_df["gene_symbol"] == gene]
            if gdf.empty:
                rex_log(f"No training rows for {gene}")
                continue
            for exon in sorted(gdf["exon_num_in_chain"].dropna().astype(int).unique()):
                edf = gdf[gdf["exon_num_in_chain"] == exon]
                if edf.empty:
                    continue
                clade_keys = self._clade_keys(edf)
                if "pan" not in clade_keys:
                    clade_keys.append("pan")
                for ck in clade_keys:
                    sub = self._subset_by_clade(edf, ck)
                    if ck != "pan" and len(sub) < min_clade_samples:
                        continue
                    peps = sub["exon_peptide"].astype(str).tolist()
                    stats = builder.build_stats(peps)
                    tiers = builder.build_tiers(peps, stats)
                    if not tiers:
                        continue
                    key = (gene, exon, ck)
                    self.entries[key] = ExonRegexLibraryEntry(
                        gene_symbol=gene,
                        exon_num_in_chain=exon,
                        clade_key=ck,
                        tiers=tiers
                    )
                    self.stats[key] = stats
        rex_log(f"Built {len(self.entries)} exon/clade entries across {len(genes)} genes.")


## Cell 66 – Matching & scoring


In [None]:
# ===== Cell 66 =====
@dataclass
class RexHit:
    accession: str
    gene_symbol: str
    exon_num_in_chain: int
    clade_used: str
    tier: str
    start: int
    end: int
    matched_peptide: str
    ghead_density: float
    len_residual: int
    anchors_hit: int
    score: float

class RegExTractorMatcher:
    def __init__(self, library: RegExTractorLibrary):
        self.lib = library

    @staticmethod
    def _tier_weight(tier: str) -> float:
        return {"A": 1.0, "B": 0.85, "C": 0.6, "D": 0.3}.get(tier, 0.1)

    @staticmethod
    def _choose_clade_key(row: pd.Series, keys_for_exon: List[str]) -> str:
        for lvl in REX_TAXON_LEVELS:
            if lvl == "pan":
                continue
            if lvl in row.index and isinstance(row[lvl], str) and row[lvl]:
                if row[lvl] in keys_for_exon:
                    return row[lvl]
        return "pan"

    def _scan_with_anchors(self, seq: str, pattern: ExonTierPattern, window_pad: int = rex_search_window_pad) -> List[Tuple[int,int,str]]:
        hits: List[Tuple[int,int,str]] = []
        if pattern.anchors_literals:
            for lit in pattern.anchors_literals:
                start_idx = 0
                while True:
                    idx = seq.find(lit, start_idx)
                    if idx == -1:
                        break
                    wstart = max(0, idx - window_pad)
                    wend   = min(len(seq), idx + len(lit) + window_pad)
                    m = pattern.regex.search(seq, wstart, wend)
                    if m:
                        hits.append((m.start(), m.end(), m.group(0)))
                        start_idx = m.end()
                    else:
                        start_idx = idx + 1
        if (not hits) and rex_enable_fullregex_fallback:
            for m in pattern.regex.finditer(seq):
                hits.append((m.start(), m.end(), m.group(0)))
        return hits

    def _score_hit(self, pep: str, tier: ExonTierPattern) -> Tuple[float,float,int,int]:
        gden = rex_ghead_density(pep)
        len_res = min(abs(len(pep) - tier.len_range[0]), abs(len(pep) - tier.len_range[1]))
        anchors_hit = sum(1 for lit in tier.anchors_literals if lit in pep)
        score = (
            0.45 * self._tier_weight(tier.tier) +
            0.30 * max(0.0, (gden - (tier.ghead_min - 0.1)) / 0.2) +
            0.15 * (1.0 - min(len_res / 9.0, 1.0)) +
            0.10 * (anchors_hit / max(len(tier.anchors_literals), 1))
        )
        return float(score), float(gden), int(len_res), int(anchors_hit)

    def scan_sequence_row(self, row: pd.Series, genes: List[str]) -> List[RexHit]:
        seq = row["sequence"]; accession = row["accession"]
        all_hits: List[RexHit] = []
        for gene in genes:
            entries = [(k,v) for k,v in self.lib.entries.items() if k[0] == gene]
            exons   = sorted(set([k[1] for k,_ in entries]))
            for exon in exons:
                keys_for_exon = [k[2] for k,_ in entries if k[1] == exon]
                clade = self._choose_clade_key(row, keys_for_exon)
                key = (gene, exon, clade) if (gene,exon,clade) in self.lib.entries else (gene,exon,"pan")
                entry = self.lib.entries.get(key)
                if not entry:
                    continue
                for tier in entry.tiers:
                    for s,e,mtxt in self._scan_with_anchors(seq, tier):
                        score,gden,lres,ahit = self._score_hit(mtxt, tier)
                        if gden < (tier.ghead_min - 0.05):
                            continue
                        all_hits.append(RexHit(
                            accession=accession,
                            gene_symbol=gene,
                            exon_num_in_chain=exon,
                            clade_used=entry.clade_key,
                            tier=tier.tier,
                            start=s, end=e,
                            matched_peptide=mtxt,
                            ghead_density=gden,
                            len_residual=lres,
                            anchors_hit=ahit,
                            score=score
                        ))
                    if any(h.tier=="A" and h.score>=0.85 for h in all_hits):
                        break
        # Keep highest-score per (gene,exon)
        best: Dict[Tuple[str,int], RexHit] = {}
        for h in sorted(all_hits, key=lambda x: x.score, reverse=True):
            key = (h.gene_symbol, h.exon_num_in_chain)
            if key not in best:
                best[key] = h
        return list(best.values())


## Cell 67 – Chain reconstruction (anchor‑and‑walk)

Seed with the strongest exon hit and walk downstream/upstream using expected
exon order to assemble coherent blocks.


In [None]:
# ===== Cell 67 =====
@dataclass
class RexChain:
    accession: str
    gene_symbol: str
    seed_exon: int
    exons: List[int]
    coords: List[Tuple[int,int]]
    tiers: List[str]
    mean_score: float
    consecutive_blocks: int

def rex_exon_order_model(training_df: pd.DataFrame, gene: str) -> List[int]:
    exons = (
        training_df[training_df["gene_symbol"] == gene]["exon_num_in_chain"]
        .dropna().astype(int).sort_values().unique().tolist()
    )
    return exons

def rex_walk_chain(
    seq_len: int,
    seed: RexHit,
    hits_for_gene: Dict[int, RexHit],
    exon_order: List[int],
    window_pad: int = rex_search_window_pad
) -> RexChain:
    if seed.exon_num_in_chain not in exon_order:
        return RexChain(seed.accession, seed.gene_symbol, seed.exon_num_in_chain,
                        [seed.exon_num_in_chain], [(seed.start, seed.end)],
                        [seed.tier], seed.score, 1)

    idx = exon_order.index(seed.exon_num_in_chain)
    exons = [seed.exon_num_in_chain]
    coords = [(seed.start, seed.end)]
    tiers = [seed.tier]
    last_end = seed.end
    consec = 1

    # downstream
    for j in range(idx+1, len(exon_order)):
        ex = exon_order[j]
        h = hits_for_gene.get(ex)
        if not h:
            break
        if h.start >= last_end and h.start <= last_end + window_pad:
            exons.append(ex); coords.append((h.start,h.end)); tiers.append(h.tier)
            last_end = h.end; consec += 1
        else:
            break

    # upstream
    first_start = seed.start
    for j in range(idx-1, -1, -1):
        ex = exon_order[j]
        h = hits_for_gene.get(ex)
        if not h:
            break
        if h.end <= first_start and h.end >= max(0, first_start - window_pad):
            exons.insert(0, ex); coords.insert(0,(h.start,h.end)); tiers.insert(0,h.tier)
            first_start = h.start; consec += 1
        else:
            break

    mean_score = float(np.mean([hits_for_gene[e].score if e in hits_for_gene else 0.0 for e in exons]))
    return RexChain(seed.accession, seed.gene_symbol, seed.exon_num_in_chain, exons, coords, tiers, mean_score, consec)


## Cell 68 – Orchestrator (build → scan → chain)

Build the library from passed exons, scan the rejected pool, and assemble chains.


In [None]:
# ===== Cell 68 =====
def rex_run_regextractor(
    training_df: pd.DataFrame,
    rejected_df: pd.DataFrame,
    genes: List[str],
    min_clade_samples: int = rex_min_clade_samples
) -> Tuple[pd.DataFrame, pd.DataFrame]:
    lib = RegExTractorLibrary()
    lib.build(training_df, genes, min_clade_samples=min_clade_samples)
    matcher = RegExTractorMatcher(lib)

    hits_rows: List[Dict[str,Any]] = []
    chains_rows: List[Dict[str,Any]] = []

    exon_orders = {g: rex_exon_order_model(training_df, g) for g in genes}

    for _, row in rejected_df.iterrows():
        seq_hits = matcher.scan_sequence_row(row, genes=genes)
        if not seq_hits:
            continue
        for h in seq_hits:
            hits_rows.append({
                "accession": h.accession,
                "gene_symbol": h.gene_symbol,
                "exon_num_in_chain": h.exon_num_in_chain,
                "clade_used": h.clade_used,
                "tier": h.tier,
                "start": h.start, "end": h.end,
                "length": h.end - h.start,
                "matched_peptide": h.matched_peptide,
                "ghead_density": h.ghead_density,
                "len_residual": h.len_residual,
                "anchors_hit": h.anchors_hit,
                "score": h.score
            })

        by_gene: Dict[str, List[RexHit]] = {}
        for h in seq_hits:
            by_gene.setdefault(h.gene_symbol, []).append(h)

        for gene, ghits in by_gene.items():
            hits_by_exon = {h.exon_num_in_chain: h for h in ghits}
            seed = sorted(ghits, key=lambda x: x.score, reverse=True)[0]
            chain = rex_walk_chain(
                seq_len=len(row["sequence"]),
                seed=seed,
                hits_for_gene=hits_by_exon,
                exon_order=exon_orders.get(gene, [])
            )
            chains_rows.append({
                "accession": chain.accession,
                "gene_symbol": chain.gene_symbol,
                "seed_exon": chain.seed_exon,
                "exons": chain.exons,
                "coords": chain.coords,
                "tiers": chain.tiers,
                "mean_score": chain.mean_score,
                "consecutive_blocks": chain.consecutive_blocks
            })

    rex_hits_df = pd.DataFrame(hits_rows)
    rex_chains_df = pd.DataFrame(chains_rows)
    if not rex_chains_df.empty:
        rex_chains_df = rex_chains_df[
            rex_chains_df["consecutive_blocks"].astype(int) >= rex_chain_min_consecutive
        ].reset_index(drop=True)

    rex_log(f"Emitted {len(rex_hits_df)} exon hits and {len(rex_chains_df)} chains.")
    return rex_hits_df, rex_chains_df


## Cell 69 – Run RegExTractor and export

Uses **training_norm** (from `wide_df`) and **rejected_norm** (from unmapped pool).
Restricts to **COL1A1/COL1A2** via `GENE_SYMBOLS`.


In [None]:
# ===== Cell 69 =====
# Pick targets from global GENE_SYMBOLS
if 'GENE_SYMBOLS' in globals():
    target_genes = [g for g in GENE_SYMBOLS if g in {"COL1A1","COL1A2"}]
else:
    target_genes = ["COL1A1","COL1A2"]

rex_hits_df, rex_chains_df = rex_run_regextractor(
    training_norm,
    rejected_norm,
    genes=target_genes
)

# Inspect
try:
    display(rex_hits_df.head(20))
    display(rex_chains_df.head(10))
except Exception:
    print(rex_hits_df.head(20))
    print(rex_chains_df.head(10))

# Save to /content (non-breaking; does not depend on any prior path constants)
hits_path   = "/content/rex_hits_COL1A1_COL1A2.csv"
chains_path = "/content/rex_chains_COL1A1_COL1A2.csv"
rex_hits_df.to_csv(hits_path, index=False)
rex_chains_df.to_csv(chains_path, index=False)
rex_log(f"Saved:\n  {hits_path}\n  {chains_path}")


# **Part 7: Shannon Entropy Analysis & Visualisation**

## Cell 71 – Per-exon entropy and plotting

Computes median entropy per exon and saves bar+error plots.
Charts use **matplotlib** only (no seaborn).

In [None]:
# ===== Cell 71 =====
# Shannon entropy per exon (robust computation + safe plotting)

import re
import numpy as np
import matplotlib.pyplot as plt

def shannon_entropy(col: list[str]) -> float:
    """Shannon entropy (base-2) for a list of single-character residues."""
    if not col:
        return 0.0
    vals, cnts = np.unique(col, return_counts=True)
    p = cnts / cnts.sum()
    p = np.clip(p, 1e-12, 1.0)
    return float(-(p * np.log2(p)).sum())

def pad_peptides(peps: list[str]) -> np.ndarray:
    """Pad peptides to a rectangular (n_peps x max_len) array with '' as filler."""
    m = max((len(x) for x in peps), default=0)
    arr = np.full((len(peps), m), '', dtype=object)
    for i, s in enumerate(peps):
        for j, ch in enumerate(s):
            arr[i, j] = ch
    return arr

def safe_median(values, default=0.0) -> float:
    """Median that returns `default` for empty/all-NaN inputs."""
    arr = np.asarray(list(values), dtype=float)
    if arr.size == 0 or not np.isfinite(arr).any():
        return float(default)
    return float(np.nanmedian(arr))

def exon_idx(colname: str) -> int:
    """Extract numeric exon index from 'exon_<N>_peptide' for proper sorting."""
    m = re.search(r'exon_(\d+)_peptide', colname or "")
    return int(m.group(1)) if m else 10**9

entropy_rows = []
n_plots = 0

if 'wide_df' in globals() and not wide_df.empty:
    for g, sub in wide_df.groupby('gene_symbol', dropna=False):
        exon_cols = [c for c in sub.columns
                     if isinstance(c, str) and c.startswith('exon_') and c.endswith('_peptide')]
        exon_cols = sorted(exon_cols, key=exon_idx)

        # Per-exon entropy & median length
        for c in exon_cols:
            peps = sub[c].fillna('').astype(str).tolist()
            # entropy: position-wise entropy, then median across positions
            arr = pad_peptides(peps)
            ents = []
            for j in range(arr.shape[1]):
                # characters present at column j (skip empty fillers)
                col = [ch for ch in arr[:, j].tolist() if ch]
                if not col:
                    continue
                ents.append(shannon_entropy(col))
            exon_entropy = safe_median(ents, default=0.0)

            # robust median length (ignore empty strings)
            lengths = [len(p) for p in peps if p]
            exon_length = int(round(safe_median(lengths, default=0.0)))

            entropy_rows.append({
                'gene_symbol': g,
                'exon_col': c,
                'median_length': exon_length,
                'entropy': exon_entropy
            })

    entropy_df = pd.DataFrame(entropy_rows)

    if not entropy_df.empty:
        # Save table
        entropy_df.to_csv(ENTROPY_TSV, sep='\t', index=False)
        logger.info(f"Entropy stats rows: {len(entropy_df)} (saved → {ENTROPY_TSV.name})")

        # Plots: exon index on X, median length as bars, entropy as error bars
        for g, sub in entropy_df.groupby('gene_symbol', dropna=False):
            sub = sub.copy()
            sub['exon_idx'] = sub['exon_col'].map(exon_idx)
            sub = sub.sort_values('exon_idx')

            xs = list(range(len(sub)))
            heights = [int(h) if np.isfinite(h) else 0 for h in sub['median_length'].tolist()]
            errs = [float(e) if np.isfinite(e) else 0.0 for e in sub['entropy'].tolist()]

            try:
                plt.figure(figsize=(10, 4))
                plt.bar(xs, heights)
                plt.errorbar(xs, heights, yerr=errs, fmt='none')
                plt.title(f"Exon length (bars) + entropy (errors) — {g}")
                plt.xlabel("Exon index (sorted)")
                plt.ylabel("Median AA length")
                out_png = OUTPUTS_PATH / f"entropy_{g}_{RUN_ID}.png"
                plt.tight_layout()
                plt.savefig(out_png)
                plt.close()
                n_plots += 1
            except Exception as e:
                logger.warning(f"Plot failed for gene {g}: {e}")

        logger.info(f"Entropy plots generated: {n_plots} (saved to {OUTPUTS_PATH})")
else:
    logger.info("No wide_df available; skipping entropy.")


# **Part 8: RegExTractor — Implementation (COL1A1 & COL1A2)**

This section adds a self‑contained, non‑breaking implementation of the
**RegExTractor** engine. It learns clade‑aware, exon‑specific, triplet‑aware
regex patterns from your **well‑mapped** exon atlas and then rescues exons
from **rejected** sequences, chaining hits into putative COL1A1/1A2 blocks.

> Assumptions (explicit):
> - You already have an exon‑level training table with **one peptide per exon**
>   (frame-correct, without stop codons) and taxonomic labels.
> - The **rejected** pool is a table of full protein sequences with species/clade
>   annotations.
> - Even‑indexed exon numbering is used within the helix region.
> - No function or variable from earlier parts is renamed. New utilities are
>   prefixed with `rex_` or placed in `RegExTractor*` classes to avoid conflicts.

In [None]:
# ===== Cell 80 =====
# RegExTractor configuration (Colab-friendly)
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Dict, List, Tuple, Optional, Iterable, Any
import re
import math
import numpy as np
import pandas as pd

# --------------------------------------------------------------------------------------
# User-tunable parameters (safe defaults). Adjust via Colab's #@param if desired.
# --------------------------------------------------------------------------------------
rex_min_clade_samples = 8                     #@param {type:"integer"}
rex_anchor_triplets_min = 2                   #@param {type:"integer"}
rex_anchor_triplets_max = 4                   #@param {type:"integer"}
rex_anchor_entropy_max = 0.25                 #@param {type:"number"}
rex_freq_threshold_strict = 0.05              #@param {type:"number"}
rex_freq_threshold_moderate = 0.01            #@param {type:"number"}
rex_ghead_density_min = 0.80                  #@param {type:"number"}
rex_len_tolerance_strict = 1                  #@param {type:"integer"}
rex_len_tolerance_moderate = 3                #@param {type:"integer"}
rex_len_tolerance_loose = 6                   #@param {type:"integer"}
rex_chain_min_consecutive = 3                 #@param {type:"integer"}
rex_search_window_pad = 90                    #@param {type:"integer"}
rex_enable_fullregex_fallback = True          #@param {type:"boolean"}

# Taxonomic levels in descending specificity (must match your columns if available)
REX_TAXON_LEVELS = ["genus", "family", "order", "class", "kingdom", "pan"]


def rex_log(msg: str):
    """Lightweight logger (replace with proper logging if desired)."""
    print(f"[RegExTractor] {msg}")


## Cell 81 – Data adapters & expectations

To avoid breaking existing code, we **adapt** whatever exon/rejected tables you
already have into a minimal interface RegExTractor needs.

**Training (mapped) table** must provide per row:

- `gene_symbol` ∈ {COL1A1, COL1A2}
- `exon_num_in_chain` (int; even indexing for helix exons)
- `exon_peptide` (str; AA sequence of that exon)
- Optional taxon columns: `genus`, `family`, `order`, `class`, `kingdom`

**Rejected (unmapped) table** must provide:

- `accession` (str)
- `sequence` (str; full AA sequence)
- Optional taxon columns as above

If your column names differ, pass a mapping dict to the adapter helpers.


In [None]:
# ===== Cell 81 =====
# Adapters to standardize input column names for RegExTractor

@dataclass
class RexColumnMap:
    gene_col: str = "gene_symbol"
    exon_col: str = "exon_num_in_chain"
    pep_col: str = "exon_peptide"
    seq_col: str = "sequence"
    acc_col: str = "accession"
    tax_cols: Dict[str, str] = field(default_factory=lambda: {
        "genus": "genus",
        "family": "family",
        "order": "order",
        "class": "class",
        "kingdom": "kingdom"
    })

def rex_normalize_training_df(df: pd.DataFrame, cmap: RexColumnMap) -> pd.DataFrame:
    """
    Normalize a training exon table into required columns.

    Returns a new DataFrame with columns:
        gene_symbol, exon_num_in_chain, exon_peptide, genus, family, order, class, kingdom
    """
    needed = {cmap.gene_col, cmap.exon_col, cmap.pep_col}
    missing = [c for c in needed if c not in df.columns]
    if missing:
        raise KeyError(f"Training DF missing columns: {missing}")

    out = pd.DataFrame({
        "gene_symbol": df[cmap.gene_col].astype(str),
        "exon_num_in_chain": pd.to_numeric(df[cmap.exon_col], errors="coerce").astype("Int64"),
        "exon_peptide": df[cmap.pep_col].astype(str)
    })

    for lvl in ["genus", "family", "order", "class", "kingdom"]:
        col = cmap.tax_cols.get(lvl)
        out[lvl] = df[col].astype(str) if (col in df.columns) else ""

    # Drop rows with NA exon index or empty peptides
    out = out.dropna(subset=["exon_num_in_chain"])
    out = out[out["exon_peptide"].str.len() > 0].copy()
    return out

def rex_normalize_rejected_df(df: pd.DataFrame, cmap: RexColumnMap) -> pd.DataFrame:
    """
    Normalize a rejected table into required columns.

    Returns a new DataFrame with columns:
        accession, sequence, genus, family, order, class, kingdom
    """
    needed = {cmap.acc_col, cmap.seq_col}
    missing = [c for c in needed if c not in df.columns]
    if missing:
        raise KeyError(f"Rejected DF missing columns: {missing}")

    out = pd.DataFrame({
        "accession": df[cmap.acc_col].astype(str),
        "sequence": df[cmap.seq_col].astype(str)
    })

    for lvl in ["genus", "family", "order", "class", "kingdom"]:
        col = cmap.tax_cols.get(lvl)
        out[lvl] = df[col].astype(str) if (col in df.columns) else ""

    # Remove empty sequences
    out = out[out["sequence"].str.len() > 0].copy()
    return out


## Cell 82 – Exon statistics & triplet utilities

- Shannon entropy per aligned column.
- Triplet grid helpers (G at head, X, Y positions).
- Anchor window detection (low‑entropy, short motifs).


In [None]:
# ===== Cell 82 =====
# Exon statistics & triplet helpers

def rex_shannon_entropy(chars: Iterable[str]) -> float:
    arr = np.array(list(chars))
    if arr.size == 0:
        return 0.0
    vals, cnts = np.unique(arr, return_counts=True)
    p = cnts / cnts.sum()
    p = np.clip(p, 1e-12, 1.0)
    return float(-(p * np.log2(p)).sum())

def rex_triplet_offset_best(seq: str) -> int:
    """
    Choose 0/1/2 offset maximizing G at head positions (positions i where (i-offset)%3==0).
    """
    best_off, best_cnt = 0, -1
    for off in (0, 1, 2):
        cnt = sum(1 for i,ch in enumerate(seq) if (i - off) % 3 == 0 and ch == "G")
        if cnt > best_cnt:
            best_cnt = cnt
            best_off = off
    return best_off

def rex_ghead_density(seq: str, offset: Optional[int] = None) -> float:
    """
    Fraction of triplets with G at head under given offset (or the best offset if None).
    """
    if not seq:
        return 0.0
    off = rex_triplet_offset_best(seq) if offset is None else offset
    triplets = max((len(seq) - off) // 3, 0)
    if triplets == 0:
        return 0.0
    heads = [i for i in range(off, off + 3 * triplets, 3)]
    g_cnt = sum(1 for i in heads if seq[i] == "G")
    return g_cnt / len(heads) if heads else 0.0

def rex_find_anchor_windows(
    seqs: List[str],
    k_min: int = 2,
    k_max: int = 4,
    entropy_max: float = 0.25
) -> List[Tuple[int, int]]:
    """
    Find candidate anchor windows (start, end) in units of amino acids, not triplets.
    We scan all sequences aligned by their best G-head offset and compute entropy per
    triplet slot across the cohort. We return up to two non-overlapping low-entropy windows.

    NOTE: This is alignment-light and relies on the collagen triplet periodicity.
    """
    if not seqs:
        return []

    # Normalize offsets per sequence to maximize head-G alignment
    offsets = [rex_triplet_offset_best(s) for s in seqs]
    # Determine a nominal aligned region length (median len)
    L = int(np.median([len(s) for s in seqs]))
    # Build head/X/Y symbol matrices truncated to L
    cols = {0: [], 1: [], 2: []}
    for s, off in zip(seqs, offsets):
        # pad to L with gaps to equalize
        ss = (s + "-" * max(0, L - len(s)))[:L]
        for i, ch in enumerate(ss):
            cols[i % 3].append(ch)

    # Compute per-position entropy across all sequences (but stratified by position mod 3)
    ent = np.zeros(L, dtype=float)
    for i in range(L):
        ent[i] = rex_shannon_entropy([cols[i % 3][j] for j in range(len(seqs)) if i < len(seqs[j])])

    # Slide over triplet windows to find low-entropy runs
    candidates: List[Tuple[int, int, float]] = []
    for k in range(k_min, k_max + 1):
        w = 3 * k
        for i in range(0, L - w + 1, 3):
            e = float(ent[i:i + w].mean())
            if e <= entropy_max:
                candidates.append((i, i + w, e))

    # Pick up to two non-overlapping windows with best (lowest) entropy, spread near ends
    candidates.sort(key=lambda x: x[2])
    anchors: List[Tuple[int, int]] = []
    for s, e, _ in candidates:
        if not anchors:
            anchors.append((s, e))
        elif len(anchors) == 1:
            # Prefer a window far from the first anchor to frame the exon
            s0, e0 = anchors[0]
            if e <= s0 or s >= e0:  # non-overlap
                anchors.append((s, e))
                break
    return sorted(anchors)


## Cell 83 – Character classes & spacer summaries

For each **spacer region** (between anchors, and the flanks), we summarize
allowed residues per triplet position into compact character classes to use
inside a repeating token: `(?:G[<Xclass>][<Yclass>]){m,n}`.


In [None]:
# ===== Cell 83 =====
def rex_char_classes_for_region(
    seqs: List[str],
    start: int,
    end: int,
    freq_threshold: float = 0.05
) -> Tuple[str, str]:
    """
    Build character classes for X and Y positions over [start, end) region.
    Returns (Xclass, Yclass) where each is either a single literal or a bracket class.
    """
    if start >= end:
        return ".", "."
    X_counts, Y_counts = {}, {}
    for s in seqs:
        # Guard: skip too-short sequences
        if len(s) < end:
            continue
        for i in range(start, end, 3):
            if i + 2 >= len(s):
                break
            x, y = s[i + 1], s[i + 2]
            X_counts[x] = X_counts.get(x, 0) + 1
            Y_counts[y] = Y_counts.get(y, 0) + 1

    def build_class(d: Dict[str, int]) -> str:
        if not d:
            return "."
        total = sum(d.values())
        keep = sorted([aa for aa, c in d.items() if c / total >= freq_threshold])
        if not keep:
            # keep the top-1 if none pass threshold
            keep = [max(d.items(), key=lambda kv: kv[1])[0]]
        if len(keep) == 1:
            return re.escape(keep[0])
        return "[" + "".join(sorted(set(keep))) + "]"

    return build_class(X_counts), build_class(Y_counts)


## Cell 84 – Tiered pattern builder (per exon, per clade)

We create a **stack** of patterns per exon:
- **Tier A:** strict anchors + tight X/Y classes + narrow length tolerance
- **Tier B:** same anchors + broader classes + moderate tolerance
- **Tier C:** anchors + degenerate X/Y (`.`) + loose tolerance
- **Tier D:** degenerate triplet token only (optional fallback)


In [None]:
# ===== Cell 84 =====
@dataclass
class ExonStats:
    exon_len_median: int
    exon_len_q1: int
    exon_len_q3: int
    anchors: List[Tuple[int, int]]  # [(start,end)] AA indices within exon
    ghead_density: float

@dataclass
class ExonTierPattern:
    tier: str
    regex: re.Pattern
    anchors_literals: List[str]  # literal peptides of anchor windows
    len_range: Tuple[int, int]
    ghead_min: float

@dataclass
class ExonRegexLibraryEntry:
    gene_symbol: str
    exon_num_in_chain: int
    clade_key: str
    tiers: List[ExonTierPattern]

class RegExTractorBuilder:
    """
    Builds tiered regex patterns for one exon given training peptides and clade scope.
    """
    def __init__(self,
                 anchor_k_min: int = rex_anchor_triplets_min,
                 anchor_k_max: int = rex_anchor_triplets_max,
                 anchor_entropy_max: float = rex_anchor_entropy_max):
        self.kmin = anchor_k_min
        self.kmax = anchor_k_max
        self.entmax = anchor_entropy_max

    @staticmethod
    def _length_stats(peps: List[str]) -> Tuple[int, int, int]:
        lengths = np.array([len(p) for p in peps if p], dtype=int)
        if lengths.size == 0:
            return 0, 0, 0
        return int(np.median(lengths)), int(np.percentile(lengths, 25)), int(np.percentile(lengths, 75))

    def build_stats(self, peps: List[str]) -> ExonStats:
        if not peps:
            return ExonStats(0, 0, 0, [], 0.0)
        Lmed, Lq1, Lq3 = self._length_stats(peps)
        anchors = rex_find_anchor_windows(
            peps, k_min=self.kmin, k_max=self.kmax, entropy_max=self.entmax
        )
        # Compute average G-head density across peptides
        dens = np.mean([rex_ghead_density(p) for p in peps]) if peps else 0.0
        return ExonStats(Lmed, Lq1, Lq3, anchors, float(dens))

    @staticmethod
    def _literal(seq: str, start: int, end: int) -> str:
        return seq[start:end]

    def _anchor_literals(self, peps: List[str], anchors: List[Tuple[int, int]]) -> List[str]:
        # Use the most common literal for each anchor window
        lits = []
        for (s, e) in anchors:
            counter: Dict[str, int] = {}
            for p in peps:
                if len(p) < e:
                    continue
                frag = p[s:e]
                counter[frag] = counter.get(frag, 0) + 1
            if counter:
                top = max(counter.items(), key=lambda kv: kv[1])[0]
                lits.append(top)
        return lits

    def _spacer_token(self, Xclass: str, Yclass: str, rep_range: Tuple[int, int]) -> str:
        m, n = rep_range
        return f"(?:G{Xclass}{Yclass})" + (f"{{{m},{n}}}" if m != n else f"{{{m}}}")

    def build_tiers(
        self,
        peps: List[str],
        stats: ExonStats,
        freq_strict: float = rex_freq_threshold_strict,
        freq_moderate: float = rex_freq_threshold_moderate
    ) -> List[ExonTierPattern]:
        """
        Construct Tier A/B/C/D patterns. Returns compiled patterns with metadata.
        """
        tiers: List[ExonTierPattern] = []
        if not peps or stats.exon_len_median <= 0:
            return tiers

        Lmed, Lq1, Lq3 = stats.exon_len_median, stats.exon_len_q1, stats.exon_len_q3
        # Anchor literals from cohort
        anchor_lits = self._anchor_literals(peps, stats.anchors)

        # Define spacer regions [by AA index] between anchors to summarize X/Y classes
        # Regions: left flank, middle, right flank
        regions: List[Tuple[int, int]] = []
        if stats.anchors:
            if len(stats.anchors) == 1:
                (a1s, a1e) = stats.anchors[0]
                regions = [(0, a1s), (a1e, Lmed)]
            else:
                (a1s, a1e), (a2s, a2e) = sorted(stats.anchors)
                regions = [(0, a1s), (a1e, a2s), (a2e, Lmed)]
        else:
            regions = [(0, Lmed)]

        # Summarize X/Y classes per region at two thresholds
        xy_strict = [rex_char_classes_for_region(peps, s, e, freq_strict) for (s, e) in regions]
        xy_mod = [rex_char_classes_for_region(peps, s, e, freq_moderate) for (s, e) in regions]

        # Map region lengths (in triplets, rounded)
        def triplet_len_range(s: int, e: int, tol: int) -> Tuple[int, int]:
            L = max(e - s, 0)
            m = max((L // 3) - tol, 0)
            n = max((L // 3) + tol, 0)
            if m > n:
                m, n = n, m
            return m, n

        # Tiers A/B/C assembly
        # Length tolerances around per-region triplet counts
        strict_tol = rex_len_tolerance_strict
        mod_tol = rex_len_tolerance_moderate
        loose_tol = rex_len_tolerance_loose

        def assemble(xy_classes, tol) -> str:
            parts = []
            # Left flank
            if stats.anchors:
                a1s, a1e = stats.anchors[0]
                m, n = triplet_len_range(0, a1s, tol)
                Xc, Yc = xy_classes[0]
                if (a1s - 0) >= 3:
                    parts.append(self._spacer_token(Xc, Yc, (m, n)))
                parts.append(re.escape(anchor_lits[0])) if anchor_lits else None
                if len(stats.anchors) == 2:
                    (a2s, a2e) = stats.anchors[1]
                    # Middle
                    m, n = triplet_len_range(a1e, a2s, tol)
                    Xc, Yc = xy_classes[1]
                    if (a2s - a1e) >= 3:
                        parts.append(self._spacer_token(Xc, Yc, (m, n)))
                    parts.append(re.escape(anchor_lits[1])) if len(anchor_lits) > 1 else None
                    # Right flank
                    m, n = triplet_len_range(a2e, Lmed, tol)
                    Xc, Yc = xy_classes[2]
                    if (Lmed - a2e) >= 3:
                        parts.append(self._spacer_token(Xc, Yc, (m, n)))
                else:
                    # Single anchor → right flank only
                    m, n = triplet_len_range(a1e, Lmed, tol)
                    Xc, Yc = xy_classes[1]
                    if (Lmed - a1e) >= 3:
                        parts.append(self._spacer_token(Xc, Yc, (m, n)))
            else:
                # No anchors → one spacer covering entire exon
                Xc, Yc = xy_classes[0]
                m, n = triplet_len_range(0, Lmed, tol)
                parts.append(self._spacer_token(Xc, Yc, (m, n)))
            return "".join(parts)

        # Tier A (strict)
        pat_A = assemble(xy_strict, strict_tol)
        if pat_A:
            tiers.append(ExonTierPattern(
                tier="A",
                regex=re.compile(pat_A, flags=re.ASCII),
                anchors_literals=anchor_lits,
                len_range=(Lmed - strict_tol*3, Lmed + strict_tol*3),
                ghead_min=max(rex_ghead_density_min, stats.ghead_density * 0.9)
            ))

        # Tier B (moderate)
        pat_B = assemble(xy_mod, mod_tol)
        if pat_B:
            tiers.append(ExonTierPattern(
                tier="B",
                regex=re.compile(pat_B, flags=re.ASCII),
                anchors_literals=anchor_lits,
                len_range=(Lmed - mod_tol*3, Lmed + mod_tol*3),
                ghead_min=max(rex_ghead_density_min * 0.9, stats.ghead_density * 0.8)
            ))

        # Tier C (loose; anchors + generic triplet tokens)
        # Replace classes with '.' and widen tolerances
        xy_loose = [(".", ".") for _ in regions]
        pat_C = assemble(xy_loose, loose_tol)
        if pat_C:
            tiers.append(ExonTierPattern(
                tier="C",
                regex=re.compile(pat_C, flags=re.ASCII),
                anchors_literals=anchor_lits,
                len_range=(Lmed - loose_tol*3, Lmed + loose_tol*3),
                ghead_min=max(rex_ghead_density_min * 0.8, 0.65)
            ))

        # Tier D (degenerate fallback): (?:G..){kmin,kmax} without anchors
        kmin = max((Lmed // 3) - loose_tol, 1)
        kmax = (Lmed // 3) + loose_tol
        pat_D = f"(?:G..)" + (f"{{{kmin},{kmax}}}" if kmin != kmax else f"{{{kmin}}}")
        tiers.append(ExonTierPattern(
            tier="D",
            regex=re.compile(pat_D, flags=re.ASCII),
            anchors_literals=[],
            len_range=(kmin*3, kmax*3),
            ghead_min=0.65
        ))
        return tiers


## Cell 85 – Pattern library (per gene, exon, clade)

We group training peptides by **clade level** with sufficient support and
compile entries keyed as `(gene, exon, clade_key)`.

In [None]:
# ===== Cell 85 =====
class RegExTractorLibrary:
    """
    Holds compiled patterns across genes/exons/clades.
    """
    def __init__(self):
        self.entries: Dict[Tuple[str, int, str], ExonRegexLibraryEntry] = {}
        self.stats: Dict[Tuple[str, int, str], ExonStats] = {}

    @staticmethod
    def _clade_keys(df: pd.DataFrame) -> List[str]:
        # Build clade keys in descending specificity given data availability
        keys = []
        for lvl in REX_TAXON_LEVELS:
            if lvl == "pan":
                keys.append("pan")
            elif lvl in df.columns and df[lvl].astype(str).str.len().gt(0).any():
                keys.extend(sorted(df[lvl].astype(str).unique()))
        return list(dict.fromkeys(keys))  # stable unique

    @staticmethod
    def _subset_by_clade(df: pd.DataFrame, clade_key: str) -> pd.DataFrame:
        if clade_key == "pan":
            return df
        for lvl in ["genus", "family", "order", "class", "kingdom"]:
            if lvl in df.columns and clade_key in set(df[lvl].astype(str).unique()):
                return df[df[lvl].astype(str) == clade_key]
        # Not found → empty
        return df.iloc[0:0].copy()

    def build(
        self,
        training_df: pd.DataFrame,
        genes: List[str],
        min_clade_samples: int = rex_min_clade_samples
    ):
        """
        Build the tiered regex library from a normalized training dataframe.
        """
        builder = RegExTractorBuilder()
        for gene in genes:
            gdf = training_df[training_df["gene_symbol"] == gene]
            if gdf.empty:
                rex_log(f"No training rows for {gene}")
                continue
            for exon in sorted(gdf["exon_num_in_chain"].dropna().astype(int).unique()):
                edf = gdf[gdf["exon_num_in_chain"] == exon]
                if edf.empty:
                    continue
                clade_keys = self._clade_keys(edf)
                # Always ensure 'pan' last
                if "pan" not in clade_keys:
                    clade_keys.append("pan")
                for ck in clade_keys:
                    sub = self._subset_by_clade(edf, ck)
                    if ck != "pan" and len(sub) < min_clade_samples:
                        continue
                    peps = sub["exon_peptide"].astype(str).tolist()
                    stats = builder.build_stats(peps)
                    tiers = builder.build_tiers(peps, stats)
                    if not tiers:
                        continue
                    key = (gene, exon, ck)
                    self.entries[key] = ExonRegexLibraryEntry(
                        gene_symbol=gene,
                        exon_num_in_chain=exon,
                        clade_key=ck,
                        tiers=tiers
                    )
                    self.stats[key] = stats
        rex_log(f"Built {len(self.entries)} exon/clade entries across {len(genes)} genes.")


## Cell 86 – Matching & scoring

- Two‑stage search: **anchor‑first** (fast) then full regex in a local window.
- Composite score combines tier, G‑head density, length plausibility, and anchors.


In [None]:
# ===== Cell 86 =====
@dataclass
class RexHit:
    accession: str
    gene_symbol: str
    exon_num_in_chain: int
    clade_used: str
    tier: str
    start: int
    end: int
    matched_peptide: str
    ghead_density: float
    len_residual: int
    anchors_hit: int
    score: float

class RegExTractorMatcher:
    def __init__(self, library: RegExTractorLibrary):
        self.lib = library

    @staticmethod
    def _tier_weight(tier: str) -> float:
        return {"A": 1.0, "B": 0.85, "C": 0.6, "D": 0.3}.get(tier, 0.1)

    @staticmethod
    def _choose_clade_key(row: pd.Series, keys_for_exon: List[str]) -> str:
        # Pick the most specific available clade key for this row; fallback to 'pan'
        for lvl in REX_TAXON_LEVELS:
            if lvl == "pan":
                continue
            if lvl in row.index and isinstance(row[lvl], str) and row[lvl]:
                if row[lvl] in keys_for_exon:
                    return row[lvl]
        return "pan"

    def _scan_with_anchors(
        self, seq: str, pattern: ExonTierPattern, window_pad: int = rex_search_window_pad
    ) -> List[Tuple[int, int, str]]:
        """
        Return list of (start, end, match_text) using anchor priming to limit regex search.
        """
        hits: List[Tuple[int, int, str]] = []
        if pattern.anchors_literals:
            # Find candidate windows around anchor occurrences
            for lit in pattern.anchors_literals:
                start_idx = 0
                while True:
                    idx = seq.find(lit, start_idx)
                    if idx == -1:
                        break
                    wstart = max(0, idx - window_pad)
                    wend = min(len(seq), idx + len(lit) + window_pad)
                    m = pattern.regex.search(seq, wstart, wend)
                    if m:
                        hits.append((m.start(), m.end(), m.group(0)))
                        # Advance past this match
                        start_idx = m.end()
                    else:
                        start_idx = idx + 1
        if (not hits) and rex_enable_fullregex_fallback:
            for m in pattern.regex.finditer(seq):
                hits.append((m.start(), m.end(), m.group(0)))
        return hits

    def _score_hit(self, pep: str, tier: ExonTierPattern) -> Tuple[float, float, int, int]:
        gden = rex_ghead_density(pep)
        len_res = min(abs(len(pep) - tier.len_range[0]), abs(len(pep) - tier.len_range[1]))
        anchors_hit = 0
        for lit in tier.anchors_literals:
            if lit in pep:
                anchors_hit += 1
        # Composite score
        score = (
            0.45 * self._tier_weight(tier.tier) +
            0.30 * max(0.0, (gden - (tier.ghead_min - 0.1)) / 0.2) +  # normalize in [0,1]
            0.15 * (1.0 - min(len_res / 9.0, 1.0)) +
            0.10 * (anchors_hit / max(len(tier.anchors_literals), 1))
        )
        return float(score), float(gden), int(len_res), int(anchors_hit)

    def scan_sequence_row(
        self,
        row: pd.Series,
        genes: List[str]
    ) -> List[RexHit]:
        """
        Scan a single rejected row for exon hits across the requested genes.
        """
        seq = row["sequence"]
        accession = row["accession"]
        all_hits: List[RexHit] = []

        for gene in genes:
            # Gather available exons for this gene
            entries = [(k, v) for k, v in self.lib.entries.items() if k[0] == gene]
            exons = sorted(set([k[1] for k, _ in entries]))
            for exon in exons:
                # Find best clade key usable for this exon
                keys_for_exon = [k[2] for k, _ in entries if k[1] == exon]
                clade = self._choose_clade_key(row, keys_for_exon)
                key = (gene, exon, clade) if (gene, exon, clade) in self.lib.entries else (gene, exon, "pan")
                entry = self.lib.entries.get(key)
                if not entry:
                    continue
                for tier in entry.tiers:
                    # Anchor-primed search
                    for s, e, mtxt in self._scan_with_anchors(seq, tier):
                        score, gden, lres, ahit = self._score_hit(mtxt, tier)
                        if gden < (tier.ghead_min - 0.05):
                            continue
                        all_hits.append(RexHit(
                            accession=accession,
                            gene_symbol=gene,
                            exon_num_in_chain=exon,
                            clade_used=entry.clade_key,
                            tier=tier.tier,
                            start=s,
                            end=e,
                            matched_peptide=mtxt,
                            ghead_density=gden,
                            len_residual=lres,
                            anchors_hit=ahit,
                            score=score
                        ))
                    # Early exit if we already found strong tier hits
                    if any(h.tier == "A" and h.score >= 0.85 for h in all_hits):
                        break
        # Deduplicate overlapping hits by keeping the highest score per (gene, exon)
        best: Dict[Tuple[str, int], RexHit] = {}
        for h in sorted(all_hits, key=lambda x: x.score, reverse=True):
            key = (h.gene_symbol, h.exon_num_in_chain)
            if key not in best:
                best[key] = h
        return list(best.values())


## Cell 87 – Chain reconstruction (anchor‑and‑walk)

Given hits, we **seed** with the strongest, then walk downstream/upstream using
expected exon order and non‑overlapping windows.


In [None]:
# ===== Cell 87 =====
@dataclass
class RexChain:
    accession: str
    gene_symbol: str
    seed_exon: int
    exons: List[int]
    coords: List[Tuple[int, int]]
    tiers: List[str]
    mean_score: float
    consecutive_blocks: int

def rex_exon_order_model(training_df: pd.DataFrame, gene: str) -> List[int]:
    exons = (
        training_df[training_df["gene_symbol"] == gene]["exon_num_in_chain"]
        .dropna().astype(int).sort_values().unique().tolist()
    )
    return exons

def rex_walk_chain(
    seq_len: int,
    seed: RexHit,
    hits_for_gene: Dict[int, RexHit],
    exon_order: List[int],
    window_pad: int = rex_search_window_pad
) -> RexChain:
    # Determine exon neighbors
    if seed.exon_num_in_chain not in exon_order:
        return RexChain(seed.accession, seed.gene_symbol, seed.exon_num_in_chain, [seed.exon_num_in_chain],
                        [(seed.start, seed.end)], [seed.tier], seed.score, 1)

    idx = exon_order.index(seed.exon_num_in_chain)
    used = {seed.exon_num_in_chain}
    exons = [seed.exon_num_in_chain]
    coords = [(seed.start, seed.end)]
    tiers = [seed.tier]

    # Walk downstream
    last_end = seed.end
    consec = 1
    for j in range(idx + 1, len(exon_order)):
        ex = exon_order[j]
        h = hits_for_gene.get(ex)
        if not h:
            break
        # Require non-overlap and in-order progression within a window
        if h.start >= last_end and h.start <= last_end + window_pad:
            exons.append(ex)
            coords.append((h.start, h.end))
            tiers.append(h.tier)
            last_end = h.end
            consec += 1
        else:
            break

    # Walk upstream
    first_start = seed.start
    for j in range(idx - 1, -1, -1):
        ex = exon_order[j]
        h = hits_for_gene.get(ex)
        if not h:
            break
        if h.end <= first_start and h.end >= max(0, first_start - window_pad):
            exons.insert(0, ex)
            coords.insert(0, (h.start, h.end))
            tiers.insert(0, h.tier)
            first_start = h.start
            consec += 1
        else:
            break

    mean_score = float(np.mean([hits_for_gene[e].score if e in hits_for_gene else 0.0 for e in exons]))
    return RexChain(seed.accession, seed.gene_symbol, seed.exon_num_in_chain, exons, coords, tiers, mean_score, consec)


## Cell 88 – Orchestrator (build → scan → chain) + ledgers

Produces:
- `rex_hits_df`: per‑exon best hit per gene
- `rex_chains_df`: best chain per gene (seeded by top hit)


In [None]:
# ===== Cell 88 =====
def rex_run_regextractor(
    training_df: pd.DataFrame,
    rejected_df: pd.DataFrame,
    genes: List[str] = GENE_SYMBOLS,
    min_clade_samples: int = rex_min_clade_samples
) -> Tuple[pd.DataFrame, pd.DataFrame]:
    """
    Full pipeline:
      1) Build pattern library per gene/exon/clade
      2) Scan rejected sequences → hits
      3) Build chains per sequence/gene using anchor-and-walk
    """
    lib = RegExTractorLibrary()
    lib.build(training_df, genes, min_clade_samples=min_clade_samples)
    matcher = RegExTractorMatcher(lib)

    hits_rows: List[Dict[str, Any]] = []
    chains_rows: List[Dict[str, Any]] = []

    # Precompute exon order per gene
    exon_orders = {g: rex_exon_order_model(training_df, g) for g in genes}

    for _, row in rejected_df.iterrows():
        seq_hits = matcher.scan_sequence_row(row, genes=genes)
        if not seq_hits:
            continue

        # Record best hit per exon/gene
        for h in seq_hits:
            hits_rows.append({
                "accession": h.accession,
                "gene_symbol": h.gene_symbol,
                "exon_num_in_chain": h.exon_num_in_chain,
                "clade_used": h.clade_used,
                "tier": h.tier,
                "start": h.start,
                "end": h.end,
                "length": h.end - h.start,
                "matched_peptide": h.matched_peptide,
                "ghead_density": h.ghead_density,
                "len_residual": h.len_residual,
                "anchors_hit": h.anchors_hit,
                "score": h.score
            })

        # Build chains for each gene separately, pick strongest seed
        by_gene: Dict[str, List[RexHit]] = {}
        for h in seq_hits:
            by_gene.setdefault(h.gene_symbol, []).append(h)

        for gene, ghits in by_gene.items():
            # Index by exon for quick lookup
            hits_by_exon = {h.exon_num_in_chain: h for h in ghits}
            # Choose seed = highest-score hit
            seed = sorted(ghits, key=lambda x: x.score, reverse=True)[0]
            chain = rex_walk_chain(
                seq_len=len(row["sequence"]),
                seed=seed,
                hits_for_gene=hits_by_exon,
                exon_order=exon_orders.get(gene, [])
            )
            chains_rows.append({
                "accession": chain.accession,
                "gene_symbol": chain.gene_symbol,
                "seed_exon": chain.seed_exon,
                "exons": chain.exons,
                "coords": chain.coords,
                "tiers": chain.tiers,
                "mean_score": chain.mean_score,
                "consecutive_blocks": chain.consecutive_blocks
            })

    rex_hits_df = pd.DataFrame(hits_rows)
    rex_chains_df = pd.DataFrame(chains_rows)

    # Optional: filter chains by minimal consecutive blocks
    if not rex_chains_df.empty:
        rex_chains_df = rex_chains_df[
            rex_chains_df["consecutive_blocks"].astype(int) >= rex_chain_min_consecutive
        ].reset_index(drop=True)

    rex_log(f"Emitted {len(rex_hits_df)} exon hits and {len(rex_chains_df)} chains.")
    return rex_hits_df, rex_chains_df


## Cell 89 – DNA rescue stub (optional integration point)

After chains are placed, you can bracket a missing exon by two neighboring
exons and apply a lightweight DNA translation scan. This cell is **scaffold**
only — plug in your genome accessor when ready.


In [None]:
# ===== Cell 89 =====
def rex_dna_rescue_stub(
    accession: str,
    bracket_left: Tuple[int, int],
    bracket_right: Tuple[int, int],
    genome_accessor: Any
) -> Optional[str]:
    """
    Placeholder for DNA rescue step (to be wired to your genome/FASTA accessor):

    - Extract DNA between `bracket_left` and `bracket_right` genomic coordinates.
    - Translate 3 frames on the coding strand.
    - Pick frame with highest G-head density and suitable length.
    - Return peptide if recovered, else None.

    This function is NOT implemented here to keep this section dependency-free.
    """
    _ = (accession, bracket_left, bracket_right, genome_accessor)
    return None


## Cell 90 – How to call this in your notebook

1) Prepare/normalize inputs:

```python
cmap = RexColumnMap(
    gene_col="gene_symbol",
    exon_col="exon_num_in_chain",
    pep_col="exon_peptide",
    seq_col="sequence",
    acc_col="accession",
    tax_cols={"genus":"genus","family":"family","order":"order","class":"class","kingdom":"kingdom"}
)

training_norm = rex_normalize_training_df(mapped_exon_df, cmap)
rejected_norm = rex_normalize_rejected_df(rejected_df, cmap)


In [None]:
cmap = RexColumnMap(
    gene_col="gene_symbol",
    exon_col="exon_num_in_chain",
    pep_col="exon_peptide",
    seq_col="sequence",
    acc_col="accession",
    tax_cols={"genus":"genus","family":"family","order":"order","class":"class","kingdom":"kingdom"}
)

training_norm = rex_normalize_training_df(mapped_exon_df, cmap)
rejected_norm = rex_normalize_rejected_df(rejected_df, cmap)

NameError: name 'mapped_exon_df' is not defined

# **Part 9: Reproducibility & Manifest**

## Cell 91 – Manifest writer

Writes a plaintext manifest with core counts and SHA256 hashes of outputs.

In [None]:
# ===== Cell 91 =====
# Manifest writer
def sha256_file(p: Path) -> str:
    if not p.exists(): return ""
    h = hashlib.sha256()
    with open(p, 'rb') as f:
        for chunk in iter(lambda: f.read(65536), b''): h.update(chunk)
    return h.hexdigest()

manifest_lines = [
    f"RUN_ID: {RUN_ID}",
    f"TIME_UTC: {RUN_TIMESTAMP}",
    f"WORKING_ROWS: {len(working_df) if 'working_df' in globals() else 0}",
    f"CHAIN_ROWS: {len(chain_df) if 'chain_df' in globals() else 0}",
    f"MAPPED_ROWS: {len(exon_df) if 'exon_df' in globals() else 0}",
    f"CONSENSUS_ROWS: {len(consensus_long) if 'consensus_long' in globals() else 0}",
    f"WIDE_ROWS: {len(wide_df) if 'wide_df' in globals() else 0}",
    "FILES:"
]
for p in [WORKING_SNAPSHOT, REJECTED_SNAPSHOT, MAPPED_SNAPSHOT,
          CONSENSUS_LONG_SNAPSHOT, WIDE_ARCH_SNAPSHOT, ENTROPY_TSV, RESCUE_LOG_TSV]:
    manifest_lines.append(f" - {p.name}: {sha256_file(p)}")

with open(MANIFEST_PATH, 'w') as f:
    f.write("\n".join(manifest_lines))
logger.info(f"Manifest written: {MANIFEST_PATH}")