In [131]:
from huggingface_hub import InferenceClient
import numpy as np

In [132]:
def hallucination_check_prompt(task_type, source_info, response):
    if task_type == 'Summary':
        article = source_info
        prompt = 'Below is the original news:\n'
        prompt += f'{article}'+'\n'
        prompt += 'Below is a summary of the news:\n'
        prompt += f'{response}'+'\n'
        prompt += 'Your task is to identify and label any hallucinated statements in the summary that are unsupported or contradicted by the original news. '
    elif task_type == 'Data2txt':
        business_info = source_info
        prompt = 'Below is a structured data in the JSON format:\n'
        prompt += f'{business_info}'+'\n'
        prompt += 'Below is an overview article written in accordance with the structured data:\n'
        prompt += f'{response}'+'\n'
        prompt += 'Your task is to identify and label any hallucinated statements in the overview that are unsupported or contradicted by the structured data. '
    elif task_type == 'QA':
        question, passages = source_info['question'], source_info['passages']
        prompt = 'Below is a question:\n'
        prompt += f'{question}'+'\n\n'
        prompt += 'Below are related passages:\n'
        prompt += f'{passages}'+'\n'
        prompt += 'Below is an answer:\n'
        prompt += f'{response}'+'\n\n'
        prompt += 'Your task is to identify and label any hallucinated statements in the answer that are unsupported or contradicted by the passages. '
    prompt += 'Then, compile the labeled hallucinated spans into a JSON list, with each list item representing a separate hallucinated span.\n'
    prompt += 'Output:'
    return prompt

In [133]:
source_info = {'question': 'What is Trump', 'passages': 'passage1: Trump is overweight.\npassage2: Trump is handsome.'}
response = 'According to three passages, Trump is a dog.'
input_str = hallucination_check_prompt('QA', source_info, response)
input_str

'Below is a question:\nWhat is Trump\n\nBelow are related passages:\npassage1: Trump is overweight.\npassage2: Trump is handsome.\nBelow is an answer:\nAccording to three passages, Trump is a dog.\n\nYour task is to identify and label any hallucinated statements in the answer that are unsupported or contradicted by the passages. Then, compile the labeled hallucinated spans into a JSON list, with each list item representing a separate hallucinated span.\nOutput:'

In [134]:
client = InferenceClient(model='http://localhost:8346', timeout=100)
output = client.text_generation(input_str, 
                                max_new_tokens=512, 
                                stream=False,
                                do_sample=True,
                                temperature=0.1,
                                details=True
                                )
output_details = [{'id': i.id, 'text': i.text, 'logprob': np.exp(i.logprob), 'special': i.special} for i in output.details.tokens]

In [135]:
output

TextGenerationResponse(generated_text=' ["According to three passages, Trump is a dog."]', details=Details(finish_reason=<FinishReason.EndOfSequenceToken: 'eos_token'>, generated_tokens=15, seed=11784489630023641976, prefill=[], tokens=[Token(id=6796, text=' ["', logprob=0.0, special=False), Token(id=7504, text='Acc', logprob=-0.018188477, special=False), Token(id=3278, text='ording', logprob=0.0, special=False), Token(id=304, text=' to', logprob=0.0, special=False), Token(id=2211, text=' three', logprob=0.0, special=False), Token(id=1209, text=' pass', logprob=0.0, special=False), Token(id=1179, text='ages', logprob=0.0, special=False), Token(id=29892, text=',', logprob=0.0, special=False), Token(id=27504, text=' Trump', logprob=0.0, special=False), Token(id=338, text=' is', logprob=0.0, special=False), Token(id=263, text=' a', logprob=0.0, special=False), Token(id=11203, text=' dog', logprob=0.0, special=False), Token(id=1213, text='."', logprob=0.0, special=False), Token(id=29962, t

In [136]:
output_details

[{'id': 6796, 'text': ' ["', 'logprob': 1.0, 'special': False},
 {'id': 7504, 'text': 'Acc', 'logprob': 0.9819759350372468, 'special': False},
 {'id': 3278, 'text': 'ording', 'logprob': 1.0, 'special': False},
 {'id': 304, 'text': ' to', 'logprob': 1.0, 'special': False},
 {'id': 2211, 'text': ' three', 'logprob': 1.0, 'special': False},
 {'id': 1209, 'text': ' pass', 'logprob': 1.0, 'special': False},
 {'id': 1179, 'text': 'ages', 'logprob': 1.0, 'special': False},
 {'id': 29892, 'text': ',', 'logprob': 1.0, 'special': False},
 {'id': 27504, 'text': ' Trump', 'logprob': 1.0, 'special': False},
 {'id': 338, 'text': ' is', 'logprob': 1.0, 'special': False},
 {'id': 263, 'text': ' a', 'logprob': 1.0, 'special': False},
 {'id': 11203, 'text': ' dog', 'logprob': 1.0, 'special': False},
 {'id': 1213, 'text': '."', 'logprob': 1.0, 'special': False},
 {'id': 29962, 'text': ']', 'logprob': 1.0, 'special': False},
 {'id': 2, 'text': '</s>', 'logprob': 1.0, 'special': True}]