Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@ jobs:
working-directory: ./app
run: |
docker buildx build \
--platform linux/amd64,linux/arm64 \
--platform linux/amd64 \
.
2 changes: 2 additions & 0 deletions app/.dockerignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
myenv
venv
.direnv
.envrc
__pycache__
sentence-transformers
.tmp
8 changes: 4 additions & 4 deletions app/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ RUN apt-get update && apt-get install -y \
git \
&& rm -rf /var/lib/apt/lists/*

# Install Go for ARM architecture (latest supported version 1.21)
RUN curl -OL https://golang.org/dl/go1.21.1.linux-arm64.tar.gz && \
tar -C /usr/local -xzf go1.21.1.linux-arm64.tar.gz && \
rm go1.21.1.linux-arm64.tar.gz
# Install Go for x86 architecture (latest supported version 1.21)
RUN curl -OL https://golang.org/dl/go1.21.1.linux-amd64.tar.gz && \
tar -C /usr/local -xzf go1.21.1.linux-amd64.tar.gz && \
rm go1.21.1.linux-amd64.tar.gz

# Set Go environment variables
ENV PATH="/usr/local/go/bin:${PATH}"
Expand Down
20 changes: 7 additions & 13 deletions app/rag_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,14 @@ def embed_knowledge_base(self):
def normalize_query(self, query):
return query.lower().strip()

def get_query_embedding(self, query, use_cpu=True):
def get_query_embedding(self, query):
normalized_query = self.normalize_query(query)
query_embedding = self.model.encode([normalized_query], convert_to_tensor=True)
if use_cpu:
query_embedding = query_embedding.cpu()
query_embedding = query_embedding.cpu()
return query_embedding

def get_doc_embeddings(self, use_cpu=True):
if use_cpu:
return self.doc_embeddings.cpu()
return self.doc_embeddings
def get_doc_embeddings(self):
return self.doc_embeddings.cpu()

def compute_document_scores(self, query_embedding, doc_embeddings, high_match_threshold):
text_similarities = cosine_similarity(query_embedding, doc_embeddings)[0]
Expand All @@ -66,12 +63,9 @@ def compute_document_scores(self, query_embedding, doc_embeddings, high_match_th

return result

def retrieve(self, query, similarity_threshold=0.4, high_match_threshold=0.8, max_docs=5, use_cpu=True):
# Note: Set use_cpu=True to run on CPU, which is useful for testing or environments without a GPU.
# Set use_cpu=False to leverage GPU for better performance in production.

query_embedding = self.get_query_embedding(query, use_cpu)
doc_embeddings = self.get_doc_embeddings(use_cpu)
def retrieve(self, query, similarity_threshold=0.4, high_match_threshold=0.8, max_docs=5):
query_embedding = self.get_query_embedding(query)
doc_embeddings = self.get_doc_embeddings()

doc_scores = self.compute_document_scores(query_embedding, doc_embeddings, high_match_threshold)
retrieved_docs = self.get_top_docs(doc_scores, similarity_threshold, max_docs)
Expand Down
3 changes: 2 additions & 1 deletion app/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ scikit-learn==1.2.2
segment-analytics-python==2.3.3
numpy==1.24.4
sentence-transformers==2.3.1
torch==2.0.1
--find-links https://download.pytorch.org/whl/cpu/torch_stable.html
torch==2.0.1+cpu
huggingface_hub==0.15.1
openai==0.28.0
PyYAML==6.0.2
Expand Down
13 changes: 5 additions & 8 deletions app/test_rag_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,7 @@ def test_get_doc_embeddings(self):
def test_retrieve_fallback(self):
# test a query that should return the fallback response
query = "Hello"
# set use_cpu to True, as testing has no GPU calculations
result = self.rag_system.retrieve(query, use_cpu=True)
result = self.rag_system.retrieve(query)
self.assertIsInstance(result, list)
self.assertGreater(len(result), 0)
self.assertEqual(len(result), 1) # should return one result for fallback
Expand All @@ -67,8 +66,7 @@ def test_retrieve_fallback(self):
def test_retrieve_actual_response(self):
# test a query that should return an actual response from the knowledge base
query = "What is Defang?"
# set use_cpu to True, as testing has no GPU calculations
result = self.rag_system.retrieve(query, use_cpu=True)
result = self.rag_system.retrieve(query)
self.assertIsInstance(result, list)
self.assertGreater(len(result), 0)
self.assertLessEqual(len(result), 5) # should return up to max_docs (5)
Expand All @@ -80,9 +78,8 @@ def test_retrieve_actual_response(self):

def test_compute_document_scores(self):
query = "Does Defang have an MCP sample?"
# get embeddings and move them to CPU, as testing has no GPU calculations
query_embedding = self.rag_system.get_query_embedding(query, use_cpu=True)
doc_embeddings = self.rag_system.get_doc_embeddings(use_cpu=True)
query_embedding = self.rag_system.get_query_embedding(query)
doc_embeddings = self.rag_system.get_doc_embeddings()

# call function and get results
result = self.rag_system.compute_document_scores(query_embedding, doc_embeddings, high_match_threshold=0.8)
Expand All @@ -105,4 +102,4 @@ def test_compute_document_scores(self):
print("Test for compute_document_scores passed successfully!")

if __name__ == '__main__':
unittest.main()
unittest.main()
3 changes: 2 additions & 1 deletion compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@ services:
restart: always
domainname: ask.defang.io
x-defang-dns-role: arn:aws:iam::258338292852:role/dnsadmin-39a19c3
platform: linux/amd64
build:
context: ./app
shm_size: "30gb"
dockerfile: Dockerfile
ports:
- target: 5050
published: 5050
Expand Down