Skip to content
This repository has been archived by the owner on Nov 28, 2023. It is now read-only.

Commit

Permalink
Merge 5135b33 into ab7d452
Browse files Browse the repository at this point in the history
  • Loading branch information
manonreau authored Nov 16, 2021
2 parents ab7d452 + 5135b33 commit 695c2b7
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 130 deletions.
117 changes: 62 additions & 55 deletions deeprank/learn/DataSet.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,15 @@ class DataSet():

def __init__(self, train_database, valid_database=None, test_database=None,
chain1='A', chain2='B',
mapfly=True, grid_info=None,
mapfly=True, grid_info={
'number_of_points': (10, 10, 10),
'resolution': (3, 3, 3)},
use_rotation=None,
select_feature='all', select_target='DOCKQ',
normalize_features=True, normalize_targets=True,
target_ordering=None,
dict_filter=None, pair_chain_feature=None,
transform_to_2D=False, projection=0,
grid_shape=None,
clip_features=True, clip_factor=1.5,
rotation_seed=None,
tqdm=False,
Expand Down Expand Up @@ -105,9 +106,6 @@ def __init__(self, train_database, valid_database=None, test_database=None,
projection (int): Projection axis from 3D to 2D:
Mapping: 0 -> yz, 1 -> xz, 2 -> xy
Default = 0
grid_shape (None or tuple(int), optional):
Shape of the grid in the hdf5 file. Is not necessary
if the grid points are still present in the HDF5 file.
clip_features (bool, optional):
Remove too large values of the grid.
Can be needed for native complexes where the coulomb
Expand Down Expand Up @@ -185,7 +183,6 @@ def __init__(self, train_database, valid_database=None, test_database=None,
# shape of the data
self.input_shape = None
self.data_shape = None
self.grid_shape = grid_shape

# the possible pairing of the ind features
self.pair_chain_feature = pair_chain_feature
Expand Down Expand Up @@ -291,12 +288,12 @@ def process_dataset(self):
self.get_input_shape()

# get renormalization factor
if self.normalize_features or self.normalize_targets:
if self.normalize_features or self.normalize_targets or self.clip_features:
if self.mapfly:
self.compute_norm()
else:
self.get_norm()

logger.info('\n')
logger.info(" Data Set Info:")
logger.info(
Expand Down Expand Up @@ -342,7 +339,7 @@ def __getitem__(self, index):

if self.normalize_features:
feature = self._normalize_feature(feature)

if self.normalize_targets:
target = self._normalize_target(target)

Expand Down Expand Up @@ -820,20 +817,17 @@ def get_grid_shape(self):
mol_data = fh5.get(mol)

# get the grid size
if self.grid_shape is None:

if 'grid_points' in mol_data:
nx = mol_data['grid_points']['x'].shape[0]
ny = mol_data['grid_points']['y'].shape[0]
nz = mol_data['grid_points']['z'].shape[0]
self.grid_shape = (nx, ny, nz)

else:
raise ValueError(
f'Impossible to determine sparse grid shape.\n '
f'Specify argument grid_shape=(x,y,z)')
if 'grid_points' in mol_data:
nx = mol_data['grid_points']['x'].shape[0]
ny = mol_data['grid_points']['y'].shape[0]
nz = mol_data['grid_points']['z'].shape[0]
self.grid_shape = (nx, ny, nz)
else:
raise ValueError(
f'Impossible to determine sparse grid shape.\n '
f'Specify argument grid_shape=(x,y,z)')

fh5.close()
fh5.close()

elif self.grid_info is not None:
self.grid_shape = self.grid_info['number_of_points']
Expand Down Expand Up @@ -897,7 +891,7 @@ def compute_norm(self):
self.target_min = self.param_norm['targets'].min[0]
self.target_max = self.param_norm['targets'].max[0]

logger.info(self.target_min, self.target_max)
logger.info(f'{self.target_min}, {self.target_max}')

def get_norm(self):
"""Get the normalization values for the features."""
Expand Down Expand Up @@ -1045,20 +1039,23 @@ def _normalize_feature(self, feature):
) / self.feature_std[ic]
return feature


def _clip_feature(self, feature):
"""Clip the value of the features at +/- mean + clip_factor * std.
Args:
feature (np.array): raw feature values
Returns:
np.array: clipped feature values
"""

w = self.clip_factor
for ic in range(self.data_shape[0]):
minv = self.feature_mean[ic] - w * self.feature_std[ic]
maxv = self.feature_mean[ic] + w * self.feature_std[ic]
feature[ic] = np.clip(feature[ic], minv, maxv)
#feature[ic] = self._mad_based_outliers(feature[ic],minv,maxv)
if len(feature[ic]) > 0:
minv = self.feature_mean[ic] - w * self.feature_std[ic]
maxv = self.feature_mean[ic] + w * self.feature_std[ic]
if minv != maxv:
feature[ic] = np.clip(feature[ic], minv, maxv)
#feature[ic] = self._mad_based_outliers(feature[ic],minv,maxv)
return feature

@staticmethod
Expand Down Expand Up @@ -1168,7 +1165,10 @@ def load_one_molecule(self, fname, mol=None):
feature.append(mat)

# get the target value
target = mol_data.get('targets/' + self.select_target)[()]
try:
target = mol_data.get('targets/' + self.select_target)[()]
except Exception:
logger.exception(f'No target value for: {fname} - not required for the test set')

# close
fh5.close()
Expand Down Expand Up @@ -1214,7 +1214,10 @@ def map_one_molecule(self, fname, mol=None, angle=None, axis=None):
feature += data

# get the target value
target = mol_data.get('targets/' + self.select_target)[()]
try:
target = mol_data.get('targets/' + self.select_target)[()]
except Exception:
logger.exception(f'No target value for: {fname} - not required for the test set')

# close
fh5.close()
Expand Down Expand Up @@ -1431,32 +1434,36 @@ def map_feature(self, feat_names, mol_data, grid, npts, angle, axis):
tmp_feat_ser = [np.zeros(npts), np.zeros(npts)]
tmp_feat_vect = [np.zeros(npts), np.zeros(npts)]
data = np.array(mol_data['features/' + name][()])

if data.shape[0]==0:
logger.warning(f'No {name} retrieved at the protein/protein interface')

chain = data[:, 0]
pos = data[:, 1:4]
feat_value = data[:, 4]

if angle is not None:
pos = pdb2sql.transform.rot_xyz_around_axis(
pos, axis, angle, center)

if __vectorize__ or __vectorize__ == 'both':

for chainID in [0, 1]:
tmp_feat_vect[chainID] = np.sum(
vmap(pos[chain == chainID, :],
feat_value[chain == chainID]),
0)

if not __vectorize__ or __vectorize__ == 'both':

for chainID, xyz, val in zip(chain, pos, feat_value):
tmp_feat_ser[int(chainID)] += \
self._featgrid(xyz, val, grid, npts)

if __vectorize__ == 'both':
assert np.allclose(tmp_feat_ser, tmp_feat_vect)

else:
chain = data[:, 0]
pos = data[:, 1:4]
feat_value = data[:, 4]

if angle is not None:
pos = pdb2sql.transform.rot_xyz_around_axis(
pos, axis, angle, center)

if __vectorize__ or __vectorize__ == 'both':

for chainID in [0, 1]:
tmp_feat_vect[chainID] = np.sum(
vmap(pos[chain == chainID, :],
feat_value[chain == chainID]),
0)

if not __vectorize__ or __vectorize__ == 'both':

for chainID, xyz, val in zip(chain, pos, feat_value):
tmp_feat_ser[int(chainID)] += \
self._featgrid(xyz, val, grid, npts)

if __vectorize__ == 'both':
assert np.allclose(tmp_feat_ser, tmp_feat_vect)

if __vectorize__:
feat += tmp_feat_vect
else:
Expand Down
Loading

0 comments on commit 695c2b7

Please sign in to comment.