Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Basic implementation of an plugin system for OA #2765

Merged
merged 19 commits into from
May 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 4 additions & 2 deletions inference/full-dev-setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@ else
INFERENCE_TAG=latest
fi

POSTGRES_PORT=${POSTGRES_PORT:-5432}

# Creates a tmux window with splits for the individual services

tmux new-session -d -s "inference-dev-setup"
tmux send-keys "docker run --rm -it -p 5732:5432 -e POSTGRES_PASSWORD=postgres --name postgres postgres" C-m
tmux send-keys "docker run --rm -it -p $POSTGRES_PORT:5432 -e POSTGRES_PASSWORD=postgres --name postgres postgres" C-m
tmux split-window -h
tmux send-keys "docker run --rm -it -p 6779:6379 --name redis redis" C-m

Expand All @@ -30,7 +32,7 @@ fi

tmux split-window -h
tmux send-keys "cd server" C-m
tmux send-keys "LOGURU_LEVEL=$LOGLEVEL POSTGRES_PORT=5732 REDIS_PORT=6779 DEBUG_API_KEYS='0000,0001' ALLOW_DEBUG_AUTH=True TRUSTED_CLIENT_KEYS=6969 uvicorn main:app" C-m
tmux send-keys "LOGURU_LEVEL=$LOGLEVEL POSTGRES_PORT=$POSTGRES_PORT REDIS_PORT=6779 DEBUG_API_KEYS='0000,0001' ALLOW_DEBUG_AUTH=True TRUSTED_CLIENT_KEYS=6969 uvicorn main:app" C-m
tmux split-window -h
tmux send-keys "cd text-client" C-m
tmux send-keys "sleep 5" C-m
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""added used plugin to message

Revision ID: 5b4211625a9f
Revises: ea19bbc743f9
Create Date: 2023-05-01 22:53:16.297495

