Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#47- Create API endpoints that return leaderboards #250

Merged
merged 10 commits into from
Jan 2, 2023
12 changes: 11 additions & 1 deletion backend/oasst_backend/api/v1/api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
from fastapi import APIRouter
from oasst_backend.api.v1 import frontend_messages, frontend_users, messages, stats, tasks, text_labels, users
from oasst_backend.api.v1 import (
frontend_messages,
frontend_users,
leaderboards,
messages,
stats,
tasks,
text_labels,
users,
)

api_router = APIRouter()
api_router.include_router(tasks.router, prefix="/tasks", tags=["tasks"])
Expand All @@ -9,3 +18,4 @@
api_router.include_router(users.router, prefix="/users", tags=["users"])
api_router.include_router(frontend_users.router, prefix="/frontend_users", tags=["frontend_users"])
api_router.include_router(stats.router, prefix="/stats", tags=["stats"])
api_router.include_router(leaderboards.router, prefix="/experimental/leaderboards", tags=["leaderboards"])
25 changes: 25 additions & 0 deletions backend/oasst_backend/api/v1/leaderboards.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from fastapi import APIRouter, Depends
from oasst_backend.api import deps
from oasst_backend.models import ApiClient
from oasst_backend.prompt_repository import PromptRepository
from sqlmodel import Session

router = APIRouter()


@router.get("/create/assistant")
def get_assistant_leaderboard(
db: Session = Depends(deps.get_db),
api_client: ApiClient = Depends(deps.get_trusted_api_client),
):
pr = PromptRepository(db, api_client, None)
return pr.get_user_leaderboard(role="assistant")


@router.get("/create/prompter")
def get_prompter_leaderboard(
db: Session = Depends(deps.get_db),
api_client: ApiClient = Depends(deps.get_trusted_api_client),
):
pr = PromptRepository(db, api_client, None)
return pr.get_user_leaderboard(role="prompter")
23 changes: 22 additions & 1 deletion backend/oasst_backend/prompt_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from oasst_backend.models import ApiClient, Message, MessageReaction, Task, TextLabels, User
from oasst_backend.models.payload_column_type import PayloadContainer
from oasst_shared.schemas import protocol as protocol_schema
from oasst_shared.schemas.protocol import SystemStats
from oasst_shared.schemas.protocol import LeaderboardStats, SystemStats
from sqlalchemy import update
from sqlmodel import Session, func
from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND
Expand Down Expand Up @@ -705,3 +705,24 @@ def get_stats(self) -> SystemStats:
deleted=result.get(True, 0),
message_trees=result.get(None, 0),
)

def get_user_leaderboard(self, role: str) -> LeaderboardStats:
"""
Get leaderboard stats for Messages created,
separate leaderboard for prompts & assistants

"""
query = (
self.db.query(Message.user_id, User.username, User.display_name, func.count(Message.user_id))
.join(User, User.id == Message.user_id, isouter=True)
.filter(Message.deleted is not True, Message.role == role)
.group_by(Message.user_id, User.username, User.display_name)
.order_by(func.count(Message.user_id).desc())
)

result = [
{"ranking": i, "user_id": j[0], "username": j[1], "display_name": j[2], "score": j[3]}
for i, j in enumerate(query.all(), start=1)
]

return LeaderboardStats(leaderboard=result)
14 changes: 13 additions & 1 deletion oasst-shared/oasst_shared/schemas/protocol.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import enum
from datetime import datetime
from typing import Literal, Optional, Union
from typing import List, Literal, Optional, Union
from uuid import UUID, uuid4

import pydantic
Expand Down Expand Up @@ -281,3 +281,15 @@ class SystemStats(BaseModel):
active: int = 0
deleted: int = 0
message_trees: int = 0


class UserScore(BaseModel):
ranking: int
user_id: UUID
username: str
display_name: str
score: int


class LeaderboardStats(BaseModel):
leaderboard: List[UserScore]