Skip to content

Commit

Permalink
black
Browse files Browse the repository at this point in the history
  • Loading branch information
tom-andersson committed Jun 9, 2023
1 parent 59dc7c0 commit 2762ad8
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 9 deletions.
2 changes: 1 addition & 1 deletion deepsensor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
"""

# Magnitude of diagonal to regularise matrices with in `backends` library used by `neuralprocesses`
DEFAULT_LAB_EPSILON = 1e-6
DEFAULT_LAB_EPSILON = 1e-6
10 changes: 6 additions & 4 deletions deepsensor/data/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,20 +75,22 @@ def _validate_maps(self, x1_map, x2_map):
return x1_map, x2_map

def _validate_xr(self, data: Union[xr.DataArray, xr.Dataset]):

def _validate_da(da: xr.DataArray):
coord_names = [self.norm_params["coords"][coord]["name"] for coord in ["time", "x1", "x2"]]
coord_names = [
self.norm_params["coords"][coord]["name"]
for coord in ["time", "x1", "x2"]
]
if coord_names[0] not in da.dims:
# We don't have a time dimension.
coord_names = coord_names[1:]
if list(da.dims) != coord_names:
raise ValueError(
f"Dimensions of {da.name} need to be {coord_names} but are {list(da.dims)}."
)

if isinstance(data, xr.DataArray):
_validate_da(data)

elif isinstance(data, xr.Dataset):
for var_ID, da in data.data_vars.items():
_validate_da(da)
Expand Down
3 changes: 2 additions & 1 deletion deepsensor/tensorflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
def convert_to_tensor(arr):
return tf.convert_to_tensor(arr)


from deepsensor import config as deepsensor_config
from deepsensor import backend

Expand All @@ -27,4 +28,4 @@ def convert_to_tensor(arr):
backend.convert_to_tensor = convert_to_tensor
backend.str = "tf"

B.epsilon = deepsensor_config.DEFAULT_LAB_EPSILON
B.epsilon = deepsensor_config.DEFAULT_LAB_EPSILON
4 changes: 1 addition & 3 deletions deepsensor/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,7 @@ def set_gpu_default_device():
# print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

else:
raise NotImplementedError(
f"Backend {deepsensor.backend.str} not implemented"
)
raise NotImplementedError(f"Backend {deepsensor.backend.str} not implemented")


def train_epoch(
Expand Down
1 change: 1 addition & 0 deletions tests/test_data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def test_same_names_xr(self):
self.assert_allclose_xr(da_unnorm, da_raw),
f"Original {type(da_raw).__name__} not restored.",
)

def test_wrong_order_xr_ds(self):
"""Order of dimensions in xarray must be: time, x1, x2"""
ds_raw = _gen_data_xr(dims=("time", "lat", "lon"), data_vars=["var1", "var2"])
Expand Down

0 comments on commit 2762ad8

Please sign in to comment.