Skip to content
This repository has been archived by the owner on Nov 28, 2023. It is now read-only.

Commit

Permalink
Merge 31a6110 into 4ff7834
Browse files Browse the repository at this point in the history
  • Loading branch information
NicoRenaud committed Mar 31, 2020
2 parents 4ff7834 + 31a6110 commit a959975
Show file tree
Hide file tree
Showing 17 changed files with 233 additions and 42 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ before_install:
# Useful for debugging any issues with conda
- conda info -a

# python
# python
- conda install python=3.7

# install openmpi for mpi4py
Expand Down
4 changes: 2 additions & 2 deletions deeprank/features/AtomicFeature.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(self, pdbfile, param_charge=None, param_vdw=None,
>>> atfeat.evaluate_pair_interaction(save_interactions=test_name)
>>>
>>> # close the db
>>> atfeat.sqldb.close()
>>> atfeat.sqldb._close()
"""

super().__init__("Atomic")
Expand Down Expand Up @@ -939,7 +939,7 @@ def __compute_feature__(pdb_data, featgrp, featgrp_raw):
atfeat.export_data_hdf5(featgrp_raw)

# close
atfeat.sqldb.close()
atfeat.sqldb._close()


########################################################################
Expand Down
6 changes: 3 additions & 3 deletions deeprank/features/BSA.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self, pdb_data, chainA='A', chainB='B'):
>>> bsa = BSA('1AK4.pdb')
>>> bsa.get_structure()
>>> bsa.get_contact_residue_sasa()
>>> bsa.sql.close()
>>> bsa.sql._close()
"""
self.pdb_data = pdb_data
self.sql = pdb2sql.interface(pdb_data)
Expand Down Expand Up @@ -156,7 +156,7 @@ def __compute_feature__(pdb_data, featgrp, featgrp_raw):
bsa.export_data_hdf5(featgrp_raw)

# close the file
bsa.sql.close()
bsa.sql._close()


########################################################################
Expand All @@ -177,7 +177,7 @@ def __compute_feature__(pdb_data, featgrp, featgrp_raw):
bsa = BSA(pdb_file)
bsa.get_structure()
bsa.get_contact_residue_sasa()
bsa.sql.close()
bsa.sql._close()

pprint(bsa.feature_data)
print()
Expand Down
2 changes: 1 addition & 1 deletion deeprank/features/FullPSSM.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def get_feature_value(self, cutoff=5.5):
# get interface contact residues
# ctc_res = {"A":[chain 1 residues], "B": [chain2 residues]}
ctc_res = sql.get_contact_residues(cutoff=cutoff)
sql.close()
sql._close()
ctc_res = ctc_res["A"] + ctc_res["B"]

# handle with small interface or no interface
Expand Down
2 changes: 1 addition & 1 deletion deeprank/features/NaivePSSM.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def get_feature_value(self, contact_only=True):

contact_residue = sql.get_contact_residue(cutoff=5.5)
contact_residue = contact_residue["A"] + contact_residue["B"]
sql.close()
sql._close()

pssm_data_xyz = {}
pssm_data = {}
Expand Down
6 changes: 3 additions & 3 deletions deeprank/features/ResidueDensity.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def get(self, cutoff=5.5):
# handle with small interface or no interface
if total_ctc == 0:
# first close the sql
self.sql.close()
self.sql._close()

raise ValueError(
f"No residue contact found with the cutoff {cutoff}Å. "
Expand Down Expand Up @@ -179,7 +179,7 @@ def __compute_feature__(pdb_data, featgrp, featgrp_raw):
resdens.export_data_hdf5(featgrp_raw)

# close sql
resdens.sql.close()
resdens.sql._close()

########################################################################
#
Expand All @@ -203,7 +203,7 @@ def __compute_feature__(pdb_data, featgrp, featgrp_raw):

resdens.get()
resdens.extract_features()
resdens.sql.close()
resdens.sql._close()

pprint(resdens.feature_data)
print()
Expand Down
134 changes: 115 additions & 19 deletions deeprank/generate/DataGenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from deeprank.config import logger
from deeprank.generate import GridTools as gt
import pdb2sql
from pdb2sql.align import align as align_along_axis
from pdb2sql.align import align_interface

try:
from tqdm import tqdm
Expand All @@ -33,7 +35,7 @@ def _printif(string, cond): return print(string) if cond else None
class DataGenerator(object):

def __init__(self, pdb_select=None, pdb_source=None,
pdb_native=None, pssm_source=None,
pdb_native=None, pssm_source=None, align=None,
compute_targets=None, compute_features=None,
data_augmentation=None, hdf5='database.h5', mpi_comm=None):
"""Generate the data (features/targets/maps) required for deeprank.
Expand Down Expand Up @@ -77,6 +79,7 @@ def __init__(self, pdb_select=None, pdb_source=None,
self.pdb_select = pdb_select or []
self.pdb_source = pdb_source or []
self.pdb_native = pdb_native or []
self.align = align

if pssm_source is not None:
config.PATH_PSSM_SOURCE = pssm_source
Expand Down Expand Up @@ -239,6 +242,7 @@ def create_database(

# names of the molecule
mol_name = os.path.splitext(os.path.basename(cplx))[0]
mol_name = mol_name.replace('-', '_')
mol_aug_name_list = []

try:
Expand Down Expand Up @@ -395,7 +399,7 @@ def create_database(
f' with {self.data_augmentation} times...')

# loop over the complexes
for _, mol_aug_name in enumerate(mol_aug_name_list):
for mol_aug_name in mol_aug_name_list:

# crete a subgroup for the molecule
molgrp = self.f5.require_group(mol_aug_name)
Expand All @@ -406,10 +410,14 @@ def create_database(
self._add_pdb(molgrp, ref, 'native')

# get the rotation axis and angle
axis, angle = pdb2sql.transform.get_rot_axis_angle(random_seed)
if self.align is None:
axis, angle = pdb2sql.transform.get_rot_axis_angle(random_seed)
else:
axis, angle = self._get_aligned_rotation_axis_angle(random_seed,
self.align)

# create the new pdb and get molecule center
# molecule center is the origin of rotation
# molecule center is the origin of rotation)
mol_center = self._add_aug_pdb(
molgrp, cplx, 'complex', axis, angle)

Expand Down Expand Up @@ -733,7 +741,7 @@ def _get_grid_center(pdb, contact_distance):
center_contact = np.mean(
np.array(sqldb.get('x,y,z', rowID=contact_atoms)), 0)

sqldb.close()
sqldb._close()

return center_contact

Expand Down Expand Up @@ -888,6 +896,7 @@ def map_features(self, grid_info={},
for m in modes:
if m not in grid_info:
grid_info[m] = 'ind'

################################################################
#
################################################################
Expand Down Expand Up @@ -1300,50 +1309,137 @@ def _compute_targets(targ_list, pdb_data, targrp):
# ADD PDB FILE
#
# ====================================================================================

@staticmethod
def _add_pdb(molgrp, pdbfile, name):

def _add_pdb(self, molgrp, pdbfile, name):
"""Add a pdb to a molgrp.
Args:
molgrp (str): mopl group where tp add the pdb
pdbfile (str): psb file to add
name (str): dataset name in the hdf5 molgroup
"""
# read the pdb and extract the ATOM lines
with open(pdbfile, 'r') as fi:
data = [line.split('\n')[0]
for line in fi if line.startswith('ATOM')]

# no alignement
if self.align is None:
# read the pdb and extract the ATOM lines
with open(pdbfile, 'r') as fi:
data = [line.split('\n')[0]
for line in fi if line.startswith('ATOM')]

# some alignement
elif isinstance(self.align, dict):

sqldb = self._get_aligned_sqldb(pdbfile, self.align)
data = sqldb.sql2pdb()

# PDB default line length is 80
# http://www.wwpdb.org/documentation/file-format
data = np.array(data).astype('|S78')
molgrp.create_dataset(name, data=data)

@staticmethod
def _get_aligned_sqldb(pdbfile, dict_align):
"""return a sqldb of the pdb that is aligned as specified in the dict
Arguments:
pdbfile {str} -- path ot the pdb
dict_align {dict} -- dictionanry of options to align the pdb
"""
if 'selection' not in dict_align.keys():
dict_align['selection'] = {}

if 'export' not in dict_align.keys():
dict_align['export'] = False

if dict_align['selection'] == 'interface':

if np.all([k in dict_align for k in ['chain1', 'chain2']]):
chains = {'chain1' : dict_align['chain1'],
'chain2' : dict_align['chain2']}
else:
chains = {}

sqldb = align_interface(pdbfile, plane=dict_align['plane'],
export=dict_align['export'],
**chains)

else:

sqldb = align_along_axis(pdbfile, axis=dict_align['axis'],
export = dict_align['export'],
**dict_align['selection'])

return sqldb

# ====================================================================================
#
# AUGMENTED DATA
#
# ====================================================================================

# add a rotated pdb structure to the database

@staticmethod
def _add_aug_pdb(molgrp, pdbfile, name, axis, angle):
def _get_aligned_rotation_axis_angle(random_seed, dict_align):
"""Returns the axis and angle of rotation for data
augmentation with aligned complexes
Arguments:
random_seed {int} -- random seed of rotation
dict_align {dict} -- the dict describing the alignement
Returns:
list(float): axis of rotation
float: angle of rotation
"""

if random_seed is not None:
np.random.seed(random_seed)

angle = 2 * np.pi * np.random.rand()

if 'plane' in dict_align.keys():
if dict_align['plane'] == 'xy':
axis = [0.,0.,1.]
elif dict_align['plane'] == 'xz':
axis = [0.,1.,0.]
elif dict_align['plane'] == 'yz':
axis = [1.,0.,0.]
else:
raise ValueError("plane must be xy, xz or yz")

elif 'axis' in dict_align.keys():
if dict_align['axis'] == 'x':
axis = [1.,0.,0.]
elif dict_align['axis'] == 'y':
axis = [0.,1.,0.]
elif dict_align['axis'] == 'z':
axis = [0.,0.,1.]
else:
raise ValueError("axis must be x, y or z")
else:
raise ValueError('dict_align must contains plane or axis')

return axis, angle

# add a rotated pdb structure to the database
def _add_aug_pdb(self, molgrp, pdbfile, name, axis, angle):
"""Add augmented pdbs to the dataset.
Args:
molgrp (str): name of the molgroup
pdbfile (str): pdb file name
name (str): name of the dataset
axis (list(float)): axis of rotation
angle (folat): angle of rotation
angle (float): angle of rotation
dict_align (dict) : dict for alignement of the original pdb
Returns:
list(float): center of the molecule
"""
# create tthe sqldb and extract positions
sqldb = pdb2sql.pdb2sql(pdbfile)
# create the sqldb and extract positions
if self.align is None:
sqldb = pdb2sql.pdb2sql(pdbfile)
else:
sqldb = self._get_aligned_sqldb(pdbfile, self.align)

# rotate the positions
pdb2sql.transform.rot_axis(sqldb, axis, angle)
Expand All @@ -1358,7 +1454,7 @@ def _add_aug_pdb(molgrp, pdbfile, name, axis, angle):
molgrp.create_dataset(name, data=data)

# close the db
sqldb.close()
sqldb._close()

return center

Expand Down
4 changes: 2 additions & 2 deletions deeprank/generate/GridTools.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def create_new_data(self):
self.add_all_atomic_densities()

# cloe the db file
self.sqldb.close()
self.sqldb._close()

################################################################

Expand Down Expand Up @@ -199,7 +199,7 @@ def update_feature(self):
self.add_all_atomic_densities()

# cloe the db file
self.sqldb.close()
self.sqldb._close()

################################################################

Expand Down
2 changes: 1 addition & 1 deletion deeprank/learn/DataSet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1325,7 +1325,7 @@ def map_atomic_densities(

densities += [atdensA, atdensB]

sql.close()
sql._close()

return densities

Expand Down
4 changes: 2 additions & 2 deletions deeprank/tools/sasa.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def get_residue_center(self, chainA='A', chainB='B'):
for r in resB[:, :2]:
if tuple(r) not in self.resinfo[chainB]:
self.resinfo[chainB].append(tuple(r))
sql.close()
sql._close()

def get_residue_carbon_beta(self, chainA='A', chainB='B'):
"""Extract the position of the carbon beta of each residue.
Expand All @@ -104,7 +104,7 @@ def get_residue_carbon_beta(self, chainA='A', chainB='B'):
'resSeq,resName,x,y,z',
name='CB',
chainID=chainB))
sql.close()
sql._close()

assert len(resA[:, 0].astype(np.int).tolist()) == len(
np.unique(resA[:, 0].astype(np.int)).tolist())
Expand Down
2 changes: 1 addition & 1 deletion deeprank/utils/visualize3Ddata.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def visualize3Ddata(hdf5=None, mol_name=None, out=None):
# create the pdb file
sqldb = pdb2sql.pdb2sql(molgrp['complex'][:])
sqldb.exportpdb(outdir + '/complex.pdb')
sqldb.close()
sqldb._close()

# get the grid
grid = {}
Expand Down
Loading

0 comments on commit a959975

Please sign in to comment.