# Stanford RNA3D submit notebook (Fase 1 + Fase 2 full pipeline)

Pipeline completo com contratos estritos (fail-fast), sem fallback silencioso.

In [None]:
import csv
import hashlib
import os
import shutil
import subprocess
import sys
import zipfile
from pathlib import Path

import polars as pl

SCRIPT_LOC = "submission_notebook_phase1_phase2_full_v2"
N_MODELS = 5
TOP_K = 128
TEMPLATE_SCORE_THRESHOLD = 0.65
LEN_THRESHOLD = 68


def _die(stage: str, where: str, cause: str, impact, examples) -> None:
    examples_text = ",".join(str(x) for x in examples) if examples else "-"
    raise RuntimeError(f"[{stage}] [{where}] {cause} | impacto={impact} | exemplos={examples_text}")


def _tail(stdout: str | None, stderr: str | None, n: int = 20) -> list[str]:
    text = ((stdout or "") + "\n" + (stderr or "")).strip()
    if not text:
        return []
    return text.splitlines()[-n:]


def _run(cmd: list[str], env: dict[str, str], cwd: Path) -> None:
    where = f"{SCRIPT_LOC}:run"
    print("[RUN]", " ".join(cmd))
    proc = subprocess.run(cmd, env=env, cwd=str(cwd), text=True, capture_output=True)
    if proc.returncode != 0:
        _die("PIPELINE", where, "comando falhou", proc.returncode, _tail(proc.stdout, proc.stderr, 20))
    out = (proc.stdout or "").strip()
    if out:
        print(out)


def _collect_datasets(input_root: Path) -> list[Path]:
    where = f"{SCRIPT_LOC}:collect_datasets"
    if not input_root.exists():
        _die("LOAD", where, "diretorio /kaggle/input ausente", 1, [str(input_root)])
    datasets = [path for path in sorted(input_root.iterdir()) if path.is_dir()]
    if not datasets:
        _die("LOAD", where, "nenhum dataset montado em /kaggle/input", 0, [])
    return datasets


def _find_by_filename(datasets: list[Path], filename: str) -> list[Path]:
    hits: list[Path] = []
    for ds in datasets:
        hits.extend([path for path in ds.rglob(filename) if path.is_file()])
    return hits


def _find_by_pattern(datasets: list[Path], pattern: str) -> list[Path]:
    hits: list[Path] = []
    for ds in datasets:
        hits.extend([path for path in ds.rglob(pattern) if path.is_file()])
    return hits


def _require_single(label: str, candidates: list[Path], *, stage: str, where: str) -> Path:
    unique: list[Path] = []
    seen: set[str] = set()
    for item in candidates:
        key = str(item.resolve())
        if key not in seen:
            seen.add(key)
            unique.append(item.resolve())
    if not unique:
        _die(stage, where, f"ativo obrigatorio ausente: {label}", 1, [label])
    if len(unique) > 1:
        _die(stage, where, f"ativo ambiguo para {label}", len(unique), [str(path) for path in unique[:8]])
    return unique[0]


def _require_any(label: str, candidates: list[Path], *, stage: str, where: str) -> list[Path]:
    unique: list[Path] = []
    seen: set[str] = set()
    for item in candidates:
        key = str(item.resolve())
        if key not in seen:
            seen.add(key)
            unique.append(item.resolve())
    if not unique:
        _die(stage, where, f"ativo obrigatorio ausente: {label}", 1, [label])
    return unique


def _src_supports_command(src_root: Path, command: str) -> bool:
    env = os.environ.copy()
    env["PYTHONPATH"] = str(src_root) + ((":" + env["PYTHONPATH"]) if env.get("PYTHONPATH") else "")
    proc = subprocess.run([sys.executable, "-m", "rna3d_local", command, "--help"], env=env, cwd=str(src_root.parent), text=True, capture_output=True)
    return proc.returncode == 0


