In [43]:
import findspark
findspark.init()

from pyspark.sql import SparkSession
from pyspark.sql.functions import udf, col, desc, asc, coalesce, broadcast
from pyspark import Row
import re

spark = SparkSession.builder.master("spark://master:7077")\
                    .appName("vcf_merge_1212")\
                    .config("spark.worker.memory", "14g")\
                    .getOrCreate()

#  .config("driver.memory", "20g")\

In [None]:
#spark.range(10000).toDF("hi").show(200)

In [24]:
spark.catalog.clearCache()

In [38]:
spark.stop()

# 실 데이터를 이용한 VCF merge

In [40]:
## filter def
def chr_remove(chrom):
    chrom = re.sub("chr", "", chrom) # "chr" to ""
    if chrom == "X": 
        chrom = "23"
    elif chrom == "Y": 
        chrom = "24"
    elif chrom == "XY" or chrom == "M": 
        chrom = "-99"
    return chrom

chr_remove_udf = udf(chr_remove)

def selectCol(row, lhs_len, rhs_len):
    # INFO re      
    AC, AN = 0, 0 
    
    if row[9] == None :
        GT = row[lhs_len + 9:]
    elif row[lhs_len + 9] == None :
        GT = row[9:lhs_len]
    else:
        GT = row[9:lhs_len]+row[lhs_len + 9:]
        
    for temp in GT:
        if temp == None:
            break
        else:
            if "0/1:" in temp:
                AC += 1
                AN += 1
            elif "1/1:" in temp:
                AC += 2
                AN += 1
            elif "0/0:" in temp:
                AN += 1
    
    # rhs가 null
    if(row.CHROM_temp == None):
        temp = tuple()
        for ref in range(rhs_len - 9):
            temp += ("0/0",) # GC
            AN += 1

        # info
        AN *= 2
        info = ("AC="+str(AC)+";AN="+str(AN)+";SF=0",)
        return row[:7] + info + (row[8],) + row[9:lhs_len] + temp
    
    # lhs가 null
    elif(row.CHROM == None):
        temp = tuple()
        for ref in range(lhs_len - 9):
            temp += ("0/0",) # GC
            AN += 1
           
        # info
        AN *= 2
        info = ("AC="+str(AC)+";AN="+str(AN)+";SF=1",)
        return row[lhs_len:lhs_len + 7] + info + (row.FORMAT_temp,) + temp + row[lhs_len + 9:]
    
    # case, control 둘다 존재
    else:
        # QUAL re-calculation
        col_total = lhs_len + rhs_len - 18
        lhs_QUAL = float(row.QUAL) * ((lhs_len - 9) / col_total)
        rhs_QUAL = float(row.QUAL_temp) * ((rhs_len - 9) / col_total)
        QUAL = lhs_QUAL + rhs_QUAL
        
        # info
        AN *= 2
        info = ("AC="+str(AC)+";AN="+str(AN)+";SF=0,1",)        
        
        return row[:5]+(QUAL,)+(row[6],)+info+(row[8],)+row[9:lhs_len]+row[lhs_len + 9:]

In [41]:
# load case.vcf from HDFS
case = spark.sparkContext.textFile("hdfs://master:9000/vcf/case_vcf")\
            .map(lambda x : x.split("\t"))
header = case.first()
step1 = case.filter(lambda row : row != header).toDF(header)
case = step1.select(chr_remove_udf(step1["#CHROM"]).cast("Integer").alias("CHROM"), "*")\
             .drop(col("#CHROM")).filter(col("FILTER") == "PASS").coalesce(20)
case.cache()

# control vcf load
control = spark.sparkContext.textFile("hdfs://master:9000/vcf/control_vcf")\
            .map(lambda x : x.split("\t"))
# 첫 번째 row를 header로 바꾸기
header = control.first()
step1 = control.filter(lambda row : row != header).toDF(header)
control = step1.select(chr_remove_udf(step1["#CHROM"]).cast("Integer").alias("CHROM"), "*")\
             .drop(col("#CHROM")).filter(col("FILTER") == "PASS")
