# Config

In [None]:
# --- Config & helpers ---
from pathlib import Path
from Bio import SeqIO, SeqRecord, Seq
import pandas as pd, numpy as np
import os, re, shutil, subprocess
from pathlib import Path
import matplotlib.pyplot as plt



# Paths
BASE        = Path("/~/HMM-PROFILES-FOR-CA-ENZYMES")
OG          = BASE / "OG_Labels"                 # original per-class FASTAs (*.txt)
NEW_LABELS  = BASE / "new_labels"                # <class>_<n>_seqdump.txt
REBASE      = BASE / "OUTPUT"                   # workspace
RE_OG_UPD   = REBASE / "OG_Labels_UPDATED"       # OUTPUT: merged + dedup per-class FASTAs
SPLIT_SRC   = REBASE / "per_class_split_sources" # recursively scanned

# Outputs for analysis
OUT_DIR   = REBASE / "length_consensus_outputs"
ALIGN_DIR = OUT_DIR / "aligned"
CONS_DIR  = OUT_DIR / "consensus"
HM_DIR    = OUT_DIR / "heatmaps"
for d in [REBASE, RE_OG_UPD, OUT_DIR, ALIGN_DIR, CONS_DIR, HM_DIR]:
    d.mkdir(parents=True, exist_ok=True)

def first_token(s: str) -> str:
    return s.split()[0] if s else s

def load_fasta_records(path: Path):
    recs = []
    for r in SeqIO.parse(str(path), "fasta"):
        sid = first_token(r.id or r.description)
        r.id = sid; r.description = ""
        recs.append(r)
    return recs

def write_fasta(records, path: Path):
    if records:
        SeqIO.write(records, str(path), "fasta")


In [None]:
# Seed UPDATED from OG (ensures every class file exists)
for f in sorted(OG.glob("*.txt")):
    dst = RE_OG_UPD / f.name
    if not dst.exists():
        shutil.copy2(f, dst)

# Index current IDs
existing_ids = {}
for f in sorted(RE_OG_UPD.glob("*.txt")):
    cls = f.stem
    ids = { first_token(r.id or r.description) for r in SeqIO.parse(str(f), "fasta") }
    existing_ids[cls] = set(ids)

# Infer class from split filenames or parent dir
class_pat_list = [
    re.compile(r'^(?P<cls>[A-Za-z]+)[._-]'),        # alpha.something.fa
    re.compile(r'8ca[._-](?P<cls>[A-Za-z]+)[._-]'), # 8ca.zeta.45.*
    re.compile(r'^(?P<cls>[A-Za-z]+)$'),            # plain stem
]
def infer_class_from_path(p: Path):
    stem = p.stem
    for pat in class_pat_list:
        m = pat.search(stem)
        if m: return m.group("cls").lower()
    m2 = re.match(r'^(?P<cls>[A-Za-z]+)$', p.parent.name)
    return m2.group("cls").lower() if m2 else None

# Gather split files
split_files = []
if SPLIT_SRC.exists():
    for ext in ("*.fa","*.fasta","*.faa","*.fas","*.txt"):
        split_files += list(SPLIT_SRC.rglob(ext))

# Merge SPLIT sources
for f in sorted(split_files):
    cls_simple = infer_class_from_path(f)
    if not cls_simple: 
        continue
    targets = [p for p in RE_OG_UPD.glob("*.txt")
               if re.search(rf'\b{re.escape(cls_simple)}\b', p.stem, flags=re.IGNORECASE)]
    if not targets:
        targets = [RE_OG_UPD / f"8ca.{cls_simple}.extra.txt"]
        targets[0].touch(exist_ok=True)

    new_recs = load_fasta_records(f)
    for out_path in targets:
        key = out_path.stem
        prev = existing_ids.setdefault(key, set())
        current = list(SeqIO.parse(str(out_path), "fasta"))
        add = 0
        seen_local = set()
        for r in new_recs:
            if r.id in seen_local: 
                continue
            seen_local.add(r.id)
            if r.id not in prev:
                current.append(r); prev.add(r.id); add += 1
        write_fasta(current, out_path)

# Merge NEW_LABELS dumps
dump_re = re.compile(r'^(?P<cls>[A-Za-z]+)_(?P<num>\d+)_seqdump\.txt$')
if NEW_LABELS.exists():
    for f in sorted(NEW_LABELS.glob("*.txt")):
        m = dump_re.match(f.name)
        if not m: 
            continue
        cls_simple = m.group("cls").lower()
        targets = [p for p in RE_OG_UPD.glob("*.txt")
                   if re.search(rf'\b{re.escape(cls_simple)}\b', p.stem, flags=re.IGNORECASE)]
        if not targets:
            targets = [RE_OG_UPD / f"8ca.{cls_simple}.extra.txt"]
            targets[0].touch(exist_ok=True)

        new_recs = load_fasta_records(f)
        for out_path in targets:
            key = out_path.stem
            prev = existing_ids.setdefault(key, set())
            current = list(SeqIO.parse(str(out_path), "fasta"))
            add = 0
            seen_local = set()
            for r in new_recs:
                if r.id in seen_local: 
                    continue
                seen_local.add(r.id)
                if r.id not in prev:
                    current.append(r); prev.add(r.id); add += 1
            write_fasta(current, out_path)

