In [None]:
# ═══════════════════════════════════════════════════════════════════════════════
# TAB 1 — Setup · File Detection · GPU Detection
# Edit FASTA_INPUT and OUTPUT_DIR, then run all three cells in order.
# ═══════════════════════════════════════════════════════════════════════════════

import sys, os, importlib, glob, warnings
from pathlib import Path
warnings.filterwarnings('ignore')

# ── Add repo root to path ─────────────────────────────────────────────────────
_REPO_ROOT = os.path.abspath(os.getcwd())
if _REPO_ROOT not in sys.path:
    sys.path.insert(0, _REPO_ROOT)

# ── Auto-install missing packages ─────────────────────────────────────────────
_REQUIRED = [
    ('psutil','psutil>=5.8.0'), ('pandas','pandas>=1.3.0'),
    ('numpy','numpy>=1.21.0'),  ('matplotlib','matplotlib>=3.5.0'),
    ('seaborn','seaborn>=0.11.0'), ('openpyxl','openpyxl>=3.0.0'),
    ('tqdm','tqdm>=4.64.0'),
]
_miss = [p for m,p in _REQUIRED if importlib.util.find_spec(m) is None]
if _miss:
    import subprocess; subprocess.check_call([sys.executable,'-m','pip','install',*_miss,'-q'])

import pandas as pd, numpy as np

# ── User Configuration ────────────────────────────────────────────────────────
FASTA_INPUT        = ['*.fna', '*.fasta']   # path, wildcard, or list
OUTPUT_DIR         = 'notebook_reports'
ENABLED_CLASSES    = None                   # None = all 9 detectors
RAM_OVERRIDE_BYTES = None                   # None = auto-detect
EXPORT_CSV         = True
EXPORT_BED         = True
EXPORT_JSON        = True
EXPORT_EXCEL       = True

# ── GPU detection ─────────────────────────────────────────────────────────────
def _detect_gpu():
    try:
        import torch
        if torch.cuda.is_available():
            return 'cuda', torch.cuda.get_device_name(0)
    except ImportError:
        pass
    try:
        import cupy as cp; cp.array([1])
        return 'cupy', 'CUDA GPU'
    except Exception:
        pass
    return None, None

GPU_BACKEND, GPU_NAME = _detect_gpu()
_gpu_msg = f'GPU  {GPU_BACKEND} ({GPU_NAME})' if GPU_BACKEND else 'GPU  none (CPU only)'
print(f'\u2705 Deps OK | Python {sys.version.split()[0]} | {_gpu_msg}')

# ── Resolve input files ───────────────────────────────────────────────────────
def _resolve(inp):
    out = []
    for p in ([inp] if isinstance(inp, str) else list(inp)):
        hits = glob.glob(p)
        out.extend(hits)
        if not hits and os.path.isfile(p):
            out.append(p)
    return sorted({str(Path(f).resolve()) for f in out})

FASTA_FILES = _resolve(FASTA_INPUT)
if not FASTA_FILES:
    raise FileNotFoundError(f'No FASTA files found: {FASTA_INPUT}')

# ── Sequence-count helper ─────────────────────────────────────────────────────
def _seq_lengths(p):
    L, c = [], 0
    with open(p) as fh:
        for ln in fh:
            s = ln.strip()
            if s.startswith('>'):
                if c: L.append(c)
                c = 0
            else:
                c += len(s)
    if c: L.append(c)
    return L

# ── File-type classification ──────────────────────────────────────────────────
FILE_TYPES = {}   # path -> 'single' | 'multi' | 'multi_equal'
for fp in FASTA_FILES:
    ls = _seq_lengths(fp)
    if   len(ls) == 1:           FILE_TYPES[fp] = 'single'
    elif len(set(ls)) == 1:      FILE_TYPES[fp] = 'multi_equal'
    else:                         FILE_TYPES[fp] = 'multi'

# ── GFF pairing (same stem, .gff or .gff3) ────────────────────────────────────
GFF_MAP = {}   # fna_path -> gff_path  (only .fna files with a matching GFF)
for fp in FASTA_FILES:
    stem   = Path(fp).stem
    parent = Path(fp).parent
    for ext in ('.gff3', '.gff'):
        candidate = parent / (stem + ext)
        if candidate.exists():
            GFF_MAP[fp] = str(candidate)
            break

# ── Summary ───────────────────────────────────────────────────────────────────
print(f'\n\U0001f4c2 Input files: {len(FASTA_FILES)}')
for fp in FASTA_FILES:
    gff_tag = f'  +GFF: {Path(GFF_MAP[fp]).name}' if fp in GFF_MAP else ''
    print(f'   [{FILE_TYPES[fp]:12s}]  {Path(fp).name}{gff_tag}')
print(f'\n\U0001f4c1 Output dir : {OUTPUT_DIR}')


In [None]:
# ═══════════════════════════════════════════════════════════════════════════════
# TAB 2 — Adaptive Analysis  (whole-genome + GFF-region analysis)
# ═══════════════════════════════════════════════════════════════════════════════

