In [1]:
package_jar = '../target/spark-data-repair-plugin_2.12_spark3.2_0.1.0-EXPERIMENTAL-with-dependencies.jar'

In [64]:
import numpy as np
import pandas as pd
from pyspark.sql import *
from pyspark.sql.types import *
from pyspark.sql import functions as f

spark = SparkSession.builder \
    .config('spark.jars', package_jar) \
    .config('spark.deriver.memory', '8g') \
    .enableHiveSupport() \
    .getOrCreate()

# Suppresses user warinig messages in Python
import warnings
warnings.simplefilter("ignore", UserWarning)

# Suppresses `WARN` messages in JVM
spark.sparkContext.setLogLevel("ERROR")

In [65]:
from repair.api import Scavenger
Scavenger().version()

'0.1.0-spark3.2-EXPERIMENTAL'

In [66]:
boston_schema = "tid string, CRIM double, ZN int, INDUS string, CHAS string, " \
    "NOX string, RM double, AGE string, DIS double, RAD string, TAX int, " \
    "PTRATIO string, B double, LSTAT double"
spark.read.option("header", True).schema(boston_schema).csv("../testdata/boston.csv").createOrReplaceTempView("boston")
spark.table('boston').printSchema()

root
 |-- tid: string (nullable = true)
 |-- CRIM: double (nullable = true)
 |-- ZN: integer (nullable = true)
 |-- INDUS: string (nullable = true)
 |-- CHAS: string (nullable = true)
 |-- NOX: string (nullable = true)
 |-- RM: double (nullable = true)
 |-- AGE: string (nullable = true)
 |-- DIS: double (nullable = true)
 |-- RAD: string (nullable = true)
 |-- TAX: integer (nullable = true)
 |-- PTRATIO: string (nullable = true)
 |-- B: double (nullable = true)
 |-- LSTAT: double (nullable = true)



In [67]:
spark.table('boston').toPandas()

Unnamed: 0,tid,CRIM,ZN,INDUS,CHAS,NOX,RM,AGE,DIS,RAD,TAX,PTRATIO,B,LSTAT
0,0,0.00632,18,2.31,0.0,0.538,6.575,65.2,4.0900,1,296.0,15.3,396.90,4.98
1,1,0.02731,0,7.07,0.0,0.469,6.421,78.9,4.9671,2,242.0,17.8,396.90,9.14
2,2,0.02729,0,7.07,0.0,0.469,7.185,61.1,4.9671,2,242.0,17.8,392.83,4.03
3,3,0.03237,0,2.18,0.0,0.458,6.998,45.8,6.0622,3,222.0,18.7,394.63,2.94
4,4,0.06905,0,2.18,0.0,0.458,7.147,54.2,6.0622,3,222.0,18.7,396.90,5.33
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
501,501,0.06263,0,11.93,0.0,0.573,6.593,69.1,2.4786,1,273.0,21.0,391.99,9.67
502,502,0.04527,0,11.93,0.0,0.573,6.120,76.7,2.2875,1,273.0,21.0,396.90,9.08
503,503,0.06076,0,11.93,0.0,0.573,6.976,91.0,2.1675,1,273.0,21.0,396.90,5.64
504,504,0.10959,0,11.93,0.0,0.573,6.794,89.3,2.3889,1,273.0,21.0,393.45,6.48


In [68]:
import altair as alt

charts = []
pdf = spark.table('boston').toPandas()

for c in [c for c in pdf.columns if c not in ['tid', 'CRIM', 'RM', 'DIS', 'TAX', 'B', 'LSTAT', 'INDUS']]:
    charts.append(alt.Chart(pdf).mark_bar().encode(x=alt.X(c), y=alt.Y('count()', axis=alt.Axis(title='freq'))).properties(width=300, height=300))

alt.hconcat(*charts)

In [69]:
target = 'RAD'

In [70]:
pdf = spark.table('boston').toPandas()

X_test = (pdf[pdf[target].isna()]).drop([target], axis=1).reset_index(drop=True)

pdf = pdf[pdf[target].notna()]
X = pdf.drop(['tid', target], axis=1).reset_index(drop=True)
y = pdf[target].reset_index(drop=True)

