In [110]:
import findspark
findspark.init()

# Spark & python function
from pyspark.sql import SparkSession
from pyspark.sql.types import IntegerType, StringType
from pyspark.sql.functions import * as F
from pyspark import Row
import re
import psycopg2

spark = SparkSession.builder.master("spark://master:7077")\
                        .appName("gVCF_combine")\
                        .config("spark.executor.memory", "18G")\
                        .config("spark.executor.core", "3")\
                        .config("spark.sql.shuffle.partitions", 20)\
                        .config("spark.driver.memory", "13G")\
                        .config("spark.driver.maxResultSize", "10G")\
                        .config("spark.driver.extraClassPath", "/root/postgresql-9.4.1207.jar")\
                        .config("spark.executor.extraClassPath", "/root/postgresql-9.4.1207.jar")\
                        .config("spark.jars", "/root/postgresql-9.4.1207.jar")\
                        .getOrCreate()

#.config("spark.daemon.memory", "14G")\
#.config("spark.driver.port", "9797")\
#.config("spark.blockManager.port", "9898")\
#.config("spark.driver.blockManager.port", "9898")\

# redis
#.config("spark.redis.host", "210.115.229.97")\
#.config("spark.redis.port", "6379")\
#.config("spark.jars", "/spark-redis/target/spark-redis-2.4.1-SNAPSHOT-jar-with-dependencies.jar")\

# postgresql
#.config("spark.driver.extraClassPath", "/postgresql-9.4.1207.jar")\
#.config("spark.executor.extraClassPath", "/postgresql-9.4.1207.jar")\
#.config("spark.jars", "/postgresql-9.4.1207.jar")\
#.config("spark.repl.local.jars", "/postgresql-9.4.1207.jar")\

In [109]:
spark.catalog.clearCache()
spark.stop()

In [None]:
!hdfs dfs -ls /vcf/gvcf

In [39]:
def preVCF(hdfs, flag, spark):
    vcf = spark.sparkContext.textFile(hdfs)

    header_contig = vcf.filter(lambda x : re.match("^#", x))
    col_name = vcf.filter(lambda x : x.startswith("#CHROM")).first().split("\t")
    vcf_data = vcf.filter(lambda x : re.match("[^#][^#]", x))\
                       .map(lambda x : x.split("\t"))\
                       .toDF(col_name)\
                       .withColumn("POS", col("POS").cast(IntegerType()))


    ## lookup table
    lookup = vcf_data.filter(substring(col("INFO"), 0, 3).isin("END"))\
                .select(col("#CHROM").alias("chr"), col("POS").alias("start"),
                        regexp_replace(col("INFO"), "END=", "").cast(IntegerType()).alias("end"), vcf_data[9])\
                .coalesce(10).cache()
    
    #.withColumn("INDEX", concat(col("CHR"), lit("_"), col("START")))\

    # flag에 따라 column 명 변경
    if flag == 1:
        for index in range(len(vcf_data.columns[:9])):
            compared_arr = ["#CHROM", "POS", "REF"]
            if vcf_data.columns[index] in compared_arr:
                continue
            vcf_data = vcf_data.withColumnRenamed(vcf_data.columns[index], vcf_data.columns[index] + "_temp") 
    
    return (vcf_data, lookup)

In [111]:
# main
left, lookup_left = preVCF("hdfs://master:9000/raw_data/gvcf/ND02798_eg.raw.vcf", 0, spark)
right, lookup_right = preVCF("hdfs://master:9000/raw_data/gvcf/ND02809_eg.raw.vcf", 1, spark)

#lookup_left.count()
#lookup_right.count()

# redis
# lookup_left.write.format("org.apache.spark.sql.redis").option("table", "left").option("key.column", "INDEX").save()
# lookup_right.write.format("org.apache.spark.sql.redis").option("table", "right").option("key.column", "INDEX").save()
# import redis
#conn = redis.StrictRedis(host='210.115.229.97', port=6379,db=0)
# conn.hgetall("left:chrM_4636")[b"CHR"]

#spark.sparkContext.parallelize(range(10)).map(lambda x : spark.sql("SELECT * FROM lookup_left WHERE END <= {0}".format(x)))

In [41]:
lookup_left.write.format("jdbc")\
           .option("url", "jdbc:postgresql://210.115.229.97:5432/testdb")\
           .option("dbtable", "lookup.left")\
           .option("user", "postgres")\
           .option("password", "sempre813!")\
           .option("batchsize", "100000")\
           .save()

