In [1]:
# Create Spark Session
from pyspark.sql import SparkSession
from os.path import abspath
# warehouse_location points to the default location for managed databases and tables
warehouse_location = abspath('spark-warehouse')
spark = SparkSession \
    .builder \
    .appName("Factor of cores") \
    .master("local[*]") \
    .config("spark.sql.warehouse.dir", warehouse_location) \
    .config("spark.executor.instances", "4") \
    .config("spark.executor.cores", "2") \
    .config("spark.executor.memory", "1G") \
    .config("spark.driver.memory", "4G") \
    .getOrCreate()
spark

In [3]:
# Specify the file path
file_path = "./Input/Sales_order/part-00000-2482645b-99e0-4e87-ba6c-4fc20fed5d14-c000.csv"

# Number of lines to read (including the header if needed)
n = 5

# Open the file and read lines
with open(file_path, 'r') as file:
    lines = [file.readline().strip() for _ in range(n)]

# Print the sampled lines
for line in lines:
    print(line)

ORD00001,2024-06-15,C1,128,422,54016,South,Uruguay
ORD00002,2021-06-16,C2,198,359,71082,East,China
ORD00003,2023-10-08,C5,93,305,28365,North,Bermuda
ORD00004,2021-10-07,C4,19,564,10716,North,Greenland
ORD00005,2022-10-11,C3,110,384,42240,South,Sri Lanka


In [2]:
order_schema = """
                Order_Id string,
                Order_Date date,
                Customer_ID string,
                Qty integer,
                Price integer,
                Amount integer,
                Sales_Region string,
                Country string
                """

In [4]:
sales_df = spark.read.format("csv").option("header",True).schema(order_schema).load("./Input/Sales_order/")
print(f"Number of Partition -> {sales_df.rdd.getNumPartitions()}")

Number of Partition -> 22


In [5]:
sales_df.printSchema()

root
 |-- Order_Id: string (nullable = true)
 |-- Order_Date: date (nullable = true)
 |-- Customer_ID: string (nullable = true)
 |-- Qty: integer (nullable = true)
 |-- Price: integer (nullable = true)
 |-- Amount: integer (nullable = true)
 |-- Sales_Region: string (nullable = true)
 |-- Country: string (nullable = true)



In [6]:
customer_data = [["C1","Pratap"],
                ["C2","Sruthi"],
                 ["C3","Nirupama"],
                 ["C4","Kiyanshitha"],
                 ["C5","Chand"] ]

customer_schema = """ Customer_ID string, 
                    Customer_Name string """

In [7]:
customer_df = spark.createDataFrame(data=customer_data, schema=customer_schema)

In [8]:
# Lets create a simple Python decorator - {get_time} to get the execution timings
# If you dont know about Python decorators - check out : https://www.geeksforgeeks.org/decorators-in-python/
import time

def get_time(func):
    def inner_get_time() -> str:
        start_time = time.time()
        func()
        end_time = time.time()
        print("-"*80)
        return (f"Execution time: {(end_time - start_time)*1000} ms")
    print(inner_get_time())
    print("-"*80)

In [9]:
customer_df.rdd.getNumPartitions()

8

In [13]:
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1)

In [12]:
spark.conf.get("spark.sql.autoBroadcastJoinThreshold")

'10485760b'

In [14]:
# Code for benchmarking
from pyspark.sql.functions import count, lit 
@get_time
def x(): 
    df_joined = sales_df.join(customer_df, on=sales_df.Customer_ID==customer_df.Customer_ID, how="left_outer")
    df_joined.write.format("noop").mode("overwrite").save()

--------------------------------------------------------------------------------
Execution time: 230696.02060317993 ms
--------------------------------------------------------------------------------


In [15]:
# Code for benchmarking
from pyspark.sql.functions import count, lit , broadcast
@get_time
def x(): 
    df_joined = sales_df.join(broadcast(customer_df), on=sales_df.Customer_ID==customer_df.Customer_ID, how="left_outer")
    df_joined.write.format("noop").mode("overwrite").save()

--------------------------------------------------------------------------------
Execution time: 113922.34110832214 ms
--------------------------------------------------------------------------------


In [None]:
# Code for benchmarking
from pyspark.sql.functions import count, lit , broadcast
@get_time
def x(): 
    df_joined = spark.sql("""SELECT /*+ BROADCAST(customer) */
                            sales.*,
                            customer.*
                            FROM sales
                            LEFT OUTER JOIN customer
                            ON sales.Customer_ID = customer.Customer_ID
                                """)
    df_joined.write.format("noop").mode("overwrite").save()

##Big table vs Big Table

In [18]:
sales_schema = "transacted_at string, trx_id string, retailer_id string, description string, amount double, city_id string"

sales = spark.read.format("csv").schema(sales_schema).option("header", True).load("./Input/new_sales.csv")

In [19]:
# Read City data

city_schema = "city_id string, city string, state string, state_abv string, country string"

