Skip to content
This repository has been archived by the owner on Feb 12, 2024. It is now read-only.

Commit

Permalink
Feature/add ollama (#138)
Browse files Browse the repository at this point in the history
  • Loading branch information
richardblythman committed Jan 13, 2024
1 parent d45900f commit 49de39d
Show file tree
Hide file tree
Showing 7 changed files with 136 additions and 1 deletion.
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ retrying = "^1.3.4"
anthropic_support = ["anthropic"]
hf_support = ["accelerate", "datasets", "torch", "transformers"]
vllm_support = ["accelerate", "torch", "vllm"]
ollama_support = ["litellm"]

all = [
# anthropic
Expand All @@ -37,6 +38,7 @@ all = [
"datasets",
"torch",
"transformers",
"litellm"
]
all_with_extras = [
# all
Expand All @@ -49,6 +51,8 @@ all_with_extras = [
"transformers",
# More Extras
"vllm",
# ollama
"litellm"
]
# To export dependencies to pip, use:
# poetry export -f requirements.txt --with dev --without-hashes --output requirements-dev.txt
Expand Down
1 change: 1 addition & 0 deletions synthesizer/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class LLMProviderName(Enum):
LLAMACPP = "llamacpp"
LITE_LLM = "lite-llm"
SCIPHI = "sciphi"
OLLAMA = "ollama"


class RAGProviderName(Enum):
Expand Down
2 changes: 2 additions & 0 deletions synthesizer/interface/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from synthesizer.interface.llm.sciphi_interface import SciPhiLLMInterface
from synthesizer.interface.llm.vllm_interface import vLLMInterface
from synthesizer.interface.llm_interface_manager import LLMInterfaceManager
from synthesizer.interface.llm.ollama_interface import OllamaLLMInterface
from synthesizer.interface.rag.agent_search import (
AgentSearchRAGConfig,
AgentSearchRAGInterface,
Expand All @@ -38,6 +39,7 @@
"OpenAILLMInterface",
"SciPhiLLMInterface",
"vLLMInterface",
"OllamaLLMInterface",
# RAG
"RAGInterfaceManager",
"RAGProviderConfig",
Expand Down
60 changes: 60 additions & 0 deletions synthesizer/interface/llm/ollama_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""A module for interfacing with Ollama"""
import logging

from synthesizer.interface.base import LLMInterface, LLMProviderName
from synthesizer.interface.llm_interface_manager import llm_interface
from synthesizer.llm import GenerationConfig, OllamaConfig, OllamaLLM

logger = logging.getLogger(__name__)


@llm_interface
class OllamaLLMInterface(LLMInterface):
"""A class to interface with Ollama."""

provider_name = LLMProviderName.OLLAMA
system_message = "You are a helpful assistant."

def __init__(
self,
config: OllamaConfig,
*args,
**kwargs,
) -> None:
self.config = config
self._model = OllamaLLM(config)

def get_completion(
self, prompt: str, generation_config: GenerationConfig
) -> str:
"""Get a completion from the Ollama based on the provided prompt."""

logger.debug(
f"Getting completion from Ollama for model={generation_config.model_name}"
)
if "instruct" in generation_config.model_name:
return self.model.get_instruct_completion(
prompt, generation_config
)
else:
return self._model.get_chat_completion(
[
{
"role": "system",
"content": OllamaLLMInterface.system_message,
},
{"role": "user", "content": prompt},
],
generation_config,
)

def get_chat_completion(
self, conversation: list[dict], generation_config: GenerationConfig
) -> str:
raise NotImplementedError(
"Chat completion not yet implemented for Ollama."
)

@property
def model(self) -> OllamaLLM:
return self._model
2 changes: 1 addition & 1 deletion synthesizer/interface/rag/bing_search/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(
) -> None:
super().__init__(config)
self.config: BingRAGConfig = config
print('self.config = ', self.config)
print("self.config = ", self.config)
api_key = self.config.api_key or os.getenv("BING_API_KEY")
if not api_key:
raise ValueError(
Expand Down
3 changes: 3 additions & 0 deletions synthesizer/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from synthesizer.llm.models.openai_llm import OpenAIConfig, OpenAILLM
from synthesizer.llm.models.sciphi_llm import SciPhiConfig, SciPhiLLM
from synthesizer.llm.models.vllm_llm import vLLM, vLLMConfig
from synthesizer.llm.models.ollama_llm import OllamaConfig, OllamaLLM

__all__ = [
# Base
Expand All @@ -26,4 +27,6 @@
"SciPhiLLM",
"vLLMConfig",
"vLLM",
"OllamaConfig",
"OllamaLLM",
]
65 changes: 65 additions & 0 deletions synthesizer/llm/models/ollama_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""A module for creating Ollama model abstractions."""
import os
from dataclasses import dataclass

from litellm import completion

from synthesizer.core import LLMProviderName
from synthesizer.llm.base import LLM, GenerationConfig, LLMConfig
from synthesizer.llm.config_manager import model_config


@model_config
@dataclass
class OllamaConfig(LLMConfig):
"""Configuration for Ollama models."""

# Base
provider_name: LLMProviderName = LLMProviderName.OLLAMA
api_base: str = "http://localhost:11434"


class OllamaLLM(LLM):
"""A concrete class for creating Ollama models."""

def __init__(
self,
config: OllamaConfig,
*args,
**kwargs,
) -> None:
super().__init__()
self.config: OllamaConfig = config

# set the config here, again, for typing purposes
if not isinstance(self.config, OllamaConfig):
raise ValueError(
"The provided config must be an instance of OllamaConfig."
)

def get_chat_completion(
self,
messages: list[dict[str, str]],
generation_config: GenerationConfig,
) -> str:
"""Get a chat completion from Ollama based on the provided prompt."""

# Create the chat completion
response = completion(
model="ollama/mistral",
messages=messages,
api_base=self.config.api_base,
stream=generation_config.do_stream,
)

return response.choices[0].message["content"]

def get_instruct_completion(
self,
messages: list[dict[str, str]],
generation_config: GenerationConfig,
) -> str:
"""Get an instruction completion from Ollama."""
raise NotImplementedError(
"Instruction completion is not yet supported for Ollama."
)

0 comments on commit 49de39d

Please sign in to comment.