Skip to content

Commit

Permalink
Merge pull request #9 from arena-ai/8-add-the-ability-to-download-eve…
Browse files Browse the repository at this point in the history
…nts-in-a-file

8 add the ability to download events in a file
  • Loading branch information
ngrislain committed May 22, 2024
2 parents d925e76 + f24a589 commit c8f0726
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 12 deletions.
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.2.0
0.2.1
61 changes: 58 additions & 3 deletions backend/app/api/routes/events.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from typing import Any
from typing import Any, Literal

from fastapi import APIRouter, HTTPException
from fastapi import APIRouter, HTTPException, Response
from sqlmodel import func, select, desc
from sqlalchemy.orm import aliased
from sqlalchemy.sql.functions import coalesce
import pyarrow as pa
import pyarrow.parquet as pq
import pyarrow.csv as pc

from app.api.deps import CurrentUser, SessionDep
from app import crud
Expand All @@ -18,7 +22,6 @@ def read_events(
"""
Retrieve Events.
"""

if current_user.is_superuser:
statement = select(func.count()).select_from(Event)
count = session.exec(statement).one()
Expand Down Expand Up @@ -199,3 +202,55 @@ def delete_event_attribute(session: SessionDep, current_user: CurrentUser, id: i
session.delete(event_attribute)
session.commit()
return Message(message="Event attribute deleted successfully")


@router.get("/download/{format}")
def download_events(
session: SessionDep, current_user: CurrentUser, format: Literal["arrow", "csv"], skip: int = 0, limit: int = 1000000
) -> Any:
"""
Retrieve Events.
"""
if current_user.is_superuser:
request = select(Event).where(Event.name == "request").offset(skip).limit(limit).cte()
else:
request = select(Event).where(Event.name == "request").where(Event.owner_id == current_user.id).offset(skip).limit(limit).cte()
modified_request = select(Event).where(Event.name == "modified_request").cte()
response = select(Event).where(Event.name == "response").cte()
user_evaluation = select(Event).where(Event.name == "user_evaluation").cte()
lm_judge_evaluation = select(Event).where(Event.name == "lm_judge_evaluation").cte()
lm_config = select(Event).where(Event.name == "lm_config").cte()
statement = (
select(
request.c.id,
request.c.timestamp,
request.c.owner_id,
request.c.content.label("request"),
modified_request.c.content.label("modified_request"),
response.c.content.label("response"),
user_evaluation.c.content.label("user_evaluation"),
lm_judge_evaluation.c.content.label("lm_judge_evaluation"),
lm_config.c.content.label("lm_config"),
)
.outerjoin(modified_request, request.c.id == modified_request.c.parent_id)
.outerjoin(response, request.c.id == response.c.parent_id)
.outerjoin(user_evaluation, request.c.id == user_evaluation.c.parent_id)
.outerjoin(lm_judge_evaluation, request.c.id == lm_judge_evaluation.c.parent_id)
.outerjoin(lm_config, request.c.id == lm_config.c.parent_id)
)
# Execute the query
result = session.exec(statement)
events = result.all()
# Arrange them in a Table
table = pa.Table.from_pylist([dict(zip(result.keys(), event)) for event in events])
# Write table to a parquet format in memory
buf = pa.BufferOutputStream()
match format:
case "arrow":
pq.write_table(table, buf)
case "csv":
pc.write_csv(table, buf)
# Get the buffer value
buf = buf.getvalue().to_pybytes()
# Return a file as the response
return Response(content=buf, media_type='application/octet-stream')
13 changes: 8 additions & 5 deletions backend/app/lm/api/routes/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from app.services import Request, Response
from app.ops import cst, tup, Computation
from app.ops.settings import openai_api_key, mistral_api_key, anthropic_api_key, language_models_api_keys, lm_config
from app.ops.events import log_request, LogRequest, log_response, create_event_identifier, log_lm_judge_evaluation
from app.ops.events import log_request, LogRequest, log_response, create_event_identifier, log_lm_judge_evaluation, log_lm_config
from app.ops.lm import openai, openai_request, mistral, mistral_request, anthropic, anthropic_request, chat, chat_request, judge
from app.ops.masking import masking, replace_masking
from app.ops.session import session, user, event
Expand Down Expand Up @@ -54,10 +54,12 @@ def lm_response(self, ses: Computation[Session], usr: Computation[UserOut], requ
async def process_request(self) -> Resp:
ses = session()
usr = user(ses, self.user.id)
# We need the config now
config = await self.config(ses, usr).evaluate(session=self.session)
# Arena request
arena_request = self.arena_request()
arena_request_event = log_request(ses, usr, None, arena_request)
# We need the config now
config = await self.config(ses, usr).evaluate(session=self.session)
config_event = log_lm_config(ses, usr, arena_request_event, config)
# Build the request
lm_request = await self.lm_request().evaluate(session=self.session)
lm_request_event = arena_request_event
Expand All @@ -83,7 +85,7 @@ async def set_content():
chat_completion_response = lm_response.content
event_identifier = create_event_identifier(ses, usr, arena_request_event, chat_completion_response.id)
# Evaluate before post-processing
arena_request_event, lm_request_event, lm_response_event, event_identifier, chat_completion_response = await tup(arena_request_event, lm_request_event, lm_response_event, event_identifier, chat_completion_response).evaluate(session=self.session)
arena_request_event, config_event, lm_request_event, lm_response_event, event_identifier, chat_completion_response = await tup(arena_request_event, config_event, lm_request_event, lm_response_event, event_identifier, chat_completion_response).evaluate(session=self.session)
# post-process the (request, response) pair
if config.judge_evaluation:
judge_score = judge(
Expand Down Expand Up @@ -227,11 +229,12 @@ async def chat_completion_response(
usr = user(ses, current_user.id)
request_event = event(ses, chat_completion_request_event_response.request_event_id)
config = await lm_config(ses, usr).evaluate(session=session_dep)
config_event = log_lm_config(ses, usr, request_event, config)
lm_response = Response(status_code=200, headers={}, content=chat_completion_request_event_response.response)
lm_response_event = log_response(ses, usr, request_event, lm_response)
event_identifier = create_event_identifier(ses, usr, request_event, chat_completion_request_event_response.response.id)
# Evaluate before post-processing
lm_response_event, event_identifier = await tup(lm_response_event, event_identifier).evaluate(session=session_dep)
config_event, lm_response_event, event_identifier = await tup(config_event, lm_response_event, event_identifier).evaluate(session=session_dep)
# post-process the (request, response) pair
if config.judge_evaluation:
judge_score = judge(language_models_api_keys(ses, usr), chat_completion_request_event_response.request, chat_completion_request_event_response.response)
Expand Down
9 changes: 7 additions & 2 deletions backend/app/ops/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from app.ops import Op
import app.crud as crud
from app.models import EventCreate, EventOut, User, EventIdentifier
from app.lm.models import Score
from app.lm.models import Score, LMConfig
from app.services import Request, Response


Expand Down Expand Up @@ -52,4 +52,9 @@ class LogLMJudgeEvaluation(LogEvent[Score]):
class LogUserEvaluation(LogEvent[Score]):
name: str = "user_evaluation"

log_user_evaluation = LogUserEvaluation()
log_user_evaluation = LogUserEvaluation()

class LogLMConfig(LogEvent[LMConfig]):
name: str = "lm_config"

log_lm_config = LogLMConfig()
1 change: 1 addition & 0 deletions backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ anthropic = "^0.25"
mistralai = "^0.1"
anyio = {extras = ["trio"], version = "^4.3.0"}
faker = "^25.0.0"
pyarrow = "^16.1.0"

[tool.poetry.group.dev.dependencies]
pytest = "^7.4.3"
Expand Down
2 changes: 1 addition & 1 deletion client/examples/pii_removal/what.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def chat_completion_request(self) -> dict[str, Any]:
)

print("\n[bold blue]Activate masking")
arena.lm_config(lm_config=LMConfig(pii_removal="masking", judge_evaluation=True, judge_with_pii=True))
arena.lm_config(lm_config=LMConfig(pii_removal="masking", judge_evaluation=True, judge_with_pii=False))

print("\n[bold blue]Run experiments with masking")
for i in range(20):
Expand Down

0 comments on commit c8f0726

Please sign in to comment.