import gc, time, datetime, concurrent.futures
import matplotlib; matplotlib.use('Agg')
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm

from Utilities.system_resource_inspector import SystemResourceInspector
from Utilities.adaptive_chunk_planner    import AdaptiveChunkPlanner
from Utilities.nonbscanner               import analyze_sequence as _nbf_analyze
from Utilities.utilities                 import (
    read_fasta_file, export_to_csv, export_to_bed,
    export_to_json, export_to_excel,
)

# ── Adaptive resource plan ────────────────────────────────────────────────────
_insp   = SystemResourceInspector()
_budget = RAM_OVERRIDE_BYTES or _insp.get_memory_budget()
_cpus   = _insp.get_cpu_count()
_total  = max(sum(os.path.getsize(f) for f in FASTA_FILES if os.path.exists(f)), 1_000)
_plan   = AdaptiveChunkPlanner().plan(_total, _budget, _cpus)
CHUNK_SIZE, CHUNK_OVERLAP = _plan['chunk_size'], _plan['overlap']
N_WORKERS, EXEC_MODE      = _plan['workers'], _plan['mode']

# Boost workers when GPU acceleration is active
if GPU_BACKEND:
    N_WORKERS = min(N_WORKERS * 2, os.cpu_count() or 4)

print(f'\u2699\ufe0f  RAM {_budget/1e9:.2f} GB | chunk={CHUNK_SIZE:,} overlap={CHUNK_OVERLAP:,} '
      f'workers={N_WORKERS} mode={EXEC_MODE} gpu={GPU_BACKEND or "none"}')

# ── Run folder ────────────────────────────────────────────────────────────────
_RUN_TS = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
_BASE   = Path(OUTPUT_DIR) / _RUN_TS
_BASE.mkdir(parents=True, exist_ok=True)
print(f'\U0001f4c2 Run output: {_BASE}')

# ── Lightweight GFF parser ────────────────────────────────────────────────────
def _parse_gff(gff_path):
    feats = []
    with open(gff_path) as fh:
        for ln in fh:
            if ln.startswith('#') or not ln.strip():
                continue
            p = ln.rstrip('\n').split('\t')
            if len(p) < 8:
                continue
            try:
                feats.append({
                    'seqid':  p[0],
                    'type':   p[2],
                    'start':  max(int(p[3]) - 1, 0),
                    'end':    int(p[4]),
                    'strand': p[6],
                    'attrs':  p[8] if len(p) > 8 else '',
                })
            except ValueError:
                pass
    return feats

# ── Motif scan wrapper ────────────────────────────────────────────────────────
def _scan(name, seq):
    return _nbf_analyze(
        sequence=seq, sequence_name=name,
        use_chunking=True,
        chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP,
        use_parallel_chunks=(EXEC_MODE == 'hybrid'),
        enabled_classes=ENABLED_CLASSES,
    )

# ── Per-file analysis ─────────────────────────────────────────────────────────
RESULTS_BY_FILE  = {}   # stem -> {df, folder, file_type, path}
GFF_RESULTS      = {}   # stem -> {region_df, gff_path, folder}
sns.set_theme(style='whitegrid')

