spark.catalog.clearCache()
spark.stop()

In [1]:
import findspark
findspark.init()

# Spark function
from pyspark.sql.functions import pandas_udf, udf
from pyspark.sql import SparkSession
from pyspark.sql.types import IntegerType, StringType
import pyspark.sql.functions as F
from pyspark.sql.functions import when
from pyspark import Row
from pyspark.sql.window import Window
from pyspark import StorageLevel

# Python function
import re
import subprocess
import numpy as np
import pandas
import pyarrow

# Start for Spark Session
spark = SparkSession.builder.master("spark://master:7077")\
                        .appName("gVCF_combine")\
                        .config("spark.driver.memory", "8G")\
                        .config("spark.driver.maxResultSize", "8G")\
                        .config("spark.executor.memory", "24G")\
                        .config("spark.sql.execution.arrow.enabled", "false")\
                        .config("spark.sql.execution.arrow.fallback.enabled", "false")\
                        .config("spark.network.timeout", 10000000)\
                        .config("spark.sql.shuffle.partitions", 40)\
                        .getOrCreate()

In [2]:
def hadoop_list(length, hdfs):
    args = "hdfs dfs -ls "+ hdfs +" | awk '{print $8}'"
    proc = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
    s_output, s_err = proc.communicate()
    all_dart_dirs = s_output.split()
    
    return all_dart_dirs[:length]

def preVCF(hdfs, flag, spark):
    vcf = spark.sparkContext.textFile(hdfs)
    #header_contig = vcf.filter(lambda x : re.match("^#", x))
    col_name = vcf.filter(lambda x : x.startswith("#CHROM")).first().split("\t")
    vcf_data = vcf.filter(lambda x : re.match("[^#][^#]", x))\
                       .map(lambda x : x.split("\t"))\
                       .toDF(col_name)\
                       .withColumn("POS", F.col("POS").cast(IntegerType()))
    
    if flag == 1:
        for index in range(len(vcf_data.columns) - 1):
            compared_arr = ["#CHROM", "POS", "REF"]
            if vcf_data.columns[index] in compared_arr:
                continue
            vcf_data = vcf_data.withColumnRenamed(vcf_data.columns[index], vcf_data.columns[index] + "_temp") 
    
    return vcf_data

def chunks(lst, n):
    for index in range(0, len(lst), n):
        yield lst[index:index + n]
        
def addIndex(POS, size):
    if POS == 1:
        return POS
    else :
        return int(POS / size + 1) 
addIndex_udf = udf(addIndex, returnType=IntegerType())

# for POS index
def sampling_func(data, ran):
    N = len(data)
    sample = data.take(range(0, N, ran))
    return sample   

In [3]:
hdfs = "hdfs://master:9000"
hdfs_list = hadoop_list(10, "/raw_data/gvcf")
info_window = Window.partitionBy("#CHROM").orderBy("POS")

for index in range(len(hdfs_list)):
    if index == 0:
        join_vcf = preVCF(hdfs + hdfs_list[index].decode("UTF-8"), 0, spark)
        inner_pos = join_vcf.select(F.col("#CHROM"), F.col("POS"), F.col("REF"))
        
    else:
        joiner = preVCF(hdfs + hdfs_list[index].decode("UTF-8"), 1, spark)
        inner_pos_right = joiner.select(F.col("#CHROM"), F.col("POS"), F.col("REF"))
        
        # for column null value
        join_vcf = join_vcf.join(joiner, ["#CHROM", "POS", "REF"], "full")\
            .withColumn("ID", when(F.col("ID").isNull(), F.col("ID_temp")).otherwise(F.col("ID")))\
            .withColumn("ALT",when(F.col("ALT").isNull(), F.col("ALT_temp")).otherwise(F.col("ALT")))\
            .withColumn("FORMAT", when(F.col("FORMAT").isNull(), F.col("FORMAT_temp")).otherwise(F.col("FORMAT")))\
            .withColumn("QUAL", F.lit(".")).withColumn("FILTER", F.lit("."))\
            .withColumn("INFO", when(F.col("INFO").startswith("END") == False, F.col("INFO"))\
                        .when(F.col("INFO_temp").startswith("END") == False, F.col("INFO_temp")))\
            .drop("INFO_temp", "ID_temp", "ALT_temp", "FORMAT_temp", "QUAL_temp", "FILTER_temp")
        
        # for index
        inner_pos = inner_pos.join(inner_pos_right, ["#CHROM", "POS", "REF"], "inner")
    
