In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, lit, when,regexp_replace, split, trim, array_compact, transform, get_json_object, lower
spark = SparkSession.builder.appName("ride-hailing platform").getOrCreate()
from pyspark.sql.types import (StructType, StructField, StringType,LongType,IntegerType,ArrayType,MapType)

In [2]:
raw_drivers = [
("D001","Ramesh","35","Hyderabad","Car,Bike"),
("D002","Suresh","Forty","Bangalore","Auto"),
("D003","Anita",None,"Mumbai",["Car"]),
("D004","Kiran","29","Delhi","Car|Bike"),
("D005","", "42","Chennai",None)
]

In [3]:
driver_schema = StructType([
    StructField("driverid", StringType(), nullable=False),
    StructField("name", StringType(), nullable=True),
    StructField("age", StringType(), nullable=True),
    StructField("city", StringType(), nullable=True),
    StructField("vechile", StringType(), nullable=True)
])
df = spark.createDataFrame(raw_drivers,driver_schema)
df.show()

+--------+------+-----+---------+--------+
|driverid|  name|  age|     city| vechile|
+--------+------+-----+---------+--------+
|    D001|Ramesh|   35|Hyderabad|Car,Bike|
|    D002|Suresh|Forty|Bangalore|    Auto|
|    D003| Anita| NULL|   Mumbai|   [Car]|
|    D004| Kiran|   29|    Delhi|Car|Bike|
|    D005|      |   42|  Chennai|    NULL|
+--------+------+-----+---------+--------+



In [4]:
clean_age = df.withColumn("age", when(col("age") == "", None)
    .when(col("age").rlike(r"^\d+$"),
          col("age").cast(IntegerType()))
    .otherwise(None))

clean_name_city_vechile = clean_age.withColumn("name", when(col("name") == "", None)
    .otherwise(col("name"))) \
.withColumn("city",trim(col("city")))\
.withColumn(
    "vechile",
    (when(
        col("vechile").isNull(),
        None
    ).otherwise(
        array_compact(
            transform(
                split(
                    regexp_replace(
                        col("vechile"),
                        r"\[|\]|'|\|", ","),
                    ","),
                lambda x: when(trim(x) != lit(""), trim(x)).otherwise(lit(None))
            )
        )
    )).cast(ArrayType(StringType()))
)

clean_name_city_vechile.show()

driver_df = clean_name_city_vechile

+--------+------+----+---------+-----------+
|driverid|  name| age|     city|    vechile|
+--------+------+----+---------+-----------+
|    D001|Ramesh|  35|Hyderabad|[Car, Bike]|
|    D002|Suresh|NULL|Bangalore|     [Auto]|
|    D003| Anita|NULL|   Mumbai|      [Car]|
|    D004| Kiran|  29|    Delhi|[Car, Bike]|
|    D005|  NULL|  42|  Chennai|       NULL|
+--------+------+----+---------+-----------+



In [5]:
raw_cities = [
("Hyderabad","South"),
("Bangalore","South"),
("Mumbai","West"),
("Delhi","North"),
("Chennai","South")
]

In [6]:
city_schema = StructType([
    StructField("city", StringType(), nullable=True),
    StructField("region", StringType(), nullable=True)
])
city_df = spark.createDataFrame(raw_cities,city_schema)
city_df.show()

+---------+------+
|     city|region|
+---------+------+
|Hyderabad| South|
|Bangalore| South|
|   Mumbai|  West|
|    Delhi| North|
|  Chennai| South|
+---------+------+



Small reference dataset city_df

Intended for broadcast join

In [7]:
from pyspark.sql.functions import broadcast

In [8]:
driver_join  = driver_df.join(broadcast(city_df), "city", "inner")
driver_join.show()

+---------+--------+------+----+-----------+------+
|     city|driverid|  name| age|    vechile|region|
+---------+--------+------+----+-----------+------+
|Hyderabad|    D001|Ramesh|  35|[Car, Bike]| South|
|Bangalore|    D002|Suresh|NULL|     [Auto]| South|
|   Mumbai|    D003| Anita|NULL|      [Car]|  West|
|    Delhi|    D004| Kiran|  29|[Car, Bike]| North|
|  Chennai|    D005|  NULL|  42|       NULL| South|
+---------+--------+------+----+-----------+------+



