In [None]:
#!/usr/bin/env python3
"""
CLI for the "second MAFFT" step (fixed and fully commented):
- Run MAFFT a second time on an input FASTA
- Crop the alignment to the non-gap span of the reference sequence (configurable via --reference_name)
- Write both cropped (still-gapped) and cropped-then-ungapped FASTA outputs
- Length alarm: print only if any *cropped & ungapped* sequence length is outside
  [(1 - TOL) * target_len, (1 + TOL) * target_len], where TOL = --length_tolerance (default 0.30)

File-system / process hardening:
- Atomic writes for all outputs (temp file + fsync + atomic replace)
- Overwrite policy selectable: {skip, fail, replace} (default: fail)
- MAFFT process: timeout + stderr captured to <output>.stderr.txt on failure/warnings
- Honors overwrite policy consistently (no silent replace)
- Optional extra MAFFT args via --mafft_args (e.g., "--thread -1 --maxiterate 2")
"""

import argparse
import shutil
import subprocess
import sys
import tempfile
import os
import math
import shlex
from dataclasses import dataclass
from pathlib import Path
from typing import List, Tuple, Optional

# -------------------------------
# Limits to guard against bad inputs
# -------------------------------
MAX_SEQ_LENGTH = 20000   # maximum allowed sequence length
MAX_NUM_SEQUENCES = 15000  # maximum number of sequences allowed in input

# -------------------------------
# Overwrite policy (applies to ALL writes)
#   skip    -> do nothing if target exists (silently keep existing file)
#   fail    -> hard error if target exists (default for safety)
#   replace -> overwrite existing file atomically
# -------------------------------
OVERWRITE_REPLACE = "replace"
OVERWRITE_FAIL    = "fail"   # default: safer, refuses to overwrite silently
OVERWRITE_SKIP = "skip"

# -------------------------------
# Basic FASTA record
# -------------------------------
@dataclass
class FastaEntry:
    name: str
    sequence: str

# -------------------------------
# Atomic writers
#   - Guarantee complete writes even if crash occurs mid-write
#   - Implement overwrite policy strictly
# -------------------------------
def _atomic_write_text(data: str, target: Path, overwrite: str = OVERWRITE_SKIP) -> None:
    target = Path(target)
    target.parent.mkdir(parents=True, exist_ok=True)

    # Enforce overwrite policy up front
    if target.exists():
        if overwrite == OVERWRITE_SKIP:
            return
        if overwrite == OVERWRITE_FAIL:
            print(f"❌ ERROR: Refusing to overwrite existing file: {target}", file=sys.stderr)
            sys.exit(1)
        # overwrite == replace -> allowed

    # Write to a sibling temp file, flush, fsync, and atomically replace final path
    tmp = target.with_name(target.name + ".tmp")
    with tmp.open("w", encoding="utf-8", newline="\n") as f:
        f.write(data)
        f.flush()
        os.fsync(f.fileno())
    os.replace(tmp, target)

def _atomic_write_bytes(data: bytes, target: Path, overwrite: str = OVERWRITE_SKIP) -> None:
    target = Path(target)
    target.parent.mkdir(parents=True, exist_ok=True)

    if target.exists():
        if overwrite == OVERWRITE_SKIP:
            return
        if overwrite == OVERWRITE_FAIL:
            print(f"❌ ERROR: Refusing to overwrite existing file: {target}", file=sys.stderr)
            sys.exit(1)

    tmp = target.with_name(target.name + ".tmp")
    with tmp.open("wb") as f:
        f.write(data)
        f.flush()
        os.fsync(f.fileno())
    os.replace(tmp, target)

# -------------------------------
# FASTA I/O
#   - read_fasta: parses FASTA, storing name + sequence, stripping spaces, uppercasing
#   - write_fasta: dumps all entries atomically
# -------------------------------
def read_fasta(filename: Path) -> List[FastaEntry]:
    if not filename.exists():
        print(f"❌ ERROR: Unable to open input file: {filename}", file=sys.stderr)
        sys.exit(1)

    entries: List[FastaEntry] = []
    current_name: Optional[str] = None
    current_seq_chunks: List[str] = []

    with filename.open("r", encoding="utf-8") as fh:
        for raw in fh:
            line = raw.rstrip("\n")
            if not line:
                continue
            if line.startswith(">"):
                # Flush any previous record
                if current_name is not None:
                    seq = "".join(current_seq_chunks)
                    if len(seq) > MAX_SEQ_LENGTH:
                        print("❌ ERROR: Sequence too long!", file=sys.stderr)
                        sys.exit(1)
                    entries.append(FastaEntry(current_name, seq))
                    if len(entries) > MAX_NUM_SEQUENCES:
                        print("❌ ERROR: Too many sequences!", file=sys.stderr)
                        sys.exit(1)
                header = line[1:].strip()
                token = header.split()[0] if header else ""
                current_name = token
                current_seq_chunks = []
            else:
                # Append sequence line, stripping spaces and uppercasing
                current_seq_chunks.append(line.replace(" ", "").upper())

        # Flush final record
        if current_name is not None:
            seq = "".join(current_seq_chunks)
            if len(seq) > MAX_SEQ_LENGTH:
                print("❌ ERROR: Sequence too long!", file=sys.stderr)
                sys.exit(1)
            entries.append(FastaEntry(current_name, seq))

    if not entries:
        print(f"❌ ERROR: No sequences read from {filename}", file=sys.stderr)
        sys.exit(1)
    return entries

