In [114]:
import json
from json import JSONDecodeError

import requests
from functools import reduce
import pandas as pd
from pyspark.conf import SparkConf

from pyspark.sql.functions import (
    col, udf, struct, lit, split, expr, collect_set, struct, 
    regexp_replace, min as pyspark_min, explode, when,
    array_contains, count, first, element_at, size, sum as pyspark_sum
)
from pyspark.sql.types import FloatType, ArrayType, StructType, StructField, BooleanType, StringType
from pyspark.sql import SparkSession, DataFrame
from collections import defaultdict

# establish spark connection
spark_conf = (
    SparkConf()
    .set('spark.driver.memory', f'8g')
    .set('spark.executor.memory', f'8g')
    .set('spark.driver.maxResultSize', '0')
    .set('spark.debug.maxToStringFields', '2000')
    .set('spark.sql.execution.arrow.maxRecordsPerBatch', '500000')
)
spark = (
    SparkSession.builder
    .config(conf=spark_conf)
    .master('local[*]')
    .config("spark.driver.bindAddress", "127.0.0.1")
    .getOrCreate()
)


In [2]:
vi = spark.read.parquet('/Users/dsuveges/project_data/variant-index').persist()

vi.printSchema()
print(vi.show(1, False, True))

root
 |-- chr_id: string (nullable = true)
 |-- position: integer (nullable = true)
 |-- ref_allele: string (nullable = true)
 |-- alt_allele: string (nullable = true)
 |-- chr_id_b37: string (nullable = true)
 |-- position_b37: integer (nullable = true)
 |-- rs_id: string (nullable = true)
 |-- most_severe_consequence: string (nullable = true)
 |-- cadd: struct (nullable = true)
 |    |-- raw: double (nullable = true)
 |    |-- phred: double (nullable = true)
 |-- af: struct (nullable = true)
 |    |-- gnomad_afr: double (nullable = true)
 |    |-- gnomad_amr: double (nullable = true)
 |    |-- gnomad_asj: double (nullable = true)
 |    |-- gnomad_eas: double (nullable = true)
 |    |-- gnomad_fin: double (nullable = true)
 |    |-- gnomad_nfe: double (nullable = true)
 |    |-- gnomad_nfe_est: double (nullable = true)
 |    |-- gnomad_nfe_nwe: double (nullable = true)
 |    |-- gnomad_nfe_onf: double (nullable = true)
 |    |-- gnomad_nfe_seu: double (nullable = true)
 |    |-- gnoma

In [3]:
vi.count()

72878709

In [4]:
print(vi.filter(col('rs_id').isNotNull()).count())
print(vi.filter(col('rs_id').isNotNull()).select(col('rs_id')).distinct().count())
print(vi
      .filter(col('rs_id').isNotNull())
      .groupby('rs_id')
      .agg(
          count(col('rs_id')).alias('count'),
          collect_set(col('most_severe_consequence')).alias('most_severe_consequences'),
          collect_set(col('gene_id_any')).alias('gene_id_any'),
          collect_set(col('gene_id_prot_coding')).alias('gene_id_prot_coding')
          
      )
      .filter(size(col('gene_id_any'))>1)
      .show())

72056558
66727621
+-----+-----+------------------------+-----------+-------------------+
|rs_id|count|most_severe_consequences|gene_id_any|gene_id_prot_coding|
+-----+-----+------------------------+-----------+-------------------+
+-----+-----+------------------------+-----------+-------------------+

None


In [11]:
(
    vi
    .filter(col('rs_id') == 'rs1491385527')
    .drop('af', 'cadd')
    .show(2, False, True)
)

-RECORD 0------------------------------------------
 chr_id                       | 18                 
 position                     | 50330765           
 ref_allele                   | T                  
 alt_allele                   | TATATATATATATATATA 
 chr_id_b37                   | 18                 
 position_b37                 | 47857135           
 rs_id                        | rs1491385527       
 most_severe_consequence      | intergenic_variant 
 gene_id_any_distance         | 42926              
 gene_id_any                  | ENSG00000154832    
 gene_id_prot_coding_distance | 42926              
 gene_id_prot_coding          | ENSG00000154832    
-RECORD 1------------------------------------------
 chr_id                       | 18                 
 position                     | 50330765           
 ref_allele                   | TTTTTTTTTTACTGG    
 alt_allele                   | T                  
 chr_id_b37                   | 18                 
 position_b3

In [7]:
import json
with open('/Users/dsuveges/Downloads/vep-annotated-complex.json', 'r') as f:
    data = json.load(f)

In [29]:
import pandas as pd

df = spark.createDataFrame(pd.DataFrame({"variant": data.keys(), "genes": data.values()})).persist()

