# analyzeLinkageAlgorithms
This notebook serves as a one-stop shop for record linkage analysis, experimentation, and evaluation. By adjusting the parameter settings and runtime mode below, a user can perform the following:

* Run a comparative analysis of three record linkage algorithms: a python implementation of LAC's current algorithm (without post-processing heuristics and in general form), the DIBBs Basic algorithm, and the DIBBs Log-Odds enhanced algorithm.
* Train an updated set of DIBBs Enhanced Algorithm log-odds weights to suit a specific suite of production data. Then, visually profile and identify cutoff thresholds for each pass of the DIBBs Enhanced Algorithm using previously trained or newly trained log-odds weights.
* Simultaneously test a suite of different fuzzy matching thresholds for one or more fields of patient data.

Much of the code used in these tasks overlaps, but to help facilitate easy use, the notebook is divided into many smaller sections with called-out settings and explanations specific to the desired tasks.

To begin, run the line below to install the notebook's dependencies; then, select which mode you'd like to run the notebook in.

In [None]:
pip install --upgrade pip

In [None]:
pip install psycopg2-binary azure-identity recordlinkage azure-keyvault-secrets rapidfuzz numpy pandas matplotlib

## Mode Selection
In the cell below, enter the string corresponding to the mode you'd like to run the notebook in:

