In [24]:
spark.catalog.clearCache()
spark.stop()

In [9]:
import findspark
findspark.init()

# Spark & python function
import pandas
import pyarrow
from pyspark.sql.functions import pandas_udf, udf
from pyspark.sql import SparkSession
from pyspark.sql.types import IntegerType, StringType
import pyspark.sql.functions as F
from pyspark.sql.functions import when
from pyspark import Row
from pyspark.sql.window import Window
from pyspark import StorageLevel

import re
import subprocess

spark = SparkSession.builder.master("spark://master:7077")\
                        .appName("gVCF_combine")\
                        .config("spark.driver.memory", "8G")\
                        .config("spark.driver.maxResultSize", "8G")\
                        .config("spark.executor.memory", "24G")\
                        .config("spark.sql.execution.arrow.enabled", "true")\
                        .config("spark.network.timeout", 10000000)\
                        .getOrCreate()

spark.sparkContext.setCheckpointDir("/usr/local/etc/SparkVCFtools/work_jupyter/checkpoints/")

                        #.config("spark.sql.shuffle.partitions",6
                        #.config("spark.redis.host", "210.115.229.97")\
                        #.config("spark.redis.port", "6379")\
                        #.config("spark.jars", "/spark-redis/target/spark-redis-2.4.1-SNAPSHOT-jar-with-dependencies.jar")\

In [10]:
def hadoop_list(length, hdfs):
    args = "hdfs dfs -ls "+ hdfs +" | awk '{print $8}'"
    proc = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
    s_output, s_err = proc.communicate()
    all_dart_dirs = s_output.split()
    
    return all_dart_dirs[:length]

def preVCF(hdfs, flag, spark):
    vcf = spark.sparkContext.textFile(hdfs)
    #header_contig = vcf.filter(lambda x : re.match("^#", x))
    col_name = vcf.filter(lambda x : x.startswith("#CHROM")).first().split("\t")
    vcf_data = vcf.filter(lambda x : re.match("[^#][^#]", x))\
                       .map(lambda x : x.split("\t"))\
                       .toDF(col_name)\
                       .withColumn("POS", F.col("POS").cast(IntegerType()))
    vcf_data = vcf_data.select(vcf_data.columns[:9])
    
    if flag == 1:
        for index in range(len(vcf_data.columns)):
            compared_arr = ["#CHROM", "POS", "REF"]
            if vcf_data.columns[index] in compared_arr:
                continue
            vcf_data = vcf_data.withColumnRenamed(vcf_data.columns[index], vcf_data.columns[index] + "_temp") 
    
    return vcf_data

def sampleVCF(hdfs, spark):
    vcf = spark.sparkContext.textFile(hdfs)
    col_name = vcf.filter(lambda x : x.startswith("#CHROM")).first().split("\t")
    vcf_data = vcf.filter(lambda x : re.match("[^#][^#]", x))\
                       .map(lambda x : x.split("\t"))\
                       .toDF(col_name)\
                       .withColumn("POS", F.col("POS").cast(IntegerType()))
    vcf_data = vcf_data.select(vcf_data.columns[:2] + [vcf_data.columns[-1], ])
    
    return vcf_data

def chunks(lst, n):
    for index in range(0, len(lst), n):
        yield lst[index:index + n]

In [11]:
hdfs = "hdfs://master:9000"
hdfs_list = hadoop_list(10, "/raw_data/gvcf")
info_window = Window.partitionBy("#CHROM").orderBy("POS")

for index in range(len(hdfs_list)):
    if index == 0:
        join_vcf = preVCF(hdfs + hdfs_list[index].decode("UTF-8"), 0, spark)
    else:
        join_vcf = join_vcf.join(preVCF(hdfs + hdfs_list[index].decode("UTF-8"), 1, spark), ["#CHROM", "POS", "REF"], "full")\
            .withColumn("ID", when(F.col("ID").isNull(), F.col("ID_temp")).otherwise(F.col("ID")))\
            .withColumn("ALT",when(F.col("ALT").isNull(), F.col("ALT_temp")).otherwise(F.col("ALT")))\
            .withColumn("FORMAT", when(F.col("FORMAT").isNull(), F.col("FORMAT_temp")).otherwise(F.col("FORMAT")))\
            .withColumn("QUAL", F.lit(".")).withColumn("FILTER", F.lit("."))\
            .withColumn("INFO", when(F.col("INFO").startswith("END") == False, F.col("INFO"))\
                        .when(F.col("INFO_temp").startswith("END") == False, F.col("INFO_temp")))\
            .drop("INFO_temp", "ID_temp", "ALT_temp", "FORMAT_temp", "QUAL_temp", "FILTER_temp")

join_vcf = join_vcf.withColumn("INFO", when(F.col("INFO").isNull(), F.concat(F.lit("END="), F.lead("POS", 1).over(info_window) - 1))\
                              .otherwise(F.col("INFO")))
join_vcf = join_vcf.orderBy(F.col("#CHROM"), F.col("POS")).cache()
chr_pos = join_vcf.select(F.col("#CHROM"), F.col("POS")).cache()

join_vcf.count()
chr_pos.count()

226934049

In [5]:
# 초기버젼
# sample
sample_window = Window.partitionBy(F.col("#CHROM")).orderBy(F.col("POS")).rangeBetween(Window.unboundedPreceding, Window.currentRow)
sample_list = list()
sample_join = list()

chr_pos = join_vcf.select(F.col("#CHROM"), F.col("POS"))

for index in range(len(hdfs_list)):
    sample_list.append(sampleVCF(hdfs + hdfs_list[index].decode("UTF-8"), spark))
    sample_join.append(chr_pos.join(sample_list[index], ["#CHROM", "POS"], "full")\
                       .withColumn(sample_list[index].columns[2], F.last(sample_list[index].columns[2], ignorenulls=True).over(sample_window))\
                       .orderBy(F.col("#CHROM"), F.col("POS")))

In [6]:
# finally join
for index in range(len(sample_join)):
    if index == 0:
        result = join_vcf.join(sample_join[index], ["#CHROM", "POS"], "inner")
        continue
    result = result.join(sample_join[index], ["#CHROM", "POS"], "inner")
    
    
result = result.orderBy(F.col("#CHROM"), F.col("POS")).dropDuplicates().cache()
result.count()
join_vcf.unpersist()

KeyboardInterrupt: 

In [None]:
result.show(3000)