# Stanford RNA3D submit notebook (TBM + DRfold2 risk router v77-alt)


In [None]:
import csv
import hashlib
import math
import os
import shutil
import subprocess
import sys
import time
from pathlib import Path

SCRIPT_LOC = "submission_notebook_dynamic_tbm_drfold2_router_v77_alt_risknorm"

# Config (explicit; no silent fallback)
USE_DRFOLD2 = True
DRFOLD2_MAX_TARGETS = 6
DRFOLD2_SIMILARITY_THRESHOLD = 0.62
DRFOLD2_MAX_SEQ_LEN = 900
DRFOLD2_N_MODELS_RUN = 1  # run DRfold2 once per target (faster)
FINAL_N_MODELS = 5
SUBMISSION_ABS_CLIP = 900.0


def _die(stage: str, where: str, cause: str, impact, examples) -> None:
    raise RuntimeError(f"[{stage}] [{where}] {cause} | impacto={impact} | exemplos={examples}")


def _tail(stdout: str | None, stderr: str | None, n: int = 20) -> list[str]:
    text = ((stdout or "") + "\n" + (stderr or "")).strip()
    if not text:
        return []
    return text.splitlines()[-n:]


def _run(cmd: list[str], env: dict, cwd: Path) -> None:
    where = f"{SCRIPT_LOC}:run"
    print("[RUN]", " ".join(cmd))
    proc = subprocess.run(cmd, env=env, text=True, capture_output=True, cwd=str(cwd))
    if proc.returncode != 0:
        _die("PIPELINE", where, "comando falhou", proc.returncode, _tail(proc.stdout, proc.stderr, 20))
    out = (proc.stdout or "").strip()
    if out:
        print(out)


def _supports(subcmd: str, flag: str, env: dict, cwd: Path) -> bool:
    where = f"{SCRIPT_LOC}:supports"
    proc = subprocess.run([sys.executable, "-m", "rna3d_local", subcmd, "--help"], env=env, text=True, capture_output=True, cwd=str(cwd))
    if proc.returncode != 0:
        _die("PIPELINE", where, f"falha ao consultar help de {subcmd}", proc.returncode, _tail(proc.stdout, proc.stderr, 12))
    txt = (proc.stdout or "") + "\n" + (proc.stderr or "")
    return flag in txt


def _ensure_biopython(wheel_dir: Path | None):
    where = f"{SCRIPT_LOC}:ensure_biopython"
    try:
        import Bio  # noqa: F401
        return
    except Exception:
        pass

    if wheel_dir is None:
        _die("ENV", where, "Biopython ausente e diretorio de wheels nao encontrado", 1, [])

    wheels = list(wheel_dir.glob("biopython-*.whl"))
    if not wheels:
        _die("ENV", where, "wheel biopython nao encontrado", 1, [str(wheel_dir)])

    cmd = [
        sys.executable,
        "-m",
        "pip",
        "install",
        "--no-index",
        "--find-links",
        str(wheel_dir),
        "biopython",
    ]
    proc = subprocess.run(cmd, capture_output=True, text=True)
    if proc.returncode != 0:
        _die("ENV", where, "falha ao instalar biopython local", proc.returncode, _tail(proc.stdout, proc.stderr, 12))


def _discover_assets(input_root: Path) -> dict[str, Path | None]:
    where = f"{SCRIPT_LOC}:discover_assets"
    if not input_root.exists():
        _die("LOAD", where, "diretorio /kaggle/input ausente", 1, [str(input_root)])

    candidates = [d for d in sorted(input_root.iterdir()) if d.is_dir()]
    if not candidates:
        _die("LOAD", where, "nenhum dataset montado em /kaggle/input", 0, [])

    src_root = None
    template_index = None
    templates = None
    wheel_dir = None
    drfold2_root = None

    for d in candidates:
        if src_root is None:
            p = d / "src" / "rna3d_local" / "cli.py"
            if p.exists():
                src_root = d / "src"

        if template_index is None:
            p = d / "runs" / "20260211_real_kaggle_baseline_full_v2" / "template_db" / "template_index.parquet"
            if p.exists():
                template_index = p

        if templates is None:
            p = d / "runs" / "20260211_real_kaggle_baseline_full_v2" / "template_db" / "templates.parquet"
            if p.exists():
                templates = p

        if wheel_dir is None:
            p = d / "wheels"
            if p.exists():
                wheel_dir = p

        if drfold2_root is None:
            p = d / "DRfold_infer.py"
            if p.exists():
                drfold2_root = d

    missing = []
    if src_root is None:
        missing.append("src/rna3d_local")
    if template_index is None:
        missing.append("template_index.parquet")
    if templates is None:
        missing.append("templates.parquet")
    if USE_DRFOLD2 and drfold2_root is None:
        missing.append("DRfold_infer.py")

    if missing:
        mounted = [d.name for d in candidates]
        _die("LOAD", where, "ativos obrigatorios ausentes em /kaggle/input", len(missing), missing + [f"mounted={mounted}"])

    return {
        "src_root": src_root,
        "template_index": template_index,
        "templates": templates,
        "wheel_dir": wheel_dir,
        "drfold2_root": drfold2_root,
    }


