Skip to content

Commit

Permalink
add MistralValorClient
Browse files Browse the repository at this point in the history
  • Loading branch information
b.nativi committed May 23, 2024
1 parent 59fca69 commit 45118d3
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,11 @@ def text_generation_test_data(db: Session, dataset_name: str, model_name: str):
# TODO Make this text work on the github checks. It currently works locally.
# def test_openai_api_request():
# """
# Tests the OpenAIClient class.
# Tests the OpenAIValorClient class.

# Just tests that a CoherenceMetric is correctly built from an OpenAIClient.coherence call.
# Just tests that a CoherenceMetric is correctly built from an OpenAIValorClient.coherence call.
# """
# client = OpenAIClient(
# client = OpenAIValorClient(
# seed=2024, # TODO Should we have a seed here?
# )

Expand All @@ -182,14 +182,14 @@ def text_generation_test_data(db: Session, dataset_name: str, model_name: str):


@patch(
"valor_api.backend.metrics.llm_call.OpenAIClient.connect",
"valor_api.backend.metrics.llm_call.OpenAIValorClient.connect",
mocked_connection,
)
@patch(
"valor_api.backend.metrics.llm_call.OpenAIClient.coherence",
"valor_api.backend.metrics.llm_call.OpenAIValorClient.coherence",
mocked_coherence,
)
# @patch.object(llm_call.OpenAIClient, mocked_OpenAIClient)
# @patch.object(llm_call.OpenAIValorClient, mocked_OpenAIValorClient)
def test_compute_text_generation(
db: Session,
dataset_name: str,
Expand Down Expand Up @@ -269,11 +269,11 @@ def test_compute_text_generation(


@patch(
"valor_api.backend.metrics.llm_call.OpenAIClient.connect",
"valor_api.backend.metrics.llm_call.OpenAIValorClient.connect",
mocked_connection,
)
@patch(
"valor_api.backend.metrics.llm_call.OpenAIClient.coherence",
"valor_api.backend.metrics.llm_call.OpenAIValorClient.coherence",
mocked_coherence,
)
def test_text_generation(
Expand Down
86 changes: 85 additions & 1 deletion api/valor_api/backend/metrics/llm_call.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Any

import openai
from mistralai.client import MistralClient
from mistralai.models.chat_completion import ChatMessage

COHERENCE_INSTRUCTION = """You are a helpful assistant. You will grade the user's text. Your task is to rate the text based on its coherence. Please make sure you read and understand these instructions carefully. Please keep this document open while reviewing, and refer to it as needed.
Evaluation Criteria:
Expand Down Expand Up @@ -96,7 +98,7 @@ def coherence(
return ret


class OpenAIClient(LLMClient):
class OpenAIValorClient(LLMClient):
"""
Wrapper for calls to OpenAI's API.
Expand Down Expand Up @@ -172,6 +174,7 @@ def __call__(
finish_reason = openai_response.choices[0].finish_reason
response = openai_response.choices[0].message.content

# TODO Only keep these if we can test them.
if finish_reason == "length":
raise ValueError(
"OpenAI response reached max token limit. Resulting evaluation is likely invalid or of low quality."
Expand All @@ -182,3 +185,84 @@ def __call__(
)

return response


class MistralValorClient(LLMClient):
"""
Wrapper for calls to Mistral's API.
Parameters
----------
api_key : str, optional
The Mistral API key to use. If not specified, then the MISTRAL_API_KEY environment variable will be used.
model_name : str
The model to use. Defaults to "mistral-small-latest".
"""

# url: str
api_key: str | None = None
model_name: str = (
"mistral-small-latest" # mistral-small-latest mistral-large-latest
)

def __init__(
self,
api_key: str | None = None,
):
"""
TODO should we use an __attrs_post_init__ instead?
"""
self.api_key = api_key

def connect(
self,
):
"""
TODO This is separated for now because I want to mock connecting to the Mistral API.
"""
if self.api_key is None:
self.client = MistralClient()
else:
self.client = MistralClient(api_key=self.api_key)

def process_messages(
self,
messages: list[dict[str, str]],
) -> Any:
"""
All messages should be formatted according to the standard set by OpenAI, and should be modified
as needed for other models. This function takes in messages in the OpenAI standard format and converts
them to the format required by the model.
"""
ret = []
for i in range(len(messages)):
ret.append(
ChatMessage(
role=messages[i]["role"], content=messages[i]["content"]
)
)
return ret

def __call__(
self,
messages: list[dict[str, str]],
) -> Any:
"""
Call to the API.
TODO possibly change this to a call with the API. This would remove the openai python dependence.
"""
processed_messages = self.process_messages(messages)
mistral_response = self.client.chat(
model=self.model_name,
messages=processed_messages,
)
# TODO Are there any errors we should catch in a try except block?

# token_usage = mistral_response.usage # TODO Could report token usage to user. Could use token length to determine if input is too larger, although this would require us to know the model's context window size.
# finish_reason = mistral_response.choices[0].finish_reason
response = mistral_response.choices[0].message.content

# TODO Possibly add errors depending on the finish reason?

return response
4 changes: 2 additions & 2 deletions api/valor_api/backend/metrics/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from valor_api import schemas
from valor_api.backend import core, models
from valor_api.backend.metrics.llm_call import OpenAIClient
from valor_api.backend.metrics.llm_call import OpenAIValorClient
from valor_api.backend.metrics.metric_utils import ( # log_evaluation_item_counts,
create_metric_mappings,
get_or_create_row,
Expand Down Expand Up @@ -161,7 +161,7 @@ def _compute_text_generation_metrics(

res = db.execute(total_query).all()

client = OpenAIClient()
client = OpenAIValorClient()
client.connect()
ret = []

Expand Down

0 comments on commit 45118d3

Please sign in to comment.