* `"compare_algorithms"`: this mode performs a computational comparison in the performance of the three algorithms of interest (a Python port of LAC's current algorithm, DIBBs Basic, and DIBBs Enhanced). Individual algorithm results are saved along the way, and final results are both displayed in the notebook as well as exported to a file for saving.
* `"train_weights"`: this mode computes population-appropriate weights for each field of patient data in the extract loaded from the MPI. These new weights serve as the log-odds scores in the DIBBs Enhanced algorithm. Generating these weights does not require any running of a record linkage algorithm and is a fairly quick process; however, graphically profiling the impact of the trained weights does. Once the weights are computed, the notebook runs each pass of the DIBBs enhanced algorithm in isolation, accumulating the scores each record _would have earned_ if linkage were being run in full. The distributions of these scores are used in concert with the labels the notebook generates to visually separate the true matches from the non-matches, so that a user can identify a cutoff line to use as a new parameter setting in the Enhanced algorithm.
* `"test_thresholds"`: this mode evaluates a list of different fuzzy matching thresholds for a single field of patient data (e.g. `birthdate` or `first_name`). Each provided possible threshold has a subset of performance statistics computed for it, and when the testing framework completes, these values are reported for each tested threshold so that a user can determine which threshold performed the best.

These modes are mutually exclusive with one another--e.g., running `train_weights` will preclude running `test_thresholds`. Each of these modes has different mode-specific parameter settings, which can be configured in the sections below.

Once the MPI has been seeded, any of these three modes can be run before any others, but if all functionality is desired from all parts, the most logical sequence would be the following:

* for each field of interest, run the notebook in `test_thresholds` mode to identify the best fuzzy matching-threshold for each field to-be-used in the desired algorithm; `test_thresholds` mode relies only on the DIBBs Basic Algorithm, meaning it needs no additional inputs or functions from weight training or the enhanced algorithm
* modify the settings for `train_weights` mode by adding these experimentally determined per-field thresholds to the `COLS_TO_THRESHOLDS` parameter in the form of key-value pairs (i.e. `"first_name_": 0.92`)
* run the notebook in `train_weights` mode once for each pass of the algorithm to adjust (for the DIBBs algorithm, this will be twice), passing in the appropriate `cols_to_profile` values for the pass of interest; for each such run, use the profiling graph to determine the best cutoff threshold for each pass
* modify the settings for `compare_algorithms` mode by updaing the final cutoff threshold parameters and the new, updated log-odds weights with these experimentally determined values
* run the notebook in `compare_algorithms` mode to see the performance impacts of the tuning

In [None]:
# MODE SELECTION
# Options are: "compare_algorithms", "train_weights", and "test_thresholds"
NOTEBOOK_MODE = "compare_algorithms"

## MPI Access Configuration
The cell below contains the settings to-be configured to ensure the notebook can access the MPI in which production data is stored. Configure the db_client information and table names to reflect the structure of the MPI.

In [None]:
# Set your Key Vault information
vault_name = "$KEY_VAULT"
KEY_VAULT_URL = f"https://{vault_name}.vault.azure.net"
vault_linked_service = "$KEY_VAULT_LINKED_SERVICE"

# Set up db_client
DB_NAME = "DibbsMpiDB"
DB_USER = "postgres"
DB_HOST = "$MPI_DB_HOST"
DB_PORT = "5432"

# MPI table names
DB_TABLE_PATIENT = "patient"
DB_TABLE_PERSON = "person"
DB_TABLE_NAME = "name"
DB_TABLE_GIVEN_NAME = "given_name"
DB_TABLE_ADDRESS = "address"
DB_TABLE_IDENTIFIER = "identifier"

## General Notebook Settings
The cell below contains a selection of imports and configuration options which are used across all modes. These settings include configuration of the size of the dataset to use, tuning thresholds for ground-truth labeling, and the window size of the neighborhood to search around each field when building the indexer (see below)

In [None]:
# Imports for secure access and ground-truth labeling
from azure.identity import DefaultAzureCredential
import time
import copy
import pandas as pd
import recordlinkage as rl
from recordlinkage.base import BaseCompareFeature
import numpy as np

# Adjust data volume for scaling
# Make sure evaluation size is less than or equal to labeling size!
# If running in `compare_algorithms` mode or `train_weights` mode, we recommend both sizes be set to 150000.
# If running in `test_thresholds` mode, we recommend any value between 50000 and 150000.
LABELING_SIZE = 200000
EVALUATION_SIZE = 200000

# Adjust the parameteres for ground truth labeling.
# Window size is the number of nearest neighbors that will be chosen as possible candidates
# during the indexing of each field. We recommend a value no larger than 7.
# Jaro Threshold and Birthday threshold are "fuzzy match" thresholds used by the true-match
# labelers to decide if two candidates are indeed a match. Birthday should not be set below 0.95.
# Jaro can be set anywhere above 0.85, but we recommend at least 0.9.
WINDOW_INDEX_SIZE = 5
JARO_THRESHOLD = 0.9
BIRTHDAY_THRESHOLD = 0.95

In [None]:
# DEFAULT DIBBS ALGORITHMS

# UPDATED DIBBs ALGORITHMS
# These algorithms and log odds scores are the updated values developed after
# substantial statistical tuning. Older algorithms can be found below.
DIBBS_ENHANCED_LOG_ODDS_SCORES = {
    'address': 8.438284928858774,
    'birthdate': 10.126641103800338,
    'city': 2.438553006137189,
    'first_name': 6.849475906891162,
    'last_name': 6.350720397426025,
    'mrn': 0.3051262572525359,
    'sex': 0.7510419059643679,
    'state': 0.022376768992488694,
    'zip': 4.975031471124867
}
FUZZY_THRESHOLDS = {
    "first_name": 0.9,
    "last_name": 0.9,
    "birthdate": 0.95,
    "address": 0.9,
    "city": 0.92,
    "zip": 0.95
}

DIBBS_BASIC = [
    {
        "funcs": {
            "first_name": "feature_match_fuzzy_string",
            "last_name": "feature_match_exact",
        },
        "blocks": [
            {"value": "birthdate"},
            {"value": "mrn", "transformation": "last4"},
            {"value": "sex"}
        ],
        "matching_rule": "eval_perfect_match",
        "cluster_ratio": 0.9,
        "kwargs": {
            "thresholds": FUZZY_THRESHOLDS
        }
    },
    {
        "funcs": {
            "address": "feature_match_fuzzy_string",
            "birthdate": "feature_match_exact",
        },
        "blocks": [
            {"value": "zip"},
            {"value": "first_name", "transformation": "first4"},
            {"value": "last_name", "transformation": "first4"},
            {"value": "sex"},
        ],
        "matching_rule": "eval_perfect_match",
        "cluster_ratio": 0.9,
        "kwargs": {
            "thresholds": FUZZY_THRESHOLDS
        }
    }
]

DIBBS_ENHANCED = [
    {
        "funcs": {
            "first_name": "feature_match_log_odds_fuzzy_compare",
            "last_name": "feature_match_log_odds_fuzzy_compare",
        },
        "blocks": [
            {"value": "birthdate"},
            {"value": "mrn", "transformation": "last4"},
            {"value": "sex"}
        ],
        "matching_rule": "eval_log_odds_cutoff",
        "cluster_ratio": 0.9,
        "kwargs": {
            "similarity_measure": "JaroWinkler",
            "thresholds": FUZZY_THRESHOLDS,
            "true_match_threshold": 12.2,
            "log_odds": DIBBS_ENHANCED_LOG_ODDS_SCORES,
        },
    },
    {
        "funcs": {
            "address": "feature_match_log_odds_fuzzy_compare",
            "birthdate": "feature_match_log_odds_fuzzy_compare",
        },
        "blocks": [
            {"value": "zip"},
            {"value": "first_name", "transformation": "first4"},
            {"value": "last_name", "transformation": "first4"},
            {"value": "sex"},
        ],
        "matching_rule": "eval_log_odds_cutoff",
        "cluster_ratio": 0.9,
        "kwargs": {
            "similarity_measure": "JaroWinkler",
            "thresholds": FUZZY_THRESHOLDS,
            "true_match_threshold": 17.0,
            "log_odds": DIBBS_ENHANCED_LOG_ODDS_SCORES,
        },
    }
]

# OLD DIBBs ALGORITHMS
# The algorithms and information listed below represents the preliminary
# algorithms DIBBs developed, prior to tuning and experimentation. They 
# are included here mostly for posterity.

# DIBBS_ENHANCED_LOG_ODDS_SCORES = {
#     "birthdate": 9.944142836217619,
#     "first_name": 8.009121400325398,
#     "last_name": 5.327681398982514,
#     "sex": 0.6964525713514773,
#     "address": 5.769942276960749,
#     "city": 1.8002552875091014,
#     "state": 0.0,
#     "zip": 4.909466232098861,
#     "mrn": 1.464232660081324,
# }

# DIBBS_BASIC = [
#     {
#         "funcs": {
#             "first_name": "feature_match_fuzzy_string",
#             "last_name": "feature_match_fuzzy_string",
#             "birthdate": "feature_match_fuzzy_string",
#         },
#         "blocks": [
#             {"value": "mrn", "transformation": "last4"},
#             {"value": "address", "transformation": "first4"},
#         ],
#         "matching_rule": "eval_perfect_match",
#         "cluster_ratio": 0.9,
#     },
#     {
#         "funcs": {
#             "address": "feature_match_fuzzy_string",
#             "city": "feature_match_fuzzy_string",
#         },
#         "blocks": [
#             {"value": "first_name", "transformation": "first4"},
#             {"value": "last_name", "transformation": "first4"},
#         ],
#         "matching_rule": "eval_perfect_match",
#         "cluster_ratio": 0.9,
#     },
# ]

# DIBBS_ENHANCED = [
#     {
#         "funcs": {
#             "birthdate": "feature_match_log_odds_fuzzy_compare",
#             "first_name": "feature_match_log_odds_fuzzy_compare",
#             "last_name": "feature_match_log_odds_fuzzy_compare",
#         },
#         "blocks": [
#             {"value": "mrn", "transformation": "last4"},
#             {"value": "address", "transformation": "first4"},
#         ],
#         "matching_rule": "eval_log_odds_cutoff",
#         "cluster_ratio": 0.9,
        # "kwargs": {
        #     "similarity_measure": "JaroWinkler",
        #     "threshold": 0.7,
        #     "true_match_threshold": 16.5,
        #     "log_odds": DIBBS_ENHANCED_LOG_ODDS_SCORES,
        # },
#     },
#     {
#         "funcs": {
#             "zip": "feature_match_log_odds_fuzzy_compare",
#             "city": "feature_match_log_odds_fuzzy_compare",
#         },
#         "blocks": [
#             {"value": "first_name", "transformation": "first4"},
#             {"value": "last_name", "transformation": "first4"},
#         ],
#         "matching_rule": "eval_log_odds_cutoff",
#         "cluster_ratio": 0.9,
#         "kwargs": {
#             "similarity_measure": "JaroWinkler",
#             "threshold": 0.7,
#             "true_match_threshold": 7.0,
#             "log_odds": DIBBS_ENHANCED_LOG_ODDS_SCORES,
#         },
#     },
# ]


## Mode-Specific Settings
The cell below contains additional parameter settings that can be tuned 

In [None]:
# SETTINGS FOR `test_thresholds` MODE
# Testing field can be any supported field name in a patient resource
# We recommend testing no more than 7 values at once
TESTING_FIELD = "first_name"
TESTING_VALS = [0.85, 0.88, 0.90, 0.92, 0.95]

# SETTINGS FOR `train_weights` MODE
# For small data sets (< 10k records), use 50k neg samples.
# For mid sized (10k - 25k records), use 75k samples.
# Any larger, don't go above 100k samples.
NEG_SAMPLES = 100000
# Cols_to_profile should be an array of field names corresponding to one pass
# of a linkage algorithm. This means to run profiling for DIBBs,
# two separate runs will have to be made (since each pass has its own cutoff).
# Pass 1 should be ["first_name", "last_name", "birthdate"]
# Pass 2 should be ["address", "city"]
COLS_TO_PROFILE = ["first_name", "last_name", "birthdate"]
# Cols_to_thresholds should be a dictionary mapping field name strings (e.g.
# "first_name") to floats giving the minimum fuzzy matching threshold permissible
# to count for enhaned weight scoring. These numbers can be experimentally
# determined using the `test_thresholds` mode.
COLS_TO_THRESHOLDS = {
    "first_name": 0.9,
    "last_name": 0.9,
    "birthdate": 0.95,
    "address": 0.9,
    "city": 0.92,
    "zip": 0.95
}

# SETTINGS FOR `compare_algorithms` MODE
# Before running any algorithm, check the `linkage-notebook-outputs` container of the 
# storage account for the environment this notebook is running in. We might have a 
# previously saved linkage result there for this algorithm. If so, load that
# instead of rerunning the algorithm now. Safe to use in any situation where the order
# of data in the MPI doesn't change. If you want to leave this setting as True
# for some algorithms but still force others to re-run, simply manually delete
# that algorithm's output file in the Azure container associated with linkage results.
LOAD_PREVIOUS_RUNS = True
# Final_cols_to_thresholds should be a dictionary with the same structure and K-V
# pairs as COLS_TO_THRESHOLDS determined above; this parameter simply uses those
# values in the algorithm comparison evaluation
FINAL_COLS_TO_THRESHOLDS = COLS_TO_THRESHOLDS
# Each of these parameters should be the cutoff threshold value determined with 
# experimental profiling in `train_weights` mode. To use the default DIBBs enhanced
# cutoffs, set them to None.
# The experimentally determined values for the updated log-odds weights, calculated
# using the `train_weights` mode
FINAL_LOG_ODDS = {
    'address': 8.438284928858774,
    'birthdate': 10.126641103800338,
    'city': 2.438553006137189,
    'first_name': 6.849475906891162,
    'last_name': 6.350720397426025,
    'mrn': 0.3051262572525359,
    'sex': 0.7510419059643679,
    'state': 0.022376768992488694,
    'zip': 4.975031471124867
}

# Each of these parameters should be the cutoff threshold value determined with 
# experimental profiling in `train_weights` mode. To use the default DIBBs enhanced
# cutoffs, set them to None.
FINAL_PASS_1_THRESHOLD = 12.2
FINAL_PASS_2_THRESHOLD = 17.0

## File System Mounting
The cell below creates a spark session for use in the remainder of the notebook. It also mounts this session to the file system holding the Azure storage containers via secure connection. This lets us save the results we accumulate so that if the connection is disrupted, the notebook times out, or we simply wish to view previous information, we can do so with ease.

In [None]:
import json
from notebookutils import mssparkutils
from pyspark.sql import SparkSession

# Set paths
STORAGE_ACCOUNT = "$STORAGE_ACCOUNT"
LINKAGE_OUTPUTS_FILESYSTEM = f"abfss://linkage-notebook-outputs@{STORAGE_ACCOUNT}.dfs.core.windows.net/"
BLOB_STORAGE_LINKED_SERVICE = "$BLOB_STORAGE_LINKED_SERVICE"

"""
Function that writes the output of a linkage algorithm to a json file.
"""
def write_linkage_results(fname, results):
    # Results come in as a dict of ints to sets, so just json dumps it
    res_to_write = {str(k):[str(x) for x in list(v)] for (k,v) in results.items()}
    res_to_write = json.dumps(res_to_write)
    mssparkutils.fs.put(LINKAGE_OUTPUTS_FILESYSTEM + fname + ".json", res_to_write, True)


"""
Function that loads the output of a linkage algorithm from a json file using spark.read
and converts it into the same format as the results dict (ints to sets).
"""
def load_linkage_results(spark_session, fname):
    try:
        res = spark_session.read.json(LINKAGE_OUTPUTS_FILESYSTEM + fname + ".json")
    except:
        print("Existing results not found.")
        return None
    
    print("Existing results found!")
    res = res.toPandas()
    res = res.to_dict()
    res = {int(k):set([int(x) for x in v[0]]) for (k,v) in res.items()}
    return res


# Instantiate a spark session for use with the rest of the notebook
credential = DefaultAzureCredential()
db_password =  TokenLibrary.getSecret(vault_name,"mpi-db-password",vault_linked_service)
url = f"jdbc:postgresql://{DB_HOST}:{DB_PORT}/{DB_NAME}"
db_props = {
    "user": DB_USER,
    "password": db_password,
    "driver": "org.postgresql.Driver"
}
spark = (
    SparkSession.builder.master("local[*]")
    .appName("Build sub-sampled MPI")
    .getOrCreate()
)

# Set up for writing to blob storage
linkage_bucket_name = "linkage-notebook-outputs"
blob_sas_token = mssparkutils.credentials.getConnectionStringOrCreds(BLOB_STORAGE_LINKED_SERVICE)
wasb_path = 'wasbs://%s@%s.blob.core.windows.net/' % (linkage_bucket_name, STORAGE_ACCOUNT)
spark.conf.set('fs.azure.sas.%s.%s.blob.core.windows.net' % (linkage_bucket_name, STORAGE_ACCOUNT), blob_sas_token)

# Try mounting the remote storage directory at the mount point
try:
    mssparkutils.fs.mount(
        wasb_path,
        "/",
        {"LinkedService": f"${BLOB_STORAGE_LINKED_SERVICE}"}
    )
except:
    print("Already mounted")

## Extracting the MPI
Now we get into the meat of the notebook, in which we read and analyze data from the MPI. The following cells open a connection to the MPI specified by the configuration options above, extract patient data from it, and format it into an appropriately sized sample for use with the remainder of this notebook. All 3 notebook modes rely on clean, properly formatted patient data pulled from the MPI. This is accomplished with `pyspark` using access patterns called "temporary database views". We can read a dataset distributed across multiple tables into memory in parallel, and by constructing specific views for each of the tables we want to access, `pyspark` provides us a window to run queries or perform joins against that table without disrupting the underlying information actually stored in the MPI. Once we have these views, we can begin with the default `patient` table and iteratively join each other table to the assorted patient records there, ensuring that their information is properly grouped and collated along the way, so that our result is a single massaged table consisting of one patient row per re-constructed health care record.

In [None]:
# MPI ACCESS
# Create views into all tables in the MPI so we can extract them in parallel
# Each of these views represents a safe "copy" of one table in the MPI that we
# can query and/or manipulate without worrying about affecting the actual data
# stored in the database. We need one view for each table we want to access.
patient_view = "patient_view"
mpi_patient_data = spark.read.jdbc(url, DB_TABLE_PATIENT, properties=db_props)
mpi_patient_data.createOrReplaceTempView(patient_view)

name_view = "name_view"
mpi_name_data = spark.read.jdbc(url, DB_TABLE_NAME, properties=db_props)
mpi_name_data.createOrReplaceTempView(name_view)

given_name_view = "given_name_view"
mpi_given_name_data = spark.read.jdbc(url, DB_TABLE_GIVEN_NAME, properties=db_props)
mpi_given_name_data.createOrReplaceTempView(given_name_view)

address_view = "address_view"
mpi_address_data = spark.read.jdbc(url, DB_TABLE_ADDRESS, properties=db_props)
mpi_address_data.createOrReplaceTempView(address_view)

identifier_view = "identifier_view"
mpi_identifier_data = spark.read.jdbc(url, DB_TABLE_IDENTIFIER, properties=db_props)
mpi_identifier_data.createOrReplaceTempView(identifier_view)

In [None]:
# MPI EXTRACTION
# Pull all the various sources of patient information out of the MPI and
# massage them into a single table. This allows us to format the data for
# training and testing easily off the same spark DF.

import pyspark.sql.functions as F
from pyspark.sql.functions import struct

'''
Helper function to construct a complete string representation of a patient's given 
name from the various fields of a row struct in a pyspark dataframe pulled from
the MPI.
'''
def construct_full_given_name(row):
    gn = ""
    if row["given_name_list"] is not None:
        sorted_structs = sorted(row["given_name_list"], key=lambda x: x.given_name_index)
        gn = [x.given_name for x in sorted_structs]
        gn = [x for x in gn if x is not None]
        gn = " ".join(gn) if len(gn) > 0 else ""
    return row["name_id"], row["patient_id"], row["last_name"], gn


# Start with table with 1 row per patient so that when we left join, we preserve that
extracted_patient_data = spark.sql(f"SELECT * from {DB_TABLE_PATIENT}_view")

# Given names don't have a patient_id field, so compile the given names
# associated with each name_id entry in preparation to join to last names
extracted_given_names = spark.sql(f"SELECT * from {DB_TABLE_GIVEN_NAME}_view")
extracted_given_names = extracted_given_names.withColumn(
    "name_structs",
    struct(extracted_given_names.given_name, extracted_given_names.given_name_index)
)
extracted_given_names = extracted_given_names.groupBy("name_id").agg(F.collect_list("name_structs").alias("given_name_list"))
extracted_given_names.cache()

# Last names are 1:1 with a name_id representing all associated given names
# Last names also map back to patient_ids in the patient_table, so use left
# joins to preserve all present info and the 1 row per patient structure
extracted_name_data = spark.sql(f"SELECT * from {DB_TABLE_NAME}_view")
full_name_table = extracted_name_data.join(extracted_given_names, "name_id", "left")
full_name_table = full_name_table.rdd.map(construct_full_given_name).toDF(["name_id", "patient_id", "last_name", "given_name"])
full_name_table = full_name_table.withColumn("full_name_structs", struct(full_name_table.given_name, full_name_table.last_name))
full_name_table = full_name_table.groupBy("patient_id").agg(F.collect_list("full_name_structs").alias("full_name_list"))
extracted_patient_data = extracted_patient_data.join(full_name_table, "patient_id", "left")
extracted_patient_data.cache()

# Identifier table needs compiled (patients can have multiple IDs, such as MRN and
# SS) and can left join back to patients
extracted_identifier_data = spark.sql(f"SELECT * from {DB_TABLE_IDENTIFIER}_view")
extracted_identifier_data = extracted_identifier_data.withColumn("id_structs", struct(
    extracted_identifier_data.patient_identifier, extracted_identifier_data.type_code
))
extracted_identifier_data = extracted_identifier_data.groupBy("patient_id").agg(F.collect_list("id_structs").alias("ids_list"))
extracted_patient_data = extracted_patient_data.join(extracted_identifier_data, "patient_id", "left")

# Address fields can be massively collapsed into the traditional string representation
# Then we join this back on patient table
extracted_address_data = spark.sql(f"SELECT * from {DB_TABLE_ADDRESS}_view")
extracted_address_data = extracted_address_data.withColumn("address_structs", struct(
    extracted_address_data.line_1,
    extracted_address_data.line_2,
    extracted_address_data.city,
    extracted_address_data.state,
    extracted_address_data.zip_code
))
extracted_address_data = extracted_address_data.groupBy("patient_id").agg(F.collect_list("address_structs").alias("address_list"))
extracted_patient_data = extracted_patient_data.join(extracted_address_data, "patient_id", "left")

## Data Formatting And Set Creation
With the data extracted and properly joined out of the MPI, the next cell transforms the pyspark DataFrames currently holding the scattered patient information into the FHIR- and pandas-based formats the remainder of the notebook expects. `None` typed data and empty strings are handled, then data is converted to a FHIR format for any downstream flattening. The cell generates two different formats of the same set of data:

* a list of FHIR-formatted data consisting of `EVALUATION_SIZE` records on which the notebook will perform record linkage (this can be considered the "testing set" version of the data)
* a pandas DataFrame of flattened list-structured data consisting of `LABELING_SIZE` records, on which the notebook will perform ground-truth labeling (so that we can assess algorithm performance relative to a "known" baseline).

In [None]:
# TRAIN/TEST CREATION
# Use the extracted information from the MPI to create two sets of data, one
# in flattened array form in a pandas DF for labeling (training) and one as a 
# list of FHIR bundles for evaluation (testing).

from pyspark.sql.types import StructType, StructField, StringType
import fhirpathpy
from typing import Any, Callable, List, Literal, Union


selection_criteria_types = Literal["first", "last", "all"]

FIELD_COLS = ["address", "birthdate", "city", "first_name", "last_name", "mrn", "sex", "state", "zip"]
FIELD_COLS_TO_IDX = dict(zip(FIELD_COLS, range(len(FIELD_COLS))))
LINKING_FIELDS_TO_FHIRPATHS = {
    "first_name": "Patient.name.given",
    "last_name": "Patient.name.family",
    "birthdate": "Patient.birthDate",
    "address": "Patient.address.line",
    "zip": "Patient.address.postalCode",
    "city": "Patient.address.city",
    "state": "Patient.address.state",
    "sex": "Patient.gender",
    "mrn": "Patient.identifier.where(type.coding.code='MR').value",
}

"""
Returns value(s), according to the selection criteria, from a given list of values
parsed from a FHIR resource. A single string value is returned - if the selected
value is a complex structure (list or dict), it is converted to a string.
"""
def apply_selection_criteria(
    value: List[Any],
    selection_criteria: selection_criteria_types,
) -> str | List:
    if selection_criteria == "first":
        value = value[0]
    elif selection_criteria == "last":
        value = value[-1]
    elif selection_criteria == "all":
        return value
    else:
        raise ValueError(
            f'Selection criteria {selection_criteria} is not a valid option. Must be one of "first", "last", "random", or "all".'  # noqa
        )

    if type(value) is dict:
        value = json.dumps(value)
    elif type(value) is list:
        value = ",".join(value)
    return value


"""
Yields a single value from a resource based on a provided `fhir_path`.
If the path doesn't map to an extant value in the first, returns
`None` instead.
"""
def extract_value_with_resource_path(
    resource: dict,
    path: str,
    selection_criteria: Literal["first", "last", "random", "all"] = "first",
) -> Union[Any, None]:
    
    parse_function = get_fhirpathpy_parser(path)
    value = parse_function(resource)
    if len(value) == 0:
        return None
    else:
        value = apply_selection_criteria(value, selection_criteria)
        return value


"""
Accepts a FHIRPath expression, and returns a callable function
which returns the evaluated value at fhirpath_expression for
a specified FHIR resource.
"""
def get_fhirpathpy_parser(fhirpath_expression: str) -> Callable:
    return fhirpathpy.compile(fhirpath_expression)


"""
Formatting function to account for patient resources that have multiple
associated addresses. Each address is a self-contained object, replete
with its own `line` property that can hold a list of strings. This
function condenses that `line` into a single concatenated string, for
each address object, and returns the result in a properly formatted
list.
"""
def _condense_extract_address_from_resource(resource: dict, field: str):
    expanded_address_fhirpath = LINKING_FIELDS_TO_FHIRPATHS[field]
    expanded_address_fhirpath = ".".join(expanded_address_fhirpath.split(".")[:-1])
    list_of_address_objects = extract_value_with_resource_path(
        resource, expanded_address_fhirpath, "all"
    )
    list_of_usable_address_elements = []
    if field == "address":
        if list_of_address_objects is not None:
            list_of_address_lists = [
                ao.get(LINKING_FIELDS_TO_FHIRPATHS[field].split(".")[-1], [])
                for ao in list_of_address_objects
            ]
            list_of_usable_address_elements = [
                " ".join(obj) for obj in list_of_address_lists
            ]
    else:
        if list_of_address_objects is not None:
            for address_object in list_of_address_objects:
                list_of_usable_address_elements.append(
                    address_object.get(LINKING_FIELDS_TO_FHIRPATHS[field].split(".")[-1])
                )
    return list_of_usable_address_elements


"""
Helper method that flattens an incoming patient resource into a list whose
elements are the keys of the FHIR dictionary, reformatted and ordered
according to our "blocking fields extractor" dictionary.
"""
def flatten_patient_resource(resource: dict, col_to_idx: dict) -> List:
    flattened_record = [
        flatten_patient_field_helper(resource, f) for f in col_to_idx.keys()
    ]
    flattened_record = [resource["id"], None] + flattened_record
    return flattened_record


"""
Helper function that determines the correct way to flatten a patient's
FHIR field based on the specific field in question. Names and Addresses,
because their lists can hold multiple objects, are fetched completely,
whereas other fields just have their first element used (since historical
information doesn't matter there).

For any field for which the value would be `None`, instead use an empty string
(if the field isn't first_name or address) or a list with one element, the
empty string (if the field is first_name or address). This ensures that
future loops over the elements don't disrupt the flow of the matching
algorithm.
"""
def flatten_patient_field_helper(resource: dict, field: str) -> any:
    if field == "first_name":
        vals = extract_value_with_resource_path(
            resource, LINKING_FIELDS_TO_FHIRPATHS[field], selection_criteria="all"
        )
        return vals if vals is not None else [""]
    elif field in ["address", "city", "zip", "state"]:
        vals = _condense_extract_address_from_resource(resource, field)
        if field == "address":
            return vals if (vals is not None and len(vals) > 0) else [""]
        else:
            return vals[0] if (vals is not None and len(vals) > 0) else ""
    else:
        val = extract_value_with_resource_path(
            resource, LINKING_FIELDS_TO_FHIRPATHS[field], selection_criteria="first"
        )
        return val if val is not None else ""


"""
Function that transforms an aggregated row of patient information from joins of
MPI tables into a single FHIR Patient resource, with all information present
in the table loaded into appropriate fields and sub-fields.
"""
def create_patient_resource_from_spark_row(row):
    # Pull out the various fields from the passed-in row
    extracted_pid = row["patient_id"]
    extracted_birthdate = row["dob"] if row["dob"] is not None else ""
    extracted_gender = row["sex"] if row["sex"] is not None else ""
    extracted_names = row["full_name_list"] if row["full_name_list"] is not None else []
    extracted_identifiers = row["ids_list"] if row["ids_list"] is not None else []
    extracted_addresses = row["address_list"] if row["address_list"] is not None else []

    # Initialize a patient resource to append fields into
    patient_resource = {
        "resourceType": "Patient",
        "id": f"{extracted_pid}",
        "identifier": [],
        "name": [],
        "gender": f"{extracted_gender}",
        "birthDate": f"{extracted_birthdate}",
        "address": [],
    }

    # Append the appropriate list construct for each name present
    for en in extracted_names:
        givens = en.given_name if en.given_name is not None else ""
        givens = str(givens).split()
        last = en.last_name if en.last_name is not None else ""
        patient_resource["name"].append({
            "family": f"{last}",
            "given": givens
        })
    
    # Do the same for identifiers, following the correct coding scheme
    for ident in extracted_identifiers:
        patient_resource["identifier"].append({
            "type": {
                "coding": [
                    {
                        "system": "http://terminology.hl7.org/CodeSystem/v2-0203",
                        "code": ident.type_code
                    }
                ]
            },
            "value": ident.patient_identifier
        })
    
    # Finally, repeat for each address present in the patient row, making
    # sure to capture `line` elements appropriately
    for addr in extracted_addresses:
        l1 = addr.line_1 if addr.line_1 is not None else ""
        l2 = addr.line_2 if addr.line_2 is not None else ""
        lines = [x for x in [l1, l2] if x != "" and x != "None"]
        city = addr.city if addr.city is not None else ""
        state = addr.state if addr.state is not None else ""
        zipcode = addr.zip_code if addr.zip_code is not None else ""
        patient_resource["address"].append({
            "line": lines if len(lines) > 0 else [""],
            "city": f"{city}",
            "state": f"{state}",
            "postalCode": f"{zipcode}"
        })
    
    return (row["person_id"], patient_resource)


"""
Simple helper used for iteratively returning a single patient during parallel
processing.
"""
def yield_patient_resource(row):
    return row[1]


"""
Simple helper used for iteratively building a flattened representation of a 
FHIR patient resource, augmented with the person_id of the person to which
this patient is assigned in the MPI.
"""
def yield_flattened_patient_with_person_id(row):
    pid = row[0]
    fp = flatten_patient_resource(row[1], FIELD_COLS_TO_IDX)
    fp[1] = pid
    return fp


# Build the base FHIR and flattened row groups of data
fhir_mapped_data = extracted_patient_data.rdd.map(create_patient_resource_from_spark_row)
fhir_mapped_data.cache()
flattened_patient_data = fhir_mapped_data.map(yield_flattened_patient_with_person_id)
fhir_mapped_data = fhir_mapped_data.map(yield_patient_resource)

# Construct the labeling set--need an explicit schema since the 
# conversion to rdd for mapping removed that information, so 
# just have to cast it back
formatted_cols = ["patient_id", "person_id"] + FIELD_COLS
pyspark_schema = StructType([
    StructField(x, StringType(), True) for x in formatted_cols
])
flattened_patient_data = flattened_patient_data.toDF(pyspark_schema)
labeling_set = [list(x) for x in flattened_patient_data.collect()]
if LABELING_SIZE is not None and LABELING_SIZE < len(labeling_set):
    labeling_set = labeling_set[:LABELING_SIZE]
labeling_set = pd.DataFrame(labeling_set, columns=formatted_cols)
del flattened_patient_data

# Construct the evaluation set
evaluation_set = [x for x in fhir_mapped_data.collect()]
if EVALUATION_SIZE is not None and EVALUATION_SIZE < len(evaluation_set):
    evaluation_set = evaluation_set[:EVALUATION_SIZE]
evaluation_set = spark.sparkContext.parallelize(evaluation_set, numSlices=512)
del fhir_mapped_data

# Force cleanup of DB connections
del extracted_patient_data
del extracted_address_data
del extracted_identifier_data
del extracted_name_data
del full_name_table
del extracted_given_names
del mpi_patient_data
del mpi_name_data
del mpi_given_name_data
del mpi_address_data
del mpi_identifier_data

## Ground-Truth Labeling
This large section of the notebook is responsible for creating the ground-truth labels that the notebook uses for evaluation purposes. While the labels (which represent matches between candidate pairs of records in the Labeling Set) aren't guaranteed to be matches in reality, for purposes of algorithm comparison, that isn't a problem, since all algorithms are measured back to the same standard of comparison. Labeling uses the following outline of a procedure:

* build an index of possible candidate matches, which are the pairs of records in the Labeling Set there are "worth" considering for full match status (this is done by building a sliding window around a sorted list of patients and checking each record with respect to some number of its neighbors)
* generate numerical comparisons of fuzzy match scores across different fields for each candidate pair in the index
* apply a variety of filtering rules to find candidate pairs with matching scores sufficiently high
* label these subsets as "true" matches for the purpose of the notebook

The code in the cells of this section will need to run regardless of the selected notebook mode, but importantly, the values here should generally **not** require any user input or changing. All Labeling tuning should be done using the settings at the top of the notebook.

In [None]:
# CANDIDATE INDEXING
# Generates tuples of all possible candidate pairs that the labeler will compute
# match likelihoods for.

from recordlinkage.index import SortedNeighbourhood, BaseIndexAlgorithm
from recordlinkage.utils import listify

'''
A custom Indexing function built to operate compatibly on the first_name column
returned from the MPI. Since that's a list of strings (because someone could have
multiple given names), we need a way to cross-conjoin these entries and apply
the same fuzzy blocking filter window that a regular column of strings would get.
This performs joint name concatenation on copies of that column in the data, and
then uses an edit distance neighborhood to find fuzzy blocking candidates.
'''
class FirstNameSortedNeighborhood(BaseIndexAlgorithm):
    def __init__(
        self,
        left_on=None,
        right_on=None,
        window=3,
        sorting_key_values=None,
        block_on=[],
        block_left_on=[],
        block_right_on=[],
        **kwargs
    ):
        super(FirstNameSortedNeighborhood, self).__init__(**kwargs)

        # variables to block on
        self.left_on = left_on
        self.right_on = right_on
        self.window = window
        self.sorting_key_values = sorting_key_values
        self.block_on = block_on
        self.block_left_on = block_left_on
        self.block_right_on = block_right_on

    def _get_left_and_right_on(self):
        """
        We only care about the de-dupe case which involves no self.right, but this
        still needs to be implemented for super compatibility.
        """
        if self.right_on is None:
            return (self.left_on, self.left_on)
        else:
            return (self.left_on, self.right_on)

    def _get_sorting_key_values(self, array1, array2):
        """
        Return the sorting key values as a series. This function is required by the"
        package for multi-index neighborhood filtering according to some papers it's"
        built on.
        """

        concat_arrays = np.concatenate([array1, array2])
        return np.unique(concat_arrays)

    def _link_index(self, df_a, df_b):
        df_a = df_a.copy()
        df_b = df_b.copy()
        df_a["first_name"] = df_a["first_name"].str.join(" ")
        df_b["first_name"] = df_a["first_name"].str.join(" ")
        left_on, right_on = self._get_left_and_right_on()
        left_on = listify(left_on)
        right_on = listify(right_on)

        window = self.window

        # Correctly generate blocking keys
        block_left_on = listify(self.block_left_on)
        block_right_on = listify(self.block_right_on)

        if self.block_on:
            block_left_on = listify(self.block_on)
            block_right_on = listify(self.block_on)

        blocking_keys = ["sorting_key"] + [
            "blocking_key_%d" % i for i, v in enumerate(block_left_on)
        ]

        # Format the data to thread with index pairs
        data_left = pd.DataFrame(df_a[listify(left_on) + block_left_on], copy=False)
        data_left.columns = blocking_keys
        data_left["index_x"] = np.arange(len(df_a))
        data_left.dropna(axis=0, how="any", subset=blocking_keys, inplace=True)

        data_right = pd.DataFrame(df_b[listify(right_on) + block_right_on], copy=False)
        data_right.columns = blocking_keys
        data_right["index_y"] = np.arange(len(df_b))
        data_right.dropna(axis=0, how="any", subset=blocking_keys, inplace=True)

        # sorting_key_values is the terminology in Data Matching [Christen,
        # 2012]
        if self.sorting_key_values is None:
            self.sorting_key_values = self._get_sorting_key_values(
                data_left["sorting_key"].values, data_right["sorting_key"].values
            )

        sorting_key_factors = pd.Series(
            np.arange(len(self.sorting_key_values)), index=self.sorting_key_values
        )

        data_left["sorting_key"] = data_left["sorting_key"].map(sorting_key_factors)
        data_right["sorting_key"] = data_right["sorting_key"].map(sorting_key_factors)

        # Internal window size
        _window = int((window - 1) / 2)

        def merge_lagged(x, y, w):
            """Merge two dataframes with a lag on in the sorting key."""

            y = y.copy()
            y["sorting_key"] = y["sorting_key"] + w

            return x.merge(y, how="inner")

        pairs_concat = [
            merge_lagged(data_left, data_right, w) for w in range(-_window, _window + 1)
        ]

        pairs_df = pd.concat(pairs_concat, axis=0)

        return pd.MultiIndex(
            levels=[df_a.index.values, df_b.index.values],
            codes=[pairs_df["index_x"].values, pairs_df["index_y"].values],
            verify_integrity=False,
        )


"""
Function that builds an indexed list of the possible matches between candidate
pairs of records that we should consider for full match evaluation.
"""
def find_candidate_links(data):
    start = time.time()
    # Create a windowed neighborhood index on patient table because full is 
    # too expensive
    indexer = rl.Index()
    # Adding multiple different neighborhoods takes their union so we don't over-block
    indexer.add(SortedNeighbourhood('last_name', window=WINDOW_INDEX_SIZE))
    indexer.add(SortedNeighbourhood('birthdate', window=WINDOW_INDEX_SIZE))
    indexer.add(SortedNeighbourhood('mrn', window=WINDOW_INDEX_SIZE))
    indexer.add(FirstNameSortedNeighborhood('first_name', window=WINDOW_INDEX_SIZE))
    candidate_links = indexer.index(data)
    print(len(candidate_links), "candidate pairs identified")

    # Note: using a multi-indexer treats the row number as the index, so
    # results will automatically be in acceptable eval format
    end = time.time()
    print("Identifying possible candidate pairs took ", str(round(end - start, 2)), "seconds")
    return candidate_links


"""
Function that transforms a recordlinkage toolkit multi-index into a set of
candidate tuples, and constructs a dictionary mapping the "lower indexed" record 
of each pair to all "higher indexed" records that are linked to it. This structure
allows us to perform efficient scoring later by taking set differences.
"""
def get_pred_match_dict_from_multi_idx(mltidx, n_rows):
    candidate_tuples = mltidx.to_list()
    pred_matches = {k: set() for k in range(n_rows)}
    for pair in candidate_tuples:
        reference_record = min(pair)
        linked_record = max(pair)
        pred_matches[reference_record].add(linked_record)
    return pred_matches

# Convert mrn col into actual Nones rather than string Nones
# We'll need to reverse this in the later spark-block stages to parallel mask
labeling_set = labeling_set.replace({"": None})

candidate_links = find_candidate_links(labeling_set)

In [None]:
# FEATURE COMPARATOR
# Generate string similarity scores for all features in all candidate pairs.

import rapidfuzz

"""
Returns the normalized similarity measure between string1 and string2, as
determined by the similarlity measure. The higher the normalized similarity measure
(up to 1.0), the more similar string1 and string2 are. A normalized similarity
measure of 0.0 means string1 and string 2 are not at all similar. This function
expects basic text cleaning (e.g. removal of numeric characters, trimming of spaces,
etc.) to already have been performed on the input strings.
"""
def compare_strings(
    string1: str,
    string2: str,
    similarity_measure: Literal[
        "JaroWinkler", "Levenshtein", "DamerauLevenshtein"
    ] = "JaroWinkler",
) -> float:
    if similarity_measure == "JaroWinkler":
        return rapidfuzz.distance.JaroWinkler.normalized_similarity(string1, string2)
    elif similarity_measure == "Levenshtein":
        return rapidfuzz.distance.Levenshtein.normalized_similarity(string1, string2)
    elif similarity_measure == "DamerauLevenshtein":
        return rapidfuzz.distance.DamerauLevenshtein.normalized_similarity(
            string1, string2
        )

"""
A special class for comparing LoL concatenated elements. Use the full concatenation
of all values to account for multiple entries like given names.
"""
class CompareNestedString(BaseCompareFeature):
    def _compute_vectorized(self, s1, s2):
        strrep1 = s1.str.lstrip('[').str.rstrip(']').str.split(',')
        strrep2 = s2.str.lstrip('[').str.rstrip(']').str.split(',')
        return (strrep1.str[0] == strrep2.str[0]).astype(float)


""""
A special class for comparing LoL first name elements. Use the full concatenation
of all names to account for multiple given names.
"""
class CompareFirstName(BaseCompareFeature):
    def _compute_vectorized(self, s1, s2):
        strrep1 = s1.str.lstrip('[').str.rstrip(']').str.split(',')
        strrep2 = s2.str.lstrip('[').str.rstrip(']').str.split(',')
        jarowinklers = np.vectorize(compare_strings)(strrep1.str.join(" "), strrep2.str.join(" "))
        return jarowinklers


"""
A special class for comparing LoL address line elements. Check each address
line against each other address line to account for patients who have changed
residence.
"""
class CompareAddress(BaseCompareFeature):
    def _compute_vectorized(self, s1, s2):
        def comp_address_fields(a1_list, a2_list):
            best_score = 0.0
            for a1 in a1_list:
                for a2 in a2_list:
                    score = compare_strings(a1, a2)
                    if score >= best_score:
                        best_score = score
            return best_score

        strrep1 = s1.str.lstrip('[').str.rstrip(']').str.split(',')
        strrep2 = s2.str.lstrip('[').str.rstrip(']').str.split(',')
        jarowinklers = np.vectorize(comp_address_fields)(strrep1, strrep2)
        return jarowinklers


'''
Produces a dataframe with a multi-index, in which each tuple of row indices
denotes one potential candidate match. The value in each column of the DF
is the fuzzy match similarity score between the two records given by the 
multi-index.
'''
def compute_comparator_matrix(data, candidate_links):
    start = time.time()

    # Apply feature comparisons on each supported field from the MPI
    comp = rl.Compare()
    comp.add(CompareFirstName("first_name", "first_name",label="first_name"))
    comp.string(
        "last_name", "last_name", method="jarowinkler", label="last_name"
    )
    comp.string("mrn", "mrn", method="jarowinkler", label="mrn")
    comp.string(
        "birthdate", "birthdate", method="jarowinkler", label="birthdate"
    )
    comp.add(CompareAddress("address", "address", label="address"))
    comp.string("city", "city", method="jarowinkler", label="city")
    comp.string("zip", "zip", method="jarowinkler", label="zip")
    comp.string("sex", "sex", method="jarowinkler", label="sex")
    features = comp.compute(candidate_links, data)

    end = time.time()
    print("Computation took", str(round(end - start, 2)), "seconds")
    return features

features = compute_comparator_matrix(labeling_set, candidate_links)

### Labeling Method 1: Virginia Labels
The VA Labels are a result of DIBBs' early work with a very simple record linkage system. Under the VA Labeling Scheme, two records match if and only if they exactly agree on: first name, last name, date of birth, and address. The VA Labels aren't intended to reflect "real" match criteria, but rather illustrate performance for data "in the best case." Record pairs that are linked by VA labels are the super easy cases that any other record linkage algorithm should readily pick up. 

In [None]:
def get_va_labels(data, candidate_links):
    start = time.time()

    # Apply feature comparisons on each supported field from the MPI
    comp = rl.Compare()
    comp.add(CompareNestedString("first_name", "first_name",label="first_name"))
    comp.exact("last_name", "last_name", label="last_name")
    comp.exact("birthdate", "birthdate", label="birthdate")
    comp.add(CompareNestedString("address", "address", label="address"))
    features = comp.compute(candidate_links, data)
    matches = features[features.sum(axis=1) == 4]

    end = time.time()
    print("Comparing candidates took", str(round(end - start, 2)), "seconds")

    matches = get_pred_match_dict_from_multi_idx(matches.index, len(data))
    return matches


va_labels = get_va_labels(labeling_set, candidate_links)

### Labeling Method 2: UK NHS Labels
The United Kingdom's National Health Services administration uses a straightforward deterministic matching procedure to classify patients as matches. Their baseline deterministic model performs linkage in three steps:

1) exact match on DOB, exact match on sex, exact match on NHS number (which for our purposes is equivalent to MRN)
2) of the unlinked records, new matches are linked using partial/fuzzy match on DOB, exact match on sex, exact match on postal code, and a partial match on MRN
3) of the remaining unlinked records, new matches are found by exact match on DOB, exact match on sex, and exact match on post code

