In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, isnan, when, count, countDistinct

import matplotlib.pyplot as plt
import seaborn as sns
from pyspark.sql.functions import min, max

In [2]:
spark = SparkSession.builder \
    .appName("EDA on Large Parquet Files") \
    .getOrCreate()


In [3]:
df1 = spark.read.parquet("sample-data/mx_submitsline.parquet")
df1.printSchema()

root
 |-- attending_provider_npi: string (nullable = true)
 |-- billing_provider_address_cbsa_name: string (nullable = true)
 |-- billing_provider_address_city: string (nullable = true)
 |-- billing_provider_address_county: string (nullable = true)
 |-- billing_provider_address_key: string (nullable = true)
 |-- billing_provider_address_latitude: string (nullable = true)
 |-- billing_provider_address_longitude: string (nullable = true)
 |-- billing_provider_address_precision: string (nullable = true)
 |-- billing_provider_address_region: string (nullable = true)
 |-- billing_provider_address_state: string (nullable = true)
 |-- billing_provider_address_street: string (nullable = true)
 |-- billing_provider_address_zipcode: string (nullable = true)
 |-- billing_provider_npi: string (nullable = true)
 |-- claim_all_diagnosis_codes: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- diagnosis_code_set: string (nullable = true)
 |    |    |-- diagnosis_c

In [12]:
df1.show(5)

+----------------------+----------------------------------+-----------------------------+-------------------------------+----------------------------+---------------------------------+----------------------------------+----------------------------------+-------------------------------+------------------------------+-------------------------------+--------------------------------+--------------------+-------------------------+---------------------------+----------------------------------+-------------------------------+--------------------+-----------------------------------+--------------------+-------------------------+-----------------------+-------------------------+---------------------------+--------------------+--------------------+-------------+----------------+----------------+----------------+----------------+----------------+----------------+----------------+----------------+--------+-----------------------------------+-----------------------------------+---------------------

In [4]:
df1.describe().show()

24/09/19 21:49:10 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
24/09/19 22:00:18 WARN DAGScheduler: Broadcasting large task binary with size 1163.1 KiB
                                                                                

+-------+--------------------+--------------------+--------------------+----------------------+----------------------------------+-----------------------------+-------------------------------+----------------------------+---------------------------------+----------------------------------+----------------------------------+-------------------------------+------------------------------+-------------------------------+--------------------------------+--------------------+---------------------------+----------------------------------+-------------------------------+--------------------+-----------------------------------+-------------------------+-----------------+------------------+-----------------+-----------------+-----------------+------------------+------------------+------------------+------------------+-----------------------------------+-----------------------------------+------------------------------+--------------------------------+-----------------------------+---------------

In [8]:
# Get the list of float and timestamp columns
float_and_timestamp_cols = [col_name for col_name, dtype in df1.dtypes if dtype in ['float', 'int', 'date']]

# Describe the float and timestamp columns
df1.select(float_and_timestamp_cols).describe().show()



+-------+------------------+-------------------------------+
|summary|patient_birth_year|patient_location_preferred_type|
+-------+------------------+-------------------------------+
|  count|          15337531|                       13767756|
|   mean|1983.8176939952068|              1.146122795900799|
| stddev| 352.8397439976159|            0.35322929306822703|
|    min|              1931|                              1|
|    max|              9999|                              2|
+-------+------------------+-------------------------------+



                                                                                

In [14]:
float_and_timestamp_cols

['claim_statement_from_date',
 'claim_statement_to_date',
 'clearinghouse_received_date',
 'inpatient_admission_date',
 'patient_birth_year',
 'patient_location_preferred_type',
 'line_level_from_date',
 'line_level_to_date']

In [22]:
from pyspark.sql import Row

def date_describe(date_columns):
    results = []
    for c in date_columns:
        result = df1.select(
            count(c).alias("total_rows"),
            countDistinct(c).alias("distinct_dates"),
            count(col(c).isNull()).alias("null_dates"),
            min(c).alias("min_date"),
            max(c).alias("max_date")
        ).collect()[0]
        results.append(Row(column=c, **result.asDict()))
    
    return spark.createDataFrame(results)

