# Load & Standardize Inputs

In [3]:
# ============================================================
# NOTEBOOK-1 / STAGE 1 — Load & Standardize Inputs (ONE CELL)
# Stanford RNA 3D Folding Part 2
#
# Goals:
# - Define canonical paths (MSA / PDB_RNA / extra / CSVs)
# - Load train/val/test sequences + train/val labels + sample_submission
# - Standardize types + add basic derived columns (L, temporal_cutoff_dt)
# - Perform strict sanity checks (columns, missing files, sequence chars)
# - Print a compact dataset summary for debugging
#
# Outputs (globals):
# - PATHS, COMP_ROOT, MSA_DIR, PDB_RNA_DIR, EXTRA_DIR, OUT_DIR
# - df_train_seq, df_val_seq, df_test_seq
# - df_train_lbl, df_val_lbl
# - df_sample_sub
# - LABEL_COORD_COLS_TRAIN, LABEL_COORD_COLS_VAL
# - SEQ_ALLOWED
# ============================================================

import os, re, json, math, warnings
from pathlib import Path

import numpy as np
import pandas as pd

warnings.filterwarnings("ignore", category=pd.errors.DtypeWarning)
pd.set_option("display.max_columns", 200)
pd.set_option("display.width", 160)

# ----------------------------
# 0) Canonical paths (as you provided)
# ----------------------------
COMP_ROOT    = Path("/kaggle/input/stanford-rna-3d-folding-2")
MSA_DIR      = COMP_ROOT / "MSA"
PDB_RNA_DIR  = COMP_ROOT / "PDB_RNA"
EXTRA_DIR    = COMP_ROOT / "extra"

PATHS = {
    "COMP_ROOT": str(COMP_ROOT),
    "MSA_DIR": str(MSA_DIR),
    "PDB_RNA_DIR": str(PDB_RNA_DIR),
    "EXTRA_DIR": str(EXTRA_DIR),
    "SAMPLE_SUB": str(COMP_ROOT / "sample_submission.csv"),
    "TEST_SEQ": str(COMP_ROOT / "test_sequences.csv"),
    "TRAIN_LBL": str(COMP_ROOT / "train_labels.csv"),
    "TRAIN_SEQ": str(COMP_ROOT / "train_sequences.csv"),
    "VAL_LBL": str(COMP_ROOT / "validation_labels.csv"),
    "VAL_SEQ": str(COMP_ROOT / "validation_sequences.csv"),
    "PARSE_FASTA_PY": str(EXTRA_DIR / "parse_fasta_py.py"),
    "RNA_METADATA": str(EXTRA_DIR / "rna_metadata.csv"),
    "EXTRA_README": str(EXTRA_DIR / "README.md"),
}

# Write outputs here (versioned folder recommended)
OUT_DIR = Path("/kaggle/working/rna3d_artifacts_v1")
OUT_DIR.mkdir(parents=True, exist_ok=True)

def _require(p: str | Path, what: str = ""):
    p = Path(p)
    if not p.exists():
        raise FileNotFoundError(f"Missing {what or 'file/dir'}: {p}")
    return p

# Validate required files/dirs exist
_require(MSA_DIR, "MSA_DIR")
_require(PDB_RNA_DIR, "PDB_RNA_DIR")
_require(EXTRA_DIR, "EXTRA_DIR")
for k in ["SAMPLE_SUB","TEST_SEQ","TRAIN_LBL","TRAIN_SEQ","VAL_LBL","VAL_SEQ"]:
    _require(PATHS[k], k)

# ----------------------------
# 1) Helpers
# ----------------------------
SEQ_ALLOWED = set("ACGU")

def _clean_str_series(s: pd.Series) -> pd.Series:
    return s.astype("string").fillna("").str.strip()

def _add_seq_len_and_time(df: pd.DataFrame, name: str) -> pd.DataFrame:
    df = df.copy()
    # Standardize key cols
    for c in ["target_id","sequence","temporal_cutoff","description","stoichiometry","all_sequences","ligand_ids","ligand_SMILES"]:
        if c in df.columns:
            df[c] = _clean_str_series(df[c])
    # Derived columns
    if "sequence" in df.columns:
        df["L"] = df["sequence"].str.len().astype("int32")
    if "temporal_cutoff" in df.columns:
        df["temporal_cutoff_dt"] = pd.to_datetime(df["temporal_cutoff"], errors="coerce", utc=False)
    # Basic validity checks
    if "target_id" in df.columns:
        if df["target_id"].isna().any():
            raise ValueError(f"[{name}] Found NaN target_id.")
        if df["target_id"].duplicated().any():
            # sequences csv should be unique per target_id within split
            dups = df.loc[df["target_id"].duplicated(), "target_id"].head(10).tolist()
            raise ValueError(f"[{name}] Duplicate target_id detected (showing up to 10): {dups}")
    if "sequence" in df.columns:
        bad_mask = ~df["sequence"].str.fullmatch(r"[ACGU]*")
        if bad_mask.any():
            bad = df.loc[bad_mask, ["target_id","sequence"]].head(10)
            raise ValueError(
                f"[{name}] Found non-ACGU characters in `sequence` (showing up to 10 rows):\n{bad}"
            )
    return df

def _read_labels_fast(csv_path: str | Path, split_name: str):
    """
    Read labels with dtypes tuned:
    - ID: string
    - resname: string
    - resid: int32
    - chain: string
    - copy: int16
    - coordinate columns: float32
    Also returns list of coordinate columns found.
    """
    csv_path = Path(csv_path)
    # First read header only
    cols = pd.read_csv(csv_path, nrows=0).columns.tolist()

    # Detect coordinate columns: x_k, y_k, z_k (k can be 1..N)
    coord_pat = re.compile(r"^[xyz]_\d+$")
    coord_cols = [c for c in cols if coord_pat.match(c)]

    # Minimal required
    required = {"ID","resname","resid"}
    missing_req = sorted(list(required - set(cols)))
    if missing_req:
        raise ValueError(f"[{split_name}] labels missing required columns: {missing_req}")

    # dtype map
    dtype = {}
    for c in cols:
        if c in coord_cols:
            dtype[c] = "float32"
        elif c == "resid":
            dtype[c] = "int32"
        elif c == "copy":
            dtype[c] = "int16"
        else:
            # ID, resname, chain, etc.
            dtype[c] = "string"

    df = pd.read_csv(csv_path, dtype=dtype)
    # Standardize strings
    for c in ["ID","resname","chain"]:
        if c in df.columns:
            df[c] = _clean_str_series(df[c])
    # Quick parse target_id + resid from ID (safe for later)
    # ID format: {target_id}_{resid} where resid is 1-based integer
    if "ID" in df.columns:
        sp = df["ID"].str.rsplit("_", n=1, expand=True)
        if sp.shape[1] == 2:
            df["target_id"] = sp[0].astype("string")
            # keep numeric resid_from_id for consistency check later
            df["resid_from_id"] = pd.to_numeric(sp[1], errors="coerce").astype("Int32")
        else:
            df["target_id"] = pd.NA
            df["resid_from_id"] = pd.NA

    # Basic checks
    if df["resid"].isna().any():
        raise ValueError(f"[{split_name}] labels: NaN resid found.")
    if (df["resid"] <= 0).any():
        bad = df.loc[df["resid"] <= 0, ["ID","resid"]].head(10)
        raise ValueError(f"[{split_name}] labels: resid must be 1-based (>0). Bad rows:\n{bad}")

    # coord columns count should be multiple of 3 (x/y/z for each reference/pred)
    if len(coord_cols) % 3 != 0:
        raise ValueError(f"[{split_name}] labels: number of coord cols is not multiple of 3: {len(coord_cols)}")

    return df, coord_cols

def _print_seq_summary(df: pd.DataFrame, name: str):
    L = df["L"].to_numpy()
    print(f"\n[{name}] n_targets={len(df):,} | L: min={L.min():,}  p50={int(np.median(L)):,}  p90={int(np.quantile(L,0.9)):,}  max={L.max():,}")
    if "temporal_cutoff_dt" in df.columns:
        tmin = df["temporal_cutoff_dt"].min()
        tmax = df["temporal_cutoff_dt"].max()
        nbad = df["temporal_cutoff_dt"].isna().sum()
        print(f"[{name}] temporal_cutoff_dt: min={tmin}  max={tmax}  invalid_dates={nbad:,}")

# ----------------------------
# 2) Load sequences
# ----------------------------
df_train_seq = pd.read_csv(PATHS["TRAIN_SEQ"])
df_val_seq   = pd.read_csv(PATHS["VAL_SEQ"])
df_test_seq  = pd.read_csv(PATHS["TEST_SEQ"])

df_train_seq = _add_seq_len_and_time(df_train_seq, "train_sequences")
df_val_seq   = _add_seq_len_and_time(df_val_seq, "validation_sequences")
df_test_seq  = _add_seq_len_and_time(df_test_seq, "test_sequences")

# Verify required columns exist in sequences
SEQ_REQUIRED = ["target_id","sequence","temporal_cutoff","stoichiometry","all_sequences"]
for name, df in [("train_sequences", df_train_seq), ("validation_sequences", df_val_seq), ("test_sequences", df_test_seq)]:
    miss = [c for c in SEQ_REQUIRED if c not in df.columns]
    if miss:
        raise ValueError(f"[{name}] missing required columns: {miss}")

# ----------------------------
# 3) Load labels (fast dtype)
# ----------------------------
df_train_lbl, LABEL_COORD_COLS_TRAIN = _read_labels_fast(PATHS["TRAIN_LBL"], "train_labels")
df_val_lbl,   LABEL_COORD_COLS_VAL   = _read_labels_fast(PATHS["VAL_LBL"], "validation_labels")

# ----------------------------
# 4) Load sample_submission (and basic checks)
# ----------------------------
df_sample_sub = pd.read_csv(PATHS["SAMPLE_SUB"])