In [9]:
raw_trips = [
("T001","D001","Hyderabad","2024-01-05","Completed","450"),
("T002","D002","Bangalore","05/01/2024","Cancelled","0"),
("T003","D003","Mumbai","2024/01/06","Completed","620"),
("T004","D004","Delhi","invalid_date","Completed","540"),
("T005","D001","Hyderabad","2024-01-10","Completed","700"),
("T006","D005","Chennai","2024-01-12","Completed","350")
]

PART A — DATA CLEANING & STRUCTURING

In [10]:
trips_schema = StructType([
    StructField("userid", StringType(), nullable=False),
    StructField("driverid", StringType(), nullable=False),
    StructField("city", StringType(), nullable=True),
    StructField("date", StringType(), nullable=True),
    StructField("status", StringType(), nullable=True),
    StructField("amount", StringType(), nullable=True),
])
trips_df = spark.createDataFrame(raw_trips,trips_schema)
trips_df.show()

+------+--------+---------+------------+---------+------+
|userid|driverid|     city|        date|   status|amount|
+------+--------+---------+------------+---------+------+
|  T001|    D001|Hyderabad|  2024-01-05|Completed|   450|
|  T002|    D002|Bangalore|  05/01/2024|Cancelled|     0|
|  T003|    D003|   Mumbai|  2024/01/06|Completed|   620|
|  T004|    D004|    Delhi|invalid_date|Completed|   540|
|  T005|    D001|Hyderabad|  2024-01-10|Completed|   700|
|  T006|    D005|  Chennai|  2024-01-12|Completed|   350|
+------+--------+---------+------------+---------+------+



In [11]:
from pyspark.sql.functions import col, to_date, coalesce, split, lit, array_remove, try_to_timestamp

In [12]:
clean_date_amount = trips_df.withColumn("amount", col("amount").cast(IntegerType()))\
.withColumn(
    "date",
    coalesce(
        to_date(try_to_timestamp(col("date"), lit("yyyy-MM-dd"))),
        to_date(try_to_timestamp(col("date"), lit("dd/MM/yyyy"))),
        to_date(try_to_timestamp(col("date"), lit("yyyy/MM/dd")))
    )
)

clean_date_amount = clean_date_amount.filter(col("amount") > 0)

clean_date_amount.show()
tripsdf=clean_date_amount

+------+--------+---------+----------+---------+------+
|userid|driverid|     city|      date|   status|amount|
+------+--------+---------+----------+---------+------+
|  T001|    D001|Hyderabad|2024-01-05|Completed|   450|
|  T003|    D003|   Mumbai|2024-01-06|Completed|   620|
|  T004|    D004|    Delhi|      NULL|Completed|   540|
|  T005|    D001|Hyderabad|2024-01-10|Completed|   700|
|  T006|    D005|  Chennai|2024-01-12|Completed|   350|
+------+--------+---------+----------+---------+------+



In [13]:
raw_activity = [
("D001","login,accept_trip,logout","{'device':'mobile'}",180),
("D002",["login","logout"],"device=laptop",60),
("D003","login|accept_trip",None,120),
("D004",None,"{'device':'tablet'}",90),
("D005","login","{'device':'mobile'}",30)
]

In [14]:
activity_schema = StructType([
    StructField("userid", StringType(), nullable=False),
    StructField("actions", StringType(), nullable=True),
    StructField("device", StringType(), nullable=True),
    StructField("amount", IntegerType(), nullable=True),
])
activity_df = spark.createDataFrame(raw_activity,activity_schema)
activity_df.show()

+------+--------------------+-------------------+------+
|userid|             actions|             device|amount|
+------+--------------------+-------------------+------+
|  D001|login,accept_trip...|{'device':'mobile'}|   180|
|  D002|     [login, logout]|      device=laptop|    60|
|  D003|   login|accept_trip|               NULL|   120|
|  D004|                NULL|{'device':'tablet'}|    90|
|  D005|               login|{'device':'mobile'}|    30|
+------+--------------------+-------------------+------+