# Quick summary
summary = []
for f in sorted(RE_OG_UPD.glob("*.txt")):
    n = sum(1 for _ in SeqIO.parse(str(f), "fasta"))
    summary.append((f.name, n))
pd.DataFrame(summary, columns=["class_file","n_sequences"]).sort_values("class_file")


# EDA

In [None]:
def seq_len(rec: SeqRecord) -> int:
    return sum(1 for ch in str(rec.seq) if ch not in "-.")

def trimmed_mean(arr, lo=0.25, hi=0.75):
    if not arr: return float("nan")
    a = np.array(sorted(arr))
    i0, i1 = int(np.floor(len(a)*lo)), int(np.ceil(len(a)*hi))
    mid = a[i0:i1]
    return float(np.mean(mid)) if len(mid) else float("nan")

def trimmed_median(arr, lo=0.25, hi=0.75):
    if not arr: return float("nan")
    a = np.array(sorted(arr))
    i0, i1 = int(np.floor(len(a)*lo)), int(np.ceil(len(a)*hi))
    mid = a[i0:i1]
    return float(np.median(mid)) if len(mid) else float("nan")

rows = []
all_lengths = []
for f in sorted(RE_OG_UPD.glob("*.txt")):
    cls = f.stem
    lens = [seq_len(r) for r in SeqIO.parse(str(f), "fasta")]
    all_lengths.extend(lens)
    rows.append({
        "class": cls,
        "n_seqs": len(lens),
        "mean_len": np.mean(lens) if lens else float("nan"),
        "median_len": np.median(lens) if lens else float("nan"),
        "std_len": np.std(lens, ddof=1) if len(lens) > 1 else float("nan"),
        "trimmed_mean_len_25_75": trimmed_mean(lens),
        "trimmed_median_len_25_75": trimmed_median(lens),
        "min_len": int(np.min(lens)) if lens else None,
        "q25_len": int(np.quantile(lens, 0.25)) if lens else None,
        "q75_len": int(np.quantile(lens, 0.75)) if lens else None,
        "max_len": int(np.max(lens)) if lens else None,
    })

rows.append({
    "class": "__ALL__",
    "n_seqs": sum(r["n_seqs"] for r in rows),
    "mean_len": np.mean(all_lengths) if all_lengths else float("nan"),
    "median_len": np.median(all_lengths) if all_lengths else float("nan"),
    "std_len": np.std(all_lengths, ddof=1) if len(all_lengths) > 1 else float("nan"),
    "trimmed_mean_len_25_75": trimmed_mean(all_lengths),
    "trimmed_median_len_25_75": trimmed_median(all_lengths),
    "min_len": int(np.min(all_lengths)) if all_lengths else None,
    "q25_len": int(np.quantile(all_lengths, 0.25)) if all_lengths else None,
    "q75_len": int(np.quantile(all_lengths, 0.75)) if all_lengths else None,
    "max_len": int(np.max(all_lengths)) if all_lengths else None,
})

df_stats = pd.DataFrame(rows).sort_values("class")
df_stats


In [None]:
# ECDF per class + explicit overall, with mean markers

IN_DIR = RE_OG_UPD
SAVE_PNG = OUT_DIR / "length_ecdf_per_class.png"
INCLUDE_OVERALL = True
MAX_CLASSES_PLOTTED = None
USE_LOG_X = False
SHOW_MEANS = True  # add vertical lines at class means
SHOW_OVERALL_MEAN = True

def seq_len_fast(r): return sum(1 for ch in str(r.seq) if ch not in "-.")
def ecdf(vals):
    if len(vals)==0: return np.array([]), np.array([])
    x = np.sort(np.asarray(vals))
    y = (np.arange(1, len(x)+1) / len(x)) * 100.0
    return x, y

# gather
lengths_by_class = {}
all_lengths = []
for f in sorted(Path(IN_DIR).glob("*.txt")):
    cls = f.stem
    lens = [seq_len_fast(r) for r in SeqIO.parse(str(f), "fasta")]
    if lens:
        lengths_by_class[cls] = np.array(lens, dtype=int)
        all_lengths.extend(lens)

# optional limit
if MAX_CLASSES_PLOTTED is not None and len(lengths_by_class) > MAX_CLASSES_PLOTTED:
    keep = dict(sorted(lengths_by_class.items(), key=lambda kv: len(kv[1]), reverse=True)[:MAX_CLASSES_PLOTTED])
    lengths_by_class = keep

fig, ax = plt.subplots(figsize=(10, 6))

# plot classes
handles = []
labels  = []
for cls, lens in sorted(lengths_by_class.items()):
    x, p = ecdf(lens)
    if len(x)==0: continue
    h = ax.step(x, p, where="post", linewidth=1.3, label=f"{cls} (n={len(lens)})")[0]
    handles.append(h); labels.append(h.get_label())
    if SHOW_MEANS:
        ax.axvline(np.mean(lens), color=h.get_color(), linestyle=":", linewidth=1.0, alpha=0.8)

# plot overall explicitly with fixed styling & label
if INCLUDE_OVERALL and len(all_lengths)>0:
    x_all, p_all = ecdf(all_lengths)
    hov = ax.step(x_all, p_all, where="post",
                  color="black", linestyle="--", linewidth=2.0,
                  label="__ALL__ (overall)")[0]
    handles.append(hov); labels.append(hov.get_label())
    if SHOW_OVERALL_MEAN:
        ax.axvline(np.mean(all_lengths), color="black", linestyle="--", linewidth=1.2, alpha=0.9)

