Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions atompack-py/python/atompack/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ class Molecule:
"""
...

def get_property(self, key: str) -> float | int | str | npt.NDArray:
def get_property(self, key: str) -> float | int | str | npt.NDArray | None:
"""
Get a custom property by key.

Expand All @@ -326,7 +326,7 @@ class Molecule:

Returns
-------
float, int, str, or ndarray
float, int, str, ndarray, or None
Property value

Raises
Expand All @@ -336,19 +336,19 @@ class Molecule:
"""
...

def set_property(self, key: str, value: float | int | str | npt.NDArray) -> None:
def set_property(self, key: str, value: float | int | str | npt.NDArray | None) -> None:
"""
Set a custom property.

Supported types: float, int, str, 1D float32/float64/int32/int64 arrays,
Supported types: None, float, int, str, 1D float32/float64/int32/int64 arrays,
and 2D float32/float64 arrays with shape (n, 3). Input dtype is preserved.
The key 'stress' is reserved; use the dedicated ``stress`` property instead.

Parameters
----------
key : str
Property key
value : float, int, str, or ndarray
value : float, int, str, ndarray, or None
Property value

Raises
Expand Down
11 changes: 7 additions & 4 deletions atompack-py/python/atompack/ase_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
_ASE_RESERVED_ARRAYS = {"numbers", "positions"}
_ASE_TYPES = None
_CALC_MODES = {"singlepoint", "nocopy", "none"}
_UNSUPPORTED_PROPERTY = object()


def _voigt6_to_mat3x3(stress):
Expand Down Expand Up @@ -53,6 +54,8 @@ def _get_stress(atoms):


def _coerce_property(value, n_atoms):
if value is None:
return None
if isinstance(value, (str, bool, int, float, np.integer, np.floating)):
if isinstance(value, str):
return value
Expand All @@ -76,7 +79,7 @@ def _coerce_property(value, n_atoms):
if arr.dtype == np.float32:
return arr.astype(np.float32, copy=False)
return arr.astype(np.float64, copy=False)
return None
return _UNSUPPORTED_PROPERTY


def _merge_properties(properties, builtins, values, n_atoms):
Expand All @@ -93,7 +96,7 @@ def _merge_properties(properties, builtins, values, n_atoms):
builtins["stress"] = arr.astype(np.float64, copy=False)
continue
coerced = _coerce_property(value, n_atoms)
if coerced is not None:
if coerced is not _UNSUPPORTED_PROPERTY:
properties[key] = coerced


