# <center> spark-advanced-with-covid-19-example </center>

## Below capabilities will be demoed in this section
* read csv data file
* select & filter
* change column name
* join
* orderBy
* distinct
* aggregation
* alias
* head
* Expr
* UDF - User Defined Function

In [1]:
import findspark
import os

findspark.init(os.environ['SPARK_HOME'])

In [58]:
from pyspark.sql import SparkSession, Window
import pyspark.sql.functions as F
import pyspark.sql.types as T

In [3]:
spark = SparkSession.builder.appName('covid-19').getOrCreate()

In [114]:
df = spark.read.csv('./data/time-series-19-covid-combined.csv', header=True, sep=',', inferSchema=True)

In [31]:
print(df.printSchema())
df.show(5)

root
 |-- Date: timestamp (nullable = true)
 |-- Country/Region: string (nullable = true)
 |-- Province/State: string (nullable = true)
 |-- Lat: double (nullable = true)
 |-- Long: double (nullable = true)
 |-- Confirmed: integer (nullable = true)
 |-- Recovered: integer (nullable = true)
 |-- Deaths: integer (nullable = true)

None
+-------------------+--------------+--------------+----+----+---------+---------+------+
|               Date|Country/Region|Province/State| Lat|Long|Confirmed|Recovered|Deaths|
+-------------------+--------------+--------------+----+----+---------+---------+------+
|2020-01-22 00:00:00|   Afghanistan|          null|33.0|65.0|        0|        0|     0|
|2020-01-23 00:00:00|   Afghanistan|          null|33.0|65.0|        0|        0|     0|
|2020-01-24 00:00:00|   Afghanistan|          null|33.0|65.0|        0|        0|     0|
|2020-01-25 00:00:00|   Afghanistan|          null|33.0|65.0|        0|        0|     0|
|2020-01-26 00:00:00|   Afghanistan|     

In [115]:
df = df.withColumnRenamed('Country/Region', 'CountryRegion') \
       .withColumnRenamed('Province/State', 'ProvinceState')
df.columns

['Date',
 'CountryRegion',
 'ProvinceState',
 'Lat',
 'Long',
 'Confirmed',
 'Recovered',
 'Deaths']

In [116]:
df = df.withColumn('Date', F.to_date(F.col('Date')))
df.printSchema()

root
 |-- Date: date (nullable = true)
 |-- CountryRegion: string (nullable = true)
 |-- ProvinceState: string (nullable = true)
 |-- Lat: double (nullable = true)
 |-- Long: double (nullable = true)
 |-- Confirmed: integer (nullable = true)
 |-- Recovered: integer (nullable = true)
 |-- Deaths: integer (nullable = true)



In [117]:
df = df.cache()

In [118]:
df.count()

22176

In [119]:
df.filter(F.col('Confirmed') > 0).show()

+----------+-------------+-------------+----+----+---------+---------+------+
|      Date|CountryRegion|ProvinceState| Lat|Long|Confirmed|Recovered|Deaths|
+----------+-------------+-------------+----+----+---------+---------+------+
|2020-02-24|  Afghanistan|         null|33.0|65.0|        1|        0|     0|
|2020-02-25|  Afghanistan|         null|33.0|65.0|        1|        0|     0|
|2020-02-26|  Afghanistan|         null|33.0|65.0|        1|        0|     0|
|2020-02-27|  Afghanistan|         null|33.0|65.0|        1|        0|     0|
|2020-02-28|  Afghanistan|         null|33.0|65.0|        1|        0|     0|
|2020-02-29|  Afghanistan|         null|33.0|65.0|        1|        0|     0|
|2020-03-01|  Afghanistan|         null|33.0|65.0|        1|        0|     0|
|2020-03-02|  Afghanistan|         null|33.0|65.0|        1|        0|     0|
|2020-03-03|  Afghanistan|         null|33.0|65.0|        1|        0|     0|
|2020-03-04|  Afghanistan|         null|33.0|65.0|        1|    

In [120]:
df.where("Confirmed > 0 AND CountryRegion = 'Australia'").show()

