### tabularizePatientTable
This notebook extracts all data from the Patient table (`patient`) in the Master Patient Index and tabularizes the data into Lists of Lists of Lists (LoLoL).

In [None]:
pip install azure-identity phdi recordlinkage azure-keyvault-secrets

In [None]:
# IMPORTS AND CONSTANTS

from azure.identity import DefaultAzureCredential

# Ground-truth labeling imports
import time
import pandas as pd
import recordlinkage as rl
from recordlinkage.base import BaseCompareFeature
import numpy as np
from phdi.harmonization import compare_strings

# Set your Key Vault information
vault_name = "$KEY_VAULT"
KEY_VAULT_URL = f"https://{vault_name}.vault.azure.net"
vault_linked_service = "$KEY_VAULT_LINKED_SERVICE"

# Set up db_client
DB_NAME = "DibbsMpiDB"
DB_USER = "postgres"
DB_HOST = "$MPI_DB_HOST"
DB_PORT = "5432"
DB_TABLE_PATIENT = "patient"
DB_TABLE_PERSON= "person"

# Adjust data volume for scaling
# Make sure evaluation size is less than labeling size!
from random import sample
LABELING_SIZE = 10000
EVALUATION_SIZE = 1000

# Ground-truth labeling parameters
WINDOW_INDEX_SIZE = 5
JARO_THRESHOLD = 0.85

In [None]:
# OPEN PARALLEL CONNECTION TO MPI, READ DB INTO COMPRESSED MEMORY

from pyspark.sql import SparkSession, Row
import json

# Access the MPI Database
credential = DefaultAzureCredential()
db_password =  TokenLibrary.getSecret(vault_name,"mpi-db-password",vault_linked_service)

url = f"jdbc:postgresql://{DB_HOST}:{DB_PORT}/{DB_NAME}"
db_props = {
    "user": DB_USER,
    "password": db_password,
    "driver": "org.postgresql.Driver"
}

spark = (
    SparkSession.builder.master("local[*]")
    .appName("Build sub-sampled MPI")
    .getOrCreate()
)

# Parallel pyspark read the MPI and create a SQL window
view_name = "patient_view"
full_mpi_data = spark.read.jdbc(url, DB_TABLE_PATIENT, properties=db_props)
full_mpi_data.createOrReplaceTempView(view_name)

In [None]:
# MPI ACCESS AND TABULATION FUNCTIONS

import pyspark.sql.functions as F

FIELDS_TO_JSONPATHS = {
    "address": """$.address[*].line""",
    "birthdate": "$.birthDate",
    "city": """$.address[*].city""",
    "first_name": """$.name[*].given""",
    "last_name": """$.name[*].family""",
    "codes": """$.identifier[*].type.coding[0].code""",
    "vals": """$.identifier[*].value""",
    "sex": "$.gender",
    "state": """$.address[*].state""",
    "zip": """$.address[*].postalCode""",
}

"""
Generate a query to extract all data from a parallelized extract of the MPI. This function
uses a modified set of FHIRpath keys to efficiently extract the de-serialized JSON objects
holding patient data from the MPI, so that we can filter and process them in parallel to
create a DataFrame representation of the MPI without needing any future network calls to
the DB instance.
"""
def generate_query(view):
    select_query_stubs = []
    query_data = []
    for key in FIELDS_TO_JSONPATHS:
        query = f"get_json_object(patient_resource,'{FIELDS_TO_JSONPATHS[key]}') as {key}"
        select_query_stubs.append(query)
    select_query = "SELECT patient_id, person_id, " + ", ".join(stub for stub in select_query_stubs)
    query = select_query + f" FROM {view};"
    return query


"""
Parallelized mapping function that computes a patient's MRN from two other extracted values
from the MPI database. Pyspark doesn't support JSONB path querying, so we can't directly find
the struct containing MRN within the JSON array of a patient's identifiers. However, the
field's 'codes' array and 'vals' array *can* be efficiently extracted in parallel. Since
this operation preserves order, we can infer the MRN based on a zip of these lists.
"""
def extract_mrn(row):
    mrn = None
    codes = json.loads(row['codes'])
    vals = json.loads(row['vals'])
    mr_idx = [pos for (pos, v) in enumerate(codes) if v == "MR"]
    if len(mr_idx) > 0:
        mrn = vals[mr_idx[0]]
    return Row(row['patient_id'], row['person_id'], row['address'], row['birthdate'], row['city'], row['first_name'], row['last_name'], mrn, row['sex'], row['state'], row['zip'])

# Create the pyspark-supported query and grab non-formatted patient data
query = generate_query(view_name)
extracted_mpi_data = spark.sql(query)

