In [3]:
import pyspark
import os
import sys
from pyspark import SparkContext
os.environ['PYSPARK_PYTHON'] = sys.executable
os.environ['PYSPARK_DRIVER_PYTHON'] = sys.executable
from pyspark.sql import SparkSession
spark = SparkSession.builder.config("spark.driver.memory", "16g").appName('chapter_2').getOrCreate()

In [4]:
prev = spark.read.option("recursiveFileLookup","true").csv("databases/*.csv")
prev

DataFrame[_c0: string, _c1: string, _c2: string, _c3: string, _c4: string, _c5: string, _c6: string, _c7: string, _c8: string, _c9: string, _c10: string, _c11: string]

In [5]:
prev.show(2)

+-----+-----+-----------------+------------+------------+------------+-------+------+------+------+-------+--------+
|  _c0|  _c1|              _c2|         _c3|         _c4|         _c5|    _c6|   _c7|   _c8|   _c9|   _c10|    _c11|
+-----+-----+-----------------+------------+------------+------------+-------+------+------+------+-------+--------+
| id_1| id_2|     cmp_fname_c1|cmp_fname_c2|cmp_lname_c1|cmp_lname_c2|cmp_sex|cmp_bd|cmp_bm|cmp_by|cmp_plz|is_match|
|37291|53113|0.833333333333333|           ?|           1|           ?|      1|     1|     1|     1|      0|    TRUE|
+-----+-----+-----------------+------------+------------+------------+-------+------+------+------+-------+--------+
only showing top 2 rows



In [6]:
parsed = spark.read.option("header", "true").option("nullValue", "?").option("inferSchema", "true").csv("databases/*.csv")

In [7]:
parsed.printSchema()
parsed.show(5)

root
 |-- id_1: string (nullable = true)
 |-- id_2: string (nullable = true)
 |-- cmp_fname_c1: string (nullable = true)
 |-- cmp_fname_c2: string (nullable = true)
 |-- cmp_lname_c1: string (nullable = true)
 |-- cmp_lname_c2: string (nullable = true)
 |-- cmp_sex: string (nullable = true)
 |-- cmp_bd: string (nullable = true)
 |-- cmp_bm: string (nullable = true)
 |-- cmp_by: integer (nullable = true)
 |-- cmp_plz: integer (nullable = true)
 |-- is_match: boolean (nullable = true)

+-----+-----+-----------------+------------+------------+------------+-------+------+------+------+-------+--------+
| id_1| id_2|     cmp_fname_c1|cmp_fname_c2|cmp_lname_c1|cmp_lname_c2|cmp_sex|cmp_bd|cmp_bm|cmp_by|cmp_plz|is_match|
+-----+-----+-----------------+------------+------------+------------+-------+------+------+------+-------+--------+
|37291|53113|0.833333333333333|        null|           1|        null|      1|     1|     1|     1|      0|    true|
|39086|47614|                1|        null

In [8]:
parsed.count()

5749133

In [9]:
parsed.cache()

DataFrame[id_1: string, id_2: string, cmp_fname_c1: string, cmp_fname_c2: string, cmp_lname_c1: string, cmp_lname_c2: string, cmp_sex: string, cmp_bd: string, cmp_bm: string, cmp_by: int, cmp_plz: int, is_match: boolean]

In [10]:
from pyspark.sql.functions import col
parsed.groupBy("is_match").count().orderBy(col("count").desc()).show()

+--------+-------+
|is_match|  count|
+--------+-------+
|   false|5728201|
|    true|  20931|
|    null|      1|
+--------+-------+



In [11]:
parsed.createOrReplaceTempView("linkage")

In [12]:
spark.sql("""
    SELECT is_match, COUNT(*) cnt
    FROM linkage
    GROUP BY is_match
    ORDER BY cnt DESC
""").show()

+--------+-------+
|is_match|    cnt|
+--------+-------+
|   false|5728201|
|    true|  20931|
|    null|      1|
+--------+-------+



In [13]:
summary = parsed.describe()

In [14]:
summary.select("summary", "cmp_fname_c1", "cmp_fname_c2").show()

+-------+--------------------+------------------+
|summary|        cmp_fname_c1|      cmp_fname_c2|
+-------+--------------------+------------------+
|  count|             5748126|            103699|
|   mean|  0.7129023464241683|0.9000089989364238|
| stddev|  0.3887584395082916|0.2713306768152377|
|    min|                   0|                 0|
|    max|2.68694413843136e-05|                 1|
+-------+--------------------+------------------+



In [15]:
matches = parsed.where("is_match = true")
match_summary = matches.describe()
misses = parsed.filter(col("is_match") == False)
miss_summary = misses.describe()

