In [6]:
import pyspark
import os
import sys
from pyspark import SparkContext
os.environ['PYSPARK_PYTHON'] = sys.executable
os.environ['PYSPARK_DRIVER_PYTHON'] = sys.executable

In [7]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.config("spark.driver.memory", "16g")\
.appName('entitty').getOrCreate()

In [11]:
data = spark.read.option("recursifeFileLookup", "true").csv("data/*.csv")

In [12]:
data

DataFrame[_c0: string, _c1: string, _c2: string, _c3: string, _c4: string, _c5: string, _c6: string, _c7: string, _c8: string, _c9: string, _c10: string, _c11: string]

In [13]:
data.show(2)

+-----+-----+-----------------+------------+------------+------------+-------+------+------+------+-------+--------+
|  _c0|  _c1|              _c2|         _c3|         _c4|         _c5|    _c6|   _c7|   _c8|   _c9|   _c10|    _c11|
+-----+-----+-----------------+------------+------------+------------+-------+------+------+------+-------+--------+
| id_1| id_2|     cmp_fname_c1|cmp_fname_c2|cmp_lname_c1|cmp_lname_c2|cmp_sex|cmp_bd|cmp_bm|cmp_by|cmp_plz|is_match|
|37291|53113|0.833333333333333|           ?|           1|           ?|      1|     1|     1|     1|      0|    TRUE|
+-----+-----+-----------------+------------+------------+------------+-------+------+------+------+-------+--------+
only showing top 2 rows



In [14]:
parsed =  spark.read.option("header","true").option("nullValue","?").\
option("inferSchema","true").csv("data/*.csv")

In [18]:
parsed.printSchema()
parsed.show(5)

root
 |-- id_1: string (nullable = true)
 |-- id_2: string (nullable = true)
 |-- cmp_fname_c1: string (nullable = true)
 |-- cmp_fname_c2: string (nullable = true)
 |-- cmp_lname_c1: string (nullable = true)
 |-- cmp_lname_c2: string (nullable = true)
 |-- cmp_sex: string (nullable = true)
 |-- cmp_bd: string (nullable = true)
 |-- cmp_bm: string (nullable = true)
 |-- cmp_by: integer (nullable = true)
 |-- cmp_plz: integer (nullable = true)
 |-- is_match: boolean (nullable = true)

+-----+-----+-----------------+------------+------------+------------+-------+------+------+------+-------+--------+
| id_1| id_2|     cmp_fname_c1|cmp_fname_c2|cmp_lname_c1|cmp_lname_c2|cmp_sex|cmp_bd|cmp_bm|cmp_by|cmp_plz|is_match|
+-----+-----+-----------------+------------+------------+------------+-------+------+------+------+-------+--------+
|37291|53113|0.833333333333333|        null|           1|        null|      1|     1|     1|     1|      0|    true|
|39086|47614|                1|        null

In [19]:
parsed.count()

5749133

In [20]:
parsed.cache()

DataFrame[id_1: string, id_2: string, cmp_fname_c1: string, cmp_fname_c2: string, cmp_lname_c1: string, cmp_lname_c2: string, cmp_sex: string, cmp_bd: string, cmp_bm: string, cmp_by: int, cmp_plz: int, is_match: boolean]

In [21]:
from pyspark.sql.functions import col
parsed.groupBy("is_match").count().orderBy(col("count").desc()).show()

+--------+-------+
|is_match|  count|
+--------+-------+
|   false|5728201|
|    true|  20931|
|    null|      1|
+--------+-------+



In [22]:
parsed.createOrReplaceTempView("linkage")

In [23]:
spark.sql("""
SELECT is_match, COUNT(*) cnt
FROM linkage
GROUP BY is_match
ORDER BY cnt DESC
""").show()

+--------+-------+
|is_match|    cnt|
+--------+-------+
|   false|5728201|
|    true|  20931|
|    null|      1|
+--------+-------+



In [24]:
summary = parsed.describe()

In [25]:
summary.select("summary","cmp_fname_c1",'cmp_fname_c2').show()

