Skip to content

Commit

Permalink
Merge pull request #203 from dstansby/fix-squeezing
Browse files Browse the repository at this point in the history
Squeezing fixes
  • Loading branch information
dstansby committed May 26, 2023
2 parents 41cd64d + 9324436 commit ae214f6
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 68 deletions.
10 changes: 1 addition & 9 deletions cdflib/cdf_to_xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,7 @@ def _convert_cdf_time_types(data, atts: Dict[str, AttData], properties: VDRInfo,
# If nothing, ALL CDF_EPOCH16 types are converted to CDF_EPOCH, because xarray can't handle int64s
"""

data = np.squeeze(data)

if not hasattr(data, "__len__"):
data = [data]

try:
len(data)
except Exception:
data = [data]
data = np.atleast_1d(data)

if to_datetime and to_unixtime:
print("Cannot convert to both unixtime and datetime. Continuing with conversion to unixtime.")
Expand Down
70 changes: 35 additions & 35 deletions cdflib/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import numpy as np
import numpy.typing as npt

from cdflib.utils import _squeeze_or_scalar_complex, _squeeze_or_scalar_real

LEAPSEC_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "CDFLeapSeconds.txt")

epochs_type = Union[str, List[float], List[int], List[complex], Tuple[float, ...], Tuple[int, ...], Tuple[complex, ...], np.ndarray]
Expand Down Expand Up @@ -186,7 +188,7 @@ def to_datetime(cls, cdf_time: epoch_types) -> npt.NDArray[np.datetime64]:
return cls._compose_date(*times.T[:9]).astype("datetime64[us]")

@staticmethod
def unixtime(cdf_time: npt.ArrayLike) -> npt.NDArray:
def unixtime(cdf_time: npt.ArrayLike) -> Union[float, npt.NDArray]:
"""
Converts CDF epoch argument into seconds after 1970-01-01. This method
converts a scalar, or array-like. Precision is only kept to the
Expand All @@ -211,39 +213,39 @@ def unixtime(cdf_time: npt.ArrayLike) -> npt.NDArray:
unixtime.append(
datetime.datetime(date[0], date[1], date[2], date[3], date[4], date[5], date[6], tzinfo=utc).timestamp()
)
return np.squeeze(unixtime)
return _squeeze_or_scalar_real(unixtime)

@staticmethod
def compute(datetimes: npt.ArrayLike) -> npt.NDArray:
def compute(datetimes: npt.ArrayLike) -> Union[float, complex, npt.NDArray]:
"""
Computes the provided date/time components into CDF epoch value(s).
For CDF_EPOCH:
For computing into CDF_EPOCH value, each date/time elements should
have exactly seven (7) components, as year, month, day, hour, minute,
second and millisecond, in a list. For example:
[[2017,1,1,1,1,1,111],[2017,2,2,2,2,2,222]]
Or, call function compute_epoch directly, instead, with at least three
(3) first (up to seven) components. The last component, if
not the 7th, can be a float that can have a fraction of the unit.
For computing into CDF_EPOCH value, each date/time elements should
have exactly seven (7) components, as year, month, day, hour, minute,
second and millisecond, in a list. For example:
[[2017,1,1,1,1,1,111],[2017,2,2,2,2,2,222]]
Or, call function compute_epoch directly, instead, with at least three
(3) first (up to seven) components. The last component, if
not the 7th, can be a float that can have a fraction of the unit.
For CDF_EPOCH16:
They should have exactly ten (10) components, as year,
month, day, hour, minute, second, millisecond, microsecond, nanosecond
and picosecond, in a list. For example:
[[2017,1,1,1,1,1,123,456,789,999],[2017,2,2,2,2,2,987,654,321,999]]
Or, call function compute_epoch directly, instead, with at least three
(3) first (up to ten) components. The last component, if
not the 10th, can be a float that can have a fraction of the unit.
They should have exactly ten (10) components, as year,
month, day, hour, minute, second, millisecond, microsecond, nanosecond
and picosecond, in a list. For example:
[[2017,1,1,1,1,1,123,456,789,999],[2017,2,2,2,2,2,987,654,321,999]]
Or, call function compute_epoch directly, instead, with at least three
(3) first (up to ten) components. The last component, if
not the 10th, can be a float that can have a fraction of the unit.
For TT2000:
Each TT2000 typed date/time should have exactly nine (9) components, as
year, month, day, hour, minute, second, millisecond, microsecond,
and nanosecond, in a list. For example:
[[2017,1,1,1,1,1,123,456,789],[2017,2,2,2,2,2,987,654,321]]
Or, call function compute_tt2000 directly, instead, with at least three
(3) first (up to nine) components. The last component, if
not the 9th, can be a float that can have a fraction of the unit.
Each TT2000 typed date/time should have exactly nine (9) components, as
year, month, day, hour, minute, second, millisecond, microsecond,
and nanosecond, in a list. For example:
[[2017,1,1,1,1,1,123,456,789],[2017,2,2,2,2,2,987,654,321]]
Or, call function compute_tt2000 directly, instead, with at least three
(3) first (up to nine) components. The last component, if
not the 9th, can be a float that can have a fraction of the unit.
"""

if not isinstance(datetimes, (list, tuple, np.ndarray)):
Expand All @@ -253,16 +255,14 @@ def compute(datetimes: npt.ArrayLike) -> npt.NDArray:
items = datetimes.shape[1]

if items == 7:
ret = CDFepoch.compute_epoch(datetimes)
return _squeeze_or_scalar_real(CDFepoch.compute_epoch(datetimes))
elif items == 10:
ret = CDFepoch.compute_epoch16(datetimes)
return _squeeze_or_scalar_complex(CDFepoch.compute_epoch16(datetimes))
elif items == 9:
ret = CDFepoch.compute_tt2000(datetimes)
return _squeeze_or_scalar_real(CDFepoch.compute_tt2000(datetimes))
else:
raise TypeError("Unknown input")

return np.squeeze(ret)

@staticmethod
def findepochrange(
epochs: epochs_type, starttime: Optional[epoch_types] = None, endtime: Optional[epoch_types] = None
Expand Down Expand Up @@ -490,7 +490,7 @@ def breakdown_tt2000(tt2000: cdf_tt2000_type) -> np.ndarray:
return np.squeeze(toutcs.T)

@staticmethod
def compute_tt2000(datetimes: npt.ArrayLike) -> npt.NDArray[np.int64]:
def compute_tt2000(datetimes: npt.ArrayLike) -> Union[int, npt.NDArray[np.int64]]:
if not isinstance(datetimes, (list, tuple, np.ndarray)):
raise TypeError("datetime must be in list form")

Expand Down Expand Up @@ -829,7 +829,7 @@ def _JulianDay(y: int, m: int, d: int) -> int:
return 367 * y - a1 - a2 + a3 + d + 1721029

@staticmethod
def compute_epoch16(datetimes: npt.ArrayLike) -> npt.NDArray[np.complex128]:
def compute_epoch16(datetimes: npt.ArrayLike) -> Union[complex, npt.NDArray[np.complex128]]:
new_dates = np.atleast_2d(datetimes)
count = len(new_dates)
epochs = []
Expand Down Expand Up @@ -992,7 +992,7 @@ def compute_epoch16(datetimes: npt.ArrayLike) -> npt.NDArray[np.complex128]:
cepoch = complex(epoch[0], epoch[1])
epochs.append(cepoch)

return np.squeeze(epochs)
return _squeeze_or_scalar_complex(epochs)

@staticmethod
def _calc_from_julian(epoch0: npt.ArrayLike, epoch1: npt.ArrayLike) -> npt.NDArray:
Expand Down Expand Up @@ -1276,8 +1276,8 @@ def _encodex_epoch(epoch: cdf_epoch_type, iso_8601: bool = True) -> str:
return encoded

@staticmethod
def compute_epoch(dates: npt.ArrayLike) -> np.ndarray:
# TODOL Add docstring. What is the output format?
def compute_epoch(dates: npt.ArrayLike) -> Union[float, npt.NDArray]:
# TODO Add docstring. What is the output format?

new_dates = np.atleast_2d(dates)
count = new_dates.shape[0]
Expand Down Expand Up @@ -1366,7 +1366,7 @@ def compute_epoch(dates: npt.ArrayLike) -> np.ndarray:
return np.array(86400000.0 * daysSince0AD + msecInDay)
epochs.append(86400000.0 * daysSince0AD + msecInDay)

return np.squeeze(epochs)
return _squeeze_or_scalar_real(epochs)

@staticmethod
def _computeEpoch(y: int, m: int, d: int, h: int, mn: int, s: int, ms: int) -> float:
Expand Down
20 changes: 20 additions & 0 deletions cdflib/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from typing import Union

import numpy as np
import numpy.typing as npt


def _squeeze_or_scalar_real(arr: npt.ArrayLike) -> Union[npt.NDArray, float]:
arr = np.squeeze(arr)
if arr.ndim == 0:
return arr.item()
else:
return arr


def _squeeze_or_scalar_complex(arr: npt.ArrayLike) -> Union[npt.NDArray, complex]:
arr = np.squeeze(arr)
if arr.ndim == 0:
return arr.item()
else:
return arr
2 changes: 1 addition & 1 deletion cdflib/xarray_to_cdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ def _unixtime_to_tt2000(unixtime_data):
int(dt.microsecond % 1000),
0,
]
converted_data = float(cdfepoch.compute(dt_to_convert))
converted_data = cdfepoch.compute(dt_to_convert)
else:
converted_data = np.nan

Expand Down
3 changes: 2 additions & 1 deletion tests/test_epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def test_encode_cdftt2000():

def test_unixtime():
x = cdfepoch.unixtime([500000000100, 123456789101112131])
assert isinstance(x, np.ndarray)
assert x[0] == 946728435.816
assert x[1] == 1070184724.917112

Expand All @@ -93,7 +94,7 @@ def test_unixtime_roundtrip(tzone):
y, m, d = 2000, 1, 1
epoch = cdfepoch.compute_tt2000([[y, m, d]])
unixtime = cdfepoch.unixtime(epoch)
assert unixtime == [946684800.0]
assert unixtime == 946684800.0
finally:
os.environ.clear()
os.environ.update(_environ)
Expand Down
31 changes: 9 additions & 22 deletions tests/test_xarray_reader_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,38 +19,25 @@


@pytest.mark.remote_data
def test_mms_fpi():
def test_mms_fpi(tmp_path):
fname = "mms1_fpi_brst_l2_des-moms_20151016130334_v3.3.0.cdf"
url = (
"https://lasp.colorado.edu/maven/sdc/public/data/sdc/web/cdflib_testing/mms1_fpi_brst_l2_des-moms_20151016130334_v3.3.0.cdf"
)
url = f"https://lasp.colorado.edu/maven/sdc/public/data/sdc/web/cdflib_testing/{fname}"
if not os.path.exists(fname):
urllib.request.urlretrieve(url, fname)

a = cdflib.cdf_to_xarray("mms1_fpi_brst_l2_des-moms_20151016130334_v3.3.0.cdf", to_unixtime=True, fillval_to_nan=True)
a = cdflib.cdf_to_xarray(fname, to_unixtime=True, fillval_to_nan=True)

cdflib.xarray_to_cdf(a, "mms1_fpi_brst_l2_des-moms_20151016130334_v3.3.0-created-from-cdf-input.cdf", from_unixtime=True)
b = cdflib.cdf_to_xarray(
"mms1_fpi_brst_l2_des-moms_20151016130334_v3.3.0-created-from-cdf-input.cdf", to_unixtime=True, fillval_to_nan=True
)
os.remove("mms1_fpi_brst_l2_des-moms_20151016130334_v3.3.0-created-from-cdf-input.cdf")
os.remove("mms1_fpi_brst_l2_des-moms_20151016130334_v3.3.0.cdf")
cdflib.xarray_to_cdf(a, tmp_path / fname, from_unixtime=True)
b = cdflib.cdf_to_xarray(tmp_path / fname, to_unixtime=True, fillval_to_nan=True)

fname = "mms1_fpi_brst_l2_des-moms_20151016130334_v3.3.0.nc"
url = (
"https://lasp.colorado.edu/maven/sdc/public/data/sdc/web/cdflib_testing/mms1_fpi_brst_l2_des-moms_20151016130334_v3.3.0.nc"
)
url = f"https://lasp.colorado.edu/maven/sdc/public/data/sdc/web/cdflib_testing/{fname}"
if not os.path.exists(fname):
urllib.request.urlretrieve(url, fname)

c = xr.load_dataset("mms1_fpi_brst_l2_des-moms_20151016130334_v3.3.0.nc")

cdflib.xarray_to_cdf(c, "mms1_fpi_brst_l2_des-moms_20151016130334_v3.3.0-created-from-netcdf-input.cdf")
d = cdflib.cdf_to_xarray(
"mms1_fpi_brst_l2_des-moms_20151016130334_v3.3.0-created-from-netcdf-input.cdf", to_unixtime=True, fillval_to_nan=True
)
os.remove("mms1_fpi_brst_l2_des-moms_20151016130334_v3.3.0-created-from-netcdf-input.cdf")
os.remove("mms1_fpi_brst_l2_des-moms_20151016130334_v3.3.0.nc")
c = xr.load_dataset(fname)
cdflib.xarray_to_cdf(c, tmp_path / fname)
d = cdflib.cdf_to_xarray(tmp_path / fname, to_unixtime=True, fillval_to_nan=True)


@pytest.mark.remote_data
Expand Down

0 comments on commit ae214f6

Please sign in to comment.