In [None]:
spark.catalog.clearCache()
spark.stop()

In [1]:
import findspark
findspark.init()

# Spark function
from pyspark import Row, StorageLevel
from pyspark.sql import SparkSession, DataFrame
from pyspark.sql.functions import pandas_udf, udf, explode, array, when
from pyspark.sql.types import IntegerType, StringType, ArrayType, BooleanType, StructType, StructField
from pyspark.sql.window import Window
import pyspark.sql.functions as F
from pyspark.sql.functions import pandas_udf, PandasUDFType

# Python function
import re
import subprocess
import numpy as np
import pandas as pd
import pyarrow
from functools import reduce 
import copy

appname = input("appname, folder name : ")
#folder_name = copy.deepcopy(appname) 
#gvcf_count = int(input("gvcf count : "))

# Start for Spark Session
spark = SparkSession.builder.master("spark://master:7077")\
                        .appName(appname)\
                        .config("spark.driver.memory", "8G")\
                        .config("spark.driver.maxResultSize", "8G")\
                        .config("spark.executor.memory", "24G")\
                        .config("spark.executor.core", 3)\
                        .config("spark.sql.execution.arrow.enabled", "true")\
                        .config("spark.network.timeout", "9999s")\
                        .config("spark.files.fetchTimeout", "9999s")\
                        .config("spark.sql.shuffle.partitions", 80)\
                        .config("spark.eventLog.enabled", "true")\
                        .config("spark.sql.execution.arrow.maxRecordsPerBatch", 10000000)\
                        .getOrCreate()

appname, folder name : 


In [2]:
def hadoop_list(length, hdfs):
    args = "hdfs dfs -ls "+ hdfs +" | awk '{print $8}'"
    proc = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
    s_output, s_err = proc.communicate()
    all_dart_dirs = s_output.split()
    
    return all_dart_dirs[:length]

def preVCF(hdfs, flag, spark):
    vcf = spark.sparkContext.textFile(hdfs)
    #header_contig = vcf.filter(lambda x : re.match("^#", x))
    col_name = vcf.filter(lambda x : x.startswith("#CHROM")).first().split("\t")
    vcf_data = vcf.filter(lambda x : re.match("[^#][^#]", x))\
                       .map(lambda x : x.split("\t"))\
                       .toDF(col_name)\
                       .withColumn("POS", F.col("POS").cast(IntegerType()))\
                       .drop("QUAL","FILTER")
    if flag == 0 :
        vcf_data = vcf_data.drop(vcf_data.columns[-1])
    elif flag == 1:
        vcf_data = vcf_data.selecet(["#CHROM", "POS"] + [vcf_data.columns[-1]])
    return vcf_data

def chunks(lst, n):
    for index in range(0, len(lst), n):
        yield lst[index:index + n]
        
def addIndex(POS, size):
    if POS == 1:
        return POS
    else :
        return int(POS / size + 1) 
addIndex_udf = udf(addIndex, returnType=IntegerType())

# for indel
word_len = udf(lambda col : True if len(col) >= 2 else False, returnType=BooleanType())
ref_melt = udf(lambda ref : list(ref)[1:], ArrayType(StringType()))    

def ref_concat(temp): 
    return_str = []
    for num in range(0, len(temp)):
        return_str.append(temp[num] + "_" + str(int(num + 1)))
    return return_str
ref_concat = udf(ref_concat, ArrayType(StringType()))

def info_change(temp):
    some_list = temp.split(";")
    result = [i for i in some_list if i.startswith('DP=')]
    return result[0]
info_change = udf(info_change, StringType())

def unionAll(*dfs):
    return reduce(DataFrame.unionByName, dfs) 

# for sample value
value_change = udf(lambda value : "./." + value[3:], StringType())

# for POS index
def sampling_func(data, ran):
    N = len(data)
    sample = data.take(range(0, N, ran))
    return sample

def sample_join(left, right):
    return left.select(F.col("#CHROM"), F.col("POS"), left.columns[-1])\
        .join(right.select(F.col("#CHROM"), F.col("POS"), right.columns[-1]), ["#CHROM", "POS"], "full").orderBy(F.col("#CHROM"), F.col("POS"))

def union_all(dfs):
    return reduce(DataFrame.unionAll, dfs)

def outer_union_all(dfs):
    all_cols = set([])
    for df in dfs:
        all_cols |= set(df.columns) 
    all_cols = list(all_cols)
    print(all_cols)

    def expr(cols, all_cols):

        def append_cols(col):
            if col in cols:
                return col
            else:
                return F.lit(None).alias(col)

        cols_ = map(append_cols, all_cols)
        return list(cols_)

    union_df = union_all([df.select(expr(df.columns, all_cols)) for df in dfs])
    return union_df

# pandas udf
value_max = udf(lambda value : max(value), StringType())

def info_min(info):
    temp = [value for value in info if value.startswith("END=") == False]
    # info 값 수정
    return "%".join(temp)
info_min = udf(info_min, StringType())

