Skip to content

Commit

Permalink
Add LlamaCppGenerator
Browse files Browse the repository at this point in the history
  • Loading branch information
awinml committed Jan 5, 2024
1 parent c51ed18 commit 3d70140
Show file tree
Hide file tree
Showing 5 changed files with 330 additions and 0 deletions.
File renamed without changes.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ exclude_lines = [

[tool.pytest.ini_options]
minversion = "6.0"
addopts = "-vv"
markers = [
"unit: unit tests",
"integration: integration tests"
Expand Down
98 changes: 98 additions & 0 deletions src/llama_cpp_haystack/generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import logging
from typing import Any, Dict, List, Optional

from haystack import component
from llama_cpp import Llama

logger = logging.getLogger(__name__)


@component
class LlamaCppGenerator:
"""
Generator for using a model with Llama.cpp.
This component provides an interface to generate text using a quantized model (GGUF) using llama.cpp.
Usage example:
```python
from llama_cpp_haystack import LlamaCppGenerator
generator = LlamaCppGenerator(model_path="zephyr-7b-beta.Q4_0.gguf", n_ctx=2048, n_batch=512)
print(generator.run("Who is the best American actor?", generation_kwargs={"max_tokens": 128}))
# {'replies': ['John Cusack'], 'meta': [{"object": "text_completion", ...}]}
```
"""

def __init__(
self,
model_path: str,
n_ctx: Optional[int] = 0,
n_batch: Optional[int] = 512,
model_kwargs: Optional[Dict[str, Any]] = None,
generation_kwargs: Optional[Dict[str, Any]] = None,
):
"""
:param model_path: The path of a quantized model for text generation,
for example, "zephyr-7b-beta.Q4_0.gguf".
If the model_path is also specified in the `model_kwargs`, this parameter will be ignored.
:param n_ctx: The number of tokens in the context. When set to 0, the context will be taken from the model.
If the n_ctx is also specified in the `model_kwargs`, this parameter will be ignored.
:param n_batch: Prompt processing maximum batch size. Defaults to 512.
If the n_batch is also specified in the `model_kwargs`, this parameter will be ignored.
:param model_kwargs: Dictionary containing keyword arguments used to initialize the LLM for text generation.
These keyword arguments provide fine-grained control over the model loading.
In case of duplication, these kwargs override `model_path`, `n_ctx`, and `n_batch` init parameters.
See Llama.cpp's [documentation](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.__init__)
for more information on the available kwargs.
:param generation_kwargs: A dictionary containing keyword arguments to customize text generation.
Some examples: `max_tokens`, `temperature`, `top_k`, `top_p`,...
See Llama.cpp's documentation for more information:
https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_completion
"""

model_kwargs = model_kwargs or {}
generation_kwargs = generation_kwargs or {}

# check if the huggingface_pipeline_kwargs contain the essential parameters
# otherwise, populate them with values from init parameters
model_kwargs.setdefault("model_path", model_path)
model_kwargs.setdefault("n_ctx", n_ctx)
model_kwargs.setdefault("n_batch", n_batch)

self.model_path = model_path
self.n_ctx = n_ctx
self.n_batch = n_batch
self.model_kwargs = model_kwargs
self.generation_kwargs = generation_kwargs
self.model = None

def warm_up(self):
if self.model is None:
self.model = Llama(**self.model_kwargs)

@component.output_types(replies=List[str], meta=List[Dict[str, Any]])
def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None):
"""
Run the text generation model on the given prompt.
:param prompt: A string representing the prompt.
:param generation_kwargs: A dictionary containing keyword arguments to customize text generation.
Some examples: `max_tokens`, `temperature`, `top_k`, `top_p`,...
See Llama.cpp's documentation for more information:
https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_completion
:return: A dictionary of the returned responses and metadata.
"""
if self.model is None:
error_msg = "The model has not been loaded. Please call warm_up() before running."
raise RuntimeError(error_msg)

if not prompt:
return {"replies": []}

# merge generation kwargs from init method with those from run method
updated_generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}

output = self.model.create_completion(prompt=prompt, **updated_generation_kwargs)
replies = [output["choices"][0]["text"]]

