In [1]:
import pyspark.sql.functions as F
from pyspark import SparkConf
from pyspark.sql import SparkSession
from functools import reduce 

sparkConf = SparkConf()

# establish spark connection
spark = (
    SparkSession.builder
    .config(conf=sparkConf)
    .master('local[*]')
    .getOrCreate()
)

credSet22Path = '/Users/dsuveges/project_data/test_credible_set_chr22.parquet/'

credSet = (
    spark.read.parquet(credSet22Path)

    .filter(F.col('logABF').isNotNull())
    .distinct()

    # Create a study specific key: <- why do we need the type?
    .withColumn("studyKey", F.concat_ws('_', *['type', 'study_id', 'phenotype_id', 'bio_feature']))
)

credSet.show()

+--------------------+------------+------------+--------+----------+--------+---------+--------------------+------------------+------------------+--------------------+-----------------+-----------------+----------------+-------+------------------+------------------+---------+--------+-----------------+-----------------+-------+-----------------+-----------------+----------------+----+--------------------+
|         bio_feature|is95_credset|is99_credset|lead_alt|lead_chrom|lead_pos| lead_ref|     lead_variant_id|            logABF|multisignal_method|        phenotype_id|         postprob|  postprob_cumsum|        study_id|tag_alt|          tag_beta|     tag_beta_cond|tag_chrom| tag_pos|         tag_pval|    tag_pval_cond|tag_ref|           tag_se|      tag_se_cond|  tag_variant_id|type|            studyKey|
+--------------------+------------+------------+--------+----------+--------+---------+--------------------+------------------+------------------+--------------------+---------------

In [5]:
# priors
priors = spark.createDataFrame([{
    "priorc1": 1e-4,  # priorc1 Prior on variant being causal for trait 1
    "priorc2": 1e-4,  # priorc2 Prior on variant being causal for trait 2
    "priorc12": 1e-5  # priorc12 Prior on variant being causal for traits 1 and 2
}])

# Joining the credset table with itself based on tag variants:
columnsToJoin = [
    "studyKey","tag_variant_id", "lead_variant_id", "type", "logABF" 
]
rename_columns = [
    "studyKey", "lead_variant_id", "type", "logABF"
]

# Overlapping signals (exploded at the tag variant level)
leftDf = reduce(lambda DF, col:
    DF.withColumnRenamed(col, 'left_' + col), rename_columns, credSet.select(columnsToJoin).distinct())
rightDf = reduce(lambda DF, col: 
    DF.withColumnRenamed(col, 'right_' + col), rename_columns, credSet.select(columnsToJoin).distinct())

## TODO: Resolve so that biofeatures are always right

# priors
priors = spark.createDataFrame([{
    "priorc1": 1e-4,  # priorc1 Prior on variant being causal for trait 1
    "priorc2": 1e-4,  # priorc2 Prior on variant being causal for trait 2
    "priorc12": 1e-5  # priorc12 Prior on variant being causal for traits 1 and 2
}])

# Joining the credset table with itself based on tag variants:
columnsToJoin = [
    "studyKey","tag_variant_id", "lead_variant_id", "type", "logABF" 
]
rename_columns = [
    "studyKey", "lead_variant_id", "type", "logABF"
]

## TODO: Resolve so that biofeatures are always right

# Overlapping signals (exploded at the tag variant level)
leftDf = reduce(
    lambda DF, col: 
        DF.withColumnRenamed(col, 'left_' + col), rename_columns, credSet.select(columnsToJoin))
rightDf = reduce(
    lambda DF, col: 
        DF.withColumnRenamed(col, 'right_' + col), rename_columns, credSet.select(columnsToJoin))

overlappingPeaks = (
    leftDf
    # molecular traits always on the right-side
    .filter(F.col("left_type") == "gwas")
    
    # Get all study/peak pairs where at least one tagging variant overlap:
    .join(rightDf, on='tag_variant_id', how='inner')
    .filter(
        # Remove rows with identical study:
        (F.col('left_studyKey') != F.col('right_studyKey')) 
    )
    # Keep only the upper triangle where both study is gwas
    .filter(
        (F.col('right_type') != 'gwas') | 
        (F.col('left_studyKey') > F.col('right_studyKey'))
    )
    # remove overlapping tag variant info
    .drop("left_logABF", "right_logABF", "tag_variant_id")
    # distinct to get study-pair info
    .distinct()
    .persist()
)


