In [1]:
import pyspark.conf
import pyspark.sql
SparkConf = pyspark.conf.SparkConf
SparkSession = pyspark.sql.SparkSession
spark = SparkSession.builder \
            .appName("Intro") \
            .config('spark.executor.memory', '2g') \
            .config('spark.driver.memory','8g') \
            .getOrCreate()

In [2]:
from pyspark.sql import functions as F

# This may take some time if you work locally
#prev = spark.read.csv("linkage/block*.csv")
prev = spark.read.csv("linkage/block_1.csv")
prev.show()

+-----+-----+-----------------+------------+------------+------------+-------+------+------+------+-------+--------+
|  _c0|  _c1|              _c2|         _c3|         _c4|         _c5|    _c6|   _c7|   _c8|   _c9|   _c10|    _c11|
+-----+-----+-----------------+------------+------------+------------+-------+------+------+------+-------+--------+
| id_1| id_2|     cmp_fname_c1|cmp_fname_c2|cmp_lname_c1|cmp_lname_c2|cmp_sex|cmp_bd|cmp_bm|cmp_by|cmp_plz|is_match|
|37291|53113|0.833333333333333|           ?|           1|           ?|      1|     1|     1|     1|      0|    TRUE|
|39086|47614|                1|           ?|           1|           ?|      1|     1|     1|     1|      1|    TRUE|
|70031|70237|                1|           ?|           1|           ?|      1|     1|     1|     1|      1|    TRUE|
|84795|97439|                1|           ?|           1|           ?|      1|     1|     1|     1|      1|    TRUE|
|36950|42116|                1|           ?|           1|       

The ? are missing values and the header wasn't read correctly

In [3]:
# We can infer the schema else it is string
# Note: To perform the schema inference. Spark must pass over the ddata set twice (detecting and parsing schema)
# If you know the schema you can define it using StructField and StructType:
"""
from pyspark.sql.types import *
schema = StructType([StructField("id_1", IntegerType(), False),
StructField("id_2", StringType(), False),
StructField("cmp_fname_c1", DoubleType(), False)])
spark.read.schema(schema).csv("...")
or
schema = "id_1 INT, id_2 INT, cmp_fname_c1 DOUBLE"
"""
prev.printSchema()
parsed = spark.read \
            .option("header", "true") \
            .option("nullValue", "?") \
            .option("inferSchema", "true") \
            .csv("linkage/block_1.csv")

root
 |-- _c0: string (nullable = true)
 |-- _c1: string (nullable = true)
 |-- _c2: string (nullable = true)
 |-- _c3: string (nullable = true)
 |-- _c4: string (nullable = true)
 |-- _c5: string (nullable = true)
 |-- _c6: string (nullable = true)
 |-- _c7: string (nullable = true)
 |-- _c8: string (nullable = true)
 |-- _c9: string (nullable = true)
 |-- _c10: string (nullable = true)
 |-- _c11: string (nullable = true)



In [4]:
parsed.show()

+-----+-----+-----------------+------------+------------+------------+-------+------+------+------+-------+--------+
| id_1| id_2|     cmp_fname_c1|cmp_fname_c2|cmp_lname_c1|cmp_lname_c2|cmp_sex|cmp_bd|cmp_bm|cmp_by|cmp_plz|is_match|
+-----+-----+-----------------+------------+------------+------------+-------+------+------+------+-------+--------+
|37291|53113|0.833333333333333|        null|         1.0|        null|      1|     1|     1|     1|      0|    true|
|39086|47614|              1.0|        null|         1.0|        null|      1|     1|     1|     1|      1|    true|
|70031|70237|              1.0|        null|         1.0|        null|      1|     1|     1|     1|      1|    true|
|84795|97439|              1.0|        null|         1.0|        null|      1|     1|     1|     1|      1|    true|
|36950|42116|              1.0|        null|         1.0|         1.0|      1|     1|     1|     1|      1|    true|
|42413|48491|              1.0|        null|         1.0|       

In [6]:
# sanity check
parsed.first()

Row(id_1=37291, id_2=53113, cmp_fname_c1=0.833333333333333, cmp_fname_c2=None, cmp_lname_c1=1.0, cmp_lname_c2=None, cmp_sex=1, cmp_bd=1, cmp_bm=1, cmp_by=1, cmp_plz=0, is_match=True)

## Analyzing Data with the DataFrame API

In [6]:
parsed.count()

574913

In [7]:
parsed.show(5)