return {"replies": replies, "meta": [output]}
2 changes: 2 additions & 0 deletions tests/models/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*
!.gitignore
229 changes: 229 additions & 0 deletions tests/test_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
import os
import urllib.request
from pathlib import Path
from unittest.mock import MagicMock

import pytest
from haystack import Document, Pipeline
from haystack.components.builders.answer_builder import AnswerBuilder
from haystack.components.builders.prompt_builder import PromptBuilder
from haystack.components.retrievers import InMemoryBM25Retriever
from haystack.document_stores import InMemoryDocumentStore

from llama_cpp_haystack import LlamaCppGenerator


@pytest.fixture
def model_path():
return Path(__file__).parent / "models"


def download_file(file_link, filename, capsys):
# Checks if the file already exists before downloading
if not os.path.isfile(filename):
urllib.request.urlretrieve(file_link, filename) # noqa: S310
with capsys.disabled():
print("\nModel file downloaded successfully.")
else:
with capsys.disabled():
print("\nModel file already exists.")


class TestLlamaCppGenerator:
@pytest.fixture
def generator(self, model_path, capsys):
ggml_model_path = "https://huggingface.co/TheBloke/phi-2-GGUF/resolve/main/phi-2.Q3_K_S.gguf"
filename = "phi-2.Q3_K_S.gguf"

# Download GGUF model from HuggingFace
download_file(ggml_model_path, str(model_path / filename), capsys)

model_path = str(model_path / filename)
generator = LlamaCppGenerator(model_path=model_path, n_ctx=128, n_batch=128)
generator.warm_up()
return generator

@pytest.fixture
def generator_mock(self):
mock_model = MagicMock()
generator = LlamaCppGenerator(model_path="test_model.gguf", n_ctx=2048, n_batch=512)
generator.model = mock_model
return generator, mock_model

def test_default_init(self):
"""
Test default initialization parameters.
"""
generator = LlamaCppGenerator(model_path="test_model.gguf")

assert generator.model_path == "test_model.gguf"
assert generator.n_ctx == 0
assert generator.n_batch == 512
assert generator.model_kwargs == {"model_path": "test_model.gguf", "n_ctx": 0, "n_batch": 512}
assert generator.generation_kwargs == {}

def test_custom_init(self):
"""
Test custom initialization parameters.
"""
generator = LlamaCppGenerator(
model_path="test_model.gguf",
n_ctx=2048,
n_batch=512,
)

assert generator.model_path == "test_model.gguf"
assert generator.n_ctx == 2048
assert generator.n_batch == 512
assert generator.model_kwargs == {"model_path": "test_model.gguf", "n_ctx": 2048, "n_batch": 512}
assert generator.generation_kwargs == {}

def test_ignores_model_path_if_specified_in_model_kwargs(self, model_path):
"""
Test that model_path is ignored if already specified in model_kwargs.
"""
generator = LlamaCppGenerator(
model_path=str(model_path / "phi-2.Q3_K_S.gguf"),
n_ctx=512,
n_batch=512,
model_kwargs={"model_path": "other_model.gguf"},
)
assert generator.model_kwargs["model_path"] == "other_model.gguf"

def test_ignores_n_ctx_if_specified_in_model_kwargs(self, model_path):
"""
Test that n_ctx is ignored if already specified in model_kwargs.
"""
generator = LlamaCppGenerator(
model_path=str(model_path / "phi-2.Q3_K_S.gguf"), n_ctx=512, n_batch=512, model_kwargs={"n_ctx": 1024}
)
assert generator.model_kwargs["n_ctx"] == 1024

def test_ignores_n_batch_if_specified_in_model_kwargs(self, model_path):
"""
Test that n_batch is ignored if already specified in model_kwargs.
"""
generator = LlamaCppGenerator(
model_path=str(model_path / "phi-2.Q3_K_S.gguf"), n_ctx=512, n_batch=512, model_kwargs={"n_batch": 1024}
)
assert generator.model_kwargs["n_batch"] == 1024

