In [12]:
from pyspark.sql import SparkSession 
from pyspark.sql.functions import rand, when, pandas_udf, PandasUDFType
from pyspark.sql.types import BooleanType
import pandas as pd

In [2]:
# Create a new SparkSession
spark = (SparkSession
         .builder
         .appName("broadcast-variables")
         .master("spark://spark-master:7077")
         .config("spark.executor.memory", "512m")
         .getOrCreate())

# Set log level to ERROR
spark.sparkContext.setLogLevel("ERROR")

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/02/21 12:33:51 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [7]:
# Create some sample data frames
# A large data frame with 1 million rows
large_df = (spark.range(0, 1000000)
            .withColumn("salary", 100*(rand() * 100).cast("int"))
            .withColumn("gender", when((rand() * 2).cast("int") == 0, "M").otherwise("F"))
            .withColumn("country_code", 
                        when((rand() * 4).cast("int") == 0, "US")
                        .when((rand() * 4).cast("int") == 1, "CN")
                        .when((rand() * 4).cast("int") == 2, "IN")
                        .when((rand() * 4).cast("int") == 3, "BR")))
large_df.show(5)

[Stage 0:>                                                          (0 + 1) / 1]

+---+------+------+------------+
| id|salary|gender|country_code|
+---+------+------+------------+
|  0|  8000|     M|          US|
|  1|  3500|     F|        null|
|  2|  9700|     F|        null|
|  3|  4800|     F|        null|
|  4|  9100|     F|        null|
+---+------+------+------------+
only showing top 5 rows



                                                                                

In [8]:
# Define lookup table
lookup = {"US": "United States", "CN": "China", "IN": "India", "BR": "Brazil", "RU": "Russia"}

# Create broadcast variable
broadcast_lookup = spark.sparkContext.broadcast(lookup)

In [9]:
@pandas_udf('string', PandasUDFType.SCALAR)
def country_convert(s):
    return s.map(broadcast_lookup.value)



In [10]:
large_df.withColumn("country_name", country_convert(large_df.country_code)).show(5)

[Stage 1:>                                                          (0 + 1) / 1]

+---+------+------+------------+-------------+
| id|salary|gender|country_code| country_name|
+---+------+------+------------+-------------+
|  0|  8000|     M|          US|United States|
|  1|  3500|     F|        null|         null|
|  2|  9700|     F|        null|         null|
|  3|  4800|     F|        null|         null|
|  4|  9100|     F|        null|         null|
+---+------+------+------------+-------------+
only showing top 5 rows



                                                                                

In [13]:
@pandas_udf(BooleanType(), PandasUDFType.SCALAR)
def filter_unknown_country(s):
    return s.isin(broadcast_lookup.value)

In [14]:
large_df.filter(filter_unknown_country(large_df.country_code)).show(5)

[Stage 2:>                                                          (0 + 1) / 1]

+---+------+------+------------+
| id|salary|gender|country_code|
+---+------+------+------------+
|  0|  8000|     M|          US|
|  6|  3400|     F|          US|
|  7|  8400|     M|          CN|
|  8|  1100|     F|          US|
|  9|  2900|     M|          CN|
+---+------+------+------------+
only showing top 5 rows



                                                                                

In [15]:
spark.stop()