Skip to content

Commit

Permalink
Merge pull request #205 from dstansby/untyped-cdf
Browse files Browse the repository at this point in the history
More typing!
  • Loading branch information
dstansby committed May 26, 2023
2 parents df6eebf + a9c5ec3 commit 1b528a1
Show file tree
Hide file tree
Showing 8 changed files with 76 additions and 66 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,4 @@ repos:
rev: 'v1.3.0'
hooks:
- id: mypy
additional_dependencies: [xarray]
64 changes: 35 additions & 29 deletions cdflib/cdf_to_xarray.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import re
from typing import Dict, Union
from typing import Any, Dict, List, Tuple, Union

import numpy as np
import numpy.typing as npt
import xarray as xr

from cdflib import CDF
from cdflib.dataclasses import AttData, VDRInfo
from cdflib.dataclasses import VDRInfo
from cdflib.epochs import CDFepoch as cdfepoch

ISTP_TO_XARRAY_ATTRS = {"FIELDNAM": "standard_name", "LABLAXIS": "long_name", "UNITS": "units"}


def _find_xarray_plotting_values(var_att_dict):
def _find_xarray_plotting_values(var_att_dict) -> Dict[str, str]:
"""
This is a simple function that looks through a variable attribute dictionary for ISTP attributes that are similar
to ones used natively by Xarray, specifically their plotting routines. If some are found, this returns a dictionary
Expand All @@ -27,7 +29,9 @@ def _find_xarray_plotting_values(var_att_dict):
return xarray_att_dict


def _convert_cdf_time_types(data, atts: Dict[str, AttData], properties: VDRInfo, to_datetime=False, to_unixtime=False):
def _convert_cdf_time_types(
data: npt.ArrayLike, atts, properties: VDRInfo, to_datetime: bool = False, to_unixtime: bool = False
) -> Tuple[npt.NDArray, Dict[str, Any]]:
"""
# Converts CDF time types into either datetime objects, unixtime, or nothing
# If nothing, ALL CDF_EPOCH16 types are converted to CDF_EPOCH, because xarray can't handle int64s
Expand Down Expand Up @@ -62,9 +66,7 @@ def _convert_cdf_time_types(data, atts: Dict[str, AttData], properties: VDRInfo,
new_atts = {}
for att in atts:
data_type = atts[att].Data_Type
data = atts[att].Data
if not hasattr(data, "__len__"):
data = [data]
data = np.atleast_1d(atts[att]["Data"])
if len(data) == 0 or data_type not in ("CDF_EPOCH", "CDF_EPOCH16", "CDF_TIME_TT2000"):
new_atts[att] = data
else:
Expand All @@ -81,7 +83,9 @@ def _convert_cdf_time_types(data, atts: Dict[str, AttData], properties: VDRInfo,
return new_data, new_atts


def _convert_cdf_to_dicts(filename, to_datetime=False, to_unixtime=False):
def _convert_cdf_to_dicts(
filename, to_datetime: bool = False, to_unixtime: bool = False
) -> Tuple[Dict[str, List[Union[str, np.ndarray]]], Dict[str, Any], Dict[str, npt.NDArray], Dict[str, VDRInfo]]:
# Open the CDF file
# Converts the entire CDF file into python dictionary objects

Expand All @@ -96,9 +100,9 @@ def _convert_cdf_to_dicts(filename, to_datetime=False, to_unixtime=False):
gatt = {}

# Gather all information about the CDF file, and store in the below dictionaries
variable_data = {}
variable_attributes = {}
variable_properties = {}
variable_data: Dict[str, npt.NDArray] = {}
variable_attributes: Dict[str, Any] = {}
variable_properties: Dict[str, VDRInfo] = {}

for var_name in all_cdf_variables:
var_attribute_list = cdf_file.varattsget(var_name)
Expand All @@ -122,7 +126,7 @@ def _convert_cdf_to_dicts(filename, to_datetime=False, to_unixtime=False):

def _verify_depend_dimensions(
dataset, dimension_number, primary_variable_name, coordinate_variable_name, primary_variable_properties: VDRInfo
):
) -> bool:
primary_data = np.array(dataset[primary_variable_name])
coordinate_data = np.array(dataset[coordinate_variable_name])

Expand Down Expand Up @@ -183,7 +187,7 @@ def _verify_depend_dimensions(
return True


def _discover_depend_variables(vardata, varatts, varprops):
def _discover_depend_variables(vardata, varatts, varprops) -> List[str]:
# This loops through the variable attributes to discover which variables are the coordinates of other variables,
# Unfortunately, there is no easy way to tell this by looking at the variable ITSELF,
# you need to look at all variables and see if one points to it.
Expand All @@ -202,7 +206,7 @@ def _discover_depend_variables(vardata, varatts, varprops):
return list(set(list_of_depend_vars))


def _discover_uncertainty_variables(varatts):
def _discover_uncertainty_variables(varatts) -> Dict[str, str]:
# This loops through the variable attributes to discover which variables are the labels of other variables
# Unfortunately, there is no easy way to tell this by looking at the label variable itself
# This returns a KEY:VALUE pair, with the LABEL VARIABLE corresponding to which dimension it covers.
Expand All @@ -217,12 +221,12 @@ def _discover_uncertainty_variables(varatts):
return list_of_label_vars


def _discover_label_variables(varatts, all_variable_properties: Dict[str, VDRInfo], all_variable_data):
def _discover_label_variables(varatts, all_variable_properties: Dict[str, VDRInfo], all_variable_data) -> Dict[str, str]:
# This loops through the variable attributes to discover which variables are the labels of other variables
# Unfortunately, there is no easy way to tell this by looking at the label variable itself
# This returns a KEY:VALUE pair, with the LABEL VARIABLE corresponding to which dimension it covers.

list_of_label_vars = {}
list_of_label_vars: Dict[str, str] = {}

for v in varatts:
label_keys = [x for x in list(varatts[v].keys()) if x.startswith("LABL_PTR_")]
Expand Down Expand Up @@ -256,7 +260,7 @@ def _discover_label_variables(varatts, all_variable_properties: Dict[str, VDRInf
return list_of_label_vars


def _convert_fillvals_to_nan(var_data, var_atts, var_properties: VDRInfo):
def _convert_fillvals_to_nan(var_data, var_atts, var_properties: VDRInfo) -> npt.NDArray:
if var_atts is None:
return var_data
if var_data is None:
Expand Down Expand Up @@ -291,7 +295,7 @@ def _determine_record_dimensions(
all_variable_data,
all_variable_properties,
created_unlimited_dims,
):
) -> Tuple[str, bool, bool]:
"""
Determines the name of the
:param var_name:
Expand Down Expand Up @@ -392,7 +396,7 @@ def _determine_dimension_names(
all_variable_properties,
created_regular_dims,
record_name_found,
):
) -> List[Tuple[str, int, bool, bool]]:
"""
:param var_name:
:param var_atts:
Expand Down Expand Up @@ -522,7 +526,7 @@ def _determine_dimension_names(
return return_list


def _reformat_variable_dims_and_data(var_dims, var_data):
def _reformat_variable_dims_and_data(var_dims, var_data) -> Tuple[str, npt.NDArray]:
if len(var_dims) > 0 and var_data is None:
var_data = np.array([])

Expand All @@ -540,7 +544,9 @@ def _reformat_variable_dims_and_data(var_dims, var_data):
return var_dims, var_data


def _generate_xarray_data_variables(all_variable_data, all_variable_attributes, all_variable_properties, fillval_to_nan):
def _generate_xarray_data_variables(
all_variable_data, all_variable_attributes, all_variable_properties, fillval_to_nan
) -> Tuple[Dict[str, "xr.Variable"], Dict[str, int]]:
# Import here to avoid xarray as a dependency of all of cdflib
import xarray as xr

Expand All @@ -551,10 +557,10 @@ def _generate_xarray_data_variables(all_variable_data, all_variable_attributes,
created_regular_dims: Dict[
str, int
] = {} # These hold the records of the names/lengths of the standard dimensions of the variable
depend_dimensions = (
{}
) # This will be used after the creation of DataArrays, to determine which are "data" and which are "coordinates"
created_vars = {}
depend_dimensions: Dict[
str, int
] = {} # This will be used after the creation of DataArrays, to determine which are "data" and which are "coordinates"
created_vars: Dict[str, xr.Variable] = {}

for var_name in all_variable_data:
var_dims = []
Expand Down Expand Up @@ -616,14 +622,14 @@ def _generate_xarray_data_variables(all_variable_data, all_variable_attributes,

# Finally, create the new variable
try:
created_vars[var_name] = xr.Variable(var_dims, var_data, attrs=var_atts)
created_vars[var_name] = xr.Variable(var_dims, var_data, attrs=var_atts) # type: ignore[no-untyped-call]
except Exception as e:
print(f"ERROR: Creating Variable {var_name} ran into exception: {e}")

return created_vars, depend_dimensions


def _verify_dimension_sizes(created_data_vars, created_coord_vars):
def _verify_dimension_sizes(created_data_vars, created_coord_vars) -> None:
for var in created_data_vars:
for d in created_data_vars[var].dims:
if d in created_data_vars:
Expand Down Expand Up @@ -666,7 +672,7 @@ def _verify_dimension_sizes(created_data_vars, created_coord_vars):
)


def cdf_to_xarray(filename, to_datetime=False, to_unixtime=False, fillval_to_nan=False):
def cdf_to_xarray(filename, to_datetime=False, to_unixtime=False, fillval_to_nan=False) -> xr.Dataset:
"""
This function converts CDF files into XArray Dataset Objects.
Expand Down Expand Up @@ -780,7 +786,7 @@ def cdf_to_xarray(filename, to_datetime=False, to_unixtime=False, fillval_to_nan
else:
created_vars[lab].dims = created_vars[var_name].dims
else:
created_vars[lab].dims = created_vars[var_name].dims[-1]
created_vars[lab].dims = (created_vars[var_name].dims[-1],)
# Add the labels to the coordinates as well
created_coord_vars[lab] = created_vars[lab]
elif var_name in uncertainty_variables:
Expand Down
19 changes: 10 additions & 9 deletions cdflib/epochs_astropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"""
import datetime
from datetime import timezone
from typing import List, Union, Optional
from typing import List, Optional, Tuple, Union

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -54,7 +54,7 @@ class CDFAstropy:
increment = 0

@staticmethod
def convert_to_astropy(epochs: Union[Time, npt.ArrayLike], format: Optional[str]=None) -> Time:
def convert_to_astropy(epochs: Union[Time, npt.ArrayLike], format: Optional[str] = None) -> Time:
"""
Convert CDF epochs to astropy time objects.
Expand Down Expand Up @@ -91,7 +91,7 @@ def encode(epochs: npt.ArrayLike, iso_8601: bool = True) -> npt.NDArray[np.str_]
return epochs.strftime("%d-%b-%Y %H:%M:%S.%f")

@staticmethod
def breakdown(epochs):
def breakdown(epochs: Union[Time, npt.ArrayLike]) -> npt.NDArray:
# Returns either a single array, or a array of arrays depending on the input
epochs = CDFAstropy.convert_to_astropy(epochs)
if epochs.format == "cdf_tt2000":
Expand All @@ -108,7 +108,7 @@ def to_datetime(cdf_time: npt.ArrayLike) -> Time:
return cdf_time.datetime

@staticmethod
def unixtime(cdf_time): # @NoSelf
def unixtime(cdf_time: Union[Time, npt.ArrayLike]) -> npt.NDArray:
"""
Encodes the epoch(s) into seconds after 1970-01-01. Precision is only
kept to the nearest microsecond.
Expand Down Expand Up @@ -137,7 +137,9 @@ def compute(datetimes: npt.ArrayLike) -> npt.NDArray:
return np.squeeze(cdf_time)

@staticmethod
def findepochrange(epochs, starttime=None, endtime=None): # @NoSelf
def findepochrange(
epochs: Union[Time, npt.ArrayLike], starttime: Optional[npt.ArrayLike] = None, endtime: Optional[npt.ArrayLike] = None
) -> Tuple[int, int]:
if isinstance(starttime, list):
start = CDFAstropy.compute(starttime)
if isinstance(endtime, list):
Expand Down Expand Up @@ -233,7 +235,7 @@ def breakdown_epoch(epochs: Time) -> npt.NDArray:
return np.squeeze(times)

@staticmethod
def parse(value):
def parse(value: npt.ArrayLike) -> npt.NDArray:
"""
Parses the provided date/time string(s) into CDF epoch value(s).
Expand All @@ -251,8 +253,7 @@ def parse(value):
'yyyy-mm-dd hh:mm:ss.mmmuuunnn' (in iso_8601). The string is
the output from encode function.
"""
if not isinstance(value, (list, np.ndarray)):
value = [value]
value = np.atleast_1d(value)

time_list = []

Expand All @@ -265,4 +266,4 @@ def parse(value):
if len(subs) == 9:
time_list.append(int(Time(t, precision=9).cdf_tt2000))

return np.array(time_list)
return np.squeeze(time_list)

0 comments on commit 1b528a1

Please sign in to comment.