In [1]:
from pyspark.sql import SparkSession
import os


In [2]:
spark = SparkSession.builder \
        .master("local[30]") \
        .appName("app") \
        .config("spark.driver.memory", "900g") \
        .config("spark.executor.memory", "900g") \
        .config("spark.memory.offHeap.enabled",False) \
        .config("spark.jars", "postgresql-42.3.3.jar") \
        .getOrCreate()
#spark.sparkContext.setLogLevel("DEBUG")

23/11/11 23:17:19 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


In [None]:
spark.sparkContext.setLogLevel("WARN") # ALL, DEBUG, WARN,
#spark.sparkContext.setLogLevel("ALL") # ALL, DEBUG, WARN,

In [3]:
username = os.environ.get('USERNAME', 'tpch')
password = os.environ.get('PASSWORD', 'tpch')
dbname = os.environ.get('DBNAME', 'tpch')
dbhost = os.environ.get('DBHOST', 'postgres')

df_tables = spark.read.format("jdbc") \
    .option("url", f'jdbc:postgresql://{dbhost}:5432/{dbname}') \
    .option("driver", "org.postgresql.Driver") \
    .option("dbtable", "information_schema.tables") \
    .option("user", username) \
    .option("password", password) \
    .load()

for idx, row in df_tables.toPandas().iterrows():
        if row.table_schema == 'public':
            table_name = row.table_name
            df = spark.read.format("jdbc") \
                .option("url", f'jdbc:postgresql://{dbhost}:5432/{dbname}') \
                .option("driver", "org.postgresql.Driver") \
                .option("dbtable", table_name) \
                .option("user", username) \
                .option("password", password) \
                .load()
    
            print(table_name)
            #print(df.show())
            df.createOrReplaceTempView(table_name)
            spark.catalog.cacheTable(table_name)

part
supplier
partsupp
customer
orders
lineitem
nation
region


In [None]:
def extract_metrics(spark, group_id):
    parsed = list(urlsplit(spark.sparkContext.uiWebUrl))
    host_port = parsed[1]
    parsed[1] = 'localhost' + host_port[host_port.find(':'):]
    API_URL = f'{urlunsplit(parsed)}/api/v1'

    app_id = spark.sparkContext.applicationId
    sql_queries = requests.get(API_URL + f'/applications/{app_id}/sql', params={'length': '100000'}).json()
    query_ids = [q['id'] for q in sql_queries if q['description'] == group_id]
    if (len(query_ids) == 0):
        print(f'query with group {group_id} not found')
        return None
    query_id = query_ids[0]
    print(f'query id: {query_id}')
    
    query_details = requests.get(API_URL + f'/applications/{app_id}/sql/{query_id}',
                                 params={'details': 'true', 'planDescription': 'true'}).json()
    
    success_job_ids = query_details['successJobIds']
    running_job_ids = query_details['runningJobIds']
    failed_job_ids = query_details['failedJobIds']
    
    job_ids = success_job_ids + running_job_ids + failed_job_ids
    
    job_details = [requests.get(API_URL + f'/applications/{app_id}/jobs/{jid}').json() for jid in job_ids]
    
    job_stages = {}
    
    for j in job_details:
        stage_ids = j['stageIds']
        
        stage_params = {'details': 'true', 'withSummaries': 'true'}
        stages = [requests.get(API_URL + f'/applications/{app_id}/stages/{sid}', stage_params) for sid in stage_ids]
        
        job_stages[j['jobId']] = [stage.json() for stage in stages if stage.status_code == 200] # can be 404
    
    return query_details, job_details, job_stages

In [None]:
spark.sql("SET spark.sql.yannakakis.countGroupInLeaves = false").show()

In [None]:
spark.sql("SET spark.sql.yannakakis.enabled = false").show()

In [None]:
spark.sql("SET spark.sql.yannakakis.enabled = true").show()
spark.sql("SET spark.sql.yannakakis.countGroupInLeaves = false").show()

In [None]:
spark.sql("SET spark.local.dir").show()

In [None]:
spark.sql("ANALYZE TABLE part COMPUTE STATISTICS;").show()

In [11]:
df = spark.sql("""select *
        		from
            part,
			partsupp,
			supplier,
			nation,
			region
		where
			p_partkey = ps_partkey
			and s_suppkey = ps_suppkey
			and s_nationkey = n_nationkey
			and n_regionkey = r_regionkey
            AND p_retailprice >
                (SELECT avg (p_retailprice) FROM part)""")

