spark.catalog.clearCache()
spark.stop()

In [1]:
import findspark
findspark.init()

# Spark function
from pyspark.sql.functions import pandas_udf, udf
from pyspark.sql import SparkSession
from pyspark.sql.types import IntegerType, StringType
import pyspark.sql.functions as F
from pyspark.sql.functions import when
from pyspark import Row
from pyspark.sql.window import Window
from pyspark import StorageLevel

# Python function
import re
import subprocess
import numpy as np
import pandas
import pyarrow

# Start for Spark Session
spark = SparkSession.builder.master("spark://master:7077")\
                        .appName("gVCF_combine")\
                        .config("spark.driver.memory", "8G")\
                        .config("spark.driver.maxResultSize", "8G")\
                        .config("spark.executor.memory", "24G")\
                        .config("spark.executor.core", 3)\
                        .config("spark.sql.execution.arrow.enabled", "true")\
                        .config("spark.sql.execution.arrow.fallback.enabled", "true")\
                        .config("spark.network.timeout", 10000000)\
                        .config("spark.sql.shuffle.partitions", 40)\
                        .getOrCreate()

In [2]:
def hadoop_list(length, hdfs):
    args = "hdfs dfs -ls "+ hdfs +" | awk '{print $8}'"
    proc = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
    s_output, s_err = proc.communicate()
    all_dart_dirs = s_output.split()
    
    return all_dart_dirs[:length]

def preVCF(hdfs, flag, spark):
    vcf = spark.sparkContext.textFile(hdfs)
    #header_contig = vcf.filter(lambda x : re.match("^#", x))
    col_name = vcf.filter(lambda x : x.startswith("#CHROM")).first().split("\t")
    vcf_data = vcf.filter(lambda x : re.match("[^#][^#]", x))\
                       .map(lambda x : x.split("\t"))\
                       .toDF(col_name)\
                       .withColumn("POS", F.col("POS").cast(IntegerType()))
    
    if flag == 1:
        for index in range(len(vcf_data.columns) - 1):
            compared_arr = ["#CHROM", "POS", "REF"]
            if vcf_data.columns[index] in compared_arr:
                continue
            vcf_data = vcf_data.withColumnRenamed(vcf_data.columns[index], vcf_data.columns[index] + "_temp") 
    
    return vcf_data

def chunks(lst, n):
    for index in range(0, len(lst), n):
        yield lst[index:index + n]
        
def addIndex(POS, size):
    if POS == 1:
        return POS
    else :
        return int(POS / size + 1) 
addIndex_udf = udf(addIndex, returnType=IntegerType())

# for POS index
def sampling_func(data, ran):
    N = len(data)
    sample = data.take(range(0, N, ran))
    return sample   

In [3]:
hdfs = "hdfs://master:9000"
hdfs_list = hadoop_list(2, "/raw_data/gvcf")
info_window = Window.partitionBy("#CHROM").orderBy("POS")
vcf_list = list()

for index in range(len(hdfs_list)):
    if index == 0:
        join_vcf = preVCF(hdfs + hdfs_list[index].decode("UTF-8"), 0, spark).cache()
        inner_pos = join_vcf.select(F.col("#CHROM"), F.col("POS"), F.col("REF"))
    else:
        vcf_list.append(preVCF(hdfs + hdfs_list[index].decode("UTF-8"), 1, spark).cache())
        inner_pos_right = vcf_list[index - 1].select(F.col("#CHROM"), F.col("POS"), F.col("REF"))
        # for column null value
        join_vcf = join_vcf.join(vcf_list[index - 1], ["#CHROM", "POS", "REF"], "full")\
            .withColumn("ID", when(F.col("ID").isNull(), F.col("ID_temp")).otherwise(F.col("ID")))\
            .withColumn("ALT",when(F.col("ALT").isNull(), F.col("ALT_temp")).otherwise(F.col("ALT")))\
            .withColumn("FORMAT", when(F.col("FORMAT").isNull(), F.col("FORMAT_temp")).otherwise(F.col("FORMAT")))\
            .withColumn("QUAL", F.lit(".")).withColumn("FILTER", F.lit("."))\
            .withColumn("INFO", when(F.col("INFO").startswith("END") == False, F.col("INFO"))\
                        .when(F.col("INFO_temp").startswith("END") == False, F.col("INFO_temp")))\
            .drop("INFO_temp", "ID_temp", "ALT_temp", "FORMAT_temp", "QUAL_temp", "FILTER_temp")
        
        # for index
        inner_pos = inner_pos.join(inner_pos_right, ["#CHROM", "POS", "REF"], "inner")
    
