# Stanford RNA3D submit notebook (Fase 1 + Fase 2 full pipeline)

Pipeline completo com contratos estritos (fail-fast), sem fallback silencioso.

In [None]:
import csv
import hashlib
import os
import shutil
import subprocess
import sys
import zipfile
from pathlib import Path

import polars as pl

SCRIPT_LOC = "submission_notebook_phase1_phase2_full_v2"
N_MODELS = 5
TOP_K = 128
TEMPLATE_SCORE_THRESHOLD = 0.65
LEN_THRESHOLD = 68


def _die(stage: str, where: str, cause: str, impact, examples) -> None:
    examples_text = ",".join(str(x) for x in examples) if examples else "-"
    raise RuntimeError(f"[{stage}] [{where}] {cause} | impacto={impact} | exemplos={examples_text}")


def _tail(stdout: str | None, stderr: str | None, n: int = 20) -> list[str]:
    text = ((stdout or "") + "\n" + (stderr or "")).strip()
    if not text:
        return []
    return text.splitlines()[-n:]


def _run(cmd: list[str], env: dict[str, str], cwd: Path) -> None:
    where = f"{SCRIPT_LOC}:run"
    print("[RUN]", " ".join(cmd))
    proc = subprocess.run(cmd, env=env, cwd=str(cwd), text=True, capture_output=True)
    if proc.returncode != 0:
        _die("PIPELINE", where, "comando falhou", proc.returncode, _tail(proc.stdout, proc.stderr, 20))
    out = (proc.stdout or "").strip()
    if out:
        print(out)


def _collect_datasets(input_root: Path) -> list[Path]:
    where = f"{SCRIPT_LOC}:collect_datasets"
    if not input_root.exists():
        _die("LOAD", where, "diretorio /kaggle/input ausente", 1, [str(input_root)])
    datasets = [path for path in sorted(input_root.iterdir()) if path.is_dir()]
    if not datasets:
        _die("LOAD", where, "nenhum dataset montado em /kaggle/input", 0, [])
    return datasets


def _find_by_filename(datasets: list[Path], filename: str) -> list[Path]:
    hits: list[Path] = []
    for ds in datasets:
        hits.extend([path for path in ds.rglob(filename) if path.is_file()])
    return hits


def _find_by_pattern(datasets: list[Path], pattern: str) -> list[Path]:
    hits: list[Path] = []
    for ds in datasets:
        hits.extend([path for path in ds.rglob(pattern) if path.is_file()])
    return hits


def _require_single(label: str, candidates: list[Path], *, stage: str, where: str) -> Path:
    unique: list[Path] = []
    seen: set[str] = set()
    for item in candidates:
        key = str(item.resolve())
        if key not in seen:
            seen.add(key)
            unique.append(item.resolve())
    if not unique:
        _die(stage, where, f"ativo obrigatorio ausente: {label}", 1, [label])
    if len(unique) > 1:
        _die(stage, where, f"ativo ambiguo para {label}", len(unique), [str(path) for path in unique[:8]])
    return unique[0]


def _require_any(label: str, candidates: list[Path], *, stage: str, where: str) -> list[Path]:
    unique: list[Path] = []
    seen: set[str] = set()
    for item in candidates:
        key = str(item.resolve())
        if key not in seen:
            seen.add(key)
            unique.append(item.resolve())
    if not unique:
        _die(stage, where, f"ativo obrigatorio ausente: {label}", 1, [label])
    return unique


def _src_supports_command(src_root: Path, command: str) -> bool:
    env = os.environ.copy()
    env["PYTHONPATH"] = str(src_root) + ((":" + env["PYTHONPATH"]) if env.get("PYTHONPATH") else "")
    proc = subprocess.run([sys.executable, "-m", "rna3d_local", command, "--help"], env=env, cwd=str(src_root.parent), text=True, capture_output=True)
    return proc.returncode == 0