# Checkpoint the DF here so we don't have re-query the pyspark view every time
# lineage evaluates
extracted_mpi_data.cache()

# Derive each patient's MRN and format the columns we want to map the data into
formatted_mpi_data = extracted_mpi_data.rdd.map(extract_mrn)
formatted_cols = [x for x in FIELDS_TO_JSONPATHS.keys() if x != "codes" and x != "vals"]
formatted_cols.append("mrn")
formatted_cols = sorted(formatted_cols)
formatted_cols.insert(0, 'person_id')
formatted_cols.insert(0, 'patient_id')

# Pyspark will read the JSON data as de-serialized strings, which means it will
# actually save the quotes into the string-proper so that it can be re-serialized if
# necessary; however, we don't ever need to re-serialize the MPI (that would be super
# slow) so strip the quotes to make comparisons easier
mpi_df = formatted_mpi_data.toDF(formatted_cols)
for col in formatted_cols:
    if col != "first_name" and col != "address":
        mpi_df = mpi_df.withColumn(col, F.regexp_replace(col, "\"", ""))

# Cache the updated version so we can filter on it later
mpi_df.cache()

# Make a random sampling generator for use on both the labeling and the eval set,
# since the eval set *must* be included in the label set
rng = np.random.default_rng()

# Apply sampling, if desired, and convert into pandas DF for labeling
labeling_set = [list(x) for x in mpi_df.collect()]
sampled_idx = None
if LABELING_SIZE < len(labeling_set):
    sampled_idx = rng.choice(len(labeling_set), LABELING_SIZE, replace=False)
    sampled_idx = set(sampled_idx)
    labeling_set = [v for (i, v) in enumerate(labeling_set) if i in sampled_idx]
labeling_set = pd.DataFrame(labeling_set, columns=formatted_cols)

# Now, we need a copy of the data in a FHIR format for the linkage algorithms
fhir_query = "select patient_resource from " + view_name + ";"
fhir_pull = spark.sql(fhir_query)
evaluation_set = [json.loads(x['patient_resource']) for x in fhir_pull.collect()]

# As above, apply sampling, but use same generator to ensure that eval set
# is a subset of the labeling set
if EVALUATION_SIZE < len(evaluation_set):
    #Re-use labeling sampler, if one was created
    if sampled_idx is not None:
        eval_sample = rng.choice(list(sampled_idx), EVALUATION_SIZE, replace=False, shuffle=False)
    # We kept the whole MPI, so any record can be evaluated
    else:
        eval_sample = rng.choice(len(evaluation_set), EVALUATION_SIZE, replace=False)
    eval_sample = set(eval_sample)
    evaluation_set = [v for (i, v) in enumerate(evaluation_set) if i in eval_sample]

In [None]:
# GROUND TRUTH LABELING: VIRGINIA FUNCTIONS

from recordlinkage.index import SortedNeighbourhood, BaseIndexAlgorithm
from recordlinkage.utils import listify

