From c15a1852054df80574f6f603aa69d666fe4f46ae Mon Sep 17 00:00:00 2001 From: Sudeep Pillai Date: Wed, 22 May 2024 23:54:57 -0700 Subject: [PATCH] Move to `pytorch==2.3.0`, add auth bearer token support --- docker/agibuild.cu121.yaml | 2 +- nos/server/http/_security.py | 28 +++++++++++++++++++++------- nos/version.py | 2 +- 3 files changed, 23 insertions(+), 9 deletions(-) diff --git a/docker/agibuild.cu121.yaml b/docker/agibuild.cu121.yaml index b7cded62..a3d98900 100644 --- a/docker/agibuild.cu121.yaml +++ b/docker/agibuild.cu121.yaml @@ -12,7 +12,7 @@ images: - git conda: - accelerate - - pytorch + - pytorch=2.3.0 - torchvision - pytorch-cuda=12.1 - -c pytorch -c nvidia diff --git a/nos/server/http/_security.py b/nos/server/http/_security.py index 491f9214..4fb97e77 100644 --- a/nos/server/http/_security.py +++ b/nos/server/http/_security.py @@ -1,7 +1,8 @@ import os +from typing import Optional from fastapi import Depends, HTTPException, Request, status -from fastapi.security import APIKeyHeader +from fastapi.security import APIKeyHeader, HTTPBearer from loguru import logger @@ -10,24 +11,37 @@ if len(key) > 0: logger.debug(f"Adding valid_m2m_keys [key={key}]") valid_m2m_keys[key] = key + api_key_header = APIKeyHeader(name="X-Api-Key", auto_error=False) +api_key_bearer = HTTPBearer(auto_error=False) -async def validate_m2m_key(request: Request, api_key: str = Depends(api_key_header)) -> bool: - logger.debug(f"validate_m2m_key [api_key={api_key}]") +async def get_api_key(request: Request) -> Optional[str]: + api_key: Optional[str] = None + api_key_header_value = request.headers.get("X-Api-Key") + if api_key_header_value: + api_key = api_key_header_value + else: + authorization: Optional[str] = request.headers.get("Authorization") + if authorization: + scheme, credentials = authorization.split() + if scheme.lower() == "bearer": + api_key = credentials + return api_key + +async def validate_m2m_key(request: Request, api_key: Optional[str] = Depends(get_api_key)) -> bool: + logger.debug(f"validate_m2m_key [api_key={api_key}]") if not api_key: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail="Missing X-Api-Key Key header", + detail="Missing authentication token (use `X-Api-Key` or `Authorization: Bearer `)", ) - if api_key not in valid_m2m_keys: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, - detail="Invalid Machine-to-Machine Key", + detail="Invalid authentication token (use `X-Api-Key` or `Authorization: Bearer `)", ) - assert isinstance(api_key, str) return True diff --git a/nos/version.py b/nos/version.py index 493f7415..6a9beea8 100644 --- a/nos/version.py +++ b/nos/version.py @@ -1 +1 @@ -__version__ = "0.3.0" +__version__ = "0.4.0"