def write_fasta(entries: List[FastaEntry], out_path: Path, overwrite: str) -> None:
    lines = []
    for e in entries:
        lines.append(f">{e.name}\n{e.sequence}\n")
    _atomic_write_text("".join(lines), out_path, overwrite=overwrite)

# -------------------------------
# Alignment helpers
#   - find_target_bounds: locate first and last non-gap indices in reference
#   - remove_gaps: strip '-' characters
# -------------------------------
def find_target_bounds(entries: List[FastaEntry], reference_name: str) -> Tuple[int, int]:
    for e in entries:
        if e.name == reference_name:
            start = next((i for i, ch in enumerate(e.sequence) if ch != '-'), -1)
            end = max((i for i, ch in enumerate(e.sequence) if ch != '-'), default=-1)
            return start, end
    return -1, -1

def remove_gaps(entries: List[FastaEntry]) -> List[FastaEntry]:
    return [FastaEntry(e.name, e.sequence.replace("-", "")) for e in entries]

# -------------------------------
# MAFFT runner
#   - Checks overwrite policy before execution
#   - Captures stdout/stderr, times out safely
#   - Writes stderr to .stderr.txt on error or warnings
# -------------------------------
def run_mafft(input_fa: Path, output_fa: Path, overwrite: str, timeout_s: int, mafft_args: str = "") -> None:
    output_fa.parent.mkdir(parents=True, exist_ok=True)

    if output_fa.exists():
        if overwrite == OVERWRITE_SKIP:
            return
        if overwrite == OVERWRITE_FAIL:
            print(f"❌ ERROR: Refusing to overwrite existing file: {output_fa}", file=sys.stderr)
            sys.exit(1)

    if shutil.which("mafft") is None:
        print("❌ ERROR: MAFFT not found in PATH.", file=sys.stderr)
        sys.exit(1)
    print(f"🔄 Running second MAFFT on {input_fa.name}...")

    cmd = ["mafft", "--auto"]
    if mafft_args:
        cmd.extend(shlex.split(mafft_args))
    cmd.append(str(input_fa))

    try:
        proc = subprocess.run(
            cmd,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            text=False,
            check=False,
            timeout=timeout_s,
        )
    except subprocess.TimeoutExpired:
        print(f"❌ ERROR: MAFFT timed out after {timeout_s} seconds.", file=sys.stderr)
        sys.exit(1)
    except Exception as e:
        print(f"❌ ERROR: Failed to run MAFFT: {e}", file=sys.stderr)
        sys.exit(1)

    if proc.returncode != 0:
        _atomic_write_bytes(proc.stderr or b"", output_fa.with_suffix(output_fa.suffix + ".stderr.txt"), overwrite=overwrite)
        sys.stderr.write((proc.stderr or b"").decode(errors="replace"))
        print("❌ ERROR: Second MAFFT run failed.", file=sys.stderr)
        sys.exit(1)

    # Success: write alignment (stdout)
    _atomic_write_bytes(proc.stdout or b"", output_fa, overwrite=overwrite)

    # Save warnings if any
    if proc.stderr:
        _atomic_write_bytes(proc.stderr, output_fa.with_suffix(output_fa.suffix + ".stderr.txt"), overwrite=overwrite)

# -------------------------------
# Alarm + summary reporting
#   - Alarm only prints when sequences violate length tolerance
#   - Summary gives min/max/target lengths and counts
# -------------------------------
def _alarm_if_violations(ungapped: List[FastaEntry], reference_name: str, tol: float) -> None:
    target_len = next((len(e.sequence) for e in ungapped if e.name == reference_name), None)
    if target_len is None:
        print(f"❌ ERROR: Reference sequence '{reference_name}' not found.", file=sys.stderr)
        sys.exit(1)

    lo = math.floor((1.0 - tol) * target_len)
    hi = math.ceil((1.0 + tol) * target_len)

    out = [(e.name, len(e.sequence)) for e in ungapped if not (lo <= len(e.sequence) <= hi)]
    if out:
        print("\n=== Length Check ===")
        print(f"Target length: {target_len} (allowed: {lo}..{hi}; tol=±{int(tol*100)}%)")
        print("Sequences OUT of range:")
        for name, L in out:
            print(f"  - {name}: {L}")