This system's use of a variety of fields in different combinations gives a good insight to the "average case" performance of a linkage algorithm. Since fields can be missing or inexact in one pass and still match in another with a new combination, this labeling scheme gives a good representation of a linkage algorithm's performance on data if we otherwise knew nothing about the nature and field quality of the data.

In [None]:
'''
Helper function that combines two dictionaries, each of which has already been
formatted in the requisite stats indexing fashion.
'''
def combine_match_dicts(m1, m2):
    m3 = {}
    for k in m1:
        union_set = set()
        union_set = union_set.union(m1[k])
        union_set = union_set.union(m2[k])
        m3[k] = union_set
    return m3


'''
Generate the UK's National Health Service labels using three match conditions, 
depending on the available field information and whether constraints are
progressively relaxed.
'''
def get_uk_nhs_labels(data, features):
    matches_type_1 = features.loc[
        (features['birthdate'] == 1.0) &
        (features['sex'] == 1.0) &
        (features['mrn'] == 1.0)
    ]

    matches_type_2 = features.loc[
        (features['birthdate'] >= BIRTHDAY_THRESHOLD) & 
        (features['sex'] == 1.0) &
        (features['zip'] == 1.0) &
        (features['mrn'] >= JARO_THRESHOLD)
    ]

    matches_type_3 = features.loc[
        (features['birthdate'] == 1.0) &
        (features['sex'] == 1.0) &
        (features['zip'] == 1.0)
    ]

    m1_dict = get_pred_match_dict_from_multi_idx(matches_type_1.index, len(data))
    m2_dict = get_pred_match_dict_from_multi_idx(matches_type_2.index, len(data))
    m3_dict = get_pred_match_dict_from_multi_idx(matches_type_3.index, len(data))
    pred_matches = combine_match_dicts(m1_dict, m2_dict)
    pred_matches = combine_match_dicts(pred_matches, m3_dict)
    
    return pred_matches

