Skip to content

Commit

Permalink
feat: Add bunnet dependency and update openai version
Browse files Browse the repository at this point in the history
The code changes include adding the `bunnet` dependency to the `requirements.txt` file and updating the `openai` dependency to version 1.40.2. These changes are necessary to support the new functionality introduced in the code.

Recent user commits:
- popup on upload
- multi file upload
- Add a celery task queue to backend (#82)
- Update README.md
- Merge pull request #81 from Watts-Lab/markwhiting-patch-1
- Update README.md
- assign docker command to task definition
- clean up docker compose for development env
- Login window (#76)
- Update README.md
- upload button

Recent repository commits:
- popup on upload
- multi file upload
- Add a celery task queue to backend (#82)
- Update README.md
- Merge pull request #81 from Watts-Lab/markwhiting-patch-1
- Update README.md
- assign docker command to task definition
- clean up docker compose for development env
- Login window (#76)
- Update README.md
- upload button
  • Loading branch information
amirrr committed Aug 10, 2024
1 parent fb930ec commit 3c41d64
Show file tree
Hide file tree
Showing 10 changed files with 371 additions and 142 deletions.
57 changes: 53 additions & 4 deletions server/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,64 @@

import argparse
import os
from celery import Celery
from celery.result import EagerResult
from dotenv import load_dotenv
import jwt
from sanic import Sanic, json as json_response
from sanic.request import Request
from sanic.worker.manager import WorkerManager
from sanic_cors import CORS
from sanic_jwt import initialize, exceptions
from config.app_config import AppConfig
from controllers.login import login_user, validate_user
from database.database import init_db
from database.models.users import User
import socketio

from celery_worker import get_paper_info, run_assistant, celery
from celery.result import EagerResult
from celery_worker import get_paper_info, run_assistant


load_dotenv()

WorkerManager.THRESHOLD = 600


async def authenticate(request: Request, *args, **kwargs):
"""
Authenticate the user.
"""
token = request.credentials.token
decoded_token = jwt.decode(token, app.config.JWT_SECRET, algorithms=["HS256"])
email = decoded_token.get("email")
user = User.find_one(User.email == email).run()
if user:
return user.to_dict()
else:
raise exceptions.AuthenticationFailed("User not found.")


async def retrieve_user(request: Request, payload, *args, **kwargs):
"""
Retrieve the user.
"""
if payload:
user_id = payload.get("user_id", None)
user = User.find_one(id=user_id).run()
return user.to_dict()
else:
return None


# Initialize the Sanic app
app = Sanic("Atlas", config=AppConfig())
sanicjwt = initialize(
app,
authenticate=authenticate,
retrieve_user=retrieve_user,
url_prefix="/api/auth",
secret=app.config.JWT_SECRET,
)


# Initialize CORS
CORS(app, resources={r"/*": {"origins": "*"}})
Expand All @@ -45,7 +83,18 @@ async def attach_db(_app, _loop):
"""
Initialize the database connection.
"""
await init_db()
init_db()


@app.route("/api/protected", methods=["GET"])
@sanicjwt.protected()
async def protected_route(request: Request):
"""
A protected route.
"""
print("userrrrr", request.app)
print("request to api/protedced aaaaaa", request)
return json_response({"protected": True})


@app.route("/api/login", methods=["POST"])
Expand Down
148 changes: 142 additions & 6 deletions server/celery_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,48 @@
import os
import sys
from celery import Celery, Task
from celery.signals import worker_process_init
from dotenv import load_dotenv
import requests
import socketio
import boto3

from controllers.assisstant import run_assistant_api
from database.database import init_db
from database.models.papers import Papers
from database.models.results import Result
from database.models.users import User
from utils.assistant_retriever import Assistant


load_dotenv()

sys.path.append(os.getcwd())

celery = Celery(__name__)
celery.conf.broker_url = os.getenv("CELERY_BROKER_URL")
celery.conf.result_backend = os.getenv("CELERY_RESULT_BACKEND")
celery = Celery(
__name__,
broker=os.getenv("CELERY_BROKER_URL"),
backend=os.getenv("CELERY_RESULT_BACKEND"),
)


AWS_S3_BUCKET = os.getenv("AWS_S3_BUCKET")
AWS_S3_KEY = os.getenv("AWS_S3_KEY")
AWS_S3_SECRET = os.getenv("AWS_S3_SECRET")

# Global variable to hold the initialized database
DB_INITIALIZED = False


@worker_process_init.connect
def init_celery_worker(**kwargs):
"""
Initialize the Beanie/MongoDB connection for each worker process.
"""
global DB_INITIALIZED
if not DB_INITIALIZED:
init_db()
DB_INITIALIZED = True


@celery.task(bind=True, name="get_paper_info")
Expand All @@ -29,6 +57,7 @@ def get_paper_info(paper_path: str):
external_sio = socketio.RedisManager(
os.getenv("CELERY_BROKER_URL"), write_only=True
)

print("Task created.", paper_path)

external_sio.emit("task_created", {"task": "Task created."})
Expand All @@ -43,11 +72,29 @@ def get_paper_info(paper_path: str):
return paper_info


def save_paper_info(paper_info: dict) -> Papers:
"""
Save the paper info to the database.
"""
user = User.find_one(User.email == "amirhossein.nakhaei@rwth-aachen.de").run()

new_paper = Papers(
user=user,
title=paper_info["title"],
run_ids=[paper_info["run_id"]],
truth_ids=[],
s3_url=paper_info["s3_url"],
)

return new_paper.create()


@celery.task(bind=True, name="run_assistant")
def run_assistant(self: Task, paper_path: str, socket_id: str):
"""
Task to create a task.
"""

external_sio = socketio.RedisManager(
os.getenv("CELERY_BROKER_URL"), write_only=True
)
Expand All @@ -57,7 +104,7 @@ def run_assistant(self: Task, paper_path: str, socket_id: str):
external_sio.emit(
"status",
{
"status": "Updating assistant...",
"status": "Starting assistant...",
"progress": 0,
"task_id": task_id,
"done": False,
Expand All @@ -66,10 +113,63 @@ def run_assistant(self: Task, paper_path: str, socket_id: str):
namespace="/home",
)

res = run_assistant_api(
s3 = boto3.client(
"s3",
aws_access_key_id=AWS_S3_KEY,
aws_secret_access_key=AWS_S3_SECRET,
)

try:
s3.upload_file(
paper_path,
AWS_S3_BUCKET,
f"amirhossein.nakhaei@rwth-aachen.de/{paper_path}",
ExtraArgs=None,
Callback=None,
Config=None,
)

res = {
"title": paper_path.replace("paper/", "").replace(f"{socket_id}-", ""),
"run_id": task_id,
"s3_url": f"https://{AWS_S3_BUCKET}.s3.amazonaws.com/amirhossein.nakhaei@rwth-aachen.de/{paper_path}",
}

# Save the paper info to the database
new_paper = save_paper_info(res)

print(f"File uploaded to S3: {res}")
except Exception as e:
print(f"Error uploading file to S3: {e}")

open_ai_res = run_assistant_api(
file_path=paper_path, sid=socket_id, sio=external_sio, task_id=task_id
)

result_obj = Result(
user=User.find_one(User.email == "amirhossein.nakhaei@rwth-aachen.de").run(),
json_response=open_ai_res["output"]["result"],
prompt_token=open_ai_res["output"]["prompt_token"],
completion_token=open_ai_res["output"]["completion_token"],
quality=0.9,
feature_list=default_experiments_features,
run_id=task_id,
)

external_sio.emit(
"status",
{
"status": "Saving results to database...",
"progress": 90,
"task_id": task_id,
"done": False,
},
to=socket_id,
namespace="/home",
)

result_obj.create()

external_sio.emit(
"status",
{
Expand All @@ -82,7 +182,11 @@ def run_assistant(self: Task, paper_path: str, socket_id: str):
namespace="/home",
)

return res
return {
"message": "Success",
"file_name": open_ai_res["file_name"],
"experiments": open_ai_res["output"]["result"]["experiments"],
}


def search_paper_by_title(title, filtering=None, sort=None, order=None):
Expand Down Expand Up @@ -111,3 +215,35 @@ def search_paper_by_title(title, filtering=None, sort=None, order=None):
print("No paper found with the given title.")
else:
print("Error occurred while searching for the paper.")


default_experiments_features = [
"experiments.name",
"experiments.description",
"experiments.participant_source",
"experiments.participant_source_category",
"experiments.units_randomized",
"experiments.units_analyzed",
"experiments.sample_size_randomized",
"experiments.sample_size_analyzed",
"experiments.sample_size_notes",
"experiments.adults",
"experiments.age_mean",
"experiments.age_sd",
"experiments.female_perc",
"experiments.male_perc",
"experiments.gender_other",
"experiments.language",
"experiments.language_secondary",
"experiments.compensation",
"experiments.demographics_conditions",
"experiments.population_other",
"experiments.conditions.name",
"experiments.conditions.description",
"experiments.conditions.type",
"experiments.conditions.message",
"experiments.conditions.behaviors.name",
"experiments.conditions.behaviors.description",
"experiments.conditions.behaviors.priority",
"experiments.conditions.behaviors.focal",
]
11 changes: 7 additions & 4 deletions server/controllers/assisstant.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
"""

import os
from sanic import json as json_response
import socketio

from gpt_assistant import AssistantException, call_asssistant_api
Expand All @@ -29,9 +28,9 @@ def run_assistant_api(
file_name = file_path.replace("paper/", "").replace(f"{sid}-", "")

response_data = {
"message": "File successfully uploaded",
"message": "Success",
"file_name": file_name,
"experiments": result["experiments"],
"output": result,
}

if os.path.isfile(file_path):
Expand All @@ -43,5 +42,9 @@ def run_assistant_api(

return response_data
except AssistantException as e:
response_data = {"error": str(e)}
response_data = {
"message": "Failed",
"file_name": "failed",
"output": str(e),
}
return response_data
10 changes: 5 additions & 5 deletions server/controllers/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ async def login_user(email: str):
return json_response(body=response_data, status=400)

# Fetch user asynchronously
user = await User.find_one(User.email == email)
user = User.find_one(User.email == email).run()
expiration_time = datetime.now(UTC) + timedelta(hours=1)

if user:
Expand All @@ -31,7 +31,7 @@ async def login_user(email: str):
user.magic_link_expired = False
user.magic_link_expiration_date = expiration_time
user.updated_at = datetime.now(UTC)
await user.save() # Save the User asynchronously
user.save() # Save the User asynchronously

# Send magic link via email
send_magic_link(email=email, token=magic_link)
Expand All @@ -48,7 +48,7 @@ async def login_user(email: str):
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
await new_user.create() # Create the User asynchronously
new_user.create() # Create the User asynchronously

# Send magic link via email
send_magic_link(email=email, token=new_user.magic_link)
Expand Down Expand Up @@ -76,7 +76,7 @@ async def validate_user(email: str, token: str):
return json_response(body=response_data, status=400)

# Fetch user asynchronously
user = await User.find_one(User.email == email)
user = User.find_one(User.email == email).run()

if user:
if user.magic_link == token and not user.magic_link_expired:
Expand All @@ -87,7 +87,7 @@ async def validate_user(email: str, token: str):
# user.magic_link_expired = True
user.updated_at = datetime.now(UTC)
user.magic_link_expired = True
await user.save() # Save the User asynchronously
user.save() # Save the User asynchronously

response_data = {"message": "Magic link validated."}
header = {
Expand Down
Loading

0 comments on commit 3c41d64

Please sign in to comment.