In [16]:
summary_p = summary.toPandas()

In [17]:
summary_p.head()
...
summary_p.shape
...

(5, 12)

In [18]:
summary_p = summary_p.set_index('summary').transpose().reset_index()
...
summary_p = summary_p.rename(columns={'index':'field'})
...
summary_p = summary_p.rename_axis(None, axis=1)
...
summary_p.shape

(11, 6)

In [19]:
summaryT = spark.createDataFrame(summary_p)
...
summaryT

DataFrame[field: string, count: string, mean: string, stddev: string, min: string, max: string]

In [20]:
summaryT.printSchema()

root
 |-- field: string (nullable = true)
 |-- count: string (nullable = true)
 |-- mean: string (nullable = true)
 |-- stddev: string (nullable = true)
 |-- min: string (nullable = true)
 |-- max: string (nullable = true)



In [21]:
from pyspark.sql.types import DoubleType
for c in summaryT.columns:
    if c == 'field':
        continue
    summaryT = summaryT.withColumn(c, summaryT[c].cast(DoubleType()))
summaryT.printSchema()

root
 |-- field: string (nullable = true)
 |-- count: double (nullable = true)
 |-- mean: double (nullable = true)
 |-- stddev: double (nullable = true)
 |-- min: double (nullable = true)
 |-- max: double (nullable = true)



In [22]:
from pyspark.sql import DataFrame
from pyspark.sql.types import DoubleType
def pivot_summary(desc):
    # convert to pandas dataframe
    desc_p = desc.toPandas()
    # transpose
    desc_p = desc_p.set_index('summary').transpose().reset_index()
    desc_p = desc_p.rename(columns={'index':'field'})
    desc_p = desc_p.rename_axis(None, axis=1)
    # convert to Spark dataframe
    descT = spark.createDataFrame(desc_p)
    # convert metric columns to double from string
    for c in descT.columns:
        if c == 'field':
            continue
        else:
            descT = descT.withColumn(c, descT[c].cast(DoubleType()))
        return descT

In [23]:
match_summaryT = pivot_summary(match_summary)
miss_summaryT = pivot_summary(miss_summary)

In [24]:
match_summaryT.createOrReplaceTempView("match_desc")
miss_summaryT.createOrReplaceTempView("miss_desc")
spark.sql("""
SELECT a.field, a.count + b.count total, a.mean - b.mean delta
FROM match_desc a INNER JOIN miss_desc b ON a.field = b.field
WHERE a.field NOT IN ("id_1", "id_2")
ORDER BY delta DESC, total DESC
""")

DataFrame[field: string, total: double, delta: double]

In [25]:
good_features = ["cmp_lname_c1", "cmp_plz", "cmp_by", "cmp_bd", "cmp_bm"]
...
sum_expression = " + ".join(good_features)
...
sum_expression

'cmp_lname_c1 + cmp_plz + cmp_by + cmp_bd + cmp_bm'

In [26]:
from pyspark.sql.functions import expr
scored = parsed.fillna(0, subset=good_features).\
                withColumn('score', expr(sum_expression)).\
                select('score', 'is_match')
...
scored.show()

+-----+--------+
|score|is_match|
+-----+--------+
|  4.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  4.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  4.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
+-----+--------+
only showing top 20 rows



In [27]:
def crossTabs(scored: DataFrame, t: DoubleType) -> DataFrame:
    return scored.selectExpr(f"score >= {t} as above", "is_match").\
        groupBy("above").pivot("is_match", ("true", "false")).\
        count()

In [28]:
crossTabs(scored, 4.0).show()

+-----+-----+-------+
|above| true|  false|
+-----+-----+-------+
| null|    6|    789|
| true|20871|    637|
|false|   54|5726775|
+-----+-----+-------+



In [29]:
confused=crossTabs(scored, 2.0)
confused.show()

+-----+-----+-------+
|above| true|  false|
+-----+-----+-------+
| null|    6|    789|
| true|20925| 596413|
|false| null|5130999|
+-----+-----+-------+



In [30]:
tp = confused.filter("above = true").select("true").collect()[0].true
fp = confused.filter("above = true").select("false").collect()[0].false
fn = confused.filter("above = false").select("true").fillna(0).collect()[0].true
tn = confused.filter("above = false").select("false").collect()[0].false

In [31]:
precision = tp/(tp+fp)
recall = tp/(tp+fn)
f1 = 2*precision*recall/(precision+recall)

In [32]:
precision

0.03389553210720869

In [33]:
recall

1.0

In [34]:
f1

0.06556858223020917

In [35]:
fn

0