'''
A custom Indexing function built to operate compatibly on the first_name column
returned from the MPI. Since that's a list of strings (because someone could have
multiple given names), we need a way to cross-conjoin these entries and apply
the same fuzzy blocking filter window that a regular column of strings would get.
This performs joint name concatenation on copies of that column in the data, and
then uses an edit distance neighborhood to find fuzzy blocking candidates.
'''
class FirstNameSortedNeighborhood(BaseIndexAlgorithm):
    def __init__(
        self,
        left_on=None,
        right_on=None,
        window=3,
        sorting_key_values=None,
        block_on=[],
        block_left_on=[],
        block_right_on=[],
        **kwargs
    ):
        super(FirstNameSortedNeighborhood, self).__init__(**kwargs)

        # variables to block on
        self.left_on = left_on
        self.right_on = right_on
        self.window = window
        self.sorting_key_values = sorting_key_values
        self.block_on = block_on
        self.block_left_on = block_left_on
        self.block_right_on = block_right_on

    def _get_left_and_right_on(self):
        """
        We only care about the de-dupe case which involves no self.right, but this
        still needs to be implemented for super compatibility.
        """
        if self.right_on is None:
            return (self.left_on, self.left_on)
        else:
            return (self.left_on, self.right_on)

    def _get_sorting_key_values(self, array1, array2):
        """
        Return the sorting key values as a series. This function is required by the"
        package for multi-index neighborhood filtering according to some papers it's"
        built on.
        """

        concat_arrays = np.concatenate([array1, array2])
        return np.unique(concat_arrays)

    def _link_index(self, df_a, df_b):
        df_a = df_a.copy()
        df_b = df_b.copy()
        df_a["first_name"] = df_a["first_name"].str.join(" ")
        df_b["first_name"] = df_a["first_name"].str.join(" ")
        left_on, right_on = self._get_left_and_right_on()
        left_on = listify(left_on)
        right_on = listify(right_on)
    
        window = self.window

        # Correctly generate blocking keys
        block_left_on = listify(self.block_left_on)
        block_right_on = listify(self.block_right_on)

        if self.block_on:
            block_left_on = listify(self.block_on)
            block_right_on = listify(self.block_on)

        blocking_keys = ["sorting_key"] + [
            "blocking_key_%d" % i for i, v in enumerate(block_left_on)
        ]

        # Format the data to thread with index pairs
        data_left = pd.DataFrame(df_a[listify(left_on) + block_left_on], copy=False)
        data_left.columns = blocking_keys
        data_left["index_x"] = np.arange(len(df_a))
        data_left.dropna(axis=0, how="any", subset=blocking_keys, inplace=True)

        data_right = pd.DataFrame(df_b[listify(right_on) + block_right_on], copy=False)
        data_right.columns = blocking_keys
        data_right["index_y"] = np.arange(len(df_b))
        data_right.dropna(axis=0, how="any", subset=blocking_keys, inplace=True)

        # sorting_key_values is the terminology in Data Matching [Christen,
        # 2012]
        if self.sorting_key_values is None:
            self.sorting_key_values = self._get_sorting_key_values(
                data_left["sorting_key"].values, data_right["sorting_key"].values
            )

        sorting_key_factors = pd.Series(
            np.arange(len(self.sorting_key_values)), index=self.sorting_key_values
        )

        data_left["sorting_key"] = data_left["sorting_key"].map(sorting_key_factors)
        data_right["sorting_key"] = data_right["sorting_key"].map(sorting_key_factors)

        # Internal window size
        _window = int((window - 1) / 2)

        def merge_lagged(x, y, w):
            """Merge two dataframes with a lag on in the sorting key."""

            y = y.copy()
            y["sorting_key"] = y["sorting_key"] + w

            return x.merge(y, how="inner")

        pairs_concat = [
            merge_lagged(data_left, data_right, w) for w in range(-_window, _window + 1)
        ]

        pairs_df = pd.concat(pairs_concat, axis=0)

        return pd.MultiIndex(
            levels=[df_a.index.values, df_b.index.values],
            codes=[pairs_df["index_x"].values, pairs_df["index_y"].values],
            verify_integrity=False,
        )


# Transform a recordlinkage toolkit multi-index into a set of candidate tuples
def get_pred_match_dict_from_multi_idx(mltidx, n_rows):
    candidate_tuples = mltidx.to_list()
    pred_matches = {k: set() for k in range(n_rows)}
    for pair in candidate_tuples:
        reference_record = min(pair)
        linked_record = max(pair)
        pred_matches[reference_record].add(linked_record)
    return pred_matches


# Special class for comparing LoL concatenated elements
# Use the full concatenation of all values to account for multiple entries like given names
class CompareNestedString(BaseCompareFeature):
    def _compute_vectorized(self, s1, s2):
        return (s1.str[0] == s2.str[0]).astype(float)


def get_va_labels(data):
    start = time.time()

    # Create a windowed neighborhood index on patient table because full is 
    # too expensive
    indexer = rl.Index()
    # Adding multiple different neighborhoods takes their union so we don't over-block
    indexer.add(SortedNeighbourhood('last_name', window=WINDOW_INDEX_SIZE))
    indexer.add(SortedNeighbourhood('birthdate', window=WINDOW_INDEX_SIZE))
    indexer.add(SortedNeighbourhood('mrn', window=WINDOW_INDEX_SIZE))
    indexer.add(FirstNameSortedNeighborhood('first_name', window=WINDOW_INDEX_SIZE))
    candidate_links = indexer.index(data)

    
    # Note: using a multi-indexer treats the row number as the index, so
    # results will automatically be in acceptable eval format

    print(len(candidate_links), "candidate pairs identified")

    # Apply feature comparisons on each supported field from the MPI
    comp = rl.Compare()
    comp.add(CompareNestedString("first_name", "first_name",label="first_name"))
    comp.exact("last_name", "last_name", label="last_name")
    comp.exact("birthdate", "birthdate", label="birthdate")
    comp.add(CompareNestedString("address", "address", label="address"))
    features = comp.compute(candidate_links, data)
    matches = features[features.sum(axis=1) == 4]

    end = time.time()
    print("Computation took", str(round(end - start, 2)), "seconds")

    matches = get_pred_match_dict_from_multi_idx(matches.index, len(data))
    return matches, candidate_links