def _find_src_root(datasets: list[Path], unpack_root: Path) -> Path:
    where = f"{SCRIPT_LOC}:find_src_root"
    candidates: list[Path] = [ds / "src" for ds in datasets if (ds / "src" / "rna3d_local" / "cli.py").exists()]

    zip_candidates = _find_by_filename(datasets, "src_reboot.zip")
    if zip_candidates:
        zip_path = _require_single("src_reboot.zip", zip_candidates, stage="LOAD", where=where)
        if unpack_root.exists() and any(unpack_root.iterdir()):
            _die("LOAD", where, "diretorio de unpack do src nao-vazio", 1, [str(unpack_root)])
        unpack_root.mkdir(parents=True, exist_ok=True)
        with zipfile.ZipFile(str(zip_path), "r") as archive:
            archive.extractall(str(unpack_root))
        extracted_src = unpack_root / "src"
        if not (extracted_src / "rna3d_local" / "cli.py").exists():
            _die("LOAD", where, "src_reboot.zip sem src/rna3d_local/cli.py", 1, [str(zip_path)])
        candidates.append(extracted_src)

    if not candidates:
        _die("LOAD", where, "nenhum src/rna3d_local/cli.py encontrado", 1, [])

    unique: list[Path] = []
    seen: set[str] = set()
    for item in candidates:
        key = str(item.resolve())
        if key not in seen:
            seen.add(key)
            unique.append(item.resolve())

    compatible = [src for src in unique if _src_supports_command(src, "build-embedding-index")]
    if not compatible:
        _die("LOAD", where, "nenhum src compativel com CLI da Fase 1+2", len(unique), [str(path) for path in unique[:8]])
    if len(compatible) > 1:
        _die("LOAD", where, "multiplos src compativeis; selecao ambigua", len(compatible), [str(path) for path in compatible[:8]])
    return compatible[0]


def _find_template_assets(datasets: list[Path]) -> tuple[Path, Path]:
    where = f"{SCRIPT_LOC}:find_template_assets"

    # Prefer deterministic sources to avoid ambiguity when multiple datasets ship template_db.
    preferred = [
        "stanford-rna3d-reboot-src-v2",
        "ribonanza-quickstart-3d-templates",
    ]
    for ds_name in preferred:
        ds = next((item for item in datasets if item.name == ds_name), None)
        if ds is None:
            continue
        idx_candidates = _find_by_filename([ds], "template_index.parquet")
        tpl_candidates = _find_by_filename([ds], "templates.parquet")
        if idx_candidates and tpl_candidates:
            template_index = _require_single("template_index.parquet", idx_candidates, stage="LOAD", where=where)
            templates = _require_single("templates.parquet", tpl_candidates, stage="LOAD", where=where)
            return template_index, templates

    template_index = _require_single("template_index.parquet", _find_by_filename(datasets, "template_index.parquet"), stage="LOAD", where=where)
    templates = _require_single("templates.parquet", _find_by_filename(datasets, "templates.parquet"), stage="LOAD", where=where)
    return template_index, templates


def _find_quickstart_file(datasets: list[Path]) -> Path:
    where = f"{SCRIPT_LOC}:find_quickstart"
    candidates = _find_by_pattern(datasets, "*QUICK_START*.csv") + _find_by_pattern(datasets, "*quickstart*.csv") + _find_by_pattern(datasets, "*quickstart*.parquet")
    return _require_single("quickstart", candidates, stage="LOAD", where=where)


def _materialize_phase2_assets(datasets: list[Path], dst: Path) -> Path:
    where = f"{SCRIPT_LOC}:materialize_phase2_assets"
    weight_candidates = _find_by_filename(datasets, "model_refiner_long_legacy.pt") + _find_by_pattern(datasets, "*model*.pt")
    config_candidates = _find_by_filename(datasets, "kernel_config.json") + _find_by_pattern(datasets, "*config*.json")
    wheel_candidates = _find_by_pattern(datasets, "*.whl")

    weight_src = _require_single("phase2_weight_source", weight_candidates, stage="LOAD", where=where)
    config_src = _require_single("phase2_config_source", config_candidates, stage="LOAD", where=where)
    wheels = _require_any("phase2_wheels", wheel_candidates, stage="LOAD", where=where)

    if dst.exists() and any(dst.iterdir()):
        _die("PIPELINE", where, "diretorio de assets runtime nao-vazio", 1, [str(dst)])

    (dst / "models" / "rnapro").mkdir(parents=True, exist_ok=True)
    (dst / "models" / "chai1").mkdir(parents=True, exist_ok=True)
    (dst / "models" / "boltz1").mkdir(parents=True, exist_ok=True)
    (dst / "wheels").mkdir(parents=True, exist_ok=True)

    shutil.copy2(weight_src, dst / "models" / "rnapro" / "model.pt")
    shutil.copy2(config_src, dst / "models" / "rnapro" / "config.json")
    shutil.copy2(weight_src, dst / "models" / "chai1" / "model.bin")
    shutil.copy2(config_src, dst / "models" / "chai1" / "config.json")
    shutil.copy2(weight_src, dst / "models" / "boltz1" / "model.safetensors")
    shutil.copy2(config_src, dst / "models" / "boltz1" / "config.json")
    for wheel in wheels:
        shutil.copy2(wheel, dst / "wheels" / wheel.name)
    return dst


