# Tutorial: introduction to data analysis with PySpark

In [43]:
import findspark

In [44]:
SPARK_HOME = "C:\Spark\spark-3.0.1-bin-hadoop2.7"
findspark.init(SPARK_HOME)

In [45]:
import pyspark
from pyspark import SparkContext
from pyspark.sql import functions as F
from pyspark.sql import SparkSession
from pyspark.sql import SQLContext

from pyspark.sql.functions import broadcast
from pyspark.sql.types import StringType, IntegerType
from pyspark.sql.window import Window

# Configuration & Intialization

* SparkContext — provides connection to Spark with the ability to create RDDs
* SQLContext — provides connection to Spark with the ability to run SQL queries on data
* SparkSession — all-encompassing context which includes coverage for SparkContext, SQLContext and HiveContext. SparkSession is the entry point for programming Spark applications. It let you interact with DataSet and DataFrame APIs provided by Spark. We set the application name by calling appName. The getOrCreate() method either returns a new SparkSession of the app or returns the existing one.

In [46]:
# create a spark session
spark = SparkSession.builder.appName("data analysis with PySpark").getOrCreate()

In [47]:
# create a SparkContext instance which allows the Spark Application to access 
# Spark Cluster with the help of a resource manager which is usually YARN or Mesos
sc = SparkContext.getOrCreate()

In [48]:
# create a SQLContext instance to access the SQL query engine built on top of Spark
sqlContext = SQLContext(spark)

# Reading

We can start by loading the files in our dataset using the spark.read.load command. This command reads parquet files, which is the default file format for spark, but you can add the parameter format to read other formats.

In [49]:
# import covid cases dataset
cases = spark.read.load(
    "data/Case.csv", 
    format="csv", 
    sep=",",
    inferSchema="true", 
    header="true"
)
cases.show(6)

+--------+--------+------------+-----+--------------------+---------+---------+----------+
| case_id|province|        city|group|      infection_case|confirmed| latitude| longitude|
+--------+--------+------------+-----+--------------------+---------+---------+----------+
| 1000001|   Seoul|  Yongsan-gu| true|       Itaewon Clubs|      139|37.538621|126.992652|
| 1000002|   Seoul|   Gwanak-gu| true|             Richway|      119| 37.48208|126.901384|
| 1000003|   Seoul|     Guro-gu| true| Guro-gu Call Center|       95|37.508163|126.884387|
| 1000004|   Seoul|Yangcheon-gu| true|Yangcheon Table T...|       43|37.546061|126.874209|
| 1000005|   Seoul|   Dobong-gu| true|     Day Care Center|       43|37.679422|127.044374|
| 1000006|   Seoul|     Guro-gu| true|Manmin Central Ch...|       41|37.481059|126.894343|
+--------+--------+------------+-----+--------------------+---------+---------+----------+
only showing top 6 rows



In [50]:
# import cases timing dataset
time_province = spark.read.load(
    "data/TimeProvince.csv", 
    format="csv", 
    sep=",",
    inferSchema="true", 
    header="true"
)
time_province.show(6)

+----------+----+--------+---------+--------+--------+
|      date|time|province|confirmed|released|deceased|
+----------+----+--------+---------+--------+--------+
|2020-01-20|  16|   Seoul|        0|       0|       0|
|2020-01-20|  16|   Busan|        0|       0|       0|
|2020-01-20|  16|   Daegu|        0|       0|       0|
|2020-01-20|  16| Incheon|        1|       0|       0|
|2020-01-20|  16| Gwangju|        0|       0|       0|
|2020-01-20|  16| Daejeon|        0|       0|       0|
+----------+----+--------+---------+--------+--------+
only showing top 6 rows



In [51]:
# import regions dataset
regions = spark.read.load(
    "data/Region.csv",
    format="csv", 
    sep=",", 
    inferSchema="true", 
    header="true"
)
regions.show(6)

