# Stanford RNA3D submit notebook (Fase 1 + Fase 2 full pipeline)

Pipeline completo com contratos estritos (fail-fast), sem fallback silencioso.

In [None]:
import csv
import hashlib
import os
import shutil
import subprocess
import sys
import zipfile
from pathlib import Path

import polars as pl

SCRIPT_LOC = "submission_notebook_phase1_phase2_full_v2"
N_MODELS = 5
TOP_K = 16
TEMPLATE_SCORE_THRESHOLD = 0.65
LEN_THRESHOLD = 68
TBM_CHUNK_SIZE = 12


def _die(stage: str, where: str, cause: str, impact, examples) -> None:
    examples_text = ",".join(str(x) for x in examples) if examples else "-"
    raise RuntimeError(f"[{stage}] [{where}] {cause} | impacto={impact} | exemplos={examples_text}")


def _tail(stdout: str | None, stderr: str | None, n: int = 20) -> list[str]:
    text = ((stdout or "") + "\n" + (stderr or "")).strip()
    if not text:
        return []
    return text.splitlines()[-n:]


def _run(cmd: list[str], env: dict[str, str], cwd: Path) -> None:
    where = f"{SCRIPT_LOC}:run"
    print("[RUN]", " ".join(cmd))
    proc = subprocess.run(cmd, env=env, cwd=str(cwd), text=True, capture_output=True)
    if proc.returncode != 0:
        _die("PIPELINE", where, "comando falhou", proc.returncode, _tail(proc.stdout, proc.stderr, 20))
    out = (proc.stdout or "").strip()
    if out:
        print(out)


def _collect_datasets(input_root: Path) -> list[Path]:
    where = f"{SCRIPT_LOC}:collect_datasets"
    if not input_root.exists():
        _die("LOAD", where, "diretorio /kaggle/input ausente", 1, [str(input_root)])
    datasets = [path for path in sorted(input_root.iterdir()) if path.is_dir()]
    if not datasets:
        _die("LOAD", where, "nenhum dataset montado em /kaggle/input", 0, [])
    return datasets


def _find_by_filename(datasets: list[Path], filename: str) -> list[Path]:
    hits: list[Path] = []
    for ds in datasets:
        hits.extend([path for path in ds.rglob(filename) if path.is_file()])
    return hits


def _find_by_pattern(datasets: list[Path], pattern: str) -> list[Path]:
    hits: list[Path] = []
    for ds in datasets:
        hits.extend([path for path in ds.rglob(pattern) if path.is_file()])
    return hits


def _require_single(label: str, candidates: list[Path], *, stage: str, where: str) -> Path:
    unique: list[Path] = []
    seen: set[str] = set()
    for item in candidates:
        key = str(item.resolve())
        if key not in seen:
            seen.add(key)
            unique.append(item.resolve())
    if not unique:
        _die(stage, where, f"ativo obrigatorio ausente: {label}", 1, [label])
    if len(unique) > 1:
        _die(stage, where, f"ativo ambiguo para {label}", len(unique), [str(path) for path in unique[:8]])
    return unique[0]


def _require_any(label: str, candidates: list[Path], *, stage: str, where: str) -> list[Path]:
    unique: list[Path] = []
    seen: set[str] = set()
    for item in candidates:
        key = str(item.resolve())
        if key not in seen:
            seen.add(key)
            unique.append(item.resolve())
    if not unique:
        _die(stage, where, f"ativo obrigatorio ausente: {label}", 1, [label])
    return unique


def _src_supports_command(src_root: Path, command: str) -> bool:
    env = os.environ.copy()
    env["PYTHONPATH"] = str(src_root) + ((":" + env["PYTHONPATH"]) if env.get("PYTHONPATH") else "")
    proc = subprocess.run([sys.executable, "-m", "rna3d_local", command, "--help"], env=env, cwd=str(src_root.parent), text=True, capture_output=True)
    return proc.returncode == 0


