In [213]:
import pandas as pd

pdb_list_file = 'pdb_list.txt'
data = '/Users/dsuveges/project/random_notebooks/issue-1891_extracting_drug-ligand_complex/molecules_w_targets/'

pdb_ids = pd.read_csv(pdb_list_file, sep=',', header=None, names=['pdbId'])
print(pdb_ids.head())
print(len(pdb_ids))


  pdbId
0  13gs
1  1avd
2  1b86
3  1bzm
4  1bzs
988


In [214]:
import json
from json import JSONDecodeError

import requests
from functools import reduce
import pandas as pd
from pyspark.sql.functions import (
    col, udf, struct, lit, split, expr, collect_set, struct, 
    regexp_replace, min as pyspark_min, explode, when,
    array_contains, count, first, element_at, size, sum as pyspark_sum, array
)
from pyspark.sql.types import (
    FloatType, ArrayType, StructType, StructField, BooleanType, StringType, IntegerType
)
from pyspark.sql import SparkSession
from pyspark.conf import SparkConf
from collections import defaultdict
from pyspark.context import SparkContext

# establish spark connection
spark = (
    SparkSession.builder
    .master('local[*]')
    .getOrCreate()
)

# Dataset witht all the details, produced earlier:
input_dataset = (
    spark.read.parquet(data)
    .persist()
)

# This dataset is grouped by pdb id to get all the compounds:
data_to_look_up = (
    input_dataset
    .filter(col('ensembl_gene_id').startswith('ENSG'))
    .groupby('pdb_structure_id')
    .agg(collect_set(col('pdb_compound_id')).alias('coumpound_ids'))
    .persist()
)

data_to_look_up.show()
print(data_to_look_up.count())

+----------------+-------------+
|pdb_structure_id|coumpound_ids|
+----------------+-------------+
|            1avd|        [BTN]|
|            1d5m|        [ALC]|
|            1d6q|        [GOL]|
|            1e9b|        [ATM]|
|            1ere|        [EST]|
|            1j3z|        [CMO]|
|            1jan|         [ZN]|
|            1ln2|        [MSE]|
|            1lq8|   [NDG, IPA]|
|            1ozj|         [ZN]|
|            1qxe|   [FUX, OXY]|
|            1raz|         [ZN]|
|            1t2v|        [SEP]|
|            1t9s|    [ZN, 5GP]|
|            1y8q|    [ATP, ZN]|
|            1ydb|    [AZM, ZN]|
|            1yxu|        [AMP]|
|            1z0f|        [GDP]|
|            1z89|   [62P, NAP]|
|            2b02|        [MSE]|
+----------------+-------------+
only showing top 20 rows

27315


In [215]:
pdb_w_compound = (
    data_to_look_up
    .withColumnRenamed('pdb_structure_id', 'pdbId')
    .toPandas()
    .merge(pdb_ids, on='pdbId', how='right')
)

print(pdb_w_compound.head())
print(len(pdb_w_compound))

  pdbId    coumpound_ids
0  13gs  [SAS, MES, GSH]
1  1avd            [BTN]
2  1b86            [OXY]
3  1bzm             [ZN]
4  1bzs             [ZN]
988


In [216]:
from plip.structure.preparation import PDBComplex
from plip.exchange.report import BindingSiteReport
from plip.basic import config


class GetPDB:
    
    PDB_URL = 'https://www.ebi.ac.uk/pdbe/entry-files/download/pdb{}.ent'
    
    def __init__(self, data_folder: str) -> None:
        self.data_folder = data_folder
        
    
    def get_pdb(self, pdb_structure_id: str) -> str:
        """Reading file from a given loaction fetch and save if not found"""
        try:
            # Readind data from the given location:
            with open(f'{self.data_folder}/pdb{pdb_structure_id}.ent', 'rt') as f:
                data = f.read()
            
        except FileNotFoundError:
            # Fetch data from the web
            data = self.fetch_pdb(pdb_structure_id)
            
            # Save file
            with open(f'{self.data_folder}/pdb{pdb_structure_id}.ent', 'wt') as f:
                f.write(data)
    
        return data
    

    def fetch_pdb(self, pdb_structure_id: str)-> str:
        """This function fetches the pdb file from ePDB server as a string

        Args:
            pdb_structure_id (str)
        Returns:
            structure data in pdb format as string eg 'AIN:A:1202'
        """
        if not pdb_structure_id:
            return ''

        try:
            response = requests.get(self.PDB_URL.format(pdb_structure_id))
            data = response.text
        except ConnectionError:
            data = ''

        return data


