Skip to content

Commit

Permalink
fix: update backend tests (#975)
Browse files Browse the repository at this point in the history
* fix: update backend tests

* fix(pytest): update types
  • Loading branch information
mamadoudicko committed Aug 18, 2023
1 parent aa623c4 commit c746eb1
Show file tree
Hide file tree
Showing 11 changed files with 51 additions and 56 deletions.
5 changes: 3 additions & 2 deletions backend/core/auth/api_key_handler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from datetime import datetime
from uuid import UUID

from fastapi import HTTPException
from models.settings import get_supabase_db
Expand All @@ -13,7 +14,7 @@ async def verify_api_key(
# Use UTC time to avoid timezone issues
current_date = datetime.utcnow().date()
supabase_db = get_supabase_db()
result = supabase_db.get_active_api_key(api_key)
result = supabase_db.get_active_api_key(UUID(api_key))

if result.data is not None and len(result.data) > 0:
api_key_creation_date = datetime.strptime(
Expand All @@ -36,7 +37,7 @@ async def get_user_from_api_key(
supabase_db = get_supabase_db()

# Lookup the user_id from the api_keys table
user_id_data = supabase_db.get_user_id_by_api_key(api_key)
user_id_data = supabase_db.get_user_id_by_api_key(UUID(api_key))

if not user_id_data.data:
raise HTTPException(status_code=400, detail="Invalid API key.")
Expand Down
8 changes: 4 additions & 4 deletions backend/core/chat_service.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
import os

if __name__ == "__main__":
# import needed here when running main.py to debug backend
# you will need to run pip install python-dotenv
from dotenv import load_dotenv

load_dotenv()
import sentry_sdk
from fastapi import FastAPI, HTTPException, Request, status
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from logger import get_logger
from middlewares.cors import add_cors_middleware
from routes.misc_routes import misc_router
from routes.chat_routes import chat_router
from routes.misc_routes import misc_router

logger = get_logger(__name__)

Expand All @@ -27,12 +29,10 @@
add_cors_middleware(app)



app.include_router(chat_router)
app.include_router(misc_router)



@app.exception_handler(HTTPException)
async def http_exception_handler(_, exc):
return JSONResponse(
Expand Down Expand Up @@ -64,5 +64,5 @@ async def validation_exception_handler(
if __name__ == "__main__":
# run main.py to debug backend
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=5050)

uvicorn.run(app, host="0.0.0.0", port=5050)
18 changes: 9 additions & 9 deletions backend/core/crawl/crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,17 @@

import requests
from pydantic import BaseModel
from newspaper import Article
from bs4 import BeautifulSoup


class CrawlWebsite(BaseModel):
url: str
js: bool = False
depth: int = int(os.getenv("CRAWL_DEPTH","1"))
depth: int = int(os.getenv("CRAWL_DEPTH", "1"))
max_pages: int = 100
max_time: int = 60

def _crawl(self, url):
try:
try:
response = requests.get(url)
if response.status_code == 200:
return response.text
Expand All @@ -33,7 +32,7 @@ def extract_content(self, url):
article.download()
article.parse()
except Exception as e:
print(f'Error downloading or parsing article: {e}')
print(f"Error downloading or parsing article: {e}")
return None
return article.text

Expand All @@ -49,13 +48,13 @@ def _process_recursive(self, url, depth, visited_urls):
if not raw_html:
return content

soup = BeautifulSoup(raw_html, 'html.parser')
links = [a['href'] for a in soup.find_all('a', href=True)]
soup = BeautifulSoup(raw_html, "html.parser")
links = [a["href"] for a in soup.find_all("a", href=True)]
for link in links:
full_url = urljoin(url, link)
# Ensure we're staying on the same domain
if self.url in full_url:
content += self._process_recursive(full_url, depth-1, visited_urls)
content += self._process_recursive(full_url, depth - 1, visited_urls)

return content

Expand All @@ -73,7 +72,8 @@ def process(self):
return temp_file_path, file_name

def checkGithub(self):
return 'github.com' in self.url
return "github.com" in self.url


def slugify(text):
text = unicodedata.normalize("NFKD", text).encode("ascii", "ignore").decode("utf-8")
Expand Down
2 changes: 0 additions & 2 deletions backend/core/llm/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,3 @@ def embeddings(self) -> OpenAIEmbeddings:
return OpenAIEmbeddings(
openai_api_key=self.openai_api_key
) # pyright: ignore reportPrivateUsage=none


10 changes: 6 additions & 4 deletions backend/core/llm/qa_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from langchain.chains import ConversationalRetrievalChain, LLMChain
from langchain.chains.question_answering import load_qa_chain
from langchain.chat_models import ChatOpenAI
from langchain.embeddings.base import Embeddings
from langchain.llms.base import BaseLLM
from langchain.prompts.chat import (
ChatPromptTemplate,
Expand Down Expand Up @@ -42,9 +43,10 @@ class QABaseBrainPicking(BaseBrainPicking):
Each have the same prompt template, which is defined in the `prompt_template` property.
"""

supabase_client: Client = None
vector_store: CustomSupabaseVectorStore = None
qa: ConversationalRetrievalChain = None
supabase_client: Client
vector_store: CustomSupabaseVectorStore
qa: ConversationalRetrievalChain
embeddings: Embeddings

def __init__(
self,
Expand All @@ -53,7 +55,7 @@ def __init__(
chat_id: str,
streaming: bool = False,
**kwargs,
) -> "QABaseBrainPicking":
):
super().__init__(
model=model,
brain_id=brain_id,
Expand Down
4 changes: 2 additions & 2 deletions backend/core/routes/api_key_routes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from secrets import token_hex
from typing import List
from uuid import uuid4
from uuid import UUID, uuid4

from asyncpg.exceptions import UniqueViolationError
from auth import AuthBearer, get_current_user
Expand Down Expand Up @@ -79,7 +79,7 @@ async def delete_api_key(key_id: str, current_user: User = Depends(get_current_u
"""
supabase_db = get_supabase_db()
supabase_db.delete_api_key(key_id, current_user.id)
supabase_db.delete_api_key(UUID(key_id), current_user.id)

return {"message": "API key deleted."}

Expand Down
32 changes: 22 additions & 10 deletions backend/core/routes/chat_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from auth import AuthBearer, get_current_user
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from fastapi.responses import StreamingResponse
from llm.qa_headless import HeadlessQA
from llm.openai import OpenAIBrainPicking
from models.brains import Brain
from llm.qa_headless import HeadlessQA
from models.brain_entity import BrainEntity
from models.brains import Brain
from models.chat import Chat
from models.chats import ChatQuestion
from models.databases.supabase.supabase import SupabaseDB
Expand Down Expand Up @@ -60,7 +60,7 @@ def check_user_limit(
):
if user.user_openai_api_key is None:
date = time.strftime("%Y%m%d")
max_requests_number = int(os.getenv("MAX_REQUESTS_NUMBER", 1))
max_requests_number = int(os.getenv("MAX_REQUESTS_NUMBER", 1000))

user.increment_user_request_count(date)
if int(user.requests_count) >= int(max_requests_number):
Expand Down Expand Up @@ -238,7 +238,7 @@ async def create_stream_question_handler(
# Retrieve user's OpenAI API key
current_user.user_openai_api_key = request.headers.get("Openai-Api-Key")
brain = Brain(id=brain_id)
brain_details: BrainEntity = None
brain_details: BrainEntity | None = None
if not current_user.user_openai_api_key and brain_id:
brain_details = get_brain_details(brain_id)
if brain_details:
Expand Down Expand Up @@ -268,18 +268,30 @@ async def create_stream_question_handler(
if brain_id:
gpt_answer_generator = OpenAIBrainPicking(
chat_id=str(chat_id),
model=(brain_details or chat_question).model if current_user.user_openai_api_key else "gpt-3.5-turbo",
max_tokens=(brain_details or chat_question).max_tokens if current_user.user_openai_api_key else 0,
temperature=(brain_details or chat_question).temperature if current_user.user_openai_api_key else 256,
model=(brain_details or chat_question).model
if current_user.user_openai_api_key
else "gpt-3.5-turbo",
max_tokens=(brain_details or chat_question).max_tokens
if current_user.user_openai_api_key
else 0,
temperature=(brain_details or chat_question).temperature
if current_user.user_openai_api_key
else 256,
brain_id=str(brain_id),
user_openai_api_key=current_user.user_openai_api_key, # pyright: ignore reportPrivateUsage=none
streaming=True,
)
else:
gpt_answer_generator = HeadlessQA(
model=chat_question.model if current_user.user_openai_api_key else "gpt-3.5-turbo",
temperature=chat_question.temperature if current_user.user_openai_api_key else 256,
max_tokens=chat_question.max_tokens if current_user.user_openai_api_key else 0,
model=chat_question.model
if current_user.user_openai_api_key
else "gpt-3.5-turbo",
temperature=chat_question.temperature
if current_user.user_openai_api_key
else 256,
max_tokens=chat_question.max_tokens
if current_user.user_openai_api_key
else 0,
user_openai_api_key=current_user.user_openai_api_key, # pyright: ignore reportPrivateUsage=none
chat_id=str(chat_id),
streaming=True,
Expand Down
4 changes: 1 addition & 3 deletions backend/core/routes/crawl_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,7 @@ async def crawl_endpoint(
brain = Brain(id=brain_id)

if request.headers.get("Openai-Api-Key"):
brain.max_brain_size = os.getenv(
"MAX_BRAIN_SIZE_WITH_KEY", 209715200
) # pyright: ignore reportPrivateUsage=none
brain.max_brain_size = int(os.getenv("MAX_BRAIN_SIZE_WITH_KEY", 209715200))

file_size = 1000000
remaining_free_space = brain.remaining_brain_size
Expand Down
16 changes: 0 additions & 16 deletions backend/core/tests/test_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,6 @@ def test_upload_explore_and_delete_file_txt(client, api_key):
headers={"Authorization": "Bearer " + api_key},
)

# Commenting out this test out because it is not working since a moment (investigating).
# However, since all PRs were failing, backend tests were starting to get abandoned, which introduced new bugs.

"""
# Assert that the explore response status code is 200 (HTTP OK)
assert explore_response.status_code == 200
# Delete the file
delete_response = client.delete(
f"/explore/{file_name}",
Expand All @@ -99,7 +92,6 @@ def test_upload_explore_and_delete_file_txt(client, api_key):
# Optionally, you can assert on specific fields in the delete response data
delete_response_data = delete_response.json()
assert "message" in delete_response_data
"""


def test_upload_explore_and_delete_file_pdf(client, api_key):
Expand Down Expand Up @@ -195,13 +187,6 @@ def test_upload_explore_and_delete_file_csv(client, api_key):
headers={"Authorization": "Bearer " + api_key},
)

# Commenting out this test out because it is not working since a moment (investigating).
# However, since all PRs were failing, backend tests were starting to get abandoned, which introduced new bugs.

"""
# Assert that the explore response status code is 200 (HTTP OK)
assert explore_response.status_code == 200
# Delete the file
delete_response = client.delete(
f"/explore/{file_name}",
Expand All @@ -215,4 +200,3 @@ def test_upload_explore_and_delete_file_csv(client, api_key):
# Optionally, you can assert on specific fields in the delete response data
delete_response_data = delete_response.json()
assert "message" in delete_response_data
"""
3 changes: 2 additions & 1 deletion backend/core/utils/vectors.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from concurrent.futures import ThreadPoolExecutor
from typing import List
from uuid import UUID

from langchain.embeddings.openai import OpenAIEmbeddings
from logger import get_logger
Expand Down Expand Up @@ -48,7 +49,7 @@ def process_batch(batch_ids: List[str]):

try:
if len(batch_ids) == 1:
return (supabase_db.get_vectors_by_batch(batch_ids[0])).data
return (supabase_db.get_vectors_by_batch(UUID(batch_ids[0]))).data
else:
return (supabase_db.get_vectors_in_batch(batch_ids)).data
except Exception as e:
Expand Down
5 changes: 2 additions & 3 deletions backend/core/vectorstore/supabase.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, List

from langchain.docstore.document import Document
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.embeddings.base import Embeddings
from langchain.vectorstores import SupabaseVectorStore
from supabase.client import Client

Expand All @@ -14,7 +14,7 @@ class CustomSupabaseVectorStore(SupabaseVectorStore):
def __init__(
self,
client: Client,
embedding: OpenAIEmbeddings,
embedding: Embeddings,
table_name: str,
brain_id: str = "none",
):
Expand All @@ -29,7 +29,6 @@ def similarity_search(
threshold: float = 0.5,
**kwargs: Any
) -> List[Document]:

vectors = self._embedding.embed_documents([query])
query_embedding = vectors[0]
res = self._client.rpc(
Expand Down

0 comments on commit c746eb1

Please sign in to comment.