va_labels, candidate_links = get_va_labels(labeling_set)

In [None]:
# GROUND-TRUTH LABELING: RECORD LINKAGE TOOLKIT FUNCTIONS

# Special class for comparing LoL first name elements
# Use the full concatenation of all names to account for multiple given names
class CompareFirstName(BaseCompareFeature):
    def _compute_vectorized(self, s1, s2):
        jarowinklers = np.vectorize(compare_strings)(s1.str.join(" "), s2.str.join(" "))
        return jarowinklers >= JARO_THRESHOLD


# Special class for comparing LoL address line elements
# Check each address line against each other address line to account for moving
class CompareAddress(BaseCompareFeature):
    def _compute_vectorized(self, s1, s2):

        def comp_address_fields(a1_list, a2_list):
            best_score = 0.0
            for a1 in a1_list:
                for a2 in a2_list:
                    score = compare_strings(a1, a2)
                    if score >= best_score:
                        best_score = score
            return best_score

        jarowinklers = np.vectorize(comp_address_fields)(s1, s2)
        return jarowinklers >= JARO_THRESHOLD
    

def predict_third_party_labels(data, candidate_links):
    start = time.time()

    # Apply feature comparisons on each supported field from the MPI
    comp = rl.Compare()
    comp.add(CompareFirstName("first_name", "first_name",label="first_name"))
    comp.string(
        "last_name", "last_name", method="jarowinkler", threshold=JARO_THRESHOLD, label="last_name"
    )
    comp.string("mrn", "mrn", method="jarowinkler", threshold=JARO_THRESHOLD, label="mrn")
    comp.string(
        "birthdate", "birthdate", method="jarowinkler", threshold=JARO_THRESHOLD, label="birthdate"
    )
    comp.add(CompareAddress("address", "address", label="address"))
    comp.string("city", "city", method="jarowinkler", threshold=JARO_THRESHOLD, label="city")
    comp.string("zip", "zip", method="jarowinkler", threshold=JARO_THRESHOLD, label="zip")
    features = comp.compute(candidate_links, data)

    # Create an EM Predictor and label the binary training vectors
    clf = rl.ECMClassifier()
    pred_links = clf.fit_predict(features)

    end = time.time()
    print("Computation took", str(round(end - start, 2)), "seconds")

    matches = get_pred_match_dict_from_multi_idx(pred_links, len(data))
    return matches


third_party_labels = predict_third_party_labels(labeling_set, candidate_links)

In [None]:
# LINKAGE DRIVER FUNCTIONS

from phdi.linkage.link import _flatten_patient_resource, extract_blocking_values_from_record
from typing import List


"""
Function to use a previously-computed parallel dataframe to quickly block possible candidates
for an incoming record. Since the supplied `mpi_df` is a native spark DataFrame, we can use
hyper-efficient `where` filters to block on the appropriate string columns of MPI candidates
rather than use a database client to retrieve information over the network. We also cache
the filtered dataframe so that when we later use an RDD to operate on it, the filters are
preserved in memory without needing to re-apply the lineage.
"""
def spark_block(block_vals: dict, mpi_df):

    # We'll sequentially apply each blocking filter, since that's equivalent to finding
    # their intersection all at once
    result = mpi_df
    result.cache()
    for blocking_criterion in block_vals:
        props = block_vals[blocking_criterion]

        # Special case if we're blocking on first_name or address: pyspark can serialize these
        # as JSON strings, but that means they actually get stored as strings, so we need to 
        # account for the brackets '[' and ']'
        if blocking_criterion == "first_name" or blocking_criterion == "address":
            if "transformation" in props:
                if props["transformation"] == "first4":
                    result = result.where(result[blocking_criterion].startswith("[\"" + props["value"]))
                elif props["transformation"] == "last4":
                    result = result.where(result[blocking_criterion].endswith(props["value"] + "\"]"))
            else:
                result = result.where(result[blocking_criterion] == "[\"" + props["value"] + "\"]")

        # Regular case is just a straight string comparison since we've already stripped the 
        # de-serialization quotes
        else:
            if "transformation" in props:
                if props["transformation"] == "first4":
                    result = result.where(result[blocking_criterion].startswith(props["value"]))
                elif props["transformation"] == "last4":
                    result = result.where(result[blocking_criterion].endswith(props["value"]))
            else:
                result = result.where(result[blocking_criterion] == props["value"])
    return result


