# Bedrock Knowledge Base Retrieval and Generation for ReVIEW

In [1]:
import json

import sys

from pydantic import BaseModel

sys.path.append("../frontend/")
from components.bedrock_utils import get_bedrock_client


In [2]:
FOUNDATION_MODEL = "anthropic.claude-3-sonnet-20240229-v1:0"
REGION_NAME = "us-east-1"
KNOWLEDGE_BASE_ID = "VX6KXQMHZ2"
NUM_CHUNKS = 5
USERNAME = "demouser"
MEDIA_NAME = None
QUERY = "Did they mention Nvidia?"


In [3]:
# Used for retrieval
bedrock_agent_runtime_client = get_bedrock_client(region=REGION_NAME, agent=True)

# Used for generation
bedrock_client = get_bedrock_client(region=REGION_NAME, agent=False)


Create new client
  Using region: us-east-1
Found credentials in shared credentials file: ~/.aws/credentials
boto3 Bedrock client successfully created!
bedrock_client._endpoint=bedrock-agent-runtime(https://bedrock-agent-runtime.us-east-1.amazonaws.com)
Create new client
  Using region: us-east-1
Found credentials in shared credentials file: ~/.aws/credentials
boto3 Bedrock client successfully created!
bedrock_client._endpoint=bedrock-runtime(https://bedrock-runtime.us-east-1.amazonaws.com)


In [4]:
def retrieve(agent_client, query, username, media_name, num_chunks):
    # Always filter on username to prevent people from querying other users' data
    # Optionally filter on media name if user wants to chat with just one media file
    username_filter = {"equals": {"key": "username", "value": username}}
    if not MEDIA_NAME:
        retrieval_filter = username_filter
    else:
        retrieval_filter = {
            "andAll": [
                username_filter,
                {"equals": {"key": "media_name", "value": media_name}},
            ]
        }

    retrieval_config = {
        "vectorSearchConfiguration": {
            "numberOfResults": num_chunks,
            "filter": retrieval_filter,
        },
    }

    res = agent_client.retrieve(
        knowledgeBaseId=KNOWLEDGE_BASE_ID,
        retrievalConfiguration=retrieval_config,
        retrievalQuery={"text": query},
    )

    return res

In [5]:
res = retrieve(
    agent_client=bedrock_agent_runtime_client,
    query=QUERY,
    username=USERNAME,
    media_name=MEDIA_NAME,
    num_chunks=NUM_CHUNKS,
)

In [6]:
kazu = res["retrievalResults"]
print(f"{len(kazu)} chunks retrieved.")
c0 = kazu[0]
print("First chunk:")
print(f"  Text: {c0['content']['text'][:50]} ...")
print(f"  Score: {c0['score']}")
print(f"  Location: {c0['location']['s3Location']['uri']}")
print(f"  Custom Meta: {c0['metadata']['media_name']}")

5 chunks retrieved.
First chunk:
  Text: [0] from scoping requirements to defining evaluati ...
  Score: 0.3786281
  Location: s3://kazu-dev-339712833620-assets/transcripts-txt/demouser/a88fb9b4-5253-40db-8f0d-6a34e9f1f00b.txt
  Custom Meta: test-5min-vid.mp4


In [7]:
def build_chunks_string(retrieve_response: dict) -> str:
    """Build a single string from retrieved chunks like:
    <chunk_1>
    <media_name>
    foo-bar-vid.mp4
    </media_name>
    <transcript>
    [0] blah blah [12] blah blah blah
    </transcript>
    </chunk_1>
    <chunk_2>
    ...
    """
    chunks_string = ""
    for i, chunk in enumerate(retrieve_response["retrievalResults"]):
        chunks_string += f"<chunk_{i+1}>\n<media_name>\n{chunk['metadata']['media_name']}\n</media_name>\n<transcript>\n{chunk['content']['text']}\n</transcript>\n</chunk_{i+1}>\n\n"
    return chunks_string

In [8]:
def generate(br_client, model_id, query, retrieval_response, **kwargs) -> str:
    SYSTEM_PROMPT = """You are an intelligent AI which attempts to answer questions based on retrieved chunks of automatically generated transcripts."""

    MESSAGE_TEMPLATE = """
I will provide you with retrieved chunks of transcripts. The user will provide you with a question. Using only information in the provided transcript chunks, you will attempt to answer the user's question.

Each chunk may or may not be relevant to answering the question. Each chunk will include a <media_name> block which contains the parent file that the transcript came from. Each line in the transcript chunk begins with an integer timestamp (in seconds) within square brackets, followed by a transcribed sentence. When answering the question, you will need to provide the timestamp you got the answer from.

Here are the retrieved chunks of transcripts in numbered order:

<transcript_chunks>
{chunks}
</transcript_chunks>

When you answer the question, your answer must be a parsable json string. The json should have two keys. One key, "answer", is your answer to the user's question. The second key, "citations" is a list of dicts which contain a "media_name" key and a "timestamp" key, which correspond to the resources used to answer the question. For example, if you got your answer from only one chunk, then the "citations" list will be only one element long, with the media_name of the chunk from which you got the answer, and the relevant timestamp within that chunk's transcript. If you used information from three chunks, the "citations" list will be three elements long.

If you are unable to answer the question using information provided in any of the chunks, your response should include no citations like this:
{{"answer": "I am unable to answer the question based on the provided media file(s).", "citations": []}}

Here is the user's question:
<question>
{query}
</question>
    """

    chunks_str = build_chunks_string(retrieval_response)
    message_content = MESSAGE_TEMPLATE.format(query=query, chunks=chunks_str)

    body = {
        "system": SYSTEM_PROMPT,
        "messages": [{"role": "user", "content": message_content}],
        "anthropic_version": "",
        **kwargs,
    }
    response = br_client.invoke_model(modelId=model_id, body=json.dumps(body))
    response = json.loads(response["body"].read().decode("utf-8"))

    return response["content"][0]["text"]

In [9]:
from pydantic import BaseModel
from typing import List

class Citation(BaseModel):
    media_name: str
    timestamp: int

class LLMAnswer(BaseModel):
    answer: str
    citations: List[Citation]

    def pprint(self):
        print(f"LLMAnswer:\n Answer={self.answer}\n Citations={self.citations}")

def parse_generation(generation_response: str) -> LLMAnswer:
    llm_answer = LLMAnswer(**json.loads(generation_response))
    return llm_answer

In [10]:
#########################
# Full workflow example #
#########################

FOUNDATION_MODEL = "anthropic.claude-3-sonnet-20240229-v1:0"
REGION_NAME = "us-east-1"
KNOWLEDGE_BASE_ID = "VX6KXQMHZ2"
NUM_CHUNKS = 5
USERNAME = "demouser"
MEDIA_NAME = None

query = "What AWS services are mentioned?"

# Used for retrieval
bedrock_agent_runtime_client = get_bedrock_client(region=REGION_NAME, agent=True)
# Used for generation
bedrock_client = get_bedrock_client(region=REGION_NAME, agent=False)

retrieval_result = retrieve(
    agent_client=bedrock_agent_runtime_client,
    query=query,
    username=USERNAME,
    media_name=MEDIA_NAME,
    num_chunks=NUM_CHUNKS,
)

generate_result = generate(
    br_client=bedrock_client,
    model_id=FOUNDATION_MODEL,
    query=query,
    retrieval_response=retrieval_result,
    temperature=0.1,
    max_tokens=200,
)

answer: LLMAnswer = parse_generation(generate_result)

Create new client
  Using region: us-east-1
Found credentials in shared credentials file: ~/.aws/credentials
boto3 Bedrock client successfully created!
bedrock_client._endpoint=bedrock-agent-runtime(https://bedrock-agent-runtime.us-east-1.amazonaws.com)
Create new client
  Using region: us-east-1
Found credentials in shared credentials file: ~/.aws/credentials
boto3 Bedrock client successfully created!
bedrock_client._endpoint=bedrock-runtime(https://bedrock-runtime.us-east-1.amazonaws.com)


In [11]:
answer.pprint()

LLMAnswer:
 Answer=The AWS services mentioned in the provided transcript chunks are Amazon Transcribe, Amazon SageMaker, AWS Trainium, AWS Inferentia, Amazon Bedrock, and Amazon Cognito.
 Citations=[Citation(media_name='test-vid1.mp4', timestamp=31), Citation(media_name='test-5min-vid.mp4', timestamp=63), Citation(media_name='test-5min-vid.mp4', timestamp=79), Citation(media_name='test-vid1.mp4', timestamp=31)]
