In [1]:
import os
import sys
import spark_utils as sut
import pandas as pd

os.environ['PYSPARK_PYTHON'] = sys.executable
os.environ['PYSPARK_DRIVER_PYTHON'] = sys.executable

from pyspark.sql.types import StructType, StructField, IntegerType, StringType
import pyspark.sql.functions as F
spark = sut.get_spark_session()

## Introduction

This notebook demos the custom PySpark functions discussed in [Speed up Your ML Projects With Spark -- Handy Custom {pySpark} Functions (I)](https://medium.com/towards-artificial-intelligence/speed-up-your-ml-projects-with-spark-09183e054d3a) published on [Towards AI](https://pub.towardsai.net/). 

The revelant functions were saved in [spark_utils.py](spark_utils.py) and imported into this notebook for demo by `import spark_utils as sut`. 

## Demo dataframe

While Spark is famous for its ability to work with big data, for demo purposes, I have created a small dataset with an obvious duplicate issue. Do you notice that the two ID fields, ID1 and ID2, do not form a primary key? We will use this table to demo and test our custom functions.

In [2]:
# Define schema
schema = StructType([
    StructField("ID1", IntegerType(), True),
    StructField("ID2", StringType(), True),
    StructField("Name", StringType(), True),
    StructField("DOB", StringType(), True),
    StructField("City", StringType(), True)
])

# Sample data with duplicates based on ID1 and ID2
data = [
    (101, 'A', 'Alice', '2000-01-01', 'New York'),
    (102, 'B', 'Bob', '1990-01-01', 'Los Angeles'),
    (103, 'E', 'Elly', '1982-01-01', 'San Francisco'),
    (104, 'J', 'Jesse', '1995-01-01', 'Chicago'),
    (105, 'B', 'Bingo', '1987-01-01', 'Los Angeles'),
    (101, 'A', 'Alice', '2000-01-01', 'NY'),
    (102, 'B', 'Bob', '1990-01-01', 'LA'),
    (105, 'B', 'Binggy', '1987-01-01', 'Los Angeles'),  
]

# Create DataFrame with schema
df = spark.createDataFrame(data, schema=schema)

# Convert DOB column to DateType
df = df.withColumn("DOB", F.to_date(df["DOB"], "yyyy-MM-dd"))

df = df.orderBy('ID1','ID2')

# Show DataFrame
df.show()


+---+---+------+----------+-------------+
|ID1|ID2|  Name|       DOB|         City|
+---+---+------+----------+-------------+
|101|  A| Alice|2000-01-01|     New York|
|101|  A| Alice|2000-01-01|           NY|
|102|  B|   Bob|1990-01-01|  Los Angeles|
|102|  B|   Bob|1990-01-01|           LA|
|103|  E|  Elly|1982-01-01|San Francisco|
|104|  J| Jesse|1995-01-01|      Chicago|
|105|  B|Binggy|1987-01-01|  Los Angeles|
|105|  B| Bingo|1987-01-01|  Los Angeles|
+---+---+------+----------+-------------+



# Custom pySpark functions

## shape

I find the shape attribute of pandas dataframes is pretty convenient, therefore created a custom function to get the shape of spark dataframes too. A few things to note:

* This custom shape function prints out comma-formated numbers, which can be especially helpful for big datasets.
* It can return the shape tuple for further programmatic use when the print_only parameter is set to False.

BTW, you might be delighted to learn that all the functions in this article are equipped with 1) Docstring documentation and 2) Type hints. You are welcome 😁

In [3]:
# create a longer dataframe to demo the comma-formatted print out
df_ = df.crossJoin(spark.range(1, 1234567)).drop("id")

In [4]:
num_row, num_col = sut.shape(df_, print_only = False)
print(num_row, num_col)

Number of rows: 9,876,528
Number of columns: 5
9876528 5


## print schema alphabetically

In pySpark, there is a built-in printSchema function. However, when working with very wide tables, I prefer to have the column names sorted alphabetically so I can check for things more effectively. Here is the function for that.

In [5]:
sut.print_schema_alphabetically(df)

root
 |-- City: string (nullable = true)
 |-- DOB: date (nullable = true)
 |-- ID1: integer (nullable = true)
 |-- ID2: string (nullable = true)
 |-- Name: string (nullable = true)



## verify primary key

A common EDA task is to check the primary key(s) and troubleshoot for duplicates. The three functions below are created for this purpose. First, let’s look at the is_primary_key function. As its name indicates, this function checks if the specified column(s) forms a primary key in the DataFrame. A few things to note

* It returns False when the dataframe is empty, or when any of the specified columns are missing from the dataframe.
* It checks for missing values in any of the specified columns and excludes relevant rows from the row counts.
* Using the verbose parameter, users can specify whether to print out or suppress detailed info during the function run.