def _find_src_root(datasets: list[Path], unpack_root: Path) -> Path:
    where = f"{SCRIPT_LOC}:find_src_root"
    candidates: list[Path] = [ds / "src" for ds in datasets if (ds / "src" / "rna3d_local" / "cli.py").exists()]

    zip_candidates = _find_by_filename(datasets, "src_reboot.zip")
    if zip_candidates:
        zip_path = _require_single("src_reboot.zip", zip_candidates, stage="LOAD", where=where)
        if unpack_root.exists() and any(unpack_root.iterdir()):
            _die("LOAD", where, "diretorio de unpack do src nao-vazio", 1, [str(unpack_root)])
        unpack_root.mkdir(parents=True, exist_ok=True)
        with zipfile.ZipFile(str(zip_path), "r") as archive:
            archive.extractall(str(unpack_root))
        extracted_src = unpack_root / "src"
        if not (extracted_src / "rna3d_local" / "cli.py").exists():
            _die("LOAD", where, "src_reboot.zip sem src/rna3d_local/cli.py", 1, [str(zip_path)])
        candidates.append(extracted_src)

    if not candidates:
        _die("LOAD", where, "nenhum src/rna3d_local/cli.py encontrado", 1, [])

    unique: list[Path] = []
    seen: set[str] = set()
    for item in candidates:
        key = str(item.resolve())
        if key not in seen:
            seen.add(key)
            unique.append(item.resolve())

    compatible = [src for src in unique if _src_supports_command(src, "build-embedding-index")]
    if not compatible:
        _die("LOAD", where, "nenhum src compativel com CLI da Fase 1+2", len(unique), [str(path) for path in unique[:8]])
    if len(compatible) > 1:
        _die("LOAD", where, "multiplos src compativeis; selecao ambigua", len(compatible), [str(path) for path in compatible[:8]])
    return compatible[0]


def _find_template_assets(datasets: list[Path]) -> tuple[Path, Path]:
    where = f"{SCRIPT_LOC}:find_template_assets"

    # Prefer deterministic sources to avoid ambiguity when multiple datasets ship template_db.
    preferred = [
        "stanford-rna3d-reboot-src-v2",
        "ribonanza-quickstart-3d-templates",
    ]
    for ds_name in preferred:
        ds = next((item for item in datasets if item.name == ds_name), None)
        if ds is None:
            continue
        idx_candidates = _find_by_filename([ds], "template_index.parquet")
        tpl_candidates = _find_by_filename([ds], "templates.parquet")
        if idx_candidates and tpl_candidates:
            template_index = _require_single("template_index.parquet", idx_candidates, stage="LOAD", where=where)
            templates = _require_single("templates.parquet", tpl_candidates, stage="LOAD", where=where)
            return template_index, templates

    template_index = _require_single("template_index.parquet", _find_by_filename(datasets, "template_index.parquet"), stage="LOAD", where=where)
    templates = _require_single("templates.parquet", _find_by_filename(datasets, "templates.parquet"), stage="LOAD", where=where)
    return template_index, templates


def _find_quickstart_file(datasets: list[Path]) -> Path:
    where = f"{SCRIPT_LOC}:find_quickstart"
    candidates = _find_by_pattern(datasets, "*QUICK_START*.csv") + _find_by_pattern(datasets, "*quickstart*.csv") + _find_by_pattern(datasets, "*quickstart*.parquet")
    return _require_single("quickstart", candidates, stage="LOAD", where=where)


def _materialize_phase2_assets(datasets: list[Path], dst: Path) -> Path:
    where = f"{SCRIPT_LOC}:materialize_phase2_assets"
    # Locate a phase2 assets root that matches the rna3d_local offline runners contract.
    # Expected layout: <root>/models/{rnapro,chai1,boltz1}/...
    required_markers = [
        ("rnapro", Path("models/rnapro/rnapro-public-best-500m.ckpt")),
        ("chai1", Path("models/chai1/models_v2/trunk.pt")),
        ("boltz1", Path("models/boltz1/boltz1_conf.ckpt")),
    ]
    roots: list[Path] = []
    for ds in datasets:
        for cand in [
            ds,
            ds / "export" / "kaggle_assets",
            ds / "kaggle_assets",
            ds / "export" / "kaggle_assets" / "export" / "kaggle_assets",
        ]:
            ok = True
            for _label, rel in required_markers:
                if not (cand / rel).exists():
                    ok = False
                    break
            if ok:
                roots.append(cand.resolve())
    uniq: list[Path] = []
    seen: set[str] = set()
    for r in roots:
        k = str(r)
        if k not in seen:
            seen.add(k)
            uniq.append(r)
    if not uniq:
        _die("LOAD", where, "assets phase2 nao encontrados (markers ausentes)", 1, [str(m[1]) for m in required_markers])
    if len(uniq) > 1:
        _die("LOAD", where, "assets phase2 ambiguos (multiplos roots)", len(uniq), [str(p) for p in uniq[:8]])
    return uniq[0]

