Skip to content

Commit

Permalink
Merge pull request #947 from ipelupessy/fix_856
Browse files Browse the repository at this point in the history
make sure vectorattributes definitions are stored and copied
  • Loading branch information
ipelupessy committed Apr 30, 2023
2 parents 0eac17b + 15469d8 commit 5daf44e
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 3 deletions.
2 changes: 2 additions & 0 deletions src/amuse/datamodel/particles.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ def copy(self, memento = None, keep_structure = False, filter_attributes = lambd
converted.append(x)
result.add_particles_to_store(keys, attributes, converted)

object.__setattr__(result, "_derived_attributes", CompositeDictionary(self._derived_attributes))
result._private.collection_attributes = self._private.collection_attributes._copy_for_collection(result)

return result
Expand Down Expand Up @@ -425,6 +426,7 @@ def copy_to_new_particles(self, keys = None, keys_generator = None, memento = No

result.add_particles_to_store(particle_keys, attributes, converted)

object.__setattr__(result, "_derived_attributes", CompositeDictionary(self._derived_attributes))
result._private.collection_attributes = self._private.collection_attributes._copy_for_collection(result)

return result
Expand Down
25 changes: 22 additions & 3 deletions src/amuse/io/store_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from amuse.datamodel import Grid
from amuse.datamodel import GridPoint
from amuse.datamodel import AbstractSet
from amuse.datamodel.base import VectorAttribute


from amuse.io import store_v1

Expand Down Expand Up @@ -847,6 +849,7 @@ def store_particles(self, particles, extra_attributes = {}, parent=None, mapping
self.hdf5file.flush()
self.store_collection_attributes(particles, group, extra_attributes, links)
self.store_values(particles, group, links)
self.store_selected_derived_attributes(particles,group)

mapping_from_setid_to_group[id(particles)] = group

Expand All @@ -868,6 +871,7 @@ def store_grid(self, grid, extra_attributes = {}, parent=None, mapping_from_seti
compression_opts=self.compression_opts,
)

self.store_selected_derived_attributes(grid, group)
self.store_collection_attributes(grid, group, extra_attributes, links)
self.store_values(grid, group, links)

Expand Down Expand Up @@ -924,8 +928,15 @@ def store_values(self, container, group, links = []):
)
dataset.attrs["units"] = "none".encode('ascii')

def store_selected_derived_attributes(self, container, group):
saving=dict()
for key in container._derived_attributes.keys():
attr=container._derived_attributes[key]
if key not in container.GLOBAL_DERIVED_ATTRIBUTES and isinstance(attr, VectorAttribute):
saving[key]=attr
if len(saving):
group.attrs["extra_vector_attributes"]=pickle_to_string(saving)


def store_linked_array(self, attribute, attributes_group, quantity, group, links):
subgroup = attributes_group.create_group(attribute)
shape = quantity.shape
Expand Down Expand Up @@ -1127,8 +1138,9 @@ def load_particles_from_group(self, group):

self.mapping_from_groupid_to_set[group.id] = particles
self.load_collection_attributes(particles, group)


if "extra_vector_attributes" in group.attrs.keys():
self.load_extra_derived_attributes(particles, group)

return particles

def load_grid_from_group(self, group):
Expand All @@ -1143,9 +1155,16 @@ def load_grid_from_group(self, group):
container._private.attribute_storage = HDF5GridAttributeStorage(shape, group, self)
self.mapping_from_groupid_to_set[group.id] = container
self.load_collection_attributes(container, group)
if "extra_vector_attributes" in group.attrs.keys():
self.load_extra_derived_attributes(container, group)

return container

def load_extra_derived_attributes(self, container, group):
attrs=unpickle_from_string(group.attrs["extra_vector_attributes"])
for key, attr in attrs.items():
container._derived_attributes[key]=attr

def load_from_group(self, group):
container_type = group.attrs['type'] if isinstance(group.attrs['type'], str) else group.attrs['type'].decode('ascii')

Expand Down
32 changes: 32 additions & 0 deletions src/amuse/test/suite/ticket_tests/test_github856.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import os
from amuse.test import amusetest

from amuse.datamodel import new_cartesian_grid, Particles

from amuse.io import read_set_from_file, write_set_to_file

class test_github856(amusetest.TestCase):
def test1(self):
filename=os.path.join(self.get_path_to_results(),"github856.amuse")

g1=new_cartesian_grid((5,5), 1)
write_set_to_file(g1,filename,"amuse")
del g1
g2=read_set_from_file(filename,"amuse")
self.assertEquals(g2.get_axes_names(),"xy")

def test2(self):
g1=Particles(lon=[1,2], lat=[3,4])
g1.add_vector_attribute("lonlat",["lon","lat"])
g2=g1.copy()
self.assertEquals(g2.lonlat,[[1,3],[2,4]])

def test3(self):
filename=os.path.join(self.get_path_to_results(),"github856_2.amuse")

g1=Particles(lon=[1,2], lat=[3,4])
g1.add_vector_attribute("lonlat",["lon","lat"])
write_set_to_file(g1,filename,"amuse")
del g1
g2=read_set_from_file(filename,"amuse")
self.assertEquals(g2.lonlat,[[1,3],[2,4]])

0 comments on commit 5daf44e

Please sign in to comment.