In [1]:
from datasets import Dataset 
import os
from ragas import evaluate
from ragas.metrics import faithfulness, answer_correctness

# os.environ["OPENAI_API_KEY"] = "your-openai-key"
from dotenv import load_dotenv, find_dotenv

_ = load_dotenv(find_dotenv()) 

data_samples = {
    'question': ['When was the first super bowl?', 'Who won the most super bowls?'],
    'answer': ['The first superbowl was held on Jan 15, 1967', 'The most super bowls have been won by The New England Patriots'],
    'contexts' : [['The First AFL–NFL World Championship Game was an American football game played on January 15, 1967, at the Los Angeles Memorial Coliseum in Los Angeles,'], 
    ['The Green Bay Packers...Green Bay, Wisconsin.','The Packers compete...Football Conference']],
    'ground_truth': ['The first superbowl was held on January 15, 1967', 'The New England Patriots have won the Super Bowl a record six times']
}

dataset = Dataset.from_dict(data_samples)

score = evaluate(dataset,metrics=[faithfulness,answer_correctness])
print(score.to_pandas())

Evaluating:   0%|          | 0/4 [00:00<?, ?it/s]

                         question  \
0  When was the first super bowl?   
1   Who won the most super bowls?   

                                              answer  \
0       The first superbowl was held on Jan 15, 1967   
1  The most super bowls have been won by The New ...   

                                            contexts  \
