# Tutorial: introduction to data analysis with PySpark

In [1]:
import findspark

In [2]:
SPARK_HOME = "C:\Spark\spark-3.0.1-bin-hadoop2.7"
findspark.init(SPARK_HOME)

In [3]:
import pyspark
from pyspark import SparkContext
from pyspark.sql import functions as F
from pyspark.sql import SparkSession
from pyspark.sql import SQLContext

from pyspark.sql.functions import broadcast
from pyspark.sql.types import StringType, IntegerType
from pyspark.sql.window import Window

# Configuration & Intialization

* SparkContext — provides connection to Spark with the ability to create RDDs
* SQLContext — provides connection to Spark with the ability to run SQL queries on data
* SparkSession — all-encompassing context which includes coverage for SparkContext, SQLContext and HiveContext. SparkSession is the entry point for programming Spark applications. It let you interact with DataSet and DataFrame APIs provided by Spark. We set the application name by calling appName. The getOrCreate() method either returns a new SparkSession of the app or returns the existing one.

In [4]:
# create a spark session
spark = SparkSession.builder.appName("data analysis with PySaprk").getOrCreate()

In [5]:
# create a SparkContext instance which allows the Spark Application to access 
# Spark Cluster with the help of a resource manager which is usually YARN or Mesos
sc = SparkContext.getOrCreate()

In [6]:
# create a SQLContext instance to access the SQL query engine built on top of Spark
sqlContext = SQLContext(spark)

# Reading

We can start by loading the files in our dataset using the spark.read.load command. This command reads parquet files, which is the default file format for spark, but you can add the parameter format to read other formats.

In [None]:
# covid cases dataset
cases = spark.read.load(
    "data/Case.csv", 
    format="csv", 
    sep=",",
    inferSchema="true", 
    header="true"
)
cases.show(6)

In [None]:
# cases timing dataset
time_province = spark.read.load(
    "data/TimeProvince.csv", 
    format="csv", 
    sep=",",
    inferSchema="true", 
    header="true"
)
time_province.show(6)

In [None]:
regions = spark.read.load(
    "data/Region.csv",
    format="csv", 
    sep=",", 
    inferSchema="true", 
    header="true"
)
regions.show(6)

In [None]:
# alternative reading options
#df = spark.read.csv("Case.csv", header=True, inferSchema=True)
#df = spark.read.format('csv').options(header=True,inferSchema=True).load("Case.csv")

# for other file formats:
#df = spark.read.text(path_to_file)
#df = spark.read.json(path_to_file)
#df = spark.read.parquet(path_to_file)

## Overview

In [None]:
cases.show(10, truncate=False)

In [None]:
cases.columns

In [None]:
# know data schema
cases.printSchema()

# Basic Functions

Get a descriptive overview of data fields

In [None]:
# describe data
describe_cases = cases.select("province", "city", "infection_case", "confirmed", "latitude", "longitude")
describe_cases.describe().show()

We can filter a data frame using multiple conditions using AND(&), OR(|) and NOT(~) conditions

In [None]:
# filter columns
seoul_cases = cases.filter(F.col("province")=="Seoul")
seoul_cases.count()

#alternative
#seoul_cases = cases.where(F.col("province")=="Seoul")
#seoul_cases.count()

In [None]:
cases.where((cases.confirmed>100) & (cases.province=='Daegu')).show()

We can select a subset of columns using the select keyword

In [None]:
# select columns 
cases.select(["province", "city"]).show(5)

We can change a single column name

In [None]:
# rename columns: exisitng left, new right
cases = cases\
    .withColumnRenamed("latitude", "lat")\
    .withColumnRenamed("longitude", "long")\
    .withColumnRenamed("infection_case", "infection_source")

cases.show(10)

We can sort data by increasing (deafult) or decreasing order

In [None]:
# sorting: use "asc" for ascending and "desc" for descending order

cases.sort(F.desc("confirmed")).show()

# alternative
#cases.orderBy(F.desc("confirmed")).show()