In [71]:
import category_encoders as ce
cols = ['CHAS', 'NOX', 'AGE', 'RAD', 'PTRATIO', 'INDUS']
se = ce.OrdinalEncoder(cols=[c for c in cols if c != target], handle_unknown='impute')
X = se.fit_transform(X)
_X_test = se.transform(X_test[X.columns]).copy(deep=True)
X_test = pd.concat([X_test[['tid']], _X_test], axis=1)

In [72]:
import altair as alt

cols = ['CHAS', 'NOX', 'AGE', 'RAD', 'PTRATIO', 'CRIM', 'RM', 'DIS', 'TAX', 'B', 'LSTAT', 'INDUS']

_y = y.replace(dict(map(lambda v: (v[1], v[0]), enumerate(y.unique()))))
pdf = pd.concat([X, _y], axis=1)

alt.Chart(pdf).mark_circle().encode(
    alt.X(alt.repeat("column"), type='quantitative'),
    alt.Y(alt.repeat("row"), type='quantitative'),
    color=f'{target}:N'
).properties(width=200, height=200).repeat(row=cols, column=cols)

In [73]:
from minepy import MINE

results = []

cols = ['CHAS', 'NOX', 'AGE', 'RAD', 'PTRATIO', 'CRIM', 'RM', 'DIS', 'TAX', 'B', 'LSTAT', 'INDUS']
cols.remove(target)

mine = MINE(alpha=0.6, c=15, est="mic_approx")

import itertools
for c1, c2 in itertools.combinations(cols, 2):
    mine.compute_score(X[c1], X[c2])
    results.append(((c1, c2), mine.mic()))

print(sorted(results, key=lambda x: x[1], reverse=True)[0:3])

[(('PTRATIO', 'INDUS'), 0.9857112305824217), (('NOX', 'INDUS'), 0.9458702128069584), (('NOX', 'PTRATIO'), 0.9421340157534075)]


In [74]:
from minepy import MINE

results = []

_y = y.replace(dict(map(lambda v: (v[1], v[0]), enumerate(y.unique()))))
cols = ['CHAS', 'NOX', 'AGE', 'RAD', 'PTRATIO', 'CRIM', 'RM', 'DIS', 'TAX', 'B', 'LSTAT', 'INDUS']

mine = MINE(alpha=0.6, c=15, est="mic_approx")

for c in [c for c in cols if c != target]:
    mine.compute_score(_y, X[c])
    results.append(((target, c), mine.mic()))

print(sorted(results, key=lambda x: x[1], reverse=True)[0:3])

[(('RAD', 'NOX'), 0.9541832132990704), (('RAD', 'INDUS'), 0.9533622598514608), (('RAD', 'TAX'), 0.8399754036616242)]


In [75]:
from sklearn.ensemble import RandomForestClassifier
from boruta import BorutaPy

# RandomForestClassifier cannot handle NaN correctly
_X = X.fillna(-255.0)

rf = RandomForestClassifier(n_jobs=-1, max_depth=5)
rf.fit(_X, y)
print('SCORE with ALL Features: %1.2f' % rf.score(_X, y))

rf = RandomForestClassifier(n_jobs=-1, max_depth=5)
fs = BorutaPy(rf, n_estimators='auto', random_state=0, perc=80, two_step=False, max_iter=500)
fs.fit(_X.values, y.values)

selected = fs.support_
print('Selected Features: %s' % ','.join(_X.columns[selected]))

X_selected = _X[_X.columns[selected]]
rf = RandomForestClassifier(n_jobs=-1, max_depth=5)
rf.fit(X_selected, y)
print('SCORE with selected Features: %1.2f' % rf.score(X_selected, y))

SCORE with ALL Features: 0.90
Selected Features: CRIM,ZN,INDUS,NOX,RM,AGE,DIS,TAX,PTRATIO,B,LSTAT
SCORE with selected Features: 0.89


