From 8ae87ca4493aa2aa75435cc2448f7759c38f372f Mon Sep 17 00:00:00 2001 From: AkshayPall Date: Sun, 11 Sep 2022 17:01:40 -0700 Subject: [PATCH] Add User Status Checks on Endpoints --- ambuda/auth.py | 3 +- ambuda/checks.py | 1 + ambuda/models/auth.py | 21 ++++++++ ambuda/queries.py | 6 ++- ambuda/utils/user_mixins.py | 8 +++ ambuda/views/auth.py | 9 ++++ .../99379162619e_add_user_statuses.py | 54 +++++++++++++++++++ test/ambuda/conftest.py | 31 +++++++++++ test/ambuda/test_models.py | 31 +++++++++++ test/ambuda/views/test_admin.py | 10 ++++ test/ambuda/views/test_auth.py | 15 ++++++ 11 files changed, 187 insertions(+), 2 deletions(-) create mode 100644 migrations/versions/99379162619e_add_user_statuses.py diff --git a/ambuda/auth.py b/ambuda/auth.py index e5945ce0..bb190bc9 100644 --- a/ambuda/auth.py +++ b/ambuda/auth.py @@ -19,7 +19,8 @@ def _load_user(user_id: int) -> Optional[User]: import current_user`) and as a template variable injected into each template. """ session = get_session() - return session.query(User).get(int(user_id)) + user = session.query(User).get(int(user_id)) + return user if user and user.is_ok else None def _unauthorized(): diff --git a/ambuda/checks.py b/ambuda/checks.py index 965d0aa4..6856aec5 100644 --- a/ambuda/checks.py +++ b/ambuda/checks.py @@ -134,6 +134,7 @@ def _check_lookup_tables(session) -> list[str]: def _check_bot_user(session) -> list[str]: """Check that the ambuda-bot user exists.""" username = consts.BOT_USERNAME + # Assume bot user is active bot_user = session.query(db.User).filter_by(username=username).first() if bot_user: return [] diff --git a/ambuda/models/auth.py b/ambuda/models/auth.py index f50e8687..c4d7692b 100644 --- a/ambuda/models/auth.py +++ b/ambuda/models/auth.py @@ -28,6 +28,15 @@ class User(AmbudaUserMixin, Base): #: The user's self-description. description = Column(Text_, nullable=False, default="") + #: If the user deleted their account. + is_deleted = Column(Boolean, nullable=False, default=False) + + #: If the user was banned.. + is_banned = Column(Boolean, nullable=False, default=False) + + #: If the user has verified their email. + is_verified = Column(Boolean, nullable=False, default=False) + #: All roles available for this user. roles = relationship("Role", secondary="user_roles") @@ -35,6 +44,18 @@ def set_password(self, raw_password: str): """Hash and save the given password.""" self.password_hash = generate_password_hash(raw_password) + def set_is_deleted(self, is_deleted: bool): + """Update is_deleted.""" + self.is_deleted = is_deleted + + def set_is_banned(self, is_banned: bool): + """Update is_banned.""" + self.is_banned = is_banned + + def set_is_verified(self, is_verified: bool): + """Update is_verified.""" + self.is_verified = is_verified + def check_password(self, raw_password: str) -> bool: """Check if the given password matches the user's hash.""" return check_password_hash(self.password_hash, raw_password) diff --git a/ambuda/queries.py b/ambuda/queries.py index e093b16c..a3a7a6f2 100644 --- a/ambuda/queries.py +++ b/ambuda/queries.py @@ -204,7 +204,11 @@ def page(project_id, page_slug: str) -> Optional[db.Page]: def user(username: str) -> Optional[db.User]: session = get_session() - return session.query(db.User).filter_by(username=username).first() + return ( + session.query(db.User) + .filter_by(username=username, is_deleted=False, is_banned=False) + .first() + ) def create_user(*, username: str, email: str, raw_password: str) -> db.User: diff --git a/ambuda/utils/user_mixins.py b/ambuda/utils/user_mixins.py index a0e07bfb..7c9cd3c0 100644 --- a/ambuda/utils/user_mixins.py +++ b/ambuda/utils/user_mixins.py @@ -31,6 +31,10 @@ def is_moderator(self) -> bool: def is_admin(self) -> bool: return False + @property + def is_ok(self) -> bool: + return True + class AmbudaUserMixin(UserMixin): def has_role(self, role: SiteRole) -> bool: @@ -59,3 +63,7 @@ def is_moderator(self) -> bool: @property def is_admin(self) -> bool: return self.has_role(SiteRole.ADMIN) + + @property + def is_ok(self) -> bool: + return not (self.is_deleted or self.is_banned) diff --git a/ambuda/views/auth.py b/ambuda/views/auth.py index 071dd80d..47d5c5dd 100644 --- a/ambuda/views/auth.py +++ b/ambuda/views/auth.py @@ -129,6 +129,7 @@ class ResetPasswordFromTokenForm(FlaskForm): @bp.route("/register", methods=["GET", "POST"]) def register(): if current_user.is_authenticated: + logout_if_not_ok() return redirect(url_for("site.index")) form = SignupForm() @@ -152,6 +153,7 @@ def register(): @bp.route("/sign-in", methods=["GET", "POST"]) def sign_in(): if current_user.is_authenticated: + logout_if_not_ok() return redirect(url_for("site.index")) form = SignInForm() @@ -165,6 +167,13 @@ def sign_in(): return render_template("auth/sign-in.html", form=form) +def logout_if_not_ok(): + # Check if user is now deleted or banned + user = q.user(username=current_user.username) + if user and not user.is_ok: + logout_user() + + @bp.route("/sign-out") def sign_out(): logout_user() diff --git a/migrations/versions/99379162619e_add_user_statuses.py b/migrations/versions/99379162619e_add_user_statuses.py new file mode 100644 index 00000000..f20de024 --- /dev/null +++ b/migrations/versions/99379162619e_add_user_statuses.py @@ -0,0 +1,54 @@ +"""Add user status + +Revision ID: 99379162619e +Revises: bc48af5ec2e6 +Create Date: 2022-09-11 17:20:02.341713 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy import orm + +# revision identifiers, used by Alembic. +revision = "99379162619e" +down_revision = "bc48af5ec2e6" +branch_labels = None +depends_on = None + +Base = orm.declarative_base() + + +class User(Base): + __tablename__ = "users" + id = sa.Column(sa.Integer, primary_key=True) + is_deleted = sa.Column(sa.Boolean, nullable=False, default=False) + is_banned = sa.Column(sa.Boolean, nullable=False, default=False) + is_verified = sa.Column(sa.Boolean, nullable=False, default=False) + + +def upgrade() -> None: + op.add_column("users", sa.Column("is_deleted", sa.Boolean, nullable=True)) + op.add_column("users", sa.Column("is_banned", sa.Boolean, nullable=True)) + op.add_column("users", sa.Column("is_verified", sa.Boolean, nullable=True)) + + bind = op.get_bind() + session = orm.Session(bind=bind) + for user in session.query(User).all(): + user.is_deleted = False + user.is_banned = False + user.is_verified = False + session.add(user) + session.commit() + + with op.batch_alter_table("users") as batch_op: + batch_op.alter_column("is_deleted", existing_type=sa.BOOLEAN(), nullable=False) + batch_op.alter_column("is_banned", existing_type=sa.BOOLEAN(), nullable=False) + batch_op.alter_column("is_verified", existing_type=sa.BOOLEAN(), nullable=False) + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("users", "is_deleted") + op.drop_column("users", "is_banned") + op.drop_column("users", "is_verified") + # ### end Alembic commands ### diff --git a/test/ambuda/conftest.py b/test/ambuda/conftest.py index 670c59cf..981582a3 100644 --- a/test/ambuda/conftest.py +++ b/test/ambuda/conftest.py @@ -73,6 +73,19 @@ def initialize_test_db(): session.add(admin) session.flush() + # Deleted and Banned + deleted_admin = db.User(username="sandrocottus-deleted", email="cgm@ambuda.org") + deleted_admin.set_password("maurya") + deleted_admin.set_is_deleted(True) + + banned = db.User(username="sikander-banned", email="alex@ambuda.org") + banned.set_password("onesicritus") + banned.set_is_banned(True) + + session.add(deleted_admin) + session.add(banned) + session.flush() + # Roles p1_role = db.Role(name=db.SiteRole.P1.value) p2_role = db.Role(name=db.SiteRole.P2.value) @@ -84,8 +97,12 @@ def initialize_test_db(): rama.roles = [p1_role, p2_role] admin.roles = [p1_role, p2_role, admin_role] + deleted_admin.roles = [p1_role, p2_role, admin_role] + banned.roles = [p1_role] session.add(rama) session.add(admin) + session.add(deleted_admin) + session.add(banned) session.flush() # Proofreading @@ -152,3 +169,17 @@ def admin_client(flask_app): session = get_session() user = session.query(db.User).filter_by(username="akprasad").first() return flask_app.test_client(user=user) + + +@pytest.fixture() +def deleted_client(flask_app): + session = get_session() + user = session.query(db.User).filter_by(username="sandrocottus-deleted").first() + return flask_app.test_client(user=user) + + +@pytest.fixture() +def banned_client(flask_app): + session = get_session() + user = session.query(db.User).filter_by(username="sikander-banned").first() + return flask_app.test_client(user=user) diff --git a/test/ambuda/test_models.py b/test/ambuda/test_models.py index 59a04416..4b8651bb 100644 --- a/test/ambuda/test_models.py +++ b/test/ambuda/test_models.py @@ -8,6 +8,18 @@ def _cleanup(session, *objects): session.commit() +def test_user__is_ok_when_created(client): + session = get_session() + user = db.User(username="test", email="test@ambuda.org") + user.set_password("my-password") + session.add(user) + session.commit() + + assert user.is_ok + + _cleanup(session, user) + + def test_user__set_and_check_password(client): session = get_session() user = db.User(username="test", email="test@ambuda.org") @@ -38,6 +50,25 @@ def test_user__set_and_check_role(client): _cleanup(session, user) +def test_user__deletion(client): + session = get_session() + + # Check active user + user = db.User(username="test", email="test@ambuda.org") + user.set_password("my-password") + session.add(user) + session.commit() + assert user.is_ok + + # Deleted + user.set_is_deleted(True) + session.add(user) + session.commit() + assert not user.is_ok + + _cleanup(session, user) + + def test_role__repr(client): role = db.Role(name="foo") assert repr(role) == "" diff --git a/test/ambuda/views/test_admin.py b/test/ambuda/views/test_admin.py index ad8d2cb4..a341676f 100644 --- a/test/ambuda/views/test_admin.py +++ b/test/ambuda/views/test_admin.py @@ -8,6 +8,16 @@ def test_admin_index__auth(admin_client): assert resp.status_code == 200 +def test_admin_index__inactive(deleted_client, banned_client): + assert deleted_client.get("/admin/").status_code == 404 + assert banned_client.get("/admin/").status_code == 404 + + def test_admin_text__unauth(client): resp = client.get("/admin/text/") assert resp.status_code == 404 + + +def test_admin_text__inactive(deleted_client, banned_client): + assert deleted_client.get("/admin/text/").status_code == 404 + assert banned_client.get("/admin/text/").status_code == 404 diff --git a/test/ambuda/views/test_auth.py b/test/ambuda/views/test_auth.py index cbcf1a6d..32d0134c 100644 --- a/test/ambuda/views/test_auth.py +++ b/test/ambuda/views/test_auth.py @@ -65,6 +65,7 @@ def test_register__unauth_post__ok(client): r = client.post("/register", data=data) assert r.status_code == 302 assert current_user.username == "krishna" + assert current_user.is_ok def test_register__auth(rama_client): @@ -72,6 +73,13 @@ def test_register__auth(rama_client): assert r.status_code == 302 +def test_register__banned(banned_client): + with banned_client: + r = banned_client.get("/register") + assert r.status_code == 200 + assert current_user.is_anonymous + + def test_sign_in__unauth(client): r = client.get("/sign-in") assert ">Sign in to Ambuda<" in r.text @@ -115,6 +123,13 @@ def test_sign_in__auth(rama_client): assert r.status_code == 302 +def test_sign_in__banned(banned_client): + with banned_client: + r = banned_client.get("/sign-in") + assert r.status_code == 200 + assert current_user.is_anonymous + + def test_sign_out__unauth(client): with client: r = client.get("/sign-out")