Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Add support for Ollama assistants #376

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
868febb
Added almost empty `OllamaApiAssistant`
smokestacklightnin Mar 22, 2024
a0c8499
Add `_make_system_content` method
smokestacklightnin Mar 22, 2024
ae0720a
Add preliminary (untested) `_call_api` method
smokestacklightnin Mar 22, 2024
eeba8a1
Using JSONL for responses
smokestacklightnin Apr 2, 2024
420c9e8
Add kwargs for compatibility and TODO messages to remove in a future …
smokestacklightnin Mar 30, 2024
20fb764
Add Ollama gemma:2b model
smokestacklightnin Mar 30, 2024
906f2a1
Fix `OllamaApiAssistant._call_api` signature by adding types
smokestacklightnin Mar 31, 2024
50f19a1
Add temperature option
smokestacklightnin Mar 31, 2024
0ce77d8
Add `_assert_api_call_is_success()`
smokestacklightnin Apr 7, 2024
7bbbafb
Add `answer()`
smokestacklightnin Apr 7, 2024
301c815
Add `__init__()`
smokestacklightnin Apr 7, 2024
a4a2608
Set url through initializer or environment variable
smokestacklightnin Apr 7, 2024
14e14c5
Add `is_available()`
smokestacklightnin Apr 7, 2024
0cae498
Rename Gemma2B to OllamaGemma2B
smokestacklightnin Apr 7, 2024
d02b501
Remove unnecessary `else` clause
smokestacklightnin Apr 10, 2024
1ce1982
Handle error in http response
smokestacklightnin Apr 10, 2024
fd5c34b
Remove unnecessary `_call_api()` abstraction
smokestacklightnin Apr 10, 2024
6f2055c
Fix typing errors
smokestacklightnin Apr 10, 2024
e5e8e30
Add docstring
smokestacklightnin Apr 10, 2024
72161a0
Add `OllamaPhi2`
smokestacklightnin Apr 10, 2024
c5e79e0
Remove unnecessary exclusion from test
smokestacklightnin Apr 10, 2024
f6edb19
Simplify check for availability of Ollama model
smokestacklightnin Apr 10, 2024
6460d1e
Simplify call to superclass `is_available()`
smokestacklightnin Apr 10, 2024
c9b2e01
Correct incorrect grammar on system instruction
smokestacklightnin Apr 10, 2024
086ce23
Add several Ollama models
smokestacklightnin Apr 11, 2024
9724dd6
Order alphabetically
smokestacklightnin Apr 11, 2024
9bebbb0
Add Ollama to listings in docs
smokestacklightnin Apr 12, 2024
bc211d3
Merge branch 'main' into assistants/ollama/basic-functionality
pmeier May 28, 2024
4a737e0
refactor streaming again
pmeier May 28, 2024
3e2a682
more
pmeier May 28, 2024
5a4d89d
fix
pmeier May 28, 2024
9de0920
cleanup
pmeier May 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/examples/gallery_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@
# - [ragna.assistants.Gpt4][]
# - [llamafile](https://github.com/Mozilla-Ocho/llamafile)
# - [ragna.assistants.LlamafileAssistant][]
# - [Ollama](https://ollama.com/)
# - [ragna.assistants.OllamaGemma2B][]
# - [ragna.assistants.OllamaLlama2][]
# - [ragna.assistants.OllamaLlava][]
# - [ragna.assistants.OllamaMistral][]
# - [ragna.assistants.OllamaMixtral][]
# - [ragna.assistants.OllamaOrcaMini][]
# - [ragna.assistants.OllamaPhi2][]

from ragna import assistants

Expand Down
8 changes: 8 additions & 0 deletions docs/tutorials/gallery_python_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,14 @@
# - [ragna.assistants.Jurassic2Ultra][]
# - [llamafile](https://github.com/Mozilla-Ocho/llamafile)
# - [ragna.assistants.LlamafileAssistant][]
# - [Ollama](https://ollama.com/)
# - [ragna.assistants.OllamaGemma2B][]
# - [ragna.assistants.OllamaLlama2][]
# - [ragna.assistants.OllamaLlava][]
# - [ragna.assistants.OllamaMistral][]
# - [ragna.assistants.OllamaMixtral][]
# - [ragna.assistants.OllamaOrcaMini][]
# - [ragna.assistants.OllamaPhi2][]
#
# !!! note
#
Expand Down
16 changes: 16 additions & 0 deletions ragna/assistants/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@
"CommandLight",
"GeminiPro",
"GeminiUltra",
"OllamaGemma2B",
"OllamaPhi2",
"OllamaLlama2",
"OllamaLlava",
"OllamaMistral",
"OllamaMixtral",
"OllamaOrcaMini",
"Gpt35Turbo16k",
"Gpt4",
"Jurassic2Ultra",
Expand All @@ -19,6 +26,15 @@
from ._demo import RagnaDemoAssistant
from ._google import GeminiPro, GeminiUltra
from ._llamafile import LlamafileAssistant
from ._ollama import (
OllamaGemma2B,
OllamaLlama2,
OllamaLlava,
OllamaMistral,
OllamaMixtral,
OllamaOrcaMini,
OllamaPhi2,
)
from ._openai import Gpt4, Gpt35Turbo16k

# isort: split
Expand Down
10 changes: 5 additions & 5 deletions ragna/assistants/_ai21labs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

class Ai21LabsAssistant(HttpApiAssistant):
_API_KEY_ENV_VAR = "AI21_API_KEY"
_STREAMING_PROTOCOL = None
_MODEL_TYPE: str