+-----+-----+-----------------+------------+------------+------------+-------+------+------+------+-------+--------+
| id_1| id_2|     cmp_fname_c1|cmp_fname_c2|cmp_lname_c1|cmp_lname_c2|cmp_sex|cmp_bd|cmp_bm|cmp_by|cmp_plz|is_match|
+-----+-----+-----------------+------------+------------+------------+-------+------+------+------+-------+--------+
|37291|53113|0.833333333333333|        null|         1.0|        null|      1|     1|     1|     1|      0|    true|
|39086|47614|              1.0|        null|         1.0|        null|      1|     1|     1|     1|      1|    true|
|70031|70237|              1.0|        null|         1.0|        null|      1|     1|     1|     1|      1|    true|
|84795|97439|              1.0|        null|         1.0|        null|      1|     1|     1|     1|      1|    true|
|36950|42116|              1.0|        null|         1.0|         1.0|      1|     1|     1|     1|      1|    true|
+-----+-----+-----------------+------------+------------+-------

In [5]:
parsed.printSchema()

root
 |-- id_1: integer (nullable = true)
 |-- id_2: integer (nullable = true)
 |-- cmp_fname_c1: double (nullable = true)
 |-- cmp_fname_c2: double (nullable = true)
 |-- cmp_lname_c1: double (nullable = true)
 |-- cmp_lname_c2: double (nullable = true)
 |-- cmp_sex: integer (nullable = true)
 |-- cmp_bd: integer (nullable = true)
 |-- cmp_bm: integer (nullable = true)
 |-- cmp_by: integer (nullable = true)
 |-- cmp_plz: integer (nullable = true)
 |-- is_match: boolean (nullable = true)



In [10]:
from pyspark.sql.functions import col
parsed.groupBy("is_match").count().orderBy(col("count").desc()).show()

+--------+------+
|is_match| count|
+--------+------+
|   false|572820|
|    true|  2093|
+--------+------+



In [11]:
parsed.cache()

DataFrame[id_1: int, id_2: int, cmp_fname_c1: double, cmp_fname_c2: double, cmp_lname_c1: double, cmp_lname_c2: double, cmp_sex: int, cmp_bd: int, cmp_bm: int, cmp_by: int, cmp_plz: int, is_match: boolean]

In [12]:
# To use SQL syntax create temporary table
parsed.createOrReplaceTempView("linkage")

In [13]:
spark.sql("""
            SELECT is_match, COUNT(*) cnt
            FROM linkage
            GROUP BY is_match
            ORDER BY cnt DESC
        """).show()

+--------+------+
|is_match|   cnt|
+--------+------+
|   false|572820|
|    true|  2093|
+--------+------+



In [15]:
summary = parsed.describe()
summary.select("summary", "cmp_fname_c1", "cmp_fname_c2").show()

+-------+------------------+------------------+
|summary|      cmp_fname_c1|      cmp_fname_c2|
+-------+------------------+------------------+
|  count|            574811|             10325|
|   mean|0.7127592938253411|0.8977586763518969|
| stddev|0.3889286452463531|0.2742577520430532|
|    min|               0.0|               0.0|
|    max|               1.0|               1.0|
+-------+------------------+------------------+



In [25]:
matches = parsed.where("is_match = true")
misses  = parsed.where("is_match = false")
matchSummary = matches.describe()
missSummary = misses.describe()

In [18]:
# Panda conversion
summary_p = summary.toPandas()
summary_p = summary_p.set_index('summary').transpose().reset_index()
summary_p = summary_p.rename(columns={'index':'field'})
summary_p = summary_p.rename_axis(None, axis=1)
summaryT = spark.createDataFrame(summary_p)
summaryT.show()

+------------+------+--------------------+-------------------+---+------+
|       field| count|                mean|             stddev|min|   max|
+------------+------+--------------------+-------------------+---+------+
|        id_1|574913|  33271.962171667714| 23622.669425933756|  1| 99894|
|        id_2|574913|    66564.6636865056|  23642.00230967228|  6|100000|
|cmp_fname_c1|574811|  0.7127592938253411| 0.3889286452463531|0.0|   1.0|
|cmp_fname_c2| 10325|  0.8977586763518969| 0.2742577520430532|0.0|   1.0|
|cmp_lname_c1|574913|  0.3155724578100624| 0.3342494687554245|0.0|   1.0|
|cmp_lname_c2|   239|  0.3269155414552904| 0.3783092020540671|0.0|   1.0|
|     cmp_sex|574913|  0.9550923357099248|0.20710152240504442|  0|     1|
|      cmp_bd|574851| 0.22475563232907309| 0.4174216587235557|  0|     1|
|      cmp_bm|574851|  0.4886361857246487| 0.4998712818281637|  0|     1|
|      cmp_by|574851| 0.22266639529199742| 0.4160365041645591|  0|     1|
|     cmp_plz|573618|0.005494946113964

In [19]:
summaryT.printSchema()