df.show(500)

df.explain(True)

23/11/11 10:51:53 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.

+---------+--------------------+--------------------+----------+--------------------+------+-----------+--------------------+--------------------+----------+----------+-----------+--------------------+--------------------+---------+--------------------+--------------------+-----------+---------------+--------------------+--------------------+-----------+--------------------+-----------+--------------------+-----------+--------------------+--------------------+
|p_partkey|              p_name|              p_mfgr|   p_brand|              p_type|p_size|p_container|       p_retailprice|           p_comment|ps_partkey|ps_suppkey|ps_availqty|       ps_supplycost|          ps_comment|s_suppkey|              s_name|           s_address|s_nationkey|        s_phone|           s_acctbal|           s_comment|n_nationkey|              n_name|n_regionkey|           n_comment|r_regionkey|              r_name|           r_comment|
+---------+--------------------+--------------------+----------+------

                                                                                

In [4]:
def run_query(file):
    with open(file, 'r') as f:
        query = '\n'.join(filter(lambda line: not line.startswith('limit') and not line.startswith('-'), f.readlines()))
        
        print("running query: \n" + query)
        return spark.sql(query)

In [None]:
df = spark.sql("""
SELECT * FROM region;
""")

df.show()

df.explain(True)

In [4]:
df = spark.sql("SELECT * from nation, region where n_regionkey = r_regionkey and r_name IN ('EUROPE', 'ASIA')")
df.show()

+-----------+--------------------+-----------+--------------------+-----------+--------------------+--------------------+
|n_nationkey|              n_name|n_regionkey|           n_comment|r_regionkey|              r_name|           r_comment|
+-----------+--------------------+-----------+--------------------+-----------+--------------------+--------------------+
|         21|VIETNAM          ...|          2|hely enticingly e...|          2|ASIA             ...|ges. thinly even ...|
|         18|CHINA            ...|          2|c dependencies. f...|          2|ASIA             ...|ges. thinly even ...|
|         12|JAPAN            ...|          2|ously. final, exp...|          2|ASIA             ...|ges. thinly even ...|
|          9|INDONESIA        ...|          2| slyly express as...|          2|ASIA             ...|ges. thinly even ...|
|          8|INDIA            ...|          2|ss excuses cajole...|          2|ASIA             ...|ges. thinly even ...|
|         23|UNITED KING

In [None]:
df_t1 = spark.createDataFrame([(1,1), (2,1), (2,2), (3,2), (3,3), (4,3), (4,3), (5,2), (5,1), (6,4)], schema=("a","b"))
df_t1.createOrReplaceTempView("t1")
df_t2 = spark.createDataFrame([(1,1), (2,1), (3,2), (3,2), (3,3), (3,3), (4,3), (4,2), (5,1), (6,4)], schema=("c","d"))
df_t2.createOrReplaceTempView("t2")
df_t3 = spark.createDataFrame([(1,1), (2,1), (3,2), (3,2), (3,3), (3,3), (4,3), (4,2), (5,1), (6,4)], schema=("e","f"))
df_t3.createOrReplaceTempView("t3")

query = "select median(a) from t1, t2 where b = c"
#query = "select percentile(a, 0.5, b) from t1, t2 where b = c"
#query = "select median(a) from t1 where EXISTS (SELECT 1 FROM t2 WHERE b = c)"
#query = "select count(*) from t1, t2 where b = c"
#query = "select *a from t1 where EXISTS (SELECT 1 FROM t2 WHERE b = c)"

spark.sql("SET spark.sql.yannakakis.enabled = false").show()

df = spark.sql(query)
df.show()

spark.sql("SET spark.sql.yannakakis.enabled = true").show()

df = spark.sql(query)
df.show()

In [None]:
spark.conf.set("spark.sql.legacy.setCommandRejectsSparkCoreConfs","false")
#spark.conf.set("spark.executor.cores", "6")
#spark.conf.set("spark.executor.instances", "6")
spark.conf.set("spark.sql.shuffle.partitions", "6")

In [7]:
spark.sparkContext.uiWebUrl

'http://a8dc08c6eef4:4040'

In [None]:
import pandas as pd
import time