+-----+--------+-----------+---------+----------+-----------------------+------------------+----------------+-------------+------------------------+-------------------+------------------+
| code|province|       city| latitude| longitude|elementary_school_count|kindergarten_count|university_count|academy_ratio|elderly_population_ratio|elderly_alone_ratio|nursing_home_count|
+-----+--------+-----------+---------+----------+-----------------------+------------------+----------------+-------------+------------------------+-------------------+------------------+
|10000|   Seoul|      Seoul|37.566953|126.977977|                    607|               830|              48|         1.44|                   15.38|                5.8|             22739|
|10010|   Seoul| Gangnam-gu|37.518421|127.047222|                     33|                38|               0|         4.18|                   13.17|                4.3|              3088|
|10020|   Seoul|Gangdong-gu|37.530492|127.123837|           

In [52]:
# alternative reading options
#df = spark.read.csv("Case.csv", header=True, inferSchema=True)
#df = spark.read.format('csv').options(header=True,inferSchema=True).load("Case.csv")

# for other file formats:
#df = spark.read.text(path_to_file)
#df = spark.read.json(path_to_file)
#df = spark.read.parquet(path_to_file)

## Overview

In [53]:
# show dataset with full columns names
cases.show(10, truncate=False)

+--------+--------+---------------+-----+--------------------------------+---------+---------+----------+
| case_id|province|city           |group|infection_case                  |confirmed|latitude |longitude |
+--------+--------+---------------+-----+--------------------------------+---------+---------+----------+
|1000001 |Seoul   |Yongsan-gu     |true |Itaewon Clubs                   |139      |37.538621|126.992652|
|1000002 |Seoul   |Gwanak-gu      |true |Richway                         |119      |37.48208 |126.901384|
|1000003 |Seoul   |Guro-gu        |true |Guro-gu Call Center             |95       |37.508163|126.884387|
|1000004 |Seoul   |Yangcheon-gu   |true |Yangcheon Table Tennis Club     |43       |37.546061|126.874209|
|1000005 |Seoul   |Dobong-gu      |true |Day Care Center                 |43       |37.679422|127.044374|
|1000006 |Seoul   |Guro-gu        |true |Manmin Central Church           |41       |37.481059|126.894343|
|1000007 |Seoul   |from other city|true |SMR N

In [54]:
cases.columns

[' case_id',
 'province',
 'city',
 'group',
 'infection_case',
 'confirmed',
 'latitude',
 'longitude']

In [55]:
# know data schema
cases.printSchema()

root
 |--  case_id: integer (nullable = true)
 |-- province: string (nullable = true)
 |-- city: string (nullable = true)
 |-- group: boolean (nullable = true)
 |-- infection_case: string (nullable = true)
 |-- confirmed: integer (nullable = true)
 |-- latitude: string (nullable = true)
 |-- longitude: string (nullable = true)



# Basic Functions

Get a descriptive overview of data fields

In [56]:
# describe data
describe_cases = cases.select("province", "city", "infection_case", "confirmed", "latitude", "longitude")
describe_cases.describe().show()

+-------+--------+---------------+--------------------+------------------+------------------+------------------+
|summary|province|           city|      infection_case|         confirmed|          latitude|         longitude|
+-------+--------+---------------+--------------------+------------------+------------------+------------------+
|  count|     174|            174|                 174|               174|               174|               174|
|   mean|    null|           null|                null| 65.48850574712644| 36.69405111076924|127.58488500461536|
| stddev|    null|           null|                null|355.09765388939746|0.9114662922487264| 0.823086807800544|
|    min|   Busan|              -|Anyang Gunpo Past...|                 0|                 -|                 -|
|    max|   Ulsan|from other city|     overseas inflow|              4511|         37.758635|          129.1256|
+-------+--------+---------------+--------------------+------------------+------------------+---

We can filter a data frame using multiple conditions using AND(&), OR(|) and NOT(~) conditions

In [57]:
# filter columns
seoul_cases = cases.filter(F.col("province")=="Seoul")
seoul_cases.count()

#alternative
#seoul_cases = cases.where(F.col("province")=="Seoul")
#seoul_cases.count()

38

In [58]:
cases.where((cases.confirmed>100) & (cases.province=='Daegu')).show()

