Skip to content

Commit

Permalink
return improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
Tomas Stolker committed Mar 2, 2019
1 parent fe8433e commit 700cade
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 93 deletions.
3 changes: 3 additions & 0 deletions pynpoint/core/attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
def get_attributes():
"""
Function to get a dictionary with all attributes.
:return: Attribute information.
:rtype: dict
"""

attr = {'PIXSCALE':{'attribute':'static',
Expand Down
173 changes: 81 additions & 92 deletions pynpoint/core/dataio.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,9 @@ def close_connection(self):
:return: None
"""

if not self.m_open:
return None

self.m_data_bank.close()
self.m_open = False
if self.m_open:
self.m_data_bank.close()
self.m_open = False


class Port(six.with_metaclass(ABCMeta)):
Expand Down Expand Up @@ -681,8 +679,6 @@ def _set_all_key(self,
for key, value in six.iteritems(tmp_attributes):
self._m_data_storage.m_data_bank[tag].attrs[key] = value

return None

def _append_key(self,
tag,
data,
Expand Down Expand Up @@ -712,22 +708,26 @@ def _append_key(self,

# if the dimension offset is 1 add that dimension (e.g. save 2D image in 3D image stack)
if data.ndim + 1 == data_dim:
if data_dim == 3:
data = data[np.newaxis, :, :]

if data_dim == 2:
data = data[np.newaxis, :]
elif data_dim == 3:
data = data[np.newaxis, :, :]

def _type_check():
check_result = False

if tmp_dim == data.ndim:
if tmp_dim == 3:
return (tmp_shape[1] == data.shape[1]) and (tmp_shape[2] == data.shape[2])

if tmp_dim == 1:
check_result = True
elif tmp_dim == 2:
return tmp_shape[1] == data.shape[1]
check_result = tmp_shape[1] == data.shape[1]
elif tmp_dim == 3:
check_result = (tmp_shape[1] == data.shape[1]) and \
(tmp_shape[2] == data.shape[2])

return True

return False
return check_result

if _type_check():
# YES -> dim and type match
Expand Down Expand Up @@ -764,21 +764,19 @@ def __setitem__(self, key, value):
:return: None
"""

if not self._check_status_and_activate():
return None
if self._check_status_and_activate():

self._m_data_storage.m_data_bank[self._m_tag][key] = value
self._m_data_storage.m_data_bank[self._m_tag][key] = value

def del_all_data(self):
"""
Delete all data belonging to the database tag.
"""

if not self._check_status_and_activate():
return None
if self._check_status_and_activate():

if self._m_tag in self._m_data_storage.m_data_bank:
del self._m_data_storage.m_data_bank[self._m_tag]
if self._m_tag in self._m_data_storage.m_data_bank:
del self._m_data_storage.m_data_bank[self._m_tag]

def set_all(self,
data,
Expand Down Expand Up @@ -831,14 +829,12 @@ def set_all(self,

data = np.asarray(data)

# check if port is ready to use
if not self._check_status_and_activate():
return None
if self._check_status_and_activate():

self._set_all_key(self._m_tag,
data,
data_dim,
keep_attributes)
self._set_all_key(tag=self._m_tag,
data=data,
data_dim=data_dim,
keep_attributes=keep_attributes)

def append(self,
data,
Expand Down Expand Up @@ -872,13 +868,12 @@ def append(self,
:return: None
"""

if not self._check_status_and_activate():
return None
if self._check_status_and_activate():

self._append_key(self._m_tag,
data=data,
data_dim=data_dim,
force=force)
self._append_key(self._m_tag,
data=data,
data_dim=data_dim,
force=force)

def activate(self):
"""
Expand Down Expand Up @@ -928,19 +923,18 @@ def add_attribute(self,
:return: None
"""

if not self._check_status_and_activate():
return None
if self._check_status_and_activate():

if self._m_tag not in self._m_data_storage.m_data_bank:
warnings.warn("Can not save attribute while no data exists.")
return None
if self._m_tag not in self._m_data_storage.m_data_bank:
warnings.warn("Can not store attribute if data tag does not exist.")

if static:
self._m_data_storage.m_data_bank[self._m_tag].attrs[name] = value
else:
if static:
self._m_data_storage.m_data_bank[self._m_tag].attrs[name] = value

else:
self._set_all_key(tag=("header_" + self._m_tag + "/" + name),
data=np.asarray(value))
else:
self._set_all_key(tag=("header_" + self._m_tag + "/" + name),
data=np.asarray(value))

def append_attribute_data(self,
name,
Expand All @@ -955,11 +949,10 @@ def append_attribute_data(self,
:return: None
"""

if not self._check_status_and_activate():
return None
if self._check_status_and_activate():

self._append_key(tag=("header_" + self._m_tag + "/" + name),
data=np.asarray([value, ]))
self._append_key(tag=("header_" + self._m_tag + "/" + name),
data=np.asarray([value, ]))

def add_value_to_static_attribute(self,
name,
Expand All @@ -975,17 +968,16 @@ def add_value_to_static_attribute(self,
:return: None
"""

if not self._check_status_and_activate():
return None
if self._check_status_and_activate():

if not isinstance(value, int) or isinstance(value, float):
raise ValueError("Only integer and float values can be added to an existing "
"attribute.")
if not isinstance(value, int) or isinstance(value, float):
raise ValueError("Only integer and float values can be added to an existing "
"attribute.")

if name not in self._m_data_storage.m_data_bank[self._m_tag].attrs:
raise AttributeError("Value can not be added to a not existing attribute.")
if name not in self._m_data_storage.m_data_bank[self._m_tag].attrs:
raise AttributeError("Value can not be added to a not existing attribute.")

self._m_data_storage.m_data_bank[self._m_tag].attrs[name] += value
self._m_data_storage.m_data_bank[self._m_tag].attrs[name] += value

def copy_attributes(self,
input_port):
Expand All @@ -1001,29 +993,28 @@ def copy_attributes(self,
:return: None
"""

if input_port.tag == self._m_tag:
return None
if self._check_status_and_activate() and input_port.tag != self._m_tag:

if not self._check_status_and_activate():
return None
# link non-static attributes
if "header_" + input_port.tag + "/" in self._m_data_storage.m_data_bank:

# link non-static attributes
if "header_" + input_port.tag + "/" in self._m_data_storage.m_data_bank:
for attr_name, attr_data in six.iteritems(self._m_data_storage\
.m_data_bank["header_" + input_port.tag + "/"]):
for attr_name, attr_data in six.iteritems(self._m_data_storage\
.m_data_bank["header_" + input_port.tag + "/"]):

# overwrite existing header information in the database
if "header_" + self._m_tag + "/" + attr_name in self._m_data_storage.m_data_bank:
del self._m_data_storage.m_data_bank["header_" + self._m_tag + "/" + attr_name]
database_name = "header_"+self._m_tag+"/"+attr_name

self._m_data_storage.m_data_bank["header_"+self._m_tag+"/"+attr_name] = attr_data
# overwrite existing header information in the database
if database_name in self._m_data_storage.m_data_bank:
del self._m_data_storage.m_data_bank[database_name]

# copy static attributes
attributes = input_port.get_all_static_attributes()
for attr_name, attr_val in six.iteritems(attributes):
self.add_attribute(attr_name, attr_val)
self._m_data_storage.m_data_bank[database_name] = attr_data

self._m_data_storage.m_data_bank.flush()
# copy static attributes
attributes = input_port.get_all_static_attributes()
for attr_name, attr_val in six.iteritems(attributes):
self.add_attribute(attr_name, attr_val)

self._m_data_storage.m_data_bank.flush()

def del_attribute(self,
name):
Expand All @@ -1037,19 +1028,18 @@ def del_attribute(self,
:return: None
"""

if not self._check_status_and_activate():
return None
if self._check_status_and_activate():

# check if attribute is static
if name in self._m_data_storage.m_data_bank[self._m_tag].attrs:
del self._m_data_storage.m_data_bank[self._m_tag].attrs[name]
# check if attribute is static
if name in self._m_data_storage.m_data_bank[self._m_tag].attrs:
del self._m_data_storage.m_data_bank[self._m_tag].attrs[name]

elif "header_"+self._m_tag+"/"+name in self._m_data_storage.m_data_bank:
# remove non-static attribute
del self._m_data_storage.m_data_bank[("header_" + self._m_tag + "/" + name)]
elif "header_"+self._m_tag+"/"+name in self._m_data_storage.m_data_bank:
# remove non-static attribute
del self._m_data_storage.m_data_bank[("header_" + self._m_tag + "/" + name)]

else:
warnings.warn("Attribute '%s' does not exist and could not be deleted." % name)
else:
warnings.warn("Attribute '%s' does not exist and could not be deleted." % name)

def del_all_attributes(self):
"""
Expand All @@ -1058,16 +1048,15 @@ def del_all_attributes(self):
:return: None
"""

if not self._check_status_and_activate():
return None
if self._check_status_and_activate():

# static attributes
if self._m_tag in self._m_data_storage.m_data_bank:
self._m_data_storage.m_data_bank[self._m_tag].attrs.clear()
# static attributes
if self._m_tag in self._m_data_storage.m_data_bank:
self._m_data_storage.m_data_bank[self._m_tag].attrs.clear()

# non-static attributes
if "header_" + self._m_tag + "/" in self._m_data_storage.m_data_bank:
del self._m_data_storage.m_data_bank[("header_" + self._m_tag + "/")]
# non-static attributes
if "header_" + self._m_tag + "/" in self._m_data_storage.m_data_bank:
del self._m_data_storage.m_data_bank[("header_" + self._m_tag + "/")]

def check_static_attribute(self,
name,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_core/test_outputport.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def test_add_static_attribute_error(self):
# check that only one warning was raised
assert len(warning) == 1
# check that the message matches
assert warning[0].message.args[0] == "Can not save attribute while no data exists."
assert warning[0].message.args[0] == "Can not store attribute if data tag does not exist."

out_port.del_all_attributes()
out_port.del_all_data()
Expand Down

0 comments on commit 700cade

Please sign in to comment.