def _build_template_family_map(template_index_path: Path, out_path: Path) -> Path:
    where = f"{SCRIPT_LOC}:build_template_family_map"
    df = pl.read_parquet(str(template_index_path))
    if "template_uid" not in df.columns:
        _die("PIPELINE", where, "template_index sem template_uid", 1, [str(template_index_path)])
    out = df.select(pl.col("template_uid").cast(pl.Utf8).unique().sort()).with_columns(pl.lit("unknown").alias("family_label"))
    out_path.parent.mkdir(parents=True, exist_ok=True)
    out.write_parquet(str(out_path), compression="zstd")
    if out.height == 0:
        _die("PIPELINE", where, "template_family_map vazio", 0, [str(out_path)])
    return out_path


def _build_weak_labels(candidates_path: Path, out_path: Path) -> Path:
    where = f"{SCRIPT_LOC}:build_weak_labels"
    df = pl.read_parquet(str(candidates_path))
    required = ["target_id", "template_uid"]
    missing = [c for c in required if c not in df.columns]
    if missing:
        _die("PIPELINE", where, "candidates sem colunas obrigatorias", len(missing), missing)

    score_col = None
    for c in ["final_score", "cosine_score", "rank"]:
        if c in df.columns:
            score_col = c
            break
    if score_col is None:
        _die("PIPELINE", where, "candidates sem coluna de ranking/score", 1, [str(df.columns[:8])])

    if score_col == "rank":
        ordered = df.sort(["target_id", "rank"], descending=[False, False])
    else:
        ordered = df.sort(["target_id", score_col], descending=[False, True])

    top1 = ordered.group_by("target_id").agg(pl.first("template_uid").alias("top_template"))
    labels = (
        ordered.select("target_id", "template_uid")
        .unique()
        .join(top1, on="target_id", how="left")
        .with_columns((pl.col("template_uid") == pl.col("top_template")).cast(pl.Float64).alias("label"))
        .select("target_id", "template_uid", "label")
    )
    if labels.height < 8:
        _die("PIPELINE", where, "labels fracos insuficientes", labels.height, ["min=8"]) 
    out_path.parent.mkdir(parents=True, exist_ok=True)
    labels.write_parquet(str(out_path), compression="zstd")
    return out_path


def _assert_cli_commands(env: dict[str, str], repo_root: Path) -> None:
    where = f"{SCRIPT_LOC}:assert_cli"
    required = [
        "build-embedding-index",
        "infer-description-family",
        "retrieve-templates-latent",
        "predict-tbm",
        "predict-rnapro-offline",
        "export-submission",
        "check-submission",
    ]
    for command in required:
        proc = subprocess.run([sys.executable, "-m", "rna3d_local", command, "--help"], env=env, cwd=str(repo_root), text=True, capture_output=True)
        if proc.returncode != 0:
            _die("ENV", where, f"comando ausente no pacote rna3d_local: {command}", proc.returncode, _tail(proc.stdout, proc.stderr, 12))


import uuid

comp_root = Path("/kaggle/input/stanford-rna-3d-folding-2")
input_root = Path("/kaggle/input")
work_root = Path("/kaggle/working")
run_root = work_root / f"run_{SCRIPT_LOC}_{uuid.uuid4().hex[:8]}"
submission_path = work_root / "submission.csv"

sample_path = comp_root / "sample_submission.csv"
targets_path = comp_root / "test_sequences.csv"
if not sample_path.exists():
    _die("LOAD", f"{SCRIPT_LOC}:paths", "sample_submission.csv ausente", 1, [str(sample_path)])