# Must contain ID + coordinates for 5 predictions (x_1..z_5)
if "ID" not in df_sample_sub.columns:
    raise ValueError("sample_submission.csv must contain column `ID`.")

sub_cols = df_sample_sub.columns.tolist()
sub_coord_pat = re.compile(r"^[xyz]_[1-5]$")
sub_coord_cols = [c for c in sub_cols if sub_coord_pat.match(c)]
if len(sub_coord_cols) != 15:
    # x_1..x_5 (5) + y_1..y_5 (5) + z_1..z_5 (5) = 15
    raise ValueError(f"sample_submission must contain 15 coord cols (x_1..z_5). Found {len(sub_coord_cols)}")

# ----------------------------
# 5) Optional: detect aux files for later stages (no heavy parsing here)
# ----------------------------
has_parse_fasta = Path(PATHS["PARSE_FASTA_PY"]).exists()
has_rna_metadata = Path(PATHS["RNA_METADATA"]).exists()

msa_files = sorted(MSA_DIR.glob("*.MSA.fasta"))
cif_files = sorted(PDB_RNA_DIR.glob("*.cif"))

# ----------------------------
# 6) Summary prints (compact)
# ----------------------------
print("=== PATHS ===")
print(json.dumps(PATHS, indent=2))

_print_seq_summary(df_train_seq, "train_sequences")
_print_seq_summary(df_val_seq,   "validation_sequences")
_print_seq_summary(df_test_seq,  "test_sequences")

print(f"\n[train_labels] rows={len(df_train_lbl):,} | coord_cols={len(LABEL_COORD_COLS_TRAIN)} (refs={len(LABEL_COORD_COLS_TRAIN)//3})")
print(f"[val_labels]   rows={len(df_val_lbl):,} | coord_cols={len(LABEL_COORD_COLS_VAL)} (refs={len(LABEL_COORD_COLS_VAL)//3})")

print(f"\n[sample_submission] rows={len(df_sample_sub):,} | coord_cols={len(sub_coord_cols)} (must be 15)")

print("\n=== AUX FILES (for next stages) ===")
print(f"MSA_DIR exists: {MSA_DIR.exists()} | n_msa_files(train+val provided)={len(msa_files):,} | example={msa_files[0].name if msa_files else None}")
print(f"PDB_RNA_DIR exists: {PDB_RNA_DIR.exists()} | n_cif_files={len(cif_files):,} | example={cif_files[0].name if cif_files else None}")
print(f"extra/parse_fasta_py.py exists: {has_parse_fasta}")
print(f"extra/rna_metadata.csv exists: {has_rna_metadata}")
print(f"OUT_DIR: {OUT_DIR}")

# Quick alignment sanity: labels target_id should be subset of sequences target_id for the same split
train_targets = set(df_train_seq["target_id"].tolist())
val_targets   = set(df_val_seq["target_id"].tolist())
train_lbl_targets = set(df_train_lbl["target_id"].dropna().unique().tolist())
val_lbl_targets   = set(df_val_lbl["target_id"].dropna().unique().tolist())

missing_train = sorted(list(train_lbl_targets - train_targets))[:10]
missing_val   = sorted(list(val_lbl_targets - val_targets))[:10]

if missing_train:
    raise ValueError(f"train_labels contain target_id not found in train_sequences (show up to 10): {missing_train}")
if missing_val:
    raise ValueError(f"validation_labels contain target_id not found in validation_sequences (show up to 10): {missing_val}")

print("\n[OK] Stage 1 complete: inputs loaded & standardized.")


=== PATHS ===
{
  "COMP_ROOT": "/kaggle/input/stanford-rna-3d-folding-2",
  "MSA_DIR": "/kaggle/input/stanford-rna-3d-folding-2/MSA",
  "PDB_RNA_DIR": "/kaggle/input/stanford-rna-3d-folding-2/PDB_RNA",
  "EXTRA_DIR": "/kaggle/input/stanford-rna-3d-folding-2/extra",
  "SAMPLE_SUB": "/kaggle/input/stanford-rna-3d-folding-2/sample_submission.csv",
  "TEST_SEQ": "/kaggle/input/stanford-rna-3d-folding-2/test_sequences.csv",
  "TRAIN_LBL": "/kaggle/input/stanford-rna-3d-folding-2/train_labels.csv",
  "TRAIN_SEQ": "/kaggle/input/stanford-rna-3d-folding-2/train_sequences.csv",
  "VAL_LBL": "/kaggle/input/stanford-rna-3d-folding-2/validation_labels.csv",
  "VAL_SEQ": "/kaggle/input/stanford-rna-3d-folding-2/validation_sequences.csv",
  "PARSE_FASTA_PY": "/kaggle/input/stanford-rna-3d-folding-2/extra/parse_fasta_py.py",
  "RNA_METADATA": "/kaggle/input/stanford-rna-3d-folding-2/extra/rna_metadata.csv",
  "EXTRA_README": "/kaggle/input/stanford-rna-3d-folding-2/extra/README.md"
}

