Skip to content
This repository has been archived by the owner on Nov 28, 2023. It is now read-only.

Commit

Permalink
Merge 754060b into ab7d452
Browse files Browse the repository at this point in the history
  • Loading branch information
manonreau committed Nov 15, 2021
2 parents ab7d452 + 754060b commit eac8a6a
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 90 deletions.
111 changes: 61 additions & 50 deletions deeprank/learn/DataSet.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,18 +285,19 @@ def process_dataset(self):
self.get_pairing_feature()

# get grid shape
self.get_grid_shape()
if self.grid_shape is None:
self.get_grid_shape()

# get the input shape
self.get_input_shape()

# get renormalization factor
if self.normalize_features or self.normalize_targets:
if self.normalize_features or self.normalize_targets or self.clip_features:
if self.mapfly:
self.compute_norm()
else:
self.get_norm()

logger.info('\n')
logger.info(" Data Set Info:")
logger.info(
Expand Down Expand Up @@ -342,7 +343,7 @@ def __getitem__(self, index):

if self.normalize_features:
feature = self._normalize_feature(feature)

if self.normalize_targets:
target = self._normalize_target(target)

Expand Down Expand Up @@ -820,20 +821,17 @@ def get_grid_shape(self):
mol_data = fh5.get(mol)

# get the grid size
if self.grid_shape is None:

if 'grid_points' in mol_data:
nx = mol_data['grid_points']['x'].shape[0]
ny = mol_data['grid_points']['y'].shape[0]
nz = mol_data['grid_points']['z'].shape[0]
self.grid_shape = (nx, ny, nz)

else:
raise ValueError(
f'Impossible to determine sparse grid shape.\n '
f'Specify argument grid_shape=(x,y,z)')
if 'grid_points' in mol_data:
nx = mol_data['grid_points']['x'].shape[0]
ny = mol_data['grid_points']['y'].shape[0]
nz = mol_data['grid_points']['z'].shape[0]
self.grid_shape = (nx, ny, nz)
else:
raise ValueError(
f'Impossible to determine sparse grid shape.\n '
f'Specify argument grid_shape=(x,y,z)')

fh5.close()
fh5.close()

elif self.grid_info is not None:
self.grid_shape = self.grid_info['number_of_points']
Expand Down Expand Up @@ -897,7 +895,7 @@ def compute_norm(self):
self.target_min = self.param_norm['targets'].min[0]
self.target_max = self.param_norm['targets'].max[0]

logger.info(self.target_min, self.target_max)
logger.info(f'{self.target_min}, {self.target_max}')

def get_norm(self):
"""Get the normalization values for the features."""
Expand Down Expand Up @@ -1045,20 +1043,23 @@ def _normalize_feature(self, feature):
) / self.feature_std[ic]
return feature


def _clip_feature(self, feature):
"""Clip the value of the features at +/- mean + clip_factor * std.
Args:
feature (np.array): raw feature values
Returns:
np.array: clipped feature values
"""

w = self.clip_factor
for ic in range(self.data_shape[0]):
minv = self.feature_mean[ic] - w * self.feature_std[ic]
maxv = self.feature_mean[ic] + w * self.feature_std[ic]
feature[ic] = np.clip(feature[ic], minv, maxv)
#feature[ic] = self._mad_based_outliers(feature[ic],minv,maxv)
if len(feature[ic]) > 0:
minv = self.feature_mean[ic] - w * self.feature_std[ic]
maxv = self.feature_mean[ic] + w * self.feature_std[ic]
if minv != maxv:
feature[ic] = np.clip(feature[ic], minv, maxv)
#feature[ic] = self._mad_based_outliers(feature[ic],minv,maxv)
return feature

@staticmethod
Expand Down Expand Up @@ -1168,7 +1169,10 @@ def load_one_molecule(self, fname, mol=None):
feature.append(mat)

# get the target value
target = mol_data.get('targets/' + self.select_target)[()]
try:
target = mol_data.get('targets/' + self.select_target)[()]
except Exception:
logger.exception(f'No target value for: {fname} - not required for the test set')

# close
fh5.close()
Expand Down Expand Up @@ -1214,7 +1218,10 @@ def map_one_molecule(self, fname, mol=None, angle=None, axis=None):
feature += data

# get the target value
target = mol_data.get('targets/' + self.select_target)[()]
try:
target = mol_data.get('targets/' + self.select_target)[()]
except Exception:
logger.exception(f'No target value for: {fname} - not required for the test set')