In [6]:
id_cols = ['ID1', 'ID2']
sut.is_primary_key(df, id_cols)

✅ No missing values found in columns: ID1, ID2
ℹ️ Total row count: 8
ℹ️ Unique row count: 5
❌ The column(s) ID1, ID2 do not form a primary key.


False

In [7]:
# surpress the detailed info during function run
sut.is_primary_key(df, id_cols, verbose = False)

False

In [8]:
id_cols_ = ['ID3']
sut.is_primary_key(df, id_cols_)

❌ Column(s) ID3 do not exist in the DataFrame.


False

In [9]:
# accept one single column name as a string
sut.is_primary_key(df, 'ID1')

✅ No missing values found in columns: ID1
ℹ️ Total row count: 8
ℹ️ Unique row count: 5
❌ The column(s) ID1 do not form a primary key.


False

## find duplicates

Consistent with our inspection of the dummy table, the two ID fields do not form a primary key. Of course, duplicates can exist in real data too, below is the function to identify them. 🔎

In [10]:
if not sut.is_primary_key(df, id_cols, verbose = False):
    dups = sut.find_duplicates(df, id_cols)
    dups.show()

+-----+---+---+------+----------+-----------+
|count|ID1|ID2|  Name|       DOB|       City|
+-----+---+---+------+----------+-----------+
|    2|101|  A| Alice|2000-01-01|   New York|
|    2|101|  A| Alice|2000-01-01|         NY|
|    2|102|  B|   Bob|1990-01-01|Los Angeles|
|    2|102|  B|   Bob|1990-01-01|         LA|
|    2|105|  B| Bingo|1987-01-01|Los Angeles|
|    2|105|  B|Binggy|1987-01-01|Los Angeles|
+-----+---+---+------+----------+-----------+



## identify columns responsible for dups

From the above table, it is fairly easy to tell which columns are responsible for duplications in our data.

* 🔎 The City column is responsible for the differences in 101-A and 102-B ID combinations. For example, the dup in 101-A is because the City is recorded both as “New York” and “NY”.
* 🔎 The Name column is responsible for the difference in the 105-B ID combination, where the person’s name is “Bingo” in one record and “Binggy” in another.
Identifying the root cause of the dups is important for troubleshooting. For instance, based on the discovery above, we should consolidate both city and person names in our data.

You can imagine that when we have very wide tables and many more dups, identifying and summarizing the root cause using human eyes 👀 like we did above becomes much trickier.

The cols_responsible_for_id_dups function comes in rescue by summarizing the difference_counts for each column based on the primary key(s) provided. 😎 For example, in the output below, we can easily see that the field City is responsible for differences in two unique ID combinations, while the Name column is responsible for the dups in one ID pair.

In [11]:
if not sut.is_primary_key(df, id_cols, verbose = False):
    dup_cols = sut.cols_responsible_for_id_dups(df, id_cols)
    dup_cols.show()

+--------+-----------------+
|col_name|difference_counts|
+--------+-----------------+
|    City|                2|
|    Name|                1|
|     DOB|                0|
+--------+-----------------+



## Dedupe the Dataframe

In [13]:
dedup_df = sut.deduplicate_by_rank(df, id_cols, 'City', ascending=False, tiebreaker_col='Name', verbose=False)
dedup_df.show()

+---+---+------+----------+-------------+
|ID1|ID2|  Name|       DOB|         City|
+---+---+------+----------+-------------+
|101|  A| Alice|2000-01-01|     New York|
|102|  B|   Bob|1990-01-01|  Los Angeles|
|103|  E|  Elly|1982-01-01|San Francisco|
|104|  J| Jesse|1995-01-01|      Chicago|
|105|  B|Binggy|1987-01-01|  Los Angeles|
+---+---+------+----------+-------------+



In [16]:
sut.is_primary_key(dedup_df, id_cols)

✅ No missing values found in columns: ID1, ID2
ℹ️ Total row count: 5
ℹ️ Unique row count: 5
🔑 The column(s) ID1, ID2 form a primary key.


True

## value counts with percent

The columns responsible for the most duplicates are listed at the top of the summary table above. We can then analyze these columns further for troubleshooting. If you have a very wide table, narrowing down the investigation like this can be pretty handy. For example, you can zoom in by checking relevant columns’ value counts among the dups. And of course, I have a custom function ready for you to do just this. 😜 This function is very much like the value_counts in pandas, with two additional features

* percentage for each unique value
* comma-formated numbers in the printout

Let’s see it in action

In [None]:
sut.value_counts_with_pct(dups, 'City')

+-----------+-----+-----+
|       City|count|  pct|
+-----------+-----+-----+
|Los Angeles|    3| 50.0|
|         LA|    1|16.67|
|         NY|    1|16.67|
|   New York|    1|16.67|
+-----------+-----+-----+



DataFrame[City: string, count: bigint, pct: double]