Skip to content
This repository has been archived by the owner on Nov 28, 2023. It is now read-only.

Commit

Permalink
Merge 65b37a4 into 1c02ecb
Browse files Browse the repository at this point in the history
  • Loading branch information
CunliangGeng committed Jun 2, 2020
2 parents 1c02ecb + 65b37a4 commit 2af9afc
Show file tree
Hide file tree
Showing 2 changed files with 177 additions and 53 deletions.
205 changes: 153 additions & 52 deletions deeprank/generate/DataGenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(self, pdb_select=None, pdb_source=None,
pdb_native (list(str), optional): List of folders where to find the native comformations,
nust set it if having targets to compute in parameter "compute_targets".
pssm_source (list(str), optional): List of folders where to find the PSSM files
align (dict, optional): Dicitionary to align the compexes,
align (dict, optional): Dicitionary to align the compexes,
e.g. align = {"selection":{"chainID":["A","B"]},"axis":"z"}}
e.g. align = {"selection":"interface","plane":"xy"}
if "selection" is not specified the entire complex is used for alignement
Expand Down Expand Up @@ -83,7 +83,7 @@ def __init__(self, pdb_select=None, pdb_source=None,
self.pdb_source = pdb_source or []
self.pdb_native = pdb_native or []
self.pssm_source = pssm_source
self.align = align
self.align = align

if self.pssm_source is not None:
config.PATH_PSSM_SOURCE = self.pssm_source
Expand Down Expand Up @@ -234,7 +234,7 @@ def create_database(
self.f5.attrs['pssm_source'] = os.path.abspath(self.pssm_source)
self.f5.attrs['features'] = self.compute_features
self.f5.attrs['targets'] = self.compute_targets

##################################################
# Start generating HDF5 database
##################################################
Expand Down Expand Up @@ -504,6 +504,106 @@ def create_database(
self.f5.close()
self.logger.info(f'\n# Successfully created database: {self.hdf5}\n')

def aug_data(self, augmentation, keep_existing_aug=True, random_seed=None):
"""Augment exiting original PDB data and features.
Args:
augmentation(int): Times of augmentation
keep_existing_aug (bool, optional): Keep existing augmentated data.
If False, existing aug will be removed. Defaults to True.
Examples:
>>> database = DataGenerator(h5='database.h5')
>>> database.aug_data(augmentation=3, append=True)
>>> grid_info = {
>>> 'number_of_points' : [30,30,30],
>>> 'resolution' : [1.,1.,1.],
>>> 'atomic_densities' : {'C':1.7, 'N':1.55, 'O':1.52, 'S':1.8},
>>> }
>>> database.map_features(grid_info)
"""

# check if file exists
if not os.path.isfile(self.hdf5):
raise FileNotFoundError('File %s does not exists' % self.hdf5)

# get the folder names
f5 = h5py.File(self.hdf5, 'a')
fnames = f5.keys()

# get the non rotated ones
fnames_original = list(
filter(lambda x: not re.search(r'_r\d+$', x), fnames))

# get the rotated ones
fnames_augmented = list(
filter(lambda x: re.search(r'_r\d+$', x), fnames))

aug_id_start = 0
if keep_existing_aug:
exiting_augs = list(
filter(lambda x: re.search(fnames_original[0]+ r'_r\d+$', x), fnames_augmented))
aug_id_start += len(exiting_augs)
else:
for i in fnames_augmented:
del f5[i]

self.logger.info(
f'{"":s}\n# Start augmenting data'
f' with {augmentation} times...')

# GET ALL THE NAMES
for mol_name in fnames_original:
mol_aug_name_list = [
mol_name + '_r%03d' % (idir + 1) for idir in
range(aug_id_start, aug_id_start + augmentation)]

# loop over the complexes
for mol_aug_name in mol_aug_name_list:

# crete a subgroup for the molecule
molgrp = f5.require_group(mol_aug_name)
molgrp.attrs['type'] = 'molecule'

# copy the ref into it
if 'native' in f5[mol_name]:
f5.copy(mol_name + '/native', molgrp)

# get the rotation axis and angle
if self.align is None:
axis, angle = pdb2sql.transform.get_rot_axis_angle(random_seed)
else:
axis, angle = self._get_aligned_rotation_axis_angle(random_seed,
self.align)

# create the new pdb and get molecule center
# molecule center is the origin of rotation)
mol_center = self._add_aug_pdb(
molgrp, f5[mol_name + '/complex'][()], 'complex', axis, angle)

# copy the targets/features
if 'targets' in f5[mol_name]:
f5.copy(mol_name + '/targets/', molgrp)
f5.copy(mol_name + '/features/', molgrp)

# rotate the feature
self._rotate_feature(molgrp, axis, angle, mol_center)

# grid center used to create grid box
molgrp.require_group('grid_points')
center = pdb2sql.transform.rot_xyz_around_axis(
f5[mol_name + '/grid_points/center'],
axis, angle, mol_center)

molgrp['grid_points'].create_dataset('center', data=center)

# store the rotation axis/angl/center as attriutes
# in case we need them later
molgrp.attrs['axis'] = axis
molgrp.attrs['angle'] = angle
molgrp.attrs['center'] = mol_center
f5.close()
self.logger.info(f'\n# Successfully augmented data in {self.hdf5}')

# ====================================================================================
#
Expand Down Expand Up @@ -729,17 +829,17 @@ def add_target(self, prog_bar=False):

def realign_complexes(self, align, compute_features=None, pssm_source=None):
"""Align all the complexes already present in the HDF5.
Arguments:
align {dict} -- alignement dictionary (see __init__)
Keyword Arguments:
compute_features {list} -- list of features to be computed
if None computes the features specified in
the attrs['features'] of the file (if present)
pssm_source {str} -- path of the pssm files. If None the source specfied in
the attrs['pssm_source'] will be used (if present) (default: {None})
Raises:
ValueError: If no PSSM detected
Expand All @@ -748,13 +848,13 @@ def realign_complexes(self, align, compute_features=None, pssm_source=None):
>>> database = DataGenerator(hdf5='1ak4.hdf5')
>>> # if comute_features and pssm_source are not specified
>>> # the values in hdf5.attrs['features'] and hdf5.attrs['pssm_source'] will be used
>>> database.realign_complex(align={'axis':'x'},
>>> compute_features['deeprank.features.X'],
>>> database.realign_complex(align={'axis':'x'},
>>> compute_features['deeprank.features.X'],
>>> pssm_source='./1ak4_pssm/')
"""

f5 = h5py.File(self.hdf5,'a')

mol_names = f5.keys()
self.logger.info(f'\n# Start aligning the HDF5 database: {self.hdf5}')

Expand All @@ -772,7 +872,7 @@ def realign_complexes(self, align, compute_features=None, pssm_source=None):

elif pssm_source is not None:
config.PATH_PSSM_SOURCE = pssm_source

elif 'pssm_source' in f5.attrs:
config.PATH_PSSM_SOURCE = f5.attrs['pssm_source']
else :
Expand All @@ -797,7 +897,7 @@ def realign_complexes(self, align, compute_features=None, pssm_source=None):
for od in old_dir:
if od in molgrp:
del molgrp[od]

# the internal features
molgrp.require_group('features')
molgrp.require_group('features_raw')
Expand Down Expand Up @@ -892,6 +992,10 @@ def map_features(self, grid_info={},
remove_error=True):
"""Map the feature on a grid of points centered at the interface.
If features to map are not given, they will be are automatically
determined for each molecule. Otherwise, given features will be mapped
for all molecules (i.e. existing mapped features will be recalculated).
Args:
grid_info (dict): Informaton for the grid.
See deeprank.generate.GridTools.py for details.
Expand Down Expand Up @@ -947,38 +1051,12 @@ def map_features(self, grid_info={},
# Check grid_info
################################################################
# fills in the grid data if not provided : default = NONE
grid_info_ref = dict(grid_info) # deep copy
grinfo = ['number_of_points', 'resolution']
for gr in grinfo:
if gr not in grid_info:
grid_info[gr] = None

# Dtermine which feature to map
if 'feature' not in grid_info:

# get the mol group
mol = list(f5.keys())[0]

# if we havent mapped anything yet or if we reset
if 'mapped_features' not in list(f5[mol].keys()) or reset:
grid_info['feature'] = list(f5[mol + '/features'].keys())

# if we have already mapped stuff
elif 'mapped_features' in list(f5[mol].keys()):

# feature name
all_feat = list(f5[mol + '/features'].keys())

# feature already mapped
mapped_feat = list(
f5[mol + '/mapped_features/Feature_ind'].keys())

# we select only the feture that were not mapped yet
grid_info['feature'] = []
for feat_name in all_feat:
if not any(map(lambda x: x.startswith(feat_name + '_'),
mapped_feat)):
grid_info['feature'].append(feat_name)

# by default we do not map atomic densities
if 'atomic_densities' not in grid_info:
grid_info['atomic_densities'] = None
Expand All @@ -988,7 +1066,7 @@ def map_features(self, grid_info={},
for m in modes:
if m not in grid_info:
grid_info[m] = 'ind'

################################################################
#
################################################################
Expand Down Expand Up @@ -1025,6 +1103,30 @@ def map_features(self, grid_info={},
for mol in mol_tqdm:
mol_tqdm.set_postfix(mol=mol)

# Determine which feature to map
# if feature not given, then determine it for each molecule
if 'feature' not in grid_info_ref:
# if we havent mapped anything yet or if we reset
if 'mapped_features' not in list(f5[mol].keys()) or reset:
grid_info['feature'] = list(f5[mol + '/features'].keys())

# if we have already mapped stuff
elif 'mapped_features' in list(f5[mol].keys()):

# feature name
all_feat = list(f5[mol + '/features'].keys())

# feature already mapped
mapped_feat = list(
f5[mol + '/mapped_features/Feature_ind'].keys())

# we select only the feture that were not mapped yet
grid_info['feature'] = []
for feat_name in all_feat:
if not any(map(lambda x: x.startswith(feat_name + '_'),
mapped_feat)):
grid_info['feature'].append(feat_name)

try:
# compute the data we want on the grid
gt.GridTools(
Expand Down Expand Up @@ -1063,7 +1165,6 @@ def map_features(self, grid_info={},
# close he hdf5 file
f5.close()


# ====================================================================================
#
# REMOVE DATA FROM THE DATA SET
Expand Down Expand Up @@ -1402,7 +1503,7 @@ def _compute_targets(targ_list, pdb_data, targrp):
# ADD PDB FILE
#
# ====================================================================================

def _add_pdb(self, molgrp, pdbfile, name):
"""Add a pdb to a molgrp.
Expand All @@ -1418,10 +1519,10 @@ def _add_pdb(self, molgrp, pdbfile, name):
with open(pdbfile, 'r') as fi:
data = [line.split('\n')[0]
for line in fi if line.startswith('ATOM')]

# some alignement
elif isinstance(self.align, dict):

sqldb = self._get_aligned_sqldb(pdbfile, self.align)
data = sqldb.sql2pdb()

Expand All @@ -1433,7 +1534,7 @@ def _add_pdb(self, molgrp, pdbfile, name):
@staticmethod
def _get_aligned_sqldb(pdbfile, dict_align):
"""return a sqldb of the pdb that is aligned as specified in the dict
Arguments:
pdbfile {str} -- path ot the pdb
dict_align {dict} -- dictionanry of options to align the pdb
Expand All @@ -1451,14 +1552,14 @@ def _get_aligned_sqldb(pdbfile, dict_align):
'chain2' : dict_align['chain2']}
else:
chains = {}
sqldb = align_interface(pdbfile, plane=dict_align['plane'],

sqldb = align_interface(pdbfile, plane=dict_align['plane'],
export=dict_align['export'],
**chains)

else:

sqldb = align_along_axis(pdbfile, axis=dict_align['axis'],
sqldb = align_along_axis(pdbfile, axis=dict_align['axis'],
export = dict_align['export'],
**dict_align['selection'])

Expand All @@ -1474,8 +1575,8 @@ def _get_aligned_sqldb(pdbfile, dict_align):
def _get_aligned_rotation_axis_angle(random_seed, dict_align):
"""Returns the axis and angle of rotation for data
augmentation with aligned complexes
Arguments:
Arguments:
random_seed {int} -- random seed of rotation
dict_align {dict} -- the dict describing the alignement
Expand All @@ -1498,7 +1599,7 @@ def _get_aligned_rotation_axis_angle(random_seed, dict_align):
axis = [1.,0.,0.]
else:
raise ValueError("plane must be xy, xz or yz")

elif 'axis' in dict_align.keys():
if dict_align['axis'] == 'x':
axis = [1.,0.,0.]
Expand Down Expand Up @@ -1572,10 +1673,10 @@ def _rotate_feature(molgrp, axis, angle, center, feat_name='all'):
feat = list(feat)

for fn in feat:

# extract the data
data = molgrp['features/' + fn][()]

# if data not empty
if data.shape[0] != 0:

Expand Down
Loading

0 comments on commit 2af9afc

Please sign in to comment.