# Fetch interactions from pdb

**procecc:**

1. Read pre-defined data
2. Re-arrange the same data by pdb vs. ligand
3. Fetch pdb header (via single thread 100 structure: 18 sec.)
4. Process header 
5. Return data in proper format

In [109]:
import json
from json import JSONDecodeError

import requests
from functools import reduce
import pandas as pd
from pyspark.sql.functions import (
    col, udf, struct, lit, split, expr, collect_set, struct, 
    regexp_replace, min as pyspark_min, explode, when,
    array_contains, count, first, element_at, size, sum as pyspark_sum, array, array_sort
)
from pyspark.sql.types import (
    FloatType, ArrayType, StructType, StructField, BooleanType, StringType, IntegerType
)
from pyspark.sql import SparkSession
from pyspark.conf import SparkConf
from collections import defaultdict
from pyspark.context import SparkContext

# establish spark connection
spark = (
    SparkSession.builder
    .master('local[*]')
    .getOrCreate()
)

In [6]:
data = '/Users/dsuveges/project/random_notebooks/issue-1891_extracting_drug-ligand_complex/molecules_w_targets/'

In [14]:
# Dataset witht all the details, produced earlier:
input_dataset = (
    spark.read.parquet(data)
    .persist()
)

# This dataset is grouped by pdb id to get all the compounds:
data_to_look_up = (
    input_dataset
    .filter(col('ensembl_gene_id').startswith('ENSG'))
    .groupby('pdb_structure_id')
    .agg(collect_set(col('pdb_compound_id')))
    .persist()
)

data_to_look_up.show()
print(data_to_look_up.count())

+----------------+----------------------------+
|pdb_structure_id|collect_set(pdb_compound_id)|
+----------------+----------------------------+
|            1avd|                       [BTN]|
|            1d5m|                       [ALC]|
|            1d6q|                       [GOL]|
|            1e9b|                       [ATM]|
|            1ere|                       [EST]|
|            1j3z|                       [CMO]|
|            1jan|                        [ZN]|
|            1ln2|                       [MSE]|
|            1lq8|                  [NDG, IPA]|
|            1ozj|                        [ZN]|
|            1qxe|                  [FUX, OXY]|
|            1raz|                        [ZN]|
|            1t2v|                       [SEP]|
|            1t9s|                   [ZN, 5GP]|
|            1y8q|                   [ATP, ZN]|
|            1ydb|                   [AZM, ZN]|
|            1yxu|                       [AMP]|
|            1z0f|                      

In [30]:


import urllib

def fetch_pdb(pdbId: str) -> None:
    """This function fetches a single PDB structure from PDBEurope
    
    Args:
        pdb_strcture_id: string, a single PDB structure identifier
        folder: string, a folder to save the structure
    """
    try:
        urllib.request.urlretrieve(f'https://www.ebi.ac.uk/pdbe/static/entry/download/{pdbId}.header', f'./{pdbId}.pdb')
    except:
        # logging.warning(f'Failed to fetch PDB structure: {pdb_strcture_id}')
        print(f'Failed to fetch PDB structure: {pdbId}')

    

@udf(IntegerType())
def test_parser(pdbId):
    try:
        with open(f'{pdbId}.pdb', 'rt') as f:
            lines = f.readlines()
    except:
        return 0

    binding_sites_lines = filter(lambda line: line.startswith('REMARK 800') and ':' in line, lines)
    return len([line for line in binding_sites_lines])


# Download data:
(
    data_to_look_up
    .limit(1000)
    .toPandas()
    .pdb_structure_id
    .apply(fetch_pdb)
)

parsing_test = (
    data_to_look_up
    .limit(1000)
    .withColumn('test', test_parser(col('pdb_structure_id')))
    .persist()
)

Failed to fetch PDB structure: 4y14
Failed to fetch PDB structure: 7ozy
Failed to fetch PDB structure: 4cql
Failed to fetch PDB structure: 4xol
Failed to fetch PDB structure: 5dak
Failed to fetch PDB structure: 5e83
Failed to fetch PDB structure: 5em3
Failed to fetch PDB structure: 5epz
Failed to fetch PDB structure: 5gli
Failed to fetch PDB structure: 5hou
Failed to fetch PDB structure: 5j9x
Failed to fetch PDB structure: 5l60
Failed to fetch PDB structure: 6tks
Failed to fetch PDB structure: 6wjd
Failed to fetch PDB structure: 6zel
Failed to fetch PDB structure: 6yjv
Failed to fetch PDB structure: 6ts0
Failed to fetch PDB structure: 7bfa
Failed to fetch PDB structure: 6t8v
Failed to fetch PDB structure: 6wov
Failed to fetch PDB structure: 7l20
Failed to fetch PDB structure: 4ug0
Failed to fetch PDB structure: 7ai1