# Filter out non-date columns
date_columns = [col_name for col_name, dtype in df1.dtypes if dtype == 'date']
date_describe(date_columns).show()

                                                                                

+--------------------+----------+--------------+----------+----------+----------+
|              column|total_rows|distinct_dates|null_dates|  min_date|  max_date|
+--------------------+----------+--------------+----------+----------+----------+
|claim_statement_f...|  15482079|          3548|  15501334|2013-11-13|2024-09-05|
|claim_statement_t...|  15482079|          3548|  15501334|2013-11-13|2024-09-05|
|clearinghouse_rec...|  15501334|          3537|  15501334|2015-01-01|2024-09-06|
|inpatient_admissi...|   1984190|          3878|  15501334|1920-08-02|2024-09-03|
|line_level_from_date|  15401550|          3548|  15501334|2013-11-13|2024-09-05|
|  line_level_to_date|  15401544|          3552|  15501334|2013-11-13|2024-09-26|
+--------------------+----------+--------------+----------+----------+----------+



In [5]:
print(f"Number of rows: {df1.count()}")
print(f"Number of columns: {len(df1.columns)}")



Number of rows: 15501334
Number of columns: 150


                                                                                

In [9]:
df1.groupBy("patient_location_preferred_type").count().orderBy("count", ascending=False).show()



+-------------------------------+--------+
|patient_location_preferred_type|   count|
+-------------------------------+--------+
|                              1|11755973|
|                              2| 2011783|
|                           NULL| 1733578|
+-------------------------------+--------+



                                                                                

In [13]:
greater_null_percentages = [
    'attending_provider_npi', 'claim_all_diagnosis_codes', 'drg_code', 
    'inpatient_admitting_diagnosis_code', 'operating_provider_npi', 'principal_diagnosis_classification', 
    'line_level_drug_packaging_information'
]


def is_numeric_type(df, col_name):
    return dict(df.dtypes)[col_name] in ['int', 'bigint', 'float', 'double', 'decimal']

exprs = []
for col_name in greater_null_percentages:
    if is_numeric_type(df1, col_name):
        exprs.append(count(when(col(col_name).isNull() | isnan(col(col_name)), col_name)).alias(col_name))
    else:
        exprs.append(count(when(col(col_name).isNull(), col_name)).alias(col_name))

null_counts = df1.select(exprs)

null_counts.show()



+----------------------+-------------------------+--------+----------------------------------+----------------------+----------------------------------+-------------------------------------+
|attending_provider_npi|claim_all_diagnosis_codes|drg_code|inpatient_admitting_diagnosis_code|operating_provider_npi|principal_diagnosis_classification|line_level_drug_packaging_information|
+----------------------+-------------------------+--------+----------------------------------+----------------------+----------------------------------+-------------------------------------+
|              13969235|                    21850|15220257|                          14961229|              14940599|                          15498021|                             15088418|
+----------------------+-------------------------+--------+----------------------------------+----------------------+----------------------------------+-------------------------------------+



                                                                                

In [14]:
# Calculate the total number of rows in the dataframe
total_rows = df1.count()

# Calculate the percentage of null counts for each column
null_percentages = null_counts.select(
    [(col(c) / total_rows * 100).alias(c) for c in null_counts.columns]
)

null_percentages.show()



+----------------------+-------------------------+-----------------+----------------------------------+----------------------+----------------------------------+-------------------------------------+
|attending_provider_npi|claim_all_diagnosis_codes|         drg_code|inpatient_admitting_diagnosis_code|operating_provider_npi|principal_diagnosis_classification|line_level_drug_packaging_information|
+----------------------+-------------------------+-----------------+----------------------------------+----------------------+----------------------------------+-------------------------------------+
|     90.11634095491394|      0.14095561065905685|98.18675605596266|                  96.5157514830659|     96.38266616279606|                 99.97862764585294|                    97.33625506037093|
+----------------------+-------------------------+-----------------+----------------------------------+----------------------+----------------------------------+-------------------------------------+


                                                                                

In [6]:
df1.dtypes