def _build_template_family_map(template_index_path: Path, out_path: Path) -> Path:
    where = f"{SCRIPT_LOC}:build_template_family_map"
    df = pl.read_parquet(str(template_index_path))
    if "template_uid" not in df.columns:
        _die("PIPELINE", where, "template_index sem template_uid", 1, [str(template_index_path)])
    out = df.select(pl.col("template_uid").cast(pl.Utf8).unique().sort()).with_columns(pl.lit("unknown").alias("family_label"))
    out_path.parent.mkdir(parents=True, exist_ok=True)
    out.write_parquet(str(out_path), compression="zstd")
    if out.height == 0:
        _die("PIPELINE", where, "template_family_map vazio", 0, [str(out_path)])
    return out_path


def _build_weak_labels(candidates_path: Path, out_path: Path) -> Path:
    where = f"{SCRIPT_LOC}:build_weak_labels"
    df = pl.read_parquet(str(candidates_path))
    required = ["target_id", "template_uid"]
    missing = [c for c in required if c not in df.columns]
    if missing:
        _die("PIPELINE", where, "candidates sem colunas obrigatorias", len(missing), missing)

    score_col = None
    for c in ["final_score", "cosine_score", "rank"]:
        if c in df.columns:
            score_col = c
            break
    if score_col is None:
        _die("PIPELINE", where, "candidates sem coluna de ranking/score", 1, [str(df.columns[:8])])

    if score_col == "rank":
        ordered = df.sort(["target_id", "rank"], descending=[False, False])
    else:
        ordered = df.sort(["target_id", score_col], descending=[False, True])

    top1 = ordered.group_by("target_id").agg(pl.first("template_uid").alias("top_template"))
    labels = (
        ordered.select("target_id", "template_uid")
        .unique()
        .join(top1, on="target_id", how="left")
        .with_columns((pl.col("template_uid") == pl.col("top_template")).cast(pl.Float64).alias("label"))
        .select("target_id", "template_uid", "label")
    )
    if labels.height < 8:
        _die("PIPELINE", where, "labels fracos insuficientes", labels.height, ["min=8"]) 
    out_path.parent.mkdir(parents=True, exist_ok=True)
    labels.write_parquet(str(out_path), compression="zstd")
    return out_path


def _assert_cli_commands(env: dict[str, str], repo_root: Path) -> None:
    where = f"{SCRIPT_LOC}:assert_cli"
    required = [
        "build-embedding-index",
        "infer-description-family",
        "prepare-chemical-features",
        "retrieve-templates-latent",
        "train-template-reranker",
        "score-template-reranker",
        "predict-tbm",
        "build-phase2-assets",
        "predict-rnapro-offline",
        "predict-chai1-offline",
        "predict-boltz1-offline",
        "build-hybrid-candidates",
        "select-top5-hybrid",
        "export-submission",
        "check-submission",
    ]
    for command in required:
        proc = subprocess.run([sys.executable, "-m", "rna3d_local", command, "--help"], env=env, cwd=str(repo_root), text=True, capture_output=True)
        if proc.returncode != 0:
            _die("ENV", where, f"comando ausente no pacote rna3d_local: {command}", proc.returncode, _tail(proc.stdout, proc.stderr, 12))


