# Amazon Comprehend custom document classification

## Step 1: Create Amazon Comprehend Classification Training Job <a id="step1"></a>

In this step, we will import some necessary libraries that will be used throughout this notebook.

We will then use a prepared dataset, of the appropriate filetype (.csv) and structure - one column containing the raw text of a document, and the other column containing the label of that document.

The custom classification model we are going to train is in [Multi-class mode](https://docs.aws.amazon.com/comprehend/latest/dg/prep-classifier-data-multi-class.html) and we will use a CSV file to train the model. You can also use an Augmented manifest file to train the model, please review the documentation on how to use augmented manifest file. 

We will look at the CSV training data in the subsequent sections.

In [153]:
import boto3
import botocore
import sagemaker
import json
import os
import io
import datetime
import pandas as pd
from PIL import Image
from pathlib import Path
import multiprocessing as mp
from sagemaker import get_execution_role
from IPython.display import Image, display, HTML, JSON

# variables
data_bucket = 'infiniteloopbucket'
region = boto3.session.Session().region_name
account_id = boto3.client('sts').get_caller_identity().get('Account') #WHAT IS STS

os.environ["BUCKET"] = data_bucket
os.environ["REGION"] = region
role = sagemaker.get_execution_role()
#role = arn:aws:iam::426377748928:role/service-role/AmazonComprehendServiceRole-classifier_role2

print(f"SageMaker role is: {role}\nDefault SageMaker Bucket: s3://{data_bucket}")

s3=boto3.client('s3')
comprehend=boto3.client('comprehend', region_name=region)

SageMaker role is: arn:aws:iam::426377748928:role/TeamRole
Default SageMaker Bucket: s3://infiniteloopbucket


In [125]:
role

'arn:aws:iam::426377748928:role/TeamRole'

We will use the pre-prepared dataset and upload it to Amazon S3. The dataset is in `CSV` format and will be named `comprehend_train_data.csv`. 

The following code cells will upload the training data to the S3 bucket, and create a Custom Comprehend Classifier. You can also create a custom classifier manually, please see the subsequent sections for instructions on how to do that.

In [126]:
# Upload Comprehend training data to S3
key='HRemail_df.csv' #PATH TO CSV LOCATION
s3.upload_file(Filename='HRemail_df.csv', #FILE NAME
               Bucket=data_bucket, 
               Key=key)


Let's review the training data

In [127]:
df = pd.read_csv('HRemail_df.csv', names=["Class", "Document"])
df

Unnamed: 0,Class,Document
0,Cold emails,"Reaching out - Factmata\nDear Mr. Cuban,\nApol..."
1,Cold emails,John Lee (MIT graduate) - Inquiry into Biomedi...
2,Cold emails,Editorial internship inquiry\nCarolina VonKamp...
3,Cold emails,Interested in Eight Sleep Marketing Internship...
4,Cold emails,"Hey, my name's Niraj Pant.\nI understand your ..."
...,...,...
417,interview acceptance,"To:\nFrom:\nSubject:\nMs. Decision Maker,\nI a..."
418,interview acceptance,The following template can be used for email i...
419,interview acceptance,To:\nFrom:\nSubject:\nMr./Ms. [Recruiter or Hi...
420,interview acceptance,"Dear Y,\nThank you very much for your mail.\nI..."


In [128]:
classes = df['Class'].unique()
classes_df = pd.DataFrame(classes, columns = ['Classes'])
classes_df

Unnamed: 0,Classes
0,Cold emails
1,Rejection
2,acceptance
3,follow up
4,interview acceptance


Our training dataset contains exactly 7 classes that we are going to train the custom classifier with. The first column in the CSV is the class label, and the second column in the CSV is the document's text. Together, each line of the file contains a single class and the text of a document that demonstrates that class. If you have samples in the form of PDF, PNG, JPG, TIFF etc. you can extract the text using OCR technology such as [Amazon Textract](https://docs.aws.amazon.com/textract/latest/dg/what-is.html) to extract the text from the documents to prepare the CSV training data. 

---

Once we have a labeled dataset ready we are going to create and train a [Amazon Comprehend custom classification model](https://docs.aws.amazon.com/comprehend/latest/dg/how-document-classification.html) with the dataset.

### Create Amazon Comprehend custom classification Training Job

<div class="alert alert-block alert-warning"> <b>💡 NOTE:</b> <p>Executing the model training code block below will start a training job which can take upwards of 40 to 60 minutes to complete. </div>

We will use Amazon Comprehend custom classification to train our own model for classifying the documents. We will use Amazon Comprehend `CreateDocumentClassifier` API to create a classifier which will train a custom model using the labeled data CSV file we created above. The training data contains extracted text, that was extracted using Amazon Textract, and then labeled.

In [158]:
import uuid
from datetime import datetime
uuid_id = uuid.uuid1()

# Create a document classifier
account_id = boto3.client('sts').get_caller_identity().get('Account')
id = str(datetime.now().strftime("%s"))

document_classifier_name = f"custom-doc-class-{uuid_id}"
document_classifier_version = 'v1'
document_classifier_arn = ''
response = None

try:
    print(f'Starting training job in region: {region} for account ID: {account_id}, with training data s3://{data_bucket}/{key}')
    create_response = comprehend.create_document_classifier(
        InputDataConfig={
            'DataFormat': 'COMPREHEND_CSV',
            'S3Uri': f's3://infiniteloopbucket/HRemail_df.csv'
        },
        DataAccessRoleArn=role,
        DocumentClassifierName=document_classifier_name,
        VersionName=document_classifier_version,
        LanguageCode='en',
        Mode='MULTI_CLASS'
    )
    
    document_classifier_arn = create_response['DocumentClassifierArn']
    %store document_classifier_arn
    print(f"Comprehend Custom Classifier created with ARN: {document_classifier_arn}")
except Exception as error:
    if error.response['Error']['Code'] == 'ResourceInUseException':
        print(f'A classifier with the name "{document_classifier_name}" already exists.')
        document_classifier_arn = f'arn:aws:comprehend:{region}:{account_id}:document-classifier/{document_classifier_name}/version/{document_classifier_version}'
        print(f'The classifier ARN is: "{document_classifier_arn}"')
    else:
        print(error)

Starting training job in region: ap-southeast-1 for account ID: 426377748928, with training data s3://infiniteloopbucket/HRemail_df.csv
Stored 'document_classifier_arn' (str)
Comprehend Custom Classifier created with ARN: arn:aws:comprehend:ap-southeast-1:426377748928:document-classifier/custom-doc-class-ea8a86ca-c0a7-11ed-b5cb-b58e85463887/version/v1


This job can take ~30 minutes to complete. Once the training job is completed move on to next step.

### Check status of the Comprehend Custom Classification Job

Let's check the status of the training job.

In [157]:
%%time
# Loop through and wait for the training to complete.
import time
from datetime import datetime
import pprint



jobArn = create_response['DocumentClassifierArn']

max_time = time.time() + 3*60*60 # 3 hours
while time.time() < max_time:
    now = datetime.now()
    current_time = now.strftime("%H:%M:%S")
    describe_custom_classifier = comprehend.describe_document_classifier(
        DocumentClassifierArn = jobArn
        
    )
    #describe_custom_classifier = comprehend.describe_document_classifier(DocumentClassifierArn = create_response['DocumentClassifierArn'])
    #print("describe: %s\n", describe_custom_classifier)
    
    status = describe_custom_classifier["DocumentClassifierProperties"]["Status"]
    print(f"{current_time} : Custom document classifier: {status}")
    
    if status == "TRAINED" or status == "IN_ERROR":
        break
    
    #pprint.pprint(describe_custom_classifier)
        
    time.sleep(60)
    

07:27:09 : Custom document classifier: SUBMITTED
07:28:09 : Custom document classifier: IN_ERROR
CPU times: user 19.9 ms, sys: 1.86 ms, total: 21.7 ms
Wall time: 1min


Alternatively, you can also check the status of the training job from the Amazon Comprehend console. Navigate to the [Amazon Comprehend console](https://console.aws.amazon.com/comprehend) screen and click _"Custom classification"_ under the _"Customization"_ menu on the left panel.

In [144]:
import boto3
# Instantiate Boto3 SDK:
client = boto3.client('comprehend', region_name='region')
# Create a document classifier
create_response = client.create_document_classifier(
    InputDataConfig={
    'S3Uri': 's3://S3Bucket/docclass/file name'
    },
    DataAccessRoleArn='arn:aws:iam::account number:role/resource name',
    DocumentClassifierName='SampleCodeClassifier1',
    LanguageCode='en'
)
print("Create response: %s\n", create_response)
# Check the status of the classifier
describe_response = client.describe_document_classifier(
DocumentClassifierArn=create_response['DocumentClassifierArn'])
print("Describe response: %s\n", describe_response)
# List all classifiers in account
list_response = client.list_document_classifiers()
print("List response: %s\n", list_response)

EndpointConnectionError: Could not connect to the endpoint URL: "https://comprehend.region.amazonaws.com/"

---
## Step 2: Classify Documents using the custom classifier asynchronous analysis job<a id="step2"></a>

In this step we will use the Comprehend classifier model that we just trained to classify a group of un-identified documents. We will use Comprehend [StartDocumentClassificationJob](https://docs.aws.amazon.com/comprehend/latest/APIReference/API_StartDocumentClassificationJob.html) API to run an asynchronous job that will classify our documents.

Amazon Comprehend Async classification works with PDF, PNG, JPEG, as well as UTF-8 encoded plaintext files. Since our sample documents under the `sample_docs` directory are of wither JPEG, PNG, or PDF format, we will specify a `DocumentReadAction` and use Amazon Textract with the `TEXTRACT_DETECT_DOCUMENT_TEXT`. This will tell Amazon Comprehend to use Amazon Textract [DetectDocumentText](https://docs.aws.amazon.com/textract/latest/dg/API_DetectDocumentText.html) API behind the scenes to extract the text and then perform classification. For `InputFormat`, we will use `ONE_DOC_PER_FILE` mode which signifies that each file is a single document (the other mode is `ONE_DOC_PER_LINE` which means every line in the plaintext file is a document, this is best suited for small documents such as product reviews or customer service chat transcripts etc.). More on this, see [documentation](https://docs.aws.amazon.com/comprehend/latest/dg/how-class-run.html)

To begin with the classification of the sample documents, first let's upload them into the S3 bucket.

In [43]:
# Upload data to S3 bucket:
!aws s3 sync ./sample-docs s3://{data_bucket}/comprehend/doc-class-samples/


The user-provided path ./sample-docs does not exist.


Once the documents are uploaded, we will start a a classification job using the [StartDocumentClassificationJob](https://docs.aws.amazon.com/comprehend/latest/APIReference/API_StartDocumentClassificationJob.html) API and the configurations discussed above.

In [44]:
import uuid

jobname = f'classification-job-{uuid.uuid1()}'
print(f'Starting Comprehend Classification job {jobname} with model {document_classifier_arn}')

response = comprehend.start_document_classification_job(
    JobName=jobname,
    DocumentClassifierArn=document_classifier_arn,
    InputDataConfig={
        'S3Uri': f's3://{data_bucket}/comprehend/doc-class-samples/',
        'InputFormat': 'ONE_DOC_PER_FILE',
        'DocumentReaderConfig': {
            'DocumentReadAction': 'TEXTRACT_DETECT_DOCUMENT_TEXT',
            'DocumentReadMode': 'FORCE_DOCUMENT_READ_ACTION'
        }
    },
    OutputDataConfig={
        'S3Uri': f's3://{data_bucket}/comprehend/doc-class-output/'
    },
    DataAccessRoleArn=role
)

JSON(response)

Starting Comprehend Classification job classification-job-ca9fd656-c013-11ed-b5cb-b58e85463887 with model arn:aws:comprehend:ap-southeast-1:426377748928:document-classifier/custom-doc-class-3b58d8e4-c00e-11ed-b5cb-b58e85463887/version/v1


ResourceUnavailableException: An error occurred (ResourceUnavailableException) when calling the StartDocumentClassificationJob operation: CLASSIFIER_NOT_TRAINED_MESSAGE: Classifier (Optional[arn:aws:comprehend:ap-southeast-1:426377748928:document-classifier/custom-doc-class-3b58d8e4-c00e-11ed-b5cb-b58e85463887/version/v1]) must have trained status.

### Check status of the classification job

The code block below will check the status of the classification job. If the job completes then it will download the output predictions. The output is a zip file which will contain the inference result for each of the documents being classified. The zip will also contain the output of the Textract operation performed by Amazon Comprehend.

In [None]:
%%time
# Loop through and wait for the training to complete . Takes up to 10 mins 
import time
from datetime import datetime
import tarfile
import os

classify_response=response
max_time = time.time() + 3*60*60 # 3 hours
documents=[]

while time.time() < max_time:
    now = datetime.now()
    current_time = now.strftime("%H:%M:%S")
    describe_job = comprehend.describe_document_classification_job(
        JobId=classify_response['JobId']
    )
    status = describe_job["DocumentClassificationJobProperties"]["JobStatus"]

    print(f"{current_time} : Custom document classifier Job: {status}")
    
    if status == "COMPLETED" or status == "FAILED":
        if status == "COMPLETED":
            classify_output_file = describe_job["DocumentClassificationJobProperties"]["OutputDataConfig"]["S3Uri"]
            print(f'Output generated - {classify_output_file}')
            !mkdir -p classification-output
            !aws s3 cp {classify_output_file} ./classification-output
            
            opfile = os.path.basename(classify_output_file)
            # open file
            file = tarfile.open(f'./classification-output/{opfile}')
            # extracting file
            file.extractall('./classification-output')
            file.close()
            
            for file in os.listdir('./classification-output'):
                if file.endswith('.out'):
                    with open(f'./classification-output/{file}', 'r') as f:
                        documents.append(dict(file=file, classification_output=json.load(f)['Classes']))        
        else:
            print("Classification job failed")
            print(describe_job)
        break
        
    time.sleep(10)

Let's take a look at the Amazon Comprehend classification output. We have collected the output for all the files in a `documents` variable. The script above will download and un-zip the zip file locally, so you can navigate into the `classification-output` directory from the file browser panel on the left and inspect the files manually.

In [None]:
for doc in documents:
    print(f"File: {doc['file']}")
    for doc_class in doc['classification_output']:
        print(f"└── Class: {doc_class['Name']} , Score: {round(doc_class['Score'] * 100, 2)}%")
    print("\n")

---

## Step 3: Create Document classification real-time endpoint

<div class="alert alert-block alert-warning">
    <b>⚠️ Note:</b> Creation of a real-time endpoint can take up to 15 minutes.
</div>


Once our Comprehend custom classifier is fully trained (i.e. status = `TRAINED`). You can also create a real-time endpoint. You can then use this endpoint to classify documents in real time. The following code cells use the `comprehend` Boto3 client to create an endpoint, but you can also create one manually via the console. Instructions on how to do that can be found in the subsequent section.

In [None]:
#create comprehend endpoint
import uuid
temp_id = str(uuid.uuid1())
model_arn = document_classifier_arn
ep_name = f'classifier-endpoint-{temp_id.split("-")[0]}'

try:
    endpoint_response = comprehend.create_endpoint(
        EndpointName=ep_name,
        ModelArn=model_arn,
        DesiredInferenceUnits=1,    
        DataAccessRoleArn=role
    )
    ENDPOINT_ARN=endpoint_response['EndpointArn']
    print(f'Endpoint created with ARN: {ENDPOINT_ARN}')    
except Exception as error:
    if error.response['Error']['Code'] == 'ResourceInUseException':
        print(f'An endpoint with the name "{ep_name}" already exists.')
        ENDPOINT_ARN = f'arn:aws:comprehend:{region}:{account_id}:document-classifier-endpoint/{ep_name}'
        print(f'The classifier endpoint ARN is: "{ENDPOINT_ARN}"')
        %store ENDPOINT_ARN
    else:
        print(error)
    

In [None]:
%store ENDPOINT_ARN

In [None]:
JSON(endpoint_response)

Alternatively, use the steps below to create a Comprehend endpoint using the AWS console.

- Go to [Comprehend on AWS Console](https://console.aws.amazon.com/comprehend/v2/home?region=us-east-1#endpoints) and click on Endpoints in the left menu.
- Click on "Create endpoint"
- Give an Endpoint name; for Custom model type select Custom classification; for version select no version or the latest version of the model.
- For Classifier model select from the drop down menu
- For Inference Unit select 1
- Check "Acknowledge"
- Click "Create endpoint"

[It may take ~15 minutes](https://console.aws.amazon.com/comprehend/v2/home?region=us-east-1#endpoints) for the endpoint to get created. The code cell below checks the creation status.


In [None]:
%%time
# Loop through and wait for the training to complete . Takes up to 10 mins 
import time
from datetime import datetime

ep_arn = endpoint_response["EndpointArn"]

max_time = time.time() + 3*60*60 # 3 hours
while time.time() < max_time:
    now = datetime.now()
    current_time = now.strftime("%H:%M:%S")
    describe_endpoint_resp = comprehend.describe_endpoint(
        EndpointArn=ep_arn
    )
    status = describe_endpoint_resp["EndpointProperties"]["Status"]
    print(f"{current_time} : Custom document classifier: {status}")
    
    if status == "IN_SERVICE" or status == "FAILED":
        break
        
    time.sleep(10)
    

---
## Step 4: Classify Documents using the real-time endpoint <a id="step4"></a>

Once the endpoint has been created, we will use some sample documents under the `/samle-docs` directory and try to classify them.

In [None]:
"""
Section below will be removed prior to publish, only applicable for beta environment
"""

import base64
from botocore.exceptions import ClientError

os.environ['AWS_DATA_PATH'] = './botodata/'
session = boto3.session.Session()
comprehend = session.client('comprehend', region_name='us-east-1')

"""
Section above will be removed prior to publish, only applicable for beta environment
"""

# Replace this with any document name in the /sample-docs/ directory
document = "CMS1500.png"

with open(f"./sample-docs/{document}", mode='rb') as file:
        document_bytes = file.read()
try:
    response = comprehend.classify_document(Bytes=document_bytes, 
                                        DocumentReaderConfig={
                                            "DocumentReadAction": "TEXTRACT_ANALYZE_DOCUMENT",
                                            "DocumentReadMode": "FORCE_DOCUMENT_READ_ACTION",
                                            "FeatureTypes": ["FORMS"]
                                        },
                                        EndpointArn=ENDPOINT_ARN)
    classes = response['Classes']
    metadata = response['DocumentMetadata']['ExtractedCharacters'][0]
    print(f"File: {document}")
    print(f"Page Count: {metadata['Page']}, Character count: {metadata['Count']}")
    for doc_class in classes:
        print(f"└── Class: {doc_class['Name']} , Score: {round(doc_class['Score'] * 100, 2)}%")
except ClientError as e:
    print(e)
    print("Error", e.response['Reason'], e.response['Detail']['Reason'])

In the above code cell, we classified a document in real-time using the endpoint we created earlier. Real-time endpoints are suitable for use-cases that have low latency, real-time requirements. One important thing to consider is that the size of document when using native semi-structured documents with classify document real-time API is that the max number of pages supported is one. So real-time endpoint is suitable for single page documents. If you have more than 20 documents, and or have multi-page documents, you should look at using Async analysis (aka asynchronous jobs) API, as we have seen earlier in the notebook.

---

## Cleanup

In this step we will delete the document classification real-time endpoint since will be charged for any deployed. It could take ~5 to 10 minutes to delete the endpoint

In [None]:
ep_del_response = comprehend.delete_endpoint(EndpointArn=ENDPOINT_ARN)
JSON(ep_del_response)

Once, the endpoint is fully deleted, let's delete the document classifier trained model.

In [None]:
dc_del_response = comprehend.delete_document_classifier(DocumentClassifierArn = document_classifier_arn)
JSON(dc_del_response)

Delete sample document and classification output files from S3

In [None]:
!aws s3 rm s3://{data_bucket}/comprehend/ --recursive

---
## Conslusion

In this notebook we learned how to train an Amazon Comprehend custom classifier using our pre-prepared dataset, that was constructed from sample documents by extracting the text from the documents using Amazon Textract and labeling the data into a CSV file format. We then trained an Amazon Comprehend custom classifier with the extracted text and created an Amazon Comprehend Classifier real time endpoint to performe classification of documents. We used documents in their native format (JPG, PNG, PDF..) without any extraction and conversion directly with the classification APIs to determine the document class with both asynchronous analysis job as well as real-time endpoint.