def _export_submission_strict_from_long(predictions_path: Path, sample_path: Path, out_path: Path):
    where = f"{SCRIPT_LOC}:export_strict"
    import pandas as pd

    sample_df = pd.read_csv(sample_path)
    pred_df = pd.read_parquet(predictions_path)

    required = {"ID", "model_id", "x", "y", "z"}
    miss = sorted(required - set(pred_df.columns))
    if miss:
        _die("EXPORT", where, "predicoes long sem colunas obrigatorias", len(miss), miss)

    pred_df["ID"] = pred_df["ID"].astype(str)
    pred_df["model_id"] = pd.to_numeric(pred_df["model_id"], errors="coerce")
    if pred_df["model_id"].isna().any():
        bad = pred_df.loc[pred_df["model_id"].isna(), "ID"].head(8).tolist()
        _die("EXPORT", where, "model_id invalido em predicoes", int(pred_df["model_id"].isna().sum()), bad)
    pred_df["model_id"] = pred_df["model_id"].astype(int)

    dups = pred_df.duplicated(["ID", "model_id"])
    if dups.any():
        bad = pred_df.loc[dups, ["ID", "model_id"]].head(8).astype(str).agg(':'.join, axis=1).tolist()
        _die("EXPORT", where, "predicoes duplicadas por ID/model_id", int(dups.sum()), bad)

    key_sample = set(sample_df["ID"].astype(str).tolist())
    key_pred = set(pred_df["ID"].astype(str).tolist())
    missing = sorted(key_sample - key_pred)
    extra = sorted(key_pred - key_sample)
    if missing or extra:
        _die("EXPORT", where, "chaves de predicao nao batem com sample", f"missing={len(missing)} extra={len(extra)}", missing[:4] + extra[:4])

    wide = sample_df[["ID", "resname", "resid"]].copy()
    for axis in ("x", "y", "z"):
        piv = pred_df.pivot(index="ID", columns="model_id", values=axis)
        piv.columns = [f"{axis}_{int(c)}" for c in piv.columns]
        piv = piv.reset_index()
        wide = wide.merge(piv, on="ID", how="left")

    expected_cols = sample_df.columns.tolist()
    expected_pred_cols = [c for c in expected_cols if c not in ("ID", "resname", "resid")]
    got_pred_cols = [c for c in wide.columns.tolist() if c not in ("ID", "resname", "resid")]
    miss_pred = sorted(set(expected_pred_cols) - set(got_pred_cols))
    extra_pred = sorted(set(got_pred_cols) - set(expected_pred_cols))
    if miss_pred or extra_pred:
        _die("EXPORT", where, "colunas de predicao divergentes do sample", f"missing={len(miss_pred)} extra={len(extra_pred)}", miss_pred[:4] + extra_pred[:4])

    wide = sample_df[["ID", "resname", "resid"]].merge(wide, on=["ID", "resname", "resid"], how="left")
    wide = wide.reindex(columns=expected_cols)

    pred_cols = [c for c in expected_cols if c not in ("ID", "resname", "resid")]
    if wide[pred_cols].isna().any().any():
        bad = wide.loc[wide[pred_cols].isna().any(axis=1), "ID"].head(8).tolist()
        _die("EXPORT", where, "submissao final contem nulos", int(wide[pred_cols].isna().any(axis=1).sum()), bad)

    out_path.parent.mkdir(parents=True, exist_ok=True)
    wide.to_csv(out_path, index=False)


def _prepare_drfold2_runtime(*, drfold2_input_root: Path, runtime_root: Path, env: dict) -> Path:
    where = f"{SCRIPT_LOC}:prepare_drfold2_runtime"
    infer_script = drfold2_input_root / "DRfold_infer.py"
    if not infer_script.exists():
        _die("DRFOLD2", where, "DRfold_infer.py ausente no dataset", 1, [str(infer_script)])
    if runtime_root.exists() and any(runtime_root.iterdir()):
        _die("DRFOLD2", where, "runtime_root nao-vazio; use novo dir", 1, [str(runtime_root)])
    runtime_root.mkdir(parents=True, exist_ok=True)

    # /kaggle/input is read-only; copy code/config to a writable runtime dir.
    for child in sorted(drfold2_input_root.iterdir()):
        dst = runtime_root / child.name
        if child.name == "model_hub":
            try:
                dst.symlink_to(child, target_is_directory=True)
                print(f"[INFO] [{SCRIPT_LOC}] drfold2_model_hub_symlink dst={dst} src={child}")
            except Exception as e:
                print(f"[WARN] [{SCRIPT_LOC}] drfold2_model_hub_symlink_failed; copying | err={e!r}")
                shutil.copytree(child, dst)
            continue
        if child.is_dir():
            shutil.copytree(child, dst)
        else:
            shutil.copy2(child, dst)

    arena_dir = runtime_root / "Arena"
    arena_bin = arena_dir / "Arena"
    if not (arena_dir / "Makefile").exists():
        _die("DRFOLD2", where, "Arena/Makefile ausente", 1, [str(arena_dir)])
    # Always rebuild inside Kaggle to avoid binary incompatibilities.
    proc = subprocess.run(["make", "-B"], cwd=str(arena_dir), env=env, text=True, capture_output=True)
    if proc.returncode != 0:
        _die("DRFOLD2", where, "falha ao compilar Arena", proc.returncode, _tail(proc.stdout, proc.stderr, 20))
    if not arena_bin.exists():
        _die("DRFOLD2", where, "Arena compilado mas binario nao encontrado", 1, [str(arena_bin)])

    return runtime_root


