# GenAI based rule recommendations for AWS Entity Resolution
This notebook will walk you through the process of building a GenAI based rule recommendation for [AWS Entity Resolution](https://aws.amazon.com/entity-resolution/) using a Large Language Model (LLM) hosted on [Amazon Bedrock](https://aws.amazon.com/bedrock/). We will use an existing AWS Entity Resolution [Schema Mapping](https://docs.aws.amazon.com/entityresolution/latest/userguide/schema-mapping.html) as a source. We will generate data quality metrics of the source data and provide that as an input to the prompt for the LLM model. The LLM model will provide a rule recommendation based on the input prompt with its reasoning.

<div class="alert alert-block alert-info">
<b>Pre-requisites for this notebook:</b>
    <ul>
        <li>IAM Role:
            <ul>
                <li>Access to AWS Entity Resolution</li>
                <li>Access to invoke the Foundation Models on Amazon Bedrock</li>
                <li>Read access to Amazon S3 source data bucket</li>
                <li><i>Optional:</i>
                    <ul><li>If you want to access <a href="https://docs.aws.amazon.com/codewhisperer/latest/userguide/what-is-cwspr.html">Amazon Code Whisperer</a> to get machine learning-powered code generator that provides you with code recommendations, use this <a href='https://docs.aws.amazon.com/codewhisperer/latest/userguide/glue-setup.html'>link</a> to setup.</li></ul>
            </ul>
        </li>
        <li>This notebook is recommended to be run with <i>Glue 5.0</i> and boto3 version <i>1.34.131</i> or above</li>
        <li>At the time of writing this notebook, Amazon Bedrock was only available in <a href="https://docs.aws.amazon.com/bedrock/latest/userguide/what-is-bedrock.html#bedrock-regions">these supported AWS Regions</a>. If you are running this notebook from any other AWS Region, then you have to change the Amazon Bedrock client's region and/or endpoint URL parameters to one of those supported AWS Regions. Follow the guidance in the <i>Organize imports</i> section of this notebook.</li>        
    </ul>
</div>

## Setting the PySpark session

In [None]:
%additional_python_modules networkx
%idle_timeout 90
%glue_version 5.0
%worker_type G.1X
%number_of_workers 5

In [None]:
spark

**Table of Contents:**

1. [Pre-requisites](#1)

    1.a. [Set AWS Region and boto3 config](#1.a)
  
    1.b. [Organize imports](#1.b)
       
    1.c. [Enable model access in Amazon Bedrock](#1.c)
    
    1.d. [Check and configure security permissions](#1.d)
    
 2. [Data Preparation](#2)
 
    2.a. [Fetch the AWS Entity Resolution schema mapping for your input dataset](#2.a)        
        
    2.b. [Start the interactive PySpark session](#2.b)
    
    2.c. [Read and prepare input data from the AWS Glue table defined in the AWS Glue Data Catalog](#2.c)
    
 3. [Generate Data Quality metrics](#3)
    
    3.a. [Calculate the percentage distribution of empty/missing values for every column](#3.a)
    
    3.b. [Calculate the frequency distribution of Top 3 values for every column](#3.b)
    
 4. [Invoke the LLM Model](#4)
    
    4.a. [Provide the model id of the LLM model to use with Amazon Bedrock](#4.a)    

    4.b. [Prepare the LLM prompt for getting the rule recommendation](#4.b)
    
    4.c. [Invoke the LLM model for getting the rule recommendation](#4.c)
    
    4.d. [Extract and parse the rule generated by LLM](#4.d)
    
    4.e. [Validate the rule on input dataset](#4.e)

## 1. Pre-requisites <a id='1'></a>

#### 1.a AWS Region and boto3 config <a id="1.a"></a>
Get the current AWS Region (where this notebook is running). This will be used to initiate some of the clients to AWS services using the boto3 APIs.

<div class="alert alert-block alert-info">
    <b>Note:</b> All the AWS services used by this notebook except Amazon Bedrock will use the current AWS Region. For Bedrock, follow the guidance in the next section.
</div>

<div class="alert alert-block alert-warning">  
<b>Note:</b> At the time of writing this notebook, Amazon Bedrock was only available in <a href="https://docs.aws.amazon.com/bedrock/latest/userguide/what-is-bedrock.html#bedrock-regions">these supported AWS Regions</a>. If you are running this notebook from any other AWS Region, then you have to change the Amazon Bedrock client's region and/or endpoint URL parameters to one of those supported AWS Regions. In order to do this, this notebook will use the value specified in the environment variable named <mark>AMAZON_BEDROCK_REGION</mark>. If this is not specified, then the notebook will default to <mark>us-east-1 (N. Virginia)</mark> for Amazon Bedrock.
</div>


In [None]:
import boto3
from botocore.config import Config
import os

print("boto3 version: {} and expected is 1.34.131 or above".format(boto3.__version__))

my_session = boto3.session.Session()

my_region = my_session.region_name
print("Current AWS Region: {}".format(my_region))

# Explicity set the AWS Region for Amazon Bedrock clients
AMAZON_BEDROCK_DEFAULT_REGION = "us-east-1"
br_region = os.environ.get('AMAZON_BEDROCK_REGION')

if br_region is None:
    br_region = AMAZON_BEDROCK_DEFAULT_REGION
elif len(br_region) == 0:
    br_region = AMAZON_BEDROCK_DEFAULT_REGION
print("AWS Region for Amazon Bedrock: {}".format(br_region))

Set the timeout and retry configurations that will be applied to all the boto3 clients used in this notebook.

In [None]:
# Increase the standard time out limits in the boto3 client from 1 minute to 3 minutes
# and set the retry limits
my_boto3_config = Config(
    connect_timeout = (60 * 3),
    read_timeout = (60 * 3),
    retries = {
        'max_attempts': 600,
        'mode': 'adaptive'
    }
)

#### 1.b Organize imports <a id='1.b'></a>

In [None]:
import sys
from awsglue.transforms import *
from awsglue.utils import getResolvedOptions
from pyspark.context import SparkContext
from awsglue.context import GlueContext
from awsglue.job import Job

from pyspark.sql import DataFrame
from pyspark.sql.functions import concat,concat_ws, col, lit, trim,lower,isnan,when,count,round


#### 1.c Enable model access in Amazon Bedrock <a id ='1.c'> </a>

<div class="alert alert-block alert-info">
    <b>Note:</b> Before invoking any model in Amazon Bedrock, enable access to that model by following the instructions <a href="https://docs.aws.amazon.com/bedrock/latest/userguide/model-access.html">here</a>. In addition, for Anthropic models, you need to submit the use case details. Otherwise, you will get an authorization error.
    <br /><br />
    By default, this notebook is expecting access to <b>Claude 3.5 Sonnet v2</b> model. Please access request to the model using the below link if you do not have it already.
</div>

Run the following cell to print the Amazon Bedrock model access page URL for the AWS Region that was selected earlier.

In [None]:
# Print the Amazon Bedrock model access page URL
print("Amazon Bedrock model access page - https://{}.console.aws.amazon.com/bedrock/home?region={}#/modelaccess"
             .format(br_region, br_region))

## 2. Data Preparation <a id="2" />

#### 2.a Fetch the AWS Entity Resolution schema mapping for your input dataset <a id='2.a'></a>

<div class="alert alert-info alert-block"><b>Note: </b>Update the AWS Glue Database, AWS Glue Table, and the AWS Entity Resolution schema mapping name based on your environment</div>

In [None]:
schemaName = 'sourceSchemaMapping'
awsGlueDatabase = 'entityresolution'
awsGlueTable = 'sourceGlueTable'

In [None]:
aerClient = boto3.client('entityresolution')

#fetch the schema mapping definition specified above 
aerSchemaResponse = aerClient.get_schema_mapping(
    schemaName=schemaName
)
print(aerSchemaResponse)

#### Parse the schema mapping definition and extract the relevant fields, their types and the group definitions

In [None]:
schemaFieldList = aerSchemaResponse['mappedInputFields']

# a static ordering of how the Name group is created within AWS Entity Resolution
nameOrder = ['NAME_FIRST','NAME_MIDDLE','NAME_LAST']

# a static ordering of how the Address group is created within AWS Entity Resolution
addressOrder = ['ADDRESS_STREET1','ADDRESS_STREET2','ADDRESS_STREET3','ADDRESS_CITY','ADDRESS_STATE','ADDRESS_POSTALCODE','ADDRESS_COUNTRY']

addressDict = []
nameDict = []
fields = []
uniqueIdColumn = ''

addressGroupName = ''
nameGroupName = ''

for field in schemaFieldList:
    
    if(field["type"] == 'UNIQUE_ID'):
        uniqueIdColumn = field['fieldName']
        fields.append(field['fieldName'])
    
    if('matchKey' in field):
        fields.append(field['fieldName'])

        if("groupName" in field):
            if(field['matchKey'] == 'Address'):
                addressGroupName = field['groupName']
                addressDict.append(field)
                fields.remove(field['fieldName'])
            if(field['matchKey'] == 'Name'):
                nameGroupName = field['groupName']        
                nameDict.append(field)
                fields.remove(field['fieldName'])

def orderedFieldListForGroup(typeList,typeDict):
    finalList = []
    for field in typeList:
        for field2 in typeDict:
            if(field2['type'] == field):
                finalList.append(field2['fieldName'])
    return finalList

addressList = orderedFieldListForGroup(addressOrder, addressDict)
nameList = orderedFieldListForGroup(nameOrder, nameDict)


#### 2.b Start the PySpark interactive session <a id='2.b'></a>

In [None]:
sc = SparkContext.getOrCreate()
glueContext = GlueContext(sc)
spark = glueContext.spark_session
job = Job(glueContext)

#### 2.c Read and prepare the input data from the AWS Glue table defined in the AWS Glue Data Catalog <a id="2.c" />

The next few cells read the input data using AWS Glue Dynamic frame, and convert it into Spark DataFrame. It applies transformations to create the final DataFrame representing the shape of the schema as defined in the AWS Entity Resolution schema mapping.


In [None]:
dyf = glueContext.create_dynamic_frame.from_catalog(database=awsGlueDatabase, table_name=awsGlueTable)
dyf.printSchema()

#### Convert the DynamicFrame to a Spark DataFrame and display a sample of the data


In [None]:
df = dyf.toDF()
df.show()

#### Only select the columns that are part of the schema mapping definition and create the concatenated group based on the schema definition

In [None]:
from pyspark.sql.functions import concat,concat_ws, col, lit, trim,lower,isnan,when,count,round

filteredDF = df.select(*fields,*addressList, *nameList)

if(addressGroupName!=''):
    filteredDF = filteredDF.withColumn(addressGroupName,concat_ws(" ",*addressList)).drop(*addressList)
    
if(nameGroupName!=''):    
    filteredDF = filteredDF.withColumn(nameGroupName,concat_ws(" ",*nameList)).drop(*nameList)


def trim_all_string_columns(df: DataFrame) -> DataFrame:
    return df\
        .select(
            *[trim(col(c[0])).alias(c[0]) if c[1] == 'string' else col(c[0]) for c in df.dtypes]
        )

def lowercase_all_string_columns(df: DataFrame) -> DataFrame:
    return df\
        .select(
            *[lower(col(c[0])).alias(c[0]) if c[1] == 'string' else col(c[0]) for c in df.dtypes]
        )

    

#mimicing AWS Entity Resolution normalization by converting every string column to lower case and removing all whitespaces
filteredDF = filteredDF.transform(lowercase_all_string_columns)
#filteredDF = filteredDF.transform(trim_all_string_columns)


filteredDF.printSchema()


#### View sample records

In [None]:
filteredDF.show()

## 3. Generate Data Quality metrics <a id="3" />


#### Calculate the total number of records in the input dataset

In [None]:
recordCount = filteredDF.count()
print(recordCount)

#### 3.a Calculate the percentage distribution of empty/missing values for every column <a id="3.a" />

In [None]:
from pyspark.sql.functions import col, lit, trim,lower,isnan,when,count,round
emptyValuesDF = filteredDF.select([round((count(when(col(c).contains('None') | \
                            col(c).contains('NULL') | \
                            (col(c) == '' ) | \
                            col(c).isNull() | \
                            isnan(c), c 
                           ))/recordCount*100),2).alias(c)
                    for c in filteredDF.columns])

emptyValuesDict = dict()
for column in emptyValuesDF.columns:
    emptyValuesDict[column] = emptyValuesDF.select(column).first()[0]
    
emptyValuesDF.show()

#### Draw a graph of the percentage distribution

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

#sns.set_theme(style="whitegrid")

p = emptyValuesDF.toPandas()
fig, ax = plt.subplots(figsize=(8, 8))

emptyC = emptyValuesDF.first()

colList = emptyValuesDF.columns
emptyList = list(emptyC.asDict().values())
sns.barplot(data=emptyValuesDF.toPandas(), palette="ch:start=.2,rot=-.3")

plt.xticks(rotation=90)
ax.set_ylabel('percentage')
ax.set_title('Empty Value percentage')
plt.rcParams.update({'font.size':12})
fig.subplots_adjust(bottom=0.2) 

%matplot plt

#### 3.b Calculate the frequency distribution of Top 3 values for every column <a id="3.b" />
Calculate the frequency/occurences of the top 3 values configured by the variable `noOfTop`. The output contains the percentage value of the occurence in the whole dataset.

In [None]:
from pyspark.sql.functions import (abs as df_abs, col, count, countDistinct,
                                   max as df_max, mean, min as df_min,
                                   sum as df_sum, when
                                   )

# the frequency of the top occuring words
noOfTop = 3

def guess_json_type(string_value):
        try:
            obj = json.loads(string_value)
        except:
            return None

        return type(obj)
    
def describe_categorical_1d(df, column, recordCount):
        count_column_name = "count({c})".format(c=column)

        value_counts = (df.select(column).na.drop()
                        .groupBy(column)
                        .agg(count(col(column)))
                        .orderBy(count_column_name, ascending=False)
                       ).cache()

        # Get the noOfTop classes by value count,
        # and put the rest of them grouped at the
        # end of the Series:
        top_Freq = value_counts.limit(noOfTop).toPandas().sort_values(count_column_name,
                                                               ascending=False)
        
        top_Freq[count_column_name] = top_Freq[count_column_name]/recordCount*100
        top_Freq[count_column_name] = top_Freq[count_column_name].round(2).astype(str) + '%'
        

        
        top_Freq.rename(columns={count_column_name: "frequency"})
        
        stats = top_Freq.take([0]).rename(columns={column: 'top', count_column_name: 'freq'}).iloc[0]
        
        others_count = 0
        others_distinct_count = 0
        unique_categories_count = value_counts.count()
        
        value_counts.unpersist()
        top = top_Freq.set_index(column)[count_column_name]
        stats["value_counts"] = top
        stats["unique values"] = unique_categories_count
        return top.to_json(orient="index")

In [None]:
from pyspark.sql import functions as f
frequencyDict = dict()
for colname in filteredDF.columns:
    frequencyDict[colname] = describe_categorical_1d(filteredDF, colname, recordCount)

In [None]:
print(frequencyDict)

## 4. Invoke the LLM Model <a id="4" />

#### 4.a Provide the model id of the LLM model to use with Amazon Bedrock <a id="4.a" />
<div class="alert alert-info alert-block"><b>Note: </b>Refer to this <a href='https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html'>link</a> to get the model id of the LLM. In certain cases, the model id may refer to an ARN, as is the case for Anthropic's Claude 3.5 Sonnet v2.<br /><br />Update the <i>region</i> and <i>AWSAccountNumber</i> in the below cell.</div>

In [None]:
model_id = 'arn:aws:bedrock:[region]:[AWSAccountNumber]:inference-profile/us.anthropic.claude-3-5-sonnet-20241022-v2:0'

#### 4.b Prepare the prompt for the LLM model to generate the rule recommendation <a id="4.b" />

The LLM model uses a few shot learning to make the LLM model aware of the constraints of the AWS Entity Resolution matching. This is to ensure that the output generated by the model is not generic, but specific to the capabilities of the service.

In [None]:
prompt = f"""Entity resolution is the process of determining when multiple records belong to the same person, despite differences in how they are described or inconsistencies in how data was entered.

As an example, these 2 records belong to the same person:
Record 1:
Name: Jon Doe II
Address: 123 East Main St.
Phone: +1(897)777-1414
Email: jondoe@gmail.com

Record 2:
Name: Jonathan Doe Jr.
Address: 123 East Main St.
Phone: 897.777.1414
Email:

The conclusion is that the above two records belong to the same person because they both share similar names, Jonathan and Jon being common alias/nickname, and share the same address and phone number even though the second record is missing the email identifier. A rule that can be applied deterministically to confirm this would be: (name) AND (Address Or Email Or Phone)

Entity resolution can also determine when multiple records do not belong to the same person. As an example if the name in one record contains Sr. as the suffix, while the other record contains Jr. as the suffix, it can be determined that the records belong to two different individuals despite they sharing the same address and phone, primarily because it would indicate a father and a son living in the same house.

While determinig the rule, it is important to understand the quality of the input records. If any column has a reasonable percentage of missing values, the rule should either consider skipping that column or use it in a "OR" condition with another column. The rule should try to consider all the input columns.

Using this information, generate a rule that can be applied to determine if the records belong to the same person. The rule should be deterministic, with no consideration for fuzzy or geo-mapping. Also provide the reasoning for those rules, and no code snippets are required. If you ignore any of the input column, explain why in your reasoning. 

Data Quality for the input dataset is defined for each of the input column. The json for each column lists the top 3 occurences and their frequency from 0 to 100%. For any column, if the empty value is very high, you may ignore that column in the rule. If for any column, a particular value has high frequency occurence, there is a possibility that it may result in poor accuracy result."""

for column in emptyValuesDict:
    prompt += f"""\n
Column: {column}
Empty Value: {emptyValuesDict[column]}
Frequency: {frequencyDict[column]}

The generated rule should only choose the operator from [ExactMatch, AND, OR], where ExactMatch(column) means the values in that column need to match exactly. Output your rules after "Recommended Rule:"
If possible, make the rule simple and easy to understand.
"""

print(prompt)

#### 4.c Invoke the LLM model with the given prompt <a id="4.c" />

In [None]:
import json
from botocore.exceptions import ClientError
bedrockClient = boto3.client("bedrock-runtime")

# Define the prompt for the model.

accept = 'application/json'
contentType = 'application/json'

# Format the request payload using the model's native structure.
native_request = {
    "anthropic_version": "bedrock-2023-05-31",
    "max_tokens": 2000,
    "temperature": 1,
    "messages": [
        {
            "role": "user",
            "content": [{"type": "text", "text": prompt}],
        }
    ]
}

# Convert the native request to JSON.
request = json.dumps(native_request)

try:
    # Invoke the model with the request.
    response = bedrockClient.invoke_model(modelId=model_id, body=request)

except (Exception) as e:
    print(f"ERROR: Can't invoke '{model_id}'. Reason: {e}")
    exit(1)

# Decode the response body.



model_response = json.loads(response["body"].read())

# Extract and print the response text.
response_text = model_response["content"][0]["text"]
print(response_text)

#### 4.d Extract and parse the rule suggested by the LLM <a id="4.d" />

In [None]:
generated_rule = ""
try:
    generated_rule = response_text.split("Recommended Rule:")[1].split('\n\n')[0].replace('\n', '').strip()
except Exception as e:
    print("An error occurred:", e)
    
generated_rule

In [None]:
from pyspark.sql import functions as F
import re

def parse_rule_to_condition(rule_str, alias1="df1", alias2="df2"):
    # Replace 'ExactMatch(col)' with Spark's NULL-safe equality check
    rule_str = re.sub(
        r'ExactMatch\((\w+)\)',
        lambda m: f'(F.col("{alias1}.{m.group(1)}").eqNullSafe(F.col("{alias2}.{m.group(1)}")))',
        rule_str
    )
    # Replace logical operators with PySpark-compatible symbols
    rule_str = rule_str.replace("AND", "&").replace("OR", "|")
    # Evaluate the string into a PySpark Column condition
    return eval(rule_str)

df1 = filteredDF.alias("df1")
df2 = filteredDF.alias("df2")
condition = parse_rule_to_condition(generated_rule, "df1", "df2")
print(condition)

#### 4.e Validate the rule on the input dataset <a id="4.e" />
Apply the rule on the input dataset to look for pairs that match due to the rule

In [None]:
colNameA = "df1."+uniqueIdColumn
colNameB = "df2."+uniqueIdColumn

condition = condition & (F.col(colNameA) < F.col(colNameB))

found_pairs = df1.join(df2, condition).select(
    F.col(colNameA).alias("id1"),
    F.col(colNameB).alias("id2")
)

print(f"this rule found {found_pairs.count()} matches")

In [None]:
found_pairs.show()

#### Draw a connected graph with sample records

In [None]:
pd = found_pairs.limit(17).toPandas()
print(pd)

In [None]:
import networkx as nx
import matplotlib.pyplot as plt



plt.clf()
G = nx.DiGraph()

G.add_nodes_from(pd['id1'])
G.add_nodes_from(pd['id2'])
edges = [(row['id1'], row['id2']) for index, row in pd.iterrows()]
G.add_edges_from(edges)


pos = nx.planar_layout(G)

options = {
    "font_size": 8,
    "node_size": 1000,
    "node_color": "skyblue",
    "edgecolors": "black",
    "linewidths": 1,
    "width": 1,
}
nx.draw_networkx(G, pos, **options)

# Set margins for the axes so that nodes aren't clipped
ax = plt.gca()
ax.margins(0.01)
plt.axis("off")

%matplot plt

#### Apply a random sampling to get matched paris and verify the result

In [None]:
#randomly sample one pair for quality checking

sampled_pair = found_pairs.orderBy(F.rand()).limit(1).first()

id1 = sampled_pair.id1
id2 = sampled_pair.id2
print(f"generated rule :\n{generated_rule}")
print(f"sample pair: id1 = {id1}, id2 = {id2}")
row1 = filteredDF.filter(F.col(uniqueIdColumn) == id1).first()
row2 = filteredDF.filter(F.col(uniqueIdColumn) == id2).first()

for col in filteredDF.columns:
    val1 = row1[col] if row1[col] is not None else "NULL"
    val2 = row2[col] if row2[col] is not None else "NULL"
    match = "Match" if val1 == val2 else "No-Match"
    print(f"| {col:20} | {match:10} | {str(val1):20} | {str(val2):20} |")
    