print(overlappingPeaks.count())
overlappingPeaks.show()

244396
+--------------------+--------------------+---------+--------------------+---------------------+----------+
|       left_studyKey|left_lead_variant_id|left_type|      right_studyKey|right_lead_variant_id|right_type|
+--------------------+--------------------+---------+--------------------+---------------------+----------+
|   gwas_GCST90002357|     22:19972675:G:C|     gwas|eqtl_Schmiedel_20...|      22:19990499:C:T|      eqtl|
|gwas_NEALE2_30000...|     22:19984029:T:C|     gwas|sqtl_GTEx-sQTL_ch...|      22:19973083:G:C|      sqtl|
|   gwas_GCST90002381|     22:20148740:G:A|     gwas|eqtl_GENCORD_ENSG...|      22:19990499:C:T|      eqtl|
|     gwas_GCST004607|     22:20023636:C:T|     gwas|sqtl_GTEx-sQTL_ch...|      22:19975440:G:T|      sqtl|
|   gwas_GCST90012110|     22:20059164:G:A|     gwas|   gwas_GCST90012109|      22:20150299:C:G|      gwas|
|   gwas_GCST90012109|     22:20150299:C:G|     gwas|eqtl_CEDAR_ILMN_1...|      22:20148856:C:A|      eqtl|
|   gwas_GCST90002405

In [6]:

overlappingLeft = (
    overlappingPeaks
    .join(
        leftDf
        .select("left_studyKey", "left_lead_variant_id", "tag_variant_id", "left_logABF"),
        on=["left_studyKey", "left_lead_variant_id"], 
        how='inner'
    )
)
overlappingRight = (
    overlappingPeaks
    .join(
        rightDf
        .select("right_studyKey", "right_lead_variant_id", "tag_variant_id", "right_logABF"),
        on=['right_studyKey', 'right_lead_variant_id'],
        how='inner'
    )
)

overlappingSignals = (
    overlappingLeft.alias("a")
    .join(
        overlappingRight.alias("b"),
        on = [
            "tag_variant_id",
            "left_lead_variant_id",
            "right_lead_variant_id",
            "left_studyKey",
            "right_studyKey",
            "right_type",
            "left_type"
        ],

        how='outer'
    )
)

signalPairsCols = ["studyKey", "lead_variant_id", "type"]


In [7]:
# Colocalisation analysis
coloc = (
    overlappingSignals

    # Before summarizing logABF columns nulls need to be filled with 0:
    .fillna(0,  subset=['left_logABF', 'right_logABF'])
    # Grouping data by peak and collect list of the sums:
    .withColumn('sum_logABF', F.col('left_logABF') + F.col('right_logABF'))
    
    # TODO: group by key column and keep rest of columns:
    .groupBy(*["left_" + col for col in signalPairsCols] + ["right_" + col for col in signalPairsCols])
    .agg(
        F.count('*').alias('coloc_n_vars'),
        F.collect_list(F.col('left_logABF')).alias('left_logABF_array'),
        F.collect_list(F.col('right_logABF')).alias('right_logABF_array'),
        F.collect_list(F.col('sum_logABF')).alias('sum_logABF_array')
    )
    .persist()
)

In [10]:
test_set = coloc.limit(1000).persist()
test_set.count()

1000

In [21]:
from  pyspark.ml.param import TypeConverters
from pyspark.ml.feature import VectorAssembler


In [22]:
assembler = VectorAssembler(inputCols=['left_logABF_array'], outputCol="left_logABF_vector")
selected_features = assembler.transform(test_set).select('left_logABF_vector')
selected_features.collect()

IllegalArgumentException: Data type array<double> of column left_logABF_array is not supported.