We can use groupBy function with a spark DataFrame, too. Pretty much same as the pandas groupBy with the exception that you will need to import pyspark.sql.functions

In [None]:
# grouping operations: cases by province

cases.groupBy("province").count().orderBy("count", ascending=False).show()

If you don’t like the new grouped column names, you can use the alias keyword to rename columns in the agg command itself

In [None]:
# grouping operations: tot and max confirmed by prov-city

cases.groupBy(["province","city"]).agg(
    F.sum("confirmed").alias("tot confirmed"),
    F.max("confirmed").alias("max confirmed")
).show(10)

We can use .withcolumn along with PySpark SQL functions (like When/Othewise) to create a new column. In essence, you can find String functions, Date functions, and Math functions already implemented using Spark functions

In [None]:
# create new column by replacing "-" with "unknown" in all cases column
# if else: when / otherwise

for c in cases.columns:
    cases = cases.withColumn(
        c, F.when(F.col(c)=="-", None).otherwise(F.col(c))
    )

cases.select("city").distinct().show()

## Cache / Persist

Spark works on the lazy execution principle. What that means is that nothing really gets executed until you use an action function like the .count() on a dataframe. And if you do a .count function, it generally helps to cache at this step. So you might want to cache() or persist() your dataframes when you do a .count() operation.

In [None]:
# persist data in memory
cases.count()
cases = cases.persist()
# alternative
#cases.persist().count()

# alternative
#cases.cache().count())
#cases.count()
#cases = cases.cache()

In [None]:
# to unpersist:
#cases.unpersist()

In [None]:
# count nulls by column

cases.select([F.count(F.when(F.col(c).isNull(), c)).alias(c) for c in cases.columns]).show()


## Some more advanced processing: Windows

Window functions allow to create new columns based on groups of values avoiding for loops on data groups.

In [None]:
# summing
w_sum = Window.partitionBy("province")

cases = cases.withColumn("tot_cases_by_prov", F.sum("confirmed").over(w_sum))
cases.show(15)

In [None]:
# ranking
w_rank = Window().partitionBy('province').orderBy(F.desc('confirmed'))

cases_ranked = cases.withColumn(
    "rank_cases_by_prov", F.rank().over(w_rank)
).drop(*["lat", "long"])

cases_ranked.show(10)

In [None]:
cases_ranked.filter(F.col("rank_cases_by_prov")==1).select(["city", "confirmed"]).dropna().show()

In [None]:
start = time_province.select("date").rdd.min()[0]
end = time_province.select("date").rdd.max()[0]

print("Dataset time range: {} - {}".format(start, end))

In [None]:
#lagging
w_lag = Window().partitionBy(['province']).orderBy('date')

time_province = time_province.withColumn("lag_7",F.lag("confirmed", 7).over(w_lag))
time_province.filter(time_province.date>'2020-03-10').show()

In [None]:
# rolling aggregation (mean) ove ther last 7 days
# include current day: rowsBetween(-6,0)
# exclude current day: rowsBetween(-7,-1)

w_roll = Window().partitionBy(['province']).orderBy('date').rowsBetween(-6,0)

time_province = time_province.withColumn(
    "roll_7_confirmed", F.round(F.mean("confirmed").over(w_roll),2)
)
time_province.filter(time_province.date>'2020-03-10').show(10)

In [None]:
time_province.printSchema()

## Some more advanced processing: UDFs

Sometimes we want to do complicated things with a column or multiple columns. While Spark SQL functions do solve many use cases, when it comes to column creation, we can create Spark UDFs to build more matured Python functionalities.

In [None]:
cases.count()
cases = cases.persist()

In [None]:
def get_confirmed_level(confirmed):
    """
    Assigns "high" category if confirmed cases are 
    above 50 otherwise "low"
    """
    if confirmed < 50: 
        return 'low'
    else:
        return 'high'
    
#convert to a UDF Function by passing in the function and return type of function
confirmed_udf = F.udf(get_confirmed_level, StringType())