"""
Quick helper to extract the threshold and metric used in fuzzy string comparisons.
We have this to not clutter the main analytic function.
"""
def _get_fuzzy_comp_params(**kwargs):
    similarity_measure = "JaroWinkler"
    if "similarity_measure" in kwargs:
        similarity_measure = kwargs["similarity_measure"]
    threshold = 0.7
    if "threshold" in kwargs:
        threshold = kwargs["threshold"]
    return similarity_measure, threshold


"""
Helper to apply the result of a feature-wise comparison between an incoming record and a 
candidate from the MPI to the accumulated 'match score' of the two. In a 'normal' case
where we're not using log-odds, this is just a count of the number of feature comparisons
that satisfy the fuzzy string threshold. In the log-odds case, this is an accumulation of
the weighted probability score that the two records are a match.
"""
def _apply_score_contribution(feature_score, col, fuzzy_threshold, match_score, match_rule, **kwargs):
    if "log" in match_rule:
        col_odds = kwargs["log_odds"][col]
        match_score += (feature_score * col_odds)
    else:
        if feature_score >= fuzzy_threshold:
            match_score += 1.0
    return match_score


"""
Main parallel driver function that gets map-applied to each row of an RDD constructed
from the candidate block filtered from the MPI. Each such candidate row (record) is
iteratively processed by each feature comparison function, and the results are 
accumulated into a total score that is used by the matching rule to decide if the two
records should be linked.

This function will be distributed to an executor on a worker node once Spark actually
invokes the parallel evaluation, so even though we're sequentially processing the comparison
functions for each feature, those are (comparatively) hyper-efficient because they
just aggregate equality and greather than comparisons. This allowas each worker node in
the cluster to process multiple rows of the candidate block simultaneously.
"""
def spark_compare_map_helper(row, flattened_record, funcs, col_to_idx, matching_rule, **kwargs):

    # Iteratively accumulate results of each feature-wise comparison
    match_score = 0.0
    for col in funcs:
        func = funcs[col]
        feature_idx_in_record = col_to_idx[col]
        feature_in_record = flattened_record[feature_idx_in_record]

        if "fuzzy" in func:
            similarity_measure, fuzzy_threshold = _get_fuzzy_comp_params(**kwargs)

            # Given name is a list (possibly including middle name), so our logic says
            # concatenate all the values together and then fuzzy compare
            if col == "first_name":
                feature_in_record = " ".join(feature_in_record)
                feature_in_mpi = " ".join(json.loads(row[col]))
                feature_score = compare_strings(feature_in_mpi, feature_in_record, similarity_measure)
                match_score = _apply_score_contribution(
                    feature_score, col, fuzzy_threshold, match_score, matching_rule, **kwargs
                )

            # Address is also a list, but rather than concatenate them all, we check if each
            # line of an incoming address matches any line of an MPI address; this accounts for
            # a patient's change of residence history
            elif col == "address":
                feature_in_mpi = json.loads(row[col])
                best_score = 0.0
                for r in feature_in_record:
                    for m in feature_in_mpi:
                        feature_comp = compare_strings(r, m, similarity_measure)
                        if feature_comp > best_score:
                            best_score = feature_comp
                match_score = _apply_score_contribution(
                    best_score, col, fuzzy_threshold, match_score, matching_rule, **kwargs
                )
            
            # Regular case: straight string comparison on the fields
            else:
                feature_in_mpi = row[col]
                feature_score = compare_strings(feature_in_mpi, feature_in_record, similarity_measure)
                match_score = _apply_score_contribution(
                    feature_score, col, fuzzy_threshold, match_score, matching_rule, **kwargs
                )
        else:
            pass

    return (row['patient_id'], match_score)
    

"""
Orchestrator function that provisions the RDD-mapping of the parallel candidate evaluation.
Once we've parallel-processed the candidates, we apply RDD filtering to identify only those
candidates who satisfy the provided matching rule. (be that "all feature-wise comparisons are
true" or "total probability score exceeds log-odds cutoff").
"""
def spark_compare(data_block, flattened_record, funcs, col_to_idx, matching_rule, **kwargs):
    res = data_block.rdd.map(lambda row: spark_compare_map_helper(
        row, flattened_record, funcs, col_to_idx, matching_rule, **kwargs
    ))
    if "log" in matching_rule:
        match_cutoff = kwargs["true_match_threshold"]
    else:
        match_cutoff = len(funcs)
    match_list = []
    match_list = res.filter(lambda row: row[1] >= match_cutoff).map(lambda x: x[0])
    return match_list


