Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ settings:
model_name: google/gemma-2-9b-it
num_gpus: 1
tokenizer_name_or_path: ${settings.model_name}
provider: local
paths:
raw_data_file_paths:
- data/deu_Latn_sampled_500k_first200.jsonl
- data/test.jsonl
start_indexes:
- 150
output_directory_path: data/output
Expand Down
45 changes: 45 additions & 0 deletions configs/score_documents/adult_content_config_openai.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
settings:
model_name: gpt-4o
num_gpus: 1
tokenizer_name_or_path: google/gemma-2-9b-it
provider: openai
openai:
api_key: ${env:OPENAI_API_KEY}
base_url: https://api.openai.com/v1

paths:
raw_data_file_paths:
- /raid/fromm/ml_filter/data/test.jsonl
start_indexes:
- 0
output_directory_path: /raid/fromm/ml_filter/data/output
prompt_template_file_path: /raid/fromm/ml_filter/data/prompts/adult/adult_content_scoring_prompt.yaml
llm_rest_client:
max_tokens: 8192
sampling_params:
max_tokens: 500
temperature: 0.7
n: 3
top_p: 0.9
max_pool_connections: 1000
max_pool_maxsize: 1000
max_retries: 2
backoff_factor: 0.4
timeout: 100
verbose: false
num_gpus: ${settings.num_gpus}
max_new_tokens: 500
tokenizer:
pretrained_model_name_or_path: ${settings.tokenizer_name_or_path}
special_tokens: null
add_generation_prompt: true
prompt_builder:
prompt_template_file_path: ${settings.paths.prompt_template_file_path}
max_prompt_length: 7690
document_processor:
output_directory_path: ${settings.paths.output_directory_path}
queue_size: 1000
num_processes: 10
score_metric_name: adult_score
strings_to_remove: []
jq_language_pattern: .language
4 changes: 2 additions & 2 deletions configs/score_documents/lorem_ipsum.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ settings:
# we need to set this here manually as this is specified only when hosting the model
num_gpus: 1
tokenizer_name_or_path: ${settings.model_name}

provider: local

paths:
raw_data_file_paths:
- data/test_fineweb2_dump.jsonl
Expand All @@ -15,7 +16,6 @@ settings:
- 10

llm_rest_client:
model_name: ${settings.model_name}
max_tokens: 8192 # The maximum total number of tokens supported by the model (input + output)
sampling_params:
max_tokens: 500 # The maximum number of tokens to generate
Expand Down
87 changes: 58 additions & 29 deletions src/ml_filter/config/annotation_pipeline_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,56 +3,85 @@
from pydantic import BaseModel, DirectoryPath, Field, FilePath


class OpenAIConfig(BaseModel):
api_key: str = Field(..., description="OpenAI API key, typically sourced from an environment variable.")
base_url: str = Field(default="https://api.openai.com/v1", description="OpenAI API base URL.")

class Config:
extra = "forbid"


class PathsConfig(BaseModel):
raw_data_file_paths: List[FilePath] = Field(default_factory=list)
output_directory_path: DirectoryPath
prompt_template_file_path: FilePath
start_indexes: List[int] = Field(default_factory=list)

class Config:
extra = "forbid"


class SettingsConfig(BaseModel):
model_name: str
num_gpus: int
tokenizer_name_or_path: str
model_name: str = Field(..., description="Model name (e.g., 'google/gemma-2-9b-it' for local, 'gpt-4o' for OpenAI).")
num_gpus: int = Field(..., ge=0, description="Number of GPUs for local LLM (ignored for OpenAI).")
tokenizer_name_or_path: str = Field(..., description="Tokenizer name or path.")
paths: PathsConfig
provider: str = Field(..., description="LLM provider: 'local' or 'openai'.")
openai: Optional[OpenAIConfig] = Field(
default=None, description="OpenAI-specific configuration, required if provider is 'openai'."
)

class Config:
extra = "forbid"


class LLMRestClientConfig(BaseModel):
model_name: str
max_tokens: int
max_pool_connections: int
max_pool_maxsize: int
max_retries: int
backoff_factor: float
timeout: int
verbose: bool
num_gpus: int
sampling_params: dict
max_tokens: int = Field(..., ge=1, description="Maximum total tokens (input + output).")
max_pool_connections: int = Field(..., ge=1, description="Maximum pool connections for local LLM.")
max_pool_maxsize: int = Field(..., ge=1, description="Maximum pool size for local LLM.")
max_retries: int = Field(..., ge=0, description="Maximum number of retries for API requests.")
backoff_factor: float = Field(..., ge=0.0, description="Backoff factor for retry delays.")
timeout: int = Field(..., ge=1, description="Request timeout in seconds.")
verbose: bool = Field(..., description="Enable verbose logging.")
num_gpus: int = Field(..., ge=0, description="Number of GPUs for local LLM (ignored for OpenAI).")
sampling_params: dict = Field(..., description="Sampling parameters for text generation.")


