0. User config

In [None]:
from logger import get_logger
from models.settings import CommonsDep
from models.users import User

logger = get_logger(__name__)


def create_user(commons: CommonsDep, user: User, date):
    logger.info(f"New user entry in db document for user {user.email}")

    return (
        commons["supabase"]
        .table("users")
        .insert(
            {"user_id": user.id, "email": user.email, "date": date, "requests_count": 1}
        )
        .execute()
    )

1. API key handler 

In [None]:
from datetime import datetime

from fastapi import HTTPException
from models.settings import common_dependencies
from models.users import User
from pydantic import DateError


async def verify_api_key(
    api_key: str,
) -> bool:
    try:
        # Use UTC time to avoid timezone issues
        current_date = datetime.utcnow().date()
        commons = common_dependencies()
        result = (
            commons["supabase"]
            .table("api_keys")
            .select("api_key", "creation_time")
            .filter("api_key", "eq", api_key)
            .filter("is_active", "eq", True)
            .execute()
        )
        if result.data is not None and len(result.data) > 0:
            api_key_creation_date = datetime.strptime(
                result.data[0]["creation_time"], "%Y-%m-%dT%H:%M:%S"
            ).date()

            # Check if the API key was created in the month of the current date
            if (api_key_creation_date.month == current_date.month) and (
                api_key_creation_date.year == current_date.year
            ):
                return True
        return False
    except DateError:
        return False


async def get_user_from_api_key(
    api_key: str,
) -> User:
    commons = common_dependencies()

    # Lookup the user_id from the api_keys table
    user_id_data = (
        commons["supabase"]
        .table("api_keys")
        .select("user_id")
        .filter("api_key", "eq", api_key)
        .execute()
    )

    if not user_id_data.data:
        raise HTTPException(status_code=400, detail="Invalid API key.")

    user_id = user_id_data.data[0]["user_id"]

    # Lookup the email from the users table. Todo: remove and use user_id for credentials
    user_email_data = (
        commons["supabase"]
        .table("users")
        .select("email")
        .filter("user_id", "eq", user_id)
        .execute()
    )

    email = user_email_data.data[0]["email"] if user_email_data.data else None

    return User(email=email, id=user_id)

2. Code Authorization

In [None]:
import os
from typing import Optional

from auth.api_key_handler import get_user_from_api_key, verify_api_key
from auth.jwt_token_handler import decode_access_token, verify_token
from fastapi import Depends, HTTPException, Request
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from models.users import User


class AuthBearer(HTTPBearer):
    def __init__(self, auto_error: bool = True):
        super().__init__(auto_error=auto_error)

    async def __call__(
        self,
        request: Request,
    ):
        credentials: Optional[HTTPAuthorizationCredentials] = await super().__call__(
            request
        )
        self.check_scheme(credentials)
        token = credentials.credentials
        return await self.authenticate(
            token,
        )

    def check_scheme(self, credentials):
        if credentials and credentials.scheme != "Bearer":
            raise HTTPException(status_code=401, detail="Token must be Bearer")
        elif not credentials:
            raise HTTPException(
                status_code=403, detail="Authentication credentials missing"
            )

    async def authenticate(
        self,
        token: str,
    ) -> User:
        if os.environ.get("AUTHENTICATE") == "false":
            return self.get_test_user()
        elif verify_token(token):
            return decode_access_token(token)
        elif await verify_api_key(
            token,
        ):
            return await get_user_from_api_key(
                token,
            )
        else:
            raise HTTPException(status_code=401, detail="Invalid token or api key.")

    def get_test_user(self) -> User:
        return User(
            email="test@example.com", id="XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX"
        )  # replace with test user information


def get_current_user(user: User = Depends(AuthBearer())) -> User:
    return user

3. JWT token handler 

In [None]:
import os
from datetime import datetime, timedelta
from typing import Optional

from jose import jwt
from jose.exceptions import JWTError
from models.users import User

SECRET_KEY = os.environ.get("JWT_SECRET_KEY")
ALGORITHM = "HS256"


def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
    to_encode = data.copy()
    if expires_delta:
        expire = datetime.utcnow() + expires_delta
    else:
        expire = datetime.utcnow() + timedelta(minutes=15)
    to_encode.update({"exp": expire})
    encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
    return encoded_jwt


def decode_access_token(token: str) -> User:
    try:
        payload = jwt.decode(
            token, SECRET_KEY, algorithms=[ALGORITHM], options={"verify_aud": False}
        )
    except JWTError:
        return None

    return User(email=payload.get("email"), id=payload.get("sub"))


def verify_token(token: str):
    payload = decode_access_token(token)
    return payload is not None

4. Users

In [None]:
from typing import Optional
from uuid import UUID

from logger import get_logger
from models.settings import common_dependencies
from pydantic import BaseModel

logger = get_logger(__name__)


class User(BaseModel):
    id: UUID
    email: Optional[str]
    user_openai_api_key: Optional[str] = None
    requests_count: int = 0

    # [TODO] Rename the user table and its references to 'user_usage'
    def create_user(self, date):
        commons = common_dependencies()
        logger.info(f"New user entry in db document for user {self.email}")

        return (
            commons["supabase"]
            .table("users")
            .insert(
                {
                    "user_id": self.id,
                    "email": self.email,
                    "date": date,
                    "requests_count": 1,
                }
            )
            .execute()
        )

    def get_user_request_stats(self):
        commons = common_dependencies()
        requests_stats = (
            commons["supabase"]
            .from_("users")
            .select("*")
            .filter("user_id", "eq", self.id)
            .execute()
        )
        return requests_stats.data

    def fetch_user_requests_count(self, date):
        commons = common_dependencies()
        response = (
            commons["supabase"]
            .from_("users")
            .select("*")
            .filter("user_id", "eq", self.id)
            .filter("date", "eq", date)
            .execute()
        )
        userItem = next(iter(response.data or []), {"requests_count": 0})

        return userItem["requests_count"]

    def increment_user_request_count(self, date):
        commons = common_dependencies()
        requests_count = self.fetch_user_requests_count(date) + 1
        logger.info(f"User {self.email} request count updated to {requests_count}")
        commons["supabase"].table("users").update(
            {"requests_count": requests_count}
        ).match({"user_id": self.id, "date": date}).execute()
        self.requests_count = requests_count

5. Pyright config

In [None]:
{
  "exclude": [
    "supabase"
  ]
}