In [1]:
from pyspark.sql import SparkSession, functions as f, types as t, DataFrame, Column

spark = SparkSession.builder.getOrCreate()
association_file = 'gs://genetics_etl_python_playground/XX.XX/output/python_etl/parquet/gwas_catalog_associations'

df = (
    spark.read.parquet(association_file)
    .distinct()
    .withColumn('riskAllele', f.upper(f.col('riskAllele')))
    .persist()
)
df.show(1, False, True)

print(df.count())
print(df.select('associationId').distinct().count())

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
22/12/12 14:45:04 INFO org.apache.spark.SparkEnv: Registering MapOutputTracker
22/12/12 14:45:04 INFO org.apache.spark.SparkEnv: Registering BlockManagerMaster
22/12/12 14:45:04 INFO org.apache.spark.SparkEnv: Registering BlockManagerMasterHeartbeat
22/12/12 14:45:04 INFO org.apache.spark.SparkEnv: Registering OutputCommitCoordinator
22/12/12 14:45:18 WARN org.apache.spark.sql.catalyst.util.package: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
                                                                                

-RECORD 0-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
 variantId               | 16_79539854_C_T                                                                                                                                                                                                                                                              
 chromosome              | 16                                                                                                                                                                                                                                                                           
 position                | 79539854                                                                          

                                                                                

434351




434351


                                                                                

In [53]:
# Is the intersecion of the gnomad and gwas rsid list empty?
def check_rsids(gnomad: Column, gwas: Column) -> Column:
    return f.when(f.size(f.array_intersect(gnomad, gwas)) > 0, True).otherwise(False)


def check_concordance(riskAllele: Column, referenceAllele: Column, alternateAllele: Column):
    # Calculating the reverse complement of the risk allele:
    riskAlleleReverseComplement = f.when(
        riskAllele.rlike(r"^[ACTG]+$"),
        f.reverse(f.translate(riskAllele, "ACTG", "TGAC")),
    ).otherwise(riskAllele)

    # OK, is the risk allele or the reverse complent is the same as the mapped alleles:
    return (
        f.when(
            (riskAllele == referenceAllele)
            | (riskAllele == alternateAllele),
            True,
        )
        # If risk allele is found on the negative strand:
        .when(
            (riskAlleleReverseComplement == referenceAllele)
            | (riskAlleleReverseComplement == alternateAllele),
            True,
        )
        # If risk allele is ambiguous, still accepted: < This condition could be reconsidered
        .when(riskAllele == "?", True)
        # If the association could not be mapped we keep it:
        .when(referenceAllele.isNull(), True)
        # Allele is discordant:
        .otherwise(False)
    )
    



mappings = (
    df
    .withColumn(
        'isRsIdMatched', 
        check_rsids(
            f.col('rsIdsGnomad'), 
            f.col('rsIdsGwasCatalog')
        )
    )
    .withColumn(
        'isConcordant', 
        check_concordance(
            f.col('riskAllele'), 
            f.col('referenceAllele'), 
            f.col('alternateAllele')
        )
    )
)


mafs = (
    mappings
    .select('variantId', f.explode_outer(f.col('alleleFrequencies')).alias('af'))
    .distinct()
    .withColumn(
        "maf",
        f.when(
            f.col("af.alleleFrequency") > 0.5, 1 - f.col("af.alleleFrequency")
        ).otherwise(f.col("af.alleleFrequency")),
    )
    .groupBy('variantId')
    .agg(f.max('maf').alias('maxMaf'))
)

processed = mappings.join(mafs, on='variantId', how='left').persist()
processed.show(1, False, True)



                                                                                

-RECORD 0---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
 variantId               | 4_78701519_C_T                                                                                                                                                                                                                                                                                     
 chromosome              | 4                                                                                                                                                                                                                                                                                                  
 position                | 78701519        

In [66]:
from pyspark.sql.window import Window

def _find_mappings_to_drop(associationId: Column, filterColumn: Column) -> Column:
    """Flagging mappings to drop for each association.
    
    Some associations have multiple mappings. Some has matching rsId others don't. We only
    want to drop the non-matching mappings, when a matching is available for the given association.
    This logic can be generalised for other measures eg. allele concordance.

    
    """
    w = Window.partitionBy(associationId)

    # Generating a boolean column informing if the filter column contains true anywhere for the association:
    aggregated_filter = f.when(f.array_contains(f.collect_set(filterColumn).over(w), True),True).otherwise(False)

    # Generate a filter column:
    return f.when(aggregated_filter & (filterColumn == False), False).otherwise(True)

def _keep_mapping_with_top_maf(associationId: Column, maf_column: Column) -> Column:
    w = Window.partitionBy(associationId).orderBy(f.desc(maf_column))
    row_numbers = f.row_number().over(w)
    return f.when(row_numbers == 1, True).otherwise(False)

