Skip to content

Commit

Permalink
Move to pytorch==2.3.0, add auth bearer token support (#552)
Browse files Browse the repository at this point in the history
<!-- Thank you for your contribution! Please review
https://github.com/autonomi-ai/nos/blob/main/docs/CONTRIBUTING.md before
opening a pull request. -->

<!-- Please add a reviewer to the assignee section when you create a PR.
If you don't have the access to it, we will shortly find a reviewer and
assign them to your PR. -->

## Summary

<!-- Please give a short summary of the change and the problem this
solves. -->

## Related issues

<!-- For example: "Closes #1234" -->

## Checks

- [ ] `make lint`: I've run `make lint` to lint the changes in this PR.
- [ ] `make test`: I've made sure the tests (`make test-cpu` or `make
test`) are passing.
- Additional tests:
   - [ ] Benchmark tests (when contributing new models)
   - [ ] GPU/HW tests
  • Loading branch information
spillai authored Jun 1, 2024
1 parent e9a4c2c commit 55411c4
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 15 deletions.
2 changes: 1 addition & 1 deletion docker/agibuild.cpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ images:
- ./scripts/entrypoint.sh:/app/entrypoint.sh
- ./requirements/requirements.server.txt:/tmp/requirements.server.txt
run:
- mamba install pytorch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 cpuonly -c pytorch
- mamba install pytorch=2.3.0 torchvision torchaudio cpuonly -c pytorch
- pip config set global.extra-index-url https://download.pytorch.org/whl/cpu
- pip install -r /tmp/requirements.server.txt && rm -rf /tmp/requirements.server.txt
- mamba install -y -c conda-forge x264=='1!161.3030' ffmpeg=4.3.2
Expand Down
2 changes: 1 addition & 1 deletion docker/agibuild.cu121.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ images:
- git
conda:
- accelerate
- pytorch
- pytorch=2.3.0
- torchvision
- pytorch-cuda=12.1
- -c pytorch -c nvidia
Expand Down
8 changes: 4 additions & 4 deletions docker/agibuild.gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ images:
- build-essential
- git
conda:
- accelerate>0.18.0
- pytorch==2.1.1
- torchaudio==2.1.1
- torchvision==0.16.1
- accelerate
- pytorch=2.3.0
- torchaudio
- torchvision
- pytorch-cuda=11.8
- cudatoolkit=11.8
- -c pytorch -c nvidia
Expand Down
28 changes: 21 additions & 7 deletions nos/server/http/_security.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import os
from typing import Optional

from fastapi import Depends, HTTPException, Request, status
from fastapi.security import APIKeyHeader
from fastapi.security import APIKeyHeader, HTTPBearer
from loguru import logger


Expand All @@ -10,24 +11,37 @@
if len(key) > 0:
logger.debug(f"Adding valid_m2m_keys [key={key}]")
valid_m2m_keys[key] = key

api_key_header = APIKeyHeader(name="X-Api-Key", auto_error=False)
api_key_bearer = HTTPBearer(auto_error=False)


async def validate_m2m_key(request: Request, api_key: str = Depends(api_key_header)) -> bool:
logger.debug(f"validate_m2m_key [api_key={api_key}]")
async def get_api_key(request: Request) -> Optional[str]:
api_key: Optional[str] = None
api_key_header_value = request.headers.get("X-Api-Key")
if api_key_header_value:
api_key = api_key_header_value
else:
authorization: Optional[str] = request.headers.get("Authorization")
if authorization:
scheme, credentials = authorization.split()
if scheme.lower() == "bearer":
api_key = credentials
return api_key


async def validate_m2m_key(request: Request, api_key: Optional[str] = Depends(get_api_key)) -> bool:
logger.debug(f"validate_m2m_key [api_key={api_key}]")
if not api_key:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing X-Api-Key Key header",
detail="Missing authentication token (use `X-Api-Key` or `Authorization: Bearer <token>`)",
)

if api_key not in valid_m2m_keys:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid Machine-to-Machine Key",
detail="Invalid authentication token (use `X-Api-Key` or `Authorization: Bearer <token>`)",
)

assert isinstance(api_key, str)
return True

Expand Down
35 changes: 34 additions & 1 deletion nos/server/http/_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import dataclasses
import os
import time
import uuid
from dataclasses import field
from functools import lru_cache
from pathlib import Path
Expand All @@ -14,10 +15,11 @@
from PIL import Image
from pydantic import ConfigDict
from pydantic.dataclasses import dataclass
from tqdm import tqdm

from nos.client import Client
from nos.common.tasks import TaskType
from nos.constants import DEFAULT_GRPC_ADDRESS
from nos.constants import DEFAULT_GRPC_ADDRESS, NOS_TMP_DIR
from nos.logging import logger
from nos.protoc import import_module
from nos.version import __version__
Expand Down Expand Up @@ -136,6 +138,7 @@ def app_factory(version: str = HTTP_API_VERSION, address: str = DEFAULT_GRPC_ADD

svc = InferenceService(address=address)
logger.info(f"app_factory [env={env}]: Adding CORS middleware ...")

app = FastAPI(
title="NOS REST API",
description=f"NOS REST API (version={__version__}, api_version={version})",
Expand All @@ -156,6 +159,9 @@ def app_factory(version: str = HTTP_API_VERSION, address: str = DEFAULT_GRPC_ADD
app.middleware("http")(default_exception_middleware)
app.add_exception_handler(Exception, default_exception_handler)

NOS_TMP_FILES_DIR = Path(NOS_TMP_DIR) / "uploaded_files"
NOS_TMP_FILES_DIR.mkdir(parents=True, exist_ok=True)

def get_client() -> Client:
"""Get the inference client."""
return svc.client
Expand Down Expand Up @@ -221,6 +227,33 @@ def model_info(
except KeyError:
raise HTTPException(status_code=400, detail=f"Invalid model {model}")

# TODO (delete file after processing)
@app.post(f"/{version}/file/upload", status_code=201)
def upload_file(file: UploadFile = File(...), client: Client = Depends(get_client)) -> JSONResponse:
try:
uid = uuid.uuid4()
basename = f"{uid}-{Path(file.filename).name}"
path = NOS_TMP_FILES_DIR / basename
logger.debug(f"Uploading file: [local={file.filename}, path={path}]")
file.file.seek(0)
with path.open("wb") as f:
for chunk in tqdm(
iter(lambda: file.file.read(1024), b""),
desc="Uploading file",
unit="KB",
unit_scale=True,
unit_divisor=1024,
):
f.write(chunk)
logger.info(f"Successfully uploaded file [path={path}]")
except Exception as exc:
logger.error(f"""Failed to upload file [file={file.filename}, exc={exc}]""")
raise HTTPException(status_code=500, detail="Failed to upload file.")
return {
"file_id": str(uid),
"filename": basename,
}

@app.post(f"/{version}/chat/completions", status_code=status.HTTP_201_CREATED)
def chat(
request: ChatCompletionsRequest,
Expand Down
2 changes: 1 addition & 1 deletion nos/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.3.0"
__version__ = "0.4.1"
1 change: 1 addition & 0 deletions requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ python-multipart
pyyaml
rich>=12.5.1
sentry-sdk[loguru]
setuptools==69.5.1
tqdm
typer>=0.7.0
typing_extensions>=4.5.0

0 comments on commit 55411c4

Please sign in to comment.