Skip to content

Commit

Permalink
fix: fixing #7 and #5
Browse files Browse the repository at this point in the history
  • Loading branch information
lorr1 committed Jun 11, 2022
1 parent 74b9302 commit f568875
Show file tree
Hide file tree
Showing 14 changed files with 406 additions and 1,875 deletions.
1 change: 0 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
dev:
poetry install
poetry run pre-commit install
poetry run mypy --install-types

test: dev check
poetry run pytest tests
Expand Down
47 changes: 47 additions & 0 deletions manifest/caches/noop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""Noop cache."""
from typing import Any, Dict, Union

from manifest.caches import Cache


class NoopCache(Cache):
"""A Noop cache that caches nothing for request/response pairs."""

def connect(self, connection_str: str, cache_args: Dict[str, Any]) -> None:
"""
Connect to client.
Args:
connection_str: connection string.
cache_args: cache arguments.
"""
pass

def close(self) -> None:
"""Close the client."""
pass

def get_key(self, key: str, table: str = "default") -> Union[str, None]:
"""
Return None key for never in cache.
Args:
key: key for cache.
table: table to get key in.
"""
return None

def set_key(self, key: str, value: str, table: str = "default") -> None:
"""
Do not set anything as no cache.
Args:
key: key for cache.
value: new value for key.
table: table to set key in.
"""
pass

def commit(self) -> None:
"""Commit any results."""
pass
103 changes: 103 additions & 0 deletions manifest/clients/ai21.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
"""OpenAI client."""
import logging
import os
from typing import Any, Callable, Dict, Optional, Tuple

import requests

from manifest.clients.client import Client

logger = logging.getLogger(__name__)

AI21_ENGINES = {
"j1-jumbo",
"j1-grande",
"j1-large",
}


class AI21Client(Client):
"""AI21Client client."""

def connect(
self,
connection_str: Optional[str] = None,
client_args: Dict[str, Any] = {},
) -> None:
"""
Connect to the AI21 server.
connection_str is passed as default AI21_API_KEY if variable not set.
Args:
connection_str: connection string.
client_args: client arguments.
"""
# Taken from https://studio.ai21.com/docs/api/
self.host = "https://api.ai21.com/studio/v1"
self.api_key = os.environ.get("AI21_API_KEY", connection_str)
if self.api_key is None:
raise ValueError(
"AI21 API key not set. Set AI21_API_KEY environment "
"variable or pass through `connection_str`."
)
self.engine = client_args.pop("engine", "j1-large")
if self.engine not in AI21_ENGINES:
raise ValueError(f"Invalid engine {self.engine}. Must be {AI21_ENGINES}.")
self.temperature = client_args.pop("temperature", 0.0)
self.max_tokens = client_args.pop("max_tokens", 10)
self.top_k_return = client_args.pop("topKReturn", 1.0)
self.num_results = client_args.pop("numResults", 1)
self.top_p = client_args.pop("topP", 1.0)

def close(self) -> None:
"""Close the client."""
pass

def get_model_params(self) -> Dict:
"""
Get model params.
By getting model params from the server, we can add to request
and make sure cache keys are unique to model.
Returns:
model params.
"""
return {"model_name": "ai21", "engine": self.engine}

def get_request(self, query: str, **kwargs: Any) -> Tuple[Callable[[], Dict], Dict]:
"""
Get request string function.
Args:
query: query string.
Returns:
request function that takes no input.
request parameters as dict.
"""
request_params = {
"engine": kwargs.get("engine", self.engine),
"prompt": query,
"temperature": kwargs.get("temperature", self.temperature),
"maxTokens": kwargs.get("maxTokens", self.max_tokens),
"topKReturn": kwargs.get("topKReturn", self.top_k_return),
"numResults": kwargs.get("numResults", self.num_results),
"topP": kwargs.get("topP", self.top_p),
}

def _run_completion() -> Dict:
post_str = self.host + "/" + self.engine + "/complete"
print(self.api_key)
print(post_str)
print("https://api.ai21.com/studio/v1/j1-large/complete")
print(request_params)
res = requests.post(
post_str,
headers={"Authorization": f"Bearer {self.api_key}"},
json=request_params,
)
return res.json()

return _run_completion, request_params
13 changes: 13 additions & 0 deletions manifest/clients/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,19 @@ def close(self) -> None:
"""Close the client."""
raise NotImplementedError()

@abstractmethod
def get_model_params(self) -> Dict:
"""
Get model params.
By getting model params from the server, we can add to request
and make sure cache keys are unique to model.
Returns:
model params.
"""
raise NotImplementedError()

@abstractmethod
def connect(self, connection_str: str, client_args: Dict[str, Any]) -> None:
"""
Expand Down
138 changes: 138 additions & 0 deletions manifest/clients/crfm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
"""OpenAI client."""
import logging
import os
import sys
from typing import Any, Callable, Dict, Optional, Tuple

from manifest.clients.client import Client

crfm_code_dir = os.environ.get("CRFM_CODE_DIR", "/home/code/benchmarking")
sys.path.append(crfm_code_dir)

from src.common.authentication import Authentication # type: ignore
from src.common.request import Request, RequestResult # type: ignore
from src.proxy.remote_service import RemoteService # type: ignore

logger = logging.getLogger(__name__)

CRFM_ENGINES = {
"ai21/j1-jumbo",
"ai21/j1-grande",
"ai21/j1-large",
}


class CRFMClient(Client):
"""CRFMClient client."""

