**Post-Processing Amazon Textract with Location-Aware Transformers**

# Part 2: Data Consolidation and Model Training/Deployment

> *This notebook works well with the `Python 3 (Data Science)` kernel on SageMaker Studio*

In the [first notebook](1.%20Data%20Preparation.ipynb) we worked through preparing a corpus with Amazon Textract and labelling a small sample to highlight entities of interest.

In this part 2, we'll consolidate the labelling results together with a pre-prepared augmentation set, and actually train and deploy a SageMaker model for word classification.

First, as in the previous notebook, we'll start by importing the required libraries and loading configuration:

In [None]:
%load_ext autoreload
%autoreload 2

# Python Built-Ins:
from datetime import datetime
import json
from logging import getLogger
import os
import random
import time

# External Dependencies:
import boto3  # AWS SDK for Python
import sagemaker
from sagemaker.huggingface import HuggingFace as HuggingFaceEstimator
from tqdm.notebook import tqdm  # Progress bars

# Local Dependencies:
import util

logger = getLogger()

# Manual configuration (check this matches notebook 1):
bucket_name = sagemaker.Session().default_bucket()
bucket_prefix = "textract-transformers/"
print(f"Working in bucket s3://{bucket_name}/{bucket_prefix}")
config = util.project.init("ocr-transformers-demo")
print(config)

# Field configuration saved from first notebook:
with open("data/field-config.json", "r") as f:
    fields = [
        util.postproc.config.FieldConfiguration.from_dict(cfg)
        for cfg in json.loads(f.read())
    ]
entity_classes = [f.name for f in fields]

# S3 URIs as per first notebook:
imgs_s3uri = f"s3://{bucket_name}/{bucket_prefix}data/imgs-clean"
textract_s3uri = f"s3://{bucket_name}/{bucket_prefix}data/textracted"
annotations_base_s3uri = f"s3://{bucket_name}/{bucket_prefix}data/annotations"

## Data Consolidation

To construct a training set, we'll typically need to consolidate the results of multiple SageMaker Ground Truth labelling jobs: Perhaps because the work was split up into more manageable chunks - or maybe because additional review/adjustment jobs were run to improve label quality.

First, we'll download the output folders of all our labelling jobs to the local `data/annotations` folder: (The code here assumes you configured the same `annotations_base_s3uri` output folder for each job in SMGT)

In [None]:
!aws s3 sync --quiet $annotations_base_s3uri ./data/annotations

Inside this folder, you'll find some **pre-annotated augmentation data** provided for you already (in the `augmentation-` subfolders). These datasets are not especially large or externally useful, but will help you train a better model without too much (or even any!) manual annotation effort.

▶️ **Edit** the `include_jobs` line below to control which datasets (pre-provided and your own) will be included:

In [None]:
include_jobs = [
    "augmentation-1",
    "augmentation-2",
    # TODO: Adjust the below to match the labelling jobs you created, or comment out if you didn't:
    "cfpb-boxes-1",
]


source_manifests = []
for job_name in sorted(filter(
    lambda n: os.path.isdir(f"data/annotations/{n}"),
    os.listdir("data/annotations")
)):
    if job_name not in include_jobs:
        logger.warning(f"Skipping {job_name} (not in include_jobs list)")
        continue
    job_manifest_path = f"data/annotations/{job_name}/manifests/output/output.manifest"
    if not os.path.isfile(job_manifest_path):
        raise RuntimeError(f"Could not find job output manifest {job_manifest_path}")
    source_manifests.append({ "job_name": job_name, "manifest_path": job_manifest_path })

print(f"Got {len(source_manifests)} annotated manifests:")
print("\n".join(map(lambda o: o["manifest_path"], source_manifests)))

Now that the results are downloaded, we're ready to consolidate the **output manifest files** from each one into a combined manifest file.

Note that to combine multiple output manifests to a single dataset:

- We need to ensure the labels are stored in the same attribute on every record (records use the labeling job name by default, which will be different between jobs).
- If importing data collected from some other account (like the `augmentation-` sets), we'll need to **map the S3 URIs** to equivalent links on your own bucket.

In [None]:
# Annotations/labels will be standardized to this field on all records:
standard_label_field = "label"

# To import a manifest from somebody else, we of course need to map their bucket names and prefixes
# to ours (and have equivalent files stored in the same locations after the mapping):
BUCKET_MAPPINGS = { "DOC-EXAMPLE-BUCKET": bucket_name }
PREFIX_MAPPINGS = { "EXAMPLE-PREFIX/": bucket_prefix }

annotated_page_imgs = {}
print("Writing data/annotations/annotations-all.manifest.jsonl")
with open("data/annotations/annotations-all.manifest.jsonl", "w") as fout:
    for source in tqdm(source_manifests, desc="Consolidating manifests..."):
        with open(source["manifest_path"], "r") as fin:
            for line in filter(lambda l: l, fin):
                obj = json.loads(line)

                # Import refs by applying BUCKET_MAPPINGS and PREFIX_MAPPINGS:
                for k in filter(lambda k: k.endswith("-ref"), obj.keys()):
                    if not obj[k].lower().startswith("s3://"):
                        raise RuntimeError(
                            "Attr {} ends with -ref but does not start with 's3://'\n{}".format(
                                k,
                                obj
                            )
                        )
                    obj_bucket, _, obj_key = obj[k][len("s3://"):].partition("/")
                    obj_bucket = BUCKET_MAPPINGS.get(obj_bucket, obj_bucket)
                    for old_prefix in PREFIX_MAPPINGS:
                        if obj_key.startswith(old_prefix):
                            obj_key = (
                                PREFIX_MAPPINGS[old_prefix]
                                + obj_key[len(old_prefix):]
                            )
                    obj[k] = f"s3://{obj_bucket}/{obj_key}"
                
                # Find the job output field:
                if source["job_name"] in obj:
                    source_label_attr = source["job_name"]
                elif standard_label_field in obj:
                    source_label_attr = standard_label_field
                else:
                    raise RuntimeError("Couldn't find label field for entry in {}:\n{}".format(
                        source["job_name"],
                        obj,
                    ))
                # Rename to standard:
                obj[standard_label_field] = obj.pop(source_label_attr)
                source_meta_attr = f"{source_label_attr}-metadata"
                if source_meta_attr in obj:
                    obj[f"{standard_label_field}-metadata"] = obj.pop(source_meta_attr)
                # Write to output manifest:
                fout.write(json.dumps(obj) + "\n")