ax.set_xlabel("Sequence length (aa)")
ax.set_ylabel("Percent ≤ length")
ax.set_title("Cumulative Distribution of Sequence Lengths by Class")
if USE_LOG_X: ax.set_xscale("log")
ax.grid(True, alpha=0.3, linestyle=":")
ax.legend(handles, labels, loc="best", fontsize=9)  # ensure overall is included
fig.tight_layout()
fig.savefig(SAVE_PNG, dpi=150)
plt.show()

print("Saved plot to:", SAVE_PNG)


## Per Class ECDF

In [None]:
# === ECDF: one figure per class (with optional overall overlay) ===


IN_DIR = RE_OG_UPD
OUT_DIR_EACH = OUT_DIR / "ecdf_per_class"
OUT_DIR_EACH.mkdir(parents=True, exist_ok=True)

INCLUDE_OVERALL = True          # overlay pooled ECDF
SHOW_MEAN = True                # vertical mean line (class)
SHOW_MEDIAN = False             # vertical median line (class)
SHOW_TRIMMED_MEAN = False       # vertical trimmed mean (25–75%)
SHOW_OVERALL_MEAN = True        # vertical overall mean
USE_LOG_X = False               # log-scale x axis if lengths span wide range

def seq_len_fast(r): 
    return sum(1 for ch in str(r.seq) if ch not in "-.")

def ecdf(vals):
    if len(vals)==0: return np.array([]), np.array([])
    x = np.sort(np.asarray(vals))
    y = (np.arange(1, len(x)+1) / len(x)) * 100.0
    return x, y

def trimmed_mean(arr, lo=0.25, hi=0.75):
    if len(arr) == 0: return np.nan
    a = np.sort(np.asarray(arr))
    i0, i1 = int(np.floor(len(a)*lo)), int(np.ceil(len(a)*hi))
    mid = a[i0:i1]
    return float(np.mean(mid)) if len(mid) else np.nan

# Gather lengths
lengths_by_class = {}
all_lengths = []
for f in sorted(Path(IN_DIR).glob("*.txt")):
    cls = f.stem
    lens = [seq_len_fast(r) for r in SeqIO.parse(str(f), "fasta")]
    if lens:
        lengths_by_class[cls] = np.array(lens, dtype=int)
        all_lengths.extend(lens)
all_lengths = np.array(all_lengths, dtype=int)

# Precompute overall ECDF
x_all, p_all = ecdf(all_lengths) if len(all_lengths) else (np.array([]), np.array([]))
overall_mean = float(np.mean(all_lengths)) if len(all_lengths) else np.nan

for cls, lens in sorted(lengths_by_class.items()):
    x, p = ecdf(lens)
    if len(x) == 0:
        continue

    fig, ax = plt.subplots(figsize=(9, 5))
    # Class ECDF
    h = ax.step(x, p, where="post", linewidth=2.0, label=f"{cls} (n={len(lens)})")[0]
    color = h.get_color()

    # Class markers
    if SHOW_MEAN:
        ax.axvline(np.mean(lens), color=color, linestyle=":", linewidth=1.2, alpha=0.9, label="mean")
    if SHOW_MEDIAN:
        ax.axvline(np.median(lens), color=color, linestyle="-.", linewidth=1.0, alpha=0.9, label="median")
    if SHOW_TRIMMED_MEAN:
        ax.axvline(trimmed_mean(lens), color=color, linestyle="--", linewidth=1.0, alpha=0.9, label="trimmed mean (25–75%)")

    # Overall overlay
    if INCLUDE_OVERALL and len(x_all):
        ax.step(x_all, p_all, where="post", color="black", linestyle="--", linewidth=1.5, label="__ALL__ (overall)")
        if SHOW_OVERALL_MEAN and not np.isnan(overall_mean):
            ax.axvline(overall_mean, color="black", linestyle="--", linewidth=1.0, alpha=0.9, label="overall mean")

    ax.set_xlabel("Sequence length (aa)")
    ax.set_ylabel("Percent ≤ length")
    ax.set_title(f"ECDF of Sequence Lengths — {cls}")
    if USE_LOG_X:
        ax.set_xscale("log")
    ax.grid(True, alpha=0.3, linestyle=":")
    ax.legend(loc="best", fontsize=9)

    out_png = OUT_DIR_EACH / f"{cls}.length_ecdf.png"
    fig.tight_layout()
    fig.savefig(out_png, dpi=150)
    plt.show()
    print("Saved:", out_png)


In [None]:
# Correct class imports (add these even if you imported earlier)

# Find clustalo or muscle
MSA_BIN = shutil.which(os.environ.get("CLUSTALO_BIN") or "clustalo") or shutil.which("muscle")
assert MSA_BIN, "Install clustalo or muscle (and/or set CLUSTALO_BIN)."

def write_class_fasta_for_alignment(in_path: Path, out_path: Path):
    recs = []
    for r in SeqIO.parse(str(in_path), "fasta"):
        sid = first_token(r.id or r.description)
        recs.append(SeqRecord(Seq(str(r.seq)), id=sid, description=""))
    write_fasta(recs, out_path)