+--------+--------+------------+-----+--------------------+---------+---------+----------+
| case_id|province|        city|group|      infection_case|confirmed| latitude| longitude|
+--------+--------+------------+-----+--------------------+---------+---------+----------+
| 1200001|   Daegu|      Nam-gu| true|  Shincheonji Church|     4511| 35.84008|  128.5667|
| 1200002|   Daegu|Dalseong-gun| true|Second Mi-Ju Hosp...|      196|35.857375|128.466651|
| 1200003|   Daegu|      Seo-gu| true|Hansarang Convale...|      124|35.885592|128.556649|
| 1200004|   Daegu|Dalseong-gun| true|Daesil Convalesce...|      101|35.857393|128.466653|
| 1200009|   Daegu|           -|false|contact with patient|      917|        -|         -|
| 1200010|   Daegu|           -|false|                 etc|      747|        -|         -|
+--------+--------+------------+-----+--------------------+---------+---------+----------+



We can select a subset of columns using the select keyword

In [59]:
# select columns 
cases.select(["province", "city"]).show(5)

+--------+------------+
|province|        city|
+--------+------------+
|   Seoul|  Yongsan-gu|
|   Seoul|   Gwanak-gu|
|   Seoul|     Guro-gu|
|   Seoul|Yangcheon-gu|
|   Seoul|   Dobong-gu|
+--------+------------+
only showing top 5 rows



We can change a single column name

In [60]:
# rename columns: exisitng left, new right
cases = cases\
    .withColumnRenamed("latitude", "lat")\
    .withColumnRenamed("longitude", "long")\
    .withColumnRenamed("infection_case", "infection_source")

cases.show(10)

+--------+--------+---------------+-----+--------------------+---------+---------+----------+
| case_id|province|           city|group|    infection_source|confirmed|      lat|      long|
+--------+--------+---------------+-----+--------------------+---------+---------+----------+
| 1000001|   Seoul|     Yongsan-gu| true|       Itaewon Clubs|      139|37.538621|126.992652|
| 1000002|   Seoul|      Gwanak-gu| true|             Richway|      119| 37.48208|126.901384|
| 1000003|   Seoul|        Guro-gu| true| Guro-gu Call Center|       95|37.508163|126.884387|
| 1000004|   Seoul|   Yangcheon-gu| true|Yangcheon Table T...|       43|37.546061|126.874209|
| 1000005|   Seoul|      Dobong-gu| true|     Day Care Center|       43|37.679422|127.044374|
| 1000006|   Seoul|        Guro-gu| true|Manmin Central Ch...|       41|37.481059|126.894343|
| 1000007|   Seoul|from other city| true|SMR Newly Planted...|       36|        -|         -|
| 1000008|   Seoul|  Dongdaemun-gu| true|       Dongan Churc

We can sort data by increasing (deafult) or decreasing order

In [61]:
# sorting: use "asc" for ascending and "desc" for descending order

cases.sort(F.desc("confirmed")).show()

# alternative
#cases.orderBy(F.desc("confirmed")).show()

+--------+-----------------+---------------+-----+--------------------+---------+---------+----------+
| case_id|         province|           city|group|    infection_source|confirmed|      lat|      long|
+--------+-----------------+---------------+-----+--------------------+---------+---------+----------+
| 1200001|            Daegu|         Nam-gu| true|  Shincheonji Church|     4511| 35.84008|  128.5667|
| 1200009|            Daegu|              -|false|contact with patient|      917|        -|         -|
| 1200010|            Daegu|              -|false|                 etc|      747|        -|         -|
| 6000001| Gyeongsangbuk-do|from other city| true|  Shincheonji Church|      566|        -|         -|
| 2000020|      Gyeonggi-do|              -|false|     overseas inflow|      305|        -|         -|
| 1000036|            Seoul|              -|false|     overseas inflow|      298|        -|         -|
| 1200002|            Daegu|   Dalseong-gun| true|Second Mi-Ju Hosp...|  

We can use groupBy function with a spark DataFrame, too. Pretty much same as the pandas groupBy with the exception that you will need to import pyspark.sql.functions

In [62]:
# grouping operations: cases by province

cases.groupBy("province").count().orderBy("count", ascending=False).show()

+-----------------+-----+
|         province|count|
+-----------------+-----+
|            Seoul|   38|
|      Gyeonggi-do|   22|
| Gyeongsangbuk-do|   13|
| Gyeongsangnam-do|   12|
|            Busan|   10|
|          Daejeon|   10|
|            Daegu|   10|
|Chungcheongnam-do|    8|
|       Gangwon-do|    8|
|Chungcheongbuk-do|    7|
|          Incheon|    7|
|           Sejong|    6|
|     Jeollabuk-do|    5|
|     Jeollanam-do|    5|
|          Gwangju|    5|
|            Ulsan|    4|
|          Jeju-do|    4|
+-----------------+-----+