for fasta_path in tqdm(FASTA_FILES, desc='Files', unit='file'):
    stem     = Path(fasta_path).stem
    ftype    = FILE_TYPES[fasta_path]
    file_dir = _BASE / stem
    file_dir.mkdir(parents=True, exist_ok=True)
    tqdm.write(f'\n\u2500\u2500 {stem}  [{ftype}] \u2500\u2500')

    seqs = read_fasta_file(fasta_path)
    if not seqs:
        tqdm.write('  \u26a0\ufe0f  No sequences \u2014 skipping.')
        continue

    # ── Whole-genome scan (parallel over sequences) ───────────────────────────
    motifs_file, t0 = [], time.perf_counter()
    with concurrent.futures.ThreadPoolExecutor(max_workers=N_WORKERS) as pool:
        futs = {pool.submit(_scan, sn, sq): sn for sn, sq in seqs.items()}
        for fut in tqdm(concurrent.futures.as_completed(futs),
                        total=len(futs), desc=f'  seqs({stem})', leave=False):
            sn = futs[fut]
            res = fut.result()
            tqdm.write(f'  \u25b8 {sn[:55]}  \u2192 {len(res):,} motifs')
            motifs_file.extend(res)
    tqdm.write(f'  \u2705 {len(motifs_file):,} motifs in {time.perf_counter()-t0:.1f}s')
    gc.collect()

    # ── Build per-file DataFrame ──────────────────────────────────────────────
    df = pd.DataFrame(motifs_file) if motifs_file else pd.DataFrame()
    for col, dflt in [('Class','Unknown'),('Subclass','Other'),('Start',0),
                      ('End',0),('Length',0),('Score',0.0),('Strand','+'),
                      ('Sequence_Name','')]:
        if col not in df.columns: df[col] = dflt
    if not df.empty:
        m = df['Length'] == 0
        df.loc[m,'Length'] = (df.loc[m,'End'] - df.loc[m,'Start']).clip(lower=0)
    df['Source_File'] = Path(fasta_path).name
    df['File_Type']   = ftype

    # ── Per-file exports ──────────────────────────────────────────────────────
    if not df.empty:
        rows = df.to_dict(orient='records')
        if EXPORT_CSV:   export_to_csv(rows,   filename=str(file_dir/'motifs.csv'))
        if EXPORT_BED:   export_to_bed(rows,   filename=str(file_dir/'motifs.bed'))
        if EXPORT_JSON:  export_to_json(rows,  filename=str(file_dir/'motifs.json'))
        if EXPORT_EXCEL: export_to_excel(rows, filename=str(file_dir/'motifs.xlsx'))

    # ── Per-file class-distribution plot ─────────────────────────────────────
    if not df.empty:
        cc  = df['Class'].value_counts()
        fig, ax = plt.subplots(figsize=(8,3))
        ax.barh(cc.index[::-1], cc.values[::-1])
        ax.set_xlabel('Motif Count'); ax.set_title(f'{stem} [{ftype}] \u2014 Class Distribution')
        plt.tight_layout(); fig.savefig(str(file_dir/'class_distribution.png'), dpi=150)
        plt.close(fig)

    RESULTS_BY_FILE[stem] = {'df':df,'folder':file_dir,'file_type':ftype,'path':fasta_path}

    # ── GFF region analysis ───────────────────────────────────────────────────
    if fasta_path in GFF_MAP:
        gff_path = GFF_MAP[fasta_path]
        tqdm.write(f'  \U0001f4cb GFF: {Path(gff_path).name}')
        features  = _parse_gff(gff_path)
        tqdm.write(f'     {len(features):,} features parsed')

        gff_dir = file_dir / 'gff_regions'
        gff_dir.mkdir(exist_ok=True)

        region_rows = []
        feat_types  = sorted({f['type'] for f in features})
        for ftype_gff in tqdm(feat_types, desc=f'  GFF({stem})', leave=False):
            type_feats  = [f for f in features if f['type'] == ftype_gff]
            type_motifs = []
            for feat in type_feats:
                seq_id     = feat['seqid']
                if seq_id not in seqs:
                    continue
                region_seq = seqs[seq_id][feat['start']:feat['end']]
                if len(region_seq) < 12:   # 12 bp = minimum motif length in any detector
                    continue
                rname = f"{seq_id}:{ftype_gff}:{feat['start']}-{feat['end']}({feat['strand']})"
                mots  = _scan(rname, region_seq)
                for m in mots:
                    m['GFF_Type']   = ftype_gff
                    m['GFF_SeqID']  = seq_id
                    m['GFF_Start']  = feat['start']
                    m['GFF_End']    = feat['end']
                    m['GFF_Strand'] = feat['strand']
                    _a = feat['attrs']
                    m['GFF_Attrs']  = _a[:80] + ('...' if len(_a) > 80 else '')
                type_motifs.extend(mots)
            region_rows.extend(type_motifs)
            tqdm.write(f'     {ftype_gff}: {len(type_feats):,} regions -> {len(type_motifs):,} motifs')
            gc.collect()

        gff_df = pd.DataFrame(region_rows) if region_rows else pd.DataFrame()
        for col, dflt in [('Class','Unknown'),('Subclass','Other'),('Start',0),
                          ('End',0),('Length',0),('Score',0.0),
                          ('GFF_Type',''),('GFF_SeqID',''),
                          ('GFF_Start',0),('GFF_End',0),
                          ('GFF_Strand','+'),('GFF_Attrs','')]:
            if col not in gff_df.columns: gff_df[col] = dflt

        if not gff_df.empty:
            gff_df.to_csv(str(gff_dir/'gff_region_motifs.csv'), index=False)
            pivot = gff_df.groupby(['GFF_Type','Class']).size().unstack(fill_value=0)
            fig, ax = plt.subplots(figsize=(max(8, len(pivot)*1.4), 5))
            pivot.plot(kind='bar', ax=ax, colormap='tab20', width=0.8)
            ax.set_xlabel('GFF Feature Type'); ax.set_ylabel('Motif Count')
            ax.set_title(f'{stem} \u2014 Motifs per GFF Feature Type')
            ax.legend(title='Class', bbox_to_anchor=(1,1))
            plt.tight_layout()
            fig.savefig(str(gff_dir/'gff_motifs_by_type.png'), dpi=150)
            plt.close(fig)

        GFF_RESULTS[stem] = {'region_df':gff_df,'gff_path':gff_path,'folder':gff_dir}
        tqdm.write(f'  \u2705 GFF analysis: {len(gff_df):,} region motifs')

print(f'\n\u2705 Analysis complete \u2014 {len(RESULTS_BY_FILE)} file(s) processed '
      f'({len(GFF_RESULTS)} with GFF).')


In [None]:
# ═══════════════════════════════════════════════════════════════════════════════
# TAB 3 — Master Tables, Summary Plots & Downloads
# ═══════════════════════════════════════════════════════════════════════════════

import base64
from IPython.display import display, HTML, Image