def _load_target_lengths(targets_csv: Path) -> dict[str, int]:
    where = f"{SCRIPT_LOC}:target_lengths"
    out: dict[str, int] = {}
    with targets_csv.open("r", encoding="utf-8", newline="") as f:
        reader = csv.DictReader(f)
        if reader.fieldnames is None or "target_id" not in reader.fieldnames or "sequence" not in reader.fieldnames:
            _die("LOAD", where, "targets csv sem colunas esperadas", 1, [str(reader.fieldnames)])
        for row in reader:
            tid = str(row.get("target_id") or "").strip()
            raw_seq = str(row.get("sequence") or "").strip().upper()
            seq = raw_seq.replace("|", "")
            if not tid:
                _die("LOAD", where, "target_id vazio em test_sequences", 1, [str(row)])
            if not seq:
                _die("LOAD", where, "sequence vazia em test_sequences", 1, [tid])
            out[tid] = len(seq)
    if not out:
        _die("LOAD", where, "test_sequences vazio", 0, [str(targets_csv)])
    return out


def _select_drfold2_targets_by_risk(*, retrieval_path: Path, targets_csv: Path, out_ids_path: Path) -> list[str]:
    where = f"{SCRIPT_LOC}:select_drfold2_risk"
    import polars as pl

    if DRFOLD2_MAX_TARGETS <= 0:
        _die("DRFOLD2", where, "DRFOLD2_MAX_TARGETS invalido", 1, [str(DRFOLD2_MAX_TARGETS)])
    if DRFOLD2_MAX_SEQ_LEN <= 0:
        _die("DRFOLD2", where, "DRFOLD2_MAX_SEQ_LEN invalido", 1, [str(DRFOLD2_MAX_SEQ_LEN)])
    if not (0.0 <= float(DRFOLD2_SIMILARITY_THRESHOLD) <= 1.0):
        _die("DRFOLD2", where, "DRFOLD2_SIMILARITY_THRESHOLD invalido", 1, [str(DRFOLD2_SIMILARITY_THRESHOLD)])

    lf = pl.scan_parquet(str(retrieval_path))
    names = lf.collect_schema().names()
    if "target_id" not in names or "similarity" not in names:
        _die("DRFOLD2", where, "retrieval sem colunas esperadas", 1, ["target_id", "similarity"])

    lengths = _load_target_lengths(targets_csv)
    lengths_df = pl.DataFrame(
        {
            "target_id": list(lengths.keys()),
            "target_len": [int(v) for v in lengths.values()],
        }
    )

    retrieval_df = (
        lf.group_by("target_id")
        .agg(pl.max("similarity").cast(pl.Float64).alias("retr_max_similarity"))
        .collect(streaming=True)
    )
    if retrieval_df.height <= 0:
        _die("DRFOLD2", where, "retrieval vazio", 0, [str(retrieval_path)])

    ranked = (
        retrieval_df.join(lengths_df, on="target_id", how="inner")
        .filter(
            (pl.col("retr_max_similarity") < float(DRFOLD2_SIMILARITY_THRESHOLD))
            & (pl.col("target_len") <= int(DRFOLD2_MAX_SEQ_LEN))
        )
        .with_columns(
            (float(DRFOLD2_SIMILARITY_THRESHOLD) - pl.col("retr_max_similarity")).alias("_sim_gap"),
            (pl.col("target_len").cast(pl.Float64) / float(DRFOLD2_MAX_SEQ_LEN)).alias("_len_ratio"),
        )
        .with_columns((pl.col("_sim_gap") * 10.0 + pl.col("_len_ratio")).alias("_risk"))
        .sort(
            by=["_risk", "retr_max_similarity", "target_len", "target_id"],
            descending=[True, False, True, False],
        )
        .limit(int(DRFOLD2_MAX_TARGETS))
    )

    ids = [str(x) for x in ranked.get_column("target_id").to_list()]
    out_ids_path.write_text("\n".join(ids) + ("\n" if ids else ""), encoding="utf-8")
    if ids:
        preview = (
            ranked.select("target_id", "retr_max_similarity", "target_len", "_risk")
            .head(8)
            .to_dicts()
        )
        print(f"[INFO] [{SCRIPT_LOC}] drfold2_risk_selected n={len(ids)} ids={ids} preview={preview}")
    else:
        print(
            f"[INFO] [{SCRIPT_LOC}] drfold2_risk_selected n=0 "
            f"threshold={DRFOLD2_SIMILARITY_THRESHOLD} max_len={DRFOLD2_MAX_SEQ_LEN}"
        )
    return ids


