diff --git a/alphatrion/log/load.py b/alphatrion/log/load.py index 74569058..530c66f8 100644 --- a/alphatrion/log/load.py +++ b/alphatrion/log/load.py @@ -38,8 +38,8 @@ async def load_checkpoint( ) -> list[str]: """ Load checkpoint from artifact registry, the path is expected to be in the format of: - - OCI: "org_id/team_id/exp_id/ckpt:version" - - S3: "org_id/team_id/exp_id/ckpt/filename" + - OCI: "org_id/team_id/exp_id/ckpt:version_or_filename", it should be a version. + - S3: "org_id/team_id/exp_id/ckpt/version_or_filename", it should be a filename. :param id: the id of the experiment. :param version_or_filename: the version or filename of the checkpoint to load, default is "latest". diff --git a/alphatrion/log/log.py b/alphatrion/log/log.py index 96a90373..b0a3bc8f 100644 --- a/alphatrion/log/log.py +++ b/alphatrion/log/log.py @@ -182,11 +182,6 @@ async def log_metrics(metrics: dict[str, float]): exp.done() -# log_records is used to log a list of records, which is similar to log_metrics -# but for tracing the execution of the code. -# async def log_records(): - - async def log_dataset( name: str, data_or_path: dict[str, Any] | str | list[str], diff --git a/alphatrion/server/graphql/resolvers.py b/alphatrion/server/graphql/resolvers.py index 2f8fe759..5078d93a 100644 --- a/alphatrion/server/graphql/resolvers.py +++ b/alphatrion/server/graphql/resolvers.py @@ -78,9 +78,7 @@ def get_team(info: Info[GraphQLContext, None], id: strawberry.ID) -> Team | None user_id = info.context.user_id metadb = runtime.storage_runtime().metadb - if not metadb.team_is_accessible_to_user( - team_id=id, user_id=user_id, org_id=uuid.UUID(info.context.org_id) - ): + if not metadb.team_is_accessible_to_user(team_id=id, user_id=user_id): raise RuntimeError( "Not allowed to access team that user does not belong to" ) @@ -206,9 +204,7 @@ def list_experiments( ) -> list[Experiment]: user_id = info.context.user_id metadb = runtime.storage_runtime().metadb - if not metadb.team_is_accessible_to_user( - team_id=team_id, user_id=user_id, org_id=uuid.UUID(info.context.org_id) - ): + if not metadb.team_is_accessible_to_user(team_id=team_id, user_id=user_id): raise RuntimeError( "Not allowed to access team that user does not belong to" ) @@ -352,9 +348,7 @@ def list_agents( user_id = info.context.user_id metadb = runtime.storage_runtime().metadb - if not metadb.team_is_accessible_to_user( - team_id=team_id, user_id=user_id, org_id=uuid.UUID(info.context.org_id) - ): + if not metadb.team_is_accessible_to_user(team_id=team_id, user_id=user_id): raise RuntimeError( "Not allowed to access team that user does not belong to" ) @@ -517,9 +511,7 @@ def list_runs_by_session_id( def total_agents(info: Info[GraphQLContext, None], team_id: strawberry.ID) -> int: user_id = info.context.user_id metadb = runtime.storage_runtime().metadb - if not metadb.team_is_accessible_to_user( - team_id=team_id, user_id=user_id, org_id=uuid.UUID(info.context.org_id) - ): + if not metadb.team_is_accessible_to_user(team_id=team_id, user_id=user_id): raise RuntimeError( "Not allowed to access team that user does not belong to" ) @@ -530,9 +522,7 @@ def total_agents(info: Info[GraphQLContext, None], team_id: strawberry.ID) -> in def total_sessions(info: Info[GraphQLContext, None], team_id: strawberry.ID) -> int: user_id = info.context.user_id metadb = runtime.storage_runtime().metadb - if not metadb.team_is_accessible_to_user( - team_id=team_id, user_id=user_id, org_id=uuid.UUID(info.context.org_id) - ): + if not metadb.team_is_accessible_to_user(team_id=team_id, user_id=user_id): raise RuntimeError( "Not allowed to access team that user does not belong to" ) @@ -596,9 +586,7 @@ def total_experiments( ) -> int: user_id = info.context.user_id metadb = runtime.storage_runtime().metadb - if not metadb.team_is_accessible_to_user( - team_id=team_id, user_id=user_id, org_id=uuid.UUID(info.context.org_id) - ): + if not metadb.team_is_accessible_to_user(team_id=team_id, user_id=user_id): raise RuntimeError( "Not allowed to access team that user does not belong to" ) @@ -609,9 +597,7 @@ def total_experiments( def total_runs(info: Info[GraphQLContext, None], team_id: strawberry.ID) -> int: user_id = info.context.user_id metadb = runtime.storage_runtime().metadb - if not metadb.team_is_accessible_to_user( - team_id=team_id, user_id=user_id, org_id=uuid.UUID(info.context.org_id) - ): + if not metadb.team_is_accessible_to_user(team_id=team_id, user_id=user_id): raise RuntimeError( "Not allowed to access team that user does not belong to" ) @@ -622,9 +608,7 @@ def total_runs(info: Info[GraphQLContext, None], team_id: strawberry.ID) -> int: def total_datasets(info: Info[GraphQLContext, None], team_id: strawberry.ID) -> int: user_id = info.context.user_id metadb = runtime.storage_runtime().metadb - if not metadb.team_is_accessible_to_user( - team_id=team_id, user_id=user_id, org_id=uuid.UUID(info.context.org_id) - ): + if not metadb.team_is_accessible_to_user(team_id=team_id, user_id=user_id): raise RuntimeError( "Not allowed to access team that user does not belong to" ) @@ -648,9 +632,7 @@ def aggregate_team_usage( org_id = uuid.UUID(info.context.org_id) user_id = uuid.UUID(info.context.user_id) metadb = runtime.storage_runtime().metadb - if not metadb.team_is_accessible_to_user( - team_id=team_id, user_id=user_id, org_id=org_id - ): + if not metadb.team_is_accessible_to_user(team_id=team_id, user_id=user_id): raise RuntimeError( "Not allowed to access team that user does not belong to" ) @@ -831,9 +813,7 @@ def aggregate_model_distributions( org_id = uuid.UUID(ctx.org_id) user_id = uuid.UUID(ctx.user_id) metadb = runtime.storage_runtime().metadb - if not metadb.team_is_accessible_to_user( - team_id=team_id, user_id=user_id, org_id=org_id - ): + if not metadb.team_is_accessible_to_user(team_id=team_id, user_id=user_id): raise RuntimeError( "Not allowed to access team that user does not belong to" ) @@ -856,9 +836,7 @@ def list_exps_by_timeframe( ) -> list[Experiment]: user_id = info.context.user_id metadb = runtime.storage_runtime().metadb - if not metadb.team_is_accessible_to_user( - team_id=team_id, user_id=user_id, org_id=uuid.UUID(info.context.org_id) - ): + if not metadb.team_is_accessible_to_user(team_id=team_id, user_id=user_id): raise RuntimeError( "Not allowed to access team that user does not belong to" ) @@ -919,9 +897,7 @@ async def list_artifact_tags( """List tags for a repository.""" user_id = info.context.user_id metadb = runtime.storage_runtime().metadb - if not metadb.team_is_accessible_to_user( - team_id=team_id, user_id=user_id, org_id=uuid.UUID(info.context.org_id) - ): + if not metadb.team_is_accessible_to_user(team_id=team_id, user_id=user_id): raise RuntimeError( "Not allowed to access team that user does not belong to" ) @@ -940,9 +916,7 @@ async def list_artifact_files( """List files in an artifact without loading content.""" user_id = info.context.user_id metadb = runtime.storage_runtime().metadb - if not metadb.team_is_accessible_to_user( - team_id=team_id, user_id=user_id, org_id=uuid.UUID(info.context.org_id) - ): + if not metadb.team_is_accessible_to_user(team_id=team_id, user_id=user_id): raise RuntimeError( "Not allowed to access team that user does not belong to" ) @@ -995,9 +969,7 @@ async def get_artifact_content( """Get artifact content from registry.""" user_id = info.context.user_id metadb = runtime.storage_runtime().metadb - if not metadb.team_is_accessible_to_user( - team_id=team_id, user_id=user_id, org_id=uuid.UUID(info.context.org_id) - ): + if not metadb.team_is_accessible_to_user(team_id=team_id, user_id=user_id): raise RuntimeError( "Not allowed to access team that user does not belong to" ) @@ -1091,7 +1063,6 @@ async def get_artifact_download_urls( if not metadb.team_is_accessible_to_user( team_id=dataset.team_id, user_id=uuid.UUID(user_id), - org_id=uuid.UUID(info.context.org_id), ): raise RuntimeError( "Not allowed to access dataset that user does not belong to" @@ -1515,9 +1486,7 @@ def get_daily_cost_usage( org_id = uuid.UUID(ctx.org_id) user_id = uuid.UUID(ctx.user_id) metadb = runtime.storage_runtime().metadb - if not metadb.team_is_accessible_to_user( - team_id=team_id, user_id=user_id, org_id=org_id - ): + if not metadb.team_is_accessible_to_user(team_id=team_id, user_id=user_id): return [] try: @@ -1593,9 +1562,7 @@ def list_datasets( ) -> list[Dataset]: user_id = info.context.user_id metadb = runtime.storage_runtime().metadb - if not metadb.team_is_accessible_to_user( - team_id=team_id, user_id=user_id, org_id=uuid.UUID(info.context.org_id) - ): + if not metadb.team_is_accessible_to_user(team_id=team_id, user_id=user_id): raise RuntimeError( "Not allowed to access team that user does not belong to" ) @@ -1838,6 +1805,13 @@ def add_user_to_team( ): raise RuntimeError("Only super admin can add users to teams") + if not metadb.user_and_team_in_same_org( + user_id=uuid.UUID(input.user_id), + team_id=uuid.UUID(input.team_id), + target_org_id=uuid.UUID(info.context.org_id), + ): + raise RuntimeError("User and team must belong to the same organization") + user_id = uuid.UUID(input.user_id) team_id = uuid.UUID(input.team_id) @@ -1856,7 +1830,6 @@ def remove_user_from_team( if not metadb.team_is_accessible_to_user( team_id=team_id, user_id=uuid.UUID(info.context.user_id), - org_id=uuid.UUID(info.context.org_id), ): raise RuntimeError( "Not allowed to modify team that user does not belong to" @@ -1878,9 +1851,7 @@ def create_experiment( metadb = runtime.storage_runtime().metadb # Verify user has access to the team - if not metadb.team_is_accessible_to_user( - team_id=team_id, user_id=user_id, org_id=org_id - ): + if not metadb.team_is_accessible_to_user(team_id=team_id, user_id=user_id): raise RuntimeError( "Not allowed to create experiments in team that user does not belong to" ) @@ -1936,7 +1907,6 @@ def update_experiment( """Update an existing experiment.""" user_id = uuid.UUID(info.context.user_id) - org_id = uuid.UUID(info.context.org_id) experiment_id = uuid.UUID(input.id) metadb = runtime.storage_runtime().metadb @@ -1946,9 +1916,7 @@ def update_experiment( if not exp: raise RuntimeError(f"Experiment with id '{input.id}' not found") - if not metadb.team_is_accessible_to_user( - team_id=exp.team_id, user_id=user_id, org_id=org_id - ): + if not metadb.team_is_accessible_to_user(team_id=exp.team_id, user_id=user_id): raise RuntimeError( "Not allowed to update experiment in team that user does not belong to" ) @@ -2066,7 +2034,6 @@ def abort_experiment( Only works if the experiment is in PENDING status.""" user_id = uuid.UUID(info.context.user_id) - org_id = uuid.UUID(info.context.org_id) experiment_id_uuid = uuid.UUID(experiment_id) metadb = runtime.storage_runtime().metadb @@ -2076,9 +2043,7 @@ def abort_experiment( if not exp: raise RuntimeError(f"Experiment with id '{experiment_id}' not found") - if not metadb.team_is_accessible_to_user( - team_id=exp.team_id, user_id=user_id, org_id=org_id - ): + if not metadb.team_is_accessible_to_user(team_id=exp.team_id, user_id=user_id): raise RuntimeError( "Not allowed to update experiment in team that user does not belong to" ) diff --git a/alphatrion/storage/sqlstore.py b/alphatrion/storage/sqlstore.py index eac71df3..890b919d 100644 --- a/alphatrion/storage/sqlstore.py +++ b/alphatrion/storage/sqlstore.py @@ -1324,12 +1324,17 @@ def user_is_super_admin_in_org(self, user_id: uuid.UUID, org_id: uuid.UUID) -> b return admin_membership is not None def team_is_accessible_to_user( - self, team_id: uuid.UUID, user_id: uuid.UUID, org_id: uuid.UUID + self, team_id: uuid.UUID, user_id: uuid.UUID ) -> bool: session = self._session() + team = self.get_team(team_id) + if team is None: + session.close() + return False + # Check if user is super admin - if self.user_is_super_admin_in_org(user_id, org_id): + if self.user_is_super_admin_in_org(user_id, team.org_id): session.close() return True @@ -1355,6 +1360,29 @@ def org_is_accessible_to_user(self, org_id: uuid.UUID, user_id: uuid.UUID) -> bo session.close() return user is not None + def user_and_team_in_same_org( + self, user_id: uuid.UUID, team_id: uuid.UUID, target_org_id: uuid.UUID + ) -> bool: + session = self._session() + team = ( + session.query(Team).filter(Team.uuid == team_id, Team.is_del == 0).first() + ) + if team is None: + session.close() + return False + + if team.org_id != target_org_id: + session.close() + return False + + user = ( + session.query(User) + .filter(User.uuid == user_id, User.org_id == team.org_id, User.is_del == 0) + .first() + ) + session.close() + return user is not None + def experiment_is_accessible_to_user( self, experiment_id: uuid.UUID, user_id: uuid.UUID ) -> bool: @@ -1471,7 +1499,7 @@ def dataset_is_accessible_to_user( session = self._session() dst = ( session.query(Dataset) - .filter(Dataset.uuid == dataset_id, Agent.is_del == 0) + .filter(Dataset.uuid == dataset_id, Dataset.is_del == 0) .first() ) if dst is None: diff --git a/tests/integration/server/test_graphql_mutation.py b/tests/integration/server/test_graphql_mutation.py index a3debc5f..86f11882 100644 --- a/tests/integration/server/test_graphql_mutation.py +++ b/tests/integration/server/test_graphql_mutation.py @@ -326,7 +326,9 @@ def test_add_user_to_team_with_invalid_team( user_id=test_user_id, ) assert response.errors is not None - assert "not found" in str(response.errors[0]) + # When team doesn't exist, user_and_team_in_same_org returns False + # which triggers "must belong to the same organization" error + assert "same organization" in str(response.errors[0]).lower() def test_add_user_to_team_with_invalid_user( @@ -355,7 +357,9 @@ def test_add_user_to_team_with_invalid_user( user_id=test_user_id, ) assert response.errors is not None - assert "not found" in str(response.errors[0]) + # When user doesn't exist, user_and_team_in_same_org returns False + # which triggers "must belong to the same organization" error + assert "same organization" in str(response.errors[0]).lower() def test_user_workflow(execute_graphql, test_org_id, test_user_id, test_team_id): diff --git a/tests/integration/test_s3_backend.py b/tests/integration/test_s3_backend.py index 7346f002..51d32e4d 100644 --- a/tests/integration/test_s3_backend.py +++ b/tests/integration/test_s3_backend.py @@ -46,7 +46,9 @@ def s3_env_vars(): else: os.environ[key] = value - storage_runtime_module.__STORAGE_RUNTIME__ = None # Reset again to clear any cached runtime + storage_runtime_module.__STORAGE_RUNTIME__ = ( + None # Reset again to clear any cached runtime + ) @pytest.fixture diff --git a/tests/unit/storage/test_sql.py b/tests/unit/storage/test_sql.py index 62e9af98..805029fc 100644 --- a/tests/unit/storage/test_sql.py +++ b/tests/unit/storage/test_sql.py @@ -164,3 +164,79 @@ def test_create_metrics_batch(db): assert metric_dict["loss"] == 0.1 assert metric_dict["precision"] == 0.92 assert metric_dict["recall"] == 0.88 + + +def test_user_and_team_in_same_org_success(db): + """Test that user and team in same org returns True""" + org_id = uuid.uuid4() + + # Create team + team_id = db.create_team(org_id=org_id, name="Test Team") + + # Create user in same org + user_id = db.create_user( + org_id=org_id, name="test_user", email="user@example.com", team_id=team_id + ) + + # Verify user and team are in same org + assert db.user_and_team_in_same_org(user_id, team_id, org_id) is True + + +def test_user_and_team_in_same_org_different_orgs(db): + """Test that user and team in different orgs returns False""" + org1_id = uuid.uuid4() + org2_id = uuid.uuid4() + + # Create team in org1 + team_id = db.create_team(org_id=org1_id, name="Team in Org1") + + # Create user in org2 (different org) + user_id = db.create_user(org_id=org2_id, name="test_user", email="user@example.com") + + # Verify user and team are NOT in same org + assert db.user_and_team_in_same_org(user_id, team_id, org1_id) is False + + +def test_user_and_team_in_same_org_nonexistent_team(db): + """Test that nonexistent team returns False""" + org_id = uuid.uuid4() + + # Create user + user_id = db.create_user(org_id=org_id, name="test_user", email="user@example.com") + + # Use nonexistent team ID + nonexistent_team_id = uuid.uuid4() + + # Verify returns False for nonexistent team + assert db.user_and_team_in_same_org(user_id, nonexistent_team_id, org_id) is False + + +def test_user_and_team_in_same_org_nonexistent_user(db): + """Test that nonexistent user returns False""" + org_id = uuid.uuid4() + + # Create team + team_id = db.create_team(org_id=org_id, name="Test Team") + + # Use nonexistent user ID + nonexistent_user_id = uuid.uuid4() + + # Verify returns False for nonexistent user + assert db.user_and_team_in_same_org(nonexistent_user_id, team_id, org_id) is False + + +def test_user_and_team_in_same_org_wrong_target_org(db): + """Test that wrong target org returns False""" + org_id = uuid.uuid4() + wrong_org_id = uuid.uuid4() + + # Create team in org_id + team_id = db.create_team(org_id=org_id, name="Test Team") + + # Create user in same org + user_id = db.create_user( + org_id=org_id, name="test_user", email="user@example.com", team_id=team_id + ) + + # Verify returns False when checking against wrong target org + assert db.user_and_team_in_same_org(user_id, team_id, wrong_org_id) is False