# ── Master DataFrames ─────────────────────────────────────────────────────────
_dfs       = [r['df'] for r in RESULTS_BY_FILE.values() if not r['df'].empty]
_master_df = pd.concat(_dfs, ignore_index=True) if _dfs else pd.DataFrame()

_master_dir = _BASE / '_master'
_master_dir.mkdir(exist_ok=True)

_gdfs   = [v['region_df'] for v in GFF_RESULTS.values() if not v['region_df'].empty]
_gff_df = pd.concat(_gdfs, ignore_index=True) if _gdfs else pd.DataFrame()

# ── Summary tables ────────────────────────────────────────────────────────────
_tables = {}
if not _master_df.empty:
    _tables['1_global_class_distribution'] = (
        _master_df.groupby(['Source_File','File_Type','Class'])
        .size().reset_index(name='Count')
    )
    _tables['2_per_file_summary'] = pd.DataFrame([
        {'File': Path(r['path']).name, 'File_Type': r['file_type'],
         'Sequences': r['df']['Sequence_Name'].nunique() if not r['df'].empty else 0,
         'Total_Motifs': len(r['df']),
         'Classes': r['df']['Class'].nunique() if not r['df'].empty else 0}
        for r in RESULTS_BY_FILE.values()
    ])
    _tables['3_class_statistics'] = (
        _master_df.groupby('Class')
        .agg(Total_Count=('Class','count'), Mean_Length=('Length','mean'),
             Mean_Score=('Score','mean'))
        .round(3).reset_index().sort_values('Total_Count', ascending=False)
    )
    _tables['4_file_class_pivot'] = (
        _master_df.groupby(['Source_File','Class'])
        .size().unstack(fill_value=0).reset_index()
    )
    _eq_dfs = [r['df'] for r in RESULTS_BY_FILE.values()
               if r['file_type'] == 'multi_equal' and not r['df'].empty]
    if _eq_dfs:
        _eq = pd.concat(_eq_dfs, ignore_index=True)
        _tables['5_equal_length_positional'] = (
            _eq.groupby(['Source_File','Class','Start'])
            .size().reset_index(name='Frequency')
            .sort_values(['Source_File','Class','Frequency'], ascending=[True,True,False])
        )

# ── GFF-specific summary tables ───────────────────────────────────────────────
if not _gff_df.empty:
    _tables['6_gff_motifs_per_feature_type'] = (
        _gff_df.groupby(['GFF_Type','Class'])
        .size().reset_index(name='Count')
        .sort_values('Count', ascending=False)
    )
    _tables['7_gff_density_per_feature'] = (
        _gff_df.assign(Region_Len=(_gff_df['GFF_End']-_gff_df['GFF_Start']).clip(lower=1))
        .groupby('GFF_Type')
        .agg(Total_Motifs=('Class','count'),
             Unique_Classes=('Class','nunique'),
             Mean_Region_Len=('Region_Len','mean'))
        .round(2).reset_index()
        .sort_values('Total_Motifs', ascending=False)
    )
    _tables['8_gff_class_pivot'] = (
        _gff_df.groupby(['GFF_Type','Class'])
        .size().unstack(fill_value=0).reset_index()
    )
    _tables['9_gff_top50_hotspot_regions'] = (
        _gff_df.groupby(['GFF_SeqID','GFF_Type','GFF_Start','GFF_End'])
        .agg(Motif_Count=('Class','count'), Classes=('Class','nunique'))
        .reset_index().sort_values('Motif_Count', ascending=False).head(50)
    )

# ── Export all data ────────────────────────────────────────────────────────────
if not _master_df.empty:
    rows = _master_df.to_dict(orient='records')
    if EXPORT_CSV:   export_to_csv(rows,   filename=str(_master_dir/'master_motifs.csv'))
    if EXPORT_BED:   export_to_bed(rows,   filename=str(_master_dir/'master_motifs.bed'))
    if EXPORT_JSON:  export_to_json(rows,  filename=str(_master_dir/'master_motifs.json'))
    if EXPORT_EXCEL: export_to_excel(rows, filename=str(_master_dir/'master_motifs.xlsx'))
if not _gff_df.empty:
    _gff_df.to_csv(str(_master_dir/'gff_region_motifs_all.csv'), index=False)
for tname, tdf in _tables.items():
    tdf.to_csv(str(_master_dir/f'{tname}.csv'), index=False)

# ── Master summary plots ──────────────────────────────────────────────────────
_ncol  = (2
          + (1 if len(RESULTS_BY_FILE) > 1 else 0)
          + (1 if not _gff_df.empty else 0))
fig, axes = plt.subplots(1, _ncol, figsize=(6*_ncol, 4))
_ax = iter([axes] if _ncol == 1 else axes.flat)

