From 45118d3c4425bf77c7465eb85e3d0216e97f80d7 Mon Sep 17 00:00:00 2001 From: "b.nativi" Date: Thu, 23 May 2024 21:36:01 +0000 Subject: [PATCH] add MistralValorClient --- .../backend/metrics/test_text_generation.py | 16 ++-- api/valor_api/backend/metrics/llm_call.py | 86 ++++++++++++++++++- .../backend/metrics/text_generation.py | 4 +- 3 files changed, 95 insertions(+), 11 deletions(-) diff --git a/api/tests/functional-tests/backend/metrics/test_text_generation.py b/api/tests/functional-tests/backend/metrics/test_text_generation.py index ce585ef25..ade82723f 100644 --- a/api/tests/functional-tests/backend/metrics/test_text_generation.py +++ b/api/tests/functional-tests/backend/metrics/test_text_generation.py @@ -164,11 +164,11 @@ def text_generation_test_data(db: Session, dataset_name: str, model_name: str): # TODO Make this text work on the github checks. It currently works locally. # def test_openai_api_request(): # """ -# Tests the OpenAIClient class. +# Tests the OpenAIValorClient class. -# Just tests that a CoherenceMetric is correctly built from an OpenAIClient.coherence call. +# Just tests that a CoherenceMetric is correctly built from an OpenAIValorClient.coherence call. # """ -# client = OpenAIClient( +# client = OpenAIValorClient( # seed=2024, # TODO Should we have a seed here? # ) @@ -182,14 +182,14 @@ def text_generation_test_data(db: Session, dataset_name: str, model_name: str): @patch( - "valor_api.backend.metrics.llm_call.OpenAIClient.connect", + "valor_api.backend.metrics.llm_call.OpenAIValorClient.connect", mocked_connection, ) @patch( - "valor_api.backend.metrics.llm_call.OpenAIClient.coherence", + "valor_api.backend.metrics.llm_call.OpenAIValorClient.coherence", mocked_coherence, ) -# @patch.object(llm_call.OpenAIClient, mocked_OpenAIClient) +# @patch.object(llm_call.OpenAIValorClient, mocked_OpenAIValorClient) def test_compute_text_generation( db: Session, dataset_name: str, @@ -269,11 +269,11 @@ def test_compute_text_generation( @patch( - "valor_api.backend.metrics.llm_call.OpenAIClient.connect", + "valor_api.backend.metrics.llm_call.OpenAIValorClient.connect", mocked_connection, ) @patch( - "valor_api.backend.metrics.llm_call.OpenAIClient.coherence", + "valor_api.backend.metrics.llm_call.OpenAIValorClient.coherence", mocked_coherence, ) def test_text_generation( diff --git a/api/valor_api/backend/metrics/llm_call.py b/api/valor_api/backend/metrics/llm_call.py index c87e2b300..fa55bdfe0 100644 --- a/api/valor_api/backend/metrics/llm_call.py +++ b/api/valor_api/backend/metrics/llm_call.py @@ -1,6 +1,8 @@ from typing import Any import openai +from mistralai.client import MistralClient +from mistralai.models.chat_completion import ChatMessage COHERENCE_INSTRUCTION = """You are a helpful assistant. You will grade the user's text. Your task is to rate the text based on its coherence. Please make sure you read and understand these instructions carefully. Please keep this document open while reviewing, and refer to it as needed. Evaluation Criteria: @@ -96,7 +98,7 @@ def coherence( return ret -class OpenAIClient(LLMClient): +class OpenAIValorClient(LLMClient): """ Wrapper for calls to OpenAI's API. @@ -172,6 +174,7 @@ def __call__( finish_reason = openai_response.choices[0].finish_reason response = openai_response.choices[0].message.content + # TODO Only keep these if we can test them. if finish_reason == "length": raise ValueError( "OpenAI response reached max token limit. Resulting evaluation is likely invalid or of low quality." @@ -182,3 +185,84 @@ def __call__( ) return response + + +class MistralValorClient(LLMClient): + """ + Wrapper for calls to Mistral's API. + + Parameters + ---------- + api_key : str, optional + The Mistral API key to use. If not specified, then the MISTRAL_API_KEY environment variable will be used. + model_name : str + The model to use. Defaults to "mistral-small-latest". + """ + + # url: str + api_key: str | None = None + model_name: str = ( + "mistral-small-latest" # mistral-small-latest mistral-large-latest + ) + + def __init__( + self, + api_key: str | None = None, + ): + """ + TODO should we use an __attrs_post_init__ instead? + """ + self.api_key = api_key + + def connect( + self, + ): + """ + TODO This is separated for now because I want to mock connecting to the Mistral API. + """ + if self.api_key is None: + self.client = MistralClient() + else: + self.client = MistralClient(api_key=self.api_key) + + def process_messages( + self, + messages: list[dict[str, str]], + ) -> Any: + """ + All messages should be formatted according to the standard set by OpenAI, and should be modified + as needed for other models. This function takes in messages in the OpenAI standard format and converts + them to the format required by the model. + """ + ret = [] + for i in range(len(messages)): + ret.append( + ChatMessage( + role=messages[i]["role"], content=messages[i]["content"] + ) + ) + return ret + + def __call__( + self, + messages: list[dict[str, str]], + ) -> Any: + """ + Call to the API. + + TODO possibly change this to a call with the API. This would remove the openai python dependence. + """ + processed_messages = self.process_messages(messages) + mistral_response = self.client.chat( + model=self.model_name, + messages=processed_messages, + ) + # TODO Are there any errors we should catch in a try except block? + + # token_usage = mistral_response.usage # TODO Could report token usage to user. Could use token length to determine if input is too larger, although this would require us to know the model's context window size. + # finish_reason = mistral_response.choices[0].finish_reason + response = mistral_response.choices[0].message.content + + # TODO Possibly add errors depending on the finish reason? + + return response diff --git a/api/valor_api/backend/metrics/text_generation.py b/api/valor_api/backend/metrics/text_generation.py index 4e49264a3..2c950f367 100644 --- a/api/valor_api/backend/metrics/text_generation.py +++ b/api/valor_api/backend/metrics/text_generation.py @@ -7,7 +7,7 @@ from valor_api import schemas from valor_api.backend import core, models -from valor_api.backend.metrics.llm_call import OpenAIClient +from valor_api.backend.metrics.llm_call import OpenAIValorClient from valor_api.backend.metrics.metric_utils import ( # log_evaluation_item_counts, create_metric_mappings, get_or_create_row, @@ -161,7 +161,7 @@ def _compute_text_generation_metrics( res = db.execute(total_query).all() - client = OpenAIClient() + client = OpenAIValorClient() client.connect() ret = []