[train_sequence

# Parse all_sequences + Resolve stoichiometry + Build Chain Boundaries

In [6]:
# ============================================================
# NOTEBOOK-1 / STAGE 2 — Parse all_sequences + Resolve stoichiometry + Build Chain Boundaries (ONE CELL)
# REVISI FULL (fixes the 3 failing train targets)
#
# Key fixes vs previous:
# 1) More robust FASTA header parsing: collects BOTH chain_id and auth_chain_id as aliases
#    e.g. "Chain A[auth B]" -> aliases include "A" and "B"
# 2) Normalize chain sequences to match canonical target sequence:
#    - uppercase
#    - convert T -> U
# 3) Do NOT partially write segments for a target that fails validation (atomic per-target segment commit)
# 4) Hard-fail policy is adjustable; default keeps going but prints failing targets clearly
#
# Requires globals from STAGE 1:
# - PATHS, OUT_DIR
# - df_train_seq, df_val_seq, df_test_seq
#
# Writes:
# - OUT_DIR/tables/targets_*_stage2.parquet
# - OUT_DIR/tables/segments_*.parquet
# - OUT_DIR/meta/qa_stage2_*.csv
#
# Outputs (globals):
# - df_train_seq2, df_val_seq2, df_test_seq2
# - seg_train, seg_val, seg_test
# - qa_train2, qa_val2, qa_test2
# ============================================================

import re, json, hashlib, warnings
from pathlib import Path
from typing import Dict, Tuple, List, Optional

import numpy as np
import pandas as pd

warnings.filterwarnings("ignore", category=pd.errors.DtypeWarning)

# ----------------------------
# 0) Require STAGE 1 globals
# ----------------------------
for need in ["PATHS", "OUT_DIR", "df_train_seq", "df_val_seq", "df_test_seq"]:
    if need not in globals():
        raise RuntimeError(f"Missing `{need}`. Run STAGE 1 first.")

OUT_DIR = Path(OUT_DIR)
TABLE_DIR = OUT_DIR / "tables"
META_DIR  = OUT_DIR / "meta"
TABLE_DIR.mkdir(parents=True, exist_ok=True)
META_DIR.mkdir(parents=True, exist_ok=True)

# ----------------------------
# 1) Robust FASTA parser (extends provided idea):
#    - extracts chain_id and auth_chain_id aliases for each chain entry
# ----------------------------
RE_CHAINS_PREFIX = re.compile(r"^Chains?\s+", re.IGNORECASE)
RE_AUTH = re.compile(r"\[auth\s+([^\]]+)\]", re.IGNORECASE)

def _norm_seq(s: str) -> str:
    # canonicalize: uppercase + DNA->RNA
    return str(s).strip().upper().replace("T", "U")

def _parse_chain_token(token: str) -> Tuple[Optional[str], Optional[str]]:
    """
    token examples:
      "A[auth B]"
      "A [auth B]"
      "A"
      "B[auth 1]" (multi-char auth allowed)
    returns:
      (chain_id, auth_chain_id)
    """
    t = token.strip()
    if not t:
        return None, None
    # chain_id is the leading chunk before '[' (or whole string if no '[')
    chain_id = t.split("[", 1)[0].strip()
    chain_id = chain_id.split()[0].strip() if chain_id else None

    m = RE_AUTH.search(t)
    auth_id = m.group(1).strip() if m else None
    return chain_id, auth_id

def parse_fasta_allseq(fasta_content: str) -> Dict[str, Tuple[str, List[str]]]:
    """
    Returns dict: {primary_key: (sequence, aliases)}
    - primary_key: prefer first auth_chain_id if present else first chain_id
    - aliases: include BOTH chain_id and auth_chain_id for all chains listed in header
    """
    result: Dict[str, Tuple[str, List[str]]] = {}
    if fasta_content is None:
        return result

    lines = str(fasta_content).strip().splitlines()
    i = 0
    while i < len(lines):
        line = lines[i].strip()
        if line.startswith(">"):
            parts = line.split("|")
            chains_part = parts[1].strip() if len(parts) >= 2 else ""
            chains_part = RE_CHAINS_PREFIX.sub("", chains_part).strip()

            aliases: List[str] = []
            # split by comma for "Chains A[auth A], B[auth B]"
            chain_tokens = [c.strip() for c in chains_part.split(",") if c.strip()] if chains_part else []

            # collect both chain_id and auth_id as aliases
            primary_auth: Optional[str] = None
            primary_chain: Optional[str] = None

            for tok in chain_tokens:
                chain_id, auth_id = _parse_chain_token(tok)
                if primary_chain is None and chain_id:
                    primary_chain = chain_id
                if primary_auth is None and auth_id:
                    primary_auth = auth_id
                if chain_id:
                    aliases.append(chain_id)
                if auth_id:
                    aliases.append(auth_id)

            # if header malformed, aliases empty -> fallback primary None
            primary = primary_auth or primary_chain

            # read sequence lines until next header
            seq = ""
            while (i + 1) < len(lines) and (not lines[i + 1].startswith(">")):
                seq += lines[i + 1].strip()
                i += 1
            seq = _norm_seq(seq)

            if primary:
                # de-dup aliases, keep order
                seen = set()
                aliases_u = []
                for a in aliases:
                    a = str(a).strip()
                    if a and a not in seen:
                        aliases_u.append(a)
                        seen.add(a)
                # ensure primary included
                if primary not in seen:
                    aliases_u = [primary] + aliases_u
                result[str(primary).strip()] = (seq, aliases_u)

        i += 1

    return result

# ----------------------------
# 2) Stoichiometry parsing + hashing utils
# ----------------------------
STOICH_ITEM_RE = re.compile(r"^\s*([A-Za-z0-9]+)\s*:\s*([0-9]+)\s*$")

def parse_stoichiometry(stoich: str):
    if stoich is None:
        return []
    s = str(stoich).strip()
    if not s:
        return []
    s = s.replace("{", "").replace("}", "")
    items = []
    for part in s.split(";"):
        part = part.strip()
        if not part:
            continue
        m = STOICH_ITEM_RE.match(part)
        if not m:
            raise ValueError(f"Bad stoichiometry token: `{part}` (from `{stoich}`)")
        chain_token = m.group(1).strip()
        copies = int(m.group(2))
        if copies <= 0:
            raise ValueError(f"Stoichiometry copies must be >0: `{part}` (from `{stoich}`)")
        items.append((chain_token, copies))
    return items

def md5_str(s: str) -> str:
    h = hashlib.md5()
    h.update(s.encode("utf-8"))
    return h.hexdigest()

def md5_pieces(pieces: List[str]) -> str:
    h = hashlib.md5()
    for p in pieces:
        h.update(p.encode("utf-8"))
    return h.hexdigest()

def build_alias_to_primary(chain_dict: Dict[str, Tuple[str, List[str]]]):
    """
    chain_dict: {primary: (seq, aliases)}
    returns:
      primary_to_seq, alias_to_primary
    """
    primary_to_seq = {}
    alias_to_primary = {}
    for primary, (seq, aliases) in (chain_dict or {}).items():
        primary = str(primary).strip()
        if not primary:
            continue
        seq = _norm_seq(seq)
        primary_to_seq[primary] = seq
        aliases = aliases or [primary]
        for a in aliases:
            a = str(a).strip()
            if a and a not in alias_to_primary:
                alias_to_primary[a] = primary
        # include primary
        if primary not in alias_to_primary:
            alias_to_primary[primary] = primary
    return primary_to_seq, alias_to_primary

# ----------------------------
# 3) Core processing per split (atomic per-target segments)
# ----------------------------
def process_split(df_seq: pd.DataFrame, split_name: str, join_threshold: int = 20000, hard_fail: bool = False):
    required = ["target_id","sequence","L","stoichiometry","all_sequences"]
    miss = [c for c in required if c not in df_seq.columns]
    if miss:
        raise ValueError(f"[{split_name}] missing required columns: {miss}")

    out = df_seq.copy()
    # normalize target sequence too (should already be ACGU)
    out["sequence"] = out["sequence"].astype("string").fillna("").str.strip().str.upper().str.replace("T", "U", regex=False)

    # new cols
    out["stoich_items_json"] = ""
    out["n_segments"] = 0
    out["n_primary_chains"] = 0
    out["L_calc"] = pd.NA
    out["boundary_starts_json"] = ""
    out["parse_ok"] = False
    out["stoich_ok"] = False
    out["rebuild_ok"] = False
    out["err_stage2"] = ""

    seg_rows_all = []
    qa_rows = []

    for row in out.itertuples(index=True):
        idx = row.Index
        tid = row.target_id
        seq_given = row.sequence
        L_given = int(row.L)

        parse_ok = False
        stoich_ok = False
        rebuild_ok = False
        err = ""
        L_calc = None
        boundary_starts = []
        stoich_items = None
        n_segments = 0
        n_primary = 0

        # per-target temp holders to avoid partial commit
        seg_rows_tmp = []
        pieces = []
        primaries_used = set()

        try:
            chain_dict = parse_fasta_allseq(row.all_sequences)
            primary_to_seq, alias_to_primary = build_alias_to_primary(chain_dict)
            if not primary_to_seq:
                raise ValueError("parse_fasta_allseq returned empty chain dictionary")
            parse_ok = True

            stoich_items = parse_stoichiometry(row.stoichiometry)
            if not stoich_items:
                raise ValueError("empty stoichiometry after parsing")

            start = 1
            seg_idx = 0

            for chain_token, copies in stoich_items:
                # try direct, then uppercase (defensive)
                ct = str(chain_token).strip()
                if ct not in alias_to_primary and ct.upper() in alias_to_primary:
                    ct = ct.upper()

                if ct in alias_to_primary:
                    primary = alias_to_primary[ct]
                else:
                    # last resort: sometimes tokens include whitespace; try strip already done
                    raise KeyError(f"stoichiometry chain `{chain_token}` not found among aliases (n_aliases={len(alias_to_primary)})")

                chain_seq = primary_to_seq.get(primary, "")
                chain_seq = _norm_seq(chain_seq)
                if not chain_seq:
                    raise ValueError(f"resolved primary chain `{primary}` has empty sequence")

                primaries_used.add(primary)

                for copy_idx in range(1, copies + 1):
                    seg_len = len(chain_seq)
                    end = start + seg_len - 1
                    seg_idx += 1
                    boundary_starts.append(start)
                    pieces.append(chain_seq)

                    seg_rows_tmp.append({
                        "split": split_name,
                        "target_id": tid,
                        "seg_idx": seg_idx,
                        "chain_token": chain_token,
                        "chain_primary": primary,
                        "copy_idx": copy_idx,
                        "start_1based": start,
                        "end_1based": end,
                        "seg_len": seg_len,
                    })
                    start = end + 1

            L_calc = start - 1
            n_segments = seg_idx
            n_primary = len(primaries_used)
            stoich_ok = True

            # length check
            if L_calc != L_given:
                raise ValueError(f"L mismatch: L_calc={L_calc} vs L_given={L_given}")

            # sequence check (direct for moderate, md5 for huge)
            if L_calc <= join_threshold:
                seq_rebuilt = "".join(pieces)
                if seq_rebuilt != seq_given:
                    # common edge case: sequence_given already uppercase; still mismatch means mapping/order issue
                    raise ValueError("sequence mismatch (direct compare)")
            else:
                if md5_pieces(pieces) != md5_str(seq_given):
                    raise ValueError("sequence mismatch (md5 compare)")

            rebuild_ok = True

            # commit segments only if OK
            seg_rows_all.extend(seg_rows_tmp)

        except Exception as e:
            err = str(e)

        out.at[idx, "stoich_items_json"] = json.dumps(stoich_items) if stoich_items is not None else ""
        out.at[idx, "n_segments"] = int(n_segments)
        out.at[idx, "n_primary_chains"] = int(n_primary)
        out.at[idx, "L_calc"] = int(L_calc) if L_calc is not None else pd.NA
        out.at[idx, "boundary_starts_json"] = json.dumps(boundary_starts) if boundary_starts else ""
        out.at[idx, "parse_ok"] = bool(parse_ok)
        out.at[idx, "stoich_ok"] = bool(stoich_ok)
        out.at[idx, "rebuild_ok"] = bool(rebuild_ok)
        out.at[idx, "err_stage2"] = err

        qa_rows.append({
            "split": split_name,
            "target_id": tid,
            "L": L_given,
            "L_calc": (int(L_calc) if L_calc is not None else np.nan),
            "n_segments": int(n_segments),
            "n_primary_chains": int(n_primary),
            "parse_ok": bool(parse_ok),
            "stoich_ok": bool(stoich_ok),
            "rebuild_ok": bool(rebuild_ok),
            "err": err,
        })

    qa_df = pd.DataFrame(qa_rows)
    seg_df = pd.DataFrame(seg_rows_all)

    n_total = len(out)
    n_ok = int(out["rebuild_ok"].sum())
    n_bad = n_total - n_ok
    print(f"[{split_name}] STAGE2 rebuild_ok: {n_ok:,}/{n_total:,} (bad={n_bad:,})")

    if n_bad > 0:
        bad_preview = qa_df.loc[~qa_df["rebuild_ok"]].head(20)[
            ["target_id","L","L_calc","parse_ok","stoich_ok","err"]
        ]
        print(f"\n[{split_name}] Failures (up to 20):")
        print(bad_preview.to_string(index=False))
        if hard_fail:
            raise RuntimeError(f"[{split_name}] STAGE2 found {n_bad} failing targets. Fix mapping before proceeding.")

    return out, seg_df, qa_df

# ----------------------------
# 4) Run STAGE2 (default hard_fail=False to allow progress; set True if you want strict)
# ----------------------------
df_train_seq2, seg_train, qa_train2 = process_split(df_train_seq, "train", join_threshold=20000, hard_fail=False)
df_val_seq2,   seg_val,   qa_val2   = process_split(df_val_seq,   "val",   join_threshold=20000, hard_fail=True)   # val/test should be clean
df_test_seq2,  seg_test,  qa_test2  = process_split(df_test_seq,  "test",  join_threshold=20000, hard_fail=True)

# ----------------------------
# 5) Save artifacts
# ----------------------------
df_train_seq2.to_parquet(TABLE_DIR / "targets_train_stage2.parquet", index=False)
df_val_seq2.to_parquet(  TABLE_DIR / "targets_val_stage2.parquet",   index=False)
df_test_seq2.to_parquet( TABLE_DIR / "targets_test_stage2.parquet",  index=False)

seg_train.to_parquet(TABLE_DIR / "segments_train.parquet", index=False)
seg_val.to_parquet(  TABLE_DIR / "segments_val.parquet",   index=False)
seg_test.to_parquet( TABLE_DIR / "segments_test.parquet",  index=False)

qa_train2.to_csv(META_DIR / "qa_stage2_train.csv", index=False)
qa_val2.to_csv(  META_DIR / "qa_stage2_val.csv",   index=False)
qa_test2.to_csv( META_DIR / "qa_stage2_test.csv",  index=False)

# Helpful summary: how many targets were skipped in train (if any)
n_train_bad = int((~df_train_seq2["rebuild_ok"]).sum())
if n_train_bad > 0:
    bad_ids = df_train_seq2.loc[~df_train_seq2["rebuild_ok"], "target_id"].head(20).tolist()
    print(f"\n[train] WARNING: {n_train_bad} targets could not be segmented/rebuilt. Example ids (up to 20): {bad_ids}")
    print("They remain in targets_train_stage2.parquet with rebuild_ok=False, but will NOT appear in segments_train.parquet.")

print("\n[OK] STAGE 2 complete.")
print(f"Saved targets:  {TABLE_DIR}/targets_*_stage2.parquet")
print(f"Saved segments: {TABLE_DIR}/segments_*.parquet")
print(f"Saved QA:       {META_DIR}/qa_stage2_*.csv")


[train] STAGE2 rebuild_ok: 5,550/5,716 (bad=166)

[train] Failures (up to 20):
target_id    L  L_calc  parse_ok  stoich_ok                                     err
     1FEU   40    42.0      True       True     L mismatch: L_calc=42 vs L_given=40
     1M5K  113   184.0      True       True   L mismatch: L_calc=184 vs L_given=113
     1N35   15    20.0      True       True     L mismatch: L_calc=20 vs L_given=15
     1TFY   46    44.0      True       True     L mismatch: L_calc=44 vs L_given=46
     1YSH  163    96.0      True       True    L mismatch: L_calc=96 vs L_given=163
     2E9T   15    14.0      True       True     L mismatch: L_calc=14 vs L_given=15
     3BO2  222    41.0      True       True    L mismatch: L_calc=41 vs L_given=222
     3BO3  222    41.0      True       True    L mismatch: L_calc=41 vs L_given=222
     3DEG  313   367.0      True       True   L mismatch: L_calc=367 vs L_given=313
     3HAX   77    28.0      True       True     L mismatch: L_calc=28 vs L_given=

# Rebuild & Validate Target Sequence + Residue Index Table

In [7]:
# ============================================================
# NOTEBOOK-1 / STAGE 3 — Rebuild-Validate (coverage) + Residue Index Table (ONE CELL)
#
# Requires STAGE 2 artifacts already saved in:
#   OUT_DIR=/kaggle/working/rna3d_artifacts_v1
#   - tables/targets_*_stage2.parquet
#   - tables/segments_*.parquet
#
# What this stage does:
# 1) Load targets_stage2 + segments tables
# 2) For each split:
#    - keep only targets with rebuild_ok=True (train has some False; val/test all True)
#    - validate segment coverage: start at 1, contiguous, last end == L
#    - build a per-residue index table:
#        columns: target_id, resid, resname, chain_primary, copy_idx, seg_idx, is_chain_start, is_chain_end
#    - write as partitioned parquet parts (safe memory):
#        OUT_DIR/residue_index/{split}/part-xxxxx.parquet
#    - write QA summary:
#        OUT_DIR/meta/qa_stage3_{split}.csv
#
# Outputs (globals):
# - targets2_train, targets2_val, targets2_test
# - seg_train, seg_val, seg_test
# - qa3_train, qa3_val, qa3_test
# - RESIDX_DIR
# ============================================================

import gc, math, warnings
from pathlib import Path
import numpy as np
import pandas as pd

warnings.filterwarnings("ignore", category=pd.errors.DtypeWarning)

# ----------------------------
# 0) Paths + Load STAGE2 artifacts
# ----------------------------
if "OUT_DIR" not in globals():
    OUT_DIR = Path("/kaggle/working/rna3d_artifacts_v1")
else:
    OUT_DIR = Path(OUT_DIR)

TABLE_DIR = OUT_DIR / "tables"
META_DIR  = OUT_DIR / "meta"
RESIDX_DIR = OUT_DIR / "residue_index"

for p in [TABLE_DIR, META_DIR]:
    if not p.exists():
        raise FileNotFoundError(f"Missing STAGE2 output directory: {p}. Run STAGE 2 first.")

def _req(p: Path):
    if not p.exists():
        raise FileNotFoundError(f"Missing file: {p}")
    return p

targets2_train = pd.read_parquet(_req(TABLE_DIR / "targets_train_stage2.parquet"))
targets2_val   = pd.read_parquet(_req(TABLE_DIR / "targets_val_stage2.parquet"))
targets2_test  = pd.read_parquet(_req(TABLE_DIR / "targets_test_stage2.parquet"))

seg_train = pd.read_parquet(_req(TABLE_DIR / "segments_train.parquet"))
seg_val   = pd.read_parquet(_req(TABLE_DIR / "segments_val.parquet"))
seg_test  = pd.read_parquet(_req(TABLE_DIR / "segments_test.parquet"))

# Ensure types
for df in [targets2_train, targets2_val, targets2_test]:
    df["target_id"] = df["target_id"].astype("string")
    df["sequence"]  = df["sequence"].astype("string")
    df["L"]         = df["L"].astype("int32")
    if "rebuild_ok" in df.columns:
        df["rebuild_ok"] = df["rebuild_ok"].astype("bool")

for df in [seg_train, seg_val, seg_test]:
    df["target_id"] = df["target_id"].astype("string")
    for c in ["seg_idx","copy_idx","start_1based","end_1based","seg_len"]:
        if c in df.columns:
            df[c] = df[c].astype("int32")

# ----------------------------
# 1) Builder (batched) to avoid memory spikes
# ----------------------------
def build_residue_index(
    split_name: str,
    targets_df: pd.DataFrame,
    seg_df: pd.DataFrame,
    out_root: Path,
    batch_targets: int = 200,
    hard_fail: bool = True
):
    """
    Writes partitioned parquet parts to:
      out_root/{split}/part-00000.parquet, ...
    Returns QA dataframe.
    """
    out_dir = out_root / split_name
    out_dir.mkdir(parents=True, exist_ok=True)

    # Filter to clean targets only
    if "rebuild_ok" in targets_df.columns:
        targets_ok = targets_df[targets_df["rebuild_ok"] == True].copy()
    else:
        targets_ok = targets_df.copy()

    # Quick mapping: target_id -> (sequence, L)
    # (avoid heavy merges)
    tid_list = targets_ok["target_id"].tolist()
    seq_map = dict(zip(targets_ok["target_id"].tolist(), targets_ok["sequence"].tolist()))
    L_map   = dict(zip(targets_ok["target_id"].tolist(), targets_ok["L"].tolist()))

    # Segment groupby
    seg_g = seg_df.groupby("target_id", sort=False)

    qa_rows = []
    part_idx = 0

    # Batch processing
    for b0 in range(0, len(tid_list), batch_targets):
        batch = tid_list[b0:b0+batch_targets]

        cols_target_id = []
        cols_resid     = []
        cols_resname   = []
        cols_chain     = []
        cols_copy      = []
        cols_segidx    = []
        cols_is_start  = []
        cols_is_end    = []

        for tid in batch:
            seq = str(seq_map[tid])
            L   = int(L_map[tid])

            if tid not in seg_g.indices:
                qa_rows.append({
                    "split": split_name, "target_id": tid, "L": L,
                    "coverage_ok": False, "n_segments": 0, "n_rows": 0,
                    "err": "missing segments for target_id"
                })
                if hard_fail:
                    raise RuntimeError(f"[{split_name}] Missing segments for target_id={tid}")
                continue

            segs = seg_g.get_group(tid).sort_values(["start_1based","seg_idx"])
            nseg = len(segs)

            # ---- coverage validation (this is the "Rebuild & Validate" part here) ----
            ok = True
            err = ""
            try:
                # must start at 1
                if int(segs["start_1based"].iloc[0]) != 1:
                    raise ValueError(f"segments do not start at 1 (start={int(segs['start_1based'].iloc[0])})")
                # contiguous
                prev_end = int(segs["end_1based"].iloc[0])
                for j in range(1, nseg):
                    st = int(segs["start_1based"].iloc[j])
                    if st != prev_end + 1:
                        raise ValueError(f"non-contiguous segments at j={j} (start={st}, expected={prev_end+1})")
                    prev_end = int(segs["end_1based"].iloc[j])
                # must end at L
                if prev_end != L:
                    raise ValueError(f"segments end != L (end={prev_end}, L={L})")
            except Exception as e:
                ok = False
                err = str(e)
                qa_rows.append({
                    "split": split_name, "target_id": tid, "L": L,
                    "coverage_ok": False, "n_segments": nseg, "n_rows": 0,
                    "err": err
                })
                if hard_fail:
                    raise RuntimeError(f"[{split_name}] Coverage validation failed for {tid}: {err}")
                continue

            # ---- build per-residue rows ----
            # Convert whole sequence once
            # Use bytes array so slicing is cheap
            seq_bytes = np.frombuffer(seq.encode("ascii"), dtype="S1")  # shape (L,)
            # Safety check
            if seq_bytes.shape[0] != L:
                qa_rows.append({
                    "split": split_name, "target_id": tid, "L": L,
                    "coverage_ok": False, "n_segments": nseg, "n_rows": 0,
                    "err": f"sequence length mismatch in memory (len(seq)={seq_bytes.shape[0]}, L={L})"
                })
                if hard_fail:
                    raise RuntimeError(f"[{split_name}] sequence length mismatch for {tid}")
                continue

            # Accumulate arrays per segment
            # (Total length = L, so these arrays will sum to L each target)
            tid_arrs, resid_arrs, resname_arrs, chain_arrs, copy_arrs, segidx_arrs, is_start_arrs, is_end_arrs = [],[],[],[],[],[],[],[]

            for r in segs.itertuples(index=False):
                st = int(r.start_1based)
                en = int(r.end_1based)
                seglen = en - st + 1

                resid = np.arange(st, en+1, dtype=np.int32)
                resname = seq_bytes[st-1:en].astype("U1")  # convert bytes -> 1-char str array

                tid_arrs.append(np.full(seglen, tid, dtype=object))
                resid_arrs.append(resid)
                resname_arrs.append(resname)

                chain_arrs.append(np.full(seglen, getattr(r, "chain_primary"), dtype=object))
                copy_arrs.append(np.full(seglen, int(r.copy_idx), dtype=np.int32))
                segidx_arrs.append(np.full(seglen, int(r.seg_idx), dtype=np.int32))

                is_start = np.zeros(seglen, dtype=bool); is_start[0] = True
                is_end   = np.zeros(seglen, dtype=bool); is_end[-1] = True
                is_start_arrs.append(is_start)
                is_end_arrs.append(is_end)

            cols_target_id.append(np.concatenate(tid_arrs))
            cols_resid.append(np.concatenate(resid_arrs))
            cols_resname.append(np.concatenate(resname_arrs))
            cols_chain.append(np.concatenate(chain_arrs))
            cols_copy.append(np.concatenate(copy_arrs))
            cols_segidx.append(np.concatenate(segidx_arrs))
            cols_is_start.append(np.concatenate(is_start_arrs))
            cols_is_end.append(np.concatenate(is_end_arrs))

            qa_rows.append({
                "split": split_name, "target_id": tid, "L": L,
                "coverage_ok": True, "n_segments": nseg, "n_rows": L,
                "err": ""
            })

        # Write this batch if any data
        if cols_target_id:
            df_part = pd.DataFrame({
                "target_id": np.concatenate(cols_target_id),
                "resid":     np.concatenate(cols_resid).astype("int32"),
                "resname":   pd.Categorical(np.concatenate(cols_resname)),
                "chain_primary": pd.Categorical(np.concatenate(cols_chain)),
                "copy_idx":  np.concatenate(cols_copy).astype("int16"),
                "seg_idx":   np.concatenate(cols_segidx).astype("int16"),
                "is_chain_start": np.concatenate(cols_is_start).astype(bool),
                "is_chain_end":   np.concatenate(cols_is_end).astype(bool),
            })

            part_path = out_dir / f"part-{part_idx:05d}.parquet"
            df_part.to_parquet(part_path, index=False)
            part_idx += 1

            # free
            del df_part
            gc.collect()

        if (b0 // batch_targets) % 10 == 0:
            print(f"[{split_name}] processed {min(b0+batch_targets, len(tid_list)):,}/{len(tid_list):,} targets | parts={part_idx}")

    qa_df = pd.DataFrame(qa_rows)

    # Save QA
    qa_path = META_DIR / f"qa_stage3_{split_name}.csv"
    qa_df.to_csv(qa_path, index=False)

    # Print summary
    n_total = len(tid_list)
    n_ok = int((qa_df["coverage_ok"] == True).sum())
    n_bad = n_total - n_ok
    print(f"\n[{split_name}] STAGE3 coverage_ok: {n_ok:,}/{n_total:,} (bad={n_bad:,})")
    if n_bad > 0:
        print(f"[{split_name}] Examples of failures (up to 10):")
        print(qa_df.loc[~qa_df["coverage_ok"]].head(10)[["target_id","L","n_segments","err"]].to_string(index=False))

    print(f"[{split_name}] residue_index parts written to: {out_dir}")
    print(f"[{split_name}] QA saved: {qa_path}")

    return qa_df

# ----------------------------
# 2) Run STAGE3 for train/val/test
# ----------------------------
print("=== STAGE 3: Build residue index tables ===")
# Train: allow skip of bad (already filtered rebuild_ok=True in targets), hard_fail=True still safe
qa3_train = build_residue_index("train", targets2_train, seg_train, RESIDX_DIR, batch_targets=200, hard_fail=True)
qa3_val   = build_residue_index("val",   targets2_val,   seg_val,   RESIDX_DIR, batch_targets=200, hard_fail=True)
qa3_test  = build_residue_index("test",  targets2_test,  seg_test,  RESIDX_DIR, batch_targets=200, hard_fail=True)

# ----------------------------
# 3) Save a small pointer file listing where residue parts are
# ----------------------------
ptr = {
    "residue_index_dir": str(RESIDX_DIR),
    "train_glob": str((RESIDX_DIR/"train"/"part-*.parquet")),
    "val_glob":   str((RESIDX_DIR/"val"/"part-*.parquet")),
    "test_glob":  str((RESIDX_DIR/"test"/"part-*.parquet")),
}
(META_DIR / "residue_index_paths_stage3.json").write_text(json.dumps(ptr, indent=2))

print("\n[OK] STAGE 3 complete.")
print(json.dumps(ptr, indent=2))


=== STAGE 3: Build residue index tables ===
[train] processed 200/5,550 targets | parts=1
[train] processed 2,200/5,550 targets | parts=11
[train] processed 4,200/5,550 targets | parts=21

[train] STAGE3 coverage_ok: 5,550/5,550 (bad=0)
[train] residue_index parts written to: /kaggle/working/rna3d_artifacts_v1/residue_index/train
[train] QA saved: /kaggle/working/rna3d_artifacts_v1/meta/qa_stage3_train.csv
[val] processed 28/28 targets | parts=1

[val] STAGE3 coverage_ok: 28/28 (bad=0)
[val] residue_index parts written to: /kaggle/working/rna3d_artifacts_v1/residue_index/val
[val] QA saved: /kaggle/working/rna3d_artifacts_v1/meta/qa_stage3_val.csv
[test] processed 28/28 targets | parts=1

[test] STAGE3 coverage_ok: 28/28 (bad=0)
[test] residue_index parts written to: /kaggle/working/rna3d_artifacts_v1/residue_index/test
[test] QA saved: /kaggle/working/rna3d_artifacts_v1/meta/qa_stage3_test.csv

[OK] STAGE 3 complete.
{
  "residue_index_dir": "/kaggle/working/rna3d_artifacts_v1/residue

# Build Clean Label Tensors (Multi-Reference) for Train/Validation

In [8]:
# ============================================================
# NOTEBOOK-1 / STAGE 4 — Build Clean Label Tensors (Multi-Reference) (ONE CELL)
#
# Requires:
# - OUT_DIR from previous stages (default /kaggle/working/rna3d_artifacts_v1)
# - STAGE 2 outputs:
#     tables/targets_train_stage2.parquet
#     tables/targets_val_stage2.parquet
# - Labels in memory from STAGE 1 (preferred):
#     df_train_lbl, df_val_lbl
#   If missing, will reload from PATHS["TRAIN_LBL"] / PATHS["VAL_LBL"].
#
# Outputs:
# - labels_npz/train/<target_id>.npz
# - labels_npz/val/<target_id>.npz
# - meta/qa_stage4_train.csv
# - meta/qa_stage4_val.csv
# - meta/labels_manifest_train.parquet
# - meta/labels_manifest_val.parquet
# ============================================================

import os, re, gc, json, warnings
from pathlib import Path
import numpy as np
import pandas as pd

warnings.filterwarnings("ignore", category=pd.errors.DtypeWarning)

# ----------------------------
# 0) Paths
# ----------------------------
if "OUT_DIR" not in globals():
    OUT_DIR = Path("/kaggle/working/rna3d_artifacts_v1")
else:
    OUT_DIR = Path(OUT_DIR)

TABLE_DIR = OUT_DIR / "tables"
META_DIR  = OUT_DIR / "meta"
LBL_DIR   = OUT_DIR / "labels_npz"
LBL_DIR_TRAIN = LBL_DIR / "train"
LBL_DIR_VAL   = LBL_DIR / "val"
for p in [TABLE_DIR, META_DIR, LBL_DIR_TRAIN, LBL_DIR_VAL]:
    p.mkdir(parents=True, exist_ok=True)

def _req(p: Path):
    if not p.exists():
        raise FileNotFoundError(f"Missing file: {p}")
    return p

targets2_train = pd.read_parquet(_req(TABLE_DIR / "targets_train_stage2.parquet"))
targets2_val   = pd.read_parquet(_req(TABLE_DIR / "targets_val_stage2.parquet"))

# normalize sequences
targets2_train["target_id"] = targets2_train["target_id"].astype("string")
targets2_val["target_id"]   = targets2_val["target_id"].astype("string")
targets2_train["sequence"]  = targets2_train["sequence"].astype("string").str.upper().str.replace("T","U", regex=False)
targets2_val["sequence"]    = targets2_val["sequence"].astype("string").str.upper().str.replace("T","U", regex=False)

targets2_train["L"] = targets2_train["L"].astype("int32")
targets2_val["L"]   = targets2_val["L"].astype("int32")

# Train: keep only rebuild_ok=True (train had some failures in Stage 2)
if "rebuild_ok" in targets2_train.columns:
    targets2_train_ok = targets2_train[targets2_train["rebuild_ok"] == True].copy()
else:
    targets2_train_ok = targets2_train.copy()

targets2_val_ok = targets2_val.copy()  # val should be all ok

print(f"[train] targets total={len(targets2_train):,} | using rebuild_ok={len(targets2_train_ok):,}")
print(f"[val]   targets total={len(targets2_val_ok):,}")

# ----------------------------
# 1) Get labels DF (use Stage 1 globals if present, else reload)
# ----------------------------
COORD_PAT = re.compile(r"^[xyz]_\d+$")

def _read_labels_fast(csv_path: str | Path, split_name: str):
    csv_path = Path(csv_path)
    cols = pd.read_csv(csv_path, nrows=0).columns.tolist()
    coord_cols = [c for c in cols if COORD_PAT.match(c)]
    required = {"ID","resname","resid"}
    missing_req = sorted(list(required - set(cols)))
    if missing_req:
        raise ValueError(f"[{split_name}] labels missing required columns: {missing_req}")
    dtype = {}
    for c in cols:
        if c in coord_cols:
            dtype[c] = "float32"
        elif c == "resid":
            dtype[c] = "int32"
        elif c == "copy":
            dtype[c] = "int16"
        else:
            dtype[c] = "string"
    df = pd.read_csv(csv_path, dtype=dtype)
    # standardize
    df["ID"] = df["ID"].astype("string").fillna("").str.strip()
    df["resname"] = df["resname"].astype("string").fillna("").str.strip().str.upper().str.replace("T","U", regex=False)
    # extract target_id from ID
    sp = df["ID"].str.rsplit("_", n=1, expand=True)
    if sp.shape[1] == 2:
        df["target_id"] = sp[0].astype("string")
        df["resid_from_id"] = pd.to_numeric(sp[1], errors="coerce").astype("Int32")
    else:
        df["target_id"] = pd.NA
        df["resid_from_id"] = pd.NA
    return df, coord_cols

# locate PATHS if needed for reload
if "PATHS" not in globals():
    # fallback: assume competition root fixed
    COMP_ROOT = Path("/kaggle/input/stanford-rna-3d-folding-2")
    PATHS = {
        "TRAIN_LBL": str(COMP_ROOT / "train_labels.csv"),
        "VAL_LBL": str(COMP_ROOT / "validation_labels.csv"),
    }

if "df_train_lbl" in globals() and isinstance(df_train_lbl, pd.DataFrame):
    df_train_lbl_use = df_train_lbl
    train_coord_cols = [c for c in df_train_lbl_use.columns if COORD_PAT.match(c)]
else:
    df_train_lbl_use, train_coord_cols = _read_labels_fast(PATHS["TRAIN_LBL"], "train_labels")

if "df_val_lbl" in globals() and isinstance(df_val_lbl, pd.DataFrame):
    df_val_lbl_use = df_val_lbl
    val_coord_cols = [c for c in df_val_lbl_use.columns if COORD_PAT.match(c)]
else:
    df_val_lbl_use, val_coord_cols = _read_labels_fast(PATHS["VAL_LBL"], "validation_labels")

# ensure target_id exists (in case Stage1 used different extraction)
for df in [df_train_lbl_use, df_val_lbl_use]:
    if "target_id" not in df.columns:
        sp = df["ID"].astype("string").str.rsplit("_", n=1, expand=True)
        df["target_id"] = sp[0].astype("string")

# determine n_ref
def _nref(coord_cols):
    if len(coord_cols) % 3 != 0:
        raise ValueError(f"Coord cols not multiple of 3: {len(coord_cols)}")
    return len(coord_cols) // 3

nref_train = _nref(train_coord_cols)
nref_val   = _nref(val_coord_cols)
print(f"[train_labels] coord_cols={len(train_coord_cols)} => n_ref={nref_train}")
print(f"[val_labels]   coord_cols={len(val_coord_cols)} => n_ref={nref_val}")

# build ordered triplets per ref k: (x_k,y_k,z_k)
def _coord_triplets(coord_cols):
    # build dict k -> {x,y,z}
    d = {}
    for c in coord_cols:
        axis, k = c.split("_", 1)
        k = int(k)
        d.setdefault(k, {})[axis] = c
    ks = sorted(d.keys())
    triplets = []
    for k in ks:
        if not all(a in d[k] for a in ["x","y","z"]):
            raise ValueError(f"Missing xyz for ref={k}")
        triplets.append((k, d[k]["x"], d[k]["y"], d[k]["z"]))
    return triplets

train_triplets = _coord_triplets(train_coord_cols)
val_triplets   = _coord_triplets(val_coord_cols)

# ----------------------------
# 2) Core builder: per target -> npz
# ----------------------------
def build_labels_npz(split_name: str, targets_ok: pd.DataFrame, df_lbl: pd.DataFrame, triplets, out_dir: Path,
                     hard_fail: bool = True, print_every: int = 200):

    out_dir.mkdir(parents=True, exist_ok=True)

    # maps for quick access
    seq_map = dict(zip(targets_ok["target_id"].tolist(), targets_ok["sequence"].tolist()))
    L_map   = dict(zip(targets_ok["target_id"].tolist(), targets_ok["L"].tolist()))
    tids = targets_ok["target_id"].tolist()
    tid_set = set(tids)

    # filter labels to only these targets
    df = df_lbl.copy()
    df["target_id"] = df["target_id"].astype("string")
    df = df[df["target_id"].isin(tid_set)]

    # ensure resid/resname types
    df["resid"] = pd.to_numeric(df["resid"], errors="coerce").astype("Int32")
    df["resname"] = df["resname"].astype("string").fillna("").str.strip().str.upper().str.replace("T","U", regex=False)

    # groupby
    g = df.groupby("target_id", sort=False)

    qa_rows = []
    manifest_rows = []

    done = 0
    for tid in tids:
        seq = str(seq_map[tid])
        L = int(L_map[tid])

        ok = True
        err = ""
        npz_path = out_dir / f"{tid}.npz"

        try:
            if tid not in g.indices:
                raise ValueError("no label rows for target_id")

            part = g.get_group(tid)

            # sort by resid
            part = part.sort_values("resid", kind="mergesort")

            # coverage check
            resid = part["resid"].to_numpy(dtype=np.int32, copy=False)
            if resid.size != L:
                raise ValueError(f"row count != L (rows={resid.size}, L={L})")
            if resid[0] != 1 or resid[-1] != L:
                raise ValueError(f"resid range not 1..L (min={resid[0]}, max={resid[-1]})")
            # unique & consecutive
            if np.any(np.diff(resid) != 1):
                raise ValueError("resid not strictly consecutive 1..L")

            # resname check against sequence
            # compare as bytes for speed
            seq_arr = np.frombuffer(seq.encode("ascii"), dtype="S1").astype("U1")
            resname = part["resname"].to_numpy(dtype="U1", copy=False)
            if seq_arr.shape[0] != L:
                raise ValueError("sequence length mismatch in memory")
            mismatch = np.nonzero(resname != seq_arr)[0]
            if mismatch.size > 0:
                j = int(mismatch[0])
                raise ValueError(f"resname mismatch at resid={j+1}: label={resname[j]} vs seq={seq_arr[j]}")

            # build coords_ref (n_ref, L, 3)
            n_ref = len(triplets)
            coords = np.empty((n_ref, L, 3), dtype=np.float32)

            for ri, (k, xcol, ycol, zcol) in enumerate(triplets):
                x = part[xcol].to_numpy(dtype=np.float32, copy=False)
                y = part[ycol].to_numpy(dtype=np.float32, copy=False)
                z = part[zcol].to_numpy(dtype=np.float32, copy=False)
                coords[ri, :, 0] = x
                coords[ri, :, 1] = y
                coords[ri, :, 2] = z

            mask_valid = np.isfinite(coords).all(axis=2)  # (n_ref, L)

            # write npz (fast)
            # store resname_seq so later stages don't need to re-read sequences parquet for validation
            np.savez(
                npz_path,
                target_id=str(tid),
                L=np.int32(L),
                coords_ref=coords,             # (n_ref,L,3)
                mask_valid=mask_valid,         # (n_ref,L)
                resname_seq=seq_arr,           # (L,)
            )

        except Exception as e:
            ok = False
            err = str(e)
            if hard_fail:
                raise

        qa_rows.append({
            "split": split_name,
            "target_id": tid,
            "L": L,
            "n_ref": len(triplets),
            "ok": ok,
            "npz_path": str(npz_path) if ok else "",
            "err": err,
        })

        if ok:
            manifest_rows.append({
                "target_id": tid,
                "L": L,
                "n_ref": len(triplets),
                "npz_path": str(npz_path),
            })

        done += 1
        if done % print_every == 0:
            print(f"[{split_name}] written {done:,}/{len(tids):,} npz")

        # periodic GC
        if done % (print_every * 5) == 0:
            gc.collect()

    qa_df = pd.DataFrame(qa_rows)
    manifest_df = pd.DataFrame(manifest_rows)

    # summary
    n_ok = int(qa_df["ok"].sum())
    n_bad = len(qa_df) - n_ok
    print(f"\n[{split_name}] STAGE4 labels ok: {n_ok:,}/{len(qa_df):,} (bad={n_bad:,})")
    if n_bad > 0:
        print(f"[{split_name}] failures (up to 10):")
        print(qa_df.loc[~qa_df["ok"]].head(10)[["target_id","L","err"]].to_string(index=False))

    return qa_df, manifest_df

# ----------------------------
# 3) Run for train + val
# ----------------------------
print("\n=== STAGE 4: Build label tensors (npz) ===")
qa4_train, manifest_train = build_labels_npz(
    split_name="train",
    targets_ok=targets2_train_ok,
    df_lbl=df_train_lbl_use,
    triplets=train_triplets,
    out_dir=LBL_DIR_TRAIN,
    hard_fail=False,      # keep running and report; you can set True once you trust everything
    print_every=200
)

qa4_val, manifest_val = build_labels_npz(
    split_name="val",
    targets_ok=targets2_val_ok,
    df_lbl=df_val_lbl_use,
    triplets=val_triplets,
    out_dir=LBL_DIR_VAL,
    hard_fail=True,
    print_every=28
)

# ----------------------------
# 4) Save QA + manifest
# ----------------------------
qa4_train.to_csv(META_DIR / "qa_stage4_train.csv", index=False)
qa4_val.to_csv(  META_DIR / "qa_stage4_val.csv",   index=False)

manifest_train.to_parquet(META_DIR / "labels_manifest_train.parquet", index=False)
manifest_val.to_parquet(  META_DIR / "labels_manifest_val.parquet",   index=False)

# Save pointers
ptr = {
    "labels_npz_train_dir": str(LBL_DIR_TRAIN),
    "labels_npz_val_dir": str(LBL_DIR_VAL),
    "manifest_train": str(META_DIR / "labels_manifest_train.parquet"),
    "manifest_val": str(META_DIR / "labels_manifest_val.parquet"),
    "qa_train": str(META_DIR / "qa_stage4_train.csv"),
    "qa_val": str(META_DIR / "qa_stage4_val.csv"),
}
(META_DIR / "labels_paths_stage4.json").write_text(json.dumps(ptr, indent=2))

print("\n[OK] STAGE 4 complete.")
print(json.dumps(ptr, indent=2))


[train] targets total=5,716 | using rebuild_ok=5,550
[val]   targets total=28
[train_labels] coord_cols=3 => n_ref=1
[val_labels]   coord_cols=120 => n_ref=40

=== STAGE 4: Build label tensors (npz) ===
[train] written 200/5,550 npz
[train] written 400/5,550 npz
[train] written 600/5,550 npz
[train] written 800/5,550 npz
[train] written 1,000/5,550 npz
[train] written 1,200/5,550 npz
[train] written 1,400/5,550 npz
[train] written 1,600/5,550 npz
[train] written 1,800/5,550 npz
[train] written 2,000/5,550 npz
[train] written 2,200/5,550 npz
[train] written 2,400/5,550 npz
[train] written 2,600/5,550 npz
[train] written 2,800/5,550 npz
[train] written 3,000/5,550 npz
[train] written 3,200/5,550 npz
[train] written 3,400/5,550 npz
[train] written 3,600/5,550 npz
[train] written 3,800/5,550 npz
[train] written 4,000/5,550 npz
[train] written 4,200/5,550 npz
[train] written 4,400/5,550 npz
[train] written 4,600/5,550 npz
[train] written 4,800/5,550 npz
[train] written 5,000/5,550 npz
[trai

# Build TBM Template Index + Parse MSA (Train/Val) + Export Artifacts Dataset

In [9]:
# ============================================================
# NOTEBOOK-1 / STAGE 5 — Build TBM Template Index + Parse MSA (Train/Val) + Export Artifacts Manifest (ONE CELL)
#
# Inputs (competition):
# - /kaggle/input/stanford-rna-3d-folding-2/extra/rna_metadata.csv
# - /kaggle/input/stanford-rna-3d-folding-2/PDB_RNA/*.cif
# - /kaggle/input/stanford-rna-3d-folding-2/MSA/{target_id}.MSA.fasta
#
# Requires previous stages outputs in OUT_DIR (default: /kaggle/working/rna3d_artifacts_v1):
# - tables/targets_train_stage2.parquet
# - tables/targets_val_stage2.parquet
# - meta/qa_stage2_*.csv, meta/qa_stage3_*.csv, meta/qa_stage4_*.csv (optional but expected)
#
# Outputs:
# - OUT_DIR/tbm/template_index.parquet
# - OUT_DIR/msa/msa_index_train.parquet
# - OUT_DIR/msa/msa_index_val.parquet
# - OUT_DIR/meta/artifacts_manifest_stage5.json
# - OUT_DIR/meta/stage5_config.json
#
# Notes:
# - Train has some rebuild_ok=False from Stage2; we index MSAs for rebuild_ok=True only (recommended).
# - MSA parsing here is "light": sanity + stats + stable pointers. Full MSA usage happens in Notebook-2/3.
# ============================================================

import os, re, gc, json, time, warnings
from pathlib import Path
from typing import Dict, Any, Tuple, Optional

import numpy as np
import pandas as pd

warnings.filterwarnings("ignore", category=pd.errors.DtypeWarning)

# ----------------------------
# 0) Paths
# ----------------------------
if "PATHS" not in globals():
    COMP_ROOT = Path("/kaggle/input/stanford-rna-3d-folding-2")
    PATHS = {
        "COMP_ROOT": str(COMP_ROOT),
        "MSA_DIR": str(COMP_ROOT / "MSA"),
        "PDB_RNA_DIR": str(COMP_ROOT / "PDB_RNA"),
        "EXTRA_DIR": str(COMP_ROOT / "extra"),
        "RNA_METADATA": str(COMP_ROOT / "extra" / "rna_metadata.csv"),
    }

COMP_ROOT   = Path(PATHS.get("COMP_ROOT", "/kaggle/input/stanford-rna-3d-folding-2"))
MSA_DIR     = Path(PATHS.get("MSA_DIR", COMP_ROOT / "MSA"))
PDB_RNA_DIR = Path(PATHS.get("PDB_RNA_DIR", COMP_ROOT / "PDB_RNA"))
EXTRA_DIR   = Path(PATHS.get("EXTRA_DIR", COMP_ROOT / "extra"))
RNA_META_CSV= Path(PATHS.get("RNA_METADATA", EXTRA_DIR / "rna_metadata.csv"))

if "OUT_DIR" not in globals():
    OUT_DIR = Path("/kaggle/working/rna3d_artifacts_v1")
else:
    OUT_DIR = Path(OUT_DIR)

TABLE_DIR = OUT_DIR / "tables"
META_DIR  = OUT_DIR / "meta"
TBM_DIR   = OUT_DIR / "tbm"
MSA_OUT   = OUT_DIR / "msa"

for p in [TABLE_DIR, META_DIR, TBM_DIR, MSA_OUT]:
    p.mkdir(parents=True, exist_ok=True)

def _req(p: Path, what="file/dir"):
    if not p.exists():
        raise FileNotFoundError(f"Missing {what}: {p}")
    return p

_req(MSA_DIR, "MSA_DIR")
_req(PDB_RNA_DIR, "PDB_RNA_DIR")
_req(RNA_META_CSV, "rna_metadata.csv")
_req(TABLE_DIR / "targets_train_stage2.parquet", "targets_train_stage2.parquet")
_req(TABLE_DIR / "targets_val_stage2.parquet", "targets_val_stage2.parquet")

# ----------------------------
# 1) Build TBM template index from rna_metadata.csv + has_cif
# ----------------------------
print("=== STAGE 5.1: Build TBM template_index.parquet ===")

# Build a fast lookup set for existing CIF stems (case-insensitive)
cif_stems = set()
# Using iterdir is faster than glob on some FS
for p in PDB_RNA_DIR.iterdir():
    if p.is_file() and p.suffix.lower() == ".cif":
        cif_stems.add(p.stem.lower())

print(f"[PDB_RNA] cif files found: {len(cif_stems):,} (stems cached)")

# Read rna_metadata with usecols intersection (memory-safe)
meta_cols = pd.read_csv(RNA_META_CSV, nrows=0).columns.tolist()

wanted = [
    "target_id","pdb_id","chain_id","auth_chain_id","entity_id","entity_type",
    "sequence","canonical_sequence","full_sequence",
    "temporal_cutoff","resolution","method","title","keyword_ribosome",
    "group_id","seq_group_id",
    "mmseqs_0.900","mmseqs_0.950","mmseqs_0.850",
    "composition_rna_fraction","composition_na_hybrid_fraction",
    "length","length_observed","length_expected","fraction_observed",
    "missing_residues","nonstandard_residues","undefined_residues","unexpected_residues",
    "total_structuredness_adjusted","intra_chain_structuredness_adjusted","inter_chain_structuredness_adjusted",
]
usecols = [c for c in wanted if c in meta_cols]
# Always ensure these exist if present
if "pdb_id" not in usecols and "pdb_id" in meta_cols:
    usecols.append("pdb_id")
if "sequence" not in usecols and "sequence" in meta_cols:
    usecols.append("sequence")

# dtype suggestions (best-effort)
dtype = {}
for c in usecols:
    if c in ["keyword_ribosome","missing_residues","nonstandard_residues","undefined_residues","unexpected_residues"]:
        dtype[c] = "boolean"
    elif c in ["length","length_observed","length_expected"]:
        dtype[c] = "Int32"
    elif c in ["resolution","composition_rna_fraction","composition_na_hybrid_fraction",
               "fraction_observed","total_structuredness_adjusted",
               "intra_chain_structuredness_adjusted","inter_chain_structuredness_adjusted"]:
        dtype[c] = "float32"
    else:
        dtype[c] = "string"

t0 = time.time()
df_meta = pd.read_csv(RNA_META_CSV, usecols=usecols, dtype=dtype, low_memory=False)
print(f"[rna_metadata] loaded rows={len(df_meta):,} cols={len(df_meta.columns)} in {time.time()-t0:.1f}s")

# Normalize key fields
if "pdb_id" in df_meta.columns:
    df_meta["pdb_id"] = df_meta["pdb_id"].astype("string").str.strip()
    df_meta["pdb_id_lc"] = df_meta["pdb_id"].str.lower()

if "sequence" in df_meta.columns:
    df_meta["sequence"] = df_meta["sequence"].astype("string").str.strip().str.upper().str.replace("T","U", regex=False)
    df_meta["seq_len"] = df_meta["sequence"].str.len().astype("Int32")

# has_cif flag
if "pdb_id_lc" in df_meta.columns:
    df_meta["has_cif"] = df_meta["pdb_id_lc"].isin(cif_stems)
else:
    df_meta["has_cif"] = False

# Normalize temporal_cutoff if present
if "temporal_cutoff" in df_meta.columns:
    df_meta["temporal_cutoff_dt"] = pd.to_datetime(df_meta["temporal_cutoff"], errors="coerce")

# Save template index
template_path = TBM_DIR / "template_index.parquet"
df_meta.to_parquet(template_path, index=False)
print(f"[OK] Saved: {template_path}")

# ----------------------------
# 2) Parse MSA for Train/Val: build msa_index_{split}.parquet (stats + pointers)
# ----------------------------
print("\n=== STAGE 5.2: Parse MSA (Train/Val) -> msa_index_{split}.parquet ===")

targets_train2 = pd.read_parquet(TABLE_DIR / "targets_train_stage2.parquet")
targets_val2   = pd.read_parquet(TABLE_DIR / "targets_val_stage2.parquet")

targets_train2["target_id"] = targets_train2["target_id"].astype("string")
targets_val2["target_id"]   = targets_val2["target_id"].astype("string")

# Use only rebuild_ok=True for train
if "rebuild_ok" in targets_train2.columns:
    train_ids = targets_train2.loc[targets_train2["rebuild_ok"] == True, "target_id"].tolist()
else:
    train_ids = targets_train2["target_id"].tolist()

val_ids = targets_val2["target_id"].tolist()

print(f"[train] indexing MSA for targets: {len(train_ids):,}")
print(f"[val]   indexing MSA for targets: {len(val_ids):,}")

# Lightweight MSA scanner (fast + safe)
# We do NOT build per-position arrays here; we just validate + compute robust summary stats.
def scan_msa_file(path: Path, max_chars: int = 3_000_000) -> Dict[str, Any]:
    """
    Returns:
      exists, parse_ok, n_seqs_est, aln_len_first, chars_scanned, gaps_scanned, gap_frac_est, truncated
    Notes:
      - If file is huge, we stop after max_chars of sequence characters to keep runtime reasonable.
      - n_seqs_est may be undercounted if truncated early (flagged by truncated=True).
    """
    if not path.exists():
        return dict(exists=False, parse_ok=False, n_seqs_est=0, aln_len_first=np.nan,
                    chars_scanned=0, gaps_scanned=0, gap_frac_est=np.nan, truncated=False, err="missing")

    n_headers = 0
    aln_len_first = None
    chars = 0
    gaps = 0
    truncated = False
    err = ""

    try:
        with path.open("r", encoding="utf-8", errors="replace") as f:
            in_seq = False
            cur_len_first = 0
            first_seq_done = False

            for line in f:
                if line.startswith(">"):
                    n_headers += 1
                    in_seq = True
                    continue
                if not in_seq:
                    continue
                s = line.strip()
                if not s:
                    continue

                # For first sequence length estimate, count all sequence chars until next header.
                if not first_seq_done:
                    cur_len_first += len(s)
                # scan stats (cap by max_chars)
                if chars < max_chars:
                    # count gaps and chars
                    gaps += s.count("-")
                    chars += len(s)
                else:
                    truncated = True
                    # we can still count headers if we want, but that requires reading rest of file;
                    # keep it fast: break.
                    first_seq_done = True
                    break

            # If we never met next header, first seq length is what we counted
            if aln_len_first is None:
                aln_len_first = cur_len_first if cur_len_first > 0 else np.nan

        gap_frac = (gaps / chars) if chars > 0 else np.nan
        return dict(
            exists=True,
            parse_ok=True,
            n_seqs_est=n_headers,
            aln_len_first=float(aln_len_first) if aln_len_first is not None else np.nan,
            chars_scanned=int(chars),
            gaps_scanned=int(gaps),
            gap_frac_est=float(gap_frac) if gap_frac == gap_frac else np.nan,
            truncated=bool(truncated),
            err=""
        )
    except Exception as e:
        err = str(e)
        return dict(exists=True, parse_ok=False, n_seqs_est=n_headers, aln_len_first=np.nan,
                    chars_scanned=int(chars), gaps_scanned=int(gaps), gap_frac_est=np.nan,
                    truncated=bool(truncated), err=err)

def build_msa_index(split_name: str, target_ids, max_chars: int = 3_000_000, print_every: int = 500):
    rows = []
    t0 = time.time()

    for i, tid in enumerate(target_ids, 1):
        msa_path = MSA_DIR / f"{tid}.MSA.fasta"
        stats = scan_msa_file(msa_path, max_chars=max_chars)

        rows.append({
            "split": split_name,
            "target_id": str(tid),
            "msa_path": str(msa_path),
            **stats
        })

        if i % print_every == 0:
            ok = sum(1 for r in rows if r["exists"] and r["parse_ok"])
            miss = sum(1 for r in rows if not r["exists"])
            print(f"[{split_name}] {i:,}/{len(target_ids):,} | ok={ok:,} | missing={miss:,} | elapsed={time.time()-t0:.1f}s")

    df = pd.DataFrame(rows)
    # simple QA summary
    n_total = len(df)
    n_exist = int(df["exists"].sum())
    n_ok = int((df["exists"] & df["parse_ok"]).sum())
    n_trunc = int(df["truncated"].sum())
    print(f"\n[{split_name}] msa_index: total={n_total:,} | exists={n_exist:,} | parse_ok={n_ok:,} | truncated={n_trunc:,}")
    return df

msa_train_df = build_msa_index("train", train_ids, max_chars=3_000_000, print_every=500)
msa_val_df   = build_msa_index("val",   val_ids,   max_chars=3_000_000, print_every=28)

msa_train_path = MSA_OUT / "msa_index_train.parquet"
msa_val_path   = MSA_OUT / "msa_index_val.parquet"
msa_train_df.to_parquet(msa_train_path, index=False)
msa_val_df.to_parquet(msa_val_path, index=False)

print(f"[OK] Saved: {msa_train_path}")
print(f"[OK] Saved: {msa_val_path}")

# ----------------------------
# 3) Export manifest/config (for Notebook-2/3 stability)
# ----------------------------
print("\n=== STAGE 5.3: Write manifest/config ===")

cfg = {
    "stage5_time_utc": time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime()),
    "paths": {
        "OUT_DIR": str(OUT_DIR),
        "TBM_DIR": str(TBM_DIR),
        "MSA_OUT": str(MSA_OUT),
        "PDB_RNA_DIR": str(PDB_RNA_DIR),
        "MSA_DIR": str(MSA_DIR),
        "RNA_META_CSV": str(RNA_META_CSV),
    },
    "notes": {
        "train_policy": "Use only rebuild_ok=True targets for training downstream.",
        "msa_policy": "MSA parsed to stable pointers + lightweight stats (full MSA usage occurs in Notebook-2/3).",
    },
    "stats": {
        "n_train_targets_stage2_total": int(len(targets_train2)),
        "n_train_targets_rebuild_ok": int(len(train_ids)),
        "n_val_targets": int(len(val_ids)),
        "n_template_rows": int(len(df_meta)),
        "n_pdb_cif_files": int(len(cif_stems)),
    }
}

manifest = {
    "tables": {
        "targets_train_stage2": str(TABLE_DIR / "targets_train_stage2.parquet"),
        "targets_val_stage2": str(TABLE_DIR / "targets_val_stage2.parquet"),
        "targets_test_stage2": str(TABLE_DIR / "targets_test_stage2.parquet") if (TABLE_DIR / "targets_test_stage2.parquet").exists() else "",
        "segments_train": str(TABLE_DIR / "segments_train.parquet"),
        "segments_val": str(TABLE_DIR / "segments_val.parquet"),
        "segments_test": str(TABLE_DIR / "segments_test.parquet"),
    },
    "labels": {
        "labels_npz_train_dir": str(OUT_DIR / "labels_npz" / "train"),
        "labels_npz_val_dir": str(OUT_DIR / "labels_npz" / "val"),
        "manifest_train": str(META_DIR / "labels_manifest_train.parquet") if (META_DIR / "labels_manifest_train.parquet").exists() else "",
        "manifest_val": str(META_DIR / "labels_manifest_val.parquet") if (META_DIR / "labels_manifest_val.parquet").exists() else "",
    },
    "residue_index": {
        "residue_index_dir": str(OUT_DIR / "residue_index"),
        "train_glob": str(OUT_DIR / "residue_index" / "train" / "part-*.parquet"),
        "val_glob": str(OUT_DIR / "residue_index" / "val" / "part-*.parquet"),
        "test_glob": str(OUT_DIR / "residue_index" / "test" / "part-*.parquet"),
    },
    "tbm": {
        "template_index": str(template_path),
    },
    "msa": {
        "msa_index_train": str(msa_train_path),
        "msa_index_val": str(msa_val_path),
    }
}

# Atomic JSON write to avoid partial JSON issues
def write_json_atomic(path: Path, obj: Dict[str, Any]):
    tmp = path.with_suffix(path.suffix + ".tmp")
    tmp.write_text(json.dumps(obj, indent=2), encoding="utf-8")
    tmp.replace(path)

write_json_atomic(META_DIR / "stage5_config.json", cfg)
write_json_atomic(META_DIR / "artifacts_manifest_stage5.json", manifest)

print(f"[OK] Wrote: {META_DIR/'stage5_config.json'}")
print(f"[OK] Wrote: {META_DIR/'artifacts_manifest_stage5.json'}")

print("\n[OK] STAGE 5 complete.")
print("Next (recommended): Click 'Save Version' on this notebook output to create a reusable Kaggle Dataset from:")
print(f"  {OUT_DIR}")


=== STAGE 5.1: Build TBM template_index.parquet ===
[PDB_RNA] cif files found: 9,564 (stems cached)
[rna_metadata] loaded rows=26,255 cols=31 in 1.8s
[OK] Saved: /kaggle/working/rna3d_artifacts_v1/tbm/template_index.parquet

=== STAGE 5.2: Parse MSA (Train/Val) -> msa_index_{split}.parquet ===
[train] indexing MSA for targets: 5,550
[val]   indexing MSA for targets: 28
[train] 500/5,550 | ok=500 | missing=0 | elapsed=10.9s
[train] 1,000/5,550 | ok=1,000 | missing=0 | elapsed=28.5s
[train] 1,500/5,550 | ok=1,500 | missing=0 | elapsed=43.2s
[train] 2,000/5,550 | ok=2,000 | missing=0 | elapsed=65.8s
[train] 2,500/5,550 | ok=2,500 | missing=0 | elapsed=82.4s
[train] 3,000/5,550 | ok=3,000 | missing=0 | elapsed=105.0s
[train] 3,500/5,550 | ok=3,500 | missing=0 | elapsed=122.1s
[train] 4,000/5,550 | ok=4,000 | missing=0 | elapsed=144.9s
[train] 4,500/5,550 | ok=4,500 | missing=0 | elapsed=166.6s
[train] 5,000/5,550 | ok=5,000 | missing=0 | elapsed=182.1s
[train] 5,500/5,550 | ok=5,500 | miss