Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions alphatrion/log/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ async def load_checkpoint(
) -> list[str]:
"""
Load checkpoint from artifact registry, the path is expected to be in the format of:
- OCI: "org_id/team_id/exp_id/ckpt:version"
- S3: "org_id/team_id/exp_id/ckpt/filename"
- OCI: "org_id/team_id/exp_id/ckpt:version_or_filename", it should be a version.
- S3: "org_id/team_id/exp_id/ckpt/version_or_filename", it should be a filename.

:param id: the id of the experiment.
:param version_or_filename: the version or filename of the checkpoint to load, default is "latest".
Expand Down
5 changes: 0 additions & 5 deletions alphatrion/log/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,11 +182,6 @@ async def log_metrics(metrics: dict[str, float]):
exp.done()


# log_records is used to log a list of records, which is similar to log_metrics
# but for tracing the execution of the code.
# async def log_records():


async def log_dataset(
name: str,
data_or_path: dict[str, Any] | str | list[str],
Expand Down
87 changes: 26 additions & 61 deletions alphatrion/server/graphql/resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,7 @@ def get_team(info: Info[GraphQLContext, None], id: strawberry.ID) -> Team | None
user_id = info.context.user_id

metadb = runtime.storage_runtime().metadb
if not metadb.team_is_accessible_to_user(
team_id=id, user_id=user_id, org_id=uuid.UUID(info.context.org_id)
):
if not metadb.team_is_accessible_to_user(team_id=id, user_id=user_id):
raise RuntimeError(
"Not allowed to access team that user does not belong to"
)
Expand Down Expand Up @@ -206,9 +204,7 @@ def list_experiments(
) -> list[Experiment]:
user_id = info.context.user_id
metadb = runtime.storage_runtime().metadb
if not metadb.team_is_accessible_to_user(
team_id=team_id, user_id=user_id, org_id=uuid.UUID(info.context.org_id)
):
if not metadb.team_is_accessible_to_user(team_id=team_id, user_id=user_id):
raise RuntimeError(
"Not allowed to access team that user does not belong to"
)
Expand Down Expand Up @@ -352,9 +348,7 @@ def list_agents(
user_id = info.context.user_id

metadb = runtime.storage_runtime().metadb
if not metadb.team_is_accessible_to_user(
team_id=team_id, user_id=user_id, org_id=uuid.UUID(info.context.org_id)
):
if not metadb.team_is_accessible_to_user(team_id=team_id, user_id=user_id):
raise RuntimeError(
"Not allowed to access team that user does not belong to"
)
Expand Down Expand Up @@ -517,9 +511,7 @@ def list_runs_by_session_id(
def total_agents(info: Info[GraphQLContext, None], team_id: strawberry.ID) -> int:
user_id = info.context.user_id
metadb = runtime.storage_runtime().metadb
if not metadb.team_is_accessible_to_user(
team_id=team_id, user_id=user_id, org_id=uuid.UUID(info.context.org_id)
):
if not metadb.team_is_accessible_to_user(team_id=team_id, user_id=user_id):
raise RuntimeError(
"Not allowed to access team that user does not belong to"
)
Expand All @@ -530,9 +522,7 @@ def total_agents(info: Info[GraphQLContext, None], team_id: strawberry.ID) -> in
def total_sessions(info: Info[GraphQLContext, None], team_id: strawberry.ID) -> int:
user_id = info.context.user_id
metadb = runtime.storage_runtime().metadb
if not metadb.team_is_accessible_to_user(
team_id=team_id, user_id=user_id, org_id=uuid.UUID(info.context.org_id)
):
if not metadb.team_is_accessible_to_user(team_id=team_id, user_id=user_id):
raise RuntimeError(
"Not allowed to access team that user does not belong to"
)
Expand Down Expand Up @@ -596,9 +586,7 @@ def total_experiments(
) -> int:
user_id = info.context.user_id
metadb = runtime.storage_runtime().metadb
if not metadb.team_is_accessible_to_user(
team_id=team_id, user_id=user_id, org_id=uuid.UUID(info.context.org_id)
):
if not metadb.team_is_accessible_to_user(team_id=team_id, user_id=user_id):
raise RuntimeError(
"Not allowed to access team that user does not belong to"
)
Expand All @@ -609,9 +597,7 @@ def total_experiments(
def total_runs(info: Info[GraphQLContext, None], team_id: strawberry.ID) -> int:
user_id = info.context.user_id
metadb = runtime.storage_runtime().metadb
if not metadb.team_is_accessible_to_user(
team_id=team_id, user_id=user_id, org_id=uuid.UUID(info.context.org_id)
):
if not metadb.team_is_accessible_to_user(team_id=team_id, user_id=user_id):
raise RuntimeError(
"Not allowed to access team that user does not belong to"
)
Expand All @@ -622,9 +608,7 @@ def total_runs(info: Info[GraphQLContext, None], team_id: strawberry.ID) -> int:
def total_datasets(info: Info[GraphQLContext, None], team_id: strawberry.ID) -> int:
user_id = info.context.user_id
metadb = runtime.storage_runtime().metadb
if not metadb.team_is_accessible_to_user(
team_id=team_id, user_id=user_id, org_id=uuid.UUID(info.context.org_id)
):
if not metadb.team_is_accessible_to_user(team_id=team_id, user_id=user_id):
raise RuntimeError(
"Not allowed to access team that user does not belong to"
)
Expand All @@ -648,9 +632,7 @@ def aggregate_team_usage(
org_id = uuid.UUID(info.context.org_id)
user_id = uuid.UUID(info.context.user_id)
metadb = runtime.storage_runtime().metadb
if not metadb.team_is_accessible_to_user(
team_id=team_id, user_id=user_id, org_id=org_id
):
if not metadb.team_is_accessible_to_user(team_id=team_id, user_id=user_id):
raise RuntimeError(
"Not allowed to access team that user does not belong to"
)
Expand Down Expand Up @@ -831,9 +813,7 @@ def aggregate_model_distributions(
org_id = uuid.UUID(ctx.org_id)
user_id = uuid.UUID(ctx.user_id)
metadb = runtime.storage_runtime().metadb
if not metadb.team_is_accessible_to_user(
team_id=team_id, user_id=user_id, org_id=org_id
):
if not metadb.team_is_accessible_to_user(team_id=team_id, user_id=user_id):
raise RuntimeError(
"Not allowed to access team that user does not belong to"
)
Expand All @@ -856,9 +836,7 @@ def list_exps_by_timeframe(
) -> list[Experiment]:
user_id = info.context.user_id
metadb = runtime.storage_runtime().metadb
if not metadb.team_is_accessible_to_user(
team_id=team_id, user_id=user_id, org_id=uuid.UUID(info.context.org_id)
):
if not metadb.team_is_accessible_to_user(team_id=team_id, user_id=user_id):
raise RuntimeError(
"Not allowed to access team that user does not belong to"
)
Expand Down Expand Up @@ -919,9 +897,7 @@ async def list_artifact_tags(
"""List tags for a repository."""
user_id = info.context.user_id
metadb = runtime.storage_runtime().metadb
if not metadb.team_is_accessible_to_user(
team_id=team_id, user_id=user_id, org_id=uuid.UUID(info.context.org_id)
):
if not metadb.team_is_accessible_to_user(team_id=team_id, user_id=user_id):
raise RuntimeError(
"Not allowed to access team that user does not belong to"
)
Expand All @@ -940,9 +916,7 @@ async def list_artifact_files(
"""List files in an artifact without loading content."""
user_id = info.context.user_id
metadb = runtime.storage_runtime().metadb
if not metadb.team_is_accessible_to_user(
team_id=team_id, user_id=user_id, org_id=uuid.UUID(info.context.org_id)
):
if not metadb.team_is_accessible_to_user(team_id=team_id, user_id=user_id):
raise RuntimeError(
"Not allowed to access team that user does not belong to"
)
Expand Down Expand Up @@ -995,9 +969,7 @@ async def get_artifact_content(
"""Get artifact content from registry."""
user_id = info.context.user_id
metadb = runtime.storage_runtime().metadb
if not metadb.team_is_accessible_to_user(
team_id=team_id, user_id=user_id, org_id=uuid.UUID(info.context.org_id)
):
if not metadb.team_is_accessible_to_user(team_id=team_id, user_id=user_id):
raise RuntimeError(
"Not allowed to access team that user does not belong to"
)
Expand Down Expand Up @@ -1091,7 +1063,6 @@ async def get_artifact_download_urls(
if not metadb.team_is_accessible_to_user(
team_id=dataset.team_id,
user_id=uuid.UUID(user_id),
org_id=uuid.UUID(info.context.org_id),
):
raise RuntimeError(
"Not allowed to access dataset that user does not belong to"
Expand Down Expand Up @@ -1515,9 +1486,7 @@ def get_daily_cost_usage(
org_id = uuid.UUID(ctx.org_id)
user_id = uuid.UUID(ctx.user_id)
metadb = runtime.storage_runtime().metadb
if not metadb.team_is_accessible_to_user(
team_id=team_id, user_id=user_id, org_id=org_id
):
if not metadb.team_is_accessible_to_user(team_id=team_id, user_id=user_id):
return []

try:
Expand Down Expand Up @@ -1593,9 +1562,7 @@ def list_datasets(
) -> list[Dataset]:
user_id = info.context.user_id
metadb = runtime.storage_runtime().metadb
if not metadb.team_is_accessible_to_user(
team_id=team_id, user_id=user_id, org_id=uuid.UUID(info.context.org_id)
):
if not metadb.team_is_accessible_to_user(team_id=team_id, user_id=user_id):
raise RuntimeError(
"Not allowed to access team that user does not belong to"
)
Expand Down Expand Up @@ -1838,6 +1805,13 @@ def add_user_to_team(
):
raise RuntimeError("Only super admin can add users to teams")

if not metadb.user_and_team_in_same_org(
user_id=uuid.UUID(input.user_id),
team_id=uuid.UUID(input.team_id),
target_org_id=uuid.UUID(info.context.org_id),
):
raise RuntimeError("User and team must belong to the same organization")

user_id = uuid.UUID(input.user_id)
team_id = uuid.UUID(input.team_id)

Expand All @@ -1856,7 +1830,6 @@ def remove_user_from_team(
if not metadb.team_is_accessible_to_user(
team_id=team_id,
user_id=uuid.UUID(info.context.user_id),
org_id=uuid.UUID(info.context.org_id),
):
raise RuntimeError(
"Not allowed to modify team that user does not belong to"
Expand All @@ -1878,9 +1851,7 @@ def create_experiment(
metadb = runtime.storage_runtime().metadb

# Verify user has access to the team
if not metadb.team_is_accessible_to_user(
team_id=team_id, user_id=user_id, org_id=org_id
):
if not metadb.team_is_accessible_to_user(team_id=team_id, user_id=user_id):
raise RuntimeError(
"Not allowed to create experiments in team that user does not belong to"
)
Expand Down Expand Up @@ -1936,7 +1907,6 @@ def update_experiment(
"""Update an existing experiment."""

user_id = uuid.UUID(info.context.user_id)
org_id = uuid.UUID(info.context.org_id)
experiment_id = uuid.UUID(input.id)

metadb = runtime.storage_runtime().metadb
Expand All @@ -1946,9 +1916,7 @@ def update_experiment(
if not exp:
raise RuntimeError(f"Experiment with id '{input.id}' not found")

if not metadb.team_is_accessible_to_user(
team_id=exp.team_id, user_id=user_id, org_id=org_id
):
if not metadb.team_is_accessible_to_user(team_id=exp.team_id, user_id=user_id):
raise RuntimeError(
"Not allowed to update experiment in team that user does not belong to"
)
Expand Down Expand Up @@ -2066,7 +2034,6 @@ def abort_experiment(
Only works if the experiment is in PENDING status."""