def run_plip(row):
    """This function fetches the pdb file from ePDB server as a string

    Args:
        valid_types (str)
        plip_wanted_columns
        combination_dict
    Returns:
        A PySpark dataframe
    """
    (structure, drugs) = row

    try:
        pdb = gpdb.get_pdb(structure)
    except:
        return 'failed_to_fetch.'
    
    if pdb == '':
        return 'returned empty'
    

    protlig = PDBComplex()
    
    try:
        protlig.load_pdb(pdb, as_string=True)  # load the pdb file
        return 'parsed alright'
    except:
        return 'parsing failed'


def parse_interaction(interaction: PLInteraction, compound_id:str, pdb_id:str) -> dict:

    interaction_type = interaction.__doc__.split('(')[0]
    
    if interaction_type == 'waterbridge':
        return {}

    # Parsing data form the interaction:
    return {
        'pdb_structure_id': pdb_id,
        'compound_id': compound_id,
        'interaction_type': interaction_type,
        'prot_residue_number': interaction.resnr,
        'prot_residue_type': interaction.restype,
        'prot_chain_id': interaction.reschain
    }

def characerize_complex(row: tuple) -> list:
    # Get pdb data:
    (pdb_id, compounds) = row
    print(compounds)
    pdb_data = gpdb.get_pdb(pdb_id)

    # Load into plip:
    mol_complex = PDBComplex()
    mol_complex.load_pdb(pdb_data, as_string=True)

    # Filtering out only the relevant ligands:
    ligands_of_interest = [ligand for ligand in mol_complex.ligands if ligand.hetid in compounds]

    # Characterizing relevant complex:
    [mol_complex.characterize_complex(ligand) for ligand in ligands_of_interest]

    # Extract details from ligands:
    return [parse_interaction(interaction, compound.split(':')[0], pdb_id) for compound, interaction_set in mol_complex.interaction_sets.items() for interaction in interaction_set.all_itypes]


In [218]:
import dask.dataframe as dd

ddf = dd.from_pandas(pdb_w_compound, npartitions=30)
gpdb = GetPDB(data_folder='pdbs/')
    

res_df = (
    ddf
    .assign(
        new_col = ddf.map_partitions(
            lambda df: df.apply(lambda row: characerize_complex(row), axis=1), meta=(None, 'f8')
        )
#        .map_partitions(lambda df: df.apply(run_plip, axis=1), meta=(None, 'f8'))
    )
       .compute(scheduler='processes')
)

res_df.head()


Unnamed: 0,pdbId,coumpound_ids,new_col
0,13gs,"[SAS, MES, GSH]","[{'pdb_structure_id': '13gs', 'compound_id': '..."
1,1avd,[BTN],"[{'pdb_structure_id': '1avd', 'compound_id': '..."
2,1b86,[OXY],[]
3,1bzm,[ZN],[]
4,1bzs,[ZN],"[{'pdb_structure_id': '1bzs', 'compound_id': '..."


In [219]:
from itertools import chain

pd.DataFrame(list(chain.from_iterable(
    res_df
    .loc[lambda df: df.new_col.apply(lambda x: len(x) >0)]
    .assign(new_col = lambda df: df.new_col.apply(lambda l: [value for value in l if value != {}]))
    .new_col
    .to_list()
)))



Unnamed: 0,pdb_structure_id,compound_id,interaction_type,prot_residue_number,prot_residue_type,prot_chain_id
0,13gs,GSH,saltbridge,13,ARG,A
1,13gs,GSH,saltbridge,44,LYS,A
2,13gs,GSH,hbond,65,SER,A
3,13gs,GSH,hbond,51,GLN,A
4,13gs,GSH,hbond,52,LEU,A
...,...,...,...,...,...,...
65,1avd,BTN,hydroph_interaction,70,TRP,B
66,1bzs,ZN,metal_complex,147,HIS,A
67,1bzs,ZN,metal_complex,149,ASP,A
68,1bzs,ZN,metal_complex,162,HIS,A
