# INITAL HMMR Profiles

In [41]:
# --- paths & imports ---
from pathlib import Path
from Bio import SeqIO, SeqRecord, Seq
import pandas as pd
import numpy as np
import random, collections, re, subprocess, shlex, shutil

# Your project base on Windows (WSL mount):
BASE        = Path("/mnt/c/Users/SAM/CODE/HMMR")
MSA         = BASE / "8ca-1024.aln-fasta"                     # master aligned FASTA (all seqs)
OG          = BASE / "OG_Labels"                              # per-class FASTA files (unaligned) e.g., *.txt
UNIPROT_BIG = BASE / "uniprotkb_carbonic_anhydrase_2025_08_22.fasta"

# Work/output dirs (safe to re-create):
ALIGN_DIR   = BASE / "per_class_aligned"                      # aligned subsets from MSA (full set per class)
SPLIT_DIR   = BASE / "per_class_split"                        # train/test ids per class
TRAIN_ALN   = BASE / "train_aligned"                          # training MSAs (aligned) after slicing
CLEAN_ALN   = BASE / "train_aligned_clean"                    # strict-cleaned training MSAs
TRIM_ALN    = BASE / "train_aligned_trimmed"                  # training MSAs after gappy-column trimming
PROFILES    = BASE / "profiles"                               # HMMs written here (on /mnt/c OK)
RESULTS     = BASE / "results"
LOGS        = BASE / "logs"
TMP         = BASE / "tmp"

# Where we'll press and keep the combined HMM library (Linux FS avoids Windows locking)
HMM_LIB     = Path.home() / "hmmer_lib"                       # e.g., /home/sfkaplan/hmmer_lib
HMM_LIB.mkdir(exist_ok=True, parents=True)

for d in [ALIGN_DIR, SPLIT_DIR, TRAIN_ALN, CLEAN_ALN, TRIM_ALN, PROFILES, RESULTS, LOGS, TMP]:
    d.mkdir(exist_ok=True, parents=True)

# Reproducible split:
RNG_SEED    = 1337
random.seed(RNG_SEED)

def run(cmd, log=None, check=False):
    """Run a shell command and (optionally) write logs."""
    print("$", cmd)
    p = subprocess.run(shlex.split(cmd), capture_output=True, text=True)
    if log:
        Path(log).write_text(p.stdout + "\n--- STDERR ---\n" + p.stderr)
    if check and p.returncode != 0:
        raise RuntimeError(f"Command failed ({p.returncode})\n{p.stderr}")
    return p

def first_token(s: str) -> str:
    """Return the first whitespace-delimited token from a FASTA header."""
    return s.split()[0] if s else s


In [42]:
assert MSA.exists(), f"Missing master alignment: {MSA}"
assert OG.exists(),  f"Missing OG_Labels dir: {OG}"

# Index the master alignment by first token of ID/description
msa_records = list(SeqIO.parse(str(MSA), "fasta"))
msa_index   = { first_token(r.id or r.description): r for r in msa_records }
print(f"[MSA] Loaded {len(msa_records)} aligned sequences")

rows = []
for f in sorted(OG.glob("*.txt")):
    cls = f.stem                          # e.g., '8ca.alpha.45'
    want_ids = [ first_token(r.id or r.description) for r in SeqIO.parse(str(f), "fasta") ]
    found = [ msa_index[i] for i in want_ids if i in msa_index ]
    out_aln = ALIGN_DIR / f"{cls}.aln.fasta"
    if found:
        SeqIO.write(found, str(out_aln), "fasta")
    rows.append({"class": cls, "wanted": len(want_ids), "found": len(found), "out": str(out_aln) })

slice_report = pd.DataFrame(rows).sort_values("class")
display(slice_report.head(10))
slice_report_path = BASE / "01_slice_report.csv"
slice_report.to_csv(slice_report_path, index=False)
print("[OK] wrote", slice_report_path)


[MSA] Loaded 499 aligned sequences


Unnamed: 0,class,wanted,found,out
0,8ca.alpha.45,82,82,/mnt/c/Users/SAM/CODE/HMMR/per_class_aligned/8...
1,8ca.beta.45,93,93,/mnt/c/Users/SAM/CODE/HMMR/per_class_aligned/8...
2,8ca.delta.45,50,50,/mnt/c/Users/SAM/CODE/HMMR/per_class_aligned/8...
3,8ca.eta.45,30,30,/mnt/c/Users/SAM/CODE/HMMR/per_class_aligned/8...
4,8ca.gamma.45,87,87,/mnt/c/Users/SAM/CODE/HMMR/per_class_aligned/8...
5,8ca.iota.45,97,97,/mnt/c/Users/SAM/CODE/HMMR/per_class_aligned/8...
6,8ca.theta.45,41,41,/mnt/c/Users/SAM/CODE/HMMR/per_class_aligned/8...
7,8ca.zeta.45,19,19,/mnt/c/Users/SAM/CODE/HMMR/per_class_aligned/8...


[OK] wrote /mnt/c/Users/SAM/CODE/HMMR/01_slice_report.csv


In [43]:
SPLIT_DIR.mkdir(exist_ok=True)

split_rows = []
for f in sorted(OG.glob("*.txt")):
    cls = f.stem
    ids = [ first_token(r.id or r.description) for r in SeqIO.parse(str(f), "fasta") ]
    ids = sorted(set(ids))
    if len(ids) < 3:                          # too small to split meaningfully
        split_rows.append({"class":cls, "n_total":len(ids), "n_train":len(ids), "n_test":0, "note":"too_small"})
        # still write a train file with all ids to avoid losing the class
        (SPLIT_DIR / f"{cls}.train.ids").write_text("\n".join(ids) + ("\n" if ids else ""))
        (SPLIT_DIR / f"{cls}.test.ids").write_text("")
        continue

    n_train = max(2, int(0.8 * len(ids)))     # ensure at least 2 in train
    random.shuffle(ids)
    train_ids = sorted(ids[:n_train])
    test_ids  = sorted(ids[n_train:])
    (SPLIT_DIR / f"{cls}.train.ids").write_text("\n".join(train_ids) + "\n")
    (SPLIT_DIR / f"{cls}.test.ids").write_text("\n".join(test_ids) + "\n")
    split_rows.append({"class":cls, "n_total":len(ids), "n_train":len(train_ids), "n_test":len(test_ids), "note":""})

split_report = pd.DataFrame(split_rows).sort_values("class")
display(split_report)
split_report_path = BASE / "02_split_report.csv"
split_report.to_csv(split_report_path, index=False)
print("[OK] wrote", split_report_path)


Unnamed: 0,class,n_total,n_train,n_test,note
0,8ca.alpha.45,82,65,17,
1,8ca.beta.45,93,74,19,
2,8ca.delta.45,50,40,10,
3,8ca.eta.45,30,24,6,
4,8ca.gamma.45,87,69,18,
5,8ca.iota.45,97,77,20,
6,8ca.theta.45,41,32,9,
7,8ca.zeta.45,19,15,4,