df.show(truncate=False)

+-------------------------------------------------------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|variant                                                      |genes                                                                                                                                                                                      |
+-------------------------------------------------------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|NC_000001.11:g.20810811_20829210del                          |[ENSG00000075151]                                                                                                                                                                    

In [13]:
print(df.count())
print(df.select('variant').distinct().count())

6731
6731


In [14]:
(df
.select(explode(col('genes')))
.distinct()
.count())

10448

In [30]:
print(
    df
    .filter(size(col('genes')) <= 2)
    .count()
)

print(
    df
    .filter(size(col('genes')) <= 2)
    .select(explode(col('genes')))
    .distinct()
    .count()
)

4707
1869


In [26]:
print(
    df
    .filter(size(col('genes')) <= 5)
    .count()
)

print(
    df
    .filter(size(col('genes')) <= 5)
    .select(explode(col('genes')))
    .distinct()
    .count()
)

6123
3515


In [1]:
from collections import defaultdict
from functools import reduce
import re
import requests

import pandas as pd

from pyspark.conf import SparkConf
from pyspark.sql.functions import (
    col, split, expr, lit, struct, element_at
)
from pyspark.sql import SparkSession, DataFrame

# establish spark connection
spark_conf = (
    SparkConf()
    .set('spark.driver.memory', '8g')
    .set('spark.executor.memory', '8g')
    .set('spark.driver.maxResultSize', '0')
    .set('spark.debug.maxToStringFields', '2000')
    .set('spark.sql.execution.arrow.maxRecordsPerBatch', '500000')
)
spark = (
    SparkSession.builder
    .config(conf=spark_conf)
    .master('local[*]')
    .config("spark.driver.bindAddress", "127.0.0.1")
    .getOrCreate()
)

def process_interacting_residues(sites_lines: zip) -> DataFrame:
    '''
    Considerations:
    * Interacting residues are listed behind SITE tokens
    * Binding sites are identified by their AC## site identifier.
    * Only up to four residues are listed in one row.
    * For each interacting residues the residue name, chainId and residueNo provided
    * For whatever reason the chain and the residue number is concatenated for interacting HOH
    - HOH interactions are dropped.
    - Interactions with other ligands are kept.
    '''

    # Process each line and create a dataframe:
    # How one line looks like:
    # SITE     2 AC1 16 TRP A 760  TYR A 813  ILE A 825  GLU A 826
    # As seen, in each row there can be up to 4 interacting residue
    sites = spark.createDataFrame([parsing_SITE_row(line) for line in sites_lines]).persist()

    # For each interacting residue we have the following fields:
    fields = ['residue', 'chain', 'residueNo']

    # The shape of the dataframe is not good, we have to collate these fields in a single column:
    expressions = map(
        lambda index: (f'_{index}', struct([col(f'{field}_{index}').alias(field) for field in fields])),
        range(1, 5)
    )

    # Applying map on the dataframe:
    res_df = reduce(lambda DF, value: DF.withColumn(*value), expressions, sites)

    # Stack the previously generated columns:
    unpivot_expression = f'''stack(4, {", ".join([f"'_{index}', _{index}" for index in range(1,5)])} ) as (index, interaction)'''

    return (
        res_df

        # Unpivot:
        .select('siteId', expr(unpivot_expression))

        # Extracting columns:
        .select(
            'siteId',
            col('interaction.residue').alias('residue'),
            col('interaction.chain').alias('chain'),
            col('interaction.residueNo').alias('residueNo')
        )
        # Removing water and empty lines:
        .filter(
            (col('residue') != 'HOH') & (col('residue').isNotNull())
        )
        .persist()
    )

def parsing_SITE_row(row: str) -> dict:
    """ Parsing the PDB SITE token data
    Rows look like this:
    SITE     3 AC1 14 ASN A 981  ILE A 982  LEU A 983  GLY A 993
    SITE     4 AC1 14 ASP A 994  HOH A4137
    Be aware how the HOH residues are annotated!
    Also not all position is filled!
    """

    # Field names for the full row listing all fields up to 4 interacting residues:
    field_names = [
        'token',  # SITE
        'rowNumber',  # Number of SITE rows for the given siteId
        'siteId',  # AC## the code identifying the binding site
        'interactorCount',  # The number of interacting residues for the given binding site

        # The following 3 fields are repeated for 4 interacting residues:
        'residue_1',
        'chain_1',
        'residueNo_1',

        'residue_2',
        'chain_2',
        'residueNo_2',

        'residue_3',
        'chain_3',
        'residueNo_3',

        'residue_4',
        'chain_4',
        'residueNo_4',
    ]

    # Parsing HOH annotation: A1102 -> A 1102
    row = ' '.join([f'{word[0]} {word[1:]}' if re.match(r'[A-Z]+[0-9]+', word) and len(word) > 3 else word for word in row.split()])

    # Splitting again:
    fields = row.split()

    # zipping together the values and the token names:
    return {key: value for key, value in zip(field_names[0:len(fields)], fields)}


