# Summarize Scientific Documents with Amazon Comprehend and HuggingFace

Researchers must stay up-to-date on their fields of interest. However, it's difficult to keep track of the large number of journals, whitepapers, and research pre-prints generated in many areas. In response, many research groups have turned to AI/ML tools to summarize and classify new documents.

In this workshop, we'll use several AWS AI/ML services to process scientific documents from the [NIH NCBI PMC Article Dataset](https://registry.opendata.aws/ncbi-pmc/) on the Registry of Open Data. This is a free full-text archive of biomedical and life sciences journal article at the U.S. National Institutes of Health's National Library of Medicine.

# 1. Import Libraries and Create Clients

In [None]:
import boto3
import sagemaker
import os
import json
import pprint
import pandas as pd
from random import sample
import re

boto_session = boto3.Session()
s3 = boto_session.client('s3')
sm_session = sagemaker.Session(boto_session=boto_session)
s3_bucket = sm_session.default_bucket()
s3_prefix = "sci-docs/data"
print(f"S3 path is {s3_bucket}/{s3_prefix}")

# 2. Download Documents from the NIH NCBI PMC Article Dataset

Copy 25 random articles from the PubMed open data set (https://registry.opendata.aws/ncbi-pmc/) into the SageMaker default bucket for this account.

In [None]:
pmc_bucket = "pmc-oa-opendata"
pmc_prefix = "oa_comm/txt/all/"
local_raw_data_dir = "data/raw/"

article_names = [ os.path.basename(article["Key"]) for article in sample(s3.list_objects_v2(Bucket=pmc_bucket, Prefix=pmc_prefix)["Contents"], 25) ]
for article in article_names:
    print(article)
    sm_session.download_data(
        local_raw_data_dir,
        bucket=pmc_bucket,
        key_prefix=pmc_prefix+article
    )    
    
# Once all files have been downloaded, upload them all to the S3 bucket for your project
sm_session.upload_data(
    local_raw_data_dir,
    bucket=s3_bucket,
    key_prefix=s3_prefix+"/raw"
)   

Look at a few examples

In [None]:
art = sample(article_names,1)[0]
print(art)
!head data/raw/{art}

# 3. Summarize the Documents Using Amazon Comprehend Topic Modelling

Submit an Amazon Comprehend topic modelling job

In [None]:
comprehend = boto_session.client(service_name='comprehend')

sagemaker.s3.s3_path_join(s3_bucket, s3_prefix, "raw")

input_s3_url = sagemaker.s3.s3_path_join("s3://", s3_bucket, s3_prefix, "raw")
input_doc_format = "ONE_DOC_PER_FILE"
output_s3_url = sagemaker.s3.s3_path_join("s3://", s3_bucket, s3_prefix, "output")
data_access_role_arn = sagemaker.session.get_execution_role()
number_of_topics = 25

input_data_config = {"S3Uri": input_s3_url, "InputFormat": input_doc_format}
output_data_config = {"S3Uri": output_s3_url}

start_topics_detection_job_result = comprehend.start_topics_detection_job(NumberOfTopics=number_of_topics,
                                                                              InputDataConfig=input_data_config,
                                                                              OutputDataConfig=output_data_config,
                                                                              DataAccessRoleArn=data_access_role_arn)

job_id = start_topics_detection_job_result["JobId"]
print(f"Job {job_id} submitted")

Once job is finished, download and unpack the results

In [None]:
describe_topics_detection_job_result = comprehend.describe_topics_detection_job(JobId=job_id).get("TopicsDetectionJobProperties", [])
print(f"Job {job_id} status is {describe_topics_detection_job_result['JobStatus']}")

if describe_topics_detection_job_result["JobStatus"] == "COMPLETED":
    output_url = sagemaker.s3.parse_s3_url(describe_topics_detection_job_result["OutputDataConfig"]["S3Uri"])
    sm_session.download_data(
        "data",
        bucket=output_url[0],
        key_prefix=output_url[1],
    )
    os.system("!tar xvfz data/output.tar.gz -C data/output/ && rm data/output.tar.gz")

    topics = pd.read_csv("data/topic-terms.csv").sort_values(['topic', 'weight'], ascending=[True, False]).groupby(['topic'])['term'].agg(lambda x : ', '.join(x))
    docs = pd.read_csv("data/doc-topics.csv").sort_values(['docname', 'proportion'], ascending=[True, False])
    results = pd.merge(docs, topics, how='left', on='topic')
    display(results)

Let's look at some specific examples:

In [None]:
if describe_topics_detection_job_result["JobStatus"] == "COMPLETED":

    input_url = sagemaker.s3.parse_s3_url(describe_topics_detection_job_result["InputDataConfig"]["S3Uri"])
    sample = results.sample()
    docname, idx, score, terms = sample.iloc[0,:]

    print(f"Document name is {docname}")
    print(f"Identified terms are {terms}")

    sm_session.download_data(
            "data",
            bucket=input_url[0],
            key_prefix=os.path.join(input_url[1], docname)
        )
    os.system(f"head -n 25 data/{docname}")

# 4. Generate TLDR Summaries Using a Pre-Trained NLP Model from HuggingFace

In [None]:
import sagemaker
from sagemaker.huggingface import HuggingFaceModel
from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig
from sagemaker.async_inference.waiter_config import WaiterConfig

role = sagemaker.get_execution_role()

# Hub Model configuration. https://huggingface.co/models
hub = {
	'HF_MODEL_ID':'alk/pegasus-scitldr',
	'HF_TASK':'text2text-generation'
}

# create Hugging Face Model Class
huggingface_model= HuggingFaceModel(
	transformers_version='4.17.0',
	pytorch_version='1.10.2',
	py_version='py38',
	env=hub,
	role=role, 
)

In [None]:
# deploy model to SageMaker Inference
async_config = AsyncInferenceConfig(
    output_path=f"s3://{s3_bucket}/{s3_prefix}/tldr_output",
    max_concurrent_invocations_per_instance=4
)

predictor = huggingface_model.deploy(
    async_inference_config=async_config,
	initial_instance_count=1, # number of instances
	instance_type='ml.m5.4xlarge', # ec2 instance type
    wait=True
)

Convert document text to json format and upload to s3

In [None]:
# Find an article with well-defined background information
result = None
while result is None:
    art = sample(article_names,1)[0]
    print(art)

    with open(f"data/raw/{art}", "r", encoding="utf-8", errors="replace") as f:
        text = f.read().replace("\n", " ").replace("\t", " ")
        result = re.search("Background (.{,1000})", text)

dict = {"inputs": result.group(1)} # Search for background infomation
print(dict)

In [None]:
async_response = predictor.predict_async(data=dict)

waiter = WaiterConfig(max_attempts=24, delay=15)
result = async_response.get_result(waiter)
pprint.pprint(result)

# 5. Clean Up

In [None]:
predictor.delete_endpoint()

# Delete all S3 objects
bucket = boto_session.resource("s3").Bucket(s3_bucket)
bucket.objects.filter(Prefix="sci-docs").delete()
os.system(f"rm -rf data")