Skip to content

Commit

Permalink
Develop OAuth logout (#732)
Browse files Browse the repository at this point in the history
  • Loading branch information
Repumba committed Feb 20, 2023
1 parent 093ad05 commit 4456c1a
Show file tree
Hide file tree
Showing 16 changed files with 246 additions and 22 deletions.
29 changes: 29 additions & 0 deletions dev/oidc/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,34 @@
"Content-Type": "application/json",
}

print("[+] Creating user foo in MWDB with password foofoofoo")

response = mwdb_session.post(
"http://127.0.0.1/api/user/foo",
json={
"email": "foo@mwdb.local",
"feed_quality": "high",
"send_email": False,
"additional_info": "string"
}
)
response.raise_for_status()

response = mwdb_session.get(
"http://127.0.0.1/api/user/foo/change_password"
)
response.raise_for_status()
password_token = response.json()['token']

response = requests.post(
"http://127.0.0.1/api/auth/change_password",
json={
"password": "foofoofoo",
"token": password_token
}
)
response.raise_for_status()

print("[+] Registering new OIDC provider")

response = mwdb_session.post(
Expand All @@ -99,6 +127,7 @@
"userinfo_endpoint": "http://keycloak.:8080/realms/mwdb-oidc-dev/protocol/openid-connect/userinfo",
"token_endpoint": "http://keycloak.:8080/realms/mwdb-oidc-dev/protocol/openid-connect/token",
"jwks_endpoint": "http://keycloak.:8080/realms/mwdb-oidc-dev/protocol/openid-connect/certs",
"logout_endpoint": "http://127.0.0.1:8080/realms/mwdb-oidc-dev/protocol/openid-connect/logout",
},
)
response.raise_for_status()
Expand Down
2 changes: 2 additions & 0 deletions docker-compose-oidc-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ services:
env_file:
# NOTE: use gen_vars.sh in order to generate this file
- postgres-vars.env
ports:
- "127.0.0.1:54322:5432"
redis:
image: redis:alpine
mailhog:
Expand Down
7 changes: 6 additions & 1 deletion mwdb/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
OpenIDAuthenticateResource,
OpenIDAuthorizeResource,
OpenIDBindAccountResource,
OpenIDLogoutResource,
OpenIDProviderResource,
OpenIDRegisterUserResource,
OpenIDSingleProviderResource,
Expand Down Expand Up @@ -166,10 +167,13 @@ def require_auth():
auth = request.headers.get("Authorization")

g.auth_user = None
g.auth_provider = None

if auth and auth.startswith("Bearer "):
token = auth.split(" ", 1)[1]
g.auth_user = User.verify_session_token(token)
result = User.verify_session_token(token)
if result is not None:
g.auth_user, g.auth_provider = result
# Not a session token? Maybe APIKey token
if g.auth_user is None:
g.auth_user = APIKey.verify_token(token)
Expand Down Expand Up @@ -340,6 +344,7 @@ def require_auth():
api.add_resource(OpenIDAuthorizeResource, "/oauth/<provider_name>/authorize")
api.add_resource(OpenIDBindAccountResource, "/oauth/<provider_name>/bind_account")
api.add_resource(OpenIDRegisterUserResource, "/oauth/<provider_name>/register")
api.add_resource(OpenIDLogoutResource, "/oauth/<provider_name>/logout")