+----------+-------------+--------------------+--------+--------+---------+---------+------+
|      Date|CountryRegion|       ProvinceState|     Lat|    Long|Confirmed|Recovered|Deaths|
+----------+-------------+--------------------+--------+--------+---------+---------+------+
|2020-03-13|    Australia|Australian Capita...|-35.4735|149.0124|        1|        0|     0|
|2020-03-14|    Australia|Australian Capita...|-35.4735|149.0124|        1|        0|     0|
|2020-03-15|    Australia|Australian Capita...|-35.4735|149.0124|        1|        0|     0|
|2020-03-16|    Australia|Australian Capita...|-35.4735|149.0124|        2|        0|     0|
|2020-03-17|    Australia|Australian Capita...|-35.4735|149.0124|        2|        0|     0|
|2020-03-18|    Australia|Australian Capita...|-35.4735|149.0124|        3|        0|     0|
|2020-03-19|    Australia|Australian Capita...|-35.4735|149.0124|        4|        0|     0|
|2020-03-20|    Australia|Australian Capita...|-35.4735|149.0124|     

In [121]:
df.where("CountryRegion = 'China'").sort([F.desc('Date'), 'ProvinceState']).show()
# df.where("CountryRegion = 'China'").orderBy([F.col('Date').desc(), F.col('ProvinceState')]).show()
# df.where("CountryRegion = 'China'").orderBy([F.col('Date'), F.col('ProvinceState')], ascending=[False, True]).show()

+----------+-------------+--------------+-------+--------+---------+---------+------+
|      Date|CountryRegion| ProvinceState|    Lat|    Long|Confirmed|Recovered|Deaths|
+----------+-------------+--------------+-------+--------+---------+---------+------+
|2020-04-14|        China|         Anhui|31.8257|117.2264|      991|      984|     6|
|2020-04-14|        China|       Beijing|40.1824|116.4142|      589|      491|     8|
|2020-04-14|        China|     Chongqing|30.0572| 107.874|      579|      570|     6|
|2020-04-14|        China|        Fujian|26.0789|117.9874|      353|      329|     1|
|2020-04-14|        China|         Gansu|37.8099|101.0583|      139|      136|     2|
|2020-04-14|        China|     Guangdong|23.3417|113.4244|     1564|     1458|     8|
|2020-04-14|        China|       Guangxi|23.8298|108.7881|      254|      252|     2|
|2020-04-14|        China|       Guizhou|26.8154|106.8748|      146|      144|     2|
|2020-04-14|        China|        Hainan|19.1959|109.7

In [122]:
df.filter(F.col('CountryRegion') == 'China').where("Date = '2020-03-01 00:00:00'").show()

+----------+-------------+--------------+-------+--------+---------+---------+------+
|      Date|CountryRegion| ProvinceState|    Lat|    Long|Confirmed|Recovered|Deaths|
+----------+-------------+--------------+-------+--------+---------+---------+------+
|2020-03-01|        China|         Anhui|31.8257|117.2264|      990|      873|     6|
|2020-03-01|        China|       Beijing|40.1824|116.4142|      413|      276|     8|
|2020-03-01|        China|     Chongqing|30.0572| 107.874|      576|      450|     6|
|2020-03-01|        China|        Fujian|26.0789|117.9874|      296|      247|     1|
|2020-03-01|        China|         Gansu|37.8099|101.0583|       91|       84|     2|
|2020-03-01|        China|     Guangdong|23.3417|113.4244|     1349|     1016|     7|
|2020-03-01|        China|       Guangxi|23.8298|108.7881|      252|      181|     2|
|2020-03-01|        China|       Guizhou|26.8154|106.8748|      146|      112|     2|
|2020-03-01|        China|        Hainan|19.1959|109.7

In [123]:
df.where("CountryRegion = 'China'").select('Date').dropDuplicates().count()

84

In [124]:
df.where("CountryRegion = 'China'") \
  .groupBy('Date').agg(F.sum(F.col('Confirmed')).alias('sumConfirmed')) \
  .sort(F.desc('Date')).show()