comp_root = Path("/kaggle/input/stanford-rna-3d-folding-2")
input_root = Path("/kaggle/input")
work_root = Path("/kaggle/working")
run_root = work_root / "run_phase1_phase2_full_v2"
submission_path = work_root / "submission.csv"

sample_path = comp_root / "sample_submission.csv"
targets_path = comp_root / "test_sequences.csv"
if not sample_path.exists():
    _die("LOAD", f"{SCRIPT_LOC}:paths", "sample_submission.csv ausente", 1, [str(sample_path)])
if not targets_path.exists():
    _die("LOAD", f"{SCRIPT_LOC}:paths", "test_sequences.csv ausente", 1, [str(targets_path)])

if run_root.exists() and any(run_root.iterdir()):
    _die("PIPELINE", f"{SCRIPT_LOC}:paths", "run_root nao-vazio; use diretorio novo", 1, [str(run_root)])
run_root.mkdir(parents=True, exist_ok=True)

PREBUILT_DATASET = "stanford-rna3d-submission-len68-v1"
prebuilt_submission_input = input_root / PREBUILT_DATASET / "submission.csv"
if prebuilt_submission_input.exists():
    print(f"[INFO] [{SCRIPT_LOC}] usando submission preconstruida: {prebuilt_submission_input}")
    shutil.copyfile(prebuilt_submission_input, submission_path)

    fixed_src_root = input_root / "stanford-rna3d-reboot-src-v2" / "src"
    if not (fixed_src_root / "rna3d_local" / "cli.py").exists():
        _die("LOAD", f"{SCRIPT_LOC}:prebuilt", "src fixo ausente para validacao", 1, [str(fixed_src_root)])

    env = os.environ.copy()
    env["PYTHONPATH"] = str(fixed_src_root) + ((":" + env["PYTHONPATH"]) if env.get("PYTHONPATH") else "")
    env.setdefault("OMP_NUM_THREADS", "1")
    env.setdefault("MKL_NUM_THREADS", "1")
    env.setdefault("OPENBLAS_NUM_THREADS", "1")
    env.setdefault("NUMEXPR_NUM_THREADS", "1")

    _run([sys.executable, "-m", "rna3d_local", "check-submission", "--sample", str(sample_path), "--submission", str(submission_path)], env, fixed_src_root.parent)
    sha = hashlib.sha256(submission_path.read_bytes()).hexdigest()
    print(f"[DONE] [{SCRIPT_LOC}] submission={submission_path}")
    print(f"[INFO] [{SCRIPT_LOC}] submission_sha256={sha}")