In [33]:
(
    parsing_test
    .filter(col('test') == 0)
    .show(truncate=False)
)

+----------------+-------------------------------------------------------+----+
|pdb_structure_id|collect_set(pdb_compound_id)                           |test|
+----------------+-------------------------------------------------------+----+
|1d5m            |[ALC]                                                  |0   |
|1d6q            |[GOL]                                                  |0   |
|1lq8            |[NDG, IPA]                                             |0   |
|1t2v            |[SEP]                                                  |0   |
|2b02            |[MSE]                                                  |0   |
|2qki            |[GOL]                                                  |0   |
|3jz2            |[GOL]                                                  |0   |
|4lbo            |[SIA, BGC]                                             |0   |
|4oc5            |[ZN]                                                   |0   |
|5czx            |[EDO]                 

In [40]:
test_data = [
    {
        "study_id": "study_1",
        "tag_variant_id": "var1",
        "lead_variant_id": "var1",
    },
    {
        "study_id": "study_1",
        "tag_variant_id": "var2",
        "lead_variant_id": "var1",
    },
    {
        "study_id": "study_1",
        "tag_variant_id": "var3",
        "lead_variant_id": "var1",
    },
    {
        "study_id": "study_1",
        "tag_variant_id": "var4",
        "lead_variant_id": "var4",
    },
    {
        "study_id": "study_1",
        "tag_variant_id": "var5",
        "lead_variant_id": "var4",
    },
    {
        "study_id": "study_1",
        "tag_variant_id": "var6",
        "lead_variant_id": "var6",
    },
    {
        "study_id": "study_1",
        "tag_variant_id": "var7",
        "lead_variant_id": "var6",
    },
    {
        "study_id": "study_1",
        "tag_variant_id": "var8",
        "lead_variant_id": "var6",
    },
#    
#
    {
        "study_id": "study_2",
        "tag_variant_id": "var9",
        "lead_variant_id": "var9",
    },
    {
        "study_id": "study_2",
        "tag_variant_id": "var2",
        "lead_variant_id": "var9",
    },
    {
        "study_id": "study_2",
        "tag_variant_id": "var3",
        "lead_variant_id": "var9",
    },
#
#
    {
        "study_id": "study_3",
        "tag_variant_id": "var4",
        "lead_variant_id": "var10",
    },
    {
        "study_id": "study_3",
        "tag_variant_id": "var5",
        "lead_variant_id": "var10",
    },
    {
        "study_id": "study_3",
        "tag_variant_id": "var10",
        "lead_variant_id": "var10",
    },
    {
        "study_id": "study_3",
        "tag_variant_id": "var11",
        "lead_variant_id": "var11",
    },
    {
        "study_id": "study_3",
        "tag_variant_id": "var2",
        "lead_variant_id": "var11",
    },
]

df = (
    spark.createDataFrame(test_data)
    .select('study_id', 'lead_variant_id', 'tag_variant_id')
    .persist()
)

df.show(40)

+--------+---------------+--------------+
|study_id|lead_variant_id|tag_variant_id|
+--------+---------------+--------------+
| study_1|           var1|          var1|
| study_1|           var1|          var2|
| study_1|           var1|          var3|
| study_1|           var4|          var4|
| study_1|           var4|          var5|
| study_1|           var6|          var6|
| study_1|           var6|          var7|
| study_1|           var6|          var8|
| study_2|           var9|          var9|
| study_2|           var9|          var2|
| study_2|           var9|          var3|
| study_3|          var10|          var4|
| study_3|          var10|          var5|
| study_3|          var10|         var10|
| study_3|          var11|         var11|
| study_3|          var11|          var2|
+--------+---------------+--------------+



In [28]:
columns = ['study_id', 'lead_variant_id']