+----------+------------+
|      Date|sumConfirmed|
+----------+------------+
|2020-04-14|       83306|
|2020-04-13|       83213|
|2020-04-12|       83134|
|2020-04-11|       83014|
|2020-04-10|       82941|
|2020-04-09|       82883|
|2020-04-08|       82809|
|2020-04-07|       82718|
|2020-04-06|       82665|
|2020-04-05|       82602|
|2020-04-04|       82543|
|2020-04-03|       82511|
|2020-04-02|       82432|
|2020-04-01|       82361|
|2020-03-31|       82279|
|2020-03-30|       82198|
|2020-03-29|       82122|
|2020-03-28|       81999|
|2020-03-27|       81897|
|2020-03-26|       81782|
+----------+------------+
only showing top 20 rows



### which date has the biggest death toll in China?

In [43]:
df.printSchema()

root
 |-- Date: date (nullable = true)
 |-- CountryRegion: string (nullable = true)
 |-- ProvinceState: string (nullable = true)
 |-- Lat: double (nullable = true)
 |-- Long: double (nullable = true)
 |-- Confirmed: integer (nullable = true)
 |-- Recovered: integer (nullable = true)
 |-- Deaths: integer (nullable = true)



In [141]:
total_deaths = df.where("CountryRegion = 'China'") \
                 .groupBy('Date').agg(F.sum(F.col('Deaths')).alias('sumDeaths')) \
                 .sort([F.desc('sumDeaths'), 'Date'])
total_deaths.show(5)

+----------+---------+
|      Date|sumDeaths|
+----------+---------+
|2020-04-13|     3345|
|2020-04-14|     3345|
|2020-04-11|     3343|
|2020-04-12|     3343|
|2020-04-10|     3340|
+----------+---------+
only showing top 5 rows



In [142]:
total_deaths_diff = total_deaths.withColumnRenamed('sumDeaths', 'total_deaths_today') \
                                .sort(F.desc('Date'))
total_deaths_diff.show(5)

+----------+------------------+
|      Date|total_deaths_today|
+----------+------------------+
|2020-04-14|              3345|
|2020-04-13|              3345|
|2020-04-12|              3343|
|2020-04-11|              3343|
|2020-04-10|              3340|
+----------+------------------+
only showing top 5 rows



In [143]:
total_deaths_diff = total_deaths_diff.withColumn(
    'total_deaths_yesterday', 
    F.lead(F.col('total_deaths_today'), 1).over(
            Window.orderBy(F.desc('Date'))
    ),
)
total_deaths_diff.show()

+----------+------------------+----------------------+
|      Date|total_deaths_today|total_deaths_yesterday|
+----------+------------------+----------------------+
|2020-04-14|              3345|                  3345|
|2020-04-13|              3345|                  3343|
|2020-04-12|              3343|                  3343|
|2020-04-11|              3343|                  3340|
|2020-04-10|              3340|                  3339|
|2020-04-09|              3339|                  3337|
|2020-04-08|              3337|                  3335|
|2020-04-07|              3335|                  3335|
|2020-04-06|              3335|                  3333|
|2020-04-05|              3333|                  3330|
|2020-04-04|              3330|                  3326|
|2020-04-03|              3326|                  3322|
|2020-04-02|              3322|                  3316|
|2020-04-01|              3316|                  3309|
|2020-03-31|              3309|                  3308|
|2020-03-3

In [144]:
total_deaths_diff = total_deaths_diff.withColumn('diff', F.col('total_deaths_today') - F.col('total_deaths_yesterday')) \
                                     .sort(F.desc('diff'))
total_deaths_diff.show()

+----------+------------------+----------------------+----+
|      Date|total_deaths_today|total_deaths_yesterday|diff|
+----------+------------------+----------------------+----+
|2020-02-13|              1369|                  1117| 252|
|2020-02-22|              2443|                  2238| 205|
|2020-02-14|              1521|                  1369| 152|
|2020-02-24|              2595|                  2445| 150|
|2020-02-15|              1663|                  1521| 142|
|2020-02-18|              2003|                  1864| 139|
|2020-02-20|              2238|                  2116| 122|
|2020-02-19|              2116|                  2003| 113|
|2020-02-10|              1012|                   905| 107|
|2020-02-16|              1766|                  1663| 103|
|2020-02-02|               361|                   259| 102|
|2020-02-11|              1112|                  1012| 100|
|2020-02-09|               905|                   805| 100|
|2020-02-17|              1864|         

