Skip to content

Commit

Permalink
Merge pull request #189 from dstansby/ADRInfo
Browse files Browse the repository at this point in the history
Add ADRInfo
  • Loading branch information
dstansby committed May 22, 2023
2 parents e0d2a05 + 9fa462b commit cd3b7e8
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 58 deletions.
108 changes: 52 additions & 56 deletions cdflib/cdfread.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import numpy as np

import cdflib.epochs as epoch
from cdflib.dataclasses import AEDR, VDR, CDRInfo, GDRInfo
from cdflib.dataclasses import AEDR, VDR, ADRInfo, CDRInfo, GDRInfo

__all__ = ["CDF"]

Expand Down Expand Up @@ -285,7 +285,7 @@ def varinq(self, variable):

return var

def attinq(self, attribute=None):
def attinq(self, attribute=None) -> ADRInfo:
"""
Get attribute information.
Expand Down Expand Up @@ -372,7 +372,7 @@ def attget(self, attribute=None, entry=None):
name, next_adr = self._read_adr_fast(position)
if name.strip().lower() == attribute.strip().lower():
adr_info = self._read_adr(position)
if isinstance(entry, str) and adr_info["scope"] == 1:
if isinstance(entry, str) and adr_info.scope == 1:
# If the user has specified a string entry, they are obviously looking for a variable attribute.
# Filter out any global attributes that may have the same name.
adr_info = None
Expand All @@ -399,7 +399,7 @@ def attget(self, attribute=None, entry=None):
raise ValueError("Please set attribute keyword equal to " "the name or number of an attribute")

# Find the correct entry from the "entry" variable
if adr_info["scope"] == 1:
if adr_info.scope == 1:
if not isinstance(entry, int):
raise ValueError('"entry" must be an integer')
num_entry_string = "num_gr_entry"
Expand Down Expand Up @@ -445,9 +445,9 @@ def attget(self, attribute=None, entry=None):
num_entry_string = "num_gr_entry"
first_entry_string = "first_gr_entry"
max_entry_string = "max_gr_entry"
if entry_num > adr_info[max_entry_string]:
if entry_num > getattr(adr_info, max_entry_string):
raise ValueError("The entry does not exist")
return self._get_attdata(adr_info, entry_num, adr_info[num_entry_string], adr_info[first_entry_string])
return self._get_attdata(adr_info, entry_num, getattr(adr_info, num_entry_string), getattr(adr_info, first_entry_string))

def varget(self, variable=None, epoch=None, starttime=None, endtime=None, startrec=0, endrec=None, record_range_only=False):
"""
Expand Down Expand Up @@ -584,15 +584,15 @@ def globalattsget(self):
return_dict: Dict[str, List[Union[str, int]]] = {}
for _ in range(self._num_att):
adr_info = self._read_adr(byte_loc)
if adr_info["scope"] != 1:
byte_loc = adr_info["next_adr_location"]
if adr_info.scope != 1:
byte_loc = adr_info.next_adr_loc
continue
if adr_info["num_gr_entry"] == 0:
byte_loc = adr_info["next_adr_location"]
if adr_info.num_gr_entry == 0:
byte_loc = adr_info.next_adr_loc
continue
entries = []
aedr_byte_loc = adr_info["first_gr_entry"]
for _ in range(adr_info["num_gr_entry"]):
aedr_byte_loc = adr_info.first_gr_entry
for _ in range(adr_info.num_gr_entry):
aedr_info = self._read_aedr(aedr_byte_loc)
entryData = aedr_info.entry
# This exists to get rid of extraneous numpy arrays
Expand All @@ -602,8 +602,8 @@ def globalattsget(self):

entries.append(entryData)

return_dict[adr_info["name"]] = entries
byte_loc = adr_info["next_adr_location"]
return_dict[adr_info.name] = entries
byte_loc = adr_info.next_adr_loc

return return_dict

Expand Down Expand Up @@ -858,9 +858,9 @@ def _get_attnames(self):
for _ in range(0, self._num_att):
attr = {}
adr_info = self._read_adr(position)
attr[adr_info["name"]] = self._scope_token(int(adr_info["scope"]))
attr[adr_info.name] = self._scope_token(adr_info.scope)
attrs.append(attr)
position = adr_info["next_adr_location"]
position = adr_info.next_adr_loc
return attrs