def _find_src_root(datasets: list[Path], unpack_root: Path) -> Path:
    where = f"{SCRIPT_LOC}:find_src_root"
    candidates: list[Path] = [ds / "src" for ds in datasets if (ds / "src" / "rna3d_local" / "cli.py").exists()]

    zip_candidates = _find_by_filename(datasets, "src_reboot.zip")
    if zip_candidates:
        zip_path = _require_single("src_reboot.zip", zip_candidates, stage="LOAD", where=where)
        if unpack_root.exists() and any(unpack_root.iterdir()):
            _die("LOAD", where, "diretorio de unpack do src nao-vazio", 1, [str(unpack_root)])
        unpack_root.mkdir(parents=True, exist_ok=True)
        with zipfile.ZipFile(str(zip_path), "r") as archive:
            archive.extractall(str(unpack_root))
        extracted_src = unpack_root / "src"
        if not (extracted_src / "rna3d_local" / "cli.py").exists():
            _die("LOAD", where, "src_reboot.zip sem src/rna3d_local/cli.py", 1, [str(zip_path)])
        candidates.append(extracted_src)

    if not candidates:
        _die("LOAD", where, "nenhum src/rna3d_local/cli.py encontrado", 1, [])

    unique: list[Path] = []
    seen: set[str] = set()
    for item in candidates:
        key = str(item.resolve())
        if key not in seen:
            seen.add(key)
            unique.append(item.resolve())

    compatible = [src for src in unique if _src_supports_command(src, "build-embedding-index")]
    if not compatible:
        _die("LOAD", where, "nenhum src compativel com CLI da Fase 1+2", len(unique), [str(path) for path in unique[:8]])
    if len(compatible) > 1:
        _die("LOAD", where, "multiplos src compativeis; selecao ambigua", len(compatible), [str(path) for path in compatible[:8]])
    return compatible[0]


def _find_template_assets(datasets: list[Path]) -> tuple[Path, Path]:
    where = f"{SCRIPT_LOC}:find_template_assets"

    # Prefer deterministic sources to avoid ambiguity when multiple datasets ship template_db.
    preferred = [
        "stanford-rna3d-reboot-src-v2",
        "ribonanza-quickstart-3d-templates",
    ]
    for ds_name in preferred:
        ds = next((item for item in datasets if item.name == ds_name), None)
        if ds is None:
            continue
        idx_candidates = _find_by_filename([ds], "template_index.parquet")
        tpl_candidates = _find_by_filename([ds], "templates.parquet")
        if idx_candidates and tpl_candidates:
            template_index = _require_single("template_index.parquet", idx_candidates, stage="LOAD", where=where)
            templates = _require_single("templates.parquet", tpl_candidates, stage="LOAD", where=where)
            return template_index, templates

    template_index = _require_single("template_index.parquet", _find_by_filename(datasets, "template_index.parquet"), stage="LOAD", where=where)
    templates = _require_single("templates.parquet", _find_by_filename(datasets, "templates.parquet"), stage="LOAD", where=where)
    return template_index, templates


def _find_quickstart_file(datasets: list[Path]) -> Path:
    where = f"{SCRIPT_LOC}:find_quickstart"
    candidates = _find_by_pattern(datasets, "*QUICK_START*.csv") + _find_by_pattern(datasets, "*quickstart*.csv") + _find_by_pattern(datasets, "*quickstart*.parquet")
    return _require_single("quickstart", candidates, stage="LOAD", where=where)


