From a97c6f436c2411a814d9686293053db4acf4eef3 Mon Sep 17 00:00:00 2001 From: emrgnt-cmplxty <68796651+emrgnt-cmplxty@users.noreply.github.com> Date: Sun, 17 Dec 2023 12:37:01 -0500 Subject: [PATCH] Feature/add back scripts (#128) * Add back sciphi * cleanup scripts * expand scripts --- docs/source/index.rst | 2 +- docs/source/setup/quickstart.rst | 52 +++-- pyproject.toml | 3 +- .../config/prompts/answer_question.yaml | 25 +++ .../config/prompts/question_and_answer.yaml | 27 --- .../rag_science_evaluator.py | 35 ++-- synthesizer/interface/__init__.py | 2 + synthesizer/interface/llm/sciphi_interface.py | 60 ++++++ synthesizer/interface/rag/agent_search.py | 1 + synthesizer/llm/__init__.py | 6 +- synthesizer/llm/base.py | 25 --- synthesizer/llm/config_manager.py | 4 + synthesizer/llm/models/openai_llm.py | 3 - synthesizer/llm/models/sciphi_llm.py | 123 ++++++++++++ synthesizer/scripts/data_augmenter.py | 186 ++++++++++++++++++ synthesizer/scripts/rag_harness.py | 92 +++++++++ 16 files changed, 554 insertions(+), 92 deletions(-) create mode 100644 synthesizer/config/prompts/answer_question.yaml delete mode 100644 synthesizer/config/prompts/question_and_answer.yaml create mode 100644 synthesizer/interface/llm/sciphi_interface.py create mode 100644 synthesizer/llm/models/sciphi_llm.py create mode 100644 synthesizer/scripts/data_augmenter.py create mode 100644 synthesizer/scripts/rag_harness.py diff --git a/docs/source/index.rst b/docs/source/index.rst index b0f2141..01c3a59 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -23,7 +23,7 @@ Welcome to Synthesizer 💡 A multi-purpose LLM framework for inference, RAG, and data creation. -Looking for the AgentSearch documentation? `Click Here `_. +Looking for the AgentSearch documentation? `Click Here `_. With Synthesizer, you can: diff --git a/docs/source/setup/quickstart.rst b/docs/source/setup/quickstart.rst index 4f78bdd..9a48cb6 100644 --- a/docs/source/setup/quickstart.rst +++ b/docs/source/setup/quickstart.rst @@ -7,9 +7,7 @@ Welcome to the Synthesizer quickstart guide! Synthesizer, or ΨΦ, is your porta This guide will introduce you to: -- Generating data tailored to your needs. - Using the RAG provider interface. -- Creating RAG-enhanced textbooks. - Evaluating your RAG pipeline. @@ -26,10 +24,43 @@ Before you start, ensure you've installed Synthesizer: For additional details, refer to the `installation guide `_. -Instantiate Your LLM and RAG Provider +Using Synthesizer +----------------- + +1. **Generate synthetic question answer pairs** + + .. code-block:: bash + + export SCIPHI_API_KEY=MY_SCIPHI_API_KEY + python -m synthesizer.scripts.data_augmenter run --dataset="wiki_qa" + + .. code-block:: bash + + tail augmented_output/config_name_eq_answer_question__dataset_name_eq_wiki_qa.jsonl + { "formatted_prompt": "... ### Question:\nwhat country did wine originate in\n\n### Input:\n1. URL: https://en.wikipedia.org/wiki/History%20of%20wine (Score: 0.85)\nTitle:History of wine....", + { "completion": Wine originated in the South Caucasus, which is now part of modern-day Armenia ... + +2. **Evaluate RAG pipeline performance** + + .. code-block:: bash + + export SCIPHI_API_KEY=MY_SCIPHI_API_KEY + python -m synthesizer.scripts.rag_harness --rag_provider="agent-search" --llm_provider_name="sciphi" --n_samples=25 + + .. code-block:: bash + ... + INFO:__main__:Now generating completions... + 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:29<00:00, 3.40it/s] + INFO:__main__:Final Accuracy=0.42 + +.. note:: + This is a basic introduction to Synthesizer. Check back later for more detailed and intricate documentation that delves deeper into advanced features and customization options. + + +Developing with Synthesizer ------------------------------------- -Here's how you can use Synthesizer to quickly set up and retrieve chat completions, without diving deep into intricate configurations: +Here's how you can use Synthesizer to quickly set up and RAG augmented generation, without diving deep into intricate configurations: .. code-block:: python @@ -45,8 +76,7 @@ Here's how you can use Synthesizer to quickly set up and retrieve chat completio # RAG Provider Settings rag_interface = RAGInterfaceManager.get_interface_from_args( - RAGProviderName(rag_provider_name), - api_base=rag_api_base, + RAGProviderName("agent-search"), limit_hierarchical_url_results=rag_limit_hierarchical_url_results, limit_final_pagerank_results=rag_limit_final_pagerank_results, ) @@ -65,13 +95,7 @@ Here's how you can use Synthesizer to quickly set up and retrieve chat completio # other generation params here ... ) - formatted_prompt = rag_prompt.format(rag_context=rag_context) + formatted_prompt = raw_prompt.format(rag_context=rag_context) completion = llm_interface.get_completion( formatted_prompt, generation_config - ) - print(completion) - - ### Output: - # Fermat's Last Theorem was proven by British mathematician Andrew Wiles in 1994 (Wikipedia). Wiles's proof was based on a special case of the modularity theorem for elliptic curves, along with Ribet's theorem (Wikipedia). The modularity theorem and Fermat's Last Theorem were previously considered inaccessible to proof by contemporaneous mathematicians (Wikipedia). However, Wiles's proof provided a solution to Fermat's Last Theorem, which had remained unproved for over 300 years (PlanetMath). Wiles's proof is widely accepted and has been recognized with numerous awards, including the Abel Prize in 2016 (Wikipedia). - - # It is important to note that Wiles's proof of Fermat's Last Theorem is a mathematical proof and not related to the science fiction novel "The Last Theorem" by Arthur C. Clarke and Frederik Pohl (Wikipedia). The novel is a work of fiction and does not provide a real mathematical proof for Fermat's Last Theorem (Wikipedia). Additionally, there have been other attempts to prove Fermat's Last Theorem, such as Sophie Germain's approach, but Wiles's proof is the most widely accepted and recognized (Math Stack Exchange). + ) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 4f6d752..6bc6d2a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ authors = ["Owen Colegrove "] license = "Apache-2.0" readme = "README.md" name = 'sciphi-synthesizer' -version = '1.0.0' +version = '1.0.1' packages = [ { include = "synthesizer" } ] @@ -24,7 +24,6 @@ fire = "^0.5.0" openai = { version = "0.27.8" } pyyaml = "^6.0.1" retrying = "^1.3.4" -tqdm = "^4.66.1" # Begin optional dependencies diff --git a/synthesizer/config/prompts/answer_question.yaml b/synthesizer/config/prompts/answer_question.yaml new file mode 100644 index 0000000..d5439fe --- /dev/null +++ b/synthesizer/config/prompts/answer_question.yaml @@ -0,0 +1,25 @@ +raw_text: | + Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. + + ### Instruction + Given the context provided after this instruction, extract and return an answer to the given questin. + + ### Question: + {question} + + ### Input: + {rag_context} + + ### Response + +dataset_supplied_inputs_map: + question: question + +user_supplied_inputs: [rag_context] + +default_dataset_name: "wiki_qa" + +output_format: "jsonl" + +default_user_inputs_map: + rag_context: "" diff --git a/synthesizer/config/prompts/question_and_answer.yaml b/synthesizer/config/prompts/question_and_answer.yaml deleted file mode 100644 index c0b04b9..0000000 --- a/synthesizer/config/prompts/question_and_answer.yaml +++ /dev/null @@ -1,27 +0,0 @@ -raw_text: | - Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. - - ### Instruction - Given the context provided after this instruction, extract and return question and answer pairs in JSONL format. The expected format is: - - ```jsonl - [{{\"question\": X, \"answer\": Y}}, ...] - ``` - - ### Input: - {context} - - {user_supplied_suffix} - ### Response - -dataset_supplied_inputs_map: - text: context - -user_supplied_inputs: [user_supplied_suffix] - -default_dataset_name: "ContextualAI/tiny-wiki100-chunks" - -output_format: "jsonl" - -default_user_inputs_map: - user_supplied_suffix: "_Note_ - Ensure all question and answer pairs are implied by the context above." diff --git a/synthesizer/eval/rag/science_multiple_choice/rag_science_evaluator.py b/synthesizer/eval/rag/science_multiple_choice/rag_science_evaluator.py index 7ca367b..e40e7aa 100644 --- a/synthesizer/eval/rag/science_multiple_choice/rag_science_evaluator.py +++ b/synthesizer/eval/rag/science_multiple_choice/rag_science_evaluator.py @@ -30,7 +30,7 @@ E: {E} #### Wikipedia Context: - {wiki_context} + {search_context} #### Answer: """ @@ -70,17 +70,16 @@ def __init__( "evals", f"{ScienceMultipleChoiceEvaluator.NAME.lower().replace(' ', '_')}.csv", ) - ) + ).head(n_samples) def initialize_prompts(self): contexts = ( - self.rag_interface.get_rag_context( - list( - self.evals[ - ScienceMultipleChoiceEvaluator.PROMPT_FIELD - ].values - ) - ) + [ + self.rag_interface.get_rag_context(prompt) + for prompt in self.evals[ + ScienceMultipleChoiceEvaluator.PROMPT_FIELD + ].values + ] if self.rag_interface else [ScienceMultipleChoiceEvaluator.RAG_DISABLED_RESPONSE] * len(self.evals) @@ -98,7 +97,7 @@ def build_prompt(self, entry: dict, context: str) -> str: + "\n" + SCIENCE_QUESTION_TEMPLATE.format( example_number=self.n_few_shot + 1, - wiki_context=context, + search_context=context, **entry, ) ) @@ -119,8 +118,8 @@ def n_shot_science_template(self) -> str: SCIENCE_QUESTION_TEMPLATE.format( example_number=example_number, prompt=example_prompt, - wiki_context=self.rag_interface.get_rag_context( - [example_prompt] + search_context=self.rag_interface.get_rag_context( + example_prompt ) if self.rag_interface else ScienceMultipleChoiceEvaluator.RAG_DISABLED_RESPONSE, @@ -140,9 +139,9 @@ def n_shot_science_template(self) -> str: SCIENCE_QUESTION_TEMPLATE.format( example_number=example_number, prompt=example_prompt, - wiki_context=self.rag_interface.get_rag_context( - [example_prompt] - )[0] + search_context=self.rag_interface.get_rag_context( + example_prompt + ) if self.rag_interface else ScienceMultipleChoiceEvaluator.RAG_DISABLED_RESPONSE, A="Mitochondria are primarily responsible for protein synthesis using ribosomes.", @@ -161,9 +160,9 @@ def n_shot_science_template(self) -> str: SCIENCE_QUESTION_TEMPLATE.format( example_number=example_number, prompt=example_prompt, - wiki_context=self.rag_interface.get_rag_context( - [example_prompt] - )[0] + search_context=self.rag_interface.get_rag_context( + example_prompt + ) if self.rag_interface else ScienceMultipleChoiceEvaluator.RAG_DISABLED_RESPONSE, A="Oxidation involves the addition of oxygen to a molecule or the loss of electrons from an atom or molecule.", diff --git a/synthesizer/interface/__init__.py b/synthesizer/interface/__init__.py index 40dda45..467da0a 100644 --- a/synthesizer/interface/__init__.py +++ b/synthesizer/interface/__init__.py @@ -9,6 +9,7 @@ HuggingFaceLLMInterface, ) from synthesizer.interface.llm.openai_interface import OpenAILLMInterface +from synthesizer.interface.llm.sciphi_interface import SciPhiLLMInterface from synthesizer.interface.llm.vllm_interface import vLLMInterface from synthesizer.interface.llm_interface_manager import LLMInterfaceManager from synthesizer.interface.rag.agent_search import ( @@ -27,6 +28,7 @@ "AnthropicLLMInterface", "HuggingFaceLLMInterface", "OpenAILLMInterface", + "SciPhiLLMInterface", "vLLMInterface", # RAG "RAGInterfaceManager", diff --git a/synthesizer/interface/llm/sciphi_interface.py b/synthesizer/interface/llm/sciphi_interface.py new file mode 100644 index 0000000..d529236 --- /dev/null +++ b/synthesizer/interface/llm/sciphi_interface.py @@ -0,0 +1,60 @@ +"""A module for interfacing with the SciPhi API""" +import logging + +from synthesizer.interface.base import LLMInterface, LLMProviderName +from synthesizer.interface.llm_interface_manager import llm_interface +from synthesizer.llm import GenerationConfig, SciPhiConfig, SciPhiLLM + +logger = logging.getLogger(__name__) + + +@llm_interface +class SciPhiLLMInterface(LLMInterface): + """A class to interface with the SciPhi API.""" + + provider_name = LLMProviderName.SCIPHI + system_message = "You are a helpful assistant." + + def __init__( + self, + config: SciPhiConfig, + *args, + **kwargs, + ) -> None: + self.config = config + self._model = SciPhiLLM(config) + + def get_completion( + self, prompt: str, generation_config: GenerationConfig + ) -> str: + """Get a completion from the SciPhi API based on the provided prompt.""" + + logger.debug( + f"Getting completion from SciPhi API for model={generation_config.model_name}" + ) + if "instruct" in generation_config.model_name: + return self.model.get_instruct_completion( + prompt, generation_config + ) + else: + return self._model.get_chat_completion( + [ + { + "role": "system", + "content": SciPhiLLMInterface.system_message, + }, + {"role": "user", "content": prompt}, + ], + generation_config, + ) + + def get_chat_completion( + self, conversation: list[dict], generation_config: GenerationConfig + ) -> str: + raise NotImplementedError( + "Chat completion not yet implemented for SciPhi." + ) + + @property + def model(self) -> SciPhiLLM: + return self._model diff --git a/synthesizer/interface/rag/agent_search.py b/synthesizer/interface/rag/agent_search.py index ad02e41..f09a7f5 100644 --- a/synthesizer/interface/rag/agent_search.py +++ b/synthesizer/interface/rag/agent_search.py @@ -2,6 +2,7 @@ from dataclasses import dataclass from agent_search.core import SERPClient + from synthesizer.core import RAGProviderName from synthesizer.interface.base import RAGInterface, RAGProviderConfig from synthesizer.interface.rag_interface_manager import ( diff --git a/synthesizer/llm/__init__.py b/synthesizer/llm/__init__.py index 163660c..f4b2a04 100644 --- a/synthesizer/llm/__init__.py +++ b/synthesizer/llm/__init__.py @@ -1,4 +1,4 @@ -from synthesizer.llm.base import LLM, GenerationConfig, LLMConfig, ModelName +from synthesizer.llm.base import LLM, GenerationConfig, LLMConfig from synthesizer.llm.config_manager import LLMConfigManager from synthesizer.llm.models.anthropic_llm import AnthropicConfig, AnthropicLLM from synthesizer.llm.models.hugging_face_llm import ( @@ -6,12 +6,12 @@ HuggingFaceLLM, ) from synthesizer.llm.models.openai_llm import OpenAIConfig, OpenAILLM +from synthesizer.llm.models.sciphi_llm import SciPhiConfig, SciPhiLLM from synthesizer.llm.models.vllm_llm import vLLM, vLLMConfig __all__ = [ # Base "LLM", - "ModelName", "LLMConfig", "LLMConfigManager", "GenerationConfig", @@ -22,6 +22,8 @@ "HuggingFaceLLM", "OpenAIConfig", "OpenAILLM", + "SciPhiConfig", + "SciPhiLLM", "vLLMConfig", "vLLM", ] diff --git a/synthesizer/llm/base.py b/synthesizer/llm/base.py index f286969..0d6669c 100644 --- a/synthesizer/llm/base.py +++ b/synthesizer/llm/base.py @@ -1,36 +1,11 @@ """Base classes for language model providers.""" from abc import ABC, abstractmethod from dataclasses import dataclass, field, fields -from enum import Enum from typing import Optional from synthesizer.core import LLMProviderName -class ModelName(Enum): - """An enum to hold the names of supported models.""" - - # OpenAI Models - - ## GPT-3.5 - GPT_3p5_TURBO_0301 = "gpt-3.5-turbo-0301" - GPT_3p5_TURBO_0613 = "gpt-3.5-turbo-0613" - GPT_3p5_TURBO_16k_0613 = "gpt-3.5-turbo-16k-0613" - GPT_3p5_TURBO = "gpt-3.5-turbo" - GPT_3p5_TURBO_INSTRUCT = "gpt-3.5-turbo-instruct" - - ## GPT-4 - GPT_4_0314 = "gpt-4-0314" - GPT_4_0613 = "gpt-4-0613" - GPT_4 = "gpt-4" - GPT_4_32k = "gpt-4-32k" - - # Anthropic Models - - CLAUDE_INSTANT_1 = "claude-instant-1" - CLAUDE_2 = "claude-2" - - @dataclass class LLMConfig(ABC): provider_name: LLMProviderName diff --git a/synthesizer/llm/config_manager.py b/synthesizer/llm/config_manager.py index cff0c00..031acb0 100644 --- a/synthesizer/llm/config_manager.py +++ b/synthesizer/llm/config_manager.py @@ -15,6 +15,10 @@ def get_config_for_provider( llm_provider_name: LLMProviderName, ) -> Type[LLMConfig]: """Get the configuration class for a given model.""" + print( + "LLMConfigManager.config_registry = ", + LLMConfigManager.config_registry, + ) config_class = LLMConfigManager.config_registry.get(llm_provider_name) if not config_class: diff --git a/synthesizer/llm/models/openai_llm.py b/synthesizer/llm/models/openai_llm.py index 7abf0ec..7227bdf 100644 --- a/synthesizer/llm/models/openai_llm.py +++ b/synthesizer/llm/models/openai_llm.py @@ -1,7 +1,4 @@ """A module for creating OpenAI model abstractions.""" -# TODO - Will we face issues if a user attempts to access -# OpenAI + vLLM / SciPhi remote in the same session? -# My guess is yes, but need to test + workaround. from dataclasses import dataclass from synthesizer.core import LLMProviderName diff --git a/synthesizer/llm/models/sciphi_llm.py b/synthesizer/llm/models/sciphi_llm.py new file mode 100644 index 0000000..d6fb36b --- /dev/null +++ b/synthesizer/llm/models/sciphi_llm.py @@ -0,0 +1,123 @@ +"""A module for creating SciPhi model abstractions.""" +import os +from dataclasses import dataclass + +from synthesizer.core import LLMProviderName +from synthesizer.llm.base import LLM, GenerationConfig, LLMConfig +from synthesizer.llm.config_manager import model_config + + +@model_config +@dataclass +class SciPhiConfig(LLMConfig): + """Configuration for SciPhi models.""" + + # Base + provider_name: LLMProviderName = LLMProviderName.SCIPHI + + api_base: str = "https://api.sciphi.ai/v1" + + +class SciPhiLLM(LLM): + """A concrete class for creating SciPhi models.""" + + PROMPT_MEASUREMENT_PREFIX = ( + "" + ) + + def __init__( + self, + config: SciPhiConfig, + *args, + **kwargs, + ) -> None: + super().__init__() + self.config: SciPhiConfig = config + + try: + import openai + except ImportError: + raise ImportError( + "Please install the synthesizer package before attempting to run with an SciPhi model. This can be accomplished via `pip install openai`." + ) + + sciphi_api_key = os.getenv("SCIPHI_API_KEY") + openai.api_base = self.config.api_base + if not sciphi_api_key: + raise ValueError( + "Please set the environment variable SCIPHI_API_KEY before attempting to run with the SciPhi provider." + ) + openai.api_key = sciphi_api_key + # set the config here, again, for typing purposes + if not isinstance(self.config, SciPhiConfig): + raise ValueError( + "The provided config must be an instance of SciPhiConfig." + ) + + def get_chat_completion( + self, + messages: list[dict[str, str]], + generation_config: GenerationConfig, + ) -> str: + """Get a completion from the SciPhi API based on the provided messages.""" + import openai + + # Create a dictionary with the default arguments + args = self._get_base_args( + generation_config, + SciPhiLLM.PROMPT_MEASUREMENT_PREFIX + + f"{SciPhiLLM.PROMPT_MEASUREMENT_PREFIX}\n\n".join( + [m["content"] for m in messages] + ), + ) + + args["messages"] = messages + + # Conditionally add the 'functions' argument if it's not None + if generation_config.functions is not None: + args["functions"] = generation_config.functions + + # Create the chat completion + response = openai.ChatCompletion.create(**args) + return response.choices[0].message["content"] + + def get_instruct_completion( + self, prompt: str, generation_config: GenerationConfig + ) -> str: + """Get an instruction completion from the SciPhi API based on the provided prompt.""" + import openai + + args = self._get_base_args(generation_config, prompt) + + args["prompt"] = prompt + + # Create the instruction completion + response = openai.Completion.create(**args) + return response.choices[0].text + + def _get_base_args( + self, + generation_config: GenerationConfig, + prompt=None, + ) -> dict: + """Get the base arguments for the SciPhi API.""" + + args = { + "model": generation_config.model_name, + "temperature": generation_config.temperature, + "top_p": generation_config.top_p, + "stream": generation_config.do_stream, + # TODO - We need to cap this to avoid potential errors when exceed max allowable context + "max_tokens": generation_config.max_tokens_to_sample, + } + + # Check if were using SciPhi api with re-routed base + if self.config.provider_name in [ + LLMProviderName.VLLM, + LLMProviderName.SCIPHI, + ]: + args["top_k"] = generation_config.top_k + args["skip_special_tokens"] = generation_config.skip_special_tokens + args["stop"] = generation_config.stop_token + + return args diff --git a/synthesizer/scripts/data_augmenter.py b/synthesizer/scripts/data_augmenter.py new file mode 100644 index 0000000..e91bbb2 --- /dev/null +++ b/synthesizer/scripts/data_augmenter.py @@ -0,0 +1,186 @@ +"""A module which facilitates data augmentation.""" "" +import json +import logging +import os +from typing import Optional + +import dotenv +import fire +import yaml +from datasets import Dataset, load_dataset +from tqdm import tqdm + +from synthesizer.core import ( + JsonlDataWriter, + LLMProviderName, + Prompt, + RAGProviderName, +) +from synthesizer.core.utils import get_config_dir +from synthesizer.interface import LLMInterfaceManager, RAGInterfaceManager +from synthesizer.llm import GenerationConfig + +# Load environment variables +dotenv.load_dotenv() + +# Logging setup +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def get_output_path(output_dir: str, output_name: str) -> str: + """Returns the complete path where the output will be saved.""" + return ( + os.path.join(output_dir, output_name) + if os.path.isabs(output_dir) + else os.path.join(os.getcwd(), output_dir, output_name) + ) + + +def ensure_directory_exists(filepath: str): + """Ensures the directory of the given filepath exists.""" + directory = os.path.dirname(filepath) + if not os.path.exists(directory): + os.makedirs(directory) + + +SHUFFLE_SEED = 42 + + +def main( + # Run Settings + output_dir="augmented_output", + output_name=None, + shuffle: bool = True, + n_samples: int = 100, + # LLM Settings + llm_provider_name="sciphi", + llm_model_name="SciPhi/SciPhi-Mistral-7B-32k", + llm_max_tokens_to_sample=128, + llm_temperature=0.1, + llm_top_k=100, + llm_api_base: Optional[str] = "https://api.sciphi.ai/v1", + llm_skip_special_tokens: bool = False, + # RAG Settings + rag_enabled=True, + rag_provider_name="agent-search", + rag_api_base="https://api.sciphi.ai", + # Dataset Settings + dataset_name: Optional[str] = None, + dataset_split: str = "train", + # Prompt Settings + config_name: Optional[str] = "answer_question", + config_path: Optional[str] = None, + ## Additional user inputs (like user_supplied_suffix) + ## can be passed in through Fire as kwargs, e.g. + ## --user_supplied_inputs "{'your_input': 'Your Input Value'}" + user_supplied_inputs: Optional[dict] = None, + **kwargs, +): + """Run the data augmenter.""" + + # Validate configurations + if config_name and config_path: + raise ValueError( + "Must provide either a config name or a config path, but not both." + ) + + # Initialize configuration and output path + config_path_from_name_or_default = config_path or os.path.join( + get_config_dir(), "prompts", f"{config_name}.yaml" + ) + with open(config_path_from_name_or_default, "r") as yaml_file: + config = yaml.safe_load(yaml_file) + + # Set default dataset name if not provided + dataset_name = dataset_name or config["default_dataset_name"] + + # Set the default user supplied inputs + user_supplied_inputs = ( + user_supplied_inputs or config["default_user_inputs_map"] + ) + + # Set up output settings + + ## Get the configuration name and the output file name + config_name = config_path_from_name_or_default.split(os.path.sep)[ + -1 + ].replace(".yaml", "") + output_name = ( + output_name + or f"config_name_eq_{config_name}__dataset_name_eq_{dataset_name.replace('/','_')}.jsonl" + ) + logger.info( + f"Augmenting dataset {dataset_name} with prompt {config_name}." + ) + + logger.info(f"Saving the output to {output_name}.") + output_path = get_output_path(output_dir, output_name) + + ## Ensure output directory exists + ensure_directory_exists(output_path) + writer = JsonlDataWriter(output_path) + + prompt = Prompt(config=config) + rag_interface = ( + RAGInterfaceManager.get_interface_from_args( + RAGProviderName(rag_provider_name), + api_base=rag_api_base or llm_api_base, + ) + if rag_enabled + else None + ) + + llm_interface = LLMInterfaceManager.get_interface_from_args( + LLMProviderName(llm_provider_name), + api_base=llm_api_base, + # Currently only consumed by SciPhi + rag_interface=rag_interface, + # Consumed by single-load providers + model_name=llm_model_name, + ) + + llm_generation_config = GenerationConfig( + temperature=llm_temperature, + top_k=llm_top_k, + max_tokens_to_sample=llm_max_tokens_to_sample, + model_name=llm_model_name, + ) + # Prepare the samples + dataset: Dataset = load_dataset(dataset_name)[dataset_split] + + if shuffle: + dataset = dataset.shuffle(seed=SHUFFLE_SEED) + n_samples = min(n_samples, len(dataset)) + samples = dataset.select(range(n_samples)) + + logger.info(f"Now running over {n_samples} samples.") + for entry in tqdm(samples): + user_supplied_inputs["rag_context"] = rag_interface.get_rag_context( + entry["question"] + ) + formatted_prompt = prompt.format( + dataset_entry=entry, **user_supplied_inputs + ) + completion = llm_interface.get_completion( + formatted_prompt, llm_generation_config + ) + if config["output_format"] == "jsonl": + try: + data = { + "formatted_prompt": formatted_prompt, + "completion": completion, + } + writer.write([data]) + except json.decoder.JSONDecodeError: + logger.error( + f"Failed to decode JSON response from LLM: {completion}" + ) + else: + writer.write( + [{"prompt": formatted_prompt, "completion": completion}] + ) + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/synthesizer/scripts/rag_harness.py b/synthesizer/scripts/rag_harness.py new file mode 100644 index 0000000..bd13ccc --- /dev/null +++ b/synthesizer/scripts/rag_harness.py @@ -0,0 +1,92 @@ +import logging +from typing import Optional + +import dotenv +import fire +from tqdm import tqdm + +from synthesizer.core import LLMProviderName, RAGProviderName +from synthesizer.eval.rag import ScienceMultipleChoiceEvaluator +from synthesizer.interface import LLMInterfaceManager, RAGInterfaceManager +from synthesizer.llm import GenerationConfig + +# Load environment variables +dotenv.load_dotenv() + +# Logging setup +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def main( + # LLM Settings + llm_provider_name="sciphi", + llm_model_name="SciPhi/SciPhi-Mistral-7B-32k", + llm_max_tokens_to_sample=32, + llm_temperature=0.1, + llm_top_k=100, + # RAG Settings + rag_enabled=True, + rag_provider_name="agent-search", + # Evaluation Settings + n_few_shot=3, + n_samples=100, + evals_to_run="science_multiple_choice", + *args, + **kwargs, +): + rag_interface = ( + RAGInterfaceManager.get_interface_from_args( + RAGProviderName(rag_provider_name), + ) + if rag_enabled + else None + ) + llm_interface = LLMInterfaceManager.get_interface_from_args( + LLMProviderName(llm_provider_name), + # Currently only consumed by SciPhi + rag_interface=rag_interface, + # Consumed by single-load providers + model_name=llm_model_name, + ) + + llm_generation_config = GenerationConfig( + temperature=llm_temperature, + top_k=llm_top_k, + max_tokens_to_sample=llm_max_tokens_to_sample, + model_name=llm_model_name, + ) + + for eval in evals_to_run.split(","): + logger.info(f"Running eval: {eval}") + if eval == "science_multiple_choice": + evaluator = ScienceMultipleChoiceEvaluator( + llm_interface=llm_interface, + rag_interface=rag_interface, + n_few_shot=n_few_shot, + n_samples=n_samples, + ) + + # TODO - Implement other evals + + logger.info(f"Instruction:\n\n{evaluator.instruction}") + logger.info("Now building prompts...") + evaluator.initialize_prompts() + logger.info(f"Example Prompt:\n\n{evaluator.prompts[0]}") + + logger.info("Now generating completions...") + counts = 0 + for i in tqdm(range(n_samples)): + logger.debug( + f"Processing sample {i} with prompt:\n{evaluator.prompts[i]}" + ) + response = llm_interface.get_completion( + prompt=evaluator.prompts[i], + generation_config=llm_generation_config, + ) + counts += int(evaluator.evaluate_response(response, i)) + logger.info(f"Final Accuracy={(counts) / (i + 1)}") + + +if __name__ == "__main__": + fire.Fire(main)