# Detect Attributions with Amazon Bedrock
Goodbye Hallucinations, hello Attributions.  When an LLM generates output, there's always a risk that it hallucinates, or makes up facts that are not true.  This causes users to lose trust in LLM based solutions.  While there are a number of methods that attempt to detect these hallucinations, this notebook is targeted towards the opposite; detecting attributions.  For every fact or claim in an LLM's output, we will try to detect exactly where that fact came from.  In this way we both give a user confidence that any fact is grounded in truth, and detect by omission any hallucinations, or claims that are not supported by facts.  This system is intended to work in the context of RAG, where we assume every fact needed to create the output is present in the input, and we want the LLM to base its response only on facts present in the input.
  
This notebook has three main parts:
  1) Set up the environment.  We import libraries and create basic building blocks for part two
  2) Build attribution functionality.  Here we build the capability to analyze LLM output, and list attributions.
  3) Testing and examples.  Test the method on real data.
  
NOTE:  This notebook by default will load a cache of calls to Claude into memeory so that a duplicated request will instantly return the previous result, rather than asking Claude again.  This is helpful when testing and demoing, but should be turned off if you would like to generate new responses of the same request.

## 1) Set up the environment
These are basic functions and libraries that you might find in any application that uses Bedrock.  Note this must be run from a machine or account role that has permission to access Bedrock.

First, let's install some dependances:

Note: We install the Anthropic SDK to do local token counting, but do not actually use this in any way to call out to Anthropic.

In [3]:
#!pip install anthropic

In [4]:
from anthropic import Anthropic
client = Anthropic()
def count_tokens(text):
    return client.count_tokens(text)
#count_tokens('Hello world what is up?!')

We'll install the HuggingFace datasets library, for use in supplying sample data for testing.

In [21]:
#!pip install datasets

Collecting datasets
  Downloading datasets-2.16.1-py3-none-any.whl.metadata (20 kB)