# Per-class alignments
aln_paths = {}
for f in sorted(RE_OG_UPD.glob("*.txt")):
    cls = f.stem
    in_fa  = ALIGN_DIR / f"{cls}.in.fasta"
    out_fa = ALIGN_DIR / f"{cls}.aln.fasta"
    write_class_fasta_for_alignment(f, in_fa)

    if "clustalo" in MSA_BIN:
        cmd = [MSA_BIN, "-i", str(in_fa), "-o", str(out_fa), "--force", "--threads=4", "--output-order=input-order"]
    else:  # muscle
        cmd = [MSA_BIN, "-align", str(in_fa), "-output", str(out_fa)]
    subprocess.run(cmd, check=True)
    aln_paths[cls] = out_fa

# ALL combined
all_in = ALIGN_DIR / "__ALL__.in.fasta"
with open(all_in, "w") as out:
    for f in sorted(RE_OG_UPD.glob("*.txt")):
        for r in SeqIO.parse(str(f), "fasta"):
            sid = first_token(r.id or r.description)
            SeqIO.write(SeqRecord(Seq(str(r.seq)), id=sid, description=""), out, "fasta")

all_out = ALIGN_DIR / "__ALL__.aln.fasta"
if "clustalo" in MSA_BIN:
    cmd = [MSA_BIN, "-i", str(all_in), "-o", str(all_out), "--force", "--threads=4", "--output-order=input-order"]
else:
    cmd = [MSA_BIN, "-align", str(all_in), "-output", str(all_out)]
subprocess.run(cmd, check=True)
aln_paths["__ALL__"] = all_out

aln_paths


# Consensus-similarity heatmaps 

In [None]:
# === Consensus-similarity heatmaps (sequence vs position) ===

# Reuse AA set if you want; not strictly needed here
def majority_consensus_from_alignment(aln_fa: Path, threshold: float = 0.5) -> str:
    """Build majority-rule consensus directly from an alignment FASTA."""
    recs = list(SeqIO.parse(str(aln_fa), "fasta"))
    assert recs, f"Empty alignment: {aln_fa}"
    seqs = [str(r.seq) for r in recs]
    L = max(len(s) for s in seqs)
    cons = []
    for j in range(L):
        col = [s[j] if j < len(s) else "-" for s in seqs]
        # ignore gaps
        letters = [c for c in col if c not in "-."]
        if not letters:
            cons.append("-")
            continue
        # majority letter
        vals, counts = np.unique(letters, return_counts=True)
        i = int(np.argmax(counts))
        frac = counts[i] / float(len(letters))
        cons.append(vals[i] if frac >= threshold else "X")
    return "".join(cons)

def consensus_similarity_matrix(aln_fa: Path, consensus: str, ignore_when_cons_is_gap=True):
    """Return matrix (n_seq x L) with 1 if residue matches consensus, 0 if mismatch, nan if ignored."""
    recs = list(SeqIO.parse(str(aln_fa), "fasta"))
    seqs = [str(r.seq) for r in recs]
    L = len(consensus)
    M = np.full((len(seqs), L), np.nan, dtype=float)
    for i, s in enumerate(seqs):
        for j in range(L):
            c_cons = consensus[j] if j < len(consensus) else "-"
            c_seq  = s[j] if j < len(s) else "-"
            if ignore_when_cons_is_gap and (c_cons in "-."):
                continue  # leave NaN
            if c_seq in "-.":
                continue  # leave NaN (gap in sequence)
            M[i, j] = 1.0 if (c_seq == c_cons and c_cons not in {"X", "-"}) else 0.0
    return recs, M
def plot_similarity_heatmap(M: np.ndarray, title: str, save_path: Path = None, sort_by_overall=True):
    """Plot sequences × positions heatmap of matches to consensus."""
    import matplotlib as mpl

    H = M.copy()

    # Optionally sort sequences (rows) by overall similarity to consensus
    if sort_by_overall:
        means = np.nanmean(H, axis=1)
        order = np.argsort(-np.nan_to_num(means, nan=-1.0))
        H = H[order, :]

    # Mask NaNs so they render as blank
    H_masked = np.ma.masked_invalid(H)

    # Use modern colormap API (2 discrete levels: mismatch=0, match=1)
    cmap = mpl.colormaps.get_cmap("viridis").resampled(2)

    plt.figure(figsize=(12, max(3, H.shape[0] * 0.03)))
    im = plt.imshow(H_masked, aspect="auto", interpolation="nearest", cmap=cmap, vmin=0, vmax=1)
    plt.xlabel("Alignment position")
    plt.ylabel("Sequences")
    plt.title(title)
    cbar = plt.colorbar(im, fraction=0.046, pad=0.04)
    cbar.set_ticks([0, 1])
    cbar.set_ticklabels(["mismatch", "match"])

    if save_path:
        plt.tight_layout()
        plt.savefig(save_path, dpi=150, bbox_inches="tight")
    plt.show()

# Generate heatmaps for each class and ALL
for cls, aln in aln_paths.items():
    cons = majority_consensus_from_alignment(aln, threshold=0.5)
    recs, M = consensus_similarity_matrix(aln, cons, ignore_when_cons_is_gap=True)
    out_png = HM_DIR / f"{cls}.consensus_similarity.heatmap.png"
    plot_similarity_heatmap(M, f"Consensus Similarity — {cls}", save_path=out_png, sort_by_overall=True)
    # Also save consensus used
    with open(CONS_DIR / f"{cls}.consensus.majority.fasta", "w") as f:
        f.write(f">{cls}\n")
        for i in range(0, len(cons), 80):
            f.write(cons[i:i+80] + "\n")


