In [1]:
from pyspark.sql import SparkSession
import os
import pandas as pd
import time
import string
import pathlib
import random
import threading
import time
from urllib.parse import urlsplit, urlunsplit
import requests
import json
from py4j.protocol import Py4JJavaError, Py4JError

In [2]:
# Global configuration
SPARK_MEMORY = 900
SPARK_CORES = 60
DBHOST = 'postgres'
group_in_leaves = False
QUERY_TIMEOUT = 60 * 20

In [3]:
def create_spark():
    spark = SparkSession.builder \
        .appName("app") \
        .master(f'local[{SPARK_CORES}]') \
        .config("spark.driver.memory", f'{SPARK_MEMORY}g') \
        .config("spark.executor.memory", f'{SPARK_MEMORY}g') \
        .config("spark.jars", "postgresql-42.3.3.jar") \
        .getOrCreate()
    if (group_in_leaves):
        spark.sql("SET spark.sql.yannakakis.countGroupInLeaves = true").show()
    return spark

In [4]:
def extract_metrics(spark, group_id):
    parsed = list(urlsplit(spark.sparkContext.uiWebUrl))
    host_port = parsed[1]
    parsed[1] = 'localhost' + host_port[host_port.find(':'):]
    API_URL = f'{urlunsplit(parsed)}/api/v1'

    app_id = spark.sparkContext.applicationId
    sql_queries = requests.get(API_URL + f'/applications/{app_id}/sql', params={'length': '100000'}).json()
    query_ids = [q['id'] for q in sql_queries if q['description'] == group_id]
    if (len(query_ids) == 0):
        print(f'query with group {group_id} not found')
        return None
    query_id = query_ids[0]
    print(f'query id: {query_id}')
    
    query_details = requests.get(API_URL + f'/applications/{app_id}/sql/{query_id}',
                                 params={'details': 'true', 'planDescription': 'true'}).json()
    
    success_job_ids = query_details['successJobIds']
    running_job_ids = query_details['runningJobIds']
    failed_job_ids = query_details['failedJobIds']
    
    job_ids = success_job_ids + running_job_ids + failed_job_ids
    
    job_details = [requests.get(API_URL + f'/applications/{app_id}/jobs/{jid}').json() for jid in job_ids]
    
    job_stages = {}
    
    for j in job_details:
        stage_ids = j['stageIds']
        
        stage_params = {'details': 'true', 'withSummaries': 'true'}
        stages = [requests.get(API_URL + f'/applications/{app_id}/stages/{sid}', stage_params) for sid in stage_ids]
        
        job_stages[j['jobId']] = [stage.json() for stage in stages if stage.status_code == 200] # can be 404
    
    return query_details, job_details, job_stages

In [5]:
def import_db(spark, dbname):
    
    username = dbname
    password = dbname
    dbname = dbname

    df_tables = spark.read.format("jdbc") \
    .option("url", f'jdbc:postgresql://{DBHOST}:5432/{dbname}') \
    .option("driver", "org.postgresql.Driver") \
    .option("dbtable", "information_schema.tables") \
    .option("user", username) \
    .option("password", password) \
    .load()

    for idx, row in df_tables.toPandas().iterrows():
        if row.table_schema == 'public':
            table_name = row.table_name
            df = spark.read.format("jdbc") \
                .option("url", f'jdbc:postgresql://{DBHOST}:5432/{dbname}') \
                .option("driver", "org.postgresql.Driver") \
                .option("dbtable", table_name) \
                .option("user", username) \
                .option("password", password) \
                .load()
    
            print(table_name)
            #print(df.show())
            df.createOrReplaceTempView(table_name)

def random_str(size=16, chars=string.ascii_uppercase + string.digits):
    return ''.join(random.choice(chars) for _ in range(size))

def set_group_id(spark):
    group_id = random_str()
    spark.sparkContext.setJobGroup(group_id, group_id)
    return group_id

def cancel_query(seconds, group_id):
    time.sleep(seconds)
    print("cancelling jobs with id " + group_id)
    print(spark.sparkContext.cancelJobGroup(group_id))
    print("cancelled job")

def cancel_query_after(spark, seconds):
    group_id = random_str()
    spark.sparkContext.setJobGroup(group_id, group_id)
    threading.Thread(target=cancel_query, args=(seconds,group_id,)).start()
    return group_id
    
def run_query(file):
    with open(file, 'r') as f:
        query = '\n'.join(filter(lambda line: not line.startswith('limit') and not line.startswith('-'), f.readlines()))
        
        print("running query: \n" + query)
        return spark.sql(query)

