spark.catalog.clearCache()
spark.stop()

In [1]:
import findspark
findspark.init()

# Spark function
from pyspark import Row, StorageLevel
from pyspark.sql import SparkSession, DataFrame
from pyspark.sql.functions import pandas_udf, udf, explode, array, when
from pyspark.sql.types import IntegerType, StringType, ArrayType,BooleanType
from pyspark.sql.window import Window
import pyspark.sql.functions as F

# Python function
import re
import subprocess
import numpy as np
import pandas
import pyarrow
from functools import reduce 
import copy
import time

appname = input("appname, folder name : ")
folder_name = copy.deepcopy(appname) 
gvcf_count = int(input("gvcf count : "))

# Start for Spark Session
spark = SparkSession.builder.master("spark://master:7077")\
                        .appName(appname)\
                        .config("spark.driver.memory", "8G")\
                        .config("spark.driver.maxResultSize", "5G")\
                        .config("spark.executor.memory", "25G")\
                        .config("spark.memory.fraction", 0.2)\
                        .config("spark.sql.shuffle.partitions", 100)\
                        .config("spark.eventLog.enabled", "true")\
                        .getOrCreate()

appname, folder name : gvcf_indel_20_sep
gvcf count : 20


In [2]:
def hadoop_list(length, hdfs):
    args = "hdfs dfs -ls "+ hdfs +" | awk '{print $8}'"
    proc = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
    s_output, s_err = proc.communicate()
    all_dart_dirs = s_output.split()
    return all_dart_dirs[:length]

def preVCF(hdfs, spark):
    vcf = spark.sparkContext.textFile(hdfs)
    #header_contig = vcf.filter(lambda x : re.match("^#", x))
    col_name = vcf.filter(lambda x : x.startswith("#CHROM")).first().split("\t")
    vcf_data = vcf.filter(lambda x : re.match("[^#][^#]", x))\
                       .map(lambda x : x.split("\t"))\
                       .toDF(col_name)\
                       .withColumn("POS", F.col("POS").cast(IntegerType()))
    sample_name = vcf_data.columns[-1]
    vcf_data = vcf_data.drop("QUAL", "FILTER")
    
    for index in range(len(vcf_data.columns) - 1):
        compared_arr = ["#CHROM", "POS"]
        if vcf_data.columns[index] in compared_arr:
            continue
        vcf_data = vcf_data.withColumnRenamed(vcf_data.columns[index], vcf_data.columns[index] + "_" + sample_name)     
    return vcf_data

def chunks(lst, n):
    for index in range(0, len(lst), n):
        yield lst[index:index + n]
        
def addIndex(POS, size):
    if POS == 1:
        return POS
    else :
        return int(POS / size + 1) 
addIndex_udf = udf(addIndex, returnType=IntegerType())

def unionAll(*dfs):
    return reduce(DataFrame.unionByName, dfs) 

# for indel
word_len = udf(lambda col : True if len(col) >= 2 else False, returnType=BooleanType())
ref_melt = udf(lambda ref : list(ref)[1:], ArrayType(StringType()))    

def ref_concat(temp): 
    return_str = []
    for num in range(0, len(temp)):
        return_str.append(temp[num] + "_" + str(int(num + 1)))
    return return_str
ref_concat = udf(ref_concat, ArrayType(StringType()))

def ref_max(left, right):
    if left == None:
        return right
    elif right == None:
        return left
    else :
        if len(left) >= len(right):
            return left
        else :
            return right
ref_max = udf(ref_max, StringType())

def info_change(temp):
    some_list = temp.split(";")
    result = [i for i in some_list if i.startswith('DP=')]
    return result[0]
info_change = udf(info_change, StringType())

# for sample value
value_change = udf(lambda value : "./." + value[3:], StringType())

# for POS index
def sampling_func(data, ran):
    N = len(data)
    sample = data.take(range(0, N, ran))
    return sample

def reduce_join(left, right):   
    return_vcf = left.drop(left.columns[-1])\
                     .join(right.drop(right.columns[-1]), ["#CHROM", "POS"], "full")
    
    return return_vcf

def reduce_inner_join(left, right):   
    return_vcf = left.join(right, ["#CHROM", "POS"], "inner")
    
    return return_vcf

def column_name(df_col, name):
    return_list = []
    for col in df_col:
        if col.startswith(name):
            return_list.append(col)
    return return_list