def _read_target_sequences_subset(targets_csv: Path, selected_ids: list[str]) -> dict[str, str]:
    where = f"{SCRIPT_LOC}:read_targets"
    selected = {str(x).strip(): None for x in selected_ids}
    with targets_csv.open('r', encoding='utf-8', newline='') as f:
        reader = csv.DictReader(f)
        if reader.fieldnames is None or 'target_id' not in reader.fieldnames or 'sequence' not in reader.fieldnames:
            _die('LOAD', where, 'targets csv sem colunas esperadas', 1, [str(reader.fieldnames)])
        for row in reader:
            tid = str(row.get('target_id') or '').strip()
            if tid in selected:
                seq = str(row.get('sequence') or '').strip().upper()
                if not seq:
                    _die('LOAD', where, 'sequence vazia para target', 1, [tid])
                selected[tid] = seq
    missing = [tid for tid, seq in selected.items() if not seq]
    if missing:
        _die('LOAD', where, 'target_id selecionado ausente em test_sequences', len(missing), missing[:8])
    return {tid: str(seq) for tid, seq in selected.items() if seq}


def _extract_c1prime_coords(pdb_path: Path, target_sequence: str):
    where = f"{SCRIPT_LOC}:extract_c1prime"
    from Bio.PDB import PDBParser

    if not pdb_path.exists():
        _die('DRFOLD2', where, 'PDB ausente', 1, [str(pdb_path)])

    parser = PDBParser(QUIET=True)
    try:
        structure = parser.get_structure('drfold2', str(pdb_path))
    except Exception as exc:  # noqa: BLE001
        _die('DRFOLD2', where, 'falha ao parsear PDB', 1, [f"{type(exc).__name__}:{exc}", str(pdb_path)])

    models = list(structure.get_models())
    if not models:
        _die('DRFOLD2', where, 'PDB sem modelos', 1, [str(pdb_path)])

    residues = []
    for chain in models[0].get_chains():
        for residue in chain.get_residues():
            hetflag, _resseq, _icode = residue.id
            if str(hetflag).strip():
                continue
            residues.append(residue)

    if len(residues) != len(target_sequence):
        _die(
            'DRFOLD2',
            where,
            'PDB com numero de residuos divergente da sequencia alvo',
            f"expected={len(target_sequence)} got={len(residues)}",
            [str(pdb_path)],
        )

    coords = []
    missing = []
    for idx, res in enumerate(residues, start=1):
        if not res.has_id("C1'"):
            resseq = res.id[1] if isinstance(res.id, tuple) and len(res.id) > 1 else '?'
            missing.append(f"idx={idx}:resseq={resseq}")
            continue
        xyz = res["C1'"].get_coord()
        coords.append((float(xyz[0]), float(xyz[1]), float(xyz[2])))

    if missing:
        _die('DRFOLD2', where, "PDB sem atomo obrigatorio C1'", len(missing), missing[:8])
    if len(coords) != len(target_sequence):
        _die('DRFOLD2', where, 'coords incompletas', f"expected={len(target_sequence)} got={len(coords)}", [str(pdb_path)])

    return coords