if not _master_df.empty:
    cc = _master_df['Class'].value_counts()
    a  = next(_ax)
    a.barh(cc.index[::-1], cc.values[::-1], color='steelblue')
    a.set_xlabel('Count'); a.set_title('Global Class Distribution')

    ft = _master_df.groupby('File_Type').size()
    a  = next(_ax)
    a.pie(ft.values, labels=ft.index, autopct='%1.1f%%', startangle=90)
    a.set_title('Motifs by File Type')

    if len(RESULTS_BY_FILE) > 1:
        pf = _master_df.groupby('Source_File').size().sort_values(ascending=False)
        a  = next(_ax)
        a.barh([Path(n).stem[:22] for n in pf.index[::-1]], pf.values[::-1], color='coral')
        a.set_xlabel('Count'); a.set_title('Motifs per File')

if not _gff_df.empty:
    a  = next(_ax)
    gt = _gff_df['GFF_Type'].value_counts().head(15)
    a.barh(gt.index[::-1], gt.values[::-1], color='mediumseagreen')
    a.set_xlabel('Count'); a.set_title('Motifs by GFF Feature Type')

plt.tight_layout()
_plot = str(_master_dir/'master_summary.png')
fig.savefig(_plot, dpi=150); plt.close(fig)
display(Image(_plot))

# ── GFF feature-type x class heatmap ─────────────────────────────────────────
if not _gff_df.empty and '8_gff_class_pivot' in _tables:
    _piv = _tables['8_gff_class_pivot'].set_index('GFF_Type')
    fig2, ax2 = plt.subplots(figsize=(max(10, len(_piv.columns)*1.2),
                                       max(4,  len(_piv)*0.6)))
    sns.heatmap(_piv, annot=True, fmt='d', cmap='YlOrRd', ax=ax2,
                linewidths=0.4, cbar_kws={'label':'Motif count'})
    ax2.set_title('GFF Feature Type \u00d7 Non-B Class Heatmap')
    ax2.set_xlabel('Non-B Class'); ax2.set_ylabel('GFF Feature Type')
    plt.tight_layout()
    _gheat = str(_master_dir/'gff_class_heatmap.png')
    fig2.savefig(_gheat, dpi=150); plt.close(fig2)
    display(Image(_gheat))

# ── Display tables ────────────────────────────────────────────────────────────
for tname, tdf in _tables.items():
    print(f"\n{'='*60}\n{tname.replace('_',' ').upper()}\n{'='*60}")
    display(tdf)

# ── Download helpers ──────────────────────────────────────────────────────────
_MIME = {'csv':'text/csv','bed':'text/plain','json':'application/json',
         'xlsx':'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
         'png':'image/png'}

def _dl(path, label):
    with open(path,'rb') as fh: b64 = base64.b64encode(fh.read()).decode()
    ext  = Path(path).suffix.lstrip('.')
    mime = _MIME.get(ext,'application/octet-stream')
    return (f'<a href="data:{mime};base64,{b64}" download="{Path(path).name}" '
            f'style="margin:2px 6px;padding:3px 8px;border:1px solid #aaa;'
            f'border-radius:4px;text-decoration:none;">{label}</a>')

_html = ['<h2>\U0001f4e5 Downloads</h2><h3>Master Outputs</h3><div>']
for fmt,fn in [('CSV','master_motifs.csv'),('BED','master_motifs.bed'),
               ('JSON','master_motifs.json'),('Excel','master_motifs.xlsx')]:
    p = _master_dir/fn
    if p.exists(): _html.append(_dl(str(p), f'Master {fmt}'))
if (_master_dir/'gff_region_motifs_all.csv').exists():
    _html.append(_dl(str(_master_dir/'gff_region_motifs_all.csv'), 'GFF Regions CSV'))
_html.append('</div><h3>Summary Tables</h3><div>')
for tn in _tables:
    p = _master_dir/f'{tn}.csv'
    if p.exists(): _html.append(_dl(str(p), tn.replace('_',' ').title()))
_html.append('</div><h3>Summary Plots</h3><div>')
for fn,lbl in [('master_summary.png','Master Summary'),
               ('gff_class_heatmap.png','GFF Heatmap')]:
    p = _master_dir/fn
    if p.exists(): _html.append(_dl(str(p), lbl))
_html.append('</div><h3>Per-File Outputs</h3>')
for stem, res in RESULTS_BY_FILE.items():
    _html.append(f'<details style="margin:4px 0"><summary><b>{stem}</b>'
                 f' <em>[{res["file_type"]}]</em></summary><div style="margin:4px 12px">')
    for fmt,fn in [('CSV','motifs.csv'),('BED','motifs.bed'),
                   ('JSON','motifs.json'),('Excel','motifs.xlsx')]:
        p = res['folder']/fn
        if p.exists(): _html.append(_dl(str(p), fmt))
    for fn,lbl in [('class_distribution.png','Plot'),
                   ('gff_regions/gff_motifs_by_type.png','GFF Plot'),
                   ('gff_regions/gff_region_motifs.csv','GFF CSV')]:
        p = res['folder']/fn
        if p.exists(): _html.append(_dl(str(p), lbl))
    _html.append('</div></details>')

display(HTML('\n'.join(_html)))
print(f'\n\u2705 All outputs: {_BASE}')


