From f073617d98775d65518a758825ee5e47f3ca7e0c Mon Sep 17 00:00:00 2001 From: Alex Thewsey Date: Tue, 10 Jan 2023 00:38:19 +0800 Subject: [PATCH 01/14] fix(src): Handling textract URIs with no manifest Fix enumeration of input files when no data channel manifest provided --- notebooks/src/code/data/base.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/notebooks/src/code/data/base.py b/notebooks/src/code/data/base.py index 4b965d4..ecd2da8 100644 --- a/notebooks/src/code/data/base.py +++ b/notebooks/src/code/data/base.py @@ -496,12 +496,13 @@ def prepare_base_dataset( ds_raw = datasets.Dataset.from_dict( { "textract-ref": [ - os.path.join(currpath, file) + # Output paths *relative* to textract_path: + os.path.join(os.path.relpath(currpath, textract_path), file) for currpath, _, files in os.walk(textract_path) for file in files ] }, - cache_dir=cache_dir, + # At writing, from_dict() doesn't support setting cache_dir ) if not datasets.utils.is_progress_bar_enabled(): From 1511c9000ff755207ada76c29b5291f4146e9ba3 Mon Sep 17 00:00:00 2001 From: Alex Thewsey Date: Tue, 10 Jan 2023 00:44:32 +0800 Subject: [PATCH 02/14] fix(src): allow MLM model_param_names=None Resolve issue where mlm get_task() and class TextractLayoutLMDataCollatorForLanguageModelling would throw errors if model_param_names argument was missing, even though it was marked as optional. These will now raise warnings instead. --- notebooks/src/code/data/mlm.py | 39 ++++++++++++++++++++++++++++------ 1 file changed, 32 insertions(+), 7 deletions(-) diff --git a/notebooks/src/code/data/mlm.py b/notebooks/src/code/data/mlm.py index 23cb375..0f5f838 100644 --- a/notebooks/src/code/data/mlm.py +++ b/notebooks/src/code/data/mlm.py @@ -54,20 +54,20 @@ def __post_init__(self): # Configuration diagnostics: if self.tim_probability > 0: - if "image_mask_label" not in self.model_param_names: + if self.model_param_names is None or "image_mask_label" not in self.model_param_names: logger.warning( "model_param_names does not contain image_mask_label: Ignoring configured " "tim_probability. Text-Image Matching will be disabled." ) - elif "image_mask_label" in self.model_param_names: + elif self.model_param_names is not None and ("image_mask_label" in self.model_param_names): logger.warning("Pre-training with Text-Image Matching disabled (tim_probability = 0)") if self.tiam_probability > 0: - if "imline_mask_label" not in self.model_param_names: + if self.model_param_names is None or "imline_mask_label" not in self.model_param_names: logger.warning( "model_param_names does not contain imline_mask_label: Ignoring configured " "tiam_probability. Text-Image Alignment will be disabled." ) - elif "imline_mask_label" in self.model_param_names: + elif self.model_param_names is not None and "imline_mask_label" in self.model_param_names: logger.warning("Pre-training with Text-Image Alignment disabled (tiam_probability = 0)") return super().__post_init__() @@ -82,6 +82,22 @@ def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict "Custom Textract MLM data collator has not been implemented for TensorFlow" ) + def _is_tia_enabled(self) -> bool: + """Safely check whether current configuration enables Text-Image Alignment (TIA) task""" + return ( + self.model_param_names is not None + and self.tiam_probability > 0 + and "imline_mask_label" in self.model_param_names + ) + + def _is_tim_enabled(self) -> bool: + """Safely check whether current configuration enables Text-Image Matching (TIM) task""" + return ( + self.model_param_names is not None + and self.tim_probability > 0 + and "image_mask_label" in self.model_param_names + ) + def torch_call(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]: if isinstance(batch, list): batch = {k: [ex[k] for ex in batch] for k in batch[0]} @@ -126,7 +142,7 @@ def torch_call(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]: # (This fn is cheap to call if TIM is turned off) image_mask_labels, image_indices = self.torch_permute_images(tokenized["image"]) - if ("imline_mask_label" in self.model_param_names) and (self.tiam_probability > 0): + if self._is_tia_enabled(): # Text Image Alignment (TIA): For each image in the batch (including reassigned # ones), mask some text in the image. masked_images = [] @@ -164,7 +180,7 @@ def torch_call(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]: shuffled_images = tokenized["image"][image_indices] tokenized["image"] = shuffled_images - if ("image_mask_label" in self.model_param_names) and (self.tim_probability > 0): + if self._is_tim_enabled(): tokenized["image_mask_label"] = image_mask_labels return tokenized @@ -327,7 +343,16 @@ def get_task( """Load datasets and data collators for MLM model training""" logger.info("Getting MLM datasets") # We only need to track line IDs for each word when TIAM is enabled: - tiam_enabled = ("imline_mask_label" in model_param_names) and (data_args.tiam_probability > 0) + if model_param_names is None: + tiam_enabled = False + logger.warning( + "Skipping generation of text line IDs in dataset: Since model_param_names is not " + "provided, we don't know if this field is required/supported by the model." + ) + else: + tiam_enabled = ("imline_mask_label" in model_param_names) and ( + data_args.tiam_probability > 0 + ) train_dataset = prepare_dataset( data_args.textract, tokenizer=tokenizer, From 3ebd3aab3109df2220d970a6097aacc684d78831 Mon Sep 17 00:00:00 2001 From: Alex Thewsey Date: Tue, 10 Jan 2023 03:00:39 +0800 Subject: [PATCH 03/14] feat: First draft seq2seq model integration Text-only seq2seq field normalization model using T5 to normalize date fields to YYYY-MM-DD format. Includes integration to pipeline (via postproc Lambda + field configuration SSM param) and setup NB (in Optional Extras.ipnyb), but no updates to readme/customization guide yet. --- notebooks/Optional Extras.ipynb | 656 +++++++++++++++++- notebooks/src/code/config.py | 35 +- notebooks/src/code/data/__init__.py | 9 +- notebooks/src/code/data/seq2seq/__init__.py | 9 + .../code/data/seq2seq/date_normalization.py | 223 ++++++ .../src/code/data/seq2seq/task_builder.py | 224 ++++++ notebooks/src/code/inference_seq2seq.py | 130 ++++ notebooks/src/code/models/__init__.py | 3 + notebooks/src/code/train.py | 41 +- notebooks/src/inference_seq2seq.py | 5 + .../postprocessing/fn-postprocess/main.py | 101 +-- .../fn-postprocess/util/config.py | 16 + .../fn-postprocess/util/extract.py | 131 ++++ .../fn-postprocess/util/normalize.py | 91 +++ 14 files changed, 1541 insertions(+), 133 deletions(-) create mode 100644 notebooks/src/code/data/seq2seq/__init__.py create mode 100644 notebooks/src/code/data/seq2seq/date_normalization.py create mode 100644 notebooks/src/code/data/seq2seq/task_builder.py create mode 100644 notebooks/src/code/inference_seq2seq.py create mode 100644 notebooks/src/code/models/__init__.py create mode 100644 notebooks/src/inference_seq2seq.py create mode 100644 pipeline/postprocessing/fn-postprocess/util/extract.py create mode 100644 pipeline/postprocessing/fn-postprocess/util/normalize.py diff --git a/notebooks/Optional Extras.ipynb b/notebooks/Optional Extras.ipynb index 97a29c6..5bb3765 100644 --- a/notebooks/Optional Extras.ipynb +++ b/notebooks/Optional Extras.ipynb @@ -11,13 +11,15 @@ "\n", "# Optional Extras\n", "\n", - "> *This notebook works well with the `Data Science 3.0 (Python 3)` kernel on SageMaker Studio - use the same as for NB1*\n", + "> *This notebook works well with the `PyTorch 1.10 Python 3.8 CPU Optimized (Python 3)` kernel on SageMaker Studio - **different** from the others in the series*\n", "\n", - "This notebook discusses optional extra/alternative steps separate from the typical pipeline setup flow. You won't typically need to run these steps unless specifically instructed.\n", + "This notebook discusses optional extra/alternative steps separate from the typical pipeline setup flow. You won't typically need to run these steps unless specifically guided, or you're digging deeper into customization.\n", "\n", "## Common setup\n", "\n", - "First, as usual, we'll set up and import required libraries. You should run these cells regardless of which optional section(s) you're using:" + "First, as usual, we'll set up and import required libraries. You should run these cells regardless of which optional section(s) you're using:\n", + "\n", + "The Hugging Face `datasets` and `transformers` installs here are used specifically for dataset preparation in the seq2seq section. If you have problems with these libraries and aren't tackling this section, you may be able to omit them. If you regularly need to install several custom libraries in Studio notebooks, refer to the [documentation](https://docs.aws.amazon.com/sagemaker/latest/dg/studio-byoi-create.html) and [samples](https://github.com/aws-samples/sagemaker-studio-custom-image-samples) on building **Custom kernel images** for SageMaker." ] }, { @@ -25,11 +27,17 @@ "execution_count": null, "id": "01620969-ea29-46c7-ba84-9006ca9e73d1", "metadata": { + "scrolled": true, "tags": [] }, "outputs": [], "source": [ - "!pip install sagemaker-studio-image-build \"sagemaker>=2.87,<3\"" + "!pip install amazon-textract-response-parser \\\n", + " \"datasets>=2.4,<3\" \\\n", + " \"ipywidgets>=7,<8\" \\\n", + " sagemaker-studio-image-build \\\n", + " \"sagemaker>=2.87,<3\" \\\n", + " \"transformers>=4.25,<4.26\"" ] }, { @@ -52,6 +60,8 @@ "\n", "# External Dependencies:\n", "import boto3 # General-purpose AWS SDK for Python\n", + "import numpy as np # Matrix/math utilities\n", + "import pandas as pd # Data table / dataframe utilities\n", "import sagemaker # High-level Python SDK for Amazon SageMaker\n", "\n", "# Local Dependencies:\n", @@ -82,12 +92,13 @@ "\n", "- **[Manual thumbnail generator setup](#Manual-thumbnail-generator-setup)**: Customise online page thumbnail generation endpoint\n", "- **[Optimise costs with endpoint auto-scaling](#Optimise-costs-with-endpoint-auto-scaling)**: Configure your SageMaker endpoint(s) to auto-scale based on incoming request volume\n", - "- **[Experimenting with alternative OCR engines](#Experimenting-with-alternative-OCR-engines)**: Substitute Amazon Textract with open-source OCR tools, for use with unsupported languages" + "- **[Experimenting with alternative OCR engines](#Experimenting-with-alternative-OCR-engines)**: Substitute Amazon Textract with open-source OCR tools, for use with unsupported languages\n", + "- **[Exploring sequence-to-sequence models](#Exploring-sequence-to-sequence-models)**: Use generative models to automatically re-format detected fields and fix common OCR error patterns" ] }, { "cell_type": "markdown", - "id": "c90d7dbd-95b3-42ab-a9f1-906742a6c860", + "id": "b2de05db-4a6a-44c0-afc6-203ecdbd999d", "metadata": { "tags": [] }, @@ -100,8 +111,16 @@ ">\n", "> You may find it useful if you want to customise the container image or script used by this process, or if you deployed your pipeline without thumbnailing support but want to experiment with image-based models from notebooks.\n", ">\n", - "> ⚠️ **Note:** Deploying and registering a thumbnailing endpoint from the notebook will still not turn on thumbnail generation in a pipeline deployed without support for it. Instead, refer to your CDK app parameters to ensure the pipeline state machine gets updated to include a thumbnail generation step.\n", - "\n", + "> ⚠️ **Note:** Deploying and registering a thumbnailing endpoint from the notebook will still not turn on thumbnail generation in a pipeline deployed without support for it. Instead, refer to your CDK app parameters to ensure the pipeline state machine gets updated to include a thumbnail generation step." + ] + }, + { + "cell_type": "markdown", + "id": "71fa91ae-c79b-4ecc-b470-d7daa79033ce", + "metadata": { + "tags": [] + }, + "source": [ "### Build and register custom container image\n", "\n", "The tools we use to read PDF files aren't installed by default in the pre-built SageMaker containers and aren't `pip install`able, so the thumbnail generator will need a custom container image. We can derive a custom image from an existing AWS DLC serving container, to minimise boilerplate code because a SageMaker-compatible serving stack will already be included.\n", @@ -378,8 +397,10 @@ }, { "cell_type": "markdown", - "id": "3c185e93-fc9f-451d-ab3d-d597665ea4a4", - "metadata": {}, + "id": "77093526-6286-401e-97d4-2a408cbbd15b", + "metadata": { + "tags": [] + }, "source": [ "---\n", "\n", @@ -393,8 +414,16 @@ "\n", "SageMaker Async Inference endpoints support [auto-scaling down to zero instances](https://docs.aws.amazon.com/sagemaker/latest/dg/async-inference-autoscale.html) when not in use, which can provide significant cost-savings for use cases where document processing is occasional and the pipeline is often idle.\n", "\n", - "⏰ **However:** You should be aware that enabling scale-to-zero can introduce cold-start delays of **several minutes** if requests arrive when all instances backing your endpoint have been shut down.\n", - "\n", + "⏰ **However:** You should be aware that enabling scale-to-zero can introduce cold-start delays of **several minutes** if requests arrive when all instances backing your endpoint have been shut down." + ] + }, + { + "cell_type": "markdown", + "id": "6365f067-0180-4b1a-b10d-fb1c6c04a482", + "metadata": { + "tags": [] + }, + "source": [ "### Setting up auto-scaling\n", "\n", "You can configure auto-scaling for your endpoint(s) by first registering them with the [application auto-scaling service](https://docs.aws.amazon.com/autoscaling/application/userguide/what-is-application-auto-scaling.html) and then applying a scaling policy as shown in the following cells.\n", @@ -606,8 +635,10 @@ }, { "cell_type": "markdown", - "id": "4c980dcd-8263-4d5e-bbcc-0eeaf43e022f", - "metadata": {}, + "id": "7d3c5912-dfe8-409f-8543-90b53d12c688", + "metadata": { + "tags": [] + }, "source": [ "---\n", "\n", @@ -617,7 +648,17 @@ "\n", "> This section demonstrates how to process a batch of documents using alternative, open-source-based OCR engines on Amazon SageMaker - in case you have a use case requiring languages not yet supported by Amazon Textract.\n", "\n", - "As detailed further in the [Customization Guide](../CUSTOMIZATION_GUIDE.md) - You can use alternative, open-source-based OCR engines with this solution if needed, by packaging them to produce Amazon Textract-compatible result formats and integrating them with the pipeline, for which we use Amazon SageMaker Asynchronous Inference for consistency with other steps.\n", + "As detailed further in the [Customization Guide](../CUSTOMIZATION_GUIDE.md) - You can use alternative, open-source-based OCR engines with this solution if needed, by packaging them to produce Amazon Textract-compatible result formats and integrating them with the pipeline, for which we use Amazon SageMaker Asynchronous Inference for consistency with other steps." + ] + }, + { + "cell_type": "markdown", + "id": "f0ad4044-fca1-4579-8b48-ba9d1322f367", + "metadata": { + "tags": [] + }, + "source": [ + "### Deploy the alternative engine(s)\n", "\n", "First, (re)-deploy your solution with the `BUILD_SM_OCRS` variable set, to create container image(s) and SageMaker model(s) for your chosen OCR engine(s).\n", "\n", @@ -649,6 +690,8 @@ "id": "16185606-2a2b-416f-9764-bc022dba5bdb", "metadata": {}, "source": [ + "### Extract documents in batch\n", + "\n", "Just like with batch page image generation in notebook 1, we'll use a SageMaker Processing Job to run the work on a scalable cluster of instances. The input document locations are specified the same way as for page image generation, so the code below takes the whole corpus (S3 prefix) for simplicity.\n", "\n", "> ⏰ If you'd like to select **just a subset of documents**, you can instead set `ocr_inputs` using the same manifest-based \"OPTION 2\" approach shown to set `preproc_inputs` in the *Extract clean input images* section of [Notebook 1](1.%20Data%20Preparation.ipynb)." @@ -792,11 +835,584 @@ }, { "cell_type": "markdown", - "id": "76c1ba26-e982-4f8a-8bf0-fc0902a2d917", + "id": "048ef094-9f94-466d-bb93-729c4ac82348", + "metadata": {}, + "source": [ + "### Integrate with the document pipeline\n", + "\n", + "The above steps demonstrate how to process documents in batch with alternative, open-source OCR engines, to produce datasets ready for experimenting with multi-lingual model architectures like LayoutXLM. To actually deploy the alternative OCR into your document pipeline, use the `DEPLOY_SM_OCR` and `USE_SM_OCR` variables at CDK deployment. You'll likely want to update `OCR_DEFAULT_LANGUAGES` in [/pipeline/ocr/sagemaker_ocr.py](../pipeline/ocr/sagemaker_ocr.py) to align with your use case's language needs." + ] + }, + { + "cell_type": "markdown", + "id": "88994b13-bd18-4793-a4b9-ac3978d2fdf5", + "metadata": { + "tags": [] + }, + "source": [ + "---\n", + "\n", + "*[Back to contents](#Contents)*\n", + "\n", + "## Exploring sequence-to-sequence models\n", + "\n", + "> This section demonstrates training a (non-layout-aware) model that edits extracted text fields to normalize data types or fix common OCR error patterns.\n", + "\n", + "Since the main flow of this solution focusses on \"extractive\" entity recognition models, you might reasonably wonder whether the same layout-aware ideas could be extended to \"generative\" models capable of actually editing the OCR'd text: For example to reformat fields or fix errors. The answer to this is **\"probably yes, but...\"**:\n", + "\n", + "1. Care needs to be taken with large generative models to address bias and privacy concerns: For example will it be possible to extract sensitive or PII data the model was trained on, when it's deployed? Will it be biased to predicting certain patterns that aren't representative of your documents, or are representative on average but leave some user groups with consistently poorer service?\n", + "2. Published, pre-trained, layout-aware document models have most often provided a decoder-only stack to date: so finding pre-trained initial weights for a generative output module may be challenging. Due to their large size, training these modules from scratch could be resource-intensive.\n", + "\n", + "Here we show a more basic approach to start realizing some of the same benefits: Pairing the layout-aware NER model **alongside text-only seq2seq models** to normalize and standardize extracted fields." + ] + }, + { + "cell_type": "markdown", + "id": "1cf54db0-354b-4e3c-acba-c81c15f6c1ca", "metadata": {}, "source": [ - "The above steps demonstrate how to process documents in batch with alternative, open-source OCR engines, to produce datasets ready for experimenting with multi-lingual model architectures like LayoutXLM. To actually deploy the alternative OCR into your document pipeline, use the `DEPLOY_SM_OCR` and `USE_SM_OCR` variables at CDK deployment. You'll likely want to update `OCR_DEFAULT_LANGUAGES` in [/pipeline/ocr/sagemaker_ocr.py](../pipeline/ocr/sagemaker_ocr.py) to align with your use case's language needs.\n", + "### Collect datasets\n", "\n", + "In this example we'll demonstrate **normalizing dates** to a consistent format. Text-to-text models can tackle this in a flexible, example-driven and statistics-oriented way. Although maximum achievable accuracy might sometimes be higher with rule-based approaches, we'll show how the ML-based approach can yield good performance quickly without needing to build lots of rules and parsing expressions.\n", + "\n", + "This task can be tackled via **synthetic dataset generation**: randomly generating dates and input prompts, according to expected statistical distribution of your target data.\n", + "\n", + "Run the cell below to generate a training and evaluation dataset. As shown in the preview, the data will include a wide range of source date formats but **also** support multiple different *target* formats - controllable via the first part of the prompt:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "67265da0-8ec5-46ce-837a-7883d435a7b6", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from src.code.data.seq2seq.date_normalization import generate_seq2seq_date_norm_dataset\n", + "\n", + "rng = np.random.default_rng(42)\n", + "train_dataset = generate_seq2seq_date_norm_dataset(n=1000, rng=rng)\n", + "eval_dataset = generate_seq2seq_date_norm_dataset(n=200, rng=rng)\n", + "\n", + "train_dataset.save_to_disk(\"data/seq2seq-train\")\n", + "eval_dataset.save_to_disk(\"data/seq2seq-validation\")\n", + "\n", + "print(\"Dataset sample (top 10 records):\")\n", + "pd.DataFrame(train_dataset[0:10])" + ] + }, + { + "cell_type": "markdown", + "id": "38389ddb-80b6-45fe-94d9-2f569becbe9c", + "metadata": {}, + "source": [ + "As usual with SageMaker, once the datasets are prepared we'll stage them to Amazon S3 ready to use in model training:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e1a21963-a264-49f7-8fae-17f5b5a502c1", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "train_s3uri = f\"s3://{bucket_name}/{bucket_prefix}seq2seq/date-norm/train\"\n", + "validation_s3uri = f\"s3://{bucket_name}/{bucket_prefix}seq2seq/date-norm/validation\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "03e5ad33-d8d5-4e85-9f69-f922dcebf0ac", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "!aws s3 sync --delete data/seq2seq-train {train_s3uri}\n", + "!aws s3 sync --delete data/seq2seq-validation {validation_s3uri}" + ] + }, + { + "cell_type": "markdown", + "id": "3f120616-0ae6-49c0-9821-893e5501cf9d", + "metadata": {}, + "source": [ + "### Look up custom container images\n", + "\n", + "The training and inference jobs in this section will use the same customized container images created in the main notebook series for model training and deployment (see [Notebook 2 Model Training](2.%20Model%20Training.ipynb)): so you need to have built those first.\n", + "\n", + "The code below will check the container images are already prepared and staged in your account's Amazon Elastic Container Registry (ECR)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eb38b61f-4af4-4190-87a7-ce07a8ab2273", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Configurations:\n", + "train_repo_name = \"sm-ocr-training\"\n", + "train_repo_tag = \"hf-pt-gpu-custom\" # TODO: Check this matches your ECR repo name and tagging\n", + "inf_repo_name = \"sm-ocr-inference\"\n", + "inf_repo_tag = train_repo_tag\n", + "\n", + "account_id = sagemaker.Session().account_id()\n", + "region = os.environ[\"AWS_REGION\"]\n", + "\n", + "# Combine together into the final URIs:\n", + "train_image_uri = f\"{account_id}.dkr.ecr.{region}.amazonaws.com/{train_repo_name}:{train_repo_tag}\"\n", + "print(f\"Target training image: {train_image_uri}\")\n", + "inf_image_uri = f\"{account_id}.dkr.ecr.{region}.amazonaws.com/{inf_repo_name}:{inf_repo_tag}\"\n", + "print(f\"Target inference image: {inf_image_uri}\")\n", + "\n", + "# Check from notebook whether the images were successfully created:\n", + "ecr = boto3.client(\"ecr\")\n", + "for repo, tag, uri in (\n", + " (train_repo_name, train_repo_tag, train_image_uri),\n", + " (inf_repo_name, inf_repo_tag, inf_image_uri)\n", + "):\n", + " imgs_desc = ecr.describe_images(\n", + " registryId=account_id,\n", + " repositoryName=repo,\n", + " imageIds=[{\"imageTag\": tag}],\n", + " )\n", + " assert len(imgs_desc[\"imageDetails\"]) > 0, f\"Couldn't find ECR image {uri} after build\"\n", + " print(f\"Found {uri}\")" + ] + }, + { + "cell_type": "markdown", + "id": "d991db9b-0daf-4517-a3da-3e6b24c30a7d", + "metadata": {}, + "source": [ + "### Train a model\n", + "\n", + "With data prepared, model training is very similar to the setup from the main notebooks. Some key differences include:\n", + "\n", + "- Setting `task_name: seq2seq` to indicate we're training a sequence-to-sequence model instead of the usual layout-aware `ner`.\n", + "- Choosing a text-only pre-trained base model compatible with text generation, in this case `t5-base`.\n", + "- Since the data is synthetic, we can easily generate quite a large dataset in comparison to the amount of training we want to run: So logging, evaluation, and model saving will be controlled in terms of number of training steps rather than number of epochs (passes through the whole dataset)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "264ec056-d96f-440f-8f00-e4df0763980d", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from sagemaker.huggingface.estimator import HuggingFace as HuggingFaceEstimator\n", + "\n", + "hyperparameters = {\n", + " \"model_name_or_path\": \"t5-base\",\n", + " \"task_name\": \"seq2seq\",\n", + " \"logging_steps\": 100,\n", + " \"evaluation_strategy\": \"steps\",\n", + " \"eval_steps\": 200,\n", + " # Only need to set do_eval when validation channel is not provided and want to generate:\n", + " \"do_eval\": \"1\",\n", + " \"save_strategy\": \"steps\",\n", + " \"save_steps\": 200,\n", + " \"learning_rate\": 5e-4,\n", + " \"per_device_train_batch_size\": 2,\n", + " \"per_device_eval_batch_size\": 4,\n", + " \"seed\": 1337,\n", + "\n", + " \"num_train_epochs\": 5, # Set high to drive via early stopping\n", + " \"early_stopping_patience\": 4, # Usually stops after <25 epochs on this sample data+config\n", + " \"metric_for_best_model\": \"eval_acc\",\n", + " # \"greater_is_better\": \"false\",\n", + " # # Early stopping implies checkpointing every evaluation (epoch), so limit the total checkpoints\n", + " # # kept to avoid filling up disk:\n", + " \"save_total_limit\": 10,\n", + "}\n", + "\n", + "\n", + "metric_definitions = [\n", + " {\"Name\": \"epoch\", \"Regex\": util.training.get_hf_metric_regex(\"epoch\")},\n", + " {\"Name\": \"learning_rate\", \"Regex\": util.training.get_hf_metric_regex(\"learning_rate\")},\n", + " {\"Name\": \"train:loss\", \"Regex\": util.training.get_hf_metric_regex(\"loss\")},\n", + " {\n", + " \"Name\": \"validation:n_examples\",\n", + " \"Regex\": util.training.get_hf_metric_regex(\"eval_n_examples\"),\n", + " },\n", + " {\"Name\": \"validation:loss_avg\", \"Regex\": util.training.get_hf_metric_regex(\"eval_loss\")},\n", + " {\"Name\": \"validation:acc\", \"Regex\": util.training.get_hf_metric_regex(\"eval_acc\")},\n", + "]\n", + "\n", + "estimator = HuggingFaceEstimator(\n", + " role=sagemaker.get_execution_role(),\n", + " entry_point=\"train.py\",\n", + " source_dir=\"src\",\n", + " py_version=None,\n", + " pytorch_version=None,\n", + " transformers_version=None,\n", + " image_uri=train_image_uri, # Use the customized training container image\n", + "\n", + " base_job_name=\"t5-datenorm\",\n", + " output_path=f\"s3://{bucket_name}/{bucket_prefix}trainjobs\",\n", + "\n", + " instance_type=\"ml.g4dn.xlarge\", # Could also consider ml.p3.2xlarge\n", + " instance_count=1,\n", + " volume_size=40,\n", + "\n", + " debugger_hook_config=False,\n", + "\n", + " hyperparameters=hyperparameters,\n", + " metric_definitions=metric_definitions,\n", + " environment={\n", + " # Required for our custom dataset loading code (which depends on tokenizer):\n", + " \"TOKENIZERS_PARALLELISM\": \"false\",\n", + " },\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "f19603d2-351b-42ad-8d2e-4e709ca85c46", + "metadata": {}, + "source": [ + "There is no `textract` input data channel for this job, as both the `training` and `validation` datasets simply provide plain text.\n", + "\n", + "Run the cell below to kick off the job and view logs.\n", + "\n", + "> ⏰ In our tests, the training took about 30 minutes to complete on an `ml.g4dn.xlarge` instance in default configuration" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "38eafacf-0f50-4a5f-9bc6-cdda0e315bb7", + "metadata": { + "scrolled": true, + "tags": [] + }, + "outputs": [], + "source": [ + "inputs = {\n", + " \"train\": train_s3uri,\n", + " \"validation\": validation_s3uri,\n", + "}\n", + "\n", + "estimator.fit(inputs)" + ] + }, + { + "cell_type": "markdown", + "id": "e12f9b2d-3ad3-4845-a127-b10131a03902", + "metadata": {}, + "source": [ + "Once the training is complete, you have a model ready to normalize detected dates to specific target formats.\n", + "\n", + "As discussed in the main solution notebooks, you can also 'attach' the notebook to previously-completed training jobs as shown below:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "20e00d35-f89c-48d2-ae2f-4ea253fab41b", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "#estimator = HuggingFaceEstimator.attach(\"t5-datenorm-2023-01-09-12-19-12-377\")" + ] + }, + { + "cell_type": "markdown", + "id": "1cbf1dda-4099-4dfc-89f7-dcc4314b89eb", + "metadata": {}, + "source": [ + "### Deploy for inference\n", + "\n", + "Model deployment is similar to the entity recognition and other models shown in this solution. Note that for this endpoint we'll set up a [real-time inference endpoint](https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-endpoints.html) (not specifying an `async_inference_config` as with some other examples), and use a separate [inference_seq2seq.py](src/inference_seq2seq.py) entrypoint because the handling logic is quite different from standard `inference.py` models that consume Amazon Textract JSON." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "70b63c81-a95d-4cf2-a99d-ab215ef6afbd", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from sagemaker.huggingface import HuggingFaceModel\n", + "\n", + "# Look up the model artifact location from the training job:\n", + "training_job_desc = estimator.latest_training_job.describe()\n", + "model_s3uri = training_job_desc[\"ModelArtifacts\"][\"S3ModelArtifacts\"]\n", + "model_name = training_job_desc[\"TrainingJobName\"]\n", + "\n", + "# Make sure we don't accidentally re-use same model:\n", + "try:\n", + " smclient.delete_model(ModelName=model_name)\n", + " print(f\"Deleted existing model {model_name}\")\n", + "except smclient.exceptions.ClientError as e:\n", + " if not (\n", + " e.response[\"Error\"][\"Code\"] in (404, \"404\")\n", + " or e.response[\"Error\"].get(\"Message\", \"\").startswith(\"Could not find model\")\n", + " ):\n", + " raise e\n", + "\n", + "model = HuggingFaceModel(\n", + " name=model_name,\n", + " model_data=model_s3uri,\n", + " role=sagemaker.get_execution_role(),\n", + " source_dir=\"src/\",\n", + " entry_point=\"inference_seq2seq.py\",\n", + " py_version=None,\n", + " pytorch_version=None,\n", + " transformers_version=None,\n", + " image_uri=inf_image_uri,\n", + " env={\n", + " \"PYTHONUNBUFFERED\": \"1\", # TODO: Disable once debugging is done\n", + " },\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "93e06e94-6e4c-4d04-a03e-91c55789e432", + "metadata": {}, + "outputs": [], + "source": [ + "# Delete previous endpoint, if already in use:\n", + "try:\n", + " predictor.delete_endpoint(delete_endpoint_config=True)\n", + " print(\"Deleting previous endpoint...\")\n", + " time.sleep(8)\n", + "except (NameError, smclient.exceptions.ResourceNotFound):\n", + " pass # No existing endpoint to delete\n", + "except smclient.exceptions.ClientError as e:\n", + " if \"Could not find\" not in e.response[\"Error\"].get(\"Message\", \"\"):\n", + " raise e\n", + "\n", + "print(\"Deploying model...\")\n", + "predictor = model.deploy(\n", + " endpoint_name=training_job_desc[\"TrainingJobName\"],\n", + " initial_instance_count=1,\n", + " instance_type=\"ml.m5.large\",\n", + " serializer=sagemaker.serializers.JSONSerializer(),\n", + " deserializer=sagemaker.deserializers.JSONDeserializer(),\n", + ")\n", + "print(\"\\nDone!\")" + ] + }, + { + "cell_type": "markdown", + "id": "4242c33f-0539-44eb-994a-b55d9451a684", + "metadata": {}, + "source": [ + "### Validate the endpoint\n", + "\n", + "Once the model is deployed, we can run (some or all of) the evaluation dataset through it to validate performance - as shown below.\n", + "\n", + "> ⏰ In our tests, it took about a minute to run the full evaluation dataset through the model. For a faster turnaround, you could process just the first N samples of the dataset by instead running e.g. `eval_results = eval_dataset.select(range(N)).map(...`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "961278f4-1d87-42f6-9675-d7c03642ab91", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import datasets\n", + "\n", + "eval_dataset = datasets.load_from_disk(\"data/seq2seq-validation\")\n", + "\n", + "\n", + "def predict_batch(batch):\n", + " \"\"\"Run a dataset batch through the SageMaker endpoint and check per-example correctness\"\"\"\n", + " input_texts = batch[\"src_texts\"]\n", + " result = predictor.predict({\"inputs\": input_texts})\n", + " result[\"correct\"] = [\n", + " gen == batch[\"tgt_texts\"][ix] for ix, gen in enumerate(result[\"generated_text\"])\n", + " ]\n", + " return {**batch, **result}\n", + "\n", + "\n", + "eval_results = eval_dataset.map(\n", + " predict_batch,\n", + " desc=\"Running inference\",\n", + " batched=True,\n", + " batch_size=16,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "9c5c1fef-1ba1-49d8-903a-8b601a5f28c3", + "metadata": {}, + "source": [ + "Below we measure overall \"accuracy\" on this evaluation set and print out some examples, to demonstrate performance:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b26141de-2304-4ee9-a4a0-e3db89fa41ab", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Calculate overall accuracy:\n", + "n_correct = sum(eval_results[\"correct\"])\n", + "n_total = len(eval_results)\n", + "print(\n", + " \"{} of {} samples correct.\\n Overall accuracy: {:.2%}\".format(\n", + " n_correct, n_total, n_correct / n_total\n", + " )\n", + ")\n", + "\n", + "# Present some examples from the dataset:\n", + "pd.DataFrame(eval_results)" + ] + }, + { + "cell_type": "markdown", + "id": "23c9c260-9633-43c5-b9f5-e01ca23431bb", + "metadata": {}, + "source": [ + "As shown above, this text-to-text model can take in a raw detected date mention (e.g. `Sunday Dec 31st 2000`) with a prompt prefix (e.g. `Convert dates to YYYY-MM-DD: `) and attempt to output the desired normalized format (e.g. `2000-12-31`)." + ] + }, + { + "cell_type": "markdown", + "id": "868d1622-e1f2-4c51-8f56-cce0db913676", + "metadata": { + "tags": [] + }, + "source": [ + "### Integrate with processing pipeline\n", + "\n", + "So how can such a field normalizing model be integrated with the overall document processing pipeline?\n", + "\n", + "In fact, the **post-processing Lambda function** invoked after our entity recognition model to extract and consolidate entities, is able to call out to additional \"normalizing\" models where required.\n", + "\n", + "These are configured through the same **entity/field type configuration** we originally set up for the pipeline in [Notebook 1 (Data Preparation)](1.%20Data%20Preparation.ipynb).\n", + "\n", + "First, we can load up the current pipeline entity configuration:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "147b502b-d807-44ce-8818-0345eb059fe4", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "print(\"Loading current pipeline field configuration...\")\n", + "# Load JSON text from AWS SSM Parameter Store:\n", + "# (If this fails, you could also try reading from data/field-config.json)\n", + "fields_json = ssm.get_parameter(Name=config.entity_config_param)[\"Parameter\"][\"Value\"]\n", + "# Parse the JSON into Python config classes:\n", + "fields = [\n", + " util.postproc.config.FieldConfiguration.from_dict(cfg)\n", + " for cfg in json.loads(fields_json)\n", + "]\n", + "print(\"Done\")" + ] + }, + { + "cell_type": "markdown", + "id": "4d53907a-2610-40b2-a28a-c51687228bc6", + "metadata": {}, + "source": [ + "Next, find any entity type that looks like a date (any with 'date' in the name), and configure the normalizer for those fields:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "014d5107-40c2-4a5d-afd3-3be4a9e2a6d3", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "for f in fields:\n", + " if \"date\" in f.name.lower():\n", + " print(f\"Found date field: {f.name}\")\n", + " f.normalizer_endpoint = predictor.endpoint_name\n", + " print(f\" - Setting normalizer_endpoint = '{f.normalizer_endpoint}'\")\n", + " f.normalizer_prompt = \"Convert dates to YYYY-MM-DD: \"\n", + " print(f\" - Setting normalizer_prompt = '{f.normalizer_prompt}'\")" + ] + }, + { + "cell_type": "markdown", + "id": "bbca20a2-0d9e-46e5-bc35-e3934dd28997", + "metadata": {}, + "source": [ + "When you're happy with the updated field configuration, you can run the below to update the pipeline parameter:\n", + "\n", + "You may also like to check these updates in the [AWS Systems Manager Parameter Store console](https://console.aws.amazon.com/systems-manager/parameters/?&tab=Table)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1b9b7dff-8070-4d7a-a59a-fcc82807ab66", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "print(\"Saving new field configuration locally...\")\n", + "with open(\"data/field-config.json\", \"w\") as f:\n", + " f.write(json.dumps(\n", + " [cfg.to_dict() for cfg in fields],\n", + " indent=2,\n", + " ))\n", + "\n", + "print(\"Uploading new field configuration to pipeline...\")\n", + "pipeline_entity_config = json.dumps([f.to_dict(omit=[\"annotation_guidance\"]) for f in fields], indent=2)\n", + "ssm.put_parameter(\n", + " Name=config.entity_config_param,\n", + " Overwrite=True,\n", + " Value=pipeline_entity_config,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "91826510-9e89-479a-a5f8-054b0c4074d3", + "metadata": {}, + "source": [ + "After updating your pipeline's field configuration SSM parameter to set `normalizer_endpoint` and `normalizer_prompt` on your target entity types, your pipeline's Post-processing Lambda function should automatically start calling your SageMaker model endpoints to normalize mentions on the relevant fields. For example with the Credit Card Agreements sample data, you should see that `Agreement Effective Date` results start to show in `YYYY-MM-DD` format instead of the document's source format, when reviewing results in Amazon A2I or the Step Functions console.\n", + "\n", + "> ⚠️ **Note:** There may be a few minutes' delay before normalization starts to take effect, if your post-processing Lambda is configured to cache the SSM configuration. Check your AWS Lambda logs for error messages, in case normalization model calls are failing.\n", + "\n", + "This example of normalizing individual extracted date fields is just one option in a spectrum of ways you could combine generative and extractive models for document understanding. For example, you could:\n", + "\n", + "- Train additional normalization types, for example for other data types or to fix common OCR error patterns\n", + "- Include more context from around the original mention, to help the model perform better (such as interpreting whether a raw date is likely to be DD/MM or MM/DD given other information)\n", + "- Explore linking generative and layout-aware aspects into one end-to-end trainable model" + ] + }, + { + "cell_type": "markdown", + "id": "5a9b4159-ad1a-4ecf-886c-7efe183b37c4", + "metadata": {}, + "source": [ "---\n", "\n", "*[Back to contents](#Contents)*" @@ -806,9 +1422,9 @@ "metadata": { "instance_type": "ml.t3.medium", "kernelspec": { - "display_name": "Python 3 (Data Science 3.0)", + "display_name": "Python 3 (PyTorch 1.10 Python 3.8 CPU Optimized)", "language": "python", - "name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-east-1:081325390199:image/sagemaker-data-science-310-v1" + "name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-east-1:081325390199:image/pytorch-1.10-cpu-py38" }, "language_info": { "codemirror_mode": { @@ -820,7 +1436,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.6" + "version": "3.8.10" } }, "nbformat": 4, diff --git a/notebooks/src/code/config.py b/notebooks/src/code/config.py index aae9c05..bfc9f9d 100644 --- a/notebooks/src/code/config.py +++ b/notebooks/src/code/config.py @@ -7,7 +7,7 @@ from dataclasses import dataclass, field import os import tarfile -from typing import Optional +from typing import Optional, Sequence, Tuple # External Dependencies: try: @@ -19,15 +19,15 @@ from transformers.trainer_utils import IntervalStrategy -def get_n_cpus(): +def get_n_cpus() -> int: return int(os.environ.get("SM_NUM_CPUS", len(os.sched_getaffinity(0)))) -def get_n_gpus(): +def get_n_gpus() -> int: return int(os.environ.get("SM_NUM_GPUS", 0)) -def get_default_num_workers(): +def get_default_num_workers() -> int: """Choose a sensible default dataloader_num_workers based on available hardware""" n_cpus = get_n_cpus() n_gpus = get_n_gpus() @@ -91,14 +91,15 @@ class SageMakerTrainingArguments(TrainingArguments): metadata={"help": "TQDM progress bars are disabled by default for SageMaker/CloudWatch."}, ) do_eval: bool = field( - # Users should not set this typical param directly - default=True, + default=None, metadata={ - "help": "This value is overridden by presence or absence of the `validation` channel" + "help": ( + "This value is normally set by the presence or absence of the 'validation' " + "channel, but can be explicitly overridden." + ) }, ) do_train: bool = field( - # Users should not set this typical param directly default=True, metadata={"help": "Set false to disable training (for validation-only jobs)"}, ) @@ -344,7 +345,8 @@ class DataTrainingArguments: default="ner", metadata={ "help": "The name of the task. This script currently supports 'ner' for entity " - "recognition or 'mlm' for pre-training (masked language modelling)." + "recognition, 'mlm' for pre-training (masked language modelling), or 'seq2seq' for " + "experimental (non-layout-aware) sequence-to-sequence data normalizations." }, ) textract: Optional[str] = field( @@ -408,19 +410,24 @@ class DataTrainingArguments: ) def __post_init__(self): - if not self.textract: - raise ValueError("'textract' (Folder of Textract result JSONs) channel is mandatory") self.task_name = self.task_name.lower() + if (not self.textract) and (self.task_name != "seq2seq"): + raise ValueError("'textract' (Folder of Textract result JSONs) channel is mandatory") -def parse_args(cmd_args=None): +def parse_args( + cmd_args: Optional[Sequence[str]] = None, +) -> Tuple[ModelArguments, DataTrainingArguments, SageMakerTrainingArguments]: """Parse config arguments from the command line, or cmd_args instead if supplied""" parser = HfArgumentParser((ModelArguments, DataTrainingArguments, SageMakerTrainingArguments)) model_args, data_args, training_args = parser.parse_args_into_dataclasses(args=cmd_args) - # Auto-set activities depending which channels were provided: - training_args.do_eval = bool(data_args.validation) + # Auto-set activities depending which channels were provided. + # By only overriding do_eval if it's not explicitly specified, we allow override e.g. to force + # validation in a job where no external dataset is provided but a synthetic one can be generated + if training_args.do_eval is None: + training_args.do_eval = bool(data_args.validation) if not training_args.do_eval: training_args.evaluation_strategy = "no" diff --git a/notebooks/src/code/data/__init__.py b/notebooks/src/code/data/__init__.py index b1d1f85..fac6ccc 100644 --- a/notebooks/src/code/data/__init__.py +++ b/notebooks/src/code/data/__init__.py @@ -16,6 +16,7 @@ from .base import TaskData from .mlm import get_task as get_mlm_task from .ner import get_task as get_ner_task +from .seq2seq import get_task as get_seq2seq_task def get_datasets( @@ -40,5 +41,11 @@ def get_datasets( return get_ner_task( data_args, tokenizer, processor, n_workers=n_workers, cache_dir=cache_dir ) + elif data_args.task_name == "seq2seq": + return get_seq2seq_task( + data_args, tokenizer, processor, n_workers=n_workers, cache_dir=cache_dir + ) else: - raise ValueError("Unknown task '%s' is not 'mlm' or 'ner'" % data_args.task_name) + raise ValueError( + "Unknown task '%s' is not in 'mlm', 'ner', 'seq2seq'" % data_args.task_name + ) diff --git a/notebooks/src/code/data/seq2seq/__init__.py b/notebooks/src/code/data/seq2seq/__init__.py new file mode 100644 index 0000000..95dc524 --- /dev/null +++ b/notebooks/src/code/data/seq2seq/__init__.py @@ -0,0 +1,9 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: MIT-0 +"""Data utilities for generative, sequence-to-sequence tasks + +This task is experimental, and does not currently support layout-aware models. As shown in the +'Optional Extras' notebook, you can use it to train separate post-processing models to normalize +extracted fields: For example converting the format of dates. +""" +from .task_builder import get_task diff --git a/notebooks/src/code/data/seq2seq/date_normalization.py b/notebooks/src/code/data/seq2seq/date_normalization.py new file mode 100644 index 0000000..581f8cf --- /dev/null +++ b/notebooks/src/code/data/seq2seq/date_normalization.py @@ -0,0 +1,223 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: MIT-0 +"""Synthetic dataset generation for date normalization via seq2seq text modelling + +This script provides utilities for tackling date field format normalization as a conditional +language modelling task. For example, training a model with input like: + + "Convert dates to YYYY-MM-DD: 31/12/2000" + +...into a target output sequence like "2000-12-31". + +In the plain text case, it's relatively straightforward to generate synthetic data for this task +as shown here. By modifying the distribution of randomly generated dates, the likelihood of +different observed formats in the source document and target formats in the prompt, we can tailor +model performance to match target use-case without having to write extensive text parsing rules. +""" +# Python Built-Ins: +from dataclasses import dataclass +from logging import getLogger +import time +from typing import List, Optional, Sequence + +# External Dependencies: +from datasets import Dataset, DatasetInfo +import numpy as np + + +logger = getLogger("data.seq2seq.dates") + + +@dataclass +class DateFormatConfig: + """Configuration describing a date format for date normalization tasks + + Parameters + ---------- + format_str : + A formal `time.strftime`-compatible specifier for the date format, for example `%Y-%m-%d`. + format_name : + A human-friendly identifier for the format, as might be used in task prompts. For example + `YYYY-MM-DD` for a prompt like "Convert dates to YYYY-MM-DD". + observed_weight : + Weight/frequency with which this date format will be observed in content, for synthetic data + generation. Does not need to be normalized to 1.0 across all your configured formats, + because the dataset generator will ensure this for you. + target_weight : + Weight/frequency with which this date format will be used as the target for prompting, for + synthetic data generation. Does not need to be normalized to 1.0 across all your configured + formats, because the dataset generator will ensure this for you. + """ + + format_str: str + format_name: str + observed_weight: float + target_weight: float + + +# Format configuration for synthetic date normalization training data generation +DATE_FORMAT_CONFIGS = [ + DateFormatConfig("%Y-%m-%d", "YYYY-MM-DD", observed_weight=0.1, target_weight=0.7), + DateFormatConfig("%d/%m/%y", "DD/MM/YY", observed_weight=0.35, target_weight=0.05), + DateFormatConfig("%d/%m/%Y", "DD/MM/YYYY", observed_weight=0.35, target_weight=0.2), + DateFormatConfig("%m/%d/%y", "MM/DD/YY", observed_weight=0.1, target_weight=0.02), + DateFormatConfig("%m/%d/%Y", "MM/DD/YYYY", observed_weight=0.05, target_weight=0.03), + # Including day names: + DateFormatConfig("%a %d %b %y", "DDD DD MM YY", observed_weight=0.05, target_weight=0.0), + DateFormatConfig("%a. %d %b %y", "DDD. DD MM YY", observed_weight=0.05, target_weight=0.0), + DateFormatConfig("%A %b %d %y", "DDDD MM DD YY", observed_weight=0.01, target_weight=0.0), + DateFormatConfig("%A %b %d %y", "DDDD MM DDst YY", observed_weight=0.01, target_weight=0.0), + DateFormatConfig("%A %b %d %y", "DDDD MM DDnd YY", observed_weight=0.01, target_weight=0.0), + DateFormatConfig("%A %b %d %y", "DDDD MM DDrd YY", observed_weight=0.01, target_weight=0.0), + DateFormatConfig("%A %b %d %y", "DDDD MM DDth YY", observed_weight=0.01, target_weight=0.0), + DateFormatConfig("%A, %b %d %y", "DDDD, MM DD YY", observed_weight=0.02, target_weight=0.0), + # Including times: + DateFormatConfig( + "%Y-%m-%d %H:%M:%S", "YYYY-MM-DD HH:mm:ss", observed_weight=0.02, target_weight=0.0 + ), + DateFormatConfig("%d/%m/%y %H:%M", "DD/MM/YY HH:mm", observed_weight=0.02, target_weight=0.0), + DateFormatConfig("%H:%M %d/%m/%y", "HH:mm DD/MM/YY", observed_weight=0.02, target_weight=0.0), + DateFormatConfig( + "%I:%M%p %d/%m/%Y", "hh:mmp DD/MM/YYYY", observed_weight=0.02, target_weight=0.0 + ), + DateFormatConfig("%H:%M %d/%m/%Y", "HH:mm DD/MM/YYYY", observed_weight=0.02, target_weight=0.0), + DateFormatConfig( + "%d/%m/%Y %I:%M%p", "DD/MM/YYYY hh:mmp", observed_weight=0.02, target_weight=0.0 + ), + DateFormatConfig("%d/%m/%Y %H:%M", "DD/MM/YYYY HH:mm", observed_weight=0.02, target_weight=0.0), + DateFormatConfig("%m/%d/%y", "MM/DD/YY", observed_weight=0.02, target_weight=0.0), + DateFormatConfig( + "%d/%m/%y %I:%M%p", "DD/MM/YY hh:mmp", observed_weight=0.02, target_weight=0.0 + ), + DateFormatConfig("%d/%m/%y %H:%M", "DD/MM/YY HH:mm", observed_weight=0.02, target_weight=0.0), +] + + +def random_times_between( + start: time.struct_time, + end: time.struct_time, + n: int = 1, + rng: Optional[np.random.Generator] = None, +) -> List[time.struct_time]: + """Generate uniformly random datetimes between `start` and `end` + + Parameters + ---------- + start : + Start of the date/time window (Generate with e.g. `time.strptime()`). + end : + End of the date/time window (Generate with e.g. `time.strptime()`). + n : + Number of samples to generate. + rng : + Optional numpy random generator. Provide this to speed things up and enable reproducibility. + + Returns + ------- + datetimes : + List of `n` generated date/times in the given window. You can convert these to string + representations via e.g. `time.strftime()`. + """ + # Create a RNG if one was not provided: + if rng is None: + rng = np.random.default_rng() + + # To treat the struct_times as numeric (so we can add randomized offsets), convert them into + # timestamps via mktime(): + start = time.mktime(start) + end = time.mktime(end) + + # Generate random offsets as a 0-1 proportion through the window: + props = rng.uniform(size=n) + + # localtime() is the inverse of mktime(), converting timestamps back to full time structs: + max_offset = end - start + return [time.localtime(start + p * max_offset) for p in props] + + +def generate_seq2seq_date_norm_dataset( + n: int, + configs: Sequence[DateFormatConfig] = DATE_FORMAT_CONFIGS, + from_date: time.struct_time = time.strptime("1950-01-01", "%Y-%m-%d"), + to_date: time.struct_time = time.strptime("2050-01-01", "%Y-%m-%d"), + rng: Optional[np.random.Generator] = None, +) -> Dataset: + """Generate a synthetic seq2seq task dataset for date normalization in text + + Parameters + ---------- + n : + Number of examples to generate + configs : + Sequence of date format configuration objects describing the date formats to use and their + relative frequencies in source texts and target requests. + from_date : + Start of the date window that randomly generated dates should fall within. + to_date : + End of the date window that randomly generated dates should fall within. + rng : + Optional numpy random generator object. Provide this if you want reproducibility. + + Returns + ------- + dataset : + Hugging Face datasets.Dataset with fields `src_texts` (the input prompts) and `tgt_texts` + (the target outputs) for each generated example. + """ + # Create a RNG if one was not provided: + if rng is None: + rng = np.random.default_rng() + + # Normalize the observed_weights of the date format configurations: + observed_weights = [fmt.observed_weight for fmt in configs] + observed_weights_total = sum(observed_weights) + if observed_weights_total != 1.0: + logger.info(f"Normalizing observed_weights (summed to {observed_weights_total})") + observed_weights = [w / observed_weights_total for w in observed_weights] + # Select an observed format for the `n` input texts: + obs_choices = rng.choice( + len(observed_weights), + p=observed_weights, + size=(n,), + replace=True, + ) + + # Normalize the target_weights of the date format configurations + target_weights = [fmt.target_weight for fmt in configs] + target_weights_total = sum(target_weights) + if target_weights_total != 1.0: + logger.info(f"Normalizing target_weights (summed to {target_weights_total})") + target_weights = [w / target_weights_total for w in target_weights] + # Select a requested format for the `n` prompts: + target_choices = rng.choice( + len(target_weights), + p=target_weights, + size=(n,), + replace=True, + ) + + # Generate the `n` prompts & answers: + random_dates = random_times_between(from_date, to_date, n=n, rng=rng) + prompts = [] + answers = [] + for ix in range(n): + obs_config = configs[obs_choices[ix]] + target_config = configs[target_choices[ix]] + random_date = random_dates[ix] + prompt = "Convert dates to %s: %s" % ( + target_config.format_name, + time.strftime(obs_config.format_str, random_date), + ) + answer = time.strftime(target_config.format_str, random_date) + prompts.append(prompt) + answers.append(answer) + + return Dataset.from_dict( + { + "src_texts": prompts, + "tgt_texts": answers, + }, + info=DatasetInfo( + description="Synthetic dataset for T5-style seq2seq normalization of dates", + ), + ) diff --git a/notebooks/src/code/data/seq2seq/task_builder.py b/notebooks/src/code/data/seq2seq/task_builder.py new file mode 100644 index 0000000..5cbfe9c --- /dev/null +++ b/notebooks/src/code/data/seq2seq/task_builder.py @@ -0,0 +1,224 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: MIT-0 +"""Main 'task' builder for seq2seq tasks used by the model training script + +Collects seq2seq data (e.g. date normalization) into overall task format expected by the training +script (matching other tasks like NER, MLM). + +One interesting aspect of the conditional prompting framing of this seq2seq task, is that you could +train a single model to perform multiple kinds of field normalization by using different prompts. +For example "Convert date ...: ..." vs "Normalize currency ...: ..." and so on. + +Here we just show a single-task date normalizing example. +""" +# Python Built-Ins: +from logging import getLogger +from numbers import Real +import os +from typing import Callable, Dict, Optional, Union + +# External Dependencies: +import datasets +import numpy as np +from transformers import EvalPrediction, PreTrainedTokenizerBase +from transformers.processing_utils import ProcessorMixin +from transformers.utils.generic import PaddingStrategy, TensorType +from transformers.tokenization_utils_base import TruncationStrategy + +# Local Dependencies: +from ...config import DataTrainingArguments +from ..base import TaskData +from .date_normalization import generate_seq2seq_date_norm_dataset + + +logger = getLogger("data.seq2seq") + + +def _preprocess_seq2seq_dataset( + batch: Dict[str, list], + tokenizer: PreTrainedTokenizerBase, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_input_length: Optional[int] = None, + max_output_length: Optional[int] = None, + stride: int = 0, + is_split_into_words: bool = False, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Union[str, TensorType, None] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, +) -> Dict[str, list]: + """map fn to tokenize a seq2seq dataset ready for use in training + + TODO: Should we use a DataCollator for per-batch tokenization instead? + """ + # encode the documents + prompts = batch["src_texts"] + answers = batch["tgt_texts"] + + # Encode the inputs: + model_inputs = tokenizer( + prompts, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_input_length, + stride=stride, + is_split_into_words=is_split_into_words, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + ) + + # Encode the targets: + labels = tokenizer( + answers, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_output_length, + stride=stride, + is_split_into_words=is_split_into_words, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + ).input_ids + + # important: we need to replace the index of the padding tokens by -100 + # such that they are not taken into account by the CrossEntropyLoss + labels_with_ignore_index = [] + for labels_example in labels: + labels_example = [label if label != 0 else -100 for label in labels_example] + labels_with_ignore_index.append(labels_example) + + model_inputs["labels"] = labels_with_ignore_index + return model_inputs + + +def get_metric_computer( + tokenizer: PreTrainedTokenizerBase, +) -> Callable[[EvalPrediction], Dict[str, Real]]: + """An 'accuracy' computer for seq2seq tasks that ignores outer whitespace and case. + + For our example task, it's reasonable to measure exact-match accuracy (since we're normalising + small text spans - not e.g. summarizing long texts to shorter paragraphs). Therefore this metric + computer checks exact accuracy, while allowing for variations in case and leading/trailing + whitespace. + """ + + def compute_metrics(p: EvalPrediction) -> Dict[str, Real]: + # Convert model output probs/logits to predicted token IDs: + predicted_token_ids = np.argmax(p.predictions[0], axis=2) + # Replace everything from the first token onward with padding (as eos + # would terminate generation in a normal generate() call) + for ix_batch, seq in enumerate(predicted_token_ids): + eos_token_matches = np.where(seq == tokenizer.eos_token_id) + if len(eos_token_matches) and len(eos_token_matches[0]): + first_eos_posn = eos_token_matches[0][0] + predicted_token_ids[ix_batch, first_eos_posn:] = tokenizer.pad_token_id + + gen_texts = [ + s.strip().lower() + for s in tokenizer.batch_decode(predicted_token_ids, skip_special_tokens=True) + ] + + target_texts = [ + s.strip().lower() + for s in tokenizer.batch_decode( + # Replace label '-100' tokens (ignore index for BinaryCrossEntropy) with '0' ( + # token), to avoid an OverflowError when trying to decode the target text: + np.maximum(0, p.label_ids), + skip_special_tokens=True, + ) + ] + + n_examples = len(gen_texts) + n_correct = sum(1 for gen, target in zip(gen_texts, target_texts) if gen == target) + return { + "n_examples": len(gen_texts), + "acc": n_correct / n_examples, + } + + return compute_metrics + + +def get_task( + data_args: DataTrainingArguments, + tokenizer: PreTrainedTokenizerBase, + processor: Optional[ProcessorMixin] = None, + # model_param_names: Optional[Iterable[str]] = None, + n_workers: Optional[int] = None, + cache_dir: Optional[str] = None, +) -> TaskData: + """Load datasets and data collators for seq2seq model training""" + logger.info("Getting seq2seq datasets") + + # TODO: Currently non-reproducible, but we don't have access to CLI arg seed here + # Best practice for now would be to generate your dataset before running training anyway, + # instead of relying on ephemeral dataset generation within the job. + rng = np.random.default_rng() + + # Load or create the training and validation datasets: + if data_args.train: + logger.info("Loading seq2seq training dataset from disk %s", data_args.train) + train_dataset = datasets.load_from_disk(data_args.train) + else: + logger.info("Generating new synthetic seq2seq training dataset") + train_dataset = generate_seq2seq_date_norm_dataset(n=1000, rng=rng) + + if data_args.validation: + logger.info("Loading seq2seq validation dataset from disk %s", data_args.validation) + eval_dataset = datasets.load_from_disk(data_args.validation) + else: + logger.info("Generating new synthetic seq2seq validation dataset") + eval_dataset = generate_seq2seq_date_norm_dataset(n=200, rng=rng) + + # Pre-process the datasets with the tokenizer: + preproc_kwargs = { + "max_input_length": data_args.max_seq_length - 2, # To allow for CLS+SEP in final + "max_output_length": 64, # TODO: Parameterize? + "pad_to_multiple_of": data_args.pad_to_multiple_of, + "padding": "max_length", + "tokenizer": tokenizer, + } + train_dataset = train_dataset.map( + _preprocess_seq2seq_dataset, + batched=True, + cache_file_name=(os.path.join(cache_dir, "seq2seqtrain.arrow") if cache_dir else None), + num_proc=n_workers, + remove_columns=train_dataset.column_names, + fn_kwargs=preproc_kwargs, + ) + eval_dataset = eval_dataset.map( + _preprocess_seq2seq_dataset, + batched=True, + cache_file_name=(os.path.join(cache_dir, "seq2seqeval.arrow") if cache_dir else None), + num_proc=n_workers, + remove_columns=eval_dataset.column_names, + fn_kwargs=preproc_kwargs, + ) + + return TaskData( + train_dataset=train_dataset, + data_collator=None, + eval_dataset=eval_dataset, + metric_computer=get_metric_computer(tokenizer), + ) diff --git a/notebooks/src/code/inference_seq2seq.py b/notebooks/src/code/inference_seq2seq.py new file mode 100644 index 0000000..d5f8096 --- /dev/null +++ b/notebooks/src/code/inference_seq2seq.py @@ -0,0 +1,130 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: MIT-0 +"""Alternative SageMaker inference wrapper for text-only (non-multimodal) seq2seq models + +These models are optionally deployed alongside the core layout-aware NER model, to normalize +detected entity mentions. + +API Usage +--------- + +All requests and responses in 'application/json'. The model takes a dict with key `inputs` which +may be a text string or a list of strings. It will return a dict with key `generated_text` +containing either a text string or a list of strings (as per the input). +""" + +# Python Built-Ins: +import json +import os +from typing import Dict, List, Union + +# External Dependencies: +from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline +import torch + +# Local Dependencies: +from . import logging_utils + +logger = logging_utils.getLogger("infcustom") +logger.info("Loading custom inference handlers") +# If you need to debug this script and aren't seeing any logging in CloudWatch, try setting the +# following on the Model to force flushing log calls through: env={ "PYTHONUNBUFFERED": "1" } + +# Configurations: +INFERENCE_BATCH_SIZE = int(os.environ.get("INFERENCE_BATCH_SIZE", "4")) +PAD_TO_MULTIPLE_OF = os.environ.get("PAD_TO_MULTIPLE_OF", "8") +PAD_TO_MULTIPLE_OF = None if PAD_TO_MULTIPLE_OF in ("None", "") else int(PAD_TO_MULTIPLE_OF) + + +def input_fn(input_bytes, content_type: str): + """Deserialize and pre-process model request JSON + + Requests must be of type application/json. See module-level docstring for API details. + """ + logger.info(f"Received request of type:{content_type}") + if content_type != "application/json": + raise ValueError("Content type must be application/json") + + req_json = json.loads(input_bytes) + if "inputs" not in req_json: + raise ValueError( + "Request JSON must contain field 'inputs' with either a text string or a list of text " + "strings" + ) + return req_json["inputs"] + + +# No custom output_fn needed as result is plain JSON fully prepared by predict_fn + + +def model_fn(model_dir) -> dict: + """Load model artifacts from model_dir into a dict + + Returns + ------- + pipeline : transformers.pipeline + HF Pipeline for text generation inference + """ + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + tokenizer = AutoTokenizer.from_pretrained( + model_dir, + pad_to_multiple_of=PAD_TO_MULTIPLE_OF, + # TODO: Is it helpful to use_fast=True? + ) + model = AutoModelForSeq2SeqLM.from_pretrained(model_dir) + model.eval() + model.to(device) + + pl = pipeline( + "text2text-generation", + model=model, + tokenizer=tokenizer, + batch_size=INFERENCE_BATCH_SIZE, + # num_workers as per default + device=model.device, + ) + + logger.info("Model loaded") + return { + # Could return other objects e.g. `model` and `tokenizer`` for debugging + "pipeline": pl, + } + + +def predict_fn( + input_data: Union[str, List[str]], + model_data: dict, +) -> Dict[str, Union[str, List[str]]]: + """Generate text outputs from an input or list of inputs + + Parameters + ---------- + input_data : + Input text string or list of input text strings (including prompts if needed) + model_data : { pipeline } + Trained model data loaded by model_fn, including a `pipeline`. + + Returns + ------- + result : + Dict including key `generated_text`, which will either be a text string (if `input_data` was + a single string) or a list of strings (if `input_data` was a list). + """ + pl = model_data["pipeline"] + + # Use transformers Pipelines to simplify the inference process and handle e.g. batching and + # tokenization for us: + result = pl(input_data, clean_up_tokenization_spaces=True) + + # Convert output from list of dicts to dict of lists: + result = {k: [r[k] for r in result] for k in result[0].keys()} + # Strip any leading/trailing whitespace from results: + result["generated_text"] = [t.strip() for t in result["generated_text"]] + + # If input was a plain string (instead of a list of strings), remove the batch dimension from + # outputs too: + if isinstance(input_data, str): + for k in result: + result[k] = result[k][0] + + return result diff --git a/notebooks/src/code/models/__init__.py b/notebooks/src/code/models/__init__.py new file mode 100644 index 0000000..74ca562 --- /dev/null +++ b/notebooks/src/code/models/__init__.py @@ -0,0 +1,3 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: MIT-0 +"""Some model implementation customizations""" diff --git a/notebooks/src/code/train.py b/notebooks/src/code/train.py index 5b9288a..2c6998d 100644 --- a/notebooks/src/code/train.py +++ b/notebooks/src/code/train.py @@ -1,16 +1,21 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: MIT-0 """Train HuggingFace LayoutLM on Amazon Textract results + +(This script also allows for training non-layout-aware models e.g. T5 on seq2seq conditional text +generation task) """ # Python Built-Ins: from inspect import signature import os import shutil +from typing import Optional, Tuple # External Dependencies: from torch import distributed as dist from transformers import ( AutoConfig, + AutoModelForSeq2SeqLM, AutoModelForMaskedLM, AutoModelForTokenClassification, AutoProcessor, @@ -19,8 +24,13 @@ LayoutLMv2Config, LayoutXLMProcessor, LayoutXLMTokenizerFast, + PretrainedConfig, + PreTrainedModel, + PreTrainedTokenizerBase, PreTrainedTokenizerFast, + ProcessorMixin, set_seed, + Trainer, ) from transformers.file_utils import EntryNotFoundError from transformers.trainer_utils import get_last_checkpoint @@ -35,8 +45,10 @@ logger = logging_utils.getLogger("main") -def get_model(model_args: config.ModelArguments, data_args: config.DataTrainingArguments): - """Load pre-trained Config, Model and Tokenizer""" +def get_model( + model_args: config.ModelArguments, data_args: config.DataTrainingArguments +) -> Tuple[PretrainedConfig, PreTrainedModel, PreTrainedTokenizerFast, Optional[ProcessorMixin]]: + """Load pre-trained Config, Model, Tokenizer, and Processor if one exists""" config = AutoConfig.from_pretrained( model_args.config_name if model_args.config_name else model_args.model_name_or_path, num_labels=data_args.num_labels, @@ -75,7 +87,15 @@ def get_model(model_args: config.ModelArguments, data_args: config.DataTrainingA revision=model_args.model_revision, use_auth_token=True if model_args.use_auth_token else None, ) - tokenizer = processor.tokenizer + if hasattr(processor, "tokenizer"): + tokenizer = processor.tokenizer + elif isinstance(processor, PreTrainedTokenizerBase): + # AutoProcessor loaded something, but it's just a standard tokenizer. + # This happens e.g. with t5-base model as at HF transformers==4.25.1 + tokenizer = processor + processor = None + else: + tokenizer = None except (EntryNotFoundError, OSError): processor = None tokenizer = None @@ -154,6 +174,15 @@ def get_model(model_args: config.ModelArguments, data_args: config.DataTrainingA revision=model_args.model_revision, use_auth_token=True if model_args.use_auth_token else None, ) + elif data_args.task_name == "seq2seq": + model = AutoModelForSeq2SeqLM.from_pretrained( + model_args.model_name_or_path, + from_tf=bool(".ckpt" in model_args.model_name_or_path), + config=config, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) else: raise ValueError( f"Unknown data_args.task_name '{data_args.task_name}' not in ('mlm', 'ner')" @@ -165,7 +194,7 @@ def train( model_args: config.ModelArguments, data_args: config.DataTrainingArguments, training_args: config.SageMakerTrainingArguments, -): +) -> Trainer: training_args._setup_devices # Force distributed setup if applicable and not already done logger.info("Started with local_rank %s", training_args.local_rank) # Don't strictly need this around the model setup too, but keeps logs more understandable: @@ -239,7 +268,7 @@ def train( model=model, args=training_args, train_dataset=datasets.train_dataset, - eval_dataset=datasets.eval_dataset if data_args.validation else None, + eval_dataset=datasets.eval_dataset, # No `tokenizer`, as either the dataset or the data_collator does it for us data_collator=datasets.data_collator, callbacks=[ @@ -323,7 +352,7 @@ def train( return trainer -def main(): +def main() -> None: """CLI script entry point to parse arguments and run training""" model_args, data_args, training_args = config.parse_args() diff --git a/notebooks/src/inference_seq2seq.py b/notebooks/src/inference_seq2seq.py new file mode 100644 index 0000000..09e01d1 --- /dev/null +++ b/notebooks/src/inference_seq2seq.py @@ -0,0 +1,5 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: MIT-0 +"""Load alternative inference handlers for seq2seq model deployment +""" +from code.inference_seq2seq import * diff --git a/pipeline/postprocessing/fn-postprocess/main.py b/pipeline/postprocessing/fn-postprocess/main.py index 48a133f..b8f8bb3 100644 --- a/pipeline/postprocessing/fn-postprocess/main.py +++ b/pipeline/postprocessing/fn-postprocess/main.py @@ -27,19 +27,21 @@ import json import logging import os -from typing import List # External Dependencies: import boto3 # General-purpose AWS SDK for Python import trp # Amazon Textract Response Parser +# Set up logging before local imports: +logger = logging.getLogger() +logger.setLevel(logging.INFO) + # Local Dependencies -from util.boxes import UniversalBox from util.config import FieldConfiguration +from util.extract import extract_entities +from util.normalize import normalize_detections -logger = logging.getLogger() -logger.setLevel(logging.INFO) s3 = boto3.resource("s3") ssm = boto3.client("ssm") @@ -74,7 +76,10 @@ def handler(event, context): doc = json.loads(s3.Bucket(srcbucket).Object(srckey).get()["Body"].read()) doc = trp.Document(doc) + # Pull out the entities from the Amazon Textract-format doc: entities = extract_entities(doc, entity_config) + # Normalize entity values, if any per-type normalizations are configured: + normalize_detections(entities, entity_config) result_fields = {} for ixtype, cfg in enumerate(cfg for cfg in entity_config if not cfg.ignore): @@ -173,91 +178,3 @@ def handler(event, context): ), "Fields": result_fields, } - - -class EntityDetection: - def __init__(self, trp_words, cls_id: int, cls_name: str, page_num: int): - self.cls_id = cls_id - self.cls_name = cls_name - self.page_num = page_num - - if len(trp_words) and not hasattr(trp_words[0], "id"): - trp_words_by_line = trp_words - trp_words_flat = [w for ws in trp_words for w in ws] - - else: - trp_words_by_line = [trp_words] - trp_words_flat = trp_words - self.bbox = UniversalBox.aggregate( - boxes=[UniversalBox(box=w.geometry.boundingBox) for w in trp_words_flat], - ) - self.blocks = list(map(lambda w: w.id, trp_words_flat)) - self.confidence = min( - map( - lambda w: min( - w._block.get("PredictedClassConfidence", 1.0), - w.confidence, - ), - trp_words_flat, - ) - ) - self.text = "\n".join( - map( - lambda words: " ".join([w.text for w in words]), - trp_words_by_line, - ) - ) - - def to_dict(self): - return { - "ClassId": self.cls_id, - "ClassName": self.cls_name, - "Confidence": self.confidence, - "Blocks": self.blocks, - "BoundingBox": self.bbox.to_dict(), - "PageNum": self.page_num, - "Text": self.text, - } - - def __repr__(self): - return json.dumps(self.to_dict()) - - -def extract_entities( - doc: trp.Document, - entity_config: List[FieldConfiguration], -) -> List[EntityDetection]: - entity_classes = {c.class_id: c.name for c in entity_config if not c.ignore} - detections = [] - - current_cls = None - current_entity = [] - for ixpage, page in enumerate(doc.pages): - for line in page.lines: # TODO: Lines InReadingOrder? - current_entity.append([]) - for word in line.words: - pred_cls = word._block.get("PredictedClass") - if pred_cls not in entity_classes: - pred_cls = None # Treat all non-config'd entities as "other" - - if pred_cls != current_cls: - if current_cls is not None: - detections.append( - EntityDetection( - trp_words=list( - filter( - lambda ws: len(ws), - current_entity, - ) - ), - cls_id=current_cls, - cls_name=entity_classes[current_cls], - page_num=ixpage + 1, - ) - ) - current_cls = pred_cls - current_entity = [[]] if pred_cls is None else [[word]] - elif pred_cls is not None: - current_entity[-1].append(word) - - return detections diff --git a/pipeline/postprocessing/fn-postprocess/util/config.py b/pipeline/postprocessing/fn-postprocess/util/config.py index 08aa3da..b4b1d9f 100644 --- a/pipeline/postprocessing/fn-postprocess/util/config.py +++ b/pipeline/postprocessing/fn-postprocess/util/config.py @@ -39,6 +39,8 @@ def __init__( optional: Optional[bool] = None, select: Optional[str] = None, annotation_guidance: Optional[str] = None, + normalizer_endpoint: Optional[str] = None, + normalizer_prompt: Optional[str] = None, ): """Create a FieldConfiguration @@ -61,12 +63,20 @@ def __init__( annotation_guidance : Optional[str] HTML-tagged guidance detailing the specific scope for this entity: I.e. what should and should not be included for consistent labelling. + normalizer_endpoint : Optional[str] + An optional deployed SageMaker seq2seq endpoint for field value normalization, if one + should be used (You'll have to train and deploy this endpoint separately). + normalizer_prompt : Optional[str] + The prompting prefix for the seq2seq field value normalization requests on this field, + if enabled. For example, "Convert dates to YYYY-MM-DD: " """ self.class_id = class_id self.name = name self.ignore = ignore self.optional = optional self.annotation_guidance = annotation_guidance + self.normalizer_endpoint = normalizer_endpoint + self.normalizer_prompt = normalizer_prompt try: self.select = FieldSelectionMethods[select.upper()].value if select else None except KeyError as e: @@ -77,3 +87,9 @@ def __init__( [fsm.name for fsm in FieldSelectionMethods], ) ) from e + if bool(self.normalizer_endpoint) ^ bool(self.normalizer_prompt): + raise ValueError( + "Cannot provide only one of `normalizer_endpoint` and `normalizer_prompt` without " + "setting both. Got: '%s' and '%s'" + % (self.normalizer_endpoint, self.normalizer_prompt) + ) diff --git a/pipeline/postprocessing/fn-postprocess/util/extract.py b/pipeline/postprocessing/fn-postprocess/util/extract.py new file mode 100644 index 0000000..3e6dc40 --- /dev/null +++ b/pipeline/postprocessing/fn-postprocess/util/extract.py @@ -0,0 +1,131 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: MIT-0 +"""Utils to extract entity mentions from SageMaker Textract WORD-tagging model results + +As a simple heuristic, consecutive WORD blocks of the same tagged entity class are tagged as +belonging to the same mention. This means that in cases where the normal human reading order +diverges from the Amazon Textract block output order, mentions may get split up. +""" +# Python Built-Ins: +import json +from typing import List, Optional, Sequence + +# External Dependencies: +import trp # Amazon Textract Response Parser + +# Local Dependencies: +from .boxes import UniversalBox +from .config import FieldConfiguration + + +class EntityDetection: + """Object describing an entity mention in a document + + If property `raw_text` (or 'RawText' in the JSON-ified equivalent) is set, this mention has + been normalized. Otherwise, `text` is as per the original document. + """ + + raw_text: Optional[str] + + def __init__(self, trp_words: Sequence[trp.Word], cls_id: int, cls_name: str, page_num: int): + self.cls_id = cls_id + self.cls_name = cls_name + self.page_num = page_num + + if len(trp_words) and not hasattr(trp_words[0], "id"): + trp_words_by_line = trp_words + trp_words_flat = [w for ws in trp_words for w in ws] + + else: + trp_words_by_line = [trp_words] + trp_words_flat = trp_words + self.bbox = UniversalBox.aggregate( + boxes=[UniversalBox(box=w.geometry.boundingBox) for w in trp_words_flat], + ) + self.blocks = list(map(lambda w: w.id, trp_words_flat)) + self.confidence = min( + map( + lambda w: min( + w._block.get("PredictedClassConfidence", 1.0), + w.confidence, + ), + trp_words_flat, + ) + ) + self.text = "\n".join( + map( + lambda words: " ".join([w.text for w in words]), + trp_words_by_line, + ) + ) + self.raw_text = None + + def normalize(self, normalized_text: str) -> None: + """Update the detection with a new normalized text value + + Only the original raw_text value will be preserved, so if you normalize() multiple times no + record of the intermediate normalized_text values will be kept. + """ + if self.raw_text is None: + self.raw_text = self.text + # Otherwise keep original 'raw' text (normalize called multiple times) + self.text = normalized_text + + def to_dict(self) -> dict: + """Represent this mention as a PascalCase JSON-able object""" + result = { + "ClassId": self.cls_id, + "ClassName": self.cls_name, + "Confidence": self.confidence, + "Blocks": self.blocks, + "BoundingBox": self.bbox.to_dict(), + "PageNum": self.page_num, + "Text": self.text, + } + if self.raw_text is not None: + result["RawText"] = self.raw_text + return result + + def __repr__(self) -> str: + return json.dumps(self.to_dict()) + + +def extract_entities( + doc: trp.Document, + entity_config: List[FieldConfiguration], +) -> List[EntityDetection]: + """Collect EntityDetections from an NER-enriched Textract JSON doc into a flat list""" + entity_classes = {c.class_id: c.name for c in entity_config if not c.ignore} + detections = [] + + current_cls = None + current_entity = [] + for ixpage, page in enumerate(doc.pages): + for line in page.lines: # TODO: Lines InReadingOrder? + current_entity.append([]) + for word in line.words: + pred_cls = word._block.get("PredictedClass") + if pred_cls not in entity_classes: + pred_cls = None # Treat all non-config'd entities as "other" + + if pred_cls != current_cls: + if current_cls is not None: + detections.append( + EntityDetection( + trp_words=list( + filter( + lambda ws: len(ws), + current_entity, + ) + ), + cls_id=current_cls, + cls_name=entity_classes[current_cls], + page_num=ixpage + 1, + ) + ) + current_cls = pred_cls + current_entity = [[]] if pred_cls is None else [[word]] + elif pred_cls is not None: + current_entity[-1].append(word) + + return detections diff --git a/pipeline/postprocessing/fn-postprocess/util/normalize.py b/pipeline/postprocessing/fn-postprocess/util/normalize.py new file mode 100644 index 0000000..14fd520 --- /dev/null +++ b/pipeline/postprocessing/fn-postprocess/util/normalize.py @@ -0,0 +1,91 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: MIT-0 +"""Utils to normalize detected entity text by calling SageMaker sequence-to-sequence model endpoints + +`normalizer_endpoint` on a FieldConfiguration is assumed to be a deployed real-time SageMaker +endpoint that accepts batched 'application/json' requests of structure: +`{"inputs": ["list", "of", "strings"]}`, and returns 'application/json' responses of structure: +`{"generated_text": ["corresponding", "result", "strings"]}` +""" +# Python Built-Ins: +import json +from logging import getLogger +from typing import Dict, List, Sequence + +# External Dependencies: +import boto3 # General-purpose AWS SDK for Python + +# Local Dependencies: +from .config import FieldConfiguration +from .extract import EntityDetection + +logger = getLogger("postproc") +smruntime = boto3.client("sagemaker-runtime") + + +def normalize_detections( + detections: Sequence[EntityDetection], + entity_config: Sequence[FieldConfiguration], +) -> None: + """Normalize detected entities in-place via batched requests to SageMaker normalizer endpoint(s) + + Due to the high likelihood of one document featuring multiple matches of the same text for the + same entity class, we de-duplicate requests by target endpoint and input text - and duplicate + the result across all linked detections. + """ + entity_config_by_clsid = {c.class_id: c for c in entity_config if not c.ignore} + + # Batched normalization requests: + # - By target endpoint name + # - By input text (after adding prompt prefix) + # - List of which detections (indexes) correspond to the request + norm_requests: Dict[str, Dict[str, List[int]]] = {} + + # Collect required normalization requests from the detections: + for ixdet, detection in enumerate(detections): + config = entity_config_by_clsid.get(detection.cls_id) + if not config: + continue # Ignore any detections in non-configured classes + if not config.normalizer_endpoint: + continue # This entity class configuration has no normalizer + if config.normalizer_endpoint not in norm_requests: + norm_requests[config.normalizer_endpoint] = {} + + norm_input_text = config.normalizer_prompt + detection.text + if norm_input_text in norm_requests[config.normalizer_endpoint]: + norm_requests[config.normalizer_endpoint][norm_input_text].append(ixdet) + else: + norm_requests[config.normalizer_endpoint][norm_input_text] = [ixdet] + + # Call out to the SageMaker endpoints and update the detections with the results: + for endpoint_name in norm_requests: + req_dict = norm_requests[endpoint_name] + input_texts = [k for k in req_dict] + try: + norm_resp = smruntime.invoke_endpoint( + EndpointName=endpoint_name, + Body=json.dumps( + { + "inputs": input_texts, + } + ), + ContentType="application/json", + Accept="application/json", + ) + # Response should be JSON dict containing list 'generated_text' of outputs: + output_texts = json.loads(norm_resp["Body"].read())["generated_text"] + except Exception: + # Log the failure, but continue on: + logger.exception( + "Entity normalization call failed: %s texts to endpoint '%s'", + len(input_texts), + endpoint_name, + ) + continue + + for ixtext, output in enumerate(output_texts): + for ixdetection in req_dict[input_texts[ixtext]]: + detections[ixdetection].normalize(output) + + # Return nothing to explicitly indicate that detections are modified in-place + return From 5a0d82331406bb812396c24517b9b1be0f85580f Mon Sep 17 00:00:00 2001 From: Alex Thewsey Date: Fri, 13 Jan 2023 02:08:20 +0800 Subject: [PATCH 04/14] fix(src): Allow non-fast seq2seq tok for byT5 Allow non-fast tokenizers for seq2seq modelling tasks, which is necessary if wanting to use byT5 model in seq2seq mode. The previous restriction requiring Fast Tokenizers matters for ner and mlm tasks where custom data collation expects a particular API, but not for seq2seq. --- notebooks/src/code/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/notebooks/src/code/train.py b/notebooks/src/code/train.py index 2c6998d..bd67beb 100644 --- a/notebooks/src/code/train.py +++ b/notebooks/src/code/train.py @@ -221,8 +221,8 @@ def train( # https://sagemaker.readthedocs.io/en/stable/api/training/smd_data_parallel_use_sm_pysdk.html # For SM Distributed, ddp_launcher.py is not necessary - point straight to train.py - # Tokenizer check: this script requires a fast tokenizer. - if not isinstance(tokenizer, PreTrainedTokenizerFast): + # Tokenizer check: Our MLM/NER data prep requires a fast tokenizer. + if data_args.task_name in ("mlm", "ner") and not isinstance(tokenizer, PreTrainedTokenizerFast): raise ValueError( "This example script only works for models that have a fast tokenizer. See the list " "at https://huggingface.co/transformers/index.html#supported-frameworks for details." From 461fcbabe9a17a9fc905fccd25f1683a24f96b9b Mon Sep 17 00:00:00 2001 From: Alex Thewsey Date: Fri, 13 Jan 2023 02:12:07 +0800 Subject: [PATCH 05/14] style(lint): black-format src folder --- notebooks/src/code/train.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/notebooks/src/code/train.py b/notebooks/src/code/train.py index bd67beb..dcee936 100644 --- a/notebooks/src/code/train.py +++ b/notebooks/src/code/train.py @@ -222,7 +222,9 @@ def train( # For SM Distributed, ddp_launcher.py is not necessary - point straight to train.py # Tokenizer check: Our MLM/NER data prep requires a fast tokenizer. - if data_args.task_name in ("mlm", "ner") and not isinstance(tokenizer, PreTrainedTokenizerFast): + if data_args.task_name in ("mlm", "ner") and not isinstance( + tokenizer, PreTrainedTokenizerFast + ): raise ValueError( "This example script only works for models that have a fast tokenizer. See the list " "at https://huggingface.co/transformers/index.html#supported-frameworks for details." From 69bf9339d5dd24f4574052d4570723bb96a3e960 Mon Sep 17 00:00:00 2001 From: Alex Thewsey Date: Fri, 13 Jan 2023 02:14:07 +0800 Subject: [PATCH 06/14] feat(seq2seq): Weight sample dates to US formats Re-weight sample date re-formatting task from UK date formats towards US (month-first) formats, for better consistency with the CFPB credit cards sample documents. --- .../code/data/seq2seq/date_normalization.py | 28 ++++++++++--------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/notebooks/src/code/data/seq2seq/date_normalization.py b/notebooks/src/code/data/seq2seq/date_normalization.py index 581f8cf..a5c0567 100644 --- a/notebooks/src/code/data/seq2seq/date_normalization.py +++ b/notebooks/src/code/data/seq2seq/date_normalization.py @@ -58,19 +58,21 @@ class DateFormatConfig: # Format configuration for synthetic date normalization training data generation DATE_FORMAT_CONFIGS = [ DateFormatConfig("%Y-%m-%d", "YYYY-MM-DD", observed_weight=0.1, target_weight=0.7), - DateFormatConfig("%d/%m/%y", "DD/MM/YY", observed_weight=0.35, target_weight=0.05), - DateFormatConfig("%d/%m/%Y", "DD/MM/YYYY", observed_weight=0.35, target_weight=0.2), - DateFormatConfig("%m/%d/%y", "MM/DD/YY", observed_weight=0.1, target_weight=0.02), - DateFormatConfig("%m/%d/%Y", "MM/DD/YYYY", observed_weight=0.05, target_weight=0.03), - # Including day names: - DateFormatConfig("%a %d %b %y", "DDD DD MM YY", observed_weight=0.05, target_weight=0.0), - DateFormatConfig("%a. %d %b %y", "DDD. DD MM YY", observed_weight=0.05, target_weight=0.0), - DateFormatConfig("%A %b %d %y", "DDDD MM DD YY", observed_weight=0.01, target_weight=0.0), - DateFormatConfig("%A %b %d %y", "DDDD MM DDst YY", observed_weight=0.01, target_weight=0.0), - DateFormatConfig("%A %b %d %y", "DDDD MM DDnd YY", observed_weight=0.01, target_weight=0.0), - DateFormatConfig("%A %b %d %y", "DDDD MM DDrd YY", observed_weight=0.01, target_weight=0.0), - DateFormatConfig("%A %b %d %y", "DDDD MM DDth YY", observed_weight=0.01, target_weight=0.0), - DateFormatConfig("%A, %b %d %y", "DDDD, MM DD YY", observed_weight=0.02, target_weight=0.0), + DateFormatConfig("%m/%d/%y", "MM/DD/YY", observed_weight=0.35, target_weight=0.05), + DateFormatConfig("%m/%d/%Y", "MM/DD/YYYY", observed_weight=0.35, target_weight=0.2), + DateFormatConfig("%d/%m/%y", "DD/MM/YY", observed_weight=0.05, target_weight=0.02), + DateFormatConfig("%d/%m/%Y", "DD/MM/YYYY", observed_weight=0.04, target_weight=0.03), + # Including day names and month names: + DateFormatConfig("%A %b %d %y", "DDDD MMM DD YY", observed_weight=0.03, target_weight=0.0), + DateFormatConfig("%A, %b %d %y", "DDDD, MMM DD YY", observed_weight=0.02, target_weight=0.0), + DateFormatConfig("%a %b %d, %y", "DDD MMM DD, YY", observed_weight=0.02, target_weight=0.0), + DateFormatConfig("%a. %b %d %y", "DDD. MMM DD YY", observed_weight=0.02, target_weight=0.0), + DateFormatConfig("%A %b %d %y", "DDDD MMM DDst YY", observed_weight=0.01, target_weight=0.0), + DateFormatConfig("%A %b %d %y", "DDDD MMM DDnd YY", observed_weight=0.01, target_weight=0.0), + DateFormatConfig("%A %b %d %y", "DDDD MMM DDrd YY", observed_weight=0.01, target_weight=0.0), + DateFormatConfig("%A %b %d %y", "DDDD MMM DDth YY", observed_weight=0.01, target_weight=0.0), + DateFormatConfig("%a %d %b %y", "DDD DD MMM YY", observed_weight=0.02, target_weight=0.0), + DateFormatConfig("%a. %d %b %y", "DDD. DD MMM YY", observed_weight=0.02, target_weight=0.0), # Including times: DateFormatConfig( "%Y-%m-%d %H:%M:%S", "YYYY-MM-DD HH:mm:ss", observed_weight=0.02, target_weight=0.0 From 9bb41826e44e4eca25975ac3b3606ceb262a6382 Mon Sep 17 00:00:00 2001 From: Alex Thewsey Date: Fri, 13 Jan 2023 21:45:20 +0800 Subject: [PATCH 07/14] feat(label): Support custom tpl SMGT jobs from NB1 Enable/streamline creating a SMGT labelling job using the custom (transcription reviews) template from data preparation notebook. With addition of the seq2seq model, more users are likely to be interested in collecting text normalizations. --- annotation/__init__.py | 18 +++++-- notebooks/1. Data Preparation.ipynb | 31 ++++++------ notebooks/util/smgt.py | 73 +++++++++++++++++++---------- 3 files changed, 81 insertions(+), 41 deletions(-) diff --git a/annotation/__init__.py b/annotation/__init__.py index bd0c2d9..0fea1e8 100644 --- a/annotation/__init__.py +++ b/annotation/__init__.py @@ -192,9 +192,10 @@ def __init__(self, scope: Construct, id: str, **kwargs): self._pre_lambda = PythonFunction( self, # Include 'LabelingFunction' in the name so the entities with the - # AmazoSageMakerGroundTruthExecution policy will automatically have access to call it: + # AmazonSageMakerGroundTruthExecution policy will automatically have access to call it: # https://console.aws.amazon.com/iam/home?#/policies/arn:aws:iam::aws:policy/AmazonSageMakerGroundTruthExecution - "SMGT-LabelingFunction-Pre", + # (Of course this won't work if construct so deeply nested the name is cut off) + "PreLabelingFunction", entry=PRE_LAMBDA_PATH, index="main.py", handler="handler", @@ -205,7 +206,7 @@ def __init__(self, scope: Construct, id: str, **kwargs): ) self._post_lambda = PythonFunction( self, - "SMGT-LabelingFunction-Post", + "PostLabelingFunction", entry=POST_LAMBDA_PATH, index="main.py", handler="handler", @@ -247,4 +248,15 @@ def get_data_science_policy_statements(self) -> List[PolicyStatement]: effect=Effect.ALLOW, resources=["arn:aws:codebuild:*:*:project/sagemaker-studio*"], ), + PolicyStatement( + sid="InvokeCustomSMGTLambdas", + actions=[ + "lambda:InvokeFunction", + ], + effect=Effect.ALLOW, + resources=[ + self.pre_lambda.function_arn, + self.post_lambda.function_arn, + ], + ), ] diff --git a/notebooks/1. Data Preparation.ipynb b/notebooks/1. Data Preparation.ipynb index ffede11..6d11ccf 100644 --- a/notebooks/1. Data Preparation.ipynb +++ b/notebooks/1. Data Preparation.ipynb @@ -1505,19 +1505,13 @@ "source": [ "Note that:\n", "\n", - "- When you draw a bounding box on the page image, a new OCR result is populated in the left sidebar prompting you to review (and if necessary correct) Textract's transcription of the text in that region.\n", - "- Overlapping bounding boxes of the same type are consolidated, allowing us to highlight non-square regions of text (for example a particular sentence over multiple lines within a paragraph).\n", - "- Transcription review fields are mandatory: The template should not let you submit the result until all transcriptions have been reviewed.\n", + "- When you draw a bounding box on the page image, a new OCR result is populated in the left sidebar prompting you to review (and if necessary, correct) Textract's transcription of the text in that region.\n", + "- Overlapping bounding boxes of the same type are combined, allowing you to highlight non-square regions of text.\n", + "- Transcription review fields are mandatory: The template should not let you submit the result until all transcriptions have been reviewed in the sidebar.\n", "\n", - "You should aim to follow these same conventions when annotating the sample data, even with the built-in task type. Under the hood, the ML model code applies similar logic to map your bounding box annotations to the Textract detected `WORD`s and `LINE`s.\n", + "You should aim to follow the same overlapping conventions when annotating the sample data, even with the built-in task type. Under the hood, the ML model code applies similar logic to map your bounding box annotations to the Textract detected `WORD`s and `LINE`s.\n", "\n", - "To use this custom template in a data labeling job, you can adjust the instructions below (which assume you'll use the faster built-in template) as follows:\n", - "\n", - "- Select task category 'Custom' > task type 'Custom', instead of 'Image > Bounding Box'\n", - "- For template body, copy the contents of the `*.liquid.html` file above (**NOT** the `*.tpl.liquid.html`, which has placeholders e.g. for the list of classes)\n", - "- In the tool configuration step, select the `SMGT-Pre` and `SMGT-Post` Lambda functions that have been created for you by the solution stack: These should appear in the drop-down options.\n", - "\n", - "In practice, while it's important to explore how the bounding boxes are being interpreted, we'd recommend to use the simpler built-in template for this walkthrough: To help you complete your data annotation faster." + "Since reviewing transcriptions makes labelling take longer, you'll probably want to stick to the built-in task template instead of this custom one, unless you have a use case for collecting the text data. The [Exploring sequence-to-sequence models section](Optional%20Extras.ipynb#Exploring-sequence-to-sequence-models) of the **Optional Extras** notebook discusses training generative models to actually normalize and correct OCR transcriptions." ] }, { @@ -1723,7 +1717,7 @@ "\n", "To minimize the risk of errors and get started quickly, you're recommended to create your labeling job by running the utility function provided below.\n", "\n", - "This will set up a job with the default pre-built bounding box template:" + "This will set up a job with the **default pre-built bounding box** template by default, for faster labelling. If you want to collect OCR transcription reviews as well (for example, to measure OCR accuracy on your fields of interest or train text normalizing models), you can un-comment and fill in the extra optional arguments to use the custom template we saw earlier." ] }, { @@ -1743,9 +1737,15 @@ " input_manifest_s3uri=input_manifest_s3uri,\n", " output_s3uri=annotations_base_s3uri,\n", " workteam_arn=workteam_arn,\n", + " s3_inputs_prefix=f\"{bucket_prefix}data/manifests\",\n", + "\n", " # To create a review/adjustment job from a manifest with existing labels in:\n", " # reviewing_attribute_name=\"label\",\n", - " s3_inputs_prefix=f\"{bucket_prefix}data/manifests\",\n", + "\n", + " # To use the custom task template (adding transcription review):\n", + " # task_template=\"annotation/ocr-bbox-and-validation.liquid.html\",\n", + " # pre_lambda_arn=\"arn:aws...{AnnotationInfra->PreLabelingFunction Lambda from your CFn stack}\",\n", + " # post_lambda_arn=\"arn:aws...{AnnotationInfra->PostLabelingFunction Lambda from your CFn stack}\",\n", ")\n", "print(f\"\\nLABELLING JOB STARTED:\\n{create_labeling_job_resp['LabelingJobArn']}\")" ] @@ -1761,9 +1761,12 @@ " - The `input_manifest_s3uri` (`s3://[...].jsonl`) from above for the input location\n", " - The `annotations_base_s3uri` (`s3://[...]/data/annotations`) with **no trailing slash** for the output location\n", "- Select or create any **SageMaker IAM execution role** that has access to the `bucket_name` we're using.\n", - "- For **task type**, select *Image > Bounding Box*\n", + "- For **task type**, select *Image > Bounding Box* (for the default bounding-box UI) or *Custom* (for the custom UI with OCR transcription reviews)\n", "- On the second screen, be sure to use **worker type** *Private* and select the workteam we made earlier from the dropdown.\n", "- For the built-in task type, you'll need to enter the **labels** manually exactly in the order that we defined them in this notebook.\n", + "- For the custom task type:\n", + " - Copy the contents of the `*.liquid.html` file above into the *Template body* section (**NOT** the `*.tpl.liquid.html`, which has placeholders e.g. for the list of classes)\n", + " - Select the `PreLabelingFunction` and `PostLabelingFunction` Lambda functions created by the Pipeline CDK stack (you can find these via the [CloudFormation console](https://console.aws.amazon.com/cloudformation/home?#/stacks))\n", "\n", "The cell below prints out some of these values to help:" ] diff --git a/notebooks/util/smgt.py b/notebooks/util/smgt.py index 6b73fdf..4e416cf 100644 --- a/notebooks/util/smgt.py +++ b/notebooks/util/smgt.py @@ -296,35 +296,49 @@ def create_bbox_labeling_job( local_inputs_folder: str = os.path.join("data", "manifests"), reviewing_attribute_name: Optional[str] = None, s3_inputs_prefix: str = "data/manifests", + task_template: Optional[str] = None, + pre_lambda_arn: Optional[str] = None, + post_lambda_arn: Optional[str] = None, ) -> dict: """Create a SageMaker Ground Truth labelling job with the built-in Bounding Box task UI Parameters ---------- - job_name : str + job_name : Name of the job to create (must be unique in your AWS Account+Region) - bucket_name : str + bucket_name : Name of the S3 bucket where input/output manifests and job metadata will be stored - execution_role_arn : str + execution_role_arn : ARN of the SageMaker Execution Role (in AWS IAM) that the labelling job will run as. The role must have permission to access your selected `bucket_name`. fields : Iterable[FieldConfiguration] Field/entity types list - input_manifest_s3uri : str + input_manifest_s3uri : 's3://...' URI where the input JSON-Lines manifest file is (already) stored - output_s3uri : str + output_s3uri : 's3://...' URI where the job output should be stored (SMGT will add a job subfolder) - workteam_arn : str + workteam_arn : ARN of the SageMaker Ground Truth workteam who will be performing the task - local_inputs_folder : str + local_inputs_folder : Local folder where configuration files for SMGT will be stored before uploading to S3. (Default 'data/manifests') reviewing_attribute_name : Optional[str] Set the name of the manifest attribute where existing labels are stored, to trigger an adjustment job on pre-existing labels. (Default None) - s3_inputs_prefix : str + s3_inputs_prefix : Key prefix (with or without trailing slash) under which configuration files for SMGT will be uploaded to the S3 bucket_name. (Default 'data/manifests') + task_template : + Optional custom task template file (local path). If not provided, the standard SMGT Bounding + Box task UI will be used. + pre_lambda_arn : + Override AWS Lambda ARN for Ground Truth task pre-processing. When unset, the default + pre-processing Lambda for SMGT Bounding Box (adjustment) task UI will be used. Set this + parameter to use your own function instead. + post_lambda_arn : + Override AWS Lambda ARN for Ground Truth task post-processing. When unset, the default + post-processing Lambda for SMGT Bounding Box (adjustment) task UI will be used. Set this + parameter to use your own function instead. Returns ------- @@ -354,23 +368,26 @@ def create_bbox_labeling_job( bucket.upload_file(input_category_file, input_category_s3key) print(f"Uploaded Labeling Category Config {input_category_file} to:\n{input_category_s3uri}") - # Generate and upload the task template; - task_template_file = os.path.join(local_inputs_folder, f"{job_name}.liquid.html") + # Generate and upload the task template: task_template_s3key = "/".join((s3_inputs_prefix, f"{job_name}.liquid.html")) task_template_s3uri = f"s3://{bucket_name}/{task_template_s3key}" - with open(task_template_file, "w") as f: - f.write( - get_bbox_template( - header="Highlight the entities with bounding boxes", - instructions_short=( - label_category_config.get("instructions", {}).get("shortInstruction", "") - ), - instructions_full=( - label_category_config.get("instructions", {}).get("fullInstruction", "") - ), - reviewing_attribute_name=reviewing_attribute_name, + if task_template is None: + task_template_file = os.path.join(local_inputs_folder, f"{job_name}.liquid.html") + with open(task_template_file, "w") as f: + f.write( + get_bbox_template( + header="Highlight the entities with bounding boxes", + instructions_short=( + label_category_config.get("instructions", {}).get("shortInstruction", "") + ), + instructions_full=( + label_category_config.get("instructions", {}).get("fullInstruction", "") + ), + reviewing_attribute_name=reviewing_attribute_name, + ) ) - ) + else: + task_template_file = task_template bucket.upload_file(task_template_file, task_template_s3key) print(f"Uploaded resolved task UI template {task_template_file} to:\n{task_template_s3uri}") @@ -403,7 +420,11 @@ def create_bbox_labeling_job( "UiConfig": { "UiTemplateS3Uri": task_template_s3uri, }, - "PreHumanTaskLambdaArn": get_smgt_lambda_arn(pre=True, task=task), + "PreHumanTaskLambdaArn": ( + get_smgt_lambda_arn(pre=True, task=task) + if pre_lambda_arn is None + else pre_lambda_arn + ), "TaskTitle": "Credit Card Agreement Entities", "TaskDescription": "Highlight the entities with bounding boxes", "NumberOfHumanWorkersPerDataObject": 1, @@ -411,7 +432,11 @@ def create_bbox_labeling_job( "TaskAvailabilityLifetimeInSeconds": 10 * 24 * 60 * 60, "MaxConcurrentTaskCount": 250, "AnnotationConsolidationConfig": { - "AnnotationConsolidationLambdaArn": get_smgt_lambda_arn(pre=False, task=task), + "AnnotationConsolidationLambdaArn": ( + get_smgt_lambda_arn(pre=False, task=task) + if post_lambda_arn is None + else post_lambda_arn + ), }, }, ) From d3c2ca8f5520bd65879fac318094d7865302d18c Mon Sep 17 00:00:00 2001 From: Alex Thewsey Date: Fri, 13 Jan 2023 22:15:57 +0800 Subject: [PATCH 08/14] fix(seq2seq): 1st/Nth date format generator fix Date normalization generators were not correctly adding ordinal suffixes to day numbers in some formats as intended: Just decorating the format name without changing the actual format string. --- notebooks/src/code/data/seq2seq/date_normalization.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/notebooks/src/code/data/seq2seq/date_normalization.py b/notebooks/src/code/data/seq2seq/date_normalization.py index a5c0567..d58d293 100644 --- a/notebooks/src/code/data/seq2seq/date_normalization.py +++ b/notebooks/src/code/data/seq2seq/date_normalization.py @@ -67,10 +67,10 @@ class DateFormatConfig: DateFormatConfig("%A, %b %d %y", "DDDD, MMM DD YY", observed_weight=0.02, target_weight=0.0), DateFormatConfig("%a %b %d, %y", "DDD MMM DD, YY", observed_weight=0.02, target_weight=0.0), DateFormatConfig("%a. %b %d %y", "DDD. MMM DD YY", observed_weight=0.02, target_weight=0.0), - DateFormatConfig("%A %b %d %y", "DDDD MMM DDst YY", observed_weight=0.01, target_weight=0.0), - DateFormatConfig("%A %b %d %y", "DDDD MMM DDnd YY", observed_weight=0.01, target_weight=0.0), - DateFormatConfig("%A %b %d %y", "DDDD MMM DDrd YY", observed_weight=0.01, target_weight=0.0), - DateFormatConfig("%A %b %d %y", "DDDD MMM DDth YY", observed_weight=0.01, target_weight=0.0), + DateFormatConfig("%A %b %dst %y", "DDDD MMM DDst YY", observed_weight=0.01, target_weight=0.0), + DateFormatConfig("%A %b %dnd %y", "DDDD MMM DDnd YY", observed_weight=0.01, target_weight=0.0), + DateFormatConfig("%A %b %drd %y", "DDDD MMM DDrd YY", observed_weight=0.01, target_weight=0.0), + DateFormatConfig("%A %b %dth %y", "DDDD MMM DDth YY", observed_weight=0.01, target_weight=0.0), DateFormatConfig("%a %d %b %y", "DDD DD MMM YY", observed_weight=0.02, target_weight=0.0), DateFormatConfig("%a. %d %b %y", "DDD. DD MMM YY", observed_weight=0.02, target_weight=0.0), # Including times: From 9c2401ccc515554a44f53f289a3ec4b0e9465875 Mon Sep 17 00:00:00 2001 From: Alex Thewsey Date: Mon, 16 Jan 2023 14:19:07 +0800 Subject: [PATCH 09/14] fix(smgt): Py logging setup on pre/post Lambda Explicitly set up Python logging level in pre/post SMGT Lambda functions to ensure (only) expected messages are generated. Fixes issue of .info() calls not coming through from postproc Lambda by default. --- annotation/fn-SMGT-Post/main.py | 1 + annotation/fn-SMGT-Pre/main.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/annotation/fn-SMGT-Post/main.py b/annotation/fn-SMGT-Post/main.py index aa867e5..3740ac5 100644 --- a/annotation/fn-SMGT-Post/main.py +++ b/annotation/fn-SMGT-Post/main.py @@ -12,6 +12,7 @@ import boto3 # AWS SDK for Python logger = logging.getLogger() +logger.setLevel(logging.INFO) s3 = boto3.client("s3") diff --git a/annotation/fn-SMGT-Pre/main.py b/annotation/fn-SMGT-Pre/main.py index 340e07c..fc0cf40 100644 --- a/annotation/fn-SMGT-Pre/main.py +++ b/annotation/fn-SMGT-Pre/main.py @@ -1,10 +1,13 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: MIT-0 """A minimal Lambda function for pre-processing SageMaker Ground Truth custom annotation tasks + +Just passes through the event's `dataObject` unchanged. """ import logging logger = logging.getLogger() +logger.setLevel(logging.INFO) def handler(event, context): From 47e7769e203a8d39de7eda963ec58f4e7be9ecb9 Mon Sep 17 00:00:00 2001 From: Alex Thewsey Date: Tue, 17 Jan 2023 11:51:54 +0800 Subject: [PATCH 10/14] feat(smgt): Simpler custom task output format Simplify and improve the output format of labelling jobs using the custom (bboxes + OCR transcript reviews) task UI. The new template avoids duplicating the bboxes and recording Textract word IDs (which are very verbose), and makes it easier to pull out source/OCR text vs target/corrected text for each field (useful for seq2seq). --- annotation/fn-SMGT-Post/data_model.py | 279 +++++++++++++++++ annotation/fn-SMGT-Post/main.py | 212 ++++++------- annotation/fn-SMGT-Post/smgt.py | 288 ++++++++++++++++++ .../ocr-bbox-and-validation.liquid.tpl.html | 23 +- 4 files changed, 678 insertions(+), 124 deletions(-) create mode 100644 annotation/fn-SMGT-Post/data_model.py create mode 100644 annotation/fn-SMGT-Post/smgt.py diff --git a/annotation/fn-SMGT-Post/data_model.py b/annotation/fn-SMGT-Post/data_model.py new file mode 100644 index 0000000..5b65341 --- /dev/null +++ b/annotation/fn-SMGT-Post/data_model.py @@ -0,0 +1,279 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: MIT-0 +"""Parsers, data models, and utilities for our custom (OCR-oriented) components + +This module contains code for parsing and translating data objects from our custom OCR review +SMGT task template, ready for consolidation into the final output manifest. +""" +# Python Built-Ins: +from __future__ import annotations +from dataclasses import dataclass +from enum import Enum +import json +from logging import getLogger +import re +from typing import List, Optional + +# Local Dependencies: +from smgt import BaseJsonable, BaseObjectParser, SMGTOutputBoundingBox + +logger = getLogger("data_model") + + +class OCRReviewStatus(str, Enum): + """Ternary status for OCR transcription review""" + + correct = "correct" + unclear = "unclear" + wrong = "wrong" + + +@dataclass +class SMGTOCREntity(BaseJsonable, BaseObjectParser): + """BBox+transcript review OCR entity annotation, as used in consolidation + + This class `parse()`s from raw template output format and serializes to final output manifest + format - so it's a bit specific to consolidation/post-processing Lambda as written. + + Attributes + ---------- + detection_id : + Auto-generated identifier assigned to each bounding box cluster/group by the UI template. + ocr_status : + Parsed status of the OCR transcription review (correct, unclear, wrong). + box_ixs : + Indexes of the bounding boxes in the main crowd-bounding-box result that this entity + corresponds to. + class_id : + Numeric ID of the entity type/class (either this or string label should be known). + label : + String name of the entity type/class (either this or the numeric class_id should be known). + raw_text : + The raw text for the entity as detected by OCR tool. + target_text : + The target/normalized text as overridden by the user. + """ + + detection_id: str + ocr_status: OCRReviewStatus + box_ixs: List[int] + class_id: Optional[int] = None + label: Optional[str] = None + raw_text: Optional[str] = None + target_text: Optional[str] = None + + @classmethod + def find_detection_ids(cls, parent_obj: dict) -> List[str]: + """Find all auto-generated entity/detection IDs in top-level custom task output data + + Because of the mechanics of the SM Crowd HTML Elements and the template, there are multiple + keys in the annotation output storing each entity's raw data. This function discovers + available entity/detection IDs in a result. + + Parameters + ---------- + parent_obj : + Top-level annotation data object as output by the UI task template, containing multiple + fields. + """ + return sorted( + set( + map( + lambda m: m.group(1), + filter( + lambda m: m, + map( + lambda key: re.match(r"ocr-(.*)-[a-z]+", key, flags=re.IGNORECASE), + parent_obj.keys(), + ), + ), + ), + ), + ) + + @classmethod + def parse( + cls, + parent_obj: dict, + detection_id: str, + boxes: Optional[List[SMGTOutputBoundingBox]] = None, + ) -> SMGTOCREntity: + """Parse the entity with given ID from the *whole annotation object* + + Use the `find_detection_ids()` method to look up available IDs in the top-level annotation + data, then this parser to extract each ID. + + Parameters + ---------- + parent_obj : + Top-level annotation data object as output by the UI task template, containing multiple + fields. + detection_id : + Specific entity/group ID to extract for this entity + boxes : + If provided, these will simply be used to validate the tagged `boxIxs` in the entity + annotation are within range of the crowd-bounding-box tool's output. + + Raises + ------ + ValueError + If missing data or inconsistencies prevent the entity from being parsed from raw data. + """ + meta_field_key = f"ocr-{detection_id}-meta" + if meta_field_key not in parent_obj: + raise ValueError( + "OCR annotation metadata key %s not found in raw data" % meta_field_key, + ) + + meta = json.loads(parent_obj[meta_field_key]) + box_ixs = meta["boxIxs"] + if len(box_ixs) < 1: + raise ValueError( + "OCR annotation has no linked box annotations: %s" % detection_id, + ) + label = meta.get("label") + class_id = meta.get("labelId") + raw_text = meta.get("ocrText") + if boxes is not None: + n_boxes = len(boxes) + illegal_box_ixs = [ix >= 0 and ix < n_boxes for ix in box_ixs] + if len(illegal_box_ixs) > 0: + raise ValueError( + "OCR annotation '%s' links to boxIxs outside the range 0-%s: %s" + % (detection_id, n_boxes, illegal_box_ixs) + ) + if label is None: + label = boxes[box_ixs[0]].label + if class_id is None: + class_id = boxes[box_ixs[0]].class_id + + OCR_STATUSES = tuple(s.value for s in OCRReviewStatus) # String enum to Tuple[str] + ocr_status_fields = [f"ocr-{detection_id}-{status}" for status in OCR_STATUSES] + unknown_statuses = [ + s for ix, s in enumerate(OCR_STATUSES) if ocr_status_fields[ix] not in parent_obj + ] + if len(unknown_statuses): + logger.warning( + "OCR annotation %s could not determine whether the following statuses were " + "selected: %s", + detection_id, + unknown_statuses, + ) + selected_statuses = [ + s + for ix, s in enumerate(OCR_STATUSES) + if parent_obj.get(ocr_status_fields[ix], {}).get("on") + ] + n_selected_statuses = len(selected_statuses) + if n_selected_statuses == 1: + parsed_status = OCRReviewStatus[selected_statuses[0]] + elif n_selected_statuses >= 1: + logger.warning( + "OCR annotation %s selected %s statuses: %s. Marking as 'unclear'", + detection_id, + n_selected_statuses, + selected_statuses, + ) + parsed_status = OCRReviewStatus.unclear + else: # (0 selected statuses) + logger.warning( # TODO: push warnings through to output manifest? + "Missing OCR review status for annotation %s. Assuming 'unclear'", + detection_id, + ) + parsed_status = OCRReviewStatus.unclear + + if parsed_status == OCRReviewStatus.correct: + target_text = raw_text + else: + correction_field_key = f"ocr-{detection_id}-override" + target_text = parent_obj.get(correction_field_key) + if parsed_status == OCRReviewStatus.wrong and correction_field_key not in parent_obj: + logger.warning( + "OCR annotation %s tagged as 'wrong', but target text field %s is missing", + detection_id, + correction_field_key, + ) + + return SMGTOCREntity( + detection_id=detection_id, + ocr_status=parsed_status, + box_ixs=box_ixs, + class_id=class_id, + label=label, + raw_text=raw_text, + target_text=target_text, + ) + + +@dataclass +class SMGTWorkerAnnotation(BaseJsonable, BaseObjectParser): + """One worker's full annotation for a page using the custom bbox+transcript review task UI + + This class `parse()`s from raw template output format and serializes to final output manifest + format - so it's a bit specific to consolidation/post-processing Lambda as written. + + Attributes + ---------- + boxes : + Parsed SMGT crowd-bounding-box boxes as labelled + entities : + Parsed OCR "entities" (bounding box groupings with transcription accuracy reviews) + image_height : + Input image height in pixels + image_width : + Input image width in pixels + image_depth : + Input image number of channels (usually 1 grayscale or 3 RGB) if known. + """ + + boxes: List[SMGTOutputBoundingBox] + entities: List[SMGTOCREntity] + image_height: int + image_width: int + image_depth: Optional[int] = None + + @classmethod + def parse( + cls, + obj: dict, + class_list: Optional[List[str]] = None, + crowd_bounding_box_name: str = "boxtool", + ) -> SMGTWorkerAnnotation: + boxtool_data = obj[crowd_bounding_box_name] + image_props = boxtool_data["inputImageProperties"] + image_height = image_props["height"] + image_width = image_props["width"] + image_depth = image_props.get("depth") + + boxes = [ + SMGTOutputBoundingBox.parse(box, class_list=class_list) + for box in boxtool_data["boundingBoxes"] + ] + entity_detection_ids = SMGTOCREntity.find_detection_ids(obj) + entities = [] + for det_id in entity_detection_ids: + try: + entities.append(SMGTOCREntity.parse(obj, det_id)) + except Exception: + logger.exception("Failed to load annotated entity %s", det_id) + # TODO: Propagate failed entity extractions as warnings to output too? + + return cls( + boxes=boxes, + entities=entities, + image_height=image_height, + image_width=image_width, + image_depth=image_depth, + ) + + def to_jsonable(self) -> dict: + img_meta = {"height": self.image_height, "width": self.image_width} + if self.image_depth is not None: + img_meta["depth"] = self.image_depth + return { + # Image metadata and bounding boxes in format compatible with built-in BBox task: + "image_size": [img_meta], + "annotations": [box.to_jsonable() for box in self.boxes], + # Additional data for OCR transcription reviews: + "entities": [entity.to_jsonable() for entity in self.entities], + } diff --git a/annotation/fn-SMGT-Post/main.py b/annotation/fn-SMGT-Post/main.py index 3740ac5..44d4e0c 100644 --- a/annotation/fn-SMGT-Post/main.py +++ b/annotation/fn-SMGT-Post/main.py @@ -5,128 +5,110 @@ # Python Built-Ins: import json import logging -import re -from urllib.parse import urlparse +from typing import List, Optional # External Dependencies: import boto3 # AWS SDK for Python +# Set up logger before local imports: logger = logging.getLogger() logger.setLevel(logging.INFO) + +# Local Dependencies: +from data_model import SMGTWorkerAnnotation # Custom task data model (edit if needed!) +from smgt import ( # Generic SageMaker Ground Truth parsers/utilities + ConsolidationRequest, + ObjectAnnotationResult, + PostConsolidationDatum, +) + + s3 = boto3.client("s3") -def handler(event, context): - consolidated_labels = [] - - parsed_url = urlparse(event["payload"]["s3Uri"]) - logger.info("Consolidating labels from %s", event["payload"]["s3Uri"]) - textFile = s3.get_object(Bucket=parsed_url.netloc, Key=parsed_url.path[1:]) - filecont = textFile["Body"].read() - annotations = json.loads(filecont) - - for dataset in annotations: - dataset_worker_anns = [] - consolidated_label = { - "workerAnnotations": dataset_worker_anns, - } - dataset_warnings = [] - - label = { - "datasetObjectId": dataset["datasetObjectId"], - "consolidatedAnnotation": { - "content": { - event["labelAttributeName"]: consolidated_label, - }, - }, - } - - for annotation in dataset["annotations"]: - ann_raw = json.loads(annotation["annotationData"]["content"]) - ann_data = json.loads(annotation["annotationData"]["content"]) # (Deep clone of raw) - ann_data["workerId"] = annotation["workerId"] - # Find the unique OCR annotation IDs: - ann_ocr_ids = set( - map( - lambda m: m.group(1), - filter( - lambda m: m, - map( - lambda key: re.match(r"ocr-(.*)-[a-z]+", key, flags=re.IGNORECASE), - ann_raw.keys(), - ), - ), - ), +def consolidate_object_annotations( + object_data: ObjectAnnotationResult, + label_attribute_name: str, + label_categories: Optional[List[str]] = None, +) -> PostConsolidationDatum: + """Consolidate the (potentially multiple) raw worker annotations for a dataset object + + TODO: Actual consolidation/reconciliation of multiple labels is not yet supported! + + This function just takes the "first" (not necessarily clock-first) worker's result and outputs + a warning if others were found. + + Parameters + ---------- + object_data : + Object describing the raw annotations and metadata for a particular task in the SMGT job + label_attribute_name : + Target attribute on the output object to store consolidated label results (note this may + not be the *only* attribute set/updated on the output object, hence provided as a param + rather than abstracted away). + label_categories : + Label categories specified when creating the labelling job. If provided, this is used to + translate from class names to numeric class_id similarly to SMGT's built-in bounding box + task result. + """ + warn_msgs: List[str] = [] + worker_anns: List[SMGTWorkerAnnotation] = [] + for worker_ann in object_data.annotations: + ann_raw = worker_ann.fetch_data() + worker_anns.append(SMGTWorkerAnnotation.parse(ann_raw, class_list=label_categories)) + + if len(worker_anns) > 1: + warn_msg = ( + "Reconciliation of multiple worker annotations is not currently implemented for this " + "post-processor. Outputting annotation from worker %s and ignoring labels from %s" + % ( + object_data.annotations[0].worker_id, + [a.worker_id for a in object_data.annotations[1:]], ) - # Normalize the OCR labels for this annotation: - ocr_ann_data = [] - ann_data["ocrAnnotations"] = ocr_ann_data - for ocr_id in ann_ocr_ids: - meta_field_key = f"ocr-{ocr_id}-meta" - if meta_field_key in ann_data: - ocr_datum = json.loads(ann_data[meta_field_key]) - del ann_data[meta_field_key] - else: - ocr_datum = {} - ocr_datum["annotationId"] = ocr_id - - # Consolidate the field's status from (potentially missing/inconsistent) radios: - ocr_statuses = ("correct", "unclear", "wrong") - ocr_status_fields = [f"ocr-{ocr_id}-{s}" for s in ocr_statuses] - unknown_statuses = [ - s for ix, s in enumerate(ocr_statuses) if ocr_status_fields[ix] not in ann_data - ] - selected_statuses = [ - s - for ix, s in enumerate(ocr_statuses) - if ann_data.get(ocr_status_fields[ix], {}).get("on") - ] - if len(selected_statuses) >= 1: - ocr_datum["status"] = selected_statuses[0] - else: - dataset_warnings.append( - f"Missing correct/unclear/wrong status for OCR field {ocr_id}", - ) - if len(selected_statuses) > 1: - dataset_warnings.append( - "OCR field {} tagged to multiple statuses {}: Taking first value".format( - ocr_id, - selected_statuses, - ) - ) - if len(unknown_statuses): - dataset_warnings.append( - "".join( - "Could not determine whether the following statuses were selected ", - "for OCR field {}: {}", - ).format( - ocr_id, - unknown_statuses, - ) - ) - for key in ocr_status_fields: - if key in ann_data: - del ann_data[key] - - # Load in the correction text, if provided: - correction_field_key = f"ocr-{ocr_id}-override" - if correction_field_key in ann_data: - # Ignore correction if 'wrong' was not selected: - if "wrong" in selected_statuses: - ocr_datum["correction"] = ann_data[correction_field_key] - # Tidy up the raw field regardless: - del ann_data[correction_field_key] - - ocr_ann_data.append(ocr_datum) - dataset_worker_anns.append(ann_data) - - if len(dataset_warnings): - consolidated_label["consolidationWarnings"] = dataset_warnings - if len(dataset_worker_anns): - # Take first annotation as 'consolidated' value: - for key in dataset_worker_anns[0]: - if key not in consolidated_label: - consolidated_label[key] = dataset_worker_anns[0][key] - consolidated_labels.append(label) - - return consolidated_labels + ) + logger.warning(warn_msg) + warn_msgs.append(warn_msg) + + consolidated_label = worker_anns[0].to_jsonable() + if len(warn_msgs): + consolidated_label["consolidationWarnings"] = warn_msgs + + return PostConsolidationDatum( + dataset_object_id=object_data.dataset_object_id, + consolidated_content={ + label_attribute_name: consolidated_label, + # Note: In our tests it's not possible to add a f"{label_attribute_name}-meta" field + # here - it gets replaced by whatever post-processing happens, instead of merged. + }, + ) + + +def handler(event: dict, context) -> List[dict]: + """Main Lambda handler for consolidation of SMGT worker annotations + + This function receives a batched request to consolidate (multiple?) workers' annotations for + multiple objects, and outputs the consolidated results per object. For more docs see: + + https://docs.aws.amazon.com/sagemaker/latest/dg/sms-custom-templates-step3-lambda-requirements.html + """ + logger.info("Received event: %s", json.dumps(event)) + req = ConsolidationRequest.parse(event) + if req.label_categories and len(req.label_categories) > 0: + label_cats = req.label_categories + else: + logger.warning( + "Label categories list (see CreateLabelingJob.LabelCategoryConfigS3Uri) was not " + "provided when creating this job. Post-consolidation outputs will be incompatible with " + "built-in Bounding Box task, because we're unable to map class names to numeric IDs." + ) + label_cats = None + + # Loop through the objects in this batch, consolidating annotations for each: + return [ + consolidate_object_annotations( + object_data, + label_attribute_name=req.label_attribute_name, + label_categories=label_cats, + ).to_jsonable() + for object_data in req.fetch_object_annotations() + ] diff --git a/annotation/fn-SMGT-Post/smgt.py b/annotation/fn-SMGT-Post/smgt.py new file mode 100644 index 0000000..8be2948 --- /dev/null +++ b/annotation/fn-SMGT-Post/smgt.py @@ -0,0 +1,288 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: MIT-0 +"""Parsers, data models, and utilities for generic SageMaker Ground Truth objects + +This module contains code for dealing with SMGT intermediate formats and consolidation requests +in general (not specific to our particular custom template design). +""" +# Python Built-Ins: +from __future__ import annotations +from abc import ABC +from dataclasses import dataclass +import json +import logging +from typing import List, Optional, Union +from urllib.parse import urlparse + +# External Dependencies: +import boto3 # AWS SDK for Python + +logger = logging.getLogger("smgt") +s3client = boto3.client("s3") + + +class BaseObjectParser(ABC): + """Base interface for classes that can be created by parse()ing some (JSON?) object""" + + @classmethod + def parse(cls, obj: Union[dict, float, int, list, str]) -> BaseObjectParser: + raise NotImplementedError("Parsers must implement parse() method") + + +class BaseJsonable(ABC): + """Base interface for classes that can be represented by a JSON-serializable object""" + + def to_jsonable(self) -> Union[dict, float, int, list, str]: + raise NotImplementedError("BaseJsonable classes must implement to_jsonable() method") + + +class S3OrInlineObject: + """Wrapper class for API dicts that contain either `content` (inline) or `s3Uri`""" + + def __init__(self, obj: dict): + if "content" in obj: + self._inline = True + self._raw = obj["content"] + elif "s3Uri" in obj: + self._inline = False + self._raw = obj["s3Uri"] + else: + raise ValueError("API object expected to contain either 'content' or 's3 key: %s" % obj) + + def fetch(self) -> Union[bytes, str]: + """Load the text content (either inline or from S3 object)""" + if self._inline: + return self._raw + else: + logger.info("Fetching S3 object %s", self._raw) + parsed_url = urlparse(self._raw) + text_file = s3client.get_object(Bucket=parsed_url.netloc, Key=parsed_url.path[1:]) + return text_file["Body"].read() + + +@dataclass +class WorkerAnnotation(BaseObjectParser): + """One worker's raw annotation for an object + + Attributes + ---------- + worker_id : + Opaque worker identifier, for example something like "private.us-east-1.e47e1e0123456789" + for an internal workforce in us-east-1. + """ + + worker_id: str + _annotation_data: S3OrInlineObject + + @classmethod + def parse(cls, obj: dict) -> WorkerAnnotation: + return cls( + worker_id=obj["workerId"], _annotation_data=S3OrInlineObject(obj["annotationData"]) + ) + + def fetch_data(self) -> dict: + """Fetch (and JSON-parse) the worker's annotation for this object""" + return json.loads(self._annotation_data.fetch()) + + +@dataclass +class ObjectAnnotationResult(BaseObjectParser): + """One dataset object's pre-consolidation annotations (from multiple workers) + + Attributes + ---------- + dataset_object_id : + Index of the object in the SMGT job dataset + data_object : + Main input object (i.e. just the source-ref or source, not the whole manifest line) for the + task + annotations : + List of raw annotations from possibly multiple workers. + """ + + dataset_object_id: str + data_object: S3OrInlineObject + annotations: List[WorkerAnnotation] + + @classmethod + def parse(cls, obj: dict) -> ObjectAnnotationResult: + return cls( + dataset_object_id=obj["datasetObjectId"], + data_object=S3OrInlineObject(obj["dataObject"]), + annotations=[WorkerAnnotation.parse(o) for o in obj["annotations"]], + ) + + +@dataclass +class ConsolidationRequest(BaseObjectParser): + """Loaded `event` for this post-annotation Lambda function + + See: + https://docs.aws.amazon.com/sagemaker/latest/dg/sms-custom-templates-step3-lambda-requirements.html#sms-custom-templates-step3-postlambda + + Attributes + ---------- + version : + A version number used internally by Ground Truth + labelingJobArn : + The Amazon Resource Name, or ARN, of your labeling job. This ARN can be used to reference + the labeling job when using Ground Truth API operations such as DescribeLabelingJob. + labelCategories : + Includes the label categories and other attributes you either specified in the console, or + that you include in the label category configuration file. + labelAttributeName : + Either the name of your labeling job, or the label attribute name you specify when you + create the labeling job. + roleArn : + The Amazon Resource Name (ARN) of the IAM execution role you specify when you create the + labeling job. + payload : + Raw annotation data for this request. + """ + + version: str + labeling_job_arn: str + label_categories: List[str] + label_attribute_name: str + role_arn: str + payload: S3OrInlineObject + + @classmethod + def parse(cls, obj: dict) -> ConsolidationRequest: + return cls( + version=obj["version"], + labeling_job_arn=obj["labelingJobArn"], + label_categories=obj.get("labelCategories", []), + label_attribute_name=obj["labelAttributeName"], + role_arn=obj["roleArn"], + payload=S3OrInlineObject(obj["payload"]), + ) + + def fetch_object_annotations(self) -> List[ObjectAnnotationResult]: + """Fetch and parse the list of object raw annotation data included for this request""" + logger.info("Fetching consolidation request payload (raw annotation data)") + payload_data = self.payload.fetch() + logger.info("Parsing raw annotation list") + payload_data = json.loads(payload_data) + if not isinstance(payload_data, list): + raise ValueError( + "Expected consolidation request.payload to point to a JSON list file, but top-level" + "object after parsing was of type: %s" % type(payload_data) + ) + return [ObjectAnnotationResult.parse(object_data) for object_data in payload_data] + + +@dataclass +class PostConsolidationDatum(BaseJsonable): + """Expected output format for each dataset object in the consolidation request""" + + dataset_object_id: str + consolidated_content: Union[dict, BaseJsonable] + + def to_jsonable(self) -> dict: + if hasattr(self.consolidated_content, "to_jsonable"): + content = self.consolidated_content.to_jsonable() + else: + content = self.consolidated_content + return { + "datasetObjectId": self.dataset_object_id, + "consolidatedAnnotation": {"content": content}, + } + + +@dataclass +class SMGTOutputBoundingBox(BaseJsonable, BaseObjectParser): + """Maybe-slightly-customized SageMaker Ground Truth bounding box data model + + TODO: Review whether we should have two separate models here? + + SMGT built-in bounding box jobs produce boxes with numeric `class_id` in the output... But the + crowd-bounding-box annotator tool produces raw outputs with string `label` to identify the + selected class. + + It's possible to convert between the two in the post-annotation Lambda function, + assuming your SMGT job was set up with the `LabelCategoryConfigS3Uri` parameter (in which case + the Lambda will receive the class name list). + + *However*, the built-in task outputs a `"class-map": {"0": "Name0", ...}` key in the `-meta` + field to link from numeric IDs back to string labels. AFAICT we can't output `-meta` field data + with custom task post-processing Lambdas, as it seems to get overwritten in the output. + + So instead, this class allows you to specify the class list at parse() time and will serialize + *both* class_id number and label string into the output box, if both are known. + + Attributes + ---------- + top : + Absolute top of the bounding box relative to the page origin (in pixels) + left : + Absolute left of the bounding box relative to the page origin (in pixels) + height : + Absolute height of the bounding box in pixels + width : + Absolute width of the bounding box in pixels + label : + String class/type name of the bounding box (if known) + class_id : + 0-based integer class/type ID of the bounding box (if known) + """ + + top: int + left: int + height: int + width: int + label: Optional[str] = None + class_id: Optional[int] = None + + def to_jsonable(self) -> dict: + """Serialize the box to a JSON-able plain dictionary + + Whichever of `class_id` or `label` (or both) are known will be included. + """ + result = {"top": self.top, "left": self.left, "height": self.height, "width": self.width} + if self.class_id is not None: + result["class_id"] = self.class_id + if self.label is not None: + result["label"] = self.label + return result + + @classmethod + def parse(cls, obj: dict, class_list: Optional[List[str]] = None) -> SMGTOutputBoundingBox: + """Parse the bounding box from a SMGT box dictionary + + Parameters + ---------- + obj : + Dictionary specifying the box as generated by SageMaker Ground Truth labelling tool + class_list : + Optional list of class names, which if provided will be used to map between numeric + `class_id` and string `label` in cases where only one is provided in the raw data. + """ + label = obj.get("label") + class_id = obj.get("class_id") + if class_list and len(class_list) > 0: + if label is None and class_id is not None: + if class_id >= 0 and class_id < len(class_list): + label = class_list[class_id] + else: + logger.warning( + "Box class ID %s is out of range 0-%s: Could not infer class name", + class_id, + len(class_list), + ) + elif class_id is None and label is not None: + try: + class_id = class_list.index(label) + except ValueError: + logger.warning( + "Box class name '%s' not in provided list: Could not infer class ID" + ) + + return cls( + top=obj["top"], + left=obj["left"], + height=obj["height"], + width=obj["width"], + label=label, + class_id=class_id, + ) diff --git a/notebooks/annotation/ocr-bbox-and-validation.liquid.tpl.html b/notebooks/annotation/ocr-bbox-and-validation.liquid.tpl.html index 0b64602..954ce94 100644 --- a/notebooks/annotation/ocr-bbox-and-validation.liquid.tpl.html +++ b/notebooks/annotation/ocr-bbox-and-validation.liquid.tpl.html @@ -25,6 +25,7 @@ > + Using the tool } function consolidateBboxAnnotations(annotations) { - var groups = []; // { label, boxes } + var groups = []; // { label, boxes, boxIxs } annotations.forEach(function (box, ixBox) { var className = box.label; var matchedGroup = null; @@ -329,10 +330,12 @@

Using the tool

matchedGroup = group; // TODO: Might be nice to check if any boxes are completely contained->redundant matchedGroup.boxes.push(box); + matchedGroup.boxIxs.push(ixBox); newGroups.push(matchedGroup); } else { // This box merges `group` into `matchedGroup` matchedGroup.boxes = matchedGroup.boxes.concat(group.boxes); + matchedGroup.boxIxs = matchedGroup.boxIxs.concat(group.boxIxs); } } else { // This group does not overlap with this annotation or is different class - pass through @@ -344,6 +347,7 @@

Using the tool

newGroups.push({ label: className, boxes: [box], + boxIxs: [ixBox], }); } groups = newGroups; @@ -448,6 +452,7 @@

Using the tool

label: annGroup.label, labelId: annGroup.labelId, boxes: annGroup.boxes, + boxIxs: annGroup.boxIxs, words: annGroup.words, ocrText: annGroup.ocrText, }); @@ -458,8 +463,7 @@

Using the tool

label=annGroup.label, labelId=annGroup.labelId, labelColor=labelColors[annGroup.label], - boxes=annGroup.boxes, - wordIds=annGroup.words.map(function(w) { return w.id; }), + boxIxs=annGroup.boxIxs, ); document.getElementById("ocr-fields").appendChild(ocrFieldEl); } else { @@ -490,14 +494,15 @@

Using the tool

radio.removeAttribute("checked"); }); ocrField.ocrText = annGroup.ocrText; - metaField.wordIds = annGroup.words.map(function(w) { return w.id; }); + metaField.ocrText = annGroup.ocrText; metaChanged = true; } - if (JSON.stringify(ocrField.boxes) !== JSON.stringify(annGroup.boxes)) { - metaField.boxes = annGroup.boxes; + if (JSON.stringify(ocrField.boxIxs) !== JSON.stringify(annGroup.boxIxs)) { + metaField.boxIxs = annGroup.boxIxs; metaChanged = true; } ocrField.boxes = annGroup.boxes; + ocrField.boxIxs = annGroup.boxIxs; ocrField.words = annGroup.words; if (labelColors[annGroup.label]) { container.style.borderColor = labelColors[annGroup.label]; @@ -527,7 +532,7 @@

Using the tool

window.ocrFields = newOcrFields; } - function renderOcrField(id, ocrText="", label=null, labelId=null, labelColor=null, boxes=[], wordIds=[]) { + function renderOcrField(id, ocrText="", label=null, labelId=null, labelColor=null, boxIxs=[]) { var container = document.createElement("div"); container.id = "ocr-field-" + id; container.classList.add("field-container"); @@ -541,8 +546,8 @@

Using the tool

JSON.stringify({ label: label, labelId: labelId, - boxes: boxes, - wordIds: wordIds, + boxIxs: boxIxs, + ocrText: ocrText, }), ); metastore.style.display = "none"; From 6d6eb4cbfa119775c20be0b7e9a09377fa7e958c Mon Sep 17 00:00:00 2001 From: Alex Thewsey Date: Tue, 17 Jan 2023 12:14:07 +0800 Subject: [PATCH 11/14] fix(smgt): Missing JSON serializer in postproc --- annotation/fn-SMGT-Post/data_model.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/annotation/fn-SMGT-Post/data_model.py b/annotation/fn-SMGT-Post/data_model.py index 5b65341..298217b 100644 --- a/annotation/fn-SMGT-Post/data_model.py +++ b/annotation/fn-SMGT-Post/data_model.py @@ -204,6 +204,21 @@ def parse( target_text=target_text, ) + def to_jsonable(self) -> dict: + return { + k: v + for k, v in { + "detectionId": self.detection_id, + "ocrStatus": self.ocr_status, + "boxIxs": self.box_ixs, + "classId": self.class_id, + "label": self.label, + "rawText": self.raw_text, + "targetText": self.target_text, + }.items() + if v is not None + } + @dataclass class SMGTWorkerAnnotation(BaseJsonable, BaseObjectParser): From 16900a328b17356d316ff824a3a5745359a43907 Mon Sep 17 00:00:00 2001 From: Alex Thewsey Date: Tue, 17 Jan 2023 15:15:42 +0800 Subject: [PATCH 12/14] feat(seq2seq): Enable seq2seq from SMGT labels Update the training script data loaders to support training seq2seq model (still plain text) from SMGT custom task UI entity OCR validation results. Update seq2seq in Extras notebook to match setup used in demo. Add mention of SMGT data usage to Extras notebook. --- notebooks/Optional Extras.ipynb | 57 +++- notebooks/src/code/data/base.py | 20 ++ notebooks/src/code/data/geometry.py | 172 ----------- notebooks/src/code/data/ner.py | 8 +- notebooks/src/code/data/seq2seq/metrics.py | 57 ++++ .../src/code/data/seq2seq/task_builder.py | 273 +++++++++++++----- notebooks/src/code/data/smgt.py | 260 +++++++++++++++++ notebooks/src/code/data/splitting.py | 24 ++ 8 files changed, 602 insertions(+), 269 deletions(-) create mode 100644 notebooks/src/code/data/seq2seq/metrics.py create mode 100644 notebooks/src/code/data/smgt.py diff --git a/notebooks/Optional Extras.ipynb b/notebooks/Optional Extras.ipynb index 5bb3765..2b645d6 100644 --- a/notebooks/Optional Extras.ipynb +++ b/notebooks/Optional Extras.ipynb @@ -898,8 +898,8 @@ "train_dataset.save_to_disk(\"data/seq2seq-train\")\n", "eval_dataset.save_to_disk(\"data/seq2seq-validation\")\n", "\n", - "print(\"Dataset sample (top 10 records):\")\n", - "pd.DataFrame(train_dataset[0:10])" + "print(\"Dataset sample (top 15 records):\")\n", + "pd.DataFrame(train_dataset[0:15])" ] }, { @@ -936,6 +936,29 @@ "!aws s3 sync --delete data/seq2seq-validation {validation_s3uri}" ] }, + { + "cell_type": "markdown", + "id": "de96d940-87fb-4b76-b3dc-f4c30fd93473", + "metadata": {}, + "source": [ + "### 🧪 (Experimental) Training with annotated documents\n", + "\n", + "If you annotated your documents using the **custom** SageMaker Ground Truth task UI in Notebook 1 (with OCR transcript reviews), instead of the default (bounding-box-only) UI, you should also be able to directly train the seq2seq model on your manually-annotated data.\n", + "\n", + "To do this, set your `train`, `textract` and `validation` channels as shown in Notebook 2 instead of the synthetic/augmented dataset used below. The script will build seq2seq examples from your annotated entity types, raw OCR text, and corrected OCR texts - something like:\n", + "\n", + "```json\n", + "{\n", + " \"src_texts\": \"Normalize Card Name: mycool Credit Card.\",\n", + " \"tgt_texts\": \"MyCool Credit Card\"\n", + "}\n", + "```\n", + "\n", + "In the *Integrate with processing pipeline* section below, you'd then configure your normalization prompts to be of the format `Normalize {YourFieldLabel}: ` for each field where you wanted to turn the normalizing model on, instead of the `Convert dates...` prompt we use.\n", + "\n", + "You'll probably find it easiest to run through this example with the generated date-normalization dataset first to understand the flow, before trying to use your SMGT annotations instead." + ] + }, { "cell_type": "markdown", "id": "3f120616-0ae6-49c0-9821-893e5501cf9d", @@ -1013,30 +1036,28 @@ "from sagemaker.huggingface.estimator import HuggingFace as HuggingFaceEstimator\n", "\n", "hyperparameters = {\n", - " \"model_name_or_path\": \"t5-base\",\n", + " \"model_name_or_path\": \"google/byt5-base\",\n", " \"task_name\": \"seq2seq\",\n", " \"logging_steps\": 100,\n", " \"evaluation_strategy\": \"steps\",\n", - " \"eval_steps\": 200,\n", + " \"eval_steps\": 250, # (=Twice per epoch, at 1000 data points & batch size 2)\n", " # Only need to set do_eval when validation channel is not provided and want to generate:\n", " \"do_eval\": \"1\",\n", " \"save_strategy\": \"steps\",\n", - " \"save_steps\": 200,\n", - " \"learning_rate\": 5e-4,\n", + " \"save_steps\": 250,\n", + " \"learning_rate\": 1e-4,\n", " \"per_device_train_batch_size\": 2,\n", " \"per_device_eval_batch_size\": 4,\n", " \"seed\": 1337,\n", "\n", - " \"num_train_epochs\": 5, # Set high to drive via early stopping\n", - " \"early_stopping_patience\": 4, # Usually stops after <25 epochs on this sample data+config\n", + " \"num_train_epochs\": 5.01, # Make sure the epoch==5.0 evaluation gets taken\n", + " \"early_stopping_patience\": 4,\n", " \"metric_for_best_model\": \"eval_acc\",\n", " # \"greater_is_better\": \"false\",\n", - " # # Early stopping implies checkpointing every evaluation (epoch), so limit the total checkpoints\n", - " # # kept to avoid filling up disk:\n", + " # Avoid filling up disk with too many saved model checkpoints:\n", " \"save_total_limit\": 10,\n", "}\n", "\n", - "\n", "metric_definitions = [\n", " {\"Name\": \"epoch\", \"Regex\": util.training.get_hf_metric_regex(\"epoch\")},\n", " {\"Name\": \"learning_rate\", \"Regex\": util.training.get_hf_metric_regex(\"learning_rate\")},\n", @@ -1058,12 +1079,12 @@ " transformers_version=None,\n", " image_uri=train_image_uri, # Use the customized training container image\n", "\n", - " base_job_name=\"t5-datenorm\",\n", + " base_job_name=\"byt5-datenorm\",\n", " output_path=f\"s3://{bucket_name}/{bucket_prefix}trainjobs\",\n", "\n", - " instance_type=\"ml.g4dn.xlarge\", # Could also consider ml.p3.2xlarge\n", + " instance_type=\"ml.p3.2xlarge\", # t5-base fits on ml.g4dn.xlarge GPU, but not byt5-base\n", " instance_count=1,\n", - " volume_size=40,\n", + " volume_size=80,\n", "\n", " debugger_hook_config=False,\n", "\n", @@ -1289,7 +1310,9 @@ "id": "23c9c260-9633-43c5-b9f5-e01ca23431bb", "metadata": {}, "source": [ - "As shown above, this text-to-text model can take in a raw detected date mention (e.g. `Sunday Dec 31st 2000`) with a prompt prefix (e.g. `Convert dates to YYYY-MM-DD: `) and attempt to output the desired normalized format (e.g. `2000-12-31`)." + "As shown above, this text-to-text model can take in a raw detected date mention (e.g. `Sunday Dec 31st 2000`) with a prompt prefix (e.g. `Convert dates to YYYY-MM-DD: `) and attempt to output the desired normalized format (e.g. `2000-12-31`).\n", + "\n", + "Note that the \"overall accuracy\" metric reported above should match with the `eval_acc` metric emitted by the training job, since the same validation dataset is used." ] }, { @@ -1336,7 +1359,9 @@ "id": "4d53907a-2610-40b2-a28a-c51687228bc6", "metadata": {}, "source": [ - "Next, find any entity type that looks like a date (any with 'date' in the name), and configure the normalizer for those fields:" + "Next, find any entity type that looks like a date (any with 'date' in the name), and configure the normalizer for those fields:\n", + "\n", + "> ⚠️ **Note:** Check the way you prompt your normalization model matches how it was trained, for good results!" ] }, { diff --git a/notebooks/src/code/data/base.py b/notebooks/src/code/data/base.py index ecd2da8..6e3d503 100644 --- a/notebooks/src/code/data/base.py +++ b/notebooks/src/code/data/base.py @@ -106,6 +106,25 @@ def normalize_asset_ref( return asset_ref +def looks_like_hf_dataset(folder: str) -> bool: + """Check if a local folder looks like a HuggingFace `Dataset` from save_to_disk(), or not""" + if not os.path.isfile(os.path.join(folder, "dataset_info.json")): + logger.debug( + "Folder missing dataset_info.json does not appear to be HF Dataset: %s", + folder, + ) + return False + elif not os.path.isfile(os.path.join(folder, "state.json")): + logger.debug( + "Folder missing state.json does not appear to be HF Dataset: %s", + folder, + ) + return False + else: + logger.debug("Folder appears to be saved Hugging Face Dataset: %s", folder) + return True + + def find_images_from_textract_path( textract_file_path: str, images_path: str, @@ -201,6 +220,7 @@ def map_load_text_and_images( images_prefix: str = "", textract_path: str = "", textract_prefix: str = "", + # TODO: output_line_ids seems to be broken atm? At least for NER/seq2seq it's throwing errors output_line_ids: bool = False, ) -> Dict[str, List]: """datasets.map function to load examples for a manifest-file-like batch diff --git a/notebooks/src/code/data/geometry.py b/notebooks/src/code/data/geometry.py index f67a6a1..b6cfb9d 100644 --- a/notebooks/src/code/data/geometry.py +++ b/notebooks/src/code/data/geometry.py @@ -11,178 +11,6 @@ import trp -class AnnotationBoundingBox: - """Class to parse a bounding box annotated by SageMaker Ground Truth Object Detection - - Pre-calculates all box TLHWBR metrics (both absolute and relative) on init, for efficient and - easy processing later. - """ - - def __init__(self, manifest_box: dict, image_height: int, image_width: int): - self._class_id = manifest_box["class_id"] - self._abs_top = manifest_box["top"] - self._abs_left = manifest_box["left"] - self._abs_height = manifest_box["height"] - self._abs_width = manifest_box["width"] - self._abs_bottom = self.abs_top + self.abs_height - self._abs_right = self.abs_left + self.abs_width - self._rel_top = self._abs_top / image_height - self._rel_left = self._abs_left / image_width - self._rel_height = self._abs_height / image_height - self._rel_width = self._abs_width / image_width - self._rel_bottom = self._abs_bottom / image_height - self._rel_right = self._abs_right / image_width - - @property - def class_id(self): - return self._class_id - - @property - def abs_top(self): - return self._abs_top - - @property - def abs_left(self): - return self._abs_left - - @property - def abs_height(self): - return self._abs_height - - @property - def abs_width(self): - return self._abs_width - - @property - def abs_bottom(self): - return self._abs_bottom - - @property - def abs_right(self): - return self._abs_right - - @property - def rel_top(self): - return self._rel_top - - @property - def rel_left(self): - return self._rel_left - - @property - def rel_height(self): - return self._rel_height - - @property - def rel_width(self): - return self._rel_width - - @property - def rel_bottom(self): - return self._rel_bottom - - @property - def rel_right(self): - return self._rel_right - - -class BoundingBoxAnnotationResult: - """Class to parse the result field saved by a SageMaker Ground Truth Object Detection job""" - - def __init__(self, manifest_obj: dict): - """Initialize a BoundingBoxAnnotationResult - - Arguments - --------- - manifest_obj : dict - The contents of the output field of a record in a SMGT Object Detection labelling job - output manifest, or equivalent. - """ - try: - image_size_spec = manifest_obj["image_size"][0] - self._image_height = int(image_size_spec["height"]) - self._image_width = int(image_size_spec["width"]) - self._image_depth = ( - int(image_size_spec["depth"]) if "depth" in image_size_spec else None - ) - except Exception as e: - raise ValueError( - "".join( - ( - "manifest_obj must be a dictionary including 'image_size': a list of ", - "length 1 whose first/only element is a dict with integer properties ", - f"'height' and 'width', optionally also 'depth'. Got: {manifest_obj}", - ) - ) - ) from e - assert ( - len(manifest_obj["image_size"]) == 1 - ), f"manifest_obj['image_size'] must be a list of len 1. Got: {manifest_obj['image_size']}" - - try: - self._boxes = [ - AnnotationBoundingBox( - b, - image_height=self._image_height, - image_width=self._image_width, - ) - for b in manifest_obj["annotations"] - ] - except Exception as e: - raise ValueError( - "".join( - ( - "manifest_obj['annotations'] must be a list-like of absolute TLHW bounding box ", - f"dicts with class_id. Got {manifest_obj['annotations']}", - ) - ) - ) from e - - @property - def image_height(self): - return self._image_height - - @property - def image_width(self): - return self._image_width - - @property - def image_depth(self): - return self._image_depth - - @property - def boxes(self): - return self._boxes - - def normalized_boxes( - self, - return_tensors: Optional[str] = None, - ): - """Annotation boxes in 0-1000 normalized x0,y0,x1,y1 array/tensor format as per LayoutLM""" - raw_zero_to_one_list = [ - [ - box.rel_left, - box.rel_top, - box.rel_right, - box.rel_bottom, - ] - for box in self._boxes - ] - if return_tensors == "np" or not return_tensors: - if len(raw_zero_to_one_list) == 0: - npresult = np.zeros((0, 4), dtype="long") - else: - npresult = (np.array(raw_zero_to_one_list) * 1000).astype("long") - return npresult if return_tensors else npresult.tolist() - elif return_tensors == "pt": - if len(raw_zero_to_one_list) == 0: - return torch.zeros((0, 4), dtype=torch.long) - else: - return (torch.FloatTensor(raw_zero_to_one_list) * 1000).long() - else: - raise ValueError("return_tensors must be None, 'np' or 'pt'. Got: %s" % return_tensors) - - def layoutlm_boxes_from_trp_blocks( textract_blocks: Iterable[ Union[ diff --git a/notebooks/src/code/data/ner.py b/notebooks/src/code/data/ner.py index bc3047a..e8fa04a 100644 --- a/notebooks/src/code/data/ner.py +++ b/notebooks/src/code/data/ner.py @@ -29,7 +29,7 @@ split_long_dataset_samples, TaskData, ) -from .geometry import BoundingBoxAnnotationResult +from .smgt import BoundingBoxAnnotationResult logger = getLogger("data.ner") @@ -159,11 +159,13 @@ def map_smgt_boxes_to_word_labels( n_classes: int, ): """datasets.map function to tag "word_labels" from word "boxes" and SMGT bbox annotation data""" + # TODO: Check if manifest-line diagnostic feed-through is actually working? Seems broken manifest_lines = batch.get("manifest-line") if annotation_attr not in batch: raise ValueError( - "Bounding box label attribute '%s' missing from batch%s" - % (annotation_attr, f" (batch from manifest line {manifest_lines[0]}") + "Bounding box label attribute '{}' missing from batch{}".format( + annotation_attr, " (Manifest lines {manifest_lines})." if manifest_lines else "." + ) ) # TODO: More useful error messages if one fails? annotations = [BoundingBoxAnnotationResult(ann) for ann in batch[annotation_attr]] diff --git a/notebooks/src/code/data/seq2seq/metrics.py b/notebooks/src/code/data/seq2seq/metrics.py new file mode 100644 index 0000000..4702b78 --- /dev/null +++ b/notebooks/src/code/data/seq2seq/metrics.py @@ -0,0 +1,57 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: MIT-0 +"""Validation/accuracy metric callbacks for seq2seq modelling tasks""" +# Python Built-Ins: +from numbers import Real +from typing import Callable, Dict + +# External Dependencies: +import numpy as np +from transformers import EvalPrediction, PreTrainedTokenizerBase + + +def get_metric_computer( + tokenizer: PreTrainedTokenizerBase, +) -> Callable[[EvalPrediction], Dict[str, Real]]: + """An 'accuracy' computer for seq2seq tasks that ignores outer whitespace and case. + + For our example task, it's reasonable to measure exact-match accuracy (since we're normalising + small text spans - not e.g. summarizing long texts to shorter paragraphs). Therefore this metric + computer checks exact accuracy, while allowing for variations in case and leading/trailing + whitespace. + """ + + def compute_metrics(p: EvalPrediction) -> Dict[str, Real]: + # Convert model output probs/logits to predicted token IDs: + predicted_token_ids = np.argmax(p.predictions[0], axis=2) + # Replace everything from the first token onward with padding (as eos + # would terminate generation in a normal generate() call) + for ix_batch, seq in enumerate(predicted_token_ids): + eos_token_matches = np.where(seq == tokenizer.eos_token_id) + if len(eos_token_matches) and len(eos_token_matches[0]): + first_eos_posn = eos_token_matches[0][0] + predicted_token_ids[ix_batch, first_eos_posn:] = tokenizer.pad_token_id + + gen_texts = [ + s.strip().lower() + for s in tokenizer.batch_decode(predicted_token_ids, skip_special_tokens=True) + ] + + target_texts = [ + s.strip().lower() + for s in tokenizer.batch_decode( + # Replace label '-100' tokens (ignore index for BinaryCrossEntropy) with '0' ( + # token), to avoid an OverflowError when trying to decode the target text: + np.maximum(0, p.label_ids), + skip_special_tokens=True, + ) + ] + + n_examples = len(gen_texts) + n_correct = sum(1 for gen, target in zip(gen_texts, target_texts) if gen == target) + return { + "n_examples": len(gen_texts), + "acc": n_correct / n_examples, + } + + return compute_metrics diff --git a/notebooks/src/code/data/seq2seq/task_builder.py b/notebooks/src/code/data/seq2seq/task_builder.py index 5cbfe9c..ccb51fc 100644 --- a/notebooks/src/code/data/seq2seq/task_builder.py +++ b/notebooks/src/code/data/seq2seq/task_builder.py @@ -13,32 +13,34 @@ """ # Python Built-Ins: from logging import getLogger -from numbers import Real import os -from typing import Callable, Dict, Optional, Union +from typing import Dict, Optional, Union # External Dependencies: import datasets import numpy as np -from transformers import EvalPrediction, PreTrainedTokenizerBase +from transformers import PreTrainedTokenizerBase from transformers.processing_utils import ProcessorMixin from transformers.utils.generic import PaddingStrategy, TensorType from transformers.tokenization_utils_base import TruncationStrategy # Local Dependencies: from ...config import DataTrainingArguments -from ..base import TaskData +from ..base import looks_like_hf_dataset, prepare_base_dataset, TaskData +from ..smgt import BBoxesWithTranscriptReviewsAnnotationResult +from ..splitting import duplicate_batch_record, remove_batch_records from .date_normalization import generate_seq2seq_date_norm_dataset +from .metrics import get_metric_computer logger = getLogger("data.seq2seq") -def _preprocess_seq2seq_dataset( +def _map_collate_seq2seq_dataset( batch: Dict[str, list], tokenizer: PreTrainedTokenizerBase, add_special_tokens: bool = True, - padding: Union[bool, str, PaddingStrategy] = False, + padding: Union[bool, str, PaddingStrategy] = "max_length", truncation: Union[bool, str, TruncationStrategy] = None, max_input_length: Optional[int] = None, max_output_length: Optional[int] = None, @@ -54,10 +56,7 @@ def _preprocess_seq2seq_dataset( return_length: bool = False, verbose: bool = True, ) -> Dict[str, list]: - """map fn to tokenize a seq2seq dataset ready for use in training - - TODO: Should we use a DataCollator for per-batch tokenization instead? - """ + """map fn to tokenize a seq2seq dataset ready for use in training""" # encode the documents prompts = batch["src_texts"] answers = batch["tgt_texts"] @@ -113,51 +112,143 @@ def _preprocess_seq2seq_dataset( return model_inputs -def get_metric_computer( +def collate_seq2seq_dataset( + dataset: datasets.Dataset, tokenizer: PreTrainedTokenizerBase, -) -> Callable[[EvalPrediction], Dict[str, Real]]: - """An 'accuracy' computer for seq2seq tasks that ignores outer whitespace and case. + max_input_len: Optional[int] = None, + max_output_len: Optional[int] = None, + pad_to_multiple_of: Optional[int] = None, + padding: Union[bool, str, PaddingStrategy] = "max_length", + cache_dir: Optional[str] = None, + cache_file_prefix: Optional[str] = None, + num_workers: Optional[int] = None, +) -> datasets.Dataset: + """Tokenize a seq2seq dataset ready for use in training - For our example task, it's reasonable to measure exact-match accuracy (since we're normalising - small text spans - not e.g. summarizing long texts to shorter paragraphs). Therefore this metric - computer checks exact accuracy, while allowing for variations in case and leading/trailing - whitespace. + TODO: Should we use a DataCollator for per-batch tokenization instead? """ + preproc_kwargs = { + "max_input_length": max_input_len, + "max_output_length": max_output_len, + "pad_to_multiple_of": pad_to_multiple_of, + "padding": padding, + "tokenizer": tokenizer, + } + return dataset.map( + _map_collate_seq2seq_dataset, + batched=True, + cache_file_name=( + os.path.join(cache_dir, f"{cache_file_prefix}_collated.arrow") + if (cache_dir and cache_file_prefix) + else None + ), + num_proc=num_workers, + remove_columns=dataset.column_names, + fn_kwargs=preproc_kwargs, + ) - def compute_metrics(p: EvalPrediction) -> Dict[str, Real]: - # Convert model output probs/logits to predicted token IDs: - predicted_token_ids = np.argmax(p.predictions[0], axis=2) - # Replace everything from the first token onward with padding (as eos - # would terminate generation in a normal generate() call) - for ix_batch, seq in enumerate(predicted_token_ids): - eos_token_matches = np.where(seq == tokenizer.eos_token_id) - if len(eos_token_matches) and len(eos_token_matches[0]): - first_eos_posn = eos_token_matches[0][0] - predicted_token_ids[ix_batch, first_eos_posn:] = tokenizer.pad_token_id - - gen_texts = [ - s.strip().lower() - for s in tokenizer.batch_decode(predicted_token_ids, skip_special_tokens=True) - ] - target_texts = [ - s.strip().lower() - for s in tokenizer.batch_decode( - # Replace label '-100' tokens (ignore index for BinaryCrossEntropy) with '0' ( - # token), to avoid an OverflowError when trying to decode the target text: - np.maximum(0, p.label_ids), - skip_special_tokens=True, - ) +def map_smgt_data_to_fieldnorm_seq2seq( + batch: Dict[str, list], # TODO: Support List[Any]? Union[Dict[List], List[Any]], + annotation_attr: str, +): + """Map base Textract+SMGT dataset with custom task UI inputs to a seq2seq field normalizing task + + Between the already-extracted Textract data and the boxes available on the SMGT result, you + should have everything you need here to support validating the raw text matches the source doc + at given locations and pulling through the source word layout boxes (similar to what we do in + NER data prep) - but since our seq2seq models are all text-only it's not been done for now. + """ + if annotation_attr not in batch: + raise ValueError(f"Ground Truth label attribute '{annotation_attr}' missing from batch.") + + anns_orig = batch[annotation_attr][:] + # Create placeholders in batch for fields to be built: + batch["class_name"] = [None for _ in anns_orig] + batch["src_texts"] = [None for _ in anns_orig] + batch["tgt_texts"] = [None for _ in anns_orig] + + # Process the batch, expanding it as we go (to one record per entity): + ix_offset = 0 + for ix_orig, ann in enumerate(anns_orig): + ix_cur = ix_orig + ix_offset + ann = BBoxesWithTranscriptReviewsAnnotationResult(ann) + valid_entities = [ + ent + for ent in ann.entities + if ent.label is not None and ent.raw_text is not None and ent.target_text is not None ] + n_valid_entities = len(valid_entities) + if n_valid_entities == 0: + batch = remove_batch_records(batch, ix_cur, n=1) + ix_offset -= 1 + else: + batch = duplicate_batch_record( + batch, + ix_cur, + n_valid_entities, + feature_overrides={ + "class_name": [ent.label for ent in valid_entities], + "src_texts": [ + f"Normalize {ent.label}: {ent.raw_text}" for ent in valid_entities + ], + "tgt_texts": [ent.target_text for ent in valid_entities], + }, + ) + ix_offset += n_valid_entities - 1 - n_examples = len(gen_texts) - n_correct = sum(1 for gen, target in zip(gen_texts, target_texts) if gen == target) - return { - "n_examples": len(gen_texts), - "acc": n_correct / n_examples, - } + return batch - return compute_metrics + +def prepare_dataset( + data_path: str, + annotation_attr: Optional[str] = None, + textract_path: Optional[str] = None, + images_path: Optional[str] = None, + images_prefix: str = "", + textract_prefix: str = "", + num_workers: Optional[int] = None, + batch_size: int = 16, + cache_dir: Optional[str] = None, + cache_file_prefix: Optional[str] = None, +) -> datasets.Dataset: + + if looks_like_hf_dataset(data_path): + # Pre-prepared dataset, just load and return: + return datasets.load_from_disk(data_path) + + # Else we need to prepare the dataset from Textract/SMGT files. + dataset = prepare_base_dataset( + textract_path=textract_path, + manifest_file_path=data_path, + images_path=images_path, + images_prefix=images_prefix, + textract_prefix=textract_prefix, + num_workers=num_workers, + batch_size=batch_size, + cache_dir=cache_dir, + map_cache_file_name=( + os.path.join(cache_dir, f"{cache_file_prefix}_1base.arrow") + if (cache_dir and cache_file_prefix) + else None + ), + ).map( + map_smgt_data_to_fieldnorm_seq2seq, + batched=True, + batch_size=batch_size, + fn_kwargs={"annotation_attr": annotation_attr}, + num_proc=num_workers, + desc="Extracting seq2seq examples from Ground Truth annotations", + cache_file_name=( + os.path.join(cache_dir, f"{cache_file_prefix}_2label.arrow") + if (cache_dir and cache_file_prefix) + else None + ), + ) + + # Since this is a field-text normalization task, splitting long samples is not supported (e.g. + # with `split_long_dataset_samples()` as in other tasks) + return dataset def get_task( @@ -178,43 +269,69 @@ def get_task( # Load or create the training and validation datasets: if data_args.train: - logger.info("Loading seq2seq training dataset from disk %s", data_args.train) - train_dataset = datasets.load_from_disk(data_args.train) + train_dataset = prepare_dataset( + data_path=data_args.train, + annotation_attr=data_args.annotation_attr, + textract_path=data_args.textract, + images_path=data_args.images, + images_prefix=data_args.images_prefix, + textract_prefix=data_args.textract_prefix, + num_workers=n_workers, + batch_size=data_args.dataproc_batch_size, + cache_dir=cache_dir, + cache_file_prefix="seq2seqtrain", + ) + logger.info("Train dataset ready: %s", train_dataset) else: + # TODO: Factor generation+preprocessing into separate generate_dataset fn? logger.info("Generating new synthetic seq2seq training dataset") - train_dataset = generate_seq2seq_date_norm_dataset(n=1000, rng=rng) + train_dataset = generate_seq2seq_date_norm_dataset( + n=1000, + rng=rng, + ) + + train_dataset = collate_seq2seq_dataset( + train_dataset, + tokenizer, + max_input_len=data_args.max_seq_length - 2, + max_output_len=64, # TODO: Parameterize? + pad_to_multiple_of=data_args.pad_to_multiple_of, + cache_dir=cache_dir, + cache_file_prefix="seq2seqtrain", + ) if data_args.validation: - logger.info("Loading seq2seq validation dataset from disk %s", data_args.validation) - eval_dataset = datasets.load_from_disk(data_args.validation) + eval_dataset = prepare_dataset( + data_path=data_args.validation, + annotation_attr=data_args.annotation_attr, + textract_path=data_args.textract, + images_path=data_args.images, + images_prefix=data_args.images_prefix, + textract_prefix=data_args.textract_prefix, + num_workers=n_workers, + batch_size=data_args.dataproc_batch_size, + cache_dir=cache_dir, + cache_file_prefix="seq2seqval", + ) + logger.info("Validation dataset ready: %s", eval_dataset) else: - logger.info("Generating new synthetic seq2seq validation dataset") - eval_dataset = generate_seq2seq_date_norm_dataset(n=200, rng=rng) + if not data_args.train: + logger.info("Generating new synthetic seq2seq validation dataset") + eval_dataset = generate_seq2seq_date_norm_dataset(n=200, rng=rng) + else: + # Can't assume it's the date norm task: Leave no val set + eval_dataset = None - # Pre-process the datasets with the tokenizer: - preproc_kwargs = { - "max_input_length": data_args.max_seq_length - 2, # To allow for CLS+SEP in final - "max_output_length": 64, # TODO: Parameterize? - "pad_to_multiple_of": data_args.pad_to_multiple_of, - "padding": "max_length", - "tokenizer": tokenizer, - } - train_dataset = train_dataset.map( - _preprocess_seq2seq_dataset, - batched=True, - cache_file_name=(os.path.join(cache_dir, "seq2seqtrain.arrow") if cache_dir else None), - num_proc=n_workers, - remove_columns=train_dataset.column_names, - fn_kwargs=preproc_kwargs, - ) - eval_dataset = eval_dataset.map( - _preprocess_seq2seq_dataset, - batched=True, - cache_file_name=(os.path.join(cache_dir, "seq2seqeval.arrow") if cache_dir else None), - num_proc=n_workers, - remove_columns=eval_dataset.column_names, - fn_kwargs=preproc_kwargs, - ) + if eval_dataset: + eval_dataset = collate_seq2seq_dataset( + eval_dataset, + tokenizer, + max_input_len=data_args.max_seq_length - 2, + max_output_len=128, # TODO: Parameterize? + pad_to_multiple_of=data_args.pad_to_multiple_of, + cache_dir=cache_dir, + cache_file_prefix="seq2seqval", + ) return TaskData( train_dataset=train_dataset, diff --git a/notebooks/src/code/data/smgt.py b/notebooks/src/code/data/smgt.py new file mode 100644 index 0000000..8d0c9bf --- /dev/null +++ b/notebooks/src/code/data/smgt.py @@ -0,0 +1,260 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: MIT-0 +"""Data models for working with SageMaker Ground Truth in general and our specific custom task UI. + +Includes parsing e.g. bounding box results from built-in task type or the crowd-bounding-box tag. +""" +# Python Built-Ins: +from __future__ import annotations +from dataclasses import dataclass +from enum import Enum +from typing import List, Optional + +# External Dependencies: +import numpy as np +import torch + + +class AnnotationBoundingBox: + """Class to parse a bounding box annotated by SageMaker Ground Truth Object Detection + + Pre-calculates all box TLHWBR metrics (both absolute and relative) on init, for efficient and + easy processing later. + """ + + def __init__(self, manifest_box: dict, image_height: int, image_width: int): + self._class_id = manifest_box["class_id"] + self._abs_top = manifest_box["top"] + self._abs_left = manifest_box["left"] + self._abs_height = manifest_box["height"] + self._abs_width = manifest_box["width"] + self._abs_bottom = self.abs_top + self.abs_height + self._abs_right = self.abs_left + self.abs_width + self._rel_top = self._abs_top / image_height + self._rel_left = self._abs_left / image_width + self._rel_height = self._abs_height / image_height + self._rel_width = self._abs_width / image_width + self._rel_bottom = self._abs_bottom / image_height + self._rel_right = self._abs_right / image_width + + @property + def class_id(self): + return self._class_id + + @property + def abs_top(self): + return self._abs_top + + @property + def abs_left(self): + return self._abs_left + + @property + def abs_height(self): + return self._abs_height + + @property + def abs_width(self): + return self._abs_width + + @property + def abs_bottom(self): + return self._abs_bottom + + @property + def abs_right(self): + return self._abs_right + + @property + def rel_top(self): + return self._rel_top + + @property + def rel_left(self): + return self._rel_left + + @property + def rel_height(self): + return self._rel_height + + @property + def rel_width(self): + return self._rel_width + + @property + def rel_bottom(self): + return self._rel_bottom + + @property + def rel_right(self): + return self._rel_right + + +class BoundingBoxAnnotationResult: + """Class to parse the result field saved by a SageMaker Ground Truth Object Detection job""" + + def __init__(self, manifest_obj: dict): + """Initialize a BoundingBoxAnnotationResult + + Arguments + --------- + manifest_obj : dict + The contents of the output field of a record in a SMGT Object Detection labelling job + output manifest, or equivalent. + """ + try: + image_size_spec = manifest_obj["image_size"][0] + self._image_height = int(image_size_spec["height"]) + self._image_width = int(image_size_spec["width"]) + self._image_depth = ( + int(image_size_spec["depth"]) if "depth" in image_size_spec else None + ) + except Exception as e: + raise ValueError( + "".join( + ( + "manifest_obj must be a dictionary including 'image_size': a list of ", + "length 1 whose first/only element is a dict with integer properties ", + f"'height' and 'width', optionally also 'depth'. Got: {manifest_obj}", + ) + ) + ) from e + assert ( + len(manifest_obj["image_size"]) == 1 + ), f"manifest_obj['image_size'] must be a list of len 1. Got: {manifest_obj['image_size']}" + + try: + self._boxes = [ + AnnotationBoundingBox( + b, + image_height=self._image_height, + image_width=self._image_width, + ) + for b in manifest_obj["annotations"] + ] + except Exception as e: + raise ValueError( + "".join( + ( + "manifest_obj['annotations'] must be a list-like of absolute TLHW bounding box ", + f"dicts with class_id. Got {manifest_obj['annotations']}", + ) + ) + ) from e + + @property + def image_height(self): + return self._image_height + + @property + def image_width(self): + return self._image_width + + @property + def image_depth(self): + return self._image_depth + + @property + def boxes(self): + return self._boxes + + def normalized_boxes( + self, + return_tensors: Optional[str] = None, + ): + """Annotation boxes in 0-1000 normalized x0,y0,x1,y1 array/tensor format as per LayoutLM""" + raw_zero_to_one_list = [ + [ + box.rel_left, + box.rel_top, + box.rel_right, + box.rel_bottom, + ] + for box in self._boxes + ] + if return_tensors == "np" or not return_tensors: + if len(raw_zero_to_one_list) == 0: + npresult = np.zeros((0, 4), dtype="long") + else: + npresult = (np.array(raw_zero_to_one_list) * 1000).astype("long") + return npresult if return_tensors else npresult.tolist() + elif return_tensors == "pt": + if len(raw_zero_to_one_list) == 0: + return torch.zeros((0, 4), dtype=torch.long) + else: + return (torch.FloatTensor(raw_zero_to_one_list) * 1000).long() + else: + raise ValueError("return_tensors must be None, 'np' or 'pt'. Got: %s" % return_tensors) + + +class OCRReviewStatus(str, Enum): + """Ternary status for OCR transcription review + + TODO: Merge/share with postproc Lambda function if possible? + """ + + correct = "correct" + unclear = "unclear" + wrong = "wrong" + + +@dataclass +class OCREntityWithTranscriptReview: + detection_id: str + ocr_status: OCRReviewStatus + box_ixs: List[int] + class_id: int # TODO: This is optional in postproc Lambda's data model, re-align + raw_text: str # TODO: This is optional in postproc Lambda's data model, re-align + target_text: Optional[str] + label: Optional[str] + + @classmethod + def from_dict(cls, raw: dict) -> OCREntityWithTranscriptReview: + """Parse an individual entity annotation as produced by custom SMGT task UI+post-proc""" + raw_text = raw["rawText"] + ocr_status = OCRReviewStatus[raw["ocrStatus"]] + target_text = raw.get("targetText") + if target_text is None: + if ocr_status != OCRReviewStatus.wrong: + target_text = raw_text + else: + raise ValueError( + "Entity annotation is missing targetText field, but is tagged with ocrStatus " + "'wrong' so we can't take the rawText as target: %s" % raw + ) + + return cls( + detection_id=raw["detectionId"], + ocr_status=ocr_status, + box_ixs=raw["boxIxs"], + class_id=raw["classId"], + raw_text=raw_text, + target_text=target_text, + label=raw.get("label"), + ) + + +class BBoxesWithTranscriptReviewsAnnotationResult(BoundingBoxAnnotationResult): + """Result field saved by an SMGT job using the custom entities-with-transcription-reviews task + + This custom task, introduced via the notebooks and implemented by custom Liquid HTML template + and pre/post-processing Lambda functions, outputs data compatible with the standard bounding box + task UI but enriched with consolidated (overlapping) per-class regions and transcription reviews + for each region. + """ + + entities: List[OCREntityWithTranscriptReview] + + def __init__(self, manifest_obj: dict): + # Parse the bounding boxes themselves via superclass: + super().__init__(manifest_obj) + # Parse the OCR entities: + if "entities" not in manifest_obj: + raise ValueError( + "SMGT manifest is missing 'entities' key, which should be generated by the custom " + "task template but not the built-in bounding box UI. Using bbox-only annotations " + "for a seq2seq training job is not currently supported." + ) + self.entities = [ + OCREntityWithTranscriptReview.from_dict(e) for e in manifest_obj["entities"] + ] diff --git a/notebooks/src/code/data/splitting.py b/notebooks/src/code/data/splitting.py index 343bf79..7518d68 100644 --- a/notebooks/src/code/data/splitting.py +++ b/notebooks/src/code/data/splitting.py @@ -180,6 +180,30 @@ def duplicate_batch_record( } +def remove_batch_records( + batch: Dict[str, List[Any]], + ix_start: int, + n: int = 1, +) -> Dict[str, List[Any]]: + """Remove one or more records from a batch + + Parameters + ---------- + batch : + Input data, dictionary by feature name of value lists. + ix_start : + 0-based index of first record to remove + n : + Number of records to remove + + Returns + ------- + result : + A shallow copy of the batch with the target record(s) removed. + """ + return {name: (values[:ix_start] + values[ix_start + n :]) for name, values in batch.items()} + + def split_batch_record( batch: Dict[str, List[Any]], ix: int, From 155bdaba860fff2bbacfb65c80741da62785c39d Mon Sep 17 00:00:00 2001 From: Alex Thewsey Date: Tue, 17 Jan 2023 15:16:59 +0800 Subject: [PATCH 13/14] doc(seq2seq): Mention seq2seq sample on readme Update README and CUSTOMIZATION_GUIDE to mention new seq2seq entity text normalization training option. --- CUSTOMIZATION_GUIDE.md | 2 ++ README.md | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/CUSTOMIZATION_GUIDE.md b/CUSTOMIZATION_GUIDE.md index 0d318fd..abe8811 100644 --- a/CUSTOMIZATION_GUIDE.md +++ b/CUSTOMIZATION_GUIDE.md @@ -90,6 +90,8 @@ Consider editing the `select_examples()` function to customize how the set of ca ### Step 8: Proceed with data annotation and subsequent steps +If you're planning to review OCR accuracy as part of your PoC or train models to normalize from the raw detected text to standardised values (for example, normalising dates or number representations), you might find it useful to use the custom Ground Truth UI presented in Notebook 1 - instead of the default built-in (bounding-box only) UI. + From the labelling job onwards (through notebook 2 and beyond), the flow should be essentially the same as with the sample data. Just remember to edit the `include_jobs` list in notebook 2 to reflect the actual annotation jobs you performed. If your dataset is particularly tiny (more like e.g. 30 labelled pages than 100), it might be helpful to try increasing the `early_stopping_patience` hyperparameter to force the training job to re-process the same examples for longer. You could also explore hyperparameter tuning. However, it'd likely have a bigger impact to spend that time annotatting more data instead! diff --git a/README.md b/README.md index f9d391a..c15f603 100644 --- a/README.md +++ b/README.md @@ -202,7 +202,7 @@ The approach should work well for many different document types, and the solutio However, there are many more opportunities to extend the approach. For example: -- Rather than token/word classification, alternative '**sequence-to-sequence**' ML tasks such as could be applied: Perhaps to fix common OCR error patterns, or to build general question-answering models on documents. +- Rather than token/word classification, alternative '**sequence-to-sequence**' ML tasks could be applied: Perhaps to fix common OCR error patterns, or to build general question-answering models on documents. Training seq2seq models is discussed further in the [Optional Extras notebook](notebooks/Optional%20Extras.ipynb). - Just as the BERT-based model was extended to consider coordinates as input, perhaps **source OCR confidence scores** (also available from Textract) would be useful model inputs. - The post-processing Lambda function could be extended to perform more complex validations on detected fields: For example to extract numerics, enforce regular expression matching, or even call some additional AI service such as Amazon Comprehend. From 8b99fb5aa6b9be06c5fdcbbd0a06b4b5daea761b Mon Sep 17 00:00:00 2001 From: Alex Thewsey Date: Tue, 17 Jan 2023 15:26:18 +0800 Subject: [PATCH 14/14] build(git): gitignore SM Experiments tmpfiles --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index 72becb2..12aca3b 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,9 @@ .cdk.staging cdk.out +# SageMaker Experiments/Debugger (if try running locally): +tmp_trainer/ + # Working data folders and notebook-built assets: # (With some specific exclusions) notebooks/data/*