[OK] wrote /mnt/c/Users/SAM/CODE/HMMR/02_split_report.csv


In [44]:
TRAIN_ALN.mkdir(exist_ok=True)
rows=[]
for aln in sorted(ALIGN_DIR.glob("*.aln.fasta")):
    cls = aln.stem.replace(".aln","")
    train_ids_path = SPLIT_DIR / f"{cls}.train.ids"
    if not train_ids_path.exists():
        rows.append({"class":cls, "train_written":0, "note":"no_train_ids"}); continue
    train_ids = set(train_ids_path.read_text().strip().splitlines())
    in_recs = list(SeqIO.parse(str(aln), "fasta"))
    out_recs = [ r for r in in_recs if first_token(r.id or r.description) in train_ids ]
    out_path = TRAIN_ALN / f"{cls}.train.aln.fasta"
    if out_recs:
        SeqIO.write(out_recs, str(out_path), "fasta")
    rows.append({"class":cls, "train_written":len(out_recs), "note":"" if out_recs else "empty"})

train_aln_report = pd.DataFrame(rows).sort_values("class")
display(train_aln_report.head(10))
train_aln_report_path = BASE / "03_train_alignment_report.csv"
train_aln_report.to_csv(train_aln_report_path, index=False)
print("[OK] wrote", train_aln_report_path)


Unnamed: 0,class,train_written,note
0,8ca.alpha.45,65,
1,8ca.beta.45,74,
2,8ca.delta.45,40,
3,8ca.eta.45,24,
4,8ca.gamma.45,69,
5,8ca.iota.45,77,
6,8ca.theta.45,32,
7,8ca.zeta.45,15,


[OK] wrote /mnt/c/Users/SAM/CODE/HMMR/03_train_alignment_report.csv


In [45]:
CLEAN_ALN.mkdir(exist_ok=True)
valid = set("ACDEFGHIKLMNPQRSTVWYBXZ-")  # dash for gaps; allow X

def clean_strict(seq_str: str) -> str:
    s = seq_str.upper().replace(".", "-")
    return "".join(ch if ch in valid else "X" for ch in s)

def enforce_modal_length(records):
    lengths = [len(r.seq) for r in records]
    if not lengths: return [], 0, 0
    modal_len, _ = collections.Counter(lengths).most_common(1)[0]
    kept = [r for r in records if len(r.seq) == modal_len]
    return kept, modal_len, len(records) - len(kept)

rows=[]
for aln in sorted(TRAIN_ALN.glob("*.train.aln.fasta")):
    cls = aln.stem.replace(".train.aln","")
    recs = []
    repl = 0
    for r in SeqIO.parse(str(aln), "fasta"):
        raw = str(r.seq)
        cleaned = clean_strict(raw)
        repl += sum(1 for a,b in zip(raw.upper(), cleaned) if a != b)
        recs.append(SeqRecord.SeqRecord(Seq.Seq(cleaned),
                                        id=first_token(r.id or r.description),
                                        description=""))
    kept, modal_len, dropped = enforce_modal_length(recs)
    out = CLEAN_ALN / f"{cls}.train.clean.aln.fasta"
    if kept:
        SeqIO.write(kept, str(out), "fasta")
    rows.append({"class":cls,"n_input":len(recs),"n_kept":len(kept),
                 "dropped":dropped,"modal_len":modal_len,"replacements":repl})

clean_report = pd.DataFrame(rows).sort_values("class")
display(clean_report.head(10))
clean_report_path = BASE / "04_clean_report.csv"
clean_report.to_csv(clean_report_path, index=False)
print("[OK] wrote", clean_report_path)


Unnamed: 0,class,n_input,n_kept,dropped,modal_len,replacements
0,8ca.alpha.45,65,65,0,3783,0
1,8ca.beta.45,74,74,0,3783,0
2,8ca.delta.45,40,40,0,3783,0
3,8ca.eta.45,24,24,0,3783,0
4,8ca.gamma.45,69,69,0,3783,0
5,8ca.iota.45,77,77,0,3783,0
6,8ca.theta.45,32,32,0,3783,0
7,8ca.zeta.45,15,15,0,3783,0


[OK] wrote /mnt/c/Users/SAM/CODE/HMMR/04_clean_report.csv


In [46]:
TRIM_ALN.mkdir(exist_ok=True)

def trim_alignment(in_fa: Path, out_fa: Path, min_symbol_frac=0.10):
    recs = list(SeqIO.parse(str(in_fa), "fasta"))
    if not recs:
        return {"class": in_fa.stem, "kept_cols": 0, "orig_cols": 0, "nseq": 0}
    arr = np.array([list(str(r.seq)) for r in recs])  # shape (N, L)
    gapmask = (arr == "-")
    symfrac = 1.0 - gapmask.mean(axis=0)
    keep = symfrac >= min_symbol_frac
    kept_cols = int(keep.sum()); L = arr.shape[1]
    if kept_cols == 0:
        keep = symfrac > 0.0
        kept_cols = int(keep.sum())
    if kept_cols == 0:
        return {"class": in_fa.stem, "kept_cols": 0, "orig_cols": L, "nseq": len(recs)}
    trimmed = ["".join(row[keep]) for row in arr]
    out_recs = [SeqRecord.SeqRecord(Seq.Seq(s), id=first_token(r.id or r.description), description="")
                for s, r in zip(trimmed, recs)]
    SeqIO.write(out_recs, str(out_fa), "fasta")
    return {"class": in_fa.stem, "kept_cols": kept_cols, "orig_cols": L, "nseq": len(recs)}

rows=[]
for aln in sorted(CLEAN_ALN.glob("*.train.clean.aln.fasta")):
    out = TRIM_ALN / aln.name.replace(".clean", "")
    rows.append(trim_alignment(aln, out, min_symbol_frac=0.10))

trim_report = pd.DataFrame(rows).sort_values("class")
display(trim_report.head(10))
trim_report_path = BASE / "05_trim_report.csv"
trim_report.to_csv(trim_report_path, index=False)
print("[OK] wrote", trim_report_path)


Unnamed: 0,class,kept_cols,orig_cols,nseq
0,8ca.alpha.45.train.clean.aln,350,3783,65
1,8ca.beta.45.train.clean.aln,412,3783,74
2,8ca.delta.45.train.clean.aln,710,3783,40
3,8ca.eta.45.train.clean.aln,1501,3783,24
4,8ca.gamma.45.train.clean.aln,192,3783,69
5,8ca.iota.45.train.clean.aln,379,3783,77
6,8ca.theta.45.train.clean.aln,1109,3783,32
7,8ca.zeta.45.train.clean.aln,841,3783,15


[OK] wrote /mnt/c/Users/SAM/CODE/HMMR/05_trim_report.csv