[('attending_provider_npi', 'string'),
 ('billing_provider_address_cbsa_name', 'string'),
 ('billing_provider_address_city', 'string'),
 ('billing_provider_address_county', 'string'),
 ('billing_provider_address_key', 'string'),
 ('billing_provider_address_latitude', 'string'),
 ('billing_provider_address_longitude', 'string'),
 ('billing_provider_address_precision', 'string'),
 ('billing_provider_address_region', 'string'),
 ('billing_provider_address_state', 'string'),
 ('billing_provider_address_street', 'string'),
 ('billing_provider_address_zipcode', 'string'),
 ('billing_provider_npi', 'string'),
 ('claim_all_diagnosis_codes',
  'array<struct<diagnosis_code_set:string,diagnosis_code:string,diagnosis_pointer:string>>'),
 ('claim_filing_indicator_code', 'string'),
 ('claim_filing_indicator_description', 'string'),
 ('claim_filing_indicator_pay_type', 'string'),
 ('claim_id', 'string'),
 ('claim_institutional_or_professional', 'string'),
 ('claim_number', 'string'),
 ('claim_stateme

In [23]:
null_counts_all_columns = df1.select([count(when(col(c).isNull(), c)).alias(c) for c in df1.columns])

null_counts_all_columns.show()

24/09/23 16:01:57 ERROR Executor: Exception in task 4.0 in stage 198.0 (TID 827)
java.lang.OutOfMemoryError: Java heap space
	at java.base/java.lang.invoke.MethodTypeForm.setCachedLambdaForm(MethodTypeForm.java:144)
	at java.base/java.lang.invoke.Invokers.callSiteForm(Invokers.java:608)
	at java.base/java.lang.invoke.Invokers.linkToTargetMethod(Invokers.java:572)
	at java.base/java.lang.invoke.MethodHandleNatives.linkCallSiteImpl(MethodHandleNatives.java:288)
	at java.base/java.lang.invoke.MethodHandleNatives.linkCallSite(MethodHandleNatives.java:265)
	at org.apache.spark.sql.execution.aggregate.AggregationIterator.initializeAggregateFunctions(AggregationIterator.scala:81)
	at org.apache.spark.sql.execution.aggregate.AggregationIterator.<init>(AggregationIterator.scala:118)
	at org.apache.spark.sql.execution.aggregate.TungstenAggregationIterator.<init>(TungstenAggregationIterator.scala:106)
	at org.apache.spark.sql.execution.aggregate.HashAggregateExec.$anonfun$doExecute$1(HashAggregat

Py4JError: py4j does not exist in the JVM

ERROR:root:Exception while sending command.
Traceback (most recent call last):
  File "/Users/sahil/coding/doctor-right/.venv/lib/python3.12/site-packages/py4j/clientserver.py", line 516, in send_command
    raise Py4JNetworkError("Answer from Java side is empty")
py4j.protocol.Py4JNetworkError: Answer from Java side is empty

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/Users/sahil/coding/doctor-right/.venv/lib/python3.12/site-packages/py4j/java_gateway.py", line 1038, in send_command
    response = connection.send_command(command)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/sahil/coding/doctor-right/.venv/lib/python3.12/site-packages/py4j/clientserver.py", line 539, in send_command
    raise Py4JNetworkError(
py4j.protocol.Py4JNetworkError: Error while sending or receiving


ERROR:root:Exception while sending command.
Traceback (most recent call last):
  File "/Users/sahil/coding/doctor-right/.venv/lib/python3.12/site-packages/py4j/clientserver.py", line 516, in send_command
    raise Py4JNetworkError("Answer from Java side is empty")
py4j.protocol.Py4JNetworkError: Answer from Java side is empty

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/Users/sahil/coding/doctor-right/.venv/lib/python3.12/site-packages/py4j/java_gateway.py", line 1038, in send_command
    response = connection.send_command(command)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/sahil/coding/doctor-right/.venv/lib/python3.12/site-packages/py4j/clientserver.py", line 539, in send_command
    raise Py4JNetworkError(
py4j.protocol.Py4JNetworkError: Error while sending or receiving