class TokenizerConfig(BaseModel):
pretrained_model_name_or_path: str
special_tokens: Optional[dict]
add_generation_prompt: bool
pretrained_model_name_or_path: str = Field(..., description="Pretrained tokenizer name or path.")
special_tokens: Optional[dict] = Field(default=None, description="Special tokens for tokenizer.")
add_generation_prompt: bool = Field(..., description="Whether to add generation prompt.")

class Config:
extra = "forbid"


class PromptBuilderConfig(BaseModel):
prompt_template_file_path: str
max_prompt_length: int
prompt_template_file_path: str = Field(..., description="Path to the prompt template file.")
max_prompt_length: int = Field(..., ge=1, description="Maximum length of the prompt.")

class Config:
extra = "forbid"


class DocumentProcessorConfig(BaseModel):
output_directory_path: DirectoryPath
queue_size: int
num_processes: int
score_metric_name: str
strings_to_remove: List[str] = Field(default_factory=list)
jq_language_pattern: str
output_directory_path: DirectoryPath = Field(..., description="Output directory for processed documents.")
queue_size: int = Field(..., ge=1, description="Size of the processing queue.")
num_processes: int = Field(..., ge=1, description="Number of processes for document processing.")
score_metric_name: str = Field(..., description="Name of the score metric.")
strings_to_remove: List[str] = Field(default_factory=list, description="Strings to remove from documents.")
jq_language_pattern: str = Field(..., description="JQ pattern for language metadata.")

class Config:
extra = "forbid"


class AnnotationPipelineConfig(BaseModel):
settings: SettingsConfig
llm_rest_client: LLMRestClientConfig
tokenizer: TokenizerConfig
prompt_builder: PromptBuilderConfig
document_processor: DocumentProcessorConfig
settings: SettingsConfig = Field(..., description="General settings for the pipeline.")
llm_rest_client: LLMRestClientConfig = Field(..., description="Configuration for LLM REST client.")
tokenizer: TokenizerConfig = Field(..., description="Configuration for tokenizer.")
prompt_builder: PromptBuilderConfig = Field(..., description="Configuration for prompt builder.")
document_processor: DocumentProcessorConfig = Field(..., description="Configuration for document processor.")

class Config:
extra = "forbid"
121 changes: 85 additions & 36 deletions src/ml_filter/llm_api/llm_rest_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,14 @@

from requests import RequestException, Session
from requests.adapters import HTTPAdapter
from openai import OpenAI, OpenAIError

from ml_filter.data_processing.document import DocumentProcessingStatus, ProcessedDocument
from ml_filter.utils.logging import get_logger


class LLMRestClient:
""" "A class representing a REST client for the LLM service.
This class is responsible for sending requests to the LLM service
(hosted TGI container given the endpoint) and returning the response.
"""
"""A class representing a REST client for LLM services, supporting both local TGI and OpenAI API."""

