In [2]:
%%configure -f
{
    "driverMemory": "16G", "driverCores": 8,
    "executorMemory": "8G", "executorCores": 6, "numExecutors": 2
}

Starting Spark application


ID,YARN Application ID,Kind,State,Spark UI,Driver log,User,Current session?
24,,pyspark,idle,,,,✔


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

SparkSession available as 'spark'.


ID,YARN Application ID,Kind,State,Spark UI,Driver log,User,Current session?
24,,pyspark,idle,,,,✔


In [3]:
from typing import List

from pyspark import SparkFiles
from subprocess import call
import sys


def install_deps(deps: List[str]) -> None:
    call([sys.executable, '-m', 'pip', 'install', '-q', '-t', SparkFiles.getRootDirectory(), *deps])


install_deps(['numpy', 'matplotlib', 'pandas', 'scipy', 'seaborn', 'statsmodels', 'pyarrow', 'pymongo'])

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [4]:
from pyspark.sql import functions as F, types as T
import numpy as np
from scipy import stats
from statsmodels.sandbox.stats.multicomp import multipletests

@F.udf(T.ArrayType(T.DoubleType()))
def diff(A, B):
    return np.abs(np.array(A) - np.array(B)).tolist()

@F.udf(T.DoubleType())
def var(A):
    return float(np.var(A))

@F.udf(T.DoubleType())
def avg(A):
    return float(np.mean(A))

@F.udf(T.DoubleType())
def mannwhiteneyu(ref, mod):
    result = stats.mannwhitneyu(np.array(ref), np.array(mod), alternative='two-sided')
    return float(result.pvalue)

@F.udf(T.DoubleType())
def bonferroni_correction(pvalues, alpha=0.05):
    reject, pvals_corrected, _, _ = multipletests(pvalues, alpha=alpha, method='bonferroni')
    return float(np.mean(pvals_corrected))


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [5]:
from pyspark import SparkContext
from pyspark.sql import SparkSession

sc: SparkContext
spark: SparkSession

project_configuration_df = (
    spark
    .read
    .format("mongodb")
    .option("database", "enhancer3d")
    .option("collection", "project_configuration")
    .load()
)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [6]:
links_df = (
    spark
    .read
    .format("parquet")
    .option("header", "true")
    .option("inferSchema", "true")
    .load("/work/data/links/")
    # GM12878_EP_hg38_liftovered.parquet -> GM12878
    .withColumn("cell_line", F.element_at(F.split(F.element_at(F.split(F.input_file_name(), "/"), -1), "_"), 1))
    .alias("links")
)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [113]:
%%pretty
links_df.show(n=5)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

enh_id,gene_id,pval,qval,cell_line
chr1:777020-778280,ENSG00000197049,1.49719,1.0,HFFC6
chr1:1020370-1022200,ENSG00000197049,0.668822,2.0,HFFC6
chr1:925800-925920,ENSG00000188976,3.024489,3.0,HFFC6
chr1:939460-941140,ENSG00000188976,2.660575,4.0,HFFC6
chr1:941310-942130,ENSG00000188976,2.350741,5.0,HFFC6


In [7]:
ensembles_list_by_project_df = (
    project_configuration_df
    .select(F.col('_id.project_id').alias('project_id'), F.col('datasets.ensemble_id').alias('ensemble_id'))
    # blow up the list of ensembles
    .withColumn('ensemble_id', F.explode(F.col('ensemble_id')))
)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [115]:
%%pretty
ensembles_list_by_project_df.show(n=5, truncate=False)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

project_id,ensemble_id
8k_models_project_GM12878,models3D_GM12878_Deni_models3D_GM12878_Deni_mod_results_GM12878_Deni_chr7_54723172_57700542
8k_models_project_GM12878,models3D_GM12878_Deni_models3D_GM12878_Deni_ref2_results_GM12878_Deni_chr7_0_1069141
8k_models_project_GM12878,models3D_GM12878_Deni_models3D_GM12878_Deni_results_GM12878_Deni_chr12_5480623_8570102
8k_models_project_GM12878,models3D_GM12878_Nean_models3D_GM12878_Nean_mod_results_GM12878_Nean_chr3_127434292_130363324
8k_models_project_GM12878,models3D_GM12878_Deni_models3D_GM12878_Deni_mod_results_GM12878_Deni_chr21_12254928_14481319


In [8]:
all_relevant_ensembles_df = (
    ensembles_list_by_project_df
    .where(
        (
            (F.col('project_id').isin(['8k_models_project_GM12878']))
            & (F.col('ensemble_id').like('models3D_GM12878_Nean_models3D_GM12878_Nean_results%'))
        )
        | (
            (F.col('project_id').isin(['8k_models_project_HFFC6']))
            & (F.col('ensemble_id').like('models3D_HFFC6_Nean_models3D_HFFC6_Nean_results%'))
        )
        | (
            (F.col('project_id').isin(['8k_models_project_H1ESC']))
            & (F.col('ensemble_id').like('models3D_H1ESC_Nean_models3D_H1ESC_Nean_results%'))
        )
        | (
            (F.col('project_id').isin(['8k_models_project_GM12878']))
            & (F.col('ensemble_id').like('models3D_GM12878_Nean_models3D_GM12878_Nean_results%'))
        )
    )
)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [117]:
%%pretty
all_relevant_ensembles_df.count()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

