In [None]:
spark.catalog.clearCache()
spark.stop()

In [1]:
import findspark
findspark.init()

# Spark & python function
import pandas
import pyarrow
from pyspark.sql.functions import pandas_udf, udf
from pyspark.sql import SparkSession
from pyspark.sql.types import IntegerType, StringType
import pyspark.sql.functions as F
from pyspark.sql.functions import when
from pyspark import Row
from pyspark.sql.window import Window
from pyspark import StorageLevel

import re
import subprocess

spark = SparkSession.builder.master("spark://master:7077")\
                        .appName("gVCF_combine")\
                        .config("spark.executor.memory", "24G")\
                        .config("spark.executor.core", "3")\
                        .config("spark.sql.shuffle.partitions",20)\
                        .config("spark.driver.memory", "8G")\
                        .config("spark.driver.maxResultSize", "8G")\
                        .config("spark.sql.execution.arrow.enabled", "true")\
                        .getOrCreate()

In [2]:
def preVCF(hdfs, flag, spark):
    vcf = spark.sparkContext.textFile(hdfs)
    # drop ---> QUAL FILTER column
    header_contig = vcf.filter(lambda x : re.match("^#", x))
    col_name = vcf.filter(lambda x : x.startswith("#CHROM")).first().split("\t")
    vcf_data = vcf.filter(lambda x : re.match("[^#][^#]", x))\
                       .map(lambda x : x.split("\t"))\
                       .toDF(col_name)\
                       .withColumn("POS", F.col("POS").cast(IntegerType()))
    
    if flag == 1:
        for index in range(len(vcf_data.columns[:9])):
            compared_arr = ["#CHROM", "POS", "REF"]
            if vcf_data.columns[index] in compared_arr:
                continue
            vcf_data = vcf_data.withColumnRenamed(vcf_data.columns[index], vcf_data.columns[index] + "_temp") 
    
    return vcf_data

def hadoop_list(length, hdfs):
    args = "hdfs dfs -ls "+ hdfs +" | awk '{print $8}'"
    proc = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
    s_output, s_err = proc.communicate()
    all_dart_dirs = s_output.split()
    
    return all_dart_dirs[:length]

def selectNotNull(left, right):
    if left == None:
        return right
    else:
        return left
selectNotNull_u = udf(selectNotNull, returnType=StringType())   
def qual_filter(none = None):
    return "."
qual_filter_u = udf(qual_filter, returnType=StringType())

def chunks(lst, n):
    for i in range(0, len(lst), n):
        yield lst[i:i + n]

In [4]:
# main
hdfs = "hdfs://master:9000"
hdfs_list = hadoop_list(5, "/raw_data/gvcf")

w = Window.partitionBy("#CHROM").orderBy("POS")
sample_w = Window.partitionBy("#CHROM").orderBy("POS").rangeBetween(Window.unboundedPreceding, Window.currentRow)   

# all files join
for index in range(len(hdfs_list)):
    if index == 0:
        join_vcf = preVCF(hdfs + hdfs_list[index].decode("UTF-8"), 0, spark)
    else:
        join_vcf = join_vcf.join(preVCF(hdfs + hdfs_list[index].decode("UTF-8"), 1, spark), ["#CHROM", "POS", "REF"], "full")\
            .withColumn("ID", when(F.col("ID").isNull(), F.col("ID_temp")).otherwise(F.col("ID")))\
            .withColumn("ALT",when(F.col("ALT").isNull(), F.col("ALT_temp")).otherwise(F.col("ALT")))\
            .withColumn("FORMAT", when(F.col("FORMAT").isNull(), F.col("FORMAT_temp")).otherwise(F.col("FORMAT")))\
            .withColumn("QUAL", F.lit(".")).withColumn("FILTER", F.lit("."))\
            .withColumn("INFO", when(F.col("INFO").startswith("END") == False, F.col("INFO"))\
                        .when(F.col("INFO_temp").startswith("END") == False, F.col("INFO_temp")))\
            .drop("INFO_temp", "ID_temp", "ALT_temp", "FORMAT_temp", "QUAL_temp", "FILTER_temp")

join_vcf = join_vcf.withColumn("INFO", when(F.col("INFO").isNull(), F.concat(F.lit("END="), F.lead("POS", 1).over(w) - 1))\
                              .otherwise(F.col("INFO")))

# per sample value update(block) using SQL window
sample_list = []
count = 0
for sample_name in chunks(join_vcf.columns[9:], 3):
    if count == 0:
        sample_list.append(join_vcf.select(join_vcf.columns[:9] + [col for col in sample_name]))
    else :
        sample_list.append(join_vcf.select(["#CHROM","POS"] + [col for col in sample_name]))
    for index in range(len(sample_name)): 
        sample_list[count] = sample_list[count].withColumn(sample_name[index], F.last(sample_name[index], ignorenulls=True).over(sample_w))
    count += 1
    

# finally join
for index in range(len(sample_list)):
    if index == 0 :
        result = sample_list[0].join(sample_list[index], ["#CHROM", "POS"], "inner")
    else:
        result = result.join(sample_list[index], ["#CHROM", "POS"], "inner")

In [None]:
result = result.orderBy(F.col("#CHROM"), F.col("POS")).show(300)

In [None]:
join_vcf.orderBy