### Split training and test sets

To get some insight on how well our model is generalizing to real-world data, we'll need to reserve some annotated data as a testing/validation set.

Below, we randomly partition the data into training and test sets and then upload the two manifests to S3:

In [None]:
def split_manifest(f_in, f_train, f_test, train_pct=0.9, random_seed=1337):
    logger.info(f"Reading {f_in}")
    with open(f_in, "r") as fin:
        lines = [l for l in filter(lambda l: l, fin)]
    logger.info(f"Shuffling records")
    random.Random(random_seed).shuffle(lines)
    n_train = round(len(lines) * train_pct)

    with open(f_train, "w") as ftrain:
        logger.info(f"Writing {n_train} records to {f_train}")
        for l in lines[:n_train]:
            ftrain.write(l)
    with open(f_test, "w") as ftest:
        logger.info(f"Writing {len(lines) - n_train} records to {f_test}")
        for l in lines[n_train:]:
            ftest.write(l)

split_manifest(
    "data/annotations/annotations-all.manifest.jsonl",
    "data/annotations/annotations-train.manifest.jsonl",
    "data/annotations/annotations-test.manifest.jsonl",
)

In [None]:
train_manifest_s3uri = f"s3://{bucket_name}/{bucket_prefix}data/annotations/annotations-train.manifest.jsonl"
!aws s3 cp data/annotations/annotations-train.manifest.jsonl $train_manifest_s3uri

test_manifest_s3uri = f"s3://{bucket_name}/{bucket_prefix}data/annotations/annotations-test.manifest.jsonl"
!aws s3 cp data/annotations/annotations-test.manifest.jsonl $test_manifest_s3uri

### Visualize the data

Before training the model, we'll sense-check the data by plotting a few examples.

The utility function below will overlay the page image with the annotated bounding boxes, the locations of `WORD` blocks detected from the Amazon Textract results, and the resulting classification of individual Textract `WORD`s.

> ⏰ If you Textracted a large number of documents and haven't previously synced them to the notebook, the initial download here may take a few minutes to complete. For our sample set of 120, typically only ~20s is needed.

In [None]:
%%time

!aws s3 sync --quiet $textract_s3uri ./data/textracted

