Skip to content

Commit

Permalink
Merge pull request #76 from CambioML/remove-ngrok-chatbot
Browse files Browse the repository at this point in the history
remove unused dependencies
  • Loading branch information
jwilber committed Oct 6, 2023
2 parents f87f830 + 6617c4c commit a86867e
Showing 1 changed file with 72 additions and 133 deletions.
205 changes: 72 additions & 133 deletions pykoi/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@
import asyncio
import os
import re
import socket
import subprocess
import threading
import time

from datetime import datetime
Expand All @@ -15,9 +13,7 @@
from fastapi.responses import JSONResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from pyngrok import ngrok
from starlette.middleware.cors import CORSMiddleware
from pykoi.interactives.chatbot import Chatbot
from pykoi.telemetry.telemetry import Telemetry
from pykoi.telemetry.events import AppStartEvent, AppStopEvent
from pykoi.chat.db.constants import RAG_LIST_SEPARATOR
Expand All @@ -30,48 +26,60 @@ class UpdateQATable(BaseModel):
id: int
vote_status: str


class UpdateRAGTable(BaseModel):
id: int
vote_status: str


class UpdateQATableAnswer(BaseModel):
id: int
new_answer: str


class UpdateRAGTableAnswer(BaseModel):
id: int
new_answer: str


class RankingTableUpdate(BaseModel):
question: str
up_ranking_answer: str
low_ranking_answer: str


class InferenceRankingTable(BaseModel):
n: Optional[int] = 2


class ModelAnswer(BaseModel):
model: str
qid: int
rank: int
answer: str


class ComparatorInsertRequest(BaseModel):
data: List[ModelAnswer]


class RetrievalNewMessage(BaseModel):
prompt: str
file_names: List[str]


class QATableToCSV(BaseModel):
file_name: str


class RAGTableToCSV(BaseModel):
file_name: str


class ComparatorTableToCSV(BaseModel):
file_name: str