In [47]:
# Build per-class HMMs with 80% dropout (keep columns with >=20% symbols); fallback to 0.0 if needed.
records = []
for aln in sorted(TRIM_ALN.glob("*.train.aln.fasta")):
    cls = aln.name.replace(".train.aln.fasta","")           # clean class name
    hmm = PROFILES / f"{cls}.hmm"
    log1= LOGS / f"{cls}.hmmbuild.log"
    log2= LOGS / f"{cls}.hmmbuild.retry.log"

    # quick sanity
    nseq  = sum(1 for _ in SeqIO.parse(str(aln), "fasta"))
    nsyms = sum((c != '-') for r in SeqIO.parse(str(aln), "fasta") for c in str(r.seq))
    if nseq < 2 or nsyms == 0:
        records.append({"class":cls,"nseq":nseq,"symbols":nsyms,"status":"skip","note":"too_small_or_empty"})
        continue

    # NOTE: HMMER uses -n (short) not --name
    p1 = run(f'hmmbuild --amino --symfrac 0.2 -n "{cls}" --cpu 4 "{hmm}" "{aln}"', log=str(log1))
    if p1.returncode == 0 and hmm.exists() and hmm.stat().st_size > 0:
        records.append({"class":cls,"nseq":nseq,"symbols":nsyms,"status":"ok","note":"symfrac=0.2"})
        continue

    p2 = run(f'hmmbuild --amino --symfrac 0.0 -n "{cls}" --cpu 4 "{hmm}" "{aln}"', log=str(log2))
    if p2.returncode == 0 and hmm.exists() and hmm.stat().st_size > 0:
        records.append({"class":cls,"nseq":nseq,"symbols":nsyms,"status":"ok","note":"fallback symfrac=0.0"})
    else:
        records.append({"class":cls,"nseq":nseq,"symbols":nsyms,"status":"fail","note":"see logs"})

hmr = pd.DataFrame(records).sort_values(["status","class"])
display(hmr)
hmr_path = BASE / "06_hmmbuild_report.csv"
hmr.to_csv(hmr_path, index=False)
print("[OK] wrote", hmr_path)

# Combine to a single HMM lib on /mnt/c then copy to Linux FS and press there
combined = PROFILES / "all_classes.hmm"
if combined.exists(): combined.unlink()
parts = [p for p in sorted(PROFILES.glob("*.hmm")) if p.name != "all_classes.hmm" and p.stat().st_size > 0]
assert parts, "No non-empty HMMs built; check 06_hmmbuild_report.csv and logs."

with open(combined, "w") as out:
    for p in parts:
        s = p.read_text()
        out.write(s if s.endswith("\n") else s + "\n")
print(f"[combine] {len(parts)} models -> {combined} ({combined.stat().st_size} bytes)")

# Copy to Linux FS & hmmpress to avoid Windows file locking
dst = HMM_LIB / "all_classes.hmm"
shutil.copy2(combined, dst)

# Clean any stale indices at destination and press
for ext in (".h3f",".h3i",".h3m",".h3p"):
    p = Path(str(dst) + ext)
    if p.exists(): p.unlink()

press = run(f'hmmpress "{dst}"', log=str(LOGS / "hmmpress_all_classes.log"))
if press.returncode != 0:
    print(Path(LOGS / "hmmpress_all_classes.log").read_text()[:1000])
    raise SystemError("hmmpress failed; see log.")

ALL_HMM = dst   # <- use this path in later hmmscan calls
print("[hmmpress] OK at", ALL_HMM)


$ hmmbuild --amino --symfrac 0.2 -n "8ca.alpha.45" --cpu 4 "/mnt/c/Users/SAM/CODE/HMMR/profiles/8ca.alpha.45.hmm" "/mnt/c/Users/SAM/CODE/HMMR/train_aligned_trimmed/8ca.alpha.45.train.aln.fasta"
$ hmmbuild --amino --symfrac 0.2 -n "8ca.beta.45" --cpu 4 "/mnt/c/Users/SAM/CODE/HMMR/profiles/8ca.beta.45.hmm" "/mnt/c/Users/SAM/CODE/HMMR/train_aligned_trimmed/8ca.beta.45.train.aln.fasta"
$ hmmbuild --amino --symfrac 0.2 -n "8ca.delta.45" --cpu 4 "/mnt/c/Users/SAM/CODE/HMMR/profiles/8ca.delta.45.hmm" "/mnt/c/Users/SAM/CODE/HMMR/train_aligned_trimmed/8ca.delta.45.train.aln.fasta"
$ hmmbuild --amino --symfrac 0.2 -n "8ca.eta.45" --cpu 4 "/mnt/c/Users/SAM/CODE/HMMR/profiles/8ca.eta.45.hmm" "/mnt/c/Users/SAM/CODE/HMMR/train_aligned_trimmed/8ca.eta.45.train.aln.fasta"
$ hmmbuild --amino --symfrac 0.2 -n "8ca.gamma.45" --cpu 4 "/mnt/c/Users/SAM/CODE/HMMR/profiles/8ca.gamma.45.hmm" "/mnt/c/Users/SAM/CODE/HMMR/train_aligned_trimmed/8ca.gamma.45.train.aln.fasta"
$ hmmbuild --amino --symfrac 0.2 -n "8c

Unnamed: 0,class,nseq,symbols,status,note
0,8ca.alpha.45,65,10869,ok,symfrac=0.2
1,8ca.beta.45,74,15310,ok,symfrac=0.2
2,8ca.delta.45,40,13079,ok,symfrac=0.2
3,8ca.eta.45,24,14496,ok,symfrac=0.2
4,8ca.gamma.45,69,11130,ok,symfrac=0.2
5,8ca.iota.45,77,13583,ok,symfrac=0.2
6,8ca.theta.45,32,11683,ok,symfrac=0.2
7,8ca.zeta.45,15,3230,ok,symfrac=0.2


[OK] wrote /mnt/c/Users/SAM/CODE/HMMR/06_hmmbuild_report.csv
[combine] 8 models -> /mnt/c/Users/SAM/CODE/HMMR/profiles/all_classes.hmm (2431234 bytes)
$ hmmpress "/home/sfkaplan/hmmer_lib/all_classes.hmm"
[hmmpress] OK at /home/sfkaplan/hmmer_lib/all_classes.hmm


In [50]:
# === Robust held-out validation (replaces previous Cell 7) ===
from pathlib import Path
from Bio import SeqIO
import pandas as pd, re, subprocess, shlex

BASE     = Path("/mnt/c/Users/SAM/CODE/HMMR")
OG       = BASE / "OG_Labels"
SPLIT    = BASE / "per_class_split"
TMP      = BASE / "tmp";     TMP.mkdir(exist_ok=True)
RESULTS  = BASE / "results"; RESULTS.mkdir(exist_ok=True)
LOGS     = BASE / "logs";    LOGS.mkdir(exist_ok=True)

ALL_HMM  = Path.home() / "hmmer_lib" / "all_classes.hmm"  # pressed in Linux home

def first_token(s: str) -> str:
    return s.split()[0] if s else s