def max_value(value):
    value = list(filter(None, value))
    if len(value) == 0:
        return None
    return max(value)
max_value = udf(max_value, StringType())

def info_min(value):
    value = list(filter(None, value))
    temp = [info for info in value if info.startswith("END=") == False]
    temp = "%".join(temp)
    
    if temp == "":
        return None
    else :
        return temp
info_min = udf(info_min, StringType())

def format_value(value): 
    # ##FORMAT=<ID=SB,Number=4,Type=Integer,Description="Per-sample component statistics which comprise the Fisher's Exact Test to detect strand bias.">
    value = list(filter(None, value))
    if len(value) == 1:
        value.append("GT:DP:GQ:MIN_DP:PL")
    def format_reduce(left, right):
        left, right = left.split(":"), right.split(":")
        if len(left) <= len(right):        
            temp = copy.deepcopy(right)
            right = copy.deepcopy(left)
            left = copy.deepcopy(temp)
        for value in right:
            if value not in left:
                left.append(value)
        return ":".join(left)
    return str(reduce(format_reduce, value))
format_value = udf(format_value, StringType())

def with_vale(temp):
    temp = temp.withColumn("REF", max_value(F.array(column_name(temp.columns, "REF"))))\
     .drop(*column_name(temp.columns, "REF_"))\
     .withColumn("ID", max_value(F.array(column_name(temp.columns, "ID"))))\
     .drop(*column_name(temp.columns, "ID_"))\
     .withColumn("ALT", max_value(F.array(column_name(temp.columns, "ALT"))))\
     .drop(*column_name(temp.columns, "ALT_"))\
     .withColumn("INFO", info_min(F.array(column_name(temp.columns, "INFO"))))\
     .drop(*column_name(temp.columns, "INFO_"))\
     .withColumn("FORMAT", F.array(column_name(temp.columns, "FORMAT")))\
     .drop(*column_name(temp.columns, "FORMAT_"))
    return temp

def indel_union(temp):
    split_col = F.split("REF_temp", '_')
    temp = temp.filter(word_len(F.col("REF")))\
            .withColumn("REF", ref_melt(F.col("REF"))).withColumn("REF", ref_concat(F.col("REF")))\
            .withColumn("REF", explode(F.col("REF"))).withColumnRenamed("REF", "REF_temp")\
            .withColumn('REF', split_col.getItem(0)).withColumn('POS_var', split_col.getItem(1))\
            .drop(F.col("REF_temp")).withColumn("POS", (F.col("POS") + F.col("POS_var")).cast(IntegerType()))\
            .drop(F.col("POS_var"))\
            .withColumn('ID', F.lit("."))\
            .withColumn('ALT', F.lit("*,<NON_REF>"))
    return temp

# 10개 기준임.
def join_split(v_list):
    stage1_list = list(chunks(v_list, 5))
    if len(v_list) == 1:
        return v_list    
    stage1 = []
    for vcf in stage1_list:
        if len(vcf) == 1:
            stage1.append(vcf)
        else :
            stage1.append(reduce(reduce_join, vcf))
    return reduce(reduce_join, stage1)

def join_split_inner(v_list, num):
    stage1_list = list(chunks(v_list, num))
    stage1 = []
    for vcf in stage1_list:
        stage1.append(reduce(reduce_inner_join, vcf))
    return stage1

def find_duplicate(temp):
    return temp.groupBy(F.col("#CHROM"), F.col("POS")).agg((F.count("*")>1).cast("int").alias("e"))\
         .orderBy(F.col("e"), ascending=False)

In [3]:
#gvcf_count = 20
hdfs = "hdfs://master:9000"
hdfs_list = hadoop_list(gvcf_count, "/raw_data/gvcf")
vcf_list = list()

for index in range(len(hdfs_list)):
    vcf_list.append(preVCF(hdfs + hdfs_list[index].decode("UTF-8"), spark).cache())

In [4]:
if gvcf_count == 20 or gvcf_count == 15:
    temp1 = join_split(vcf_list[:10])
    temp2 = join_split(vcf_list[10:])
    vcf = reduce(reduce_join, [temp1, temp2])
elif gvcf_count == 25:
    temp1 = join_split(vcf_list[:10])
    temp2 = join_split(vcf_list[10:20])
    temp3 = join_split(vcf_list[20:])
    temp4 = reduce(reduce_join, [temp1, temp2])
    vcf = reduce(reduce_join, [temp3, temp4])