if not targets_path.exists():
    _die("LOAD", f"{SCRIPT_LOC}:paths", "test_sequences.csv ausente", 1, [str(targets_path)])

run_root.mkdir(parents=True, exist_ok=False)

datasets = _collect_datasets(input_root)
src_root = _find_src_root(datasets, run_root / "src_unpack")
template_index_path, templates_path = _find_template_assets(datasets)
template_family_map_path = _build_template_family_map(template_index_path, run_root / "template_family_map.parquet")

repo_root = src_root.parent
if str(src_root) not in sys.path:
    sys.path.insert(0, str(src_root))

env = os.environ.copy()
env["PYTHONPATH"] = str(src_root) + ((":" + env["PYTHONPATH"]) if env.get("PYTHONPATH") else "")
env.setdefault("OMP_NUM_THREADS", "1")
env.setdefault("MKL_NUM_THREADS", "1")
env.setdefault("OPENBLAS_NUM_THREADS", "1")
env.setdefault("NUMEXPR_NUM_THREADS", "1")
env.setdefault("POLARS_MAX_THREADS", "1")
# Avoid export streaming partition path conflicts in Kaggle runtime; we already keep RAM low upstream.
env["RNA3D_EXPORT_STREAMING_THRESHOLD_BYTES"] = "1000000000000"

_assert_cli_commands(env, repo_root)

print(f"[INFO] [{SCRIPT_LOC}] run_root={run_root}")
print(f"[INFO] [{SCRIPT_LOC}] src_root={src_root}")
print(f"[INFO] [{SCRIPT_LOC}] template_index={template_index_path}")
print(f"[INFO] [{SCRIPT_LOC}] templates={templates_path}")

emb_dir = run_root / "embedding"
emb_path = emb_dir / "template_embeddings.parquet"
retrieval_path = run_root / "retrieval_candidates.parquet"
retrieval_tbm_path = run_root / "retrieval_candidates_tbm.parquet"
targets_tbm_path = run_root / "targets_tbm.csv"
targets_fallback_path = run_root / "targets_fallback.csv"
tbm_path = run_root / "tbm_predictions.parquet"
drfold_path = run_root / "drfold_predictions.parquet"
combined_path = run_root / "combined_predictions.parquet"

_run([sys.executable, "-m", "rna3d_local", "build-embedding-index", "--template-index", str(template_index_path), "--out-dir", str(emb_dir), "--embedding-dim", "256", "--encoder", "mock", "--ann-engine", "none"], env, repo_root)
_run([
    sys.executable, "-m", "rna3d_local", "retrieve-templates-latent",
    "--template-index", str(template_index_path),
    "--template-embeddings", str(emb_path),
    "--targets", str(targets_path),
    "--out", str(retrieval_path),
    "--top-k", str(TOP_K),
    "--encoder", "mock",
    "--embedding-dim", "256",
    "--ann-engine", "numpy_bruteforce",
    "--weight-embed", "0.90",
    "--weight-llm", "0.00",
    "--weight-seq", "0.10",
], env, repo_root)

# Compute which targets can be covered by at least one contiguous template with length >= target_len.
targets_df = pl.read_csv(targets_path)
if "target_id" not in targets_df.columns or "sequence" not in targets_df.columns:
    _die("LOAD", f"{SCRIPT_LOC}:targets", "targets sem colunas target_id/sequence", 1, [str(targets_df.columns)])
targets_df = targets_df.with_columns(pl.col("target_id").cast(pl.Utf8), pl.col("sequence").cast(pl.Utf8))
targets_len = targets_df.select(pl.col("target_id"), pl.col("sequence").str.replace_all(r"\|", "").str.len_chars().alias("target_len"))

tpl_stats = pl.scan_parquet(str(templates_path)).select(pl.col("template_uid").cast(pl.Utf8), pl.col("resid").cast(pl.Int32))
tpl_stats = tpl_stats.group_by("template_uid").agg(pl.col("resid").min().alias("min_resid"), pl.col("resid").max().alias("max_resid"), pl.col("resid").n_unique().alias("n_unique"))
tpl_stats = tpl_stats.with_columns((pl.col("max_resid") - pl.col("min_resid") + 1).cast(pl.Int32).alias("tpl_len"))
tpl_stats = tpl_stats.with_columns((pl.col("n_unique") == pl.col("tpl_len")).alias("contiguous"))

