Skip to content

Commit

Permalink
more serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
andped10 committed May 16, 2024
1 parent 2e65a01 commit e6ae923
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 7 deletions.
17 changes: 10 additions & 7 deletions src/easyreflectometry/experiment/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,17 +206,20 @@ def as_dict(self, skip: list = None) -> dict:
return this_dict

@classmethod
def from_dict(cls, data: dict) -> Model:
def from_dict(cls, this_dict: dict) -> Model:
"""
Create a Model from a dictionary.
:param data: dictionary of the Model
:param this_dict: dictionary of the Model
:return: Model
"""
model = super().from_dict(data)
resolution_function = ResolutionFunction.from_dict(this_dict['resolution_function'])
del this_dict['resolution_function']
sample = Sample.from_dict(this_dict['sample'])
del this_dict['sample']

# Ensure that the sample is also converted
# TODO Should probably be handled in easyscience
model.sample = model.sample.__class__.from_dict(data['sample'])
model.resolution_function = ResolutionFunction.from_dict(data['resolution_function'])
model = super().from_dict(this_dict)

model.sample = sample
model.resolution_function = resolution_function
return model
19 changes: 19 additions & 0 deletions src/easyreflectometry/experiment/model_collection.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

__author__ = 'github.com/arm61'

from typing import Optional
Expand Down Expand Up @@ -39,3 +41,20 @@ def remove_model(self, idx: int):
:param idx: Index of the model to remove
"""
del self[idx]

@classmethod
def from_dict(cls, this_dict: dict) -> ModelCollection:
"""
Create an instance of a collection from a dictionary.
:param data: The dictionary for the collection
:return: An instance of the collection
"""
collection = super().from_dict(this_dict) # type: ModelCollection

if len(collection) != len(this_dict['data']):
raise ValueError(f"Expected {len(collection)} models, got {len(this_dict['data'])}")
for i, model_data in enumerate(this_dict['data']):
collection[i] = Model.from_dict(model_data)

return collection
20 changes: 20 additions & 0 deletions tests/experiment/test_model_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,23 @@ def test_as_dict(self):

# Expect
assert dict_repr['data'][0]['resolution_function'] == {'smearing': 'PercentageFhwm', 'constant': 5.0}

def test_dict_round_trip(self):
# When
model_1 = Model(name='Model1')
model_2 = Model(name='Model2')
model_3 = Model(name='Model3')

# Then
collection = ModelCollection(model_1, model_2, model_3)

src_dict = collection.as_dict()

# Then
collection_from_dict = ModelCollection.from_dict(src_dict)

# Expect
assert collection.as_data_dict(skip=['resolution_function', 'interface']) == collection_from_dict.as_data_dict(
skip=['resolution_function', 'interface']
)
assert collection[0]._resolution_function.smearing(5.5) == collection_from_dict[0]._resolution_function.smearing(5.5)
6 changes: 6 additions & 0 deletions tests/sample/elements/materials/test_material_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ def test_from_pars(self):
assert p[0].name == 'Boron'
assert p[1].name == 'Potassium'

def test_empty_list(self):
p = MaterialCollection([])
assert p.name == 'EasyMaterials'
assert p.interface is None
assert len(p) == 0

def test_dict_repr(self):
p = MaterialCollection()
assert p._dict_repr == {
Expand Down

0 comments on commit e6ae923

Please sign in to comment.