> ⚠️ **Note:** For the interactive visualization widgets in this notebook to work correctly, you'll need the [IPyWidgets extension for JupyterLab](https://ipywidgets.readthedocs.io/en/latest/user_install.html).
>
> On [SageMaker Studio](https://aws.amazon.com/sagemaker/studio/), this should be installed by default.
>
> On the classic [SageMaker Notebook Instances](https://docs.aws.amazon.com/sagemaker/latest/dg/nbi.html) though, you'll need to install the `@jupyter-widgets/jupyterlab-manager` extension (from `Settings > Extension Manager`, or using a [lifecycle configuration](https://docs.aws.amazon.com/sagemaker/latest/dg/notebook-lifecycle-config.html) similar to [this sample](https://github.com/aws-samples/amazon-sagemaker-notebook-instance-lifecycle-config-samples/tree/master/scripts/install-lab-extension)) - or just use plain `Jupyter` instead of `JupyterLab`.

In [None]:
with open("data/annotations/annotations-test.manifest.jsonl", "r") as fman:
    test_examples = [json.loads(l) for l in filter(lambda l: l, fman)]

util.viz.draw_from_manifest_items(
    test_examples,
    standard_label_field,
    entity_classes,
    imgs_s3uri[len("s3://"):].partition("/")[2],
    textract_s3key_prefix=textract_s3uri[len("s3://"):].partition("/")[2],
    imgs_local_prefix="data/imgs-clean",
    textract_local_prefix="data/textracted",
)

## Self-supervised pre-training

In many cases, businesses have a great deal more relevant *unlabelled* data available in addition to the manually labeled dataset. For example, you might have many more historical documents available (with OCR results already, or able to be processed with Amazon Textract) than you're reasonably able to annotate entities on - just as we do in this example!

Large-scale language models like LayoutLM are typically **pre-trained** to unlabelled data in a **self-supervised** pattern: Teaching the model to predict some implicit task in the data like, for example, masking a few words on the page and predicting what words should go in the gaps.

This pre-training doesn't directly teach the model to perform the target task (classifying entities), but forces the core of the model to learn intrinsic patterns in the data. When we then replace the output layers and **fine-tune** towards the target task with human-labelled data, the model is able to learn the target task more effectively.

**In this example, pre-training is optional**:

- By default, for speed, the configuration below will use a public pre-trained model from the [Hugging Face Transformers model repository](https://huggingface.co/models?search=layoutlm). This allows us to focus immediately on fine-tuning to our task; but also means accuracy may be degraded if our documents are very different from the original corpus the model was trained on.
- Alternatively, set `pretrain = True` below to *further* pre-train this same base public model on your own Textracted documents first.

Pre-training more likely to be valuable where you have a broader range of data available than the core supervised/annotated dataset, and the language/layouts used in your domain are unusual or specicalized. If you followed through [Notebook 1](1.%20Data%20Preparation.ipynb) with the default settings to Amazon Textract only a small sample of the documents, you may like to go back, increase `N_DOCS_KEPT`, and Textract some more of the source documents first.

> ⚠️ **Note:** Refer to the [Amazon SageMaker Pricing Page](https://aws.amazon.com/sagemaker/pricing/) for up-to-date guidance before running large pre-training jobs.
>
> In our tests at the time of writing:
>
> - Pre-training on only the 120 "sample" documents to 25 epochs took about 30 minutes on an `ml.p3.8xlarge` instance with per-device batch size 4
> - Pre-training on a larger 500-document subset with the same infrastructure and settings took about an hour
> - Although the observed effect on downstream (entity recognition) accuracy metrics was generally positive in either case, it was small compared to variation over random seed initializations in fine-tuning.

In [None]:
pretrain = False  # Set this True instead to run an additional pre-training job.

pretrained_s3_uri = None  # Will be overwritten later if pretrain is enabled

For self-supervised pre-training, you can utilize the full available corpus of Textract-processed documents: Not just the subset of documents and pages you have annotations for. Reserving some documents for validation is still a good idea though, to understand if and when the model starts to over-fit.

Arguably, including pages from the entity recognition validation dataset in pre-training constitutes [leakage](https://en.wikipedia.org/wiki/Leakage_(machine_learning)): Because even though we're not including any information about the entity labels the NER model will predict, we're teaching the model information about patterns of content in the hold-out pages.

Therefore, the below code takes a conservative view to avoid possibly over-estimating the added benefits of pre-training: Constructing manifests to route *any document with pages in the entity recognition validation set* to also be in the validation set for pre-training.

In [None]:
selfsup_train_manifest_s3uri = f"s3://{bucket_name}/{bucket_prefix}data/docs-train.manifest.jsonl"
selfsup_val_manifest_s3uri = f"s3://{bucket_name}/{bucket_prefix}data/docs-val.manifest.jsonl"

# To avoid information leakage, take the validation set = the set of all documents with *any* pages
# mentioned in the validation set:
val_textract_s3uris = set()
with open("data/annotations/annotations-test.manifest.jsonl", "r") as f:
    for line in f:
        val_textract_s3uris.add(json.loads(line)["textract-ref"])
with open("data/docs-val.manifest.jsonl", "w") as f:
    for uri in val_textract_s3uris:
        f.write(json.dumps({"textract-ref": uri}) + "\n")
print(f"Added {len(val_textract_s3uris)} docs to pre-training validation set")

# Any Textracted docs not mentioned in validation can go to training:
train_textract_s3uris = set()
with open("data/textracted-all.manifest.jsonl", "r") as fner:
    with open("data/docs-train.manifest.jsonl", "w") as f:
        for line in fner:
            uri = json.loads(line)["textract-ref"]
            if (uri in val_textract_s3uris) or (uri in train_textract_s3uris):
                continue
            else:
                train_textract_s3uris.add(uri)
                f.write(json.dumps({"textract-ref": uri}) + "\n")
print(f"Added {len(train_textract_s3uris)} docs to pre-training set")

In [None]:
!aws s3 cp data/docs-train.manifest.jsonl {selfsup_train_manifest_s3uri}
!aws s3 cp data/docs-val.manifest.jsonl {selfsup_val_manifest_s3uri}

With the Amazon Textract JSONs prepared on S3 and split between training and validation via manifests, we're ready to run the pre-training.

> ▶️ See the following *Fine-tuning on annotated data* section for more details and links on how model training works in SageMaker - which are omitted here since this section is optional.

Since customized inputs for this job might be more variable than fine-tuning (because annotating data requires effort, but scaling up your unlabelled corpus may be easy), it's worth mentioning some relevant parameter options:

- **`instance_type`**: While `ml.g4dn.xlarge` is a nice, low-hourly-cost, GPU-enabled option for our small data fine-tuning job later; the larger data volume in pre-training makes the speed-up available from `ml.p3.2xlarge` more significant. The provided script is multi-GPU capable, so for bigger jobs you may find `ml.p3.8xlarge` and beyond give more acceptable run-times.
- **`per_device_train_batch_size`**: Controls *per-accelerator* batching; so bear in mind that moving up to a multi-GPU instance type (such as 4 GPUs in an `ml.p3.8xlarge`) implicitly increases the overall batch size for learning.
- Other hyperparameters are available, as the implementation is generally based on the [Hugging Face TrainingArguments parser](https://huggingface.co/transformers/main_classes/trainer.html#transformers.TrainingArguments) with [customizations applied in src/code/config.py](src/code/config.py)

In [None]:
hyperparameters = {
    # (See src/code/config.py for more info on script parameters)
    "task_name": "mlm",
    "textract_prefix": textract_s3uri[len("s3://"):].partition("/")[2],

    "model_name_or_path": "microsoft/layoutlm-base-uncased",
    
    "learning_rate": 5e-5,
    "per_device_train_batch_size": 4,

    "num_train_epochs": 25,
    "early_stopping_patience": 10,
    "metric_for_best_model": "eval_loss",
    "greater_is_better": "false",
    
    # Early stopping implies checkpointing every evaluation (epoch), so limit the total checkpoints
    # kept to avoid filling up disk:
    "save_total_limit": 10,
    "seed": 42,
}

metric_definitions = [
    { "Name": "epoch", "Regex": util.training.get_hf_metric_regex("epoch") },
    { "Name": "learning_rate", "Regex": util.training.get_hf_metric_regex("learning_rate") },
    { "Name": "train:loss", "Regex": util.training.get_hf_metric_regex("loss") },
    { "Name": "validation:loss", "Regex": util.training.get_hf_metric_regex("eval_loss") },
    {
        "Name": "validation:samples_per_sec",
        "Regex": util.training.get_hf_metric_regex("eval_samples_per_second"),
    },
]

pre_estimator = HuggingFaceEstimator(
    role=sagemaker.get_execution_role(),
    entry_point="train.py",
    source_dir="src",
    py_version="py38",
    pytorch_version="1.9",
    transformers_version="4.11",

    base_job_name="layoutlm-cfpb-pretrain",
    output_path=f"s3://{bucket_name}/{bucket_prefix}trainjobs",

    instance_type="ml.p3.8xlarge",
    instance_count=1,
    volume_size=50,

    debugger_hook_config=False,
#     profiler_config=sagemaker.debugger.ProfilerConfig(
#         framework_profile_params=sagemaker.debugger.FrameworkProfile(),
#     ),

    hyperparameters=hyperparameters,
    metric_definitions=metric_definitions,
    # Required for our custom dataset loading code (which depends on tokenizer):
    environment={ "TOKENIZERS_PARALLELISM": "false", },
)

In [None]:
if pretrain:
    pre_estimator.fit(
        inputs={
            "train": selfsup_train_manifest_s3uri,
            "textract": textract_s3uri + "/",
            "validation": selfsup_val_manifest_s3uri,
        },
        #wait=False,
    )

Once the pre-training is complete, fetch the output model S3 URI to use as input for the fine-tuning stage:

In [None]:
if pretrain:
    # Un-comment this first line to load an previous pre-training job instead:
    # pre_estimator = HuggingFaceEstimator.attach("layoutlm-cfpb-pretrain-2021-11-17-01-53-05-786")

    pretraining_job_desc = pre_estimator.latest_training_job.describe()
    pretrained_s3_uri = pretraining_job_desc["ModelArtifacts"]["S3ModelArtifacts"]

print(f"Custom pre-trained model: {pretrained_s3_uri}")

## Fine-tuning on annotated data

In this section we'll run a [SageMaker Training Job](https://docs.aws.amazon.com/sagemaker/latest/dg/how-it-works-training.html) to fine-tune the model on our annotated dataset.

In this process:

- SageMaker will run the job on a dedicated, managed instance of type we choose (we'll use `ml.p*` or `ml.g*` GPU-accelerated types), allowing us to keep this notebook's resources modest and only pay for the seconds of GPU time the training job needs.
- The data as specified in the manifest files will be downloaded from Amazon S3.
- The bundle of scripts we provide (in `src/`) will be transparently uploaded to S3 and then run inside the specified SageMaker-provided [framework container](https://docs.aws.amazon.com/sagemaker/latest/dg/docker-containers-prebuilt.html). There's no need for us to build our own container image or implement a serving stack for inference (although fully-custom containers are [also supported](https://docs.aws.amazon.com/sagemaker/latest/dg/docker-containers.html)).
- Job hyperparameters will be passed through to our `src/` scripts as CLI arguments.
- SageMaker will analyze the logs from the job (i.e. `print()` or `logger` calls from our script) with the regular expressions specified in `metric_definitions`, to scrape structured timeseries metrics like loss and accuracy.
- When the job finishes, the contents of the `model` folder in the container will be automatically tarballed and uploaded to a `model.tar.gz` in Amazon S3.

Rather than orchestrating this process through the low-level [SageMaker API](https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateTrainingJob.html) (e.g. via [boto3](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_training_job)), we'll use the open-source [SageMaker Python SDK](https://sagemaker.readthedocs.io/en/stable/) (`sagemaker`) for convenience.

Rather than using the base [SageMaker PyTorch framework containers](https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html), we'll take advantage of the [tailored containers for Hugging Face](https://sagemaker.readthedocs.io/en/stable/frameworks/huggingface/index.html). You can also refer to [Hugging Face's own docs for training on SageMaker](https://huggingface.co/transformers/sagemaker.html) for more information, and of course the implementation of our training script here in the `src/` folder.

First, we'll configure some parameters you may **sometimes wish to re-use across training jobs**. Continuation jobs may want to use the same checkpoint location in S3, while from-scratch training should start fresh

▶️ You can choose when to re-run this cell between experiments:

In [None]:
checkpoint_collection_name = "checkpoints-" + datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
print(f"Saving checkpoints to collection {checkpoint_collection_name}")

checkpoint_s3_uri = f"s3://{bucket_name}/{bucket_prefix}checkpoints/{checkpoint_collection_name}"

Next, we'll define the core configuration for our training job:

▶️ This should usually be re-run for every new training job

In [None]:
hyperparameters = {
    # (See src/code/config.py for more info on script parameters)
    "annotation_attr": standard_label_field,
    "textract_prefix": textract_s3uri[len("s3://"):].partition("/")[2],
    "num_labels": len(fields) + 1,  # +1 for "other"

    "num_train_epochs": 150,  # Set high for automatic HP tuning later
    "early_stopping_patience": 5,  # Usually stops after <20 epochs on this sample data+config
    "metric_for_best_model": "eval_focus_else_acc_minus_one",
    "greater_is_better": "true",

    # Early stopping implies checkpointing every evaluation (epoch), so limit the total checkpoints
    # kept to avoid filling up disk:
    "save_total_limit": 10,
}
if not pretrained_s3_uri:
    hyperparameters["model_name_or_path"] = "microsoft/layoutlm-base-uncased"

def get_hf_metric_regex(metric_name):
    """Build RegEx string to extract a numeric HuggingFace Transformers metric from logs

    HF metric log lines look like a Python dict print e.g:
    {'eval_loss': 0.3940396010875702, ..., 'epoch': 1.0}
    """
    scientific_number_exp = r"(-?[0-9]+(\.[0-9]+)?(e[+\-][0-9]+)?)"
    return "".join((
        "'",
        metric_name,
        "': ",
        scientific_number_exp,
        "[,}]",
    ))

metric_definitions = [
    { "Name": "epoch", "Regex": get_hf_metric_regex("epoch") },
    { "Name": "learning_rate", "Regex": get_hf_metric_regex("learning_rate") },
    { "Name": "train:loss", "Regex": get_hf_metric_regex("loss") },
    { "Name": "validation:n_examples", "Regex": get_hf_metric_regex("eval_n_examples") },
    { "Name": "validation:loss_avg", "Regex": get_hf_metric_regex("eval_loss") },
    { "Name": "validation:acc", "Regex": get_hf_metric_regex("eval_acc") },
    { "Name": "validation:focus_acc", "Regex": get_hf_metric_regex("eval_focus_acc") },
    { "Name": "validation:target", "Regex": get_hf_metric_regex("eval_focus_else_acc_minus_one") },
]

estimator = HuggingFaceEstimator(
    role=sagemaker.get_execution_role(),
    entry_point="train.py",
    source_dir="src",
    py_version="py38",
    pytorch_version="1.9",
    transformers_version="4.11",

    base_job_name="layoutlm-cfpb-hf",
    output_path=f"s3://{bucket_name}/{bucket_prefix}trainjobs",
    #checkpoint_s3_uri=checkpoint_s3_uri,  # Un-comment to turn on checkpoint upload to S3

    instance_type="ml.g4dn.xlarge",  # Could also consider ml.p3.2xlarge
    instance_count=1,
    volume_size=50,

    #debugger_hook_config=False,
    profiler_config=sagemaker.debugger.ProfilerConfig(
        framework_profile_params=sagemaker.debugger.FrameworkProfile(),
    ),

    hyperparameters=hyperparameters,
    metric_definitions=metric_definitions,
    # Required for our custom dataset loading code (which depends on tokenizer):
    environment={ "TOKENIZERS_PARALLELISM": "false", },
)

Finally, the below cell will actually kick off the training job and stream logs from the running container.

> ℹ️ You'll also be able to check the status of the job in the [Training jobs page of the SageMaker Console](https://console.aws.amazon.com/sagemaker/home?#/jobs).

In [None]:
inputs = {
    "train": train_manifest_s3uri,
    "textract": textract_s3uri + "/",
    "validation": test_manifest_s3uri,
}
if pretrained_s3_uri:
    print(f"Using custom pre-trained model {pretrained_s3_uri}")
    inputs["model_name_or_path"] = pretrained_s3_uri

estimator.fit(inputs)

## (Optional) Hyperparameter tuning

Particularly when applying novel techniques or working in new domains, we'll often need to find good values for a range of different *hyperparameters* of our proposed algorithms.

Rather than spending time manually adjusting these parameters, we can use [SageMaker Automatic Model Tuning](https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning.html) which uses an intelligent [Bayesian optimization approach](https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning-how-it-works.html) to efficiently and automatically search for high-performing combinations over several training jobs.

You can optionally run the cell below to kick off an HPO job for the model:

In [None]:
tuner = sagemaker.tuner.HyperparameterTuner(
    estimator,
    "validation:target",
    base_tuning_job_name="layoutlm-cfpb-hpo",
    hyperparameter_ranges={
        "learning_rate": sagemaker.parameter.ContinuousParameter(
            1e-8,
            1e-3,
            scaling_type="Logarithmic",
        ),
        "per_device_train_batch_size": sagemaker.parameter.CategoricalParameter([2, 4, 6, 8]),
        "label_smoothing_factor": sagemaker.parameter.CategoricalParameter([0.0, 1e-12, 1e-9, 1e-6]),
    },
    metric_definitions=metric_definitions,
    strategy="Bayesian",
    objective_type="Maximize",
    max_jobs=21,
    max_parallel_jobs=2,
    #early_stopping_type="Auto",  # Off by default - could consider turning it on
#     warm_start_config=sagemaker.tuner.WarmStartConfig(
#         warm_start_type=sagemaker.tuner.WarmStartTypes.IDENTICAL_DATA_AND_ALGORITHM,
#         parents={ "layoutlm-cfpb-hpo-210723-1625" },
#     ),
)

tuner.fit(
    inputs={
        "train": train_manifest_s3uri,
        "textract": textract_s3uri + "/",
        "validation": test_manifest_s3uri,
    },
    wait=False,
)

This job will run asynchronously so won't block the notebook, but you can check on the status from the [Hyperparameter tuning jobs list](https://console.aws.amazon.com/sagemaker/home?#/hyper-tuning-jobs) of the SageMaker Console.

## Deploy the model

Once our model is trained (or maybe even automatically hyperparameter-tuned over several training jobs), it's ready to be deployed for real-time or batch inference.

Note that if, for some reason, you need to recover the state of a previous training or tuning job after a notebook restart or similar, you can `attach()` to training or tuning jobs by name - as shown below:

In [None]:
# If needed, you can attach to a previous training job by name like this:
# estimator = HuggingFaceEstimator.attach("layoutlm-cfpb-210529-0851-006-5ee95cde")
# tuner = sagemaker.tuner.HyperparameterTuner.attach("layoutlm-cfpb-hpo-210603-0542")

### Easy one-click deployment

For straightforward deployment, you can just call `estimator.deploy()` (or equivalently, `tuner.deploy()`):

In [None]:
training_job_name = estimator.latest_training_job.describe()["TrainingJobName"]
# Or:
# training_job_name = tuner.best_training_job()

predictor = estimator.deploy(
    # Avoid us accidentally deploying the model twice by setting name per training job:
    endpoint_name=training_job_name,
    initial_instance_count=1,
    instance_type="ml.g4dn.xlarge",
    serializer=sagemaker.serializers.JSONSerializer(),
    deserializer=sagemaker.deserializers.JSONDeserializer(),
    # TODO: Disable once debugging is done
    env={ "PYTHONUNBUFFERED": "1" },
)

### (Optional) Digging deeper into the model

Alternatively, you may instead want to explore the artifacts saved by the training job, or edit the `code` script bundle before deploying the endpoint - especially for debugging any problems with inference. Let's see how:

In [None]:
smclient = boto3.client("sagemaker")

In [None]:
training_job_desc = estimator.latest_training_job.describe()
model_s3uri = training_job_desc["ModelArtifacts"]["S3ModelArtifacts"]
model_name = training_job_desc["TrainingJobName"]

In [None]:
!rm -rf ./data/model
!aws s3 cp $model_s3uri ./data/model/model.tar.gz

In [None]:
!cd data/model && tar -xzvf model.tar.gz

In [None]:
from sagemaker.huggingface import HuggingFaceModel

try:
    # Make sure we don't accidentally re-use same model:
    smclient.delete_model(ModelName=model_name)
    print(f"Deleted existing model {model_name}")
except smclient.exceptions.ClientError as e:
    if not (
        e.response["Error"]["Code"] in (404, "404")
        or e.response["Error"].get("Message", "").startswith("Could not find model")
    ):
        raise e

model = HuggingFaceModel(
    name=model_name,
    model_data=model_s3uri,
    role=sagemaker.get_execution_role(),
    source_dir="src/",
    entry_point="inference.py",
    transformers_version="4.11",
    py_version="py38",
    pytorch_version="1.9",
    # TODO: Disable once debugging is done
    env={ "PYTHONUNBUFFERED": "1" },
)

In [None]:
try:
    # Delete previous endpoint, if already in use:
    predictor.delete_endpoint(delete_endpoint_config=True)
    print("Deleting previous endpoint...")
    time.sleep(8)
except (NameError, smclient.exceptions.ResourceNotFound):
    pass  # No existing endpoint to delete
except smclient.exceptions.ClientError as e:
    if "Could not find" not in e.response["Error"].get("Message", ""):
        raise e


print("Deploying model...")
predictor = model.deploy(
    endpoint_name=training_job_desc["TrainingJobName"],
    initial_instance_count=1,
    instance_type="ml.g4dn.xlarge",
    serializer=sagemaker.serializers.JSONSerializer(),
    deserializer=sagemaker.deserializers.JSONDeserializer(),
    # TODO: Disable once debugging is done
    env={ "PYTHONUNBUFFERED": "1" },
)
print("\nDone!")

### (Optional) Optimize costs with Asynchronous Inference

In the examples above, we deployed the trained model to a [real-time inference endpoint](https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-endpoints.html) on SageMaker. These real-time endpoints provide synchronous (request-response) inference, and can be configured to [auto-scale](https://docs.aws.amazon.com/sagemaker/latest/dg/endpoint-auto-scaling.html) based on demand.

However, for [SageMaker asynchronous inference](https://docs.aws.amazon.com/sagemaker/latest/dg/async-inference.html) may be a better fit for many document processing use-cases:

1. Unlike real-time endpoints (at the time of writing), asynchronous endpoints can [auto-scale down to zero instances](https://docs.aws.amazon.com/sagemaker/latest/dg/async-inference-autoscale.html). This can offer substantial cost savings if your business process is low-volume and often idle: With the trade-off that overall process latency may increase, especially for cold-start requests.
1. Asynchronous inference can support longer timeouts and larger request/response payload sizes than real-time: Which can be useful in cases where an individual document may be long and take a significant time to process with a model.

The OCR pipeline stack is configured to handle either synchronous or asynchronous endpoints by default, so you can alternatively set up an auto-scaling asynchronous endpoint as shown below and use that in the pipeline later.

First, create the asynchronous endpoint:

- As detailed in the [SDK docs](https://sagemaker.readthedocs.io/en/stable/overview.html#sagemaker-asynchronous-inference), the optional `async_inference_config` parameter tells SageMaker that the endpoint will be asynchronous rather than real-time.
- For permissions integration, our async endpoint will need to store its outputs in the proper S3 location the pipeline is expecting (`output_path`). We can look that up from here in the notebook via the same SSM-based `config` we've seen before.
- To resume the pipeline when the model processes a document, our endpoint will need to notify the pipeline's SNS topic. Again, this is given on `config`.
- While the *SageMaker* limits on request/response size are higher for asynchronous endpoints than real-time, we need to also make sure the serving stack *within the container* is configured to support very large payloads. Setting the `MMS_MAX_REQUEST_SIZE` and `MMS_MAX_RESPONSE_SIZE` environment variables below prevents errors related to this. For more information see the [AWSLabs Multi-Model Server configuration doc](https://github.com/awslabs/multi-model-server/blob/master/docs/configuration.md) and corresponding page [for TorchServe](https://github.com/pytorch/serve/blob/master/docs/configuration.md#other-properties).

In [None]:
predictor_async = estimator.deploy(
    # Avoid us accidentally deploying the model twice by setting name per training job:
    endpoint_name="async-" + training_job_name,
    initial_instance_count=1,
    instance_type="ml.g4dn.xlarge",
    serializer=sagemaker.serializers.JSONSerializer(),
    deserializer=sagemaker.deserializers.JSONDeserializer(),
    env={
        "PYTHONUNBUFFERED": "1",  # TODO: Disable once debugging is done
        "MMS_MAX_REQUEST_SIZE": str(100*1024*1024),  # 100MiB instead of default ~6.2MiB
        "MMS_MAX_RESPONSE_SIZE": str(100*1024*1024),  # 100MiB instead of default ~6.2MiB
    },
    async_inference_config=sagemaker.async_inference.AsyncInferenceConfig(
        output_path=f"s3://{config.model_results_bucket}",
        max_concurrent_invocations_per_instance=2,
        notification_config={
            "SuccessTopic": config.model_callback_topic_arn,
            "ErrorTopic": config.model_callback_topic_arn,
        },
    ),
)

Next, [configure auto-scaling](https://docs.aws.amazon.com/sagemaker/latest/dg/async-inference-autoscale.html) for the endpoint by first registering it with the [application auto-scaling service](https://docs.aws.amazon.com/autoscaling/application/userguide/what-is-application-auto-scaling.html) and then applying a scaling policy:

In [None]:
appscaling = boto3.client("application-autoscaling")

# Define and register your endpoint variant
appscaling.register_scalable_target(
    ServiceNamespace="sagemaker",
    ResourceId=f"endpoint/{predictor_async.endpoint_name}/variant/AllTraffic",
    ScalableDimension="sagemaker:variant:DesiredInstanceCount",
    MinCapacity=0,  # (MinCapacity 0 not supported with real-time endpoints)
    MaxCapacity=5,
)
print(f"Endpoint registered with auto-scaling service: {predictor_async.endpoint_name}")

In [None]:
scaling_policy_resp = appscaling.put_scaling_policy(
    PolicyName=f"sagemaker-endpoint-{predictor_async.endpoint_name}",
    ServiceNamespace="sagemaker",
    ResourceId=f"endpoint/{predictor_async.endpoint_name}/variant/AllTraffic",
    ScalableDimension="sagemaker:variant:DesiredInstanceCount",
    PolicyType="TargetTrackingScaling",
    TargetTrackingScalingPolicyConfiguration={
        "TargetValue": 5.0,
        "CustomizedMetricSpecification": {
            "MetricName": "ApproximateBacklogSizePerInstance",
            "Namespace": "AWS/SageMaker",
            "Dimensions": [
                { "Name": "EndpointName", "Value": predictor_async.endpoint_name },
            ],
            "Statistic": "Average",
        },
        "ScaleInCooldown": 5 * 60,  # (seconds)
        "ScaleOutCooldown": 2 * 60,  # (seconds)
    }
)
print(f"Created/updated scaling policy ARN:\n{scaling_policy_resp['PolicyARN']}")

If needed, the policy can be later deleted and the endpoint de-registered from auto-scaling:

In [None]:
# appscaling.delete_scaling_policy(
#     PolicyName=f"sagemaker-endpoint-{predictor_async.endpoint_name}",
#     ServiceNamespace="sagemaker",
#     ResourceId=f"endpoint/{predictor_async.endpoint_name}/variant/AllTraffic",
#     ScalableDimension="sagemaker:variant:DesiredInstanceCount",
# )
# print(f"Auto-scaling policy deleted for endpoint {predictor_async.endpoint_name}")

# appscaling.deregister_scalable_target(
#     ServiceNamespace="sagemaker"
#     ResourceId=f"endpoint/{predictor_async.endpoint_name}/variant/AllTraffic",
#     ScalableDimension="sagemaker:variant:DesiredInstanceCount",
# )
# print(f"Auto-scaling de-registered from endpoint {predictor_async.endpoint_name}")

## Using the Model

Once the deployment is complete, we're ready to try it out with some real-time requests!

In [None]:
# As with estimators, you can attach the notebook to a previously deployed endpoint like this:
# from sagemaker.huggingface import HuggingFacePredictor
# predictor = HuggingFacePredictor(
#     "layoutlm-cfpb-hf-2021-09-02-01-08-11-234",
#     serializer=sagemaker.serializers.JSONSerializer(),
#     deserializer=sagemaker.deserializers.JSONDeserializer(),
# )

### Making requests and rendering results

This model accepts Textract-like JSON (e.g. as returned by [AnalyzeDocument](https://docs.aws.amazon.com/textract/latest/dg/API_AnalyzeDocument.html#API_AnalyzeDocument_ResponseSyntax) or [DetectDocumentText](https://docs.aws.amazon.com/textract/latest/dg/API_DetectDocumentText.html#API_DetectDocumentText_ResponseSyntax) APIs) and classifies each `WORD` [block](https://docs.aws.amazon.com/textract/latest/dg/API_Block.html) according to the entity classes we defined earlier: Returning the same JSON with additional fields added to indicate the predictions.

We can use utility functions to render these predictions as we did the manual annotations previously:

> ⚠️ **Note:** For long multi-page documents these JSON objects could become very large (many-MB), which can cross scaling thresholds such as configured payload limits (from SageMaker and/or TorchServe), inference timeouts (from SageMaker and/or TorchServe) or available memory.
>
> The handling logic in [src/code/inference.py](src/code/inference.py) supports a range of workarounds for this such as specific page selection and passing input and/or output by S3 reference instead of inline - as demonstrated below.

In [None]:
import ipywidgets as widgets
import trp

s3 = boto3.resource("s3")

def predict_from_manifest_item(
    item,
    predictor,
    imgs_s3key_prefix=imgs_s3uri[len("s3://"):].partition("/")[2],
    textract_s3key_prefix=textract_s3uri[len("s3://"):].partition("/")[2],
    imgs_local_prefix="data/imgs-clean",
    textract_local_prefix="data/textracted",
    draw=True,
):
    paths = util.viz.local_paths_from_manifest_item(
        item,
        imgs_s3key_prefix,
        textract_s3key_prefix=textract_s3key_prefix,
        imgs_local_prefix=imgs_local_prefix,
        textract_local_prefix=textract_local_prefix,
    )

    ## Basic inline request/response may fail for large, multi-page documents (because of breaking
    ## the 5MB real-time inference payload limit; or the model running out of memory):
#     with open(paths["textract"], "r") as ftextract:
#         result_json = predictor.predict(json.loads(ftextract.read()))

    ## We can strip the JSON down to only the target page of interest like this:
#     with open(paths["textract"], "r") as ftextract:
#         result_json = predictor.predict({
#             "Blocks": trp.Document(
#                 json.loads(ftextract.read()),
#             ).pages[item["page-num"] - 1].blocks
#         })

    ## Or have the model refer directly to S3 and return us only the page of interest:
    result_json = predictor.predict({
        "S3Input": { "URI": item["textract-ref"] },
        "TargetPageNum": item["page-num"],
        "TargetPageOnly": True,
    })

    ## If we wanted, we could even have the model save results to S3 and fetch them ourselves:
    ## (Which is what the OCR pipeline does when configured with a real-time endpoint)
#     result_json = predictor.predict({
#         "S3Input": { "URI": item["textract-ref"] },
#         "TargetPageNum": item["page-num"],
#         "TargetPageOnly": True,
#         "S3Output": { "Bucket": bucket_name, "Key": f"{bucket_prefix}tmp/model-result" },
#     })
#     result_json = json.loads(
#         s3.Bucket(result_json["Bucket"]).Object(result_json["Key"]).get()["Body"].read()
#     )

    result_trp = trp.Document(result_json)

    if draw:
        util.viz.draw_smgt_annotated_page(
            paths["image"],
            entity_classes,
            annotations=[],
            textract_result=result_trp,
            # Note that page_num should be item["page-num"] if we requested the full set of pages
            # from the model above:
            page_num=1,
        )
    return result_trp


widgets.interact(
    lambda ix: predict_from_manifest_item(
        test_examples[ix],
        predictor,
    ),
    ix=widgets.IntSlider(
        min=0,
        max=len(test_examples) - 1,
        step=1,
        value=0,
        description="Example:",
    )
)

### From token classification to entity detection

You may have noticed a slight mismatch: We're talking about extracting 'fields' or 'entities' from the document, but our model just classifies individual words. Going from words to entities assumes we're able to understand which words go "together" and what order they should be read in.

Fortunately, Textract helps us out with this too as the word blocks are already collected into `LINE`s.

For many straightforward applications, we can simply loop through the lines on a page and define an "entity detection" as a contiguous group of the same class - as below:

In [None]:
res = predict_from_manifest_item(
    test_examples[6],
    predictor,
    draw=False,
)

In [None]:
other_cls = len(entity_classes)
prev_cls = other_cls
current_entity = ""

for page in res.pages:
    for line in page.lines:
        for word in line.words:
            pred_cls = word._block["PredictedClass"]
            if pred_cls != prev_cls:
                if prev_cls != other_cls:
                    print(f"----------\n{entity_classes[prev_cls]}:\n{current_entity}")
                prev_cls = pred_cls
                if pred_cls != other_cls:
                    current_entity = word.text
                else:
                    current_entity = ""
                continue
            current_entity = " ".join((current_entity, word.text))


Of course there may be some instances where this heuristic breaks down, but we still have access to all the position (and text) information from each `LINE` and `WORD` to write additional rules for reading order and separation if wanted.

### Integrating the model with the OCR Pipeline

If you've deployed the **OCR pipeline stack** in your AWS Account, you can now configure it to use this endpoint as follows:

- First, identify the **endpoint name** of your deployed model. Assuming you created the predictor as above, you can simply run the following cell:

In [None]:
print(predictor.endpoint_name)

- Next, identify the **AWS Systems Manager Parameter** that configures the SageMaker endpoint for the OCR pipeline stack.

The below code should pull it through for you, but alternatively you can refer to your stack's **Outputs** in the [AWS CloudFormation Console](https://console.aws.amazon.com/cloudformation/home?#/stacks). The Output name should include `SageMakerEndpoint`.

In [None]:
print(config.sagemaker_endpoint_name_param)

- Finally, we'll update this SSM parameter to point to the deployed SageMaker endpoint.

The below code should do this for you automatically:

In [None]:
pipeline_endpoint_name = predictor.endpoint_name
# Or, if you deployed a SageMaker async endpoint to use instead:
#pipeline_endpoint_name = predictor_async.endpoint_name

print(f"Configuring pipeline with model: {pipeline_endpoint_name}")

ssm = boto3.client("ssm")
ssm.put_parameter(
    Name=config.sagemaker_endpoint_name_param,
    Overwrite=True,
    Value=pipeline_endpoint_name,
)

Alternatively, you could open the [AWS Systems Manager Parameter Store console](https://console.aws.amazon.com/systems-manager/parameters/?tab=Table) and click on the *name* of the parameter to open its detail page, then the **Edit** button in the top right corner as shown below:

![](img/ssm-param-detail-screenshot.png "Screenshot of SSM parameter detail page showing Edit button")

From this screen you can manually set the **Value** of the parameter and save the changes.

Whether you updated the SSM parameter via code or the console, your stack is now configured to use the deployed model for OCR enrichment!


### Updating the pipeline entity definitions

As well as configuring the *enrichment* stage of the pipeline to reference the deployed version of the model, we need to configure the *post-processing* stage to match the model's **definition of entity/field types**.

The entity configuration is as we saved in the previous notebook, but the `annotation_guidance` attributes are not needed:

> ℹ️ **Note:** As well as the mapping from ID numbers (returned by the model) to human-readable class names, this configuration controls how the pipeline consolidates entity matches into "fields" of the document: E.g. choosing the "most likely" or "first" value between multiple detections, or setting up a multi-value field.

In [None]:
pipeline_entity_config = json.dumps([f.to_dict(omit=["annotation_guidance"]) for f in fields], indent=2)
print(pipeline_entity_config)

As above, you *could* set this value manually in the SSM console for the parameter named as `EntityConfig`.

...But we can make the same update via code through the APIs:

In [None]:
print(f"Setting pipeline entity configuration")
ssm.put_parameter(
    Name=config.entity_config_param,
    Overwrite=True,
    Value=pipeline_entity_config,
)

### Trying out the pipeline

To see the pipeline in action:

▶️ **Open** the [AWS Step Functions Console](https://console.aws.amazon.com/states/home?#/statemachines) and click on the name of your *State Machine* from the list to see its details.

(If you can't find it in the list, the code below should look it up for you or you can check the *Outputs* tab of your pipeline stack in the [AWS CloudFormation Console](https://console.aws.amazon.com/cloudformation/home?#/stacks))

In [None]:
print("Your pipeline state machine is:")
print(config.pipeline_sfn_arn.rpartition(":")[2])

▶️ **Locate** your pipeline's `InputBucket` in [Amazon S3](https://s3.console.aws.amazon.com/s3/home?)

(Likewise you can look this up from CloudFormation or using the below)

In [None]:
print("Your pipeline's input S3 bucket:")
print(config.pipeline_input_bucket_name)

▶️ **Upload** a sample document (PDF) from our dataset to the S3 bucket

You can do this by dragging and dropping the file to the S3 console - or running the cells below to upload a test document through the AWS CLI:

In [None]:
pdfpaths = []
for currpath, dirs, files in os.walk("data/raw"):
    if "/." in currpath or "__" in currpath:
        continue
    pdfpaths += [
        os.path.join(currpath, f) for f in files
        if f.lower().endswith(".pdf")
    ]
pdfpaths = sorted(pdfpaths)

In [None]:
test_filepath = pdfpaths[0]
test_s3uri = f"s3://{config.pipeline_input_bucket_name}/{test_filepath}"

!aws s3 cp '{test_filepath}' '{test_s3uri}'

You should see that a new *execution* (run) of the state machine is triggered automatically:

> ℹ️ This may take a few seconds after the upload is complete. If you're not seeing it:
>
> - Check you're in the correct "pipeline" state machine, as this solution's stack creates more than one state machine
> - Try refreshing the page or the execution list

![](img/sfn-statemachine-screenshot.png "Screenshot of AWS Step Functions state machine detail page showing execution list")

Clicking through to the execution, you'll be able to see the progress through the workflow and output/error information.

Depending on your configuration, your view may look a little different to the below and you may have **either a successful execution or a failure at the review step**:

![](img/sfn-execution-status-screenshot.png "Screenshot of Step Functions execution detail view")

## Next steps

You should now have been able to train and deploy the enrichment model, and demonstrate its integration to the pipeline.

However, the final human review stage is not fully set up yet, so may have triggered an error.

In the final notebook, we'll configure the human review functionality to finish up the flow: Open up **notebook [3. Human Review.ipynb](3.%20Human%20Review.ipynb)** to follow along.


### A note on clean-up

Note that while training, processing and transform jobs in SageMaker start and stop compute resources for the specific job being executed, deployed **endpoints** stay active (and therefore accumulating charges) until you turn them off.

When you're finished using an endpoint, you should delete it either through the [Amazon SageMaker Console](https://console.aws.amazon.com/sagemaker/home?#/endpoints) or via commands like the below.

(Of course, our OCR pipeline stack will throw an error if you try to run it configured with an Endpoint Name that no longer exists)

In [None]:
# predictor.delete_endpoint(delete_endpoint_config=True)

In [None]:
# predictor_async.delete_endpoint(delete_endpoint_config=True)