def _predict_drfold2_selected(
    *,
    drfold2_root: Path,
    targets_csv: Path,
    selected_ids: list[str],
    work_dir: Path,
    out_parquet: Path,
    env: dict,
) -> tuple[list[str], list[str]]:
    where = f"{SCRIPT_LOC}:predict_drfold2"
    if not selected_ids:
        _die('DRFOLD2', where, 'selected_ids vazio', 0, [])
    if int(DRFOLD2_N_MODELS_RUN) not in (1, 5):
        _die('DRFOLD2', where, 'DRFOLD2_N_MODELS_RUN invalido (suporta 1 ou 5)', 1, [str(DRFOLD2_N_MODELS_RUN)])
    if int(FINAL_N_MODELS) != 5:
        _die('DRFOLD2', where, 'FINAL_N_MODELS invalido (esperado 5)', 1, [str(FINAL_N_MODELS)])

    infer_script = drfold2_root / 'DRfold_infer.py'
    if not infer_script.exists():
        _die('DRFOLD2', where, 'DRfold_infer.py ausente', 1, [str(infer_script)])
    if not (drfold2_root / 'model_hub').exists():
        _die('DRFOLD2', where, 'model_hub ausente', 1, [str(drfold2_root / 'model_hub')])

    if work_dir.exists() and any(work_dir.iterdir()):
        _die('DRFOLD2', where, 'work_dir nao-vazio; use novo dir', 1, [str(work_dir)])
    work_dir.mkdir(parents=True, exist_ok=True)

    seqs = _read_target_sequences_subset(targets_csv=targets_csv, selected_ids=selected_ids)

    rows = []
    succeeded_ids: list[str] = []
    failed_errors: list[str] = []
    for tid in selected_ids:
        seq = seqs[tid]
        t0 = time.time()
        try:
            target_dir = work_dir / tid
            target_dir.mkdir(parents=True, exist_ok=True)
            fasta_path = target_dir / 'target.fasta'
            fasta_path.write_text(f">{tid}\n{seq}\n", encoding='utf-8')

            cmd = [sys.executable, str(infer_script), str(fasta_path), str(target_dir)]
            if int(DRFOLD2_N_MODELS_RUN) > 1:
                cmd.append('1')

            proc = subprocess.run(cmd, cwd=str(drfold2_root), env=env, text=True, capture_output=True)
            if proc.returncode != 0:
                _die('DRFOLD2', where, 'falha ao executar DRfold2', proc.returncode, _tail(proc.stdout, proc.stderr, 20))

            relax_dir = target_dir / 'relax'
            pdb_path = relax_dir / 'model_1.pdb'
            if not pdb_path.exists():
                # DRfold2 script may fail to produce relax output if Arena is incompatible.
                debug = []
                debug += _tail(proc.stdout, proc.stderr, 12)
                debug.append(f"ls_target_dir={sorted([p.name for p in target_dir.iterdir()])}")
                if (target_dir / 'folds').exists():
                    folds = sorted([p.name for p in (target_dir / 'folds').glob('*.pdb')])
                    debug.append(f"folds_pdb={folds[:6]}")
                _die('DRFOLD2', where, 'DRfold2 nao gerou relax/model_1.pdb', 1, debug[:20])

            coords = _extract_c1prime_coords(pdb_path=pdb_path, target_sequence=seq)

            for model_id in range(1, 6):
                for resid, (base, (x, y, z)) in enumerate(zip(seq, coords, strict=True), start=1):
                    rows.append(
                        {
                            'branch': 'drfold2',
                            'target_id': tid,
                            'ID': f"{tid}_{resid}",
                            'resid': resid,
                            'resname': base,
                            'model_id': model_id,
                            'x': x,
                            'y': y,
                            'z': z,
                        }
                    )

            dt = time.time() - t0
            succeeded_ids.append(tid)
            print(f"[INFO] [{SCRIPT_LOC}] drfold2_done target={tid} len={len(seq)} sec={dt:.1f}")
        except Exception as exc:  # noqa: BLE001
            dt = time.time() - t0
            msg = f"{tid}:{type(exc).__name__}:{exc}"
            failed_errors.append(msg)
            print(
                f"[WARN] [{SCRIPT_LOC}] [DRFOLD2] target_fail -> fallback_tbm target={tid} sec={dt:.1f} err={type(exc).__name__}:{exc}",
                file=sys.stderr,
            )
            continue

    import polars as pl

    if rows:
        df = pl.DataFrame(rows)
        df.write_parquet(str(out_parquet))
    print(
        f"[INFO] [{SCRIPT_LOC}] drfold2_summary selected={len(selected_ids)} "
        f"succeeded={len(succeeded_ids)} failed={len(failed_errors)}"
    )
    return succeeded_ids, failed_errors


def _build_drfold2_coord_map(drfold2_parquet: Path) -> tuple[list[str], dict[str, list[float]]]:
    where = f"{SCRIPT_LOC}:drfold2_map"
    import polars as pl

    df = pl.read_parquet(str(drfold2_parquet))
    need = {'ID', 'model_id', 'x', 'y', 'z'}
    miss = sorted(need - set(df.columns))
    if miss:
        _die('DRFOLD2', where, 'drfold2 parquet sem colunas obrigatorias', len(miss), miss)

    out = pl.DataFrame({'ID': df.get_column('ID').unique()})
    for axis in ('x', 'y', 'z'):
        piv = df.pivot(index='ID', columns='model_id', values=axis)
        cols = [c for c in piv.columns if c != 'ID']
        if len(cols) != 5:
            _die('DRFOLD2', where, 'pivot incompleto (esperado 5 modelos)', len(cols), [str(cols)])
        piv = piv.rename({c: f"{axis}_{int(c)}" for c in cols})
        out = out.join(piv, on='ID', how='left')

    coord_cols = []
    for m in range(1, 6):
        coord_cols += [f"x_{m}", f"y_{m}", f"z_{m}"]

    mapping = {}
    for row in out.select(['ID', *coord_cols]).iter_rows():
        rid = str(row[0])
        vals = [float(x) for x in row[1:]]
        mapping[rid] = vals
    return coord_cols, mapping


