# Clinical Trials Gen AI Workshop Part 3

**Note:** This notebook was designed to be run in a JupyterLab space in SageMaker Studio running *Sagemaker Distribution 1.4* on an *ml.t3.medium* instance with 5GB of storage, although other configurations may be supported.

## Introduction

"This notebook demonstrates how to apply techniques learned in previous workshop sections to perform processing at scale. This notebook requires a AWS HealthLake Datastore to be created by following the [guide](https://docs.aws.amazon.com/healthlake/latest/devguide/getting-started.html) with **SYNTHEA** as the selected preloaded data type. If you do not have the correct role permissions for AWS Bedrock Batch Inference or an AWS HealthLake Datastore, you can still view the final result HTML page by running all the initial cells through 1.2, and then continuing execution from 1.12 with USE_SAMPLE_BATCH_INFERENCE_OUTPUT = True. 

## Setup

First, import some pre-requisite libraries

In [None]:
!pip install 'boto3>=1.35.1'
!pip install 'awswrangler>=3.9.1'
!pip show boto3
!pip show awswrangler

In [None]:
import requests
from xml.etree import ElementTree
import requests
from xml.etree import ElementTree
import boto3
import json
import awswrangler as wr
import pandas as pd
import random
import string
import concurrent.futures
import time
import sagemaker
from datetime import date
import html
from pathlib import Path
print("All pre-requisite libraries successfully installed")

Setting VERBOSE_LOGGING to ***True*** will generate a lot more cell output but may make it easier to visualize full prompts and can be a useful troubleshooting tool. We use ***False*** by default to make the output easier to read

In [None]:
VERBOSE_LOGGING = False
print(f"Current verbose logging setting is {VERBOSE_LOGGING}")

## 1. Clinical Trial Information Processing

### 1.1 Helper function to convert XML to JSON

We define a helper function below called parseXmlToJson.

It converts an XML like


    <eligibility>
        <criteria>
            <textblock>
                TEXT ABOUT ELIGIBILITY CRITERIA
            </textblock>
        </criteria>
    </eligibility>


to

    {
    	"eligibility": {
    		"criteria": {
    			"textblock": "TEXT ABOUT ELIGIBILITY CRITERIA"
    		}
    	}
    }

In [None]:
def parseXmlToJson(xml):
  response = {}

  for child in list(xml):
    if len(list(child)) > 0:
      response[child.tag] = parseXmlToJson(child)
    else:
      response[child.tag] = child.text or ''

  return response

### 1.2 Download the clinical trial information

We identify a clinical trial of interest. 

Here we've chosen [NCT02697071](https://classic.clinicaltrials.gov/ct2/show/NCT02697071) titled ***Ketamine for Acute Migraine in the Emergency Department***

In [None]:
# store the clinicaltrials.gov id in the variable
clinical_trial_id = 'NCT02697071'

# We download the clinical trial information
# from clinicaltrials.gov in xml format.
query_string = f"https://clinicaltrials.gov/ct2/show/{clinical_trial_id}?displayxml=true"
response = requests.get(query_string)

# Parse the XML to JSON object using the helper function
# we defined above
events = ElementTree.fromstring(response.content)
jsonObj = parseXmlToJson(events)

# The clinical trials website/information download includes 
# a lot of information about the clinical trial, we are only 
# interested in the criteria for the study

criteria = jsonObj["eligibility"]["criteria"]["textblock"]
print(criteria)

### 1.3 Helper function for invoking an LLM using Amazon Bedrock

Here we define a helper function that leverages Amazon Bedrock and provides an API to interact with Large Language Models (LLM). Here we have chosen to leverage Claude 3 Sonnet. We also provide an estimate of the number of input tokens which measures the length of the input prompt to the LLM. This estimate can be found by dividing the number of characters in the input prompt by 6. Claude 3 Sonnet supports a 200,000 input token context limit which roughly translates to 500 pages, or about the size of a book.

Claude uses a system role for enhanced accuracy and improved focus. Learn more details at https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/system-prompts.

We set the inference parameter temperature to 0 which reduces randomness (creativity) in responses so the model is more likely to provide the same output when invoked with the same input. You can find additional information about inference parameters at https://docs.aws.amazon.com/bedrock/latest/userguide/inference-parameters.html.

You must add model access for the Claude 3 Sonnet model in the current region as part of this workship via the Bedrock console if you have not done so previously. Follow the steps at https://docs.aws.amazon.com/bedrock/latest/userguide/model-access-modify.html.

In [None]:
BEDROCK_CLIENT = boto3.client(service_name='bedrock-runtime')
CLAUDE_3_SONNET_MODEL_ID = 'anthropic.claude-3-sonnet-20240229-v1:0'

def invoke_claude_v3(input_prompt, system_message="You are acting as a researcher evaluating if a patient profile is suitable for a clinical trial."):
    modelId = CLAUDE_3_SONNET_MODEL_ID
    
    body=json.dumps(
        {
            "anthropic_version": "bedrock-2023-05-31",
            "max_tokens": 2000,
            "temperature": 0, # see https://docs.aws.amazon.com/bedrock/latest/userguide/inference-parameters.html for more information
            "system": system_message,
            "messages": [{"role": "user", "content": input_prompt}]
        }  
    )  

    concat_index = 1500
    if VERBOSE_LOGGING or len(body) <= concat_index:
        print(f"Invoking {modelId} with input: {body}")
    else:
        index_to_concat = min(concat_index, len(body))
        print(f"Invoking {modelId} with concatenated input (set VERBOSE_LOGGING to True for full input): {body[0:index_to_concat]}\n\n...\n\n{body[-1*index_to_concat:]}")

    input_token_estimate = int(len(input_prompt)/6)
    print(f"\n\n Estimated input tokens: {input_token_estimate}")

    print("\n\n Response below (Please wait for the response output before proceeding to the next cell as the request may take time to process):\n")

    response = BEDROCK_CLIENT.invoke_model(body=body, modelId=modelId)
    response_body = json.loads(response.get('body').read())

    response = response_body['content'][0]['text']

    print(response)
   
    return response

