# Ground Truth Generator
- For retriever
- For generator

## Setting
 - Auto Reload
 - path for utils

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys, os
module_path = "../../.."
sys.path.append(os.path.abspath(module_path))

## 1. Bedrock Client 생성

In [3]:
import json
import boto3
from pprint import pprint
from termcolor import colored
from utils import bedrock, print_ww
from utils.bedrock import bedrock_info

### ---- ⚠️ Un-comment and edit the below lines as needed for your AWS setup ⚠️ ----
- os.environ["AWS_DEFAULT_REGION"] = "<REGION_NAME>"  # E.g. "us-east-1"
- os.environ["AWS_PROFILE"] = "<YOUR_PROFILE>"
- os.environ["BEDROCK_ASSUME_ROLE"] = "<YOUR_ROLE_ARN>"  # E.g. "arn:aws:..."
- os.environ["BEDROCK_ENDPOINT_URL"] = "<YOUR_ENDPOINT_URL>"  # E.g. "https://..."

In [4]:
boto3_bedrock = bedrock.get_bedrock_client(
    assumed_role=os.environ.get("BEDROCK_ASSUME_ROLE", None),
    endpoint_url=os.environ.get("BEDROCK_ENDPOINT_URL", None),
    region=os.environ.get("AWS_DEFAULT_REGION", None),
)

aws_region = os.environ.get("AWS_DEFAULT_REGION", None)
print (colored("\n== FM lists ==", "green"))
pprint (bedrock_info.get_list_fm_models())

Create new client
  Using region: None
  Using profile: None
