## Init Spark using Standalone

In [1]:
import findspark
findspark.init()

from pyspark.sql import SparkSession
from pyspark.sql.functions import udf, col, desc, asc, coalesce, broadcast
from pyspark import Row
import re

spark = SparkSession.builder.master("spark://master:7077")\
                    .appName("vcf_merge_1212")\
                    .config("spark.executor.memory", "14G")\
                    .config("spark.executor.core", "2")\
                    .config("spark.sql.shuffle.partitions", 20)\
                    .getOrCreate()

#  .config("driver.memory", "20g")\

In [6]:
spark.catalog.clearCache()

In [7]:
spark.stop()

## VCF merge Function

In [2]:
def chr_remove(chrom):
    chrom = re.sub("chr", "", chrom) # "chr" to ""
    if chrom == "X": 
        chrom = "23"
    elif chrom == "Y": 
        chrom = "24"
    elif chrom == "XY" or chrom == "M": 
        chrom = "-99"
    return chrom
chr_remove_udf = udf(chr_remove)

def alt_filter(row):
    if "," in row[4]:
        temp = row[4].split(",")
        temp.sort()
        row_new = [",".join(temp), ]
        
        return row[:4] + row_new + row[5:]
    
    else:
        return row

def preVCF(hdfs, flag): # hdfs://, flag 0 == lhs, 1 == rhs
    vcf = spark.sparkContext.textFile(hdfs).map(lambda x : x.split("\t"))
    header = vcf.first()
    step1 = vcf.filter(lambda row : row != header).map(alt_filter).toDF(header)
    return_vcf = step1.select(chr_remove_udf(step1["#CHROM"]).cast("Integer").alias("CHROM"), "*")\
                      .drop(col("#CHROM")).filter(col("FILTER") == "PASS")
    if flag == 1:
        for index in range(len(return_vcf.columns[:9])):
            return_vcf = return_vcf.withColumnRenamed(return_vcf.columns[index], return_vcf.columns[index] + "_temp") 
    return return_vcf.coalesce(20)

def rowTodict(format_, row):
    return_col = []
    for ref in row:
        temp_dict = dict()
        temp = ref.split(":")
        for index in range(len(temp)):
            temp_dict[format_[index]] = temp[index]
        return_col.append(temp_dict)
    return return_col

def dictToFormat(col_value, d_format):
    result_return = []
    for temp in col_value:
        temp_col = []
        for keys in d_format:
            if keys in temp:
                temp_col.append(temp[keys])
            else:
                temp_col.append(".")
        result_return.append(":".join(temp_col))
    return tuple(result_return)

def selectCol(row, lhs_len, rhs_len):
    # INFO re      
    AC, AN = 0, 0 
    
    if row[9] == None :
        GT = row[lhs_len + 9:]
    elif row[lhs_len + 9] == None :
        GT = row[9:lhs_len]
    else:
        GT = row[9:lhs_len]+row[lhs_len + 9:]
        
    for temp in GT:
        if temp == None:
            break
        else:
            if "0/1:" in temp:
                AC += 1
                AN += 1
            elif "1/1:" in temp:
                AC += 2
                AN += 1
            elif "0/0:" in temp:
                AN += 1
    
    # rhs가 null
    if(row.CHROM_temp == None):
        temp = tuple()
        for ref in range(rhs_len - 9):
            temp += ("0/0",) # GC
            AN += 1

        # info
        AN *= 2
        info = ("AC="+str(AC)+";AN="+str(AN)+";SF=0",)
        return row[:5] + (float(row.QUAL),) + (row.FILTER, ) + info + (row[8],) + row[9:lhs_len] + temp
    
    # lhs가 null
    elif(row.CHROM == None):
        temp = tuple()
        for ref in range(lhs_len - 9):
            temp += ("0/0",) # GC
            AN += 1
           
        # info
        AN *= 2
        info = ("AC="+str(AC)+";AN="+str(AN)+";SF=1",)
        return row[lhs_len:lhs_len + 5] + (float(row.QUAL_temp), ) + (row.FILTER_temp, ) + info + (row.FORMAT_temp,) + temp + row[lhs_len + 9:]
    
    # case, control 둘다 존재
    else:
        
        # QUAL re-calculation
        format_, lhs_format, rhs_format = row[8].split(":")+row[lhs_len + 8].split(":"), row[8].split(":"), row[lhs_len + 8].split(":")
        dup_format, lhs_col, rhs_col = [], rowTodict(lhs_format, row[9:lhs_len]), rowTodict(rhs_format, row[lhs_len + 9:])
        
        # format duplicate
        for dup in format_:
            if dup not in dup_format:
                dup_format.append(dup)
        
        result_lhs, result_rhs = dictToFormat(lhs_col, dup_format), dictToFormat(rhs_col, dup_format)
        
        # qual re-calcualtion # 100
        col_total = lhs_len + rhs_len - 18
        lhs_QUAL = float(row.QUAL) * ((lhs_len - 9) / col_total)
        rhs_QUAL = float(row.QUAL_temp) * ((rhs_len - 9) / col_total)
        QUAL = lhs_QUAL + rhs_QUAL
        
        # info
        AN *= 2
        info = ("AC="+str(AC)+";AN="+str(AN)+";SF=0,1",)        
        
        #return row[:5]+(QUAL,)+(row[6],)+info+(row[8],)+row[9:lhs_len]+row[lhs_len + 9:]
        return row[:5]+(QUAL,)+(row[6],)+info+(":".join(dup_format), )+result_lhs + result_rhs

## Run VCF merge

In [3]:
# load case.vcf from HDFS
case = preVCF("hdfs://master:9000/vcf/case_merge_vcf", 0).cache()
case.count()

# control vcf load
control = preVCF("hdfs://master:9000/vcf/control_vcf", 1).cache()
control.count()

# case & control indexing
case_col = len(case.columns)
control_col = len(control.columns)

# merge schema
col = case.columns + control.columns
header = col[:case_col] + col[case_col + 9:]

### join expresion
joinEX = [
              case['CHROM'] == control['CHROM_temp'],
              case['POS'] == control['POS_temp'],
              case['REF'] == control['REF_temp']
         ]

join_result = case.join(control, joinEX, 'full').cache()
join_result.count()

# unpresti
case.unpersist()
control.unpersist()

# write delim \t
join_result.rdd.map(lambda row : selectCol(row, case_col, control_col))\
               .toDF(header).dropDuplicates(['CHROM', 'POS'])\
               .write.option("delimiter", "\t").csv("hdfs://master:9000/vcf/merge_1216_re6.txt")

In [5]:
join_result.count()

886402