# close
fh5.close()
Expand Down Expand Up @@ -1431,32 +1438,36 @@ def map_feature(self, feat_names, mol_data, grid, npts, angle, axis):
tmp_feat_ser = [np.zeros(npts), np.zeros(npts)]
tmp_feat_vect = [np.zeros(npts), np.zeros(npts)]
data = np.array(mol_data['features/' + name][()])

if data.shape[0]==0:
logger.warning(f'No {name} retrieved at the protein/protein interface')

chain = data[:, 0]
pos = data[:, 1:4]
feat_value = data[:, 4]

if angle is not None:
pos = pdb2sql.transform.rot_xyz_around_axis(
pos, axis, angle, center)

if __vectorize__ or __vectorize__ == 'both':

for chainID in [0, 1]:
tmp_feat_vect[chainID] = np.sum(
vmap(pos[chain == chainID, :],
feat_value[chain == chainID]),
0)

if not __vectorize__ or __vectorize__ == 'both':

for chainID, xyz, val in zip(chain, pos, feat_value):
tmp_feat_ser[int(chainID)] += \
self._featgrid(xyz, val, grid, npts)

if __vectorize__ == 'both':
assert np.allclose(tmp_feat_ser, tmp_feat_vect)

else:
chain = data[:, 0]
pos = data[:, 1:4]
feat_value = data[:, 4]

if angle is not None:
pos = pdb2sql.transform.rot_xyz_around_axis(
pos, axis, angle, center)

if __vectorize__ or __vectorize__ == 'both':

for chainID in [0, 1]:
tmp_feat_vect[chainID] = np.sum(
vmap(pos[chain == chainID, :],
feat_value[chain == chainID]),
0)

if not __vectorize__ or __vectorize__ == 'both':

for chainID, xyz, val in zip(chain, pos, feat_value):
tmp_feat_ser[int(chainID)] += \
self._featgrid(xyz, val, grid, npts)

if __vectorize__ == 'both':
assert np.allclose(tmp_feat_ser, tmp_feat_vect)

if __vectorize__:
feat += tmp_feat_vect
else:
Expand Down
95 changes: 55 additions & 40 deletions deeprank/learn/NeuralNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ def __init__(self, data_set, model,
chain1='A',
chain2='B',
cuda=False, ngpu=0,
plot=True,
save_hitrate=True,
plot=False,
save_hitrate=False,
save_classmetrics=False,
outdir='./'):
"""Train a Convolutional Neural Network for DeepRank.
Expand Down Expand Up @@ -177,7 +177,7 @@ def __init__(self, data_set, model,
self.ngpu = 1

# ------------------------------------------
# Regression or classifiation
# Regression or classification
# ------------------------------------------

# task to accomplish
Expand Down Expand Up @@ -215,9 +215,8 @@ def __init__(self, data_set, model,

# output directory
self.outdir = outdir
if self.plot:
if not os.path.isdir(self.outdir):
os.mkdir(outdir)
if not os.path.isdir(self.outdir):
os.mkdir(outdir)

# ------------------------------------------
# Network
Expand Down Expand Up @@ -424,12 +423,12 @@ def test(self, hdf5='test_data.hdf5'):
# do test
self.data = {}
_, self.data['test'] = self._epoch(loader, train_model=False)
if self.task == 'reg':
self._plot_scatter_reg(os.path.join(self.outdir, 'prediction.png'))
else:
self._plot_boxplot_class(os.path.join(self.outdir, 'prediction.png'))

self.plot_hit_rate(os.path.join(self.outdir + 'hitrate.png'))

# plot results
if self.plot is True :
self._plot_scatter(os.path.join(self.outdir, 'prediction.png'))
if self.save_hitrate:
self.plot_hit_rate(os.path.join(self.outdir + 'hitrate.png'))

self._export_epoch_hdf5(0, self.data)
self.f5.close()
Expand Down Expand Up @@ -975,10 +974,14 @@ def _plot_scatter_reg(self, figname):
yvalues = np.array([])

for l in labels:

if l in self.data:

targ = self.data[l]['targets'].flatten()
try:
targ = self.data[l]['targets']
except Exception:
logger.exception(f'No target values are provided for the {l} set \n Skip {l} in the scatter plot')
continue

out = self.data[l]['outputs'].flatten()

xvalues = np.append(xvalues, targ)
Expand Down Expand Up @@ -1024,10 +1027,13 @@ def _plot_boxplot_class(self, figname):
for l in labels:

if l in self.data:
try:
tar = self.data[l]['targets']
except Exception:
logger.exception(f'No target values are provided for the {l} set \n Skip {l} in the boxplot')
continue

tar = self.data[l]['targets']
out = self.data[l]['outputs']

data = [[], []]
confusion = [[0, 0], [0, 0]]
for pts, t in zip(out, tar):
Expand All @@ -1054,11 +1060,11 @@ def plot_hit_rate(self, figname):
Args:
figname (str): filename for the plot
irmsd_thr (float, optional): threshold for 'good' models
target_thr (float, optional): threshold for 'good' models
"""
if self.plot is False:
return