def benchmark(query):
    df0 = run_query(query)
    df0.show()
    
    spark.sql("SET spark.sql.yannakakis.enabled = true").show()

    start_time = time.time()

    df1 = run_query(query)
    df1.show()
    #df1.explain(True)

    end_time = time.time()
    yannakakis_time = end_time - start_time

    spark.sql("SET spark.sql.yannakakis.enabled = false").show()

    start_time = time.time()

    df2 = run_query(query)
    df2.show()
    #df2.explain(True)

    end_time = time.time()
    ref_time = end_time - start_time
    
    #return [query, ref_time, yannakakis_time]
    return [query, ref_time, yannakakis_time]


queries = ['tpch-kit/dbgen/queries/postgres/2.sql',
           'tpch-kit/dbgen/queries/postgres/11.sql', 
           'tpch-kit/dbgen/queries/postgres/11-hint.sql',
           'median-1.sql',
           'median-2.sql', 
           'median-3.sql', 
           'median-4.sql', 
           'median-5.sql',
        'median-1-hint.sql',
           'median-2-hint.sql', 
           'median-3-hint.sql', 
           'median-4-hint.sql', 
           'median-5-hint.sql'] * 4

results = [benchmark(q) for q in queries]

df = pd.DataFrame(results, columns = ['query', 'ref_time', 'yannakakis_time'])

print(df)

df.to_csv("results.csv")
    

#print(f'row count: {df1.count()} vs. {df2.count()}' )
    #print(f'time ref: {ref_time}\ntime yannakakis: {yannakakis_time}')

running query: 




select

	s_acctbal,

	s_name,

	n_name,

	p_partkey,

	p_mfgr,

	s_address,

	s_phone,

	s_comment

from

	part,

	supplier,

	partsupp,

	nation,

	region

where

	p_partkey = ps_partkey

	and s_suppkey = ps_suppkey

	and p_size = 15

	and p_type like '%BRASS'

	and s_nationkey = n_nationkey

	and n_regionkey = r_regionkey

	and r_name = 'EUROPE'

	and ps_supplycost = (

		select

			min(ps_supplycost)

		from

			partsupp,

			supplier,

			nation,

			region

		where

			p_partkey = ps_partkey

			and s_suppkey = ps_suppkey

			and s_nationkey = n_nationkey

			and n_regionkey = r_regionkey

			and r_name = 'EUROPE'

	)

order by

	s_acctbal desc,

	n_name,

	s_name,

	p_partkey;



[Stage 2:>                  (0 + 1) / 1][Stage 3:>                  (0 + 1) / 1]

In [None]:
#spark.conf.set("spark.sql.legacy.setCommandRejectsSparkCoreConfs","false")
#spark.conf.set("spark.executor.cores", "1")
#spark.conf.set("spark.executor.instances", "1")
spark.conf.set("spark.sql.shuffle.partitions", "1")

In [None]:
spark.uiWebUrl

In [7]:
## Compare result
import time
query = 'tpch-kit/dbgen/queries/postgres/2.sql'
#query = 'tpch-kit/dbgen/queries/postgres/13.sql'
#query = 'count-3.sql'
#query = 'tpch-kit/dbgen/queries/postgres/11.sql'
#query = '11-simple.sql'
query = 'median-1-hint.sql'
#query = 'median-1.sql'
#query = 'tpch-kit/dbgen/queries/postgres/7.sql'
#query = '13-simple.sql'
#query = 'subselect-exists.sql'
#query = 'min-1.sql'

spark.sql("SET spark.sql.yannakakis.enabled = true").show()

start_time = time.time()

df1 = run_query(query)
df1.show()
df1.explain(mode="extended")

end_time = time.time()
yannakakis_time = end_time - start_time

spark.sql("SET spark.sql.yannakakis.enabled = false").show()

start_time = time.time()

df2 = run_query(query)
df2.show()
df2.explain(mode="extended")

end_time = time.time()
ref_time = end_time - start_time

#print(f'row count: {df1.count()} vs. {df2.count()}' )
print(f'time ref: {ref_time}\ntime yannakakis: {yannakakis_time}')

+--------------------+-----+
|                 key|value|
+--------------------+-----+
|spark.sql.yannaka...| true|
+--------------------+-----+

