Skip to content

Commit

Permalink
reprebot/feat: #3 add hugging face model to llm client
Browse files Browse the repository at this point in the history
- use `ChatHuggingFace` to instantiate a `Hugging Face` models
- use `HuggingFaceEndpoint` to connect to the `Hugging Face` model
- add `hugging-face` to `LLMClient`
- add `TestHuggingFaceLLMClient` unit test #20
- update `langchain` to `0.1.12` version #23
- install `transformers` dependency
- install `Jinja2` dependency
- add `HUGGINGFACEHUB_API_TOKEN` to `pipeline.yml`
  • Loading branch information
este6an13 committed Mar 16, 2024
1 parent 8429150 commit b6e18cf
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 14 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,10 @@ There are various ways to contribute to **Reprebot**:
Reprebot relies on the following dependencies:

- `chromadb`
- `Jinja2`
- `langchain`
- `langchain-openai`
- `pytest`
- `transformers`

For a comprehensive list of dependencies, including both direct and transitive dependencies, please refer to the `requirements.txt` file.
6 changes: 4 additions & 2 deletions direct-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
chromadb==0.4.24
langchain==0.1.11
Jinja2==3.1.3
langchain==0.1.12
langchain-openai==0.0.8
pytest==8.1.1
pytest==8.1.1
transformers==4.38.2
20 changes: 12 additions & 8 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ fastapi==0.110.0
filelock==3.13.1
flatbuffers==24.3.7
frozenlist==1.4.1
fsspec==2024.2.0
fsspec==2024.3.0
google-auth==2.28.2
googleapis-common-protos==1.63.0
greenlet==3.0.3
Expand All @@ -35,17 +35,19 @@ huggingface-hub==0.21.4
humanfriendly==10.0
idna==3.6
importlib-metadata==6.11.0
importlib_resources==6.1.3
importlib_resources==6.3.0
iniconfig==2.0.0
Jinja2==3.1.3
jsonpatch==1.33
jsonpointer==2.4
kubernetes==29.0.0
langchain==0.1.11
langchain-community==0.0.27
langchain-core==0.1.30
langchain==0.1.12
langchain-community==0.0.28
langchain-core==0.1.32
langchain-openai==0.0.8
langchain-text-splitters==0.0.1
langsmith==0.1.23
langsmith==0.1.27
MarkupSafe==2.1.5
marshmallow==3.21.1
mmh3==4.1.0
monotonic==1.6
Expand All @@ -55,7 +57,7 @@ mypy-extensions==1.0.0
numpy==1.26.4
oauthlib==3.2.2
onnxruntime==1.17.1
openai==1.13.3
openai==1.14.1
opentelemetry-api==1.23.0
opentelemetry-exporter-otlp-proto-common==1.23.0
opentelemetry-exporter-otlp-proto-grpc==1.23.0
Expand Down Expand Up @@ -88,6 +90,7 @@ regex==2023.12.25
requests==2.31.0
requests-oauthlib==1.4.0
rsa==4.9
safetensors==0.4.2
six==1.16.0
sniffio==1.3.1
SQLAlchemy==2.0.28
Expand All @@ -97,6 +100,7 @@ tenacity==8.2.3
tiktoken==0.6.0
tokenizers==0.15.2
tqdm==4.66.2
transformers==4.38.2
typer==0.9.0
typing-inspect==0.9.0
typing_extensions==4.10.0
Expand All @@ -107,4 +111,4 @@ websocket-client==1.7.0
websockets==12.0
wrapt==1.16.0
yarl==1.9.4
zipp==3.17.0
zipp==3.18.1
5 changes: 2 additions & 3 deletions src/llm_client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import FakeEmbeddings
from langchain.chat_models.fake import FakeListChatModel
from langchain_community.chat_models.huggingface import ChatHuggingFace
from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint

class LLMClient:
def __init__(self, model_type: str):
Expand All @@ -25,14 +27,11 @@ def setup_model(self, temperature=0):
)
elif self.model_type == "fake":
model = FakeListChatModel(responses=["Hello",])
"""
elif self.model_type == "hugging-face":
llm = HuggingFaceEndpoint(
repo_id="google/gemma-7b",
)
model = ChatHuggingFace(llm=llm)
# https://github.com/langchain-ai/langchain/issues/18639
"""
return model

def setup_chain(self, retriever, prompt, model):
Expand Down
34 changes: 33 additions & 1 deletion test/unit/src/test_llm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from langchain_community.embeddings import FakeEmbeddings
from langchain_core.runnables.base import RunnableSequence
from langchain.chat_models.fake import FakeListChatModel
from langchain_community.chat_models.huggingface import ChatHuggingFace


class TestFakeLLMClient:
Expand Down Expand Up @@ -38,4 +39,35 @@ def test_setup_chain(self, llm_client):
def test_query(self, llm_client):
response = llm_client.query(user_input="")
assert isinstance(response, str)
assert response == "Hello"
assert response == "Hello"


class TestHuggingFaceLLMClient:
@pytest.fixture
def llm_client(self):
return LLMClient(model_type="hugging-face")

def test_init(self, llm_client):
assert llm_client.model_type == "hugging-face"

def test_setup_model(self, llm_client):
model = llm_client.setup_model()
assert isinstance(model, ChatHuggingFace)

def test_setup_chain(self, llm_client):
# Empty retriever for testing
retriever = Chroma.from_documents(
documents=[Document(page_content="")],
embedding=FakeEmbeddings(size=1),
).as_retriever(search_kwargs={"k": 1})
# Empty prompt for testing
prompt = ChatPromptTemplate.from_messages([""])
# Default "fake" model use in this test
model = llm_client.setup_model()
# RAG chain
chain = llm_client.setup_chain(retriever=retriever, prompt=prompt, model=model)
assert isinstance(chain, RunnableSequence)

def test_query(self, llm_client):
response = llm_client.query(user_input="")
assert isinstance(response, str)

0 comments on commit b6e18cf

Please sign in to comment.