'''
A parallel-optimized linkage method that uses a native Spark dataframe extract
of the MPI, coupled with RDD transformation and explosion operations, to rapidly
identify all candidates in a provisioned block who are "true" matches to an 
incoming record.
'''
def link_fhir_record_from_dataset(
    record: dict,
    algo_config: List[dict],
    formatted_cols,
    mpi_df
) -> List:

    # Flatten incoming resource and remove any lurking None's
    flattened_record = _flatten_patient_resource(record)
    if flattened_record[2] is None:
        flattened_record[2] = [""]
    if flattened_record[5] is None:
        flattened_record[5] = [""]


    # Accumulate all matches across all passes to return
    # We'll do that as a list of RDDs, since we can leverage schema-less
    # unions later on to rapidly put the partitions adjacent to one another
    compiled_rdds = []
    for linkage_pass in algo_config:
        blocking_fields = linkage_pass["blocks"]
        field_blocks = extract_blocking_values_from_record(record, blocking_fields)
        if len(field_blocks) == 0:
            continue
        
        # Use the extract of the MPI to quickly filter down a block of candidates
        data_block = spark_block(field_blocks, mpi_df)
        col_to_idx = {v: k for k, v in enumerate(formatted_cols)}

        # Parallel process the candidates to find any matches
        # Make sure that we cache the results *prior* to explosion, since this will save
        # us very expensive lineage evaluations down the road
        kwargs = linkage_pass.get("kwargs", {})
        matching_records = spark_compare(
            data_block, flattened_record, linkage_pass["funcs"], col_to_idx, linkage_pass["matching_rule"], **kwargs
        )
        matching_records.cache()
        exploded_rdd = matching_records.map(lambda x: (flattened_record[0], x))
        compiled_rdds.append(exploded_rdd)

    return compiled_rdds


'''
Turn the patient_ids of identified "found matches" into the threaded multi-row-indices
that the ground truth labeler can understand. This way, all indices are expressed in
the same scheme for statistical comparison.
'''
def map_patient_ids_to_idxs(pids: List, data: pd.DataFrame):
    record_idxs = []
    for pid in pids:
        row_idx = data[data['patient_id'] == pid].index.values
        if len(row_idx) > 0:
            record_idxs.append(row_idx[0])
    return record_idxs


"""
Simple helper that converts a list of MPI-based UUID patient IDs into indices of those
patients in the labeling dataset. This allows us to quickly evaluate stats later.
"""
def _map_ids(full_ids, label_set):
    return map_patient_ids_to_idxs(full_ids, label_set)


"""
Helper function to massage the result of a parallel groupBy-and-reduce into a structure
where each incoming record in the dataset-to-evaluate is given a single row in the output,
with all the IDs of the patients it linked to as the value to its key.
"""
def _matches_formatter(row, label_set):
    new_row = (row[0], [x[1] for x in list(set(row[1]))])
    new_row = (map_patient_ids_to_idxs([new_row[0]], label_set)[0], _map_ids(new_row[1], label_set))
    new_row = (new_row[0], set([x for x in new_row[1] if x != new_row[0]]))
    return new_row


'''
Find existing patient records in a dataset that map to each incoming record in a block 
of FHIR data. Since the FHIR data itself is pulled from the MPI, we can freely use it
for querying for blocks without risk of finding unrecognized data.
'''
def link_all_fhir_records_block_dataset(records: List, algo_config: List[dict], label_set, formatted_cols, mpi_df, spark):
    found_matches = {}
    linked_rdds = []
    start = time.time()
    for record in records:
        ridx = map_patient_ids_to_idxs([record.get("id")], label_set)[0]
        linked_records = link_fhir_record_from_dataset(record, algo_config, formatted_cols, mpi_df)
        linked_rdds += linked_records

    # Take giant union over all RDDs to truncate the lineage chain, caching again to avoid
    # re-computing the feature evaluations
    sc = spark.sparkContext
    all_matches = sc.union(linked_rdds)
    all_matches.cache()
    match_groups = all_matches.groupBy(lambda x: x[0])
    
    # Now we have one giant rdd where ecach row is a tuple of (patient_id, IterableOverMatches)
    # Re-format that, collect the results, and we're done!
    formatted_matches = match_groups.map(lambda row: _matches_formatter(row, label_set))
    print("finished linking ", str(time.time() - start))
    found_matches = formatted_matches.collectAsMap()
    return found_matches


'''
Due to transforming patient_ids back into indices, multiple tuples get inserted for each
match, i.e. we record the link (i,j) and the link (j,i), which would skew our stats.
This function eliminates these redundancies and makes sure each link is counted once.
'''
def dedupe_match_double_counts(match_dict):
    for k in match_dict:
        if k > 0:
            lower_set = set(list(range(k)))
            match_dict[k] = match_dict[k].difference(lower_set)
    return match_dict

In [None]:
# ALGORITHM EVALUATION: LAC EXISTING

