In [None]:
import pyspark.sql
from pyspark.sql.types import *
from pyspark.sql.functions import *

@udf(ListType())
def split_last(data):
    ''' Calaculates harmonic sum of a list, scaling factor = 1
    '''
    return sum(s / (i+1) for i, s in enumerate(data))


# Parket file with all the data:
literature_data = '/Users/dsuveges/project/evidences/20.06_evidence_data.filtered.parquet'

global spark
spark = (pyspark.sql.SparkSession.builder.getOrCreate())
print('Spark version: ', spark.version)

# Load and select required cols
data = (
    spark.read.json(evidence_file)
    .filter(col('type') == 'literature')
    .select(
        col('type').alias('evidence_type'),
        col('unique_association_fields.datasource').alias('data_source'),
        col('disease.efo_info.efo_id').alias('efo_code'),
        col('target.gene_info.geneid').alias('gene_id'),
        col('scores.association_score').alias('assoc_score'),
        col('unique_association_fields.publication_id').alias('pmid'),
    )
    # Clean efo_code
    .withColumn('efo_code', split_last(col('efo_code')))
    .withColumn('pmid', split_last(col('pmid')))
)

# Save data:
(
    data
    .repartitionByRange('evidence_type', 'data_source', 'efo_code', 'gene_id')
    .write
    .parquet(
        out_file
    )
)