def summarize(entries_ungapped: List[FastaEntry], reference_name: str) -> None:
    min_len = min(len(e.sequence) for e in entries_ungapped)
    max_len = max(len(e.sequence) for e in entries_ungapped)
    target_len = next((len(e.sequence) for e in entries_ungapped if e.name == reference_name), 0)

    print("\nSummary for Second MAFFT Run:")
    print(f"1. Shortest: {min_len}, Longest: {max_len}, Target: {target_len}")
    print(f"2. Total sequences: {len(entries_ungapped)}")

# -------------------------------
# CLI entry point
# -------------------------------
def main():
    p = argparse.ArgumentParser(description="Second MAFFT run + crop/ungap, FS hardened.")

    # Input/output arguments
    p.add_argument("--input", default="NAMEcroppedV1.fasta", help="Input FASTA for the second MAFFT run")
    p.add_argument("--temp_aligned", default=None, help="Temporary aligned FASTA path")
    p.add_argument("--aligned_out", default=None, help="Cropped aligned (still gapped) output path")
    p.add_argument("--cropped_ungapped_out", default=None, help="Ungapped (cropped) output path")
    p.add_argument("--output_dir", default=None, help="Directory for outputs (default: create a new temp folder)")

    # Filesystem safety knobs
    p.add_argument("--overwrite", default=OVERWRITE_FAIL, choices=[OVERWRITE_SKIP, OVERWRITE_FAIL, OVERWRITE_REPLACE], help="Overwrite policy for outputs")
    p.add_argument("--mafft_timeout", type=int, default=3600, help="Timeout in seconds for MAFFT")
    p.add_argument("--mafft_args", default="", help="Extra MAFFT args appended after --auto")

    # Biology knobs
    p.add_argument("--reference_name", default="target_sequence", help="Header name of reference sequence")
    p.add_argument("--length_tolerance", type=float, default=0.30, help="Allowed ±fraction around target length")

    # Parse known arguments and ignore the rest (like the -f from Colab)
    args, unknown = p.parse_known_args()

    input_fa = Path(args.input).expanduser().resolve()

    # Prepare output directory
    if args.output_dir:
        outdir = Path(args.output_dir).expanduser().resolve()
        outdir.mkdir(parents=True, exist_ok=True)
    else:
        outdir = Path(tempfile.mkdtemp(prefix="second_mafft_"))

    # Resolve file paths
    temp_aligned = Path(args.temp_aligned).expanduser().resolve() if args.temp_aligned else outdir / "temp_aligned2.fasta"
    aligned_out = Path(args.aligned_out).expanduser().resolve() if args.aligned_out else outdir / "NAMEalignedV2.fasta"
    cropped_ungapped_out = Path(args.cropped_ungapped_out).expanduser().resolve() if args.cropped_ungapped_out else outdir / "NAMEcroppedV2.fasta"

    # Guardrail: outputs must differ from input
    for pth in (temp_aligned, aligned_out, cropped_ungapped_out):
        if pth == input_fa:
            print(f"❌ ERROR: Output path {pth} must differ from input {input_fa}", file=sys.stderr)
            sys.exit(1)

    # Run MAFFT
    run_mafft(input_fa, temp_aligned, overwrite=args.overwrite, timeout_s=args.mafft_timeout, mafft_args=args.mafft_args)

    # Load alignment
    entries = read_fasta(temp_aligned)

    # Compute cropping bounds
    start, end = find_target_bounds(entries, reference_name=args.reference_name)
    if start == -1 or end == -1:
        print("❌ ERROR: Could not determine cropping boundaries.", file=sys.stderr)
        sys.exit(1)
    ref_len = end - start + 1
    if ref_len <= 0:
        print("❌ ERROR: Invalid crop window.", file=sys.stderr)
        sys.exit(1)
    print(f"🔍 Cropping boundaries: start={start}, end={end}, length={ref_len}")

    # Build cropped entries
    cropped_entries = [FastaEntry(e.name, e.sequence[start:end+1]) for e in entries]

    # Write cropped gapped output
    write_fasta(cropped_entries, aligned_out, overwrite=args.overwrite)
    print(f"✅ Cropped aligned file written to: {aligned_out}")

    # Write ungapped (cropped)
    ungapped = remove_gaps(cropped_entries)
    write_fasta(ungapped, cropped_ungapped_out, overwrite=args.overwrite)
    print(f"✅ Ungapped (cropped) file written to: {cropped_ungapped_out}")

    # Run length alarm
    _alarm_if_violations(ungapped, reference_name=args.reference_name, tol=args.length_tolerance)

    # Print summary
    summarize(ungapped, reference_name=args.reference_name)
    print(f"\n(All outputs written under: {outdir})")


if __name__ == "__main__":
    # Use parse_known_args to ignore extra arguments passed by environments like Colab
    main()