def _materialize_phase2_assets(datasets: list[Path], dst: Path) -> Path:
    where = f"{SCRIPT_LOC}:materialize_phase2_assets"
    # Locate a phase2 assets root that matches the rna3d_local offline runners contract.
    # Expected layout: <root>/models/{rnapro,chai1,boltz1}/...
    required_markers = [
        ("rnapro", Path("models/rnapro/rnapro-public-best-500m.ckpt")),
        ("chai1", Path("models/chai1/models_v2/trunk.pt")),
        ("boltz1", Path("models/boltz1/boltz1_conf.ckpt")),
    ]
    roots: list[Path] = []
    for ds in datasets:
        for cand in [
            ds,
            ds / "export" / "kaggle_assets",
            ds / "kaggle_assets",
            ds / "export" / "kaggle_assets" / "export" / "kaggle_assets",
        ]:
            ok = True
            for _label, rel in required_markers:
                if not (cand / rel).exists():
                    ok = False
                    break
            if ok:
                roots.append(cand.resolve())
    uniq: list[Path] = []
    seen: set[str] = set()
    for r in roots:
        k = str(r)
        if k not in seen:
            seen.add(k)
            uniq.append(r)
    if not uniq:
        _die("LOAD", where, "assets phase2 nao encontrados (markers ausentes)", 1, [str(m[1]) for m in required_markers])
    if len(uniq) > 1:
        _die("LOAD", where, "assets phase2 ambiguos (multiplos roots)", len(uniq), [str(p) for p in uniq[:8]])
    return uniq[0]

def _build_template_family_map(template_index_path: Path, out_path: Path) -> Path:
    where = f"{SCRIPT_LOC}:build_template_family_map"
    df = pl.read_parquet(str(template_index_path))
    if "template_uid" not in df.columns:
        _die("PIPELINE", where, "template_index sem template_uid", 1, [str(template_index_path)])
    out = df.select(pl.col("template_uid").cast(pl.Utf8).unique().sort()).with_columns(pl.lit("unknown").alias("family_label"))
    out_path.parent.mkdir(parents=True, exist_ok=True)
    out.write_parquet(str(out_path), compression="zstd")
    if out.height == 0:
        _die("PIPELINE", where, "template_family_map vazio", 0, [str(out_path)])
    return out_path


def _build_weak_labels(candidates_path: Path, out_path: Path) -> Path:
    where = f"{SCRIPT_LOC}:build_weak_labels"
    df = pl.read_parquet(str(candidates_path))
    required = ["target_id", "template_uid"]
    missing = [c for c in required if c not in df.columns]
    if missing:
        _die("PIPELINE", where, "candidates sem colunas obrigatorias", len(missing), missing)

    score_col = None
    for c in ["final_score", "cosine_score", "rank"]:
        if c in df.columns:
            score_col = c
            break
    if score_col is None:
        _die("PIPELINE", where, "candidates sem coluna de ranking/score", 1, [str(df.columns[:8])])

    if score_col == "rank":
        ordered = df.sort(["target_id", "rank"], descending=[False, False])
    else:
        ordered = df.sort(["target_id", score_col], descending=[False, True])

    top1 = ordered.group_by("target_id").agg(pl.first("template_uid").alias("top_template"))
    labels = (
        ordered.select("target_id", "template_uid")
        .unique()
        .join(top1, on="target_id", how="left")
        .with_columns((pl.col("template_uid") == pl.col("top_template")).cast(pl.Float64).alias("label"))
        .select("target_id", "template_uid", "label")
    )
    if labels.height < 8:
        _die("PIPELINE", where, "labels fracos insuficientes", labels.height, ["min=8"]) 
    out_path.parent.mkdir(parents=True, exist_ok=True)
    labels.write_parquet(str(out_path), compression="zstd")
    return out_path


def _assert_cli_commands(env: dict[str, str], repo_root: Path) -> None:
    where = f"{SCRIPT_LOC}:assert_cli"
    required = [
        "build-embedding-index",
        "infer-description-family",
        "retrieve-templates-latent",
        "predict-tbm",
        "export-submission",
        "check-submission",
    ]
    for command in required:
        proc = subprocess.run([sys.executable, "-m", "rna3d_local", command, "--help"], env=env, cwd=str(repo_root), text=True, capture_output=True)
        if proc.returncode != 0:
            _die("ENV", where, f"comando ausente no pacote rna3d_local: {command}", proc.returncode, _tail(proc.stdout, proc.stderr, 12))


