diff --git a/configs/score_documents/adult_content_config.yaml b/configs/score_documents/adult_content_config_local.yaml similarity index 95% rename from configs/score_documents/adult_content_config.yaml rename to configs/score_documents/adult_content_config_local.yaml index a69d12f4..7ce0a0e9 100644 --- a/configs/score_documents/adult_content_config.yaml +++ b/configs/score_documents/adult_content_config_local.yaml @@ -2,9 +2,10 @@ settings: model_name: google/gemma-2-9b-it num_gpus: 1 tokenizer_name_or_path: ${settings.model_name} + provider: local paths: raw_data_file_paths: - - data/deu_Latn_sampled_500k_first200.jsonl + - data/test.jsonl start_indexes: - 150 output_directory_path: data/output diff --git a/configs/score_documents/adult_content_config_openai.yaml b/configs/score_documents/adult_content_config_openai.yaml new file mode 100644 index 00000000..5c751067 --- /dev/null +++ b/configs/score_documents/adult_content_config_openai.yaml @@ -0,0 +1,45 @@ +settings: + model_name: gpt-4o + num_gpus: 1 + tokenizer_name_or_path: google/gemma-2-9b-it + provider: openai + openai: + api_key: ${env:OPENAI_API_KEY} + base_url: https://api.openai.com/v1 + + paths: + raw_data_file_paths: + - /raid/fromm/ml_filter/data/test.jsonl + start_indexes: + - 0 + output_directory_path: /raid/fromm/ml_filter/data/output + prompt_template_file_path: /raid/fromm/ml_filter/data/prompts/adult/adult_content_scoring_prompt.yaml +llm_rest_client: + max_tokens: 8192 + sampling_params: + max_tokens: 500 + temperature: 0.7 + n: 3 + top_p: 0.9 + max_pool_connections: 1000 + max_pool_maxsize: 1000 + max_retries: 2 + backoff_factor: 0.4 + timeout: 100 + verbose: false + num_gpus: ${settings.num_gpus} + max_new_tokens: 500 +tokenizer: + pretrained_model_name_or_path: ${settings.tokenizer_name_or_path} + special_tokens: null + add_generation_prompt: true +prompt_builder: + prompt_template_file_path: ${settings.paths.prompt_template_file_path} + max_prompt_length: 7690 +document_processor: + output_directory_path: ${settings.paths.output_directory_path} + queue_size: 1000 + num_processes: 10 + score_metric_name: adult_score + strings_to_remove: [] + jq_language_pattern: .language diff --git a/configs/score_documents/lorem_ipsum.yaml b/configs/score_documents/lorem_ipsum.yaml index 9f936081..1ad3ee65 100644 --- a/configs/score_documents/lorem_ipsum.yaml +++ b/configs/score_documents/lorem_ipsum.yaml @@ -5,7 +5,8 @@ settings: # we need to set this here manually as this is specified only when hosting the model num_gpus: 1 tokenizer_name_or_path: ${settings.model_name} - + provider: local + paths: raw_data_file_paths: - data/test_fineweb2_dump.jsonl @@ -15,7 +16,6 @@ settings: - 10 llm_rest_client: - model_name: ${settings.model_name} max_tokens: 8192 # The maximum total number of tokens supported by the model (input + output) sampling_params: max_tokens: 500 # The maximum number of tokens to generate diff --git a/src/ml_filter/config/annotation_pipeline_config.py b/src/ml_filter/config/annotation_pipeline_config.py index 48ebedd3..2d2cd616 100644 --- a/src/ml_filter/config/annotation_pipeline_config.py +++ b/src/ml_filter/config/annotation_pipeline_config.py @@ -3,56 +3,85 @@ from pydantic import BaseModel, DirectoryPath, Field, FilePath +class OpenAIConfig(BaseModel): + api_key: str = Field(..., description="OpenAI API key, typically sourced from an environment variable.") + base_url: str = Field(default="https://api.openai.com/v1", description="OpenAI API base URL.") + + class Config: + extra = "forbid" + + class PathsConfig(BaseModel): raw_data_file_paths: List[FilePath] = Field(default_factory=list) output_directory_path: DirectoryPath prompt_template_file_path: FilePath start_indexes: List[int] = Field(default_factory=list) + class Config: + extra = "forbid" + class SettingsConfig(BaseModel): - model_name: str - num_gpus: int - tokenizer_name_or_path: str + model_name: str = Field(..., description="Model name (e.g., 'google/gemma-2-9b-it' for local, 'gpt-4o' for OpenAI).") + num_gpus: int = Field(..., ge=0, description="Number of GPUs for local LLM (ignored for OpenAI).") + tokenizer_name_or_path: str = Field(..., description="Tokenizer name or path.") paths: PathsConfig + provider: str = Field(..., description="LLM provider: 'local' or 'openai'.") + openai: Optional[OpenAIConfig] = Field( + default=None, description="OpenAI-specific configuration, required if provider is 'openai'." + ) + + class Config: + extra = "forbid" class LLMRestClientConfig(BaseModel): - model_name: str - max_tokens: int - max_pool_connections: int - max_pool_maxsize: int - max_retries: int - backoff_factor: float - timeout: int - verbose: bool - num_gpus: int - sampling_params: dict + max_tokens: int = Field(..., ge=1, description="Maximum total tokens (input + output).") + max_pool_connections: int = Field(..., ge=1, description="Maximum pool connections for local LLM.") + max_pool_maxsize: int = Field(..., ge=1, description="Maximum pool size for local LLM.") + max_retries: int = Field(..., ge=0, description="Maximum number of retries for API requests.") + backoff_factor: float = Field(..., ge=0.0, description="Backoff factor for retry delays.") + timeout: int = Field(..., ge=1, description="Request timeout in seconds.") + verbose: bool = Field(..., description="Enable verbose logging.") + num_gpus: int = Field(..., ge=0, description="Number of GPUs for local LLM (ignored for OpenAI).") + sampling_params: dict = Field(..., description="Sampling parameters for text generation.") class TokenizerConfig(BaseModel): - pretrained_model_name_or_path: str - special_tokens: Optional[dict] - add_generation_prompt: bool + pretrained_model_name_or_path: str = Field(..., description="Pretrained tokenizer name or path.") + special_tokens: Optional[dict] = Field(default=None, description="Special tokens for tokenizer.") + add_generation_prompt: bool = Field(..., description="Whether to add generation prompt.") + + class Config: + extra = "forbid" class PromptBuilderConfig(BaseModel): - prompt_template_file_path: str - max_prompt_length: int + prompt_template_file_path: str = Field(..., description="Path to the prompt template file.") + max_prompt_length: int = Field(..., ge=1, description="Maximum length of the prompt.") + + class Config: + extra = "forbid" class DocumentProcessorConfig(BaseModel): - output_directory_path: DirectoryPath - queue_size: int - num_processes: int - score_metric_name: str - strings_to_remove: List[str] = Field(default_factory=list) - jq_language_pattern: str + output_directory_path: DirectoryPath = Field(..., description="Output directory for processed documents.") + queue_size: int = Field(..., ge=1, description="Size of the processing queue.") + num_processes: int = Field(..., ge=1, description="Number of processes for document processing.") + score_metric_name: str = Field(..., description="Name of the score metric.") + strings_to_remove: List[str] = Field(default_factory=list, description="Strings to remove from documents.") + jq_language_pattern: str = Field(..., description="JQ pattern for language metadata.") + + class Config: + extra = "forbid" class AnnotationPipelineConfig(BaseModel): - settings: SettingsConfig - llm_rest_client: LLMRestClientConfig - tokenizer: TokenizerConfig - prompt_builder: PromptBuilderConfig - document_processor: DocumentProcessorConfig + settings: SettingsConfig = Field(..., description="General settings for the pipeline.") + llm_rest_client: LLMRestClientConfig = Field(..., description="Configuration for LLM REST client.") + tokenizer: TokenizerConfig = Field(..., description="Configuration for tokenizer.") + prompt_builder: PromptBuilderConfig = Field(..., description="Configuration for prompt builder.") + document_processor: DocumentProcessorConfig = Field(..., description="Configuration for document processor.") + + class Config: + extra = "forbid" \ No newline at end of file diff --git a/src/ml_filter/llm_api/llm_rest_client.py b/src/ml_filter/llm_api/llm_rest_client.py index fb41d70f..f4c9fe5c 100644 --- a/src/ml_filter/llm_api/llm_rest_client.py +++ b/src/ml_filter/llm_api/llm_rest_client.py @@ -7,16 +7,14 @@ from requests import RequestException, Session from requests.adapters import HTTPAdapter +from openai import OpenAI, OpenAIError from ml_filter.data_processing.document import DocumentProcessingStatus, ProcessedDocument from ml_filter.utils.logging import get_logger class LLMRestClient: - """ "A class representing a REST client for the LLM service. - This class is responsible for sending requests to the LLM service - (hosted TGI container given the endpoint) and returning the response. - """ + """A class representing a REST client for LLM services, supporting both local TGI and OpenAI API.""" def __init__( self, @@ -31,6 +29,9 @@ def __init__( max_tokens: int, verbose: bool, sampling_params: Dict[str, Any], + provider: str = "local", + openai_api_key: str = None, + openai_base_url: str = None, ): """Initializes the LLMRestClient.""" self.max_retries = max_retries @@ -40,32 +41,84 @@ def __init__( self.max_tokens = max_tokens self.verbose = verbose self.logger = get_logger(name=self.__class__.__name__, level=logging.INFO) - self.session = session self.sampling_params = sampling_params + self.provider = provider.lower() + + if self.provider == "openai": + if not openai_api_key: + raise ValueError("OpenAI API key is required when provider is 'openai'.") + self.openai_client = OpenAI( + api_key=openai_api_key, + base_url=openai_base_url or "https://api.openai.com/v1", + ) + else: + self.session = session + self.session.mount( + "http://", HTTPAdapter(pool_connections=max_pool_connections, pool_maxsize=max_pool_maxsize) + ) + self.rest_endpoint_generate = ( + f"{rest_endpoint}v1/completions" if rest_endpoint.endswith("/") else f"{rest_endpoint}/v1/completions" + ) + self.logger.info(f"Using rest endpoint at {self.rest_endpoint_generate}") - # TODO: Not entirely sure why this is needed now, but it worked fine previously - self.session.mount("http://", HTTPAdapter(pool_connections=max_pool_connections, pool_maxsize=max_pool_maxsize)) - - self.rest_endpoint_generate = ( - f"{rest_endpoint}v1/completions" if rest_endpoint.endswith("/") else f"{rest_endpoint}/v1/completions" - ) + def generate(self, processed_document: ProcessedDocument) -> List[ProcessedDocument]: + """Generates a response based on the given prompt.""" + if self.provider == "openai": + return self._generate_openai(processed_document) + else: + return self._generate_local(processed_document) + + def _generate_openai(self, processed_document: ProcessedDocument) -> List[ProcessedDocument]: + """Generates a response using the OpenAI API.""" + request = { + "model": self.model_name, + "messages": [{"role": "user", "content": processed_document.prompt}], + "max_tokens": self.sampling_params.get("max_tokens", 500), + "temperature": self.sampling_params.get("temperature", 0.7), + "top_p": self.sampling_params.get("top_p", 0.9), + "n": self.sampling_params.get("n", 1), + } + start_time_generation = time.time() + new_documents = [] - self.logger.info(f"Using rest endpoint at {self.rest_endpoint_generate}") + for i in range(self.max_retries): + try: + response = self.openai_client.chat.completions.create(**request) + break + except OpenAIError as e: + self.logger.error(f"OpenAI API request failed with {e}, retrying... ({i+1}/{self.max_retries})") + time.sleep(self.backoff_factor * (2**i)) + if i == self.max_retries - 1: + processed_document.document_processing_status = DocumentProcessingStatus.ERROR_SERVER + processed_document.errors.append(str(e)) + self.logger.error(f"OpenAI API request failed after {self.max_retries} retries.") + return [processed_document] - def generate(self, processed_document: ProcessedDocument) -> List[ProcessedDocument]: - """Generates a response based on the given prompt. - Args: - processed_document (ProcessedDocument): The processed document. + generated_texts = self._parse_openai_response(response) + for generated_text in generated_texts: + new_document = copy.deepcopy(processed_document) + if generated_text is not None: + new_document.generated_text = generated_text + else: + new_document.document_processing_status = DocumentProcessingStatus.ERROR_NO_GENERATED_TEXT + new_document.errors.append("No generated text in OpenAI response.") + end_time_generation = time.time() + time_diff_generation = end_time_generation - start_time_generation + completion_tokens = response.usage.completion_tokens if hasattr(response.usage, "completion_tokens") else 0 + out_token_per_second = completion_tokens / time_diff_generation if time_diff_generation > 0 else 0 + new_document.out_tokens_per_second = out_token_per_second + new_document.timestamp = int(end_time_generation) + new_documents.append(new_document) - Returns: - Dict[str, Any]: A dictionary containing the generated response. - """ + return new_documents - request = dict( - model=self.model_name, - prompt=processed_document.prompt, + def _generate_local(self, processed_document: ProcessedDocument) -> List[ProcessedDocument]: + """Generates a response using the local TGI endpoint.""" + request = { + "model": self.model_name, + "prompt": processed_document.prompt, **self.sampling_params, - ) + } start_time_generation = time.time() for i in range(self.max_retries): try: @@ -79,7 +132,6 @@ def generate(self, processed_document: ProcessedDocument) -> List[ProcessedDocum traceback.print_exc() print(f"Request failed with {e}, retrying...{i}") time.sleep(self.backoff_factor * (2**i)) - if i == self.max_retries - 1: processed_document.document_processing_status = DocumentProcessingStatus.ERROR_SERVER processed_document.errors.append(str(e)) @@ -99,8 +151,6 @@ def generate(self, processed_document: ProcessedDocument) -> List[ProcessedDocum new_document.errors.append(f"Response could not be parsed: {response_dict}") end_time_generation = time.time() time_diff_generation = end_time_generation - start_time_generation - # note, we only get 'prompt_tokens', 'total_tokens' and 'completion_tokens' on request basis and - # measure time for the full request. We cannot decompose the time for the different parts of the request out_token_per_second = response_dict["usage"]["completion_tokens"] / time_diff_generation new_document.out_tokens_per_second = out_token_per_second new_document.timestamp = int(end_time_generation) @@ -114,16 +164,15 @@ def generate(self, processed_document: ProcessedDocument) -> List[ProcessedDocum return new_documents def parse_response(self, response_dict: dict) -> List[str] | None: - """Parses the response from the LLM service. - - Args: - response_dict (dict): The response dictionary. - - Returns: - str: The generated text. - """ + """Parses the response from the local LLM service.""" choices = response_dict.get("choices") if choices is None or len(choices) == 0: return None - else: - return [choice.get("text") for choice in choices] + return [choice.get("text") for choice in choices] + + def _parse_openai_response(self, response: Any) -> List[str] | None: + """Parses the response from the OpenAI API.""" + choices = response.choices + if choices is None or len(choices) == 0: + return None + return [choice.message.content for choice in choices] \ No newline at end of file diff --git a/src/ml_filter/llm_client.py b/src/ml_filter/llm_client.py index 2c54cb00..b04a9c1b 100644 --- a/src/ml_filter/llm_client.py +++ b/src/ml_filter/llm_client.py @@ -1,5 +1,6 @@ import logging import shutil +import os # Add os import for environment variables from pathlib import Path from omegaconf import OmegaConf @@ -20,13 +21,22 @@ def __init__(self, config_file_path: Path, experiment_id: str, rest_endpoint: st self.experiment_id = experiment_id self.rest_endpoint = rest_endpoint + # Register custom resolver for environment variables + OmegaConf.register_new_resolver( + "env", + lambda key: os.environ.get(key), + replace=True + ) + + # Existing resolver for eval OmegaConf.register_new_resolver("eval", eval) + + # Load and validate configuration config_omegaconf = OmegaConf.load(config_file_path) config_resolved = OmegaConf.to_container(config_omegaconf, resolve=True) cfg = AnnotationPipelineConfig.model_validate(config_resolved) self.prompt_template_file_path = Path(cfg.prompt_builder.prompt_template_file_path) - # Create experiment directory and store the config as backup self.experiment_dir_path = Path(cfg.settings.paths.output_directory_path) / self.experiment_id self.experiment_dir_path.mkdir(parents=True, exist_ok=True) @@ -36,21 +46,22 @@ def __init__(self, config_file_path: Path, experiment_id: str, rest_endpoint: st cfg.prompt_builder.prompt_template_file_path, self.experiment_dir_path / Path(self.prompt_template_file_path).name, ) - # Dataset related variables self.raw_data_file_paths = [Path(path) for path in cfg.settings.paths.raw_data_file_paths] self.start_indexes = [int(index) for index in cfg.settings.paths.start_indexes] # LLMRestClient related variables self.max_retries = cfg.llm_rest_client.max_retries self.backoff_factor = cfg.llm_rest_client.backoff_factor - self.model_name = cfg.llm_rest_client.model_name + self.model_name = cfg.settings.model_name self.timeout = cfg.llm_rest_client.timeout self.max_pool_connections = cfg.llm_rest_client.max_pool_connections self.max_pool_maxsize = cfg.llm_rest_client.max_pool_maxsize self.max_tokens = cfg.llm_rest_client.max_tokens - self.verbose = cfg.llm_rest_client.verbose self.sampling_params = cfg.llm_rest_client.sampling_params + self.provider = cfg.settings.provider + self.openai_api_key = cfg.settings.openai.api_key if cfg.settings.openai else None + self.openai_base_url = cfg.settings.openai.base_url if cfg.settings.openai else None # Tokenizer related variables self.pretrained_model_name_or_path = Path(cfg.tokenizer.pretrained_model_name_or_path) @@ -65,14 +76,7 @@ def __init__(self, config_file_path: Path, experiment_id: str, rest_endpoint: st self.jq_language_pattern = cfg.document_processor.jq_language_pattern def run(self): - """Runs the LLM service. - - This method loads the dataset, initializes the tokenizer, LLMRestClient, and DocumentProcessor, - and then runs the document processing on the loaded data to obtain the model responses. - """ - - # Get Tokenizer - # This tokenizer is only used for applying the chat template, but is not applied within TGI. + """Runs the LLM service.""" tokenizer = PreTrainedHFTokenizer( pretrained_model_name_or_path=self.pretrained_model_name_or_path, truncation=False, @@ -82,7 +86,6 @@ def run(self): add_generation_prompt=self.add_generation_prompt, ) - # Get LLMRestClient llm_rest_client = LLMRestClient( max_retries=self.max_retries, backoff_factor=self.backoff_factor, @@ -95,9 +98,11 @@ def run(self): max_tokens=self.max_tokens, sampling_params=self.sampling_params, verbose=self.verbose, + provider=self.provider, + openai_api_key=self.openai_api_key, + openai_base_url=self.openai_base_url, ) - # Get DocumentProcessor document_processor = DocumentProcessor( llm_rest_client=llm_rest_client, prompt_builder=PromptBuilder( @@ -112,4 +117,4 @@ def run(self): jq_language_pattern=self.jq_language_pattern, ) - document_processor.run() + document_processor.run() \ No newline at end of file