In [15]:
df_activity_clean = activity_df.withColumn(
    "actions",
    (when(
        col("actions").isNull(),
        None
    ).otherwise(
        array_compact(
            transform(
                split(
                    regexp_replace(
                        col("actions"),
                        r"\[|\]|'|\|", ","),
                    ","),
                lambda x: when(trim(x) != lit(""), trim(x)).otherwise(lit(None))
            )
        )
    )).cast(ArrayType(StringType()))
).withColumn(
    "device",
    when(col("device").isNull(), None)
    .when(col("device").like("{'device':%}"), get_json_object(col("device"), "$.device"))
    .when(col("device").like("device=%"), split(col("device"), "=").getItem(1))
    .otherwise(None)
)

df_activity_clean.show(truncate=False)
df_activity_clean.printSchema()

+------+----------------------------+------+------+
|userid|actions                     |device|amount|
+------+----------------------------+------+------+
|D001  |[login, accept_trip, logout]|mobile|180   |
|D002  |[login, logout]             |laptop|60    |
|D003  |[login, accept_trip]        |NULL  |120   |
|D004  |NULL                        |tablet|90    |
|D005  |[login]                     |mobile|30    |
+------+----------------------------+------+------+

root
 |-- userid: string (nullable = false)
 |-- actions: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- device: string (nullable = true)
 |-- amount: integer (nullable = true)



In [16]:
df_activity_clean.show()
tripsdf.show()
driver_join.show()
driver_df.show()
city_df.show()