# Creating the two dataset to be joined:
left_df = reduce(lambda DF, col: DF.withColumnRenamed(col, col+'_left'), columns, df)
right_df = reduce(lambda DF, col: DF.withColumnRenamed(col, col+'_right'), columns, df)


overlapping_signals = (
    left_df
    .join(right_df, on='tag_variant_id', how='inner')
    .filter(col('study_id_right') != col('study_id_left'))
    .filter(col('lead_variant_id_left') > col('lead_variant_id_right'))
    .drop('tag_variant_id')
    .distinct()
    .persist()
)

overlapping_signals.show()

+--------------------+-------------+---------------------+--------------+
|lead_variant_id_left|study_id_left|lead_variant_id_right|study_id_right|
+--------------------+-------------+---------------------+--------------+
|                var9|      study_2|                 var1|       study_1|
|                var4|      study_1|                var10|       study_3|
|               var11|      study_3|                 var1|       study_1|
|                var9|      study_2|                var11|       study_3|
+--------------------+-------------+---------------------+--------------+



In [32]:
(
    overlapping_signals
    .join(left_df.withColumnRenamed('tag_variant_id', 'tag_variant_id_left'), on='lead_variant_id_left', how='inner')
    .join(right_df.withColumnRenamed('tag_variant_id', 'tag_variant_id_right'), on='lead_variant_id_right', how='inner')
    .show(30)
)

+---------------------+--------------------+-------------+--------------+-------------+-------------------+--------------+--------------------+
|lead_variant_id_right|lead_variant_id_left|study_id_left|study_id_right|study_id_left|tag_variant_id_left|study_id_right|tag_variant_id_right|
+---------------------+--------------------+-------------+--------------+-------------+-------------------+--------------+--------------------+
|                var10|                var4|      study_1|       study_3|      study_1|               var4|       study_3|               var10|
|                var10|                var4|      study_1|       study_3|      study_1|               var4|       study_3|                var5|
|                var10|                var4|      study_1|       study_3|      study_1|               var4|       study_3|                var4|
|                var10|                var4|      study_1|       study_3|      study_1|               var5|       study_3|              

In [36]:
    (
        pd.read_csv('http://ftp.ebi.ac.uk/pub/databases/gwas/summary_statistics/GCST90082001-GCST90083000/GCST90082654/GCST90082654_buildGRCh38.tsv.gz', sep='\t', compression='infer')
        .Model.unique()
    )

array(['ADD-WGR-FIRTH'], dtype=object)

In [39]:
(
    spark.read.parquet('/Users/dsuveges/project_data/molecule/')
    .filter(col('inchiKey') == 'WHMQZCPGFZBLBG-UHFFFAOYSA-N')
    .show()
)

+---+---------------+--------+--------+---------------+----+-------------------+-------------------------+--------+----------------+----------+---------------+----------+--------+---------------+--------------+-------------+--------------+-----------+
+---+---------------+--------+--------+---------------+----+-------------------+-------------------------+--------+----------------+----------+---------------+----------+--------+---------------+--------------+-------------+--------------+-----------+
+---+---------------+--------+--------+---------------+----+-------------------+-------------------------+--------+----------------+----------+---------------+----------+--------+---------------+--------------+-------------+--------------+-----------+



In [45]:
(
    spark.createDataFrame([
        {'s_l': 'study1', 't_l':['v1', 'v2', 'v3'], 's_r': 'study2', 't_r':['v1', 'v4', 'v5', 'v3']}
    ])
    .withColumn(collect_set([col('t_r'), col('t_l')]).alias('variants'))
    .show()
)

TypeError: Invalid argument, not a string or column: [Column<'t_r'>, Column<'t_l'>] of type <class 'list'>. For column literals, use 'lit', 'array', 'struct' or 'create_map' function.

In [97]:
from pyspark.sql.functions import sum as spark_sum, explode, max as pyspark_max, expr, collect_list, exp
from pyspark.sql.window import Window

df = (
    spark
    .createDataFrame(
        [(1, [1, 2, 3, 4]), (2, [3, 4, 5])],
        ("key", "values"))
    .withColumn("values", explode("values"))
)

#Df mormalised no-log
w1 = Window.partitionBy("key")

df.show()

+---+------+
|key|values|
+---+------+
|  1|     1|
|  1|     2|
|  1|     3|
|  1|     4|
|  2|     3|
|  2|     4|
|  2|     5|
+---+------+