def _patch_submission_with_drfold2(*, base_submission: Path, drfold2_parquet: Path, selected_target_ids: list[str], out_submission: Path) -> None:
    where = f"{SCRIPT_LOC}:patch_submission"
    if not selected_target_ids:
        shutil.copyfile(base_submission, out_submission)
        return

    selected = set(selected_target_ids)
    coord_cols, mapping = _build_drfold2_coord_map(drfold2_parquet=drfold2_parquet)

    missing = []
    touched = 0

    with base_submission.open('r', encoding='utf-8', newline='') as f_in, out_submission.open('w', encoding='utf-8', newline='') as f_out:
        reader = csv.DictReader(f_in)
        if reader.fieldnames is None:
            _die('EXPORT', where, 'base_submission sem header', 1, [str(base_submission)])
        for c in coord_cols:
            if c not in reader.fieldnames:
                _die('EXPORT', where, 'base_submission sem coluna esperada', 1, [c])

        writer = csv.DictWriter(f_out, fieldnames=reader.fieldnames)
        writer.writeheader()

        for row in reader:
            rid = str(row.get('ID') or '')
            tid = rid.split('_', 1)[0] if '_' in rid else ''
            if tid in selected:
                vals = mapping.get(rid)
                if vals is None:
                    if len(missing) < 8:
                        missing.append(rid)
                else:
                    for col, val in zip(coord_cols, vals, strict=True):
                        row[col] = repr(float(val))
                    touched += 1
            writer.writerow(row)

    if missing:
        _die('EXPORT', where, 'IDs selecionados sem coords drfold2', len(missing), missing)

    print(f"[INFO] [{SCRIPT_LOC}] patched_rows={touched} targets={selected_target_ids}")


def _submission_layout(header: list[str]) -> tuple[list[int], list[str]]:
    where = f"{SCRIPT_LOC}:submission_layout"
    mids: list[int] = []
    for col in header:
        if not col.startswith("x_"):
            continue
        suffix = col.split("_", 1)[1]
        try:
            mid = int(suffix)
        except Exception:
            _die("EXPORT", where, "coluna de modelo invalida", 1, [col])
        mids.append(mid)
    mids = sorted(mids)
    if not mids:
        _die("EXPORT", where, "submission sem colunas x_<model>", 1, header[:8])

    coord_cols: list[str] = []
    for mid in mids:
        for axis in ("x", "y", "z"):
            col = f"{axis}_{mid}"
            if col not in header:
                _die("EXPORT", where, "submission sem coluna esperada", 1, [col])
            coord_cols.append(col)
    return mids, coord_cols


def _parse_submission_id(id_value: str) -> str:
    where = f"{SCRIPT_LOC}:parse_submission_id"
    raw = str(id_value or "").strip()
    if "_" not in raw:
        _die("EXPORT", where, "ID invalido (esperado <target>_<resid>)", 1, [raw])
    target_id, resid = raw.rsplit("_", 1)
    if not target_id:
        _die("EXPORT", where, "ID invalido (target vazio)", 1, [raw])
    try:
        _ = int(resid)
    except Exception:
        _die("EXPORT", where, "ID invalido (resid nao inteiro)", 1, [raw])
    return target_id