In [87]:
# global variable, lookup_left, lookup_right
def selectCol(row):
    if(row.INFO == None):
        vcf_column = (row["#CHROM"], row["POS"], row.ID_temp, row["REF"],row.ALT_temp, row.QUAL_temp, row.FILTER_temp)
        chr_, pos = row["#CHROM"], row["POS"]
        
        ### 0 ---> CHR, 1 ---> START, 2 ---> END, 3: ---> SAMPLE
        value = getValue("lookup.left", chr_, pos)[0]
        end_right = int(row.INFO_temp.replace("END=", ""))
        
        if(value[2] > end_right): re_pos = end_right
        elif(value[2] < end_right): re_pos = value[2]
        else: re_pos =  end_right
        
        vcf_column = vcf_column + ("END=" + str(re_pos), row.FORMAT_temp) + value[3:] + row[16:]
        return vcf_column
    
    elif(row.INFO_temp == None):
        vcf_column = (row["#CHROM"], row["POS"], row.ID, row["REF"],row.ALT, row.QUAL, row.FILTER)
        chr_, pos = row["#CHROM"], row["POS"]
        
        value = getValue("lookup.right", chr_, pos)[0]
        end_right = int(row.INFO.replace("END=", ""))
        
        if(value[2] > end_right): re_pos = end_right
        elif(value[2] < end_right): re_pos = value[2]
        else: re_pos =  end_right
        
        vcf_column = vcf_column + ("END=" + str(re_pos), row.FORMAT) + (row[9],) + value[3:] 
        return vcf_column

In [113]:
join_vcf = left.join(right, ["#CHROM", "POS", "REF"], "full").orderBy(col("#CHROM"), col("POS")).cache()
join_vcf.count()

81337377

In [115]:
join_vcf.show(50)

+------+-----+---+----+---------+----+------+---------+------------------+--------------------+-------+---------+---------+-----------+---------+------------------+--------------------+
|#CHROM|  POS|REF|  ID|      ALT|QUAL|FILTER|     INFO|            FORMAT|             ND02798|ID_temp| ALT_temp|QUAL_temp|FILTER_temp|INFO_temp|       FORMAT_temp|             ND02809|
+------+-----+---+----+---------+----+------+---------+------------------+--------------------+-------+---------+---------+-----------+---------+------------------+--------------------+
|  chr1|    1|  N|   .|<NON_REF>|   .|     .|END=10081|GT:DP:GQ:MIN_DP:PL|     0/0:0:0:0:0,0,0|      .|<NON_REF>|        .|          .|END=10098|GT:DP:GQ:MIN_DP:PL|     0/0:0:0:0:0,0,0|
|  chr1|10082|  C|   .|<NON_REF>|   .|     .|END=10108|GT:DP:GQ:MIN_DP:PL|    0/0:2:5:2:0,6,49|   null|     null|     null|       null|     null|              null|                null|
|  chr1|10099|  A|null|     null|null|  null|     null|              n

In [132]:
from pyspark.sql.window import Window
from pyspark.sql.functions import last

lookup_window = Window.partitionBy("#CHROM").orderBy("POS").rangeBetween(Window.unboundedPreceding, 0)

In [136]:
join_vcf\
        .withColumn("ND02798", last("ND02798", ignorenulls  = True).over(lookup_window))\
        .withColumn("INFO", last("INFO", ignorenulls  = True).over(lookup_window))\
        .withColumn("ND02809", last("ND02798", ignorenulls  = True).over(lookup_window))\
        .withColumn("INFO_temp", last("INFO", ignorenulls  = True).over(lookup_window))\
        .show()

+------+-----+---+----+---------+----+------+---------+------------------+-----------------+-------+---------+---------+-----------+---------+------------------+-----------------+
|#CHROM|  POS|REF|  ID|      ALT|QUAL|FILTER|     INFO|            FORMAT|          ND02798|ID_temp| ALT_temp|QUAL_temp|FILTER_temp|INFO_temp|       FORMAT_temp|          ND02809|
+------+-----+---+----+---------+----+------+---------+------------------+-----------------+-------+---------+---------+-----------+---------+------------------+-----------------+
| chr12|    1|  N|   .|<NON_REF>|   .|     .|END=60131|GT:DP:GQ:MIN_DP:PL|  0/0:0:0:0:0,0,0|      .|<NON_REF>|        .|          .|END=60131|GT:DP:GQ:MIN_DP:PL|  0/0:0:0:0:0,0,0|
| chr12|60112|  A|null|     null|null|  null|END=60131|              null|  0/0:0:0:0:0,0,0|      .|<NON_REF>|        .|          .|END=60131|GT:DP:GQ:MIN_DP:PL|  0/0:0:0:0:0,0,0|
| chr12|60132|  G|   .|<NON_REF>|   .|     .|END=60137|GT:DP:GQ:MIN_DP:PL| 0/0:2:6:2:0,6,58|   null|

In [141]:
def getValue(chr_, pos):
    value = lookup_left.where((col("CHR") == chr_) & (col("START") <= pos) & (col("END") >= pos))\
                          .rdd.collect()[0]
    return value

In [89]:
 def getValue(table_name, chr_, pos):
    import psycopg2
    with psycopg2.connect(database = "testdb", user = "postgres", password="sempre813!", host = "210.115.229.97", port = "5432") as f:
        cur = f.cursor()
        cur.execute("""SELECT * FROM {0} WHERE chr='{1}' AND start < {2} AND "end" > {2}""".format(table_name,chr_, pos))
        rows = cur.fetchall()
    return rows