# 1) Gather the complete set of held-out (test) IDs per class
test_map = {}   # seq_id -> true_class
perclass_counts = []
for f in sorted(OG.glob("*.txt")):
    cls = f.stem
    tid_path = SPLIT / f"{cls}.test.ids"
    if not tid_path.exists():
        perclass_counts.append({"class": cls, "n_test": 0}); continue
    ids = [ln.strip() for ln in tid_path.read_text().splitlines() if ln.strip()]
    for sid in ids:
        test_map[sid] = cls
    perclass_counts.append({"class": cls, "n_test": len(ids)})

perclass_counts = pd.DataFrame(perclass_counts).sort_values("class")
total_test = sum(perclass_counts["n_test"])
print(f"[held-out] total test sequences requested: {total_test}")
display(perclass_counts.head(20))

# 2) Build a FASTA with ALL held-out sequences
TEST_MERGED = TMP / "heldout_test_sequences.faa"
with open(TEST_MERGED, "w") as out:
    for f in sorted(OG.glob("*.txt")):
        cls = f.stem
        tid_path = SPLIT / f"{cls}.test.ids"
        if not tid_path.exists(): continue
        want = set(x.strip() for x in tid_path.read_text().splitlines() if x.strip())
        if not want: continue
        for rec in SeqIO.parse(str(f), "fasta"):
            sid = first_token(rec.id or rec.description)
            if sid in want:
                rec.id = sid
                rec.description = sid
                SeqIO.write(rec, out, "fasta")

# Sanity check: how many actually written?
n_written = sum(1 for _ in SeqIO.parse(str(TEST_MERGED), "fasta"))
print(f"[held-out] wrote {n_written} sequences to {TEST_MERGED}")

# 3) hmmscan against the pressed library
tbl = RESULTS / "hmmscan_heldout.tbl"
dom = RESULTS / "hmmscan_heldout.domtbl"
cmd = f'hmmscan --cpu 4 --tblout "{tbl}" --domtblout "{dom}" "{ALL_HMM}" "{TEST_MERGED}"'
print("$", cmd)
subprocess.run(shlex.split(cmd), check=True)

# 4) Robust parse of --tblout and pick best hit per sequence (if any)
colnames = [
    "target","tacc","tlen","query","qacc","qlen",
    "fs_evalue","fs_score","fs_bias",
    "n_dom","n_dom_exp","dom_cevalue","dom_ievalue",
    "hmmfrom","hmmto","alifrom","alito","envfrom","envto","acc","desc"
]
hits = pd.read_csv(tbl, sep=r"\s+", comment="#", header=None, names=colnames, engine="python")

# Keep essentials; sometimes 'query' can be '-' if parse breaks — filter those out
hits = hits.loc[hits["query"].notna() & (hits["query"] != "-"), ["query","target","fs_score","fs_evalue"]]
hits = hits.rename(columns={"query":"seq_id","target":"pred_class","fs_score":"bit_score","fs_evalue":"evalue"})

# Best hit per sequence
if not hits.empty:
    ranked = hits.sort_values(["seq_id","bit_score"], ascending=[True,False])
    best_hits = ranked.groupby("seq_id", as_index=False).first()
else:
    best_hits = pd.DataFrame(columns=["seq_id","pred_class","bit_score","evalue"])

# 5) Join with the full list of test IDs so NO-HIT seqs are included
truth = pd.Series(test_map, name="true_class")
full = pd.DataFrame({"seq_id": list(test_map.keys())}).merge(best_hits, on="seq_id", how="left")

# Mark no-hit rows explicitly
full["pred_class"] = full["pred_class"].fillna("NO_HIT")
# optional: set missing scores to NaN (already NaN)
# Compute correctness (NO_HIT counts as incorrect)
full = full.join(truth, on="seq_id")
full["correct"] = (full["pred_class"] == full["true_class"])

# 6) Metrics + artifacts
n_eval      = len(full)
n_hits      = int((full["pred_class"] != "NO_HIT").sum())
coverage    = n_hits / n_eval if n_eval else float("nan")
accuracy    = full["correct"].mean() if n_eval else float("nan")

print(f"[held-out] evaluated: {n_eval}  |  with ≥1 hit: {n_hits}  (coverage={coverage:.3f})")
print(f"[held-out] accuracy:  {accuracy:.3f}")

# Confusion matrix including NO_HIT column
cm = pd.crosstab(full["true_class"], full["pred_class"])
display(cm.head(20))

# Save reports
full_out = BASE / "07_heldout_predictions.csv"
cm_out   = BASE / "07_confusion_matrix.csv"
cov_out  = BASE / "07_heldout_summary.txt"
full.to_csv(full_out, index=False)
cm.to_csv(cm_out)
cov_out.write_text(
    f"total_test={total_test}\n"
    f"evaluated={n_eval}\n"
    f"with_hits={n_hits}\n"
    f"coverage={coverage:.4f}\n"
    f"accuracy={accuracy:.4f}\n"
)
print("[OK] wrote", full_out, cm_out, cov_out)


[held-out] total test sequences requested: 103


Unnamed: 0,class,n_test
0,8ca.alpha.45,17
1,8ca.beta.45,19
2,8ca.delta.45,10
3,8ca.eta.45,6
4,8ca.gamma.45,18
5,8ca.iota.45,20
6,8ca.theta.45,9
7,8ca.zeta.45,4


[held-out] wrote 103 sequences to /mnt/c/Users/SAM/CODE/HMMR/tmp/heldout_test_sequences.faa
$ hmmscan --cpu 4 --tblout "/mnt/c/Users/SAM/CODE/HMMR/results/hmmscan_heldout.tbl" --domtblout "/mnt/c/Users/SAM/CODE/HMMR/results/hmmscan_heldout.domtbl" "/home/sfkaplan/hmmer_lib/all_classes.hmm" "/mnt/c/Users/SAM/CODE/HMMR/tmp/heldout_test_sequences.faa"
# hmmscan :: search sequence(s) against a profile database
# HMMER 3.4 (Aug 2023); http://hmmer.org/
# Copyright (C) 2023 Howard Hughes Medical Institute.
# Freely distributed under the BSD open source license.
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
# query sequence file:             /mnt/c/Users/SAM/CODE/HMMR/tmp/heldout_test_sequences.faa
# target HMM database:             /home/sfkaplan/hmmer_lib/all_classes.hmm
# per-seq hits tabular output:     /mnt/c/Users/SAM/CODE/HMMR/results/hmmscan_heldout.tbl
# per-dom hits tabular output:     /mnt/c/Users/SAM/CODE/HMMR/results/hmmscan_heldout.domtbl
# multithrea

pred_class,NO_HIT
true_class,Unnamed: 1_level_1
8ca.alpha.45,17
8ca.beta.45,19
8ca.delta.45,10
8ca.eta.45,6
8ca.gamma.45,18
8ca.iota.45,20
8ca.theta.45,9
8ca.zeta.45,4


[OK] wrote /mnt/c/Users/SAM/CODE/HMMR/07_heldout_predictions.csv /mnt/c/Users/SAM/CODE/HMMR/07_confusion_matrix.csv /mnt/c/Users/SAM/CODE/HMMR/07_heldout_summary.txt


