In [1]:
import os
from pyspark.sql import SparkSession
import mlstream_spark_udfs as udfs

# Create a spark session and add the jar file with mlstream-spark-udfs
spark = SparkSession.builder \
    .appName('demo approx_topk') \
    .master("local[*]") \
    .config("spark.driver.memory","4G") \
    .config("spark.executor.memory","4G")\
    .config("spark.jars", os.environ["MLSTREAM_UDFS_JAR_PATH"]) \
    .getOrCreate()

# Note: Spark UI should be available at http://localhost:4040
udfs.registerUDFs(spark)

In [67]:
# Create two sequence of integers (0, 1, ..., n - 1] and (0, 1, 2,...,4],
# then take the union. It's clear that the duplicate elements are (0, 1, ...,4]
n = 200000000
# If you increase n even more, spark is likely to crash with "OutOfMemory" exception
spark.range(0, n,  step = 1, numPartitions = 4).createOrReplaceTempView("lots_of_data")
spark.range(0, 5,  step = 1, numPartitions = 4).createOrReplaceTempView("duplicates")
spark.sql("""(
(SELECT * FROM lots_of_data) UNION ALL 
(SELECT * FROM duplicates)
)""").createOrReplaceTempView("test_table")
spark.sql("""SELECT * FROM test_table""").show(10)

+---+
| id|
+---+
|  0|
|  1|
|  2|
|  3|
|  4|
|  5|
|  6|
|  7|
|  8|
|  9|
+---+
only showing top 10 rows



In [71]:
def query_1():
    spark.sql("""
        SELECT extract_approx_topk(d) FROM
            (SELECT 
                approx_topk(CAST (id AS LONG), CAST(1.0 AS FLOAT), 1000, "1MB") AS d 
             FROM test_table) tmp
    """).show(10)
%time query_1()

+---------+-----+------+
|      key|value| error|
+---------+-----+------+
|        0|  1.0|1525.0|
|        1|  1.0|1525.0|
|        2|  1.0|1525.0|
|        3|  1.0|1525.0|
|        4|  1.0|1525.0|
|149946368|  1.0|1525.0|
|149946369|  1.0|1525.0|
|149946370|  1.0|1525.0|
|149946371|  1.0|1525.0|
|149946372|  1.0|1525.0|
+---------+-----+------+
only showing top 10 rows

CPU times: user 4 ms, sys: 0 ns, total: 4 ms
Wall time: 8.27 s


In [72]:
# You can note in the spark logs (or the docker console) the following messsage
# 19/11/19 17:44:10 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
# Older versions of spark would crash under this circumstance
def query_2():
    spark.sql("""
        SELECT id, COUNT(*) AS cnt
        FROM test_table
        GROUP BY id
        ORDER BY cnt DESC
        LIMIT 1000
    """).show(10)
%time query_2() 

+----+---+
|  id|cnt|
+----+---+
|   0|  2|
|   1|  2|
|   3|  2|
|   2|  2|
|   4|  2|
|  26|  1|
|  29|  1|
| 474|  1|
| 964|  1|
|1677|  1|
+----+---+
only showing top 10 rows

CPU times: user 8 ms, sys: 4 ms, total: 12 ms
Wall time: 1min 14s


In [73]:
# Exact calculation where the filter comes from approx_topk
def query_3():
    spark.sql("""
        SELECT extract_approx_topk(d) FROM
            (SELECT 
                approx_topk(CAST (id AS LONG), CAST(1.0 AS FLOAT), 1000, "10MB") AS d 
             FROM test_table) tmp
    """).createOrReplaceTempView("candidate_ids")
    
    spark.sql("""
            SELECT id, COUNT(id) AS cnt FROM candidate_ids c
            LEFT JOIN (SELECT * FROM test_table) t ON c.key = t.id
            GROUP BY id
            ORDER BY cnt DESC
            LIMIT 1000            
    """).show(10)    

%time query_3()

+---------+---+
|       id|cnt|
+---------+---+
|        0|  2|
|        1|  2|
|        3|  2|
|        2|  2|
|        4|  2|
|199229684|  1|
|199230068|  1|
|199230140|  1|
|199230304|  1|
|199230375|  1|
+---------+---+
only showing top 10 rows

CPU times: user 0 ns, sys: 4 ms, total: 4 ms
Wall time: 37.9 s


In [82]:
# Exact calculation where the filter comes from approx_topk
def query_4():
    print("Compute Candiates")
    spark.sql("""
        SELECT extract_approx_topk(d) FROM
            (SELECT 
                approx_topk(CAST (id AS LONG), CAST(1.0 AS FLOAT), 1000, "10MB") AS d 
             FROM test_table) tmp
    """).show(10)
    
    print("Compute Exact Values For Candidates")
    spark.sql("""
            SELECT id, COUNT(id) AS cnt 
            FROM test_table t
            WHERE id IN (1,2,3,4) -- TODO: add function approx_topk_is_frequent(d) so no hard-coding is needed
            GROUP BY id
            ORDER BY cnt DESC
            LIMIT 1000            
    """).show(10)    

%time query_4()

Compute Candiates
+---------+-----+-----+
|      key|value|error|
+---------+-----+-----+
|        0|  1.0| 95.0|
|        1|  1.0| 95.0|
|        2|  1.0| 95.0|
|        3|  1.0| 95.0|
|        4|  1.0| 95.0|
|199229440|  1.0| 95.0|
|199229441|  1.0| 95.0|
|199229442|  1.0| 95.0|
|199229443|  1.0| 95.0|
|199229444|  1.0| 95.0|
+---------+-----+-----+
only showing top 10 rows

Compute Exact Values For Candiates
+---+---+
| id|cnt|
+---+---+
|  1|  2|
|  3|  2|
|  2|  2|
|  4|  2|
+---+---+

CPU times: user 8 ms, sys: 0 ns, total: 8 ms
Wall time: 11.2 s


In [None]:
# TODO: 
# - improve errors to report not the "global" errors, but errors for each element
# - rename approx_topk -> approx_most_frequent