#### define a python function to find top n deadlist days for a given country
* if country not provided, find whole world
* if n not provided, find top 1 day

In [63]:
print(df.printSchema())
df.show(5)

root
 |-- Date: date (nullable = true)
 |-- CountryRegion: string (nullable = true)
 |-- ProvinceState: string (nullable = true)
 |-- Lat: double (nullable = true)
 |-- Long: double (nullable = true)
 |-- Confirmed: integer (nullable = true)
 |-- Recovered: integer (nullable = true)
 |-- Deaths: integer (nullable = true)

None
+----------+-------------+-------------+----+----+---------+---------+------+
|      Date|CountryRegion|ProvinceState| Lat|Long|Confirmed|Recovered|Deaths|
+----------+-------------+-------------+----+----+---------+---------+------+
|2020-01-22|  Afghanistan|         null|33.0|65.0|        0|        0|     0|
|2020-01-23|  Afghanistan|         null|33.0|65.0|        0|        0|     0|
|2020-01-24|  Afghanistan|         null|33.0|65.0|        0|        0|     0|
|2020-01-25|  Afghanistan|         null|33.0|65.0|        0|        0|     0|
|2020-01-26|  Afghanistan|         null|33.0|65.0|        0|        0|     0|
+----------+-------------+-------------+----+--

In [129]:
def top_n_deadlist_day(df, country=None, n=1):
    if country is not None:
        df = df.where(f"CountryRegion = '{country}'")
    
    results = df.groupBy('Date').agg(F.sum(F.col('Deaths')).alias('sumDeaths')) \
                .sort([F.desc('sumDeaths'), 'Date']).collect()[:n]
    
    return [(row.Date.strftime('%Y-%m-%d'), row.sumDeaths) for row in results]

In [130]:
top_n_deadlist_day(df, country='China', n=10)
# top_n_deadlist_day(df)
# top_n_deadlist_day(df, country='China')
# top_n_deadlist_day(df, n=5)

[('2020-04-13', 3345),
 ('2020-04-14', 3345),
 ('2020-04-11', 3343),
 ('2020-04-12', 3343),
 ('2020-04-10', 3340),
 ('2020-04-09', 3339),
 ('2020-04-08', 3337),
 ('2020-04-06', 3335),
 ('2020-04-07', 3335),
 ('2020-04-05', 3333)]

### Spark User Defined Function

In [81]:
from math import sqrt

In [82]:
math = spark.createDataFrame([('Alex', 50), ('Bob', 30), ('Charlie', 36), ('Dan', 80)], ['name', 'score'])
math.show()

+-------+-----+
|   name|score|
+-------+-----+
|   Alex|   50|
|    Bob|   30|
|Charlie|   36|
|    Dan|   80|
+-------+-----+



In [83]:
my_udf = F.udf(lambda x: sqrt(x) * 10, T.FloatType())

In [85]:
math1 = math.withColumn('new_score', my_udf(F.col('score')))
math1.show()

+-------+-----+---------+
|   name|score|new_score|
+-------+-----+---------+
|   Alex|   50| 70.71068|
|    Bob|   30|54.772255|
|Charlie|   36|     60.0|
|    Dan|   80| 89.44272|
+-------+-----+---------+



#### Challenge - round the new_score to int ?

In [86]:
math2 = math1.withColumn('new_score', F.col('new_score').cast('int'))
print(math2.printSchema())
math2.show()

root
 |-- name: string (nullable = true)
 |-- score: long (nullable = true)
 |-- new_score: integer (nullable = true)

None
+-------+-----+---------+
|   name|score|new_score|
+-------+-----+---------+
|   Alex|   50|       70|
|    Bob|   30|       54|
|Charlie|   36|       60|
|    Dan|   80|       89|
+-------+-----+---------+



### Use UDF to address the same issue above

In [87]:
from datetime import datetime, timedelta

In [131]:
total_deaths.printSchema()

root
 |-- Date: date (nullable = true)
 |-- sumDeaths: long (nullable = true)



In [132]:
total_deaths.show(5)