1037

In [44]:
import pymongo
from pyspark.sql import Row
from typing import List

def load_mongo_batch(root_condition: str = 'or', projection: List[str] = None, structure: T.StructType = None):
    @F.udf(T.ArrayType(structure or T.MapType(T.StringType(), T.StringType())))
    def load_mongo_batch_internal(criteria: List[Row]) -> List[dict]:
        mongo_uri = os.environ.get("MONGO_URI", "mongodb://mongo:Flkj234KJFsdzipArch@mongo:27017")
        database = os.environ.get("MONGO_DATABASE", "enhancer3d")
        collection = os.environ.get("MONGO_COLLECTION", "distance_calculation")

        client = pymongo.MongoClient(mongo_uri)
        db = client[database]
        collection = db[collection]

        criteria = {
            f'${root_condition}': [
                {
                    key: value
                    for key, value in item.asDict().items()
                }
                for item in criteria
            ]
        }

        return list(
            collection.find(criteria)
            if projection is None
            else collection.find(criteria, {field: 1 for field in projection})
        )

    return load_mongo_batch_internal

distances_query_df = (
    all_relevant_ensembles_df
    .withColumn(
        'batch_id',
        F.monotonically_increasing_id() % 18
    )
    .groupBy('batch_id')
    .agg(
        F.collect_list(
            F.struct(
                F.col('project_id').alias('_id.project_id'),
                F.col('ensemble_id').alias('_id.ensemble_id'),
            )
        ).alias('criteria')
    )
    # Load full data
    .select(
        load_mongo_batch(
            root_condition='or',
            projection=[
                '_id.project_id',
                '_id.ensemble_id',
                '_id.region_id',
                '_id.gene_id',
                '_id.enh_id',
                'gene_type',
                'avg_dist',
                'enh_tSS_distance',
                'project_cell_lines'
            ],
            structure=T.StructType([
                T.StructField('_id', T.StructType([
                    T.StructField('project_id', T.StringType(), True),
                    T.StructField('ensemble_id', T.StringType(), True),
                    T.StructField('region_id', T.StringType(), True),
                    T.StructField('gene_id', T.StringType(), True),
                    T.StructField('enh_id', T.StringType(), True)
                ])),
                T.StructField('gene_type', T.StringType(), True),
                T.StructField('avg_dist', T.DoubleType(), True),
                T.StructField('enh_tSS_distance', T.DoubleType(), True),
                T.StructField('project_cell_lines', T.ArrayType(T.StringType()), True)
            ])
        )(
            F.col('criteria')
        )
        .alias('data')
    )
    # Explode the data
    .select(
        F.explode(F.col('data')).alias('data')
    )
    .select(
        F.col('data._id').alias('_id'),
        F.col('data.gene_type').alias('gene_type'),
        F.col('data.avg_dist').alias('avg_dist'),
        F.col('data.enh_tSS_distance').alias('enh_tSS_distance'),
        F.col('data.project_cell_lines').alias('project_cell_lines')
    )
)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [45]:
distances_df = (
    distances_query_df
    .where(
        (F.col('gene_type') == 'protein_coding')
        # & (F.col('enh_tSS_distance') < 20_000)
    )
    .select(
        F.col('_id.project_id').alias('project_id'),
        F.col('_id.ensemble_id').alias('ensemble_id'),
        F.col('_id.region_id').alias('region_id'),
        F.col('_id.gene_id').alias('gene_id'),
        F.col('_id.enh_id').alias('enh_id'),
        F.element_at(F.col('project_cell_lines'), 1).alias('cell_line'),
        # 'dist',
        'avg_dist',
        'enh_tSS_distance'
    )
    # gene_id ENH00001.XXX -> ENH00001
    .withColumn('gene_id', F.split(F.col('gene_id'), '\.')[0])
    .alias("distances")
)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…



In [46]:
distances_with_links_df = (
    distances_df
    .join(
        other=links_df,
        on=F.expr("distances.cell_line = links.cell_line AND distances.gene_id = links.gene_id AND distances.enh_id = links.enh_id"),
        how="outer"
    )
    .select(
        distances_df.project_id,
        distances_df.ensemble_id,
        distances_df.cell_line,
        distances_df.region_id,
        distances_df.gene_id,
        distances_df.enh_id,
        distances_df.avg_dist.alias('dist_avg_dist'),
        # If has link then True else False
        F.when(F.col('links.gene_id').isNotNull(), True).otherwise(False).alias('has_link'),
    )
)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [47]:
import os

# write all to csv into /work/playground/links/experiment_3
os.makedirs("/work/playground/links/experiment_4", exist_ok=True)

(
    distances_with_links_df
    .toPandas()
    .to_parquet(
        "/work/playground/links/experiment_4/distances_with_links.parquet",
        index=False
    )
)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…