Skip to content

Commit

Permalink
use jose
Browse files Browse the repository at this point in the history
  • Loading branch information
Mropat committed Mar 16, 2022
1 parent 15c4653 commit 2984a4d
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 10 deletions.
1 change: 0 additions & 1 deletion genotype_api/api/endpoints/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def delete_user(
)
session.delete(user)
session.commit()
session.flush()
return JSONResponse(content="User deleted successfully", status_code=status.HTTP_200_OK)


Expand Down
2 changes: 2 additions & 0 deletions genotype_api/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pathlib import Path
from typing import Literal

from pydantic import BaseSettings

Expand All @@ -24,6 +25,7 @@ class SecuritySettings(BaseSettings):

client_id = ""
algorithm = ""
jwks_uri = "https://www.googleapis.com/oauth2/v3/certs"

class Config:
env_file = str(ENV_FILE)
Expand Down
31 changes: 22 additions & 9 deletions genotype_api/security.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,29 @@
from fastapi import HTTPException, Security, Depends
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from sqlmodel import Session
from starlette import status

from starlette.requests import Request

from genotype_api.database import get_session
from genotype_api.models import User
from genotype_api.config import security_settings
from genotype_api.crud.users import get_user_by_email

from google.oauth2 import id_token
from google.auth.transport import requests
from jose import jwt
import requests


def decode_id_token(token: str):
try:
return id_token.verify_oauth2_token(token, requests.Request(), security_settings.client_id)
return jwt.decode(
token,
key=requests.get(security_settings.jwks_uri).json(),
algorithms=[security_settings.algorithm],
audience=security_settings.client_id,
options={
"verify_at_hash": False,
},
)
except:
return

Expand All @@ -27,15 +35,20 @@ def __init__(self, auto_error: bool = True):
async def __call__(self, request: Request):
credentials: HTTPAuthorizationCredentials = await super(JWTBearer, self).__call__(request)
if not credentials:
raise HTTPException(status_code=403, detail="Invalid authorization code.")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Invalid authorization code."
)
if credentials.scheme != "Bearer":
raise HTTPException(status_code=403, detail="Invalid authentication scheme.")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Invalid authentication scheme."
)
if not self.verify_jwt(credentials.credentials):
raise HTTPException(status_code=403, detail="Invalid token or expired token.")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Invalid token or expired token."
)
return credentials.credentials

def verify_jwt(self, jwtoken: str) -> bool:
isTokenValid: bool = False
try:
payload = decode_id_token(jwtoken)
except:
Expand All @@ -54,5 +67,5 @@ async def get_active_user(
user = User.parse_obj(decode_id_token(token))
db_user: User = get_user_by_email(session=session, email=user.email)
if not db_user:
raise HTTPException(status_code=400, detail="User not in DB")
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="User not in DB")
return user

0 comments on commit 2984a4d

Please sign in to comment.