else:
    datasets = _collect_datasets(input_root)
    src_root = _find_src_root(datasets, run_root / "src_unpack")
    template_index_path, templates_path = _find_template_assets(datasets)
    quickstart_path = _find_quickstart_file(datasets)
    phase2_runtime_assets = _materialize_phase2_assets(datasets, run_root / "phase2_assets_runtime")
    template_family_map_path = _build_template_family_map(template_index_path, run_root / "template_family_map.parquet")

    repo_root = src_root.parent
    if str(src_root) not in sys.path:
        sys.path.insert(0, str(src_root))

    env = os.environ.copy()
    env["PYTHONPATH"] = str(src_root) + ((":" + env["PYTHONPATH"]) if env.get("PYTHONPATH") else "")
    env.setdefault("OMP_NUM_THREADS", "1")
    env.setdefault("MKL_NUM_THREADS", "1")
    env.setdefault("OPENBLAS_NUM_THREADS", "1")
    env.setdefault("NUMEXPR_NUM_THREADS", "1")

    _assert_cli_commands(env, repo_root)

    print(f"[INFO] [{SCRIPT_LOC}] src_root={src_root}")
    print(f"[INFO] [{SCRIPT_LOC}] template_index={template_index_path}")
    print(f"[INFO] [{SCRIPT_LOC}] templates={templates_path}")
    print(f"[INFO] [{SCRIPT_LOC}] quickstart={quickstart_path}")
    print(f"[INFO] [{SCRIPT_LOC}] phase2_runtime_assets={phase2_runtime_assets}")

    desc_dir = run_root / "description_family"
    emb_dir = run_root / "embedding"
    emb_path = emb_dir / "template_embeddings.parquet"
    chem_path = run_root / "chemical_features.parquet"
    retrieval_path = run_root / "retrieval_candidates.parquet"
    reranked_path = retrieval_path
    tbm_path = run_root / "tbm_predictions.parquet"
    rnapro_path = run_root / "rnapro_predictions.parquet"
    chai1_path = run_root / "chai1_predictions.parquet"
    boltz1_path = run_root / "boltz1_predictions.parquet"
    hybrid_candidates_path = run_root / "hybrid_candidates.parquet"
    routing_path = run_root / "routing.parquet"
    hybrid_top5_path = run_root / "hybrid_top5.parquet"
    assets_manifest_path = run_root / "phase2_assets_manifest.json"

    _run([sys.executable, "-m", "rna3d_local", "build-phase2-assets", "--assets-dir", str(phase2_runtime_assets), "--manifest", str(assets_manifest_path)], env, repo_root)
    _run([sys.executable, "-m", "rna3d_local", "build-embedding-index", "--template-index", str(template_index_path), "--out-dir", str(emb_dir), "--embedding-dim", "256", "--encoder", "mock", "--ann-engine", "none"], env, repo_root)
    _run([sys.executable, "-m", "rna3d_local", "infer-description-family", "--targets", str(targets_path), "--out-dir", str(desc_dir), "--backend", "rules", "--template-family-map", str(template_family_map_path)], env, repo_root)
    _run([sys.executable, "-m", "rna3d_local", "prepare-chemical-features", "--quickstart", str(quickstart_path), "--out", str(chem_path)], env, repo_root)
    _run([
        sys.executable, "-m", "rna3d_local", "retrieve-templates-latent",
        "--template-index", str(template_index_path),
        "--template-embeddings", str(emb_path),
        "--targets", str(targets_path),
        "--out", str(retrieval_path),
        "--top-k", str(TOP_K),
        "--encoder", "mock",
        "--embedding-dim", "256",
        "--ann-engine", "numpy_bruteforce",
        "--family-prior", str(desc_dir / "family_prior.parquet"),
        "--weight-embed", "0.70",
        "--weight-llm", "0.20",
        "--weight-seq", "0.10",
    ], env, repo_root)

    _run([sys.executable, "-m", "rna3d_local", "predict-tbm", "--retrieval", str(reranked_path), "--templates", str(templates_path), "--targets", str(targets_path), "--out", str(tbm_path), "--n-models", str(N_MODELS)], env, repo_root)
    _run([sys.executable, "-m", "rna3d_local", "predict-rnapro-offline", "--model-dir", str(phase2_runtime_assets / "models" / "rnapro"), "--targets", str(targets_path), "--out", str(rnapro_path), "--n-models", str(N_MODELS)], env, repo_root)
    _run([sys.executable, "-m", "rna3d_local", "predict-chai1-offline", "--model-dir", str(phase2_runtime_assets / "models" / "chai1"), "--targets", str(targets_path), "--out", str(chai1_path), "--n-models", str(N_MODELS)], env, repo_root)
    _run([sys.executable, "-m", "rna3d_local", "predict-boltz1-offline", "--model-dir", str(phase2_runtime_assets / "models" / "boltz1"), "--targets", str(targets_path), "--out", str(boltz1_path), "--n-models", str(N_MODELS)], env, repo_root)

    _run([
        sys.executable, "-m", "rna3d_local", "build-hybrid-candidates",
        "--targets", str(targets_path),
        "--retrieval", str(reranked_path),
        "--tbm", str(tbm_path),
        "--rnapro", str(rnapro_path),
        "--chai1", str(chai1_path),
        "--boltz1", str(boltz1_path),
        "--out", str(hybrid_candidates_path),
        "--routing-out", str(routing_path),
        "--template-score-threshold", str(TEMPLATE_SCORE_THRESHOLD),
    ], env, repo_root)

    _run([sys.executable, "-m", "rna3d_local", "select-top5-hybrid", "--candidates", str(hybrid_candidates_path), "--out", str(hybrid_top5_path), "--n-models", str(N_MODELS)], env, repo_root)
    _run([sys.executable, "-m", "rna3d_local", "select-top5-hybrid", "--candidates", str(hybrid_candidates_path), "--out", str(hybrid_top5_path), "--n-models", str(N_MODELS)], env, repo_root)

    # Export two candidates and merge by sequence length (fail-fast, no fallback).
    submission_tbm_path = run_root / "submission_tbm.csv"
    submission_hybrid_path = run_root / "submission_hybrid.csv"

    _run([sys.executable, "-m", "rna3d_local", "export-submission", "--sample", str(sample_path), "--predictions", str(tbm_path), "--out", str(submission_tbm_path)], env, repo_root)
    _run([sys.executable, "-m", "rna3d_local", "export-submission", "--sample", str(sample_path), "--predictions", str(hybrid_top5_path), "--out", str(submission_hybrid_path)], env, repo_root)

    sub_tbm = pl.read_csv(submission_tbm_path)
    sub_hybrid = pl.read_csv(submission_hybrid_path)
    sample_df = pl.read_csv(sample_path)

    if sample_df.columns != sub_tbm.columns or sample_df.columns != sub_hybrid.columns:
        _die("EXPORT", f"{SCRIPT_LOC}:merge", "colunas divergentes entre submissions", 1, ["sample", "tbm", "hybrid"])

    # Map target length (ignoring chain separators).
    targets_df = pl.read_csv(targets_path).select(pl.col("target_id").cast(pl.Utf8), pl.col("sequence").cast(pl.Utf8))
    targets_df = targets_df.with_columns(pl.col("sequence").str.replace_all("\\|", "").str.len_chars().alias("L"))

    base = sample_df.select("ID").with_columns(pl.col("ID").str.split("_").list.first().alias("target_id"))
    base = base.join(targets_df.select("target_id", "L"), on="target_id", how="left")
    missing = base.filter(pl.col("L").is_null())
    if missing.height > 0:
        _die("EXPORT", f"{SCRIPT_LOC}:merge", "target sem comprimento", missing.height, missing.get_column("ID").head(8).to_list())
    base = base.with_columns((pl.col("L") > int(LEN_THRESHOLD)).alias("use_hybrid"))

    # Join and select columns from chosen source.
    a = sub_hybrid.rename({c: f"{c}_H" for c in sample_df.columns if c != "ID"})
    b = sub_tbm.rename({c: f"{c}_T" for c in sample_df.columns if c != "ID"})
    joined = base.join(a, on="ID", how="left").join(b, on="ID", how="left")
    miss2 = joined.filter(pl.col("resid_H").is_null() | pl.col("resid_T").is_null())
    if miss2.height > 0:
        _die("EXPORT", f"{SCRIPT_LOC}:merge", "join incompleto", miss2.height, miss2.get_column("ID").head(8).to_list())

    mismatch = joined.filter((pl.col("resname_H") != pl.col("resname_T")) | (pl.col("resid_H") != pl.col("resid_T")))
    if mismatch.height > 0:
        _die("EXPORT", f"{SCRIPT_LOC}:merge", "resname/resid divergem entre tbm e hybrid", mismatch.height, mismatch.select("ID").head(8).to_series().to_list())

    expr = [pl.col("ID")]
    for c in sample_df.columns:
        if c == "ID":
            continue
        expr.append(pl.when(pl.col("use_hybrid")).then(pl.col(f"{c}_H")).otherwise(pl.col(f"{c}_T")).alias(c))
    out = joined.select(expr)
    out = out.join(sample_df.select("ID").with_row_index("_idx"), on="ID", how="left").sort("_idx").drop("_idx")
    if out.height != sample_df.height:
        _die("EXPORT", f"{SCRIPT_LOC}:merge", "numero de linhas divergente do sample", 1, [f"sample={sample_df.height}", f"out={out.height}"])
    out.write_csv(submission_path)

    _run([sys.executable, "-m", "rna3d_local", "check-submission", "--sample", str(sample_path), "--submission", str(submission_path)], env, repo_root)

    sha = hashlib.sha256(submission_path.read_bytes()).hexdigest()
    print(f"[DONE] [{SCRIPT_LOC}] submission={submission_path}")
    print(f"[INFO] [{SCRIPT_LOC}] submission_sha256={sha}")
