In [1]:
## Trying out code for the active user logic

import datetime
from datetime import timedelta
from jose import jwt, JWTError, ExpiredSignatureError
from fastapi import HTTPException, status, Depends
from sqlalchemy.orm import Session
from typing import Annotated
from fastapi_jwt_sql.database.database import SessionLocal
from fastapi_jwt_sql.database.sqlmodels.jwt_tokens import TokenBlacklist


In [11]:
(datetime.datetime.now(datetime.UTC) + timedelta(minutes=2)).timestamp()

1717485952.075249

In [12]:
token_input = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyLCJleHAiOjE3MTc0ODU5NTIsImp0aSI6IjEyM2F2YyJ9.N1NOl0UdGV-DMCF3SVPc--aORSTi-GOeaTqXHM8xaJg"

In [4]:
SECRET_KEY = "hi"
ALGORITHM = "HS256"


## Dependencies:

In [5]:

def get_db():
    """Create a single session for you database request.

    Yields:
        The database connection.
    """
    db = SessionLocal()
    try:
        yield db
    finally:
        db.close()


db_dependency = Annotated[Session, Depends(get_db)]

In [6]:


def add_jti_to_blacklist(db: db_dependency, jti: str, expires_at: datetime.datetime) -> None:
    jti_to_add = TokenBlacklist(
        token_jti=jti,
        expires_at=expires_at,
        added_on=datetime.datetime.now(datetime.UTC),
    )
    db.add(jti_to_add)
    db.commit()


def is_jti_blacklisted(db: db_dependency, jti) -> bool:
    jti = db.query(TokenBlacklist).filter_by(token_jti=jti).scalar()
    return not (jti is None)

## Code logic:

In [None]:
def is_login_valid(token: str, secret, algorithms):
    payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
    # automatically raises expired error if expired.
    username = payload.get("sub")
    jti: str = payload.get("jti")  # type:ignore
    if username is None or jti is None:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials. (2)")
    if is_jti_blacklisted(db=db, jti=jti):
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED, detail="You logged out. Log in again.")
    return payload

In [None]:

@auth_router.post("/token", response_model=Token)
def set_access_token(token: str) -> dict[str, str]:
    # return needs to match the structure of the fastapi pydantic model of type Token

    return {
        "access_token": token,
        "token_type": "bearer"
    }

In [None]:
from fastapi_jwt_sql.api.auth import REFRESH_SECRET
from fastapi_jwt_sql.database.sqlmodels.jwt_tokens import RefreshTokens


def get_user_access_token(token: Annotated[str, Depends(oauth2_bearer)], db: db_dependency):
    try:
        payload = is_login_valid(token, secret=SECRET_KEY, algorithms=[ALGORITHM])
        username = payload.get("sub")
        return payload
    except ExpiredSignatureError as expired:
        # find the refresh_token.
        username = payload.get("sub")
        refresh_token = db.query(RefreshTokens).filter_by(
            username=username).scalar()
        try:
            payload = is_login_valid(refresh_token, secret=REFRESH_SECRET, algorithms=[ALGORITHM])
                # generate new auth token, link it to auth url.
            user_id = payload.get("user_id")
            new_auth_token = _create_jwt_token(username=username,
                                               user_id=user_id,
                                               type="auth",
                                               expires_delta=timedelta(minutes=ACCESS_AUTH_TOKEN_EXPIRATION_MINUTES),)
            set_access_token(new_auth_token)
            return {"msg": "success"}
        except JWTError as er:
            # remove active user from active_user_db. 
            raise er
    except JWTError as e:
        raise e

In [None]:

def login_for_access_token(form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
                           db: db_dependency) -> dict[str, str]:
    try:
        user: Users = authenticate_user(
            form_data.username, form_data.password, db)
    except:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate user. (1)")

    auth_token = _create_jwt_token(
        username=user.username,
        user_id=user.id,
        type="auth",
        expires_delta=timedelta(minutes=ACCESS_AUTH_TOKEN_EXPIRATION_MINUTES),
    )
    # active_user = LoggedUser(username=user.username,
    #                          logged_in_on=datetime.datetime.now(datetime.UTC))
    # db.add(active_user)
    # db.commit()

    set_access_token(auth_token)
    return {"msg": "login success"}


In [7]:
# from fastapi_jwt_sql.api.auth import REFRESH_SECRET
# from fastapi_jwt_sql.database.sqlmodels.jwt_tokens import RefreshTokens