logger.info(f'\n --> Hitrate plot: {figname}\n')

color_plot = {'train': 'red', 'valid': 'blue', 'test': 'green'}
Expand All @@ -1067,8 +1073,14 @@ def plot_hit_rate(self, figname):
fig, ax = plt.subplots()
for l in labels:
if l in self.data:
try:
hits = self.data[l]['hit']
except Exception:
logger.exception(f'No hitrate computed for the {l} set')
continue

if 'hit' in self.data[l]:
hitrate = rankingMetrics.hitrate(self.data[l]['hit'])
hitrate = rankingMetrics.hitrate(hits)
m = len(hitrate)
x = np.linspace(0, 100, m)
plt.plot(x, hitrate, c=color_plot[l], label=f"{l} M={m}")
Expand All @@ -1083,7 +1095,7 @@ def plot_hit_rate(self, figname):
fig.savefig(figname)
plt.close()

def _compute_hitrate(self, irmsd_thr=4.0):
def _compute_hitrate(self, target_thr=4.0):

labels = ['train', 'valid', 'test']
self.hitrate = {}
Expand All @@ -1100,14 +1112,17 @@ def _compute_hitrate(self, irmsd_thr=4.0):
# get the target values
out = self.data[l]['outputs']

# get the irmsd
irmsd = []
for fname, mol in self.data[l]['mol']:

f5 = h5py.File(fname, 'r')
irmsd.append(f5[mol + '/targets/IRMSD'][()])
f5.close()

# get the target vaues
targets = []
try:
for fname, mol in self.data[l]['mol']:
f5 = h5py.File(fname, 'r')
targets.append(f5[mol + f'/targets/self.data_set.select_target'][()])
f5.close()
except Exception:
logger.exception(f'No target value ({self.data_set.select_target}) provided for for the {l} set. Skip Hitrate computation for the {l} set.')
continue

# sort the data
if self.task == 'class':
out = F.softmax(torch.FloatTensor(out),
Expand All @@ -1117,16 +1132,16 @@ def _compute_hitrate(self, irmsd_thr=4.0):
if not inverse:
ind_sort = ind_sort[::-1]

# get the irmsd of the recommendation
irmsd = np.array(irmsd)[ind_sort]
# get the targets of the recommendation
targets = np.array(targets)[ind_sort]

# make a binary list out of that
binary_recomendation = (irmsd <= irmsd_thr).astype('int')
binary_recomendation = (targets <= target_thr).astype('int')

# number of recommended hit
npos = np.sum(binary_recomendation)
if npos == 0:
npos = len(irmsd)
npos = len(targets)
warnings.warn(
f'Non positive decoys found in {l} for hitrate plot')

Expand All @@ -1135,7 +1150,7 @@ def _compute_hitrate(self, irmsd_thr=4.0):
binary_recomendation, npos)
self.data[l]['relevance'] = binary_recomendation

def _get_relevance(self, data, irmsd_thr=4.0):
def _get_relevance(self, data, target_thr=4.0):

# get the target ordering
inverse = self.data_set.target_ordering == 'lower'
Expand All @@ -1145,12 +1160,12 @@ def _get_relevance(self, data, irmsd_thr=4.0):
# get the target values
out = data['outputs']

# get the irmsd
irmsd = []
# get the targets
targets = []
for fname, mol in data['mol']:

f5 = h5py.File(fname, 'r')
irmsd.append(f5[mol + '/targets/IRMSD'][()])
targets.append(f5[mol + f'/targets/self.data_set.select_target'][()])
f5.close()

# sort the data
Expand All @@ -1161,11 +1176,11 @@ def _get_relevance(self, data, irmsd_thr=4.0):
if not inverse:
ind_sort = ind_sort[::-1]

# get the irmsd of the recommendation
irmsd = np.array(irmsd)[ind_sort]
# get the targets of the recommendation
targets = np.array(targets)[ind_sort]

# make a binary list out of that
return (irmsd <= irmsd_thr).astype('int')
return (targets <= target_thr).astype('int')

def _get_classmetrics(self, data, metricname):

Expand Down

0 comments on commit eac8a6a

Please sign in to comment.