def _normalize_submission_coords(*, in_submission: Path, out_submission: Path, abs_clip: float) -> None:
    where = f"{SCRIPT_LOC}:normalize_submission"
    if abs_clip <= 0:
        _die("EXPORT", where, "abs_clip invalido", 1, [str(abs_clip)])

    sums: dict[tuple[str, int, str], float] = {}
    counts: dict[tuple[str, int, str], int] = {}
    rows_seen = 0

    with in_submission.open("r", encoding="utf-8", newline="") as f_in:
        reader = csv.DictReader(f_in)
        if reader.fieldnames is None:
            _die("EXPORT", where, "submission de entrada sem header", 1, [str(in_submission)])
        mids, coord_cols = _submission_layout(list(reader.fieldnames))
        for row_idx, row in enumerate(reader, start=1):
            rows_seen += 1
            tid = _parse_submission_id(str(row.get("ID") or ""))
            for mid in mids:
                for axis in ("x", "y", "z"):
                    col = f"{axis}_{mid}"
                    raw = row.get(col)
                    try:
                        val = float(str(raw))
                    except Exception:
                        _die("EXPORT", where, "valor nao-numerico na submission", 1, [f"row={row_idx}", col, str(raw)])
                    if not math.isfinite(val):
                        _die("EXPORT", where, "valor nao-finito na submission", 1, [f"row={row_idx}", col, str(raw)])
                    key = (tid, mid, axis)
                    sums[key] = sums.get(key, 0.0) + val
                    counts[key] = counts.get(key, 0) + 1

    if rows_seen <= 0:
        _die("EXPORT", where, "submission de entrada vazia", 0, [str(in_submission)])

    means: dict[tuple[str, int, str], float] = {}
    for key, total in sums.items():
        n = counts.get(key, 0)
        if n <= 0:
            _die("EXPORT", where, "contador invalido ao normalizar", 1, [str(key)])
        means[key] = total / float(n)

    clipped = 0
    with in_submission.open("r", encoding="utf-8", newline="") as f_in, out_submission.open("w", encoding="utf-8", newline="") as f_out:
        reader = csv.DictReader(f_in)
        if reader.fieldnames is None:
            _die("EXPORT", where, "submission de entrada sem header (2a passagem)", 1, [str(in_submission)])
        mids, coord_cols = _submission_layout(list(reader.fieldnames))
        writer = csv.DictWriter(f_out, fieldnames=reader.fieldnames)
        writer.writeheader()

        for row_idx, row in enumerate(reader, start=1):
            tid = _parse_submission_id(str(row.get("ID") or ""))
            for mid in mids:
                for axis in ("x", "y", "z"):
                    col = f"{axis}_{mid}"
                    key = (tid, mid, axis)
                    if key not in means:
                        _die("EXPORT", where, "mean ausente para target/model/axis", 1, [str(key)])
                    val = float(str(row[col])) - means[key]
                    if val > abs_clip:
                        val = abs_clip
                        clipped += 1
                    elif val < -abs_clip:
                        val = -abs_clip
                        clipped += 1
                    row[col] = f"{val:.6f}"
            writer.writerow(row)

    print(f"[INFO] [{SCRIPT_LOC}] normalize_submission done rows={rows_seen} clipped={clipped} abs_clip={abs_clip:g}")


def _assert_submission_coord_bounds(submission_path: Path, *, abs_max: float) -> None:
    where = f"{SCRIPT_LOC}:submission_bounds"
    if abs_max <= 0:
        _die("CHECK", where, "abs_max invalido", 1, [str(abs_max)])

    with submission_path.open("r", encoding="utf-8", newline="") as f:
        reader = csv.DictReader(f)
        if reader.fieldnames is None:
            _die("CHECK", where, "submission sem header", 1, [str(submission_path)])
        _mids, coord_cols = _submission_layout(list(reader.fieldnames))
        bad: list[str] = []
        for row_idx, row in enumerate(reader, start=1):
            for col in coord_cols:
                raw = row.get(col)
                try:
                    val = float(str(raw))
                except Exception:
                    bad.append(f"{col}@{row_idx}:non-numeric")
                    if len(bad) >= 8:
                        _die("CHECK", where, "coordenadas invalidas na submission", len(bad), bad)
                    continue
                if (not math.isfinite(val)) or abs(val) > abs_max:
                    bad.append(f"{col}@{row_idx}:abs>{abs_max:g}")
                    if len(bad) >= 8:
                        _die("CHECK", where, "coordenadas invalidas na submission", len(bad), bad)
        if bad:
            _die("CHECK", where, "coordenadas invalidas na submission", len(bad), bad)