In [None]:
# ═══════════════════════════════════════════════════════════════════════════════
# TAB 4 — Species & Region Comparison
#
# Groups files by species prefix (e.g. "Homo.Sapiens" from
# "Homo.Sapiens_promoters.fasta") and the region tag after the first "_".
#
# Within-species  : class distribution, subclass distribution, motif density,
#                   motif length, and GC % across all region types.
# Cross-species   : same metrics compared between species for shared regions.
#
# Requires Cell 1 (Setup) and Cell 2 (Analysis) to have been run first.
# ═══════════════════════════════════════════════════════════════════════════════

import re
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from IPython.display import display, Image

sns.set_theme(style='whitegrid')

# ── helpers ──────────────────────────────────────────────────────────────────────────────

def _parse_species_region(stem):
    """Split 'Homo.Sapiens_promoters' into ('Homo.Sapiens', 'promoters').
    Falls back to (stem, 'unknown') when no '_' separator is present."""
    idx = stem.find('_')
    if idx == -1:
        return stem, 'unknown'
    return stem[:idx], stem[idx + 1:]


def _gc_and_length(fasta_path):
    """Return (gc_percent, total_bp) from a FASTA file."""
    gc = total = 0
    with open(fasta_path) as fh:
        for ln in fh:
            s = ln.strip()
            if not s or s.startswith('>'):
                continue
            su = s.upper()
            gc    += su.count('G') + su.count('C')
            total += len(su)
    gc_pct = (gc / total * 100) if total else 0.0
    return round(gc_pct, 2), total


def _savefig(fig, path):
    fig.savefig(str(path), dpi=150, bbox_inches='tight')
    plt.close(fig)
    display(Image(str(path)))


# ── build comparison data frame ──────────────────────────────────────────────────────

_comp_rows = []
for stem, res in RESULTS_BY_FILE.items():
    species, region = _parse_species_region(stem)
    df  = res['df']
    fp  = res['path']
    gc_pct, seq_len = _gc_and_length(fp)
    n   = len(df)
    density = (n / seq_len * 1000) if seq_len > 0 else 0.0
    _comp_rows.append({
        'Stem':                stem,
        'Species':             species,
        'Region':              region,
        'Total_Motifs':        n,
        'Seq_Length_bp':       seq_len,
        'Density_per_kb':      round(density, 4),
        'GC_Percent':          gc_pct,
        'Mean_Motif_Length':   round(df['Length'].mean(),   2) if not df.empty else 0.0,
        'Median_Motif_Length': round(df['Length'].median(), 2) if not df.empty else 0.0,
        'Unique_Classes':      df['Class'].nunique()    if not df.empty else 0,
        'Unique_Subclasses':   df['Subclass'].nunique() if not df.empty else 0,
    })

_comp_df      = pd.DataFrame(_comp_rows)
_species_list = sorted(_comp_df['Species'].unique())
_region_list  = sorted(_comp_df['Region'].unique())
print(f'Species detected : {_species_list}')
print(f'Regions detected : {_region_list}')

_cmp_dir = _BASE / '_comparisons'
_cmp_dir.mkdir(exist_ok=True)

# ══════════════════════════════════════════════════════════════════════════════
# WITHIN-SPECIES COMPARISON
# ══════════════════════════════════════════════════════════════════════════════