uk_labels = get_uk_nhs_labels(labeling_set, features)

### Labeling Method 3: NCI SEER Labels
The National Cancer Institute's Surveillance, Epidemiology, and End Results program provides the third and final set of labeling criteria our notebook uses for ground-truth classification. This system employs the following rules:

1) two records should be linked if they are an exact match on SSN (MRN for our purposes), as well as fuzzy match on at least 2 of first name, last name, and birthdate
2) of the remaining unlinked records, they should be considered a match if they are a fuzzy match on first name, last name, and sex, and are additionally a fuzzy match on either MRN or birthdate.

The SEER labels are what we consider our "best" or most accurate/representative labels. They use fields that suit production data well, and they keep a good balance of sensitivity and specificity. Performance on these labels is our benchmark of which algorithm is functioning best.

In [None]:
'''
Generate the NCI's SEER Labels using two types of matches, based on whether or not
the candidate pair has a perfectly agreeing MRN.
'''
def get_seer_labels(data, features):
    mrn_matches = features.loc[features['mrn'] == 1.0]
    matches_type_1 = mrn_matches.loc[
        ((mrn_matches['first_name'] >= JARO_THRESHOLD) & (mrn_matches['last_name'] >= JARO_THRESHOLD)) |
        ((mrn_matches['first_name'] >= JARO_THRESHOLD) & (mrn_matches['birthdate'] >= BIRTHDAY_THRESHOLD)) |
        ((mrn_matches['birthdate'] >= BIRTHDAY_THRESHOLD) & (mrn_matches['last_name'] >= JARO_THRESHOLD))
    ]

    matches_type_2 = features.loc[
        (features['first_name'] >= JARO_THRESHOLD) &
        (features['last_name'] >= JARO_THRESHOLD) & 
        (features['sex'] >= JARO_THRESHOLD) &
        (
            ((features['mrn'] >= JARO_THRESHOLD) & (features['birthdate'] >= BIRTHDAY_THRESHOLD)) | 
            (features['birthdate'] == 1.0)
        )
    ]

    m1_dict = get_pred_match_dict_from_multi_idx(matches_type_1.index, len(data))
    m2_dict = get_pred_match_dict_from_multi_idx(matches_type_2.index, len(data))
    pred_matches = combine_match_dicts(m1_dict, m2_dict)

    return pred_matches