boto3 Bedrock client successfully created!
bedrock-runtime(https://bedrock-runtime.us-east-1.amazonaws.com)
[32m
== FM lists ==[0m
{'Claude-Instant-V1': 'anthropic.claude-instant-v1',
 'Claude-V1': 'anthropic.claude-v1',
 'Claude-V2': 'anthropic.claude-v2',
 'Claude-V2-1': 'anthropic.claude-v2:1',
 'Cohere-Embeddings-En': 'cohere.embed-english-v3',
 'Cohere-Embeddings-Multilingual': 'cohere.embed-multilingual-v3',
 'Command': 'cohere.command-text-v14',
 'Command-Light': 'cohere.command-light-text-v14',
 'Jurassic-2-Mid': 'ai21.j2-mid-v1',
 'Jurassic-2-Ultra': 'ai21.j2-ultra-v1',
 'Llama2-13b-Chat': 'meta.llama2-13b-chat-v1',
 'Titan-Embeddings-G1': 'amazon.titan-embed-text-v1',
 'Titan-Text-G1': 'amazon.titan-text-express-v1',
 'Titan-Text-G1-Light': 'amazon.titan-text-lite-v1'}


## 2.LLM 로딩 (Claude-v2.1 for retriever, Jurassic for reasoning)

In [5]:
from langchain.llms.bedrock import Bedrock
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler

In [6]:
llm_claude = Bedrock(
    model_id=bedrock_info.get_model_id(model_name="Claude-V2-1"),
    client=boto3_bedrock,
    model_kwargs={
        "max_tokens_to_sample": 512
    },
    streaming=False,
    callbacks=[StreamingStdOutCallbackHandler()]
)

llm_jurassic = Bedrock(
    model_id=bedrock_info.get_model_id(model_name="Jurassic-2-Ultra"),
    client=boto3_bedrock,
    model_kwargs={
        #"max_tokens": 512,
        "maxTokens": 512
    },
    streaming=False,
    callbacks=[StreamingStdOutCallbackHandler()]
)

## 2. OpenSearch 정의
### 선수 조건
- 01_preprocess_docs/02_load_docs_opensearch.ipynb를 통해서 OpenSearch Index 가 생성이 되어 있어야 합니다.
#### [중요] 아래에 aws parameter store 에 아래 인증정보가 먼저 입력되어 있어야 합니다.
- 01_preprocess_docs/01_parameter_store_example.ipynb 참고

In [7]:
from utils.proc_docs import get_parameter

In [8]:
aws_region = "us-east-1"
ssm = boto3.client("ssm", aws_region)

opensearch_domain_endpoint = get_parameter(
    boto3_client = ssm,
    parameter_name = 'knox_opensearch_domain_endpoint',
)

opensearch_user_id = get_parameter(
    boto3_client = ssm,
    parameter_name = 'knox_opensearch_userid',
)

opensearch_user_password = get_parameter(
    boto3_client = ssm,
    parameter_name = 'knox_opensearch_password',
)
http_auth = (opensearch_user_id, opensearch_user_password) # Master username, Master password

### Index 이름 셋팅
- 이전 노트북 01_preprocess_docs/02_load_docs_opensearch.ipynb를 통해서 생성된 OpenSearch Index name 입력

In [20]:
index_name = "v17-genai-poc-knox-kor-eval-parent-doc-retriever"
#index_name = "v16-genai-poc-knox-eval-parent-doc-retriever"

### OpenSearch Client 생성

In [21]:
from utils.opensearch import opensearch_utils

In [22]:
os_client = opensearch_utils.create_aws_opensearch_client(
    aws_region,
    opensearch_domain_endpoint,
    http_auth
)

## 3. Ground Truth Generator 정의

In [23]:
import pandas as pd
from termcolor import colored
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate

In [28]:
retriever_prompt_template = """
\n\nHuman: Here is the context information, inside <context></context> XML tags.

<context>
{context}
</context>

Given the context information and not prior knowledge.
generate only questions based on the below query.

You are a Professor. Your task is to setup \
{num_questions_per_chunk} questions for an upcoming \
quiz/examination. The questions should be diverse in nature \
across the document. The questions should not contain options, start with "-"
Restrict the questions to the context information provided.
Write in Korean.

\n\nAssistant:"""

PROMPT_RETRIEVER = PromptTemplate(
    template=retriever_prompt_template, input_variables=["context", "num_questions_per_chunk"]
)

In [29]:
generation_prompt_template = """
Here is the context, inside <context></context> XML tags.

<context>
{context}
</context>

Only using the context as above, answer the following question with the rules as below:
    - Don't insert XML tag such as <context> and </context> when answering.
    - Write as much as you can
    - Be courteous and polite
    - Only answer the question if you can find the answer in the context with certainty.
    - Skip the preamble
    - Use three sentences maximum and keep the answer concise.
    - If the answer is not in the context, just say "Could not find answer in given contexts."
    - Answer in Korean.

Question:
{question}

Answer:"""

PROMPT_GENERATION = PromptTemplate(
    template=generation_prompt_template, input_variables=["context", "question"]
)

In [36]:
def GTGenerator(os_client, llm_retriever, llm_generation, prompt_retriever, \
                prompt_generation,  docs_per_request, parent_document=False, num_questions_per_chunk=2):

    is_done = False
    offset = 0
    count = 0
    limit = docs_per_request = 5
    fetched_count = 0
    loop_count = 0

    if parent_document:
        query = {
            "query": {
                    "bool" : {
                        "must" : {
                            "match_all": {}
                        },
                        "filter": [
                            {'term': {'metadata.family_tree': 'child'}}
                        ]
                    }
                }
            }
    else:
        query = {"query": {"match_all": {}}}

    llm_chain_retriever = LLMChain(llm=llm_retriever, prompt=prompt_retriever)
    llm_chain_generation = LLMChain(llm=llm_generation, prompt=prompt_generation)
    gt = [] # [question, 정답 id, 정답 text]

    while not is_done:
        try:
            fetched_count += 1
            fetched_docs = os_client.search(
                index=index_name,
                body=query,
                size=limit,
                from_=offset
            )
            fetched_count = 0
        except Exception as e:
            if fetched_count == 3:
                print("Terminating script as connection is timeout more than 3 times.")
                break
            print ("{} Couldn't get records trying again for limit:{} and offset:{}".format(e, limit, offset))
            continue

        fetched_docs = fetched_docs["hits"]["hits"]
        loop_count += 1

        for index, doc in enumerate(fetched_docs):
            # Process the doc here.
            doc_id = doc["_id"]
            doc_text = doc["_source"]["text"]
            #print (colored(f'DOC ID: {doc_id}', "green"))
            #print (colored(f'TEXT: {doc_text}', "blue"))

            questions = llm_chain_retriever.predict(context=doc_text, num_questions_per_chunk=str(num_questions_per_chunk))
            #print (questions)
            questions = questions.split("\n\n-")
            if len(questions) <= num_questions_per_chunk + 1:

                if len(questions) == num_questions_per_chunk:
                    questions = list(map(lambda x:x.strip(), questions))
                else:
                    questions = list(map(lambda x:x.strip(), questions[1:]))
                for q in questions:
                    answer = llm_chain_generation.predict(question=q, context=doc_text)
                    answer = answer.strip()
                    #answer = answer[1:-1].strip()
                    #print (colored(f'question: {q}', "green"))
                    #print (colored(f'answer: {answer}', "blue"))
                    gt.append([q, answer, doc_id, doc_text])
            else:
                print ("err")
                print (questions)

            #print ("==")
        #break
        #if loop_count == 10: break
        offset += docs_per_request
        if len(fetched_docs) < docs_per_request:
            print("This is last batch.")
            is_done = True

        print("batch {} completed".format(count))
    return gt

In [37]:
%%time
gt = GTGenerator(
    os_client=os_client,
    llm_retriever=llm_claude,
    llm_generation=llm_claude, #llm_jurassic,
    prompt_retriever=PROMPT_RETRIEVER,
    prompt_generation=PROMPT_GENERATION,
    docs_per_request=5,
    parent_document=True,
    num_questions_per_chunk=1
)

eval_dataset_retriever = pd.DataFrame(gt, columns=["question", "answer", "doc_id", "doc"])
eval_dataset_retriever.to_csv("eval_dataset_v17.csv", index=False)
#eval_dataset_retriever.to_pickle("eval_dataset.pkl")

batch 0 completed
batch 0 completed
batch 0 completed
batch 0 completed
batch 0 completed
batch 0 completed
batch 0 completed
batch 0 completed
batch 0 completed
batch 0 completed
batch 0 completed
batch 0 completed
batch 0 completed
batch 0 completed
batch 0 completed
batch 0 completed
batch 0 completed
batch 0 completed
batch 0 completed
batch 0 completed
batch 0 completed
batch 0 completed
batch 0 completed
batch 0 completed
batch 0 completed
batch 0 completed
batch 0 completed
batch 0 completed
batch 0 completed
batch 0 completed
This is last batch.
batch 0 completed
CPU times: user 1.39 s, sys: 35.4 ms, total: 1.42 s
Wall time: 26min 51s