# Expanded and Filtered Profiles

## Imports and Summary

In [None]:
# === Per-class length filtering into new FASTAs (range per class) ===
# ----- SOURCE & OUTPUT -----
SRC_DIR  = RE_OG_UPD                     # where your current per-class FASTAs (*.txt) live
OUT_BASE = REBASE / "filtered_sets_per_class"   # base folder for outputs
OUT_BASE.mkdir(parents=True, exist_ok=True)

# ----- CONFIG: specify ranges PER CLASS -----
# Keys must match the class file stem (e.g., "8ca.alpha.45" for "8ca.alpha.45.txt").
# Each class maps to:
#   - one dict: {"name": "...", "min_len": L1, "max_len": L2}
#   - OR a list of such dicts to create multiple outputs for that class.
# Use None for open-ended bounds.
CLASS_RANGE_SPECS = {
    "8ca.alpha.45": {"name": "alpha_no_change", "min_len": None, "max_len": None},
    "8ca.beta.45": {"name": "beta_no_change", "min_len": None, "max_len": None},
    "8ca.delta.45": {"name": "delta_250_600", "min_len": 250, "max_len": 600},
    "8ca.eta.45": {"name": "eta_300_no_max", "min_len": 300, "max_len": None},
    "8ca.gamma.45": {"name": "gamma_no_change", "min_len": None, "max_len": None},
    "8ca.iota.45": {"name": "iota_no_change", "min_len": None, "max_len": None},
    "8ca.theta.45": {"name": "theta_250_550", "min_len": 250, "max_len": 550},
    "8ca.zeta.45": {"name": "zeta_150_250", "min_len": 150, "max_len": 250}}

# Only write files with at least this many kept sequences
MIN_SEQS_TO_WRITE = 1

# Whether to ignore '-' and '.' when computing length
IGNORE_GAPS = True

# ----- HELPERS -----
def seq_len(rec: SeqRecord, ignore_gaps: bool = True) -> int:
    s = str(rec.seq)
    return sum(1 for ch in s if ch not in "-.") if ignore_gaps else len(s)

def in_range(n: int, lo, hi) -> bool:
    if lo is not None and n < lo: return False
    if hi is not None and n > hi: return False
    return True

def ensure_list(spec):
    if spec is None: return []
    return spec if isinstance(spec, (list, tuple)) else [spec]

# ----- MAIN -----
class_files = sorted(Path(SRC_DIR).glob("*.txt"))
assert class_files, f"No class FASTAs found in {SRC_DIR}"

summary_rows = []

# map stem -> path for quick lookup
by_stem = {f.stem: f for f in class_files}

# sanity: warn for spec keys not found
missing = [k for k in CLASS_RANGE_SPECS.keys() if k not in by_stem]
if missing:
    print("[warn] Specified classes not found in SRC_DIR:", ", ".join(missing))

for cls, spec in CLASS_RANGE_SPECS.items():
    if cls not in by_stem:
        continue
    src = by_stem[cls]
    # load once
    records = list(SeqIO.parse(str(src), "fasta"))
    total = len(records)

    for rng in ensure_list(spec):
        name = rng.get("name") or f"{cls}_len_{rng.get('min_len')}_{rng.get('max_len')}"
        lo   = rng.get("min_len", None)
        hi   = rng.get("max_len", None)

        out_dir = OUT_BASE / name
        out_dir.mkdir(parents=True, exist_ok=True)

        kept = []
        for r in records:
            L = seq_len(r, IGNORE_GAPS)
            if in_range(L, lo, hi):
                rid = (r.id or r.description or "").split()[0]
                kept.append(SeqRecord(Seq(str(r.seq)), id=rid, description=""))

        if len(kept) >= MIN_SEQS_TO_WRITE:
            out_fa = out_dir / f"{cls}.txt"
            SeqIO.write(kept, str(out_fa), "fasta")
            out_path = str(out_fa)
        else:
            out_path = ""

        summary_rows.append({
            "class": cls,
            "range_name": name,
            "min_len": lo,
            "max_len": hi,
            "n_total_in_src": total,
            "n_kept": len(kept),
            "out_path": out_path,
        })

# Save summary
summary_df = pd.DataFrame(summary_rows).sort_values(["class", "range_name"])
summary_csv = OUT_BASE / "per_class_filtered_summary.csv"
summary_df.to_csv(summary_csv, index=False)

print(f"[OK] Wrote filtered FASTAs under: {OUT_BASE}")
print(f"[OK] Summary → {summary_csv}")
summary_df


In [None]:
# === Length stats from multiple per-class folders (you specify the dirs) ===
from pathlib import Path
from Bio import SeqIO
from Bio.SeqRecord import SeqRecord
import numpy as np
import pandas as pd

# ---------------- CONFIG:folders here ----------------
DIRS = [
    REBASE / "filtered_sets_per_class" / "alpha_no_change",
    REBASE / "filtered_sets_per_class" / "beta_no_change",     
    REBASE / "filtered_sets_per_class" / "gamma_no_change",     
    REBASE / "filtered_sets_per_class" / "delta_250_600",
    REBASE / "filtered_sets_per_class" / "zeta_150_250",       
    REBASE / "filtered_sets_per_class" / "eta_300_no_max",       
    REBASE / "filtered_sets_per_class" / "theta_250_550",      
    REBASE / "filtered_sets_per_class" / "iota_no_change",      
]