retr_lf = pl.scan_parquet(str(retrieval_path)).select(pl.col("target_id").cast(pl.Utf8), pl.col("template_uid").cast(pl.Utf8)).unique()
supported = retr_lf.join(targets_len.lazy(), on="target_id", how="inner").join(tpl_stats, on="template_uid", how="inner")
supported = supported.filter(pl.col("contiguous") & (pl.col("tpl_len") >= pl.col("target_len"))).select(pl.col("target_id")).unique().collect()
supported_ids = set(supported.get_column("target_id").to_list())
all_ids = set(targets_len.get_column("target_id").to_list())
fallback_ids = sorted(all_ids - supported_ids)
print(f"[INFO] [{SCRIPT_LOC}] tbm_supported={len(supported_ids)} fallback={len(fallback_ids)}")
if len(supported_ids) == 0:
    _die("PIPELINE", f"{SCRIPT_LOC}:coverage", "nenhum alvo com template valido para TBM", 0, [])

# Write fallback targets once.
targets_df.filter(pl.col("target_id").is_in(fallback_ids)).write_csv(targets_fallback_path)

# Run TBM in small chunks to cap peak RAM in hidden reruns.
supported_ids_sorted = sorted(supported_ids)
tbm_parts_dir = run_root / "tbm_parts"
tbm_parts_dir.mkdir(parents=True, exist_ok=True)
tbm_part_paths: list[Path] = []
for chunk_start in range(0, len(supported_ids_sorted), int(TBM_CHUNK_SIZE)):
    chunk_ids = supported_ids_sorted[chunk_start : chunk_start + int(TBM_CHUNK_SIZE)]
    if not chunk_ids:
        continue
    chunk_tag = f"{chunk_start // int(TBM_CHUNK_SIZE):04d}"
    targets_chunk_path = tbm_parts_dir / f"targets_tbm_{chunk_tag}.csv"
    retrieval_chunk_path = tbm_parts_dir / f"retrieval_tbm_{chunk_tag}.parquet"
    tbm_chunk_path = tbm_parts_dir / f"tbm_predictions_{chunk_tag}.parquet"

    targets_df.filter(pl.col("target_id").is_in(chunk_ids)).write_csv(targets_chunk_path)
    pl.scan_parquet(str(retrieval_path)).filter(pl.col("target_id").is_in(chunk_ids)).sink_parquet(str(retrieval_chunk_path), engine="streaming")

    _run([
        sys.executable,
        "-m",
        "rna3d_local",
        "predict-tbm",
        "--retrieval",
        str(retrieval_chunk_path),
        "--templates",
        str(templates_path),
        "--targets",
        str(targets_chunk_path),
        "--out",
        str(tbm_chunk_path),
        "--n-models",
        str(N_MODELS),
    ], env, repo_root)
    if not tbm_chunk_path.exists():
        _die("PIPELINE", f"{SCRIPT_LOC}:tbm_chunk", "predict-tbm sem saida no chunk", 1, [str(tbm_chunk_path)])
    tbm_part_paths.append(tbm_chunk_path)

if not tbm_part_paths:
    _die("PIPELINE", f"{SCRIPT_LOC}:tbm_chunk", "nenhum chunk TBM gerado", 0, [])

pl.scan_parquet([str(p) for p in tbm_part_paths]).sink_parquet(str(tbm_path), engine="streaming")

def _normalize_seq(seq: str) -> str:
    raw = str(seq or "").strip().upper().replace("T", "U")
    return "".join(ch for ch in raw if ch not in {"|", " ", "\t", "\n", "\r"})