In [67]:
df_norm_nonLog = (
    df
    .withColumn("x/max(x)", (col("values") / pyspark_max(col("values")).over(w1)) / pyspark_sum(col("values") / pyspark_max(col("values")).over(w1)).over(w1))
#     .withColumn("x/max(x)", col("values") / col("max"))
#     .withColumn("sum(x/max(x))", pyspark_sum(col("x/max(x)")).over(w1))
#     .withColumn("x/max(x)/sum(x/max(x))", col("x/max(x)") / col("sum(x/max(x))"))   
    .show()
)

+---+------+-------------------+
|key|values|           x/max(x)|
+---+------+-------------------+
|  1|     1|                0.1|
|  1|     2|                0.2|
|  1|     3|                0.3|
|  1|     4|                0.4|
|  2|     3|               0.25|
|  2|     4|0.33333333333333337|
|  2|     5| 0.4166666666666667|
+---+------+-------------------+



In [93]:
import math

# norm1 <- function(x, log = FALSE) {
#   if (all(is.na(x))) return(x)
#   if (log) {
#     x <- x - max(x, na.rm = TRUE)
#     x <- exp(x)    
#   } else {
#     ## This does not work if x contains NaNs or +Infs
#     stopifnot(all(x >= 0, na.rm = TRUE))
#     x <- x / max(x, na.rm = TRUE)
#   }
#   return(x / sum(x, na.rm = TRUE))
# }

@udf(ArrayType(FloatType()))
def normLog (a: list) -> list:
    a = [math.exp(x - max(a)) for x in a]
    a = [x/sum(a) for x in a]
    
    return a
    
    
(
    df
    .groupBy('key')
    .agg(
        normLog(collect_list(col('values'))).alias('norm')
    )
    .withColumn('norm', explode("norm"))
    .show()
)

# df.show()

+---+-----------+
|key|       norm|
+---+-----------+
|  1|0.032058604|
|  1|0.087144315|
|  1| 0.23688282|
|  1|  0.6439143|
|  2| 0.09003057|
|  2| 0.24472848|
|  2| 0.66524094|
+---+-----------+



In [102]:
(
    df
    .withColumn("logDifMax", exp(col('values') - pyspark_max(col("values")).over(w1)))
    .withColumn('norm', col("logDifMax")/pyspark_sum(col("logDifMax")).over(w1))
    .show()
)

+---+------+--------------------+-------------------+
|key|values|           logDifMax|               norm|
+---+------+--------------------+-------------------+
|  1|     1|0.049787068367863944|0.03205860328008499|
|  1|     2|  0.1353352832366127|0.08714431874203257|
|  1|     3| 0.36787944117144233|0.23688281808991013|
|  1|     4|                 1.0| 0.6439142598879724|
|  2|     3|  0.1353352832366127|0.09003057317038046|
|  2|     4| 0.36787944117144233|0.24472847105479764|
|  2|     5|                 1.0| 0.6652409557748218|
+---+------+--------------------+-------------------+



In [104]:
df = spark.createDataFrame([{
    "study_id": "s1",
    "lead_id": "l1",
},
{
    "study_id": "s1",
    "lead_id": "l2",

},
{
    "study_id": "s2",
    "lead_id": "l1",
}]).persist()

df.show()

+-------+--------+
|lead_id|study_id|
+-------+--------+
|     l1|      s1|
|     l2|      s1|
|     l1|      s2|
+-------+--------+



In [114]:
(
    df.withColumnRenamed('study_id', 'study_id_l')
    .join(df.withColumnRenamed('study_id', 'study_id_r'), on='lead_id', how='inner')
    .filter(col('study_id_r') != col('study_id_l'))
    .withColumn('unique_test', array_sort(array(col('study_id_l'), col('study_id_r'))))
    .groupBy('unique_test')
    .agg(
        *[first(c).alias(c) for c in ['study_id_r', 'study_id_l', 'lead_id']]
    )
    .show()
)

+-----------+----------+----------+-------+
|unique_test|study_id_r|study_id_l|lead_id|
+-----------+----------+----------+-------+
|   [s1, s2]|        s2|        s1|     l1|
+-----------+----------+----------+-------+



In [113]:
df.columns

['lead_id', 'study_id']