In [1]:
package_jar = '../target/spark-data-repair-plugin_2.12_spark3.1_0.1.0-EXPERIMENTAL-with-dependencies.jar'

In [2]:
import numpy as np
import pandas as pd
from pyspark.sql import *
from pyspark.sql.types import *
from pyspark.sql import functions as f

spark = SparkSession.builder \
    .config('spark.jars', package_jar) \
    .config('spark.deriver.memory', '8g') \
    .enableHiveSupport() \
    .getOrCreate()

# Suppresses user warinig messages in Python
import warnings
warnings.simplefilter("ignore", UserWarning)

# Suppresses `WARN` messages in JVM
spark.sparkContext.setLogLevel("ERROR")

In [3]:
from repair.api import Scavenger
Scavenger().version()

'0.1.0-spark3.1-EXPERIMENTAL'

In [4]:
spark.read.option("header", True).csv("../testdata/hospital.csv").createOrReplaceTempView("hospital")
spark.table('hospital').printSchema()

root
 |-- tid: string (nullable = true)
 |-- ProviderNumber: string (nullable = true)
 |-- HospitalName: string (nullable = true)
 |-- Address1: string (nullable = true)
 |-- Address2: string (nullable = true)
 |-- Address3: string (nullable = true)
 |-- City: string (nullable = true)
 |-- State: string (nullable = true)
 |-- ZipCode: string (nullable = true)
 |-- CountyName: string (nullable = true)
 |-- PhoneNumber: string (nullable = true)
 |-- HospitalType: string (nullable = true)
 |-- HospitalOwner: string (nullable = true)
 |-- EmergencyService: string (nullable = true)
 |-- Condition: string (nullable = true)
 |-- MeasureCode: string (nullable = true)
 |-- MeasureName: string (nullable = true)
 |-- Score: string (nullable = true)
 |-- Sample: string (nullable = true)
 |-- Stateavg: string (nullable = true)



In [5]:
import altair as alt

charts = []
pdf = spark.table('hospital').toPandas()
cols = ['ProviderNumber', 'HospitalName', 'Address1', 'City', 'State', 'ZipCode', 'CountyName', 'PhoneNumber', 'HospitalType', 'HospitalOwner', 'EmergencyService', 'Condition', 'MeasureCode', 'MeasureName', 'Score', 'Sample', 'Stateavg']

for c in cols:
    charts.append(alt.Chart(pdf).mark_bar().encode(x=alt.X(c), y=alt.Y('count()', axis=alt.Axis(title='freq'))).properties(width=300, height=300))

alt.hconcat(*charts)

In [6]:
spark.read.option("header", True).csv("../bin/testdata/hospital_error_cells.csv").createOrReplaceTempView("hospital_error_cells")
spark.table('hospital_error_cells').printSchema()

root
 |-- tid: string (nullable = true)
 |-- attribute: string (nullable = true)
 |-- correct_val: string (nullable = true)



In [7]:
from repair.model import RepairModel
model = RepairModel().setTableName('hospital').setRowId('tid').setDiscreteThreshold(100) 
target_columns = ['ProviderNumber', 'HospitalName', 'Address1', 'City', 'State', 'ZipCode', 'CountyName', 'PhoneNumber', 'HospitalType', 'HospitalOwner', 'EmergencyService', 'Condition', 'MeasureCode', 'MeasureName', 'Score', 'Stateavg']
error_cells_df = spark.table('hospital_error_cells')
repair_base_df = model._prepare_repair_base_cells('hospital', error_cells_df, target_columns, 1000, 20)

In [9]:
import altair as alt

charts = []
pdf = repair_base_df.toPandas()
cols = ['ProviderNumber', 'HospitalName', 'Address1', 'City', 'State', 'ZipCode', 'CountyName', 'PhoneNumber', 'HospitalType', 'HospitalOwner', 'EmergencyService', 'Condition', 'MeasureCode', 'MeasureName', 'Score', 'Sample', 'Stateavg']

for c in cols:
    charts.append(alt.Chart(pdf).mark_bar().encode(x=alt.X(c), y=alt.Y('count()', axis=alt.Axis(title='freq'))).properties(width=300, height=300))

alt.hconcat(*charts)

In [37]:
target = 'Condition'

In [38]:
pdf = repair_base_df.toPandas()
pdf = pdf[pdf[target].notna()]
X = pdf.drop(['tid', target], axis=1).reset_index(drop=True)
y = pdf[target].reset_index(drop=True)

In [39]:
import category_encoders as ce
se = ce.OrdinalEncoder(handle_unknown='impute')
X = se.fit_transform(X)
X

Unnamed: 0,ProviderNumber,HospitalName,Address1,Address2,Address3,City,State,ZipCode,CountyName,PhoneNumber,HospitalType,HospitalOwner,EmergencyService,MeasureCode,MeasureName,Score,Sample,Stateavg
0,1.0,1.0,1.0,,,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,,,,1.0
1,1.0,1.0,1.0,,,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,2.0,2.0,,,2.0
2,1.0,1.0,1.0,,,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,3.0,3.0,,,3.0
3,1.0,1.0,1.0,,,,1.0,1.0,1.0,1.0,1.0,1.0,1.0,4.0,4.0,,,4.0
4,1.0,1.0,1.0,,,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,5.0,5.0,,,5.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
963,46.0,46.0,46.0,,,40.0,1.0,45.0,34.0,46.0,1.0,1.0,1.0,24.0,24.0,40.0,217.0,24.0
964,46.0,46.0,46.0,,,40.0,1.0,45.0,34.0,46.0,1.0,1.0,1.0,25.0,25.0,30.0,104.0,25.0
965,46.0,46.0,46.0,,,40.0,1.0,45.0,34.0,46.0,1.0,1.0,1.0,1.0,,33.0,71.0,1.0
966,46.0,46.0,46.0,,,40.0,1.0,45.0,34.0,46.0,1.0,1.0,1.0,2.0,2.0,30.0,73.0,2.0


In [40]:
pdf = pd.concat([X, y], axis=1)

In [41]:
import altair as alt

cols = ['ProviderNumber', 'HospitalName', 'Address1', 'City', 'State', 'ZipCode', 'CountyName', 'PhoneNumber', 'HospitalType', 'HospitalOwner', 'EmergencyService', 'Condition', 'MeasureCode', 'MeasureName', 'Score', 'Sample', 'Stateavg']

alt.Chart(pdf).mark_circle().encode(
    alt.X(alt.repeat("column"), type='quantitative'),
    alt.Y(alt.repeat("row"), type='quantitative'),
    color=f'{target}:N'
).properties(width=200, height=200).repeat(row=cols, column=cols)