1
+ from fastapi import Depends , HTTPException , status
2
+ from fastapi .security import OAuth2PasswordBearer
3
+ from jose import JWTError , jwt
4
+ from sqlalchemy .orm import Session
5
+
6
+ from . import schemas
7
+ from .config import settings
8
+ from .database import get_db
9
+ from .models import User
10
+
11
+ oauth2_scheme = OAuth2PasswordBearer (tokenUrl = "token" )
12
+
13
+ SECRET_KEY = settings .secret_key
14
+ ALGORITHM = settings .algorithm
15
+ ACCESS_TOKEN_EXPIRE_MINUTES = settings .access_token_expire_minutes
16
+
17
+
18
+ def create_access_token (data : dict ) -> str :
19
+ """
20
+ Generates a JWT access token.
21
+
22
+ Args:
23
+ data (dict): A dictionary containing user data to include in the token.
24
+
25
+ Returns:
26
+ str: The encoded JWT access token.
27
+ """
28
+ to_encode = data .copy ()
29
+ expire = datetime .utcnow () + timedelta (minutes = ACCESS_TOKEN_EXPIRE_MINUTES )
30
+ to_encode .update ({"exp" : expire })
31
+ encoded_jwt = jwt .encode (to_encode , SECRET_KEY , algorithm = ALGORITHM )
32
+ return encoded_jwt
33
+
34
+
35
+ def verify_access_token (token : str , credentials_exception ) -> schemas .TokenData :
36
+ """
37
+ Verifies a JWT access token.
38
+
39
+ Args:
40
+ token (str): The JWT access token to verify.
41
+ credentials_exception: An HTTPException to raise if the token is invalid.
42
+
43
+ Returns:
44
+ schemas.TokenData: The decoded token data.
45
+
46
+ Raises:
47
+ HTTPException: If the token is invalid or expired.
48
+ """
49
+ try :
50
+ payload = jwt .decode (token , SECRET_KEY , algorithms = [ALGORITHM ])
51
+ id : str = payload .get ("sub" )
52
+ if id is None :
53
+ raise credentials_exception
54
+ token_data = schemas .TokenData (id = id )
55
+ except JWTError :
56
+ raise credentials_exception
57
+ return token_data
58
+
59
+
60
+ def get_current_user (token : str = Depends (oauth2_scheme ), db : Session = Depends (get_db )) -> User :
61
+ """
62
+ Retrieves the current user from the database using the provided access token.
63
+
64
+ Args:
65
+ token (str): The JWT access token.
66
+ db (Session): The database session.
67
+
68
+ Returns:
69
+ User: The current user object.
70
+
71
+ Raises:
72
+ HTTPException: If the token is invalid or the user is not found.
73
+ """
74
+ credentials_exception = HTTPException (
75
+ status_code = status .HTTP_401_UNAUTHORIZED ,
76
+ detail = "Could not validate credentials" ,
77
+ headers = {"WWW-Authenticate" : "Bearer" },
78
+ )
79
+ token_data = verify_access_token (token , credentials_exception )
80
+ user = db .query (User ).filter (User .id == token_data .id ).first ()
81
+ if user is None :
82
+ raise credentials_exception
83
+ return user
0 commit comments