Atomworks datasets is like 3-stage pipeline:
1. Dataset: index -> raw example
2. Loader: raw example -> statandardized dict, typically include atom_array, example_id, extra_info
3. Transform: standardized dict -> model-ready dict (cropping/atomize/encoding/features; keep annotation and featurization seperate)

Protocol:
1. Explore the raw data 
2. Define the rule for example_id (unique across dataset, tracable back to origin)
3. Write loader (normalize metadata + parse structure; don't add model feature here)
4. Write transform (split into annotation transforms then featurization transforms)
5. Handle failtures (save_failed_examples_to_dir / fallback)
6. Verify: random pick 10 sample for visualization + sanity stats like atom counts, chain IDs, missing coords
7. Wire training IO (DataLoader + Sample + collate strategy for variable-length structures)
8. Record idx to id to path to metadata (reversable)

The metadata schema in Atomworks:
A0. minimum you need to have:
1. example_id: str -> need to be unique across all dataset
2. path: str -> absolute str
3. assembly_id: str -> choicable, always 1 as default
4. dataset_name: str -> such as 'test_cifs'
5. file_size: int -> check missing file
6. file_ext: str (.cif/.cif.gz/.pdb)

# Unified metadata

In [40]:
from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Optional, Any, Dict, Iterable

import pandas as pd

@dataclass
class BuildMetaConfig:
    dataset_name: str
    base_dir: Path
    assembly_id: str = "1"

    # Example ID function: (path, rel_path, dataset_name) -> str
    example_id_fn: Optional[Callable[[Path, Path, str], str]] = None

    # Quick checks / enrichments
    add_file_stats: bool = True
    add_has_atom_site: bool = True

    # Parse health-check (slower, but catches issues early)
    do_parse_check: bool = True
    parse_kwargs: Optional[Dict[str, Any]] = None  # passed to atomworks.io.parse
    build_assembly: tuple[str, ...] = ("1",)

    # Error string length to store in parquet
    error_maxlen: int = 300


def _default_example_id_fn(path: Path, rel: Path, dataset_name: str) -> str:
    # Make a stable, mostly-readable ID
    s = str(rel).replace("/", "_").replace("\\", "_")
    s = s.replace(".", "_")
    return f"{dataset_name}__{s}"


def _quick_has_atom_site(path: Path, max_lines: int = 4000) -> bool:
    # Heuristic: scan early lines for _atom_site.
    # Works for text cif; for .gz you should decompress or skip this check.
    try:
        with path.open("rt", errors="ignore") as f:
            for _ in range(max_lines):
                line = f.readline()
                if not line:
                    break
                if "_atom_site." in line:
                    return True
    except Exception:
        return False
    return False


def build_metadata_from_filedataset(
    ds_raw: Iterable[Any],
    out_parquet: str | Path,
    cfg: BuildMetaConfig,
) -> pd.DataFrame:
    """
    Build AtomWorks-style metadata from a FileDataset-like iterable where ds_raw[i] yields PathLike.
    """
    out_parquet = Path(out_parquet)

    base_dir = Path(cfg.base_dir).resolve() 
    exid_fn = cfg.example_id_fn or _default_example_id_fn

    # Optional parse import only if needed
    parse = None
    if cfg.do_parse_check:
        from atomworks.io import parse as _parse
        parse = _parse

    # Default parse kwargs: avoid entity/chain_info pitfalls unless you really need them
    parse_kwargs = {
        "add_missing_atoms": False,
        "fix_formal_charges": False,
    }
    if cfg.parse_kwargs:
        parse_kwargs.update(cfg.parse_kwargs)

    rows = []
    for i, raw in enumerate(ds_raw):
        p = Path(raw).resolve()
        rel = p.relative_to(base_dir) if str(p).startswith(str(base_dir)) else Path(p.name)

        row = {
            "dataset_name": cfg.dataset_name,
            "example_id": exid_fn(p, rel, cfg.dataset_name),
            "path": str(p),
            "assembly_id": str(cfg.assembly_id),
            "file_ext": "".join(p.suffixes),
        }

        if cfg.add_file_stats:
            try:
                st = p.stat()
                row["file_size"] = int(st.st_size)
                row["mtime"] = float(st.st_mtime)
            except Exception:
                row["file_size"] = None
                row["mtime"] = None

        if cfg.add_has_atom_site:
            # If gz, this heuristic won't work without decompressing. Keep False/None.
            if row["file_ext"].endswith(".gz"):
                row["has_atom_site"] = None
            else:
                row["has_atom_site"] = _quick_has_atom_site(p)

        # Parse check
        row["parse_ok"] = None
        row["n_atoms"] = None
        row["error"] = None
        if cfg.do_parse_check and parse is not None:
            try:
                out = parse(
                    filename=str(p),
                    build_assembly=cfg.build_assembly,
                    **parse_kwargs,
                )
                aa = out["assemblies"][cfg.build_assembly[0]][0]
                row["parse_ok"] = True
                row["n_atoms"] = int(len(aa))
            except Exception as e:
                row["parse_ok"] = False
                msg = repr(e)
                row["error"] = msg[: cfg.error_maxlen]

        rows.append(row)

    df = pd.DataFrame(rows)

    # Attach dataset-level metadata in attrs (AtomWorks-style)
    df.attrs["dataset_name"] = cfg.dataset_name
    df.attrs["base_dir"] = str(base_dir)
    df.attrs["schema_version"] = "atomworks_like_v1"
    df.attrs["notes"] = (
        "Built from FileDataset; includes parse_ok/n_atoms/error if do_parse_check=True."
    )

    df.to_parquet(out_parquet, index=False)
    return df


In [44]:
from atomworks.ml.datasets import FileDataset
from pathlib import Path

ds_raw = FileDataset.from_directory(
    directory="/home/eva/20251207_python_playground/20260109_playground_biotite_atomworks/example_cifs",
    name="test_cifs",
    max_depth=5,
)

cfg = BuildMetaConfig(
    dataset_name="test_cifs",
    base_dir=Path("/home/eva/20251207_python_playground/20260109_playground_biotite_atomworks/example_cifs"),
    do_parse_check=True,
    build_assembly=("1",),
    # you could add parse_kwargs to cover default values
)

df = build_metadata_from_filedataset(
    ds_raw=ds_raw,
    out_parquet="test_cifs_metadata.parquet",
    cfg=cfg,
)

print(df["parse_ok"].value_counts(dropna=False))

# print the df.attrs for dataset-level metadata
print("dataset attributes metadata:")
print(df.attrs)

print("dataset metadata:")
df.head(3)



parse_ok
True    3
Name: count, dtype: int64
dataset attributes metadata:
{'dataset_name': 'test_cifs', 'base_dir': '/home/eva/20251207_python_playground/20260109_playground_biotite_atomworks/example_cifs', 'schema_version': 'atomworks_like_v1', 'notes': 'Built from FileDataset; includes parse_ok/n_atoms/error if do_parse_check=True.'}
dataset metadata:


Unnamed: 0,dataset_name,example_id,path,assembly_id,file_ext,file_size,mtime,has_atom_site,parse_ok,n_atoms,error
0,test_cifs,test_cifs__7ee8_cif,/home/eva/20251207_python_playground/20260109_...,1,.cif,628536,1769950000.0,True,True,2628,
1,test_cifs,test_cifs__7h66_cif,/home/eva/20251207_python_playground/20260109_...,1,.cif,303738,1769950000.0,True,True,1764,
2,test_cifs,test_cifs__7h68_cif,/home/eva/20251207_python_playground/20260109_...,1,.cif,853667,1769950000.0,True,True,3575,


# To write metadata manually into FileDataset

In [10]:
from pathlib import Path
import pandas as pd
from atomworks.io import parse
from atomworks.ml.datasets import FileDataset

# Generate a dataset
ds_raw = FileDataset.from_directory(
    directory="/home/eva/20251207_python_playground/20260109_playground_biotite_atomworks/example_cifs",
    name="test_cifs",
    max_depth=5,
)

# Parse each file and collect metadata
rows = []
for i in range(len(ds_raw)):
    p = Path(ds_raw[i]).resolve()
    row = {
        "example_id": p.stem,
        "path": str(p),
        "assembly_id": "1",
        "parse_ok": False,
        "n_atoms": None,
        "error": None,
    }

    try:
        out = parse(
            filename=str(p),
            build_assembly=("1",),
            add_missing_atoms=False,
            fix_formal_charges=False,
            add_id_and_entity_annotations=False,  # Get around with bad annotations
        )
        aa = out["assemblies"]["1"][0]
        row["parse_ok"] = True
        row["n_atoms"] = len(aa)
    except Exception as e:
        row["error"] = repr(e)[:300]  

    rows.append(row)

# Save metadata to Parquet
df = pd.DataFrame(rows)
df.head(3)


Unnamed: 0,example_id,path,assembly_id,parse_ok,n_atoms,error
0,7ee8,/home/eva/20251207_python_playground/20260109_...,1,True,2628,
1,7h66,/home/eva/20251207_python_playground/20260109_...,1,True,1764,
2,7h68,/home/eva/20251207_python_playground/20260109_...,1,True,3575,


In [11]:
df.attrs["dataset_name"] = "test_cifs"
df.attrs["source"] = "local_example_cifs"
df.attrs["created_by"] = "eva"
df.attrs["notes"] = "Filtered parse_ok only"

df.to_parquet("test_cifs_metadata.parquet", index=False)

print(df["parse_ok"].value_counts())
print("Bad examples (first 5):")
print(df.loc[~df["parse_ok"], ["example_id","path","error"]].head())

parse_ok
True    3
Name: count, dtype: int64
Bad examples (first 5):
Empty DataFrame
Columns: [example_id, path, error]
Index: []


# Read parquet

In [12]:
# Load and filter parsed metadata - this will not include the attrs!!
df = pd.read_parquet("test_cifs_metadata.parquet")

# print columns and first 2 rows
print("Columns:", df.columns.tolist())
print(df.head(2))
df_ok = df[df["parse_ok"]].reset_index(drop=True)
df_ok.to_parquet("test_cifs_metadata_ok.parquet", index=False)

Columns: ['example_id', 'path', 'assembly_id', 'parse_ok', 'n_atoms', 'error']
  example_id                                               path assembly_id  \
0       7ee8  /home/eva/20251207_python_playground/20260109_...           1   
1       7h66  /home/eva/20251207_python_playground/20260109_...           1   

   parse_ok  n_atoms error  
0      True     2628  None  
1      True     1764  None  


# You could restart from here to rebuild dataset from parquet only (the path to cif file not changed)

In [13]:
from atomworks.ml.datasets import PandasDataset
from atomworks.ml.datasets.loaders import create_base_loader

out_path = "test_cifs_metadata_ok.parquet"

loader = create_base_loader(
    example_id_colname="example_id",
    path_colname="path",
    assembly_id_colname="assembly_id",
)

ds = PandasDataset(
    data=out_path,
    name="test_cifs_from_parquet_ok",
    loader=loader,
)

ex = ds[0]

# Inspect the first example
print(ex.keys())
print("example_id:", ex["example_id"])
print("atom_array atoms:", len(ex["atom_array"]))

print("dataset metadata:")
print(ex['metadata'])

print("extra info:")
print(ex['extra_info'])

print("atom_array:")
print(ex['atom_array'].get_annotation_categories())

aa = ex["atom_array"]

for name in aa.get_annotation_categories():
    values = getattr(aa, name)
    print(
        f"{name:25s}",
        f"type={values.dtype if hasattr(values, 'dtype') else type(values)}",
        f"example={values[:5]}"
    )


dict_keys(['example_id', 'path', 'assembly_id', 'extra_info', 'atom_array', 'atom_array_stack', 'chain_info', 'ligand_info', 'metadata'])
example_id: 7ee8
atom_array atoms: 5261
dataset metadata:
{'id': '7ee8', 'method': 'X-RAY_DIFFRACTION', 'deposition_date': '2021-03-17', 'release_date': '2025-03-26', 'resolution': 1.22, 'extra_metadata': None, 'crystallization_details': {'pH': [7.5, 7.5]}}
extra info:
{'PANDAS_ATTRS': '{"dataset_name": "test_cifs", "source": "local_example_cifs", "created_by": "eva", "notes": "Filtered parse_ok only"}', 'parse_ok': True, 'n_atoms': 2628, 'error': None}
atom_array:
['chain_id', 'res_id', 'ins_code', 'res_name', 'hetero', 'atom_name', 'element', 'stereo', 'alt_atom_id', 'is_aromatic', 'is_polymer', 'occupancy', 'chain_type', 'uses_alt_atom_id', 'charge', 'is_backbone_atom', 'b_factor', 'pn_unit_id', 'molecule_id', 'chain_entity', 'pn_unit_entity', 'molecule_entity', 'atomic_number', 'transformation_id', 'chain_iid', 'pn_unit_iid', 'molecule_iid']
chai

Aim:

Input: cif folder path

output 1：chains_metadata.parquet（each row is one chain）

output 2：interfaces_metadata.parquet（each row is a chain pair which satisfy the contact condition）

The parquet can be directly load with PandasDataset + create_loader_with_query_pn_units

In [14]:
from __future__ import annotations

from pathlib import Path
import itertools
import numpy as np
import pandas as pd
import biotite.structure as struc

from atomworks.io import parse

# ---------- Parameter ----------
CIF_DIR = Path("/home/eva/20251207_python_playground/20260109_playground_biotite_atomworks/example_cifs").resolve()
ASSEMBLY_ID = "1"

# Parameters for interface definition
CONTACT_CUTOFF_A = 5.0      # atom-atom distance cutoff (Å)
MIN_CONTACTS = 20           # minimum number of contacts to define an interface
CELL_SIZE = 6.0             # cell list cell size (Å)


def safe_parse(path: Path):
    """
    - add_missing_atoms=False / fix_formal_charges=False avoid modifying the structure
    - build_assembly: only build the specified assembly to save time
    """
    return parse(
        filename=str(path),
        build_assembly=(ASSEMBLY_ID,),
        add_missing_atoms=False,
        fix_formal_charges=False,
    )


def get_polymer_chain_units(atom_array: struc.AtomArray) -> list[dict]:
    """
    Extract polymer chain units from an AtomArray.
      pn_unit_iid, chain_id, atom_indices, n_atoms
    """
    # only care about polymer atoms
    poly = atom_array[atom_array.is_polymer]

    # pn_unit_iid is what we consider as "polymer chain"
    units = []
    for iid in np.unique(poly.pn_unit_iid):
        mask = (poly.pn_unit_iid == iid)
        idx = np.where(mask)[0] # local indices in polymer atom array
        chain_ids = np.unique(poly.chain_id[mask])
        chain_id = chain_ids[0] if len(chain_ids) else None

        units.append({
            "pn_unit_iid": str(iid),
            "chain_id": str(chain_id) if chain_id is not None else None,
            "poly_local_indices": idx,
            "n_atoms": int(mask.sum()),
        })
    return units


def count_chain_contacts(atom_array: struc.AtomArray, iid1: str, iid2: str,
                         cutoff: float = CONTACT_CUTOFF_A) -> tuple[int, float]:
    """
    Calculate number of contacts and minimum distance between two polymer chains.
    Contacts are defined as atom pairs within the specified cutoff distance.
    """
    aa = atom_array[atom_array.is_polymer]
    m1 = (aa.pn_unit_iid == iid1)
    m2 = (aa.pn_unit_iid == iid2)
    if not (np.any(m1) and np.any(m2)):
        return 0, float("inf")

    coord1 = aa.coord[m1]
    coord2 = aa.coord[m2]

    # remove NaN coords
    v1 = ~np.isnan(coord1).any(axis=1)
    v2 = ~np.isnan(coord2).any(axis=1)
    coord1 = coord1[v1]
    coord2 = coord2[v2]
    if len(coord1) == 0 or len(coord2) == 0:
        return 0, float("inf")

    cl = struc.CellList(coord2, cell_size=CELL_SIZE)
    near = cl.get_atoms(coord1, cutoff, as_mask=True)
    has_neighbor = np.any(near, axis=1)
    n_contacts = int(np.sum(has_neighbor))

    # calculate min distance
    min_dist = float("inf")
    if n_contacts > 0:
        # only check those with neighbors
        for i, rowmask in enumerate(near):
            if not np.any(rowmask):
                continue
            d = np.linalg.norm(coord2[rowmask] - coord1[i], axis=1)
            md = float(d.min())
            if md < min_dist:
                min_dist = md

    return n_contacts, min_dist


def make_file_id(path: Path, base_dir: Path) -> str:
    # Make a unique file ID from path relative to base_dir
    rel = path.relative_to(base_dir)
    return str(rel).replace("/", "_").replace(".", "_")


# ---------- Main ----------
chain_rows = []
iface_rows = []
bad_rows = []

cif_paths = sorted([p for p in CIF_DIR.rglob("*") if p.is_file() and "".join(p.suffixes).endswith(".cif")])
print("Found CIFs:", len(cif_paths))

for p in cif_paths:
    file_id = make_file_id(p, CIF_DIR)

    try:
        out = safe_parse(p)
        aa = out["assemblies"][ASSEMBLY_ID][0]
    except Exception as e:
        bad_rows.append({
            "file_id": file_id,
            "path": str(p),
            "error": repr(e)[:300],
        })
        continue

    # ---- chains-level ----
    units = get_polymer_chain_units(aa)

    for u in units:
        chain_rows.append({
            "example_id": f"{file_id}__{u['pn_unit_iid']}",
            "path": str(p),
            "assembly_id": ASSEMBLY_ID,
            "pn_unit_iid": u["pn_unit_iid"],
            "chain_id": u["chain_id"],
            "file_id": file_id,
            "n_atoms_chain": u["n_atoms"],
        })

    # ---- interfaces-level ----
    # 只在“聚合物链之间”找界面（你也可以扩展到 ligand）
    iids = [u["pn_unit_iid"] for u in units]
    for iid1, iid2 in itertools.combinations(sorted(iids), 2):
        n_contacts, min_dist = count_chain_contacts(aa, iid1, iid2, cutoff=CONTACT_CUTOFF_A)
        if n_contacts >= MIN_CONTACTS:
            iface_rows.append({
                "example_id": f"{file_id}__{iid1}__{iid2}",
                "path": str(p),
                "assembly_id": ASSEMBLY_ID,
                "pn_unit_1_iid": iid1,
                "pn_unit_2_iid": iid2,
                "file_id": file_id,
                "n_contacts": int(n_contacts),
                "min_dist": float(min_dist),
                "interface_type": "polymer-polymer",
            })

# ---------- Write parquet ----------
df_chains = pd.DataFrame(chain_rows)
df_ifaces = pd.DataFrame(iface_rows)
df_bad = pd.DataFrame(bad_rows)

df_chains.to_parquet("chains_metadata.parquet", index=False)
df_ifaces.to_parquet("interfaces_metadata.parquet", index=False)
df_bad.to_csv("bad_files.csv", index=False)

print("chains rows:", len(df_chains))
print("interfaces rows:", len(df_ifaces))
print("bad files:", len(df_bad))
print(df_bad.head())


Found CIFs: 3
chains rows: 5
interfaces rows: 3
bad files: 0
Empty DataFrame
Columns: []
Index: []


In [25]:
from atomworks.ml.datasets import PandasDataset
from atomworks.ml.datasets.loaders import create_loader_with_query_pn_units

loader_chain = create_loader_with_query_pn_units(
    pn_unit_iid_colnames="pn_unit_iid",
    parser_args={"add_id_and_entity_annotations": False},
)

ds_chain = PandasDataset(
    data="chains_metadata.parquet",
    name="chains_ds",
    loader=loader_chain,
)

print(f"Dataset size: {len(ds_chain)} examples")

# print all examples
for i in range(len(ds_chain)):
    ex = ds_chain[i]
    print(ex["example_id"], ex["query_pn_unit_iids"])
    print(ex.keys())
    print(ex['extra_info'])
    print(ex['atom_array'][0:2])



Dataset size: 5 examples
7ee8_cif__A_1 ['A_1']
dict_keys(['example_id', 'path', 'assembly_id', 'extra_info', 'atom_array', 'atom_array_stack', 'chain_info', 'ligand_info', 'metadata', 'query_pn_unit_iids'])
{'chain_id': 'A', 'file_id': '7ee8_cif', 'n_atoms_chain': 867}
    A       1  PRO N      N       -30.570  -19.082   12.020
    A       1  PRO CA     C       -29.427  -18.387   11.424
7ee8_cif__B_1 ['B_1']
dict_keys(['example_id', 'path', 'assembly_id', 'extra_info', 'atom_array', 'atom_array_stack', 'chain_info', 'ligand_info', 'metadata', 'query_pn_unit_iids'])
{'chain_id': 'B', 'file_id': '7ee8_cif', 'n_atoms_chain': 863}
    A       1  PRO N      N       -30.570  -19.082   12.020
    A       1  PRO CA     C       -29.427  -18.387   11.424
7ee8_cif__C_1 ['C_1']
dict_keys(['example_id', 'path', 'assembly_id', 'extra_info', 'atom_array', 'atom_array_stack', 'chain_info', 'ligand_info', 'metadata', 'query_pn_unit_iids'])
{'chain_id': 'C', 'file_id': '7ee8_cif', 'n_atoms_chain': 865}


In [26]:
loader_iface = create_loader_with_query_pn_units(
    pn_unit_iid_colnames=["pn_unit_1_iid", "pn_unit_2_iid"],
    parser_args={"add_id_and_entity_annotations": False},
)

ds_iface = PandasDataset(
    data="interfaces_metadata.parquet",
    name="ifaces_ds",
    loader=loader_iface,
)

for i in range(len(ds_iface)):
    ex = ds_iface[i]
    print(ex["example_id"], ex["query_pn_unit_iids"])
    print(ex.keys())
    print(ex['extra_info'])
    print(ex['atom_array'][0:2])


7ee8_cif__A_1__B_1 ['A_1', 'B_1']
dict_keys(['example_id', 'path', 'assembly_id', 'extra_info', 'atom_array', 'atom_array_stack', 'chain_info', 'ligand_info', 'metadata', 'query_pn_unit_iids'])
{'file_id': '7ee8_cif', 'n_contacts': 148, 'min_dist': 2.676819324493408, 'interface_type': 'polymer-polymer'}
    A       1  PRO N      N       -30.570  -19.082   12.020
    A       1  PRO CA     C       -29.427  -18.387   11.424
7ee8_cif__A_1__C_1 ['A_1', 'C_1']
dict_keys(['example_id', 'path', 'assembly_id', 'extra_info', 'atom_array', 'atom_array_stack', 'chain_info', 'ligand_info', 'metadata', 'query_pn_unit_iids'])
{'file_id': '7ee8_cif', 'n_contacts': 148, 'min_dist': 2.662930727005005, 'interface_type': 'polymer-polymer'}
    A       1  PRO N      N       -30.570  -19.082   12.020
    A       1  PRO CA     C       -29.427  -18.387   11.424
7ee8_cif__B_1__C_1 ['B_1', 'C_1']
dict_keys(['example_id', 'path', 'assembly_id', 'extra_info', 'atom_array', 'atom_array_stack', 'chain_info', 'ligan