city = spark.read.format("csv").schema(city_schema).option("header", True).load("./Input/cities.csv")

In [20]:
# Join Data

df_sales_joined = sales.join(city, on=sales.city_id==city.city_id, how="left_outer")

In [21]:
# Code for benchmarking
from pyspark.sql.functions import count, lit 
@get_time
def x(): 
    df_sales_joined.write.format("noop").mode("overwrite").save()

--------------------------------------------------------------------------------
Execution time: 30005.728721618652 ms
--------------------------------------------------------------------------------


In [22]:
# Write Sales data in Buckets

sales.write.format("csv").mode("overwrite").bucketBy(4, "city_id").option("header", True).saveAsTable("sales_bucket")

In [23]:
# Write City data in Buckets

city.write.format("csv").mode("overwrite").bucketBy(4, "city_id").option("header", True).saveAsTable("city_bucket")

In [24]:
# Check tables

spark.sql("show tables in default").show()

+---------+------------+-----------+
|namespace|   tableName|isTemporary|
+---------+------------+-----------+
|  default| city_bucket|      false|
|  default|sales_bucket|      false|
+---------+------------+-----------+



In [25]:
# Read Sales table

sales_bucket = spark.read.table("sales_bucket")

In [26]:
# Read City table

city_bucket = spark.read.table("city_bucket")

In [None]:
# Show a sample of user_id distribution in users table
sales_bucket.select("city_id").distinct().show()

# Show a sample of user_id distribution in orders table
city_bucket.select("city_id").distinct().show()

In [27]:
from pyspark.sql import functions as F
# Assuming that both tables are bucketed into 4 buckets (0 to 3)
# Filter users for bucket 0
city_bucket_0 = city_bucket.filter((F.hash("city_id") % 4) == 0)

# Filter orders for bucket 0
sales_bucket_0 = sales_bucket.filter((F.hash("city_id") % 4) == 0)

# Show the content of bucket 0 for both tables
city_bucket_0.where("city_id = '1030993386'").show()
sales_bucket_0.orderBy("city_id").show()


+----------+----------+-----+---------+-------+
|   city_id|      city|state|state_abv|country|
+----------+----------+-----+---------+-------+
|1030993386|Montevideo| NULL|     NULL|Uruguay|
+----------+----------+-----+---------+-------+

+--------------------+----------+-----------+--------------------+-------+----------+
|       transacted_at|    trx_id|retailer_id|         description| amount|   city_id|
+--------------------+----------+-----------+--------------------+-------+----------+
|2017-12-09T12:00:...|1614079909|  847200066|            Wal-Mart|   5.62|1030993386|
|2017-12-24T22:00:...| 745524396|  386167994|Wendy's          ...|  324.3|1030993386|
|2017-12-14T23:00:...|1080477925| 1295306792|Dick's Sporting G...|  84.38|1030993386|
|2017-11-25T23:00:...|1729862678|  847200066|Wal-Mart    arc i...|   47.6|1030993386|
|2017-12-09T12:00:...|1614145382|  902350112|DineEquity   ccd ...|2796.67|1030993386|
|2017-01-15T19:00:...|1344235985| 1334799521|TJ Max    ppd id:...|2590.

In [28]:
# Join datasets

df_joined_bucket = sales_bucket.join(city_bucket, on=sales_bucket.city_id==city_bucket.city_id, how="left_outer")

In [29]:
# Write dataset
# Code for benchmarking
from pyspark.sql.functions import count, lit 
@get_time
def x(): 
    df_joined_bucket.write.format("noop").mode("overwrite").save()

--------------------------------------------------------------------------------
Execution time: 19095.71671485901 ms
--------------------------------------------------------------------------------


In [31]:
spark.stop()

In [30]:
df_joined_bucket.explain()

== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- SortMergeJoin [city_id#289], [city_id#296], LeftOuter
   :- Sort [city_id#289 ASC NULLS FIRST], false, 0
   :  +- FileScan csv spark_catalog.default.sales_bucket[transacted_at#284,trx_id#285,retailer_id#286,description#287,amount#288,city_id#289] Batched: false, Bucketed: true, DataFilters: [], Format: CSV, Location: InMemoryFileIndex(1 paths)[file:/home/jupyter/PySpark Tutorials - Qbex/07 - Performance Tuning/sp..., PartitionFilters: [], PushedFilters: [], ReadSchema: struct<transacted_at:string,trx_id:string,retailer_id:string,description:string,amount:double,cit..., SelectedBucketsCount: 4 out of 4
   +- Sort [city_id#296 ASC NULLS FIRST], false, 0
      +- Filter isnotnull(city_id#296)
         +- FileScan csv spark_catalog.default.city_bucket[city_id#296,city#297,state#298,state_abv#299,country#300] Batched: false, Bucketed: true, DataFilters: [isnotnull(city_id#296)], Format: CSV, Location: InMemoryFileIndex(1 paths)[file