def parsing_binding_sites(binding_sites_lines: filter) -> DataFrame:
    '''Binding sites are annotated after the REMARK 800 token.
    The rows processed one by one, then converted to a spark dataframe'''

    # Parsing binding site annotation:
    parsed_lines = [line.strip().replace('REMARK 800 ', '').split(': ') for line in binding_sites_lines]

    # Collecting the annotatin in a dictionary:
    binding_sites_dict = defaultdict(list)
    [{binding_sites_dict[key].append(value) for key, value in parsed_lines}]

    # Convert dictionary into a spark dataframe:
    return (
        spark.createDataFrame(pd.DataFrame(binding_sites_dict))

        # Parse ligand information:
        .withColumn('SITE_DESCRIPTION', split(col('SITE_DESCRIPTION'), ' '))
        .withColumn('compoundId', element_at(col('SITE_DESCRIPTION'), -3))
        .withColumn('compoundChainId', element_at(col('SITE_DESCRIPTION'), -2))
        .withColumn('compoundResidueNo', element_at(col('SITE_DESCRIPTION'), -1))

        .drop('EVIDENCE_CODE', 'SITE_DESCRIPTION')
        .withColumnRenamed('SITE_IDENTIFIER', 'siteId')
        .persist()
    )


##
## Iniput:
##
pdbId = '5ddp'

##
## Fetch PDB header data:
##
url = f'https://www.ebi.ac.uk/pdbe/static/entry/download/{pdbId}.header'
data = requests.get(url)
lines = data.text.split('\n')

## 
## Extract annotation for binding sites from REMARK 800 token:
##
binding_sites_lines = filter(lambda line: line.startswith('REMARK 800') and ':' in line, lines)

# Get binding site list:
binding_sites = parsing_binding_sites(binding_sites_lines)
binding_sites.show()

##
## Get interacting residues from the SITE token:
##

# Interacting residues are stored after the SITE token:
sites_lines = filter(lambda line: line.startswith('SITE'), lines)
interacting_residues = process_interacting_residues(sites_lines)

interacting_residues.show()

##
## Joining binding site data + interacting residues together
##
(
    binding_sites
    .join(interacting_residues, on='siteId')
    .withColumn('pdbId', lit(pdbId))
    .show(100)
)


+------+----------+---------------+-----------------+
|siteId|compoundId|compoundChainId|compoundResidueNo|
+------+----------+---------------+-----------------+
|   AC1|       GLN|              A|              101|
|   AC2|        MG|              A|              102|
|   AC3|        MG|              A|              103|
|   AC4|        MG|              A|              104|
|   AC5|        MG|              A|              105|
|   AC6|        MG|              A|              106|
|   AC7|        NA|              A|              107|
|   AC8|        NA|              A|              108|
|   AC9|        NA|              A|              109|
|   AD1|        NA|              A|              110|
|   AD2|        NA|              A|              111|
|   AD3|       GLN|              B|              101|
|   AD4|        MG|              B|              102|
|   AD5|        MG|              B|              103|
|   AD6|        MG|              B|              104|
|   AD7|        MG|         

In [2]:
(
    binding_sites
    .join(interacting_residues, on='siteId')
    .withColumn('pdbId', lit(pdbId))
    .filter(col('siteId') == 'AD3')
    .show(100)
)

+------+----------+---------------+-----------------+-------+-----+---------+-----+
|siteId|compoundId|compoundChainId|compoundResidueNo|residue|chain|residueNo|pdbId|
+------+----------+---------------+-----------------+-------+-----+---------+-----+
|   AD3|       GLN|              B|              101|      C|    B|        1| 5ddp|
|   AD3|       GLN|              B|              101|      G|    B|       22| 5ddp|
|   AD3|       GLN|              B|              101|      G|    B|       23| 5ddp|
|   AD3|       GLN|              B|              101|      A|    B|       24| 5ddp|
|   AD3|       GLN|              B|              101|      G|    B|       54| 5ddp|
|   AD3|       GLN|              B|              101|      C|    B|       58| 5ddp|
|   AD3|       GLN|              B|              101|      G|    B|       59| 5ddp|
|   AD3|       GLN|              B|              101|      C|    B|       60| 5ddp|
|   AD3|       GLN|              B|              101|     MG|    B|      102