# Using AI21 Grammatical Error Correction (GEC) on SageMaker through Model Packages

This sample notebook shows you how to deploy **AI21 Grammatical Error Correction (GEC)** using Amazon SageMaker.

## Pre-requisites:
1. Before running this notebook, please make sure you got this notebook from the model catalog on SageMaker AWS Management Console.
1. **Note**: This notebook contains elements which render correctly in Jupyter interface. Open this notebook from an Amazon SageMaker Notebook Instance or Amazon SageMaker Studio.
1. Ensure that IAM role used has **AmazonSageMakerFullAccess**.
1. This notebook is intended to work with **boto3 v1.25.4** or higher.

## Contents:
1. [Select model package](#1.-Select-model-package)
1. [Create an endpoint and perform real-time inference](#2.-Create-an-endpoint-and-perform-real-time-inference)
   1. [Create an endpoint](#A.-Create-an-endpoint)
   1. [Proofread and detect mistakes](#B.-Proofread-and-detect-mistakes)
1. [Clean-up](#3.-Clean-up)
   1. [Delete the endpoint](#A.-Delete-the-endpoint)
   1. [Delete the model](#B.-Delete-the-model)


## Usage instructions
You can run this notebook one cell at a time (By using Shift+Enter for running a cell).

## 1. Select model package
Confirm that you received this notebook from the model catalog in SageMaker AWS Management Console.

In [13]:
model_package_map = {
    "us-east-1": "arn:aws:sagemaker:us-east-1:865070037744:model-package/gec-1-0-004-60323e2bffcf3443a9b89444c23ee912",
    "us-east-2": "arn:aws:sagemaker:us-east-2:057799348421:model-package/gec-1-0-004-60323e2bffcf3443a9b89444c23ee912",
    "us-west-1": "arn:aws:sagemaker:us-west-1:382657785993:model-package/gec-1-0-004-60323e2bffcf3443a9b89444c23ee912",
    "us-west-2": "arn:aws:sagemaker:us-west-2:594846645681:model-package/gec-1-0-004-60323e2bffcf3443a9b89444c23ee912",
    "ca-central-1": "arn:aws:sagemaker:ca-central-1:470592106596:model-package/gec-1-0-004-60323e2bffcf3443a9b89444c23ee912",
    "eu-central-1": "arn:aws:sagemaker:eu-central-1:446921602837:model-package/gec-1-0-004-60323e2bffcf3443a9b89444c23ee912",
    "eu-west-1": "arn:aws:sagemaker:eu-west-1:985815980388:model-package/gec-1-0-004-60323e2bffcf3443a9b89444c23ee912",
    "eu-west-2": "arn:aws:sagemaker:eu-west-2:856760150666:model-package/gec-1-0-004-60323e2bffcf3443a9b89444c23ee912",
    "eu-west-3": "arn:aws:sagemaker:eu-west-3:843114510376:model-package/gec-1-0-004-60323e2bffcf3443a9b89444c23ee912",
    "eu-north-1": "arn:aws:sagemaker:eu-north-1:136758871317:model-package/gec-1-0-004-60323e2bffcf3443a9b89444c23ee912",
    "ap-southeast-1": "arn:aws:sagemaker:ap-southeast-1:192199979996:model-package/gec-1-0-004-60323e2bffcf3443a9b89444c23ee912",
    "ap-southeast-2": "arn:aws:sagemaker:ap-southeast-2:666831318237:model-package/gec-1-0-004-60323e2bffcf3443a9b89444c23ee912",
    "ap-northeast-2": "arn:aws:sagemaker:ap-northeast-2:745090734665:model-package/gec-1-0-004-60323e2bffcf3443a9b89444c23ee912",
    "ap-northeast-1": "arn:aws:sagemaker:ap-northeast-1:977537786026:model-package/gec-1-0-004-60323e2bffcf3443a9b89444c23ee912",
    "ap-south-1": "arn:aws:sagemaker:ap-south-1:077584701553:model-package/gec-1-0-004-60323e2bffcf3443a9b89444c23ee912",
    "sa-east-1": "arn:aws:sagemaker:sa-east-1:270155090741:model-package/gec-1-0-004-60323e2bffcf3443a9b89444c23ee912"
}

In [14]:
import json
from sagemaker import ModelPackage
from sagemaker import get_execution_role
import sagemaker as sage
import boto3

### Check the version of boto3 - must be v1.25.4 or higher
If you see a lower version number, pick another kernel to run the notebook, with Python 3.8 or above

In [15]:
boto3.__version__

'1.26.74'

### Install ai21 python SDK

In [16]:
! pip install -U "ai21[SM]"
import ai21

Looking in indexes: https://pypi.org/simple, https://pip.repos.neuron.amazonaws.com


In [17]:
region = boto3.Session().region_name
if region not in model_package_map.keys():
    raise ("UNSUPPORTED REGION")

model_package_arn = model_package_map[region]

In [18]:
role = get_execution_role()
sagemaker_session = sage.Session()

runtime_sm_client = boto3.client("runtime.sagemaker")

## 2. Create an endpoint and perform real-time inference

If you want to understand how real-time inference with Amazon SageMaker works, see [Documentation](https://docs.aws.amazon.com/sagemaker/latest/dg/deploy-model.html).

In [19]:
endpoint_name = "gec"

content_type = "application/json"

real_time_inference_instance_type = (
    "ml.g4dn.2xlarge"
)

### A. Create an endpoint

In [20]:
# create a deployable model from the model package.
model = ModelPackage(
    role=role, model_package_arn=model_package_arn, sagemaker_session=sagemaker_session
)

# Deploy the model
predictor = model.deploy(1, real_time_inference_instance_type, endpoint_name=endpoint_name, 
                         model_data_download_timeout=3600,
                         container_startup_health_check_timeout=600,
                        )

-----------------!

Once endpoint has been created, you would be able to perform real-time inference.

### B. Proofread and detect mistakes

**AI21 Grammatical Error Correction (GEC) model** offers access to our world-class proofreading engine. It has been specifically developed for providing suggestions on how to correct spelling mistakes, grammar, punctuation, word misuse, and more.

This model takes a piece of text and returns a list of suggested corrections for detected mistakes, along with the type of mistake.

The input text should contain **no more than 500 characters**. The output will include all detected mistakes and suggested corrections.

The following is a simple example of detecting different types of grammatical errors and suggesting corrections for them.

In [9]:
text = "Your grammer is a reflection of you image. Good or bad: you have made an an impression. And like all impressions, you in total control."

response = ai21.GEC.execute(
    text=text,
    sm_endpoint=endpoint_name
)

print(response.corrections)

[{'suggestion': 'grammar', 'startIndex': 5, 'endIndex': 12, 'originalText': 'grammer', 'correctionType': 'Spelling'}, {'suggestion': 'your', 'startIndex': 32, 'endIndex': 35, 'originalText': 'you', 'correctionType': 'Grammar'}, {'suggestion': 'bad,', 'startIndex': 51, 'endIndex': 55, 'originalText': 'bad:', 'correctionType': 'Punctuation'}, {'suggestion': 'an impression.', 'startIndex': 70, 'endIndex': 87, 'originalText': 'an an impression.', 'correctionType': 'Word Repetition'}, {'suggestion': 'you are in', 'startIndex': 114, 'endIndex': 120, 'originalText': 'you in', 'correctionType': 'Missing Word'}]


Let's print the suggested corrections in a more comfortable way:

In [10]:
print("Original text (with errors):")
print(text)
print("============================")
print("Suggested corrections:")
print("\n".join([
    "- {correctionType} error detected at position ({startIndex},{endIndex}): {originalText} --> {suggestion}".format(
        correctionType=x['correctionType'],
        startIndex=x['startIndex'],
        endIndex=x['endIndex'],
        originalText=x['originalText'],
        suggestion=x['suggestion']
    ) for x in response.corrections]))

Original text (with errors):
Your grammer is a reflection of you image. Good or bad: you have made an an impression. And like all impressions, you in total control.
Suggested corrections:
- Spelling error detected at position (5,12): grammer --> grammar
- Grammar error detected at position (32,35): you --> your
- Punctuation error detected at position (51,55): bad: --> bad,
- Word Repetition error detected at position (70,87): an an impression. --> an impression.
- Missing Word error detected at position (114,120): you in --> you are in


### Interested in learning more?
Take a look at our [guide](https://docs.ai21.com/docs/gec-api) to understand all the capabilities of AI21 Grammatical Error model

## 3. Clean-up

### A. Delete the endpoint

Now that you have successfully performed a real-time inference, you do not need the endpoint any more. You can terminate the endpoint to avoid being charged.

In [11]:
model.sagemaker_session.delete_endpoint(endpoint_name)
model.sagemaker_session.delete_endpoint_config(endpoint_name)

### B. Delete the model

In [12]:
model.delete_model()