running query: 
select

    /*+ FK(ps_partkey, p_partkey), FK(n_regionkey, r_regionkey), FK(ps_suppkey, s_suppkey), FK(s_nationkey, n_nationkey), PK(ps_partkey, ps_suppkey) */

        median(s_acctbal)

		from

            part,

			partsupp,

			supplier,

			nation,

			region

		where

			p_partkey = ps_partkey

			and s_suppkey = ps_suppkey

			and s_nationkey = n_nationkey

			and n_regionkey = r_regionkey

            AND p_retailprice >

                (SELECT avg (p_retailprice) FROM part)

            and r_name IN ('EUROPE', 'ASIA')


23/11/11 22:10:23 WARN RewriteJoinsAsSemijoins: not applicable to aggregate: Aggregate [avg(p_retailprice#7376) AS avg(p_retailprice)#7368]
+- Project [p_retailprice#7376]
   +- InMemoryRelation [p_partkey#7369, p_name#7370, p_mfgr#33, p_brand#34, p_type#7373, p_size#7374, p_container#35, p_retailprice#7376, p_comment#7377], StorageLevel(disk, memory, deserialized, 1 replicas)
         +- *(1) Project [p_partkey#24, p_name#25, staticinvoke(class org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils, StringType, readSidePadding, p_mfgr#26, 25, true, false, true) AS p_mfgr#33, staticinvoke(class org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils, StringType, readSidePadding, p_brand#27, 10, true, false, true) AS p_brand#34, p_type#28, p_size#29, staticinvoke(class org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils, StringType, readSidePadding, p_container#30, 10, true, false, true) AS p_container#35, p_retailprice#31, p_comment#32]
            +- *(1) Scan JDBCRelati

+-----------------+
|median(s_acctbal)|
+-----------------+
|          4494.43|
+-----------------+

== Parsed Logical Plan ==
'UnresolvedHint FK, ['ps_partkey, 'p_partkey]
+- 'UnresolvedHint FK, ['n_regionkey, 'r_regionkey]
   +- 'UnresolvedHint FK, ['ps_suppkey, 's_suppkey]
      +- 'UnresolvedHint FK, ['s_nationkey, 'n_nationkey]
         +- 'UnresolvedHint PK, ['ps_partkey, 'ps_suppkey]
            +- 'Project [unresolvedalias('median('s_acctbal), None)]
               +- 'Filter (((('p_partkey = 'ps_partkey) AND ('s_suppkey = 'ps_suppkey)) AND ('s_nationkey = 'n_nationkey)) AND ((('n_regionkey = 'r_regionkey) AND ('p_retailprice > scalar-subquery#7366 [])) AND 'r_name IN (EUROPE,ASIA)))
                  :  +- 'Project [unresolvedalias('avg('p_retailprice), None)]
                  :     +- 'UnresolvedRelation [part], [], false
                  +- 'Join Inner
                     :- 'Join Inner
                     :  :- 'Join Inner
                     :  :  :- 'Join Inner
     

23/11/11 22:10:53 WARN DAGScheduler: Broadcasting large task binary with size 2.1 MiB
23/11/11 22:11:16 WARN DAGScheduler: Broadcasting large task binary with size 2.1 MiB
23/11/11 22:11:24 WARN DAGScheduler: Broadcasting large task binary with size 2.1 MiB
[Stage 164:>                                                        (0 + 1) / 1]

+-----------------+
|median(s_acctbal)|
+-----------------+
|          4494.43|
+-----------------+

== Parsed Logical Plan ==
'UnresolvedHint FK, ['ps_partkey, 'p_partkey]
+- 'UnresolvedHint FK, ['n_regionkey, 'r_regionkey]
   +- 'UnresolvedHint FK, ['ps_suppkey, 's_suppkey]
      +- 'UnresolvedHint FK, ['s_nationkey, 'n_nationkey]
         +- 'UnresolvedHint PK, ['ps_partkey, 'ps_suppkey]
            +- 'Project [unresolvedalias('median('s_acctbal), None)]
               +- 'Filter (((('p_partkey = 'ps_partkey) AND ('s_suppkey = 'ps_suppkey)) AND ('s_nationkey = 'n_nationkey)) AND ((('n_regionkey = 'r_regionkey) AND ('p_retailprice > scalar-subquery#9055 [])) AND 'r_name IN (EUROPE,ASIA)))
                  :  +- 'Project [unresolvedalias('avg('p_retailprice), None)]
                  :     +- 'UnresolvedRelation [part], [], false
                  +- 'Join Inner
                     :- 'Join Inner
                     :  :- 'Join Inner
                     :  :  :- 'Join Inner
     

                                                                                