for species in _species_list:
    sp_rows  = _comp_df[_comp_df['Species'] == species].copy()
    sp_stems = sp_rows['Stem'].tolist()
    sp_dir   = _cmp_dir / species
    sp_dir.mkdir(exist_ok=True)

    if len(sp_stems) < 2:
        print(f"\n\u26a0  '{species}' has only one region file — skipping within-species plots.")
        continue

    print(f"\n{'\u2550'*60}\nWithin-species comparison: {species}\n{'\u2550'*60}")

    # ── 1. Class distribution by region ───────────────────────────────────────────
    _class_by_region = {}
    for stem in sp_stems:
        df  = RESULTS_BY_FILE[stem]['df']
        reg = sp_rows.loc[sp_rows['Stem'] == stem, 'Region'].values[0]
        _class_by_region[reg] = (
            df['Class'].value_counts() if not df.empty else pd.Series(dtype=int)
        )

    _all_cls = sorted({c for s in _class_by_region.values() for c in s.index})
    if _all_cls:
        _cmat = pd.DataFrame(
            {r: s.reindex(_all_cls, fill_value=0)
             for r, s in _class_by_region.items()}
        ).T
        fig, ax = plt.subplots(figsize=(max(8, len(_all_cls) * 1.2), 4))
        _cmat.plot(kind='bar', ax=ax, colormap='tab20', width=0.8)
        ax.set_title(f'{species} \u2014 Class Distribution by Region')
        ax.set_xlabel('Region'); ax.set_ylabel('Motif Count')
        ax.legend(title='Class', bbox_to_anchor=(1, 1))
        plt.xticks(rotation=30, ha='right')
        _savefig(fig, sp_dir / 'class_by_region.png')

    # ── 2. Subclass distribution by region ─────────────────────────────────────────
    _sub_by_region = {}
    for stem in sp_stems:
        df  = RESULTS_BY_FILE[stem]['df']
        reg = sp_rows.loc[sp_rows['Stem'] == stem, 'Region'].values[0]
        _sub_by_region[reg] = (
            df['Subclass'].value_counts() if not df.empty else pd.Series(dtype=int)
        )

    _all_subs = sorted({c for s in _sub_by_region.values() for c in s.index})
    if _all_subs:
        _smat = pd.DataFrame(
            {r: s.reindex(_all_subs, fill_value=0)
             for r, s in _sub_by_region.items()}
        ).T
        fig, ax = plt.subplots(figsize=(max(8, len(_all_subs) * 1.2), 4))
        _smat.plot(kind='bar', ax=ax, colormap='tab20', width=0.8)
        ax.set_title(f'{species} \u2014 Subclass Distribution by Region')
        ax.set_xlabel('Region'); ax.set_ylabel('Motif Count')
        ax.legend(title='Subclass', bbox_to_anchor=(1, 1))
        plt.xticks(rotation=30, ha='right')
        _savefig(fig, sp_dir / 'subclass_by_region.png')

    # ── 3. Coverage density by region ─────────────────────────────────────────────
    fig, ax = plt.subplots(figsize=(max(6, len(sp_rows) * 1.2), 4))
    bars = ax.bar(sp_rows['Region'], sp_rows['Density_per_kb'], color='steelblue')
    ax.bar_label(bars, fmt='%.3f', padding=2)
    ax.set_title(f'{species} \u2014 Motif Density (motifs per kb) by Region')
    ax.set_xlabel('Region'); ax.set_ylabel('Motifs per kb')
    plt.xticks(rotation=30, ha='right')
    _savefig(fig, sp_dir / 'density_by_region.png')

    # ── 4. Motif length distribution by region ───────────────────────────────────
    _len_parts = []
    for stem in sp_stems:
        df  = RESULTS_BY_FILE[stem]['df']
        reg = sp_rows.loc[sp_rows['Stem'] == stem, 'Region'].values[0]
        if not df.empty and 'Length' in df.columns:
            tmp = df[['Length']].copy()
            tmp['Region'] = reg
            _len_parts.append(tmp)
    if _len_parts:
        _len_df = pd.concat(_len_parts, ignore_index=True)
        _len_df = _len_df[_len_df['Length'] > 0]
        if not _len_df.empty:
            fig, ax = plt.subplots(figsize=(max(8, len(sp_stems) * 2), 4))
            sns.boxplot(data=_len_df, x='Region', y='Length', ax=ax, palette='Set2')
            ax.set_title(f'{species} \u2014 Motif Length Distribution by Region')
            ax.set_xlabel('Region'); ax.set_ylabel('Motif Length (bp)')
            plt.xticks(rotation=30, ha='right')
            _savefig(fig, sp_dir / 'length_by_region.png')

    # ── 5. GC content by region ───────────────────────────────────────────────────
    fig, ax = plt.subplots(figsize=(max(6, len(sp_rows) * 1.2), 4))
    bars = ax.bar(sp_rows['Region'], sp_rows['GC_Percent'], color='mediumseagreen')
    ax.bar_label(bars, fmt='%.1f%%', padding=2)
    ax.set_title(f'{species} \u2014 GC Content (%) by Region')
    ax.set_xlabel('Region'); ax.set_ylabel('GC %')
    ax.set_ylim(0, 100)
    plt.xticks(rotation=30, ha='right')
    _savefig(fig, sp_dir / 'gc_by_region.png')

    # ── Summary table ────────────────────────────────────────────────────────────────────
    _sp_summary = sp_rows.set_index('Region')[[
        'Total_Motifs', 'Seq_Length_bp', 'Density_per_kb', 'GC_Percent',
        'Mean_Motif_Length', 'Median_Motif_Length',
        'Unique_Classes', 'Unique_Subclasses'
    ]]
    _sp_summary.to_csv(str(sp_dir / 'within_species_summary.csv'))
    print(f'\n{species} \u2014 Summary')
    display(_sp_summary)


# ══════════════════════════════════════════════════════════════════════════════
# CROSS-SPECIES COMPARISON  (only when >= 2 species are present)
# ══════════════════════════════════════════════════════════════════════════════