LAC_ALGO = [
    {
        "funcs": {
            "first_name": "feature_match_fuzzy_string",
            "last_name": "feature_match_fuzzy_string",
            "address": "feature_match_fuzzy_string",
            "mrn": "feature_match_fuzzy_string",
        },
        "blocks": [
            {"value": "first_name", "transformation": "first4"},
            {"value": "last_name", "transformation": "first4"},
            {"value": "birthdate"},
        ],
        "matching_rule": "eval_perfect_match",
        "cluster_ratio": 0.9,
    },
    {
        "funcs": {
            "first_name": "feature_match_fuzzy_string",
            "last_name": "feature_match_fuzzy_string",
            "address": "feature_match_fuzzy_string",
            "mrn": "feature_match_fuzzy_string",
        },
        "blocks": [
            {"value": "first_name", "transformation": "first4"},
            {"value": "last_name", "transformation": "first4"},
            {"value": "address", "transformation": "first4"},
        ],
        "matching_rule": "eval_perfect_match",
        "cluster_ratio": 0.9,
    },
    {
        "funcs": {
            "first_name": "feature_match_fuzzy_string",
            "last_name": "feature_match_fuzzy_string",
            "address": "feature_match_fuzzy_string",
            "mrn": "feature_match_fuzzy_string",
        },
        "blocks": [
            {"value": "birthdate"},
        ],
        "matching_rule": "eval_perfect_match",
        "cluster_ratio": 0.9,
    },
]

found_matches_lac = link_all_fhir_records_block_dataset(evaluation_set, LAC_ALGO, labeling_set, formatted_cols, mpi_df, spark)
found_matches_lac = dedupe_match_double_counts(found_matches_lac)

In [None]:
# ALGORITHM EVALUATION: DIBBs BASIC
from phdi.linkage import DIBBS_BASIC
found_matches_dibbs_basic = link_all_fhir_records_block_dataset(evaluation_set, DIBBS_BASIC, labeling_set, formatted_cols, mpi_df, spark)
found_matches_dibbs_basic = dedupe_match_double_counts(found_matches_dibbs_basic)

In [None]:
# ALGORITHM EVALUATION: DIBBs ENHANCED
from phdi.linkage import DIBBS_ENHANCED
found_matches_dibbs_enhanced = link_all_fhir_records_block_dataset(evaluation_set, DIBBS_ENHANCED, labeling_set, formatted_cols, mpi_df, spark)
found_matches_dibbs_enhanced = dedupe_match_double_counts(found_matches_dibbs_enhanced)

In [None]:
# MOUNT THE FILE SYSTEM SO WE CAN WRITE THE OUTPUTS TO FILES

# Set paths
STORAGE_ACCOUNT = "$STORAGE_ACCOUNT"
LINKAGE_OUTPUTS_FILESYSTEM = f"abfss://linkage-notebook-outputs@{STORAGE_ACCOUNT}.dfs.core.windows.net/"
BLOB_STORAGE_LINKED_SERVICE = "$BLOB_STORAGE_LINKED_SERVICE"

from notebookutils import mssparkutils

# Set up for writing to blob storage
linkage_bucket_name = "linkage-notebook-outputs"
blob_sas_token = mssparkutils.credentials.getConnectionStringOrCreds(BLOB_STORAGE_LINKED_SERVICE)
wasb_path = 'wasbs://%s@%s.blob.core.windows.net/' % (linkage_bucket_name, STORAGE_ACCOUNT)
spark.conf.set('fs.azure.sas.%s.%s.blob.core.windows.net' % (linkage_bucket_name, STORAGE_ACCOUNT), blob_sas_token)
# Try mounting the remote storage directory at the mount point
try:
    mssparkutils.fs.mount(
        wasb_path,
        "/",
        {"LinkedService": f"${BLOB_STORAGE_LINKED_SERVICE}"}
    )
except:
    print("Already mounted")

In [None]:
# RECOMPUTE AND EXPORT LOG-ODDS

from phdi.linkage import calculate_m_probs, calculate_u_probs, calculate_log_odds
import json

m_probs = calculate_m_probs(labeling_set, third_party_labels)
u_probs = calculate_u_probs(labeling_set, third_party_labels, n_samples=25000)
log_odds = calculate_log_odds(m_probs, u_probs)
log_odds.pop("patient_id")
mssparkutils.fs.put(LINKAGE_OUTPUTS_FILESYSTEM + "updated_log_odds.json", json.dumps(log_odds), True)

In [None]:
# RUN THE NUMBERS AND GET THE STATS FUNCTIONS

