In [1]:
# Create spark session
from pyspark.sql import SparkSession

spark = SparkSession.builder.appName("broadcast").master("local[*]").getOrCreate()

In [3]:
# Read EMP CSV data
_schema = "first_name string, last_name string, job_title string, dob string, email string, phone string, salary double, department_id int"

df = spark.read.schema(_schema).option("header",True).csv('employee_records.txt')

In [6]:
# dictionary used to lookup on department names
dept_names = {1 : 'Department 1', 
              2 : 'Department 2', 
              3 : 'Department 3', 
              4 : 'Department 4',
              5 : 'Department 5', 
              6 : 'Department 6', 
              7 : 'Department 7', 
              8 : 'Department 8', 
              9 : 'Department 9', 
              10 : 'Department 10'}

Now if we want to use this dictionary to perform a lookup on employee df, then for each partitions, for each task this dictionary will be sent from driver to executors.                                                               
However, we can broadcast this to all executors beforehand, so that it will only be sent once

In [8]:
# Broadcast the variable
broadcast_depnames = spark.sparkContext.broadcast(dept_names)

In [9]:
# We can check the type of this variable, which will not be dict because now it is not just a python object it is pyspark broadcast variable
type(broadcast_depnames)

pyspark.broadcast.Broadcast

In [10]:
# To check value of a broadcast variable we can use .value
broadcast_depnames.value

{1: 'Department 1',
 2: 'Department 2',
 3: 'Department 3',
 4: 'Department 4',
 5: 'Department 5',
 6: 'Department 6',
 7: 'Department 7',
 8: 'Department 8',
 9: 'Department 9',
 10: 'Department 10'}

In [12]:
# Now to perform the lookup, we need to write a function which receives the deparment id and returns the department name. As this is not a dataframe, we cant join and need to use a udf

# Create UDF to return Department name

from pyspark.sql.functions import udf

def get_dept_names(dep_id):
    return broadcast_depnames.value.get(dep_id)
    
get_dept_names_udf = udf(get_dept_names)

In [13]:
# Use udf to perform lookup
final_df = df.withColumn("dep_name",get_dept_names_udf(df.department_id))

In [14]:
# We can see that dep name is populated and there was no shuffle involved
final_df.show()

+----------+----------+--------------------+----------+--------------------+--------------------+--------+-------------+-------------+
|first_name| last_name|           job_title|       dob|               email|               phone|  salary|department_id|     dep_name|
+----------+----------+--------------------+----------+--------------------+--------------------+--------+-------------+-------------+
|   Richard|  Morrison|Public relations ...|1973-05-05|melissagarcia@exa...|       (699)525-4827|512653.0|            8| Department 8|
|     Bobby|  Mccarthy|   Barrister's clerk|1974-04-25|   llara@example.net|  (750)846-1602x7458|999836.0|            7| Department 7|
|    Dennis|    Norman|Land/geomatics su...|1990-06-24| jturner@example.net|    873.820.0518x825|131900.0|           10|Department 10|
|      John|    Monroe|        Retail buyer|1968-06-16|  erik33@example.net|    820-813-0557x624|485506.0|            1| Department 1|
|  Michelle|   Elliott|      Air cabin crew|1975-03-31|

# Accumulators
Lets see an example of how accumulators can be used

In [15]:
# Calculate total salary of deparment 6

from pyspark.sql.functions import sum

final_df.where("department_id=6").groupBy("department_id").agg(sum("salary").cast("long")).show()

# If we ue the above way to get the output there will be a shuffle involved and deparment id 6 records will be distributed across executors so they will need to be brought together to perform aggregation

+-------------+---------------------------+
|department_id|CAST(sum(salary) AS BIGINT)|
+-------------+---------------------------+
|            6|                50294510721|
+-------------+---------------------------+



In [23]:
# Lets see how can we avoid the shuffle by using accumulators

# We need to initialize the accumulator with a initial value
dept_sal = spark.sparkContext.accumulator(0)

In [24]:
# As we don't want to use group by which forces a shuffle, we need to use udf to calculate the output and we need to run it on all rows present in the dataframe.

def calculate_salary(department_id, salary):
    if department_id==6:
        dept_sal.add(salary)

# foreach function runs a given function on each row of the dataframe
df.where("department_id=6").foreach(lambda row: calculate_salary(row.department_id, row.salary))

#Now we can see that using accumulator there was no shuffle involved

In [25]:
# to get the value of the accumulator we can use .value
dept_sal.value

50294510721.0

In [26]:
spark.stop()