Skip to content
This repository has been archived by the owner on Nov 28, 2023. It is now read-only.

Commit

Permalink
Merge pull request #200 from DeepRank/issue198
Browse files Browse the repository at this point in the history
Remove hardcoded chainIDs in issue 198
  • Loading branch information
CunliangGeng authored Dec 17, 2020
2 parents 3249d87 + d91cad0 commit 66fd866
Show file tree
Hide file tree
Showing 36 changed files with 17,538 additions and 17,995 deletions.
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Byte-compiled / optimized
deeprank.egg-info
database*
dist
build
build

# specific architure files
deeprank/learn/arch_*
Expand Down Expand Up @@ -60,3 +60,7 @@ test/atomic_pair_interaction.dat
# Mac OSX files
.DS_Store

# test coverage
htmlcov
coverage.xml
.coverage
58 changes: 34 additions & 24 deletions deeprank/features/AtomicFeature.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,18 @@

class AtomicFeature(FeatureClass):

def __init__(self, pdbfile, param_charge=None, param_vdw=None,
patch_file=None, contact_cutoff=8.5, verbose=False):
def __init__(self, pdbfile, chain1='A', chain2='B', param_charge=None,
param_vdw=None, patch_file=None, contact_cutoff=8.5,
verbose=False):
"""Compute the Coulomb, van der Waals interaction and charges.
Args:
pdbfile (str): pdb file of the molecule
chain1 (str): First chain ID, defaults to 'A'
chain2 (str): Second chain ID, defaults to 'B'
param_charge (str): file name of the force field file
containing the charges e.g. protein-allhdg5.4_new.top.
Must be of the format:
Expand Down Expand Up @@ -68,6 +72,8 @@ def __init__(self, pdbfile, param_charge=None, param_vdw=None,

# set a few things
self.pdbfile = pdbfile
self.chain1 = chain1
self.chain2 = chain2
self.param_charge = param_charge
self.param_vdw = param_vdw
self.patch_file = patch_file
Expand Down Expand Up @@ -222,16 +228,16 @@ def get_contact_atoms(self):
# but need to add a filter parameter to filter out ligand.

# position of the chains
xyz1 = np.array(self.sqldb.get('x,y,z', chainID='A'))
xyz2 = np.array(self.sqldb.get('x,y,z', chainID='B'))
xyz1 = np.array(self.sqldb.get('x,y,z', chainID=self.chain1))
xyz2 = np.array(self.sqldb.get('x,y,z', chainID=self.chain2))

# rowID of the chains
index_a = self.sqldb.get('rowID', chainID='A')
index_b = self.sqldb.get('rowID', chainID='B')
index_a = self.sqldb.get('rowID', chainID=self.chain1)
index_b = self.sqldb.get('rowID', chainID=self.chain2)

# resName of the chains
resName1 = np.array(self.sqldb.get('resName', chainID='A'))
resName2 = np.array(self.sqldb.get('resName', chainID='B'))
resName1 = np.array(self.sqldb.get('resName', chainID=self.chain1))
resName2 = np.array(self.sqldb.get('resName', chainID=self.chain2))

# declare the contact atoms
self.contact_atoms_A = []
Expand Down Expand Up @@ -536,7 +542,7 @@ def evaluate_charges(self, extend_contact_to_residue=False):
charge_data[key] = [charge[i]]

# xyz format
chain_dict = [{'A': 0, 'B': 1}[key[0]]]
chain_dict = [{self.chain1: 0, self.chain2: 1}[key[0]]]
key = tuple(chain_dict + xyz[i, :].tolist())
charge_data_xyz[key] = [charge[i]]

Expand Down Expand Up @@ -582,8 +588,8 @@ def evaluate_pair_interaction(self, print_interactions=False,
vdw_data_xyz = {}

# define the matrices
natA, natB = len(self.sqldb.get('x', chainID='A')), len(
self.sqldb.get('x', chainID='B'))
natA, natB = len(self.sqldb.get('x', chainID=self.chain1)), len(
self.sqldb.get('x', chainID=self.chain2))
matrix_elec = np.zeros((natA, natB))
matrix_vdw = np.zeros((natA, natB))

Expand Down Expand Up @@ -753,16 +759,16 @@ def compute_coulomb_interchain_only(self, dosum=True, contact_only=False):

else:

xyzA = np.array(self.sqldb.get('x,y,z', chainID='A'))
xyzB = np.array(self.sqldb.get('x,y,z', chainID='B'))
xyzA = np.array(self.sqldb.get('x,y,z', chainID=self.chain1))
xyzB = np.array(self.sqldb.get('x,y,z', chainID=self.chain2))

chargeA = np.array(self.sqldb.get('CHARGE', chainID='A'))
chargeB = np.array(self.sqldb.get('CHARGE', chainID='B'))
chargeA = np.array(self.sqldb.get('CHARGE', chainID=self.chain1))
chargeB = np.array(self.sqldb.get('CHARGE', chainID=self.chain2))

atinfoA = self.sqldb.get(
self.atom_key, chainID='A')
self.atom_key, chainID=self.chain1)
atinfoB = self.sqldb.get(
self.atom_key, chainID='B')
self.atom_key, chainID=self.chain2)

natA, natB = len(xyzA), len(xyzB)
matrix = np.zeros((natA, natB))
Expand Down Expand Up @@ -837,19 +843,19 @@ def compute_vdw_interchain_only(self, dosum=True, contact_only=False):

else:

xyzA = np.array(self.sqldb.get('x,y,z', chainID='A'))
xyzB = np.array(self.sqldb.get('x,y,z', chainID='B'))
xyzA = np.array(self.sqldb.get('x,y,z', chainID=self.chain1))
xyzB = np.array(self.sqldb.get('x,y,z', chainID=self.chain2))

vdwA = np.array(self.sqldb.get('eps,sig', chainID='A'))
vdwB = np.array(self.sqldb.get('eps,sig', chainID='B'))
vdwA = np.array(self.sqldb.get('eps,sig', chainID=self.chain1))
vdwB = np.array(self.sqldb.get('eps,sig', chainID=self.chain2))

epsA, sigA = vdwA[:, 0], vdwA[:, 1]
epsB, sigB = vdwB[:, 0], vdwB[:, 1]

atinfoA = self.sqldb.get(
self.atom_key, chainID='A')
self.atom_key, chainID=self.chain1)
atinfoB = self.sqldb.get(
self.atom_key, chainID='B')
self.atom_key, chainID=self.chain2)

natA, natB = len(xyzA), len(xyzB)
matrix = np.zeros((natA, natB))
Expand Down Expand Up @@ -908,18 +914,22 @@ def _prefactor_vdw(r):
#
########################################################################

def __compute_feature__(pdb_data, featgrp, featgrp_raw):
def __compute_feature__(pdb_data, featgrp, featgrp_raw, chain1, chain2):
"""Main function called in deeprank for the feature calculations.
Args:
pdb_data (list(bytes)): pdb information
featgrp (str): name of the group where to save xyz-val data
featgrp_raw (str): name of the group where to save human readable data
chain1 (str): First chain ID
chain2 (str): Second chain ID
"""
path = os.path.dirname(os.path.realpath(__file__))
FF = path + '/forcefield/'

atfeat = AtomicFeature(pdb_data,
chain1=chain1,
chain2=chain2,
param_charge=FF + 'protein-allhdg5-4_new.top',
param_vdw=FF + 'protein-allhdg5-4_new.param',
patch_file=FF + 'patch.top')
Expand Down
31 changes: 21 additions & 10 deletions deeprank/features/BSA.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

class BSA(FeatureClass):

def __init__(self, pdb_data, chainA='A', chainB='B'):
def __init__(self, pdb_data, chain1='A', chain2='B'):
"""Compute the burried surface area feature.
Freesasa is required for this feature.
Expand All @@ -24,8 +24,8 @@ def __init__(self, pdb_data, chainA='A', chainB='B'):
Args :
pdb_data (list(byte) or str): pdb data or pdb filename
chainA (str, optional): name of the first chain
chainB (str, optional): name of the second chain
chain1 (str, optional): name of the first chain
chain2 (str, optional): name of the second chain
Example :
>>> bsa = BSA('1AK4.pdb')
Expand All @@ -35,7 +35,9 @@ def __init__(self, pdb_data, chainA='A', chainB='B'):
"""
self.pdb_data = pdb_data
self.sql = pdb2sql.interface(pdb_data)
self.chains_label = [chainA, chainB]
self.chain1 = chain1
self.chain2 = chain2
self.chains_label = [chain1, chain2]

self.feature_data = {}
self.feature_data_xyz = {}
Expand Down Expand Up @@ -84,8 +86,8 @@ def get_contact_residue_sasa(self, cutoff=5.5):
self.bsa_data = {}
self.bsa_data_xyz = {}

ctc_res = self.sql.get_contact_residues(cutoff=cutoff)
ctc_res = ctc_res["A"] + ctc_res["B"]
ctc_res = self.sql.get_contact_residues(cutoff=cutoff, chain1=self.chain1, chain2=self.chain2)
ctc_res = ctc_res[self.chain1] + ctc_res[self.chain2]

# handle with small interface or no interface
total_res = len(ctc_res)
Expand Down Expand Up @@ -115,9 +117,9 @@ def get_contact_residue_sasa(self, cutoff=5.5):
bsa = asa_unbound - asa_complex

# define the xyz key : (chain,x,y,z)
chain = {'A': 0, 'B': 1}[res[0]]
chain = {self.chain1: 0, self.chain2: 1}[res[0]]

# get the center
# get the center
_, xyz = self.get_residue_center(self.sql, res=res)
xyzkey = tuple([chain] + xyz[0])

Expand All @@ -136,10 +138,19 @@ def get_contact_residue_sasa(self, cutoff=5.5):
########################################################################


def __compute_feature__(pdb_data, featgrp, featgrp_raw):
def __compute_feature__(pdb_data, featgrp, featgrp_raw, chain1, chain2):
"""Main function called in deeprank for the feature calculations.
Args:
pdb_data (list(bytes)): pdb information
featgrp (str): name of the group where to save xyz-val data
featgrp_raw (str): name of the group where to save human readable data
chain1 (str): First chain ID
chain2 (str): Second chain ID
"""

# create the BSA instance
bsa = BSA(pdb_data)
bsa = BSA(pdb_data, chain1, chain2)

# get the structure/calc
bsa.get_structure()
Expand Down
30 changes: 23 additions & 7 deletions deeprank/features/FullPSSM.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

class FullPSSM(FeatureClass):

def __init__(self, mol_name=None, pdb_file=None, pssm_path=None,
pssm_format='new', out_type='pssmvalue'):
def __init__(self, mol_name=None, pdb_file=None, chain1='A', chain2='B',
pssm_path=None, pssm_format='new', out_type='pssmvalue'):
"""Compute all the PSSM data.
Simply extracts all the PSSM information and
Expand All @@ -26,6 +26,8 @@ def __init__(self, mol_name=None, pdb_file=None, pssm_path=None,
Args:
mol_name (str): name of the molecule. Defaults to None.
pdb_file (str): name of the pdb_file. Defaults to None.
chain1 (str): First chain ID. Defaults to 'A'
chain2 (str): Second chain ID. Defaults to 'B'
pssm_path (str): path to the pssm data. Defaults to None.
pssm_format (str): "old" or "new" pssm format.
Defaults to 'new'.
Expand All @@ -50,6 +52,8 @@ def __init__(self, mol_name=None, pdb_file=None, pssm_path=None,
self.pssm_path = pssm_path
self.pssm_format = pssm_format
self.out_type = out_type.lower()
self.chain1 = chain1
self.chain2 = chain2

if isinstance(pdb_file, str) and mol_name is None:
self.mol_name = os.path.basename(pdb_file).split('.')[0]
Expand Down Expand Up @@ -174,9 +178,10 @@ def get_feature_value(self, cutoff=5.5):

# get interface contact residues
# ctc_res = {"A":[chain 1 residues], "B": [chain2 residues]}
ctc_res = sql.get_contact_residues(cutoff=cutoff)
ctc_res = sql.get_contact_residues(cutoff=cutoff,
chain1=self.chain1, chain2=self.chain2)
sql._close()
ctc_res = ctc_res["A"] + ctc_res["B"]
ctc_res = ctc_res[self.chain1] + ctc_res[self.chain2]

# handle with small interface or no interface
total_res = len(ctc_res)
Expand Down Expand Up @@ -213,7 +218,7 @@ def get_feature_value(self, cutoff=5.5):

# get feature values
for res in ctc_res_with_pssm:
chain = {'A': 0, 'B': 1}[res[0]]
chain = {self.chain1: 0, self.chain2: 1}[res[0]]
key = tuple([chain] + xyz_dict[res])
for name, value in zip(self.feature_names, self.pssm[res]):
# Make sure the feature_names and pssm[res] have
Expand All @@ -233,7 +238,17 @@ def get_feature_value(self, cutoff=5.5):
########################################################################


def __compute_feature__(pdb_data, featgrp, featgrp_raw, out_type='pssmvalue'):
def __compute_feature__(pdb_data, featgrp, featgrp_raw, chain1, chain2, out_type='pssmvalue'):
"""Main function called in deeprank for the feature calculations.
Args:
pdb_data (list(bytes)): pdb information
featgrp (str): name of the group where to save xyz-val data
featgrp_raw (str): name of the group where to save human readable data
chain1 (str): First chain ID
chain2 (str): Second chain ID
out_type (str): which feature to generate, 'pssmvalue' or 'pssmic'.
"""

if config.PATH_PSSM_SOURCE is None:
raise FileExistsError(f"No available PSSM source, "
Expand All @@ -244,7 +259,8 @@ def __compute_feature__(pdb_data, featgrp, featgrp_raw, out_type='pssmvalue'):
mol_name = os.path.split(featgrp.name)[0]
mol_name = mol_name.lstrip('/')

pssm = FullPSSM(mol_name, pdb_data, path, out_type=out_type)
pssm = FullPSSM(mol_name, pdb_data, chain1=chain1, chain2=chain2,
pssm_path=path, out_type=out_type)

# read the raw data
pssm.read_PSSM_data()
Expand Down
5 changes: 3 additions & 2 deletions deeprank/features/PSSM_IC.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ class PSSM_IC(FullPSSM):
#
##########################################################################

def __compute_feature__(pdb_data, featgrp, featgrp_raw):
def __compute_feature__(pdb_data, featgrp, featgrp_raw, chain1, chain2):

func(pdb_data, featgrp, featgrp_raw, out_type='pssmic')
func(pdb_data, featgrp, featgrp_raw, chain1=chain1, chain2=chain2,
out_type='pssmic')

##########################################################################
#
Expand Down
Loading

0 comments on commit 66fd866

Please sign in to comment.