join_vcf = join_vcf.withColumn("INFO", when(F.col("INFO").isNull(), F.concat(F.lit("END="), F.lead("POS", 1).over(info_window) - 1))\
                              .otherwise(F.col("INFO")))
join_vcf = join_vcf.orderBy(F.col("#CHROM"), F.col("POS")).cache()
join_vcf.count()

81337377

In [4]:
inner_pos = spark.createDataFrame(inner_pos.drop(F.col("REF")).orderBy(F.col("#CHROM"), F.col("POS"))\
            .toPandas().groupby("#CHROM", group_keys=False).apply(sampling_func, ran = 13).sort_index())\
            .withColumnRenamed("#CHROM", "chr_temp")\
            .withColumnRenamed("POS", "pos_temp")
inner_pos = inner_pos.orderBy(F.col("#CHROM"), F.col("POS")).cache()
inner_pos.count()

  An error occurred while calling z:org.apache.spark.sql.api.python.PythonSQLUtils.readArrowStreamFromFile.
: java.lang.IllegalArgumentException
	at java.nio.ByteBuffer.allocate(ByteBuffer.java:334)
	at org.apache.arrow.vector.ipc.message.MessageSerializer.readMessage(MessageSerializer.java:543)
	at org.apache.spark.sql.execution.arrow.ArrowConverters$$anon$3.readNextBatch(ArrowConverters.scala:243)
	at org.apache.spark.sql.execution.arrow.ArrowConverters$$anon$3.<init>(ArrowConverters.scala:229)
	at org.apache.spark.sql.execution.arrow.ArrowConverters$.getBatchesFromStream(ArrowConverters.scala:228)
	at org.apache.spark.sql.execution.arrow.ArrowConverters$$anonfun$readArrowStreamFromFile$2.apply(ArrowConverters.scala:216)
	at org.apache.spark.sql.execution.arrow.ArrowConverters$$anonfun$readArrowStreamFromFile$2.apply(ArrowConverters.scala:214)
	at org.apache.spark.util.Utils$.tryWithResource(Utils.scala:2543)
	at org.apache.spark.sql.execution.arrow.ArrowConverters$.readArrowStreamFr

132236

In [5]:
for index in range(len(vcf_list)):
    vcf_list[index].unpersist()

In [6]:
pos_index = Window.partitionBy("#CHROM").orderBy("POS").rangeBetween(Window.unboundedPreceding, Window.currentRow)
ex = [join_vcf["#CHROM"] == inner_pos["chr_temp"], join_vcf["POS"] == inner_pos["pos_temp"]]
temp = join_vcf.select(F.col("#CHROM"), F.col("POS")).join(inner_pos, ex, "full")\
               .drop(F.col("chr_temp"))\
               .withColumn("POS_INDEX", when(F.col("pos_temp").isNull(), F.last(F.col("pos_temp"), ignorenulls=True).over(pos_index))\
                           .otherwise(F.col("pos_temp")))\
               .drop(F.col("pos_temp")).orderBy(F.col("#CHROM"), F.col("POS")).cache()
temp.count()

81337377

In [7]:
join_vcf_index = join_vcf.join(temp, ["#CHROM", "POS"], "inner").orderBy(F.col("#CHROM"), F.col("POS")).dropDuplicates()\
                         .repartition(F.col("#CHROM"), F.col("POS_INDEX")).cache()
join_vcf_index.count()
join_vcf.unpersist()

DataFrame[#CHROM: string, POS: int, REF: string, ID: string, ALT: string, QUAL: string, FILTER: string, INFO: string, FORMAT: string, ND02798: string, ND02809: string]

In [8]:
sample_w = Window.partitionBy(F.col("#CHROM"), F.col("POS_INDEX")).orderBy(F.col("POS")).rangeBetween(Window.unboundedPreceding, Window.currentRow)   
for sample_name in join_vcf_index.columns[9:-1]:     
    join_vcf_index = join_vcf_index.withColumn(sample_name, when(F.col(sample_name).isNull(), F.last(sample_name, ignorenulls=True).over(sample_w))\
                                                   .otherwise(F.col(sample_name)))
join_vcf_index.drop(F.col("POS_INDEX"))\
            .write.mode('overwrite').parquet("/raw_data/output/gvcf_2")