In [2]:
import pyspark.sql.functions as f
from pyspark.sql import Window, SparkSession
import os, sys
os.environ['PYSPARK_PYTHON'] = sys.executable
os.environ['PYSPARK_DRIVER_PYTHON'] = sys.executable

In [3]:
from pyspark import SparkConf

In [4]:
conf = SparkConf()
conf.set("spark.app.name","ML Intro")
conf.set("spark.master","local[6]")

spark = SparkSession.builder\
                    .config(conf=conf)\
                    .getOrCreate()

In [5]:
parsed = spark.read\
              .format("csv")\
              .option("header","true")\
              .option("inferSchema","true")\
              .option("nullValue","?")\
              .load(r"C:\Users\blais\Documents\ML\data\linkage\block*\*.csv")

In [6]:
parsed.printSchema()

root
 |-- id_1: integer (nullable = true)
 |-- id_2: integer (nullable = true)
 |-- cmp_fname_c1: double (nullable = true)
 |-- cmp_fname_c2: double (nullable = true)
 |-- cmp_lname_c1: double (nullable = true)
 |-- cmp_lname_c2: double (nullable = true)
 |-- cmp_sex: integer (nullable = true)
 |-- cmp_bd: integer (nullable = true)
 |-- cmp_bm: integer (nullable = true)
 |-- cmp_by: integer (nullable = true)
 |-- cmp_plz: integer (nullable = true)
 |-- is_match: boolean (nullable = true)



In [7]:
parsed.first()

Row(id_1=3148, id_2=8326, cmp_fname_c1=1.0, cmp_fname_c2=None, cmp_lname_c1=1.0, cmp_lname_c2=None, cmp_sex=1, cmp_bd=1, cmp_bm=1, cmp_by=1, cmp_plz=1, is_match=True)

In [8]:
parsed.show(5)

+-----+-----+------------+------------+------------+------------+-------+------+------+------+-------+--------+
| id_1| id_2|cmp_fname_c1|cmp_fname_c2|cmp_lname_c1|cmp_lname_c2|cmp_sex|cmp_bd|cmp_bm|cmp_by|cmp_plz|is_match|
+-----+-----+------------+------------+------------+------------+-------+------+------+------+-------+--------+
| 3148| 8326|         1.0|        NULL|         1.0|        NULL|      1|     1|     1|     1|      1|    true|
|14055|94934|         1.0|        NULL|         1.0|        NULL|      1|     1|     1|     1|      1|    true|
|33948|34740|         1.0|        NULL|         1.0|        NULL|      1|     1|     1|     1|      1|    true|
|  946|71870|         1.0|        NULL|         1.0|        NULL|      1|     1|     1|     1|      1|    true|
|64880|71676|         1.0|        NULL|         1.0|        NULL|      1|     1|     1|     1|      1|    true|
+-----+-----+------------+------------+------------+------------+-------+------+------+------+-------+--

In [9]:
parsed.count()

5749132

Introducing caching:
- Once the data has been parsed once, we'd like to save the data in its parsed form on the cluster so that we don't have to reparse it every time we want to ask a new question. Spark supports this use case by allowing us to signal that a given dataframe should be cached in memory after it is generated by calling the cache method on the instance. 

In [10]:
parsed.cache()

DataFrame[id_1: int, id_2: int, cmp_fname_c1: double, cmp_fname_c2: double, cmp_lname_c1: double, cmp_lname_c2: double, cmp_sex: int, cmp_bd: int, cmp_bm: int, cmp_by: int, cmp_plz: int, is_match: boolean]

Our DF has now been cached and the next thing we want to know is the relative fraction of records that were matches vs those that were non matches:

In [11]:
parsed.select('is_match').distinct().show()

+--------+
|is_match|
+--------+
|    true|
|   false|
+--------+



In [12]:
parsed.groupBy("is_match").count().orderBy(f.col("count").desc()).show()

+--------+-------+
|is_match|  count|
+--------+-------+
|   false|5728201|
|    true|  20931|
+--------+-------+