'''
To ensure accurate statistics, the matches and the true matches dictionaries
in the statistical evaluation function should have the following form:

{
    row_num_of_record_in_data: set(row_nums_of_linked_records)
}

Each row in the data should be represented as a key in both dictionaries.
The value for each of these keys should be a set that contains all other
row numbers for records in the data set that link to the key record.
'''
def score_linkage_vs_truth(found_matches, true_matches, num_eval_records, num_label_records):

    # Formula is: for m=num_recs in eval_set, n=num_recs in label_set,
    # m * (n-1) - [1/2 (m * (m-1))]--each record could match with every other record
    # in the MPI besides itself; this double counts matches in the eval_set though,
    # since they're also in the MPI, so subtract those away to get accurate results
    total_possible_matches = (num_eval_records * (num_label_records - 1.0)) - ((num_eval_records * (num_eval_records - 1.0)) / 2.0)
    true_positives = 0.0
    false_positives = 0.0
    false_negatives = 0.0

    for root_record in true_matches:
        if root_record in found_matches:
            true_positives += len(
                true_matches[root_record].intersection(found_matches[root_record])
            )
            false_positives += len(
                found_matches[root_record].difference(true_matches[root_record])
            )
            false_negatives += len(
                true_matches[root_record].difference(found_matches[root_record])
            )
        else:
            false_negatives += len(true_matches[root_record])
    for record in set(set(found_matches.keys()).difference(true_matches.keys())):
        false_positives += len(found_matches[record])

    true_negatives = (
        total_possible_matches - true_positives - false_positives - false_negatives
    )

    # Calculate some stats, but watch out for division by 0 if we happened to pick
    # a bad subsample
    if true_negatives + false_negatives > 0:
        npv = round((true_negatives / (true_negatives + false_negatives)), 3)
    else:
        npv = "N/A"
    if true_positives + false_negatives > 0:
        sensitivity = round(true_positives / (true_positives + false_negatives), 3)
    else:
        sensitivity = "N/A"
    if true_negatives + false_positives > 0:
        specificity = round(true_negatives / (true_negatives + false_positives), 3)
    else:
        specificity = "N/A"
    if true_positives + false_positives > 0:
        ppv = round(true_positives / (true_positives + false_positives), 3)
    else:
        ppv = "N/A"
    if (2 * true_positives + false_negatives + false_positives) > 0:
        f1 = round(
            (2 * true_positives) / (2 * true_positives + false_negatives + false_positives), 3
        )
    else:
        f1 = "N/A"

    return {
        "tp": true_positives,
        "fp": false_positives,
        "fn": false_negatives,
        "sens": sensitivity,
        "spec": specificity,
        "ppv": ppv,
        "npv": npv,
        "f1": f1
    }

display_str = ""

if RECORDS_TO_SAMPLE is not None:
    n_records = RECORDS_TO_SAMPLE
else:
    n_records = len(evaluation_set)

for lbl_type in ["va", "emc"]:
    if lbl_type == "va":
        labels = va_labels
    else:
        labels = third_party_labels
    
    stats_dict_lac = score_linkage_vs_truth(found_matches_lac, labels, n_records)
    stats_dict_dibbs_b = score_linkage_vs_truth(found_matches_dibbs_basic, labels, n_records)
    stats_dict_dibbs_e = score_linkage_vs_truth(found_matches_dibbs_enhanced, labels, n_records)

    display_str += "DISPLAYING EVALUATION ON " + lbl_type.upper() + " LABELS:\n"
    display_str += "\n"

    for algo in ["lac", "basic", "enhanced"]:
        if algo == "lac":
            display_str += "LAC Existing Algorithm:\n"
            scores = stats_dict_lac
        elif algo == "basic":
            display_str += "DIBBs Basic Algorithm:\n"
            scores = stats_dict_dibbs_b
        else:
            display_str += "DIBBs Log-Odds Algorithm:\n"
            scores = stats_dict_dibbs_e

        display_str += "True Positives: " + str(scores["tp"]) + "\n"
        display_str += "False Positives: " + str(scores["fp"]) + "\n"
        display_str += "False Negatives: " + str(scores["fn"]) + "\n"
        display_str += "Sensitivity: " + str(scores["sens"]) + "\n"
        display_str += "Specificity: " + str(scores["spec"]) + "\n"
        display_str += "PPV: " + str(scores["ppv"]) + "\n"
        display_str += "NPV: " + str(scores["npv"]) + "\n"
        display_str += "F1: " + str(scores["f1"]) + "\n"
        display_str += "\n"
    
    display_str += "\n"

print(display_str)

mssparkutils.fs.put(LINKAGE_OUTPUTS_FILESYSTEM + "results.txt", display_str, True)