spark.catalog.clearCache()
spark.stop()

In [1]:
import findspark
findspark.init()

# Spark function
from pyspark import Row, StorageLevel
from pyspark.sql import SparkSession, DataFrame
from pyspark.sql.functions import pandas_udf, udf, explode, array, when, PandasUDFType
from pyspark.sql.types import IntegerType, StringType, ArrayType, BooleanType, MapType
from pyspark.sql.window import Window
import pyspark.sql.functions as F

# Python function
import re
import subprocess
import pandas as pd
import pyarrow
from functools import reduce 
from collections import Counter
import copy
import operator
import itertools

appname = input("appname, folder name : ")
folder_name = copy.deepcopy(appname) 
gvcf_count = int(input("gvcf count : "))

# Start for Spark Session
spark = SparkSession.builder.master("spark://master:7077")\
                        .appName(appname)\
                        .config("spark.driver.memory", "8G")\
                        .config("spark.driver.maxResultSize", "5G")\
                        .config("spark.executor.memory", "25G")\
                        .config("spark.memory.fraction", 0.2)\
                        .config("spark.sql.shuffle.partitions", 100)\
                        .config("spark.eventLog.enabled", "true")\
                        .config("spark.cleaner.periodicGC.interval", "15min")\
                        .getOrCreate()

appname, folder name : gvcf_indel_0311
gvcf count : 20


In [2]:
def hadoop_list(length, hdfs):
    args = "hdfs dfs -ls "+ hdfs +" | awk '{print $8}'"
    proc = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
    s_output, s_err = proc.communicate()
    all_dart_dirs = s_output.split()
    return all_dart_dirs[:length]
    
def preVCF(hdfs, spark):
    vcf = spark.sparkContext.textFile(hdfs)
    #header_contig = vcf.filter(lambda x : re.match("^#", x))
    col_name = vcf.filter(lambda x : x.startswith("#CHROM")).first().split("\t")
    vcf_data = vcf.filter(lambda x : re.match("[^#][^#]", x))\
                       .map(lambda x : x.split("\t"))\
                       .toDF(col_name)\
                       .withColumn("POS", F.col("POS").cast(IntegerType()))\
                       .withColumn("FORMAT", F.array_remove(F.split(F.col("FORMAT"), ":"), "GT"))\
                       .withColumn("INFO", when(info_remove(F.col("INFO")), None).otherwise(F.col("INFO")))
    
    sample_name = vcf_data.columns[-1]
    vcf_data = vcf_data.drop("QUAL", "FILTER", sample_name)
    
    for index in range(len(vcf_data.columns)):
        compared_arr = ["#CHROM", "POS"]
        if vcf_data.columns[index] in compared_arr:
            continue
        #####
        vcf_data = vcf_data.withColumn(vcf_data.columns[index], F.array(vcf_data.columns[index]))
        #####
        vcf_data = vcf_data.withColumnRenamed(vcf_data.columns[index], vcf_data.columns[index] + "_" + sample_name)     
    return vcf_data

firstremove = udf(lambda value : value[1:], ArrayType(StringType()))
def sampleVCF(hdfs, spark):
    vcf = spark.sparkContext.textFile(hdfs)
    col_name = vcf.filter(lambda x : x.startswith("#CHROM")).first().split("\t")
    vcf_data = vcf.filter(lambda x : re.match("[^#][^#]", x))\
                       .map(lambda x : x.split("\t"))\
                       .toDF(col_name)\
                       .withColumn("POS", F.col("POS").cast(IntegerType()))
    
    sample_name = vcf_data.columns[-1]
    vcf_data = vcf_data.select(F.col("#CHROM"), F.col("POS"), F.col("FORMAT"), F.col(sample_name))\
                       .withColumn("FORMAT", F.array_remove(F.split(F.col("FORMAT"), ":"), "GT"))\
                       .withColumn(sample_name, firstremove(F.split(F.col(sample_name), ":")))\
                       .withColumn(sample_name, F.map_from_arrays(F.col("FORMAT"), F.col(sample_name)))
    return vcf_data.select("#CHROM", "POS", sample_name)

