In [1]:
from pyspark.sql import SparkSession

In [2]:
spark = SparkSession.builder.appName("pyspark optimization").getOrCreate()

In [3]:
from google.colab import files

uploaded = files.upload()

Saving 2015-summary.csv to 2015-summary.csv


In [4]:
df = spark.read.csv(  "2015-summary.csv", header=True, inferSchema=True)

In [5]:
df.show()

+--------------------+-------------------+-----+
|   DEST_COUNTRY_NAME|ORIGIN_COUNTRY_NAME|count|
+--------------------+-------------------+-----+
|       United States|            Romania|   15|
|       United States|            Croatia|    1|
|       United States|            Ireland|  344|
|               Egypt|      United States|   15|
|       United States|              India|   62|
|       United States|          Singapore|    1|
|       United States|            Grenada|   62|
|          Costa Rica|      United States|  588|
|             Senegal|      United States|   40|
|             Moldova|      United States|    1|
|       United States|       Sint Maarten|  325|
|       United States|   Marshall Islands|   39|
|              Guyana|      United States|   64|
|               Malta|      United States|    1|
|            Anguilla|      United States|   41|
|             Bolivia|      United States|   30|
|       United States|           Paraguay|    6|
|             Algeri

In [15]:
df.count()

256

In [17]:
df.rdd.getNumPartitions()

5

In [18]:
from google.colab import files

uploaded = files.upload()

Saving 2015-summary.csv to 2015-summary (1).csv


In [19]:
df = spark.read.csv(  "2015-summary.csv", header=True, inferSchema=True)

In [20]:
df.show()

+--------------------+-------------------+-----+
|   DEST_COUNTRY_NAME|ORIGIN_COUNTRY_NAME|count|
+--------------------+-------------------+-----+
|       United States|            Romania|   15|
|       United States|            Croatia|    1|
|       United States|            Ireland|  344|
|               Egypt|      United States|   15|
|       United States|              India|   62|
|       United States|          Singapore|    1|
|       United States|            Grenada|   62|
|          Costa Rica|      United States|  588|
|             Senegal|      United States|   40|
|             Moldova|      United States|    1|
|       United States|       Sint Maarten|  325|
|       United States|   Marshall Islands|   39|
|              Guyana|      United States|   64|
|               Malta|      United States|    1|
|            Anguilla|      United States|   41|
|             Bolivia|      United States|   30|
|       United States|           Paraguay|    6|
|             Algeri

In [21]:
df.count()

256

In [22]:
df.rdd.getNumPartitions()

1

Repartitions and Coalesce

In [31]:
repartition_df= df.repartition(4)
repartition_df.rdd.getNumPartitions()

4

In [33]:
from pyspark.sql.functions import spark_partition_id
repartition_df.withColumn("partitionId", spark_partition_id()).groupBy("partitionId").count().show()

+-----------+-----+
|partitionId|count|
+-----------+-----+
|          0|   64|
|          1|   64|
|          2|   64|
|          3|   64|
+-----------+-----+



In [35]:
partitioned_on_column = df.repartition(25, "DEST_COUNTRY_NAME")
partitioned_on_column.rdd.getNumPartitions()

25

In [36]:
partitioned_on_column.withColumn("partitionId",spark_partition_id()).groupBy("partitionId").count().show(25)

+-----------+-----+
|partitionId|count|
+-----------+-----+
|          0|    5|
|          1|    6|
|          2|    7|
|          3|    5|
|          4|    6|
|          5|    7|
|          6|    8|
|          7|    5|
|          8|    3|
|          9|  132|
|         10|    4|
|         11|    3|
|         12|    6|
|         13|    8|
|         14|    2|
|         15|   11|
|         16|    4|
|         17|    4|
|         18|    3|
|         19|    4|
|         20|    8|
|         21|    4|
|         22|    4|
|         23|    5|
|         24|    2|
+-----------+-----+



In [40]:
coalesce_df = repartition_df.repartition(8)

In [41]:
coalesce_df.withColumn("partitonId",spark_partition_id()).groupBy("partitonId").count().show()

+----------+-----+
|partitonId|count|
+----------+-----+
|         0|   32|
|         1|   32|
|         2|   32|
|         3|   32|
|         4|   32|
|         5|   32|
|         6|   32|
|         7|   32|
+----------+-----+



In [42]:
three_coalese_df = coalesce_df.coalesce(3)

In [44]:
three_coalese_df.withColumn("partitonId",spark_partition_id()).groupBy("partitonId").count().show()

+----------+-----+
|partitonId|count|
+----------+-----+
|         0|   64|
|         1|   96|
|         2|   96|
+----------+-----+



#cache & persist

In [45]:
from pyspark.sql.functions import col
origin_df = df.filter(col("ORIGIN_COUNTRY_NAME")=="origincountry names")

In [47]:
origin_df.cache()
origin_df.count()    #for initiang or calling cache

origin_df.groupBy("DEST_COUNTRY_NAME").sum("count").show()

+-----------------+----------+
|DEST_COUNTRY_NAME|sum(count)|
+-----------------+----------+
+-----------------+----------+



In [48]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col

spark = SparkSession.builder.appName("Cache Example").getOrCreate()

# Original dataset
df = spark.createDataFrame([
    ("Alice", "Electronics", 500),
    ("Bob", "Electronics", 300),
    ("Alice", "Clothing", 200),
    ("Bob", "Clothing", 100),
    ("Charlie", "Electronics", 150)
], ["customer", "category", "amount"])

