# Creating a PubMed Chatbot with Llama2

## Overview

For this tutorial we are creating a PubMed chatbot that will answer questions by gathering information from documents we have provided via an index. The model we will be using today is a pretrained Llama2 model from Jumpstart.

## Learning Objectives
- Introduce langchain
- Explain the differences between zero-shot, one-shot, and few-shot prompting
- Practice using different document retrievers

## Prerequisites
You need access to SageMaker, Model Jumpstart, S3, and Kendra

## Get Started

### Deploy the Model

Identify which model we want to deploy from Jumpstart, in this case we are using Llama2 with 7 billion parameters.

In [None]:
model_id, model_version = "meta-textgeneration-llama-2-13b-f", "2.*"

Create an endpoint to deploy our model locally, so that we can communicate with our model, send inputs, and retrieve outputs.

In [None]:
from sagemaker.jumpstart.model import JumpStartModel

model = JumpStartModel(model_id=model_id, model_version=model_version)
predictor = model.deploy()


Next we will print the endpoint name which we will need later to run our chatbot.

In [None]:
endpoint_id=predictor.endpoint_name

### PubMed API vs Kendra Index

Our chatbot respond to prompts based on the documents we supplied. This occurs via a **vector index**. A vector index is a data structure composed of vectorized embeddings (generated from our inputs) that enables fast and accurate search and retrieval from a large dataset of objects. We will explore using two methods to generate our index: PubMed API vs Kendra Index.

**What is the difference?**

The **PubMed API** is provided free by langchain to connect your model to more than **35 million citations** for biomedical literature from MEDLINE, life science journals, and online books. **Kendra index** is an AWS product that allows the user more **security and control** on which documents you wish to supply to your model. 

We will explore both methods to see which produces the best results!

#### Setting up a Kendra Index

If you choose to use a Kendra index to supply documents to your model follow the instructions below:

AWS marketplace provides a PubMed database named **PubMed Central® (PMC)** that contains free full-text archive of biomedical and life sciences journal article at the U.S. National Institutes of Health's National Library of Medicine (NIH/NLM). We will subset this database to add documents to our Kendra index. Ensure that you have the correct roles and policies to allow your environment to connect to S3 buckets, SageMaker, and Kendra.

The first step will be to create a bucket that we will later use as our data source for our index.

In [None]:
#make bucket
bucket = 'pubmed-chat-docs'
! aws s3 mb s3://{bucket}

Next we will download the metadata file from the PMC bucket. This metadata file will list all of the articles within the PMC bucket and their paths. We will use these data to subset the database into our own bucket.

In [None]:
#download the metadata file
! aws s3 cp  s3://pmc-oa-opendata/oa_comm/txt/metadata/txt/oa_comm.filelist.txt . --sse

We only want the metadata of the first 50 files to keep this tutorial short.

In [None]:
#import the file as a dataframe
import pandas as pd
import os

df = pd.read_csv('oa_comm.filelist.csv')

#first 50 files
first_50=df[0:50]
#save new metadata
first_50.to_csv('oa_comm.filelist_50.csv', index=False)

Lets look at our metadata! We can see that the bucket path to the files are under the **Key** column. This column is what we will use to loop through the PMC bucket and copy the first 50 files to our bucket.

In [None]:
first_50

In [None]:
import os
#gather path to files in bucket
for i in first_50['Key']:
    os.system(f'aws s3 cp s3://pmc-oa-opendata/{i} s3://{bucket}/docs/ --sse')

We will also save our new metadata file to our bucket to help Kendra index our files.

In [None]:
! aws s3 cp oa_comm.filelist_50.csv s3://{bucket}/docs/

