In [9]:
spark.catalog.clearCache()
spark.stop()

In [10]:
import findspark
findspark.init()

# Spark function
from pyspark import Row, StorageLevel
from pyspark.sql import SparkSession, DataFrame
from pyspark.sql.functions import pandas_udf, udf, explode, array, when
from pyspark.sql.types import IntegerType, StringType, ArrayType,BooleanType
from pyspark.sql.window import Window
import pyspark.sql.functions as F

# Python function
import re
import subprocess
import numpy as np
import pandas
import pyarrow
from functools import reduce 
import copy
import time

appname = input("appname, folder name : ")
folder_name = copy.deepcopy(appname) 
gvcf_count = int(input("gvcf count : "))

# Start for Spark Session
spark = SparkSession.builder.master("spark://master:7077")\
                        .appName(appname)\
                        .config("spark.driver.memory", "8G")\
                        .config("spark.driver.maxResultSize", "5G")\
                        .config("spark.executor.memory", "25G")\
                        .config("spark.sql.execution.arrow.enabled", "true")\
                        .config("spark.sql.execution.arrow.fallback.enabled", "true")\
                        .config("spark.network.timeout", "9999s")\
                        .config("spark.files.fetchTimeout", "9999s")\
                        .config("spark.sql.shuffle.partitions", 100)\
                        .config("spark.eventLog.enabled", "true")\
                        .getOrCreate()

appname, folder name : gvcf_indel_seperate_20
gvcf count : 20


In [11]:
def hadoop_list(length, hdfs):
    args = "hdfs dfs -ls "+ hdfs +" | awk '{print $8}'"
    proc = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
    s_output, s_err = proc.communicate()
    all_dart_dirs = s_output.split()
    return all_dart_dirs[:length]

def preVCF(hdfs, spark):
    vcf = spark.sparkContext.textFile(hdfs)
    #header_contig = vcf.filter(lambda x : re.match("^#", x))
    col_name = vcf.filter(lambda x : x.startswith("#CHROM")).first().split("\t")
    vcf_data = vcf.filter(lambda x : re.match("[^#][^#]", x))\
                       .map(lambda x : x.split("\t"))\
                       .toDF(col_name)\
                       .withColumn("POS", F.col("POS").cast(IntegerType()))
    sample_name = vcf_data.columns[-1]
    vcf_data = vcf_data.drop("QUAL", "FILTER")
    
    for index in range(len(vcf_data.columns) - 1):
        compared_arr = ["#CHROM", "POS"]
        if vcf_data.columns[index] in compared_arr:
            continue
        vcf_data = vcf_data.withColumnRenamed(vcf_data.columns[index], vcf_data.columns[index] + "_" + sample_name)     
    return vcf_data

def chunks(lst, n):
    for index in range(0, len(lst), n):
        yield lst[index:index + n]
        
def addIndex(POS, size):
    if POS == 1:
        return POS
    else :
        return int(POS / size + 1) 
addIndex_udf = udf(addIndex, returnType=IntegerType())

def unionAll(*dfs):
    return reduce(DataFrame.unionByName, dfs) 

# for indel
word_len = udf(lambda col : True if len(col) >= 2 else False, returnType=BooleanType())
ref_melt = udf(lambda ref : list(ref)[1:], ArrayType(StringType()))    

def ref_concat(temp): 
    return_str = []
    for num in range(0, len(temp)):
        return_str.append(temp[num] + "_" + str(int(num + 1)))
    return return_str
ref_concat = udf(ref_concat, ArrayType(StringType()))

def ref_max(left, right):
    if left == None:
        return right
    elif right == None:
        return left
    else :
        if len(left) >= len(right):
            return left
        else :
            return right
ref_max = udf(ref_max, StringType())

def info_change(temp):
    some_list = temp.split(";")
    result = [i for i in some_list if i.startswith('DP=')]
    return result[0]
info_change = udf(info_change, StringType())

# for sample value
value_change = udf(lambda value : "./." + value[3:], StringType())

# for POS index
def sampling_func(data, ran):
    N = len(data)
    sample = data.take(range(0, N, ran))
    return sample

def reduce_join(left, right):   
    return_vcf = left.drop(left.columns[-1])\
                     .join(right.drop(right.columns[-1]), ["#CHROM", "POS"], "full")
    
    return return_vcf

