diff --git a/mcpgateway/admin.py b/mcpgateway/admin.py index 5ae92503a..51a5923c4 100644 --- a/mcpgateway/admin.py +++ b/mcpgateway/admin.py @@ -98,8 +98,8 @@ from mcpgateway.services.logging_service import LoggingService from mcpgateway.services.oauth_manager import OAuthManager from mcpgateway.services.plugin_service import get_plugin_service -from mcpgateway.services.prompt_service import PromptNotFoundError, PromptService -from mcpgateway.services.resource_service import ResourceNotFoundError, ResourceService +from mcpgateway.services.prompt_service import PromptNameConflictError, PromptNotFoundError, PromptService +from mcpgateway.services.resource_service import ResourceNotFoundError, ResourceService, ResourceURIConflictError from mcpgateway.services.root_service import RootService from mcpgateway.services.server_service import ServerError, ServerNameConflictError, ServerNotFoundError, ServerService from mcpgateway.services.tag_service import TagService @@ -5340,6 +5340,7 @@ async def admin_edit_tool( user_email = get_user_email(user) # Determine personal team for default assignment team_id = form.get("team_id", None) + LOGGER.info(f"before Verifying team for user {user_email} with team_id {team_id}") team_service = TeamManagementService(db) team_id = await team_service.verify_team_for_user(user_email, team_id) @@ -6411,12 +6412,12 @@ async def admin_delete_gateway(gateway_id: str, request: Request, db: Session = return RedirectResponse(f"{root_path}/admin#gateways", status_code=303) -@admin_router.get("/resources/{uri:path}") -async def admin_get_resource(uri: str, db: Session = Depends(get_db), user=Depends(get_current_user_with_permissions)) -> Dict[str, Any]: +@admin_router.get("/resources/{resource_id}") +async def admin_get_resource(resource_id: int, db: Session = Depends(get_db), user=Depends(get_current_user_with_permissions)) -> Dict[str, Any]: """Get resource details for the admin UI. Args: - uri: Resource URI. + resource_id: Resource ID. db: Database session. user: Authenticated user. @@ -6438,10 +6439,11 @@ async def admin_get_resource(uri: str, db: Session = Depends(get_db), user=Depen >>> mock_db = MagicMock() >>> mock_user = {"email": "test_user", "db": mock_db} >>> resource_uri = "test://resource/get" + >>> resource_id = 1 >>> >>> # Mock resource data >>> mock_resource = ResourceRead( - ... id=1, uri=resource_uri, name="Get Resource", description="Test", + ... id=resource_id, uri=resource_uri, name="Get Resource", description="Test", ... mime_type="text/plain", size=10, created_at=datetime.now(timezone.utc), ... updated_at=datetime.now(timezone.utc), is_active=True, metrics=ResourceMetrics( ... total_executions=0, successful_executions=0, failed_executions=0, @@ -6450,27 +6452,27 @@ async def admin_get_resource(uri: str, db: Session = Depends(get_db), user=Depen ... ), ... tags=[] ... ) - >>> mock_content = ResourceContent(type="resource", uri=resource_uri, mime_type="text/plain", text="Hello content") + >>> mock_content = ResourceContent(id=str(resource_id), type="resource", uri=resource_uri, mime_type="text/plain", text="Hello content") >>> >>> # Mock service methods - >>> original_get_resource_by_uri = resource_service.get_resource_by_uri + >>> original_get_resource_by_id = resource_service.get_resource_by_id >>> original_read_resource = resource_service.read_resource - >>> resource_service.get_resource_by_uri = AsyncMock(return_value=mock_resource) + >>> resource_service.get_resource_by_id = AsyncMock(return_value=mock_resource) >>> resource_service.read_resource = AsyncMock(return_value=mock_content) >>> >>> # Test successful retrieval >>> async def test_admin_get_resource_success(): - ... result = await admin_get_resource(resource_uri, mock_db, mock_user) - ... return isinstance(result, dict) and result['resource']['uri'] == resource_uri and result['content'].text == "Hello content" # Corrected to .text + ... result = await admin_get_resource(resource_id, mock_db, mock_user) + ... return isinstance(result, dict) and result['resource']['id'] == resource_id and result['content'].text == "Hello content" # Corrected to .text >>> >>> asyncio.run(test_admin_get_resource_success()) True >>> >>> # Test resource not found - >>> resource_service.get_resource_by_uri = AsyncMock(side_effect=ResourceNotFoundError("Resource not found")) + >>> resource_service.get_resource_by_id = AsyncMock(side_effect=ResourceNotFoundError("Resource not found")) >>> async def test_admin_get_resource_not_found(): ... try: - ... await admin_get_resource("nonexistent://uri", mock_db, mock_user) + ... await admin_get_resource(999, mock_db, mock_user) ... return False ... except HTTPException as e: ... return e.status_code == 404 and "Resource not found" in e.detail @@ -6479,11 +6481,11 @@ async def admin_get_resource(uri: str, db: Session = Depends(get_db), user=Depen True >>> >>> # Test exception during content read (resource found but content fails) - >>> resource_service.get_resource_by_uri = AsyncMock(return_value=mock_resource) # Resource found + >>> resource_service.get_resource_by_id = AsyncMock(return_value=mock_resource) # Resource found >>> resource_service.read_resource = AsyncMock(side_effect=Exception("Content read error")) >>> async def test_admin_get_resource_content_error(): ... try: - ... await admin_get_resource(resource_uri, mock_db, mock_user) + ... await admin_get_resource(resource_id, mock_db, mock_user) ... return False ... except Exception as e: ... return str(e) == "Content read error" @@ -6492,18 +6494,18 @@ async def admin_get_resource(uri: str, db: Session = Depends(get_db), user=Depen True >>> >>> # Restore original methods - >>> resource_service.get_resource_by_uri = original_get_resource_by_uri + >>> resource_service.get_resource_by_id = original_get_resource_by_id >>> resource_service.read_resource = original_read_resource """ - LOGGER.debug(f"User {get_user_email(user)} requested details for resource URI {uri}") + LOGGER.debug(f"User {get_user_email(user)} requested details for resource ID {resource_id}") try: - resource = await resource_service.get_resource_by_uri(db, uri) - content = await resource_service.read_resource(db, uri) + resource = await resource_service.get_resource_by_id(db, resource_id) + content = await resource_service.read_resource(db, resource_id) return {"resource": resource.model_dump(by_alias=True), "content": content} except ResourceNotFoundError as e: raise HTTPException(status_code=404, detail=str(e)) except Exception as e: - LOGGER.error(f"Error getting resource {uri}: {e}") + LOGGER.error(f"Error getting resource {resource_id}: {e}") raise e @@ -6596,6 +6598,9 @@ async def admin_add_resource(request: Request, db: Session = Depends(get_db), us created_user_agent=metadata["created_user_agent"], import_batch_id=metadata["import_batch_id"], federation_source=metadata["federation_source"], + team_id=team_id, + owner_email=user_email, + visibility=visibility, ) return JSONResponse( content={"message": "Add resource registered successfully!", "success": True}, @@ -6609,14 +6614,16 @@ async def admin_add_resource(request: Request, db: Session = Depends(get_db), us error_message = ErrorFormatter.format_database_error(ex) LOGGER.error(f"IntegrityError in admin_add_resource: {error_message}") return JSONResponse(status_code=409, content=error_message) - + if isinstance(ex, ResourceURIConflictError): + LOGGER.error(f"ResourceURIConflictError in admin_add_resource: {ex}") + return JSONResponse(content={"message": str(ex), "success": False}, status_code=409) LOGGER.error(f"Error in admin_add_resource: {ex}") return JSONResponse(content={"message": str(ex), "success": False}, status_code=500) -@admin_router.post("/resources/{uri:path}/edit") +@admin_router.post("/resources/{resource_id}/edit") async def admin_edit_resource( - uri: str, + resource_id: str, request: Request, db: Session = Depends(get_db), user=Depends(get_current_user_with_permissions), @@ -6631,7 +6638,7 @@ async def admin_edit_resource( - content Args: - uri: Resource URI. + resource_id: Resource ID. request: FastAPI request containing form data. db: Database session. user: Authenticated user. @@ -6705,9 +6712,9 @@ async def admin_edit_resource( >>> # Reset mock >>> resource_service.update_resource = original_update_resource """ - LOGGER.debug(f"User {get_user_email(user)} is editing resource URI {uri}") + LOGGER.debug(f"User {get_user_email(user)} is editing resource ID {resource_id}") form = await request.form() - + LOGGER.info(f"Form data received for resource edit: {form}") visibility = str(form.get("visibility", "private")) # Parse tags from comma-separated string tags_str = str(form.get("tags", "")) @@ -6716,17 +6723,19 @@ async def admin_edit_resource( try: mod_metadata = MetadataCapture.extract_modification_metadata(request, user, 0) resource = ResourceUpdate( - name=str(form["name"]), + uri=str(form.get("uri", "")), + name=str(form.get("name", "")), description=str(form.get("description")), mime_type=str(form.get("mimeType")), - content=str(form["content"]), + content=str(form.get("content", "")), template=str(form.get("template")), tags=tags, visibility=visibility, ) + LOGGER.info(f"ResourceUpdate object created: {resource}") await resource_service.update_resource( db, - uri, + resource_id, resource, modified_by=mod_metadata["modified_by"], modified_from_ip=mod_metadata["modified_from_ip"], @@ -6749,21 +6758,24 @@ async def admin_edit_resource( error_message = ErrorFormatter.format_database_error(ex) LOGGER.error(f"IntegrityError in admin_edit_resource: {error_message}") return JSONResponse(status_code=409, content=error_message) + if isinstance(ex, ResourceURIConflictError): + LOGGER.error(f"ResourceURIConflictError in admin_edit_resource: {ex}") + return JSONResponse(status_code=409, content={"message": str(ex), "success": False}) LOGGER.error(f"Error in admin_edit_resource: {ex}") return JSONResponse(content={"message": str(ex), "success": False}, status_code=500) -@admin_router.post("/resources/{uri:path}/delete") -async def admin_delete_resource(uri: str, request: Request, db: Session = Depends(get_db), user=Depends(get_current_user_with_permissions)) -> RedirectResponse: +@admin_router.post("/resources/{resource_id}/delete") +async def admin_delete_resource(resource_id: str, request: Request, db: Session = Depends(get_db), user=Depends(get_current_user_with_permissions)) -> RedirectResponse: """ Delete a resource via the admin UI. - This endpoint permanently removes a resource from the database using its URI. + This endpoint permanently removes a resource from the database using its resource ID. The operation is irreversible and should be used with caution. It requires user authentication and logs the deletion attempt. Args: - uri (str): The URI of the resource to delete. + resource_id (str): The ID of the resource to delete. request (Request): FastAPI request object (not used directly but required by the route signature). db (Session): Database session dependency. user (str): Authenticated user dependency. @@ -6808,18 +6820,18 @@ async def admin_delete_resource(uri: str, request: Request, db: Session = Depend True >>> resource_service.delete_resource = original_delete_resource """ + user_email = get_user_email(user) - LOGGER.debug(f"User {user_email} is deleting resource URI {uri}") + LOGGER.debug(f"User {get_user_email(user)} is deleting resource ID {resource_id}") error_message = None try: - await resource_service.delete_resource(db, uri, user_email=user_email) + await resource_service.delete_resource(user["db"] if isinstance(user, dict) else db, resource_id) except PermissionError as e: - LOGGER.warning(f"Permission denied for user {user_email} deleting resource {uri}: {e}") + LOGGER.warning(f"Permission denied for user {user_email} deleting resource {resource_id}: {e}") error_message = str(e) except Exception as e: LOGGER.error(f"Error deleting resource: {e}") error_message = "Failed to delete resource. Please try again." - form = await request.form() is_inactive_checked: str = str(form.get("is_inactive_checked", "false")) root_path = request.scope.get("root_path", "") @@ -6961,12 +6973,12 @@ async def admin_toggle_resource( return RedirectResponse(f"{root_path}/admin#resources", status_code=303) -@admin_router.get("/prompts/{name}") -async def admin_get_prompt(name: str, db: Session = Depends(get_db), user=Depends(get_current_user_with_permissions)) -> Dict[str, Any]: +@admin_router.get("/prompts/{prompt_id}") +async def admin_get_prompt(prompt_id: int, db: Session = Depends(get_db), user=Depends(get_current_user_with_permissions)) -> Dict[str, Any]: """Get prompt details for the admin UI. Args: - name: Prompt name. + prompt_id: Prompt ID. db: Database session. user: Authenticated user. @@ -7049,16 +7061,16 @@ async def admin_get_prompt(name: str, db: Session = Depends(get_db), user=Depend >>> >>> prompt_service.get_prompt_details = original_get_prompt_details """ - LOGGER.debug(f"User {get_user_email(user)} requested details for prompt name {name}") + LOGGER.info(f"User {get_user_email(user)} requested details for prompt ID {prompt_id}") try: - prompt_details = await prompt_service.get_prompt_details(db, name) + prompt_details = await prompt_service.get_prompt_details(db, prompt_id) prompt = PromptRead.model_validate(prompt_details) return prompt.model_dump(by_alias=True) except PromptNotFoundError as e: raise HTTPException(status_code=404, detail=str(e)) except Exception as e: - LOGGER.error(f"Error getting prompt {name}: {e}") - raise e + LOGGER.error(f"Error getting prompt {prompt_id}: {e}") + raise @admin_router.post("/prompts") @@ -7151,6 +7163,9 @@ async def admin_add_prompt(request: Request, db: Session = Depends(get_db), user created_user_agent=metadata["created_user_agent"], import_batch_id=metadata["import_batch_id"], federation_source=metadata["federation_source"], + team_id=team_id, + owner_email=user_email, + visibility=visibility, ) return JSONResponse( content={"message": "Prompt registered successfully!", "success": True}, @@ -7164,13 +7179,16 @@ async def admin_add_prompt(request: Request, db: Session = Depends(get_db), user error_message = ErrorFormatter.format_database_error(ex) LOGGER.error(f"IntegrityError in admin_add_prompt: {error_message}") return JSONResponse(status_code=409, content=error_message) + if isinstance(ex, PromptNameConflictError): + LOGGER.error(f"PromptNameConflictError in admin_add_prompt: {ex}") + return JSONResponse(status_code=409, content={"message": str(ex), "success": False}) LOGGER.error(f"Error in admin_add_prompt: {ex}") return JSONResponse(content={"message": str(ex), "success": False}, status_code=500) -@admin_router.post("/prompts/{name}/edit") +@admin_router.post("/prompts/{prompt_id}/edit") async def admin_edit_prompt( - name: str, + prompt_id: str, request: Request, db: Session = Depends(get_db), user=Depends(get_current_user_with_permissions), @@ -7178,21 +7196,21 @@ async def admin_edit_prompt( """Edit a prompt via the admin UI. Expects form fields: - - name - - description (optional) - - template - - arguments (as a JSON string representing a list) + - name + - description (optional) + - template + - arguments (as a JSON string representing a list) Args: - name: Prompt name. + prompt_id: Prompt ID. request: FastAPI request containing form data. db: Database session. user: Authenticated user. Returns: - JSONResponse: A JSON response indicating success or failure of the server update operation. + JSONResponse: A JSON response indicating success or failure of the server update operation. - Examples: + Examples: >>> import asyncio >>> from unittest.mock import AsyncMock, MagicMock >>> from fastapi import Request @@ -7240,15 +7258,18 @@ async def admin_edit_prompt( True >>> prompt_service.update_prompt = original_update_prompt """ - LOGGER.debug(f"User {get_user_email(user)} is editing prompt name {name}") + LOGGER.debug(f"User {get_user_email(user)} is editing prompt {prompt_id}") form = await request.form() + LOGGER.info(f"form data: {form}") visibility = str(form.get("visibility", "private")) user_email = get_user_email(user) # Determine personal team for default assignment team_id = form.get("team_id", None) + LOGGER.info(f"befor Verifying team for user {user_email} with team_id {team_id}") team_service = TeamManagementService(db) team_id = await team_service.verify_team_for_user(user_email, team_id) + LOGGER.info(f"Verifying team for user {user_email} with team_id {team_id}") args_json: str = str(form.get("arguments")) or "[]" arguments = json.loads(args_json) @@ -7269,7 +7290,7 @@ async def admin_edit_prompt( ) await prompt_service.update_prompt( db, - name, + prompt_id, prompt, modified_by=mod_metadata["modified_by"], modified_from_ip=mod_metadata["modified_from_ip"], @@ -7292,21 +7313,24 @@ async def admin_edit_prompt( error_message = ErrorFormatter.format_database_error(ex) LOGGER.error(f"IntegrityError in admin_edit_prompt: {error_message}") return JSONResponse(status_code=409, content=error_message) + if isinstance(ex, PromptNameConflictError): + LOGGER.error(f"PromptNameConflictError in admin_edit_prompt: {ex}") + return JSONResponse(status_code=409, content={"message": str(ex), "success": False}) LOGGER.error(f"Error in admin_edit_prompt: {ex}") return JSONResponse(content={"message": str(ex), "success": False}, status_code=500) -@admin_router.post("/prompts/{name}/delete") -async def admin_delete_prompt(name: str, request: Request, db: Session = Depends(get_db), user=Depends(get_current_user_with_permissions)) -> RedirectResponse: +@admin_router.post("/prompts/{prompt_id}/delete") +async def admin_delete_prompt(prompt_id: str, request: Request, db: Session = Depends(get_db), user=Depends(get_current_user_with_permissions)) -> RedirectResponse: """ Delete a prompt via the admin UI. - This endpoint permanently deletes a prompt from the database using its name. + This endpoint permanently deletes a prompt from the database using its ID. Deletion is irreversible and requires authentication. All actions are logged for administrative auditing. Args: - name (str): The name of the prompt to delete. + prompt_id (str): The ID of the prompt to delete. request (Request): FastAPI request object (not used directly but required by the route signature). db (Session): Database session dependency. user (str): Authenticated user dependency. @@ -7352,17 +7376,16 @@ async def admin_delete_prompt(name: str, request: Request, db: Session = Depends >>> prompt_service.delete_prompt = original_delete_prompt """ user_email = get_user_email(user) - LOGGER.debug(f"User {user_email} is deleting prompt name {name}") + LOGGER.info(f"User {get_user_email(user)} is deleting prompt id {prompt_id}") error_message = None try: - await prompt_service.delete_prompt(db, name, user_email=user_email) + await prompt_service.delete_prompt(db, prompt_id, user_email=user_email) except PermissionError as e: - LOGGER.warning(f"Permission denied for user {user_email} deleting prompt {name}: {e}") + LOGGER.warning(f"Permission denied for user {user_email} deleting prompt {prompt_id}: {e}") error_message = str(e) except Exception as e: LOGGER.error(f"Error deleting prompt: {e}") error_message = "Failed to delete prompt. Please try again." - form = await request.form() is_inactive_checked: str = str(form.get("is_inactive_checked", "false")) root_path = request.scope.get("root_path", "") @@ -9212,6 +9235,9 @@ async def admin_add_a2a_agent( created_user_agent=metadata["created_user_agent"], import_batch_id=metadata["import_batch_id"], federation_source=metadata["federation_source"], + team_id=team_id, + owner_email=user_email, + visibility=form.get("visibility", "private"), ) return JSONResponse( diff --git a/mcpgateway/alembic/versions/e5a59c16e041_unique_const_changes_for_prompt_and_.py b/mcpgateway/alembic/versions/e5a59c16e041_unique_const_changes_for_prompt_and_.py new file mode 100644 index 000000000..00aa8c4a8 --- /dev/null +++ b/mcpgateway/alembic/versions/e5a59c16e041_unique_const_changes_for_prompt_and_.py @@ -0,0 +1,183 @@ +# -*- coding: utf-8 -*- +"""unique const changes for prompt and resource + +Revision ID: e5a59c16e041 +Revises: 8a2934be50c0 +Create Date: 2025-10-15 11:20:53.888488 + +""" + +# Standard +from typing import Sequence, Union + +# Third-Party +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision: str = "e5a59c16e041" +down_revision: Union[str, Sequence[str], None] = "8a2934be50c0" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """ + Apply schema changes to add or update unique constraints for prompts, resources and a2a agents. + This migration recreates tables with updated unique constraints and preserves data. + Compatible with SQLite, MySQL, and PostgreSQL. + """ + bind = op.get_bind() + inspector = sa.inspect(bind) + + # ### commands auto generated by Alembic - please adjust! ### + for tbl, constraints in { + "prompts": [("name", "uq_team_owner_name_prompts")], + "resources": [("uri", "uq_team_owner_uri_resources")], + "a2a_agents": [("slug", "uq_team_owner_slug_a2a_agents")], + }.items(): + try: + print(f"Processing {tbl} for unique constraint update...") + + # Get table metadata using SQLAlchemy + metadata = sa.MetaData() + table = sa.Table(tbl, metadata, autoload_with=bind) + + # Create temporary table name + tmp_table = f"{tbl}_tmp_nounique" + + # Drop temp table if it exists + if inspector.has_table(tmp_table): + op.drop_table(tmp_table) + + # Create new table structure with same columns but no old unique constraints + new_table = sa.Table(tmp_table, metadata) + + for column in table.columns: + # Copy column with same properties + new_column = column.copy() + new_table.append_column(new_column) + + # Copy foreign key constraints + for fk in table.foreign_keys: + new_table.append_constraint(fk.constraint.copy()) + uqs_to_copy = [] + # # # Copy unique constraints that we're not replacing, and skip any unique constraint only on 'name' + if tbl == "prompts": + uqs_to_copy = [] + for uq in table.constraints: + if isinstance(uq, sa.UniqueConstraint) and set([col.name for col in uq.columns]) != {"name"} and not any(uq.name == c[1] if uq.name else False for c in constraints): + uqs_to_copy.append(uq) + # Copy unique constraints that we're not replacing, and skip any unique constraint only on 'name' + if tbl == "resources": + uqs_to_copy = [ + uq + for uq in table.constraints + if isinstance(uq, sa.UniqueConstraint) and set([col.name for col in uq.columns]) != {"uri"} and not any(uq.name == c[1] if uq.name else False for c in constraints) + ] + + # For a2a_agents, also drop any unique constraint on just 'name' + if tbl == "a2a_agents": + uqs_to_copy = [ + uq + for uq in table.constraints + if isinstance(uq, sa.UniqueConstraint) + and set([col.name for col in uq.columns]) != {"name"} + and set([col.name for col in uq.columns]) != {"slug"} + and not any(uq.name == c[1] if uq.name else False for c in constraints) + ] + for uq in uqs_to_copy: + if uq is not None: + new_table.append_constraint(uq.copy()) + + # Create the temporary table + new_table.create(bind) + + # Copy data + column_names = [c.name for c in table.columns] + insert_stmt = new_table.insert().from_select(column_names, sa.select(*[table.c[name] for name in column_names])) + bind.execute(insert_stmt) + + # Add new unique constraints using batch operations for SQLite compatibility + with op.batch_alter_table(tmp_table, schema=None) as batch_op: + for col, constraint_name in constraints: + cols = ["team_id", "owner_email", col] + batch_op.create_unique_constraint(constraint_name, cols) + + # Drop original table and rename temp table + op.drop_table(tbl) + op.rename_table(tmp_table, tbl) + + except Exception as e: + print(f"Warning: Could not update unique constraint on {tbl} table: {e}") + # ### end Alembic commands ### + + +def downgrade() -> None: + """ + Revert schema changes, restoring previous unique constraints for prompts, resources and a2a_agents. + This migration recreates tables with the original unique constraints and preserves data. + Compatible with SQLite, MySQL, and PostgreSQL. + """ + bind = op.get_bind() + inspector = sa.inspect(bind) + + for tbl, constraints in { + "prompts": [("name", "uq_team_owner_name_prompts")], + "resources": [("uri", "uq_team_owner_uri_resources")], + "a2a_agents": [("slug", "uq_team_owner_slug_a2a_agents")], + }.items(): + try: + print(f"Processing {tbl} for unique constraint revert...") + + # Get table metadata using SQLAlchemy + metadata = sa.MetaData() + table = sa.Table(tbl, metadata, autoload_with=bind) + + # Create temporary table name + tmp_table = f"{tbl}_tmp_revert" + + # Drop temp table if it exists + if inspector.has_table(tmp_table): + op.drop_table(tmp_table) + + # Create new table structure with same columns but original unique constraints + new_table = sa.Table(tmp_table, metadata) + + for column in table.columns: + # Copy column with same properties + new_column = column.copy() + new_table.append_column(new_column) + + # Copy foreign key constraints + for fk in table.foreign_keys: + new_table.append_constraint(fk.constraint.copy()) + + # Copy unique constraints that we're not reverting + uqs_to_copy = [uq for uq in table.constraints if isinstance(uq, sa.UniqueConstraint) and not any(uq.name == c[1] if uq.name else False for c in constraints)] + for uq in uqs_to_copy: + new_table.append_constraint(uq.copy()) + + # Add back the original single-column unique constraints + + for col, _ in constraints: + if col in [c.name for c in table.columns]: + new_table.append_constraint(sa.UniqueConstraint(col)) + if tbl == "a2a_agents": + # Also re-add unique constraint on 'name' for a2a_agents + new_table.append_constraint(sa.UniqueConstraint("name")) + # Create the temporary table + new_table.create(bind) + + # Copy data + column_names = [c.name for c in table.columns] + insert_stmt = new_table.insert().from_select(column_names, sa.select(*[table.c[name] for name in column_names])) + bind.execute(insert_stmt) + + # Drop original table and rename temp table + op.drop_table(tbl) + op.rename_table(tmp_table, tbl) + + except Exception as e: + print(f"Warning: Could not revert unique constraint on {tbl} table: {e}") + # ### end Alembic commands ### diff --git a/mcpgateway/db.py b/mcpgateway/db.py index 91853c292..d2b4f4542 100644 --- a/mcpgateway/db.py +++ b/mcpgateway/db.py @@ -1840,7 +1840,7 @@ class Resource(Base): __tablename__ = "resources" id: Mapped[int] = mapped_column(primary_key=True) - uri: Mapped[str] = mapped_column(String(767), unique=True) + uri: Mapped[str] = mapped_column(String(767), nullable=False) name: Mapped[str] = mapped_column(String(255), nullable=False) description: Mapped[Optional[str]] = mapped_column(Text, nullable=True) mime_type: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) @@ -1881,6 +1881,7 @@ class Resource(Base): # Many-to-many relationship with Servers servers: Mapped[List["Server"]] = relationship("Server", secondary=server_resource_association, back_populates="resources") + __table_args__ = (UniqueConstraint("team_id", "owner_email", "uri", name="uq_team_owner_uri_resource"),) @property def content(self) -> "ResourceContent": @@ -1927,6 +1928,7 @@ def content(self) -> "ResourceContent": if self.text_content is not None: return ResourceContent( type="resource", + id=str(self.id), uri=self.uri, mime_type=self.mime_type, text=self.text_content, @@ -1934,6 +1936,7 @@ def content(self) -> "ResourceContent": if self.binary_content is not None: return ResourceContent( type="resource", + id=str(self.id), uri=self.uri, mime_type=self.mime_type or "application/octet-stream", blob=self.binary_content, @@ -2078,7 +2081,7 @@ class Prompt(Base): __tablename__ = "prompts" id: Mapped[int] = mapped_column(primary_key=True) - name: Mapped[str] = mapped_column(String(255), unique=True) + name: Mapped[str] = mapped_column(String(255), nullable=False) description: Mapped[Optional[str]] = mapped_column(Text, nullable=True) template: Mapped[str] = mapped_column(Text) argument_schema: Mapped[Dict[str, Any]] = mapped_column(JSON) @@ -2116,6 +2119,8 @@ class Prompt(Base): owner_email: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) visibility: Mapped[str] = mapped_column(String(20), nullable=False, default="public") + __table_args__ = (UniqueConstraint("team_id", "owner_email", "name", name="uq_team_owner_name_prompt"),) + def validate_arguments(self, args: Dict[str, str]) -> None: """ Validate prompt arguments against the argument schema. @@ -2548,8 +2553,8 @@ class A2AAgent(Base): __tablename__ = "a2a_agents" id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: uuid.uuid4().hex) - name: Mapped[str] = mapped_column(String(255), nullable=False, unique=True) - slug: Mapped[str] = mapped_column(String(255), nullable=False, unique=True) + name: Mapped[str] = mapped_column(String(255), nullable=False) + slug: Mapped[str] = mapped_column(String(255), nullable=False) description: Mapped[Optional[str]] = mapped_column(Text) endpoint_url: Mapped[str] = mapped_column(String(767), nullable=False) agent_type: Mapped[str] = mapped_column(String(50), nullable=False, default="generic") # e.g., "openai", "anthropic", "custom" @@ -2594,6 +2599,7 @@ class A2AAgent(Base): # Relationships servers: Mapped[List["Server"]] = relationship("Server", secondary=server_a2a_association, back_populates="a2a_agents") metrics: Mapped[List["A2AAgentMetric"]] = relationship("A2AAgentMetric", back_populates="a2a_agent", cascade="all, delete-orphan") + __table_args__ = (UniqueConstraint("team_id", "owner_email", "slug", name="uq_team_owner_slug_a2a_agent"),) @property def execution_count(self) -> int: diff --git a/mcpgateway/main.py b/mcpgateway/main.py index 726190c43..12c849054 100644 --- a/mcpgateway/main.py +++ b/mcpgateway/main.py @@ -2560,14 +2560,14 @@ async def create_resource( raise HTTPException(status_code=409, detail=ErrorFormatter.format_database_error(e)) -@resource_router.get("/{uri:path}") +@resource_router.get("/{resource_id}") @require_permission("resources.read") -async def read_resource(uri: str, request: Request, db: Session = Depends(get_db), user=Depends(get_current_user_with_permissions)) -> Any: +async def read_resource(resource_id: str, request: Request, db: Session = Depends(get_db), user=Depends(get_current_user_with_permissions)) -> Any: """ - Read a resource by its URI with plugin support. + Read a resource by its ID with plugin support. Args: - uri (str): URI of the resource. + resource_id (str): ID of the resource. request (Request): FastAPI request object for context. db (Session): Database session. user (str): Authenticated user. @@ -2582,20 +2582,20 @@ async def read_resource(uri: str, request: Request, db: Session = Depends(get_db request_id = request.headers.get("X-Request-ID", str(uuid.uuid4())) server_id = request.headers.get("X-Server-ID") - logger.debug(f"User {user} requested resource with URI {uri} (request_id: {request_id})") + logger.debug(f"User {user} requested resource with ID {resource_id} (request_id: {request_id})") # Check cache - if cached := resource_cache.get(uri): + if cached := resource_cache.get(resource_id): return cached try: # Call service with context for plugin support - content = await resource_service.read_resource(db, uri, request_id=request_id, user=user, server_id=server_id) + content = await resource_service.read_resource(db, resource_id, request_id=request_id, user=user, server_id=server_id) except (ResourceNotFoundError, ResourceError) as exc: # Translate to FastAPI HTTP error raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc - resource_cache.set(uri, content) + resource_cache.set(resource_id, content) # Ensure a plain JSON-serializable structure try: # First-Party @@ -2608,36 +2608,36 @@ async def read_resource(uri: str, request: Request, db: Session = Depends(get_db # If TextContent, wrap into resource envelope with text if isinstance(content, TextContent): - return {"type": "resource", "uri": uri, "text": content.text} + return {"type": "resource", "id": resource_id, "uri": content.uri, "text": content.text} except Exception: pass # nosec B110 - Intentionally continue with fallback resource content handling if isinstance(content, bytes): - return {"type": "resource", "uri": uri, "blob": content.decode("utf-8", errors="ignore")} + return {"type": "resource", "id": resource_id, "uri": content.uri, "blob": content.decode("utf-8", errors="ignore")} if isinstance(content, str): - return {"type": "resource", "uri": uri, "text": content} + return {"type": "resource", "id": resource_id, "uri": content.uri, "text": content} # Objects with a 'text' attribute (e.g., mocks) – best-effort mapping if hasattr(content, "text"): - return {"type": "resource", "uri": uri, "text": getattr(content, "text")} + return {"type": "resource", "id": resource_id, "uri": content.uri, "text": getattr(content, "text")} - return {"type": "resource", "uri": uri, "text": str(content)} + return {"type": "resource", "id": resource_id, "uri": content.uri, "text": str(content)} -@resource_router.put("/{uri:path}", response_model=ResourceRead) +@resource_router.put("/{resource_id}", response_model=ResourceRead) @require_permission("resources.update") async def update_resource( - uri: str, + resource_id: str, resource: ResourceUpdate, request: Request, db: Session = Depends(get_db), user=Depends(get_current_user_with_permissions), ) -> ResourceRead: """ - Update a resource identified by its URI. + Update a resource identified by its ID. Args: - uri (str): URI of the resource. + resource_id (str): ID of the resource. resource (ResourceUpdate): New resource data. request (Request): The FastAPI request object for metadata extraction. db (Session): Database session. @@ -2650,14 +2650,14 @@ async def update_resource( HTTPException: If the resource is not found or update fails. """ try: - logger.debug(f"User {user} is updating resource with URI {uri}") + logger.debug(f"User {user} is updating resource with ID {resource_id}") # Extract modification metadata mod_metadata = MetadataCapture.extract_modification_metadata(request, user, 0) # Version will be incremented in service user_email = user.get("email") if isinstance(user, dict) else str(user) result = await resource_service.update_resource( db, - uri, + resource_id, resource, modified_by=mod_metadata["modified_by"], modified_from_ip=mod_metadata["modified_from_ip"], @@ -2670,23 +2670,25 @@ async def update_resource( except ResourceNotFoundError as e: raise HTTPException(status_code=404, detail=str(e)) except ValidationError as e: - logger.error(f"Validation error while updating resource {uri}: {e}") + logger.error(f"Validation error while updating resource {resource_id}: {e}") raise HTTPException(status_code=422, detail=ErrorFormatter.format_validation_error(e)) except IntegrityError as e: - logger.error(f"Integrity error while updating resource {uri}: {e}") + logger.error(f"Integrity error while updating resource {resource_id}: {e}") raise HTTPException(status_code=409, detail=ErrorFormatter.format_database_error(e)) - await invalidate_resource_cache(uri) + except ResourceURIConflictError as e: + raise HTTPException(status_code=409, detail=str(e)) + await invalidate_resource_cache(resource_id) return result -@resource_router.delete("/{uri:path}") +@resource_router.delete("/{resource_id}") @require_permission("resources.delete") -async def delete_resource(uri: str, db: Session = Depends(get_db), user=Depends(get_current_user_with_permissions)) -> Dict[str, str]: +async def delete_resource(resource_id: str, db: Session = Depends(get_db), user=Depends(get_current_user_with_permissions)) -> Dict[str, str]: """ - Delete a resource by its URI. + Delete a resource by its ID. Args: - uri (str): URI of the resource to delete. + resource_id (str): ID of the resource to delete. db (Session): Database session. user (str): Authenticated user. @@ -2697,11 +2699,11 @@ async def delete_resource(uri: str, db: Session = Depends(get_db), user=Depends( HTTPException: If the resource is not found or deletion fails. """ try: - logger.debug(f"User {user} is deleting resource with URI {uri}") + logger.debug(f"User {user} is deleting resource with id {resource_id}") user_email = user.get("email") if isinstance(user, dict) else str(user) - await resource_service.delete_resource(db, uri, user_email=user_email) - await invalidate_resource_cache(uri) - return {"status": "success", "message": f"Resource {uri} deleted"} + await resource_service.delete_resource(db, resource_id, user_email=user_email) + await invalidate_resource_cache(resource_id) + return {"status": "success", "message": f"Resource {resource_id} deleted"} except PermissionError as e: raise HTTPException(status_code=403, detail=str(e)) except ResourceNotFoundError as e: @@ -2710,21 +2712,21 @@ async def delete_resource(uri: str, db: Session = Depends(get_db), user=Depends( raise HTTPException(status_code=400, detail=str(e)) -@resource_router.post("/subscribe/{uri:path}") +@resource_router.post("/subscribe/{resource_id}") @require_permission("resources.read") -async def subscribe_resource(uri: str, user=Depends(get_current_user_with_permissions)) -> StreamingResponse: +async def subscribe_resource(resource_id: str, user=Depends(get_current_user_with_permissions)) -> StreamingResponse: """ Subscribe to server-sent events (SSE) for a specific resource. Args: - uri (str): URI of the resource to subscribe to. + resource_id (str): ID of the resource to subscribe to. user (str): Authenticated user. Returns: StreamingResponse: A streaming response with event updates. """ - logger.debug(f"User {user} is subscribing to resource with URI {uri}") - return StreamingResponse(resource_service.subscribe_events(uri), media_type="text/event-stream") + logger.debug(f"User {user} is subscribing to resource with resource_id {resource_id}") + return StreamingResponse(resource_service.subscribe_events(resource_id), media_type="text/event-stream") ############### @@ -2896,22 +2898,22 @@ async def create_prompt( raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred while creating the prompt") -@prompt_router.post("/{name}") +@prompt_router.post("/{prompt_id}") @require_permission("prompts.read") async def get_prompt( - name: str, + prompt_id: str, args: Dict[str, str] = Body({}), db: Session = Depends(get_db), user=Depends(get_current_user_with_permissions), ) -> Any: - """Get a prompt by name with arguments. + """Get a prompt by prompt_id with arguments. This implements the prompts/get functionality from the MCP spec, which requires a POST request with arguments in the body. Args: - name: Name of the prompt. + prompt_id: ID of the prompt. args: Template arguments. db: Database session. user: Authenticated user. @@ -2922,14 +2924,14 @@ async def get_prompt( Raises: Exception: Re-raised if not a handled exception type. """ - logger.debug(f"User: {user} requested prompt: {name} with args={args}") + logger.debug(f"User: {user} requested prompt: {prompt_id} with args={args}") try: PromptExecuteArgs(args=args) - result = await prompt_service.get_prompt(db, name, args) - logger.debug(f"Prompt execution successful for '{name}'") + result = await prompt_service.get_prompt(db, prompt_id, args) + logger.debug(f"Prompt execution successful for '{prompt_id}'") except Exception as ex: - logger.error(f"Could not retrieve prompt {name}: {ex}") + logger.error(f"Could not retrieve prompt {prompt_id}: {ex}") if isinstance(ex, PluginViolationError): # Return the actual plugin violation message return JSONResponse(content={"message": ex.message, "details": str(ex.violation) if hasattr(ex, "violation") else None}, status_code=422) @@ -2941,19 +2943,19 @@ async def get_prompt( return result -@prompt_router.get("/{name}") +@prompt_router.get("/{prompt_id}") @require_permission("prompts.read") async def get_prompt_no_args( - name: str, + prompt_id: str, db: Session = Depends(get_db), user=Depends(get_current_user_with_permissions), ) -> Any: - """Get a prompt by name without arguments. + """Get a prompt by ID without arguments. This endpoint is for convenience when no arguments are needed. Args: - name: The name of the prompt to retrieve + prompt_id: The ID of the prompt to retrieve db: Database session user: Authenticated user @@ -2963,14 +2965,14 @@ async def get_prompt_no_args( Raises: Exception: Re-raised from prompt service. """ - logger.debug(f"User: {user} requested prompt: {name} with no arguments") - return await prompt_service.get_prompt(db, name, {}) + logger.debug(f"User: {user} requested prompt: {prompt_id} with no arguments") + return await prompt_service.get_prompt(db, prompt_id, {}) -@prompt_router.put("/{name}", response_model=PromptRead) +@prompt_router.put("/{prompt_id}", response_model=PromptRead) @require_permission("prompts.update") async def update_prompt( - name: str, + prompt_id: str, prompt: PromptUpdate, request: Request, db: Session = Depends(get_db), @@ -2980,7 +2982,7 @@ async def update_prompt( Update (overwrite) an existing prompt definition. Args: - name (str): Identifier of the prompt to update. + prompt_id (str): Identifier of the prompt to update. prompt (PromptUpdate): New prompt content and metadata. request (Request): The FastAPI request object for metadata extraction. db (Session): Active SQLAlchemy session. @@ -2993,8 +2995,7 @@ async def update_prompt( HTTPException: * **409 Conflict** - a different prompt with the same *name* already exists and is still active. * **400 Bad Request** - validation or persistence error raised by :pyclass:`~mcpgateway.services.prompt_service.PromptService`. """ - logger.info(f"User: {user} requested to update prompt: {name} with data={prompt}") - logger.debug(f"User: {user} requested to update prompt: {name} with data={prompt}") + logger.debug(f"User: {user} requested to update prompt: {prompt_id} with data={prompt}") try: # Extract modification metadata mod_metadata = MetadataCapture.extract_modification_metadata(request, user, 0) # Version will be incremented in service @@ -3002,7 +3003,7 @@ async def update_prompt( user_email = user.get("email") if isinstance(user, dict) else str(user) return await prompt_service.update_prompt( db, - name, + prompt_id, prompt, modified_by=mod_metadata["modified_by"], modified_from_ip=mod_metadata["modified_from_ip"], @@ -3032,14 +3033,14 @@ async def update_prompt( raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred while updating the prompt") -@prompt_router.delete("/{name}") +@prompt_router.delete("/{prompt_id}") @require_permission("prompts.delete") -async def delete_prompt(name: str, db: Session = Depends(get_db), user=Depends(get_current_user_with_permissions)) -> Dict[str, str]: +async def delete_prompt(prompt_id: str, db: Session = Depends(get_db), user=Depends(get_current_user_with_permissions)) -> Dict[str, str]: """ - Delete a prompt by name. + Delete a prompt by ID. Args: - name: Name of the prompt. + prompt_id: ID of the prompt. db: Database session. user: Authenticated user. @@ -3049,11 +3050,11 @@ async def delete_prompt(name: str, db: Session = Depends(get_db), user=Depends(g Raises: HTTPException: If the prompt is not found, a prompt error occurs, or an unexpected error occurs during deletion. """ - logger.debug(f"User: {user} requested deletion of prompt {name}") + logger.debug(f"User: {user} requested deletion of prompt {prompt_id}") try: user_email = user.get("email") if isinstance(user, dict) else str(user) - await prompt_service.delete_prompt(db, name, user_email=user_email) - return {"status": "success", "message": f"Prompt {name} deleted"} + await prompt_service.delete_prompt(db, prompt_id, user_email=user_email) + return {"status": "success", "message": f"Prompt {prompt_id} deleted"} except Exception as e: if isinstance(e, PermissionError): raise HTTPException(status_code=403, detail=str(e)) @@ -3061,7 +3062,7 @@ async def delete_prompt(name: str, db: Session = Depends(get_db), user=Depends(g raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) if isinstance(e, PromptError): raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) - logger.error(f"Unexpected error while deleting prompt {name}: {e}") + logger.error(f"Unexpected error while deleting prompt {prompt_id}: {e}") raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred while deleting the prompt") # except PromptNotFoundError as e: diff --git a/mcpgateway/models.py b/mcpgateway/models.py index 51ffe1406..20faca993 100644 --- a/mcpgateway/models.py +++ b/mcpgateway/models.py @@ -139,13 +139,15 @@ class ResourceContent(BaseModel): Attributes: type (Literal["resource"]): The fixed content type identifier for resources. - uri (str): The URI identifying the resource. + id (str): The ID identifying the resource. + uri (str): The URI of the resource. mime_type (Optional[str]): The MIME type of the resource, if known. text (Optional[str]): A textual representation of the resource, if applicable. blob (Optional[bytes]): Binary data of the resource, if applicable. """ type: Literal["resource"] + id: str uri: str mime_type: Optional[str] = None text: Optional[str] = None diff --git a/mcpgateway/plugins/framework/external/mcp/server/server.py b/mcpgateway/plugins/framework/external/mcp/server/server.py index 3772d03c1..78dba8ce9 100644 --- a/mcpgateway/plugins/framework/external/mcp/server/server.py +++ b/mcpgateway/plugins/framework/external/mcp/server/server.py @@ -123,7 +123,7 @@ async def invoke_hook( >>> server = ExternalPluginServer(config_path="./tests/unit/mcpgateway/plugins/fixtures/configs/valid_multiple_plugins_filter.yaml") >>> def prompt_pre_fetch_func(plugin: Plugin, payload: PromptPrehookPayload, context: PluginContext) -> PromptPrehookResult: ... return plugin.prompt_pre_fetch(payload, context) - >>> payload = PromptPrehookPayload(name="test_prompt", args={"user": "This is so innovative"}) + >>> payload = PromptPrehookPayload(prompt_id="test_id", args={"user": "This is so innovative"}) >>> context = PluginContext(global_context=GlobalContext(request_id="1", server_id="2")) >>> initialized = asyncio.run(server.initialize()) >>> initialized diff --git a/mcpgateway/plugins/framework/external/mcp/tls_utils.py b/mcpgateway/plugins/framework/external/mcp/tls_utils.py index 370cbb7df..91b04cfb0 100644 --- a/mcpgateway/plugins/framework/external/mcp/tls_utils.py +++ b/mcpgateway/plugins/framework/external/mcp/tls_utils.py @@ -79,6 +79,9 @@ def create_ssl_context(tls_config: MCPClientTLSConfig, plugin_name: str) -> ssl. # - Automatic expiration checking (notBefore/notAfter per RFC 5280) ssl_context = ssl.create_default_context() + # Enforce TLS 1.2 or higher for security + ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2 + if not tls_config.verify: # Disable certificate verification (not recommended for production) logger.warning(f"Certificate verification disabled for plugin '{plugin_name}'. This is not recommended for production use.") diff --git a/mcpgateway/plugins/framework/manager.py b/mcpgateway/plugins/framework/manager.py index 20005ab50..9287effee 100644 --- a/mcpgateway/plugins/framework/manager.py +++ b/mcpgateway/plugins/framework/manager.py @@ -21,7 +21,7 @@ >>> # Create test payload and context >>> from mcpgateway.plugins.framework.models import PromptPrehookPayload, GlobalContext - >>> payload = PromptPrehookPayload(name="test", args={"user": "input"}) + >>> payload = PromptPrehookPayload(prompt_id="test", name="test", args={"user": "input"}) >>> context = GlobalContext(request_id="123") >>> # result, contexts = await manager.prompt_pre_fetch(payload, context) # Called in async context """ @@ -172,7 +172,7 @@ async def execute( >>> # In async context: >>> # result, contexts = await executor.execute( >>> # plugins=plugins, - >>> # payload=PromptPrehookPayload(name="test", args={}), + >>> # payload=PromptPrehookPayload(prompt_id="123", args={}), >>> # global_context=GlobalContext(request_id="123"), >>> # plugin_run=pre_prompt_fetch, >>> # compare=pre_prompt_matches @@ -328,7 +328,7 @@ async def pre_prompt_fetch(plugin: PluginRef, payload: PromptPrehookPayload, con >>> from mcpgateway.plugins.framework import GlobalContext, Plugin, PromptPrehookPayload, PluginContext, GlobalContext >>> # Assuming you have a plugin instance: >>> # plugin_ref = PluginRef(my_plugin) - >>> payload = PromptPrehookPayload(name="test", args={"key": "value"}) + >>> payload = PromptPrehookPayload(prompt_id="123", args={"key": "value"}) >>> context = PluginContext(global_context=GlobalContext(request_id="123")) >>> # In async context: >>> # result = await pre_prompt_fetch(plugin_ref, payload, context) @@ -354,7 +354,7 @@ async def post_prompt_fetch(plugin: PluginRef, payload: PromptPosthookPayload, c >>> # Assuming you have a plugin instance: >>> # plugin_ref = PluginRef(my_plugin) >>> result = PromptResult(messages=[]) - >>> payload = PromptPosthookPayload(name="test", result=result) + >>> payload = PromptPosthookPayload(prompt_id="123", result=result) >>> context = PluginContext(global_context=GlobalContext(request_id="123")) >>> # In async context: >>> # result = await post_prompt_fetch(plugin_ref, payload, context) @@ -451,7 +451,7 @@ async def post_resource_fetch(plugin: PluginRef, payload: ResourcePostFetchPaylo >>> from mcpgateway.models import ResourceContent >>> # Assuming you have a plugin instance: >>> # plugin_ref = PluginRef(my_plugin) - >>> content = ResourceContent(type="resource", uri="file:///data.txt", text="Data") + >>> content = ResourceContent(type="resource", id="res-1", uri="file:///data.txt", text="Data") >>> payload = ResourcePostFetchPayload(uri="file:///data.txt", content=content) >>> context = PluginContext(global_context=GlobalContext(request_id="123")) >>> # In async context: @@ -484,7 +484,7 @@ class PluginManager: >>> >>> # Execute prompt hooks >>> from mcpgateway.plugins.framework import PromptPrehookPayload, GlobalContext - >>> payload = PromptPrehookPayload(name="test", args={}) + >>> payload = PromptPrehookPayload(prompt_id="123", args={}) >>> context = GlobalContext(request_id="req-123") >>> # In async context: >>> # result, contexts = await manager.prompt_pre_fetch(payload, context) @@ -713,6 +713,7 @@ async def prompt_pre_fetch( >>> >>> from mcpgateway.plugins.framework import PromptPrehookPayload, GlobalContext >>> payload = PromptPrehookPayload( + ... prompt_id="123", ... name="greeting", ... args={"user": "Alice"} ... ) @@ -774,7 +775,7 @@ async def prompt_post_fetch( >>> prompt_result = PromptResult(messages=[message]) >>> >>> post_payload = PromptPosthookPayload( - ... name="greeting", + ... prompt_id="123", ... result=prompt_result ... ) >>> @@ -974,7 +975,7 @@ async def resource_post_fetch( >>> # In async context: >>> # await manager.initialize() >>> # from mcpgateway.models import ResourceContent - >>> # content = ResourceContent(type="resource", uri="file:///data.txt", text="Data") + >>> # content = ResourceContent(type="resource",id="res-1", uri="file:///data.txt", text="Data") >>> # payload = ResourcePostFetchPayload("file:///data.txt", content) >>> # context = GlobalContext(request_id="123", server_id="srv1") >>> # contexts = self._context_store.get("123") # From pre-fetch diff --git a/mcpgateway/plugins/framework/models.py b/mcpgateway/plugins/framework/models.py index 7febad705..1d02eb3c9 100644 --- a/mcpgateway/plugins/framework/models.py +++ b/mcpgateway/plugins/framework/models.py @@ -769,26 +769,26 @@ class PromptPrehookPayload(BaseModel): """A prompt payload for a prompt prehook. Attributes: - name (str): The name of the prompt template. + prompt_id (str): The ID of the prompt template. args (dic[str,str]): The prompt template arguments. Examples: - >>> payload = PromptPrehookPayload(name="test_prompt", args={"user": "alice"}) - >>> payload.name - 'test_prompt' + >>> payload = PromptPrehookPayload(prompt_id="123", args={"user": "alice"}) + >>> payload.prompt_id + '123' >>> payload.args {'user': 'alice'} - >>> payload2 = PromptPrehookPayload(name="empty") + >>> payload2 = PromptPrehookPayload(prompt_id="empty") >>> payload2.args {} - >>> p = PromptPrehookPayload(name="greeting", args={"name": "Bob", "time": "morning"}) - >>> p.name - 'greeting' + >>> p = PromptPrehookPayload(prompt_id="123", args={"name": "Bob", "time": "morning"}) + >>> p.prompt_id + '123' >>> p.args["name"] 'Bob' """ - name: str + prompt_id: str args: Optional[dict[str, str]] = Field(default_factory=dict) @@ -796,27 +796,27 @@ class PromptPosthookPayload(BaseModel): """A prompt payload for a prompt posthook. Attributes: - name (str): The prompt name. + prompt_id (str): The prompt ID. result (PromptResult): The prompt after its template is rendered. Examples: >>> from mcpgateway.models import PromptResult, Message, TextContent >>> msg = Message(role="user", content=TextContent(type="text", text="Hello World")) >>> result = PromptResult(messages=[msg]) - >>> payload = PromptPosthookPayload(name="greeting", result=result) - >>> payload.name - 'greeting' + >>> payload = PromptPosthookPayload(prompt_id="123", result=result) + >>> payload.prompt_id + '123' >>> payload.result.messages[0].content.text 'Hello World' >>> from mcpgateway.models import PromptResult, Message, TextContent >>> msg = Message(role="assistant", content=TextContent(type="text", text="Test output")) >>> r = PromptResult(messages=[msg]) - >>> p = PromptPosthookPayload(name="test", result=r) - >>> p.name - 'test' + >>> p = PromptPosthookPayload(prompt_id="123", result=r) + >>> p.prompt_id + '123' """ - name: str + prompt_id: str result: PromptResult @@ -1096,7 +1096,7 @@ class ResourcePostFetchPayload(BaseModel): Examples: >>> from mcpgateway.models import ResourceContent - >>> content = ResourceContent(type="resource", uri="file:///data.txt", + >>> content = ResourceContent(type="resource", id="res-1", uri="file:///data.txt", ... text="Hello World") >>> payload = ResourcePostFetchPayload(uri="file:///data.txt", content=content) >>> payload.uri @@ -1104,7 +1104,7 @@ class ResourcePostFetchPayload(BaseModel): >>> payload.content.text 'Hello World' >>> from mcpgateway.models import ResourceContent - >>> resource_content = ResourceContent(type="resource", uri="test://resource", text="Test data") + >>> resource_content = ResourceContent(type="resource", id="res-2", uri="test://resource", text="Test data") >>> p = ResourcePostFetchPayload(uri="test://resource", content=resource_content) >>> p.uri 'test://resource' diff --git a/mcpgateway/plugins/framework/utils.py b/mcpgateway/plugins/framework/utils.py index 325cbbd29..17f561fb1 100644 --- a/mcpgateway/plugins/framework/utils.py +++ b/mcpgateway/plugins/framework/utils.py @@ -124,12 +124,12 @@ def pre_prompt_matches(payload: PromptPrehookPayload, conditions: list[PluginCon Examples: >>> from mcpgateway.plugins.framework import PluginCondition, PromptPrehookPayload, GlobalContext - >>> payload = PromptPrehookPayload(name="greeting", args={}) - >>> cond = PluginCondition(prompts={"greeting"}) + >>> payload = PromptPrehookPayload(prompt_id="id1", args={}) + >>> cond = PluginCondition(prompts={"id1"}) >>> ctx = GlobalContext(request_id="req1") >>> pre_prompt_matches(payload, [cond], ctx) True - >>> payload2 = PromptPrehookPayload(name="other", args={}) + >>> payload2 = PromptPrehookPayload(prompt_id="id2", args={}) >>> pre_prompt_matches(payload2, [cond], ctx) False """ @@ -138,7 +138,7 @@ def pre_prompt_matches(payload: PromptPrehookPayload, conditions: list[PluginCon if not matches(condition, context): current_result = False - if condition.prompts and payload.name not in condition.prompts: + if condition.prompts and payload.prompt_id not in condition.prompts: current_result = False if current_result: return True @@ -163,7 +163,7 @@ def post_prompt_matches(payload: PromptPosthookPayload, conditions: list[PluginC if not matches(condition, context): current_result = False - if condition.prompts and payload.name not in condition.prompts: + if condition.prompts and payload.prompt_id not in condition.prompts: current_result = False if current_result: return True @@ -294,8 +294,8 @@ def post_resource_matches(payload: ResourcePostFetchPayload, conditions: list[Pl Examples: >>> from mcpgateway.plugins.framework import PluginCondition, ResourcePostFetchPayload, GlobalContext >>> from mcpgateway.models import ResourceContent - >>> content = ResourceContent(type="resource", uri="file:///data.txt", text="Test") - >>> payload = ResourcePostFetchPayload(uri="file:///data.txt", content=content) + >>> content = ResourceContent(type="resource", id="123", uri="file:///data.txt", text="Test") + >>> payload = ResourcePostFetchPayload(id="123",uri="file:///data.txt", content=content) >>> cond = PluginCondition(resources={"file:///data.txt"}) >>> ctx = GlobalContext(request_id="req1") >>> post_resource_matches(payload, [cond], ctx) diff --git a/mcpgateway/schemas.py b/mcpgateway/schemas.py index 368e8508a..2b5a66d6b 100644 --- a/mcpgateway/schemas.py +++ b/mcpgateway/schemas.py @@ -1742,6 +1742,7 @@ class ResourceUpdate(BaseModelWithConfigDict): Similar to ResourceCreate but URI is not required and all fields are optional. """ + uri: Optional[str] = Field(None, description="Unique URI for the resource") name: Optional[str] = Field(None, description="Human-readable resource name") description: Optional[str] = Field(None, description="Resource description") mime_type: Optional[str] = Field(None, description="Resource MIME type") @@ -3928,6 +3929,7 @@ class A2AAgentCreate(BaseModel): model_config = ConfigDict(str_strip_whitespace=True) name: str = Field(..., description="Unique name for the agent") + slug: Optional[str] = Field(None, description="Optional slug for the agent (auto-generated if not provided)") description: Optional[str] = Field(None, description="Agent description") endpoint_url: str = Field(..., description="URL endpoint for the agent") agent_type: str = Field(default="generic", description="Type of agent (e.g., 'openai', 'anthropic', 'custom')") diff --git a/mcpgateway/services/a2a_service.py b/mcpgateway/services/a2a_service.py index 26861f341..7d9dd4350 100644 --- a/mcpgateway/services/a2a_service.py +++ b/mcpgateway/services/a2a_service.py @@ -18,6 +18,7 @@ # Third-Party import httpx from sqlalchemy import and_, case, delete, desc, func, or_, select +from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session # First-Party @@ -27,6 +28,7 @@ from mcpgateway.services.logging_service import LoggingService from mcpgateway.services.team_management_service import TeamManagementService from mcpgateway.services.tool_service import ToolService +from mcpgateway.utils.create_slug import slugify # Initialize logging service first logging_service = LoggingService() @@ -70,7 +72,7 @@ class A2AAgentNotFoundError(A2AAgentError): class A2AAgentNameConflictError(A2AAgentError): """Raised when an A2A agent name conflicts with an existing one.""" - def __init__(self, name: str, is_active: bool = True, agent_id: Optional[str] = None): + def __init__(self, name: str, is_active: bool = True, agent_id: Optional[str] = None, visibility: Optional[str] = "public"): """Initialize an A2AAgentNameConflictError exception. Creates an exception that indicates an agent name conflict, with additional @@ -80,6 +82,7 @@ def __init__(self, name: str, is_active: bool = True, agent_id: Optional[str] = name: The agent name that caused the conflict. is_active: Whether the conflicting agent is currently active. agent_id: The ID of the conflicting agent, if known. + visibility: The visibility level of the conflicting agent (private, team, public). Examples: >>> error = A2AAgentNameConflictError("test-agent") @@ -106,7 +109,7 @@ def __init__(self, name: str, is_active: bool = True, agent_id: Optional[str] = self.name = name self.is_active = is_active self.agent_id = agent_id - message = f"A2A Agent already exists with name: {name}" + message = f"{visibility.capitalize()} A2A Agent already exists with name: {name}" if not is_active: message += f" (currently inactive, ID: {agent_id})" super().__init__(message) @@ -170,55 +173,78 @@ async def register_agent( Raises: A2AAgentNameConflictError: If an agent with the same name already exists. + IntegrityError: If a database integrity error occurs. + A2AAgentError: For other errors during registration. """ - # Check for existing agent with same name - existing_query = select(DbA2AAgent).where(DbA2AAgent.name == agent_data.name) - existing_agent = db.execute(existing_query).scalar_one_or_none() - - if existing_agent: - raise A2AAgentNameConflictError(name=agent_data.name, is_active=existing_agent.enabled, agent_id=existing_agent.id) - - # Create new agent - new_agent = DbA2AAgent( - name=agent_data.name, - description=agent_data.description, - endpoint_url=agent_data.endpoint_url, - agent_type=agent_data.agent_type, - protocol_version=agent_data.protocol_version, - capabilities=agent_data.capabilities, - config=agent_data.config, - auth_type=agent_data.auth_type, - auth_value=agent_data.auth_value, # This should be encrypted in practice - tags=agent_data.tags, - # Team scoping fields - use schema values if provided, otherwise fallback to parameters - team_id=getattr(agent_data, "team_id", None) or team_id, - owner_email=getattr(agent_data, "owner_email", None) or owner_email or created_by, - visibility=getattr(agent_data, "visibility", None) or visibility, - created_by=created_by, - created_from_ip=created_from_ip, - created_via=created_via, - created_user_agent=created_user_agent, - import_batch_id=import_batch_id, - federation_source=federation_source, - ) - - db.add(new_agent) - db.commit() - db.refresh(new_agent) - - # Automatically create a tool for the A2A agent if not already present - tool_service = ToolService() - await tool_service.create_tool_from_a2a_agent( - db=db, - agent=new_agent, - created_by=created_by, - created_from_ip=created_from_ip, - created_via=created_via, - created_user_agent=created_user_agent, - ) - - logger.info(f"Registered new A2A agent: {new_agent.name} (ID: {new_agent.id})") - return self._db_to_schema(new_agent) + try: + agent_data.slug = slugify(agent_data.name) + # Check for existing server with the same slug within the same team or public scope + if visibility.lower() == "public": + logger.info(f"visibility.lower(): {visibility.lower()}") + logger.info(f"agent_data.name: {agent_data.name}") + logger.info(f"agent_data.slug: {agent_data.slug}") + # Check for existing public a2a agent with the same slug + existing_agent = db.execute(select(DbA2AAgent).where(DbA2AAgent.slug == agent_data.slug, DbA2AAgent.visibility == "public")).scalar_one_or_none() + if existing_agent: + raise A2AAgentNameConflictError(name=agent_data.slug, is_active=existing_agent.enabled, agent_id=existing_agent.id, visibility=existing_agent.visibility) + elif visibility.lower() == "team" and team_id: + # Check for existing team a2a agent with the same slug + existing_agent = db.execute(select(DbA2AAgent).where(DbA2AAgent.slug == agent_data.slug, DbA2AAgent.visibility == "team", DbA2AAgent.team_id == team_id)).scalar_one_or_none() + if existing_agent: + raise A2AAgentNameConflictError(name=agent_data.slug, is_active=existing_agent.enabled, agent_id=existing_agent.id, visibility=existing_agent.visibility) + + # Create new agent + new_agent = DbA2AAgent( + name=agent_data.name, + slug=agent_data.slug, + description=agent_data.description, + endpoint_url=agent_data.endpoint_url, + agent_type=agent_data.agent_type, + protocol_version=agent_data.protocol_version, + capabilities=agent_data.capabilities, + config=agent_data.config, + auth_type=agent_data.auth_type, + auth_value=agent_data.auth_value, # This should be encrypted in practice + tags=agent_data.tags, + # Team scoping fields - use schema values if provided, otherwise fallback to parameters + team_id=getattr(agent_data, "team_id", None) or team_id, + owner_email=getattr(agent_data, "owner_email", None) or owner_email or created_by, + visibility=getattr(agent_data, "visibility", None) or visibility, + created_by=created_by, + created_from_ip=created_from_ip, + created_via=created_via, + created_user_agent=created_user_agent, + import_batch_id=import_batch_id, + federation_source=federation_source, + ) + + db.add(new_agent) + db.commit() + db.refresh(new_agent) + + # Automatically create a tool for the A2A agent if not already present + tool_service = ToolService() + await tool_service.create_tool_from_a2a_agent( + db=db, + agent=new_agent, + created_by=created_by, + created_from_ip=created_from_ip, + created_via=created_via, + created_user_agent=created_user_agent, + ) + + logger.info(f"Registered new A2A agent: {new_agent.name} (ID: {new_agent.id})") + return self._db_to_schema(new_agent) + except A2AAgentNameConflictError as ie: + db.rollback() + raise ie + except IntegrityError as ie: + db.rollback() + logger.error(f"IntegrityErrors in group: {ie}") + raise ie + except Exception as e: + db.rollback() + raise A2AAgentError(f"Failed to register A2A agent: {str(e)}") async def list_agents(self, db: Session, cursor: Optional[str] = None, include_inactive: bool = False, tags: Optional[List[str]] = None) -> List[A2AAgentRead]: # pylint: disable=unused-argument """List A2A agents with optional filtering. @@ -396,6 +422,8 @@ async def update_agent( A2AAgentNotFoundError: If the agent is not found. PermissionError: If user doesn't own the agent. A2AAgentNameConflictError: If name conflicts with another agent. + A2AAgentError: For other errors during update. + IntegrityError: If a database integrity error occurs. """ try: query = select(DbA2AAgent).where(DbA2AAgent.id == agent_id) @@ -412,15 +440,21 @@ async def update_agent( permission_service = PermissionService(db) if not await permission_service.check_resource_ownership(user_email, agent): raise PermissionError("Only the owner can update this agent") - # Check for name conflict if name is being updated if agent_data.name and agent_data.name != agent.name: - existing_query = select(DbA2AAgent).where(DbA2AAgent.name == agent_data.name, DbA2AAgent.id != agent_id) - existing_agent = db.execute(existing_query).scalar_one_or_none() - - if existing_agent: - raise A2AAgentNameConflictError(name=agent_data.name, is_active=existing_agent.enabled, agent_id=existing_agent.id) - + visibility = agent_data.visibility or agent.visibility + team_id = agent_data.team_id or agent.team_id + # Check for existing server with the same slug within the same team or public scope + if visibility.lower() == "public": + # Check for existing public a2a agent with the same slug + existing_agent = db.execute(select(DbA2AAgent).where(DbA2AAgent.slug == agent_data.slug, DbA2AAgent.visibility == "public")).scalar_one_or_none() + if existing_agent: + raise A2AAgentNameConflictError(name=agent_data.slug, is_active=existing_agent.enabled, agent_id=existing_agent.id, visibility=existing_agent.visibility) + elif visibility.lower() == "team" and team_id: + # Check for existing team a2a agent with the same slug + existing_agent = db.execute(select(DbA2AAgent).where(DbA2AAgent.slug == agent_data.slug, DbA2AAgent.visibility == "team", DbA2AAgent.team_id == team_id)).scalar_one_or_none() + if existing_agent: + raise A2AAgentNameConflictError(name=agent_data.slug, is_active=existing_agent.enabled, agent_id=existing_agent.id, visibility=existing_agent.visibility) # Update fields update_data = agent_data.model_dump(exclude_unset=True) for field, value in update_data.items(): @@ -447,6 +481,19 @@ async def update_agent( except PermissionError: db.rollback() raise + except A2AAgentNameConflictError as ie: + db.rollback() + raise ie + except A2AAgentNotFoundError as nf: + db.rollback() + raise nf + except IntegrityError as ie: + db.rollback() + logger.error(f"IntegrityErrors in group: {ie}") + raise ie + except Exception as e: + db.rollback() + raise A2AAgentError(f"Failed to update A2A agent: {str(e)}") async def toggle_agent_status(self, db: Session, agent_id: str, activate: bool, reachable: Optional[bool] = None, user_email: Optional[str] = None) -> A2AAgentRead: """Toggle the activation status of an A2A agent. diff --git a/mcpgateway/services/prompt_service.py b/mcpgateway/services/prompt_service.py index 346a51b94..208b0ccc5 100644 --- a/mcpgateway/services/prompt_service.py +++ b/mcpgateway/services/prompt_service.py @@ -20,7 +20,7 @@ import os from string import Formatter import time -from typing import Any, AsyncGenerator, Dict, List, Optional, Set +from typing import Any, AsyncGenerator, Dict, List, Optional, Set, Union import uuid # Third-Party @@ -58,13 +58,14 @@ class PromptNotFoundError(PromptError): class PromptNameConflictError(PromptError): """Raised when a prompt name conflicts with existing (active or inactive) prompt.""" - def __init__(self, name: str, is_active: bool = True, prompt_id: Optional[int] = None): + def __init__(self, name: str, is_active: bool = True, prompt_id: Optional[int] = None, visibility: str = "public") -> None: """Initialize the error with prompt information. Args: name: The conflicting prompt name is_active: Whether the existing prompt is active prompt_id: ID of the existing prompt if available + visibility: Prompt visibility level (private, team, public). Examples: >>> from mcpgateway.services.prompt_service import PromptNameConflictError @@ -84,7 +85,7 @@ def __init__(self, name: str, is_active: bool = True, prompt_id: Optional[int] = self.name = name self.is_active = is_active self.prompt_id = prompt_id - message = f"Prompt already exists with name: {name}" + message = f"{visibility.capitalize()} Prompt already exists with name: {name}" if not is_active: message += f" (currently inactive, ID: {prompt_id})" super().__init__(message) @@ -317,6 +318,7 @@ async def register_prompt( Raises: IntegrityError: If a database integrity error occurs. + PromptNameConflictError: If a prompt with the same name already exists. PromptError: For other prompt registration errors Examples: @@ -376,6 +378,17 @@ async def register_prompt( owner_email=getattr(prompt, "owner_email", None) or owner_email or created_by, visibility=getattr(prompt, "visibility", None) or visibility, ) + # Check for existing server with the same name + if visibility.lower() == "public": + # Check for existing public prompt with the same name + existing_prompt = db.execute(select(DbPrompt).where(DbPrompt.name == prompt.name, DbPrompt.visibility == "public")).scalar_one_or_none() + if existing_prompt: + raise PromptNameConflictError(prompt.name, is_active=existing_prompt.is_active, prompt_id=existing_prompt.id, visibility=existing_prompt.visibility) + elif visibility.lower() == "team": + # Check for existing team prompt with the same name + existing_prompt = db.execute(select(DbPrompt).where(DbPrompt.name == prompt.name, DbPrompt.visibility == "team", DbPrompt.team_id == team_id)).scalar_one_or_none() + if existing_prompt: + raise PromptNameConflictError(prompt.name, is_active=existing_prompt.is_active, prompt_id=existing_prompt.id, visibility=existing_prompt.visibility) # Add to DB db.add(db_prompt) @@ -392,6 +405,9 @@ async def register_prompt( except IntegrityError as ie: logger.error(f"IntegrityErrors in group: {ie}") raise ie + except PromptNameConflictError as se: + db.rollback() + raise se except Exception as e: db.rollback() raise PromptError(f"Failed to register prompt: {str(e)}") @@ -596,7 +612,7 @@ async def _record_prompt_metric(self, db: Session, prompt: DbPrompt, start_time: async def get_prompt( self, db: Session, - name: str, + prompt_id: Union[int, str], arguments: Optional[Dict[str, str]] = None, user: Optional[str] = None, tenant_id: Optional[str] = None, @@ -607,7 +623,7 @@ async def get_prompt( Args: db: Database session - name: Name of prompt to get + prompt_id: ID of the prompt to retrieve arguments: Optional arguments for rendering user: Optional user identifier for plugin context tenant_id: Optional tenant identifier for plugin context @@ -631,7 +647,7 @@ async def get_prompt( >>> db.execute.return_value.scalar_one_or_none.return_value = MagicMock() >>> import asyncio >>> try: - ... asyncio.run(service.get_prompt(db, 'prompt_name')) + ... asyncio.run(service.get_prompt(db, 'prompt_id')) ... except Exception: ... pass """ @@ -645,7 +661,7 @@ async def get_prompt( with create_span( "prompt.render", { - "prompt.name": name, + "prompt.id": prompt_id, "arguments_count": len(arguments) if arguments else 0, "user": user or "anonymous", "server_id": server_id, @@ -654,29 +670,32 @@ async def get_prompt( }, ) as span: try: + # Ensure prompt_id is an int for database operations + prompt_id_int = int(prompt_id) if isinstance(prompt_id, str) else prompt_id + if self._plugin_manager: if not request_id: request_id = uuid.uuid4().hex global_context = GlobalContext(request_id=request_id, user=user, server_id=server_id, tenant_id=tenant_id) pre_result, context_table = await self._plugin_manager.prompt_pre_fetch( - payload=PromptPrehookPayload(name=name, args=arguments), global_context=global_context, local_contexts=None, violations_as_exceptions=True + payload=PromptPrehookPayload(prompt_id=str(prompt_id), args=arguments), global_context=global_context, local_contexts=None, violations_as_exceptions=True ) # Use modified payload if provided if pre_result.modified_payload: payload = pre_result.modified_payload - name = payload.name + prompt_id_int = int(payload.prompt_id) if isinstance(payload.prompt_id, str) else payload.prompt_id arguments = payload.args # Find prompt - prompt = db.execute(select(DbPrompt).where(DbPrompt.name == name).where(DbPrompt.is_active)).scalar_one_or_none() + prompt = db.execute(select(DbPrompt).where(DbPrompt.id == prompt_id_int).where(DbPrompt.is_active)).scalar_one_or_none() if not prompt: - inactive_prompt = db.execute(select(DbPrompt).where(DbPrompt.name == name).where(not_(DbPrompt.is_active))).scalar_one_or_none() + inactive_prompt = db.execute(select(DbPrompt).where(DbPrompt.id == prompt_id_int).where(not_(DbPrompt.is_active))).scalar_one_or_none() if inactive_prompt: - raise PromptNotFoundError(f"Prompt '{name}' exists but is inactive") + raise PromptNotFoundError(f"Prompt '{prompt_id_int}' exists but is inactive") - raise PromptNotFoundError(f"Prompt not found: {name}") + raise PromptNotFoundError(f"Prompt not found: {prompt_id_int}") if not arguments: result = PromptResult( @@ -702,7 +721,7 @@ async def get_prompt( if self._plugin_manager: post_result, _ = await self._plugin_manager.prompt_post_fetch( - payload=PromptPosthookPayload(name=name, result=result), global_context=global_context, local_contexts=context_table, violations_as_exceptions=True + payload=PromptPosthookPayload(prompt_id=str(prompt_id_int), result=result), global_context=global_context, local_contexts=context_table, violations_as_exceptions=True ) # Use modified payload if provided result = post_result.modified_payload.result if post_result.modified_payload else result @@ -732,7 +751,7 @@ async def get_prompt( async def update_prompt( self, db: Session, - name: str, + prompt_id: Union[int, str], prompt_update: PromptUpdate, modified_by: Optional[str] = None, modified_from_ip: Optional[str] = None, @@ -745,7 +764,7 @@ async def update_prompt( Args: db: Database session - name: Name of prompt to update + prompt_id: ID of prompt to update prompt_update: Prompt update object modified_by: Username of the person modifying the prompt modified_from_ip: IP address where the modification originated @@ -760,6 +779,7 @@ async def update_prompt( PromptNotFoundError: If the prompt is not found PermissionError: If user doesn't own the prompt IntegrityError: If a database integrity error occurs. + PromptNameConflictError: If a prompt with the same name already exists. PromptError: For other update errors Examples: @@ -779,14 +799,25 @@ async def update_prompt( ... pass """ try: - prompt = db.execute(select(DbPrompt).where(DbPrompt.name == name).where(DbPrompt.is_active)).scalar_one_or_none() + prompt = db.get(DbPrompt, prompt_id) if not prompt: - inactive_prompt = db.execute(select(DbPrompt).where(DbPrompt.name == name).where(not_(DbPrompt.is_active))).scalar_one_or_none() - - if inactive_prompt: - raise PromptNotFoundError(f"Prompt '{name}' exists but is inactive") + raise PromptNotFoundError(f"Prompt not found: {prompt_id}") - raise PromptNotFoundError(f"Prompt not found: {name}") + # # Check for name conflict if name is being changed and visibility is public + if prompt_update.name and prompt_update.name != prompt.name: + visibility = prompt_update.visibility or prompt.visibility + team_id = prompt_update.team_id or prompt.team_id + if visibility.lower() == "public": + # Check for existing public prompts with the same name + existing_prompt = db.execute(select(DbPrompt).where(DbPrompt.name == prompt_update.name, DbPrompt.visibility == "public")).scalar_one_or_none() + if existing_prompt: + raise PromptNameConflictError(prompt_update.name, is_active=existing_prompt.is_active, prompt_id=existing_prompt.id, visibility=existing_prompt.visibility) + elif visibility.lower() == "team" and team_id: + # Check for existing team prompt with the same name + existing_prompt = db.execute(select(DbPrompt).where(DbPrompt.name == prompt_update.name, DbPrompt.visibility == "team", DbPrompt.team_id == team_id)).scalar_one_or_none() + logger.info(f"Existing prompt check result: {existing_prompt}") + if existing_prompt: + raise PromptNameConflictError(prompt_update.name, is_active=existing_prompt.is_active, prompt_id=existing_prompt.id, visibility=existing_prompt.visibility) # Check ownership if user_email provided if user_email: @@ -858,6 +889,10 @@ async def update_prompt( db.rollback() logger.error(f"Prompt not found: {e}") raise e + except PromptNameConflictError as pnce: + db.rollback() + logger.error(f"Prompt name conflict: {pnce}") + raise pnce except Exception as e: db.rollback() raise PromptError(f"Failed to update prompt: {str(e)}") @@ -930,13 +965,13 @@ async def toggle_prompt_status(self, db: Session, prompt_id: int, activate: bool raise PromptError(f"Failed to toggle prompt status: {str(e)}") # Get prompt details for admin ui - async def get_prompt_details(self, db: Session, name: str, include_inactive: bool = False) -> Dict[str, Any]: + async def get_prompt_details(self, db: Session, prompt_id: Union[int, str], include_inactive: bool = False) -> Dict[str, Any]: # pylint: disable=unused-argument """ - Get prompt details by name. + Get prompt details by ID. Args: db: Database session - name: Name of prompt + prompt_id: ID of prompt include_inactive: Whether to include inactive prompts Returns: @@ -958,34 +993,28 @@ async def get_prompt_details(self, db: Session, name: str, include_inactive: boo >>> result == prompt_dict True """ - query = select(DbPrompt).where(DbPrompt.name == name) - if not include_inactive: - query = query.where(DbPrompt.is_active) - prompt = db.execute(query).scalar_one_or_none() + logger.info(f"prompt_id:::{prompt_id}") + prompt = db.get(DbPrompt, prompt_id) if not prompt: - if not include_inactive: - inactive_prompt = db.execute(select(DbPrompt).where(DbPrompt.name == name).where(not_(DbPrompt.is_active))).scalar_one_or_none() - if inactive_prompt: - raise PromptNotFoundError(f"Prompt '{name}' exists but is inactive") - raise PromptNotFoundError(f"Prompt not found: {name}") + raise PromptNotFoundError(f"Prompt not found: {prompt_id}") # Return the fully converted prompt including metrics prompt.team = self._get_team_name(db, prompt.team_id) return self._convert_db_prompt(prompt) - async def delete_prompt(self, db: Session, name: str, user_email: Optional[str] = None) -> None: + async def delete_prompt(self, db: Session, prompt_id: Union[int, str], user_email: Optional[str] = None) -> None: """ - Delete a prompt template. + Delete a prompt template by its ID. Args: - db: Database session - name: Name of prompt to delete - user_email: Email of user performing delete (for ownership check) + db (Session): Database session. + prompt_id (str): ID of the prompt to delete. + user_email (Optional[str]): Email of user performing delete (for ownership check). Raises: - PromptNotFoundError: If the prompt is not found - PermissionError: If user doesn't own the prompt - PromptError: For other deletion errors - Exception: For unexpected errors + PromptNotFoundError: If the prompt is not found. + PermissionError: If user doesn't own the prompt. + PromptError: For other deletion errors. + Exception: For unexpected errors. Examples: >>> from mcpgateway.services.prompt_service import PromptService @@ -999,14 +1028,14 @@ async def delete_prompt(self, db: Session, name: str, user_email: Optional[str] >>> service._notify_prompt_deleted = MagicMock() >>> import asyncio >>> try: - ... asyncio.run(service.delete_prompt(db, 'prompt_name')) + ... asyncio.run(service.delete_prompt(db, '123')) ... except Exception: ... pass """ try: - prompt = db.execute(select(DbPrompt).where(DbPrompt.name == name)).scalar_one_or_none() + prompt = db.get(DbPrompt, prompt_id) if not prompt: - raise PromptNotFoundError(f"Prompt not found: {name}") + raise PromptNotFoundError(f"Prompt not found: {prompt_id}") # Check ownership if user_email provided if user_email: @@ -1021,7 +1050,7 @@ async def delete_prompt(self, db: Session, name: str, user_email: Optional[str] db.delete(prompt) db.commit() await self._notify_prompt_deleted(prompt_info) - logger.info(f"Permanently deleted prompt: {name}") + logger.info(f"Deleted prompt: {prompt_info['name']}") except PermissionError: db.rollback() raise diff --git a/mcpgateway/services/resource_service.py b/mcpgateway/services/resource_service.py index 7c28847df..9a31e5237 100644 --- a/mcpgateway/services/resource_service.py +++ b/mcpgateway/services/resource_service.py @@ -78,18 +78,20 @@ class ResourceNotFoundError(ResourceError): class ResourceURIConflictError(ResourceError): """Raised when a resource URI conflicts with existing (active or inactive) resource.""" - def __init__(self, uri: str, is_active: bool = True, resource_id: Optional[int] = None): + def __init__(self, uri: str, is_active: bool = True, resource_id: Optional[int] = None, visibility: str = "public") -> None: """Initialize the error with resource information. Args: uri: The conflicting resource URI is_active: Whether the existing resource is active resource_id: ID of the existing resource if available + visibility: Visibility status of the resource """ self.uri = uri self.is_active = is_active self.resource_id = resource_id - message = f"Resource already exists with URI: {uri}" + message = f"{visibility.capitalize()} Resource already exists with URI: {uri}" + logger.info(f"ResourceURIConflictError: {message}") if not is_active: message += f" (currently inactive, ID: {resource_id})" super().__init__(message) @@ -312,6 +314,7 @@ async def register_resource( Raises: IntegrityError: If a database integrity error occurs. + ResourceURIConflictError: If a resource with the same URI already exists. ResourceError: For other resource registration errors Examples: @@ -333,6 +336,20 @@ async def register_resource( 'resource_read' """ try: + logger.info(f"Registering resource: {resource.uri}") + # Check for existing server with the same uri + if visibility.lower() == "public": + logger.info(f"visibility:: {visibility}") + # Check for existing public resource with the same uri + existing_resource = db.execute(select(DbResource).where(DbResource.uri == resource.uri, DbResource.visibility == "public")).scalar_one_or_none() + if existing_resource: + raise ResourceURIConflictError(resource.uri, is_active=existing_resource.is_active, resource_id=existing_resource.id, visibility=existing_resource.visibility) + elif visibility.lower() == "team" and team_id: + # Check for existing team resource with the same uri + existing_resource = db.execute(select(DbResource).where(DbResource.uri == resource.uri, DbResource.visibility == "team", DbResource.team_id == team_id)).scalar_one_or_none() + if existing_resource: + raise ResourceURIConflictError(resource.uri, is_active=existing_resource.is_active, resource_id=existing_resource.id, visibility=existing_resource.visibility) + # Detect mime type if not provided mime_type = resource.mime_type if not mime_type: @@ -379,6 +396,9 @@ async def register_resource( except IntegrityError as ie: logger.error(f"IntegrityErrors in group: {ie}") raise ie + except ResourceURIConflictError as rce: + logger.error(f"ResourceURIConflictError in group: {resource.uri}") + raise rce except Exception as e: db.rollback() raise ResourceError(f"Failed to register resource: {str(e)}") @@ -616,12 +636,12 @@ async def _record_resource_metric(self, db: Session, resource: DbResource, start db.add(metric) db.commit() - async def read_resource(self, db: Session, uri: str, request_id: Optional[str] = None, user: Optional[str] = None, server_id: Optional[str] = None) -> ResourceContent: + async def read_resource(self, db: Session, resource_id: Union[int, str], request_id: Optional[str] = None, user: Optional[str] = None, server_id: Optional[str] = None) -> ResourceContent: """Read a resource's content with plugin hook support. Args: db: Database session - uri: Resource URI to read + resource_id: ID of the resource to read request_id: Optional request ID for tracing user: Optional user making the request server_id: Optional server ID for context @@ -642,7 +662,10 @@ async def read_resource(self, db: Session, uri: str, request_id: Optional[str] = >>> service = ResourceService() >>> db = MagicMock() >>> uri = 'http://example.com/resource.txt' - >>> db.execute.return_value.scalar_one_or_none.return_value = MagicMock(content='test') + >>> import types + >>> mock_resource = types.SimpleNamespace(content='test', uri=uri) + >>> db.execute.return_value.scalar_one_or_none.return_value = mock_resource + >>> db.get.return_value = mock_resource # Ensure uri is a string, not None >>> import asyncio >>> result = asyncio.run(service.read_resource(db, uri)) >>> isinstance(result, ResourceContent) @@ -663,7 +686,8 @@ async def read_resource(self, db: Session, uri: str, request_id: Optional[str] = success = False error_message = None resource = None - + resource_db = db.get(DbResource, resource_id) + uri = resource_db.uri if resource_db else None # Create trace span for resource reading with create_span( "resource.read", @@ -672,8 +696,8 @@ async def read_resource(self, db: Session, uri: str, request_id: Optional[str] = "user": user or "anonymous", "server_id": server_id, "request_id": request_id, - "http.url": uri if uri.startswith("http") else None, - "resource.type": "template" if ("{" in uri and "}" in uri) else "static", + "http.url": uri if uri is not None and uri.startswith("http") else None, + "resource.type": "template" if (uri is not None and "{" in uri and "}" in uri) else "static", }, ) as span: try: @@ -685,7 +709,7 @@ async def read_resource(self, db: Session, uri: str, request_id: Optional[str] = contexts = None # Call pre-fetch hooks if plugin manager is available - plugin_eligible = bool(self._plugin_manager and PLUGINS_AVAILABLE and ("://" in uri)) + plugin_eligible = bool(self._plugin_manager and PLUGINS_AVAILABLE and uri and ("://" in uri)) if plugin_eligible: # Initialize plugin manager if needed # pylint: disable=protected-access @@ -718,21 +742,20 @@ async def read_resource(self, db: Session, uri: str, request_id: Optional[str] = logger.debug(f"Resource URI modified by plugin: {original_uri} -> {uri}") # Original resource fetching logic + logger.info(f"Fetching resource: {resource_id} (URI: {uri})") # Check for template - if "{" in uri and "}" in uri: + if uri is not None and "{" in uri and "}" in uri: content = await self._read_template_resource(uri) else: # Find resource - resource = db.execute(select(DbResource).where(DbResource.uri == uri).where(DbResource.is_active)).scalar_one_or_none() - + resource = db.execute(select(DbResource).where(DbResource.id == resource_id).where(DbResource.is_active)).scalar_one_or_none() if not resource: # Check if inactive resource exists - inactive_resource = db.execute(select(DbResource).where(DbResource.uri == uri).where(not_(DbResource.is_active))).scalar_one_or_none() - + inactive_resource = db.execute(select(DbResource).where(DbResource.id == resource_id).where(not_(DbResource.is_active))).scalar_one_or_none() if inactive_resource: - raise ResourceNotFoundError(f"Resource '{uri}' exists but is inactive") + raise ResourceNotFoundError(f"Resource '{resource_id}' exists but is inactive") - raise ResourceNotFoundError(f"Resource not found: {uri}") + raise ResourceNotFoundError(f"Resource not found: {resource_id}") content = resource.content @@ -747,8 +770,6 @@ async def read_resource(self, db: Session, uri: str, request_id: Optional[str] = # Use modified content if plugin changed it if post_result.modified_payload: content = post_result.modified_payload.content - logger.debug(f"Resource content modified by plugin for URI: {original_uri}") - # Set success attributes on span if span: span.set_attribute("success", True) @@ -765,19 +786,18 @@ async def read_resource(self, db: Session, uri: str, request_id: Optional[str] = # If content is already a Pydantic content model, return as-is if isinstance(content, (ResourceContent, TextContent)): return content - # If content is any object that quacks like content (e.g., MagicMock with .text/.blob), return as-is if hasattr(content, "text") or hasattr(content, "blob"): return content # Normalize primitive types to ResourceContent if isinstance(content, bytes): - return ResourceContent(type="resource", uri=original_uri, blob=content) + return ResourceContent(type="resource", id=resource_id, uri=original_uri, blob=content) if isinstance(content, str): - return ResourceContent(type="resource", uri=original_uri, text=content) + return ResourceContent(type="resource", id=resource_id, uri=original_uri, text=content) # Fallback to stringified content - return ResourceContent(type="resource", uri=original_uri, text=str(content)) + return ResourceContent(type="resource", id=resource_id, uri=original_uri, text=str(content)) except Exception as e: success = False @@ -947,7 +967,7 @@ async def unsubscribe_resource(self, db: Session, subscription: ResourceSubscrip async def update_resource( self, db: Session, - uri: str, + resource_id: Union[int, str], resource_update: ResourceUpdate, modified_by: Optional[str] = None, modified_from_ip: Optional[str] = None, @@ -960,7 +980,7 @@ async def update_resource( Args: db: Database session - uri: Resource URI + resource_id: Resource ID resource_update: Resource update object modified_by: Username of the person modifying the resource modified_from_ip: IP address where the modification request originated @@ -973,12 +993,13 @@ async def update_resource( Raises: ResourceNotFoundError: If the resource is not found + ResourceURIConflictError: If a resource with the same URI already exists. PermissionError: If user doesn't own the resource ResourceError: For other update errors IntegrityError: If a database integrity error occurs. Exception: For unexpected errors - Examples: + Example: >>> from mcpgateway.services.resource_service import ResourceService >>> from unittest.mock import MagicMock, AsyncMock >>> from mcpgateway.schemas import ResourceRead @@ -992,21 +1013,29 @@ async def update_resource( >>> service._convert_resource_to_read = MagicMock(return_value='resource_read') >>> ResourceRead.model_validate = MagicMock(return_value='resource_read') >>> import asyncio - >>> asyncio.run(service.update_resource(db, 'uri', MagicMock())) + >>> asyncio.run(service.update_resource(db, 'resource_id', MagicMock())) 'resource_read' """ try: - # Find resource - resource = db.execute(select(DbResource).where(DbResource.uri == uri).where(DbResource.is_active)).scalar_one_or_none() - + logger.info(f"Updating resource: {resource_id}") + resource = db.get(DbResource, resource_id) if not resource: - # Check if inactive resource exists - inactive_resource = db.execute(select(DbResource).where(DbResource.uri == uri).where(not_(DbResource.is_active))).scalar_one_or_none() - - if inactive_resource: - raise ResourceNotFoundError(f"Resource '{uri}' exists but is inactive") + raise ResourceNotFoundError(f"Resource not found: {resource_id}") - raise ResourceNotFoundError(f"Resource not found: {uri}") + # # Check for uri conflict if uri is being changed and visibility is public + if resource_update.uri and resource_update.uri != resource.uri: + visibility = resource_update.visibility or resource.visibility + team_id = resource_update.team_id or resource.team_id + if visibility.lower() == "public": + # Check for existing public resources with the same uri + existing_resource = db.execute(select(DbResource).where(DbResource.uri == resource_update.uri, DbResource.visibility == "public")).scalar_one_or_none() + if existing_resource: + raise ResourceURIConflictError(resource_update.uri, is_active=existing_resource.is_active, resource_id=existing_resource.id, visibility=existing_resource.visibility) + elif visibility.lower() == "team" and team_id: + # Check for existing team resource with the same uri + existing_resource = db.execute(select(DbResource).where(DbResource.uri == resource_update.uri, DbResource.visibility == "team", DbResource.team_id == team_id)).scalar_one_or_none() + if existing_resource: + raise ResourceURIConflictError(resource_update.uri, is_active=existing_resource.is_active, resource_id=existing_resource.id, visibility=existing_resource.visibility) # Check ownership if user_email provided if user_email: @@ -1018,6 +1047,8 @@ async def update_resource( raise PermissionError("Only the owner can update this resource") # Update fields if provided + if resource_update.uri is not None: + resource.uri = resource_update.uri if resource_update.name is not None: resource.name = resource_update.name if resource_update.description is not None: @@ -1064,7 +1095,7 @@ async def update_resource( # Notify subscribers await self._notify_resource_updated(resource) - logger.info(f"Updated resource: {uri}") + logger.info(f"Updated resource: {resource.uri}") return self._convert_resource_to_read(resource) except PermissionError: db.rollback() @@ -1073,19 +1104,22 @@ async def update_resource( db.rollback() logger.error(f"IntegrityErrors in group: {ie}") raise ie + except ResourceURIConflictError as pe: + logger.error(f"Resource URI conflict: {pe}") + raise pe except Exception as e: db.rollback() if isinstance(e, ResourceNotFoundError): raise e raise ResourceError(f"Failed to update resource: {str(e)}") - async def delete_resource(self, db: Session, uri: str, user_email: Optional[str] = None) -> None: + async def delete_resource(self, db: Session, resource_id: Union[int, str], user_email: Optional[str] = None) -> None: """ Delete a resource. Args: db: Database session - uri: Resource URI + resource_id: Resource ID user_email: Email of user performing delete (for ownership check) Raises: @@ -1093,7 +1127,7 @@ async def delete_resource(self, db: Session, uri: str, user_email: Optional[str] PermissionError: If user doesn't own the resource ResourceError: For other deletion errors - Examples: + Example: >>> from mcpgateway.services.resource_service import ResourceService >>> from unittest.mock import MagicMock, AsyncMock >>> service = ResourceService() @@ -1104,16 +1138,16 @@ async def delete_resource(self, db: Session, uri: str, user_email: Optional[str] >>> db.commit = MagicMock() >>> service._notify_resource_deleted = AsyncMock() >>> import asyncio - >>> asyncio.run(service.delete_resource(db, 'uri')) + >>> asyncio.run(service.delete_resource(db, 'resource_id')) """ try: # Find resource by its URI. - resource = db.execute(select(DbResource).where(DbResource.uri == uri)).scalar_one_or_none() + resource = db.execute(select(DbResource).where(DbResource.id == resource_id)).scalar_one_or_none() if not resource: # If resource doesn't exist, rollback and re-raise a ResourceNotFoundError. db.rollback() - raise ResourceNotFoundError(f"Resource not found: {uri}") + raise ResourceNotFoundError(f"Resource not found: {resource_id}") # Check ownership if user_email provided if user_email: @@ -1141,7 +1175,7 @@ async def delete_resource(self, db: Session, uri: str, user_email: Optional[str] # Notify subscribers. await self._notify_resource_deleted(resource_info) - logger.info(f"Permanently deleted resource: {uri}") + logger.info(f"Permanently deleted resource: {resource.uri}") except PermissionError: db.rollback() @@ -1153,22 +1187,22 @@ async def delete_resource(self, db: Session, uri: str, user_email: Optional[str] db.rollback() raise ResourceError(f"Failed to delete resource: {str(e)}") - async def get_resource_by_uri(self, db: Session, uri: str, include_inactive: bool = False) -> ResourceRead: + async def get_resource_by_id(self, db: Session, resource_id: int, include_inactive: bool = False) -> ResourceRead: """ - Get a resource by URI. + Get a resource by ID. Args: db: Database session - uri: Resource URI + resource_id: Resource ID include_inactive: Whether to include inactive resources Returns: - ResourceRead object + ResourceRead: The resource object Raises: ResourceNotFoundError: If the resource is not found - Examples: + Example: >>> from mcpgateway.services.resource_service import ResourceService >>> from unittest.mock import MagicMock >>> service = ResourceService() @@ -1177,10 +1211,10 @@ async def get_resource_by_uri(self, db: Session, uri: str, include_inactive: boo >>> db.execute.return_value.scalar_one_or_none.return_value = resource >>> service._convert_resource_to_read = MagicMock(return_value='resource_read') >>> import asyncio - >>> asyncio.run(service.get_resource_by_uri(db, 'uri')) + >>> asyncio.run(service.get_resource_by_id(db, 999)) 'resource_read' """ - query = select(DbResource).where(DbResource.uri == uri) + query = select(DbResource).where(DbResource.id == resource_id) if not include_inactive: query = query.where(DbResource.is_active) @@ -1190,12 +1224,12 @@ async def get_resource_by_uri(self, db: Session, uri: str, include_inactive: boo if not resource: if not include_inactive: # Check if inactive resource exists - inactive_resource = db.execute(select(DbResource).where(DbResource.uri == uri).where(not_(DbResource.is_active))).scalar_one_or_none() + inactive_resource = db.execute(select(DbResource).where(DbResource.id == resource_id).where(not_(DbResource.is_active))).scalar_one_or_none() if inactive_resource: - raise ResourceNotFoundError(f"Resource '{uri}' exists but is inactive") + raise ResourceNotFoundError(f"Resource '{resource_id}' exists but is inactive") - raise ResourceNotFoundError(f"Resource not found: {uri}") + raise ResourceNotFoundError(f"Resource not found: {resource_id}") return self._convert_resource_to_read(resource) diff --git a/mcpgateway/static/admin.js b/mcpgateway/static/admin.js index 8a5fc75ee..75c65872d 100644 --- a/mcpgateway/static/admin.js +++ b/mcpgateway/static/admin.js @@ -3452,12 +3452,12 @@ async function viewPrompt(promptName) { /** * SECURE: Edit Prompt function with validation */ -async function editPrompt(promptName) { +async function editPrompt(promptId) { try { - console.log(`Editing prompt: ${promptName}`); + console.log(`Editing prompt: ${promptId}`); const response = await fetchWithTimeout( - `${window.ROOT_PATH}/admin/prompts/${encodeURIComponent(promptName)}`, + `${window.ROOT_PATH}/admin/prompts/${encodeURIComponent(promptId)}`, ); if (!response.ok) { @@ -3513,7 +3513,22 @@ async function editPrompt(promptName) { // Set form action and populate fields with validation const editForm = safeGetElement("edit-prompt-form"); if (editForm) { - editForm.action = `${window.ROOT_PATH}/admin/prompts/${encodeURIComponent(promptName)}/edit`; + editForm.action = `${window.ROOT_PATH}/admin/prompts/${encodeURIComponent(promptId)}/edit`; + // Add or update hidden team_id input if present in URL + const teamId = new URL(window.location.href).searchParams.get( + "team_id", + ); + if (teamId) { + let teamInput = safeGetElement("edit-prompt-team-id"); + if (!teamInput) { + teamInput = document.createElement("input"); + teamInput.type = "hidden"; + teamInput.name = "team_id"; + teamInput.id = "edit-prompt-team-id"; + editForm.appendChild(teamInput); + } + teamInput.value = teamId; + } } // Validate prompt name @@ -8009,7 +8024,13 @@ async function handlePromptFormSubmit(e) { async function handleEditPromptFormSubmit(e) { e.preventDefault(); const form = e.target; + const formData = new FormData(form); + // Add team_id from URL if present (like handleEditToolFormSubmit) + const teamId = new URL(window.location.href).searchParams.get("team_id"); + if (teamId) { + formData.set("team_id", teamId); + } try { // Validate inputs diff --git a/mcpgateway/templates/admin.html b/mcpgateway/templates/admin.html index 44e52a5fe..517cd8f16 100644 --- a/mcpgateway/templates/admin.html +++ b/mcpgateway/templates/admin.html @@ -3618,14 +3618,14 @@