# Validate bedrock model access
try:
    invoke_claude_v3("ping, respond with \"pong\"", "")
    print("\n")
    print("Bedrock model access confirmed")
except Exception as e:
    print(e)
    print("\n")
    print("Bedrock model access failed. Please ensure you have the correct permissions and the model is available in your region. Follow the documentation at https://docs.aws.amazon.com/bedrock/latest/userguide/model-access-modify.html to add model access.")

### 1.4 Parsing Clinical Trail Eligibility

The eligibility criteria that we have when downloading from the clinicaltrials.gov website comes as a text block. This can contain new line characters, bullet points, etc. We first want to transform this into a structured format, like a JSON object, where we can iterate through each study criterion using code. We could potentially write a rules-based engine to perform this work, but there are hundreds of variations in how researchers describe study criteria. Instead, we will leverage an LLM Assistant to perform this transformation.

In [None]:
input_prompt = f"""

We are looking to process the inclusion and exclusion criteria for a scientific study as a json object. 
Inside <study></study> XML tags is the inclusion and exclusion criteria for a scientific study. 
Put your answerser to the user inside <answer></answer> XML tags

<study>{criteria}</study>
"""

input_prompt += """

Here is an example:
<example>
{

"Inclusion Criteria":[
"Inclusion Criteria 1",
"Inclusion Criteria 2"],
"Exclusion Criteria":[
"Exclusion Criteria 1",
"Exclusion Criteria 2"],

}
</example>
"""

llm_study_response = invoke_claude_v3(input_prompt)

### 1.5 Parsing Clinical Trial Eligibility Response from LLM to Python object

We will now process the output of the LLM assistant and make it possible to interact with the structured output as a python json object.

In [None]:
raw_claude_output = llm_study_response.replace("\n","")
xml_start_loc = raw_claude_output.find("<answer>") + 8
xml_end_loc = raw_claude_output.find("</answer>")

if xml_start_loc == -1 or xml_end_loc == -1:
    raise Exception("No <answer> xml tag detected in claude output")

study_criteria_json = json.loads(raw_claude_output.replace("\n","")[xml_start_loc:xml_end_loc])

study_criteria = json.dumps(study_criteria_json, indent=4)
print(study_criteria)

## 2. Batch Clinical Trial Eligibility Processing

