In [29]:
from sqlglot import parse_one
from sqlglot.planner import *
from sqlglot.optimizer import optimize
from sqlglot.expressions import *


In [40]:
sql = """
select s.name, ad.address, count(*) as total_count from Students s, Addresses ad 
where s.student_id = ad.student_id and s.year > 3 order by s.name Asc;
"""
parsed = parse_one(sql)
optimized = optimize(parsed)
plan = Plan(optimized)

In [41]:
# Traverse in topological order
def topological_traverse(plan):
    dag = plan.dag
    visited = set()
    result = []

    def visit(step):
        if step in visited:
            return
        visited.add(step)
        for dep in dag[step]:
            visit(dep)
        result.append(step)

    visit(plan.root)
    return result
topological_traverse(plan)

[- Scan: ad (2060794747680)
     Context:
       Source: "addresses" AS "ad"
     Projections:,
 - Scan: s (2060815953472)
     Context:
       Source: "students" AS "s"
     Projections:,
 - Join: s (2060815949296)
     Context:
       Source: s
       ad: INNER
       Key: "ad"."student_id"
       On: TRUE AND TRUE
     Projections:
     Condition: "s"."year" > 3
     Dependencies:
     - Scan: ad (2060794747680)
       Context:
         Source: "addresses" AS "ad"
       Projections:
     - Scan: s (2060815953472)
       Context:
         Source: "students" AS "s"
       Projections:,
 - Aggregate: s (2060815954624)
     Context:
       Aggregations:
         - COUNT("_a_0") AS "total_count"
       Operands:
         - * AS _a_0
     Projections:
     Dependencies:
     - Join: s (2060815949296)
       Context:
         Source: s
         ad: INNER
         Key: "ad"."student_id"
         On: TRUE AND TRUE
       Projections:
       Condition: "s"."year" > 3
       Dependencies:
   

In [None]:
import sqlglot.planner


def topological_traverse_to_plan(plan):
    dag = plan.dag
    visited = set()
    result = []

    def visit(step):
        if step in visited:
            return
        visited.add(step)
        for dep in dag[step]:
            visit(dep)
        result.append(step)

    visit(plan.root)
    execution_plan = []
    for step in result:
        if isinstance(step, Scan):
            execution_plan.append({"operation": "SELECT", "column": step.name})
        if isinstance(step, Join):
            execution_plan.append({"operation": "JOIN", "left": step.source_name, "right": step.alias_column_names})
    return execution_plan


execution_plan = topological_traverse_to_plan(plan)
print(execution_plan)

[{'operation': 'SELECT', 'column': 'ad'}, {'operation': 'SELECT', 'column': 's'}]


In [10]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("PySpark SQL on UCI Datasets").getOrCreate()
spark.read.csv("Address.csv", header=True, inferSchema=True).createOrReplaceTempView("Addresses")
spark.read.csv("Student.csv", header=True, inferSchema=True).createOrReplaceTempView("Students")


In [30]:
query = spark.sql('''
SELECT s.name, a.address, COUNT(*) as count
FROM Students s, Addresses a
WHERE s.id = a.id AND s.year > 3
GROUP BY s.name, a.address
ORDER BY s.name ASC;
            ''')
# query.show()
logical = query._jdf.queryExecution().logical()
print(logical)

'Sort ['s.name ASC NULLS FIRST], true
+- 'Aggregate ['s.name, 'a.address], ['s.name, 'a.address, 'COUNT(1) AS count#377]
   +- 'Filter (('s.id = 'a.id) AND ('s.year > 3))
      +- 'Join Inner
         :- 'SubqueryAlias s
         :  +- 'UnresolvedRelation [Students], [], false
         +- 'SubqueryAlias a
            +- 'UnresolvedRelation [Addresses], [], false