user_id = uuid.UUID(info.context.user_id)
org_id = uuid.UUID(info.context.org_id)
experiment_id_uuid = uuid.UUID(experiment_id)

metadb = runtime.storage_runtime().metadb
Expand All @@ -2076,9 +2043,7 @@ def abort_experiment(
if not exp:
raise RuntimeError(f"Experiment with id '{experiment_id}' not found")

if not metadb.team_is_accessible_to_user(
team_id=exp.team_id, user_id=user_id, org_id=org_id
):
if not metadb.team_is_accessible_to_user(team_id=exp.team_id, user_id=user_id):
raise RuntimeError(
"Not allowed to update experiment in team that user does not belong to"
)
Expand Down
34 changes: 31 additions & 3 deletions alphatrion/storage/sqlstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -1324,12 +1324,17 @@ def user_is_super_admin_in_org(self, user_id: uuid.UUID, org_id: uuid.UUID) -> b
return admin_membership is not None

def team_is_accessible_to_user(
self, team_id: uuid.UUID, user_id: uuid.UUID, org_id: uuid.UUID
self, team_id: uuid.UUID, user_id: uuid.UUID
) -> bool:
session = self._session()

team = self.get_team(team_id)
if team is None:
session.close()
return False

# Check if user is super admin
if self.user_is_super_admin_in_org(user_id, org_id):
if self.user_is_super_admin_in_org(user_id, team.org_id):
session.close()
return True

Expand All @@ -1355,6 +1360,29 @@ def org_is_accessible_to_user(self, org_id: uuid.UUID, user_id: uuid.UUID) -> bo
session.close()
return user is not None

def user_and_team_in_same_org(
self, user_id: uuid.UUID, team_id: uuid.UUID, target_org_id: uuid.UUID
) -> bool:
session = self._session()
team = (
session.query(Team).filter(Team.uuid == team_id, Team.is_del == 0).first()
)
if team is None:
session.close()
return False

if team.org_id != target_org_id:
session.close()
return False

user = (
session.query(User)
.filter(User.uuid == user_id, User.org_id == team.org_id, User.is_del == 0)
.first()
)
session.close()
return user is not None

def experiment_is_accessible_to_user(
self, experiment_id: uuid.UUID, user_id: uuid.UUID
) -> bool:
Expand Down Expand Up @@ -1471,7 +1499,7 @@ def dataset_is_accessible_to_user(
session = self._session()
dst = (
session.query(Dataset)
.filter(Dataset.uuid == dataset_id, Agent.is_del == 0)
.filter(Dataset.uuid == dataset_id, Dataset.is_del == 0)
.first()
)
if dst is None:
Expand Down
8 changes: 6 additions & 2 deletions tests/integration/server/test_graphql_mutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,9 @@ def test_add_user_to_team_with_invalid_team(
user_id=test_user_id,
)
assert response.errors is not None
assert "not found" in str(response.errors[0])
# When team doesn't exist, user_and_team_in_same_org returns False
# which triggers "must belong to the same organization" error
assert "same organization" in str(response.errors[0]).lower()


def test_add_user_to_team_with_invalid_user(
Expand Down Expand Up @@ -355,7 +357,9 @@ def test_add_user_to_team_with_invalid_user(
user_id=test_user_id,
)
assert response.errors is not None
assert "not found" in str(response.errors[0])
# When user doesn't exist, user_and_team_in_same_org returns False
# which triggers "must belong to the same organization" error
assert "same organization" in str(response.errors[0]).lower()


def test_user_workflow(execute_graphql, test_org_id, test_user_id, test_team_id):
Expand Down
4 changes: 3 additions & 1 deletion tests/integration/test_s3_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ def s3_env_vars():
else:
os.environ[key] = value

storage_runtime_module.__STORAGE_RUNTIME__ = None # Reset again to clear any cached runtime
storage_runtime_module.__STORAGE_RUNTIME__ = (
None # Reset again to clear any cached runtime
)


@pytest.fixture
Expand Down
Loading
Loading