cases = cases.withColumn("confirmed_level", confirmed_udf(F.col("confirmed")))
cases.groupBy("confirmed_level").count().show()

In [None]:
def year_month(date):
    """
    Extract year_month from datetime.date object 
    as yyyymm format and nteger type
    """
    if date is not None:
        month = str(date.month)
        year = str(date.year)
        if len(month) < 2:
            year_month_var = year + "0" + month
        else:
            year_month_var = year + month
        return int(year_month_var)

        
year_month_udf = F.udf(year_month, IntegerType())

time_province = time_province.withColumn("date", F.to_date(F.col("date")))
time_province = time_province.withColumn("year_month", year_month_udf(F.col("date")))

time_province.groupBy("year_month").agg(
    F.sum("released").alias("released"),
    F.sum("deceased").alias("deceased")
).orderBy("year_month").show()

# Sort Merge Join

A Sort Merge Join enables an all-to-all communication strategy among the nodes: the Driver Node will orchestrate the Executors, each of which will hold a particular set of joining keys. Before running the actual operation, the partitions are first sorted (this operation is obviously heavy itself). As you can imagine this kind of strategy can be expensive: nodes need to use the network to share data.

In [None]:
cases_with_region = cases.join(regions, ['province','city'], how='left')
print("Join records {}".format(cases_with_region.count()))
cases_with_region.persist()

cases_with_region.printSchema()

## Broadcast/Map Side Joins

Use a boradcast join when you face a scenario where you need to join a very big table (about 1B Rows) with a very small table (about 100–200 rows). In such type of join, you broadcast the small table to each machine/node when you perform a join with the big table. Broadcasting operation is itself quite expensive (it means that all the nodes need to receive a copy of the table), so it’s not surprising that if we increase the amount of executors that need to receive the table, we increase the broadcasting cost. If we have more executors available, a sort merge join may be more efficient.

In [None]:
cases_reg_broad = cases.join(broadcast(regions), ['province','city'], how='left')
print("Join records {}".format(cases_with_region.count()))
cases_reg_broad.persist()

cases_reg_broad.printSchema()

# Writing

In [None]:
# remove blank space from column name to avoid writing errors
#cases_with_region = cases_with_region.withColumnRenamed(" case_id", "case_id")

#converto to Pandas dataframe and save as csv file
cases_with_region.toPandas().to_csv("saved_cases.csv")

In [None]:
import sys
import os

os.environ['HADOOP_HOME'] = "C:/hadoop"
sys.path.append("C:/hadoop/bin")

In [None]:
# alternatively, save as parquet
cases = cases.withColumnRenamed(" case_id", "case_id").persist()
cases.coalesce(1).write.format("parquet").mode("overwrite").save("cases_with_region.parquet")

In [None]:
# unpersist data after saving them
cases_with_region.unpersist()

# Coalesce / Repartition

With too few partitions You will not utilize all of the cores available in the cluster.

With too many partitions There will be excessive overhead in managing many small tasks.

Between the two the first one is far more impactful on performance. Scheduling too many smalls tasks is a relatively small impact at this point for partition counts below 1000. If you have on the order of tens of thousands of partitions then spark gets very slow.

Have your number of partitions set to 3 or 4 times the number of CPU cores in your cluster so that the work gets distributed more evenly among the available CPU cores.

In [None]:
# get the number of partitions in a data frame
cases_with_region.rdd.getNumPartitions()

In [None]:
# check out the distribution of records in a partition by using the glom function
#cases_with_region.rdd.glom().map(len).collect()

In [None]:
# coalesce partitions

cases_with_region = cases_with_region.coalesce(4)
cases_with_region.rdd.getNumPartitions()

In [None]:
# repartition by cores
cores = 8
n = 3
partitions = cores*n

cases_with_region = cases_with_region.repartition(partitions)
cases_with_region.rdd.getNumPartitions()

In [None]:
# repartition by column

cases_with_region = cases_with_region.repartition("province")
cases_with_region.rdd.getNumPartitions()