In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

# Environments


In [None]:
%ls /kaggle/input/boltz-dependencies

In [None]:
!pip install --no-index /kaggle/input/boltz-dependencies/*whl --no-deps

In [None]:
!pip install --no-index /kaggle/input/fairscale-0413/*whl --no-deps

In [None]:
!pip install --no-index /kaggle/input/biopython/*whl --no-deps

# Prepare scripts

In [None]:
%cd /kaggle/working/

In [7]:
%mkdir inputs_prediction
%mkdir outputs_prediction

In [8]:
%cp -rf /kaggle/input/rna-prediction-boltz/boltz/src/boltz .

In [None]:
%ls boltz

# Write file

In [None]:
%%writefile inference.py
import pickle
import urllib.request
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Literal, Optional
import click
import torch
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.strategies import DDPStrategy
from pytorch_lightning.utilities import rank_zero_only
from tqdm import tqdm
from boltz.data import const
from boltz.data.module.inference import BoltzInferenceDataModule
from boltz.data.msa.mmseqs2 import run_mmseqs2
from boltz.data.parse.a3m import parse_a3m
from boltz.data.parse.csv import parse_csv
from boltz.data.parse.fasta import parse_fasta
from boltz.data.parse.yaml import parse_yaml
from boltz.data.types import MSA, Manifest, Record
from boltz.data.write.writer import BoltzWriter
from boltz.model.model import Boltz1

CCD_URL = "https://huggingface.co/boltz-community/boltz-1/resolve/main/ccd.pkl"
MODEL_URL = "https://huggingface.co/boltz-community/boltz-1/resolve/main/boltz1_conf.ckpt"


@dataclass
class InferenceArtifacts:
    """Container for processed/ready-to-infer paths."""
    manifest: Manifest
    targets_dir: Path
    msa_dir: Path


@dataclass
class SamplerHyperparams:
    """Diffusion/sampling hyperparameters."""
    gamma_0: float = 0.605
    gamma_min: float = 1.107
    noise_scale: float = 0.901
    rho: float = 8
    step_scale: float = 1.638
    sigma_min: float = 0.0004
    sigma_max: float = 160.0
    sigma_data: float = 16.0
    P_mean: float = -1.2
    P_std: float = 1.5
    coordinate_augmentation: bool = True
    alignment_reverse_diff: bool = True
    synchronize_sigmas: bool = True
    use_inference_model_cache: bool = True


@rank_zero_only
def ensure_artifacts_downloaded(cache_dir: Path) -> None:
    """Fetch CCD dict and model weights into cache_dir if missing."""
    ccd_path = cache_dir / "ccd.pkl"
    if not ccd_path.exists():
        click.echo(
            f"Downloading the CCD dictionary to {ccd_path}. "
            "You may change the cache directory with the --cache flag."
        )
        urllib.request.urlretrieve(CCD_URL, str(ccd_path))  # noqa: S310

    model_ckpt = cache_dir / "boltz1_conf.ckpt"
    if not model_ckpt.exists():
        click.echo(
            f"Downloading the model weights to {model_ckpt}. "
            "You may change the cache directory with the --cache flag."
        )
        urllib.request.urlretrieve(MODEL_URL, str(model_ckpt))  # noqa: S310


def resolve_and_filter_inputs(
    input_path: Path,
    output_dir: Path,
    override_existing: bool = False,
) -> list[Path]:
    """Expand inputs, validate types, and skip already-predicted targets."""
    click.echo("Checking input data.")

    # Expand directory to files or wrap single file
    if input_path.is_dir():
        candidates: list[Path] = list(input_path.glob("*"))
        filtered: list[Path] = []
        for p in candidates:
            if p.suffix in (".fa", ".fas", ".fasta", ".yml", ".yaml"):
                filtered.append(p)
            elif p.is_dir():
                raise RuntimeError(f"Found directory {p} instead of .fasta or .yaml.")
            else:
                raise RuntimeError(
                    f"Unable to parse filetype {p.suffix}, "
                    "please provide a .fasta or .yaml file."
                )
        paths = filtered
    else:
        paths = [input_path]

    # Skip those with existing predictions unless override
    existing_pred_dirs = (output_dir / "predictions").rglob("*")
    existing_ids = {e.name for e in existing_pred_dirs if e.is_dir()}

    if existing_ids and not override_existing:
        pruned = [p for p in paths if p.stem not in existing_ids]
        num_skipped = len(paths) - len(pruned)
        click.echo(
            f"Found some existing predictions ({num_skipped}), "
            "skipping and running only the missing ones. "
            "Use --override to recompute."
        )
        paths = pruned
    elif existing_ids and override_existing:
        click.echo("Found existing predictions, will override.")

    return paths


def build_msa_alignments(
    sequences_by_entity: dict[str, str],
    target_id: str,
    msa_raw_dir: Path,
    msa_server_url: str,
    msa_pairing_strategy: str,
) -> None:
    """Create paired/unpaired MSAs via MMSeqs2 and save CSVs."""
    if len(sequences_by_entity) > 1:
        paired_msas = run_mmseqs2(
            list(sequences_by_entity.values()),
            msa_raw_dir / f"{target_id}_paired_tmp",
            use_env=True,
            use_pairing=True,
            host_url=msa_server_url,
            pairing_strategy=msa_pairing_strategy,
        )
    else:
        paired_msas = [""] * len(sequences_by_entity)

    unpaired_msas = run_mmseqs2(
        list(sequences_by_entity.values()),
        msa_raw_dir / f"{target_id}_unpaired_tmp",
        use_env=True,
        use_pairing=False,
        host_url=msa_server_url,
        pairing_strategy=msa_pairing_strategy,
    )

    for idx, entity_name in enumerate(sequences_by_entity):
        # Paired block
        paired = paired_msas[idx].strip().splitlines()
        paired = paired[1::2]  # strip headers
        paired = paired[: const.max_paired_seqs]

        keep_keys = [i for i, s in enumerate(paired) if s != "-" * len(s)]
        paired = [s for s in paired if s != "-" * len(s)]

        # Unpaired block
        unpaired = unpaired_msas[idx].strip().splitlines()
        unpaired = unpaired[1::2]
        unpaired = unpaired[: (const.max_msa_seqs - len(paired))]
        if paired:
            unpaired = unpaired[1:]  # query already present

        # Merge and write csv
        seqs = paired + unpaired
        keys = keep_keys + [-1] * len(unpaired)

        csv_lines = ["key,sequence"] + [f"{k},{s}" for k, s in zip(keys, seqs)]
        (msa_raw_dir / f"{entity_name}.csv").write_text("\n".join(csv_lines))


@rank_zero_only
def preprocess_dataset(  # noqa: C901, PLR0912, PLR0915
    inputs: list[Path],
    output_dir: Path,
    ccd_pickle: Path,
    msa_server_url: str,
    msa_pairing_strategy: str,
    max_msa_seqs: int = 4096,
    allow_msa_server: bool = False,
) -> None:
    """Parse inputs, generate/convert MSAs, dump structures + manifest."""
    click.echo("Processing input data.")
    prior_records: Optional[list[Record]] = None

    # Reuse/update existing manifest if present
    manifest_path = output_dir / "processed" / "manifest.json"
    if manifest_path.exists():
        click.echo(f"Found a manifest file at output directory: {output_dir}")
        manifest: Manifest = Manifest.load(manifest_path)
        requested_ids = [p.stem for p in inputs]
        found_pairs = [
            (rec, rec.id) for rec in manifest.records if rec.id in requested_ids
        ]
        if found_pairs:
            prior_records, already = zip(*found_pairs)
            prior_records = list(prior_records)
            missing_count = len(requested_ids) - len(already)
            if missing_count == 0:
                click.echo("All examples in data are processed. Updating the manifest")
                Manifest(prior_records).dump(manifest_path)
                return
            click.echo(f"{missing_count} missing ids. Preprocessing these ids")
            missing_ids = set(requested_ids).difference(set(already))
            inputs = [p for p in inputs if p.stem in missing_ids]

    # Directories
    msa_raw_dir = output_dir / "msa"
    struct_dir = output_dir / "processed" / "structures"
    msa_proc_dir = output_dir / "processed" / "msa"
    preds_dir = output_dir / "predictions"
    output_dir.mkdir(parents=True, exist_ok=True)
    msa_raw_dir.mkdir(parents=True, exist_ok=True)
    struct_dir.mkdir(parents=True, exist_ok=True)
    msa_proc_dir.mkdir(parents=True, exist_ok=True)
    preds_dir.mkdir(parents=True, exist_ok=True)

    # Load CCD
    with ccd_pickle.open("rb") as fh:
        ccd = pickle.load(fh)  # noqa: S301

    if prior_records is not None:
        click.echo(f"Found {len(prior_records)} records. Adding them to records")

    records: list[Record] = prior_records if prior_records is not None else []
    for path in tqdm(inputs):
        try:
            # Parse input target
            if path.suffix in (".fa", ".fas", ".fasta"):
                parsed = parse_fasta(path, ccd)
            elif path.suffix in (".yml", ".yaml"):
                parsed = parse_yaml(path, ccd)
            elif path.is_dir():
                raise RuntimeError(f"Found directory {path} instead of .fasta or .yaml, skipping.")
            else:
                raise RuntimeError(
                    f"Unable to parse filetype {path.suffix}, please provide a .fasta or .yaml file."
                )

            target_id = parsed.record.id

            # Determine which chains need MSA generation
            to_generate: dict[str, str] = {}
            protein_type = const.chain_type_ids["PROTEIN"]
            for ch in parsed.record.chains:
                if (ch.mol_type == protein_type) and (ch.msa_id == 0):
                    entity_id = ch.entity_id
                    msa_stub = f"{target_id}_{entity_id}"
                    to_generate[msa_stub] = parsed.sequences[entity_id]
                    ch.msa_id = msa_raw_dir / f"{msa_stub}.csv"
                elif ch.msa_id == 0:
                    ch.msa_id = -1  # unsupported for non-protein

            if to_generate and not allow_msa_server:
                raise RuntimeError("Missing MSAs and --use_msa_server flag not set.")

            if to_generate:
                click.echo(f"Generating MSA for {path} with {len(to_generate)} protein entities.")
                build_msa_alignments(
                    sequences_by_entity=to_generate,
                    target_id=target_id,
                    msa_raw_dir=msa_raw_dir,
                    msa_server_url=msa_server_url,
                    msa_pairing_strategy=msa_pairing_strategy,
                )

            # Convert raw MSAs to processed NPZs
            distinct_msas = sorted({c.msa_id for c in parsed.record.chains if c.msa_id != -1})
            msa_id_map: dict[Path | str, str] = {}
            for idx, raw_id in enumerate(distinct_msas):
                raw_path = Path(raw_id)
                if not raw_path.exists():
                    raise FileNotFoundError(f"MSA file {raw_path} not found.")

                npz_out = msa_proc_dir / f"{target_id}_{idx}.npz"
                msa_id_map[raw_id] = f"{target_id}_{idx}"
                if not npz_out.exists():
                    if raw_path.suffix == ".a3m":
                        msa_obj: MSA = parse_a3m(raw_path, taxonomy=None, max_seqs=max_msa_seqs)
                    elif raw_path.suffix == ".csv":
                        msa_obj: MSA = parse_csv(raw_path, max_seqs=max_msa_seqs)
                    else:
                        raise RuntimeError(f"MSA file {raw_path} not supported, only a3m or csv.")
                    msa_obj.dump(npz_out)

            # Repoint chains to processed MSA ids
            for ch in parsed.record.chains:
                if (ch.msa_id != -1) and (ch.msa_id in msa_id_map):
                    ch.msa_id = msa_id_map[ch.msa_id]

            records.append(parsed.record)

            # Dump structure
            (struct_dir / f"{parsed.record.id}.npz").parent.mkdir(parents=True, exist_ok=True)
            parsed.structure.dump(struct_dir / f"{parsed.record.id}.npz")

        except Exception as exc:
            if len(inputs) > 1:
                print(f"Failed to process {path}. Skipping. Error: {exc}.")
            else:
                raise exc

    Manifest(records).dump(output_dir / "processed" / "manifest.json")


def run_inference(
    data: str,
    out_dir: str,
    cache: str = "~/.boltz",
    checkpoint: Optional[str] = None,
    devices: int = 1,
    accelerator: str = "gpu",
    recycling_steps: int = 3,
    sampling_steps: int = 200,
    diffusion_samples: int = 10,
    step_scale: float = 1.638,
    write_full_pae: bool = False,
    write_full_pde: bool = False,
    output_format: Literal["pdb", "mmcif"] = "mmcif",
    num_workers: int = 2,
    override: bool = False,
    seed: Optional[int] = None,
    use_msa_server: bool = False,
    msa_server_url: str = "https://api.colabfold.com",
    msa_pairing_strategy: str = "greedy",
) -> None:
    """Run Boltz-1 predictions with refactored names."""
    if accelerator == "cpu":
        click.echo("Running on CPU, this will be slow. Consider using a GPU.")

    torch.set_grad_enabled(False)
    torch.set_float32_matmul_precision("highest")

    if seed is not None:
        seed_everything(int(seed))

    cache_dir = Path(cache).expanduser()
    cache_dir.mkdir(parents=True, exist_ok=True)

    input_path = Path(data).expanduser()
    output_dir = Path(out_dir).expanduser() / f"boltz_results_{input_path.stem}"
    output_dir.mkdir(parents=True, exist_ok=True)

    ensure_artifacts_downloaded(cache_dir)

    # Validate/collect inputs
    input_paths = resolve_and_filter_inputs(input_path, output_dir, override_existing=override)
    if not input_paths:
        click.echo("No predictions to run, exiting.")
        return

    # Trainer strategy
    strategy: str | DDPStrategy = "auto"
    if (isinstance(devices, int) and devices > 1) or (isinstance(devices, list) and len(devices) > 1):
        strategy = DDPStrategy()
        if len(input_paths) < (devices if isinstance(devices, int) else len(devices)):
            raise ValueError("Requested more devices than predictions.")

    plural = "s" if len(input_paths) > 1 else ""
    click.echo(f"Running predictions for {len(input_paths)} structure{plural}")

    # Preprocess
    ccd_pickle = cache_dir / "ccd.pkl"
    preprocess_dataset(
        inputs=input_paths,
        output_dir=output_dir,
        ccd_pickle=ccd_pickle,
        allow_msa_server=use_msa_server,
        msa_server_url=msa_server_url,
        msa_pairing_strategy=msa_pairing_strategy,
    )

    # Load processed artifacts
    processed_root = output_dir / "processed"
    artifacts = InferenceArtifacts(
        manifest=Manifest.load(processed_root / "manifest.json"),
        targets_dir=processed_root / "structures",
        msa_dir=processed_root / "msa",
    )

    # Data module
    dm = BoltzInferenceDataModule(
        manifest=artifacts.manifest,
        target_dir=artifacts.targets_dir,
        msa_dir=artifacts.msa_dir,
        num_workers=num_workers,
    )

    # Model
    ckpt_path = Path(checkpoint) if checkpoint is not None else (cache_dir / "boltz1_conf.ckpt")
    predict_args = {
        "recycling_steps": recycling_steps,
        "sampling_steps": sampling_steps,
        "diffusion_samples": diffusion_samples,
        "write_confidence_summary": True,
        "write_full_pae": write_full_pae,
        "write_full_pde": write_full_pde,
    }
    sampler_params = SamplerHyperparams()
    sampler_params.step_scale = step_scale
    model: Boltz1 = Boltz1.load_from_checkpoint(
        ckpt_path,
        strict=True,
        predict_args=predict_args,
        map_location="cpu",
        diffusion_process_args=asdict(sampler_params),
        ema=False,
    )
    model.eval()

    writer = BoltzWriter(
        data_dir=artifacts.targets_dir,
        output_dir=output_dir / "predictions",
        output_format=output_format,
    )

    trainer = Trainer(
        default_root_dir=output_dir,
        strategy=strategy,
        callbacks=[writer],
        accelerator=accelerator,
        devices=devices,
        precision=32,
    )

    trainer.predict(model, datamodule=dm, return_predictions=False)


# Backwards-compatible wrapper (optional): keep old entrypoint name working.
def predict(*args, **kwargs):
    """Deprecated alias for run_inference (kept for compatibility)."""
    return run_inference(*args, **kwargs)


if __name__ == "__main__":
    run_inference(
        data="./inputs_prediction",
        out_dir="./outputs_prediction",
        cache="/kaggle/input/rna-prediction-boltz/",
        diffusion_samples=10,
        seed=42,
        override=True,
    )

Writing inference.py


# Prepare inputs

In [None]:
from pathlib import Path
import pandas as pd

def generate_yaml_inputs(
    csv_path: str | Path,
    output_dir: str | Path = "/kaggle/working/inputs_prediction",
) -> list[Path]:
    """
    Read a CSV of targets and sequences and write one YAML file per target.

    Expected CSV columns: 'target_id', 'sequence'
    """
    csv_path = Path(csv_path)
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    df_sequences = pd.read_csv(csv_path)
    # Optional peek
    display(df_sequences.head())

    target_ids = df_sequences["target_id"].tolist()
    rna_seqs = df_sequences["sequence"].tolist()

    written_files: list[Path] = []
    for target_id, rna_seq in zip(target_ids, rna_seqs):
        yaml_text = (
            "constraints: []\n"
            "sequences:\n"
            "- rna:\n"
            "    id:\n"
            "    - A1\n"
            f"    sequence: {rna_seq}"
        )
        yaml_path = output_dir / f"{target_id}.yaml"
        yaml_path.write_text(yaml_text)
        written_files.append(yaml_path)

    return written_files


# Run it:
_ = generate_yaml_inputs(
    csv_path="/kaggle/input/stanford-rna-3d-folding/test_sequences.csv",
    output_dir="/kaggle/working/inputs_prediction",
)

In [12]:
%ls inputs_prediction

R1107.yaml  R1116.yaml    R1126.yaml  R1136.yaml  R1149.yaml  R1189.yaml
R1108.yaml  R1117v2.yaml  R1128.yaml  R1138.yaml  R1156.yaml  R1190.yaml


In [13]:
%ls outputs_prediction

# Exec inference

In [None]:
import gc
import logging
import subprocess
from pathlib import Path
import torch

def clear_cuda_cache() -> None:
    """Free up GPU memory before running inference."""
    try:
        torch.cuda.empty_cache()
    except Exception:
        pass
    gc.collect()


logging.basicConfig(level=logging.INFO, format="%(levelname)s | %(message)s")
logger = logging.getLogger("inference_exec")

def _tail(text: str, n: int = 40) -> str:
    lines = text.strip().splitlines()
    return "\n".join(lines[-n:]) if lines else ""

def run_inference_script(
    python_bin: str = "python",
    script_path: str | Path = "inference.py",
    extra_args: list[str] | None = None,
    workdir: str | Path | None = None,
) -> subprocess.CompletedProcess:
    """
    Run the inference script as a subprocess and return the CompletedProcess.
    Prints a short, useful log summary (return code + tails of stdout/stderr).
    """
    cmd = [python_bin, str(script_path)]
    if extra_args:
        cmd += list(extra_args)

    logger.info(f"Launching: {' '.join(cmd)}")
    result = subprocess.run(
        cmd,
        capture_output=True,
        text=True,
        cwd=str(workdir) if workdir else None,
    )

    logger.info(f"Return code: {result.returncode}")
    if result.stdout:
        logger.info("STDOUT (tail):\n" + _tail(result.stdout))
    if result.stderr:
        # stderr can contain non-fatal warnings; keep as WARNING not ERROR by default
        logger.warning("STDERR (tail):\n" + _tail(result.stderr))
    return result

clear_cuda_cache()
result = run_inference_script(
    python_bin="python",
    script_path="inference.py",
    # If your inference.py takes CLI args, add them here:
    extra_args=None,
    workdir=None,
)

# Read RNA files

In [None]:
def summarize_subprocess_result(res: subprocess.CompletedProcess, tail_lines: int = 60):
    print(f"Return code: {res.returncode}")
    print("\n=== STDOUT (last {} lines) ===".format(tail_lines))
    print(_tail(res.stdout, tail_lines))
    if res.stderr:
        print("\n=== STDERR (last {} lines) ===".format(tail_lines))
        print(_tail(res.stderr, tail_lines))
    if res.returncode != 0:
        # Surface failure early so the notebook stops before “gather results”
        raise RuntimeError("Inference script failed (non-zero exit code).")

summarize_subprocess_result(result, tail_lines=80)

CompletedProcess(args=['python', 'inference.py'], returncode=0, stdout='Checking input data.\nRunning predictions for 12 structures\nProcessing input data.\n\nPredicting: |          | 0/? [00:00<?, ?it/s]\nPredicting:   0%|          | 0/12 [00:00<?, ?it/s]\nPredicting DataLoader 0:   0%|          | 0/12 [00:00<?, ?it/s]\nPredicting DataLoader 0:   8%|▊         | 1/12 [00:33<06:04,  0.03it/s]\nPredicting DataLoader 0:  17%|█▋        | 2/12 [04:23<21:56,  0.01it/s]\nPredicting DataLoader 0:  25%|██▌       | 3/12 [05:12<15:36,  0.01it/s]\nPredicting DataLoader 0:  33%|███▎      | 4/12 [19:00<38:01,  0.00it/s]\nPredicting DataLoader 0:  42%|████▏     | 5/12 [19:30<27:18,  0.00it/s]\nPredicting DataLoader 0:  50%|█████     | 6/12 [20:19<20:19,  0.00it/s]\nPredicting DataLoader 0:  58%|█████▊    | 7/12 [21:26<15:18,  0.01it/s]\nPredicting DataLoader 0:  67%|██████▋   | 8/12 [21:42<10:51,  0.01it/s]\nPredicting DataLoader 0:  75%|███████▌  | 9/12 [25:24<08:28,  0.01it/s]\nPredicting DataLoade

# Gather results

In [None]:
from __future__ import annotations
from pathlib import Path
import json
import logging
from typing import List, Tuple

import pandas as pd
from Bio.PDB.MMCIF2Dict import MMCIF2Dict

# config 
PREDICTIONS_ROOT = Path("outputs_prediction/boltz_results_inputs_prediction/predictions")
SUBMISSION_TEMPLATE = Path("/kaggle/input/stanford-rna-3d-folding/sample_submission.csv")
SUBMISSION_OUT = Path("submission.csv")
MAX_MODELS_PER_TARGET = 10
TOP_K = 5  # number of predictions to keep per target

# logging 
logging.basicConfig(level=logging.INFO, format="%(levelname)s | %(message)s")
log = logging.getLogger("gather_results")

#  helpers 
def read_c1prime_coords_from_cif(cif_path: Path) -> List[Tuple[float, float, float]]:
    """Extract all C1' coordinates (x,y,z) from a .cif."""
    mmcif = MMCIF2Dict(str(cif_path))
    xs = mmcif["_atom_site.Cartn_x"]
    ys = mmcif["_atom_site.Cartn_y"]
    zs = mmcif["_atom_site.Cartn_z"]
    atom_names = mmcif["_atom_site.label_atom_id"]

    coords = [
        (float(xs[i]), float(ys[i]), float(zs[i]))
        for i, atom in enumerate(atom_names)
        if atom == "C1'"
    ]
    if not coords:
        raise ValueError(f"No C1' atoms found in {cif_path}")
    return coords


def read_confidence_from_json(json_path: Path) -> float:
    """Read the global confidence score from confidence_*.json."""
    with json_path.open("r") as jf:
        data = json.load(jf)
    if "confidence_score" not in data:
        raise KeyError(f"'confidence_score' missing in {json_path}")
    return float(data["confidence_score"])


def load_coords_and_confidence(target_id: str, model_idx: int) -> Tuple[list[Tuple[float, float, float]], float]:
    """
    Load C1' coordinates and confidence for a single target/model pair.
    Paths follow the Boltz writer convention.
    """
    base = PREDICTIONS_ROOT / target_id
    cif_path = base / f"{target_id}_model_{model_idx}.cif"
    json_path = base / f"confidence_{target_id}_model_{model_idx}.json"

    if not cif_path.exists():
        raise FileNotFoundError(f"Missing CIF: {cif_path}")
    if not json_path.exists():
        raise FileNotFoundError(f"Missing JSON: {json_path}")

    coords = read_c1prime_coords_from_cif(cif_path)
    conf = read_confidence_from_json(json_path)
    return coords, conf


def fill_submission_for_target(
    df_sub: pd.DataFrame,
    target_id: str,
    best_samples: list[tuple[float, list[tuple[float, float, float]], int]],
) -> None:
    """
    For a given target, assign coordinates from top-K samples into x_k,y_k,z_k columns
    across all rows where ID starts with the target_id.
    """
    mask = df_sub["ID"].str.startswith(target_id)
    if not mask.any():
        log.warning(f"No submission rows match target '{target_id}'. Skipping.")
        return

    # sanity check: coords length must match number of rows for this target
    num_rows = mask.sum()

    for rank, (conf, coords, model_idx) in enumerate(best_samples, start=1):
        if len(coords) != num_rows:
            log.warning(
                f"Row/coord mismatch for target {target_id} (rank {rank}): "
                f"{len(coords)} coords vs {num_rows} rows in template. Truncating to min."
            )
        limit = min(len(coords), num_rows)
        xs = [coords[i][0] for i in range(limit)]
        ys = [coords[i][1] for i in range(limit)]
        zs = [coords[i][2] for i in range(limit)]

        # Assign only the first `limit` matching rows to keep alignment
        idxs = df_sub.index[mask][:limit]
        df_sub.loc[idxs, f"x_{rank}"] = xs
        df_sub.loc[idxs, f"y_{rank}"] = ys
        df_sub.loc[idxs, f"z_{rank}"] = zs


# main aggregation 
submission_df = pd.read_csv(SUBMISSION_TEMPLATE)
target_dirs = sorted([p.name for p in PREDICTIONS_ROOT.iterdir() if p.is_dir()])

log.info(f"Found {len(target_dirs)} targets in {PREDICTIONS_ROOT}")

for target_id in target_dirs:
    samples: list[tuple[float, list[tuple[float, float, float]], int]] = []
    for idx in range(MAX_MODELS_PER_TARGET):
        try:
            coords, conf = load_coords_and_confidence(target_id, idx)
            samples.append((conf, coords, idx))
        except Exception as exc:
            # many indices may not exist; keep it quiet unless truly unexpected
            log.debug(f"Skipping {target_id} model_{idx}: {exc}")

    if not samples:
        log.warning(f"No valid models found for {target_id}.")
        continue

    # highest confidence first
    samples.sort(key=lambda x: x[0], reverse=True)
    top_samples = samples[:TOP_K]

    fill_submission_for_target(submission_df, target_id, top_samples)
    log.info(f"{target_id}: used models {[s[2] for s in top_samples]} (by confidence)")

R1128: Used models [0, 1, 2, 3, 4] by confidence
R1136: Used models [0, 1, 2, 3, 4] by confidence
R1108: Used models [0, 1, 2, 3, 4] by confidence
R1126: Used models [0, 1, 2, 3, 4] by confidence
R1189: Used models [0, 1, 2, 3, 4] by confidence
R1149: Used models [0, 1, 2, 3, 4] by confidence
R1107: Used models [0, 1, 2, 3, 4] by confidence
R1190: Used models [0, 1, 2, 3, 4] by confidence
R1116: Used models [0, 1, 2, 3, 4] by confidence
R1117v2: Used models [0, 1, 2, 3, 4] by confidence
R1156: Used models [0, 1, 2, 3, 4] by confidence
R1138: Used models [0, 1, 2, 3, 4] by confidence


In [22]:
%ls

[0m[01;34mboltz[0m/        [01;34minputs_prediction[0m/  [01;34moutputs_prediction[0m/
inference.py  __notebook__.ipynb  submission.csv


In [23]:
%rm -rf boltz

In [24]:
%rm -rf inputs_prediction

In [25]:
%rm -rf outputs_prediction

In [26]:
%rm -rf inference.py

In [27]:
%ls

__notebook__.ipynb  submission.csv


# Submission

In [None]:
# Save
submission_df.to_csv(SUBMISSION_OUT, index=False)
log.info(f"Wrote submission to {SUBMISSION_OUT.resolve()}")
submission_df = pd.read_csv("submission.csv")
submission_df.head(20)

Unnamed: 0,ID,resname,resid,x_1,y_1,z_1,x_2,y_2,z_2,x_3,y_3,z_3,x_4,y_4,z_4,x_5,y_5,z_5
0,R1107_1,G,1,3.70495,-14.33893,-10.09669,4.02854,-17.7368,17.98308,-9.5709,5.12235,16.47663,14.67717,-13.7387,0.36041,-18.78674,7.66877,-6.27028
1,R1107_2,G,2,1.14741,-11.06134,-13.74821,2.70712,-21.31407,14.04303,-5.63369,8.51553,14.91798,10.09968,-16.43514,1.8279,-17.50484,10.34362,-1.59239
2,R1107_3,G,3,-2.14035,-6.70239,-15.06713,-1.0175,-24.00727,11.00196,-0.53064,10.03004,13.33584,5.27593,-16.88308,4.64238,-15.32987,10.8727,3.56209
3,R1107_4,G,4,-2.17234,-0.6093,-13.42071,-6.05388,-24.44052,8.94807,4.4123,9.29037,10.91419,1.01392,-14.32377,6.91798,-11.68615,9.54526,7.53135
4,R1107_5,G,5,0.77151,3.77795,-11.44262,-10.94124,-22.26543,7.6388,7.53758,6.75297,7.00821,-1.47926,-9.50194,5.91292,-6.61575,7.48247,9.00864
5,R1107_6,C,6,5.67863,5.70963,-10.1188,-14.21652,-17.9473,6.72871,7.51149,3.88782,2.32127,-2.03465,-5.31495,2.3708,-1.27673,6.21076,7.60576
6,R1107_7,C,7,11.02761,5.55795,-9.97404,-15.23629,-12.76198,5.73029,3.95808,2.32277,-1.16555,-1.87518,-5.07159,-2.84111,2.58904,6.95155,4.03119
7,R1107_8,A,8,18.04744,4.86127,-8.03881,-13.39502,-8.17929,4.32908,-1.08548,3.56857,-2.47697,0.02936,-0.74087,-5.05757,4.13797,9.39085,-0.29584
8,R1107_9,C,9,15.15094,2.14722,-3.6201,-9.69963,-5.92144,1.72035,-6.7481,3.35136,-3.70763,2.87785,3.48328,-7.68874,5.36403,13.4962,-4.02454
9,R1107_10,A,10,13.85003,1.55464,1.60036,-6.66339,-6.83633,-2.01346,-12.40604,3.41132,-6.09684,4.12825,6.77217,-11.9894,5.90802,11.5565,-11.28776
