Skip to content

Commit

Permalink
Feature/add log filtration (#342)
Browse files Browse the repository at this point in the history
* filter logs

* cleanup
  • Loading branch information
emrgnt-cmplxty committed Apr 30, 2024
1 parent 3d56b38 commit 806b2c2
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 43 deletions.
16 changes: 10 additions & 6 deletions r2r/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,14 +190,18 @@ def filtered_deletion(self, key: str, value: Union[bool, int, str]):
response = requests.delete(url, params={"key": key, "value": value})
return response.json()

def get_logs(self):
url = f"{self.base_url}/logs"
response = requests.get(url)
def get_logs(self, pipeline_type=None):
params = {}
if pipeline_type:
params["pipeline_type"] = pipeline_type
response = requests.get(f"{self.base_url}/logs", params=params)
return response.json()

def get_logs_summary(self):
url = f"{self.base_url}/logs_summary"
response = requests.get(url)
def get_logs_summary(self, pipeline_type=None):
params = {}
if pipeline_type:
params["pipeline_type"] = pipeline_type
response = requests.get(f"{self.base_url}/logs_summary", params=params)
return response.json()

def get_user_ids(self):
Expand Down
89 changes: 64 additions & 25 deletions r2r/core/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,20 @@
import types
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Optional

logger = logging.getLogger(__name__)

# TODO - Move Run Types to a global enum
RUN_TYPES = [
"ingestion",
"embedding",
"search",
"rag",
"evaluation",
"scraper",
]


class LoggingProvider(ABC):
@abstractmethod
Expand All @@ -27,7 +38,9 @@ def log(
pass

@abstractmethod
def get_logs(self, max_logs: int) -> list:
def get_logs(
self, max_logs: int, pipeline_run_type: Optional[str] = None
) -> list:
pass


Expand Down Expand Up @@ -114,22 +127,22 @@ def log(
f"Error occurred while logging to the PostgreSQL database: {str(e)}"
)

def get_logs(self, max_logs: int) -> list:
def get_logs(self, max_logs: int, pipeline_run_type=None) -> list:
logs = []
with self.db_module.connect(
dbname=os.getenv("POSTGRES_DBNAME"),
user=os.getenv("POSTGRES_USER"),
password=os.getenv("POSTGRES_PASSWORD"),
host=os.getenv("POSTGRES_HOST"),
port=os.getenv("POSTGRES_PORT"),
) as conn:
with conn.cursor() as cur:
with self.conn.cursor() as cur:
if pipeline_run_type:
cur.execute(
f"SELECT * FROM {self.collection_name} WHERE pipeline_run_type = %s ORDER BY timestamp DESC LIMIT %s",
(pipeline_run_type, max_logs),
)
else:
cur.execute(
f"SELECT * FROM {self.collection_name} ORDER BY timestamp DESC LIMIT %s",
(max_logs,),
)
colnames = [desc[0] for desc in cur.description]
logs = [dict(zip(colnames, row)) for row in cur.fetchall()]
colnames = [desc[0] for desc in cur.description]
logs = [dict(zip(colnames, row)) for row in cur.fetchall()]
self.conn.commit()
return logs


Expand Down Expand Up @@ -203,13 +216,20 @@ def log(
f"Error occurred while logging to the local database: {str(e)}"
)

def get_logs(self, max_logs: int) -> list:
def get_logs(self, max_logs: int, pipeline_run_type=None) -> list:
logs = []
with self.db_module.connect(self.logging_path) as conn:
cur = conn.execute(
f"SELECT * FROM {self.collection_name} ORDER BY timestamp DESC LIMIT ?",
(max_logs,),
)
cur = conn.cursor()
if pipeline_run_type:
cur.execute(
f"SELECT * FROM {self.collection_name} WHERE pipeline_run_type = ? ORDER BY timestamp DESC LIMIT ?",
(pipeline_run_type, max_logs),
)
else:
cur.execute(
f"SELECT * FROM {self.collection_name} ORDER BY timestamp DESC LIMIT ?",
(max_logs,),
)
colnames = [desc[0] for desc in cur.description]
results = cur.fetchall()
logs = [dict(zip(colnames, row)) for row in results]
Expand Down Expand Up @@ -254,7 +274,6 @@ def log(
log_level,
):
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

log_entry = {
"timestamp": timestamp,
"pipeline_run_id": str(pipeline_run_id),
Expand All @@ -264,14 +283,32 @@ def log(
"log_level": log_level,
}
try:
self.redis.lpush(self.log_key, json.dumps(log_entry))
# Save log entry under a key that includes the pipeline_run_type
type_specific_key = f"{self.log_key}:{pipeline_run_type}"
self.redis.lpush(type_specific_key, json.dumps(log_entry))
except Exception as e:
# Handle any exceptions that occur during the logging process
logger.error(f"Error occurred while logging to Redis: {str(e)}")

def get_logs(self, max_logs: int) -> list:
logs = self.redis.lrange(self.log_key, 0, max_logs - 1)
return [json.loads(log) for log in logs]
def get_logs(self, max_logs: int, pipeline_run_type=None) -> list:
if pipeline_run_type:
if pipeline_run_type not in RUN_TYPES:
raise ValueError(
f"Error, `{pipeline_run_type}` is not in LoggingDatabaseConnection's list of supported run types."
)
# Fetch logs for a specific type
key_to_fetch = f"{self.log_key}:{pipeline_run_type}"
logs = self.redis.lrange(key_to_fetch, 0, max_logs - 1)
return [json.loads(log) for log in logs]
else:
# Fetch logs for all types
all_logs = []
for run_type in RUN_TYPES:
key_to_fetch = f"{self.log_key}:{run_type}"
logs = self.redis.lrange(key_to_fetch, 0, max_logs - 1)
all_logs.extend([json.loads(log) for log in logs])
# Sort logs by timestamp if needed and slice to max_logs
all_logs.sort(key=lambda x: x["timestamp"], reverse=True)
return all_logs[:max_logs]


class LoggingDatabaseConnection:
Expand Down Expand Up @@ -312,8 +349,10 @@ def __enter__(self):
def __exit__(self, exc_type, exc_val, exc_tb):
self.logging_provider.close()

def get_logs(self, max_logs: int) -> list:
return self.logging_provider.get_logs(max_logs)
def get_logs(
self, max_logs: int, pipeline_run_type: Optional[str]
) -> list:
return self.logging_provider.get_logs(max_logs, pipeline_run_type)


def log_execution_to_db(func):
Expand Down
4 changes: 4 additions & 0 deletions r2r/examples/clients/run_test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,7 @@ async def stream_rag_completion():
print("Fetching logs summary after all steps...")
logs_summary_response = client.get_logs_summary()
print(f"Logs summary response:\n{logs_summary_response}\n")

print("Fetching 'rag' logs after all steps...")
rag_logs_response = client.get_logs(pipeline_type="rag")
print(f"'Rag' Logs response:\n{rag_logs_response}\n")
33 changes: 22 additions & 11 deletions r2r/main/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import requests
from fastapi import (
BackgroundTasks,
Depends,
FastAPI,
File,
Form,
Expand Down Expand Up @@ -38,6 +39,7 @@
AddEntriesRequest,
AddEntryRequest,
EvalPayloadModel,
LogFilterModel,
LogModel,
RAGMessageModel,
SettingsModel,
Expand Down Expand Up @@ -353,7 +355,11 @@ async def _stream_rag_completion(
url = "http://localhost:8000"
else:
url = str(url).split("/rag_completion")[0]
if "localhost" not in url and "127.0.0.1" not in url:
if (
"localhost" not in url
and "127.0.0.1" not in url
and "0.0.0.0" not in url
):
url = url.replace("http://", "https://")

# Pass the payload to the /eval endpoint
Expand Down Expand Up @@ -386,7 +392,7 @@ async def _stream_rag_completion(

@app.post("/eval")
async def eval(payload: EvalPayloadModel):
# try:
try:
logging.info(
f"Received evaluation payload: {payload.dict(exclude_none=True)}"
)
Expand All @@ -401,11 +407,12 @@ async def eval(payload: EvalPayloadModel):
)

return {"message": "Evaluation completed successfully."}
# except Exception as e:
# logging.error(
# f":eval_endpoint: [Error](payload={payload}, error={str(e)})"
# )
# raise HTTPException(status_code=500, detail=str(e))

except Exception as e:
logging.error(
f":eval_endpoint: [Error](payload={payload}, error={str(e)})"
)
raise HTTPException(status_code=500, detail=str(e))

@app.delete("/filtered_deletion/")
async def filtered_deletion(key: str, value: Union[bool, int, str]):
Expand Down Expand Up @@ -448,13 +455,15 @@ async def get_user_documents(user_id: str):
raise HTTPException(status_code=500, detail=str(e))

@app.get("/logs")
async def logs():
async def logs(filter: LogFilterModel = Depends()):
try:
if logging_connection is None:
raise HTTPException(
status_code=404, detail="Logging provider not found."
)
logs = logging_connection.get_logs(config.app.get("max_logs", 100))
logs = logging_connection.get_logs(
config.app.get("max_logs", 100), filter.pipeline_type
)
for log in logs:
LogModel(**log).dict(by_alias=True)
return {
Expand All @@ -465,13 +474,15 @@ async def logs():
raise HTTPException(status_code=500, detail=str(e))

@app.get("/logs_summary")
async def logs_summary():
async def logs_summary(filter: LogFilterModel = Depends()):
try:
if logging_connection is None:
raise HTTPException(
status_code=404, detail="Logging provider not found."
)
logs = logging_connection.get_logs(config.app.get("max_logs", 100))
logs = logging_connection.get_logs(
config.app.get("max_logs", 100), filter.pipeline_type
)
logs_summary = process_logs(logs)
events_summary = [
SummaryLogModel(**log).dict(by_alias=True)
Expand Down
4 changes: 4 additions & 0 deletions r2r/main/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,7 @@ class SummaryLogModel(BaseModel):
class Config:
alias_generator = to_camel
populate_by_name = True


class LogFilterModel(BaseModel):
pipeline_type: Optional[str] = None
4 changes: 3 additions & 1 deletion r2r/pipelines/core/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ def __init__(
*args,
**kwargs,
):
super().__init__(eval_provider, logging_connection=logging_connection, **kwargs)
super().__init__(
eval_provider, logging_connection=logging_connection, **kwargs
)

@log_execution_to_db
def evaluate(self, query: str, context: str, completion: str) -> Any:
Expand Down

0 comments on commit 806b2c2

Please sign in to comment.