import uuid

comp_root = Path("/kaggle/input/stanford-rna-3d-folding-2")
input_root = Path("/kaggle/input")
work_root = Path("/kaggle/working")
run_root = work_root / f"run_{SCRIPT_LOC}_{uuid.uuid4().hex[:8]}"
submission_path = work_root / "submission.csv"

sample_path = comp_root / "sample_submission.csv"
targets_path = comp_root / "test_sequences.csv"
if not sample_path.exists():
    _die("LOAD", f"{SCRIPT_LOC}:paths", "sample_submission.csv ausente", 1, [str(sample_path)])
if not targets_path.exists():
    _die("LOAD", f"{SCRIPT_LOC}:paths", "test_sequences.csv ausente", 1, [str(targets_path)])

run_root.mkdir(parents=True, exist_ok=False)

datasets = _collect_datasets(input_root)
src_root = _find_src_root(datasets, run_root / "src_unpack")
template_index_path, templates_path = _find_template_assets(datasets)
template_family_map_path = _build_template_family_map(template_index_path, run_root / "template_family_map.parquet")

repo_root = src_root.parent
if str(src_root) not in sys.path:
    sys.path.insert(0, str(src_root))

env = os.environ.copy()
env["PYTHONPATH"] = str(src_root) + ((":" + env["PYTHONPATH"]) if env.get("PYTHONPATH") else "")
env.setdefault("OMP_NUM_THREADS", "1")
env.setdefault("MKL_NUM_THREADS", "1")
env.setdefault("OPENBLAS_NUM_THREADS", "1")
env.setdefault("NUMEXPR_NUM_THREADS", "1")

_assert_cli_commands(env, repo_root)

print(f"[INFO] [{SCRIPT_LOC}] run_root={run_root}")
print(f"[INFO] [{SCRIPT_LOC}] src_root={src_root}")
print(f"[INFO] [{SCRIPT_LOC}] template_index={template_index_path}")
print(f"[INFO] [{SCRIPT_LOC}] templates={templates_path}")

desc_dir = run_root / "description_family"
emb_dir = run_root / "embedding"
emb_path = emb_dir / "template_embeddings.parquet"
retrieval_path = run_root / "retrieval_candidates.parquet"
retrieval_tbm_path = run_root / "retrieval_candidates_tbm.parquet"
targets_tbm_path = run_root / "targets_tbm.csv"
targets_fallback_path = run_root / "targets_fallback.csv"
tbm_path = run_root / "tbm_predictions.parquet"
drfold_path = run_root / "drfold_predictions.parquet"
combined_path = run_root / "combined_predictions.parquet"

_run([sys.executable, "-m", "rna3d_local", "build-embedding-index", "--template-index", str(template_index_path), "--out-dir", str(emb_dir), "--embedding-dim", "256", "--encoder", "mock", "--ann-engine", "none"], env, repo_root)
_run([sys.executable, "-m", "rna3d_local", "infer-description-family", "--targets", str(targets_path), "--out-dir", str(desc_dir), "--backend", "rules", "--template-family-map", str(template_family_map_path)], env, repo_root)
_run([
    sys.executable, "-m", "rna3d_local", "retrieve-templates-latent",
    "--template-index", str(template_index_path),
    "--template-embeddings", str(emb_path),
    "--targets", str(targets_path),
    "--out", str(retrieval_path),
    "--top-k", str(TOP_K),
    "--encoder", "mock",
    "--embedding-dim", "256",
    "--ann-engine", "numpy_bruteforce",
    "--family-prior", str(desc_dir / "family_prior.parquet"),
    "--weight-embed", "0.70",
    "--weight-llm", "0.20",
    "--weight-seq", "0.10",
], env, repo_root)