def _assert_target_model_coverage(predictions_path: Path, target_id: str, *, n_models: int, target_len: int) -> None:
    where = f"{SCRIPT_LOC}:fallback_coverage"
    if target_len <= 0:
        _die("FALLBACK", where, "target_len invalido", 1, [f"{target_id}:{target_len}"])
    df = (
        pl.scan_parquet(str(predictions_path))
        .filter(pl.col("target_id") == target_id)
        .select(
            pl.col("model_id").cast(pl.Int32),
            pl.col("resid").cast(pl.Int32),
        )
        .collect(streaming=True)
    )
    if df.height <= 0:
        _die("FALLBACK", where, "fallback sem linhas para target", 1, [target_id, str(predictions_path)])

    stats = (
        df.group_by("model_id")
        .agg(
            pl.col("resid").n_unique().alias("n_unique"),
            pl.col("resid").min().alias("min_resid"),
            pl.col("resid").max().alias("max_resid"),
        )
        .sort("model_id")
    )
    mids = sorted(stats.get_column("model_id").to_list())
    expected_mids = list(range(1, int(n_models) + 1))
    if mids != expected_mids:
        _die("FALLBACK", where, "fallback sem model_id esperado", 1, [f"{target_id}:got={mids}:expected={expected_mids}"])

    bad = stats.filter(
        (pl.col("n_unique") != int(target_len))
        | (pl.col("min_resid") != 1)
        | (pl.col("max_resid") != int(target_len))
    )
    if bad.height > 0:
        examples = (
            bad.select(
                (
                    pl.col("model_id").cast(pl.Utf8)
                    + pl.lit(":")
                    + pl.col("n_unique").cast(pl.Utf8)
                    + pl.lit(":")
                    + pl.col("min_resid").cast(pl.Utf8)
                    + pl.lit(":")
                    + pl.col("max_resid").cast(pl.Utf8)
                ).alias("k")
            )
            .head(8)
            .get_column("k")
            .to_list()
        )
        _die("FALLBACK", where, "fallback com cobertura insuficiente", int(bad.height), [f"{target_id}:{x}" for x in examples])


# Phase2 fallback for targets sem cobertura TBM estrita.
if fallback_ids:
    phase2_assets_root = _materialize_phase2_assets(datasets, run_root / "phase2_assets")
    rnapro_model_dir = phase2_assets_root / "models" / "rnapro"
    if not rnapro_model_dir.exists():
        _die("FALLBACK", f"{SCRIPT_LOC}:rnapro", "model dir RNApro ausente", 1, [str(rnapro_model_dir)])

    fallback_targets_df = targets_df.filter(pl.col("target_id").is_in(fallback_ids)).sort("target_id")
    fallback_targets_df.write_csv(targets_fallback_path)

    fallback_pred_path = run_root / "fallback_rnapro_predictions.parquet"
    _run([
        sys.executable,
        "-m",
        "rna3d_local",
        "predict-rnapro-offline",
        "--model-dir",
        str(rnapro_model_dir),
        "--targets",
        str(targets_fallback_path),
        "--out",
        str(fallback_pred_path),
        "--n-models",
        str(N_MODELS),
    ], env, repo_root)

    for tid in fallback_ids:
        seq_raw = targets_df.filter(pl.col("target_id") == tid).get_column("sequence").item()
        seq = _normalize_seq(seq_raw)
        _assert_target_model_coverage(fallback_pred_path, tid, n_models=int(N_MODELS), target_len=len(seq))

    pl.concat([pl.scan_parquet(str(tbm_path)), pl.scan_parquet(str(fallback_pred_path))], how="vertical_relaxed").sink_parquet(str(combined_path), engine="streaming")
else:
    pl.scan_parquet(str(tbm_path)).sink_parquet(str(combined_path), engine="streaming")


def _center_predictions_long(predictions_path: Path, out_path: Path) -> None:
    where = f"{SCRIPT_LOC}:center_predictions"
    lf = pl.scan_parquet(str(predictions_path))
    required = {"target_id", "model_id", "resid", "resname", "x", "y", "z"}
    cols = set(lf.collect_schema().names())
    missing = sorted(required - cols)
    if missing:
        _die("EXPORT", where, "predictions sem colunas obrigatorias para centralizacao", len(missing), missing)

    means = lf.group_by(["target_id", "model_id"]).agg(
        pl.col("x").cast(pl.Float64).mean().alias("_mx"),
        pl.col("y").cast(pl.Float64).mean().alias("_my"),
        pl.col("z").cast(pl.Float64).mean().alias("_mz"),
    )
    centered = (
        lf.join(means, on=["target_id", "model_id"], how="left")
        .with_columns(
            (pl.col("x").cast(pl.Float64) - pl.col("_mx")).alias("x"),
            (pl.col("y").cast(pl.Float64) - pl.col("_my")).alias("y"),
            (pl.col("z").cast(pl.Float64) - pl.col("_mz")).alias("z"),
        )
        .select(
            pl.col("target_id").cast(pl.Utf8),
            pl.col("model_id").cast(pl.Int32),
            pl.col("resid").cast(pl.Int32),
            pl.col("resname").cast(pl.Utf8),
            pl.col("x").cast(pl.Float64),
            pl.col("y").cast(pl.Float64),
            pl.col("z").cast(pl.Float64),
        )
    )
    centered.sink_parquet(str(out_path), engine="streaming")