If you don’t like the new grouped column names, you can use the alias keyword to rename columns in the agg command itself

In [63]:
# grouping operations: tot and max confirmed by prov-city

cases.groupBy(["province","city"]).agg(
    F.sum("confirmed").alias("tot confirmed"),
    F.max("confirmed").alias("max confirmed")
).show(10)

+----------------+---------------+-------------+-------------+
|        province|           city|tot confirmed|max confirmed|
+----------------+---------------+-------------+-------------+
|Gyeongsangnam-do|       Jinju-si|            9|            9|
|           Seoul|        Guro-gu|          139|           95|
|           Seoul|     Gangnam-gu|           18|            7|
|         Daejeon|              -|          100|           55|
|    Jeollabuk-do|from other city|            6|            3|
|Gyeongsangnam-do|Changnyeong-gun|            7|            7|
|           Seoul|              -|          561|          298|
|         Jeju-do|from other city|            1|            1|
|Gyeongsangbuk-do|              -|          345|          190|
|Gyeongsangnam-do|   Geochang-gun|           18|           10|
+----------------+---------------+-------------+-------------+
only showing top 10 rows



We can use .withcolumn along with PySpark SQL functions (like When/Othewise) to create a new column. In essence, you can find String functions, Date functions, and Math functions already implemented using Spark functions

In [64]:
# create new column by replacing "-" with "unknown" in all cases column
# if else: when / otherwise

for c in cases.columns:
    cases = cases.withColumn(
        c, F.when(F.col(c)=="-", None).otherwise(F.col(c))
    )

cases.select("city").distinct().show()

+---------------+
|           city|
+---------------+
|     Gangnam-gu|
|     Cheonan-si|
|from other city|
|      Anyang-si|
|      Gwanak-gu|
|     Yongsan-gu|
|        Dong-gu|
|         Sejong|
|     Gangseo-gu|
|       Wonju-si|
|     Suyeong-gu|
|   Geochang-gun|
|           null|
|  Dongdaemun-gu|
|     Dongnae-gu|
|         Jin-gu|
|     Yangsan-si|
|    Changwon-si|
|         Nam-gu|
|   Gyeongsan-si|
+---------------+
only showing top 20 rows



## Cache / Persist

Spark works on the lazy execution principle. What that means is that nothing really gets executed until you use an action function like the .count() on a dataframe. And if you do call a .count() function, it generally helps to cache at this step. So you might want to cache() or persist() your dataframes when you do a .count() operation.

In [65]:
# persist data in memory or on disk (see persist options to specify persistance at disk level)
cases.count()
cases = cases.persist()
# alternative
#cases.persist().count()

# alternatively, use cache to persist only in memory
#cases.cache().count())
#cases.count()
#cases = cases.cache()

In [66]:
# remeber to unpersist data when not needed anymore
#cases.unpersist()

In [67]:
# count nulls by column

cases.select([F.count(F.when(F.col(c).isNull(), c)).alias(c) for c in cases.columns]).show()


+--------+--------+----+-----+----------------+---------+---+----+
| case_id|province|city|group|infection_source|confirmed|lat|long|
+--------+--------+----+-----+----------------+---------+---+----+
|       0|       0|  53|    0|               0|        0|109| 109|
+--------+--------+----+-----+----------------+---------+---+----+



## Some more advanced processing: Windows

Window functions allow to create new columns based on groups of values avoiding for loops on data groups.

In [68]:
# summing over groups across partitions
w_sum = Window.partitionBy("province")

cases = cases.withColumn("tot_cases_by_prov", F.sum("confirmed").over(w_sum))
cases.show(15)