# TEST Alpha

In [57]:
# === Held-out validation for ALPHA ONLY ===
from pathlib import Path
from Bio import SeqIO
import pandas as pd, subprocess, shlex

BASE     = Path("/mnt/c/Users/SAM/CODE/HMMR")
OG       = BASE / "OG_Labels"
SPLIT    = BASE / "per_class_split"
TMP      = BASE / "tmp";     TMP.mkdir(exist_ok=True)
RESULTS  = BASE / "results"; RESULTS.mkdir(exist_ok=True)

# HMM library (pressed) – keep all classes so we can detect off-target hits
ALL_HMM  = Path.home() / "hmmer_lib" / "all_classes.hmm"

# Configuration: evaluate only alpha subclasses
CLASS_PREFIX = "8ca.alpha."

def first_token(s: str) -> str:
    """Make sequence IDs consistent with the IDs in *.test.ids files."""
    return s.split()[0] if s else s

# 1) Collect ALPHA held-out IDs (true labels)
test_map = {}   # seq_id -> true_class (e.g., 8ca.alpha.45)
perclass_counts = []
for f in sorted(OG.glob("*.txt")):
    cls = f.stem
    if not cls.startswith(CLASS_PREFIX):
        continue  # skip non-alpha classes
    tid_path = SPLIT / f"{cls}.test.ids"
    if not tid_path.exists():
        perclass_counts.append({"class": cls, "n_test": 0})
        continue
    ids = [ln.strip() for ln in tid_path.read_text().splitlines() if ln.strip()]
    for sid in ids:
        test_map[sid] = cls
    perclass_counts.append({"class": cls, "n_test": len(ids)})

perclass_counts = pd.DataFrame(perclass_counts).sort_values("class")
total_test = int(perclass_counts["n_test"].sum()) if not perclass_counts.empty else 0
print(f"[alpha held-out] total test sequences requested: {total_test}")
display(perclass_counts)

# 2) Build a FASTA containing only ALPHA held-out sequences
TEST_MERGED = TMP / "heldout_alpha_test_sequences.faa"
with open(TEST_MERGED, "w") as out:
    for f in sorted(OG.glob("*.txt")):
        cls = f.stem
        if not cls.startswith(CLASS_PREFIX):
            continue
        tid_path = SPLIT / f"{cls}.test.ids"
        if not tid_path.exists():
            continue
        want = set(x.strip() for x in tid_path.read_text().splitlines() if x.strip())
        if not want:
            continue
        for rec in SeqIO.parse(str(f), "fasta"):
            sid = first_token(rec.id or rec.description)
            if sid in want:
                rec.id = sid
                rec.description = sid
                SeqIO.write(rec, out, "fasta")

# Sanity check: how many actually written?
n_written = sum(1 for _ in SeqIO.parse(str(TEST_MERGED), "fasta"))
print(f"[alpha held-out] wrote {n_written} sequences to {TEST_MERGED}")

# 3) hmmscan against the complete profile library
tbl = RESULTS / "hmmscan_alpha_heldout.tbl"
dom = RESULTS / "hmmscan_alpha_heldout.domtbl"
cmd = f'hmmscan --cpu 4 --tblout "{tbl}" --domtblout "{dom}" "{ALL_HMM}" "{TEST_MERGED}"'
print("$", cmd)
subprocess.run(shlex.split(cmd), check=True)
# 4) Robust parse of --domtblout and pick best full-seq hit per sequence
# HMMER 3.4 domtblout columns (per manual):
# HMMER 3.4 domtblout columns (hmmscan)
dom_cols = [
    "target","tacc","tlen",
    "query","qacc","qlen",
    "fs_evalue","fs_score","fs_bias",
    "#","of",                         # <— the two you were missing
    "dom_cevalue","dom_ievalue","dom_score","dom_bias",
    "hmmfrom","hmmto","alifrom","alito","envfrom","envto","acc",
    "desc"
]

dom_hits = pd.read_csv(
    dom,
    sep=r"\s+",
    comment="#",
    header=None,
    engine="python",
    names=dom_cols,
    usecols=list(range(0, 23))  # everything up to (but not including) desc
)

# Keep essentials & pick best full-seq hit per query
dom_hits = dom_hits.rename(columns={
    "query":"seq_id", "target":"pred_class",
    "fs_score":"bit_score", "fs_evalue":"evalue"
})[["seq_id","pred_class","bit_score","evalue"]]

ranked = dom_hits.sort_values(["seq_id","bit_score"], ascending=[True, False])
best_hits = ranked.groupby("seq_id", as_index=False).first()


# Best full-seq hit per sequence
if not dom_hits.empty:
    ranked = dom_hits.sort_values(["seq_id", "bit_score"], ascending=[True, False])
    best_hits = ranked.groupby("seq_id", as_index=False).first()
else:
    best_hits = pd.DataFrame(columns=["seq_id", "pred_class", "bit_score", "evalue"])

# 5) (unchanged) Join with the full list of test IDs so NO-HIT seqs are included
truth = pd.Series(test_map, name="true_class")
full = pd.DataFrame({"seq_id": list(test_map.keys())}).merge(best_hits, on="seq_id", how="left")
full["pred_class"] = full["pred_class"].fillna("NO_HIT")
full = full.join(truth, on="seq_id")
full["correct"] = (full["pred_class"] == full["true_class"])


# 6) Metrics + confusion matrix (restricted to ALPHA true classes)
n_eval   = len(full)
n_hits   = int((full["pred_class"] != "NO_HIT").sum())
coverage = (n_hits / n_eval) if n_eval else float("nan")
accuracy = full["correct"].mean() if n_eval else float("nan")

print(f"[alpha held-out] evaluated: {n_eval}  |  with ≥1 hit: {n_hits}  (coverage={coverage:.3f})")
print(f"[alpha held-out] accuracy:  {accuracy:.3f}")

# Confusion matrix: rows = true alpha subclass, columns = predicted model (incl. NO_HIT)
cm = pd.crosstab(full["true_class"], full["pred_class"])
display(cm)

# 7) Save artifacts
full_out = BASE / "07_alpha_heldout_predictions.csv"
cm_out   = BASE / "07_alpha_confusion_matrix.csv"
sum_out  = BASE / "07_alpha_heldout_summary.txt"

full.to_csv(full_out, index=False)
cm.to_csv(cm_out)
sum_out.write_text(
    f"total_test={total_test}\n"
    f"evaluated={n_eval}\n"
    f"with_hits={n_hits}\n"
    f"coverage={coverage:.4f}\n"
    f"accuracy={accuracy:.4f}\n"
)
print("[OK] wrote", full_out, cm_out, sum_out)


[alpha held-out] total test sequences requested: 17


Unnamed: 0,class,n_test
0,8ca.alpha.45,17


