In [1]:
import pyspark
from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()



## Setting up the data and analyzing it

In [2]:
prev = spark.read.option("recursiveFileLookup","true").csv("data/linkage/donation/block_1.csv")

prev.show(3)

parsed = spark.read.option("header","true").option("nullValue","?").option("inferSchema","true").csv("data//linkage//donation//block_1.csv")
parsed.printSchema()
parsed.count()
parsed.cache()

parsed.show(5)

from pyspark.sql.functions import col
parsed.groupBy("is_match").count().orderBy(col("count").desc()).show()

parsed.createOrReplaceTempView("linkage")

+-----+-----+-----------------+------------+------------+------------+-------+------+------+------+-------+--------+
|  _c0|  _c1|              _c2|         _c3|         _c4|         _c5|    _c6|   _c7|   _c8|   _c9|   _c10|    _c11|
+-----+-----+-----------------+------------+------------+------------+-------+------+------+------+-------+--------+
| id_1| id_2|     cmp_fname_c1|cmp_fname_c2|cmp_lname_c1|cmp_lname_c2|cmp_sex|cmp_bd|cmp_bm|cmp_by|cmp_plz|is_match|
|37291|53113|0.833333333333333|           ?|           1|           ?|      1|     1|     1|     1|      0|    TRUE|
|39086|47614|                1|           ?|           1|           ?|      1|     1|     1|     1|      1|    TRUE|
+-----+-----+-----------------+------------+------------+------------+-------+------+------+------+-------+--------+
only showing top 3 rows

root
 |-- id_1: integer (nullable = true)
 |-- id_2: integer (nullable = true)
 |-- cmp_fname_c1: double (nullable = true)
 |-- cmp_fname_c2: double (nullab

In [3]:
spark.sql("""
    SELECT is_match,count(*) cnt
    FROM linkage
    GROUP BY is_match
    ORDER BY cnt DESC""").show()

+--------+------+
|is_match|   cnt|
+--------+------+
|   false|572820|
|    true|  2093|
+--------+------+



## Fast Summary Statistics, Plotting and Reshaping DataFrames

In [4]:
summary = parsed.describe()
summary.select("summary","cmp_fname_c1","cmp_fname_c2").show()

matches = parsed.where("is_match = true")
match_summary = matches.describe()

misses = parsed.filter(col('is_match') == False)
miss_summary = misses.describe()

+-------+------------------+------------------+
|summary|      cmp_fname_c1|      cmp_fname_c2|
+-------+------------------+------------------+
|  count|            574811|             10325|
|   mean|0.7127592938253411|0.8977586763518969|
| stddev|0.3889286452463531|0.2742577520430532|
|    min|               0.0|               0.0|
|    max|               1.0|               1.0|
+-------+------------------+------------------+



In [5]:
summary_p = summary.toPandas()
summary_p.head()
summary_p.shape

summary_p = summary_p.set_index('summary').transpose().reset_index()
summary_p = summary_p.rename(columns={'index':'field'})
summary_p = summary_p.rename_axis(None,axis=1)
summary_p.shape

summaryT = spark.createDataFrame(summary_p)
summaryT
summaryT.printSchema()

root
 |-- field: string (nullable = true)
 |-- count: string (nullable = true)
 |-- mean: string (nullable = true)
 |-- stddev: string (nullable = true)
 |-- min: string (nullable = true)
 |-- max: string (nullable = true)



In [6]:
from pyspark.sql.types import DoubleType
for c in summaryT.columns:
    if c == 'field':
        continue
    summaryT = summaryT.withColumn(c,summaryT[c].cast(DoubleType()))
    
summaryT.printSchema()

root
 |-- field: string (nullable = true)
 |-- count: double (nullable = true)
 |-- mean: double (nullable = true)
 |-- stddev: double (nullable = true)
 |-- min: double (nullable = true)
 |-- max: double (nullable = true)



In [7]:
from pyspark.sql import DataFrame
from pyspark.sql.types import DoubleType

def pivot_summary(desc):
    desc_p = desc.toPandas()
    desc_p = desc_p.set_index('summary').transpose().reset_index()
    desc_p = desc_p.rename(columns={'index':'field'})
    desc_p = desc_p.rename_axis(None,axis = 1)
    
    descT = spark.createDataFrame(desc_p)
    
    for c in descT.columns:
        if c == 'field':
            continue
        else:
            descT = descT.withColumn(c,descT[c].cast(DoubleType()))    
        return descT

match_summaryT = pivot_summary(match_summary)
miss_summaryT = pivot_summary(miss_summary)
match_summaryT.createOrReplaceTempView("match_desc")
miss_summaryT.createOrReplaceTempView("miss_desc")

In [8]:
spark.sql("""
    SELECT a.field, a.count + b.count total, a.mean - b.mean delta
    FROM match_desc a INNER JOIN miss_desc b ON a.field = b.field
    WHERE a.field NOT IN ("id_1","id_2")
    ORDER BY delta DESC, total DESC
""")

DataFrame[field: string, total: double, delta: double]

In [9]:
good_features = ["cmp_lname_c1","cmp_plz","cmp_by","cmp_bd","cmp_bm"]

sum_expression = "+".join(good_features)

sum_expression

'cmp_lname_c1+cmp_plz+cmp_by+cmp_bd+cmp_bm'

In [10]:
from pyspark.sql.functions import expr

scored = parsed.fillna(0,subset=good_features).withColumn('score',expr(sum_expression)).select('score','is_match')
scored.show()

def crossTabs(scored: DataFrame, t: DoubleType) -> DataFrame:
    return scored.selectExpr(f"score>={t} as above","is_match").groupBy("above").pivot("is_match",("true","false")).count()

cm1 = crossTabs(scored,4.0)
cm2 = crossTabs(scored,2.0)
crossTabs(scored,4.0).show()
crossTabs(scored,4.0).show()

+-----+--------+
|score|is_match|
+-----+--------+
|  4.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  4.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  4.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
+-----+--------+
only showing top 20 rows

+-----+----+------+
|above|true| false|
+-----+----+------+
| true|2087|    66|
|false|   6|572754|
+-----+----+------+

+-----+----+------+
|above|true| false|
+-----+----+------+
| true|2087|    66|
|false|   6|572754|
+-----+----+------+



In [11]:
TP=cm1.filter("above==true").select("true").collect()[0].true
TN=cm1.filter("above==true").select("false").collect()[0].false
FP=cm1.filter("above==false").select("true").collect()[0].true
FN=cm1.filter("above==false").select("false").collect()[0].false

precision = TP/(TP + FP)
recall = TP/(TP + FN)
f1score = 2*precision*recall/(precision+recall)

print(f"Precision->{precision}\nRecall->{recall}\nF1-Score->{f1score}")

Precision->0.9971333014811276
Recall->0.0036305691486863325
F1-Score->0.007234796354522354
