## Parallel Call Variants Stage of Deep Variant

### Start Spark with dependencies

In [10]:
import pyspark
from pyspark import SparkConf
from pyspark.sql import SparkSession

spark = SparkSession.builder \
        .appName("Spark Deep Variant") \
        .config("spark.driver.memory", "8G") \
        .config("spark.jars", "../target/scala-2.12/spark-deepvariant-assembly-0.0.2.jar") \
        .master("local[8]") \
        .getOrCreate()
spark

### Load model

In [11]:
from sparkdv.transformers import VariantCallerModel

In [12]:
variant_caller = VariantCallerModel()\
  .load("../variant_caller_fast_onnx")

### Load data

In [14]:
from sparkdv.utils.schemas import example_schema
input_path = "/home/jose/genomics/deepvariant/output/intermediate_results_dir/make_examples*"
input_ds = spark.read.format("tfrecord").schema(example_schema).option("recordType", "Example").load(input_path)

### Plot number of partitions
Num of partitions determine how many tasks we can actually run in parallel

In [None]:
input_ds.rdd.getNumPartitions()

### Call the model on a few examples - store to disk

In [16]:
# Let's compute on a small portion
variant_caller.transform(input_ds.limit(1024)).select("probabilities")\
.write.parquet("./result_probs.parquet")

### Sample one datapoint

In [14]:
output = variant_caller.transform(input_ds.limit(1024)).select("probabilities")
next(output.toLocalIterator())

Row(probabilities=[0.01451485138386488, 0.9837851524353027, 0.0017000219086185098])

In [15]:
input_ds.count()

229635