Expand Down Expand Up @@ -187,7 +190,7 @@ def _extract_ase_record(
if key in _ASE_RESERVED_ARRAYS or key in _BUILTIN_FIELDS:
continue
coerced = _coerce_property(value, n_atoms)
if coerced is not None:
if coerced is not _UNSUPPORTED_PROPERTY:
properties[key] = coerced

calc = getattr(atoms, "calc", None)
Expand All @@ -196,7 +199,7 @@ def _extract_ase_record(
for key, value in results.items():
if key not in _BUILTIN_FIELDS:
coerced = _coerce_property(value, n_atoms)
if coerced is not None:
if coerced is not _UNSUPPORTED_PROPERTY:
properties[key] = coerced

if copy_info and getattr(atoms, "info", None):
Expand Down
6 changes: 6 additions & 0 deletions atompack-py/src/database_flat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,9 @@ pub(super) fn get_molecules_flat_soa_impl<'py>(
}

if schema_entry.slot_bytes == 0 {
if schema_entry.type_tag == TYPE_NONE {
continue;
}
if let Some(ref mtx) = string_mutexes[section_idx] {
let val = Some(
std::str::from_utf8(sec.payload)
Expand Down Expand Up @@ -449,6 +452,9 @@ pub(super) fn get_molecules_flat_soa_impl<'py>(
}

if schema_entry.slot_bytes == 0 {
if schema_entry.type_tag == TYPE_NONE {
continue;
}
if let Some(ref mtx) = string_mutexes[section_idx] {
let val = Some(
std::str::from_utf8(sec.payload)
Expand Down
1 change: 1 addition & 0 deletions atompack-py/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ const TYPE_BOOL3: u8 = 9;
const TYPE_MAT3X3_F64: u8 = 10;
const TYPE_FLOAT32: u8 = 11;
const TYPE_MAT3X3_F32: u8 = 12;
const TYPE_NONE: u8 = 13;

const RECORD_FORMAT_SOA_V2: u32 = 2;
const RECORD_FORMAT_SOA_V3: u32 = 3;
Expand Down
8 changes: 8 additions & 0 deletions atompack-py/src/molecule_helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,7 @@ pub(super) fn property_value_to_pyobject(
value: &PropertyValue,
) -> PyResult<Py<PyAny>> {
Ok(match value {
PropertyValue::None => py.None(),
PropertyValue::Float(v) => into_py_any(py, *v)?,
PropertyValue::Int(v) => into_py_any(py, *v)?,
PropertyValue::String(v) => into_py_any(py, v)?,
Expand Down Expand Up @@ -411,6 +412,12 @@ pub(super) fn property_section_to_pyobject<'py>(
Ok(payload.len() / stride)
};
Ok(match section.type_tag {
TYPE_NONE => {
if !payload.is_empty() {
return Err(PyValueError::new_err("Null property payload must be empty"));
}
py.None()
}
TYPE_FLOAT => into_py_any(py, read_f64_scalar(payload)?)?,
TYPE_INT => into_py_any(py, read_i64_scalar(payload)?)?,
TYPE_STRING => into_py_any(
Expand Down Expand Up @@ -448,6 +455,7 @@ pub(super) fn property_section_to_pyobject<'py>(

fn property_value_is_atom_array(value: &PropertyValue, n_atoms: usize) -> bool {
match value {
PropertyValue::None => false,
PropertyValue::FloatArray(values) => values.len() == n_atoms,
PropertyValue::Vec3Array(values) => values.len() == n_atoms,
PropertyValue::IntArray(values) => values.len() == n_atoms,
Expand Down
5 changes: 4 additions & 1 deletion atompack-py/src/py_dtypes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,9 @@ pub(crate) fn parse_mat3_field(value: &Bound<'_, PyAny>, label: &str) -> PyResul
}

pub(crate) fn parse_property_value(value: &Bound<'_, PyAny>) -> PyResult<PropertyValue> {
if value.is_none() {
return Ok(PropertyValue::None);
}
if let Ok(v) = value.extract::<i64>() {
return Ok(PropertyValue::Int(v));
}
Expand Down Expand Up @@ -305,6 +308,6 @@ pub(crate) fn parse_property_value(value: &Bound<'_, PyAny>) -> PyResult<Propert
});
}
Err(PyValueError::new_err(
"Unsupported property type. Supported: float, int, str, ndarray",
"Unsupported property type. Supported: None, float, int, str, ndarray",
))
}
23 changes: 20 additions & 3 deletions atompack-py/src/soa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ pub(crate) fn section_schema_from_ref(
) -> atompack::Result<SectionSchema> {
let per_atom = is_per_atom(section.kind, section.key, section.type_tag);
let elem_bytes = match section.type_tag {
TYPE_NONE => 0,
TYPE_STRING => 0,
tag if per_atom => {
let elem_bytes = type_tag_elem_bytes(tag);
Expand All @@ -168,7 +169,7 @@ pub(crate) fn section_schema_from_ref(
TYPE_MAT3X3_F64 => 72,
_ => section.payload.len(),
};
let slot_bytes = if section.type_tag == TYPE_STRING {
let slot_bytes = if matches!(section.type_tag, TYPE_STRING | TYPE_NONE) {
0
} else if per_atom {
elem_bytes
Expand Down Expand Up @@ -196,6 +197,15 @@ pub(crate) fn validate_section_payload(
n_atoms: usize,
) -> atompack::Result<()> {
match section.type_tag {
TYPE_NONE => {
if !section.payload.is_empty() {
return Err(invalid_data(format!(
"Section '{}' has invalid payload length {} (expected 0)",
section.key,
section.payload.len()
)));
}
}
TYPE_STRING => {
std::str::from_utf8(section.payload)
.map_err(|_| invalid_data(format!("Invalid UTF-8 in section '{}'", section.key)))?;
Expand Down Expand Up @@ -254,6 +264,7 @@ pub(crate) fn validate_section_payload(
/// Element size in bytes for a given type tag. Returns 0 for variable-length types.
pub(crate) fn type_tag_elem_bytes(tag: u8) -> usize {
match tag {
TYPE_NONE => 0,
TYPE_FLOAT => 8,
TYPE_INT => 8,
TYPE_STRING => 0,
Expand Down Expand Up @@ -289,7 +300,7 @@ fn database_schema_section(
n_atoms: usize,
) -> PyResult<DatabaseSchemaSection> {
let per_atom = is_per_atom(kind, key, type_tag);
let elem_bytes = if type_tag == TYPE_STRING {
let elem_bytes = if matches!(type_tag, TYPE_STRING | TYPE_NONE) {
0
} else {
let elem_bytes = type_tag_elem_bytes(type_tag);
Expand All @@ -301,7 +312,7 @@ fn database_schema_section(
}
elem_bytes
};
let slot_bytes = if type_tag == TYPE_STRING {
let slot_bytes = if matches!(type_tag, TYPE_STRING | TYPE_NONE) {
0
} else if per_atom {
let expected = n_atoms.checked_mul(elem_bytes).ok_or_else(|| {
Expand Down Expand Up @@ -1096,6 +1107,12 @@ fn decode_mat3x3_f32(payload: &[u8]) -> PyResult<[[f32; 3]; 3]> {

fn decode_property_value(type_tag: u8, payload: &[u8]) -> PyResult<PropertyValue> {
Ok(match type_tag {
TYPE_NONE => {
if !payload.is_empty() {
return Err(PyValueError::new_err("Null property payload must be empty"));
}
PropertyValue::None
}
TYPE_FLOAT => PropertyValue::Float(read_f64_scalar(payload)?),
TYPE_INT => PropertyValue::Int(read_i64_scalar(payload)?),
TYPE_STRING => PropertyValue::String(
Expand Down
4 changes: 4 additions & 0 deletions atompack-py/tests/test_atom_molecule.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def test_molecule_custom_properties() -> None:
mol.set_property("int_vec32", np.array([3, 4], dtype=np.int32))
mol.set_property("vec3", np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32))
mol.set_property("vec3_f64", np.array([[1.1, 1.2, 1.3], [2.1, 2.2, 2.3]], dtype=np.float64))
mol.set_property("optional_label", None)
mol.stress = np.eye(3, dtype=np.float64) * 3.0

assert mol.get_property("temperature") == pytest.approx(300.0)
Expand Down Expand Up @@ -208,10 +209,12 @@ def test_molecule_custom_properties() -> None:
np.testing.assert_allclose(
vec3_f64, np.array([[1.1, 1.2, 1.3], [2.1, 2.2, 2.3]], dtype=np.float64)
)
assert mol.get_property("optional_label") is None

np.testing.assert_allclose(mol.stress, np.eye(3, dtype=np.float64) * 3.0)

assert mol.has_property("method") is True
assert mol.has_property("optional_label") is True
assert mol.has_property("stress") is False
assert set(mol.property_keys()) >= {
"temperature",
Expand All @@ -223,6 +226,7 @@ def test_molecule_custom_properties() -> None:
"int_vec32",
"vec3",
"vec3_f64",
"optional_label",
}

with pytest.raises(KeyError, match=r"not found"):
Expand Down
2 changes: 2 additions & 0 deletions atompack-py/tests/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def _make_molecule(energy: float) -> atompack.Molecule:
mol.velocities = np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], dtype=np.float32)
mol.cell = np.eye(3, dtype=np.float64) * 2.0
mol.set_property("tag", "train")
mol.set_property("optional", None)
mol.set_property("ids", np.array([1, 2], dtype=np.int64))
return mol

Expand Down Expand Up @@ -48,6 +49,7 @@ def test_database_roundtrip(tmp_path: Path, compression: str) -> None:
np.testing.assert_allclose(mol1_r.velocities, mol1.velocities)
np.testing.assert_allclose(mol1_r.cell, mol1.cell)
assert mol1_r.get_property("tag") == "train"
assert mol1_r.get_property("optional") is None
np.testing.assert_array_equal(mol1_r.get_property("ids"), np.array([1, 2], dtype=np.int64))

batch = db2.get_molecules([0, 1])
Expand Down
18 changes: 18 additions & 0 deletions atompack-py/tests/test_from_ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def test_from_ase_extracts_core_fields() -> None:
"int_vec32": np.array([3, 4], dtype=np.int32),
"vec3": np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32),
"vec3_f64": np.array([[1.0, 1.1, 1.2], [2.0, 2.1, 2.2]], dtype=np.float64),
"nullable": None,
"stress": np.array(
[[1.1, 1.2, 1.3], [2.1, 2.2, 2.3], [3.1, 3.2, 3.3]], dtype=np.float64
),
Expand Down Expand Up @@ -126,6 +127,7 @@ def test_from_ase_extracts_core_fields() -> None:
np.testing.assert_allclose(
mol.get_property("vec3_f64"), np.array([[1.0, 1.1, 1.2], [2.0, 2.1, 2.2]], dtype=np.float64)
)
assert mol.get_property("nullable") is None
np.testing.assert_allclose(
mol.stress,
np.array([[1.1, 1.2, 1.3], [2.1, 2.2, 2.3], [3.1, 3.2, 3.3]], dtype=np.float64),
Expand Down Expand Up @@ -360,6 +362,22 @@ def test_to_ase_owned_maps_builtins_and_properties() -> None:
np.testing.assert_array_equal(atoms.arrays["tags"], np.array([3, 4], dtype=np.int32))


def test_to_ase_roundtrip_preserves_none_custom_property() -> None:
mol = atompack.Molecule.from_arrays(
np.array([[0.0, 0.0, 0.0]], dtype=np.float32),
np.array([1], dtype=np.uint8),
)
mol.set_property("nullable", None)

atoms = mol.to_ase()
assert "nullable" in atoms.info
assert atoms.info["nullable"] is None

roundtrip = atompack.from_ase(atoms)
assert roundtrip.has_property("nullable") is True
assert roundtrip.get_property("nullable") is None


def test_to_ase_calc_modes() -> None:
mol = atompack.Molecule.from_arrays(
np.array([[0.0, 0.0, 0.0], [1.0, 0.5, 0.0]], dtype=np.float32),
Expand Down
11 changes: 11 additions & 0 deletions atompack/src/storage/dtypes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ pub(super) fn float_scalar_payload_len(value: &FloatScalarData) -> usize {

pub(super) fn property_value_type_tag(value: &PropertyValue) -> u8 {
match value {
PropertyValue::None => TYPE_NONE,
PropertyValue::Float(_) => TYPE_FLOAT,
PropertyValue::Int(_) => TYPE_INT,
PropertyValue::String(_) => TYPE_STRING,
Expand All @@ -83,6 +84,7 @@ pub(super) fn property_value_type_tag(value: &PropertyValue) -> u8 {

pub(super) fn property_value_payload_len(value: &PropertyValue) -> usize {
match value {
PropertyValue::None => 0,
PropertyValue::Float(_) | PropertyValue::Int(_) => 8,
PropertyValue::String(value) => value.len(),
PropertyValue::FloatArray(values) => values.len() * 8,
Expand Down Expand Up @@ -120,6 +122,7 @@ fn extend_i32(buf: &mut Vec<u8>, values: &[i32]) {

pub(super) fn property_value_to_bytes(value: &PropertyValue) -> Vec<u8> {
match value {
PropertyValue::None => Vec::new(),
PropertyValue::Float(value) => value.to_le_bytes().to_vec(),
PropertyValue::Int(value) => value.to_le_bytes().to_vec(),
PropertyValue::String(value) => value.as_bytes().to_vec(),
Expand Down Expand Up @@ -291,6 +294,14 @@ pub(super) fn decode_mat3x3_f64(payload: &[u8]) -> Result<[[f64; 3]; 3]> {

pub(super) fn decode_property_value(type_tag: u8, payload: &[u8]) -> Result<PropertyValue> {
Ok(match type_tag {
TYPE_NONE => {
if !payload.is_empty() {
return Err(Error::InvalidData(
"null property payload must be empty".into(),
));
}
PropertyValue::None
}
TYPE_FLOAT => {
if payload.len() < 8 {
return Err(Error::InvalidData("f64 property truncated".into()));
Expand Down
Loading
Loading