+--------+-----------------+---------------+-----+--------------------+---------+---------+----------+-----------------+
| case_id|         province|           city|group|    infection_source|confirmed|      lat|      long|tot_cases_by_prov|
+--------+-----------------+---------------+-----+--------------------+---------+---------+----------+-----------------+
| 1700001|           Sejong|         Sejong| true|Ministry of Ocean...|       31|36.504713|127.265172|               49|
| 1700002|           Sejong|         Sejong| true|gym facility in S...|        8| 36.48025|   127.289|               49|
| 1700003|           Sejong|from other city| true|  Shincheonji Church|        1|     null|      null|               49|
| 1700004|           Sejong|           null|false|     overseas inflow|        5|     null|      null|               49|
| 1700005|           Sejong|           null|false|contact with patient|        3|     null|      null|               49|
| 1700006|           Sejong|    

In [69]:
# ranking over groups across partitions
w_rank = Window().partitionBy('province').orderBy(F.desc('confirmed'))

cases_ranked = cases.withColumn(
    "rank_cases_by_prov", F.rank().over(w_rank)
).drop(*["lat", "long"])

cases_ranked.show(10)

+--------+--------+---------------+-----+--------------------+---------+-----------------+------------------+
| case_id|province|           city|group|    infection_source|confirmed|tot_cases_by_prov|rank_cases_by_prov|
+--------+--------+---------------+-----+--------------------+---------+-----------------+------------------+
| 1700001|  Sejong|         Sejong| true|Ministry of Ocean...|       31|               49|                 1|
| 1700002|  Sejong|         Sejong| true|gym facility in S...|        8|               49|                 2|
| 1700004|  Sejong|           null|false|     overseas inflow|        5|               49|                 3|
| 1700005|  Sejong|           null|false|contact with patient|        3|               49|                 4|
| 1700003|  Sejong|from other city| true|  Shincheonji Church|        1|               49|                 5|
| 1700006|  Sejong|           null|false|                 etc|        1|               49|                 5|
| 1600002|

In [70]:
cases_ranked.filter(F.col("rank_cases_by_prov")==1).select(["city", "confirmed"]).dropna().show()

+---------------+---------+
|           city|confirmed|
+---------------+---------+
|         Sejong|       31|
|from other city|       17|
|from other city|      566|
|         Nam-gu|     4511|
|from other city|       32|
|     Dongnae-gu|       39|
|     Cheonan-si|      103|
+---------------+---------+



In [71]:
start = time_province.select("date").rdd.min()[0]
end = time_province.select("date").rdd.max()[0]

print("Dataset time range: {} - {}".format(start, end))

Dataset time range: 2020-01-20 - 2020-06-30


In [72]:
#lagging over groups across partitions
w_lag = Window().partitionBy(['province']).orderBy('date')

time_province = time_province.withColumn("lag_7",F.lag("confirmed", 7).over(w_lag))
time_province.filter(time_province.date>'2020-03-10').show()

+----------+----+--------+---------+--------+--------+-----+
|      date|time|province|confirmed|released|deceased|lag_7|
+----------+----+--------+---------+--------+--------+-----+
|2020-03-11|   0|  Sejong|       10|       0|       0|    1|
|2020-03-12|   0|  Sejong|       15|       0|       0|    1|
|2020-03-13|   0|  Sejong|       32|       0|       0|    1|
|2020-03-14|   0|  Sejong|       38|       0|       0|    2|
|2020-03-15|   0|  Sejong|       39|       0|       0|    3|
|2020-03-16|   0|  Sejong|       40|       0|       0|    6|
|2020-03-17|   0|  Sejong|       40|       0|       0|    8|
|2020-03-18|   0|  Sejong|       41|       0|       0|   10|
|2020-03-19|   0|  Sejong|       41|       0|       0|   15|
|2020-03-20|   0|  Sejong|       41|       0|       0|   32|
|2020-03-21|   0|  Sejong|       41|       2|       0|   38|
|2020-03-22|   0|  Sejong|       41|       3|       0|   39|
|2020-03-23|   0|  Sejong|       42|       3|       0|   40|
|2020-03-24|   0|  Sejon

In [73]:
# rolling over group using an aggregation (mean) over ther last 7 days
# include current day: rowsBetween(-6,0)
# exclude current day: rowsBetween(-7,-1)

w_roll = Window().partitionBy(['province']).orderBy('date').rowsBetween(-6,0)

time_province = time_province.withColumn(
    "roll_7_confirmed", F.round(F.mean("confirmed").over(w_roll),2)
)
time_province.filter(time_province.date>'2020-03-10').show(10)