def gatkVCF(hdfs, spark):
    vcf = spark.sparkContext.textFile(hdfs)
    #header_contig = vcf.filter(lambda x : re.match("^#", x))
    col_name = vcf.filter(lambda x : x.startswith("#CHROM")).first().split("\t")
    vcf_data = vcf.filter(lambda x : re.match("[^#][^#]", x))\
                       .map(lambda x : x.split("\t"))\
                       .toDF(col_name)
    return vcf_data

def chunks(lst, n):
    for index in range(0, len(lst), n):
        yield lst[index:index + n]

def unionAll(*dfs):
    return reduce(DataFrame.unionByName, dfs) 

# for indel
word_len = udf(lambda col : True if len(col) >= 2 else False, returnType=BooleanType())
ref_melt = udf(lambda ref : list(ref)[1:], ArrayType(StringType()))    

def ref_concat(temp): 
    return_str = []
    for num in range(0, len(temp)):
        return_str.append(temp[num] + "_" + str(int(num + 1)))
    return return_str
ref_concat = udf(ref_concat, ArrayType(StringType()))

#@pandas_udf(ArrayType(StringType()))
def list_flat(value):
    value = list(filter(None, value))
    value = list(itertools.chain(*value))
    del value[0]
    return value
list_flat = udf(list_flat, ArrayType(StringType()))

def info_change(temp):
    some_list = temp.split(";")
    result = [i for i in some_list if i.startswith('DP=')]
    return result[0]
info_change = udf(info_change, StringType())

# for sample value
value_change = udf(lambda value : "./." + value[3:], StringType())

def reduce_join(left, right):   
    return_vcf = left.join(right, ["#CHROM", "POS"], "full")

    ###
    remove_colname = right.columns[2:]
    l_name = left.columns
    r_name = right.columns
    v_name = return_vcf.columns
    name_list = ["REF", "ID", "ALT", "INFO", "FORMAT"]
    
    for name in name_list:
        if name == "INFO":
            return_vcf = return_vcf.withColumn(column_name(l_name, name)[0], 
                                       when(F.isnull(column_name(l_name, name)[0]), F.col(column_name(r_name, name)[0]))\
                                       .when(F.isnull(column_name(r_name, name)[0]), F.col(column_name(l_name, name)[0]))
                                       .otherwise(F.array_union(*column_name(v_name, name))))
            
        return_vcf = return_vcf.withColumn(column_name(l_name, name)[0], 
                                       when(F.isnull(column_name(l_name, name)[0]), F.col(column_name(r_name, name)[0]))\
                                       .when(F.isnull(column_name(r_name, name)[0]), F.col(column_name(l_name, name)[0]))
                                       .otherwise(F.array_union(*column_name(v_name, name))))
    return_vcf = return_vcf.drop(*remove_colname)
                                        
    return return_vcf

def column_rename(vcf):
    name_list = ["REF", "ID", "ALT", "INFO", "FORMAT"]
    for name in name_list:
        vcf = vcf.withColumnRenamed(column_name(vcf.columns, name)[0], name)
    return vcf
    

def reduce_inner_join(left, right):   
    return_vcf = left.join(right, ["#CHROM", "POS"], "inner")
    return return_vcf

def column_name(df_col, name):
    return_list = []
    for col in df_col:
        if col.startswith(name):
            return_list.append(col)
    return return_list

info_remove = udf(lambda value : True if value.startswith("END=") else False, BooleanType())

def column_revalue(vcf):
    # info 값 수정 필요
    name_list = ["ID", "REF","ALT", "INFO", "FORMAT"]
    for name in name_list:
        if name == "FORMAT":
            vcf = vcf.withColumn(name, F.array_sort(F.array_distinct(F.flatten(F.col(name)))))
            vcf = vcf.withColumn(name, F.concat(F.lit("GT:"), F.array_join(F.col(name), ":")))
        else:
            vcf = vcf.withColumn(name, F.array_max(F.col(name)))
    return vcf