def __init__(
self,
Expand All @@ -31,6 +29,9 @@ def __init__(
max_tokens: int,
verbose: bool,
sampling_params: Dict[str, Any],
provider: str = "local",
openai_api_key: str = None,
openai_base_url: str = None,
):
"""Initializes the LLMRestClient."""
self.max_retries = max_retries
Expand All @@ -40,32 +41,84 @@ def __init__(
self.max_tokens = max_tokens
self.verbose = verbose
self.logger = get_logger(name=self.__class__.__name__, level=logging.INFO)
self.session = session
self.sampling_params = sampling_params
self.provider = provider.lower()

if self.provider == "openai":
if not openai_api_key:
raise ValueError("OpenAI API key is required when provider is 'openai'.")
self.openai_client = OpenAI(
api_key=openai_api_key,
base_url=openai_base_url or "https://api.openai.com/v1",
)
else:
self.session = session
self.session.mount(
"http://", HTTPAdapter(pool_connections=max_pool_connections, pool_maxsize=max_pool_maxsize)
)
self.rest_endpoint_generate = (
f"{rest_endpoint}v1/completions" if rest_endpoint.endswith("/") else f"{rest_endpoint}/v1/completions"
)
self.logger.info(f"Using rest endpoint at {self.rest_endpoint_generate}")

# TODO: Not entirely sure why this is needed now, but it worked fine previously
self.session.mount("http://", HTTPAdapter(pool_connections=max_pool_connections, pool_maxsize=max_pool_maxsize))

self.rest_endpoint_generate = (
f"{rest_endpoint}v1/completions" if rest_endpoint.endswith("/") else f"{rest_endpoint}/v1/completions"
)
def generate(self, processed_document: ProcessedDocument) -> List[ProcessedDocument]:
"""Generates a response based on the given prompt."""
if self.provider == "openai":
return self._generate_openai(processed_document)
else:
return self._generate_local(processed_document)

def _generate_openai(self, processed_document: ProcessedDocument) -> List[ProcessedDocument]:
"""Generates a response using the OpenAI API."""
request = {
"model": self.model_name,
"messages": [{"role": "user", "content": processed_document.prompt}],
"max_tokens": self.sampling_params.get("max_tokens", 500),
"temperature": self.sampling_params.get("temperature", 0.7),
"top_p": self.sampling_params.get("top_p", 0.9),
"n": self.sampling_params.get("n", 1),
}
start_time_generation = time.time()
new_documents = []

self.logger.info(f"Using rest endpoint at {self.rest_endpoint_generate}")
for i in range(self.max_retries):
try:
response = self.openai_client.chat.completions.create(**request)
break
except OpenAIError as e:
self.logger.error(f"OpenAI API request failed with {e}, retrying... ({i+1}/{self.max_retries})")
time.sleep(self.backoff_factor * (2**i))
if i == self.max_retries - 1:
processed_document.document_processing_status = DocumentProcessingStatus.ERROR_SERVER
processed_document.errors.append(str(e))
self.logger.error(f"OpenAI API request failed after {self.max_retries} retries.")
return [processed_document]

def generate(self, processed_document: ProcessedDocument) -> List[ProcessedDocument]:
"""Generates a response based on the given prompt.
Args:
processed_document (ProcessedDocument): The processed document.
generated_texts = self._parse_openai_response(response)
for generated_text in generated_texts:
new_document = copy.deepcopy(processed_document)
if generated_text is not None:
new_document.generated_text = generated_text
else:
new_document.document_processing_status = DocumentProcessingStatus.ERROR_NO_GENERATED_TEXT
new_document.errors.append("No generated text in OpenAI response.")
end_time_generation = time.time()
time_diff_generation = end_time_generation - start_time_generation
completion_tokens = response.usage.completion_tokens if hasattr(response.usage, "completion_tokens") else 0
out_token_per_second = completion_tokens / time_diff_generation if time_diff_generation > 0 else 0
new_document.out_tokens_per_second = out_token_per_second
new_document.timestamp = int(end_time_generation)
new_documents.append(new_document)

Returns:
Dict[str, Any]: A dictionary containing the generated response.
"""
return new_documents

request = dict(
model=self.model_name,
prompt=processed_document.prompt,
def _generate_local(self, processed_document: ProcessedDocument) -> List[ProcessedDocument]:
"""Generates a response using the local TGI endpoint."""
request = {
"model": self.model_name,
"prompt": processed_document.prompt,
**self.sampling_params,
)
}
start_time_generation = time.time()
for i in range(self.max_retries):
try:
Expand All @@ -79,7 +132,6 @@ def generate(self, processed_document: ProcessedDocument) -> List[ProcessedDocum
traceback.print_exc()
print(f"Request failed with {e}, retrying...{i}")
time.sleep(self.backoff_factor * (2**i))

if i == self.max_retries - 1:
processed_document.document_processing_status = DocumentProcessingStatus.ERROR_SERVER
processed_document.errors.append(str(e))
Expand All @@ -99,8 +151,6 @@ def generate(self, processed_document: ProcessedDocument) -> List[ProcessedDocum
new_document.errors.append(f"Response could not be parsed: {response_dict}")
end_time_generation = time.time()
time_diff_generation = end_time_generation - start_time_generation
# note, we only get 'prompt_tokens', 'total_tokens' and 'completion_tokens' on request basis and
# measure time for the full request. We cannot decompose the time for the different parts of the request
out_token_per_second = response_dict["usage"]["completion_tokens"] / time_diff_generation
new_document.out_tokens_per_second = out_token_per_second
new_document.timestamp = int(end_time_generation)
Expand All @@ -114,16 +164,15 @@ def generate(self, processed_document: ProcessedDocument) -> List[ProcessedDocum
return new_documents

def parse_response(self, response_dict: dict) -> List[str] | None:
"""Parses the response from the LLM service.

Args:
response_dict (dict): The response dictionary.

Returns:
str: The generated text.
"""
"""Parses the response from the local LLM service."""
choices = response_dict.get("choices")
if choices is None or len(choices) == 0:
return None
else:
return [choice.get("text") for choice in choices]
return [choice.get("text") for choice in choices]

def _parse_openai_response(self, response: Any) -> List[str] | None:
"""Parses the response from the OpenAI API."""
choices = response.choices
if choices is None or len(choices) == 0:
return None
return [choice.message.content for choice in choices]
Loading