[alpha held-out] wrote 17 sequences to /mnt/c/Users/SAM/CODE/HMMR/tmp/heldout_alpha_test_sequences.faa
$ hmmscan --cpu 4 --tblout "/mnt/c/Users/SAM/CODE/HMMR/results/hmmscan_alpha_heldout.tbl" --domtblout "/mnt/c/Users/SAM/CODE/HMMR/results/hmmscan_alpha_heldout.domtbl" "/home/sfkaplan/hmmer_lib/all_classes.hmm" "/mnt/c/Users/SAM/CODE/HMMR/tmp/heldout_alpha_test_sequences.faa"
# hmmscan :: search sequence(s) against a profile database
# HMMER 3.4 (Aug 2023); http://hmmer.org/
# Copyright (C) 2023 Howard Hughes Medical Institute.
# Freely distributed under the BSD open source license.
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
# query sequence file:             /mnt/c/Users/SAM/CODE/HMMR/tmp/heldout_alpha_test_sequences.faa
# target HMM database:             /home/sfkaplan/hmmer_lib/all_classes.hmm
# per-seq hits tabular output:     /mnt/c/Users/SAM/CODE/HMMR/results/hmmscan_alpha_heldout.tbl
# per-dom hits tabular output:     /mnt/c/Users/SAM/CODE/HMMR/re

pred_class,8ca.alpha.45
true_class,Unnamed: 1_level_1
8ca.alpha.45,17


[OK] wrote /mnt/c/Users/SAM/CODE/HMMR/07_alpha_heldout_predictions.csv /mnt/c/Users/SAM/CODE/HMMR/07_alpha_confusion_matrix.csv /mnt/c/Users/SAM/CODE/HMMR/07_alpha_heldout_summary.txt


# TEsst All AGAIN

In [58]:
# === Robust held-out validation (ALL classes) ===
from pathlib import Path
from Bio import SeqIO
import pandas as pd, subprocess, shlex

BASE     = Path("/mnt/c/Users/SAM/CODE/HMMR")
OG       = BASE / "OG_Labels"
SPLIT    = BASE / "per_class_split"
TMP      = BASE / "tmp";     TMP.mkdir(exist_ok=True)
RESULTS  = BASE / "results"; RESULTS.mkdir(exist_ok=True)
LOGS     = BASE / "logs";    LOGS.mkdir(exist_ok=True)

ALL_HMM  = Path.home() / "hmmer_lib" / "all_classes.hmm"  # pressed in Linux home

def first_token(s: str) -> str:
    return s.split()[0] if s else s

# 1) Gather the complete set of held-out (test) IDs per class
test_map = {}   # seq_id -> true_class
perclass_counts = []
for f in sorted(OG.glob("*.txt")):
    cls = f.stem
    tid_path = SPLIT / f"{cls}.test.ids"
    if not tid_path.exists():
        perclass_counts.append({"class": cls, "n_test": 0})
        continue
    ids = [ln.strip() for ln in tid_path.read_text().splitlines() if ln.strip()]
    for sid in ids:
        test_map[sid] = cls
    perclass_counts.append({"class": cls, "n_test": len(ids)})

perclass_counts = pd.DataFrame(perclass_counts).sort_values("class")
total_test = int(perclass_counts["n_test"].sum()) if not perclass_counts.empty else 0
print(f"[held-out] total test sequences requested: {total_test}")
display(perclass_counts.head(20))

# 2) Build a FASTA with ALL held-out sequences
TEST_MERGED = TMP / "heldout_test_sequences.faa"
with open(TEST_MERGED, "w") as out:
    for f in sorted(OG.glob("*.txt")):
        cls = f.stem
        tid_path = SPLIT / f"{cls}.test.ids"
        if not tid_path.exists():
            continue
        want = set(x.strip() for x in tid_path.read_text().splitlines() if x.strip())
        if not want:
            continue
        for rec in SeqIO.parse(str(f), "fasta"):
            sid = first_token(rec.id or rec.description)
            if sid in want:
                rec.id = sid
                rec.description = sid
                SeqIO.write(rec, out, "fasta")

# Sanity check: how many actually written?
n_written = sum(1 for _ in SeqIO.parse(str(TEST_MERGED), "fasta"))
print(f"[held-out] wrote {n_written} sequences to {TEST_MERGED}")

# 3) hmmscan against the pressed library
tbl = RESULTS / "hmmscan_heldout.tbl"
dom = RESULTS / "hmmscan_heldout.domtbl"
cmd = f'hmmscan --cpu 4 --tblout "{tbl}" --domtblout "{dom}" "{ALL_HMM}" "{TEST_MERGED}"'
print("$", cmd)
subprocess.run(shlex.split(cmd), check=True)

# 4) Robust parse of --domtblout and pick best full-seq hit per sequence
# HMMER 3.4 domtblout columns (hmmscan)
dom_cols = [
    "target","tacc","tlen",
    "query","qacc","qlen",
    "fs_evalue","fs_score","fs_bias",
    "#","of",
    "dom_cevalue","dom_ievalue","dom_score","dom_bias",
    "hmmfrom","hmmto","alifrom","alito","envfrom","envto","acc",
    "desc"
]

dom_hits = pd.read_csv(
    dom,
    sep=r"\s+",
    comment="#",
    header=None,
    engine="python",
    names=dom_cols,
    usecols=list(range(0, 23))  # exclude 'desc' which may contain spaces
)

# Keep essentials; pick best full-seq hit per query
dom_hits = dom_hits.rename(columns={
    "query":"seq_id",
    "target":"pred_class",
    "fs_score":"bit_score",
    "fs_evalue":"evalue"
})[["seq_id","pred_class","bit_score","evalue"]]

if not dom_hits.empty:
    ranked = dom_hits.sort_values(["seq_id", "bit_score"], ascending=[True, False])
    best_hits = ranked.groupby("seq_id", as_index=False).first()
else:
    best_hits = pd.DataFrame(columns=["seq_id", "pred_class", "bit_score", "evalue"])

# 5) Join with the full list of test IDs so NO-HIT seqs are included
truth = pd.Series(test_map, name="true_class")
full = pd.DataFrame({"seq_id": list(test_map.keys())}).merge(best_hits, on="seq_id", how="left")

# Mark no-hit rows explicitly; compute correctness
full["pred_class"] = full["pred_class"].fillna("NO_HIT")
full = full.join(truth, on="seq_id")
full["correct"] = (full["pred_class"] == full["true_class"])

# 6) Metrics + artifacts
n_eval      = len(full)
n_hits      = int((full["pred_class"] != "NO_HIT").sum())
coverage    = n_hits / n_eval if n_eval else float("nan")
accuracy    = full["correct"].mean() if n_eval else float("nan")

print(f"[held-out] evaluated: {n_eval}  |  with ≥1 hit: {n_hits}  (coverage={coverage:.3f})")
print(f"[held-out] accuracy:  {accuracy:.3f}")

# Confusion matrix including NO_HIT column
cm = pd.crosstab(full["true_class"], full["pred_class"])
display(cm.head(20))