+-------+--------------------+------------------+
|summary|        cmp_fname_c1|      cmp_fname_c2|
+-------+--------------------+------------------+
|  count|             5748126|            103699|
|   mean|  0.7129023464241683|0.9000089989364238|
| stddev|  0.3887584395082916|0.2713306768152377|
|    min|                   0|                 0|
|    max|2.68694413843136e-05|                 1|
+-------+--------------------+------------------+



In [27]:
matches = parsed.where('is_match = true')
matchSummary = matches.describe()
misses = parsed.filter(col('is_match')==False)
missSummary = misses.describe()

In [47]:
from pyspark.sql import DataFrame
from pyspark.sql.types import DoubleType
def pivot(d):
    dp = d.toPandas()
    dp = dp.set_index("summary").transpose().reset_index()\
.rename(columns={'index':'field'}).rename_axis(None, axis = 1)
    dt = spark.createDataFrame(dp)
    for c in dt.columns:
        if c=='field':
            continue
        dt = dt.withColumn(c, dt[c].cast(DoubleType()))
    return dt

In [48]:
matchSummaryT = pivot(matchSummary)
missSummaryT = pivot(missSummary)

In [148]:
matchSummaryT.createOrReplaceTempView("match_desc")
missSummaryT.createOrReplaceTempView("miss_desc")
spark.sql("""
SELECT a.field, a.count + b.count total, a.mean - b.mean delta
FROM match_desc a INNER JOIN miss_desc b ON a.field = b.field
WHERE a.field NOT IN ("id_1", "id_2")
ORDER BY delta DESC, total DESC
""").show()

+------------+---------+--------------------+
|       field|    total|               delta|
+------------+---------+--------------------+
|     cmp_plz|5736289.0|  0.9563812499852176|
|cmp_lname_c2|   2464.0|  0.8064147192926268|
|      cmp_by|5748337.0|  0.7762059675300512|
|      cmp_bd|5748337.0|   0.775442311783404|
|cmp_lname_c1|5749132.0|  0.6838772482599225|
|      cmp_bm|5748337.0|  0.5109496938298685|
|cmp_fname_c1|5748125.0| 0.28545290574676274|
|cmp_fname_c2| 103698.0| 0.09104268062279941|
|     cmp_sex|5749132.0|0.032408185250332844|
+------------+---------+--------------------+



In [51]:
good_features = ["cmp_lname_c1", "cmp_plz", "cmp_by", "cmp_bd", "cmp_bm"]
...
sum_expression = " + ".join(good_features)
...
sum_expression

'cmp_lname_c1 + cmp_plz + cmp_by + cmp_bd + cmp_bm'

In [58]:
from pyspark.sql.functions import expr
scored = parsed.fillna(0, subset=good_features).\
withColumn('score', expr(sum_expression)).\
select('score', 'is_match')

scored.show()

+-----+--------+
|score|is_match|
+-----+--------+
|  4.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  4.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  4.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
+-----+--------+
only showing top 20 rows



In [55]:
def crossTabs(scored: DataFrame, t: DoubleType) -> DataFrame:
    return scored.selectExpr(f"score >= {t} as above", "is_match").\
groupBy("above").pivot("is_match", ("true", "false")).\
count()

In [56]:
crossTabs(scored, 4.0).show()

+-----+-----+-------+
|above| true|  false|
+-----+-----+-------+
| null|    6|    789|
| true|20871|    637|
|false|   54|5726775|
+-----+-----+-------+



In [78]:
confused = crossTabs(scored, 2.0)

In [131]:
tp = confused.filter("above = true").select("true").collect()[0].true
fp = confused.filter("above = true").select("false").collect()[0].false
fn = confused.filter("above = false").select("true").fillna(0).collect()[0].true
tn = confused.filter("above = false").select("false").collect()[0].false

In [139]:
precision = tp/(tp+fp)
recall = tp/(tp+fn)
f1 = 2*precision*recall/(precision+recall)

In [145]:
precision

0.03389553210720869

In [146]:
recall

1.0

In [147]:
f1

0.06556858223020917

In [149]:
confused.show()

+-----+-----+-------+
|above| true|  false|
+-----+-----+-------+
| null|    6|    789|
| true|20925| 596413|
|false| null|5130999|
+-----+-----+-------+