When processing large amounts of data using Amazon Bedrock, we will use the [batch inference](https://docs.aws.amazon.com/bedrock/latest/userguide/batch-inference.html) feature. This specific feature offers 50% off the on-demand inference price. Use batch inference to run multiple inference requests asynchronously, and improve the performance of model inference on large datasets. Completion time of batch inference depends on various factors like the size of the job, but you can expect completion timeframe of a typical job within 24 hours. Learn more by reviewing the [launch annoucement](https://aws.amazon.com/about-aws/whats-new/2024/08/amazon-bedrock-fms-batch-inference-50-price/) of this feature.



### 2.1 Configuration Setup

We intialize all of our AWS SDK clients here. We also retreive the S3 bucket that is created by default in this Sagemaker domain.

In [None]:
BEDROCK_CLIENT = boto3.client(service_name='bedrock')
S3_CLIENT = boto3.client('s3')
SAGEMAKER_SESSION = sagemaker.Session()
SAGEMAKER_S3_BUCKET = SAGEMAKER_SESSION.default_bucket()
print(f"Sagemaker S3 Bucket is {SAGEMAKER_S3_BUCKET}")
STS_CLIENT = boto3.client("sts")
ACCOUNT_ID = STS_CLIENT.get_caller_identity()["Account"]
print(f"Current ACCOUNT_ID is {ACCOUNT_ID}")
part_3_resource_folder = "[Resources] Clinical Trials Gen AI Part 3"
print(f"part_3_resource_folder is {part_3_resource_folder}")

### 2.2 Define Helper Functions for FHIR Retreival

We define helper functions that extract key data fields from patient, encounter, observation, condition, procedure, and medication FHIR resources for each patient's record.

In [None]:
def parseFHIRObj(text, search):
    if VERBOSE_LOGGING:
        print(f'text: {text}')
        print(f'search: {search}')
    if not isinstance(text, str) or len(text) < len(search) or text == 'Not Recorded':
        if VERBOSE_LOGGING:
            print('Text not found')
        return "Not Recorded"
    
    index = text.find(f'{search}=')
    found_text = text[index + len(search) + 1 : text.find(',', index)]

    if VERBOSE_LOGGING:
        print(f'found_text: {found_text}')
    
    return found_text

def compact_patient_resource(df):  
    compact_patient = {
        "gender" : df["gender"].iloc[0],
        "birthdate" : df["birthdate"].iloc[0]
    }
    return compact_patient

def compact_encounters_resource(df):
    scoped_df = df[["type","period","reasoncode"]].fillna('Not Recorded')

    new_encounter_list = []
    
    for index, row in scoped_df.iterrows():
        encounter_type = parseFHIRObj(row['type'], 'text')

        encounter_start = parseFHIRObj(row['period'], 'start')

        encounter_end = parseFHIRObj(row['period'], 'end')

        encounter_reasoncode = parseFHIRObj(row['reasoncode'], 'display')

        encounter = {
            'encounter_type': encounter_type,
            'encounter_start': encounter_start,
            'encounter_end': encounter_end,
            'encounter_reasoncode': encounter_reasoncode
        }

        new_encounter_list.append(encounter)

    if len(new_encounter_list) == 0:
        return "None"
    else:
        return new_encounter_list
    
def compact_observations_resource(df):
    scoped_df = df[["category","code","effectivedatetime","valuequantity", "component"]].fillna('Not Recorded')
    
    new_observations_list = []
    
    for index, row in scoped_df.iterrows():
        observation_category = parseFHIRObj(row['category'], 'code')

        observation_code = parseFHIRObj(row['code'], 'text')

        observation_time = row['effectivedatetime']

        if row['component'] == 'Not Recorded':
            observation_value = parseFHIRObj(row['valuequantity'], 'value')
        else:
            observation_value = row['component']

        observation_unit = parseFHIRObj(row['valuequantity'], 'unit')
    

        observation = {
            'observation_category': observation_category,
            'observation_code': observation_code,
            'observation_time': observation_time,
            'observation_value': observation_value,
            'observation_unit': observation_unit,
        }

        new_observations_list.append(observation)

    if len(new_observations_list) == 0:
        return "None"
    else:
        return new_observations_list

def compact_conditions_resource(df):
    scoped_df = df[["code","onsetdatetime","abatementdatetime","recordeddate"]].fillna('Not Recorded')
    
    new_conditions_list = []
    
    for index, row in scoped_df.iterrows():
        condition_code = parseFHIRObj(row['code'], 'text')

        condition_onsetdatetime = row['onsetdatetime']

        condition_abatementdatetime = row['abatementdatetime']

        condition_recordeddate = row['recordeddate']

        condition = {
            'condition_code': condition_code,
            'condition_onsetdatetime': condition_onsetdatetime,
            'condition_abatementdatetime': condition_abatementdatetime,
            'condition_recordeddate': condition_recordeddate,
        }

        new_conditions_list.append(condition)

    if len(new_conditions_list) == 0:
        return "None"
    else:
        return new_conditions_list

def compact_procedures_resource(df):
    scoped_df = df[["code","performedperiod","reasonreference"]].fillna('Not Recorded')
    
    new_procedures_list = []
    
    for index, row in scoped_df.iterrows():
        procedure_code = parseFHIRObj(row['code'], 'text')

        procedure_start = parseFHIRObj(row['performedperiod'], 'start')
        
        procedure_end = parseFHIRObj(row['performedperiod'], 'end')

        procedure_reason = parseFHIRObj(row['reasonreference'], 'display')

        procedure = {
            'procedure_code': procedure_code,
            'procedure_start': procedure_start,
            'procedure_end': procedure_end,
            'procedure_reason': procedure_reason,
        }

        new_procedures_list.append(procedure)
    
    if len(new_procedures_list) == 0:
        return "None"
    else:
        return new_procedures_list

def compact_medications_resource(df):
    scoped_df = df[["medicationcodeableconcept","authoredon"]].fillna('Not Recorded')
    
    new_medications_list = []
    
    for index, row in scoped_df.iterrows():
        medication_name = parseFHIRObj(row['medicationcodeableconcept'], 'text')

        medication_prescribed_time = row['authoredon']

        medication = {
            'medication_name': medication_name,
            'medication_prescribed_time': medication_prescribed_time,
        }

        new_medications_list.append(medication)

    if len(new_medications_list) == 0:
        return "None"
    else:
        return new_medications_list

### 2.3 Define Helper Functions for Batch Processing

Here, we define a helper function to create a random ID that can be used for AWS Bedrock Batch Processing

In [None]:
def generate_random_record_id():
    return ''.join(random.choices(string.ascii_uppercase + string.digits, k=12))

### 2.4 Find AWS Healthlake Athena Database

This jupyter notebook assumes you have created an AWS Healthlake Datastore by following the [documentation](https://docs.aws.amazon.com/healthlake/latest/devguide/create-data-store.html), ensure that "Preload sample data" checkbox is selected. This will automatically setup Athena access if the IAM role you're using contains the "AWSLakeFormationDataAdmin" policy. If your IAM role does not contain the necessary permissions, ask your AWS IAM administrator to setup the Athena integration as described in the [documentation](https://docs.aws.amazon.com/healthlake/latest/devguide/search-healthlake.html). 

Here, we list all of the athena databases that end with "_healthlake_view". By default, we select the first database in the list but you can tune this selection by changing the value of DATABASE_SELECTION_INDEX.

In [None]:
def list_filtered_athena_databases(catalog_name):
    # Initialize the Athena client
    athena_client = boto3.client('athena')

    # Create a paginator for the list_databases operation
    paginator = athena_client.get_paginator('list_databases')

    # Initialize an empty list to store database names
    database_names = []

    # Paginate through the results
    for page in paginator.paginate(CatalogName=catalog_name):
        for database in page['DatabaseList']:
            db_name = database['Name']
            # Filter databases that end with "_healthlake_view"
            if db_name.endswith("_healthlake_view"):
                database_names.append(db_name)

    return database_names

ATHENA_CATALOG_NAME = 'AwsDataCatalog'
filtered_databases = list_filtered_athena_databases(ATHENA_CATALOG_NAME)
print(f"Found {len(filtered_databases)} matching databases:")
print(filtered_databases)
DATABASE_SELECTION_INDEX = 0
print(f"Selected DATABASE_SELECTION_INDEX: {DATABASE_SELECTION_INDEX}")
HEALTHLAKE_ATHENA_DATABASE = filtered_databases[DATABASE_SELECTION_INDEX]
print(f"Using HEALTHLAKE_ATHENA_DATABASE: {HEALTHLAKE_ATHENA_DATABASE}")

### 2.5 Define Helper Function for Patient Information Retreival from AWS Healthlake

This helper function uses the [AWS SDK for pandas (awswrangler)](https://aws-sdk-pandas.readthedocs.io/en/stable/) to retreive patient data from AWS Healthlake using SQL queries.

In [None]:
def get_llm_input_prompt_patient(input_prompt, patient_id, study_criteria):
    if VERBOSE_LOGGING:
        print(f"Starting Athena queries to retreive patient data from Healthlake for {patient_id}:")

    patient_query = f"select * from patient where id = '{patient_id}'"
    df = wr.athena.read_sql_query(patient_query, workgroup="primary", database=HEALTHLAKE_ATHENA_DATABASE, ctas_approach=False, athena_cache_settings={"max_cache_seconds": 2628000, "max_cache_query_inspections": 500})
    sample_patient_resource = compact_patient_resource(df)

    encounters_query = f"select * from encounter where subject.reference = 'Patient/{patient_id}'"
    df = wr.athena.read_sql_query(encounters_query, workgroup="primary", database=HEALTHLAKE_ATHENA_DATABASE, ctas_approach=False, athena_cache_settings={"max_cache_seconds": 2628000, "max_cache_query_inspections": 500})
    sample_patient_encounters = compact_encounters_resource(df)

    observations_query = f"select * from observation where subject.reference = 'Patient/{patient_id}'"
    df = wr.athena.read_sql_query(observations_query, workgroup="primary", database=HEALTHLAKE_ATHENA_DATABASE, ctas_approach=False, athena_cache_settings={"max_cache_seconds": 2628000, "max_cache_query_inspections": 500})
    sample_patient_observations = compact_observations_resource(df)

    conditions_query = f"select * from condition where subject.reference = 'Patient/{patient_id}'"
    df = wr.athena.read_sql_query(conditions_query, workgroup="primary", database=HEALTHLAKE_ATHENA_DATABASE, ctas_approach=False, athena_cache_settings={"max_cache_seconds": 2628000, "max_cache_query_inspections": 500})
    sample_patient_conditions = compact_conditions_resource(df)

    procedure_query = f"select * from procedure where subject.reference = 'Patient/{patient_id}'"
    df = wr.athena.read_sql_query(procedure_query, workgroup="primary", database=HEALTHLAKE_ATHENA_DATABASE, ctas_approach=False, athena_cache_settings={"max_cache_seconds": 2628000, "max_cache_query_inspections": 500})
    sample_patient_procedures = compact_procedures_resource(df)

    medication_request_query = f"select * from medicationrequest where subject.reference = 'Patient/{patient_id}'"
    df = wr.athena.read_sql_query(medication_request_query, workgroup="primary", database=HEALTHLAKE_ATHENA_DATABASE, ctas_approach=False, athena_cache_settings={"max_cache_seconds": 2628000, "max_cache_query_inspections": 500})
    sample_patient_medication_requests = compact_medications_resource(df)
    
    input_prompt = input_prompt.replace("<patient>REPLACE</patient>",f"<patient>{sample_patient_resource}</patient>")
    input_prompt = input_prompt.replace("<encounter>REPLACE</encounter>",f"<encounter>{sample_patient_encounters}</encounter>")
    input_prompt = input_prompt.replace("<condition>REPLACE</condition>",f"<condition>{sample_patient_conditions}</condition>")
    input_prompt = input_prompt.replace("<observation>REPLACE</observation>",f"<observation>{sample_patient_observations}</observation>")
    input_prompt = input_prompt.replace("<procedure>REPLACE</procedure>",f"<procedure>{sample_patient_procedures}</procedure>")
    input_prompt = input_prompt.replace("<medicationRequest>REPLACE</medicationRequest>",f"<medicationRequest>{sample_patient_medication_requests}</medicationRequest>")
    input_prompt = input_prompt.replace("<study>REPLACE</study>",f"<study>{study_criteria}</study>")

    bedrock_body= {
        "anthropic_version": "bedrock-2023-05-31",
        "max_tokens": 2000,
        "temperature": 0,
        "system": "You are acting as a researcher evaluating if a patient profile is suitable for a clinical trial.",
        "messages": [{"role": "user", "content": input_prompt}]
    }   

    consolidated_patient_object = {
        "FHIR Patient ID": patient_id,
        "patient": sample_patient_resource,
        "encounters": sample_patient_encounters,
        "conditions": sample_patient_conditions,
        "observations": sample_patient_observations,
        "procedures": sample_patient_procedures,
        "medication requests": sample_patient_medication_requests,
    }

    if VERBOSE_LOGGING:
        print(f"Finished retreiving patient data from Healthlake for {patient_id}:")
    
    return (patient_id, generate_random_record_id(), bedrock_body, consolidated_patient_object)

### 2.6 Retreive Patient ID List

We select a list of patient IDs from a file or retreive all patients in our AWS Healthlake database that we are interested in examining for clinical trial eligbility.

In [None]:
USE_FULL_HEALTHLAKE_DATABASE = True # set this to False to use a patient id list in txt format

if not USE_FULL_HEALTHLAKE_DATABASE:
    def read_list_of_patient_ids(filename):
        file = open(filename, 'r')
        lines = file.readlines()
        return list(map(lambda s: s.strip(), lines))

    patient_id_list = read_list_of_patient_ids(f"{part_3_resource_folder}/patient_id_list.txt")

else:
    patient_list_query = "select id from patient"
    df = wr.athena.read_sql_query(patient_list_query, workgroup="primary", database=HEALTHLAKE_ATHENA_DATABASE, ctas_approach=False, athena_cache_settings={"max_cache_seconds": 2628000, "max_cache_query_inspections": 500})
    patient_id_list = df['id'].tolist()

print(patient_id_list)

### 2.7 Multi-threaded Prompt Retreival for Patients

We now define the input prompt to our LLM and we use the python multi-threaded concurrent library to retrieve  patient data in parallel. This takes around 2 minutes on the recommended *t3.medium* instance.

In [None]:
input_prompt = """

We are looking to evaluate if a patient is a good candidate based on the criteria for a scientific study.

Inside <patient></patient> XML tags is the patient profile. 
Inside <encounter></encounter> XML tags is a list of the patient's encounters with healthcare providers.
Inside <observation></observation> XML tags is a list of the patient's observations from healthcare providers.
Inside <condition></condition> XML tags is a list of the patient's conditions diagnosed by healthcare providers.
Inside <procedure></procedure> XML tags is a list of the patient's procedures from healthcare providers.
Inside <medicationRequest></medicationRequest> XML tags is a list of the patient's prescription medication requests.
Inside <study></study> XML tags is the criteria of a scientific study

<patient>REPLACE</patient>

<encounter>REPLACE</encounter>

<observation>REPLACE</observation>

<condition>REPLACE</condition>

<procedure>REPLACE</procedure>

<medicationRequest>REPLACE</medicationRequest>

<study>REPLACE</study>

Please provide your response for each of the study criteria as

(good candidate) Yes this patient is a good candidate for the scientific study.
(likely good candidate) It is likely that this patient is a good candidate for the scientific study but more information might be needed.
(likely bad candidate) It is unlikely that this patient is a good candidate for the scientific study but more information might be needed.
(bad candidate) No this patient is not a good candidate for the scientific study.

and an explanation with references to the patient information for why you chose each response

in a JSON reponse.

If there is no evidence in the patient record for a exclusion condition, choose good candidate.

If the Inclusion criteria you are evaluating is an age range, choose good candidate if the patient falls within the age range, choose bad candidate if the patient falls outside the age range.

""" + f"Today's date is {date.today()}" + """

Put your answers to the user inside <answer></answer> XML tags

Here is an example:
<example>
{
    "Inclusion Criteria":{
        "Age between 12-55": {
            "Answer":"good candidate",
            "Explanation":"The person's birthdate was 1996-12-10 which makes the patient 28 years old"
        },
        "History of Disease A":{
            "Answer":"likely good candidate",
            "Explanation":"There is evidence that the person has had Disease A in the encounter on 2022-01-31"
        }
    },
    "Exclusion Criteria":{
        "Age greater than 55": {
            "Answer":"good candidate",
            "Explanation":"The person's birthdate was 1996-12-10 which makes the patient 28 years old which is not an age greater than 55"
        },
        "History of Disease B": {
            "Answer":"likely bad candidate",
            "Explanation":"I did not find evidence of the patient for Disease B in the previous encounters but they had a related disease of Disease C as referenced by the doct's visit on 2021-10-13"
        },
    }
}
</example>

"""

print("Starting parallel retreival of patient data from Healthlake. Please wait for completion before moving to the next cell.")

with concurrent.futures.ThreadPoolExecutor(max_workers = 20) as executor:
    # Submit tasks to the executor
    futures = [executor.submit(get_llm_input_prompt_patient, input_prompt, patient_id, study_criteria) for patient_id in patient_id_list]
    # Collect the results
    results = [future.result() for future in concurrent.futures.as_completed(futures)]

print("Finished Retreiveing all patient data from Healthlake.")


### 2.8 Generate Bedrock Batch Inference Input File

For each patient we examine, we store the patient information as individual files in S3 so we can view them later. For each patient, we store the full text of the prompt we want the LLM to evaluate as a new line in json format. One of the requirements for Bedrock Batch Inference is a minimum of 1000 records per job. Since our demo doesn't have this many records, we will have to add some additional blank records to meet the minimum required per job. Additional quotas and requirements can be found in the [documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/quotas.html#quotas-batch)

In [None]:
def dummy_record():
    return {
        "anthropic_version": "bedrock-2023-05-31",
        "max_tokens": 1,
        "temperature": 0,
        "messages": [{"role": "user", "content": "I"}]
    }

record_id_to_patient_info_map = {}

batch_job_dict = []

for result in results:
    patient_id = result[0]
    
    consolidated_patient_info = json.dumps(result[3], indent=4)

    record_id = result[1]

    record_id_to_patient_info_map[record_id] = consolidated_patient_info

    llm_input = result[2]
    
    patient = {
        "recordId" : record_id,
        "modelInput" : llm_input
    }
    batch_job_dict.append(patient)

# Add dummy records to meet the minimum required per job
while len(batch_job_dict) < 1000:
    batch_job_dict.append({
        "recordId" : generate_random_record_id(),
        "modelInput" : dummy_record()
    })

print(f"batch_job_dict length: {len(batch_job_dict)}")

batch_inference_input_filename = "batch_inference_input.jsonl"
with open(batch_inference_input_filename, "w") as outfile:
    for job in batch_job_dict:
        outfile.write(json.dumps(job) + "\n")

print(f"Generated batch inference input as {batch_inference_input_filename}")

record_id_to_patient_info_map_filename = "record_id_to_patient_info_map.json"
with open(record_id_to_patient_info_map_filename, "w") as outfile:
    outfile.write(json.dumps(record_id_to_patient_info_map))

print(f"Generated record id to patient info map as {record_id_to_patient_info_map_filename}")

if VERBOSE_LOGGING:
    print(f"record_id_to_patient_info_map: {json.dumps(record_id_to_patient_info_map)}")

### 2.9 Upload Bedrock Batch Inference Input File

Batch inference requires the input file to be stored in S3.

In [None]:
bedrock_batch_processing_input_folder = "Clinical_Trials_Gen_AI_Workshop_Batch_Inference_Input"
bedrock_batch_processing_s3_input_filepath = f"{bedrock_batch_processing_input_folder}/{batch_inference_input_filename}"
S3_CLIENT.upload_file(batch_inference_input_filename, SAGEMAKER_S3_BUCKET, bedrock_batch_processing_s3_input_filepath)
print(f"Successfully uploaded {batch_inference_input_filename} to s3://{SAGEMAKER_S3_BUCKET}/{bedrock_batch_processing_s3_input_filepath}")

### 2.10 Bedrock Batch Inference Role Configuration

Bedrock Batch Inference requires a role that has permissions to read the S3 Input file and write the output as a file in the designated location in S3. By default, we will use the name **"Bedrock-Batch-Inference-Role"** but this can be changed by editing the *bedrock_batch_inference_role_name* variable. This role can be created using the instructions [here](https://docs.aws.amazon.com/bedrock/latest/userguide/batch-inference-permissions.html). If you are running this notebook as part of an AWS workshop event, this role may have been pre-created for you. If you are uanble to create this role, you can still continue through this workshop using the sample output.

In [None]:
bedrock_batch_inference_role_name = "Bedrock-Batch-Inference-Role"
bedrock_batch_inference_role_arn = f"arn:aws:iam::{ACCOUNT_ID}:role/{bedrock_batch_inference_role_name}"

def check_role_exists(role_name):
    iam = boto3.client('iam')
    try:
        iam.get_role(RoleName=role_name)
        return True
    except ClientError as e:
        if e.response['Error']['Code'] == 'NoSuchEntity':
            return False
        raise

if check_role_exists(bedrock_batch_inference_role_name):
    print(f"Role '{bedrock_batch_inference_role_name}' exists")
else:
    print(f"Role '{bedrock_batch_inference_role_name}' does not exist")

### 2.11 Executing Bedrock Batch Inference

Bedrock Batch Inference can take some time to finish execution. This section may take 30-60+ minutes to complete as Amazon Bedrock processes the Batch Inference Job. To skip this live batch processing, by default we will use sample output data. 

In [None]:
USE_SAMPLE_BATCH_INFERENCE_OUTPUT = True # Set to "False" to use live batch inference output.
if USE_SAMPLE_BATCH_INFERENCE_OUTPUT:
    print("We will use the stored sample_batch_inference_output.jsonl to demonstrate patient analysis.")
else:
    print("We will send the input file to Bedrock Batch Inference and wait for the results.")

#### 2.11.1 Start Bedrock Batch Inference Job

Bedrock Batch Inference jobs are asynchronous, so we immediately receive the job ARN (Amazon Resource Name) after submitting the job. This step is skipped if we are using the sample batch inference output.

In [None]:
if not USE_SAMPLE_BATCH_INFERENCE_OUTPUT:
    # Input data configuration
    input_data_config = {
      "s3InputDataConfig": {
        "s3Uri": f"s3://{SAGEMAKER_S3_BUCKET}/{bedrock_batch_processing_s3_input_filepath}"
      }
    }
    print(f"Input data config: {input_data_config}")
    
    bedrock_batch_processing_output_folder = "Clinical_Trials_Gen_AI_Workshop_Batch_Inference_Ouput"
    bedrock_batch_processing_s3_output_uri = f"s3://{SAGEMAKER_S3_BUCKET}/{bedrock_batch_processing_output_folder}/"
    # Output data configuration
    output_data_config = {
      "s3OutputDataConfig": {
        "s3Uri": bedrock_batch_processing_s3_output_uri
      }
    }
    print(f"Output data config: {output_data_config}")


    # Create batch inference job
    response = BEDROCK_CLIENT.create_model_invocation_job(
      roleArn=bedrock_batch_inference_role_arn,
      modelId=CLAUDE_3_SONNET_MODEL_ID,
      inputDataConfig=input_data_config,
      outputDataConfig=output_data_config,
      jobName="BatchInferenceJob"+generate_random_record_id()
    )
    
    job_arn = response["jobArn"]
    print(f"Job ARN: {job_arn}")
else:
    print("Using sample batch inference output.")

#### 2.11.2 Query Bedrock Batch Inference Job Status

We query the status of our Bedrock Batch Inference Job and print the status every 5 seconds until the job is completed. This step is skipped if we are using the sample batch inference output.

In [None]:
if not USE_SAMPLE_BATCH_INFERENCE_OUTPUT:
    batch_inference_output_filename = "batch_inference_output.jsonl"
    full_status = BEDROCK_CLIENT.get_model_invocation_job(jobIdentifier=job_arn)
    status = full_status["status"]
    if VERBOSE_LOGGING:
        print(f"Full status: {full_status}")
    else:
        print(f'Status: {full_status["status"]}')
    counter = 0
    while status != "Completed":
        print(f"seconds: {counter}")
    
        full_status = BEDROCK_CLIENT.get_model_invocation_job(jobIdentifier=job_arn)
        status = full_status["status"]
        time.sleep(5)
        counter += 5
        if VERBOSE_LOGGING:
            print(f"Full status: {full_status}")
        else:
            print(f'Status: {full_status["status"]}')
    
    job_id = job_arn[job_arn.rfind("/") + 1:]
    print(f"job_id: {job_id}")
    s3_output_file_location = f"{bedrock_batch_processing_output_folder}/{job_id}/{batch_inference_input_filename}.out"
    #print(f"s3_output_file_location: {s3_output_file_location}")
    S3_CLIENT.download_file(SAGEMAKER_S3_BUCKET, s3_output_file_location, batch_inference_output_filename)
    print(f"Downloaded batch inference output from s3://{SAGEMAKER_S3_BUCKET}/{s3_output_file_location} to {batch_inference_output_filename}")
else:
    batch_inference_output_filename = f"{part_3_resource_folder}/sample_batch_inference_output.jsonl"
    print(f"Using sample batch inference output at {batch_inference_output_filename}.")

### 2.12 Helper Function for flatening result output

We will now parse the response of the LLM assistant which is currently a string of a multi-level JSON object. To streamline the analysis process, we will flatten this multi-level object to a single level by combining each level of the JSON object using an underscore (`_`) character.

An example of this is transforming 

```
"Exclusion Criteria": {
        "Known adverse reaction or tolerance to study medication": {
            "Answer": "good candidate",
            "Explanation": "There is no information provided about any known adverse reactions or tolerances to study medications."
        },
        "Headache due to trauma": {
            "Answer": "good candidate",
            "Explanation": "There is no evidence in the patient record of a headache due to trauma."
        }
}
```

to

```
{
    Exclusion Criteria_Known adverse reaction or tolerance to study medication_Answer : "good candidate",
    Exclusion Criteria_Known adverse reaction or tolerance to study medication_Explanation : "There is no information provided about any known adverse reactions or tolerances to study medications.",
    Exclusion Criteria_Headache due to trauma_Answer : "good candidate",
    Exclusion Criteria_Headache due to trauma_Explanation : "There is no evidence in the patient record of a headache due to trauma."
}
```

In [None]:
def flatten_patient_result(model_completion):
    print(model_completion)
    raw_claude_output = model_completion.replace("\n","")
    xml_start_loc = raw_claude_output.find("<answer>") + 8
    xml_end_loc = raw_claude_output.find("</answer>")
    
    if xml_start_loc == -1 or xml_end_loc == -1:
        raise Exception("No <answer> xml tag detected in claude output")
    
    def flat_keys(obj, new_obj={}, keys=[]):
        for key, value in obj.items():
            if isinstance(value, dict):
                flat_keys(obj[key], new_obj, keys + [key])
            else:
                new_obj['_'.join(keys + [key])] = value
        return new_obj
    
    flattened_patient_result = flat_keys(json.JSONDecoder().decode(raw_claude_output[xml_start_loc:xml_end_loc]))
    return flattened_patient_result

### 2.13 Generate Patient Suitability Score

We can now use the information we have about each study criteria for a patient to generate a discrete patient suitability score. We calculate this by adding up the score that the patient receives for each study criteria using the formula below

| LLM Response of Study Criteria    | Score |
| -------- | ------- |
| (good candidate) Yes this patient is a good candidate for the scientific study. | 3 |
| (likely good candidate) It is likely that this patient is a good candidate for the scientific study but more information might be needed. | 2 |
| (likely bad candidate) It is unlikely that this patient is a good candidate for the scientific study but more information might be needed. | 1 |
| (bad candidate) No this patient is not a good candidate for the scientific study. | 0 |

In [None]:
def get_patient_score(flattened_patient_obj):
    total_possible_score = 0
    patient_score = 0
    for criteria, result in flattened_patient_obj.items():
    
        if criteria.endswith("_Answer"):
            total_possible_score += 3
        
            match result:
                case "good candidate":
                    patient_score += 3
                case "likely good candidate":
                    patient_score += 2
                case "likely bad candidate":
                    patient_score += 1
                case "bad candidate":
                    patient_score += 0
                case _:
                    raise Exception(f"Unexpected result {result} in flattened_patient_obj")
    
    patient_suitability_percentage = round((patient_score/total_possible_score) * 100, 2)
    return patient_suitability_percentage

### 2.14 Consolidate Patient Profile

For each patient processed, we get the link of the consolidated patient information, the patient score, and the answers and explanations for each study criteria.

In [None]:
with open(batch_inference_output_filename) as f:
    batch_output_raw = [json.loads(line) for line in f]

consolidated_patient_result_output = []
discarded_entries = 0
for batch_output_item in batch_output_raw:
    if len(batch_output_item["modelOutput"]['content'][0]['text']) < 5:
        discarded_entries += 1
        continue
        
    patient_dict = flatten_patient_result(batch_output_item["modelOutput"]['content'][0]['text'])
    patient_dict["Score"] = get_patient_score(patient_dict)
    # if USE_SAMPLE_BATCH_INFERENCE_OUTPUT:
    patient_dict["patient_id"] = batch_output_item["recordId"]
    # else:
    #     patient_dict["patient_id"] = record_id_to_patient_id_map[batch_output_item["recordId"]]

    if VERBOSE_LOGGING:
        print(patient_id)
        print(patient_score)
    consolidated_patient_result_output.append(patient_dict)

print(f"Discarded {discarded_entries} entries due to empty LLM output.")
print(f"Finished processing {len(consolidated_patient_result_output)} patients.")

### 2.15 Helper function for Color Coding answers

We define a helper function to change the background color of a cell based on the answer from the LLM on clinical trial eligibility.

In [None]:
# Define a color mapping
color_mapping = {
    'good candidate': '#4CAF50',  # Green
    'likely good candidate': '#FFEB3B',  # Yellow
    'likely bad candidate': '#FF6D00',  # Orange
    'bad candidate': '#D50000',  # Red
}

# Function to generate cell style
def get_cell_style(val):
    color = color_mapping.get(str(val), "#FFFFFF")
    return f'background-color: {color}; font-weight: bold;' if color != "#FFFFFF" else ''


# Function to safely render cell content
def render_cell(val):
    if isinstance(val, str) and val.strip().startswith('<') and val.strip().endswith('>'):
        return val  # Render as HTML
    return html.escape(str(val))  # Escape as plain text

### 2.16 Display consolidated Patient Findings

We display all of the patients and data we have processed in a consolidated webpage.

In [None]:
consolidated_patient_df = pd.DataFrame(consolidated_patient_result_output)
score_column = consolidated_patient_df.pop("Score")
patient_id_column = consolidated_patient_df.pop("patient_id")
consolidated_patient_df.insert(0, "Score", score_column)
consolidated_patient_df.insert(0, "patient_id", patient_id_column)
consolidated_patient_df = consolidated_patient_df.sort_values(by="Score", ascending=False)
# display(consolidated_patient_df)

# Generate HTML
html_template = '''
<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>Patient Information</title>
    <style>
        body, html {{
            height: 100%;
            margin: 0;
            padding: 0;
            font-family: Arial, sans-serif;
        }}
        .container {{
            display: flex;
            height: 100vh;
        }}
        .table-container {{
            flex: 1;
            overflow: auto;
        }}
        .side-panel {{
            width: 0;
            background-color: #f1f1f1;
            position: fixed;
            top: 0;
            right: 0;
            height: 100%;
            z-index: 1000;
            box-shadow: -2px 0 5px rgba(0,0,0,0.2);
            overflow: hidden;
            transition: width 0.3s;
            display: flex;
            flex-direction: column;
        }}
        .resize-handle {{
            width: 10px;
            background-color: #ccc;
            cursor: ew-resize;
            position: absolute;
            top: 0;
            bottom: 0;
            left: 0;
        }}
        .side-panel-header {{
            position: sticky;
            top: 0;
            background-color: #f1f1f1;
            padding: 20px;
            z-index: 2;
            border-bottom: 1px solid #ddd;
        }}
        .side-panel-content {{
            flex-grow: 1;
            overflow-y: auto;
            overflow-x: hidden;
            padding: 0 20px 20px 20px;
        }}
        .close-btn {{
            position: sticky;
            top: 10px;
            float: right;
            font-size: 36px;
            cursor: pointer;
        }}
        table {{
            border-collapse: separate;
            border-spacing: 0;
            width: 100%;
        }}
        th, td {{
            border: 1px solid #ddd;
            padding: 8px;
            text-align: center;
        }}
        th {{
            background-color: #f2f2f2;
            position: sticky;
            top: 0;
            z-index: 10;
        }}
        .sticky-col {{
            position: sticky;
            background-color: #f2f2f2;
            z-index: 5;
        }}
        .sticky-col.header {{
            z-index: 11;
        }}
        th:first-child, td:first-child {{
            left: 0;
        }}
        th:nth-child(2), td:nth-child(2) {{
            left: 40px;
        }}
        .patient-id {{
            cursor: pointer;
            color: blue;
            text-decoration: underline;
        }}
        .patient-info-container {{
            overflow-x: auto;
            white-space: nowrap;
        }}
        #patientInfo {{
            display: inline-block;
            white-space: pre;
            font-family: monospace;
        }}
        .horizontal-scroll {{
            overflow-x: scroll;
            overflow-y: hidden;
            height: 12px;
            background-color: #ddd;
            position: sticky;
            top: 0;
            z-index: 3;
            opacity: 1;
            transition: opacity 0.3s;
        }}
        .horizontal-scroll::-webkit-scrollbar {{
            height: 12px;
        }}
        .horizontal-scroll::-webkit-scrollbar-thumb {{
            background-color: #888;
            border-radius: 6px;
        }}
        .horizontal-scroll::-webkit-scrollbar-track {{
            background-color: #ddd;
            height: 5px !important;
        }}
        .scroll-content {{
            height: 1px;
        }}
    </style>
</head>
<body>
    <div class="container">
        <div class="table-container">
            <table>
                <thead>
                    <tr>
                        <th class="sticky-col header" style="min-width: 40px;">#</th>
                        <th class="sticky-col header" style="min-width: 100px;">patient_id</th>
                        {table_headers}
                    </tr>
                </thead>
                <tbody>
                    {table_rows}
                </tbody>
            </table>
        </div>
        <div id="sidePanelContainer" class="side-panel">
            <div id="resizeHandle" class="resize-handle"></div>
            <div class="side-panel-header">
                <span class="close-btn" onclick="closeNav()">&times;</span>
                <h2>Patient Information</h2>
            </div>
            <div class="horizontal-scroll">
                <div class="scroll-content"></div>
            </div>
            <div class="side-panel-content">
                <div class="patient-info-container">
                    <pre id="patientInfo"></pre>
                </div>
            </div>
        </div>
    </div>

    <script>
        const patientInfoMap = {patient_info_json};
        const sidePanel = document.getElementById("sidePanelContainer");
        const resizeHandle = document.getElementById('resizeHandle');
        const horizontalScroll = document.querySelector('.horizontal-scroll');
        const patientInfoContainer = document.querySelector('.patient-info-container');
        const scrollContent = document.querySelector('.scroll-content');
        const patientInfoElement = document.getElementById("patientInfo");
        
        let isResizing = false;
        let startX;
        let startWidth;

        function showPatientInfo(patientId) {{
            if (patientInfoMap[patientId]) {{
                let patientData = JSON.parse(patientInfoMap[patientId]);
                let formattedInfo = JSON.stringify(patientData, null, 4);
                patientInfoElement.textContent = formattedInfo;
                sidePanel.style.width = "400px";
                updateScrollContent();
            }} else {{
                patientInfoElement.textContent = "No information available for this patient.";
                sidePanel.style.width = "400px";
            }}
        }}

        function closeNav() {{
            sidePanel.style.width = "0";
        }}

        function updateScrollContent() {{
            requestAnimationFrame(() => {{
                scrollContent.style.width = `${{patientInfoElement.scrollWidth}}px`;
            }});
        }}

        resizeHandle.addEventListener('mousedown', (e) => {{
            isResizing = true;
            startX = e.clientX;
            startWidth = parseInt(getComputedStyle(sidePanel).width, 10);
            document.addEventListener('mousemove', resize);
            document.addEventListener('mouseup', stopResize);
        }});

        function resize(e) {{
            if (!isResizing) return;
            
            const width = startWidth + (startX - e.clientX);
            
            if (width > 200 && width < window.innerWidth - 100) {{
                sidePanel.style.width = `${{width}}px`;
                updateScrollContent();
            }}
        }}

        function stopResize() {{
            isResizing = false;
            document.removeEventListener('mousemove', resize);
            document.removeEventListener('mouseup', stopResize);
        }}

        horizontalScroll.addEventListener('scroll', () => {{
            patientInfoContainer.scrollLeft = horizontalScroll.scrollLeft;
        }});

        patientInfoContainer.addEventListener('scroll', () => {{
            horizontalScroll.scrollLeft = patientInfoContainer.scrollLeft;
        }});

        // Refresh scroll content width on window resize
        window.addEventListener('resize', updateScrollContent);
    </script>
</body>
</html>
'''

# Generate table headers
table_headers = ' '.join(f'<th>{html.escape(str(col))}</th>' for col in consolidated_patient_df.columns[1:])

# Generate table rows
table_rows = ''
for index, row in consolidated_patient_df.iterrows():
    table_rows += f'<tr>'
    table_rows += f'<td class="sticky-col">{index}</td>'
    table_rows += f'<td class="sticky-col patient-id" onclick="showPatientInfo(\'{row["patient_id"]}\')">{row["patient_id"]}</td>'
    table_rows += ' '.join(f'<td style="{get_cell_style(val)}">{render_cell(val)}</td>' for val in row[1:])
    table_rows += '</tr>'
    
# Load or create patient_info_json
if not USE_SAMPLE_BATCH_INFERENCE_OUTPUT:
    patient_info_json = json.dumps(record_id_to_patient_info_map)
else:
    sample_record_id_to_patient_info_map_filename = f"{part_3_resource_folder}/sample_record_id_to_patient_info_map.json"
    with open(sample_record_id_to_patient_info_map_filename, 'r') as file:
        patient_info_json = file.read()

# Fill in the template
html_content = html_template.format(
    table_headers=table_headers,
    table_rows=table_rows,
    patient_info_json=patient_info_json
)

# Save the HTML content to a file
consolidated_patient_table_name = "consolidated_patient_table.html"
with open(consolidated_patient_table_name, 'w') as f:
    f.write(html_content)

print(f"Finished generating {consolidated_patient_table_name}.")

### 2.17 View Consolidated Patient Findings

Now let's walk through the consolidated patient table together. To open the Consolidated Patient Table HTML File, you may either 

    1. Right-click the HTML file in the sidebar and click download. Open this file from your computer's download folder in any browser of your choice.
    2. Double-click the HTML file in the sidebar to open it in a new tab in JupyterLab. Click "Trust HTML" in the top left corner if this is your first time opening the file.

You'll notice our sample study and output has resulted in a wide variety of responses across patients and clinical study criteria. We have sorted the table based on our calculated scoring with every criteria equally weighted. This view will now allow you to investigate each patient and criteria in depth to determine the best patient(s) to select for a clinical trial. Clicking the blue hyperlinked patient_id shows the details that we fed as input to our LLM for that patient.