if len(_species_list) >= 2:
    _xs_dir = _cmp_dir / '_cross_species'
    _xs_dir.mkdir(exist_ok=True)

    print(f"\n{'\u2550'*60}\nCross-species comparison: {_species_list}\n{'\u2550'*60}")

    _sp_region_sets = {
        sp: set(_comp_df[_comp_df['Species'] == sp]['Region'])
        for sp in _species_list
    }
    _shared_regions = sorted(set.intersection(*_sp_region_sets.values()))
    _all_regions    = sorted(set.union(*_sp_region_sets.values()))
    print(f'Shared regions : {_shared_regions}')
    print(f'All regions    : {_all_regions}')

    # ── 1. Density heatmap (species x region) ────────────────────────────────────
    _dens_pivot = _comp_df.pivot_table(
        index='Species', columns='Region',
        values='Density_per_kb', aggfunc='mean'
    )
    if not _dens_pivot.empty:
        fig, ax = plt.subplots(figsize=(max(8, len(_all_regions) * 1.4),
                                        max(4, len(_species_list) * 0.8)))
        sns.heatmap(_dens_pivot, annot=True, fmt='.3f', cmap='YlOrRd', ax=ax,
                    linewidths=0.4, cbar_kws={'label': 'Motifs per kb'})
        ax.set_title('Cross-Species \u2014 Motif Density Heatmap (motifs per kb)')
        ax.set_xlabel('Region'); ax.set_ylabel('Species')
        _savefig(fig, _xs_dir / 'cross_species_density_heatmap.png')

    # ── 2. GC% heatmap (species x region) ────────────────────────────────────────
    _gc_pivot = _comp_df.pivot_table(
        index='Species', columns='Region',
        values='GC_Percent', aggfunc='mean'
    )
    if not _gc_pivot.empty:
        fig, ax = plt.subplots(figsize=(max(8, len(_all_regions) * 1.4),
                                        max(4, len(_species_list) * 0.8)))
        sns.heatmap(_gc_pivot, annot=True, fmt='.1f', cmap='YlGn', ax=ax,
                    linewidths=0.4, cbar_kws={'label': 'GC %'})
        ax.set_title('Cross-Species \u2014 GC Content Heatmap (%)')
        ax.set_xlabel('Region'); ax.set_ylabel('Species')
        _savefig(fig, _xs_dir / 'cross_species_gc_heatmap.png')

    # ── 3. Class comparison per shared region ─────────────────────────────────────
    for region in _shared_regions:
        _rc = {}
        for stem in _comp_df[_comp_df['Region'] == region]['Stem']:
            df = RESULTS_BY_FILE[stem]['df']
            sp = _comp_df.loc[_comp_df['Stem'] == stem, 'Species'].values[0]
            _rc[sp] = (
                df['Class'].value_counts() if not df.empty else pd.Series(dtype=int)
            )
        _all_cls = sorted({c for s in _rc.values() for c in s.index})
        if _all_cls:
            _rmat = pd.DataFrame(
                {sp: s.reindex(_all_cls, fill_value=0) for sp, s in _rc.items()}
            ).T
            fig, ax = plt.subplots(figsize=(max(8, len(_all_cls) * 1.2), 4))
            _rmat.plot(kind='bar', ax=ax, colormap='tab20', width=0.8)
            ax.set_title(f'Cross-Species Class Comparison \u2014 {region}')
            ax.set_xlabel('Species'); ax.set_ylabel('Motif Count')
            ax.legend(title='Class', bbox_to_anchor=(1, 1))
            plt.xticks(rotation=30, ha='right')
            _safe = re.sub(r'[^\w\-]', '_', region)
            _savefig(fig, _xs_dir / f'cross_species_class_{_safe}.png')

    # ── 4. Length distribution for shared regions ───────────────────────────────
    _xs_len_parts = []
    for stem, res in RESULTS_BY_FILE.items():
        row = _comp_df[_comp_df['Stem'] == stem]
        if row.empty:
            continue
        sp     = row['Species'].values[0]
        region = row['Region'].values[0]
        if region not in _shared_regions:
            continue
        df = res['df']
        if not df.empty and 'Length' in df.columns:
            tmp = df[['Length']].copy()
            tmp['Species'] = sp
            tmp['Region']  = region
            _xs_len_parts.append(tmp)
    if _xs_len_parts:
        _xs_len_df = pd.concat(_xs_len_parts, ignore_index=True)
        _xs_len_df = _xs_len_df[_xs_len_df['Length'] > 0]
        if not _xs_len_df.empty:
            _w = max(10, len(_shared_regions) * len(_species_list) * 1.5)
            fig, ax = plt.subplots(figsize=(_w, 5))
            sns.boxplot(data=_xs_len_df, x='Region', y='Length',
                        hue='Species', ax=ax, palette='Set2')
            ax.set_title('Cross-Species \u2014 Motif Length Distribution by Region')
            ax.set_xlabel('Region'); ax.set_ylabel('Motif Length (bp)')
            plt.xticks(rotation=30, ha='right')
            ax.legend(title='Species', bbox_to_anchor=(1, 1))
            _savefig(fig, _xs_dir / 'cross_species_length_by_region.png')

    # ── Cross-species summary table ────────────────────────────────────────────────────
    _xs_summary = _comp_df[[
        'Species', 'Region', 'Total_Motifs', 'Seq_Length_bp',
        'Density_per_kb', 'GC_Percent', 'Mean_Motif_Length',
        'Median_Motif_Length', 'Unique_Classes', 'Unique_Subclasses'
    ]].sort_values(['Species', 'Region'])
    _xs_summary.to_csv(str(_xs_dir / 'cross_species_summary.csv'), index=False)
    print('\nCross-Species Summary Table')
    display(_xs_summary)

# ── Save master comparison CSV ─────────────────────────────────────────────────────────────
_comp_df.to_csv(str(_cmp_dir / 'all_comparisons_summary.csv'), index=False)
print(f'\n\u2705 Comparison outputs saved to: {_cmp_dir}')
