-
Notifications
You must be signed in to change notification settings - Fork 46
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
14 changed files
with
406 additions
and
1,875 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
"""Noop cache.""" | ||
from typing import Any, Dict, Union | ||
|
||
from manifest.caches import Cache | ||
|
||
|
||
class NoopCache(Cache): | ||
"""A Noop cache that caches nothing for request/response pairs.""" | ||
|
||
def connect(self, connection_str: str, cache_args: Dict[str, Any]) -> None: | ||
""" | ||
Connect to client. | ||
Args: | ||
connection_str: connection string. | ||
cache_args: cache arguments. | ||
""" | ||
pass | ||
|
||
def close(self) -> None: | ||
"""Close the client.""" | ||
pass | ||
|
||
def get_key(self, key: str, table: str = "default") -> Union[str, None]: | ||
""" | ||
Return None key for never in cache. | ||
Args: | ||
key: key for cache. | ||
table: table to get key in. | ||
""" | ||
return None | ||
|
||
def set_key(self, key: str, value: str, table: str = "default") -> None: | ||
""" | ||
Do not set anything as no cache. | ||
Args: | ||
key: key for cache. | ||
value: new value for key. | ||
table: table to set key in. | ||
""" | ||
pass | ||
|
||
def commit(self) -> None: | ||
"""Commit any results.""" | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
"""OpenAI client.""" | ||
import logging | ||
import os | ||
from typing import Any, Callable, Dict, Optional, Tuple | ||
|
||
import requests | ||
|
||
from manifest.clients.client import Client | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
AI21_ENGINES = { | ||
"j1-jumbo", | ||
"j1-grande", | ||
"j1-large", | ||
} | ||
|
||
|
||
class AI21Client(Client): | ||
"""AI21Client client.""" | ||
|
||
def connect( | ||
self, | ||
connection_str: Optional[str] = None, | ||
client_args: Dict[str, Any] = {}, | ||
) -> None: | ||
""" | ||
Connect to the AI21 server. | ||
connection_str is passed as default AI21_API_KEY if variable not set. | ||
Args: | ||
connection_str: connection string. | ||
client_args: client arguments. | ||
""" | ||
# Taken from https://studio.ai21.com/docs/api/ | ||
self.host = "https://api.ai21.com/studio/v1" | ||
self.api_key = os.environ.get("AI21_API_KEY", connection_str) | ||
if self.api_key is None: | ||
raise ValueError( | ||
"AI21 API key not set. Set AI21_API_KEY environment " | ||
"variable or pass through `connection_str`." | ||
) | ||
self.engine = client_args.pop("engine", "j1-large") | ||
if self.engine not in AI21_ENGINES: | ||
raise ValueError(f"Invalid engine {self.engine}. Must be {AI21_ENGINES}.") | ||
self.temperature = client_args.pop("temperature", 0.0) | ||
self.max_tokens = client_args.pop("max_tokens", 10) | ||
self.top_k_return = client_args.pop("topKReturn", 1.0) | ||
self.num_results = client_args.pop("numResults", 1) | ||
self.top_p = client_args.pop("topP", 1.0) | ||
|
||
def close(self) -> None: | ||
"""Close the client.""" | ||
pass | ||
|
||
def get_model_params(self) -> Dict: | ||
""" | ||
Get model params. | ||
By getting model params from the server, we can add to request | ||
and make sure cache keys are unique to model. | ||
Returns: | ||
model params. | ||
""" | ||
return {"model_name": "ai21", "engine": self.engine} | ||
|
||
def get_request(self, query: str, **kwargs: Any) -> Tuple[Callable[[], Dict], Dict]: | ||
""" | ||
Get request string function. | ||
Args: | ||
query: query string. | ||
Returns: | ||
request function that takes no input. | ||
request parameters as dict. | ||
""" | ||
request_params = { | ||
"engine": kwargs.get("engine", self.engine), | ||
"prompt": query, | ||
"temperature": kwargs.get("temperature", self.temperature), | ||
"maxTokens": kwargs.get("maxTokens", self.max_tokens), | ||
"topKReturn": kwargs.get("topKReturn", self.top_k_return), | ||
"numResults": kwargs.get("numResults", self.num_results), | ||
"topP": kwargs.get("topP", self.top_p), | ||
} | ||
|
||
def _run_completion() -> Dict: | ||
post_str = self.host + "/" + self.engine + "/complete" | ||
print(self.api_key) | ||
print(post_str) | ||
print("https://api.ai21.com/studio/v1/j1-large/complete") | ||
print(request_params) | ||
res = requests.post( | ||
post_str, | ||
headers={"Authorization": f"Bearer {self.api_key}"}, | ||
json=request_params, | ||
) | ||
return res.json() | ||
|
||
return _run_completion, request_params |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
"""OpenAI client.""" | ||
import logging | ||
import os | ||
import sys | ||
from typing import Any, Callable, Dict, Optional, Tuple | ||
|
||
from manifest.clients.client import Client | ||
|
||
crfm_code_dir = os.environ.get("CRFM_CODE_DIR", "/home/code/benchmarking") | ||
sys.path.append(crfm_code_dir) | ||
|
||
from src.common.authentication import Authentication # type: ignore | ||
from src.common.request import Request, RequestResult # type: ignore | ||
from src.proxy.remote_service import RemoteService # type: ignore | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
CRFM_ENGINES = { | ||
"ai21/j1-jumbo", | ||
"ai21/j1-grande", | ||
"ai21/j1-large", | ||
} | ||
|
||
|
||
class CRFMClient(Client): | ||
"""CRFMClient client.""" | ||
|
||
def connect( | ||
self, | ||
connection_str: Optional[str] = None, | ||
client_args: Dict[str, Any] = {}, | ||
) -> None: | ||
""" | ||
Connect to the CRFM endpoint. | ||
connection_str is passed as default CRFM_API_KEY if variable not set. | ||
Args: | ||
connection_str: connection string. | ||
client_args: client arguments. | ||
""" | ||
self.service = RemoteService("https://crfm-models.stanford.edu") | ||
api_key = os.environ.get("CRFM_API_KEY", connection_str) | ||
if api_key is None: | ||
raise ValueError( | ||
"CRFM API key not set. Set CRFM_API_KEY environment " | ||
"variable or pass through `connection_str`." | ||
) | ||
self.auth = Authentication(api_key=api_key) | ||
self.engine = client_args.pop("engine", "ai21/j1-large") | ||
if self.engine not in CRFM_ENGINES: | ||
raise ValueError(f"Invalid engine {self.engine}. Must be {CRFM_ENGINES}.") | ||
self.temperature = client_args.pop("temperature", 0.0) | ||
self.max_tokens = client_args.pop("max_tokens", 10) | ||
self.top_k_per_token = client_args.pop("top_k_per_token", 1) | ||
self.num_completions = client_args.pop("num_completions", 1) | ||
self.stop_sequences = client_args.pop("stop_sequences", []) | ||
self.top_p = client_args.pop("top_p", 1.0) | ||
self.presence_penalty = client_args.pop("presence_penalty", 1.0) | ||
self.frequency_penalty = client_args.pop("frequency_penalty", 1.0) | ||
|
||
def close(self) -> None: | ||
"""Close the client.""" | ||
pass | ||
|
||
def get_model_params(self) -> Dict: | ||
""" | ||
Get model params. | ||
By getting model params from the server, we can add to request | ||
and make sure cache keys are unique to model. | ||
Returns: | ||
model params. | ||
""" | ||
return {"model_name": "crfm", "engine": self.engine} | ||
|
||
def format_response(self, response: RequestResult) -> Dict[str, Any]: | ||
""" | ||
Format RequestResult to dict. | ||
Args: | ||
response: RequestResult | ||
Return: | ||
response as dict | ||
""" | ||
return { | ||
"object": "text_completion", | ||
"model": self.engine, | ||
"choices": [ | ||
{ | ||
"text": text.text, | ||
# TODO: Add in more metadata for HF models | ||
# "logprobs": { | ||
# "tokens": result["tokens"], | ||
# "token_logprobs": result["token_scores"], | ||
# "text_offset": result["text_offset"], | ||
# "top_logprobs": result["top_logprobs"], | ||
# "finish_reason": "length", | ||
# }, | ||
} | ||
for text in response.completions | ||
], | ||
} | ||
|
||
def get_request(self, query: str, **kwargs: Any) -> Tuple[Callable[[], Dict], Dict]: | ||
""" | ||
Get request string function. | ||
Args: | ||
query: query string. | ||
Returns: | ||
request function that takes no input. | ||
request parameters as dict. | ||
""" | ||
request_params = { | ||
"model": kwargs.get("engine", self.engine), | ||
"prompt": query, | ||
"temperature": kwargs.get("temperature", self.temperature), | ||
"max_tokens": kwargs.get("max_tokens", self.max_tokens), | ||
"top_k_per_token": kwargs.get("top_k_per_token", self.top_k_per_token), | ||
"num_completions": kwargs.get("num_completions", self.num_completions), | ||
"stop_sequences": kwargs.get("stop_sequences", self.stop_sequences), | ||
"top_p": kwargs.get("top_p", self.top_p), | ||
"presence_penalty": kwargs.get("presence_penalty", self.presence_penalty), | ||
"frequency_penalty": kwargs.get( | ||
"frequency_penalty", self.frequency_penalty | ||
), | ||
} | ||
|
||
def _run_completion() -> Dict: | ||
request = Request(**request_params) | ||
request_result = self.service.make_request(self.auth, request) | ||
return self.format_response(request_result) | ||
|
||
return _run_completion, request_params |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.