Now we can create our Kendra index, use the following instructions to create a index via the console or command line [Creating a Kendra index](https://docs.aws.amazon.com/kendra/latest/dg/create-index.html). To connect our bucket as a data source follow the instructions provided [here](https://docs.aws.amazon.com/kendra/latest/dg/data-source-s3.html) to do so via the console or AWS Python SDK.

### Creating a Inference Script 

For us to fluidly send input and receive outputs from our chatbot we need to create an [**inference script**](https://docs.aws.amazon.com/sagemaker/latest/dg/deploy-model.html#deploy-model-options) that will format inputs in a way that the chatbot can understand and format outputs in a way we can understand. We will also be supplying instructions to the chatbot through the script.

Our script will utilize **LangChain** tools and packages to enable our model to:
- **Connect to sources of context** (e.g. providing our model with tasks and examples)
- **Rely on reason** (e.g. instruct our model how to answer based on provided context)

**Warning**: The inference script must be run on the terminal via the command `python YOUR_SCRIPT.py`.

**Warning:** The following tools must be installed via your terminal

`pip install "langchain" "xmltodict"`

The first part of our script will be to list all the tools that are required. 
-  **PubMedRetriever:** Utilizes the langchain retriever tool to specifically retrieve PubMed documents from the PubMed API.
- **AmazonKendraRetriever:** Utilizes the langchain retriever tool to specifically retrieve documents stored in your Kendra index.
- **ConversationalRetrievalChain:** Allows the user to construct a conversation with the model and retrieves the outputs while sending inputs to the model.
- **PromptTemplate:** Allows the user to prompt the model to provide instructions, best method for zero and few shot prompting
- **LLMContentHandler:** Handles the content to and from the model by transforming the input to a format that model can accept and transforms the output from the model to string that the LLM class expects.
- **SagemakerEndpoint**: Connects to our endpoint in SageMaker and allows all of the tools listed above to connect to our model.


```python
from langchain_community.retrievers import PubMedRetriever
from langchain.retrievers import AmazonKendraRetriever
from langchain_community.llms import SagemakerEndpoint
from langchain_community.llms.sagemaker_endpoint import LLMContentHandler
from langchain.chains import ConversationalRetrievalChain
from langchain.prompts import PromptTemplate
import sys
import json
import os
```

Second will build a class that will hold the functions we need to send inputs and retrieve outputs from our model. For the beginning of our class we will establish some colors to our text conversation with our chatbot which we will utilize later.

```python
class bcolors:
    HEADER = '\033[95m'
    OKBLUE = '\033[94m'
    OKCYAN = '\033[96m'
    OKGREEN = '\033[92m'
    WARNING = '\033[93m'
    FAIL = '\033[91m'
    ENDC = '\033[0m'
    BOLD = '\033[1m'
    UNDERLINE = '\033[4m'
```

Next is to create a function that will gather the necessary information to connect to our model, which will be the:
- Location
- Kendra Index ID **(only if you are using Kendra instead of the PubMed API)**
- Endpoint_name or ID

```python
def build_chain():
  region = os.environ["AWS_REGION"]
  kendra_index_id = os.environ["KENDRA_INDEX_ID"] #only needed is using a Kendra index instead of Pubmed API
  endpoint_name = os.environ["LLAMA_2_ENDPOINT"]
```

Next we will create a class named **'ContentHandeler'** that will transforms our inputs and outputs into a json format. For Llama2 to understand and accept our inputs we need to structure in a specific manner, this is done by the **transform_input** function:
```
{
"inputs": 
    [[
        {"role": "user", "content": prompt},
    ]],
    **model_kwargs
                                  }
```
Where `prompt` will be our instructions to our model (what the model is expected to do with our input) and `**model_kwargs` is where we provide our parameters.

Our input is then encoded in a **UTF-8** format to convert our string into 0s and 1s. 

The next function in our class is named **transform_output**, this function will take the outputs sent from our model and decode them from 0s and 1s to strings.

```python
class ContentHandler(LLMContentHandler):
      content_type = "application/json"
      accepts = "application/json"

      def transform_input(self, prompt: str, model_kwargs: dict) -> bytes:
          input_str = json.dumps({"inputs": 
                                  [[
                                    {"role": "user", "content": prompt},
                                  ]],
                                  **model_kwargs
                                  })
          
          return input_str.encode('utf-8')
      
      def transform_output(self, output: bytes) -> str:
          response_json = json.loads(output.read().decode("utf-8"))
          
          return response_json[0]['generation']['content']
```

Now that we have a class that handles our input and outputs in a format that our model can understand we use **SageMakerEndpoint** tool to connect to our endpoint that we made in SageMaker. 

Notice that we set our class as the **content_handler** and we specified other **model_kwargs** as our parameters to control the temperature, top p, and max number of new tokens the model should generate to process our output.


```python
content_handler = ContentHandler()

  llm=SagemakerEndpoint(
          endpoint_name=endpoint_name, 
          region_name=region, 
          model_kwargs={"parameters": {"max_new_tokens": 1000, "top_p": 0.9,"temperature":0.6}},
          endpoint_kwargs={"CustomAttributes":"accept_eula=true"},
          content_handler=content_handler,
      )
```

We specify what our retriever both the PubMed and Kendra retriever are list, please only add one per script.

```python
retriever= PubMedRetriever()
retriever = AmazonKendraRetriever(index_id=kendra_index_id,region_name=region) #only use if using Kendra Index
```

Here we are constructing our **prompt_template**, this is where we can try zero-shot or few-shot prompting. Only add one method per script.

#### Zero-shot prompting

Zero-shot prompting does not require any additional training more so it gives a pre-trained language model a task or query to generate text (our output). The model relies on its general language understanding and the patterns it has learned during its training to produce relevant output. In our script we have connect our model to a **retriever** to make sure it gathers information from that retriever (this can be the PubMed API or Kendra). 

See below that the task is more like instructions notifying our model they will be asked questions which it will answer based on the info of the scientific documents provided from the index provided (this can be the PubMed API or Kendra index). All of this information is established as a **prompt template** for our model to receive.

```python
prompt_template = """
  Ignore everything before.
  
  Instruction:
  Instructions:
  I will provide you with research papers on a specific topic in English, and you will create a cumulative summary. 
  The summary should be concise and should accurately and objectively communicate the takeaway of the papers related to the topic. 
  You should not include any personal opinions or interpretations in your summary, but rather focus on objectively presenting the information from the papers. 
  Your summary should be written in your own words and ensure that your summary is clear, concise, and accurately reflects the content of the original papers. First, provide a concise summary then citations at the end.
  
  {question} Answer "don't know" if not present in the document. 
  {context}
  Solution:"""
  PROMPT = PromptTemplate(
      template=prompt_template, input_variables=["context", "question"],
  )
```

#### One-shot and Few-shot Prompting

One and few shot prompting are similar to one-shot prompting, in addition to giving our model a task just like before we have also supplied an example of how the our model structure our output.

See below that we have implemented one-shot prompting to our script.  

```python
prompt_template = """
  Instructions:
  I will provide you with research papers on a specific topic in English, and you will create a cumulative summary. 
  The summary should be concise and should accurately and objectively communicate the takeaway of the papers related to the topic. 
  You should not include any personal opinions or interpretations in your summary, but rather focus on objectively presenting the information from the papers. 
  Your summary should be written in your own words and ensure that your summary is clear, concise, and accurately reflects the content of the original papers. First, provide a concise summary then citations at the end. 
  Examples:
  Question: What is a cell?
  Answer: '''
  Cell, in biology, the basic membrane-bound unit that contains the fundamental molecules of life and of which all living things are composed. 
  Sources: 
  Chow, Christopher , Laskey, Ronald A. , Cooper, John A. , Alberts, Bruce M. , Staehelin, L. Andrew , 
  Stein, Wilfred D. , Bernfield, Merton R. , Lodish, Harvey F. , Cuffe, Michael and Slack, Jonathan M.W.. 
  "cell". Encyclopedia Britannica, 26 Sep. 2023, https://www.britannica.com/science/cell-biology. Accessed 9 November 2023.
  '''
  
  {question} Answer "don't know" if not present in the document. 
  {context}
  

  
  Solution:"""
  PROMPT = PromptTemplate(
      template=prompt_template, input_variables=["context", "question"],
  )
```

The following set of commands control the chat history essentially telling the model to expect another question after it finishes answering the previous one. Follow up questions can contain references to past chat history so the **ConversationalRetrievalChain** combines the chat history and the followup question into a standalone question, then looks up relevant documents from the retriever, and finally passes those documents and the question to a question-answering chain to return a response.

All of these pieces such as our conversational chain, prompt, and chat history are passed through a function called **run_chain** so that our model can return is response. We have also set the length of our chat history to one meaning that our model can only refer to the pervious conversation as a reference.

```python
condense_qa_template = """
  Chat History:
  {chat_history}
  Here is a new question for you: {question}
  Standalone question:"""
  standalone_question_prompt = PromptTemplate.from_template(condense_qa_template)
 
    qa = ConversationalRetrievalChain.from_llm(
        llm=llm, 
        retriever=retriever, 
        condense_question_prompt=standalone_question_prompt, 
        return_source_documents=True, 
        combine_docs_chain_kwargs={"prompt":PROMPT},
        )
      return qa

def run_chain(chain, prompt: str, history=[]):
    print(prompt)
    return chain({"question": prompt, "chat_history": history})

MAX_HISTORY_LENGTH = 1 #increase to refer to more pervious chats
```

The final part of our script utilizes our class and incorporates colors to add a bit of flare to our conversation with our model. The model when first initialized should greet the user asking **"Hello! How can I help you?"** then instructs the user to ask a question or exit the session **"Ask a question, start a New search: or CTRL-D to exit."**. With every question submitted to the model it is labeled as a **new search** we then run the run_chain function to get the models response or answer and add the response to the **chat history**. 

```python
if __name__ == "__main__":
  chat_history = []
  qa = build_chain()
  print(bcolors.OKBLUE + "Hello! How can I help you?" + bcolors.ENDC)
  print(bcolors.OKCYAN + "Ask a question, start a New search: or CTRL-D to exit." + bcolors.ENDC)
  print(">", end=" ", flush=True)
  for query in sys.stdin:
    if (query.strip().lower().startswith("new search:")):
      query = query.strip().lower().replace("new search:","")
      chat_history = []
    elif (len(chat_history) == MAX_HISTORY_LENGTH):
      chat_history.pop(0)
    result = run_chain(qa, query, chat_history)
    chat_history.append((query, result["answer"]))
    print(bcolors.OKGREEN + result['answer'] + bcolors.ENDC)
    ###this if statment is not needed for PubMed Retreiver users
    if 'source_documents' in result: 
      print(bcolors.OKGREEN + 'Sources:')
      for d in result['source_documents']:
        print(d.metadata['source'])
    ###
    print(bcolors.ENDC)
    print(bcolors.OKCYAN + "Ask a question, start a New search: or CTRL-D to exit." + bcolors.ENDC)
    print(">", end=" ", flush=True)
  print(bcolors.OKBLUE + "Bye" + bcolors.ENDC)
```

Running our script in the terminal will require us to export the following global variables before running the python script. Don't forget to run you python script on the terminal using the command `python NAME_OF_YOUR_SCRIPT.py`. For more guidence take a look at our **example inference scripts** for the [PubMed API](/example_scripts/langchain_chat_llama_2_zeroshot.py) and [Kendra](/example_scripts/kendra_chat_llama_2.py).

In [None]:
#retreive our endpoint id
endpoint_id

In [None]:
#enter these global variables in your terminal
export AWS_REGION='<Enter_location>'
export LLAMA_2_ENDPOINT='<Enter_endpoint_id>'
export KENDRA_INDEX_ID='<Enter_kendra_index>'

You should see similar results on the terminal. In this example we ask the chatbot to explain brain cancer!

![PubMed Chatbot Results](../../docs/images/PubMed_chatbot_results.png)

## Conclusions
Here you learned how to deploy a model, create a vector database (index) from PubMed documents, and then interact with a model to product predictions using an inference script. 

### Clean Up

**Warning:** Dont forget to delete the resources we just made to avoid accruing additional costs!

In [None]:
#Delete model and endpoint
model.delete()
endpoint.delete()

In [None]:
#Delete bucket
! aws s3 rb s3://$bucket --force  