diff --git a/atompack-py/python/atompack/__init__.pyi b/atompack-py/python/atompack/__init__.pyi index a28c97e..9d49536 100644 --- a/atompack-py/python/atompack/__init__.pyi +++ b/atompack-py/python/atompack/__init__.pyi @@ -315,7 +315,7 @@ class Molecule: """ ... - def get_property(self, key: str) -> float | int | str | npt.NDArray: + def get_property(self, key: str) -> float | int | str | npt.NDArray | None: """ Get a custom property by key. @@ -326,7 +326,7 @@ class Molecule: Returns ------- - float, int, str, or ndarray + float, int, str, ndarray, or None Property value Raises @@ -336,11 +336,11 @@ class Molecule: """ ... - def set_property(self, key: str, value: float | int | str | npt.NDArray) -> None: + def set_property(self, key: str, value: float | int | str | npt.NDArray | None) -> None: """ Set a custom property. - Supported types: float, int, str, 1D float32/float64/int32/int64 arrays, + Supported types: None, float, int, str, 1D float32/float64/int32/int64 arrays, and 2D float32/float64 arrays with shape (n, 3). Input dtype is preserved. The key 'stress' is reserved; use the dedicated ``stress`` property instead. @@ -348,7 +348,7 @@ class Molecule: ---------- key : str Property key - value : float, int, str, or ndarray + value : float, int, str, ndarray, or None Property value Raises diff --git a/atompack-py/python/atompack/ase_bridge.py b/atompack-py/python/atompack/ase_bridge.py index b268580..a1cc155 100644 --- a/atompack-py/python/atompack/ase_bridge.py +++ b/atompack-py/python/atompack/ase_bridge.py @@ -21,6 +21,7 @@ _ASE_RESERVED_ARRAYS = {"numbers", "positions"} _ASE_TYPES = None _CALC_MODES = {"singlepoint", "nocopy", "none"} +_UNSUPPORTED_PROPERTY = object() def _voigt6_to_mat3x3(stress): @@ -53,6 +54,8 @@ def _get_stress(atoms): def _coerce_property(value, n_atoms): + if value is None: + return None if isinstance(value, (str, bool, int, float, np.integer, np.floating)): if isinstance(value, str): return value @@ -76,7 +79,7 @@ def _coerce_property(value, n_atoms): if arr.dtype == np.float32: return arr.astype(np.float32, copy=False) return arr.astype(np.float64, copy=False) - return None + return _UNSUPPORTED_PROPERTY def _merge_properties(properties, builtins, values, n_atoms): @@ -93,7 +96,7 @@ def _merge_properties(properties, builtins, values, n_atoms): builtins["stress"] = arr.astype(np.float64, copy=False) continue coerced = _coerce_property(value, n_atoms) - if coerced is not None: + if coerced is not _UNSUPPORTED_PROPERTY: properties[key] = coerced @@ -187,7 +190,7 @@ def _extract_ase_record( if key in _ASE_RESERVED_ARRAYS or key in _BUILTIN_FIELDS: continue coerced = _coerce_property(value, n_atoms) - if coerced is not None: + if coerced is not _UNSUPPORTED_PROPERTY: properties[key] = coerced calc = getattr(atoms, "calc", None) @@ -196,7 +199,7 @@ def _extract_ase_record( for key, value in results.items(): if key not in _BUILTIN_FIELDS: coerced = _coerce_property(value, n_atoms) - if coerced is not None: + if coerced is not _UNSUPPORTED_PROPERTY: properties[key] = coerced if copy_info and getattr(atoms, "info", None): diff --git a/atompack-py/src/database_flat.rs b/atompack-py/src/database_flat.rs index a8ada30..8486633 100644 --- a/atompack-py/src/database_flat.rs +++ b/atompack-py/src/database_flat.rs @@ -368,6 +368,9 @@ pub(super) fn get_molecules_flat_soa_impl<'py>( } if schema_entry.slot_bytes == 0 { + if schema_entry.type_tag == TYPE_NONE { + continue; + } if let Some(ref mtx) = string_mutexes[section_idx] { let val = Some( std::str::from_utf8(sec.payload) @@ -449,6 +452,9 @@ pub(super) fn get_molecules_flat_soa_impl<'py>( } if schema_entry.slot_bytes == 0 { + if schema_entry.type_tag == TYPE_NONE { + continue; + } if let Some(ref mtx) = string_mutexes[section_idx] { let val = Some( std::str::from_utf8(sec.payload) diff --git a/atompack-py/src/lib.rs b/atompack-py/src/lib.rs index 3e6b6e3..92639ea 100644 --- a/atompack-py/src/lib.rs +++ b/atompack-py/src/lib.rs @@ -113,6 +113,7 @@ const TYPE_BOOL3: u8 = 9; const TYPE_MAT3X3_F64: u8 = 10; const TYPE_FLOAT32: u8 = 11; const TYPE_MAT3X3_F32: u8 = 12; +const TYPE_NONE: u8 = 13; const RECORD_FORMAT_SOA_V2: u32 = 2; const RECORD_FORMAT_SOA_V3: u32 = 3; diff --git a/atompack-py/src/molecule_helpers.rs b/atompack-py/src/molecule_helpers.rs index b48c3e8..12aa81a 100644 --- a/atompack-py/src/molecule_helpers.rs +++ b/atompack-py/src/molecule_helpers.rs @@ -374,6 +374,7 @@ pub(super) fn property_value_to_pyobject( value: &PropertyValue, ) -> PyResult> { Ok(match value { + PropertyValue::None => py.None(), PropertyValue::Float(v) => into_py_any(py, *v)?, PropertyValue::Int(v) => into_py_any(py, *v)?, PropertyValue::String(v) => into_py_any(py, v)?, @@ -411,6 +412,12 @@ pub(super) fn property_section_to_pyobject<'py>( Ok(payload.len() / stride) }; Ok(match section.type_tag { + TYPE_NONE => { + if !payload.is_empty() { + return Err(PyValueError::new_err("Null property payload must be empty")); + } + py.None() + } TYPE_FLOAT => into_py_any(py, read_f64_scalar(payload)?)?, TYPE_INT => into_py_any(py, read_i64_scalar(payload)?)?, TYPE_STRING => into_py_any( @@ -448,6 +455,7 @@ pub(super) fn property_section_to_pyobject<'py>( fn property_value_is_atom_array(value: &PropertyValue, n_atoms: usize) -> bool { match value { + PropertyValue::None => false, PropertyValue::FloatArray(values) => values.len() == n_atoms, PropertyValue::Vec3Array(values) => values.len() == n_atoms, PropertyValue::IntArray(values) => values.len() == n_atoms, diff --git a/atompack-py/src/py_dtypes.rs b/atompack-py/src/py_dtypes.rs index e05f112..95a0bd3 100644 --- a/atompack-py/src/py_dtypes.rs +++ b/atompack-py/src/py_dtypes.rs @@ -243,6 +243,9 @@ pub(crate) fn parse_mat3_field(value: &Bound<'_, PyAny>, label: &str) -> PyResul } pub(crate) fn parse_property_value(value: &Bound<'_, PyAny>) -> PyResult { + if value.is_none() { + return Ok(PropertyValue::None); + } if let Ok(v) = value.extract::() { return Ok(PropertyValue::Int(v)); } @@ -305,6 +308,6 @@ pub(crate) fn parse_property_value(value: &Bound<'_, PyAny>) -> PyResult atompack::Result { let per_atom = is_per_atom(section.kind, section.key, section.type_tag); let elem_bytes = match section.type_tag { + TYPE_NONE => 0, TYPE_STRING => 0, tag if per_atom => { let elem_bytes = type_tag_elem_bytes(tag); @@ -168,7 +169,7 @@ pub(crate) fn section_schema_from_ref( TYPE_MAT3X3_F64 => 72, _ => section.payload.len(), }; - let slot_bytes = if section.type_tag == TYPE_STRING { + let slot_bytes = if matches!(section.type_tag, TYPE_STRING | TYPE_NONE) { 0 } else if per_atom { elem_bytes @@ -196,6 +197,15 @@ pub(crate) fn validate_section_payload( n_atoms: usize, ) -> atompack::Result<()> { match section.type_tag { + TYPE_NONE => { + if !section.payload.is_empty() { + return Err(invalid_data(format!( + "Section '{}' has invalid payload length {} (expected 0)", + section.key, + section.payload.len() + ))); + } + } TYPE_STRING => { std::str::from_utf8(section.payload) .map_err(|_| invalid_data(format!("Invalid UTF-8 in section '{}'", section.key)))?; @@ -254,6 +264,7 @@ pub(crate) fn validate_section_payload( /// Element size in bytes for a given type tag. Returns 0 for variable-length types. pub(crate) fn type_tag_elem_bytes(tag: u8) -> usize { match tag { + TYPE_NONE => 0, TYPE_FLOAT => 8, TYPE_INT => 8, TYPE_STRING => 0, @@ -289,7 +300,7 @@ fn database_schema_section( n_atoms: usize, ) -> PyResult { let per_atom = is_per_atom(kind, key, type_tag); - let elem_bytes = if type_tag == TYPE_STRING { + let elem_bytes = if matches!(type_tag, TYPE_STRING | TYPE_NONE) { 0 } else { let elem_bytes = type_tag_elem_bytes(type_tag); @@ -301,7 +312,7 @@ fn database_schema_section( } elem_bytes }; - let slot_bytes = if type_tag == TYPE_STRING { + let slot_bytes = if matches!(type_tag, TYPE_STRING | TYPE_NONE) { 0 } else if per_atom { let expected = n_atoms.checked_mul(elem_bytes).ok_or_else(|| { @@ -1096,6 +1107,12 @@ fn decode_mat3x3_f32(payload: &[u8]) -> PyResult<[[f32; 3]; 3]> { fn decode_property_value(type_tag: u8, payload: &[u8]) -> PyResult { Ok(match type_tag { + TYPE_NONE => { + if !payload.is_empty() { + return Err(PyValueError::new_err("Null property payload must be empty")); + } + PropertyValue::None + } TYPE_FLOAT => PropertyValue::Float(read_f64_scalar(payload)?), TYPE_INT => PropertyValue::Int(read_i64_scalar(payload)?), TYPE_STRING => PropertyValue::String( diff --git a/atompack-py/tests/test_atom_molecule.py b/atompack-py/tests/test_atom_molecule.py index 5674ca0..6c03f83 100644 --- a/atompack-py/tests/test_atom_molecule.py +++ b/atompack-py/tests/test_atom_molecule.py @@ -169,6 +169,7 @@ def test_molecule_custom_properties() -> None: mol.set_property("int_vec32", np.array([3, 4], dtype=np.int32)) mol.set_property("vec3", np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32)) mol.set_property("vec3_f64", np.array([[1.1, 1.2, 1.3], [2.1, 2.2, 2.3]], dtype=np.float64)) + mol.set_property("optional_label", None) mol.stress = np.eye(3, dtype=np.float64) * 3.0 assert mol.get_property("temperature") == pytest.approx(300.0) @@ -208,10 +209,12 @@ def test_molecule_custom_properties() -> None: np.testing.assert_allclose( vec3_f64, np.array([[1.1, 1.2, 1.3], [2.1, 2.2, 2.3]], dtype=np.float64) ) + assert mol.get_property("optional_label") is None np.testing.assert_allclose(mol.stress, np.eye(3, dtype=np.float64) * 3.0) assert mol.has_property("method") is True + assert mol.has_property("optional_label") is True assert mol.has_property("stress") is False assert set(mol.property_keys()) >= { "temperature", @@ -223,6 +226,7 @@ def test_molecule_custom_properties() -> None: "int_vec32", "vec3", "vec3_f64", + "optional_label", } with pytest.raises(KeyError, match=r"not found"): diff --git a/atompack-py/tests/test_database.py b/atompack-py/tests/test_database.py index 4d83dfd..ccaac1a 100644 --- a/atompack-py/tests/test_database.py +++ b/atompack-py/tests/test_database.py @@ -20,6 +20,7 @@ def _make_molecule(energy: float) -> atompack.Molecule: mol.velocities = np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], dtype=np.float32) mol.cell = np.eye(3, dtype=np.float64) * 2.0 mol.set_property("tag", "train") + mol.set_property("optional", None) mol.set_property("ids", np.array([1, 2], dtype=np.int64)) return mol @@ -48,6 +49,7 @@ def test_database_roundtrip(tmp_path: Path, compression: str) -> None: np.testing.assert_allclose(mol1_r.velocities, mol1.velocities) np.testing.assert_allclose(mol1_r.cell, mol1.cell) assert mol1_r.get_property("tag") == "train" + assert mol1_r.get_property("optional") is None np.testing.assert_array_equal(mol1_r.get_property("ids"), np.array([1, 2], dtype=np.int64)) batch = db2.get_molecules([0, 1]) diff --git a/atompack-py/tests/test_from_ase.py b/atompack-py/tests/test_from_ase.py index 8505310..c485cd4 100644 --- a/atompack-py/tests/test_from_ase.py +++ b/atompack-py/tests/test_from_ase.py @@ -87,6 +87,7 @@ def test_from_ase_extracts_core_fields() -> None: "int_vec32": np.array([3, 4], dtype=np.int32), "vec3": np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32), "vec3_f64": np.array([[1.0, 1.1, 1.2], [2.0, 2.1, 2.2]], dtype=np.float64), + "nullable": None, "stress": np.array( [[1.1, 1.2, 1.3], [2.1, 2.2, 2.3], [3.1, 3.2, 3.3]], dtype=np.float64 ), @@ -126,6 +127,7 @@ def test_from_ase_extracts_core_fields() -> None: np.testing.assert_allclose( mol.get_property("vec3_f64"), np.array([[1.0, 1.1, 1.2], [2.0, 2.1, 2.2]], dtype=np.float64) ) + assert mol.get_property("nullable") is None np.testing.assert_allclose( mol.stress, np.array([[1.1, 1.2, 1.3], [2.1, 2.2, 2.3], [3.1, 3.2, 3.3]], dtype=np.float64), @@ -360,6 +362,22 @@ def test_to_ase_owned_maps_builtins_and_properties() -> None: np.testing.assert_array_equal(atoms.arrays["tags"], np.array([3, 4], dtype=np.int32)) +def test_to_ase_roundtrip_preserves_none_custom_property() -> None: + mol = atompack.Molecule.from_arrays( + np.array([[0.0, 0.0, 0.0]], dtype=np.float32), + np.array([1], dtype=np.uint8), + ) + mol.set_property("nullable", None) + + atoms = mol.to_ase() + assert "nullable" in atoms.info + assert atoms.info["nullable"] is None + + roundtrip = atompack.from_ase(atoms) + assert roundtrip.has_property("nullable") is True + assert roundtrip.get_property("nullable") is None + + def test_to_ase_calc_modes() -> None: mol = atompack.Molecule.from_arrays( np.array([[0.0, 0.0, 0.0], [1.0, 0.5, 0.0]], dtype=np.float32), diff --git a/atompack/src/storage/dtypes.rs b/atompack/src/storage/dtypes.rs index 950b9b0..309eec1 100644 --- a/atompack/src/storage/dtypes.rs +++ b/atompack/src/storage/dtypes.rs @@ -69,6 +69,7 @@ pub(super) fn float_scalar_payload_len(value: &FloatScalarData) -> usize { pub(super) fn property_value_type_tag(value: &PropertyValue) -> u8 { match value { + PropertyValue::None => TYPE_NONE, PropertyValue::Float(_) => TYPE_FLOAT, PropertyValue::Int(_) => TYPE_INT, PropertyValue::String(_) => TYPE_STRING, @@ -83,6 +84,7 @@ pub(super) fn property_value_type_tag(value: &PropertyValue) -> u8 { pub(super) fn property_value_payload_len(value: &PropertyValue) -> usize { match value { + PropertyValue::None => 0, PropertyValue::Float(_) | PropertyValue::Int(_) => 8, PropertyValue::String(value) => value.len(), PropertyValue::FloatArray(values) => values.len() * 8, @@ -120,6 +122,7 @@ fn extend_i32(buf: &mut Vec, values: &[i32]) { pub(super) fn property_value_to_bytes(value: &PropertyValue) -> Vec { match value { + PropertyValue::None => Vec::new(), PropertyValue::Float(value) => value.to_le_bytes().to_vec(), PropertyValue::Int(value) => value.to_le_bytes().to_vec(), PropertyValue::String(value) => value.as_bytes().to_vec(), @@ -291,6 +294,14 @@ pub(super) fn decode_mat3x3_f64(payload: &[u8]) -> Result<[[f64; 3]; 3]> { pub(super) fn decode_property_value(type_tag: u8, payload: &[u8]) -> Result { Ok(match type_tag { + TYPE_NONE => { + if !payload.is_empty() { + return Err(Error::InvalidData( + "null property payload must be empty".into(), + )); + } + PropertyValue::None + } TYPE_FLOAT => { if payload.len() < 8 { return Err(Error::InvalidData("f64 property truncated".into())); diff --git a/atompack/src/storage/mod.rs b/atompack/src/storage/mod.rs index 886aceb..405bb4c 100644 --- a/atompack/src/storage/mod.rs +++ b/atompack/src/storage/mod.rs @@ -79,6 +79,7 @@ const TYPE_BOOL3: u8 = 9; // [bool; 3] const TYPE_MAT3X3_F64: u8 = 10; // [[f64; 3]; 3] const TYPE_FLOAT32: u8 = 11; // f32 scalar const TYPE_MAT3X3_F32: u8 = 12; // [[f32; 3]; 3] +const TYPE_NONE: u8 = 13; // explicit null property // Two redundant page-aligned header slots for crash safety. const HEADER_SLOT_SIZE: usize = 4096; @@ -1425,6 +1426,8 @@ mod tests { "i32arr".to_string(), PropertyValue::Int32Array(vec![100, -200]), ); + mol.properties + .insert("none_val".to_string(), PropertyValue::None); { let mut db = AtomDatabase::create(&path, CompressionType::Lz4).unwrap(); @@ -1463,7 +1466,7 @@ mod tests { } // Verify properties - assert_eq!(r.properties.len(), 9); + assert_eq!(r.properties.len(), 10); match r.properties.get("scalar_f").unwrap() { PropertyValue::Float(v) => assert_eq!(*v, 99.9), other => panic!("expected Float, got {:?}", other), @@ -1476,6 +1479,10 @@ mod tests { PropertyValue::String(v) => assert_eq!(v, "hello"), other => panic!("expected String, got {:?}", other), } + match r.properties.get("none_val").unwrap() { + PropertyValue::None => {} + other => panic!("expected None, got {:?}", other), + } } #[test] diff --git a/atompack/src/storage/schema.rs b/atompack/src/storage/schema.rs index 0f87334..7ba950f 100644 --- a/atompack/src/storage/schema.rs +++ b/atompack/src/storage/schema.rs @@ -125,6 +125,7 @@ pub(super) fn decode_schema_lock(bytes: &[u8]) -> Result { fn schema_type_tag_elem_bytes(tag: u8) -> Result { match tag { + TYPE_NONE => Ok(0), TYPE_FLOAT => Ok(8), TYPE_INT => Ok(8), TYPE_STRING => Ok(0), @@ -163,7 +164,7 @@ fn schema_entry( ) -> Result { let per_atom = schema_is_per_atom(kind, key); let elem_bytes = schema_type_tag_elem_bytes(type_tag)?; - let slot_bytes = if type_tag == TYPE_STRING { + let slot_bytes = if matches!(type_tag, TYPE_STRING | TYPE_NONE) { 0 } else if per_atom { match type_tag { diff --git a/atompack/src/types.rs b/atompack/src/types.rs index 506e4af..e6d8b54 100644 --- a/atompack/src/types.rs +++ b/atompack/src/types.rs @@ -3,6 +3,7 @@ use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Serialize, Deserialize)] pub enum PropertyValue { + None, Float(f64), Int(i64), String(String),