0  [The First AFL–NFL World Championship Game was...   
1  [The Green Bay Packers...Green Bay, Wisconsin....   

                                        ground_truth  faithfulness  \
0   The first superbowl was held on January 15, 1967           0.0   
1  The New England Patriots have won the Super Bo...           0.0   

   answer_correctness  
0            0.749095  
1            0.731078  


In [87]:
import pandas as pd
df = pd.read_csv("evaluation_questions.csv")
df

Unnamed: 0,Questions,Groud_answer,Difficulty
0,How many countries and areas were enrolled in ...,"By the end of the 2021, 109 countries plus two...",2
1,What are the strategic objectives of the globa...,there are five main obejectives: to improve aw...,2
2,How many countries have developed their own na...,"Up to 2021-22, there are 149 countries release...",2
3,Which sector in One health would be influenced...,All sectors of One Health would be influenced ...,3
4,What is One Health?,"One Health is an integrated, unifying approach...",1
5,What is antimicrobial resistance (AMR)?,"AMR occurs when microorganisms, such as bacter...",1
6,Is it still possible to control antimicrobial ...,Yes. But it requires multisectoral and multidi...,5
7,What are the issues with financing in controll...,The financing challenges associated with comba...,5
8,What is the status of MRSA infection in Europe?,"For S. aureus, a significant decrease in the E...",5
9,Which regions will be disproportionally affect...,Low- and middle-income countries such as Afric...,5


In [31]:
from vectordb_utils_shule import ShuleVectorDB
from prompt_utils import build_prompt
import pickle
import importlib
# importlib.reload(openai_utils)
# from openai_utils import get_completion_openai, init_openai

index_path = "../index/index_n179366_03171945.pickle"
top_n = 8
recall_n = 80

init_openai()
with open(index_path, 'rb') as file:
        vec_db_shule = pickle.load(file)
    
def search_db(user_input, source_type):
    search_results = []
    
    if search_strategy == "hnsw":
        search_labels = vec_db_shule.search_bge(user_input, top_n)
        texts, pages, titles, years, countries, ORGs = vec_db_shule.get_context_by_labels(search_labels)
        search_results = texts
        res = [texts[i] if countries[i] == 'xxx' else f'In {countries[i]}, {texts[i]}' for i in range(top_n)]

        search_field = "\n\n".join([f"{i+1}. [Reference: {titles[i]}, Page: {pages[i]}, ORG: {ORGs[i]}, Year: {years[i]}]\n{texts[i]}" for i in range(top_n)])
        prompt = build_prompt(source_type=source_type, info=[f"{res[i]} [Reference: Page {pages[i]}, {titles[i]}, {years[i]}, {ORGs[i]}]" for i in range(top_n)], query=user_input)
        
    elif search_strategy == "rerank":
        scores, texts, pages, titles, years, countries, ORGs = rerank(user_input, top_n, recall_n)
        search_results = texts
        res = [texts[i] if countries[i] == 'xxx' else f'In {countries[i]}, {texts[i]}' for i in range(top_n)]

        search_field = "\n\n".join([f"{i+1}. [Reference: {titles[i]}, Page: {pages[i]}, ORG: {ORGs[i]}, Year: {years[i]}]\n{texts[i]}" for i in range(top_n)])
        prompt = build_prompt(source_type=source_type, info=[f"{res[i]} [Reference: Page {pages[i]}, {titles[i]}, {years[i]}, {ORGs[i]}]" for i in range(top_n)], query=user_input)
        
    elif search_strategy == "fusion":
        log_warning("Not support yet.")
        return
            
    log_info(f"prompt content built:\n{prompt}")
    return prompt, search_results


def rerank(user_input, top_n, recall_n):
    search_labels = vec_db_shule.search_bge(user_input, recall_n)
    t0 = time.time()
    texts, pages, titles, years, countries, ORGs = vec_db_shule.get_context_by_labels(search_labels)
    t1 = time.time()
    log_info(f"vec_db_shule.get_context_by_labels costs: {t1 - t0}")

    documents = [texts[i] if countries[i] == 'xxx' else f'In {countries[i]}, {texts[i]}' for i in range(len(pages))]
    res = rerank_model.rank(documents = documents,
                            query=user_input,
                            batch_size = 1,
                            return_documents = False,
                            show_progress_bar = False)
    t2 = time.time()
    log_info(f"rerank_model.predict costs: {t2 - t1}")
    
    ids = [i['corpus_id'] for i in res][:top_n]
    scores = [i['score'] for i in res][:top_n]

    log_info(f"finish rerank {recall_n} texts, return highest {top_n} texts")
    return scores, [texts[i] for i in ids], [pages[i] for i in ids], [titles[i] for i in ids], [years[i] for i in ids], [countries[i] for i in ids], [ORGs[i] for i in ids]



In [88]:
questions = df["Questions"].to_list()
ground_truths = df["Groud_answer"].to_list()
answers = []
contexts = []
retrieve_time = []
completion_time = []

In [90]:
from logger import log_info
import importlib
import openai_utils
import time
# 在需要重新加载的地方调用 reload() 函数
importlib.reload(openai_utils)
from openai_utils import get_completion_openai, init_openai

search_strategy = "hnsw"
source = "Hybrid"
model= "GPT-3.5"

init_openai()
for i, query in enumerate(questions[:2]):
    print(i, "begins")
    t0 = time.time()
    prompt, search_results = search_db(query, source)
    t1 = time.time()
    response = get_completion_openai(prompt, model=model)
    t2 = time.time()
    contexts.append(search_results) 
    answers.append(response)
    retrieve_time.append(t1 - t0)
    completion_time.append(t2 - t1)


0 How many countries and areas were enrolled in GLASS-AMR? begins
1 What are the strategic objectives of the global action plan (GAP) on antimicrobial resistance (AMR) in 2015? begins


In [75]:
data = {
    "question": questions,
    "answer": answers,
    "contexts": contexts,
    "ground_truth": ground_truths
}
dataset = Dataset.from_dict(data)

In [82]:
data = {
    "question": questions,
    "answer": answers,
    "contexts": contexts,
    "ground_truth": ground_truths,
    "retrieve_time": retrieve_time,
    "completion_time": completion_time, 
}
dataset = Dataset.from_dict(data)

NameError: name 'retrieve_time' is not defined

In [80]:
questions[:2]

['How many countries and areas were enrolled in GLASS-AMR?',
 'What are the strategic objectives of the global action plan (GAP) on antimicrobial resistance (AMR) in 2015?']

In [78]:
from ragas import evaluate
from ragas.metrics import (
    faithfulness,
    answer_relevancy,
    context_recall,
    context_precision,
    context_relevancy,
    context_entity_recall,
    answer_similarity
)

result = evaluate(
    dataset = dataset,
    metrics=[context_precision,
            context_recall,
            faithfulness,
            answer_relevancy,
            context_relevancy,
            context_entity_recall,
            answer_similarity],
    raise_exceptions=False
)

df = result.to_pandas()

Evaluating:   0%|          | 0/72 [00:00<?, ?it/s]

KeyboardInterrupt: 

Exception in thread Exception in threading.excepthook:
Exception ignored in thread started by: <bound method Thread._bootstrap of <Runner(Thread-11, stopped 140307882550848)>>
Traceback (most recent call last):
  File "/usr/lib/python3.10/threading.py", line 973, in _bootstrap
    self._bootstrap_inner()
  File "/usr/lib/python3.10/threading.py", line 1018, in _bootstrap_inner
    self._invoke_excepthook(self)
  File "/usr/lib/python3.10/threading.py", line 1336, in invoke_excepthook
    local_print("Exception in threading.excepthook:",
  File "/home/ubuntu/.local/lib/python3.10/site-packages/ipykernel/iostream.py", line 604, in flush
    self.pub_thread.schedule(self._flush)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/ipykernel/iostream.py", line 267, in schedule
    self._event_pipe.send(b"")
  File "/home/ubuntu/.local/lib/python3.10/site-packages/zmq/sugar/socket.py", line 696, in send
    return super().send(data, flags=flags, copy=copy, track=track)
  File "zmq/backe

In [83]:
df2 = df
df2

Unnamed: 0,question,answer,contexts,ground_truth,context_precision,context_recall,faithfulness,answer_relevancy
0,How many countries and areas were enrolled in ...,"Based on the information provided, as of today...",[AMR findings are thus presented and 1.3 Key f...,"By the end of the 2021, 109 countries plus two...",1.0,1.0,1.0,0.97457
1,What are the strategic objectives of the globa...,The strategic objectives of the global action ...,"[In 2015, recognizing the urgent need to tackl...",there are five main obejectives: to improve aw...,1.0,1.0,1.0,0.965291
2,How many countries have developed their own na...,"Based on available data, more than 115 countri...",[Executive summary Since the Global Action Pla...,"Up to 2021-22, there are 149 countries release...",1.0,0.0,0.666667,
3,Which sector in One health would be influenced...,The sector in One Health that would be influen...,[Introduction WHO implementation handbook for ...,All sectors of One Health would be influenced ...,1.0,1.0,1.0,
4,What is One Health?,"One Health is a collaborative, multisectoral, ...",[2 ‘One Health’ is an approach to designing an...,"One Health is an integrated, unifying approach...",1.0,1.0,1.0,1.0
5,What is antimicrobial resistance (AMR)?,Antimicrobial resistance (AMR) refers to the a...,[Antimicrobial resistance in livestock in the ...,"AMR occurs when microorganisms, such as bacter...",1.0,1.0,1.0,0.977347
6,Is it still possible to control antimicrobial ...,"Yes, it is still possible to control antimicro...",[SURVEILLANCE REPORT Antimicrobial resistance ...,Yes. But it requires multisectoral and multidi...,1.0,1.0,1.0,0.976911
7,What are the issues with financing in controll...,The issues with financing in controlling antim...,[Insufficient funding devoted to AMR is anothe...,The financing challenges associated with comba...,1.0,1.0,1.0,0.927155
8,What is the status of MRSA infection in Europe?,The status of MRSA infection in Europe varies ...,[MRSA is currently the most commonly The propo...,"For S. aureus, a significant decrease in the E...",1.0,1.0,1.0,
9,Which regions will be disproportionally affect...,Regions that will be disproportionately affect...,[It has been estimated that failure to address...,Low- and middle-income countries such as Afric...,0.909354,1.0,1.0,