Dataframe Aggregation functions:
- In addition to count, we also can compute more complex aggregations. for example to find the mean and standard deciation of the cmp_sex field in the overall parsed df - we could use:

In [13]:
parsed.select(f.avg("cmp_sex"), f.stddev("cmp_sex")).show()

+-----------------+-------------------+
|     avg(cmp_sex)|    stddev(cmp_sex)|
+-----------------+-------------------+
|0.955001381078048|0.20730111116897781|
+-----------------+-------------------+



In [14]:
parsed.createOrReplaceTempView("linkage")

In [15]:
spark.sql("""
    SELECT is_match, COUNT(*)
    FROM linkage
    GROUP BY is_match
    ORDER BY COUNT(*) DESC
""").show()

+--------+--------+
|is_match|count(1)|
+--------+--------+
|   false| 5728201|
|    true|   20931|
+--------+--------+



Fact Summary Statistics for DataFrames:

In [16]:
summary = parsed.describe()

In [17]:
summary.show()

+-------+------------------+------------------+-------------------+------------------+------------------+-------------------+-------------------+-------------------+-------------------+------------------+-------------------+
|summary|              id_1|              id_2|       cmp_fname_c1|      cmp_fname_c2|      cmp_lname_c1|       cmp_lname_c2|            cmp_sex|             cmp_bd|             cmp_bm|            cmp_by|            cmp_plz|
+-------+------------------+------------------+-------------------+------------------+------------------+-------------------+-------------------+-------------------+-------------------+------------------+-------------------+
|  count|           5749132|           5749132|            5748125|            103698|           5749132|               2464|            5749132|            5748337|            5748337|           5748337|            5736289|
|   mean| 33324.48559643438| 66587.43558331935| 0.7129024704437266|0.9000176718903189|0.315627819308

In [18]:
summary.select("summary","cmp_fname_c1","cmp_fname_c2").show()

+-------+-------------------+------------------+
|summary|       cmp_fname_c1|      cmp_fname_c2|
+-------+-------------------+------------------+
|  count|            5748125|            103698|
|   mean| 0.7129024704437266|0.9000176718903189|
| stddev|0.38875835961628014|0.2713176105782334|
|    min|                0.0|               0.0|
|    max|                1.0|               1.0|
+-------+-------------------+------------------+



In [19]:
matches = parsed.where("is_match = true")
match_summary = matches.describe()

In [20]:
misses = parsed.filter(f.col("is_match")==False)
miss_summary = misses.describe()

Pivoting and Reshaping DataFrames:

In [21]:
summary_p = summary.toPandas()

In [22]:
summary_p.head()

Unnamed: 0,summary,id_1,id_2,cmp_fname_c1,cmp_fname_c2,cmp_lname_c1,cmp_lname_c2,cmp_sex,cmp_bd,cmp_bm,cmp_by,cmp_plz
0,count,5749132.0,5749132.0,5748125.0,103698.0,5749132.0,2464.0,5749132.0,5748337.0,5748337.0,5748337.0,5736289.0
1,mean,33324.48559643438,66587.43558331935,0.7129024704437266,0.9000176718903189,0.3156278193080383,0.3184128315317443,0.955001381078048,0.2244652670850717,0.488855298497635,0.2227485966810923,0.0055286614743434
2,stddev,23659.859374488064,23620.48761326969,0.3887583596162801,0.2713176105782334,0.3342336339615828,0.3685670662006653,0.2073011111689778,0.4172297223846263,0.4998758236779031,0.4160909629831756,0.0741491492542004
3,min,1.0,6.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,max,99980.0,100000.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0


In [23]:
summary_p.shape

(5, 12)

In [24]:
summary_p = summary_p.set_index("summary").transpose().reset_index()

In [25]:
summary_p = summary_p.rename(columns={'index':'field'})

In [26]:
summary_p = summary_p.rename_axis(None, axis=1)