seer_labels = get_seer_labels(labeling_set, features)

## Parallel Linkage Code
The following massive cell contains all of the code that makes our record linkage analysis work. While we follow the same general format and structure of the versions of these functions present in the SDK, for this notebook, we need the parallel processing power spark can offer to make large-scale analysis computationally tractable. The functions below are heavily documented with comments spread throughout the code, but in general, an incoming record from the Evaluation Set undergoes the following process (each such record is computed in parallel to avoid looping over the large testing list):

* fetch candidate records in a block from the spark-extracted version of the MPI
* for each candidate record in the block, use the provided comparator functions to determine if the candidate is a match to the incoming record
* map all matched candidates from the block back to their indexed ID in the original Labeling Set
* deduplicate any match pairs that were found in multiple passes of the algorithm.

In [None]:
# LINKAGE DRIVER FUNCTIONS

import re

"""
Extracts values from a given patient record for eventual use in database
record linkage blocking. A list of fields to block on, as well as a mapping
of those fields to any desired transformations of their extracted values,
is used to fhir-path parse the value out of the incoming patient record.
"""
def extract_blocking_values_from_record(
    record: dict, blocking_fields: List[dict]
) -> dict:
    transform_funcs = {
        "first4": lambda x: x[:4] if len(x) >= 4 else x,
        "last4": lambda x: x[-4:] if len(x) >= 4 else x,
    }

    block_vals = dict.fromkeys([b.get("value") for b in blocking_fields], "")
    transform_blocks = [b for b in blocking_fields if "transformation" in b]
    transformations = dict(
        zip(
            [b.get("value") for b in transform_blocks],
            [b.get("transformation") for b in transform_blocks],
        )
    )
    for block_dict in blocking_fields:
        block = block_dict.get("value")

        # Apply utility extractor for safe parsing
        value = extract_value_with_resource_path(
            record,
            LINKING_FIELDS_TO_FHIRPATHS[block],
            selection_criteria="first",
        )
        if value:
            if block in transformations:
                value = transform_funcs[transformations[block]](value)
                block_vals[block] = {
                    "value": value,
                    "transformation": transformations[block],
                }
            else:
                block_vals[block] = {"value": value}

    # Account for any incoming FHIR resources that return no data
    # for a field--don't count this against records to-block
    keys_to_pop = []
    for field in block_vals:
        if _is_empty_extraction_field(block_vals, field):
            keys_to_pop.append(field)
    for k in keys_to_pop:
        block_vals.pop(k)

    return block_vals


"""
Helper method that determines when a field extracted from an incoming
record should be considered "empty" for the purpose of blocking.
Fields whose values are either `None` or the empty string should not
be used when retrieving blocked records from the MPI, since that
would impose an artificial constraint (e.g. if an incoming record
has no `last_name` field, we don't want to retrieve only records
from the MPI that also have no `last_name`).
"""
def _is_empty_extraction_field(block_vals: dict, field: str):
    # Means the value extractor found no data in the FHIR resource
    if block_vals[field] == "":
        return True
    # Alternatively, there was "data" there, but it's empty
    elif (
        block_vals[field].get("value") is None
        or block_vals[field].get("value") == ""
        or block_vals[field].get("value") == [""]
    ):
        return True  # pragma: no cover
    return False


"""
Function that uses a pandas DataFrame construct of an extracted MPI to efficiently
filter down candidates into appropriate blocks. While the filtering itself is not
parallelized, since it occurs on the worker nodes, each executor is performing
linkage for one or more test records simultaneously. As a result, a pandas DF
provides an appropriate level of speed to use .loc retrieval.
"""
def spark_block(block_vals: dict, labeling_set: pd.DataFrame):

    # We'll sequentially apply each blocking filter, since that's equivalent to finding
    # their intersection all at once
    result = labeling_set
    for blocking_criterion in block_vals:
        props = block_vals[blocking_criterion]
        if props["value"] is None or props["value"] == "":
            continue

        # Special case if we're blocking on first_name or address: pyspark can serialize these
        # as JSON strings, but that means they actually get stored as strings, so we need to 
        # account for the brackets '[' and ']'
        if blocking_criterion == "first_name" or blocking_criterion == "address":
            if "transformation" in props:
                if props["transformation"] == "first4":
                    result = result.loc[result[blocking_criterion].str.startswith("[" + props["value"])]
                elif props["transformation"] == "last4":
                    result = result.loc[result[blocking_criterion].str.endswith(props["value"] + "]")]
            else:
                result = result.loc[result[blocking_criterion] == "[" + props["value"] + "]"]

        # Regular case is just a straight string comparison since we've already stripped the 
        # de-serialization quotes
        else:
            if "transformation" in props:
                if props["transformation"] == "first4":
                    result = result.loc[result[blocking_criterion].str.startswith(props["value"])]
                elif props["transformation"] == "last4":
                    result = result.loc[result[blocking_criterion].str.endswith(props["value"])]
            else:
                result = result.loc[result[blocking_criterion] == props["value"]]
    return result