def _run_dynamic_pipeline(*, assets: dict[str, Path | None], sample: Path, targets: Path, run: Path, submission: Path, env: dict, repo_root: Path) -> None:
    retrieval = run / 'retrieval_candidates.parquet'
    tbm = run / 'tbm_predictions.parquet'
    base_submission = run / 'submission_tbm.csv'

    _run([
        sys.executable, '-m', 'rna3d_local', 'retrieve-templates',
        '--template-index', str(assets['template_index']),
        '--targets', str(targets),
        '--out', str(retrieval),
        '--top-k', '400',
        '--kmer-size', '3',
        '--length-weight', '0.25',
        '--chunk-size', '200000',
        '--memory-budget-mb', '8192',
        '--max-rows-in-memory', '10000000',
    ], env, repo_root)

    tbm_cmd = [
        sys.executable, '-m', 'rna3d_local', 'predict-tbm',
        '--retrieval', str(retrieval),
        '--templates', str(assets['templates']),
        '--targets', str(targets),
        '--out', str(tbm),
        '--n-models', str(FINAL_N_MODELS),
        '--min-coverage', '0.01',
        '--chunk-size', '200000',
        '--memory-budget-mb', '8192',
        '--max-rows-in-memory', '10000000',
    ]
    if _supports('predict-tbm', '--rerank-pool-size', env, repo_root):
        tbm_cmd += ['--rerank-pool-size', '128']
    if _supports('predict-tbm', '--mapping-mode', env, repo_root):
        tbm_cmd += ['--mapping-mode', 'hybrid', '--projection-mode', 'template_warped']
    _run(tbm_cmd, env, repo_root)

    _export_submission_strict_from_long(predictions_path=tbm, sample_path=sample, out_path=base_submission)

    final_submission = base_submission

    if USE_DRFOLD2:
        drf_input = assets.get('drfold2_root')
        if drf_input is None:
            _die('DRFOLD2', f"{SCRIPT_LOC}:paths", 'drfold2_root ausente', 1, [])
        drf_runtime = run / 'drfold2_runtime'
        drf_root = _prepare_drfold2_runtime(drfold2_input_root=Path(drf_input), runtime_root=drf_runtime, env=env)

        ids_path = run / 'drfold2_target_ids.txt'
        selected_ids = _select_drfold2_targets_by_risk(
            retrieval_path=retrieval,
            targets_csv=targets,
            out_ids_path=ids_path,
        )
        if selected_ids:
            drf_work = run / 'drfold2_work'
            drf_pred = run / 'drfold2_predictions.parquet'
            succeeded_ids, failed_errors = _predict_drfold2_selected(
                drfold2_root=Path(drf_root),
                targets_csv=targets,
                selected_ids=selected_ids,
                work_dir=drf_work,
                out_parquet=drf_pred,
                env=env,
            )
            failed_path = run / 'drfold2_failed_targets.txt'
            failed_path.write_text("\n".join(failed_errors) + ("\n" if failed_errors else ""), encoding='utf-8')

            if succeeded_ids:
                patched = run / 'submission_patched.csv'
                _patch_submission_with_drfold2(
                    base_submission=base_submission,
                    drfold2_parquet=drf_pred,
                    selected_target_ids=succeeded_ids,
                    out_submission=patched,
                )
                final_submission = patched
            else:
                print(f"[INFO] [{SCRIPT_LOC}] drfold2 sem alvos validos; mantendo TBM para todos os selecionados")

    normalized_submission = run / 'submission_normalized.csv'
    _normalize_submission_coords(
        in_submission=final_submission,
        out_submission=normalized_submission,
        abs_clip=float(SUBMISSION_ABS_CLIP),
    )
    _assert_submission_coord_bounds(normalized_submission, abs_max=float(SUBMISSION_ABS_CLIP))
    shutil.copyfile(normalized_submission, submission)

    _run([
        sys.executable, '-m', 'rna3d_local', 'check-submission',
        '--sample', str(sample),
        '--submission', str(submission),
    ], env, repo_root)


comp = Path('/kaggle/input/stanford-rna-3d-folding-2')
work = Path('/kaggle/working')
input_root = Path('/kaggle/input')
run = work / 'run_dynamic_submit_v77_alt_risknorm'
run.mkdir(parents=True, exist_ok=True)

sample = comp / 'sample_submission.csv'
targets = comp / 'test_sequences.csv'
if not sample.exists():
    _die('LOAD', f"{SCRIPT_LOC}:paths", 'sample_submission.csv ausente', 1, [str(sample)])
if not targets.exists():
    _die('LOAD', f"{SCRIPT_LOC}:paths", 'test_sequences.csv ausente', 1, [str(targets)])

assets = _discover_assets(input_root)
_ensure_biopython(assets['wheel_dir'])

src_root = assets['src_root']
if str(src_root) not in sys.path:
    sys.path.insert(0, str(src_root))

env = os.environ.copy()
env['PYTHONPATH'] = str(src_root) + (":" + env['PYTHONPATH'] if env.get('PYTHONPATH') else "")
env.setdefault('OMP_NUM_THREADS', '1')
env.setdefault('MKL_NUM_THREADS', '1')
env.setdefault('OPENBLAS_NUM_THREADS', '1')
env.setdefault('NUMEXPR_NUM_THREADS', '1')

repo_root = src_root.parent
if not (repo_root / 'pyproject.toml').exists():
    _die('LOAD', f"{SCRIPT_LOC}:repo_root", 'pyproject.toml ausente no repo_root detectado', 1, [str(repo_root)])

submission = work / 'submission.csv'

print(
    f"[INFO] [{SCRIPT_LOC}] mode=dynamic use_drfold2={USE_DRFOLD2} "
    f"drfold2_max_targets={DRFOLD2_MAX_TARGETS} thr={DRFOLD2_SIMILARITY_THRESHOLD} "
    f"drfold2_max_seq_len={DRFOLD2_MAX_SEQ_LEN} drfold2_n_models_run={DRFOLD2_N_MODELS_RUN} "
    f"abs_clip={SUBMISSION_ABS_CLIP}"
)
_run_dynamic_pipeline(
    assets=assets,
    sample=sample,
    targets=targets,
    run=run,
    submission=submission,
    env=env,
    repo_root=repo_root,
)
print(f"[DONE] [{SCRIPT_LOC}] dynamic_pipeline submission={submission}")

h = hashlib.sha256(submission.read_bytes()).hexdigest()
print(f"[INFO] [{SCRIPT_LOC}] submission_sha256={h}")