# 7) Save reports
full_out = BASE / "07_heldout_predictions.csv"
cm_out   = BASE / "07_confusion_matrix.csv"
cov_out  = BASE / "07_heldout_summary.txt"

full.to_csv(full_out, index=False)
cm.to_csv(cm_out)
cov_out.write_text(
    f"total_test={total_test}\n"
    f"evaluated={n_eval}\n"
    f"with_hits={n_hits}\n"
    f"coverage={coverage:.4f}\n"
    f"accuracy={accuracy:.4f}\n"
)
print("[OK] wrote", full_out, cm_out, cov_out)


[held-out] total test sequences requested: 103


Unnamed: 0,class,n_test
0,8ca.alpha.45,17
1,8ca.beta.45,19
2,8ca.delta.45,10
3,8ca.eta.45,6
4,8ca.gamma.45,18
5,8ca.iota.45,20
6,8ca.theta.45,9
7,8ca.zeta.45,4


[held-out] wrote 103 sequences to /mnt/c/Users/SAM/CODE/HMMR/tmp/heldout_test_sequences.faa
$ hmmscan --cpu 4 --tblout "/mnt/c/Users/SAM/CODE/HMMR/results/hmmscan_heldout.tbl" --domtblout "/mnt/c/Users/SAM/CODE/HMMR/results/hmmscan_heldout.domtbl" "/home/sfkaplan/hmmer_lib/all_classes.hmm" "/mnt/c/Users/SAM/CODE/HMMR/tmp/heldout_test_sequences.faa"
# hmmscan :: search sequence(s) against a profile database
# HMMER 3.4 (Aug 2023); http://hmmer.org/
# Copyright (C) 2023 Howard Hughes Medical Institute.
# Freely distributed under the BSD open source license.
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
# query sequence file:             /mnt/c/Users/SAM/CODE/HMMR/tmp/heldout_test_sequences.faa
# target HMM database:             /home/sfkaplan/hmmer_lib/all_classes.hmm
# per-seq hits tabular output:     /mnt/c/Users/SAM/CODE/HMMR/results/hmmscan_heldout.tbl
# per-dom hits tabular output:     /mnt/c/Users/SAM/CODE/HMMR/results/hmmscan_heldout.domtbl
# multithrea

pred_class,8ca.alpha.45,8ca.beta.45,8ca.delta.45,8ca.eta.45,8ca.gamma.45,8ca.iota.45,8ca.theta.45,8ca.zeta.45
true_class,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
8ca.alpha.45,17,0,0,0,0,0,0,0
8ca.beta.45,0,19,0,0,0,0,0,0
8ca.delta.45,0,0,10,0,0,0,0,0
8ca.eta.45,0,0,0,6,0,0,0,0
8ca.gamma.45,0,0,0,0,18,0,0,0
8ca.iota.45,0,0,0,0,0,20,0,0
8ca.theta.45,0,0,0,0,0,0,9,0
8ca.zeta.45,0,0,0,0,0,0,0,4


[OK] wrote /mnt/c/Users/SAM/CODE/HMMR/07_heldout_predictions.csv /mnt/c/Users/SAM/CODE/HMMR/07_confusion_matrix.csv /mnt/c/Users/SAM/CODE/HMMR/07_heldout_summary.txt


# Label More sequences

In [None]:
# === Label the first N UniProt sequences with HMMER (minimal CSV) ===
from pathlib import Path
from Bio import SeqIO
import pandas as pd, subprocess, shlex

BASE        = Path("/mnt/c/Users/SAM/CODE/HMMR")
UNIPROT_BIG = BASE / "uniprotkb_carbonic_anhydrase_2025_08_22.fasta"
N_TOP       = 100000  # <-- change as needed

TMP      = BASE / "tmp";     TMP.mkdir(exist_ok=True)
RESULTS  = BASE / "results"; RESULTS.mkdir(exist_ok=True)
ALL_HMM  = Path.home() / "hmmer_lib" / "all_classes.hmm"  # pressed HMM library

def first_token(s: str) -> str:
    return s.split()[0] if s else s

# 1) Subset first N sequences
SUBSET_FASTA = TMP / f"uniprot_first_{N_TOP}.faa"
subset_records = []
for i, rec in enumerate(SeqIO.parse(str(UNIPROT_BIG), "fasta"), start=1):
    subset_records.append(rec)
    if i >= N_TOP:
        break
SeqIO.write(subset_records, str(SUBSET_FASTA), "fasta")
print(f"[label] wrote {len(subset_records)} sequences to {SUBSET_FASTA}")

# 2) Run hmmscan
tbl = RESULTS / f"hmmscan_uniprot_first_{N_TOP}.tbl"
dom = RESULTS / f"hmmscan_uniprot_first_{N_TOP}.domtbl"
cmd = f'hmmscan --cpu 4 --tblout "{tbl}" --domtblout "{dom}" "{ALL_HMM}" "{SUBSET_FASTA}"'
print("$", cmd)
subprocess.run(shlex.split(cmd), check=True)

# 3) Parse domtblout (stable format)
dom_cols = [
    "target","tacc","tlen",
    "query","qacc","qlen",
    "fs_evalue","fs_score","fs_bias",
    "#","of",
    "dom_cevalue","dom_ievalue","dom_score","dom_bias",
    "hmmfrom","hmmto","alifrom","alito","envfrom","envto","acc",
    "desc"
]
dom_hits = pd.read_csv(
    dom, sep=r"\s+", comment="#", header=None, engine="python",
    names=dom_cols, usecols=list(range(0,23))
)

# keep only essentials
dom_hits = dom_hits.rename(columns={
    "query":"seq_id","target":"pred_class",
    "fs_score":"bit_score","fs_evalue":"evalue",
    "qlen":"query_len","tlen":"target_len"
})[["seq_id","pred_class","bit_score","evalue","query_len","target_len"]]

# best hit per sequence
if not dom_hits.empty:
    ranked = dom_hits.sort_values(["seq_id","bit_score"], ascending=[True,False])
    best_hits = ranked.groupby("seq_id", as_index=False).first()
else:
    best_hits = pd.DataFrame(columns=["seq_id","pred_class","bit_score","evalue","query_len","target_len"])

# 4) Join with metadata for all input sequences
rows = []
for rec in subset_records:
    sid = first_token(rec.id or rec.description)
    rows.append({
        "seq_id": sid,
        "description": rec.description,
        "length": len(rec.seq),
    })
meta_df = pd.DataFrame(rows)

full = meta_df.merge(best_hits, on="seq_id", how="left")
full["pred_class"] = full["pred_class"].fillna("NO_HIT")

# 5) Save CSV
out_csv = BASE / f"08_uniprot_first_{N_TOP}_labels.csv"
full.to_csv(out_csv, index=False)
print("[OK] wrote", out_csv)

#display(full.head(20))


# Fix no hits

In [None]:
# === Rescue pass for NO_HITs (relaxed hmmscan), robust ===
from pathlib import Path
from Bio import SeqIO
import pandas as pd, subprocess, shlex, sys