"""
Function that compares a single blocked candidate from the MPI with the
incoming, now flattened, record. Comparison functions for evaluating the linkage
match are applied iteratively, and a net score is accumulated giving the
total strength of the linkage match. This function is applied sequentially to
each of the candidate records returned in the block.
"""
def spark_compare_df_helper(row, flattened_record, funcs, col_to_idx, matching_rule, **kwargs):

    # Iteratively accumulate results of each feature-wise comparison
    match_score = 0.0
    for col in funcs:
        func = funcs[col]
        feature_idx_in_record = col_to_idx[col]
        feature_in_record = flattened_record[feature_idx_in_record]

        if "fuzzy" in func:
            similarity_measure, fuzzy_threshold = _get_fuzzy_comp_params(col, **kwargs)

            # Given name is a list (possibly including middle name), so our logic says
            # concatenate all the values together and then fuzzy compare
            if col == "first_name":
                feature_in_record = " ".join(feature_in_record)
                feature_in_mpi = re.sub(r'\[|\]', "", row[col])
                feature_in_mpi = feature_in_mpi.split(", ")
                feature_in_mpi = " ".join(feature_in_mpi)
                feature_score = compare_strings(feature_in_mpi, feature_in_record, similarity_measure)
                match_score = _apply_score_contribution(
                    feature_score, col, fuzzy_threshold, match_score, matching_rule, **kwargs
                )

            # Address is also a list, but rather than concatenate them all, we check if each
            # line of an incoming address matches any line of an MPI address; this accounts for
            # a patient's change of residence history
            elif col == "address":
                feature_in_mpi = re.sub(r'\[|\]', "", row[col])
                feature_in_mpi = feature_in_mpi.split(", ")
                best_score = 0.0
                for r in feature_in_record:
                    for m in feature_in_mpi:
                        feature_comp = compare_strings(r, m, similarity_measure)
                        if feature_comp > best_score:
                            best_score = feature_comp
                match_score = _apply_score_contribution(
                    best_score, col, fuzzy_threshold, match_score, matching_rule, **kwargs
                )
            
            # Regular case: straight string comparison on the fields
            else:
                feature_in_mpi = row[col]
                feature_score = compare_strings(feature_in_mpi, feature_in_record, similarity_measure)
                match_score = _apply_score_contribution(
                    feature_score, col, fuzzy_threshold, match_score, matching_rule, **kwargs
                )
        else:
            feature_in_mpi = row[col]
            feature_score = compare_strings(feature_in_mpi, feature_in_record, similarity_measure)
            match_score = _apply_score_contribution(
                feature_score, col, 1.0, match_score, matching_rule, **kwargs
            )

    return pd.Series([row['patient_id'], match_score])


"""
Quick helper to extract the threshold and metric used in fuzzy string comparisons.
We have this to not clutter the main analytic function.
"""
def _get_fuzzy_comp_params(col, **kwargs):
    similarity_measure = "JaroWinkler"
    if "similarity_measure" in kwargs:
        similarity_measure = kwargs["similarity_measure"]
    threshold = 0.7
    
    # Optional unique threshold per column in the data
    if "thresholds" in kwargs:
        if col in kwargs["thresholds"]:
            threshold = kwargs["thresholds"][col]
    
    # Single universal threshold for all fields
    elif "threshold" in kwargs:
        threshold = kwargs["threshold"]
        
    return similarity_measure, threshold


"""
Helper to apply the result of a feature-wise comparison between an incoming record and a 
candidate from the MPI to the accumulated 'match score' of the two. In a 'normal' case
where we're not using log-odds, this is just a count of the number of feature comparisons
that satisfy the fuzzy string threshold. In the log-odds case, this is an accumulation of
the weighted probability score that the two records are a match.
"""
def _apply_score_contribution(feature_score, col, fuzzy_threshold, match_score, match_rule, **kwargs):
    if "log" in match_rule:
        col_odds = kwargs["log_odds"][col]
        match_score += (feature_score * col_odds)
    else:
        if feature_score >= fuzzy_threshold:
            match_score += 1.0
    return match_score


"""
Orchestrator function that provisions the RDD-mapping of the parallel candidate evaluation.
Once we've parallel-processed the candidates, we apply RDD filtering to identify only those
candidates who satisfy the provided matching rule. (be that "all feature-wise comparisons are
true" or "total probability score exceeds log-odds cutoff").
"""
def spark_compare(data_block: pd.DataFrame, record: List, funcs: dict, col_to_idx: dict, matching_rule, **kwargs):
    res = data_block.apply(lambda x: spark_compare_df_helper(x, record, funcs, col_to_idx, matching_rule, **kwargs), axis=1)
    if "log" in matching_rule:
        match_cutoff = kwargs["true_match_threshold"]
    else:
        match_cutoff = len(funcs)
    match_list = res.loc[res[1] >= match_cutoff]
    match_list = list(match_list[0])
    return match_list


'''
Main driver function that's applied in parallel to each record of the incoming
evaluation set. The procedure is much the same as if the record were being
processed in real time, except that a pandas dataframe (rather than a networked
DB) is used to retrieve the candidate block for speed purposes.
'''
def parallel_eval(record, algo_config: List[dict], labeling_set: pd.DataFrame, testing_field=None, testing_vals=None):

    # Flatten incoming resource and remove any lurking None's
    flattened_record = flatten_patient_resource(record, FIELD_COLS_TO_IDX)
    if flattened_record[2] is None:
        flattened_record[2] = [""]
    if flattened_record[5] is None:
        flattened_record[5] = [""]

    if testing_field:
        matches = {str(x): [] for x in testing_vals}
    else:
        matches = []

    for linkage_pass in algo_config:
        blocking_fields = linkage_pass["blocks"]
        field_blocks = extract_blocking_values_from_record(record, blocking_fields)
        if len(field_blocks) == 0:
            continue
        
        # Use the extract of the MPI to quickly filter down a block of candidates
        data_block = spark_block(field_blocks, labeling_set)
        col_to_idx = {v: k for k, v in enumerate(formatted_cols)}

        # Parallel process the candidates to find any matches
        kwargs = linkage_pass.get("kwargs", {})

        if testing_field:
            for tv in testing_vals:
                vkwargs = copy.deepcopy(kwargs)
                vkwargs["thresholds"][testing_field] = tv
                matching_records = spark_compare(
                    data_block, flattened_record, linkage_pass["funcs"], col_to_idx, linkage_pass["matching_rule"], **vkwargs
                )
                matches[str(tv)] += matching_records
        else:
            matching_records = spark_compare(
                data_block, flattened_record, linkage_pass["funcs"], col_to_idx, linkage_pass["matching_rule"], **kwargs
            )
            matches += matching_records

    if testing_field:
        return flattened_record[0], [matches[str(tv)] for tv in testing_vals]
    else:
        return flattened_record[0], matches


'''
Turn the patient_ids of identified "found matches" into the threaded multi-row-indices
that the ground truth labeler can understand. This way, all indices are expressed in
the same scheme for statistical comparison.
'''
def map_patient_ids_to_idxs(pids: List, data: pd.DataFrame):
    record_idxs = []
    for pid in pids:
        row_idx = data[data['patient_id'] == pid].index.values
        if len(row_idx) > 0:
            record_idxs.append(row_idx[0])
    return record_idxs


'''
Find existing patient records in a dataset that map to each incoming record in a block 
of FHIR data. Since the FHIR data itself is pulled from the MPI, we can freely use it
for querying for blocks without risk of finding unrecognized data.
'''
def link_all_fhir_records_block_dataset(records, algo_config: List[dict], label_set: pd.DataFrame, testing_field=None, testing_vals=None):
    if testing_field:
        if not testing_vals or len(testing_vals) == 0:
            print("Must supply list of threshold values to test")
            return
        found_matches = { str(x): {} for x in testing_vals }
    else:
        found_matches = {}
    start = time.time()
    res = records.map(lambda x: parallel_eval(x, algo_config, label_set, testing_field, testing_vals))
    res.cache()

    if testing_field:
        for x in res.collect():
            ridx = map_patient_ids_to_idxs([x[0]], label_set)[0]
            for tv in range(len(x[1])):
                linked_rs_at_threshold = x[1][tv]
                lidx = set(linked_rs_at_threshold)
                lidx = map_patient_ids_to_idxs(lidx, label_set)
                found_matches[str(testing_vals[tv])][ridx] = set(lidx)
            
        print("finished linking ", str(time.time() - start))
        return found_matches

    else:
        for x in res.collect():
            ridx = map_patient_ids_to_idxs([x[0]], label_set)[0]
            lidx = set(x[1])
            lidx = map_patient_ids_to_idxs(lidx, label_set)
            found_matches[ridx] = set(lidx)

        print("finished linking ", str(time.time() - start))
        return found_matches

'''
Due to transforming patient_ids back into indices, multiple tuples get inserted for each
match, i.e. we record the link (i,j) and the link (j,i), which would skew our stats.
This function eliminates these redundancies and makes sure each link is counted once.
'''
def dedupe_match_double_counts(match_dict, is_fuzzy_test=False):
    if is_fuzzy_test:
        for tv in match_dict:
            sub_dict = match_dict[tv]
            for k in sub_dict:
                if k > 0:
                    lower_set = set(list(range(k)))
                    sub_dict[k] = sub_dict[k].difference(lower_set)
                if k in sub_dict[k]:
                    sub_dict[k].remove(k)
        return match_dict
    
    else:
        for k in match_dict:
            if k > 0:
                lower_set = set(list(range(k)))
                match_dict[k] = match_dict[k].difference(lower_set)
            if k in match_dict[k]:
                match_dict[k].remove(k)
        return match_dict


# Change the real None-type values back into their placeholders so we can mass-boolean filter
labeling_set = labeling_set.replace({None: ""})

## `test_thresholds` Mode
The following cell holds the code necessary to perform fuzzy matching threshold testing for a given patient field. When run in this mode, the statistical outputs of each value of the tested field are displayed so that a user can choose the threshold that best meets their needs.

In [None]:
"""
Helper function that scores a fuzzy matching threshold test on a subset of relevant
performance statistics.
"""
def score_fuzzy_test(found_matches, true_matches, records_in_dataset, testing_vals):
    scores = {}
    for tv in testing_vals:
        true_positives = 0.0
        false_positives = 0.0
        false_negatives = 0.0
        matches_at_threshold = found_matches[str(tv)]

        for root_record in true_matches:
            if root_record in matches_at_threshold:
                true_positives += len(
                    true_matches[root_record].intersection(matches_at_threshold[root_record])
                )
                false_positives += len(
                    matches_at_threshold[root_record].difference(true_matches[root_record])
                )
                false_negatives += len(
                    true_matches[root_record].difference(matches_at_threshold[root_record])
                )
            else:
                false_negatives += len(true_matches[root_record])
        for record in set(set(matches_at_threshold.keys()).difference(true_matches.keys())):
            false_positives += len(matches_at_threshold[record])

        sensitivity = round(true_positives / (true_positives + false_negatives), 3)
        ppv = round(true_positives / (true_positives + false_positives), 3)
        f1 = round(
            (2 * true_positives) / (2 * true_positives + false_negatives + false_positives),
            3,
        )
        f_half_num = (1.0 + 0.5**2) * true_positives
        f_half_denom_new = (0.5**2) * false_negatives + false_positives
        f_half = round(f_half_num / (f_half_num + f_half_denom_new), 3)
        scores[str(tv)] = {
            "tp": true_positives,
            "fp": false_positives,
            "fn": false_negatives,
            "sens": sensitivity,
            "ppv": ppv,
            "f1": f1,
            "f_half": f_half
        }
    
    return scores


