# Amazon Comprehend - Custom Entity Detection Example

Reference:

- [Custom entity recognition](https://docs.aws.amazon.com/comprehend/latest/dg/custom-entity-recognition.html)

## 1. Create Fake Dataset

In this section, we create a dummy dataset that includes two important entities:

- Service Date (target entity)
- Receive Date (distracter entity)

Each document has several paragraph, and each paragraph has several sentence. There are always one sentence with ``service date`` information and one sentence with ``receive date`` information.

In this example, we use the [Plain-text annotations](https://docs.aws.amazon.com/comprehend/latest/dg/cer-annotation-csv.html) format for annotation.

In [121]:
import typing as T
import random
from datetime import datetime

import polars as pl
from faker import Faker
from s3pathlib import S3Path, context
from boto_session_manager import BotoSesManager
from rich import print as rprint

In [122]:
class Config:
    # aws account related
    aws_profile = "aws_data_lab_sanhe_us_east_1"
    s3_bucket = "669508176277-us-east-1-data"
    comprehend_iam_role = "arn:aws:iam::669508176277:role/sanhe-comprehend-admin-access"

    # comprehend model / dataset related
    model_name = "medical-service-report-entities"
    n_document = 200
    n_paragraph_per_doc_lower = 3
    n_paragraph_per_doc_upper = 6
    n_sentence_per_paragraph_lower = 3
    n_sentence_per_paragraph_upper = 10
    n_word_per_sentence_lower = 6
    n_word_per_sentence_upper = 20

    service_date_key_options = ["service date", "date of service"]
    receive_date_key = "receive date"
    date_format_options = ["%Y-%m-%d", "%m/%d/%Y"]

In [123]:
bsm = BotoSesManager(profile_name=Config.aws_profile)
context.attach_boto_session(bsm.boto_ses)

fake = Faker()

In [124]:
def create_initial_sentences() -> T.List[str]:
    n_sentence = random.randint(
        Config.n_sentence_per_paragraph_lower,
        Config.n_sentence_per_paragraph_upper,
    )
    return [
        fake.sentence(
            nb_words=random.randint(
                Config.n_word_per_sentence_lower,
                Config.n_word_per_sentence_upper,
            ),
            variable_nb_words=False,
        )
        for _ in range(n_sentence)
    ]

print(create_initial_sentences())

['Likely myself reach turn story.', 'Environmental choice challenge senior tough manage.', 'Husband once especially return including truth.', 'Owner mention human little should energy.']


In [125]:
def create_paragraph() -> str:
    return " ".join(create_initial_sentences())

print(create_paragraph())

Summer strong social. Strong three forget power education measure. Scientist here kitchen power short bill. Enter now camera answer administration.


In [126]:
def create_important_paragraph(key: str) -> T.Tuple[str, int, int]:
    sentence_list = create_initial_sentences()
    n_word = random.randint(
        Config.n_word_per_sentence_lower,
        Config.n_word_per_sentence_upper,
    )
    words = [fake.word() for _ in range(n_word)]

    # IMPORTANT, the logic below calculates the begin offset and end offset
    value = datetime.strptime(fake.date(), "%Y-%m-%d").strftime(random.choice(Config.date_format_options))

    key_index = random.randint(1, n_word)
    words.insert(key_index, key)

    value_index = random.randint(key_index+1, n_word+1)
    words.insert(value_index, value)

    sentence = " ".join(words) + "."
    begin_offset = len(" ".join(words[:value_index])) + 1
    end_offset = begin_offset + len(value)

    n_sentence = len(sentence_list)
    sentence_index = random.randint(1, n_sentence)
    sentence_list.insert(sentence_index, sentence)
    sentence_offset = len(" ".join(sentence_list[:sentence_index])) + 1
    paragraph = " ".join(sentence_list)

    begin_offset += sentence_offset
    end_offset += sentence_offset

    # debug
    # print([paragraph[begin_offset-1], paragraph[begin_offset], paragraph[begin_offset+1]])
    # print([paragraph[end_offset-1], paragraph[end_offset], paragraph[end_offset+1]])

    return paragraph, begin_offset, end_offset

print(create_important_paragraph("service date"))

('Son president wind eye. Without law around. Pattern international character. author service date call 1984-12-14 growth understand animal serve. North whatever machine maybe never there.', 102, 112)


In [128]:
def create_document() -> T.Tuple[str, int, int, int, int]:
    n_paragraph = random.randint(Config.n_paragraph_per_doc_lower, Config.n_paragraph_per_doc_upper)
    paragraph_list = [create_paragraph() for _ in range(n_paragraph)]

    (
        service_date_paragraph,
        service_date_begin_offset,
        service_date_end_offset,
    )= create_important_paragraph(key=random.choice(Config.service_date_key_options))
    (
        receive_date_paragraph,
        receive_date_begin_offset,
        receive_date_end_offset,
    )= create_important_paragraph(key=Config.receive_date_key)

    # IMPORTANT, insert the important paragraph to document and
    # calculate the line offset
    index1_and_index2 = random.sample(list(range(1, n_paragraph)), 2)
    index1_and_index2.sort()
    index1, index2 = index1_and_index2

    paragraph_list.insert(index2, receive_date_paragraph)
    delta = len("\n".join(paragraph_list[:index2])) + 1
    receive_date_begin_offset += (delta + len(service_date_paragraph) + 1)
    receive_date_end_offset += (delta + len(service_date_paragraph) + 1)

    paragraph_list.insert(index1, service_date_paragraph)
    delta = len("\n".join(paragraph_list[:index1])) + 1
    service_date_begin_offset += delta
    service_date_end_offset += delta

    document = "\n".join(paragraph_list)

    # debug
    # print([document[service_date_begin_offset-1], document[service_date_begin_offset], document[service_date_begin_offset+1]])
    # print([document[service_date_end_offset-1], document[service_date_end_offset], document[service_date_end_offset+1]])
    # print([document[receive_date_begin_offset-1], document[receive_date_begin_offset], document[receive_date_begin_offset+1]])
    # print([document[receive_date_end_offset-1], document[receive_date_end_offset], document[receive_date_end_offset+1]])

    return (
        document,
        service_date_begin_offset,
        service_date_end_offset,
        receive_date_begin_offset,
        receive_date_end_offset,
    )

print(create_document())

('Charge culture become PM. Already certainly child agent. Serious as case show development drive.\nTell conference all meet. Impact small value any. example date of service police 1988-01-15 according human marriage. Rich four prevent high street investment. Per current security price loss.\nMonth hold serious. Television exist reality eight thing. Inside citizen enjoy college hear. Event laugh join. Form street ok after billion.\nStation simply be phone. current receive date land 1974-10-06 very find small. Congress plant which up. Fight cost late stuff. Sell fund identify perform century.\nBad story huge fine. Woman describe continue able myself well. Suffer question design. Study sport argue.', 178, 188, 482, 492)


## 2. Write Training / Testing to S3


In [131]:
# Define S3 locations
s3dir_project = S3Path.from_s3_uri(f"s3://{Config.s3_bucket}/poc/2023-04-09-custom-document-classification-example/")
s3dir_doc = s3dir_project.joinpath("documents").to_dir()
s3dir_doc_train = s3dir_doc.joinpath("train").to_dir()
s3dir_doc_test = s3dir_doc.joinpath("test").to_dir()
s3path_annotation_train = s3dir_project.joinpath("annotation-train.csv")
s3path_annotation_test = s3dir_project.joinpath("annotation-test.csv")

print(f"preview s3dir_document: {s3dir_doc.console_url}")

# _ = s3dir_project.delete_if_exists()

preview s3dir_document: https://console.aws.amazon.com/s3/buckets/669508176277-us-east-1-data?prefix=poc/2023-04-09-custom-document-classification-example/documents/


In [132]:
# Create dataset and split into train and test (70/30)
training = list()
testing = list()
for ith_doc in range(1, 1+Config.n_document):
    doc_tuple = create_document()
    filename = f"{str(ith_doc).zfill(6)}.txt"
    sample = (filename, doc_tuple)
    if random.randint(1, 100) <= 70:
        training.append(sample)
    else:
        testing.append(sample)

In [133]:
# Write data and annotation to S3
for s3dir_document, s3path_annotation, samples in [
    (s3dir_doc_train, s3path_annotation_train, training),
    (s3dir_doc_test, s3path_annotation_test, testing),
]:
    annotation_rows = list()
    for (
        filename,
        (
            document,
            service_date_begin_offset,
            service_date_end_offset,
            receive_date_begin_offset,
            receive_date_end_offset,
        ),
    ) in samples:
        s3path = s3dir_document.joinpath(filename)
        s3path.write_text(document)
        annotation_rows.append((
            filename,
            service_date_begin_offset,
            service_date_end_offset,
            "SERVICE_DATE",
        ))
        annotation_rows.append((
            filename,
            receive_date_begin_offset,
            receive_date_end_offset,
            "RECEIVE_DATE",
        ))
    annotation_df = pl.DataFrame(
        annotation_rows,
        schema=["File", "Begin Offset", "End Offset", "Type"],
    )
    with s3path_annotation.open("wb") as f:
        annotation_df.write_csv(f, has_header=True, separator=",")
    #     break
    # break

## 3. Train Comprehend Model

In this example, we use [comprehend.create_entity_recognizer](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/comprehend/client/create_entity_recognizer.html) API to create a custom entity recognizer.


In [135]:
bsm.comprehend_client.create_entity_recognizer(
    RecognizerName=Config.model_name,
    VersionName="v000002",
    DataAccessRoleArn=Config.comprehend_iam_role,
    InputDataConfig=dict(
        # ----------------------------------------
        # Define Data Format
        # ----------------------------------------
        DataFormat="COMPREHEND_CSV",
        # DataFormat="AUGMENTED_MANIFEST",
        # ----------------------------------------
        # The entity types in the labeled training data that Amazon Comprehend uses to train the custom entity recognizer. Any entity types that you don’t specify are ignored.
        #
        # A maximum of 25 entity types can be used at one time to train an entity recognizer. Entity types must not contain the following invalid characters: n (line break), \n (escaped line break), r (carriage return), \r (escaped carriage return), t (tab), \t (escaped tab), space, and , (comma).
        # ----------------------------------------
        EntityTypes=[
            dict(Type="SERVICE_DATE"),
            dict(Type="RECEIVE_DATE"),
        ],
        # ----------------------------------------
        # The S3 location of the folder that contains the training documents for your custom entity recognizer.
        #
        # This parameter is required if you set DataFormat to COMPREHEND_CSV.
        # ----------------------------------------
        Documents=dict(
            S3Uri=s3dir_doc_train.uri,
            TestS3Uri=s3dir_doc_test.uri,
            InputFormat="ONE_DOC_PER_FILE",
        ),
        # ----------------------------------------
        # The S3 location of the CSV file that annotates your training documents.
        # ----------------------------------------
        Annotations=dict(
            S3Uri=s3path_annotation_train.uri,
            TestS3Uri=s3path_annotation_test.uri,
        ),
    ),
    LanguageCode="en",
)

{'EntityRecognizerArn': 'arn:aws:comprehend:us-east-1:669508176277:entity-recognizer/medical-service-report-entities/version/v000002',
 'ResponseMetadata': {'RequestId': 'c7352d2d-691c-4813-9584-64556870a90f',
  'HTTPStatusCode': 200,
  'HTTPHeaders': {'x-amzn-requestid': 'c7352d2d-691c-4813-9584-64556870a90f',
   'content-type': 'application/x-amz-json-1.1',
   'content-length': '133',
   'date': 'Mon, 10 Apr 2023 05:30:10 GMT'},
  'RetryAttempts': 0}}