In [27]:
summary_p

Unnamed: 0,field,count,mean,stddev,min,max
0,id_1,5749132,33324.48559643438,23659.859374488064,1.0,99980.0
1,id_2,5749132,66587.43558331935,23620.48761326969,6.0,100000.0
2,cmp_fname_c1,5748125,0.7129024704437266,0.3887583596162801,0.0,1.0
3,cmp_fname_c2,103698,0.9000176718903189,0.2713176105782334,0.0,1.0
4,cmp_lname_c1,5749132,0.3156278193080383,0.3342336339615828,0.0,1.0
5,cmp_lname_c2,2464,0.3184128315317443,0.3685670662006653,0.0,1.0
6,cmp_sex,5749132,0.955001381078048,0.2073011111689778,0.0,1.0
7,cmp_bd,5748337,0.2244652670850717,0.4172297223846263,0.0,1.0
8,cmp_bm,5748337,0.488855298497635,0.4998758236779031,0.0,1.0
9,cmp_by,5748337,0.2227485966810923,0.4160909629831756,0.0,1.0


In [28]:
summaryT = spark.createDataFrame(summary_p)

In [29]:
summaryT.show()

+------------+-------+-------------------+-------------------+---+------+
|       field|  count|               mean|             stddev|min|   max|
+------------+-------+-------------------+-------------------+---+------+
|        id_1|5749132|  33324.48559643438| 23659.859374488064|  1| 99980|
|        id_2|5749132|  66587.43558331935| 23620.487613269695|  6|100000|
|cmp_fname_c1|5748125| 0.7129024704437266|0.38875835961628014|0.0|   1.0|
|cmp_fname_c2| 103698| 0.9000176718903189| 0.2713176105782334|0.0|   1.0|
|cmp_lname_c1|5749132| 0.3156278193080383| 0.3342336339615828|0.0|   1.0|
|cmp_lname_c2|   2464| 0.3184128315317443|0.36856706620066537|0.0|   1.0|
|     cmp_sex|5749132|  0.955001381078048|0.20730111116897781|  0|     1|
|      cmp_bd|5748337|0.22446526708507172|0.41722972238462636|  0|     1|
|      cmp_bm|5748337|0.48885529849763504| 0.4998758236779031|  0|     1|
|      cmp_by|5748337| 0.2227485966810923| 0.4160909629831756|  0|     1|
|     cmp_plz|5736289|0.00552866147434

In [30]:
summaryT.printSchema()

root
 |-- field: string (nullable = true)
 |-- count: string (nullable = true)
 |-- mean: string (nullable = true)
 |-- stddev: string (nullable = true)
 |-- min: string (nullable = true)
 |-- max: string (nullable = true)



In [31]:
from pyspark.sql.types import DoubleType

In [32]:
for c in summaryT.columns:
    if c == 'field':
        continue
    summaryT = summaryT.withColumn(c, summaryT[c].cast(DoubleType()))

In [33]:
summaryT.printSchema()

root
 |-- field: string (nullable = true)
 |-- count: double (nullable = true)
 |-- mean: double (nullable = true)
 |-- stddev: double (nullable = true)
 |-- min: double (nullable = true)
 |-- max: double (nullable = true)



In [34]:
from pyspark.sql import DataFrame
from pyspark.sql.types import DoubleType

In [35]:
def pivot_summary(desc):
    # convert to pandas dataframe
    desc_p = desc.toPandas()
    # transpose
    desc_p = desc_p.set_index("summary").transpose().reset_index()
    desc_p = desc_p.rename(columns={'index':'field'})
    desc_p = desc_p.rename_axis(None, axis=1)
    # convert to spark dataframe
    descT = spark.createDataFrame(desc_p)
    # convert metric columns to double from string
    for c in descT.columns:
        if c == 'field':
            continue
        else:
            descT = descT.withColumn(c, descT[c].cast(DoubleType()))
    return descT

In [36]:
match_summaryT = pivot_summary(match_summary)
miss_summaryT = pivot_summary(miss_summary)

