In [1]:
# Spark Session
from pyspark.sql import SparkSession

spark = (
    SparkSession
    .builder
    .appName("Distributed Shared Variables")
    .master("local[*]")
    .config("spark.cores.max", 16)
    .config("spark.executor.cores", 4)
    .config("spark.executor.memory", "512M")
    .getOrCreate()
)

spark

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/01/13 09:15:24 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [4]:
# Read EMP CSV data

_schema = "first_name string, last_name string, job_title string, dob string, email string, phone string, salary double, department_id int"

emp = spark.read.format("csv").schema(_schema).option("header", True).load("data/input/employee_records.csv")

In [14]:
# Variable (Lookup)

dept_names = {1 : 'Department 1', 
              2 : 'Department 2', 
              3 : 'Department 3', 
              4 : 'Department 4',
              5 : 'Department 5', 
              6 : 'Department 6', 
              7 : 'Department 7', 
              8 : 'Department 8', 
              9 : 'Department 9', 
              10 : 'Department 10'}

In [15]:
# Broadcast the variable

dept_names_broadcast = spark.sparkContext.broadcast(dept_names)

In [17]:
# Check the value of the variable
# broadcast variable is useful in case to supply the lookup variable 
# instead of doing actual shuffle

dept_names_broadcast.value

{1: 'Department 1',
 2: 'Department 2',
 3: 'Department 3',
 4: 'Department 4',
 5: 'Department 5',
 6: 'Department 6',
 7: 'Department 7',
 8: 'Department 8',
 9: 'Department 9',
 10: 'Department 10'}

In [25]:
# Create UDF to return Department name
from pyspark.sql.functions import udf

@udf
def dept_name_mapping(dept_id):
    return dept_names_broadcast.value.get(dept_id)


from pyspark.sql.functions import col

emp_final = emp.withColumn("dept_name", dept_name_mapping(col("department_id")))
emp_final.show(1)


+----------+---------+--------------------+----------+--------------------+-------------+--------+-------------+------------+
|first_name|last_name|           job_title|       dob|               email|        phone|  salary|department_id|   dept_name|
+----------+---------+--------------------+----------+--------------------+-------------+--------+-------------+------------+
|   Richard| Morrison|Public relations ...|1973-05-05|melissagarcia@exa...|(699)525-4827|512653.0|            8|Department 8|
+----------+---------+--------------------+----------+--------------------+-------------+--------+-------------+------------+
only showing top 1 row



In [34]:
# Calculate total salary of Department 6
from pyspark.sql.functions import sum
emp_dept_6_sal = emp.where("department_id == 6 ").groupBy("department_id").agg(sum("salary").cast("long")).show()

[Stage 4:>                                                          (0 + 8) / 8]

+-------------+---------------------------+
|department_id|CAST(sum(salary) AS BIGINT)|
+-------------+---------------------------+
|            6|                50294510721|
+-------------+---------------------------+



                                                                                

In [37]:
# Accumulators

emp_accumulators = spark.sparkContext.accumulator(0)

In [40]:
# Use foreach

def calculate_dept_salary(department_id, salary):
    if department_id ==6:
        emp_accumulators.add(salary)

emp.foreach(lambda row: calculate_dept_salary(row.department_id, row.salary))

                                                                                

In [41]:
emp_accumulators.value

50294510721.0

In [42]:
# Stop Spark Session
spark.stop()