https://github.com/truera/trulens/blob/main/trulens_eval/examples/expositional/models/litellm_quickstart.ipynb

结论：

- 使用litellm定制后，虽然还是openai，但是终于能得到评估结果了
- 初步认为，是TruLens代码中的bug，还在默认使用openai的设置，造成连接错误


In [1]:
%%time

import os

base_url = "http://ape:3000/v1"
api_key = "sk-bJP6QSnUfjAYeYeE505d3eBf63A643BeB0B8E350Df9b7750"

os.environ["OPENAI_API_KEY"] = api_key
os.environ["OPENAI_API_BASE"] = base_url

CPU times: user 9 µs, sys: 2 µs, total: 11 µs
Wall time: 12.6 µs


In [2]:
university_info = """
The University of Washington, founded in 1861 in Seattle, is a public research university
with over 45,000 students across three campuses in Seattle, Tacoma, and Bothell.
As the flagship institution of the six public universities in Washington state,
UW encompasses over 500 buildings and 20 million square feet of space,
including one of the largest library systems in the world.
"""

In [3]:
import nest_asyncio
nest_asyncio.apply()

In [4]:
%%time

from chromadb.utils.embedding_functions import OpenAIEmbeddingFunction

embedding_function = OpenAIEmbeddingFunction(
    api_key=api_key,
    model_name="text-embedding-ada-002",
    api_base=base_url,
)

CPU times: user 880 ms, sys: 175 ms, total: 1.06 s
Wall time: 863 ms


In [5]:
import chromadb

chroma_client = chromadb.Client()
vector_store = chroma_client.get_or_create_collection(
    name="Universities",
    embedding_function=embedding_function
)

In [6]:
vector_store.add(
    "uni_info",
    documents=university_info,
)

In [7]:
%%time

from trulens_eval import Tru
from trulens_eval.tru_custom_app import instrument

tru = Tru()
tru.reset_database()

🦑 Tru initialized with db url sqlite:///default.sqlite .
🛑 Secret keys may be written to the database. See the `database_redact_keys` option of `Tru` to prevent this.
CPU times: user 6.04 s, sys: 411 ms, total: 6.45 s
Wall time: 6.29 s


In [8]:
%%time

import litellm
from litellm import embedding


class RAG_from_scratch:
    @instrument
    def retrieve(self, query: str) -> list:
        """
        Retrieve relevant text from vector store.
        """
        results = vector_store.query(
            query_embeddings=embedding(
                model="text-embedding-ada-002", 
                input=query
            ).data[0]["embedding"],
            n_results=2,
        )
        return results["documents"]

    @instrument
    def generate_completion(self, query: str, context_str: list) -> str:
        """
        Generate answer from context.
        """
        completion = (
            litellm.completion(
                model="gpt-3.5-turbo",
                api_base=base_url,
                temperature=0,
                messages=[
                    {
                        "role": "user",
                        "content": f"We have provided context information below. \n"
                        f"---------------------\n"
                        f"{context_str}"
                        f"\n---------------------\n"
                        f"Given this information, please answer the question: {query}",
                    }
                ],
            )
            .choices[0]
            .message.content
        )
        return completion

    @instrument
    def query(self, query: str) -> str:
        context_str = self.retrieve(query)
        completion = self.generate_completion(query, context_str)
        return completion


rag = RAG_from_scratch()

CPU times: user 41 µs, sys: 3 µs, total: 44 µs
Wall time: 46.3 µs


In [9]:
%%time

import numpy as np

from trulens_eval import Feedback
from trulens_eval import LiteLLM
from trulens_eval import Select

# Initialize LiteLLM-based feedback function collection class:
provider = LiteLLM(
    model_engine="gpt-3.5-turbo",
    # api_base="http://monkey:11434",
    # api_key =api_key,
)

# Define a groundedness feedback function
f_groundedness = (
    Feedback(
        provider.groundedness_measure_with_cot_reasons, name="Groundedness"
    )
    .on(Select.RecordCalls.retrieve.rets.collect())
    .on_output()
)

# Question/answer relevance between overall question and answer.
f_answer_relevance = (
    Feedback(provider.relevance_with_cot_reasons, name="Answer Relevance")
    .on(Select.RecordCalls.retrieve.args.query)
    .on_output()
)

# Question/statement relevance between question and each context chunk.
f_context_relevance = (
    Feedback(
        provider.context_relevance_with_cot_reasons, name="Context Relevance"
    )
    .on(Select.RecordCalls.retrieve.args.query)
    .on(Select.RecordCalls.retrieve.rets.collect())
    .aggregate(np.mean)
)

f_coherence = Feedback(
    provider.coherence_with_cot_reasons, name="coherence"
).on_output()

✅ In Groundedness, input source will be set to __record__.app.retrieve.rets.collect() .
✅ In Groundedness, input statement will be set to __record__.main_output or `Select.RecordOutput` .
✅ In Answer Relevance, input prompt will be set to __record__.app.retrieve.args.query .
✅ In Answer Relevance, input response will be set to __record__.main_output or `Select.RecordOutput` .
✅ In Context Relevance, input question will be set to __record__.app.retrieve.args.query .
✅ In Context Relevance, input context will be set to __record__.app.retrieve.rets.collect() .
✅ In coherence, input text will be set to __record__.main_output or `Select.RecordOutput` .
CPU times: user 84.3 ms, sys: 31 µs, total: 84.3 ms
Wall time: 83.8 ms


In [13]:
from trulens_eval import TruCustomApp

tru_rag = TruCustomApp(
    rag,
    app_id="RAG v1",
    feedbacks=[
        f_groundedness,
        f_answer_relevance,
        f_context_relevance,
        f_coherence,
    ],
)

In [12]:
import nltk
# [nltk_data] Error loading punkt: <urlopen error [Errno 111] Connection
nltk.set_proxy('http://myproxy:7890')

In [14]:
with tru_rag as recording:
    rag.query("Give me a long history of U Dub")

In [15]:
tru.get_leaderboard(app_ids=["RAG v1"])

Unnamed: 0_level_0,Answer Relevance,Context Relevance,Groundedness,coherence,latency,total_cost
app_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
RAG v1,1.0,0.75,0.612222,0.9,6.0,0.000619