Collecting pyarrow-hotfix (from datasets)
  Downloading pyarrow_hotfix-0.6-py3-none-any.whl.metadata (3.6 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting fsspec<=2023.10.0,>=2023.1.0 (from fsspec[http]<=2023.10.0,>=2023.1.0->datasets)
  Downloading fsspec-2023.10.0-py3-none-any.whl.metadata (6.8 kB)
Collecting huggingface-hub>=0.19.4 (from datasets)
  Downloading huggingface_hub-0.20.3-py3-none-any.whl.metadata (12 kB)
Downloading datasets-2.16.1-py3-none-any.whl (507 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m507.1/507.1 kB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m:00:01[0m
[?25hDownloading fsspec-2023.10.0-py3-none-any.whl (166 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m166.4/166.4 kB[0m [31m1.5 MB/s[0m eta [36m0:00:00[0m:00:01[0m
[?25hDownload

In [36]:
#grab the sentance tokenizer, for use in assigning each sentance a number.  (may also need to pip install nltk)
import nltk
nltk.download('punkt')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


True

In [5]:
import pickle, os, re, json
#we'll use time to track how long Bedrock takes to respond, which helps to estimate how long a job will take.
import time

#for ask_claude_threaded, import our threading libraries.
from queue import Queue
from threading import Thread

In [6]:
#set CACHE_RESPONCES to true to save all responses from Claude for reuse later.
#when true, and request to Claude that has been made before will be served from this cache, rather then sending a request to Bedrock.
CACHE_RESPONCES = True
if CACHE_RESPONCES: print ("WARNING: Claude Cache is enabled.  Responses may be stale, and this should be turned off in helper_functions during production use to prevent memory overflow.")



Next, let's set up the connection to Bedrock:
If needed, install at least 1.28.57 of Boto3 so that Bedrock is included.

In [8]:
#!pip install update boto3==1.28.57

In [9]:
#for connecting with Bedrock, use Boto3
import boto3
from botocore.config import Config

#increase the standard time out limits in boto3, because Bedrock may take a while to respond to large requests.
my_config = Config(
    connect_timeout=60*3,
    read_timeout=60*3,
)

In [10]:
bedrock = boto3.client(service_name='bedrock-runtime',config=my_config)
bedrock_service = boto3.client(service_name='bedrock',config=my_config)

In [12]:
#check that it's working:
models = bedrock_service.list_foundation_models()
if "anthropic.claude-v2" in str(models):
    pass#print("Claud-v2 found!")
else:
    print ("Error, no model found.")
max_token_count = 100000 #property of Claude 2

In [14]:
#Save our cache of calls to Claude
#this speeds things up when testing, because we're often making the same calls to Claude over and over.
claude_cache_pickel = "claude_cache.pkl"
    
def save_calls(claude_cache):
    with open(claude_cache_pickel, 'wb') as file:
        pickle.dump(claude_cache,file)
#load our cached calls to Claude
def load_calls():
    with open(claude_cache_pickel, 'rb') as file:
        return pickle.load(file)
def clear_cache():
    claude_cache = {}
    save_calls()
#a cache of recent requests, to speed up itteration while testing
claude_cache = {}

if not os.path.exists(claude_cache_pickel):
    print ("Creating new, empty cache of Claude calls.")
    save_calls(claude_cache)

if CACHE_RESPONCES:
    claude_cache = load_calls()

Creating new, empty cache of Claude calls.


In [15]:
MAX_ATTEMPTS = 5 #how many times to retry if Claude is not working.
def ask_claude(prompt_text, DEBUG=False):
    '''
    Send a prompt to Bedrock, and return the response.  Debug is used to see exactly what is being sent to and from Bedrock.
    '''
    #usually, the prompt will have "human" and "assistant" tags already.  These are required, so if they are not there, add them in.
    if not "Assistant:" in prompt_text:
        prompt_text = "\n\nHuman:"+prompt_text+"\n\Assistant: "
        
    promt_json = {
        "prompt": prompt_text,
        "max_tokens_to_sample": 3000,
        "temperature": 0.7,
        "top_k": 250,
        "top_p": 0.7,
        "stop_sequences": ["\n\nHuman:"]
    }
    body = json.dumps(promt_json)
    
    #returned cashed results, if any
    if body in claude_cache:
        return claude_cache[body]
    
    if DEBUG: print("sending:",prompt_text)
    modelId = 'anthropic.claude-v2'
    accept = 'application/json'
    contentType = 'application/json'
    
    start_time = time.time()
    attempt = 1
    while True:
        try:
            query_start_time = time.time()
            response = bedrock.invoke_model(body=body, modelId=modelId, accept=accept, contentType=contentType)
            response_body = json.loads(response.get('body').read())

            raw_results = response_body.get("completion").strip()

            #strip out HTML tags that Claude sometimes adds, such as <text>
            #results = re.sub('<[^<]+?>', '', raw_results)
            results = raw_results
            
            request_time = round(time.time()-start_time,2)
            if DEBUG:
                print("Recieved:",results)
                print("request time (sec):",request_time)
            #total_tokens = count_tokens(prompt_text+raw_results)
            #output_tokens = count_tokens(raw_results)
            #tokens_per_sec = round(total_tokens/request_time,2)
            break
        except Exception as e:
            print("Error with calling Bedrock: "+str(e))
            attempt+=1
            if attempt>MAX_ATTEMPTS:
                print("Max attempts reached!")
                results = str(e)
                request_time = -1
                #total_tokens = -1
                #output_tokens = -1
                #tokens_per_sec = -1
                break
            else:#retry in 10 seconds
                time.sleep(10)
    #store in cache only if it was not an error:
    if request_time>0:
        claude_cache[body] = (prompt_text,results,request_time)
    
    return(prompt_text,results,request_time)

In the next cell, we add queue handleing.  This allows us to make multiple requests to Bedrock at the same time.

In [16]:
from queue import Queue
from threading import Thread

# Threaded function for queue processing.
def thread_request(q, result):
    while not q.empty():
        work = q.get()                      #fetch new work from the Queue
        thread_start_time = time.time()
        try:
            data = ask_claude(work[1])
            result[work[0]] = data          #Store data back at correct index
        except Exception as e:
            error_time = time.time()
            print('Error with prompt!',str(e))
            result[work[0]] = (work[1],str(e),round(error_time-thread_start_time,2))
        #signal to the queue that task has been processed
        q.task_done()
    return True

def ask_claude_threaded(prompts,DEBUG=False):
    '''
    Call ask_claude, but multi-threaded.
    Returns a dict of the prompts and responces.
    '''
    q = Queue(maxsize=0)
    num_theads = min(50, len(prompts))
    
    #Populating Queue with tasks
    results = [{} for x in prompts];
    #load up the queue with the promts to fetch and the index for each job (as a tuple):
    for i in range(len(prompts)):
        #need the index and the url in each queue item.
        q.put((i,prompts[i]))
        
    #Starting worker threads on queue processing
    for i in range(num_theads):
        #print('Starting thread ', i)
        worker = Thread(target=thread_request, args=(q,results))
        worker.setDaemon(True)    #setting threads as "daemon" allows main program to 
                                  #exit eventually even if these dont finish 
                                  #correctly.
        worker.start()

    #now we wait until the queue has been processed
    q.join()

    if DEBUG:print('All tasks completed.')
    return results

In [17]:
#test if a singe Claude call is working
#print("Testing!  Is Claude working? "+ask_claude("Is Claude working?")[1])

Testing!  Is Claude working? I'm afraid I don't have enough information to know if someone named Claude is working or not. As an AI assistant without personal knowledge of Claude, I can't make assumptions about what specific people are currently doing.


In [18]:
#test if our threaded Claude calls are working
#print(ask_claude_threaded(["Please say the number one.","Please say the number two."]))

  worker.setDaemon(True)    #setting threads as "daemon" allows main program to


[('\n\nHuman:Please say the number one.\n\\Assistant: ', 'One.', 1.84), ('\n\nHuman:Please say the number two.\n\\Assistant: ', 'Two.', 2.89)]


## 2) Build attribution functionality
These are the functions specific to detecting attribution.

In [41]:
def add_line_numbers(text):
    '''
    This function takes a text, and adds XML style line numbers to each sentance.
    This allows the LLM to have a way to reference which lines support attribution.
    
    The text is passed in as a plain text string, and is returned as a string with tags added.
    '''
    lines = nltk.sent_tokenize(text)
    
    labeled_text = ""
    for line_number,line in enumerate(lines):
        labeled_text += "<line number=%s>%s</line>"%(line_number+1,line)
        
    return labeled_text

In [73]:
def add_line_numbers_dict(text):
    '''
    This function takes a text, and adds XML style line numbers to each sentance.
    This allows the LLM to have a way to reference which lines support attribution.
    
    The text is passed in as a plain text string, but the results are returned as a dict.
    This is useful for printing out the results of attribution.
    '''
    lines = nltk.sent_tokenize(text)
    temp_dict = {}
    for line_number,line in enumerate(lines):
        temp_dict[line_number+1]=line
        
    return temp_dict

In [193]:
detect_attribution_template = """\n\nHuman:  You are given a reference source, and a statement based on that source.  Your job is to find every single line in the reference which helps to support the claims made in the statement.
Here is the source and statement:
<source>
{{INPUT}}
</source>
<statement>
{{OUTPUT}}
</statement>
For each line in the statement, please list all the line numbers from the source, if any, which support that line in the statement.
Your response should be in JSON format, with a key for each line number in the statement, and values which are the supporting line numbers from the source.
\nAssistant:  Here is what you asked for:
"""


def get_attribution_prompt(text,summary):
    '''
    create the prompt for detecting attribution
    '''
    
    final_prompt = detect_attribution_template.replace("{{INPUT}}",text).replace("{{OUTPUT}}",summary)
    return final_prompt

In [194]:
detect_fact_score_template = """\n\nHuman: You will be given a statement.  It may contain false content, and therefore is being reviewed in a two step process.
The first step is to identify all lines that make factual claims, and the second step is to find sources that support those claims.
You are a detail oriented expert in charge of the first step of this process.
You job is to look at each line in the statement, and score how likely it is to require proof from trustworthy sources.
For example, the sentence "I have a hat." makes a factual claim, and should score high on the need for support, while the sentence "Here is a summary." is a helper sentence, and therefore does not need external support and should score lower.
Here is the statement:
<statement>
{{TEXT}}
</statement>
For each line in the statement, please score each line on a scale from 0 to 100, where 0 indicates a helper sentence that requires no support and makes no factual claims, and 100 is a factual claim that requires support from trustworthy sources.
Your response should be in JSON format, with a key for each line number in the statement, and values which are the score for that statement.
\nAssistant:  Here is what you asked for:
"""


def get_fact_score_prompt(text):
    '''
    create the prompt for scoring each line's need for factual support.
    '''
    final_prompt = detect_fact_score_template.replace("{{TEXT}}",text)
    return final_prompt

In [140]:
def get_attribution(text,summary):
    labeled_article = add_line_numbers(text)
    labeled_summary = add_line_numbers(summary)

    prompt = get_attribution_prompt(labeled_article,labeled_summary)
    result = ask_claude(prompt)
    return result
def get_fact_score(text):
    labeled_article = add_line_numbers(text)

    prompt = get_fact_score_prompt(labeled_article)
    result = ask_claude(prompt)
    return result

In [167]:
def get_attribution_print_results(text,summary):
    """
    Bundle everything up and print out results, for testing and seeing quick results.
    """
    result = get_attribution(text,summary)[1]
    score_result = get_fact_score(summary)[1]
    
    test_article_lines = add_line_numbers_dict(text)
    test_highlights_lines = add_line_numbers_dict(summary)
    result_json = json.loads(result)
    score_result_json = json.loads(score_result) 
    
    for line in result_json:
        print ("Output line %s: (Facts required score: %s)"%(line,score_result_json[line]))
        print(test_highlights_lines[int(line)])
        print ("Supporting facts:")
        if len(result_json[line])==0:
            hallucination = ""
            if score_result_json[line]>50:
                hallucination = "  Likely hallucination detected!"
            print ("  No supporting facts found."+hallucination)
        else:
            for line_2 in result_json[line]:

                print("  %s: "%line_2+test_article_lines[int(line_2)])
        print ("")

## 3). Testing and Examples

In [23]:
#we'll start by downloading some sample data for testing.
#we use https://huggingface.co/datasets/cnn_dailymail which is a collection of news articles and their highlights.
from datasets import load_dataset
dataset = load_dataset("cnn_dailymail",'1.0.0')

Downloading data:   0%|          | 0.00/256M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/257M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/259M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/34.7M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/30.0M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/287113 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/13368 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/11490 [00:00<?, ? examples/s]

In [192]:
#let's pick an article for testing, grab one about weather so it's not too depressing...
article_num = 8
test_article = dataset['test'][article_num]['article']
test_article = test_article.replace("Ahead of the storm.","Ahead of the storm,")#typo correction
test_article = test_article.replace("Gov.","Gov")#easier to read
test_highlight = dataset['test'][article_num]['highlights']
fake_test_highlight = test_highlight + "  I love to eat donuts for breakfast."
print (test_article)
print (" ")
print ("Human generated highlights:")
print (fake_test_highlight)

(CNN)Filipinos are being warned to be on guard for flash floods and landslides as tropical storm Maysak approached the Asian island nation Saturday. Just a few days ago, Maysak gained super typhoon status thanks to its sustained 150 mph winds. It has since lost a lot of steam as it has spun west in the Pacific Ocean. It's now classified as a tropical storm, according to the Philippine national weather service, which calls it a different name, Chedeng. It boasts steady winds of more than 70 mph (115 kph) and gusts up to 90 mph as of 5 p.m. (5 a.m. ET) Saturday. Still, that doesn't mean Maysak won't pack a wallop. Authorities took preemptive steps to keep people safe such as barring outdoor activities like swimming, surfing, diving and boating in some locales, as well as a number of precautionary evacuations. Gabriel Llave, a disaster official, told PNA that tourists who arrive Saturday in and around the coastal town of Aurora "will not be accepted by the owners of hotels, resorts, inns 

In [142]:
print(add_line_numbers(test_article))
print ("")
print(add_line_numbers(fake_test_highlight))

<line number=1>(CNN)Filipinos are being warned to be on guard for flash floods and landslides as tropical storm Maysak approached the Asian island nation Saturday.</line><line number=2>Just a few days ago, Maysak gained super typhoon status thanks to its sustained 150 mph winds.</line><line number=3>It has since lost a lot of steam as it has spun west in the Pacific Ocean.</line><line number=4>It's now classified as a tropical storm, according to the Philippine national weather service, which calls it a different name, Chedeng.</line><line number=5>It boasts steady winds of more than 70 mph (115 kph) and gusts up to 90 mph as of 5 p.m. (5 a.m.</line><line number=6>ET) Saturday.</line><line number=7>Still, that doesn't mean Maysak won't pack a wallop.</line><line number=8>Authorities took preemptive steps to keep people safe such as barring outdoor activities like swimming, surfing, diving and boating in some locales, as well as a number of precautionary evacuations.</line><line number=

In [143]:
result = get_attribution(test_article,fake_test_highlight)[1]
print(result)

{
  "1": [2, 5],
  "2": [7, 8, 9, 12, 13],
  "3": []
}


That's not quite human readable, so let's print it out again, with the accociated text.

In [144]:
test_article_lines = add_line_numbers_dict(test_article)
test_highlights_lines = add_line_numbers_dict(fake_test_highlight)

In [145]:
result_json = json.loads(result) 

In [146]:
for line in result_json:
    print ("Output line %s:"%line)
    print(test_highlights_lines[int(line)])
    print ("Supporting facts:")
    if len(result_json[line])==0:
        print ("  No supporting facts found.")
    else:
        for line_2 in result_json[line]:
            print("  %s: "%line_2+test_article_lines[int(line_2)])
    print ("")

Output line 1:
Once a super typhoon, Maysak is now a tropical storm with 70 mph winds .
Supporting facts:
  2: Just a few days ago, Maysak gained super typhoon status thanks to its sustained 150 mph winds.
  5: It boasts steady winds of more than 70 mph (115 kph) and gusts up to 90 mph as of 5 p.m. (5 a.m.

Output line 2:
It could still cause flooding, landslides and other problems in the Philippines .
Supporting facts:
  7: Still, that doesn't mean Maysak won't pack a wallop.
  8: Authorities took preemptive steps to keep people safe such as barring outdoor activities like swimming, surfing, diving and boating in some locales, as well as a number of precautionary evacuations.
  9: Gabriel Llave, a disaster official, told PNA that tourists who arrive Saturday in and around the coastal town of Aurora "will not be accepted by the owners of hotels, resorts, inns and the like ... and will be advised to return to their respective places."
  12: It's expected to make landfall Sunday morning 

Now let's take the next step, and detect hallucinations.  We'll add an additional sentence to the highlights, so that we have one extra sentance that is not a hallucination, "Here are some highlights." and one that is "I love to eat donuts for breakfast."

In [195]:
fake_test_highlight = "Here are some highlights.  " + fake_test_highlight

In [196]:
score_result = get_fact_score(fake_test_highlight)[1]
print(score_result)

{
  "1": 0,
  "2": 100,
  "3": 100,
  "4": 100
}


Finally, we combind these results with the previous attributions to detect hallucinations.

In [197]:
#rerun the previous attribution analysis because we've added an extra line.
result = get_attribution(test_article,fake_test_highlight)[1]
print (result)
test_article_lines = add_line_numbers_dict(test_article)
test_highlights_lines = add_line_numbers_dict(fake_test_highlight)
result_json = json.loads(result)
score_result_json = json.loads(score_result) 

{
  "1": [],
  "2": [2, 5], 
  "3": [7, 8, 9],
  "4": []
}


In [198]:
for line in result_json:
    print ("Output line %s: (Facts required score: %s)"%(line,score_result_json[line]))
    print(test_highlights_lines[int(line)])
    print ("Supporting facts:")
    if len(result_json[line])==0:
        hallucination = ""
        if score_result_json[line]>50:
            hallucination = "  Likely hallucination detected!"
        print ("  No supporting facts found."+hallucination)
    else:
        for line_2 in result_json[line]:
            
            print("  %s: "%line_2+test_article_lines[int(line_2)])
    print ("")

Output line 1: (Facts required score: 0)
Here are some highlights.
Supporting facts:
  No supporting facts found.

Output line 2: (Facts required score: 100)
Once a super typhoon, Maysak is now a tropical storm with 70 mph winds .
Supporting facts:
  2: Just a few days ago, Maysak gained super typhoon status thanks to its sustained 150 mph winds.
  5: It boasts steady winds of more than 70 mph (115 kph) and gusts up to 90 mph as of 5 p.m. (5 a.m.

Output line 3: (Facts required score: 100)
It could still cause flooding, landslides and other problems in the Philippines .
Supporting facts:
  7: Still, that doesn't mean Maysak won't pack a wallop.
  8: Authorities took preemptive steps to keep people safe such as barring outdoor activities like swimming, surfing, diving and boating in some locales, as well as a number of precautionary evacuations.
  9: Gabriel Llave, a disaster official, told PNA that tourists who arrive Saturday in and around the coastal town of Aurora "will not be accep