+------+--------------------+------+------+
|userid|             actions|device|amount|
+------+--------------------+------+------+
|  D001|[login, accept_tr...|mobile|   180|
|  D002|     [login, logout]|laptop|    60|
|  D003|[login, accept_trip]|  NULL|   120|
|  D004|                NULL|tablet|    90|
|  D005|             [login]|mobile|    30|
+------+--------------------+------+------+

+------+--------+---------+----------+---------+------+
|userid|driverid|     city|      date|   status|amount|
+------+--------+---------+----------+---------+------+
|  T001|    D001|Hyderabad|2024-01-05|Completed|   450|
|  T003|    D003|   Mumbai|2024-01-06|Completed|   620|
|  T004|    D004|    Delhi|      NULL|Completed|   540|
|  T005|    D001|Hyderabad|2024-01-10|Completed|   700|
|  T006|    D005|  Chennai|2024-01-12|Completed|   350|
+------+--------+---------+----------+---------+------+

+---------+--------+------+----+-----------+------+
|     city|driverid|  name| age|    vechile|re

PART B — DATA INTEGRATION (JOINS)


In [17]:
city_df.show()

trips_city_join  = tripsdf.join(broadcast(city_df), "city", "inner")
trips_city_join.show()

trips_city_join.explain(True)

ophan = trips_city_join.filter(~trips_city_join["date"].isNull())
ophan.show()

+---------+------+
|     city|region|
+---------+------+
|Hyderabad| South|
|Bangalore| South|
|   Mumbai|  West|
|    Delhi| North|
|  Chennai| South|
+---------+------+

+---------+------+--------+----------+---------+------+------+
|     city|userid|driverid|      date|   status|amount|region|
+---------+------+--------+----------+---------+------+------+
|Hyderabad|  T001|    D001|2024-01-05|Completed|   450| South|
|   Mumbai|  T003|    D003|2024-01-06|Completed|   620|  West|
|    Delhi|  T004|    D004|      NULL|Completed|   540| North|
|Hyderabad|  T005|    D001|2024-01-10|Completed|   700| South|
|  Chennai|  T006|    D005|2024-01-12|Completed|   350| South|
+---------+------+--------+----------+---------+------+------+

== Parsed Logical Plan ==
'Join UsingJoin(Inner, [city])
:- Filter (amount#97 > 0)
:  +- Project [userid#72, driverid#73, city#74, coalesce(to_date(try_to_timestamp(date#75, Some(yyyy-MM-dd), TimestampType, Some(Etc/UTC), false), None, Some(Etc/UTC), true), to

PART C — ANALYTICS & AGGREGATIONS

In [18]:
from pyspark.sql.functions import count

total_trips_per_city = (
    trips_df
    .groupBy("city")
    .agg(count("*").alias("Total revenue per city"))
)

total_trips_per_city.show()


+---------+----------------------+
|     city|Total revenue per city|
+---------+----------------------+
|Bangalore|                     1|
|   Mumbai|                     1|
|Hyderabad|                     2|
|  Chennai|                     1|
|    Delhi|                     1|
+---------+----------------------+



In [19]:
from pyspark.sql.functions import sum

total_rev_per_city = (
    trips_df
    .groupBy("city")
    .agg(sum("amount").alias("total_trips"))
)

total_rev_per_city.show()


+---------+-----------+
|     city|total_trips|
+---------+-----------+
|Bangalore|        0.0|
|   Mumbai|      620.0|
|Hyderabad|     1150.0|
|  Chennai|      350.0|
|    Delhi|      540.0|
+---------+-----------+



In [20]:
from pyspark.sql.functions import count

total_trips_completed = (
    trips_df
    .groupBy("driverid")
    .agg(count("*").alias("total_trips_completed"))
)

total_trips_completed.show()


+--------+---------------------+
|driverid|total_trips_completed|
+--------+---------------------+
|    D002|                    1|
|    D003|                    1|
|    D001|                    2|
|    D004|                    1|
|    D005|                    1|
+--------+---------------------+



In [21]:
trips_df.filter(trips_df["status"] == "Not Completed").show()

+------+--------+----+----+------+------+
|userid|driverid|city|date|status|amount|
+------+--------+----+----+------+------+
+------+--------+----+----+------+------+



PART D — WINDOW FUNCTIONS

In [22]:
trips_df.show()

+------+--------+---------+------------+---------+------+
|userid|driverid|     city|        date|   status|amount|
+------+--------+---------+------------+---------+------+
|  T001|    D001|Hyderabad|  2024-01-05|Completed|   450|
|  T002|    D002|Bangalore|  05/01/2024|Cancelled|     0|
|  T003|    D003|   Mumbai|  2024/01/06|Completed|   620|
|  T004|    D004|    Delhi|invalid_date|Completed|   540|
|  T005|    D001|Hyderabad|  2024-01-10|Completed|   700|
|  T006|    D005|  Chennai|  2024-01-12|Completed|   350|
+------+--------+---------+------------+---------+------+



In [23]:
from pyspark.sql import functions as F

driver_revenue = trips_df.groupBy("driverid") \
    .agg(F.sum("amount").alias("total_revenue")) \
    .orderBy(F.desc("total_revenue"))

driver_revenue.show()

+--------+-------------+
|driverid|total_revenue|
+--------+-------------+
|    D001|       1150.0|
|    D003|        620.0|
|    D004|        540.0|
|    D005|        350.0|
|    D002|          0.0|
+--------+-------------+



In [24]:
from pyspark.sql.window import Window

In [25]:
city_driver_rank = trips_df.groupBy("city", "driverid") \
    .agg(F.sum("amount").alias("city_revenue")) \
    .withColumn("rank", F.rank().over(Window.partitionBy("city").orderBy(F.desc("city_revenue"))))
city_driver_rank.show()

+---------+--------+------------+----+
|     city|driverid|city_revenue|rank|
+---------+--------+------------+----+
|Bangalore|    D002|         0.0|   1|
|  Chennai|    D005|       350.0|   1|
|    Delhi|    D004|       540.0|   1|
|Hyderabad|    D001|      1150.0|   1|
|   Mumbai|    D003|       620.0|   1|
+---------+--------+------------+----+



In [26]:
from pyspark.sql.window import Window

city_date_window = Window.partitionBy("city").orderBy("date") \
    .rowsBetween(Window.unboundedPreceding, Window.currentRow)

running_revenue = trips_df.groupBy("city", "date") \
    .agg(F.sum("amount").alias("daily_revenue")) \
    .withColumn("running_revenue", F.sum("daily_revenue").over(city_date_window))

. Compare GroupBy vs Window for one metric
GroupBy: Aggregates data into fewer rows (e.g., total revenue per driver).
Window: Keeps original granularity but adds computed columns (e.g., rank, cumulative sum).
Use Case:
GroupBy → summary reports.
Window → analytics like ranking, running totals without collapsing rows

PART E — UDF (ONLY IF REQUIRED)

In [27]:

from pyspark.sql.functions import udf, col
from pyspark.sql.types import StringType

# Define classification based on revenue
def classify_revenue(revenue):
    if revenue >= 1000:
        return "High"
    elif revenue >= 500:
        return "Medium"
    else:
        return "Low"

# Register UDF
classify_revenue_udf = udf(classify_revenue, StringType())

# Apply UDF on city_revenue column
city_driver_rank.withColumn("revenue_grade", classify_revenue_udf(col("city_revenue"))).show()


+---------+--------+------------+----+-------------+
|     city|driverid|city_revenue|rank|revenue_grade|
+---------+--------+------------+----+-------------+
|Bangalore|    D002|         0.0|   1|          Low|
|  Chennai|    D005|       350.0|   1|          Low|
|    Delhi|    D004|       540.0|   1|       Medium|
|Hyderabad|    D001|      1150.0|   1|         High|
|   Mumbai|    D003|       620.0|   1|       Medium|
+---------+--------+------------+----+-------------+



PART F — SORTING & ORDERING

In [28]:
from pyspark.sql.functions import desc

sorted_cities_by_revenue = (
     trips_df
    .groupBy("city")
    .agg(sum("amount").alias("total_revenue"))
    .orderBy(desc("total_revenue"))
)

sorted_cities_by_revenue.show()


+---------+-------------+
|     city|total_revenue|
+---------+-------------+
|Hyderabad|       1150.0|
|   Mumbai|        620.0|
|    Delhi|        540.0|
|  Chennai|        350.0|
|Bangalore|          0.0|
+---------+-------------+



In [29]:
from pyspark.sql.functions import sum

driver_city_revenue = (
    trips_df
    .groupBy("city", "driverid")
    .agg(sum("amount").alias("driver_revenue"))
)

driver_city_revenue.show()

from pyspark.sql.window import Window
from pyspark.sql.functions import desc, row_number

city_window = Window.partitionBy("city").orderBy(desc("driver_revenue"))

sorted_drivers_within_city = (
    driver_city_revenue
    .withColumn("rank", row_number().over(city_window))
    .orderBy("city", "rank")
)

sorted_drivers_within_city.show()



+---------+--------+--------------+
|     city|driverid|driver_revenue|
+---------+--------+--------------+
|Bangalore|    D002|           0.0|
|   Mumbai|    D003|         620.0|
|Hyderabad|    D001|        1150.0|
|  Chennai|    D005|         350.0|
|    Delhi|    D004|         540.0|
+---------+--------+--------------+

+---------+--------+--------------+----+
|     city|driverid|driver_revenue|rank|
+---------+--------+--------------+----+
|Bangalore|    D002|           0.0|   1|
|  Chennai|    D005|         350.0|   1|
|    Delhi|    D004|         540.0|   1|
|Hyderabad|    D001|        1150.0|   1|
|   Mumbai|    D003|         620.0|   1|
+---------+--------+--------------+----+



Sorting causes a shuffle because Spark must move data across partitions to establish a global or partition-level order.

PART G — SET OPERATIONS

In [30]:
completed_drivers_df = (
    trips_df
    .filter(trips_df.status == "Completed")
    .select("driverid")
    .distinct()
)

completed_drivers_df.show()


+--------+
|driverid|
+--------+
|    D003|
|    D001|
|    D004|
|    D005|
+--------+



In [31]:
df_activity_clean.show(truncate=False)

+------+----------------------------+------+------+
|userid|actions                     |device|amount|
+------+----------------------------+------+------+
|D001  |[login, accept_trip, logout]|mobile|180   |
|D002  |[login, logout]             |laptop|60    |
|D003  |[login, accept_trip]        |NULL  |120   |
|D004  |NULL                        |tablet|90    |
|D005  |[login]                     |mobile|30    |
+------+----------------------------+------+------+



In [32]:
from pyspark.sql.functions import array_contains

active_drivers_df = (
    df_activity_clean
    .filter(array_contains("actions", "login")  & ~array_contains("actions", "logout"))
    .select("userid")
    .distinct()
)

active_drivers_df.show()


+------+
|userid|
+------+
|  D003|
|  D005|
+------+



Set operations work on entire rows and treat DataFrames as mathematical sets, while joins combine columns based on matching keys.

PART H — DAG & PERFORMANCE ANALYSIS

In [33]:
trip_city_join_df = trips_df.join(
    city_df,
    on="city",
    how="left"
)

trip_city_join_df.explain(True)

== Parsed Logical Plan ==
'Join UsingJoin(LeftOuter, [city])
:- LogicalRDD [userid#72, driverid#73, city#74, date#75, status#76, amount#77], false
+- LogicalRDD [city#44, region#45], false

== Analyzed Logical Plan ==
city: string, userid: string, driverid: string, date: string, status: string, amount: string, region: string
Project [city#74, userid#72, driverid#73, date#75, status#76, amount#77, region#45]
+- Join LeftOuter, (city#74 = city#44)
   :- LogicalRDD [userid#72, driverid#73, city#74, date#75, status#76, amount#77], false
   +- LogicalRDD [city#44, region#45], false

== Optimized Logical Plan ==
Project [city#74, userid#72, driverid#73, date#75, status#76, amount#77, region#45]
+- Join LeftOuter, (city#74 = city#44)
   :- LogicalRDD [userid#72, driverid#73, city#74, date#75, status#76, amount#77], false
   +- Filter isnotnull(city#44)
      +- LogicalRDD [city#44, region#45], false

== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- Project [city#74, userid#72, drive

In [34]:
sorted_drivers_within_city.explain(True)


== Parsed Logical Plan ==
'Sort ['city ASC NULLS FIRST, 'rank ASC NULLS FIRST], true
+- Project [city#74, driverid#73, driver_revenue#471, rank#492]
   +- Project [city#74, driverid#73, driver_revenue#471, rank#492, rank#492]
      +- Window [row_number() windowspecdefinition(city#74, driver_revenue#471 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS rank#492], [city#74], [driver_revenue#471 DESC NULLS LAST]
         +- Project [city#74, driverid#73, driver_revenue#471]
            +- Aggregate [city#74, driverid#73], [city#74, driverid#73, sum(cast(amount#77 as double)) AS driver_revenue#471]
               +- LogicalRDD [userid#72, driverid#73, city#74, date#75, status#76, amount#77], false

== Analyzed Logical Plan ==
city: string, driverid: string, driver_revenue: double, rank: int
Sort [city#74 ASC NULLS FIRST, rank#492 ASC NULLS FIRST], true
+- Project [city#74, driverid#73, driver_revenue#471, rank#492]
   +- Project [city#74, driverid#7

In [35]:
from pyspark.sql.functions import broadcast

optimized_join_df = trips_df.join(
    broadcast(city_df),
    on="city",
    how="left"
)


In [36]:
print("\n--- Identifying Shuffles, Broadcast Joins, and Sort Stages ---\n")

# 1. Shuffles (often occur with wide transformations like non-broadcast joins, groupBy, orderBy)
# For example, a non-broadcast join like trip_city_join_df will typically involve shuffles (SortMergeJoin)
print("\n--- Shuffles (e.g., in trip_city_join_df) ---")
trip_city_join_df.explain(True)
print("\nNote: In the Physical Plan above, look for 'Exchange hashpartitioning' or 'Exchange rangepartitioning' which indicate a shuffle.")
print("A 'SortMergeJoin' also implies shuffles for sorting before merging.")

# Another example for shuffles and sorts from window functions/orderBy
print("\n--- Shuffles and Sorts (e.g., in sorted_drivers_within_city) ---")
sorted_drivers_within_city.explain(True)
print("\nNote: In the Physical Plan above, 'Exchange rangepartitioning' and 'Sort' operations are prominent, indicating shuffles and sorts.")

# 2. Broadcast Joins (explicitly used with `broadcast` hint)
print("\n--- Broadcast Joins (e.g., in optimized_join_df) ---")
optimized_join_df.explain(True)
print("\nNote: In the Physical Plan above, look for 'BroadcastHashJoin' and 'BroadcastExchange' which confirm a broadcast join.")

# 3. Sort Stages (occur with orderBy, window functions requiring order, or SortMergeJoin)
print("\n--- Sort Stages (e.g., in trip_city_join_df for SortMergeJoin) ---")
trip_city_join_df.explain(True)
print("\nNote: In the Physical Plan above, 'Sort' within 'SortMergeJoin' indicates sort stages.")

print("\n--- Sort Stages (e.g., in sorted_drivers_within_city for window functions and final ordering) ---")
sorted_drivers_within_city.explain(True)
print("\nNote: In the Physical Plan above, 'Sort' is explicitly shown for ordering the data within partitions before the window function, and again for the final global sort.")


--- Identifying Shuffles, Broadcast Joins, and Sort Stages ---


--- Shuffles (e.g., in trip_city_join_df) ---
== Parsed Logical Plan ==
'Join UsingJoin(LeftOuter, [city])
:- LogicalRDD [userid#72, driverid#73, city#74, date#75, status#76, amount#77], false
+- LogicalRDD [city#44, region#45], false

== Analyzed Logical Plan ==
city: string, userid: string, driverid: string, date: string, status: string, amount: string, region: string
Project [city#74, userid#72, driverid#73, date#75, status#76, amount#77, region#45]
+- Join LeftOuter, (city#74 = city#44)
   :- LogicalRDD [userid#72, driverid#73, city#74, date#75, status#76, amount#77], false
   +- LogicalRDD [city#44, region#45], false

== Optimized Logical Plan ==
Project [city#74, userid#72, driverid#73, date#75, status#76, amount#77, region#45]
+- Join LeftOuter, (city#74 = city#44)
   :- LogicalRDD [userid#72, driverid#73, city#74, date#75, status#76, amount#77], false
   +- Filter isnotnull(city#44)
      +- LogicalRDD [city#44, 

In [37]:
from pyspark.sql.functions import broadcast

# Assuming trips_df is the larger DataFrame and city_df is the smaller one
# This will broadcast city_df to all worker nodes to avoid shuffling trips_df
optimized_trips_city_join = trips_df.join(broadcast(city_df), on="city", how="inner")

print("\n--- Optimized Join using Broadcast Hint ---")
optimized_trips_city_join.explain(True)
optimized_trips_city_join.show()


--- Optimized Join using Broadcast Hint ---
== Parsed Logical Plan ==
'Join UsingJoin(Inner, [city])
:- LogicalRDD [userid#72, driverid#73, city#74, date#75, status#76, amount#77], false
+- ResolvedHint (strategy=broadcast)
   +- LogicalRDD [city#44, region#45], false

== Analyzed Logical Plan ==
city: string, userid: string, driverid: string, date: string, status: string, amount: string, region: string
Project [city#74, userid#72, driverid#73, date#75, status#76, amount#77, region#45]
+- Join Inner, (city#74 = city#44)
   :- LogicalRDD [userid#72, driverid#73, city#74, date#75, status#76, amount#77], false
   +- ResolvedHint (strategy=broadcast)
      +- LogicalRDD [city#44, region#45], false

== Optimized Logical Plan ==
Project [city#74, userid#72, driverid#73, date#75, status#76, amount#77, region#45]
+- Join Inner, (city#74 = city#44), rightHint=(strategy=broadcast)
   :- Filter isnotnull(city#74)
   :  +- LogicalRDD [userid#72, driverid#73, city#74, date#75, status#76, amount#77