Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into project-BasicRAG
Browse files Browse the repository at this point in the history
  • Loading branch information
sanchitvj committed Mar 27, 2024
2 parents cecce62 + 3025d59 commit e3a3edb
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 9 deletions.
11 changes: 10 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
# GRAG
<h1 align="center">GRAG</h1>

[![License: AGPL v3](https://img.shields.io/badge/License-AGPL_v3-blue.svg)](https://www.gnu.org/licenses/agpl-3.0)
![Static Badge](https://img.shields.io/badge/docstring%20style-google-pink?labelColor=white)
![Static Badge](https://img.shields.io/badge/linter-ruff-yellow?labelColor=white)
![Docs](https://img.shields.io/github/actions/workflow/status/arjbingly/Capstone_5/ruff_linting.yml)
![Static Badge](https://img.shields.io/badge/buildstyle-hatchling-purple?labelColor=white)
![Static Badge](https://img.shields.io/badge/codestyle-pyflake-purple?labelColor=white)
![GitHub Issues or Pull Requests](https://img.shields.io/github/issues-pr/arjbingly/Capstone_5)


## Project Overview

Expand Down
1 change: 1 addition & 0 deletions cookbook/Basic-RAG/BasicRAG_stuff.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

client = DeepLakeClient(collection_name="test")
retriever = Retriever(vectordb=client)

rag = BasicRAG(doc_chain="stuff", retriever=retriever)

if __name__ == "__main__":
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,6 @@ docstring-code-format = true

[tool.ruff.lint.pydocstyle]
convention = "google"

[tool.mypy]
ignore_missing_imports = true
2 changes: 1 addition & 1 deletion src/config.ini
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[llm]
model_name : Llama-2-7b-chat
model_name : Llama-2-13b-chat
# meta-llama/Llama-2-70b-chat-hf Mixtral-8x7B-Instruct-v0.1
quantization : Q5_K_M
pipeline : llama_cpp
Expand Down
2 changes: 1 addition & 1 deletion src/grag/components/multivec_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(
store_path: str = multivec_retriever_conf["store_path"],
id_key: str = multivec_retriever_conf["id_key"],
namespace: str = multivec_retriever_conf["namespace"],
top_k=1,
top_k=int(multivec_retriever_conf["top_k"]),
client_kwargs: Optional[Dict[str, Any]] = None,
):
"""Initialize the Retriever.
Expand Down
16 changes: 10 additions & 6 deletions src/tests/rag/basic_rag_test.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
from typing import Text, List
from typing import List, Text

from grag.components.multivec_retriever import Retriever
from grag.components.vectordb.deeplake_client import DeepLakeClient
from grag.rag.basic_rag import BasicRAG

client = DeepLakeClient(collection_name="test")
retriever = Retriever(vectordb=client)


def test_rag_stuff():
rag = BasicRAG(doc_chain="stuff")
response, sources = rag("What is simulated annealing?")
rag = BasicRAG(doc_chain="stuff", retriever=retriever)
response, sources = rag("What is Flash Attention?")
assert isinstance(response, Text)
assert isinstance(sources, List)
assert all(isinstance(s, str) for s in sources)
del rag.llm


def test_rag_refine():
rag = BasicRAG(doc_chain="refine")
response, sources = rag("What is simulated annealing?")
# assert isinstance(response, Text)
rag = BasicRAG(doc_chain="refine", retriever=retriever)
response, sources = rag("What is Flash Attention?")
assert isinstance(response, List)
assert all(isinstance(s, str) for s in response)
assert isinstance(sources, List)
Expand Down

0 comments on commit e3a3edb

Please sign in to comment.