(
    processed
    # Dropping rows, where rsId doesn't match, but matching rsId available:
    .withColumn('rsidFilter', _find_mappings_to_drop(f.col('associationId'), f.col('isRsIdMatched')))
    .filter(f.col('rsidFilter'))
    # Dropping rows, where alleles aren't concordant, but concordant alleles available:
    .withColumn('concordanceFilter', _find_mappings_to_drop(f.col('associationId'), f.col('isConcordant')))
    .filter(f.col('concordanceFilter'))
    # Out of the remaining mappings, keeping the one with the highest MAF:
    .withColumn('mafFilter', _keep_mapping_with_top_maf(f.col('associationId'), f.col('maxMaf')))
    .filter('mafFilter')
    .drop('rsidFilter', 'concordanceFilter', 'mafFilter', 'isRsIdMatched', 'maxMaf', 'isRsIdMatched', 'isConcordant')
    .show(1, False, True)
)


-RECORD 0----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
 variantId               | 8_19986711_A_G                                                                                                                                                                                                                                                                          
 chromosome              | 8                                                                                                                                                                                                                                                                                       
 position                | 19986711                                         

In [68]:
(
    processed
    # .filter(~f.col('isRsIdMatched') & f.col('referenceAllele').isNotNull())
    # .show(1, False, True)
    .count(1, False, True)
)

TypeError: count() takes 1 positional argument but 4 were given

In [13]:
assoc = spark.read.parquet('gs://genetics_etl_python_playground/XX.XX/output/python_etl/parquet//gwas_catalog_associations').persist()
assoc.show(1, False, True)


[Stage 2:>                                                          (0 + 1) / 1]

-RECORD 0---------------------------------
 chromosome          | 8                  
 position            | 127472793          
 referenceAllele     | A                  
 alternateAllele     | C                  
 variantId           | 8_127472793_A_C    
 studyId             | GCST000017         
 pValueMantissa      | 2.0                
 pValueExponent      | -14                
 beta                | null               
 beta_ci_lower       | null               
 beta_ci_upper       | null               
 odds_ratio          | 0.6993006993006994 
 odds_ratio_ci_lower | 0.6380705725634574 
 odds_ratio_ci_upper | 0.7664065529268912 
 qualityControl      | []                 
only showing top 1 row



                                                                                

In [75]:
# assoc.filter(f.size(f.col('qualityControl')) == 0).count() # 315122
# assoc.filter(f.size(f.col('qualityControl')) != 0).count() # 119229
assoc.select(f.explode(f.col('qualityControl'))).groupBy('col').count().show(1000, truncate=False)


+--------------------------+-----+
|col                       |count|
+--------------------------+-----+
|Composite association     |285  |
|Variant inconsistency     |1278 |
|No mapping in GnomAd      |33312|
|Incomplete genomic mapping|19310|
|Subsignificant p-value    |96372|
+--------------------------+-----+



In [5]:
# cicaful = (
#     assoc
#     .filter(f.col('confidenceInterval').contains('decrease') & f.col('beta').isNull())
#     .select('studyAccession', 'variantId', 'effectSize', 'confidenceInterval', 
#     'beta', 'alternateAllele', 'referenceAllele', 'riskAllele', 'pValueMantissa', 'pValueExponent')
#     .persist()
# )
import sys
from scipy.stats import norm

def _pval_to_zscore(pvalcol: Column) -> Column:
    """Convert p-value column to z-score column.

    Args:
        pvalcol (Column): pvalues to be casted to floats.

    Returns:
        Column: p-values transformed to z-scores

    Examples:
        >>> d = [{"id": "t1", "pval": "1"}, {"id": "t2", "pval": "0.9"}, {"id": "t3", "pval": "0.05"}, {"id": "t4", "pval": "1e-300"}, {"id": "t5", "pval": "1e-1000"}, {"id": "t6", "pval": "NA"}]
        >>> df = spark.createDataFrame(d)
        >>> df.withColumn("zscore", pval_to_zscore(f.col("pval"))).show()
        +---+-------+----------+
        | id|   pval|    zscore|
        +---+-------+----------+
        | t1|      1|       0.0|
        | t2|    0.9|0.12566137|
        | t3|   0.05|  1.959964|
        | t4| 1e-300| 37.537838|
        | t5|1e-1000| 37.537838|
        | t6|     NA|      null|
        +---+-------+----------+
        <BLANKLINE>

    """
    pvalue_float = pvalcol.cast(t.FloatType())
    pvalue_nozero = f.when(pvalue_float == 0, sys.float_info.min).otherwise(
        pvalue_float
    )
    return f.udf(
        lambda pv: float(abs(norm.ppf((float(pv)) / 2))) if pv else None,
        t.FloatType(),
    )(pvalue_nozero)


def _get_reverse_complement(allele_col: Column) -> Column:

    return f.when(
        allele_col.rlike("[ACTG]+"),
        f.reverse(f.translate(allele_col, "ACTG", "TGAC")),
    ).otherwise(allele_col)