class UserInDB:
def __init__(self, username: str, hashed_password: str):
self.username = username
Expand All @@ -97,7 +105,7 @@ def __init__(
Initialize the Application.
Args:
share (bool, optional): If True, the application will be shared via ngrok. Defaults to False.
share (bool, optional): If True, the application will be shared via localhost.run. Defaults to False.
debug (bool, optional): If True, the application will run in debug mode. Defaults to False.
username (str, optional): The username for authentication. Defaults to None.
password (str, optional): The password for authentication. Defaults to None.
Expand Down Expand Up @@ -245,8 +253,16 @@ async def update_qa_table_response(
component["component"].database.update_answer(
request_body.id, request_body.new_answer
)
print("/chat/qa_table/update_answer", request_body.id, request_body.new_answer)
return {"log": "Table response updated", "new_answer": request_body.new_answer, "status": "200"}
print(
"/chat/qa_table/update_answer",
request_body.id,
request_body.new_answer,
)
return {
"log": "Table response updated",
"new_answer": request_body.new_answer,
"status": "200",
}
except Exception as ex:
return {"log": f"Table update failed: {ex}", "status": "500"}

Expand All @@ -257,7 +273,10 @@ async def save_qa_table_to_csv(
):
try:
component["component"].database.save_to_csv(request_body.file_name)
return {"log": f"Saved to {request_body.file_name}.csv", "status": "200"}
return {
"log": f"Saved to {request_body.file_name}.csv",
"status": "200",
}
except Exception as ex:
return {"log": f"Save to CSV failed: {ex}", "status": "500"}

Expand Down Expand Up @@ -339,8 +358,16 @@ async def update_rag_table_response(
component["component"].database.update_answer(
request_body.id, request_body.new_answer
)
print("/chat/rag_table/update_answer", request_body.id, request_body.new_answer)
return {"log": "Table response updated", "new_answer": request_body.new_answer, "status": "200"}
print(
"/chat/rag_table/update_answer",
request_body.id,
request_body.new_answer,
)
return {
"log": "Table response updated",
"new_answer": request_body.new_answer,
"status": "200",
}
except Exception as ex:
return {"log": f"Table update failed: {ex}", "status": "500"}

Expand All @@ -356,8 +383,14 @@ async def retrieve_rag_table(
row_list[5] = row_list[5].split(RAG_LIST_SEPARATOR)
row_list[6] = row_list[6].split(RAG_LIST_SEPARATOR)
row_list[7] = row_list[7].split(RAG_LIST_SEPARATOR)
modified_rows.append(row_list) # Append the modified list to the new list
return {"rows": modified_rows, "log": "RAG Table retrieved", "status": "200"}
modified_rows.append(
row_list
) # Append the modified list to the new list
return {
"rows": modified_rows,
"log": "RAG Table retrieved",
"status": "200",
}
except Exception as ex:
return {"log": f"Table retrieval failed: {ex}", "status": "500"}

Expand All @@ -368,7 +401,10 @@ async def save_rag_table_to_csv(
):
try:
component["component"].database.save_to_csv(request_body.file_name)
return {"log": f"Saved to {request_body.file_name}.csv", "status": "200"}
return {
"log": f"Saved to {request_body.file_name}.csv",
"status": "200",
}
except Exception as ex:
return {"log": f"Save to CSV failed: {ex}", "status": "500"}

Expand Down Expand Up @@ -464,7 +500,9 @@ async def retrieve_comparator(
user: Union[None, UserInDB] = Depends(self.get_auth_dependency())
):
try:
rows = component["component"].comparator_db.retrieve_all_question_answers()
rows = component[
"component"
].comparator_db.retrieve_all_question_answers()
data = []
for row in rows:
a_id, model_name, qid, question, answer, rank, _ = row
Expand Down Expand Up @@ -502,11 +540,13 @@ async def save_comparator_table_to_csv(
try:
print("Saving Comparator to CSV", request_body.file_name)
component["component"].comparator_db.save_to_csv(request_body.file_name)
return {"log": f"Saved to {request_body.file_name}.csv", "status": "200"}
return {
"log": f"Saved to {request_body.file_name}.csv",
"status": "200",
}
except Exception as ex:
return {"log": f"Save to CSV failed: {ex}", "status": "500"}


def create_qa_retrieval_route(self, app: FastAPI, component: Dict[str, Any]):
"""
Create QA retrieval routes for the application.
Expand Down Expand Up @@ -604,19 +644,31 @@ async def inference(
try:
print("[/retrieval]: model inference.....", request_body.prompt)
component["component"].retrieval_model.re_init(request_body.file_names)
output = component["component"].retrieval_model.run_with_return_source_documents({"query": request_body.prompt})
print('output', output, output["result"])
output = component[
"component"
].retrieval_model.run_with_return_source_documents(
{"query": request_body.prompt}
)
print("output", output, output["result"])
if output["source_documents"] == []:
source = ["N/A"]
source_content = ["N/A"]
else:
source = []
source_content = []
for source_document in output["source_documents"]:
source.append(source_document.metadata.get('file_name', 'No file name found'))
source.append(
source_document.metadata.get(
"file_name", "No file name found"
)
)
source_content.append(source_document.page_content)
id = component["component"].database.insert_question_answer(
request_body.prompt, output["result"], request_body.file_names, source, source_content
request_body.prompt,
output["result"],
request_body.file_names,
source,
source_content,
)
return {
"id": id,
Expand Down Expand Up @@ -783,7 +835,6 @@ async def read_item(
# debug mode should be set to False in production because
# it will start two processes when debug mode is enabled.

# Set the ngrok tunnel if share is True
start_event = AppStartEvent(
start_time=time.time(), date_time=datetime.utcfromtimestamp(time.time())
)
Expand Down Expand Up @@ -833,115 +884,3 @@ async def read_item(
duration=time.time() - start_event.start_time,
)
)

def display(self):
"""
Run the application.
"""
print("hey2")
import nest_asyncio

nest_asyncio.apply()
app = FastAPI()

app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)

@app.post("/token")
def login(credentials: HTTPBasicCredentials = Depends(oauth_scheme)):
user = self.authenticate_user(
self._fake_users_db, credentials.username, credentials.password
)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Basic"},
)
return {"message": "Logged in successfully"}

@app.get("/components")
async def get_components(
user: Union[None, UserInDB] = Depends(self.get_auth_dependency())
):
return JSONResponse(
[
{
"id": component["id"],
"svelte_component": component["svelte_component"],
"props": component["props"],
}
for component in self.components
]
)

def create_data_route(id: str, data_source: Any):
"""
Create data route for the application.
Args:
id (str): The id of the data source.
data_source (Any): The data source.
"""

@app.get(f"/data/{id}")
async def get_data(
user: Union[None, UserInDB] = Depends(self.get_auth_dependency())
):
data = data_source.fetch_func()
return JSONResponse(data)

for id, data_source in self.data_sources.items():
create_data_route(id, data_source)

for component in self.components:
if component["svelte_component"] == "Chatbot":
self.create_chatbot_route(app, component)
if component["svelte_component"] == "Feedback":
self.create_feedback_route(app, component)
if component["svelte_component"] == "Compare":
self.create_chatbot_comparator_route(app, component)

app.mount(
"/",
StaticFiles(
directory=os.path.join(
os.path.dirname(os.path.realpath(__file__)), "frontend/dist"
),
html=True,
),
name="static",
)

@app.get("/{path:path}")
async def read_item(
path: str, user: Union[None, UserInDB] = Depends(self.get_auth_dependency())
):
return {"path": path}

# debug mode should be set to False in production because
# it will start two processes when debug mode is enabled.

# Set the ngrok tunnel if share is True
if self._share:
public_url = ngrok.connect(self._port)
print("Public URL:", public_url)
import uvicorn

uvicorn.run(app, host=self._host, port=self._port)
print("Stopping server...")
ngrok.disconnect(public_url)
else:
import uvicorn

def run_uvicorn():
uvicorn.run(app, host=self._host, port=self._port)

t = threading.Thread(target=run_uvicorn)
t.start()
return Chatbot()(port=self._host)

0 comments on commit a86867e

Please sign in to comment.