def indel_union(temp):
    split_col = F.split("REF_temp", '_')
    temp = temp.filter(word_len(F.col("REF")))\
            .withColumn("REF", ref_melt(F.col("REF"))).withColumn("REF", ref_concat(F.col("REF")))\
            .withColumn("REF", explode(F.col("REF"))).withColumnRenamed("REF", "REF_temp")\
            .withColumn('REF', split_col.getItem(0)).withColumn('POS_var', split_col.getItem(1))\
            .drop(F.col("REF_temp")).withColumn("POS", (F.col("POS") + F.col("POS_var")).cast(IntegerType()))\
            .drop(F.col("POS_var"))\
            .withColumn('ID', F.lit("."))\
            .withColumn('ALT', F.lit("*,<NON_REF>"))
    return temp

# 10개 기준임.
def join_split(v_list):
    stage1_list = list(chunks(v_list, 5))
    if len(v_list) == 1:
        return v_list    
    stage1 = []
    for vcf in stage1_list:
        if len(vcf) == 1:
            stage1.append(vcf)
        else :
            stage1.append(reduce(reduce_join, vcf))
    return reduce(reduce_join, stage1)

def join_split_inner(v_list, num):
    stage1_list = list(chunks(v_list, num))
    stage1_list = list(map())
    stage1 = []
    for vcf in stage1_list:
        stage1.append(reduce(reduce_inner_join, vcf))
    return stage1

def value_merge(find_value, sample):
    del sample["GT"]
    if len(find_value) == len(sample):
        return_string = list(sample.values())
        return ":".join(return_string)
    else :
        return_string = []
        for value in find_value:
            format_value = sample.get(value)
            if format_value == None:
                format_value = "."
            return_string.append(format_value)      
        return ":".join(return_string)
value_merge = udf(value_merge, StringType())

def index2dict(value, index):
    temp = ["." for i in range(len(index))]
    sample = dict(zip(index, temp))
    merge_dict = {**value, **sample}
    sort_dict = sorted(merge_dict.items(), key=operator.itemgetter(0))
    return reduce(lambda x, y: (0, x[1] + ":" + y[1]), sort_dict)[1]
index2dict = udf(index2dict, StringType())

nullCheck = udf(lambda value : len(value) == 0, BooleanType())

def str2list(value):
    temp = ["." for index in range(len(value))]
    return temp
str2list = udf(str2list, ArrayType(StringType()))

def dictsort(value):
    sort_dict = sorted(value.items(), key=operator.itemgetter(0))
    return reduce(lambda x, y: (0, x[1] + ":" + y[1]), sort_dict)[1]
dictsort = udf(dictsort, StringType())

def find_duplicate(temp):
    return temp.groupBy(F.col("#CHROM"), F.col("POS")).agg((F.count("*")>1).cast("int").alias("e"))\
               .orderBy(F.col("e"), ascending=False)

def vcf_join(v_list):
    chunks_list = list(chunks(v_list, 10))
    map_list = list(map(join_split, chunks_list))
    if len(map_list) <= 1:
        return map_list
    elif len(map_list) == 2:
        return reduce(reduce_join, map_list)
    else:
        flag = True
        while flag == True:
            if len(map_list) == 2:
                flag = False
            else :
                map_list = list(chunks(map_list, 2))
                map_list = list(map(join_split, map_list))
        return reduce(reduce_join, map_list)    
    
def list2dict_del(value):
    temp = ["." for index in range(len(value))]
    sample = dict(zip(value, temp))
    return sample
list2dict_del = udf(list2dict_del, returnType=MapType(StringType(), StringType()))

def index2dict_del(value, FORMAT):
    for index in list(value.keys()):
        if index in FORMAT:
            FORMAT[index] = value[index]
    return ":".join(list(FORMAT.values()))
index2dict_del = udf(index2dict_del, StringType())

In [None]:
#gvcf_count = 20
hdfs = "hdfs://master:9000"
hdfs_list = hadoop_list(gvcf_count, "/raw_data/gvcf")
vcf_list = list()

for index in range(len(hdfs_list)):
    vcf_list.append(preVCF(hdfs + hdfs_list[index].decode("UTF-8"), spark).cache())
    
vcf = column_rename(vcf_join(vcf_list))
vcf = column_revalue(vcf).persist(StorageLevel.MEMORY_ONLY)
vcf.count()