def _harmonize_beta(effect_size: Column, confidence_interval: Column, needs_harmonization: Column) -> Column:
    beta = f.when(confidence_interval.contains('increase') | confidence_interval.contains('decrease'), effect_size).otherwise(None)
    return f.when(
        (
            confidence_interval.contains("increase")
            & needs_harmonization
        )
        | (
            confidence_interval.contains("decrease")
            & ~ needs_harmonization
        ),
        beta * -1,
    ).otherwise(beta)

def _calculate_beta_ci(beta: Column, zscore: Column, direction: Column) -> Column: 
    zscore_95 = f.lit(1.96)
    return f.when(
        direction == 'upper', beta + f.abs(zscore_95 * beta) / zscore
    ).when(direction == 'lower', beta - f.abs(zscore_95 * beta) / zscore).otherwise(None)

def _harmonize_odds_ratio(odds_ratio: Column, needsHarmonization: Column) -> Column: 
    return f.when(needsHarmonization, 1/odds_ratio).otherwise(odds_ratio)

def _calculate_or_ci(odds_ratio: Column, zscore: Column, direction: Column) -> Column:
    zscore_95 = f.lit(1.96)
    odds_ratio_estimate = f.log(odds_ratio)
    odds_ratio_se = odds_ratio_estimate / zscore
    return (
        f.when(direction == 'upper', f.exp(odds_ratio_estimate + f.abs(zscore_95 * odds_ratio_se)))
        .when(direction == 'lower', f.exp(odds_ratio_estimate - f.abs(zscore_95 * odds_ratio_se)))
    )
#  effectSize              | 0.161                                                                                                                                                                                                                                                                        
#  confidenceInterval      | unit increase   
(
    df
    # Adding a flag indicating if effect harmonization is required:
    .withColumn(
        "needsHarmonisation",
        # If the alleles are palindrom - the reference and alt alleles are reverse complement of each other:
        # eg. T -> A: in such cases we cannot disambiguate the effect, which means we cannot be sure if
        # the effect is given to the alt allele on the positive strand or the ref allele on
        # The negative strand. We assume, we don't need to harminze.
        f.when(
            (
                f.col("referenceAllele")
                == _get_reverse_complement(f.col("alternateAllele"))
            ), False
        )
        .when(
            # As we are calculating effect on the alternate allele, we have to harmonise effect
            # if the risk allele is reference allele or the reverse complement of the reference allele
            (f.col("riskAllele") == f.col("referenceAllele"))
            | (
                (
                    f.col("riskAllele")
                    == _get_reverse_complement(f.col("referenceAllele"))
                )
            ),
            True,
        ).otherwise(False),
    )
    # Z-score is needed to calculate 95% confidence interval:
    .withColumn(
        "zscore",
        _pval_to_zscore(
            f.concat_ws("E", f.col("pValueMantissa"), f.col("pValueExponent"))
        ),
    )
    # Harmonizing betas:
    .withColumn('beta', _harmonize_beta(f.col('effectSize'), f.col('confidenceInterval'), f.col("needsHarmonisation")))
    .withColumn('beta_ci_upper', _calculate_beta_ci(f.col('beta'), f.col("zscore"), f.lit('upper')))
    .withColumn('beta_ci_lower', _calculate_beta_ci(f.col('beta'), f.col("zscore"), f.lit('lower')))

    # # Harmonizing odds-ratios:
    .withColumn('odds_ratio', _harmonize_odds_ratio(f.col('effectSize'), f.col('confidenceInterval'), f.col("needsHarmonisation")))
    .withColumn('odds_ratio_ci_upper', _calculate_or_ci(f.col('odds_ratio'), f.col("zscore"), f.lit('upper')))
    .withColumn('odds_ratio_ci_lower', _calculate_or_ci(f.col('odds_ratio'), f.col("zscore"), f.lit('lower')))

    .show(1, False, True)
)



22/12/12 15:00:57 WARN org.apache.spark.sql.Column: Constructing trivially true equals predicate, 'upper = upper'. Perhaps you need to use aliases.
22/12/12 15:00:57 WARN org.apache.spark.sql.Column: Constructing trivially true equals predicate, 'lower = lower'. Perhaps you need to use aliases.


-RECORD 0-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
 variantId               | 16_79539854_C_T                                                                                                                                                                                                                                                              
 chromosome              | 16                                                                                                                                                                                                                                                                           
 position                | 79539854                                                                          

                                                                                

In [11]:
from math import log10
zscore = 4
odds_ratio = 0.3
zscore_95 = 1.96
odds_ratio_estimate = log10(odds_ratio)
odds_ratio_se = odds_ratio_estimate / zscore
zscore_95 * odds_ratio_se


-0.25621058518736545

In [4]:
-1.2 - abs(1.96*-1.2)/4

-1.7879999999999998