### colname 수정 ---> []_temp
for index in range(len(control.columns[:])):
    control = control.withColumnRenamed(control.columns[index], control.columns[index] + "_temp") 
    
control = control.coalesce(20)
control.cache()
#control.count() # cache action


### join expresion
joinEX = [
              case['CHROM'] == control['CHROM_temp'],
              case['POS'] == control['POS_temp'],
              case['REF'] == control['REF_temp']
         ]

join_result = case.join(control, joinEX, 'outer')   


# case & control indexing
case_col = len(case.columns)
control_col = len(control.columns)

# schema
col = join_result.columns
header = col[:case_col] + col[case_col + 9:]
                                            
join_rdd = join_result.rdd.map(lambda row : selectCol(row, case_col, control_col))\
                      .coalesce(10).toDF(header)
join_rdd.cache()
join_rdd.count() # cache action

889929

In [42]:
# write HDFS
#join_rdd.write.parquet("hdfs://210.115.229.91:9000/vcf/merge.parquet")
join_rdd.write.csv("hdfs://master:9000/vcf/merge_1216.csv")

# 
spark.catalog.clearCache()
spark.stop()

### case preprocessing

In [None]:
case = spark.sparkContext.textFile("hdfs://master:9000/vcf/case_vcf")\
            .map(lambda x : x.split("\t"))
header = case.first()
step1 = case.filter(lambda row : row != header).toDF(header)
case = step1.select(chr_remove_udf(step1["#CHROM"]).cast("Integer").alias("CHROM"), "*")\
             .drop(col("#CHROM")).filter(col("FILTER") == "PASS").coalesce(20)
case.cache()

In [None]:
case.count()

## control preprocessing

In [44]:
# control vcf load
control = spark.sparkContext.textFile("hdfs://master:9000/vcf/control_vcf")\
            .map(lambda x : x.split("\t"))
# 첫 번째 row를 header로 바꾸기
header = control.first()
step1 = control.filter(lambda row : row != header).toDF(header)
control = step1.select(chr_remove_udf(step1["#CHROM"]).cast("Integer").alias("CHROM"), "*")\
             .drop(col("#CHROM")).filter(col("FILTER") == "PASS")
### colname 수정 ---> []_temp
for index in range(len(control.columns[:])):
    control = control.withColumnRenamed(control.columns[index], control.columns[index] + "_temp") 
    
control = control.coalesce(20)
control.cache()