Joining Dataframes:

In [37]:
match_summaryT.createOrReplaceTempView("match_desc")
miss_summaryT.createOrReplaceTempView("miss_desc")

In [38]:
spark.sql("""
    SELECT a.field, a.count + b.count total, a.mean - b.mean delta
  FROM match_desc a INNER JOIN miss_desc b ON a.field = b.field
  WHERE a.field NOT IN ("id_1", "id_2")
  ORDER BY delta DESC, total DESC
""").show()

+------------+---------+--------------------+
|       field|    total|               delta|
+------------+---------+--------------------+
|     cmp_plz|5736289.0|  0.9563812499852176|
|cmp_lname_c2|   2464.0|  0.8064147192926266|
|      cmp_by|5748337.0|  0.7762059675300512|
|      cmp_bd|5748337.0|   0.775442311783404|
|cmp_lname_c1|5749132.0|  0.6838772482594513|
|      cmp_bm|5748337.0|  0.5109496938298685|
|cmp_fname_c1|5748125.0|  0.2854529057459947|
|cmp_fname_c2| 103698.0| 0.09104268062280174|
|     cmp_sex|5749132.0|0.032408185250332844|
+------------+---------+--------------------+



A good feature has 2 properties: it tends to have significantly different values for matches and nonmatches (so the difference between the means ought to be large). A good feature will also occur very often in the data that we can rely on it to be regulary available for any pair of records. 
- By this measure, cmp_fname_c2 isn't very useful. 
- cmp_sex isn't also useful. 
- cmp_plz and cmp_by on the other hand are excellent. 

Scoring and Model Evaluation:

In [39]:
good_features = ["cmp_lname_c1", "cmp_plz", "cmp_by", "cmp_bd", "cmp_bm"]
sum_expression = "+".join(good_features)

In [40]:
sum_expression

'cmp_lname_c1+cmp_plz+cmp_by+cmp_bd+cmp_bm'

In [41]:
scored = parsed.fillna(0, subset=good_features)\
                .withColumn('score',f.expr(sum_expression))\
                .select('score','is_match')

In [42]:
scored.show()

+-----+--------+
|score|is_match|
+-----+--------+
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  4.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
|  4.0|    true|
|  5.0|    true|
|  5.0|    true|
|  5.0|    true|
+-----+--------+
only showing top 20 rows



In [43]:
scored.groupBy('is_match').agg(f.avg("score"),f.stddev("score")).show()

+--------+------------------+------------------+
|is_match|        avg(score)|     stddev(score)|
+--------+------------------+------------------+
|    true| 4.944413854980717|0.2418236049726355|
|   false|1.2436102486054796|0.5538759352302625|
+--------+------------------+------------------+



Final Step in creating our scoring function is to decide what threshold the score must exceed in order to predict that the 2 records represent a match. 
To help choose a threshold, its helpful to create a contingency table - sometimes known as a cross tabulation or cross tab. that counts the number of records whose scores fall above/below the threshold value crossed with the number of records in each of those categories that were/were not matches. 

In [44]:
# function that takes the scored dataframe as a param and computes the crosstabs using the dataframe api
def crossTabs(scored: DataFrame, t: DoubleType) -> DataFrame:
    return scored.selectExpr(f"score >= {t} as above","is_match")\
                 .groupBy("above").pivot("is_match",("true","false")).count()

In [45]:
crossTabs(scored, 4.0).show()

+-----+-----+-------+
|above| true|  false|
+-----+-----+-------+
| true|20871|    637|
|false|   60|5727564|
+-----+-----+-------+



In [46]:
crossTabs(scored, 2.0).show()

+-----+-----+-------+
|above| true|  false|
+-----+-----+-------+
| true|20931| 596414|
|false| NULL|5131787|
+-----+-----+-------+



The trade-off between false positives and negatives is evident from the above tables - raising the threshold high or lowering it has an impact on this. final threshold is application dependent.