# Filter for Electronics (used multiple times)
electronics_df = df.filter(col("category") == "Electronics")

In [50]:
df.show()

+--------+-----------+------+
|customer|   category|amount|
+--------+-----------+------+
|   Alice|Electronics|   500|
|     Bob|Electronics|   300|
|   Alice|   Clothing|   200|
|     Bob|   Clothing|   100|
| Charlie|Electronics|   150|
+--------+-----------+------+



In [51]:
electronics_df.cache()
electronics_df.count()  # triggers cache

electronics_df.groupBy("customer").sum("amount").show()
electronics_df.agg({"amount": "avg"}).show()

+--------+-----------+
|customer|sum(amount)|
+--------+-----------+
|     Bob|        300|
|   Alice|        500|
| Charlie|        150|
+--------+-----------+

+-----------------+
|      avg(amount)|
+-----------------+
|316.6666666666667|
+-----------------+



In [54]:
import time
from pyspark import StorageLevel
filtered_df = df.filter(df["amount"] > 150).persist(StorageLevel.MEMORY_AND_DISK)

# First action triggers caching
start = time.time()
filtered_df.count()
print("With persist - first action (triggers cache):", round(time.time() - start, 2), "sec")

# Second action reuses the cache
start = time.time()
filtered_df.groupBy("customer").sum("amount").show()
print("With persist - second action (faster):", round(time.time() - start, 2), "sec")

With persist - first action (triggers cache): 0.93 sec
+--------+-----------+
|customer|sum(amount)|
+--------+-----------+
|     Bob|        300|
|   Alice|        700|
+--------+-----------+

With persist - second action (faster): 0.43 sec


In [57]:
filtered_df.unpersist()

DataFrame[customer: string, category: string, amount: bigint]

#broadcasting join Vs Shuffle join

In [59]:
#big dataframe
df_transactions = spark.createDataFrame([
    (1,"us", 1000),
    (1,"us", 1000),
    (1,"us", 1000),
    (1,"us", 1000),
],  ["id","country_code","amount"])

#small dataset

df_countries = spark.createDataFrame([
    ("us","united states"),
    ("In", "india"),
    ("uk","united kigdom"),
], ["country-code", "country_name"])

In [60]:
df_transactions.show()

+---+------------+------+
| id|country_code|amount|
+---+------------+------+
|  1|          us|  1000|
|  1|          us|  1000|
|  1|          us|  1000|
|  1|          us|  1000|
+---+------------+------+



In [61]:
df_countries.show()

+------------+-------------+
|country-code| country_name|
+------------+-------------+
|          us|united states|
|          In|        india|
|          uk|united kigdom|
+------------+-------------+



In [62]:
df_join= df_transactions.join(df_countries, df_transactions["country_code"]==df_countries["country-code"],"inner")

In [63]:
df_join.show()

+---+------------+------+------------+-------------+
| id|country_code|amount|country-code| country_name|
+---+------------+------+------------+-------------+
|  1|          us|  1000|          us|united states|
|  1|          us|  1000|          us|united states|
|  1|          us|  1000|          us|united states|
|  1|          us|  1000|          us|united states|
+---+------------+------+------------+-------------+



In [None]:
#broadcast joins dont shuffle instead send the small tables to each of the executor node
from pyspark.sql.functions import broadcast

df_join = df_transactions.join(
    broadcast(df_countries),
    df_transactions["country_code"] == df_countries["country-code"],
    "inner")

#Task

In [67]:
from pyspark.sql.functions import broadcast
import time
spark = SparkSession.builder.appName("SimpleOptimization").getOrCreate()

# Making Small lookup table
products = [(1, "computer"), (2, "Mouse"), (3, "Keyboard")]
df_products = spark.createDataFrame(products, ["product_id", "product_name"])

# Now, Small order table
orders = [(101, 1), (102, 2), (103, 3)]
df_orders = spark.createDataFrame(orders, ["order_id", "product_id"])

# by repeating 10 times(large orders)
df_large_orders = df_orders
for _ in range(10):
    df_large_orders = df_large_orders.union(df_orders)

# No optimization
start = time.time()
df_join1 = df_large_orders.join(df_products, "product_id")
df_join1.show()
print("Time without optimization:", round(time.time() - start, 2), "sec")

# Optimization
df_large_orders = df_large_orders.repartition(4)
df_large_orders.cache()

start = time.time()
df_join2 = df_large_orders.join(broadcast(df_products), "product_id")
df_join2.show()
print("Time with optimization:", round(time.time() - start, 2), "sec")

+----------+--------+------------+
|product_id|order_id|product_name|
+----------+--------+------------+
|         1|     101|    computer|
|         1|     101|    computer|
|         1|     101|    computer|
|         1|     101|    computer|
|         1|     101|    computer|
|         1|     101|    computer|
|         1|     101|    computer|
|         1|     101|    computer|
|         1|     101|    computer|
|         1|     101|    computer|
|         1|     101|    computer|
|         2|     102|       Mouse|
|         2|     102|       Mouse|
|         2|     102|       Mouse|
|         2|     102|       Mouse|
|         2|     102|       Mouse|
|         2|     102|       Mouse|
|         2|     102|       Mouse|
|         2|     102|       Mouse|
|         2|     102|       Mouse|
+----------+--------+------------+
only showing top 20 rows

Time without optimization: 5.12 sec
+----------+--------+------------+
|product_id|order_id|product_name|
+----------+--------+-------