+----------+----+--------+---------+--------+--------+-----+----------------+
|      date|time|province|confirmed|released|deceased|lag_7|roll_7_confirmed|
+----------+----+--------+---------+--------+--------+-----+----------------+
|2020-03-11|   0|  Sejong|       10|       0|       0|    1|            4.43|
|2020-03-12|   0|  Sejong|       15|       0|       0|    1|            6.43|
|2020-03-13|   0|  Sejong|       32|       0|       0|    1|           10.86|
|2020-03-14|   0|  Sejong|       38|       0|       0|    2|            16.0|
|2020-03-15|   0|  Sejong|       39|       0|       0|    3|           21.14|
|2020-03-16|   0|  Sejong|       40|       0|       0|    6|            26.0|
|2020-03-17|   0|  Sejong|       40|       0|       0|    8|           30.57|
|2020-03-18|   0|  Sejong|       41|       0|       0|   10|            35.0|
|2020-03-19|   0|  Sejong|       41|       0|       0|   15|           38.71|
|2020-03-20|   0|  Sejong|       41|       0|       0|   32|    

In [74]:
time_province.printSchema()

root
 |-- date: string (nullable = true)
 |-- time: integer (nullable = true)
 |-- province: string (nullable = true)
 |-- confirmed: integer (nullable = true)
 |-- released: integer (nullable = true)
 |-- deceased: integer (nullable = true)
 |-- lag_7: integer (nullable = true)
 |-- roll_7_confirmed: double (nullable = true)



## Some more advanced processing: UDFs

Sometimes we want to do complicated things with a column or multiple columns. While Spark SQL functions do solve many use cases, when it comes to column creation, we can create Spark UDFs to build more matured or custom Python functionalities.

In [75]:
cases.count()
cases = cases.persist()

In [76]:
def get_confirmed_level(confirmed):
    """
    Assigns "high" category if confirmed cases are 
    above 50 otherwise "low"
    """
    if confirmed < 50: 
        return 'low'
    else:
        return 'high'
    
#convert to a UDF Function by passing in the function and return type of function
confirmed_udf = F.udf(lambda x: get_confirmed_level(x), StringType())

cases = cases.withColumn("confirmed_level", confirmed_udf(F.col("confirmed")))
cases.groupBy("confirmed_level").count().show()

+---------------+-----+
|confirmed_level|count|
+---------------+-----+
|            low|  143|
|           high|   31|
+---------------+-----+



In [77]:
def year_month(date):
    """
    Extract year_month from datetime.date object 
    as yyyymm format and nteger type
    """
    if date is not None:
        month = str(date.month)
        year = str(date.year)
        if len(month) < 2:
            year_month_var = year + "0" + month
        else:
            year_month_var = year + month
        return int(year_month_var)

        
year_month_udf = F.udf(year_month, IntegerType())

time_province = time_province.withColumn("date", F.to_date(F.col("date")))
time_province = time_province.withColumn("year_month", year_month_udf(F.col("date")))

time_province.groupBy("year_month").agg(
    F.sum("released").alias("released"),
    F.sum("deceased").alias("deceased")
).orderBy("year_month").show()

+----------+--------+--------+
|year_month|released|deceased|
+----------+--------+--------+
|    202001|       0|       0|
|    202002|     311|      85|
|    202003|   57024|    2587|
|    202004|  226480|    6525|
|    202005|  294969|    8081|
|    202006|  309949|    8326|
+----------+--------+--------+



# Sort Merge Join

A Sort Merge Join enables an all-to-all communication strategy among the nodes: the Driver Node will orchestrate the Executors, each of which will hold a particular set of joining keys. Before running the actual operation, the partitions are first sorted (this operation is obviously heavy itself). As you can imagine this kind of strategy can be expensive: nodes need to use the network to share data.

In [78]:
cases_with_region = cases.join(regions, ['province','city'], how='left')
print("Join records {}".format(cases_with_region.count()))
cases_with_region.persist()

cases_with_region.printSchema()