root
 |-- field: string (nullable = true)
 |-- count: string (nullable = true)
 |-- mean: string (nullable = true)
 |-- stddev: string (nullable = true)
 |-- min: string (nullable = true)
 |-- max: string (nullable = true)



In [22]:
from pyspark.sql.types import DoubleType
# Cast to double type
for c in summaryT.columns:
    if c == 'field':
        continue
    summaryT = summaryT.withColumn(c, summaryT[c].cast(DoubleType()))
    summaryT.printSchema()

root
 |-- field: string (nullable = true)
 |-- count: double (nullable = true)
 |-- mean: string (nullable = true)
 |-- stddev: string (nullable = true)
 |-- min: string (nullable = true)
 |-- max: string (nullable = true)

root
 |-- field: string (nullable = true)
 |-- count: double (nullable = true)
 |-- mean: double (nullable = true)
 |-- stddev: string (nullable = true)
 |-- min: string (nullable = true)
 |-- max: string (nullable = true)

root
 |-- field: string (nullable = true)
 |-- count: double (nullable = true)
 |-- mean: double (nullable = true)
 |-- stddev: double (nullable = true)
 |-- min: string (nullable = true)
 |-- max: string (nullable = true)

root
 |-- field: string (nullable = true)
 |-- count: double (nullable = true)
 |-- mean: double (nullable = true)
 |-- stddev: double (nullable = true)
 |-- min: double (nullable = true)
 |-- max: string (nullable = true)

root
 |-- field: string (nullable = true)
 |-- count: double (nullable = true)
 |-- mean: double (nullab

In [23]:
from pyspark.sql import DataFrame
from pyspark.sql.types import DoubleType
def pivot_summary(desc):
    # convert to pandas dataframe
    desc_p = desc.toPandas()
    # transpose
    desc_p = desc_p.set_index('summary').transpose().reset_index()
    desc_p = desc_p.rename(columns={'index':'field'})
    desc_p = desc_p.rename_axis(None, axis=1)
    # convert to Spark dataframe
    descT = spark.createDataFrame(desc_p)
    # convert metric columns to double from string
    for c in descT.columns:
        if c == 'field':
            continue
        else:
            descT = descT.withColumn(c, descT[c].cast(DoubleType()))
    return descT

In [28]:
match_summaryT = pivot_summary(matchSummary)
miss_summaryT = pivot_summary(missSummary)

In [29]:
match_summaryT.createOrReplaceTempView("match_desc")
miss_summaryT.createOrReplaceTempView("miss_desc")
spark.sql("""
SELECT a.field, a.count + b.count total, a.mean - b.mean delta
FROM match_desc a INNER JOIN miss_desc b ON a.field = b.field
WHERE a.field NOT IN ("id_1", "id_2")
ORDER BY delta DESC, total DESC
""").show()

+------------+--------+-------------------+
|       field|   total|              delta|
+------------+--------+-------------------+
|     cmp_plz|573618.0| 0.9524975516429005|
|cmp_lname_c2|   239.0| 0.8136949970410103|
|      cmp_by|574851.0| 0.7763379425859384|
|      cmp_bd|574851.0| 0.7732820129086737|
|cmp_lname_c1|574913.0| 0.6844795197261291|
|      cmp_bm|574851.0|  0.510834819548174|
|cmp_fname_c1|574811.0|0.28531156828518667|
|cmp_fname_c2| 10325.0|0.09900440489032714|
|     cmp_sex|574913.0|0.03452211590529575|
+------------+--------+-------------------+



## Scoring and Model Evaluation

In [30]:
good_features = ["cmp_lname_c1", "cmp_plz", "cmp_by", "cmp_bd", "cmp_bm"]
sum_expression = " + ".join(good_features)
sum_expression

'cmp_lname_c1 + cmp_plz + cmp_by + cmp_bd + cmp_bm'

In [31]:
from pyspark.sql.functions import expr
scored = parsed.fillna(0, subset=good_features).\
withColumn('score', expr(sum_expression)).\
select('score', 'is_match')
scored.show()

+-----+--------+
|score|is_match|
+-----+--------+
|  4.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  4.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  4.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
+-----+--------+
only showing top 20 rows



In [33]:
def crossTabs(scored: DataFrame, t: DoubleType) -> DataFrame:
    return scored.selectExpr(f"score >= {t} as above", "is_match").\
groupBy("above").pivot("is_match", ("true", "false")).\
count()
crossTabs(scored, 4.0).show()

+-----+----+------+
|above|true| false|
+-----+----+------+
| true|2087|    66|
|false|   6|572754|
+-----+----+------+



In [34]:
crossTabs(scored, 2.0).show()

+-----+----+------+
|above|true| false|
+-----+----+------+
| true|2093| 59729|
|false|null|513091|
+-----+----+------+