In [None]:
#window
info_window = Window.partitionBy("#CHROM").orderBy("POS")
vcf_not_indel = vcf.withColumn("INFO", when(F.col("INFO").isNull(), F.concat(F.lit("END="), F.lead("POS", 1).over(info_window) - 1))\
                              .otherwise(F.col("INFO")))

# dropduplicates할 때, indel 삭제되는 경우 있음, 아직 수정 안 됨.!
# indel union & parquet write
unionAll(*[indel_union(vcf), vcf_not_indel])\
                 .orderBy(F.col("#CHROM"), F.col("POS"))\
                 .dropDuplicates(["#CHROM", "POS"])\
                 .write.mode('overwrite')\
                 .parquet("/raw_data/output/gvcf_output/" + folder_name + "//info.g.vcf")

spark.catalog.clearCache()
spark.stop()

## Parquet write for sample

In [33]:
cnt = 0
num = 4
part_num = 80
# Start for Spark Session with write
spark = SparkSession.builder.master("spark://master:7077")\
                        .appName(appname + "_sample" + str(num) + "_" + str(part_num))\
                        .config("spark.driver.memory", "8G")\
                        .config("spark.driver.maxResultSize", "5G")\
                        .config("spark.executor.memory", "25G")\
                        .config("spark.sql.shuffle.partitions", part_num)\
                        .config("spark.eventLog.enabled", "true")\
                        .config("spark.memory.fraction", 0.05)\
                        .config("spark.cleaner.periodicGC.interval", "15min")\
                        .getOrCreate()

#.config("spark.cleaner.periodicGC.interval", "1min")

In [34]:
hdfs = "hdfs://master:9000"
hdfs_list = hadoop_list(gvcf_count, "/raw_data/gvcf")
vcf_list = list()
for index in range(len(hdfs_list)):
    vcf_list.append(sampleVCF(hdfs + hdfs_list[index].decode("UTF-8"), spark))

In [35]:
# info parquet read and cache()
indel_com = spark.read.parquet("/raw_data/output/gvcf_output/" + folder_name + "//info.g.vcf")\
                 .select(["#CHROM","POS","FORMAT"])\
                 .withColumn("FORMAT", F.array_remove(F.split(F.col("FORMAT"), ":"), "GT"))\
                 .orderBy(F.col("#CHROM"), F.col("POS")).persist(StorageLevel.MEMORY_ONLY)
indel_com.count()

364124918

In [36]:
sample_w = Window.partitionBy(F.col("#CHROM")).orderBy(F.col("POS")).rangeBetween(Window.unboundedPreceding, Window.currentRow)  
parquet_list = list()

for index in range(len(vcf_list)):
    temp = indel_com.join(vcf_list[index], ["#CHROM", "POS"], "full").repartition(F.col("#CHROM"))
    sample_name = temp.columns[-1]
    
    temp = temp.withColumn(sample_name, F.last(sample_name, ignorenulls=True).over(sample_w))\
             .withColumn("key", F.map_keys(F.col(sample_name)))\
             .withColumn("index", F.array_remove(F.array_except("FORMAT", "key"), "SB"))\
             .drop("key")

    null_not_value = temp.filter(F.map_keys(F.col(sample_name)) != F.col("FORMAT")).drop("FORMAT")\
                     .repartition(F.col("#CHROM"), F.col("POS"), F.col(sample_name), F.col("index"))\
                     .withColumn(sample_name, F.concat(F.lit("./.:"), index2dict(F.col(sample_name), F.col("index"))))\
                     .drop("index").drop("FORMAT")\
    
    null_value = temp.filter(F.map_keys(F.col(sample_name)) == F.col("FORMAT")).drop("FORMAT", "index")\
             .withColumn(sample_name, F.concat(F.lit("./.:"), F.array_join(F.map_values(F.col(sample_name)), ":")))
    
    value_union = null_not_value.union(null_value)
    parquet_list.append(value_union)    

In [37]:
for parquet in join_split_inner(parquet_list, num):
    parquet.write.mode('overwrite')\
            .parquet("/raw_data/output/gvcf_output/"+ folder_name + "//" + "sample_" + str(cnt) + ".g.vcf")
    cnt += num
    #.withColumn(sample_name, value_change(F.col(sample_name)))

In [32]:
spark.catalog.clearCache()
spark.stop()