Skip to content

Commit

Permalink
Merge pull request #292 from DeepRank/fix-grid-dimension-setting-multi
Browse files Browse the repository at this point in the history
Add the option to have a grid box of different x,y and z dimensions

closes #290
  • Loading branch information
cbaakman committed Dec 15, 2022
2 parents 2401ea5 + 5570782 commit 11262a1
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 24 deletions.
52 changes: 29 additions & 23 deletions deeprankcore/utils/grid.py
Expand Up @@ -4,7 +4,7 @@


from enum import Enum
from typing import Dict, Union
from typing import Dict, Union, List
import numpy as np
import h5py
import itertools
Expand All @@ -27,26 +27,30 @@ class GridSettings:
"""Objects of this class hold the settings to build a grid.
The grid is basically a multi-divided 3D cube with
the following properties:
- points_count: the number of points on one edge of the cube
- size: the length in Å of one edge of the cube
- resolution: the size in Å of one edge subdivision. Also the distance between two points on the edge.
- sizes: x, y, z sizes of the box in Å
- points_counts: the number of points on the x, y, z edges of the cube
- resolutions: the size in Å of one x, y, z edge subdivision. Also the distance between two points on the edge.
"""

def __init__(self, points_count: int, size: float):
self._points_count = points_count
self._size = size
def __init__(self, points_counts: List[int], sizes: List[float]):
assert len(points_counts) == 3
assert len(sizes) == 3

self._points_counts = points_counts
self._sizes = sizes

@property
def resolution(self) -> float:
return self._size / self._points_count
def resolutions(self) -> List[float]:
return [self._sizes[i] / self._points_counts[i] for i in range(3)]

@property
def size(self) -> float:
return self._size
def sizes(self) -> List[float]:
return self._sizes.tolist()

@property
def points_count(self) -> int:
return self._points_count
def points_counts(self) -> List[int]:
return self._points_counts


class Grid:
Expand All @@ -69,19 +73,21 @@ def __init__(self, id_: str, settings: GridSettings, center: np.array):
def _set_mesh(self, settings: GridSettings, center: np.array):
"builds the grid points"

half_size = settings.size / 2
half_size_x = settings.sizes[0] / 2
half_size_y = settings.sizes[1] / 2
half_size_z = settings.sizes[2] / 2

min_x = center[0] - half_size
max_x = center[0] + half_size
self._xs = np.linspace(min_x, max_x, num=settings.points_count)
min_x = center[0] - half_size_x
max_x = center[0] + half_size_x
self._xs = np.linspace(min_x, max_x, num=settings.points_counts[0])

min_y = center[1] - half_size
max_y = center[1] + half_size
self._ys = np.linspace(min_y, max_y, num=settings.points_count)
min_y = center[1] - half_size_y
max_y = center[1] + half_size_y
self._ys = np.linspace(min_y, max_y, num=settings.points_counts[1])

min_z = center[2] - half_size
max_z = center[2] + half_size
self._zs = np.linspace(min_z, max_z, num=settings.points_count)
min_z = center[2] - half_size_z
max_z = center[2] + half_size_z
self._zs = np.linspace(min_z, max_z, num=settings.points_counts[2])

self._ygrid, self._xgrid, self._zgrid = np.meshgrid(
self._ys, self._xs, self._zs
Expand Down
5 changes: 4 additions & 1 deletion tests/utils/test_graph.py
Expand Up @@ -61,7 +61,9 @@ def test_graph_build_and_export(): # pylint: disable=too-many-locals
graph.write_to_hdf5(hdf5_path)

# export grid to hdf5
grid_settings = GridSettings(20, 20.0)
grid_settings = GridSettings(np.array((20, 21, 21)), np.array((20.0, 21.0, 21.0)))
assert np.all(grid_settings.resolutions == np.array((1.0, 1.0, 1.0)))

graph.write_as_grid_to_hdf5(hdf5_path, grid_settings, MapMethod.FAST_GAUSSIAN)

# check the contents of the hdf5 file
Expand Down Expand Up @@ -95,5 +97,6 @@ def test_graph_build_and_export(): # pylint: disable=too-many-locals
assert "value" in mapped_group[feature_name]
data = mapped_group[feature_name]["value"][()]
assert len(np.nonzero(data)) > 0, f"{feature_name}: all zero"
assert np.all(data.shape == grid_settings.points_counts)
finally:
shutil.rmtree(tmp_dir_path) # clean up after the test

0 comments on commit 11262a1

Please sign in to comment.