In [8]:
import pyspark
from pyspark.sql import (
    functions as f,
    SparkSession,
    types as t
)

conf = pyspark.SparkConf().setAll([('spark.sql.optimizer.dynamicPartitionPruning.enabled', 'true')])
spark = SparkSession.builder.appName("partition_pruning").config(conf=conf).getOrCreate()

table_schema = t.StructType([
    t.StructField("date", t.StringType(), True),
    t.StructField("name", t.StringType(), True),
    t.StructField("region", t.IntegerType(), True),
    t.StructField("price", t.IntegerType(), True)])

csv_file_path = "file:///home/jovyan/work/sample/ecommerce_order.csv"
df = spark.read.schema(table_schema).csv(csv_file_path)

In [10]:
# write the file with the partition
df.write.partitionBy("region").mode("overwrite").parquet("/home/jovyan/work/output/partition_pruning")


In [22]:
read_df = spark.read.parquet("/home/jovyan/work/output/partition_pruning")
sales_total_df = read_df.where("region==2").agg(f.round(f.sum("price"),2).alias("sales"))

In [23]:
sales_total_df.explain(mode="formatted")

== Physical Plan ==
AdaptiveSparkPlan (6)
+- HashAggregate (5)
   +- Exchange (4)
      +- HashAggregate (3)
         +- Project (2)
            +- Scan parquet  (1)


(1) Scan parquet 
Output [2]: [price#230, region#231]
Batched: true
Location: InMemoryFileIndex [file:/home/jovyan/work/output/partition_pruning]
PartitionFilters: [isnotnull(region#231), (region#231 = 2)]
ReadSchema: struct<price:int>

(2) Project
Output [1]: [price#230]
Input [2]: [price#230, region#231]

(3) HashAggregate
Input [1]: [price#230]
Keys: []
Functions [1]: [partial_sum(price#230)]
Aggregate Attributes [1]: [sum#244L]
Results [1]: [sum#245L]

(4) Exchange
Input [1]: [sum#245L]
Arguments: SinglePartition, ENSURE_REQUIREMENTS, [plan_id=187]

(5) HashAggregate
Input [1]: [sum#245L]
Keys: []
Functions [1]: [sum(price#230)]
Aggregate Attributes [1]: [sum(price#230)#237L]
Results [1]: [round(sum(price#230)#237L, 2) AS sales#238L]

(6) AdaptiveSparkPlan
Output [1]: [sales#238L]
Arguments: isFinalPlan=false