def reduce_inner_join(left, right):   
    return_vcf = left.join(right, ["#CHROM", "POS"], "inner")
    
    return return_vcf

def column_name(df_col, name):
    return_list = []
    for col in df_col:
        if col.startswith(name):
            return_list.append(col)
    return return_list

def max_value(value):
    value = list(filter(None, value))
    if len(value) == 0:
        return None
    return max(value)
max_value = udf(max_value, StringType())

def info_min(value):
    value = list(filter(None, value))
    temp = [info for info in value if info.startswith("END=") == False]
    temp = "%".join(temp)
    
    if temp == "":
        return None
    else :
        return temp
info_min = udf(info_min, StringType())

def format_value(value): 
    # ##FORMAT=<ID=SB,Number=4,Type=Integer,Description="Per-sample component statistics which comprise the Fisher's Exact Test to detect strand bias.">
    value = list(filter(None, value))
    if len(value) == 1:
        value.append("GT:DP:GQ:MIN_DP:PL")
    def format_reduce(left, right):
        left, right = left.split(":"), right.split(":")
        if len(left) <= len(right):        
            temp = copy.deepcopy(right)
            right = copy.deepcopy(left)
            left = copy.deepcopy(temp)
        for value in right:
            if value not in left:
                left.append(value)
        return ":".join(left)
    return str(reduce(format_reduce, value))
format_value = udf(format_value, StringType())

def with_vale(temp):
    temp = temp.withColumn("REF", max_value(F.array(column_name(temp.columns, "REF"))))\
     .drop(*column_name(temp.columns, "REF_"))\
     .withColumn("ID", max_value(F.array(column_name(temp.columns, "ID"))))\
     .drop(*column_name(temp.columns, "ID_"))\
     .withColumn("ALT", max_value(F.array(column_name(temp.columns, "ALT"))))\
     .drop(*column_name(temp.columns, "ALT_"))\
     .withColumn("INFO", info_min(F.array(column_name(temp.columns, "INFO"))))\
     .drop(*column_name(temp.columns, "INFO_"))\
     .withColumn("FORMAT", F.array(column_name(temp.columns, "FORMAT")))\
     .drop(*column_name(temp.columns, "FORMAT_"))
    return temp

def indel_union(temp):
    temp = temp.filter(word_len(F.col("REF")))\
            .withColumn("REF", ref_melt(F.col("REF"))).withColumn("REF", ref_concat(F.col("REF")))\
            .withColumn("REF", explode(F.col("REF"))).withColumnRenamed("REF", "REF_temp")\
            .withColumn('REF', split_col.getItem(0)).withColumn('POS_var', split_col.getItem(1))\
            .drop(F.col("REF_temp")).withColumn("POS", (F.col("POS") + F.col("POS_var")).cast(IntegerType()))\
            .drop(F.col("POS_var"))\
            .withColumn('ID', F.lit("."))\
            .withColumn('ALT', F.lit("*,<NON_REF>"))
    return temp

# 10개 기준임.
def join_split(v_list):
    stage1_list = list(chunks(v_list, 5))
    if len(v_list) == 1:
        return v_list    
    stage1 = []
    for vcf in stage1_list:
        if len(vcf) == 1:
            stage1.append(vcf)
        else :
            stage1.append(reduce(reduce_join, vcf))
    return reduce(reduce_join, stage1)

def join_split_inner(v_list):
    stage1_list = list(chunks(v_list, 3))
    stage1 = []
    for vcf in stage1_list:
        stage1.append(reduce(reduce_inner_join, vcf))
    return stage1

def find_duplicate(temp):
    return temp.groupBy(F.col("#CHROM"), F.col("POS")).agg((F.count("*")>1).cast("int").alias("e"))\
         .orderBy(F.col("e"), ascending=False)

In [12]:
#gvcf_count = 20
hdfs = "hdfs://master:9000"
hdfs_list = hadoop_list(gvcf_count, "/raw_data/gvcf")
vcf_list = list()

for index in range(len(hdfs_list)):
    vcf_list.append(preVCF(hdfs + hdfs_list[index].decode("UTF-8"), spark).cache())

In [13]:
if gvcf_count == 20 or gvcf_count == 15:
    temp1 = join_split(vcf_list[:10])
    temp2 = join_split(vcf_list[10:])
    vcf = reduce(reduce_join, [temp1, temp2])