@classmethod
Expand All @@ -27,7 +28,8 @@ async def answer(
# See https://docs.ai21.com/reference/j2-chat-api#chat-api-parameters
# See https://docs.ai21.com/reference/j2-complete-api-ref#api-parameters
# See https://docs.ai21.com/reference/j2-chat-api#understanding-the-response
response = await self._client.post(
async for data in self._call_api(
"POST",
f"https://api.ai21.com/studio/v1/j2-{self._MODEL_TYPE}/chat",
headers={
"accept": "application/json",
Expand All @@ -46,10 +48,8 @@ async def answer(
],
"system": self._make_system_content(sources),
},
)
await self._assert_api_call_is_success(response)

yield cast(str, response.json()["outputs"][0]["text"])
):
yield cast(str, data["outputs"][0]["text"])


# The Jurassic2Mid assistant receives a 500 internal service error from the remote
Expand Down
5 changes: 3 additions & 2 deletions ragna/assistants/_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

from ragna.core import PackageRequirement, RagnaException, Requirement, Source

from ._http_api import HttpApiAssistant
from ._http_api import HttpApiAssistant, HttpStreamingProtocol


class AnthropicAssistant(HttpApiAssistant):
_API_KEY_ENV_VAR = "ANTHROPIC_API_KEY"
_STREAMING_PROTOCOL = HttpStreamingProtocol.SSE
_MODEL: str

@classmethod
Expand Down Expand Up @@ -40,7 +41,7 @@ async def answer(
) -> AsyncIterator[str]:
# See https://docs.anthropic.com/claude/reference/messages_post
# See https://docs.anthropic.com/claude/reference/streaming
async for data in self._stream_sse(
async for data in self._call_api(
"POST",
"https://api.anthropic.com/v1/messages",
headers={
Expand Down
5 changes: 3 additions & 2 deletions ragna/assistants/_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

from ragna.core import RagnaException, Source

from ._http_api import HttpApiAssistant
from ._http_api import HttpApiAssistant, HttpStreamingProtocol


class CohereAssistant(HttpApiAssistant):
_API_KEY_ENV_VAR = "COHERE_API_KEY"
_STREAMING_PROTOCOL = HttpStreamingProtocol.JSONL
_MODEL: str

@classmethod
Expand All @@ -29,7 +30,7 @@ async def answer(
# See https://docs.cohere.com/docs/cochat-beta
# See https://docs.cohere.com/reference/chat
# See https://docs.cohere.com/docs/retrieval-augmented-generation-rag
async for event in self._stream_jsonl(
async for event in self._call_api(
"POST",
"https://api.cohere.ai/v1/chat",
headers={
Expand Down
49 changes: 11 additions & 38 deletions ragna/assistants/_google.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,15 @@
from typing import AsyncIterator

from ragna._compat import anext
from ragna.core import PackageRequirement, Requirement, Source
from ragna.core import Source

from ._http_api import HttpApiAssistant


# ijson does not support reading from an (async) iterator, but only from file-like
# objects, i.e. https://docs.python.org/3/tutorial/inputoutput.html#methods-of-file-objects.
# See https://github.com/ICRAR/ijson/issues/44 for details.
# ijson actually doesn't care about most of the file interface and only requires the
# read() method to be present.
class AsyncIteratorReader:
def __init__(self, ait: AsyncIterator[bytes]) -> None:
self._ait = ait

async def read(self, n: int) -> bytes:
# n is usually used to indicate how many bytes to read, but since we want to
# return a chunk as soon as it is available, we ignore the value of n. The only
# exception is n == 0, which is used by ijson to probe the return type and
# set up decoding.
if n == 0:
return b""
return await anext(self._ait, b"") # type: ignore[call-arg]
from ._http_api import HttpApiAssistant, HttpStreamingProtocol


class GoogleAssistant(HttpApiAssistant):
_API_KEY_ENV_VAR = "GOOGLE_API_KEY"
_STREAMING_PROTOCOL = HttpStreamingProtocol.JSON
_MODEL: str

@classmethod
def _extra_requirements(cls) -> list[Requirement]:
return [PackageRequirement("ijson")]

@classmethod
def display_name(cls) -> str:
return f"Google/{cls._MODEL}"
Expand All @@ -51,9 +28,7 @@ def _instructize_prompt(self, prompt: str, sources: list[Source]) -> str:
async def answer(
self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256
) -> AsyncIterator[str]:
import ijson

async with self._client.stream(
async for chunk in self._call_api(
"POST",
f"https://generativelanguage.googleapis.com/v1beta/models/{self._MODEL}:streamGenerateContent",
params={"key": self._api_key},
Expand All @@ -64,7 +39,10 @@ async def answer(
],
# https://ai.google.dev/docs/safety_setting_gemini
"safetySettings": [
{"category": f"HARM_CATEGORY_{category}", "threshold": "BLOCK_NONE"}
{
"category": f"HARM_CATEGORY_{category}",
"threshold": "BLOCK_NONE",
}
for category in [
"HARASSMENT",
"HATE_SPEECH",
Expand All @@ -78,14 +56,9 @@ async def answer(
"maxOutputTokens": max_new_tokens,
},
},
) as response:
await self._assert_api_call_is_success(response)

async for chunk in ijson.items(
AsyncIteratorReader(response.aiter_bytes(1024)),
"item.candidates.item.content.parts.item.text",
):
yield chunk
parse_kwargs=dict(item="item.candidates.item.content.parts.item.text"),
):
yield chunk


class GeminiPro(GoogleAssistant):
Expand Down
Loading