+----------+---------+
|      Date|sumDeaths|
+----------+---------+
|2020-04-13|     3345|
|2020-04-14|     3345|
|2020-04-11|     3343|
|2020-04-12|     3343|
|2020-04-10|     3340|
+----------+---------+
only showing top 5 rows



In [133]:
get_yesterday_udf = F.udf(lambda x: x + timedelta(days=-1), T.DateType())

In [134]:
total_deaths.withColumn('Yesterday_Date', get_yesterday_udf(F.col('Date'))).show(10)

+----------+---------+--------------+
|      Date|sumDeaths|Yesterday_Date|
+----------+---------+--------------+
|2020-04-13|     3345|    2020-04-12|
|2020-04-14|     3345|    2020-04-13|
|2020-04-11|     3343|    2020-04-10|
|2020-04-12|     3343|    2020-04-11|
|2020-04-10|     3340|    2020-04-09|
|2020-04-09|     3339|    2020-04-08|
|2020-04-08|     3337|    2020-04-07|
|2020-04-06|     3335|    2020-04-05|
|2020-04-07|     3335|    2020-04-06|
|2020-04-05|     3333|    2020-04-04|
+----------+---------+--------------+
only showing top 10 rows



### Use window function to address the same issue above

In [135]:
total_deaths.show(5)

+----------+---------+
|      Date|sumDeaths|
+----------+---------+
|2020-04-13|     3345|
|2020-04-14|     3345|
|2020-04-11|     3343|
|2020-04-12|     3343|
|2020-04-10|     3340|
+----------+---------+
only showing top 5 rows



In [136]:
total_deaths = total_deaths.withColumn(
    'rank', 
    F.dense_rank().over(
        Window.partitionBy().orderBy(F.desc('sumDeaths'))
    ),
)

total_deaths.show()

+----------+---------+----+
|      Date|sumDeaths|rank|
+----------+---------+----+
|2020-04-13|     3345|   1|
|2020-04-14|     3345|   1|
|2020-04-12|     3343|   2|
|2020-04-11|     3343|   2|
|2020-04-10|     3340|   3|
|2020-04-09|     3339|   4|
|2020-04-08|     3337|   5|
|2020-04-07|     3335|   6|
|2020-04-06|     3335|   6|
|2020-04-05|     3333|   7|
|2020-04-04|     3330|   8|
|2020-04-03|     3326|   9|
|2020-04-02|     3322|  10|
|2020-04-01|     3316|  11|
|2020-03-31|     3309|  12|
|2020-03-30|     3308|  13|
|2020-03-29|     3304|  14|
|2020-03-28|     3299|  15|
|2020-03-27|     3296|  16|
|2020-03-26|     3291|  17|
+----------+---------+----+
only showing top 20 rows



In [146]:
total_deaths_diff.show(5)

+----------+------------------+----------------------+----+
|      Date|total_deaths_today|total_deaths_yesterday|diff|
+----------+------------------+----------------------+----+
|2020-02-13|              1369|                  1117| 252|
|2020-02-22|              2443|                  2238| 205|
|2020-02-14|              1521|                  1369| 152|
|2020-02-24|              2595|                  2445| 150|
|2020-02-15|              1663|                  1521| 142|
+----------+------------------+----------------------+----+
only showing top 5 rows



In [147]:
total_deaths_diff.withColumn(
    'rank',
    F.dense_rank().over(
        Window.partitionBy().orderBy(F.desc('diff'))
    ),
).show()

+----------+------------------+----------------------+----+----+
|      Date|total_deaths_today|total_deaths_yesterday|diff|rank|
+----------+------------------+----------------------+----+----+
|2020-02-13|              1369|                  1117| 252|   1|
|2020-02-22|              2443|                  2238| 205|   2|
|2020-02-14|              1521|                  1369| 152|   3|
|2020-02-24|              2595|                  2445| 150|   4|
|2020-02-15|              1663|                  1521| 142|   5|
|2020-02-18|              2003|                  1864| 139|   6|
|2020-02-20|              2238|                  2116| 122|   7|
|2020-02-19|              2116|                  2003| 113|   8|
|2020-02-10|              1012|                   905| 107|   9|
|2020-02-16|              1766|                  1663| 103|  10|
|2020-02-02|               361|                   259| 102|  11|
|2020-02-11|              1112|                  1012| 100|  12|
|2020-02-09|             