elif gvcf_count == 25:
    temp1 = join_split(vcf_list[:10])
    temp2 = join_split(vcf_list[10:20])
    temp3 = join_split(vcf_list[20:])
    temp4 = reduce(reduce_join, [temp1, temp2])
    vcf = reduce(reduce_join, [temp3, temp4])
elif gvcf_count > 5 and gvcf_count <= 10:
    temp1 = reduce(reduce_join, vcf_list[:5])
    temp2 = reduce(reduce_join, vcf_list[5:])
    vcf = reduce(reduce_join, [temp1, temp2])
elif gvcf_count <= 5:
    vcf = reduce(reduce_join, vcf_list)
    
vcf = with_vale(vcf).cache()
vcf.count()    
    
# window
info_window = Window.partitionBy("#CHROM").orderBy("POS")
vcf_not_indel = vcf.withColumn("INFO", when(F.col("INFO").isNull(), F.concat(F.lit("END="), F.lead("POS", 1).over(info_window) - 1))\
                              .otherwise(F.col("INFO")))

# union
# dropduplicates할 때, indel 삭제되는 경우 있음.

# indel union
split_col = F.split("REF_temp", '_')
indel = indel_union(vcf)

indel_com = unionAll(*[indel, vcf_not_indel])\
                 .orderBy(F.col("#CHROM"), F.col("POS"))\
                 .dropDuplicates(["#CHROM", "POS"])\
                 .cache()                 
indel_com.count()

indel.unpersist()
vcf.unpersist()

Py4JJavaError: An error occurred while calling o5107.count.
: org.apache.spark.SparkException: Job aborted due to stage failure: Master removed our application: KILLED
	at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1891)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1879)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1878)
	at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:1878)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:927)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:927)
	at scala.Option.foreach(Option.scala:257)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:927)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2112)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2061)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2050)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:738)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2061)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2082)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2101)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2126)
	at org.apache.spark.rdd.RDD$$anonfun$collect$1.apply(RDD.scala:990)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:385)
	at org.apache.spark.rdd.RDD.collect(RDD.scala:989)
	at org.apache.spark.sql.execution.SparkPlan.executeCollect(SparkPlan.scala:299)
	at org.apache.spark.sql.Dataset$$anonfun$count$1.apply(Dataset.scala:2836)
	at org.apache.spark.sql.Dataset$$anonfun$count$1.apply(Dataset.scala:2835)
	at org.apache.spark.sql.Dataset$$anonfun$52.apply(Dataset.scala:3370)
	at org.apache.spark.sql.execution.SQLExecution$$anonfun$withNewExecutionId$1.apply(SQLExecution.scala:80)
	at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:127)
	at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:75)
	at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3369)
	at org.apache.spark.sql.Dataset.count(Dataset.scala:2835)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:238)
	at java.lang.Thread.run(Thread.java:748)


In [5]:
sample_w = Window.partitionBy(F.col("#CHROM")).orderBy(F.col("POS")).rangeBetween(Window.unboundedPreceding, Window.currentRow)  
parquet_list = list()

for index in range(len(vcf_list)):
    temp = indel_com.select(["#CHROM","POS"])\
          .join(vcf_list[index].select(["#CHROM", "POS"] + vcf_list[index].columns[7:]), ["#CHROM", "POS"], "full")
        # sample window    
    temp = temp.withColumn(temp.columns[-1], F.last(temp.columns[-1], ignorenulls=True).over(sample_w))
    parquet_list.append(temp)

In [6]:
cnt = 0
for parquet in join_split_inner(parquet_list):
    time.sleep(30)
    parquet.write.mode('overwrite')\
            .parquet("/raw_data/output/gvcf_output/"+ folder_name + "//" + "sample_" + str(cnt) + ".g.vcf")
    
    for index in range(cnt, cnt + 3):
        if len(vcf_list) - 1 < index:
            break
        vcf_list[index].unpersist()
    cnt += 3

indel_com.write.mode('overwrite')\
            .parquet("/raw_data/output/gvcf_output/" + folder_name + "//info.g.vcf")
    #.withColumn(sample_name, value_change(F.col(sample_name)))

In [7]:
#spark.read.parquet("/raw_data/output/gvcf_output/3/sample_0.g.vcf").count()