# Compute which targets can be covered by at least one contiguous template with length >= target_len.
targets_df = pl.read_csv(targets_path)
if "target_id" not in targets_df.columns or "sequence" not in targets_df.columns:
    _die("LOAD", f"{SCRIPT_LOC}:targets", "targets sem colunas target_id/sequence", 1, [str(targets_df.columns)])
targets_df = targets_df.with_columns(pl.col("target_id").cast(pl.Utf8), pl.col("sequence").cast(pl.Utf8))
targets_len = targets_df.select(pl.col("target_id"), pl.col("sequence").str.replace_all(r"\|", "").str.len_chars().alias("target_len"))

tpl_stats = pl.scan_parquet(str(templates_path)).select(pl.col("template_uid").cast(pl.Utf8), pl.col("resid").cast(pl.Int32))
tpl_stats = tpl_stats.group_by("template_uid").agg(pl.col("resid").min().alias("min_resid"), pl.col("resid").max().alias("max_resid"), pl.col("resid").n_unique().alias("n_unique"))
tpl_stats = tpl_stats.with_columns((pl.col("max_resid") - pl.col("min_resid") + 1).cast(pl.Int32).alias("tpl_len"))
tpl_stats = tpl_stats.with_columns((pl.col("n_unique") == pl.col("tpl_len")).alias("contiguous"))

retr_lf = pl.scan_parquet(str(retrieval_path)).select(pl.col("target_id").cast(pl.Utf8), pl.col("template_uid").cast(pl.Utf8)).unique()
supported = retr_lf.join(targets_len.lazy(), on="target_id", how="inner").join(tpl_stats, on="template_uid", how="inner")
supported = supported.filter(pl.col("contiguous") & (pl.col("tpl_len") >= pl.col("target_len"))).select(pl.col("target_id")).unique().collect()
supported_ids = set(supported.get_column("target_id").to_list())
all_ids = set(targets_len.get_column("target_id").to_list())
fallback_ids = sorted(all_ids - supported_ids)
print(f"[INFO] [{SCRIPT_LOC}] tbm_supported={len(supported_ids)} fallback={len(fallback_ids)}")
if len(supported_ids) == 0:
    _die("PIPELINE", f"{SCRIPT_LOC}:coverage", "nenhum alvo com template valido para TBM", 0, [])

# Write filtered targets/retrieval for TBM.
targets_df.filter(pl.col("target_id").is_in(sorted(supported_ids))).write_csv(targets_tbm_path)
targets_df.filter(pl.col("target_id").is_in(fallback_ids)).write_csv(targets_fallback_path)
pl.read_parquet(retrieval_path).filter(pl.col("target_id").is_in(sorted(supported_ids))).write_parquet(retrieval_tbm_path)

_run([sys.executable, "-m", "rna3d_local", "predict-tbm", "--retrieval", str(retrieval_tbm_path), "--templates", str(templates_path), "--targets", str(targets_tbm_path), "--out", str(tbm_path), "--n-models", str(N_MODELS)], env, repo_root)

def _normalize_seq(seq: str) -> str:
    raw = str(seq or "").strip().upper().replace("T", "U")
    return "".join(ch for ch in raw if ch not in {"|", " ", "\t", "\n", "\r"})

def _parse_c1prime_coords(pdb_path: Path, *, target_id: str, seq: str):
    coords = {}
    try:
        for line in pdb_path.read_text("utf-8", errors="replace").splitlines():
            if not line.startswith("ATOM"):
                continue
            atom = line[12:16].strip()
            if atom != "C1'":
                continue
            try:
                resid = int(line[22:26].strip())
                x = float(line[30:38].strip())
                y = float(line[38:46].strip())
                z = float(line[46:54].strip())
            except Exception:
                continue
            coords[resid] = (x, y, z)
    except Exception as exc:
        _die("DRFOLD2", f"{SCRIPT_LOC}:parse_pdb", "falha ao ler pdb", 1, [target_id, str(pdb_path), f"{type(exc).__name__}:{exc}"])
    out = []
    if len(seq) == 0:
        _die("DRFOLD2", f"{SCRIPT_LOC}:parse_pdb", "sequencia vazia", 1, [target_id])
    for i,ch in enumerate(seq, start=1):
        if i not in coords:
            _die("DRFOLD2", f"{SCRIPT_LOC}:parse_pdb", "C1' ausente", 1, [f"{target_id}:{i}", str(pdb_path)])
        x, y, z = coords[i]
        out.append((i, ch, float(x), float(y), float(z)))
    return out