# def get_user(token: str, db:db_dependency):

#     try:
#         payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
#         # automatically raises expired error if expired.
#         username = payload.get("sub")
#         jti: str = payload.get("jti")  # type:ignore
#         if username is None or jti is None:
#             raise HTTPException(
#                 status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials. (2)")
#         if is_jti_blacklisted(db=db, jti=jti):
#             raise HTTPException(
#                 status_code=status.HTTP_401_UNAUTHORIZED, detail="You logged out. Log in again.")
#         return payload
#     except ExpiredSignatureError as expired:
#         # find the refresh_token.
#         refresh_token = db.query(RefreshTokens).filter_by(
#             username=username).scalar()
#         try:
#             refresh_payload = jwt.decode(refresh_token, REFRESH_SECRET, algorithms=[ALGORITHM])
#             refresh_jti = refresh_payload.get("jti")
#             refresh_username = refresh_payload.get("sub")
#             if refresh_username != username:
#                 raise ValueError
#             user_id =refresh_payload.get("user_id")
#             if user_id is None or refresh_jti is None:
#                 raise HTTPException(
#                     status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials. (2)")
#             if is_jti_blacklisted(db=db, jti=refresh_jti):
#                 raise HTTPException(
#                     status_code=status.HTTP_401_UNAUTHORIZED, detail="You logged out. Log in again.")
#                 #generate new auth token, link it to auth url. 
#             new_auth_token = _create_jwt_token(username=username,
#                                                user_id=user_id,
#                                                type="auth",
#                                                expires_delta=timedelta(minutes=ACCESS_AUTH_TOKEN_EXPIRATION_MINUTES),)
#         except JWTError as er:
#             raise er
#     except JWTError as e:
#         raise e

IndentationError: expected an indented block after 'else' statement on line 21 (3703356073.py, line 24)

In [8]:
def get_active_user(token: str):
    """Checks the validity of the token.

    Args:
        token:
            Dependency of the OAuth2PasswordBearer, this is where the token is stored.

    Raises an HTTPException when the user is not logged in, could not be verified or if JWT cannot be verified.

    Returns a dict with the key user information.
    """
    try:
        payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
        # automatically raises expired error if expired.
        username = payload.get("sub")
        jti: str = payload.get("jti")  # type:ignore
        exp = payload.get("exp")
        if username is None or jti is None:
            raise HTTPException(
                status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials. (2)")
        return {"username": username, "jti": jti, "exp": exp}
    except ExpiredSignatureError as expired:
        raise expired
    except JWTError as e:
        raise e
    

def get_current_user(db: db_dependency, token: str):
    """Logic to validate user, by checking for active user and if the jti is not blacklisted.
    TODO: implement logic with refresh token, (also not blacklisted), if valid generate a new auth token.

    Args:
        db:
            DB dependency, to check if jti is blacklisted
        token:
            The authentication token, retrieved from the token url.

    Returns:
        The active user dict {username, jti, exp}

    Raises HTTPException when login is expired, or could not validate the user.
    """


    try:
        active_user = get_active_user(token)
        jti = active_user.get("jti")
        print(active_user)
    except ExpiredSignatureError:
        print("expired")
        pass
    except JWTError:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate user. (2)")
    if is_jti_blacklisted(db=db, jti=jti):
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED, detail="You logged out. Log in again.")
    else:

        active_user = get_active_user(token)

        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED, detail="Login Expired, log in again. (2)")
    # # Third, if not valid, Check if Refresh token is still valid. If yes, generate new AUTH token.
    # return active_user


In [9]:
async def logout(active_user: Annotated[dict, Depends(get_active_user)], db: db_dependency):
    jti= active_user.get("jti")
    exp= active_user.get("exp")
    if exp and jti is not None:
        exp_timestamp: datetime.date = datetime.datetime.fromtimestamp(exp)
        add_jti_to_blacklist(jti=jti, expires_at=exp_timestamp,
                             db=db)
    return {"logout successful": active_user}

In [13]:
get_active_user(token=token_input)

{'username': '1234567890', 'jti': '123avc', 'exp': 1717485952}

In [14]:
get_active_user(token=token_input)

{'username': '1234567890', 'jti': '123avc', 'exp': 1717485952}

In [15]:
get_current_user(db=db_dependency, token=token_input)

{'username': '1234567890', 'jti': '123avc', 'exp': 1717485952}


AttributeError: type object 'TokenBlacklist' has no attribute '_query_cls'