# Amazon Comprehend custom document classification

In [153]:
import boto3
import botocore
import sagemaker
import json
import os
import io
import datetime
import pandas as pd
from PIL import Image
from pathlib import Path
import multiprocessing as mp
from sagemaker import get_execution_role
from IPython.display import Image, display, HTML, JSON

# variables
data_bucket = 'infiniteloopbucket'
region = boto3.session.Session().region_name
account_id = boto3.client('sts').get_caller_identity().get('Account') #WHAT IS STS

os.environ["BUCKET"] = data_bucket
os.environ["REGION"] = region
role = sagemaker.get_execution_role()
#role = arn:aws:iam::426377748928:role/service-role/AmazonComprehendServiceRole-classifier_role2

print(f"SageMaker role is: {role}\nDefault SageMaker Bucket: s3://{data_bucket}")

s3=boto3.client('s3')
comprehend=boto3.client('comprehend', region_name=region)

SageMaker role is: arn:aws:iam::426377748928:role/TeamRole
Default SageMaker Bucket: s3://infiniteloopbucket


In [125]:
role

'arn:aws:iam::426377748928:role/TeamRole'

In [126]:
# Upload Comprehend training data to S3
key='HRemail_df.csv' #PATH TO CSV LOCATION
s3.upload_file(Filename='HRemail_df.csv', #FILE NAME
               Bucket=data_bucket, 
               Key=key)


Let's review the training data

In [127]:
df = pd.read_csv('HRemail_df.csv', names=["Class", "Document"])
df

Unnamed: 0,Class,Document
0,Cold emails,"Reaching out - Factmata\nDear Mr. Cuban,\nApol..."
1,Cold emails,John Lee (MIT graduate) - Inquiry into Biomedi...
2,Cold emails,Editorial internship inquiry\nCarolina VonKamp...
3,Cold emails,Interested in Eight Sleep Marketing Internship...
4,Cold emails,"Hey, my name's Niraj Pant.\nI understand your ..."
...,...,...
417,interview acceptance,"To:\nFrom:\nSubject:\nMs. Decision Maker,\nI a..."
418,interview acceptance,The following template can be used for email i...
419,interview acceptance,To:\nFrom:\nSubject:\nMr./Ms. [Recruiter or Hi...
420,interview acceptance,"Dear Y,\nThank you very much for your mail.\nI..."


In [128]:
classes = df['Class'].unique()
classes_df = pd.DataFrame(classes, columns = ['Classes'])
classes_df

Unnamed: 0,Classes
0,Cold emails
1,Rejection
2,acceptance
3,follow up
4,interview acceptance


---

Once we have a labeled dataset ready we are going to create and train a [Amazon Comprehend custom classification model](https://docs.aws.amazon.com/comprehend/latest/dg/how-document-classification.html) with the dataset.

### Check status of the Comprehend Custom Classification Job

Let's check the status of the training job.

In [157]:
%%time
# Loop through and wait for the training to complete.
import time
from datetime import datetime
import pprint



jobArn = create_response['DocumentClassifierArn']

max_time = time.time() + 3*60*60 # 3 hours
while time.time() < max_time:
    now = datetime.now()
    current_time = now.strftime("%H:%M:%S")
    describe_custom_classifier = comprehend.describe_document_classifier(
        DocumentClassifierArn = jobArn
        
    )
    #describe_custom_classifier = comprehend.describe_document_classifier(DocumentClassifierArn = create_response['DocumentClassifierArn'])
    #print("describe: %s\n", describe_custom_classifier)
    
    status = describe_custom_classifier["DocumentClassifierProperties"]["Status"]
    print(f"{current_time} : Custom document classifier: {status}")
    
    if status == "TRAINED" or status == "IN_ERROR":
        break
    
    #pprint.pprint(describe_custom_classifier)
        
    time.sleep(60)
    

07:27:09 : Custom document classifier: SUBMITTED
07:28:09 : Custom document classifier: IN_ERROR
CPU times: user 19.9 ms, sys: 1.86 ms, total: 21.7 ms
Wall time: 1min


In [144]:
import boto3
# Instantiate Boto3 SDK:
client = boto3.client('comprehend', region_name='region')
# Create a document classifier
create_response = client.create_document_classifier(
    InputDataConfig={
    'S3Uri': 's3://S3Bucket/docclass/file name'
    },
    DataAccessRoleArn='arn:aws:iam::account number:role/resource name',
    DocumentClassifierName='SampleCodeClassifier1',
    LanguageCode='en'
)
print("Create response: %s\n", create_response)
# Check the status of the classifier
describe_response = client.describe_document_classifier(
DocumentClassifierArn=create_response['DocumentClassifierArn'])
print("Describe response: %s\n", describe_response)
# List all classifiers in account
list_response = client.list_document_classifiers()
print("List response: %s\n", list_response)

EndpointConnectionError: Could not connect to the endpoint URL: "https://comprehend.region.amazonaws.com/"