# DRfold2 fallback for targets without valid TBM templates.
drfold_parts = []
if fallback_ids:
    drfold_candidates = _find_by_filename(datasets, "DRfold_infer.py")
    drfold_script = _require_single("DRfold_infer.py", drfold_candidates, stage="DRFOLD2", where=f"{SCRIPT_LOC}:drfold")
    drfold_root = drfold_script.parent

    if not drfold_script.exists():
        _die("DRFOLD2", f"{SCRIPT_LOC}:drfold", "DRfold_infer.py ausente", 1, [str(drfold_script)])
    drfold_work = run_root / "drfold2"
    drfold_work.mkdir(parents=True, exist_ok=True)
    for tid in fallback_ids:
        seq_raw = targets_df.filter(pl.col("target_id") == tid).get_column("sequence").item()
        seq = _normalize_seq(seq_raw)
        fasta = drfold_work / f"{tid}.fa"
        fasta.write_text(f">{tid}\n{seq}\n", encoding="utf-8")
        outdir = drfold_work / tid
        outdir.mkdir(parents=True, exist_ok=False)
        cmd = [sys.executable, str(drfold_script), str(fasta), str(outdir)]
        print("[RUN]", " ".join(cmd))
        proc = subprocess.run(cmd, text=True, capture_output=True, cwd=str(drfold_root))
        if proc.returncode != 0:
            _die("DRFOLD2", f"{SCRIPT_LOC}:drfold", "runner falhou", proc.returncode, _tail(proc.stdout, proc.stderr, 20))
        pdb = outdir / "relax" / "model_1.pdb"
        if not pdb.exists():
            _die("DRFOLD2", f"{SCRIPT_LOC}:drfold", "model_1.pdb ausente", 1, [tid, str(pdb)])
        coords = _parse_c1prime_coords(pdb, target_id=tid, seq=seq)
        rows = []
        for model_id in range(1, int(N_MODELS) + 1):
            for resid, resname, x, y, z in coords:
                rows.append({"target_id": tid, "model_id": model_id, "resid": resid, "resname": resname, "x": x, "y": y, "z": z})
        part_path = drfold_work / f"{tid}_pred.parquet"
        pl.DataFrame(rows).write_parquet(part_path)
        drfold_parts.append(part_path)

if drfold_parts:
    pl.concat([pl.read_parquet(p) for p in drfold_parts], how="vertical_relaxed").write_parquet(drfold_path)
    combined = pl.concat([pl.read_parquet(tbm_path), pl.read_parquet(drfold_path)], how="vertical_relaxed")
else:
    combined = pl.read_parquet(tbm_path)
combined.write_parquet(combined_path)

_run([sys.executable, "-m", "rna3d_local", "export-submission", "--sample", str(sample_path), "--predictions", str(combined_path), "--out", str(submission_path)], env, repo_root)
_run([sys.executable, "-m", "rna3d_local", "check-submission", "--sample", str(sample_path), "--submission", str(submission_path)], env, repo_root)

sha = hashlib.sha256(submission_path.read_bytes()).hexdigest()
print(f"[DONE] [{SCRIPT_LOC}] submission={submission_path}")
print(f"[INFO] [{SCRIPT_LOC}] submission_sha256={sha}")