def _assert_submission_coord_bounds(sub_path: Path, *, abs_max: float = 1000.0) -> None:
    where = f"{SCRIPT_LOC}:submission_bounds"
    if abs_max <= 0:
        _die("CHECK", where, "abs_max invalido", 1, [str(abs_max)])
    with sub_path.open("r", encoding="utf-8", newline="") as handle:
        reader = csv.reader(handle)
        header = next(reader, None)
        if not header:
            _die("CHECK", where, "submission csv vazio", 1, [str(sub_path)])
        coord_idxs = [idx for idx, name in enumerate(header) if name.startswith(("x_", "y_", "z_"))]
        bad: list[str] = []
        for row_idx, row in enumerate(reader, start=1):
            if len(row) != len(header):
                _die("CHECK", where, "linha com numero de colunas invalido", 1, [f"row={row_idx}"])
            for cidx in coord_idxs:
                col = header[cidx]
                raw = row[cidx]
                try:
                    val = float(raw)
                except Exception:
                    bad.append(f"{col}@{row_idx}:non-numeric")
                    if len(bad) >= 8:
                        _die("CHECK", where, "coordenadas invalidas na submission", len(bad), bad)
                    continue
                if not (val == val) or abs(val) > abs_max:
                    bad.append(f"{col}@{row_idx}:abs>{abs_max:g}")
                    if len(bad) >= 8:
                        _die("CHECK", where, "coordenadas invalidas na submission", len(bad), bad)
        if bad:
            _die("CHECK", where, "coordenadas invalidas na submission", len(bad), bad)


def _parse_target_resid(id_value: str) -> tuple[str, int]:
    where = f"{SCRIPT_LOC}:parse_id"
    key = str(id_value or "")
    if "_" not in key:
        _die("EXPORT", where, "ID invalido (esperado <target>_<resid>)", 1, [key])
    target_id, resid_str = key.rsplit("_", 1)
    try:
        resid = int(resid_str)
    except Exception:
        _die("EXPORT", where, "ID invalido (resid nao-inteiro)", 1, [key])
    return target_id, resid


def _model_ids_from_sample_header(header: list[str]) -> list[int]:
    where = f"{SCRIPT_LOC}:sample_header"
    mids = sorted(int(c.split("_", 1)[1]) for c in header if c.startswith("x_"))
    if not mids:
        _die("EXPORT", where, "sample sem colunas de modelo", 1, header[:8])
    for mid in mids:
        for pref in ("y_", "z_"):
            col = f"{pref}{mid}"
            if col not in header:
                _die("EXPORT", where, "sample sem coluna obrigatoria de modelo", 1, [col])
    return mids


