In [None]:
!python --version

In [None]:
import sagemaker, boto3, json
from sagemaker.session import Session

sagemaker_session = Session()
aws_role = sagemaker_session.get_caller_identity_arn()
aws_region = boto3.Session().region_name
sess = sagemaker.Session()

In [None]:
newline, bold, unbold = "\n", "\033[1m", "\033[0m"


def query_endpoint(encoded_text, endpoint_name):
    client = boto3.client("runtime.sagemaker")
    response = client.invoke_endpoint(
        EndpointName=endpoint_name, ContentType="application/x-text", Body=encoded_text
    )
    return response


def parse_response(query_response):
    model_predictions = json.loads(query_response["Body"].read())
    generated_text = model_predictions["generated_text"]
    return generated_text

In [None]:
# Input must be a json
payload = {
    "text_inputs": "Tell me the steps to make a pizza",
    "max_length": 100,
    "num_return_sequences": 3,
    "top_k": 50,
    "top_p": 0.95,
    "do_sample": True,
}

endpoint_name = "example-huggingface-text2text-flan-t5-l-2023-07-11-02-51-19-194"

def query_endpoint_with_json_payload(encoded_json, endpoint_name):
    client = boto3.client("runtime.sagemaker")
    response = client.invoke_endpoint(
        EndpointName=endpoint_name, ContentType="application/json", Body=encoded_json
    )
    return response


query_response = query_endpoint_with_json_payload(
    json.dumps(payload).encode("utf-8"), endpoint_name=endpoint_name
)


def parse_response_multiple_texts(query_response):
    model_predictions = json.loads(query_response["Body"].read())
    generated_text = model_predictions["generated_texts"]
    return generated_text


generated_texts = parse_response_multiple_texts(query_response)
print(generated_texts)

In [None]:
context = """Paste the context text here. 
"""


In [None]:
prompts = [
    # """Answer based on context:\n\n{context}\n\n{question}""",
    # """{context}\n\nAnswer this question based on the article: {question}""",
    # """{context}\n\n{question}""",
    # """{context}\nAnswer this question: {question}""",
    """Read this article and answer this question {context}\n{question}""",
    """{context}\n\nBased on the above article, answer a question. {question}""",
    #"""Write an article that answers the following question: {question} {context}""",
]

# question = "what are the key features of this product?"
# question = "who owns this product?"
# question = "can I install the product my self?"
# question = "do I need an electrician to install the product?"
question = "what is DIUS?"

parameters = {
    "max_length": 200,
    "num_return_sequences": 1,
    "top_k": 50,
    "top_p": 0.95,
    "do_sample": True,
}


for each_prompt in prompts:
    input_text = each_prompt.replace("{context}", context)
    input_text = input_text.replace("{question}", question)
    print(f"{bold} For prompt{unbold}: '{each_prompt}'{newline}")
    payload = {"text_inputs": input_text, **parameters}
    query_response = query_endpoint_with_json_payload(
        json.dumps(payload).encode("utf-8"), endpoint_name=endpoint_name
    )
    generated_texts = parse_response_multiple_texts(query_response)
    print(f"{bold} The reasoning result is{unbold}: '{generated_texts}'{newline}")