In [148]:
df.show(5)

+----------+-------------+-------------+----+----+---------+---------+------+
|      Date|CountryRegion|ProvinceState| Lat|Long|Confirmed|Recovered|Deaths|
+----------+-------------+-------------+----+----+---------+---------+------+
|2020-01-22|  Afghanistan|         null|33.0|65.0|        0|        0|     0|
|2020-01-23|  Afghanistan|         null|33.0|65.0|        0|        0|     0|
|2020-01-24|  Afghanistan|         null|33.0|65.0|        0|        0|     0|
|2020-01-25|  Afghanistan|         null|33.0|65.0|        0|        0|     0|
|2020-01-26|  Afghanistan|         null|33.0|65.0|        0|        0|     0|
+----------+-------------+-------------+----+----+---------+---------+------+
only showing top 5 rows



In [150]:
df1 = df.groupBy(['CountryRegion', 'Date']).agg(
    F.sum(F.col('Confirmed')).alias('Confirmed'),
    F.sum(F.col('Deaths')).alias('Deaths'),
).sort(['CountryRegion', 'Date'])

df1.where("CountryRegion = 'China'").show()

+-------------+----------+---------+------+
|CountryRegion|      Date|Confirmed|Deaths|
+-------------+----------+---------+------+
|        China|2020-01-22|      548|    17|
|        China|2020-01-23|      643|    18|
|        China|2020-01-24|      920|    26|
|        China|2020-01-25|     1406|    42|
|        China|2020-01-26|     2075|    56|
|        China|2020-01-27|     2877|    82|
|        China|2020-01-28|     5509|   131|
|        China|2020-01-29|     6087|   133|
|        China|2020-01-30|     8141|   171|
|        China|2020-01-31|     9802|   213|
|        China|2020-02-01|    11891|   259|
|        China|2020-02-02|    16630|   361|
|        China|2020-02-03|    19716|   425|
|        China|2020-02-04|    23707|   491|
|        China|2020-02-05|    27440|   563|
|        China|2020-02-06|    30587|   633|
|        China|2020-02-07|    34110|   718|
|        China|2020-02-08|    36814|   805|
|        China|2020-02-09|    39829|   905|
|        China|2020-02-10|    42

In [164]:
df2 = df1.withColumn(
    'Rank',
    F.dense_rank().over(
        Window.partitionBy('CountryRegion').orderBy('Date')
    ),
).withColumn(
    'RunningTotal_Confirmed',
    F.sum(F.col('Confirmed')).over(
        Window.partitionBy('CountryRegion').orderBy('Date') \
              .rowsBetween(Window.unboundedPreceding, 0)
    ),
).withColumn(
    'MovingAvg_Deaths',
    F.avg(F.col('Deaths')).over(
        Window.partitionBy('CountryRegion').orderBy('Date') \
              .rowsBetween(-4, 0)
    ),
)

In [165]:
df2.where("CountryRegion IN ('China', 'Canada')").orderBy([F.col('CountryRegion').desc(), F.col('Date')]).show(100)

+-------------+----------+---------+------+----+----------------------+------------------+
|CountryRegion|      Date|Confirmed|Deaths|Rank|RunningTotal_Confirmed|  MovingAvg_Deaths|
+-------------+----------+---------+------+----+----------------------+------------------+
|        China|2020-01-22|      548|    17|   1|                   548|              17.0|
|        China|2020-01-23|      643|    18|   2|                  1191|              17.5|
|        China|2020-01-24|      920|    26|   3|                  2111|20.333333333333332|
|        China|2020-01-25|     1406|    42|   4|                  3517|             25.75|
|        China|2020-01-26|     2075|    56|   5|                  5592|              31.8|
|        China|2020-01-27|     2877|    82|   6|                  8469|              44.8|
|        China|2020-01-28|     5509|   131|   7|                 13978|              67.4|
|        China|2020-01-29|     6087|   133|   8|                 20065|              88.8|