"""
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql

# revision identifiers, used by Alembic.
revision = "5b4211625a9f"
down_revision = "ea19bbc743f9"
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("message", sa.Column("used_plugin", postgresql.JSONB(astext_type=sa.Text()), nullable=True))
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("message", "used_plugin")
# ### end Alembic commands ###
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,13 @@ async def abort_work(self, message_id: str, reason: str) -> models.DbMessage:
await self.session.refresh(message)
return message

async def complete_work(self, message_id: str, content: str) -> models.DbMessage:
async def complete_work(self, message_id: str, content: str, used_plugin: inference.PluginUsed) -> models.DbMessage:
logger.debug(f"Completing work on message {message_id}")
message = await self.get_assistant_message_by_id(message_id)
message.state = inference.MessageState.complete
message.work_end_at = datetime.datetime.utcnow()
message.content = content
message.used_plugin = used_plugin
await self.session.commit()
logger.debug(f"Completed work on message {message_id}")
await self.session.refresh(message)
Expand Down
2 changes: 2 additions & 0 deletions inference/server/oasst_inference_server/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def custom_json_deserializer(s):
return chat_schema.CreateMessageRequest.parse_obj(d)
case "WorkRequest":
return inference.WorkRequest.parse_obj(d)
case "PluginUsed":
return inference.PluginUsed.parse_obj(d)
case None:
return d
case _:
Expand Down
3 changes: 3 additions & 0 deletions inference/server/oasst_inference_server/models/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class DbMessage(SQLModel, table=True):
safety_label: str | None = Field(None)
safety_rots: str | None = Field(None)

used_plugin: inference.PluginUsed | None = Field(None, sa_column=sa.Column(pg.JSONB))

state: inference.MessageState = Field(inference.MessageState.manual)
work_parameters: inference.WorkParameters = Field(None, sa_column=sa.Column(pg.JSONB))
work_begin_at: datetime.datetime | None = Field(None)
Expand Down Expand Up @@ -68,6 +70,7 @@ def to_read(self) -> inference.MessageRead:
safety_level=self.safety_level,
safety_label=self.safety_label,
safety_rots=self.safety_rots,
used_plugin=self.used_plugin,
)


Expand Down
1 change: 1 addition & 0 deletions inference/server/oasst_inference_server/routes/chats.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ async def create_assistant_message(
work_parameters = inference.WorkParameters(
model_config=model_config,
sampling_parameters=request.sampling_parameters,
plugins=request.plugins,
)
assistant_message = await ucr.initiate_assistant_message(
parent_id=request.parent_id,
Expand Down
91 changes: 91 additions & 0 deletions inference/server/oasst_inference_server/routes/configs.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
import asyncio

import aiohttp
import fastapi
import pydantic
import yaml
from aiohttp.client_exceptions import ClientConnectorError, ServerTimeoutError
from fastapi import HTTPException
from loguru import logger
from oasst_inference_server.settings import settings
from oasst_shared import model_configs
from oasst_shared.schemas import inference

# NOTE: Populate this with plugins that we will provide out of the box
OA_PLUGINS = []

router = fastapi.APIRouter(
prefix="/configs",
tags=["configs"],
Expand Down Expand Up @@ -63,6 +73,16 @@ class ModelConfigInfo(pydantic.BaseModel):
repetition_penalty=1.2,
),
),
ParameterConfig(
name="k50-Plugins",
description="Top-k sampling with k=50 and temperature=0.35",
sampling_parameters=inference.SamplingParameters(
max_new_tokens=1024,
temperature=0.35,
top_k=50,
repetition_penalty=(1 / 0.90),
),
),
ParameterConfig(
name="nucleus9",
description="Nucleus sampling with p=0.9",
Expand Down Expand Up @@ -93,6 +113,44 @@ class ModelConfigInfo(pydantic.BaseModel):
]


async def fetch_plugin(url: str, retries: int = 3, timeout: float = 5.0) -> inference.PluginConfig:
async with aiohttp.ClientSession() as session:
for attempt in range(retries):
try:
async with session.get(url, timeout=timeout) as response:
content_type = response.headers.get("Content-Type")

if response.status == 200:
if "application/json" in content_type or url.endswith(".json"):
config = await response.json()
elif (
"application/yaml" in content_type
or "application/x-yaml" in content_type
or url.endswith(".yaml")
or url.endswith(".yml")
):
config = yaml.safe_load(await response.text())
else:
raise HTTPException(
status_code=400,
detail=f"Unsupported content type: {content_type}. Only JSON and YAML are supported.",
)

return inference.PluginConfig(**config)
elif response.status == 404:
raise HTTPException(status_code=404, detail="Plugin not found")
else:
raise HTTPException(status_code=response.status, detail="Unexpected status code")
except (ClientConnectorError, ServerTimeoutError) as e:
if attempt == retries - 1: # last attempt
raise HTTPException(status_code=500, detail=f"Request failed after {retries} retries: {e}")
await asyncio.sleep(2**attempt) # exponential backoff

except aiohttp.ClientError as e:
raise HTTPException(status_code=500, detail=f"Request failed: {e}")
raise HTTPException(status_code=500, detail="Failed to fetch plugin")


@router.get("/model_configs")
async def get_model_configs() -> list[ModelConfigInfo]:
return [
Expand All @@ -103,3 +161,36 @@ async def get_model_configs() -> list[ModelConfigInfo]:
for model_config_name in model_configs.MODEL_CONFIGS
if (settings.allowed_model_config_names == "*" or model_config_name in settings.allowed_model_config_names_list)
]


@router.post("/plugin_config")
async def get_plugin_config(plugin: inference.PluginEntry) -> inference.PluginEntry:
try:
plugin_config = await fetch_plugin(plugin.url)
except HTTPException as e:
logger.warning(f"Failed to fetch plugin config from {plugin.url}: {e.detail}")
raise fastapi.HTTPException(status_code=e.status_code, detail=e.detail)

return inference.PluginEntry(url=plugin.url, enabled=plugin.enabled, plugin_config=plugin_config)


@router.get("/builtin_plugins")
async def get_builtin_plugins() -> list[inference.PluginEntry]:
plugins = []

for plugin in OA_PLUGINS:
try:
plugin_config = await fetch_plugin(plugin.url)
except HTTPException as e:
logger.warning(f"Failed to fetch plugin config from {plugin.url}: {e.detail}")
continue

final_plugin: inference.PluginEntry = inference.PluginEntry(
url=plugin.url,
enabled=plugin.enabled,
trusted=plugin.trusted,
plugin_config=plugin_config,
)
plugins.append(final_plugin)

return plugins
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ async def handle_generated_text_response(
message = await cr.complete_work(
message_id=message_id,
content=response.text,
used_plugin=response.used_plugin,
)
logger.info(f"Completed work for {message_id=}")
message_packet = inference.InternalFinishedMessageResponse(
Expand Down
2 changes: 2 additions & 0 deletions inference/server/oasst_inference_server/schemas/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ class CreateAssistantMessageRequest(pydantic.BaseModel):
parent_id: str
model_config_name: str
sampling_parameters: inference.SamplingParameters = pydantic.Field(default_factory=inference.SamplingParameters)
plugins: list[inference.PluginEntry] = pydantic.Field(default_factory=list[inference.PluginEntry])
used_plugin: inference.PluginUsed | None = None


class PendingResponseEvent(pydantic.BaseModel):
Expand Down
72 changes: 72 additions & 0 deletions inference/worker/PLUGINS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Plugin system for OA

This is a basic implementation of support for external augmentation and
OpenAI/ChatGPT plugins into the Open-Assistant. In the current state, this is
more of a proof-of-concept and should be considered to be used behind some
experimental flag.

## Architecture

There is now some kind of middleware between work.py(worker) and the final
prompt that is passed to the inference server for generation and streaming. That
middleware is responsible for checking if there is an enabled plugin in the
userland/UI and if so, it will take over the job of creating curated pre-prompts
for plugin usage, as well as generating subsequent calls to LLM(inner
monologues) in order to generate the final externally **augmented** prompt, that
will be passed back to the worker and next to the inference, for final LLM
generation/streaming tokens to the frontend.
olliestanley marked this conversation as resolved.
Show resolved Hide resolved

## Plugins

Plugins are in essence just pretty wrappers around some kind of API-s and serve
a purpose to help LLM utilize it more precisely and reliably, so they can be
quite useful and powerful augmentation tools for Open-Assistant. Two main parts
of a plugin are the ai-plugin.json file, which is just the main descriptor of a
plugin, and the second part is OpenAPI specification of the plugin API-s.

Here is OpenAI plugins
[specification](https://platform.openai.com/docs/plugins/getting-started) that
is currently partially supported with this system.

For now, only non-authentication-based and only (**GET** request) plugins are
supported. Some of them are:

- https://www.klarna.com/.well-known/ai-plugin.json
- https://www.joinmilo.com/.well-known/ai-plugin.json

Adding support for all other request types would be quite tricky with the
current approach. It would be best to drop current “mansplaining” of the API to
LLM and just show it complete json/yaml content. But unfortunately for that to
be reliable and to work as close as current approach we would need larger
context size and a bit more capable models.

And quite a few of them can be found on this website
[plugin "store" wellknown.ai](https://www.wellknown.ai/)
olliestanley marked this conversation as resolved.
Show resolved Hide resolved

One of the ideas of the plugin system is that we can have some internal OA
plugins, which will be like out-of-the-box plugins, and there could be endless
third-party community-developed plugins as well.

### Notes regarding the reliability and performance and the limitations of the plugin system

Performance can vary a lot depending on the models and plugins used. Some of
them work better some worse, but that aspect should improve as we get better and
better models. One of the biggest limitations at the moment is context size and
instruction following capabilities. And that is combated with some prompt
tricks, truncations of the plugin OpenAPI descriptions and dynamically
including/excluding parts of the prompts in the internal processing of the
subsequent generations of intermediate texts (inner monologues). More of the
limitations and possible alternatives are explained in code comments.

The current approach is somewhat hybrid I would say, and relies on the zero-shot
capabilities of a model. There will be one more branch with the plugin system
that will be a bit different approach than this one as it will be utilizing
other smaller embedding transformer models and vector stores, so we can do A/B
testing of the system alongside new OA model releases.

## Relevant files for the inference side of the plugin system

- chat_chain.py
- chat*chain_utils.py *(tweaking tools/plugin description string generation can
help for some models)\_
- chat*chain_prompts.py *(tweaking prompts can help also)\_