elif gvcf_count > 5 and gvcf_count <= 10:
    temp1 = reduce(reduce_join, vcf_list[:5])
    temp2 = reduce(reduce_join, vcf_list[5:])
    vcf = reduce(reduce_join, [temp1, temp2])
elif gvcf_count <= 5:
    vcf = reduce(reduce_join, vcf_list)
    
vcf = with_vale(vcf).cache()
    
# window
info_window = Window.partitionBy("#CHROM").orderBy("POS")
vcf_not_indel = vcf.withColumn("INFO", when(F.col("INFO").isNull(), F.concat(F.lit("END="), F.lead("POS", 1).over(info_window) - 1))\
                              .otherwise(F.col("INFO")))

# dropduplicates할 때, indel 삭제되는 경우 있음, 아직 수정 안 됨.!
# indel union
unionAll(*[indel_union(vcf), vcf_not_indel])\
                 .orderBy(F.col("#CHROM"), F.col("POS"))\
                 .dropDuplicates(["#CHROM", "POS"])\
                 .write.mode('overwrite')\
                 .parquet("/raw_data/output/gvcf_output/" + folder_name + "//info.g.vcf")

spark.catalog.clearCache()
spark.stop()

In [5]:
# Start for Spark Session with write
spark = SparkSession.builder.master("spark://master:7077")\
                        .appName(appname + "_sample")\
                        .config("spark.driver.memory", "8G")\
                        .config("spark.driver.maxResultSize", "5G")\
                        .config("spark.executor.memory", "25G")\
                        .config("spark.sql.shuffle.partitions", 100)\
                        .config("spark.eventLog.enabled", "true")\
                        .config("spark.memory.fraction", 0.1)\
                        .config("spark.cleaner.periodicGC.interval", "15min")\
                        .getOrCreate()

#.config("spark.cleaner.periodicGC.interval", "1min")

In [6]:
# info parquet read and cache()
indel_com = spark.read.parquet("/raw_data/output/gvcf_output/" + folder_name + "//info.g.vcf")\
                 .select(["#CHROM","POS"]).orderBy(F.col("#CHROM"), F.col("POS")).cache()
indel_com.count()

vcf_list = list()
for index in range(len(hdfs_list)):
    vcf_list.append(preVCF(hdfs + hdfs_list[index].decode("UTF-8"), spark))

In [8]:
sample_w = Window.partitionBy(F.col("#CHROM")).orderBy(F.col("POS")).rangeBetween(Window.unboundedPreceding, Window.currentRow)  
parquet_list = list()

for index in range(len(vcf_list)):
    temp = indel_com\
          .join(vcf_list[index].select(["#CHROM", "POS"] + vcf_list[index].columns[7:])\
                .withColumn(vcf_list[index].columns[-1], value_change(F.col(vcf_list[index].columns[-1]))), ["#CHROM", "POS"], "full")
    
        # sample window    
    temp = temp.withColumn(temp.columns[-1], F.last(temp.columns[-1], ignorenulls=True).over(sample_w))
    parquet_list.append(temp)

In [None]:
cnt = 0
num = 3
for parquet in join_split_inner(parquet_list, num):
    parquet.write.mode('overwrite')\
            .parquet("/raw_data/output/gvcf_output/"+ folder_name + "//" + "sample_" + str(cnt) + ".g.vcf")
    cnt += num
    #.withColumn(sample_name, value_change(F.col(sample_name)))

In [None]:
spark.catalog.clearCache()
spark.stop()

print(spark.read.parquet("/raw_data/output/gvcf_output/"+folder_name+"/sample_0.g.vcf").count())
print(spark.read.parquet("/raw_data/output/gvcf_output/"+folder_name+"/sample_3.g.vcf").count())
print(spark.read.parquet("/raw_data/output/gvcf_output/"+folder_name+"/info.g.vcf").count())

print(spark.read.parquet("/raw_data/output/gvcf_output/"+folder_name+"/sample_0.g.vcf").show())
print(spark.read.parquet("/raw_data/output/gvcf_output/"+folder_name+"/sample_3.g.vcf").show())
print(spark.read.parquet("/raw_data/output/gvcf_output/"+folder_name+"/info.g.vcf").show())