In [None]:
# !pip install --upgrade boto3 
# !pip install evaluate
# !pip install rouge_score

In [None]:
import json
import boto3
import pandas as pd
import re
from typing import Dict, List
from tqdm import tqdm

In [None]:
endpoint_name = "meta-textgeneration-llama-2-70b-f-2024-02-12-18-43-47-048"

def query_endpoint(payload):
    client = boto3.client("sagemaker-runtime")
    response = client.invoke_endpoint(
        EndpointName=endpoint_name,
        ContentType="application/json",
        Body=json.dumps(payload),
    )
    response = response["Body"].read().decode("utf8")
    response = json.loads(response)
    return response

def format_messages(messages: List[Dict[str, str]]) -> List[str]:
    """Format messages for Llama-2 chat models.
    
    The model only supports 'system', 'user' and 'assistant' roles, starting with 'system', then 'user' and 
    alternating (u/a/u/a/u...). The last message must be from 'user'.
    """
    prompt: List[str] = []

    if messages[0]["role"] == "system":
        content = "".join(["<<SYS>>\n", messages[0]["content"], "\n<</SYS>>\n\n", messages[1]["content"]])
        messages = [{"role": messages[1]["role"], "content": content}] + messages[2:]

    for user, answer in zip(messages[::2], messages[1::2]):
        prompt.extend(["<s>", "[INST] ", (user["content"]).strip(), " [/INST] ", (answer["content"]).strip(), "</s>"])

    prompt.extend(["<s>", "[INST] ", (messages[-1]["content"]).strip(), " [/INST] "])

    return "".join(prompt)

In [None]:
def predict(prompt_string, context):
    # dialog = [{"role": "user", "content": f"Generate a new short impression from the following Radiology findings section using medical vocabulary and output it within <impression> tags. Findings: {context}"}]
    dialog = [{"role": "user", "content": prompt_string.format(context)}]
    prompt = format_messages(dialog)
    payload = {"inputs": prompt, "parameters": {"max_new_tokens": 712, "top_p": 0.9, "temperature": 0.6}}
    response = query_endpoint(payload)
    impression = re.findall("<impression>(.*?)</impression>", response[0]['generated_text'], re.DOTALL)
    if impression == []:
        return response[0]['generated_text']
    return impression[0]

In [None]:
def generate_impressions(prompt, filename):
    dev_df = pd.read_csv(filename)
    print(f"num rows: {len(dev_df)}")
    
    output_bedrock = []
    for index, row in tqdm(dev_df.iterrows(), total=dev_df.shape[0]):
        generated_impression = predict(prompt, row['findings'])
        output_bedrock.append(generated_impression)
    
    dev_df['Llama-generated-impressions'] = output_bedrock    
    return dev_df

In [None]:
def calculate_rouge_scores(dev_df, prompt, filename):
    print('PROMPT used:\n', prompt)
    print("-"*50)
    rouge_score = evaluate.load("rouge") 
    result_pretrained_dev1 = rouge_score.compute(predictions=list(dev_df["claude-v2-generated-impressions"]), references=list(dev_df["impression"]))
    print("ROUGE Score for claude-v2 model on ", filename)
    print(result_pretrained_dev1)
    print("-"*50)

## Zero shot prompting

In [None]:
# prompt_zero_shot = f"""Human: Generate a new short impression from the following Radiology findings section using medical vocabulary and output it within <impression> tags. Findings: {context}

# Assistant:"""

prompt_zero_shot = """Human: Generate radiology report impressions based on the following findings and output it within <impression> tags. Findings: {}

Assistant:"""

#### generate impressions for dev1_MIMICXR.csv

In [None]:
filename='dev1_MIMICXR.csv'

dev1_df_zero_shot = generate_impressions(prompt_zero_shot, filename)

##### ROUGE Score Computation for dev1_MIMICXR.csv

In [None]:
calculate_rouge_scores(dev1_df_zero_shot, prompt=prompt_zero_shot, filename='dev1_MIMICXR.csv')

#### generate impressions for dev2_Indiana.csv

In [None]:
filename='dev2_Indiana.csv'

dev2_df_zero_shot = generate_impressions(prompt_zero_shot, filename)

##### ROUGE Score Computation for dev2_Indiana.csv

In [None]:
calculate_rouge_scores(dev2_df_zero_shot, prompt=prompt_zero_shot, filename='dev2_Indiana.csv')