def connect(
self,
connection_str: Optional[str] = None,
client_args: Dict[str, Any] = {},
) -> None:
"""
Connect to the CRFM endpoint.
connection_str is passed as default CRFM_API_KEY if variable not set.
Args:
connection_str: connection string.
client_args: client arguments.
"""
self.service = RemoteService("https://crfm-models.stanford.edu")
api_key = os.environ.get("CRFM_API_KEY", connection_str)
if api_key is None:
raise ValueError(
"CRFM API key not set. Set CRFM_API_KEY environment "
"variable or pass through `connection_str`."
)
self.auth = Authentication(api_key=api_key)
self.engine = client_args.pop("engine", "ai21/j1-large")
if self.engine not in CRFM_ENGINES:
raise ValueError(f"Invalid engine {self.engine}. Must be {CRFM_ENGINES}.")
self.temperature = client_args.pop("temperature", 0.0)
self.max_tokens = client_args.pop("max_tokens", 10)
self.top_k_per_token = client_args.pop("top_k_per_token", 1)
self.num_completions = client_args.pop("num_completions", 1)
self.stop_sequences = client_args.pop("stop_sequences", [])
self.top_p = client_args.pop("top_p", 1.0)
self.presence_penalty = client_args.pop("presence_penalty", 1.0)
self.frequency_penalty = client_args.pop("frequency_penalty", 1.0)

def close(self) -> None:
"""Close the client."""
pass

def get_model_params(self) -> Dict:
"""
Get model params.
By getting model params from the server, we can add to request
and make sure cache keys are unique to model.
Returns:
model params.
"""
return {"model_name": "crfm", "engine": self.engine}

def format_response(self, response: RequestResult) -> Dict[str, Any]:
"""
Format RequestResult to dict.
Args:
response: RequestResult
Return:
response as dict
"""
return {
"object": "text_completion",
"model": self.engine,
"choices": [
{
"text": text.text,
# TODO: Add in more metadata for HF models
# "logprobs": {
# "tokens": result["tokens"],
# "token_logprobs": result["token_scores"],
# "text_offset": result["text_offset"],
# "top_logprobs": result["top_logprobs"],
# "finish_reason": "length",
# },
}
for text in response.completions
],
}

def get_request(self, query: str, **kwargs: Any) -> Tuple[Callable[[], Dict], Dict]:
"""
Get request string function.
Args:
query: query string.
Returns:
request function that takes no input.
request parameters as dict.
"""
request_params = {
"model": kwargs.get("engine", self.engine),
"prompt": query,
"temperature": kwargs.get("temperature", self.temperature),
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
"top_k_per_token": kwargs.get("top_k_per_token", self.top_k_per_token),
"num_completions": kwargs.get("num_completions", self.num_completions),
"stop_sequences": kwargs.get("stop_sequences", self.stop_sequences),
"top_p": kwargs.get("top_p", self.top_p),
"presence_penalty": kwargs.get("presence_penalty", self.presence_penalty),
"frequency_penalty": kwargs.get(
"frequency_penalty", self.frequency_penalty
),
}

def _run_completion() -> Dict:
request = Request(**request_params)
request_result = self.service.make_request(self.auth, request)
return self.format_response(request_result)

return _run_completion, request_params
12 changes: 12 additions & 0 deletions manifest/clients/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,18 @@ def close(self) -> None:
"""Close the client."""
pass

def get_model_params(self) -> Dict:
"""
Get model params.
By getting model params from the server, we can add to request
and make sure cache keys are unique to model.
Returns:
model params.
"""
return {"engine": "dummy"}

def get_request(self, query: str, **kwargs: Any) -> Tuple[Callable[[], Dict], Dict]:
"""
Get request string function.
Expand Down
16 changes: 16 additions & 0 deletions manifest/clients/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ def connect(
self.temperature = client_args.pop("temperature", 0.0)
self.max_tokens = client_args.pop("max_tokens", 10)
self.top_p = client_args.pop("top_p", 1.0)
self.logprobs = client_args.pop("logprobs", None)
self.best_of = client_args.pop("best_of", 1)
self.frequency_penalty = client_args.pop("frequency_penalty", 0.0)
self.presence_penalty = client_args.pop("presence_penalty", 0.0)
self.n = client_args.pop("n", 1)
Expand All @@ -57,6 +59,18 @@ def close(self) -> None:
"""Close the client."""
pass

def get_model_params(self) -> Dict:
"""
Get model params.
By getting model params from the server, we can add to request
and make sure cache keys are unique to model.
Returns:
model params.
"""
return {"model_name": "openai", "engine": self.engine}

def get_request(self, query: str, **kwargs: Any) -> Tuple[Callable[[], Dict], Dict]:
"""
Get request string function.
Expand All @@ -77,6 +91,8 @@ def get_request(self, query: str, **kwargs: Any) -> Tuple[Callable[[], Dict], Di
"frequency_penalty": kwargs.get(
"frequency_penalty", self.frequency_penalty
),
"logprobs": kwargs.get("logprobs", self.logprobs),
"best_of": kwargs.get("best_of", self.best_of),
"presence_penalty": kwargs.get("presence_penalty", self.presence_penalty),
"n": kwargs.get("n", self.n),
}
Expand Down
12 changes: 12 additions & 0 deletions manifest/clients/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,18 @@ def close(self) -> None:
"""Close the client."""
pass

def get_model_params(self) -> Dict:
"""
Get model params.
By getting model params from the server, we can add to request
and make sure cache keys are unique to model.
Returns:
model params.
"""
return {"model_name": "opt"}

def get_request(self, query: str, **kwargs: Any) -> Tuple[Callable[[], Dict], Dict]:
"""
Get request string function.
Expand Down
Loading

0 comments on commit f568875

Please sign in to comment.