Join records 174
root
 |-- province: string (nullable = true)
 |-- city: string (nullable = true)
 |--  case_id: integer (nullable = true)
 |-- group: boolean (nullable = true)
 |-- infection_source: string (nullable = true)
 |-- confirmed: integer (nullable = true)
 |-- lat: string (nullable = true)
 |-- long: string (nullable = true)
 |-- tot_cases_by_prov: long (nullable = true)
 |-- confirmed_level: string (nullable = true)
 |-- code: integer (nullable = true)
 |-- latitude: double (nullable = true)
 |-- longitude: double (nullable = true)
 |-- elementary_school_count: integer (nullable = true)
 |-- kindergarten_count: integer (nullable = true)
 |-- university_count: integer (nullable = true)
 |-- academy_ratio: double (nullable = true)
 |-- elderly_population_ratio: double (nullable = true)
 |-- elderly_alone_ratio: double (nullable = true)
 |-- nursing_home_count: integer (nullable = true)



## Broadcast / Map Side Joins

Use a boradcast join when you face a scenario where you need to join a very big table (about 1B Rows) with a very small table (about 100–200 rows). In such type of join, you broadcast the small table to each machine/node when you perform a join with the big table. Broadcasting operation is itself quite expensive (it means that all the nodes need to receive a copy of the table), so it’s not surprising that if we increase the amount of executors that need to receive the table, we increase the broadcasting cost. If we have more executors available, a sort merge join may be more efficient.

In [79]:
cases_reg_broad = cases.join(broadcast(regions), ['province','city'], how='left')

print("Join records {}".format(cases_with_region.count()))
cases_reg_broad.persist()

cases_reg_broad.printSchema()

Join records 174
root
 |-- province: string (nullable = true)
 |-- city: string (nullable = true)
 |--  case_id: integer (nullable = true)
 |-- group: boolean (nullable = true)
 |-- infection_source: string (nullable = true)
 |-- confirmed: integer (nullable = true)
 |-- lat: string (nullable = true)
 |-- long: string (nullable = true)
 |-- tot_cases_by_prov: long (nullable = true)
 |-- confirmed_level: string (nullable = true)
 |-- code: integer (nullable = true)
 |-- latitude: double (nullable = true)
 |-- longitude: double (nullable = true)
 |-- elementary_school_count: integer (nullable = true)
 |-- kindergarten_count: integer (nullable = true)
 |-- university_count: integer (nullable = true)
 |-- academy_ratio: double (nullable = true)
 |-- elderly_population_ratio: double (nullable = true)
 |-- elderly_alone_ratio: double (nullable = true)
 |-- nursing_home_count: integer (nullable = true)



# Writing

In [None]:
#converto to Pandas dataframe and save as csv file
cases_with_region.toPandas().to_csv("saved_cases.csv")

In [None]:
# alternatively, save as parquet file

# first remove blank space from column name to avoid column writing errors
#cases_with_region = cases_with_region.withColumnRenamed(" case_id", "case_id")
cases = cases.withColumnRenamed(" case_id", "case_id").persist()
cases.coalesce(1).write.format("parquet").mode("overwrite").save("cases_with_region.parquet")

In [None]:
# unpersist data after saving them
cases_with_region.unpersist()

# Coalesce / Repartition

With too few partitions you will not utilize all of the cores available in the cluster.

With too many partitions there will be excessive overhead in managing many small tasks.

Between the two the first one is far more impactful on performance. Scheduling too many smalls tasks is a relatively small impact at this point for partition counts below 1000. If you have on the order of tens of thousands of partitions then spark gets very slow.

Have your number of partitions set to 3 or 4 times the number of CPU cores in your cluster so that the work gets distributed more evenly among the available CPU cores.

In [80]:
# get the number of partitions in a data frame
cases_with_region.rdd.getNumPartitions()

200

In [81]:
# check out the distribution of records in a partition by using the glom function
#cases_with_region.rdd.glom().map(len).collect()

In [82]:
# coalesce partitions: can only reduce the number of partitions

cases_with_region = cases_with_region.coalesce(4)
cases_with_region.rdd.getNumPartitions()

4

In [83]:
# repartition by cores: increase or decrease the number of partitions
cores = 8
n = 3
partitions = cores*n

cases_with_region = cases_with_region.repartition("partitions")
cases_with_region.rdd.getNumPartitions()

24

In [84]:
# repartition by column: sets as many partitions as column levels

cases_with_region = cases_with_region.repartition("province")
cases_with_region.rdd.getNumPartitions()

200