spark.catalog.clearCache()

spark.stop()

In [1]:
import findspark
findspark.init()

# Spark function
from pyspark import Row, StorageLevel
from pyspark.sql import SparkSession, DataFrame
from pyspark.sql.functions import pandas_udf, udf, explode, array, when
from pyspark.sql.types import IntegerType, StringType, ArrayType,BooleanType
from pyspark.sql.window import Window
import pyspark.sql.functions as F

# Python function
import re
import subprocess
import numpy as np
import pandas
import pyarrow
from functools import reduce 
import copy

appname = input("appname, folder name : ")
folder_name = copy.deepcopy(appname) 
gvcf_count = int(input("gvcf count : "))

# Start for Spark Session
spark = SparkSession.builder.master("spark://master:7077")\
                        .appName(appname)\
                        .config("spark.driver.memory", "8G")\
                        .config("spark.driver.maxResultSize", "8G")\
                        .config("spark.executor.memory", "24G")\
                        .config("spark.executor.core", 3)\
                        .config("spark.sql.execution.arrow.enabled", "false")\
                        .config("spark.sql.execution.arrow.fallback.enabled", "false")\
                        .config("spark.network.timeout", "9999s")\
                        .config("spark.files.fetchTimeout", "9999s")\
                        .config("spark.sql.shuffle.partitions", 40)\
                        .config("spark.eventLog.enabled", "true")\
                        .getOrCreate()

appname, folder name : gvcf_indel_20
gvcf count : 20


In [2]:
def hadoop_list(length, hdfs):
    args = "hdfs dfs -ls "+ hdfs +" | awk '{print $8}'"
    proc = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
    s_output, s_err = proc.communicate()
    all_dart_dirs = s_output.split()
    
    return all_dart_dirs[:length]

def preVCF(hdfs, flag, spark):
    vcf = spark.sparkContext.textFile(hdfs)
    #header_contig = vcf.filter(lambda x : re.match("^#", x))
    col_name = vcf.filter(lambda x : x.startswith("#CHROM")).first().split("\t")
    vcf_data = vcf.filter(lambda x : re.match("[^#][^#]", x))\
                       .map(lambda x : x.split("\t"))\
                       .toDF(col_name)\
                       .withColumn("POS", F.col("POS").cast(IntegerType()))
    
    if flag == 1:
        for index in range(len(vcf_data.columns) - 1):
            compared_arr = ["#CHROM", "POS", "REF"]
            if vcf_data.columns[index] in compared_arr:
                continue
            vcf_data = vcf_data.withColumnRenamed(vcf_data.columns[index], vcf_data.columns[index] + "_temp") 
    
    return vcf_data

def chunks(lst, n):
    for index in range(0, len(lst), n):
        yield lst[index:index + n]
        
def addIndex(POS, size):
    if POS == 1:
        return POS
    else :
        return int(POS / size + 1) 
addIndex_udf = udf(addIndex, returnType=IntegerType())

# for indel
word_len = udf(lambda col : True if len(col) >= 2 else False, returnType=BooleanType())
ref_melt = udf(lambda ref : list(ref)[1:], ArrayType(StringType()))    

def ref_concat(temp): 
    return_str = []
    for num in range(0, len(temp)):
        return_str.append(temp[num] + "_" + str(int(num + 1)))
    return return_str
ref_concat = udf(ref_concat, ArrayType(StringType()))

def info_change(temp):
    some_list = temp.split(";")
    result = [i for i in some_list if i.startswith('DP=')]
    return result[0]
info_change = udf(info_change, StringType())

def unionAll(*dfs):
    return reduce(DataFrame.unionByName, dfs) 

# for sample value
value_change = udf(lambda value : "./." + value[3:], StringType())

# for POS index
def sampling_func(data, ran):
    N = len(data)
    sample = data.take(range(0, N, ran))
    return sample

def sample_join(left, right):
    return left.select(F.col("#CHROM"), F.col("POS"), F.col("REF"), left.columns[-1])\
        .join(right.select(F.col("#CHROM"), F.col("POS"), F.col("REF"), right.columns[-1]), ["#CHROM", "POS", "REF"], "full")

In [None]:
hdfs = "hdfs://master:9000"
hdfs_list = hadoop_list(gvcf_count, "/raw_data/gvcf")
vcf_list = list()
sample_list = list()