def test_raises_error_without_warm_up(self, model_path):
"""
Test that the generator raises an error if warm_up() is not called before running.
"""
generator = LlamaCppGenerator(model_path=str(model_path / "phi-2.Q3_K_S.gguf"), n_ctx=512, n_batch=512)
with pytest.raises(RuntimeError):
generator.run("What is the capital of China?")

def test_run_with_empty_prompt(self, generator_mock):
"""
Test that an empty prompt returns an empty list of replies.
"""
generator, _ = generator_mock
result = generator.run("")
assert result["replies"] == []

def test_run_with_valid_prompt(self, generator_mock):
"""
Test that a valid prompt returns a list of replies.
"""
generator, mock_model = generator_mock
mock_output = {
"choices": [{"text": "Generated text"}],
"metadata": {"other_info": "Some metadata"},
}
mock_model.create_completion.return_value = mock_output
result = generator.run("Test prompt")
assert result["replies"] == ["Generated text"]
assert result["meta"] == [mock_output]

def test_run_with_generation_kwargs(self, generator_mock):
"""
Test that a valid prompt and generation kwargs returns a list of replies.
"""
generator, mock_model = generator_mock
mock_output = {
"choices": [{"text": "Generated text"}],
"metadata": {"other_info": "Some metadata"},
}
mock_model.create_completion.return_value = mock_output
generation_kwargs = {"max_tokens": 128}
result = generator.run("Test prompt", generation_kwargs)
assert result["replies"] == ["Generated text"]
assert result["meta"] == [mock_output]

@pytest.mark.integration
def test_run(self, generator):
"""
Test that a valid prompt returns a list of replies.
"""
questions_and_answers = [
("What's the capital of France?", "Paris"),
("What is the capital of Canada?", "Ottawa"),
("What is the capital of Ghana?", "Accra"),
]

for question, answer in questions_and_answers:
prompt = f"""Instruct: Answer in a single word. {question} \n Output:"""
result = generator.run(prompt)

assert "replies" in result
assert isinstance(result["replies"], list)
assert len(result["replies"]) > 0
assert answer.lower() in result["replies"][0].lower().strip()

@pytest.mark.integration
def test_run_rag_pipeline(self, generator):
"""
Test that a valid prompt returns a list of replies.
"""
prompt_template = """
Instruct: Given these documents, answer the question.\nDocuments:
{% for doc in documents %}
{{ doc.content }}
{% endfor %}
\nQuestion: {{question}}
\nOutput:
"""
rag_pipeline = Pipeline()
rag_pipeline.add_component(
instance=InMemoryBM25Retriever(document_store=InMemoryDocumentStore(), top_k=1), name="retriever"
)
rag_pipeline.add_component(instance=PromptBuilder(template=prompt_template), name="prompt_builder")
rag_pipeline.add_component(instance=generator, name="llm")
rag_pipeline.add_component(instance=AnswerBuilder(), name="answer_builder")
rag_pipeline.connect("retriever", "prompt_builder.documents")
rag_pipeline.connect("prompt_builder", "llm")
rag_pipeline.connect("llm.replies", "answer_builder.replies")
rag_pipeline.connect("retriever", "answer_builder.documents")

# Populate the document store
documents = [
Document(content="My name is Jean and I live in Paris."),
Document(content="My name is Mark and I live in Berlin."),
Document(content="My name is Giorgio and I live in Rome."),
]
rag_pipeline.get_component("retriever").document_store.write_documents(documents)

# Query and assert
questions = ["Who lives in Paris?", "Who lives in Berlin?", "Who lives in Rome?"]
answers_spywords = ["Jean", "Mark", "Giorgio"]

for question, spyword in zip(questions, answers_spywords):
result = rag_pipeline.run(
{
"retriever": {"query": question},
"prompt_builder": {"question": question},
"llm": {"generation_kwargs": {"temperature": 0.1}},
"answer_builder": {"query": question},
}
)

assert len(result["answer_builder"]["answers"]) == 1
generated_answer = result["answer_builder"]["answers"][0]
assert spyword in generated_answer.data
assert generated_answer.query == question
assert hasattr(generated_answer, "documents")
assert hasattr(generated_answer, "meta")

0 comments on commit 3d70140

Please sign in to comment.