OUT_CSV = OUT_DIR / "length_stats_selected_dirs.csv"

# If the same class appears in multiple dirs:
#   - True  => merge sequences across files, dedup by ID
#   - False => raise an error if a class is seen more than once
MERGE_DUPLICATE_CLASSES = True

# ---------------- Helpers ----------------
def first_token(s: str) -> str:
    return (s or "").split()[0]

def seq_len(rec: SeqRecord) -> int:
    # length ignoring gaps
    return sum(1 for ch in str(rec.seq) if ch not in "-.")

def trimmed_mean(arr, lo=0.25, hi=0.75):
    if not arr: return float("nan")
    a = np.array(sorted(arr))
    i0, i1 = int(np.floor(len(a)*lo)), int(np.ceil(len(a)*hi))
    mid = a[i0:i1]
    return float(np.mean(mid)) if len(mid) else float("nan")

def trimmed_median(arr, lo=0.25, hi=0.75):
    if not arr: return float("nan")
    a = np.array(sorted(arr))
    i0, i1 = int(np.floor(len(a)*lo)), int(np.ceil(len(a)*hi))
    mid = a[i0:i1]
    return float(np.median(mid)) if len(mid) else float("nan")

# ---------------- Collect class files from the specified dirs ----------------
files = []
for d in DIRS:
    d = Path(d)
    assert d.exists(), f"Directory does not exist: {d}"
    for ext in ("*.txt", "*.fa", "*.fasta", "*.faa", "*.fas"):
        files.extend(sorted(d.glob(ext)))

assert files, f"No FASTA files found in: {', '.join(str(Path(d)) for d in DIRS)}"

# ---------------- Load per-class sequences (merge or enforce uniqueness) ----------------
class_to_ids = {}
class_to_lengths = {}

for fp in files:
    cls = fp.stem  # e.g., '8ca.alpha.45'
    # Load and normalize IDs
    rec_ids = []
    for r in SeqIO.parse(str(fp), "fasta"):
        rid = first_token(r.id or r.description)
        rec_ids.append((rid, seq_len(r)))

    if cls not in class_to_ids:
        class_to_ids[cls] = set()
        class_to_lengths[cls] = []

    if not MERGE_DUPLICATE_CLASSES and class_to_lengths[cls]:
        raise RuntimeError(f"Class {cls} appears in multiple input folders; set MERGE_DUPLICATE_CLASSES=True to merge.")

    # Merge (dedup by ID)
    seen = class_to_ids[cls]
    for rid, L in rec_ids:
        if rid in seen:
            continue
        seen.add(rid)
        class_to_lengths[cls].append(L)

# ---------------- Compute stats ----------------
rows = []
all_lengths = []

for cls, lens in sorted(class_to_lengths.items()):
    if len(lens) == 0:
        rows.append({
            "class": cls, "n_seqs": 0,
            "mean_len": float("nan"), "median_len": float("nan"),
            "std_len": float("nan"),
            "trimmed_mean_len_25_75": float("nan"),
            "trimmed_median_len_25_75": float("nan"),
            "min_len": None, "q25_len": None, "q75_len": None, "max_len": None,
        })
        continue

    all_lengths.extend(lens)
    lens_arr = np.array(lens, dtype=float)
    rows.append({
        "class": cls,
        "n_seqs": int(len(lens_arr)),
        "mean_len": float(np.mean(lens_arr)),
        "median_len": float(np.median(lens_arr)),
        "std_len": float(np.std(lens_arr, ddof=1)) if len(lens_arr) > 1 else float("nan"),
        "trimmed_mean_len_25_75": trimmed_mean(lens),
        "trimmed_median_len_25_75": trimmed_median(lens),
        "min_len": int(np.min(lens_arr)),
        "q25_len": int(np.quantile(lens_arr, 0.25)),
        "q75_len": int(np.quantile(lens_arr, 0.75)),
        "max_len": int(np.max(lens_arr)),
    })

# Overall row
if all_lengths:
    all_arr = np.array(all_lengths, dtype=float)
    rows.append({
        "class": "__ALL__",
        "n_seqs": int(sum(r["n_seqs"] for r in rows)),
        "mean_len": float(np.mean(all_arr)),
        "median_len": float(np.median(all_arr)),
        "std_len": float(np.std(all_arr, ddof=1)) if len(all_arr) > 1 else float("nan"),
        "trimmed_mean_len_25_75": trimmed_mean(all_lengths),
        "trimmed_median_len_25_75": trimmed_median(all_lengths),
        "min_len": int(np.min(all_arr)),
        "q25_len": int(np.quantile(all_arr, 0.25)),
        "q75_len": int(np.quantile(all_arr, 0.75)),
        "max_len": int(np.max(all_arr)),
    })

df_stats = pd.DataFrame(rows).sort_values("class")
df_stats.to_csv(OUT_CSV, index=False)
print("Dirs used:")
for d in DIRS: print(" -", d)
print("Saved:", OUT_CSV)
df_stats


## Plots

In [None]:
# === ECDF (length distributions) from multiple filtered-set directories ===

SAVE_PNG = OUT_DIR / "length_ecdf_combined_filtered_dirs.png"
INCLUDE_OVERALL = True
MAX_CLASSES_PLOTTED = None
USE_LOG_X = False
SHOW_MEANS = True
SHOW_OVERALL_MEAN = True
MERGE_DUPLICATES = True  # dedup if same class appears in multiple dirs