DataFrame[CHROM_temp: int, POS_temp: string, ID_temp: string, REF_temp: string, ALT_temp: string, QUAL_temp: string, FILTER_temp: string, INFO_temp: string, FORMAT_temp: string, BLSA-1579_temp: string, BLSA-1775_temp: string, BLSA-1839_temp: string, BLSA-1883_temp: string, BLSA-1924_temp: string, BLSA-2037_temp: string, BLSA-2069_temp: string, BLSA-745_temp: string, BLSA-827_temp: string, JHU-705_temp: string, JHU-710_temp: string, JHU-719_temp: string, MIAMI-2112_temp: string, MIAMI-2843_temp: string, MIAMI-2852_temp: string, MIAMI-3216_temp: string, MIAMI-3231_temp: string, MIAMI-3296_temp: string, MIAMI-3410_temp: string, MIAMI-3643_temp: string, MIAMI-3651_temp: string, MIAMI-3747_temp: string, MIAMI-3772_temp: string, MIAMI-3797_temp: string, MIAMI-3799_temp: string, MIAMI-3860_temp: string, MIAMI-4022_temp: string, MIAMI-4042_temp: string, SH-00-34_temp: string, SH-00-38_temp: string, SH-00-49_temp: string, SH-01-14_temp: string, SH-01-31_temp: string, SH-01-37_temp: string, SH-0

In [None]:
control.count()

## outer join

In [19]:
%time
### join expresion
joinEX = [
              case['CHROM'] == control['CHROM_temp'],
              case['POS'] == control['POS_temp'],
              case['REF'] == control['REF_temp']
         ]
join_result = case.join(control, joinEX, 'outer')   


# case & control indexing
case_col = len(case.columns)
control_col = len(control.columns)

# schema
col = join_result.columns
header = col[:case_col] + col[case_col + 9:]
 
    
#join_rdd = join_result.rdd.map(lambda row : selectCol(row, case_col, control_col))

join_rdd = join_result.rdd.map(lambda row : selectCol(row, case_col, control_col)).toDF(header)

join_rdd.cache()
join_rdd.count() # cache action

CPU times: user 0 ns, sys: 5 µs, total: 5 µs
Wall time: 11.2 µs


889929

In [23]:
join_rdd.dropDuplicates(['CHROM', 'POS']).count()

886402

In [22]:
join_rdd.printSchema()

root
 |-- CHROM: long (nullable = true)
 |-- POS: string (nullable = true)
 |-- ID: string (nullable = true)
 |-- REF: string (nullable = true)
 |-- ALT: string (nullable = true)
 |-- QUAL: double (nullable = true)
 |-- FILTER: string (nullable = true)
 |-- INFO: string (nullable = true)
 |-- FORMAT: string (nullable = true)
 |-- ND00002: string (nullable = true)
 |-- ND00003: string (nullable = true)
 |-- ND00004: string (nullable = true)
 |-- ND00006: string (nullable = true)
 |-- ND00007: string (nullable = true)
 |-- ND00008: string (nullable = true)
 |-- ND00010: string (nullable = true)
 |-- ND00015: string (nullable = true)
 |-- ND00021: string (nullable = true)
 |-- ND00023: string (nullable = true)
 |-- ND00029: string (nullable = true)
 |-- ND00033: string (nullable = true)
 |-- ND00035: string (nullable = true)
 |-- ND00039: string (nullable = true)
 |-- ND00043: string (nullable = true)
 |-- ND00055: string (nullable = true)
 |-- ND00058: string (nullable = true)
 |-- ND000

In [20]:
join_rdd.show(10)

+-----+---------+-----------+---+---+-----------------+------+-------------------+--------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+

In [16]:
def selectCol(row, lhs_len, rhs_len):
    # INFO re      
    AC, AN = 0, 0 
    
    if row[9] == None :
        GT = row[lhs_len + 9:]
    elif row[lhs_len + 9] == None :
        GT = row[9:lhs_len]
    else:
        GT = row[9:lhs_len]+row[lhs_len + 9:]
        
    for temp in GT:
        if temp == None:
            break
        else:
            if "0/1:" in temp:
                AC += 1
                AN += 1
            elif "1/1:" in temp:
                AC += 2
                AN += 1
            elif "0/0:" in temp:
                AN += 1
    
    # rhs가 null
    if(row.CHROM_temp == None):
        temp = tuple()
        for ref in range(rhs_len - 9):
            temp += ("0/0",) # GC
            AN += 1

        # info
        AN *= 2
        info = ("AC="+str(AC)+";AN="+str(AN)+";SF=0",)
        return row[:7] + info + (row[8],) + row[9:lhs_len] + temp
    
    # lhs가 null
    elif(row.CHROM == None):
        temp = tuple()
        for ref in range(lhs_len - 9):
            temp += ("0/0",) # GC
            AN += 1
           
        # info
        AN *= 2
        info = ("AC="+str(AC)+";AN="+str(AN)+";SF=1",)
        #info = ("AC=;AN=SF=1",)
        return row[lhs_len:lhs_len + 7] + info + (row.FORMAT_temp,) + temp + row[lhs_len + 9:]
    
    # case, control 둘다 존재
    else:
        # QUAL re-calculation
        col_total = lhs_len + rhs_len - 18
        lhs_QUAL = float(row.QUAL) * ((lhs_len - 9) / col_total)
        rhs_QUAL = float(row.QUAL_temp) * ((rhs_len - 9) / col_total)
        QUAL = lhs_QUAL + rhs_QUAL
        
        # info
        AN *= 2
        info = ("AC="+str(AC)+";AN="+str(AN)+";SF=0,1",)        
        
        return row[:5]+(QUAL,)+(row[6],)+info+(row[8],)+row[9:lhs_len]+row[lhs_len + 9:]

# 93 vcf

In [None]:
## filter def
def chr_remove(chrom):
    chrom = re.sub("chr", "", chrom) # "chr" to ""
    if chrom == "X": 
        chrom = "23"
    elif chrom == "Y": 
        chrom = "24"
    elif chrom == "XY": 
        chrom = "25"
    elif chrom == "M":
        chrom = "26"

    return chrom
chr_remove_udf = udf(chr_remove)

# join 후 select with not null
# lhs가 null일 경우 1
# rhs가 null일 경우 merge된 모든 sample의 "0/0:0"
def selectCol(row, count):
    if(row.CHROM_temp == None):
        return row[:9] + row[9:9 + count + 1] + ("0/0:0",)
    elif(row.CHROM == None):
        temp = tuple()
        for ref in range(len(row[9:9 + count + 1])):
            temp += ("0/0:0",)
        return row[10 + count:-1] + temp + (row[-1],)
    else:
        return row[:7] + (row[7] + ";" + row[17 + count],)+ (row[8], ) + row[9: 9 + count + 1] + (row[-1],)

In [None]:
%%time
# dataFrame load function
for indv in range(0,10):
    # vcf load
    temp = spark.sparkContext.textFile("hdfs://210.115.229.91:9000/1000g_vcf/" + str(indv + 1))\
             .map(lambda x : x.split("\t"))
    # 첫 번째 row를 header로 바꾸기
    header = temp.first()
    step1 = temp.filter(lambda row : row != header).toDF(header)
    step2 = step1.select(chr_remove_udf(step1["#CHROM"]).cast("Integer").alias("CHROM"), "*")\
                .drop(step1["#CHROM"])
    ind_vcf = step2.drop(step1.columns[-1]).filter(step2["FILTER"] == "PASS")
    if(ind_vcf.rdd.getNumPartitions() > 10):
        ind_vcf = ind_vcf.coalesce(20) # partition 조절
    
    if(indv != 0):
        ### colname 수정 ---> []_temp
        for index in range(len(ind_vcf.columns[:-1])):
            ind_vcf = ind_vcf.withColumnRenamed(ind_vcf.columns[index], 
                                                ind_vcf.columns[index] + "_temp") 
        ### join expresion
        joinEX = [join_rdd_df['CHROM'] == ind_vcf['CHROM_temp'], 
                  join_rdd_df['POS'] == ind_vcf['POS_temp'],
                  join_rdd_df['REF'] == ind_vcf['REF_temp']]
        join_result = join_rdd_df.join(broadcast(ind_vcf), joinEX, 'outer')
                
        col = join_result.columns
        header = col[:9] + col[9:9 + indv] + [col[-1]]
        
        join_rdd = join_result.rdd.map(lambda row : selectCol(row, indv - 1))
        join_rdd_df = join_rdd.toDF(header)

        # partition 수 조절
        if(join_rdd_df.rdd.getNumPartitions() > 10):
            join_rdd_df = join_rdd_df.coalesce(20)
        
    else:
        join_rdd_df = ind_vcf
        
    join_rdd_df.cache()
    print("VCF merge stage = ", indv + 1)   

In [None]:
    
"""
if(indv != 1):
        # rhs -> column rename str+"_temp"
        for index in range(len(step3.columns[:-1])):
            step3 = step3.withColumnRenamed(step3.columns[index], step3.columns[index] + "_temp") 
    vcf_list.append(step3)    
for index in range(len(vcf_list)):
    vcf_list[index].rdd.repartition(5).cache()
    
# vcf join function
if count == 0:
    joinEX = [vcf_list[count]['CHROM'] == vcf_list[count + 1]['CHROM_temp'], 
              vcf_list[count]['POS'] == vcf_list[count + 1]['POS_temp'],
              vcf_list[count]['REF'] == vcf_list[count + 1]['REF_temp']]
    join_result = vcf_list[count].join(vcf_list[count + 1], joinEX, 'outer')
else:
    joinEX = [join_rdd_df['CHROM'] == vcf_list[count + 1]['CHROM_temp'], 
          join_rdd_df['POS'] == vcf_list[count + 1]['POS_temp'],
          join_rdd_df['REF'] == vcf_list[count + 1]['REF_temp']]
    join_result = join_rdd_df.join(vcf_list[count + 1], joinEX, 'outer')
    
col = join_result.columns
header = col[:9] + col[9:9 + count + 1] + [col[-1]]
join_rdd = join_result.rdd.map(lambda row : selectCol(row, count))
join_rdd_df = join_rdd.toDF(header)

# partition 수 조절
if(join_rdd_df.rdd.getNumPartitions() > 10):
    join_rdd_df = join_rdd_df.coalesce(5)

join_rdd_df.cache()
count+=1
print(count)
"""   
"""
%%time
# dataFrame load function
# case vcf load
case = spark.sparkContext.textFile("hdfs://210.115.229.91:9000/1000g_vcf/" + str(indv + 1))\
            .map(lambda x : x.split("\t"))
# 첫 번째 row를 header로 바꾸기
header = temp.first()
step1 = temp.filter(lambda row : row != header).toDF(header)
step2 = step1.select(chr_remove_udf(step1["#CHROM"]).cast("Integer").alias("CHROM"), "*")\
             .drop(step1["#CHROM"])
case = step2.drop(step1.columns[-1]).filter(step2["FILTER"] == "PASS")
if(ind_vcf.rdd.getNumPartitions() > 10):
    case = case.coalesce(20) # partition 조절

# control vcf load
control = spark.sparkContext.textFile("hdfs://210.115.229.91:9000/1000g_vcf/" + str(indv + 1))\
            .map(lambda x : x.split("\t"))
# 첫 번째 row를 header로 바꾸기
header = temp.first()
step1 = temp.filter(lambda row : row != header).toDF(header)
step2 = step1.select(chr_remove_udf(step1["#CHROM"]).cast("Integer").alias("CHROM"), "*")\
             .drop(step1["#CHROM"])
control = step2.drop(step1.columns[-1]).filter(step2["FILTER"] == "PASS")
if(control.rdd.getNumPartitions() > 10):
    control = control.coalesce(20) # partition 조절    
    
### colname 수정 ---> []_temp
for index in range(len(control.columns[:-1])):
        control = control.withColumnRenamed(control.columns[index], 
                                            control.columns[index] + "_temp")  

### join expresion
joinEX = [case['CHROM'] == control['CHROM_temp'],
          case['POS'] == control['POS_temp'],
          case['REF'] == control['REF_temp']]
join_result = case.join(broadcast(control), joinEX, 'outer')                

# schema
col = join_result.columns
header = col[:9] + col[9:9 + indv] + [col[-1]]

# case & control indexing
case_col = len(case.columns[9:-1]
control_col = len(control.columns[9:-1])
                                            
join_rdd = join_result.rdd.map(lambda row : selectCol(row, indv - 1))
join_rdd_df = join_rdd.toDF(header)

        # partition 수 조절
if(join_rdd_df.rdd.getNumPartitions() > 10):
    join_rdd_df = join_rdd_df.coalesce(20)

        
join_rdd_df.cache()
"""

In [None]:
join_rdd_df.write.parquet("hdfs://210.115.229.91:9000/vcf_merge/merge_10_V2.parquet")