join_vcf = join_vcf.withColumn("INFO", when(F.col("INFO").isNull(), F.concat(F.lit("END="), F.lead("POS", 1).over(info_window) - 1))\
                              .otherwise(F.col("INFO")))
join_vcf = join_vcf.orderBy(F.col("#CHROM"), F.col("POS")).cache()
join_vcf.count()

226934049

In [4]:
inner_pos = spark.createDataFrame(inner_pos.drop(F.col("REF")).orderBy(F.col("#CHROM"), F.col("POS"))\
            .toPandas().groupby("#CHROM", group_keys=False).apply(sampling_func, ran = 14).sort_index())\
            .withColumnRenamed("#CHROM", "chr_temp")\
            .withColumnRenamed("POS", "pos_temp")
inner_pos = inner_pos.orderBy(F.col("#CHROM"), F.col("POS")).cache()
inner_pos.count()

3871

In [5]:
pos_index = Window.partitionBy("#CHROM").orderBy("POS").rangeBetween(Window.unboundedPreceding, Window.currentRow)
ex = [join_vcf["#CHROM"] == inner_pos["chr_temp"], join_vcf["POS"] == inner_pos["pos_temp"]]
temp = join_vcf.select(F.col("#CHROM"), F.col("POS")).join(inner_pos, ex, "full")\
               .drop(F.col("chr_temp"))\
               .withColumn("POS_INDEX", when(F.col("pos_temp").isNull(), F.last(F.col("pos_temp"), ignorenulls=True).over(pos_index))\
                           .otherwise(F.col("pos_temp")))\
               .drop(F.col("pos_temp")).orderBy(F.col("#CHROM"), F.col("POS")).cache()
temp.count()

226934049

In [6]:
sample_w = Window.partitionBy(F.col("#CHROM"), F.col("POS_INDEX")).orderBy(F.col("POS")).rangeBetween(Window.unboundedPreceding, Window.currentRow)   
join_vcf_index = join_vcf.join(temp, ["#CHROM", "POS"], "inner").orderBy(F.col("#CHROM"), F.col("POS")).dropDuplicates().repartition(F.col("#CHROM"), F.col("POS_INDEX"))

for sample_name in join_vcf_index.columns[9:-1]:     
    join_vcf_index = join_vcf_index.withColumn(sample_name, when(F.col(sample_name).isNull(), F.last(sample_name, ignorenulls=True).over(sample_w))\
                                                   .otherwise(F.col(sample_name)))
    
join_vcf_index = join_vcf_index.drop(F.col("POS_INDEX")).orderBy(F.col("#CHROM"), F.col("POS")).cache()
join_vcf_index.count()

Py4JJavaError: An error occurred while calling o1560.count.
: org.apache.spark.SparkException: Job 32 cancelled because killed via the Web UI
	at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1889)
	at org.apache.spark.scheduler.DAGScheduler.handleJobCancellation(DAGScheduler.scala:1824)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleStageCancellation$1.apply$mcVI$sp(DAGScheduler.scala:1813)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleStageCancellation$1.apply(DAGScheduler.scala:1806)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleStageCancellation$1.apply(DAGScheduler.scala:1806)
	at scala.collection.IndexedSeqOptimized$class.foreach(IndexedSeqOptimized.scala:33)
	at scala.collection.mutable.ArrayOps$ofInt.foreach(ArrayOps.scala:234)
	at org.apache.spark.scheduler.DAGScheduler.handleStageCancellation(DAGScheduler.scala:1806)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2073)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2059)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2048)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:737)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2061)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2082)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2101)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2126)
	at org.apache.spark.rdd.RDD$$anonfun$collect$1.apply(RDD.scala:945)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:363)
	at org.apache.spark.rdd.RDD.collect(RDD.scala:944)
	at org.apache.spark.sql.execution.SparkPlan.executeCollect(SparkPlan.scala:299)
	at org.apache.spark.sql.Dataset$$anonfun$count$1.apply(Dataset.scala:2836)
	at org.apache.spark.sql.Dataset$$anonfun$count$1.apply(Dataset.scala:2835)
	at org.apache.spark.sql.Dataset$$anonfun$52.apply(Dataset.scala:3370)
	at org.apache.spark.sql.execution.SQLExecution$$anonfun$withNewExecutionId$1.apply(SQLExecution.scala:78)
	at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:125)
	at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:73)
	at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3369)
	at org.apache.spark.sql.Dataset.count(Dataset.scala:2835)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:238)
	at java.lang.Thread.run(Thread.java:748)


In [None]:
join_vcf.unpersist()
temp.unpersist()