In [None]:
import json
from typing import Dict
from langchain_community.llms.sagemaker_endpoint import LLMContentHandler, SagemakerEndpoint
from langchain_core.prompts import PromptTemplate
import openpyxl

In [None]:
endpoint_name = "###"  # input
region_name = "###"  # input
parameters = {
    "max_new_tokens": 1024,
    "temperature": 0.1,
    "top_k": 10,
}

profile_name = "###"  # input

In [None]:
def load_jsonl(file_path):
    data = []
    with open(file_path, 'r') as file:
        for line in file:
            data.append(json.loads(line))
    return data

# Example usage
file_path = 'recall_validation_dataset.jsonl'
test_data = load_jsonl(file_path)

In [None]:
test_data[0]

In [None]:
def build_prompt(model_name):
    if model_name == 'mistral-inst':
        template = """ <s>[INST]
You are a helpful assistant that provides direct and concise answers based only on the provided information.
Use the following information from the course information to answer the user's question. If the answer is not present in the provided information, your answer must only be 'I do not know the answer'.
Do not refer to the fact that there are provided course documents in your answer, just directly answer the question.
< -- COURSE INFORMATION -- >
{context}
< -- END COURSE INFORMATION -- >
< -- QUESTION -- >
{question}
< -- END QUESTION -- >
Solution:
[/INST]"""
    else: # default zephyr-7b-beta
        template = """<|system|> You are a helpful assistant that provides direct and concise answers based only on the provided information.</s>
<|user|> Use the following information from the course documents to answer the user's question. If the answer is not present in the provided information, your answer must only be 'I do not know the answer'.
< -- QUESTION -- >
{question}
< -- END QUESTION -- >
< -- COURSE INFORMATION -- >
{context}
< -- END COURSE INFORMATION -- >
</s>
<|assistant|> """
    prompt = PromptTemplate(
        template=template, input_variables=["context", "question"],
    )
    return prompt

In [None]:
class ContentHandler(LLMContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:

        input_str = json.dumps({"inputs": prompt, "parameters": model_kwargs})
        return input_str.encode("utf-8")

    def transform_output(self, output: bytes) -> str:
        response_json = json.loads(output.read().decode("utf-8"))
        return response_json[0]["generated_text"]

content_handler = ContentHandler()

llm_open_args = {
    "endpoint_name": endpoint_name,
    "region_name": region_name,
    "model_kwargs": parameters,
    "content_handler": content_handler
}

if profile_name != '':
    llm_open_args["credentials_profile_name"] = profile_name

llm_open = SagemakerEndpoint(**llm_open_args)

In [None]:
prompt = build_prompt('zephyr-7b-beta')
chain = prompt | llm_open

workbook = openpyxl.Workbook()
sheet = workbook.active

# Add data to cells
sheet['A1'] = 'context'
sheet['B1'] = 'question'
sheet['C1'] = 'answer'
sheet['D1'] = 'response'
sheet['E1'] = 'score'

for i, data in enumerate(test_data):
    if i%10 == 0:
        print(f'Processing {i}th data')
    context = data['context']
    question = data['question']
    answer = data['answer']
    response = chain.invoke({"context":context, "question": question})
    sheet[f'A{i+2}'] = context
    sheet[f'B{i+2}'] = question
    sheet[f'C{i+2}'] = answer
    sheet[f'D{i+2}'] = response

print('processed all data')
workbook.save('example-zeph.xlsx')

In [None]:
prompt = build_prompt('mistral-inst')
chain = prompt | llm_open

context = "Midterm 1 is worth 15%. Midterm 2 is worth 20%. The final exam is worth 30%. The homework is worth 35%."
question = "Is the final next week?"
answer = "No"
response = chain.invoke({"context":context, "question": question})
print(f'Context: {context}')
print(f'Question: {question}')
print(f'Answer: {answer}')
print(f'Response: {response}')