From ab7ec8760709bddbef0e5c320284c5c93c325aeb Mon Sep 17 00:00:00 2001 From: microwavetoasteroven <69960049+microwavetoasteroven@users.noreply.github.com> Date: Mon, 13 May 2024 11:27:17 +0100 Subject: [PATCH 01/10] fix/base.py typo fix/base.py typo --- semantic_router/llms/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/semantic_router/llms/base.py b/semantic_router/llms/base.py index 963c754..7cbc798 100644 --- a/semantic_router/llms/base.py +++ b/semantic_router/llms/base.py @@ -91,7 +91,7 @@ def extract_function_inputs( === EXAMPLE_OUTPUT End === ### EXAMPLE End ### -Note: I will tip $500 for and accurate JSON output. You will be penalized for an inaccurate JSON output. +Note: I will tip $500 for an accurate JSON output. You will be penalized for an inaccurate JSON output. Provide JSON output now: """ From 8392895cd6ce8c3c064df0590a90ab0a3f7f8cb7 Mon Sep 17 00:00:00 2001 From: Siraj R Aizlewood Date: Mon, 13 May 2024 17:15:11 +0400 Subject: [PATCH 02/10] _is_valid_inputs() fixes. Now no longer requires typehints in the function signature (it wasn't using these anyway, and would break when they weren't included. Also, we now only check if mandatary arguments have been provided in input. None mandatory don't need to be present. Finally, addde a check to ensure that, if there are extra arguments provided in input not present in the signature, then these result in false being returned. --- docs/10-debugging-discord-issue.ipynb | 163 ++++++++++++++++++++++++++ semantic_router/llms/base.py | 52 ++++++-- 2 files changed, 205 insertions(+), 10 deletions(-) create mode 100644 docs/10-debugging-discord-issue.ipynb diff --git a/docs/10-debugging-discord-issue.ipynb b/docs/10-debugging-discord-issue.ipynb new file mode 100644 index 0000000..9c781d4 --- /dev/null +++ b/docs/10-debugging-discord-issue.ipynb @@ -0,0 +1,163 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\Users\\Siraj\\Documents\\Personal\\Work\\Aurelio\\Virtual Environments\\semantic_router_3\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "\u001b[32m2024-05-13 17:10:58 INFO semantic_router.utils.logger local\u001b[0m\n", + "\u001b[32m2024-05-13 17:10:59 INFO semantic_router.utils.logger Extracting function input...\u001b[0m\n", + "\u001b[32m2024-05-13 17:11:01 INFO semantic_router.utils.logger LLM output: {\n", + " \"location\": \"berlin\"\n", + "}\u001b[0m\n", + "\u001b[32m2024-05-13 17:11:01 INFO semantic_router.utils.logger Function inputs: {'location': 'berlin'}\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2024-05-13 15:11\n" + ] + } + ], + "source": [ + "import datetime\n", + "import pytz\n", + "from semantic_router.llms.openrouter import OpenRouterLLM\n", + "from semantic_router import Route, RouteLayer\n", + "from semantic_router.encoders import HuggingFaceEncoder\n", + "from semantic_router.utils.function_call import get_schema\n", + "import geonamescache\n", + "\n", + "class Skill:\n", + " def __init__(self):\n", + " self.geocoder = geonamescache.GeonamesCache()\n", + " self.location = self.geocode_location()\n", + " self.route = Route(\n", + " name='time',\n", + " utterances=[\n", + " \"tell me what is the time\",\n", + " \"what is the date \",\n", + " \"time in varshava\",\n", + " \"date\",\n", + " \"what date is it today\",\n", + " \"time in ny\",\n", + " \"what is the time and date in boston\",\n", + " \"time\",\n", + " \"what is the time in makhachkala\",\n", + " \"date time in st petersburg\",\n", + " \"what's the date in vienna\",\n", + " \"date time\"\n", + " ],\n", + " function_schema=get_schema(self.run),\n", + "\n", + " )\n", + "\n", + " self.rl = RouteLayer(\n", + " encoder=HuggingFaceEncoder(),\n", + " routes=[self.route],\n", + " llm=OpenRouterLLM(\n", + " name='mistralai/mistral-7b-instruct:free',\n", + " openrouter_api_key='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'\n", + "\n", + " )\n", + " )\n", + "\n", + " def geocode_location(self, location_name=None):\n", + " if location_name:\n", + " location_name = location_name.title()\n", + " location = list(\n", + " self.geocoder.get_cities_by_name(location_name)[0].values() if self.geocoder.get_cities_by_name(\n", + " location_name) else self.geocoder.get_us_states_by_names(location_name)[\n", + " 0].values() if self.geocoder.get_us_states_by_names(location_name) else\n", + " self.geocoder.get_countries_by_names(location_name)[\n", + " 0].values() if self.geocoder.get_countries_by_names(location_name) else None)[0]\n", + " return location['timezone']\n", + " else:\n", + " return ''\n", + "\n", + " def run(self, location:str=None, day:int=0, hour:int=0, minute:int=0):\n", + " \"\"\"Finds the current time in a specific location.\n", + "\n", + " :param location: The location to find the current time in, should\n", + " be a valid location. Put the place name itself\n", + " like \"rome\", or \"new york\" in the lowercase.\n", + " :type location: str\n", + "\n", + " :param day: The offset in days from the current date.\n", + " Use positive integers for future dates (e.g., day=1 for tomorrow),\n", + " negative integers for past dates (e.g., day=-1 for yesterday),\n", + " and 0 for the current date.\n", + " :type day: int\n", + "\n", + " :param hour: The offset in hours from the current time.\n", + " Use positive integers for future times (e.g., hour=1 for one hour ahead),\n", + " negative integers for past times (e.g., hour=-1 for one hour ago),\n", + " and 0 to maintain the current hour.\n", + " :type hour: int\n", + "\n", + " :param minute: The offset in minutes from the current time.\n", + " Use positive integers for future minutes (e.g., minute=20 for twenty minutes ahead),\n", + " negative integers for past minutes (e.g., minute=-20 for twenty minutes ago),\n", + " and 0 to maintain the current minute.\n", + " :type minute: int\n", + "\n", + " :return: The time in the specified location.\"\"\"\n", + " timezone = self.geocode_location(location)\n", + " if timezone:\n", + " tz = pytz.timezone(timezone)\n", + " else:\n", + " tz = None\n", + "\n", + " current_time = datetime.datetime.now(tz) + datetime.timedelta(days=day)\n", + "\n", + " # Adding hours and minutes to the current time\n", + " current_time += datetime.timedelta(hours=hour, minutes=minute)\n", + "\n", + " # Format the date and time as required\n", + " formatted_time = current_time.strftime(\"%Y-%m-%d %H:%M\")\n", + "\n", + " return formatted_time\n", + "\n", + "s = Skill()\n", + "out = s.rl('time in berlin')\n", + "print(s.run(**out.function_call))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "semantic_router_3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/semantic_router/llms/base.py b/semantic_router/llms/base.py index 963c754..658ed08 100644 --- a/semantic_router/llms/base.py +++ b/semantic_router/llms/base.py @@ -18,23 +18,55 @@ def __init__(self, name: str, **kwargs): def __call__(self, messages: List[Message]) -> Optional[str]: raise NotImplementedError("Subclasses must implement this method") - + + + def _check_for_mandatory_inputs(self, inputs: dict[str, Any], mandatory_params: List[str]) -> bool: + """Check for mandatory parameters in inputs""" + for name in mandatory_params: + if name not in inputs: + logger.error(f"Mandatory input {name} missing from query") + return False + return True + + def _check_for_extra_inputs(self, inputs: dict[str, Any], all_params: List[str]) -> bool: + """Check for extra parameters not defined in the signature""" + input_keys = set(inputs.keys()) + param_keys = set(all_params) + if not input_keys.issubset(param_keys): + extra_keys = input_keys - param_keys + logger.error(f"Extra inputs provided that are not in the signature: {extra_keys}") + return False + return True + def _is_valid_inputs( self, inputs: dict[str, Any], function_schema: dict[str, Any] ) -> bool: """Validate the extracted inputs against the function schema""" try: - # Extract parameter names and types from the signature string + # Extract parameter names and determine if they are optional signature = function_schema["signature"] param_info = [param.strip() for param in signature[1:-1].split(",")] - param_names = [info.split(":")[0].strip() for info in param_info] - param_types = [ - info.split(":")[1].strip().split("=")[0].strip() for info in param_info - ] - for name, type_str in zip(param_names, param_types): - if name not in inputs: - logger.error(f"Input {name} missing from query") - return False + mandatory_params = [] + all_params = [] + + for info in param_info: + parts = info.split("=") + name_type_pair = parts[0].strip() + name = name_type_pair.split(":")[0].strip() + all_params.append(name) + + # If there is no default value, it's a mandatory parameter + if len(parts) == 1: + mandatory_params.append(name) + + # Check for mandatory parameters + if not self._check_for_mandatory_inputs(inputs, mandatory_params): + return False + + # Check for extra parameters not defined in the signature + if not self._check_for_extra_inputs(inputs, all_params): + return False + return True except Exception as e: logger.error(f"Input validation error: {str(e)}") From d756c8694c433eb0f514f15e28b8685ff38a0485 Mon Sep 17 00:00:00 2001 From: Siraj R Aizlewood Date: Mon, 13 May 2024 17:49:16 +0400 Subject: [PATCH 03/10] PyTests and bug fix. --- semantic_router/llms/base.py | 5 ++++- tests/unit/llms/test_llm_base.py | 27 +++++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/semantic_router/llms/base.py b/semantic_router/llms/base.py index 658ed08..639d3c1 100644 --- a/semantic_router/llms/base.py +++ b/semantic_router/llms/base.py @@ -52,7 +52,10 @@ def _is_valid_inputs( for info in param_info: parts = info.split("=") name_type_pair = parts[0].strip() - name = name_type_pair.split(":")[0].strip() + if ':' in name_type_pair: + name, _ = name_type_pair.split(":") + else: + name = name_type_pair all_params.append(name) # If there is no default value, it's a mandatory parameter diff --git a/tests/unit/llms/test_llm_base.py b/tests/unit/llms/test_llm_base.py index 2208928..d5f14a1 100644 --- a/tests/unit/llms/test_llm_base.py +++ b/tests/unit/llms/test_llm_base.py @@ -4,9 +4,18 @@ class TestBaseLLM: + @pytest.fixture def base_llm(self): return BaseLLM(name="TestLLM") + + @pytest.fixture + def mixed_function_schema(self): + return { + "name": "test_function", + "description": "A test function with mixed mandatory and optional parameters.", + "signature": "(mandatory1, mandatory2: int, optional1=None, optional2: str = 'default')" + } def test_base_llm_initialization(self, base_llm): assert base_llm.name == "TestLLM", "Initialization of name failed" @@ -69,3 +78,21 @@ def test_base_llm_extract_function_inputs_no_output(self, base_llm, mocker): } test_query = "What time is it in America/New_York?" base_llm.extract_function_inputs(test_schema, test_query) + + + def test_mandatory_args_only(self, base_llm, mixed_function_schema): + inputs = {"mandatory1": "value1", "mandatory2": 42} + assert base_llm._is_valid_inputs(inputs, mixed_function_schema) == True + + def test_all_args_provided(self, base_llm, mixed_function_schema): + inputs = {"mandatory1": "value1", "mandatory2": 42, "optional1": "opt1", "optional2": "opt2"} + assert base_llm._is_valid_inputs(inputs, mixed_function_schema) == True + + def test_missing_mandatory_arg(self, base_llm, mixed_function_schema): + inputs = {"mandatory1": "value1", "optional1": "opt1", "optional2": "opt2"} + assert base_llm._is_valid_inputs(inputs, mixed_function_schema) == False + + def test_extra_arg_provided(self, base_llm, mixed_function_schema): + inputs = {"mandatory1": "value1", "mandatory2": 42, "optional1": "opt1", "optional2": "opt2", "extra": "value"} + assert base_llm._is_valid_inputs(inputs, mixed_function_schema) == False + From 5c60f73bd1d400cc999b7160823c4af62ff4dd52 Mon Sep 17 00:00:00 2001 From: Siraj R Aizlewood Date: Mon, 13 May 2024 17:51:00 +0400 Subject: [PATCH 04/10] More PyTests. --- tests/unit/llms/test_llm_base.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/tests/unit/llms/test_llm_base.py b/tests/unit/llms/test_llm_base.py index d5f14a1..5e9bbe9 100644 --- a/tests/unit/llms/test_llm_base.py +++ b/tests/unit/llms/test_llm_base.py @@ -4,7 +4,7 @@ class TestBaseLLM: - + @pytest.fixture def base_llm(self): return BaseLLM(name="TestLLM") @@ -16,6 +16,14 @@ def mixed_function_schema(self): "description": "A test function with mixed mandatory and optional parameters.", "signature": "(mandatory1, mandatory2: int, optional1=None, optional2: str = 'default')" } + + @pytest.fixture + def mandatory_params(self): + return ["param1", "param2"] + + @pytest.fixture + def all_params(self): + return ["param1", "param2", "optional1"] def test_base_llm_initialization(self, base_llm): assert base_llm.name == "TestLLM", "Initialization of name failed" @@ -96,3 +104,18 @@ def test_extra_arg_provided(self, base_llm, mixed_function_schema): inputs = {"mandatory1": "value1", "mandatory2": 42, "optional1": "opt1", "optional2": "opt2", "extra": "value"} assert base_llm._is_valid_inputs(inputs, mixed_function_schema) == False + def test_check_for_mandatory_inputs_all_present(self, base_llm, mandatory_params): + inputs = {"param1": "value1", "param2": "value2"} + assert base_llm._check_for_mandatory_inputs(inputs, mandatory_params) == True + + def test_check_for_mandatory_inputs_missing_one(self, base_llm, mandatory_params): + inputs = {"param1": "value1"} + assert base_llm._check_for_mandatory_inputs(inputs, mandatory_params) == False + + def test_check_for_extra_inputs_no_extras(self, base_llm, all_params): + inputs = {"param1": "value1", "param2": "value2"} + assert base_llm._check_for_extra_inputs(inputs, all_params) == True + + def test_check_for_extra_inputs_with_extras(self, base_llm, all_params): + inputs = {"param1": "value1", "param2": "value2", "extra_param": "extra"} + assert base_llm._check_for_extra_inputs(inputs, all_params) == False \ No newline at end of file From db6cd6012e1255589c87f370345512343488db8f Mon Sep 17 00:00:00 2001 From: Siraj R Aizlewood Date: Mon, 13 May 2024 18:05:34 +0400 Subject: [PATCH 05/10] Tested and then removed debugging notebook. --- docs/10-debugging-discord-issue.ipynb | 163 -------------------------- 1 file changed, 163 deletions(-) delete mode 100644 docs/10-debugging-discord-issue.ipynb diff --git a/docs/10-debugging-discord-issue.ipynb b/docs/10-debugging-discord-issue.ipynb deleted file mode 100644 index 9c781d4..0000000 --- a/docs/10-debugging-discord-issue.ipynb +++ /dev/null @@ -1,163 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "c:\\Users\\Siraj\\Documents\\Personal\\Work\\Aurelio\\Virtual Environments\\semantic_router_3\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n", - "\u001b[32m2024-05-13 17:10:58 INFO semantic_router.utils.logger local\u001b[0m\n", - "\u001b[32m2024-05-13 17:10:59 INFO semantic_router.utils.logger Extracting function input...\u001b[0m\n", - "\u001b[32m2024-05-13 17:11:01 INFO semantic_router.utils.logger LLM output: {\n", - " \"location\": \"berlin\"\n", - "}\u001b[0m\n", - "\u001b[32m2024-05-13 17:11:01 INFO semantic_router.utils.logger Function inputs: {'location': 'berlin'}\u001b[0m\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2024-05-13 15:11\n" - ] - } - ], - "source": [ - "import datetime\n", - "import pytz\n", - "from semantic_router.llms.openrouter import OpenRouterLLM\n", - "from semantic_router import Route, RouteLayer\n", - "from semantic_router.encoders import HuggingFaceEncoder\n", - "from semantic_router.utils.function_call import get_schema\n", - "import geonamescache\n", - "\n", - "class Skill:\n", - " def __init__(self):\n", - " self.geocoder = geonamescache.GeonamesCache()\n", - " self.location = self.geocode_location()\n", - " self.route = Route(\n", - " name='time',\n", - " utterances=[\n", - " \"tell me what is the time\",\n", - " \"what is the date \",\n", - " \"time in varshava\",\n", - " \"date\",\n", - " \"what date is it today\",\n", - " \"time in ny\",\n", - " \"what is the time and date in boston\",\n", - " \"time\",\n", - " \"what is the time in makhachkala\",\n", - " \"date time in st petersburg\",\n", - " \"what's the date in vienna\",\n", - " \"date time\"\n", - " ],\n", - " function_schema=get_schema(self.run),\n", - "\n", - " )\n", - "\n", - " self.rl = RouteLayer(\n", - " encoder=HuggingFaceEncoder(),\n", - " routes=[self.route],\n", - " llm=OpenRouterLLM(\n", - " name='mistralai/mistral-7b-instruct:free',\n", - " openrouter_api_key='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'\n", - "\n", - " )\n", - " )\n", - "\n", - " def geocode_location(self, location_name=None):\n", - " if location_name:\n", - " location_name = location_name.title()\n", - " location = list(\n", - " self.geocoder.get_cities_by_name(location_name)[0].values() if self.geocoder.get_cities_by_name(\n", - " location_name) else self.geocoder.get_us_states_by_names(location_name)[\n", - " 0].values() if self.geocoder.get_us_states_by_names(location_name) else\n", - " self.geocoder.get_countries_by_names(location_name)[\n", - " 0].values() if self.geocoder.get_countries_by_names(location_name) else None)[0]\n", - " return location['timezone']\n", - " else:\n", - " return ''\n", - "\n", - " def run(self, location:str=None, day:int=0, hour:int=0, minute:int=0):\n", - " \"\"\"Finds the current time in a specific location.\n", - "\n", - " :param location: The location to find the current time in, should\n", - " be a valid location. Put the place name itself\n", - " like \"rome\", or \"new york\" in the lowercase.\n", - " :type location: str\n", - "\n", - " :param day: The offset in days from the current date.\n", - " Use positive integers for future dates (e.g., day=1 for tomorrow),\n", - " negative integers for past dates (e.g., day=-1 for yesterday),\n", - " and 0 for the current date.\n", - " :type day: int\n", - "\n", - " :param hour: The offset in hours from the current time.\n", - " Use positive integers for future times (e.g., hour=1 for one hour ahead),\n", - " negative integers for past times (e.g., hour=-1 for one hour ago),\n", - " and 0 to maintain the current hour.\n", - " :type hour: int\n", - "\n", - " :param minute: The offset in minutes from the current time.\n", - " Use positive integers for future minutes (e.g., minute=20 for twenty minutes ahead),\n", - " negative integers for past minutes (e.g., minute=-20 for twenty minutes ago),\n", - " and 0 to maintain the current minute.\n", - " :type minute: int\n", - "\n", - " :return: The time in the specified location.\"\"\"\n", - " timezone = self.geocode_location(location)\n", - " if timezone:\n", - " tz = pytz.timezone(timezone)\n", - " else:\n", - " tz = None\n", - "\n", - " current_time = datetime.datetime.now(tz) + datetime.timedelta(days=day)\n", - "\n", - " # Adding hours and minutes to the current time\n", - " current_time += datetime.timedelta(hours=hour, minutes=minute)\n", - "\n", - " # Format the date and time as required\n", - " formatted_time = current_time.strftime(\"%Y-%m-%d %H:%M\")\n", - "\n", - " return formatted_time\n", - "\n", - "s = Skill()\n", - "out = s.rl('time in berlin')\n", - "print(s.run(**out.function_call))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "semantic_router_3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.4" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} From 75ee1b7ddb2d80aba0b57a83bb7564cad78fc5eb Mon Sep 17 00:00:00 2001 From: Siraj R Aizlewood Date: Mon, 13 May 2024 18:07:14 +0400 Subject: [PATCH 06/10] Linting. --- semantic_router/llms/base.py | 19 ++++++++++++------- tests/unit/llms/test_llm_base.py | 24 +++++++++++++++++------- 2 files changed, 29 insertions(+), 14 deletions(-) diff --git a/semantic_router/llms/base.py b/semantic_router/llms/base.py index 639d3c1..604d8ad 100644 --- a/semantic_router/llms/base.py +++ b/semantic_router/llms/base.py @@ -18,26 +18,31 @@ def __init__(self, name: str, **kwargs): def __call__(self, messages: List[Message]) -> Optional[str]: raise NotImplementedError("Subclasses must implement this method") - - def _check_for_mandatory_inputs(self, inputs: dict[str, Any], mandatory_params: List[str]) -> bool: + def _check_for_mandatory_inputs( + self, inputs: dict[str, Any], mandatory_params: List[str] + ) -> bool: """Check for mandatory parameters in inputs""" for name in mandatory_params: if name not in inputs: logger.error(f"Mandatory input {name} missing from query") return False return True - - def _check_for_extra_inputs(self, inputs: dict[str, Any], all_params: List[str]) -> bool: + + def _check_for_extra_inputs( + self, inputs: dict[str, Any], all_params: List[str] + ) -> bool: """Check for extra parameters not defined in the signature""" input_keys = set(inputs.keys()) param_keys = set(all_params) if not input_keys.issubset(param_keys): extra_keys = input_keys - param_keys - logger.error(f"Extra inputs provided that are not in the signature: {extra_keys}") + logger.error( + f"Extra inputs provided that are not in the signature: {extra_keys}" + ) return False return True - + def _is_valid_inputs( self, inputs: dict[str, Any], function_schema: dict[str, Any] ) -> bool: @@ -52,7 +57,7 @@ def _is_valid_inputs( for info in param_info: parts = info.split("=") name_type_pair = parts[0].strip() - if ':' in name_type_pair: + if ":" in name_type_pair: name, _ = name_type_pair.split(":") else: name = name_type_pair diff --git a/tests/unit/llms/test_llm_base.py b/tests/unit/llms/test_llm_base.py index 5e9bbe9..7c8dbf3 100644 --- a/tests/unit/llms/test_llm_base.py +++ b/tests/unit/llms/test_llm_base.py @@ -8,15 +8,15 @@ class TestBaseLLM: @pytest.fixture def base_llm(self): return BaseLLM(name="TestLLM") - + @pytest.fixture def mixed_function_schema(self): return { "name": "test_function", "description": "A test function with mixed mandatory and optional parameters.", - "signature": "(mandatory1, mandatory2: int, optional1=None, optional2: str = 'default')" + "signature": "(mandatory1, mandatory2: int, optional1=None, optional2: str = 'default')", } - + @pytest.fixture def mandatory_params(self): return ["param1", "param2"] @@ -87,13 +87,17 @@ def test_base_llm_extract_function_inputs_no_output(self, base_llm, mocker): test_query = "What time is it in America/New_York?" base_llm.extract_function_inputs(test_schema, test_query) - def test_mandatory_args_only(self, base_llm, mixed_function_schema): inputs = {"mandatory1": "value1", "mandatory2": 42} assert base_llm._is_valid_inputs(inputs, mixed_function_schema) == True def test_all_args_provided(self, base_llm, mixed_function_schema): - inputs = {"mandatory1": "value1", "mandatory2": 42, "optional1": "opt1", "optional2": "opt2"} + inputs = { + "mandatory1": "value1", + "mandatory2": 42, + "optional1": "opt1", + "optional2": "opt2", + } assert base_llm._is_valid_inputs(inputs, mixed_function_schema) == True def test_missing_mandatory_arg(self, base_llm, mixed_function_schema): @@ -101,7 +105,13 @@ def test_missing_mandatory_arg(self, base_llm, mixed_function_schema): assert base_llm._is_valid_inputs(inputs, mixed_function_schema) == False def test_extra_arg_provided(self, base_llm, mixed_function_schema): - inputs = {"mandatory1": "value1", "mandatory2": 42, "optional1": "opt1", "optional2": "opt2", "extra": "value"} + inputs = { + "mandatory1": "value1", + "mandatory2": 42, + "optional1": "opt1", + "optional2": "opt2", + "extra": "value", + } assert base_llm._is_valid_inputs(inputs, mixed_function_schema) == False def test_check_for_mandatory_inputs_all_present(self, base_llm, mandatory_params): @@ -118,4 +128,4 @@ def test_check_for_extra_inputs_no_extras(self, base_llm, all_params): def test_check_for_extra_inputs_with_extras(self, base_llm, all_params): inputs = {"param1": "value1", "param2": "value2", "extra_param": "extra"} - assert base_llm._check_for_extra_inputs(inputs, all_params) == False \ No newline at end of file + assert base_llm._check_for_extra_inputs(inputs, all_params) == False From bfeeb1d90eea22c607065f7b47fff938078981f1 Mon Sep 17 00:00:00 2001 From: Siraj R Aizlewood Date: Mon, 13 May 2024 18:08:34 +0400 Subject: [PATCH 07/10] Linting. --- tests/unit/llms/test_llm_base.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/unit/llms/test_llm_base.py b/tests/unit/llms/test_llm_base.py index 7c8dbf3..a2ec618 100644 --- a/tests/unit/llms/test_llm_base.py +++ b/tests/unit/llms/test_llm_base.py @@ -89,7 +89,7 @@ def test_base_llm_extract_function_inputs_no_output(self, base_llm, mocker): def test_mandatory_args_only(self, base_llm, mixed_function_schema): inputs = {"mandatory1": "value1", "mandatory2": 42} - assert base_llm._is_valid_inputs(inputs, mixed_function_schema) == True + assert base_llm._is_valid_inputs(inputs, mixed_function_schema) # True is implied def test_all_args_provided(self, base_llm, mixed_function_schema): inputs = { @@ -98,11 +98,11 @@ def test_all_args_provided(self, base_llm, mixed_function_schema): "optional1": "opt1", "optional2": "opt2", } - assert base_llm._is_valid_inputs(inputs, mixed_function_schema) == True + assert base_llm._is_valid_inputs(inputs, mixed_function_schema) # True is implied def test_missing_mandatory_arg(self, base_llm, mixed_function_schema): inputs = {"mandatory1": "value1", "optional1": "opt1", "optional2": "opt2"} - assert base_llm._is_valid_inputs(inputs, mixed_function_schema) == False + assert not base_llm._is_valid_inputs(inputs, mixed_function_schema) def test_extra_arg_provided(self, base_llm, mixed_function_schema): inputs = { @@ -112,20 +112,20 @@ def test_extra_arg_provided(self, base_llm, mixed_function_schema): "optional2": "opt2", "extra": "value", } - assert base_llm._is_valid_inputs(inputs, mixed_function_schema) == False + assert not base_llm._is_valid_inputs(inputs, mixed_function_schema) def test_check_for_mandatory_inputs_all_present(self, base_llm, mandatory_params): inputs = {"param1": "value1", "param2": "value2"} - assert base_llm._check_for_mandatory_inputs(inputs, mandatory_params) == True + assert base_llm._check_for_mandatory_inputs(inputs, mandatory_params) # True is implied def test_check_for_mandatory_inputs_missing_one(self, base_llm, mandatory_params): inputs = {"param1": "value1"} - assert base_llm._check_for_mandatory_inputs(inputs, mandatory_params) == False + assert not base_llm._check_for_mandatory_inputs(inputs, mandatory_params) def test_check_for_extra_inputs_no_extras(self, base_llm, all_params): inputs = {"param1": "value1", "param2": "value2"} - assert base_llm._check_for_extra_inputs(inputs, all_params) == True + assert base_llm._check_for_extra_inputs(inputs, all_params) # True is implied def test_check_for_extra_inputs_with_extras(self, base_llm, all_params): inputs = {"param1": "value1", "param2": "value2", "extra_param": "extra"} - assert base_llm._check_for_extra_inputs(inputs, all_params) == False + assert not base_llm._check_for_extra_inputs(inputs, all_params) From 21dd2fdf31c1fe46df6239ceeb415f891a336ea4 Mon Sep 17 00:00:00 2001 From: Siraj R Aizlewood Date: Mon, 13 May 2024 18:10:44 +0400 Subject: [PATCH 08/10] Linting. --- tests/unit/llms/test_llm_base.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/unit/llms/test_llm_base.py b/tests/unit/llms/test_llm_base.py index a2ec618..944a0f8 100644 --- a/tests/unit/llms/test_llm_base.py +++ b/tests/unit/llms/test_llm_base.py @@ -89,7 +89,9 @@ def test_base_llm_extract_function_inputs_no_output(self, base_llm, mocker): def test_mandatory_args_only(self, base_llm, mixed_function_schema): inputs = {"mandatory1": "value1", "mandatory2": 42} - assert base_llm._is_valid_inputs(inputs, mixed_function_schema) # True is implied + assert base_llm._is_valid_inputs( + inputs, mixed_function_schema + ) # True is implied def test_all_args_provided(self, base_llm, mixed_function_schema): inputs = { @@ -98,7 +100,9 @@ def test_all_args_provided(self, base_llm, mixed_function_schema): "optional1": "opt1", "optional2": "opt2", } - assert base_llm._is_valid_inputs(inputs, mixed_function_schema) # True is implied + assert base_llm._is_valid_inputs( + inputs, mixed_function_schema + ) # True is implied def test_missing_mandatory_arg(self, base_llm, mixed_function_schema): inputs = {"mandatory1": "value1", "optional1": "opt1", "optional2": "opt2"} @@ -116,7 +120,9 @@ def test_extra_arg_provided(self, base_llm, mixed_function_schema): def test_check_for_mandatory_inputs_all_present(self, base_llm, mandatory_params): inputs = {"param1": "value1", "param2": "value2"} - assert base_llm._check_for_mandatory_inputs(inputs, mandatory_params) # True is implied + assert base_llm._check_for_mandatory_inputs( + inputs, mandatory_params + ) # True is implied def test_check_for_mandatory_inputs_missing_one(self, base_llm, mandatory_params): inputs = {"param1": "value1"} From b040c168c7d825085c23896dbe519fd94f54da73 Mon Sep 17 00:00:00 2001 From: Siraj R Aizlewood Date: Mon, 13 May 2024 21:30:20 +0400 Subject: [PATCH 09/10] Fixed pytests following resolution of conflicts with main. --- tests/unit/llms/test_llm_base.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/unit/llms/test_llm_base.py b/tests/unit/llms/test_llm_base.py index 38d2ee7..a52a341 100644 --- a/tests/unit/llms/test_llm_base.py +++ b/tests/unit/llms/test_llm_base.py @@ -11,11 +11,11 @@ def base_llm(self): @pytest.fixture def mixed_function_schema(self): - return { + return [{ "name": "test_function", "description": "A test function with mixed mandatory and optional parameters.", "signature": "(mandatory1, mandatory2: int, optional1=None, optional2: str = 'default')", - } + }] @pytest.fixture def mandatory_params(self): @@ -90,34 +90,34 @@ def test_base_llm_extract_function_inputs_no_output(self, base_llm, mocker): base_llm.extract_function_inputs(test_schema, test_query) def test_mandatory_args_only(self, base_llm, mixed_function_schema): - inputs = {"mandatory1": "value1", "mandatory2": 42} + inputs = [{"mandatory1": "value1", "mandatory2": 42}] assert base_llm._is_valid_inputs( inputs, mixed_function_schema ) # True is implied def test_all_args_provided(self, base_llm, mixed_function_schema): - inputs = { + inputs = [{ "mandatory1": "value1", "mandatory2": 42, "optional1": "opt1", "optional2": "opt2", - } + }] assert base_llm._is_valid_inputs( inputs, mixed_function_schema ) # True is implied def test_missing_mandatory_arg(self, base_llm, mixed_function_schema): - inputs = {"mandatory1": "value1", "optional1": "opt1", "optional2": "opt2"} + inputs = [{"mandatory1": "value1", "optional1": "opt1", "optional2": "opt2"}] assert not base_llm._is_valid_inputs(inputs, mixed_function_schema) def test_extra_arg_provided(self, base_llm, mixed_function_schema): - inputs = { + inputs = [{ "mandatory1": "value1", "mandatory2": 42, "optional1": "opt1", "optional2": "opt2", "extra": "value", - } + }] assert not base_llm._is_valid_inputs(inputs, mixed_function_schema) def test_check_for_mandatory_inputs_all_present(self, base_llm, mandatory_params): @@ -201,11 +201,11 @@ def test_validate_single_function_inputs_exception_handling(self, base_llm, mock mocked_logger = mocker.patch("semantic_router.utils.logger.logger.error") # Prepare inputs and a malformed function schema - test_inputs = {"timezone": "America/New_York"} + test_inputs = {"timezone": "America/New_York"} malformed_function_schema = { "name": "get_time", "description": "Finds the current time in a specific timezone.", - "signature": "(timezone str)", # Malformed signature missing colon + "signiture": "(timezone: str)", # Malformed key name "output": "", } @@ -218,7 +218,7 @@ def test_validate_single_function_inputs_exception_handling(self, base_llm, mock assert not result, "Method should return False when an exception occurs" # Check that the appropriate error message was logged - expected_error_message = "Single input validation error: list index out of range" # Adjust based on the actual exception message + expected_error_message = "Single input validation error: 'signature'" # Adjust based on the actual exception message mocked_logger.assert_called_once_with(expected_error_message) def test_extract_parameter_info_valid(self, base_llm): From c6e9f85f60167e66fc4aa5617ed37d3382b730ba Mon Sep 17 00:00:00 2001 From: Siraj R Aizlewood Date: Mon, 13 May 2024 21:34:35 +0400 Subject: [PATCH 10/10] Linting. --- tests/unit/llms/test_llm_base.py | 44 ++++++++++++++++++-------------- 1 file changed, 25 insertions(+), 19 deletions(-) diff --git a/tests/unit/llms/test_llm_base.py b/tests/unit/llms/test_llm_base.py index a52a341..3699ded 100644 --- a/tests/unit/llms/test_llm_base.py +++ b/tests/unit/llms/test_llm_base.py @@ -11,11 +11,13 @@ def base_llm(self): @pytest.fixture def mixed_function_schema(self): - return [{ - "name": "test_function", - "description": "A test function with mixed mandatory and optional parameters.", - "signature": "(mandatory1, mandatory2: int, optional1=None, optional2: str = 'default')", - }] + return [ + { + "name": "test_function", + "description": "A test function with mixed mandatory and optional parameters.", + "signature": "(mandatory1, mandatory2: int, optional1=None, optional2: str = 'default')", + } + ] @pytest.fixture def mandatory_params(self): @@ -96,12 +98,14 @@ def test_mandatory_args_only(self, base_llm, mixed_function_schema): ) # True is implied def test_all_args_provided(self, base_llm, mixed_function_schema): - inputs = [{ - "mandatory1": "value1", - "mandatory2": 42, - "optional1": "opt1", - "optional2": "opt2", - }] + inputs = [ + { + "mandatory1": "value1", + "mandatory2": 42, + "optional1": "opt1", + "optional2": "opt2", + } + ] assert base_llm._is_valid_inputs( inputs, mixed_function_schema ) # True is implied @@ -111,13 +115,15 @@ def test_missing_mandatory_arg(self, base_llm, mixed_function_schema): assert not base_llm._is_valid_inputs(inputs, mixed_function_schema) def test_extra_arg_provided(self, base_llm, mixed_function_schema): - inputs = [{ - "mandatory1": "value1", - "mandatory2": 42, - "optional1": "opt1", - "optional2": "opt2", - "extra": "value", - }] + inputs = [ + { + "mandatory1": "value1", + "mandatory2": 42, + "optional1": "opt1", + "optional2": "opt2", + "extra": "value", + } + ] assert not base_llm._is_valid_inputs(inputs, mixed_function_schema) def test_check_for_mandatory_inputs_all_present(self, base_llm, mandatory_params): @@ -201,7 +207,7 @@ def test_validate_single_function_inputs_exception_handling(self, base_llm, mock mocked_logger = mocker.patch("semantic_router.utils.logger.logger.error") # Prepare inputs and a malformed function schema - test_inputs = {"timezone": "America/New_York"} + test_inputs = {"timezone": "America/New_York"} malformed_function_schema = { "name": "get_time", "description": "Finds the current time in a specific timezone.",