def _load_target_coords(predictions_path: Path, target_id: str, model_ids: list[int]) -> dict[int, dict[int, tuple[float, float, float]]]:
    where = f"{SCRIPT_LOC}:load_target_coords"
    df = (
        pl.scan_parquet(str(predictions_path))
        .filter(pl.col("target_id") == target_id)
        .select(
            pl.col("model_id").cast(pl.Int32),
            pl.col("resid").cast(pl.Int32),
            pl.col("x").cast(pl.Float64),
            pl.col("y").cast(pl.Float64),
            pl.col("z").cast(pl.Float64),
        )
        .collect(streaming=True)
    )
    if df.height <= 0:
        _die("EXPORT", where, "predictions sem target_id", 1, [target_id])

    dup = df.group_by(["model_id", "resid"]).agg(pl.len().alias("n")).filter(pl.col("n") > 1)
    if dup.height > 0:
        ex = (
            dup.select((pl.col("model_id").cast(pl.Utf8) + pl.lit(":") + pl.col("resid").cast(pl.Utf8)).alias("k"))
            .head(8)
            .get_column("k")
            .to_list()
        )
        _die("EXPORT", where, "predictions com chave duplicada no target", int(dup.height), [f"{target_id}:{x}" for x in ex])

    coords: dict[int, dict[int, tuple[float, float, float]]] = {}
    for row in df.iter_rows(named=True):
        mid = int(row["model_id"])
        if mid not in model_ids:
            continue
        resid = int(row["resid"])
        x = float(row["x"])
        y = float(row["y"])
        z = float(row["z"])
        if not (x == x and y == y and z == z):
            _die("EXPORT", where, "coordenadas nao-finitas nas predictions", 1, [f"{target_id}:{mid}:{resid}"])
        per_resid = coords.setdefault(resid, {})
        if mid in per_resid:
            _die("EXPORT", where, "predictions com chave duplicada no target", 1, [f"{target_id}:{mid}:{resid}"])
        per_resid[mid] = (x, y, z)
    return coords


def _export_submission_streaming_local(sample_csv: Path, predictions_path: Path, out_csv: Path) -> None:
    where = f"{SCRIPT_LOC}:export_streaming_local"
    with sample_csv.open("r", encoding="utf-8", newline="") as f_in, out_csv.open("w", encoding="utf-8", newline="") as f_out:
        reader = csv.DictReader(f_in)
        if not reader.fieldnames:
            _die("EXPORT", where, "sample vazio", 1, [str(sample_csv)])
        header = list(reader.fieldnames)
        if "ID" not in header or "resid" not in header:
            _die("EXPORT", where, "sample sem colunas obrigatorias ID/resid", 1, header[:8])
        model_ids = _model_ids_from_sample_header(header)

        writer = csv.DictWriter(f_out, fieldnames=header)
        writer.writeheader()

        current_target = None
        current_coords = None
        for row_idx, row in enumerate(reader, start=1):
            tid, resid = _parse_target_resid(str(row.get("ID", "")))
            try:
                resid_col = int(str(row.get("resid", "")).strip())
            except Exception:
                _die("EXPORT", where, "resid invalido no sample", 1, [f"row={row_idx}"])
            if resid_col != resid:
                _die("EXPORT", where, "sample com resid divergente do ID", 1, [f"{tid}_{resid}:resid={resid_col}"])

            if tid != current_target:
                current_target = tid
                current_coords = _load_target_coords(predictions_path, tid, model_ids)

            assert current_coords is not None
            per_resid = current_coords.get(resid)
            if per_resid is None:
                _die("EXPORT", where, "predictions sem resid para alvo", 1, [f"{tid}:{resid}"])

            for mid in model_ids:
                xyz = per_resid.get(mid)
                if xyz is None:
                    _die("EXPORT", where, "predictions sem model_id para resid", 1, [f"{tid}:{mid}:{resid}"])
                x, y, z = xyz
                row[f"x_{mid}"] = f"{x}"
                row[f"y_{mid}"] = f"{y}"
                row[f"z_{mid}"] = f"{z}"
            writer.writerow(row)


centered_path = run_root / "combined_predictions_centered.parquet"
submission_tmp_path = run_root / "submission.csv"
_center_predictions_long(combined_path, centered_path)
_export_submission_streaming_local(sample_path, centered_path, submission_tmp_path)
_run([sys.executable, "-m", "rna3d_local", "check-submission", "--sample", str(sample_path), "--submission", str(submission_tmp_path)], env, repo_root)
_assert_submission_coord_bounds(submission_tmp_path, abs_max=1000.0)
shutil.copyfile(submission_tmp_path, submission_path)

sha = hashlib.sha256(submission_path.read_bytes()).hexdigest()
print(f"[DONE] [{SCRIPT_LOC}] submission={submission_path}")
print(f"[INFO] [{SCRIPT_LOC}] submission_sha256={sha}")