BASE        = Path("/mnt/c/Users/SAM/CODE/HMMR")
UNIPROT_BIG = BASE / "uniprotkb_carbonic_anhydrase_2025_08_22.fasta"
RESULTS     = BASE / "results"; RESULTS.mkdir(exist_ok=True)
TMP         = BASE / "tmp";     TMP.mkdir(exist_ok=True)
ALL_HMM     = Path.home() / "hmmer_lib" / "all_classes.hmm"

FIRST_PASS_CSV = BASE / f"08_uniprot_first_{N_TOP}_labels.csv"  # <-- your existing first pass

def first_token(s: str) -> str:
    return s.split()[0] if s else s

def parse_domtbl(dom_path: Path) -> pd.DataFrame:
    dom_cols = [
        "target","tacc","tlen",
        "query","qacc","qlen",
        "fs_evalue","fs_score","fs_bias",
        "#","of",
        "dom_cevalue","dom_ievalue","dom_score","dom_bias",
        "hmmfrom","hmmto","alifrom","alito","envfrom","envto","acc",
        "desc"
    ]
    if not dom_path.exists() or dom_path.stat().st_size == 0:
        return pd.DataFrame(columns=["seq_id","pred_class","bit_score","evalue","query_len","target_len"])
    df = pd.read_csv(
        dom_path, sep=r"\s+", comment="#", header=None, engine="python",
        names=dom_cols, usecols=list(range(0,23))
    )
    if df.empty:
        return pd.DataFrame(columns=["seq_id","pred_class","bit_score","evalue","query_len","target_len"])
    df = df.rename(columns={
        "query":"seq_id", "target":"pred_class",
        "fs_score":"bit_score", "fs_evalue":"evalue",
        "qlen":"query_len", "tlen":"target_len"
    })[["seq_id","pred_class","bit_score","evalue","query_len","target_len"]]
    ranked = df.sort_values(["seq_id","bit_score"], ascending=[True, False])
    return ranked.groupby("seq_id", as_index=False).first()

def hmmscan_relaxed(dom_out: Path, query_fa: Path) -> pd.DataFrame:
    tbl_out = dom_out.with_suffix(".tbl")
    # Correct: per-seq threshold has short form (-E), per-domain is long (--domE)
    cmd = f'hmmscan --cpu 4 --tblout "{tbl_out}" --domtblout "{dom_out}" --max -E 100 --domE 100 "{ALL_HMM}" "{query_fa}"'
    print("$", cmd)
    proc = subprocess.run(shlex.split(cmd), text=True, capture_output=True)
    if proc.returncode != 0:
        print("[hmmscan stderr]\n" + (proc.stderr or "(no stderr)"))
        print("[hmmscan stdout]\n" + (proc.stdout or "(no stdout)"))
        raise subprocess.CalledProcessError(proc.returncode, cmd, output=proc.stdout, stderr=proc.stderr)
    return parse_domtbl(dom_out)


# 1) Load first-pass labels
first_df = pd.read_csv(FIRST_PASS_CSV)
assert {"seq_id","pred_class"}.issubset(first_df.columns), "first-pass CSV missing columns"

# 2) Gather NO_HITs
nohit_ids = set(first_df.loc[first_df["pred_class"] == "NO_HIT", "seq_id"])
if not nohit_ids:
    print("[rescue] nothing to rescue — no NO_HITs")
else:
    # 3) Rebuild rescue FASTA straight from UNIPROT_BIG
    rescue_fa = TMP / "uniprot_nohit_rescue.faa"
    written = 0
    # also collect a small sample of UniProt IDs for debugging
    uniprot_ids_sample = []
    with open(rescue_fa, "w") as out:
        for i, rec in enumerate(SeqIO.parse(str(UNIPROT_BIG), "fasta")):
            sid = first_token(rec.id or rec.description)
            if i < 1000:  # sample first 1000 for debugging
                uniprot_ids_sample.append(sid)
            if sid in nohit_ids:
                rec.id = sid
                rec.description = sid
                SeqIO.write(rec, out, "fasta")
                written += 1
    print(f"[rescue] wrote {written} NO_HIT sequences to {rescue_fa}")

    # 3a) Guard: if zero written, explain why and stop cleanly
    if written == 0:
        missing_preview = sorted(list(nohit_ids - set(uniprot_ids_sample)))[:10]
        print("[rescue] 0 sequences written. Likely the seq_id normalization differs between first-pass CSV and UniProt FASTA.")
        print("          Example NO_HIT ids not seen among first ~1000 UniProt headers:", missing_preview)
        print("          First few NO_HIT ids:", sorted(list(nohit_ids))[:10])
        print("          First few UniProt ids found:", uniprot_ids_sample[:10])
        print("          Ensure both steps use first_token() on IDs.")
        # Optionally: re-scan ALL of UniProt to confirm (can be slow); uncomment if needed:
        # all_uniprot_ids = { first_token(rec.id or rec.description) for rec in SeqIO.parse(str(UNIPROT_BIG), "fasta") }
        # print("          Example NO_HIT ids not present in full UniProt:", sorted(list(nohit_ids - all_uniprot_ids))[:10])
    else:
        # 4) Relaxed hmmscan on the NO_HITs
        dom_relaxed = RESULTS / "hmmscan_uniprot_rescue.domtbl"
        rescued = hmmscan_relaxed(dom_relaxed, rescue_fa)

        # 5) Acceptance rule (tune to taste)
        def accept(row):
            try:
                e = float(row["evalue"])
            except Exception:
                return False
            bs = float(row["bit_score"]) if pd.notna(row["bit_score"]) else -1e9
            return (e <= 1e-2) or (bs >= 25.0)

        rescued_kept = rescued.loc[rescued.apply(accept, axis=1)]

        # 6) Merge back into first_df
        if not rescued_kept.empty:
            rescued_map = rescued_kept.set_index("seq_id")[["pred_class","bit_score","evalue","query_len","target_len"]]
            mask = first_df["pred_class"].eq("NO_HIT") & first_df["seq_id"].isin(rescued_map.index)
            for col in ["pred_class","bit_score","evalue","query_len","target_len"]:
                first_df.loc[mask, col] = first_df.loc[mask, "seq_id"].map(rescued_map[col])
            print(f"[rescue] accepted {int(mask.sum())} upgraded labels")
        else:
            print("[rescue] no acceptable hits found under relaxed settings")

        # 7) Save updated CSV
        out_csv = FIRST_PASS_CSV.with_name(FIRST_PASS_CSV.stem + "_rescued.csv")
        first_df.to_csv(out_csv, index=False)
        print("[OK] wrote", out_csv)

In [99]:
# If you have the DataFrame loaded
df = pd.read_csv(out_csv)

# count NO_HITs in the pred_class column
nohit_count = (df["pred_class"] == "NO_HIT").sum()
print(nohit_count)

# or proportion
nohit_frac = (df["pred_class"] == "NO_HIT").mean()
print(nohit_frac)


2361
0.02361