# Remote endpoints
api.add_resource(RemoteListResource, "/remote")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""create_logout_endpoint_column
Revision ID: bd93d1497694
Revises: 717b5da712b8
Create Date: 2023-01-04 17:14:22.271856
"""
import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "bd93d1497694"
down_revision = "717b5da712b8"
branch_labels = None
depends_on = None


def upgrade():
op.execute(
"""
ALTER TABLE public.openid_provider
ADD logout_endpoint text;
"""
)


def downgrade():
op.execute(
"""
ALTER TABLE public.openid_provider
DROP COLUMN logout_endpoint;
"""
)
1 change: 1 addition & 0 deletions mwdb/model/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class OpenIDProvider(db.Model):
token_endpoint = db.Column(db.Text, nullable=False)
userinfo_endpoint = db.Column(db.Text, nullable=False)
jwks_endpoint = db.Column(db.Text, nullable=True)
logout_endpoint = db.Column(db.Text, nullable=True)

identities = db.relationship(
"OpenIDUserIdentity",
Expand Down
29 changes: 17 additions & 12 deletions mwdb/model/user.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
import os
from typing import Optional, Tuple

import bcrypt
from flask import g
Expand Down Expand Up @@ -154,19 +155,19 @@ def create(
db.session.commit()
return user

def _generate_token(self, fields, scope, expiration):
def _generate_token(self, user_fields, scope, expiration, **extra_fields):
token_data = {"login": self.login, **extra_fields}
for field in user_fields:
token_data[field] = getattr(self, field)
token = generate_token(
dict(
[("login", self.login)]
+ [(field, getattr(self, field)) for field in fields]
),
token_data,
scope,
expiration,
)
return token

@staticmethod
def _verify_token(token, fields, scope):
def _verify_token(token, fields, scope) -> Optional[Tuple["User", Optional[str]]]:
data = verify_token(token, scope)
if data is None:
return None
Expand All @@ -182,13 +183,14 @@ def _verify_token(token, fields, scope):
if data[field] != getattr(user_obj, field):
return None

return user_obj
return user_obj, data.get("provider")

def generate_session_token(self):
def generate_session_token(self, provider=None):
return self._generate_token(
["password_ver", "identity_ver"],
scope=AuthScope.session,
expiration=24 * 3600,
provider=provider,
)

def generate_set_password_token(self):
Expand All @@ -199,18 +201,21 @@ def generate_set_password_token(self):
)

@staticmethod
def verify_session_token(token):
def verify_session_token(token) -> Optional[Tuple["User", Optional[str]]]:
return User._verify_token(
token, ["password_ver", "identity_ver"], scope=AuthScope.session
token,
["password_ver", "identity_ver"],
scope=AuthScope.session,
)

@staticmethod
def verify_set_password_token(token):
return User._verify_token(
def verify_set_password_token(token) -> Optional["User"]:
result = User._verify_token(
token,
["password_ver"],
scope=AuthScope.set_password,
)
return None if result is None else result[0]

@staticmethod
def verify_legacy_token(token):
Expand Down
5 changes: 4 additions & 1 deletion mwdb/resources/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def post(self):
"token": auth_token,
"capabilities": user.capabilities,
"groups": user.group_names,
"provider": None,
}
)

Expand Down Expand Up @@ -235,6 +236,7 @@ def post(self):
schema = AuthSetPasswordRequestSchema()
obj = loads_schema(request.get_data(as_text=True), schema)

# verify_set_password_token return tuple (user_obj, auth_provider)
user = User.verify_set_password_token(obj["token"])
if user is None:
raise Forbidden("Set password token expired")
Expand Down Expand Up @@ -422,9 +424,10 @@ def post(self):
return schema.dump(
{
"login": user.login,
"token": user.generate_session_token(),
"token": user.generate_session_token(provider=g.auth_provider),
"capabilities": user.capabilities,
"groups": user.group_names,
"provider": g.auth_provider,
}
)

Expand Down
65 changes: 63 additions & 2 deletions mwdb/resources/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from mwdb.schema.oauth import (
OpenIDAuthorizeRequestSchema,
OpenIDLoginResponseSchema,
OpenIDLogoutLinkResponseSchema,
OpenIDProviderCreateRequestSchema,
OpenIDProviderItemResponseSchema,
OpenIDProviderListResponseSchema,
Expand Down Expand Up @@ -100,6 +101,10 @@ def post(self):
if obj["jwks_endpoint"]:
jwks_endpoint = obj["jwks_endpoint"]

logout_endpoint = None
if obj["logout_endpoint"]:
logout_endpoint = obj["logout_endpoint"]

if db.session.query(
exists().where(and_(OpenIDProvider.name == obj["name"]))
).scalar():
Expand All @@ -115,6 +120,7 @@ def post(self):
token_endpoint=obj["token_endpoint"],
userinfo_endpoint=obj["userinfo_endpoint"],
jwks_endpoint=jwks_endpoint,
logout_endpoint=logout_endpoint,
)

group_name = ("OpenID_" + obj["name"])[:32]
Expand Down Expand Up @@ -250,6 +256,10 @@ def put(self, provider_name):
if jwks_endpoint is not None:
provider.jwks_endpoint = jwks_endpoint

logout_endpoint = obj["logout_endpoint"]
if logout_endpoint is not None:
provider.logout_endpoint = logout_endpoint

db.session.commit()

logger.info("Provider updated", extra={"provider": provider_name})
Expand Down Expand Up @@ -394,7 +404,7 @@ def post(self, provider_name):
user.logged_on = datetime.datetime.now()
db.session.commit()

auth_token = user.generate_session_token()
auth_token = user.generate_session_token(provider=provider_name)

logger.info(
"User logged in via OpenID Provider",
Expand All @@ -407,6 +417,7 @@ def post(self, provider_name):
"token": auth_token,
"capabilities": user.capabilities,
"groups": user.group_names,
"provider": provider_name,
}
)

Expand Down Expand Up @@ -486,7 +497,7 @@ def post(self, provider_name):
user.logged_on = datetime.datetime.now()
db.session.commit()

auth_token = user.generate_session_token()
auth_token = user.generate_session_token(provider=provider_name)

user_private_group = next(
(g for g in user.groups if g.name == user.login), None
Expand All @@ -506,6 +517,7 @@ def post(self, provider_name):
"token": auth_token,
"capabilities": user.capabilities,
"groups": user.group_names,
"provider": provider_name,
}
)

Expand Down Expand Up @@ -620,3 +632,52 @@ def get(self):
identity.provider.name for identity in g.auth_user.openid_identities
]
return OpenIDProviderListResponseSchema().dump({"providers": identities})


@rate_limited_resource
class OpenIDLogoutResource(Resource):
@requires_authorization
def get(self, provider_name):
"""
---
summary: Get logout endpoint url
description: |
Get logout endpoint url
security:
- bearerAuth: []
tags:
- auth
parameters:
- in: path
name: provider_name
schema:
type: string
description: OpenID provider name.
responses:
200:
description: When logout endpoint was found
content:
application/json:
schema: OpenIDLogoutLinkResponseSchema
404:
description: Requested provider doesn't exist
412:
description: |
Logout endpoint is not specified for this provider
503:
description: |
Request canceled due to database statement timeout.
"""
provider = (
db.session.query(OpenIDProvider)
.filter(OpenIDProvider.name == provider_name)
.first()
)
if not provider:
raise NotFound(f"Requested provider name '{provider_name}' not found")

if not provider.logout_endpoint:
raise NotFound(f"Logout endpoint is not configured for '{provider_name}'")

schema = OpenIDLogoutLinkResponseSchema()
return schema.dump({"url": provider.logout_endpoint})
1 change: 1 addition & 0 deletions mwdb/schema/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class AuthSuccessResponseSchema(UserLoginSchemaBase):
token = fields.Str(required=True, allow_none=False)
capabilities = fields.List(fields.Str(), required=True, allow_none=False)
groups = fields.List(fields.Str(), required=True, allow_none=False)
provider = fields.Str(required=True, allow_none=True)


class AuthValidateTokenResponseSchema(UserLoginSchemaBase):
Expand Down
6 changes: 6 additions & 0 deletions mwdb/schema/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class OpenIDProviderCreateRequestSchema(Schema):
token_endpoint = fields.Str(required=True, allow_none=False)
userinfo_endpoint = fields.Str(required=True, allow_none=False)
jwks_endpoint = fields.Str(required=True, allow_none=True)
logout_endpoint = fields.Str(required=False, allow_none=True)


class OpenIDProviderItemResponseSchema(OpenIDProviderCreateRequestSchema):
Expand All @@ -23,6 +24,7 @@ class OpenIDProviderUpdateRequestSchema(Schema):
token_endpoint = fields.Str(missing=None)
userinfo_endpoint = fields.Str(missing=None)
jwks_endpoint = fields.Str(missing=None)
logout_endpoint = fields.Str(missing=None)


class OpenIDAuthorizeRequestSchema(Schema):
Expand All @@ -43,3 +45,7 @@ class OpenIDLoginResponseSchema(Schema):
authorization_url = fields.Str(required=True, allow_none=False)
state = fields.Str(required=True, allow_none=False)
nonce = fields.Str(required=True, allow_none=False)


class OpenIDLogoutLinkResponseSchema(Schema):
url = fields.Str(required=True, allow_none=False)

0 comments on commit 4456c1a

Please sign in to comment.