# ---------------- Helpers ----------------
def first_token(s: str) -> str:
    return (s or "").split()[0]

def seq_len_fast(r): 
    return sum(1 for ch in str(r.seq) if ch not in "-.")

def ecdf(vals):
    if len(vals)==0: return np.array([]), np.array([])
    x = np.sort(np.asarray(vals))
    y = (np.arange(1, len(x)+1) / len(x)) * 100.0
    return x, y

# ---------------- Collect lengths per class ----------------
lengths_by_class = {}
ids_by_class = {}
all_lengths = []

for d in DIRS:
    d = Path(d)
    assert d.exists(), f"Missing directory: {d}"
    for ext in ("*.txt", "*.fa", "*.fasta", "*.faa", "*.fas"):
        for f in sorted(d.glob(ext)):
            cls = f.stem
            for r in SeqIO.parse(str(f), "fasta"):
                rid = first_token(r.id or r.description)
                L = seq_len_fast(r)
                if MERGE_DUPLICATES:
                    if cls not in ids_by_class:
                        ids_by_class[cls] = set()
                        lengths_by_class[cls] = []
                    if rid in ids_by_class[cls]:
                        continue
                    ids_by_class[cls].add(rid)
                    lengths_by_class[cls].append(L)
                else:
                    lengths_by_class.setdefault(cls, []).append(L)

# Merge all lengths into one array
for lens in lengths_by_class.values():
    all_lengths.extend(lens)

# Optionally limit plotted classes
if MAX_CLASSES_PLOTTED is not None and len(lengths_by_class) > MAX_CLASSES_PLOTTED:
    keep = dict(sorted(lengths_by_class.items(), key=lambda kv: len(kv[1]), reverse=True)[:MAX_CLASSES_PLOTTED])
    lengths_by_class = keep

# ---------------- Plot ----------------
fig, ax = plt.subplots(figsize=(10, 6))
handles, labels = [], []

for cls, lens in sorted(lengths_by_class.items()):
    if not lens:
        continue
    x, p = ecdf(lens)
    h = ax.step(x, p, where="post", linewidth=1.3, label=f"{cls} (n={len(lens)})")[0]
    handles.append(h); labels.append(h.get_label())
    if SHOW_MEANS:
        ax.axvline(np.mean(lens), color=h.get_color(), linestyle=":", linewidth=1.0, alpha=0.8)

# Overall ECDF
if INCLUDE_OVERALL and len(all_lengths) > 0:
    x_all, p_all = ecdf(all_lengths)
    hov = ax.step(x_all, p_all, where="post",
                  color="black", linestyle="--", linewidth=2.0,
                  label="__ALL__ (overall)")[0]
    handles.append(hov); labels.append(hov.get_label())
    if SHOW_OVERALL_MEAN:
        ax.axvline(np.mean(all_lengths), color="black", linestyle="--", linewidth=1.2, alpha=0.9)

ax.set_xlabel("Sequence length (aa)")
ax.set_ylabel("Percent ≤ length")
ax.set_title("Cumulative Distribution of Sequence Lengths — Combined Filtered Sets")
if USE_LOG_X:
    ax.set_xscale("log")
ax.grid(True, alpha=0.3, linestyle=":")
ax.legend(handles, labels, loc="best", fontsize=9)
fig.tight_layout()
fig.savefig(SAVE_PNG, dpi=150)
plt.show()

print(f"Saved plot → {SAVE_PNG}")


In [None]:
# === Build alignments from multiple filtered-set folders, then make consensus-similarity heatmaps ===

# ---------- OUTPUT locations (re-use if you already set these elsewhere) ----------
OUT_DIR   = OUT_DIR if 'OUT_DIR' in globals() else (REBASE / "idr_length_consensus_outputs")
ALIGN_DIR = ALIGN_DIR if 'ALIGN_DIR' in globals() else (OUT_DIR / "aligned")
CONS_DIR  = CONS_DIR  if 'CONS_DIR'  in globals() else (OUT_DIR / "consensus")
HM_DIR    = HM_DIR    if 'HM_DIR'    in globals() else (OUT_DIR / "heatmaps")
for d in [OUT_DIR, ALIGN_DIR, CONS_DIR, HM_DIR]:
    Path(d).mkdir(parents=True, exist_ok=True)

# ---------- MSA tool ----------
MSA_BIN = shutil.which(os.environ.get("CLUSTALO_BIN") or "clustalo") or shutil.which("muscle")
assert MSA_BIN, "Install clustalo or muscle (and/or set CLUSTALO_BIN)."

# ---------- Helpers ----------
def first_token(s: str) -> str:
    return (s or "").split()[0]

def write_fasta(records, path: Path):
    if records:
        SeqIO.write(records, str(path), "fasta")

def run_msa(in_fa: Path, out_fa: Path, threads: int = 4):
    if "clustalo" in MSA_BIN:
        cmd = [MSA_BIN, "-i", str(in_fa), "-o", str(out_fa), "--force", f"--threads={threads}", "--output-order=input-order"]
    else:  # muscle
        cmd = [MSA_BIN, "-align", str(in_fa), "-output", str(out_fa)]
    subprocess.run(cmd, check=True)