def _read_cdr(self, byte_loc: int) -> Tuple[CDRInfo, int]:
Expand Down Expand Up @@ -1016,15 +1016,15 @@ def _read_varatts(self, var_num, zVar):
return_dict = {}
for z in range(0, self._num_att):
adr_info = self._read_adr(byte_loc)
if adr_info["scope"] == 1:
byte_loc = adr_info["next_adr_location"]
if adr_info.scope == 1:
byte_loc = adr_info.next_adr_loc
continue
if zVar:
byte_loc = adr_info["first_z_entry"]
num_entry = adr_info["num_z_entry"]
byte_loc = adr_info.first_z_entry
num_entry = adr_info.num_z_entry
else:
byte_loc = adr_info["first_gr_entry"]
num_entry = adr_info["num_gr_entry"]
byte_loc = adr_info.first_gr_entry
num_entry = adr_info.num_gr_entry
found = 0
for _ in range(0, num_entry):
entryNum, byte_next = self._read_aedr_fast(byte_loc)
Expand All @@ -1037,15 +1037,15 @@ def _read_varatts(self, var_num, zVar):
if isinstance(entryData, np.ndarray):
if len(entryData) == 1:
entryData = entryData[0]
return_dict[adr_info["name"]] = entryData
return_dict[adr_info.name] = entryData
found = 1
break
byte_loc = adr_info["next_adr_location"]
byte_loc = adr_info.next_adr_loc
if found == 0:
return_dict[adr_info["name"]] = None
return_dict[adr_info.name] = None
return return_dict

def _read_adr(self, position):
def _read_adr(self, position: int) -> ADRInfo:
"""
Read an attribute descriptor record (ADR).
"""
Expand All @@ -1054,7 +1054,7 @@ def _read_adr(self, position):
else:
return self._read_adr2(position)

def _read_adr3(self, byte_loc):
def _read_adr3(self, byte_loc: int) -> ADRInfo:
self._f.seek(byte_loc, 0)
block_size = int.from_bytes(self._f.read(8), "big") # Block Size
adr = self._f.read(block_size - 8)
Expand All @@ -1072,22 +1072,20 @@ def _read_adr3(self, byte_loc):
name = str(adr[60:315].decode(self.string_encoding))
name = name.replace("\x00", "")

# Build the return dictionary
return_dict: Dict[str, Union[str, int]] = {}
return_dict["scope"] = scope
return_dict["next_adr_location"] = next_adr_loc
return_dict["attribute_number"] = num
return_dict["num_gr_entry"] = num_gr_entry
return_dict["max_gr_entry"] = MaxEntry
return_dict["num_z_entry"] = num_z_entry
return_dict["max_z_entry"] = MaxZEntry
return_dict["first_z_entry"] = position_next_z_entry
return_dict["first_gr_entry"] = position_next_gr_entry
return_dict["name"] = name

return return_dict
return ADRInfo(
scope,
next_adr_loc,
num,
num_gr_entry,
MaxEntry,
num_z_entry,
MaxZEntry,
position_next_z_entry,
position_next_gr_entry,
name,
)

def _read_adr2(self, byte_loc):
def _read_adr2(self, byte_loc: int) -> ADRInfo:
self._f.seek(byte_loc, 0)
block_size = int.from_bytes(self._f.read(4), "big") # Block Size
adr = self._f.read(block_size - 4)
Expand All @@ -1105,20 +1103,18 @@ def _read_adr2(self, byte_loc):
name = str(adr[48:112].decode(self.string_encoding))
name = name.replace("\x00", "")

# Build the return dictionary
return_dict: Dict[str, Union[int, str]] = {}
return_dict["scope"] = scope
return_dict["next_adr_location"] = next_adr_loc
return_dict["attribute_number"] = num
return_dict["num_gr_entry"] = num_gr_entry
return_dict["max_gr_entry"] = MaxEntry
return_dict["num_z_entry"] = num_z_entry
return_dict["max_z_entry"] = MaxZEntry
return_dict["first_z_entry"] = position_next_z_entry
return_dict["first_gr_entry"] = position_next_gr_entry
return_dict["name"] = name

return return_dict
return ADRInfo(
scope,
next_adr_loc,
num,
num_gr_entry,
MaxEntry,
num_z_entry,
MaxZEntry,
position_next_z_entry,
position_next_gr_entry,
name,
)

def _read_adr_fast(self, position):
"""
Expand Down
14 changes: 14 additions & 0 deletions cdflib/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,20 @@
import numpy as np


@dataclass
class ADRInfo:
scope: int
next_adr_loc: int
attribute_number: int
num_gr_entry: int
max_gr_entry: int
num_z_entry: int
max_z_entry: int
first_z_entry: int
first_gr_entry: int
name: str


@dataclass
class CDRInfo:
encoding: int
Expand Down
1 change: 1 addition & 0 deletions doc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ Breaking changes
- The ``expand`` keyword argument to ``CDF.globalattsget`` and ``CDF.varattsget`` has been removed.
Use ``CDF.attinq`` to get attribute information instead.
- ``CDF.vdr_info`` now returns a dataclass instead of a dict.
- ``CDF.attinq`` now returns a dataclass instead of a dict.

Bugfixes
--------
Expand Down
4 changes: 2 additions & 2 deletions tests/test_cdfwrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,10 @@ def test_globalattrs(tmp_path):

# Test CDF info
attrib = reader.attinq("Global2")
assert attrib["num_gr_entry"] == 1
assert attrib.num_gr_entry == 1

attrib = reader.attinq("Global6")
assert attrib["num_gr_entry"] == 4
assert attrib.num_gr_entry == 4

entry = reader.attget("Global6", 3)
assert entry["Data_Type"] == "CDF_INT8"
Expand Down

0 comments on commit cd3b7e8

Please sign in to comment.