for index in range(len(hdfs_list)):
    if index == 0:
        vcf_list.append(preVCF(hdfs + hdfs_list[index].decode("UTF-8"), 0, spark).cache())
        join_vcf = vcf_list[index]
    else:
        vcf_list.append(preVCF(hdfs + hdfs_list[index].decode("UTF-8"), 1, spark).cache())
        join_vcf = join_vcf.join(vcf_list[index], ["#CHROM", "POS", "REF"], "full")\
            .withColumn("ID", when(F.col("ID").isNull(), F.col("ID_temp")).otherwise(F.col("ID")))\
            .withColumn("ALT",when(F.col("ALT").isNull(), F.col("ALT_temp")).otherwise(F.col("ALT")))\
            .withColumn("FORMAT", when(F.col("FORMAT").isNull(), F.col("FORMAT_temp")).otherwise(F.col("FORMAT")))\
            .withColumn("QUAL", F.lit(".")).withColumn("FILTER", F.lit("."))\
            .withColumn("INFO", when(F.col("INFO").startswith("END") == False, F.col("INFO"))\
                        .when(F.col("INFO_temp").startswith("END") == False, F.col("INFO_temp")))\
            .drop("INFO_temp", "ID_temp", "ALT_temp", "FORMAT_temp", "QUAL_temp", "FILTER_temp")
        
        if index == len(hdfs_list) - 1:
            info = info.dropDuplicates()         

for sample in join_vcf.columns[9:]:
    sample_list.append(join_vcf.select(F.col(sample)).cache())
    join_vcf = join_vcf.drop(F.col(sample))



# indel union
split_col = F.split("REF_temp", '_')
indel = info.filter(word_len(F.col("REF")))\
            .withColumn("REF", ref_melt(F.col("REF"))).withColumn("REF", ref_concat(F.col("REF")))\
            .withColumn("REF", explode(F.col("REF"))).withColumnRenamed("REF", "REF_temp")\
            .withColumn('REF', split_col.getItem(0)).withColumn('POS_var', split_col.getItem(1))\
            .drop(F.col("REF_temp")).withColumn("POS", (F.col("POS") + F.col("POS_var")).cast(IntegerType()))\
            .drop(F.col("POS_var"))\
            .withColumn('ID', F.lit("."))\
            .withColumn('ALT', F.lit("*,<NON_REF>"))\
            .withColumn("INFO", info_change(F.col("INFO"))).cache()
            
# window
info_window = Window.partitionBy("#CHROM").orderBy("POS")
join_vcf = join_vcf.withColumn("INFO", when(F.col("INFO").isNull(), F.concat(F.lit("END="), F.lead("POS", 1).over(info_window) - 1))\
                              .otherwise(F.col("INFO")))\
           .orderBy(F.col("#CHROM"), F.col("POS"))

join_vcf = join_vcf.cache()
join_vcf.count()

join_union = unionAll(*[info, indel]).orderBy(F.col("#CHROM"), F.col("POS")).cache()
join_union.count()
join_vcf.unpersist()
print("done")

In [None]:
vcf_list_pos_index = list()
sample_w = Window.partitionBy(F.col("#CHROM")).orderBy(F.col("POS")).rangeBetween(Window.unboundedPreceding, Window.currentRow)  

for index in range(len(vcf_join_list)):
    temp = info_index.select(["#CHROM","POS","REF","POS_INDEX"])\
          .join(vcf_join_list[index].select(["#CHROM", "POS", "REF"] + vcf_join_list[index].columns[3:]), ["#CHROM", "POS", "REF"], "full")
    
    # sample window
    for sample_name in temp.columns[4:]:     
        temp = temp.withColumn(sample_name, when(F.col(sample_name).isNull(), F.last(sample_name, ignorenulls=True).over(sample_w))\
                                                   .otherwise(F.col(sample_name)))\
                   .withColumn(sample_name, value_change(F.col(sample_name)))
        
    vcf_list_pos_index.append(temp)

In [None]:
for write_parquet in vcf_list_pos_index:
    write_parquet.drop(F.col("POS_INDEX"))\
                    .write.partitionBy("#CHROM")\
                    .mode('overwrite')\
                    .parquet("/raw_data/output/gvcf_output/"+ folder_name + "//" + "_".join(write_parquet.columns[4:]) + ".g.vcf")
        
info_index.drop(F.col("POS_INDEX"))\
            .write.partitionBy("#CHROM").mode('overwrite')\
            .parquet("/raw_data/output/gvcf_output/" + folder_name + "//info.g.vcf")