# ---------- 1) Gather sequences per class from all DIRS (merge + dedup by ID) ----------
class_to_records = {}   # stem -> dict(id -> SeqRecord)
for d in DIRS:
    d = Path(d)
    assert d.exists(), f"Missing directory: {d}"
    for ext in ("*.txt","*.fa","*.fasta","*.faa","*.fas"):
        for f in d.glob(ext):
            cls = f.stem  # e.g., "8ca.alpha.45"
            bucket = class_to_records.setdefault(cls, {})
            for r in SeqIO.parse(str(f), "fasta"):
                rid = first_token(r.id or r.description)
                if rid in bucket:
                    continue
                bucket[rid] = SeqRecord(Seq(str(r.seq)), id=rid, description="")

# ---------- 2) Build per-class MSAs ----------
aln_paths = {}
for cls, id2rec in sorted(class_to_records.items()):
    recs = list(id2rec.values())
    if not recs:
        continue
    in_fa  = Path(ALIGN_DIR) / f"{cls}.in.fasta"
    out_fa = Path(ALIGN_DIR) / f"{cls}.aln.fasta"
    write_fasta(recs, in_fa)
    run_msa(in_fa, out_fa)
    aln_paths[cls] = out_fa

# ---------- 3) Build ALL combined MSA ----------
all_in  = Path(ALIGN_DIR) / "__ALL__.in.fasta"
all_out = Path(ALIGN_DIR) / "__ALL__.aln.fasta"
all_recs = [rec for id2rec in class_to_records.values() for rec in id2rec.values()]
write_fasta(all_recs, all_in)
run_msa(all_in, all_out)
aln_paths["__ALL__"] = all_out

# ---------- 4) Consensus + similarity matrix + heatmap (same as your existing logic, bundled here) ----------
def majority_consensus_from_alignment(aln_fa: Path, threshold: float = 0.5) -> str:
    """Build majority-rule consensus directly from an alignment FASTA."""
    recs = list(SeqIO.parse(str(aln_fa), "fasta"))
    assert recs, f"Empty alignment: {aln_fa}"
    seqs = [str(r.seq) for r in recs]
    L = max(len(s) for s in seqs)
    cons = []
    for j in range(L):
        col = [s[j] if j < len(s) else "-" for s in seqs]
        letters = [c for c in col if c not in "-."]
        if not letters:
            cons.append("-"); continue
        vals, counts = np.unique(letters, return_counts=True)
        i = int(np.argmax(counts))
        frac = counts[i] / float(len(letters))
        cons.append(vals[i] if frac >= threshold else "X")
    return "".join(cons)

def consensus_similarity_matrix(aln_fa: Path, consensus: str, ignore_when_cons_is_gap=True):
    """Return M (n_seq x L): 1 if residue matches consensus, 0 if mismatch, NaN if ignored (gap/cons-gap)."""
    recs = list(SeqIO.parse(str(aln_fa), "fasta"))
    seqs = [str(r.seq) for r in recs]
    L = len(consensus)
    M = np.full((len(seqs), L), np.nan, dtype=float)
    for i, s in enumerate(seqs):
        for j in range(L):
            c_cons = consensus[j] if j < len(consensus) else "-"
            c_seq  = s[j] if j < len(s) else "-"
            if ignore_when_cons_is_gap and (c_cons in "-."):
                continue
            if c_seq in "-.":
                continue
            M[i, j] = 1.0 if (c_seq == c_cons and c_cons not in {"X","-"}) else 0.0
    return recs, M

def plot_similarity_heatmap(M: np.ndarray, title: str, save_path: Path = None, sort_by_overall=True):
    import matplotlib as mpl
    H = M.copy()
    if sort_by_overall:
        means = np.nanmean(H, axis=1)
        order = np.argsort(-np.nan_to_num(means, nan=-1.0))
        H = H[order, :]
    H_masked = np.ma.masked_invalid(H)
    cmap = mpl.colormaps.get_cmap("viridis").resampled(2)
    plt.figure(figsize=(12, max(3, H.shape[0] * 0.03)))
    im = plt.imshow(H_masked, aspect="auto", interpolation="nearest", cmap=cmap, vmin=0, vmax=1)
    plt.xlabel("Alignment position"); plt.ylabel("Sequences"); plt.title(title)
    cbar = plt.colorbar(im, fraction=0.046, pad=0.04); cbar.set_ticks([0, 1]); cbar.set_ticklabels(["mismatch", "match"])
    if save_path:
        plt.tight_layout(); plt.savefig(save_path, dpi=150, bbox_inches="tight")
    plt.show()

# ---------- 5) Run heatmaps for each class + ALL ----------
for cls, aln in aln_paths.items():
    cons = majority_consensus_from_alignment(aln, threshold=0.5)
    _, M = consensus_similarity_matrix(aln, cons, ignore_when_cons_is_gap=True)
    out_png = Path(HM_DIR) / f"{cls}.consensus_similarity.heatmap.png"
    plot_similarity_heatmap(M, f"Consensus Similarity — {cls}", save_path=out_png, sort_by_overall=True)
    # save consensus
    with open(Path(CONS_DIR) / f"{cls}.consensus.majority.fasta", "w") as f:
        f.write(f">{cls}\n")
        for i in range(0, len(cons), 80):
            f.write(cons[i:i+80] + "\n")

print("[OK] Built alignments from DIRS and saved consensus-similarity heatmaps to:", HM_DIR)
