Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LogProbInferenceEngines API and implement for OpenAI #909

Merged
merged 13 commits into from
Jun 30, 2024
Merged
1 change: 1 addition & 0 deletions src/unitxt/dialog_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
{"user": "kkk", "system": ""},
]
"""

from typing import Any, Dict, List, Optional

from .formats import SystemFormat
Expand Down
79 changes: 73 additions & 6 deletions src/unitxt/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from dataclasses import field
from typing import Any, Dict, List, Literal, Optional, Union

from tqdm import tqdm

from .artifact import Artifact
from .operator import PackageRequirementsMixin

Expand All @@ -15,12 +17,31 @@ def _infer(self, dataset):
"""Perform inference on the input dataset."""
pass

def infer(self, dataset):
def infer(self, dataset) -> str:
"""Verifies instances of a dataset and performs inference."""
[self.verify_instance(instance) for instance in dataset]
return self._infer(dataset)


class LogProbInferenceEngine(abc.ABC, Artifact):
"""Abstract base class for inference with log probs."""

@abc.abstractmethod
def _infer_log_probs(self, dataset):
"""Perform inference on the input dataset that returns log probs."""
pass

def infer_log_probs(self, dataset) -> List[Dict]:
"""Verifies instances of a dataset and performs inference that returns log probabilities of top tokens.

For each instance , returns a list of top tokens per position.
[ "top_tokens": [ { "text": ..., "logprob": ...} , ... ]

"""
[self.verify_instance(instance) for instance in dataset]
return self._infer_log_probs(dataset)


class HFPipelineBasedInferenceEngine(InferenceEngine, PackageRequirementsMixin):
model_name: str
max_new_tokens: int
Expand Down Expand Up @@ -158,9 +179,12 @@ class OpenAiInferenceEngineParams(Artifact):
stop: Union[Optional[str], List[str]] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
top_logprobs: Optional[int] = 20


class OpenAiInferenceEngine(InferenceEngine, PackageRequirementsMixin):
class OpenAiInferenceEngine(
InferenceEngine, LogProbInferenceEngine, PackageRequirementsMixin
):
label: str = "openai"
model_name: str
parameters: OpenAiInferenceEngineParams = field(
Expand All @@ -169,6 +193,7 @@ class OpenAiInferenceEngine(InferenceEngine, PackageRequirementsMixin):
_requirement = {
"openai": "Install openai package using 'pip install --upgrade openai"
}
data_classification_policy = ["public"]

def prepare(self):
from openai import OpenAI
Expand All @@ -183,8 +208,9 @@ def prepare(self):
self.client = OpenAI(api_key=api_key)

def _infer(self, dataset):
yoavkatz marked this conversation as resolved.
Show resolved Hide resolved
return [
self.client.chat.completions.create(
outputs = []
for instance in tqdm(dataset, desc="Inferring with openAI API"):
response = self.client.chat.completions.create(
messages=[
# {
# "role": "system",
Expand All @@ -204,8 +230,49 @@ def _infer(self, dataset):
temperature=self.parameters.temperature,
top_p=self.parameters.top_p,
)
for instance in dataset
]
output = response.choices[0].message.content

outputs.append(output)

return outputs

def _infer_log_probs(self, dataset):
outputs = []
for instance in tqdm(dataset, desc="Inferring with openAI API"):
response = self.client.chat.completions.create(
messages=[
# {
# "role": "system",
# "content": self.system_prompt,
# },
{
"role": "user",
"content": instance["source"],
}
],
model=self.model_name,
frequency_penalty=self.parameters.frequency_penalty,
presence_penalty=self.parameters.presence_penalty,
max_tokens=self.parameters.max_tokens,
seed=self.parameters.seed,
stop=self.parameters.stop,
temperature=self.parameters.temperature,
top_p=self.parameters.top_p,
logprobs=True,
top_logprobs=self.parameters.top_logprobs,
)
top_logprobs_response = response.choices[0].logprobs.content
output = [
{
"top_tokens": [
{"text": obj.token, "logprob": obj.logprob}
for obj in generated_token.top_logprobs
]
}
for generated_token in top_logprobs_response
]
outputs.append(output)
return outputs


class WMLInferenceEngineParams(Artifact):
Expand Down
1 change: 1 addition & 0 deletions src/unitxt/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

------------------------
"""

import fnmatch
import itertools
import os
Expand Down
1 change: 1 addition & 0 deletions src/unitxt/service/metrics/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Then, save the value in the environment variable UNITXT_METRICS_MASTER_KEY_TOKEN.
To create tokens that have access for the master key, use create_token(..), as shown in main().
"""

from datetime import datetime, timedelta

from fastapi import Depends, HTTPException
Expand Down
1 change: 1 addition & 0 deletions src/unitxt/struct_data_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
{"key1": "value1", "key2": value2, "key3": "value3"}
------------------------
"""

import json
import random
from abc import ABC, abstractmethod
Expand Down
1 change: 1 addition & 0 deletions utils/compare_unitxt_datasets_between_versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Done by run create_examples_for_recipes_file
4. Compare dir A and dir B using generate_diff_html (defined in a separate file).
"""

import concurrent.futures
import itertools
import json
Expand Down
Loading