In [6]:
def benchmark_query(spark, query, respath):
    start_time = time.time()

    group_id = cancel_query_after(spark, QUERY_TIMEOUT)
    df1 = run_query(query)
    df1.show()

    end_time = time.time()
    diff_time = end_time - start_time

    execution, jobs, job_stages = extract_metrics(spark, group_id)

    if execution is not None:
            with open(respath + f'/execution.json', 'w') as f:
                f.write(json.dumps(execution, indent=2))
            with open(respath + f'/jobs.json', 'w') as f:
                f.write(json.dumps(jobs, indent=2))
            with open(respath + f'/stages.json', 'w') as f:
                f.write(json.dumps(job_stages, indent=2))
    return diff_time

def benchmark(spark, dbname, query_file, mode):
    #spark.sql("SET spark.sql.yannakakis.enabled = false").show()
    # run the query once to warm up Spark (load the relation in memory)
    #df0 = run_query(query)
    #df0.show()
    
    query_name = os.path.basename(query_file)

    respath = f'benchmark-results-{dbname}/' + query_name + "/" + mode
    pathlib.Path(respath).mkdir(parents=True, exist_ok=True)

    if mode == "opt":
        spark.sql("SET spark.sql.yannakakis.enabled = true").show()
    elif mode == "ref":
        spark.sql("SET spark.sql.yannakakis.enabled = false").show()
    else:
        return []

    try:
        runtime = benchmark_query(spark, query_file, respath)
        return [query_name, runtime, mode]
    except Py4JError as e:
        print('timeout or error: ' + e)
        return [query_name, None, mode]

## SNAP Benchmark

In [7]:
dbname = 'snap'

spark = create_spark()
import_db(spark, dbname)

queries = ['snap-queries/patents/path02.sql',
          'snap-queries/patents/path03.sql',
          'snap-queries/patents/path04.sql',
          'snap-queries/patents/path05.sql',
          'snap-queries/patents/path06.sql',
          'snap-queries/patents/path07.sql',
          'snap-queries/patents/path08.sql',
          'snap-queries/patents/tree01.sql',
          'snap-queries/patents/tree02.sql',
          'snap-queries/patents/tree03.sql']

queries = ['snap-queries/patents/tree01.sql']

mode = 'ref'

results = []

for q in queries:
    results = results + [benchmark(spark, dbname, q, mode)]
    df = pd.DataFrame(results, columns = ['query', 'runtime', 'mode'])
    df.to_csv(f'results-{mode}.csv')
    print(df)

23/11/14 22:45:05 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
23/11/14 22:45:06 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


patents
+--------------------+-----+
|                 key|value|
+--------------------+-----+
|spark.sql.yannaka...|false|
+--------------------+-----+

running query: 
SELECT COUNT(*) FROM patents p1, patents p2, patents p3, patents p4a, patents p4b

WHERE p1.toNode = p2.fromNode

AND p2.toNode = p3.fromNode

    AND p3.toNode = p4a.fromNode

    AND p3.toNode = p4b.fromNode



                                                                                

+-----------+
|   count(1)|
+-----------+
|15961425879|
+-----------+

query id: 4
        query     runtime mode
0  tree01.sql  145.612812  ref


In [None]:
#spark.conf.set("spark.sql.legacy.setCommandRejectsSparkCoreConfs","false")
#spark.conf.set("spark.executor.cores", "1")
#spark.conf.set("spark.executor.instances", "1")
spark.conf.set("spark.sql.shuffle.partitions", "1")

In [None]:
## Compare result
import time
#query = 'tpch-kit/dbgen/queries/postgres/2.sql'
#query = 'tpch-kit/dbgen/queries/postgres/13.sql'
#query = 'count-3.sql'
#query = 'tpch-kit/dbgen/queries/postgres/11.sql'
#query = '11-simple.sql'
query = 'median-2-hint.sql'
#query = 'tpch-kit/dbgen/queries/postgres/7.sql'
#query = '13-simple.sql'
#query = 'dsb-queries/agg_queries/query101.sql'
query = 'snap-queries/patents/tree01.sql'

spark.sql("SET spark.sql.yannakakis.enabled = true").show()

start_time = time.time()

df1 = run_query(query)
df1.show()
df1.explain(mode="extended")

end_time = time.time()
yannakakis_time = end_time - start_time

spark.sql("SET spark.sql.yannakakis.enabled = false").show()

start_time = time.time()

df2 = run_query(query)
df2.show()
df2.explain(mode="extended")

end_time = time.time()
ref_time = end_time - start_time

#print(f'row count: {df1.count()} vs. {df2.count()}' )
print(f'time ref: {ref_time}\ntime yannakakis: {yannakakis_time}')