if NOTEBOOK_MODE == "test_thresholds":
    new_algo = copy.deepcopy(DIBBS_BASIC)
    col_thresholds = {
        "address": 0.85,
        "birthdate": 0.85,
        "city": 0.85,
        "first_name": 0.85,
        "last_name": 0.85,
        "mrn": 0.85,
        "sex": 0.85,
        "state": 0.85,
        "zip": 0.85
    }
    new_algo[0]["kwargs"] = { "thresholds": col_thresholds }
    new_algo[1]["kwargs"] = { "thresholds": col_thresholds }

    found_matches_dibbs_basic = link_all_fhir_records_block_dataset(evaluation_set, new_algo, labeling_set, TESTING_FIELD, TESTING_VALS)
    found_matches_dibbs_basic = dedupe_match_double_counts(found_matches_dibbs_basic, True)
    eval_scores = score_fuzzy_test(found_matches_dibbs_basic, seer_labels, EVALUATION_SIZE, TESTING_VALS)
    for t in eval_scores:
        print(t, eval_scores[t])
        print()

## `train_weights` Mode
The cells below contain code necessary to run the notebook in log-odds training mode, which allows a user to recompute the population weights for each field of patient data for later use with the DIBBs Enhanced algorithm. Weights training is made up of several steps:

* recalculate the m- and u-probabilities, which measure the likelihoods that two records will have the same field value conditioned on whether or not the records match
* recompute the log-odds, which is a ratio of the m- and u-probabilities calculated above
* compute the distribution of feature scures for candidate pairs which do match and for a selection of candidate pairs which don't match
* visualize these distributions to identify the cutoff score separating matches from non-matches

In [None]:
# RECOMPUTE AND EXPORT LOG-ODDS
# Estimates the m and u probabilities, computes their log-odds ratio,
# then saves the output to a file for later loading.

from random import randint
from math import log

"""
For a given set of patient records, calculate the per-field
m-probability. The m-probability for field X is defined as the
probability that a pair of records A and B have the same value in
X, given that A and B are a true matching pair. This function
incorporates LaPlacian Smoothing to account for unseen data and
to resolve future logarithms against 0.
"""
def calculate_m_probs(
    data: pd.DataFrame,
    true_matches: dict,
    cols: Union[List[str], None] = None,
):
    if cols is None:
        cols = data.columns
    m_probs = {c: 1.0 for c in cols}
    total_pairs = 1.0
    for root_record, paired_records in true_matches.items():
        total_pairs += len(paired_records)
        for pr in paired_records:
            for c in cols:
                if data[c].iloc[root_record] == data[c].iloc[pr]:
                    m_probs[c] += 1
    for c in cols:
        m_probs[c] /= total_pairs
    return m_probs

"""
Function to estimate the u-probabilities of a set of data using a quick-check
RNG heuristic to estimate negative pairs.
"""
def calculate_u_probs(
    data: pd.DataFrame,
    true_matches: dict,
    n_samples: int,
):

    # Quick heuristic check to make sure we can generate enough
    # negative samples to satisfy the parameter request
    max_combos = (len(data.index) * (len(data.index) - 1)) / 2.0
    # Based on bernoulli limits for deterministic runtimes, don't worry about the ln(2)
    # This is how many neg pairs you can expect to generate in "reasonable" time
    runtime_sample_neg_ceiling = np.log(2) * 0.10 * max_combos
    if n_samples >= runtime_sample_neg_ceiling:
        print("Too many samples requested for data size. Lower n_samples parameter.")
        return

    u_probs = {c: 1.0 for c in data.columns}
    neg_pairs = set()

    # Use speed of RNGers to take a sample out of all possible non-match pairs
    # without explicitly constructing the list
    while len(neg_pairs) < n_samples:
        idx1 = randint(0, len(data.index)-1)
        idx2 = randint(0, len(data.index)-1)
        root = min(idx1, idx2)
        ref = max(idx1, idx2)
        if root not in true_matches or ref not in true_matches[root]:
            neg_pairs.add((root, ref))

    neg_pairs = list(neg_pairs)

    # Count up the number of candidate pairs that have a field that matches,
    # then normalize per field
    for root, ref in neg_pairs:
        for c in data.columns:
            if data[c].iloc[root] == data[c].iloc[ref]:
                u_probs[c] += 1.0
    for c in data.columns:
            u_probs[c] = u_probs[c] / (n_samples + 1.0)

    return u_probs


"""
Calculate the per-field log odds ratio score that two records will
match in a given field. Measures the likelihood that two records
match on a column due to being a true match as opposed to random
chance.
"""
def calculate_log_odds(
    m_probs: dict,
    u_probs: dict,
):
    if m_probs.keys() != u_probs.keys():
        raise ValueError(
            "m- and u- probability dictionaries must contain the same set of keys"
        )
    log_odds = {}
    for k in m_probs:
        log_odds[k] = log(m_probs[k]) - log(u_probs[k])
    return log_odds


if NOTEBOOK_MODE == "train_weights":
    m_probs = calculate_m_probs(labeling_set, seer_labels)
    u_probs = calculate_u_probs(labeling_set, seer_labels, n_samples=NEG_SAMPLES)
    log_odds = calculate_log_odds(m_probs, u_probs)
    log_odds.pop("patient_id")
    log_odds.pop("person_id")
    print(log_odds)
    mssparkutils.fs.put(LINKAGE_OUTPUTS_FILESYSTEM + "updated_log_odds.json", json.dumps(log_odds), True)

In [None]:
# PROFILE LOG-ODDS WEIGHTS FOR CUTOFF DETERMINATION
# Run computations on the log-odds cutoff scores for both matches and non-matches
# for later graphical evaluation.

"""
Helper function that computes the match score between two records in the same
data set, given a set of columns on which to perform fuzzy match evaluation.
"""
def profiling_df_helper(data, idx_i, idx_j, fuzzy_cols, log_odds, col_to_idx, cols_to_thresholds=None):

    # Iteratively accumulate results of each feature-wise comparison
    match_score = 0.0
    ri = data[idx_i]
    rj = data[idx_j]

    for col in fuzzy_cols:
        col_odds = log_odds[col]
        cidx = col_to_idx[col]
        similarity_measure="JaroWinkler"
        min_sim_threshold = 0.85
        if cols_to_thresholds is not None:
            if col in cols_to_thresholds:
                min_sim_threshold = cols_to_thresholds[col]

        # Given name is a list (possibly including middle name), so our logic says
        # concatenate all the values together and then fuzzy compare
        if col == "first_name":
            feature_in_record = re.sub(r'\[|\]', "", ri[cidx])
            feature_in_record = feature_in_record.split(", ")
            feature_in_record = " ".join(feature_in_record)
            feature_in_mpi = re.sub(r'\[|\]', "", rj[cidx])
            feature_in_mpi = feature_in_mpi.split(", ")
            feature_in_mpi = " ".join(feature_in_mpi)
            feature_score = compare_strings(feature_in_mpi, feature_in_record, similarity_measure)

        # Address is also a list, but rather than concatenate them all, we check if each
        # line of an incoming address matches any line of an MPI address; this accounts for
        # a patient's change of residence history
        elif col == "address":
            feature_in_record = re.sub(r'\[|\]', "", ri[cidx])
            feature_in_record = feature_in_record.split(", ")
            feature_in_mpi = re.sub(r'\[|\]', "", rj[cidx])
            feature_in_mpi = feature_in_mpi.split(", ")
            feature_score = 0.0
            for r in feature_in_record:
                for m in feature_in_mpi:
                    feature_comp = compare_strings(r, m, similarity_measure)
                    if feature_comp > feature_score:
                        feature_score = feature_comp
        
        # Regular case: straight string comparison on the fields
        else:
            feature_in_record = ri[cidx]
            feature_in_mpi = rj[cidx]
            feature_score = compare_strings(feature_in_mpi, feature_in_record, similarity_measure)
        
        match_score += feature_score * col_odds

    return match_score


"""
Function to generate net feature scores for a set of patient record data.
These scores can later be used to build distributions of matches vs non-
matches for graphical separation.
"""
def profile_log_odds_computation(
    data: pd.DataFrame,
    true_matches: dict,
    log_odds: dict,
    fuzzy_cols,
    neg_samples: int = 50000,
    cols_to_thresholds = None,
):
    neg_pairs = set()
    while len(neg_pairs) < neg_samples:
        idx1 = randint(0, len(data.index)-1)
        idx2 = randint(0, len(data.index)-1)
        root = min(idx1, idx2)
        ref = max(idx1, idx2)
        if root not in true_matches or ref not in true_matches[root]:
            neg_pairs.add((root, ref))

    neg_pairs = list(neg_pairs)

    data_cols = list(data.columns)
    col_to_idx = dict(zip(data_cols, range(len(data_cols))))
    data = data.values.tolist()

    true_match_scores = []
    for root_record, paired_records in true_matches.items():
        for pr in paired_records:
            score = profiling_df_helper(data, root_record, pr, fuzzy_cols, log_odds, col_to_idx, cols_to_thresholds)
            true_match_scores.append(score)

    non_match_scores = []
    for record_1, record_2 in neg_pairs:
        score = profiling_df_helper(data, record_1, record_2, fuzzy_cols, log_odds, col_to_idx, cols_to_thresholds)
        non_match_scores.append(score)
    
    return true_match_scores, non_match_scores


if NOTEBOOK_MODE == "train_weights":
    match_scores, non_match_scores = profile_log_odds_computation(labeling_set, seer_labels, log_odds, COLS_TO_PROFILE, NEG_SAMPLES, COLS_TO_THRESHOLDS)

### Visualizing Cutoff Scores
The cell below creates a visualization of the distributions of match scores and non-match scores. Since matplotlib can't interactively zoom in Synapse, there are a selection of controls at the top of the cell that allow regeneration of the graph with different ranges to create the effect of zooming in anyway.

IMPORTANT: Once a `train_weights` run is completed to this point, you can change the graph to zoom in or out **without** re-running the expensive linkage code above. The results are saved by the previous cell, so this cell's only responsibility is to show those results in a graph. Simply re-run this individual cell with different axis values to get a fast graphical update before moving on to changing the columns at the top of the notebook.

In [None]:
# VISUALIZE LOG-ODDS GRAPH
# Graphically displays elbow curves of the log-odds total match scores for both
# true matches and true non-matches (as determined by SEER labels) so that the
# separating hyperplane can be determined.

# Number of bars to show in the histogram--start with 75
N_BINS = 100
# Range of the x axis--start with 0 to 25, and if you can't see the two 
# distribution peaks, increase the right boundary
X_RANGE = [0, 25]
# Density of tick marks along the axis--once you start zooming in, adjust
# this value to get a more precise number of the separation plane
AXIS_TICK_DENSITY = 20
# Display limits for the axis--making these closer together will have the
# effect of zooming the graph in because the chart will still fill the whole
# figure space
AXIS_LIMITS = [0, 25]
YLIM = [0, 50]

import matplotlib.pyplot as plt