def format_value(value): 
    # ##FORMAT=<ID=SB,Number=4,Type=Integer,Description="Per-sample component statistics which comprise the Fisher's Exact Test to detect strand bias.">
    if len(value) == 1:
        value.append("GT:DP:GQ:MIN_DP:PL")
    def format_reduce(left, right):
        left, right = left.split(":"), right.split(":")
        if len(left) <= len(right):        
            temp = copy.deepcopy(right)
            right = copy.deepcopy(left)
            left = copy.deepcopy(temp)
        for value in right:
            if value not in left:
                left.append(value)
        return ":".join(left)
    return str(reduce(format_reduce, value))
format_value = udf(format_value, StringType())


In [3]:
hdfs = "hdfs://master:9000"
hdfs_list = hadoop_list(15, "/raw_data/gvcf")

In [4]:
vcf[preVCF(hdfs + sample.decode("UTF-8"), 0, spark) for sample in hdfs_list]
vcf = union_all([preVCF(hdfs + sample.decode("UTF-8"), 0, spark) for sample in hdfs_list])\
            .groupBy("#CHROM", "POS").agg(value_max(F.collect_set("ID")).alias("ID"), value_max(F.collect_set("REF")).alias("REF"), 
                                      value_max(F.collect_set("ALT")).alias("ALT"), info_min(F.collect_set("INFO")).alias("INFO"),
                                      format_value(F.collect_set("FORMAT")).alias("FORMAT")).cache()
vcf.count()

286988405

In [None]:
split_col = F.split("REF_temp", '_')
info_indel = vcf.filter(word_len(F.col("REF")))\
            .withColumn("REF", ref_melt(F.col("REF"))).withColumn("REF", ref_concat(F.col("REF")))\
            .withColumn("REF", explode(F.col("REF"))).withColumnRenamed("REF", "REF_temp")\
            .withColumn('REF', split_col.getItem(0)).withColumn('POS_var', split_col.getItem(1))\
            .drop(F.col("REF_temp")).withColumn("POS", (F.col("POS") + F.col("POS_var")).cast(IntegerType()))\
            .drop(F.col("POS_var"))\
            .withColumn('ID', F.lit("."))\
            .withColumn('ALT', F.lit("*,<NON_REF>"))\
            .withColumn("INFO", info_change(F.col("INFO")))\
            .orderBy(F.col("#CHROM"), F.col("POS"))\
            .cache()

In [None]:
vcf.filter(word_len(F.col("REF"))).withColumn("REF", ref_melt(F.col("REF"))).withColumn("REF", ref_concat(F.col("REF")))\
   .withColumn("REF", explode(F.col("REF"))).withColumnRenamed("REF", "REF_temp")\
   .withColumn('REF', split_col.getItem(0)).withColumn('POS_var', split_col.getItem(1))\
   .withColumn("POS", (F.col("POS") + F.col("POS_var")).cast(IntegerType())).drop("REF_temp","POS_var")\
   .withColumn('ID', F.lit("."))\
   .withColumn('ALT', F.lit("*,<NON_REF>")).show()

In [None]:
info_indel.show()

In [None]:
@pandas_udf("array<string>", PandasUDFType.GROUPED_AGG)
def value_collect(value):
    return value

@pandas_udf("String", PandasUDFType.GROUPED_AGG)
def value_max(value):
    return max(value)

@pandas_udf("array<string>", PandasUDFType.GROUPED_AGG)
def info_min(info):
    temp = [value for value in info if value.startswith("END=") == False]
    # info 값 수정
    return temp

@pandas_udf("String", PandasUDFType.GROUPED_AGG)
def format_value(value): 
    # ##FORMAT=<ID=SB,Number=4,Type=Integer,Description="Per-sample component statistics which comprise the Fisher's Exact Test to detect strand bias.">
    if len(value) == 1:
        value.append("GT:DP:GQ:MIN_DP:PL")
    def format_reduce(left, right):
        left, right = left.split(":"), right.split(":")
        if len(left) <= len(right):        
            temp = copy.deepcopy(right)
            right = copy.deepcopy(left)
            left = copy.deepcopy(temp)
        for value in right:
            if value not in left:
                left.append(value)
        return ":".join(left)
    return str(reduce(format_reduce, value))

In [None]:
df = spark.createDataFrame(
    [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
    ("id", "v"))
df.groupby("id").agg(mean_udf(df['v'])).orderBy(F.col("mean_udf(v)")).show()

In [None]:
hdfs = "hdfs://master:9000"
hdfs_list = hadoop_list(3, "/raw_data/gvcf")

In [None]:
vcf = union_all([preVCF(hdfs + sample.decode("UTF-8"), 0, spark) for sample in hdfs_list]).cache()
vcf.count()

In [None]:
temp = vcf.groupby("#CHROM", "POS").agg(value_collect(vcf['ID']), value_collect(vcf["REF"]), value_collect(vcf["ALT"]), value_collect(vcf["INFO"]), 
                                value_collect(vcf["FORMAT"])).cache()

In [None]:
temp.count()