Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add the option to have a grid box of different x,y and z dimensions #292

Merged
merged 5 commits into from Dec 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
52 changes: 29 additions & 23 deletions deeprankcore/utils/grid.py
Expand Up @@ -4,7 +4,7 @@


from enum import Enum
from typing import Dict, Union
from typing import Dict, Union, List
import numpy as np
import h5py
import itertools
Expand All @@ -27,26 +27,30 @@ class GridSettings:
"""Objects of this class hold the settings to build a grid.
The grid is basically a multi-divided 3D cube with
the following properties:
- points_count: the number of points on one edge of the cube
- size: the length in Å of one edge of the cube
- resolution: the size in Å of one edge subdivision. Also the distance between two points on the edge.

- sizes: x, y, z sizes of the box in Å
- points_counts: the number of points on the x, y, z edges of the cube
- resolutions: the size in Å of one x, y, z edge subdivision. Also the distance between two points on the edge.
"""

def __init__(self, points_count: int, size: float):
self._points_count = points_count
self._size = size
def __init__(self, points_counts: List[int], sizes: List[float]):
assert len(points_counts) == 3
assert len(sizes) == 3

self._points_counts = points_counts
self._sizes = sizes

@property
def resolution(self) -> float:
return self._size / self._points_count
def resolutions(self) -> List[float]:
return [self._sizes[i] / self._points_counts[i] for i in range(3)]

@property
def size(self) -> float:
return self._size
def sizes(self) -> List[float]:
return self._sizes.tolist()

@property
def points_count(self) -> int:
return self._points_count
def points_counts(self) -> List[int]:
return self._points_counts


class Grid:
Expand All @@ -69,19 +73,21 @@ def __init__(self, id_: str, settings: GridSettings, center: np.array):
def _set_mesh(self, settings: GridSettings, center: np.array):
"builds the grid points"

half_size = settings.size / 2
half_size_x = settings.sizes[0] / 2
half_size_y = settings.sizes[1] / 2
half_size_z = settings.sizes[2] / 2

min_x = center[0] - half_size
max_x = center[0] + half_size
self._xs = np.linspace(min_x, max_x, num=settings.points_count)
min_x = center[0] - half_size_x
max_x = center[0] + half_size_x
self._xs = np.linspace(min_x, max_x, num=settings.points_counts[0])

min_y = center[1] - half_size
max_y = center[1] + half_size
self._ys = np.linspace(min_y, max_y, num=settings.points_count)
min_y = center[1] - half_size_y
max_y = center[1] + half_size_y
self._ys = np.linspace(min_y, max_y, num=settings.points_counts[1])

min_z = center[2] - half_size
max_z = center[2] + half_size
self._zs = np.linspace(min_z, max_z, num=settings.points_count)
min_z = center[2] - half_size_z
max_z = center[2] + half_size_z
self._zs = np.linspace(min_z, max_z, num=settings.points_counts[2])

self._ygrid, self._xgrid, self._zgrid = np.meshgrid(
self._ys, self._xs, self._zs
Expand Down
5 changes: 4 additions & 1 deletion tests/utils/test_graph.py
Expand Up @@ -61,7 +61,9 @@ def test_graph_build_and_export(): # pylint: disable=too-many-locals
graph.write_to_hdf5(hdf5_path)

# export grid to hdf5
grid_settings = GridSettings(20, 20.0)
grid_settings = GridSettings(np.array((20, 21, 21)), np.array((20.0, 21.0, 21.0)))
assert np.all(grid_settings.resolutions == np.array((1.0, 1.0, 1.0)))

graph.write_as_grid_to_hdf5(hdf5_path, grid_settings, MapMethod.FAST_GAUSSIAN)

# check the contents of the hdf5 file
Expand Down Expand Up @@ -95,5 +97,6 @@ def test_graph_build_and_export(): # pylint: disable=too-many-locals
assert "value" in mapped_group[feature_name]
data = mapped_group[feature_name]["value"][()]
assert len(np.nonzero(data)) > 0, f"{feature_name}: all zero"
assert np.all(data.shape == grid_settings.points_counts)
finally:
shutil.rmtree(tmp_dir_path) # clean up after the test