In [76]:
# One of non-linear embedding in sklearn
from sklearn.manifold import TSNE
tsne = TSNE(n_components=2, random_state=0, perplexity=300, n_iter=1000)
_X = X_selected.dropna()
_X = tsne.fit_transform(_X)
print('KL divergence: {}'.format(tsne.kl_divergence_))

import altair as alt
_X = pd.DataFrame({'tSNE-X': _X[:, 0], 'tSNE-Y': _X[:, 1], target: y})
alt.Chart(_X).mark_point().encode(x='tSNE-X', y='tSNE-Y', color=f'{target}:N').properties(width=600, height=400).interactive()

KL divergence: 0.055440932512283325


In [85]:
from repair import train
params = {'hp.timeout': '3600', 'hp.no_progress_loss': '30'}
(clf, score), _ = train.build_model(X[X.columns[selected]], y, is_discrete=True, num_class=len(y.unique()), n_jobs=-1, opts=params)
print(f'Score: {score}')

# import lightgbm as lgb
# obj = 'multiclass' if len(y.unique()) > 2 else 'binary'
# clf = lgb.LGBMClassifier(objective=obj, num_leaves=64, min_child_samples=20, max_depth=7)
# clf.fit(X[X.columns[selected]], y)

import json
top_k = 3
probs = clf.predict_proba(X_test[X.columns[selected]])
pmf = map(lambda p: {"classes": clf.classes_.tolist(), "probs": p.tolist()}, probs)
pmf = map(lambda p: json.dumps(p), pmf)
df = spark.createDataFrame(pd.DataFrame({'tid': X_test['tid'], 'pmf': pd.Series(list(pmf))}))
df = df.selectExpr('tid', 'from_json(pmf, "classes array<string>, probs array<double>") pmf')
df = df.selectExpr('tid', 'arrays_zip(pmf.classes, pmf.probs) pmf')
df = df.selectExpr('tid', f'slice(array_sort(pmf, (left, right) -> if(left.`1` < right.`1`, 1, -1)), 1, {top_k}) top_k_pmf')
df = df.selectExpr('tid', f'top_k_pmf[0].`0` `{target}`', 'top_k_pmf')
predicted = df.toPandas()

Score: 0.9222122552684354


In [83]:
spark.read.option("header", True).csv("../testdata/boston_clean.csv").createOrReplaceTempView("boston_clean")
pdf_clean = spark.table('boston_clean').where(f'attribute = "{target}"').selectExpr('tid', 'correct_val').toPandas()
result = pd.merge(predicted, pdf_clean, on='tid')
result['is_correct'] = result[target] == result['correct_val']
pd.set_option("display.max_colwidth", 300)
result

Unnamed: 0,tid,RAD,top_k_pmf,correct_val,is_correct
0,72,4,"[(4, 0.9473800432317989), (5, 0.017509841380478806), (6, 0.011693077501512114)]",4,True
1,94,4,"[(4, 0.6545920254937155), (2, 0.2772878519246377), (7, 0.01331131534753019)]",4,True
2,98,2,"[(2, 0.9782369427418603), (5, 0.005641408556065768), (4, 0.004562545208971053)]",2,True
3,147,5,"[(5, 0.9838303224088093), (24, 0.0027471293969857646), (2, 0.0024388910034314868)]",5,True
4,162,5,"[(5, 0.9808974848174248), (4, 0.00592823808583792), (2, 0.0024881447709252536)]",5,True
5,226,8,"[(8, 0.9838265200580221), (5, 0.003357011279786091), (4, 0.002433643007997613)]",8,True
6,336,5,"[(5, 0.9685706950824172), (1, 0.00878301414338974), (3, 0.006654767911506455)]",5,True
7,339,5,"[(5, 0.9678465828994787), (1, 0.00937782107987187), (3, 0.0066344867509307084)]",5,True
8,365,24,"[(24, 0.9874359793298935), (4, 0.0022707630048804913), (5, 0.0017174635816137509)]",24,True
9,451,24,"[(24, 0.9877091186975115), (4, 0.0020676924221626728), (5, 0.0017079166716726809)]",24,True


In [84]:
print('Accuracy: {}'.format(len(result[result['is_correct']]) / len(result)))

Accuracy: 1.0