def show_profiling_graph(match_scores, non_match_scores):
    fig, ax = plt.subplots()
    fig.set_size_inches(12,6)
    _, bins, _ = plt.hist(match_scores, bins=N_BINS, range=X_RANGE)
    _ = plt.hist(non_match_scores, bins=bins, alpha=0.5)

    # Adjust the density of tick marks here to find the best separation boundary
    ax.xaxis.set_major_locator(plt.MaxNLocator(AXIS_TICK_DENSITY))

    # Use this min and max of the x axis to effectively zoom in
    ax.set_xlim(AXIS_LIMITS)
    ax.set_ylim(YLIM)
    plt.show()

if NOTEBOOK_MODE == "train_weights":
    show_profiling_graph(match_scores, non_match_scores)

## `compare_algorithms` Mode
This final collection of cells contains code that executes the notebook in comaprative mode. We run three algorithms and collect statistics on each of them for comparative purposes. Along the way, each algorithm's results are saved into a linked container, so that they can be loaded during future runs if desired. The three algorithms run are:

* a Python port of LAC's current blocking and matching criteria from their linking R-script--since we're testing the general principles of linkage, we included no post-processing heuristics or additional rules which might boost performance
* the DIBBs Basic algorithm
* the DIBBs Enhanced algorithm, which has the same blocks and passes as the DIBBs Basic, but incorporates statistical correction in the form of additional weights on each field match

In [None]:
# ALGORITHM EVALUATION: LAC EXISTING

LAC_ALGO = [
    {
        "funcs": {
            "first_name": "feature_match_fuzzy_string",
            "last_name": "feature_match_fuzzy_string",
            "address": "feature_match_fuzzy_string",
            "mrn": "feature_match_fuzzy_string",
        },
        "blocks": [
            {"value": "first_name", "transformation": "first4"},
            {"value": "last_name", "transformation": "first4"},
            {"value": "birthdate"},
        ],
        "matching_rule": "eval_perfect_match",
        "cluster_ratio": 0.9,
    },
    {
        "funcs": {
            "first_name": "feature_match_fuzzy_string",
            "last_name": "feature_match_fuzzy_string",
            "address": "feature_match_fuzzy_string",
            "mrn": "feature_match_fuzzy_string",
        },
        "blocks": [
            {"value": "first_name", "transformation": "first4"},
            {"value": "last_name", "transformation": "first4"},
            {"value": "address", "transformation": "first4"},
        ],
        "matching_rule": "eval_perfect_match",
        "cluster_ratio": 0.9,
    },
    {
        "funcs": {
            "first_name": "feature_match_fuzzy_string",
            "last_name": "feature_match_fuzzy_string",
            "address": "feature_match_fuzzy_string",
            "mrn": "feature_match_fuzzy_string",
        },
        "blocks": [
            {"value": "birthdate"},
        ],
        "matching_rule": "eval_perfect_match",
        "cluster_ratio": 0.9,
    },
]

if NOTEBOOK_MODE == "compare_algorithms":
    if LOAD_PREVIOUS_RUNS:
        found_matches_lac = load_linkage_results(spark, "lac_algorithm_results")
    if not LOAD_PREVIOUS_RUNS or found_matches_lac is None:
        found_matches_lac = link_all_fhir_records_block_dataset(evaluation_set, LAC_ALGO, labeling_set)
        found_matches_lac = dedupe_match_double_counts(found_matches_lac)
        write_linkage_results("lac_algorithm_results", found_matches_lac)

In [None]:
# ALGORITHM EVALUATION: DIBBs BASIC

new_algo = copy.deepcopy(DIBBS_BASIC)
col_thresholds = {
    "address": 0.85,
    "birthdate": 0.85,
    "city": 0.85,
    "first_name": 0.85,
    "last_name": 0.85,
    "mrn": 0.85,
    "sex": 0.85,
    "state": 0.85,
    "zip": 0.85
}
if len(FINAL_COLS_TO_THRESHOLDS) > 0:
    print("Using experimentally updated fuzzy thresholds")
    new_algo[0]["kwargs"] = { "thresholds": FINAL_COLS_TO_THRESHOLDS }
    new_algo[1]["kwargs"] = { "thresholds": FINAL_COLS_TO_THRESHOLDS }
else:
    print("Using default uniform fuzzy threshold")
    new_algo[0]["kwargs"] = { "thresholds": col_thresholds }
    new_algo[1]["kwargs"] = { "thresholds": col_thresholds }

if NOTEBOOK_MODE == "compare_algorithms":
    if LOAD_PREVIOUS_RUNS:
        found_matches_dibbs_basic = load_linkage_results(spark, "dibbs_basic_algorithm_results")
    if not LOAD_PREVIOUS_RUNS or found_matches_dibbs_basic is None:
        found_matches_dibbs_basic = link_all_fhir_records_block_dataset(evaluation_set, new_algo, labeling_set)
        found_matches_dibbs_basic = dedupe_match_double_counts(found_matches_dibbs_basic)
        write_linkage_results("dibbs_basic_algorithm_results", found_matches_dibbs_basic)

In [None]:
# ALGORITHM EVALUATION: DIBBs ENHANCED
new_algo = copy.deepcopy(DIBBS_ENHANCED)
col_thresholds = {
    "address": 0.85,
    "birthdate": 0.85,
    "city": 0.85,
    "first_name": 0.85,
    "last_name": 0.85,
    "mrn": 0.85,
    "sex": 0.85,
    "state": 0.85,
    "zip": 0.85
}
del new_algo[0]["kwargs"]["threshold"]
del new_algo[1]["kwargs"]["threshold"]

if len(FINAL_LOG_ODDS) > 0:
    print("Using experimentally calculated log-odds weights")
    new_algo[0]["kwargs"]["log_odds"] = FINAL_LOG_ODDS
    new_algo[1]["kwargs"]["log_odds"] = FINAL_LOG_ODDS
else:
    print("Using default synthetically trained log-odds weights")

if len(FINAL_COLS_TO_THRESHOLDS) > 0:
    print("Using experimentally updated fuzzy thresholds")
    new_algo[0]["kwargs"]["thresholds"] = FINAL_COLS_TO_THRESHOLDS
    new_algo[1]["kwargs"]["thresholds"] = FINAL_COLS_TO_THRESHOLDS
else:
    print("Using default uniform fuzzy threshold")
    new_algo[0]["kwargs"]["thresholds"] = col_thresholds
    new_algo[1]["kwargs"]["thresholds"] = col_thresholds

if FINAL_PASS_1_THRESHOLD is not None:
    print("Using experimental cutoff threshold 1")
    new_algo[0]["kwargs"]["true_match_threshold"] = FINAL_PASS_1_THRESHOLD
else:
    print("Using default cutoff threshold 1")
if FINAL_PASS_2_THRESHOLD is not None:
    print("Using experimental cutoff threshold 2")
    new_algo[1]["kwargs"]["true_match_threshold"] = FINAL_PASS_2_THRESHOLD
else:
    print("Using default cutoff threshold 2")

if NOTEBOOK_MODE == "compare_algorithms":
    if LOAD_PREVIOUS_RUNS:
        found_matches_dibbs_enhanced = load_linkage_results(spark, "dibbs_enhanced_algorithm_results")
    if not LOAD_PREVIOUS_RUNS or found_matches_dibbs_enhanced is None:
        found_matches_dibbs_enhanced = link_all_fhir_records_block_dataset(evaluation_set, new_algo, labeling_set)
        found_matches_dibbs_enhanced = dedupe_match_double_counts(found_matches_dibbs_enhanced)
        write_linkage_results("dibbs_enhanced_algorithm_results", found_matches_dibbs_enhanced)

In [None]:
# RUN THE NUMBERS AND GET THE STATS FUNCTIONS

'''
To ensure accurate statistics, the matches and the true matches dictionaries
in the statistical evaluation function should have the following form:

{
    row_num_of_record_in_data: set(row_nums_of_linked_records)
}

Each row in the data should be represented as a key in both dictionaries.
The value for each of these keys should be a set that contains all other
row numbers for records in the data set that link to the key record.
'''

def score_linkage_vs_truth(found_matches, true_matches, records_in_dataset):

    # Need division by 2 because ordering is irrelevant, matches are symmetric
    total_possible_matches = (records_in_dataset * (records_in_dataset - 1)) / 2.0
    true_positives = 0.0
    false_positives = 0.0
    false_negatives = 0.0

    for root_record in true_matches:
        if root_record in found_matches:
            true_positives += len(
                true_matches[root_record].intersection(found_matches[root_record])
            )
            false_positives += len(
                found_matches[root_record].difference(true_matches[root_record])
            )
            false_negatives += len(
                true_matches[root_record].difference(found_matches[root_record])
            )
        else:
            false_negatives += len(true_matches[root_record])
    for record in set(set(found_matches.keys()).difference(true_matches.keys())):
        false_positives += len(found_matches[record])

    true_negatives = (
        total_possible_matches - true_positives - false_positives - false_negatives
    )
    npv = round((true_negatives / (true_negatives + false_negatives)), 3)
    sensitivity = round(true_positives / (true_positives + false_negatives), 3)
    specificity = round(true_negatives / (true_negatives + false_positives), 3)
    ppv = round(true_positives / (true_positives + false_positives), 3)
    f1 = round(
        (2 * true_positives) / (2 * true_positives + false_negatives + false_positives),
        3,
    )

    return {
        "tp": true_positives,
        "fp": false_positives,
        "fn": false_negatives,
        "sens": sensitivity,
        "spec": specificity,
        "ppv": ppv,
        "npv": npv,
        "f1": f1
    }

if NOTEBOOK_MODE == "compare_algorithms":
    display_str = ""
    n_records = EVALUATION_SIZE

    for lbl_type in ["va", "uk-nhs", "nci-seer"]:
        if lbl_type == "va":
            labels = va_labels
        elif lbl_type == "nci-seer":
            labels = seer_labels
        else:
            labels = uk_labels
        
        stats_dict_lac = score_linkage_vs_truth(found_matches_lac, labels, n_records)
        stats_dict_dibbs_b = score_linkage_vs_truth(found_matches_dibbs_basic, labels, n_records)
        stats_dict_dibbs_e = score_linkage_vs_truth(found_matches_dibbs_enhanced, labels, n_records)

        display_str += "DISPLAYING EVALUATION ON " + lbl_type.upper() + " LABELS:\n"
        display_str += "\n"

        for algo in ["lac", "basic", "enhanced"]:
            if algo == "lac":
                display_str += "LAC Existing Algorithm:\n"
                scores = stats_dict_lac
            elif algo == "basic":
                display_str += "DIBBs Basic Algorithm:\n"
                scores = stats_dict_dibbs_b
            else:
                display_str += "DIBBs Log-Odds Algorithm:\n"
                scores = stats_dict_dibbs_e

            display_str += "True Positives: " + str(scores["tp"]) + "\n"
            display_str += "False Positives: " + str(scores["fp"]) + "\n"
            display_str += "False Negatives: " + str(scores["fn"]) + "\n"
            display_str += "Sensitivity: " + str(scores["sens"]) + "\n"
            display_str += "Specificity: " + str(scores["spec"]) + "\n"
            display_str += "PPV: " + str(scores["ppv"]) + "\n"
            display_str += "NPV: " + str(scores["npv"]) + "\n"
            display_str += "F1: " + str(scores["f1"]) + "\n"
            display_str += "\n"
        
        display_str += "\n"

    print(display_str)

    mssparkutils.fs.put(LINKAGE_OUTPUTS_FILESYSTEM + "results.txt", display_str, True)