diff --git a/src/dispatch/auth/models.py b/src/dispatch/auth/models.py index c5f3d30267ed..68092be34840 100644 --- a/src/dispatch/auth/models.py +++ b/src/dispatch/auth/models.py @@ -4,6 +4,7 @@ import secrets from datetime import datetime, timedelta from uuid import uuid4 +from typing import Optional import bcrypt from jose import jwt @@ -52,6 +53,7 @@ def hash_password(password: str): class DispatchUser(Base, TimeStampMixin): """SQLAlchemy model for a Dispatch user.""" + __table_args__ = {"schema": "dispatch_core"} id = Column(Integer, primary_key=True) @@ -104,6 +106,7 @@ def get_organization_role(self, organization_slug: OrganizationSlug): class DispatchUserOrganization(Base, TimeStampMixin): """SQLAlchemy model for the relationship between users and organizations.""" + __table_args__ = {"schema": "dispatch_core"} dispatch_user_id = Column(Integer, ForeignKey(DispatchUser.id), primary_key=True) dispatch_user = relationship(DispatchUser, backref="organizations") @@ -116,6 +119,7 @@ class DispatchUserOrganization(Base, TimeStampMixin): class DispatchUserProject(Base, TimeStampMixin): """SQLAlchemy model for the relationship between users and projects.""" + dispatch_user_id = Column(Integer, ForeignKey(DispatchUser.id), primary_key=True) dispatch_user = relationship(DispatchUser, backref="projects") @@ -129,6 +133,7 @@ class DispatchUserProject(Base, TimeStampMixin): class UserProject(DispatchBase): """Pydantic model for a user's project membership.""" + project: ProjectRead default: bool | None = False role: str | None = None @@ -136,6 +141,7 @@ class UserProject(DispatchBase): class UserOrganization(DispatchBase): """Pydantic model for a user's organization membership.""" + organization: OrganizationRead default: bool | None = False role: str | None = None @@ -143,6 +149,7 @@ class UserOrganization(DispatchBase): class UserBase(DispatchBase): """Base Pydantic model for user data.""" + email: EmailStr projects: list[UserProject] | None = [] organizations: list[UserOrganization] | None = [] @@ -158,6 +165,7 @@ def email_required(cls, v): class UserLogin(UserBase): """Pydantic model for user login data.""" + password: str @field_validator("password") @@ -171,6 +179,7 @@ def password_required(cls, v): class UserRegister(UserLogin): """Pydantic model for user registration data.""" + password: str = None @field_validator("password", mode="before") @@ -183,12 +192,14 @@ def password_required(cls, v): class UserLoginResponse(DispatchBase): """Pydantic model for the response after user login.""" + projects: list[UserProject] | None token: str | None = None class UserRead(UserBase): """Pydantic model for reading user data.""" + id: PrimaryKey role: str | None = None experimental_features: bool | None @@ -196,15 +207,17 @@ class UserRead(UserBase): class UserUpdate(DispatchBase): """Pydantic model for updating user data.""" + id: PrimaryKey - projects: list[UserProject] | None - organizations: list[UserOrganization] | None - experimental_features: bool | None - role: str | None = None + projects: Optional[list[UserProject]] = None + organizations: Optional[list[UserOrganization]] + experimental_features: Optional[bool] = None + role: Optional[str] = None class UserPasswordUpdate(DispatchBase): """Pydantic model for password updates only.""" + current_password: str new_password: str @@ -231,6 +244,7 @@ def password_required(cls, v): class AdminPasswordReset(DispatchBase): """Pydantic model for admin password resets.""" + new_password: str @field_validator("new_password") @@ -248,6 +262,7 @@ def validate_password(cls, v): class UserCreate(DispatchBase): """Pydantic model for creating a new user.""" + email: EmailStr password: str | None = None projects: list[UserProject] | None @@ -263,16 +278,19 @@ def hash(cls, v): class UserRegisterResponse(DispatchBase): """Pydantic model for the response after user registration.""" + token: str | None = None class UserPagination(Pagination): """Pydantic model for paginated user results.""" + items: list[UserRead] = [] class MfaChallengeStatus(DispatchEnum): """Enumeration of possible MFA challenge statuses.""" + APPROVED = "approved" DENIED = "denied" EXPIRED = "expired" @@ -281,6 +299,7 @@ class MfaChallengeStatus(DispatchEnum): class MfaChallenge(Base, TimeStampMixin): """SQLAlchemy model for an MFA challenge event.""" + id = Column(Integer, primary_key=True, autoincrement=True) valid = Column(Boolean, default=False) reason = Column(String, nullable=True) @@ -293,11 +312,13 @@ class MfaChallenge(Base, TimeStampMixin): class MfaPayloadResponse(DispatchBase): """Pydantic model for the response to an MFA challenge payload.""" + status: str class MfaPayload(DispatchBase): """Pydantic model for an MFA challenge payload.""" + action: str project_id: int challenge_id: str diff --git a/src/dispatch/cli.py b/src/dispatch/cli.py index e8fc869d9e4b..1129f79164e8 100644 --- a/src/dispatch/cli.py +++ b/src/dispatch/cli.py @@ -285,6 +285,7 @@ def prompt_for_confirmation(command: str) -> bool: f"Warning: You are about to {command} a remote database.", fg="yellow", ) + database_name = click.prompt(f"Please enter the database name (env = {DATABASE_NAME})") if database_name != DATABASE_NAME: click.secho( @@ -292,8 +293,11 @@ def prompt_for_confirmation(command: str) -> bool: fg="red", ) return False - sqlalchemy_database_uri = f"postgresql+psycopg2://{config._DATABASE_CREDENTIAL_USER}:{config._QUOTED_DATABASE_PASSWORD}@{database_hostname}:{config.DATABASE_PORT}/{database_name}" + if command != "drop": + return True + + sqlalchemy_database_uri = f"postgresql+psycopg2://{config._DATABASE_CREDENTIAL_USER}:{config._QUOTED_DATABASE_PASSWORD}@{database_hostname}:{config.DATABASE_PORT}/{database_name}" if database_exists(str(sqlalchemy_database_uri)): if click.confirm( f"Are you sure you want to {command} database: '{database_hostname}:{database_name}'?" @@ -301,7 +305,7 @@ def prompt_for_confirmation(command: str) -> bool: return True else: click.secho(f"Database '{database_hostname}:{database_name}' does not exist!!!", fg="red") - return False + return False @dispatch_database.command("init") diff --git a/src/dispatch/signal/service.py b/src/dispatch/signal/service.py index f99c0f5672f0..728ef5c2af1b 100644 --- a/src/dispatch/signal/service.py +++ b/src/dispatch/signal/service.py @@ -87,12 +87,14 @@ def get_signal_engagement_by_name_or_raise( ) if not signal_engagement: - raise ValidationError([ - { - "msg": "Signal engagement not found.", - "loc": "signalEngagement", - } - ]) + raise ValidationError( + [ + { + "msg": "Signal engagement not found.", + "loc": "signalEngagement", + } + ] + ) return signal_engagement @@ -254,12 +256,14 @@ def get_signal_filter_by_name_or_raise( ) if not signal_filter: - raise ValidationError([ - { - "msg": "Signal Filter not found.", - "loc": "signalFilter", - } - ]) + raise ValidationError( + [ + { + "msg": "Signal Filter not found.", + "loc": "signalFilter", + } + ] + ) return signal_filter @@ -303,9 +307,7 @@ def get_default(*, db_session: Session, project_id: int) -> Signal | None: ) -def get_by_primary_or_external_id( - *, db_session: Session, signal_id: str | int -) -> Signal | None: +def get_by_primary_or_external_id(*, db_session: Session, signal_id: str | int) -> Signal | None: """Gets a signal by id or external_id.""" if is_valid_uuid(signal_id): signal = db_session.query(Signal).filter(Signal.external_id == signal_id).one_or_none() @@ -475,6 +477,7 @@ def update( signal: Signal, signal_in: SignalUpdate, user: DispatchUser | None = None, + update_filters: bool = False, ) -> Signal: """Updates a signal.""" signal_data = signal.dict() @@ -533,23 +536,21 @@ def update( updates["engagements-removed"].append(se.name) signal.engagements = engagements - is_filters_updated = {filter.id for filter in signal.filters} != { - filter.id for filter in signal_in.filters - } - - if is_filters_updated: - filters = [] - for f in signal_in.filters: - signal_filter = get_signal_filter_by_name_or_raise( - db_session=db_session, project_id=signal.project.id, signal_filter_in=f - ) - if signal_filter not in signal.filters: - updates["filters-added"].append(signal_filter.name) - filters.append(signal_filter) - for f in signal.filters: - if f not in filters: - updates["filters-removed"].append(f.name) - signal.filters = filters + # if update_filters, use only the filters from the signal_in, otherwise use the existing filters and add new filters + filter_set = set() if update_filters else set(signal.filters) + for f in signal_in.filters: + signal_filter = get_signal_filter_by_name_or_raise( + db_session=db_session, project_id=signal.project.id, signal_filter_in=f + ) + if signal_filter not in signal.filters: + updates["filters-added"].append(signal_filter.name) + filter_set.add(signal_filter) + elif update_filters: + filter_set.add(signal_filter) + for f in signal.filters: + if f not in filter_set: + updates["filters-removed"].append(f.name) + signal.filters = list(filter_set) if signal_in.workflows: workflows = [] diff --git a/src/dispatch/signal/views.py b/src/dispatch/signal/views.py index 5599dd18bfb0..534158ea73f2 100644 --- a/src/dispatch/signal/views.py +++ b/src/dispatch/signal/views.py @@ -318,7 +318,7 @@ def return_single_signal_stats( "input": signal_id, "ctx": {"error": ValueError("Signal not found.")}, } - ] + ], ) signal_data = get_signal_stats( @@ -345,7 +345,7 @@ def get_signal(db_session: DbSession, signal_id: str | PrimaryKey): "input": signal_id, "ctx": {"error": ValueError("Signal not found.")}, } - ] + ], ) return signal @@ -356,18 +356,13 @@ def create_signal(db_session: DbSession, signal_in: SignalCreate, current_user: return create(db_session=db_session, signal_in=signal_in, user=current_user) -@router.put( - "/{signal_id}", - response_model=SignalRead, - dependencies=[Depends(PermissionsDependency([SensitiveProjectActionPermission]))], -) -def update_signal( +def _update_signal( db_session: DbSession, signal_id: str | PrimaryKey, signal_in: SignalUpdate, current_user: CurrentUser, + update_filters: bool = False, ): - """Updates an existing signal.""" signal = get_by_primary_or_external_id(db_session=db_session, signal_id=signal_id) if not signal: raise ValidationError.from_exception_data( @@ -379,12 +374,16 @@ def update_signal( "input": signal_id, "ctx": {"error": ValueError("Signal not found.")}, } - ] + ], ) try: signal = update( - db_session=db_session, signal=signal, signal_in=signal_in, user=current_user + db_session=db_session, + signal=signal, + signal_in=signal_in, + user=current_user, + update_filters=update_filters, ) except IntegrityError: raise ValidationError( @@ -399,6 +398,48 @@ def update_signal( return signal +@router.put( + "/{signal_id}", + response_model=SignalRead, + dependencies=[Depends(PermissionsDependency([SensitiveProjectActionPermission]))], +) +def update_signal( + db_session: DbSession, + signal_id: str | PrimaryKey, + signal_in: SignalUpdate, + current_user: CurrentUser, +): + """Updates an existing signal from API, no filters are updated.""" + return _update_signal( + db_session=db_session, + signal_id=signal_id, + signal_in=signal_in, + current_user=current_user, + update_filters=False, + ) + + +@router.put( + "/update/{signal_id}", + response_model=SignalRead, + dependencies=[Depends(PermissionsDependency([SensitiveProjectActionPermission]))], +) +def update_signal_with_filters( + db_session: DbSession, + signal_id: str | PrimaryKey, + signal_in: SignalUpdate, + current_user: CurrentUser, +): + """Updates an existing signal from the UI, also updates filters.""" + return _update_signal( + db_session=db_session, + signal_id=signal_id, + signal_in=signal_in, + current_user=current_user, + update_filters=True, + ) + + @router.delete( "/{signal_id}", response_model=None, @@ -417,6 +458,6 @@ def delete_signal(db_session: DbSession, signal_id: str | PrimaryKey): "input": signal_id, "ctx": {"error": ValueError("Signal not found.")}, } - ] + ], ) delete(db_session=db_session, signal_id=signal.id) diff --git a/src/dispatch/static/dispatch/src/signal/api.js b/src/dispatch/static/dispatch/src/signal/api.js index 427e7eba3802..849b47e9be6c 100644 --- a/src/dispatch/static/dispatch/src/signal/api.js +++ b/src/dispatch/static/dispatch/src/signal/api.js @@ -18,7 +18,7 @@ export default { }, update(signalId, payload) { - return API.put(`${resource}/${signalId}`, payload) + return API.put(`${resource}/update/${signalId}`, payload) }, delete(signalId) {